diff --git a/bumble/device.py b/bumble/device.py index e67f78d3..b777cb3d 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -48,6 +48,7 @@ from typing_extensions import Self from bumble import ( core, data_types, + gatt, gatt_client, gatt_server, hci, @@ -264,7 +265,7 @@ class ExtendedAdvertisement(Advertisement): # ----------------------------------------------------------------------------- class AdvertisementDataAccumulator: - def __init__(self, passive=False): + def __init__(self, passive: bool = False): self.passive = passive self.last_advertisement = None self.last_data = b'' @@ -1243,7 +1244,7 @@ class LePhyOptions: PREFER_S_2_CODED_PHY = 1 PREFER_S_8_CODED_PHY = 2 - def __init__(self, coded_phy_preference=0): + def __init__(self, coded_phy_preference: int = 0): self.coded_phy_preference = coded_phy_preference def __int__(self): @@ -1691,6 +1692,7 @@ class Connection(utils.CompositeEventEmitter): self_address: hci.Address self_resolvable_address: Optional[hci.Address] peer_address: hci.Address + peer_name: Optional[str] peer_resolvable_address: Optional[hci.Address] peer_le_features: Optional[hci.LeFeatureMask] role: hci.Role @@ -1931,7 +1933,7 @@ class Connection(utils.CompositeEventEmitter): self.remove_listener(self.EVENT_DISCONNECTION, abort.set_result) self.remove_listener(self.EVENT_DISCONNECTION_FAILURE, abort.set_exception) - async def set_data_length(self, tx_octets, tx_time) -> None: + async def set_data_length(self, tx_octets: int, tx_time: int) -> None: return await self.device.set_data_length(self, tx_octets, tx_time) async def update_parameters( @@ -1961,7 +1963,12 @@ class Connection(utils.CompositeEventEmitter): use_l2cap=use_l2cap, ) - async def set_phy(self, tx_phys=None, rx_phys=None, phy_options=None): + async def set_phy( + self, + tx_phys: Optional[Iterable[hci.Phy]] = None, + rx_phys: Optional[Iterable[hci.Phy]] = None, + phy_options: int = 0, + ): return await self.device.set_connection_phy(self, tx_phys, rx_phys, phy_options) async def get_phy(self) -> ConnectionPHY: @@ -2156,7 +2163,7 @@ class DeviceConfiguration: # Decorator that converts the first argument from a connection handle to a connection def with_connection_from_handle(function): @functools.wraps(function) - def wrapper(self, connection_handle, *args, **kwargs): + def wrapper(self, connection_handle: int, *args, **kwargs): if (connection := self.lookup_connection(connection_handle)) is None: raise ObjectLookupError( f'no connection for handle: 0x{connection_handle:04x}' @@ -2169,7 +2176,7 @@ def with_connection_from_handle(function): # Decorator that converts the first argument from a bluetooth address to a connection def with_connection_from_address(function): @functools.wraps(function) - def wrapper(self, address, *args, **kwargs): + def wrapper(self, address: hci.Address, *args, **kwargs): if connection := self.pending_connections.get(address, False): return function(self, connection, *args, **kwargs) for connection in self.connections.values(): @@ -2655,7 +2662,7 @@ class Device(utils.CompositeEventEmitter): def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: self.host.send_l2cap_pdu(connection_handle, cid, pdu) - async def send_command(self, command, check_result=False): + async def send_command(self, command: hci.HCI_Command, check_result: bool = False): try: return await asyncio.wait_for( self.host.send_command(command, check_result), self.command_timeout @@ -2912,13 +2919,13 @@ class Device(utils.CompositeEventEmitter): def supports_le_features(self, feature: hci.LeFeatureMask) -> bool: return self.host.supports_le_features(feature) - def supports_le_phy(self, phy: int) -> bool: - if phy == hci.HCI_LE_1M_PHY: + def supports_le_phy(self, phy: hci.Phy) -> bool: + if phy == hci.Phy.LE_1M: return True - feature_map: dict[int, hci.LeFeatureMask] = { - hci.HCI_LE_2M_PHY: hci.LeFeatureMask.LE_2M_PHY, - hci.HCI_LE_CODED_PHY: hci.LeFeatureMask.LE_CODED_PHY, + feature_map: dict[hci.Phy, hci.LeFeatureMask] = { + hci.Phy.LE_2M: hci.LeFeatureMask.LE_2M_PHY, + hci.Phy.LE_CODED: hci.LeFeatureMask.LE_CODED_PHY, } if phy not in feature_map: raise InvalidArgumentError('invalid PHY') @@ -3522,7 +3529,9 @@ class Device(utils.CompositeEventEmitter): self.discovering = False @host_event_handler - def on_inquiry_result(self, address, class_of_device, data, rssi): + def on_inquiry_result( + self, address: hci.Address, class_of_device: int, data: bytes, rssi: int + ): self.emit( self.EVENT_INQUIRY_RESULT, address, @@ -3531,7 +3540,9 @@ class Device(utils.CompositeEventEmitter): rssi, ) - async def set_scan_enable(self, inquiry_scan_enabled, page_scan_enabled): + async def set_scan_enable( + self, inquiry_scan_enabled: bool, page_scan_enabled: bool + ): if inquiry_scan_enabled and page_scan_enabled: scan_enable = 0x03 elif page_scan_enabled: @@ -3657,6 +3668,7 @@ class Device(utils.CompositeEventEmitter): # If the address is not parsable, assume it is a name instead always_resolve = False logger.debug('looking for peer by name') + assert isinstance(peer_address, str) peer_address = await self.find_peer_by_name( peer_address, transport ) # TODO: timeout @@ -3684,7 +3696,7 @@ class Device(utils.CompositeEventEmitter): ): pending_connection.set_result(connection) - def on_connection_failure(error): + def on_connection_failure(error: core.ConnectionError): if transport == PhysicalTransport.LE or ( # match BR/EDR connection failure event against peer address error.transport == transport @@ -3904,6 +3916,7 @@ class Device(utils.CompositeEventEmitter): except InvalidArgumentError: # If the address is not parsable, assume it is a name instead logger.debug('looking for peer by name') + assert isinstance(peer_address, str) peer_address = await self.find_peer_by_name( peer_address, PhysicalTransport.BR_EDR ) # TODO: timeout @@ -3962,7 +3975,7 @@ class Device(utils.CompositeEventEmitter): ): pending_connection.set_result(connection) - def on_connection_failure(error): + def on_connection_failure(error: core.ConnectionError): if ( error.transport == PhysicalTransport.BR_EDR and error.peer_address == peer_address @@ -3999,7 +4012,7 @@ class Device(utils.CompositeEventEmitter): self.pending_connections.pop(peer_address, None) @asynccontextmanager - async def connect_as_gatt(self, peer_address): + async def connect_as_gatt(self, peer_address: Union[hci.Address, str]): async with AsyncExitStack() as stack: connection = await stack.enter_async_context( await self.connect(peer_address) @@ -4035,6 +4048,7 @@ class Device(utils.CompositeEventEmitter): except InvalidArgumentError: # If the address is not parsable, assume it is a name instead logger.debug('looking for peer by name') + assert isinstance(peer_address, str) peer_address = await self.find_peer_by_name( peer_address, PhysicalTransport.BR_EDR ) # TODO: timeout @@ -4080,7 +4094,9 @@ class Device(utils.CompositeEventEmitter): ) self.disconnecting = False - async def set_data_length(self, connection, tx_octets, tx_time) -> None: + async def set_data_length( + self, connection: Connection, tx_octets: int, tx_time: int + ) -> None: if tx_octets < 0x001B or tx_octets > 0x00FB: raise InvalidArgumentError('tx_octets must be between 0x001B and 0x00FB') @@ -4183,7 +4199,11 @@ class Device(utils.CompositeEventEmitter): ) async def set_connection_phy( - self, connection, tx_phys=None, rx_phys=None, phy_options=None + self, + connection: Connection, + tx_phys: Optional[Iterable[hci.Phy]] = None, + rx_phys: Optional[Iterable[hci.Phy]] = None, + phy_options: int = 0, ): if not self.host.supports_command(hci.HCI_LE_SET_PHY_COMMAND): logger.warning('ignoring request, command not supported') @@ -4199,7 +4219,7 @@ class Device(utils.CompositeEventEmitter): all_phys=all_phys_bits, tx_phys=hci.phy_list_to_bits(tx_phys), rx_phys=hci.phy_list_to_bits(rx_phys), - phy_options=0 if phy_options is None else int(phy_options), + phy_options=phy_options, ) ) @@ -4210,7 +4230,11 @@ class Device(utils.CompositeEventEmitter): ) raise hci.HCI_StatusError(result) - async def set_default_phy(self, tx_phys=None, rx_phys=None): + async def set_default_phy( + self, + tx_phys: Optional[Iterable[hci.Phy]] = None, + rx_phys: Optional[Iterable[hci.Phy]] = None, + ): all_phys_bits = (1 if tx_phys is None else 0) | ( (1 if rx_phys is None else 0) << 1 ) @@ -4248,7 +4272,7 @@ class Device(utils.CompositeEventEmitter): check_result=True, ) - async def find_peer_by_name(self, name, transport=PhysicalTransport.LE): + async def find_peer_by_name(self, name: str, transport=PhysicalTransport.LE): """ Scan for a peer with a given name and return its address. """ @@ -4263,7 +4287,7 @@ class Device(utils.CompositeEventEmitter): if local_name == name: peer_address.set_result(address) - listener = None + listener: Optional[Callable[..., None]] = None was_scanning = self.scanning was_discovering = self.discovering try: @@ -4369,10 +4393,10 @@ class Device(utils.CompositeEventEmitter): def smp_session_proxy(self, session_proxy: type[smp.Session]) -> None: self.smp_manager.session_proxy = session_proxy - async def pair(self, connection): + async def pair(self, connection: Connection): return await self.smp_manager.pair(connection) - def request_pairing(self, connection): + def request_pairing(self, connection: Connection): return self.smp_manager.request_pairing(connection) async def get_long_term_key( @@ -4460,7 +4484,7 @@ class Device(utils.CompositeEventEmitter): on_authentication_failure, ) - async def encrypt(self, connection, enable=True): + async def encrypt(self, connection: Connection, enable: bool = True): if not enable and connection.transport == PhysicalTransport.LE: raise InvalidArgumentError('`enable` parameter is classic only.') @@ -4470,7 +4494,7 @@ class Device(utils.CompositeEventEmitter): def on_encryption_change(): pending_encryption.set_result(None) - def on_encryption_failure(error_code): + def on_encryption_failure(error_code: int): pending_encryption.set_exception(hci.HCI_Error(error_code)) connection.on( @@ -4562,10 +4586,10 @@ class Device(utils.CompositeEventEmitter): async def switch_role(self, connection: Connection, role: hci.Role): pending_role_change = asyncio.get_running_loop().create_future() - def on_role_change(new_role): + def on_role_change(new_role: hci.Role): pending_role_change.set_result(new_role) - def on_role_change_failure(error_code): + def on_role_change_failure(error_code: int): pending_role_change.set_exception(hci.HCI_Error(error_code)) connection.on(connection.EVENT_ROLE_CHANGE, on_role_change) @@ -5155,10 +5179,10 @@ class Device(utils.CompositeEventEmitter): ): connection.emit(connection.EVENT_LINK_KEY) - def add_service(self, service): + def add_service(self, service: gatt.Service): self.gatt_server.add_service(service) - def add_services(self, services): + def add_services(self, services: Iterable[gatt.Service]): self.gatt_server.add_services(services) def add_default_services( @@ -5254,10 +5278,10 @@ class Device(utils.CompositeEventEmitter): @host_event_handler def on_advertising_set_termination( self, - status, - advertising_handle, - connection_handle, - number_of_completed_extended_advertising_events, + status: int, + advertising_handle: int, + connection_handle: int, + number_of_completed_extended_advertising_events: int, ): # Legacy advertising set is also one of extended advertising sets. if not ( @@ -5556,7 +5580,12 @@ class Device(utils.CompositeEventEmitter): ) @host_event_handler - def on_connection_failure(self, transport, peer_address, error_code): + def on_connection_failure( + self, + transport: hci.PhysicalTransport, + peer_address: hci.Address, + error_code: int, + ): logger.debug( f'*** Connection failed: {hci.HCI_Constant.error_name(error_code)}' ) @@ -5675,7 +5704,7 @@ class Device(utils.CompositeEventEmitter): @host_event_handler @with_connection_from_handle - def on_connection_authentication(self, connection): + def on_connection_authentication(self, connection: Connection): logger.debug( f'*** Connection Authentication: [0x{connection.handle:04X}] ' f'{connection.peer_address} as {connection.role_name}' @@ -5685,7 +5714,9 @@ class Device(utils.CompositeEventEmitter): @host_event_handler @with_connection_from_handle - def on_connection_authentication_failure(self, connection, error): + def on_connection_authentication_failure( + self, connection: Connection, error: core.ConnectionError + ): logger.debug( f'*** Connection Authentication Failure: [0x{connection.handle:04X}] ' f'{connection.peer_address} as {connection.role_name}, error={error}' @@ -5727,10 +5758,13 @@ class Device(utils.CompositeEventEmitter): @host_event_handler @with_connection_from_address def on_authentication_io_capability_response( - self, connection, io_capability, authentication_requirements + self, + connection: Connection, + io_capability: int, + authentication_requirements: int, ): - connection.peer_pairing_io_capability = io_capability - connection.peer_pairing_authentication_requirements = ( + connection.pairing_peer_io_capability = io_capability + connection.pairing_peer_authentication_requirements = ( authentication_requirements ) @@ -5741,7 +5775,7 @@ class Device(utils.CompositeEventEmitter): # Ask what the pairing config should be for this connection pairing_config = self.pairing_config_factory(connection) io_capability = pairing_config.delegate.classic_io_capability - peer_io_capability = connection.peer_pairing_io_capability + peer_io_capability = connection.pairing_peer_io_capability async def confirm() -> bool: # Ask the user to confirm the pairing, without display @@ -5816,7 +5850,7 @@ class Device(utils.CompositeEventEmitter): # [Classic only] @host_event_handler @with_connection_from_address - def on_authentication_user_passkey_request(self, connection) -> None: + def on_authentication_user_passkey_request(self, connection: Connection) -> None: # Ask what the pairing config should be for this connection pairing_config = self.pairing_config_factory(connection) @@ -5859,7 +5893,7 @@ class Device(utils.CompositeEventEmitter): # [Classic only] @host_event_handler @with_connection_from_address - def on_pin_code_request(self, connection): + def on_pin_code_request(self, connection: Connection): # Classic legacy pairing # Ask what the pairing config should be for this connection pairing_config = self.pairing_config_factory(connection) @@ -5903,7 +5937,9 @@ class Device(utils.CompositeEventEmitter): # [Classic only] @host_event_handler @with_connection_from_address - def on_authentication_user_passkey_notification(self, connection, passkey): + def on_authentication_user_passkey_notification( + self, connection: Connection, passkey: int + ): # Ask what the pairing config should be for this connection pairing_config = self.pairing_config_factory(connection) @@ -5915,14 +5951,15 @@ class Device(utils.CompositeEventEmitter): # [Classic only] @host_event_handler @try_with_connection_from_address - def on_remote_name(self, connection: Connection, address, remote_name): + def on_remote_name( + self, connection: Connection, address: hci.Address, remote_name: bytes + ): # Try to decode the name try: - remote_name = remote_name.decode('utf-8') if connection: - connection.peer_name = remote_name + connection.peer_name = remote_name.decode('utf-8') connection.emit(connection.EVENT_REMOTE_NAME) - self.emit(self.EVENT_REMOTE_NAME, address, remote_name) + self.emit(self.EVENT_REMOTE_NAME, address, remote_name.decode('utf-8')) except UnicodeDecodeError as error: logger.warning('peer name is not valid UTF-8') if connection: @@ -5933,7 +5970,9 @@ class Device(utils.CompositeEventEmitter): # [Classic only] @host_event_handler @try_with_connection_from_address - def on_remote_name_failure(self, connection: Connection, address, error): + def on_remote_name_failure( + self, connection: Connection, address: hci.Address, error: int + ): if connection: connection.emit(connection.EVENT_REMOTE_NAME_FAILURE, error) self.emit(self.EVENT_REMOTE_NAME_FAILURE, address, error) @@ -6134,7 +6173,7 @@ class Device(utils.CompositeEventEmitter): @host_event_handler @with_connection_from_handle - def on_connection_encryption_key_refresh(self, connection): + def on_connection_encryption_key_refresh(self, connection: Connection): logger.debug( f'*** Connection Key Refresh: [0x{connection.handle:04X}] ' f'{connection.peer_address} as {connection.role_name}' @@ -6172,7 +6211,9 @@ class Device(utils.CompositeEventEmitter): @host_event_handler @with_connection_from_handle - def on_connection_parameters_update_failure(self, connection, error): + def on_connection_parameters_update_failure( + self, connection: Connection, error: int + ): logger.debug( f'*** Connection Parameters Update Failed: [0x{connection.handle:04X}] ' f'{connection.peer_address} as {connection.role_name}, ' @@ -6182,7 +6223,7 @@ class Device(utils.CompositeEventEmitter): @host_event_handler @with_connection_from_handle - def on_connection_phy_update(self, connection, phy): + def on_connection_phy_update(self, connection: Connection, phy: core.ConnectionPHY): logger.debug( f'*** Connection PHY Update: [0x{connection.handle:04X}] ' f'{connection.peer_address} as {connection.role_name}, ' @@ -6192,7 +6233,7 @@ class Device(utils.CompositeEventEmitter): @host_event_handler @with_connection_from_handle - def on_connection_phy_update_failure(self, connection, error): + def on_connection_phy_update_failure(self, connection: Connection, error: int): logger.debug( f'*** Connection PHY Update Failed: [0x{connection.handle:04X}] ' f'{connection.peer_address} as {connection.role_name}, ' @@ -6221,7 +6262,7 @@ class Device(utils.CompositeEventEmitter): @host_event_handler @with_connection_from_handle - def on_connection_att_mtu_update(self, connection, att_mtu): + def on_connection_att_mtu_update(self, connection: Connection, att_mtu: int): logger.debug( f'*** Connection ATT MTU Update: [0x{connection.handle:04X}] ' f'{connection.peer_address} as {connection.role_name}, ' @@ -6233,7 +6274,12 @@ class Device(utils.CompositeEventEmitter): @host_event_handler @with_connection_from_handle def on_connection_data_length_change( - self, connection, max_tx_octets, max_tx_time, max_rx_octets, max_rx_time + self, + connection: Connection, + max_tx_octets: int, + max_tx_time: int, + max_rx_octets: int, + max_rx_time: int, ): logger.debug( f'*** Connection Data Length Change: [0x{connection.handle:04X}] ' @@ -6358,14 +6404,16 @@ class Device(utils.CompositeEventEmitter): # [Classic only] @host_event_handler @with_connection_from_address - def on_role_change(self, connection, new_role): + def on_role_change(self, connection: Connection, new_role: hci.Role): connection.role = new_role connection.emit(connection.EVENT_ROLE_CHANGE, new_role) # [Classic only] @host_event_handler @try_with_connection_from_address - def on_role_change_failure(self, connection, address, error): + def on_role_change_failure( + self, connection: Connection, address: hci.Address, error: int + ): if connection: connection.emit(connection.EVENT_ROLE_CHANGE_FAILURE, error) self.emit(self.EVENT_ROLE_CHANGE_FAILURE, address, error) @@ -6379,7 +6427,7 @@ class Device(utils.CompositeEventEmitter): # [Classic only] @host_event_handler @with_connection_from_address - def on_classic_pairing_failure(self, connection: Connection, status) -> None: + def on_classic_pairing_failure(self, connection: Connection, status: int) -> None: connection.emit(connection.EVENT_CLASSIC_PAIRING_FAILURE, status) def on_pairing_start(self, connection: Connection) -> None: @@ -6403,7 +6451,7 @@ class Device(utils.CompositeEventEmitter): connection.emit(connection.EVENT_PAIRING_FAILURE, reason) @with_connection_from_handle - def on_gatt_pdu(self, connection, pdu): + def on_gatt_pdu(self, connection: Connection, pdu: bytes): # Parse the L2CAP payload into an ATT PDU object att_pdu = ATT_PDU.from_bytes(pdu) @@ -6425,7 +6473,7 @@ class Device(utils.CompositeEventEmitter): connection.gatt_server.on_gatt_pdu(connection, att_pdu) @with_connection_from_handle - def on_smp_pdu(self, connection, pdu): + def on_smp_pdu(self, connection: Connection, pdu: bytes): self.smp_manager.on_smp_pdu(connection, pdu) @host_event_handler