smp: add type hints

This commit is contained in:
uael
2023-05-02 05:27:35 +00:00
parent fdee5ecf70
commit 3c81b248a3
2 changed files with 254 additions and 132 deletions

View File

@@ -23,7 +23,7 @@ import asyncio
import logging import logging
from contextlib import asynccontextmanager, AsyncExitStack from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
from .colors import color from .colors import color
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
@@ -528,6 +528,7 @@ class Connection(CompositeEventEmitter):
transport: int transport: int
self_address: Address self_address: Address
peer_address: Address peer_address: Address
peer_resolvable_address: Optional[Address]
role: int role: int
encryption: int encryption: int
authenticated: bool authenticated: bool
@@ -2196,11 +2197,13 @@ class Device(CompositeEventEmitter):
await self.stop_discovery() await self.stop_discovery()
@property @property
def pairing_config_factory(self): def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]:
return self.smp_manager.pairing_config_factory return self.smp_manager.pairing_config_factory
@pairing_config_factory.setter @pairing_config_factory.setter
def pairing_config_factory(self, pairing_config_factory): def pairing_config_factory(
self, pairing_config_factory: Callable[[Connection], PairingConfig]
) -> None:
self.smp_manager.pairing_config_factory = pairing_config_factory self.smp_manager.pairing_config_factory = pairing_config_factory
async def pair(self, connection): async def pair(self, connection):
@@ -2232,7 +2235,7 @@ class Device(CompositeEventEmitter):
if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral: if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral:
return keys.ltk_peripheral.value return keys.ltk_peripheral.value
async def get_link_key(self, address): async def get_link_key(self, address: Address) -> Optional[bytes]:
# Look for the key in the keystore # Look for the key in the keystore
if self.keystore is not None: if self.keystore is not None:
keys = await self.keystore.get(str(address)) keys = await self.keystore.get(str(address))
@@ -3074,18 +3077,15 @@ class Device(CompositeEventEmitter):
connection.emit('role_change_failure', error) connection.emit('role_change_failure', error)
self.emit('role_change_failure', address, error) self.emit('role_change_failure', address, error)
@with_connection_from_handle def on_pairing_start(self, connection: Connection) -> None:
def on_pairing_start(self, connection):
connection.emit('pairing_start') connection.emit('pairing_start')
@with_connection_from_handle def on_pairing(self, connection: Connection, keys: PairingKeys, sc: bool) -> None:
def on_pairing(self, connection, keys, sc):
connection.sc = sc connection.sc = sc
connection.authenticated = True connection.authenticated = True
connection.emit('pairing', keys) connection.emit('pairing', keys)
@with_connection_from_handle def on_pairing_failure(self, connection: Connection, reason: int) -> None:
def on_pairing_failure(self, connection, reason):
connection.emit('pairing_failure', reason) connection.emit('pairing_failure', reason)
@with_connection_from_handle @with_connection_from_handle

View File

@@ -26,16 +26,22 @@ from __future__ import annotations
import logging import logging
import asyncio import asyncio
import secrets import secrets
from typing import Dict, Optional, Type from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
)
from pyee import EventEmitter from pyee import EventEmitter
from .colors import color from .colors import color
from .hci import ( from .hci import (
HCI_DISPLAY_ONLY_IO_CAPABILITY,
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
HCI_KEYBOARD_ONLY_IO_CAPABILITY,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
Address, Address,
HCI_LE_Enable_Encryption_Command, HCI_LE_Enable_Encryption_Command,
HCI_Object, HCI_Object,
@@ -51,6 +57,10 @@ from .core import (
from .keys import PairingKeys from .keys import PairingKeys
from . import crypto from . import crypto
if TYPE_CHECKING:
from bumble.device import Connection, Device
from bumble.pairing import PairingConfig
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -184,7 +194,7 @@ SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032'
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def error_name(error_code): def error_name(error_code: int) -> str:
return name_or_number(SMP_ERROR_NAMES, error_code) return name_or_number(SMP_ERROR_NAMES, error_code)
@@ -197,11 +207,12 @@ class SMP_Command:
''' '''
smp_classes: Dict[int, Type[SMP_Command]] = {} smp_classes: Dict[int, Type[SMP_Command]] = {}
fields: Any
code = 0 code = 0
name = '' name = ''
@staticmethod @staticmethod
def from_bytes(pdu): def from_bytes(pdu: bytes) -> "SMP_Command":
code = pdu[0] code = pdu[0]
cls = SMP_Command.smp_classes.get(code) cls = SMP_Command.smp_classes.get(code)
@@ -217,11 +228,11 @@ class SMP_Command:
return self return self
@staticmethod @staticmethod
def command_name(code): def command_name(code: int) -> str:
return name_or_number(SMP_COMMAND_NAMES, code) return name_or_number(SMP_COMMAND_NAMES, code)
@staticmethod @staticmethod
def auth_req_str(value): def auth_req_str(value: int) -> str:
bonding_flags = value & 3 bonding_flags = value & 3
mitm = (value >> 2) & 1 mitm = (value >> 2) & 1
sc = (value >> 3) & 1 sc = (value >> 3) & 1
@@ -234,12 +245,12 @@ class SMP_Command:
) )
@staticmethod @staticmethod
def io_capability_name(io_capability): def io_capability_name(io_capability: int) -> str:
return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability) return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability)
@staticmethod @staticmethod
def key_distribution_str(value): def key_distribution_str(value: int) -> str:
key_types = [] key_types: List[str] = []
if value & SMP_ENC_KEY_DISTRIBUTION_FLAG: if value & SMP_ENC_KEY_DISTRIBUTION_FLAG:
key_types.append('ENC') key_types.append('ENC')
if value & SMP_ID_KEY_DISTRIBUTION_FLAG: if value & SMP_ID_KEY_DISTRIBUTION_FLAG:
@@ -251,7 +262,7 @@ class SMP_Command:
return ','.join(key_types) return ','.join(key_types)
@staticmethod @staticmethod
def keypress_notification_type_name(notification_type): def keypress_notification_type_name(notification_type: int) -> str:
return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type) return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type)
@staticmethod @staticmethod
@@ -272,14 +283,14 @@ class SMP_Command:
return inner return inner
def __init__(self, pdu=None, **kwargs): def __init__(self, pdu: Optional[bytes] = None, **kwargs: Any) -> None:
if hasattr(self, 'fields') and kwargs: if hasattr(self, 'fields') and kwargs:
HCI_Object.init_from_fields(self, self.fields, kwargs) HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None: if pdu is None:
pdu = bytes([self.code]) + HCI_Object.dict_to_bytes(kwargs, self.fields) pdu = bytes([self.code]) + HCI_Object.dict_to_bytes(kwargs, self.fields)
self.pdu = pdu self.pdu = pdu
def init_from_bytes(self, pdu, offset): def init_from_bytes(self, pdu: bytes, offset: int) -> None:
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self): def to_bytes(self):
@@ -320,6 +331,13 @@ class SMP_Pairing_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request
''' '''
io_capability: int
oob_data_flag: int
auth_req: int
maximum_encryption_key_size: int
initiator_key_distribution: int
responder_key_distribution: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass( @SMP_Command.subclass(
@@ -343,6 +361,13 @@ class SMP_Pairing_Response_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response
''' '''
io_capability: int
oob_data_flag: int
auth_req: int
maximum_encryption_key_size: int
initiator_key_distribution: int
responder_key_distribution: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('confirm_value', 16)]) @SMP_Command.subclass([('confirm_value', 16)])
@@ -351,6 +376,8 @@ class SMP_Pairing_Confirm_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm
''' '''
confirm_value: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('random_value', 16)]) @SMP_Command.subclass([('random_value', 16)])
@@ -359,6 +386,8 @@ class SMP_Pairing_Random_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random
''' '''
random_value: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('reason', {'size': 1, 'mapper': error_name})]) @SMP_Command.subclass([('reason', {'size': 1, 'mapper': error_name})])
@@ -367,6 +396,8 @@ class SMP_Pairing_Failed_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed
''' '''
reason: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('public_key_x', 32), ('public_key_y', 32)]) @SMP_Command.subclass([('public_key_x', 32), ('public_key_y', 32)])
@@ -375,6 +406,9 @@ class SMP_Pairing_Public_Key_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key
''' '''
public_key_x: bytes
public_key_y: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass( @SMP_Command.subclass(
@@ -387,6 +421,8 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check
''' '''
dhkey_check: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass( @SMP_Command.subclass(
@@ -402,6 +438,8 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification
''' '''
notification_type: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('long_term_key', 16)]) @SMP_Command.subclass([('long_term_key', 16)])
@@ -410,6 +448,8 @@ class SMP_Encryption_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information
''' '''
long_term_key: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('ediv', 2), ('rand', 8)]) @SMP_Command.subclass([('ediv', 2), ('rand', 8)])
@@ -418,6 +458,9 @@ class SMP_Master_Identification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification
''' '''
ediv: int
rand: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('identity_resolving_key', 16)]) @SMP_Command.subclass([('identity_resolving_key', 16)])
@@ -426,6 +469,8 @@ class SMP_Identity_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information
''' '''
identity_resolving_key: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass( @SMP_Command.subclass(
@@ -439,6 +484,9 @@ class SMP_Identity_Address_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information
''' '''
addr_type: int
bd_addr: Address
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('signature_key', 16)]) @SMP_Command.subclass([('signature_key', 16)])
@@ -447,6 +495,8 @@ class SMP_Signing_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information
''' '''
signature_key: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass( @SMP_Command.subclass(
@@ -459,9 +509,11 @@ class SMP_Security_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request
''' '''
auth_req: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def smp_auth_req(bonding, mitm, sc, keypress, ct2): def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) -> int:
value = 0 value = 0
if bonding: if bonding:
value |= SMP_BONDING_AUTHREQ value |= SMP_BONDING_AUTHREQ
@@ -574,11 +626,17 @@ class Session:
}, },
} }
def __init__(self, manager, connection, pairing_config, is_initiator): def __init__(
self,
manager: Manager,
connection: Connection,
pairing_config: PairingConfig,
is_initiator: bool,
) -> None:
self.manager = manager self.manager = manager
self.connection = connection self.connection = connection
self.preq = None self.preq: Optional[bytes] = None
self.pres = None self.pres: Optional[bytes] = None
self.ea = None self.ea = None
self.eb = None self.eb = None
self.tk = bytes(16) self.tk = bytes(16)
@@ -588,29 +646,29 @@ class Session:
self.ltk_ediv = 0 self.ltk_ediv = 0
self.ltk_rand = bytes(8) self.ltk_rand = bytes(8)
self.link_key = None self.link_key = None
self.initiator_key_distribution = 0 self.initiator_key_distribution: int = 0
self.responder_key_distribution = 0 self.responder_key_distribution: int = 0
self.peer_random_value = None self.peer_random_value: Optional[bytes] = None
self.peer_public_key_x = bytes(32) self.peer_public_key_x: bytes = bytes(32)
self.peer_public_key_y = bytes(32) self.peer_public_key_y = bytes(32)
self.peer_ltk = None self.peer_ltk = None
self.peer_ediv = None self.peer_ediv = None
self.peer_rand = None self.peer_rand: Optional[bytes] = None
self.peer_identity_resolving_key = None self.peer_identity_resolving_key = None
self.peer_bd_addr = None self.peer_bd_addr: Optional[Address] = None
self.peer_signature_key = None self.peer_signature_key = None
self.peer_expected_distributions = [] self.peer_expected_distributions: List[Type[SMP_Command]] = []
self.dh_key = None self.dh_key = None
self.confirm_value = None self.confirm_value = None
self.passkey = None self.passkey: Optional[int] = None
self.passkey_ready = asyncio.Event() self.passkey_ready = asyncio.Event()
self.passkey_step = 0 self.passkey_step = 0
self.passkey_display = False self.passkey_display = False
self.pairing_method = 0 self.pairing_method = 0
self.pairing_config = pairing_config self.pairing_config = pairing_config
self.wait_before_continuing = None self.wait_before_continuing: Optional[asyncio.Future[None]] = None
self.completed = False self.completed = False
self.ctkd_task = None self.ctkd_task: Optional[Awaitable[None]] = None
# Decide if we're the initiator or the responder # Decide if we're the initiator or the responder
self.is_initiator = is_initiator self.is_initiator = is_initiator
@@ -628,7 +686,9 @@ class Session:
# Create a future that can be used to wait for the session to complete # Create a future that can be used to wait for the session to complete
if self.is_initiator: if self.is_initiator:
self.pairing_result = asyncio.get_running_loop().create_future() self.pairing_result: Optional[
asyncio.Future[None]
] = asyncio.get_running_loop().create_future()
else: else:
self.pairing_result = None self.pairing_result = None
@@ -641,11 +701,11 @@ class Session:
) )
# Authentication Requirements Flags - Vol 3, Part H, Figure 3.3 # Authentication Requirements Flags - Vol 3, Part H, Figure 3.3
self.bonding = pairing_config.bonding self.bonding: bool = pairing_config.bonding
self.sc = pairing_config.sc self.sc: bool = pairing_config.sc
self.mitm = pairing_config.mitm self.mitm: bool = pairing_config.mitm
self.keypress = False self.keypress = False
self.ct2 = False self.ct2: bool = False
# I/O Capabilities # I/O Capabilities
self.io_capability = pairing_config.delegate.io_capability self.io_capability = pairing_config.delegate.io_capability
@@ -669,34 +729,35 @@ class Session:
self.iat = 1 if peer_address.is_random else 0 self.iat = 1 if peer_address.is_random else 0
@property @property
def pkx(self): def pkx(self) -> Tuple[bytes, bytes]:
return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x) return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x)
@property @property
def pka(self): def pka(self) -> bytes:
return self.pkx[0 if self.is_initiator else 1] return self.pkx[0 if self.is_initiator else 1]
@property @property
def pkb(self): def pkb(self) -> bytes:
return self.pkx[0 if self.is_responder else 1] return self.pkx[0 if self.is_responder else 1]
@property @property
def nx(self): def nx(self) -> Tuple[bytes, bytes]:
assert self.peer_random_value
return (self.r, self.peer_random_value) return (self.r, self.peer_random_value)
@property @property
def na(self): def na(self) -> bytes:
return self.nx[0 if self.is_initiator else 1] return self.nx[0 if self.is_initiator else 1]
@property @property
def nb(self): def nb(self) -> bytes:
return self.nx[0 if self.is_responder else 1] return self.nx[0 if self.is_responder else 1]
@property @property
def auth_req(self): def auth_req(self) -> int:
return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2) return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2)
def get_long_term_key(self, rand, ediv): def get_long_term_key(self, rand: bytes, ediv: int) -> Optional[bytes]:
if not self.sc and not self.completed: if not self.sc and not self.completed:
if rand == self.ltk_rand and ediv == self.ltk_ediv: if rand == self.ltk_rand and ediv == self.ltk_ediv:
return self.stk return self.stk
@@ -706,13 +767,13 @@ class Session:
return None return None
def decide_pairing_method( def decide_pairing_method(
self, auth_req, initiator_io_capability, responder_io_capability self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
): ) -> None:
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0): if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
self.pairing_method = self.JUST_WORKS self.pairing_method = self.JUST_WORKS
return return
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] # type: ignore[index]
if isinstance(details, tuple) and len(details) == 2: if isinstance(details, tuple) and len(details) == 2:
# One entry for legacy pairing and one for secure connections # One entry for legacy pairing and one for secure connections
details = details[1 if self.sc else 0] details = details[1 if self.sc else 0]
@@ -724,7 +785,9 @@ class Session:
self.pairing_method = details[0] self.pairing_method = details[0]
self.passkey_display = details[1 if self.is_initiator else 2] self.passkey_display = details[1 if self.is_initiator else 2]
def check_expected_value(self, expected, received, error): def check_expected_value(
self, expected: bytes, received: bytes, error: int
) -> bool:
logger.debug(f'expected={expected.hex()} got={received.hex()}') logger.debug(f'expected={expected.hex()} got={received.hex()}')
if expected != received: if expected != received:
logger.info(color('pairing confirm/check mismatch', 'red')) logger.info(color('pairing confirm/check mismatch', 'red'))
@@ -732,8 +795,8 @@ class Session:
return False return False
return True return True
def prompt_user_for_confirmation(self, next_steps): def prompt_user_for_confirmation(self, next_steps: Callable[[], None]) -> None:
async def prompt(): async def prompt() -> None:
logger.debug('ask for confirmation') logger.debug('ask for confirmation')
try: try:
response = await self.pairing_config.delegate.confirm() response = await self.pairing_config.delegate.confirm()
@@ -747,8 +810,10 @@ class Session:
self.connection.abort_on('disconnection', prompt()) self.connection.abort_on('disconnection', prompt())
def prompt_user_for_numeric_comparison(self, code, next_steps): def prompt_user_for_numeric_comparison(
async def prompt(): self, code: int, next_steps: Callable[[], None]
) -> None:
async def prompt() -> None:
logger.debug(f'verification code: {code}') logger.debug(f'verification code: {code}')
try: try:
response = await self.pairing_config.delegate.compare_numbers( response = await self.pairing_config.delegate.compare_numbers(
@@ -764,11 +829,15 @@ class Session:
self.connection.abort_on('disconnection', prompt()) self.connection.abort_on('disconnection', prompt())
def prompt_user_for_number(self, next_steps): def prompt_user_for_number(self, next_steps: Callable[[int], None]) -> None:
async def prompt(): async def prompt() -> None:
logger.debug('prompting user for passkey') logger.debug('prompting user for passkey')
try: try:
passkey = await self.pairing_config.delegate.get_number() passkey = await self.pairing_config.delegate.get_number()
if passkey is None:
logger.debug('Passkey request rejected')
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
return
logger.debug(f'user input: {passkey}') logger.debug(f'user input: {passkey}')
next_steps(passkey) next_steps(passkey)
except Exception as error: except Exception as error:
@@ -777,9 +846,10 @@ class Session:
self.connection.abort_on('disconnection', prompt()) self.connection.abort_on('disconnection', prompt())
def display_passkey(self): def display_passkey(self) -> None:
# Generate random Passkey/PIN code # Generate random Passkey/PIN code
self.passkey = secrets.randbelow(1000000) self.passkey = secrets.randbelow(1000000)
assert self.passkey is not None
logger.debug(f'Pairing PIN CODE: {self.passkey:06}') logger.debug(f'Pairing PIN CODE: {self.passkey:06}')
self.passkey_ready.set() self.passkey_ready.set()
@@ -793,9 +863,9 @@ class Session:
self.pairing_config.delegate.display_number(self.passkey, digits=6), self.pairing_config.delegate.display_number(self.passkey, digits=6),
) )
def input_passkey(self, next_steps=None): def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None:
# Prompt the user for the passkey displayed on the peer # Prompt the user for the passkey displayed on the peer
def after_input(passkey): def after_input(passkey: int) -> None:
self.passkey = passkey self.passkey = passkey
if not self.sc: if not self.sc:
@@ -809,7 +879,9 @@ class Session:
self.prompt_user_for_number(after_input) self.prompt_user_for_number(after_input)
def display_or_input_passkey(self, next_steps=None): def display_or_input_passkey(
self, next_steps: Optional[Callable[[], None]] = None
) -> None:
if self.passkey_display: if self.passkey_display:
self.display_passkey() self.display_passkey()
if next_steps is not None: if next_steps is not None:
@@ -817,14 +889,14 @@ class Session:
else: else:
self.input_passkey(next_steps) self.input_passkey(next_steps)
def send_command(self, command): def send_command(self, command: SMP_Command) -> None:
self.manager.send_command(self.connection, command) self.manager.send_command(self.connection, command)
def send_pairing_failed(self, error): def send_pairing_failed(self, error: int) -> None:
self.send_command(SMP_Pairing_Failed_Command(reason=error)) self.send_command(SMP_Pairing_Failed_Command(reason=error))
self.on_pairing_failure(error) self.on_pairing_failure(error)
def send_pairing_request_command(self): def send_pairing_request_command(self) -> None:
self.manager.on_session_start(self) self.manager.on_session_start(self)
command = SMP_Pairing_Request_Command( command = SMP_Pairing_Request_Command(
@@ -838,7 +910,7 @@ class Session:
self.preq = bytes(command) self.preq = bytes(command)
self.send_command(command) self.send_command(command)
def send_pairing_response_command(self): def send_pairing_response_command(self) -> None:
response = SMP_Pairing_Response_Command( response = SMP_Pairing_Response_Command(
io_capability=self.io_capability, io_capability=self.io_capability,
oob_data_flag=0, oob_data_flag=0,
@@ -850,18 +922,19 @@ class Session:
self.pres = bytes(response) self.pres = bytes(response)
self.send_command(response) self.send_command(response)
def send_pairing_confirm_command(self): def send_pairing_confirm_command(self) -> None:
self.r = crypto.r() self.r = crypto.r()
logger.debug(f'generated random: {self.r.hex()}') logger.debug(f'generated random: {self.r.hex()}')
if self.sc: if self.sc:
async def next_steps(): async def next_steps() -> None:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
z = 0 z = 0
elif self.pairing_method == self.PASSKEY: elif self.pairing_method == self.PASSKEY:
# We need a passkey # We need a passkey
await self.passkey_ready.wait() await self.passkey_ready.wait()
assert self.passkey
z = 0x80 + ((self.passkey >> self.passkey_step) & 1) z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else: else:
@@ -892,10 +965,10 @@ class Session:
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value)) self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
def send_pairing_random_command(self): def send_pairing_random_command(self) -> None:
self.send_command(SMP_Pairing_Random_Command(random_value=self.r)) self.send_command(SMP_Pairing_Random_Command(random_value=self.r))
def send_public_key_command(self): def send_public_key_command(self) -> None:
self.send_command( self.send_command(
SMP_Pairing_Public_Key_Command( SMP_Pairing_Public_Key_Command(
public_key_x=bytes(reversed(self.manager.ecc_key.x)), public_key_x=bytes(reversed(self.manager.ecc_key.x)),
@@ -903,18 +976,18 @@ class Session:
) )
) )
def send_pairing_dhkey_check_command(self): def send_pairing_dhkey_check_command(self) -> None:
self.send_command( self.send_command(
SMP_Pairing_DHKey_Check_Command( SMP_Pairing_DHKey_Check_Command(
dhkey_check=self.ea if self.is_initiator else self.eb dhkey_check=self.ea if self.is_initiator else self.eb
) )
) )
def start_encryption(self, key): def start_encryption(self, key: bytes) -> None:
# We can now encrypt the connection with the short term key, so that we can # We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection # distribute the long term and/or other keys over an encrypted connection
self.manager.device.host.send_command_sync( self.manager.device.host.send_command_sync(
HCI_LE_Enable_Encryption_Command( HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
connection_handle=self.connection.handle, connection_handle=self.connection.handle,
random_number=bytes(8), random_number=bytes(8),
encrypted_diversifier=0, encrypted_diversifier=0,
@@ -922,7 +995,7 @@ class Session:
) )
) )
async def derive_ltk(self): async def derive_ltk(self) -> None:
link_key = await self.manager.device.get_link_key(self.connection.peer_address) link_key = await self.manager.device.get_link_key(self.connection.peer_address)
assert link_key is not None assert link_key is not None
ilk = ( ilk = (
@@ -932,7 +1005,7 @@ class Session:
) )
self.ltk = crypto.h6(ilk, b'brle') self.ltk = crypto.h6(ilk, b'brle')
def distribute_keys(self): def distribute_keys(self) -> None:
# Distribute the keys as required # Distribute the keys as required
if self.is_initiator: if self.is_initiator:
# CTKD: Derive LTK from LinkKey # CTKD: Derive LTK from LinkKey
@@ -1032,7 +1105,7 @@ class Session:
) )
self.link_key = crypto.h6(ilk, b'lebr') self.link_key = crypto.h6(ilk, b'lebr')
def compute_peer_expected_distributions(self, key_distribution_flags): def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
# Set our expectations for what to wait for in the key distribution phase # Set our expectations for what to wait for in the key distribution phase
self.peer_expected_distributions = [] self.peer_expected_distributions = []
if not self.sc and self.connection.transport == BT_LE_TRANSPORT: if not self.sc and self.connection.transport == BT_LE_TRANSPORT:
@@ -1055,7 +1128,7 @@ class Session:
f'{[c.__name__ for c in self.peer_expected_distributions]}' f'{[c.__name__ for c in self.peer_expected_distributions]}'
) )
def check_key_distribution(self, command_class): def check_key_distribution(self, command_class: Type[SMP_Command]) -> None:
# First, check that the connection is encrypted # First, check that the connection is encrypted
if not self.connection.is_encrypted: if not self.connection.is_encrypted:
logger.warning( logger.warning(
@@ -1083,7 +1156,7 @@ class Session:
) )
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR) self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
async def pair(self): async def pair(self) -> None:
# Start pairing as an initiator # Start pairing as an initiator
# TODO: check that this session isn't already active # TODO: check that this session isn't already active
@@ -1091,9 +1164,10 @@ class Session:
self.send_pairing_request_command() self.send_pairing_request_command()
# Wait for the pairing process to finish # Wait for the pairing process to finish
assert self.pairing_result
await self.connection.abort_on('disconnection', self.pairing_result) await self.connection.abort_on('disconnection', self.pairing_result)
def on_disconnection(self, _): def on_disconnection(self, _: int) -> None:
self.connection.remove_listener('disconnection', self.on_disconnection) self.connection.remove_listener('disconnection', self.on_disconnection)
self.connection.remove_listener( self.connection.remove_listener(
'connection_encryption_change', self.on_connection_encryption_change 'connection_encryption_change', self.on_connection_encryption_change
@@ -1104,14 +1178,14 @@ class Session:
) )
self.manager.on_session_end(self) self.manager.on_session_end(self)
def on_peer_key_distribution_complete(self): def on_peer_key_distribution_complete(self) -> None:
# The initiator can now send its keys # The initiator can now send its keys
if self.is_initiator: if self.is_initiator:
self.distribute_keys() self.distribute_keys()
self.connection.abort_on('disconnection', self.on_pairing()) self.connection.abort_on('disconnection', self.on_pairing())
def on_connection_encryption_change(self): def on_connection_encryption_change(self) -> None:
if self.connection.is_encrypted: if self.connection.is_encrypted:
if self.is_responder: if self.is_responder:
# The responder distributes its keys first, the initiator later # The responder distributes its keys first, the initiator later
@@ -1121,11 +1195,11 @@ class Session:
if not self.peer_expected_distributions: if not self.peer_expected_distributions:
self.on_peer_key_distribution_complete() self.on_peer_key_distribution_complete()
def on_connection_encryption_key_refresh(self): def on_connection_encryption_key_refresh(self) -> None:
# Do as if the connection had just been encrypted # Do as if the connection had just been encrypted
self.on_connection_encryption_change() self.on_connection_encryption_change()
async def on_pairing(self): async def on_pairing(self) -> None:
logger.debug('pairing complete') logger.debug('pairing complete')
if self.completed: if self.completed:
@@ -1137,7 +1211,7 @@ class Session:
self.pairing_result.set_result(None) self.pairing_result.set_result(None)
# Use the peer address from the pairing protocol or the connection # Use the peer address from the pairing protocol or the connection
if self.peer_bd_addr: if self.peer_bd_addr is not None:
peer_address = self.peer_bd_addr peer_address = self.peer_bd_addr
else: else:
peer_address = self.connection.peer_address peer_address = self.connection.peer_address
@@ -1186,7 +1260,7 @@ class Session:
) )
self.manager.on_pairing(self, peer_address, keys) self.manager.on_pairing(self, peer_address, keys)
def on_pairing_failure(self, reason): def on_pairing_failure(self, reason: int) -> None:
logger.warning(f'pairing failure ({error_name(reason)})') logger.warning(f'pairing failure ({error_name(reason)})')
if self.completed: if self.completed:
@@ -1199,7 +1273,7 @@ class Session:
self.pairing_result.set_exception(error) self.pairing_result.set_exception(error)
self.manager.on_pairing_failure(self, reason) self.manager.on_pairing_failure(self, reason)
def on_smp_command(self, command): def on_smp_command(self, command: SMP_Command) -> None:
# Find the handler method # Find the handler method
handler_name = f'on_{command.name.lower()}' handler_name = f'on_{command.name.lower()}'
handler = getattr(self, handler_name, None) handler = getattr(self, handler_name, None)
@@ -1215,12 +1289,16 @@ class Session:
else: else:
logger.error(color('SMP command not handled???', 'red')) logger.error(color('SMP command not handled???', 'red'))
def on_smp_pairing_request_command(self, command): def on_smp_pairing_request_command(
self, command: SMP_Pairing_Request_Command
) -> None:
self.connection.abort_on( self.connection.abort_on(
'disconnection', self.on_smp_pairing_request_command_async(command) 'disconnection', self.on_smp_pairing_request_command_async(command)
) )
async def on_smp_pairing_request_command_async(self, command): async def on_smp_pairing_request_command_async(
self, command: SMP_Pairing_Request_Command
) -> None:
# Check if the request should proceed # Check if the request should proceed
accepted = await self.pairing_config.delegate.accept() accepted = await self.pairing_config.delegate.accept()
if not accepted: if not accepted:
@@ -1280,7 +1358,9 @@ class Session:
): ):
self.distribute_keys() self.distribute_keys()
def on_smp_pairing_response_command(self, command): def on_smp_pairing_response_command(
self, command: SMP_Pairing_Response_Command
) -> None:
if self.is_responder: if self.is_responder:
logger.warning(color('received pairing response as a responder', 'red')) logger.warning(color('received pairing response as a responder', 'red'))
return return
@@ -1331,7 +1411,9 @@ class Session:
else: else:
self.send_pairing_confirm_command() self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command_legacy(self, _): def on_smp_pairing_confirm_command_legacy(
self, _: SMP_Pairing_Confirm_Command
) -> None:
if self.is_initiator: if self.is_initiator:
self.send_pairing_random_command() self.send_pairing_random_command()
else: else:
@@ -1341,7 +1423,9 @@ class Session:
else: else:
self.send_pairing_confirm_command() self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command_secure_connections(self, _): def on_smp_pairing_confirm_command_secure_connections(
self, _: SMP_Pairing_Confirm_Command
) -> None:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.is_initiator: if self.is_initiator:
self.r = crypto.r() self.r = crypto.r()
@@ -1352,14 +1436,18 @@ class Session:
else: else:
self.send_pairing_confirm_command() self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command(self, command): def on_smp_pairing_confirm_command(
self, command: SMP_Pairing_Confirm_Command
) -> None:
self.confirm_value = command.confirm_value self.confirm_value = command.confirm_value
if self.sc: if self.sc:
self.on_smp_pairing_confirm_command_secure_connections(command) self.on_smp_pairing_confirm_command_secure_connections(command)
else: else:
self.on_smp_pairing_confirm_command_legacy(command) self.on_smp_pairing_confirm_command_legacy(command)
def on_smp_pairing_random_command_legacy(self, command): def on_smp_pairing_random_command_legacy(
self, command: SMP_Pairing_Random_Command
) -> None:
# Check that the confirmation values match # Check that the confirmation values match
confirm_verifier = crypto.c1( confirm_verifier = crypto.c1(
self.tk, self.tk,
@@ -1371,6 +1459,7 @@ class Session:
self.ia, self.ia,
self.ra, self.ra,
) )
assert self.confirm_value
if not self.check_expected_value( if not self.check_expected_value(
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
): ):
@@ -1394,7 +1483,9 @@ class Session:
else: else:
self.send_pairing_random_command() self.send_pairing_random_command()
def on_smp_pairing_random_command_secure_connections(self, command): def on_smp_pairing_random_command_secure_connections(
self, command: SMP_Pairing_Random_Command
) -> None:
if self.pairing_method == self.PASSKEY and self.passkey is None: if self.pairing_method == self.PASSKEY and self.passkey is None:
logger.warning('no passkey entered, ignoring command') logger.warning('no passkey entered, ignoring command')
return return
@@ -1402,6 +1493,7 @@ class Session:
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
if self.is_initiator: if self.is_initiator:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
assert self.confirm_value
# Check that the random value matches what was committed to earlier # Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4( confirm_verifier = crypto.f4(
self.pkb, self.pka, command.random_value, bytes([0]) self.pkb, self.pka, command.random_value, bytes([0])
@@ -1411,6 +1503,7 @@ class Session:
): ):
return return
elif self.pairing_method == self.PASSKEY: elif self.pairing_method == self.PASSKEY:
assert self.passkey and self.confirm_value
# Check that the random value matches what was committed to earlier # Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4( confirm_verifier = crypto.f4(
self.pkb, self.pkb,
@@ -1435,6 +1528,7 @@ class Session:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
self.send_pairing_random_command() self.send_pairing_random_command()
elif self.pairing_method == self.PASSKEY: elif self.pairing_method == self.PASSKEY:
assert self.passkey and self.confirm_value
# Check that the random value matches what was committed to earlier # Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4( confirm_verifier = crypto.f4(
self.pka, self.pka,
@@ -1468,19 +1562,21 @@ class Session:
ra = bytes(16) ra = bytes(16)
rb = ra rb = ra
elif self.pairing_method == self.PASSKEY: elif self.pairing_method == self.PASSKEY:
assert self.passkey
ra = self.passkey.to_bytes(16, byteorder='little') ra = self.passkey.to_bytes(16, byteorder='little')
rb = ra rb = ra
else: else:
# OOB not implemented yet # OOB not implemented yet
return return
assert self.preq and self.pres
io_cap_a = self.preq[1:4] io_cap_a = self.preq[1:4]
io_cap_b = self.pres[1:4] io_cap_b = self.pres[1:4]
self.ea = crypto.f6(mac_key, self.na, self.nb, rb, io_cap_a, a, b) self.ea = crypto.f6(mac_key, self.na, self.nb, rb, io_cap_a, a, b)
self.eb = crypto.f6(mac_key, self.nb, self.na, ra, io_cap_b, b, a) self.eb = crypto.f6(mac_key, self.nb, self.na, ra, io_cap_b, b, a)
# Next steps to be performed after possible user confirmation # Next steps to be performed after possible user confirmation
def next_steps(): def next_steps() -> None:
# The initiator sends the DH Key check to the responder # The initiator sends the DH Key check to the responder
if self.is_initiator: if self.is_initiator:
self.send_pairing_dhkey_check_command() self.send_pairing_dhkey_check_command()
@@ -1502,14 +1598,18 @@ class Session:
else: else:
next_steps() next_steps()
def on_smp_pairing_random_command(self, command): def on_smp_pairing_random_command(
self, command: SMP_Pairing_Random_Command
) -> None:
self.peer_random_value = command.random_value self.peer_random_value = command.random_value
if self.sc: if self.sc:
self.on_smp_pairing_random_command_secure_connections(command) self.on_smp_pairing_random_command_secure_connections(command)
else: else:
self.on_smp_pairing_random_command_legacy(command) self.on_smp_pairing_random_command_legacy(command)
def on_smp_pairing_public_key_command(self, command): def on_smp_pairing_public_key_command(
self, command: SMP_Pairing_Public_Key_Command
) -> None:
# Store the public key so that we can compute the confirmation value later # Store the public key so that we can compute the confirmation value later
self.peer_public_key_x = command.public_key_x self.peer_public_key_x = command.public_key_x
self.peer_public_key_y = command.public_key_y self.peer_public_key_y = command.public_key_y
@@ -1538,9 +1638,12 @@ class Session:
# We can now send the confirmation value # We can now send the confirmation value
self.send_pairing_confirm_command() self.send_pairing_confirm_command()
def on_smp_pairing_dhkey_check_command(self, command): def on_smp_pairing_dhkey_check_command(
self, command: SMP_Pairing_DHKey_Check_Command
) -> None:
# Check that what we received matches what we computed earlier # Check that what we received matches what we computed earlier
expected = self.eb if self.is_initiator else self.ea expected = self.eb if self.is_initiator else self.ea
assert expected
if not self.check_expected_value( if not self.check_expected_value(
expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR
): ):
@@ -1549,7 +1652,8 @@ class Session:
if self.is_responder: if self.is_responder:
if self.wait_before_continuing is not None: if self.wait_before_continuing is not None:
async def next_steps(): async def next_steps() -> None:
assert self.wait_before_continuing
await self.wait_before_continuing await self.wait_before_continuing
self.wait_before_continuing = None self.wait_before_continuing = None
self.send_pairing_dhkey_check_command() self.send_pairing_dhkey_check_command()
@@ -1558,29 +1662,42 @@ class Session:
else: else:
self.send_pairing_dhkey_check_command() self.send_pairing_dhkey_check_command()
else: else:
assert self.ltk
self.start_encryption(self.ltk) self.start_encryption(self.ltk)
def on_smp_pairing_failed_command(self, command): def on_smp_pairing_failed_command(
self, command: SMP_Pairing_Failed_Command
) -> None:
self.on_pairing_failure(command.reason) self.on_pairing_failure(command.reason)
def on_smp_encryption_information_command(self, command): def on_smp_encryption_information_command(
self, command: SMP_Encryption_Information_Command
) -> None:
self.peer_ltk = command.long_term_key self.peer_ltk = command.long_term_key
self.check_key_distribution(SMP_Encryption_Information_Command) self.check_key_distribution(SMP_Encryption_Information_Command)
def on_smp_master_identification_command(self, command): def on_smp_master_identification_command(
self, command: SMP_Master_Identification_Command
) -> None:
self.peer_ediv = command.ediv self.peer_ediv = command.ediv
self.peer_rand = command.rand self.peer_rand = command.rand
self.check_key_distribution(SMP_Master_Identification_Command) self.check_key_distribution(SMP_Master_Identification_Command)
def on_smp_identity_information_command(self, command): def on_smp_identity_information_command(
self, command: SMP_Identity_Information_Command
) -> None:
self.peer_identity_resolving_key = command.identity_resolving_key self.peer_identity_resolving_key = command.identity_resolving_key
self.check_key_distribution(SMP_Identity_Information_Command) self.check_key_distribution(SMP_Identity_Information_Command)
def on_smp_identity_address_information_command(self, command): def on_smp_identity_address_information_command(
self, command: SMP_Identity_Address_Information_Command
) -> None:
self.peer_bd_addr = command.bd_addr self.peer_bd_addr = command.bd_addr
self.check_key_distribution(SMP_Identity_Address_Information_Command) self.check_key_distribution(SMP_Identity_Address_Information_Command)
def on_smp_signing_information_command(self, command): def on_smp_signing_information_command(
self, command: SMP_Signing_Information_Command
) -> None:
self.peer_signature_key = command.signature_key self.peer_signature_key = command.signature_key
self.check_key_distribution(SMP_Signing_Information_Command) self.check_key_distribution(SMP_Signing_Information_Command)
@@ -1591,14 +1708,22 @@ class Manager(EventEmitter):
Implements the Initiator and Responder roles of the Security Manager Protocol Implements the Initiator and Responder roles of the Security Manager Protocol
''' '''
def __init__(self, device, pairing_config_factory): device: Device
sessions: Dict[int, Session]
pairing_config_factory: Callable[[Connection], PairingConfig]
def __init__(
self,
device: Device,
pairing_config_factory: Callable[[Connection], PairingConfig],
) -> None:
super().__init__() super().__init__()
self.device = device self.device = device
self.sessions = {} self.sessions = {}
self._ecc_key = None self._ecc_key = None
self.pairing_config_factory = pairing_config_factory self.pairing_config_factory = pairing_config_factory
def send_command(self, connection, command): def send_command(self, connection: Connection, command: SMP_Command) -> None:
logger.debug( logger.debug(
f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] ' f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}' f'{connection.peer_address}: {command}'
@@ -1606,19 +1731,12 @@ class Manager(EventEmitter):
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes()) connection.send_l2cap_pdu(cid, command.to_bytes())
def on_smp_pdu(self, connection, pdu): def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
# Look for a session with this connection, and create one if none exists # Look for a session with this connection, and create one if none exists
if not (session := self.sessions.get(connection.handle)): if not (session := self.sessions.get(connection.handle)):
if connection.role == BT_CENTRAL_ROLE: if connection.role == BT_CENTRAL_ROLE:
logger.warning('Remote starts pairing as Peripheral!') logger.warning('Remote starts pairing as Peripheral!')
pairing_config = self.pairing_config_factory(connection) pairing_config = self.pairing_config_factory(connection)
if pairing_config is None:
# Pairing disabled
self.send_command(
connection,
SMP_Pairing_Failed_Command(reason=SMP_PAIRING_NOT_SUPPORTED_ERROR),
)
return
session = Session(self, connection, pairing_config, is_initiator=False) session = Session(self, connection, pairing_config, is_initiator=False)
self.sessions[connection.handle] = session self.sessions[connection.handle] = session
@@ -1633,23 +1751,22 @@ class Manager(EventEmitter):
session.on_smp_command(command) session.on_smp_command(command)
@property @property
def ecc_key(self): def ecc_key(self) -> crypto.EccKey:
if self._ecc_key is None: if self._ecc_key is None:
self._ecc_key = crypto.EccKey.generate() self._ecc_key = crypto.EccKey.generate()
assert self._ecc_key
return self._ecc_key return self._ecc_key
async def pair(self, connection): async def pair(self, connection: Connection) -> None:
# TODO: check if there's already a session for this connection # TODO: check if there's already a session for this connection
if connection.role != BT_CENTRAL_ROLE: if connection.role != BT_CENTRAL_ROLE:
logger.warning('Start pairing as Peripheral!') logger.warning('Start pairing as Peripheral!')
pairing_config = self.pairing_config_factory(connection) pairing_config = self.pairing_config_factory(connection)
if pairing_config is None:
raise ValueError('pairing config must not be None when initiating')
session = Session(self, connection, pairing_config, is_initiator=True) session = Session(self, connection, pairing_config, is_initiator=True)
self.sessions[connection.handle] = session self.sessions[connection.handle] = session
return await session.pair() return await session.pair()
def request_pairing(self, connection): def request_pairing(self, connection: Connection) -> None:
pairing_config = self.pairing_config_factory(connection) pairing_config = self.pairing_config_factory(connection)
if pairing_config: if pairing_config:
auth_req = smp_auth_req( auth_req = smp_auth_req(
@@ -1663,15 +1780,18 @@ class Manager(EventEmitter):
auth_req = 0 auth_req = 0
self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req)) self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req))
def on_session_start(self, session): def on_session_start(self, session: Session) -> None:
self.device.on_pairing_start(session.connection.handle) self.device.on_pairing_start(session.connection)
def on_pairing(self, session, identity_address, keys): def on_pairing(
self, session: Session, identity_address: Optional[Address], keys: PairingKeys
) -> None:
# Store the keys in the key store # Store the keys in the key store
if self.device.keystore and identity_address is not None: if self.device.keystore and identity_address is not None:
async def store_keys(): async def store_keys():
try: try:
assert self.device.keystore
await self.device.keystore.update(str(identity_address), keys) await self.device.keystore.update(str(identity_address), keys)
except Exception as error: except Exception as error:
logger.warning(f'!!! error while storing keys: {error}') logger.warning(f'!!! error while storing keys: {error}')
@@ -1679,17 +1799,19 @@ class Manager(EventEmitter):
self.device.abort_on('flush', store_keys()) self.device.abort_on('flush', store_keys())
# Notify the device # Notify the device
self.device.on_pairing(session.connection.handle, keys, session.sc) self.device.on_pairing(session.connection, keys, session.sc)
def on_pairing_failure(self, session, reason): def on_pairing_failure(self, session: Session, reason: int) -> None:
self.device.on_pairing_failure(session.connection.handle, reason) self.device.on_pairing_failure(session.connection, reason)
def on_session_end(self, session): def on_session_end(self, session: Session) -> None:
logger.debug(f'session end for connection 0x{session.connection.handle:04X}') logger.debug(f'session end for connection 0x{session.connection.handle:04X}')
if session.connection.handle in self.sessions: if session.connection.handle in self.sessions:
del self.sessions[session.connection.handle] del self.sessions[session.connection.handle]
def get_long_term_key(self, connection, rand, ediv): def get_long_term_key(
self, connection: Connection, rand: bytes, ediv: int
) -> Optional[bytes]:
if session := self.sessions.get(connection.handle): if session := self.sessions.get(connection.handle):
return session.get_long_term_key(rand, ediv) return session.get_long_term_key(rand, ediv)