diff --git a/bumble/controller.py b/bumble/controller.py index 356b7e4..5f74828 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -370,6 +370,12 @@ class Controller: return connection return None + def find_peripheral_connection_by_handle(self, handle): + for connection in self.peripheral_connections.values(): + if connection.handle == handle: + return connection + return None + def find_classic_connection_by_handle(self, handle): for connection in self.classic_connections.values(): if connection.handle == handle: @@ -414,7 +420,7 @@ class Controller: ) ) - def on_link_central_disconnected(self, peer_address, reason): + def on_link_disconnected(self, peer_address, reason): ''' Called when an active disconnection occurs from a peer ''' @@ -431,6 +437,17 @@ class Controller: # Remove the connection del self.peripheral_connections[peer_address] + elif connection := self.central_connections.get(peer_address): + self.send_hci_packet( + HCI_Disconnection_Complete_Event( + status=HCI_SUCCESS, + connection_handle=connection.handle, + reason=reason, + ) + ) + + # Remove the connection + del self.central_connections[peer_address] else: logger.warning(f'!!! No peripheral connection found for {peer_address}') @@ -479,7 +496,7 @@ class Controller: ) ) - def on_link_peripheral_disconnection_complete(self, disconnection_command, status): + def on_link_disconnection_complete(self, disconnection_command, status): ''' Called when a disconnection has been completed ''' @@ -499,26 +516,11 @@ class Controller: ): logger.debug(f'CENTRAL Connection removed: {connection}') del self.central_connections[connection.peer_address] - - def on_link_peripheral_disconnected(self, peer_address): - ''' - Called when a connection to a peripheral is broken - ''' - - # Send a disconnection complete event - if connection := self.central_connections.get(peer_address): - self.send_hci_packet( - HCI_Disconnection_Complete_Event( - status=HCI_SUCCESS, - connection_handle=connection.handle, - reason=HCI_CONNECTION_TIMEOUT_ERROR, - ) - ) - - # Remove the connection - del self.central_connections[peer_address] - else: - logger.warning(f'!!! No central connection found for {peer_address}') + elif connection := self.find_peripheral_connection_by_handle( + disconnection_command.connection_handle + ): + logger.debug(f'PERIPHERAL Connection removed: {connection}') + del self.peripheral_connections[connection.peer_address] def on_link_encrypted(self, peer_address, _rand, _ediv, _ltk): # For now, just setup the encryption without asking the host @@ -877,6 +879,14 @@ class Controller: else: # Remove the connection del self.central_connections[connection.peer_address] + elif connection := self.find_peripheral_connection_by_handle(handle): + if self.link: + self.link.disconnect( + self.random_address, connection.peer_address, command + ) + else: + # Remove the connection + del self.peripheral_connections[connection.peer_address] elif connection := self.find_classic_connection_by_handle(handle): if self.link: self.link.classic_disconnect( diff --git a/bumble/link.py b/bumble/link.py index c40e4b8..606df2c 100644 --- a/bumble/link.py +++ b/bumble/link.py @@ -159,29 +159,29 @@ class LocalLink: asyncio.get_running_loop().call_soon(self.on_connection_complete) def on_disconnection_complete( - self, central_address, peripheral_address, disconnect_command + self, initiating_address, target_address, disconnect_command ): # Find the controller that initiated the disconnection - if not (central_controller := self.find_controller(central_address)): + if not (initiating_controller := self.find_controller(initiating_address)): logger.warning('!!! Initiating controller not found') return # Disconnect from the first controller with a matching address - if peripheral_controller := self.find_controller(peripheral_address): - peripheral_controller.on_link_central_disconnected( - central_address, disconnect_command.reason + if target_controller := self.find_controller(target_address): + target_controller.on_link_disconnected( + initiating_address, disconnect_command.reason ) - central_controller.on_link_peripheral_disconnection_complete( + initiating_controller.on_link_disconnection_complete( disconnect_command, HCI_SUCCESS ) - def disconnect(self, central_address, peripheral_address, disconnect_command): + def disconnect(self, initiating_address, target_address, disconnect_command): logger.debug( - f'$$$ DISCONNECTION {central_address} -> ' - f'{peripheral_address}: reason = {disconnect_command.reason}' + f'$$$ DISCONNECTION {initiating_address} -> ' + f'{target_address}: reason = {disconnect_command.reason}' ) - args = [central_address, peripheral_address, disconnect_command] + args = [initiating_address, target_address, disconnect_command] asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args) # pylint: disable=too-many-arguments diff --git a/tests/self_test.py b/tests/self_test.py index 55efa6a..9098299 100644 --- a/tests/self_test.py +++ b/tests/self_test.py @@ -23,13 +23,9 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch -from bumble.controller import Controller from bumble.core import PhysicalTransport -from bumble.link import LocalLink -from bumble.device import Device, Peer -from bumble.host import Host +from bumble.device import Peer from bumble.gatt import Service, Characteristic -from bumble.transport.common import AsyncPipeSink from bumble.pairing import PairingConfig, PairingDelegate from bumble.smp import ( SMP_PAIRING_NOT_SUPPORTED_ERROR, @@ -38,9 +34,10 @@ from bumble.smp import ( OobLegacyContext, ) from bumble.core import ProtocolError -from bumble.keys import PairingKeys from bumble.hci import Role +from .test_utils import TwoDevices + # ----------------------------------------------------------------------------- # Logging @@ -49,63 +46,26 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -class TwoDevices: - def __init__(self): - self.connections = [None, None] - - addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0'] - self.link = LocalLink() - self.controllers = [ - Controller('C1', link=self.link, public_address=addresses[0]), - Controller('C2', link=self.link, public_address=addresses[1]), - ] - self.devices = [ - Device( - address=addresses[0], - host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])), - ), - Device( - address=addresses[1], - host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])), - ), - ] - - self.paired = [ - asyncio.get_event_loop().create_future(), - asyncio.get_event_loop().create_future(), - ] - - def on_connection(self, which, connection): - self.connections[which] = connection - - def on_paired(self, which: int, keys: PairingKeys): - self.paired[which].set_result(keys) +@pytest.mark.asyncio +async def test_self_connection(): + two_devices = TwoDevices() + await two_devices.setup_connection() # ----------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_self_connection(): - # Create two devices, each with a controller, attached to the same link +async def test_self_disconnection(): two_devices = TwoDevices() + await two_devices.setup_connection() + await two_devices.connections[0].disconnect() + assert two_devices.connections[0] is None + assert two_devices.connections[1] is None - # Attach listeners - two_devices.devices[0].on( - 'connection', lambda connection: two_devices.on_connection(0, connection) - ) - two_devices.devices[1].on( - 'connection', lambda connection: two_devices.on_connection(1, connection) - ) - - # Start - await two_devices.devices[0].power_on() - await two_devices.devices[1].power_on() - - # Connect the two devices - await two_devices.devices[0].connect(two_devices.devices[1].random_address) - - # Check the post conditions - assert two_devices.connections[0] is not None - assert two_devices.connections[1] is not None + two_devices = TwoDevices() + await two_devices.setup_connection() + await two_devices.connections[1].disconnect() + assert two_devices.connections[0] is None + assert two_devices.connections[1] is None # ----------------------------------------------------------------------------- @@ -115,24 +75,14 @@ async def test_self_connection(): (Role.CENTRAL, Role.PERIPHERAL), ) async def test_self_classic_connection(responder_role): - # Create two devices, each with a controller, attached to the same link two_devices = TwoDevices() - # Attach listeners - two_devices.devices[0].on( - 'connection', lambda connection: two_devices.on_connection(0, connection) - ) - two_devices.devices[1].on( - 'connection', lambda connection: two_devices.on_connection(1, connection) - ) - # Enable Classic connections two_devices.devices[0].classic_enabled = True two_devices.devices[1].classic_enabled = True # Start - await two_devices.devices[0].power_on() - await two_devices.devices[1].power_on() + await two_devices.setup_connection() # Connect the two devices await asyncio.gather( @@ -203,15 +153,9 @@ async def test_self_gatt(): s4 = Service('3A12C182-14E2-4FE0-8C5B-65D7C569F9DB', [], included_services=[s2, s3]) two_devices.devices[1].add_services([s1, s2, s4]) - # Start - await two_devices.devices[0].power_on() - await two_devices.devices[1].power_on() - # Connect the two devices - connection = await two_devices.devices[0].connect( - two_devices.devices[1].random_address - ) - peer = Peer(connection) + await two_devices.setup_connection() + peer = Peer(two_devices.connections[0]) bogus_uuid = 'A0AA6007-0B48-4BBE-80AC-0DE9AAF541EA' result = await peer.discover_services([bogus_uuid]) @@ -264,15 +208,9 @@ async def test_self_gatt_long_read(): service = Service('8140E247-04F0-42C1-BC34-534C344DAFCA', characteristics) two_devices.devices[1].add_service(service) - # Start - await two_devices.devices[0].power_on() - await two_devices.devices[1].power_on() - # Connect the two devices - connection = await two_devices.devices[0].connect( - two_devices.devices[1].random_address - ) - peer = Peer(connection) + await two_devices.setup_connection() + peer = Peer(two_devices.connections[0]) result = await peer.discover_service(service.uuid) assert len(result) == 1 @@ -289,25 +227,12 @@ async def _test_self_smp_with_configs(pairing_config1, pairing_config2): # Create two devices, each with a controller, attached to the same link two_devices = TwoDevices() - # Start - await two_devices.devices[0].power_on() - await two_devices.devices[1].power_on() - - # Attach listeners - two_devices.devices[0].on( - 'connection', lambda connection: two_devices.on_connection(0, connection) - ) - two_devices.devices[1].on( - 'connection', lambda connection: two_devices.on_connection(1, connection) - ) - # Connect the two devices - connection = await two_devices.devices[0].connect( - two_devices.devices[1].random_address - ) + await two_devices.setup_connection() + connection = two_devices.connections[0] assert not connection.is_encrypted - # Attach connection listeners + # Attach pairing listeners two_devices.connections[0].on( 'pairing', lambda keys: two_devices.on_paired(0, keys) ) @@ -488,23 +413,13 @@ async def test_self_smp_over_classic(): # Create two devices, each with a controller, attached to the same link two_devices = TwoDevices() - # Attach listeners - two_devices.devices[0].on( - 'connection', lambda connection: two_devices.on_connection(0, connection) - ) - two_devices.devices[1].on( - 'connection', lambda connection: two_devices.on_connection(1, connection) - ) - # Enable Classic connections two_devices.devices[0].classic_enabled = True two_devices.devices[1].classic_enabled = True - # Start + # Connect the two devices await two_devices.devices[0].power_on() await two_devices.devices[1].power_on() - - # Connect the two devices await asyncio.gather( two_devices.devices[0].connect( two_devices.devices[1].public_address, transport=PhysicalTransport.BR_EDR @@ -650,6 +565,7 @@ async def test_self_smp_oob_legacy(): # ----------------------------------------------------------------------------- async def run_test_self(): await test_self_connection() + await test_self_disconnection() await test_self_gatt() await test_self_gatt_long_read() await test_self_smp() diff --git a/tests/test_utils.py b/tests/test_utils.py index 136a3c6..be40ecb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -25,6 +25,7 @@ from bumble.device import Device, Connection from bumble.host import Host from bumble.transport.common import AsyncPipeSink from bumble.hci import Address +from bumble.keys import PairingKeys # ----------------------------------------------------------------------------- @@ -51,16 +52,6 @@ class TwoDevices: ), ] - self.paired = [None, None] - - def on_connection(self, which, connection): - self.connections[which] = connection - - def on_paired(self, which, keys): - self.paired[which] = keys - - async def setup_connection(self) -> None: - # Attach listeners self.devices[0].on( 'connection', lambda connection: self.on_connection(0, connection) ) @@ -68,6 +59,22 @@ class TwoDevices: 'connection', lambda connection: self.on_connection(1, connection) ) + self.paired = [ + asyncio.get_event_loop().create_future(), + asyncio.get_event_loop().create_future(), + ] + + def on_connection(self, which, connection): + self.connections[which] = connection + connection.on('disconnection', lambda code: self.on_disconnection(which)) + + def on_disconnection(self, which): + self.connections[which] = None + + def on_paired(self, which: int, keys: PairingKeys) -> None: + self.paired[which].set_result(keys) + + async def setup_connection(self) -> None: # Start await self.devices[0].power_on() await self.devices[1].power_on()