diff --git a/bumble/profiles/ams.py b/bumble/profiles/ams.py new file mode 100644 index 0000000..6d2b86b --- /dev/null +++ b/bumble/profiles/ams.py @@ -0,0 +1,362 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Apple Media Service (AMS). +""" + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import asyncio +import dataclasses +import enum +import logging +from typing import Optional, Iterable, Union + + +from bumble.device import Peer +from bumble.gatt import ( + Characteristic, + GATT_AMS_SERVICE, + GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC, + GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC, + GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC, + TemplateService, +) +from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy +from bumble import utils + + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Protocol +# ----------------------------------------------------------------------------- +class RemoteCommandId(utils.OpenIntEnum): + PLAY = 0 + PAUSE = 1 + TOGGLE_PLAY_PAUSE = 2 + NEXT_TRACK = 3 + PREVIOUS_TRACK = 4 + VOLUME_UP = 5 + VOLUME_DOWN = 6 + ADVANCE_REPEAT_MODE = 7 + ADVANCE_SHUFFLE_MODE = 8 + SKIP_FORWARD = 9 + SKIP_BACKWARD = 10 + LIKE_TRACK = 11 + DISLIKE_TRACK = 12 + BOOKMARK_TRACK = 13 + + +class EntityId(utils.OpenIntEnum): + PLAYER = 0 + QUEUE = 1 + TRACK = 2 + + +class ActionId(utils.OpenIntEnum): + POSITIVE = 0 + NEGATIVE = 1 + + +class EntityUpdateFlags(enum.IntFlag): + TRUNCATED = 1 + + +class PlayerAttributeId(utils.OpenIntEnum): + NAME = 0 + PLAYBACK_INFO = 1 + VOLUME = 2 + + +class QueueAttributeId(utils.OpenIntEnum): + INDEX = 0 + COUNT = 1 + SHUFFLE_MODE = 2 + REPEAT_MODE = 3 + + +class ShuffleMode(utils.OpenIntEnum): + OFF = 0 + ONE = 1 + ALL = 2 + + +class RepeatMode(utils.OpenIntEnum): + OFF = 0 + ONE = 1 + ALL = 2 + + +class TrackAttributeId(utils.OpenIntEnum): + ARTIST = 0 + ALBUM = 1 + TITLE = 2 + DURATION = 3 + + +class PlaybackState(utils.OpenIntEnum): + PAUSED = 0 + PLAYING = 1 + REWINDING = 2 + FAST_FORWARDING = 3 + + +@dataclasses.dataclass +class PlaybackInfo: + playback_state: PlaybackState = PlaybackState.PAUSED + playback_rate: float = 1.0 + elapsed_time: float = 0.0 + + +# ----------------------------------------------------------------------------- +# GATT Server-side +# ----------------------------------------------------------------------------- +class Ams(TemplateService): + UUID = GATT_AMS_SERVICE + + remote_command_characteristic: Characteristic + entity_update_characteristic: Characteristic + entity_attribute_characteristic: Characteristic + + def __init__(self) -> None: + # TODO not the final implementation + self.remote_command_characteristic = Characteristic( + GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC, + Characteristic.Properties.NOTIFY + | Characteristic.Properties.WRITE_WITHOUT_RESPONSE, + Characteristic.Permissions.WRITEABLE, + ) + + # TODO not the final implementation + self.entity_update_characteristic = Characteristic( + GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC, + Characteristic.Properties.NOTIFY | Characteristic.Properties.WRITE, + Characteristic.Permissions.WRITEABLE, + ) + + # TODO not the final implementation + self.entity_attribute_characteristic = Characteristic( + GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC, + Characteristic.Properties.READ + | Characteristic.Properties.WRITE_WITHOUT_RESPONSE, + Characteristic.Permissions.WRITEABLE | Characteristic.Permissions.READABLE, + ) + + super().__init__( + [ + self.remote_command_characteristic, + self.entity_update_characteristic, + self.entity_attribute_characteristic, + ] + ) + + +# ----------------------------------------------------------------------------- +# GATT Client-side +# ----------------------------------------------------------------------------- +class AmsProxy(ProfileServiceProxy): + SERVICE_CLASS = Ams + + # NOTE: these don't use adapters, because the format for write and notifications + # are different. + remote_command: CharacteristicProxy[bytes] + entity_update: CharacteristicProxy[bytes] + entity_attribute: CharacteristicProxy[bytes] + + def __init__(self, service_proxy: ServiceProxy): + self.remote_command = service_proxy.get_required_characteristic_by_uuid( + GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC + ) + + self.entity_update = service_proxy.get_required_characteristic_by_uuid( + GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC + ) + + self.entity_attribute = service_proxy.get_required_characteristic_by_uuid( + GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC + ) + + +class AmsClient(utils.EventEmitter): + EVENT_SUPPORTED_COMMANDS = "supported_commands" + EVENT_PLAYER_NAME = "player_name" + EVENT_PLAYER_PLAYBACK_INFO = "player_playback_info" + EVENT_PLAYER_VOLUME = "player_volume" + EVENT_QUEUE_COUNT = "queue_count" + EVENT_QUEUE_INDEX = "queue_index" + EVENT_QUEUE_SHUFFLE_MODE = "queue_shuffle_mode" + EVENT_QUEUE_REPEAT_MODE = "queue_repeat_mode" + EVENT_TRACK_ARTIST = "track_artist" + EVENT_TRACK_ALBUM = "track_album" + EVENT_TRACK_TITLE = "track_title" + EVENT_TRACK_DURATION = "track_duration" + + supported_commands: set[RemoteCommandId] + player_name: str = "" + player_playback_info: PlaybackInfo + player_volume: float = 1.0 + queue_count: int = 0 + queue_index: int = 0 + queue_shuffle_mode: ShuffleMode = ShuffleMode.OFF + queue_repeat_mode: RepeatMode = RepeatMode.OFF + track_artist: str = "" + track_album: str = "" + track_title: str = "" + track_duration: float = 0.0 + + def __init__(self, ams_proxy: AmsProxy) -> None: + super().__init__() + self._ams_proxy = ams_proxy + self._started = False + self._read_attribute_semaphore = asyncio.Semaphore() + self.supported_commands = set() + + @classmethod + async def for_peer(cls, peer: Peer) -> Optional[AmsClient]: + ams_proxy = await peer.discover_service_and_create_proxy(AmsProxy) + if ams_proxy is None: + return None + return cls(ams_proxy) + + async def start(self) -> None: + logger.debug("subscribing to remote command characteristic") + await self._ams_proxy.remote_command.subscribe( + self._on_remote_command_notification + ) + + logger.debug("subscribing to entity update characteristic") + await self._ams_proxy.entity_update.subscribe( + lambda data: utils.AsyncRunner.spawn( + self._on_entity_update_notification(data) + ) + ) + + self._started = True + + async def stop(self) -> None: + await self._ams_proxy.remote_command.unsubscribe( + self._on_remote_command_notification + ) + await self._ams_proxy.entity_update.unsubscribe( + self._on_entity_update_notification + ) + self._started = False + + async def observe( + self, + entity: EntityId, + attributes: Iterable[ + Union[PlayerAttributeId, QueueAttributeId, TrackAttributeId] + ], + ) -> None: + await self._ams_proxy.entity_update.write_value( + bytes([entity] + list(attributes)), with_response=True + ) + + async def command(self, command: RemoteCommandId) -> None: + await self._ams_proxy.remote_command.write_value( + bytes([command]), with_response=True + ) + + def _on_remote_command_notification(self, data: bytes) -> None: + supported_commands = [RemoteCommandId(command) for command in data] + logger.debug( + f"supported commands: {[command.name for command in supported_commands]}" + ) + for command in supported_commands: + self.supported_commands.add(command) + + self.emit(self.EVENT_SUPPORTED_COMMANDS) + + async def _on_entity_update_notification(self, data: bytes) -> None: + entity = EntityId(data[0]) + flags = EntityUpdateFlags(data[2]) + value = data[3:] + + if flags & EntityUpdateFlags.TRUNCATED: + logger.debug("truncated attribute, fetching full value") + + # Write the entity and attribute we're interested in + # (protected by a semaphore, so that we only read one attribute at a time) + async with self._read_attribute_semaphore: + await self._ams_proxy.entity_attribute.write_value( + data[:2], with_response=True + ) + value = await self._ams_proxy.entity_attribute.read_value() + + if entity == EntityId.PLAYER: + player_attribute = PlayerAttributeId(data[1]) + if player_attribute == PlayerAttributeId.NAME: + self.player_name = value.decode() + self.emit(self.EVENT_PLAYER_NAME) + elif player_attribute == PlayerAttributeId.PLAYBACK_INFO: + playback_state_str, playback_rate_str, elapsed_time_str = ( + value.decode().split(",") + ) + self.player_playback_info = PlaybackInfo( + PlaybackState(int(playback_state_str)), + float(playback_rate_str), + float(elapsed_time_str), + ) + self.emit(self.EVENT_PLAYER_PLAYBACK_INFO) + elif player_attribute == PlayerAttributeId.VOLUME: + self.player_volume = float(value.decode()) + self.emit(self.EVENT_PLAYER_VOLUME) + else: + logger.warning(f"received unknown player attribute {player_attribute}") + + elif entity == EntityId.QUEUE: + queue_attribute = QueueAttributeId(data[1]) + if queue_attribute == QueueAttributeId.COUNT: + self.queue_count = int(value) + self.emit(self.EVENT_QUEUE_COUNT) + elif queue_attribute == QueueAttributeId.INDEX: + self.queue_index = int(value) + self.emit(self.EVENT_QUEUE_INDEX) + elif queue_attribute == QueueAttributeId.REPEAT_MODE: + self.queue_repeat_mode = RepeatMode(int(value)) + self.emit(self.EVENT_QUEUE_REPEAT_MODE) + elif queue_attribute == QueueAttributeId.SHUFFLE_MODE: + self.queue_shuffle_mode = ShuffleMode(int(value)) + self.emit(self.EVENT_QUEUE_SHUFFLE_MODE) + else: + logger.warning(f"received unknown queue attribute {queue_attribute}") + + elif entity == EntityId.TRACK: + track_attribute = TrackAttributeId(data[1]) + if track_attribute == TrackAttributeId.ARTIST: + self.track_artist = value.decode() + self.emit(self.EVENT_TRACK_ARTIST) + elif track_attribute == TrackAttributeId.ALBUM: + self.track_album = value.decode() + self.emit(self.EVENT_TRACK_ALBUM) + elif track_attribute == TrackAttributeId.TITLE: + self.track_title = value.decode() + self.emit(self.EVENT_TRACK_TITLE) + elif track_attribute == TrackAttributeId.DURATION: + self.track_duration = float(value.decode()) + self.emit(self.EVENT_TRACK_DURATION) + else: + logger.warning(f"received unknown track attribute {track_attribute}") + + else: + logger.warning(f"received unknown attribute ID {data[1]}") diff --git a/examples/run_ams_client.py b/examples/run_ams_client.py new file mode 100644 index 0000000..098567e --- /dev/null +++ b/examples/run_ams_client.py @@ -0,0 +1,220 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import sys +import os +import logging +from bumble.colors import color + +from bumble.device import Device, Peer +from bumble.transport import open_transport +from bumble.profiles.ams import ( + AmsClient, + EntityId, + PlayerAttributeId, + QueueAttributeId, + TrackAttributeId, + RemoteCommandId, +) + + +# ----------------------------------------------------------------------------- +async def handle_command_client( + ams_client: AmsClient, reader: asyncio.StreamReader, writer: asyncio.StreamWriter +) -> None: + while True: + command = (await reader.readline()).decode("utf-8") + if not command.endswith("\n"): + print("command client terminated") + return + command = command.strip() + + try: + if command.upper() in [member.name for member in RemoteCommandId]: + await ams_client.command(RemoteCommandId[command.upper()]) + continue + except Exception as error: + writer.write(f"ERROR: {error}\n".encode("utf-8")) + + writer.write(f"unknown command {command}\n".encode("utf-8")) + + +# ----------------------------------------------------------------------------- +async def main() -> None: + if len(sys.argv) < 3: + print( + 'Usage: run_ams_client.py ' + ' ' + ) + print('example: run_ams_client.py device1.json usb:0 E1:CA:72:48:C4:E8 512') + return + device_config, transport_spec, bluetooth_address, mtu = sys.argv[1:] + + print('<<< connecting to HCI...') + async with await open_transport(transport_spec) as hci_transport: + print('<<< connected') + + # Create a device to manage the host, with a custom listener + device = Device.from_config_file_with_hci( + device_config, hci_transport.source, hci_transport.sink + ) + await device.power_on() + + # Connect to the peer + print(f'=== Connecting to {bluetooth_address}...') + connection = await device.connect(bluetooth_address) + print(f'=== Connected: {connection}') + + await connection.encrypt() + + peer = Peer(connection) + mtu_int = int(mtu) + if mtu_int: + new_mtu = await peer.request_mtu(mtu_int) + print(f'ATT MTU = {new_mtu}') + ams_client = await AmsClient.for_peer(peer) + if ams_client is None: + print("!!! no AMS service found") + return + + # Register event handlers + + def on_supported_commands(): + print( + color("Supported commands:", "magenta"), + ", ".join([command.name for command in ams_client.supported_commands]), + ) + + ams_client.on(AmsClient.EVENT_SUPPORTED_COMMANDS, on_supported_commands) + + def on_player_name(): + print(color("Player Name:", "green"), ams_client.player_name) + + ams_client.on(AmsClient.EVENT_PLAYER_NAME, on_player_name) + + def on_player_playback_info(): + print( + color("Playback State:", "green"), + ams_client.player_playback_info.playback_state.name, + ) + print( + color("Playback Rate: ", "green"), + ams_client.player_playback_info.playback_rate, + ) + print( + color("Elapsed Time: ", "green"), + ams_client.player_playback_info.elapsed_time, + ) + + ams_client.on(AmsClient.EVENT_PLAYER_PLAYBACK_INFO, on_player_playback_info) + + def on_player_volume(): + print(color("Volume:", "green"), ams_client.player_volume) + + ams_client.on(AmsClient.EVENT_PLAYER_VOLUME, on_player_volume) + + def on_queue_count(): + print(color("Queue Count:", "yellow"), ams_client.queue_count) + + ams_client.on(AmsClient.EVENT_QUEUE_COUNT, on_queue_count) + + def on_queue_index(): + print(color("Queue Index:", "yellow"), ams_client.queue_index) + + ams_client.on(AmsClient.EVENT_QUEUE_INDEX, on_queue_index) + + def on_queue_shuffle_mode(): + print( + color("Queue Shuffle Mode:", "yellow"), + ams_client.queue_shuffle_mode.name, + ) + + ams_client.on(AmsClient.EVENT_QUEUE_SHUFFLE_MODE, on_queue_shuffle_mode) + + def on_queue_repeat_mode(): + print( + color("Queue Repeat Mode:", "yellow"), ams_client.queue_repeat_mode.name + ) + + ams_client.on(AmsClient.EVENT_QUEUE_REPEAT_MODE, on_queue_repeat_mode) + + def on_track_artist(): + print(color("Track Artist:", "cyan"), ams_client.track_artist) + + ams_client.on(AmsClient.EVENT_TRACK_ARTIST, on_track_artist) + + def on_track_album(): + print(color("Track Album:", "cyan"), ams_client.track_album) + + ams_client.on(AmsClient.EVENT_TRACK_ALBUM, on_track_album) + + def on_track_title(): + print(color("Track Title:", "cyan"), ams_client.track_title) + + ams_client.on(AmsClient.EVENT_TRACK_TITLE, on_track_title) + + def on_track_duration(): + print(color("Track Duration:", "cyan"), ams_client.track_duration) + + ams_client.on(AmsClient.EVENT_TRACK_DURATION, on_track_duration) + + # Start the client + await ams_client.start() + + # Observe the player, queue and track + await ams_client.observe( + EntityId.PLAYER, + [ + PlayerAttributeId.NAME, + PlayerAttributeId.PLAYBACK_INFO, + PlayerAttributeId.VOLUME, + ], + ) + await ams_client.observe( + EntityId.QUEUE, + [ + QueueAttributeId.COUNT, + QueueAttributeId.INDEX, + QueueAttributeId.REPEAT_MODE, + QueueAttributeId.SHUFFLE_MODE, + ], + ) + await ams_client.observe( + EntityId.TRACK, + [ + TrackAttributeId.ALBUM, + TrackAttributeId.ARTIST, + TrackAttributeId.DURATION, + TrackAttributeId.TITLE, + ], + ) + + # Accept a TCP connection to handle commands. + tcp_server = await asyncio.start_server( + lambda reader, writer: handle_command_client(ams_client, reader, writer), + '127.0.0.1', + 9000, + ) + print("Accepting command client on port 9000") + async with tcp_server: + await tcp_server.serve_forever() + + +# ----------------------------------------------------------------------------- +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) +asyncio.run(main())