Merge pull request #629 from google/gbg/sdp-enforce-mtu

SDP: enforce MTU limits
This commit is contained in:
Gilles Boccon-Gibod
2025-01-21 12:29:18 -05:00
committed by GitHub
3 changed files with 386 additions and 139 deletions

View File

@@ -773,7 +773,6 @@ class ClassicChannel(EventEmitter):
self.psm = psm self.psm = psm
self.source_cid = source_cid self.source_cid = source_cid
self.destination_cid = 0 self.destination_cid = 0
self.response = None
self.connection_result = None self.connection_result = None
self.disconnection_result = None self.disconnection_result = None
self.sink = None self.sink = None
@@ -783,27 +782,15 @@ class ClassicChannel(EventEmitter):
self.state = new_state self.state = new_state
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None: def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
if self.state != self.State.OPEN:
raise InvalidStateError('channel not open')
self.manager.send_pdu(self.connection, self.destination_cid, pdu) self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
self.manager.send_control_frame(self.connection, self.signaling_cid, frame) self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
async def send_request(self, request: SupportsBytes) -> bytes:
# Check that there isn't already a request pending
if self.response:
raise InvalidStateError('request already pending')
if self.state != self.State.OPEN:
raise InvalidStateError('channel not open')
self.response = asyncio.get_running_loop().create_future()
self.send_pdu(request)
return await self.response
def on_pdu(self, pdu: bytes) -> None: def on_pdu(self, pdu: bytes) -> None:
if self.response: if self.sink:
self.response.set_result(pdu)
self.response = None
elif self.sink:
# pylint: disable=not-callable # pylint: disable=not-callable
self.sink(pdu) self.sink(pdu)
else: else:

View File

@@ -16,15 +16,21 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import struct import struct
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING from typing import Iterable, NewType, Optional, Union, Sequence, Type, TYPE_CHECKING
from typing_extensions import Self from typing_extensions import Self
from . import core, l2cap from bumble import core, l2cap
from .colors import color from bumble.colors import color
from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError from bumble.core import (
from .hci import HCI_Object, name_or_number, key_with_value InvalidStateError,
InvalidArgumentError,
InvalidPacketError,
ProtocolError,
)
from bumble.hci import HCI_Object, name_or_number, key_with_value
if TYPE_CHECKING: if TYPE_CHECKING:
from .device import Device, Connection from .device import Device, Connection
@@ -242,11 +248,11 @@ class DataElement:
return DataElement(DataElement.BOOLEAN, value) return DataElement(DataElement.BOOLEAN, value)
@staticmethod @staticmethod
def sequence(value: List[DataElement]) -> DataElement: def sequence(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.SEQUENCE, value) return DataElement(DataElement.SEQUENCE, value)
@staticmethod @staticmethod
def alternative(value: List[DataElement]) -> DataElement: def alternative(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.ALTERNATIVE, value) return DataElement(DataElement.ALTERNATIVE, value)
@staticmethod @staticmethod
@@ -473,7 +479,9 @@ class ServiceAttribute:
self.value = value self.value = value
@staticmethod @staticmethod
def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]: def list_from_data_elements(
elements: Sequence[DataElement],
) -> list[ServiceAttribute]:
attribute_list = [] attribute_list = []
for i in range(0, len(elements) // 2): for i in range(0, len(elements) // 2):
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)] attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
@@ -486,7 +494,7 @@ class ServiceAttribute:
@staticmethod @staticmethod
def find_attribute_in_list( def find_attribute_in_list(
attribute_list: List[ServiceAttribute], attribute_id: int attribute_list: Iterable[ServiceAttribute], attribute_id: int
) -> Optional[DataElement]: ) -> Optional[DataElement]:
return next( return next(
( (
@@ -534,7 +542,12 @@ class SDP_PDU:
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
''' '''
sdp_pdu_classes: Dict[int, Type[SDP_PDU]] = {} RESPONSE_PDU_IDS = {
SDP_SERVICE_SEARCH_REQUEST: SDP_SERVICE_SEARCH_RESPONSE,
SDP_SERVICE_ATTRIBUTE_REQUEST: SDP_SERVICE_ATTRIBUTE_RESPONSE,
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE,
}
sdp_pdu_classes: dict[int, Type[SDP_PDU]] = {}
name = None name = None
pdu_id = 0 pdu_id = 0
@@ -558,7 +571,7 @@ class SDP_PDU:
@staticmethod @staticmethod
def parse_service_record_handle_list_preceded_by_count( def parse_service_record_handle_list_preceded_by_count(
data: bytes, offset: int data: bytes, offset: int
) -> Tuple[int, List[int]]: ) -> tuple[int, list[int]]:
count = struct.unpack_from('>H', data, offset - 2)[0] count = struct.unpack_from('>H', data, offset - 2)[0]
handle_list = [ handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count) struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
@@ -639,6 +652,8 @@ class SDP_ErrorResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
''' '''
error_code: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass( @SDP_PDU.subclass(
@@ -675,7 +690,7 @@ class SDP_ServiceSearchResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
''' '''
service_record_handle_list: List[int] service_record_handle_list: list[int]
total_service_record_count: int total_service_record_count: int
current_service_record_count: int current_service_record_count: int
continuation_state: bytes continuation_state: bytes
@@ -752,31 +767,99 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
''' '''
attribute_list_byte_count: int attribute_lists_byte_count: int
attribute_list: bytes attribute_lists: bytes
continuation_state: bytes continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Client: class Client:
channel: Optional[l2cap.ClassicChannel] def __init__(self, connection: Connection, mtu: int = 0) -> None:
def __init__(self, connection: Connection) -> None:
self.connection = connection self.connection = connection
self.pending_request = None self.channel: Optional[l2cap.ClassicChannel] = None
self.channel = None self.mtu = mtu
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request: Optional[SDP_PDU] = None
self.pending_response: Optional[asyncio.futures.Future[SDP_PDU]] = None
self.next_transaction_id = 0
async def connect(self) -> None: async def connect(self) -> None:
self.channel = await self.connection.create_l2cap_channel( self.channel = await self.connection.create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(SDP_PSM) spec=(
l2cap.ClassicChannelSpec(SDP_PSM, self.mtu)
if self.mtu
else l2cap.ClassicChannelSpec(SDP_PSM)
)
) )
self.channel.sink = self.on_pdu
async def disconnect(self) -> None: async def disconnect(self) -> None:
if self.channel: if self.channel:
await self.channel.disconnect() await self.channel.disconnect()
self.channel = None self.channel = None
async def search_services(self, uuids: List[core.UUID]) -> List[int]: def make_transaction_id(self) -> int:
transaction_id = self.next_transaction_id
self.next_transaction_id = (self.next_transaction_id + 1) & 0xFFFF
return transaction_id
def on_pdu(self, pdu: bytes) -> None:
if not self.pending_request:
logger.warning('received response with no pending request')
return
assert self.pending_response is not None
response = SDP_PDU.from_bytes(pdu)
# Check that the transaction ID is what we expect
if self.pending_request.transaction_id != response.transaction_id:
logger.warning(
f"received response with transaction ID {response.transaction_id} "
f"but expected {self.pending_request.transaction_id}"
)
return
# Check if the response is an error
if isinstance(response, SDP_ErrorResponse):
self.pending_response.set_exception(
ProtocolError(error_code=response.error_code)
)
return
# Check that the type of the response matches the request
if response.pdu_id != SDP_PDU.RESPONSE_PDU_IDS.get(self.pending_request.pdu_id):
logger.warning("response type mismatch")
return
self.pending_response.set_result(response)
async def send_request(self, request: SDP_PDU) -> SDP_PDU:
assert self.channel is not None
async with self.request_semaphore:
assert self.pending_request is None
assert self.pending_response is None
# Create a future value to hold the eventual response
self.pending_response = asyncio.get_running_loop().create_future()
self.pending_request = request
try:
self.channel.send_pdu(bytes(request))
return await self.pending_response
finally:
self.pending_request = None
self.pending_response = None
async def search_services(self, uuids: Iterable[core.UUID]) -> list[int]:
"""
Search for services by UUID.
Args:
uuids: service the UUIDs to search for.
Returns:
A list of matching service record handles.
"""
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None: if self.channel is None:
@@ -791,16 +874,16 @@ class Client:
continuation_state = bytes([0]) continuation_state = bytes([0])
watchdog = SDP_CONTINUATION_WATCHDOG watchdog = SDP_CONTINUATION_WATCHDOG
while watchdog > 0: while watchdog > 0:
response_pdu = await self.channel.send_request( response = await self.send_request(
SDP_ServiceSearchRequest( SDP_ServiceSearchRequest(
transaction_id=0, # Transaction ID TODO: pick a real value transaction_id=self.make_transaction_id(),
service_search_pattern=service_search_pattern, service_search_pattern=service_search_pattern,
maximum_service_record_count=0xFFFF, maximum_service_record_count=0xFFFF,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu)
logger.debug(f'<<< Response: {response}') logger.debug(f'<<< Response: {response}')
assert isinstance(response, SDP_ServiceSearchResponse)
service_record_handle_list += response.service_record_handle_list service_record_handle_list += response.service_record_handle_list
continuation_state = response.continuation_state continuation_state = response.continuation_state
if len(continuation_state) == 1 and continuation_state[0] == 0: if len(continuation_state) == 1 and continuation_state[0] == 0:
@@ -811,8 +894,21 @@ class Client:
return service_record_handle_list return service_record_handle_list
async def search_attributes( async def search_attributes(
self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]] self,
) -> List[List[ServiceAttribute]]: uuids: Iterable[core.UUID],
attribute_ids: Iterable[Union[int, tuple[int, int]]],
) -> list[list[ServiceAttribute]]:
"""
Search for attributes by UUID and attribute IDs.
Args:
uuids: the service UUIDs to search for.
attribute_ids: list of attribute IDs or (start, end) attribute ID ranges.
(use (0, 0xFFFF) to include all attributes)
Returns:
A list of list of attributes, one list per matching service.
"""
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None: if self.channel is None:
@@ -824,8 +920,8 @@ class Client:
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
( (
DataElement.unsigned_integer( DataElement.unsigned_integer_32(
attribute_id[0], value_size=attribute_id[1] attribute_id[0] << 16 | attribute_id[1]
) )
if isinstance(attribute_id, tuple) if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id) else DataElement.unsigned_integer_16(attribute_id)
@@ -839,17 +935,17 @@ class Client:
continuation_state = bytes([0]) continuation_state = bytes([0])
watchdog = SDP_CONTINUATION_WATCHDOG watchdog = SDP_CONTINUATION_WATCHDOG
while watchdog > 0: while watchdog > 0:
response_pdu = await self.channel.send_request( response = await self.send_request(
SDP_ServiceSearchAttributeRequest( SDP_ServiceSearchAttributeRequest(
transaction_id=0, # Transaction ID TODO: pick a real value transaction_id=self.make_transaction_id(),
service_search_pattern=service_search_pattern, service_search_pattern=service_search_pattern,
maximum_attribute_byte_count=0xFFFF, maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list, attribute_id_list=attribute_id_list,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu)
logger.debug(f'<<< Response: {response}') logger.debug(f'<<< Response: {response}')
assert isinstance(response, SDP_ServiceSearchAttributeResponse)
accumulator += response.attribute_lists accumulator += response.attribute_lists
continuation_state = response.continuation_state continuation_state = response.continuation_state
if len(continuation_state) == 1 and continuation_state[0] == 0: if len(continuation_state) == 1 and continuation_state[0] == 0:
@@ -872,8 +968,18 @@ class Client:
async def get_attributes( async def get_attributes(
self, self,
service_record_handle: int, service_record_handle: int,
attribute_ids: List[Union[int, Tuple[int, int]]], attribute_ids: Iterable[Union[int, tuple[int, int]]],
) -> List[ServiceAttribute]: ) -> list[ServiceAttribute]:
"""
Get attributes for a service.
Args:
service_record_handle: the handle for a service
attribute_ids: list or attribute IDs or (start, end) attribute ID handles.
Returns:
A list of attributes.
"""
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None: if self.channel is None:
@@ -882,8 +988,8 @@ class Client:
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
( (
DataElement.unsigned_integer( DataElement.unsigned_integer_32(
attribute_id[0], value_size=attribute_id[1] attribute_id[0] << 16 | attribute_id[1]
) )
if isinstance(attribute_id, tuple) if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id) else DataElement.unsigned_integer_16(attribute_id)
@@ -897,17 +1003,17 @@ class Client:
continuation_state = bytes([0]) continuation_state = bytes([0])
watchdog = SDP_CONTINUATION_WATCHDOG watchdog = SDP_CONTINUATION_WATCHDOG
while watchdog > 0: while watchdog > 0:
response_pdu = await self.channel.send_request( response = await self.send_request(
SDP_ServiceAttributeRequest( SDP_ServiceAttributeRequest(
transaction_id=0, # Transaction ID TODO: pick a real value transaction_id=self.make_transaction_id(),
service_record_handle=service_record_handle, service_record_handle=service_record_handle,
maximum_attribute_byte_count=0xFFFF, maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list, attribute_id_list=attribute_id_list,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu)
logger.debug(f'<<< Response: {response}') logger.debug(f'<<< Response: {response}')
assert isinstance(response, SDP_ServiceAttributeResponse)
accumulator += response.attribute_list accumulator += response.attribute_list
continuation_state = response.continuation_state continuation_state = response.continuation_state
if len(continuation_state) == 1 and continuation_state[0] == 0: if len(continuation_state) == 1 and continuation_state[0] == 0:
@@ -933,17 +1039,17 @@ class Client:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Server: class Server:
CONTINUATION_STATE = bytes([0x01, 0x43]) CONTINUATION_STATE = bytes([0x01, 0x00])
channel: Optional[l2cap.ClassicChannel] channel: Optional[l2cap.ClassicChannel]
Service = NewType('Service', List[ServiceAttribute]) Service = NewType('Service', list[ServiceAttribute])
service_records: Dict[int, Service] service_records: dict[int, Service]
current_response: Union[None, bytes, Tuple[int, List[int]]] current_response: Union[None, bytes, tuple[int, list[int]]]
def __init__(self, device: Device) -> None: def __init__(self, device: Device) -> None:
self.device = device self.device = device
self.service_records = {} # Service records maps, by record handle self.service_records = {} # Service records maps, by record handle
self.channel = None self.channel = None
self.current_response = None self.current_response = None # Current response data, used for continuations
def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None: def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
l2cap_channel_manager.create_classic_server( l2cap_channel_manager.create_classic_server(
@@ -954,7 +1060,7 @@ class Server:
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}') logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
self.channel.send_pdu(response) self.channel.send_pdu(response)
def match_services(self, search_pattern: DataElement) -> Dict[int, Service]: def match_services(self, search_pattern: DataElement) -> dict[int, Service]:
# Find the services for which the attributes in the pattern is a subset of the # Find the services for which the attributes in the pattern is a subset of the
# service's attribute values (NOTE: the value search recurses into sequences) # service's attribute values (NOTE: the value search recurses into sequences)
matching_services = {} matching_services = {}
@@ -1011,6 +1117,31 @@ class Server:
) )
) )
def check_continuation(
self,
continuation_state: bytes,
transaction_id: int,
) -> Optional[bool]:
# Check if this is a valid continuation
if len(continuation_state) > 1:
if (
self.current_response is None
or continuation_state != self.CONTINUATION_STATE
):
self.send_response(
SDP_ErrorResponse(
transaction_id=transaction_id,
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
)
)
return None
return True
# Cleanup any partial response leftover
self.current_response = None
return False
def get_next_response_payload(self, maximum_size): def get_next_response_payload(self, maximum_size):
if len(self.current_response) > maximum_size: if len(self.current_response) > maximum_size:
payload = self.current_response[:maximum_size] payload = self.current_response[:maximum_size]
@@ -1025,7 +1156,7 @@ class Server:
@staticmethod @staticmethod
def get_service_attributes( def get_service_attributes(
service: Service, attribute_ids: List[DataElement] service: Service, attribute_ids: Iterable[DataElement]
) -> DataElement: ) -> DataElement:
attributes = [] attributes = []
for attribute_id in attribute_ids: for attribute_id in attribute_ids:
@@ -1053,30 +1184,24 @@ class Server:
def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None: def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if (
if self.current_response is None: continuation := self.check_continuation(
self.send_response( request.continuation_state, request.transaction_id
SDP_ErrorResponse( )
transaction_id=request.transaction_id, ) is None:
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, return
)
)
return
else:
# Cleanup any partial response leftover
self.current_response = None
if not continuation:
# Find the matching services # Find the matching services
matching_services = self.match_services(request.service_search_pattern) matching_services = self.match_services(request.service_search_pattern)
service_record_handles = list(matching_services.keys()) service_record_handles = list(matching_services.keys())
logger.debug(f'Service Record Handles: {service_record_handles}')
# Only return up to the maximum requested # Only return up to the maximum requested
service_record_handles_subset = service_record_handles[ service_record_handles_subset = service_record_handles[
: request.maximum_service_record_count : request.maximum_service_record_count
] ]
# Serialize to a byte array, and remember the total count
logger.debug(f'Service Record Handles: {service_record_handles}')
self.current_response = ( self.current_response = (
len(service_record_handles), len(service_record_handles),
service_record_handles_subset, service_record_handles_subset,
@@ -1084,15 +1209,21 @@ class Server:
# Respond, keeping any unsent handles for later # Respond, keeping any unsent handles for later
assert isinstance(self.current_response, tuple) assert isinstance(self.current_response, tuple)
service_record_handles = self.current_response[1][ assert self.channel is not None
: request.maximum_service_record_count total_service_record_count, service_record_handles = self.current_response
maximum_service_record_count = (self.channel.peer_mtu - 11) // 4
service_record_handles_remaining = service_record_handles[
maximum_service_record_count:
] ]
service_record_handles = service_record_handles[:maximum_service_record_count]
self.current_response = ( self.current_response = (
self.current_response[0], total_service_record_count,
self.current_response[1][request.maximum_service_record_count :], service_record_handles_remaining,
) )
continuation_state = ( continuation_state = (
Server.CONTINUATION_STATE if self.current_response[1] else bytes([0]) Server.CONTINUATION_STATE
if service_record_handles_remaining
else bytes([0])
) )
service_record_handle_list = b''.join( service_record_handle_list = b''.join(
[struct.pack('>I', handle) for handle in service_record_handles] [struct.pack('>I', handle) for handle in service_record_handles]
@@ -1100,7 +1231,7 @@ class Server:
self.send_response( self.send_response(
SDP_ServiceSearchResponse( SDP_ServiceSearchResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
total_service_record_count=self.current_response[0], total_service_record_count=total_service_record_count,
current_service_record_count=len(service_record_handles), current_service_record_count=len(service_record_handles),
service_record_handle_list=service_record_handle_list, service_record_handle_list=service_record_handle_list,
continuation_state=continuation_state, continuation_state=continuation_state,
@@ -1111,19 +1242,14 @@ class Server:
self, request: SDP_ServiceAttributeRequest self, request: SDP_ServiceAttributeRequest
) -> None: ) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if (
if self.current_response is None: continuation := self.check_continuation(
self.send_response( request.continuation_state, request.transaction_id
SDP_ErrorResponse( )
transaction_id=request.transaction_id, ) is None:
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, return
)
)
return
else:
# Cleanup any partial response leftover
self.current_response = None
if not continuation:
# Check that the service exists # Check that the service exists
service = self.service_records.get(request.service_record_handle) service = self.service_records.get(request.service_record_handle)
if service is None: if service is None:
@@ -1145,14 +1271,18 @@ class Server:
self.current_response = bytes(attribute_list) self.current_response = bytes(attribute_list)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
assert self.channel is not None
maximum_attribute_byte_count = min(
request.maximum_attribute_byte_count, self.channel.peer_mtu - 9
)
attribute_list_response, continuation_state = self.get_next_response_payload( attribute_list_response, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count maximum_attribute_byte_count
) )
self.send_response( self.send_response(
SDP_ServiceAttributeResponse( SDP_ServiceAttributeResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
attribute_list_byte_count=len(attribute_list_response), attribute_list_byte_count=len(attribute_list_response),
attribute_list=attribute_list, attribute_list=attribute_list_response,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
@@ -1161,18 +1291,14 @@ class Server:
self, request: SDP_ServiceSearchAttributeRequest self, request: SDP_ServiceSearchAttributeRequest
) -> None: ) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if (
if self.current_response is None: continuation := self.check_continuation(
self.send_response( request.continuation_state, request.transaction_id
SDP_ErrorResponse( )
transaction_id=request.transaction_id, ) is None:
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, return
)
)
else:
# Cleanup any partial response leftover
self.current_response = None
if not continuation:
# Find the matching services # Find the matching services
matching_services = self.match_services( matching_services = self.match_services(
request.service_search_pattern request.service_search_pattern
@@ -1192,14 +1318,18 @@ class Server:
self.current_response = bytes(attribute_lists) self.current_response = bytes(attribute_lists)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
assert self.channel is not None
maximum_attribute_byte_count = min(
request.maximum_attribute_byte_count, self.channel.peer_mtu - 9
)
attribute_lists_response, continuation_state = self.get_next_response_payload( attribute_lists_response, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count maximum_attribute_byte_count
) )
self.send_response( self.send_response(
SDP_ServiceSearchAttributeResponse( SDP_ServiceSearchAttributeResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
attribute_lists_byte_count=len(attribute_lists_response), attribute_lists_byte_count=len(attribute_lists_response),
attribute_lists=attribute_lists, attribute_lists=attribute_lists_response,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )

View File

@@ -20,12 +20,11 @@ import logging
import os import os
import pytest import pytest
from bumble.core import UUID, BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID from bumble.core import UUID, BT_L2CAP_PROTOCOL_ID
from bumble.sdp import ( from bumble.sdp import (
DataElement, DataElement,
ServiceAttribute, ServiceAttribute,
Client, Client,
Server,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_PUBLIC_BROWSE_ROOT, SDP_PUBLIC_BROWSE_ROOT,
@@ -174,9 +173,10 @@ def test_data_elements() -> None:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def sdp_records(): def sdp_records(record_count=1):
return { return {
0x00010001: [ 0x00010001
+ i: [
ServiceAttribute( ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(0x00010001), DataElement.unsigned_integer_32(0x00010001),
@@ -200,6 +200,7 @@ def sdp_records():
), ),
), ),
] ]
for i in range(record_count)
} }
@@ -216,19 +217,55 @@ async def test_service_search():
devices.devices[0].sdp_server.service_records.update(sdp_records()) devices.devices[0].sdp_server.service_records.update(sdp_records())
# Search for service # Search for service
client = Client(devices.connections[1]) async with Client(devices.connections[1]) as client:
await client.connect() services = await client.search_services(
services = await client.search_services( [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AF')]
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')] )
) assert len(services) == 0
# Then services = await client.search_services(
assert services[0] == 0x00010001 [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')]
)
assert len(services) == 1
assert services[0] == 0x00010001
services = await client.search_services(
[BT_L2CAP_PROTOCOL_ID, SDP_PUBLIC_BROWSE_ROOT]
)
assert len(services) == 1
assert services[0] == 0x00010001
services = await client.search_services(
[BT_L2CAP_PROTOCOL_ID, SDP_PUBLIC_BROWSE_ROOT]
)
assert len(services) == 1
assert services[0] == 0x00010001
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_service_attribute(): async def test_service_search_with_continuation():
# Setup connections
devices = TwoDevices()
await devices.setup_connection()
# Register SDP service
records = sdp_records(100)
devices.devices[0].sdp_server.service_records.update(records)
# Search for service
async with Client(devices.connections[1], mtu=48) as client:
services = await client.search_services(
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')]
)
assert len(services) == len(records)
for i in range(len(records)):
assert services[i] == 0x00010001 + i
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_attributes():
# Setup connections # Setup connections
devices = TwoDevices() devices = TwoDevices()
await devices.setup_connection() await devices.setup_connection()
@@ -236,15 +273,43 @@ async def test_service_attribute():
# Register SDP service # Register SDP service
devices.devices[0].sdp_server.service_records.update(sdp_records()) devices.devices[0].sdp_server.service_records.update(sdp_records())
# Search for service # Get attributes
client = Client(devices.connections[1]) async with Client(devices.connections[1]) as client:
await client.connect() attributes = await client.get_attributes(0x00010001, [1234])
attributes = await client.get_attributes( assert len(attributes) == 0
0x00010001, [SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID]
)
# Then attributes = await client.get_attributes(
assert attributes[0].value.value == sdp_records()[0x00010001][0].value.value 0x00010001, [SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID]
)
assert len(attributes) == 1
assert attributes[0].value.value == sdp_records()[0x00010001][0].value.value
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_attributes_with_continuation():
# Setup connections
devices = TwoDevices()
await devices.setup_connection()
# Register SDP service
records = {
0x00010001: [
ServiceAttribute(
x,
DataElement.unsigned_integer_32(0x00010001),
)
for x in range(100)
]
}
devices.devices[0].sdp_server.service_records.update(records)
# Get attributes
async with Client(devices.connections[1], mtu=48) as client:
attributes = await client.get_attributes(0x00010001, list(range(100)))
assert len(attributes) == 100
for i, attribute in enumerate(attributes):
assert attribute.id == i
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -255,19 +320,81 @@ async def test_service_search_attribute():
await devices.setup_connection() await devices.setup_connection()
# Register SDP service # Register SDP service
devices.devices[0].sdp_server.service_records.update(sdp_records()) records = {
0x00010001: [
ServiceAttribute(
4,
DataElement.sequence(
[DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
),
),
ServiceAttribute(
3,
DataElement.sequence(
[DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
),
),
ServiceAttribute(
1,
DataElement.sequence(
[DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
),
),
]
}
devices.devices[0].sdp_server.service_records.update(records)
# Search for service # Search for service
client = Client(devices.connections[1]) async with Client(devices.connections[1]) as client:
await client.connect() attributes = await client.search_attributes(
attributes = await client.search_attributes( [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [(0, 0xFFFF)]
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [(0x0000FFFF, 8)] )
) assert len(attributes) == 1
assert len(attributes[0]) == 3
assert attributes[0][0].id == 1
assert attributes[0][1].id == 3
assert attributes[0][2].id == 4
# Then attributes = await client.search_attributes(
for expect, actual in zip(attributes, sdp_records().values()): [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [1, 2, 3]
assert expect.id == actual.id )
assert expect.value == actual.value assert len(attributes) == 1
assert len(attributes[0]) == 2
assert attributes[0][0].id == 1
assert attributes[0][1].id == 3
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_search_attribute_with_continuation():
# Setup connections
devices = TwoDevices()
await devices.setup_connection()
# Register SDP service
records = {
0x00010001: [
ServiceAttribute(
x,
DataElement.sequence(
[DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
),
)
for x in range(100)
]
}
devices.devices[0].sdp_server.service_records.update(records)
# Search for service
async with Client(devices.connections[1], mtu=48) as client:
attributes = await client.search_attributes(
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [(0, 0xFFFF)]
)
assert len(attributes) == 1
assert len(attributes[0]) == 100
for i in range(100):
assert attributes[0][i].id == i
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -287,9 +414,12 @@ async def test_client_async_context():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def run(): async def run():
test_data_elements() test_data_elements()
await test_service_attribute() await test_service_attributes()
await test_service_attributes_with_continuation()
await test_service_search() await test_service_search()
await test_service_search_with_continuation()
await test_service_search_attribute() await test_service_search_attribute()
await test_service_search_attribute_with_continuation()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------