forked from auracaster/bumble_mirror
Merge pull request #178 from google/uael/pairing
Overall fixes and improvements
This commit is contained in:
@@ -152,7 +152,12 @@ class UUID:
|
||||
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian
|
||||
UUIDS: List[UUID] = [] # Registry of all instances created
|
||||
|
||||
def __init__(self, uuid_str_or_int, name=None):
|
||||
uuid_bytes: bytes
|
||||
name: Optional[str]
|
||||
|
||||
def __init__(
|
||||
self, uuid_str_or_int: Union[str, int], name: Optional[str] = None
|
||||
) -> None:
|
||||
if isinstance(uuid_str_or_int, int):
|
||||
self.uuid_bytes = struct.pack('<H', uuid_str_or_int)
|
||||
else:
|
||||
@@ -172,7 +177,7 @@ class UUID:
|
||||
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
|
||||
self.name = name
|
||||
|
||||
def register(self):
|
||||
def register(self) -> UUID:
|
||||
# Register this object in the class registry, and update the entry's name if
|
||||
# it wasn't set already
|
||||
for uuid in self.UUIDS:
|
||||
@@ -196,22 +201,22 @@ class UUID:
|
||||
raise ValueError('only 2, 4 and 16 bytes are allowed')
|
||||
|
||||
@classmethod
|
||||
def from_16_bits(cls, uuid_16, name=None):
|
||||
def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID:
|
||||
return cls.from_bytes(struct.pack('<H', uuid_16), name)
|
||||
|
||||
@classmethod
|
||||
def from_32_bits(cls, uuid_32, name=None):
|
||||
def from_32_bits(cls, uuid_32: int, name: Optional[str] = None) -> UUID:
|
||||
return cls.from_bytes(struct.pack('<I', uuid_32), name)
|
||||
|
||||
@classmethod
|
||||
def parse_uuid(cls, uuid_as_bytes, offset):
|
||||
def parse_uuid(cls, uuid_as_bytes: bytes, offset: int) -> Tuple[int, UUID]:
|
||||
return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:])
|
||||
|
||||
@classmethod
|
||||
def parse_uuid_2(cls, uuid_as_bytes, offset):
|
||||
def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> Tuple[int, UUID]:
|
||||
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
|
||||
|
||||
def to_bytes(self, force_128=False):
|
||||
def to_bytes(self, force_128: bool = False) -> bytes:
|
||||
'''
|
||||
Serialize UUID in little-endian byte-order
|
||||
'''
|
||||
@@ -227,7 +232,7 @@ class UUID:
|
||||
else:
|
||||
assert False, "unreachable"
|
||||
|
||||
def to_pdu_bytes(self):
|
||||
def to_pdu_bytes(self) -> bytes:
|
||||
'''
|
||||
Convert to bytes for use in an ATT PDU.
|
||||
According to Vol 3, Part F - 3.2.1 Attribute Type:
|
||||
@@ -236,11 +241,11 @@ class UUID:
|
||||
'''
|
||||
return self.to_bytes(force_128=(len(self.uuid_bytes) == 4))
|
||||
|
||||
def to_hex_str(self) -> str:
|
||||
def to_hex_str(self, separator: str = '') -> str:
|
||||
if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4:
|
||||
return bytes(reversed(self.uuid_bytes)).hex().upper()
|
||||
|
||||
return ''.join(
|
||||
return separator.join(
|
||||
[
|
||||
bytes(reversed(self.uuid_bytes[12:16])).hex(),
|
||||
bytes(reversed(self.uuid_bytes[10:12])).hex(),
|
||||
@@ -250,10 +255,10 @@ class UUID:
|
||||
]
|
||||
).upper()
|
||||
|
||||
def __bytes__(self):
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.to_bytes()
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, UUID):
|
||||
return self.to_bytes(force_128=True) == other.to_bytes(force_128=True)
|
||||
|
||||
@@ -262,35 +267,19 @@ class UUID:
|
||||
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.uuid_bytes)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
result = self.to_hex_str(separator='-')
|
||||
if len(self.uuid_bytes) == 2:
|
||||
uuid = struct.unpack('<H', self.uuid_bytes)[0]
|
||||
result = f'UUID-16:{uuid:04X}'
|
||||
result = 'UUID-16:' + result
|
||||
elif len(self.uuid_bytes) == 4:
|
||||
uuid = struct.unpack('<I', self.uuid_bytes)[0]
|
||||
result = f'UUID-32:{uuid:08X}'
|
||||
else:
|
||||
result = '-'.join(
|
||||
[
|
||||
bytes(reversed(self.uuid_bytes[12:16])).hex(),
|
||||
bytes(reversed(self.uuid_bytes[10:12])).hex(),
|
||||
bytes(reversed(self.uuid_bytes[8:10])).hex(),
|
||||
bytes(reversed(self.uuid_bytes[6:8])).hex(),
|
||||
bytes(reversed(self.uuid_bytes[0:6])).hex(),
|
||||
]
|
||||
).upper()
|
||||
|
||||
result = 'UUID-32:' + result
|
||||
if self.name is not None:
|
||||
return result + f' ({self.name})'
|
||||
|
||||
result += f' ({self.name})'
|
||||
return result
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Common UUID constants
|
||||
@@ -773,7 +762,7 @@ class AdvertisingData:
|
||||
def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> List[UUID]:
|
||||
uuids = []
|
||||
offset = 0
|
||||
while (uuid_size * (offset + 1)) <= len(ad_data):
|
||||
while (offset + uuid_size) <= len(ad_data):
|
||||
uuids.append(UUID.from_bytes(ad_data[offset : offset + uuid_size]))
|
||||
offset += uuid_size
|
||||
return uuids
|
||||
|
||||
177
bumble/device.py
177
bumble/device.py
@@ -23,7 +23,7 @@ import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager, AsyncExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from .colors import color
|
||||
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
|
||||
@@ -528,6 +528,7 @@ class Connection(CompositeEventEmitter):
|
||||
transport: int
|
||||
self_address: Address
|
||||
peer_address: Address
|
||||
peer_resolvable_address: Optional[Address]
|
||||
role: int
|
||||
encryption: int
|
||||
authenticated: bool
|
||||
@@ -2196,13 +2197,23 @@ class Device(CompositeEventEmitter):
|
||||
await self.stop_discovery()
|
||||
|
||||
@property
|
||||
def pairing_config_factory(self):
|
||||
def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]:
|
||||
return self.smp_manager.pairing_config_factory
|
||||
|
||||
@pairing_config_factory.setter
|
||||
def pairing_config_factory(self, pairing_config_factory):
|
||||
def pairing_config_factory(
|
||||
self, pairing_config_factory: Callable[[Connection], PairingConfig]
|
||||
) -> None:
|
||||
self.smp_manager.pairing_config_factory = pairing_config_factory
|
||||
|
||||
@property
|
||||
def smp_session_proxy(self) -> Type[smp.Session]:
|
||||
return self.smp_manager.session_proxy
|
||||
|
||||
@smp_session_proxy.setter
|
||||
def smp_session_proxy(self, session_proxy: Type[smp.Session]) -> None:
|
||||
self.smp_manager.session_proxy = session_proxy
|
||||
|
||||
async def pair(self, connection):
|
||||
return await self.smp_manager.pair(connection)
|
||||
|
||||
@@ -2232,7 +2243,7 @@ class Device(CompositeEventEmitter):
|
||||
if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral:
|
||||
return keys.ltk_peripheral.value
|
||||
|
||||
async def get_link_key(self, address):
|
||||
async def get_link_key(self, address: Address) -> Optional[bytes]:
|
||||
# Look for the key in the keystore
|
||||
if self.keystore is not None:
|
||||
keys = await self.keystore.get(str(address))
|
||||
@@ -2243,6 +2254,7 @@ class Device(CompositeEventEmitter):
|
||||
return None
|
||||
|
||||
return keys.link_key.value
|
||||
return None
|
||||
|
||||
# [Classic only]
|
||||
async def authenticate(self, connection):
|
||||
@@ -2772,89 +2784,103 @@ class Device(CompositeEventEmitter):
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_authentication_user_confirmation_request(self, connection, code):
|
||||
def on_authentication_user_confirmation_request(self, connection, code) -> None:
|
||||
# Ask what the pairing config should be for this connection
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
io_capability = pairing_config.delegate.classic_io_capability
|
||||
peer_io_capability = connection.peer_pairing_io_capability
|
||||
|
||||
# Respond
|
||||
if io_capability == HCI_DISPLAY_YES_NO_IO_CAPABILITY:
|
||||
if connection.peer_pairing_io_capability in (
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY,
|
||||
):
|
||||
# Display the code and ask the user to compare
|
||||
async def prompt():
|
||||
return (
|
||||
await pairing_config.delegate.compare_numbers(code, digits=6),
|
||||
async def confirm() -> bool:
|
||||
# Ask the user to confirm the pairing, without display
|
||||
return await pairing_config.delegate.confirm()
|
||||
|
||||
async def auto_confirm() -> bool:
|
||||
# Ask the user to auto-confirm the pairing, without display
|
||||
return await pairing_config.delegate.confirm(auto=True)
|
||||
|
||||
async def display_confirm() -> bool:
|
||||
# Display the code and ask the user to compare
|
||||
return await pairing_config.delegate.compare_numbers(code, digits=6)
|
||||
|
||||
async def display_auto_confirm() -> bool:
|
||||
# Display the code to the user and ask the delegate to auto-confirm
|
||||
await pairing_config.delegate.display_number(code, digits=6)
|
||||
return await pairing_config.delegate.confirm(auto=True)
|
||||
|
||||
async def na() -> bool:
|
||||
assert False, "N/A: unreachable"
|
||||
|
||||
# See Bluetooth spec @ Vol 3, Part C 5.2.2.6
|
||||
methods = {
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY: {
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY: display_auto_confirm,
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY: display_confirm,
|
||||
HCI_KEYBOARD_ONLY_IO_CAPABILITY: na,
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
|
||||
},
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY: {
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY: display_auto_confirm,
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY: display_confirm,
|
||||
HCI_KEYBOARD_ONLY_IO_CAPABILITY: na,
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
|
||||
},
|
||||
HCI_KEYBOARD_ONLY_IO_CAPABILITY: {
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY: na,
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY: na,
|
||||
HCI_KEYBOARD_ONLY_IO_CAPABILITY: na,
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
|
||||
},
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: {
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY: confirm,
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY: confirm,
|
||||
HCI_KEYBOARD_ONLY_IO_CAPABILITY: auto_confirm,
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: auto_confirm,
|
||||
},
|
||||
}
|
||||
|
||||
method = methods[peer_io_capability][io_capability]
|
||||
|
||||
async def reply() -> None:
|
||||
if await connection.abort_on('disconnection', method()):
|
||||
await self.host.send_command(
|
||||
HCI_User_Confirmation_Request_Reply_Command( # type: ignore[call-arg]
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
|
||||
)
|
||||
else:
|
||||
# Ask the user to confirm the pairing, without showing a code
|
||||
async def prompt():
|
||||
return await pairing_config.delegate.confirm()
|
||||
|
||||
async def confirm():
|
||||
if await prompt():
|
||||
await self.host.send_command(
|
||||
HCI_User_Confirmation_Request_Reply_Command(
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.host.send_command(
|
||||
HCI_User_Confirmation_Request_Negative_Reply_Command(
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
await self.host.send_command(
|
||||
HCI_User_Confirmation_Request_Negative_Reply_Command( # type: ignore[call-arg]
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
)
|
||||
|
||||
AsyncRunner.spawn(connection.abort_on('disconnection', confirm()))
|
||||
return
|
||||
|
||||
if io_capability == HCI_DISPLAY_ONLY_IO_CAPABILITY:
|
||||
# Display the code to the user
|
||||
AsyncRunner.spawn(pairing_config.delegate.display_number(code, 6))
|
||||
|
||||
# Automatic confirmation
|
||||
self.host.send_command_sync(
|
||||
HCI_User_Confirmation_Request_Reply_Command(bd_addr=connection.peer_address)
|
||||
)
|
||||
AsyncRunner.spawn(reply())
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_authentication_user_passkey_request(self, connection):
|
||||
def on_authentication_user_passkey_request(self, connection) -> None:
|
||||
# Ask what the pairing config should be for this connection
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
io_capability = pairing_config.delegate.classic_io_capability
|
||||
|
||||
# Respond
|
||||
if io_capability == HCI_KEYBOARD_ONLY_IO_CAPABILITY:
|
||||
# Ask the user to input a number
|
||||
async def get_number():
|
||||
number = await connection.abort_on(
|
||||
'disconnection', pairing_config.delegate.get_number()
|
||||
)
|
||||
if number is not None:
|
||||
await self.host.send_command(
|
||||
HCI_User_Passkey_Request_Reply_Command(
|
||||
bd_addr=connection.peer_address, numeric_value=number
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.host.send_command(
|
||||
HCI_User_Passkey_Request_Negative_Reply_Command(
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
)
|
||||
|
||||
asyncio.create_task(get_number())
|
||||
else:
|
||||
self.host.send_command_sync(
|
||||
HCI_User_Passkey_Request_Negative_Reply_Command(
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
async def reply() -> None:
|
||||
number = await connection.abort_on(
|
||||
'disconnection', pairing_config.delegate.get_number()
|
||||
)
|
||||
if number is not None:
|
||||
await self.host.send_command(
|
||||
HCI_User_Passkey_Request_Reply_Command( # type: ignore[call-arg]
|
||||
bd_addr=connection.peer_address, numeric_value=number
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.host.send_command(
|
||||
HCI_User_Passkey_Request_Negative_Reply_Command( # type: ignore[call-arg]
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
)
|
||||
|
||||
AsyncRunner.spawn(reply())
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@@ -3059,18 +3085,15 @@ class Device(CompositeEventEmitter):
|
||||
connection.emit('role_change_failure', error)
|
||||
self.emit('role_change_failure', address, error)
|
||||
|
||||
@with_connection_from_handle
|
||||
def on_pairing_start(self, connection):
|
||||
def on_pairing_start(self, connection: Connection) -> None:
|
||||
connection.emit('pairing_start')
|
||||
|
||||
@with_connection_from_handle
|
||||
def on_pairing(self, connection, keys, sc):
|
||||
def on_pairing(self, connection: Connection, keys: PairingKeys, sc: bool) -> None:
|
||||
connection.sc = sc
|
||||
connection.authenticated = True
|
||||
connection.emit('pairing', keys)
|
||||
|
||||
@with_connection_from_handle
|
||||
def on_pairing_failure(self, connection, reason):
|
||||
def on_pairing_failure(self, connection: Connection, reason: int) -> None:
|
||||
connection.emit('pairing_failure', reason)
|
||||
|
||||
@with_connection_from_handle
|
||||
|
||||
@@ -247,7 +247,7 @@ class TemplateService(Service):
|
||||
to expose their UUID as a class property
|
||||
'''
|
||||
|
||||
UUID = None
|
||||
UUID: Optional[UUID] = None
|
||||
|
||||
def __init__(self, characteristics, primary=True):
|
||||
super().__init__(self.UUID, characteristics, primary)
|
||||
|
||||
@@ -25,7 +25,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from .colors import color
|
||||
from .hci import Address
|
||||
@@ -139,19 +139,19 @@ class PairingKeys:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class KeyStore:
|
||||
async def delete(self, name):
|
||||
async def delete(self, name: str):
|
||||
pass
|
||||
|
||||
async def update(self, name, keys):
|
||||
async def update(self, name: str, keys: PairingKeys) -> None:
|
||||
pass
|
||||
|
||||
async def get(self, _name):
|
||||
return PairingKeys()
|
||||
async def get(self, _name: str) -> Optional[PairingKeys]:
|
||||
return None
|
||||
|
||||
async def get_all(self):
|
||||
async def get_all(self) -> List[Tuple[str, PairingKeys]]:
|
||||
return []
|
||||
|
||||
async def delete_all(self):
|
||||
async def delete_all(self) -> None:
|
||||
all_keys = await self.get_all()
|
||||
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
|
||||
|
||||
@@ -177,15 +177,15 @@ class KeyStore:
|
||||
separator = '\n'
|
||||
|
||||
@staticmethod
|
||||
def create_for_device(device: Device) -> Optional[KeyStore]:
|
||||
def create_for_device(device: Device) -> KeyStore:
|
||||
if device.config.keystore is None:
|
||||
return None
|
||||
return MemoryKeyStore()
|
||||
|
||||
keystore_type = device.config.keystore.split(':', 1)[0]
|
||||
if keystore_type == 'JsonKeyStore':
|
||||
return JsonKeyStore.from_device(device)
|
||||
|
||||
return None
|
||||
return MemoryKeyStore()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -307,3 +307,24 @@ class JsonKeyStore(KeyStore):
|
||||
return None
|
||||
|
||||
return PairingKeys.from_dict(keys)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MemoryKeyStore(KeyStore):
|
||||
all_keys: Dict[str, PairingKeys]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.all_keys = {}
|
||||
|
||||
async def delete(self, name: str) -> None:
|
||||
if name in self.all_keys:
|
||||
del self.all_keys[name]
|
||||
|
||||
async def update(self, name: str, keys: PairingKeys) -> None:
|
||||
self.all_keys[name] = keys
|
||||
|
||||
async def get(self, name: str) -> Optional[PairingKeys]:
|
||||
return self.all_keys.get(name)
|
||||
|
||||
async def get_all(self) -> List[Tuple[str, PairingKeys]]:
|
||||
return list(self.all_keys.items())
|
||||
|
||||
@@ -65,8 +65,9 @@ class PairingDelegate:
|
||||
DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG
|
||||
DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG
|
||||
|
||||
DEFAULT_KEY_DISTRIBUTION: int = (
|
||||
SMP_ENC_KEY_DISTRIBUTION_FLAG | SMP_ID_KEY_DISTRIBUTION_FLAG
|
||||
DEFAULT_KEY_DISTRIBUTION: KeyDistribution = (
|
||||
KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY
|
||||
| KeyDistribution.DISTRIBUTE_IDENTITY_KEY
|
||||
)
|
||||
|
||||
# Default mapping from abstract to Classic I/O capabilities.
|
||||
@@ -85,9 +86,9 @@ class PairingDelegate:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
io_capability=NO_OUTPUT_NO_INPUT,
|
||||
local_initiator_key_distribution=DEFAULT_KEY_DISTRIBUTION,
|
||||
local_responder_key_distribution=DEFAULT_KEY_DISTRIBUTION,
|
||||
io_capability: IoCapability = NO_OUTPUT_NO_INPUT,
|
||||
local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
|
||||
local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
|
||||
) -> None:
|
||||
self.io_capability = io_capability
|
||||
self.local_initiator_key_distribution = local_initiator_key_distribution
|
||||
@@ -113,8 +114,11 @@ class PairingDelegate:
|
||||
"""Accept or reject a Pairing request."""
|
||||
return True
|
||||
|
||||
async def confirm(self) -> bool:
|
||||
"""Respond yes or no to a Pairing confirmation question."""
|
||||
async def confirm(self, auto: bool = False) -> bool:
|
||||
"""
|
||||
Respond yes or no to a Pairing confirmation question.
|
||||
The `auto` parameter stands for automatic confirmation.
|
||||
"""
|
||||
return True
|
||||
|
||||
# pylint: disable-next=unused-argument
|
||||
@@ -129,7 +133,7 @@ class PairingDelegate:
|
||||
"""
|
||||
return 0
|
||||
|
||||
async def get_string(self, max_length) -> Optional[str]:
|
||||
async def get_string(self, max_length: int) -> Optional[str]:
|
||||
"""
|
||||
Return a string whose utf-8 encoding is up to max_length bytes.
|
||||
"""
|
||||
|
||||
376
bumble/smp.py
376
bumble/smp.py
@@ -26,16 +26,22 @@ from __future__ import annotations
|
||||
import logging
|
||||
import asyncio
|
||||
import secrets
|
||||
from typing import Dict, Optional, Type
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
from pyee import EventEmitter
|
||||
|
||||
from .colors import color
|
||||
from .hci import (
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY,
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
|
||||
HCI_KEYBOARD_ONLY_IO_CAPABILITY,
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
|
||||
Address,
|
||||
HCI_LE_Enable_Encryption_Command,
|
||||
HCI_Object,
|
||||
@@ -51,6 +57,10 @@ from .core import (
|
||||
from .keys import PairingKeys
|
||||
from . import crypto
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Connection, Device
|
||||
from bumble.pairing import PairingConfig
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -184,7 +194,7 @@ SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032'
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# -----------------------------------------------------------------------------
|
||||
def error_name(error_code):
|
||||
def error_name(error_code: int) -> str:
|
||||
return name_or_number(SMP_ERROR_NAMES, error_code)
|
||||
|
||||
|
||||
@@ -197,11 +207,12 @@ class SMP_Command:
|
||||
'''
|
||||
|
||||
smp_classes: Dict[int, Type[SMP_Command]] = {}
|
||||
fields: Any
|
||||
code = 0
|
||||
name = ''
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(pdu):
|
||||
def from_bytes(pdu: bytes) -> "SMP_Command":
|
||||
code = pdu[0]
|
||||
|
||||
cls = SMP_Command.smp_classes.get(code)
|
||||
@@ -217,11 +228,11 @@ class SMP_Command:
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def command_name(code):
|
||||
def command_name(code: int) -> str:
|
||||
return name_or_number(SMP_COMMAND_NAMES, code)
|
||||
|
||||
@staticmethod
|
||||
def auth_req_str(value):
|
||||
def auth_req_str(value: int) -> str:
|
||||
bonding_flags = value & 3
|
||||
mitm = (value >> 2) & 1
|
||||
sc = (value >> 3) & 1
|
||||
@@ -234,12 +245,12 @@ class SMP_Command:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def io_capability_name(io_capability):
|
||||
def io_capability_name(io_capability: int) -> str:
|
||||
return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability)
|
||||
|
||||
@staticmethod
|
||||
def key_distribution_str(value):
|
||||
key_types = []
|
||||
def key_distribution_str(value: int) -> str:
|
||||
key_types: List[str] = []
|
||||
if value & SMP_ENC_KEY_DISTRIBUTION_FLAG:
|
||||
key_types.append('ENC')
|
||||
if value & SMP_ID_KEY_DISTRIBUTION_FLAG:
|
||||
@@ -251,7 +262,7 @@ class SMP_Command:
|
||||
return ','.join(key_types)
|
||||
|
||||
@staticmethod
|
||||
def keypress_notification_type_name(notification_type):
|
||||
def keypress_notification_type_name(notification_type: int) -> str:
|
||||
return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type)
|
||||
|
||||
@staticmethod
|
||||
@@ -272,14 +283,14 @@ class SMP_Command:
|
||||
|
||||
return inner
|
||||
|
||||
def __init__(self, pdu=None, **kwargs):
|
||||
def __init__(self, pdu: Optional[bytes] = None, **kwargs: Any) -> None:
|
||||
if hasattr(self, 'fields') and kwargs:
|
||||
HCI_Object.init_from_fields(self, self.fields, kwargs)
|
||||
if pdu is None:
|
||||
pdu = bytes([self.code]) + HCI_Object.dict_to_bytes(kwargs, self.fields)
|
||||
self.pdu = pdu
|
||||
|
||||
def init_from_bytes(self, pdu, offset):
|
||||
def init_from_bytes(self, pdu: bytes, offset: int) -> None:
|
||||
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||
|
||||
def to_bytes(self):
|
||||
@@ -320,6 +331,13 @@ class SMP_Pairing_Request_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request
|
||||
'''
|
||||
|
||||
io_capability: int
|
||||
oob_data_flag: int
|
||||
auth_req: int
|
||||
maximum_encryption_key_size: int
|
||||
initiator_key_distribution: int
|
||||
responder_key_distribution: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass(
|
||||
@@ -343,6 +361,13 @@ class SMP_Pairing_Response_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response
|
||||
'''
|
||||
|
||||
io_capability: int
|
||||
oob_data_flag: int
|
||||
auth_req: int
|
||||
maximum_encryption_key_size: int
|
||||
initiator_key_distribution: int
|
||||
responder_key_distribution: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('confirm_value', 16)])
|
||||
@@ -351,6 +376,8 @@ class SMP_Pairing_Confirm_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm
|
||||
'''
|
||||
|
||||
confirm_value: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('random_value', 16)])
|
||||
@@ -359,6 +386,8 @@ class SMP_Pairing_Random_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random
|
||||
'''
|
||||
|
||||
random_value: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('reason', {'size': 1, 'mapper': error_name})])
|
||||
@@ -367,6 +396,8 @@ class SMP_Pairing_Failed_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed
|
||||
'''
|
||||
|
||||
reason: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('public_key_x', 32), ('public_key_y', 32)])
|
||||
@@ -375,6 +406,9 @@ class SMP_Pairing_Public_Key_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key
|
||||
'''
|
||||
|
||||
public_key_x: bytes
|
||||
public_key_y: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass(
|
||||
@@ -387,6 +421,8 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check
|
||||
'''
|
||||
|
||||
dhkey_check: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass(
|
||||
@@ -402,6 +438,8 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification
|
||||
'''
|
||||
|
||||
notification_type: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('long_term_key', 16)])
|
||||
@@ -410,6 +448,8 @@ class SMP_Encryption_Information_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information
|
||||
'''
|
||||
|
||||
long_term_key: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('ediv', 2), ('rand', 8)])
|
||||
@@ -418,6 +458,9 @@ class SMP_Master_Identification_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification
|
||||
'''
|
||||
|
||||
ediv: int
|
||||
rand: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('identity_resolving_key', 16)])
|
||||
@@ -426,6 +469,8 @@ class SMP_Identity_Information_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information
|
||||
'''
|
||||
|
||||
identity_resolving_key: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass(
|
||||
@@ -439,6 +484,9 @@ class SMP_Identity_Address_Information_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information
|
||||
'''
|
||||
|
||||
addr_type: int
|
||||
bd_addr: Address
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass([('signature_key', 16)])
|
||||
@@ -447,6 +495,8 @@ class SMP_Signing_Information_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information
|
||||
'''
|
||||
|
||||
signature_key: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SMP_Command.subclass(
|
||||
@@ -459,9 +509,11 @@ class SMP_Security_Request_Command(SMP_Command):
|
||||
See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request
|
||||
'''
|
||||
|
||||
auth_req: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def smp_auth_req(bonding, mitm, sc, keypress, ct2):
|
||||
def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) -> int:
|
||||
value = 0
|
||||
if bonding:
|
||||
value |= SMP_BONDING_AUTHREQ
|
||||
@@ -574,11 +626,17 @@ class Session:
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, manager, connection, pairing_config, is_initiator):
|
||||
def __init__(
|
||||
self,
|
||||
manager: Manager,
|
||||
connection: Connection,
|
||||
pairing_config: PairingConfig,
|
||||
is_initiator: bool,
|
||||
) -> None:
|
||||
self.manager = manager
|
||||
self.connection = connection
|
||||
self.preq = None
|
||||
self.pres = None
|
||||
self.preq: Optional[bytes] = None
|
||||
self.pres: Optional[bytes] = None
|
||||
self.ea = None
|
||||
self.eb = None
|
||||
self.tk = bytes(16)
|
||||
@@ -588,29 +646,29 @@ class Session:
|
||||
self.ltk_ediv = 0
|
||||
self.ltk_rand = bytes(8)
|
||||
self.link_key = None
|
||||
self.initiator_key_distribution = 0
|
||||
self.responder_key_distribution = 0
|
||||
self.peer_random_value = None
|
||||
self.peer_public_key_x = bytes(32)
|
||||
self.initiator_key_distribution: int = 0
|
||||
self.responder_key_distribution: int = 0
|
||||
self.peer_random_value: Optional[bytes] = None
|
||||
self.peer_public_key_x: bytes = bytes(32)
|
||||
self.peer_public_key_y = bytes(32)
|
||||
self.peer_ltk = None
|
||||
self.peer_ediv = None
|
||||
self.peer_rand = None
|
||||
self.peer_rand: Optional[bytes] = None
|
||||
self.peer_identity_resolving_key = None
|
||||
self.peer_bd_addr = None
|
||||
self.peer_bd_addr: Optional[Address] = None
|
||||
self.peer_signature_key = None
|
||||
self.peer_expected_distributions = []
|
||||
self.peer_expected_distributions: List[Type[SMP_Command]] = []
|
||||
self.dh_key = None
|
||||
self.confirm_value = None
|
||||
self.passkey = None
|
||||
self.passkey: Optional[int] = None
|
||||
self.passkey_ready = asyncio.Event()
|
||||
self.passkey_step = 0
|
||||
self.passkey_display = False
|
||||
self.pairing_method = 0
|
||||
self.pairing_config = pairing_config
|
||||
self.wait_before_continuing = None
|
||||
self.wait_before_continuing: Optional[asyncio.Future[None]] = None
|
||||
self.completed = False
|
||||
self.ctkd_task = None
|
||||
self.ctkd_task: Optional[Awaitable[None]] = None
|
||||
|
||||
# Decide if we're the initiator or the responder
|
||||
self.is_initiator = is_initiator
|
||||
@@ -628,7 +686,9 @@ class Session:
|
||||
|
||||
# Create a future that can be used to wait for the session to complete
|
||||
if self.is_initiator:
|
||||
self.pairing_result = asyncio.get_running_loop().create_future()
|
||||
self.pairing_result: Optional[
|
||||
asyncio.Future[None]
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
else:
|
||||
self.pairing_result = None
|
||||
|
||||
@@ -641,11 +701,11 @@ class Session:
|
||||
)
|
||||
|
||||
# Authentication Requirements Flags - Vol 3, Part H, Figure 3.3
|
||||
self.bonding = pairing_config.bonding
|
||||
self.sc = pairing_config.sc
|
||||
self.mitm = pairing_config.mitm
|
||||
self.bonding: bool = pairing_config.bonding
|
||||
self.sc: bool = pairing_config.sc
|
||||
self.mitm: bool = pairing_config.mitm
|
||||
self.keypress = False
|
||||
self.ct2 = False
|
||||
self.ct2: bool = False
|
||||
|
||||
# I/O Capabilities
|
||||
self.io_capability = pairing_config.delegate.io_capability
|
||||
@@ -669,34 +729,35 @@ class Session:
|
||||
self.iat = 1 if peer_address.is_random else 0
|
||||
|
||||
@property
|
||||
def pkx(self):
|
||||
def pkx(self) -> Tuple[bytes, bytes]:
|
||||
return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x)
|
||||
|
||||
@property
|
||||
def pka(self):
|
||||
def pka(self) -> bytes:
|
||||
return self.pkx[0 if self.is_initiator else 1]
|
||||
|
||||
@property
|
||||
def pkb(self):
|
||||
def pkb(self) -> bytes:
|
||||
return self.pkx[0 if self.is_responder else 1]
|
||||
|
||||
@property
|
||||
def nx(self):
|
||||
def nx(self) -> Tuple[bytes, bytes]:
|
||||
assert self.peer_random_value
|
||||
return (self.r, self.peer_random_value)
|
||||
|
||||
@property
|
||||
def na(self):
|
||||
def na(self) -> bytes:
|
||||
return self.nx[0 if self.is_initiator else 1]
|
||||
|
||||
@property
|
||||
def nb(self):
|
||||
def nb(self) -> bytes:
|
||||
return self.nx[0 if self.is_responder else 1]
|
||||
|
||||
@property
|
||||
def auth_req(self):
|
||||
def auth_req(self) -> int:
|
||||
return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2)
|
||||
|
||||
def get_long_term_key(self, rand, ediv):
|
||||
def get_long_term_key(self, rand: bytes, ediv: int) -> Optional[bytes]:
|
||||
if not self.sc and not self.completed:
|
||||
if rand == self.ltk_rand and ediv == self.ltk_ediv:
|
||||
return self.stk
|
||||
@@ -706,13 +767,13 @@ class Session:
|
||||
return None
|
||||
|
||||
def decide_pairing_method(
|
||||
self, auth_req, initiator_io_capability, responder_io_capability
|
||||
):
|
||||
self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
|
||||
) -> None:
|
||||
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
|
||||
self.pairing_method = self.JUST_WORKS
|
||||
return
|
||||
|
||||
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability]
|
||||
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] # type: ignore[index]
|
||||
if isinstance(details, tuple) and len(details) == 2:
|
||||
# One entry for legacy pairing and one for secure connections
|
||||
details = details[1 if self.sc else 0]
|
||||
@@ -724,7 +785,9 @@ class Session:
|
||||
self.pairing_method = details[0]
|
||||
self.passkey_display = details[1 if self.is_initiator else 2]
|
||||
|
||||
def check_expected_value(self, expected, received, error):
|
||||
def check_expected_value(
|
||||
self, expected: bytes, received: bytes, error: int
|
||||
) -> bool:
|
||||
logger.debug(f'expected={expected.hex()} got={received.hex()}')
|
||||
if expected != received:
|
||||
logger.info(color('pairing confirm/check mismatch', 'red'))
|
||||
@@ -732,8 +795,8 @@ class Session:
|
||||
return False
|
||||
return True
|
||||
|
||||
def prompt_user_for_confirmation(self, next_steps):
|
||||
async def prompt():
|
||||
def prompt_user_for_confirmation(self, next_steps: Callable[[], None]) -> None:
|
||||
async def prompt() -> None:
|
||||
logger.debug('ask for confirmation')
|
||||
try:
|
||||
response = await self.pairing_config.delegate.confirm()
|
||||
@@ -747,8 +810,10 @@ class Session:
|
||||
|
||||
self.connection.abort_on('disconnection', prompt())
|
||||
|
||||
def prompt_user_for_numeric_comparison(self, code, next_steps):
|
||||
async def prompt():
|
||||
def prompt_user_for_numeric_comparison(
|
||||
self, code: int, next_steps: Callable[[], None]
|
||||
) -> None:
|
||||
async def prompt() -> None:
|
||||
logger.debug(f'verification code: {code}')
|
||||
try:
|
||||
response = await self.pairing_config.delegate.compare_numbers(
|
||||
@@ -764,11 +829,15 @@ class Session:
|
||||
|
||||
self.connection.abort_on('disconnection', prompt())
|
||||
|
||||
def prompt_user_for_number(self, next_steps):
|
||||
async def prompt():
|
||||
def prompt_user_for_number(self, next_steps: Callable[[int], None]) -> None:
|
||||
async def prompt() -> None:
|
||||
logger.debug('prompting user for passkey')
|
||||
try:
|
||||
passkey = await self.pairing_config.delegate.get_number()
|
||||
if passkey is None:
|
||||
logger.debug('Passkey request rejected')
|
||||
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
|
||||
return
|
||||
logger.debug(f'user input: {passkey}')
|
||||
next_steps(passkey)
|
||||
except Exception as error:
|
||||
@@ -777,9 +846,10 @@ class Session:
|
||||
|
||||
self.connection.abort_on('disconnection', prompt())
|
||||
|
||||
def display_passkey(self):
|
||||
def display_passkey(self) -> None:
|
||||
# Generate random Passkey/PIN code
|
||||
self.passkey = secrets.randbelow(1000000)
|
||||
assert self.passkey is not None
|
||||
logger.debug(f'Pairing PIN CODE: {self.passkey:06}')
|
||||
self.passkey_ready.set()
|
||||
|
||||
@@ -793,9 +863,9 @@ class Session:
|
||||
self.pairing_config.delegate.display_number(self.passkey, digits=6),
|
||||
)
|
||||
|
||||
def input_passkey(self, next_steps=None):
|
||||
def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None:
|
||||
# Prompt the user for the passkey displayed on the peer
|
||||
def after_input(passkey):
|
||||
def after_input(passkey: int) -> None:
|
||||
self.passkey = passkey
|
||||
|
||||
if not self.sc:
|
||||
@@ -809,7 +879,9 @@ class Session:
|
||||
|
||||
self.prompt_user_for_number(after_input)
|
||||
|
||||
def display_or_input_passkey(self, next_steps=None):
|
||||
def display_or_input_passkey(
|
||||
self, next_steps: Optional[Callable[[], None]] = None
|
||||
) -> None:
|
||||
if self.passkey_display:
|
||||
self.display_passkey()
|
||||
if next_steps is not None:
|
||||
@@ -817,14 +889,14 @@ class Session:
|
||||
else:
|
||||
self.input_passkey(next_steps)
|
||||
|
||||
def send_command(self, command):
|
||||
def send_command(self, command: SMP_Command) -> None:
|
||||
self.manager.send_command(self.connection, command)
|
||||
|
||||
def send_pairing_failed(self, error):
|
||||
def send_pairing_failed(self, error: int) -> None:
|
||||
self.send_command(SMP_Pairing_Failed_Command(reason=error))
|
||||
self.on_pairing_failure(error)
|
||||
|
||||
def send_pairing_request_command(self):
|
||||
def send_pairing_request_command(self) -> None:
|
||||
self.manager.on_session_start(self)
|
||||
|
||||
command = SMP_Pairing_Request_Command(
|
||||
@@ -838,7 +910,7 @@ class Session:
|
||||
self.preq = bytes(command)
|
||||
self.send_command(command)
|
||||
|
||||
def send_pairing_response_command(self):
|
||||
def send_pairing_response_command(self) -> None:
|
||||
response = SMP_Pairing_Response_Command(
|
||||
io_capability=self.io_capability,
|
||||
oob_data_flag=0,
|
||||
@@ -850,18 +922,19 @@ class Session:
|
||||
self.pres = bytes(response)
|
||||
self.send_command(response)
|
||||
|
||||
def send_pairing_confirm_command(self):
|
||||
def send_pairing_confirm_command(self) -> None:
|
||||
self.r = crypto.r()
|
||||
logger.debug(f'generated random: {self.r.hex()}')
|
||||
|
||||
if self.sc:
|
||||
|
||||
async def next_steps():
|
||||
async def next_steps() -> None:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
z = 0
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
# We need a passkey
|
||||
await self.passkey_ready.wait()
|
||||
assert self.passkey
|
||||
|
||||
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
|
||||
else:
|
||||
@@ -892,10 +965,10 @@ class Session:
|
||||
|
||||
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
|
||||
|
||||
def send_pairing_random_command(self):
|
||||
def send_pairing_random_command(self) -> None:
|
||||
self.send_command(SMP_Pairing_Random_Command(random_value=self.r))
|
||||
|
||||
def send_public_key_command(self):
|
||||
def send_public_key_command(self) -> None:
|
||||
self.send_command(
|
||||
SMP_Pairing_Public_Key_Command(
|
||||
public_key_x=bytes(reversed(self.manager.ecc_key.x)),
|
||||
@@ -903,18 +976,18 @@ class Session:
|
||||
)
|
||||
)
|
||||
|
||||
def send_pairing_dhkey_check_command(self):
|
||||
def send_pairing_dhkey_check_command(self) -> None:
|
||||
self.send_command(
|
||||
SMP_Pairing_DHKey_Check_Command(
|
||||
dhkey_check=self.ea if self.is_initiator else self.eb
|
||||
)
|
||||
)
|
||||
|
||||
def start_encryption(self, key):
|
||||
def start_encryption(self, key: bytes) -> None:
|
||||
# We can now encrypt the connection with the short term key, so that we can
|
||||
# distribute the long term and/or other keys over an encrypted connection
|
||||
self.manager.device.host.send_command_sync(
|
||||
HCI_LE_Enable_Encryption_Command(
|
||||
HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
|
||||
connection_handle=self.connection.handle,
|
||||
random_number=bytes(8),
|
||||
encrypted_diversifier=0,
|
||||
@@ -922,7 +995,7 @@ class Session:
|
||||
)
|
||||
)
|
||||
|
||||
async def derive_ltk(self):
|
||||
async def derive_ltk(self) -> None:
|
||||
link_key = await self.manager.device.get_link_key(self.connection.peer_address)
|
||||
assert link_key is not None
|
||||
ilk = (
|
||||
@@ -932,7 +1005,7 @@ class Session:
|
||||
)
|
||||
self.ltk = crypto.h6(ilk, b'brle')
|
||||
|
||||
def distribute_keys(self):
|
||||
def distribute_keys(self) -> None:
|
||||
# Distribute the keys as required
|
||||
if self.is_initiator:
|
||||
# CTKD: Derive LTK from LinkKey
|
||||
@@ -1032,7 +1105,7 @@ class Session:
|
||||
)
|
||||
self.link_key = crypto.h6(ilk, b'lebr')
|
||||
|
||||
def compute_peer_expected_distributions(self, key_distribution_flags):
|
||||
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
|
||||
# Set our expectations for what to wait for in the key distribution phase
|
||||
self.peer_expected_distributions = []
|
||||
if not self.sc and self.connection.transport == BT_LE_TRANSPORT:
|
||||
@@ -1055,7 +1128,7 @@ class Session:
|
||||
f'{[c.__name__ for c in self.peer_expected_distributions]}'
|
||||
)
|
||||
|
||||
def check_key_distribution(self, command_class):
|
||||
def check_key_distribution(self, command_class: Type[SMP_Command]) -> None:
|
||||
# First, check that the connection is encrypted
|
||||
if not self.connection.is_encrypted:
|
||||
logger.warning(
|
||||
@@ -1083,7 +1156,7 @@ class Session:
|
||||
)
|
||||
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
|
||||
|
||||
async def pair(self):
|
||||
async def pair(self) -> None:
|
||||
# Start pairing as an initiator
|
||||
# TODO: check that this session isn't already active
|
||||
|
||||
@@ -1091,9 +1164,10 @@ class Session:
|
||||
self.send_pairing_request_command()
|
||||
|
||||
# Wait for the pairing process to finish
|
||||
assert self.pairing_result
|
||||
await self.connection.abort_on('disconnection', self.pairing_result)
|
||||
|
||||
def on_disconnection(self, _):
|
||||
def on_disconnection(self, _: int) -> None:
|
||||
self.connection.remove_listener('disconnection', self.on_disconnection)
|
||||
self.connection.remove_listener(
|
||||
'connection_encryption_change', self.on_connection_encryption_change
|
||||
@@ -1104,14 +1178,14 @@ class Session:
|
||||
)
|
||||
self.manager.on_session_end(self)
|
||||
|
||||
def on_peer_key_distribution_complete(self):
|
||||
def on_peer_key_distribution_complete(self) -> None:
|
||||
# The initiator can now send its keys
|
||||
if self.is_initiator:
|
||||
self.distribute_keys()
|
||||
|
||||
self.connection.abort_on('disconnection', self.on_pairing())
|
||||
|
||||
def on_connection_encryption_change(self):
|
||||
def on_connection_encryption_change(self) -> None:
|
||||
if self.connection.is_encrypted:
|
||||
if self.is_responder:
|
||||
# The responder distributes its keys first, the initiator later
|
||||
@@ -1121,11 +1195,11 @@ class Session:
|
||||
if not self.peer_expected_distributions:
|
||||
self.on_peer_key_distribution_complete()
|
||||
|
||||
def on_connection_encryption_key_refresh(self):
|
||||
def on_connection_encryption_key_refresh(self) -> None:
|
||||
# Do as if the connection had just been encrypted
|
||||
self.on_connection_encryption_change()
|
||||
|
||||
async def on_pairing(self):
|
||||
async def on_pairing(self) -> None:
|
||||
logger.debug('pairing complete')
|
||||
|
||||
if self.completed:
|
||||
@@ -1137,7 +1211,7 @@ class Session:
|
||||
self.pairing_result.set_result(None)
|
||||
|
||||
# Use the peer address from the pairing protocol or the connection
|
||||
if self.peer_bd_addr:
|
||||
if self.peer_bd_addr is not None:
|
||||
peer_address = self.peer_bd_addr
|
||||
else:
|
||||
peer_address = self.connection.peer_address
|
||||
@@ -1186,7 +1260,7 @@ class Session:
|
||||
)
|
||||
self.manager.on_pairing(self, peer_address, keys)
|
||||
|
||||
def on_pairing_failure(self, reason):
|
||||
def on_pairing_failure(self, reason: int) -> None:
|
||||
logger.warning(f'pairing failure ({error_name(reason)})')
|
||||
|
||||
if self.completed:
|
||||
@@ -1199,7 +1273,7 @@ class Session:
|
||||
self.pairing_result.set_exception(error)
|
||||
self.manager.on_pairing_failure(self, reason)
|
||||
|
||||
def on_smp_command(self, command):
|
||||
def on_smp_command(self, command: SMP_Command) -> None:
|
||||
# Find the handler method
|
||||
handler_name = f'on_{command.name.lower()}'
|
||||
handler = getattr(self, handler_name, None)
|
||||
@@ -1215,12 +1289,16 @@ class Session:
|
||||
else:
|
||||
logger.error(color('SMP command not handled???', 'red'))
|
||||
|
||||
def on_smp_pairing_request_command(self, command):
|
||||
def on_smp_pairing_request_command(
|
||||
self, command: SMP_Pairing_Request_Command
|
||||
) -> None:
|
||||
self.connection.abort_on(
|
||||
'disconnection', self.on_smp_pairing_request_command_async(command)
|
||||
)
|
||||
|
||||
async def on_smp_pairing_request_command_async(self, command):
|
||||
async def on_smp_pairing_request_command_async(
|
||||
self, command: SMP_Pairing_Request_Command
|
||||
) -> None:
|
||||
# Check if the request should proceed
|
||||
accepted = await self.pairing_config.delegate.accept()
|
||||
if not accepted:
|
||||
@@ -1280,7 +1358,9 @@ class Session:
|
||||
):
|
||||
self.distribute_keys()
|
||||
|
||||
def on_smp_pairing_response_command(self, command):
|
||||
def on_smp_pairing_response_command(
|
||||
self, command: SMP_Pairing_Response_Command
|
||||
) -> None:
|
||||
if self.is_responder:
|
||||
logger.warning(color('received pairing response as a responder', 'red'))
|
||||
return
|
||||
@@ -1331,7 +1411,9 @@ class Session:
|
||||
else:
|
||||
self.send_pairing_confirm_command()
|
||||
|
||||
def on_smp_pairing_confirm_command_legacy(self, _):
|
||||
def on_smp_pairing_confirm_command_legacy(
|
||||
self, _: SMP_Pairing_Confirm_Command
|
||||
) -> None:
|
||||
if self.is_initiator:
|
||||
self.send_pairing_random_command()
|
||||
else:
|
||||
@@ -1341,7 +1423,9 @@ class Session:
|
||||
else:
|
||||
self.send_pairing_confirm_command()
|
||||
|
||||
def on_smp_pairing_confirm_command_secure_connections(self, _):
|
||||
def on_smp_pairing_confirm_command_secure_connections(
|
||||
self, _: SMP_Pairing_Confirm_Command
|
||||
) -> None:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
if self.is_initiator:
|
||||
self.r = crypto.r()
|
||||
@@ -1352,14 +1436,18 @@ class Session:
|
||||
else:
|
||||
self.send_pairing_confirm_command()
|
||||
|
||||
def on_smp_pairing_confirm_command(self, command):
|
||||
def on_smp_pairing_confirm_command(
|
||||
self, command: SMP_Pairing_Confirm_Command
|
||||
) -> None:
|
||||
self.confirm_value = command.confirm_value
|
||||
if self.sc:
|
||||
self.on_smp_pairing_confirm_command_secure_connections(command)
|
||||
else:
|
||||
self.on_smp_pairing_confirm_command_legacy(command)
|
||||
|
||||
def on_smp_pairing_random_command_legacy(self, command):
|
||||
def on_smp_pairing_random_command_legacy(
|
||||
self, command: SMP_Pairing_Random_Command
|
||||
) -> None:
|
||||
# Check that the confirmation values match
|
||||
confirm_verifier = crypto.c1(
|
||||
self.tk,
|
||||
@@ -1371,6 +1459,7 @@ class Session:
|
||||
self.ia,
|
||||
self.ra,
|
||||
)
|
||||
assert self.confirm_value
|
||||
if not self.check_expected_value(
|
||||
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
|
||||
):
|
||||
@@ -1394,7 +1483,9 @@ class Session:
|
||||
else:
|
||||
self.send_pairing_random_command()
|
||||
|
||||
def on_smp_pairing_random_command_secure_connections(self, command):
|
||||
def on_smp_pairing_random_command_secure_connections(
|
||||
self, command: SMP_Pairing_Random_Command
|
||||
) -> None:
|
||||
if self.pairing_method == self.PASSKEY and self.passkey is None:
|
||||
logger.warning('no passkey entered, ignoring command')
|
||||
return
|
||||
@@ -1402,6 +1493,7 @@ class Session:
|
||||
# pylint: disable=too-many-return-statements
|
||||
if self.is_initiator:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
assert self.confirm_value
|
||||
# Check that the random value matches what was committed to earlier
|
||||
confirm_verifier = crypto.f4(
|
||||
self.pkb, self.pka, command.random_value, bytes([0])
|
||||
@@ -1411,6 +1503,7 @@ class Session:
|
||||
):
|
||||
return
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
assert self.passkey and self.confirm_value
|
||||
# Check that the random value matches what was committed to earlier
|
||||
confirm_verifier = crypto.f4(
|
||||
self.pkb,
|
||||
@@ -1435,6 +1528,7 @@ class Session:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
self.send_pairing_random_command()
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
assert self.passkey and self.confirm_value
|
||||
# Check that the random value matches what was committed to earlier
|
||||
confirm_verifier = crypto.f4(
|
||||
self.pka,
|
||||
@@ -1468,19 +1562,21 @@ class Session:
|
||||
ra = bytes(16)
|
||||
rb = ra
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
assert self.passkey
|
||||
ra = self.passkey.to_bytes(16, byteorder='little')
|
||||
rb = ra
|
||||
else:
|
||||
# OOB not implemented yet
|
||||
return
|
||||
|
||||
assert self.preq and self.pres
|
||||
io_cap_a = self.preq[1:4]
|
||||
io_cap_b = self.pres[1:4]
|
||||
self.ea = crypto.f6(mac_key, self.na, self.nb, rb, io_cap_a, a, b)
|
||||
self.eb = crypto.f6(mac_key, self.nb, self.na, ra, io_cap_b, b, a)
|
||||
|
||||
# Next steps to be performed after possible user confirmation
|
||||
def next_steps():
|
||||
def next_steps() -> None:
|
||||
# The initiator sends the DH Key check to the responder
|
||||
if self.is_initiator:
|
||||
self.send_pairing_dhkey_check_command()
|
||||
@@ -1502,14 +1598,18 @@ class Session:
|
||||
else:
|
||||
next_steps()
|
||||
|
||||
def on_smp_pairing_random_command(self, command):
|
||||
def on_smp_pairing_random_command(
|
||||
self, command: SMP_Pairing_Random_Command
|
||||
) -> None:
|
||||
self.peer_random_value = command.random_value
|
||||
if self.sc:
|
||||
self.on_smp_pairing_random_command_secure_connections(command)
|
||||
else:
|
||||
self.on_smp_pairing_random_command_legacy(command)
|
||||
|
||||
def on_smp_pairing_public_key_command(self, command):
|
||||
def on_smp_pairing_public_key_command(
|
||||
self, command: SMP_Pairing_Public_Key_Command
|
||||
) -> None:
|
||||
# Store the public key so that we can compute the confirmation value later
|
||||
self.peer_public_key_x = command.public_key_x
|
||||
self.peer_public_key_y = command.public_key_y
|
||||
@@ -1538,9 +1638,12 @@ class Session:
|
||||
# We can now send the confirmation value
|
||||
self.send_pairing_confirm_command()
|
||||
|
||||
def on_smp_pairing_dhkey_check_command(self, command):
|
||||
def on_smp_pairing_dhkey_check_command(
|
||||
self, command: SMP_Pairing_DHKey_Check_Command
|
||||
) -> None:
|
||||
# Check that what we received matches what we computed earlier
|
||||
expected = self.eb if self.is_initiator else self.ea
|
||||
assert expected
|
||||
if not self.check_expected_value(
|
||||
expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR
|
||||
):
|
||||
@@ -1549,7 +1652,8 @@ class Session:
|
||||
if self.is_responder:
|
||||
if self.wait_before_continuing is not None:
|
||||
|
||||
async def next_steps():
|
||||
async def next_steps() -> None:
|
||||
assert self.wait_before_continuing
|
||||
await self.wait_before_continuing
|
||||
self.wait_before_continuing = None
|
||||
self.send_pairing_dhkey_check_command()
|
||||
@@ -1558,29 +1662,42 @@ class Session:
|
||||
else:
|
||||
self.send_pairing_dhkey_check_command()
|
||||
else:
|
||||
assert self.ltk
|
||||
self.start_encryption(self.ltk)
|
||||
|
||||
def on_smp_pairing_failed_command(self, command):
|
||||
def on_smp_pairing_failed_command(
|
||||
self, command: SMP_Pairing_Failed_Command
|
||||
) -> None:
|
||||
self.on_pairing_failure(command.reason)
|
||||
|
||||
def on_smp_encryption_information_command(self, command):
|
||||
def on_smp_encryption_information_command(
|
||||
self, command: SMP_Encryption_Information_Command
|
||||
) -> None:
|
||||
self.peer_ltk = command.long_term_key
|
||||
self.check_key_distribution(SMP_Encryption_Information_Command)
|
||||
|
||||
def on_smp_master_identification_command(self, command):
|
||||
def on_smp_master_identification_command(
|
||||
self, command: SMP_Master_Identification_Command
|
||||
) -> None:
|
||||
self.peer_ediv = command.ediv
|
||||
self.peer_rand = command.rand
|
||||
self.check_key_distribution(SMP_Master_Identification_Command)
|
||||
|
||||
def on_smp_identity_information_command(self, command):
|
||||
def on_smp_identity_information_command(
|
||||
self, command: SMP_Identity_Information_Command
|
||||
) -> None:
|
||||
self.peer_identity_resolving_key = command.identity_resolving_key
|
||||
self.check_key_distribution(SMP_Identity_Information_Command)
|
||||
|
||||
def on_smp_identity_address_information_command(self, command):
|
||||
def on_smp_identity_address_information_command(
|
||||
self, command: SMP_Identity_Address_Information_Command
|
||||
) -> None:
|
||||
self.peer_bd_addr = command.bd_addr
|
||||
self.check_key_distribution(SMP_Identity_Address_Information_Command)
|
||||
|
||||
def on_smp_signing_information_command(self, command):
|
||||
def on_smp_signing_information_command(
|
||||
self, command: SMP_Signing_Information_Command
|
||||
) -> None:
|
||||
self.peer_signature_key = command.signature_key
|
||||
self.check_key_distribution(SMP_Signing_Information_Command)
|
||||
|
||||
@@ -1591,14 +1708,24 @@ class Manager(EventEmitter):
|
||||
Implements the Initiator and Responder roles of the Security Manager Protocol
|
||||
'''
|
||||
|
||||
def __init__(self, device, pairing_config_factory):
|
||||
device: Device
|
||||
sessions: Dict[int, Session]
|
||||
pairing_config_factory: Callable[[Connection], PairingConfig]
|
||||
session_proxy: Type[Session]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Device,
|
||||
pairing_config_factory: Callable[[Connection], PairingConfig],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.sessions = {}
|
||||
self._ecc_key = None
|
||||
self.pairing_config_factory = pairing_config_factory
|
||||
self.session_proxy = Session
|
||||
|
||||
def send_command(self, connection, command):
|
||||
def send_command(self, connection: Connection, command: SMP_Command) -> None:
|
||||
logger.debug(
|
||||
f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address}: {command}'
|
||||
@@ -1606,20 +1733,15 @@ class Manager(EventEmitter):
|
||||
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
|
||||
connection.send_l2cap_pdu(cid, command.to_bytes())
|
||||
|
||||
def on_smp_pdu(self, connection, pdu):
|
||||
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
|
||||
# Look for a session with this connection, and create one if none exists
|
||||
if not (session := self.sessions.get(connection.handle)):
|
||||
if connection.role == BT_CENTRAL_ROLE:
|
||||
logger.warning('Remote starts pairing as Peripheral!')
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
if pairing_config is None:
|
||||
# Pairing disabled
|
||||
self.send_command(
|
||||
connection,
|
||||
SMP_Pairing_Failed_Command(reason=SMP_PAIRING_NOT_SUPPORTED_ERROR),
|
||||
)
|
||||
return
|
||||
session = Session(self, connection, pairing_config, is_initiator=False)
|
||||
session = self.session_proxy(
|
||||
self, connection, pairing_config, is_initiator=False
|
||||
)
|
||||
self.sessions[connection.handle] = session
|
||||
|
||||
# Parse the L2CAP payload into an SMP Command object
|
||||
@@ -1633,23 +1755,24 @@ class Manager(EventEmitter):
|
||||
session.on_smp_command(command)
|
||||
|
||||
@property
|
||||
def ecc_key(self):
|
||||
def ecc_key(self) -> crypto.EccKey:
|
||||
if self._ecc_key is None:
|
||||
self._ecc_key = crypto.EccKey.generate()
|
||||
assert self._ecc_key
|
||||
return self._ecc_key
|
||||
|
||||
async def pair(self, connection):
|
||||
async def pair(self, connection: Connection) -> None:
|
||||
# TODO: check if there's already a session for this connection
|
||||
if connection.role != BT_CENTRAL_ROLE:
|
||||
logger.warning('Start pairing as Peripheral!')
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
if pairing_config is None:
|
||||
raise ValueError('pairing config must not be None when initiating')
|
||||
session = Session(self, connection, pairing_config, is_initiator=True)
|
||||
session = self.session_proxy(
|
||||
self, connection, pairing_config, is_initiator=True
|
||||
)
|
||||
self.sessions[connection.handle] = session
|
||||
return await session.pair()
|
||||
|
||||
def request_pairing(self, connection):
|
||||
def request_pairing(self, connection: Connection) -> None:
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
if pairing_config:
|
||||
auth_req = smp_auth_req(
|
||||
@@ -1663,15 +1786,18 @@ class Manager(EventEmitter):
|
||||
auth_req = 0
|
||||
self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req))
|
||||
|
||||
def on_session_start(self, session):
|
||||
self.device.on_pairing_start(session.connection.handle)
|
||||
def on_session_start(self, session: Session) -> None:
|
||||
self.device.on_pairing_start(session.connection)
|
||||
|
||||
def on_pairing(self, session, identity_address, keys):
|
||||
def on_pairing(
|
||||
self, session: Session, identity_address: Optional[Address], keys: PairingKeys
|
||||
) -> None:
|
||||
# Store the keys in the key store
|
||||
if self.device.keystore and identity_address is not None:
|
||||
|
||||
async def store_keys():
|
||||
try:
|
||||
assert self.device.keystore
|
||||
await self.device.keystore.update(str(identity_address), keys)
|
||||
except Exception as error:
|
||||
logger.warning(f'!!! error while storing keys: {error}')
|
||||
@@ -1679,17 +1805,19 @@ class Manager(EventEmitter):
|
||||
self.device.abort_on('flush', store_keys())
|
||||
|
||||
# Notify the device
|
||||
self.device.on_pairing(session.connection.handle, keys, session.sc)
|
||||
self.device.on_pairing(session.connection, keys, session.sc)
|
||||
|
||||
def on_pairing_failure(self, session, reason):
|
||||
self.device.on_pairing_failure(session.connection.handle, reason)
|
||||
def on_pairing_failure(self, session: Session, reason: int) -> None:
|
||||
self.device.on_pairing_failure(session.connection, reason)
|
||||
|
||||
def on_session_end(self, session):
|
||||
def on_session_end(self, session: Session) -> None:
|
||||
logger.debug(f'session end for connection 0x{session.connection.handle:04X}')
|
||||
if session.connection.handle in self.sessions:
|
||||
del self.sessions[session.connection.handle]
|
||||
|
||||
def get_long_term_key(self, connection, rand, ediv):
|
||||
def get_long_term_key(
|
||||
self, connection: Connection, rand: bytes, ediv: int
|
||||
) -> Optional[bytes]:
|
||||
if session := self.sessions.get(connection.handle):
|
||||
return session.get_long_term_key(rand, ediv)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user