resolve when bonded

This commit is contained in:
Gilles Boccon-Gibod
2026-01-29 19:41:41 -08:00
4 changed files with 330 additions and 268 deletions

View File

@@ -22,7 +22,7 @@ import click
import bumble.logging
from bumble import data_types
from bumble.colors import color
from bumble.device import Advertisement, Device
from bumble.device import Advertisement, Device, DeviceConfiguration
from bumble.hci import HCI_LE_1M_PHY, HCI_LE_CODED_PHY, Address, HCI_Constant
from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver
@@ -144,8 +144,14 @@ async def scan(
device_config, hci_source, hci_sink
)
else:
device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
device = Device.from_config_with_hci(
DeviceConfiguration(
name='Bumble',
address=Address('F0:F1:F2:F3:F4:F5'),
keystore='JsonKeyStore',
),
hci_source,
hci_sink,
)
await device.power_on()

View File

@@ -3737,6 +3737,292 @@ class Device(utils.CompositeEventEmitter):
page_scan_enabled=self.connectable,
)
async def connect_le(
self,
peer_address: hci.Address | str,
connection_parameters_preferences: (
dict[hci.Phy, ConnectionParametersPreferences] | None
) = None,
own_address_type: hci.OwnAddressType = hci.OwnAddressType.RANDOM,
timeout: float | None = DEVICE_DEFAULT_CONNECT_TIMEOUT,
) -> Connection:
# Check that there isn't already a pending connection
if self.is_le_connecting:
raise InvalidStateError('connection already pending')
try_resolve = True
if isinstance(peer_address, str):
try:
peer_address = hci.Address.from_string_for_transport(
peer_address, PhysicalTransport.LE
)
except (InvalidArgumentError, ValueError):
# If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name')
assert isinstance(peer_address, str)
peer_address = await self.find_peer_by_name(
peer_address, PhysicalTransport.LE
) # TODO: timeout
try_resolve = False
assert isinstance(peer_address, hci.Address)
if (
try_resolve
and self.address_resolver is not None
and self.address_resolver.can_resolve_to(peer_address)
):
# If we have an IRK for this address, we should resolve.
logger.debug('have IRK for address, resolving...')
peer_address = await self.find_peer_by_identity_address(
peer_address
) # TODO: timeout
def on_connection(connection):
pending_connection.set_result(connection)
def on_connection_failure(error: core.ConnectionError):
pending_connection.set_exception(error)
# Create a future so that we can wait for the connection result
pending_connection = asyncio.get_running_loop().create_future()
self.on(self.EVENT_CONNECTION, on_connection)
self.on(self.EVENT_CONNECTION_FAILURE, on_connection_failure)
try:
# Tell the controller to connect
if connection_parameters_preferences is None:
connection_parameters_preferences = {
hci.HCI_LE_1M_PHY: ConnectionParametersPreferences.default
}
self.connect_own_address_type = own_address_type
if self.host.supports_command(
hci.HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND
):
# Only keep supported PHYs
phys = sorted(
list(
set(
filter(
self.supports_le_phy,
connection_parameters_preferences.keys(),
)
)
)
)
if not phys:
raise InvalidArgumentError('at least one supported PHY needed')
phy_count = len(phys)
initiating_phys = hci.phy_list_to_bits(phys)
connection_interval_mins = [
int(
connection_parameters_preferences[phy].connection_interval_min
/ 1.25
)
for phy in phys
]
connection_interval_maxs = [
int(
connection_parameters_preferences[phy].connection_interval_max
/ 1.25
)
for phy in phys
]
max_latencies = [
connection_parameters_preferences[phy].max_latency for phy in phys
]
supervision_timeouts = [
int(connection_parameters_preferences[phy].supervision_timeout / 10)
for phy in phys
]
min_ce_lengths = [
int(connection_parameters_preferences[phy].min_ce_length / 0.625)
for phy in phys
]
max_ce_lengths = [
int(connection_parameters_preferences[phy].max_ce_length / 0.625)
for phy in phys
]
await self.send_async_command(
hci.HCI_LE_Extended_Create_Connection_Command(
initiator_filter_policy=0,
own_address_type=own_address_type,
peer_address_type=peer_address.address_type,
peer_address=peer_address,
initiating_phys=initiating_phys,
scan_intervals=(
int(DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625),
)
* phy_count,
scan_windows=(int(DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625),)
* phy_count,
connection_interval_mins=connection_interval_mins,
connection_interval_maxs=connection_interval_maxs,
max_latencies=max_latencies,
supervision_timeouts=supervision_timeouts,
min_ce_lengths=min_ce_lengths,
max_ce_lengths=max_ce_lengths,
)
)
else:
if hci.HCI_LE_1M_PHY not in connection_parameters_preferences:
raise InvalidArgumentError('1M PHY preferences required')
prefs = connection_parameters_preferences[hci.HCI_LE_1M_PHY]
await self.send_async_command(
hci.HCI_LE_Create_Connection_Command(
le_scan_interval=int(
DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625
),
le_scan_window=int(DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625),
initiator_filter_policy=0,
peer_address_type=peer_address.address_type,
peer_address=peer_address,
own_address_type=own_address_type,
connection_interval_min=int(
prefs.connection_interval_min / 1.25
),
connection_interval_max=int(
prefs.connection_interval_max / 1.25
),
max_latency=prefs.max_latency,
supervision_timeout=int(prefs.supervision_timeout / 10),
min_ce_length=int(prefs.min_ce_length / 0.625),
max_ce_length=int(prefs.max_ce_length / 0.625),
)
)
# Wait for the connection process to complete
self.le_connecting = True
if timeout is None:
return await utils.cancel_on_event(
self, Device.EVENT_FLUSH, pending_connection
)
try:
return await asyncio.wait_for(
asyncio.shield(pending_connection), timeout
)
except asyncio.TimeoutError:
await self.send_sync_command(
hci.HCI_LE_Create_Connection_Cancel_Command()
)
try:
return await utils.cancel_on_event(
self, Device.EVENT_FLUSH, pending_connection
)
except core.ConnectionError as error:
raise core.TimeoutError() from error
finally:
self.remove_listener(self.EVENT_CONNECTION, on_connection)
self.remove_listener(self.EVENT_CONNECTION_FAILURE, on_connection_failure)
self.le_connecting = False
self.connect_own_address_type = None
async def connect_classic(
self,
peer_address: hci.Address | str,
timeout: float | None = DEVICE_DEFAULT_CONNECT_TIMEOUT,
) -> Connection:
if isinstance(peer_address, str):
try:
peer_address = hci.Address.from_string_for_transport(
peer_address, PhysicalTransport.BR_EDR
)
except (InvalidArgumentError, ValueError):
# If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name')
assert isinstance(peer_address, str)
peer_address = await self.find_peer_by_name(
peer_address, PhysicalTransport.BR_EDR
) # TODO: timeout
else:
# All BR/EDR addresses should be public addresses
if peer_address.address_type != hci.Address.PUBLIC_DEVICE_ADDRESS:
raise InvalidArgumentError('BR/EDR addresses must be PUBLIC')
assert isinstance(peer_address, hci.Address)
def on_connection(connection):
if (
# match BR/EDR connection event against peer address
connection.transport == PhysicalTransport.BR_EDR
and connection.peer_address == peer_address
):
pending_connection.set_result(connection)
def on_connection_failure(error: core.ConnectionError):
if (
# match BR/EDR connection failure event against peer address
error.transport == PhysicalTransport.BR_EDR
and error.peer_address == peer_address
):
pending_connection.set_exception(error)
# Create a future so that we can wait for the connection result
pending_connection = asyncio.get_running_loop().create_future()
self.on(self.EVENT_CONNECTION, on_connection)
self.on(self.EVENT_CONNECTION_FAILURE, on_connection_failure)
try:
# Save pending connection
self.pending_connections[peer_address] = Connection(
device=self,
handle=0,
transport=core.PhysicalTransport.BR_EDR,
self_address=self.public_address,
self_resolvable_address=None,
peer_address=peer_address,
peer_resolvable_address=None,
role=hci.Role.CENTRAL,
parameters=Connection.Parameters(0, 0, 0),
)
# TODO: allow passing other settings
await self.send_async_command(
hci.HCI_Create_Connection_Command(
bd_addr=peer_address,
packet_type=0xCC18, # FIXME: change
page_scan_repetition_mode=hci.HCI_R2_PAGE_SCAN_REPETITION_MODE,
clock_offset=0x0000,
allow_role_switch=0x01,
reserved=0,
)
)
# Wait for the connection process to complete
if timeout is None:
return await utils.cancel_on_event(
self, Device.EVENT_FLUSH, pending_connection
)
try:
return await asyncio.wait_for(
asyncio.shield(pending_connection), timeout
)
except asyncio.TimeoutError:
await self.send_sync_command(
hci.HCI_Create_Connection_Cancel_Command(bd_addr=peer_address)
)
try:
return await utils.cancel_on_event(
self, Device.EVENT_FLUSH, pending_connection
)
except core.ConnectionError as error:
raise core.TimeoutError() from error
finally:
self.remove_listener(self.EVENT_CONNECTION, on_connection)
self.remove_listener(self.EVENT_CONNECTION_FAILURE, on_connection_failure)
self.pending_connections.pop(peer_address, None)
async def connect(
self,
peer_address: hci.Address | str,
@@ -3758,9 +4044,9 @@ class Device(utils.CompositeEventEmitter):
peer_address:
hci.Address or name of the device to connect to.
If a string is passed:
If the string is an address followed by a `@` suffix, the `always_resolve`
argument is implicitly set to True, so the connection is made to the
address after resolution.
[deprecated] If the string is an address followed by a `@` suffix, the
`always_resolve`argument is implicitly set to True, so the connection is
made to the address after resolution.
If the string is any other address, the connection is made to that
address (with or without address resolution, depending on the
`always_resolve` argument).
@@ -3784,271 +4070,32 @@ class Device(utils.CompositeEventEmitter):
Pass None for an unlimited time.
always_resolve:
(BLE only, ignored for BR/EDR)
If True, always initiate a scan, resolving addresses, and connect to the
address that resolves to `peer_address`.
[deprecated] (ignore)
'''
# Check parameters
if transport not in (PhysicalTransport.LE, PhysicalTransport.BR_EDR):
raise InvalidArgumentError('invalid transport')
transport = core.PhysicalTransport(transport)
# Adjust the transport automatically if we need to
if transport == PhysicalTransport.LE and not self.le_enabled:
transport = PhysicalTransport.BR_EDR
elif transport == PhysicalTransport.BR_EDR and not self.classic_enabled:
transport = PhysicalTransport.LE
# Check that there isn't already a pending connection
if transport == PhysicalTransport.LE and self.is_le_connecting:
raise InvalidStateError('connection already pending')
if isinstance(peer_address, str):
try:
if transport == PhysicalTransport.LE and peer_address.endswith('@'):
peer_address = hci.Address.from_string_for_transport(
peer_address[:-1], transport
)
always_resolve = True
logger.debug('forcing address resolution')
else:
peer_address = hci.Address.from_string_for_transport(
peer_address, transport
)
except (InvalidArgumentError, ValueError):
# If the address is not parsable, assume it is a name instead
always_resolve = False
logger.debug('looking for peer by name')
assert isinstance(peer_address, str)
peer_address = await self.find_peer_by_name(
peer_address, transport
) # TODO: timeout
# Connect using the appropriate transport
# (auto-correct the transport based on declared capabilities)
if transport == PhysicalTransport.LE or (
self.le_enabled and not self.classic_enabled
):
return await self.connect_le(
peer_address=peer_address,
connection_parameters_preferences=connection_parameters_preferences,
own_address_type=own_address_type,
timeout=timeout,
)
elif transport == PhysicalTransport.BR_EDR or (
self.classic_enabled and not self.le_enabled
):
return await self.connect_classic(
peer_address=peer_address, timeout=timeout
)
else:
# All BR/EDR addresses should be public addresses
if (
transport == PhysicalTransport.BR_EDR
and peer_address.address_type != hci.Address.PUBLIC_DEVICE_ADDRESS
):
raise InvalidArgumentError('BR/EDR addresses must be PUBLIC')
assert isinstance(peer_address, hci.Address)
if transport == PhysicalTransport.LE and always_resolve:
logger.debug('resolving address')
peer_address = await self.find_peer_by_identity_address(
peer_address
) # TODO: timeout
def on_connection(connection):
if transport == PhysicalTransport.LE or (
# match BR/EDR connection event against peer address
connection.transport == transport
and connection.peer_address == peer_address
):
pending_connection.set_result(connection)
def on_connection_failure(error: core.ConnectionError):
if transport == PhysicalTransport.LE or (
# match BR/EDR connection failure event against peer address
error.transport == transport
and error.peer_address == peer_address
):
pending_connection.set_exception(error)
# Create a future so that we can wait for the connection's result
pending_connection = asyncio.get_running_loop().create_future()
self.on(self.EVENT_CONNECTION, on_connection)
self.on(self.EVENT_CONNECTION_FAILURE, on_connection_failure)
try:
# Tell the controller to connect
if transport == PhysicalTransport.LE:
if connection_parameters_preferences is None:
if connection_parameters_preferences is None:
connection_parameters_preferences = {
hci.HCI_LE_1M_PHY: ConnectionParametersPreferences.default
}
self.connect_own_address_type = own_address_type
if self.host.supports_command(
hci.HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND
):
# Only keep supported PHYs
phys = sorted(
list(
set(
filter(
self.supports_le_phy,
connection_parameters_preferences.keys(),
)
)
)
)
if not phys:
raise InvalidArgumentError('at least one supported PHY needed')
phy_count = len(phys)
initiating_phys = hci.phy_list_to_bits(phys)
connection_interval_mins = [
int(
connection_parameters_preferences[
phy
].connection_interval_min
/ 1.25
)
for phy in phys
]
connection_interval_maxs = [
int(
connection_parameters_preferences[
phy
].connection_interval_max
/ 1.25
)
for phy in phys
]
max_latencies = [
connection_parameters_preferences[phy].max_latency
for phy in phys
]
supervision_timeouts = [
int(
connection_parameters_preferences[phy].supervision_timeout
/ 10
)
for phy in phys
]
min_ce_lengths = [
int(
connection_parameters_preferences[phy].min_ce_length / 0.625
)
for phy in phys
]
max_ce_lengths = [
int(
connection_parameters_preferences[phy].max_ce_length / 0.625
)
for phy in phys
]
await self.send_async_command(
hci.HCI_LE_Extended_Create_Connection_Command(
initiator_filter_policy=0,
own_address_type=own_address_type,
peer_address_type=peer_address.address_type,
peer_address=peer_address,
initiating_phys=initiating_phys,
scan_intervals=(
int(DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625),
)
* phy_count,
scan_windows=(
int(DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625),
)
* phy_count,
connection_interval_mins=connection_interval_mins,
connection_interval_maxs=connection_interval_maxs,
max_latencies=max_latencies,
supervision_timeouts=supervision_timeouts,
min_ce_lengths=min_ce_lengths,
max_ce_lengths=max_ce_lengths,
)
)
else:
if hci.HCI_LE_1M_PHY not in connection_parameters_preferences:
raise InvalidArgumentError('1M PHY preferences required')
prefs = connection_parameters_preferences[hci.HCI_LE_1M_PHY]
await self.send_async_command(
hci.HCI_LE_Create_Connection_Command(
le_scan_interval=int(
DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625
),
le_scan_window=int(
DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625
),
initiator_filter_policy=0,
peer_address_type=peer_address.address_type,
peer_address=peer_address,
own_address_type=own_address_type,
connection_interval_min=int(
prefs.connection_interval_min / 1.25
),
connection_interval_max=int(
prefs.connection_interval_max / 1.25
),
max_latency=prefs.max_latency,
supervision_timeout=int(prefs.supervision_timeout / 10),
min_ce_length=int(prefs.min_ce_length / 0.625),
max_ce_length=int(prefs.max_ce_length / 0.625),
)
)
else:
# Save pending connection
self.pending_connections[peer_address] = Connection(
device=self,
handle=0,
transport=core.PhysicalTransport.BR_EDR,
self_address=self.public_address,
self_resolvable_address=None,
peer_address=peer_address,
peer_resolvable_address=None,
role=hci.Role.CENTRAL,
parameters=Connection.Parameters(0, 0, 0),
)
# TODO: allow passing other settings
await self.send_async_command(
hci.HCI_Create_Connection_Command(
bd_addr=peer_address,
packet_type=0xCC18, # FIXME: change
page_scan_repetition_mode=hci.HCI_R2_PAGE_SCAN_REPETITION_MODE,
clock_offset=0x0000,
allow_role_switch=0x01,
reserved=0,
)
)
# Wait for the connection process to complete
if transport == PhysicalTransport.LE:
self.le_connecting = True
if timeout is None:
return await utils.cancel_on_event(
self, Device.EVENT_FLUSH, pending_connection
)
try:
return await asyncio.wait_for(
asyncio.shield(pending_connection), timeout
)
except asyncio.TimeoutError:
if transport == PhysicalTransport.LE:
await self.send_sync_command(
hci.HCI_LE_Create_Connection_Cancel_Command()
)
else:
await self.send_sync_command(
hci.HCI_Create_Connection_Cancel_Command(bd_addr=peer_address)
)
try:
return await utils.cancel_on_event(
self, Device.EVENT_FLUSH, pending_connection
)
except core.ConnectionError as error:
raise core.TimeoutError() from error
finally:
self.remove_listener(self.EVENT_CONNECTION, on_connection)
self.remove_listener(self.EVENT_CONNECTION_FAILURE, on_connection_failure)
if transport == PhysicalTransport.LE:
self.le_connecting = False
self.connect_own_address_type = None
else:
self.pending_connections.pop(peer_address, None)
raise InvalidArgumentError('no supported transport for request')
async def accept(
self,
@@ -4695,6 +4742,8 @@ class Device(utils.CompositeEventEmitter):
Scan for a peer with a resolvable address that can be resolved to a given
identity address.
"""
if self.address_resolver is None:
raise InvalidStateError('no resolver')
# Create a future to wait for an address to be found
peer_address = asyncio.get_running_loop().create_future()

View File

@@ -803,7 +803,9 @@ class Host(utils.EventEmitter):
data=pdu,
)
logger.debug(
'>>> ACL packet enqueue: (Handle=0x%04X) %s', connection_handle, pdu
'>>> ACL packet enqueue: (handle=0x%04X) %s',
connection_handle,
pdu.hex(),
)
packet_queue.enqueue(acl_packet, connection_handle)

View File

@@ -27,7 +27,7 @@ from __future__ import annotations
import asyncio
import enum
import logging
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, ClassVar, TypeVar, cast
@@ -507,10 +507,15 @@ def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool)
# -----------------------------------------------------------------------------
class AddressResolver:
def __init__(self, resolving_keys):
def __init__(self, resolving_keys: Sequence[tuple[bytes, Address]]) -> None:
self.resolving_keys = resolving_keys
def resolve(self, address):
def can_resolve_to(self, address: Address) -> bool:
return any(
resolved_address == address for _, resolved_address in self.resolving_keys
)
def resolve(self, address: Address) -> Address | None:
address_bytes = bytes(address)
hash_part = address_bytes[0:3]
prand = address_bytes[3:6]