diff --git a/bumble/device.py b/bumble/device.py index 775046ff..d9410a4e 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -1589,7 +1589,6 @@ class Connection(utils.CompositeEventEmitter): encryption_key_size: int authenticated: bool sc: bool - link_key_type: Optional[int] # [Classic only] gatt_client: gatt_client.Client pairing_peer_io_capability: Optional[int] pairing_peer_authentication_requirements: Optional[int] @@ -1692,7 +1691,6 @@ class Connection(utils.CompositeEventEmitter): self.encryption_key_size = 0 self.authenticated = False self.sc = False - self.link_key_type = None self.att_mtu = ATT_DEFAULT_MTU self.data_length = DEVICE_DEFAULT_DATA_LENGTH self.gatt_client = gatt_client.Client(self) # Per-connection client @@ -5075,9 +5073,9 @@ class Device(utils.CompositeEventEmitter): hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE, hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE, ) - pairing_keys = PairingKeys() - pairing_keys.link_key = PairingKeys.Key( - value=link_key, authenticated=authenticated + pairing_keys = PairingKeys( + link_key=PairingKeys.Key(value=link_key, authenticated=authenticated), + link_key_type=key_type, ) utils.cancel_on_event( @@ -5087,7 +5085,6 @@ class Device(utils.CompositeEventEmitter): if connection := self.find_connection_by_bd_addr( bd_addr, transport=PhysicalTransport.BR_EDR ): - connection.link_key_type = key_type connection.emit(connection.EVENT_LINK_KEY) def add_service(self, service): diff --git a/bumble/keys.py b/bumble/keys.py index 8ba1726f..2573b1d5 100644 --- a/bumble/keys.py +++ b/bumble/keys.py @@ -22,14 +22,15 @@ # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio +import dataclasses import logging import os import json -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Any from typing_extensions import Self from bumble.colors import color -from bumble.hci import Address +from bumble import hci if TYPE_CHECKING: from bumble.device import Device @@ -42,16 +43,17 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- +@dataclasses.dataclass class PairingKeys: + @dataclasses.dataclass class Key: - def __init__(self, value, authenticated=False, ediv=None, rand=None): - self.value = value - self.authenticated = authenticated - self.ediv = ediv - self.rand = rand + value: bytes + authenticated: bool = False + ediv: Optional[int] = None + rand: Optional[bytes] = None @classmethod - def from_dict(cls, key_dict): + def from_dict(cls, key_dict: dict[str, Any]) -> PairingKeys.Key: value = bytes.fromhex(key_dict['value']) authenticated = key_dict.get('authenticated', False) ediv = key_dict.get('ediv') @@ -61,7 +63,7 @@ class PairingKeys: return cls(value, authenticated, ediv, rand) - def to_dict(self): + def to_dict(self) -> dict[str, Any]: key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated} if self.ediv is not None: key_dict['ediv'] = self.ediv @@ -70,39 +72,42 @@ class PairingKeys: return key_dict - def __init__(self): - self.address_type = None - self.ltk = None - self.ltk_central = None - self.ltk_peripheral = None - self.irk = None - self.csrk = None - self.link_key = None # Classic + address_type: Optional[hci.AddressType] = None + ltk: Optional[Key] = None + ltk_central: Optional[Key] = None + ltk_peripheral: Optional[Key] = None + irk: Optional[Key] = None + csrk: Optional[Key] = None + link_key: Optional[Key] = None # Classic + link_key_type: Optional[int] = None # Classic - @staticmethod - def key_from_dict(keys_dict, key_name): + @classmethod + def key_from_dict(cls, keys_dict: dict[str, Any], key_name: str) -> Optional[Key]: key_dict = keys_dict.get(key_name) if key_dict is None: return None return PairingKeys.Key.from_dict(key_dict) - @staticmethod - def from_dict(keys_dict): - keys = PairingKeys() + @classmethod + def from_dict(cls, keys_dict: dict[str, Any]) -> PairingKeys: + return PairingKeys( + address_type=( + hci.AddressType(t) + if (t := keys_dict.get('address_type')) is not None + else None + ), + ltk=PairingKeys.key_from_dict(keys_dict, 'ltk'), + ltk_central=PairingKeys.key_from_dict(keys_dict, 'ltk_central'), + ltk_peripheral=PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral'), + irk=PairingKeys.key_from_dict(keys_dict, 'irk'), + csrk=PairingKeys.key_from_dict(keys_dict, 'csrk'), + link_key=PairingKeys.key_from_dict(keys_dict, 'link_key'), + link_key_type=keys_dict.get('link_key_type'), + ) - keys.address_type = keys_dict.get('address_type') - keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk') - keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central') - keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral') - keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk') - keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk') - keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key') - - return keys - - def to_dict(self): - keys = {} + def to_dict(self) -> dict[str, Any]: + keys: dict[str, Any] = {} if self.address_type is not None: keys['address_type'] = self.address_type @@ -125,9 +130,12 @@ class PairingKeys: if self.link_key is not None: keys['link_key'] = self.link_key.to_dict() + if self.link_key_type is not None: + keys['link_key_type'] = self.link_key_type + return keys - def print(self, prefix=''): + def print(self, prefix: str = '') -> None: keys_dict = self.to_dict() for container_property, value in keys_dict.items(): if isinstance(value, dict): @@ -156,20 +164,28 @@ class KeyStore: all_keys = await self.get_all() await asyncio.gather(*(self.delete(name) for (name, _) in all_keys)) - async def get_resolving_keys(self): + async def get_resolving_keys(self) -> list[tuple[bytes, hci.Address]]: all_keys = await self.get_all() resolving_keys = [] for name, keys in all_keys: if keys.irk is not None: - if keys.address_type is None: - address_type = Address.RANDOM_DEVICE_ADDRESS - else: - address_type = keys.address_type - resolving_keys.append((keys.irk.value, Address(name, address_type))) + resolving_keys.append( + ( + keys.irk.value, + hci.Address( + name, + ( + keys.address_type + if keys.address_type is not None + else hci.Address.RANDOM_DEVICE_ADDRESS + ), + ), + ) + ) return resolving_keys - async def print(self, prefix=''): + async def print(self, prefix: str = '') -> None: entries = await self.get_all() separator = '' for name, keys in entries: @@ -177,8 +193,8 @@ class KeyStore: keys.print(prefix=prefix + ' ') separator = '\n' - @staticmethod - def create_for_device(device: Device) -> KeyStore: + @classmethod + def create_for_device(cls, device: Device) -> KeyStore: if device.config.keystore is None: return MemoryKeyStore() @@ -266,9 +282,9 @@ class JsonKeyStore(KeyStore): filename = params[0] # Use a namespace based on the device address - if device.public_address not in (Address.ANY, Address.ANY_RANDOM): + if device.public_address not in (hci.Address.ANY, hci.Address.ANY_RANDOM): namespace = str(device.public_address) - elif device.random_address != Address.ANY_RANDOM: + elif device.random_address != hci.Address.ANY_RANDOM: namespace = str(device.random_address) else: namespace = JsonKeyStore.DEFAULT_NAMESPACE diff --git a/bumble/pandora/security.py b/bumble/pandora/security.py index ca3102d6..1b9f21bd 100644 --- a/bumble/pandora/security.py +++ b/bumble/pandora/security.py @@ -15,6 +15,7 @@ from __future__ import annotations import asyncio import contextlib +from collections.abc import Awaitable import grpc import logging @@ -24,6 +25,7 @@ from bumble import hci from bumble.core import ( PhysicalTransport, ProtocolError, + InvalidArgumentError, ) import bumble.utils from bumble.device import Connection as BumbleConnection, Device @@ -188,35 +190,6 @@ class PairingDelegate(BasePairingDelegate): self.service.event_queue.put_nowait(event) -BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = { - LEVEL0: lambda connection: True, - LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated, - LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated, - LEVEL3: lambda connection: connection.encryption != 0 - and connection.authenticated - and connection.link_key_type - in ( - hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE, - hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE, - ), - LEVEL4: lambda connection: connection.encryption - == hci.HCI_Encryption_Change_Event.AES_CCM - and connection.authenticated - and connection.link_key_type - == hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE, -} - -LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = { - LE_LEVEL1: lambda connection: True, - LE_LEVEL2: lambda connection: connection.encryption != 0, - LE_LEVEL3: lambda connection: connection.encryption != 0 - and connection.authenticated, - LE_LEVEL4: lambda connection: connection.encryption != 0 - and connection.authenticated - and connection.sc, -} - - class SecurityService(SecurityServicer): def __init__(self, device: Device, config: Config) -> None: self.log = utils.BumbleServerLoggerAdapter( @@ -248,6 +221,59 @@ class SecurityService(SecurityServicer): self.device.pairing_config_factory = pairing_config_factory + async def _classic_level_reached( + self, level: SecurityLevel, connection: BumbleConnection + ) -> bool: + if level == LEVEL0: + return True + if level == LEVEL1: + return connection.encryption == 0 or connection.authenticated + if level == LEVEL2: + return connection.encryption != 0 and connection.authenticated + + link_key_type: Optional[int] = None + if (keystore := connection.device.keystore) and ( + keys := await keystore.get(str(connection.peer_address)) + ): + link_key_type = keys.link_key_type + self.log.debug("link_key_type: %d", link_key_type) + + if level == LEVEL3: + return ( + connection.encryption != 0 + and connection.authenticated + and link_key_type + in ( + hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE, + hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE, + ) + ) + if level == LEVEL4: + return ( + connection.encryption == hci.HCI_Encryption_Change_Event.AES_CCM + and connection.authenticated + and link_key_type + == hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE + ) + raise InvalidArgumentError(f"Unexpected level {level}") + + def _le_level_reached( + self, level: LESecurityLevel, connection: BumbleConnection + ) -> bool: + if level == LE_LEVEL1: + return True + if level == LE_LEVEL2: + return connection.encryption != 0 + if level == LE_LEVEL3: + return connection.encryption != 0 and connection.authenticated + if level == LE_LEVEL4: + return ( + connection.encryption != 0 + and connection.authenticated + and connection.sc + ) + raise InvalidArgumentError(f"Unexpected level {level}") + @utils.rpc async def OnPairing( self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext @@ -290,7 +316,7 @@ class SecurityService(SecurityServicer): ] == oneof # security level already reached - if self.reached_security_level(connection, level): + if await self.reached_security_level(connection, level): return SecureResponse(success=empty_pb2.Empty()) # trigger pairing if needed @@ -361,7 +387,7 @@ class SecurityService(SecurityServicer): return SecureResponse(encryption_failure=empty_pb2.Empty()) # security level has been reached ? - if self.reached_security_level(connection, level): + if await self.reached_security_level(connection, level): return SecureResponse(success=empty_pb2.Empty()) return SecureResponse(not_reached=empty_pb2.Empty()) @@ -388,13 +414,10 @@ class SecurityService(SecurityServicer): pair_task: Optional[asyncio.Future[None]] = None async def authenticate() -> None: - assert connection if (encryption := connection.encryption) != 0: self.log.debug('Disable encryption...') - try: + with contextlib.suppress(Exception): await connection.encrypt(enable=False) - except: - pass self.log.debug('Disable encryption: done') self.log.debug('Authenticate...') @@ -413,15 +436,13 @@ class SecurityService(SecurityServicer): return wrapper - def try_set_success(*_: Any) -> None: - assert connection - if self.reached_security_level(connection, level): + async def try_set_success(*_: Any) -> None: + if await self.reached_security_level(connection, level): self.log.debug('Wait for security: done') wait_for_security.set_result('success') - def on_encryption_change(*_: Any) -> None: - assert connection - if self.reached_security_level(connection, level): + async def on_encryption_change(*_: Any) -> None: + if await self.reached_security_level(connection, level): self.log.debug('Wait for security: done') wait_for_security.set_result('success') elif ( @@ -436,7 +457,7 @@ class SecurityService(SecurityServicer): if self.need_pairing(connection, level): pair_task = asyncio.create_task(connection.pair()) - listeners: Dict[str, Callable[..., None]] = { + listeners: Dict[str, Callable[..., Union[None, Awaitable[None]]]] = { 'disconnection': set_failure('connection_died'), 'pairing_failure': set_failure('pairing_failure'), 'connection_authentication_failure': set_failure('authentication_failure'), @@ -455,7 +476,7 @@ class SecurityService(SecurityServicer): watcher.on(connection, event, listener) # security level already reached - if self.reached_security_level(connection, level): + if await self.reached_security_level(connection, level): return WaitSecurityResponse(success=empty_pb2.Empty()) self.log.debug('Wait for security...') @@ -465,24 +486,20 @@ class SecurityService(SecurityServicer): # wait for `authenticate` to finish if any if authenticate_task is not None: self.log.debug('Wait for authentication...') - try: + with contextlib.suppress(Exception): await authenticate_task # type: ignore - except: - pass self.log.debug('Authenticated') # wait for `pair` to finish if any if pair_task is not None: self.log.debug('Wait for authentication...') - try: + with contextlib.suppress(Exception): await pair_task # type: ignore - except: - pass self.log.debug('paired') return WaitSecurityResponse(**kwargs) - def reached_security_level( + async def reached_security_level( self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel] ) -> bool: self.log.debug( @@ -492,15 +509,14 @@ class SecurityService(SecurityServicer): 'encryption': connection.encryption, 'authenticated': connection.authenticated, 'sc': connection.sc, - 'link_key_type': connection.link_key_type, } ) ) if isinstance(level, LESecurityLevel): - return LE_LEVEL_REACHED[level](connection) + return self._le_level_reached(level, connection) - return BR_LEVEL_REACHED[level](connection) + return await self._classic_level_reached(level, connection) def need_pairing(self, connection: BumbleConnection, level: int) -> bool: if connection.transport == PhysicalTransport.LE: diff --git a/bumble/smp.py b/bumble/smp.py index 79579d38..bd47fc82 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -1380,8 +1380,10 @@ class Session: ediv=self.ltk_ediv, rand=self.ltk_rand, ) + if not self.peer_ltk: + logger.error("peer_ltk is None") peer_ltk_key = PairingKeys.Key( - value=self.peer_ltk, + value=self.peer_ltk or b'', authenticated=authenticated, ediv=self.peer_ediv, rand=self.peer_rand,