Merge pull request #688 from zxzxwu/hci2

Dataclass-based ASCS Packets
This commit is contained in:
zxzxwu
2025-06-05 20:59:06 +08:00
committed by GitHub
3 changed files with 173 additions and 123 deletions

View File

@@ -107,6 +107,36 @@ def phy_list_to_bits(phys: Optional[Iterable[Phy]]) -> int:
return phy_bits return phy_bits
# -----------------------------------------------------------------------------
# Field Metadata
# -----------------------------------------------------------------------------
# Field specification can be:
# - a dict with "serializer", "parser", "size", "mapper" keys
# - a callable that takes (packet, offset) and returns (new_offset, value) (deserialize only)
# - a string of
# - ">2" and ">4" for 2-byte and 4-byte big-endian integers
# - "*" for all remaining bytes in the packet
# - "v" for variable length bytes with a leading length byte
# - an integer [1, 4] for 1-byte, 2-byte or 4-byte unsigned little-endian integers
# - an integer [-2, -1] for 1-byte, 2-byte signed little-endian integers
FieldSpec = Union[dict[str, Any], Callable[[bytes, int], tuple[int, Any]], str, int]
@dataclasses.dataclass
class FieldMetadata:
spec: FieldSpec
list_begin: bool = False
list_end: bool = False
def metadata(
spec: FieldSpec, list_begin: bool = False, list_end: bool = False
) -> dict[str, Any]:
return {
"bumble.hci": FieldMetadata(spec=spec, list_begin=list_begin, list_end=list_end)
}
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1659,7 +1689,7 @@ class HCI_Object:
HCI_Object.init_from_fields(hci_object, parsed.keys(), parsed.values()) HCI_Object.init_from_fields(hci_object, parsed.keys(), parsed.values())
@staticmethod @staticmethod
def parse_field(data, offset, field_type): def parse_field(data: bytes, offset: int, field_type: FieldSpec):
# The field_type may be a dictionary with a mapper, parser, and/or size # The field_type may be a dictionary with a mapper, parser, and/or size
if isinstance(field_type, dict): if isinstance(field_type, dict):
if 'size' in field_type: if 'size' in field_type:
@@ -1741,7 +1771,7 @@ class HCI_Object:
return result return result
@staticmethod @staticmethod
def serialize_field(field_value, field_type): def serialize_field(field_value: Any, field_type: FieldSpec) -> bytes:
# The field_type may be a dictionary with a mapper, parser, serializer, # The field_type may be a dictionary with a mapper, parser, serializer,
# and/or size # and/or size
serializer = None serializer = None
@@ -1932,6 +1962,24 @@ class HCI_Object:
for field_name, field_value in field_strings for field_name, field_value in field_strings
) )
@classmethod
def fields_from_dataclass(cls, obj: Any) -> list[Any]:
stack: list[list[Any]] = [[]]
for field in dataclasses.fields(obj):
# Fields without metadata should be ignored.
if not isinstance(
(metadata := field.metadata.get("bumble.hci")), FieldMetadata
):
continue
if metadata.list_begin:
stack.append([])
if metadata.spec:
stack[-1].append((field.name, metadata.spec))
if metadata.list_end:
top = stack.pop()
stack[-1].append(top)
return stack[0]
def __init__(self, fields, **kwargs): def __init__(self, fields, **kwargs):
self.fields = fields self.fields = fields
self.init_from_fields(self, fields, kwargs) self.init_from_fields(self, fields, kwargs)

View File

@@ -18,10 +18,13 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field
import enum import enum
import functools
import logging import logging
import struct import struct
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union from typing import Any, Optional, Union, TypeVar
from collections.abc import Sequence
from bumble import utils from bumble import utils
from bumble import colors from bumble import colors
@@ -48,11 +51,11 @@ class ASE_Operation:
See Audio Stream Control Service - 5 ASE Control operations. See Audio Stream Control Service - 5 ASE Control operations.
''' '''
classes: Dict[int, Type[ASE_Operation]] = {} classes: dict[int, type[ASE_Operation]] = {}
op_code: int op_code: Opcode
name: str name: str
fields: Optional[Sequence[Any]] = None fields: Optional[Sequence[Any]] = None
ase_id: List[int] ase_id: Sequence[int]
class Opcode(enum.IntEnum): class Opcode(enum.IntEnum):
# fmt: off # fmt: off
@@ -65,51 +68,30 @@ class ASE_Operation:
UPDATE_METADATA = 0x07 UPDATE_METADATA = 0x07
RELEASE = 0x08 RELEASE = 0x08
@staticmethod @classmethod
def from_bytes(pdu: bytes) -> ASE_Operation: def from_bytes(cls, pdu: bytes) -> ASE_Operation:
op_code = pdu[0] op_code = pdu[0]
cls = ASE_Operation.classes.get(op_code) clazz = ASE_Operation.classes[op_code]
if cls is None: return clazz(
instance = ASE_Operation(pdu) **hci.HCI_Object.dict_from_bytes(pdu, offset=1, fields=clazz.fields)
instance.name = ASE_Operation.Opcode(op_code).name )
instance.op_code = op_code
return instance
self = cls.__new__(cls)
ASE_Operation.__init__(self, pdu)
if self.fields is not None:
self.init_from_bytes(pdu, 1)
return self
@staticmethod _OP = TypeVar("_OP", bound="ASE_Operation")
def subclass(fields):
def inner(cls: Type[ASE_Operation]):
try:
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
cls.name = operation.name
cls.op_code = operation
except:
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
cls.fields = fields
# Register a factory for this class @classmethod
ASE_Operation.classes[cls.op_code] = cls def subclass(cls, clazz: type[_OP]) -> type[_OP]:
clazz.name = f"ASE_{clazz.op_code.name.upper()}"
clazz.fields = hci.HCI_Object.fields_from_dataclass(clazz)
# Register a factory for this class
ASE_Operation.classes[clazz.op_code] = clazz
return clazz
return cls @functools.cached_property
def pdu(self) -> bytes:
return inner return bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
self.__dict__, self.fields
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None: )
if self.fields is not None and kwargs:
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
kwargs, self.fields
)
self.pdu = pdu
def init_from_bytes(self, pdu: bytes, offset: int):
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return self.pdu return self.pdu
@@ -124,105 +106,128 @@ class ASE_Operation:
return result return result
@ASE_Operation.subclass( @ASE_Operation.subclass
[ @dataclass
[
('ase_id', 1),
('target_latency', 1),
('target_phy', 1),
('codec_id', hci.CodingFormat.parse_from_bytes),
('codec_specific_configuration', 'v'),
],
]
)
class ASE_Config_Codec(ASE_Operation): class ASE_Config_Codec(ASE_Operation):
''' '''
See Audio Stream Control Service 5.1 - Config Codec Operation See Audio Stream Control Service 5.1 - Config Codec Operation
''' '''
target_latency: List[int] op_code = ASE_Operation.Opcode.CONFIG_CODEC
target_phy: List[int]
codec_id: List[hci.CodingFormat] ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True))
codec_specific_configuration: List[bytes] target_latency: Sequence[int] = field(metadata=hci.metadata(1))
target_phy: Sequence[int] = field(metadata=hci.metadata(1))
codec_id: Sequence[hci.CodingFormat] = field(
metadata=hci.metadata(hci.CodingFormat.parse_from_bytes)
)
codec_specific_configuration: Sequence[bytes] = field(
metadata=hci.metadata('v', list_end=True)
)
@ASE_Operation.subclass( @ASE_Operation.subclass
[ @dataclass
[
('ase_id', 1),
('cig_id', 1),
('cis_id', 1),
('sdu_interval', 3),
('framing', 1),
('phy', 1),
('max_sdu', 2),
('retransmission_number', 1),
('max_transport_latency', 2),
('presentation_delay', 3),
],
]
)
class ASE_Config_QOS(ASE_Operation): class ASE_Config_QOS(ASE_Operation):
''' '''
See Audio Stream Control Service 5.2 - Config Qos Operation See Audio Stream Control Service 5.2 - Config Qos Operation
''' '''
cig_id: List[int] op_code = ASE_Operation.Opcode.CONFIG_QOS
cis_id: List[int]
sdu_interval: List[int] ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True))
framing: List[int] cig_id: Sequence[int] = field(metadata=hci.metadata(1))
phy: List[int] cis_id: Sequence[int] = field(metadata=hci.metadata(1))
max_sdu: List[int] sdu_interval: Sequence[int] = field(metadata=hci.metadata(3))
retransmission_number: List[int] framing: Sequence[int] = field(metadata=hci.metadata(1))
max_transport_latency: List[int] phy: Sequence[int] = field(metadata=hci.metadata(1))
presentation_delay: List[int] max_sdu: Sequence[int] = field(metadata=hci.metadata(2))
retransmission_number: Sequence[int] = field(metadata=hci.metadata(1))
max_transport_latency: Sequence[int] = field(metadata=hci.metadata(2))
presentation_delay: Sequence[int] = field(metadata=hci.metadata(3, list_end=True))
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) @ASE_Operation.subclass
@dataclass
class ASE_Enable(ASE_Operation): class ASE_Enable(ASE_Operation):
''' '''
See Audio Stream Control Service 5.3 - Enable Operation See Audio Stream Control Service 5.3 - Enable Operation
''' '''
metadata: bytes op_code = ASE_Operation.Opcode.ENABLE
ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True))
metadata: Sequence[bytes] = field(metadata=hci.metadata('v', list_end=True))
@ASE_Operation.subclass([[('ase_id', 1)]]) @ASE_Operation.subclass
@dataclass
class ASE_Receiver_Start_Ready(ASE_Operation): class ASE_Receiver_Start_Ready(ASE_Operation):
''' '''
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
''' '''
op_code = ASE_Operation.Opcode.RECEIVER_START_READY
@ASE_Operation.subclass([[('ase_id', 1)]]) ase_id: Sequence[int] = field(
metadata=hci.metadata(1, list_begin=True, list_end=True)
)
@ASE_Operation.subclass
@dataclass
class ASE_Disable(ASE_Operation): class ASE_Disable(ASE_Operation):
''' '''
See Audio Stream Control Service 5.5 - Disable Operation See Audio Stream Control Service 5.5 - Disable Operation
''' '''
op_code = ASE_Operation.Opcode.DISABLE
@ASE_Operation.subclass([[('ase_id', 1)]]) ase_id: Sequence[int] = field(
metadata=hci.metadata(1, list_begin=True, list_end=True)
)
@ASE_Operation.subclass
@dataclass
class ASE_Receiver_Stop_Ready(ASE_Operation): class ASE_Receiver_Stop_Ready(ASE_Operation):
''' '''
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
''' '''
op_code = ASE_Operation.Opcode.RECEIVER_STOP_READY
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) ase_id: Sequence[int] = field(
metadata=hci.metadata(1, list_begin=True, list_end=True)
)
@ASE_Operation.subclass
@dataclass
class ASE_Update_Metadata(ASE_Operation): class ASE_Update_Metadata(ASE_Operation):
''' '''
See Audio Stream Control Service 5.7 - Update Metadata Operation See Audio Stream Control Service 5.7 - Update Metadata Operation
''' '''
metadata: List[bytes] op_code = ASE_Operation.Opcode.UPDATE_METADATA
ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True))
metadata: Sequence[bytes] = field(metadata=hci.metadata('v', list_end=True))
@ASE_Operation.subclass([[('ase_id', 1)]]) @ASE_Operation.subclass
@dataclass
class ASE_Release(ASE_Operation): class ASE_Release(ASE_Operation):
''' '''
See Audio Stream Control Service 5.8 - Release Operation See Audio Stream Control Service 5.8 - Release Operation
''' '''
op_code = ASE_Operation.Opcode.RELEASE
ase_id: Sequence[int] = field(
metadata=hci.metadata(1, list_begin=True, list_end=True)
)
class AseResponseCode(enum.IntEnum): class AseResponseCode(enum.IntEnum):
# fmt: off # fmt: off
@@ -384,7 +389,7 @@ class AseStateMachine(gatt.Characteristic):
target_phy: int, target_phy: int,
codec_id: hci.CodingFormat, codec_id: hci.CodingFormat,
codec_specific_configuration: bytes, codec_specific_configuration: bytes,
) -> Tuple[AseResponseCode, AseReasonCode]: ) -> tuple[AseResponseCode, AseReasonCode]:
if self.state not in ( if self.state not in (
self.State.IDLE, self.State.IDLE,
self.State.CODEC_CONFIGURED, self.State.CODEC_CONFIGURED,
@@ -420,7 +425,7 @@ class AseStateMachine(gatt.Characteristic):
retransmission_number: int, retransmission_number: int,
max_transport_latency: int, max_transport_latency: int,
presentation_delay: int, presentation_delay: int,
) -> Tuple[AseResponseCode, AseReasonCode]: ) -> tuple[AseResponseCode, AseReasonCode]:
if self.state not in ( if self.state not in (
AseStateMachine.State.CODEC_CONFIGURED, AseStateMachine.State.CODEC_CONFIGURED,
AseStateMachine.State.QOS_CONFIGURED, AseStateMachine.State.QOS_CONFIGURED,
@@ -444,7 +449,7 @@ class AseStateMachine(gatt.Characteristic):
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]: def on_enable(self, metadata: bytes) -> tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.QOS_CONFIGURED: if self.state != AseStateMachine.State.QOS_CONFIGURED:
return ( return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
@@ -456,7 +461,7 @@ class AseStateMachine(gatt.Characteristic):
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: def on_receiver_start_ready(self) -> tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.ENABLING: if self.state != AseStateMachine.State.ENABLING:
return ( return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
@@ -465,7 +470,7 @@ class AseStateMachine(gatt.Characteristic):
self.state = self.State.STREAMING self.state = self.State.STREAMING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]: def on_disable(self) -> tuple[AseResponseCode, AseReasonCode]:
if self.state not in ( if self.state not in (
AseStateMachine.State.ENABLING, AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING, AseStateMachine.State.STREAMING,
@@ -480,7 +485,7 @@ class AseStateMachine(gatt.Characteristic):
self.state = self.State.DISABLING self.state = self.State.DISABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: def on_receiver_stop_ready(self) -> tuple[AseResponseCode, AseReasonCode]:
if ( if (
self.role != AudioRole.SOURCE self.role != AudioRole.SOURCE
or self.state != AseStateMachine.State.DISABLING or self.state != AseStateMachine.State.DISABLING
@@ -494,7 +499,7 @@ class AseStateMachine(gatt.Characteristic):
def on_update_metadata( def on_update_metadata(
self, metadata: bytes self, metadata: bytes
) -> Tuple[AseResponseCode, AseReasonCode]: ) -> tuple[AseResponseCode, AseReasonCode]:
if self.state not in ( if self.state not in (
AseStateMachine.State.ENABLING, AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING, AseStateMachine.State.STREAMING,
@@ -506,7 +511,7 @@ class AseStateMachine(gatt.Characteristic):
self.metadata = le_audio.Metadata.from_bytes(metadata) self.metadata = le_audio.Metadata.from_bytes(metadata)
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]: def on_release(self) -> tuple[AseResponseCode, AseReasonCode]:
if self.state == AseStateMachine.State.IDLE: if self.state == AseStateMachine.State.IDLE:
return ( return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
@@ -604,7 +609,7 @@ class AseStateMachine(gatt.Characteristic):
class AudioStreamControlService(gatt.TemplateService): class AudioStreamControlService(gatt.TemplateService):
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
ase_state_machines: Dict[int, AseStateMachine] ase_state_machines: dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic[bytes] ase_control_point: gatt.Characteristic[bytes]
_active_client: Optional[device.Connection] = None _active_client: Optional[device.Connection] = None
@@ -649,7 +654,9 @@ class AudioStreamControlService(gatt.TemplateService):
ase.state = AseStateMachine.State.IDLE ase.state = AseStateMachine.State.IDLE
self._active_client = None self._active_client = None
def on_write_ase_control_point(self, connection, data): def on_write_ase_control_point(
self, connection: device.Connection, data: bytes
) -> None:
if not self._active_client and connection: if not self._active_client and connection:
self._active_client = connection self._active_client = connection
connection.once('disconnection', self._on_client_disconnected) connection.once('disconnection', self._on_client_disconnected)
@@ -658,7 +665,7 @@ class AudioStreamControlService(gatt.TemplateService):
responses = [] responses = []
logger.debug(f'*** ASCS Write {operation} ***') logger.debug(f'*** ASCS Write {operation} ***')
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC: if isinstance(operation, ASE_Config_Codec):
for ase_id, *args in zip( for ase_id, *args in zip(
operation.ase_id, operation.ase_id,
operation.target_latency, operation.target_latency,
@@ -667,7 +674,7 @@ class AudioStreamControlService(gatt.TemplateService):
operation.codec_specific_configuration, operation.codec_specific_configuration,
): ):
responses.append(self.on_operation(operation.op_code, ase_id, args)) responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS: elif isinstance(operation, ASE_Config_QOS):
for ase_id, *args in zip( for ase_id, *args in zip(
operation.ase_id, operation.ase_id,
operation.cig_id, operation.cig_id,
@@ -681,20 +688,20 @@ class AudioStreamControlService(gatt.TemplateService):
operation.presentation_delay, operation.presentation_delay,
): ):
responses.append(self.on_operation(operation.op_code, ase_id, args)) responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in ( elif isinstance(operation, (ASE_Enable, ASE_Update_Metadata)):
ASE_Operation.Opcode.ENABLE,
ASE_Operation.Opcode.UPDATE_METADATA,
):
for ase_id, *args in zip( for ase_id, *args in zip(
operation.ase_id, operation.ase_id,
operation.metadata, operation.metadata,
): ):
responses.append(self.on_operation(operation.op_code, ase_id, args)) responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in ( elif isinstance(
ASE_Operation.Opcode.RECEIVER_START_READY, operation,
ASE_Operation.Opcode.DISABLE, (
ASE_Operation.Opcode.RECEIVER_STOP_READY, ASE_Receiver_Start_Ready,
ASE_Operation.Opcode.RELEASE, ASE_Disable,
ASE_Receiver_Stop_Ready,
ASE_Release,
),
): ):
for ase_id in operation.ase_id: for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, [])) responses.append(self.on_operation(operation.op_code, ase_id, []))
@@ -723,8 +730,8 @@ class AudioStreamControlService(gatt.TemplateService):
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy): class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AudioStreamControlService SERVICE_CLASS = AudioStreamControlService
sink_ase: List[gatt_client.CharacteristicProxy[bytes]] sink_ase: list[gatt_client.CharacteristicProxy[bytes]]
source_ase: List[gatt_client.CharacteristicProxy[bytes]] source_ase: list[gatt_client.CharacteristicProxy[bytes]]
ase_control_point: gatt_client.CharacteristicProxy[bytes] ase_control_point: gatt_client.CharacteristicProxy[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy): def __init__(self, service_proxy: gatt_client.ServiceProxy):

View File

@@ -431,12 +431,7 @@ async def test_ascs():
) )
# Release # Release
await ascs_client.ase_control_point.write_value( await ascs_client.ase_control_point.write_value(ASE_Release(ase_id=[1, 2]))
ASE_Release(
ase_id=[1, 2],
metadata=[b'foo', b'bar'],
)
)
assert (await notifications[1].get())[:2] == bytes( assert (await notifications[1].get())[:2] == bytes(
[1, AseStateMachine.State.RELEASING] [1, AseStateMachine.State.RELEASING]
) )