mirror of
https://github.com/google/bumble.git
synced 2026-06-01 07:37:02 +00:00
Merge pull request #55 from google/uael/device-improvements
Device improvements
This commit is contained in:
+4
-4
@@ -311,7 +311,7 @@ class ConsoleApp:
|
||||
rssi = '' if self.connection_rssi is None else rssi_bar(self.connection_rssi)
|
||||
|
||||
if self.device:
|
||||
if self.device.is_connecting:
|
||||
if self.device.is_le_connecting:
|
||||
connection_state = 'CONNECTING'
|
||||
elif self.connected_peer:
|
||||
connection = self.connected_peer.connection
|
||||
@@ -574,7 +574,7 @@ class ConsoleApp:
|
||||
self.show_error('connection timed out')
|
||||
|
||||
async def do_disconnect(self, params):
|
||||
if self.device.connecting:
|
||||
if self.device.is_le_connecting:
|
||||
await self.device.cancel_connection()
|
||||
else:
|
||||
if not self.connected_peer:
|
||||
@@ -877,9 +877,9 @@ class ScanResult:
|
||||
else:
|
||||
type_color = colors.cyan
|
||||
|
||||
name = self.ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME)
|
||||
name = self.ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True)
|
||||
if name is None:
|
||||
name = self.ad_data.get(AdvertisingData.SHORTENED_LOCAL_NAME)
|
||||
name = self.ad_data.get(AdvertisingData.SHORTENED_LOCAL_NAME, raw=True)
|
||||
if name:
|
||||
# Convert to string
|
||||
try:
|
||||
|
||||
+22
-6
@@ -769,17 +769,20 @@ class AdvertisingData:
|
||||
def ad_data_to_object(ad_type, ad_data):
|
||||
if ad_type in {
|
||||
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS
|
||||
}:
|
||||
return AdvertisingData.uuid_list_to_objects(ad_data, 2)
|
||||
elif ad_type in {
|
||||
AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
|
||||
AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS
|
||||
}:
|
||||
return AdvertisingData.uuid_list_to_objects(ad_data, 4)
|
||||
elif ad_type in {
|
||||
AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
|
||||
AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS
|
||||
}:
|
||||
return AdvertisingData.uuid_list_to_objects(ad_data, 16)
|
||||
elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
|
||||
@@ -790,11 +793,24 @@ class AdvertisingData:
|
||||
return (UUID.from_bytes(ad_data[:16]), ad_data[16:])
|
||||
elif ad_type in {
|
||||
AdvertisingData.SHORTENED_LOCAL_NAME,
|
||||
AdvertisingData.COMPLETE_LOCAL_NAME
|
||||
AdvertisingData.COMPLETE_LOCAL_NAME,
|
||||
AdvertisingData.URI
|
||||
}:
|
||||
return ad_data.decode("utf-8")
|
||||
elif ad_type == AdvertisingData.TX_POWER_LEVEL:
|
||||
elif ad_type in {
|
||||
AdvertisingData.TX_POWER_LEVEL,
|
||||
AdvertisingData.FLAGS
|
||||
}:
|
||||
return ad_data[0]
|
||||
elif ad_type in {
|
||||
AdvertisingData.APPEARANCE,
|
||||
AdvertisingData.ADVERTISING_INTERVAL
|
||||
}:
|
||||
return struct.unpack('<H', ad_data)[0]
|
||||
elif ad_type == AdvertisingData.CLASS_OF_DEVICE:
|
||||
return struct.unpack('<I', bytes([*ad_data, 0]))[0]
|
||||
elif ad_type == AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE:
|
||||
return struct.unpack('<HH', ad_data)
|
||||
elif ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
|
||||
return (struct.unpack_from('<H', ad_data, 0)[0], ad_data[2:])
|
||||
else:
|
||||
@@ -811,7 +827,7 @@ class AdvertisingData:
|
||||
self.ad_structures.append((ad_type, ad_data))
|
||||
offset += length
|
||||
|
||||
def get(self, type_id, return_all=False, raw=True):
|
||||
def get(self, type_id, return_all=False, raw=False):
|
||||
'''
|
||||
Get Advertising Data Structure(s) with a given type
|
||||
|
||||
|
||||
+234
-42
@@ -519,6 +519,7 @@ class DeviceConfiguration:
|
||||
self.le_simultaneous_enabled = True
|
||||
self.classic_sc_enabled = True
|
||||
self.classic_ssp_enabled = True
|
||||
self.classic_accept_any = True
|
||||
self.connectable = True
|
||||
self.discoverable = True
|
||||
self.advertising_data = bytes(
|
||||
@@ -539,6 +540,7 @@ class DeviceConfiguration:
|
||||
self.le_simultaneous_enabled = config.get('le_simultaneous_enabled', self.le_simultaneous_enabled)
|
||||
self.classic_sc_enabled = config.get('classic_sc_enabled', self.classic_sc_enabled)
|
||||
self.classic_ssp_enabled = config.get('classic_ssp_enabled', self.classic_ssp_enabled)
|
||||
self.classic_accept_any = config.get('classic_accept_any', self.classic_accept_any)
|
||||
self.connectable = config.get('connectable', self.connectable)
|
||||
self.discoverable = config.get('discoverable', self.discoverable)
|
||||
|
||||
@@ -589,6 +591,17 @@ def with_connection_from_address(function):
|
||||
return wrapper
|
||||
|
||||
|
||||
# Decorator that tries to convert the first argument from a bluetooth address to a connection
|
||||
def try_with_connection_from_address(function):
|
||||
@functools.wraps(function)
|
||||
def wrapper(self, address, *args, **kwargs):
|
||||
for connection in self.connections.values():
|
||||
if connection.peer_address == address:
|
||||
return function(self, connection, address, *args, **kwargs)
|
||||
return function(self, None, address, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
# Decorator that adds a method to the list of event handlers for host events.
|
||||
# This assumes that the method name starts with `on_`
|
||||
def host_event_handler(function):
|
||||
@@ -619,6 +632,9 @@ class Device(CompositeEventEmitter):
|
||||
def on_connection_failure(self, error):
|
||||
pass
|
||||
|
||||
def on_connection_request(self, bd_addr, class_of_device, link_type):
|
||||
pass
|
||||
|
||||
def on_characteristic_subscription(self, connection, characteristic, notify_enabled, indicate_enabled):
|
||||
pass
|
||||
|
||||
@@ -651,6 +667,7 @@ class Device(CompositeEventEmitter):
|
||||
self.powered_on = False
|
||||
self.advertising = False
|
||||
self.advertising_type = None
|
||||
self.auto_restart_inquiry = True
|
||||
self.auto_restart_advertising = False
|
||||
self.command_timeout = 10 # seconds
|
||||
self.gatt_server = gatt_server.Server(self)
|
||||
@@ -662,12 +679,13 @@ class Device(CompositeEventEmitter):
|
||||
self.scanning = False
|
||||
self.scanning_is_passive = False
|
||||
self.discovering = False
|
||||
self.connecting = False
|
||||
self.le_connecting = False
|
||||
self.disconnecting = False
|
||||
self.connections = {} # Connections, by connection handle
|
||||
self.classic_enabled = False
|
||||
self.inquiry_response = None
|
||||
self.address_resolver = None
|
||||
self.classic_pending_accepts = { Address.ANY: [] } # Futures, by BD address OR [Futures] for Address.ANY
|
||||
|
||||
# Use the initial config or a default
|
||||
self.public_address = Address('00:00:00:00:00:00')
|
||||
@@ -688,6 +706,7 @@ class Device(CompositeEventEmitter):
|
||||
self.classic_sc_enabled = config.classic_sc_enabled
|
||||
self.discoverable = config.discoverable
|
||||
self.connectable = config.connectable
|
||||
self.classic_accept_any = config.classic_accept_any
|
||||
|
||||
# If a name is passed, override the name from the config
|
||||
if name:
|
||||
@@ -758,9 +777,11 @@ class Device(CompositeEventEmitter):
|
||||
if connection := self.connections.get(connection_handle):
|
||||
return connection
|
||||
|
||||
def find_connection_by_bd_addr(self, bd_addr, transport=None):
|
||||
def find_connection_by_bd_addr(self, bd_addr, transport=None, check_address_type=False):
|
||||
for connection in self.connections.values():
|
||||
if connection.peer_address.get_bytes() == bd_addr.get_bytes():
|
||||
if check_address_type and connection.peer_address.address_type != bd_addr.address_type:
|
||||
continue
|
||||
if transport is None or connection.transport == transport:
|
||||
return connection
|
||||
|
||||
@@ -875,6 +896,7 @@ class Device(CompositeEventEmitter):
|
||||
self,
|
||||
advertising_type=AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
|
||||
target=None,
|
||||
own_address_type=Address.RANDOM_DEVICE_ADDRESS,
|
||||
auto_restart=False
|
||||
):
|
||||
# If we're advertising, stop first
|
||||
@@ -909,7 +931,7 @@ class Device(CompositeEventEmitter):
|
||||
advertising_interval_min = self.advertising_interval_min,
|
||||
advertising_interval_max = self.advertising_interval_max,
|
||||
advertising_type = int(advertising_type),
|
||||
own_address_type = Address.RANDOM_DEVICE_ADDRESS, # TODO: allow using the public address
|
||||
own_address_type = own_address_type,
|
||||
peer_address_type = peer_address_type,
|
||||
peer_address = peer_address,
|
||||
advertising_channel_map = 7,
|
||||
@@ -942,6 +964,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
async def start_scanning(
|
||||
self,
|
||||
legacy=False,
|
||||
active=True,
|
||||
scan_interval=DEVICE_DEFAULT_SCAN_INTERVAL, # Scan interval in ms
|
||||
scan_window=DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms
|
||||
@@ -961,7 +984,7 @@ class Device(CompositeEventEmitter):
|
||||
self.advertisement_accumulator = {}
|
||||
|
||||
# Enable scanning
|
||||
if self.supports_le_feature(HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE):
|
||||
if not legacy and self.supports_le_feature(HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE):
|
||||
# Set the scanning parameters
|
||||
scan_type = HCI_LE_Set_Extended_Scan_Parameters_Command.ACTIVE_SCANNING if active else HCI_LE_Set_Extended_Scan_Parameters_Command.PASSIVE_SCANNING
|
||||
scanning_filter_policy = HCI_LE_Set_Extended_Scan_Parameters_Command.BASIC_UNFILTERED_POLICY # TODO: support other types
|
||||
@@ -1044,7 +1067,7 @@ class Device(CompositeEventEmitter):
|
||||
if advertisement := accumulator.update(report):
|
||||
self.emit('advertisement', advertisement)
|
||||
|
||||
async def start_discovery(self):
|
||||
async def start_discovery(self, auto_restart=True):
|
||||
await self.send_command(HCI_Write_Inquiry_Mode_Command(
|
||||
inquiry_mode=HCI_EXTENDED_INQUIRY_MODE
|
||||
), check_result=True)
|
||||
@@ -1058,11 +1081,14 @@ class Device(CompositeEventEmitter):
|
||||
self.discovering = False
|
||||
raise HCI_StatusError(response)
|
||||
|
||||
self.discovering = True
|
||||
self.auto_restart_inquiry = auto_restart
|
||||
self.discovering = True
|
||||
|
||||
async def stop_discovery(self):
|
||||
await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True)
|
||||
self.discovering = False
|
||||
if self.discovering:
|
||||
await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True)
|
||||
self.auto_restart_inquiry = True
|
||||
self.discovering = False
|
||||
|
||||
@host_event_handler
|
||||
def on_inquiry_result(self, address, class_of_device, data, rssi):
|
||||
@@ -1127,7 +1153,7 @@ class Device(CompositeEventEmitter):
|
||||
):
|
||||
'''
|
||||
Request a connection to a peer.
|
||||
This method cannot be called if there is already a pending connection.
|
||||
When transport is BLE, this method cannot be called if there is already a pending connection.
|
||||
|
||||
connection_parameters_preferences: (BLE only, ignored for BR/EDR)
|
||||
* None: use all PHYs with default parameters
|
||||
@@ -1145,7 +1171,7 @@ class Device(CompositeEventEmitter):
|
||||
transport = BT_LE_TRANSPORT
|
||||
|
||||
# Check that there isn't already a pending connection
|
||||
if transport == BT_LE_TRANSPORT and self.is_connecting:
|
||||
if transport == BT_LE_TRANSPORT and self.is_le_connecting:
|
||||
raise InvalidStateError('connection already pending')
|
||||
|
||||
if type(peer_address) is str:
|
||||
@@ -1262,7 +1288,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
# Wait for the connection process to complete
|
||||
if transport == BT_LE_TRANSPORT:
|
||||
self.connecting = True
|
||||
self.le_connecting = True
|
||||
if timeout is None:
|
||||
return await pending_connection
|
||||
else:
|
||||
@@ -1282,7 +1308,90 @@ class Device(CompositeEventEmitter):
|
||||
self.remove_listener('connection', on_connection)
|
||||
self.remove_listener('connection_failure', on_connection_failure)
|
||||
if transport == BT_LE_TRANSPORT:
|
||||
self.connecting = False
|
||||
self.le_connecting = False
|
||||
|
||||
async def accept(
|
||||
self,
|
||||
peer_address=Address.ANY,
|
||||
role=BT_PERIPHERAL_ROLE,
|
||||
timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT
|
||||
):
|
||||
'''
|
||||
Wait and accept any incoming connection or a connection from `peer_address` when set.
|
||||
|
||||
Notes:
|
||||
* A `connect` to the same peer will also complete this call.
|
||||
* The `timeout` parameter is only handled while waiting for the connection request,
|
||||
once received and accepeted, the controller shall issue a connection complete event.
|
||||
'''
|
||||
|
||||
if type(peer_address) is str:
|
||||
try:
|
||||
peer_address = Address(peer_address)
|
||||
except ValueError:
|
||||
# If the address is not parsable, assume it is a name instead
|
||||
logger.debug('looking for peer by name')
|
||||
peer_address = await self.find_peer_by_name(peer_address, BT_BR_EDR_TRANSPORT) # TODO: timeout
|
||||
|
||||
if peer_address == Address.NIL:
|
||||
raise ValueError('accept on nil address')
|
||||
|
||||
# Create a future so that we can wait for the request
|
||||
pending_request = asyncio.get_running_loop().create_future()
|
||||
|
||||
if peer_address == Address.ANY:
|
||||
self.classic_pending_accepts[Address.ANY].append(pending_request)
|
||||
elif peer_address in self.classic_pending_accepts:
|
||||
raise InvalidStateError('accept connection already pending')
|
||||
else:
|
||||
self.classic_pending_accepts[peer_address] = pending_request
|
||||
|
||||
try:
|
||||
# Wait for a request or a completed connection
|
||||
result = await (asyncio.wait_for(pending_request, timeout) if timeout else pending_request)
|
||||
|
||||
except:
|
||||
# Remove future from device context
|
||||
if peer_address == Address.ANY:
|
||||
self.classic_pending_accepts[Address.ANY].remove(pending_request)
|
||||
else:
|
||||
self.classic_pending_accepts.pop(peer_address)
|
||||
raise
|
||||
|
||||
# Result may already be a completed connection,
|
||||
# see `on_connection` for details
|
||||
if isinstance(result, Connection):
|
||||
return result
|
||||
|
||||
# Otherwise, result came from `on_connection_request`
|
||||
peer_address, class_of_device, link_type = result
|
||||
|
||||
def on_connection(connection):
|
||||
if connection.transport == BT_BR_EDR_TRANSPORT and connection.peer_address == peer_address:
|
||||
pending_connection.set_result(connection)
|
||||
|
||||
def on_connection_failure(error):
|
||||
if error.transport == BT_BR_EDR_TRANSPORT and error.peer_address == peer_address:
|
||||
pending_connection.set_exception(error)
|
||||
|
||||
# Create a future so that we can wait for the connection's result
|
||||
pending_connection = asyncio.get_running_loop().create_future()
|
||||
self.on('connection', on_connection)
|
||||
self.on('connection_failure', on_connection_failure)
|
||||
|
||||
try:
|
||||
# Accept connection request
|
||||
await self.send_command(HCI_Accept_Connection_Request_Command(
|
||||
bd_addr = peer_address,
|
||||
role = role
|
||||
))
|
||||
|
||||
# Wait for connection complete
|
||||
return await pending_connection
|
||||
|
||||
finally:
|
||||
self.remove_listener('connection', on_connection)
|
||||
self.remove_listener('connection_failure', on_connection_failure)
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect_as_gatt(self, peer_address):
|
||||
@@ -1293,17 +1402,32 @@ class Device(CompositeEventEmitter):
|
||||
yield peer
|
||||
|
||||
@property
|
||||
def is_connecting(self):
|
||||
return self.connecting
|
||||
def is_le_connecting(self):
|
||||
return self.le_connecting
|
||||
|
||||
@property
|
||||
def is_disconnecting(self):
|
||||
return self.disconnecting
|
||||
|
||||
async def cancel_connection(self):
|
||||
if not self.is_connecting:
|
||||
return
|
||||
await self.send_command(HCI_LE_Create_Connection_Cancel_Command(), check_result=True)
|
||||
async def cancel_connection(self, peer_address=None):
|
||||
# Low-energy: cancel ongoing connection
|
||||
if peer_address is None:
|
||||
if not self.is_le_connecting:
|
||||
return
|
||||
await self.send_command(HCI_LE_Create_Connection_Cancel_Command(), check_result=True)
|
||||
|
||||
# BR/EDR: try to cancel to ongoing connection
|
||||
# NOTE: This API does not prevent from trying to cancel a connection which is not currently being created
|
||||
else:
|
||||
if type(peer_address) is str:
|
||||
try:
|
||||
peer_address = Address(peer_address)
|
||||
except ValueError:
|
||||
# If the address is not parsable, assume it is a name instead
|
||||
logger.debug('looking for peer by name')
|
||||
peer_address = await self.find_peer_by_name(peer_address, BT_BR_EDR_TRANSPORT) # TODO: timeout
|
||||
|
||||
await self.send_command(HCI_Create_Connection_Cancel_Command(bd_addr=peer_address), check_result=True)
|
||||
|
||||
async def disconnect(self, connection, reason):
|
||||
# Create a future so that we can wait for the disconnection's result
|
||||
@@ -1402,9 +1526,9 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
# Scan/inquire with event handlers to handle scan/inquiry results
|
||||
def on_peer_found(address, ad_data):
|
||||
local_name = ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME)
|
||||
local_name = ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True)
|
||||
if local_name is None:
|
||||
local_name = ad_data.get(AdvertisingData.SHORTENED_LOCAL_NAME)
|
||||
local_name = ad_data.get(AdvertisingData.SHORTENED_LOCAL_NAME, raw=True)
|
||||
if local_name is not None:
|
||||
if local_name.decode('utf-8') == name:
|
||||
peer_address.set_result(address)
|
||||
@@ -1586,23 +1710,31 @@ class Device(CompositeEventEmitter):
|
||||
connection.remove_listener('connection_encryption_failure', on_encryption_failure)
|
||||
|
||||
# [Classic only]
|
||||
async def request_remote_name(self, connection):
|
||||
async def request_remote_name(self, remote: Connection | Address):
|
||||
# Set up event handlers
|
||||
pending_name = asyncio.get_running_loop().create_future()
|
||||
|
||||
def on_remote_name():
|
||||
pending_name.set_result(connection.peer_name)
|
||||
|
||||
def on_remote_name_failure(error_code):
|
||||
pending_name.set_exception(HCI_Error(error_code))
|
||||
|
||||
connection.on('remote_name', on_remote_name)
|
||||
connection.on('remote_name_failure', on_remote_name_failure)
|
||||
if type(remote) == Address:
|
||||
peer_address = remote
|
||||
handler = self.on('remote_name',
|
||||
lambda address, remote_name:
|
||||
pending_name.set_result(remote_name) if address == remote else None)
|
||||
failure_handler = self.on('remote_name_failure',
|
||||
lambda address, error_code:
|
||||
pending_name.set_exception(HCI_Error(error_code)) if address == remote else None)
|
||||
else:
|
||||
peer_address = remote.peer_address
|
||||
handler = remote.on('remote_name',
|
||||
lambda:
|
||||
pending_name.set_result(remote.peer_name))
|
||||
failure_handler = remote.on('remote_name_failure',
|
||||
lambda error_code:
|
||||
pending_name.set_exception(HCI_Error(error_code)))
|
||||
|
||||
try:
|
||||
result = await self.send_command(
|
||||
HCI_Remote_Name_Request_Command(
|
||||
bd_addr = connection.peer_address,
|
||||
bd_addr = peer_address,
|
||||
page_scan_repetition_mode = HCI_Remote_Name_Request_Command.R0, # TODO investigate other options
|
||||
reserved = 0,
|
||||
clock_offset = 0 # TODO investigate non-0 values
|
||||
@@ -1616,8 +1748,12 @@ class Device(CompositeEventEmitter):
|
||||
# Wait for the result
|
||||
return await pending_name
|
||||
finally:
|
||||
connection.remove_listener('remote_name', on_remote_name)
|
||||
connection.remove_listener('remote_name_failure', on_remote_name_failure)
|
||||
if type(remote) == Address:
|
||||
self.remove_listener('remote_name', handler)
|
||||
self.remove_listener('remote_name_failure', failure_handler)
|
||||
else:
|
||||
remote.remove_listener('remote_name', handler)
|
||||
remote.remove_listener('remote_name_failure', failure_handler)
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@@ -1673,6 +1809,14 @@ class Device(CompositeEventEmitter):
|
||||
)
|
||||
self.connections[connection_handle] = connection
|
||||
|
||||
# We may have an accept ongoing waiting for a connection request for `peer_address`.
|
||||
# Typicaly happen when using `connect` to the same `peer_address` we are waiting with
|
||||
# an `accept` for.
|
||||
# In this case, set the completed `connection` to the `accept` future result.
|
||||
if peer_address in self.classic_pending_accepts:
|
||||
future = self.classic_pending_accepts.pop(peer_address)
|
||||
future.set_result(connection)
|
||||
|
||||
# Emit an event to notify listeners of the new connection
|
||||
self.emit('connection', connection)
|
||||
else:
|
||||
@@ -1736,6 +1880,39 @@ class Device(CompositeEventEmitter):
|
||||
)
|
||||
self.emit('connection_failure', error)
|
||||
|
||||
# FIXME: Explore a delegate-model for BR/EDR wait connection #56.
|
||||
@host_event_handler
|
||||
def on_connection_request(self, bd_addr, class_of_device, link_type):
|
||||
logger.debug(f'*** Connection request: {bd_addr}')
|
||||
|
||||
# match a pending future using `bd_addr`
|
||||
if bd_addr in self.classic_pending_accepts:
|
||||
future = self.classic_pending_accepts.pop(bd_addr)
|
||||
future.set_result((bd_addr, class_of_device, link_type))
|
||||
|
||||
# match first pending future for ANY address
|
||||
elif len(self.classic_pending_accepts[Address.ANY]) > 0:
|
||||
future = self.classic_pending_accepts[Address.ANY].pop(0)
|
||||
future.set_result((bd_addr, class_of_device, link_type))
|
||||
|
||||
# device configuration is set to accept any incoming connection
|
||||
elif self.classic_accept_any:
|
||||
self.host.send_command_sync(
|
||||
HCI_Accept_Connection_Request_Command(
|
||||
bd_addr = bd_addr,
|
||||
role = 0x01 # Remain the peripheral
|
||||
)
|
||||
)
|
||||
|
||||
# reject incoming connection
|
||||
else:
|
||||
self.host.send_command_sync(
|
||||
HCI_Reject_Connection_Request_Command(
|
||||
bd_addr = bd_addr,
|
||||
reason = HCI_CONNECTION_REJECTED_DUE_TO_LIMITED_RESOURCES_ERROR
|
||||
)
|
||||
)
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
def on_disconnection(self, connection, reason):
|
||||
@@ -1772,9 +1949,13 @@ class Device(CompositeEventEmitter):
|
||||
@host_event_handler
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_inquiry_complete(self):
|
||||
if self.discovering:
|
||||
if self.auto_restart_inquiry:
|
||||
# Inquire again
|
||||
await self.start_discovery()
|
||||
await self.start_discovery(auto_restart=True)
|
||||
else:
|
||||
self.auto_restart_inquiry = True
|
||||
self.discovering = False
|
||||
self.emit('inquiry_complete')
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
@@ -1899,21 +2080,32 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_remote_name(self, connection, remote_name):
|
||||
@try_with_connection_from_address
|
||||
def on_remote_name(self, connection, address, remote_name):
|
||||
# Try to decode the name
|
||||
try:
|
||||
connection.peer_name = remote_name.decode('utf-8')
|
||||
connection.emit('remote_name')
|
||||
remote_name = remote_name.decode('utf-8')
|
||||
if connection:
|
||||
connection.peer_name = remote_name
|
||||
connection.emit('remote_name')
|
||||
else:
|
||||
self.emit('remote_name', address, remote_name)
|
||||
except UnicodeDecodeError as error:
|
||||
logger.warning('peer name is not valid UTF-8')
|
||||
connection.emit('remote_name_failure', error)
|
||||
if connection:
|
||||
connection.emit('remote_name_failure', error)
|
||||
else:
|
||||
self.emit('remote_name_failure', address, error)
|
||||
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_remote_name_failure(self, connection, error):
|
||||
connection.emit('remote_name_failure', error)
|
||||
@try_with_connection_from_address
|
||||
def on_remote_name_failure(self, connection, address, error):
|
||||
if connection:
|
||||
connection.emit('remote_name_failure', error)
|
||||
else:
|
||||
self.emit('remote_name_failure', address, error)
|
||||
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
|
||||
@@ -1652,6 +1652,16 @@ class Address:
|
||||
|
||||
ADDRESS_TYPE_SPEC = {'size': 1, 'mapper': lambda x: Address.address_type_name(x)}
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def ANY(cls):
|
||||
return cls(b"\xff\xff\xff\xff\xff\xff", cls.PUBLIC_DEVICE_ADDRESS)
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def NIL(cls):
|
||||
return cls(b"\x00\x00\x00\x00\x00\x00", cls.PUBLIC_DEVICE_ADDRESS)
|
||||
|
||||
@staticmethod
|
||||
def address_type_name(address_type):
|
||||
return name_or_number(Address.ADDRESS_TYPE_NAMES, address_type)
|
||||
@@ -1935,6 +1945,17 @@ class HCI_Accept_Connection_Request_Command(HCI_Command):
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Command.command([
|
||||
('bd_addr', Address.parse_address),
|
||||
('reason', {'size': 1, 'mapper': HCI_Constant.error_name})
|
||||
])
|
||||
class HCI_Reject_Connection_Request_Command(HCI_Command):
|
||||
'''
|
||||
See Bluetooth spec @ 7.1.9 Reject Connection Request Command
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Command.command([
|
||||
('bd_addr', Address.parse_address),
|
||||
|
||||
+12
-7
@@ -176,6 +176,9 @@ class Host(EventEmitter):
|
||||
if check_result:
|
||||
if type(response.return_parameters) is int:
|
||||
status = response.return_parameters
|
||||
elif type(response.return_parameters) is bytes:
|
||||
# return parameters first field is a one byte status code
|
||||
status = response.return_parameters[0]
|
||||
else:
|
||||
status = response.return_parameters.status
|
||||
|
||||
@@ -344,13 +347,12 @@ class Host(EventEmitter):
|
||||
|
||||
# Classic only
|
||||
def on_hci_connection_request_event(self, event):
|
||||
# For now, just accept everything
|
||||
# TODO: delegate the decision
|
||||
self.send_command_sync(
|
||||
HCI_Accept_Connection_Request_Command(
|
||||
bd_addr = event.bd_addr,
|
||||
role = 0x01 # Remain the peripheral
|
||||
)
|
||||
# Notify the listeners
|
||||
self.emit(
|
||||
'connection_request',
|
||||
event.bd_addr,
|
||||
event.class_of_device,
|
||||
event.link_type,
|
||||
)
|
||||
|
||||
def on_hci_le_connection_complete_event(self, event):
|
||||
@@ -645,3 +647,6 @@ class Host(EventEmitter):
|
||||
self.emit('remote_name_failure', event.bd_addr, event.status)
|
||||
else:
|
||||
self.emit('remote_name', event.bd_addr, event.remote_name)
|
||||
|
||||
def on_hci_remote_host_supported_features_notification_event(self, event):
|
||||
self.emit('remote_host_supported_features', event.bd_addr, event.host_supported_features)
|
||||
|
||||
+8
-8
@@ -24,19 +24,19 @@ def test_ad_data():
|
||||
ad = AdvertisingData.from_bytes(data)
|
||||
ad_bytes = bytes(ad)
|
||||
assert(data == ad_bytes)
|
||||
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME) is None)
|
||||
assert(ad.get(AdvertisingData.TX_POWER_LEVEL) == bytes([123]))
|
||||
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True) == [])
|
||||
assert(ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True) == [bytes([123])])
|
||||
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None)
|
||||
assert(ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123]))
|
||||
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == [])
|
||||
assert(ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [bytes([123])])
|
||||
|
||||
data2 = bytes([2, AdvertisingData.TX_POWER_LEVEL, 234])
|
||||
ad.append(data2)
|
||||
ad_bytes = bytes(ad)
|
||||
assert(ad_bytes == data + data2)
|
||||
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME) is None)
|
||||
assert(ad.get(AdvertisingData.TX_POWER_LEVEL) == bytes([123]))
|
||||
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True) == [])
|
||||
assert(ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True) == [bytes([123]), bytes([234])])
|
||||
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None)
|
||||
assert(ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123]))
|
||||
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == [])
|
||||
assert(ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [bytes([123]), bytes([234])])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
+12
-5
@@ -158,16 +158,23 @@ async def test_device_connect_parallel():
|
||||
d1.host.set_packet_sink(Sink(d1_flow()))
|
||||
d2.host.set_packet_sink(Sink(d2_flow()))
|
||||
|
||||
[c1, c2] = await asyncio.gather(*[
|
||||
[c01, c02, a10, a20, a01] = await asyncio.gather(*[
|
||||
asyncio.create_task(d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT)),
|
||||
asyncio.create_task(d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT)),
|
||||
asyncio.create_task(d1.accept(peer_address=d0.public_address)),
|
||||
asyncio.create_task(d2.accept()),
|
||||
asyncio.create_task(d0.accept(peer_address=d1.public_address)),
|
||||
])
|
||||
|
||||
assert type(c1) == Connection
|
||||
assert type(c2) == Connection
|
||||
assert type(c01) == Connection
|
||||
assert type(c02) == Connection
|
||||
assert type(a10) == Connection
|
||||
assert type(a20) == Connection
|
||||
assert type(a01) == Connection
|
||||
|
||||
assert c1.handle == 0x100
|
||||
assert c2.handle == 0x101
|
||||
assert c01.handle == a10.handle and c01.handle == 0x100
|
||||
assert c02.handle == a20.handle and c02.handle == 0x101
|
||||
assert a01 == c01
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user