diff --git a/bumble/device.py b/bumble/device.py index 5a16392b..79b305a7 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -535,7 +535,7 @@ class Connection(CompositeEventEmitter): self.on('disconnection_failure', abort.set_exception) try: - await asyncio.wait_for(abort, timeout) + await asyncio.wait_for(self.device.abort_on('flush', abort), timeout) except asyncio.TimeoutError: pass @@ -1592,7 +1592,7 @@ class Device(CompositeEventEmitter): if transport == BT_LE_TRANSPORT: self.le_connecting = True if timeout is None: - return await pending_connection + return await self.abort_on('flush', pending_connection) else: try: return await asyncio.wait_for( @@ -1609,7 +1609,7 @@ class Device(CompositeEventEmitter): ) try: - return await pending_connection + return await self.abort_on('flush', pending_connection) except ConnectionError: raise TimeoutError() finally: @@ -1661,6 +1661,7 @@ class Device(CompositeEventEmitter): try: # Wait for a request or a completed connection + pending_request = self.abort_on('flush', pending_request) result = await ( asyncio.wait_for(pending_request, timeout) if timeout @@ -1682,6 +1683,9 @@ class Device(CompositeEventEmitter): # Otherwise, result came from `on_connection_request` peer_address, class_of_device, link_type = result + # Create a future so that we can wait for the connection's result + pending_connection = asyncio.get_running_loop().create_future() + def on_connection(connection): if ( connection.transport == BT_BR_EDR_TRANSPORT @@ -1696,8 +1700,6 @@ class Device(CompositeEventEmitter): ): pending_connection.set_exception(error) - # Create a future so that we can wait for the connection's result - pending_connection = asyncio.get_running_loop().create_future() self.on('connection', on_connection) self.on('connection_failure', on_connection_failure) @@ -1713,7 +1715,7 @@ class Device(CompositeEventEmitter): ) # Wait for connection complete - return await pending_connection + return await self.abort_on('flush', pending_connection) finally: self.remove_listener('connection', on_connection) @@ -1782,7 +1784,7 @@ class Device(CompositeEventEmitter): # Wait for the disconnection process to complete self.disconnecting = True - return await pending_disconnection + return await self.abort_on('flush', pending_disconnection) finally: connection.remove_listener( 'disconnection', pending_disconnection.set_result @@ -1910,7 +1912,7 @@ class Device(CompositeEventEmitter): else: return None - return await peer_address + return await self.abort_on('flush', peer_address) finally: if handler is not None: self.remove_listener(event_name, handler) @@ -1994,7 +1996,7 @@ class Device(CompositeEventEmitter): connection.authenticating = True # Wait for the authentication to complete - await pending_authentication + await connection.abort_on('disconnection', pending_authentication) finally: connection.authenticating = False connection.remove_listener('connection_authentication', on_authentication) @@ -2068,7 +2070,7 @@ class Device(CompositeEventEmitter): raise HCI_StatusError(result) # Wait for the result - await pending_encryption + await connection.abort_on('disconnection', pending_encryption) finally: connection.remove_listener( 'connection_encryption_change', on_encryption_change @@ -2116,11 +2118,18 @@ class Device(CompositeEventEmitter): raise HCI_StatusError(result) # Wait for the result - return await pending_name + return await self.abort_on('flush', pending_name) finally: self.remove_listener('remote_name', handler) self.remove_listener('remote_name_failure', failure_handler) + @host_event_handler + def on_flush(self): + self.emit('flush') + for _, connection in self.connections.items(): + connection.emit('disconnection', 0) + self.connections = {} + # [Classic only] @host_event_handler def on_link_key(self, bd_addr, link_key, key_type): @@ -2135,7 +2144,7 @@ class Device(CompositeEventEmitter): except Exception as error: logger.warn(f'!!! error while storing keys: {error}') - asyncio.create_task(store_keys()) + self.abort_on('flush', store_keys()) if connection := self.find_connection_by_bd_addr( bd_addr, transport=BT_BR_EDR_TRANSPORT @@ -2227,10 +2236,10 @@ class Device(CompositeEventEmitter): async def new_connection(): # Figure out which PHY we're connected with if self.host.supports_command(HCI_LE_READ_PHY_COMMAND): - result = await self.send_command( + result = await asyncio.shield(self.send_command( HCI_LE_Read_PHY_Command(connection_handle=connection_handle), check_result=True, - ) + )) phy = ConnectionPHY( result.return_parameters.tx_phy, result.return_parameters.rx_phy ) @@ -2261,7 +2270,7 @@ class Device(CompositeEventEmitter): # Emit an event to notify listeners of the new connection self.emit('connection', connection) - asyncio.create_task(new_connection()) + self.abort_on('flush', new_connection()) @host_event_handler def on_connection_failure(self, transport, peer_address, error_code): @@ -2338,7 +2347,7 @@ class Device(CompositeEventEmitter): # Restart advertising if auto-restart is enabled if self.auto_restart_advertising: logger.debug('restarting advertising') - asyncio.create_task( + self.abort_on('flush', self.start_advertising( advertising_type=self.advertising_type, auto_restart=True ) @@ -2460,17 +2469,19 @@ class Device(CompositeEventEmitter): if can_compare: async def compare_numbers(): - numbers_match = await pairing_config.delegate.compare_numbers( - code, digits=6 + numbers_match = await connection.abort_on('disconnection', + pairing_config.delegate.compare_numbers( + code, digits=6 + ) ) if numbers_match: - self.host.send_command_sync( + await self.host.send_command( HCI_User_Confirmation_Request_Reply_Command( bd_addr=connection.peer_address ) ) else: - self.host.send_command_sync( + await self.host.send_command( HCI_User_Confirmation_Request_Negative_Reply_Command( bd_addr=connection.peer_address ) @@ -2480,15 +2491,16 @@ class Device(CompositeEventEmitter): else: async def confirm(): - confirm = await pairing_config.delegate.confirm() + confirm = await connection.abort_on('disconnection', + pairing_config.delegate.confirm()) if confirm: - self.host.send_command_sync( + await self.host.send_command( HCI_User_Confirmation_Request_Reply_Command( bd_addr=connection.peer_address ) ) else: - self.host.send_command_sync( + await self.host.send_command( HCI_User_Confirmation_Request_Negative_Reply_Command( bd_addr=connection.peer_address ) @@ -2512,15 +2524,16 @@ class Device(CompositeEventEmitter): if can_input: async def get_number(): - number = await pairing_config.delegate.get_number() + number = await connection.abort_on('disconnection', + pairing_config.delegate.get_number()) if number is not None: - self.host.send_command_sync( + await self.host.send_command( HCI_User_Passkey_Request_Reply_Command( bd_addr=connection.peer_address, numeric_value=number ) ) else: - self.host.send_command_sync( + await self.host.send_command( HCI_User_Passkey_Request_Negative_Reply_Command( bd_addr=connection.peer_address ) @@ -2541,7 +2554,7 @@ class Device(CompositeEventEmitter): # Ask what the pairing config should be for this connection pairing_config = self.pairing_config_factory(connection) - asyncio.create_task(pairing_config.delegate.display_number(passkey)) + connection.abort_on('disconnection', pairing_config.delegate.display_number(passkey)) # [Classic only] @host_event_handler diff --git a/bumble/host.py b/bumble/host.py index 354d5fbe..d7683796 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -17,7 +17,6 @@ # ----------------------------------------------------------------------------- import asyncio import logging -from pyee import EventEmitter from colors import color from .hci import * @@ -26,6 +25,7 @@ from .att import * from .gatt import * from .smp import * from .core import ConnectionParameters +from .utils import AbortableEventEmitter # ----------------------------------------------------------------------------- # Logging @@ -65,7 +65,7 @@ class Connection: # ----------------------------------------------------------------------------- -class Host(EventEmitter): +class Host(AbortableEventEmitter): def __init__(self, controller_source=None, controller_sink=None): super().__init__() @@ -96,7 +96,19 @@ class Host(EventEmitter): if controller_sink: self.set_packet_sink(controller_sink) + async def flush(self): + # Make sure no command is pending + await self.command_semaphore.acquire() + + # Flush current host state, then release command semaphore + self.emit('flush') + self.command_semaphore.release() + async def reset(self): + if self.ready: + self.ready = False + await self.flush() + await self.send_command(HCI_Reset_Command(), check_result=True) self.ready = True @@ -604,9 +616,9 @@ class Host(EventEmitter): logger.debug('no long term key provider') long_term_key = None else: - long_term_key = await self.long_term_key_provider( + long_term_key = await self.abort_on('flush', self.long_term_key_provider( connection.handle, event.random_number, event.encryption_diversifier - ) + )) if long_term_key: response = HCI_LE_Long_Term_Key_Request_Reply_Command( connection_handle=event.connection_handle, @@ -719,7 +731,7 @@ class Host(EventEmitter): logger.debug('no link key provider') link_key = None else: - link_key = await self.link_key_provider(event.bd_addr) + link_key = await self.abort_on('flush', self.link_key_provider(event.bd_addr)) if link_key: response = HCI_Link_Key_Request_Reply_Command( bd_addr=event.bd_addr, link_key=link_key diff --git a/bumble/smp.py b/bumble/smp.py index e9d6fe33..be3eea37 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -766,7 +766,7 @@ class Session: self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR) - asyncio.create_task(prompt()) + self.connection.abort_on('disconnection', prompt()) def prompt_user_for_numeric_comparison(self, code, next_steps): async def prompt(): @@ -783,7 +783,7 @@ class Session: self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR) - asyncio.create_task(prompt()) + self.connection.abort_on('disconnection', prompt()) def prompt_user_for_number(self, next_steps): async def prompt(): @@ -796,7 +796,7 @@ class Session: logger.warn(f'exception while prompting: {error}') self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR) - asyncio.create_task(prompt()) + self.connection.abort_on('disconnection', prompt()) def display_passkey(self): # Generate random Passkey/PIN code @@ -808,7 +808,7 @@ class Session: self.tk = self.passkey.to_bytes(16, byteorder='little') logger.debug(f'TK from passkey = {self.tk.hex()}') - asyncio.create_task( + self.connection.abort_on('disconnection', self.pairing_config.delegate.display_number(self.passkey, digits=6) ) @@ -921,14 +921,12 @@ class Session: def start_encryption(self, key): # We can now encrypt the connection with the short term key, so that we can # distribute the long term and/or other keys over an encrypted connection - asyncio.create_task( - self.manager.device.host.send_command( - HCI_LE_Enable_Encryption_Command( - connection_handle=self.connection.handle, - random_number=bytes(8), - encrypted_diversifier=0, - long_term_key=key, - ) + self.manager.device.host.send_command_sync( + HCI_LE_Enable_Encryption_Command( + connection_handle=self.connection.handle, + random_number=bytes(8), + encrypted_diversifier=0, + long_term_key=key ) ) @@ -950,7 +948,7 @@ class Session: self.connection.transport == BT_BR_EDR_TRANSPORT and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG ): - self.ctkd_task = asyncio.create_task(self.derive_ltk()) + self.ctkd_task = self.connection.abort_on('disconnection', self.derive_ltk()) elif not self.sc: # Distribute the LTK, EDIV and RAND if self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG: @@ -997,7 +995,7 @@ class Session: self.connection.transport == BT_BR_EDR_TRANSPORT and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG ): - self.ctkd_task = asyncio.create_task(self.derive_ltk()) + self.ctkd_task = self.connection.abort_on('disconnection', self.derive_ltk()) # Distribute the LTK, EDIV and RAND elif not self.sc: if self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG: @@ -1094,7 +1092,7 @@ class Session: self.send_pairing_request_command() # Wait for the pairing process to finish - await self.pairing_result + await self.connection.abort_on('disconnection', self.pairing_result) def on_disconnection(self, reason): self.connection.remove_listener('disconnection', self.on_disconnection) @@ -1112,7 +1110,7 @@ class Session: if self.is_initiator: self.distribute_keys() - asyncio.create_task(self.on_pairing()) + self.connection.abort_on('disconnection', self.on_pairing()) def on_connection_encryption_change(self): if self.connection.is_encrypted: @@ -1219,7 +1217,7 @@ class Session: logger.error(color('SMP command not handled???', 'red')) def on_smp_pairing_request_command(self, command): - asyncio.create_task(self.on_smp_pairing_request_command_async(command)) + self.connection.abort_on('disconnection', self.on_smp_pairing_request_command_async(command)) async def on_smp_pairing_request_command_async(self, command): # Check if the request should proceed @@ -1572,7 +1570,7 @@ class Session: self.wait_before_continuing = None self.send_pairing_dhkey_check_command() - asyncio.create_task(next_steps()) + self.connection.abort_on('disconnection', next_steps()) else: self.send_pairing_dhkey_check_command() else: @@ -1688,7 +1686,7 @@ class Manager(EventEmitter): except Exception as error: logger.warn(f'!!! error while storing keys: {error}') - asyncio.create_task(store_keys()) + self.device.abort_on('flush', store_keys()) # Notify the device self.device.on_pairing(session.connection.handle, keys, session.sc) diff --git a/bumble/utils.py b/bumble/utils.py index 92cef633..33456129 100644 --- a/bumble/utils.py +++ b/bumble/utils.py @@ -19,6 +19,8 @@ import asyncio import logging import traceback import collections +import sys +from typing import Awaitable from functools import wraps from colors import color from pyee import EventEmitter @@ -62,7 +64,37 @@ def composite_listener(cls): # ----------------------------------------------------------------------------- -class CompositeEventEmitter(EventEmitter): +class AbortableEventEmitter(EventEmitter): + + def abort_on(self, event: str, awaitable: Awaitable): + """ + Set a coroutine or future to abort when an event occur. + """ + future = asyncio.ensure_future(awaitable) + if future.done(): + return future + + def on_event(*_): + msg = f'abort: {event} event occurred.' + if isinstance(future, asyncio.Task): + # python prior to 3.9 does not support passing a message on `Task.cancel` + if sys.version_info < (3, 9, 0): + future.cancel() + else: + future.cancel(msg) + else: + future.set_exception(asyncio.CancelledError(msg)) + + def on_done(_): + self.remove_listener(event, on_event) + + self.on(event, on_event) + future.add_done_callback(on_done) + return future + + +# ----------------------------------------------------------------------------- +class CompositeEventEmitter(AbortableEventEmitter): def __init__(self): super().__init__() self._listener = None diff --git a/tests/device_test.py b/tests/device_test.py index 123df29e..07aecdd8 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -223,8 +223,16 @@ async def test_device_connect_parallel(): # ----------------------------------------------------------------------------- -async def run_test_device(): - await test_device_connect_parallel() +@pytest.mark.asyncio +async def test_flush(): + d0 = Device(host=Host(None, None)) + task = d0.abort_on('flush', asyncio.sleep(10000)) + await d0.host.flush() + try: + await task + assert False + except asyncio.CancelledError: + pass # ----------------------------------------------------------------------------- @@ -248,6 +256,14 @@ def test_gatt_services_without_gas(): assert len(device.gatt_server.attributes) == 0 +# ----------------------------------------------------------------------------- +async def run_test_device(): + await test_device_connect_parallel() + await test_flush() + await test_gatt_services_with_gas() + await test_gatt_services_without_gas() + + # ----------------------------------------------------------------------------- if __name__ == '__main__': logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())