mirror of
https://github.com/google/bumble.git
synced 2026-04-18 00:45:32 +00:00
wip
This commit is contained in:
318
bumble/a2dp.py
318
bumble/a2dp.py
@@ -22,7 +22,9 @@ import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import struct
|
||||
from typing import List, Callable, Awaitable
|
||||
from typing import Awaitable, Callable, Iterable, List
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
from .codecs import AacAudioRtpPacket
|
||||
from .company_ids import COMPANY_IDENTIFIERS
|
||||
@@ -105,6 +107,8 @@ SBC_ALLOCATION_METHOD_NAMES = {
|
||||
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
|
||||
}
|
||||
|
||||
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
||||
|
||||
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
|
||||
8000,
|
||||
11025,
|
||||
@@ -132,6 +136,11 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
|
||||
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
|
||||
}
|
||||
|
||||
|
||||
OPUS_VENDOR_ID = 0x000000E0
|
||||
OPUS_CODEC_ID = 0x0001
|
||||
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
@@ -516,7 +525,7 @@ class VendorSpecificMediaCodecInformation:
|
||||
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
|
||||
return struct.pack('<IH', self.vendor_id, self.codec_id) + self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
# pylint: disable=line-too-long
|
||||
@@ -530,6 +539,105 @@ class VendorSpecificMediaCodecInformation:
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class OpusMediaCodecInformation(VendorSpecificMediaCodecInformation):
|
||||
channel_mode: int
|
||||
frame_size: int
|
||||
sampling_frequency: int
|
||||
|
||||
class ChannelMode(enum.IntEnum):
|
||||
MONO = 0
|
||||
STEREO = 1
|
||||
DUAL_MONO = 2
|
||||
|
||||
CHANNEL_MODE_BITS = {
|
||||
ChannelMode.MONO: 1 << 0,
|
||||
ChannelMode.STEREO: 1 << 1,
|
||||
ChannelMode.DUAL_MONO: 1 << 2,
|
||||
}
|
||||
|
||||
class FrameSize(enum.IntFlag):
|
||||
F_10MS = 0
|
||||
F_20MS = 1
|
||||
|
||||
FRAME_SIZE_BITS = {FrameSize.F_10MS: 1 << 0, FrameSize.F_20MS: 1 << 1}
|
||||
|
||||
SAMPLING_FREQUENCIES = [48000]
|
||||
SAMPLING_FREQUENCY_BITS = {
|
||||
48000: 1 << 0,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
"""Create a new instance from the `value` part of the data, not including
|
||||
the vendor id and codec id"""
|
||||
channel_mode = data[0] & 0x07
|
||||
frame_size = (data[0] >> 3) & 0x03
|
||||
sampling_frequency = (data[0] >> 7) & 0x01
|
||||
|
||||
return cls(
|
||||
OPUS_VENDOR_ID,
|
||||
OPUS_CODEC_ID,
|
||||
data,
|
||||
channel_mode,
|
||||
frame_size,
|
||||
sampling_frequency,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_discrete_values(
|
||||
cls, channel_mode: ChannelMode, frame_size: FrameSize, sampling_frequency: int
|
||||
) -> Self:
|
||||
channel_mode_int = cls.CHANNEL_MODE_BITS[channel_mode]
|
||||
frame_size_int = cls.FRAME_SIZE_BITS[frame_size]
|
||||
sampling_frequency_int = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency]
|
||||
value = bytes(
|
||||
[channel_mode_int | (frame_size_int << 3) | (sampling_frequency_int << 7)]
|
||||
)
|
||||
return cls(
|
||||
vendor_id=OPUS_VENDOR_ID,
|
||||
codec_id=OPUS_CODEC_ID,
|
||||
value=value,
|
||||
channel_mode=channel_mode_int,
|
||||
frame_size=frame_size_int,
|
||||
sampling_frequency=sampling_frequency_int,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lists(
|
||||
cls,
|
||||
channel_modes: Iterable[ChannelMode],
|
||||
frame_sizes: Iterable[FrameSize],
|
||||
sampling_frequencies: Iterable[int],
|
||||
) -> Self:
|
||||
channel_mode = sum(channel_modes)
|
||||
frame_size = sum(frame_sizes)
|
||||
sampling_frequency = sum(
|
||||
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
|
||||
)
|
||||
value = bytes([channel_mode | (frame_size << 3) | (sampling_frequency << 7)])
|
||||
return cls(
|
||||
vendor_id=OPUS_VENDOR_ID,
|
||||
codec_id=OPUS_CODEC_ID,
|
||||
value=value,
|
||||
channel_mode=channel_mode,
|
||||
frame_size=frame_size,
|
||||
sampling_frequency=sampling_frequency,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
# pylint: disable=line-too-long
|
||||
return '\n'.join(
|
||||
[
|
||||
'OpusMediaCodecInformation(',
|
||||
f' channel_mode: {",".join([x.name for x in flags_to_list(self.channel_mode, list(self.ChannelMode))])}',
|
||||
f' frame_size: {",".join([x.name for x in flags_to_list(self.frame_size, list(self.FrameSize))])}',
|
||||
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, self.SAMPLING_FREQUENCIES)])}',
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class SbcFrame:
|
||||
@@ -628,7 +736,7 @@ class SbcPacketSource:
|
||||
from .avdtp import MediaPacket # Import here to avoid a circular reference
|
||||
|
||||
sequence_number = 0
|
||||
timestamp = 0
|
||||
sample_count = 0
|
||||
frames = []
|
||||
frames_size = 0
|
||||
max_rtp_payload = self.mtu - 12 - 1
|
||||
@@ -638,26 +746,27 @@ class SbcPacketSource:
|
||||
async for frame in sbc_parser.frames:
|
||||
if (
|
||||
frames_size + len(frame.payload) > max_rtp_payload
|
||||
or len(frames) == 16
|
||||
or len(frames) == SBC_MAX_FRAMES_IN_RTP_PAYLOAD
|
||||
):
|
||||
# Need to flush what has been accumulated so far
|
||||
logger.debug(f"yielding {len(frames)} frames")
|
||||
|
||||
# Emit a packet
|
||||
sbc_payload = bytes([len(frames)]) + b''.join(
|
||||
sbc_payload = bytes([len(frames) & 0x0F]) + b''.join(
|
||||
[frame.payload for frame in frames]
|
||||
)
|
||||
timestamp_seconds = sample_count / frame.sampling_frequency
|
||||
timestamp = int(1000 * timestamp_seconds)
|
||||
packet = MediaPacket(
|
||||
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
|
||||
)
|
||||
packet.timestamp_seconds = timestamp / frame.sampling_frequency
|
||||
packet.timestamp_seconds = timestamp_seconds
|
||||
yield packet
|
||||
|
||||
# Prepare for next packets
|
||||
sequence_number += 1
|
||||
sequence_number &= 0xFFFF
|
||||
timestamp += sum((frame.sample_count for frame in frames))
|
||||
timestamp &= 0xFFFFFFFF
|
||||
sample_count += sum((frame.sample_count for frame in frames))
|
||||
frames = [frame]
|
||||
frames_size = len(frame.payload)
|
||||
else:
|
||||
@@ -775,7 +884,7 @@ class AacPacketSource:
|
||||
from .avdtp import MediaPacket # Import here to avoid a circular reference
|
||||
|
||||
sequence_number = 0
|
||||
timestamp = 0
|
||||
sample_count = 0
|
||||
|
||||
aac_parser = AacParser(self.read)
|
||||
async for frame in aac_parser.frames:
|
||||
@@ -789,17 +898,200 @@ class AacPacketSource:
|
||||
frame.payload,
|
||||
)
|
||||
)
|
||||
timestamp_seconds = sample_count / frame.sampling_frequency
|
||||
timestamp = int(1000 * timestamp_seconds)
|
||||
packet = MediaPacket(
|
||||
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, aac_payload
|
||||
)
|
||||
packet.timestamp_seconds = timestamp / frame.sampling_frequency
|
||||
packet.timestamp_seconds = timestamp_seconds
|
||||
yield packet
|
||||
|
||||
# Prepare for next packets
|
||||
sequence_number += 1
|
||||
sequence_number &= 0xFFFF
|
||||
timestamp += frame.sample_count
|
||||
timestamp &= 0xFFFFFFFF
|
||||
frames = [frame]
|
||||
sample_count += frame.sample_count
|
||||
|
||||
return generate_packets()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class OpusPacket:
|
||||
class ChannelMode(enum.IntEnum):
|
||||
MONO = 0
|
||||
STEREO = 1
|
||||
DUAL_MONO = 2
|
||||
|
||||
channel_mode: ChannelMode
|
||||
duration: int # Duration in ms.
|
||||
sampling_frequency: int
|
||||
payload: bytes
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Opus(ch={self.channel_mode.name}, '
|
||||
f'd={self.duration}ms, '
|
||||
f'size={len(self.payload)})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OpusParser:
|
||||
"""
|
||||
Parser for Opus packets in an Ogg stream
|
||||
|
||||
See RFC 3533
|
||||
|
||||
NOTE: this parser only supports bitstreams with a single logical stream.
|
||||
"""
|
||||
|
||||
CAPTURE_PATTERN = b'OggS'
|
||||
|
||||
class HeaderType(enum.IntFlag):
|
||||
CONTINUED = 0x01
|
||||
FIRST = 0x02
|
||||
LAST = 0x04
|
||||
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
|
||||
self.read = read
|
||||
|
||||
@property
|
||||
def packets(self) -> AsyncGenerator[OpusPacket, None]:
|
||||
async def generate_frames() -> AsyncGenerator[OpusPacket, None]:
|
||||
packet = b''
|
||||
packet_count = 0
|
||||
expected_bitstream_serial_number = None
|
||||
expected_page_sequence_number = 0
|
||||
channel_mode = OpusPacket.ChannelMode.STEREO
|
||||
|
||||
while True:
|
||||
# Parse the page header
|
||||
header = await self.read(27)
|
||||
if len(header) != 27:
|
||||
logger.debug("end of stream")
|
||||
break
|
||||
|
||||
capture_pattern = header[:4]
|
||||
if capture_pattern != self.CAPTURE_PATTERN:
|
||||
print(capture_pattern.hex())
|
||||
raise ValueError("invalid capture pattern at start of page")
|
||||
|
||||
version = header[4]
|
||||
if version != 0:
|
||||
raise ValueError(f"version {version} not supported")
|
||||
|
||||
header_type = self.HeaderType(header[5])
|
||||
(granule_position,) = struct.unpack_from("<Q", header, 6)
|
||||
(bitstream_serial_number,) = struct.unpack_from("<I", header, 14)
|
||||
(page_sequence_number,) = struct.unpack_from("<I", header, 18)
|
||||
(crc_checksum,) = struct.unpack_from("<I", header, 22)
|
||||
page_segments = header[26]
|
||||
segment_table = await self.read(page_segments)
|
||||
|
||||
if header_type & self.HeaderType.FIRST:
|
||||
if expected_bitstream_serial_number is None:
|
||||
# We will only accept pages for the first encountered stream
|
||||
logger.debug("BOS")
|
||||
expected_bitstream_serial_number = bitstream_serial_number
|
||||
expected_page_sequence_number = page_sequence_number
|
||||
|
||||
if (
|
||||
expected_bitstream_serial_number is None
|
||||
or expected_bitstream_serial_number != bitstream_serial_number
|
||||
):
|
||||
logger.debug("skipping page (not the first logical bitstream)")
|
||||
for lacing_value in segment_table:
|
||||
if lacing_value:
|
||||
await self.read(lacing_value)
|
||||
continue
|
||||
|
||||
if expected_page_sequence_number != page_sequence_number:
|
||||
raise ValueError(
|
||||
f"expected page sequence number {expected_page_sequence_number}"
|
||||
f" but got {page_sequence_number}"
|
||||
)
|
||||
expected_page_sequence_number = page_sequence_number + 1
|
||||
|
||||
# Assemble the page
|
||||
if not header_type & self.HeaderType.CONTINUED:
|
||||
packet = b''
|
||||
for lacing_value in segment_table:
|
||||
if lacing_value:
|
||||
packet += await self.read(lacing_value)
|
||||
if lacing_value < 255:
|
||||
# End of packet
|
||||
packet_count += 1
|
||||
|
||||
if packet_count == 1:
|
||||
# The first packet contains the identification header
|
||||
logger.debug("first packet (header)")
|
||||
if packet[:8] != b"OpusHead":
|
||||
raise ValueError("first packet is not OpusHead")
|
||||
packet_count = (
|
||||
OpusPacket.ChannelMode.MONO
|
||||
if packet[9] == 1
|
||||
else OpusPacket.ChannelMode.STEREO
|
||||
)
|
||||
|
||||
elif packet_count == 2:
|
||||
# The second packet contains the comment header
|
||||
logger.debug("second packet (tags)")
|
||||
if packet[:8] != b"OpusTags":
|
||||
logger.warning("second packet is not OpusTags")
|
||||
else:
|
||||
yield OpusPacket(channel_mode, 20, 48000, packet)
|
||||
|
||||
packet = b''
|
||||
|
||||
if header_type & self.HeaderType.LAST:
|
||||
logger.debug("EOS")
|
||||
|
||||
return generate_frames()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OpusPacketSource:
|
||||
def __init__(
|
||||
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
|
||||
) -> None:
|
||||
self.read = read
|
||||
self.mtu = mtu
|
||||
self.codec_capabilities = codec_capabilities
|
||||
|
||||
@property
|
||||
def packets(self):
|
||||
async def generate_packets():
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from .avdtp import MediaPacket # Import here to avoid a circular reference
|
||||
|
||||
sequence_number = 0
|
||||
elapsed_ms = 0
|
||||
|
||||
opus_parser = OpusParser(self.read)
|
||||
async for opus_packet in opus_parser.packets:
|
||||
# We only support sending one Opus frame per RTP packet
|
||||
# TODO: check the spec for the first byte value here
|
||||
opus_payload = bytes([1]) + opus_packet.payload
|
||||
elapsed_s = elapsed_ms / 1000
|
||||
timestamp = int(elapsed_s * opus_packet.sampling_frequency)
|
||||
rtp_packet = MediaPacket(
|
||||
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, opus_payload
|
||||
)
|
||||
rtp_packet.timestamp_seconds = elapsed_s
|
||||
yield rtp_packet
|
||||
|
||||
# Prepare for next packets
|
||||
sequence_number += 1
|
||||
sequence_number &= 0xFFFF
|
||||
elapsed_ms += opus_packet.duration
|
||||
|
||||
return generate_packets()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# This map should be left at the end of the file so it can refer to the classes
|
||||
# above
|
||||
# -----------------------------------------------------------------------------
|
||||
A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES = {
|
||||
OPUS_VENDOR_ID: {OPUS_CODEC_ID: OpusMediaCodecInformation}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user