From b731f6f5569f7b5b46220171faf56608d19a0597 Mon Sep 17 00:00:00 2001 From: uael Date: Thu, 2 Feb 2023 17:36:23 +0000 Subject: [PATCH] overall: add types hints to the small subset used by avatar --- bumble/core.py | 55 ++++---- bumble/device.py | 238 +++++++++++++++++++++-------------- bumble/gatt.py | 6 +- bumble/gatt_server.py | 5 +- bumble/hci.py | 6 +- bumble/host.py | 2 +- bumble/keys.py | 5 +- bumble/sdp.py | 34 ++--- bumble/smp.py | 28 +++-- bumble/transport/__init__.py | 2 +- bumble/transport/common.py | 2 +- bumble/utils.py | 7 +- tests/core_test.py | 10 +- 13 files changed, 234 insertions(+), 166 deletions(-) diff --git a/bumble/core.py b/bumble/core.py index 489d2f1b..3542405b 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -17,6 +17,7 @@ # ----------------------------------------------------------------------------- from __future__ import annotations import struct +from typing import List, Optional, Tuple, Union, cast from .company_ids import COMPANY_IDENTIFIERS @@ -146,7 +147,7 @@ class UUID: ''' BASE_UUID = bytes.fromhex('00001000800000805F9B34FB') - UUIDS: list[UUID] = [] # Registry of all instances created + UUIDS: List[UUID] = [] # Registry of all instances created def __init__(self, uuid_str_or_int, name=None): if isinstance(uuid_str_or_int, int): @@ -181,7 +182,7 @@ class UUID: return self @classmethod - def from_bytes(cls, uuid_bytes, name=None): + def from_bytes(cls, uuid_bytes: bytes, name: Optional[str] = None) -> UUID: if len(uuid_bytes) in (2, 4, 16): self = cls.__new__(cls) self.uuid_bytes = uuid_bytes @@ -225,7 +226,7 @@ class UUID: ''' return self.to_bytes(force_128=(len(self.uuid_bytes) == 4)) - def to_hex_str(self): + def to_hex_str(self) -> str: if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4: return bytes(reversed(self.uuid_bytes)).hex().upper() @@ -607,6 +608,11 @@ class DeviceClass: # ----------------------------------------------------------------------------- # Advertising Data # ----------------------------------------------------------------------------- +AdvertisingObject = Union[ + List[UUID], Tuple[UUID, bytes], bytes, str, int, Tuple[int, int], Tuple[int, bytes] +] + + class AdvertisingData: # fmt: off # pylint: disable=line-too-long @@ -722,10 +728,12 @@ class AdvertisingData: BR_EDR_CONTROLLER_FLAG = 0x08 BR_EDR_HOST_FLAG = 0x10 + ad_structures: List[Tuple[int, bytes]] + # fmt: on # pylint: enable=line-too-long - def __init__(self, ad_structures=None): + def __init__(self, ad_structures: Optional[List[Tuple[int, bytes]]] = None) -> None: if ad_structures is None: ad_structures = [] self.ad_structures = ad_structures[:] @@ -752,7 +760,7 @@ class AdvertisingData: return ','.join(bit_flags_to_strings(flags, flag_names)) @staticmethod - def uuid_list_to_objects(ad_data, uuid_size): + def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> List[UUID]: uuids = [] offset = 0 while (uuid_size * (offset + 1)) <= len(ad_data): @@ -829,7 +837,7 @@ class AdvertisingData: # pylint: disable=too-many-return-statements @staticmethod - def ad_data_to_object(ad_type, ad_data): + def ad_data_to_object(ad_type: int, ad_data: bytes) -> AdvertisingObject: if ad_type in ( AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, @@ -868,22 +876,22 @@ class AdvertisingData: return ad_data.decode("utf-8") if ad_type in (AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS): - return ad_data[0] + return cast(int, struct.unpack('B', ad_data)[0]) if ad_type in ( AdvertisingData.APPEARANCE, AdvertisingData.ADVERTISING_INTERVAL, ): - return struct.unpack(' List[AdvertisingObject]: ''' Get Advertising Data Structure(s) with a given type - If return_all is True, returns a (possibly empty) list of matches, - else returns the first entry, or None if no structure matches. + Returns a (possibly empty) list of matches. ''' - def process_ad_data(ad_data): + def process_ad_data(ad_data: bytes) -> AdvertisingObject: return ad_data if raw else self.ad_data_to_object(type_id, ad_data) - if return_all: - return [ - process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id - ] + return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id] - return next( - (process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id), - None, - ) + def get(self, type_id: int, raw: bool = False) -> Optional[AdvertisingObject]: + ''' + Get Advertising Data Structure(s) with a given type + + Returns the first entry, or None if no structure matches. + ''' + + all = self.get_all(type_id, raw=raw) + return all[0] if all else None def __bytes__(self): return b''.join( diff --git a/bumble/device.py b/bumble/device.py index 22777710..b5137c6b 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -23,7 +23,7 @@ import asyncio import logging from contextlib import asynccontextmanager, AsyncExitStack from dataclasses import dataclass -from typing import ClassVar +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from colors import color @@ -197,6 +197,8 @@ DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONN # ----------------------------------------------------------------------------- class Advertisement: + address: Address + TX_POWER_NOT_AVAILABLE = ( HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE ) @@ -511,6 +513,17 @@ ConnectionParametersPreferences.default = ConnectionParametersPreferences() # ----------------------------------------------------------------------------- class Connection(CompositeEventEmitter): + device: Device + handle: int + transport: int + self_address: Address + peer_address: Address + role: int + encryption: int + authenticated: bool + sc: bool + link_key_type: int + @composite_listener class Listener: def on_disconnection(self, reason): @@ -611,6 +624,10 @@ class Connection(CompositeEventEmitter): def is_encrypted(self): return self.encryption != 0 + @property + def is_incomplete(self) -> bool: + return self.handle == None + def send_l2cap_pdu(self, cid, pdu): self.device.send_l2cap_pdu(self.handle, cid, pdu) @@ -626,20 +643,22 @@ class Connection(CompositeEventEmitter): ): return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps) - async def disconnect(self, reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR): - return await self.device.disconnect(self, reason) + async def disconnect( + self, reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR + ) -> None: + await self.device.disconnect(self, reason) - async def pair(self): + async def pair(self) -> None: return await self.device.pair(self) - def request_pairing(self): + def request_pairing(self) -> None: return self.device.request_pairing(self) # [Classic only] - async def authenticate(self): + async def authenticate(self) -> None: return await self.device.authenticate(self) - async def encrypt(self, enable=True): + async def encrypt(self, enable: bool = True) -> None: return await self.device.encrypt(self, enable) async def sustain(self, timeout=None): @@ -707,10 +726,10 @@ class Connection(CompositeEventEmitter): # ----------------------------------------------------------------------------- class DeviceConfiguration: - def __init__(self): + def __init__(self) -> None: # Setup defaults self.name = DEVICE_DEFAULT_NAME - self.address = DEVICE_DEFAULT_ADDRESS + self.address = Address(DEVICE_DEFAULT_ADDRESS) self.class_of_device = DEVICE_DEFAULT_CLASS_OF_DEVICE self.scan_response_data = DEVICE_DEFAULT_SCAN_RESPONSE_DATA self.advertising_interval_min = DEVICE_DEFAULT_ADVERTISING_INTERVAL @@ -730,12 +749,13 @@ class DeviceConfiguration: ) self.irk = bytes(16) # This really must be changed for any level of security self.keystore = None - self.gatt_services = [] + self.gatt_services: List[Dict[str, Any]] = [] - def load_from_dict(self, config): + def load_from_dict(self, config: Dict[str, Any]) -> None: # Load simple properties self.name = config.get('name', self.name) - self.address = Address(config.get('address', self.address)) + if address := config.get('address', None): + self.address = Address(address) self.class_of_device = config.get('class_of_device', self.class_of_device) self.advertising_interval_min = config.get( 'advertising_interval', self.advertising_interval_min @@ -842,6 +862,22 @@ device_host_event_handlers: list[str] = [] # ----------------------------------------------------------------------------- class Device(CompositeEventEmitter): + # incomplete list of fields. + random_address: Address + public_address: Address + classic_enabled: bool + name: str + class_of_device: int + gatt_server: gatt_server.Server + advertising_data: bytes + scan_response_data: bytes + connections: Dict[int, Connection] + pending_connections: Dict[Address, Connection] + classic_pending_accepts: Dict[ + Address, List[asyncio.Future[Union[Connection, Tuple[Address, int, int]]]] + ] + advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator] + @composite_listener class Listener: def on_advertisement(self, advertisement): @@ -888,12 +924,12 @@ class Device(CompositeEventEmitter): def __init__( self, - name=None, - address=None, - config=None, - host=None, - generic_access_service=True, - ): + name: Optional[str] = None, + address: Optional[Address] = None, + config: Optional[DeviceConfiguration] = None, + host: Optional[Host] = None, + generic_access_service: bool = True, + ) -> None: super().__init__() self._host = None @@ -995,10 +1031,12 @@ class Device(CompositeEventEmitter): setup_event_forwarding(self.gatt_server, self, 'characteristic_subscription') # Set the initial host - self.host = host + if host: + self.host = host @property - def host(self): + def host(self) -> Host: + assert self._host return self._host @host.setter @@ -1032,15 +1070,18 @@ class Device(CompositeEventEmitter): def sdp_service_records(self, service_records): self.sdp_server.service_records = service_records - def lookup_connection(self, connection_handle): + def lookup_connection(self, connection_handle: int) -> Optional[Connection]: if connection := self.connections.get(connection_handle): return connection return None def find_connection_by_bd_addr( - self, bd_addr, transport=None, check_address_type=False - ): + self, + bd_addr: Address, + transport: Optional[int] = None, + check_address_type: bool = False, + ) -> Optional[Connection]: for connection in self.connections.values(): if connection.peer_address.to_bytes() == bd_addr.to_bytes(): if ( @@ -1098,11 +1139,11 @@ class Device(CompositeEventEmitter): logger.warning('!!! Command timed out') raise CommandTimeoutError() from error - async def power_on(self): + async def power_on(self) -> None: # Reset the controller await self.host.reset() - response = await self.send_command(HCI_Read_BD_ADDR_Command()) + response = await self.send_command(HCI_Read_BD_ADDR_Command()) # type: ignore[call-arg] if response.return_parameters.status == HCI_SUCCESS: logger.debug( color(f'BD_ADDR: {response.return_parameters.bd_addr}', 'yellow') @@ -1114,7 +1155,7 @@ class Device(CompositeEventEmitter): HCI_Write_LE_Host_Support_Command( le_supported_host=int(self.le_enabled), simultaneous_le_host=int(self.le_simultaneous_enabled), - ) + ) # type: ignore[call-arg] ) if self.le_enabled: @@ -1124,7 +1165,7 @@ class Device(CompositeEventEmitter): if self.host.supports_command(HCI_LE_RAND_COMMAND): # Get 8 random bytes response = await self.send_command( - HCI_LE_Rand_Command(), check_result=True + HCI_LE_Rand_Command(), check_result=True # type: ignore[call-arg] ) # Ensure the address bytes can be a static random address @@ -1145,7 +1186,7 @@ class Device(CompositeEventEmitter): await self.send_command( HCI_LE_Set_Random_Address_Command( random_address=self.random_address - ), + ), # type: ignore[call-arg] check_result=True, ) @@ -1153,7 +1194,7 @@ class Device(CompositeEventEmitter): if self.keystore and self.host.supports_command( HCI_LE_CLEAR_RESOLVING_LIST_COMMAND ): - await self.send_command(HCI_LE_Clear_Resolving_List_Command()) + await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg] resolving_keys = await self.keystore.get_resolving_keys() for (irk, address) in resolving_keys: @@ -1163,7 +1204,7 @@ class Device(CompositeEventEmitter): peer_identity_address=address, peer_irk=irk, local_irk=self.irk, - ) + ) # type: ignore[call-arg] ) # Enable address resolution @@ -1178,28 +1219,24 @@ class Device(CompositeEventEmitter): if self.classic_enabled: await self.send_command( - HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8')) + HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8')) # type: ignore[call-arg] ) await self.send_command( - HCI_Write_Class_Of_Device_Command(class_of_device=self.class_of_device) + HCI_Write_Class_Of_Device_Command(class_of_device=self.class_of_device) # type: ignore[call-arg] ) await self.send_command( HCI_Write_Simple_Pairing_Mode_Command( simple_pairing_mode=int(self.classic_ssp_enabled) - ) + ) # type: ignore[call-arg] ) await self.send_command( HCI_Write_Secure_Connections_Host_Support_Command( secure_connections_host_support=int(self.classic_sc_enabled) - ) + ) # type: ignore[call-arg] ) await self.set_connectable(self.connectable) await self.set_discoverable(self.discoverable) - # Let the SMP manager know about the address - # TODO: allow using a public address - self.smp_manager.address = self.random_address - # Done self.powered_on = True @@ -1221,11 +1258,11 @@ class Device(CompositeEventEmitter): async def start_advertising( self, - advertising_type=AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE, - target=None, - own_address_type=OwnAddressType.RANDOM, - auto_restart=False, - ): + advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE, + target: Optional[Address] = None, + own_address_type: int = OwnAddressType.RANDOM, + auto_restart: bool = False, + ) -> None: # If we're advertising, stop first if self.advertising: await self.stop_advertising() @@ -1235,7 +1272,7 @@ class Device(CompositeEventEmitter): await self.send_command( HCI_LE_Set_Advertising_Data_Command( advertising_data=self.advertising_data - ), + ), # type: ignore[call-arg] check_result=True, ) @@ -1244,7 +1281,7 @@ class Device(CompositeEventEmitter): await self.send_command( HCI_LE_Set_Scan_Response_Data_Command( scan_response_data=self.scan_response_data - ), + ), # type: ignore[call-arg] check_result=True, ) @@ -1270,13 +1307,13 @@ class Device(CompositeEventEmitter): peer_address=peer_address, advertising_channel_map=7, advertising_filter_policy=0, - ), + ), # type: ignore[call-arg] check_result=True, ) # Enable advertising await self.send_command( - HCI_LE_Set_Advertising_Enable_Command(advertising_enable=1), + HCI_LE_Set_Advertising_Enable_Command(advertising_enable=1), # type: ignore[call-arg] check_result=True, ) @@ -1285,11 +1322,11 @@ class Device(CompositeEventEmitter): self.advertising_type = advertising_type self.advertising = True - async def stop_advertising(self): + async def stop_advertising(self) -> None: # Disable advertising if self.advertising: await self.send_command( - HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0), + HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0), # type: ignore[call-arg] check_result=True, ) @@ -1304,14 +1341,14 @@ class Device(CompositeEventEmitter): async def start_scanning( self, - legacy=False, - active=True, - scan_interval=DEVICE_DEFAULT_SCAN_INTERVAL, # Scan interval in ms - scan_window=DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms - own_address_type=OwnAddressType.RANDOM, - filter_duplicates=False, - scanning_phys=(HCI_LE_1M_PHY, HCI_LE_CODED_PHY), - ): + legacy: bool = False, + active: bool = True, + scan_interval: int = DEVICE_DEFAULT_SCAN_INTERVAL, # Scan interval in ms + scan_window: int = DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms + own_address_type: int = OwnAddressType.RANDOM, + filter_duplicates: bool = False, + scanning_phys: Tuple[int, int] = (HCI_LE_1M_PHY, HCI_LE_CODED_PHY), + ) -> None: # Check that the arguments are legal if scan_interval < scan_window: raise ValueError('scan_interval must be >= scan_window') @@ -1361,7 +1398,7 @@ class Device(CompositeEventEmitter): scan_types=[scan_type] * scanning_phy_count, scan_intervals=[int(scan_window / 0.625)] * scanning_phy_count, scan_windows=[int(scan_window / 0.625)] * scanning_phy_count, - ), + ), # type: ignore[call-arg] check_result=True, ) @@ -1372,7 +1409,7 @@ class Device(CompositeEventEmitter): filter_duplicates=1 if filter_duplicates else 0, duration=0, # TODO allow other values period=0, # TODO allow other values - ), + ), # type: ignore[call-arg] check_result=True, ) else: @@ -1390,7 +1427,7 @@ class Device(CompositeEventEmitter): le_scan_window=int(scan_window / 0.625), own_address_type=own_address_type, scanning_filter_policy=HCI_LE_Set_Scan_Parameters_Command.BASIC_UNFILTERED_POLICY, - ), + ), # type: ignore[call-arg] check_result=True, ) @@ -1398,25 +1435,25 @@ class Device(CompositeEventEmitter): await self.send_command( HCI_LE_Set_Scan_Enable_Command( le_scan_enable=1, filter_duplicates=1 if filter_duplicates else 0 - ), + ), # type: ignore[call-arg] check_result=True, ) self.scanning_is_passive = not active self.scanning = True - async def stop_scanning(self): + async def stop_scanning(self) -> None: # Disable scanning if self.supports_le_feature(HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE): await self.send_command( HCI_LE_Set_Extended_Scan_Enable_Command( enable=0, filter_duplicates=0, duration=0, period=0 - ), + ), # type: ignore[call-arg] check_result=True, ) else: await self.send_command( - HCI_LE_Set_Scan_Enable_Command(le_scan_enable=0, filter_duplicates=0), + HCI_LE_Set_Scan_Enable_Command(le_scan_enable=0, filter_duplicates=0), # type: ignore[call-arg] check_result=True, ) @@ -1434,9 +1471,9 @@ class Device(CompositeEventEmitter): if advertisement := accumulator.update(report): self.emit('advertisement', advertisement) - async def start_discovery(self, auto_restart=True): + async def start_discovery(self, auto_restart: bool = True) -> None: await self.send_command( - HCI_Write_Inquiry_Mode_Command(inquiry_mode=HCI_EXTENDED_INQUIRY_MODE), + HCI_Write_Inquiry_Mode_Command(inquiry_mode=HCI_EXTENDED_INQUIRY_MODE), # type: ignore[call-arg] check_result=True, ) @@ -1445,7 +1482,7 @@ class Device(CompositeEventEmitter): lap=HCI_GENERAL_INQUIRY_LAP, inquiry_length=DEVICE_DEFAULT_INQUIRY_LENGTH, num_responses=0, # Unlimited number of responses. - ) + ) # type: ignore[call-arg] ) if response.status != HCI_Command_Status_Event.PENDING: self.discovering = False @@ -1454,9 +1491,9 @@ class Device(CompositeEventEmitter): self.auto_restart_inquiry = auto_restart self.discovering = True - async def stop_discovery(self): + async def stop_discovery(self) -> None: if self.discovering: - await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True) + await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True) # type: ignore[call-arg] self.auto_restart_inquiry = True self.discovering = False @@ -1484,7 +1521,7 @@ class Device(CompositeEventEmitter): HCI_Write_Scan_Enable_Command(scan_enable=scan_enable) ) - async def set_discoverable(self, discoverable=True): + async def set_discoverable(self, discoverable: bool = True) -> None: self.discoverable = discoverable if self.classic_enabled: # Synthesize an inquiry response if none is set already @@ -1504,7 +1541,7 @@ class Device(CompositeEventEmitter): await self.send_command( HCI_Write_Extended_Inquiry_Response_Command( fec_required=0, extended_inquiry_response=self.inquiry_response - ), + ), # type: ignore[call-arg] check_result=True, ) await self.set_scan_enable( @@ -1512,7 +1549,7 @@ class Device(CompositeEventEmitter): page_scan_enabled=self.connectable, ) - async def set_connectable(self, connectable=True): + async def set_connectable(self, connectable: bool = True) -> None: self.connectable = connectable if self.classic_enabled: await self.set_scan_enable( @@ -1522,12 +1559,14 @@ class Device(CompositeEventEmitter): async def connect( self, - peer_address, - transport=BT_LE_TRANSPORT, - connection_parameters_preferences=None, - own_address_type=OwnAddressType.RANDOM, - timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT, - ): + peer_address: Union[Address, str], + transport: int = BT_LE_TRANSPORT, + connection_parameters_preferences: Optional[ + Dict[int, ConnectionParametersPreferences] + ] = None, + own_address_type: int = OwnAddressType.RANDOM, + timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT, + ) -> Connection: ''' Request a connection to a peer. When transport is BLE, this method cannot be called if there is already a @@ -1574,6 +1613,8 @@ class Device(CompositeEventEmitter): ): raise ValueError('BR/EDR addresses must be PUBLIC') + assert isinstance(peer_address, Address) + def on_connection(connection): if transport == BT_LE_TRANSPORT or ( # match BR/EDR connection event against peer address @@ -1691,7 +1732,7 @@ class Device(CompositeEventEmitter): supervision_timeouts=supervision_timeouts, min_ce_lengths=min_ce_lengths, max_ce_lengths=max_ce_lengths, - ) + ) # type: ignore[call-arg] ) else: if HCI_LE_1M_PHY not in connection_parameters_preferences: @@ -1720,7 +1761,7 @@ class Device(CompositeEventEmitter): supervision_timeout=int(prefs.supervision_timeout / 10), min_ce_length=int(prefs.min_ce_length / 0.625), max_ce_length=int(prefs.max_ce_length / 0.625), - ) + ) # type: ignore[call-arg] ) else: # Save pending connection @@ -1737,7 +1778,7 @@ class Device(CompositeEventEmitter): clock_offset=0x0000, allow_role_switch=0x01, reserved=0, - ) + ) # type: ignore[call-arg] ) if result.status != HCI_Command_Status_Event.PENDING: @@ -1756,10 +1797,10 @@ class Device(CompositeEventEmitter): ) except asyncio.TimeoutError: if transport == BT_LE_TRANSPORT: - await self.send_command(HCI_LE_Create_Connection_Cancel_Command()) + await self.send_command(HCI_LE_Create_Connection_Cancel_Command()) # type: ignore[call-arg] else: await self.send_command( - HCI_Create_Connection_Cancel_Command(bd_addr=peer_address) + HCI_Create_Connection_Cancel_Command(bd_addr=peer_address) # type: ignore[call-arg] ) try: @@ -1777,10 +1818,10 @@ class Device(CompositeEventEmitter): async def accept( self, - peer_address=Address.ANY, - role=BT_PERIPHERAL_ROLE, - timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT, - ): + peer_address: Union[Address, str] = Address.ANY, + role: int = BT_PERIPHERAL_ROLE, + timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT, + ) -> Connection: ''' Wait and accept any incoming connection or a connection from `peer_address` when set. @@ -1802,22 +1843,24 @@ class Device(CompositeEventEmitter): peer_address, BT_BR_EDR_TRANSPORT ) # TODO: timeout + assert isinstance(peer_address, Address) + if peer_address == Address.NIL: raise ValueError('accept on nil address') # Create a future so that we can wait for the request - pending_request = asyncio.get_running_loop().create_future() + pending_request_fut = asyncio.get_running_loop().create_future() if peer_address == Address.ANY: - self.classic_pending_accepts[Address.ANY].append(pending_request) + self.classic_pending_accepts[Address.ANY].append(pending_request_fut) elif peer_address in self.classic_pending_accepts: raise InvalidStateError('accept connection already pending') else: - self.classic_pending_accepts[peer_address] = pending_request + self.classic_pending_accepts[peer_address] = [pending_request_fut] try: # Wait for a request or a completed connection - pending_request = self.abort_on('flush', pending_request) + pending_request = self.abort_on('flush', pending_request_fut) result = await ( asyncio.wait_for(pending_request, timeout) if timeout @@ -1826,7 +1869,7 @@ class Device(CompositeEventEmitter): except Exception: # Remove future from device context if peer_address == Address.ANY: - self.classic_pending_accepts[Address.ANY].remove(pending_request) + self.classic_pending_accepts[Address.ANY].remove(pending_request_fut) else: self.classic_pending_accepts.pop(peer_address) raise @@ -1838,6 +1881,7 @@ class Device(CompositeEventEmitter): # Otherwise, result came from `on_connection_request` peer_address, _class_of_device, _link_type = result + assert isinstance(peer_address, Address) # Create a future so that we can wait for the connection's result pending_connection = asyncio.get_running_loop().create_future() @@ -1867,7 +1911,7 @@ class Device(CompositeEventEmitter): try: # Accept connection request await self.send_command( - HCI_Accept_Connection_Request_Command(bd_addr=peer_address, role=role) + HCI_Accept_Connection_Request_Command(bd_addr=peer_address, role=role) # type: ignore[call-arg] ) # Wait for connection complete @@ -2243,7 +2287,7 @@ class Device(CompositeEventEmitter): ) # [Classic only] - async def request_remote_name(self, remote): # remote: Connection | Address + async def request_remote_name(self, remote: Union[Address, Connection]) -> str: # Set up event handlers pending_name = asyncio.get_running_loop().create_future() @@ -2271,7 +2315,7 @@ class Device(CompositeEventEmitter): page_scan_repetition_mode=HCI_Remote_Name_Request_Command.R2, reserved=0, clock_offset=0, # TODO investigate non-0 values - ) + ) # type: ignore[call-arg] ) if result.status != HCI_COMMAND_STATUS_PENDING: @@ -2372,7 +2416,7 @@ class Device(CompositeEventEmitter): # In this case, set the completed `connection` to the `accept` future # result. if peer_address in self.classic_pending_accepts: - future = self.classic_pending_accepts.pop(peer_address) + future, *_ = self.classic_pending_accepts.pop(peer_address) future.set_result(connection) # Emit an event to notify listeners of the new connection @@ -2473,7 +2517,7 @@ class Device(CompositeEventEmitter): # match a pending future using `bd_addr` if bd_addr in self.classic_pending_accepts: - future = self.classic_pending_accepts.pop(bd_addr) + future, *_ = self.classic_pending_accepts.pop(bd_addr) future.set_result((bd_addr, class_of_device, link_type)) # match first pending future for ANY address diff --git a/bumble/gatt.py b/bumble/gatt.py index de4c9a7a..180941b1 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -28,7 +28,7 @@ import enum import functools import logging import struct -from typing import Sequence +from typing import Optional, Sequence from colors import color from .core import UUID, get_dict_key_by_value @@ -204,6 +204,8 @@ class Service(Attribute): See Vol 3, Part G - 3.1 SERVICE DEFINITION ''' + uuid: UUID + def __init__(self, uuid, characteristics: list[Characteristic], primary=True): # Convert the uuid to a UUID object if it isn't already if isinstance(uuid, str): @@ -221,7 +223,7 @@ class Service(Attribute): self.characteristics = characteristics[:] self.primary = primary - def get_advertising_data(self): + def get_advertising_data(self) -> Optional[bytes]: """ Get Service specific advertising data Defined by each Service, default value is empty diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index ea863047..9efc9002 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -27,7 +27,7 @@ import asyncio import logging from collections import defaultdict import struct -from typing import Tuple, Optional +from typing import List, Tuple, Optional from pyee import EventEmitter from colors import color @@ -90,6 +90,8 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517 # GATT Server # ----------------------------------------------------------------------------- class Server(EventEmitter): + attributes: List[Attribute] + def __init__(self, device): super().__init__() self.device = device @@ -140,6 +142,7 @@ class Server(EventEmitter): attribute for attribute in self.attributes if attribute.type == GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE + and isinstance(attribute, Service) and attribute.uuid == service_uuid ), None, diff --git a/bumble/hci.py b/bumble/hci.py index 91a10122..951e81cb 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -21,7 +21,7 @@ import collections import logging import functools from colors import color -from typing import Dict, Type +from typing import Dict, Type, Union from .core import ( BT_BR_EDR_TRANSPORT, @@ -1729,7 +1729,9 @@ class Address: address_type = data[offset - 1] return Address.parse_address_with_type(data, offset, address_type) - def __init__(self, address, address_type=RANDOM_DEVICE_ADDRESS): + def __init__( + self, address: Union[bytes, str], address_type: int = RANDOM_DEVICE_ADDRESS + ): ''' Initialize an instance. `address` may be a byte array in little-endian format, or a hex string in big-endian format (with optional ':' diff --git a/bumble/host.py b/bumble/host.py index 58ff8f0f..bb29eb05 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -141,7 +141,7 @@ class Host(AbortableEventEmitter): if controller_sink: self.set_packet_sink(controller_sink) - async def flush(self): + async def flush(self) -> None: # Make sure no command is pending await self.command_semaphore.acquire() diff --git a/bumble/keys.py b/bumble/keys.py index d06d19f1..d62011a2 100644 --- a/bumble/keys.py +++ b/bumble/keys.py @@ -24,6 +24,7 @@ import asyncio import logging import os import json +from typing import Optional from colors import color from .hci import Address @@ -242,7 +243,7 @@ class JsonKeyStore(KeyStore): # Atomically replace the previous file os.rename(temp_filename, self.filename) - async def delete(self, name): + async def delete(self, name: str) -> None: db = await self.load() namespace = db.get(self.namespace) @@ -278,7 +279,7 @@ class JsonKeyStore(KeyStore): await self.save(db) - async def get(self, name): + async def get(self, name: str) -> Optional[PairingKeys]: db = await self.load() namespace = db.get(self.namespace) diff --git a/bumble/sdp.py b/bumble/sdp.py index 0c773ef2..896a47a2 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -20,7 +20,7 @@ import logging import struct from colors import color import colors -from typing import Dict, Type +from typing import Dict, List, Type from . import core from .core import InvalidStateError @@ -183,63 +183,63 @@ class DataElement: raise ValueError('integer types must have a value size specified') @staticmethod - def nil(): + def nil() -> DataElement: return DataElement(DataElement.NIL, None) @staticmethod - def unsigned_integer(value, value_size): + def unsigned_integer(value: int, value_size: int) -> DataElement: return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size) @staticmethod - def unsigned_integer_8(value): + def unsigned_integer_8(value: int) -> DataElement: return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1) @staticmethod - def unsigned_integer_16(value): + def unsigned_integer_16(value: int) -> DataElement: return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2) @staticmethod - def unsigned_integer_32(value): + def unsigned_integer_32(value: int) -> DataElement: return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4) @staticmethod - def signed_integer(value, value_size): + def signed_integer(value: int, value_size: int) -> DataElement: return DataElement(DataElement.SIGNED_INTEGER, value, value_size) @staticmethod - def signed_integer_8(value): + def signed_integer_8(value: int) -> DataElement: return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1) @staticmethod - def signed_integer_16(value): + def signed_integer_16(value: int) -> DataElement: return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2) @staticmethod - def signed_integer_32(value): + def signed_integer_32(value: int) -> DataElement: return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4) @staticmethod - def uuid(value): + def uuid(value: core.UUID) -> DataElement: return DataElement(DataElement.UUID, value) @staticmethod - def text_string(value): + def text_string(value: str) -> DataElement: return DataElement(DataElement.TEXT_STRING, value) @staticmethod - def boolean(value): + def boolean(value: bool) -> DataElement: return DataElement(DataElement.BOOLEAN, value) @staticmethod - def sequence(value): + def sequence(value: List[DataElement]) -> DataElement: return DataElement(DataElement.SEQUENCE, value) @staticmethod - def alternative(value): + def alternative(value: List[DataElement]) -> DataElement: return DataElement(DataElement.ALTERNATIVE, value) @staticmethod - def url(value): + def url(value: str) -> DataElement: return DataElement(DataElement.URL, value) @staticmethod @@ -458,7 +458,7 @@ class DataElement: # ----------------------------------------------------------------------------- class ServiceAttribute: - def __init__(self, attribute_id, value): + def __init__(self, attribute_id: int, value: DataElement) -> None: self.id = attribute_id self.value = value diff --git a/bumble/smp.py b/bumble/smp.py index d3feb6b2..8c0c50a5 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -26,7 +26,7 @@ from __future__ import annotations import logging import asyncio import secrets -from typing import Dict, Type +from typing import Dict, Optional, Type from pyee import EventEmitter from colors import color @@ -504,27 +504,27 @@ class PairingDelegate: def __init__( self, - io_capability=NO_OUTPUT_NO_INPUT, - local_initiator_key_distribution=DEFAULT_KEY_DISTRIBUTION, - local_responder_key_distribution=DEFAULT_KEY_DISTRIBUTION, - ): + io_capability: int = NO_OUTPUT_NO_INPUT, + local_initiator_key_distribution: int = DEFAULT_KEY_DISTRIBUTION, + local_responder_key_distribution: int = DEFAULT_KEY_DISTRIBUTION, + ) -> None: self.io_capability = io_capability self.local_initiator_key_distribution = local_initiator_key_distribution self.local_responder_key_distribution = local_responder_key_distribution - async def accept(self): + async def accept(self) -> bool: return True - async def confirm(self): + async def confirm(self) -> bool: return True - async def compare_numbers(self, _number, _digits=6): + async def compare_numbers(self, _number: int, _digits: int = 6) -> bool: return True - async def get_number(self): + async def get_number(self) -> int: return 0 - async def display_number(self, _number, _digits=6): + async def display_number(self, _number: int, _digits: int = 6) -> None: pass async def key_distribution_response( @@ -538,7 +538,13 @@ class PairingDelegate: # ----------------------------------------------------------------------------- class PairingConfig: - def __init__(self, sc=True, mitm=True, bonding=True, delegate=None): + def __init__( + self, + sc: bool = True, + mitm: bool = True, + bonding: bool = True, + delegate: Optional[PairingDelegate] = None, + ) -> None: self.sc = sc self.mitm = mitm self.bonding = bonding diff --git a/bumble/transport/__init__.py b/bumble/transport/__init__.py index a3bb4ac8..5ba67b4d 100644 --- a/bumble/transport/__init__.py +++ b/bumble/transport/__init__.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_transport(name): +async def open_transport(name: str) -> Transport: ''' Open a transport by name. The name must be : diff --git a/bumble/transport/common.py b/bumble/transport/common.py index 555a3326..a2964075 100644 --- a/bumble/transport/common.py +++ b/bumble/transport/common.py @@ -259,7 +259,7 @@ class Transport: def __iter__(self): return iter((self.source, self.sink)) - async def close(self): + async def close(self) -> None: self.source.close() self.sink.close() diff --git a/bumble/utils.py b/bumble/utils.py index 593546fa..6f811f7d 100644 --- a/bumble/utils.py +++ b/bumble/utils.py @@ -20,7 +20,7 @@ import logging import traceback import collections import sys -from typing import Awaitable +from typing import Awaitable, TypeVar from functools import wraps from colors import color from pyee import EventEmitter @@ -65,8 +65,11 @@ def composite_listener(cls): # ----------------------------------------------------------------------------- +_T = TypeVar('_T') + + class AbortableEventEmitter(EventEmitter): - def abort_on(self, event: str, awaitable: Awaitable): + def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]: """ Set a coroutine or future to abort when an event occur. """ diff --git a/tests/core_test.py b/tests/core_test.py index ba4ca5d3..7ee2dfd6 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -25,10 +25,8 @@ def test_ad_data(): assert data == ad_bytes assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None assert ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123]) - assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == [] - assert ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [ - bytes([123]) - ] + assert ad.get_all(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) == [] + assert ad.get_all(AdvertisingData.TX_POWER_LEVEL, raw=True) == [bytes([123])] data2 = bytes([2, AdvertisingData.TX_POWER_LEVEL, 234]) ad.append(data2) @@ -36,8 +34,8 @@ def test_ad_data(): assert ad_bytes == data + data2 assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None assert ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123]) - assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == [] - assert ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [ + assert ad.get_all(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) == [] + assert ad.get_all(AdvertisingData.TX_POWER_LEVEL, raw=True) == [ bytes([123]), bytes([234]), ]