This commit is contained in:
Gilles Boccon-Gibod
2025-07-28 13:36:55 -07:00
parent 0665e9ca5c
commit a6ead0147e
4 changed files with 86 additions and 153 deletions

View File

@@ -370,6 +370,12 @@ class Controller:
return connection
return None
def find_peripheral_connection_by_handle(self, handle):
for connection in self.peripheral_connections.values():
if connection.handle == handle:
return connection
return None
def find_classic_connection_by_handle(self, handle):
for connection in self.classic_connections.values():
if connection.handle == handle:
@@ -414,7 +420,7 @@ class Controller:
)
)
def on_link_central_disconnected(self, peer_address, reason):
def on_link_disconnected(self, peer_address, reason):
'''
Called when an active disconnection occurs from a peer
'''
@@ -431,6 +437,17 @@ class Controller:
# Remove the connection
del self.peripheral_connections[peer_address]
elif connection := self.central_connections.get(peer_address):
self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=connection.handle,
reason=reason,
)
)
# Remove the connection
del self.central_connections[peer_address]
else:
logger.warning(f'!!! No peripheral connection found for {peer_address}')
@@ -479,7 +496,7 @@ class Controller:
)
)
def on_link_peripheral_disconnection_complete(self, disconnection_command, status):
def on_link_disconnection_complete(self, disconnection_command, status):
'''
Called when a disconnection has been completed
'''
@@ -499,26 +516,11 @@ class Controller:
):
logger.debug(f'CENTRAL Connection removed: {connection}')
del self.central_connections[connection.peer_address]
def on_link_peripheral_disconnected(self, peer_address):
'''
Called when a connection to a peripheral is broken
'''
# Send a disconnection complete event
if connection := self.central_connections.get(peer_address):
self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=connection.handle,
reason=HCI_CONNECTION_TIMEOUT_ERROR,
)
)
# Remove the connection
del self.central_connections[peer_address]
else:
logger.warning(f'!!! No central connection found for {peer_address}')
elif connection := self.find_peripheral_connection_by_handle(
disconnection_command.connection_handle
):
logger.debug(f'PERIPHERAL Connection removed: {connection}')
del self.peripheral_connections[connection.peer_address]
def on_link_encrypted(self, peer_address, _rand, _ediv, _ltk):
# For now, just setup the encryption without asking the host
@@ -877,6 +879,14 @@ class Controller:
else:
# Remove the connection
del self.central_connections[connection.peer_address]
elif connection := self.find_peripheral_connection_by_handle(handle):
if self.link:
self.link.disconnect(
self.random_address, connection.peer_address, command
)
else:
# Remove the connection
del self.peripheral_connections[connection.peer_address]
elif connection := self.find_classic_connection_by_handle(handle):
if self.link:
self.link.classic_disconnect(

View File

@@ -159,29 +159,29 @@ class LocalLink:
asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete(
self, central_address, peripheral_address, disconnect_command
self, initiating_address, target_address, disconnect_command
):
# Find the controller that initiated the disconnection
if not (central_controller := self.find_controller(central_address)):
if not (initiating_controller := self.find_controller(initiating_address)):
logger.warning('!!! Initiating controller not found')
return
# Disconnect from the first controller with a matching address
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_central_disconnected(
central_address, disconnect_command.reason
if target_controller := self.find_controller(target_address):
target_controller.on_link_disconnected(
initiating_address, disconnect_command.reason
)
central_controller.on_link_peripheral_disconnection_complete(
initiating_controller.on_link_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command):
def disconnect(self, initiating_address, target_address, disconnect_command):
logger.debug(
f'$$$ DISCONNECTION {central_address} -> '
f'{peripheral_address}: reason = {disconnect_command.reason}'
f'$$$ DISCONNECTION {initiating_address} -> '
f'{target_address}: reason = {disconnect_command.reason}'
)
args = [central_address, peripheral_address, disconnect_command]
args = [initiating_address, target_address, disconnect_command]
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
# pylint: disable=too-many-arguments

View File

@@ -23,13 +23,9 @@ import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from bumble.controller import Controller
from bumble.core import PhysicalTransport
from bumble.link import LocalLink
from bumble.device import Device, Peer
from bumble.host import Host
from bumble.device import Peer
from bumble.gatt import Service, Characteristic
from bumble.transport.common import AsyncPipeSink
from bumble.pairing import PairingConfig, PairingDelegate
from bumble.smp import (
SMP_PAIRING_NOT_SUPPORTED_ERROR,
@@ -38,9 +34,10 @@ from bumble.smp import (
OobLegacyContext,
)
from bumble.core import ProtocolError
from bumble.keys import PairingKeys
from bumble.hci import Role
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Logging
@@ -49,63 +46,26 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class TwoDevices:
def __init__(self):
self.connections = [None, None]
addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0']
self.link = LocalLink()
self.controllers = [
Controller('C1', link=self.link, public_address=addresses[0]),
Controller('C2', link=self.link, public_address=addresses[1]),
]
self.devices = [
Device(
address=addresses[0],
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
),
Device(
address=addresses[1],
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
]
self.paired = [
asyncio.get_event_loop().create_future(),
asyncio.get_event_loop().create_future(),
]
def on_connection(self, which, connection):
self.connections[which] = connection
def on_paired(self, which: int, keys: PairingKeys):
self.paired[which].set_result(keys)
@pytest.mark.asyncio
async def test_self_connection():
two_devices = TwoDevices()
await two_devices.setup_connection()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_self_connection():
# Create two devices, each with a controller, attached to the same link
async def test_self_disconnection():
two_devices = TwoDevices()
await two_devices.setup_connection()
await two_devices.connections[0].disconnect()
assert two_devices.connections[0] is None
assert two_devices.connections[1] is None
# Attach listeners
two_devices.devices[0].on(
'connection', lambda connection: two_devices.on_connection(0, connection)
)
two_devices.devices[1].on(
'connection', lambda connection: two_devices.on_connection(1, connection)
)
# Start
await two_devices.devices[0].power_on()
await two_devices.devices[1].power_on()
# Connect the two devices
await two_devices.devices[0].connect(two_devices.devices[1].random_address)
# Check the post conditions
assert two_devices.connections[0] is not None
assert two_devices.connections[1] is not None
two_devices = TwoDevices()
await two_devices.setup_connection()
await two_devices.connections[1].disconnect()
assert two_devices.connections[0] is None
assert two_devices.connections[1] is None
# -----------------------------------------------------------------------------
@@ -115,24 +75,14 @@ async def test_self_connection():
(Role.CENTRAL, Role.PERIPHERAL),
)
async def test_self_classic_connection(responder_role):
# Create two devices, each with a controller, attached to the same link
two_devices = TwoDevices()
# Attach listeners
two_devices.devices[0].on(
'connection', lambda connection: two_devices.on_connection(0, connection)
)
two_devices.devices[1].on(
'connection', lambda connection: two_devices.on_connection(1, connection)
)
# Enable Classic connections
two_devices.devices[0].classic_enabled = True
two_devices.devices[1].classic_enabled = True
# Start
await two_devices.devices[0].power_on()
await two_devices.devices[1].power_on()
await two_devices.setup_connection()
# Connect the two devices
await asyncio.gather(
@@ -203,15 +153,9 @@ async def test_self_gatt():
s4 = Service('3A12C182-14E2-4FE0-8C5B-65D7C569F9DB', [], included_services=[s2, s3])
two_devices.devices[1].add_services([s1, s2, s4])
# Start
await two_devices.devices[0].power_on()
await two_devices.devices[1].power_on()
# Connect the two devices
connection = await two_devices.devices[0].connect(
two_devices.devices[1].random_address
)
peer = Peer(connection)
await two_devices.setup_connection()
peer = Peer(two_devices.connections[0])
bogus_uuid = 'A0AA6007-0B48-4BBE-80AC-0DE9AAF541EA'
result = await peer.discover_services([bogus_uuid])
@@ -264,15 +208,9 @@ async def test_self_gatt_long_read():
service = Service('8140E247-04F0-42C1-BC34-534C344DAFCA', characteristics)
two_devices.devices[1].add_service(service)
# Start
await two_devices.devices[0].power_on()
await two_devices.devices[1].power_on()
# Connect the two devices
connection = await two_devices.devices[0].connect(
two_devices.devices[1].random_address
)
peer = Peer(connection)
await two_devices.setup_connection()
peer = Peer(two_devices.connections[0])
result = await peer.discover_service(service.uuid)
assert len(result) == 1
@@ -289,25 +227,12 @@ async def _test_self_smp_with_configs(pairing_config1, pairing_config2):
# Create two devices, each with a controller, attached to the same link
two_devices = TwoDevices()
# Start
await two_devices.devices[0].power_on()
await two_devices.devices[1].power_on()
# Attach listeners
two_devices.devices[0].on(
'connection', lambda connection: two_devices.on_connection(0, connection)
)
two_devices.devices[1].on(
'connection', lambda connection: two_devices.on_connection(1, connection)
)
# Connect the two devices
connection = await two_devices.devices[0].connect(
two_devices.devices[1].random_address
)
await two_devices.setup_connection()
connection = two_devices.connections[0]
assert not connection.is_encrypted
# Attach connection listeners
# Attach pairing listeners
two_devices.connections[0].on(
'pairing', lambda keys: two_devices.on_paired(0, keys)
)
@@ -488,23 +413,13 @@ async def test_self_smp_over_classic():
# Create two devices, each with a controller, attached to the same link
two_devices = TwoDevices()
# Attach listeners
two_devices.devices[0].on(
'connection', lambda connection: two_devices.on_connection(0, connection)
)
two_devices.devices[1].on(
'connection', lambda connection: two_devices.on_connection(1, connection)
)
# Enable Classic connections
two_devices.devices[0].classic_enabled = True
two_devices.devices[1].classic_enabled = True
# Start
# Connect the two devices
await two_devices.devices[0].power_on()
await two_devices.devices[1].power_on()
# Connect the two devices
await asyncio.gather(
two_devices.devices[0].connect(
two_devices.devices[1].public_address, transport=PhysicalTransport.BR_EDR
@@ -650,6 +565,7 @@ async def test_self_smp_oob_legacy():
# -----------------------------------------------------------------------------
async def run_test_self():
await test_self_connection()
await test_self_disconnection()
await test_self_gatt()
await test_self_gatt_long_read()
await test_self_smp()

View File

@@ -25,6 +25,7 @@ from bumble.device import Device, Connection
from bumble.host import Host
from bumble.transport.common import AsyncPipeSink
from bumble.hci import Address
from bumble.keys import PairingKeys
# -----------------------------------------------------------------------------
@@ -51,16 +52,6 @@ class TwoDevices:
),
]
self.paired = [None, None]
def on_connection(self, which, connection):
self.connections[which] = connection
def on_paired(self, which, keys):
self.paired[which] = keys
async def setup_connection(self) -> None:
# Attach listeners
self.devices[0].on(
'connection', lambda connection: self.on_connection(0, connection)
)
@@ -68,6 +59,22 @@ class TwoDevices:
'connection', lambda connection: self.on_connection(1, connection)
)
self.paired = [
asyncio.get_event_loop().create_future(),
asyncio.get_event_loop().create_future(),
]
def on_connection(self, which, connection):
self.connections[which] = connection
connection.on('disconnection', lambda code: self.on_disconnection(which))
def on_disconnection(self, which):
self.connections[which] = None
def on_paired(self, which: int, keys: PairingKeys) -> None:
self.paired[which].set_result(keys)
async def setup_connection(self) -> None:
# Start
await self.devices[0].power_on()
await self.devices[1].power_on()