This commit is contained in:
Gilles Boccon-Gibod
2024-10-08 21:57:28 -07:00
parent dab4d13303
commit d064de35e0
4 changed files with 444 additions and 44 deletions

View File

@@ -22,7 +22,9 @@ import dataclasses
import enum
import logging
import struct
from typing import List, Callable, Awaitable
from typing import Awaitable, Callable, Iterable, List
from typing_extensions import Self
from .codecs import AacAudioRtpPacket
from .company_ids import COMPANY_IDENTIFIERS
@@ -105,6 +107,8 @@ SBC_ALLOCATION_METHOD_NAMES = {
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
}
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
8000,
11025,
@@ -132,6 +136,11 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
}
OPUS_VENDOR_ID = 0x000000E0
OPUS_CODEC_ID = 0x0001
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
# fmt: on
@@ -516,7 +525,7 @@ class VendorSpecificMediaCodecInformation:
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
def __bytes__(self) -> bytes:
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
return struct.pack('<IH', self.vendor_id, self.codec_id) + self.value
def __str__(self) -> str:
# pylint: disable=line-too-long
@@ -530,6 +539,105 @@ class VendorSpecificMediaCodecInformation:
)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class OpusMediaCodecInformation(VendorSpecificMediaCodecInformation):
channel_mode: int
frame_size: int
sampling_frequency: int
class ChannelMode(enum.IntEnum):
MONO = 0
STEREO = 1
DUAL_MONO = 2
CHANNEL_MODE_BITS = {
ChannelMode.MONO: 1 << 0,
ChannelMode.STEREO: 1 << 1,
ChannelMode.DUAL_MONO: 1 << 2,
}
class FrameSize(enum.IntFlag):
F_10MS = 0
F_20MS = 1
FRAME_SIZE_BITS = {FrameSize.F_10MS: 1 << 0, FrameSize.F_20MS: 1 << 1}
SAMPLING_FREQUENCIES = [48000]
SAMPLING_FREQUENCY_BITS = {
48000: 1 << 0,
}
@classmethod
def from_bytes(cls, data: bytes) -> Self:
"""Create a new instance from the `value` part of the data, not including
the vendor id and codec id"""
channel_mode = data[0] & 0x07
frame_size = (data[0] >> 3) & 0x03
sampling_frequency = (data[0] >> 7) & 0x01
return cls(
OPUS_VENDOR_ID,
OPUS_CODEC_ID,
data,
channel_mode,
frame_size,
sampling_frequency,
)
@classmethod
def from_discrete_values(
cls, channel_mode: ChannelMode, frame_size: FrameSize, sampling_frequency: int
) -> Self:
channel_mode_int = cls.CHANNEL_MODE_BITS[channel_mode]
frame_size_int = cls.FRAME_SIZE_BITS[frame_size]
sampling_frequency_int = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency]
value = bytes(
[channel_mode_int | (frame_size_int << 3) | (sampling_frequency_int << 7)]
)
return cls(
vendor_id=OPUS_VENDOR_ID,
codec_id=OPUS_CODEC_ID,
value=value,
channel_mode=channel_mode_int,
frame_size=frame_size_int,
sampling_frequency=sampling_frequency_int,
)
@classmethod
def from_lists(
cls,
channel_modes: Iterable[ChannelMode],
frame_sizes: Iterable[FrameSize],
sampling_frequencies: Iterable[int],
) -> Self:
channel_mode = sum(channel_modes)
frame_size = sum(frame_sizes)
sampling_frequency = sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
)
value = bytes([channel_mode | (frame_size << 3) | (sampling_frequency << 7)])
return cls(
vendor_id=OPUS_VENDOR_ID,
codec_id=OPUS_CODEC_ID,
value=value,
channel_mode=channel_mode,
frame_size=frame_size,
sampling_frequency=sampling_frequency,
)
def __str__(self) -> str:
# pylint: disable=line-too-long
return '\n'.join(
[
'OpusMediaCodecInformation(',
f' channel_mode: {",".join([x.name for x in flags_to_list(self.channel_mode, list(self.ChannelMode))])}',
f' frame_size: {",".join([x.name for x in flags_to_list(self.frame_size, list(self.FrameSize))])}',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, self.SAMPLING_FREQUENCIES)])}',
]
)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class SbcFrame:
@@ -628,7 +736,7 @@ class SbcPacketSource:
from .avdtp import MediaPacket # Import here to avoid a circular reference
sequence_number = 0
timestamp = 0
sample_count = 0
frames = []
frames_size = 0
max_rtp_payload = self.mtu - 12 - 1
@@ -638,26 +746,27 @@ class SbcPacketSource:
async for frame in sbc_parser.frames:
if (
frames_size + len(frame.payload) > max_rtp_payload
or len(frames) == 16
or len(frames) == SBC_MAX_FRAMES_IN_RTP_PAYLOAD
):
# Need to flush what has been accumulated so far
logger.debug(f"yielding {len(frames)} frames")
# Emit a packet
sbc_payload = bytes([len(frames)]) + b''.join(
sbc_payload = bytes([len(frames) & 0x0F]) + b''.join(
[frame.payload for frame in frames]
)
timestamp_seconds = sample_count / frame.sampling_frequency
timestamp = int(1000 * timestamp_seconds)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
)
packet.timestamp_seconds = timestamp / frame.sampling_frequency
packet.timestamp_seconds = timestamp_seconds
yield packet
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
timestamp += sum((frame.sample_count for frame in frames))
timestamp &= 0xFFFFFFFF
sample_count += sum((frame.sample_count for frame in frames))
frames = [frame]
frames_size = len(frame.payload)
else:
@@ -775,7 +884,7 @@ class AacPacketSource:
from .avdtp import MediaPacket # Import here to avoid a circular reference
sequence_number = 0
timestamp = 0
sample_count = 0
aac_parser = AacParser(self.read)
async for frame in aac_parser.frames:
@@ -789,17 +898,200 @@ class AacPacketSource:
frame.payload,
)
)
timestamp_seconds = sample_count / frame.sampling_frequency
timestamp = int(1000 * timestamp_seconds)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, aac_payload
)
packet.timestamp_seconds = timestamp / frame.sampling_frequency
packet.timestamp_seconds = timestamp_seconds
yield packet
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
timestamp += frame.sample_count
timestamp &= 0xFFFFFFFF
frames = [frame]
sample_count += frame.sample_count
return generate_packets()
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class OpusPacket:
class ChannelMode(enum.IntEnum):
MONO = 0
STEREO = 1
DUAL_MONO = 2
channel_mode: ChannelMode
duration: int # Duration in ms.
sampling_frequency: int
payload: bytes
def __str__(self) -> str:
return (
f'Opus(ch={self.channel_mode.name}, '
f'd={self.duration}ms, '
f'size={len(self.payload)})'
)
# -----------------------------------------------------------------------------
class OpusParser:
"""
Parser for Opus packets in an Ogg stream
See RFC 3533
NOTE: this parser only supports bitstreams with a single logical stream.
"""
CAPTURE_PATTERN = b'OggS'
class HeaderType(enum.IntFlag):
CONTINUED = 0x01
FIRST = 0x02
LAST = 0x04
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
self.read = read
@property
def packets(self) -> AsyncGenerator[OpusPacket, None]:
async def generate_frames() -> AsyncGenerator[OpusPacket, None]:
packet = b''
packet_count = 0
expected_bitstream_serial_number = None
expected_page_sequence_number = 0
channel_mode = OpusPacket.ChannelMode.STEREO
while True:
# Parse the page header
header = await self.read(27)
if len(header) != 27:
logger.debug("end of stream")
break
capture_pattern = header[:4]
if capture_pattern != self.CAPTURE_PATTERN:
print(capture_pattern.hex())
raise ValueError("invalid capture pattern at start of page")
version = header[4]
if version != 0:
raise ValueError(f"version {version} not supported")
header_type = self.HeaderType(header[5])
(granule_position,) = struct.unpack_from("<Q", header, 6)
(bitstream_serial_number,) = struct.unpack_from("<I", header, 14)
(page_sequence_number,) = struct.unpack_from("<I", header, 18)
(crc_checksum,) = struct.unpack_from("<I", header, 22)
page_segments = header[26]
segment_table = await self.read(page_segments)
if header_type & self.HeaderType.FIRST:
if expected_bitstream_serial_number is None:
# We will only accept pages for the first encountered stream
logger.debug("BOS")
expected_bitstream_serial_number = bitstream_serial_number
expected_page_sequence_number = page_sequence_number
if (
expected_bitstream_serial_number is None
or expected_bitstream_serial_number != bitstream_serial_number
):
logger.debug("skipping page (not the first logical bitstream)")
for lacing_value in segment_table:
if lacing_value:
await self.read(lacing_value)
continue
if expected_page_sequence_number != page_sequence_number:
raise ValueError(
f"expected page sequence number {expected_page_sequence_number}"
f" but got {page_sequence_number}"
)
expected_page_sequence_number = page_sequence_number + 1
# Assemble the page
if not header_type & self.HeaderType.CONTINUED:
packet = b''
for lacing_value in segment_table:
if lacing_value:
packet += await self.read(lacing_value)
if lacing_value < 255:
# End of packet
packet_count += 1
if packet_count == 1:
# The first packet contains the identification header
logger.debug("first packet (header)")
if packet[:8] != b"OpusHead":
raise ValueError("first packet is not OpusHead")
packet_count = (
OpusPacket.ChannelMode.MONO
if packet[9] == 1
else OpusPacket.ChannelMode.STEREO
)
elif packet_count == 2:
# The second packet contains the comment header
logger.debug("second packet (tags)")
if packet[:8] != b"OpusTags":
logger.warning("second packet is not OpusTags")
else:
yield OpusPacket(channel_mode, 20, 48000, packet)
packet = b''
if header_type & self.HeaderType.LAST:
logger.debug("EOS")
return generate_frames()
# -----------------------------------------------------------------------------
class OpusPacketSource:
def __init__(
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
) -> None:
self.read = read
self.mtu = mtu
self.codec_capabilities = codec_capabilities
@property
def packets(self):
async def generate_packets():
# pylint: disable=import-outside-toplevel
from .avdtp import MediaPacket # Import here to avoid a circular reference
sequence_number = 0
elapsed_ms = 0
opus_parser = OpusParser(self.read)
async for opus_packet in opus_parser.packets:
# We only support sending one Opus frame per RTP packet
# TODO: check the spec for the first byte value here
opus_payload = bytes([1]) + opus_packet.payload
elapsed_s = elapsed_ms / 1000
timestamp = int(elapsed_s * opus_packet.sampling_frequency)
rtp_packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, opus_payload
)
rtp_packet.timestamp_seconds = elapsed_s
yield rtp_packet
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
elapsed_ms += opus_packet.duration
return generate_packets()
# -----------------------------------------------------------------------------
# This map should be left at the end of the file so it can refer to the classes
# above
# -----------------------------------------------------------------------------
A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES = {
OPUS_VENDOR_ID: {OPUS_CODEC_ID: OpusMediaCodecInformation}
}