forked from auracaster/bumble_mirror
Merge pull request #256 from zxzxwu/sdp-type-fix
Typing SDP and add tests
This commit is contained in:
111
bumble/sdp.py
111
bumble/sdp.py
@@ -18,13 +18,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from typing import Dict, List, Type
|
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
|
||||||
|
|
||||||
from . import core
|
from . import core, l2cap
|
||||||
from .colors import color
|
from .colors import color
|
||||||
from .core import InvalidStateError
|
from .core import InvalidStateError
|
||||||
from .hci import HCI_Object, name_or_number, key_with_value
|
from .hci import HCI_Object, name_or_number, key_with_value
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .device import Device, Connection
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Logging
|
# Logging
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -466,7 +469,7 @@ class ServiceAttribute:
|
|||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_from_data_elements(elements):
|
def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]:
|
||||||
attribute_list = []
|
attribute_list = []
|
||||||
for i in range(0, len(elements) // 2):
|
for i in range(0, len(elements) // 2):
|
||||||
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
|
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
|
||||||
@@ -478,7 +481,9 @@ class ServiceAttribute:
|
|||||||
return attribute_list
|
return attribute_list
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_attribute_in_list(attribute_list, attribute_id):
|
def find_attribute_in_list(
|
||||||
|
attribute_list: List[ServiceAttribute], attribute_id: int
|
||||||
|
) -> Optional[DataElement]:
|
||||||
return next(
|
return next(
|
||||||
(
|
(
|
||||||
attribute.value
|
attribute.value
|
||||||
@@ -493,7 +498,7 @@ class ServiceAttribute:
|
|||||||
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
|
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_uuid_in_value(uuid, value):
|
def is_uuid_in_value(uuid: core.UUID, value: DataElement) -> bool:
|
||||||
# Find if a uuid matches a value, either directly or recursing into sequences
|
# Find if a uuid matches a value, either directly or recursing into sequences
|
||||||
if value.type == DataElement.UUID:
|
if value.type == DataElement.UUID:
|
||||||
return value.value == uuid
|
return value.value == uuid
|
||||||
@@ -547,7 +552,9 @@ class SDP_PDU:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_service_record_handle_list_preceded_by_count(data, offset):
|
def parse_service_record_handle_list_preceded_by_count(
|
||||||
|
data: bytes, offset: int
|
||||||
|
) -> Tuple[int, List[int]]:
|
||||||
count = struct.unpack_from('>H', data, offset - 2)[0]
|
count = struct.unpack_from('>H', data, offset - 2)[0]
|
||||||
handle_list = [
|
handle_list = [
|
||||||
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
|
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
|
||||||
@@ -645,6 +652,10 @@ class SDP_ServiceSearchRequest(SDP_PDU):
|
|||||||
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
|
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
service_search_pattern: DataElement
|
||||||
|
maximum_service_record_count: int
|
||||||
|
continuation_state: bytes
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@SDP_PDU.subclass(
|
@SDP_PDU.subclass(
|
||||||
@@ -663,6 +674,11 @@ class SDP_ServiceSearchResponse(SDP_PDU):
|
|||||||
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
|
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
service_record_handle_list: List[int]
|
||||||
|
total_service_record_count: int
|
||||||
|
current_service_record_count: int
|
||||||
|
continuation_state: bytes
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@SDP_PDU.subclass(
|
@SDP_PDU.subclass(
|
||||||
@@ -678,6 +694,11 @@ class SDP_ServiceAttributeRequest(SDP_PDU):
|
|||||||
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
|
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
service_record_handle: int
|
||||||
|
maximum_attribute_byte_count: int
|
||||||
|
attribute_id_list: DataElement
|
||||||
|
continuation_state: bytes
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@SDP_PDU.subclass(
|
@SDP_PDU.subclass(
|
||||||
@@ -692,6 +713,10 @@ class SDP_ServiceAttributeResponse(SDP_PDU):
|
|||||||
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
|
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
attribute_list_byte_count: int
|
||||||
|
attribute_list: bytes
|
||||||
|
continuation_state: bytes
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@SDP_PDU.subclass(
|
@SDP_PDU.subclass(
|
||||||
@@ -707,6 +732,11 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU):
|
|||||||
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
|
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
service_search_pattern: DataElement
|
||||||
|
maximum_attribute_byte_count: int
|
||||||
|
attribute_id_list: DataElement
|
||||||
|
continuation_state: bytes
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@SDP_PDU.subclass(
|
@SDP_PDU.subclass(
|
||||||
@@ -721,26 +751,34 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
|
|||||||
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
|
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
attribute_list_byte_count: int
|
||||||
|
attribute_list: bytes
|
||||||
|
continuation_state: bytes
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class Client:
|
class Client:
|
||||||
def __init__(self, device):
|
channel: Optional[l2cap.Channel]
|
||||||
|
|
||||||
|
def __init__(self, device: Device) -> None:
|
||||||
self.device = device
|
self.device = device
|
||||||
self.pending_request = None
|
self.pending_request = None
|
||||||
self.channel = None
|
self.channel = None
|
||||||
|
|
||||||
async def connect(self, connection):
|
async def connect(self, connection: Connection) -> None:
|
||||||
result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM)
|
result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM)
|
||||||
self.channel = result
|
self.channel = result
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self) -> None:
|
||||||
if self.channel:
|
if self.channel:
|
||||||
await self.channel.disconnect()
|
await self.channel.disconnect()
|
||||||
self.channel = None
|
self.channel = None
|
||||||
|
|
||||||
async def search_services(self, uuids):
|
async def search_services(self, uuids: List[core.UUID]) -> List[int]:
|
||||||
if self.pending_request is not None:
|
if self.pending_request is not None:
|
||||||
raise InvalidStateError('request already pending')
|
raise InvalidStateError('request already pending')
|
||||||
|
if self.channel is None:
|
||||||
|
raise InvalidStateError('L2CAP not connected')
|
||||||
|
|
||||||
service_search_pattern = DataElement.sequence(
|
service_search_pattern = DataElement.sequence(
|
||||||
[DataElement.uuid(uuid) for uuid in uuids]
|
[DataElement.uuid(uuid) for uuid in uuids]
|
||||||
@@ -770,9 +808,13 @@ class Client:
|
|||||||
|
|
||||||
return service_record_handle_list
|
return service_record_handle_list
|
||||||
|
|
||||||
async def search_attributes(self, uuids, attribute_ids):
|
async def search_attributes(
|
||||||
|
self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]]
|
||||||
|
) -> List[List[ServiceAttribute]]:
|
||||||
if self.pending_request is not None:
|
if self.pending_request is not None:
|
||||||
raise InvalidStateError('request already pending')
|
raise InvalidStateError('request already pending')
|
||||||
|
if self.channel is None:
|
||||||
|
raise InvalidStateError('L2CAP not connected')
|
||||||
|
|
||||||
service_search_pattern = DataElement.sequence(
|
service_search_pattern = DataElement.sequence(
|
||||||
[DataElement.uuid(uuid) for uuid in uuids]
|
[DataElement.uuid(uuid) for uuid in uuids]
|
||||||
@@ -823,9 +865,15 @@ class Client:
|
|||||||
if sequence.type == DataElement.SEQUENCE
|
if sequence.type == DataElement.SEQUENCE
|
||||||
]
|
]
|
||||||
|
|
||||||
async def get_attributes(self, service_record_handle, attribute_ids):
|
async def get_attributes(
|
||||||
|
self,
|
||||||
|
service_record_handle: int,
|
||||||
|
attribute_ids: List[Union[int, Tuple[int, int]]],
|
||||||
|
) -> List[ServiceAttribute]:
|
||||||
if self.pending_request is not None:
|
if self.pending_request is not None:
|
||||||
raise InvalidStateError('request already pending')
|
raise InvalidStateError('request already pending')
|
||||||
|
if self.channel is None:
|
||||||
|
raise InvalidStateError('L2CAP not connected')
|
||||||
|
|
||||||
attribute_id_list = DataElement.sequence(
|
attribute_id_list = DataElement.sequence(
|
||||||
[
|
[
|
||||||
@@ -873,21 +921,25 @@ class Client:
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class Server:
|
class Server:
|
||||||
CONTINUATION_STATE = bytes([0x01, 0x43])
|
CONTINUATION_STATE = bytes([0x01, 0x43])
|
||||||
|
channel: Optional[l2cap.Channel]
|
||||||
|
Service = NewType('Service', List[ServiceAttribute])
|
||||||
|
service_records: Dict[int, Service]
|
||||||
|
current_response: Union[None, bytes, Tuple[int, List[int]]]
|
||||||
|
|
||||||
def __init__(self, device):
|
def __init__(self, device: Device) -> None:
|
||||||
self.device = device
|
self.device = device
|
||||||
self.service_records = {} # Service records maps, by record handle
|
self.service_records = {} # Service records maps, by record handle
|
||||||
self.channel = None
|
self.channel = None
|
||||||
self.current_response = None
|
self.current_response = None
|
||||||
|
|
||||||
def register(self, l2cap_channel_manager):
|
def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
|
||||||
l2cap_channel_manager.register_server(SDP_PSM, self.on_connection)
|
l2cap_channel_manager.register_server(SDP_PSM, self.on_connection)
|
||||||
|
|
||||||
def send_response(self, response):
|
def send_response(self, response):
|
||||||
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
|
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
|
||||||
self.channel.send_pdu(response)
|
self.channel.send_pdu(response)
|
||||||
|
|
||||||
def match_services(self, search_pattern):
|
def match_services(self, search_pattern: DataElement) -> Dict[int, Service]:
|
||||||
# Find the services for which the attributes in the pattern is a subset of the
|
# Find the services for which the attributes in the pattern is a subset of the
|
||||||
# service's attribute values (NOTE: the value search recurses into sequences)
|
# service's attribute values (NOTE: the value search recurses into sequences)
|
||||||
matching_services = {}
|
matching_services = {}
|
||||||
@@ -957,7 +1009,9 @@ class Server:
|
|||||||
return (payload, continuation_state)
|
return (payload, continuation_state)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_service_attributes(service, attribute_ids):
|
def get_service_attributes(
|
||||||
|
service: Service, attribute_ids: List[DataElement]
|
||||||
|
) -> DataElement:
|
||||||
attributes = []
|
attributes = []
|
||||||
for attribute_id in attribute_ids:
|
for attribute_id in attribute_ids:
|
||||||
if attribute_id.value_size == 4:
|
if attribute_id.value_size == 4:
|
||||||
@@ -982,10 +1036,10 @@ class Server:
|
|||||||
|
|
||||||
return attribute_list
|
return attribute_list
|
||||||
|
|
||||||
def on_sdp_service_search_request(self, request):
|
def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None:
|
||||||
# Check if this is a continuation
|
# Check if this is a continuation
|
||||||
if len(request.continuation_state) > 1:
|
if len(request.continuation_state) > 1:
|
||||||
if not self.current_response:
|
if self.current_response is None:
|
||||||
self.send_response(
|
self.send_response(
|
||||||
SDP_ErrorResponse(
|
SDP_ErrorResponse(
|
||||||
transaction_id=request.transaction_id,
|
transaction_id=request.transaction_id,
|
||||||
@@ -1014,6 +1068,7 @@ class Server:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Respond, keeping any unsent handles for later
|
# Respond, keeping any unsent handles for later
|
||||||
|
assert isinstance(self.current_response, tuple)
|
||||||
service_record_handles = self.current_response[1][
|
service_record_handles = self.current_response[1][
|
||||||
: request.maximum_service_record_count
|
: request.maximum_service_record_count
|
||||||
]
|
]
|
||||||
@@ -1037,10 +1092,12 @@ class Server:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_sdp_service_attribute_request(self, request):
|
def on_sdp_service_attribute_request(
|
||||||
|
self, request: SDP_ServiceAttributeRequest
|
||||||
|
) -> None:
|
||||||
# Check if this is a continuation
|
# Check if this is a continuation
|
||||||
if len(request.continuation_state) > 1:
|
if len(request.continuation_state) > 1:
|
||||||
if not self.current_response:
|
if self.current_response is None:
|
||||||
self.send_response(
|
self.send_response(
|
||||||
SDP_ErrorResponse(
|
SDP_ErrorResponse(
|
||||||
transaction_id=request.transaction_id,
|
transaction_id=request.transaction_id,
|
||||||
@@ -1073,22 +1130,24 @@ class Server:
|
|||||||
self.current_response = bytes(attribute_list)
|
self.current_response = bytes(attribute_list)
|
||||||
|
|
||||||
# Respond, keeping any pending chunks for later
|
# Respond, keeping any pending chunks for later
|
||||||
attribute_list, continuation_state = self.get_next_response_payload(
|
attribute_list_response, continuation_state = self.get_next_response_payload(
|
||||||
request.maximum_attribute_byte_count
|
request.maximum_attribute_byte_count
|
||||||
)
|
)
|
||||||
self.send_response(
|
self.send_response(
|
||||||
SDP_ServiceAttributeResponse(
|
SDP_ServiceAttributeResponse(
|
||||||
transaction_id=request.transaction_id,
|
transaction_id=request.transaction_id,
|
||||||
attribute_list_byte_count=len(attribute_list),
|
attribute_list_byte_count=len(attribute_list_response),
|
||||||
attribute_list=attribute_list,
|
attribute_list=attribute_list,
|
||||||
continuation_state=continuation_state,
|
continuation_state=continuation_state,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_sdp_service_search_attribute_request(self, request):
|
def on_sdp_service_search_attribute_request(
|
||||||
|
self, request: SDP_ServiceSearchAttributeRequest
|
||||||
|
) -> None:
|
||||||
# Check if this is a continuation
|
# Check if this is a continuation
|
||||||
if len(request.continuation_state) > 1:
|
if len(request.continuation_state) > 1:
|
||||||
if not self.current_response:
|
if self.current_response is None:
|
||||||
self.send_response(
|
self.send_response(
|
||||||
SDP_ErrorResponse(
|
SDP_ErrorResponse(
|
||||||
transaction_id=request.transaction_id,
|
transaction_id=request.transaction_id,
|
||||||
@@ -1118,13 +1177,13 @@ class Server:
|
|||||||
self.current_response = bytes(attribute_lists)
|
self.current_response = bytes(attribute_lists)
|
||||||
|
|
||||||
# Respond, keeping any pending chunks for later
|
# Respond, keeping any pending chunks for later
|
||||||
attribute_lists, continuation_state = self.get_next_response_payload(
|
attribute_lists_response, continuation_state = self.get_next_response_payload(
|
||||||
request.maximum_attribute_byte_count
|
request.maximum_attribute_byte_count
|
||||||
)
|
)
|
||||||
self.send_response(
|
self.send_response(
|
||||||
SDP_ServiceSearchAttributeResponse(
|
SDP_ServiceSearchAttributeResponse(
|
||||||
transaction_id=request.transaction_id,
|
transaction_id=request.transaction_id,
|
||||||
attribute_lists_byte_count=len(attribute_lists),
|
attribute_lists_byte_count=len(attribute_lists_response),
|
||||||
attribute_lists=attribute_lists,
|
attribute_lists=attribute_lists,
|
||||||
continuation_state=continuation_state,
|
continuation_state=continuation_state,
|
||||||
)
|
)
|
||||||
|
|||||||
13
tests/__init__.py
Normal file
13
tests/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2023 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -21,13 +21,9 @@ import os
|
|||||||
import random
|
import random
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from bumble.controller import Controller
|
|
||||||
from bumble.link import LocalLink
|
|
||||||
from bumble.device import Device
|
|
||||||
from bumble.host import Host
|
|
||||||
from bumble.transport import AsyncPipeSink
|
|
||||||
from bumble.core import ProtocolError
|
from bumble.core import ProtocolError
|
||||||
from bumble.l2cap import L2CAP_Connection_Request
|
from bumble.l2cap import L2CAP_Connection_Request
|
||||||
|
from .test_utils import TwoDevices
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -37,60 +33,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class TwoDevices:
|
|
||||||
def __init__(self):
|
|
||||||
self.connections = [None, None]
|
|
||||||
|
|
||||||
self.link = LocalLink()
|
|
||||||
self.controllers = [
|
|
||||||
Controller('C1', link=self.link),
|
|
||||||
Controller('C2', link=self.link),
|
|
||||||
]
|
|
||||||
self.devices = [
|
|
||||||
Device(
|
|
||||||
address='F0:F1:F2:F3:F4:F5',
|
|
||||||
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
|
|
||||||
),
|
|
||||||
Device(
|
|
||||||
address='F5:F4:F3:F2:F1:F0',
|
|
||||||
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
self.paired = [None, None]
|
|
||||||
|
|
||||||
def on_connection(self, which, connection):
|
|
||||||
self.connections[which] = connection
|
|
||||||
|
|
||||||
def on_paired(self, which, keys):
|
|
||||||
self.paired[which] = keys
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
async def setup_connection():
|
|
||||||
# Create two devices, each with a controller, attached to the same link
|
|
||||||
two_devices = TwoDevices()
|
|
||||||
|
|
||||||
# Attach listeners
|
|
||||||
two_devices.devices[0].on(
|
|
||||||
'connection', lambda connection: two_devices.on_connection(0, connection)
|
|
||||||
)
|
|
||||||
two_devices.devices[1].on(
|
|
||||||
'connection', lambda connection: two_devices.on_connection(1, connection)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start
|
|
||||||
await two_devices.devices[0].power_on()
|
|
||||||
await two_devices.devices[1].power_on()
|
|
||||||
|
|
||||||
# Connect the two devices
|
|
||||||
await two_devices.devices[0].connect(two_devices.devices[1].random_address)
|
|
||||||
|
|
||||||
# Check the post conditions
|
|
||||||
assert two_devices.connections[0] is not None
|
|
||||||
assert two_devices.connections[1] is not None
|
|
||||||
|
|
||||||
return two_devices
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -132,7 +74,8 @@ def test_helpers():
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_basic_connection():
|
async def test_basic_connection():
|
||||||
devices = await setup_connection()
|
devices = TwoDevices()
|
||||||
|
await devices.setup_connection()
|
||||||
psm = 1234
|
psm = 1234
|
||||||
|
|
||||||
# Check that if there's no one listening, we can't connect
|
# Check that if there's no one listening, we can't connect
|
||||||
@@ -184,7 +127,8 @@ async def test_basic_connection():
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def transfer_payload(max_credits, mtu, mps):
|
async def transfer_payload(max_credits, mtu, mps):
|
||||||
devices = await setup_connection()
|
devices = TwoDevices()
|
||||||
|
await devices.setup_connection()
|
||||||
|
|
||||||
received = []
|
received = []
|
||||||
|
|
||||||
@@ -226,7 +170,8 @@ async def test_transfer():
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bidirectional_transfer():
|
async def test_bidirectional_transfer():
|
||||||
devices = await setup_connection()
|
devices = TwoDevices()
|
||||||
|
await devices.setup_connection()
|
||||||
|
|
||||||
client_received = []
|
client_received = []
|
||||||
server_received = []
|
server_received = []
|
||||||
|
|||||||
@@ -15,15 +15,30 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Imports
|
# Imports
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
from bumble.core import UUID
|
import asyncio
|
||||||
from bumble.sdp import DataElement
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from bumble.core import UUID, BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID
|
||||||
|
from bumble.sdp import (
|
||||||
|
DataElement,
|
||||||
|
ServiceAttribute,
|
||||||
|
Client,
|
||||||
|
Server,
|
||||||
|
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||||
|
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||||
|
SDP_PUBLIC_BROWSE_ROOT,
|
||||||
|
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||||
|
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||||
|
)
|
||||||
|
from .test_utils import TwoDevices
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def basic_check(x):
|
def basic_check(x: DataElement) -> None:
|
||||||
serialized = bytes(x)
|
serialized = bytes(x)
|
||||||
if len(serialized) < 500:
|
if len(serialized) < 500:
|
||||||
print('Original:', x)
|
print('Original:', x)
|
||||||
@@ -41,7 +56,7 @@ def basic_check(x):
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def test_data_elements():
|
def test_data_elements() -> None:
|
||||||
e = DataElement(DataElement.NIL, None)
|
e = DataElement(DataElement.NIL, None)
|
||||||
basic_check(e)
|
basic_check(e)
|
||||||
|
|
||||||
@@ -157,5 +172,108 @@ def test_data_elements():
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
if __name__ == '__main__':
|
def sdp_records():
|
||||||
|
return {
|
||||||
|
0x00010001: [
|
||||||
|
ServiceAttribute(
|
||||||
|
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||||
|
DataElement.unsigned_integer_32(0x00010001),
|
||||||
|
),
|
||||||
|
ServiceAttribute(
|
||||||
|
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||||
|
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
|
||||||
|
),
|
||||||
|
ServiceAttribute(
|
||||||
|
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||||
|
DataElement.sequence(
|
||||||
|
[DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
ServiceAttribute(
|
||||||
|
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||||
|
DataElement.sequence(
|
||||||
|
[
|
||||||
|
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def test_service_search():
|
||||||
|
# Setup connections
|
||||||
|
devices = TwoDevices()
|
||||||
|
await devices.setup_connection()
|
||||||
|
assert devices.connections[0]
|
||||||
|
assert devices.connections[1]
|
||||||
|
|
||||||
|
# Register SDP service
|
||||||
|
devices.devices[0].sdp_server.service_records.update(sdp_records())
|
||||||
|
|
||||||
|
# Search for service
|
||||||
|
client = Client(devices.devices[1])
|
||||||
|
await client.connect(devices.connections[1])
|
||||||
|
services = await client.search_services(
|
||||||
|
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then
|
||||||
|
assert services[0] == 0x00010001
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def test_service_attribute():
|
||||||
|
# Setup connections
|
||||||
|
devices = TwoDevices()
|
||||||
|
await devices.setup_connection()
|
||||||
|
|
||||||
|
# Register SDP service
|
||||||
|
devices.devices[0].sdp_server.service_records.update(sdp_records())
|
||||||
|
|
||||||
|
# Search for service
|
||||||
|
client = Client(devices.devices[1])
|
||||||
|
await client.connect(devices.connections[1])
|
||||||
|
attributes = await client.get_attributes(
|
||||||
|
0x00010001, [SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then
|
||||||
|
assert attributes[0].value.value == sdp_records()[0x00010001][0].value.value
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def test_service_search_attribute():
|
||||||
|
# Setup connections
|
||||||
|
devices = TwoDevices()
|
||||||
|
await devices.setup_connection()
|
||||||
|
|
||||||
|
# Register SDP service
|
||||||
|
devices.devices[0].sdp_server.service_records.update(sdp_records())
|
||||||
|
|
||||||
|
# Search for service
|
||||||
|
client = Client(devices.devices[1])
|
||||||
|
await client.connect(devices.connections[1])
|
||||||
|
attributes = await client.search_attributes(
|
||||||
|
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [(0x0000FFFF, 8)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then
|
||||||
|
for expect, actual in zip(attributes, sdp_records().values()):
|
||||||
|
assert expect.id == actual.id
|
||||||
|
assert expect.value == actual.value
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def run():
|
||||||
test_data_elements()
|
test_data_elements()
|
||||||
|
await test_service_attribute()
|
||||||
|
await test_service_search()
|
||||||
|
await test_service_search_attribute()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||||
|
asyncio.run(run())
|
||||||
|
|||||||
73
tests/test_utils.py
Normal file
73
tests/test_utils.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# Copyright 2023 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from bumble.controller import Controller
|
||||||
|
from bumble.link import LocalLink
|
||||||
|
from bumble.device import Device, Connection
|
||||||
|
from bumble.host import Host
|
||||||
|
from bumble.transport import AsyncPipeSink
|
||||||
|
from bumble.hci import Address
|
||||||
|
|
||||||
|
|
||||||
|
class TwoDevices:
|
||||||
|
connections: List[Optional[Connection]]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.connections = [None, None]
|
||||||
|
|
||||||
|
self.link = LocalLink()
|
||||||
|
self.controllers = [
|
||||||
|
Controller('C1', link=self.link),
|
||||||
|
Controller('C2', link=self.link),
|
||||||
|
]
|
||||||
|
self.devices = [
|
||||||
|
Device(
|
||||||
|
address=Address('F0:F1:F2:F3:F4:F5'),
|
||||||
|
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
|
||||||
|
),
|
||||||
|
Device(
|
||||||
|
address=Address('F5:F4:F3:F2:F1:F0'),
|
||||||
|
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.paired = [None, None]
|
||||||
|
|
||||||
|
def on_connection(self, which, connection):
|
||||||
|
self.connections[which] = connection
|
||||||
|
|
||||||
|
def on_paired(self, which, keys):
|
||||||
|
self.paired[which] = keys
|
||||||
|
|
||||||
|
async def setup_connection(self) -> None:
|
||||||
|
# Attach listeners
|
||||||
|
self.devices[0].on(
|
||||||
|
'connection', lambda connection: self.on_connection(0, connection)
|
||||||
|
)
|
||||||
|
self.devices[1].on(
|
||||||
|
'connection', lambda connection: self.on_connection(1, connection)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start
|
||||||
|
await self.devices[0].power_on()
|
||||||
|
await self.devices[1].power_on()
|
||||||
|
|
||||||
|
# Connect the two devices
|
||||||
|
await self.devices[0].connect(self.devices[1].random_address)
|
||||||
|
|
||||||
|
# Check the post conditions
|
||||||
|
assert self.connections[0] is not None
|
||||||
|
assert self.connections[1] is not None
|
||||||
Reference in New Issue
Block a user