From 4bee8d5287c050ca1386a7df65632bd0b4db0335 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Thu, 6 Nov 2025 18:19:02 +0800 Subject: [PATCH] Use EventWatcher and send_command(check_result=True) in all similar patterns --- bumble/controller.py | 10 ++ bumble/device.py | 288 ++++++++++++++++--------------------------- 2 files changed, 118 insertions(+), 180 deletions(-) diff --git a/bumble/controller.py b/bumble/controller.py index 1153ec00..db6334b5 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -184,6 +184,7 @@ class Controller: advertising_interval: int = 2000 advertising_data: Optional[bytes] = None advertising_timer_handle: Optional[asyncio.Handle] = None + classic_scan_enable: int = 0 _random_address: 'Address' = Address('00:00:00:00:00:00') @@ -1630,6 +1631,15 @@ class Controller: ''' return bytes([HCI_SUCCESS]) + def on_hci_write_scan_enable_command( + self, command: hci.HCI_Write_Scan_Enable_Command + ) -> Optional[bytes]: + ''' + See Bluetooth spec Vol 4, Part E - 7.3.18 Write Scan Enable Command + ''' + self.classic_scan_enable = command.scan_enable + return bytes([HCI_SUCCESS]) + def on_hci_le_read_remote_features_command( self, command: hci.HCI_LE_Read_Remote_Features_Command ) -> Optional[bytes]: diff --git a/bumble/device.py b/bumble/device.py index f80be302..75765d60 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -907,7 +907,7 @@ class PeriodicAdvertisingSync(utils.EventEmitter): hci.HCI_LE_Periodic_Advertising_Create_Sync_Command.Options.DUPLICATE_FILTERING_INITIALLY_ENABLED ) - response = await self.device.send_command( + await self.device.send_command( hci.HCI_LE_Periodic_Advertising_Create_Sync_Command( options=options, advertising_sid=self.sid, @@ -916,10 +916,9 @@ class PeriodicAdvertisingSync(utils.EventEmitter): skip=self.skip, sync_timeout=int(self.sync_timeout * 100), sync_cte_type=0, - ) + ), + check_result=True, ) - if response.status != hci.HCI_Command_Status_Event.PENDING: - raise hci.HCI_StatusError(response) self.state = self.State.PENDING @@ -1915,16 +1914,13 @@ class Connection(utils.CompositeEventEmitter): """Idles the current task waiting for a disconnect or timeout""" abort = asyncio.get_running_loop().create_future() - self.on(self.EVENT_DISCONNECTION, abort.set_result) - self.on(self.EVENT_DISCONNECTION_FAILURE, abort.set_exception) + with closing(utils.EventWatcher()) as watcher: + watcher.on(self, self.EVENT_DISCONNECTION, abort.set_result) + watcher.on(self, self.EVENT_DISCONNECTION_FAILURE, abort.set_exception) - try: await asyncio.wait_for( utils.cancel_on_event(self.device, Device.EVENT_FLUSH, abort), timeout ) - finally: - self.remove_listener(self.EVENT_DISCONNECTION, abort.set_result) - self.remove_listener(self.EVENT_DISCONNECTION_FAILURE, abort.set_exception) async def set_data_length(self, tx_octets: int, tx_time: int) -> None: return await self.device.set_data_length(self, tx_octets, tx_time) @@ -2887,7 +2883,9 @@ class Device(utils.CompositeEventEmitter): self.address_resolver = smp.AddressResolver(resolving_keys) if self.address_resolution_offload or self.address_generation_offload: - await self.send_command(hci.HCI_LE_Clear_Resolving_List_Command()) + await self.send_command( + hci.HCI_LE_Clear_Resolving_List_Command(), check_result=True + ) # Add an empty entry for non-directed address generation. await self.send_command( @@ -2896,7 +2894,8 @@ class Device(utils.CompositeEventEmitter): peer_identity_address=hci.Address.ANY, peer_irk=bytes(16), local_irk=self.irk, - ) + ), + check_result=True, ) for irk, address in resolving_keys: @@ -2906,7 +2905,8 @@ class Device(utils.CompositeEventEmitter): peer_identity_address=address, peer_irk=irk, local_irk=self.irk, - ) + ), + check_result=True, ) def supports_le_features(self, feature: hci.LeFeatureMask) -> bool: @@ -3501,16 +3501,15 @@ class Device(utils.CompositeEventEmitter): check_result=True, ) - response = await self.send_command( + self.discovering = False + await self.send_command( hci.HCI_Inquiry_Command( lap=hci.HCI_GENERAL_INQUIRY_LAP, inquiry_length=DEVICE_DEFAULT_INQUIRY_LENGTH, num_responses=0, # Unlimited number of responses. - ) + ), + check_result=True, ) - if response.status != hci.HCI_Command_Status_Event.PENDING: - self.discovering = False - raise hci.HCI_StatusError(response) self.auto_restart_inquiry = auto_restart self.discovering = True @@ -3546,7 +3545,8 @@ class Device(utils.CompositeEventEmitter): scan_enable = 0x00 return await self.send_command( - hci.HCI_Write_Scan_Enable_Command(scan_enable=scan_enable) + hci.HCI_Write_Scan_Enable_Command(scan_enable=scan_enable), + check_result=True, ) async def set_discoverable(self, discoverable: bool = True) -> None: @@ -3775,7 +3775,7 @@ class Device(utils.CompositeEventEmitter): for phy in phys ] - result = await self.send_command( + await self.send_command( hci.HCI_LE_Extended_Create_Connection_Command( initiator_filter_policy=0, own_address_type=own_address_type, @@ -3796,14 +3796,15 @@ class Device(utils.CompositeEventEmitter): supervision_timeouts=supervision_timeouts, min_ce_lengths=min_ce_lengths, max_ce_lengths=max_ce_lengths, - ) + ), + check_result=True, ) else: if hci.HCI_LE_1M_PHY not in connection_parameters_preferences: raise InvalidArgumentError('1M PHY preferences required') prefs = connection_parameters_preferences[hci.HCI_LE_1M_PHY] - result = await self.send_command( + await self.send_command( hci.HCI_LE_Create_Connection_Command( le_scan_interval=int( DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625 @@ -3825,7 +3826,8 @@ class Device(utils.CompositeEventEmitter): supervision_timeout=int(prefs.supervision_timeout / 10), min_ce_length=int(prefs.min_ce_length / 0.625), max_ce_length=int(prefs.max_ce_length / 0.625), - ) + ), + check_result=True, ) else: # Save pending connection @@ -3842,7 +3844,7 @@ class Device(utils.CompositeEventEmitter): ) # TODO: allow passing other settings - result = await self.send_command( + await self.send_command( hci.HCI_Create_Connection_Command( bd_addr=peer_address, packet_type=0xCC18, # FIXME: change @@ -3850,12 +3852,10 @@ class Device(utils.CompositeEventEmitter): clock_offset=0x0000, allow_role_switch=0x01, reserved=0, - ) + ), + check_result=True, ) - if result.status != hci.HCI_Command_Status_Event.PENDING: - raise hci.HCI_StatusError(result) - # Wait for the connection process to complete if transport == PhysicalTransport.LE: self.le_connecting = True @@ -4007,7 +4007,8 @@ class Device(utils.CompositeEventEmitter): await self.send_command( hci.HCI_Accept_Connection_Request_Command( bd_addr=peer_address, role=role - ) + ), + check_result=True, ) # Wait for connection complete @@ -4077,19 +4078,17 @@ class Device(utils.CompositeEventEmitter): connection.EVENT_DISCONNECTION_FAILURE, pending_disconnection.set_exception ) - # Request a disconnection - result = await self.send_command( - hci.HCI_Disconnect_Command( - connection_handle=connection.handle, reason=reason - ) - ) - try: - if result.status != hci.HCI_Command_Status_Event.PENDING: - raise hci.HCI_StatusError(result) - # Wait for the disconnection process to complete self.disconnecting = True + + # Request a disconnection + await self.send_command( + hci.HCI_Disconnect_Command( + connection_handle=connection.handle, reason=reason + ), + check_result=True, + ) return await utils.cancel_on_event( self, Device.EVENT_FLUSH, pending_disconnection ) @@ -4175,7 +4174,7 @@ class Device(utils.CompositeEventEmitter): return - result = await self.send_command( + await self.send_command( hci.HCI_LE_Connection_Update_Command( connection_handle=connection.handle, connection_interval_min=connection_interval_min, @@ -4184,10 +4183,9 @@ class Device(utils.CompositeEventEmitter): supervision_timeout=supervision_timeout, min_ce_length=min_ce_length, max_ce_length=max_ce_length, - ) + ), + check_result=True, ) - if result.status != hci.HCI_Command_Status_Event.PENDING: - raise hci.HCI_StatusError(result) async def get_connection_rssi(self, connection): result = await self.send_command( @@ -4222,23 +4220,17 @@ class Device(utils.CompositeEventEmitter): (1 if rx_phys is None else 0) << 1 ) - result = await self.send_command( + await self.send_command( hci.HCI_LE_Set_PHY_Command( connection_handle=connection.handle, all_phys=all_phys_bits, tx_phys=hci.phy_list_to_bits(tx_phys), rx_phys=hci.phy_list_to_bits(rx_phys), phy_options=phy_options, - ) + ), + check_result=True, ) - if result.status != hci.HCI_COMMAND_STATUS_PENDING: - logger.warning( - 'HCI_LE_Set_PHY_Command failed: ' - f'{hci.HCI_Constant.error_name(result.status)}' - ) - raise hci.HCI_StatusError(result) - async def set_default_phy( self, tx_phys: Optional[Iterable[hci.Phy]] = None, @@ -4455,43 +4447,26 @@ class Device(utils.CompositeEventEmitter): async def authenticate(self, connection: Connection) -> None: # Set up event handlers pending_authentication = asyncio.get_running_loop().create_future() + with closing(utils.EventWatcher()) as watcher: - def on_authentication(): - pending_authentication.set_result(None) + @watcher.on(connection, connection.EVENT_CONNECTION_AUTHENTICATION) + def on_authentication() -> None: + pending_authentication.set_result(None) - def on_authentication_failure(error_code): - pending_authentication.set_exception(hci.HCI_Error(error_code)) + @watcher.on(connection, connection.EVENT_CONNECTION_AUTHENTICATION_FAILURE) + def on_authentication_failure(error_code: int) -> None: + pending_authentication.set_exception(hci.HCI_Error(error_code)) - connection.on(connection.EVENT_CONNECTION_AUTHENTICATION, on_authentication) - connection.on( - connection.EVENT_CONNECTION_AUTHENTICATION_FAILURE, - on_authentication_failure, - ) - - # Request the authentication - try: - result = await self.send_command( + # Request the authentication + await self.send_command( hci.HCI_Authentication_Requested_Command( connection_handle=connection.handle - ) + ), + check_result=True, ) - if result.status != hci.HCI_COMMAND_STATUS_PENDING: - logger.warning( - 'HCI_Authentication_Requested_Command failed: ' - f'{hci.HCI_Constant.error_name(result.status)}' - ) - raise hci.HCI_StatusError(result) # Wait for the authentication to complete await connection.cancel_on_disconnection(pending_authentication) - finally: - connection.remove_listener( - connection.EVENT_CONNECTION_AUTHENTICATION, on_authentication - ) - connection.remove_listener( - connection.EVENT_CONNECTION_AUTHENTICATION_FAILURE, - on_authentication_failure, - ) async def encrypt(self, connection: Connection, enable: bool = True): if not enable and connection.transport == PhysicalTransport.LE: @@ -4500,21 +4475,17 @@ class Device(utils.CompositeEventEmitter): # Set up event handlers pending_encryption = asyncio.get_running_loop().create_future() - def on_encryption_change(): - pending_encryption.set_result(None) - - def on_encryption_failure(error_code: int): - pending_encryption.set_exception(hci.HCI_Error(error_code)) - - connection.on( - connection.EVENT_CONNECTION_ENCRYPTION_CHANGE, on_encryption_change - ) - connection.on( - connection.EVENT_CONNECTION_ENCRYPTION_FAILURE, on_encryption_failure - ) - # Request the encryption - try: + with closing(utils.EventWatcher()) as watcher: + + @watcher.on(connection, connection.EVENT_CONNECTION_ENCRYPTION_CHANGE) + def _() -> None: + pending_encryption.set_result(None) + + @watcher.on(connection, connection.EVENT_CONNECTION_ENCRYPTION_FAILURE) + def _(error_code: int): + pending_encryption.set_exception(hci.HCI_Error(error_code)) + if connection.transport == PhysicalTransport.LE: # Look for a key in the key store if self.keystore is None: @@ -4539,45 +4510,26 @@ class Device(utils.CompositeEventEmitter): if connection.role != hci.Role.CENTRAL: raise InvalidStateError('only centrals can start encryption') - result = await self.send_command( + await self.send_command( hci.HCI_LE_Enable_Encryption_Command( connection_handle=connection.handle, random_number=rand, encrypted_diversifier=ediv, long_term_key=ltk, - ) + ), + check_result=True, ) - - if result.status != hci.HCI_COMMAND_STATUS_PENDING: - logger.warning( - 'HCI_LE_Enable_Encryption_Command failed: ' - f'{hci.HCI_Constant.error_name(result.status)}' - ) - raise hci.HCI_StatusError(result) else: - result = await self.send_command( + await self.send_command( hci.HCI_Set_Connection_Encryption_Command( connection_handle=connection.handle, encryption_enable=0x01 if enable else 0x00, - ) + ), + check_result=True, ) - if result.status != hci.HCI_COMMAND_STATUS_PENDING: - logger.warning( - 'HCI_Set_Connection_Encryption_Command failed: ' - f'{hci.HCI_Constant.error_name(result.status)}' - ) - raise hci.HCI_StatusError(result) - # Wait for the result await connection.cancel_on_disconnection(pending_encryption) - finally: - connection.remove_listener( - connection.EVENT_CONNECTION_ENCRYPTION_CHANGE, on_encryption_change - ) - connection.remove_listener( - connection.EVENT_CONNECTION_ENCRYPTION_FAILURE, on_encryption_failure - ) async def update_keys(self, address: str, keys: PairingKeys) -> None: if self.keystore is None: @@ -4595,80 +4547,55 @@ class Device(utils.CompositeEventEmitter): async def switch_role(self, connection: Connection, role: hci.Role): pending_role_change = asyncio.get_running_loop().create_future() - def on_role_change(new_role: hci.Role): - pending_role_change.set_result(new_role) + with closing(utils.EventWatcher()) as watcher: - def on_role_change_failure(error_code: int): - pending_role_change.set_exception(hci.HCI_Error(error_code)) + @watcher.on(connection, connection.EVENT_ROLE_CHANGE) + def _(new_role: hci.Role): + pending_role_change.set_result(new_role) - connection.on(connection.EVENT_ROLE_CHANGE, on_role_change) - connection.on(connection.EVENT_ROLE_CHANGE_FAILURE, on_role_change_failure) + @watcher.on(connection, connection.EVENT_ROLE_CHANGE_FAILURE) + def _(error_code: int): + pending_role_change.set_exception(hci.HCI_Error(error_code)) - try: - result = await self.send_command( - hci.HCI_Switch_Role_Command(bd_addr=connection.peer_address, role=role) + await self.send_command( + hci.HCI_Switch_Role_Command(bd_addr=connection.peer_address, role=role), + check_result=True, ) - if result.status != hci.HCI_COMMAND_STATUS_PENDING: - logger.warning( - 'HCI_Switch_Role_Command failed: ' - f'{hci.HCI_Constant.error_name(result.status)}' - ) - raise hci.HCI_StatusError(result) await connection.cancel_on_disconnection(pending_role_change) - finally: - connection.remove_listener(connection.EVENT_ROLE_CHANGE, on_role_change) - connection.remove_listener( - connection.EVENT_ROLE_CHANGE_FAILURE, on_role_change_failure - ) # [Classic only] async def request_remote_name(self, remote: Union[hci.Address, Connection]) -> str: # Set up event handlers - pending_name = asyncio.get_running_loop().create_future() + pending_name: asyncio.Future[str] = asyncio.get_running_loop().create_future() peer_address = ( remote if isinstance(remote, hci.Address) else remote.peer_address ) - handler = self.on( - self.EVENT_REMOTE_NAME, - lambda address, remote_name: ( - pending_name.set_result(remote_name) - if address == peer_address - else None - ), - ) - failure_handler = self.on( - self.EVENT_REMOTE_NAME_FAILURE, - lambda address, error_code: ( - pending_name.set_exception(hci.HCI_Error(error_code)) - if address == peer_address - else None - ), - ) + with closing(utils.EventWatcher()) as watcher: - try: - result = await self.send_command( + @watcher.on(self, self.EVENT_REMOTE_NAME) + def _(address: hci.Address, remote_name: str) -> None: + if address == peer_address: + pending_name.set_result(remote_name) + + @watcher.on(self, self.EVENT_REMOTE_NAME_FAILURE) + def _(address: hci.Address, error_code: int) -> None: + if address == peer_address: + pending_name.set_exception(hci.HCI_Error(error_code)) + + await self.send_command( hci.HCI_Remote_Name_Request_Command( bd_addr=peer_address, page_scan_repetition_mode=hci.HCI_Remote_Name_Request_Command.R2, reserved=0, clock_offset=0, # TODO investigate non-0 values - ) + ), + check_result=True, ) - if result.status != hci.HCI_COMMAND_STATUS_PENDING: - logger.warning( - 'HCI_Remote_Name_Request_Command failed: ' - f'{hci.HCI_Constant.error_name(result.status)}' - ) - raise hci.HCI_StatusError(result) - # Wait for the result return await utils.cancel_on_event(self, Device.EVENT_FLUSH, pending_name) - finally: - self.remove_listener(self.EVENT_REMOTE_NAME, handler) - self.remove_listener(self.EVENT_REMOTE_NAME_FAILURE, failure_handler) # [LE only] @utils.experimental('Only for testing.') @@ -4684,8 +4611,6 @@ class Device(utils.CompositeEventEmitter): Returns: List of created CIS handles corresponding to the same order of [cid_id]. """ - num_cis = len(parameters.cis_parameters) - response = await self.send_command( hci.HCI_LE_Set_CIG_Parameters_Command( cig_id=parameters.cig_id, @@ -5753,9 +5678,7 @@ class Device(utils.CompositeEventEmitter): @host_event_handler @with_connection_from_handle - def on_connection_authentication_failure( - self, connection: Connection, error: core.ConnectionError - ): + def on_connection_authentication_failure(self, connection: Connection, error: int): logger.debug( f'*** Connection Authentication Failure: [0x{connection.handle:04X}] ' f'{connection.peer_address} as {connection.role_name}, error={error}' @@ -5810,11 +5733,15 @@ class Device(utils.CompositeEventEmitter): # [Classic only] @host_event_handler @with_connection_from_address - def on_authentication_user_confirmation_request(self, connection, code) -> None: + def on_authentication_user_confirmation_request( + self, connection: Connection, code: int + ) -> None: # Ask what the pairing config should be for this connection pairing_config = self.pairing_config_factory(connection) io_capability = pairing_config.delegate.classic_io_capability peer_io_capability = connection.pairing_peer_io_capability + if peer_io_capability is None: + raise core.InvalidStateError("Unknown pairing_peer_io_capability") async def confirm() -> bool: # Ask the user to confirm the pairing, without display @@ -5941,15 +5868,16 @@ class Device(utils.CompositeEventEmitter): # Respond if io_capability == hci.IoCapability.KEYBOARD_ONLY: # Ask the user to enter a string - async def get_pin_code(): - pin_code = await connection.cancel_on_disconnection( + async def get_pin_code() -> None: + pin_code_str = await connection.cancel_on_disconnection( pairing_config.delegate.get_string(16) ) - if pin_code is not None: - pin_code = bytes(pin_code, encoding='utf-8') + if pin_code_str is not None: + pin_code = bytes(pin_code_str, encoding='utf-8') pin_code_len = len(pin_code) - assert 0 < pin_code_len <= 16, "pin_code should be 1-16 bytes" + if not 1 <= pin_code_len <= 16: + raise core.InvalidArgumentError("pin_code should be 1-16 bytes") await self.host.send_command( hci.HCI_PIN_Code_Request_Reply_Command( bd_addr=connection.peer_address,