Migrate AVRCP packets to dataclasses

This commit is contained in:
Josh Wu
2025-08-27 21:54:29 +08:00
parent 8bda7d2212
commit dab0993cba

View File

@@ -21,7 +21,7 @@ import asyncio
import enum
import logging
import struct
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import (
AsyncIterator,
Awaitable,
@@ -33,10 +33,11 @@ from typing import (
SupportsBytes,
TypeVar,
Union,
ClassVar,
cast,
)
from bumble import avc, avctp, core, l2cap, utils
from bumble import avc, avctp, core, l2cap, utils, hci
from bumble.colors import color
from bumble.device import Connection, Device
from bumble.sdp import (
@@ -64,6 +65,71 @@ AVRCP_PID = 0x110E
AVRCP_BLUETOOTH_SIG_COMPANY_ID = 0x001958
class PduId(utils.OpenIntEnum):
GET_CAPABILITIES = 0x10
LIST_PLAYER_APPLICATION_SETTING_ATTRIBUTES = 0x11
LIST_PLAYER_APPLICATION_SETTING_VALUES = 0x12
GET_CURRENT_PLAYER_APPLICATION_SETTING_VALUE = 0x13
SET_PLAYER_APPLICATION_SETTING_VALUE = 0x14
GET_PLAYER_APPLICATION_SETTING_ATTRIBUTE_TEXT = 0x15
GET_PLAYER_APPLICATION_SETTING_VALUE_TEXT = 0x16
INFORM_DISPLAYABLE_CHARACTER_SET = 0x17
INFORM_BATTERY_STATUS_OF_CT = 0x18
GET_ELEMENT_ATTRIBUTES = 0x20
GET_PLAY_STATUS = 0x30
REGISTER_NOTIFICATION = 0x31
REQUEST_CONTINUING_RESPONSE = 0x40
ABORT_CONTINUING_RESPONSE = 0x41
SET_ABSOLUTE_VOLUME = 0x50
SET_ADDRESSED_PLAYER = 0x60
SET_BROWSED_PLAYER = 0x70
GET_FOLDER_ITEMS = 0x71
GET_TOTAL_NUMBER_OF_ITEMS = 0x75
class CharacterSetId(hci.SpecableEnum):
UTF_8 = 0x06
class MediaAttributeId(hci.SpecableEnum):
TITLE = 0x01
ARTIST_NAME = 0x02
ALBUM_NAME = 0x03
TRACK_NUMBER = 0x04
TOTAL_NUMBER_OF_TRACKS = 0x05
GENRE = 0x06
PLAYING_TIME = 0x07
DEFAULT_COVER_ART = 0x08
class PlayStatus(hci.SpecableEnum):
STOPPED = 0x00
PLAYING = 0x01
PAUSED = 0x02
FWD_SEEK = 0x03
REV_SEEK = 0x04
ERROR = 0xFF
class EventId(hci.SpecableEnum):
PLAYBACK_STATUS_CHANGED = 0x01
TRACK_CHANGED = 0x02
TRACK_REACHED_END = 0x03
TRACK_REACHED_START = 0x04
PLAYBACK_POS_CHANGED = 0x05
BATT_STATUS_CHANGED = 0x06
SYSTEM_STATUS_CHANGED = 0x07
PLAYER_APPLICATION_SETTING_CHANGED = 0x08
NOW_PLAYING_CONTENT_CHANGED = 0x09
AVAILABLE_PLAYERS_CHANGED = 0x0A
ADDRESSED_PLAYER_CHANGED = 0x0B
UIDS_CHANGED = 0x0C
VOLUME_CHANGED = 0x0D
def __bytes__(self) -> bytes:
return bytes([int(self)])
# -----------------------------------------------------------------------------
def make_controller_service_sdp_records(
service_record_handle: int,
@@ -218,10 +284,10 @@ class PduAssembler:
6.3.1 AVRCP specific AV//C commands
"""
pdu_id: Optional[Protocol.PduId]
pdu_id: Optional[PduId]
payload: bytes
def __init__(self, callback: Callable[[Protocol.PduId, bytes], None]) -> None:
def __init__(self, callback: Callable[[PduId, bytes], None]) -> None:
self.callback = callback
self.reset()
@@ -230,7 +296,7 @@ class PduAssembler:
self.parameter = b''
def on_pdu(self, pdu: bytes) -> None:
pdu_id = Protocol.PduId(pdu[0])
pdu_id = PduId(pdu[0])
packet_type = Protocol.PacketType(pdu[1] & 3)
parameter_length = struct.unpack_from('>H', pdu, 2)[0]
parameter = pdu[4 : 4 + parameter_length]
@@ -271,125 +337,103 @@ class PduAssembler:
# -----------------------------------------------------------------------------
@dataclass
class Command:
pdu_id: Protocol.PduId
parameter: bytes
pdu_id: ClassVar[PduId]
_payload: Optional[bytes] = field(init=False, default=None)
def to_string(self, properties: dict[str, str]) -> str:
properties_str = ",".join(
[f"{name}={value}" for name, value in properties.items()]
)
return f"Command[{self.pdu_id.name}]({properties_str})"
_Command = TypeVar('_Command', bound='Command')
subclasses: ClassVar[dict[int, type[Command]]] = {}
fields: ClassVar[hci.Fields] = ()
def __str__(self) -> str:
return self.to_string({"parameters": self.parameter.hex()})
@classmethod
def command(cls, subclass: type[_Command]) -> type[_Command]:
cls.subclasses[subclass.pdu_id] = subclass
subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass)
return subclass
@classmethod
def from_bytes(cls, pdu_id: int, pdu: bytes) -> Command:
if not (subclass := cls.subclasses.get(pdu_id)):
raise core.InvalidPacketError(f"Unimplemented PDU {pdu_id}")
instance = subclass(**hci.HCI_Object.dict_from_bytes(pdu, 0, subclass.fields))
instance._payload = pdu[0:]
return instance
@property
def payload(self) -> bytes:
if self._payload is None:
self._payload = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
return self._payload
def __repr__(self) -> str:
return str(self)
# -----------------------------------------------------------------------------
@Command.command
@dataclass
class GetCapabilitiesCommand(Command):
class CapabilityId(utils.OpenIntEnum):
pdu_id = PduId.GET_CAPABILITIES
class CapabilityId(hci.SpecableEnum):
COMPANY_ID = 0x02
EVENTS_SUPPORTED = 0x03
capability_id: CapabilityId
@classmethod
def from_bytes(cls, pdu: bytes) -> GetCapabilitiesCommand:
return cls(cls.CapabilityId(pdu[0]))
def __init__(self, capability_id: CapabilityId) -> None:
super().__init__(Protocol.PduId.GET_CAPABILITIES, bytes([capability_id]))
self.capability_id = capability_id
def __str__(self) -> str:
return self.to_string({"capability_id": self.capability_id.name})
capability_id: CapabilityId = field(metadata=CapabilityId.type_metadata(1))
# -----------------------------------------------------------------------------
@Command.command
@dataclass
class GetPlayStatusCommand(Command):
@classmethod
def from_bytes(cls, _: bytes) -> GetPlayStatusCommand:
return cls()
def __init__(self) -> None:
super().__init__(Protocol.PduId.GET_PLAY_STATUS, b'')
pdu_id = PduId.GET_PLAY_STATUS
# -----------------------------------------------------------------------------
@Command.command
@dataclass
class GetElementAttributesCommand(Command):
identifier: int
attribute_ids: list[MediaAttributeId]
pdu_id = PduId.GET_ELEMENT_ATTRIBUTES
@classmethod
def from_bytes(cls, pdu: bytes) -> GetElementAttributesCommand:
identifier = struct.unpack_from(">Q", pdu)[0]
num_attributes = pdu[8]
attribute_ids = [MediaAttributeId(pdu[9 + i]) for i in range(num_attributes)]
return cls(identifier, attribute_ids)
def __init__(
self, identifier: int, attribute_ids: Sequence[MediaAttributeId]
) -> None:
parameter = struct.pack(">QB", identifier, len(attribute_ids)) + b''.join(
[struct.pack(">I", int(attribute_id)) for attribute_id in attribute_ids]
)
super().__init__(Protocol.PduId.GET_ELEMENT_ATTRIBUTES, parameter)
self.identifier = identifier
self.attribute_ids = list(attribute_ids)
# -----------------------------------------------------------------------------
class SetAbsoluteVolumeCommand(Command):
MAXIMUM_VOLUME = 0x7F
volume: int
@classmethod
def from_bytes(cls, pdu: bytes) -> SetAbsoluteVolumeCommand:
return cls(pdu[0])
def __init__(self, volume: int) -> None:
super().__init__(Protocol.PduId.SET_ABSOLUTE_VOLUME, bytes([volume]))
self.volume = volume
def __str__(self) -> str:
return self.to_string({"volume": str(self.volume)})
# -----------------------------------------------------------------------------
class RegisterNotificationCommand(Command):
event_id: EventId
playback_interval: int
@classmethod
def from_bytes(cls, pdu: bytes) -> RegisterNotificationCommand:
event_id = EventId(pdu[0])
playback_interval = struct.unpack_from(">I", pdu, 1)[0]
return cls(event_id, playback_interval)
def __init__(self, event_id: EventId, playback_interval: int) -> None:
super().__init__(
Protocol.PduId.REGISTER_NOTIFICATION,
struct.pack(">BI", int(event_id), playback_interval),
)
self.event_id = event_id
self.playback_interval = playback_interval
def __str__(self) -> str:
return self.to_string(
identifier: int = field(
metadata=hci.metadata(
{
"event_id": self.event_id.name,
"playback_interval": str(self.playback_interval),
'parser': lambda data, offset: (
offset + 8,
int.from_bytes(data[offset : offset + 8]),
),
'serializer': lambda x: x.to_bytes(8),
}
)
)
attribute_ids: Sequence[MediaAttributeId] = field(
metadata=MediaAttributeId.type_metadata(1, list_begin=True, list_end=True)
)
# -----------------------------------------------------------------------------
@Command.command
@dataclass
class SetAbsoluteVolumeCommand(Command):
pdu_id = PduId.SET_ABSOLUTE_VOLUME
MAXIMUM_VOLUME = 0x7F
volume: int = field(metadata=hci.metadata(1))
# -----------------------------------------------------------------------------
@Command.command
@dataclass
class RegisterNotificationCommand(Command):
pdu_id = PduId.SET_ABSOLUTE_VOLUME
event_id: EventId = field(metadata=EventId.type_metadata(1))
playback_interval: int = field(metadata=hci.metadata('>4'))
# -----------------------------------------------------------------------------
@dataclass
class Response:
pdu_id: Protocol.PduId
pdu_id: PduId
parameter: bytes
def to_string(self, properties: dict[str, str]) -> str:
@@ -410,12 +454,10 @@ class RejectedResponse(Response):
status_code: Protocol.StatusCode
@classmethod
def from_bytes(cls, pdu_id: Protocol.PduId, pdu: bytes) -> RejectedResponse:
def from_bytes(cls, pdu_id: PduId, pdu: bytes) -> RejectedResponse:
return cls(pdu_id, Protocol.StatusCode(pdu[0]))
def __init__(
self, pdu_id: Protocol.PduId, status_code: Protocol.StatusCode
) -> None:
def __init__(self, pdu_id: PduId, status_code: Protocol.StatusCode) -> None:
super().__init__(pdu_id, bytes([int(status_code)]))
self.status_code = status_code
@@ -430,7 +472,7 @@ class RejectedResponse(Response):
# -----------------------------------------------------------------------------
class NotImplementedResponse(Response):
@classmethod
def from_bytes(cls, pdu_id: Protocol.PduId, pdu: bytes) -> NotImplementedResponse:
def from_bytes(cls, pdu_id: PduId, pdu: bytes) -> NotImplementedResponse:
return cls(pdu_id, pdu[1:])
@@ -468,7 +510,7 @@ class GetCapabilitiesResponse(Response):
capabilities: Sequence[Union[SupportsBytes, bytes]],
) -> None:
super().__init__(
Protocol.PduId.GET_CAPABILITIES,
PduId.GET_CAPABILITIES,
bytes([capability_id, len(capabilities)])
+ b''.join(bytes(capability) for capability in capabilities),
)
@@ -504,7 +546,7 @@ class GetPlayStatusResponse(Response):
play_status: PlayStatus,
) -> None:
super().__init__(
Protocol.PduId.GET_PLAY_STATUS,
PduId.GET_PLAY_STATUS,
struct.pack(">IIB", song_length, song_position, int(play_status)),
)
self.song_length = song_length
@@ -565,7 +607,7 @@ class GetElementAttributesResponse(Response):
+ attribute_value_bytes
)
super().__init__(
Protocol.PduId.GET_ELEMENT_ATTRIBUTES,
PduId.GET_ELEMENT_ATTRIBUTES,
parameter,
)
self.attributes = list(attributes)
@@ -588,7 +630,7 @@ class SetAbsoluteVolumeResponse(Response):
return cls(pdu[0])
def __init__(self, volume: int) -> None:
super().__init__(Protocol.PduId.SET_ABSOLUTE_VOLUME, bytes([volume]))
super().__init__(PduId.SET_ABSOLUTE_VOLUME, bytes([volume]))
self.volume = volume
def __str__(self) -> str:
@@ -605,7 +647,7 @@ class RegisterNotificationResponse(Response):
def __init__(self, event: Event) -> None:
super().__init__(
Protocol.PduId.REGISTER_NOTIFICATION,
PduId.REGISTER_NOTIFICATION,
bytes(event),
)
self.event = event
@@ -618,43 +660,6 @@ class RegisterNotificationResponse(Response):
)
# -----------------------------------------------------------------------------
class EventId(utils.OpenIntEnum):
PLAYBACK_STATUS_CHANGED = 0x01
TRACK_CHANGED = 0x02
TRACK_REACHED_END = 0x03
TRACK_REACHED_START = 0x04
PLAYBACK_POS_CHANGED = 0x05
BATT_STATUS_CHANGED = 0x06
SYSTEM_STATUS_CHANGED = 0x07
PLAYER_APPLICATION_SETTING_CHANGED = 0x08
NOW_PLAYING_CONTENT_CHANGED = 0x09
AVAILABLE_PLAYERS_CHANGED = 0x0A
ADDRESSED_PLAYER_CHANGED = 0x0B
UIDS_CHANGED = 0x0C
VOLUME_CHANGED = 0x0D
def __bytes__(self) -> bytes:
return bytes([int(self)])
# -----------------------------------------------------------------------------
class CharacterSetId(utils.OpenIntEnum):
UTF_8 = 0x06
# -----------------------------------------------------------------------------
class MediaAttributeId(utils.OpenIntEnum):
TITLE = 0x01
ARTIST_NAME = 0x02
ALBUM_NAME = 0x03
TRACK_NUMBER = 0x04
TOTAL_NUMBER_OF_TRACKS = 0x05
GENRE = 0x06
PLAYING_TIME = 0x07
DEFAULT_COVER_ART = 0x08
# -----------------------------------------------------------------------------
@dataclass
class MediaAttribute:
@@ -663,16 +668,6 @@ class MediaAttribute:
attribute_value: str
# -----------------------------------------------------------------------------
class PlayStatus(utils.OpenIntEnum):
STOPPED = 0x00
PLAYING = 0x01
PAUSED = 0x02
FWD_SEEK = 0x03
REV_SEEK = 0x04
ERROR = 0xFF
# -----------------------------------------------------------------------------
@dataclass
class SongAndPlayStatus:
@@ -989,27 +984,6 @@ class Protocol(utils.EventEmitter):
CONTINUE = 0b10
END = 0b11
class PduId(utils.OpenIntEnum):
GET_CAPABILITIES = 0x10
LIST_PLAYER_APPLICATION_SETTING_ATTRIBUTES = 0x11
LIST_PLAYER_APPLICATION_SETTING_VALUES = 0x12
GET_CURRENT_PLAYER_APPLICATION_SETTING_VALUE = 0x13
SET_PLAYER_APPLICATION_SETTING_VALUE = 0x14
GET_PLAYER_APPLICATION_SETTING_ATTRIBUTE_TEXT = 0x15
GET_PLAYER_APPLICATION_SETTING_VALUE_TEXT = 0x16
INFORM_DISPLAYABLE_CHARACTER_SET = 0x17
INFORM_BATTERY_STATUS_OF_CT = 0x18
GET_ELEMENT_ATTRIBUTES = 0x20
GET_PLAY_STATUS = 0x30
REGISTER_NOTIFICATION = 0x31
REQUEST_CONTINUING_RESPONSE = 0x40
ABORT_CONTINUING_RESPONSE = 0x41
SET_ABSOLUTE_VOLUME = 0x50
SET_ADDRESSED_PLAYER = 0x60
SET_BROWSED_PLAYER = 0x70
GET_FOLDER_ITEMS = 0x71
GET_TOTAL_NUMBER_OF_ITEMS = 0x75
class StatusCode(utils.OpenIntEnum):
INVALID_COMMAND = 0x00
INVALID_PARAMETER = 0x01
@@ -1596,18 +1570,13 @@ class Protocol(utils.EventEmitter):
avc.CommandFrame.CommandType.NOTIFY,
):
# TODO: catch exceptions from delegates
if pdu_id == self.PduId.GET_CAPABILITIES:
self._on_get_capabilities_command(
transaction_label, GetCapabilitiesCommand.from_bytes(pdu)
)
elif pdu_id == self.PduId.SET_ABSOLUTE_VOLUME:
self._on_set_absolute_volume_command(
transaction_label, SetAbsoluteVolumeCommand.from_bytes(pdu)
)
elif pdu_id == self.PduId.REGISTER_NOTIFICATION:
self._on_register_notification_command(
transaction_label, RegisterNotificationCommand.from_bytes(pdu)
)
command = Command.from_bytes(pdu_id, pdu)
if isinstance(command, GetCapabilitiesCommand):
self._on_get_capabilities_command(transaction_label, command)
elif isinstance(command, SetAbsoluteVolumeCommand):
self._on_set_absolute_volume_command(transaction_label, command)
elif isinstance(command, RegisterNotificationCommand):
self._on_register_notification_command(transaction_label, command)
else:
# Not supported.
# TODO: check that this is the right way to respond in this case.
@@ -1652,15 +1621,15 @@ class Protocol(utils.EventEmitter):
avc.ResponseFrame.ResponseCode.CHANGED,
avc.ResponseFrame.ResponseCode.ACCEPTED,
):
if pdu_id == self.PduId.GET_CAPABILITIES:
if pdu_id == PduId.GET_CAPABILITIES:
response = GetCapabilitiesResponse.from_bytes(pdu)
elif pdu_id == self.PduId.GET_PLAY_STATUS:
elif pdu_id == PduId.GET_PLAY_STATUS:
response = GetPlayStatusResponse.from_bytes(pdu)
elif pdu_id == self.PduId.GET_ELEMENT_ATTRIBUTES:
elif pdu_id == PduId.GET_ELEMENT_ATTRIBUTES:
response = GetElementAttributesResponse.from_bytes(pdu)
elif pdu_id == self.PduId.SET_ABSOLUTE_VOLUME:
elif pdu_id == PduId.SET_ABSOLUTE_VOLUME:
response = SetAbsoluteVolumeResponse.from_bytes(pdu)
elif pdu_id == self.PduId.REGISTER_NOTIFICATION:
elif pdu_id == PduId.REGISTER_NOTIFICATION:
response = RegisterNotificationResponse.from_bytes(pdu)
else:
logger.debug("unexpected PDU ID")
@@ -1758,8 +1727,8 @@ class Protocol(utils.EventEmitter):
# Send the command.
logger.debug(f">>> AVRCP command PDU: {command}")
pdu = (
struct.pack(">BBH", command.pdu_id, 0, len(command.parameter))
+ command.parameter
struct.pack(">BBH", command.pdu_id, 0, len(command.payload))
+ command.payload
)
command_frame = avc.VendorDependentCommandFrame(
command_type,
@@ -1830,7 +1799,7 @@ class Protocol(utils.EventEmitter):
self.send_response(transaction_label, response)
def send_rejected_avrcp_response(
self, transaction_label: int, pdu_id: Protocol.PduId, status_code: StatusCode
self, transaction_label: int, pdu_id: PduId, status_code: StatusCode
) -> None:
self.send_avrcp_response(
transaction_label,
@@ -1839,7 +1808,7 @@ class Protocol(utils.EventEmitter):
)
def send_not_implemented_avrcp_response(
self, transaction_label: int, pdu_id: Protocol.PduId
self, transaction_label: int, pdu_id: PduId
) -> None:
self.send_avrcp_response(
transaction_label,
@@ -1895,7 +1864,7 @@ class Protocol(utils.EventEmitter):
if command.event_id not in supported_events:
logger.debug("event not supported")
self.send_not_implemented_avrcp_response(
transaction_label, self.PduId.REGISTER_NOTIFICATION
transaction_label, PduId.REGISTER_NOTIFICATION
)
return