mirror of
https://github.com/google/bumble.git
synced 2026-05-09 04:08:02 +00:00
wip
This commit is contained in:
318
bumble/a2dp.py
318
bumble/a2dp.py
@@ -22,7 +22,9 @@ import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import struct
|
||||
from typing import List, Callable, Awaitable
|
||||
from typing import Awaitable, Callable, Iterable, List
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
from .codecs import AacAudioRtpPacket
|
||||
from .company_ids import COMPANY_IDENTIFIERS
|
||||
@@ -105,6 +107,8 @@ SBC_ALLOCATION_METHOD_NAMES = {
|
||||
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
|
||||
}
|
||||
|
||||
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
||||
|
||||
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
|
||||
8000,
|
||||
11025,
|
||||
@@ -132,6 +136,11 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
|
||||
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
|
||||
}
|
||||
|
||||
|
||||
OPUS_VENDOR_ID = 0x000000E0
|
||||
OPUS_CODEC_ID = 0x0001
|
||||
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
@@ -516,7 +525,7 @@ class VendorSpecificMediaCodecInformation:
|
||||
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
|
||||
return struct.pack('<IH', self.vendor_id, self.codec_id) + self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
# pylint: disable=line-too-long
|
||||
@@ -530,6 +539,105 @@ class VendorSpecificMediaCodecInformation:
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class OpusMediaCodecInformation(VendorSpecificMediaCodecInformation):
|
||||
channel_mode: int
|
||||
frame_size: int
|
||||
sampling_frequency: int
|
||||
|
||||
class ChannelMode(enum.IntEnum):
|
||||
MONO = 0
|
||||
STEREO = 1
|
||||
DUAL_MONO = 2
|
||||
|
||||
CHANNEL_MODE_BITS = {
|
||||
ChannelMode.MONO: 1 << 0,
|
||||
ChannelMode.STEREO: 1 << 1,
|
||||
ChannelMode.DUAL_MONO: 1 << 2,
|
||||
}
|
||||
|
||||
class FrameSize(enum.IntFlag):
|
||||
F_10MS = 0
|
||||
F_20MS = 1
|
||||
|
||||
FRAME_SIZE_BITS = {FrameSize.F_10MS: 1 << 0, FrameSize.F_20MS: 1 << 1}
|
||||
|
||||
SAMPLING_FREQUENCIES = [48000]
|
||||
SAMPLING_FREQUENCY_BITS = {
|
||||
48000: 1 << 0,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
"""Create a new instance from the `value` part of the data, not including
|
||||
the vendor id and codec id"""
|
||||
channel_mode = data[0] & 0x07
|
||||
frame_size = (data[0] >> 3) & 0x03
|
||||
sampling_frequency = (data[0] >> 7) & 0x01
|
||||
|
||||
return cls(
|
||||
OPUS_VENDOR_ID,
|
||||
OPUS_CODEC_ID,
|
||||
data,
|
||||
channel_mode,
|
||||
frame_size,
|
||||
sampling_frequency,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_discrete_values(
|
||||
cls, channel_mode: ChannelMode, frame_size: FrameSize, sampling_frequency: int
|
||||
) -> Self:
|
||||
channel_mode_int = cls.CHANNEL_MODE_BITS[channel_mode]
|
||||
frame_size_int = cls.FRAME_SIZE_BITS[frame_size]
|
||||
sampling_frequency_int = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency]
|
||||
value = bytes(
|
||||
[channel_mode_int | (frame_size_int << 3) | (sampling_frequency_int << 7)]
|
||||
)
|
||||
return cls(
|
||||
vendor_id=OPUS_VENDOR_ID,
|
||||
codec_id=OPUS_CODEC_ID,
|
||||
value=value,
|
||||
channel_mode=channel_mode_int,
|
||||
frame_size=frame_size_int,
|
||||
sampling_frequency=sampling_frequency_int,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lists(
|
||||
cls,
|
||||
channel_modes: Iterable[ChannelMode],
|
||||
frame_sizes: Iterable[FrameSize],
|
||||
sampling_frequencies: Iterable[int],
|
||||
) -> Self:
|
||||
channel_mode = sum(channel_modes)
|
||||
frame_size = sum(frame_sizes)
|
||||
sampling_frequency = sum(
|
||||
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
|
||||
)
|
||||
value = bytes([channel_mode | (frame_size << 3) | (sampling_frequency << 7)])
|
||||
return cls(
|
||||
vendor_id=OPUS_VENDOR_ID,
|
||||
codec_id=OPUS_CODEC_ID,
|
||||
value=value,
|
||||
channel_mode=channel_mode,
|
||||
frame_size=frame_size,
|
||||
sampling_frequency=sampling_frequency,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
# pylint: disable=line-too-long
|
||||
return '\n'.join(
|
||||
[
|
||||
'OpusMediaCodecInformation(',
|
||||
f' channel_mode: {",".join([x.name for x in flags_to_list(self.channel_mode, list(self.ChannelMode))])}',
|
||||
f' frame_size: {",".join([x.name for x in flags_to_list(self.frame_size, list(self.FrameSize))])}',
|
||||
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, self.SAMPLING_FREQUENCIES)])}',
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class SbcFrame:
|
||||
@@ -628,7 +736,7 @@ class SbcPacketSource:
|
||||
from .avdtp import MediaPacket # Import here to avoid a circular reference
|
||||
|
||||
sequence_number = 0
|
||||
timestamp = 0
|
||||
sample_count = 0
|
||||
frames = []
|
||||
frames_size = 0
|
||||
max_rtp_payload = self.mtu - 12 - 1
|
||||
@@ -638,26 +746,27 @@ class SbcPacketSource:
|
||||
async for frame in sbc_parser.frames:
|
||||
if (
|
||||
frames_size + len(frame.payload) > max_rtp_payload
|
||||
or len(frames) == 16
|
||||
or len(frames) == SBC_MAX_FRAMES_IN_RTP_PAYLOAD
|
||||
):
|
||||
# Need to flush what has been accumulated so far
|
||||
logger.debug(f"yielding {len(frames)} frames")
|
||||
|
||||
# Emit a packet
|
||||
sbc_payload = bytes([len(frames)]) + b''.join(
|
||||
sbc_payload = bytes([len(frames) & 0x0F]) + b''.join(
|
||||
[frame.payload for frame in frames]
|
||||
)
|
||||
timestamp_seconds = sample_count / frame.sampling_frequency
|
||||
timestamp = int(1000 * timestamp_seconds)
|
||||
packet = MediaPacket(
|
||||
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
|
||||
)
|
||||
packet.timestamp_seconds = timestamp / frame.sampling_frequency
|
||||
packet.timestamp_seconds = timestamp_seconds
|
||||
yield packet
|
||||
|
||||
# Prepare for next packets
|
||||
sequence_number += 1
|
||||
sequence_number &= 0xFFFF
|
||||
timestamp += sum((frame.sample_count for frame in frames))
|
||||
timestamp &= 0xFFFFFFFF
|
||||
sample_count += sum((frame.sample_count for frame in frames))
|
||||
frames = [frame]
|
||||
frames_size = len(frame.payload)
|
||||
else:
|
||||
@@ -775,7 +884,7 @@ class AacPacketSource:
|
||||
from .avdtp import MediaPacket # Import here to avoid a circular reference
|
||||
|
||||
sequence_number = 0
|
||||
timestamp = 0
|
||||
sample_count = 0
|
||||
|
||||
aac_parser = AacParser(self.read)
|
||||
async for frame in aac_parser.frames:
|
||||
@@ -789,17 +898,200 @@ class AacPacketSource:
|
||||
frame.payload,
|
||||
)
|
||||
)
|
||||
timestamp_seconds = sample_count / frame.sampling_frequency
|
||||
timestamp = int(1000 * timestamp_seconds)
|
||||
packet = MediaPacket(
|
||||
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, aac_payload
|
||||
)
|
||||
packet.timestamp_seconds = timestamp / frame.sampling_frequency
|
||||
packet.timestamp_seconds = timestamp_seconds
|
||||
yield packet
|
||||
|
||||
# Prepare for next packets
|
||||
sequence_number += 1
|
||||
sequence_number &= 0xFFFF
|
||||
timestamp += frame.sample_count
|
||||
timestamp &= 0xFFFFFFFF
|
||||
frames = [frame]
|
||||
sample_count += frame.sample_count
|
||||
|
||||
return generate_packets()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class OpusPacket:
|
||||
class ChannelMode(enum.IntEnum):
|
||||
MONO = 0
|
||||
STEREO = 1
|
||||
DUAL_MONO = 2
|
||||
|
||||
channel_mode: ChannelMode
|
||||
duration: int # Duration in ms.
|
||||
sampling_frequency: int
|
||||
payload: bytes
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Opus(ch={self.channel_mode.name}, '
|
||||
f'd={self.duration}ms, '
|
||||
f'size={len(self.payload)})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OpusParser:
|
||||
"""
|
||||
Parser for Opus packets in an Ogg stream
|
||||
|
||||
See RFC 3533
|
||||
|
||||
NOTE: this parser only supports bitstreams with a single logical stream.
|
||||
"""
|
||||
|
||||
CAPTURE_PATTERN = b'OggS'
|
||||
|
||||
class HeaderType(enum.IntFlag):
|
||||
CONTINUED = 0x01
|
||||
FIRST = 0x02
|
||||
LAST = 0x04
|
||||
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
|
||||
self.read = read
|
||||
|
||||
@property
|
||||
def packets(self) -> AsyncGenerator[OpusPacket, None]:
|
||||
async def generate_frames() -> AsyncGenerator[OpusPacket, None]:
|
||||
packet = b''
|
||||
packet_count = 0
|
||||
expected_bitstream_serial_number = None
|
||||
expected_page_sequence_number = 0
|
||||
channel_mode = OpusPacket.ChannelMode.STEREO
|
||||
|
||||
while True:
|
||||
# Parse the page header
|
||||
header = await self.read(27)
|
||||
if len(header) != 27:
|
||||
logger.debug("end of stream")
|
||||
break
|
||||
|
||||
capture_pattern = header[:4]
|
||||
if capture_pattern != self.CAPTURE_PATTERN:
|
||||
print(capture_pattern.hex())
|
||||
raise ValueError("invalid capture pattern at start of page")
|
||||
|
||||
version = header[4]
|
||||
if version != 0:
|
||||
raise ValueError(f"version {version} not supported")
|
||||
|
||||
header_type = self.HeaderType(header[5])
|
||||
(granule_position,) = struct.unpack_from("<Q", header, 6)
|
||||
(bitstream_serial_number,) = struct.unpack_from("<I", header, 14)
|
||||
(page_sequence_number,) = struct.unpack_from("<I", header, 18)
|
||||
(crc_checksum,) = struct.unpack_from("<I", header, 22)
|
||||
page_segments = header[26]
|
||||
segment_table = await self.read(page_segments)
|
||||
|
||||
if header_type & self.HeaderType.FIRST:
|
||||
if expected_bitstream_serial_number is None:
|
||||
# We will only accept pages for the first encountered stream
|
||||
logger.debug("BOS")
|
||||
expected_bitstream_serial_number = bitstream_serial_number
|
||||
expected_page_sequence_number = page_sequence_number
|
||||
|
||||
if (
|
||||
expected_bitstream_serial_number is None
|
||||
or expected_bitstream_serial_number != bitstream_serial_number
|
||||
):
|
||||
logger.debug("skipping page (not the first logical bitstream)")
|
||||
for lacing_value in segment_table:
|
||||
if lacing_value:
|
||||
await self.read(lacing_value)
|
||||
continue
|
||||
|
||||
if expected_page_sequence_number != page_sequence_number:
|
||||
raise ValueError(
|
||||
f"expected page sequence number {expected_page_sequence_number}"
|
||||
f" but got {page_sequence_number}"
|
||||
)
|
||||
expected_page_sequence_number = page_sequence_number + 1
|
||||
|
||||
# Assemble the page
|
||||
if not header_type & self.HeaderType.CONTINUED:
|
||||
packet = b''
|
||||
for lacing_value in segment_table:
|
||||
if lacing_value:
|
||||
packet += await self.read(lacing_value)
|
||||
if lacing_value < 255:
|
||||
# End of packet
|
||||
packet_count += 1
|
||||
|
||||
if packet_count == 1:
|
||||
# The first packet contains the identification header
|
||||
logger.debug("first packet (header)")
|
||||
if packet[:8] != b"OpusHead":
|
||||
raise ValueError("first packet is not OpusHead")
|
||||
packet_count = (
|
||||
OpusPacket.ChannelMode.MONO
|
||||
if packet[9] == 1
|
||||
else OpusPacket.ChannelMode.STEREO
|
||||
)
|
||||
|
||||
elif packet_count == 2:
|
||||
# The second packet contains the comment header
|
||||
logger.debug("second packet (tags)")
|
||||
if packet[:8] != b"OpusTags":
|
||||
logger.warning("second packet is not OpusTags")
|
||||
else:
|
||||
yield OpusPacket(channel_mode, 20, 48000, packet)
|
||||
|
||||
packet = b''
|
||||
|
||||
if header_type & self.HeaderType.LAST:
|
||||
logger.debug("EOS")
|
||||
|
||||
return generate_frames()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OpusPacketSource:
|
||||
def __init__(
|
||||
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
|
||||
) -> None:
|
||||
self.read = read
|
||||
self.mtu = mtu
|
||||
self.codec_capabilities = codec_capabilities
|
||||
|
||||
@property
|
||||
def packets(self):
|
||||
async def generate_packets():
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from .avdtp import MediaPacket # Import here to avoid a circular reference
|
||||
|
||||
sequence_number = 0
|
||||
elapsed_ms = 0
|
||||
|
||||
opus_parser = OpusParser(self.read)
|
||||
async for opus_packet in opus_parser.packets:
|
||||
# We only support sending one Opus frame per RTP packet
|
||||
# TODO: check the spec for the first byte value here
|
||||
opus_payload = bytes([1]) + opus_packet.payload
|
||||
elapsed_s = elapsed_ms / 1000
|
||||
timestamp = int(elapsed_s * opus_packet.sampling_frequency)
|
||||
rtp_packet = MediaPacket(
|
||||
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, opus_payload
|
||||
)
|
||||
rtp_packet.timestamp_seconds = elapsed_s
|
||||
yield rtp_packet
|
||||
|
||||
# Prepare for next packets
|
||||
sequence_number += 1
|
||||
sequence_number &= 0xFFFF
|
||||
elapsed_ms += opus_packet.duration
|
||||
|
||||
return generate_packets()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# This map should be left at the end of the file so it can refer to the classes
|
||||
# above
|
||||
# -----------------------------------------------------------------------------
|
||||
A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES = {
|
||||
OPUS_VENDOR_ID: {OPUS_CODEC_ID: OpusMediaCodecInformation}
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ from .a2dp import (
|
||||
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
|
||||
A2DP_NON_A2DP_CODEC_TYPE,
|
||||
A2DP_SBC_CODEC_TYPE,
|
||||
A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES,
|
||||
AacMediaCodecInformation,
|
||||
SbcMediaCodecInformation,
|
||||
VendorSpecificMediaCodecInformation,
|
||||
@@ -328,6 +329,7 @@ class MediaPacket:
|
||||
self.marker = marker
|
||||
self.sequence_number = sequence_number & 0xFFFF
|
||||
self.timestamp = timestamp & 0xFFFFFFFF
|
||||
self.timestamp_seconds = 0.0
|
||||
self.ssrc = ssrc
|
||||
self.csrc_list = csrc_list
|
||||
self.payload_type = payload_type
|
||||
@@ -621,11 +623,25 @@ class MediaCodecCapabilities(ServiceCapabilities):
|
||||
self.media_codec_information
|
||||
)
|
||||
elif self.media_codec_type == A2DP_NON_A2DP_CODEC_TYPE:
|
||||
self.media_codec_information = (
|
||||
vendor_media_codec_information = (
|
||||
VendorSpecificMediaCodecInformation.from_bytes(
|
||||
self.media_codec_information
|
||||
)
|
||||
)
|
||||
if (
|
||||
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
|
||||
vendor_media_codec_information.vendor_id
|
||||
)
|
||||
) and (
|
||||
media_codec_information_class := vendor_class_map.get(
|
||||
vendor_media_codec_information.codec_id
|
||||
)
|
||||
):
|
||||
self.media_codec_information = media_codec_information_class.from_bytes(
|
||||
vendor_media_codec_information.value
|
||||
)
|
||||
else:
|
||||
self.media_codec_information = vendor_media_codec_information
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1388,7 +1404,7 @@ class Protocol(EventEmitter):
|
||||
return self.remote_endpoints.values()
|
||||
|
||||
def find_remote_sink_by_codec(
|
||||
self, media_type: int, codec_type: int
|
||||
self, media_type: int, codec_type: int, vendor_id: int = 0, codec_id: int = 0
|
||||
) -> Optional[DiscoveredStreamEndPoint]:
|
||||
for endpoint in self.remote_endpoints.values():
|
||||
if (
|
||||
@@ -1413,7 +1429,19 @@ class Protocol(EventEmitter):
|
||||
codec_capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE
|
||||
and codec_capabilities.media_codec_type == codec_type
|
||||
):
|
||||
has_codec = True
|
||||
if isinstance(
|
||||
codec_capabilities.media_codec_information,
|
||||
VendorSpecificMediaCodecInformation,
|
||||
):
|
||||
if (
|
||||
codec_capabilities.media_codec_information.vendor_id
|
||||
== vendor_id
|
||||
and codec_capabilities.media_codec_information.codec_id
|
||||
== codec_id
|
||||
):
|
||||
has_codec = True
|
||||
else:
|
||||
has_codec = True
|
||||
if has_media_transport and has_codec:
|
||||
return endpoint
|
||||
|
||||
|
||||
@@ -1571,14 +1571,22 @@ class Connection(CompositeEventEmitter):
|
||||
raise
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f'Connection(handle=0x{self.handle:04X}, '
|
||||
f'role={self.role_name}, '
|
||||
f'self_address={self.self_address}, '
|
||||
f'self_resolvable_address={self.self_resolvable_address}, '
|
||||
f'peer_address={self.peer_address}, '
|
||||
f'peer_resolvable_address={self.peer_resolvable_address})'
|
||||
)
|
||||
if self.transport == BT_LE_TRANSPORT:
|
||||
return (
|
||||
f'Connection(transport=LE, handle=0x{self.handle:04X}, '
|
||||
f'role={self.role_name}, '
|
||||
f'self_address={self.self_address}, '
|
||||
f'self_resolvable_address={self.self_resolvable_address}, '
|
||||
f'peer_address={self.peer_address}, '
|
||||
f'peer_resolvable_address={self.peer_resolvable_address})'
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f'Connection(transport=BR/EDR, handle=0x{self.handle:04X}, '
|
||||
f'role={self.role_name}, '
|
||||
f'self_address={self.self_address}, '
|
||||
f'peer_address={self.peer_address})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user