diff --git a/bumble/controller.py b/bumble/controller.py index da5d4cf1..cd7de3dc 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -21,7 +21,12 @@ import itertools import random import struct from bumble.colors import color -from bumble.core import BT_CENTRAL_ROLE, BT_PERIPHERAL_ROLE +from bumble.core import ( + BT_CENTRAL_ROLE, + BT_PERIPHERAL_ROLE, + BT_LE_TRANSPORT, + BT_BR_EDR_TRANSPORT, +) from bumble.hci import ( HCI_ACL_DATA_PACKET, @@ -29,17 +34,21 @@ from bumble.hci import ( HCI_COMMAND_PACKET, HCI_COMMAND_STATUS_PENDING, HCI_CONNECTION_TIMEOUT_ERROR, + HCI_CONTROLLER_BUSY_ERROR, HCI_EVENT_PACKET, HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR, HCI_LE_1M_PHY, HCI_SUCCESS, HCI_UNKNOWN_HCI_COMMAND_ERROR, + HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, HCI_VERSION_BLUETOOTH_CORE_5_0, Address, HCI_AclDataPacket, HCI_AclDataPacketAssembler, HCI_Command_Complete_Event, HCI_Command_Status_Event, + HCI_Connection_Complete_Event, + HCI_Connection_Request_Event, HCI_Disconnection_Complete_Event, HCI_Encryption_Change_Event, HCI_LE_Advertising_Report_Event, @@ -47,7 +56,9 @@ from bumble.hci import ( HCI_LE_Read_Remote_Features_Complete_Event, HCI_Number_Of_Completed_Packets_Event, HCI_Packet, + HCI_Role_Change_Event, ) +from typing import Optional, Union, Dict # ----------------------------------------------------------------------------- @@ -65,13 +76,14 @@ class DataObject: # ----------------------------------------------------------------------------- class Connection: - def __init__(self, controller, handle, role, peer_address, link): + def __init__(self, controller, handle, role, peer_address, link, transport): self.controller = controller self.handle = handle self.role = role self.peer_address = peer_address self.link = link self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) + self.transport = transport def on_hci_acl_data_packet(self, packet): self.assembler.feed_packet(packet) @@ -82,23 +94,33 @@ class Connection: def on_acl_pdu(self, data): if self.link: self.link.send_acl_data( - self.controller.random_address, self.peer_address, data + self.controller, self.peer_address, self.transport, data ) # ----------------------------------------------------------------------------- class Controller: - def __init__(self, name, host_source=None, host_sink=None, link=None): + def __init__( + self, + name, + host_source=None, + host_sink=None, + link=None, + public_address: Optional[Union[bytes, str, Address]] = None, + ): self.name = name self.hci_sink = None self.link = link - self.central_connections = ( - {} - ) # Connections where this controller is the central - self.peripheral_connections = ( - {} - ) # Connections where this controller is the peripheral + self.central_connections: Dict[ + Address, Connection + ] = {} # Connections where this controller is the central + self.peripheral_connections: Dict[ + Address, Connection + ] = {} # Connections where this controller is the peripheral + self.classic_connections: Dict[ + Address, Connection + ] = {} # Connections in BR/EDR self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0 self.hci_revision = 0 @@ -148,7 +170,14 @@ class Controller: self.advertising_timer_handle = None self._random_address = Address('00:00:00:00:00:00') - self._public_address = None + if isinstance(public_address, Address): + self._public_address = public_address + elif public_address is not None: + self._public_address = Address( + public_address, Address.PUBLIC_DEVICE_ADDRESS + ) + else: + self._public_address = Address('00:00:00:00:00:00') # Set the source and sink interfaces if host_source: @@ -271,7 +300,9 @@ class Controller: handle = 0 max_handle = 0 for connection in itertools.chain( - self.central_connections.values(), self.peripheral_connections.values() + self.central_connections.values(), + self.peripheral_connections.values(), + self.classic_connections.values(), ): max_handle = max(max_handle, connection.handle) if connection.handle == handle: @@ -279,14 +310,19 @@ class Controller: handle = max_handle + 1 return handle - def find_connection_by_address(self, address): + def find_le_connection_by_address(self, address): return self.central_connections.get(address) or self.peripheral_connections.get( address ) + def find_classic_connection_by_address(self, address): + return self.classic_connections.get(address) + def find_connection_by_handle(self, handle): for connection in itertools.chain( - self.central_connections.values(), self.peripheral_connections.values() + self.central_connections.values(), + self.peripheral_connections.values(), + self.classic_connections.values(), ): if connection.handle == handle: return connection @@ -298,6 +334,12 @@ class Controller: return connection return None + def find_classic_connection_by_handle(self, handle): + for connection in self.classic_connections.values(): + if connection.handle == handle: + return connection + return None + def on_link_central_connected(self, central_address): ''' Called when an incoming connection occurs from a central on the link @@ -310,7 +352,12 @@ class Controller: if connection is None: connection_handle = self.allocate_connection_handle() connection = Connection( - self, connection_handle, BT_PERIPHERAL_ROLE, peer_address, self.link + self, + connection_handle, + BT_PERIPHERAL_ROLE, + peer_address, + self.link, + BT_LE_TRANSPORT, ) self.peripheral_connections[peer_address] = connection logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}') @@ -364,7 +411,12 @@ class Controller: if connection is None: connection_handle = self.allocate_connection_handle() connection = Connection( - self, connection_handle, BT_CENTRAL_ROLE, peer_address, self.link + self, + connection_handle, + BT_CENTRAL_ROLE, + peer_address, + self.link, + BT_LE_TRANSPORT, ) self.central_connections[peer_address] = connection logger.debug( @@ -432,16 +484,19 @@ class Controller: def on_link_encrypted(self, peer_address, _rand, _ediv, _ltk): # For now, just setup the encryption without asking the host - if connection := self.find_connection_by_address(peer_address): + if connection := self.find_le_connection_by_address(peer_address): self.send_hci_packet( HCI_Encryption_Change_Event( status=0, connection_handle=connection.handle, encryption_enabled=1 ) ) - def on_link_acl_data(self, sender_address, data): + def on_link_acl_data(self, sender_address, transport, data): # Look for the connection to which this data belongs - connection = self.find_connection_by_address(sender_address) + if transport == BT_LE_TRANSPORT: + connection = self.find_le_connection_by_address(sender_address) + else: + connection = self.find_classic_connection_by_address(sender_address) if connection is None: logger.warning(f'!!! no connection for {sender_address}') return @@ -478,6 +533,87 @@ class Controller: ) self.send_hci_packet(HCI_LE_Advertising_Report_Event([report])) + ############################################################ + # Classic link connections + ############################################################ + + def on_classic_connection_request(self, peer_address, link_type): + self.send_hci_packet( + HCI_Connection_Request_Event( + bd_addr=peer_address, + class_of_device=0, + link_type=link_type, + ) + ) + + def on_classic_connection_complete(self, peer_address, status): + if status == HCI_SUCCESS: + # Allocate (or reuse) a connection handle + peer_address = peer_address + connection = self.classic_connections.get(peer_address) + if connection is None: + connection_handle = self.allocate_connection_handle() + connection = Connection( + controller=self, + handle=connection_handle, + # Role doesn't matter in Classic because they are managed by HCI_Role_Change and HCI_Role_Discovery + role=BT_CENTRAL_ROLE, + peer_address=peer_address, + link=self.link, + transport=BT_BR_EDR_TRANSPORT, + ) + self.classic_connections[peer_address] = connection + logger.debug( + f'New CLASSIC connection handle: 0x{connection_handle:04X}' + ) + else: + connection_handle = connection.handle + self.send_hci_packet( + HCI_Connection_Complete_Event( + status=status, + connection_handle=connection_handle, + bd_addr=peer_address, + encryption_enabled=False, + link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, + ) + ) + else: + connection = None + self.send_hci_packet( + HCI_Connection_Complete_Event( + status=status, + connection_handle=0, + bd_addr=peer_address, + encryption_enabled=False, + link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, + ) + ) + + def on_classic_disconnected(self, peer_address, reason): + # Send a disconnection complete event + if connection := self.classic_connections.get(peer_address): + self.send_hci_packet( + HCI_Disconnection_Complete_Event( + status=HCI_SUCCESS, + connection_handle=connection.handle, + reason=reason, + ) + ) + + # Remove the connection + del self.classic_connections[peer_address] + else: + logger.warning(f'!!! No classic connection found for {peer_address}') + + def on_classic_role_change(self, peer_address, new_role): + self.send_hci_packet( + HCI_Role_Change_Event( + status=HCI_SUCCESS, + bd_addr=peer_address, + new_role=new_role, + ) + ) + ############################################################ # Advertising support ############################################################ @@ -521,7 +657,31 @@ class Controller: See Bluetooth spec Vol 2, Part E - 7.1.5 Create Connection command ''' - # TODO: classic mode not supported yet + if self.link is None: + return + logger.debug(f'Connection request to {command.bd_addr}') + + # Check that we don't already have a pending connection + if self.link.get_pending_connection(): + self.send_hci_packet( + HCI_Command_Status_Event( + status=HCI_CONTROLLER_BUSY_ERROR, + num_hci_command_packets=1, + command_opcode=command.op_code, + ) + ) + return + + self.link.classic_connect(self, command.bd_addr) + + # Say that the connection is pending + self.send_hci_packet( + HCI_Command_Status_Event( + status=HCI_COMMAND_STATUS_PENDING, + num_hci_command_packets=1, + command_opcode=command.op_code, + ) + ) def on_hci_disconnect_command(self, command): ''' @@ -537,19 +697,57 @@ class Controller: ) # Notify the link of the disconnection - if not ( - connection := self.find_central_connection_by_handle( - command.connection_handle - ) - ): - logger.warning('connection not found') - return + handle = command.connection_handle + if connection := self.find_central_connection_by_handle(handle): + if self.link: + self.link.disconnect( + self.random_address, connection.peer_address, command + ) + else: + # Remove the connection + del self.central_connections[connection.peer_address] + elif connection := self.find_classic_connection_by_handle(handle): + if self.link: + self.link.classic_disconnect( + self, + connection.peer_address, + HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, + ) + else: + # Remove the connection + del self.classic_connections[connection.peer_address] - if self.link: - self.link.disconnect(self.random_address, connection.peer_address, command) - else: - # Remove the connection - del self.central_connections[connection.peer_address] + def on_hci_accept_connection_request_command(self, command): + ''' + See Bluetooth spec Vol 2, Part E - 7.1.8 Accept Connection Request command + ''' + + if self.link is None: + return + self.send_hci_packet( + HCI_Command_Status_Event( + status=HCI_SUCCESS, + num_hci_command_packets=1, + command_opcode=command.op_code, + ) + ) + self.link.classic_accept_connection(self, command.bd_addr, command.role) + + def on_hci_switch_role_command(self, command): + ''' + See Bluetooth spec Vol 2, Part E - 7.2.8 Switch Role command + ''' + + if self.link is None: + return + self.send_hci_packet( + HCI_Command_Status_Event( + status=HCI_SUCCESS, + num_hci_command_packets=1, + command_opcode=command.op_code, + ) + ) + self.link.classic_switch_role(self, command.bd_addr, command.role) def on_hci_set_event_mask_command(self, command): ''' @@ -627,6 +825,12 @@ class Controller: ret = HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR return bytes([ret]) + def on_hci_write_extended_inquiry_response_command(self, _command): + ''' + See Bluetooth spec Vol 2, Part E - 7.3.59 Write Simple Pairing Mode Command + ''' + return bytes([HCI_SUCCESS]) + def on_hci_write_simple_pairing_mode_command(self, _command): ''' See Bluetooth spec Vol 2, Part E - 7.3.59 Write Simple Pairing Mode Command diff --git a/bumble/link.py b/bumble/link.py index 82dd9db5..85ad96e4 100644 --- a/bumble/link.py +++ b/bumble/link.py @@ -19,12 +19,15 @@ import logging import asyncio from functools import partial +from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT from bumble.colors import color from bumble.hci import ( Address, HCI_SUCCESS, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR, HCI_CONNECTION_TIMEOUT_ERROR, + HCI_PAGE_TIMEOUT_ERROR, + HCI_Connection_Complete_Event, ) # ----------------------------------------------------------------------------- @@ -57,6 +60,11 @@ class LocalLink: def __init__(self): self.controllers = set() self.pending_connection = None + self.pending_classic_connection = None + + ############################################################ + # Common utils + ############################################################ def add_controller(self, controller): logger.debug(f'new controller: {controller}') @@ -71,22 +79,39 @@ class LocalLink: return controller return None - def on_address_changed(self, controller): - pass + def find_classic_controller(self, address): + for controller in self.controllers: + if controller.public_address == address: + return controller + return None def get_pending_connection(self): return self.pending_connection + ############################################################ + # LE handlers + ############################################################ + + def on_address_changed(self, controller): + pass + def send_advertising_data(self, sender_address, data): # Send the advertising data to all controllers, except the sender for controller in self.controllers: if controller.random_address != sender_address: controller.on_link_advertising_data(sender_address, data) - def send_acl_data(self, sender_address, destination_address, data): + def send_acl_data(self, sender_controller, destination_address, transport, data): # Send the data to the first controller with a matching address - if controller := self.find_controller(destination_address): - controller.on_link_acl_data(sender_address, data) + if transport == BT_LE_TRANSPORT: + destination_controller = self.find_controller(destination_address) + source_address = sender_controller.random_address + elif transport == BT_BR_EDR_TRANSPORT: + destination_controller = self.find_classic_controller(destination_address) + source_address = sender_controller.public_address + + if destination_controller is not None: + destination_controller.on_link_acl_data(source_address, transport, data) def on_connection_complete(self): # Check that we expect this call @@ -163,6 +188,89 @@ class LocalLink: if peripheral_controller := self.find_controller(peripheral_address): peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk) + ############################################################ + # Classic handlers + ############################################################ + + def classic_connect(self, initiator_controller, responder_address): + logger.debug( + f'[Classic] {initiator_controller.public_address} connects to {responder_address}' + ) + responder_controller = self.find_classic_controller(responder_address) + if responder_controller is None: + initiator_controller.on_classic_connection_complete( + responder_address, HCI_PAGE_TIMEOUT_ERROR + ) + return + self.pending_classic_connection = (initiator_controller, responder_controller) + + responder_controller.on_classic_connection_request( + initiator_controller.public_address, + HCI_Connection_Complete_Event.ACL_LINK_TYPE, + ) + + def classic_accept_connection( + self, responder_controller, initiator_address, responder_role + ): + logger.debug( + f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}' + ) + initiator_controller = self.find_classic_controller(initiator_address) + if initiator_controller is None: + responder_controller.on_classic_connection_complete( + responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR + ) + return + + async def task(): + if responder_role != BT_PERIPHERAL_ROLE: + initiator_controller.on_classic_role_change( + responder_controller.public_address, int(not (responder_role)) + ) + initiator_controller.on_classic_connection_complete( + responder_controller.public_address, HCI_SUCCESS + ) + + asyncio.create_task(task()) + responder_controller.on_classic_role_change( + initiator_controller.public_address, responder_role + ) + responder_controller.on_classic_connection_complete( + initiator_controller.public_address, HCI_SUCCESS + ) + self.pending_classic_connection = None + + def classic_disconnect(self, initiator_controller, responder_address, reason): + logger.debug( + f'[Classic] {initiator_controller.public_address} disconnects {responder_address}' + ) + responder_controller = self.find_classic_controller(responder_address) + + async def task(): + initiator_controller.on_classic_disconnected(responder_address, reason) + + asyncio.create_task(task()) + responder_controller.on_classic_disconnected( + initiator_controller.public_address, reason + ) + + def classic_switch_role( + self, initiator_controller, responder_address, initiator_new_role + ): + responder_controller = self.find_classic_controller(responder_address) + if responder_controller is None: + return + + async def task(): + initiator_controller.on_classic_role_change( + responder_address, initiator_new_role + ) + + asyncio.create_task(task()) + responder_controller.on_classic_role_change( + initiator_controller.public_address, int(not (initiator_new_role)) + ) + # ----------------------------------------------------------------------------- class RemoteLink: @@ -200,6 +308,9 @@ class RemoteLink: def get_pending_connection(self): return self.pending_connection + def get_pending_classic_connection(self): + return self.pending_classic_connection + async def wait_until_connected(self): await self.websocket @@ -366,7 +477,8 @@ class RemoteLink: async def send_acl_data_to_relay(self, peer_address, data): await self.send_targeted_message(peer_address, f'acl:{data.hex()}') - def send_acl_data(self, _, peer_address, data): + def send_acl_data(self, _, peer_address, _transport, data): + # TODO: handle different transport self.execute(partial(self.send_acl_data_to_relay, peer_address, data)) async def send_connection_request_to_relay(self, peer_address): diff --git a/tests/a2dp_test.py b/tests/a2dp_test.py index e4995315..92f7915f 100644 --- a/tests/a2dp_test.py +++ b/tests/a2dp_test.py @@ -21,6 +21,7 @@ import os import pytest from bumble.controller import Controller +from bumble.core import BT_BR_EDR_TRANSPORT from bumble.link import LocalLink from bumble.device import Device from bumble.host import Host @@ -58,18 +59,19 @@ class TwoDevices: def __init__(self): self.connections = [None, None] + addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0'] self.link = LocalLink() self.controllers = [ - Controller('C1', link=self.link), - Controller('C2', link=self.link), + Controller('C1', link=self.link, public_address=addresses[0]), + Controller('C2', link=self.link, public_address=addresses[1]), ] self.devices = [ Device( - address='F0:F1:F2:F3:F4:F5', + address=addresses[0], host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])), ), Device( - address='F5:F4:F3:F2:F1:F0', + address=addresses[1], host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])), ), ] @@ -79,6 +81,9 @@ class TwoDevices: def on_connection(self, which, connection): self.connections[which] = connection + def on_paired(self, which, keys): + self.paired[which] = keys + # ----------------------------------------------------------------------------- @pytest.mark.asyncio @@ -94,12 +99,21 @@ async def test_self_connection(): 'connection', lambda connection: two_devices.on_connection(1, connection) ) + # Enable Classic connections + two_devices.devices[0].classic_enabled = True + two_devices.devices[1].classic_enabled = True + # Start await two_devices.devices[0].power_on() await two_devices.devices[1].power_on() # Connect the two devices - await two_devices.devices[0].connect(two_devices.devices[1].random_address) + await asyncio.gather( + two_devices.devices[0].connect( + two_devices.devices[1].public_address, transport=BT_BR_EDR_TRANSPORT + ), + two_devices.devices[1].accept(two_devices.devices[0].public_address), + ) # Check the post conditions assert two_devices.connections[0] is not None @@ -152,6 +166,9 @@ def sink_codec_capabilities(): @pytest.mark.asyncio async def test_source_sink_1(): two_devices = TwoDevices() + # Enable Classic connections + two_devices.devices[0].classic_enabled = True + two_devices.devices[1].classic_enabled = True await two_devices.devices[0].power_on() await two_devices.devices[1].power_on() @@ -171,9 +188,16 @@ async def test_source_sink_1(): listener = Listener(Listener.create_registrar(two_devices.devices[1])) listener.on('connection', on_avdtp_connection) - connection = await two_devices.devices[0].connect( - two_devices.devices[1].random_address - ) + async def make_connection(): + connections = await asyncio.gather( + two_devices.devices[0].connect( + two_devices.devices[1].public_address, BT_BR_EDR_TRANSPORT + ), + two_devices.devices[1].accept(two_devices.devices[0].public_address), + ) + return connections[0] + + connection = await make_connection() client = await Protocol.connect(connection) endpoints = await client.discover_remote_endpoints() assert len(endpoints) == 1 diff --git a/tests/self_test.py b/tests/self_test.py index 751825f9..4ff2d43b 100644 --- a/tests/self_test.py +++ b/tests/self_test.py @@ -22,6 +22,7 @@ import os import pytest from bumble.controller import Controller +from bumble.core import BT_BR_EDR_TRANSPORT, BT_PERIPHERAL_ROLE, BT_CENTRAL_ROLE from bumble.link import LocalLink from bumble.device import Device, Peer from bumble.host import Host @@ -47,18 +48,19 @@ class TwoDevices: def __init__(self): self.connections = [None, None] + addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0'] self.link = LocalLink() self.controllers = [ - Controller('C1', link=self.link), - Controller('C2', link=self.link), + Controller('C1', link=self.link, public_address=addresses[0]), + Controller('C2', link=self.link, public_address=addresses[1]), ] self.devices = [ Device( - address='F0:F1:F2:F3:F4:F5', + address=addresses[0], host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])), ), Device( - address='F5:F4:F3:F2:F1:F0', + address=addresses[1], host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])), ), ] @@ -98,6 +100,49 @@ async def test_self_connection(): assert two_devices.connections[1] is not None +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'responder_role,', + (BT_CENTRAL_ROLE, BT_PERIPHERAL_ROLE), +) +async def test_self_classic_connection(responder_role): + # Create two devices, each with a controller, attached to the same link + two_devices = TwoDevices() + + # Attach listeners + two_devices.devices[0].on( + 'connection', lambda connection: two_devices.on_connection(0, connection) + ) + two_devices.devices[1].on( + 'connection', lambda connection: two_devices.on_connection(1, connection) + ) + + # Enable Classic connections + two_devices.devices[0].classic_enabled = True + two_devices.devices[1].classic_enabled = True + + # Start + await two_devices.devices[0].power_on() + await two_devices.devices[1].power_on() + + # Connect the two devices + await asyncio.gather( + two_devices.devices[0].connect( + two_devices.devices[1].public_address, transport=BT_BR_EDR_TRANSPORT + ), + two_devices.devices[1].accept( + two_devices.devices[0].public_address, responder_role + ), + ) + + # Check the post conditions + assert two_devices.connections[0] is not None + assert two_devices.connections[1] is not None + + await two_devices.connections[0].disconnect() + + # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_self_gatt():