Migrate AVRCP events to dataclasses

This commit is contained in:
Josh Wu
2025-08-27 23:42:46 +08:00
parent dab0993cba
commit 31961febe5
2 changed files with 211 additions and 212 deletions

View File

@@ -34,6 +34,7 @@ from typing import (
TypeVar, TypeVar,
Union, Union,
ClassVar, ClassVar,
TypeAlias,
cast, cast,
) )
@@ -424,7 +425,7 @@ class SetAbsoluteVolumeCommand(Command):
@Command.command @Command.command
@dataclass @dataclass
class RegisterNotificationCommand(Command): class RegisterNotificationCommand(Command):
pdu_id = PduId.SET_ABSOLUTE_VOLUME pdu_id = PduId.REGISTER_NOTIFICATION
event_id: EventId = field(metadata=EventId.type_metadata(1)) event_id: EventId = field(metadata=EventId.type_metadata(1))
playback_interval: int = field(metadata=hci.metadata('>4')) playback_interval: int = field(metadata=hci.metadata('>4'))
@@ -678,256 +679,184 @@ class SongAndPlayStatus:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ApplicationSetting: class ApplicationSetting:
class AttributeId(utils.OpenIntEnum): class AttributeId(hci.SpecableEnum):
EQUALIZER_ON_OFF = 0x01 EQUALIZER_ON_OFF = 0x01
REPEAT_MODE = 0x02 REPEAT_MODE = 0x02
SHUFFLE_ON_OFF = 0x03 SHUFFLE_ON_OFF = 0x03
SCAN_ON_OFF = 0x04 SCAN_ON_OFF = 0x04
class EqualizerOnOffStatus(utils.OpenIntEnum): class EqualizerOnOffStatus(hci.SpecableEnum):
OFF = 0x01 OFF = 0x01
ON = 0x02 ON = 0x02
class RepeatModeStatus(utils.OpenIntEnum): class RepeatModeStatus(hci.SpecableEnum):
OFF = 0x01 OFF = 0x01
SINGLE_TRACK_REPEAT = 0x02 SINGLE_TRACK_REPEAT = 0x02
ALL_TRACK_REPEAT = 0x03 ALL_TRACK_REPEAT = 0x03
GROUP_REPEAT = 0x04 GROUP_REPEAT = 0x04
class ShuffleOnOffStatus(utils.OpenIntEnum): class ShuffleOnOffStatus(hci.SpecableEnum):
OFF = 0x01 OFF = 0x01
ALL_TRACKS_SHUFFLE = 0x02 ALL_TRACKS_SHUFFLE = 0x02
GROUP_SHUFFLE = 0x03 GROUP_SHUFFLE = 0x03
class ScanOnOffStatus(utils.OpenIntEnum): class ScanOnOffStatus(hci.SpecableEnum):
OFF = 0x01 OFF = 0x01
ALL_TRACKS_SCAN = 0x02 ALL_TRACKS_SCAN = 0x02
GROUP_SCAN = 0x03 GROUP_SCAN = 0x03
class GenericValue(utils.OpenIntEnum): class GenericValue(hci.SpecableEnum):
pass pass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclass @dataclass
class Event: class Event:
event_id: EventId event_id: EventId = field(init=False)
_pdu: Optional[bytes] = field(init=False, default=None)
_Event = TypeVar('_Event', bound='Event')
subclasses: ClassVar[dict[int, type[Event]]] = {}
fields: ClassVar[hci.Fields] = ()
@classmethod
def event(cls, subclass: type[_Event]) -> type[_Event]:
cls.subclasses[subclass.event_id] = subclass
subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass)
return subclass
@classmethod @classmethod
def from_bytes(cls, pdu: bytes) -> Event: def from_bytes(cls, pdu: bytes) -> Event:
event_id = EventId(pdu[0]) if not (subclass := cls.subclasses.get(pdu[0])):
subclass = EVENT_SUBCLASSES.get(event_id, GenericEvent) raise core.InvalidPacketError(f"Unimplemented PDU {pdu[0]}")
return subclass.from_bytes(pdu) instance = subclass(**hci.HCI_Object.dict_from_bytes(pdu, 1, subclass.fields))
instance._pdu = pdu
return instance
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return bytes([self.event_id]) if self._pdu is None:
self._pdu = bytes([self.event_id]) + hci.HCI_Object.dict_to_bytes(
self.__dict__, self.fields
)
return self._pdu
def __repr__(self) -> str:
return str(self)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclass @dataclass
class GenericEvent(Event): class GenericEvent(Event):
data: bytes event_id: EventId = field(metadata=EventId.type_metadata(1))
data: bytes = field(metadata=hci.metadata('*'))
@classmethod
def from_bytes(cls, pdu: bytes) -> GenericEvent:
return cls(event_id=EventId(pdu[0]), data=pdu[1:])
def __bytes__(self) -> bytes: GenericEvent.fields = hci.HCI_Object.fields_from_dataclass(GenericEvent)
return bytes([self.event_id]) + self.data
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@Event.event
@dataclass @dataclass
class PlaybackStatusChangedEvent(Event): class PlaybackStatusChangedEvent(Event):
play_status: PlayStatus event_id = EventId.PLAYBACK_STATUS_CHANGED
play_status: PlayStatus = field(metadata=PlayStatus.type_metadata(1))
@classmethod
def from_bytes(cls, pdu: bytes) -> PlaybackStatusChangedEvent:
return cls(play_status=PlayStatus(pdu[1]))
def __init__(self, play_status: PlayStatus) -> None:
super().__init__(EventId.PLAYBACK_STATUS_CHANGED)
self.play_status = play_status
def __bytes__(self) -> bytes:
return bytes([self.event_id]) + bytes([self.play_status])
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@Event.event
@dataclass @dataclass
class PlaybackPositionChangedEvent(Event): class PlaybackPositionChangedEvent(Event):
playback_position: int event_id = EventId.PLAYBACK_POS_CHANGED
playback_position: int = field(metadata=hci.metadata('>4'))
@classmethod
def from_bytes(cls, pdu: bytes) -> PlaybackPositionChangedEvent:
return cls(playback_position=struct.unpack_from(">I", pdu, 1)[0])
def __init__(self, playback_position: int) -> None:
super().__init__(EventId.PLAYBACK_POS_CHANGED)
self.playback_position = playback_position
def __bytes__(self) -> bytes:
return bytes([self.event_id]) + struct.pack(">I", self.playback_position)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@Event.event
@dataclass @dataclass
class TrackChangedEvent(Event): class TrackChangedEvent(Event):
identifier: bytes event_id = EventId.TRACK_CHANGED
identifier: bytes = field(metadata=hci.metadata('*'))
@classmethod
def from_bytes(cls, pdu: bytes) -> TrackChangedEvent:
return cls(identifier=pdu[1:])
def __init__(self, identifier: bytes) -> None:
super().__init__(EventId.TRACK_CHANGED)
self.identifier = identifier
def __bytes__(self) -> bytes:
return bytes([self.event_id]) + self.identifier
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@Event.event
@dataclass @dataclass
class PlayerApplicationSettingChangedEvent(Event): class PlayerApplicationSettingChangedEvent(Event):
event_id = EventId.PLAYER_APPLICATION_SETTING_CHANGED
@dataclass @dataclass
class Setting: class Setting(hci.HCI_Dataclass_Object):
attribute_id: ApplicationSetting.AttributeId attribute_id: ApplicationSetting.AttributeId = field(
value_id: utils.OpenIntEnum metadata=ApplicationSetting.AttributeId.type_metadata(1)
player_application_settings: list[Setting]
@classmethod
def from_bytes(cls, pdu: bytes) -> PlayerApplicationSettingChangedEvent:
def setting(attribute_id_int: int, value_id_int: int):
attribute_id = ApplicationSetting.AttributeId(attribute_id_int)
value_id: utils.OpenIntEnum
if attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF:
value_id = ApplicationSetting.EqualizerOnOffStatus(value_id_int)
elif attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE:
value_id = ApplicationSetting.RepeatModeStatus(value_id_int)
elif attribute_id == ApplicationSetting.AttributeId.SHUFFLE_ON_OFF:
value_id = ApplicationSetting.ShuffleOnOffStatus(value_id_int)
elif attribute_id == ApplicationSetting.AttributeId.SCAN_ON_OFF:
value_id = ApplicationSetting.ScanOnOffStatus(value_id_int)
else:
value_id = ApplicationSetting.GenericValue(value_id_int)
return cls.Setting(attribute_id, value_id)
settings = [
setting(pdu[2 + (i * 2)], pdu[2 + (i * 2) + 1]) for i in range(pdu[1])
]
return cls(player_application_settings=settings)
def __init__(self, player_application_settings: Sequence[Setting]) -> None:
super().__init__(EventId.PLAYER_APPLICATION_SETTING_CHANGED)
self.player_application_settings = list(player_application_settings)
def __bytes__(self) -> bytes:
return (
bytes([self.event_id])
+ bytes([len(self.player_application_settings)])
+ b''.join(
[
bytes([setting.attribute_id, setting.value_id])
for setting in self.player_application_settings
]
)
) )
value_id: Union[
ApplicationSetting.EqualizerOnOffStatus,
ApplicationSetting.RepeatModeStatus,
ApplicationSetting.ShuffleOnOffStatus,
ApplicationSetting.ScanOnOffStatus,
ApplicationSetting.GenericValue,
] = field(metadata=hci.metadata(1))
def __post_init__(self) -> None:
super().__post_init__()
if self.attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF:
self.value_id = ApplicationSetting.EqualizerOnOffStatus(self.value_id)
elif self.attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE:
self.value_id = ApplicationSetting.RepeatModeStatus(self.value_id)
elif self.attribute_id == ApplicationSetting.AttributeId.SHUFFLE_ON_OFF:
self.value_id = ApplicationSetting.ShuffleOnOffStatus(self.value_id)
elif self.attribute_id == ApplicationSetting.AttributeId.SCAN_ON_OFF:
self.value_id = ApplicationSetting.ScanOnOffStatus(self.value_id)
else:
self.value_id = ApplicationSetting.GenericValue(self.value_id)
player_application_settings: Sequence[Setting] = field(
metadata=hci.metadata(Setting.parse_from_bytes, list_begin=True, list_end=True)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@Event.event
@dataclass @dataclass
class NowPlayingContentChangedEvent(Event): class NowPlayingContentChangedEvent(Event):
@classmethod event_id = EventId.NOW_PLAYING_CONTENT_CHANGED
def from_bytes(cls, pdu: bytes) -> NowPlayingContentChangedEvent:
return cls()
def __init__(self) -> None:
super().__init__(EventId.NOW_PLAYING_CONTENT_CHANGED)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@Event.event
@dataclass @dataclass
class AvailablePlayersChangedEvent(Event): class AvailablePlayersChangedEvent(Event):
@classmethod event_id = EventId.AVAILABLE_PLAYERS_CHANGED
def from_bytes(cls, pdu: bytes) -> AvailablePlayersChangedEvent:
return cls()
def __init__(self) -> None:
super().__init__(EventId.AVAILABLE_PLAYERS_CHANGED)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@Event.event
@dataclass @dataclass
class AddressedPlayerChangedEvent(Event): class AddressedPlayerChangedEvent(Event):
event_id = EventId.ADDRESSED_PLAYER_CHANGED
@dataclass @dataclass
class Player: class Player(hci.HCI_Dataclass_Object):
player_id: int player_id: int = field(metadata=hci.metadata('>2'))
uid_counter: int uid_counter: int = field(metadata=hci.metadata('>2'))
@classmethod player: Player = field(metadata=hci.metadata(Player.parse_from_bytes))
def from_bytes(cls, pdu: bytes) -> AddressedPlayerChangedEvent:
player_id, uid_counter = struct.unpack_from("<HH", pdu, 1)
return cls(cls.Player(player_id, uid_counter))
def __init__(self, player: Player) -> None:
super().__init__(EventId.ADDRESSED_PLAYER_CHANGED)
self.player = player
def __bytes__(self) -> bytes:
return bytes([self.event_id]) + struct.pack(
">HH", self.player.player_id, self.player.uid_counter
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@Event.event
@dataclass @dataclass
class UidsChangedEvent(Event): class UidsChangedEvent(Event):
uid_counter: int event_id = EventId.UIDS_CHANGED
uid_counter: int = field(metadata=hci.metadata('>2'))
@classmethod
def from_bytes(cls, pdu: bytes) -> UidsChangedEvent:
return cls(uid_counter=struct.unpack_from(">H", pdu, 1)[0])
def __init__(self, uid_counter: int) -> None:
super().__init__(EventId.UIDS_CHANGED)
self.uid_counter = uid_counter
def __bytes__(self) -> bytes:
return bytes([self.event_id]) + struct.pack(">H", self.uid_counter)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@Event.event
@dataclass @dataclass
class VolumeChangedEvent(Event): class VolumeChangedEvent(Event):
volume: int event_id = EventId.VOLUME_CHANGED
volume: int = field(metadata=hci.metadata(1))
@classmethod
def from_bytes(cls, pdu: bytes) -> VolumeChangedEvent:
return cls(volume=pdu[1])
def __init__(self, volume: int) -> None:
super().__init__(EventId.VOLUME_CHANGED)
self.volume = volume
def __bytes__(self) -> bytes:
return bytes([self.event_id]) + bytes([self.volume])
# -----------------------------------------------------------------------------
EVENT_SUBCLASSES: dict[EventId, type[Event]] = {
EventId.PLAYBACK_STATUS_CHANGED: PlaybackStatusChangedEvent,
EventId.PLAYBACK_POS_CHANGED: PlaybackPositionChangedEvent,
EventId.TRACK_CHANGED: TrackChangedEvent,
EventId.PLAYER_APPLICATION_SETTING_CHANGED: PlayerApplicationSettingChangedEvent,
EventId.NOW_PLAYING_CONTENT_CHANGED: NowPlayingContentChangedEvent,
EventId.AVAILABLE_PLAYERS_CHANGED: AvailablePlayersChangedEvent,
EventId.ADDRESSED_PLAYER_CHANGED: AddressedPlayerChangedEvent,
EventId.UIDS_CHANGED: UidsChangedEvent,
EventId.VOLUME_CHANGED: VolumeChangedEvent,
}
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1300,7 +1229,7 @@ class Protocol(utils.EventEmitter):
if not isinstance(event, PlayerApplicationSettingChangedEvent): if not isinstance(event, PlayerApplicationSettingChangedEvent):
logger.warning("unexpected event class") logger.warning("unexpected event class")
continue continue
yield event.player_application_settings yield list(event.player_application_settings)
async def monitor_now_playing_content(self) -> AsyncIterator[None]: async def monitor_now_playing_content(self) -> AsyncIterator[None]:
"""Monitor Now Playing changes from the connected peer.""" """Monitor Now Playing changes from the connected peer."""

View File

@@ -15,67 +15,138 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio
import struct import struct
import pytest import pytest
from collections.abc import Sequence
from typing import Self
from bumble import avc, avctp, avrcp, controller, core, device, host, link from bumble import avc, avctp, avrcp
from bumble.transport import common from . import test_utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class TwoDevices: class TwoDevices(test_utils.TwoDevices):
def __init__(self): protocols: Sequence[avrcp.Protocol] = ()
self.connections = [None, None]
addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0']
self.link = link.LocalLink()
self.controllers = [
controller.Controller('C1', link=self.link, public_address=addresses[0]),
controller.Controller('C2', link=self.link, public_address=addresses[1]),
]
self.devices = [
device.Device(
address=addresses[0],
host=host.Host(
self.controllers[0], common.AsyncPipeSink(self.controllers[0])
),
),
device.Device(
address=addresses[1],
host=host.Host(
self.controllers[1], common.AsyncPipeSink(self.controllers[1])
),
),
]
self.devices[0].classic_enabled = True
self.devices[1].classic_enabled = True
self.connections = [None, None]
self.protocols = [None, None]
def on_connection(self, which, connection):
self.connections[which] = connection
async def setup_connections(self):
await self.devices[0].power_on()
await self.devices[1].power_on()
self.connections = await asyncio.gather(
self.devices[0].connect(
self.devices[1].public_address, core.PhysicalTransport.BR_EDR
),
self.devices[1].accept(self.devices[0].public_address),
)
async def setup_avdtp_connections(self):
self.protocols = [avrcp.Protocol(), avrcp.Protocol()] self.protocols = [avrcp.Protocol(), avrcp.Protocol()]
self.protocols[0].listen(self.devices[1]) self.protocols[0].listen(self.devices[1])
await self.protocols[1].connect(self.connections[0]) await self.protocols[1].connect(self.connections[0])
@classmethod
async def create_with_avdtp(cls) -> Self:
devices = await cls.create_with_connection()
await devices.setup_avdtp_connections()
return devices
# -----------------------------------------------------------------------------
def test_GetPlayStatusCommand():
command = avrcp.GetPlayStatusCommand()
assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command
# -----------------------------------------------------------------------------
def test_GetCapabilitiesCommand():
command = avrcp.GetCapabilitiesCommand(
capability_id=avrcp.GetCapabilitiesCommand.CapabilityId.COMPANY_ID
)
assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command
# -----------------------------------------------------------------------------
def test_SetAbsoluteVolumeCommand():
command = avrcp.SetAbsoluteVolumeCommand(volume=5)
assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command
# -----------------------------------------------------------------------------
def test_GetElementAttributesCommand():
command = avrcp.GetElementAttributesCommand(
identifier=999,
attribute_ids=[
avrcp.MediaAttributeId.ALBUM_NAME,
avrcp.MediaAttributeId.ARTIST_NAME,
],
)
assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command
# -----------------------------------------------------------------------------
def test_RegisterNotificationCommand():
command = avrcp.RegisterNotificationCommand(
event_id=avrcp.EventId.ADDRESSED_PLAYER_CHANGED, playback_interval=123
)
assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command
# -----------------------------------------------------------------------------
def test_UidsChangedEvent():
event = avrcp.UidsChangedEvent(uid_counter=7)
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_TrackChangedEvent():
event = avrcp.TrackChangedEvent(identifier=b'12356')
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_VolumeChangedEvent():
event = avrcp.VolumeChangedEvent(volume=9)
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_PlaybackStatusChangedEvent():
event = avrcp.PlaybackStatusChangedEvent(play_status=avrcp.PlayStatus.PLAYING)
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_AddressedPlayerChangedEvent():
event = avrcp.AddressedPlayerChangedEvent(
player=avrcp.AddressedPlayerChangedEvent.Player(player_id=9, uid_counter=10)
)
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_AvailablePlayersChangedEvent():
event = avrcp.AvailablePlayersChangedEvent()
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_PlaybackPositionChangedEvent():
event = avrcp.PlaybackPositionChangedEvent(playback_position=1314)
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_NowPlayingContentChangedEvent():
event = avrcp.NowPlayingContentChangedEvent()
assert avrcp.Event.from_bytes(bytes(event)) == event
# -----------------------------------------------------------------------------
def test_PlayerApplicationSettingChangedEvent():
event = avrcp.PlayerApplicationSettingChangedEvent(
player_application_settings=[
avrcp.PlayerApplicationSettingChangedEvent.Setting(
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
)
]
)
assert avrcp.Event.from_bytes(bytes(event)) == event
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_frame_parser(): def test_frame_parser():
with pytest.raises(ValueError) as error: with pytest.raises(ValueError):
avc.Frame.from_bytes(bytes.fromhex("11480000")) avc.Frame.from_bytes(bytes.fromhex("11480000"))
x = bytes.fromhex("014D0208") x = bytes.fromhex("014D0208")
@@ -217,8 +288,7 @@ def test_passthrough_commands():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_supported_events(): async def test_get_supported_events():
two_devices = TwoDevices() two_devices = await TwoDevices.create_with_avdtp()
await two_devices.setup_connections()
supported_events = await two_devices.protocols[0].get_supported_events() supported_events = await two_devices.protocols[0].get_supported_events()
assert supported_events == [] assert supported_events == []