From 56ed46adfa1003fff0bfc7cd9aeb25fec9fcce5d Mon Sep 17 00:00:00 2001 From: Abel Lucas Date: Wed, 19 Oct 2022 19:00:03 +0000 Subject: [PATCH] classic: add BR/EDR accept connection logic --- bumble/device.py | 131 +++++++++++++++++++++++++++++++++++++++++++ bumble/hci.py | 21 +++++++ bumble/host.py | 13 ++--- tests/device_test.py | 17 ++++-- 4 files changed, 170 insertions(+), 12 deletions(-) diff --git a/bumble/device.py b/bumble/device.py index f68e330..13097a1 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -519,6 +519,7 @@ class DeviceConfiguration: self.le_simultaneous_enabled = True self.classic_sc_enabled = True self.classic_ssp_enabled = True + self.classic_accept_any = True self.connectable = True self.discoverable = True self.advertising_data = bytes( @@ -539,6 +540,7 @@ class DeviceConfiguration: self.le_simultaneous_enabled = config.get('le_simultaneous_enabled', self.le_simultaneous_enabled) self.classic_sc_enabled = config.get('classic_sc_enabled', self.classic_sc_enabled) self.classic_ssp_enabled = config.get('classic_ssp_enabled', self.classic_ssp_enabled) + self.classic_accept_any = config.get('classic_accept_any', self.classic_accept_any) self.connectable = config.get('connectable', self.connectable) self.discoverable = config.get('discoverable', self.discoverable) @@ -630,6 +632,9 @@ class Device(CompositeEventEmitter): def on_connection_failure(self, error): pass + def on_connection_request(self, bd_addr, class_of_device, link_type): + pass + def on_characteristic_subscription(self, connection, characteristic, notify_enabled, indicate_enabled): pass @@ -680,6 +685,7 @@ class Device(CompositeEventEmitter): self.classic_enabled = False self.inquiry_response = None self.address_resolver = None + self.classic_pending_accepts = { Address.ANY: [] } # Futures, by BD address OR [Futures] for Address.ANY # Use the initial config or a default self.public_address = Address('00:00:00:00:00:00') @@ -700,6 +706,7 @@ class Device(CompositeEventEmitter): self.classic_sc_enabled = config.classic_sc_enabled self.discoverable = config.discoverable self.connectable = config.connectable + self.classic_accept_any = config.classic_accept_any # If a name is passed, override the name from the config if name: @@ -1300,6 +1307,89 @@ class Device(CompositeEventEmitter): if transport == BT_LE_TRANSPORT: self.le_connecting = False + async def accept( + self, + peer_address=Address.ANY, + role=BT_PERIPHERAL_ROLE, + timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT + ): + ''' + Wait and accept any incoming connection or a connection from `peer_address` when set. + + Notes: + * A `connect` to the same peer will also complete this call. + * The `timeout` parameter is only handled while waiting for the connection request, + once received and accepeted, the controller shall issue a connection complete event. + ''' + + if type(peer_address) is str: + try: + peer_address = Address(peer_address) + except ValueError: + # If the address is not parsable, assume it is a name instead + logger.debug('looking for peer by name') + peer_address = await self.find_peer_by_name(peer_address, BT_BR_EDR_TRANSPORT) # TODO: timeout + + if peer_address == Address.NIL: + raise ValueError('accept on nil address') + + # Create a future so that we can wait for the request + pending_request = asyncio.get_running_loop().create_future() + + if peer_address == Address.ANY: + self.classic_pending_accepts[Address.ANY].append(pending_request) + elif peer_address in self.classic_pending_accepts: + raise InvalidStateError('accept connection already pending') + else: + self.classic_pending_accepts[peer_address] = pending_request + + try: + # Wait for a request or a completed connection + result = await (asyncio.wait_for(pending_request, timeout) if timeout else pending_request) + + except: + # Remove future from device context + if peer_address == Address.ANY: + self.classic_pending_accepts[Address.ANY].remove(pending_request) + else: + self.classic_pending_accepts.pop(peer_address) + raise + + # Result may already be a completed connection, + # see `on_connection` for details + if isinstance(result, Connection): + return result + + # Otherwise, result came from `on_connection_request` + peer_address, class_of_device, link_type = result + + def on_connection(connection): + if connection.transport == BT_BR_EDR_TRANSPORT and connection.peer_address == peer_address: + pending_connection.set_result(connection) + + def on_connection_failure(error): + if error.transport == BT_BR_EDR_TRANSPORT and error.peer_address == peer_address: + pending_connection.set_exception(error) + + # Create a future so that we can wait for the connection's result + pending_connection = asyncio.get_running_loop().create_future() + self.on('connection', on_connection) + self.on('connection_failure', on_connection_failure) + + try: + # Accept connection request + await self.send_command(HCI_Accept_Connection_Request_Command( + bd_addr = peer_address, + role = role + )) + + # Wait for connection complete + return await pending_connection + + finally: + self.remove_listener('connection', on_connection) + self.remove_listener('connection_failure', on_connection_failure) + @asynccontextmanager async def connect_as_gatt(self, peer_address): async with AsyncExitStack() as stack: @@ -1716,6 +1806,14 @@ class Device(CompositeEventEmitter): ) self.connections[connection_handle] = connection + # We may have an accept ongoing waiting for a connection request for `peer_address`. + # Typicaly happen when using `connect` to the same `peer_address` we are waiting with + # an `accept` for. + # In this case, set the completed `connection` to the `accept` future result. + if peer_address in self.classic_pending_accepts: + future = self.classic_pending_accepts.pop(peer_address) + future.set_result(connection) + # Emit an event to notify listeners of the new connection self.emit('connection', connection) else: @@ -1779,6 +1877,39 @@ class Device(CompositeEventEmitter): ) self.emit('connection_failure', error) + # FIXME: Explore a delegate-model for BR/EDR wait connection #56. + @host_event_handler + def on_connection_request(self, bd_addr, class_of_device, link_type): + logger.debug(f'*** Connection request: {bd_addr}') + + # match a pending future using `bd_addr` + if bd_addr in self.classic_pending_accepts: + future = self.classic_pending_accepts.pop(bd_addr) + future.set_result((bd_addr, class_of_device, link_type)) + + # match first pending future for ANY address + elif len(self.classic_pending_accepts[Address.ANY]) > 0: + future = self.classic_pending_accepts[Address.ANY].pop(0) + future.set_result((bd_addr, class_of_device, link_type)) + + # device configuration is set to accept any incoming connection + elif self.classic_accept_any: + self.host.send_command_sync( + HCI_Accept_Connection_Request_Command( + bd_addr = bd_addr, + role = 0x01 # Remain the peripheral + ) + ) + + # reject incoming connection + else: + self.host.send_command_sync( + HCI_Reject_Connection_Request_Command( + bd_addr = bd_addr, + reason = HCI_CONNECTION_REJECTED_DUE_TO_LIMITED_RESOURCES_ERROR + ) + ) + @host_event_handler @with_connection_from_handle def on_disconnection(self, connection, reason): diff --git a/bumble/hci.py b/bumble/hci.py index af26374..d4cf7cc 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -1652,6 +1652,16 @@ class Address: ADDRESS_TYPE_SPEC = {'size': 1, 'mapper': lambda x: Address.address_type_name(x)} + @classmethod + @property + def ANY(cls): + return cls(b"\xff\xff\xff\xff\xff\xff", cls.PUBLIC_DEVICE_ADDRESS) + + @classmethod + @property + def NIL(cls): + return cls(b"\x00\x00\x00\x00\x00\x00", cls.PUBLIC_DEVICE_ADDRESS) + @staticmethod def address_type_name(address_type): return name_or_number(Address.ADDRESS_TYPE_NAMES, address_type) @@ -1935,6 +1945,17 @@ class HCI_Accept_Connection_Request_Command(HCI_Command): ''' +# ----------------------------------------------------------------------------- +@HCI_Command.command([ + ('bd_addr', Address.parse_address), + ('reason', {'size': 1, 'mapper': HCI_Constant.error_name}) +]) +class HCI_Reject_Connection_Request_Command(HCI_Command): + ''' + See Bluetooth spec @ 7.1.9 Reject Connection Request Command + ''' + + # ----------------------------------------------------------------------------- @HCI_Command.command([ ('bd_addr', Address.parse_address), diff --git a/bumble/host.py b/bumble/host.py index 01c25a4..32b2194 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -347,13 +347,12 @@ class Host(EventEmitter): # Classic only def on_hci_connection_request_event(self, event): - # For now, just accept everything - # TODO: delegate the decision - self.send_command_sync( - HCI_Accept_Connection_Request_Command( - bd_addr = event.bd_addr, - role = 0x01 # Remain the peripheral - ) + # Notify the listeners + self.emit( + 'connection_request', + event.bd_addr, + event.class_of_device, + event.link_type, ) def on_hci_le_connection_complete_event(self, event): diff --git a/tests/device_test.py b/tests/device_test.py index cd72c4c..acf4446 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -158,16 +158,23 @@ async def test_device_connect_parallel(): d1.host.set_packet_sink(Sink(d1_flow())) d2.host.set_packet_sink(Sink(d2_flow())) - [c1, c2] = await asyncio.gather(*[ + [c01, c02, a10, a20, a01] = await asyncio.gather(*[ asyncio.create_task(d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT)), asyncio.create_task(d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT)), + asyncio.create_task(d1.accept(peer_address=d0.public_address)), + asyncio.create_task(d2.accept()), + asyncio.create_task(d0.accept(peer_address=d1.public_address)), ]) - assert type(c1) == Connection - assert type(c2) == Connection + assert type(c01) == Connection + assert type(c02) == Connection + assert type(a10) == Connection + assert type(a20) == Connection + assert type(a01) == Connection - assert c1.handle == 0x100 - assert c2.handle == 0x101 + assert c01.handle == a10.handle and c01.handle == 0x100 + assert c02.handle == a20.handle and c02.handle == 0x101 + assert a01 == c01 # -----------------------------------------------------------------------------