From d064de35e08ed53d1defc26a762b8c696057001d Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Tue, 8 Oct 2024 21:57:28 -0700 Subject: [PATCH] wip --- apps/player/player.py | 112 ++++++++++++--- bumble/a2dp.py | 318 ++++++++++++++++++++++++++++++++++++++++-- bumble/avdtp.py | 34 ++++- bumble/device.py | 24 ++-- 4 files changed, 444 insertions(+), 44 deletions(-) diff --git a/apps/player/player.py b/apps/player/player.py index dbabf4fc..d9813c83 100644 --- a/apps/player/player.py +++ b/apps/player/player.py @@ -31,6 +31,9 @@ from bumble.a2dp import ( A2DP_SBC_CODEC_TYPE, A2DP_MPEG_2_4_AAC_CODEC_TYPE, MPEG_2_AAC_LC_OBJECT_TYPE, + A2DP_NON_A2DP_CODEC_TYPE, + OPUS_VENDOR_ID, + OPUS_CODEC_ID, AacFrame, AacParser, AacPacketSource, @@ -39,6 +42,10 @@ from bumble.a2dp import ( SbcParser, SbcPacketSource, SbcMediaCodecInformation, + OpusPacket, + OpusParser, + OpusPacketSource, + OpusMediaCodecInformation, ) from bumble.avrcp import Protocol as AvrcpProtocol from bumble.avdtp import ( @@ -57,7 +64,7 @@ from bumble.core import ( BT_BR_EDR_TRANSPORT, ) from bumble.device import Connection, Device, DeviceConfiguration -from bumble.hci import Address +from bumble.hci import Address, HCI_CONNECTION_ALREADY_EXISTS_ERROR, HCI_Constant from bumble.pairing import PairingConfig from bumble.transport import open_transport from bumble.utils import AsyncRunner @@ -130,6 +137,36 @@ async def aac_codec_capabilities(read_function) -> MediaCodecCapabilities: ) +# ----------------------------------------------------------------------------- +async def opus_codec_capabilities(read_function) -> MediaCodecCapabilities: + opus_parser = OpusParser(read_function) + opus_packet: OpusPacket + async for opus_packet in opus_parser.packets: + # We only need the first packet + print(color(f"Opus format: {opus_packet}", "cyan")) + break + + if opus_packet.channel_mode == OpusPacket.ChannelMode.MONO: + channel_mode = OpusMediaCodecInformation.ChannelMode.MONO + elif opus_packet.channel_mode == OpusPacket.ChannelMode.STEREO: + channel_mode = OpusMediaCodecInformation.ChannelMode.STEREO + else: + channel_mode = OpusMediaCodecInformation.ChannelMode.DUAL_MONO + + if opus_packet.duration == 10: + frame_size = OpusMediaCodecInformation.FrameSize.F_10MS + else: + frame_size = OpusMediaCodecInformation.FrameSize.F_20MS + + return MediaCodecCapabilities( + media_type=AVDTP_AUDIO_MEDIA_TYPE, + media_codec_type=A2DP_NON_A2DP_CODEC_TYPE, + media_codec_information=OpusMediaCodecInformation.from_discrete_values( + channel_mode=channel_mode, sampling_frequency=48000, frame_size=frame_size + ), + ) + + # ----------------------------------------------------------------------------- class Player: def __init__( @@ -144,14 +181,12 @@ class Player: self.authenticate = authenticate self.encrypt = encrypt self.avrcp_protocol: Optional[AvrcpProtocol] = None - self.done: Optional[asyncio.Future] + self.done: Optional[asyncio.Event] async def run(self, workload) -> None: - self.done = asyncio.get_running_loop().create_future() + self.done = asyncio.Event() try: await self._run(workload) - except BumbleConnectionError as error: - print(color(f"Failed to connect: {error}", "red")) except Exception as error: print(color(f"!!! ERROR: {error}", "red")) @@ -172,9 +207,12 @@ class Player: device_config.classic_enabled = True device_config.le_enabled = False + device_config.le_simultaneous_enabled = False + device_config.classic_sc_enabled = False + device_config.classic_smp_enabled = False device = Device.from_config_with_hci(device_config, hci_source, hci_sink) - # Setup the SDP to expose the SRC service + # Setup the SDP records to expose the SRC service device.sdp_service_records = a2dp_source_sdp_records() # Setup AVRCP @@ -200,15 +238,28 @@ class Player: device.on("connection", self.on_bluetooth_connection) # Run the workload - await workload(device) + try: + await workload(device) + except BumbleConnectionError as error: + if error.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR: + print(color("Connection already established", "blue")) + else: + print(color(f"Failed to connect: {error}", "red")) # Wait until it is time to exit + assert self.done is not None await asyncio.wait( - [hci_source.terminated, self.done], return_when=asyncio.FIRST_COMPLETED + [hci_source.terminated, asyncio.ensure_future(self.done.wait())], + return_when=asyncio.FIRST_COMPLETED, ) def on_bluetooth_connection(self, connection: Connection) -> None: print(color(f"--- Connected: {connection}", "cyan")) + connection.on("disconnection", self.on_bluetooth_disconnection) + + def on_bluetooth_disconnection(self, reason) -> None: + print(color(f"--- Disconnected: {HCI_Constant.error_name(reason)}", "cyan")) + self.set_done() async def connect(self, device: Device, address: str) -> Connection: print(color(f"Connecting to {address}...", "green")) @@ -243,7 +294,9 @@ class Player: self, protocol: AvdtpProtocol, codec_type: int, - packet_source: Union[SbcPacketSource, AacPacketSource], + vendor_id: int, + codec_id: int, + packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource], ): # Discover all endpoints on the remote device endpoints = await protocol.discover_remote_endpoints() @@ -251,7 +304,9 @@ class Player: print('@@@', endpoint) # Select a sink - sink = protocol.find_remote_sink_by_codec(AVDTP_AUDIO_MEDIA_TYPE, codec_type) + sink = protocol.find_remote_sink_by_codec( + AVDTP_AUDIO_MEDIA_TYPE, codec_type, vendor_id, codec_id + ) if sink is None: print(color('!!! no compatible sink found', 'red')) return @@ -313,8 +368,7 @@ class Player: print(color("Pairing...", "magenta")) await connection.authenticate() print(color("Pairing completed", "magenta")) - if self.done is not None: - self.done.set_result(None) + self.set_done() async def inquire(self, device: Device, address: str) -> None: connection = await self.connect(device, address) @@ -326,8 +380,7 @@ class Player: for endpoint in endpoints: print('@@@', endpoint) - if self.done is not None: - self.done.set_result(None) + self.set_done() async def play( self, @@ -341,6 +394,8 @@ class Player: audio_format = "sbc" elif audio_file.endswith(".aac") or audio_file.endswith(".adts"): audio_format = "aac" + elif audio_file.endswith(".ogg"): + audio_format = "opus" else: raise ValueError("Unable to determine audio format from file extension") @@ -359,7 +414,9 @@ class Player: return input_file.read(byte_count) # Obtain the codec capabilities from the stream - packet_source: Union[SbcPacketSource, AacPacketSource] + packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource] + vendor_id = 0 + codec_id = 0 if audio_format == "sbc": codec_type = A2DP_SBC_CODEC_TYPE codec_capabilities = await sbc_codec_capabilities(read_audio_data) @@ -368,7 +425,7 @@ class Player: avdtp_protocol.l2cap_channel.peer_mtu, codec_capabilities, ) - else: + elif audio_format == "aac": codec_type = A2DP_MPEG_2_4_AAC_CODEC_TYPE codec_capabilities = await aac_codec_capabilities(read_audio_data) packet_source = AacPacketSource( @@ -376,23 +433,38 @@ class Player: avdtp_protocol.l2cap_channel.peer_mtu, codec_capabilities, ) + else: + codec_type = A2DP_NON_A2DP_CODEC_TYPE + vendor_id = OPUS_VENDOR_ID + codec_id = OPUS_CODEC_ID + codec_capabilities = await opus_codec_capabilities(read_audio_data) + packet_source = OpusPacketSource( + read_audio_data, + avdtp_protocol.l2cap_channel.peer_mtu, + codec_capabilities, + ) # Rewind to the start input_file.seek(0) try: - await self.stream_packets(avdtp_protocol, codec_type, packet_source) + await self.stream_packets( + avdtp_protocol, codec_type, vendor_id, codec_id, packet_source + ) except Exception as error: print(color(f"!!! Error while streaming: {error}", "red")) - if self.done: - self.done.set_result(None) + self.set_done() if address: await self.connect(device, address) else: print(color("Waiting for an incoming connection...", "magenta")) + def set_done(self) -> None: + if self.done: + self.done.set() + # ----------------------------------------------------------------------------- def create_player(context) -> Player: @@ -469,7 +541,7 @@ def pair(context, address): @click.option( "-f", "--audio-format", - type=click.Choice(["auto", "sbc", "aac"]), + type=click.Choice(["auto", "sbc", "aac", "opus"]), help="Audio file format (use 'auto' to infer the format from the file extension)", default="auto", ) diff --git a/bumble/a2dp.py b/bumble/a2dp.py index 285b748c..e6ccdf9a 100644 --- a/bumble/a2dp.py +++ b/bumble/a2dp.py @@ -22,7 +22,9 @@ import dataclasses import enum import logging import struct -from typing import List, Callable, Awaitable +from typing import Awaitable, Callable, Iterable, List +from typing_extensions import Self + from .codecs import AacAudioRtpPacket from .company_ids import COMPANY_IDENTIFIERS @@ -105,6 +107,8 @@ SBC_ALLOCATION_METHOD_NAMES = { SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD' } +SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15 + MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [ 8000, 11025, @@ -132,6 +136,11 @@ MPEG_2_4_OBJECT_TYPE_NAMES = { MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE' } + +OPUS_VENDOR_ID = 0x000000E0 +OPUS_CODEC_ID = 0x0001 +OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15 + # fmt: on @@ -516,7 +525,7 @@ class VendorSpecificMediaCodecInformation: return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:]) def __bytes__(self) -> bytes: - return struct.pack(' str: # pylint: disable=line-too-long @@ -530,6 +539,105 @@ class VendorSpecificMediaCodecInformation: ) +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class OpusMediaCodecInformation(VendorSpecificMediaCodecInformation): + channel_mode: int + frame_size: int + sampling_frequency: int + + class ChannelMode(enum.IntEnum): + MONO = 0 + STEREO = 1 + DUAL_MONO = 2 + + CHANNEL_MODE_BITS = { + ChannelMode.MONO: 1 << 0, + ChannelMode.STEREO: 1 << 1, + ChannelMode.DUAL_MONO: 1 << 2, + } + + class FrameSize(enum.IntFlag): + F_10MS = 0 + F_20MS = 1 + + FRAME_SIZE_BITS = {FrameSize.F_10MS: 1 << 0, FrameSize.F_20MS: 1 << 1} + + SAMPLING_FREQUENCIES = [48000] + SAMPLING_FREQUENCY_BITS = { + 48000: 1 << 0, + } + + @classmethod + def from_bytes(cls, data: bytes) -> Self: + """Create a new instance from the `value` part of the data, not including + the vendor id and codec id""" + channel_mode = data[0] & 0x07 + frame_size = (data[0] >> 3) & 0x03 + sampling_frequency = (data[0] >> 7) & 0x01 + + return cls( + OPUS_VENDOR_ID, + OPUS_CODEC_ID, + data, + channel_mode, + frame_size, + sampling_frequency, + ) + + @classmethod + def from_discrete_values( + cls, channel_mode: ChannelMode, frame_size: FrameSize, sampling_frequency: int + ) -> Self: + channel_mode_int = cls.CHANNEL_MODE_BITS[channel_mode] + frame_size_int = cls.FRAME_SIZE_BITS[frame_size] + sampling_frequency_int = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency] + value = bytes( + [channel_mode_int | (frame_size_int << 3) | (sampling_frequency_int << 7)] + ) + return cls( + vendor_id=OPUS_VENDOR_ID, + codec_id=OPUS_CODEC_ID, + value=value, + channel_mode=channel_mode_int, + frame_size=frame_size_int, + sampling_frequency=sampling_frequency_int, + ) + + @classmethod + def from_lists( + cls, + channel_modes: Iterable[ChannelMode], + frame_sizes: Iterable[FrameSize], + sampling_frequencies: Iterable[int], + ) -> Self: + channel_mode = sum(channel_modes) + frame_size = sum(frame_sizes) + sampling_frequency = sum( + cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies + ) + value = bytes([channel_mode | (frame_size << 3) | (sampling_frequency << 7)]) + return cls( + vendor_id=OPUS_VENDOR_ID, + codec_id=OPUS_CODEC_ID, + value=value, + channel_mode=channel_mode, + frame_size=frame_size, + sampling_frequency=sampling_frequency, + ) + + def __str__(self) -> str: + # pylint: disable=line-too-long + return '\n'.join( + [ + 'OpusMediaCodecInformation(', + f' channel_mode: {",".join([x.name for x in flags_to_list(self.channel_mode, list(self.ChannelMode))])}', + f' frame_size: {",".join([x.name for x in flags_to_list(self.frame_size, list(self.FrameSize))])}', + f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, self.SAMPLING_FREQUENCIES)])}', + ] + ) + + # ----------------------------------------------------------------------------- @dataclasses.dataclass class SbcFrame: @@ -628,7 +736,7 @@ class SbcPacketSource: from .avdtp import MediaPacket # Import here to avoid a circular reference sequence_number = 0 - timestamp = 0 + sample_count = 0 frames = [] frames_size = 0 max_rtp_payload = self.mtu - 12 - 1 @@ -638,26 +746,27 @@ class SbcPacketSource: async for frame in sbc_parser.frames: if ( frames_size + len(frame.payload) > max_rtp_payload - or len(frames) == 16 + or len(frames) == SBC_MAX_FRAMES_IN_RTP_PAYLOAD ): # Need to flush what has been accumulated so far logger.debug(f"yielding {len(frames)} frames") # Emit a packet - sbc_payload = bytes([len(frames)]) + b''.join( + sbc_payload = bytes([len(frames) & 0x0F]) + b''.join( [frame.payload for frame in frames] ) + timestamp_seconds = sample_count / frame.sampling_frequency + timestamp = int(1000 * timestamp_seconds) packet = MediaPacket( 2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload ) - packet.timestamp_seconds = timestamp / frame.sampling_frequency + packet.timestamp_seconds = timestamp_seconds yield packet # Prepare for next packets sequence_number += 1 sequence_number &= 0xFFFF - timestamp += sum((frame.sample_count for frame in frames)) - timestamp &= 0xFFFFFFFF + sample_count += sum((frame.sample_count for frame in frames)) frames = [frame] frames_size = len(frame.payload) else: @@ -775,7 +884,7 @@ class AacPacketSource: from .avdtp import MediaPacket # Import here to avoid a circular reference sequence_number = 0 - timestamp = 0 + sample_count = 0 aac_parser = AacParser(self.read) async for frame in aac_parser.frames: @@ -789,17 +898,200 @@ class AacPacketSource: frame.payload, ) ) + timestamp_seconds = sample_count / frame.sampling_frequency + timestamp = int(1000 * timestamp_seconds) packet = MediaPacket( 2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, aac_payload ) - packet.timestamp_seconds = timestamp / frame.sampling_frequency + packet.timestamp_seconds = timestamp_seconds yield packet # Prepare for next packets sequence_number += 1 sequence_number &= 0xFFFF - timestamp += frame.sample_count - timestamp &= 0xFFFFFFFF - frames = [frame] + sample_count += frame.sample_count return generate_packets() + + +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class OpusPacket: + class ChannelMode(enum.IntEnum): + MONO = 0 + STEREO = 1 + DUAL_MONO = 2 + + channel_mode: ChannelMode + duration: int # Duration in ms. + sampling_frequency: int + payload: bytes + + def __str__(self) -> str: + return ( + f'Opus(ch={self.channel_mode.name}, ' + f'd={self.duration}ms, ' + f'size={len(self.payload)})' + ) + + +# ----------------------------------------------------------------------------- +class OpusParser: + """ + Parser for Opus packets in an Ogg stream + + See RFC 3533 + + NOTE: this parser only supports bitstreams with a single logical stream. + """ + + CAPTURE_PATTERN = b'OggS' + + class HeaderType(enum.IntFlag): + CONTINUED = 0x01 + FIRST = 0x02 + LAST = 0x04 + + def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None: + self.read = read + + @property + def packets(self) -> AsyncGenerator[OpusPacket, None]: + async def generate_frames() -> AsyncGenerator[OpusPacket, None]: + packet = b'' + packet_count = 0 + expected_bitstream_serial_number = None + expected_page_sequence_number = 0 + channel_mode = OpusPacket.ChannelMode.STEREO + + while True: + # Parse the page header + header = await self.read(27) + if len(header) != 27: + logger.debug("end of stream") + break + + capture_pattern = header[:4] + if capture_pattern != self.CAPTURE_PATTERN: + print(capture_pattern.hex()) + raise ValueError("invalid capture pattern at start of page") + + version = header[4] + if version != 0: + raise ValueError(f"version {version} not supported") + + header_type = self.HeaderType(header[5]) + (granule_position,) = struct.unpack_from(" None: + self.read = read + self.mtu = mtu + self.codec_capabilities = codec_capabilities + + @property + def packets(self): + async def generate_packets(): + # pylint: disable=import-outside-toplevel + from .avdtp import MediaPacket # Import here to avoid a circular reference + + sequence_number = 0 + elapsed_ms = 0 + + opus_parser = OpusParser(self.read) + async for opus_packet in opus_parser.packets: + # We only support sending one Opus frame per RTP packet + # TODO: check the spec for the first byte value here + opus_payload = bytes([1]) + opus_packet.payload + elapsed_s = elapsed_ms / 1000 + timestamp = int(elapsed_s * opus_packet.sampling_frequency) + rtp_packet = MediaPacket( + 2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, opus_payload + ) + rtp_packet.timestamp_seconds = elapsed_s + yield rtp_packet + + # Prepare for next packets + sequence_number += 1 + sequence_number &= 0xFFFF + elapsed_ms += opus_packet.duration + + return generate_packets() + + +# ----------------------------------------------------------------------------- +# This map should be left at the end of the file so it can refer to the classes +# above +# ----------------------------------------------------------------------------- +A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES = { + OPUS_VENDOR_ID: {OPUS_CODEC_ID: OpusMediaCodecInformation} +} diff --git a/bumble/avdtp.py b/bumble/avdtp.py index fe311986..f5a61505 100644 --- a/bumble/avdtp.py +++ b/bumble/avdtp.py @@ -51,6 +51,7 @@ from .a2dp import ( A2DP_MPEG_2_4_AAC_CODEC_TYPE, A2DP_NON_A2DP_CODEC_TYPE, A2DP_SBC_CODEC_TYPE, + A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES, AacMediaCodecInformation, SbcMediaCodecInformation, VendorSpecificMediaCodecInformation, @@ -328,6 +329,7 @@ class MediaPacket: self.marker = marker self.sequence_number = sequence_number & 0xFFFF self.timestamp = timestamp & 0xFFFFFFFF + self.timestamp_seconds = 0.0 self.ssrc = ssrc self.csrc_list = csrc_list self.payload_type = payload_type @@ -621,11 +623,25 @@ class MediaCodecCapabilities(ServiceCapabilities): self.media_codec_information ) elif self.media_codec_type == A2DP_NON_A2DP_CODEC_TYPE: - self.media_codec_information = ( + vendor_media_codec_information = ( VendorSpecificMediaCodecInformation.from_bytes( self.media_codec_information ) ) + if ( + vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get( + vendor_media_codec_information.vendor_id + ) + ) and ( + media_codec_information_class := vendor_class_map.get( + vendor_media_codec_information.codec_id + ) + ): + self.media_codec_information = media_codec_information_class.from_bytes( + vendor_media_codec_information.value + ) + else: + self.media_codec_information = vendor_media_codec_information def __init__( self, @@ -1388,7 +1404,7 @@ class Protocol(EventEmitter): return self.remote_endpoints.values() def find_remote_sink_by_codec( - self, media_type: int, codec_type: int + self, media_type: int, codec_type: int, vendor_id: int = 0, codec_id: int = 0 ) -> Optional[DiscoveredStreamEndPoint]: for endpoint in self.remote_endpoints.values(): if ( @@ -1413,7 +1429,19 @@ class Protocol(EventEmitter): codec_capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE and codec_capabilities.media_codec_type == codec_type ): - has_codec = True + if isinstance( + codec_capabilities.media_codec_information, + VendorSpecificMediaCodecInformation, + ): + if ( + codec_capabilities.media_codec_information.vendor_id + == vendor_id + and codec_capabilities.media_codec_information.codec_id + == codec_id + ): + has_codec = True + else: + has_codec = True if has_media_transport and has_codec: return endpoint diff --git a/bumble/device.py b/bumble/device.py index 38d0ca6f..07f44c08 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -1571,14 +1571,22 @@ class Connection(CompositeEventEmitter): raise def __str__(self): - return ( - f'Connection(handle=0x{self.handle:04X}, ' - f'role={self.role_name}, ' - f'self_address={self.self_address}, ' - f'self_resolvable_address={self.self_resolvable_address}, ' - f'peer_address={self.peer_address}, ' - f'peer_resolvable_address={self.peer_resolvable_address})' - ) + if self.transport == BT_LE_TRANSPORT: + return ( + f'Connection(transport=LE, handle=0x{self.handle:04X}, ' + f'role={self.role_name}, ' + f'self_address={self.self_address}, ' + f'self_resolvable_address={self.self_resolvable_address}, ' + f'peer_address={self.peer_address}, ' + f'peer_resolvable_address={self.peer_resolvable_address})' + ) + else: + return ( + f'Connection(transport=BR/EDR, handle=0x{self.handle:04X}, ' + f'role={self.role_name}, ' + f'self_address={self.self_address}, ' + f'peer_address={self.peer_address})' + ) # -----------------------------------------------------------------------------