From 727586e40e3f8ea36c7105829211040898c6f2f2 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Wed, 23 Aug 2023 14:43:10 +0800 Subject: [PATCH] Typing SDP --- bumble/sdp.py | 111 +++++++++++++++++++++++++++++++++++----------- tests/sdp_test.py | 4 +- 2 files changed, 87 insertions(+), 28 deletions(-) diff --git a/bumble/sdp.py b/bumble/sdp.py index 1d4faf9a..64281874 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -18,13 +18,16 @@ from __future__ import annotations import logging import struct -from typing import Dict, List, Type +from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING -from . import core +from . import core, l2cap from .colors import color from .core import InvalidStateError from .hci import HCI_Object, name_or_number, key_with_value +if TYPE_CHECKING: + from .device import Device, Connection + # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- @@ -466,7 +469,7 @@ class ServiceAttribute: self.value = value @staticmethod - def list_from_data_elements(elements): + def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]: attribute_list = [] for i in range(0, len(elements) // 2): attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)] @@ -478,7 +481,9 @@ class ServiceAttribute: return attribute_list @staticmethod - def find_attribute_in_list(attribute_list, attribute_id): + def find_attribute_in_list( + attribute_list: List[ServiceAttribute], attribute_id: int + ) -> Optional[DataElement]: return next( ( attribute.value @@ -493,7 +498,7 @@ class ServiceAttribute: return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code) @staticmethod - def is_uuid_in_value(uuid, value): + def is_uuid_in_value(uuid: core.UUID, value: DataElement) -> bool: # Find if a uuid matches a value, either directly or recursing into sequences if value.type == DataElement.UUID: return value.value == uuid @@ -547,7 +552,9 @@ class SDP_PDU: return self @staticmethod - def parse_service_record_handle_list_preceded_by_count(data, offset): + def parse_service_record_handle_list_preceded_by_count( + data: bytes, offset: int + ) -> Tuple[int, List[int]]: count = struct.unpack_from('>H', data, offset - 2)[0] handle_list = [ struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count) @@ -645,6 +652,10 @@ class SDP_ServiceSearchRequest(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU ''' + service_search_pattern: DataElement + maximum_service_record_count: int + continuation_state: bytes + # ----------------------------------------------------------------------------- @SDP_PDU.subclass( @@ -663,6 +674,11 @@ class SDP_ServiceSearchResponse(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU ''' + service_record_handle_list: List[int] + total_service_record_count: int + current_service_record_count: int + continuation_state: bytes + # ----------------------------------------------------------------------------- @SDP_PDU.subclass( @@ -678,6 +694,11 @@ class SDP_ServiceAttributeRequest(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU ''' + service_record_handle: int + maximum_attribute_byte_count: int + attribute_id_list: DataElement + continuation_state: bytes + # ----------------------------------------------------------------------------- @SDP_PDU.subclass( @@ -692,6 +713,10 @@ class SDP_ServiceAttributeResponse(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU ''' + attribute_list_byte_count: int + attribute_list: bytes + continuation_state: bytes + # ----------------------------------------------------------------------------- @SDP_PDU.subclass( @@ -707,6 +732,11 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU ''' + service_search_pattern: DataElement + maximum_attribute_byte_count: int + attribute_id_list: DataElement + continuation_state: bytes + # ----------------------------------------------------------------------------- @SDP_PDU.subclass( @@ -721,26 +751,34 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU ''' + attribute_list_byte_count: int + attribute_list: bytes + continuation_state: bytes + # ----------------------------------------------------------------------------- class Client: - def __init__(self, device): + channel: Optional[l2cap.Channel] + + def __init__(self, device: Device) -> None: self.device = device self.pending_request = None self.channel = None - async def connect(self, connection): + async def connect(self, connection: Connection) -> None: result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM) self.channel = result - async def disconnect(self): + async def disconnect(self) -> None: if self.channel: await self.channel.disconnect() self.channel = None - async def search_services(self, uuids): + async def search_services(self, uuids: List[core.UUID]) -> List[int]: if self.pending_request is not None: raise InvalidStateError('request already pending') + if self.channel is None: + raise InvalidStateError('L2CAP not connected') service_search_pattern = DataElement.sequence( [DataElement.uuid(uuid) for uuid in uuids] @@ -770,9 +808,13 @@ class Client: return service_record_handle_list - async def search_attributes(self, uuids, attribute_ids): + async def search_attributes( + self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]] + ) -> List[List[ServiceAttribute]]: if self.pending_request is not None: raise InvalidStateError('request already pending') + if self.channel is None: + raise InvalidStateError('L2CAP not connected') service_search_pattern = DataElement.sequence( [DataElement.uuid(uuid) for uuid in uuids] @@ -823,9 +865,15 @@ class Client: if sequence.type == DataElement.SEQUENCE ] - async def get_attributes(self, service_record_handle, attribute_ids): + async def get_attributes( + self, + service_record_handle: int, + attribute_ids: List[Union[int, Tuple[int, int]]], + ) -> List[ServiceAttribute]: if self.pending_request is not None: raise InvalidStateError('request already pending') + if self.channel is None: + raise InvalidStateError('L2CAP not connected') attribute_id_list = DataElement.sequence( [ @@ -873,21 +921,25 @@ class Client: # ----------------------------------------------------------------------------- class Server: CONTINUATION_STATE = bytes([0x01, 0x43]) + channel: Optional[l2cap.Channel] + Service = NewType('Service', List[ServiceAttribute]) + service_records: Dict[int, Service] + current_response: Union[None, bytes, Tuple[int, List[int]]] - def __init__(self, device): + def __init__(self, device: Device) -> None: self.device = device self.service_records = {} # Service records maps, by record handle self.channel = None self.current_response = None - def register(self, l2cap_channel_manager): + def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None: l2cap_channel_manager.register_server(SDP_PSM, self.on_connection) def send_response(self, response): logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}') self.channel.send_pdu(response) - def match_services(self, search_pattern): + def match_services(self, search_pattern: DataElement) -> Dict[int, Service]: # Find the services for which the attributes in the pattern is a subset of the # service's attribute values (NOTE: the value search recurses into sequences) matching_services = {} @@ -957,7 +1009,9 @@ class Server: return (payload, continuation_state) @staticmethod - def get_service_attributes(service, attribute_ids): + def get_service_attributes( + service: Service, attribute_ids: List[DataElement] + ) -> DataElement: attributes = [] for attribute_id in attribute_ids: if attribute_id.value_size == 4: @@ -982,10 +1036,10 @@ class Server: return attribute_list - def on_sdp_service_search_request(self, request): + def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None: # Check if this is a continuation if len(request.continuation_state) > 1: - if not self.current_response: + if self.current_response is None: self.send_response( SDP_ErrorResponse( transaction_id=request.transaction_id, @@ -1014,6 +1068,7 @@ class Server: ) # Respond, keeping any unsent handles for later + assert isinstance(self.current_response, tuple) service_record_handles = self.current_response[1][ : request.maximum_service_record_count ] @@ -1037,10 +1092,12 @@ class Server: ) ) - def on_sdp_service_attribute_request(self, request): + def on_sdp_service_attribute_request( + self, request: SDP_ServiceAttributeRequest + ) -> None: # Check if this is a continuation if len(request.continuation_state) > 1: - if not self.current_response: + if self.current_response is None: self.send_response( SDP_ErrorResponse( transaction_id=request.transaction_id, @@ -1073,22 +1130,24 @@ class Server: self.current_response = bytes(attribute_list) # Respond, keeping any pending chunks for later - attribute_list, continuation_state = self.get_next_response_payload( + attribute_list_response, continuation_state = self.get_next_response_payload( request.maximum_attribute_byte_count ) self.send_response( SDP_ServiceAttributeResponse( transaction_id=request.transaction_id, - attribute_list_byte_count=len(attribute_list), + attribute_list_byte_count=len(attribute_list_response), attribute_list=attribute_list, continuation_state=continuation_state, ) ) - def on_sdp_service_search_attribute_request(self, request): + def on_sdp_service_search_attribute_request( + self, request: SDP_ServiceSearchAttributeRequest + ) -> None: # Check if this is a continuation if len(request.continuation_state) > 1: - if not self.current_response: + if self.current_response is None: self.send_response( SDP_ErrorResponse( transaction_id=request.transaction_id, @@ -1118,13 +1177,13 @@ class Server: self.current_response = bytes(attribute_lists) # Respond, keeping any pending chunks for later - attribute_lists, continuation_state = self.get_next_response_payload( + attribute_lists_response, continuation_state = self.get_next_response_payload( request.maximum_attribute_byte_count ) self.send_response( SDP_ServiceSearchAttributeResponse( transaction_id=request.transaction_id, - attribute_lists_byte_count=len(attribute_lists), + attribute_lists_byte_count=len(attribute_lists_response), attribute_lists=attribute_lists, continuation_state=continuation_state, ) diff --git a/tests/sdp_test.py b/tests/sdp_test.py index f07b5790..505539c3 100644 --- a/tests/sdp_test.py +++ b/tests/sdp_test.py @@ -23,7 +23,7 @@ from bumble.sdp import DataElement # ----------------------------------------------------------------------------- # ----------------------------------------------------------------------------- -def basic_check(x): +def basic_check(x: DataElement) -> None: serialized = bytes(x) if len(serialized) < 500: print('Original:', x) @@ -41,7 +41,7 @@ def basic_check(x): # ----------------------------------------------------------------------------- -def test_data_elements(): +def test_data_elements() -> None: e = DataElement(DataElement.NIL, None) basic_check(e)