diff --git a/bumble/l2cap.py b/bumble/l2cap.py index a7a944d..4a49f79 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -773,7 +773,6 @@ class ClassicChannel(EventEmitter): self.psm = psm self.source_cid = source_cid self.destination_cid = 0 - self.response = None self.connection_result = None self.disconnection_result = None self.sink = None @@ -783,27 +782,15 @@ class ClassicChannel(EventEmitter): self.state = new_state def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None: + if self.state != self.State.OPEN: + raise InvalidStateError('channel not open') self.manager.send_pdu(self.connection, self.destination_cid, pdu) def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: self.manager.send_control_frame(self.connection, self.signaling_cid, frame) - async def send_request(self, request: SupportsBytes) -> bytes: - # Check that there isn't already a request pending - if self.response: - raise InvalidStateError('request already pending') - if self.state != self.State.OPEN: - raise InvalidStateError('channel not open') - - self.response = asyncio.get_running_loop().create_future() - self.send_pdu(request) - return await self.response - def on_pdu(self, pdu: bytes) -> None: - if self.response: - self.response.set_result(pdu) - self.response = None - elif self.sink: + if self.sink: # pylint: disable=not-callable self.sink(pdu) else: diff --git a/bumble/sdp.py b/bumble/sdp.py index 826bd59..15d3cc8 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -16,15 +16,21 @@ # Imports # ----------------------------------------------------------------------------- from __future__ import annotations +import asyncio import logging import struct -from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING +from typing import Iterable, NewType, Optional, Union, Type, TYPE_CHECKING from typing_extensions import Self -from . import core, l2cap -from .colors import color -from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError -from .hci import HCI_Object, name_or_number, key_with_value +from bumble import core, l2cap +from bumble.colors import color +from bumble.core import ( + InvalidStateError, + InvalidArgumentError, + InvalidPacketError, + ProtocolError, +) +from bumble.hci import HCI_Object, name_or_number, key_with_value if TYPE_CHECKING: from .device import Device, Connection @@ -242,11 +248,11 @@ class DataElement: return DataElement(DataElement.BOOLEAN, value) @staticmethod - def sequence(value: List[DataElement]) -> DataElement: + def sequence(value: list[DataElement]) -> DataElement: return DataElement(DataElement.SEQUENCE, value) @staticmethod - def alternative(value: List[DataElement]) -> DataElement: + def alternative(value: list[DataElement]) -> DataElement: return DataElement(DataElement.ALTERNATIVE, value) @staticmethod @@ -473,7 +479,7 @@ class ServiceAttribute: self.value = value @staticmethod - def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]: + def list_from_data_elements(elements: list[DataElement]) -> list[ServiceAttribute]: attribute_list = [] for i in range(0, len(elements) // 2): attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)] @@ -486,7 +492,7 @@ class ServiceAttribute: @staticmethod def find_attribute_in_list( - attribute_list: List[ServiceAttribute], attribute_id: int + attribute_list: list[ServiceAttribute], attribute_id: int ) -> Optional[DataElement]: return next( ( @@ -534,7 +540,12 @@ class SDP_PDU: See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT ''' - sdp_pdu_classes: Dict[int, Type[SDP_PDU]] = {} + RESPONSE_PDU_IDS = { + SDP_SERVICE_SEARCH_REQUEST: SDP_SERVICE_SEARCH_RESPONSE, + SDP_SERVICE_ATTRIBUTE_REQUEST: SDP_SERVICE_ATTRIBUTE_RESPONSE, + SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE, + } + sdp_pdu_classes: dict[int, Type[SDP_PDU]] = {} name = None pdu_id = 0 @@ -558,7 +569,7 @@ class SDP_PDU: @staticmethod def parse_service_record_handle_list_preceded_by_count( data: bytes, offset: int - ) -> Tuple[int, List[int]]: + ) -> tuple[int, list[int]]: count = struct.unpack_from('>H', data, offset - 2)[0] handle_list = [ struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count) @@ -639,6 +650,8 @@ class SDP_ErrorResponse(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU ''' + error_code: int + # ----------------------------------------------------------------------------- @SDP_PDU.subclass( @@ -675,7 +688,7 @@ class SDP_ServiceSearchResponse(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU ''' - service_record_handle_list: List[int] + service_record_handle_list: list[int] total_service_record_count: int current_service_record_count: int continuation_state: bytes @@ -752,31 +765,99 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU ''' - attribute_list_byte_count: int - attribute_list: bytes + attribute_lists_byte_count: int + attribute_lists: bytes continuation_state: bytes # ----------------------------------------------------------------------------- class Client: - channel: Optional[l2cap.ClassicChannel] - - def __init__(self, connection: Connection) -> None: + def __init__(self, connection: Connection, mtu: int = 0) -> None: self.connection = connection - self.pending_request = None - self.channel = None + self.channel: Optional[l2cap.ClassicChannel] = None + self.mtu = mtu + self.request_semaphore = asyncio.Semaphore(1) + self.pending_request: Optional[SDP_PDU] = None + self.pending_response: Optional[asyncio.futures.Future[SDP_PDU]] = None + self.next_transaction_id = 0 async def connect(self) -> None: self.channel = await self.connection.create_l2cap_channel( - spec=l2cap.ClassicChannelSpec(SDP_PSM) + spec=( + l2cap.ClassicChannelSpec(SDP_PSM, self.mtu) + if self.mtu + else l2cap.ClassicChannelSpec(SDP_PSM) + ) ) + self.channel.sink = self.on_pdu async def disconnect(self) -> None: if self.channel: await self.channel.disconnect() self.channel = None - async def search_services(self, uuids: List[core.UUID]) -> List[int]: + def make_transaction_id(self) -> int: + transaction_id = self.next_transaction_id + self.next_transaction_id = self.next_transaction_id & 0xFFFF + return transaction_id + + def on_pdu(self, pdu: bytes) -> None: + if not self.pending_request: + logger.warning('received response with no pending request') + return + assert self.pending_response is not None + + response = SDP_PDU.from_bytes(pdu) + + # Check that the transaction ID is what we expect + if self.pending_request.transaction_id != response.transaction_id: + logger.warning( + f"received response with transaction ID {response.transaction_id} " + f"but expected {self.pending_request.transaction_id}" + ) + return + + # Check if the response is an error + if isinstance(response, SDP_ErrorResponse): + self.pending_response.set_exception( + ProtocolError(error_code=response.error_code) + ) + return + + # Check that the type of the response matches the request + if response.pdu_id != SDP_PDU.RESPONSE_PDU_IDS.get(self.pending_request.pdu_id): + logger.warning("response type mismatch") + return + + self.pending_response.set_result(response) + + async def send_request(self, request: SDP_PDU) -> SDP_PDU: + assert self.channel is not None + async with self.request_semaphore: + assert self.pending_request is None + assert self.pending_response is None + + # Create a future value to hold the eventual response + self.pending_response = asyncio.get_running_loop().create_future() + self.pending_request = request + + try: + self.channel.send_pdu(bytes(request)) + return await self.pending_response + finally: + self.pending_request = None + self.pending_response = None + + async def search_services(self, uuids: Iterable[core.UUID]) -> list[int]: + """ + Search for services by UUID. + + Args: + uuids: service the UUIDs to search for. + + Returns: + A list of matching service record handles. + """ if self.pending_request is not None: raise InvalidStateError('request already pending') if self.channel is None: @@ -791,16 +872,16 @@ class Client: continuation_state = bytes([0]) watchdog = SDP_CONTINUATION_WATCHDOG while watchdog > 0: - response_pdu = await self.channel.send_request( + response = await self.send_request( SDP_ServiceSearchRequest( - transaction_id=0, # Transaction ID TODO: pick a real value + transaction_id=self.make_transaction_id(), service_search_pattern=service_search_pattern, maximum_service_record_count=0xFFFF, continuation_state=continuation_state, ) ) - response = SDP_PDU.from_bytes(response_pdu) logger.debug(f'<<< Response: {response}') + assert isinstance(response, SDP_ServiceSearchResponse) service_record_handle_list += response.service_record_handle_list continuation_state = response.continuation_state if len(continuation_state) == 1 and continuation_state[0] == 0: @@ -811,8 +892,21 @@ class Client: return service_record_handle_list async def search_attributes( - self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]] - ) -> List[List[ServiceAttribute]]: + self, + uuids: list[core.UUID], + attribute_ids: Iterable[Union[int, tuple[int, int]]], + ) -> list[list[ServiceAttribute]]: + """ + Search for attributes by UUID and attribute IDs. + + Args: + uuids: the service UUIDs to search for. + attribute_ids: list of attribute IDs or (start, end) attribute ID ranges. + (use (0, 0xFFFF) to include all attributes) + + Returns: + A list of list of attributes, one list per matching service. + """ if self.pending_request is not None: raise InvalidStateError('request already pending') if self.channel is None: @@ -824,8 +918,8 @@ class Client: attribute_id_list = DataElement.sequence( [ ( - DataElement.unsigned_integer( - attribute_id[0], value_size=attribute_id[1] + DataElement.unsigned_integer_32( + attribute_id[0] << 16 | attribute_id[1] ) if isinstance(attribute_id, tuple) else DataElement.unsigned_integer_16(attribute_id) @@ -839,17 +933,17 @@ class Client: continuation_state = bytes([0]) watchdog = SDP_CONTINUATION_WATCHDOG while watchdog > 0: - response_pdu = await self.channel.send_request( + response = await self.send_request( SDP_ServiceSearchAttributeRequest( - transaction_id=0, # Transaction ID TODO: pick a real value + transaction_id=self.make_transaction_id(), service_search_pattern=service_search_pattern, maximum_attribute_byte_count=0xFFFF, attribute_id_list=attribute_id_list, continuation_state=continuation_state, ) ) - response = SDP_PDU.from_bytes(response_pdu) logger.debug(f'<<< Response: {response}') + assert isinstance(response, SDP_ServiceSearchAttributeResponse) accumulator += response.attribute_lists continuation_state = response.continuation_state if len(continuation_state) == 1 and continuation_state[0] == 0: @@ -872,8 +966,18 @@ class Client: async def get_attributes( self, service_record_handle: int, - attribute_ids: List[Union[int, Tuple[int, int]]], - ) -> List[ServiceAttribute]: + attribute_ids: list[Union[int, tuple[int, int]]], + ) -> list[ServiceAttribute]: + """ + Get attributes for a service. + + Args: + service_record_handle: the handle for a service + attribute_ids: list or attribute IDs or (start, end) attribute ID handles. + + Returns: + A list of attributes. + """ if self.pending_request is not None: raise InvalidStateError('request already pending') if self.channel is None: @@ -882,8 +986,8 @@ class Client: attribute_id_list = DataElement.sequence( [ ( - DataElement.unsigned_integer( - attribute_id[0], value_size=attribute_id[1] + DataElement.unsigned_integer_32( + attribute_id[0] << 16 | attribute_id[1] ) if isinstance(attribute_id, tuple) else DataElement.unsigned_integer_16(attribute_id) @@ -897,17 +1001,17 @@ class Client: continuation_state = bytes([0]) watchdog = SDP_CONTINUATION_WATCHDOG while watchdog > 0: - response_pdu = await self.channel.send_request( + response = await self.send_request( SDP_ServiceAttributeRequest( - transaction_id=0, # Transaction ID TODO: pick a real value + transaction_id=self.make_transaction_id(), service_record_handle=service_record_handle, maximum_attribute_byte_count=0xFFFF, attribute_id_list=attribute_id_list, continuation_state=continuation_state, ) ) - response = SDP_PDU.from_bytes(response_pdu) logger.debug(f'<<< Response: {response}') + assert isinstance(response, SDP_ServiceAttributeResponse) accumulator += response.attribute_list continuation_state = response.continuation_state if len(continuation_state) == 1 and continuation_state[0] == 0: @@ -933,17 +1037,17 @@ class Client: # ----------------------------------------------------------------------------- class Server: - CONTINUATION_STATE = bytes([0x01, 0x43]) + CONTINUATION_STATE = bytes([0x01, 0x00]) channel: Optional[l2cap.ClassicChannel] - Service = NewType('Service', List[ServiceAttribute]) - service_records: Dict[int, Service] - current_response: Union[None, bytes, Tuple[int, List[int]]] + Service = NewType('Service', list[ServiceAttribute]) + service_records: dict[int, Service] + current_response: Union[None, bytes, tuple[int, list[int]]] def __init__(self, device: Device) -> None: self.device = device self.service_records = {} # Service records maps, by record handle self.channel = None - self.current_response = None + self.current_response = None # Current response data, used for continuations def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None: l2cap_channel_manager.create_classic_server( @@ -954,7 +1058,7 @@ class Server: logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}') self.channel.send_pdu(response) - def match_services(self, search_pattern: DataElement) -> Dict[int, Service]: + def match_services(self, search_pattern: DataElement) -> dict[int, Service]: # Find the services for which the attributes in the pattern is a subset of the # service's attribute values (NOTE: the value search recurses into sequences) matching_services = {} @@ -1011,6 +1115,31 @@ class Server: ) ) + def check_continuation( + self, + continuation_state: bytes, + transaction_id: int, + ) -> Optional[bool]: + # Check if this is a valid continuation + if len(continuation_state) > 1: + if ( + self.current_response is None + or continuation_state != self.CONTINUATION_STATE + ): + self.send_response( + SDP_ErrorResponse( + transaction_id=transaction_id, + error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, + ) + ) + return None + return True + + # Cleanup any partial response leftover + self.current_response = None + + return False + def get_next_response_payload(self, maximum_size): if len(self.current_response) > maximum_size: payload = self.current_response[:maximum_size] @@ -1025,7 +1154,7 @@ class Server: @staticmethod def get_service_attributes( - service: Service, attribute_ids: List[DataElement] + service: Service, attribute_ids: list[DataElement] ) -> DataElement: attributes = [] for attribute_id in attribute_ids: @@ -1053,30 +1182,24 @@ class Server: def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None: # Check if this is a continuation - if len(request.continuation_state) > 1: - if self.current_response is None: - self.send_response( - SDP_ErrorResponse( - transaction_id=request.transaction_id, - error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, - ) - ) - return - else: - # Cleanup any partial response leftover - self.current_response = None + if ( + continuation := self.check_continuation( + request.continuation_state, request.transaction_id + ) + ) is None: + return + if not continuation: # Find the matching services matching_services = self.match_services(request.service_search_pattern) service_record_handles = list(matching_services.keys()) + logger.debug(f'Service Record Handles: {service_record_handles}') # Only return up to the maximum requested service_record_handles_subset = service_record_handles[ : request.maximum_service_record_count ] - # Serialize to a byte array, and remember the total count - logger.debug(f'Service Record Handles: {service_record_handles}') self.current_response = ( len(service_record_handles), service_record_handles_subset, @@ -1084,15 +1207,21 @@ class Server: # Respond, keeping any unsent handles for later assert isinstance(self.current_response, tuple) - service_record_handles = self.current_response[1][ - : request.maximum_service_record_count + assert self.channel is not None + total_service_record_count, service_record_handles = self.current_response + maximum_service_record_count = (self.channel.peer_mtu - 11) // 4 + service_record_handles_remaining = service_record_handles[ + maximum_service_record_count: ] + service_record_handles = service_record_handles[:maximum_service_record_count] self.current_response = ( - self.current_response[0], - self.current_response[1][request.maximum_service_record_count :], + total_service_record_count, + service_record_handles_remaining, ) continuation_state = ( - Server.CONTINUATION_STATE if self.current_response[1] else bytes([0]) + Server.CONTINUATION_STATE + if service_record_handles_remaining + else bytes([0]) ) service_record_handle_list = b''.join( [struct.pack('>I', handle) for handle in service_record_handles] @@ -1100,7 +1229,7 @@ class Server: self.send_response( SDP_ServiceSearchResponse( transaction_id=request.transaction_id, - total_service_record_count=self.current_response[0], + total_service_record_count=total_service_record_count, current_service_record_count=len(service_record_handles), service_record_handle_list=service_record_handle_list, continuation_state=continuation_state, @@ -1111,19 +1240,14 @@ class Server: self, request: SDP_ServiceAttributeRequest ) -> None: # Check if this is a continuation - if len(request.continuation_state) > 1: - if self.current_response is None: - self.send_response( - SDP_ErrorResponse( - transaction_id=request.transaction_id, - error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, - ) - ) - return - else: - # Cleanup any partial response leftover - self.current_response = None + if ( + continuation := self.check_continuation( + request.continuation_state, request.transaction_id + ) + ) is None: + return + if not continuation: # Check that the service exists service = self.service_records.get(request.service_record_handle) if service is None: @@ -1145,14 +1269,18 @@ class Server: self.current_response = bytes(attribute_list) # Respond, keeping any pending chunks for later + assert self.channel is not None + maximum_attribute_byte_count = min( + request.maximum_attribute_byte_count, self.channel.peer_mtu - 9 + ) attribute_list_response, continuation_state = self.get_next_response_payload( - request.maximum_attribute_byte_count + maximum_attribute_byte_count ) self.send_response( SDP_ServiceAttributeResponse( transaction_id=request.transaction_id, attribute_list_byte_count=len(attribute_list_response), - attribute_list=attribute_list, + attribute_list=attribute_list_response, continuation_state=continuation_state, ) ) @@ -1161,18 +1289,14 @@ class Server: self, request: SDP_ServiceSearchAttributeRequest ) -> None: # Check if this is a continuation - if len(request.continuation_state) > 1: - if self.current_response is None: - self.send_response( - SDP_ErrorResponse( - transaction_id=request.transaction_id, - error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, - ) - ) - else: - # Cleanup any partial response leftover - self.current_response = None + if ( + continuation := self.check_continuation( + request.continuation_state, request.transaction_id + ) + ) is None: + return + if not continuation: # Find the matching services matching_services = self.match_services( request.service_search_pattern @@ -1192,14 +1316,18 @@ class Server: self.current_response = bytes(attribute_lists) # Respond, keeping any pending chunks for later + assert self.channel is not None + maximum_attribute_byte_count = min( + request.maximum_attribute_byte_count, self.channel.peer_mtu - 9 + ) attribute_lists_response, continuation_state = self.get_next_response_payload( - request.maximum_attribute_byte_count + maximum_attribute_byte_count ) self.send_response( SDP_ServiceSearchAttributeResponse( transaction_id=request.transaction_id, attribute_lists_byte_count=len(attribute_lists_response), - attribute_lists=attribute_lists, + attribute_lists=attribute_lists_response, continuation_state=continuation_state, ) ) diff --git a/tests/sdp_test.py b/tests/sdp_test.py index 91835e7..26d9d38 100644 --- a/tests/sdp_test.py +++ b/tests/sdp_test.py @@ -20,12 +20,11 @@ import logging import os import pytest -from bumble.core import UUID, BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID +from bumble.core import UUID, BT_L2CAP_PROTOCOL_ID from bumble.sdp import ( DataElement, ServiceAttribute, Client, - Server, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, SDP_PUBLIC_BROWSE_ROOT, @@ -174,9 +173,10 @@ def test_data_elements() -> None: # ----------------------------------------------------------------------------- -def sdp_records(): +def sdp_records(record_count=1): return { - 0x00010001: [ + 0x00010001 + + i: [ ServiceAttribute( SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(0x00010001), @@ -200,6 +200,7 @@ def sdp_records(): ), ), ] + for i in range(record_count) } @@ -216,19 +217,55 @@ async def test_service_search(): devices.devices[0].sdp_server.service_records.update(sdp_records()) # Search for service - client = Client(devices.connections[1]) - await client.connect() - services = await client.search_services( - [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')] - ) + async with Client(devices.connections[1]) as client: + services = await client.search_services( + [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AF')] + ) + assert len(services) == 0 - # Then - assert services[0] == 0x00010001 + services = await client.search_services( + [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')] + ) + assert len(services) == 1 + assert services[0] == 0x00010001 + + services = await client.search_services( + [BT_L2CAP_PROTOCOL_ID, SDP_PUBLIC_BROWSE_ROOT] + ) + assert len(services) == 1 + assert services[0] == 0x00010001 + + services = await client.search_services( + [BT_L2CAP_PROTOCOL_ID, SDP_PUBLIC_BROWSE_ROOT] + ) + assert len(services) == 1 + assert services[0] == 0x00010001 # ----------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_service_attribute(): +async def test_service_search_with_continuation(): + # Setup connections + devices = TwoDevices() + await devices.setup_connection() + + # Register SDP service + records = sdp_records(100) + devices.devices[0].sdp_server.service_records.update(records) + + # Search for service + async with Client(devices.connections[1], mtu=48) as client: + services = await client.search_services( + [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')] + ) + assert len(services) == len(records) + for i in range(len(records)): + assert services[i] == 0x00010001 + i + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_service_attributes(): # Setup connections devices = TwoDevices() await devices.setup_connection() @@ -236,15 +273,43 @@ async def test_service_attribute(): # Register SDP service devices.devices[0].sdp_server.service_records.update(sdp_records()) - # Search for service - client = Client(devices.connections[1]) - await client.connect() - attributes = await client.get_attributes( - 0x00010001, [SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID] - ) + # Get attributes + async with Client(devices.connections[1]) as client: + attributes = await client.get_attributes(0x00010001, [1234]) + assert len(attributes) == 0 - # Then - assert attributes[0].value.value == sdp_records()[0x00010001][0].value.value + attributes = await client.get_attributes( + 0x00010001, [SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID] + ) + assert len(attributes) == 1 + assert attributes[0].value.value == sdp_records()[0x00010001][0].value.value + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_service_attributes_with_continuation(): + # Setup connections + devices = TwoDevices() + await devices.setup_connection() + + # Register SDP service + records = { + 0x00010001: [ + ServiceAttribute( + x, + DataElement.unsigned_integer_32(0x00010001), + ) + for x in range(100) + ] + } + devices.devices[0].sdp_server.service_records.update(records) + + # Get attributes + async with Client(devices.connections[1], mtu=48) as client: + attributes = await client.get_attributes(0x00010001, list(range(100))) + assert len(attributes) == 100 + for i, attribute in enumerate(attributes): + assert attribute.id == i # ----------------------------------------------------------------------------- @@ -255,19 +320,81 @@ async def test_service_search_attribute(): await devices.setup_connection() # Register SDP service - devices.devices[0].sdp_server.service_records.update(sdp_records()) + records = { + 0x00010001: [ + ServiceAttribute( + 4, + DataElement.sequence( + [DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))] + ), + ), + ServiceAttribute( + 3, + DataElement.sequence( + [DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))] + ), + ), + ServiceAttribute( + 1, + DataElement.sequence( + [DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))] + ), + ), + ] + } + + devices.devices[0].sdp_server.service_records.update(records) # Search for service - client = Client(devices.connections[1]) - await client.connect() - attributes = await client.search_attributes( - [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [(0x0000FFFF, 8)] - ) + async with Client(devices.connections[1]) as client: + attributes = await client.search_attributes( + [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [(0, 0xFFFF)] + ) + assert len(attributes) == 1 + assert len(attributes[0]) == 3 + assert attributes[0][0].id == 1 + assert attributes[0][1].id == 3 + assert attributes[0][2].id == 4 - # Then - for expect, actual in zip(attributes, sdp_records().values()): - assert expect.id == actual.id - assert expect.value == actual.value + attributes = await client.search_attributes( + [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [1, 2, 3] + ) + assert len(attributes) == 1 + assert len(attributes[0]) == 2 + assert attributes[0][0].id == 1 + assert attributes[0][1].id == 3 + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_service_search_attribute_with_continuation(): + # Setup connections + devices = TwoDevices() + await devices.setup_connection() + + # Register SDP service + records = { + 0x00010001: [ + ServiceAttribute( + x, + DataElement.sequence( + [DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))] + ), + ) + for x in range(100) + ] + } + devices.devices[0].sdp_server.service_records.update(records) + + # Search for service + async with Client(devices.connections[1], mtu=48) as client: + attributes = await client.search_attributes( + [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [(0, 0xFFFF)] + ) + assert len(attributes) == 1 + assert len(attributes[0]) == 100 + for i in range(100): + assert attributes[0][i].id == i # ----------------------------------------------------------------------------- @@ -287,9 +414,12 @@ async def test_client_async_context(): # ----------------------------------------------------------------------------- async def run(): test_data_elements() - await test_service_attribute() + await test_service_attributes() + await test_service_attributes_with_continuation() await test_service_search() + await test_service_search_with_continuation() await test_service_search_attribute() + await test_service_search_attribute_with_continuation() # -----------------------------------------------------------------------------