From 3c81b248a383174ccbafa275767d7876cc6dc963 Mon Sep 17 00:00:00 2001 From: uael Date: Tue, 2 May 2023 05:27:35 +0000 Subject: [PATCH] smp: add type hints --- bumble/device.py | 20 +-- bumble/smp.py | 366 +++++++++++++++++++++++++++++++---------------- 2 files changed, 254 insertions(+), 132 deletions(-) diff --git a/bumble/device.py b/bumble/device.py index c09e061..7448bd8 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -23,7 +23,7 @@ import asyncio import logging from contextlib import asynccontextmanager, AsyncExitStack from dataclasses import dataclass -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union from .colors import color from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU @@ -528,6 +528,7 @@ class Connection(CompositeEventEmitter): transport: int self_address: Address peer_address: Address + peer_resolvable_address: Optional[Address] role: int encryption: int authenticated: bool @@ -2196,11 +2197,13 @@ class Device(CompositeEventEmitter): await self.stop_discovery() @property - def pairing_config_factory(self): + def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]: return self.smp_manager.pairing_config_factory @pairing_config_factory.setter - def pairing_config_factory(self, pairing_config_factory): + def pairing_config_factory( + self, pairing_config_factory: Callable[[Connection], PairingConfig] + ) -> None: self.smp_manager.pairing_config_factory = pairing_config_factory async def pair(self, connection): @@ -2232,7 +2235,7 @@ class Device(CompositeEventEmitter): if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral: return keys.ltk_peripheral.value - async def get_link_key(self, address): + async def get_link_key(self, address: Address) -> Optional[bytes]: # Look for the key in the keystore if self.keystore is not None: keys = await self.keystore.get(str(address)) @@ -3074,18 +3077,15 @@ class Device(CompositeEventEmitter): connection.emit('role_change_failure', error) self.emit('role_change_failure', address, error) - @with_connection_from_handle - def on_pairing_start(self, connection): + def on_pairing_start(self, connection: Connection) -> None: connection.emit('pairing_start') - @with_connection_from_handle - def on_pairing(self, connection, keys, sc): + def on_pairing(self, connection: Connection, keys: PairingKeys, sc: bool) -> None: connection.sc = sc connection.authenticated = True connection.emit('pairing', keys) - @with_connection_from_handle - def on_pairing_failure(self, connection, reason): + def on_pairing_failure(self, connection: Connection, reason: int) -> None: connection.emit('pairing_failure', reason) @with_connection_from_handle diff --git a/bumble/smp.py b/bumble/smp.py index d345f8b..ade51ec 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -26,16 +26,22 @@ from __future__ import annotations import logging import asyncio import secrets -from typing import Dict, Optional, Type +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, + Type, +) from pyee import EventEmitter from .colors import color from .hci import ( - HCI_DISPLAY_ONLY_IO_CAPABILITY, - HCI_DISPLAY_YES_NO_IO_CAPABILITY, - HCI_KEYBOARD_ONLY_IO_CAPABILITY, - HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY, Address, HCI_LE_Enable_Encryption_Command, HCI_Object, @@ -51,6 +57,10 @@ from .core import ( from .keys import PairingKeys from . import crypto +if TYPE_CHECKING: + from bumble.device import Connection, Device + from bumble.pairing import PairingConfig + # ----------------------------------------------------------------------------- # Logging @@ -184,7 +194,7 @@ SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032' # ----------------------------------------------------------------------------- # Utils # ----------------------------------------------------------------------------- -def error_name(error_code): +def error_name(error_code: int) -> str: return name_or_number(SMP_ERROR_NAMES, error_code) @@ -197,11 +207,12 @@ class SMP_Command: ''' smp_classes: Dict[int, Type[SMP_Command]] = {} + fields: Any code = 0 name = '' @staticmethod - def from_bytes(pdu): + def from_bytes(pdu: bytes) -> "SMP_Command": code = pdu[0] cls = SMP_Command.smp_classes.get(code) @@ -217,11 +228,11 @@ class SMP_Command: return self @staticmethod - def command_name(code): + def command_name(code: int) -> str: return name_or_number(SMP_COMMAND_NAMES, code) @staticmethod - def auth_req_str(value): + def auth_req_str(value: int) -> str: bonding_flags = value & 3 mitm = (value >> 2) & 1 sc = (value >> 3) & 1 @@ -234,12 +245,12 @@ class SMP_Command: ) @staticmethod - def io_capability_name(io_capability): + def io_capability_name(io_capability: int) -> str: return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability) @staticmethod - def key_distribution_str(value): - key_types = [] + def key_distribution_str(value: int) -> str: + key_types: List[str] = [] if value & SMP_ENC_KEY_DISTRIBUTION_FLAG: key_types.append('ENC') if value & SMP_ID_KEY_DISTRIBUTION_FLAG: @@ -251,7 +262,7 @@ class SMP_Command: return ','.join(key_types) @staticmethod - def keypress_notification_type_name(notification_type): + def keypress_notification_type_name(notification_type: int) -> str: return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type) @staticmethod @@ -272,14 +283,14 @@ class SMP_Command: return inner - def __init__(self, pdu=None, **kwargs): + def __init__(self, pdu: Optional[bytes] = None, **kwargs: Any) -> None: if hasattr(self, 'fields') and kwargs: HCI_Object.init_from_fields(self, self.fields, kwargs) if pdu is None: pdu = bytes([self.code]) + HCI_Object.dict_to_bytes(kwargs, self.fields) self.pdu = pdu - def init_from_bytes(self, pdu, offset): + def init_from_bytes(self, pdu: bytes, offset: int) -> None: return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) def to_bytes(self): @@ -320,6 +331,13 @@ class SMP_Pairing_Request_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request ''' + io_capability: int + oob_data_flag: int + auth_req: int + maximum_encryption_key_size: int + initiator_key_distribution: int + responder_key_distribution: int + # ----------------------------------------------------------------------------- @SMP_Command.subclass( @@ -343,6 +361,13 @@ class SMP_Pairing_Response_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response ''' + io_capability: int + oob_data_flag: int + auth_req: int + maximum_encryption_key_size: int + initiator_key_distribution: int + responder_key_distribution: int + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('confirm_value', 16)]) @@ -351,6 +376,8 @@ class SMP_Pairing_Confirm_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm ''' + confirm_value: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('random_value', 16)]) @@ -359,6 +386,8 @@ class SMP_Pairing_Random_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random ''' + random_value: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('reason', {'size': 1, 'mapper': error_name})]) @@ -367,6 +396,8 @@ class SMP_Pairing_Failed_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed ''' + reason: int + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('public_key_x', 32), ('public_key_y', 32)]) @@ -375,6 +406,9 @@ class SMP_Pairing_Public_Key_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key ''' + public_key_x: bytes + public_key_y: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass( @@ -387,6 +421,8 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check ''' + dhkey_check: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass( @@ -402,6 +438,8 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification ''' + notification_type: int + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('long_term_key', 16)]) @@ -410,6 +448,8 @@ class SMP_Encryption_Information_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information ''' + long_term_key: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('ediv', 2), ('rand', 8)]) @@ -418,6 +458,9 @@ class SMP_Master_Identification_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification ''' + ediv: int + rand: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('identity_resolving_key', 16)]) @@ -426,6 +469,8 @@ class SMP_Identity_Information_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information ''' + identity_resolving_key: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass( @@ -439,6 +484,9 @@ class SMP_Identity_Address_Information_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information ''' + addr_type: int + bd_addr: Address + # ----------------------------------------------------------------------------- @SMP_Command.subclass([('signature_key', 16)]) @@ -447,6 +495,8 @@ class SMP_Signing_Information_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information ''' + signature_key: bytes + # ----------------------------------------------------------------------------- @SMP_Command.subclass( @@ -459,9 +509,11 @@ class SMP_Security_Request_Command(SMP_Command): See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request ''' + auth_req: int + # ----------------------------------------------------------------------------- -def smp_auth_req(bonding, mitm, sc, keypress, ct2): +def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) -> int: value = 0 if bonding: value |= SMP_BONDING_AUTHREQ @@ -574,11 +626,17 @@ class Session: }, } - def __init__(self, manager, connection, pairing_config, is_initiator): + def __init__( + self, + manager: Manager, + connection: Connection, + pairing_config: PairingConfig, + is_initiator: bool, + ) -> None: self.manager = manager self.connection = connection - self.preq = None - self.pres = None + self.preq: Optional[bytes] = None + self.pres: Optional[bytes] = None self.ea = None self.eb = None self.tk = bytes(16) @@ -588,29 +646,29 @@ class Session: self.ltk_ediv = 0 self.ltk_rand = bytes(8) self.link_key = None - self.initiator_key_distribution = 0 - self.responder_key_distribution = 0 - self.peer_random_value = None - self.peer_public_key_x = bytes(32) + self.initiator_key_distribution: int = 0 + self.responder_key_distribution: int = 0 + self.peer_random_value: Optional[bytes] = None + self.peer_public_key_x: bytes = bytes(32) self.peer_public_key_y = bytes(32) self.peer_ltk = None self.peer_ediv = None - self.peer_rand = None + self.peer_rand: Optional[bytes] = None self.peer_identity_resolving_key = None - self.peer_bd_addr = None + self.peer_bd_addr: Optional[Address] = None self.peer_signature_key = None - self.peer_expected_distributions = [] + self.peer_expected_distributions: List[Type[SMP_Command]] = [] self.dh_key = None self.confirm_value = None - self.passkey = None + self.passkey: Optional[int] = None self.passkey_ready = asyncio.Event() self.passkey_step = 0 self.passkey_display = False self.pairing_method = 0 self.pairing_config = pairing_config - self.wait_before_continuing = None + self.wait_before_continuing: Optional[asyncio.Future[None]] = None self.completed = False - self.ctkd_task = None + self.ctkd_task: Optional[Awaitable[None]] = None # Decide if we're the initiator or the responder self.is_initiator = is_initiator @@ -628,7 +686,9 @@ class Session: # Create a future that can be used to wait for the session to complete if self.is_initiator: - self.pairing_result = asyncio.get_running_loop().create_future() + self.pairing_result: Optional[ + asyncio.Future[None] + ] = asyncio.get_running_loop().create_future() else: self.pairing_result = None @@ -641,11 +701,11 @@ class Session: ) # Authentication Requirements Flags - Vol 3, Part H, Figure 3.3 - self.bonding = pairing_config.bonding - self.sc = pairing_config.sc - self.mitm = pairing_config.mitm + self.bonding: bool = pairing_config.bonding + self.sc: bool = pairing_config.sc + self.mitm: bool = pairing_config.mitm self.keypress = False - self.ct2 = False + self.ct2: bool = False # I/O Capabilities self.io_capability = pairing_config.delegate.io_capability @@ -669,34 +729,35 @@ class Session: self.iat = 1 if peer_address.is_random else 0 @property - def pkx(self): + def pkx(self) -> Tuple[bytes, bytes]: return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x) @property - def pka(self): + def pka(self) -> bytes: return self.pkx[0 if self.is_initiator else 1] @property - def pkb(self): + def pkb(self) -> bytes: return self.pkx[0 if self.is_responder else 1] @property - def nx(self): + def nx(self) -> Tuple[bytes, bytes]: + assert self.peer_random_value return (self.r, self.peer_random_value) @property - def na(self): + def na(self) -> bytes: return self.nx[0 if self.is_initiator else 1] @property - def nb(self): + def nb(self) -> bytes: return self.nx[0 if self.is_responder else 1] @property - def auth_req(self): + def auth_req(self) -> int: return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2) - def get_long_term_key(self, rand, ediv): + def get_long_term_key(self, rand: bytes, ediv: int) -> Optional[bytes]: if not self.sc and not self.completed: if rand == self.ltk_rand and ediv == self.ltk_ediv: return self.stk @@ -706,13 +767,13 @@ class Session: return None def decide_pairing_method( - self, auth_req, initiator_io_capability, responder_io_capability - ): + self, auth_req: int, initiator_io_capability: int, responder_io_capability: int + ) -> None: if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0): self.pairing_method = self.JUST_WORKS return - details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] + details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] # type: ignore[index] if isinstance(details, tuple) and len(details) == 2: # One entry for legacy pairing and one for secure connections details = details[1 if self.sc else 0] @@ -724,7 +785,9 @@ class Session: self.pairing_method = details[0] self.passkey_display = details[1 if self.is_initiator else 2] - def check_expected_value(self, expected, received, error): + def check_expected_value( + self, expected: bytes, received: bytes, error: int + ) -> bool: logger.debug(f'expected={expected.hex()} got={received.hex()}') if expected != received: logger.info(color('pairing confirm/check mismatch', 'red')) @@ -732,8 +795,8 @@ class Session: return False return True - def prompt_user_for_confirmation(self, next_steps): - async def prompt(): + def prompt_user_for_confirmation(self, next_steps: Callable[[], None]) -> None: + async def prompt() -> None: logger.debug('ask for confirmation') try: response = await self.pairing_config.delegate.confirm() @@ -747,8 +810,10 @@ class Session: self.connection.abort_on('disconnection', prompt()) - def prompt_user_for_numeric_comparison(self, code, next_steps): - async def prompt(): + def prompt_user_for_numeric_comparison( + self, code: int, next_steps: Callable[[], None] + ) -> None: + async def prompt() -> None: logger.debug(f'verification code: {code}') try: response = await self.pairing_config.delegate.compare_numbers( @@ -764,11 +829,15 @@ class Session: self.connection.abort_on('disconnection', prompt()) - def prompt_user_for_number(self, next_steps): - async def prompt(): + def prompt_user_for_number(self, next_steps: Callable[[int], None]) -> None: + async def prompt() -> None: logger.debug('prompting user for passkey') try: passkey = await self.pairing_config.delegate.get_number() + if passkey is None: + logger.debug('Passkey request rejected') + self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR) + return logger.debug(f'user input: {passkey}') next_steps(passkey) except Exception as error: @@ -777,9 +846,10 @@ class Session: self.connection.abort_on('disconnection', prompt()) - def display_passkey(self): + def display_passkey(self) -> None: # Generate random Passkey/PIN code self.passkey = secrets.randbelow(1000000) + assert self.passkey is not None logger.debug(f'Pairing PIN CODE: {self.passkey:06}') self.passkey_ready.set() @@ -793,9 +863,9 @@ class Session: self.pairing_config.delegate.display_number(self.passkey, digits=6), ) - def input_passkey(self, next_steps=None): + def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None: # Prompt the user for the passkey displayed on the peer - def after_input(passkey): + def after_input(passkey: int) -> None: self.passkey = passkey if not self.sc: @@ -809,7 +879,9 @@ class Session: self.prompt_user_for_number(after_input) - def display_or_input_passkey(self, next_steps=None): + def display_or_input_passkey( + self, next_steps: Optional[Callable[[], None]] = None + ) -> None: if self.passkey_display: self.display_passkey() if next_steps is not None: @@ -817,14 +889,14 @@ class Session: else: self.input_passkey(next_steps) - def send_command(self, command): + def send_command(self, command: SMP_Command) -> None: self.manager.send_command(self.connection, command) - def send_pairing_failed(self, error): + def send_pairing_failed(self, error: int) -> None: self.send_command(SMP_Pairing_Failed_Command(reason=error)) self.on_pairing_failure(error) - def send_pairing_request_command(self): + def send_pairing_request_command(self) -> None: self.manager.on_session_start(self) command = SMP_Pairing_Request_Command( @@ -838,7 +910,7 @@ class Session: self.preq = bytes(command) self.send_command(command) - def send_pairing_response_command(self): + def send_pairing_response_command(self) -> None: response = SMP_Pairing_Response_Command( io_capability=self.io_capability, oob_data_flag=0, @@ -850,18 +922,19 @@ class Session: self.pres = bytes(response) self.send_command(response) - def send_pairing_confirm_command(self): + def send_pairing_confirm_command(self) -> None: self.r = crypto.r() logger.debug(f'generated random: {self.r.hex()}') if self.sc: - async def next_steps(): + async def next_steps() -> None: if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): z = 0 elif self.pairing_method == self.PASSKEY: # We need a passkey await self.passkey_ready.wait() + assert self.passkey z = 0x80 + ((self.passkey >> self.passkey_step) & 1) else: @@ -892,10 +965,10 @@ class Session: self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value)) - def send_pairing_random_command(self): + def send_pairing_random_command(self) -> None: self.send_command(SMP_Pairing_Random_Command(random_value=self.r)) - def send_public_key_command(self): + def send_public_key_command(self) -> None: self.send_command( SMP_Pairing_Public_Key_Command( public_key_x=bytes(reversed(self.manager.ecc_key.x)), @@ -903,18 +976,18 @@ class Session: ) ) - def send_pairing_dhkey_check_command(self): + def send_pairing_dhkey_check_command(self) -> None: self.send_command( SMP_Pairing_DHKey_Check_Command( dhkey_check=self.ea if self.is_initiator else self.eb ) ) - def start_encryption(self, key): + def start_encryption(self, key: bytes) -> None: # We can now encrypt the connection with the short term key, so that we can # distribute the long term and/or other keys over an encrypted connection self.manager.device.host.send_command_sync( - HCI_LE_Enable_Encryption_Command( + HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg] connection_handle=self.connection.handle, random_number=bytes(8), encrypted_diversifier=0, @@ -922,7 +995,7 @@ class Session: ) ) - async def derive_ltk(self): + async def derive_ltk(self) -> None: link_key = await self.manager.device.get_link_key(self.connection.peer_address) assert link_key is not None ilk = ( @@ -932,7 +1005,7 @@ class Session: ) self.ltk = crypto.h6(ilk, b'brle') - def distribute_keys(self): + def distribute_keys(self) -> None: # Distribute the keys as required if self.is_initiator: # CTKD: Derive LTK from LinkKey @@ -1032,7 +1105,7 @@ class Session: ) self.link_key = crypto.h6(ilk, b'lebr') - def compute_peer_expected_distributions(self, key_distribution_flags): + def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None: # Set our expectations for what to wait for in the key distribution phase self.peer_expected_distributions = [] if not self.sc and self.connection.transport == BT_LE_TRANSPORT: @@ -1055,7 +1128,7 @@ class Session: f'{[c.__name__ for c in self.peer_expected_distributions]}' ) - def check_key_distribution(self, command_class): + def check_key_distribution(self, command_class: Type[SMP_Command]) -> None: # First, check that the connection is encrypted if not self.connection.is_encrypted: logger.warning( @@ -1083,7 +1156,7 @@ class Session: ) self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR) - async def pair(self): + async def pair(self) -> None: # Start pairing as an initiator # TODO: check that this session isn't already active @@ -1091,9 +1164,10 @@ class Session: self.send_pairing_request_command() # Wait for the pairing process to finish + assert self.pairing_result await self.connection.abort_on('disconnection', self.pairing_result) - def on_disconnection(self, _): + def on_disconnection(self, _: int) -> None: self.connection.remove_listener('disconnection', self.on_disconnection) self.connection.remove_listener( 'connection_encryption_change', self.on_connection_encryption_change @@ -1104,14 +1178,14 @@ class Session: ) self.manager.on_session_end(self) - def on_peer_key_distribution_complete(self): + def on_peer_key_distribution_complete(self) -> None: # The initiator can now send its keys if self.is_initiator: self.distribute_keys() self.connection.abort_on('disconnection', self.on_pairing()) - def on_connection_encryption_change(self): + def on_connection_encryption_change(self) -> None: if self.connection.is_encrypted: if self.is_responder: # The responder distributes its keys first, the initiator later @@ -1121,11 +1195,11 @@ class Session: if not self.peer_expected_distributions: self.on_peer_key_distribution_complete() - def on_connection_encryption_key_refresh(self): + def on_connection_encryption_key_refresh(self) -> None: # Do as if the connection had just been encrypted self.on_connection_encryption_change() - async def on_pairing(self): + async def on_pairing(self) -> None: logger.debug('pairing complete') if self.completed: @@ -1137,7 +1211,7 @@ class Session: self.pairing_result.set_result(None) # Use the peer address from the pairing protocol or the connection - if self.peer_bd_addr: + if self.peer_bd_addr is not None: peer_address = self.peer_bd_addr else: peer_address = self.connection.peer_address @@ -1186,7 +1260,7 @@ class Session: ) self.manager.on_pairing(self, peer_address, keys) - def on_pairing_failure(self, reason): + def on_pairing_failure(self, reason: int) -> None: logger.warning(f'pairing failure ({error_name(reason)})') if self.completed: @@ -1199,7 +1273,7 @@ class Session: self.pairing_result.set_exception(error) self.manager.on_pairing_failure(self, reason) - def on_smp_command(self, command): + def on_smp_command(self, command: SMP_Command) -> None: # Find the handler method handler_name = f'on_{command.name.lower()}' handler = getattr(self, handler_name, None) @@ -1215,12 +1289,16 @@ class Session: else: logger.error(color('SMP command not handled???', 'red')) - def on_smp_pairing_request_command(self, command): + def on_smp_pairing_request_command( + self, command: SMP_Pairing_Request_Command + ) -> None: self.connection.abort_on( 'disconnection', self.on_smp_pairing_request_command_async(command) ) - async def on_smp_pairing_request_command_async(self, command): + async def on_smp_pairing_request_command_async( + self, command: SMP_Pairing_Request_Command + ) -> None: # Check if the request should proceed accepted = await self.pairing_config.delegate.accept() if not accepted: @@ -1280,7 +1358,9 @@ class Session: ): self.distribute_keys() - def on_smp_pairing_response_command(self, command): + def on_smp_pairing_response_command( + self, command: SMP_Pairing_Response_Command + ) -> None: if self.is_responder: logger.warning(color('received pairing response as a responder', 'red')) return @@ -1331,7 +1411,9 @@ class Session: else: self.send_pairing_confirm_command() - def on_smp_pairing_confirm_command_legacy(self, _): + def on_smp_pairing_confirm_command_legacy( + self, _: SMP_Pairing_Confirm_Command + ) -> None: if self.is_initiator: self.send_pairing_random_command() else: @@ -1341,7 +1423,9 @@ class Session: else: self.send_pairing_confirm_command() - def on_smp_pairing_confirm_command_secure_connections(self, _): + def on_smp_pairing_confirm_command_secure_connections( + self, _: SMP_Pairing_Confirm_Command + ) -> None: if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): if self.is_initiator: self.r = crypto.r() @@ -1352,14 +1436,18 @@ class Session: else: self.send_pairing_confirm_command() - def on_smp_pairing_confirm_command(self, command): + def on_smp_pairing_confirm_command( + self, command: SMP_Pairing_Confirm_Command + ) -> None: self.confirm_value = command.confirm_value if self.sc: self.on_smp_pairing_confirm_command_secure_connections(command) else: self.on_smp_pairing_confirm_command_legacy(command) - def on_smp_pairing_random_command_legacy(self, command): + def on_smp_pairing_random_command_legacy( + self, command: SMP_Pairing_Random_Command + ) -> None: # Check that the confirmation values match confirm_verifier = crypto.c1( self.tk, @@ -1371,6 +1459,7 @@ class Session: self.ia, self.ra, ) + assert self.confirm_value if not self.check_expected_value( self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR ): @@ -1394,7 +1483,9 @@ class Session: else: self.send_pairing_random_command() - def on_smp_pairing_random_command_secure_connections(self, command): + def on_smp_pairing_random_command_secure_connections( + self, command: SMP_Pairing_Random_Command + ) -> None: if self.pairing_method == self.PASSKEY and self.passkey is None: logger.warning('no passkey entered, ignoring command') return @@ -1402,6 +1493,7 @@ class Session: # pylint: disable=too-many-return-statements if self.is_initiator: if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): + assert self.confirm_value # Check that the random value matches what was committed to earlier confirm_verifier = crypto.f4( self.pkb, self.pka, command.random_value, bytes([0]) @@ -1411,6 +1503,7 @@ class Session: ): return elif self.pairing_method == self.PASSKEY: + assert self.passkey and self.confirm_value # Check that the random value matches what was committed to earlier confirm_verifier = crypto.f4( self.pkb, @@ -1435,6 +1528,7 @@ class Session: if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): self.send_pairing_random_command() elif self.pairing_method == self.PASSKEY: + assert self.passkey and self.confirm_value # Check that the random value matches what was committed to earlier confirm_verifier = crypto.f4( self.pka, @@ -1468,19 +1562,21 @@ class Session: ra = bytes(16) rb = ra elif self.pairing_method == self.PASSKEY: + assert self.passkey ra = self.passkey.to_bytes(16, byteorder='little') rb = ra else: # OOB not implemented yet return + assert self.preq and self.pres io_cap_a = self.preq[1:4] io_cap_b = self.pres[1:4] self.ea = crypto.f6(mac_key, self.na, self.nb, rb, io_cap_a, a, b) self.eb = crypto.f6(mac_key, self.nb, self.na, ra, io_cap_b, b, a) # Next steps to be performed after possible user confirmation - def next_steps(): + def next_steps() -> None: # The initiator sends the DH Key check to the responder if self.is_initiator: self.send_pairing_dhkey_check_command() @@ -1502,14 +1598,18 @@ class Session: else: next_steps() - def on_smp_pairing_random_command(self, command): + def on_smp_pairing_random_command( + self, command: SMP_Pairing_Random_Command + ) -> None: self.peer_random_value = command.random_value if self.sc: self.on_smp_pairing_random_command_secure_connections(command) else: self.on_smp_pairing_random_command_legacy(command) - def on_smp_pairing_public_key_command(self, command): + def on_smp_pairing_public_key_command( + self, command: SMP_Pairing_Public_Key_Command + ) -> None: # Store the public key so that we can compute the confirmation value later self.peer_public_key_x = command.public_key_x self.peer_public_key_y = command.public_key_y @@ -1538,9 +1638,12 @@ class Session: # We can now send the confirmation value self.send_pairing_confirm_command() - def on_smp_pairing_dhkey_check_command(self, command): + def on_smp_pairing_dhkey_check_command( + self, command: SMP_Pairing_DHKey_Check_Command + ) -> None: # Check that what we received matches what we computed earlier expected = self.eb if self.is_initiator else self.ea + assert expected if not self.check_expected_value( expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR ): @@ -1549,7 +1652,8 @@ class Session: if self.is_responder: if self.wait_before_continuing is not None: - async def next_steps(): + async def next_steps() -> None: + assert self.wait_before_continuing await self.wait_before_continuing self.wait_before_continuing = None self.send_pairing_dhkey_check_command() @@ -1558,29 +1662,42 @@ class Session: else: self.send_pairing_dhkey_check_command() else: + assert self.ltk self.start_encryption(self.ltk) - def on_smp_pairing_failed_command(self, command): + def on_smp_pairing_failed_command( + self, command: SMP_Pairing_Failed_Command + ) -> None: self.on_pairing_failure(command.reason) - def on_smp_encryption_information_command(self, command): + def on_smp_encryption_information_command( + self, command: SMP_Encryption_Information_Command + ) -> None: self.peer_ltk = command.long_term_key self.check_key_distribution(SMP_Encryption_Information_Command) - def on_smp_master_identification_command(self, command): + def on_smp_master_identification_command( + self, command: SMP_Master_Identification_Command + ) -> None: self.peer_ediv = command.ediv self.peer_rand = command.rand self.check_key_distribution(SMP_Master_Identification_Command) - def on_smp_identity_information_command(self, command): + def on_smp_identity_information_command( + self, command: SMP_Identity_Information_Command + ) -> None: self.peer_identity_resolving_key = command.identity_resolving_key self.check_key_distribution(SMP_Identity_Information_Command) - def on_smp_identity_address_information_command(self, command): + def on_smp_identity_address_information_command( + self, command: SMP_Identity_Address_Information_Command + ) -> None: self.peer_bd_addr = command.bd_addr self.check_key_distribution(SMP_Identity_Address_Information_Command) - def on_smp_signing_information_command(self, command): + def on_smp_signing_information_command( + self, command: SMP_Signing_Information_Command + ) -> None: self.peer_signature_key = command.signature_key self.check_key_distribution(SMP_Signing_Information_Command) @@ -1591,14 +1708,22 @@ class Manager(EventEmitter): Implements the Initiator and Responder roles of the Security Manager Protocol ''' - def __init__(self, device, pairing_config_factory): + device: Device + sessions: Dict[int, Session] + pairing_config_factory: Callable[[Connection], PairingConfig] + + def __init__( + self, + device: Device, + pairing_config_factory: Callable[[Connection], PairingConfig], + ) -> None: super().__init__() self.device = device self.sessions = {} self._ecc_key = None self.pairing_config_factory = pairing_config_factory - def send_command(self, connection, command): + def send_command(self, connection: Connection, command: SMP_Command) -> None: logger.debug( f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] ' f'{connection.peer_address}: {command}' @@ -1606,19 +1731,12 @@ class Manager(EventEmitter): cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID connection.send_l2cap_pdu(cid, command.to_bytes()) - def on_smp_pdu(self, connection, pdu): + def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None: # Look for a session with this connection, and create one if none exists if not (session := self.sessions.get(connection.handle)): if connection.role == BT_CENTRAL_ROLE: logger.warning('Remote starts pairing as Peripheral!') pairing_config = self.pairing_config_factory(connection) - if pairing_config is None: - # Pairing disabled - self.send_command( - connection, - SMP_Pairing_Failed_Command(reason=SMP_PAIRING_NOT_SUPPORTED_ERROR), - ) - return session = Session(self, connection, pairing_config, is_initiator=False) self.sessions[connection.handle] = session @@ -1633,23 +1751,22 @@ class Manager(EventEmitter): session.on_smp_command(command) @property - def ecc_key(self): + def ecc_key(self) -> crypto.EccKey: if self._ecc_key is None: self._ecc_key = crypto.EccKey.generate() + assert self._ecc_key return self._ecc_key - async def pair(self, connection): + async def pair(self, connection: Connection) -> None: # TODO: check if there's already a session for this connection if connection.role != BT_CENTRAL_ROLE: logger.warning('Start pairing as Peripheral!') pairing_config = self.pairing_config_factory(connection) - if pairing_config is None: - raise ValueError('pairing config must not be None when initiating') session = Session(self, connection, pairing_config, is_initiator=True) self.sessions[connection.handle] = session return await session.pair() - def request_pairing(self, connection): + def request_pairing(self, connection: Connection) -> None: pairing_config = self.pairing_config_factory(connection) if pairing_config: auth_req = smp_auth_req( @@ -1663,15 +1780,18 @@ class Manager(EventEmitter): auth_req = 0 self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req)) - def on_session_start(self, session): - self.device.on_pairing_start(session.connection.handle) + def on_session_start(self, session: Session) -> None: + self.device.on_pairing_start(session.connection) - def on_pairing(self, session, identity_address, keys): + def on_pairing( + self, session: Session, identity_address: Optional[Address], keys: PairingKeys + ) -> None: # Store the keys in the key store if self.device.keystore and identity_address is not None: async def store_keys(): try: + assert self.device.keystore await self.device.keystore.update(str(identity_address), keys) except Exception as error: logger.warning(f'!!! error while storing keys: {error}') @@ -1679,17 +1799,19 @@ class Manager(EventEmitter): self.device.abort_on('flush', store_keys()) # Notify the device - self.device.on_pairing(session.connection.handle, keys, session.sc) + self.device.on_pairing(session.connection, keys, session.sc) - def on_pairing_failure(self, session, reason): - self.device.on_pairing_failure(session.connection.handle, reason) + def on_pairing_failure(self, session: Session, reason: int) -> None: + self.device.on_pairing_failure(session.connection, reason) - def on_session_end(self, session): + def on_session_end(self, session: Session) -> None: logger.debug(f'session end for connection 0x{session.connection.handle:04X}') if session.connection.handle in self.sessions: del self.sessions[session.connection.handle] - def get_long_term_key(self, connection, rand, ediv): + def get_long_term_key( + self, connection: Connection, rand: bytes, ediv: int + ) -> Optional[bytes]: if session := self.sessions.get(connection.handle): return session.get_long_term_key(rand, ediv)