# Copyright 2021-2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio import logging import struct from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, ClassVar, NewType, TypeVar from typing_extensions import Self from bumble import core, hci, l2cap, utils from bumble.colors import color from bumble.core import ( InvalidArgumentError, InvalidPacketError, InvalidStateError, ProtocolError, ) if TYPE_CHECKING: from bumble.device import Connection, Device # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- # fmt: off # pylint: disable=line-too-long SDP_CONTINUATION_WATCHDOG = 64 # Maximum number of continuations we're willing to do SDP_PSM = 0x0001 class PduId(hci.SpecableEnum): SDP_ERROR_RESPONSE = 0x01 SDP_SERVICE_SEARCH_REQUEST = 0x02 SDP_SERVICE_SEARCH_RESPONSE = 0x03 SDP_SERVICE_ATTRIBUTE_REQUEST = 0x04 SDP_SERVICE_ATTRIBUTE_RESPONSE = 0x05 SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST = 0x06 SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE = 0x07 class ErrorCode(hci.SpecableEnum): INVALID_SDP_VERSION = 0x0001 INVALID_SERVICE_RECORD_HANDLE = 0x0002 INVALID_REQUEST_SYNTAX = 0x0003 INVALID_PDU_SIZE = 0x0004 INVALID_CONTINUATION_STATE = 0x0005 INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST = 0x0006 SDP_SERVICE_NAME_ATTRIBUTE_ID_OFFSET = 0x0000 SDP_SERVICE_DESCRIPTION_ATTRIBUTE_ID_OFFSET = 0x0001 SDP_PROVIDER_NAME_ATTRIBUTE_ID_OFFSET = 0x0002 SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID = 0X0000 SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID = 0X0001 SDP_SERVICE_RECORD_STATE_ATTRIBUTE_ID = 0X0002 SDP_SERVICE_ID_ATTRIBUTE_ID = 0X0003 SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X0004 SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID = 0X0005 SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID = 0X0006 SDP_SERVICE_INFO_TIME_TO_LIVE_ATTRIBUTE_ID = 0X0007 SDP_SERVICE_AVAILABILITY_ATTRIBUTE_ID = 0X0008 SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X0009 SDP_DOCUMENTATION_URL_ATTRIBUTE_ID = 0X000A SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID = 0X000B SDP_ICON_URL_ATTRIBUTE_ID = 0X000C SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D # Profile-specific Attribute Identifiers (cf. Assigned Numbers for Service Discovery) # used by AVRCP, HFP and A2DP SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311 SDP_ATTRIBUTE_ID_NAMES = { SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID: 'SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID', SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID: 'SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID', SDP_SERVICE_RECORD_STATE_ATTRIBUTE_ID: 'SDP_SERVICE_RECORD_STATE_ATTRIBUTE_ID', SDP_SERVICE_ID_ATTRIBUTE_ID: 'SDP_SERVICE_ID_ATTRIBUTE_ID', SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID', SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID: 'SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID', SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID: 'SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID', SDP_SERVICE_INFO_TIME_TO_LIVE_ATTRIBUTE_ID: 'SDP_SERVICE_INFO_TIME_TO_LIVE_ATTRIBUTE_ID', SDP_SERVICE_AVAILABILITY_ATTRIBUTE_ID: 'SDP_SERVICE_AVAILABILITY_ATTRIBUTE_ID', SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID', SDP_DOCUMENTATION_URL_ATTRIBUTE_ID: 'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID', SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID: 'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID', SDP_ICON_URL_ATTRIBUTE_ID: 'SDP_ICON_URL_ATTRIBUTE_ID', SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID', SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID: 'SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID', } SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot') # To be used in searches where an attribute ID list allows a range to be specified SDP_ALL_ATTRIBUTES_RANGE = (0x0000, 0xFFFF) # fmt: on # pylint: enable=line-too-long # pylint: disable=invalid-name # ----------------------------------------------------------------------------- @dataclass class DataElement: class Type(utils.OpenIntEnum): NIL = 0 UNSIGNED_INTEGER = 1 SIGNED_INTEGER = 2 UUID = 3 TEXT_STRING = 4 BOOLEAN = 5 SEQUENCE = 6 ALTERNATIVE = 7 URL = 8 NIL = Type.NIL UNSIGNED_INTEGER = Type.UNSIGNED_INTEGER SIGNED_INTEGER = Type.SIGNED_INTEGER UUID = Type.UUID TEXT_STRING = Type.TEXT_STRING BOOLEAN = Type.BOOLEAN SEQUENCE = Type.SEQUENCE ALTERNATIVE = Type.ALTERNATIVE URL = Type.URL TYPE_CONSTRUCTORS = { NIL: lambda x: DataElement(DataElement.NIL, None), UNSIGNED_INTEGER: lambda x, y: DataElement( DataElement.UNSIGNED_INTEGER, DataElement.unsigned_integer_from_bytes(x), value_size=y, ), SIGNED_INTEGER: lambda x, y: DataElement( DataElement.SIGNED_INTEGER, DataElement.signed_integer_from_bytes(x), value_size=y, ), UUID: lambda x: DataElement( DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x))) ), TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x), BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1), SEQUENCE: lambda x: DataElement( DataElement.SEQUENCE, DataElement.list_from_bytes(x) ), ALTERNATIVE: lambda x: DataElement( DataElement.ALTERNATIVE, DataElement.list_from_bytes(x) ), URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')), } type: Type value: Any value_size: int | None = None def __post_init__(self) -> None: # Used as a cache when parsing from bytes so we can emit a byte-for-byte replica self._bytes: bytes | None = None if self.type in ( DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER, ): if self.value_size is None: raise InvalidArgumentError( 'integer types must have a value size specified' ) @staticmethod def nil() -> DataElement: return DataElement(DataElement.NIL, None) @staticmethod def unsigned_integer(value: int, value_size: int) -> DataElement: return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size) @staticmethod def unsigned_integer_8(value: int) -> DataElement: return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1) @staticmethod def unsigned_integer_16(value: int) -> DataElement: return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2) @staticmethod def unsigned_integer_32(value: int) -> DataElement: return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4) @staticmethod def signed_integer(value: int, value_size: int) -> DataElement: return DataElement(DataElement.SIGNED_INTEGER, value, value_size) @staticmethod def signed_integer_8(value: int) -> DataElement: return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1) @staticmethod def signed_integer_16(value: int) -> DataElement: return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2) @staticmethod def signed_integer_32(value: int) -> DataElement: return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4) @staticmethod def uuid(value: core.UUID) -> DataElement: return DataElement(DataElement.UUID, value) @staticmethod def text_string(value: bytes) -> DataElement: return DataElement(DataElement.TEXT_STRING, value) @staticmethod def boolean(value: bool) -> DataElement: return DataElement(DataElement.BOOLEAN, value) @staticmethod def sequence(value: Iterable[DataElement]) -> DataElement: return DataElement(DataElement.SEQUENCE, value) @staticmethod def alternative(value: Iterable[DataElement]) -> DataElement: return DataElement(DataElement.ALTERNATIVE, value) @staticmethod def url(value: str) -> DataElement: return DataElement(DataElement.URL, value) @staticmethod def unsigned_integer_from_bytes(data): if len(data) == 1: return data[0] if len(data) == 2: return struct.unpack('>H', data)[0] if len(data) == 4: return struct.unpack('>I', data)[0] if len(data) == 8: return struct.unpack('>Q', data)[0] raise InvalidPacketError(f'invalid integer length {len(data)}') @staticmethod def signed_integer_from_bytes(data): if len(data) == 1: return struct.unpack('b', data)[0] if len(data) == 2: return struct.unpack('>h', data)[0] if len(data) == 4: return struct.unpack('>i', data)[0] if len(data) == 8: return struct.unpack('>q', data)[0] raise InvalidPacketError(f'invalid integer length {len(data)}') @staticmethod def list_from_bytes(data): elements = [] while data: element = DataElement.from_bytes(data) elements.append(element) data = data[len(bytes(element)) :] return elements @staticmethod def parse_from_bytes(data, offset): element = DataElement.from_bytes(data[offset:]) return offset + len(bytes(element)), element @staticmethod def from_bytes(data): element_type = data[0] >> 3 size_index = data[0] & 7 value_offset = 0 if size_index == 0: if element_type == DataElement.NIL: value_size = 0 else: value_size = 1 elif size_index == 1: value_size = 2 elif size_index == 2: value_size = 4 elif size_index == 3: value_size = 8 elif size_index == 4: value_size = 16 elif size_index == 5: value_size = data[1] value_offset = 1 elif size_index == 6: value_size = struct.unpack('>H', data[1:3])[0] value_offset = 2 else: # size_index == 7 value_size = struct.unpack('>I', data[1:5])[0] value_offset = 4 value_data = data[1 + value_offset : 1 + value_offset + value_size] constructor = DataElement.TYPE_CONSTRUCTORS.get(element_type) if constructor: if element_type in ( DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER, ): result = constructor(value_data, value_size) else: result = constructor(value_data) else: result = DataElement(element_type, value_data) result._bytes = data[ : 1 + value_offset + value_size ] # Keep a copy so we can re-serialize to an exact replica return result def __bytes__(self): # Return early if we have a cache if self._bytes: return self._bytes if self.type == DataElement.NIL: data = b'' elif self.type == DataElement.UNSIGNED_INTEGER: if self.value < 0: raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative') if self.value_size == 1: data = struct.pack('B', self.value) elif self.value_size == 2: data = struct.pack('>H', self.value) elif self.value_size == 4: data = struct.pack('>I', self.value) elif self.value_size == 8: data = struct.pack('>Q', self.value) else: raise InvalidArgumentError('invalid value_size') elif self.type == DataElement.SIGNED_INTEGER: if self.value_size == 1: data = struct.pack('b', self.value) elif self.value_size == 2: data = struct.pack('>h', self.value) elif self.value_size == 4: data = struct.pack('>i', self.value) elif self.value_size == 8: data = struct.pack('>q', self.value) else: raise InvalidArgumentError('invalid value_size') elif self.type == DataElement.UUID: data = bytes(reversed(bytes(self.value))) elif self.type == DataElement.URL: data = self.value.encode('utf8') elif self.type == DataElement.BOOLEAN: data = bytes([1 if self.value else 0]) elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE): data = b''.join([bytes(element) for element in self.value]) else: data = self.value size = len(data) size_bytes = b'' if self.type == DataElement.NIL: if size != 0: raise InvalidArgumentError('NIL must be empty') size_index = 0 elif self.type in ( DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER, DataElement.UUID, ): if size <= 1: size_index = 0 elif size == 2: size_index = 1 elif size == 4: size_index = 2 elif size == 8: size_index = 3 elif size == 16: size_index = 4 else: raise InvalidArgumentError('invalid data size') elif self.type in ( DataElement.TEXT_STRING, DataElement.SEQUENCE, DataElement.ALTERNATIVE, DataElement.URL, ): if size <= 0xFF: size_index = 5 size_bytes = bytes([size]) elif size <= 0xFFFF: size_index = 6 size_bytes = struct.pack('>H', size) elif size <= 0xFFFFFFFF: size_index = 7 size_bytes = struct.pack('>I', size) else: raise InvalidArgumentError('invalid data size') elif self.type == DataElement.BOOLEAN: if size != 1: raise InvalidArgumentError('boolean must be 1 byte') size_index = 0 else: raise RuntimeError("internal error - self.type not supported") self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data return self._bytes def to_string(self, pretty=False, indentation=0): prefix = ' ' * indentation type_name = self.type.name if self.type == DataElement.NIL: value_string = '' elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE): container_separator = '\n' if pretty else '' element_separator = '\n' if pretty else ',' elements = [ element.to_string(pretty, indentation + 1 if pretty else 0) for element in self.value ] value_string = ( f'[{container_separator}' f'{element_separator.join(elements)}' f'{container_separator}{prefix}]' ) elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER): value_string = f'{self.value}#{self.value_size}' elif isinstance(self.value, DataElement): value_string = self.value.to_string(pretty, indentation) else: value_string = str(self.value) return f'{prefix}{type_name}({value_string})' def __str__(self): return self.to_string() # ----------------------------------------------------------------------------- @dataclass class ServiceAttribute: id: int value: DataElement @staticmethod def list_from_data_elements( elements: Sequence[DataElement], ) -> list[ServiceAttribute]: attribute_list = [] for i in range(0, len(elements) // 2): attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)] if attribute_id.type != DataElement.UNSIGNED_INTEGER: logger.warning('attribute ID element is not an integer') continue attribute_list.append(ServiceAttribute(attribute_id.value, attribute_value)) return attribute_list @staticmethod def find_attribute_in_list( attribute_list: Iterable[ServiceAttribute], attribute_id: int ) -> DataElement | None: return next( ( attribute.value for attribute in attribute_list if attribute.id == attribute_id ), None, ) @staticmethod def id_name(id_code): return hci.name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code) @staticmethod def is_uuid_in_value(uuid: core.UUID, value: DataElement) -> bool: # Find if a uuid matches a value, either directly or recursing into sequences if value.type == DataElement.UUID: return value.value == uuid if value.type == DataElement.SEQUENCE: for element in value.value: if ServiceAttribute.is_uuid_in_value(uuid, element): return True return False return False def to_string(self, with_colors=False): if with_colors: return ( f'Attribute(id={color(self.id_name(self.id), "magenta")},' f'value={self.value})' ) return f'Attribute(id={self.id_name(self.id)},value={self.value})' def __str__(self): return self.to_string() # ----------------------------------------------------------------------------- def _parse_service_record_handle_list( data: bytes, offset: int ) -> tuple[int, list[int]]: count = struct.unpack_from('>H', data, offset)[0] offset += 2 handle_list = [ struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count) ] return offset + count * 4, handle_list def _serialize_service_record_handle_list( handles: list[int], ) -> bytes: return struct.pack('>H', len(handles)) + b''.join( struct.pack('>I', handle) for handle in handles ) def _parse_bytes_preceded_by_length(data: bytes, offset: int) -> tuple[int, bytes]: length = struct.unpack_from('>H', data, offset)[0] offset += 2 return offset + length, data[offset : offset + length] def _serialize_bytes_preceded_by_length(data: bytes) -> bytes: return struct.pack('>H', len(data)) + data _SERVICE_RECORD_HANDLE_LIST_METADATA = hci.metadata( { 'parser': _parse_service_record_handle_list, 'serializer': _serialize_service_record_handle_list, } ) _BYTES_PRECEDED_BY_LENGTH_METADATA = hci.metadata( { 'parser': _parse_bytes_preceded_by_length, 'serializer': _serialize_bytes_preceded_by_length, } ) # ----------------------------------------------------------------------------- @dataclass class SDP_PDU: ''' See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT ''' RESPONSE_PDU_IDS = { PduId.SDP_SERVICE_SEARCH_REQUEST: PduId.SDP_SERVICE_SEARCH_RESPONSE, PduId.SDP_SERVICE_ATTRIBUTE_REQUEST: PduId.SDP_SERVICE_ATTRIBUTE_RESPONSE, PduId.SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: PduId.SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE, } subclasses: ClassVar[dict[int, type[SDP_PDU]]] = {} pdu_id: ClassVar[PduId] fields: ClassVar[hci.Fields] transaction_id: int _payload: bytes | None = field(init=False, repr=False, default=None) @classmethod def from_bytes(cls, pdu: bytes) -> SDP_PDU: pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0) subclass = cls.subclasses.get(pdu_id) if not (subclass := cls.subclasses.get(pdu_id)): raise InvalidPacketError(f"Unknown PDU type {pdu_id}") instance = subclass( transaction_id=transaction_id, **hci.HCI_Object.dict_from_bytes(pdu, 5, subclass.fields), ) instance._payload = pdu return instance _PDU = TypeVar('_PDU', bound='SDP_PDU') @classmethod def subclass(cls, subclass: type[_PDU]) -> type[_PDU]: subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass) cls.subclasses[subclass.pdu_id] = subclass return subclass def __bytes__(self): if self._payload is None: self._payload = struct.pack( '>BHH', self.pdu_id, self.transaction_id, 0 ) + hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields) return self._payload @property def name(self) -> str: return self.pdu_id.name def __str__(self): result = f'{color(self.name, "blue")} [TID={self.transaction_id}]' if fields := getattr(self, 'fields', None): result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ') elif len(self.pdu) > 1: result += f': {self.pdu.hex()}' return result # ----------------------------------------------------------------------------- @SDP_PDU.subclass @dataclass class SDP_ErrorResponse(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU ''' pdu_id = PduId.SDP_ERROR_RESPONSE error_code: ErrorCode = field(metadata=ErrorCode.type_metadata(2)) # ----------------------------------------------------------------------------- @SDP_PDU.subclass @dataclass class SDP_ServiceSearchRequest(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU ''' pdu_id = PduId.SDP_SERVICE_SEARCH_REQUEST service_search_pattern: DataElement = field( metadata=hci.metadata(DataElement.parse_from_bytes) ) maximum_service_record_count: int = field(metadata=hci.metadata('>2')) continuation_state: bytes = field(metadata=hci.metadata('*')) # ----------------------------------------------------------------------------- @SDP_PDU.subclass @dataclass class SDP_ServiceSearchResponse(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU ''' pdu_id = PduId.SDP_SERVICE_SEARCH_RESPONSE total_service_record_count: int = field(metadata=hci.metadata('>2')) service_record_handle_list: Sequence[int] = field( metadata=_SERVICE_RECORD_HANDLE_LIST_METADATA ) continuation_state: bytes = field(metadata=hci.metadata('*')) # ----------------------------------------------------------------------------- @SDP_PDU.subclass @dataclass class SDP_ServiceAttributeRequest(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU ''' pdu_id = PduId.SDP_SERVICE_ATTRIBUTE_REQUEST service_record_handle: int = field(metadata=hci.metadata('>4')) maximum_attribute_byte_count: int = field(metadata=hci.metadata('>2')) attribute_id_list: DataElement = field( metadata=hci.metadata(DataElement.parse_from_bytes) ) continuation_state: bytes = field(metadata=hci.metadata('*')) # ----------------------------------------------------------------------------- @SDP_PDU.subclass @dataclass class SDP_ServiceAttributeResponse(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU ''' pdu_id = PduId.SDP_SERVICE_ATTRIBUTE_RESPONSE attribute_list: bytes = field(metadata=_BYTES_PRECEDED_BY_LENGTH_METADATA) continuation_state: bytes = field(metadata=hci.metadata('*')) # ----------------------------------------------------------------------------- @SDP_PDU.subclass @dataclass class SDP_ServiceSearchAttributeRequest(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU ''' pdu_id = PduId.SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST service_search_pattern: DataElement = field( metadata=hci.metadata(DataElement.parse_from_bytes) ) maximum_attribute_byte_count: int = field(metadata=hci.metadata('>2')) attribute_id_list: DataElement = field( metadata=hci.metadata(DataElement.parse_from_bytes) ) continuation_state: bytes = field(metadata=hci.metadata('*')) # ----------------------------------------------------------------------------- @SDP_PDU.subclass @dataclass class SDP_ServiceSearchAttributeResponse(SDP_PDU): ''' See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU ''' pdu_id = PduId.SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE attribute_lists: bytes = field(metadata=_BYTES_PRECEDED_BY_LENGTH_METADATA) continuation_state: bytes = field(metadata=hci.metadata('*')) # ----------------------------------------------------------------------------- class Client: def __init__(self, connection: Connection, mtu: int = 0) -> None: self.connection = connection self.channel: l2cap.ClassicChannel | None = None self.mtu = mtu self.request_semaphore = asyncio.Semaphore(1) self.pending_request: SDP_PDU | None = None self.pending_response: asyncio.futures.Future[SDP_PDU] | None = None self.next_transaction_id = 0 async def connect(self) -> None: self.channel = await self.connection.create_l2cap_channel( spec=( l2cap.ClassicChannelSpec(SDP_PSM, self.mtu) if self.mtu else l2cap.ClassicChannelSpec(SDP_PSM) ) ) self.channel.sink = self.on_pdu async def disconnect(self) -> None: if self.channel: await self.channel.disconnect() self.channel = None def make_transaction_id(self) -> int: transaction_id = self.next_transaction_id self.next_transaction_id = (self.next_transaction_id + 1) & 0xFFFF return transaction_id def on_pdu(self, pdu: bytes) -> None: if not self.pending_request: logger.warning('received response with no pending request') return assert self.pending_response is not None response = SDP_PDU.from_bytes(pdu) # Check that the transaction ID is what we expect if self.pending_request.transaction_id != response.transaction_id: logger.warning( f"received response with transaction ID {response.transaction_id} " f"but expected {self.pending_request.transaction_id}" ) return # Check if the response is an error if isinstance(response, SDP_ErrorResponse): self.pending_response.set_exception( ProtocolError(error_code=response.error_code) ) return # Check that the type of the response matches the request if response.pdu_id != SDP_PDU.RESPONSE_PDU_IDS.get(self.pending_request.pdu_id): logger.warning("response type mismatch") return self.pending_response.set_result(response) async def send_request(self, request: SDP_PDU) -> SDP_PDU: assert self.channel is not None async with self.request_semaphore: assert self.pending_request is None assert self.pending_response is None # Create a future value to hold the eventual response self.pending_response = asyncio.get_running_loop().create_future() self.pending_request = request try: self.channel.write(bytes(request)) return await self.pending_response finally: self.pending_request = None self.pending_response = None async def search_services(self, uuids: Iterable[core.UUID]) -> list[int]: """ Search for services by UUID. Args: uuids: service the UUIDs to search for. Returns: A list of matching service record handles. """ if self.pending_request is not None: raise InvalidStateError('request already pending') if self.channel is None: raise InvalidStateError('L2CAP not connected') service_search_pattern = DataElement.sequence( [DataElement.uuid(uuid) for uuid in uuids] ) # Request and accumulate until there's no more continuation service_record_handle_list: list[int] = [] continuation_state = bytes([0]) watchdog = SDP_CONTINUATION_WATCHDOG while watchdog > 0: response = await self.send_request( SDP_ServiceSearchRequest( transaction_id=self.make_transaction_id(), service_search_pattern=service_search_pattern, maximum_service_record_count=0xFFFF, continuation_state=continuation_state, ) ) logger.debug(f'<<< Response: {response}') assert isinstance(response, SDP_ServiceSearchResponse) service_record_handle_list += response.service_record_handle_list continuation_state = response.continuation_state if len(continuation_state) == 1 and continuation_state[0] == 0: break logger.debug(f'continuation: {continuation_state.hex()}') watchdog -= 1 return service_record_handle_list async def search_attributes( self, uuids: Iterable[core.UUID], attribute_ids: Iterable[int | tuple[int, int]], ) -> list[list[ServiceAttribute]]: """ Search for attributes by UUID and attribute IDs. Args: uuids: the service UUIDs to search for. attribute_ids: list of attribute IDs or (start, end) attribute ID ranges. (use (0, 0xFFFF) to include all attributes) Returns: A list of list of attributes, one list per matching service. """ if self.pending_request is not None: raise InvalidStateError('request already pending') if self.channel is None: raise InvalidStateError('L2CAP not connected') service_search_pattern = DataElement.sequence( [DataElement.uuid(uuid) for uuid in uuids] ) attribute_id_list = DataElement.sequence( [ ( DataElement.unsigned_integer_32( attribute_id[0] << 16 | attribute_id[1] ) if isinstance(attribute_id, tuple) else DataElement.unsigned_integer_16(attribute_id) ) for attribute_id in attribute_ids ] ) # Request and accumulate until there's no more continuation accumulator = b'' continuation_state = bytes([0]) watchdog = SDP_CONTINUATION_WATCHDOG while watchdog > 0: response = await self.send_request( SDP_ServiceSearchAttributeRequest( transaction_id=self.make_transaction_id(), service_search_pattern=service_search_pattern, maximum_attribute_byte_count=0xFFFF, attribute_id_list=attribute_id_list, continuation_state=continuation_state, ) ) logger.debug(f'<<< Response: {response}') assert isinstance(response, SDP_ServiceSearchAttributeResponse) accumulator += response.attribute_lists continuation_state = response.continuation_state if len(continuation_state) == 1 and continuation_state[0] == 0: break logger.debug(f'continuation: {continuation_state.hex()}') watchdog -= 1 # Parse the result into attribute lists attribute_lists_sequences = DataElement.from_bytes(accumulator) if attribute_lists_sequences.type != DataElement.SEQUENCE: logger.warning('unexpected data type') return [] return [ ServiceAttribute.list_from_data_elements(sequence.value) for sequence in attribute_lists_sequences.value if sequence.type == DataElement.SEQUENCE ] async def get_attributes( self, service_record_handle: int, attribute_ids: Iterable[int | tuple[int, int]], ) -> list[ServiceAttribute]: """ Get attributes for a service. Args: service_record_handle: the handle for a service attribute_ids: list or attribute IDs or (start, end) attribute ID handles. Returns: A list of attributes. """ if self.pending_request is not None: raise InvalidStateError('request already pending') if self.channel is None: raise InvalidStateError('L2CAP not connected') attribute_id_list = DataElement.sequence( [ ( DataElement.unsigned_integer_32( attribute_id[0] << 16 | attribute_id[1] ) if isinstance(attribute_id, tuple) else DataElement.unsigned_integer_16(attribute_id) ) for attribute_id in attribute_ids ] ) # Request and accumulate until there's no more continuation accumulator = b'' continuation_state = bytes([0]) watchdog = SDP_CONTINUATION_WATCHDOG while watchdog > 0: response = await self.send_request( SDP_ServiceAttributeRequest( transaction_id=self.make_transaction_id(), service_record_handle=service_record_handle, maximum_attribute_byte_count=0xFFFF, attribute_id_list=attribute_id_list, continuation_state=continuation_state, ) ) logger.debug(f'<<< Response: {response}') assert isinstance(response, SDP_ServiceAttributeResponse) accumulator += response.attribute_list continuation_state = response.continuation_state if len(continuation_state) == 1 and continuation_state[0] == 0: break logger.debug(f'continuation: {continuation_state.hex()}') watchdog -= 1 # Parse the result into a list of attributes attribute_list_sequence = DataElement.from_bytes(accumulator) if attribute_list_sequence.type != DataElement.SEQUENCE: logger.warning('unexpected data type') return [] return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value) async def __aenter__(self) -> Self: await self.connect() return self async def __aexit__(self, *args) -> None: await self.disconnect() # ----------------------------------------------------------------------------- class Server: CONTINUATION_STATE = bytes([0x01, 0x00]) channel: l2cap.ClassicChannel | None Service = NewType('Service', list[ServiceAttribute]) service_records: dict[int, Service] current_response: None | bytes | tuple[int, list[int]] def __init__(self, device: Device) -> None: self.device = device self.service_records = {} # Service records maps, by record handle self.channel = None self.current_response = None # Current response data, used for continuations def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None: l2cap_channel_manager.create_classic_server( spec=l2cap.ClassicChannelSpec(psm=SDP_PSM), handler=self.on_connection ) def send_response(self, response): logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}') self.channel.write(response) def match_services(self, search_pattern: DataElement) -> dict[int, Service]: # Find the services for which the attributes in the pattern is a subset of the # service's attribute values (NOTE: the value search recurses into sequences) matching_services = {} for handle, service in self.service_records.items(): for uuid in search_pattern.value: found = False for attribute in service: if ServiceAttribute.is_uuid_in_value(uuid.value, attribute.value): found = True break if found: matching_services[handle] = service break return matching_services def on_connection(self, channel): self.channel = channel self.channel.sink = self.on_pdu def on_pdu(self, pdu): try: sdp_pdu = SDP_PDU.from_bytes(pdu) except Exception: logger.exception(color('failed to parse SDP Request PDU', 'red')) self.send_response( SDP_ErrorResponse( transaction_id=0, error_code=ErrorCode.INVALID_REQUEST_SYNTAX ) ) logger.debug(f'{color("<<< Received SDP Request", "green")}: {sdp_pdu}') # Find the handler method handler_name = f'on_{sdp_pdu.name.lower()}' handler = getattr(self, handler_name, None) if handler: try: handler(sdp_pdu) except Exception: logger.exception(color("!!! Exception in handler:", "red")) self.send_response( SDP_ErrorResponse( transaction_id=sdp_pdu.transaction_id, error_code=ErrorCode.INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST, ) ) else: logger.error(color('SDP Request not handled???', 'red')) self.send_response( SDP_ErrorResponse( transaction_id=sdp_pdu.transaction_id, error_code=ErrorCode.INVALID_REQUEST_SYNTAX, ) ) def check_continuation( self, continuation_state: bytes, transaction_id: int, ) -> bool | None: # Check if this is a valid continuation if len(continuation_state) > 1: if ( self.current_response is None or continuation_state != self.CONTINUATION_STATE ): self.send_response( SDP_ErrorResponse( transaction_id=transaction_id, error_code=ErrorCode.INVALID_CONTINUATION_STATE, ) ) return None return True # Cleanup any partial response leftover self.current_response = None return False def get_next_response_payload(self, maximum_size): if len(self.current_response) > maximum_size: payload = self.current_response[:maximum_size] continuation_state = Server.CONTINUATION_STATE self.current_response = self.current_response[maximum_size:] else: payload = self.current_response continuation_state = bytes([0]) self.current_response = None return (payload, continuation_state) @staticmethod def get_service_attributes( service: Service, attribute_ids: Iterable[DataElement] ) -> DataElement: attributes = [] for attribute_id in attribute_ids: if attribute_id.value_size == 4: # Attribute ID range id_range_start = attribute_id.value >> 16 id_range_end = attribute_id.value & 0xFFFF else: id_range_start = attribute_id.value id_range_end = attribute_id.value attributes += [ attribute for attribute in service if attribute.id >= id_range_start and attribute.id <= id_range_end ] # Return the matching attributes, sorted by attribute id attributes.sort(key=lambda x: x.id) attribute_list = DataElement.sequence([]) for attribute in attributes: attribute_list.value.append(DataElement.unsigned_integer_16(attribute.id)) attribute_list.value.append(attribute.value) return attribute_list def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None: # Check if this is a continuation if ( continuation := self.check_continuation( request.continuation_state, request.transaction_id ) ) is None: return if not continuation: # Find the matching services matching_services = self.match_services(request.service_search_pattern) service_record_handles = list(matching_services.keys()) logger.debug(f'Service Record Handles: {service_record_handles}') # Only return up to the maximum requested service_record_handles_subset = service_record_handles[ : request.maximum_service_record_count ] self.current_response = ( len(service_record_handles), service_record_handles_subset, ) # Respond, keeping any unsent handles for later assert isinstance(self.current_response, tuple) assert self.channel is not None total_service_record_count, service_record_handles = self.current_response maximum_service_record_count = (self.channel.peer_mtu - 11) // 4 service_record_handles_remaining = service_record_handles[ maximum_service_record_count: ] service_record_handles = service_record_handles[:maximum_service_record_count] self.current_response = ( total_service_record_count, service_record_handles_remaining, ) continuation_state = ( Server.CONTINUATION_STATE if service_record_handles_remaining else bytes([0]) ) self.send_response( SDP_ServiceSearchResponse( transaction_id=request.transaction_id, total_service_record_count=total_service_record_count, service_record_handle_list=service_record_handles, continuation_state=continuation_state, ) ) def on_sdp_service_attribute_request( self, request: SDP_ServiceAttributeRequest ) -> None: # Check if this is a continuation if ( continuation := self.check_continuation( request.continuation_state, request.transaction_id ) ) is None: return if not continuation: # Check that the service exists service = self.service_records.get(request.service_record_handle) if service is None: self.send_response( SDP_ErrorResponse( transaction_id=request.transaction_id, error_code=ErrorCode.INVALID_SERVICE_RECORD_HANDLE, ) ) return # Get the attributes for the service attribute_list = Server.get_service_attributes( service, request.attribute_id_list.value ) # Serialize to a byte array logger.debug(f'Attributes: {attribute_list}') self.current_response = bytes(attribute_list) # Respond, keeping any pending chunks for later assert self.channel is not None maximum_attribute_byte_count = min( request.maximum_attribute_byte_count, self.channel.peer_mtu - 9 ) attribute_list_response, continuation_state = self.get_next_response_payload( maximum_attribute_byte_count ) self.send_response( SDP_ServiceAttributeResponse( transaction_id=request.transaction_id, attribute_list=attribute_list_response, continuation_state=continuation_state, ) ) def on_sdp_service_search_attribute_request( self, request: SDP_ServiceSearchAttributeRequest ) -> None: # Check if this is a continuation if ( continuation := self.check_continuation( request.continuation_state, request.transaction_id ) ) is None: return if not continuation: # Find the matching services matching_services = self.match_services( request.service_search_pattern ).values() # Filter the required attributes attribute_lists = DataElement.sequence([]) for service in matching_services: attribute_list = Server.get_service_attributes( service, request.attribute_id_list.value ) if attribute_list.value: attribute_lists.value.append(attribute_list) # Serialize to a byte array logger.debug(f'Search response: {attribute_lists}') self.current_response = bytes(attribute_lists) # Respond, keeping any pending chunks for later assert self.channel is not None maximum_attribute_byte_count = min( request.maximum_attribute_byte_count, self.channel.peer_mtu - 9 ) attribute_lists_response, continuation_state = self.get_next_response_payload( maximum_attribute_byte_count ) self.send_response( SDP_ServiceSearchAttributeResponse( transaction_id=request.transaction_id, attribute_lists=attribute_lists_response, continuation_state=continuation_state, ) )