diff --git a/bumble/sdp.py b/bumble/sdp.py index 1d4faf9a..64281874 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -18,13 +18,16 @@ from __future__ import annotations import logging import struct -from typing import Dict, List, Type +from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING -from . import core +from . import core, l2cap from .colors import color from .core import InvalidStateError from .hci import HCI_Object, name_or_number, key_with_value +if TYPE_CHECKING: + from .device import Device, Connection + # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- @@ -466,7 +469,7 @@ class ServiceAttribute: self.value = value @staticmethod - def list_from_data_elements(elements): + def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]: attribute_list = [] for i in range(0, len(elements) // 2): attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)] @@ -478,7 +481,9 @@ class ServiceAttribute: return attribute_list @staticmethod - def find_attribute_in_list(attribute_list, attribute_id): + def find_attribute_in_list( + attribute_list: List[ServiceAttribute], attribute_id: int + ) -> Optional[DataElement]: return next( ( attribute.value @@ -493,7 +498,7 @@ class ServiceAttribute: return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code) @staticmethod - def is_uuid_in_value(uuid, value): + def is_uuid_in_value(uuid: core.UUID, value: DataElement) -> bool: # Find if a uuid matches a value, either directly or recursing into sequences if value.type == DataElement.UUID: return value.value == uuid @@ -547,7 +552,9 @@ class SDP_PDU: return self @staticmethod - def parse_service_record_handle_list_preceded_by_count(data, offset): + def parse_service_record_handle_list_preceded_by_count( + data: bytes, offset: int + ) -> Tuple[int, List[int]]: count = struct.unpack_from('>H', data, offset - 2)[0] handle_list = [ struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count) @@ -645,6 +652,10 @@ class SDP_ServiceSearchRequest(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU ''' + service_search_pattern: DataElement + maximum_service_record_count: int + continuation_state: bytes + # ----------------------------------------------------------------------------- @SDP_PDU.subclass( @@ -663,6 +674,11 @@ class SDP_ServiceSearchResponse(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU ''' + service_record_handle_list: List[int] + total_service_record_count: int + current_service_record_count: int + continuation_state: bytes + # ----------------------------------------------------------------------------- @SDP_PDU.subclass( @@ -678,6 +694,11 @@ class SDP_ServiceAttributeRequest(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU ''' + service_record_handle: int + maximum_attribute_byte_count: int + attribute_id_list: DataElement + continuation_state: bytes + # ----------------------------------------------------------------------------- @SDP_PDU.subclass( @@ -692,6 +713,10 @@ class SDP_ServiceAttributeResponse(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU ''' + attribute_list_byte_count: int + attribute_list: bytes + continuation_state: bytes + # ----------------------------------------------------------------------------- @SDP_PDU.subclass( @@ -707,6 +732,11 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU ''' + service_search_pattern: DataElement + maximum_attribute_byte_count: int + attribute_id_list: DataElement + continuation_state: bytes + # ----------------------------------------------------------------------------- @SDP_PDU.subclass( @@ -721,26 +751,34 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU): See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU ''' + attribute_list_byte_count: int + attribute_list: bytes + continuation_state: bytes + # ----------------------------------------------------------------------------- class Client: - def __init__(self, device): + channel: Optional[l2cap.Channel] + + def __init__(self, device: Device) -> None: self.device = device self.pending_request = None self.channel = None - async def connect(self, connection): + async def connect(self, connection: Connection) -> None: result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM) self.channel = result - async def disconnect(self): + async def disconnect(self) -> None: if self.channel: await self.channel.disconnect() self.channel = None - async def search_services(self, uuids): + async def search_services(self, uuids: List[core.UUID]) -> List[int]: if self.pending_request is not None: raise InvalidStateError('request already pending') + if self.channel is None: + raise InvalidStateError('L2CAP not connected') service_search_pattern = DataElement.sequence( [DataElement.uuid(uuid) for uuid in uuids] @@ -770,9 +808,13 @@ class Client: return service_record_handle_list - async def search_attributes(self, uuids, attribute_ids): + async def search_attributes( + self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]] + ) -> List[List[ServiceAttribute]]: if self.pending_request is not None: raise InvalidStateError('request already pending') + if self.channel is None: + raise InvalidStateError('L2CAP not connected') service_search_pattern = DataElement.sequence( [DataElement.uuid(uuid) for uuid in uuids] @@ -823,9 +865,15 @@ class Client: if sequence.type == DataElement.SEQUENCE ] - async def get_attributes(self, service_record_handle, attribute_ids): + async def get_attributes( + self, + service_record_handle: int, + attribute_ids: List[Union[int, Tuple[int, int]]], + ) -> List[ServiceAttribute]: if self.pending_request is not None: raise InvalidStateError('request already pending') + if self.channel is None: + raise InvalidStateError('L2CAP not connected') attribute_id_list = DataElement.sequence( [ @@ -873,21 +921,25 @@ class Client: # ----------------------------------------------------------------------------- class Server: CONTINUATION_STATE = bytes([0x01, 0x43]) + channel: Optional[l2cap.Channel] + Service = NewType('Service', List[ServiceAttribute]) + service_records: Dict[int, Service] + current_response: Union[None, bytes, Tuple[int, List[int]]] - def __init__(self, device): + def __init__(self, device: Device) -> None: self.device = device self.service_records = {} # Service records maps, by record handle self.channel = None self.current_response = None - def register(self, l2cap_channel_manager): + def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None: l2cap_channel_manager.register_server(SDP_PSM, self.on_connection) def send_response(self, response): logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}') self.channel.send_pdu(response) - def match_services(self, search_pattern): + def match_services(self, search_pattern: DataElement) -> Dict[int, Service]: # Find the services for which the attributes in the pattern is a subset of the # service's attribute values (NOTE: the value search recurses into sequences) matching_services = {} @@ -957,7 +1009,9 @@ class Server: return (payload, continuation_state) @staticmethod - def get_service_attributes(service, attribute_ids): + def get_service_attributes( + service: Service, attribute_ids: List[DataElement] + ) -> DataElement: attributes = [] for attribute_id in attribute_ids: if attribute_id.value_size == 4: @@ -982,10 +1036,10 @@ class Server: return attribute_list - def on_sdp_service_search_request(self, request): + def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None: # Check if this is a continuation if len(request.continuation_state) > 1: - if not self.current_response: + if self.current_response is None: self.send_response( SDP_ErrorResponse( transaction_id=request.transaction_id, @@ -1014,6 +1068,7 @@ class Server: ) # Respond, keeping any unsent handles for later + assert isinstance(self.current_response, tuple) service_record_handles = self.current_response[1][ : request.maximum_service_record_count ] @@ -1037,10 +1092,12 @@ class Server: ) ) - def on_sdp_service_attribute_request(self, request): + def on_sdp_service_attribute_request( + self, request: SDP_ServiceAttributeRequest + ) -> None: # Check if this is a continuation if len(request.continuation_state) > 1: - if not self.current_response: + if self.current_response is None: self.send_response( SDP_ErrorResponse( transaction_id=request.transaction_id, @@ -1073,22 +1130,24 @@ class Server: self.current_response = bytes(attribute_list) # Respond, keeping any pending chunks for later - attribute_list, continuation_state = self.get_next_response_payload( + attribute_list_response, continuation_state = self.get_next_response_payload( request.maximum_attribute_byte_count ) self.send_response( SDP_ServiceAttributeResponse( transaction_id=request.transaction_id, - attribute_list_byte_count=len(attribute_list), + attribute_list_byte_count=len(attribute_list_response), attribute_list=attribute_list, continuation_state=continuation_state, ) ) - def on_sdp_service_search_attribute_request(self, request): + def on_sdp_service_search_attribute_request( + self, request: SDP_ServiceSearchAttributeRequest + ) -> None: # Check if this is a continuation if len(request.continuation_state) > 1: - if not self.current_response: + if self.current_response is None: self.send_response( SDP_ErrorResponse( transaction_id=request.transaction_id, @@ -1118,13 +1177,13 @@ class Server: self.current_response = bytes(attribute_lists) # Respond, keeping any pending chunks for later - attribute_lists, continuation_state = self.get_next_response_payload( + attribute_lists_response, continuation_state = self.get_next_response_payload( request.maximum_attribute_byte_count ) self.send_response( SDP_ServiceSearchAttributeResponse( transaction_id=request.transaction_id, - attribute_lists_byte_count=len(attribute_lists), + attribute_lists_byte_count=len(attribute_lists_response), attribute_lists=attribute_lists, continuation_state=continuation_state, ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..1e45f74c --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/l2cap_test.py b/tests/l2cap_test.py index 6f8e1810..c6b2340e 100644 --- a/tests/l2cap_test.py +++ b/tests/l2cap_test.py @@ -21,13 +21,9 @@ import os import random import pytest -from bumble.controller import Controller -from bumble.link import LocalLink -from bumble.device import Device -from bumble.host import Host -from bumble.transport import AsyncPipeSink from bumble.core import ProtocolError from bumble.l2cap import L2CAP_Connection_Request +from .test_utils import TwoDevices # ----------------------------------------------------------------------------- @@ -37,60 +33,6 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -class TwoDevices: - def __init__(self): - self.connections = [None, None] - - self.link = LocalLink() - self.controllers = [ - Controller('C1', link=self.link), - Controller('C2', link=self.link), - ] - self.devices = [ - Device( - address='F0:F1:F2:F3:F4:F5', - host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])), - ), - Device( - address='F5:F4:F3:F2:F1:F0', - host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])), - ), - ] - - self.paired = [None, None] - - def on_connection(self, which, connection): - self.connections[which] = connection - - def on_paired(self, which, keys): - self.paired[which] = keys - - -# ----------------------------------------------------------------------------- -async def setup_connection(): - # Create two devices, each with a controller, attached to the same link - two_devices = TwoDevices() - - # Attach listeners - two_devices.devices[0].on( - 'connection', lambda connection: two_devices.on_connection(0, connection) - ) - two_devices.devices[1].on( - 'connection', lambda connection: two_devices.on_connection(1, connection) - ) - - # Start - await two_devices.devices[0].power_on() - await two_devices.devices[1].power_on() - - # Connect the two devices - await two_devices.devices[0].connect(two_devices.devices[1].random_address) - - # Check the post conditions - assert two_devices.connections[0] is not None - assert two_devices.connections[1] is not None - - return two_devices # ----------------------------------------------------------------------------- @@ -132,7 +74,8 @@ def test_helpers(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_basic_connection(): - devices = await setup_connection() + devices = TwoDevices() + await devices.setup_connection() psm = 1234 # Check that if there's no one listening, we can't connect @@ -184,7 +127,8 @@ async def test_basic_connection(): # ----------------------------------------------------------------------------- async def transfer_payload(max_credits, mtu, mps): - devices = await setup_connection() + devices = TwoDevices() + await devices.setup_connection() received = [] @@ -226,7 +170,8 @@ async def test_transfer(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_bidirectional_transfer(): - devices = await setup_connection() + devices = TwoDevices() + await devices.setup_connection() client_received = [] server_received = [] diff --git a/tests/sdp_test.py b/tests/sdp_test.py index f07b5790..090e7b2c 100644 --- a/tests/sdp_test.py +++ b/tests/sdp_test.py @@ -15,15 +15,30 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- -from bumble.core import UUID -from bumble.sdp import DataElement +import asyncio +import logging +import os + +from bumble.core import UUID, BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID +from bumble.sdp import ( + DataElement, + ServiceAttribute, + Client, + Server, + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, + SDP_PUBLIC_BROWSE_ROOT, + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, +) +from .test_utils import TwoDevices # ----------------------------------------------------------------------------- # pylint: disable=invalid-name # ----------------------------------------------------------------------------- # ----------------------------------------------------------------------------- -def basic_check(x): +def basic_check(x: DataElement) -> None: serialized = bytes(x) if len(serialized) < 500: print('Original:', x) @@ -41,7 +56,7 @@ def basic_check(x): # ----------------------------------------------------------------------------- -def test_data_elements(): +def test_data_elements() -> None: e = DataElement(DataElement.NIL, None) basic_check(e) @@ -157,5 +172,108 @@ def test_data_elements(): # ----------------------------------------------------------------------------- -if __name__ == '__main__': +def sdp_records(): + return { + 0x00010001: [ + ServiceAttribute( + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + DataElement.unsigned_integer_32(0x00010001), + ), + ServiceAttribute( + SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, + DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), + ), + ServiceAttribute( + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))] + ), + ), + ServiceAttribute( + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]), + ] + ), + ), + ] + } + + +# ----------------------------------------------------------------------------- +async def test_service_search(): + # Setup connections + devices = TwoDevices() + await devices.setup_connection() + assert devices.connections[0] + assert devices.connections[1] + + # Register SDP service + devices.devices[0].sdp_server.service_records.update(sdp_records()) + + # Search for service + client = Client(devices.devices[1]) + await client.connect(devices.connections[1]) + services = await client.search_services( + [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')] + ) + + # Then + assert services[0] == 0x00010001 + + +# ----------------------------------------------------------------------------- +async def test_service_attribute(): + # Setup connections + devices = TwoDevices() + await devices.setup_connection() + + # Register SDP service + devices.devices[0].sdp_server.service_records.update(sdp_records()) + + # Search for service + client = Client(devices.devices[1]) + await client.connect(devices.connections[1]) + attributes = await client.get_attributes( + 0x00010001, [SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID] + ) + + # Then + assert attributes[0].value.value == sdp_records()[0x00010001][0].value.value + + +# ----------------------------------------------------------------------------- +async def test_service_search_attribute(): + # Setup connections + devices = TwoDevices() + await devices.setup_connection() + + # Register SDP service + devices.devices[0].sdp_server.service_records.update(sdp_records()) + + # Search for service + client = Client(devices.devices[1]) + await client.connect(devices.connections[1]) + attributes = await client.search_attributes( + [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [(0x0000FFFF, 8)] + ) + + # Then + for expect, actual in zip(attributes, sdp_records().values()): + assert expect.id == actual.id + assert expect.value == actual.value + + +# ----------------------------------------------------------------------------- +async def run(): test_data_elements() + await test_service_attribute() + await test_service_search() + await test_service_search_attribute() + + +# ----------------------------------------------------------------------------- +if __name__ == '__main__': + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + asyncio.run(run()) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..f19f18c8 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,73 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from bumble.controller import Controller +from bumble.link import LocalLink +from bumble.device import Device, Connection +from bumble.host import Host +from bumble.transport import AsyncPipeSink +from bumble.hci import Address + + +class TwoDevices: + connections: List[Optional[Connection]] + + def __init__(self) -> None: + self.connections = [None, None] + + self.link = LocalLink() + self.controllers = [ + Controller('C1', link=self.link), + Controller('C2', link=self.link), + ] + self.devices = [ + Device( + address=Address('F0:F1:F2:F3:F4:F5'), + host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])), + ), + Device( + address=Address('F5:F4:F3:F2:F1:F0'), + host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])), + ), + ] + + self.paired = [None, None] + + def on_connection(self, which, connection): + self.connections[which] = connection + + def on_paired(self, which, keys): + self.paired[which] = keys + + async def setup_connection(self) -> None: + # Attach listeners + self.devices[0].on( + 'connection', lambda connection: self.on_connection(0, connection) + ) + self.devices[1].on( + 'connection', lambda connection: self.on_connection(1, connection) + ) + + # Start + await self.devices[0].power_on() + await self.devices[1].power_on() + + # Connect the two devices + await self.devices[0].connect(self.devices[1].random_address) + + # Check the post conditions + assert self.connections[0] is not None + assert self.connections[1] is not None