diff --git a/bumble/avrcp.py b/bumble/avrcp.py index b8f22fa0..733f327c 100644 --- a/bumble/avrcp.py +++ b/bumble/avrcp.py @@ -34,6 +34,7 @@ from typing import ( TypeVar, Union, ClassVar, + TypeAlias, cast, ) @@ -424,7 +425,7 @@ class SetAbsoluteVolumeCommand(Command): @Command.command @dataclass class RegisterNotificationCommand(Command): - pdu_id = PduId.SET_ABSOLUTE_VOLUME + pdu_id = PduId.REGISTER_NOTIFICATION event_id: EventId = field(metadata=EventId.type_metadata(1)) playback_interval: int = field(metadata=hci.metadata('>4')) @@ -678,256 +679,184 @@ class SongAndPlayStatus: # ----------------------------------------------------------------------------- class ApplicationSetting: - class AttributeId(utils.OpenIntEnum): + class AttributeId(hci.SpecableEnum): EQUALIZER_ON_OFF = 0x01 REPEAT_MODE = 0x02 SHUFFLE_ON_OFF = 0x03 SCAN_ON_OFF = 0x04 - class EqualizerOnOffStatus(utils.OpenIntEnum): + class EqualizerOnOffStatus(hci.SpecableEnum): OFF = 0x01 ON = 0x02 - class RepeatModeStatus(utils.OpenIntEnum): + class RepeatModeStatus(hci.SpecableEnum): OFF = 0x01 SINGLE_TRACK_REPEAT = 0x02 ALL_TRACK_REPEAT = 0x03 GROUP_REPEAT = 0x04 - class ShuffleOnOffStatus(utils.OpenIntEnum): + class ShuffleOnOffStatus(hci.SpecableEnum): OFF = 0x01 ALL_TRACKS_SHUFFLE = 0x02 GROUP_SHUFFLE = 0x03 - class ScanOnOffStatus(utils.OpenIntEnum): + class ScanOnOffStatus(hci.SpecableEnum): OFF = 0x01 ALL_TRACKS_SCAN = 0x02 GROUP_SCAN = 0x03 - class GenericValue(utils.OpenIntEnum): + class GenericValue(hci.SpecableEnum): pass # ----------------------------------------------------------------------------- @dataclass class Event: - event_id: EventId + event_id: EventId = field(init=False) + _pdu: Optional[bytes] = field(init=False, default=None) + + _Event = TypeVar('_Event', bound='Event') + subclasses: ClassVar[dict[int, type[Event]]] = {} + fields: ClassVar[hci.Fields] = () + + @classmethod + def event(cls, subclass: type[_Event]) -> type[_Event]: + cls.subclasses[subclass.event_id] = subclass + subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass) + return subclass @classmethod def from_bytes(cls, pdu: bytes) -> Event: - event_id = EventId(pdu[0]) - subclass = EVENT_SUBCLASSES.get(event_id, GenericEvent) - return subclass.from_bytes(pdu) + if not (subclass := cls.subclasses.get(pdu[0])): + raise core.InvalidPacketError(f"Unimplemented PDU {pdu[0]}") + instance = subclass(**hci.HCI_Object.dict_from_bytes(pdu, 1, subclass.fields)) + instance._pdu = pdu + return instance def __bytes__(self) -> bytes: - return bytes([self.event_id]) + if self._pdu is None: + self._pdu = bytes([self.event_id]) + hci.HCI_Object.dict_to_bytes( + self.__dict__, self.fields + ) + return self._pdu + + def __repr__(self) -> str: + return str(self) # ----------------------------------------------------------------------------- @dataclass class GenericEvent(Event): - data: bytes + event_id: EventId = field(metadata=EventId.type_metadata(1)) + data: bytes = field(metadata=hci.metadata('*')) - @classmethod - def from_bytes(cls, pdu: bytes) -> GenericEvent: - return cls(event_id=EventId(pdu[0]), data=pdu[1:]) - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + self.data +GenericEvent.fields = hci.HCI_Object.fields_from_dataclass(GenericEvent) # ----------------------------------------------------------------------------- +@Event.event @dataclass class PlaybackStatusChangedEvent(Event): - play_status: PlayStatus - - @classmethod - def from_bytes(cls, pdu: bytes) -> PlaybackStatusChangedEvent: - return cls(play_status=PlayStatus(pdu[1])) - - def __init__(self, play_status: PlayStatus) -> None: - super().__init__(EventId.PLAYBACK_STATUS_CHANGED) - self.play_status = play_status - - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + bytes([self.play_status]) + event_id = EventId.PLAYBACK_STATUS_CHANGED + play_status: PlayStatus = field(metadata=PlayStatus.type_metadata(1)) # ----------------------------------------------------------------------------- +@Event.event @dataclass class PlaybackPositionChangedEvent(Event): - playback_position: int - - @classmethod - def from_bytes(cls, pdu: bytes) -> PlaybackPositionChangedEvent: - return cls(playback_position=struct.unpack_from(">I", pdu, 1)[0]) - - def __init__(self, playback_position: int) -> None: - super().__init__(EventId.PLAYBACK_POS_CHANGED) - self.playback_position = playback_position - - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + struct.pack(">I", self.playback_position) + event_id = EventId.PLAYBACK_POS_CHANGED + playback_position: int = field(metadata=hci.metadata('>4')) # ----------------------------------------------------------------------------- +@Event.event @dataclass class TrackChangedEvent(Event): - identifier: bytes - - @classmethod - def from_bytes(cls, pdu: bytes) -> TrackChangedEvent: - return cls(identifier=pdu[1:]) - - def __init__(self, identifier: bytes) -> None: - super().__init__(EventId.TRACK_CHANGED) - self.identifier = identifier - - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + self.identifier + event_id = EventId.TRACK_CHANGED + identifier: bytes = field(metadata=hci.metadata('*')) # ----------------------------------------------------------------------------- +@Event.event @dataclass class PlayerApplicationSettingChangedEvent(Event): + event_id = EventId.PLAYER_APPLICATION_SETTING_CHANGED + @dataclass - class Setting: - attribute_id: ApplicationSetting.AttributeId - value_id: utils.OpenIntEnum - - player_application_settings: list[Setting] - - @classmethod - def from_bytes(cls, pdu: bytes) -> PlayerApplicationSettingChangedEvent: - def setting(attribute_id_int: int, value_id_int: int): - attribute_id = ApplicationSetting.AttributeId(attribute_id_int) - value_id: utils.OpenIntEnum - if attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF: - value_id = ApplicationSetting.EqualizerOnOffStatus(value_id_int) - elif attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE: - value_id = ApplicationSetting.RepeatModeStatus(value_id_int) - elif attribute_id == ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: - value_id = ApplicationSetting.ShuffleOnOffStatus(value_id_int) - elif attribute_id == ApplicationSetting.AttributeId.SCAN_ON_OFF: - value_id = ApplicationSetting.ScanOnOffStatus(value_id_int) - else: - value_id = ApplicationSetting.GenericValue(value_id_int) - - return cls.Setting(attribute_id, value_id) - - settings = [ - setting(pdu[2 + (i * 2)], pdu[2 + (i * 2) + 1]) for i in range(pdu[1]) - ] - return cls(player_application_settings=settings) - - def __init__(self, player_application_settings: Sequence[Setting]) -> None: - super().__init__(EventId.PLAYER_APPLICATION_SETTING_CHANGED) - self.player_application_settings = list(player_application_settings) - - def __bytes__(self) -> bytes: - return ( - bytes([self.event_id]) - + bytes([len(self.player_application_settings)]) - + b''.join( - [ - bytes([setting.attribute_id, setting.value_id]) - for setting in self.player_application_settings - ] - ) + class Setting(hci.HCI_Dataclass_Object): + attribute_id: ApplicationSetting.AttributeId = field( + metadata=ApplicationSetting.AttributeId.type_metadata(1) ) + value_id: Union[ + ApplicationSetting.EqualizerOnOffStatus, + ApplicationSetting.RepeatModeStatus, + ApplicationSetting.ShuffleOnOffStatus, + ApplicationSetting.ScanOnOffStatus, + ApplicationSetting.GenericValue, + ] = field(metadata=hci.metadata(1)) + + def __post_init__(self) -> None: + super().__post_init__() + if self.attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF: + self.value_id = ApplicationSetting.EqualizerOnOffStatus(self.value_id) + elif self.attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE: + self.value_id = ApplicationSetting.RepeatModeStatus(self.value_id) + elif self.attribute_id == ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: + self.value_id = ApplicationSetting.ShuffleOnOffStatus(self.value_id) + elif self.attribute_id == ApplicationSetting.AttributeId.SCAN_ON_OFF: + self.value_id = ApplicationSetting.ScanOnOffStatus(self.value_id) + else: + self.value_id = ApplicationSetting.GenericValue(self.value_id) + + player_application_settings: Sequence[Setting] = field( + metadata=hci.metadata(Setting.parse_from_bytes, list_begin=True, list_end=True) + ) # ----------------------------------------------------------------------------- +@Event.event @dataclass class NowPlayingContentChangedEvent(Event): - @classmethod - def from_bytes(cls, pdu: bytes) -> NowPlayingContentChangedEvent: - return cls() - - def __init__(self) -> None: - super().__init__(EventId.NOW_PLAYING_CONTENT_CHANGED) + event_id = EventId.NOW_PLAYING_CONTENT_CHANGED # ----------------------------------------------------------------------------- +@Event.event @dataclass class AvailablePlayersChangedEvent(Event): - @classmethod - def from_bytes(cls, pdu: bytes) -> AvailablePlayersChangedEvent: - return cls() - - def __init__(self) -> None: - super().__init__(EventId.AVAILABLE_PLAYERS_CHANGED) + event_id = EventId.AVAILABLE_PLAYERS_CHANGED # ----------------------------------------------------------------------------- +@Event.event @dataclass class AddressedPlayerChangedEvent(Event): + event_id = EventId.ADDRESSED_PLAYER_CHANGED + @dataclass - class Player: - player_id: int - uid_counter: int + class Player(hci.HCI_Dataclass_Object): + player_id: int = field(metadata=hci.metadata('>2')) + uid_counter: int = field(metadata=hci.metadata('>2')) - @classmethod - def from_bytes(cls, pdu: bytes) -> AddressedPlayerChangedEvent: - player_id, uid_counter = struct.unpack_from(" None: - super().__init__(EventId.ADDRESSED_PLAYER_CHANGED) - self.player = player - - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + struct.pack( - ">HH", self.player.player_id, self.player.uid_counter - ) + player: Player = field(metadata=hci.metadata(Player.parse_from_bytes)) # ----------------------------------------------------------------------------- +@Event.event @dataclass class UidsChangedEvent(Event): - uid_counter: int - - @classmethod - def from_bytes(cls, pdu: bytes) -> UidsChangedEvent: - return cls(uid_counter=struct.unpack_from(">H", pdu, 1)[0]) - - def __init__(self, uid_counter: int) -> None: - super().__init__(EventId.UIDS_CHANGED) - self.uid_counter = uid_counter - - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + struct.pack(">H", self.uid_counter) + event_id = EventId.UIDS_CHANGED + uid_counter: int = field(metadata=hci.metadata('>2')) # ----------------------------------------------------------------------------- +@Event.event @dataclass class VolumeChangedEvent(Event): - volume: int - - @classmethod - def from_bytes(cls, pdu: bytes) -> VolumeChangedEvent: - return cls(volume=pdu[1]) - - def __init__(self, volume: int) -> None: - super().__init__(EventId.VOLUME_CHANGED) - self.volume = volume - - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + bytes([self.volume]) - - -# ----------------------------------------------------------------------------- -EVENT_SUBCLASSES: dict[EventId, type[Event]] = { - EventId.PLAYBACK_STATUS_CHANGED: PlaybackStatusChangedEvent, - EventId.PLAYBACK_POS_CHANGED: PlaybackPositionChangedEvent, - EventId.TRACK_CHANGED: TrackChangedEvent, - EventId.PLAYER_APPLICATION_SETTING_CHANGED: PlayerApplicationSettingChangedEvent, - EventId.NOW_PLAYING_CONTENT_CHANGED: NowPlayingContentChangedEvent, - EventId.AVAILABLE_PLAYERS_CHANGED: AvailablePlayersChangedEvent, - EventId.ADDRESSED_PLAYER_CHANGED: AddressedPlayerChangedEvent, - EventId.UIDS_CHANGED: UidsChangedEvent, - EventId.VOLUME_CHANGED: VolumeChangedEvent, -} + event_id = EventId.VOLUME_CHANGED + volume: int = field(metadata=hci.metadata(1)) # ----------------------------------------------------------------------------- @@ -1300,7 +1229,7 @@ class Protocol(utils.EventEmitter): if not isinstance(event, PlayerApplicationSettingChangedEvent): logger.warning("unexpected event class") continue - yield event.player_application_settings + yield list(event.player_application_settings) async def monitor_now_playing_content(self) -> AsyncIterator[None]: """Monitor Now Playing changes from the connected peer.""" diff --git a/tests/avrcp_test.py b/tests/avrcp_test.py index dfc0bb31..26f9ba9d 100644 --- a/tests/avrcp_test.py +++ b/tests/avrcp_test.py @@ -15,67 +15,138 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- -import asyncio import struct import pytest +from collections.abc import Sequence +from typing import Self -from bumble import avc, avctp, avrcp, controller, core, device, host, link -from bumble.transport import common +from bumble import avc, avctp, avrcp +from . import test_utils # ----------------------------------------------------------------------------- -class TwoDevices: - def __init__(self): - self.connections = [None, None] - - addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0'] - self.link = link.LocalLink() - self.controllers = [ - controller.Controller('C1', link=self.link, public_address=addresses[0]), - controller.Controller('C2', link=self.link, public_address=addresses[1]), - ] - self.devices = [ - device.Device( - address=addresses[0], - host=host.Host( - self.controllers[0], common.AsyncPipeSink(self.controllers[0]) - ), - ), - device.Device( - address=addresses[1], - host=host.Host( - self.controllers[1], common.AsyncPipeSink(self.controllers[1]) - ), - ), - ] - self.devices[0].classic_enabled = True - self.devices[1].classic_enabled = True - self.connections = [None, None] - self.protocols = [None, None] - - def on_connection(self, which, connection): - self.connections[which] = connection - - async def setup_connections(self): - await self.devices[0].power_on() - await self.devices[1].power_on() - - self.connections = await asyncio.gather( - self.devices[0].connect( - self.devices[1].public_address, core.PhysicalTransport.BR_EDR - ), - self.devices[1].accept(self.devices[0].public_address), - ) +class TwoDevices(test_utils.TwoDevices): + protocols: Sequence[avrcp.Protocol] = () + async def setup_avdtp_connections(self): self.protocols = [avrcp.Protocol(), avrcp.Protocol()] self.protocols[0].listen(self.devices[1]) await self.protocols[1].connect(self.connections[0]) + @classmethod + async def create_with_avdtp(cls) -> Self: + devices = await cls.create_with_connection() + await devices.setup_avdtp_connections() + return devices + + +# ----------------------------------------------------------------------------- +def test_GetPlayStatusCommand(): + command = avrcp.GetPlayStatusCommand() + assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command + + +# ----------------------------------------------------------------------------- +def test_GetCapabilitiesCommand(): + command = avrcp.GetCapabilitiesCommand( + capability_id=avrcp.GetCapabilitiesCommand.CapabilityId.COMPANY_ID + ) + assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command + + +# ----------------------------------------------------------------------------- +def test_SetAbsoluteVolumeCommand(): + command = avrcp.SetAbsoluteVolumeCommand(volume=5) + assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command + + +# ----------------------------------------------------------------------------- +def test_GetElementAttributesCommand(): + command = avrcp.GetElementAttributesCommand( + identifier=999, + attribute_ids=[ + avrcp.MediaAttributeId.ALBUM_NAME, + avrcp.MediaAttributeId.ARTIST_NAME, + ], + ) + assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command + + +# ----------------------------------------------------------------------------- +def test_RegisterNotificationCommand(): + command = avrcp.RegisterNotificationCommand( + event_id=avrcp.EventId.ADDRESSED_PLAYER_CHANGED, playback_interval=123 + ) + assert avrcp.Command.from_bytes(command.pdu_id, command.payload) == command + + +# ----------------------------------------------------------------------------- +def test_UidsChangedEvent(): + event = avrcp.UidsChangedEvent(uid_counter=7) + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_TrackChangedEvent(): + event = avrcp.TrackChangedEvent(identifier=b'12356') + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_VolumeChangedEvent(): + event = avrcp.VolumeChangedEvent(volume=9) + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_PlaybackStatusChangedEvent(): + event = avrcp.PlaybackStatusChangedEvent(play_status=avrcp.PlayStatus.PLAYING) + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_AddressedPlayerChangedEvent(): + event = avrcp.AddressedPlayerChangedEvent( + player=avrcp.AddressedPlayerChangedEvent.Player(player_id=9, uid_counter=10) + ) + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_AvailablePlayersChangedEvent(): + event = avrcp.AvailablePlayersChangedEvent() + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_PlaybackPositionChangedEvent(): + event = avrcp.PlaybackPositionChangedEvent(playback_position=1314) + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_NowPlayingContentChangedEvent(): + event = avrcp.NowPlayingContentChangedEvent() + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_PlayerApplicationSettingChangedEvent(): + event = avrcp.PlayerApplicationSettingChangedEvent( + player_application_settings=[ + avrcp.PlayerApplicationSettingChangedEvent.Setting( + avrcp.ApplicationSetting.AttributeId.REPEAT_MODE, + avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT, + ) + ] + ) + assert avrcp.Event.from_bytes(bytes(event)) == event + # ----------------------------------------------------------------------------- def test_frame_parser(): - with pytest.raises(ValueError) as error: + with pytest.raises(ValueError): avc.Frame.from_bytes(bytes.fromhex("11480000")) x = bytes.fromhex("014D0208") @@ -217,8 +288,7 @@ def test_passthrough_commands(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_get_supported_events(): - two_devices = TwoDevices() - await two_devices.setup_connections() + two_devices = await TwoDevices.create_with_avdtp() supported_events = await two_devices.protocols[0].get_supported_events() assert supported_events == []