Merge pull request #420 from zxzxwu/rfc

Add RFCOMM and SDP context manager and search helper
This commit is contained in:
zxzxwu
2024-02-04 00:42:24 +08:00
committed by GitHub
5 changed files with 183 additions and 75 deletions

View File

@@ -50,10 +50,8 @@ from bumble.sdp import (
SDP_PUBLIC_BROWSE_ROOT,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement,
ServiceAttribute,
Client as SdpClient,
)
from bumble.transport import open_transport_or_link
import bumble.rfcomm
@@ -198,48 +196,6 @@ def make_sdp_records(channel):
}
async def find_rfcomm_channel_with_uuid(connection: Connection, uuid: str) -> int:
# Connect to the SDP Server
sdp_client = SdpClient(connection)
await sdp_client.connect()
# Search for services with an L2CAP service attribute
search_result = await sdp_client.search_attributes(
[BT_L2CAP_PROTOCOL_ID],
[
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
],
)
for attribute_list in search_result:
service_uuid = None
service_class_id_list = ServiceAttribute.find_attribute_in_list(
attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
)
if service_class_id_list:
if service_class_id_list.value:
for service_class_id in service_class_id_list.value:
service_uuid = service_class_id.value
if str(service_uuid) != uuid:
# This service doesn't have a UUID or isn't the right one.
continue
# Look for the RFCOMM Channel number
protocol_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
)
if protocol_descriptor_list:
for protocol_descriptor in protocol_descriptor_list.value:
if len(protocol_descriptor.value) >= 2:
if protocol_descriptor.value[0].value == BT_RFCOMM_PROTOCOL_ID:
await sdp_client.disconnect()
return protocol_descriptor.value[1].value
await sdp_client.disconnect()
return 0
def log_stats(title, stats):
stats_min = min(stats)
stats_max = max(stats)
@@ -957,7 +913,9 @@ class RfcommClient(StreamedPacketIO):
logging.info(
color(f'@@@ Discovering channel number from UUID {self.uuid}', 'cyan')
)
channel = await find_rfcomm_channel_with_uuid(connection, self.uuid)
channel = await bumble.rfcomm.find_rfcomm_channel_with_uuid(
connection, self.uuid
)
logging.info(color(f'@@@ Channel number = {channel}', 'cyan'))
if channel == 0:
logging.info(color('!!! No RFComm service with this UUID found', 'red'))

View File

@@ -22,10 +22,13 @@ import asyncio
import dataclasses
import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from typing_extensions import Self
from pyee import EventEmitter
from . import core, l2cap
from bumble import core
from bumble import l2cap
from bumble import sdp
from .colors import color
from .core import (
UUID,
@@ -35,15 +38,6 @@ from .core import (
InvalidStateError,
ProtocolError,
)
from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_PUBLIC_BROWSE_ROOT,
DataElement,
ServiceAttribute,
)
if TYPE_CHECKING:
from bumble.device import Device, Connection
@@ -122,29 +116,33 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# -----------------------------------------------------------------------------
def make_service_sdp_records(
service_record_handle: int, channel: int, uuid: Optional[UUID] = None
) -> List[ServiceAttribute]:
) -> List[sdp.ServiceAttribute]:
"""
Create SDP records for an RFComm service given a channel number and an
optional UUID. A Service Class Attribute is included only if the UUID is not None.
"""
records = [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
sdp.ServiceAttribute(
sdp.SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
sdp.DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
sdp.ServiceAttribute(
sdp.SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence(
[sdp.DataElement.uuid(sdp.SDP_PUBLIC_BROWSE_ROOT)]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
sdp.ServiceAttribute(
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
sdp.DataElement.sequence(
[sdp.DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]
),
sdp.DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel),
sdp.DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
sdp.DataElement.unsigned_integer_8(channel),
]
),
]
@@ -154,15 +152,81 @@ def make_service_sdp_records(
if uuid:
records.append(
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(uuid)]),
sdp.ServiceAttribute(
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence([sdp.DataElement.uuid(uuid)]),
)
)
return records
# -----------------------------------------------------------------------------
async def find_rfcomm_channels(connection: Connection) -> Dict[int, List[UUID]]:
"""Searches all RFCOMM channels and their associated UUID from SDP service records.
Args:
connection: ACL connection to make SDP search.
Returns:
Dictionary mapping from channel number to service class UUID list.
"""
results = {}
async with sdp.Client(connection) as sdp_client:
search_result = await sdp_client.search_attributes(
uuids=[core.BT_RFCOMM_PROTOCOL_ID],
attribute_ids=[
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
],
)
for attribute_lists in search_result:
service_classes: List[UUID] = []
channel: Optional[int] = None
for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
protocol_descriptor_list = attribute.value.value
channel = protocol_descriptor_list[1].value[1].value
elif attribute.id == sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID:
service_class_id_list = attribute.value.value
service_classes = [
service_class.value for service_class in service_class_id_list
]
if not service_classes or not channel:
logger.warning(f"Bad result {attribute_lists}.")
else:
results[channel] = service_classes
return results
# -----------------------------------------------------------------------------
async def find_rfcomm_channel_with_uuid(
connection: Connection, uuid: str | UUID
) -> Optional[int]:
"""Searches an RFCOMM channel associated with given UUID from service records.
Args:
connection: ACL connection to make SDP search.
uuid: UUID of service record to search for.
Returns:
RFCOMM channel number if found, otherwise None.
"""
if isinstance(uuid, str):
uuid = UUID(uuid)
return next(
(
channel
for channel, class_id_list in (
await find_rfcomm_channels(connection)
).items()
if uuid in class_id_list
),
None,
)
# -----------------------------------------------------------------------------
def compute_fcs(buffer: bytes) -> int:
result = 0xFF
@@ -876,7 +940,15 @@ class Client:
self.multiplexer = None
# Close the L2CAP channel
# TODO
if self.l2cap_channel:
await self.l2cap_channel.disconnect()
self.l2cap_channel = None
async def __aenter__(self) -> Multiplexer:
return await self.start()
async def __aexit__(self, *args) -> None:
await self.shutdown()
# -----------------------------------------------------------------------------
@@ -890,7 +962,7 @@ class Server(EventEmitter):
self.acceptors = {}
# Register ourselves with the L2CAP channel manager
device.create_l2cap_server(
self.l2cap_server = device.create_l2cap_server(
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM), handler=self.on_connection
)
@@ -941,3 +1013,9 @@ class Server(EventEmitter):
acceptor = self.acceptors.get(dlc.dlci >> 1)
if acceptor:
acceptor(dlc)
def __enter__(self) -> Self:
return self
def __exit__(self, *args) -> None:
self.l2cap_server.close()

View File

@@ -19,6 +19,7 @@ from __future__ import annotations
import logging
import struct
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
from typing_extensions import Self
from . import core, l2cap
from .colors import color
@@ -920,6 +921,13 @@ class Client:
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
async def __aenter__(self) -> Self:
await self.connect()
return self
async def __aexit__(self, *args) -> None:
await self.disconnect()
# -----------------------------------------------------------------------------
class Server:

View File

@@ -19,7 +19,17 @@ import asyncio
import pytest
from . import test_utils
from bumble.rfcomm import RFCOMM_Frame, Server, Client, DLC
from bumble import core
from bumble.rfcomm import (
RFCOMM_Frame,
Server,
Client,
DLC,
make_service_sdp_records,
find_rfcomm_channels,
find_rfcomm_channel_with_uuid,
RFCOMM_PSM,
)
# -----------------------------------------------------------------------------
@@ -70,6 +80,45 @@ async def test_basic_connection():
assert await queues[0].get() == b'Lorem ipsum dolor sit amet'
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_record():
HANDLE = 2
CHANNEL = 1
SERVICE_UUID = core.UUID('00000000-0000-0000-0000-000000000001')
devices = test_utils.TwoDevices()
await devices.setup_connection()
devices[0].sdp_service_records[HANDLE] = make_service_sdp_records(
HANDLE, CHANNEL, SERVICE_UUID
)
assert SERVICE_UUID in (await find_rfcomm_channels(devices.connections[1]))[CHANNEL]
assert (
await find_rfcomm_channel_with_uuid(devices.connections[1], SERVICE_UUID)
== CHANNEL
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_context():
devices = test_utils.TwoDevices()
await devices.setup_connection()
server = Server(devices[0])
with server:
assert server.l2cap_server is not None
client = Client(devices.connections[1])
async with client:
assert client.l2cap_channel is not None
assert client.l2cap_channel is None
assert RFCOMM_PSM not in devices[0].l2cap_channel_manager.servers
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_frames()

View File

@@ -38,6 +38,7 @@ from .test_utils import TwoDevices
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
def basic_check(x: DataElement) -> None:
serialized = bytes(x)
@@ -269,6 +270,20 @@ async def test_service_search_attribute():
assert expect.value == actual.value
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_client_async_context():
devices = TwoDevices()
await devices.setup_connection()
client = Client(devices.connections[1])
async with client:
assert client.channel is not None
assert client.channel is None
# -----------------------------------------------------------------------------
async def run():
test_data_elements()