smp: add type hints

This commit is contained in:
uael
2023-05-02 05:27:35 +00:00
parent fdee5ecf70
commit 3c81b248a3
2 changed files with 254 additions and 132 deletions

View File

@@ -26,16 +26,22 @@ from __future__ import annotations
import logging
import asyncio
import secrets
from typing import Dict, Optional, Type
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
)
from pyee import EventEmitter
from .colors import color
from .hci import (
HCI_DISPLAY_ONLY_IO_CAPABILITY,
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
HCI_KEYBOARD_ONLY_IO_CAPABILITY,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
Address,
HCI_LE_Enable_Encryption_Command,
HCI_Object,
@@ -51,6 +57,10 @@ from .core import (
from .keys import PairingKeys
from . import crypto
if TYPE_CHECKING:
from bumble.device import Connection, Device
from bumble.pairing import PairingConfig
# -----------------------------------------------------------------------------
# Logging
@@ -184,7 +194,7 @@ SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032'
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def error_name(error_code):
def error_name(error_code: int) -> str:
return name_or_number(SMP_ERROR_NAMES, error_code)
@@ -197,11 +207,12 @@ class SMP_Command:
'''
smp_classes: Dict[int, Type[SMP_Command]] = {}
fields: Any
code = 0
name = ''
@staticmethod
def from_bytes(pdu):
def from_bytes(pdu: bytes) -> "SMP_Command":
code = pdu[0]
cls = SMP_Command.smp_classes.get(code)
@@ -217,11 +228,11 @@ class SMP_Command:
return self
@staticmethod
def command_name(code):
def command_name(code: int) -> str:
return name_or_number(SMP_COMMAND_NAMES, code)
@staticmethod
def auth_req_str(value):
def auth_req_str(value: int) -> str:
bonding_flags = value & 3
mitm = (value >> 2) & 1
sc = (value >> 3) & 1
@@ -234,12 +245,12 @@ class SMP_Command:
)
@staticmethod
def io_capability_name(io_capability):
def io_capability_name(io_capability: int) -> str:
return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability)
@staticmethod
def key_distribution_str(value):
key_types = []
def key_distribution_str(value: int) -> str:
key_types: List[str] = []
if value & SMP_ENC_KEY_DISTRIBUTION_FLAG:
key_types.append('ENC')
if value & SMP_ID_KEY_DISTRIBUTION_FLAG:
@@ -251,7 +262,7 @@ class SMP_Command:
return ','.join(key_types)
@staticmethod
def keypress_notification_type_name(notification_type):
def keypress_notification_type_name(notification_type: int) -> str:
return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type)
@staticmethod
@@ -272,14 +283,14 @@ class SMP_Command:
return inner
def __init__(self, pdu=None, **kwargs):
def __init__(self, pdu: Optional[bytes] = None, **kwargs: Any) -> None:
if hasattr(self, 'fields') and kwargs:
HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
pdu = bytes([self.code]) + HCI_Object.dict_to_bytes(kwargs, self.fields)
self.pdu = pdu
def init_from_bytes(self, pdu, offset):
def init_from_bytes(self, pdu: bytes, offset: int) -> None:
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self):
@@ -320,6 +331,13 @@ class SMP_Pairing_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request
'''
io_capability: int
oob_data_flag: int
auth_req: int
maximum_encryption_key_size: int
initiator_key_distribution: int
responder_key_distribution: int
# -----------------------------------------------------------------------------
@SMP_Command.subclass(
@@ -343,6 +361,13 @@ class SMP_Pairing_Response_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response
'''
io_capability: int
oob_data_flag: int
auth_req: int
maximum_encryption_key_size: int
initiator_key_distribution: int
responder_key_distribution: int
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('confirm_value', 16)])
@@ -351,6 +376,8 @@ class SMP_Pairing_Confirm_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm
'''
confirm_value: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('random_value', 16)])
@@ -359,6 +386,8 @@ class SMP_Pairing_Random_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random
'''
random_value: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('reason', {'size': 1, 'mapper': error_name})])
@@ -367,6 +396,8 @@ class SMP_Pairing_Failed_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed
'''
reason: int
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('public_key_x', 32), ('public_key_y', 32)])
@@ -375,6 +406,9 @@ class SMP_Pairing_Public_Key_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key
'''
public_key_x: bytes
public_key_y: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass(
@@ -387,6 +421,8 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check
'''
dhkey_check: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass(
@@ -402,6 +438,8 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification
'''
notification_type: int
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('long_term_key', 16)])
@@ -410,6 +448,8 @@ class SMP_Encryption_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information
'''
long_term_key: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('ediv', 2), ('rand', 8)])
@@ -418,6 +458,9 @@ class SMP_Master_Identification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification
'''
ediv: int
rand: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('identity_resolving_key', 16)])
@@ -426,6 +469,8 @@ class SMP_Identity_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information
'''
identity_resolving_key: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass(
@@ -439,6 +484,9 @@ class SMP_Identity_Address_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information
'''
addr_type: int
bd_addr: Address
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('signature_key', 16)])
@@ -447,6 +495,8 @@ class SMP_Signing_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information
'''
signature_key: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass(
@@ -459,9 +509,11 @@ class SMP_Security_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request
'''
auth_req: int
# -----------------------------------------------------------------------------
def smp_auth_req(bonding, mitm, sc, keypress, ct2):
def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) -> int:
value = 0
if bonding:
value |= SMP_BONDING_AUTHREQ
@@ -574,11 +626,17 @@ class Session:
},
}
def __init__(self, manager, connection, pairing_config, is_initiator):
def __init__(
self,
manager: Manager,
connection: Connection,
pairing_config: PairingConfig,
is_initiator: bool,
) -> None:
self.manager = manager
self.connection = connection
self.preq = None
self.pres = None
self.preq: Optional[bytes] = None
self.pres: Optional[bytes] = None
self.ea = None
self.eb = None
self.tk = bytes(16)
@@ -588,29 +646,29 @@ class Session:
self.ltk_ediv = 0
self.ltk_rand = bytes(8)
self.link_key = None
self.initiator_key_distribution = 0
self.responder_key_distribution = 0
self.peer_random_value = None
self.peer_public_key_x = bytes(32)
self.initiator_key_distribution: int = 0
self.responder_key_distribution: int = 0
self.peer_random_value: Optional[bytes] = None
self.peer_public_key_x: bytes = bytes(32)
self.peer_public_key_y = bytes(32)
self.peer_ltk = None
self.peer_ediv = None
self.peer_rand = None
self.peer_rand: Optional[bytes] = None
self.peer_identity_resolving_key = None
self.peer_bd_addr = None
self.peer_bd_addr: Optional[Address] = None
self.peer_signature_key = None
self.peer_expected_distributions = []
self.peer_expected_distributions: List[Type[SMP_Command]] = []
self.dh_key = None
self.confirm_value = None
self.passkey = None
self.passkey: Optional[int] = None
self.passkey_ready = asyncio.Event()
self.passkey_step = 0
self.passkey_display = False
self.pairing_method = 0
self.pairing_config = pairing_config
self.wait_before_continuing = None
self.wait_before_continuing: Optional[asyncio.Future[None]] = None
self.completed = False
self.ctkd_task = None
self.ctkd_task: Optional[Awaitable[None]] = None
# Decide if we're the initiator or the responder
self.is_initiator = is_initiator
@@ -628,7 +686,9 @@ class Session:
# Create a future that can be used to wait for the session to complete
if self.is_initiator:
self.pairing_result = asyncio.get_running_loop().create_future()
self.pairing_result: Optional[
asyncio.Future[None]
] = asyncio.get_running_loop().create_future()
else:
self.pairing_result = None
@@ -641,11 +701,11 @@ class Session:
)
# Authentication Requirements Flags - Vol 3, Part H, Figure 3.3
self.bonding = pairing_config.bonding
self.sc = pairing_config.sc
self.mitm = pairing_config.mitm
self.bonding: bool = pairing_config.bonding
self.sc: bool = pairing_config.sc
self.mitm: bool = pairing_config.mitm
self.keypress = False
self.ct2 = False
self.ct2: bool = False
# I/O Capabilities
self.io_capability = pairing_config.delegate.io_capability
@@ -669,34 +729,35 @@ class Session:
self.iat = 1 if peer_address.is_random else 0
@property
def pkx(self):
def pkx(self) -> Tuple[bytes, bytes]:
return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x)
@property
def pka(self):
def pka(self) -> bytes:
return self.pkx[0 if self.is_initiator else 1]
@property
def pkb(self):
def pkb(self) -> bytes:
return self.pkx[0 if self.is_responder else 1]
@property
def nx(self):
def nx(self) -> Tuple[bytes, bytes]:
assert self.peer_random_value
return (self.r, self.peer_random_value)
@property
def na(self):
def na(self) -> bytes:
return self.nx[0 if self.is_initiator else 1]
@property
def nb(self):
def nb(self) -> bytes:
return self.nx[0 if self.is_responder else 1]
@property
def auth_req(self):
def auth_req(self) -> int:
return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2)
def get_long_term_key(self, rand, ediv):
def get_long_term_key(self, rand: bytes, ediv: int) -> Optional[bytes]:
if not self.sc and not self.completed:
if rand == self.ltk_rand and ediv == self.ltk_ediv:
return self.stk
@@ -706,13 +767,13 @@ class Session:
return None
def decide_pairing_method(
self, auth_req, initiator_io_capability, responder_io_capability
):
self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
) -> None:
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
self.pairing_method = self.JUST_WORKS
return
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability]
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] # type: ignore[index]
if isinstance(details, tuple) and len(details) == 2:
# One entry for legacy pairing and one for secure connections
details = details[1 if self.sc else 0]
@@ -724,7 +785,9 @@ class Session:
self.pairing_method = details[0]
self.passkey_display = details[1 if self.is_initiator else 2]
def check_expected_value(self, expected, received, error):
def check_expected_value(
self, expected: bytes, received: bytes, error: int
) -> bool:
logger.debug(f'expected={expected.hex()} got={received.hex()}')
if expected != received:
logger.info(color('pairing confirm/check mismatch', 'red'))
@@ -732,8 +795,8 @@ class Session:
return False
return True
def prompt_user_for_confirmation(self, next_steps):
async def prompt():
def prompt_user_for_confirmation(self, next_steps: Callable[[], None]) -> None:
async def prompt() -> None:
logger.debug('ask for confirmation')
try:
response = await self.pairing_config.delegate.confirm()
@@ -747,8 +810,10 @@ class Session:
self.connection.abort_on('disconnection', prompt())
def prompt_user_for_numeric_comparison(self, code, next_steps):
async def prompt():
def prompt_user_for_numeric_comparison(
self, code: int, next_steps: Callable[[], None]
) -> None:
async def prompt() -> None:
logger.debug(f'verification code: {code}')
try:
response = await self.pairing_config.delegate.compare_numbers(
@@ -764,11 +829,15 @@ class Session:
self.connection.abort_on('disconnection', prompt())
def prompt_user_for_number(self, next_steps):
async def prompt():
def prompt_user_for_number(self, next_steps: Callable[[int], None]) -> None:
async def prompt() -> None:
logger.debug('prompting user for passkey')
try:
passkey = await self.pairing_config.delegate.get_number()
if passkey is None:
logger.debug('Passkey request rejected')
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
return
logger.debug(f'user input: {passkey}')
next_steps(passkey)
except Exception as error:
@@ -777,9 +846,10 @@ class Session:
self.connection.abort_on('disconnection', prompt())
def display_passkey(self):
def display_passkey(self) -> None:
# Generate random Passkey/PIN code
self.passkey = secrets.randbelow(1000000)
assert self.passkey is not None
logger.debug(f'Pairing PIN CODE: {self.passkey:06}')
self.passkey_ready.set()
@@ -793,9 +863,9 @@ class Session:
self.pairing_config.delegate.display_number(self.passkey, digits=6),
)
def input_passkey(self, next_steps=None):
def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None:
# Prompt the user for the passkey displayed on the peer
def after_input(passkey):
def after_input(passkey: int) -> None:
self.passkey = passkey
if not self.sc:
@@ -809,7 +879,9 @@ class Session:
self.prompt_user_for_number(after_input)
def display_or_input_passkey(self, next_steps=None):
def display_or_input_passkey(
self, next_steps: Optional[Callable[[], None]] = None
) -> None:
if self.passkey_display:
self.display_passkey()
if next_steps is not None:
@@ -817,14 +889,14 @@ class Session:
else:
self.input_passkey(next_steps)
def send_command(self, command):
def send_command(self, command: SMP_Command) -> None:
self.manager.send_command(self.connection, command)
def send_pairing_failed(self, error):
def send_pairing_failed(self, error: int) -> None:
self.send_command(SMP_Pairing_Failed_Command(reason=error))
self.on_pairing_failure(error)
def send_pairing_request_command(self):
def send_pairing_request_command(self) -> None:
self.manager.on_session_start(self)
command = SMP_Pairing_Request_Command(
@@ -838,7 +910,7 @@ class Session:
self.preq = bytes(command)
self.send_command(command)
def send_pairing_response_command(self):
def send_pairing_response_command(self) -> None:
response = SMP_Pairing_Response_Command(
io_capability=self.io_capability,
oob_data_flag=0,
@@ -850,18 +922,19 @@ class Session:
self.pres = bytes(response)
self.send_command(response)
def send_pairing_confirm_command(self):
def send_pairing_confirm_command(self) -> None:
self.r = crypto.r()
logger.debug(f'generated random: {self.r.hex()}')
if self.sc:
async def next_steps():
async def next_steps() -> None:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
z = 0
elif self.pairing_method == self.PASSKEY:
# We need a passkey
await self.passkey_ready.wait()
assert self.passkey
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else:
@@ -892,10 +965,10 @@ class Session:
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
def send_pairing_random_command(self):
def send_pairing_random_command(self) -> None:
self.send_command(SMP_Pairing_Random_Command(random_value=self.r))
def send_public_key_command(self):
def send_public_key_command(self) -> None:
self.send_command(
SMP_Pairing_Public_Key_Command(
public_key_x=bytes(reversed(self.manager.ecc_key.x)),
@@ -903,18 +976,18 @@ class Session:
)
)
def send_pairing_dhkey_check_command(self):
def send_pairing_dhkey_check_command(self) -> None:
self.send_command(
SMP_Pairing_DHKey_Check_Command(
dhkey_check=self.ea if self.is_initiator else self.eb
)
)
def start_encryption(self, key):
def start_encryption(self, key: bytes) -> None:
# We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection
self.manager.device.host.send_command_sync(
HCI_LE_Enable_Encryption_Command(
HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
connection_handle=self.connection.handle,
random_number=bytes(8),
encrypted_diversifier=0,
@@ -922,7 +995,7 @@ class Session:
)
)
async def derive_ltk(self):
async def derive_ltk(self) -> None:
link_key = await self.manager.device.get_link_key(self.connection.peer_address)
assert link_key is not None
ilk = (
@@ -932,7 +1005,7 @@ class Session:
)
self.ltk = crypto.h6(ilk, b'brle')
def distribute_keys(self):
def distribute_keys(self) -> None:
# Distribute the keys as required
if self.is_initiator:
# CTKD: Derive LTK from LinkKey
@@ -1032,7 +1105,7 @@ class Session:
)
self.link_key = crypto.h6(ilk, b'lebr')
def compute_peer_expected_distributions(self, key_distribution_flags):
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
# Set our expectations for what to wait for in the key distribution phase
self.peer_expected_distributions = []
if not self.sc and self.connection.transport == BT_LE_TRANSPORT:
@@ -1055,7 +1128,7 @@ class Session:
f'{[c.__name__ for c in self.peer_expected_distributions]}'
)
def check_key_distribution(self, command_class):
def check_key_distribution(self, command_class: Type[SMP_Command]) -> None:
# First, check that the connection is encrypted
if not self.connection.is_encrypted:
logger.warning(
@@ -1083,7 +1156,7 @@ class Session:
)
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
async def pair(self):
async def pair(self) -> None:
# Start pairing as an initiator
# TODO: check that this session isn't already active
@@ -1091,9 +1164,10 @@ class Session:
self.send_pairing_request_command()
# Wait for the pairing process to finish
assert self.pairing_result
await self.connection.abort_on('disconnection', self.pairing_result)
def on_disconnection(self, _):
def on_disconnection(self, _: int) -> None:
self.connection.remove_listener('disconnection', self.on_disconnection)
self.connection.remove_listener(
'connection_encryption_change', self.on_connection_encryption_change
@@ -1104,14 +1178,14 @@ class Session:
)
self.manager.on_session_end(self)
def on_peer_key_distribution_complete(self):
def on_peer_key_distribution_complete(self) -> None:
# The initiator can now send its keys
if self.is_initiator:
self.distribute_keys()
self.connection.abort_on('disconnection', self.on_pairing())
def on_connection_encryption_change(self):
def on_connection_encryption_change(self) -> None:
if self.connection.is_encrypted:
if self.is_responder:
# The responder distributes its keys first, the initiator later
@@ -1121,11 +1195,11 @@ class Session:
if not self.peer_expected_distributions:
self.on_peer_key_distribution_complete()
def on_connection_encryption_key_refresh(self):
def on_connection_encryption_key_refresh(self) -> None:
# Do as if the connection had just been encrypted
self.on_connection_encryption_change()
async def on_pairing(self):
async def on_pairing(self) -> None:
logger.debug('pairing complete')
if self.completed:
@@ -1137,7 +1211,7 @@ class Session:
self.pairing_result.set_result(None)
# Use the peer address from the pairing protocol or the connection
if self.peer_bd_addr:
if self.peer_bd_addr is not None:
peer_address = self.peer_bd_addr
else:
peer_address = self.connection.peer_address
@@ -1186,7 +1260,7 @@ class Session:
)
self.manager.on_pairing(self, peer_address, keys)
def on_pairing_failure(self, reason):
def on_pairing_failure(self, reason: int) -> None:
logger.warning(f'pairing failure ({error_name(reason)})')
if self.completed:
@@ -1199,7 +1273,7 @@ class Session:
self.pairing_result.set_exception(error)
self.manager.on_pairing_failure(self, reason)
def on_smp_command(self, command):
def on_smp_command(self, command: SMP_Command) -> None:
# Find the handler method
handler_name = f'on_{command.name.lower()}'
handler = getattr(self, handler_name, None)
@@ -1215,12 +1289,16 @@ class Session:
else:
logger.error(color('SMP command not handled???', 'red'))
def on_smp_pairing_request_command(self, command):
def on_smp_pairing_request_command(
self, command: SMP_Pairing_Request_Command
) -> None:
self.connection.abort_on(
'disconnection', self.on_smp_pairing_request_command_async(command)
)
async def on_smp_pairing_request_command_async(self, command):
async def on_smp_pairing_request_command_async(
self, command: SMP_Pairing_Request_Command
) -> None:
# Check if the request should proceed
accepted = await self.pairing_config.delegate.accept()
if not accepted:
@@ -1280,7 +1358,9 @@ class Session:
):
self.distribute_keys()
def on_smp_pairing_response_command(self, command):
def on_smp_pairing_response_command(
self, command: SMP_Pairing_Response_Command
) -> None:
if self.is_responder:
logger.warning(color('received pairing response as a responder', 'red'))
return
@@ -1331,7 +1411,9 @@ class Session:
else:
self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command_legacy(self, _):
def on_smp_pairing_confirm_command_legacy(
self, _: SMP_Pairing_Confirm_Command
) -> None:
if self.is_initiator:
self.send_pairing_random_command()
else:
@@ -1341,7 +1423,9 @@ class Session:
else:
self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command_secure_connections(self, _):
def on_smp_pairing_confirm_command_secure_connections(
self, _: SMP_Pairing_Confirm_Command
) -> None:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.is_initiator:
self.r = crypto.r()
@@ -1352,14 +1436,18 @@ class Session:
else:
self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command(self, command):
def on_smp_pairing_confirm_command(
self, command: SMP_Pairing_Confirm_Command
) -> None:
self.confirm_value = command.confirm_value
if self.sc:
self.on_smp_pairing_confirm_command_secure_connections(command)
else:
self.on_smp_pairing_confirm_command_legacy(command)
def on_smp_pairing_random_command_legacy(self, command):
def on_smp_pairing_random_command_legacy(
self, command: SMP_Pairing_Random_Command
) -> None:
# Check that the confirmation values match
confirm_verifier = crypto.c1(
self.tk,
@@ -1371,6 +1459,7 @@ class Session:
self.ia,
self.ra,
)
assert self.confirm_value
if not self.check_expected_value(
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
):
@@ -1394,7 +1483,9 @@ class Session:
else:
self.send_pairing_random_command()
def on_smp_pairing_random_command_secure_connections(self, command):
def on_smp_pairing_random_command_secure_connections(
self, command: SMP_Pairing_Random_Command
) -> None:
if self.pairing_method == self.PASSKEY and self.passkey is None:
logger.warning('no passkey entered, ignoring command')
return
@@ -1402,6 +1493,7 @@ class Session:
# pylint: disable=too-many-return-statements
if self.is_initiator:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
assert self.confirm_value
# Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4(
self.pkb, self.pka, command.random_value, bytes([0])
@@ -1411,6 +1503,7 @@ class Session:
):
return
elif self.pairing_method == self.PASSKEY:
assert self.passkey and self.confirm_value
# Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4(
self.pkb,
@@ -1435,6 +1528,7 @@ class Session:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
self.send_pairing_random_command()
elif self.pairing_method == self.PASSKEY:
assert self.passkey and self.confirm_value
# Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4(
self.pka,
@@ -1468,19 +1562,21 @@ class Session:
ra = bytes(16)
rb = ra
elif self.pairing_method == self.PASSKEY:
assert self.passkey
ra = self.passkey.to_bytes(16, byteorder='little')
rb = ra
else:
# OOB not implemented yet
return
assert self.preq and self.pres
io_cap_a = self.preq[1:4]
io_cap_b = self.pres[1:4]
self.ea = crypto.f6(mac_key, self.na, self.nb, rb, io_cap_a, a, b)
self.eb = crypto.f6(mac_key, self.nb, self.na, ra, io_cap_b, b, a)
# Next steps to be performed after possible user confirmation
def next_steps():
def next_steps() -> None:
# The initiator sends the DH Key check to the responder
if self.is_initiator:
self.send_pairing_dhkey_check_command()
@@ -1502,14 +1598,18 @@ class Session:
else:
next_steps()
def on_smp_pairing_random_command(self, command):
def on_smp_pairing_random_command(
self, command: SMP_Pairing_Random_Command
) -> None:
self.peer_random_value = command.random_value
if self.sc:
self.on_smp_pairing_random_command_secure_connections(command)
else:
self.on_smp_pairing_random_command_legacy(command)
def on_smp_pairing_public_key_command(self, command):
def on_smp_pairing_public_key_command(
self, command: SMP_Pairing_Public_Key_Command
) -> None:
# Store the public key so that we can compute the confirmation value later
self.peer_public_key_x = command.public_key_x
self.peer_public_key_y = command.public_key_y
@@ -1538,9 +1638,12 @@ class Session:
# We can now send the confirmation value
self.send_pairing_confirm_command()
def on_smp_pairing_dhkey_check_command(self, command):
def on_smp_pairing_dhkey_check_command(
self, command: SMP_Pairing_DHKey_Check_Command
) -> None:
# Check that what we received matches what we computed earlier
expected = self.eb if self.is_initiator else self.ea
assert expected
if not self.check_expected_value(
expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR
):
@@ -1549,7 +1652,8 @@ class Session:
if self.is_responder:
if self.wait_before_continuing is not None:
async def next_steps():
async def next_steps() -> None:
assert self.wait_before_continuing
await self.wait_before_continuing
self.wait_before_continuing = None
self.send_pairing_dhkey_check_command()
@@ -1558,29 +1662,42 @@ class Session:
else:
self.send_pairing_dhkey_check_command()
else:
assert self.ltk
self.start_encryption(self.ltk)
def on_smp_pairing_failed_command(self, command):
def on_smp_pairing_failed_command(
self, command: SMP_Pairing_Failed_Command
) -> None:
self.on_pairing_failure(command.reason)
def on_smp_encryption_information_command(self, command):
def on_smp_encryption_information_command(
self, command: SMP_Encryption_Information_Command
) -> None:
self.peer_ltk = command.long_term_key
self.check_key_distribution(SMP_Encryption_Information_Command)
def on_smp_master_identification_command(self, command):
def on_smp_master_identification_command(
self, command: SMP_Master_Identification_Command
) -> None:
self.peer_ediv = command.ediv
self.peer_rand = command.rand
self.check_key_distribution(SMP_Master_Identification_Command)
def on_smp_identity_information_command(self, command):
def on_smp_identity_information_command(
self, command: SMP_Identity_Information_Command
) -> None:
self.peer_identity_resolving_key = command.identity_resolving_key
self.check_key_distribution(SMP_Identity_Information_Command)
def on_smp_identity_address_information_command(self, command):
def on_smp_identity_address_information_command(
self, command: SMP_Identity_Address_Information_Command
) -> None:
self.peer_bd_addr = command.bd_addr
self.check_key_distribution(SMP_Identity_Address_Information_Command)
def on_smp_signing_information_command(self, command):
def on_smp_signing_information_command(
self, command: SMP_Signing_Information_Command
) -> None:
self.peer_signature_key = command.signature_key
self.check_key_distribution(SMP_Signing_Information_Command)
@@ -1591,14 +1708,22 @@ class Manager(EventEmitter):
Implements the Initiator and Responder roles of the Security Manager Protocol
'''
def __init__(self, device, pairing_config_factory):
device: Device
sessions: Dict[int, Session]
pairing_config_factory: Callable[[Connection], PairingConfig]
def __init__(
self,
device: Device,
pairing_config_factory: Callable[[Connection], PairingConfig],
) -> None:
super().__init__()
self.device = device
self.sessions = {}
self._ecc_key = None
self.pairing_config_factory = pairing_config_factory
def send_command(self, connection, command):
def send_command(self, connection: Connection, command: SMP_Command) -> None:
logger.debug(
f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
@@ -1606,19 +1731,12 @@ class Manager(EventEmitter):
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes())
def on_smp_pdu(self, connection, pdu):
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
# Look for a session with this connection, and create one if none exists
if not (session := self.sessions.get(connection.handle)):
if connection.role == BT_CENTRAL_ROLE:
logger.warning('Remote starts pairing as Peripheral!')
pairing_config = self.pairing_config_factory(connection)
if pairing_config is None:
# Pairing disabled
self.send_command(
connection,
SMP_Pairing_Failed_Command(reason=SMP_PAIRING_NOT_SUPPORTED_ERROR),
)
return
session = Session(self, connection, pairing_config, is_initiator=False)
self.sessions[connection.handle] = session
@@ -1633,23 +1751,22 @@ class Manager(EventEmitter):
session.on_smp_command(command)
@property
def ecc_key(self):
def ecc_key(self) -> crypto.EccKey:
if self._ecc_key is None:
self._ecc_key = crypto.EccKey.generate()
assert self._ecc_key
return self._ecc_key
async def pair(self, connection):
async def pair(self, connection: Connection) -> None:
# TODO: check if there's already a session for this connection
if connection.role != BT_CENTRAL_ROLE:
logger.warning('Start pairing as Peripheral!')
pairing_config = self.pairing_config_factory(connection)
if pairing_config is None:
raise ValueError('pairing config must not be None when initiating')
session = Session(self, connection, pairing_config, is_initiator=True)
self.sessions[connection.handle] = session
return await session.pair()
def request_pairing(self, connection):
def request_pairing(self, connection: Connection) -> None:
pairing_config = self.pairing_config_factory(connection)
if pairing_config:
auth_req = smp_auth_req(
@@ -1663,15 +1780,18 @@ class Manager(EventEmitter):
auth_req = 0
self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req))
def on_session_start(self, session):
self.device.on_pairing_start(session.connection.handle)
def on_session_start(self, session: Session) -> None:
self.device.on_pairing_start(session.connection)
def on_pairing(self, session, identity_address, keys):
def on_pairing(
self, session: Session, identity_address: Optional[Address], keys: PairingKeys
) -> None:
# Store the keys in the key store
if self.device.keystore and identity_address is not None:
async def store_keys():
try:
assert self.device.keystore
await self.device.keystore.update(str(identity_address), keys)
except Exception as error:
logger.warning(f'!!! error while storing keys: {error}')
@@ -1679,17 +1799,19 @@ class Manager(EventEmitter):
self.device.abort_on('flush', store_keys())
# Notify the device
self.device.on_pairing(session.connection.handle, keys, session.sc)
self.device.on_pairing(session.connection, keys, session.sc)
def on_pairing_failure(self, session, reason):
self.device.on_pairing_failure(session.connection.handle, reason)
def on_pairing_failure(self, session: Session, reason: int) -> None:
self.device.on_pairing_failure(session.connection, reason)
def on_session_end(self, session):
def on_session_end(self, session: Session) -> None:
logger.debug(f'session end for connection 0x{session.connection.handle:04X}')
if session.connection.handle in self.sessions:
del self.sessions[session.connection.handle]
def get_long_term_key(self, connection, rand, ediv):
def get_long_term_key(
self, connection: Connection, rand: bytes, ediv: int
) -> Optional[bytes]:
if session := self.sessions.get(connection.handle):
return session.get_long_term_key(rand, ediv)