diff --git a/bumble/hci.py b/bumble/hci.py index 9cfe81f..12ef797 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -6436,7 +6436,9 @@ class HCI_LE_Create_BIG_Complete_Event(HCI_LE_Meta_Event): irc: int = field(metadata=metadata(1)) max_pdu: int = field(metadata=metadata(2)) iso_interval: int = field(metadata=metadata(2)) - connection_handle: int = field(metadata=metadata(2, list_begin=True, list_end=True)) + connection_handle: Sequence[int] = field( + metadata=metadata(2, list_begin=True, list_end=True) + ) # ----------------------------------------------------------------------------- @@ -6468,7 +6470,9 @@ class HCI_LE_BIG_Sync_Established_Event(HCI_LE_Meta_Event): irc: int = field(metadata=metadata(1)) max_pdu: int = field(metadata=metadata(2)) iso_interval: int = field(metadata=metadata(2)) - connection_handle: int = field(metadata=metadata(2, list_begin=True, list_end=True)) + connection_handle: Sequence[int] = field( + metadata=metadata(2, list_begin=True, list_end=True) + ) # ----------------------------------------------------------------------------- diff --git a/bumble/host.py b/bumble/host.py index 8e3020e..8452f3d 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -22,11 +22,16 @@ import collections import dataclasses import logging import struct -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, cast +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Union, cast from bumble import drivers, hci, utils from bumble.colors import color -from bumble.core import ConnectionParameters, ConnectionPHY, PhysicalTransport +from bumble.core import ( + ConnectionParameters, + ConnectionPHY, + InvalidStateError, + PhysicalTransport, +) from bumble.l2cap import L2CAP_PDU from bumble.snoop import Snooper from bumble.transport.common import TransportLostError @@ -902,10 +907,14 @@ class Host(utils.EventEmitter): def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None: self.emit('l2cap_pdu', connection.handle, cid, pdu) - def on_command_processed(self, event): + def on_command_processed( + self, event: Union[hci.HCI_Command_Complete_Event, hci.HCI_Command_Status_Event] + ): if self.pending_response: # Check that it is what we were expecting - if self.pending_command.op_code != event.command_opcode: + if self.pending_command is None: + logger.warning('!!! pending_command is None ') + elif self.pending_command.op_code != event.command_opcode: logger.warning( '!!! command result mismatch, expected ' f'0x{self.pending_command.op_code:X} but got ' @@ -919,10 +928,10 @@ class Host(utils.EventEmitter): ############################################################ # HCI handlers ############################################################ - def on_hci_event(self, event): + def on_hci_event(self, event: hci.HCI_Event): logger.warning(f'{color(f"--- Ignoring event {event}", "red")}') - def on_hci_command_complete_event(self, event): + def on_hci_command_complete_event(self, event: hci.HCI_Command_Complete_Event): if event.command_opcode == 0: # This is used just for the Num_HCI_Command_Packets field, not related to # an actual command @@ -931,7 +940,7 @@ class Host(utils.EventEmitter): return self.on_command_processed(event) - def on_hci_command_status_event(self, event): + def on_hci_command_status_event(self, event: hci.HCI_Command_Status_Event): return self.on_command_processed(event) def on_hci_number_of_completed_packets_event( @@ -951,7 +960,7 @@ class Host(utils.EventEmitter): ) # Classic only - def on_hci_connection_request_event(self, event): + def on_hci_connection_request_event(self, event: hci.HCI_Connection_Request_Event): # Notify the listeners self.emit( 'connection_request', @@ -960,7 +969,14 @@ class Host(utils.EventEmitter): event.link_type, ) - def on_hci_le_connection_complete_event(self, event): + def on_hci_le_connection_complete_event( + self, + event: Union[ + hci.HCI_LE_Connection_Complete_Event, + hci.HCI_LE_Enhanced_Connection_Complete_Event, + hci.HCI_LE_Enhanced_Connection_Complete_V2_Event, + ], + ): # Check if this is a cancellation if event.status == hci.HCI_SUCCESS: # Create/update the connection @@ -1006,15 +1022,25 @@ class Host(utils.EventEmitter): event.status, ) - def on_hci_le_enhanced_connection_complete_event(self, event): + def on_hci_le_enhanced_connection_complete_event( + self, + event: Union[ + hci.HCI_LE_Enhanced_Connection_Complete_Event, + hci.HCI_LE_Enhanced_Connection_Complete_V2_Event, + ], + ): # Just use the same implementation as for the non-enhanced event for now self.on_hci_le_connection_complete_event(event) - def on_hci_le_enhanced_connection_complete_v2_event(self, event): + def on_hci_le_enhanced_connection_complete_v2_event( + self, event: hci.HCI_LE_Enhanced_Connection_Complete_V2_Event + ): # Just use the same implementation as for the v1 event for now self.on_hci_le_enhanced_connection_complete_event(event) - def on_hci_connection_complete_event(self, event): + def on_hci_connection_complete_event( + self, event: hci.HCI_Connection_Complete_Event + ): if event.status == hci.HCI_SUCCESS: # Create/update the connection logger.debug( @@ -1054,7 +1080,9 @@ class Host(utils.EventEmitter): event.status, ) - def on_hci_disconnection_complete_event(self, event): + def on_hci_disconnection_complete_event( + self, event: hci.HCI_Disconnection_Complete_Event + ): # Find the connection handle = event.connection_handle if ( @@ -1093,7 +1121,9 @@ class Host(utils.EventEmitter): # Notify the listeners self.emit('disconnection_failure', handle, event.status) - def on_hci_le_connection_update_complete_event(self, event): + def on_hci_le_connection_update_complete_event( + self, event: hci.HCI_LE_Connection_Update_Complete_Event + ): if (connection := self.connections.get(event.connection_handle)) is None: logger.warning('!!! CONNECTION PARAMETERS UPDATE COMPLETE: unknown handle') return @@ -1113,7 +1143,9 @@ class Host(utils.EventEmitter): 'connection_parameters_update_failure', connection.handle, event.status ) - def on_hci_le_phy_update_complete_event(self, event): + def on_hci_le_phy_update_complete_event( + self, event: hci.HCI_LE_PHY_Update_Complete_Event + ): if (connection := self.connections.get(event.connection_handle)) is None: logger.warning('!!! CONNECTION PHY UPDATE COMPLETE: unknown handle') return @@ -1143,7 +1175,9 @@ class Host(utils.EventEmitter): ): self.on_hci_le_advertising_report_event(event) - def on_hci_le_advertising_set_terminated_event(self, event): + def on_hci_le_advertising_set_terminated_event( + self, event: hci.HCI_LE_Advertising_Set_Terminated_Event + ): self.emit( 'advertising_set_termination', event.status, @@ -1152,7 +1186,9 @@ class Host(utils.EventEmitter): event.num_completed_extended_advertising_events, ) - def on_hci_le_periodic_advertising_sync_established_event(self, event): + def on_hci_le_periodic_advertising_sync_established_event( + self, event: hci.HCI_LE_Periodic_Advertising_Sync_Established_Event + ): self.emit( 'periodic_advertising_sync_establishment', event.status, @@ -1164,16 +1200,22 @@ class Host(utils.EventEmitter): event.advertiser_clock_accuracy, ) - def on_hci_le_periodic_advertising_sync_lost_event(self, event): + def on_hci_le_periodic_advertising_sync_lost_event( + self, event: hci.HCI_LE_Periodic_Advertising_Sync_Lost_Event + ): self.emit('periodic_advertising_sync_loss', event.sync_handle) - def on_hci_le_periodic_advertising_report_event(self, event): + def on_hci_le_periodic_advertising_report_event( + self, event: hci.HCI_LE_Periodic_Advertising_Report_Event + ): self.emit('periodic_advertising_report', event.sync_handle, event) - def on_hci_le_biginfo_advertising_report_event(self, event): + def on_hci_le_biginfo_advertising_report_event( + self, event: hci.HCI_LE_BIGInfo_Advertising_Report_Event + ): self.emit('biginfo_advertising_report', event.sync_handle, event) - def on_hci_le_cis_request_event(self, event): + def on_hci_le_cis_request_event(self, event: hci.HCI_LE_CIS_Request_Event): self.emit( 'cis_request', event.acl_connection_handle, @@ -1182,10 +1224,12 @@ class Host(utils.EventEmitter): event.cis_id, ) - def on_hci_le_create_big_complete_event(self, event): + def on_hci_le_create_big_complete_event( + self, event: hci.HCI_LE_Create_BIG_Complete_Event + ): self.bigs[event.big_handle] = set(event.connection_handle) if self.iso_packet_queue is None: - logger.warning("BIS established but ISO packets not supported") + raise InvalidStateError("BIS established but ISO packets not supported") for connection_handle in event.connection_handle: self.bis_links[connection_handle] = IsoLink( @@ -1208,8 +1252,13 @@ class Host(utils.EventEmitter): event.iso_interval, ) - def on_hci_le_big_sync_established_event(self, event): + def on_hci_le_big_sync_established_event( + self, event: hci.HCI_LE_BIG_Sync_Established_Event + ): self.bigs[event.big_handle] = set(event.connection_handle) + if self.iso_packet_queue is None: + raise InvalidStateError("BIS established but ISO packets not supported") + for connection_handle in event.connection_handle: self.bis_links[connection_handle] = IsoLink( connection_handle, self.iso_packet_queue @@ -1229,15 +1278,19 @@ class Host(utils.EventEmitter): event.connection_handle, ) - def on_hci_le_big_sync_lost_event(self, event): + def on_hci_le_big_sync_lost_event(self, event: hci.HCI_LE_BIG_Sync_Lost_Event): self.remove_big(event.big_handle) self.emit('big_sync_lost', event.big_handle, event.reason) - def on_hci_le_terminate_big_complete_event(self, event): + def on_hci_le_terminate_big_complete_event( + self, event: hci.HCI_LE_Terminate_BIG_Complete_Event + ): self.remove_big(event.big_handle) self.emit('big_termination', event.reason, event.big_handle) - def on_hci_le_periodic_advertising_sync_transfer_received_event(self, event): + def on_hci_le_periodic_advertising_sync_transfer_received_event( + self, event: hci.HCI_LE_Periodic_Advertising_Sync_Transfer_Received_Event + ): self.emit( 'periodic_advertising_sync_transfer', event.status, @@ -1250,7 +1303,9 @@ class Host(utils.EventEmitter): event.advertiser_clock_accuracy, ) - def on_hci_le_periodic_advertising_sync_transfer_received_v2_event(self, event): + def on_hci_le_periodic_advertising_sync_transfer_received_v2_event( + self, event: hci.HCI_LE_Periodic_Advertising_Sync_Transfer_Received_V2_Event + ): self.emit( 'periodic_advertising_sync_transfer', event.status, @@ -1263,11 +1318,11 @@ class Host(utils.EventEmitter): event.advertiser_clock_accuracy, ) - def on_hci_le_cis_established_event(self, event): + def on_hci_le_cis_established_event(self, event: hci.HCI_LE_CIS_Established_Event): # The remaining parameters are unused for now. if event.status == hci.HCI_SUCCESS: if self.iso_packet_queue is None: - logger.warning("CIS established but ISO packets not supported") + raise InvalidStateError("CIS established but ISO packets not supported") self.cis_links[event.connection_handle] = IsoLink( handle=event.connection_handle, packet_queue=self.iso_packet_queue ) @@ -1294,7 +1349,9 @@ class Host(utils.EventEmitter): 'cis_establishment_failure', event.connection_handle, event.status ) - def on_hci_le_remote_connection_parameter_request_event(self, event): + def on_hci_le_remote_connection_parameter_request_event( + self, event: hci.HCI_LE_Remote_Connection_Parameter_Request_Event + ): if event.connection_handle not in self.connections: logger.warning('!!! REMOTE CONNECTION PARAMETER REQUEST: unknown handle') return @@ -1313,7 +1370,9 @@ class Host(utils.EventEmitter): ) ) - def on_hci_le_long_term_key_request_event(self, event): + def on_hci_le_long_term_key_request_event( + self, event: hci.HCI_LE_Long_Term_Key_Request_Event + ): if (connection := self.connections.get(event.connection_handle)) is None: logger.warning('!!! LE LONG TERM KEY REQUEST: unknown handle') return @@ -1347,7 +1406,9 @@ class Host(utils.EventEmitter): asyncio.create_task(send_long_term_key()) - def on_hci_synchronous_connection_complete_event(self, event): + def on_hci_synchronous_connection_complete_event( + self, event: hci.HCI_Synchronous_Connection_Complete_Event + ): if event.status == hci.HCI_SUCCESS: # Create/update the connection logger.debug( @@ -1373,7 +1434,9 @@ class Host(utils.EventEmitter): # Notify the client self.emit('sco_connection_failure', event.bd_addr, event.status) - def on_hci_synchronous_connection_changed_event(self, event): + def on_hci_synchronous_connection_changed_event( + self, event: hci.HCI_Synchronous_Connection_Changed_Event + ): pass def on_hci_mode_change_event(self, event: hci.HCI_Mode_Change_Event): @@ -1385,7 +1448,7 @@ class Host(utils.EventEmitter): event.interval, ) - def on_hci_role_change_event(self, event): + def on_hci_role_change_event(self, event: hci.HCI_Role_Change_Event): if event.status == hci.HCI_SUCCESS: logger.debug( f'role change for {event.bd_addr}: ' @@ -1399,7 +1462,9 @@ class Host(utils.EventEmitter): ) self.emit('role_change_failure', event.bd_addr, event.status) - def on_hci_le_data_length_change_event(self, event): + def on_hci_le_data_length_change_event( + self, event: hci.HCI_LE_Data_Length_Change_Event + ): if (connection := self.connections.get(event.connection_handle)) is None: logger.warning('!!! DATA LENGTH CHANGE: unknown handle') return @@ -1413,7 +1478,9 @@ class Host(utils.EventEmitter): event.max_rx_time, ) - def on_hci_authentication_complete_event(self, event): + def on_hci_authentication_complete_event( + self, event: hci.HCI_Authentication_Complete_Event + ): # Notify the client if event.status == hci.HCI_SUCCESS: self.emit('connection_authentication', event.connection_handle) @@ -1454,7 +1521,9 @@ class Host(utils.EventEmitter): 'connection_encryption_failure', event.connection_handle, event.status ) - def on_hci_encryption_key_refresh_complete_event(self, event): + def on_hci_encryption_key_refresh_complete_event( + self, event: hci.HCI_Encryption_Key_Refresh_Complete_Event + ): # Notify the client if event.status == hci.HCI_SUCCESS: self.emit('connection_encryption_key_refresh', event.connection_handle) @@ -1465,7 +1534,7 @@ class Host(utils.EventEmitter): event.status, ) - def on_hci_qos_setup_complete_event(self, event): + def on_hci_qos_setup_complete_event(self, event: hci.HCI_QOS_Setup_Complete_Event): if event.status == hci.HCI_SUCCESS: self.emit( 'connection_qos_setup', event.connection_handle, event.service_type @@ -1477,23 +1546,31 @@ class Host(utils.EventEmitter): event.status, ) - def on_hci_link_supervision_timeout_changed_event(self, event): + def on_hci_link_supervision_timeout_changed_event( + self, event: hci.HCI_Link_Supervision_Timeout_Changed_Event + ): pass - def on_hci_max_slots_change_event(self, event): + def on_hci_max_slots_change_event(self, event: hci.HCI_Max_Slots_Change_Event): pass - def on_hci_page_scan_repetition_mode_change_event(self, event): + def on_hci_page_scan_repetition_mode_change_event( + self, event: hci.HCI_Page_Scan_Repetition_Mode_Change_Event + ): pass - def on_hci_link_key_notification_event(self, event): + def on_hci_link_key_notification_event( + self, event: hci.HCI_Link_Key_Notification_Event + ): logger.debug( f'link key for {event.bd_addr}: {event.link_key.hex()}, ' f'type={hci.HCI_Constant.link_key_type_name(event.key_type)}' ) self.emit('link_key', event.bd_addr, event.link_key, event.key_type) - def on_hci_simple_pairing_complete_event(self, event): + def on_hci_simple_pairing_complete_event( + self, event: hci.HCI_Simple_Pairing_Complete_Event + ): logger.debug( f'simple pairing complete for {event.bd_addr}: ' f'status={hci.HCI_Constant.status_name(event.status)}' @@ -1503,10 +1580,10 @@ class Host(utils.EventEmitter): else: self.emit('classic_pairing_failure', event.bd_addr, event.status) - def on_hci_pin_code_request_event(self, event): + def on_hci_pin_code_request_event(self, event: hci.HCI_PIN_Code_Request_Event): self.emit('pin_code_request', event.bd_addr) - def on_hci_link_key_request_event(self, event): + def on_hci_link_key_request_event(self, event: hci.HCI_Link_Key_Request_Event): async def send_link_key(): if self.link_key_provider is None: logger.debug('no link key provider') @@ -1531,10 +1608,14 @@ class Host(utils.EventEmitter): asyncio.create_task(send_link_key()) - def on_hci_io_capability_request_event(self, event): + def on_hci_io_capability_request_event( + self, event: hci.HCI_IO_Capability_Request_Event + ): self.emit('authentication_io_capability_request', event.bd_addr) - def on_hci_io_capability_response_event(self, event): + def on_hci_io_capability_response_event( + self, event: hci.HCI_IO_Capability_Response_Event + ): self.emit( 'authentication_io_capability_response', event.bd_addr, @@ -1542,25 +1623,33 @@ class Host(utils.EventEmitter): event.authentication_requirements, ) - def on_hci_user_confirmation_request_event(self, event): + def on_hci_user_confirmation_request_event( + self, event: hci.HCI_User_Confirmation_Request_Event + ): self.emit( 'authentication_user_confirmation_request', event.bd_addr, event.numeric_value, ) - def on_hci_user_passkey_request_event(self, event): + def on_hci_user_passkey_request_event( + self, event: hci.HCI_User_Passkey_Request_Event + ): self.emit('authentication_user_passkey_request', event.bd_addr) - def on_hci_user_passkey_notification_event(self, event): + def on_hci_user_passkey_notification_event( + self, event: hci.HCI_User_Passkey_Notification_Event + ): self.emit( 'authentication_user_passkey_notification', event.bd_addr, event.passkey ) - def on_hci_inquiry_complete_event(self, _event): + def on_hci_inquiry_complete_event(self, _event: hci.HCI_Inquiry_Complete_Event): self.emit('inquiry_complete') - def on_hci_inquiry_result_with_rssi_event(self, event): + def on_hci_inquiry_result_with_rssi_event( + self, event: hci.HCI_Inquiry_Result_With_RSSI_Event + ): for bd_addr, class_of_device, rssi in zip( event.bd_addr, event.class_of_device, event.rssi ): @@ -1572,7 +1661,9 @@ class Host(utils.EventEmitter): rssi, ) - def on_hci_extended_inquiry_result_event(self, event): + def on_hci_extended_inquiry_result_event( + self, event: hci.HCI_Extended_Inquiry_Result_Event + ): self.emit( 'inquiry_result', event.bd_addr, @@ -1581,7 +1672,9 @@ class Host(utils.EventEmitter): event.rssi, ) - def on_hci_remote_name_request_complete_event(self, event): + def on_hci_remote_name_request_complete_event( + self, event: hci.HCI_Remote_Name_Request_Complete_Event + ): if event.status != hci.HCI_SUCCESS: self.emit('remote_name_failure', event.bd_addr, event.status) else: @@ -1592,14 +1685,18 @@ class Host(utils.EventEmitter): self.emit('remote_name', event.bd_addr, utf8_name) - def on_hci_remote_host_supported_features_notification_event(self, event): + def on_hci_remote_host_supported_features_notification_event( + self, event: hci.HCI_Remote_Host_Supported_Features_Notification_Event + ): self.emit( 'remote_host_supported_features', event.bd_addr, event.host_supported_features, ) - def on_hci_le_read_remote_features_complete_event(self, event): + def on_hci_le_read_remote_features_complete_event( + self, event: hci.HCI_LE_Read_Remote_Features_Complete_Event + ): if event.status != hci.HCI_SUCCESS: self.emit( 'le_remote_features_failure', event.connection_handle, event.status @@ -1611,22 +1708,34 @@ class Host(utils.EventEmitter): int.from_bytes(event.le_features, 'little'), ) - def on_hci_le_cs_read_remote_supported_capabilities_complete_event(self, event): + def on_hci_le_cs_read_remote_supported_capabilities_complete_event( + self, event: hci.HCI_LE_CS_Read_Remote_Supported_Capabilities_Complete_Event + ): self.emit('cs_remote_supported_capabilities', event) - def on_hci_le_cs_security_enable_complete_event(self, event): + def on_hci_le_cs_security_enable_complete_event( + self, event: hci.HCI_LE_CS_Security_Enable_Complete_Event + ): self.emit('cs_security', event) - def on_hci_le_cs_config_complete_event(self, event): + def on_hci_le_cs_config_complete_event( + self, event: hci.HCI_LE_CS_Config_Complete_Event + ): self.emit('cs_config', event) - def on_hci_le_cs_procedure_enable_complete_event(self, event): + def on_hci_le_cs_procedure_enable_complete_event( + self, event: hci.HCI_LE_CS_Procedure_Enable_Complete_Event + ): self.emit('cs_procedure', event) - def on_hci_le_cs_subevent_result_event(self, event): + def on_hci_le_cs_subevent_result_event( + self, event: hci.HCI_LE_CS_Subevent_Result_Event + ): self.emit('cs_subevent_result', event) - def on_hci_le_cs_subevent_result_continue_event(self, event): + def on_hci_le_cs_subevent_result_continue_event( + self, event: hci.HCI_LE_CS_Subevent_Result_Continue_Event + ): self.emit('cs_subevent_result_continue', event) def on_hci_le_subrate_change_event(self, event: hci.HCI_LE_Subrate_Change_Event): @@ -1639,5 +1748,5 @@ class Host(utils.EventEmitter): event.supervision_timeout, ) - def on_hci_vendor_event(self, event): + def on_hci_vendor_event(self, event: hci.HCI_Vendor_Event): self.emit('vendor_event', event)