Merge pull request #571 from google/gbg/a2dp-player

a2dp player
This commit is contained in:
Gilles Boccon-Gibod
2024-10-19 07:40:43 -07:00
committed by GitHub
19 changed files with 1884 additions and 547 deletions

View File

@@ -17,12 +17,16 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import struct
import logging
from collections.abc import AsyncGenerator
from typing import List, Callable, Awaitable
import dataclasses
import enum
import logging
import struct
from typing import Awaitable, Callable
from typing_extensions import ClassVar, Self
from .codecs import AacAudioRtpPacket
from .company_ids import COMPANY_IDENTIFIERS
from .sdp import (
DataElement,
@@ -42,6 +46,7 @@ from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
name_or_number,
)
from .rtp import MediaPacket
# -----------------------------------------------------------------------------
@@ -103,6 +108,8 @@ SBC_ALLOCATION_METHOD_NAMES = {
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
}
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
8000,
11025,
@@ -130,6 +137,9 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
}
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
# fmt: on
@@ -257,38 +267,61 @@ class SbcMediaCodecInformation:
A2DP spec - 4.3.2 Codec Specific Information Elements
'''
sampling_frequency: int
channel_mode: int
block_length: int
subbands: int
allocation_method: int
sampling_frequency: SamplingFrequency
channel_mode: ChannelMode
block_length: BlockLength
subbands: Subbands
allocation_method: AllocationMethod
minimum_bitpool_value: int
maximum_bitpool_value: int
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3,
SBC_DUAL_CHANNEL_MODE: 1 << 2,
SBC_STEREO_CHANNEL_MODE: 1 << 1,
SBC_JOINT_STEREO_CHANNEL_MODE: 1,
}
BLOCK_LENGTH_BITS = {4: 1 << 3, 8: 1 << 2, 12: 1 << 1, 16: 1}
SUBBANDS_BITS = {4: 1 << 1, 8: 1}
ALLOCATION_METHOD_BITS = {
SBC_SNR_ALLOCATION_METHOD: 1 << 1,
SBC_LOUDNESS_ALLOCATION_METHOD: 1,
}
class SamplingFrequency(enum.IntFlag):
SF_16000 = 1 << 3
SF_32000 = 1 << 2
SF_44100 = 1 << 1
SF_48000 = 1 << 0
@staticmethod
def from_bytes(data: bytes) -> SbcMediaCodecInformation:
sampling_frequency = (data[0] >> 4) & 0x0F
channel_mode = (data[0] >> 0) & 0x0F
block_length = (data[1] >> 4) & 0x0F
subbands = (data[1] >> 2) & 0x03
allocation_method = (data[1] >> 0) & 0x03
@classmethod
def from_int(cls, sampling_frequency: int) -> Self:
sampling_frequencies = [
16000,
32000,
44100,
48000,
]
index = sampling_frequencies.index(sampling_frequency)
return cls(1 << (len(sampling_frequencies) - index - 1))
class ChannelMode(enum.IntFlag):
MONO = 1 << 3
DUAL_CHANNEL = 1 << 2
STEREO = 1 << 1
JOINT_STEREO = 1 << 0
class BlockLength(enum.IntFlag):
BL_4 = 1 << 3
BL_8 = 1 << 2
BL_12 = 1 << 1
BL_16 = 1 << 0
class Subbands(enum.IntFlag):
S_4 = 1 << 1
S_8 = 1 << 0
class AllocationMethod(enum.IntFlag):
SNR = 1 << 1
LOUDNESS = 1 << 0
@classmethod
def from_bytes(cls, data: bytes) -> Self:
sampling_frequency = cls.SamplingFrequency((data[0] >> 4) & 0x0F)
channel_mode = cls.ChannelMode((data[0] >> 0) & 0x0F)
block_length = cls.BlockLength((data[1] >> 4) & 0x0F)
subbands = cls.Subbands((data[1] >> 2) & 0x03)
allocation_method = cls.AllocationMethod((data[1] >> 0) & 0x03)
minimum_bitpool_value = (data[2] >> 0) & 0xFF
maximum_bitpool_value = (data[3] >> 0) & 0xFF
return SbcMediaCodecInformation(
return cls(
sampling_frequency,
channel_mode,
block_length,
@@ -298,52 +331,6 @@ class SbcMediaCodecInformation:
maximum_bitpool_value,
)
@classmethod
def from_discrete_values(
cls,
sampling_frequency: int,
channel_mode: int,
block_length: int,
subbands: int,
allocation_method: int,
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
return SbcMediaCodecInformation(
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
block_length=cls.BLOCK_LENGTH_BITS[block_length],
subbands=cls.SUBBANDS_BITS[subbands],
allocation_method=cls.ALLOCATION_METHOD_BITS[allocation_method],
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value=maximum_bitpool_value,
)
@classmethod
def from_lists(
cls,
sampling_frequencies: List[int],
channel_modes: List[int],
block_lengths: List[int],
subbands: List[int],
allocation_methods: List[int],
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
return SbcMediaCodecInformation(
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channel_mode=sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes),
block_length=sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths),
subbands=sum(cls.SUBBANDS_BITS[x] for x in subbands),
allocation_method=sum(
cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods
),
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value=maximum_bitpool_value,
)
def __bytes__(self) -> bytes:
return bytes(
[
@@ -356,23 +343,6 @@ class SbcMediaCodecInformation:
]
)
def __str__(self) -> str:
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness']
return '\n'.join(
# pylint: disable=line-too-long
[
'SbcMediaCodecInformation(',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}',
f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}',
f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}',
f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}',
f' minimum_bitpool_value: {self.minimum_bitpool_value}',
f' maximum_bitpool_value: {self.maximum_bitpool_value}' ')',
]
)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
@@ -381,83 +351,66 @@ class AacMediaCodecInformation:
A2DP spec - 4.5.2 Codec Specific Information Elements
'''
object_type: int
sampling_frequency: int
channels: int
rfa: int
object_type: ObjectType
sampling_frequency: SamplingFrequency
channels: Channels
vbr: int
bitrate: int
OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5,
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4,
}
SAMPLING_FREQUENCY_BITS = {
8000: 1 << 11,
11025: 1 << 10,
12000: 1 << 9,
16000: 1 << 8,
22050: 1 << 7,
24000: 1 << 6,
32000: 1 << 5,
44100: 1 << 4,
48000: 1 << 3,
64000: 1 << 2,
88200: 1 << 1,
96000: 1,
}
CHANNELS_BITS = {1: 1 << 1, 2: 1}
class ObjectType(enum.IntFlag):
MPEG_2_AAC_LC = 1 << 7
MPEG_4_AAC_LC = 1 << 6
MPEG_4_AAC_LTP = 1 << 5
MPEG_4_AAC_SCALABLE = 1 << 4
@staticmethod
def from_bytes(data: bytes) -> AacMediaCodecInformation:
object_type = data[0]
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
channels = (data[2] >> 2) & 0x03
rfa = 0
class SamplingFrequency(enum.IntFlag):
SF_8000 = 1 << 11
SF_11025 = 1 << 10
SF_12000 = 1 << 9
SF_16000 = 1 << 8
SF_22050 = 1 << 7
SF_24000 = 1 << 6
SF_32000 = 1 << 5
SF_44100 = 1 << 4
SF_48000 = 1 << 3
SF_64000 = 1 << 2
SF_88200 = 1 << 1
SF_96000 = 1 << 0
@classmethod
def from_int(cls, sampling_frequency: int) -> Self:
sampling_frequencies = [
8000,
11025,
12000,
16000,
22050,
24000,
32000,
44100,
48000,
64000,
88200,
96000,
]
index = sampling_frequencies.index(sampling_frequency)
return cls(1 << (len(sampling_frequencies) - index - 1))
class Channels(enum.IntFlag):
MONO = 1 << 1
STEREO = 1 << 0
@classmethod
def from_bytes(cls, data: bytes) -> AacMediaCodecInformation:
object_type = cls.ObjectType(data[0])
sampling_frequency = cls.SamplingFrequency(
(data[1] << 4) | ((data[2] >> 4) & 0x0F)
)
channels = cls.Channels((data[2] >> 2) & 0x03)
vbr = (data[3] >> 7) & 0x01
bitrate = ((data[3] & 0x7F) << 16) | (data[4] << 8) | data[5]
return AacMediaCodecInformation(
object_type, sampling_frequency, channels, rfa, vbr, bitrate
)
@classmethod
def from_discrete_values(
cls,
object_type: int,
sampling_frequency: int,
channels: int,
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels=cls.CHANNELS_BITS[channels],
rfa=0,
vbr=vbr,
bitrate=bitrate,
)
@classmethod
def from_lists(
cls,
object_types: List[int],
sampling_frequencies: List[int],
channels: List[int],
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channels=sum(cls.CHANNELS_BITS[x] for x in channels),
rfa=0,
vbr=vbr,
bitrate=bitrate,
object_type, sampling_frequency, channels, vbr, bitrate
)
def __bytes__(self) -> bytes:
@@ -472,30 +425,6 @@ class AacMediaCodecInformation:
]
)
def __str__(self) -> str:
object_types = [
'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC',
'MPEG_4_AAC_LTP',
'MPEG_4_AAC_SCALABLE',
'[4]',
'[5]',
'[6]',
'[7]',
]
channels = [1, 2]
# pylint: disable=line-too-long
return '\n'.join(
[
'AacMediaCodecInformation(',
f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}',
f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}',
f' vbr: {self.vbr}',
f' bitrate: {self.bitrate}' ')',
]
)
@dataclasses.dataclass
# -----------------------------------------------------------------------------
@@ -514,7 +443,7 @@ class VendorSpecificMediaCodecInformation:
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
def __bytes__(self) -> bytes:
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
return struct.pack('<IH', self.vendor_id, self.codec_id) + self.value
def __str__(self) -> str:
# pylint: disable=line-too-long
@@ -528,13 +457,69 @@ class VendorSpecificMediaCodecInformation:
)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class OpusMediaCodecInformation(VendorSpecificMediaCodecInformation):
vendor_id: int = dataclasses.field(init=False, repr=False)
codec_id: int = dataclasses.field(init=False, repr=False)
value: bytes = dataclasses.field(init=False, repr=False)
channel_mode: ChannelMode
frame_size: FrameSize
sampling_frequency: SamplingFrequency
class ChannelMode(enum.IntFlag):
MONO = 1 << 0
STEREO = 1 << 1
DUAL_MONO = 1 << 2
class FrameSize(enum.IntFlag):
FS_10MS = 1 << 0
FS_20MS = 1 << 1
class SamplingFrequency(enum.IntFlag):
SF_48000 = 1 << 0
VENDOR_ID: ClassVar[int] = 0x000000E0
CODEC_ID: ClassVar[int] = 0x0001
def __post_init__(self) -> None:
self.vendor_id = self.VENDOR_ID
self.codec_id = self.CODEC_ID
self.value = bytes(
[
self.channel_mode
| (self.frame_size << 3)
| (self.sampling_frequency << 7)
]
)
@classmethod
def from_bytes(cls, data: bytes) -> Self:
"""Create a new instance from the `value` part of the data, not including
the vendor id and codec id"""
channel_mode = cls.ChannelMode(data[0] & 0x07)
frame_size = cls.FrameSize((data[0] >> 3) & 0x03)
sampling_frequency = cls.SamplingFrequency((data[0] >> 7) & 0x01)
return cls(
channel_mode,
frame_size,
sampling_frequency,
)
def __str__(self) -> str:
return repr(self)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class SbcFrame:
sampling_frequency: int
block_count: int
channel_mode: int
allocation_method: int
subband_count: int
bitpool: int
payload: bytes
@property
@@ -553,8 +538,10 @@ class SbcFrame:
return (
f'SBC(sf={self.sampling_frequency},'
f'cm={self.channel_mode},'
f'am={self.allocation_method},'
f'br={self.bitrate},'
f'sc={self.sample_count},'
f'bp={self.bitpool},'
f'size={len(self.payload)})'
)
@@ -583,6 +570,7 @@ class SbcParser:
blocks = 4 * (1 + ((header[1] >> 4) & 3))
channel_mode = (header[1] >> 2) & 3
channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2
allocation_method = (header[1] >> 1) & 1
subbands = 8 if ((header[1]) & 1) else 4
bitpool = header[2]
@@ -602,7 +590,13 @@ class SbcParser:
# Emit the next frame
yield SbcFrame(
sampling_frequency, blocks, channel_mode, subbands, payload
sampling_frequency,
blocks,
channel_mode,
allocation_method,
subbands,
bitpool,
payload,
)
return generate_frames()
@@ -610,21 +604,15 @@ class SbcParser:
# -----------------------------------------------------------------------------
class SbcPacketSource:
def __init__(
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
) -> None:
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
self.read = read
self.mtu = mtu
self.codec_capabilities = codec_capabilities
@property
def packets(self):
async def generate_packets():
# pylint: disable=import-outside-toplevel
from .avdtp import MediaPacket # Import here to avoid a circular reference
sequence_number = 0
timestamp = 0
sample_count = 0
frames = []
frames_size = 0
max_rtp_payload = self.mtu - 12 - 1
@@ -632,29 +620,29 @@ class SbcPacketSource:
# NOTE: this doesn't support frame fragments
sbc_parser = SbcParser(self.read)
async for frame in sbc_parser.frames:
print(frame)
if (
frames_size + len(frame.payload) > max_rtp_payload
or len(frames) == 16
or len(frames) == SBC_MAX_FRAMES_IN_RTP_PAYLOAD
):
# Need to flush what has been accumulated so far
logger.debug(f"yielding {len(frames)} frames")
# Emit a packet
sbc_payload = bytes([len(frames)]) + b''.join(
sbc_payload = bytes([len(frames) & 0x0F]) + b''.join(
[frame.payload for frame in frames]
)
timestamp_seconds = sample_count / frame.sampling_frequency
timestamp = int(1000 * timestamp_seconds)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
)
packet.timestamp_seconds = timestamp / frame.sampling_frequency
packet.timestamp_seconds = timestamp_seconds
yield packet
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
timestamp += sum((frame.sample_count for frame in frames))
timestamp &= 0xFFFFFFFF
sample_count += sum((frame.sample_count for frame in frames))
frames = [frame]
frames_size = len(frame.payload)
else:
@@ -663,3 +651,315 @@ class SbcPacketSource:
frames_size += len(frame.payload)
return generate_packets()
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class AacFrame:
class Profile(enum.IntEnum):
MAIN = 0
LC = 1
SSR = 2
LTP = 3
profile: Profile
sampling_frequency: int
channel_configuration: int
payload: bytes
@property
def sample_count(self) -> int:
return 1024
@property
def duration(self) -> float:
return self.sample_count / self.sampling_frequency
def __str__(self) -> str:
return (
f'AAC(sf={self.sampling_frequency},'
f'ch={self.channel_configuration},'
f'size={len(self.payload)})'
)
# -----------------------------------------------------------------------------
ADTS_AAC_SAMPLING_FREQUENCIES = [
96000,
88200,
64000,
48000,
44100,
32000,
24000,
22050,
16000,
12000,
11025,
8000,
7350,
0,
0,
0,
]
# -----------------------------------------------------------------------------
class AacParser:
"""Parser for AAC frames in an ADTS stream"""
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
self.read = read
@property
def frames(self) -> AsyncGenerator[AacFrame, None]:
async def generate_frames() -> AsyncGenerator[AacFrame, None]:
while True:
header = await self.read(7)
if not header:
return
sync_word = (header[0] << 4) | (header[1] >> 4)
if sync_word != 0b111111111111:
raise ValueError(f"invalid sync word ({sync_word:06x})")
layer = (header[1] >> 1) & 0b11
profile = AacFrame.Profile((header[2] >> 6) & 0b11)
sampling_frequency = ADTS_AAC_SAMPLING_FREQUENCIES[
(header[2] >> 2) & 0b1111
]
channel_configuration = ((header[2] & 0b1) << 2) | (header[3] >> 6)
frame_length = (
((header[3] & 0b11) << 11) | (header[4] << 3) | (header[5] >> 5)
)
if layer != 0:
raise ValueError("layer must be 0")
payload = await self.read(frame_length - 7)
if payload:
yield AacFrame(
profile, sampling_frequency, channel_configuration, payload
)
return generate_frames()
# -----------------------------------------------------------------------------
class AacPacketSource:
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
self.read = read
self.mtu = mtu
@property
def packets(self):
async def generate_packets():
sequence_number = 0
sample_count = 0
aac_parser = AacParser(self.read)
async for frame in aac_parser.frames:
logger.debug("yielding one AAC frame")
# Emit a packet
aac_payload = bytes(
AacAudioRtpPacket.for_simple_aac(
frame.sampling_frequency,
frame.channel_configuration,
frame.payload,
)
)
timestamp_seconds = sample_count / frame.sampling_frequency
timestamp = int(1000 * timestamp_seconds)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, aac_payload
)
packet.timestamp_seconds = timestamp_seconds
yield packet
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
sample_count += frame.sample_count
return generate_packets()
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class OpusPacket:
class ChannelMode(enum.IntEnum):
MONO = 0
STEREO = 1
DUAL_MONO = 2
channel_mode: ChannelMode
duration: int # Duration in ms.
sampling_frequency: int
payload: bytes
def __str__(self) -> str:
return (
f'Opus(ch={self.channel_mode.name}, '
f'd={self.duration}ms, '
f'size={len(self.payload)})'
)
# -----------------------------------------------------------------------------
class OpusParser:
"""
Parser for Opus packets in an Ogg stream
See RFC 3533
NOTE: this parser only supports bitstreams with a single logical stream.
"""
CAPTURE_PATTERN = b'OggS'
class HeaderType(enum.IntFlag):
CONTINUED = 0x01
FIRST = 0x02
LAST = 0x04
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
self.read = read
@property
def packets(self) -> AsyncGenerator[OpusPacket, None]:
async def generate_frames() -> AsyncGenerator[OpusPacket, None]:
packet = b''
packet_count = 0
expected_bitstream_serial_number = None
expected_page_sequence_number = 0
channel_mode = OpusPacket.ChannelMode.STEREO
while True:
# Parse the page header
header = await self.read(27)
if len(header) != 27:
logger.debug("end of stream")
break
capture_pattern = header[:4]
if capture_pattern != self.CAPTURE_PATTERN:
print(capture_pattern.hex())
raise ValueError("invalid capture pattern at start of page")
version = header[4]
if version != 0:
raise ValueError(f"version {version} not supported")
header_type = self.HeaderType(header[5])
(
granule_position,
bitstream_serial_number,
page_sequence_number,
crc_checksum,
page_segments,
) = struct.unpack_from("<QIIIB", header, 6)
segment_table = await self.read(page_segments)
if header_type & self.HeaderType.FIRST:
if expected_bitstream_serial_number is None:
# We will only accept pages for the first encountered stream
logger.debug("BOS")
expected_bitstream_serial_number = bitstream_serial_number
expected_page_sequence_number = page_sequence_number
if (
expected_bitstream_serial_number is None
or expected_bitstream_serial_number != bitstream_serial_number
):
logger.debug("skipping page (not the first logical bitstream)")
for lacing_value in segment_table:
if lacing_value:
await self.read(lacing_value)
continue
if expected_page_sequence_number != page_sequence_number:
raise ValueError(
f"expected page sequence number {expected_page_sequence_number}"
f" but got {page_sequence_number}"
)
expected_page_sequence_number = page_sequence_number + 1
# Assemble the page
if not header_type & self.HeaderType.CONTINUED:
packet = b''
for lacing_value in segment_table:
if lacing_value:
packet += await self.read(lacing_value)
if lacing_value < 255:
# End of packet
packet_count += 1
if packet_count == 1:
# The first packet contains the identification header
logger.debug("first packet (header)")
if packet[:8] != b"OpusHead":
raise ValueError("first packet is not OpusHead")
packet_count = (
OpusPacket.ChannelMode.MONO
if packet[9] == 1
else OpusPacket.ChannelMode.STEREO
)
elif packet_count == 2:
# The second packet contains the comment header
logger.debug("second packet (tags)")
if packet[:8] != b"OpusTags":
logger.warning("second packet is not OpusTags")
else:
yield OpusPacket(channel_mode, 20, 48000, packet)
packet = b''
if header_type & self.HeaderType.LAST:
logger.debug("EOS")
return generate_frames()
# -----------------------------------------------------------------------------
class OpusPacketSource:
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
self.read = read
self.mtu = mtu
@property
def packets(self):
async def generate_packets():
sequence_number = 0
elapsed_ms = 0
opus_parser = OpusParser(self.read)
async for opus_packet in opus_parser.packets:
# We only support sending one Opus frame per RTP packet
# TODO: check the spec for the first byte value here
opus_payload = bytes([1]) + opus_packet.payload
elapsed_s = elapsed_ms / 1000
timestamp = int(elapsed_s * opus_packet.sampling_frequency)
rtp_packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, opus_payload
)
rtp_packet.timestamp_seconds = elapsed_s
yield rtp_packet
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
elapsed_ms += opus_packet.duration
return generate_packets()
# -----------------------------------------------------------------------------
# This map should be left at the end of the file so it can refer to the classes
# above
# -----------------------------------------------------------------------------
A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES = {
OpusMediaCodecInformation.VENDOR_ID: {
OpusMediaCodecInformation.CODEC_ID: OpusMediaCodecInformation
}
}

View File

@@ -119,7 +119,7 @@ class Frame:
# Not supported
raise NotImplementedError("extended subunit types not supported")
if subunit_id < 5:
if subunit_id < 5 or subunit_id == 7:
opcode_offset = 2
elif subunit_id == 5:
# Extended to the next byte
@@ -132,7 +132,6 @@ class Frame:
else:
subunit_id = 5 + extension
opcode_offset = 3
elif subunit_id == 6:
raise core.InvalidPacketError("reserved subunit ID")

View File

@@ -17,12 +17,10 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import struct
import time
import logging
import enum
import warnings
from pyee import EventEmitter
from typing import (
Any,
Awaitable,
@@ -39,6 +37,8 @@ from typing import (
cast,
)
from pyee import EventEmitter
from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
InvalidStateError,
@@ -51,13 +51,16 @@ from .a2dp import (
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
A2DP_NON_A2DP_CODEC_TYPE,
A2DP_SBC_CODEC_TYPE,
A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES,
AacMediaCodecInformation,
SbcMediaCodecInformation,
VendorSpecificMediaCodecInformation,
)
from .rtp import MediaPacket
from . import sdp, device, l2cap
from .colors import color
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -278,95 +281,6 @@ class RealtimeClock:
await asyncio.sleep(duration)
# -----------------------------------------------------------------------------
class MediaPacket:
@staticmethod
def from_bytes(data: bytes) -> MediaPacket:
version = (data[0] >> 6) & 0x03
padding = (data[0] >> 5) & 0x01
extension = (data[0] >> 4) & 0x01
csrc_count = data[0] & 0x0F
marker = (data[1] >> 7) & 0x01
payload_type = data[1] & 0x7F
sequence_number = struct.unpack_from('>H', data, 2)[0]
timestamp = struct.unpack_from('>I', data, 4)[0]
ssrc = struct.unpack_from('>I', data, 8)[0]
csrc_list = [
struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)
]
payload = data[12 + csrc_count * 4 :]
return MediaPacket(
version,
padding,
extension,
marker,
sequence_number,
timestamp,
ssrc,
csrc_list,
payload_type,
payload,
)
def __init__(
self,
version: int,
padding: int,
extension: int,
marker: int,
sequence_number: int,
timestamp: int,
ssrc: int,
csrc_list: List[int],
payload_type: int,
payload: bytes,
) -> None:
self.version = version
self.padding = padding
self.extension = extension
self.marker = marker
self.sequence_number = sequence_number & 0xFFFF
self.timestamp = timestamp & 0xFFFFFFFF
self.ssrc = ssrc
self.csrc_list = csrc_list
self.payload_type = payload_type
self.payload = payload
def __bytes__(self) -> bytes:
header = bytes(
[
self.version << 6
| self.padding << 5
| self.extension << 4
| len(self.csrc_list),
self.marker << 7 | self.payload_type,
]
) + struct.pack(
'>HII',
self.sequence_number,
self.timestamp,
self.ssrc,
)
for csrc in self.csrc_list:
header += struct.pack('>I', csrc)
return header + self.payload
def __str__(self) -> str:
return (
f'RTP(v={self.version},'
f'p={self.padding},'
f'x={self.extension},'
f'm={self.marker},'
f'pt={self.payload_type},'
f'sn={self.sequence_number},'
f'ts={self.timestamp},'
f'ssrc={self.ssrc},'
f'csrcs={self.csrc_list},'
f'payload_size={len(self.payload)})'
)
# -----------------------------------------------------------------------------
class MediaPacketPump:
pump_task: Optional[asyncio.Task]
@@ -377,6 +291,7 @@ class MediaPacketPump:
self.packets = packets
self.clock = clock
self.pump_task = None
self.completed = asyncio.Event()
async def start(self, rtp_channel: l2cap.ClassicChannel) -> None:
async def pump_packets():
@@ -406,6 +321,8 @@ class MediaPacketPump:
)
except asyncio.exceptions.CancelledError:
logger.debug('pump canceled')
finally:
self.completed.set()
# Pump packets
self.pump_task = asyncio.create_task(pump_packets())
@@ -417,6 +334,9 @@ class MediaPacketPump:
await self.pump_task
self.pump_task = None
async def wait_for_completion(self) -> None:
await self.completed.wait()
# -----------------------------------------------------------------------------
class MessageAssembler:
@@ -615,11 +535,25 @@ class MediaCodecCapabilities(ServiceCapabilities):
self.media_codec_information
)
elif self.media_codec_type == A2DP_NON_A2DP_CODEC_TYPE:
self.media_codec_information = (
vendor_media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(
self.media_codec_information
)
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
self.media_codec_information = media_codec_information_class.from_bytes(
vendor_media_codec_information.value
)
else:
self.media_codec_information = vendor_media_codec_information
def __init__(
self,
@@ -1316,10 +1250,20 @@ class Protocol(EventEmitter):
return None
def add_source(
self, codec_capabilities: MediaCodecCapabilities, packet_pump: MediaPacketPump
self,
codec_capabilities: MediaCodecCapabilities,
packet_pump: MediaPacketPump,
delay_reporting: bool = False,
) -> LocalSource:
seid = len(self.local_endpoints) + 1
source = LocalSource(self, seid, codec_capabilities, packet_pump)
service_capabilities = (
[ServiceCapabilities(AVDTP_DELAY_REPORTING_SERVICE_CATEGORY)]
if delay_reporting
else []
)
source = LocalSource(
self, seid, codec_capabilities, service_capabilities, packet_pump
)
self.local_endpoints.append(source)
return source
@@ -1372,7 +1316,7 @@ class Protocol(EventEmitter):
return self.remote_endpoints.values()
def find_remote_sink_by_codec(
self, media_type: int, codec_type: int
self, media_type: int, codec_type: int, vendor_id: int = 0, codec_id: int = 0
) -> Optional[DiscoveredStreamEndPoint]:
for endpoint in self.remote_endpoints.values():
if (
@@ -1397,7 +1341,19 @@ class Protocol(EventEmitter):
codec_capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE
and codec_capabilities.media_codec_type == codec_type
):
has_codec = True
if isinstance(
codec_capabilities.media_codec_information,
VendorSpecificMediaCodecInformation,
):
if (
codec_capabilities.media_codec_information.vendor_id
== vendor_id
and codec_capabilities.media_codec_information.codec_id
== codec_id
):
has_codec = True
else:
has_codec = True
if has_media_transport and has_codec:
return endpoint
@@ -2180,12 +2136,13 @@ class LocalSource(LocalStreamEndPoint):
protocol: Protocol,
seid: int,
codec_capabilities: MediaCodecCapabilities,
other_capabilitiles: Iterable[ServiceCapabilities],
packet_pump: MediaPacketPump,
) -> None:
capabilities = [
ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY),
codec_capabilities,
]
] + list(other_capabilitiles)
super().__init__(
protocol,
seid,

View File

@@ -1491,10 +1491,14 @@ class Protocol(pyee.EventEmitter):
f"<<< AVCTP Command, transaction_label={transaction_label}: " f"{command}"
)
# Only the PANEL subunit type with subunit ID 0 is supported in this profile.
if (
command.subunit_type != avc.Frame.SubunitType.PANEL
or command.subunit_id != 0
# Only addressing the unit, or the PANEL subunit with subunit ID 0 is supported
# in this profile.
if not (
command.subunit_type == avc.Frame.SubunitType.UNIT
and command.subunit_id == 7
) and not (
command.subunit_type == avc.Frame.SubunitType.PANEL
and command.subunit_id == 0
):
logger.debug("subunit not supported")
self.send_not_implemented_response(transaction_label, command)
@@ -1528,8 +1532,8 @@ class Protocol(pyee.EventEmitter):
# TODO: delegate
response = avc.PassThroughResponseFrame(
avc.ResponseFrame.ResponseCode.ACCEPTED,
avc.Frame.SubunitType.PANEL,
0,
command.subunit_type,
command.subunit_id,
command.state_flag,
command.operation_id,
command.operation_data,
@@ -1846,6 +1850,15 @@ class Protocol(pyee.EventEmitter):
RejectedResponse(pdu_id, status_code),
)
def send_not_implemented_avrcp_response(
self, transaction_label: int, pdu_id: Protocol.PduId
) -> None:
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED,
NotImplementedResponse(pdu_id, b''),
)
def _on_get_capabilities_command(
self, transaction_label: int, command: GetCapabilitiesCommand
) -> None:
@@ -1891,29 +1904,35 @@ class Protocol(pyee.EventEmitter):
async def register_notification():
# Check if the event is supported.
supported_events = await self.delegate.get_supported_events()
if command.event_id in supported_events:
if command.event_id == EventId.VOLUME_CHANGED:
volume = await self.delegate.get_absolute_volume()
response = RegisterNotificationResponse(VolumeChangedEvent(volume))
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.INTERIM,
response,
)
self._register_notification_listener(transaction_label, command)
return
if command.event_id not in supported_events:
logger.debug("event not supported")
self.send_not_implemented_avrcp_response(
transaction_label, self.PduId.REGISTER_NOTIFICATION
)
return
if command.event_id == EventId.PLAYBACK_STATUS_CHANGED:
# TODO: testing only, use delegate
response = RegisterNotificationResponse(
PlaybackStatusChangedEvent(play_status=PlayStatus.PLAYING)
)
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.INTERIM,
response,
)
self._register_notification_listener(transaction_label, command)
return
if command.event_id == EventId.VOLUME_CHANGED:
volume = await self.delegate.get_absolute_volume()
response = RegisterNotificationResponse(VolumeChangedEvent(volume))
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.INTERIM,
response,
)
self._register_notification_listener(transaction_label, command)
return
if command.event_id == EventId.PLAYBACK_STATUS_CHANGED:
# TODO: testing only, use delegate
response = RegisterNotificationResponse(
PlaybackStatusChangedEvent(play_status=PlayStatus.PLAYING)
)
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.INTERIM,
response,
)
self._register_notification_listener(transaction_label, command)
return
self._delegate_command(transaction_label, command, register_notification())

View File

@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
from dataclasses import dataclass
from typing_extensions import Self
from bumble import core
@@ -101,12 +102,40 @@ class BitReader:
break
# -----------------------------------------------------------------------------
class BitWriter:
"""Simple but not optimized bit stream writer."""
data: int
bit_count: int
def __init__(self) -> None:
self.data = 0
self.bit_count = 0
def write(self, value: int, bit_count: int) -> None:
self.data = (self.data << bit_count) | value
self.bit_count += bit_count
def write_bytes(self, data: bytes) -> None:
bit_count = 8 * len(data)
self.data = (self.data << bit_count) | int.from_bytes(data, 'big')
self.bit_count += bit_count
def __bytes__(self) -> bytes:
return (self.data << ((8 - (self.bit_count % 8)) % 8)).to_bytes(
(self.bit_count + 7) // 8, 'big'
)
# -----------------------------------------------------------------------------
class AacAudioRtpPacket:
"""AAC payload encapsulated in an RTP packet payload"""
audio_mux_element: AudioMuxElement
@staticmethod
def latm_value(reader: BitReader) -> int:
def read_latm_value(reader: BitReader) -> int:
bytes_for_value = reader.read(2)
value = 0
for _ in range(bytes_for_value + 1):
@@ -114,24 +143,33 @@ class AacAudioRtpPacket:
return value
@staticmethod
def program_config_element(reader: BitReader):
raise core.InvalidPacketError('program_config_element not supported')
def read_audio_object_type(reader: BitReader):
# GetAudioObjectType - ISO/EIC 14496-3 Table 1.16
audio_object_type = reader.read(5)
if audio_object_type == 31:
audio_object_type = 32 + reader.read(6)
return audio_object_type
@dataclass
class GASpecificConfig:
def __init__(
self, reader: BitReader, channel_configuration: int, audio_object_type: int
) -> None:
audio_object_type: int
# NOTE: other fields not supported
@classmethod
def from_bits(
cls, reader: BitReader, channel_configuration: int, audio_object_type: int
) -> Self:
# GASpecificConfig - ISO/EIC 14496-3 Table 4.1
frame_length_flag = reader.read(1)
depends_on_core_coder = reader.read(1)
if depends_on_core_coder:
self.core_coder_delay = reader.read(14)
core_coder_delay = reader.read(14)
extension_flag = reader.read(1)
if not channel_configuration:
AacAudioRtpPacket.program_config_element(reader)
raise core.InvalidPacketError('program_config_element not supported')
if audio_object_type in (6, 20):
self.layer_nr = reader.read(3)
layer_nr = reader.read(3)
if extension_flag:
if audio_object_type == 22:
num_of_sub_frame = reader.read(5)
@@ -144,14 +182,13 @@ class AacAudioRtpPacket:
if extension_flag_3 == 1:
raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
@staticmethod
def audio_object_type(reader: BitReader):
# GetAudioObjectType - ISO/EIC 14496-3 Table 1.16
audio_object_type = reader.read(5)
if audio_object_type == 31:
audio_object_type = 32 + reader.read(6)
return cls(audio_object_type)
return audio_object_type
def to_bits(self, writer: BitWriter) -> None:
assert self.audio_object_type in (1, 2)
writer.write(0, 1) # frame_length_flag = 0
writer.write(0, 1) # depends_on_core_coder = 0
writer.write(0, 1) # extension_flag = 0
@dataclass
class AudioSpecificConfig:
@@ -159,6 +196,7 @@ class AacAudioRtpPacket:
sampling_frequency_index: int
sampling_frequency: int
channel_configuration: int
ga_specific_config: AacAudioRtpPacket.GASpecificConfig
sbr_present_flag: int
ps_present_flag: int
extension_audio_object_type: int
@@ -182,44 +220,73 @@ class AacAudioRtpPacket:
7350,
]
def __init__(self, reader: BitReader) -> None:
# AudioSpecificConfig - ISO/EIC 14496-3 Table 1.15
self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
self.sampling_frequency_index = reader.read(4)
if self.sampling_frequency_index == 0xF:
self.sampling_frequency = reader.read(24)
else:
self.sampling_frequency = self.SAMPLING_FREQUENCIES[
self.sampling_frequency_index
]
self.channel_configuration = reader.read(4)
self.sbr_present_flag = -1
self.ps_present_flag = -1
if self.audio_object_type in (5, 29):
self.extension_audio_object_type = 5
self.sbc_present_flag = 1
if self.audio_object_type == 29:
self.ps_present_flag = 1
self.extension_sampling_frequency_index = reader.read(4)
if self.extension_sampling_frequency_index == 0xF:
self.extension_sampling_frequency = reader.read(24)
else:
self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[
self.extension_sampling_frequency_index
]
self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
if self.audio_object_type == 22:
self.extension_channel_configuration = reader.read(4)
else:
self.extension_audio_object_type = 0
@classmethod
def for_simple_aac(
cls,
audio_object_type: int,
sampling_frequency: int,
channel_configuration: int,
) -> Self:
if sampling_frequency not in cls.SAMPLING_FREQUENCIES:
raise ValueError(f'invalid sampling frequency {sampling_frequency}')
if self.audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23):
ga_specific_config = AacAudioRtpPacket.GASpecificConfig(
reader, self.channel_configuration, self.audio_object_type
ga_specific_config = AacAudioRtpPacket.GASpecificConfig(audio_object_type)
return cls(
audio_object_type=audio_object_type,
sampling_frequency_index=cls.SAMPLING_FREQUENCIES.index(
sampling_frequency
),
sampling_frequency=sampling_frequency,
channel_configuration=channel_configuration,
ga_specific_config=ga_specific_config,
sbr_present_flag=0,
ps_present_flag=0,
extension_audio_object_type=0,
extension_sampling_frequency_index=0,
extension_sampling_frequency=0,
extension_channel_configuration=0,
)
@classmethod
def from_bits(cls, reader: BitReader) -> Self:
# AudioSpecificConfig - ISO/EIC 14496-3 Table 1.15
audio_object_type = AacAudioRtpPacket.read_audio_object_type(reader)
sampling_frequency_index = reader.read(4)
if sampling_frequency_index == 0xF:
sampling_frequency = reader.read(24)
else:
sampling_frequency = cls.SAMPLING_FREQUENCIES[sampling_frequency_index]
channel_configuration = reader.read(4)
sbr_present_flag = 0
ps_present_flag = 0
extension_sampling_frequency_index = 0
extension_sampling_frequency = 0
extension_channel_configuration = 0
extension_audio_object_type = 0
if audio_object_type in (5, 29):
extension_audio_object_type = 5
sbr_present_flag = 1
if audio_object_type == 29:
ps_present_flag = 1
extension_sampling_frequency_index = reader.read(4)
if extension_sampling_frequency_index == 0xF:
extension_sampling_frequency = reader.read(24)
else:
extension_sampling_frequency = cls.SAMPLING_FREQUENCIES[
extension_sampling_frequency_index
]
audio_object_type = AacAudioRtpPacket.read_audio_object_type(reader)
if audio_object_type == 22:
extension_channel_configuration = reader.read(4)
if audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23):
ga_specific_config = AacAudioRtpPacket.GASpecificConfig.from_bits(
reader, channel_configuration, audio_object_type
)
else:
raise core.InvalidPacketError(
f'audioObjectType {self.audio_object_type} not supported'
f'audioObjectType {audio_object_type} not supported'
)
# if self.extension_audio_object_type != 5 and bits_to_decode >= 16:
@@ -248,13 +315,44 @@ class AacAudioRtpPacket:
# self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[self.extension_sampling_frequency_index]
# self.extension_channel_configuration = reader.read(4)
return cls(
audio_object_type,
sampling_frequency_index,
sampling_frequency,
channel_configuration,
ga_specific_config,
sbr_present_flag,
ps_present_flag,
extension_audio_object_type,
extension_sampling_frequency_index,
extension_sampling_frequency,
extension_channel_configuration,
)
def to_bits(self, writer: BitWriter) -> None:
if self.sampling_frequency_index >= 15:
raise ValueError(
f"unsupported sampling frequency index {self.sampling_frequency_index}"
)
if self.audio_object_type not in (1, 2):
raise ValueError(
f"unsupported audio object type {self.audio_object_type} "
)
writer.write(self.audio_object_type, 5)
writer.write(self.sampling_frequency_index, 4)
writer.write(self.channel_configuration, 4)
self.ga_specific_config.to_bits(writer)
@dataclass
class StreamMuxConfig:
other_data_present: int
other_data_len_bits: int
audio_specific_config: AacAudioRtpPacket.AudioSpecificConfig
def __init__(self, reader: BitReader) -> None:
@classmethod
def from_bits(cls, reader: BitReader) -> Self:
# StreamMuxConfig - ISO/EIC 14496-3 Table 1.42
audio_mux_version = reader.read(1)
if audio_mux_version == 1:
@@ -264,7 +362,7 @@ class AacAudioRtpPacket:
if audio_mux_version_a != 0:
raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
if audio_mux_version == 1:
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
tara_buffer_fullness = AacAudioRtpPacket.read_latm_value(reader)
stream_cnt = 0
all_streams_same_time_framing = reader.read(1)
num_sub_frames = reader.read(6)
@@ -275,13 +373,13 @@ class AacAudioRtpPacket:
if num_layer != 0:
raise core.InvalidPacketError('num_layer != 0 not supported')
if audio_mux_version == 0:
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig.from_bits(
reader
)
else:
asc_len = AacAudioRtpPacket.latm_value(reader)
asc_len = AacAudioRtpPacket.read_latm_value(reader)
marker = reader.bit_position
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig.from_bits(
reader
)
audio_specific_config_len = reader.bit_position - marker
@@ -299,36 +397,49 @@ class AacAudioRtpPacket:
f'frame_length_type {frame_length_type} not supported'
)
self.other_data_present = reader.read(1)
if self.other_data_present:
other_data_present = reader.read(1)
other_data_len_bits = 0
if other_data_present:
if audio_mux_version == 1:
self.other_data_len_bits = AacAudioRtpPacket.latm_value(reader)
other_data_len_bits = AacAudioRtpPacket.read_latm_value(reader)
else:
self.other_data_len_bits = 0
while True:
self.other_data_len_bits *= 256
other_data_len_bits *= 256
other_data_len_esc = reader.read(1)
self.other_data_len_bits += reader.read(8)
other_data_len_bits += reader.read(8)
if other_data_len_esc == 0:
break
crc_check_present = reader.read(1)
if crc_check_present:
crc_checksum = reader.read(8)
return cls(other_data_present, other_data_len_bits, audio_specific_config)
def to_bits(self, writer: BitWriter) -> None:
writer.write(0, 1) # audioMuxVersion = 0
writer.write(1, 1) # allStreamsSameTimeFraming = 1
writer.write(0, 6) # numSubFrames = 0
writer.write(0, 4) # numProgram = 0
writer.write(0, 3) # numLayer = 0
self.audio_specific_config.to_bits(writer)
writer.write(0, 3) # frameLengthType = 0
writer.write(0, 8) # latmBufferFullness = 0
writer.write(0, 1) # otherDataPresent = 0
writer.write(0, 1) # crcCheckPresent = 0
@dataclass
class AudioMuxElement:
payload: bytes
stream_mux_config: AacAudioRtpPacket.StreamMuxConfig
payload: bytes
def __init__(self, reader: BitReader, mux_config_present: int):
if mux_config_present == 0:
raise core.InvalidPacketError('muxConfigPresent == 0 not supported')
@classmethod
def from_bits(cls, reader: BitReader) -> Self:
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41
# (only supports mux_config_present=1)
use_same_stream_mux = reader.read(1)
if use_same_stream_mux:
raise core.InvalidPacketError('useSameStreamMux == 1 not supported')
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
stream_mux_config = AacAudioRtpPacket.StreamMuxConfig.from_bits(reader)
# We only support:
# allStreamsSameTimeFraming == 1
@@ -344,19 +455,46 @@ class AacAudioRtpPacket:
if tmp != 255:
break
self.payload = reader.read_bytes(mux_slot_length_bytes)
payload = reader.read_bytes(mux_slot_length_bytes)
if self.stream_mux_config.other_data_present:
reader.skip(self.stream_mux_config.other_data_len_bits)
if stream_mux_config.other_data_present:
reader.skip(stream_mux_config.other_data_len_bits)
# ByteAlign
while reader.bit_position % 8:
reader.read(1)
def __init__(self, data: bytes) -> None:
return cls(stream_mux_config, payload)
def to_bits(self, writer: BitWriter) -> None:
writer.write(0, 1) # useSameStreamMux = 0
self.stream_mux_config.to_bits(writer)
mux_slot_length_bytes = len(self.payload)
while mux_slot_length_bytes > 255:
writer.write(255, 8)
mux_slot_length_bytes -= 255
writer.write(mux_slot_length_bytes, 8)
if mux_slot_length_bytes == 255:
writer.write(0, 8)
writer.write_bytes(self.payload)
@classmethod
def from_bytes(cls, data: bytes) -> Self:
# Parse the bit stream
reader = BitReader(data)
self.audio_mux_element = self.AudioMuxElement(reader, mux_config_present=1)
return cls(cls.AudioMuxElement.from_bits(reader))
@classmethod
def for_simple_aac(
cls, sampling_frequency: int, channel_configuration: int, payload: bytes
) -> Self:
audio_specific_config = cls.AudioSpecificConfig.for_simple_aac(
2, sampling_frequency, channel_configuration
)
stream_mux_config = cls.StreamMuxConfig(0, 0, audio_specific_config)
audio_mux_element = cls.AudioMuxElement(stream_mux_config, payload)
return cls(audio_mux_element)
def to_adts(self):
# pylint: disable=line-too-long
@@ -383,3 +521,11 @@ class AacAudioRtpPacket:
)
+ self.audio_mux_element.payload
)
def __init__(self, audio_mux_element: AudioMuxElement) -> None:
self.audio_mux_element = audio_mux_element
def __bytes__(self) -> bytes:
writer = BitWriter()
self.audio_mux_element.to_bits(writer)
return bytes(writer)

View File

@@ -1571,14 +1571,22 @@ class Connection(CompositeEventEmitter):
raise
def __str__(self):
return (
f'Connection(handle=0x{self.handle:04X}, '
f'role={self.role_name}, '
f'self_address={self.self_address}, '
f'self_resolvable_address={self.self_resolvable_address}, '
f'peer_address={self.peer_address}, '
f'peer_resolvable_address={self.peer_resolvable_address})'
)
if self.transport == BT_LE_TRANSPORT:
return (
f'Connection(transport=LE, handle=0x{self.handle:04X}, '
f'role={self.role_name}, '
f'self_address={self.self_address}, '
f'self_resolvable_address={self.self_resolvable_address}, '
f'peer_address={self.peer_address}, '
f'peer_resolvable_address={self.peer_resolvable_address})'
)
else:
return (
f'Connection(transport=BR/EDR, handle=0x{self.handle:04X}, '
f'role={self.role_name}, '
f'self_address={self.self_address}, '
f'peer_address={self.peer_address})'
)
# -----------------------------------------------------------------------------

View File

@@ -6065,6 +6065,32 @@ class HCI_Read_Remote_Version_Information_Complete_Event(HCI_Event):
'''
# -----------------------------------------------------------------------------
@HCI_Event.event(
[
('status', STATUS_SPEC),
('connection_handle', 2),
('unused', 1),
(
'service_type',
{
'size': 1,
'mapper': lambda x: HCI_QOS_Setup_Complete_Event.ServiceType(x).name,
},
),
]
)
class HCI_QOS_Setup_Complete_Event(HCI_Event):
'''
See Bluetooth spec @ 7.7.13 QoS Setup Complete Event
'''
class ServiceType(OpenIntEnum):
NO_TRAFFIC_AVAILABLE = 0x00
BEST_EFFORT_AVAILABLE = 0x01
GUARANTEED_AVAILABLE = 0x02
# -----------------------------------------------------------------------------
@HCI_Event.event(
[

View File

@@ -1106,6 +1106,18 @@ class Host(AbortableEventEmitter):
event.status,
)
def on_hci_qos_setup_complete_event(self, event):
if event.status == hci.HCI_SUCCESS:
self.emit(
'connection_qos_setup', event.connection_handle, event.service_type
)
else:
self.emit(
'connection_qos_setup_failure',
event.connection_handle,
event.status,
)
def on_hci_link_supervision_timeout_changed_event(self, event):
pass

110
bumble/rtp.py Normal file
View File

@@ -0,0 +1,110 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from typing import List
# -----------------------------------------------------------------------------
class MediaPacket:
@staticmethod
def from_bytes(data: bytes) -> MediaPacket:
version = (data[0] >> 6) & 0x03
padding = (data[0] >> 5) & 0x01
extension = (data[0] >> 4) & 0x01
csrc_count = data[0] & 0x0F
marker = (data[1] >> 7) & 0x01
payload_type = data[1] & 0x7F
sequence_number = struct.unpack_from('>H', data, 2)[0]
timestamp = struct.unpack_from('>I', data, 4)[0]
ssrc = struct.unpack_from('>I', data, 8)[0]
csrc_list = [
struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)
]
payload = data[12 + csrc_count * 4 :]
return MediaPacket(
version,
padding,
extension,
marker,
sequence_number,
timestamp,
ssrc,
csrc_list,
payload_type,
payload,
)
def __init__(
self,
version: int,
padding: int,
extension: int,
marker: int,
sequence_number: int,
timestamp: int,
ssrc: int,
csrc_list: List[int],
payload_type: int,
payload: bytes,
) -> None:
self.version = version
self.padding = padding
self.extension = extension
self.marker = marker
self.sequence_number = sequence_number & 0xFFFF
self.timestamp = timestamp & 0xFFFFFFFF
self.timestamp_seconds = 0.0
self.ssrc = ssrc
self.csrc_list = csrc_list
self.payload_type = payload_type
self.payload = payload
def __bytes__(self) -> bytes:
header = bytes(
[
self.version << 6
| self.padding << 5
| self.extension << 4
| len(self.csrc_list),
self.marker << 7 | self.payload_type,
]
) + struct.pack(
'>HII',
self.sequence_number,
self.timestamp,
self.ssrc,
)
for csrc in self.csrc_list:
header += struct.pack('>I', csrc)
return header + self.payload
def __str__(self) -> str:
return (
f'RTP(v={self.version},'
f'p={self.padding},'
f'x={self.extension},'
f'm={self.marker},'
f'pt={self.payload_type},'
f'sn={self.sequence_number},'
f'ts={self.timestamp},'
f'ssrc={self.ssrc},'
f'csrcs={self.csrc_list},'
f'payload_size={len(self.payload)})'
)