diff --git a/bumble/device.py b/bumble/device.py index c2fbc522..fde4ab42 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -413,12 +413,36 @@ class Connection(CompositeEventEmitter): self.parameters = parameters self.encryption = 0 self.authenticated = False + self.sc = False + self.link_key_type = None + self.authenticating = False self.phy = phy self.att_mtu = ATT_DEFAULT_MTU self.data_length = DEVICE_DEFAULT_DATA_LENGTH self.gatt_client = None # Per-connection client self.gatt_server = device.gatt_server # By default, use the device's shared server + # [Classic only] + @classmethod + def incomplete(cls, device, peer_address): + """ + Instantiate an incomplete connection (ie. one waiting for a HCI Connection Complete event). + Once received it shall be completed using the `.complete` method. + """ + return cls(device, None, BT_BR_EDR_TRANSPORT, device.public_address, peer_address, None, None, None, None) + + # [Classic only] + def complete(self, handle, peer_resolvable_address, role, parameters): + """ + Finish an incomplete connection upon completion. + """ + assert self.handle is None + assert self.transport == BT_BR_EDR_TRANSPORT + self.handle = handle + self.peer_resolvable_address = peer_resolvable_address + self.role = role + self.parameters = parameters + @property def role_name(self): return 'CENTRAL' if self.role == BT_CENTRAL_ROLE else 'PERIPHERAL' @@ -600,6 +624,8 @@ def with_connection_from_handle(function): def with_connection_from_address(function): @functools.wraps(function) def wrapper(self, address, *args, **kwargs): + if (connection := self.pending_connections.get(address, False)): + return function(self, connection, *args, **kwargs) for connection in self.connections.values(): if connection.peer_address == address: return function(self, connection, *args, **kwargs) @@ -611,6 +637,8 @@ def with_connection_from_address(function): def try_with_connection_from_address(function): @functools.wraps(function) def wrapper(self, address, *args, **kwargs): + if (connection := self.pending_connections.get(address, False)): + return function(self, connection, address, *args, **kwargs) for connection in self.connections.values(): if connection.peer_address == address: return function(self, connection, address, *args, **kwargs) @@ -698,6 +726,7 @@ class Device(CompositeEventEmitter): self.le_connecting = False self.disconnecting = False self.connections = {} # Connections, by connection handle + self.pending_connections = {} # Connections, by BD address (BR/EDR only) self.classic_enabled = False self.inquiry_response = None self.address_resolver = None @@ -818,7 +847,7 @@ class Device(CompositeEventEmitter): def find_connection_by_bd_addr(self, bd_addr, transport=None, check_address_type=False): for connection in self.connections.values(): - if connection.peer_address.get_bytes() == bd_addr.get_bytes(): + if connection.peer_address.to_bytes() == bd_addr.to_bytes(): if check_address_type and connection.peer_address.address_type != bd_addr.address_type: continue if transport is None or connection.transport == transport: @@ -1345,6 +1374,9 @@ class Device(CompositeEventEmitter): max_ce_length = int(prefs.max_ce_length / 0.625), )) else: + # Save pending connection + self.pending_connections[peer_address] = Connection.incomplete(self, peer_address) + # TODO: allow passing other settings result = await self.send_command(HCI_Create_Connection_Command( bd_addr = peer_address, @@ -1382,6 +1414,8 @@ class Device(CompositeEventEmitter): if transport == BT_LE_TRANSPORT: self.le_connecting = False self.connect_own_address_type = None + else: + self.pending_connections.pop(peer_address, None) async def accept( self, @@ -1395,7 +1429,7 @@ class Device(CompositeEventEmitter): Notes: * A `connect` to the same peer will also complete this call. * The `timeout` parameter is only handled while waiting for the connection request, - once received and accepeted, the controller shall issue a connection complete event. + once received and accepted, the controller shall issue a connection complete event. ''' if type(peer_address) is str: @@ -1451,6 +1485,9 @@ class Device(CompositeEventEmitter): self.on('connection', on_connection) self.on('connection_failure', on_connection_failure) + # Save pending connection + self.pending_connections[peer_address] = Connection.incomplete(self, peer_address) + try: # Accept connection request await self.send_command(HCI_Accept_Connection_Request_Command( @@ -1464,6 +1501,7 @@ class Device(CompositeEventEmitter): finally: self.remove_listener('connection', on_connection) self.remove_listener('connection_failure', on_connection_failure) + self.pending_connections.pop(peer_address, None) @asynccontextmanager async def connect_as_gatt(self, peer_address): @@ -1707,9 +1745,13 @@ class Device(CompositeEventEmitter): logger.warn(f'HCI_Authentication_Requested_Command failed: {HCI_Constant.error_name(result.status)}') raise HCI_StatusError(result) + # Save in connection we are trying to authenticate + connection.authenticating = True + # Wait for the authentication to complete await pending_authentication finally: + connection.authenticating = False connection.remove_listener('connection_authentication', on_authentication) connection.remove_listener('connection_authentication_failure', on_authentication_failure) @@ -1786,28 +1828,18 @@ class Device(CompositeEventEmitter): # Set up event handlers pending_name = asyncio.get_running_loop().create_future() - if type(remote) == Address: - peer_address = remote - handler = self.on( - 'remote_name', - lambda address, remote_name: - pending_name.set_result(remote_name) if address == remote else None - ) - failure_handler = self.on( - 'remote_name_failure', - lambda address, error_code: - pending_name.set_exception(HCI_Error(error_code)) if address == remote else None - ) - else: - peer_address = remote.peer_address - handler = remote.on( - 'remote_name', - lambda: pending_name.set_result(remote.peer_name) - ) - failure_handler = remote.on( - 'remote_name_failure', - lambda error_code: pending_name.set_exception(HCI_Error(error_code)) - ) + peer_address = remote if type(remote) == Address else remote.peer_address + + handler = self.on( + 'remote_name', + lambda address, remote_name: + pending_name.set_result(remote_name) if address == peer_address else None + ) + failure_handler = self.on( + 'remote_name_failure', + lambda address, error_code: + pending_name.set_exception(HCI_Error(error_code)) if address == peer_address else None + ) try: result = await self.send_command( @@ -1826,12 +1858,8 @@ class Device(CompositeEventEmitter): # Wait for the result return await pending_name finally: - if type(remote) == Address: - self.remove_listener('remote_name', handler) - self.remove_listener('remote_name_failure', failure_handler) - else: - remote.remove_listener('remote_name', handler) - remote.remove_listener('remote_name_failure', failure_handler) + self.remove_listener('remote_name', handler) + self.remove_listener('remote_name_failure', failure_handler) # [Classic only] @host_event_handler @@ -1849,6 +1877,9 @@ class Device(CompositeEventEmitter): asyncio.create_task(store_keys()) + if (connection := self.find_connection_by_bd_addr(bd_addr, transport=BT_BR_EDR_TRANSPORT)): + connection.link_key_type = key_type + def add_service(self, service): self.gatt_server.add_service(service) @@ -1875,17 +1906,8 @@ class Device(CompositeEventEmitter): if transport == BT_BR_EDR_TRANSPORT: # Create a new connection - connection = Connection( - self, - connection_handle, - transport, - self.public_address, - peer_address, - peer_resolvable_address, - role, - connection_parameters, - phy=None - ) + connection: Connection = self.pending_connections.pop(peer_address) + connection.complete(connection_handle, peer_resolvable_address, role, connection_parameters) self.connections[connection_handle] = connection # We may have an accept ongoing waiting for a connection request for `peer_address`. @@ -1991,6 +2013,9 @@ class Device(CompositeEventEmitter): # device configuration is set to accept any incoming connection elif self.classic_accept_any: + # Save pending connection + self.pending_connections[bd_addr] = Connection.incomplete(self, bd_addr) + self.host.send_command_sync( HCI_Accept_Connection_Request_Command( bd_addr = bd_addr, @@ -2064,6 +2089,17 @@ class Device(CompositeEventEmitter): logger.debug(f'*** Connection Authentication Failure: [0x{connection.handle:04X}] {connection.peer_address} as {connection.role_name}, error={error}') connection.emit('connection_authentication_failure', error) + @host_event_handler + @with_connection_from_address + def on_ssp_complete(self, connection): + # On Secure Simple Pairing complete, in case: + # - Connection isn't already authenticated + # - AND We are not the initiator of the authentication + # We must trigger authentication to known if we are truly authenticated + if not connection.authenticating and not connection.authenticated: + logger.debug(f'*** Trigger Connection Authentication: [0x{connection.handle:04X}] {connection.peer_address}') + asyncio.create_task(connection.authenticate()) + # [Classic only] @host_event_handler @with_connection_from_address @@ -2200,8 +2236,7 @@ class Device(CompositeEventEmitter): if connection: connection.peer_name = remote_name connection.emit('remote_name') - else: - self.emit('remote_name', address, remote_name) + self.emit('remote_name', address, remote_name) except UnicodeDecodeError as error: logger.warning('peer name is not valid UTF-8') if connection: @@ -2215,8 +2250,7 @@ class Device(CompositeEventEmitter): def on_remote_name_failure(self, connection, address, error): if connection: connection.emit('remote_name_failure', error) - else: - self.emit('remote_name_failure', address, error) + self.emit('remote_name_failure', address, error) @host_event_handler @with_connection_from_handle @@ -2282,7 +2316,9 @@ class Device(CompositeEventEmitter): connection.emit('pairing_start') @with_connection_from_handle - def on_pairing(self, connection, keys): + def on_pairing(self, connection, keys, sc): + connection.sc = sc + connection.authenticated = True connection.emit('pairing', keys) @with_connection_from_handle diff --git a/bumble/host.py b/bumble/host.py index 8e43c50b..ae4cc666 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -599,6 +599,9 @@ class Host(EventEmitter): def on_hci_simple_pairing_complete_event(self, event): logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}') + # Notify the client + if event.status == HCI_SUCCESS: + self.emit('ssp_complete', event.bd_addr) def on_hci_pin_code_request_event(self, event): # For now, just refuse all requests diff --git a/bumble/keys.py b/bumble/keys.py index f51cfe65..b8c05b48 100644 --- a/bumble/keys.py +++ b/bumble/keys.py @@ -20,6 +20,7 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +import asyncio import logging import os import json @@ -143,6 +144,10 @@ class KeyStore: async def get_all(self): return [] + async def delete_all(self): + all_keys = await self.get_all() + await asyncio.gather(*(self.delete(name) for (name, _) in all_keys)) + async def get_resolving_keys(self): all_keys = await self.get_all() resolving_keys = [] @@ -259,6 +264,13 @@ class JsonKeyStore(KeyStore): return [(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()] + async def delete_all(self): + db = await self.load() + + db.pop(self.namespace, None) + + await self.save(db) + async def get(self, name): db = await self.load() diff --git a/bumble/smp.py b/bumble/smp.py index 8f0ea0bd..4c6ca4eb 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -1583,7 +1583,7 @@ class Manager(EventEmitter): asyncio.create_task(store_keys()) # Notify the device - self.device.on_pairing(session.connection.handle, keys) + self.device.on_pairing(session.connection.handle, keys, session.sc) def on_pairing_failure(self, session, reason): self.device.on_pairing_failure(session.connection.handle, reason)