diff --git a/bumble/avrcp.py b/bumble/avrcp.py index 2119e15..39754b4 100644 --- a/bumble/avrcp.py +++ b/bumble/avrcp.py @@ -1618,10 +1618,16 @@ class Delegate: self.status_code = status_code supported_events: list[EventId] + supported_company_ids: list[int] volume: int playback_status: PlayStatus - def __init__(self, supported_events: Iterable[EventId] = ()) -> None: + def __init__( + self, + supported_events: Iterable[EventId] = (), + supported_company_ids: Iterable[int] = (AVRCP_BLUETOOTH_SIG_COMPANY_ID,), + ) -> None: + self.supported_company_ids = list(supported_company_ids) self.supported_events = list(supported_events) self.volume = 0 self.playback_status = PlayStatus.STOPPED @@ -1629,6 +1635,9 @@ class Delegate: async def get_supported_events(self) -> list[EventId]: return self.supported_events + async def get_supported_company_ids(self) -> list[int]: + return self.supported_company_ids + async def set_absolute_volume(self, volume: int) -> None: """ Set the absolute volume. @@ -1867,6 +1876,19 @@ class Protocol(utils.EventEmitter): if isinstance(capability, EventId) ) + async def get_supported_company_ids(self) -> list[int]: + """Get the list of events supported by the connected peer.""" + response_context = await self.send_avrcp_command( + avc.CommandFrame.CommandType.STATUS, + GetCapabilitiesCommand(GetCapabilitiesCommand.CapabilityId.COMPANY_ID), + ) + response = self._check_response(response_context, GetCapabilitiesResponse) + return list( + int.from_bytes(capability, 'big') + for capability in response.capabilities + if isinstance(capability, bytes) + ) + async def get_play_status(self) -> SongAndPlayStatus: """Get the play status of the connected peer.""" response_context = await self.send_avrcp_command( @@ -2489,17 +2511,27 @@ class Protocol(utils.EventEmitter): logger.debug(f"<<< AVRCP command PDU: {command}") async def get_supported_events() -> None: + capabilities: Sequence[bytes | SupportsBytes] if ( command.capability_id - != GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED + == GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED ): - raise core.InvalidArgumentError() - - supported_events = await self.delegate.get_supported_events() + capabilities = await self.delegate.get_supported_events() + elif ( + command.capability_id == GetCapabilitiesCommand.CapabilityId.COMPANY_ID + ): + company_ids = await self.delegate.get_supported_company_ids() + capabilities = [ + company_id.to_bytes(3, 'big') for company_id in company_ids + ] + else: + raise core.InvalidArgumentError( + f"Unsupported capability: {command.capability_id}" + ) self.send_avrcp_response( transaction_label, avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE, - GetCapabilitiesResponse(command.capability_id, supported_events), + GetCapabilitiesResponse(command.capability_id, capabilities), ) self._delegate_command(transaction_label, command, get_supported_events()) diff --git a/tests/avrcp_test.py b/tests/avrcp_test.py index c5a9dc2..755ff17 100644 --- a/tests/avrcp_test.py +++ b/tests/avrcp_test.py @@ -16,8 +16,8 @@ # Imports # ----------------------------------------------------------------------------- from __future__ import annotations -import asyncio +import asyncio import struct from collections.abc import Sequence @@ -566,6 +566,21 @@ async def test_get_playback_status(): assert response.play_status == status +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_get_supported_company_ids(): + two_devices = await TwoDevices.create_with_avdtp() + + for status in avrcp.PlayStatus: + two_devices.protocols[0].delegate = avrcp.Delegate( + supported_company_ids=[avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID] + ) + supported_company_ids = await two_devices.protocols[ + 1 + ].get_supported_company_ids() + assert supported_company_ids == [avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID] + + # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_monitor_volume():