diff --git a/apps/player/player.py b/apps/player/player.py new file mode 100644 index 00000000..dbabf4fc --- /dev/null +++ b/apps/player/player.py @@ -0,0 +1,495 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import asyncio +import asyncio.subprocess +import os +import logging +from typing import Optional, Union + +import click + +from bumble.a2dp import ( + SBC_JOINT_STEREO_CHANNEL_MODE, + SBC_LOUDNESS_ALLOCATION_METHOD, + make_audio_source_service_sdp_records, + A2DP_SBC_CODEC_TYPE, + A2DP_MPEG_2_4_AAC_CODEC_TYPE, + MPEG_2_AAC_LC_OBJECT_TYPE, + AacFrame, + AacParser, + AacPacketSource, + AacMediaCodecInformation, + SbcFrame, + SbcParser, + SbcPacketSource, + SbcMediaCodecInformation, +) +from bumble.avrcp import Protocol as AvrcpProtocol +from bumble.avdtp import ( + find_avdtp_service_with_connection, + AVDTP_AUDIO_MEDIA_TYPE, + AVDTP_DELAY_REPORTING_SERVICE_CATEGORY, + MediaCodecCapabilities, + MediaPacketPump, + Protocol as AvdtpProtocol, +) +from bumble.colors import color +from bumble.core import ( + AdvertisingData, + ConnectionError as BumbleConnectionError, + DeviceClass, + BT_BR_EDR_TRANSPORT, +) +from bumble.device import Connection, Device, DeviceConfiguration +from bumble.hci import Address +from bumble.pairing import PairingConfig +from bumble.transport import open_transport +from bumble.utils import AsyncRunner + + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- + + +# ----------------------------------------------------------------------------- +def a2dp_source_sdp_records(): + service_record_handle = 0x00010001 + return { + service_record_handle: make_audio_source_service_sdp_records( + service_record_handle + ) + } + + +# ----------------------------------------------------------------------------- +async def sbc_codec_capabilities(read_function) -> MediaCodecCapabilities: + sbc_parser = SbcParser(read_function) + sbc_frame: SbcFrame + async for sbc_frame in sbc_parser.frames: + # We only need the first frame + print(color(f"SBC format: {sbc_frame}", "cyan")) + break + + return MediaCodecCapabilities( + media_type=AVDTP_AUDIO_MEDIA_TYPE, + media_codec_type=A2DP_SBC_CODEC_TYPE, + media_codec_information=SbcMediaCodecInformation.from_discrete_values( + sampling_frequency=sbc_frame.sampling_frequency, + channel_mode=SBC_JOINT_STEREO_CHANNEL_MODE, + block_length=16, + subbands=8, + allocation_method=SBC_LOUDNESS_ALLOCATION_METHOD, + minimum_bitpool_value=2, + maximum_bitpool_value=40, + ), + ) + + +# ----------------------------------------------------------------------------- +async def aac_codec_capabilities(read_function) -> MediaCodecCapabilities: + aac_parser = AacParser(read_function) + aac_frame: AacFrame + async for aac_frame in aac_parser.frames: + # We only need the first frame + print(color(f"AAC format: {aac_frame}", "cyan")) + break + + return MediaCodecCapabilities( + media_type=AVDTP_AUDIO_MEDIA_TYPE, + media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE, + media_codec_information=AacMediaCodecInformation.from_discrete_values( + object_type=MPEG_2_AAC_LC_OBJECT_TYPE, + sampling_frequency=aac_frame.sampling_frequency, + channels=aac_frame.channel_configuration, + vbr=1, + bitrate=128000, + ), + ) + + +# ----------------------------------------------------------------------------- +class Player: + def __init__( + self, + transport: str, + device_config: Optional[str], + authenticate: bool, + encrypt: bool, + ) -> None: + self.transport = transport + self.device_config = device_config + self.authenticate = authenticate + self.encrypt = encrypt + self.avrcp_protocol: Optional[AvrcpProtocol] = None + self.done: Optional[asyncio.Future] + + async def run(self, workload) -> None: + self.done = asyncio.get_running_loop().create_future() + try: + await self._run(workload) + except BumbleConnectionError as error: + print(color(f"Failed to connect: {error}", "red")) + except Exception as error: + print(color(f"!!! ERROR: {error}", "red")) + + async def _run(self, workload) -> None: + async with await open_transport(self.transport) as (hci_source, hci_sink): + # Create a device + device_config = DeviceConfiguration() + if self.device_config: + device_config.load_from_file(self.device_config) + else: + device_config.name = "Bumble Player" + device_config.class_of_device = DeviceClass.pack_class_of_device( + DeviceClass.AUDIO_SERVICE_CLASS, + DeviceClass.AUDIO_VIDEO_MAJOR_DEVICE_CLASS, + DeviceClass.AUDIO_VIDEO_UNCATEGORIZED_MINOR_DEVICE_CLASS, + ) + device_config.keystore = "JsonKeyStore" + + device_config.classic_enabled = True + device_config.le_enabled = False + device = Device.from_config_with_hci(device_config, hci_source, hci_sink) + + # Setup the SDP to expose the SRC service + device.sdp_service_records = a2dp_source_sdp_records() + + # Setup AVRCP + self.avrcp_protocol = AvrcpProtocol() + self.avrcp_protocol.listen(device) + + # Don't require MITM when pairing. + device.pairing_config_factory = lambda connection: PairingConfig(mitm=False) + + # Start the controller + await device.power_on() + + # Print some of the config/properties + print( + "Player Bluetooth Address:", + color( + device.public_address.to_string(with_type_qualifier=False), + "yellow", + ), + ) + + # Listen for connections + device.on("connection", self.on_bluetooth_connection) + + # Run the workload + await workload(device) + + # Wait until it is time to exit + await asyncio.wait( + [hci_source.terminated, self.done], return_when=asyncio.FIRST_COMPLETED + ) + + def on_bluetooth_connection(self, connection: Connection) -> None: + print(color(f"--- Connected: {connection}", "cyan")) + + async def connect(self, device: Device, address: str) -> Connection: + print(color(f"Connecting to {address}...", "green")) + connection = await device.connect(address, transport=BT_BR_EDR_TRANSPORT) + + # Request authentication + if self.authenticate: + print(color("*** Authenticating...", "blue")) + await connection.authenticate() + print(color("*** Authenticated", "blue")) + + # Enable encryption + if self.encrypt: + print(color("*** Enabling encryption...", "blue")) + await connection.encrypt() + print(color("*** Encryption on", "blue")) + + return connection + + async def create_avdtp_protocol(self, connection: Connection) -> AvdtpProtocol: + # Look for an A2DP service + avdtp_version = await find_avdtp_service_with_connection(connection) + if not avdtp_version: + raise RuntimeError("no A2DP service found") + + print(color(f"AVDTP Version: {avdtp_version}")) + + # Create a client to interact with the remote device + return await AvdtpProtocol.connect(connection, avdtp_version) + + async def stream_packets( + self, + protocol: AvdtpProtocol, + codec_type: int, + packet_source: Union[SbcPacketSource, AacPacketSource], + ): + # Discover all endpoints on the remote device + endpoints = await protocol.discover_remote_endpoints() + for endpoint in endpoints: + print('@@@', endpoint) + + # Select a sink + sink = protocol.find_remote_sink_by_codec(AVDTP_AUDIO_MEDIA_TYPE, codec_type) + if sink is None: + print(color('!!! no compatible sink found', 'red')) + return + print(f'### Selected sink: {sink.seid}') + + # Check if the sink supports delay reporting + delay_reporting = False + for capability in sink.capabilities: + if capability.service_category == AVDTP_DELAY_REPORTING_SERVICE_CATEGORY: + delay_reporting = True + break + + def on_delay_report(delay: int): + print(color(f"*** DELAY REPORT: {delay}", "blue")) + + # Stream the packets + packet_pump = MediaPacketPump(packet_source.packets) + source = protocol.add_source( + packet_source.codec_capabilities, packet_pump, delay_reporting + ) + source.on("delay_report", on_delay_report) + stream = await protocol.create_stream(source, sink) + await stream.start() + + await packet_pump.wait_for_completion() + + async def discover(self, device: Device) -> None: + @device.on("inquiry_result") + def on_inquiry_result( + address: Address, class_of_device: int, data: AdvertisingData, rssi: int + ): + ( + service_classes, + major_device_class, + minor_device_class, + ) = DeviceClass.split_class_of_device(class_of_device) + separator = "\n " + print(f">>> {color(address.to_string(False), 'yellow')}:") + print(f" Device Class (raw): {class_of_device:06X}") + major_class_name = DeviceClass.major_device_class_name(major_device_class) + print(" Device Major Class: " f"{major_class_name}") + minor_class_name = DeviceClass.minor_device_class_name( + major_device_class, minor_device_class + ) + print(" Device Minor Class: " f"{minor_class_name}") + print( + " Device Services: " + f"{', '.join(DeviceClass.service_class_labels(service_classes))}" + ) + print(f" RSSI: {rssi}") + if data.ad_structures: + print(f" {data.to_string(separator)}") + + await device.start_discovery() + + async def pair(self, device: Device, address: str) -> None: + connection = await self.connect(device, address) + + print(color("Pairing...", "magenta")) + await connection.authenticate() + print(color("Pairing completed", "magenta")) + if self.done is not None: + self.done.set_result(None) + + async def inquire(self, device: Device, address: str) -> None: + connection = await self.connect(device, address) + avdtp_protocol = await self.create_avdtp_protocol(connection) + + # Discover the remote endpoints + endpoints = await avdtp_protocol.discover_remote_endpoints() + print(f'@@@ Found {len(list(endpoints))} endpoints') + for endpoint in endpoints: + print('@@@', endpoint) + + if self.done is not None: + self.done.set_result(None) + + async def play( + self, + device: Device, + address: Optional[str], + audio_format: str, + audio_file: str, + ) -> None: + if audio_format == "auto": + if audio_file.endswith(".sbc"): + audio_format = "sbc" + elif audio_file.endswith(".aac") or audio_file.endswith(".adts"): + audio_format = "aac" + else: + raise ValueError("Unable to determine audio format from file extension") + + device.on( + "connection", + lambda connection: AsyncRunner.spawn(on_connection(connection)), + ) + + async def on_connection(connection: Connection): + avdtp_protocol = await self.create_avdtp_protocol(connection) + + with open(audio_file, 'rb') as input_file: + # NOTE: this should be using asyncio file reading, but blocking reads + # are good enough for this command line app. + async def read_audio_data(byte_count): + return input_file.read(byte_count) + + # Obtain the codec capabilities from the stream + packet_source: Union[SbcPacketSource, AacPacketSource] + if audio_format == "sbc": + codec_type = A2DP_SBC_CODEC_TYPE + codec_capabilities = await sbc_codec_capabilities(read_audio_data) + packet_source = SbcPacketSource( + read_audio_data, + avdtp_protocol.l2cap_channel.peer_mtu, + codec_capabilities, + ) + else: + codec_type = A2DP_MPEG_2_4_AAC_CODEC_TYPE + codec_capabilities = await aac_codec_capabilities(read_audio_data) + packet_source = AacPacketSource( + read_audio_data, + avdtp_protocol.l2cap_channel.peer_mtu, + codec_capabilities, + ) + + # Rewind to the start + input_file.seek(0) + + try: + await self.stream_packets(avdtp_protocol, codec_type, packet_source) + except Exception as error: + print(color(f"!!! Error while streaming: {error}", "red")) + + if self.done: + self.done.set_result(None) + + if address: + await self.connect(device, address) + else: + print(color("Waiting for an incoming connection...", "magenta")) + + +# ----------------------------------------------------------------------------- +def create_player(context) -> Player: + return Player( + transport=context.obj["hci_transport"], + device_config=context.obj["device_config"], + authenticate=context.obj["authenticate"], + encrypt=context.obj["encrypt"], + ) + + +# ----------------------------------------------------------------------------- +@click.group() +@click.pass_context +@click.option("--hci-transport", metavar="TRANSPORT", required=True) +@click.option("--device-config", metavar="FILENAME", help="Device configuration file") +@click.option( + "--authenticate", + is_flag=True, + help="Request authentication when connecting", + default=False, +) +@click.option( + "--encrypt", is_flag=True, help="Request encryption when connecting", default=False +) +def player_cli(ctx, hci_transport, device_config, authenticate, encrypt): + ctx.ensure_object(dict) + ctx.obj["hci_transport"] = hci_transport + ctx.obj["device_config"] = device_config + ctx.obj["authenticate"] = authenticate + ctx.obj["encrypt"] = encrypt + + +@player_cli.command("discover") +@click.pass_context +def discover(context): + """Discover for speakers or headphones""" + player = create_player(context) + asyncio.run(player.run(player.discover)) + + +@player_cli.command("inquire") +@click.pass_context +@click.argument( + "address", + metavar="ADDRESS", +) +def inquire(context, address): + """Connect to a speaker or headphone and inquire about their capabilities""" + player = create_player(context) + asyncio.run(player.run(lambda device: player.inquire(device, address))) + + +@player_cli.command("pair") +@click.pass_context +@click.argument( + "address", + metavar="ADDRESS", +) +def pair(context, address): + """Pair with a speaker or headphone""" + player = create_player(context) + asyncio.run(player.run(lambda device: player.pair(device, address))) + + +@player_cli.command("play") +@click.pass_context +@click.option( + "--connect", + "address", + metavar="ADDRESS", + help="Address or name to connect to", +) +@click.option( + "-f", + "--audio-format", + type=click.Choice(["auto", "sbc", "aac"]), + help="Audio file format (use 'auto' to infer the format from the file extension)", + default="auto", +) +@click.argument("audio_file") +def play(context, address, audio_format, audio_file): + """Play and audio file""" + player = create_player(context) + asyncio.run( + player.run( + lambda device: player.play(device, address, audio_format, audio_file) + ) + ) + + +# ----------------------------------------------------------------------------- +def main(): + logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper()) + player_cli() + + +# ----------------------------------------------------------------------------- +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/apps/speaker/speaker.py b/apps/speaker/speaker.py index fc2230a9..aa1a92d7 100644 --- a/apps/speaker/speaker.py +++ b/apps/speaker/speaker.py @@ -93,7 +93,7 @@ class AudioExtractor: # ----------------------------------------------------------------------------- class AacAudioExtractor: def extract_audio(self, packet: MediaPacket) -> bytes: - return AacAudioRtpPacket(packet.payload).to_adts() + return AacAudioRtpPacket.from_bytes(packet.payload).to_adts() # ----------------------------------------------------------------------------- diff --git a/bumble/a2dp.py b/bumble/a2dp.py index cac14e91..285b748c 100644 --- a/bumble/a2dp.py +++ b/bumble/a2dp.py @@ -17,12 +17,14 @@ # ----------------------------------------------------------------------------- from __future__ import annotations -import dataclasses -import struct -import logging from collections.abc import AsyncGenerator +import dataclasses +import enum +import logging +import struct from typing import List, Callable, Awaitable +from .codecs import AacAudioRtpPacket from .company_ids import COMPANY_IDENTIFIERS from .sdp import ( DataElement, @@ -535,6 +537,7 @@ class SbcFrame: block_count: int channel_mode: int subband_count: int + bitpool: int payload: bytes @property @@ -555,6 +558,7 @@ class SbcFrame: f'cm={self.channel_mode},' f'br={self.bitrate},' f'sc={self.sample_count},' + f'bp={self.bitpool},' f'size={len(self.payload)})' ) @@ -602,7 +606,7 @@ class SbcParser: # Emit the next frame yield SbcFrame( - sampling_frequency, blocks, channel_mode, subbands, payload + sampling_frequency, blocks, channel_mode, subbands, bitpool, payload ) return generate_frames() @@ -632,13 +636,12 @@ class SbcPacketSource: # NOTE: this doesn't support frame fragments sbc_parser = SbcParser(self.read) async for frame in sbc_parser.frames: - print(frame) - if ( frames_size + len(frame.payload) > max_rtp_payload or len(frames) == 16 ): # Need to flush what has been accumulated so far + logger.debug(f"yielding {len(frames)} frames") # Emit a packet sbc_payload = bytes([len(frames)]) + b''.join( @@ -663,3 +666,140 @@ class SbcPacketSource: frames_size += len(frame.payload) return generate_packets() + + +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class AacFrame: + class Profile(enum.IntEnum): + MAIN = 0 + LC = 1 + SSR = 2 + LTP = 3 + + profile: Profile + sampling_frequency: int + channel_configuration: int + payload: bytes + + @property + def sample_count(self) -> int: + return 1024 + + @property + def duration(self) -> float: + return self.sample_count / self.sampling_frequency + + def __str__(self) -> str: + return ( + f'AAC(sf={self.sampling_frequency},' + f'ch={self.channel_configuration},' + f'size={len(self.payload)})' + ) + + +# ----------------------------------------------------------------------------- +ADTS_AAC_SAMPLING_FREQUENCIES = [ + 96000, + 88200, + 64000, + 48000, + 44100, + 32000, + 24000, + 22050, + 16000, + 12000, + 11025, + 8000, + 7350, + 0, + 0, + 0, +] + + +# ----------------------------------------------------------------------------- +class AacParser: + """Parser for AAC frames in an ADTS stream""" + + def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None: + self.read = read + + @property + def frames(self) -> AsyncGenerator[AacFrame, None]: + async def generate_frames() -> AsyncGenerator[AacFrame, None]: + while True: + header = await self.read(7) + if not header: + return + + sync_word = (header[0] << 4) | (header[1] >> 4) + if sync_word != 0b111111111111: + raise ValueError(f"invalid sync word ({sync_word:06x})") + layer = (header[1] >> 1) & 0b11 + profile = AacFrame.Profile((header[2] >> 6) & 0b11) + sampling_frequency = ADTS_AAC_SAMPLING_FREQUENCIES[ + (header[2] >> 2) & 0b1111 + ] + channel_configuration = ((header[2] & 0b1) << 2) | (header[3] >> 6) + frame_length = ( + ((header[3] & 0b11) << 11) | (header[4] << 3) | (header[5] >> 5) + ) + + if layer != 0: + raise ValueError("layer must be 0") + + payload = await self.read(frame_length - 7) + if payload: + yield AacFrame( + profile, sampling_frequency, channel_configuration, payload + ) + + return generate_frames() + + +# ----------------------------------------------------------------------------- +class AacPacketSource: + def __init__( + self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities + ) -> None: + self.read = read + self.mtu = mtu + self.codec_capabilities = codec_capabilities + + @property + def packets(self): + async def generate_packets(): + # pylint: disable=import-outside-toplevel + from .avdtp import MediaPacket # Import here to avoid a circular reference + + sequence_number = 0 + timestamp = 0 + + aac_parser = AacParser(self.read) + async for frame in aac_parser.frames: + logger.debug("yielding one AAC frame") + + # Emit a packet + aac_payload = bytes( + AacAudioRtpPacket.for_simple_aac( + frame.sampling_frequency, + frame.channel_configuration, + frame.payload, + ) + ) + packet = MediaPacket( + 2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, aac_payload + ) + packet.timestamp_seconds = timestamp / frame.sampling_frequency + yield packet + + # Prepare for next packets + sequence_number += 1 + sequence_number &= 0xFFFF + timestamp += frame.sample_count + timestamp &= 0xFFFFFFFF + frames = [frame] + + return generate_packets() diff --git a/bumble/avc.py b/bumble/avc.py index 8e6b968e..81502780 100644 --- a/bumble/avc.py +++ b/bumble/avc.py @@ -119,7 +119,7 @@ class Frame: # Not supported raise NotImplementedError("extended subunit types not supported") - if subunit_id < 5: + if subunit_id < 5 or subunit_id == 7: opcode_offset = 2 elif subunit_id == 5: # Extended to the next byte @@ -132,7 +132,6 @@ class Frame: else: subunit_id = 5 + extension opcode_offset = 3 - elif subunit_id == 6: raise core.InvalidPacketError("reserved subunit ID") diff --git a/bumble/avdtp.py b/bumble/avdtp.py index fd79dc33..fe311986 100644 --- a/bumble/avdtp.py +++ b/bumble/avdtp.py @@ -377,6 +377,7 @@ class MediaPacketPump: self.packets = packets self.clock = clock self.pump_task = None + self.completed = asyncio.Event() async def start(self, rtp_channel: l2cap.ClassicChannel) -> None: async def pump_packets(): @@ -406,6 +407,8 @@ class MediaPacketPump: ) except asyncio.exceptions.CancelledError: logger.debug('pump canceled') + finally: + self.completed.set() # Pump packets self.pump_task = asyncio.create_task(pump_packets()) @@ -417,6 +420,9 @@ class MediaPacketPump: await self.pump_task self.pump_task = None + async def wait_for_completion(self) -> None: + await self.completed.wait() + # ----------------------------------------------------------------------------- class MessageAssembler: @@ -1316,10 +1322,20 @@ class Protocol(EventEmitter): return None def add_source( - self, codec_capabilities: MediaCodecCapabilities, packet_pump: MediaPacketPump + self, + codec_capabilities: MediaCodecCapabilities, + packet_pump: MediaPacketPump, + delay_reporting: bool = False, ) -> LocalSource: seid = len(self.local_endpoints) + 1 - source = LocalSource(self, seid, codec_capabilities, packet_pump) + service_capabilities = ( + [ServiceCapabilities(AVDTP_DELAY_REPORTING_SERVICE_CATEGORY)] + if delay_reporting + else [] + ) + source = LocalSource( + self, seid, codec_capabilities, service_capabilities, packet_pump + ) self.local_endpoints.append(source) return source @@ -2180,12 +2196,13 @@ class LocalSource(LocalStreamEndPoint): protocol: Protocol, seid: int, codec_capabilities: MediaCodecCapabilities, + other_capabilitiles: Iterable[ServiceCapabilities], packet_pump: MediaPacketPump, ) -> None: capabilities = [ ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY), codec_capabilities, - ] + ] + list(other_capabilitiles) super().__init__( protocol, seid, diff --git a/bumble/avrcp.py b/bumble/avrcp.py index e06a5a67..4bc625a1 100644 --- a/bumble/avrcp.py +++ b/bumble/avrcp.py @@ -1491,10 +1491,14 @@ class Protocol(pyee.EventEmitter): f"<<< AVCTP Command, transaction_label={transaction_label}: " f"{command}" ) - # Only the PANEL subunit type with subunit ID 0 is supported in this profile. - if ( - command.subunit_type != avc.Frame.SubunitType.PANEL - or command.subunit_id != 0 + # Only addressing the unit, or the PANEL subunit with subunit ID 0 is supported + # in this profile. + if not ( + command.subunit_type == avc.Frame.SubunitType.UNIT + and command.subunit_id == 7 + ) and not ( + command.subunit_type == avc.Frame.SubunitType.PANEL + and command.subunit_id == 0 ): logger.debug("subunit not supported") self.send_not_implemented_response(transaction_label, command) @@ -1528,8 +1532,8 @@ class Protocol(pyee.EventEmitter): # TODO: delegate response = avc.PassThroughResponseFrame( avc.ResponseFrame.ResponseCode.ACCEPTED, - avc.Frame.SubunitType.PANEL, - 0, + command.subunit_type, + command.subunit_id, command.state_flag, command.operation_id, command.operation_data, @@ -1846,6 +1850,15 @@ class Protocol(pyee.EventEmitter): RejectedResponse(pdu_id, status_code), ) + def send_not_implemented_avrcp_response( + self, transaction_label: int, pdu_id: Protocol.PduId + ) -> None: + self.send_avrcp_response( + transaction_label, + avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED, + NotImplementedResponse(pdu_id, b''), + ) + def _on_get_capabilities_command( self, transaction_label: int, command: GetCapabilitiesCommand ) -> None: @@ -1891,29 +1904,35 @@ class Protocol(pyee.EventEmitter): async def register_notification(): # Check if the event is supported. supported_events = await self.delegate.get_supported_events() - if command.event_id in supported_events: - if command.event_id == EventId.VOLUME_CHANGED: - volume = await self.delegate.get_absolute_volume() - response = RegisterNotificationResponse(VolumeChangedEvent(volume)) - self.send_avrcp_response( - transaction_label, - avc.ResponseFrame.ResponseCode.INTERIM, - response, - ) - self._register_notification_listener(transaction_label, command) - return + if command.event_id not in supported_events: + logger.debug("event not supported") + self.send_not_implemented_avrcp_response( + transaction_label, self.PduId.REGISTER_NOTIFICATION + ) + return - if command.event_id == EventId.PLAYBACK_STATUS_CHANGED: - # TODO: testing only, use delegate - response = RegisterNotificationResponse( - PlaybackStatusChangedEvent(play_status=PlayStatus.PLAYING) - ) - self.send_avrcp_response( - transaction_label, - avc.ResponseFrame.ResponseCode.INTERIM, - response, - ) - self._register_notification_listener(transaction_label, command) - return + if command.event_id == EventId.VOLUME_CHANGED: + volume = await self.delegate.get_absolute_volume() + response = RegisterNotificationResponse(VolumeChangedEvent(volume)) + self.send_avrcp_response( + transaction_label, + avc.ResponseFrame.ResponseCode.INTERIM, + response, + ) + self._register_notification_listener(transaction_label, command) + return + + if command.event_id == EventId.PLAYBACK_STATUS_CHANGED: + # TODO: testing only, use delegate + response = RegisterNotificationResponse( + PlaybackStatusChangedEvent(play_status=PlayStatus.PLAYING) + ) + self.send_avrcp_response( + transaction_label, + avc.ResponseFrame.ResponseCode.INTERIM, + response, + ) + self._register_notification_listener(transaction_label, command) + return self._delegate_command(transaction_label, command, register_notification()) diff --git a/bumble/codecs.py b/bumble/codecs.py index cfb3cad1..4d4c48cb 100644 --- a/bumble/codecs.py +++ b/bumble/codecs.py @@ -17,6 +17,7 @@ # ----------------------------------------------------------------------------- from __future__ import annotations from dataclasses import dataclass +from typing_extensions import Self from bumble import core @@ -101,12 +102,40 @@ class BitReader: break +# ----------------------------------------------------------------------------- +class BitWriter: + """Simple but not optimized bit stream writer.""" + + data: int + bit_count: int + + def __init__(self) -> None: + self.data = 0 + self.bit_count = 0 + + def write(self, value: int, bit_count: int) -> None: + self.data = (self.data << bit_count) | value + self.bit_count += bit_count + + def write_bytes(self, data: bytes) -> None: + bit_count = 8 * len(data) + self.data = (self.data << bit_count) | int.from_bytes(data, 'big') + self.bit_count += bit_count + + def __bytes__(self) -> bytes: + return (self.data << ((8 - (self.bit_count % 8)) % 8)).to_bytes( + (self.bit_count + 7) // 8, 'big' + ) + + # ----------------------------------------------------------------------------- class AacAudioRtpPacket: """AAC payload encapsulated in an RTP packet payload""" + audio_mux_element: AudioMuxElement + @staticmethod - def latm_value(reader: BitReader) -> int: + def read_latm_value(reader: BitReader) -> int: bytes_for_value = reader.read(2) value = 0 for _ in range(bytes_for_value + 1): @@ -114,24 +143,33 @@ class AacAudioRtpPacket: return value @staticmethod - def program_config_element(reader: BitReader): - raise core.InvalidPacketError('program_config_element not supported') + def read_audio_object_type(reader: BitReader): + # GetAudioObjectType - ISO/EIC 14496-3 Table 1.16 + audio_object_type = reader.read(5) + if audio_object_type == 31: + audio_object_type = 32 + reader.read(6) + + return audio_object_type @dataclass class GASpecificConfig: - def __init__( - self, reader: BitReader, channel_configuration: int, audio_object_type: int - ) -> None: + audio_object_type: int + # NOTE: other fields not supported + + @classmethod + def from_bits( + cls, reader: BitReader, channel_configuration: int, audio_object_type: int + ) -> Self: # GASpecificConfig - ISO/EIC 14496-3 Table 4.1 frame_length_flag = reader.read(1) depends_on_core_coder = reader.read(1) if depends_on_core_coder: - self.core_coder_delay = reader.read(14) + core_coder_delay = reader.read(14) extension_flag = reader.read(1) if not channel_configuration: - AacAudioRtpPacket.program_config_element(reader) + raise core.InvalidPacketError('program_config_element not supported') if audio_object_type in (6, 20): - self.layer_nr = reader.read(3) + layer_nr = reader.read(3) if extension_flag: if audio_object_type == 22: num_of_sub_frame = reader.read(5) @@ -144,14 +182,13 @@ class AacAudioRtpPacket: if extension_flag_3 == 1: raise core.InvalidPacketError('extensionFlag3 == 1 not supported') - @staticmethod - def audio_object_type(reader: BitReader): - # GetAudioObjectType - ISO/EIC 14496-3 Table 1.16 - audio_object_type = reader.read(5) - if audio_object_type == 31: - audio_object_type = 32 + reader.read(6) + return cls(audio_object_type) - return audio_object_type + def to_bits(self, writer: BitWriter) -> None: + assert self.audio_object_type in (1, 2) + writer.write(0, 1) # frame_length_flag = 0 + writer.write(0, 1) # depends_on_core_coder = 0 + writer.write(0, 1) # extension_flag = 0 @dataclass class AudioSpecificConfig: @@ -159,6 +196,7 @@ class AacAudioRtpPacket: sampling_frequency_index: int sampling_frequency: int channel_configuration: int + ga_specific_config: AacAudioRtpPacket.GASpecificConfig sbr_present_flag: int ps_present_flag: int extension_audio_object_type: int @@ -182,44 +220,73 @@ class AacAudioRtpPacket: 7350, ] - def __init__(self, reader: BitReader) -> None: - # AudioSpecificConfig - ISO/EIC 14496-3 Table 1.15 - self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader) - self.sampling_frequency_index = reader.read(4) - if self.sampling_frequency_index == 0xF: - self.sampling_frequency = reader.read(24) - else: - self.sampling_frequency = self.SAMPLING_FREQUENCIES[ - self.sampling_frequency_index - ] - self.channel_configuration = reader.read(4) - self.sbr_present_flag = -1 - self.ps_present_flag = -1 - if self.audio_object_type in (5, 29): - self.extension_audio_object_type = 5 - self.sbc_present_flag = 1 - if self.audio_object_type == 29: - self.ps_present_flag = 1 - self.extension_sampling_frequency_index = reader.read(4) - if self.extension_sampling_frequency_index == 0xF: - self.extension_sampling_frequency = reader.read(24) - else: - self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[ - self.extension_sampling_frequency_index - ] - self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader) - if self.audio_object_type == 22: - self.extension_channel_configuration = reader.read(4) - else: - self.extension_audio_object_type = 0 + @classmethod + def for_simple_aac( + cls, + audio_object_type: int, + sampling_frequency: int, + channel_configuration: int, + ) -> Self: + if sampling_frequency not in cls.SAMPLING_FREQUENCIES: + raise ValueError(f'invalid sampling frequency {sampling_frequency}') - if self.audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23): - ga_specific_config = AacAudioRtpPacket.GASpecificConfig( - reader, self.channel_configuration, self.audio_object_type + ga_specific_config = AacAudioRtpPacket.GASpecificConfig(audio_object_type) + + return cls( + audio_object_type=audio_object_type, + sampling_frequency_index=cls.SAMPLING_FREQUENCIES.index( + sampling_frequency + ), + sampling_frequency=sampling_frequency, + channel_configuration=channel_configuration, + ga_specific_config=ga_specific_config, + sbr_present_flag=0, + ps_present_flag=0, + extension_audio_object_type=0, + extension_sampling_frequency_index=0, + extension_sampling_frequency=0, + extension_channel_configuration=0, + ) + + @classmethod + def from_bits(cls, reader: BitReader) -> Self: + # AudioSpecificConfig - ISO/EIC 14496-3 Table 1.15 + audio_object_type = AacAudioRtpPacket.read_audio_object_type(reader) + sampling_frequency_index = reader.read(4) + if sampling_frequency_index == 0xF: + sampling_frequency = reader.read(24) + else: + sampling_frequency = cls.SAMPLING_FREQUENCIES[sampling_frequency_index] + channel_configuration = reader.read(4) + sbr_present_flag = 0 + ps_present_flag = 0 + extension_sampling_frequency_index = 0 + extension_sampling_frequency = 0 + extension_channel_configuration = 0 + extension_audio_object_type = 0 + if audio_object_type in (5, 29): + extension_audio_object_type = 5 + sbr_present_flag = 1 + if audio_object_type == 29: + ps_present_flag = 1 + extension_sampling_frequency_index = reader.read(4) + if extension_sampling_frequency_index == 0xF: + extension_sampling_frequency = reader.read(24) + else: + extension_sampling_frequency = cls.SAMPLING_FREQUENCIES[ + extension_sampling_frequency_index + ] + audio_object_type = AacAudioRtpPacket.read_audio_object_type(reader) + if audio_object_type == 22: + extension_channel_configuration = reader.read(4) + + if audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23): + ga_specific_config = AacAudioRtpPacket.GASpecificConfig.from_bits( + reader, channel_configuration, audio_object_type ) else: raise core.InvalidPacketError( - f'audioObjectType {self.audio_object_type} not supported' + f'audioObjectType {audio_object_type} not supported' ) # if self.extension_audio_object_type != 5 and bits_to_decode >= 16: @@ -248,13 +315,44 @@ class AacAudioRtpPacket: # self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[self.extension_sampling_frequency_index] # self.extension_channel_configuration = reader.read(4) + return cls( + audio_object_type, + sampling_frequency_index, + sampling_frequency, + channel_configuration, + ga_specific_config, + sbr_present_flag, + ps_present_flag, + extension_audio_object_type, + extension_sampling_frequency_index, + extension_sampling_frequency, + extension_channel_configuration, + ) + + def to_bits(self, writer: BitWriter) -> None: + if self.sampling_frequency_index >= 15: + raise ValueError( + f"unsupported sampling frequency index {self.sampling_frequency_index}" + ) + + if self.audio_object_type not in (1, 2): + raise ValueError( + f"unsupported audio object type {self.audio_object_type} " + ) + + writer.write(self.audio_object_type, 5) + writer.write(self.sampling_frequency_index, 4) + writer.write(self.channel_configuration, 4) + self.ga_specific_config.to_bits(writer) + @dataclass class StreamMuxConfig: other_data_present: int other_data_len_bits: int audio_specific_config: AacAudioRtpPacket.AudioSpecificConfig - def __init__(self, reader: BitReader) -> None: + @classmethod + def from_bits(cls, reader: BitReader) -> Self: # StreamMuxConfig - ISO/EIC 14496-3 Table 1.42 audio_mux_version = reader.read(1) if audio_mux_version == 1: @@ -264,7 +362,7 @@ class AacAudioRtpPacket: if audio_mux_version_a != 0: raise core.InvalidPacketError('audioMuxVersionA != 0 not supported') if audio_mux_version == 1: - tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader) + tara_buffer_fullness = AacAudioRtpPacket.read_latm_value(reader) stream_cnt = 0 all_streams_same_time_framing = reader.read(1) num_sub_frames = reader.read(6) @@ -275,13 +373,13 @@ class AacAudioRtpPacket: if num_layer != 0: raise core.InvalidPacketError('num_layer != 0 not supported') if audio_mux_version == 0: - self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig( + audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig.from_bits( reader ) else: - asc_len = AacAudioRtpPacket.latm_value(reader) + asc_len = AacAudioRtpPacket.read_latm_value(reader) marker = reader.bit_position - self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig( + audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig.from_bits( reader ) audio_specific_config_len = reader.bit_position - marker @@ -299,36 +397,49 @@ class AacAudioRtpPacket: f'frame_length_type {frame_length_type} not supported' ) - self.other_data_present = reader.read(1) - if self.other_data_present: + other_data_present = reader.read(1) + other_data_len_bits = 0 + if other_data_present: if audio_mux_version == 1: - self.other_data_len_bits = AacAudioRtpPacket.latm_value(reader) + other_data_len_bits = AacAudioRtpPacket.read_latm_value(reader) else: - self.other_data_len_bits = 0 while True: - self.other_data_len_bits *= 256 + other_data_len_bits *= 256 other_data_len_esc = reader.read(1) - self.other_data_len_bits += reader.read(8) + other_data_len_bits += reader.read(8) if other_data_len_esc == 0: break crc_check_present = reader.read(1) if crc_check_present: crc_checksum = reader.read(8) + return cls(other_data_present, other_data_len_bits, audio_specific_config) + + def to_bits(self, writer: BitWriter) -> None: + writer.write(0, 1) # audioMuxVersion = 0 + writer.write(1, 1) # allStreamsSameTimeFraming = 1 + writer.write(0, 6) # numSubFrames = 0 + writer.write(0, 4) # numProgram = 0 + writer.write(0, 3) # numLayer = 0 + self.audio_specific_config.to_bits(writer) + writer.write(0, 3) # frameLengthType = 0 + writer.write(0, 8) # latmBufferFullness = 0 + writer.write(0, 1) # otherDataPresent = 0 + writer.write(0, 1) # crcCheckPresent = 0 + @dataclass class AudioMuxElement: - payload: bytes stream_mux_config: AacAudioRtpPacket.StreamMuxConfig + payload: bytes - def __init__(self, reader: BitReader, mux_config_present: int): - if mux_config_present == 0: - raise core.InvalidPacketError('muxConfigPresent == 0 not supported') - + @classmethod + def from_bits(cls, reader: BitReader) -> Self: # AudioMuxElement - ISO/EIC 14496-3 Table 1.41 + # (only supports mux_config_present=1) use_same_stream_mux = reader.read(1) if use_same_stream_mux: raise core.InvalidPacketError('useSameStreamMux == 1 not supported') - self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader) + stream_mux_config = AacAudioRtpPacket.StreamMuxConfig.from_bits(reader) # We only support: # allStreamsSameTimeFraming == 1 @@ -344,19 +455,46 @@ class AacAudioRtpPacket: if tmp != 255: break - self.payload = reader.read_bytes(mux_slot_length_bytes) + payload = reader.read_bytes(mux_slot_length_bytes) - if self.stream_mux_config.other_data_present: - reader.skip(self.stream_mux_config.other_data_len_bits) + if stream_mux_config.other_data_present: + reader.skip(stream_mux_config.other_data_len_bits) # ByteAlign while reader.bit_position % 8: reader.read(1) - def __init__(self, data: bytes) -> None: + return cls(stream_mux_config, payload) + + def to_bits(self, writer: BitWriter) -> None: + writer.write(0, 1) # useSameStreamMux = 0 + self.stream_mux_config.to_bits(writer) + mux_slot_length_bytes = len(self.payload) + while mux_slot_length_bytes > 255: + writer.write(255, 8) + mux_slot_length_bytes -= 255 + writer.write(mux_slot_length_bytes, 8) + if mux_slot_length_bytes == 255: + writer.write(0, 8) + writer.write_bytes(self.payload) + + @classmethod + def from_bytes(cls, data: bytes) -> Self: # Parse the bit stream reader = BitReader(data) - self.audio_mux_element = self.AudioMuxElement(reader, mux_config_present=1) + return cls(cls.AudioMuxElement.from_bits(reader)) + + @classmethod + def for_simple_aac( + cls, sampling_frequency: int, channel_configuration: int, payload: bytes + ) -> Self: + audio_specific_config = cls.AudioSpecificConfig.for_simple_aac( + 2, sampling_frequency, channel_configuration + ) + stream_mux_config = cls.StreamMuxConfig(0, 0, audio_specific_config) + audio_mux_element = cls.AudioMuxElement(stream_mux_config, payload) + + return cls(audio_mux_element) def to_adts(self): # pylint: disable=line-too-long @@ -383,3 +521,11 @@ class AacAudioRtpPacket: ) + self.audio_mux_element.payload ) + + def __init__(self, audio_mux_element: AudioMuxElement) -> None: + self.audio_mux_element = audio_mux_element + + def __bytes__(self) -> bytes: + writer = BitWriter() + self.audio_mux_element.to_bits(writer) + return bytes(writer) diff --git a/tests/codecs_test.py b/tests/codecs_test.py index b8affada..2a44e1e9 100644 --- a/tests/codecs_test.py +++ b/tests/codecs_test.py @@ -15,8 +15,9 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +import random import pytest -from bumble.codecs import AacAudioRtpPacket, BitReader +from bumble.codecs import AacAudioRtpPacket, BitReader, BitWriter # ----------------------------------------------------------------------------- @@ -49,19 +50,58 @@ def test_reader(): assert value == int.from_bytes(data, byteorder='big') +def test_writer(): + writer = BitWriter() + assert bytes(writer) == b'' + + for i in range(100): + for j in range(1, 10): + writer = BitWriter() + chunks = [] + for k in range(j): + n_bits = random.randint(1, 32) + random_bits = random.getrandbits(n_bits) + chunks.append((n_bits, random_bits)) + writer.write(random_bits, n_bits) + + written_data = bytes(writer) + reader = BitReader(written_data) + for n_bits, written_bits in chunks: + read_bits = reader.read(n_bits) + assert read_bits == written_bits + + def test_aac_rtp(): # pylint: disable=line-too-long packet_data = bytes.fromhex( '47fc0000b090800300202066000198000de120000000000000000000000000000000000000000000001c' ) - packet = AacAudioRtpPacket(packet_data) + packet = AacAudioRtpPacket.from_bytes(packet_data) adts = packet.to_adts() assert adts == bytes.fromhex( 'fff1508004fffc2066000198000de120000000000000000000000000000000000000000000001c' ) + payload = bytes(list(range(1, 200))) + rtp = AacAudioRtpPacket.for_simple_aac(44100, 2, payload) + assert rtp.audio_mux_element.payload == payload + assert ( + rtp.audio_mux_element.stream_mux_config.audio_specific_config.sampling_frequency + == 44100 + ) + assert ( + rtp.audio_mux_element.stream_mux_config.audio_specific_config.channel_configuration + == 2 + ) + rtp2 = AacAudioRtpPacket.from_bytes(bytes(rtp)) + assert str(rtp2.audio_mux_element.stream_mux_config) == str( + rtp.audio_mux_element.stream_mux_config + ) + assert rtp2.audio_mux_element.payload == rtp.audio_mux_element.payload + # ----------------------------------------------------------------------------- if __name__ == '__main__': test_reader() + test_writer() test_aac_rtp() diff --git a/web/speaker/speaker.py b/web/speaker/speaker.py index 2b8ce006..ab211b1f 100644 --- a/web/speaker/speaker.py +++ b/web/speaker/speaker.py @@ -72,7 +72,7 @@ class AudioExtractor: # ----------------------------------------------------------------------------- class AacAudioExtractor: def extract_audio(self, packet: MediaPacket) -> bytes: - return AacAudioRtpPacket(packet.payload).to_adts() + return AacAudioRtpPacket.from_bytes(packet.payload).to_adts() # ----------------------------------------------------------------------------- @@ -282,9 +282,6 @@ class Speaker: mitm=False ) - # Start the controller - await self.device.power_on() - # Listen for Bluetooth connections self.device.on('connection', self.on_bluetooth_connection) @@ -295,6 +292,9 @@ class Speaker: self.avdtp_listener = Listener.for_device(self.device) self.avdtp_listener.on('connection', self.on_avdtp_connection) + # Start the controller + await self.device.power_on() + print(f'Speaker ready to play, codec={self.codec}') if connect_address: