forked from auracaster/bumble_mirror
Merge pull request #48 from google/uael/classic-parallel-connect
classic: update `Device.connect` to allow parallels connection creation
This commit is contained in:
@@ -104,6 +104,11 @@ class ConnectionError(BaseError):
|
||||
FAILURE = 0x01
|
||||
CONNECTION_REFUSED = 0x02
|
||||
|
||||
def __init__(self, error_code, transport, peer_address, error_namespace='', error_name='', details=''):
|
||||
super().__init__(error_code, error_namespace, error_name, details)
|
||||
self.transport = transport
|
||||
self.peer_address = peer_address
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# UUID
|
||||
|
||||
+27
-9
@@ -1145,7 +1145,7 @@ class Device(CompositeEventEmitter):
|
||||
transport = BT_LE_TRANSPORT
|
||||
|
||||
# Check that there isn't already a pending connection
|
||||
if self.is_connecting:
|
||||
if transport == BT_LE_TRANSPORT and self.is_connecting:
|
||||
raise InvalidStateError('connection already pending')
|
||||
|
||||
if type(peer_address) is str:
|
||||
@@ -1156,10 +1156,22 @@ class Device(CompositeEventEmitter):
|
||||
logger.debug('looking for peer by name')
|
||||
peer_address = await self.find_peer_by_name(peer_address, transport) # TODO: timeout
|
||||
|
||||
def on_connection(connection):
|
||||
if transport == BT_LE_TRANSPORT or (
|
||||
# match BR/EDR connection event against peer address
|
||||
connection.transport == transport and connection.peer_address == peer_address):
|
||||
pending_connection.set_result(connection)
|
||||
|
||||
def on_connection_failure(error):
|
||||
if transport == BT_LE_TRANSPORT or (
|
||||
# match BR/EDR connection failure event against peer address
|
||||
error.transport == transport and error.peer_address == peer_address):
|
||||
pending_connection.set_exception(error)
|
||||
|
||||
# Create a future so that we can wait for the connection's result
|
||||
pending_connection = asyncio.get_running_loop().create_future()
|
||||
self.on('connection', pending_connection.set_result)
|
||||
self.on('connection_failure', pending_connection.set_exception)
|
||||
self.on('connection', on_connection)
|
||||
self.on('connection_failure', on_connection_failure)
|
||||
|
||||
try:
|
||||
# Tell the controller to connect
|
||||
@@ -1249,7 +1261,8 @@ class Device(CompositeEventEmitter):
|
||||
raise HCI_StatusError(result)
|
||||
|
||||
# Wait for the connection process to complete
|
||||
self.connecting = True
|
||||
if transport == BT_LE_TRANSPORT:
|
||||
self.connecting = True
|
||||
if timeout is None:
|
||||
return await pending_connection
|
||||
else:
|
||||
@@ -1266,9 +1279,10 @@ class Device(CompositeEventEmitter):
|
||||
except ConnectionError:
|
||||
raise TimeoutError()
|
||||
finally:
|
||||
self.remove_listener('connection', pending_connection.set_result)
|
||||
self.remove_listener('connection_failure', pending_connection.set_exception)
|
||||
self.connecting = False
|
||||
self.remove_listener('connection', on_connection)
|
||||
self.remove_listener('connection_failure', on_connection_failure)
|
||||
if transport == BT_LE_TRANSPORT:
|
||||
self.connecting = False
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect_as_gatt(self, peer_address):
|
||||
@@ -1705,16 +1719,18 @@ class Device(CompositeEventEmitter):
|
||||
asyncio.create_task(new_connection())
|
||||
|
||||
@host_event_handler
|
||||
def on_connection_failure(self, connection_handle, error_code):
|
||||
def on_connection_failure(self, transport, peer_address, error_code):
|
||||
logger.debug(f'*** Connection failed: {HCI_Constant.error_name(error_code)}')
|
||||
|
||||
# For directed advertising, this means a timeout
|
||||
if self.advertising and self.advertising_type.is_directed:
|
||||
if transport == BT_LE_TRANSPORT and self.advertising and self.advertising_type.is_directed:
|
||||
self.advertising = False
|
||||
|
||||
# Notify listeners
|
||||
error = ConnectionError(
|
||||
error_code,
|
||||
transport,
|
||||
peer_address,
|
||||
'hci',
|
||||
HCI_Constant.error_name(error_code)
|
||||
)
|
||||
@@ -1746,6 +1762,8 @@ class Device(CompositeEventEmitter):
|
||||
logger.debug(f'*** Disconnection failed: {error_code}')
|
||||
error = ConnectionError(
|
||||
error_code,
|
||||
connection.transport,
|
||||
connection.peer_address,
|
||||
'hci',
|
||||
HCI_Constant.error_name(error_code)
|
||||
)
|
||||
|
||||
+2
-2
@@ -383,7 +383,7 @@ class Host(EventEmitter):
|
||||
logger.debug(f'### CONNECTION FAILED: {event.status}')
|
||||
|
||||
# Notify the listeners
|
||||
self.emit('connection_failure', event.connection_handle, event.status)
|
||||
self.emit('connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status)
|
||||
|
||||
def on_hci_le_enhanced_connection_complete_event(self, event):
|
||||
# Just use the same implementation as for the non-enhanced event for now
|
||||
@@ -413,7 +413,7 @@ class Host(EventEmitter):
|
||||
logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')
|
||||
|
||||
# Notify the client
|
||||
self.emit('connection_failure', event.connection_handle, event.status)
|
||||
self.emit('connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status)
|
||||
|
||||
def on_hci_disconnection_complete_event(self, event):
|
||||
# Find the connection
|
||||
|
||||
+7
-2
@@ -21,7 +21,7 @@ import asyncio
|
||||
from colors import color
|
||||
from pyee import EventEmitter
|
||||
|
||||
from .core import InvalidStateError, ProtocolError, ConnectionError
|
||||
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError, ConnectionError
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -634,7 +634,12 @@ class Multiplexer(EventEmitter):
|
||||
if self.state == Multiplexer.OPENING:
|
||||
self.change_state(Multiplexer.CONNECTED)
|
||||
if self.open_result:
|
||||
self.open_result.set_exception(ConnectionError(ConnectionError.CONNECTION_REFUSED))
|
||||
self.open_result.set_exception(ConnectionError(
|
||||
ConnectionError.CONNECTION_REFUSED,
|
||||
self.l2cap_channel.connection.peer_address,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
'rfcomm'
|
||||
))
|
||||
else:
|
||||
logger.warn(f'unexpected state for DM: {self}')
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from bumble.sdp import Client as SDP_Client, SDP_PUBLIC_BROWSE_ROOT, SDP_ALL_ATT
|
||||
# -----------------------------------------------------------------------------
|
||||
async def main():
|
||||
if len(sys.argv) < 3:
|
||||
print('Usage: run_classic_connect.py <device-config> <transport-spec> <bluetooth-address>')
|
||||
print('Usage: run_classic_connect.py <device-config> <transport-spec> <bluetooth-addresses..>')
|
||||
print('example: run_classic_connect.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8')
|
||||
return
|
||||
|
||||
@@ -43,8 +43,7 @@ async def main():
|
||||
device.classic_enabled = True
|
||||
await device.power_on()
|
||||
|
||||
# Connect to a peer
|
||||
target_address = sys.argv[3]
|
||||
async def connect(target_address):
|
||||
print(f'=== Connecting to {target_address}...')
|
||||
connection = await device.connect(target_address, transport=BT_BR_EDR_TRANSPORT)
|
||||
print(f'=== Connected to {connection.peer_address}!')
|
||||
@@ -76,6 +75,10 @@ async def main():
|
||||
await sdp_client.disconnect()
|
||||
await hci_source.wait_for_termination()
|
||||
|
||||
# Connect to a peer
|
||||
target_addresses = sys.argv[3:]
|
||||
await asyncio.wait([asyncio.create_task(connect(target_address)) for target_address in target_addresses])
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from types import LambdaType
|
||||
import pytest
|
||||
|
||||
from bumble.core import BT_BR_EDR_TRANSPORT
|
||||
from bumble.device import Connection, Device
|
||||
from bumble.host import Host
|
||||
from bumble.hci import (
|
||||
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, HCI_COMMAND_STATUS_PENDING, HCI_CREATE_CONNECTION_COMMAND, HCI_SUCCESS,
|
||||
Address, HCI_Command_Complete_Event, HCI_Command_Status_Event, HCI_Connection_Complete_Event, HCI_Connection_Request_Event, HCI_Packet
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Sink:
|
||||
def __init__(self, flow):
|
||||
self.flow = flow
|
||||
next(self.flow)
|
||||
|
||||
def on_packet(self, packet):
|
||||
self.flow.send(packet)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_device_connect_parallel():
|
||||
d0 = Device(host=Host(None, None))
|
||||
d1 = Device(host=Host(None, None))
|
||||
d2 = Device(host=Host(None, None))
|
||||
|
||||
# enable classic
|
||||
d0.classic_enabled = True
|
||||
d1.classic_enabled = True
|
||||
d2.classic_enabled = True
|
||||
|
||||
# set public addresses
|
||||
d0.public_address = Address('F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS)
|
||||
d1.public_address = Address('F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS)
|
||||
d2.public_address = Address('F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS)
|
||||
|
||||
def d0_flow():
|
||||
packet = HCI_Packet.from_bytes((yield))
|
||||
assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
|
||||
assert packet.bd_addr == d1.public_address
|
||||
|
||||
d0.host.on_hci_packet(HCI_Command_Status_Event(
|
||||
status = HCI_COMMAND_STATUS_PENDING,
|
||||
num_hci_command_packets = 1,
|
||||
command_opcode = HCI_CREATE_CONNECTION_COMMAND
|
||||
))
|
||||
|
||||
d1.host.on_hci_packet(HCI_Connection_Request_Event(
|
||||
bd_addr = d0.public_address,
|
||||
class_of_device = 0,
|
||||
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE
|
||||
))
|
||||
|
||||
packet = HCI_Packet.from_bytes((yield))
|
||||
assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
|
||||
assert packet.bd_addr == d2.public_address
|
||||
|
||||
d0.host.on_hci_packet(HCI_Command_Status_Event(
|
||||
status = HCI_COMMAND_STATUS_PENDING,
|
||||
num_hci_command_packets = 1,
|
||||
command_opcode = HCI_CREATE_CONNECTION_COMMAND
|
||||
))
|
||||
|
||||
d2.host.on_hci_packet(HCI_Connection_Request_Event(
|
||||
bd_addr = d0.public_address,
|
||||
class_of_device = 0,
|
||||
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE
|
||||
))
|
||||
|
||||
assert (yield) == None
|
||||
|
||||
def d1_flow():
|
||||
packet = HCI_Packet.from_bytes((yield))
|
||||
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
|
||||
|
||||
d1.host.on_hci_packet(HCI_Command_Complete_Event(
|
||||
num_hci_command_packets = 1,
|
||||
command_opcode = HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
|
||||
return_parameters = b"\x00"
|
||||
))
|
||||
|
||||
d1.host.on_hci_packet(HCI_Connection_Complete_Event(
|
||||
status = HCI_SUCCESS,
|
||||
connection_handle = 0x100,
|
||||
bd_addr = d0.public_address,
|
||||
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
encryption_enabled = True,
|
||||
))
|
||||
|
||||
d0.host.on_hci_packet(HCI_Connection_Complete_Event(
|
||||
status = HCI_SUCCESS,
|
||||
connection_handle = 0x100,
|
||||
bd_addr = d1.public_address,
|
||||
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
encryption_enabled = True,
|
||||
))
|
||||
|
||||
assert (yield) == None
|
||||
|
||||
def d2_flow():
|
||||
packet = HCI_Packet.from_bytes((yield))
|
||||
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
|
||||
|
||||
d2.host.on_hci_packet(HCI_Command_Complete_Event(
|
||||
num_hci_command_packets = 1,
|
||||
command_opcode = HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
|
||||
return_parameters = b"\x00"
|
||||
))
|
||||
|
||||
d2.host.on_hci_packet(HCI_Connection_Complete_Event(
|
||||
status = HCI_SUCCESS,
|
||||
connection_handle = 0x101,
|
||||
bd_addr = d0.public_address,
|
||||
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
encryption_enabled = True,
|
||||
))
|
||||
|
||||
d0.host.on_hci_packet(HCI_Connection_Complete_Event(
|
||||
status = HCI_SUCCESS,
|
||||
connection_handle = 0x101,
|
||||
bd_addr = d2.public_address,
|
||||
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
encryption_enabled = True,
|
||||
))
|
||||
|
||||
assert (yield) == None
|
||||
|
||||
d0.host.set_packet_sink(Sink(d0_flow()))
|
||||
d1.host.set_packet_sink(Sink(d1_flow()))
|
||||
d2.host.set_packet_sink(Sink(d2_flow()))
|
||||
|
||||
[c1, c2] = await asyncio.gather(*[
|
||||
asyncio.create_task(d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT)),
|
||||
asyncio.create_task(d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT)),
|
||||
])
|
||||
|
||||
assert type(c1) == Connection
|
||||
assert type(c2) == Connection
|
||||
|
||||
assert c1.handle == 0x100
|
||||
assert c2.handle == 0x101
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run_test_device():
|
||||
await test_device_connect_parallel()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
asyncio.run(run_test_device())
|
||||
Reference in New Issue
Block a user