mirror of
https://github.com/google/bumble.git
synced 2026-04-16 00:25:31 +00:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5790d3aae8 | ||
|
|
744294f00e | ||
|
|
3697b8dde9 | ||
|
|
f3bfbab44d | ||
|
|
afcce0d6c8 | ||
|
|
69d45bed21 | ||
|
|
4bd8c24f54 | ||
|
|
8d09693654 | ||
|
|
7d7534928f | ||
|
|
e9bf5757c4 | ||
|
|
f9f694dfcf | ||
|
|
022c23500a | ||
|
|
5d4f811a65 | ||
|
|
3c81b248a3 | ||
|
|
fdee5ecf70 | ||
|
|
29bd693bab | ||
|
|
30934969b8 | ||
|
|
4a333b6c0f | ||
|
|
b5cc167e31 |
30
apps/pandora_server.py
Normal file
30
apps/pandora_server.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import asyncio
|
||||
import click
|
||||
import logging
|
||||
|
||||
from bumble.pandora import PandoraDevice, serve
|
||||
|
||||
BUMBLE_SERVER_GRPC_PORT = 7999
|
||||
ROOTCANAL_PORT_CUTTLEFISH = 7300
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--grpc-port', help='gRPC port to serve', default=BUMBLE_SERVER_GRPC_PORT)
|
||||
@click.option(
|
||||
'--rootcanal-port', help='Rootcanal TCP port', default=ROOTCANAL_PORT_CUTTLEFISH
|
||||
)
|
||||
@click.option(
|
||||
'--transport',
|
||||
help='HCI transport',
|
||||
default=f'tcp-client:127.0.0.1:<rootcanal-port>',
|
||||
)
|
||||
def main(grpc_port: int, rootcanal_port: int, transport: str) -> None:
|
||||
if '<rootcanal-port>' in transport:
|
||||
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
|
||||
device = PandoraDevice({'transport': transport})
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
asyncio.run(serve(device, port=grpc_port))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
@@ -152,7 +152,12 @@ class UUID:
|
||||
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian
|
||||
UUIDS: List[UUID] = [] # Registry of all instances created
|
||||
|
||||
def __init__(self, uuid_str_or_int, name=None):
|
||||
uuid_bytes: bytes
|
||||
name: Optional[str]
|
||||
|
||||
def __init__(
|
||||
self, uuid_str_or_int: Union[str, int], name: Optional[str] = None
|
||||
) -> None:
|
||||
if isinstance(uuid_str_or_int, int):
|
||||
self.uuid_bytes = struct.pack('<H', uuid_str_or_int)
|
||||
else:
|
||||
@@ -172,7 +177,7 @@ class UUID:
|
||||
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
|
||||
self.name = name
|
||||
|
||||
def register(self):
|
||||
def register(self) -> UUID:
|
||||
# Register this object in the class registry, and update the entry's name if
|
||||
# it wasn't set already
|
||||
for uuid in self.UUIDS:
|
||||
@@ -196,22 +201,22 @@ class UUID:
|
||||
raise ValueError('only 2, 4 and 16 bytes are allowed')
|
||||
|
||||
@classmethod
|
||||
def from_16_bits(cls, uuid_16, name=None):
|
||||
def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID:
|
||||
return cls.from_bytes(struct.pack('<H', uuid_16), name)
|
||||
|
||||
@classmethod
|
||||
def from_32_bits(cls, uuid_32, name=None):
|
||||
def from_32_bits(cls, uuid_32: int, name: Optional[str] = None) -> UUID:
|
||||
return cls.from_bytes(struct.pack('<I', uuid_32), name)
|
||||
|
||||
@classmethod
|
||||
def parse_uuid(cls, uuid_as_bytes, offset):
|
||||
def parse_uuid(cls, uuid_as_bytes: bytes, offset: int) -> Tuple[int, UUID]:
|
||||
return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:])
|
||||
|
||||
@classmethod
|
||||
def parse_uuid_2(cls, uuid_as_bytes, offset):
|
||||
def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> Tuple[int, UUID]:
|
||||
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
|
||||
|
||||
def to_bytes(self, force_128=False):
|
||||
def to_bytes(self, force_128: bool = False) -> bytes:
|
||||
'''
|
||||
Serialize UUID in little-endian byte-order
|
||||
'''
|
||||
@@ -227,7 +232,7 @@ class UUID:
|
||||
else:
|
||||
assert False, "unreachable"
|
||||
|
||||
def to_pdu_bytes(self):
|
||||
def to_pdu_bytes(self) -> bytes:
|
||||
'''
|
||||
Convert to bytes for use in an ATT PDU.
|
||||
According to Vol 3, Part F - 3.2.1 Attribute Type:
|
||||
@@ -236,11 +241,11 @@ class UUID:
|
||||
'''
|
||||
return self.to_bytes(force_128=(len(self.uuid_bytes) == 4))
|
||||
|
||||
def to_hex_str(self) -> str:
|
||||
def to_hex_str(self, separator: str = '') -> str:
|
||||
if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4:
|
||||
return bytes(reversed(self.uuid_bytes)).hex().upper()
|
||||
|
||||
return ''.join(
|
||||
return separator.join(
|
||||
[
|
||||
bytes(reversed(self.uuid_bytes[12:16])).hex(),
|
||||
bytes(reversed(self.uuid_bytes[10:12])).hex(),
|
||||
@@ -250,10 +255,10 @@ class UUID:
|
||||
]
|
||||
).upper()
|
||||
|
||||
def __bytes__(self):
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.to_bytes()
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, UUID):
|
||||
return self.to_bytes(force_128=True) == other.to_bytes(force_128=True)
|
||||
|
||||
@@ -262,35 +267,19 @@ class UUID:
|
||||
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.uuid_bytes)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
result = self.to_hex_str(separator='-')
|
||||
if len(self.uuid_bytes) == 2:
|
||||
uuid = struct.unpack('<H', self.uuid_bytes)[0]
|
||||
result = f'UUID-16:{uuid:04X}'
|
||||
result = 'UUID-16:' + result
|
||||
elif len(self.uuid_bytes) == 4:
|
||||
uuid = struct.unpack('<I', self.uuid_bytes)[0]
|
||||
result = f'UUID-32:{uuid:08X}'
|
||||
else:
|
||||
result = '-'.join(
|
||||
[
|
||||
bytes(reversed(self.uuid_bytes[12:16])).hex(),
|
||||
bytes(reversed(self.uuid_bytes[10:12])).hex(),
|
||||
bytes(reversed(self.uuid_bytes[8:10])).hex(),
|
||||
bytes(reversed(self.uuid_bytes[6:8])).hex(),
|
||||
bytes(reversed(self.uuid_bytes[0:6])).hex(),
|
||||
]
|
||||
).upper()
|
||||
|
||||
result = 'UUID-32:' + result
|
||||
if self.name is not None:
|
||||
return result + f' ({self.name})'
|
||||
|
||||
result += f' ({self.name})'
|
||||
return result
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Common UUID constants
|
||||
@@ -773,7 +762,7 @@ class AdvertisingData:
|
||||
def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> List[UUID]:
|
||||
uuids = []
|
||||
offset = 0
|
||||
while (uuid_size * (offset + 1)) <= len(ad_data):
|
||||
while (offset + uuid_size) <= len(ad_data):
|
||||
uuids.append(UUID.from_bytes(ad_data[offset : offset + uuid_size]))
|
||||
offset += uuid_size
|
||||
return uuids
|
||||
|
||||
179
bumble/device.py
179
bumble/device.py
@@ -23,7 +23,7 @@ import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager, AsyncExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from .colors import color
|
||||
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
|
||||
@@ -528,6 +528,7 @@ class Connection(CompositeEventEmitter):
|
||||
transport: int
|
||||
self_address: Address
|
||||
peer_address: Address
|
||||
peer_resolvable_address: Optional[Address]
|
||||
role: int
|
||||
encryption: int
|
||||
authenticated: bool
|
||||
@@ -888,7 +889,7 @@ def host_event_handler(function):
|
||||
# List of host event handlers for the Device class.
|
||||
# (we define this list outside the class, because referencing a class in method
|
||||
# decorators is not straightforward)
|
||||
device_host_event_handlers: list[str] = []
|
||||
device_host_event_handlers: List[str] = []
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -2196,13 +2197,23 @@ class Device(CompositeEventEmitter):
|
||||
await self.stop_discovery()
|
||||
|
||||
@property
|
||||
def pairing_config_factory(self):
|
||||
def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]:
|
||||
return self.smp_manager.pairing_config_factory
|
||||
|
||||
@pairing_config_factory.setter
|
||||
def pairing_config_factory(self, pairing_config_factory):
|
||||
def pairing_config_factory(
|
||||
self, pairing_config_factory: Callable[[Connection], PairingConfig]
|
||||
) -> None:
|
||||
self.smp_manager.pairing_config_factory = pairing_config_factory
|
||||
|
||||
@property
|
||||
def smp_session_proxy(self) -> Type[smp.Session]:
|
||||
return self.smp_manager.session_proxy
|
||||
|
||||
@smp_session_proxy.setter
|
||||
def smp_session_proxy(self, session_proxy: Type[smp.Session]) -> None:
|
||||
self.smp_manager.session_proxy = session_proxy
|
||||
|
||||
async def pair(self, connection):
|
||||
return await self.smp_manager.pair(connection)
|
||||
|
||||
@@ -2232,7 +2243,7 @@ class Device(CompositeEventEmitter):
|
||||
if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral:
|
||||
return keys.ltk_peripheral.value
|
||||
|
||||
async def get_link_key(self, address):
|
||||
async def get_link_key(self, address: Address) -> Optional[bytes]:
|
||||
# Look for the key in the keystore
|
||||
if self.keystore is not None:
|
||||
keys = await self.keystore.get(str(address))
|
||||
@@ -2243,6 +2254,7 @@ class Device(CompositeEventEmitter):
|
||||
return None
|
||||
|
||||
return keys.link_key.value
|
||||
return None
|
||||
|
||||
# [Classic only]
|
||||
async def authenticate(self, connection):
|
||||
@@ -2772,89 +2784,103 @@ class Device(CompositeEventEmitter):
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_authentication_user_confirmation_request(self, connection, code):
|
||||
def on_authentication_user_confirmation_request(self, connection, code) -> None:
|
||||
# Ask what the pairing config should be for this connection
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
io_capability = pairing_config.delegate.classic_io_capability
|
||||
peer_io_capability = connection.peer_pairing_io_capability
|
||||
|
||||
# Respond
|
||||
if io_capability == HCI_DISPLAY_YES_NO_IO_CAPABILITY:
|
||||
if connection.peer_pairing_io_capability in (
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY,
|
||||
):
|
||||
# Display the code and ask the user to compare
|
||||
async def prompt():
|
||||
return (
|
||||
await pairing_config.delegate.compare_numbers(code, digits=6),
|
||||
async def confirm() -> bool:
|
||||
# Ask the user to confirm the pairing, without display
|
||||
return await pairing_config.delegate.confirm()
|
||||
|
||||
async def auto_confirm() -> bool:
|
||||
# Ask the user to auto-confirm the pairing, without display
|
||||
return await pairing_config.delegate.confirm(auto=True)
|
||||
|
||||
async def display_confirm() -> bool:
|
||||
# Display the code and ask the user to compare
|
||||
return await pairing_config.delegate.compare_numbers(code, digits=6)
|
||||
|
||||
async def display_auto_confirm() -> bool:
|
||||
# Display the code to the user and ask the delegate to auto-confirm
|
||||
await pairing_config.delegate.display_number(code, digits=6)
|
||||
return await pairing_config.delegate.confirm(auto=True)
|
||||
|
||||
async def na() -> bool:
|
||||
assert False, "N/A: unreachable"
|
||||
|
||||
# See Bluetooth spec @ Vol 3, Part C 5.2.2.6
|
||||
methods = {
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY: {
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY: display_auto_confirm,
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY: display_confirm,
|
||||
HCI_KEYBOARD_ONLY_IO_CAPABILITY: na,
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
|
||||
},
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY: {
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY: display_auto_confirm,
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY: display_confirm,
|
||||
HCI_KEYBOARD_ONLY_IO_CAPABILITY: na,
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
|
||||
},
|
||||
HCI_KEYBOARD_ONLY_IO_CAPABILITY: {
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY: na,
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY: na,
|
||||
HCI_KEYBOARD_ONLY_IO_CAPABILITY: na,
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
|
||||
},
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: {
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY: confirm,
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY: confirm,
|
||||
HCI_KEYBOARD_ONLY_IO_CAPABILITY: auto_confirm,
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
|
||||
},
|
||||
}
|
||||
|
||||
method = methods[peer_io_capability][io_capability]
|
||||
|
||||
async def reply() -> None:
|
||||
if await connection.abort_on('disconnection', method()):
|
||||
await self.host.send_command(
|
||||
HCI_User_Confirmation_Request_Reply_Command( # type: ignore[call-arg]
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
|
||||
)
|
||||
else:
|
||||
# Ask the user to confirm the pairing, without showing a code
|
||||
async def prompt():
|
||||
return await pairing_config.delegate.confirm()
|
||||
|
||||
async def confirm():
|
||||
if await prompt():
|
||||
await self.host.send_command(
|
||||
HCI_User_Confirmation_Request_Reply_Command(
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.host.send_command(
|
||||
HCI_User_Confirmation_Request_Negative_Reply_Command(
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
await self.host.send_command(
|
||||
HCI_User_Confirmation_Request_Negative_Reply_Command( # type: ignore[call-arg]
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
)
|
||||
|
||||
AsyncRunner.spawn(connection.abort_on('disconnection', confirm()))
|
||||
return
|
||||
|
||||
if io_capability == HCI_DISPLAY_ONLY_IO_CAPABILITY:
|
||||
# Display the code to the user
|
||||
AsyncRunner.spawn(pairing_config.delegate.display_number(code, 6))
|
||||
|
||||
# Automatic confirmation
|
||||
self.host.send_command_sync(
|
||||
HCI_User_Confirmation_Request_Reply_Command(bd_addr=connection.peer_address)
|
||||
)
|
||||
AsyncRunner.spawn(reply())
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_authentication_user_passkey_request(self, connection):
|
||||
def on_authentication_user_passkey_request(self, connection) -> None:
|
||||
# Ask what the pairing config should be for this connection
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
io_capability = pairing_config.delegate.classic_io_capability
|
||||
|
||||
# Respond
|
||||
if io_capability == HCI_KEYBOARD_ONLY_IO_CAPABILITY:
|
||||
# Ask the user to input a number
|
||||
async def get_number():
|
||||
number = await connection.abort_on(
|
||||
'disconnection', pairing_config.delegate.get_number()
|
||||
)
|
||||
if number is not None:
|
||||
await self.host.send_command(
|
||||
HCI_User_Passkey_Request_Reply_Command(
|
||||
bd_addr=connection.peer_address, numeric_value=number
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.host.send_command(
|
||||
HCI_User_Passkey_Request_Negative_Reply_Command(
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
)
|
||||
|
||||
asyncio.create_task(get_number())
|
||||
else:
|
||||
self.host.send_command_sync(
|
||||
HCI_User_Passkey_Request_Negative_Reply_Command(
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
async def reply() -> None:
|
||||
number = await connection.abort_on(
|
||||
'disconnection', pairing_config.delegate.get_number()
|
||||
)
|
||||
if number is not None:
|
||||
await self.host.send_command(
|
||||
HCI_User_Passkey_Request_Reply_Command( # type: ignore[call-arg]
|
||||
bd_addr=connection.peer_address, numeric_value=number
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.host.send_command(
|
||||
HCI_User_Passkey_Request_Negative_Reply_Command( # type: ignore[call-arg]
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
)
|
||||
|
||||
AsyncRunner.spawn(reply())
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@@ -3059,18 +3085,15 @@ class Device(CompositeEventEmitter):
|
||||
connection.emit('role_change_failure', error)
|
||||
self.emit('role_change_failure', address, error)
|
||||
|
||||
@with_connection_from_handle
|
||||
def on_pairing_start(self, connection):
|
||||
def on_pairing_start(self, connection: Connection) -> None:
|
||||
connection.emit('pairing_start')
|
||||
|
||||
@with_connection_from_handle
|
||||
def on_pairing(self, connection, keys, sc):
|
||||
def on_pairing(self, connection: Connection, keys: PairingKeys, sc: bool) -> None:
|
||||
connection.sc = sc
|
||||
connection.authenticated = True
|
||||
connection.emit('pairing', keys)
|
||||
|
||||
@with_connection_from_handle
|
||||
def on_pairing_failure(self, connection, reason):
|
||||
def on_pairing_failure(self, connection: Connection, reason: int) -> None:
|
||||
connection.emit('pairing_failure', reason)
|
||||
|
||||
@with_connection_from_handle
|
||||
|
||||
@@ -205,8 +205,16 @@ class Service(Attribute):
|
||||
'''
|
||||
|
||||
uuid: UUID
|
||||
characteristics: List[Characteristic]
|
||||
included_services: List[Service]
|
||||
|
||||
def __init__(self, uuid, characteristics: list[Characteristic], primary=True):
|
||||
def __init__(
|
||||
self,
|
||||
uuid,
|
||||
characteristics: List[Characteristic],
|
||||
primary=True,
|
||||
included_services: List[Service] = [],
|
||||
):
|
||||
# Convert the uuid to a UUID object if it isn't already
|
||||
if isinstance(uuid, str):
|
||||
uuid = UUID(uuid)
|
||||
@@ -219,7 +227,7 @@ class Service(Attribute):
|
||||
uuid.to_pdu_bytes(),
|
||||
)
|
||||
self.uuid = uuid
|
||||
# self.included_services = []
|
||||
self.included_services = included_services[:]
|
||||
self.characteristics = characteristics[:]
|
||||
self.primary = primary
|
||||
|
||||
@@ -247,12 +255,39 @@ class TemplateService(Service):
|
||||
to expose their UUID as a class property
|
||||
'''
|
||||
|
||||
UUID = None
|
||||
UUID: Optional[UUID] = None
|
||||
|
||||
def __init__(self, characteristics, primary=True):
|
||||
super().__init__(self.UUID, characteristics, primary)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class IncludedServiceDeclaration(Attribute):
|
||||
'''
|
||||
See Vol 3, Part G - 3.2 INCLUDE DEFINITION
|
||||
'''
|
||||
|
||||
service: Service
|
||||
|
||||
def __init__(self, service):
|
||||
declaration_bytes = struct.pack(
|
||||
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
|
||||
)
|
||||
super().__init__(
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
|
||||
)
|
||||
self.service = service
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
|
||||
f'group_starting_handle=0x{self.service.handle:04X}, '
|
||||
f'group_ending_handle=0x{self.service.end_group_handle:04X}, '
|
||||
f'uuid={self.service.uuid}, '
|
||||
f'{self.service.properties!s})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Characteristic(Attribute):
|
||||
'''
|
||||
|
||||
@@ -63,6 +63,7 @@ from .gatt import (
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
GATT_REQUEST_TIMEOUT,
|
||||
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||
Characteristic,
|
||||
ClientCharacteristicConfigurationBits,
|
||||
)
|
||||
@@ -109,6 +110,7 @@ class AttributeProxy(EventEmitter):
|
||||
class ServiceProxy(AttributeProxy):
|
||||
uuid: UUID
|
||||
characteristics: List[CharacteristicProxy]
|
||||
included_services: List[ServiceProxy]
|
||||
|
||||
@staticmethod
|
||||
def from_client(service_class, client, service_uuid):
|
||||
@@ -502,12 +504,69 @@ class Client:
|
||||
|
||||
return services
|
||||
|
||||
async def discover_included_services(self, _service):
|
||||
async def discover_included_services(
|
||||
self, service: ServiceProxy
|
||||
) -> List[ServiceProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.5.1 Find Included Services
|
||||
'''
|
||||
# TODO
|
||||
return []
|
||||
|
||||
starting_handle = service.handle
|
||||
ending_handle = service.end_group_handle
|
||||
|
||||
included_services: List[ServiceProxy] = []
|
||||
while starting_handle <= ending_handle:
|
||||
response = await self.send_request(
|
||||
ATT_Read_By_Type_Request(
|
||||
starting_handle=starting_handle,
|
||||
ending_handle=ending_handle,
|
||||
attribute_type=GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||
)
|
||||
)
|
||||
if response is None:
|
||||
# TODO raise appropriate exception
|
||||
return []
|
||||
|
||||
# Check if we reached the end of the iteration
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
# Unexpected end
|
||||
logger.warning(
|
||||
'!!! unexpected error while discovering included services: '
|
||||
f'{HCI_Constant.error_name(response.error_code)}'
|
||||
)
|
||||
raise ATT_Error(
|
||||
error_code=response.error_code,
|
||||
message='Unexpected error while discovering included services',
|
||||
)
|
||||
break
|
||||
|
||||
# Stop if for some reason the list was empty
|
||||
if not response.attributes:
|
||||
break
|
||||
|
||||
# Process all included services returned in this iteration
|
||||
for attribute_handle, attribute_value in response.attributes:
|
||||
if attribute_handle < starting_handle:
|
||||
# Something's not right
|
||||
logger.warning(f'bogus handle value: {attribute_handle}')
|
||||
return []
|
||||
|
||||
group_starting_handle, group_ending_handle = struct.unpack_from(
|
||||
'<HH', attribute_value
|
||||
)
|
||||
service_uuid = UUID.from_bytes(attribute_value[4:])
|
||||
included_service = ServiceProxy(
|
||||
self, group_starting_handle, group_ending_handle, service_uuid, True
|
||||
)
|
||||
|
||||
included_services.append(included_service)
|
||||
|
||||
# Move on to the next included services
|
||||
starting_handle = response.attributes[-1][0] + 1
|
||||
|
||||
service.included_services = included_services
|
||||
return included_services
|
||||
|
||||
async def discover_characteristics(
|
||||
self, uuids, service: Optional[ServiceProxy]
|
||||
|
||||
@@ -68,6 +68,7 @@ from .gatt import (
|
||||
Characteristic,
|
||||
CharacteristicDeclaration,
|
||||
CharacteristicValue,
|
||||
IncludedServiceDeclaration,
|
||||
Descriptor,
|
||||
Service,
|
||||
)
|
||||
@@ -94,6 +95,7 @@ class Server(EventEmitter):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.services = []
|
||||
self.attributes = [] # Attributes, ordered by increasing handle values
|
||||
self.attributes_by_handle = {} # Map for fast attribute access by handle
|
||||
self.max_mtu = (
|
||||
@@ -222,7 +224,14 @@ class Server(EventEmitter):
|
||||
# Add the service attribute to the DB
|
||||
self.add_attribute(service)
|
||||
|
||||
# TODO: add included services
|
||||
# Add all included service
|
||||
for included_service in service.included_services:
|
||||
# Not registered yet, register the included service first.
|
||||
if included_service not in self.services:
|
||||
self.add_service(included_service)
|
||||
# TODO: Handle circular service reference
|
||||
include_declaration = IncludedServiceDeclaration(included_service)
|
||||
self.add_attribute(include_declaration)
|
||||
|
||||
# Add all characteristics
|
||||
for characteristic in service.characteristics:
|
||||
@@ -274,6 +283,7 @@ class Server(EventEmitter):
|
||||
|
||||
# Update the service group end
|
||||
service.end_group_handle = self.attributes[-1].handle
|
||||
self.services.append(service)
|
||||
|
||||
def add_services(self, services):
|
||||
for service in services:
|
||||
|
||||
@@ -25,7 +25,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from .colors import color
|
||||
from .hci import Address
|
||||
@@ -139,19 +139,19 @@ class PairingKeys:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class KeyStore:
|
||||
async def delete(self, name):
|
||||
async def delete(self, name: str):
|
||||
pass
|
||||
|
||||
async def update(self, name, keys):
|
||||
async def update(self, name: str, keys: PairingKeys) -> None:
|
||||
pass
|
||||
|
||||
async def get(self, _name):
|
||||
return PairingKeys()
|
||||
async def get(self, _name: str) -> Optional[PairingKeys]:
|
||||
return None
|
||||
|
||||
async def get_all(self):
|
||||
async def get_all(self) -> List[Tuple[str, PairingKeys]]:
|
||||
return []
|
||||
|
||||
async def delete_all(self):
|
||||
async def delete_all(self) -> None:
|
||||
all_keys = await self.get_all()
|
||||
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
|
||||
|
||||
@@ -177,15 +177,15 @@ class KeyStore:
|
||||
separator = '\n'
|
||||
|
||||
@staticmethod
|
||||
def create_for_device(device: Device) -> Optional[KeyStore]:
|
||||
def create_for_device(device: Device) -> KeyStore:
|
||||
if device.config.keystore is None:
|
||||
return None
|
||||
return MemoryKeyStore()
|
||||
|
||||
keystore_type = device.config.keystore.split(':', 1)[0]
|
||||
if keystore_type == 'JsonKeyStore':
|
||||
return JsonKeyStore.from_device(device)
|
||||
|
||||
return None
|
||||
return MemoryKeyStore()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -307,3 +307,24 @@ class JsonKeyStore(KeyStore):
|
||||
return None
|
||||
|
||||
return PairingKeys.from_dict(keys)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MemoryKeyStore(KeyStore):
|
||||
all_keys: Dict[str, PairingKeys]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.all_keys = {}
|
||||
|
||||
async def delete(self, name: str) -> None:
|
||||
if name in self.all_keys:
|
||||
del self.all_keys[name]
|
||||
|
||||
async def update(self, name: str, keys: PairingKeys) -> None:
|
||||
self.all_keys[name] = keys
|
||||
|
||||
async def get(self, name: str) -> Optional[PairingKeys]:
|
||||
return self.all_keys.get(name)
|
||||
|
||||
async def get_all(self) -> List[Tuple[str, PairingKeys]]:
|
||||
return list(self.all_keys.items())
|
||||
|
||||
@@ -65,8 +65,9 @@ class PairingDelegate:
|
||||
DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG
|
||||
DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG
|
||||
|
||||
DEFAULT_KEY_DISTRIBUTION: int = (
|
||||
SMP_ENC_KEY_DISTRIBUTION_FLAG | SMP_ID_KEY_DISTRIBUTION_FLAG
|
||||
DEFAULT_KEY_DISTRIBUTION: KeyDistribution = (
|
||||
KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY
|
||||
| KeyDistribution.DISTRIBUTE_IDENTITY_KEY
|
||||
)
|
||||
|
||||
# Default mapping from abstract to Classic I/O capabilities.
|
||||
@@ -85,9 +86,9 @@ class PairingDelegate:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
io_capability=NO_OUTPUT_NO_INPUT,
|
||||
local_initiator_key_distribution=DEFAULT_KEY_DISTRIBUTION,
|
||||
local_responder_key_distribution=DEFAULT_KEY_DISTRIBUTION,
|
||||
io_capability: IoCapability = NO_OUTPUT_NO_INPUT,
|
||||
local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
|
||||
local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
|
||||
) -> None:
|
||||
self.io_capability = io_capability
|
||||
self.local_initiator_key_distribution = local_initiator_key_distribution
|
||||
@@ -113,8 +114,11 @@ class PairingDelegate:
|
||||
"""Accept or reject a Pairing request."""
|
||||
return True
|
||||
|
||||
async def confirm(self) -> bool:
|
||||
"""Respond yes or no to a Pairing confirmation question."""
|
||||
async def confirm(self, auto: bool = False) -> bool:
|
||||
"""
|
||||
Respond yes or no to a Pairing confirmation question.
|
||||
The `auto` parameter stands for automatic confirmation.
|
||||
"""
|
||||
return True
|
||||
|
||||
# pylint: disable-next=unused-argument
|
||||
@@ -129,7 +133,7 @@ class PairingDelegate:
|
||||
"""
|
||||
return 0
|
||||
|
||||
async def get_string(self, max_length) -> Optional[str]:
|
||||
async def get_string(self, max_length: int) -> Optional[str]:
|
||||
"""
|
||||
Return a string whose utf-8 encoding is up to max_length bytes.
|
||||
"""
|
||||
|
||||
105
bumble/pandora/__init__.py
Normal file
105
bumble/pandora/__init__.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Bumble Pandora server.
|
||||
This module implement the Pandora Bluetooth test APIs for the Bumble stack.
|
||||
"""
|
||||
|
||||
__version__ = "0.0.1"
|
||||
|
||||
import grpc
|
||||
import grpc.aio
|
||||
|
||||
from .config import Config
|
||||
from .device import PandoraDevice
|
||||
from .host import HostService
|
||||
from .security import SecurityService, SecurityStorageService
|
||||
from pandora.host_grpc_aio import add_HostServicer_to_server
|
||||
from pandora.security_grpc_aio import (
|
||||
add_SecurityServicer_to_server,
|
||||
add_SecurityStorageServicer_to_server,
|
||||
)
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
# public symbols
|
||||
__all__ = [
|
||||
'register_servicer_hook',
|
||||
'serve',
|
||||
'Config',
|
||||
'PandoraDevice',
|
||||
]
|
||||
|
||||
|
||||
# Add servicers hooks.
|
||||
_SERVICERS_HOOKS: List[Callable[[PandoraDevice, Config, grpc.aio.Server], None]] = []
|
||||
|
||||
|
||||
def register_servicer_hook(
|
||||
hook: Callable[[PandoraDevice, Config, grpc.aio.Server], None]
|
||||
) -> None:
|
||||
_SERVICERS_HOOKS.append(hook)
|
||||
|
||||
|
||||
async def serve(
|
||||
bumble: PandoraDevice,
|
||||
config: Config = Config(),
|
||||
grpc_server: Optional[grpc.aio.Server] = None,
|
||||
port: int = 0,
|
||||
) -> None:
|
||||
# initialize a gRPC server if not provided.
|
||||
server = grpc_server if grpc_server is not None else grpc.aio.server()
|
||||
port = server.add_insecure_port(f'localhost:{port}')
|
||||
|
||||
try:
|
||||
while True:
|
||||
# load server config from dict.
|
||||
config.load_from_dict(bumble.config.get('server', {}))
|
||||
|
||||
# add Pandora services to the gRPC server.
|
||||
add_HostServicer_to_server(
|
||||
HostService(server, bumble.device, config), server
|
||||
)
|
||||
add_SecurityServicer_to_server(
|
||||
SecurityService(bumble.device, config), server
|
||||
)
|
||||
add_SecurityStorageServicer_to_server(
|
||||
SecurityStorageService(bumble.device, config), server
|
||||
)
|
||||
|
||||
# call hooks if any.
|
||||
for hook in _SERVICERS_HOOKS:
|
||||
hook(bumble, config, server)
|
||||
|
||||
# open device.
|
||||
await bumble.open()
|
||||
try:
|
||||
# Pandora require classic devices to be discoverable & connectable.
|
||||
if bumble.device.classic_enabled:
|
||||
await bumble.device.set_discoverable(True)
|
||||
await bumble.device.set_connectable(True)
|
||||
|
||||
# start & serve gRPC server.
|
||||
await server.start()
|
||||
await server.wait_for_termination()
|
||||
finally:
|
||||
# close device.
|
||||
await bumble.close()
|
||||
|
||||
# re-initialize the gRPC server.
|
||||
server = grpc.aio.server()
|
||||
server.add_insecure_port(f'localhost:{port}')
|
||||
finally:
|
||||
# stop server.
|
||||
await server.stop(None)
|
||||
48
bumble/pandora/config.py
Normal file
48
bumble/pandora/config.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from bumble.pairing import PairingDelegate
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
io_capability: PairingDelegate.IoCapability = PairingDelegate.NO_OUTPUT_NO_INPUT
|
||||
pairing_sc_enable: bool = True
|
||||
pairing_mitm_enable: bool = True
|
||||
pairing_bonding_enable: bool = True
|
||||
smp_local_initiator_key_distribution: PairingDelegate.KeyDistribution = (
|
||||
PairingDelegate.DEFAULT_KEY_DISTRIBUTION
|
||||
)
|
||||
smp_local_responder_key_distribution: PairingDelegate.KeyDistribution = (
|
||||
PairingDelegate.DEFAULT_KEY_DISTRIBUTION
|
||||
)
|
||||
|
||||
def load_from_dict(self, config: Dict[str, Any]) -> None:
|
||||
io_capability_name: str = config.get(
|
||||
'io_capability', 'no_output_no_input'
|
||||
).upper()
|
||||
self.io_capability = getattr(PairingDelegate, io_capability_name)
|
||||
self.pairing_sc_enable = config.get('pairing_sc_enable', True)
|
||||
self.pairing_mitm_enable = config.get('pairing_mitm_enable', True)
|
||||
self.pairing_bonding_enable = config.get('pairing_bonding_enable', True)
|
||||
self.smp_local_initiator_key_distribution = config.get(
|
||||
'smp_local_initiator_key_distribution',
|
||||
PairingDelegate.DEFAULT_KEY_DISTRIBUTION,
|
||||
)
|
||||
self.smp_local_responder_key_distribution = config.get(
|
||||
'smp_local_responder_key_distribution',
|
||||
PairingDelegate.DEFAULT_KEY_DISTRIBUTION,
|
||||
)
|
||||
157
bumble/pandora/device.py
Normal file
157
bumble/pandora/device.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Generic & dependency free Bumble (reference) device."""
|
||||
|
||||
from bumble import transport
|
||||
from bumble.core import (
|
||||
BT_GENERIC_AUDIO_SERVICE,
|
||||
BT_HANDSFREE_SERVICE,
|
||||
BT_L2CAP_PROTOCOL_ID,
|
||||
BT_RFCOMM_PROTOCOL_ID,
|
||||
)
|
||||
from bumble.device import Device, DeviceConfiguration
|
||||
from bumble.host import Host
|
||||
from bumble.sdp import (
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
DataElement,
|
||||
ServiceAttribute,
|
||||
)
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class PandoraDevice:
|
||||
"""
|
||||
Small wrapper around a Bumble device and it's HCI transport.
|
||||
Notes:
|
||||
- The Bumble device is idle by default.
|
||||
- Repetitive calls to `open`/`close` will result on new Bumble device instances.
|
||||
"""
|
||||
|
||||
# Bumble device instance & configuration.
|
||||
device: Device
|
||||
config: Dict[str, Any]
|
||||
|
||||
# HCI transport name & instance.
|
||||
_hci_name: str
|
||||
_hci: Optional[transport.Transport] # type: ignore[name-defined]
|
||||
|
||||
def __init__(self, config: Dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.device = _make_device(config)
|
||||
self._hci_name = config.get('transport', '')
|
||||
self._hci = None
|
||||
|
||||
@property
|
||||
def idle(self) -> bool:
|
||||
return self._hci is None
|
||||
|
||||
async def open(self) -> None:
|
||||
if self._hci is not None:
|
||||
return
|
||||
|
||||
# open HCI transport & set device host.
|
||||
self._hci = await transport.open_transport(self._hci_name)
|
||||
self.device.host = Host(controller_source=self._hci.source, controller_sink=self._hci.sink) # type: ignore[no-untyped-call]
|
||||
|
||||
# power-on.
|
||||
await self.device.power_on()
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._hci is None:
|
||||
return
|
||||
|
||||
# flush & re-initialize device.
|
||||
await self.device.host.flush()
|
||||
self.device.host = None # type: ignore[assignment]
|
||||
self.device = _make_device(self.config)
|
||||
|
||||
# close HCI transport.
|
||||
await self._hci.close()
|
||||
self._hci = None
|
||||
|
||||
async def reset(self) -> None:
|
||||
await self.close()
|
||||
await self.open()
|
||||
|
||||
def info(self) -> Optional[Dict[str, str]]:
|
||||
return {
|
||||
'public_bd_address': str(self.device.public_address),
|
||||
'random_address': str(self.device.random_address),
|
||||
}
|
||||
|
||||
|
||||
def _make_device(config: Dict[str, Any]) -> Device:
|
||||
"""Initialize an idle Bumble device instance."""
|
||||
|
||||
# initialize bumble device.
|
||||
device_config = DeviceConfiguration()
|
||||
device_config.load_from_dict(config)
|
||||
device = Device(config=device_config, host=None)
|
||||
|
||||
# Add fake a2dp service to avoid Android disconnect
|
||||
device.sdp_service_records = _make_sdp_records(1)
|
||||
|
||||
return device
|
||||
|
||||
|
||||
# TODO(b/267540823): remove when Pandora A2dp is supported
|
||||
def _make_sdp_records(rfcomm_channel: int) -> Dict[int, List[ServiceAttribute]]:
|
||||
return {
|
||||
0x00010001: [
|
||||
ServiceAttribute(
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_32(0x00010001),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_HANDSFREE_SERVICE),
|
||||
DataElement.uuid(BT_GENERIC_AUDIO_SERVICE),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
|
||||
DataElement.unsigned_integer_8(rfcomm_channel),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_HANDSFREE_SERVICE),
|
||||
DataElement.unsigned_integer_16(0x0105),
|
||||
]
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
]
|
||||
}
|
||||
856
bumble/pandora/host.py
Normal file
856
bumble/pandora/host.py
Normal file
@@ -0,0 +1,856 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import bumble.device
|
||||
import grpc
|
||||
import grpc.aio
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from . import utils
|
||||
from .config import Config
|
||||
from bumble.core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_LE_TRANSPORT,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
UUID,
|
||||
AdvertisingData,
|
||||
ConnectionError,
|
||||
)
|
||||
from bumble.device import (
|
||||
DEVICE_DEFAULT_SCAN_INTERVAL,
|
||||
DEVICE_DEFAULT_SCAN_WINDOW,
|
||||
Advertisement,
|
||||
AdvertisingType,
|
||||
Device,
|
||||
)
|
||||
from bumble.gatt import Service
|
||||
from bumble.hci import (
|
||||
HCI_CONNECTION_ALREADY_EXISTS_ERROR,
|
||||
HCI_PAGE_TIMEOUT_ERROR,
|
||||
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
Address,
|
||||
)
|
||||
from google.protobuf import any_pb2, empty_pb2 # pytype: disable=pyi-error
|
||||
from pandora.host_grpc_aio import HostServicer
|
||||
from pandora.host_pb2 import (
|
||||
NOT_CONNECTABLE,
|
||||
NOT_DISCOVERABLE,
|
||||
PRIMARY_1M,
|
||||
PRIMARY_CODED,
|
||||
SECONDARY_1M,
|
||||
SECONDARY_2M,
|
||||
SECONDARY_CODED,
|
||||
SECONDARY_NONE,
|
||||
AdvertiseRequest,
|
||||
AdvertiseResponse,
|
||||
Connection,
|
||||
ConnectLERequest,
|
||||
ConnectLEResponse,
|
||||
ConnectRequest,
|
||||
ConnectResponse,
|
||||
DataTypes,
|
||||
DisconnectRequest,
|
||||
InquiryResponse,
|
||||
PrimaryPhy,
|
||||
ReadLocalAddressResponse,
|
||||
ScanningResponse,
|
||||
ScanRequest,
|
||||
SecondaryPhy,
|
||||
SetConnectabilityModeRequest,
|
||||
SetDiscoverabilityModeRequest,
|
||||
WaitConnectionRequest,
|
||||
WaitConnectionResponse,
|
||||
WaitDisconnectionRequest,
|
||||
)
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Set, Tuple, cast
|
||||
|
||||
PRIMARY_PHY_MAP: Dict[int, PrimaryPhy] = {
|
||||
# Default value reported by Bumble for legacy Advertising reports.
|
||||
# FIXME(uael): `None` might be a better value, but Bumble need to change accordingly.
|
||||
0: PRIMARY_1M,
|
||||
1: PRIMARY_1M,
|
||||
3: PRIMARY_CODED,
|
||||
}
|
||||
|
||||
SECONDARY_PHY_MAP: Dict[int, SecondaryPhy] = {
|
||||
0: SECONDARY_NONE,
|
||||
1: SECONDARY_1M,
|
||||
2: SECONDARY_2M,
|
||||
3: SECONDARY_CODED,
|
||||
}
|
||||
|
||||
|
||||
class HostService(HostServicer):
|
||||
waited_connections: Set[int]
|
||||
|
||||
def __init__(
|
||||
self, grpc_server: grpc.aio.Server, device: Device, config: Config
|
||||
) -> None:
|
||||
self.log = utils.BumbleServerLoggerAdapter(
|
||||
logging.getLogger(), {'service_name': 'Host', 'device': device}
|
||||
)
|
||||
self.grpc_server = grpc_server
|
||||
self.device = device
|
||||
self.config = config
|
||||
self.waited_connections = set()
|
||||
|
||||
@utils.rpc
|
||||
async def FactoryReset(
|
||||
self, request: empty_pb2.Empty, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
self.log.info('FactoryReset')
|
||||
|
||||
# delete all bonds
|
||||
if self.device.keystore is not None:
|
||||
await self.device.keystore.delete_all()
|
||||
|
||||
# trigger gRCP server stop then return
|
||||
asyncio.create_task(self.grpc_server.stop(None))
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@utils.rpc
|
||||
async def Reset(
|
||||
self, request: empty_pb2.Empty, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
self.log.info('Reset')
|
||||
|
||||
# clear service.
|
||||
self.waited_connections.clear()
|
||||
|
||||
# (re) power device on
|
||||
await self.device.power_on()
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@utils.rpc
|
||||
async def ReadLocalAddress(
|
||||
self, request: empty_pb2.Empty, context: grpc.ServicerContext
|
||||
) -> ReadLocalAddressResponse:
|
||||
self.log.info('ReadLocalAddress')
|
||||
return ReadLocalAddressResponse(
|
||||
address=bytes(reversed(bytes(self.device.public_address)))
|
||||
)
|
||||
|
||||
@utils.rpc
|
||||
async def Connect(
|
||||
self, request: ConnectRequest, context: grpc.ServicerContext
|
||||
) -> ConnectResponse:
|
||||
# Need to reverse bytes order since Bumble Address is using MSB.
|
||||
address = Address(
|
||||
bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS
|
||||
)
|
||||
self.log.info(f"Connect to {address}")
|
||||
|
||||
try:
|
||||
connection = await self.device.connect(
|
||||
address, transport=BT_BR_EDR_TRANSPORT
|
||||
)
|
||||
except ConnectionError as e:
|
||||
if e.error_code == HCI_PAGE_TIMEOUT_ERROR:
|
||||
self.log.warning(f"Peer not found: {e}")
|
||||
return ConnectResponse(peer_not_found=empty_pb2.Empty())
|
||||
if e.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR:
|
||||
self.log.warning(f"Connection already exists: {e}")
|
||||
return ConnectResponse(connection_already_exists=empty_pb2.Empty())
|
||||
raise e
|
||||
|
||||
self.log.info(f"Connect to {address} done (handle={connection.handle})")
|
||||
|
||||
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
|
||||
return ConnectResponse(connection=Connection(cookie=cookie))
|
||||
|
||||
@utils.rpc
|
||||
async def WaitConnection(
|
||||
self, request: WaitConnectionRequest, context: grpc.ServicerContext
|
||||
) -> WaitConnectionResponse:
|
||||
if not request.address:
|
||||
raise ValueError('Request address field must be set')
|
||||
|
||||
# Need to reverse bytes order since Bumble Address is using MSB.
|
||||
address = Address(
|
||||
bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS
|
||||
)
|
||||
if address in (Address.NIL, Address.ANY):
|
||||
raise ValueError('Invalid address')
|
||||
|
||||
self.log.info(f"WaitConnection from {address}...")
|
||||
|
||||
connection = self.device.find_connection_by_bd_addr(
|
||||
address, transport=BT_BR_EDR_TRANSPORT
|
||||
)
|
||||
if connection and id(connection) in self.waited_connections:
|
||||
# this connection was already returned: wait for a new one.
|
||||
connection = None
|
||||
|
||||
if not connection:
|
||||
connection = await self.device.accept(address)
|
||||
|
||||
# save connection has waited and respond.
|
||||
self.waited_connections.add(id(connection))
|
||||
|
||||
self.log.info(
|
||||
f"WaitConnection from {address} done (handle={connection.handle})"
|
||||
)
|
||||
|
||||
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
|
||||
return WaitConnectionResponse(connection=Connection(cookie=cookie))
|
||||
|
||||
@utils.rpc
|
||||
async def ConnectLE(
|
||||
self, request: ConnectLERequest, context: grpc.ServicerContext
|
||||
) -> ConnectLEResponse:
|
||||
address = utils.address_from_request(request, request.WhichOneof("address"))
|
||||
if address in (Address.NIL, Address.ANY):
|
||||
raise ValueError('Invalid address')
|
||||
|
||||
self.log.info(f"ConnectLE to {address}...")
|
||||
|
||||
try:
|
||||
connection = await self.device.connect(
|
||||
address,
|
||||
transport=BT_LE_TRANSPORT,
|
||||
own_address_type=request.own_address_type,
|
||||
)
|
||||
except ConnectionError as e:
|
||||
if e.error_code == HCI_PAGE_TIMEOUT_ERROR:
|
||||
self.log.warning(f"Peer not found: {e}")
|
||||
return ConnectLEResponse(peer_not_found=empty_pb2.Empty())
|
||||
if e.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR:
|
||||
self.log.warning(f"Connection already exists: {e}")
|
||||
return ConnectLEResponse(connection_already_exists=empty_pb2.Empty())
|
||||
raise e
|
||||
|
||||
self.log.info(f"ConnectLE to {address} done (handle={connection.handle})")
|
||||
|
||||
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
|
||||
return ConnectLEResponse(connection=Connection(cookie=cookie))
|
||||
|
||||
@utils.rpc
|
||||
async def Disconnect(
|
||||
self, request: DisconnectRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
self.log.info(f"Disconnect: {connection_handle}")
|
||||
|
||||
self.log.info("Disconnecting...")
|
||||
if connection := self.device.lookup_connection(connection_handle):
|
||||
await connection.disconnect(HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR)
|
||||
self.log.info("Disconnected")
|
||||
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@utils.rpc
|
||||
async def WaitDisconnection(
|
||||
self, request: WaitDisconnectionRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
self.log.info(f"WaitDisconnection: {connection_handle}")
|
||||
|
||||
if connection := self.device.lookup_connection(connection_handle):
|
||||
disconnection_future: asyncio.Future[
|
||||
None
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
|
||||
def on_disconnection(_: None) -> None:
|
||||
disconnection_future.set_result(None)
|
||||
|
||||
connection.on('disconnection', on_disconnection)
|
||||
try:
|
||||
await disconnection_future
|
||||
self.log.info("Disconnected")
|
||||
finally:
|
||||
connection.remove_listener('disconnection', on_disconnection) # type: ignore
|
||||
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@utils.rpc
|
||||
async def Advertise(
|
||||
self, request: AdvertiseRequest, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[AdvertiseResponse, None]:
|
||||
if not request.legacy:
|
||||
raise NotImplementedError(
|
||||
"TODO: add support for extended advertising in Bumble"
|
||||
)
|
||||
if request.interval:
|
||||
raise NotImplementedError("TODO: add support for `request.interval`")
|
||||
if request.interval_range:
|
||||
raise NotImplementedError("TODO: add support for `request.interval_range`")
|
||||
if request.primary_phy:
|
||||
raise NotImplementedError("TODO: add support for `request.primary_phy`")
|
||||
if request.secondary_phy:
|
||||
raise NotImplementedError("TODO: add support for `request.secondary_phy`")
|
||||
|
||||
if self.device.is_advertising:
|
||||
raise NotImplementedError('TODO: add support for advertising sets')
|
||||
|
||||
if data := request.data:
|
||||
self.device.advertising_data = bytes(self.unpack_data_types(data))
|
||||
|
||||
if scan_response_data := request.scan_response_data:
|
||||
self.device.scan_response_data = bytes(
|
||||
self.unpack_data_types(scan_response_data)
|
||||
)
|
||||
scannable = True
|
||||
else:
|
||||
scannable = False
|
||||
|
||||
# Retrieve services data
|
||||
for service in self.device.gatt_server.attributes:
|
||||
if isinstance(service, Service) and (
|
||||
service_data := service.get_advertising_data()
|
||||
):
|
||||
service_uuid = service.uuid.to_hex_str('-')
|
||||
if (
|
||||
service_uuid in request.data.incomplete_service_class_uuids16
|
||||
or service_uuid in request.data.complete_service_class_uuids16
|
||||
or service_uuid in request.data.incomplete_service_class_uuids32
|
||||
or service_uuid in request.data.complete_service_class_uuids32
|
||||
or service_uuid
|
||||
in request.data.incomplete_service_class_uuids128
|
||||
or service_uuid in request.data.complete_service_class_uuids128
|
||||
):
|
||||
self.device.advertising_data += service_data
|
||||
if (
|
||||
service_uuid
|
||||
in scan_response_data.incomplete_service_class_uuids16
|
||||
or service_uuid
|
||||
in scan_response_data.complete_service_class_uuids16
|
||||
or service_uuid
|
||||
in scan_response_data.incomplete_service_class_uuids32
|
||||
or service_uuid
|
||||
in scan_response_data.complete_service_class_uuids32
|
||||
or service_uuid
|
||||
in scan_response_data.incomplete_service_class_uuids128
|
||||
or service_uuid
|
||||
in scan_response_data.complete_service_class_uuids128
|
||||
):
|
||||
self.device.scan_response_data += service_data
|
||||
|
||||
target = None
|
||||
if request.connectable and scannable:
|
||||
advertising_type = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE
|
||||
elif scannable:
|
||||
advertising_type = AdvertisingType.UNDIRECTED_SCANNABLE
|
||||
else:
|
||||
advertising_type = AdvertisingType.UNDIRECTED
|
||||
else:
|
||||
target = None
|
||||
advertising_type = AdvertisingType.UNDIRECTED
|
||||
|
||||
if request.target:
|
||||
# Need to reverse bytes order since Bumble Address is using MSB.
|
||||
target_bytes = bytes(reversed(request.target))
|
||||
if request.target_variant() == "public":
|
||||
target = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
|
||||
advertising_type = (
|
||||
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
|
||||
) # FIXME: HIGH_DUTY ?
|
||||
else:
|
||||
target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
|
||||
advertising_type = (
|
||||
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
|
||||
) # FIXME: HIGH_DUTY ?
|
||||
|
||||
if request.connectable:
|
||||
|
||||
def on_connection(connection: bumble.device.Connection) -> None:
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
):
|
||||
pending_connection.set_result(connection)
|
||||
|
||||
self.device.on('connection', on_connection)
|
||||
|
||||
try:
|
||||
while True:
|
||||
if not self.device.is_advertising:
|
||||
self.log.info('Advertise')
|
||||
await self.device.start_advertising(
|
||||
target=target,
|
||||
advertising_type=advertising_type,
|
||||
own_address_type=request.own_address_type,
|
||||
)
|
||||
|
||||
if not request.connectable:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
pending_connection: asyncio.Future[
|
||||
bumble.device.Connection
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
|
||||
self.log.info('Wait for LE connection...')
|
||||
connection = await pending_connection
|
||||
|
||||
self.log.info(
|
||||
f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})"
|
||||
)
|
||||
|
||||
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
|
||||
yield AdvertiseResponse(connection=Connection(cookie=cookie))
|
||||
|
||||
# wait a small delay before restarting the advertisement.
|
||||
await asyncio.sleep(1)
|
||||
finally:
|
||||
if request.connectable:
|
||||
self.device.remove_listener('connection', on_connection) # type: ignore
|
||||
|
||||
try:
|
||||
self.log.info('Stop advertising')
|
||||
await self.device.abort_on('flush', self.device.stop_advertising())
|
||||
except:
|
||||
pass
|
||||
|
||||
@utils.rpc
|
||||
async def Scan(
|
||||
self, request: ScanRequest, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[ScanningResponse, None]:
|
||||
# TODO: modify `start_scanning` to accept floats instead of int for ms values
|
||||
if request.phys:
|
||||
raise NotImplementedError("TODO: add support for `request.phys`")
|
||||
|
||||
self.log.info('Scan')
|
||||
|
||||
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
|
||||
handler = self.device.on('advertisement', scan_queue.put_nowait)
|
||||
await self.device.start_scanning(
|
||||
legacy=request.legacy,
|
||||
active=not request.passive,
|
||||
own_address_type=request.own_address_type,
|
||||
scan_interval=int(request.interval)
|
||||
if request.interval
|
||||
else DEVICE_DEFAULT_SCAN_INTERVAL,
|
||||
scan_window=int(request.window)
|
||||
if request.window
|
||||
else DEVICE_DEFAULT_SCAN_WINDOW,
|
||||
)
|
||||
|
||||
try:
|
||||
# TODO: add support for `direct_address` in Bumble
|
||||
# TODO: add support for `periodic_advertising_interval` in Bumble
|
||||
while adv := await scan_queue.get():
|
||||
sr = ScanningResponse(
|
||||
legacy=adv.is_legacy,
|
||||
connectable=adv.is_connectable,
|
||||
scannable=adv.is_scannable,
|
||||
truncated=adv.is_truncated,
|
||||
sid=adv.sid,
|
||||
primary_phy=PRIMARY_PHY_MAP[adv.primary_phy],
|
||||
secondary_phy=SECONDARY_PHY_MAP[adv.secondary_phy],
|
||||
tx_power=adv.tx_power,
|
||||
rssi=adv.rssi,
|
||||
data=self.pack_data_types(adv.data),
|
||||
)
|
||||
|
||||
if adv.address.address_type == Address.PUBLIC_DEVICE_ADDRESS:
|
||||
sr.public = bytes(reversed(bytes(adv.address)))
|
||||
elif adv.address.address_type == Address.RANDOM_DEVICE_ADDRESS:
|
||||
sr.random = bytes(reversed(bytes(adv.address)))
|
||||
elif adv.address.address_type == Address.PUBLIC_IDENTITY_ADDRESS:
|
||||
sr.public_identity = bytes(reversed(bytes(adv.address)))
|
||||
else:
|
||||
sr.random_static_identity = bytes(reversed(bytes(adv.address)))
|
||||
|
||||
yield sr
|
||||
|
||||
finally:
|
||||
self.device.remove_listener('advertisement', handler) # type: ignore
|
||||
try:
|
||||
self.log.info('Stop scanning')
|
||||
await self.device.abort_on('flush', self.device.stop_scanning())
|
||||
except:
|
||||
pass
|
||||
|
||||
@utils.rpc
|
||||
async def Inquiry(
|
||||
self, request: empty_pb2.Empty, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[InquiryResponse, None]:
|
||||
self.log.info('Inquiry')
|
||||
|
||||
inquiry_queue: asyncio.Queue[
|
||||
Optional[Tuple[Address, int, AdvertisingData, int]]
|
||||
] = asyncio.Queue()
|
||||
complete_handler = self.device.on(
|
||||
'inquiry_complete', lambda: inquiry_queue.put_nowait(None)
|
||||
)
|
||||
result_handler = self.device.on( # type: ignore
|
||||
'inquiry_result',
|
||||
lambda address, class_of_device, eir_data, rssi: inquiry_queue.put_nowait( # type: ignore
|
||||
(address, class_of_device, eir_data, rssi) # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
await self.device.start_discovery(auto_restart=False)
|
||||
try:
|
||||
while inquiry_result := await inquiry_queue.get():
|
||||
(address, class_of_device, eir_data, rssi) = inquiry_result
|
||||
# FIXME: if needed, add support for `page_scan_repetition_mode` and `clock_offset` in Bumble
|
||||
yield InquiryResponse(
|
||||
address=bytes(reversed(bytes(address))),
|
||||
class_of_device=class_of_device,
|
||||
rssi=rssi,
|
||||
data=self.pack_data_types(eir_data),
|
||||
)
|
||||
|
||||
finally:
|
||||
self.device.remove_listener('inquiry_complete', complete_handler) # type: ignore
|
||||
self.device.remove_listener('inquiry_result', result_handler) # type: ignore
|
||||
try:
|
||||
self.log.info('Stop inquiry')
|
||||
await self.device.abort_on('flush', self.device.stop_discovery())
|
||||
except:
|
||||
pass
|
||||
|
||||
@utils.rpc
|
||||
async def SetDiscoverabilityMode(
|
||||
self, request: SetDiscoverabilityModeRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
self.log.info("SetDiscoverabilityMode")
|
||||
await self.device.set_discoverable(request.mode != NOT_DISCOVERABLE)
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@utils.rpc
|
||||
async def SetConnectabilityMode(
|
||||
self, request: SetConnectabilityModeRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
self.log.info("SetConnectabilityMode")
|
||||
await self.device.set_connectable(request.mode != NOT_CONNECTABLE)
|
||||
return empty_pb2.Empty()
|
||||
|
||||
def unpack_data_types(self, dt: DataTypes) -> AdvertisingData:
|
||||
ad_structures: List[Tuple[int, bytes]] = []
|
||||
|
||||
uuids: List[str]
|
||||
datas: Dict[str, bytes]
|
||||
|
||||
def uuid128_from_str(uuid: str) -> bytes:
|
||||
"""Decode a 128-bit uuid encoded as XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
|
||||
to byte format."""
|
||||
return bytes(reversed(bytes.fromhex(uuid.replace('-', ''))))
|
||||
|
||||
def uuid32_from_str(uuid: str) -> bytes:
|
||||
"""Decode a 32-bit uuid encoded as XXXXXXXX to byte format."""
|
||||
return bytes(reversed(bytes.fromhex(uuid)))
|
||||
|
||||
def uuid16_from_str(uuid: str) -> bytes:
|
||||
"""Decode a 16-bit uuid encoded as XXXX to byte format."""
|
||||
return bytes(reversed(bytes.fromhex(uuid)))
|
||||
|
||||
if uuids := dt.incomplete_service_class_uuids16:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
b''.join([uuid16_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.complete_service_class_uuids16:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
b''.join([uuid16_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.incomplete_service_class_uuids32:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
|
||||
b''.join([uuid32_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.complete_service_class_uuids32:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
|
||||
b''.join([uuid32_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.incomplete_service_class_uuids128:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
|
||||
b''.join([uuid128_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.complete_service_class_uuids128:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
|
||||
b''.join([uuid128_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if dt.HasField('include_shortened_local_name'):
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.SHORTENED_LOCAL_NAME,
|
||||
bytes(self.device.name[:8], 'utf-8'),
|
||||
)
|
||||
)
|
||||
elif dt.shortened_local_name:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.SHORTENED_LOCAL_NAME,
|
||||
bytes(dt.shortened_local_name, 'utf-8'),
|
||||
)
|
||||
)
|
||||
if dt.HasField('include_complete_local_name'):
|
||||
ad_structures.append(
|
||||
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.device.name, 'utf-8'))
|
||||
)
|
||||
elif dt.complete_local_name:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.COMPLETE_LOCAL_NAME,
|
||||
bytes(dt.complete_local_name, 'utf-8'),
|
||||
)
|
||||
)
|
||||
if dt.HasField('include_tx_power_level'):
|
||||
raise ValueError('unsupported data type')
|
||||
elif dt.tx_power_level:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.TX_POWER_LEVEL,
|
||||
bytes(struct.pack('<I', dt.tx_power_level)[:1]),
|
||||
)
|
||||
)
|
||||
if dt.HasField('include_class_of_device'):
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.CLASS_OF_DEVICE,
|
||||
bytes(struct.pack('<I', self.device.class_of_device)[:-1]),
|
||||
)
|
||||
)
|
||||
elif dt.class_of_device:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.CLASS_OF_DEVICE,
|
||||
bytes(struct.pack('<I', dt.class_of_device)[:-1]),
|
||||
)
|
||||
)
|
||||
if dt.peripheral_connection_interval_min:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE,
|
||||
bytes(
|
||||
[
|
||||
*struct.pack('<H', dt.peripheral_connection_interval_min),
|
||||
*struct.pack(
|
||||
'<H',
|
||||
dt.peripheral_connection_interval_max
|
||||
if dt.peripheral_connection_interval_max
|
||||
else dt.peripheral_connection_interval_min,
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
if uuids := dt.service_solicitation_uuids16:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
|
||||
b''.join([uuid16_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.service_solicitation_uuids32:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
|
||||
b''.join([uuid32_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.service_solicitation_uuids128:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
|
||||
b''.join([uuid128_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if datas := dt.service_data_uuid16:
|
||||
ad_structures.extend(
|
||||
[
|
||||
(
|
||||
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
|
||||
uuid16_from_str(uuid) + data,
|
||||
)
|
||||
for uuid, data in datas.items()
|
||||
]
|
||||
)
|
||||
if datas := dt.service_data_uuid32:
|
||||
ad_structures.extend(
|
||||
[
|
||||
(
|
||||
AdvertisingData.SERVICE_DATA_32_BIT_UUID,
|
||||
uuid32_from_str(uuid) + data,
|
||||
)
|
||||
for uuid, data in datas.items()
|
||||
]
|
||||
)
|
||||
if datas := dt.service_data_uuid128:
|
||||
ad_structures.extend(
|
||||
[
|
||||
(
|
||||
AdvertisingData.SERVICE_DATA_128_BIT_UUID,
|
||||
uuid128_from_str(uuid) + data,
|
||||
)
|
||||
for uuid, data in datas.items()
|
||||
]
|
||||
)
|
||||
if dt.appearance:
|
||||
ad_structures.append(
|
||||
(AdvertisingData.APPEARANCE, struct.pack('<H', dt.appearance))
|
||||
)
|
||||
if dt.advertising_interval:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.ADVERTISING_INTERVAL,
|
||||
struct.pack('<H', dt.advertising_interval),
|
||||
)
|
||||
)
|
||||
if dt.uri:
|
||||
ad_structures.append((AdvertisingData.URI, bytes(dt.uri, 'utf-8')))
|
||||
if dt.le_supported_features:
|
||||
ad_structures.append(
|
||||
(AdvertisingData.LE_SUPPORTED_FEATURES, dt.le_supported_features)
|
||||
)
|
||||
if dt.manufacturer_specific_data:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.MANUFACTURER_SPECIFIC_DATA,
|
||||
dt.manufacturer_specific_data,
|
||||
)
|
||||
)
|
||||
|
||||
return AdvertisingData(ad_structures)
|
||||
|
||||
def pack_data_types(self, ad: AdvertisingData) -> DataTypes:
|
||||
dt = DataTypes()
|
||||
uuids: List[UUID]
|
||||
s: str
|
||||
i: int
|
||||
ij: Tuple[int, int]
|
||||
uuid_data: Tuple[UUID, bytes]
|
||||
data: bytes
|
||||
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.incomplete_service_class_uuids16.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.complete_service_class_uuids16.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.incomplete_service_class_uuids32.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.complete_service_class_uuids32.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.incomplete_service_class_uuids128.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.complete_service_class_uuids128.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if s := cast(str, ad.get(AdvertisingData.SHORTENED_LOCAL_NAME)):
|
||||
dt.shortened_local_name = s
|
||||
if s := cast(str, ad.get(AdvertisingData.COMPLETE_LOCAL_NAME)):
|
||||
dt.complete_local_name = s
|
||||
if i := cast(int, ad.get(AdvertisingData.TX_POWER_LEVEL)):
|
||||
dt.tx_power_level = i
|
||||
if i := cast(int, ad.get(AdvertisingData.CLASS_OF_DEVICE)):
|
||||
dt.class_of_device = i
|
||||
if ij := cast(
|
||||
Tuple[int, int],
|
||||
ad.get(AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE),
|
||||
):
|
||||
dt.peripheral_connection_interval_min = ij[0]
|
||||
dt.peripheral_connection_interval_max = ij[1]
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS),
|
||||
):
|
||||
dt.service_solicitation_uuids16.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS),
|
||||
):
|
||||
dt.service_solicitation_uuids32.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS),
|
||||
):
|
||||
dt.service_solicitation_uuids128.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuid_data := cast(
|
||||
Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_16_BIT_UUID)
|
||||
):
|
||||
dt.service_data_uuid16[uuid_data[0].to_hex_str('-')] = uuid_data[1]
|
||||
if uuid_data := cast(
|
||||
Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_32_BIT_UUID)
|
||||
):
|
||||
dt.service_data_uuid32[uuid_data[0].to_hex_str('-')] = uuid_data[1]
|
||||
if uuid_data := cast(
|
||||
Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_128_BIT_UUID)
|
||||
):
|
||||
dt.service_data_uuid128[uuid_data[0].to_hex_str('-')] = uuid_data[1]
|
||||
if data := cast(bytes, ad.get(AdvertisingData.PUBLIC_TARGET_ADDRESS, raw=True)):
|
||||
dt.public_target_addresses.extend(
|
||||
[data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))]
|
||||
)
|
||||
if data := cast(bytes, ad.get(AdvertisingData.RANDOM_TARGET_ADDRESS, raw=True)):
|
||||
dt.random_target_addresses.extend(
|
||||
[data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))]
|
||||
)
|
||||
if i := cast(int, ad.get(AdvertisingData.APPEARANCE)):
|
||||
dt.appearance = i
|
||||
if i := cast(int, ad.get(AdvertisingData.ADVERTISING_INTERVAL)):
|
||||
dt.advertising_interval = i
|
||||
if s := cast(str, ad.get(AdvertisingData.URI)):
|
||||
dt.uri = s
|
||||
if data := cast(bytes, ad.get(AdvertisingData.LE_SUPPORTED_FEATURES, raw=True)):
|
||||
dt.le_supported_features = data
|
||||
if data := cast(
|
||||
bytes, ad.get(AdvertisingData.MANUFACTURER_SPECIFIC_DATA, raw=True)
|
||||
):
|
||||
dt.manufacturer_specific_data = data
|
||||
|
||||
return dt
|
||||
0
bumble/pandora/py.typed
Normal file
0
bumble/pandora/py.typed
Normal file
529
bumble/pandora/security.py
Normal file
529
bumble/pandora/security.py
Normal file
@@ -0,0 +1,529 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import grpc
|
||||
import logging
|
||||
|
||||
from . import utils
|
||||
from .config import Config
|
||||
from bumble import hci
|
||||
from bumble.core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_LE_TRANSPORT,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ProtocolError,
|
||||
)
|
||||
from bumble.device import Connection as BumbleConnection, Device
|
||||
from bumble.hci import HCI_Error
|
||||
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
|
||||
from contextlib import suppress
|
||||
from google.protobuf import (
|
||||
any_pb2,
|
||||
empty_pb2,
|
||||
wrappers_pb2,
|
||||
) # pytype: disable=pyi-error
|
||||
from google.protobuf.wrappers_pb2 import BoolValue # pytype: disable=pyi-error
|
||||
from pandora.host_pb2 import Connection
|
||||
from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer
|
||||
from pandora.security_pb2 import (
|
||||
LE_LEVEL1,
|
||||
LE_LEVEL2,
|
||||
LE_LEVEL3,
|
||||
LE_LEVEL4,
|
||||
LEVEL0,
|
||||
LEVEL1,
|
||||
LEVEL2,
|
||||
LEVEL3,
|
||||
LEVEL4,
|
||||
DeleteBondRequest,
|
||||
IsBondedRequest,
|
||||
LESecurityLevel,
|
||||
PairingEvent,
|
||||
PairingEventAnswer,
|
||||
SecureRequest,
|
||||
SecureResponse,
|
||||
SecurityLevel,
|
||||
WaitSecurityRequest,
|
||||
WaitSecurityResponse,
|
||||
)
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Optional, Union
|
||||
|
||||
|
||||
class PairingDelegate(BasePairingDelegate):
|
||||
def __init__(
|
||||
self,
|
||||
connection: BumbleConnection,
|
||||
service: "SecurityService",
|
||||
io_capability: BasePairingDelegate.IoCapability = BasePairingDelegate.NO_OUTPUT_NO_INPUT,
|
||||
local_initiator_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
|
||||
local_responder_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
|
||||
) -> None:
|
||||
self.log = utils.BumbleServerLoggerAdapter(
|
||||
logging.getLogger(),
|
||||
{'service_name': 'Security', 'device': connection.device},
|
||||
)
|
||||
self.connection = connection
|
||||
self.service = service
|
||||
super().__init__(
|
||||
io_capability,
|
||||
local_initiator_key_distribution,
|
||||
local_responder_key_distribution,
|
||||
)
|
||||
|
||||
async def accept(self) -> bool:
|
||||
return True
|
||||
|
||||
def add_origin(self, ev: PairingEvent) -> PairingEvent:
|
||||
if not self.connection.is_incomplete:
|
||||
assert ev.connection
|
||||
ev.connection.CopyFrom(
|
||||
Connection(
|
||||
cookie=any_pb2.Any(value=self.connection.handle.to_bytes(4, 'big'))
|
||||
)
|
||||
)
|
||||
else:
|
||||
# In BR/EDR, connection may not be complete,
|
||||
# use address instead
|
||||
assert self.connection.transport == BT_BR_EDR_TRANSPORT
|
||||
ev.address = bytes(reversed(bytes(self.connection.peer_address)))
|
||||
|
||||
return ev
|
||||
|
||||
async def confirm(self, auto: bool = False) -> bool:
|
||||
self.log.info(
|
||||
f"Pairing event: `just_works` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
if self.service.event_queue is None or self.service.event_answer is None:
|
||||
return True
|
||||
|
||||
event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
assert answer.event == event
|
||||
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
|
||||
return answer.confirm
|
||||
|
||||
async def compare_numbers(self, number: int, digits: int = 6) -> bool:
|
||||
self.log.info(
|
||||
f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
if self.service.event_queue is None or self.service.event_answer is None:
|
||||
raise RuntimeError('security: unhandled number comparison request')
|
||||
|
||||
event = self.add_origin(PairingEvent(numeric_comparison=number))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
assert answer.event == event
|
||||
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
|
||||
return answer.confirm
|
||||
|
||||
async def get_number(self) -> Optional[int]:
|
||||
self.log.info(
|
||||
f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
if self.service.event_queue is None or self.service.event_answer is None:
|
||||
raise RuntimeError('security: unhandled number request')
|
||||
|
||||
event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
assert answer.event == event
|
||||
if answer.answer_variant() is None:
|
||||
return None
|
||||
assert answer.answer_variant() == 'passkey'
|
||||
return answer.passkey
|
||||
|
||||
async def get_string(self, max_length: int) -> Optional[str]:
|
||||
self.log.info(
|
||||
f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
if self.service.event_queue is None or self.service.event_answer is None:
|
||||
raise RuntimeError('security: unhandled pin_code request')
|
||||
|
||||
event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
assert answer.event == event
|
||||
if answer.answer_variant() is None:
|
||||
return None
|
||||
assert answer.answer_variant() == 'pin'
|
||||
|
||||
if answer.pin is None:
|
||||
return None
|
||||
|
||||
pin = answer.pin.decode('utf-8')
|
||||
if not pin or len(pin) > max_length:
|
||||
raise ValueError(f'Pin must be utf-8 encoded up to {max_length} bytes')
|
||||
|
||||
return pin
|
||||
|
||||
async def display_number(self, number: int, digits: int = 6) -> None:
|
||||
if (
|
||||
self.connection.transport == BT_BR_EDR_TRANSPORT
|
||||
and self.io_capability == BasePairingDelegate.DISPLAY_OUTPUT_ONLY
|
||||
):
|
||||
return
|
||||
|
||||
self.log.info(
|
||||
f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
if self.service.event_queue is None:
|
||||
raise RuntimeError('security: unhandled number display request')
|
||||
|
||||
event = self.add_origin(PairingEvent(passkey_entry_notification=number))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
|
||||
|
||||
BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = {
|
||||
LEVEL0: lambda connection: True,
|
||||
LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated,
|
||||
LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated,
|
||||
LEVEL3: lambda connection: connection.encryption != 0
|
||||
and connection.authenticated
|
||||
and connection.link_key_type
|
||||
in (
|
||||
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
|
||||
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
|
||||
),
|
||||
LEVEL4: lambda connection: connection.encryption
|
||||
== hci.HCI_Encryption_Change_Event.AES_CCM
|
||||
and connection.authenticated
|
||||
and connection.link_key_type
|
||||
== hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
|
||||
}
|
||||
|
||||
LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = {
|
||||
LE_LEVEL1: lambda connection: True,
|
||||
LE_LEVEL2: lambda connection: connection.encryption != 0,
|
||||
LE_LEVEL3: lambda connection: connection.encryption != 0
|
||||
and connection.authenticated,
|
||||
LE_LEVEL4: lambda connection: connection.encryption != 0
|
||||
and connection.authenticated
|
||||
and connection.sc,
|
||||
}
|
||||
|
||||
|
||||
class SecurityService(SecurityServicer):
|
||||
def __init__(self, device: Device, config: Config) -> None:
|
||||
self.log = utils.BumbleServerLoggerAdapter(
|
||||
logging.getLogger(), {'service_name': 'Security', 'device': device}
|
||||
)
|
||||
self.event_queue: Optional[asyncio.Queue[PairingEvent]] = None
|
||||
self.event_answer: Optional[AsyncIterator[PairingEventAnswer]] = None
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
def pairing_config_factory(connection: BumbleConnection) -> PairingConfig:
|
||||
return PairingConfig(
|
||||
sc=config.pairing_sc_enable,
|
||||
mitm=config.pairing_mitm_enable,
|
||||
bonding=config.pairing_bonding_enable,
|
||||
delegate=PairingDelegate(
|
||||
connection,
|
||||
self,
|
||||
io_capability=config.io_capability,
|
||||
local_initiator_key_distribution=config.smp_local_initiator_key_distribution,
|
||||
local_responder_key_distribution=config.smp_local_responder_key_distribution,
|
||||
),
|
||||
)
|
||||
|
||||
self.device.pairing_config_factory = pairing_config_factory
|
||||
|
||||
@utils.rpc
|
||||
async def OnPairing(
|
||||
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[PairingEvent, None]:
|
||||
self.log.info('OnPairing')
|
||||
|
||||
if self.event_queue is not None:
|
||||
raise RuntimeError('already streaming pairing events')
|
||||
|
||||
if len(self.device.connections):
|
||||
raise RuntimeError(
|
||||
'the `OnPairing` method shall be initiated before establishing any connections.'
|
||||
)
|
||||
|
||||
self.event_queue = asyncio.Queue()
|
||||
self.event_answer = request
|
||||
|
||||
try:
|
||||
while event := await self.event_queue.get():
|
||||
yield event
|
||||
|
||||
finally:
|
||||
self.event_queue = None
|
||||
self.event_answer = None
|
||||
|
||||
@utils.rpc
|
||||
async def Secure(
|
||||
self, request: SecureRequest, context: grpc.ServicerContext
|
||||
) -> SecureResponse:
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
self.log.info(f"Secure: {connection_handle}")
|
||||
|
||||
connection = self.device.lookup_connection(connection_handle)
|
||||
assert connection
|
||||
|
||||
oneof = request.WhichOneof('level')
|
||||
level = getattr(request, oneof)
|
||||
assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[
|
||||
connection.transport
|
||||
] == oneof
|
||||
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
return SecureResponse(success=empty_pb2.Empty())
|
||||
|
||||
# trigger pairing if needed
|
||||
if self.need_pairing(connection, level):
|
||||
try:
|
||||
self.log.info('Pair...')
|
||||
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
):
|
||||
wait_for_security: asyncio.Future[
|
||||
bool
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
|
||||
connection.on("pairing_failure", wait_for_security.set_exception)
|
||||
|
||||
connection.request_pairing()
|
||||
|
||||
await wait_for_security
|
||||
else:
|
||||
await connection.pair()
|
||||
|
||||
self.log.info('Paired')
|
||||
except asyncio.CancelledError:
|
||||
self.log.warning("Connection died during encryption")
|
||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||
except (HCI_Error, ProtocolError) as e:
|
||||
self.log.warning(f"Pairing failure: {e}")
|
||||
return SecureResponse(pairing_failure=empty_pb2.Empty())
|
||||
|
||||
# trigger authentication if needed
|
||||
if self.need_authentication(connection, level):
|
||||
try:
|
||||
self.log.info('Authenticate...')
|
||||
await connection.authenticate()
|
||||
self.log.info('Authenticated')
|
||||
except asyncio.CancelledError:
|
||||
self.log.warning("Connection died during authentication")
|
||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||
except (HCI_Error, ProtocolError) as e:
|
||||
self.log.warning(f"Authentication failure: {e}")
|
||||
return SecureResponse(authentication_failure=empty_pb2.Empty())
|
||||
|
||||
# trigger encryption if needed
|
||||
if self.need_encryption(connection, level):
|
||||
try:
|
||||
self.log.info('Encrypt...')
|
||||
await connection.encrypt()
|
||||
self.log.info('Encrypted')
|
||||
except asyncio.CancelledError:
|
||||
self.log.warning("Connection died during encryption")
|
||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||
except (HCI_Error, ProtocolError) as e:
|
||||
self.log.warning(f"Encryption failure: {e}")
|
||||
return SecureResponse(encryption_failure=empty_pb2.Empty())
|
||||
|
||||
# security level has been reached ?
|
||||
if self.reached_security_level(connection, level):
|
||||
return SecureResponse(success=empty_pb2.Empty())
|
||||
return SecureResponse(not_reached=empty_pb2.Empty())
|
||||
|
||||
@utils.rpc
|
||||
async def WaitSecurity(
|
||||
self, request: WaitSecurityRequest, context: grpc.ServicerContext
|
||||
) -> WaitSecurityResponse:
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
self.log.info(f"WaitSecurity: {connection_handle}")
|
||||
|
||||
connection = self.device.lookup_connection(connection_handle)
|
||||
assert connection
|
||||
|
||||
assert request.level
|
||||
level = request.level
|
||||
assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[
|
||||
connection.transport
|
||||
] == request.level_variant()
|
||||
|
||||
wait_for_security: asyncio.Future[
|
||||
str
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
authenticate_task: Optional[asyncio.Future[None]] = None
|
||||
|
||||
async def authenticate() -> None:
|
||||
assert connection
|
||||
if (encryption := connection.encryption) != 0:
|
||||
self.log.debug('Disable encryption...')
|
||||
try:
|
||||
await connection.encrypt(enable=False)
|
||||
except:
|
||||
pass
|
||||
self.log.debug('Disable encryption: done')
|
||||
|
||||
self.log.debug('Authenticate...')
|
||||
await connection.authenticate()
|
||||
self.log.debug('Authenticate: done')
|
||||
|
||||
if encryption != 0 and connection.encryption != encryption:
|
||||
self.log.debug('Re-enable encryption...')
|
||||
await connection.encrypt()
|
||||
self.log.debug('Re-enable encryption: done')
|
||||
|
||||
def set_failure(name: str) -> Callable[..., None]:
|
||||
def wrapper(*args: Any) -> None:
|
||||
self.log.info(f'Wait for security: error `{name}`: {args}')
|
||||
wait_for_security.set_result(name)
|
||||
|
||||
return wrapper
|
||||
|
||||
def try_set_success(*_: Any) -> None:
|
||||
assert connection
|
||||
if self.reached_security_level(connection, level):
|
||||
self.log.info('Wait for security: done')
|
||||
wait_for_security.set_result('success')
|
||||
|
||||
def on_encryption_change(*_: Any) -> None:
|
||||
assert connection
|
||||
if self.reached_security_level(connection, level):
|
||||
self.log.info('Wait for security: done')
|
||||
wait_for_security.set_result('success')
|
||||
elif (
|
||||
connection.transport == BT_BR_EDR_TRANSPORT
|
||||
and self.need_authentication(connection, level)
|
||||
):
|
||||
nonlocal authenticate_task
|
||||
if authenticate_task is None:
|
||||
authenticate_task = asyncio.create_task(authenticate())
|
||||
|
||||
listeners: Dict[str, Callable[..., None]] = {
|
||||
'disconnection': set_failure('connection_died'),
|
||||
'pairing_failure': set_failure('pairing_failure'),
|
||||
'connection_authentication_failure': set_failure('authentication_failure'),
|
||||
'connection_encryption_failure': set_failure('encryption_failure'),
|
||||
'pairing': try_set_success,
|
||||
'connection_authentication': try_set_success,
|
||||
'connection_encryption_change': on_encryption_change,
|
||||
}
|
||||
|
||||
# register event handlers
|
||||
for event, listener in listeners.items():
|
||||
connection.on(event, listener)
|
||||
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
return WaitSecurityResponse(success=empty_pb2.Empty())
|
||||
|
||||
self.log.info('Wait for security...')
|
||||
kwargs = {}
|
||||
kwargs[await wait_for_security] = empty_pb2.Empty()
|
||||
|
||||
# remove event handlers
|
||||
for event, listener in listeners.items():
|
||||
connection.remove_listener(event, listener) # type: ignore
|
||||
|
||||
# wait for `authenticate` to finish if any
|
||||
if authenticate_task is not None:
|
||||
self.log.info('Wait for authentication...')
|
||||
try:
|
||||
await authenticate_task # type: ignore
|
||||
except:
|
||||
pass
|
||||
self.log.info('Authenticated')
|
||||
|
||||
return WaitSecurityResponse(**kwargs)
|
||||
|
||||
def reached_security_level(
|
||||
self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
|
||||
) -> bool:
|
||||
self.log.debug(
|
||||
str(
|
||||
{
|
||||
'level': level,
|
||||
'encryption': connection.encryption,
|
||||
'authenticated': connection.authenticated,
|
||||
'sc': connection.sc,
|
||||
'link_key_type': connection.link_key_type,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(level, LESecurityLevel):
|
||||
return LE_LEVEL_REACHED[level](connection)
|
||||
|
||||
return BR_LEVEL_REACHED[level](connection)
|
||||
|
||||
def need_pairing(self, connection: BumbleConnection, level: int) -> bool:
|
||||
if connection.transport == BT_LE_TRANSPORT:
|
||||
return level >= LE_LEVEL3 and not connection.authenticated
|
||||
return False
|
||||
|
||||
def need_authentication(self, connection: BumbleConnection, level: int) -> bool:
|
||||
if connection.transport == BT_LE_TRANSPORT:
|
||||
return False
|
||||
if level == LEVEL2 and connection.encryption != 0:
|
||||
return not connection.authenticated
|
||||
return level >= LEVEL2 and not connection.authenticated
|
||||
|
||||
def need_encryption(self, connection: BumbleConnection, level: int) -> bool:
|
||||
# TODO(abel): need to support MITM
|
||||
if connection.transport == BT_LE_TRANSPORT:
|
||||
return level == LE_LEVEL2 and not connection.encryption
|
||||
return level >= LEVEL2 and not connection.encryption
|
||||
|
||||
|
||||
class SecurityStorageService(SecurityStorageServicer):
|
||||
def __init__(self, device: Device, config: Config) -> None:
|
||||
self.log = utils.BumbleServerLoggerAdapter(
|
||||
logging.getLogger(), {'service_name': 'SecurityStorage', 'device': device}
|
||||
)
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
@utils.rpc
|
||||
async def IsBonded(
|
||||
self, request: IsBondedRequest, context: grpc.ServicerContext
|
||||
) -> wrappers_pb2.BoolValue:
|
||||
address = utils.address_from_request(request, request.WhichOneof("address"))
|
||||
self.log.info(f"IsBonded: {address}")
|
||||
|
||||
if self.device.keystore is not None:
|
||||
is_bonded = await self.device.keystore.get(str(address)) is not None
|
||||
else:
|
||||
is_bonded = False
|
||||
|
||||
return BoolValue(value=is_bonded)
|
||||
|
||||
@utils.rpc
|
||||
async def DeleteBond(
|
||||
self, request: DeleteBondRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
address = utils.address_from_request(request, request.WhichOneof("address"))
|
||||
self.log.info(f"DeleteBond: {address}")
|
||||
|
||||
if self.device.keystore is not None:
|
||||
with suppress(KeyError):
|
||||
await self.device.keystore.delete(str(address))
|
||||
|
||||
return empty_pb2.Empty()
|
||||
112
bumble/pandora/utils.py
Normal file
112
bumble/pandora/utils.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import grpc
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from bumble.device import Device
|
||||
from bumble.hci import Address
|
||||
from google.protobuf.message import Message # pytype: disable=pyi-error
|
||||
from typing import Any, Dict, Generator, MutableMapping, Optional, Tuple
|
||||
|
||||
ADDRESS_TYPES: Dict[str, int] = {
|
||||
"public": Address.PUBLIC_DEVICE_ADDRESS,
|
||||
"random": Address.RANDOM_DEVICE_ADDRESS,
|
||||
"public_identity": Address.PUBLIC_IDENTITY_ADDRESS,
|
||||
"random_static_identity": Address.RANDOM_IDENTITY_ADDRESS,
|
||||
}
|
||||
|
||||
|
||||
def address_from_request(request: Message, field: Optional[str]) -> Address:
|
||||
if field is None:
|
||||
return Address.ANY
|
||||
return Address(bytes(reversed(getattr(request, field))), ADDRESS_TYPES[field])
|
||||
|
||||
|
||||
class BumbleServerLoggerAdapter(logging.LoggerAdapter): # type: ignore
|
||||
"""Formats logs from the PandoraClient."""
|
||||
|
||||
def process(
|
||||
self, msg: str, kwargs: MutableMapping[str, Any]
|
||||
) -> Tuple[str, MutableMapping[str, Any]]:
|
||||
assert self.extra
|
||||
service_name = self.extra['service_name']
|
||||
assert isinstance(service_name, str)
|
||||
device = self.extra['device']
|
||||
assert isinstance(device, Device)
|
||||
addr_bytes = bytes(
|
||||
reversed(bytes(device.public_address))
|
||||
) # pytype: disable=attribute-error
|
||||
addr = ':'.join([f'{x:02X}' for x in addr_bytes[4:]])
|
||||
return (f'[bumble.{service_name}:{addr}] {msg}', kwargs)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def exception_to_rpc_error(
|
||||
context: grpc.ServicerContext,
|
||||
) -> Generator[None, None, None]:
|
||||
try:
|
||||
yield None
|
||||
except NotImplementedError as e:
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
|
||||
context.set_details(str(e)) # type: ignore
|
||||
except ValueError as e:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT) # type: ignore
|
||||
context.set_details(str(e)) # type: ignore
|
||||
except RuntimeError as e:
|
||||
context.set_code(grpc.StatusCode.ABORTED) # type: ignore
|
||||
context.set_details(str(e)) # type: ignore
|
||||
|
||||
|
||||
# Decorate an RPC servicer method with a wrapper that transform exceptions to gRPC errors.
|
||||
def rpc(func: Any) -> Any:
|
||||
@functools.wraps(func)
|
||||
async def asyncgen_wrapper(
|
||||
self: Any, request: Any, context: grpc.ServicerContext
|
||||
) -> Any:
|
||||
with exception_to_rpc_error(context):
|
||||
async for v in func(self, request, context):
|
||||
yield v
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(
|
||||
self: Any, request: Any, context: grpc.ServicerContext
|
||||
) -> Any:
|
||||
with exception_to_rpc_error(context):
|
||||
return await func(self, request, context)
|
||||
|
||||
@functools.wraps(func)
|
||||
def gen_wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
|
||||
with exception_to_rpc_error(context):
|
||||
for v in func(self, request, context):
|
||||
yield v
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
|
||||
with exception_to_rpc_error(context):
|
||||
return func(self, request, context)
|
||||
|
||||
if inspect.isasyncgenfunction(func):
|
||||
return asyncgen_wrapper
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
|
||||
if inspect.isgenerator(func):
|
||||
return gen_wrapper
|
||||
|
||||
return wrapper
|
||||
376
bumble/smp.py
376
bumble/smp.py
@@ -26,16 +26,22 @@ from __future__ import annotations
|
||||
import logging
|
||||
import asyncio
|
||||
import secrets
|
||||
from typing import Dict, Optional, Type
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
from pyee import EventEmitter
|
||||
|
||||
from .colors import color
|
||||
from .hci import (
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY,
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
|
||||
HCI_KEYBOARD_ONLY_IO_CAPABILITY,
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
|
||||
Address,
|
||||
HCI_LE_Enable_Encryption_Command,
|
||||
HCI_Object,
|
||||
@@ -51,6 +57,10 @@ from .core import (
|
||||
from .keys import PairingKeys
|
||||
from . import crypto
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Connection, Device
|
||||
from bumble.pairing import PairingConfig
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -184,7 +194,7 @@ SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032'
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# -----------------------------------------------------------------------------
|
||||
def error_name(error_code):
|
||||
def error_name(error_code: int) -> str:
|
||||
return name_or_number(SMP_ERROR_NAMES, error_code)
|
||||
|
||||
|
||||
@@ -197,11 +207,12 @@ class SMP_Command:
|
||||
'''
|
||||
|
||||
smp_classes: Dict[int, Type[SMP_Command]] = {}
|
||||
fields: Any
|
||||
code = 0
|
||||
name = ''
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(pdu):
|
||||
def from_bytes(pdu: bytes) -> "SMP_Command":
|
||||
code = pdu[0]
|
||||
|
||||
cls = SMP_Command.smp_classes.get(code)
|
||||
@@ -217,11 +228,11 @@ class SMP_Command:
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def command_name(code):
|
||||
def command_name(code: int) -> str:
|
||||
return name_or_number(SMP_COMMAND_NAMES, code)
|
||||
|
||||
@staticmethod
|
||||
def auth_req_str(value):
|
||||
def auth_req_str(value: int) -> str:
|
||||
bonding_flags = value & 3
|
||||
mitm = (value >> 2) & 1
|
||||
sc = (value >> 3) & 1
|
||||
@@ -234,12 +245,12 @@ class SMP_Command:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def io_capability_name(io_capability):
|
||||
def io_capability_name(io_capability: int) -> str:
|
||||
return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability)
|
||||
|
||||
@staticmethod
|
||||
def key_distribution_str(value):
|
||||
key_types = []
|
||||
def key_distribution_str(value: int) -> str:
|
||||
key_types: List[str] = []
|
||||
if value & SMP_ENC_KEY_DISTRIBUTION_FLAG:
|
||||
key_types.append('ENC')
|
||||
if value & SMP_ID_KEY_DISTRIBUTION_FLAG:
|
||||
@@ -251,7 +262,7 @@ class SMP_Command:
|
||||
return ','.join(key_types)
|
||||
|
||||
@staticmethod
|
||||
def keypress_notification_type_name(notification_type):
|
||||
def keypress_notification_type_name(notification_type: int) -> str:
|
||||
return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type)
|
||||
|
||||
@staticmethod
|
||||
@@ -272,14 +283,14 @@ class SMP_Command:
|
||||
|
||||
return inner
|
||||
|
||||
def __init__(self, pdu=None, **kwargs):
|
||||
def __init__(self, pdu: Optional[bytes] = None, **kwargs: Any) -> None:
|
||||
if hasattr(self, 'fields') and kwargs:
|
||||
HCI_Object.init_from_fields(self, self.fields, kwargs)
|
||||
if pdu is None:
|
||||
pdu = bytes([self.code]) + HCI_Object.dict_to_bytes(kwargs, self.fields)
|
||||
self.pdu = pdu
|
||||
|
||||
def init_from_bytes(self, pdu, offset):
|
||||
def init_from_bytes(self, pdu: bytes, offset: int) -> None:
|
||||
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||
|
||||
def to_bytes(self):
|
||||
@@ -320,6 +331,13 @@ class SMP_Pairing_Request_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request
|
||||
'''
|
||||
|
||||
io_capability: int
|
||||
oob_data_flag: int
|
||||
auth_req: int
|
||||
maximum_encryption_key_size: int
|
||||
initiator_key_distribution: int
|
||||
responder_key_distribution: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass(
|
||||
@@ -343,6 +361,13 @@ class SMP_Pairing_Response_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response
|
||||
'''
|
||||
|
||||
io_capability: int
|
||||
oob_data_flag: int
|
||||
auth_req: int
|
||||
maximum_encryption_key_size: int
|
||||
initiator_key_distribution: int
|
||||
responder_key_distribution: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('confirm_value', 16)])
|
||||
@@ -351,6 +376,8 @@ class SMP_Pairing_Confirm_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm
|
||||
'''
|
||||
|
||||
confirm_value: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('random_value', 16)])
|
||||
@@ -359,6 +386,8 @@ class SMP_Pairing_Random_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random
|
||||
'''
|
||||
|
||||
random_value: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('reason', {'size': 1, 'mapper': error_name})])
|
||||
@@ -367,6 +396,8 @@ class SMP_Pairing_Failed_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed
|
||||
'''
|
||||
|
||||
reason: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('public_key_x', 32), ('public_key_y', 32)])
|
||||
@@ -375,6 +406,9 @@ class SMP_Pairing_Public_Key_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key
|
||||
'''
|
||||
|
||||
public_key_x: bytes
|
||||
public_key_y: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass(
|
||||
@@ -387,6 +421,8 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check
|
||||
'''
|
||||
|
||||
dhkey_check: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass(
|
||||
@@ -402,6 +438,8 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification
|
||||
'''
|
||||
|
||||
notification_type: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('long_term_key', 16)])
|
||||
@@ -410,6 +448,8 @@ class SMP_Encryption_Information_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information
|
||||
'''
|
||||
|
||||
long_term_key: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('ediv', 2), ('rand', 8)])
|
||||
@@ -418,6 +458,9 @@ class SMP_Master_Identification_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification
|
||||
'''
|
||||
|
||||
ediv: int
|
||||
rand: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('identity_resolving_key', 16)])
|
||||
@@ -426,6 +469,8 @@ class SMP_Identity_Information_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information
|
||||
'''
|
||||
|
||||
identity_resolving_key: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass(
|
||||
@@ -439,6 +484,9 @@ class SMP_Identity_Address_Information_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information
|
||||
'''
|
||||
|
||||
addr_type: int
|
||||
bd_addr: Address
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('signature_key', 16)])
|
||||
@@ -447,6 +495,8 @@ class SMP_Signing_Information_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information
|
||||
'''
|
||||
|
||||
signature_key: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass(
|
||||
@@ -459,9 +509,11 @@ class SMP_Security_Request_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request
|
||||
'''
|
||||
|
||||
auth_req: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def smp_auth_req(bonding, mitm, sc, keypress, ct2):
|
||||
def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) -> int:
|
||||
value = 0
|
||||
if bonding:
|
||||
value |= SMP_BONDING_AUTHREQ
|
||||
@@ -574,11 +626,17 @@ class Session:
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, manager, connection, pairing_config, is_initiator):
|
||||
def __init__(
|
||||
self,
|
||||
manager: Manager,
|
||||
connection: Connection,
|
||||
pairing_config: PairingConfig,
|
||||
is_initiator: bool,
|
||||
) -> None:
|
||||
self.manager = manager
|
||||
self.connection = connection
|
||||
self.preq = None
|
||||
self.pres = None
|
||||
self.preq: Optional[bytes] = None
|
||||
self.pres: Optional[bytes] = None
|
||||
self.ea = None
|
||||
self.eb = None
|
||||
self.tk = bytes(16)
|
||||
@@ -588,29 +646,29 @@ class Session:
|
||||
self.ltk_ediv = 0
|
||||
self.ltk_rand = bytes(8)
|
||||
self.link_key = None
|
||||
self.initiator_key_distribution = 0
|
||||
self.responder_key_distribution = 0
|
||||
self.peer_random_value = None
|
||||
self.peer_public_key_x = bytes(32)
|
||||
self.initiator_key_distribution: int = 0
|
||||
self.responder_key_distribution: int = 0
|
||||
self.peer_random_value: Optional[bytes] = None
|
||||
self.peer_public_key_x: bytes = bytes(32)
|
||||
self.peer_public_key_y = bytes(32)
|
||||
self.peer_ltk = None
|
||||
self.peer_ediv = None
|
||||
self.peer_rand = None
|
||||
self.peer_rand: Optional[bytes] = None
|
||||
self.peer_identity_resolving_key = None
|
||||
self.peer_bd_addr = None
|
||||
self.peer_bd_addr: Optional[Address] = None
|
||||
self.peer_signature_key = None
|
||||
self.peer_expected_distributions = []
|
||||
self.peer_expected_distributions: List[Type[SMP_Command]] = []
|
||||
self.dh_key = None
|
||||
self.confirm_value = None
|
||||
self.passkey = None
|
||||
self.passkey: Optional[int] = None
|
||||
self.passkey_ready = asyncio.Event()
|
||||
self.passkey_step = 0
|
||||
self.passkey_display = False
|
||||
self.pairing_method = 0
|
||||
self.pairing_config = pairing_config
|
||||
self.wait_before_continuing = None
|
||||
self.wait_before_continuing: Optional[asyncio.Future[None]] = None
|
||||
self.completed = False
|
||||
self.ctkd_task = None
|
||||
self.ctkd_task: Optional[Awaitable[None]] = None
|
||||
|
||||
# Decide if we're the initiator or the responder
|
||||
self.is_initiator = is_initiator
|
||||
@@ -628,7 +686,9 @@ class Session:
|
||||
|
||||
# Create a future that can be used to wait for the session to complete
|
||||
if self.is_initiator:
|
||||
self.pairing_result = asyncio.get_running_loop().create_future()
|
||||
self.pairing_result: Optional[
|
||||
asyncio.Future[None]
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
else:
|
||||
self.pairing_result = None
|
||||
|
||||
@@ -641,11 +701,11 @@ class Session:
|
||||
)
|
||||
|
||||
# Authentication Requirements Flags - Vol 3, Part H, Figure 3.3
|
||||
self.bonding = pairing_config.bonding
|
||||
self.sc = pairing_config.sc
|
||||
self.mitm = pairing_config.mitm
|
||||
self.bonding: bool = pairing_config.bonding
|
||||
self.sc: bool = pairing_config.sc
|
||||
self.mitm: bool = pairing_config.mitm
|
||||
self.keypress = False
|
||||
self.ct2 = False
|
||||
self.ct2: bool = False
|
||||
|
||||
# I/O Capabilities
|
||||
self.io_capability = pairing_config.delegate.io_capability
|
||||
@@ -669,34 +729,35 @@ class Session:
|
||||
self.iat = 1 if peer_address.is_random else 0
|
||||
|
||||
@property
|
||||
def pkx(self):
|
||||
def pkx(self) -> Tuple[bytes, bytes]:
|
||||
return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x)
|
||||
|
||||
@property
|
||||
def pka(self):
|
||||
def pka(self) -> bytes:
|
||||
return self.pkx[0 if self.is_initiator else 1]
|
||||
|
||||
@property
|
||||
def pkb(self):
|
||||
def pkb(self) -> bytes:
|
||||
return self.pkx[0 if self.is_responder else 1]
|
||||
|
||||
@property
|
||||
def nx(self):
|
||||
def nx(self) -> Tuple[bytes, bytes]:
|
||||
assert self.peer_random_value
|
||||
return (self.r, self.peer_random_value)
|
||||
|
||||
@property
|
||||
def na(self):
|
||||
def na(self) -> bytes:
|
||||
return self.nx[0 if self.is_initiator else 1]
|
||||
|
||||
@property
|
||||
def nb(self):
|
||||
def nb(self) -> bytes:
|
||||
return self.nx[0 if self.is_responder else 1]
|
||||
|
||||
@property
|
||||
def auth_req(self):
|
||||
def auth_req(self) -> int:
|
||||
return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2)
|
||||
|
||||
def get_long_term_key(self, rand, ediv):
|
||||
def get_long_term_key(self, rand: bytes, ediv: int) -> Optional[bytes]:
|
||||
if not self.sc and not self.completed:
|
||||
if rand == self.ltk_rand and ediv == self.ltk_ediv:
|
||||
return self.stk
|
||||
@@ -706,13 +767,13 @@ class Session:
|
||||
return None
|
||||
|
||||
def decide_pairing_method(
|
||||
self, auth_req, initiator_io_capability, responder_io_capability
|
||||
):
|
||||
self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
|
||||
) -> None:
|
||||
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
|
||||
self.pairing_method = self.JUST_WORKS
|
||||
return
|
||||
|
||||
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability]
|
||||
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] # type: ignore[index]
|
||||
if isinstance(details, tuple) and len(details) == 2:
|
||||
# One entry for legacy pairing and one for secure connections
|
||||
details = details[1 if self.sc else 0]
|
||||
@@ -724,7 +785,9 @@ class Session:
|
||||
self.pairing_method = details[0]
|
||||
self.passkey_display = details[1 if self.is_initiator else 2]
|
||||
|
||||
def check_expected_value(self, expected, received, error):
|
||||
def check_expected_value(
|
||||
self, expected: bytes, received: bytes, error: int
|
||||
) -> bool:
|
||||
logger.debug(f'expected={expected.hex()} got={received.hex()}')
|
||||
if expected != received:
|
||||
logger.info(color('pairing confirm/check mismatch', 'red'))
|
||||
@@ -732,8 +795,8 @@ class Session:
|
||||
return False
|
||||
return True
|
||||
|
||||
def prompt_user_for_confirmation(self, next_steps):
|
||||
async def prompt():
|
||||
def prompt_user_for_confirmation(self, next_steps: Callable[[], None]) -> None:
|
||||
async def prompt() -> None:
|
||||
logger.debug('ask for confirmation')
|
||||
try:
|
||||
response = await self.pairing_config.delegate.confirm()
|
||||
@@ -747,8 +810,10 @@ class Session:
|
||||
|
||||
self.connection.abort_on('disconnection', prompt())
|
||||
|
||||
def prompt_user_for_numeric_comparison(self, code, next_steps):
|
||||
async def prompt():
|
||||
def prompt_user_for_numeric_comparison(
|
||||
self, code: int, next_steps: Callable[[], None]
|
||||
) -> None:
|
||||
async def prompt() -> None:
|
||||
logger.debug(f'verification code: {code}')
|
||||
try:
|
||||
response = await self.pairing_config.delegate.compare_numbers(
|
||||
@@ -764,11 +829,15 @@ class Session:
|
||||
|
||||
self.connection.abort_on('disconnection', prompt())
|
||||
|
||||
def prompt_user_for_number(self, next_steps):
|
||||
async def prompt():
|
||||
def prompt_user_for_number(self, next_steps: Callable[[int], None]) -> None:
|
||||
async def prompt() -> None:
|
||||
logger.debug('prompting user for passkey')
|
||||
try:
|
||||
passkey = await self.pairing_config.delegate.get_number()
|
||||
if passkey is None:
|
||||
logger.debug('Passkey request rejected')
|
||||
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
|
||||
return
|
||||
logger.debug(f'user input: {passkey}')
|
||||
next_steps(passkey)
|
||||
except Exception as error:
|
||||
@@ -777,9 +846,10 @@ class Session:
|
||||
|
||||
self.connection.abort_on('disconnection', prompt())
|
||||
|
||||
def display_passkey(self):
|
||||
def display_passkey(self) -> None:
|
||||
# Generate random Passkey/PIN code
|
||||
self.passkey = secrets.randbelow(1000000)
|
||||
assert self.passkey is not None
|
||||
logger.debug(f'Pairing PIN CODE: {self.passkey:06}')
|
||||
self.passkey_ready.set()
|
||||
|
||||
@@ -793,9 +863,9 @@ class Session:
|
||||
self.pairing_config.delegate.display_number(self.passkey, digits=6),
|
||||
)
|
||||
|
||||
def input_passkey(self, next_steps=None):
|
||||
def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None:
|
||||
# Prompt the user for the passkey displayed on the peer
|
||||
def after_input(passkey):
|
||||
def after_input(passkey: int) -> None:
|
||||
self.passkey = passkey
|
||||
|
||||
if not self.sc:
|
||||
@@ -809,7 +879,9 @@ class Session:
|
||||
|
||||
self.prompt_user_for_number(after_input)
|
||||
|
||||
def display_or_input_passkey(self, next_steps=None):
|
||||
def display_or_input_passkey(
|
||||
self, next_steps: Optional[Callable[[], None]] = None
|
||||
) -> None:
|
||||
if self.passkey_display:
|
||||
self.display_passkey()
|
||||
if next_steps is not None:
|
||||
@@ -817,14 +889,14 @@ class Session:
|
||||
else:
|
||||
self.input_passkey(next_steps)
|
||||
|
||||
def send_command(self, command):
|
||||
def send_command(self, command: SMP_Command) -> None:
|
||||
self.manager.send_command(self.connection, command)
|
||||
|
||||
def send_pairing_failed(self, error):
|
||||
def send_pairing_failed(self, error: int) -> None:
|
||||
self.send_command(SMP_Pairing_Failed_Command(reason=error))
|
||||
self.on_pairing_failure(error)
|
||||
|
||||
def send_pairing_request_command(self):
|
||||
def send_pairing_request_command(self) -> None:
|
||||
self.manager.on_session_start(self)
|
||||
|
||||
command = SMP_Pairing_Request_Command(
|
||||
@@ -838,7 +910,7 @@ class Session:
|
||||
self.preq = bytes(command)
|
||||
self.send_command(command)
|
||||
|
||||
def send_pairing_response_command(self):
|
||||
def send_pairing_response_command(self) -> None:
|
||||
response = SMP_Pairing_Response_Command(
|
||||
io_capability=self.io_capability,
|
||||
oob_data_flag=0,
|
||||
@@ -850,18 +922,19 @@ class Session:
|
||||
self.pres = bytes(response)
|
||||
self.send_command(response)
|
||||
|
||||
def send_pairing_confirm_command(self):
|
||||
def send_pairing_confirm_command(self) -> None:
|
||||
self.r = crypto.r()
|
||||
logger.debug(f'generated random: {self.r.hex()}')
|
||||
|
||||
if self.sc:
|
||||
|
||||
async def next_steps():
|
||||
async def next_steps() -> None:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
z = 0
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
# We need a passkey
|
||||
await self.passkey_ready.wait()
|
||||
assert self.passkey
|
||||
|
||||
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
|
||||
else:
|
||||
@@ -892,10 +965,10 @@ class Session:
|
||||
|
||||
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
|
||||
|
||||
def send_pairing_random_command(self):
|
||||
def send_pairing_random_command(self) -> None:
|
||||
self.send_command(SMP_Pairing_Random_Command(random_value=self.r))
|
||||
|
||||
def send_public_key_command(self):
|
||||
def send_public_key_command(self) -> None:
|
||||
self.send_command(
|
||||
SMP_Pairing_Public_Key_Command(
|
||||
public_key_x=bytes(reversed(self.manager.ecc_key.x)),
|
||||
@@ -903,18 +976,18 @@ class Session:
|
||||
)
|
||||
)
|
||||
|
||||
def send_pairing_dhkey_check_command(self):
|
||||
def send_pairing_dhkey_check_command(self) -> None:
|
||||
self.send_command(
|
||||
SMP_Pairing_DHKey_Check_Command(
|
||||
dhkey_check=self.ea if self.is_initiator else self.eb
|
||||
)
|
||||
)
|
||||
|
||||
def start_encryption(self, key):
|
||||
def start_encryption(self, key: bytes) -> None:
|
||||
# We can now encrypt the connection with the short term key, so that we can
|
||||
# distribute the long term and/or other keys over an encrypted connection
|
||||
self.manager.device.host.send_command_sync(
|
||||
HCI_LE_Enable_Encryption_Command(
|
||||
HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
|
||||
connection_handle=self.connection.handle,
|
||||
random_number=bytes(8),
|
||||
encrypted_diversifier=0,
|
||||
@@ -922,7 +995,7 @@ class Session:
|
||||
)
|
||||
)
|
||||
|
||||
async def derive_ltk(self):
|
||||
async def derive_ltk(self) -> None:
|
||||
link_key = await self.manager.device.get_link_key(self.connection.peer_address)
|
||||
assert link_key is not None
|
||||
ilk = (
|
||||
@@ -932,7 +1005,7 @@ class Session:
|
||||
)
|
||||
self.ltk = crypto.h6(ilk, b'brle')
|
||||
|
||||
def distribute_keys(self):
|
||||
def distribute_keys(self) -> None:
|
||||
# Distribute the keys as required
|
||||
if self.is_initiator:
|
||||
# CTKD: Derive LTK from LinkKey
|
||||
@@ -1032,7 +1105,7 @@ class Session:
|
||||
)
|
||||
self.link_key = crypto.h6(ilk, b'lebr')
|
||||
|
||||
def compute_peer_expected_distributions(self, key_distribution_flags):
|
||||
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
|
||||
# Set our expectations for what to wait for in the key distribution phase
|
||||
self.peer_expected_distributions = []
|
||||
if not self.sc and self.connection.transport == BT_LE_TRANSPORT:
|
||||
@@ -1055,7 +1128,7 @@ class Session:
|
||||
f'{[c.__name__ for c in self.peer_expected_distributions]}'
|
||||
)
|
||||
|
||||
def check_key_distribution(self, command_class):
|
||||
def check_key_distribution(self, command_class: Type[SMP_Command]) -> None:
|
||||
# First, check that the connection is encrypted
|
||||
if not self.connection.is_encrypted:
|
||||
logger.warning(
|
||||
@@ -1083,7 +1156,7 @@ class Session:
|
||||
)
|
||||
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
|
||||
|
||||
async def pair(self):
|
||||
async def pair(self) -> None:
|
||||
# Start pairing as an initiator
|
||||
# TODO: check that this session isn't already active
|
||||
|
||||
@@ -1091,9 +1164,10 @@ class Session:
|
||||
self.send_pairing_request_command()
|
||||
|
||||
# Wait for the pairing process to finish
|
||||
assert self.pairing_result
|
||||
await self.connection.abort_on('disconnection', self.pairing_result)
|
||||
|
||||
def on_disconnection(self, _):
|
||||
def on_disconnection(self, _: int) -> None:
|
||||
self.connection.remove_listener('disconnection', self.on_disconnection)
|
||||
self.connection.remove_listener(
|
||||
'connection_encryption_change', self.on_connection_encryption_change
|
||||
@@ -1104,14 +1178,14 @@ class Session:
|
||||
)
|
||||
self.manager.on_session_end(self)
|
||||
|
||||
def on_peer_key_distribution_complete(self):
|
||||
def on_peer_key_distribution_complete(self) -> None:
|
||||
# The initiator can now send its keys
|
||||
if self.is_initiator:
|
||||
self.distribute_keys()
|
||||
|
||||
self.connection.abort_on('disconnection', self.on_pairing())
|
||||
|
||||
def on_connection_encryption_change(self):
|
||||
def on_connection_encryption_change(self) -> None:
|
||||
if self.connection.is_encrypted:
|
||||
if self.is_responder:
|
||||
# The responder distributes its keys first, the initiator later
|
||||
@@ -1121,11 +1195,11 @@ class Session:
|
||||
if not self.peer_expected_distributions:
|
||||
self.on_peer_key_distribution_complete()
|
||||
|
||||
def on_connection_encryption_key_refresh(self):
|
||||
def on_connection_encryption_key_refresh(self) -> None:
|
||||
# Do as if the connection had just been encrypted
|
||||
self.on_connection_encryption_change()
|
||||
|
||||
async def on_pairing(self):
|
||||
async def on_pairing(self) -> None:
|
||||
logger.debug('pairing complete')
|
||||
|
||||
if self.completed:
|
||||
@@ -1137,7 +1211,7 @@ class Session:
|
||||
self.pairing_result.set_result(None)
|
||||
|
||||
# Use the peer address from the pairing protocol or the connection
|
||||
if self.peer_bd_addr:
|
||||
if self.peer_bd_addr is not None:
|
||||
peer_address = self.peer_bd_addr
|
||||
else:
|
||||
peer_address = self.connection.peer_address
|
||||
@@ -1186,7 +1260,7 @@ class Session:
|
||||
)
|
||||
self.manager.on_pairing(self, peer_address, keys)
|
||||
|
||||
def on_pairing_failure(self, reason):
|
||||
def on_pairing_failure(self, reason: int) -> None:
|
||||
logger.warning(f'pairing failure ({error_name(reason)})')
|
||||
|
||||
if self.completed:
|
||||
@@ -1199,7 +1273,7 @@ class Session:
|
||||
self.pairing_result.set_exception(error)
|
||||
self.manager.on_pairing_failure(self, reason)
|
||||
|
||||
def on_smp_command(self, command):
|
||||
def on_smp_command(self, command: SMP_Command) -> None:
|
||||
# Find the handler method
|
||||
handler_name = f'on_{command.name.lower()}'
|
||||
handler = getattr(self, handler_name, None)
|
||||
@@ -1215,12 +1289,16 @@ class Session:
|
||||
else:
|
||||
logger.error(color('SMP command not handled???', 'red'))
|
||||
|
||||
def on_smp_pairing_request_command(self, command):
|
||||
def on_smp_pairing_request_command(
|
||||
self, command: SMP_Pairing_Request_Command
|
||||
) -> None:
|
||||
self.connection.abort_on(
|
||||
'disconnection', self.on_smp_pairing_request_command_async(command)
|
||||
)
|
||||
|
||||
async def on_smp_pairing_request_command_async(self, command):
|
||||
async def on_smp_pairing_request_command_async(
|
||||
self, command: SMP_Pairing_Request_Command
|
||||
) -> None:
|
||||
# Check if the request should proceed
|
||||
accepted = await self.pairing_config.delegate.accept()
|
||||
if not accepted:
|
||||
@@ -1280,7 +1358,9 @@ class Session:
|
||||
):
|
||||
self.distribute_keys()
|
||||
|
||||
def on_smp_pairing_response_command(self, command):
|
||||
def on_smp_pairing_response_command(
|
||||
self, command: SMP_Pairing_Response_Command
|
||||
) -> None:
|
||||
if self.is_responder:
|
||||
logger.warning(color('received pairing response as a responder', 'red'))
|
||||
return
|
||||
@@ -1331,7 +1411,9 @@ class Session:
|
||||
else:
|
||||
self.send_pairing_confirm_command()
|
||||
|
||||
def on_smp_pairing_confirm_command_legacy(self, _):
|
||||
def on_smp_pairing_confirm_command_legacy(
|
||||
self, _: SMP_Pairing_Confirm_Command
|
||||
) -> None:
|
||||
if self.is_initiator:
|
||||
self.send_pairing_random_command()
|
||||
else:
|
||||
@@ -1341,7 +1423,9 @@ class Session:
|
||||
else:
|
||||
self.send_pairing_confirm_command()
|
||||
|
||||
def on_smp_pairing_confirm_command_secure_connections(self, _):
|
||||
def on_smp_pairing_confirm_command_secure_connections(
|
||||
self, _: SMP_Pairing_Confirm_Command
|
||||
) -> None:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
if self.is_initiator:
|
||||
self.r = crypto.r()
|
||||
@@ -1352,14 +1436,18 @@ class Session:
|
||||
else:
|
||||
self.send_pairing_confirm_command()
|
||||
|
||||
def on_smp_pairing_confirm_command(self, command):
|
||||
def on_smp_pairing_confirm_command(
|
||||
self, command: SMP_Pairing_Confirm_Command
|
||||
) -> None:
|
||||
self.confirm_value = command.confirm_value
|
||||
if self.sc:
|
||||
self.on_smp_pairing_confirm_command_secure_connections(command)
|
||||
else:
|
||||
self.on_smp_pairing_confirm_command_legacy(command)
|
||||
|
||||
def on_smp_pairing_random_command_legacy(self, command):
|
||||
def on_smp_pairing_random_command_legacy(
|
||||
self, command: SMP_Pairing_Random_Command
|
||||
) -> None:
|
||||
# Check that the confirmation values match
|
||||
confirm_verifier = crypto.c1(
|
||||
self.tk,
|
||||
@@ -1371,6 +1459,7 @@ class Session:
|
||||
self.ia,
|
||||
self.ra,
|
||||
)
|
||||
assert self.confirm_value
|
||||
if not self.check_expected_value(
|
||||
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
|
||||
):
|
||||
@@ -1394,7 +1483,9 @@ class Session:
|
||||
else:
|
||||
self.send_pairing_random_command()
|
||||
|
||||
def on_smp_pairing_random_command_secure_connections(self, command):
|
||||
def on_smp_pairing_random_command_secure_connections(
|
||||
self, command: SMP_Pairing_Random_Command
|
||||
) -> None:
|
||||
if self.pairing_method == self.PASSKEY and self.passkey is None:
|
||||
logger.warning('no passkey entered, ignoring command')
|
||||
return
|
||||
@@ -1402,6 +1493,7 @@ class Session:
|
||||
# pylint: disable=too-many-return-statements
|
||||
if self.is_initiator:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
assert self.confirm_value
|
||||
# Check that the random value matches what was committed to earlier
|
||||
confirm_verifier = crypto.f4(
|
||||
self.pkb, self.pka, command.random_value, bytes([0])
|
||||
@@ -1411,6 +1503,7 @@ class Session:
|
||||
):
|
||||
return
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
assert self.passkey and self.confirm_value
|
||||
# Check that the random value matches what was committed to earlier
|
||||
confirm_verifier = crypto.f4(
|
||||
self.pkb,
|
||||
@@ -1435,6 +1528,7 @@ class Session:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
self.send_pairing_random_command()
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
assert self.passkey and self.confirm_value
|
||||
# Check that the random value matches what was committed to earlier
|
||||
confirm_verifier = crypto.f4(
|
||||
self.pka,
|
||||
@@ -1468,19 +1562,21 @@ class Session:
|
||||
ra = bytes(16)
|
||||
rb = ra
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
assert self.passkey
|
||||
ra = self.passkey.to_bytes(16, byteorder='little')
|
||||
rb = ra
|
||||
else:
|
||||
# OOB not implemented yet
|
||||
return
|
||||
|
||||
assert self.preq and self.pres
|
||||
io_cap_a = self.preq[1:4]
|
||||
io_cap_b = self.pres[1:4]
|
||||
self.ea = crypto.f6(mac_key, self.na, self.nb, rb, io_cap_a, a, b)
|
||||
self.eb = crypto.f6(mac_key, self.nb, self.na, ra, io_cap_b, b, a)
|
||||
|
||||
# Next steps to be performed after possible user confirmation
|
||||
def next_steps():
|
||||
def next_steps() -> None:
|
||||
# The initiator sends the DH Key check to the responder
|
||||
if self.is_initiator:
|
||||
self.send_pairing_dhkey_check_command()
|
||||
@@ -1502,14 +1598,18 @@ class Session:
|
||||
else:
|
||||
next_steps()
|
||||
|
||||
def on_smp_pairing_random_command(self, command):
|
||||
def on_smp_pairing_random_command(
|
||||
self, command: SMP_Pairing_Random_Command
|
||||
) -> None:
|
||||
self.peer_random_value = command.random_value
|
||||
if self.sc:
|
||||
self.on_smp_pairing_random_command_secure_connections(command)
|
||||
else:
|
||||
self.on_smp_pairing_random_command_legacy(command)
|
||||
|
||||
def on_smp_pairing_public_key_command(self, command):
|
||||
def on_smp_pairing_public_key_command(
|
||||
self, command: SMP_Pairing_Public_Key_Command
|
||||
) -> None:
|
||||
# Store the public key so that we can compute the confirmation value later
|
||||
self.peer_public_key_x = command.public_key_x
|
||||
self.peer_public_key_y = command.public_key_y
|
||||
@@ -1538,9 +1638,12 @@ class Session:
|
||||
# We can now send the confirmation value
|
||||
self.send_pairing_confirm_command()
|
||||
|
||||
def on_smp_pairing_dhkey_check_command(self, command):
|
||||
def on_smp_pairing_dhkey_check_command(
|
||||
self, command: SMP_Pairing_DHKey_Check_Command
|
||||
) -> None:
|
||||
# Check that what we received matches what we computed earlier
|
||||
expected = self.eb if self.is_initiator else self.ea
|
||||
assert expected
|
||||
if not self.check_expected_value(
|
||||
expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR
|
||||
):
|
||||
@@ -1549,7 +1652,8 @@ class Session:
|
||||
if self.is_responder:
|
||||
if self.wait_before_continuing is not None:
|
||||
|
||||
async def next_steps():
|
||||
async def next_steps() -> None:
|
||||
assert self.wait_before_continuing
|
||||
await self.wait_before_continuing
|
||||
self.wait_before_continuing = None
|
||||
self.send_pairing_dhkey_check_command()
|
||||
@@ -1558,29 +1662,42 @@ class Session:
|
||||
else:
|
||||
self.send_pairing_dhkey_check_command()
|
||||
else:
|
||||
assert self.ltk
|
||||
self.start_encryption(self.ltk)
|
||||
|
||||
def on_smp_pairing_failed_command(self, command):
|
||||
def on_smp_pairing_failed_command(
|
||||
self, command: SMP_Pairing_Failed_Command
|
||||
) -> None:
|
||||
self.on_pairing_failure(command.reason)
|
||||
|
||||
def on_smp_encryption_information_command(self, command):
|
||||
def on_smp_encryption_information_command(
|
||||
self, command: SMP_Encryption_Information_Command
|
||||
) -> None:
|
||||
self.peer_ltk = command.long_term_key
|
||||
self.check_key_distribution(SMP_Encryption_Information_Command)
|
||||
|
||||
def on_smp_master_identification_command(self, command):
|
||||
def on_smp_master_identification_command(
|
||||
self, command: SMP_Master_Identification_Command
|
||||
) -> None:
|
||||
self.peer_ediv = command.ediv
|
||||
self.peer_rand = command.rand
|
||||
self.check_key_distribution(SMP_Master_Identification_Command)
|
||||
|
||||
def on_smp_identity_information_command(self, command):
|
||||
def on_smp_identity_information_command(
|
||||
self, command: SMP_Identity_Information_Command
|
||||
) -> None:
|
||||
self.peer_identity_resolving_key = command.identity_resolving_key
|
||||
self.check_key_distribution(SMP_Identity_Information_Command)
|
||||
|
||||
def on_smp_identity_address_information_command(self, command):
|
||||
def on_smp_identity_address_information_command(
|
||||
self, command: SMP_Identity_Address_Information_Command
|
||||
) -> None:
|
||||
self.peer_bd_addr = command.bd_addr
|
||||
self.check_key_distribution(SMP_Identity_Address_Information_Command)
|
||||
|
||||
def on_smp_signing_information_command(self, command):
|
||||
def on_smp_signing_information_command(
|
||||
self, command: SMP_Signing_Information_Command
|
||||
) -> None:
|
||||
self.peer_signature_key = command.signature_key
|
||||
self.check_key_distribution(SMP_Signing_Information_Command)
|
||||
|
||||
@@ -1591,14 +1708,24 @@ class Manager(EventEmitter):
|
||||
Implements the Initiator and Responder roles of the Security Manager Protocol
|
||||
'''
|
||||
|
||||
def __init__(self, device, pairing_config_factory):
|
||||
device: Device
|
||||
sessions: Dict[int, Session]
|
||||
pairing_config_factory: Callable[[Connection], PairingConfig]
|
||||
session_proxy: Type[Session]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Device,
|
||||
pairing_config_factory: Callable[[Connection], PairingConfig],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.sessions = {}
|
||||
self._ecc_key = None
|
||||
self.pairing_config_factory = pairing_config_factory
|
||||
self.session_proxy = Session
|
||||
|
||||
def send_command(self, connection, command):
|
||||
def send_command(self, connection: Connection, command: SMP_Command) -> None:
|
||||
logger.debug(
|
||||
f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address}: {command}'
|
||||
@@ -1606,20 +1733,15 @@ class Manager(EventEmitter):
|
||||
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
|
||||
connection.send_l2cap_pdu(cid, command.to_bytes())
|
||||
|
||||
def on_smp_pdu(self, connection, pdu):
|
||||
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
|
||||
# Look for a session with this connection, and create one if none exists
|
||||
if not (session := self.sessions.get(connection.handle)):
|
||||
if connection.role == BT_CENTRAL_ROLE:
|
||||
logger.warning('Remote starts pairing as Peripheral!')
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
if pairing_config is None:
|
||||
# Pairing disabled
|
||||
self.send_command(
|
||||
connection,
|
||||
SMP_Pairing_Failed_Command(reason=SMP_PAIRING_NOT_SUPPORTED_ERROR),
|
||||
)
|
||||
return
|
||||
session = Session(self, connection, pairing_config, is_initiator=False)
|
||||
session = self.session_proxy(
|
||||
self, connection, pairing_config, is_initiator=False
|
||||
)
|
||||
self.sessions[connection.handle] = session
|
||||
|
||||
# Parse the L2CAP payload into an SMP Command object
|
||||
@@ -1633,23 +1755,24 @@ class Manager(EventEmitter):
|
||||
session.on_smp_command(command)
|
||||
|
||||
@property
|
||||
def ecc_key(self):
|
||||
def ecc_key(self) -> crypto.EccKey:
|
||||
if self._ecc_key is None:
|
||||
self._ecc_key = crypto.EccKey.generate()
|
||||
assert self._ecc_key
|
||||
return self._ecc_key
|
||||
|
||||
async def pair(self, connection):
|
||||
async def pair(self, connection: Connection) -> None:
|
||||
# TODO: check if there's already a session for this connection
|
||||
if connection.role != BT_CENTRAL_ROLE:
|
||||
logger.warning('Start pairing as Peripheral!')
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
if pairing_config is None:
|
||||
raise ValueError('pairing config must not be None when initiating')
|
||||
session = Session(self, connection, pairing_config, is_initiator=True)
|
||||
session = self.session_proxy(
|
||||
self, connection, pairing_config, is_initiator=True
|
||||
)
|
||||
self.sessions[connection.handle] = session
|
||||
return await session.pair()
|
||||
|
||||
def request_pairing(self, connection):
|
||||
def request_pairing(self, connection: Connection) -> None:
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
if pairing_config:
|
||||
auth_req = smp_auth_req(
|
||||
@@ -1663,15 +1786,18 @@ class Manager(EventEmitter):
|
||||
auth_req = 0
|
||||
self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req))
|
||||
|
||||
def on_session_start(self, session):
|
||||
self.device.on_pairing_start(session.connection.handle)
|
||||
def on_session_start(self, session: Session) -> None:
|
||||
self.device.on_pairing_start(session.connection)
|
||||
|
||||
def on_pairing(self, session, identity_address, keys):
|
||||
def on_pairing(
|
||||
self, session: Session, identity_address: Optional[Address], keys: PairingKeys
|
||||
) -> None:
|
||||
# Store the keys in the key store
|
||||
if self.device.keystore and identity_address is not None:
|
||||
|
||||
async def store_keys():
|
||||
try:
|
||||
assert self.device.keystore
|
||||
await self.device.keystore.update(str(identity_address), keys)
|
||||
except Exception as error:
|
||||
logger.warning(f'!!! error while storing keys: {error}')
|
||||
@@ -1679,17 +1805,19 @@ class Manager(EventEmitter):
|
||||
self.device.abort_on('flush', store_keys())
|
||||
|
||||
# Notify the device
|
||||
self.device.on_pairing(session.connection.handle, keys, session.sc)
|
||||
self.device.on_pairing(session.connection, keys, session.sc)
|
||||
|
||||
def on_pairing_failure(self, session, reason):
|
||||
self.device.on_pairing_failure(session.connection.handle, reason)
|
||||
def on_pairing_failure(self, session: Session, reason: int) -> None:
|
||||
self.device.on_pairing_failure(session.connection, reason)
|
||||
|
||||
def on_session_end(self, session):
|
||||
def on_session_end(self, session: Session) -> None:
|
||||
logger.debug(f'session end for connection 0x{session.connection.handle:04X}')
|
||||
if session.connection.handle in self.sessions:
|
||||
del self.sessions[session.connection.handle]
|
||||
|
||||
def get_long_term_key(self, connection, rand, ediv):
|
||||
def get_long_term_key(
|
||||
self, connection: Connection, rand: bytes, ediv: int
|
||||
) -> Optional[bytes]:
|
||||
if session := self.sessions.get(connection.handle):
|
||||
return session.get_long_term_key(rand, ediv)
|
||||
|
||||
|
||||
@@ -2,5 +2,5 @@
|
||||
mkdocs == 1.4.0
|
||||
mkdocs-material == 8.5.6
|
||||
mkdocs-material-extensions == 1.0.3
|
||||
pymdown-extensions == 9.6
|
||||
pymdown-extensions == 10.0
|
||||
mkdocstrings-python == 0.7.1
|
||||
|
||||
@@ -40,6 +40,9 @@ disable = [
|
||||
"too-many-statements",
|
||||
]
|
||||
|
||||
[tool.pylint.main]
|
||||
ignore="pandora" # FIXME: pylint does not support stubs yet:
|
||||
|
||||
[tool.pylint.typecheck]
|
||||
signature-mutators="AsyncRunner.run_in_task"
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ url = https://github.com/google/bumble
|
||||
|
||||
[options]
|
||||
python_requires = >=3.8
|
||||
packages = bumble, bumble.transport, bumble.profiles, bumble.apps, bumble.apps.link_relay
|
||||
packages = bumble, bumble.transport, bumble.profiles, bumble.apps, bumble.apps.link_relay, bumble.pandora
|
||||
package_dir =
|
||||
bumble = bumble
|
||||
bumble.apps = apps
|
||||
@@ -33,7 +33,7 @@ install_requires =
|
||||
appdirs >= 1.4
|
||||
click >= 7.1.2; platform_system!='Emscripten'
|
||||
cryptography == 35; platform_system!='Emscripten'
|
||||
grpcio >= 1.46; platform_system!='Emscripten'
|
||||
grpcio == 1.51.1; platform_system!='Emscripten'
|
||||
libusb1 >= 2.0.1; platform_system!='Emscripten'
|
||||
libusb-package == 1.0.26.1; platform_system!='Emscripten'
|
||||
prompt_toolkit >= 3.0.16; platform_system!='Emscripten'
|
||||
@@ -45,6 +45,7 @@ install_requires =
|
||||
websockets >= 8.1; platform_system!='Emscripten'
|
||||
prettytable >= 3.6.0
|
||||
humanize >= 4.6.0
|
||||
bt-test-interfaces >= 0.0.2
|
||||
|
||||
[options.entry_points]
|
||||
console_scripts =
|
||||
@@ -60,6 +61,7 @@ console_scripts =
|
||||
bumble-usb-probe = bumble.apps.usb_probe:main
|
||||
bumble-link-relay = bumble.apps.link_relay.link_relay:main
|
||||
bumble-bench = bumble.apps.bench:main
|
||||
bumble-pandora-server = bumble.apps.pandora_server:main
|
||||
|
||||
[options.package_data]
|
||||
* = py.typed, *.pyi
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from bumble.core import AdvertisingData, get_dict_key_by_value
|
||||
from bumble.core import AdvertisingData, UUID, get_dict_key_by_value
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ad_data():
|
||||
@@ -49,6 +49,24 @@ def test_get_dict_key_by_value():
|
||||
assert get_dict_key_by_value(dictionary, 3) is None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_uuid_to_hex_str() -> None:
|
||||
assert UUID("b5ea").to_hex_str() == "B5EA"
|
||||
assert UUID("df5ce654").to_hex_str() == "DF5CE654"
|
||||
assert (
|
||||
UUID("df5ce654-e059-11ed-b5ea-0242ac120002").to_hex_str()
|
||||
== "DF5CE654E05911EDB5EA0242AC120002"
|
||||
)
|
||||
assert UUID("b5ea").to_hex_str('-') == "B5EA"
|
||||
assert UUID("df5ce654").to_hex_str('-') == "DF5CE654"
|
||||
assert (
|
||||
UUID("df5ce654-e059-11ed-b5ea-0242ac120002").to_hex_str('-')
|
||||
== "DF5CE654-E059-11ED-B5EA-0242AC120002"
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
test_ad_data()
|
||||
test_get_dict_key_by_value()
|
||||
test_uuid_to_hex_str()
|
||||
|
||||
@@ -190,7 +190,9 @@ async def test_self_gatt():
|
||||
|
||||
s1 = Service('8140E247-04F0-42C1-BC34-534C344DAFCA', [c1, c2, c3])
|
||||
s2 = Service('97210A0F-1875-4D05-9E5D-326EB171257A', [c4])
|
||||
two_devices.devices[1].add_services([s1, s2])
|
||||
s3 = Service('1853', [])
|
||||
s4 = Service('3A12C182-14E2-4FE0-8C5B-65D7C569F9DB', [], included_services=[s2, s3])
|
||||
two_devices.devices[1].add_services([s1, s2, s4])
|
||||
|
||||
# Start
|
||||
await two_devices.devices[0].power_on()
|
||||
@@ -225,6 +227,13 @@ async def test_self_gatt():
|
||||
assert result is not None
|
||||
assert result == c1.value
|
||||
|
||||
result = await peer.discover_service(s4.uuid)
|
||||
assert len(result) == 1
|
||||
result = await peer.discover_included_services(result[0])
|
||||
assert len(result) == 2
|
||||
# Service UUID is only present when the UUID is 16-bit Bluetooth UUID
|
||||
assert result[1].uuid.to_bytes() == s3.uuid.to_bytes()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user