diff --git a/bumble/device.py b/bumble/device.py index 56b747f6..2762daf8 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -2169,10 +2169,12 @@ def with_connection_from_handle(function): # Decorator that converts the first argument from a bluetooth address to a connection def with_connection_from_address(function): @functools.wraps(function) - def wrapper(self, address: hci.Address, *args, **kwargs): - for connection in self.connections.values(): + def wrapper(device: Device, address: hci.Address, *args, **kwargs): + if connection := device.pending_connections.get(address): + return function(device, connection, address, *args, **kwargs) + for connection in device.connections.values(): if connection.peer_address == address: - return function(self, connection, *args, **kwargs) + return function(device, connection, *args, **kwargs) raise ObjectLookupError('no connection for address') return wrapper @@ -2182,11 +2184,13 @@ def with_connection_from_address(function): # connection def try_with_connection_from_address(function): @functools.wraps(function) - def wrapper(self, address, *args, **kwargs): - for connection in self.connections.values(): + def wrapper(device: Device, address: hci.Address, *args, **kwargs): + if connection := device.pending_connections.get(address): + return function(device, connection, address, *args, **kwargs) + for connection in device.connections.values(): if connection.peer_address == address: - return function(self, connection, address, *args, **kwargs) - return function(self, None, address, *args, **kwargs) + return function(device, connection, address, *args, **kwargs) + return function(device, None, address, *args, **kwargs) return wrapper @@ -2234,7 +2238,7 @@ class Device(utils.CompositeEventEmitter): scan_response_data: bytes cs_capabilities: ChannelSoundingCapabilities | None = None connections: dict[int, Connection] - connection_roles: dict[hci.Address, hci.Role] + pending_connections: dict[hci.Address, Connection] classic_pending_accepts: dict[ hci.Address, list[asyncio.Future[Union[Connection, tuple[hci.Address, int, int]]]], @@ -2356,9 +2360,9 @@ class Device(utils.CompositeEventEmitter): self.le_connecting = False self.disconnecting = False self.connections = {} # Connections, by connection handle - self.connection_roles = ( + self.pending_connections = ( {} - ) # Local connection roles, by BD address (BR/EDR only) + ) # Pending connections, by BD address (BR/EDR only) self.sco_links = {} # ScoLinks, by connection handle (BR/EDR only) self.cis_links = {} # CisLinks, by connection handle (LE only) self._pending_cis = {} # (CIS_ID, CIG_ID), by CIS_handle @@ -3827,7 +3831,17 @@ class Device(utils.CompositeEventEmitter): ) else: # Save pending connection - self.connection_roles[peer_address] = hci.Role.CENTRAL + self.pending_connections[peer_address] = Connection( + device=self, + handle=0, + transport=core.PhysicalTransport.BR_EDR, + self_address=self.public_address, + self_resolvable_address=None, + peer_address=peer_address, + peer_resolvable_address=None, + role=hci.Role.CENTRAL, + parameters=Connection.Parameters(0, 0, 0), + ) # TODO: allow passing other settings result = await self.send_command( @@ -3880,7 +3894,7 @@ class Device(utils.CompositeEventEmitter): self.le_connecting = False self.connect_own_address_type = None else: - self.connection_roles.pop(peer_address, None) + self.pending_connections.pop(peer_address, None) async def accept( self, @@ -3978,7 +3992,17 @@ class Device(utils.CompositeEventEmitter): # Even if we requested a role switch in the hci.HCI_Accept_Connection_Request # command, this connection is still considered Peripheral until an eventual # role change event. - self.connection_roles[peer_address] = hci.Role.PERIPHERAL + self.pending_connections[peer_address] = Connection( + device=self, + handle=0, + transport=core.PhysicalTransport.BR_EDR, + self_address=self.public_address, + self_resolvable_address=None, + peer_address=peer_address, + peer_resolvable_address=None, + role=hci.Role.PERIPHERAL, + parameters=Connection.Parameters(0, 0, 0), + ) try: # Accept connection request @@ -3996,7 +4020,7 @@ class Device(utils.CompositeEventEmitter): finally: self.remove_listener(self.EVENT_CONNECTION, on_connection) self.remove_listener(self.EVENT_CONNECTION_FAILURE, on_connection_failure) - self.connection_roles.pop(peer_address, None) + self.pending_connections.pop(peer_address, None) @asynccontextmanager async def connect_as_gatt(self, peer_address: Union[hci.Address, str]): @@ -5441,29 +5465,27 @@ class Device(utils.CompositeEventEmitter): connection_handle: int, peer_address: hci.Address, ) -> None: - connection_role = self.connection_roles.pop(peer_address, hci.Role.PERIPHERAL) + if connection := self.pending_connections.pop(peer_address, None): + connection.handle = connection_handle + else: + # Create a new connection + connection = Connection( + device=self, + handle=connection_handle, + transport=PhysicalTransport.BR_EDR, + self_address=self.public_address, + self_resolvable_address=None, + peer_address=peer_address, + peer_resolvable_address=None, + role=hci.Role.PERIPHERAL, + parameters=Connection.Parameters(0.0, 0, 0.0), + ) - logger.debug( - f'*** Connection: [0x{connection_handle:04X}] ' - f'{peer_address} {hci.HCI_Constant.role_name(connection_role)}' - ) + logger.debug('*** %s', connection) if connection_handle in self.connections: logger.warning( 'new connection reuses the same handle as a previous connection' ) - - # Create a new connection - connection = Connection( - device=self, - handle=connection_handle, - transport=PhysicalTransport.BR_EDR, - self_address=self.public_address, - self_resolvable_address=None, - peer_address=peer_address, - peer_resolvable_address=None, - role=connection_role, - parameters=Connection.Parameters(0.0, 0, 0.0), - ) self.connections[connection_handle] = connection self.emit(self.EVENT_CONNECTION, connection) @@ -5618,7 +5640,9 @@ class Device(utils.CompositeEventEmitter): # FIXME: Explore a delegate-model for BR/EDR wait connection #56. @host_event_handler - def on_connection_request(self, bd_addr, class_of_device, link_type): + def on_connection_request( + self, bd_addr: hci.Address, class_of_device: int, link_type: int + ): logger.debug(f'*** Connection request: {bd_addr}') # Handle SCO request. @@ -5647,7 +5671,17 @@ class Device(utils.CompositeEventEmitter): # device configuration is set to accept any incoming connection elif self.classic_accept_any: # Save pending connection - self.connection_roles[bd_addr] = hci.Role.PERIPHERAL + self.pending_connections[bd_addr] = Connection( + device=self, + handle=0, + transport=core.PhysicalTransport.BR_EDR, + self_address=self.public_address, + self_resolvable_address=None, + peer_address=bd_addr, + peer_resolvable_address=None, + role=hci.Role.PERIPHERAL, + parameters=Connection.Parameters(0, 0, 0), + ) self.host.send_command_sync( hci.HCI_Accept_Connection_Request_Command( @@ -5958,7 +5992,7 @@ class Device(utils.CompositeEventEmitter): @host_event_handler @try_with_connection_from_address def on_remote_name( - self, connection: Connection, address: hci.Address, remote_name: bytes + self, connection: Optional[Connection], address: hci.Address, remote_name: bytes ): # Try to decode the name try: @@ -5977,7 +6011,7 @@ class Device(utils.CompositeEventEmitter): @host_event_handler @try_with_connection_from_address def on_remote_name_failure( - self, connection: Connection, address: hci.Address, error: int + self, connection: Optional[Connection], address: hci.Address, error: int ): if connection: connection.emit(connection.EVENT_REMOTE_NAME_FAILURE, error) @@ -6411,19 +6445,22 @@ class Device(utils.CompositeEventEmitter): @host_event_handler @try_with_connection_from_address def on_role_change( - self, connection: Connection, peer_address: hci.Address, new_role: hci.Role + self, + connection: Optional[Connection], + peer_address: hci.Address, + new_role: hci.Role, ): if connection: connection.role = new_role connection.emit(connection.EVENT_ROLE_CHANGE, new_role) else: - self.connection_roles[peer_address] = new_role + logger.warning("Role change to unknown connection %s", peer_address) # [Classic only] @host_event_handler @try_with_connection_from_address def on_role_change_failure( - self, connection: Connection, address: hci.Address, error: int + self, connection: Optional[Connection], address: hci.Address, error: int ): if connection: connection.emit(connection.EVENT_ROLE_CHANGE_FAILURE, error)