From 10e53553d74b7fe917e01f31d04a7851d91931dc Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Fri, 2 Feb 2024 13:55:54 +0800 Subject: [PATCH] Add RFCOMM and SDP helpers --- apps/bench.py | 48 +-------------- bumble/rfcomm.py | 136 ++++++++++++++++++++++++++++++++++--------- bumble/sdp.py | 8 +++ tests/rfcomm_test.py | 51 +++++++++++++++- tests/sdp_test.py | 15 +++++ 5 files changed, 183 insertions(+), 75 deletions(-) diff --git a/apps/bench.py b/apps/bench.py index d4635d4..1f9d45f 100644 --- a/apps/bench.py +++ b/apps/bench.py @@ -50,10 +50,8 @@ from bumble.sdp import ( SDP_PUBLIC_BROWSE_ROOT, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, - SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement, ServiceAttribute, - Client as SdpClient, ) from bumble.transport import open_transport_or_link import bumble.rfcomm @@ -198,48 +196,6 @@ def make_sdp_records(channel): } -async def find_rfcomm_channel_with_uuid(connection: Connection, uuid: str) -> int: - # Connect to the SDP Server - sdp_client = SdpClient(connection) - await sdp_client.connect() - - # Search for services with an L2CAP service attribute - search_result = await sdp_client.search_attributes( - [BT_L2CAP_PROTOCOL_ID], - [ - SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, - ], - ) - for attribute_list in search_result: - service_uuid = None - service_class_id_list = ServiceAttribute.find_attribute_in_list( - attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID - ) - if service_class_id_list: - if service_class_id_list.value: - for service_class_id in service_class_id_list.value: - service_uuid = service_class_id.value - if str(service_uuid) != uuid: - # This service doesn't have a UUID or isn't the right one. - continue - - # Look for the RFCOMM Channel number - protocol_descriptor_list = ServiceAttribute.find_attribute_in_list( - attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID - ) - if protocol_descriptor_list: - for protocol_descriptor in protocol_descriptor_list.value: - if len(protocol_descriptor.value) >= 2: - if protocol_descriptor.value[0].value == BT_RFCOMM_PROTOCOL_ID: - await sdp_client.disconnect() - return protocol_descriptor.value[1].value - - await sdp_client.disconnect() - return 0 - - def log_stats(title, stats): stats_min = min(stats) stats_max = max(stats) @@ -957,7 +913,9 @@ class RfcommClient(StreamedPacketIO): logging.info( color(f'@@@ Discovering channel number from UUID {self.uuid}', 'cyan') ) - channel = await find_rfcomm_channel_with_uuid(connection, self.uuid) + channel = await bumble.rfcomm.find_rfcomm_channel_with_uuid( + connection, self.uuid + ) logging.info(color(f'@@@ Channel number = {channel}', 'cyan')) if channel == 0: logging.info(color('!!! No RFComm service with this UUID found', 'red')) diff --git a/bumble/rfcomm.py b/bumble/rfcomm.py index 9a8ad77..5500bc1 100644 --- a/bumble/rfcomm.py +++ b/bumble/rfcomm.py @@ -22,10 +22,13 @@ import asyncio import dataclasses import enum from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING +from typing_extensions import Self from pyee import EventEmitter -from . import core, l2cap +from bumble import core +from bumble import l2cap +from bumble import sdp from .colors import color from .core import ( UUID, @@ -35,15 +38,6 @@ from .core import ( InvalidStateError, ProtocolError, ) -from .sdp import ( - SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, - SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, - SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, - SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_PUBLIC_BROWSE_ROOT, - DataElement, - ServiceAttribute, -) if TYPE_CHECKING: from bumble.device import Device, Connection @@ -122,29 +116,33 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30 # ----------------------------------------------------------------------------- def make_service_sdp_records( service_record_handle: int, channel: int, uuid: Optional[UUID] = None -) -> List[ServiceAttribute]: +) -> List[sdp.ServiceAttribute]: """ Create SDP records for an RFComm service given a channel number and an optional UUID. A Service Class Attribute is included only if the UUID is not None. """ records = [ - ServiceAttribute( - SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, - DataElement.unsigned_integer_32(service_record_handle), + sdp.ServiceAttribute( + sdp.SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + sdp.DataElement.unsigned_integer_32(service_record_handle), ), - ServiceAttribute( - SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, - DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), + sdp.ServiceAttribute( + sdp.SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, + sdp.DataElement.sequence( + [sdp.DataElement.uuid(sdp.SDP_PUBLIC_BROWSE_ROOT)] + ), ), - ServiceAttribute( - SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - DataElement.sequence( + sdp.ServiceAttribute( + sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + sdp.DataElement.sequence( [ - DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]), - DataElement.sequence( + sdp.DataElement.sequence( + [sdp.DataElement.uuid(BT_L2CAP_PROTOCOL_ID)] + ), + sdp.DataElement.sequence( [ - DataElement.uuid(BT_RFCOMM_PROTOCOL_ID), - DataElement.unsigned_integer_8(channel), + sdp.DataElement.uuid(BT_RFCOMM_PROTOCOL_ID), + sdp.DataElement.unsigned_integer_8(channel), ] ), ] @@ -154,15 +152,81 @@ def make_service_sdp_records( if uuid: records.append( - ServiceAttribute( - SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, - DataElement.sequence([DataElement.uuid(uuid)]), + sdp.ServiceAttribute( + sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + sdp.DataElement.sequence([sdp.DataElement.uuid(uuid)]), ) ) return records +# ----------------------------------------------------------------------------- +async def find_rfcomm_channels(connection: Connection) -> Dict[int, List[UUID]]: + """Searches all RFCOMM channels and their associated UUID from SDP service records. + + Args: + connection: ACL connection to make SDP search. + + Returns: + Dictionary mapping from channel number to service class UUID list. + """ + results = {} + async with sdp.Client(connection) as sdp_client: + search_result = await sdp_client.search_attributes( + uuids=[core.BT_RFCOMM_PROTOCOL_ID], + attribute_ids=[ + sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + ], + ) + for attribute_lists in search_result: + service_classes: List[UUID] = [] + channel: Optional[int] = None + for attribute in attribute_lists: + # The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]]. + if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: + protocol_descriptor_list = attribute.value.value + channel = protocol_descriptor_list[1].value[1].value + elif attribute.id == sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID: + service_class_id_list = attribute.value.value + service_classes = [ + service_class.value for service_class in service_class_id_list + ] + if not service_classes or not channel: + logger.warning(f"Bad result {attribute_lists}.") + else: + results[channel] = service_classes + return results + + +# ----------------------------------------------------------------------------- +async def find_rfcomm_channel_with_uuid( + connection: Connection, uuid: str | UUID +) -> Optional[int]: + """Searches an RFCOMM channel associated with given UUID from service records. + + Args: + connection: ACL connection to make SDP search. + uuid: UUID of service record to search for. + + Returns: + RFCOMM channel number if found, otherwise None. + """ + if isinstance(uuid, str): + uuid = UUID(uuid) + return next( + ( + channel + for channel, class_id_list in ( + await find_rfcomm_channels(connection) + ).items() + if uuid in class_id_list + ), + None, + ) + + # ----------------------------------------------------------------------------- def compute_fcs(buffer: bytes) -> int: result = 0xFF @@ -876,7 +940,15 @@ class Client: self.multiplexer = None # Close the L2CAP channel - # TODO + if self.l2cap_channel: + await self.l2cap_channel.disconnect() + self.l2cap_channel = None + + async def __aenter__(self) -> Multiplexer: + return await self.start() + + async def __aexit__(self, *args) -> None: + await self.shutdown() # ----------------------------------------------------------------------------- @@ -890,7 +962,7 @@ class Server(EventEmitter): self.acceptors = {} # Register ourselves with the L2CAP channel manager - device.create_l2cap_server( + self.l2cap_server = device.create_l2cap_server( spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM), handler=self.on_connection ) @@ -941,3 +1013,9 @@ class Server(EventEmitter): acceptor = self.acceptors.get(dlc.dlci >> 1) if acceptor: acceptor(dlc) + + def __enter__(self) -> Self: + return self + + def __exit__(self, *args) -> None: + self.l2cap_server.close() diff --git a/bumble/sdp.py b/bumble/sdp.py index 749e295..6423790 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -19,6 +19,7 @@ from __future__ import annotations import logging import struct from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING +from typing_extensions import Self from . import core, l2cap from .colors import color @@ -920,6 +921,13 @@ class Client: return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value) + async def __aenter__(self) -> Self: + await self.connect() + return self + + async def __aexit__(self, *args) -> None: + await self.disconnect() + # ----------------------------------------------------------------------------- class Server: diff --git a/tests/rfcomm_test.py b/tests/rfcomm_test.py index de6f4af..2ab3c2c 100644 --- a/tests/rfcomm_test.py +++ b/tests/rfcomm_test.py @@ -19,7 +19,17 @@ import asyncio import pytest from . import test_utils -from bumble.rfcomm import RFCOMM_Frame, Server, Client, DLC +from bumble import core +from bumble.rfcomm import ( + RFCOMM_Frame, + Server, + Client, + DLC, + make_service_sdp_records, + find_rfcomm_channels, + find_rfcomm_channel_with_uuid, + RFCOMM_PSM, +) # ----------------------------------------------------------------------------- @@ -70,6 +80,45 @@ async def test_basic_connection(): assert await queues[0].get() == b'Lorem ipsum dolor sit amet' +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_service_record(): + HANDLE = 2 + CHANNEL = 1 + SERVICE_UUID = core.UUID('00000000-0000-0000-0000-000000000001') + + devices = test_utils.TwoDevices() + await devices.setup_connection() + + devices[0].sdp_service_records[HANDLE] = make_service_sdp_records( + HANDLE, CHANNEL, SERVICE_UUID + ) + + assert SERVICE_UUID in (await find_rfcomm_channels(devices.connections[1]))[CHANNEL] + assert ( + await find_rfcomm_channel_with_uuid(devices.connections[1], SERVICE_UUID) + == CHANNEL + ) + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_context(): + devices = test_utils.TwoDevices() + await devices.setup_connection() + + server = Server(devices[0]) + with server: + assert server.l2cap_server is not None + + client = Client(devices.connections[1]) + async with client: + assert client.l2cap_channel is not None + + assert client.l2cap_channel is None + assert RFCOMM_PSM not in devices[0].l2cap_channel_manager.servers + + # ----------------------------------------------------------------------------- if __name__ == '__main__': test_frames() diff --git a/tests/sdp_test.py b/tests/sdp_test.py index ea8e0ab..91835e7 100644 --- a/tests/sdp_test.py +++ b/tests/sdp_test.py @@ -38,6 +38,7 @@ from .test_utils import TwoDevices # pylint: disable=invalid-name # ----------------------------------------------------------------------------- + # ----------------------------------------------------------------------------- def basic_check(x: DataElement) -> None: serialized = bytes(x) @@ -269,6 +270,20 @@ async def test_service_search_attribute(): assert expect.value == actual.value +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_client_async_context(): + devices = TwoDevices() + await devices.setup_connection() + + client = Client(devices.connections[1]) + + async with client: + assert client.channel is not None + + assert client.channel is None + + # ----------------------------------------------------------------------------- async def run(): test_data_elements()