diff --git a/bumble/device.py b/bumble/device.py index 3e83153..13217b7 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -589,6 +589,17 @@ def with_connection_from_address(function): return wrapper +# Decorator that tries to convert the first argument from a bluetooth address to a connection +def try_with_connection_from_address(function): + @functools.wraps(function) + def wrapper(self, address, *args, **kwargs): + for connection in self.connections.values(): + if connection.peer_address == address: + return function(self, connection, address, *args, **kwargs) + return function(self, None, address, *args, **kwargs) + return wrapper + + # Decorator that adds a method to the list of event handlers for host events. # This assumes that the method name starts with `on_` def host_event_handler(function): @@ -1586,23 +1597,31 @@ class Device(CompositeEventEmitter): connection.remove_listener('connection_encryption_failure', on_encryption_failure) # [Classic only] - async def request_remote_name(self, connection): + async def request_remote_name(self, remote: Connection | Address): # Set up event handlers pending_name = asyncio.get_running_loop().create_future() - def on_remote_name(): - pending_name.set_result(connection.peer_name) - - def on_remote_name_failure(error_code): - pending_name.set_exception(HCI_Error(error_code)) - - connection.on('remote_name', on_remote_name) - connection.on('remote_name_failure', on_remote_name_failure) + if type(remote) == Address: + peer_address = remote + handler = self.on('remote_name', + lambda address, remote_name: + pending_name.set_result(remote_name) if address == remote else None) + failure_handler = self.on('remote_name_failure', + lambda address, error_code: + pending_name.set_exception(HCI_Error(error_code)) if address == remote else None) + else: + peer_address = remote.peer_address + handler = remote.on('remote_name', + lambda: + pending_name.set_result(remote.peer_name)) + failure_handler = remote.on('remote_name_failure', + lambda error_code: + pending_name.set_exception(HCI_Error(error_code))) try: result = await self.send_command( HCI_Remote_Name_Request_Command( - bd_addr = connection.peer_address, + bd_addr = peer_address, page_scan_repetition_mode = HCI_Remote_Name_Request_Command.R0, # TODO investigate other options reserved = 0, clock_offset = 0 # TODO investigate non-0 values @@ -1616,8 +1635,12 @@ class Device(CompositeEventEmitter): # Wait for the result return await pending_name finally: - connection.remove_listener('remote_name', on_remote_name) - connection.remove_listener('remote_name_failure', on_remote_name_failure) + if type(remote) == Address: + self.remove_listener('remote_name', handler) + self.remove_listener('remote_name_failure', failure_handler) + else: + remote.remove_listener('remote_name', handler) + remote.remove_listener('remote_name_failure', failure_handler) # [Classic only] @host_event_handler @@ -1899,21 +1922,32 @@ class Device(CompositeEventEmitter): # [Classic only] @host_event_handler - @with_connection_from_address - def on_remote_name(self, connection, remote_name): + @try_with_connection_from_address + def on_remote_name(self, connection, address, remote_name): # Try to decode the name try: - connection.peer_name = remote_name.decode('utf-8') - connection.emit('remote_name') + remote_name = remote_name.decode('utf-8') + if connection: + connection.peer_name = remote_name + connection.emit('remote_name') + else: + self.emit('remote_name', address, remote_name) except UnicodeDecodeError as error: logger.warning('peer name is not valid UTF-8') - connection.emit('remote_name_failure', error) + if connection: + connection.emit('remote_name_failure', error) + else: + self.emit('remote_name_failure', address, error) + # [Classic only] @host_event_handler - @with_connection_from_address - def on_remote_name_failure(self, connection, error): - connection.emit('remote_name_failure', error) + @try_with_connection_from_address + def on_remote_name_failure(self, connection, address, error): + if connection: + connection.emit('remote_name_failure', error) + else: + self.emit('remote_name_failure', address, error) @host_event_handler @with_connection_from_handle