diff --git a/bumble/hci.py b/bumble/hci.py index 029c2dc..efdd596 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -107,6 +107,36 @@ def phy_list_to_bits(phys: Optional[Iterable[Phy]]) -> int: return phy_bits +# ----------------------------------------------------------------------------- +# Field Metadata +# ----------------------------------------------------------------------------- +# Field specification can be: +# - a dict with "serializer", "parser", "size", "mapper" keys +# - a callable that takes (packet, offset) and returns (new_offset, value) (deserialize only) +# - a string of +# - ">2" and ">4" for 2-byte and 4-byte big-endian integers +# - "*" for all remaining bytes in the packet +# - "v" for variable length bytes with a leading length byte +# - an integer [1, 4] for 1-byte, 2-byte or 4-byte unsigned little-endian integers +# - an integer [-2, -1] for 1-byte, 2-byte signed little-endian integers +FieldSpec = Union[dict[str, Any], Callable[[bytes, int], tuple[int, Any]], str, int] + + +@dataclasses.dataclass +class FieldMetadata: + spec: FieldSpec + list_begin: bool = False + list_end: bool = False + + +def metadata( + spec: FieldSpec, list_begin: bool = False, list_end: bool = False +) -> dict[str, Any]: + return { + "bumble.hci": FieldMetadata(spec=spec, list_begin=list_begin, list_end=list_end) + } + + # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- @@ -1659,7 +1689,7 @@ class HCI_Object: HCI_Object.init_from_fields(hci_object, parsed.keys(), parsed.values()) @staticmethod - def parse_field(data, offset, field_type): + def parse_field(data: bytes, offset: int, field_type: FieldSpec): # The field_type may be a dictionary with a mapper, parser, and/or size if isinstance(field_type, dict): if 'size' in field_type: @@ -1741,7 +1771,7 @@ class HCI_Object: return result @staticmethod - def serialize_field(field_value, field_type): + def serialize_field(field_value: Any, field_type: FieldSpec) -> bytes: # The field_type may be a dictionary with a mapper, parser, serializer, # and/or size serializer = None @@ -1932,6 +1962,24 @@ class HCI_Object: for field_name, field_value in field_strings ) + @classmethod + def fields_from_dataclass(cls, obj: Any) -> list[Any]: + stack: list[list[Any]] = [[]] + for field in dataclasses.fields(obj): + # Fields without metadata should be ignored. + if not isinstance( + (metadata := field.metadata.get("bumble.hci")), FieldMetadata + ): + continue + if metadata.list_begin: + stack.append([]) + if metadata.spec: + stack[-1].append((field.name, metadata.spec)) + if metadata.list_end: + top = stack.pop() + stack[-1].append(top) + return stack[0] + def __init__(self, fields, **kwargs): self.fields = fields self.init_from_fields(self, fields, kwargs) diff --git a/bumble/profiles/ascs.py b/bumble/profiles/ascs.py index a8bdce3..409d054 100644 --- a/bumble/profiles/ascs.py +++ b/bumble/profiles/ascs.py @@ -18,10 +18,13 @@ # ----------------------------------------------------------------------------- from __future__ import annotations +from dataclasses import dataclass, field import enum +import functools import logging import struct -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Optional, Union, TypeVar +from collections.abc import Sequence from bumble import utils from bumble import colors @@ -48,11 +51,11 @@ class ASE_Operation: See Audio Stream Control Service - 5 ASE Control operations. ''' - classes: Dict[int, Type[ASE_Operation]] = {} - op_code: int + classes: dict[int, type[ASE_Operation]] = {} + op_code: Opcode name: str fields: Optional[Sequence[Any]] = None - ase_id: List[int] + ase_id: Sequence[int] class Opcode(enum.IntEnum): # fmt: off @@ -65,51 +68,30 @@ class ASE_Operation: UPDATE_METADATA = 0x07 RELEASE = 0x08 - @staticmethod - def from_bytes(pdu: bytes) -> ASE_Operation: + @classmethod + def from_bytes(cls, pdu: bytes) -> ASE_Operation: op_code = pdu[0] - cls = ASE_Operation.classes.get(op_code) - if cls is None: - instance = ASE_Operation(pdu) - instance.name = ASE_Operation.Opcode(op_code).name - instance.op_code = op_code - return instance - self = cls.__new__(cls) - ASE_Operation.__init__(self, pdu) - if self.fields is not None: - self.init_from_bytes(pdu, 1) - return self + clazz = ASE_Operation.classes[op_code] + return clazz( + **hci.HCI_Object.dict_from_bytes(pdu, offset=1, fields=clazz.fields) + ) - @staticmethod - def subclass(fields): - def inner(cls: Type[ASE_Operation]): - try: - operation = ASE_Operation.Opcode[cls.__name__[4:].upper()] - cls.name = operation.name - cls.op_code = operation - except: - raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode') - cls.fields = fields + _OP = TypeVar("_OP", bound="ASE_Operation") - # Register a factory for this class - ASE_Operation.classes[cls.op_code] = cls + @classmethod + def subclass(cls, clazz: type[_OP]) -> type[_OP]: + clazz.name = f"ASE_{clazz.op_code.name.upper()}" + clazz.fields = hci.HCI_Object.fields_from_dataclass(clazz) + # Register a factory for this class + ASE_Operation.classes[clazz.op_code] = clazz + return clazz - return cls - - return inner - - def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None: - if self.fields is not None and kwargs: - hci.HCI_Object.init_from_fields(self, self.fields, kwargs) - if pdu is None: - pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes( - kwargs, self.fields - ) - self.pdu = pdu - - def init_from_bytes(self, pdu: bytes, offset: int): - return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields) + @functools.cached_property + def pdu(self) -> bytes: + return bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes( + self.__dict__, self.fields + ) def __bytes__(self) -> bytes: return self.pdu @@ -124,105 +106,128 @@ class ASE_Operation: return result -@ASE_Operation.subclass( - [ - [ - ('ase_id', 1), - ('target_latency', 1), - ('target_phy', 1), - ('codec_id', hci.CodingFormat.parse_from_bytes), - ('codec_specific_configuration', 'v'), - ], - ] -) +@ASE_Operation.subclass +@dataclass class ASE_Config_Codec(ASE_Operation): ''' See Audio Stream Control Service 5.1 - Config Codec Operation ''' - target_latency: List[int] - target_phy: List[int] - codec_id: List[hci.CodingFormat] - codec_specific_configuration: List[bytes] + op_code = ASE_Operation.Opcode.CONFIG_CODEC + + ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True)) + target_latency: Sequence[int] = field(metadata=hci.metadata(1)) + target_phy: Sequence[int] = field(metadata=hci.metadata(1)) + codec_id: Sequence[hci.CodingFormat] = field( + metadata=hci.metadata(hci.CodingFormat.parse_from_bytes) + ) + codec_specific_configuration: Sequence[bytes] = field( + metadata=hci.metadata('v', list_end=True) + ) -@ASE_Operation.subclass( - [ - [ - ('ase_id', 1), - ('cig_id', 1), - ('cis_id', 1), - ('sdu_interval', 3), - ('framing', 1), - ('phy', 1), - ('max_sdu', 2), - ('retransmission_number', 1), - ('max_transport_latency', 2), - ('presentation_delay', 3), - ], - ] -) +@ASE_Operation.subclass +@dataclass class ASE_Config_QOS(ASE_Operation): ''' See Audio Stream Control Service 5.2 - Config Qos Operation ''' - cig_id: List[int] - cis_id: List[int] - sdu_interval: List[int] - framing: List[int] - phy: List[int] - max_sdu: List[int] - retransmission_number: List[int] - max_transport_latency: List[int] - presentation_delay: List[int] + op_code = ASE_Operation.Opcode.CONFIG_QOS + + ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True)) + cig_id: Sequence[int] = field(metadata=hci.metadata(1)) + cis_id: Sequence[int] = field(metadata=hci.metadata(1)) + sdu_interval: Sequence[int] = field(metadata=hci.metadata(3)) + framing: Sequence[int] = field(metadata=hci.metadata(1)) + phy: Sequence[int] = field(metadata=hci.metadata(1)) + max_sdu: Sequence[int] = field(metadata=hci.metadata(2)) + retransmission_number: Sequence[int] = field(metadata=hci.metadata(1)) + max_transport_latency: Sequence[int] = field(metadata=hci.metadata(2)) + presentation_delay: Sequence[int] = field(metadata=hci.metadata(3, list_end=True)) -@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) +@ASE_Operation.subclass +@dataclass class ASE_Enable(ASE_Operation): ''' See Audio Stream Control Service 5.3 - Enable Operation ''' - metadata: bytes + op_code = ASE_Operation.Opcode.ENABLE + + ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True)) + metadata: Sequence[bytes] = field(metadata=hci.metadata('v', list_end=True)) -@ASE_Operation.subclass([[('ase_id', 1)]]) +@ASE_Operation.subclass +@dataclass class ASE_Receiver_Start_Ready(ASE_Operation): ''' See Audio Stream Control Service 5.4 - Receiver Start Ready Operation ''' + op_code = ASE_Operation.Opcode.RECEIVER_START_READY -@ASE_Operation.subclass([[('ase_id', 1)]]) + ase_id: Sequence[int] = field( + metadata=hci.metadata(1, list_begin=True, list_end=True) + ) + + +@ASE_Operation.subclass +@dataclass class ASE_Disable(ASE_Operation): ''' See Audio Stream Control Service 5.5 - Disable Operation ''' + op_code = ASE_Operation.Opcode.DISABLE -@ASE_Operation.subclass([[('ase_id', 1)]]) + ase_id: Sequence[int] = field( + metadata=hci.metadata(1, list_begin=True, list_end=True) + ) + + +@ASE_Operation.subclass +@dataclass class ASE_Receiver_Stop_Ready(ASE_Operation): ''' See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation ''' + op_code = ASE_Operation.Opcode.RECEIVER_STOP_READY -@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) + ase_id: Sequence[int] = field( + metadata=hci.metadata(1, list_begin=True, list_end=True) + ) + + +@ASE_Operation.subclass +@dataclass class ASE_Update_Metadata(ASE_Operation): ''' See Audio Stream Control Service 5.7 - Update Metadata Operation ''' - metadata: List[bytes] + op_code = ASE_Operation.Opcode.UPDATE_METADATA + + ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True)) + metadata: Sequence[bytes] = field(metadata=hci.metadata('v', list_end=True)) -@ASE_Operation.subclass([[('ase_id', 1)]]) +@ASE_Operation.subclass +@dataclass class ASE_Release(ASE_Operation): ''' See Audio Stream Control Service 5.8 - Release Operation ''' + op_code = ASE_Operation.Opcode.RELEASE + + ase_id: Sequence[int] = field( + metadata=hci.metadata(1, list_begin=True, list_end=True) + ) + class AseResponseCode(enum.IntEnum): # fmt: off @@ -384,7 +389,7 @@ class AseStateMachine(gatt.Characteristic): target_phy: int, codec_id: hci.CodingFormat, codec_specific_configuration: bytes, - ) -> Tuple[AseResponseCode, AseReasonCode]: + ) -> tuple[AseResponseCode, AseReasonCode]: if self.state not in ( self.State.IDLE, self.State.CODEC_CONFIGURED, @@ -420,7 +425,7 @@ class AseStateMachine(gatt.Characteristic): retransmission_number: int, max_transport_latency: int, presentation_delay: int, - ) -> Tuple[AseResponseCode, AseReasonCode]: + ) -> tuple[AseResponseCode, AseReasonCode]: if self.state not in ( AseStateMachine.State.CODEC_CONFIGURED, AseStateMachine.State.QOS_CONFIGURED, @@ -444,7 +449,7 @@ class AseStateMachine(gatt.Characteristic): return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]: + def on_enable(self, metadata: bytes) -> tuple[AseResponseCode, AseReasonCode]: if self.state != AseStateMachine.State.QOS_CONFIGURED: return ( AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, @@ -456,7 +461,7 @@ class AseStateMachine(gatt.Characteristic): return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: + def on_receiver_start_ready(self) -> tuple[AseResponseCode, AseReasonCode]: if self.state != AseStateMachine.State.ENABLING: return ( AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, @@ -465,7 +470,7 @@ class AseStateMachine(gatt.Characteristic): self.state = self.State.STREAMING return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]: + def on_disable(self) -> tuple[AseResponseCode, AseReasonCode]: if self.state not in ( AseStateMachine.State.ENABLING, AseStateMachine.State.STREAMING, @@ -480,7 +485,7 @@ class AseStateMachine(gatt.Characteristic): self.state = self.State.DISABLING return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: + def on_receiver_stop_ready(self) -> tuple[AseResponseCode, AseReasonCode]: if ( self.role != AudioRole.SOURCE or self.state != AseStateMachine.State.DISABLING @@ -494,7 +499,7 @@ class AseStateMachine(gatt.Characteristic): def on_update_metadata( self, metadata: bytes - ) -> Tuple[AseResponseCode, AseReasonCode]: + ) -> tuple[AseResponseCode, AseReasonCode]: if self.state not in ( AseStateMachine.State.ENABLING, AseStateMachine.State.STREAMING, @@ -506,7 +511,7 @@ class AseStateMachine(gatt.Characteristic): self.metadata = le_audio.Metadata.from_bytes(metadata) return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]: + def on_release(self) -> tuple[AseResponseCode, AseReasonCode]: if self.state == AseStateMachine.State.IDLE: return ( AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, @@ -604,7 +609,7 @@ class AseStateMachine(gatt.Characteristic): class AudioStreamControlService(gatt.TemplateService): UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE - ase_state_machines: Dict[int, AseStateMachine] + ase_state_machines: dict[int, AseStateMachine] ase_control_point: gatt.Characteristic[bytes] _active_client: Optional[device.Connection] = None @@ -649,7 +654,9 @@ class AudioStreamControlService(gatt.TemplateService): ase.state = AseStateMachine.State.IDLE self._active_client = None - def on_write_ase_control_point(self, connection, data): + def on_write_ase_control_point( + self, connection: device.Connection, data: bytes + ) -> None: if not self._active_client and connection: self._active_client = connection connection.once('disconnection', self._on_client_disconnected) @@ -658,7 +665,7 @@ class AudioStreamControlService(gatt.TemplateService): responses = [] logger.debug(f'*** ASCS Write {operation} ***') - if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC: + if isinstance(operation, ASE_Config_Codec): for ase_id, *args in zip( operation.ase_id, operation.target_latency, @@ -667,7 +674,7 @@ class AudioStreamControlService(gatt.TemplateService): operation.codec_specific_configuration, ): responses.append(self.on_operation(operation.op_code, ase_id, args)) - elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS: + elif isinstance(operation, ASE_Config_QOS): for ase_id, *args in zip( operation.ase_id, operation.cig_id, @@ -681,20 +688,20 @@ class AudioStreamControlService(gatt.TemplateService): operation.presentation_delay, ): responses.append(self.on_operation(operation.op_code, ase_id, args)) - elif operation.op_code in ( - ASE_Operation.Opcode.ENABLE, - ASE_Operation.Opcode.UPDATE_METADATA, - ): + elif isinstance(operation, (ASE_Enable, ASE_Update_Metadata)): for ase_id, *args in zip( operation.ase_id, operation.metadata, ): responses.append(self.on_operation(operation.op_code, ase_id, args)) - elif operation.op_code in ( - ASE_Operation.Opcode.RECEIVER_START_READY, - ASE_Operation.Opcode.DISABLE, - ASE_Operation.Opcode.RECEIVER_STOP_READY, - ASE_Operation.Opcode.RELEASE, + elif isinstance( + operation, + ( + ASE_Receiver_Start_Ready, + ASE_Disable, + ASE_Receiver_Stop_Ready, + ASE_Release, + ), ): for ase_id in operation.ase_id: responses.append(self.on_operation(operation.op_code, ase_id, [])) @@ -723,8 +730,8 @@ class AudioStreamControlService(gatt.TemplateService): class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy): SERVICE_CLASS = AudioStreamControlService - sink_ase: List[gatt_client.CharacteristicProxy[bytes]] - source_ase: List[gatt_client.CharacteristicProxy[bytes]] + sink_ase: list[gatt_client.CharacteristicProxy[bytes]] + source_ase: list[gatt_client.CharacteristicProxy[bytes]] ase_control_point: gatt_client.CharacteristicProxy[bytes] def __init__(self, service_proxy: gatt_client.ServiceProxy): diff --git a/tests/bap_test.py b/tests/bap_test.py index aad635e..e8e84f9 100644 --- a/tests/bap_test.py +++ b/tests/bap_test.py @@ -431,12 +431,7 @@ async def test_ascs(): ) # Release - await ascs_client.ase_control_point.write_value( - ASE_Release( - ase_id=[1, 2], - metadata=[b'foo', b'bar'], - ) - ) + await ascs_client.ase_control_point.write_value(ASE_Release(ase_id=[1, 2])) assert (await notifications[1].get())[:2] == bytes( [1, AseStateMachine.State.RELEASING] )