Merge pull request #48 from google/uael/classic-parallel-connect

classic: update `Device.connect` to allow parallels connection creation
This commit is contained in:
Gilles Boccon-Gibod
2022-10-23 20:52:08 -07:00
committed by GitHub
6 changed files with 228 additions and 16 deletions
+5
View File
@@ -104,6 +104,11 @@ class ConnectionError(BaseError):
FAILURE = 0x01
CONNECTION_REFUSED = 0x02
def __init__(self, error_code, transport, peer_address, error_namespace='', error_name='', details=''):
super().__init__(error_code, error_namespace, error_name, details)
self.transport = transport
self.peer_address = peer_address
# -----------------------------------------------------------------------------
# UUID
+27 -9
View File
@@ -1145,7 +1145,7 @@ class Device(CompositeEventEmitter):
transport = BT_LE_TRANSPORT
# Check that there isn't already a pending connection
if self.is_connecting:
if transport == BT_LE_TRANSPORT and self.is_connecting:
raise InvalidStateError('connection already pending')
if type(peer_address) is str:
@@ -1156,10 +1156,22 @@ class Device(CompositeEventEmitter):
logger.debug('looking for peer by name')
peer_address = await self.find_peer_by_name(peer_address, transport) # TODO: timeout
def on_connection(connection):
if transport == BT_LE_TRANSPORT or (
# match BR/EDR connection event against peer address
connection.transport == transport and connection.peer_address == peer_address):
pending_connection.set_result(connection)
def on_connection_failure(error):
if transport == BT_LE_TRANSPORT or (
# match BR/EDR connection failure event against peer address
error.transport == transport and error.peer_address == peer_address):
pending_connection.set_exception(error)
# Create a future so that we can wait for the connection's result
pending_connection = asyncio.get_running_loop().create_future()
self.on('connection', pending_connection.set_result)
self.on('connection_failure', pending_connection.set_exception)
self.on('connection', on_connection)
self.on('connection_failure', on_connection_failure)
try:
# Tell the controller to connect
@@ -1249,7 +1261,8 @@ class Device(CompositeEventEmitter):
raise HCI_StatusError(result)
# Wait for the connection process to complete
self.connecting = True
if transport == BT_LE_TRANSPORT:
self.connecting = True
if timeout is None:
return await pending_connection
else:
@@ -1266,9 +1279,10 @@ class Device(CompositeEventEmitter):
except ConnectionError:
raise TimeoutError()
finally:
self.remove_listener('connection', pending_connection.set_result)
self.remove_listener('connection_failure', pending_connection.set_exception)
self.connecting = False
self.remove_listener('connection', on_connection)
self.remove_listener('connection_failure', on_connection_failure)
if transport == BT_LE_TRANSPORT:
self.connecting = False
@asynccontextmanager
async def connect_as_gatt(self, peer_address):
@@ -1705,16 +1719,18 @@ class Device(CompositeEventEmitter):
asyncio.create_task(new_connection())
@host_event_handler
def on_connection_failure(self, connection_handle, error_code):
def on_connection_failure(self, transport, peer_address, error_code):
logger.debug(f'*** Connection failed: {HCI_Constant.error_name(error_code)}')
# For directed advertising, this means a timeout
if self.advertising and self.advertising_type.is_directed:
if transport == BT_LE_TRANSPORT and self.advertising and self.advertising_type.is_directed:
self.advertising = False
# Notify listeners
error = ConnectionError(
error_code,
transport,
peer_address,
'hci',
HCI_Constant.error_name(error_code)
)
@@ -1746,6 +1762,8 @@ class Device(CompositeEventEmitter):
logger.debug(f'*** Disconnection failed: {error_code}')
error = ConnectionError(
error_code,
connection.transport,
connection.peer_address,
'hci',
HCI_Constant.error_name(error_code)
)
+2 -2
View File
@@ -383,7 +383,7 @@ class Host(EventEmitter):
logger.debug(f'### CONNECTION FAILED: {event.status}')
# Notify the listeners
self.emit('connection_failure', event.connection_handle, event.status)
self.emit('connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status)
def on_hci_le_enhanced_connection_complete_event(self, event):
# Just use the same implementation as for the non-enhanced event for now
@@ -413,7 +413,7 @@ class Host(EventEmitter):
logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')
# Notify the client
self.emit('connection_failure', event.connection_handle, event.status)
self.emit('connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status)
def on_hci_disconnection_complete_event(self, event):
# Find the connection
+7 -2
View File
@@ -21,7 +21,7 @@ import asyncio
from colors import color
from pyee import EventEmitter
from .core import InvalidStateError, ProtocolError, ConnectionError
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError, ConnectionError
# -----------------------------------------------------------------------------
# Logging
@@ -634,7 +634,12 @@ class Multiplexer(EventEmitter):
if self.state == Multiplexer.OPENING:
self.change_state(Multiplexer.CONNECTED)
if self.open_result:
self.open_result.set_exception(ConnectionError(ConnectionError.CONNECTION_REFUSED))
self.open_result.set_exception(ConnectionError(
ConnectionError.CONNECTION_REFUSED,
self.l2cap_channel.connection.peer_address,
BT_BR_EDR_TRANSPORT,
'rfcomm'
))
else:
logger.warn(f'unexpected state for DM: {self}')
+6 -3
View File
@@ -30,7 +30,7 @@ from bumble.sdp import Client as SDP_Client, SDP_PUBLIC_BROWSE_ROOT, SDP_ALL_ATT
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 3:
print('Usage: run_classic_connect.py <device-config> <transport-spec> <bluetooth-address>')
print('Usage: run_classic_connect.py <device-config> <transport-spec> <bluetooth-addresses..>')
print('example: run_classic_connect.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8')
return
@@ -43,8 +43,7 @@ async def main():
device.classic_enabled = True
await device.power_on()
# Connect to a peer
target_address = sys.argv[3]
async def connect(target_address):
print(f'=== Connecting to {target_address}...')
connection = await device.connect(target_address, transport=BT_BR_EDR_TRANSPORT)
print(f'=== Connected to {connection.peer_address}!')
@@ -76,6 +75,10 @@ async def main():
await sdp_client.disconnect()
await hci_source.wait_for_termination()
# Connect to a peer
target_addresses = sys.argv[3:]
await asyncio.wait([asyncio.create_task(connect(target_address)) for target_address in target_addresses])
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())
+181
View File
@@ -0,0 +1,181 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
from types import LambdaType
import pytest
from bumble.core import BT_BR_EDR_TRANSPORT
from bumble.device import Connection, Device
from bumble.host import Host
from bumble.hci import (
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, HCI_COMMAND_STATUS_PENDING, HCI_CREATE_CONNECTION_COMMAND, HCI_SUCCESS,
Address, HCI_Command_Complete_Event, HCI_Command_Status_Event, HCI_Connection_Complete_Event, HCI_Connection_Request_Event, HCI_Packet
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class Sink:
def __init__(self, flow):
self.flow = flow
next(self.flow)
def on_packet(self, packet):
self.flow.send(packet)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_device_connect_parallel():
d0 = Device(host=Host(None, None))
d1 = Device(host=Host(None, None))
d2 = Device(host=Host(None, None))
# enable classic
d0.classic_enabled = True
d1.classic_enabled = True
d2.classic_enabled = True
# set public addresses
d0.public_address = Address('F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS)
d1.public_address = Address('F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS)
d2.public_address = Address('F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS)
def d0_flow():
packet = HCI_Packet.from_bytes((yield))
assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
assert packet.bd_addr == d1.public_address
d0.host.on_hci_packet(HCI_Command_Status_Event(
status = HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets = 1,
command_opcode = HCI_CREATE_CONNECTION_COMMAND
))
d1.host.on_hci_packet(HCI_Connection_Request_Event(
bd_addr = d0.public_address,
class_of_device = 0,
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE
))
packet = HCI_Packet.from_bytes((yield))
assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
assert packet.bd_addr == d2.public_address
d0.host.on_hci_packet(HCI_Command_Status_Event(
status = HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets = 1,
command_opcode = HCI_CREATE_CONNECTION_COMMAND
))
d2.host.on_hci_packet(HCI_Connection_Request_Event(
bd_addr = d0.public_address,
class_of_device = 0,
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE
))
assert (yield) == None
def d1_flow():
packet = HCI_Packet.from_bytes((yield))
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d1.host.on_hci_packet(HCI_Command_Complete_Event(
num_hci_command_packets = 1,
command_opcode = HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters = b"\x00"
))
d1.host.on_hci_packet(HCI_Connection_Complete_Event(
status = HCI_SUCCESS,
connection_handle = 0x100,
bd_addr = d0.public_address,
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled = True,
))
d0.host.on_hci_packet(HCI_Connection_Complete_Event(
status = HCI_SUCCESS,
connection_handle = 0x100,
bd_addr = d1.public_address,
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled = True,
))
assert (yield) == None
def d2_flow():
packet = HCI_Packet.from_bytes((yield))
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d2.host.on_hci_packet(HCI_Command_Complete_Event(
num_hci_command_packets = 1,
command_opcode = HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters = b"\x00"
))
d2.host.on_hci_packet(HCI_Connection_Complete_Event(
status = HCI_SUCCESS,
connection_handle = 0x101,
bd_addr = d0.public_address,
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled = True,
))
d0.host.on_hci_packet(HCI_Connection_Complete_Event(
status = HCI_SUCCESS,
connection_handle = 0x101,
bd_addr = d2.public_address,
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled = True,
))
assert (yield) == None
d0.host.set_packet_sink(Sink(d0_flow()))
d1.host.set_packet_sink(Sink(d1_flow()))
d2.host.set_packet_sink(Sink(d2_flow()))
[c1, c2] = await asyncio.gather(*[
asyncio.create_task(d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT)),
asyncio.create_task(d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT)),
])
assert type(c1) == Connection
assert type(c2) == Connection
assert c1.handle == 0x100
assert c2.handle == 0x101
# -----------------------------------------------------------------------------
async def run_test_device():
await test_device_connect_parallel()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run_test_device())