Migrate AVRCP events to dataclasses

This commit is contained in:
Josh Wu
2025-08-27 23:42:46 +08:00
parent dab0993cba
commit 31961febe5
2 changed files with 211 additions and 212 deletions

View File

@@ -34,6 +34,7 @@ from typing import (
TypeVar,
Union,
ClassVar,
TypeAlias,
cast,
)
@@ -424,7 +425,7 @@ class SetAbsoluteVolumeCommand(Command):
@Command.command
@dataclass
class RegisterNotificationCommand(Command):
pdu_id = PduId.SET_ABSOLUTE_VOLUME
pdu_id = PduId.REGISTER_NOTIFICATION
event_id: EventId = field(metadata=EventId.type_metadata(1))
playback_interval: int = field(metadata=hci.metadata('>4'))
@@ -678,256 +679,184 @@ class SongAndPlayStatus:
# -----------------------------------------------------------------------------
class ApplicationSetting:
class AttributeId(utils.OpenIntEnum):
class AttributeId(hci.SpecableEnum):
EQUALIZER_ON_OFF = 0x01
REPEAT_MODE = 0x02
SHUFFLE_ON_OFF = 0x03
SCAN_ON_OFF = 0x04
class EqualizerOnOffStatus(utils.OpenIntEnum):
class EqualizerOnOffStatus(hci.SpecableEnum):
OFF = 0x01
ON = 0x02
class RepeatModeStatus(utils.OpenIntEnum):
class RepeatModeStatus(hci.SpecableEnum):
OFF = 0x01
SINGLE_TRACK_REPEAT = 0x02
ALL_TRACK_REPEAT = 0x03
GROUP_REPEAT = 0x04
class ShuffleOnOffStatus(utils.OpenIntEnum):
class ShuffleOnOffStatus(hci.SpecableEnum):
OFF = 0x01
ALL_TRACKS_SHUFFLE = 0x02
GROUP_SHUFFLE = 0x03
class ScanOnOffStatus(utils.OpenIntEnum):
class ScanOnOffStatus(hci.SpecableEnum):
OFF = 0x01
ALL_TRACKS_SCAN = 0x02
GROUP_SCAN = 0x03
class GenericValue(utils.OpenIntEnum):
class GenericValue(hci.SpecableEnum):
pass
# -----------------------------------------------------------------------------
@dataclass
class Event:
event_id: EventId
event_id: EventId = field(init=False)
_pdu: Optional[bytes] = field(init=False, default=None)
_Event = TypeVar('_Event', bound='Event')
subclasses: ClassVar[dict[int, type[Event]]] = {}
fields: ClassVar[hci.Fields] = ()
@classmethod
def event(cls, subclass: type[_Event]) -> type[_Event]:
cls.subclasses[subclass.event_id] = subclass
subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass)
return subclass
@classmethod
def from_bytes(cls, pdu: bytes) -> Event:
event_id = EventId(pdu[0])
subclass = EVENT_SUBCLASSES.get(event_id, GenericEvent)
return subclass.from_bytes(pdu)
if not (subclass := cls.subclasses.get(pdu[0])):
raise core.InvalidPacketError(f"Unimplemented PDU {pdu[0]}")
instance = subclass(**hci.HCI_Object.dict_from_bytes(pdu, 1, subclass.fields))
instance._pdu = pdu
return instance
def __bytes__(self) -> bytes:
return bytes([self.event_id])
if self._pdu is None:
self._pdu = bytes([self.event_id]) + hci.HCI_Object.dict_to_bytes(
self.__dict__, self.fields
)
return self._pdu
def __repr__(self) -> str:
return str(self)
# -----------------------------------------------------------------------------
@dataclass
class GenericEvent(Event):
data: bytes
event_id: EventId = field(metadata=EventId.type_metadata(1))
data: bytes = field(metadata=hci.metadata('*'))
@classmethod
def from_bytes(cls, pdu: bytes) -> GenericEvent:
return cls(event_id=EventId(pdu[0]), data=pdu[1:])
def __bytes__(self) -> bytes:
return bytes([self.event_id]) + self.data
GenericEvent.fields = hci.HCI_Object.fields_from_dataclass(GenericEvent)
# -----------------------------------------------------------------------------
@Event.event
@dataclass
class PlaybackStatusChangedEvent(Event):
play_status: PlayStatus
@classmethod
def from_bytes(cls, pdu: bytes) -> PlaybackStatusChangedEvent:
return cls(play_status=PlayStatus(pdu[1]))
def __init__(self, play_status: PlayStatus) -> None:
super().__init__(EventId.PLAYBACK_STATUS_CHANGED)
self.play_status = play_status
def __bytes__(self) -> bytes:
return bytes([self.event_id]) + bytes([self.play_status])
event_id = EventId.PLAYBACK_STATUS_CHANGED
play_status: PlayStatus = field(metadata=PlayStatus.type_metadata(1))
# -----------------------------------------------------------------------------
@Event.event
@dataclass
class PlaybackPositionChangedEvent(Event):
playback_position: int
@classmethod
def from_bytes(cls, pdu: bytes) -> PlaybackPositionChangedEvent:
return cls(playback_position=struct.unpack_from(">I", pdu, 1)[0])
def __init__(self, playback_position: int) -> None:
super().__init__(EventId.PLAYBACK_POS_CHANGED)
self.playback_position = playback_position
def __bytes__(self) -> bytes:
return bytes([self.event_id]) + struct.pack(">I", self.playback_position)
event_id = EventId.PLAYBACK_POS_CHANGED
playback_position: int = field(metadata=hci.metadata('>4'))
# -----------------------------------------------------------------------------
@Event.event
@dataclass
class TrackChangedEvent(Event):
identifier: bytes
@classmethod
def from_bytes(cls, pdu: bytes) -> TrackChangedEvent:
return cls(identifier=pdu[1:])
def __init__(self, identifier: bytes) -> None:
super().__init__(EventId.TRACK_CHANGED)
self.identifier = identifier
def __bytes__(self) -> bytes:
return bytes([self.event_id]) + self.identifier
event_id = EventId.TRACK_CHANGED
identifier: bytes = field(metadata=hci.metadata('*'))
# -----------------------------------------------------------------------------
@Event.event
@dataclass
class PlayerApplicationSettingChangedEvent(Event):
event_id = EventId.PLAYER_APPLICATION_SETTING_CHANGED
@dataclass
class Setting:
attribute_id: ApplicationSetting.AttributeId
value_id: utils.OpenIntEnum
player_application_settings: list[Setting]
@classmethod
def from_bytes(cls, pdu: bytes) -> PlayerApplicationSettingChangedEvent:
def setting(attribute_id_int: int, value_id_int: int):
attribute_id = ApplicationSetting.AttributeId(attribute_id_int)
value_id: utils.OpenIntEnum
if attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF:
value_id = ApplicationSetting.EqualizerOnOffStatus(value_id_int)
elif attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE:
value_id = ApplicationSetting.RepeatModeStatus(value_id_int)
elif attribute_id == ApplicationSetting.AttributeId.SHUFFLE_ON_OFF:
value_id = ApplicationSetting.ShuffleOnOffStatus(value_id_int)
elif attribute_id == ApplicationSetting.AttributeId.SCAN_ON_OFF:
value_id = ApplicationSetting.ScanOnOffStatus(value_id_int)
else:
value_id = ApplicationSetting.GenericValue(value_id_int)
return cls.Setting(attribute_id, value_id)
settings = [
setting(pdu[2 + (i * 2)], pdu[2 + (i * 2) + 1]) for i in range(pdu[1])
]
return cls(player_application_settings=settings)
def __init__(self, player_application_settings: Sequence[Setting]) -> None:
super().__init__(EventId.PLAYER_APPLICATION_SETTING_CHANGED)
self.player_application_settings = list(player_application_settings)
def __bytes__(self) -> bytes:
return (
bytes([self.event_id])
+ bytes([len(self.player_application_settings)])
+ b''.join(
[
bytes([setting.attribute_id, setting.value_id])
for setting in self.player_application_settings
]
class Setting(hci.HCI_Dataclass_Object):
attribute_id: ApplicationSetting.AttributeId = field(
metadata=ApplicationSetting.AttributeId.type_metadata(1)
)
value_id: Union[
ApplicationSetting.EqualizerOnOffStatus,
ApplicationSetting.RepeatModeStatus,
ApplicationSetting.ShuffleOnOffStatus,
ApplicationSetting.ScanOnOffStatus,
ApplicationSetting.GenericValue,
] = field(metadata=hci.metadata(1))
def __post_init__(self) -> None:
super().__post_init__()
if self.attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF:
self.value_id = ApplicationSetting.EqualizerOnOffStatus(self.value_id)
elif self.attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE:
self.value_id = ApplicationSetting.RepeatModeStatus(self.value_id)
elif self.attribute_id == ApplicationSetting.AttributeId.SHUFFLE_ON_OFF:
self.value_id = ApplicationSetting.ShuffleOnOffStatus(self.value_id)
elif self.attribute_id == ApplicationSetting.AttributeId.SCAN_ON_OFF:
self.value_id = ApplicationSetting.ScanOnOffStatus(self.value_id)
else:
self.value_id = ApplicationSetting.GenericValue(self.value_id)
player_application_settings: Sequence[Setting] = field(
metadata=hci.metadata(Setting.parse_from_bytes, list_begin=True, list_end=True)
)
# -----------------------------------------------------------------------------
@Event.event
@dataclass
class NowPlayingContentChangedEvent(Event):
@classmethod
def from_bytes(cls, pdu: bytes) -> NowPlayingContentChangedEvent:
return cls()
def __init__(self) -> None:
super().__init__(EventId.NOW_PLAYING_CONTENT_CHANGED)
event_id = EventId.NOW_PLAYING_CONTENT_CHANGED
# -----------------------------------------------------------------------------
@Event.event
@dataclass
class AvailablePlayersChangedEvent(Event):
@classmethod
def from_bytes(cls, pdu: bytes) -> AvailablePlayersChangedEvent:
return cls()
def __init__(self) -> None:
super().__init__(EventId.AVAILABLE_PLAYERS_CHANGED)
event_id = EventId.AVAILABLE_PLAYERS_CHANGED
# -----------------------------------------------------------------------------
@Event.event
@dataclass
class AddressedPlayerChangedEvent(Event):
event_id = EventId.ADDRESSED_PLAYER_CHANGED
@dataclass
class Player:
player_id: int
uid_counter: int
class Player(hci.HCI_Dataclass_Object):
player_id: int = field(metadata=hci.metadata('>2'))
uid_counter: int = field(metadata=hci.metadata('>2'))
@classmethod
def from_bytes(cls, pdu: bytes) -> AddressedPlayerChangedEvent:
player_id, uid_counter = struct.unpack_from("<HH", pdu, 1)
return cls(cls.Player(player_id, uid_counter))
def __init__(self, player: Player) -> None:
super().__init__(EventId.ADDRESSED_PLAYER_CHANGED)
self.player = player
def __bytes__(self) -> bytes:
return bytes([self.event_id]) + struct.pack(
">HH", self.player.player_id, self.player.uid_counter
)
player: Player = field(metadata=hci.metadata(Player.parse_from_bytes))
# -----------------------------------------------------------------------------
@Event.event
@dataclass
class UidsChangedEvent(Event):
uid_counter: int
@classmethod
def from_bytes(cls, pdu: bytes) -> UidsChangedEvent:
return cls(uid_counter=struct.unpack_from(">H", pdu, 1)[0])
def __init__(self, uid_counter: int) -> None:
super().__init__(EventId.UIDS_CHANGED)
self.uid_counter = uid_counter
def __bytes__(self) -> bytes:
return bytes([self.event_id]) + struct.pack(">H", self.uid_counter)
event_id = EventId.UIDS_CHANGED
uid_counter: int = field(metadata=hci.metadata('>2'))
# -----------------------------------------------------------------------------
@Event.event
@dataclass
class VolumeChangedEvent(Event):
volume: int
@classmethod
def from_bytes(cls, pdu: bytes) -> VolumeChangedEvent:
return cls(volume=pdu[1])
def __init__(self, volume: int) -> None:
super().__init__(EventId.VOLUME_CHANGED)
self.volume = volume
def __bytes__(self) -> bytes:
return bytes([self.event_id]) + bytes([self.volume])
# -----------------------------------------------------------------------------
EVENT_SUBCLASSES: dict[EventId, type[Event]] = {
EventId.PLAYBACK_STATUS_CHANGED: PlaybackStatusChangedEvent,
EventId.PLAYBACK_POS_CHANGED: PlaybackPositionChangedEvent,
EventId.TRACK_CHANGED: TrackChangedEvent,
EventId.PLAYER_APPLICATION_SETTING_CHANGED: PlayerApplicationSettingChangedEvent,
EventId.NOW_PLAYING_CONTENT_CHANGED: NowPlayingContentChangedEvent,
EventId.AVAILABLE_PLAYERS_CHANGED: AvailablePlayersChangedEvent,
EventId.ADDRESSED_PLAYER_CHANGED: AddressedPlayerChangedEvent,
EventId.UIDS_CHANGED: UidsChangedEvent,
EventId.VOLUME_CHANGED: VolumeChangedEvent,
}
event_id = EventId.VOLUME_CHANGED
volume: int = field(metadata=hci.metadata(1))
# -----------------------------------------------------------------------------
@@ -1300,7 +1229,7 @@ class Protocol(utils.EventEmitter):
if not isinstance(event, PlayerApplicationSettingChangedEvent):
logger.warning("unexpected event class")
continue
yield event.player_application_settings
yield list(event.player_application_settings)
async def monitor_now_playing_content(self) -> AsyncIterator[None]:
"""Monitor Now Playing changes from the connected peer."""

View File

@@ -15,67 +15,138 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import struct
import pytest
from collections.abc import Sequence
from typing import Self
from bumble import avc, avctp, avrcp, controller, core, device, host, link
from bumble.transport import common
from bumble import avc, avctp, avrcp
from . import test_utils
# -----------------------------------------------------------------------------
class TwoDevices:
def __init__(self):
self.connections = [None, None]
addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0']
self.link = link.LocalLink()
self.controllers = [
controller.Controller('C1', link=self.link, public_address=addresses[0]),
controller.Controller('C2', link=self.link, public_address=addresses[1]),
]
self.devices = [
device.Device(
address=addresses[0],
host=host.Host(
self.controllers[0], common.AsyncPipeSink(self.controllers[0])
),
),
device.Device(
address=addresses[1],
host=host.Host(
self.controllers[1], common.AsyncPipeSink(self.controllers[1])
),
),
]
self.devices[0].classic_enabled = True
self.devices[1].classic_enabled = True
self.connections = [None, None]
self.protocols = [None, None]
def on_connection(self, which, connection):
self.connections[which] = connection
async def setup_connections(self):
await self.devices[0].power_on()
await self.devices[1].power_on()
self.connections = await asyncio.gather(
self.devices[0].connect(
self.devices[1].public_address, core.PhysicalTransport.BR_EDR
),
self.devices[1].accept(self.devices[0].public_address),
)
class TwoDevices(test_utils.TwoDevices):
protocols: Sequence[avrcp.Protocol] = ()
async def setup_avdtp_connections(self):
self.protocols = [avrcp.Protocol(), avrcp.Protocol()]
self.protocols[0].listen(self.devices[1])
await self.protocols[1].connect(self.connections[0])
@classmethod
async def create_with_avdtp(cls) -> Self:
devices = await cls.create_with_connection()
await devices.setup_avdtp_connections()
return devices
# -----------------------------------------------------------------------------
def test_GetPlayStatusCommand():
command = avrcp.GetPlayStatusCommand()
assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command
# -----------------------------------------------------------------------------
def test_GetCapabilitiesCommand():
command = avrcp.GetCapabilitiesCommand(
capability_id=avrcp.GetCapabilitiesCommand.CapabilityId.COMPANY_ID
)
assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command
# -----------------------------------------------------------------------------
def test_SetAbsoluteVolumeCommand():
command = avrcp.SetAbsoluteVolumeCommand(volume=5)
assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command
# -----------------------------------------------------------------------------
def test_GetElementAttributesCommand():
command = avrcp.GetElementAttributesCommand(
identifier=999,
attribute_ids=[
avrcp.MediaAttributeId.ALBUM_NAME,
avrcp.MediaAttributeId.ARTIST_NAME,
],
)
assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command
# -----------------------------------------------------------------------------
def test_RegisterNotificationCommand():
command = avrcp.RegisterNotificationCommand(
event_id=avrcp.EventId.ADDRESSED_PLAYER_CHANGED, playback_interval=123
)
assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command
# -----------------------------------------------------------------------------
def test_UidsChangedEvent():
event = avrcp.UidsChangedEvent(uid_counter=7)
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_TrackChangedEvent():
event = avrcp.TrackChangedEvent(identifier=b'12356')
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_VolumeChangedEvent():
event = avrcp.VolumeChangedEvent(volume=9)
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_PlaybackStatusChangedEvent():
event = avrcp.PlaybackStatusChangedEvent(play_status=avrcp.PlayStatus.PLAYING)
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_AddressedPlayerChangedEvent():
event = avrcp.AddressedPlayerChangedEvent(
player=avrcp.AddressedPlayerChangedEvent.Player(player_id=9, uid_counter=10)
)
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_AvailablePlayersChangedEvent():
event = avrcp.AvailablePlayersChangedEvent()
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_PlaybackPositionChangedEvent():
event = avrcp.PlaybackPositionChangedEvent(playback_position=1314)
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_NowPlayingContentChangedEvent():
event = avrcp.NowPlayingContentChangedEvent()
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_PlayerApplicationSettingChangedEvent():
event = avrcp.PlayerApplicationSettingChangedEvent(
player_application_settings=[
avrcp.PlayerApplicationSettingChangedEvent.Setting(
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
)
]
)
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_frame_parser():
with pytest.raises(ValueError) as error:
with pytest.raises(ValueError):
avc.Frame.from_bytes(bytes.fromhex("11480000"))
x = bytes.fromhex("014D0208")
@@ -217,8 +288,7 @@ def test_passthrough_commands():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_supported_events():
two_devices = TwoDevices()
await two_devices.setup_connections()
two_devices = await TwoDevices.create_with_avdtp()
supported_events = await two_devices.protocols[0].get_supported_events()
assert supported_events == []