forked from auracaster/bumble_mirror
Merge pull request #420 from zxzxwu/rfc
Add RFCOMM and SDP context manager and search helper
This commit is contained in:
@@ -50,10 +50,8 @@ from bumble.sdp import (
|
|||||||
SDP_PUBLIC_BROWSE_ROOT,
|
SDP_PUBLIC_BROWSE_ROOT,
|
||||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
|
||||||
DataElement,
|
DataElement,
|
||||||
ServiceAttribute,
|
ServiceAttribute,
|
||||||
Client as SdpClient,
|
|
||||||
)
|
)
|
||||||
from bumble.transport import open_transport_or_link
|
from bumble.transport import open_transport_or_link
|
||||||
import bumble.rfcomm
|
import bumble.rfcomm
|
||||||
@@ -198,48 +196,6 @@ def make_sdp_records(channel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def find_rfcomm_channel_with_uuid(connection: Connection, uuid: str) -> int:
|
|
||||||
# Connect to the SDP Server
|
|
||||||
sdp_client = SdpClient(connection)
|
|
||||||
await sdp_client.connect()
|
|
||||||
|
|
||||||
# Search for services with an L2CAP service attribute
|
|
||||||
search_result = await sdp_client.search_attributes(
|
|
||||||
[BT_L2CAP_PROTOCOL_ID],
|
|
||||||
[
|
|
||||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
|
||||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
|
||||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
for attribute_list in search_result:
|
|
||||||
service_uuid = None
|
|
||||||
service_class_id_list = ServiceAttribute.find_attribute_in_list(
|
|
||||||
attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
|
|
||||||
)
|
|
||||||
if service_class_id_list:
|
|
||||||
if service_class_id_list.value:
|
|
||||||
for service_class_id in service_class_id_list.value:
|
|
||||||
service_uuid = service_class_id.value
|
|
||||||
if str(service_uuid) != uuid:
|
|
||||||
# This service doesn't have a UUID or isn't the right one.
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Look for the RFCOMM Channel number
|
|
||||||
protocol_descriptor_list = ServiceAttribute.find_attribute_in_list(
|
|
||||||
attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
|
|
||||||
)
|
|
||||||
if protocol_descriptor_list:
|
|
||||||
for protocol_descriptor in protocol_descriptor_list.value:
|
|
||||||
if len(protocol_descriptor.value) >= 2:
|
|
||||||
if protocol_descriptor.value[0].value == BT_RFCOMM_PROTOCOL_ID:
|
|
||||||
await sdp_client.disconnect()
|
|
||||||
return protocol_descriptor.value[1].value
|
|
||||||
|
|
||||||
await sdp_client.disconnect()
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def log_stats(title, stats):
|
def log_stats(title, stats):
|
||||||
stats_min = min(stats)
|
stats_min = min(stats)
|
||||||
stats_max = max(stats)
|
stats_max = max(stats)
|
||||||
@@ -957,7 +913,9 @@ class RfcommClient(StreamedPacketIO):
|
|||||||
logging.info(
|
logging.info(
|
||||||
color(f'@@@ Discovering channel number from UUID {self.uuid}', 'cyan')
|
color(f'@@@ Discovering channel number from UUID {self.uuid}', 'cyan')
|
||||||
)
|
)
|
||||||
channel = await find_rfcomm_channel_with_uuid(connection, self.uuid)
|
channel = await bumble.rfcomm.find_rfcomm_channel_with_uuid(
|
||||||
|
connection, self.uuid
|
||||||
|
)
|
||||||
logging.info(color(f'@@@ Channel number = {channel}', 'cyan'))
|
logging.info(color(f'@@@ Channel number = {channel}', 'cyan'))
|
||||||
if channel == 0:
|
if channel == 0:
|
||||||
logging.info(color('!!! No RFComm service with this UUID found', 'red'))
|
logging.info(color('!!! No RFComm service with this UUID found', 'red'))
|
||||||
|
|||||||
136
bumble/rfcomm.py
136
bumble/rfcomm.py
@@ -22,10 +22,13 @@ import asyncio
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
|
|
||||||
from . import core, l2cap
|
from bumble import core
|
||||||
|
from bumble import l2cap
|
||||||
|
from bumble import sdp
|
||||||
from .colors import color
|
from .colors import color
|
||||||
from .core import (
|
from .core import (
|
||||||
UUID,
|
UUID,
|
||||||
@@ -35,15 +38,6 @@ from .core import (
|
|||||||
InvalidStateError,
|
InvalidStateError,
|
||||||
ProtocolError,
|
ProtocolError,
|
||||||
)
|
)
|
||||||
from .sdp import (
|
|
||||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
|
||||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
|
||||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
|
||||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
|
||||||
SDP_PUBLIC_BROWSE_ROOT,
|
|
||||||
DataElement,
|
|
||||||
ServiceAttribute,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from bumble.device import Device, Connection
|
from bumble.device import Device, Connection
|
||||||
@@ -122,29 +116,33 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def make_service_sdp_records(
|
def make_service_sdp_records(
|
||||||
service_record_handle: int, channel: int, uuid: Optional[UUID] = None
|
service_record_handle: int, channel: int, uuid: Optional[UUID] = None
|
||||||
) -> List[ServiceAttribute]:
|
) -> List[sdp.ServiceAttribute]:
|
||||||
"""
|
"""
|
||||||
Create SDP records for an RFComm service given a channel number and an
|
Create SDP records for an RFComm service given a channel number and an
|
||||||
optional UUID. A Service Class Attribute is included only if the UUID is not None.
|
optional UUID. A Service Class Attribute is included only if the UUID is not None.
|
||||||
"""
|
"""
|
||||||
records = [
|
records = [
|
||||||
ServiceAttribute(
|
sdp.ServiceAttribute(
|
||||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
sdp.SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||||
DataElement.unsigned_integer_32(service_record_handle),
|
sdp.DataElement.unsigned_integer_32(service_record_handle),
|
||||||
),
|
),
|
||||||
ServiceAttribute(
|
sdp.ServiceAttribute(
|
||||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
sdp.SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||||
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
|
sdp.DataElement.sequence(
|
||||||
|
[sdp.DataElement.uuid(sdp.SDP_PUBLIC_BROWSE_ROOT)]
|
||||||
|
),
|
||||||
),
|
),
|
||||||
ServiceAttribute(
|
sdp.ServiceAttribute(
|
||||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||||
DataElement.sequence(
|
sdp.DataElement.sequence(
|
||||||
[
|
[
|
||||||
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
|
sdp.DataElement.sequence(
|
||||||
DataElement.sequence(
|
[sdp.DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]
|
||||||
|
),
|
||||||
|
sdp.DataElement.sequence(
|
||||||
[
|
[
|
||||||
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
|
sdp.DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
|
||||||
DataElement.unsigned_integer_8(channel),
|
sdp.DataElement.unsigned_integer_8(channel),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@@ -154,15 +152,81 @@ def make_service_sdp_records(
|
|||||||
|
|
||||||
if uuid:
|
if uuid:
|
||||||
records.append(
|
records.append(
|
||||||
ServiceAttribute(
|
sdp.ServiceAttribute(
|
||||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||||
DataElement.sequence([DataElement.uuid(uuid)]),
|
sdp.DataElement.sequence([sdp.DataElement.uuid(uuid)]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return records
|
return records
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def find_rfcomm_channels(connection: Connection) -> Dict[int, List[UUID]]:
|
||||||
|
"""Searches all RFCOMM channels and their associated UUID from SDP service records.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: ACL connection to make SDP search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping from channel number to service class UUID list.
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
async with sdp.Client(connection) as sdp_client:
|
||||||
|
search_result = await sdp_client.search_attributes(
|
||||||
|
uuids=[core.BT_RFCOMM_PROTOCOL_ID],
|
||||||
|
attribute_ids=[
|
||||||
|
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||||
|
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for attribute_lists in search_result:
|
||||||
|
service_classes: List[UUID] = []
|
||||||
|
channel: Optional[int] = None
|
||||||
|
for attribute in attribute_lists:
|
||||||
|
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
|
||||||
|
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
|
||||||
|
protocol_descriptor_list = attribute.value.value
|
||||||
|
channel = protocol_descriptor_list[1].value[1].value
|
||||||
|
elif attribute.id == sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID:
|
||||||
|
service_class_id_list = attribute.value.value
|
||||||
|
service_classes = [
|
||||||
|
service_class.value for service_class in service_class_id_list
|
||||||
|
]
|
||||||
|
if not service_classes or not channel:
|
||||||
|
logger.warning(f"Bad result {attribute_lists}.")
|
||||||
|
else:
|
||||||
|
results[channel] = service_classes
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def find_rfcomm_channel_with_uuid(
|
||||||
|
connection: Connection, uuid: str | UUID
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""Searches an RFCOMM channel associated with given UUID from service records.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: ACL connection to make SDP search.
|
||||||
|
uuid: UUID of service record to search for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RFCOMM channel number if found, otherwise None.
|
||||||
|
"""
|
||||||
|
if isinstance(uuid, str):
|
||||||
|
uuid = UUID(uuid)
|
||||||
|
return next(
|
||||||
|
(
|
||||||
|
channel
|
||||||
|
for channel, class_id_list in (
|
||||||
|
await find_rfcomm_channels(connection)
|
||||||
|
).items()
|
||||||
|
if uuid in class_id_list
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def compute_fcs(buffer: bytes) -> int:
|
def compute_fcs(buffer: bytes) -> int:
|
||||||
result = 0xFF
|
result = 0xFF
|
||||||
@@ -876,7 +940,15 @@ class Client:
|
|||||||
self.multiplexer = None
|
self.multiplexer = None
|
||||||
|
|
||||||
# Close the L2CAP channel
|
# Close the L2CAP channel
|
||||||
# TODO
|
if self.l2cap_channel:
|
||||||
|
await self.l2cap_channel.disconnect()
|
||||||
|
self.l2cap_channel = None
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Multiplexer:
|
||||||
|
return await self.start()
|
||||||
|
|
||||||
|
async def __aexit__(self, *args) -> None:
|
||||||
|
await self.shutdown()
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -890,7 +962,7 @@ class Server(EventEmitter):
|
|||||||
self.acceptors = {}
|
self.acceptors = {}
|
||||||
|
|
||||||
# Register ourselves with the L2CAP channel manager
|
# Register ourselves with the L2CAP channel manager
|
||||||
device.create_l2cap_server(
|
self.l2cap_server = device.create_l2cap_server(
|
||||||
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM), handler=self.on_connection
|
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM), handler=self.on_connection
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -941,3 +1013,9 @@ class Server(EventEmitter):
|
|||||||
acceptor = self.acceptors.get(dlc.dlci >> 1)
|
acceptor = self.acceptors.get(dlc.dlci >> 1)
|
||||||
if acceptor:
|
if acceptor:
|
||||||
acceptor(dlc)
|
acceptor(dlc)
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args) -> None:
|
||||||
|
self.l2cap_server.close()
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
|
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from . import core, l2cap
|
from . import core, l2cap
|
||||||
from .colors import color
|
from .colors import color
|
||||||
@@ -920,6 +921,13 @@ class Client:
|
|||||||
|
|
||||||
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
|
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Self:
|
||||||
|
await self.connect()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args) -> None:
|
||||||
|
await self.disconnect()
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class Server:
|
class Server:
|
||||||
|
|||||||
@@ -19,7 +19,17 @@ import asyncio
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from . import test_utils
|
from . import test_utils
|
||||||
from bumble.rfcomm import RFCOMM_Frame, Server, Client, DLC
|
from bumble import core
|
||||||
|
from bumble.rfcomm import (
|
||||||
|
RFCOMM_Frame,
|
||||||
|
Server,
|
||||||
|
Client,
|
||||||
|
DLC,
|
||||||
|
make_service_sdp_records,
|
||||||
|
find_rfcomm_channels,
|
||||||
|
find_rfcomm_channel_with_uuid,
|
||||||
|
RFCOMM_PSM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -70,6 +80,45 @@ async def test_basic_connection():
|
|||||||
assert await queues[0].get() == b'Lorem ipsum dolor sit amet'
|
assert await queues[0].get() == b'Lorem ipsum dolor sit amet'
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_record():
|
||||||
|
HANDLE = 2
|
||||||
|
CHANNEL = 1
|
||||||
|
SERVICE_UUID = core.UUID('00000000-0000-0000-0000-000000000001')
|
||||||
|
|
||||||
|
devices = test_utils.TwoDevices()
|
||||||
|
await devices.setup_connection()
|
||||||
|
|
||||||
|
devices[0].sdp_service_records[HANDLE] = make_service_sdp_records(
|
||||||
|
HANDLE, CHANNEL, SERVICE_UUID
|
||||||
|
)
|
||||||
|
|
||||||
|
assert SERVICE_UUID in (await find_rfcomm_channels(devices.connections[1]))[CHANNEL]
|
||||||
|
assert (
|
||||||
|
await find_rfcomm_channel_with_uuid(devices.connections[1], SERVICE_UUID)
|
||||||
|
== CHANNEL
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context():
|
||||||
|
devices = test_utils.TwoDevices()
|
||||||
|
await devices.setup_connection()
|
||||||
|
|
||||||
|
server = Server(devices[0])
|
||||||
|
with server:
|
||||||
|
assert server.l2cap_server is not None
|
||||||
|
|
||||||
|
client = Client(devices.connections[1])
|
||||||
|
async with client:
|
||||||
|
assert client.l2cap_channel is not None
|
||||||
|
|
||||||
|
assert client.l2cap_channel is None
|
||||||
|
assert RFCOMM_PSM not in devices[0].l2cap_channel_manager.servers
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_frames()
|
test_frames()
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from .test_utils import TwoDevices
|
|||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def basic_check(x: DataElement) -> None:
|
def basic_check(x: DataElement) -> None:
|
||||||
serialized = bytes(x)
|
serialized = bytes(x)
|
||||||
@@ -269,6 +270,20 @@ async def test_service_search_attribute():
|
|||||||
assert expect.value == actual.value
|
assert expect.value == actual.value
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_async_context():
|
||||||
|
devices = TwoDevices()
|
||||||
|
await devices.setup_connection()
|
||||||
|
|
||||||
|
client = Client(devices.connections[1])
|
||||||
|
|
||||||
|
async with client:
|
||||||
|
assert client.channel is not None
|
||||||
|
|
||||||
|
assert client.channel is None
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def run():
|
async def run():
|
||||||
test_data_elements()
|
test_data_elements()
|
||||||
|
|||||||
Reference in New Issue
Block a user