diff --git a/bumble/core.py b/bumble/core.py index 34412c7..1cc10ec 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -152,7 +152,12 @@ class UUID: BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian UUIDS: List[UUID] = [] # Registry of all instances created - def __init__(self, uuid_str_or_int, name=None): + uuid_bytes: bytes + name: Optional[str] + + def __init__( + self, uuid_str_or_int: Union[str, int], name: Optional[str] = None + ) -> None: if isinstance(uuid_str_or_int, int): self.uuid_bytes = struct.pack(' UUID: # Register this object in the class registry, and update the entry's name if # it wasn't set already for uuid in self.UUIDS: @@ -196,22 +201,22 @@ class UUID: raise ValueError('only 2, 4 and 16 bytes are allowed') @classmethod - def from_16_bits(cls, uuid_16, name=None): + def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID: return cls.from_bytes(struct.pack(' UUID: return cls.from_bytes(struct.pack(' Tuple[int, UUID]: return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:]) @classmethod - def parse_uuid_2(cls, uuid_as_bytes, offset): + def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> Tuple[int, UUID]: return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2]) - def to_bytes(self, force_128=False): + def to_bytes(self, force_128: bool = False) -> bytes: ''' Serialize UUID in little-endian byte-order ''' @@ -227,7 +232,7 @@ class UUID: else: assert False, "unreachable" - def to_pdu_bytes(self): + def to_pdu_bytes(self) -> bytes: ''' Convert to bytes for use in an ATT PDU. According to Vol 3, Part F - 3.2.1 Attribute Type: @@ -236,11 +241,11 @@ class UUID: ''' return self.to_bytes(force_128=(len(self.uuid_bytes) == 4)) - def to_hex_str(self) -> str: + def to_hex_str(self, separator: str = '') -> str: if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4: return bytes(reversed(self.uuid_bytes)).hex().upper() - return ''.join( + return separator.join( [ bytes(reversed(self.uuid_bytes[12:16])).hex(), bytes(reversed(self.uuid_bytes[10:12])).hex(), @@ -250,10 +255,10 @@ class UUID: ] ).upper() - def __bytes__(self): + def __bytes__(self) -> bytes: return self.to_bytes() - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, UUID): return self.to_bytes(force_128=True) == other.to_bytes(force_128=True) @@ -262,35 +267,19 @@ class UUID: return False - def __hash__(self): + def __hash__(self) -> int: return hash(self.uuid_bytes) - def __str__(self): + def __str__(self) -> str: + result = self.to_hex_str(separator='-') if len(self.uuid_bytes) == 2: - uuid = struct.unpack(' List[UUID]: uuids = [] offset = 0 - while (uuid_size * (offset + 1)) <= len(ad_data): + while (offset + uuid_size) <= len(ad_data): uuids.append(UUID.from_bytes(ad_data[offset : offset + uuid_size])) offset += uuid_size return uuids diff --git a/bumble/device.py b/bumble/device.py index 6159435..72fd755 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -23,7 +23,7 @@ import asyncio import logging from contextlib import asynccontextmanager, AsyncExitStack from dataclasses import dataclass -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union from .colors import color from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU @@ -528,6 +528,7 @@ class Connection(CompositeEventEmitter): transport: int self_address: Address peer_address: Address + peer_resolvable_address: Optional[Address] role: int encryption: int authenticated: bool @@ -2196,13 +2197,23 @@ class Device(CompositeEventEmitter): await self.stop_discovery() @property - def pairing_config_factory(self): + def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]: return self.smp_manager.pairing_config_factory @pairing_config_factory.setter - def pairing_config_factory(self, pairing_config_factory): + def pairing_config_factory( + self, pairing_config_factory: Callable[[Connection], PairingConfig] + ) -> None: self.smp_manager.pairing_config_factory = pairing_config_factory + @property + def smp_session_proxy(self) -> Type[smp.Session]: + return self.smp_manager.session_proxy + + @smp_session_proxy.setter + def smp_session_proxy(self, session_proxy: Type[smp.Session]) -> None: + self.smp_manager.session_proxy = session_proxy + async def pair(self, connection): return await self.smp_manager.pair(connection) @@ -2232,7 +2243,7 @@ class Device(CompositeEventEmitter): if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral: return keys.ltk_peripheral.value - async def get_link_key(self, address): + async def get_link_key(self, address: Address) -> Optional[bytes]: # Look for the key in the keystore if self.keystore is not None: keys = await self.keystore.get(str(address)) @@ -2243,6 +2254,7 @@ class Device(CompositeEventEmitter): return None return keys.link_key.value + return None # [Classic only] async def authenticate(self, connection): @@ -2772,89 +2784,103 @@ class Device(CompositeEventEmitter): # [Classic only] @host_event_handler @with_connection_from_address - def on_authentication_user_confirmation_request(self, connection, code): + def on_authentication_user_confirmation_request(self, connection, code) -> None: # Ask what the pairing config should be for this connection pairing_config = self.pairing_config_factory(connection) io_capability = pairing_config.delegate.classic_io_capability + peer_io_capability = connection.peer_pairing_io_capability - # Respond - if io_capability == HCI_DISPLAY_YES_NO_IO_CAPABILITY: - if connection.peer_pairing_io_capability in ( - HCI_DISPLAY_YES_NO_IO_CAPABILITY, - HCI_DISPLAY_ONLY_IO_CAPABILITY, - ): - # Display the code and ask the user to compare - async def prompt(): - return ( - await pairing_config.delegate.compare_numbers(code, digits=6), + async def confirm() -> bool: + # Ask the user to confirm the pairing, without display + return await pairing_config.delegate.confirm() + + async def auto_confirm() -> bool: + # Ask the user to auto-confirm the pairing, without display + return await pairing_config.delegate.confirm(auto=True) + + async def display_confirm() -> bool: + # Display the code and ask the user to compare + return await pairing_config.delegate.compare_numbers(code, digits=6) + + async def display_auto_confirm() -> bool: + # Display the code to the user and ask the delegate to auto-confirm + await pairing_config.delegate.display_number(code, digits=6) + return await pairing_config.delegate.confirm(auto=True) + + async def na() -> bool: + assert False, "N/A: unreachable" + + # See Bluetooth spec @ Vol 3, Part C 5.2.2.6 + methods = { + HCI_DISPLAY_ONLY_IO_CAPABILITY: { + HCI_DISPLAY_ONLY_IO_CAPABILITY: display_auto_confirm, + HCI_DISPLAY_YES_NO_IO_CAPABILITY: display_confirm, + HCI_KEYBOARD_ONLY_IO_CAPABILITY: na, + HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm, + }, + HCI_DISPLAY_YES_NO_IO_CAPABILITY: { + HCI_DISPLAY_ONLY_IO_CAPABILITY: display_auto_confirm, + HCI_DISPLAY_YES_NO_IO_CAPABILITY: display_confirm, + HCI_KEYBOARD_ONLY_IO_CAPABILITY: na, + HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm, + }, + HCI_KEYBOARD_ONLY_IO_CAPABILITY: { + HCI_DISPLAY_ONLY_IO_CAPABILITY: na, + HCI_DISPLAY_YES_NO_IO_CAPABILITY: na, + HCI_KEYBOARD_ONLY_IO_CAPABILITY: na, + HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm, + }, + HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: { + HCI_DISPLAY_ONLY_IO_CAPABILITY: confirm, + HCI_DISPLAY_YES_NO_IO_CAPABILITY: confirm, + HCI_KEYBOARD_ONLY_IO_CAPABILITY: auto_confirm, + HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm, + }, + } + + method = methods[peer_io_capability][io_capability] + + async def reply() -> None: + if await connection.abort_on('disconnection', method()): + await self.host.send_command( + HCI_User_Confirmation_Request_Reply_Command( # type: ignore[call-arg] + bd_addr=connection.peer_address ) - + ) else: - # Ask the user to confirm the pairing, without showing a code - async def prompt(): - return await pairing_config.delegate.confirm() - - async def confirm(): - if await prompt(): - await self.host.send_command( - HCI_User_Confirmation_Request_Reply_Command( - bd_addr=connection.peer_address - ) - ) - else: - await self.host.send_command( - HCI_User_Confirmation_Request_Negative_Reply_Command( - bd_addr=connection.peer_address - ) + await self.host.send_command( + HCI_User_Confirmation_Request_Negative_Reply_Command( # type: ignore[call-arg] + bd_addr=connection.peer_address ) + ) - AsyncRunner.spawn(connection.abort_on('disconnection', confirm())) - return - - if io_capability == HCI_DISPLAY_ONLY_IO_CAPABILITY: - # Display the code to the user - AsyncRunner.spawn(pairing_config.delegate.display_number(code, 6)) - - # Automatic confirmation - self.host.send_command_sync( - HCI_User_Confirmation_Request_Reply_Command(bd_addr=connection.peer_address) - ) + AsyncRunner.spawn(reply()) # [Classic only] @host_event_handler @with_connection_from_address - def on_authentication_user_passkey_request(self, connection): + def on_authentication_user_passkey_request(self, connection) -> None: # Ask what the pairing config should be for this connection pairing_config = self.pairing_config_factory(connection) - io_capability = pairing_config.delegate.classic_io_capability - # Respond - if io_capability == HCI_KEYBOARD_ONLY_IO_CAPABILITY: - # Ask the user to input a number - async def get_number(): - number = await connection.abort_on( - 'disconnection', pairing_config.delegate.get_number() - ) - if number is not None: - await self.host.send_command( - HCI_User_Passkey_Request_Reply_Command( - bd_addr=connection.peer_address, numeric_value=number - ) - ) - else: - await self.host.send_command( - HCI_User_Passkey_Request_Negative_Reply_Command( - bd_addr=connection.peer_address - ) - ) - - asyncio.create_task(get_number()) - else: - self.host.send_command_sync( - HCI_User_Passkey_Request_Negative_Reply_Command( - bd_addr=connection.peer_address - ) + async def reply() -> None: + number = await connection.abort_on( + 'disconnection', pairing_config.delegate.get_number() ) + if number is not None: + await self.host.send_command( + HCI_User_Passkey_Request_Reply_Command( # type: ignore[call-arg] + bd_addr=connection.peer_address, numeric_value=number + ) + ) + else: + await self.host.send_command( + HCI_User_Passkey_Request_Negative_Reply_Command( # type: ignore[call-arg] + bd_addr=connection.peer_address + ) + ) + + AsyncRunner.spawn(reply()) # [Classic only] @host_event_handler @@ -3059,18 +3085,15 @@ class Device(CompositeEventEmitter): connection.emit('role_change_failure', error) self.emit('role_change_failure', address, error) - @with_connection_from_handle - def on_pairing_start(self, connection): + def on_pairing_start(self, connection: Connection) -> None: connection.emit('pairing_start') - @with_connection_from_handle - def on_pairing(self, connection, keys, sc): + def on_pairing(self, connection: Connection, keys: PairingKeys, sc: bool) -> None: connection.sc = sc connection.authenticated = True connection.emit('pairing', keys) - @with_connection_from_handle - def on_pairing_failure(self, connection, reason): + def on_pairing_failure(self, connection: Connection, reason: int) -> None: connection.emit('pairing_failure', reason) @with_connection_from_handle diff --git a/bumble/gatt.py b/bumble/gatt.py index 88b7417..e57f0a6 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -247,7 +247,7 @@ class TemplateService(Service): to expose their UUID as a class property ''' - UUID = None + UUID: Optional[UUID] = None def __init__(self, characteristics, primary=True): super().__init__(self.UUID, characteristics, primary) diff --git a/bumble/keys.py b/bumble/keys.py index bbd46a5..a30e753 100644 --- a/bumble/keys.py +++ b/bumble/keys.py @@ -25,7 +25,7 @@ import asyncio import logging import os import json -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from .colors import color from .hci import Address @@ -139,19 +139,19 @@ class PairingKeys: # ----------------------------------------------------------------------------- class KeyStore: - async def delete(self, name): + async def delete(self, name: str): pass - async def update(self, name, keys): + async def update(self, name: str, keys: PairingKeys) -> None: pass - async def get(self, _name): - return PairingKeys() + async def get(self, _name: str) -> Optional[PairingKeys]: + return None - async def get_all(self): + async def get_all(self) -> List[Tuple[str, PairingKeys]]: return [] - async def delete_all(self): + async def delete_all(self) -> None: all_keys = await self.get_all() await asyncio.gather(*(self.delete(name) for (name, _) in all_keys)) @@ -177,15 +177,15 @@ class KeyStore: separator = '\n' @staticmethod - def create_for_device(device: Device) -> Optional[KeyStore]: + def create_for_device(device: Device) -> KeyStore: if device.config.keystore is None: - return None + return MemoryKeyStore() keystore_type = device.config.keystore.split(':', 1)[0] if keystore_type == 'JsonKeyStore': return JsonKeyStore.from_device(device) - return None + return MemoryKeyStore() # ----------------------------------------------------------------------------- @@ -307,3 +307,24 @@ class JsonKeyStore(KeyStore): return None return PairingKeys.from_dict(keys) + + +# ----------------------------------------------------------------------------- +class MemoryKeyStore(KeyStore): + all_keys: Dict[str, PairingKeys] + + def __init__(self) -> None: + self.all_keys = {} + + async def delete(self, name: str) -> None: + if name in self.all_keys: + del self.all_keys[name] + + async def update(self, name: str, keys: PairingKeys) -> None: + self.all_keys[name] = keys + + async def get(self, name: str) -> Optional[PairingKeys]: + return self.all_keys.get(name) + + async def get_all(self) -> List[Tuple[str, PairingKeys]]: + return list(self.all_keys.items()) diff --git a/bumble/pairing.py b/bumble/pairing.py index eaa4470..ab356ee 100644 --- a/bumble/pairing.py +++ b/bumble/pairing.py @@ -65,8 +65,9 @@ class PairingDelegate: DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG - DEFAULT_KEY_DISTRIBUTION: int = ( - SMP_ENC_KEY_DISTRIBUTION_FLAG | SMP_ID_KEY_DISTRIBUTION_FLAG + DEFAULT_KEY_DISTRIBUTION: KeyDistribution = ( + KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY + | KeyDistribution.DISTRIBUTE_IDENTITY_KEY ) # Default mapping from abstract to Classic I/O capabilities. @@ -85,9 +86,9 @@ class PairingDelegate: def __init__( self, - io_capability=NO_OUTPUT_NO_INPUT, - local_initiator_key_distribution=DEFAULT_KEY_DISTRIBUTION, - local_responder_key_distribution=DEFAULT_KEY_DISTRIBUTION, + io_capability: IoCapability = NO_OUTPUT_NO_INPUT, + local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION, + local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION, ) -> None: self.io_capability = io_capability self.local_initiator_key_distribution = local_initiator_key_distribution @@ -113,8 +114,11 @@ class PairingDelegate: """Accept or reject a Pairing request.""" return True - async def confirm(self) -> bool: - """Respond yes or no to a Pairing confirmation question.""" + async def confirm(self, auto: bool = False) -> bool: + """ + Respond yes or no to a Pairing confirmation question. + The `auto` parameter stands for automatic confirmation. + """ return True # pylint: disable-next=unused-argument @@ -129,7 +133,7 @@ class PairingDelegate: """ return 0 - async def get_string(self, max_length) -> Optional[str]: + async def get_string(self, max_length: int) -> Optional[str]: """ Return a string whose utf-8 encoding is up to max_length bytes. """ diff --git a/bumble/smp.py b/bumble/smp.py index d345f8b..f3fbf27 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -26,16 +26,22 @@ from __future__ import annotations import logging import asyncio import secrets -from typing import Dict, Optional, Type +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, + Type, +) from pyee import EventEmitter from .colors import color from .hci import ( - HCI_DISPLAY_ONLY_IO_CAPABILITY, - HCI_DISPLAY_YES_NO_IO_CAPABILITY, - HCI_KEYBOARD_ONLY_IO_CAPABILITY, - HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY, Address, HCI_LE_Enable_Encryption_Command, HCI_Object, @@ -51,6 +57,10 @@ from .core import ( from .keys import PairingKeys from . import crypto +if TYPE_CHECKING: + from bumble.device import Connection, Device + from bumble.pairing import PairingConfig + # ----------------------------------------------------------------------------- # Logging @@ -184,7 +194,7 @@ SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032' # ----------------------------------------------------------------------------- # Utils # ----------------------------------------------------------------------------- -def error_name(error_code): +def error_name(error_code: int) -> str: return name_or_number(SMP_ERROR_NAMES, error_code) @@ -197,11 +207,12 @@ class SMP_Command: ''' smp_classes: Dict[int, Type[SMP_Command]] = {} + fields: Any code = 0 name = '' @staticmethod - def from_bytes(pdu): + def from_bytes(pdu: bytes) -> "SMP_Command": code = pdu[0] cls = SMP_Command.smp_classes.get(code) @@ -217,11 +228,11 @@ class SMP_Command: return self @staticmethod - def command_name(code): + def command_name(code: int) -> str: return name_or_number(SMP_COMMAND_NAMES, code) @staticmethod - def auth_req_str(value): + def auth_req_str(value: int) -> str: bonding_flags = value & 3 mitm = (value >> 2) & 1 sc = (value >> 3) & 1 @@ -234,12 +245,12 @@ class SMP_Command: ) @staticmethod - def io_capability_name(io_capability): + def io_capability_name(io_capability: int) -> str: return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability) @staticmethod - def key_distribution_str(value): - key_types = [] + def key_distribution_str(value: int) -> str: + key_types: List[str] = [] if value & SMP_ENC_KEY_DISTRIBUTION_FLAG: key_types.append('ENC') if value & SMP_ID_KEY_DISTRIBUTION_FLAG: @@ -251,7 +262,7 @@ class SMP_Command: return ','.join(key_types) @staticmethod - def keypress_notification_type_name(notification_type): + def keypress_notification_type_name(notification_type: int) -> str: return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type) @staticmethod @@ -272,14 +283,14 @@ class SMP_Command: return inner - def __init__(self, pdu=None, **kwargs): + def __init__(self, pdu: Optional[bytes] = None, **kwargs: Any) -> None: if hasattr(self, 'fields') and kwargs: HCI_Object.init_from_fields(self, self.fields, kwargs) if pdu is None: pdu = bytes([self.code]) + HCI_Object.dict_to_bytes(kwargs, self.fields) self.pdu = pdu - def init_from_bytes(self, pdu, offset): + def init_from_bytes(self, pdu: bytes, offset: int) -> None: return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) def to_bytes(self): @@ -320,6 +331,13 @@ class SMP_Pairing_Request_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request ''' + io_capability: int + oob_data_flag: int + auth_req: int + maximum_encryption_key_size: int + initiator_key_distribution: int + responder_key_distribution: int + # ----------------------------------------------------------------------------- @SMP_Command.subclass( @@ -343,6 +361,13 @@ class SMP_Pairing_Response_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response ''' + io_capability: int + oob_data_flag: int + auth_req: int + maximum_encryption_key_size: int + initiator_key_distribution: int + responder_key_distribution: int + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('confirm_value', 16)]) @@ -351,6 +376,8 @@ class SMP_Pairing_Confirm_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm ''' + confirm_value: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('random_value', 16)]) @@ -359,6 +386,8 @@ class SMP_Pairing_Random_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random ''' + random_value: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('reason', {'size': 1, 'mapper': error_name})]) @@ -367,6 +396,8 @@ class SMP_Pairing_Failed_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed ''' + reason: int + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('public_key_x', 32), ('public_key_y', 32)]) @@ -375,6 +406,9 @@ class SMP_Pairing_Public_Key_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key ''' + public_key_x: bytes + public_key_y: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass( @@ -387,6 +421,8 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check ''' + dhkey_check: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass( @@ -402,6 +438,8 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification ''' + notification_type: int + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('long_term_key', 16)]) @@ -410,6 +448,8 @@ class SMP_Encryption_Information_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information ''' + long_term_key: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('ediv', 2), ('rand', 8)]) @@ -418,6 +458,9 @@ class SMP_Master_Identification_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification ''' + ediv: int + rand: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('identity_resolving_key', 16)]) @@ -426,6 +469,8 @@ class SMP_Identity_Information_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information ''' + identity_resolving_key: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass( @@ -439,6 +484,9 @@ class SMP_Identity_Address_Information_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information ''' + addr_type: int + bd_addr: Address + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('signature_key', 16)]) @@ -447,6 +495,8 @@ class SMP_Signing_Information_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information ''' + signature_key: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass( @@ -459,9 +509,11 @@ class SMP_Security_Request_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request ''' + auth_req: int + # ----------------------------------------------------------------------------- -def smp_auth_req(bonding, mitm, sc, keypress, ct2): +def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) -> int: value = 0 if bonding: value |= SMP_BONDING_AUTHREQ @@ -574,11 +626,17 @@ class Session: }, } - def __init__(self, manager, connection, pairing_config, is_initiator): + def __init__( + self, + manager: Manager, + connection: Connection, + pairing_config: PairingConfig, + is_initiator: bool, + ) -> None: self.manager = manager self.connection = connection - self.preq = None - self.pres = None + self.preq: Optional[bytes] = None + self.pres: Optional[bytes] = None self.ea = None self.eb = None self.tk = bytes(16) @@ -588,29 +646,29 @@ class Session: self.ltk_ediv = 0 self.ltk_rand = bytes(8) self.link_key = None - self.initiator_key_distribution = 0 - self.responder_key_distribution = 0 - self.peer_random_value = None - self.peer_public_key_x = bytes(32) + self.initiator_key_distribution: int = 0 + self.responder_key_distribution: int = 0 + self.peer_random_value: Optional[bytes] = None + self.peer_public_key_x: bytes = bytes(32) self.peer_public_key_y = bytes(32) self.peer_ltk = None self.peer_ediv = None - self.peer_rand = None + self.peer_rand: Optional[bytes] = None self.peer_identity_resolving_key = None - self.peer_bd_addr = None + self.peer_bd_addr: Optional[Address] = None self.peer_signature_key = None - self.peer_expected_distributions = [] + self.peer_expected_distributions: List[Type[SMP_Command]] = [] self.dh_key = None self.confirm_value = None - self.passkey = None + self.passkey: Optional[int] = None self.passkey_ready = asyncio.Event() self.passkey_step = 0 self.passkey_display = False self.pairing_method = 0 self.pairing_config = pairing_config - self.wait_before_continuing = None + self.wait_before_continuing: Optional[asyncio.Future[None]] = None self.completed = False - self.ctkd_task = None + self.ctkd_task: Optional[Awaitable[None]] = None # Decide if we're the initiator or the responder self.is_initiator = is_initiator @@ -628,7 +686,9 @@ class Session: # Create a future that can be used to wait for the session to complete if self.is_initiator: - self.pairing_result = asyncio.get_running_loop().create_future() + self.pairing_result: Optional[ + asyncio.Future[None] + ] = asyncio.get_running_loop().create_future() else: self.pairing_result = None @@ -641,11 +701,11 @@ class Session: ) # Authentication Requirements Flags - Vol 3, Part H, Figure 3.3 - self.bonding = pairing_config.bonding - self.sc = pairing_config.sc - self.mitm = pairing_config.mitm + self.bonding: bool = pairing_config.bonding + self.sc: bool = pairing_config.sc + self.mitm: bool = pairing_config.mitm self.keypress = False - self.ct2 = False + self.ct2: bool = False # I/O Capabilities self.io_capability = pairing_config.delegate.io_capability @@ -669,34 +729,35 @@ class Session: self.iat = 1 if peer_address.is_random else 0 @property - def pkx(self): + def pkx(self) -> Tuple[bytes, bytes]: return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x) @property - def pka(self): + def pka(self) -> bytes: return self.pkx[0 if self.is_initiator else 1] @property - def pkb(self): + def pkb(self) -> bytes: return self.pkx[0 if self.is_responder else 1] @property - def nx(self): + def nx(self) -> Tuple[bytes, bytes]: + assert self.peer_random_value return (self.r, self.peer_random_value) @property - def na(self): + def na(self) -> bytes: return self.nx[0 if self.is_initiator else 1] @property - def nb(self): + def nb(self) -> bytes: return self.nx[0 if self.is_responder else 1] @property - def auth_req(self): + def auth_req(self) -> int: return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2) - def get_long_term_key(self, rand, ediv): + def get_long_term_key(self, rand: bytes, ediv: int) -> Optional[bytes]: if not self.sc and not self.completed: if rand == self.ltk_rand and ediv == self.ltk_ediv: return self.stk @@ -706,13 +767,13 @@ class Session: return None def decide_pairing_method( - self, auth_req, initiator_io_capability, responder_io_capability - ): + self, auth_req: int, initiator_io_capability: int, responder_io_capability: int + ) -> None: if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0): self.pairing_method = self.JUST_WORKS return - details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] + details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] # type: ignore[index] if isinstance(details, tuple) and len(details) == 2: # One entry for legacy pairing and one for secure connections details = details[1 if self.sc else 0] @@ -724,7 +785,9 @@ class Session: self.pairing_method = details[0] self.passkey_display = details[1 if self.is_initiator else 2] - def check_expected_value(self, expected, received, error): + def check_expected_value( + self, expected: bytes, received: bytes, error: int + ) -> bool: logger.debug(f'expected={expected.hex()} got={received.hex()}') if expected != received: logger.info(color('pairing confirm/check mismatch', 'red')) @@ -732,8 +795,8 @@ class Session: return False return True - def prompt_user_for_confirmation(self, next_steps): - async def prompt(): + def prompt_user_for_confirmation(self, next_steps: Callable[[], None]) -> None: + async def prompt() -> None: logger.debug('ask for confirmation') try: response = await self.pairing_config.delegate.confirm() @@ -747,8 +810,10 @@ class Session: self.connection.abort_on('disconnection', prompt()) - def prompt_user_for_numeric_comparison(self, code, next_steps): - async def prompt(): + def prompt_user_for_numeric_comparison( + self, code: int, next_steps: Callable[[], None] + ) -> None: + async def prompt() -> None: logger.debug(f'verification code: {code}') try: response = await self.pairing_config.delegate.compare_numbers( @@ -764,11 +829,15 @@ class Session: self.connection.abort_on('disconnection', prompt()) - def prompt_user_for_number(self, next_steps): - async def prompt(): + def prompt_user_for_number(self, next_steps: Callable[[int], None]) -> None: + async def prompt() -> None: logger.debug('prompting user for passkey') try: passkey = await self.pairing_config.delegate.get_number() + if passkey is None: + logger.debug('Passkey request rejected') + self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR) + return logger.debug(f'user input: {passkey}') next_steps(passkey) except Exception as error: @@ -777,9 +846,10 @@ class Session: self.connection.abort_on('disconnection', prompt()) - def display_passkey(self): + def display_passkey(self) -> None: # Generate random Passkey/PIN code self.passkey = secrets.randbelow(1000000) + assert self.passkey is not None logger.debug(f'Pairing PIN CODE: {self.passkey:06}') self.passkey_ready.set() @@ -793,9 +863,9 @@ class Session: self.pairing_config.delegate.display_number(self.passkey, digits=6), ) - def input_passkey(self, next_steps=None): + def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None: # Prompt the user for the passkey displayed on the peer - def after_input(passkey): + def after_input(passkey: int) -> None: self.passkey = passkey if not self.sc: @@ -809,7 +879,9 @@ class Session: self.prompt_user_for_number(after_input) - def display_or_input_passkey(self, next_steps=None): + def display_or_input_passkey( + self, next_steps: Optional[Callable[[], None]] = None + ) -> None: if self.passkey_display: self.display_passkey() if next_steps is not None: @@ -817,14 +889,14 @@ class Session: else: self.input_passkey(next_steps) - def send_command(self, command): + def send_command(self, command: SMP_Command) -> None: self.manager.send_command(self.connection, command) - def send_pairing_failed(self, error): + def send_pairing_failed(self, error: int) -> None: self.send_command(SMP_Pairing_Failed_Command(reason=error)) self.on_pairing_failure(error) - def send_pairing_request_command(self): + def send_pairing_request_command(self) -> None: self.manager.on_session_start(self) command = SMP_Pairing_Request_Command( @@ -838,7 +910,7 @@ class Session: self.preq = bytes(command) self.send_command(command) - def send_pairing_response_command(self): + def send_pairing_response_command(self) -> None: response = SMP_Pairing_Response_Command( io_capability=self.io_capability, oob_data_flag=0, @@ -850,18 +922,19 @@ class Session: self.pres = bytes(response) self.send_command(response) - def send_pairing_confirm_command(self): + def send_pairing_confirm_command(self) -> None: self.r = crypto.r() logger.debug(f'generated random: {self.r.hex()}') if self.sc: - async def next_steps(): + async def next_steps() -> None: if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): z = 0 elif self.pairing_method == self.PASSKEY: # We need a passkey await self.passkey_ready.wait() + assert self.passkey z = 0x80 + ((self.passkey >> self.passkey_step) & 1) else: @@ -892,10 +965,10 @@ class Session: self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value)) - def send_pairing_random_command(self): + def send_pairing_random_command(self) -> None: self.send_command(SMP_Pairing_Random_Command(random_value=self.r)) - def send_public_key_command(self): + def send_public_key_command(self) -> None: self.send_command( SMP_Pairing_Public_Key_Command( public_key_x=bytes(reversed(self.manager.ecc_key.x)), @@ -903,18 +976,18 @@ class Session: ) ) - def send_pairing_dhkey_check_command(self): + def send_pairing_dhkey_check_command(self) -> None: self.send_command( SMP_Pairing_DHKey_Check_Command( dhkey_check=self.ea if self.is_initiator else self.eb ) ) - def start_encryption(self, key): + def start_encryption(self, key: bytes) -> None: # We can now encrypt the connection with the short term key, so that we can # distribute the long term and/or other keys over an encrypted connection self.manager.device.host.send_command_sync( - HCI_LE_Enable_Encryption_Command( + HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg] connection_handle=self.connection.handle, random_number=bytes(8), encrypted_diversifier=0, @@ -922,7 +995,7 @@ class Session: ) ) - async def derive_ltk(self): + async def derive_ltk(self) -> None: link_key = await self.manager.device.get_link_key(self.connection.peer_address) assert link_key is not None ilk = ( @@ -932,7 +1005,7 @@ class Session: ) self.ltk = crypto.h6(ilk, b'brle') - def distribute_keys(self): + def distribute_keys(self) -> None: # Distribute the keys as required if self.is_initiator: # CTKD: Derive LTK from LinkKey @@ -1032,7 +1105,7 @@ class Session: ) self.link_key = crypto.h6(ilk, b'lebr') - def compute_peer_expected_distributions(self, key_distribution_flags): + def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None: # Set our expectations for what to wait for in the key distribution phase self.peer_expected_distributions = [] if not self.sc and self.connection.transport == BT_LE_TRANSPORT: @@ -1055,7 +1128,7 @@ class Session: f'{[c.__name__ for c in self.peer_expected_distributions]}' ) - def check_key_distribution(self, command_class): + def check_key_distribution(self, command_class: Type[SMP_Command]) -> None: # First, check that the connection is encrypted if not self.connection.is_encrypted: logger.warning( @@ -1083,7 +1156,7 @@ class Session: ) self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR) - async def pair(self): + async def pair(self) -> None: # Start pairing as an initiator # TODO: check that this session isn't already active @@ -1091,9 +1164,10 @@ class Session: self.send_pairing_request_command() # Wait for the pairing process to finish + assert self.pairing_result await self.connection.abort_on('disconnection', self.pairing_result) - def on_disconnection(self, _): + def on_disconnection(self, _: int) -> None: self.connection.remove_listener('disconnection', self.on_disconnection) self.connection.remove_listener( 'connection_encryption_change', self.on_connection_encryption_change @@ -1104,14 +1178,14 @@ class Session: ) self.manager.on_session_end(self) - def on_peer_key_distribution_complete(self): + def on_peer_key_distribution_complete(self) -> None: # The initiator can now send its keys if self.is_initiator: self.distribute_keys() self.connection.abort_on('disconnection', self.on_pairing()) - def on_connection_encryption_change(self): + def on_connection_encryption_change(self) -> None: if self.connection.is_encrypted: if self.is_responder: # The responder distributes its keys first, the initiator later @@ -1121,11 +1195,11 @@ class Session: if not self.peer_expected_distributions: self.on_peer_key_distribution_complete() - def on_connection_encryption_key_refresh(self): + def on_connection_encryption_key_refresh(self) -> None: # Do as if the connection had just been encrypted self.on_connection_encryption_change() - async def on_pairing(self): + async def on_pairing(self) -> None: logger.debug('pairing complete') if self.completed: @@ -1137,7 +1211,7 @@ class Session: self.pairing_result.set_result(None) # Use the peer address from the pairing protocol or the connection - if self.peer_bd_addr: + if self.peer_bd_addr is not None: peer_address = self.peer_bd_addr else: peer_address = self.connection.peer_address @@ -1186,7 +1260,7 @@ class Session: ) self.manager.on_pairing(self, peer_address, keys) - def on_pairing_failure(self, reason): + def on_pairing_failure(self, reason: int) -> None: logger.warning(f'pairing failure ({error_name(reason)})') if self.completed: @@ -1199,7 +1273,7 @@ class Session: self.pairing_result.set_exception(error) self.manager.on_pairing_failure(self, reason) - def on_smp_command(self, command): + def on_smp_command(self, command: SMP_Command) -> None: # Find the handler method handler_name = f'on_{command.name.lower()}' handler = getattr(self, handler_name, None) @@ -1215,12 +1289,16 @@ class Session: else: logger.error(color('SMP command not handled???', 'red')) - def on_smp_pairing_request_command(self, command): + def on_smp_pairing_request_command( + self, command: SMP_Pairing_Request_Command + ) -> None: self.connection.abort_on( 'disconnection', self.on_smp_pairing_request_command_async(command) ) - async def on_smp_pairing_request_command_async(self, command): + async def on_smp_pairing_request_command_async( + self, command: SMP_Pairing_Request_Command + ) -> None: # Check if the request should proceed accepted = await self.pairing_config.delegate.accept() if not accepted: @@ -1280,7 +1358,9 @@ class Session: ): self.distribute_keys() - def on_smp_pairing_response_command(self, command): + def on_smp_pairing_response_command( + self, command: SMP_Pairing_Response_Command + ) -> None: if self.is_responder: logger.warning(color('received pairing response as a responder', 'red')) return @@ -1331,7 +1411,9 @@ class Session: else: self.send_pairing_confirm_command() - def on_smp_pairing_confirm_command_legacy(self, _): + def on_smp_pairing_confirm_command_legacy( + self, _: SMP_Pairing_Confirm_Command + ) -> None: if self.is_initiator: self.send_pairing_random_command() else: @@ -1341,7 +1423,9 @@ class Session: else: self.send_pairing_confirm_command() - def on_smp_pairing_confirm_command_secure_connections(self, _): + def on_smp_pairing_confirm_command_secure_connections( + self, _: SMP_Pairing_Confirm_Command + ) -> None: if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): if self.is_initiator: self.r = crypto.r() @@ -1352,14 +1436,18 @@ class Session: else: self.send_pairing_confirm_command() - def on_smp_pairing_confirm_command(self, command): + def on_smp_pairing_confirm_command( + self, command: SMP_Pairing_Confirm_Command + ) -> None: self.confirm_value = command.confirm_value if self.sc: self.on_smp_pairing_confirm_command_secure_connections(command) else: self.on_smp_pairing_confirm_command_legacy(command) - def on_smp_pairing_random_command_legacy(self, command): + def on_smp_pairing_random_command_legacy( + self, command: SMP_Pairing_Random_Command + ) -> None: # Check that the confirmation values match confirm_verifier = crypto.c1( self.tk, @@ -1371,6 +1459,7 @@ class Session: self.ia, self.ra, ) + assert self.confirm_value if not self.check_expected_value( self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR ): @@ -1394,7 +1483,9 @@ class Session: else: self.send_pairing_random_command() - def on_smp_pairing_random_command_secure_connections(self, command): + def on_smp_pairing_random_command_secure_connections( + self, command: SMP_Pairing_Random_Command + ) -> None: if self.pairing_method == self.PASSKEY and self.passkey is None: logger.warning('no passkey entered, ignoring command') return @@ -1402,6 +1493,7 @@ class Session: # pylint: disable=too-many-return-statements if self.is_initiator: if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): + assert self.confirm_value # Check that the random value matches what was committed to earlier confirm_verifier = crypto.f4( self.pkb, self.pka, command.random_value, bytes([0]) @@ -1411,6 +1503,7 @@ class Session: ): return elif self.pairing_method == self.PASSKEY: + assert self.passkey and self.confirm_value # Check that the random value matches what was committed to earlier confirm_verifier = crypto.f4( self.pkb, @@ -1435,6 +1528,7 @@ class Session: if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): self.send_pairing_random_command() elif self.pairing_method == self.PASSKEY: + assert self.passkey and self.confirm_value # Check that the random value matches what was committed to earlier confirm_verifier = crypto.f4( self.pka, @@ -1468,19 +1562,21 @@ class Session: ra = bytes(16) rb = ra elif self.pairing_method == self.PASSKEY: + assert self.passkey ra = self.passkey.to_bytes(16, byteorder='little') rb = ra else: # OOB not implemented yet return + assert self.preq and self.pres io_cap_a = self.preq[1:4] io_cap_b = self.pres[1:4] self.ea = crypto.f6(mac_key, self.na, self.nb, rb, io_cap_a, a, b) self.eb = crypto.f6(mac_key, self.nb, self.na, ra, io_cap_b, b, a) # Next steps to be performed after possible user confirmation - def next_steps(): + def next_steps() -> None: # The initiator sends the DH Key check to the responder if self.is_initiator: self.send_pairing_dhkey_check_command() @@ -1502,14 +1598,18 @@ class Session: else: next_steps() - def on_smp_pairing_random_command(self, command): + def on_smp_pairing_random_command( + self, command: SMP_Pairing_Random_Command + ) -> None: self.peer_random_value = command.random_value if self.sc: self.on_smp_pairing_random_command_secure_connections(command) else: self.on_smp_pairing_random_command_legacy(command) - def on_smp_pairing_public_key_command(self, command): + def on_smp_pairing_public_key_command( + self, command: SMP_Pairing_Public_Key_Command + ) -> None: # Store the public key so that we can compute the confirmation value later self.peer_public_key_x = command.public_key_x self.peer_public_key_y = command.public_key_y @@ -1538,9 +1638,12 @@ class Session: # We can now send the confirmation value self.send_pairing_confirm_command() - def on_smp_pairing_dhkey_check_command(self, command): + def on_smp_pairing_dhkey_check_command( + self, command: SMP_Pairing_DHKey_Check_Command + ) -> None: # Check that what we received matches what we computed earlier expected = self.eb if self.is_initiator else self.ea + assert expected if not self.check_expected_value( expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR ): @@ -1549,7 +1652,8 @@ class Session: if self.is_responder: if self.wait_before_continuing is not None: - async def next_steps(): + async def next_steps() -> None: + assert self.wait_before_continuing await self.wait_before_continuing self.wait_before_continuing = None self.send_pairing_dhkey_check_command() @@ -1558,29 +1662,42 @@ class Session: else: self.send_pairing_dhkey_check_command() else: + assert self.ltk self.start_encryption(self.ltk) - def on_smp_pairing_failed_command(self, command): + def on_smp_pairing_failed_command( + self, command: SMP_Pairing_Failed_Command + ) -> None: self.on_pairing_failure(command.reason) - def on_smp_encryption_information_command(self, command): + def on_smp_encryption_information_command( + self, command: SMP_Encryption_Information_Command + ) -> None: self.peer_ltk = command.long_term_key self.check_key_distribution(SMP_Encryption_Information_Command) - def on_smp_master_identification_command(self, command): + def on_smp_master_identification_command( + self, command: SMP_Master_Identification_Command + ) -> None: self.peer_ediv = command.ediv self.peer_rand = command.rand self.check_key_distribution(SMP_Master_Identification_Command) - def on_smp_identity_information_command(self, command): + def on_smp_identity_information_command( + self, command: SMP_Identity_Information_Command + ) -> None: self.peer_identity_resolving_key = command.identity_resolving_key self.check_key_distribution(SMP_Identity_Information_Command) - def on_smp_identity_address_information_command(self, command): + def on_smp_identity_address_information_command( + self, command: SMP_Identity_Address_Information_Command + ) -> None: self.peer_bd_addr = command.bd_addr self.check_key_distribution(SMP_Identity_Address_Information_Command) - def on_smp_signing_information_command(self, command): + def on_smp_signing_information_command( + self, command: SMP_Signing_Information_Command + ) -> None: self.peer_signature_key = command.signature_key self.check_key_distribution(SMP_Signing_Information_Command) @@ -1591,14 +1708,24 @@ class Manager(EventEmitter): Implements the Initiator and Responder roles of the Security Manager Protocol ''' - def __init__(self, device, pairing_config_factory): + device: Device + sessions: Dict[int, Session] + pairing_config_factory: Callable[[Connection], PairingConfig] + session_proxy: Type[Session] + + def __init__( + self, + device: Device, + pairing_config_factory: Callable[[Connection], PairingConfig], + ) -> None: super().__init__() self.device = device self.sessions = {} self._ecc_key = None self.pairing_config_factory = pairing_config_factory + self.session_proxy = Session - def send_command(self, connection, command): + def send_command(self, connection: Connection, command: SMP_Command) -> None: logger.debug( f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] ' f'{connection.peer_address}: {command}' @@ -1606,20 +1733,15 @@ class Manager(EventEmitter): cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID connection.send_l2cap_pdu(cid, command.to_bytes()) - def on_smp_pdu(self, connection, pdu): + def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None: # Look for a session with this connection, and create one if none exists if not (session := self.sessions.get(connection.handle)): if connection.role == BT_CENTRAL_ROLE: logger.warning('Remote starts pairing as Peripheral!') pairing_config = self.pairing_config_factory(connection) - if pairing_config is None: - # Pairing disabled - self.send_command( - connection, - SMP_Pairing_Failed_Command(reason=SMP_PAIRING_NOT_SUPPORTED_ERROR), - ) - return - session = Session(self, connection, pairing_config, is_initiator=False) + session = self.session_proxy( + self, connection, pairing_config, is_initiator=False + ) self.sessions[connection.handle] = session # Parse the L2CAP payload into an SMP Command object @@ -1633,23 +1755,24 @@ class Manager(EventEmitter): session.on_smp_command(command) @property - def ecc_key(self): + def ecc_key(self) -> crypto.EccKey: if self._ecc_key is None: self._ecc_key = crypto.EccKey.generate() + assert self._ecc_key return self._ecc_key - async def pair(self, connection): + async def pair(self, connection: Connection) -> None: # TODO: check if there's already a session for this connection if connection.role != BT_CENTRAL_ROLE: logger.warning('Start pairing as Peripheral!') pairing_config = self.pairing_config_factory(connection) - if pairing_config is None: - raise ValueError('pairing config must not be None when initiating') - session = Session(self, connection, pairing_config, is_initiator=True) + session = self.session_proxy( + self, connection, pairing_config, is_initiator=True + ) self.sessions[connection.handle] = session return await session.pair() - def request_pairing(self, connection): + def request_pairing(self, connection: Connection) -> None: pairing_config = self.pairing_config_factory(connection) if pairing_config: auth_req = smp_auth_req( @@ -1663,15 +1786,18 @@ class Manager(EventEmitter): auth_req = 0 self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req)) - def on_session_start(self, session): - self.device.on_pairing_start(session.connection.handle) + def on_session_start(self, session: Session) -> None: + self.device.on_pairing_start(session.connection) - def on_pairing(self, session, identity_address, keys): + def on_pairing( + self, session: Session, identity_address: Optional[Address], keys: PairingKeys + ) -> None: # Store the keys in the key store if self.device.keystore and identity_address is not None: async def store_keys(): try: + assert self.device.keystore await self.device.keystore.update(str(identity_address), keys) except Exception as error: logger.warning(f'!!! error while storing keys: {error}') @@ -1679,17 +1805,19 @@ class Manager(EventEmitter): self.device.abort_on('flush', store_keys()) # Notify the device - self.device.on_pairing(session.connection.handle, keys, session.sc) + self.device.on_pairing(session.connection, keys, session.sc) - def on_pairing_failure(self, session, reason): - self.device.on_pairing_failure(session.connection.handle, reason) + def on_pairing_failure(self, session: Session, reason: int) -> None: + self.device.on_pairing_failure(session.connection, reason) - def on_session_end(self, session): + def on_session_end(self, session: Session) -> None: logger.debug(f'session end for connection 0x{session.connection.handle:04X}') if session.connection.handle in self.sessions: del self.sessions[session.connection.handle] - def get_long_term_key(self, connection, rand, ediv): + def get_long_term_key( + self, connection: Connection, rand: bytes, ediv: int + ) -> Optional[bytes]: if session := self.sessions.get(connection.handle): return session.get_long_term_key(rand, ediv) diff --git a/tests/core_test.py b/tests/core_test.py index 7ee2dfd..6c9d0c3 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -15,7 +15,7 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- -from bumble.core import AdvertisingData, get_dict_key_by_value +from bumble.core import AdvertisingData, UUID, get_dict_key_by_value # ----------------------------------------------------------------------------- def test_ad_data(): @@ -49,6 +49,24 @@ def test_get_dict_key_by_value(): assert get_dict_key_by_value(dictionary, 3) is None +# ----------------------------------------------------------------------------- +def test_uuid_to_hex_str() -> None: + assert UUID("b5ea").to_hex_str() == "B5EA" + assert UUID("df5ce654").to_hex_str() == "DF5CE654" + assert ( + UUID("df5ce654-e059-11ed-b5ea-0242ac120002").to_hex_str() + == "DF5CE654E05911EDB5EA0242AC120002" + ) + assert UUID("b5ea").to_hex_str('-') == "B5EA" + assert UUID("df5ce654").to_hex_str('-') == "DF5CE654" + assert ( + UUID("df5ce654-e059-11ed-b5ea-0242ac120002").to_hex_str('-') + == "DF5CE654-E059-11ED-B5EA-0242AC120002" + ) + + # ----------------------------------------------------------------------------- if __name__ == '__main__': test_ad_data() + test_get_dict_key_by_value() + test_uuid_to_hex_str()