Merge pull request #178 from google/uael/pairing

Overall fixes and improvements
This commit is contained in:
Lucas Abel
2023-05-03 21:39:50 -07:00
committed by GitHub
7 changed files with 439 additions and 256 deletions

View File

@@ -152,7 +152,12 @@ class UUID:
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian
UUIDS: List[UUID] = [] # Registry of all instances created
def __init__(self, uuid_str_or_int, name=None):
uuid_bytes: bytes
name: Optional[str]
def __init__(
self, uuid_str_or_int: Union[str, int], name: Optional[str] = None
) -> None:
if isinstance(uuid_str_or_int, int):
self.uuid_bytes = struct.pack('<H', uuid_str_or_int)
else:
@@ -172,7 +177,7 @@ class UUID:
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
self.name = name
def register(self):
def register(self) -> UUID:
# Register this object in the class registry, and update the entry's name if
# it wasn't set already
for uuid in self.UUIDS:
@@ -196,22 +201,22 @@ class UUID:
raise ValueError('only 2, 4 and 16 bytes are allowed')
@classmethod
def from_16_bits(cls, uuid_16, name=None):
def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID:
return cls.from_bytes(struct.pack('<H', uuid_16), name)
@classmethod
def from_32_bits(cls, uuid_32, name=None):
def from_32_bits(cls, uuid_32: int, name: Optional[str] = None) -> UUID:
return cls.from_bytes(struct.pack('<I', uuid_32), name)
@classmethod
def parse_uuid(cls, uuid_as_bytes, offset):
def parse_uuid(cls, uuid_as_bytes: bytes, offset: int) -> Tuple[int, UUID]:
return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:])
@classmethod
def parse_uuid_2(cls, uuid_as_bytes, offset):
def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> Tuple[int, UUID]:
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
def to_bytes(self, force_128=False):
def to_bytes(self, force_128: bool = False) -> bytes:
'''
Serialize UUID in little-endian byte-order
'''
@@ -227,7 +232,7 @@ class UUID:
else:
assert False, "unreachable"
def to_pdu_bytes(self):
def to_pdu_bytes(self) -> bytes:
'''
Convert to bytes for use in an ATT PDU.
According to Vol 3, Part F - 3.2.1 Attribute Type:
@@ -236,11 +241,11 @@ class UUID:
'''
return self.to_bytes(force_128=(len(self.uuid_bytes) == 4))
def to_hex_str(self) -> str:
def to_hex_str(self, separator: str = '') -> str:
if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4:
return bytes(reversed(self.uuid_bytes)).hex().upper()
return ''.join(
return separator.join(
[
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
@@ -250,10 +255,10 @@ class UUID:
]
).upper()
def __bytes__(self):
def __bytes__(self) -> bytes:
return self.to_bytes()
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if isinstance(other, UUID):
return self.to_bytes(force_128=True) == other.to_bytes(force_128=True)
@@ -262,35 +267,19 @@ class UUID:
return False
def __hash__(self):
def __hash__(self) -> int:
return hash(self.uuid_bytes)
def __str__(self):
def __str__(self) -> str:
result = self.to_hex_str(separator='-')
if len(self.uuid_bytes) == 2:
uuid = struct.unpack('<H', self.uuid_bytes)[0]
result = f'UUID-16:{uuid:04X}'
result = 'UUID-16:' + result
elif len(self.uuid_bytes) == 4:
uuid = struct.unpack('<I', self.uuid_bytes)[0]
result = f'UUID-32:{uuid:08X}'
else:
result = '-'.join(
[
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex(),
]
).upper()
result = 'UUID-32:' + result
if self.name is not None:
return result + f' ({self.name})'
result += f' ({self.name})'
return result
def __repr__(self):
return str(self)
# -----------------------------------------------------------------------------
# Common UUID constants
@@ -773,7 +762,7 @@ class AdvertisingData:
def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> List[UUID]:
uuids = []
offset = 0
while (uuid_size * (offset + 1)) <= len(ad_data):
while (offset + uuid_size) <= len(ad_data):
uuids.append(UUID.from_bytes(ad_data[offset : offset + uuid_size]))
offset += uuid_size
return uuids

View File

@@ -23,7 +23,7 @@ import asyncio
import logging
from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union
from .colors import color
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
@@ -528,6 +528,7 @@ class Connection(CompositeEventEmitter):
transport: int
self_address: Address
peer_address: Address
peer_resolvable_address: Optional[Address]
role: int
encryption: int
authenticated: bool
@@ -2196,13 +2197,23 @@ class Device(CompositeEventEmitter):
await self.stop_discovery()
@property
def pairing_config_factory(self):
def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]:
return self.smp_manager.pairing_config_factory
@pairing_config_factory.setter
def pairing_config_factory(self, pairing_config_factory):
def pairing_config_factory(
self, pairing_config_factory: Callable[[Connection], PairingConfig]
) -> None:
self.smp_manager.pairing_config_factory = pairing_config_factory
@property
def smp_session_proxy(self) -> Type[smp.Session]:
return self.smp_manager.session_proxy
@smp_session_proxy.setter
def smp_session_proxy(self, session_proxy: Type[smp.Session]) -> None:
self.smp_manager.session_proxy = session_proxy
async def pair(self, connection):
return await self.smp_manager.pair(connection)
@@ -2232,7 +2243,7 @@ class Device(CompositeEventEmitter):
if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral:
return keys.ltk_peripheral.value
async def get_link_key(self, address):
async def get_link_key(self, address: Address) -> Optional[bytes]:
# Look for the key in the keystore
if self.keystore is not None:
keys = await self.keystore.get(str(address))
@@ -2243,6 +2254,7 @@ class Device(CompositeEventEmitter):
return None
return keys.link_key.value
return None
# [Classic only]
async def authenticate(self, connection):
@@ -2772,89 +2784,103 @@ class Device(CompositeEventEmitter):
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_authentication_user_confirmation_request(self, connection, code):
def on_authentication_user_confirmation_request(self, connection, code) -> None:
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
io_capability = pairing_config.delegate.classic_io_capability
peer_io_capability = connection.peer_pairing_io_capability
# Respond
if io_capability == HCI_DISPLAY_YES_NO_IO_CAPABILITY:
if connection.peer_pairing_io_capability in (
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
HCI_DISPLAY_ONLY_IO_CAPABILITY,
):
# Display the code and ask the user to compare
async def prompt():
return (
await pairing_config.delegate.compare_numbers(code, digits=6),
async def confirm() -> bool:
# Ask the user to confirm the pairing, without display
return await pairing_config.delegate.confirm()
async def auto_confirm() -> bool:
# Ask the user to auto-confirm the pairing, without display
return await pairing_config.delegate.confirm(auto=True)
async def display_confirm() -> bool:
# Display the code and ask the user to compare
return await pairing_config.delegate.compare_numbers(code, digits=6)
async def display_auto_confirm() -> bool:
# Display the code to the user and ask the delegate to auto-confirm
await pairing_config.delegate.display_number(code, digits=6)
return await pairing_config.delegate.confirm(auto=True)
async def na() -> bool:
assert False, "N/A: unreachable"
# See Bluetooth spec @ Vol 3, Part C 5.2.2.6
methods = {
HCI_DISPLAY_ONLY_IO_CAPABILITY: {
HCI_DISPLAY_ONLY_IO_CAPABILITY: display_auto_confirm,
HCI_DISPLAY_YES_NO_IO_CAPABILITY: display_confirm,
HCI_KEYBOARD_ONLY_IO_CAPABILITY: na,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
},
HCI_DISPLAY_YES_NO_IO_CAPABILITY: {
HCI_DISPLAY_ONLY_IO_CAPABILITY: display_auto_confirm,
HCI_DISPLAY_YES_NO_IO_CAPABILITY: display_confirm,
HCI_KEYBOARD_ONLY_IO_CAPABILITY: na,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
},
HCI_KEYBOARD_ONLY_IO_CAPABILITY: {
HCI_DISPLAY_ONLY_IO_CAPABILITY: na,
HCI_DISPLAY_YES_NO_IO_CAPABILITY: na,
HCI_KEYBOARD_ONLY_IO_CAPABILITY: na,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
},
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: {
HCI_DISPLAY_ONLY_IO_CAPABILITY: confirm,
HCI_DISPLAY_YES_NO_IO_CAPABILITY: confirm,
HCI_KEYBOARD_ONLY_IO_CAPABILITY: auto_confirm,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
},
}
method = methods[peer_io_capability][io_capability]
async def reply() -> None:
if await connection.abort_on('disconnection', method()):
await self.host.send_command(
HCI_User_Confirmation_Request_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address
)
)
else:
# Ask the user to confirm the pairing, without showing a code
async def prompt():
return await pairing_config.delegate.confirm()
async def confirm():
if await prompt():
await self.host.send_command(
HCI_User_Confirmation_Request_Reply_Command(
bd_addr=connection.peer_address
)
)
else:
await self.host.send_command(
HCI_User_Confirmation_Request_Negative_Reply_Command(
bd_addr=connection.peer_address
)
await self.host.send_command(
HCI_User_Confirmation_Request_Negative_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address
)
)
AsyncRunner.spawn(connection.abort_on('disconnection', confirm()))
return
if io_capability == HCI_DISPLAY_ONLY_IO_CAPABILITY:
# Display the code to the user
AsyncRunner.spawn(pairing_config.delegate.display_number(code, 6))
# Automatic confirmation
self.host.send_command_sync(
HCI_User_Confirmation_Request_Reply_Command(bd_addr=connection.peer_address)
)
AsyncRunner.spawn(reply())
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_authentication_user_passkey_request(self, connection):
def on_authentication_user_passkey_request(self, connection) -> None:
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
io_capability = pairing_config.delegate.classic_io_capability
# Respond
if io_capability == HCI_KEYBOARD_ONLY_IO_CAPABILITY:
# Ask the user to input a number
async def get_number():
number = await connection.abort_on(
'disconnection', pairing_config.delegate.get_number()
)
if number is not None:
await self.host.send_command(
HCI_User_Passkey_Request_Reply_Command(
bd_addr=connection.peer_address, numeric_value=number
)
)
else:
await self.host.send_command(
HCI_User_Passkey_Request_Negative_Reply_Command(
bd_addr=connection.peer_address
)
)
asyncio.create_task(get_number())
else:
self.host.send_command_sync(
HCI_User_Passkey_Request_Negative_Reply_Command(
bd_addr=connection.peer_address
)
async def reply() -> None:
number = await connection.abort_on(
'disconnection', pairing_config.delegate.get_number()
)
if number is not None:
await self.host.send_command(
HCI_User_Passkey_Request_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address, numeric_value=number
)
)
else:
await self.host.send_command(
HCI_User_Passkey_Request_Negative_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address
)
)
AsyncRunner.spawn(reply())
# [Classic only]
@host_event_handler
@@ -3059,18 +3085,15 @@ class Device(CompositeEventEmitter):
connection.emit('role_change_failure', error)
self.emit('role_change_failure', address, error)
@with_connection_from_handle
def on_pairing_start(self, connection):
def on_pairing_start(self, connection: Connection) -> None:
connection.emit('pairing_start')
@with_connection_from_handle
def on_pairing(self, connection, keys, sc):
def on_pairing(self, connection: Connection, keys: PairingKeys, sc: bool) -> None:
connection.sc = sc
connection.authenticated = True
connection.emit('pairing', keys)
@with_connection_from_handle
def on_pairing_failure(self, connection, reason):
def on_pairing_failure(self, connection: Connection, reason: int) -> None:
connection.emit('pairing_failure', reason)
@with_connection_from_handle

View File

@@ -247,7 +247,7 @@ class TemplateService(Service):
to expose their UUID as a class property
'''
UUID = None
UUID: Optional[UUID] = None
def __init__(self, characteristics, primary=True):
super().__init__(self.UUID, characteristics, primary)

View File

@@ -25,7 +25,7 @@ import asyncio
import logging
import os
import json
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from .colors import color
from .hci import Address
@@ -139,19 +139,19 @@ class PairingKeys:
# -----------------------------------------------------------------------------
class KeyStore:
async def delete(self, name):
async def delete(self, name: str):
pass
async def update(self, name, keys):
async def update(self, name: str, keys: PairingKeys) -> None:
pass
async def get(self, _name):
return PairingKeys()
async def get(self, _name: str) -> Optional[PairingKeys]:
return None
async def get_all(self):
async def get_all(self) -> List[Tuple[str, PairingKeys]]:
return []
async def delete_all(self):
async def delete_all(self) -> None:
all_keys = await self.get_all()
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
@@ -177,15 +177,15 @@ class KeyStore:
separator = '\n'
@staticmethod
def create_for_device(device: Device) -> Optional[KeyStore]:
def create_for_device(device: Device) -> KeyStore:
if device.config.keystore is None:
return None
return MemoryKeyStore()
keystore_type = device.config.keystore.split(':', 1)[0]
if keystore_type == 'JsonKeyStore':
return JsonKeyStore.from_device(device)
return None
return MemoryKeyStore()
# -----------------------------------------------------------------------------
@@ -307,3 +307,24 @@ class JsonKeyStore(KeyStore):
return None
return PairingKeys.from_dict(keys)
# -----------------------------------------------------------------------------
class MemoryKeyStore(KeyStore):
all_keys: Dict[str, PairingKeys]
def __init__(self) -> None:
self.all_keys = {}
async def delete(self, name: str) -> None:
if name in self.all_keys:
del self.all_keys[name]
async def update(self, name: str, keys: PairingKeys) -> None:
self.all_keys[name] = keys
async def get(self, name: str) -> Optional[PairingKeys]:
return self.all_keys.get(name)
async def get_all(self) -> List[Tuple[str, PairingKeys]]:
return list(self.all_keys.items())

View File

@@ -65,8 +65,9 @@ class PairingDelegate:
DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG
DEFAULT_KEY_DISTRIBUTION: int = (
SMP_ENC_KEY_DISTRIBUTION_FLAG | SMP_ID_KEY_DISTRIBUTION_FLAG
DEFAULT_KEY_DISTRIBUTION: KeyDistribution = (
KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY
| KeyDistribution.DISTRIBUTE_IDENTITY_KEY
)
# Default mapping from abstract to Classic I/O capabilities.
@@ -85,9 +86,9 @@ class PairingDelegate:
def __init__(
self,
io_capability=NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution=DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution=DEFAULT_KEY_DISTRIBUTION,
io_capability: IoCapability = NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
) -> None:
self.io_capability = io_capability
self.local_initiator_key_distribution = local_initiator_key_distribution
@@ -113,8 +114,11 @@ class PairingDelegate:
"""Accept or reject a Pairing request."""
return True
async def confirm(self) -> bool:
"""Respond yes or no to a Pairing confirmation question."""
async def confirm(self, auto: bool = False) -> bool:
"""
Respond yes or no to a Pairing confirmation question.
The `auto` parameter stands for automatic confirmation.
"""
return True
# pylint: disable-next=unused-argument
@@ -129,7 +133,7 @@ class PairingDelegate:
"""
return 0
async def get_string(self, max_length) -> Optional[str]:
async def get_string(self, max_length: int) -> Optional[str]:
"""
Return a string whose utf-8 encoding is up to max_length bytes.
"""

View File

@@ -26,16 +26,22 @@ from __future__ import annotations
import logging
import asyncio
import secrets
from typing import Dict, Optional, Type
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
)
from pyee import EventEmitter
from .colors import color
from .hci import (
HCI_DISPLAY_ONLY_IO_CAPABILITY,
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
HCI_KEYBOARD_ONLY_IO_CAPABILITY,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
Address,
HCI_LE_Enable_Encryption_Command,
HCI_Object,
@@ -51,6 +57,10 @@ from .core import (
from .keys import PairingKeys
from . import crypto
if TYPE_CHECKING:
from bumble.device import Connection, Device
from bumble.pairing import PairingConfig
# -----------------------------------------------------------------------------
# Logging
@@ -184,7 +194,7 @@ SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032'
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def error_name(error_code):
def error_name(error_code: int) -> str:
return name_or_number(SMP_ERROR_NAMES, error_code)
@@ -197,11 +207,12 @@ class SMP_Command:
'''
smp_classes: Dict[int, Type[SMP_Command]] = {}
fields: Any
code = 0
name = ''
@staticmethod
def from_bytes(pdu):
def from_bytes(pdu: bytes) -> "SMP_Command":
code = pdu[0]
cls = SMP_Command.smp_classes.get(code)
@@ -217,11 +228,11 @@ class SMP_Command:
return self
@staticmethod
def command_name(code):
def command_name(code: int) -> str:
return name_or_number(SMP_COMMAND_NAMES, code)
@staticmethod
def auth_req_str(value):
def auth_req_str(value: int) -> str:
bonding_flags = value & 3
mitm = (value >> 2) & 1
sc = (value >> 3) & 1
@@ -234,12 +245,12 @@ class SMP_Command:
)
@staticmethod
def io_capability_name(io_capability):
def io_capability_name(io_capability: int) -> str:
return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability)
@staticmethod
def key_distribution_str(value):
key_types = []
def key_distribution_str(value: int) -> str:
key_types: List[str] = []
if value & SMP_ENC_KEY_DISTRIBUTION_FLAG:
key_types.append('ENC')
if value & SMP_ID_KEY_DISTRIBUTION_FLAG:
@@ -251,7 +262,7 @@ class SMP_Command:
return ','.join(key_types)
@staticmethod
def keypress_notification_type_name(notification_type):
def keypress_notification_type_name(notification_type: int) -> str:
return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type)
@staticmethod
@@ -272,14 +283,14 @@ class SMP_Command:
return inner
def __init__(self, pdu=None, **kwargs):
def __init__(self, pdu: Optional[bytes] = None, **kwargs: Any) -> None:
if hasattr(self, 'fields') and kwargs:
HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
pdu = bytes([self.code]) + HCI_Object.dict_to_bytes(kwargs, self.fields)
self.pdu = pdu
def init_from_bytes(self, pdu, offset):
def init_from_bytes(self, pdu: bytes, offset: int) -> None:
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self):
@@ -320,6 +331,13 @@ class SMP_Pairing_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request
'''
io_capability: int
oob_data_flag: int
auth_req: int
maximum_encryption_key_size: int
initiator_key_distribution: int
responder_key_distribution: int
# -----------------------------------------------------------------------------
@SMP_Command.subclass(
@@ -343,6 +361,13 @@ class SMP_Pairing_Response_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response
'''
io_capability: int
oob_data_flag: int
auth_req: int
maximum_encryption_key_size: int
initiator_key_distribution: int
responder_key_distribution: int
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('confirm_value', 16)])
@@ -351,6 +376,8 @@ class SMP_Pairing_Confirm_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm
'''
confirm_value: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('random_value', 16)])
@@ -359,6 +386,8 @@ class SMP_Pairing_Random_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random
'''
random_value: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('reason', {'size': 1, 'mapper': error_name})])
@@ -367,6 +396,8 @@ class SMP_Pairing_Failed_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed
'''
reason: int
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('public_key_x', 32), ('public_key_y', 32)])
@@ -375,6 +406,9 @@ class SMP_Pairing_Public_Key_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key
'''
public_key_x: bytes
public_key_y: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass(
@@ -387,6 +421,8 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check
'''
dhkey_check: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass(
@@ -402,6 +438,8 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification
'''
notification_type: int
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('long_term_key', 16)])
@@ -410,6 +448,8 @@ class SMP_Encryption_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information
'''
long_term_key: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('ediv', 2), ('rand', 8)])
@@ -418,6 +458,9 @@ class SMP_Master_Identification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification
'''
ediv: int
rand: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('identity_resolving_key', 16)])
@@ -426,6 +469,8 @@ class SMP_Identity_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information
'''
identity_resolving_key: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass(
@@ -439,6 +484,9 @@ class SMP_Identity_Address_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information
'''
addr_type: int
bd_addr: Address
# -----------------------------------------------------------------------------
@SMP_Command.subclass([('signature_key', 16)])
@@ -447,6 +495,8 @@ class SMP_Signing_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information
'''
signature_key: bytes
# -----------------------------------------------------------------------------
@SMP_Command.subclass(
@@ -459,9 +509,11 @@ class SMP_Security_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request
'''
auth_req: int
# -----------------------------------------------------------------------------
def smp_auth_req(bonding, mitm, sc, keypress, ct2):
def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) -> int:
value = 0
if bonding:
value |= SMP_BONDING_AUTHREQ
@@ -574,11 +626,17 @@ class Session:
},
}
def __init__(self, manager, connection, pairing_config, is_initiator):
def __init__(
self,
manager: Manager,
connection: Connection,
pairing_config: PairingConfig,
is_initiator: bool,
) -> None:
self.manager = manager
self.connection = connection
self.preq = None
self.pres = None
self.preq: Optional[bytes] = None
self.pres: Optional[bytes] = None
self.ea = None
self.eb = None
self.tk = bytes(16)
@@ -588,29 +646,29 @@ class Session:
self.ltk_ediv = 0
self.ltk_rand = bytes(8)
self.link_key = None
self.initiator_key_distribution = 0
self.responder_key_distribution = 0
self.peer_random_value = None
self.peer_public_key_x = bytes(32)
self.initiator_key_distribution: int = 0
self.responder_key_distribution: int = 0
self.peer_random_value: Optional[bytes] = None
self.peer_public_key_x: bytes = bytes(32)
self.peer_public_key_y = bytes(32)
self.peer_ltk = None
self.peer_ediv = None
self.peer_rand = None
self.peer_rand: Optional[bytes] = None
self.peer_identity_resolving_key = None
self.peer_bd_addr = None
self.peer_bd_addr: Optional[Address] = None
self.peer_signature_key = None
self.peer_expected_distributions = []
self.peer_expected_distributions: List[Type[SMP_Command]] = []
self.dh_key = None
self.confirm_value = None
self.passkey = None
self.passkey: Optional[int] = None
self.passkey_ready = asyncio.Event()
self.passkey_step = 0
self.passkey_display = False
self.pairing_method = 0
self.pairing_config = pairing_config
self.wait_before_continuing = None
self.wait_before_continuing: Optional[asyncio.Future[None]] = None
self.completed = False
self.ctkd_task = None
self.ctkd_task: Optional[Awaitable[None]] = None
# Decide if we're the initiator or the responder
self.is_initiator = is_initiator
@@ -628,7 +686,9 @@ class Session:
# Create a future that can be used to wait for the session to complete
if self.is_initiator:
self.pairing_result = asyncio.get_running_loop().create_future()
self.pairing_result: Optional[
asyncio.Future[None]
] = asyncio.get_running_loop().create_future()
else:
self.pairing_result = None
@@ -641,11 +701,11 @@ class Session:
)
# Authentication Requirements Flags - Vol 3, Part H, Figure 3.3
self.bonding = pairing_config.bonding
self.sc = pairing_config.sc
self.mitm = pairing_config.mitm
self.bonding: bool = pairing_config.bonding
self.sc: bool = pairing_config.sc
self.mitm: bool = pairing_config.mitm
self.keypress = False
self.ct2 = False
self.ct2: bool = False
# I/O Capabilities
self.io_capability = pairing_config.delegate.io_capability
@@ -669,34 +729,35 @@ class Session:
self.iat = 1 if peer_address.is_random else 0
@property
def pkx(self):
def pkx(self) -> Tuple[bytes, bytes]:
return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x)
@property
def pka(self):
def pka(self) -> bytes:
return self.pkx[0 if self.is_initiator else 1]
@property
def pkb(self):
def pkb(self) -> bytes:
return self.pkx[0 if self.is_responder else 1]
@property
def nx(self):
def nx(self) -> Tuple[bytes, bytes]:
assert self.peer_random_value
return (self.r, self.peer_random_value)
@property
def na(self):
def na(self) -> bytes:
return self.nx[0 if self.is_initiator else 1]
@property
def nb(self):
def nb(self) -> bytes:
return self.nx[0 if self.is_responder else 1]
@property
def auth_req(self):
def auth_req(self) -> int:
return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2)
def get_long_term_key(self, rand, ediv):
def get_long_term_key(self, rand: bytes, ediv: int) -> Optional[bytes]:
if not self.sc and not self.completed:
if rand == self.ltk_rand and ediv == self.ltk_ediv:
return self.stk
@@ -706,13 +767,13 @@ class Session:
return None
def decide_pairing_method(
self, auth_req, initiator_io_capability, responder_io_capability
):
self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
) -> None:
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
self.pairing_method = self.JUST_WORKS
return
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability]
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] # type: ignore[index]
if isinstance(details, tuple) and len(details) == 2:
# One entry for legacy pairing and one for secure connections
details = details[1 if self.sc else 0]
@@ -724,7 +785,9 @@ class Session:
self.pairing_method = details[0]
self.passkey_display = details[1 if self.is_initiator else 2]
def check_expected_value(self, expected, received, error):
def check_expected_value(
self, expected: bytes, received: bytes, error: int
) -> bool:
logger.debug(f'expected={expected.hex()} got={received.hex()}')
if expected != received:
logger.info(color('pairing confirm/check mismatch', 'red'))
@@ -732,8 +795,8 @@ class Session:
return False
return True
def prompt_user_for_confirmation(self, next_steps):
async def prompt():
def prompt_user_for_confirmation(self, next_steps: Callable[[], None]) -> None:
async def prompt() -> None:
logger.debug('ask for confirmation')
try:
response = await self.pairing_config.delegate.confirm()
@@ -747,8 +810,10 @@ class Session:
self.connection.abort_on('disconnection', prompt())
def prompt_user_for_numeric_comparison(self, code, next_steps):
async def prompt():
def prompt_user_for_numeric_comparison(
self, code: int, next_steps: Callable[[], None]
) -> None:
async def prompt() -> None:
logger.debug(f'verification code: {code}')
try:
response = await self.pairing_config.delegate.compare_numbers(
@@ -764,11 +829,15 @@ class Session:
self.connection.abort_on('disconnection', prompt())
def prompt_user_for_number(self, next_steps):
async def prompt():
def prompt_user_for_number(self, next_steps: Callable[[int], None]) -> None:
async def prompt() -> None:
logger.debug('prompting user for passkey')
try:
passkey = await self.pairing_config.delegate.get_number()
if passkey is None:
logger.debug('Passkey request rejected')
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
return
logger.debug(f'user input: {passkey}')
next_steps(passkey)
except Exception as error:
@@ -777,9 +846,10 @@ class Session:
self.connection.abort_on('disconnection', prompt())
def display_passkey(self):
def display_passkey(self) -> None:
# Generate random Passkey/PIN code
self.passkey = secrets.randbelow(1000000)
assert self.passkey is not None
logger.debug(f'Pairing PIN CODE: {self.passkey:06}')
self.passkey_ready.set()
@@ -793,9 +863,9 @@ class Session:
self.pairing_config.delegate.display_number(self.passkey, digits=6),
)
def input_passkey(self, next_steps=None):
def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None:
# Prompt the user for the passkey displayed on the peer
def after_input(passkey):
def after_input(passkey: int) -> None:
self.passkey = passkey
if not self.sc:
@@ -809,7 +879,9 @@ class Session:
self.prompt_user_for_number(after_input)
def display_or_input_passkey(self, next_steps=None):
def display_or_input_passkey(
self, next_steps: Optional[Callable[[], None]] = None
) -> None:
if self.passkey_display:
self.display_passkey()
if next_steps is not None:
@@ -817,14 +889,14 @@ class Session:
else:
self.input_passkey(next_steps)
def send_command(self, command):
def send_command(self, command: SMP_Command) -> None:
self.manager.send_command(self.connection, command)
def send_pairing_failed(self, error):
def send_pairing_failed(self, error: int) -> None:
self.send_command(SMP_Pairing_Failed_Command(reason=error))
self.on_pairing_failure(error)
def send_pairing_request_command(self):
def send_pairing_request_command(self) -> None:
self.manager.on_session_start(self)
command = SMP_Pairing_Request_Command(
@@ -838,7 +910,7 @@ class Session:
self.preq = bytes(command)
self.send_command(command)
def send_pairing_response_command(self):
def send_pairing_response_command(self) -> None:
response = SMP_Pairing_Response_Command(
io_capability=self.io_capability,
oob_data_flag=0,
@@ -850,18 +922,19 @@ class Session:
self.pres = bytes(response)
self.send_command(response)
def send_pairing_confirm_command(self):
def send_pairing_confirm_command(self) -> None:
self.r = crypto.r()
logger.debug(f'generated random: {self.r.hex()}')
if self.sc:
async def next_steps():
async def next_steps() -> None:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
z = 0
elif self.pairing_method == self.PASSKEY:
# We need a passkey
await self.passkey_ready.wait()
assert self.passkey
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else:
@@ -892,10 +965,10 @@ class Session:
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
def send_pairing_random_command(self):
def send_pairing_random_command(self) -> None:
self.send_command(SMP_Pairing_Random_Command(random_value=self.r))
def send_public_key_command(self):
def send_public_key_command(self) -> None:
self.send_command(
SMP_Pairing_Public_Key_Command(
public_key_x=bytes(reversed(self.manager.ecc_key.x)),
@@ -903,18 +976,18 @@ class Session:
)
)
def send_pairing_dhkey_check_command(self):
def send_pairing_dhkey_check_command(self) -> None:
self.send_command(
SMP_Pairing_DHKey_Check_Command(
dhkey_check=self.ea if self.is_initiator else self.eb
)
)
def start_encryption(self, key):
def start_encryption(self, key: bytes) -> None:
# We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection
self.manager.device.host.send_command_sync(
HCI_LE_Enable_Encryption_Command(
HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
connection_handle=self.connection.handle,
random_number=bytes(8),
encrypted_diversifier=0,
@@ -922,7 +995,7 @@ class Session:
)
)
async def derive_ltk(self):
async def derive_ltk(self) -> None:
link_key = await self.manager.device.get_link_key(self.connection.peer_address)
assert link_key is not None
ilk = (
@@ -932,7 +1005,7 @@ class Session:
)
self.ltk = crypto.h6(ilk, b'brle')
def distribute_keys(self):
def distribute_keys(self) -> None:
# Distribute the keys as required
if self.is_initiator:
# CTKD: Derive LTK from LinkKey
@@ -1032,7 +1105,7 @@ class Session:
)
self.link_key = crypto.h6(ilk, b'lebr')
def compute_peer_expected_distributions(self, key_distribution_flags):
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
# Set our expectations for what to wait for in the key distribution phase
self.peer_expected_distributions = []
if not self.sc and self.connection.transport == BT_LE_TRANSPORT:
@@ -1055,7 +1128,7 @@ class Session:
f'{[c.__name__ for c in self.peer_expected_distributions]}'
)
def check_key_distribution(self, command_class):
def check_key_distribution(self, command_class: Type[SMP_Command]) -> None:
# First, check that the connection is encrypted
if not self.connection.is_encrypted:
logger.warning(
@@ -1083,7 +1156,7 @@ class Session:
)
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
async def pair(self):
async def pair(self) -> None:
# Start pairing as an initiator
# TODO: check that this session isn't already active
@@ -1091,9 +1164,10 @@ class Session:
self.send_pairing_request_command()
# Wait for the pairing process to finish
assert self.pairing_result
await self.connection.abort_on('disconnection', self.pairing_result)
def on_disconnection(self, _):
def on_disconnection(self, _: int) -> None:
self.connection.remove_listener('disconnection', self.on_disconnection)
self.connection.remove_listener(
'connection_encryption_change', self.on_connection_encryption_change
@@ -1104,14 +1178,14 @@ class Session:
)
self.manager.on_session_end(self)
def on_peer_key_distribution_complete(self):
def on_peer_key_distribution_complete(self) -> None:
# The initiator can now send its keys
if self.is_initiator:
self.distribute_keys()
self.connection.abort_on('disconnection', self.on_pairing())
def on_connection_encryption_change(self):
def on_connection_encryption_change(self) -> None:
if self.connection.is_encrypted:
if self.is_responder:
# The responder distributes its keys first, the initiator later
@@ -1121,11 +1195,11 @@ class Session:
if not self.peer_expected_distributions:
self.on_peer_key_distribution_complete()
def on_connection_encryption_key_refresh(self):
def on_connection_encryption_key_refresh(self) -> None:
# Do as if the connection had just been encrypted
self.on_connection_encryption_change()
async def on_pairing(self):
async def on_pairing(self) -> None:
logger.debug('pairing complete')
if self.completed:
@@ -1137,7 +1211,7 @@ class Session:
self.pairing_result.set_result(None)
# Use the peer address from the pairing protocol or the connection
if self.peer_bd_addr:
if self.peer_bd_addr is not None:
peer_address = self.peer_bd_addr
else:
peer_address = self.connection.peer_address
@@ -1186,7 +1260,7 @@ class Session:
)
self.manager.on_pairing(self, peer_address, keys)
def on_pairing_failure(self, reason):
def on_pairing_failure(self, reason: int) -> None:
logger.warning(f'pairing failure ({error_name(reason)})')
if self.completed:
@@ -1199,7 +1273,7 @@ class Session:
self.pairing_result.set_exception(error)
self.manager.on_pairing_failure(self, reason)
def on_smp_command(self, command):
def on_smp_command(self, command: SMP_Command) -> None:
# Find the handler method
handler_name = f'on_{command.name.lower()}'
handler = getattr(self, handler_name, None)
@@ -1215,12 +1289,16 @@ class Session:
else:
logger.error(color('SMP command not handled???', 'red'))
def on_smp_pairing_request_command(self, command):
def on_smp_pairing_request_command(
self, command: SMP_Pairing_Request_Command
) -> None:
self.connection.abort_on(
'disconnection', self.on_smp_pairing_request_command_async(command)
)
async def on_smp_pairing_request_command_async(self, command):
async def on_smp_pairing_request_command_async(
self, command: SMP_Pairing_Request_Command
) -> None:
# Check if the request should proceed
accepted = await self.pairing_config.delegate.accept()
if not accepted:
@@ -1280,7 +1358,9 @@ class Session:
):
self.distribute_keys()
def on_smp_pairing_response_command(self, command):
def on_smp_pairing_response_command(
self, command: SMP_Pairing_Response_Command
) -> None:
if self.is_responder:
logger.warning(color('received pairing response as a responder', 'red'))
return
@@ -1331,7 +1411,9 @@ class Session:
else:
self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command_legacy(self, _):
def on_smp_pairing_confirm_command_legacy(
self, _: SMP_Pairing_Confirm_Command
) -> None:
if self.is_initiator:
self.send_pairing_random_command()
else:
@@ -1341,7 +1423,9 @@ class Session:
else:
self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command_secure_connections(self, _):
def on_smp_pairing_confirm_command_secure_connections(
self, _: SMP_Pairing_Confirm_Command
) -> None:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.is_initiator:
self.r = crypto.r()
@@ -1352,14 +1436,18 @@ class Session:
else:
self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command(self, command):
def on_smp_pairing_confirm_command(
self, command: SMP_Pairing_Confirm_Command
) -> None:
self.confirm_value = command.confirm_value
if self.sc:
self.on_smp_pairing_confirm_command_secure_connections(command)
else:
self.on_smp_pairing_confirm_command_legacy(command)
def on_smp_pairing_random_command_legacy(self, command):
def on_smp_pairing_random_command_legacy(
self, command: SMP_Pairing_Random_Command
) -> None:
# Check that the confirmation values match
confirm_verifier = crypto.c1(
self.tk,
@@ -1371,6 +1459,7 @@ class Session:
self.ia,
self.ra,
)
assert self.confirm_value
if not self.check_expected_value(
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
):
@@ -1394,7 +1483,9 @@ class Session:
else:
self.send_pairing_random_command()
def on_smp_pairing_random_command_secure_connections(self, command):
def on_smp_pairing_random_command_secure_connections(
self, command: SMP_Pairing_Random_Command
) -> None:
if self.pairing_method == self.PASSKEY and self.passkey is None:
logger.warning('no passkey entered, ignoring command')
return
@@ -1402,6 +1493,7 @@ class Session:
# pylint: disable=too-many-return-statements
if self.is_initiator:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
assert self.confirm_value
# Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4(
self.pkb, self.pka, command.random_value, bytes([0])
@@ -1411,6 +1503,7 @@ class Session:
):
return
elif self.pairing_method == self.PASSKEY:
assert self.passkey and self.confirm_value
# Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4(
self.pkb,
@@ -1435,6 +1528,7 @@ class Session:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
self.send_pairing_random_command()
elif self.pairing_method == self.PASSKEY:
assert self.passkey and self.confirm_value
# Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4(
self.pka,
@@ -1468,19 +1562,21 @@ class Session:
ra = bytes(16)
rb = ra
elif self.pairing_method == self.PASSKEY:
assert self.passkey
ra = self.passkey.to_bytes(16, byteorder='little')
rb = ra
else:
# OOB not implemented yet
return
assert self.preq and self.pres
io_cap_a = self.preq[1:4]
io_cap_b = self.pres[1:4]
self.ea = crypto.f6(mac_key, self.na, self.nb, rb, io_cap_a, a, b)
self.eb = crypto.f6(mac_key, self.nb, self.na, ra, io_cap_b, b, a)
# Next steps to be performed after possible user confirmation
def next_steps():
def next_steps() -> None:
# The initiator sends the DH Key check to the responder
if self.is_initiator:
self.send_pairing_dhkey_check_command()
@@ -1502,14 +1598,18 @@ class Session:
else:
next_steps()
def on_smp_pairing_random_command(self, command):
def on_smp_pairing_random_command(
self, command: SMP_Pairing_Random_Command
) -> None:
self.peer_random_value = command.random_value
if self.sc:
self.on_smp_pairing_random_command_secure_connections(command)
else:
self.on_smp_pairing_random_command_legacy(command)
def on_smp_pairing_public_key_command(self, command):
def on_smp_pairing_public_key_command(
self, command: SMP_Pairing_Public_Key_Command
) -> None:
# Store the public key so that we can compute the confirmation value later
self.peer_public_key_x = command.public_key_x
self.peer_public_key_y = command.public_key_y
@@ -1538,9 +1638,12 @@ class Session:
# We can now send the confirmation value
self.send_pairing_confirm_command()
def on_smp_pairing_dhkey_check_command(self, command):
def on_smp_pairing_dhkey_check_command(
self, command: SMP_Pairing_DHKey_Check_Command
) -> None:
# Check that what we received matches what we computed earlier
expected = self.eb if self.is_initiator else self.ea
assert expected
if not self.check_expected_value(
expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR
):
@@ -1549,7 +1652,8 @@ class Session:
if self.is_responder:
if self.wait_before_continuing is not None:
async def next_steps():
async def next_steps() -> None:
assert self.wait_before_continuing
await self.wait_before_continuing
self.wait_before_continuing = None
self.send_pairing_dhkey_check_command()
@@ -1558,29 +1662,42 @@ class Session:
else:
self.send_pairing_dhkey_check_command()
else:
assert self.ltk
self.start_encryption(self.ltk)
def on_smp_pairing_failed_command(self, command):
def on_smp_pairing_failed_command(
self, command: SMP_Pairing_Failed_Command
) -> None:
self.on_pairing_failure(command.reason)
def on_smp_encryption_information_command(self, command):
def on_smp_encryption_information_command(
self, command: SMP_Encryption_Information_Command
) -> None:
self.peer_ltk = command.long_term_key
self.check_key_distribution(SMP_Encryption_Information_Command)
def on_smp_master_identification_command(self, command):
def on_smp_master_identification_command(
self, command: SMP_Master_Identification_Command
) -> None:
self.peer_ediv = command.ediv
self.peer_rand = command.rand
self.check_key_distribution(SMP_Master_Identification_Command)
def on_smp_identity_information_command(self, command):
def on_smp_identity_information_command(
self, command: SMP_Identity_Information_Command
) -> None:
self.peer_identity_resolving_key = command.identity_resolving_key
self.check_key_distribution(SMP_Identity_Information_Command)
def on_smp_identity_address_information_command(self, command):
def on_smp_identity_address_information_command(
self, command: SMP_Identity_Address_Information_Command
) -> None:
self.peer_bd_addr = command.bd_addr
self.check_key_distribution(SMP_Identity_Address_Information_Command)
def on_smp_signing_information_command(self, command):
def on_smp_signing_information_command(
self, command: SMP_Signing_Information_Command
) -> None:
self.peer_signature_key = command.signature_key
self.check_key_distribution(SMP_Signing_Information_Command)
@@ -1591,14 +1708,24 @@ class Manager(EventEmitter):
Implements the Initiator and Responder roles of the Security Manager Protocol
'''
def __init__(self, device, pairing_config_factory):
device: Device
sessions: Dict[int, Session]
pairing_config_factory: Callable[[Connection], PairingConfig]
session_proxy: Type[Session]
def __init__(
self,
device: Device,
pairing_config_factory: Callable[[Connection], PairingConfig],
) -> None:
super().__init__()
self.device = device
self.sessions = {}
self._ecc_key = None
self.pairing_config_factory = pairing_config_factory
self.session_proxy = Session
def send_command(self, connection, command):
def send_command(self, connection: Connection, command: SMP_Command) -> None:
logger.debug(
f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
@@ -1606,20 +1733,15 @@ class Manager(EventEmitter):
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes())
def on_smp_pdu(self, connection, pdu):
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
# Look for a session with this connection, and create one if none exists
if not (session := self.sessions.get(connection.handle)):
if connection.role == BT_CENTRAL_ROLE:
logger.warning('Remote starts pairing as Peripheral!')
pairing_config = self.pairing_config_factory(connection)
if pairing_config is None:
# Pairing disabled
self.send_command(
connection,
SMP_Pairing_Failed_Command(reason=SMP_PAIRING_NOT_SUPPORTED_ERROR),
)
return
session = Session(self, connection, pairing_config, is_initiator=False)
session = self.session_proxy(
self, connection, pairing_config, is_initiator=False
)
self.sessions[connection.handle] = session
# Parse the L2CAP payload into an SMP Command object
@@ -1633,23 +1755,24 @@ class Manager(EventEmitter):
session.on_smp_command(command)
@property
def ecc_key(self):
def ecc_key(self) -> crypto.EccKey:
if self._ecc_key is None:
self._ecc_key = crypto.EccKey.generate()
assert self._ecc_key
return self._ecc_key
async def pair(self, connection):
async def pair(self, connection: Connection) -> None:
# TODO: check if there's already a session for this connection
if connection.role != BT_CENTRAL_ROLE:
logger.warning('Start pairing as Peripheral!')
pairing_config = self.pairing_config_factory(connection)
if pairing_config is None:
raise ValueError('pairing config must not be None when initiating')
session = Session(self, connection, pairing_config, is_initiator=True)
session = self.session_proxy(
self, connection, pairing_config, is_initiator=True
)
self.sessions[connection.handle] = session
return await session.pair()
def request_pairing(self, connection):
def request_pairing(self, connection: Connection) -> None:
pairing_config = self.pairing_config_factory(connection)
if pairing_config:
auth_req = smp_auth_req(
@@ -1663,15 +1786,18 @@ class Manager(EventEmitter):
auth_req = 0
self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req))
def on_session_start(self, session):
self.device.on_pairing_start(session.connection.handle)
def on_session_start(self, session: Session) -> None:
self.device.on_pairing_start(session.connection)
def on_pairing(self, session, identity_address, keys):
def on_pairing(
self, session: Session, identity_address: Optional[Address], keys: PairingKeys
) -> None:
# Store the keys in the key store
if self.device.keystore and identity_address is not None:
async def store_keys():
try:
assert self.device.keystore
await self.device.keystore.update(str(identity_address), keys)
except Exception as error:
logger.warning(f'!!! error while storing keys: {error}')
@@ -1679,17 +1805,19 @@ class Manager(EventEmitter):
self.device.abort_on('flush', store_keys())
# Notify the device
self.device.on_pairing(session.connection.handle, keys, session.sc)
self.device.on_pairing(session.connection, keys, session.sc)
def on_pairing_failure(self, session, reason):
self.device.on_pairing_failure(session.connection.handle, reason)
def on_pairing_failure(self, session: Session, reason: int) -> None:
self.device.on_pairing_failure(session.connection, reason)
def on_session_end(self, session):
def on_session_end(self, session: Session) -> None:
logger.debug(f'session end for connection 0x{session.connection.handle:04X}')
if session.connection.handle in self.sessions:
del self.sessions[session.connection.handle]
def get_long_term_key(self, connection, rand, ediv):
def get_long_term_key(
self, connection: Connection, rand: bytes, ediv: int
) -> Optional[bytes]:
if session := self.sessions.get(connection.handle):
return session.get_long_term_key(rand, ediv)