Typing SDP

This commit is contained in:
Josh Wu
2023-08-23 14:43:10 +08:00
parent 7341172739
commit 727586e40e
2 changed files with 87 additions and 28 deletions

View File

@@ -18,13 +18,16 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import struct import struct
from typing import Dict, List, Type from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
from . import core from . import core, l2cap
from .colors import color from .colors import color
from .core import InvalidStateError from .core import InvalidStateError
from .hci import HCI_Object, name_or_number, key_with_value from .hci import HCI_Object, name_or_number, key_with_value
if TYPE_CHECKING:
from .device import Device, Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -466,7 +469,7 @@ class ServiceAttribute:
self.value = value self.value = value
@staticmethod @staticmethod
def list_from_data_elements(elements): def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]:
attribute_list = [] attribute_list = []
for i in range(0, len(elements) // 2): for i in range(0, len(elements) // 2):
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)] attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
@@ -478,7 +481,9 @@ class ServiceAttribute:
return attribute_list return attribute_list
@staticmethod @staticmethod
def find_attribute_in_list(attribute_list, attribute_id): def find_attribute_in_list(
attribute_list: List[ServiceAttribute], attribute_id: int
) -> Optional[DataElement]:
return next( return next(
( (
attribute.value attribute.value
@@ -493,7 +498,7 @@ class ServiceAttribute:
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code) return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
@staticmethod @staticmethod
def is_uuid_in_value(uuid, value): def is_uuid_in_value(uuid: core.UUID, value: DataElement) -> bool:
# Find if a uuid matches a value, either directly or recursing into sequences # Find if a uuid matches a value, either directly or recursing into sequences
if value.type == DataElement.UUID: if value.type == DataElement.UUID:
return value.value == uuid return value.value == uuid
@@ -547,7 +552,9 @@ class SDP_PDU:
return self return self
@staticmethod @staticmethod
def parse_service_record_handle_list_preceded_by_count(data, offset): def parse_service_record_handle_list_preceded_by_count(
data: bytes, offset: int
) -> Tuple[int, List[int]]:
count = struct.unpack_from('>H', data, offset - 2)[0] count = struct.unpack_from('>H', data, offset - 2)[0]
handle_list = [ handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count) struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
@@ -645,6 +652,10 @@ class SDP_ServiceSearchRequest(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
''' '''
service_search_pattern: DataElement
maximum_service_record_count: int
continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass( @SDP_PDU.subclass(
@@ -663,6 +674,11 @@ class SDP_ServiceSearchResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
''' '''
service_record_handle_list: List[int]
total_service_record_count: int
current_service_record_count: int
continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass( @SDP_PDU.subclass(
@@ -678,6 +694,11 @@ class SDP_ServiceAttributeRequest(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
''' '''
service_record_handle: int
maximum_attribute_byte_count: int
attribute_id_list: DataElement
continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass( @SDP_PDU.subclass(
@@ -692,6 +713,10 @@ class SDP_ServiceAttributeResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
''' '''
attribute_list_byte_count: int
attribute_list: bytes
continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass( @SDP_PDU.subclass(
@@ -707,6 +732,11 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
''' '''
service_search_pattern: DataElement
maximum_attribute_byte_count: int
attribute_id_list: DataElement
continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass( @SDP_PDU.subclass(
@@ -721,26 +751,34 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
''' '''
attribute_list_byte_count: int
attribute_list: bytes
continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Client: class Client:
def __init__(self, device): channel: Optional[l2cap.Channel]
def __init__(self, device: Device) -> None:
self.device = device self.device = device
self.pending_request = None self.pending_request = None
self.channel = None self.channel = None
async def connect(self, connection): async def connect(self, connection: Connection) -> None:
result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM) result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM)
self.channel = result self.channel = result
async def disconnect(self): async def disconnect(self) -> None:
if self.channel: if self.channel:
await self.channel.disconnect() await self.channel.disconnect()
self.channel = None self.channel = None
async def search_services(self, uuids): async def search_services(self, uuids: List[core.UUID]) -> List[int]:
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None:
raise InvalidStateError('L2CAP not connected')
service_search_pattern = DataElement.sequence( service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids] [DataElement.uuid(uuid) for uuid in uuids]
@@ -770,9 +808,13 @@ class Client:
return service_record_handle_list return service_record_handle_list
async def search_attributes(self, uuids, attribute_ids): async def search_attributes(
self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]]
) -> List[List[ServiceAttribute]]:
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None:
raise InvalidStateError('L2CAP not connected')
service_search_pattern = DataElement.sequence( service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids] [DataElement.uuid(uuid) for uuid in uuids]
@@ -823,9 +865,15 @@ class Client:
if sequence.type == DataElement.SEQUENCE if sequence.type == DataElement.SEQUENCE
] ]
async def get_attributes(self, service_record_handle, attribute_ids): async def get_attributes(
self,
service_record_handle: int,
attribute_ids: List[Union[int, Tuple[int, int]]],
) -> List[ServiceAttribute]:
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None:
raise InvalidStateError('L2CAP not connected')
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
@@ -873,21 +921,25 @@ class Client:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Server: class Server:
CONTINUATION_STATE = bytes([0x01, 0x43]) CONTINUATION_STATE = bytes([0x01, 0x43])
channel: Optional[l2cap.Channel]
Service = NewType('Service', List[ServiceAttribute])
service_records: Dict[int, Service]
current_response: Union[None, bytes, Tuple[int, List[int]]]
def __init__(self, device): def __init__(self, device: Device) -> None:
self.device = device self.device = device
self.service_records = {} # Service records maps, by record handle self.service_records = {} # Service records maps, by record handle
self.channel = None self.channel = None
self.current_response = None self.current_response = None
def register(self, l2cap_channel_manager): def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
l2cap_channel_manager.register_server(SDP_PSM, self.on_connection) l2cap_channel_manager.register_server(SDP_PSM, self.on_connection)
def send_response(self, response): def send_response(self, response):
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}') logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
self.channel.send_pdu(response) self.channel.send_pdu(response)
def match_services(self, search_pattern): def match_services(self, search_pattern: DataElement) -> Dict[int, Service]:
# Find the services for which the attributes in the pattern is a subset of the # Find the services for which the attributes in the pattern is a subset of the
# service's attribute values (NOTE: the value search recurses into sequences) # service's attribute values (NOTE: the value search recurses into sequences)
matching_services = {} matching_services = {}
@@ -957,7 +1009,9 @@ class Server:
return (payload, continuation_state) return (payload, continuation_state)
@staticmethod @staticmethod
def get_service_attributes(service, attribute_ids): def get_service_attributes(
service: Service, attribute_ids: List[DataElement]
) -> DataElement:
attributes = [] attributes = []
for attribute_id in attribute_ids: for attribute_id in attribute_ids:
if attribute_id.value_size == 4: if attribute_id.value_size == 4:
@@ -982,10 +1036,10 @@ class Server:
return attribute_list return attribute_list
def on_sdp_service_search_request(self, request): def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if len(request.continuation_state) > 1:
if not self.current_response: if self.current_response is None:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
@@ -1014,6 +1068,7 @@ class Server:
) )
# Respond, keeping any unsent handles for later # Respond, keeping any unsent handles for later
assert isinstance(self.current_response, tuple)
service_record_handles = self.current_response[1][ service_record_handles = self.current_response[1][
: request.maximum_service_record_count : request.maximum_service_record_count
] ]
@@ -1037,10 +1092,12 @@ class Server:
) )
) )
def on_sdp_service_attribute_request(self, request): def on_sdp_service_attribute_request(
self, request: SDP_ServiceAttributeRequest
) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if len(request.continuation_state) > 1:
if not self.current_response: if self.current_response is None:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
@@ -1073,22 +1130,24 @@ class Server:
self.current_response = bytes(attribute_list) self.current_response = bytes(attribute_list)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
attribute_list, continuation_state = self.get_next_response_payload( attribute_list_response, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count request.maximum_attribute_byte_count
) )
self.send_response( self.send_response(
SDP_ServiceAttributeResponse( SDP_ServiceAttributeResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
attribute_list_byte_count=len(attribute_list), attribute_list_byte_count=len(attribute_list_response),
attribute_list=attribute_list, attribute_list=attribute_list,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
def on_sdp_service_search_attribute_request(self, request): def on_sdp_service_search_attribute_request(
self, request: SDP_ServiceSearchAttributeRequest
) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if len(request.continuation_state) > 1:
if not self.current_response: if self.current_response is None:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
@@ -1118,13 +1177,13 @@ class Server:
self.current_response = bytes(attribute_lists) self.current_response = bytes(attribute_lists)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
attribute_lists, continuation_state = self.get_next_response_payload( attribute_lists_response, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count request.maximum_attribute_byte_count
) )
self.send_response( self.send_response(
SDP_ServiceSearchAttributeResponse( SDP_ServiceSearchAttributeResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
attribute_lists_byte_count=len(attribute_lists), attribute_lists_byte_count=len(attribute_lists_response),
attribute_lists=attribute_lists, attribute_lists=attribute_lists,
continuation_state=continuation_state, continuation_state=continuation_state,
) )

View File

@@ -23,7 +23,7 @@ from bumble.sdp import DataElement
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def basic_check(x): def basic_check(x: DataElement) -> None:
serialized = bytes(x) serialized = bytes(x)
if len(serialized) < 500: if len(serialized) < 500:
print('Original:', x) print('Original:', x)
@@ -41,7 +41,7 @@ def basic_check(x):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_data_elements(): def test_data_elements() -> None:
e = DataElement(DataElement.NIL, None) e = DataElement(DataElement.NIL, None)
basic_check(e) basic_check(e)