Merge pull request #178 from google/uael/pairing

Overall fixes and improvements
This commit is contained in:
Lucas Abel
2023-05-03 21:39:50 -07:00
committed by GitHub
7 changed files with 439 additions and 256 deletions

View File

@@ -152,7 +152,12 @@ class UUID:
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian
UUIDS: List[UUID] = [] # Registry of all instances created UUIDS: List[UUID] = [] # Registry of all instances created
def __init__(self, uuid_str_or_int, name=None): uuid_bytes: bytes
name: Optional[str]
def __init__(
self, uuid_str_or_int: Union[str, int], name: Optional[str] = None
) -> None:
if isinstance(uuid_str_or_int, int): if isinstance(uuid_str_or_int, int):
self.uuid_bytes = struct.pack('<H', uuid_str_or_int) self.uuid_bytes = struct.pack('<H', uuid_str_or_int)
else: else:
@@ -172,7 +177,7 @@ class UUID:
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str))) self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
self.name = name self.name = name
def register(self): def register(self) -> UUID:
# Register this object in the class registry, and update the entry's name if # Register this object in the class registry, and update the entry's name if
# it wasn't set already # it wasn't set already
for uuid in self.UUIDS: for uuid in self.UUIDS:
@@ -196,22 +201,22 @@ class UUID:
raise ValueError('only 2, 4 and 16 bytes are allowed') raise ValueError('only 2, 4 and 16 bytes are allowed')
@classmethod @classmethod
def from_16_bits(cls, uuid_16, name=None): def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID:
return cls.from_bytes(struct.pack('<H', uuid_16), name) return cls.from_bytes(struct.pack('<H', uuid_16), name)
@classmethod @classmethod
def from_32_bits(cls, uuid_32, name=None): def from_32_bits(cls, uuid_32: int, name: Optional[str] = None) -> UUID:
return cls.from_bytes(struct.pack('<I', uuid_32), name) return cls.from_bytes(struct.pack('<I', uuid_32), name)
@classmethod @classmethod
def parse_uuid(cls, uuid_as_bytes, offset): def parse_uuid(cls, uuid_as_bytes: bytes, offset: int) -> Tuple[int, UUID]:
return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:]) return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:])
@classmethod @classmethod
def parse_uuid_2(cls, uuid_as_bytes, offset): def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> Tuple[int, UUID]:
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2]) return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
def to_bytes(self, force_128=False): def to_bytes(self, force_128: bool = False) -> bytes:
''' '''
Serialize UUID in little-endian byte-order Serialize UUID in little-endian byte-order
''' '''
@@ -227,7 +232,7 @@ class UUID:
else: else:
assert False, "unreachable" assert False, "unreachable"
def to_pdu_bytes(self): def to_pdu_bytes(self) -> bytes:
''' '''
Convert to bytes for use in an ATT PDU. Convert to bytes for use in an ATT PDU.
According to Vol 3, Part F - 3.2.1 Attribute Type: According to Vol 3, Part F - 3.2.1 Attribute Type:
@@ -236,11 +241,11 @@ class UUID:
''' '''
return self.to_bytes(force_128=(len(self.uuid_bytes) == 4)) return self.to_bytes(force_128=(len(self.uuid_bytes) == 4))
def to_hex_str(self) -> str: def to_hex_str(self, separator: str = '') -> str:
if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4: if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4:
return bytes(reversed(self.uuid_bytes)).hex().upper() return bytes(reversed(self.uuid_bytes)).hex().upper()
return ''.join( return separator.join(
[ [
bytes(reversed(self.uuid_bytes[12:16])).hex(), bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(), bytes(reversed(self.uuid_bytes[10:12])).hex(),
@@ -250,10 +255,10 @@ class UUID:
] ]
).upper() ).upper()
def __bytes__(self): def __bytes__(self) -> bytes:
return self.to_bytes() return self.to_bytes()
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if isinstance(other, UUID): if isinstance(other, UUID):
return self.to_bytes(force_128=True) == other.to_bytes(force_128=True) return self.to_bytes(force_128=True) == other.to_bytes(force_128=True)
@@ -262,35 +267,19 @@ class UUID:
return False return False
def __hash__(self): def __hash__(self) -> int:
return hash(self.uuid_bytes) return hash(self.uuid_bytes)
def __str__(self): def __str__(self) -> str:
result = self.to_hex_str(separator='-')
if len(self.uuid_bytes) == 2: if len(self.uuid_bytes) == 2:
uuid = struct.unpack('<H', self.uuid_bytes)[0] result = 'UUID-16:' + result
result = f'UUID-16:{uuid:04X}'
elif len(self.uuid_bytes) == 4: elif len(self.uuid_bytes) == 4:
uuid = struct.unpack('<I', self.uuid_bytes)[0] result = 'UUID-32:' + result
result = f'UUID-32:{uuid:08X}'
else:
result = '-'.join(
[
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex(),
]
).upper()
if self.name is not None: if self.name is not None:
return result + f' ({self.name})' result += f' ({self.name})'
return result return result
def __repr__(self):
return str(self)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Common UUID constants # Common UUID constants
@@ -773,7 +762,7 @@ class AdvertisingData:
def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> List[UUID]: def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> List[UUID]:
uuids = [] uuids = []
offset = 0 offset = 0
while (uuid_size * (offset + 1)) <= len(ad_data): while (offset + uuid_size) <= len(ad_data):
uuids.append(UUID.from_bytes(ad_data[offset : offset + uuid_size])) uuids.append(UUID.from_bytes(ad_data[offset : offset + uuid_size]))
offset += uuid_size offset += uuid_size
return uuids return uuids

View File

@@ -23,7 +23,7 @@ import asyncio
import logging import logging
from contextlib import asynccontextmanager, AsyncExitStack from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union
from .colors import color from .colors import color
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
@@ -528,6 +528,7 @@ class Connection(CompositeEventEmitter):
transport: int transport: int
self_address: Address self_address: Address
peer_address: Address peer_address: Address
peer_resolvable_address: Optional[Address]
role: int role: int
encryption: int encryption: int
authenticated: bool authenticated: bool
@@ -2196,13 +2197,23 @@ class Device(CompositeEventEmitter):
await self.stop_discovery() await self.stop_discovery()
@property @property
def pairing_config_factory(self): def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]:
return self.smp_manager.pairing_config_factory return self.smp_manager.pairing_config_factory
@pairing_config_factory.setter @pairing_config_factory.setter
def pairing_config_factory(self, pairing_config_factory): def pairing_config_factory(
self, pairing_config_factory: Callable[[Connection], PairingConfig]
) -> None:
self.smp_manager.pairing_config_factory = pairing_config_factory self.smp_manager.pairing_config_factory = pairing_config_factory
@property
def smp_session_proxy(self) -> Type[smp.Session]:
return self.smp_manager.session_proxy
@smp_session_proxy.setter
def smp_session_proxy(self, session_proxy: Type[smp.Session]) -> None:
self.smp_manager.session_proxy = session_proxy
async def pair(self, connection): async def pair(self, connection):
return await self.smp_manager.pair(connection) return await self.smp_manager.pair(connection)
@@ -2232,7 +2243,7 @@ class Device(CompositeEventEmitter):
if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral: if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral:
return keys.ltk_peripheral.value return keys.ltk_peripheral.value
async def get_link_key(self, address): async def get_link_key(self, address: Address) -> Optional[bytes]:
# Look for the key in the keystore # Look for the key in the keystore
if self.keystore is not None: if self.keystore is not None:
keys = await self.keystore.get(str(address)) keys = await self.keystore.get(str(address))
@@ -2243,6 +2254,7 @@ class Device(CompositeEventEmitter):
return None return None
return keys.link_key.value return keys.link_key.value
return None
# [Classic only] # [Classic only]
async def authenticate(self, connection): async def authenticate(self, connection):
@@ -2772,89 +2784,103 @@ class Device(CompositeEventEmitter):
# [Classic only] # [Classic only]
@host_event_handler @host_event_handler
@with_connection_from_address @with_connection_from_address
def on_authentication_user_confirmation_request(self, connection, code): def on_authentication_user_confirmation_request(self, connection, code) -> None:
# Ask what the pairing config should be for this connection # Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection) pairing_config = self.pairing_config_factory(connection)
io_capability = pairing_config.delegate.classic_io_capability io_capability = pairing_config.delegate.classic_io_capability
peer_io_capability = connection.peer_pairing_io_capability
# Respond async def confirm() -> bool:
if io_capability == HCI_DISPLAY_YES_NO_IO_CAPABILITY: # Ask the user to confirm the pairing, without display
if connection.peer_pairing_io_capability in ( return await pairing_config.delegate.confirm()
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
HCI_DISPLAY_ONLY_IO_CAPABILITY, async def auto_confirm() -> bool:
): # Ask the user to auto-confirm the pairing, without display
# Display the code and ask the user to compare return await pairing_config.delegate.confirm(auto=True)
async def prompt():
return ( async def display_confirm() -> bool:
await pairing_config.delegate.compare_numbers(code, digits=6), # Display the code and ask the user to compare
return await pairing_config.delegate.compare_numbers(code, digits=6)
async def display_auto_confirm() -> bool:
# Display the code to the user and ask the delegate to auto-confirm
await pairing_config.delegate.display_number(code, digits=6)
return await pairing_config.delegate.confirm(auto=True)
async def na() -> bool:
assert False, "N/A: unreachable"
# See Bluetooth spec @ Vol 3, Part C 5.2.2.6
methods = {
HCI_DISPLAY_ONLY_IO_CAPABILITY: {
HCI_DISPLAY_ONLY_IO_CAPABILITY: display_auto_confirm,
HCI_DISPLAY_YES_NO_IO_CAPABILITY: display_confirm,
HCI_KEYBOARD_ONLY_IO_CAPABILITY: na,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
},
HCI_DISPLAY_YES_NO_IO_CAPABILITY: {
HCI_DISPLAY_ONLY_IO_CAPABILITY: display_auto_confirm,
HCI_DISPLAY_YES_NO_IO_CAPABILITY: display_confirm,
HCI_KEYBOARD_ONLY_IO_CAPABILITY: na,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
},
HCI_KEYBOARD_ONLY_IO_CAPABILITY: {
HCI_DISPLAY_ONLY_IO_CAPABILITY: na,
HCI_DISPLAY_YES_NO_IO_CAPABILITY: na,
HCI_KEYBOARD_ONLY_IO_CAPABILITY: na,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
},
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: {
HCI_DISPLAY_ONLY_IO_CAPABILITY: confirm,
HCI_DISPLAY_YES_NO_IO_CAPABILITY: confirm,
HCI_KEYBOARD_ONLY_IO_CAPABILITY: auto_confirm,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
},
}
method = methods[peer_io_capability][io_capability]
async def reply() -> None:
if await connection.abort_on('disconnection', method()):
await self.host.send_command(
HCI_User_Confirmation_Request_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address
) )
)
else: else:
# Ask the user to confirm the pairing, without showing a code await self.host.send_command(
async def prompt(): HCI_User_Confirmation_Request_Negative_Reply_Command( # type: ignore[call-arg]
return await pairing_config.delegate.confirm() bd_addr=connection.peer_address
async def confirm():
if await prompt():
await self.host.send_command(
HCI_User_Confirmation_Request_Reply_Command(
bd_addr=connection.peer_address
)
)
else:
await self.host.send_command(
HCI_User_Confirmation_Request_Negative_Reply_Command(
bd_addr=connection.peer_address
)
) )
)
AsyncRunner.spawn(connection.abort_on('disconnection', confirm())) AsyncRunner.spawn(reply())
return
if io_capability == HCI_DISPLAY_ONLY_IO_CAPABILITY:
# Display the code to the user
AsyncRunner.spawn(pairing_config.delegate.display_number(code, 6))
# Automatic confirmation
self.host.send_command_sync(
HCI_User_Confirmation_Request_Reply_Command(bd_addr=connection.peer_address)
)
# [Classic only] # [Classic only]
@host_event_handler @host_event_handler
@with_connection_from_address @with_connection_from_address
def on_authentication_user_passkey_request(self, connection): def on_authentication_user_passkey_request(self, connection) -> None:
# Ask what the pairing config should be for this connection # Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection) pairing_config = self.pairing_config_factory(connection)
io_capability = pairing_config.delegate.classic_io_capability
# Respond async def reply() -> None:
if io_capability == HCI_KEYBOARD_ONLY_IO_CAPABILITY: number = await connection.abort_on(
# Ask the user to input a number 'disconnection', pairing_config.delegate.get_number()
async def get_number():
number = await connection.abort_on(
'disconnection', pairing_config.delegate.get_number()
)
if number is not None:
await self.host.send_command(
HCI_User_Passkey_Request_Reply_Command(
bd_addr=connection.peer_address, numeric_value=number
)
)
else:
await self.host.send_command(
HCI_User_Passkey_Request_Negative_Reply_Command(
bd_addr=connection.peer_address
)
)
asyncio.create_task(get_number())
else:
self.host.send_command_sync(
HCI_User_Passkey_Request_Negative_Reply_Command(
bd_addr=connection.peer_address
)
) )
if number is not None:
await self.host.send_command(
HCI_User_Passkey_Request_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address, numeric_value=number
)
)
else:
await self.host.send_command(
HCI_User_Passkey_Request_Negative_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address
)
)
AsyncRunner.spawn(reply())
# [Classic only] # [Classic only]
@host_event_handler @host_event_handler
@@ -3059,18 +3085,15 @@ class Device(CompositeEventEmitter):
connection.emit('role_change_failure', error) connection.emit('role_change_failure', error)
self.emit('role_change_failure', address, error) self.emit('role_change_failure', address, error)
@with_connection_from_handle def on_pairing_start(self, connection: Connection) -> None:
def on_pairing_start(self, connection):
connection.emit('pairing_start') connection.emit('pairing_start')
@with_connection_from_handle def on_pairing(self, connection: Connection, keys: PairingKeys, sc: bool) -> None:
def on_pairing(self, connection, keys, sc):
connection.sc = sc connection.sc = sc
connection.authenticated = True connection.authenticated = True
connection.emit('pairing', keys) connection.emit('pairing', keys)
@with_connection_from_handle def on_pairing_failure(self, connection: Connection, reason: int) -> None:
def on_pairing_failure(self, connection, reason):
connection.emit('pairing_failure', reason) connection.emit('pairing_failure', reason)
@with_connection_from_handle @with_connection_from_handle

View File

@@ -247,7 +247,7 @@ class TemplateService(Service):
to expose their UUID as a class property to expose their UUID as a class property
''' '''
UUID = None UUID: Optional[UUID] = None
def __init__(self, characteristics, primary=True): def __init__(self, characteristics, primary=True):
super().__init__(self.UUID, characteristics, primary) super().__init__(self.UUID, characteristics, primary)

View File

@@ -25,7 +25,7 @@ import asyncio
import logging import logging
import os import os
import json import json
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from .colors import color from .colors import color
from .hci import Address from .hci import Address
@@ -139,19 +139,19 @@ class PairingKeys:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class KeyStore: class KeyStore:
async def delete(self, name): async def delete(self, name: str):
pass pass
async def update(self, name, keys): async def update(self, name: str, keys: PairingKeys) -> None:
pass pass
async def get(self, _name): async def get(self, _name: str) -> Optional[PairingKeys]:
return PairingKeys() return None
async def get_all(self): async def get_all(self) -> List[Tuple[str, PairingKeys]]:
return [] return []
async def delete_all(self): async def delete_all(self) -> None:
all_keys = await self.get_all() all_keys = await self.get_all()
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys)) await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
@@ -177,15 +177,15 @@ class KeyStore:
separator = '\n' separator = '\n'
@staticmethod @staticmethod
def create_for_device(device: Device) -> Optional[KeyStore]: def create_for_device(device: Device) -> KeyStore:
if device.config.keystore is None: if device.config.keystore is None:
return None return MemoryKeyStore()
keystore_type = device.config.keystore.split(':', 1)[0] keystore_type = device.config.keystore.split(':', 1)[0]
if keystore_type == 'JsonKeyStore': if keystore_type == 'JsonKeyStore':
return JsonKeyStore.from_device(device) return JsonKeyStore.from_device(device)
return None return MemoryKeyStore()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -307,3 +307,24 @@ class JsonKeyStore(KeyStore):
return None return None
return PairingKeys.from_dict(keys) return PairingKeys.from_dict(keys)
# -----------------------------------------------------------------------------
class MemoryKeyStore(KeyStore):
all_keys: Dict[str, PairingKeys]
def __init__(self) -> None:
self.all_keys = {}
async def delete(self, name: str) -> None:
if name in self.all_keys:
del self.all_keys[name]
async def update(self, name: str, keys: PairingKeys) -> None:
self.all_keys[name] = keys
async def get(self, name: str) -> Optional[PairingKeys]:
return self.all_keys.get(name)
async def get_all(self) -> List[Tuple[str, PairingKeys]]:
return list(self.all_keys.items())

View File

@@ -65,8 +65,9 @@ class PairingDelegate:
DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG
DEFAULT_KEY_DISTRIBUTION: int = ( DEFAULT_KEY_DISTRIBUTION: KeyDistribution = (
SMP_ENC_KEY_DISTRIBUTION_FLAG | SMP_ID_KEY_DISTRIBUTION_FLAG KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY
| KeyDistribution.DISTRIBUTE_IDENTITY_KEY
) )
# Default mapping from abstract to Classic I/O capabilities. # Default mapping from abstract to Classic I/O capabilities.
@@ -85,9 +86,9 @@ class PairingDelegate:
def __init__( def __init__(
self, self,
io_capability=NO_OUTPUT_NO_INPUT, io_capability: IoCapability = NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution=DEFAULT_KEY_DISTRIBUTION, local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution=DEFAULT_KEY_DISTRIBUTION, local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
) -> None: ) -> None:
self.io_capability = io_capability self.io_capability = io_capability
self.local_initiator_key_distribution = local_initiator_key_distribution self.local_initiator_key_distribution = local_initiator_key_distribution
@@ -113,8 +114,11 @@ class PairingDelegate:
"""Accept or reject a Pairing request.""" """Accept or reject a Pairing request."""
return True return True
async def confirm(self) -> bool: async def confirm(self, auto: bool = False) -> bool:
"""Respond yes or no to a Pairing confirmation question.""" """
Respond yes or no to a Pairing confirmation question.
The `auto` parameter stands for automatic confirmation.
"""
return True return True
# pylint: disable-next=unused-argument # pylint: disable-next=unused-argument
@@ -129,7 +133,7 @@ class PairingDelegate:
""" """
return 0 return 0
async def get_string(self, max_length) -> Optional[str]: async def get_string(self, max_length: int) -> Optional[str]:
""" """
Return a string whose utf-8 encoding is up to max_length bytes. Return a string whose utf-8 encoding is up to max_length bytes.
""" """

View File

@@ -26,16 +26,22 @@ from __future__ import annotations
import logging import logging
import asyncio import asyncio
import secrets import secrets
from typing import Dict, Optional, Type from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
)
from pyee import EventEmitter from pyee import EventEmitter
from .colors import color from .colors import color
from .hci import ( from .hci import (
HCI_DISPLAY_ONLY_IO_CAPABILITY,
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
HCI_KEYBOARD_ONLY_IO_CAPABILITY,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
Address, Address,
HCI_LE_Enable_Encryption_Command, HCI_LE_Enable_Encryption_Command,
HCI_Object, HCI_Object,
@@ -51,6 +57,10 @@ from .core import (
from .keys import PairingKeys from .keys import PairingKeys
from . import crypto from . import crypto
if TYPE_CHECKING:
from bumble.device import Connection, Device
from bumble.pairing import PairingConfig
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -184,7 +194,7 @@ SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032'
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def error_name(error_code): def error_name(error_code: int) -> str:
return name_or_number(SMP_ERROR_NAMES, error_code) return name_or_number(SMP_ERROR_NAMES, error_code)
@@ -197,11 +207,12 @@ class SMP_Command:
''' '''
smp_classes: Dict[int, Type[SMP_Command]] = {} smp_classes: Dict[int, Type[SMP_Command]] = {}
fields: Any
code = 0 code = 0
name = '' name = ''
@staticmethod @staticmethod
def from_bytes(pdu): def from_bytes(pdu: bytes) -> "SMP_Command":
code = pdu[0] code = pdu[0]
cls = SMP_Command.smp_classes.get(code) cls = SMP_Command.smp_classes.get(code)
@@ -217,11 +228,11 @@ class SMP_Command:
return self return self
@staticmethod @staticmethod
def command_name(code): def command_name(code: int) -> str:
return name_or_number(SMP_COMMAND_NAMES, code) return name_or_number(SMP_COMMAND_NAMES, code)
@staticmethod @staticmethod
def auth_req_str(value): def auth_req_str(value: int) -> str:
bonding_flags = value & 3 bonding_flags = value & 3
mitm = (value >> 2) & 1 mitm = (value >> 2) & 1
sc = (value >> 3) & 1 sc = (value >> 3) & 1
@@ -234,12 +245,12 @@ class SMP_Command:
) )
@staticmethod @staticmethod
def io_capability_name(io_capability): def io_capability_name(io_capability: int) -> str:
return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability) return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability)
@staticmethod @staticmethod
def key_distribution_str(value): def key_distribution_str(value: int) -> str:
key_types = [] key_types: List[str] = []
if value & SMP_ENC_KEY_DISTRIBUTION_FLAG: if value & SMP_ENC_KEY_DISTRIBUTION_FLAG:
key_types.append('ENC') key_types.append('ENC')
if value & SMP_ID_KEY_DISTRIBUTION_FLAG: if value & SMP_ID_KEY_DISTRIBUTION_FLAG:
@@ -251,7 +262,7 @@ class SMP_Command:
return ','.join(key_types) return ','.join(key_types)
@staticmethod @staticmethod
def keypress_notification_type_name(notification_type): def keypress_notification_type_name(notification_type: int) -> str:
return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type) return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type)
@staticmethod @staticmethod
@@ -272,14 +283,14 @@ class SMP_Command:
return inner return inner
def __init__(self, pdu=None, **kwargs): def __init__(self, pdu: Optional[bytes] = None, **kwargs: Any) -> None:
if hasattr(self, 'fields') and kwargs: if hasattr(self, 'fields') and kwargs:
HCI_Object.init_from_fields(self, self.fields, kwargs) HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None: if pdu is None:
pdu = bytes([self.code]) + HCI_Object.dict_to_bytes(kwargs, self.fields) pdu = bytes([self.code]) + HCI_Object.dict_to_bytes(kwargs, self.fields)
self.pdu = pdu self.pdu = pdu
def init_from_bytes(self, pdu, offset): def init_from_bytes(self, pdu: bytes, offset: int) -> None:
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self): def to_bytes(self):
@@ -320,6 +331,13 @@ class SMP_Pairing_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request
''' '''
io_capability: int
oob_data_flag: int
auth_req: int
maximum_encryption_key_size: int
initiator_key_distribution: int
responder_key_distribution: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass( @SMP_Command.subclass(
@@ -343,6 +361,13 @@ class SMP_Pairing_Response_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response
''' '''
io_capability: int
oob_data_flag: int
auth_req: int
maximum_encryption_key_size: int
initiator_key_distribution: int
responder_key_distribution: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('confirm_value', 16)]) @SMP_Command.subclass([('confirm_value', 16)])
@@ -351,6 +376,8 @@ class SMP_Pairing_Confirm_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm
''' '''
confirm_value: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('random_value', 16)]) @SMP_Command.subclass([('random_value', 16)])
@@ -359,6 +386,8 @@ class SMP_Pairing_Random_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random
''' '''
random_value: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('reason', {'size': 1, 'mapper': error_name})]) @SMP_Command.subclass([('reason', {'size': 1, 'mapper': error_name})])
@@ -367,6 +396,8 @@ class SMP_Pairing_Failed_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed
''' '''
reason: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('public_key_x', 32), ('public_key_y', 32)]) @SMP_Command.subclass([('public_key_x', 32), ('public_key_y', 32)])
@@ -375,6 +406,9 @@ class SMP_Pairing_Public_Key_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key
''' '''
public_key_x: bytes
public_key_y: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass( @SMP_Command.subclass(
@@ -387,6 +421,8 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check
''' '''
dhkey_check: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass( @SMP_Command.subclass(
@@ -402,6 +438,8 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification
''' '''
notification_type: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('long_term_key', 16)]) @SMP_Command.subclass([('long_term_key', 16)])
@@ -410,6 +448,8 @@ class SMP_Encryption_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information
''' '''
long_term_key: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('ediv', 2), ('rand', 8)]) @SMP_Command.subclass([('ediv', 2), ('rand', 8)])
@@ -418,6 +458,9 @@ class SMP_Master_Identification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification
''' '''
ediv: int
rand: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('identity_resolving_key', 16)]) @SMP_Command.subclass([('identity_resolving_key', 16)])
@@ -426,6 +469,8 @@ class SMP_Identity_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information
''' '''
identity_resolving_key: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass( @SMP_Command.subclass(
@@ -439,6 +484,9 @@ class SMP_Identity_Address_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information
''' '''
addr_type: int
bd_addr: Address
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass([('signature_key', 16)]) @SMP_Command.subclass([('signature_key', 16)])
@@ -447,6 +495,8 @@ class SMP_Signing_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information
''' '''
signature_key: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SMP_Command.subclass( @SMP_Command.subclass(
@@ -459,9 +509,11 @@ class SMP_Security_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request
''' '''
auth_req: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def smp_auth_req(bonding, mitm, sc, keypress, ct2): def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) -> int:
value = 0 value = 0
if bonding: if bonding:
value |= SMP_BONDING_AUTHREQ value |= SMP_BONDING_AUTHREQ
@@ -574,11 +626,17 @@ class Session:
}, },
} }
def __init__(self, manager, connection, pairing_config, is_initiator): def __init__(
self,
manager: Manager,
connection: Connection,
pairing_config: PairingConfig,
is_initiator: bool,
) -> None:
self.manager = manager self.manager = manager
self.connection = connection self.connection = connection
self.preq = None self.preq: Optional[bytes] = None
self.pres = None self.pres: Optional[bytes] = None
self.ea = None self.ea = None
self.eb = None self.eb = None
self.tk = bytes(16) self.tk = bytes(16)
@@ -588,29 +646,29 @@ class Session:
self.ltk_ediv = 0 self.ltk_ediv = 0
self.ltk_rand = bytes(8) self.ltk_rand = bytes(8)
self.link_key = None self.link_key = None
self.initiator_key_distribution = 0 self.initiator_key_distribution: int = 0
self.responder_key_distribution = 0 self.responder_key_distribution: int = 0
self.peer_random_value = None self.peer_random_value: Optional[bytes] = None
self.peer_public_key_x = bytes(32) self.peer_public_key_x: bytes = bytes(32)
self.peer_public_key_y = bytes(32) self.peer_public_key_y = bytes(32)
self.peer_ltk = None self.peer_ltk = None
self.peer_ediv = None self.peer_ediv = None
self.peer_rand = None self.peer_rand: Optional[bytes] = None
self.peer_identity_resolving_key = None self.peer_identity_resolving_key = None
self.peer_bd_addr = None self.peer_bd_addr: Optional[Address] = None
self.peer_signature_key = None self.peer_signature_key = None
self.peer_expected_distributions = [] self.peer_expected_distributions: List[Type[SMP_Command]] = []
self.dh_key = None self.dh_key = None
self.confirm_value = None self.confirm_value = None
self.passkey = None self.passkey: Optional[int] = None
self.passkey_ready = asyncio.Event() self.passkey_ready = asyncio.Event()
self.passkey_step = 0 self.passkey_step = 0
self.passkey_display = False self.passkey_display = False
self.pairing_method = 0 self.pairing_method = 0
self.pairing_config = pairing_config self.pairing_config = pairing_config
self.wait_before_continuing = None self.wait_before_continuing: Optional[asyncio.Future[None]] = None
self.completed = False self.completed = False
self.ctkd_task = None self.ctkd_task: Optional[Awaitable[None]] = None
# Decide if we're the initiator or the responder # Decide if we're the initiator or the responder
self.is_initiator = is_initiator self.is_initiator = is_initiator
@@ -628,7 +686,9 @@ class Session:
# Create a future that can be used to wait for the session to complete # Create a future that can be used to wait for the session to complete
if self.is_initiator: if self.is_initiator:
self.pairing_result = asyncio.get_running_loop().create_future() self.pairing_result: Optional[
asyncio.Future[None]
] = asyncio.get_running_loop().create_future()
else: else:
self.pairing_result = None self.pairing_result = None
@@ -641,11 +701,11 @@ class Session:
) )
# Authentication Requirements Flags - Vol 3, Part H, Figure 3.3 # Authentication Requirements Flags - Vol 3, Part H, Figure 3.3
self.bonding = pairing_config.bonding self.bonding: bool = pairing_config.bonding
self.sc = pairing_config.sc self.sc: bool = pairing_config.sc
self.mitm = pairing_config.mitm self.mitm: bool = pairing_config.mitm
self.keypress = False self.keypress = False
self.ct2 = False self.ct2: bool = False
# I/O Capabilities # I/O Capabilities
self.io_capability = pairing_config.delegate.io_capability self.io_capability = pairing_config.delegate.io_capability
@@ -669,34 +729,35 @@ class Session:
self.iat = 1 if peer_address.is_random else 0 self.iat = 1 if peer_address.is_random else 0
@property @property
def pkx(self): def pkx(self) -> Tuple[bytes, bytes]:
return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x) return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x)
@property @property
def pka(self): def pka(self) -> bytes:
return self.pkx[0 if self.is_initiator else 1] return self.pkx[0 if self.is_initiator else 1]
@property @property
def pkb(self): def pkb(self) -> bytes:
return self.pkx[0 if self.is_responder else 1] return self.pkx[0 if self.is_responder else 1]
@property @property
def nx(self): def nx(self) -> Tuple[bytes, bytes]:
assert self.peer_random_value
return (self.r, self.peer_random_value) return (self.r, self.peer_random_value)
@property @property
def na(self): def na(self) -> bytes:
return self.nx[0 if self.is_initiator else 1] return self.nx[0 if self.is_initiator else 1]
@property @property
def nb(self): def nb(self) -> bytes:
return self.nx[0 if self.is_responder else 1] return self.nx[0 if self.is_responder else 1]
@property @property
def auth_req(self): def auth_req(self) -> int:
return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2) return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2)
def get_long_term_key(self, rand, ediv): def get_long_term_key(self, rand: bytes, ediv: int) -> Optional[bytes]:
if not self.sc and not self.completed: if not self.sc and not self.completed:
if rand == self.ltk_rand and ediv == self.ltk_ediv: if rand == self.ltk_rand and ediv == self.ltk_ediv:
return self.stk return self.stk
@@ -706,13 +767,13 @@ class Session:
return None return None
def decide_pairing_method( def decide_pairing_method(
self, auth_req, initiator_io_capability, responder_io_capability self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
): ) -> None:
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0): if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
self.pairing_method = self.JUST_WORKS self.pairing_method = self.JUST_WORKS
return return
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] # type: ignore[index]
if isinstance(details, tuple) and len(details) == 2: if isinstance(details, tuple) and len(details) == 2:
# One entry for legacy pairing and one for secure connections # One entry for legacy pairing and one for secure connections
details = details[1 if self.sc else 0] details = details[1 if self.sc else 0]
@@ -724,7 +785,9 @@ class Session:
self.pairing_method = details[0] self.pairing_method = details[0]
self.passkey_display = details[1 if self.is_initiator else 2] self.passkey_display = details[1 if self.is_initiator else 2]
def check_expected_value(self, expected, received, error): def check_expected_value(
self, expected: bytes, received: bytes, error: int
) -> bool:
logger.debug(f'expected={expected.hex()} got={received.hex()}') logger.debug(f'expected={expected.hex()} got={received.hex()}')
if expected != received: if expected != received:
logger.info(color('pairing confirm/check mismatch', 'red')) logger.info(color('pairing confirm/check mismatch', 'red'))
@@ -732,8 +795,8 @@ class Session:
return False return False
return True return True
def prompt_user_for_confirmation(self, next_steps): def prompt_user_for_confirmation(self, next_steps: Callable[[], None]) -> None:
async def prompt(): async def prompt() -> None:
logger.debug('ask for confirmation') logger.debug('ask for confirmation')
try: try:
response = await self.pairing_config.delegate.confirm() response = await self.pairing_config.delegate.confirm()
@@ -747,8 +810,10 @@ class Session:
self.connection.abort_on('disconnection', prompt()) self.connection.abort_on('disconnection', prompt())
def prompt_user_for_numeric_comparison(self, code, next_steps): def prompt_user_for_numeric_comparison(
async def prompt(): self, code: int, next_steps: Callable[[], None]
) -> None:
async def prompt() -> None:
logger.debug(f'verification code: {code}') logger.debug(f'verification code: {code}')
try: try:
response = await self.pairing_config.delegate.compare_numbers( response = await self.pairing_config.delegate.compare_numbers(
@@ -764,11 +829,15 @@ class Session:
self.connection.abort_on('disconnection', prompt()) self.connection.abort_on('disconnection', prompt())
def prompt_user_for_number(self, next_steps): def prompt_user_for_number(self, next_steps: Callable[[int], None]) -> None:
async def prompt(): async def prompt() -> None:
logger.debug('prompting user for passkey') logger.debug('prompting user for passkey')
try: try:
passkey = await self.pairing_config.delegate.get_number() passkey = await self.pairing_config.delegate.get_number()
if passkey is None:
logger.debug('Passkey request rejected')
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
return
logger.debug(f'user input: {passkey}') logger.debug(f'user input: {passkey}')
next_steps(passkey) next_steps(passkey)
except Exception as error: except Exception as error:
@@ -777,9 +846,10 @@ class Session:
self.connection.abort_on('disconnection', prompt()) self.connection.abort_on('disconnection', prompt())
def display_passkey(self): def display_passkey(self) -> None:
# Generate random Passkey/PIN code # Generate random Passkey/PIN code
self.passkey = secrets.randbelow(1000000) self.passkey = secrets.randbelow(1000000)
assert self.passkey is not None
logger.debug(f'Pairing PIN CODE: {self.passkey:06}') logger.debug(f'Pairing PIN CODE: {self.passkey:06}')
self.passkey_ready.set() self.passkey_ready.set()
@@ -793,9 +863,9 @@ class Session:
self.pairing_config.delegate.display_number(self.passkey, digits=6), self.pairing_config.delegate.display_number(self.passkey, digits=6),
) )
def input_passkey(self, next_steps=None): def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None:
# Prompt the user for the passkey displayed on the peer # Prompt the user for the passkey displayed on the peer
def after_input(passkey): def after_input(passkey: int) -> None:
self.passkey = passkey self.passkey = passkey
if not self.sc: if not self.sc:
@@ -809,7 +879,9 @@ class Session:
self.prompt_user_for_number(after_input) self.prompt_user_for_number(after_input)
def display_or_input_passkey(self, next_steps=None): def display_or_input_passkey(
self, next_steps: Optional[Callable[[], None]] = None
) -> None:
if self.passkey_display: if self.passkey_display:
self.display_passkey() self.display_passkey()
if next_steps is not None: if next_steps is not None:
@@ -817,14 +889,14 @@ class Session:
else: else:
self.input_passkey(next_steps) self.input_passkey(next_steps)
def send_command(self, command): def send_command(self, command: SMP_Command) -> None:
self.manager.send_command(self.connection, command) self.manager.send_command(self.connection, command)
def send_pairing_failed(self, error): def send_pairing_failed(self, error: int) -> None:
self.send_command(SMP_Pairing_Failed_Command(reason=error)) self.send_command(SMP_Pairing_Failed_Command(reason=error))
self.on_pairing_failure(error) self.on_pairing_failure(error)
def send_pairing_request_command(self): def send_pairing_request_command(self) -> None:
self.manager.on_session_start(self) self.manager.on_session_start(self)
command = SMP_Pairing_Request_Command( command = SMP_Pairing_Request_Command(
@@ -838,7 +910,7 @@ class Session:
self.preq = bytes(command) self.preq = bytes(command)
self.send_command(command) self.send_command(command)
def send_pairing_response_command(self): def send_pairing_response_command(self) -> None:
response = SMP_Pairing_Response_Command( response = SMP_Pairing_Response_Command(
io_capability=self.io_capability, io_capability=self.io_capability,
oob_data_flag=0, oob_data_flag=0,
@@ -850,18 +922,19 @@ class Session:
self.pres = bytes(response) self.pres = bytes(response)
self.send_command(response) self.send_command(response)
def send_pairing_confirm_command(self): def send_pairing_confirm_command(self) -> None:
self.r = crypto.r() self.r = crypto.r()
logger.debug(f'generated random: {self.r.hex()}') logger.debug(f'generated random: {self.r.hex()}')
if self.sc: if self.sc:
async def next_steps(): async def next_steps() -> None:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
z = 0 z = 0
elif self.pairing_method == self.PASSKEY: elif self.pairing_method == self.PASSKEY:
# We need a passkey # We need a passkey
await self.passkey_ready.wait() await self.passkey_ready.wait()
assert self.passkey
z = 0x80 + ((self.passkey >> self.passkey_step) & 1) z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else: else:
@@ -892,10 +965,10 @@ class Session:
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value)) self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
def send_pairing_random_command(self): def send_pairing_random_command(self) -> None:
self.send_command(SMP_Pairing_Random_Command(random_value=self.r)) self.send_command(SMP_Pairing_Random_Command(random_value=self.r))
def send_public_key_command(self): def send_public_key_command(self) -> None:
self.send_command( self.send_command(
SMP_Pairing_Public_Key_Command( SMP_Pairing_Public_Key_Command(
public_key_x=bytes(reversed(self.manager.ecc_key.x)), public_key_x=bytes(reversed(self.manager.ecc_key.x)),
@@ -903,18 +976,18 @@ class Session:
) )
) )
def send_pairing_dhkey_check_command(self): def send_pairing_dhkey_check_command(self) -> None:
self.send_command( self.send_command(
SMP_Pairing_DHKey_Check_Command( SMP_Pairing_DHKey_Check_Command(
dhkey_check=self.ea if self.is_initiator else self.eb dhkey_check=self.ea if self.is_initiator else self.eb
) )
) )
def start_encryption(self, key): def start_encryption(self, key: bytes) -> None:
# We can now encrypt the connection with the short term key, so that we can # We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection # distribute the long term and/or other keys over an encrypted connection
self.manager.device.host.send_command_sync( self.manager.device.host.send_command_sync(
HCI_LE_Enable_Encryption_Command( HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
connection_handle=self.connection.handle, connection_handle=self.connection.handle,
random_number=bytes(8), random_number=bytes(8),
encrypted_diversifier=0, encrypted_diversifier=0,
@@ -922,7 +995,7 @@ class Session:
) )
) )
async def derive_ltk(self): async def derive_ltk(self) -> None:
link_key = await self.manager.device.get_link_key(self.connection.peer_address) link_key = await self.manager.device.get_link_key(self.connection.peer_address)
assert link_key is not None assert link_key is not None
ilk = ( ilk = (
@@ -932,7 +1005,7 @@ class Session:
) )
self.ltk = crypto.h6(ilk, b'brle') self.ltk = crypto.h6(ilk, b'brle')
def distribute_keys(self): def distribute_keys(self) -> None:
# Distribute the keys as required # Distribute the keys as required
if self.is_initiator: if self.is_initiator:
# CTKD: Derive LTK from LinkKey # CTKD: Derive LTK from LinkKey
@@ -1032,7 +1105,7 @@ class Session:
) )
self.link_key = crypto.h6(ilk, b'lebr') self.link_key = crypto.h6(ilk, b'lebr')
def compute_peer_expected_distributions(self, key_distribution_flags): def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
# Set our expectations for what to wait for in the key distribution phase # Set our expectations for what to wait for in the key distribution phase
self.peer_expected_distributions = [] self.peer_expected_distributions = []
if not self.sc and self.connection.transport == BT_LE_TRANSPORT: if not self.sc and self.connection.transport == BT_LE_TRANSPORT:
@@ -1055,7 +1128,7 @@ class Session:
f'{[c.__name__ for c in self.peer_expected_distributions]}' f'{[c.__name__ for c in self.peer_expected_distributions]}'
) )
def check_key_distribution(self, command_class): def check_key_distribution(self, command_class: Type[SMP_Command]) -> None:
# First, check that the connection is encrypted # First, check that the connection is encrypted
if not self.connection.is_encrypted: if not self.connection.is_encrypted:
logger.warning( logger.warning(
@@ -1083,7 +1156,7 @@ class Session:
) )
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR) self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
async def pair(self): async def pair(self) -> None:
# Start pairing as an initiator # Start pairing as an initiator
# TODO: check that this session isn't already active # TODO: check that this session isn't already active
@@ -1091,9 +1164,10 @@ class Session:
self.send_pairing_request_command() self.send_pairing_request_command()
# Wait for the pairing process to finish # Wait for the pairing process to finish
assert self.pairing_result
await self.connection.abort_on('disconnection', self.pairing_result) await self.connection.abort_on('disconnection', self.pairing_result)
def on_disconnection(self, _): def on_disconnection(self, _: int) -> None:
self.connection.remove_listener('disconnection', self.on_disconnection) self.connection.remove_listener('disconnection', self.on_disconnection)
self.connection.remove_listener( self.connection.remove_listener(
'connection_encryption_change', self.on_connection_encryption_change 'connection_encryption_change', self.on_connection_encryption_change
@@ -1104,14 +1178,14 @@ class Session:
) )
self.manager.on_session_end(self) self.manager.on_session_end(self)
def on_peer_key_distribution_complete(self): def on_peer_key_distribution_complete(self) -> None:
# The initiator can now send its keys # The initiator can now send its keys
if self.is_initiator: if self.is_initiator:
self.distribute_keys() self.distribute_keys()
self.connection.abort_on('disconnection', self.on_pairing()) self.connection.abort_on('disconnection', self.on_pairing())
def on_connection_encryption_change(self): def on_connection_encryption_change(self) -> None:
if self.connection.is_encrypted: if self.connection.is_encrypted:
if self.is_responder: if self.is_responder:
# The responder distributes its keys first, the initiator later # The responder distributes its keys first, the initiator later
@@ -1121,11 +1195,11 @@ class Session:
if not self.peer_expected_distributions: if not self.peer_expected_distributions:
self.on_peer_key_distribution_complete() self.on_peer_key_distribution_complete()
def on_connection_encryption_key_refresh(self): def on_connection_encryption_key_refresh(self) -> None:
# Do as if the connection had just been encrypted # Do as if the connection had just been encrypted
self.on_connection_encryption_change() self.on_connection_encryption_change()
async def on_pairing(self): async def on_pairing(self) -> None:
logger.debug('pairing complete') logger.debug('pairing complete')
if self.completed: if self.completed:
@@ -1137,7 +1211,7 @@ class Session:
self.pairing_result.set_result(None) self.pairing_result.set_result(None)
# Use the peer address from the pairing protocol or the connection # Use the peer address from the pairing protocol or the connection
if self.peer_bd_addr: if self.peer_bd_addr is not None:
peer_address = self.peer_bd_addr peer_address = self.peer_bd_addr
else: else:
peer_address = self.connection.peer_address peer_address = self.connection.peer_address
@@ -1186,7 +1260,7 @@ class Session:
) )
self.manager.on_pairing(self, peer_address, keys) self.manager.on_pairing(self, peer_address, keys)
def on_pairing_failure(self, reason): def on_pairing_failure(self, reason: int) -> None:
logger.warning(f'pairing failure ({error_name(reason)})') logger.warning(f'pairing failure ({error_name(reason)})')
if self.completed: if self.completed:
@@ -1199,7 +1273,7 @@ class Session:
self.pairing_result.set_exception(error) self.pairing_result.set_exception(error)
self.manager.on_pairing_failure(self, reason) self.manager.on_pairing_failure(self, reason)
def on_smp_command(self, command): def on_smp_command(self, command: SMP_Command) -> None:
# Find the handler method # Find the handler method
handler_name = f'on_{command.name.lower()}' handler_name = f'on_{command.name.lower()}'
handler = getattr(self, handler_name, None) handler = getattr(self, handler_name, None)
@@ -1215,12 +1289,16 @@ class Session:
else: else:
logger.error(color('SMP command not handled???', 'red')) logger.error(color('SMP command not handled???', 'red'))
def on_smp_pairing_request_command(self, command): def on_smp_pairing_request_command(
self, command: SMP_Pairing_Request_Command
) -> None:
self.connection.abort_on( self.connection.abort_on(
'disconnection', self.on_smp_pairing_request_command_async(command) 'disconnection', self.on_smp_pairing_request_command_async(command)
) )
async def on_smp_pairing_request_command_async(self, command): async def on_smp_pairing_request_command_async(
self, command: SMP_Pairing_Request_Command
) -> None:
# Check if the request should proceed # Check if the request should proceed
accepted = await self.pairing_config.delegate.accept() accepted = await self.pairing_config.delegate.accept()
if not accepted: if not accepted:
@@ -1280,7 +1358,9 @@ class Session:
): ):
self.distribute_keys() self.distribute_keys()
def on_smp_pairing_response_command(self, command): def on_smp_pairing_response_command(
self, command: SMP_Pairing_Response_Command
) -> None:
if self.is_responder: if self.is_responder:
logger.warning(color('received pairing response as a responder', 'red')) logger.warning(color('received pairing response as a responder', 'red'))
return return
@@ -1331,7 +1411,9 @@ class Session:
else: else:
self.send_pairing_confirm_command() self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command_legacy(self, _): def on_smp_pairing_confirm_command_legacy(
self, _: SMP_Pairing_Confirm_Command
) -> None:
if self.is_initiator: if self.is_initiator:
self.send_pairing_random_command() self.send_pairing_random_command()
else: else:
@@ -1341,7 +1423,9 @@ class Session:
else: else:
self.send_pairing_confirm_command() self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command_secure_connections(self, _): def on_smp_pairing_confirm_command_secure_connections(
self, _: SMP_Pairing_Confirm_Command
) -> None:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.is_initiator: if self.is_initiator:
self.r = crypto.r() self.r = crypto.r()
@@ -1352,14 +1436,18 @@ class Session:
else: else:
self.send_pairing_confirm_command() self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command(self, command): def on_smp_pairing_confirm_command(
self, command: SMP_Pairing_Confirm_Command
) -> None:
self.confirm_value = command.confirm_value self.confirm_value = command.confirm_value
if self.sc: if self.sc:
self.on_smp_pairing_confirm_command_secure_connections(command) self.on_smp_pairing_confirm_command_secure_connections(command)
else: else:
self.on_smp_pairing_confirm_command_legacy(command) self.on_smp_pairing_confirm_command_legacy(command)
def on_smp_pairing_random_command_legacy(self, command): def on_smp_pairing_random_command_legacy(
self, command: SMP_Pairing_Random_Command
) -> None:
# Check that the confirmation values match # Check that the confirmation values match
confirm_verifier = crypto.c1( confirm_verifier = crypto.c1(
self.tk, self.tk,
@@ -1371,6 +1459,7 @@ class Session:
self.ia, self.ia,
self.ra, self.ra,
) )
assert self.confirm_value
if not self.check_expected_value( if not self.check_expected_value(
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
): ):
@@ -1394,7 +1483,9 @@ class Session:
else: else:
self.send_pairing_random_command() self.send_pairing_random_command()
def on_smp_pairing_random_command_secure_connections(self, command): def on_smp_pairing_random_command_secure_connections(
self, command: SMP_Pairing_Random_Command
) -> None:
if self.pairing_method == self.PASSKEY and self.passkey is None: if self.pairing_method == self.PASSKEY and self.passkey is None:
logger.warning('no passkey entered, ignoring command') logger.warning('no passkey entered, ignoring command')
return return
@@ -1402,6 +1493,7 @@ class Session:
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
if self.is_initiator: if self.is_initiator:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
assert self.confirm_value
# Check that the random value matches what was committed to earlier # Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4( confirm_verifier = crypto.f4(
self.pkb, self.pka, command.random_value, bytes([0]) self.pkb, self.pka, command.random_value, bytes([0])
@@ -1411,6 +1503,7 @@ class Session:
): ):
return return
elif self.pairing_method == self.PASSKEY: elif self.pairing_method == self.PASSKEY:
assert self.passkey and self.confirm_value
# Check that the random value matches what was committed to earlier # Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4( confirm_verifier = crypto.f4(
self.pkb, self.pkb,
@@ -1435,6 +1528,7 @@ class Session:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
self.send_pairing_random_command() self.send_pairing_random_command()
elif self.pairing_method == self.PASSKEY: elif self.pairing_method == self.PASSKEY:
assert self.passkey and self.confirm_value
# Check that the random value matches what was committed to earlier # Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4( confirm_verifier = crypto.f4(
self.pka, self.pka,
@@ -1468,19 +1562,21 @@ class Session:
ra = bytes(16) ra = bytes(16)
rb = ra rb = ra
elif self.pairing_method == self.PASSKEY: elif self.pairing_method == self.PASSKEY:
assert self.passkey
ra = self.passkey.to_bytes(16, byteorder='little') ra = self.passkey.to_bytes(16, byteorder='little')
rb = ra rb = ra
else: else:
# OOB not implemented yet # OOB not implemented yet
return return
assert self.preq and self.pres
io_cap_a = self.preq[1:4] io_cap_a = self.preq[1:4]
io_cap_b = self.pres[1:4] io_cap_b = self.pres[1:4]
self.ea = crypto.f6(mac_key, self.na, self.nb, rb, io_cap_a, a, b) self.ea = crypto.f6(mac_key, self.na, self.nb, rb, io_cap_a, a, b)
self.eb = crypto.f6(mac_key, self.nb, self.na, ra, io_cap_b, b, a) self.eb = crypto.f6(mac_key, self.nb, self.na, ra, io_cap_b, b, a)
# Next steps to be performed after possible user confirmation # Next steps to be performed after possible user confirmation
def next_steps(): def next_steps() -> None:
# The initiator sends the DH Key check to the responder # The initiator sends the DH Key check to the responder
if self.is_initiator: if self.is_initiator:
self.send_pairing_dhkey_check_command() self.send_pairing_dhkey_check_command()
@@ -1502,14 +1598,18 @@ class Session:
else: else:
next_steps() next_steps()
def on_smp_pairing_random_command(self, command): def on_smp_pairing_random_command(
self, command: SMP_Pairing_Random_Command
) -> None:
self.peer_random_value = command.random_value self.peer_random_value = command.random_value
if self.sc: if self.sc:
self.on_smp_pairing_random_command_secure_connections(command) self.on_smp_pairing_random_command_secure_connections(command)
else: else:
self.on_smp_pairing_random_command_legacy(command) self.on_smp_pairing_random_command_legacy(command)
def on_smp_pairing_public_key_command(self, command): def on_smp_pairing_public_key_command(
self, command: SMP_Pairing_Public_Key_Command
) -> None:
# Store the public key so that we can compute the confirmation value later # Store the public key so that we can compute the confirmation value later
self.peer_public_key_x = command.public_key_x self.peer_public_key_x = command.public_key_x
self.peer_public_key_y = command.public_key_y self.peer_public_key_y = command.public_key_y
@@ -1538,9 +1638,12 @@ class Session:
# We can now send the confirmation value # We can now send the confirmation value
self.send_pairing_confirm_command() self.send_pairing_confirm_command()
def on_smp_pairing_dhkey_check_command(self, command): def on_smp_pairing_dhkey_check_command(
self, command: SMP_Pairing_DHKey_Check_Command
) -> None:
# Check that what we received matches what we computed earlier # Check that what we received matches what we computed earlier
expected = self.eb if self.is_initiator else self.ea expected = self.eb if self.is_initiator else self.ea
assert expected
if not self.check_expected_value( if not self.check_expected_value(
expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR
): ):
@@ -1549,7 +1652,8 @@ class Session:
if self.is_responder: if self.is_responder:
if self.wait_before_continuing is not None: if self.wait_before_continuing is not None:
async def next_steps(): async def next_steps() -> None:
assert self.wait_before_continuing
await self.wait_before_continuing await self.wait_before_continuing
self.wait_before_continuing = None self.wait_before_continuing = None
self.send_pairing_dhkey_check_command() self.send_pairing_dhkey_check_command()
@@ -1558,29 +1662,42 @@ class Session:
else: else:
self.send_pairing_dhkey_check_command() self.send_pairing_dhkey_check_command()
else: else:
assert self.ltk
self.start_encryption(self.ltk) self.start_encryption(self.ltk)
def on_smp_pairing_failed_command(self, command): def on_smp_pairing_failed_command(
self, command: SMP_Pairing_Failed_Command
) -> None:
self.on_pairing_failure(command.reason) self.on_pairing_failure(command.reason)
def on_smp_encryption_information_command(self, command): def on_smp_encryption_information_command(
self, command: SMP_Encryption_Information_Command
) -> None:
self.peer_ltk = command.long_term_key self.peer_ltk = command.long_term_key
self.check_key_distribution(SMP_Encryption_Information_Command) self.check_key_distribution(SMP_Encryption_Information_Command)
def on_smp_master_identification_command(self, command): def on_smp_master_identification_command(
self, command: SMP_Master_Identification_Command
) -> None:
self.peer_ediv = command.ediv self.peer_ediv = command.ediv
self.peer_rand = command.rand self.peer_rand = command.rand
self.check_key_distribution(SMP_Master_Identification_Command) self.check_key_distribution(SMP_Master_Identification_Command)
def on_smp_identity_information_command(self, command): def on_smp_identity_information_command(
self, command: SMP_Identity_Information_Command
) -> None:
self.peer_identity_resolving_key = command.identity_resolving_key self.peer_identity_resolving_key = command.identity_resolving_key
self.check_key_distribution(SMP_Identity_Information_Command) self.check_key_distribution(SMP_Identity_Information_Command)
def on_smp_identity_address_information_command(self, command): def on_smp_identity_address_information_command(
self, command: SMP_Identity_Address_Information_Command
) -> None:
self.peer_bd_addr = command.bd_addr self.peer_bd_addr = command.bd_addr
self.check_key_distribution(SMP_Identity_Address_Information_Command) self.check_key_distribution(SMP_Identity_Address_Information_Command)
def on_smp_signing_information_command(self, command): def on_smp_signing_information_command(
self, command: SMP_Signing_Information_Command
) -> None:
self.peer_signature_key = command.signature_key self.peer_signature_key = command.signature_key
self.check_key_distribution(SMP_Signing_Information_Command) self.check_key_distribution(SMP_Signing_Information_Command)
@@ -1591,14 +1708,24 @@ class Manager(EventEmitter):
Implements the Initiator and Responder roles of the Security Manager Protocol Implements the Initiator and Responder roles of the Security Manager Protocol
''' '''
def __init__(self, device, pairing_config_factory): device: Device
sessions: Dict[int, Session]
pairing_config_factory: Callable[[Connection], PairingConfig]
session_proxy: Type[Session]
def __init__(
self,
device: Device,
pairing_config_factory: Callable[[Connection], PairingConfig],
) -> None:
super().__init__() super().__init__()
self.device = device self.device = device
self.sessions = {} self.sessions = {}
self._ecc_key = None self._ecc_key = None
self.pairing_config_factory = pairing_config_factory self.pairing_config_factory = pairing_config_factory
self.session_proxy = Session
def send_command(self, connection, command): def send_command(self, connection: Connection, command: SMP_Command) -> None:
logger.debug( logger.debug(
f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] ' f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}' f'{connection.peer_address}: {command}'
@@ -1606,20 +1733,15 @@ class Manager(EventEmitter):
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes()) connection.send_l2cap_pdu(cid, command.to_bytes())
def on_smp_pdu(self, connection, pdu): def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
# Look for a session with this connection, and create one if none exists # Look for a session with this connection, and create one if none exists
if not (session := self.sessions.get(connection.handle)): if not (session := self.sessions.get(connection.handle)):
if connection.role == BT_CENTRAL_ROLE: if connection.role == BT_CENTRAL_ROLE:
logger.warning('Remote starts pairing as Peripheral!') logger.warning('Remote starts pairing as Peripheral!')
pairing_config = self.pairing_config_factory(connection) pairing_config = self.pairing_config_factory(connection)
if pairing_config is None: session = self.session_proxy(
# Pairing disabled self, connection, pairing_config, is_initiator=False
self.send_command( )
connection,
SMP_Pairing_Failed_Command(reason=SMP_PAIRING_NOT_SUPPORTED_ERROR),
)
return
session = Session(self, connection, pairing_config, is_initiator=False)
self.sessions[connection.handle] = session self.sessions[connection.handle] = session
# Parse the L2CAP payload into an SMP Command object # Parse the L2CAP payload into an SMP Command object
@@ -1633,23 +1755,24 @@ class Manager(EventEmitter):
session.on_smp_command(command) session.on_smp_command(command)
@property @property
def ecc_key(self): def ecc_key(self) -> crypto.EccKey:
if self._ecc_key is None: if self._ecc_key is None:
self._ecc_key = crypto.EccKey.generate() self._ecc_key = crypto.EccKey.generate()
assert self._ecc_key
return self._ecc_key return self._ecc_key
async def pair(self, connection): async def pair(self, connection: Connection) -> None:
# TODO: check if there's already a session for this connection # TODO: check if there's already a session for this connection
if connection.role != BT_CENTRAL_ROLE: if connection.role != BT_CENTRAL_ROLE:
logger.warning('Start pairing as Peripheral!') logger.warning('Start pairing as Peripheral!')
pairing_config = self.pairing_config_factory(connection) pairing_config = self.pairing_config_factory(connection)
if pairing_config is None: session = self.session_proxy(
raise ValueError('pairing config must not be None when initiating') self, connection, pairing_config, is_initiator=True
session = Session(self, connection, pairing_config, is_initiator=True) )
self.sessions[connection.handle] = session self.sessions[connection.handle] = session
return await session.pair() return await session.pair()
def request_pairing(self, connection): def request_pairing(self, connection: Connection) -> None:
pairing_config = self.pairing_config_factory(connection) pairing_config = self.pairing_config_factory(connection)
if pairing_config: if pairing_config:
auth_req = smp_auth_req( auth_req = smp_auth_req(
@@ -1663,15 +1786,18 @@ class Manager(EventEmitter):
auth_req = 0 auth_req = 0
self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req)) self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req))
def on_session_start(self, session): def on_session_start(self, session: Session) -> None:
self.device.on_pairing_start(session.connection.handle) self.device.on_pairing_start(session.connection)
def on_pairing(self, session, identity_address, keys): def on_pairing(
self, session: Session, identity_address: Optional[Address], keys: PairingKeys
) -> None:
# Store the keys in the key store # Store the keys in the key store
if self.device.keystore and identity_address is not None: if self.device.keystore and identity_address is not None:
async def store_keys(): async def store_keys():
try: try:
assert self.device.keystore
await self.device.keystore.update(str(identity_address), keys) await self.device.keystore.update(str(identity_address), keys)
except Exception as error: except Exception as error:
logger.warning(f'!!! error while storing keys: {error}') logger.warning(f'!!! error while storing keys: {error}')
@@ -1679,17 +1805,19 @@ class Manager(EventEmitter):
self.device.abort_on('flush', store_keys()) self.device.abort_on('flush', store_keys())
# Notify the device # Notify the device
self.device.on_pairing(session.connection.handle, keys, session.sc) self.device.on_pairing(session.connection, keys, session.sc)
def on_pairing_failure(self, session, reason): def on_pairing_failure(self, session: Session, reason: int) -> None:
self.device.on_pairing_failure(session.connection.handle, reason) self.device.on_pairing_failure(session.connection, reason)
def on_session_end(self, session): def on_session_end(self, session: Session) -> None:
logger.debug(f'session end for connection 0x{session.connection.handle:04X}') logger.debug(f'session end for connection 0x{session.connection.handle:04X}')
if session.connection.handle in self.sessions: if session.connection.handle in self.sessions:
del self.sessions[session.connection.handle] del self.sessions[session.connection.handle]
def get_long_term_key(self, connection, rand, ediv): def get_long_term_key(
self, connection: Connection, rand: bytes, ediv: int
) -> Optional[bytes]:
if session := self.sessions.get(connection.handle): if session := self.sessions.get(connection.handle):
return session.get_long_term_key(rand, ediv) return session.get_long_term_key(rand, ediv)

View File

@@ -15,7 +15,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from bumble.core import AdvertisingData, get_dict_key_by_value from bumble.core import AdvertisingData, UUID, get_dict_key_by_value
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_ad_data(): def test_ad_data():
@@ -49,6 +49,24 @@ def test_get_dict_key_by_value():
assert get_dict_key_by_value(dictionary, 3) is None assert get_dict_key_by_value(dictionary, 3) is None
# -----------------------------------------------------------------------------
def test_uuid_to_hex_str() -> None:
assert UUID("b5ea").to_hex_str() == "B5EA"
assert UUID("df5ce654").to_hex_str() == "DF5CE654"
assert (
UUID("df5ce654-e059-11ed-b5ea-0242ac120002").to_hex_str()
== "DF5CE654E05911EDB5EA0242AC120002"
)
assert UUID("b5ea").to_hex_str('-') == "B5EA"
assert UUID("df5ce654").to_hex_str('-') == "DF5CE654"
assert (
UUID("df5ce654-e059-11ed-b5ea-0242ac120002").to_hex_str('-')
== "DF5CE654-E059-11ED-B5EA-0242AC120002"
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
test_ad_data() test_ad_data()
test_get_dict_key_by_value()
test_uuid_to_hex_str()