From 7208fd6642f6628b805c357dd3cc6142d5d02706 Mon Sep 17 00:00:00 2001 From: Abel Lucas Date: Wed, 19 Oct 2022 16:39:06 +0000 Subject: [PATCH 1/4] classic: update `Device.connect` to allow parallels connection creation According to the specification nothing prevent the Host from creating multiple connections at the same time. This commit add this mechanisme by matching the `connection` and `connection_failure` events against the peer address. --- bumble/device.py | 35 ++++++++++++++++++++++++++--------- bumble/host.py | 4 ++-- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/bumble/device.py b/bumble/device.py index ae15088..5393573 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -1144,7 +1144,7 @@ class Device(CompositeEventEmitter): transport = BT_LE_TRANSPORT # Check that there isn't already a pending connection - if self.is_connecting: + if transport == BT_LE_TRANSPORT and self.is_connecting: raise InvalidStateError('connection already pending') if type(peer_address) is str: @@ -1155,10 +1155,22 @@ class Device(CompositeEventEmitter): logger.debug('looking for peer by name') peer_address = await self.find_peer_by_name(peer_address, transport) # TODO: timeout + def on_connection(connection): + if transport == BT_LE_TRANSPORT or ( + # match BR/EDR connection event against peer address + connection.transport == transport and connection.peer_address == peer_address): + pending_connection.set_result(connection) + + def on_connection_failure(error): + if transport == BT_LE_TRANSPORT or ( + # match BR/EDR connection failure event against peer address + error.transport == transport and error.peer_address == peer_address): + pending_connection.set_exception(error) + # Create a future so that we can wait for the connection's result pending_connection = asyncio.get_running_loop().create_future() - self.on('connection', pending_connection.set_result) - self.on('connection_failure', pending_connection.set_exception) + self.on('connection', on_connection) + self.on('connection_failure', on_connection_failure) try: # Tell the controller to connect @@ -1248,7 +1260,8 @@ class Device(CompositeEventEmitter): raise HCI_StatusError(result) # Wait for the connection process to complete - self.connecting = True + if transport == BT_LE_TRANSPORT: + self.connecting = True if timeout is None: return await pending_connection else: @@ -1265,9 +1278,10 @@ class Device(CompositeEventEmitter): except ConnectionError: raise TimeoutError() finally: - self.remove_listener('connection', pending_connection.set_result) - self.remove_listener('connection_failure', pending_connection.set_exception) - self.connecting = False + self.remove_listener('connection', on_connection) + self.remove_listener('connection_failure', on_connection_failure) + if transport == BT_LE_TRANSPORT: + self.connecting = False @asynccontextmanager async def connect_as_gatt(self, peer_address): @@ -1704,11 +1718,11 @@ class Device(CompositeEventEmitter): asyncio.create_task(new_connection()) @host_event_handler - def on_connection_failure(self, connection_handle, error_code): + def on_connection_failure(self, transport, connection_handle, peer_address, error_code): logger.debug(f'*** Connection failed: {HCI_Constant.error_name(error_code)}') # For directed advertising, this means a timeout - if self.advertising and self.advertising_type.is_directed: + if transport == BT_LE_TRANSPORT and self.advertising and self.advertising_type.is_directed: self.advertising = False # Notify listeners @@ -1717,6 +1731,9 @@ class Device(CompositeEventEmitter): 'hci', HCI_Constant.error_name(error_code) ) + error.transport = transport + error.connection_handle = connection_handle # FIXME: Connection handle sounds to be a dummy value here + error.peer_address = peer_address self.emit('connection_failure', error) @host_event_handler diff --git a/bumble/host.py b/bumble/host.py index 35efad4..276d334 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -383,7 +383,7 @@ class Host(EventEmitter): logger.debug(f'### CONNECTION FAILED: {event.status}') # Notify the listeners - self.emit('connection_failure', event.connection_handle, event.status) + self.emit('connection_failure', BT_LE_TRANSPORT, event.connection_handle, event.peer_address, event.status) def on_hci_le_enhanced_connection_complete_event(self, event): # Just use the same implementation as for the non-enhanced event for now @@ -413,7 +413,7 @@ class Host(EventEmitter): logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}') # Notify the client - self.emit('connection_failure', event.connection_handle, event.status) + self.emit('connection_failure', BT_BR_EDR_TRANSPORT, event.connection_handle, event.bd_addr, event.status) def on_hci_disconnection_complete_event(self, event): # Find the connection From 45dd849d9f8cfb82182db830e5975f49d24dc643 Mon Sep 17 00:00:00 2001 From: Abel Lucas Date: Thu, 20 Oct 2022 14:51:52 +0000 Subject: [PATCH 2/4] classic: update `ConnectionError` to take transport and peer address --- bumble/core.py | 5 +++++ bumble/device.py | 7 ++++--- bumble/host.py | 4 ++-- bumble/rfcomm.py | 9 +++++++-- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/bumble/core.py b/bumble/core.py index f024c12..b4c640c 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -100,6 +100,11 @@ class ConnectionError(BaseError): FAILURE = 0x01 CONNECTION_REFUSED = 0x02 + def __init__(self, error_code, transport, peer_address, error_namespace='', error_name='', details=''): + self.transport = transport + self.peer_address = peer_address + super().__init__(error_code, error_namespace, error_name, details) + # ----------------------------------------------------------------------------- # UUID diff --git a/bumble/device.py b/bumble/device.py index 5393573..67bb3f3 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -1728,12 +1728,11 @@ class Device(CompositeEventEmitter): # Notify listeners error = ConnectionError( error_code, + transport, + peer_address, 'hci', HCI_Constant.error_name(error_code) ) - error.transport = transport - error.connection_handle = connection_handle # FIXME: Connection handle sounds to be a dummy value here - error.peer_address = peer_address self.emit('connection_failure', error) @host_event_handler @@ -1762,6 +1761,8 @@ class Device(CompositeEventEmitter): logger.debug(f'*** Disconnection failed: {error_code}') error = ConnectionError( error_code, + connection.transport, + connection.peer_address, 'hci', HCI_Constant.error_name(error_code) ) diff --git a/bumble/host.py b/bumble/host.py index 276d334..20a92bf 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -383,7 +383,7 @@ class Host(EventEmitter): logger.debug(f'### CONNECTION FAILED: {event.status}') # Notify the listeners - self.emit('connection_failure', BT_LE_TRANSPORT, event.connection_handle, event.peer_address, event.status) + self.emit('connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status) def on_hci_le_enhanced_connection_complete_event(self, event): # Just use the same implementation as for the non-enhanced event for now @@ -413,7 +413,7 @@ class Host(EventEmitter): logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}') # Notify the client - self.emit('connection_failure', BT_BR_EDR_TRANSPORT, event.connection_handle, event.bd_addr, event.status) + self.emit('connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status) def on_hci_disconnection_complete_event(self, event): # Find the connection diff --git a/bumble/rfcomm.py b/bumble/rfcomm.py index be4d406..e47a260 100644 --- a/bumble/rfcomm.py +++ b/bumble/rfcomm.py @@ -21,7 +21,7 @@ import asyncio from colors import color from pyee import EventEmitter -from .core import InvalidStateError, ProtocolError, ConnectionError +from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError, ConnectionError # ----------------------------------------------------------------------------- # Logging @@ -634,7 +634,12 @@ class Multiplexer(EventEmitter): if self.state == Multiplexer.OPENING: self.change_state(Multiplexer.CONNECTED) if self.open_result: - self.open_result.set_exception(ConnectionError(ConnectionError.CONNECTION_REFUSED)) + self.open_result.set_exception(ConnectionError( + ConnectionError.CONNECTION_REFUSED, + self.l2cap_channel.connection.peer_address, + BT_BR_EDR_TRANSPORT, + 'rfcomm' + )) else: logger.warn(f'unexpected state for DM: {self}') From 915405a9bd360435c28d0da059029d62b3774001 Mon Sep 17 00:00:00 2001 From: Abel Lucas Date: Wed, 19 Oct 2022 23:54:00 +0000 Subject: [PATCH 3/4] examples: update `run_classic_connect` example to take multiple addresses instead of one --- examples/run_classic_connect.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/run_classic_connect.py b/examples/run_classic_connect.py index d6842fe..8395a23 100644 --- a/examples/run_classic_connect.py +++ b/examples/run_classic_connect.py @@ -30,7 +30,7 @@ from bumble.sdp import Client as SDP_Client, SDP_PUBLIC_BROWSE_ROOT, SDP_ALL_ATT # ----------------------------------------------------------------------------- async def main(): if len(sys.argv) < 3: - print('Usage: run_classic_connect.py ') + print('Usage: run_classic_connect.py ') print('example: run_classic_connect.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8') return @@ -43,8 +43,7 @@ async def main(): device.classic_enabled = True await device.power_on() - # Connect to a peer - target_address = sys.argv[3] + async def connect(target_address): print(f'=== Connecting to {target_address}...') connection = await device.connect(target_address, transport=BT_BR_EDR_TRANSPORT) print(f'=== Connected to {connection.peer_address}!') @@ -76,6 +75,10 @@ async def main(): await sdp_client.disconnect() await hci_source.wait_for_termination() + # Connect to a peer + target_addresses = sys.argv[3:] + await asyncio.wait([asyncio.create_task(connect(target_address)) for target_address in target_addresses]) + # ----------------------------------------------------------------------------- logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) asyncio.run(main()) From 16b4f18c9265296430c6b1d1bba04adca039d6b5 Mon Sep 17 00:00:00 2001 From: Abel Lucas Date: Thu, 20 Oct 2022 17:41:15 +0000 Subject: [PATCH 4/4] tests: add parallel device connection test --- bumble/core.py | 2 +- bumble/device.py | 2 +- tests/device_test.py | 181 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 183 insertions(+), 2 deletions(-) create mode 100644 tests/device_test.py diff --git a/bumble/core.py b/bumble/core.py index b4c640c..ed52572 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -101,9 +101,9 @@ class ConnectionError(BaseError): CONNECTION_REFUSED = 0x02 def __init__(self, error_code, transport, peer_address, error_namespace='', error_name='', details=''): + super().__init__(error_code, error_namespace, error_name, details) self.transport = transport self.peer_address = peer_address - super().__init__(error_code, error_namespace, error_name, details) # ----------------------------------------------------------------------------- diff --git a/bumble/device.py b/bumble/device.py index 67bb3f3..3f231cd 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -1718,7 +1718,7 @@ class Device(CompositeEventEmitter): asyncio.create_task(new_connection()) @host_event_handler - def on_connection_failure(self, transport, connection_handle, peer_address, error_code): + def on_connection_failure(self, transport, peer_address, error_code): logger.debug(f'*** Connection failed: {HCI_Constant.error_name(error_code)}') # For directed advertising, this means a timeout diff --git a/tests/device_test.py b/tests/device_test.py new file mode 100644 index 0000000..cd72c4c --- /dev/null +++ b/tests/device_test.py @@ -0,0 +1,181 @@ +# Copyright 2021-2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import logging +import os +from types import LambdaType +import pytest + +from bumble.core import BT_BR_EDR_TRANSPORT +from bumble.device import Connection, Device +from bumble.host import Host +from bumble.hci import ( + HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, HCI_COMMAND_STATUS_PENDING, HCI_CREATE_CONNECTION_COMMAND, HCI_SUCCESS, + Address, HCI_Command_Complete_Event, HCI_Command_Status_Event, HCI_Connection_Complete_Event, HCI_Connection_Request_Event, HCI_Packet +) + + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +class Sink: + def __init__(self, flow): + self.flow = flow + next(self.flow) + + def on_packet(self, packet): + self.flow.send(packet) + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_device_connect_parallel(): + d0 = Device(host=Host(None, None)) + d1 = Device(host=Host(None, None)) + d2 = Device(host=Host(None, None)) + + # enable classic + d0.classic_enabled = True + d1.classic_enabled = True + d2.classic_enabled = True + + # set public addresses + d0.public_address = Address('F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS) + d1.public_address = Address('F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS) + d2.public_address = Address('F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS) + + def d0_flow(): + packet = HCI_Packet.from_bytes((yield)) + assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND' + assert packet.bd_addr == d1.public_address + + d0.host.on_hci_packet(HCI_Command_Status_Event( + status = HCI_COMMAND_STATUS_PENDING, + num_hci_command_packets = 1, + command_opcode = HCI_CREATE_CONNECTION_COMMAND + )) + + d1.host.on_hci_packet(HCI_Connection_Request_Event( + bd_addr = d0.public_address, + class_of_device = 0, + link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE + )) + + packet = HCI_Packet.from_bytes((yield)) + assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND' + assert packet.bd_addr == d2.public_address + + d0.host.on_hci_packet(HCI_Command_Status_Event( + status = HCI_COMMAND_STATUS_PENDING, + num_hci_command_packets = 1, + command_opcode = HCI_CREATE_CONNECTION_COMMAND + )) + + d2.host.on_hci_packet(HCI_Connection_Request_Event( + bd_addr = d0.public_address, + class_of_device = 0, + link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE + )) + + assert (yield) == None + + def d1_flow(): + packet = HCI_Packet.from_bytes((yield)) + assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' + + d1.host.on_hci_packet(HCI_Command_Complete_Event( + num_hci_command_packets = 1, + command_opcode = HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, + return_parameters = b"\x00" + )) + + d1.host.on_hci_packet(HCI_Connection_Complete_Event( + status = HCI_SUCCESS, + connection_handle = 0x100, + bd_addr = d0.public_address, + link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE, + encryption_enabled = True, + )) + + d0.host.on_hci_packet(HCI_Connection_Complete_Event( + status = HCI_SUCCESS, + connection_handle = 0x100, + bd_addr = d1.public_address, + link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE, + encryption_enabled = True, + )) + + assert (yield) == None + + def d2_flow(): + packet = HCI_Packet.from_bytes((yield)) + assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' + + d2.host.on_hci_packet(HCI_Command_Complete_Event( + num_hci_command_packets = 1, + command_opcode = HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, + return_parameters = b"\x00" + )) + + d2.host.on_hci_packet(HCI_Connection_Complete_Event( + status = HCI_SUCCESS, + connection_handle = 0x101, + bd_addr = d0.public_address, + link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE, + encryption_enabled = True, + )) + + d0.host.on_hci_packet(HCI_Connection_Complete_Event( + status = HCI_SUCCESS, + connection_handle = 0x101, + bd_addr = d2.public_address, + link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE, + encryption_enabled = True, + )) + + assert (yield) == None + + d0.host.set_packet_sink(Sink(d0_flow())) + d1.host.set_packet_sink(Sink(d1_flow())) + d2.host.set_packet_sink(Sink(d2_flow())) + + [c1, c2] = await asyncio.gather(*[ + asyncio.create_task(d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT)), + asyncio.create_task(d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT)), + ]) + + assert type(c1) == Connection + assert type(c2) == Connection + + assert c1.handle == 0x100 + assert c2.handle == 0x101 + + +# ----------------------------------------------------------------------------- +async def run_test_device(): + await test_device_connect_parallel() + + +# ----------------------------------------------------------------------------- +if __name__ == '__main__': + logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + asyncio.run(run_test_device())