mirror of
https://github.com/google/bumble.git
synced 2026-04-16 00:25:31 +00:00
2301 lines
76 KiB
Python
2301 lines
76 KiB
Python
# Copyright 2021-2022 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Imports
|
|
# -----------------------------------------------------------------------------
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import enum
|
|
import logging
|
|
import time
|
|
import warnings
|
|
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable
|
|
from dataclasses import dataclass, field
|
|
from typing import (
|
|
Any,
|
|
ClassVar,
|
|
SupportsBytes,
|
|
TypeVar,
|
|
cast,
|
|
)
|
|
|
|
from typing_extensions import override
|
|
|
|
from bumble import a2dp, device, hci, l2cap, sdp, utils
|
|
from bumble.colors import color
|
|
from bumble.core import (
|
|
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
|
|
InvalidStateError,
|
|
ProtocolError,
|
|
)
|
|
from bumble.rtp import MediaPacket
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Logging
|
|
# -----------------------------------------------------------------------------
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Constants
|
|
# -----------------------------------------------------------------------------
|
|
# fmt: off
|
|
# pylint: disable=line-too-long
|
|
|
|
AVDTP_PSM = 0x0019
|
|
|
|
AVDTP_DEFAULT_RTX_SIG_TIMER = 5 # Seconds
|
|
|
|
# Signal Identifiers (AVDTP spec - 8.5 Signal Command Set)
|
|
class SignalIdentifier(hci.SpecableEnum):
|
|
DISCOVER = 0x01
|
|
GET_CAPABILITIES = 0x02
|
|
SET_CONFIGURATION = 0x03
|
|
GET_CONFIGURATION = 0x04
|
|
RECONFIGURE = 0x05
|
|
OPEN = 0x06
|
|
START = 0x07
|
|
CLOSE = 0x08
|
|
SUSPEND = 0x09
|
|
ABORT = 0x0A
|
|
SECURITY_CONTROL = 0x0B
|
|
GET_ALL_CAPABILITIES = 0x0C
|
|
DELAYREPORT = 0x0D
|
|
|
|
AVDTP_DISCOVER = SignalIdentifier.DISCOVER
|
|
AVDTP_GET_CAPABILITIES = SignalIdentifier.GET_CAPABILITIES
|
|
AVDTP_SET_CONFIGURATION = SignalIdentifier.SET_CONFIGURATION
|
|
AVDTP_GET_CONFIGURATION = SignalIdentifier.GET_CONFIGURATION
|
|
AVDTP_RECONFIGURE = SignalIdentifier.RECONFIGURE
|
|
AVDTP_OPEN = SignalIdentifier.OPEN
|
|
AVDTP_START = SignalIdentifier.START
|
|
AVDTP_CLOSE = SignalIdentifier.CLOSE
|
|
AVDTP_SUSPEND = SignalIdentifier.SUSPEND
|
|
AVDTP_ABORT = SignalIdentifier.ABORT
|
|
AVDTP_SECURITY_CONTROL = SignalIdentifier.SECURITY_CONTROL
|
|
AVDTP_GET_ALL_CAPABILITIES = SignalIdentifier.GET_ALL_CAPABILITIES
|
|
AVDTP_DELAYREPORT = SignalIdentifier.DELAYREPORT
|
|
|
|
class ErrorCode(hci.SpecableEnum):
|
|
'''Error codes (AVDTP spec - 8.20.6.2 ERROR_CODE tables)'''
|
|
BAD_HEADER_FORMAT = 0x01
|
|
BAD_LENGTH = 0x11
|
|
BAD_ACP_SEID = 0x12
|
|
SEP_IN_USE = 0x13
|
|
SEP_NOT_IN_USE = 0x14
|
|
BAD_SERV_CATEGORY = 0x17
|
|
BAD_PAYLOAD_FORMAT = 0x18
|
|
NOT_SUPPORTED_COMMAND = 0x19
|
|
INVALID_CAPABILITIES = 0x1A
|
|
BAD_RECOVERY_TYPE = 0x22
|
|
BAD_MEDIA_TRANSPORT_FORMAT = 0x23
|
|
BAD_RECOVERY_FORMAT = 0x25
|
|
BAD_ROHC_FORMAT = 0x26
|
|
BAD_CP_FORMAT = 0x27
|
|
BAD_MULTIPLEXING_FORMAT = 0x28
|
|
UNSUPPORTED_CONFIGURATION = 0x29
|
|
BAD_STATE = 0x31
|
|
|
|
AVDTP_BAD_HEADER_FORMAT_ERROR = ErrorCode.BAD_HEADER_FORMAT
|
|
AVDTP_BAD_LENGTH_ERROR = ErrorCode.BAD_LENGTH
|
|
AVDTP_BAD_ACP_SEID_ERROR = ErrorCode.BAD_ACP_SEID
|
|
AVDTP_SEP_IN_USE_ERROR = ErrorCode.SEP_IN_USE
|
|
AVDTP_SEP_NOT_IN_USE_ERROR = ErrorCode.SEP_NOT_IN_USE
|
|
AVDTP_BAD_SERV_CATEGORY_ERROR = ErrorCode.BAD_SERV_CATEGORY
|
|
AVDTP_BAD_PAYLOAD_FORMAT_ERROR = ErrorCode.BAD_PAYLOAD_FORMAT
|
|
AVDTP_NOT_SUPPORTED_COMMAND_ERROR = ErrorCode.NOT_SUPPORTED_COMMAND
|
|
AVDTP_INVALID_CAPABILITIES_ERROR = ErrorCode.INVALID_CAPABILITIES
|
|
AVDTP_BAD_RECOVERY_TYPE_ERROR = ErrorCode.BAD_RECOVERY_TYPE
|
|
AVDTP_BAD_MEDIA_TRANSPORT_FORMAT_ERROR = ErrorCode.BAD_MEDIA_TRANSPORT_FORMAT
|
|
AVDTP_BAD_RECOVERY_FORMAT_ERROR = ErrorCode.BAD_RECOVERY_FORMAT
|
|
AVDTP_BAD_ROHC_FORMAT_ERROR = ErrorCode.BAD_ROHC_FORMAT
|
|
AVDTP_BAD_CP_FORMAT_ERROR = ErrorCode.BAD_CP_FORMAT
|
|
AVDTP_BAD_MULTIPLEXING_FORMAT_ERROR = ErrorCode.BAD_MULTIPLEXING_FORMAT
|
|
AVDTP_UNSUPPORTED_CONFIGURATION_ERROR = ErrorCode.UNSUPPORTED_CONFIGURATION
|
|
AVDTP_BAD_STATE_ERROR = ErrorCode.BAD_STATE
|
|
|
|
class MediaType(utils.OpenIntEnum):
|
|
AUDIO = 0x00
|
|
VIDEO = 0x01
|
|
MULTIMEDIA = 0x02
|
|
|
|
AVDTP_AUDIO_MEDIA_TYPE = MediaType.AUDIO
|
|
AVDTP_VIDEO_MEDIA_TYPE = MediaType.VIDEO
|
|
AVDTP_MULTIMEDIA_MEDIA_TYPE = MediaType.MULTIMEDIA
|
|
|
|
class StreamEndPointType(utils.OpenIntEnum):
|
|
'''TSEP (AVDTP spec - 8.20.3 Stream End-point Type, Source or Sink (TSEP)).'''
|
|
SRC = 0x00
|
|
SNK = 0x01
|
|
|
|
AVDTP_TSEP_SRC = StreamEndPointType.SRC
|
|
AVDTP_TSEP_SNK = StreamEndPointType.SNK
|
|
|
|
class ServiceCategory(hci.SpecableEnum):
|
|
'''Service Categories (AVDTP spec - Table 8.47: Service Category information element field values).'''
|
|
MEDIA_TRANSPORT = 0x01
|
|
REPORTING = 0x02
|
|
RECOVERY = 0x03
|
|
CONTENT_PROTECTION = 0x04
|
|
HEADER_COMPRESSION = 0x05
|
|
MULTIPLEXING = 0x06
|
|
MEDIA_CODEC = 0x07
|
|
DELAY_REPORTING = 0x08
|
|
|
|
AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY = ServiceCategory.MEDIA_TRANSPORT
|
|
AVDTP_REPORTING_SERVICE_CATEGORY = ServiceCategory.REPORTING
|
|
AVDTP_RECOVERY_SERVICE_CATEGORY = ServiceCategory.RECOVERY
|
|
AVDTP_CONTENT_PROTECTION_SERVICE_CATEGORY = ServiceCategory.CONTENT_PROTECTION
|
|
AVDTP_HEADER_COMPRESSION_SERVICE_CATEGORY = ServiceCategory.HEADER_COMPRESSION
|
|
AVDTP_MULTIPLEXING_SERVICE_CATEGORY = ServiceCategory.MULTIPLEXING
|
|
AVDTP_MEDIA_CODEC_SERVICE_CATEGORY = ServiceCategory.MEDIA_CODEC
|
|
AVDTP_DELAY_REPORTING_SERVICE_CATEGORY = ServiceCategory.DELAY_REPORTING
|
|
|
|
class State(utils.OpenIntEnum):
|
|
'''States (AVDTP spec - 9.1 State Definitions)'''
|
|
IDLE = 0x00
|
|
CONFIGURED = 0x01
|
|
OPEN = 0x02
|
|
STREAMING = 0x03
|
|
CLOSING = 0x04
|
|
ABORTING = 0x05
|
|
|
|
# fmt: on
|
|
# pylint: enable=line-too-long
|
|
# pylint: disable=invalid-name
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
async def find_avdtp_service_with_sdp_client(
|
|
sdp_client: sdp.Client,
|
|
) -> tuple[int, int] | None:
|
|
'''
|
|
Find an AVDTP service, using a connected SDP client, and return its version,
|
|
or None if none is found
|
|
'''
|
|
|
|
# Search for services with an Audio Sink service class
|
|
search_result = await sdp_client.search_attributes(
|
|
[BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE],
|
|
[sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID],
|
|
)
|
|
for attribute_list in search_result:
|
|
profile_descriptor_list = sdp.ServiceAttribute.find_attribute_in_list(
|
|
attribute_list, sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
|
|
)
|
|
if profile_descriptor_list:
|
|
for profile_descriptor in profile_descriptor_list.value:
|
|
if (
|
|
profile_descriptor.type == sdp.DataElement.SEQUENCE
|
|
and len(profile_descriptor.value) >= 2
|
|
):
|
|
avdtp_version_major = profile_descriptor.value[1].value >> 8
|
|
avdtp_version_minor = profile_descriptor.value[1].value & 0xFF
|
|
return (avdtp_version_major, avdtp_version_minor)
|
|
return None
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
async def find_avdtp_service_with_connection(
|
|
connection: device.Connection,
|
|
) -> tuple[int, int] | None:
|
|
'''
|
|
Find an AVDTP service, for a connection, and return its version,
|
|
or None if none is found
|
|
'''
|
|
|
|
sdp_client = sdp.Client(connection)
|
|
await sdp_client.connect()
|
|
service_version = await find_avdtp_service_with_sdp_client(sdp_client)
|
|
await sdp_client.disconnect()
|
|
|
|
return service_version
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class RealtimeClock:
|
|
def now(self) -> float:
|
|
return time.time()
|
|
|
|
async def sleep(self, duration: float) -> None:
|
|
await asyncio.sleep(duration)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class MediaPacketPump:
|
|
pump_task: asyncio.Task | None
|
|
|
|
def __init__(
|
|
self, packets: AsyncGenerator, clock: RealtimeClock = RealtimeClock()
|
|
) -> None:
|
|
self.packets = packets
|
|
self.clock = clock
|
|
self.pump_task = None
|
|
self.completed = asyncio.Event()
|
|
|
|
async def start(self, rtp_channel: l2cap.ClassicChannel) -> None:
|
|
async def pump_packets():
|
|
start_time = 0
|
|
start_timestamp = 0
|
|
|
|
try:
|
|
logger.debug('pump starting')
|
|
async for packet in self.packets:
|
|
# Capture the timestamp of the first packet
|
|
if start_time == 0:
|
|
start_time = self.clock.now()
|
|
start_timestamp = packet.timestamp_seconds
|
|
|
|
# Wait until we can send
|
|
when = start_time + (packet.timestamp_seconds - start_timestamp)
|
|
now = self.clock.now()
|
|
if when > now:
|
|
delay = when - now
|
|
logger.debug(f'waiting for {delay}')
|
|
await self.clock.sleep(delay)
|
|
|
|
# Emit
|
|
rtp_channel.write(bytes(packet))
|
|
logger.debug(
|
|
f'{color(">>> sending RTP packet:", "green")} {packet}'
|
|
)
|
|
except asyncio.exceptions.CancelledError:
|
|
logger.debug('pump canceled')
|
|
finally:
|
|
self.completed.set()
|
|
|
|
# Pump packets
|
|
self.pump_task = asyncio.create_task(pump_packets())
|
|
|
|
async def stop(self) -> None:
|
|
# Stop the pump
|
|
if self.pump_task:
|
|
self.pump_task.cancel()
|
|
await self.pump_task
|
|
self.pump_task = None
|
|
|
|
async def wait_for_completion(self) -> None:
|
|
await self.completed.wait()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class MessageAssembler:
|
|
message: bytes | None
|
|
signal_identifier: SignalIdentifier
|
|
|
|
def __init__(self, callback: Callable[[int, Message], Any]) -> None:
|
|
self.callback = callback
|
|
self.reset()
|
|
|
|
def reset(self) -> None:
|
|
self.transaction_label = 0
|
|
self.message = None
|
|
self.message_type = Message.MessageType.COMMAND
|
|
self.signal_identifier = SignalIdentifier(0)
|
|
self.number_of_signal_packets = 0
|
|
self.packet_count = 0
|
|
|
|
def on_pdu(self, pdu: bytes) -> None:
|
|
self.packet_count += 1
|
|
|
|
transaction_label = pdu[0] >> 4
|
|
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
|
|
message_type = Message.MessageType(pdu[0] & 3)
|
|
|
|
logger.debug(
|
|
f'transaction_label={transaction_label}, '
|
|
f'packet_type={packet_type.name}, '
|
|
f'message_type={message_type.name}'
|
|
)
|
|
if packet_type in (
|
|
Protocol.PacketType.SINGLE_PACKET,
|
|
Protocol.PacketType.START_PACKET,
|
|
):
|
|
if self.message is not None:
|
|
# The previous message has not been terminated
|
|
logger.warning(
|
|
'received a start or single packet when expecting an end or '
|
|
'continuation'
|
|
)
|
|
self.reset()
|
|
|
|
self.transaction_label = transaction_label
|
|
self.signal_identifier = SignalIdentifier(pdu[1] & 0x3F)
|
|
self.message_type = message_type
|
|
|
|
if packet_type == Protocol.PacketType.SINGLE_PACKET:
|
|
self.message = pdu[2:]
|
|
self.on_message_complete()
|
|
else:
|
|
self.number_of_signal_packets = pdu[2]
|
|
self.message = pdu[3:]
|
|
elif packet_type in (
|
|
Protocol.PacketType.CONTINUE_PACKET,
|
|
Protocol.PacketType.END_PACKET,
|
|
):
|
|
if self.packet_count == 0:
|
|
logger.warning('unexpected continuation')
|
|
return
|
|
|
|
if transaction_label != self.transaction_label:
|
|
logger.warning(
|
|
f'transaction label mismatch: expected {self.transaction_label}, '
|
|
f'received {transaction_label}'
|
|
)
|
|
return
|
|
|
|
if message_type != self.message_type:
|
|
logger.warning(
|
|
f'message type mismatch: expected {self.message_type}, '
|
|
f'received {message_type}'
|
|
)
|
|
return
|
|
|
|
self.message = (self.message or b'') + pdu[1:]
|
|
|
|
if packet_type == Protocol.PacketType.END_PACKET:
|
|
if self.packet_count != self.number_of_signal_packets:
|
|
logger.warning(
|
|
'incomplete fragmented message: '
|
|
f'expected {self.number_of_signal_packets} packets, '
|
|
f'received {self.packet_count}'
|
|
)
|
|
self.reset()
|
|
return
|
|
|
|
self.on_message_complete()
|
|
else:
|
|
if self.packet_count > self.number_of_signal_packets:
|
|
logger.warning(
|
|
'too many packets: '
|
|
f'expected {self.number_of_signal_packets}, '
|
|
f'received {self.packet_count}'
|
|
)
|
|
self.reset()
|
|
return
|
|
|
|
def on_message_complete(self) -> None:
|
|
message = Message.create(
|
|
self.signal_identifier,
|
|
self.message_type,
|
|
self.message or b'',
|
|
)
|
|
try:
|
|
self.callback(self.transaction_label, message)
|
|
except Exception:
|
|
logger.exception(color('!!! exception in callback', 'red'))
|
|
|
|
self.reset()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@dataclass
|
|
class ServiceCapabilities:
|
|
METADATA = hci.metadata(
|
|
{
|
|
'parser': lambda data, offset: (
|
|
len(data),
|
|
ServiceCapabilities.parse_capabilities(data[offset:]),
|
|
),
|
|
'serializer': lambda capabilities: ServiceCapabilities.serialize_capabilities(
|
|
capabilities
|
|
),
|
|
}
|
|
)
|
|
service_category: int
|
|
service_capabilities_bytes: bytes = b''
|
|
|
|
@classmethod
|
|
def create(
|
|
cls, service_category: int, service_capabilities_bytes: bytes
|
|
) -> ServiceCapabilities:
|
|
# Select the appropriate subclass
|
|
if service_category == AVDTP_MEDIA_CODEC_SERVICE_CATEGORY:
|
|
return MediaCodecCapabilities.from_bytes(service_capabilities_bytes)
|
|
return ServiceCapabilities(
|
|
service_category=service_category,
|
|
service_capabilities_bytes=service_capabilities_bytes,
|
|
)
|
|
|
|
@classmethod
|
|
def parse_capabilities(cls, payload: bytes) -> list[ServiceCapabilities]:
|
|
capabilities = []
|
|
offset = 0
|
|
while offset < len(payload):
|
|
service_category = payload[offset]
|
|
length_of_service_capabilities = payload[offset + 1]
|
|
service_capabilities_bytes = payload[
|
|
offset + 2 : offset + 2 + length_of_service_capabilities
|
|
]
|
|
capabilities.append(
|
|
ServiceCapabilities.create(service_category, service_capabilities_bytes)
|
|
)
|
|
offset += 2 + length_of_service_capabilities
|
|
|
|
return capabilities
|
|
|
|
@classmethod
|
|
def serialize_capabilities(
|
|
cls, capabilities: Iterable[ServiceCapabilities]
|
|
) -> bytes:
|
|
return b''.join(
|
|
bytes([item.service_category, len(item.service_capabilities_bytes)])
|
|
+ item.service_capabilities_bytes
|
|
for item in capabilities
|
|
)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@dataclass(init=False)
|
|
class MediaCodecCapabilities(ServiceCapabilities):
|
|
service_category = AVDTP_MEDIA_CODEC_SERVICE_CATEGORY
|
|
# Redeclare this attribute to suppress inheritance error.
|
|
service_capabilities_bytes: bytes
|
|
|
|
media_type: MediaType
|
|
media_codec_type: a2dp.CodecType
|
|
media_codec_information: bytes | SupportsBytes
|
|
|
|
# Override init to allow passing service_capabilities_bytes.
|
|
def __init__(
|
|
self,
|
|
media_type: MediaType,
|
|
media_codec_type: a2dp.CodecType,
|
|
media_codec_information: bytes | SupportsBytes,
|
|
service_capabilities_bytes: bytes | None = None,
|
|
) -> None:
|
|
self.media_type = media_type
|
|
self.media_codec_type = media_codec_type
|
|
|
|
if isinstance(media_codec_information, bytes):
|
|
self.media_codec_information = a2dp.MediaCodecInformation.create(
|
|
media_codec_type, media_codec_information
|
|
)
|
|
else:
|
|
self.media_codec_information = media_codec_information
|
|
|
|
if service_capabilities_bytes is not None:
|
|
self.service_capabilities_bytes = service_capabilities_bytes
|
|
else:
|
|
self.service_capabilities_bytes = bytes(
|
|
[self.media_type, self.media_codec_type]
|
|
) + bytes(self.media_codec_information)
|
|
|
|
@classmethod
|
|
def from_bytes(cls, data: bytes) -> ServiceCapabilities:
|
|
media_type = MediaType(data[0])
|
|
media_codec_type = a2dp.CodecType(data[1])
|
|
return cls(
|
|
media_type=media_type,
|
|
media_codec_type=media_codec_type,
|
|
media_codec_information=a2dp.MediaCodecInformation.create(
|
|
media_codec_type, data[2:]
|
|
),
|
|
)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@dataclass
|
|
class EndPointInfo:
|
|
seid: int
|
|
in_use: int
|
|
media_type: MediaType
|
|
tsep: StreamEndPointType
|
|
|
|
@classmethod
|
|
def from_bytes(cls, payload: bytes) -> EndPointInfo:
|
|
return cls(
|
|
seid=payload[0] >> 2,
|
|
in_use=payload[0] >> 1 & 1,
|
|
media_type=MediaType(payload[1] >> 4),
|
|
tsep=StreamEndPointType(payload[1] >> 3 & 1),
|
|
)
|
|
|
|
def __bytes__(self) -> bytes:
|
|
return bytes(
|
|
[self.seid << 2 | self.in_use << 1, self.media_type << 4 | self.tsep << 3]
|
|
)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class Message:
|
|
class MessageType(enum.IntEnum):
|
|
COMMAND = 0
|
|
GENERAL_REJECT = 1
|
|
RESPONSE_ACCEPT = 2
|
|
RESPONSE_REJECT = 3
|
|
|
|
SEID_METADATA = hci.metadata(
|
|
{
|
|
'serializer': lambda seid: bytes([seid << 2]),
|
|
'parser': lambda data, offset: (offset + 1, data[offset] >> 2),
|
|
}
|
|
)
|
|
|
|
# Subclasses, by signal identifier and message type
|
|
subclasses: ClassVar[dict[int, dict[int, type[Message]]]] = {}
|
|
|
|
message_type: MessageType
|
|
signal_identifier: SignalIdentifier
|
|
_payload: bytes | None = None
|
|
fields: ClassVar[hci.Fields] = ()
|
|
|
|
@property
|
|
def payload(self) -> bytes:
|
|
if self._payload is None:
|
|
self._payload = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
|
|
return self._payload
|
|
|
|
@payload.setter
|
|
def payload(self, payload: bytes) -> None:
|
|
self._payload = payload
|
|
|
|
_Message = TypeVar("_Message", bound="Message")
|
|
|
|
@classmethod
|
|
def subclass(cls, subclass: type[_Message]) -> type[_Message]:
|
|
cls.subclasses.setdefault(subclass.signal_identifier, {})[
|
|
subclass.message_type
|
|
] = subclass
|
|
subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass)
|
|
return subclass
|
|
|
|
# Factory method to create a subclass based on the signal identifier and message
|
|
# type
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
signal_identifier: SignalIdentifier,
|
|
message_type: MessageType,
|
|
payload: bytes,
|
|
) -> Message:
|
|
instance: Message
|
|
# Look for a registered subclass
|
|
if (subclasses := Message.subclasses.get(signal_identifier)) and (
|
|
subclass := subclasses.get(message_type)
|
|
):
|
|
instance = subclass(
|
|
**hci.HCI_Object.dict_from_bytes(payload, 0, subclass.fields),
|
|
)
|
|
instance.payload = payload
|
|
return instance
|
|
|
|
# Instantiate the appropriate class based on the message type
|
|
if message_type == Message.MessageType.RESPONSE_REJECT:
|
|
# Assume a simple reject message
|
|
instance = Simple_Reject(ErrorCode(payload[0]))
|
|
else:
|
|
instance = Message()
|
|
instance.payload = payload
|
|
instance.message_type = message_type
|
|
instance.signal_identifier = signal_identifier
|
|
return instance
|
|
|
|
def to_string(self, details: str | Iterable[str]) -> str:
|
|
base = color(
|
|
f'{self.signal_identifier.name}_{self.message_type.name}',
|
|
'yellow',
|
|
)
|
|
|
|
if details:
|
|
if isinstance(details, str):
|
|
return f'{base}: {details}'
|
|
|
|
return (
|
|
base
|
|
+ ':\n'
|
|
+ '\n'.join([' ' + color(detail, 'cyan') for detail in details])
|
|
)
|
|
|
|
return base
|
|
|
|
def __str__(self) -> str:
|
|
return self.to_string(self.payload.hex())
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@dataclass
|
|
class Simple_Command(Message):
|
|
'''
|
|
Command message with just one seid
|
|
'''
|
|
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
acp_seid: int = field(metadata=Message.SEID_METADATA)
|
|
|
|
def __str__(self) -> str:
|
|
return self.to_string([f'ACP SEID: {self.acp_seid}'])
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@dataclass
|
|
class Simple_Reject(Message):
|
|
'''
|
|
Reject messages with just an error code
|
|
'''
|
|
|
|
message_type = Message.MessageType.RESPONSE_REJECT
|
|
|
|
error_code: ErrorCode = field(metadata=ErrorCode.type_metadata(1))
|
|
|
|
def __str__(self) -> str:
|
|
details = [f'error_code: {self.error_code.name}']
|
|
return self.to_string(details)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Discover_Command(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.6.1 Stream End Point Discovery Command
|
|
'''
|
|
|
|
signal_identifier = AVDTP_DISCOVER
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Discover_Response(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.6.2 Stream End Point Discovery Response
|
|
'''
|
|
|
|
signal_identifier = AVDTP_DISCOVER
|
|
message_type = Message.MessageType.RESPONSE_ACCEPT
|
|
|
|
@classmethod
|
|
def parse_endpoints(
|
|
cls, data: bytes, offset: int
|
|
) -> tuple[int, list[EndPointInfo]]:
|
|
return len(data), [
|
|
EndPointInfo.from_bytes(data[i * 2 : (i + 1) * 2])
|
|
for i in range(offset, len(data) // 2)
|
|
]
|
|
|
|
@classmethod
|
|
def serialize_endpoints(cls, endpoints: Iterable[EndPointInfo]) -> bytes:
|
|
return b''.join([bytes(endpoint) for endpoint in endpoints])
|
|
|
|
endpoints: Iterable[EndPointInfo] = field(
|
|
metadata=hci.metadata(
|
|
{
|
|
'parser': lambda data, offset: Discover_Response.parse_endpoints(
|
|
data, offset
|
|
),
|
|
'serializer': lambda endpoints: Discover_Response.serialize_endpoints(
|
|
endpoints
|
|
),
|
|
}
|
|
)
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
details = []
|
|
for endpoint in self.endpoints:
|
|
details.extend(
|
|
# pylint: disable=line-too-long
|
|
[
|
|
f'ACP SEID: {endpoint.seid}',
|
|
f' in_use: {endpoint.in_use}',
|
|
f' media_type: {endpoint.media_type.name}',
|
|
f' tsep: {endpoint.tsep.name}',
|
|
]
|
|
)
|
|
return self.to_string(details)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Get_Capabilities_Command(Simple_Command):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.7.1 Get Capabilities Command
|
|
'''
|
|
|
|
signal_identifier = AVDTP_GET_CAPABILITIES
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Get_Capabilities_Response(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.7.2 Get All Capabilities Response
|
|
'''
|
|
|
|
signal_identifier = AVDTP_GET_CAPABILITIES
|
|
message_type = Message.MessageType.RESPONSE_ACCEPT
|
|
|
|
capabilities: Iterable[ServiceCapabilities] = field(
|
|
metadata=ServiceCapabilities.METADATA
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
details = [str(capability) for capability in self.capabilities]
|
|
return self.to_string(details)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Get_Capabilities_Reject(Simple_Reject):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.7.3 Get Capabilities Reject
|
|
'''
|
|
|
|
signal_identifier = AVDTP_GET_CAPABILITIES
|
|
message_type = Message.MessageType.RESPONSE_REJECT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Get_All_Capabilities_Command(Get_Capabilities_Command):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.8.1 Get All Capabilities Command
|
|
'''
|
|
|
|
signal_identifier = AVDTP_GET_ALL_CAPABILITIES
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Get_All_Capabilities_Response(Get_Capabilities_Response):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.8.2 Get All Capabilities Response
|
|
'''
|
|
|
|
signal_identifier = AVDTP_GET_ALL_CAPABILITIES
|
|
message_type = Message.MessageType.RESPONSE_ACCEPT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Get_All_Capabilities_Reject(Simple_Reject):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.8.3 Get All Capabilities Reject
|
|
'''
|
|
|
|
signal_identifier = AVDTP_GET_ALL_CAPABILITIES
|
|
message_type = Message.MessageType.RESPONSE_REJECT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Set_Configuration_Command(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.9.1 Set Configuration Command
|
|
'''
|
|
|
|
signal_identifier = AVDTP_SET_CONFIGURATION
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
acp_seid: int = field(metadata=Message.SEID_METADATA)
|
|
int_seid: int = field(metadata=Message.SEID_METADATA)
|
|
capabilities: Iterable[ServiceCapabilities] = field(
|
|
metadata=ServiceCapabilities.METADATA
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
details = [f'ACP SEID: {self.acp_seid}', f'INT SEID: {self.int_seid}'] + [
|
|
str(capability) for capability in self.capabilities
|
|
]
|
|
return self.to_string(details)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Set_Configuration_Response(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.9.2 Set Configuration Response
|
|
'''
|
|
|
|
signal_identifier = AVDTP_SET_CONFIGURATION
|
|
message_type = Message.MessageType.RESPONSE_ACCEPT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Set_Configuration_Reject(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.9.3 Set Configuration Reject
|
|
'''
|
|
|
|
signal_identifier = AVDTP_SET_CONFIGURATION
|
|
message_type = Message.MessageType.RESPONSE_REJECT
|
|
|
|
service_category: ServiceCategory = field(
|
|
metadata=ServiceCategory.type_metadata(1), default=ServiceCategory(0)
|
|
)
|
|
error_code: ErrorCode = field(
|
|
metadata=ErrorCode.type_metadata(1), default=ErrorCode(0)
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
details = [
|
|
(f'service_category: {self.service_category.name}'),
|
|
(f'error_code: {self.error_code.name}'),
|
|
]
|
|
return self.to_string(details)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Get_Configuration_Command(Simple_Command):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.10.1 Get Configuration Command
|
|
'''
|
|
|
|
signal_identifier = AVDTP_GET_CONFIGURATION
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Get_Configuration_Response(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.10.2 Get Configuration Response
|
|
'''
|
|
|
|
signal_identifier = AVDTP_GET_CONFIGURATION
|
|
message_type = Message.MessageType.RESPONSE_ACCEPT
|
|
|
|
capabilities: Iterable[ServiceCapabilities] = field(
|
|
metadata=ServiceCapabilities.METADATA
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
details = [str(capability) for capability in self.capabilities]
|
|
return self.to_string(details)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Get_Configuration_Reject(Simple_Reject):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.10.3 Get Configuration Reject
|
|
'''
|
|
|
|
signal_identifier = AVDTP_GET_CONFIGURATION
|
|
message_type = Message.MessageType.RESPONSE_REJECT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Reconfigure_Command(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.11.1 Reconfigure Command
|
|
'''
|
|
|
|
signal_identifier = AVDTP_RECONFIGURE
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
acp_seid: int = field(metadata=Message.SEID_METADATA)
|
|
capabilities: Iterable[ServiceCapabilities] = field(
|
|
metadata=ServiceCapabilities.METADATA
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
details = [
|
|
f'ACP SEID: {self.acp_seid}',
|
|
] + [str(capability) for capability in self.capabilities]
|
|
return self.to_string(details)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Reconfigure_Response(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.11.2 Reconfigure Response
|
|
'''
|
|
|
|
signal_identifier = AVDTP_RECONFIGURE
|
|
message_type = Message.MessageType.RESPONSE_ACCEPT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Reconfigure_Reject(Set_Configuration_Reject):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.11.3 Reconfigure Reject
|
|
'''
|
|
|
|
signal_identifier = AVDTP_RECONFIGURE
|
|
message_type = Message.MessageType.RESPONSE_REJECT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Open_Command(Simple_Command):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.12.1 Open Stream Command
|
|
'''
|
|
|
|
signal_identifier = AVDTP_OPEN
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Open_Response(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.12.2 Open Stream Response
|
|
'''
|
|
|
|
signal_identifier = AVDTP_OPEN
|
|
message_type = Message.MessageType.RESPONSE_ACCEPT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Open_Reject(Simple_Reject):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.12.3 Open Stream Reject
|
|
'''
|
|
|
|
signal_identifier = AVDTP_OPEN
|
|
message_type = Message.MessageType.RESPONSE_REJECT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Start_Command(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.13.1 Start Stream Command
|
|
'''
|
|
|
|
signal_identifier = AVDTP_START
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
acp_seids: Iterable[int] = field(
|
|
metadata=hci.metadata(
|
|
{
|
|
'serializer': lambda seids: bytes([seid << 2 for seid in seids]),
|
|
'parser': lambda data, offset: (
|
|
len(data),
|
|
[x >> 2 for x in data[offset:]],
|
|
),
|
|
}
|
|
)
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
return self.to_string([f'ACP SEIDs: {self.acp_seids}'])
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Start_Response(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.13.2 Start Stream Response
|
|
'''
|
|
|
|
signal_identifier = AVDTP_START
|
|
message_type = Message.MessageType.RESPONSE_ACCEPT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Start_Reject(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.13.3 Set Configuration Reject
|
|
'''
|
|
|
|
signal_identifier = AVDTP_START
|
|
message_type = Message.MessageType.RESPONSE_REJECT
|
|
|
|
acp_seid: int = field(metadata=Message.SEID_METADATA)
|
|
error_code: ErrorCode = field(metadata=ErrorCode.type_metadata(1))
|
|
|
|
def __str__(self) -> str:
|
|
details = [
|
|
f'acp_seid: {self.acp_seid}',
|
|
f'error_code: {self.error_code.name}',
|
|
]
|
|
return self.to_string(details)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Close_Command(Simple_Command):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.14.1 Close Stream Command
|
|
'''
|
|
|
|
signal_identifier = AVDTP_CLOSE
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Close_Response(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.14.2 Close Stream Response
|
|
'''
|
|
|
|
signal_identifier = AVDTP_CLOSE
|
|
message_type = Message.MessageType.RESPONSE_ACCEPT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Close_Reject(Simple_Reject):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.14.3 Close Stream Reject
|
|
'''
|
|
|
|
signal_identifier = AVDTP_CLOSE
|
|
message_type = Message.MessageType.RESPONSE_REJECT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Suspend_Command(Start_Command):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.15.1 Suspend Command
|
|
'''
|
|
|
|
signal_identifier = AVDTP_SUSPEND
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Suspend_Response(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.15.2 Suspend Response
|
|
'''
|
|
|
|
signal_identifier = AVDTP_SUSPEND
|
|
message_type = Message.MessageType.RESPONSE_ACCEPT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Suspend_Reject(Start_Reject):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.15.3 Suspend Reject
|
|
'''
|
|
|
|
signal_identifier = AVDTP_SUSPEND
|
|
message_type = Message.MessageType.RESPONSE_REJECT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Abort_Command(Simple_Command):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.16.1 Abort Command
|
|
'''
|
|
|
|
signal_identifier = AVDTP_ABORT
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Abort_Response(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.16.2 Abort Response
|
|
'''
|
|
|
|
signal_identifier = AVDTP_ABORT
|
|
message_type = Message.MessageType.RESPONSE_ACCEPT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Security_Control_Command(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.17.1 Security Control Command
|
|
'''
|
|
|
|
signal_identifier = AVDTP_SECURITY_CONTROL
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
acp_seid: int = field(metadata=Message.SEID_METADATA)
|
|
data: bytes = field(metadata=hci.metadata('*'))
|
|
|
|
def __str__(self) -> str:
|
|
return self.to_string(
|
|
[f'ACP_SEID: {self.acp_seid}', f'data: {self.data.hex()}']
|
|
)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Security_Control_Response(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.17.2 Security Control Response
|
|
'''
|
|
|
|
signal_identifier = AVDTP_SECURITY_CONTROL
|
|
message_type = Message.MessageType.RESPONSE_ACCEPT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class Security_Control_Reject(Simple_Reject):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.17.3 Security Control Reject
|
|
'''
|
|
|
|
signal_identifier = AVDTP_SECURITY_CONTROL
|
|
message_type = Message.MessageType.RESPONSE_REJECT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class General_Reject(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.18 General Reject
|
|
'''
|
|
|
|
signal_identifier = SignalIdentifier(0)
|
|
message_type = Message.MessageType.GENERAL_REJECT
|
|
|
|
def to_string(self, details):
|
|
return color('GENERAL_REJECT', 'yellow')
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class DelayReport_Command(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.19.1 Delay Report Command
|
|
'''
|
|
|
|
signal_identifier = AVDTP_DELAYREPORT
|
|
message_type = Message.MessageType.COMMAND
|
|
|
|
DELAY_METADATA = hci.metadata(
|
|
{
|
|
'serializer': lambda delay: bytes([delay >> 8, delay & 0xFF]),
|
|
'parser': lambda data, offset: (
|
|
offset + 2,
|
|
(data[offset] << 8) | (data[offset + 1]),
|
|
),
|
|
}
|
|
)
|
|
|
|
acp_seid: int = field(metadata=Message.SEID_METADATA)
|
|
delay: int = field(metadata=DELAY_METADATA)
|
|
|
|
def __str__(self) -> str:
|
|
return self.to_string([f'ACP_SEID: {self.acp_seid}', f'delay: {self.delay}'])
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class DelayReport_Response(Message):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.19.2 Delay Report Response
|
|
'''
|
|
|
|
signal_identifier = AVDTP_DELAYREPORT
|
|
message_type = Message.MessageType.RESPONSE_ACCEPT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@Message.subclass
|
|
@dataclass
|
|
class DelayReport_Reject(Simple_Reject):
|
|
'''
|
|
See Bluetooth AVDTP spec - 8.19.3 Delay Report Reject
|
|
'''
|
|
|
|
signal_identifier = AVDTP_DELAYREPORT
|
|
message_type = Message.MessageType.RESPONSE_REJECT
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class Protocol(utils.EventEmitter):
|
|
local_endpoints: list[LocalStreamEndPoint]
|
|
remote_endpoints: dict[int, DiscoveredStreamEndPoint]
|
|
streams: dict[int, Stream]
|
|
transaction_results: list[asyncio.Future[Message] | None]
|
|
channel_connector: Callable[[], Awaitable[l2cap.ClassicChannel]]
|
|
channel_acceptor: Stream | None
|
|
|
|
EVENT_OPEN = "open"
|
|
EVENT_CLOSE = "close"
|
|
|
|
class PacketType(enum.IntEnum):
|
|
SINGLE_PACKET = 0
|
|
START_PACKET = 1
|
|
CONTINUE_PACKET = 2
|
|
END_PACKET = 3
|
|
|
|
@staticmethod
|
|
async def connect(
|
|
connection: device.Connection, version: tuple[int, int] = (1, 3)
|
|
) -> Protocol:
|
|
channel = await connection.create_l2cap_channel(
|
|
spec=l2cap.ClassicChannelSpec(psm=AVDTP_PSM)
|
|
)
|
|
protocol = Protocol(channel, version)
|
|
|
|
return protocol
|
|
|
|
def __init__(
|
|
self, l2cap_channel: l2cap.ClassicChannel, version: tuple[int, int] = (1, 3)
|
|
) -> None:
|
|
super().__init__()
|
|
self.l2cap_channel = l2cap_channel
|
|
self.version = version
|
|
self.rtx_sig_timer = AVDTP_DEFAULT_RTX_SIG_TIMER
|
|
self.message_assembler = MessageAssembler(self.on_message)
|
|
self.transaction_results = [None] * 16 # Futures for up to 16 transactions
|
|
self.transaction_semaphore = asyncio.Semaphore(16)
|
|
self.transaction_count = 0
|
|
self.channel_acceptor = None
|
|
self.local_endpoints = [] # Local endpoints, with contiguous seid values
|
|
self.remote_endpoints = {} # Remote stream endpoints, by seid
|
|
self.streams = {} # Streams, by seid
|
|
|
|
# Register to receive PDUs from the channel
|
|
l2cap_channel.sink = self.on_pdu
|
|
l2cap_channel.on(l2cap_channel.EVENT_OPEN, self.on_l2cap_channel_open)
|
|
l2cap_channel.on(l2cap_channel.EVENT_CLOSE, self.on_l2cap_channel_close)
|
|
|
|
def get_local_endpoint_by_seid(self, seid: int) -> LocalStreamEndPoint | None:
|
|
if 0 < seid <= len(self.local_endpoints):
|
|
return self.local_endpoints[seid - 1]
|
|
|
|
return None
|
|
|
|
def add_source(
|
|
self,
|
|
codec_capabilities: MediaCodecCapabilities,
|
|
packet_pump: MediaPacketPump,
|
|
delay_reporting: bool = False,
|
|
) -> LocalSource:
|
|
seid = len(self.local_endpoints) + 1
|
|
service_capabilities = (
|
|
[ServiceCapabilities(AVDTP_DELAY_REPORTING_SERVICE_CATEGORY)]
|
|
if delay_reporting
|
|
else []
|
|
)
|
|
source = LocalSource(
|
|
self, seid, codec_capabilities, service_capabilities, packet_pump
|
|
)
|
|
self.local_endpoints.append(source)
|
|
|
|
return source
|
|
|
|
def add_sink(self, codec_capabilities: MediaCodecCapabilities) -> LocalSink:
|
|
seid = len(self.local_endpoints) + 1
|
|
sink = LocalSink(self, seid, codec_capabilities)
|
|
self.local_endpoints.append(sink)
|
|
|
|
return sink
|
|
|
|
async def create_stream(
|
|
self, source: LocalStreamEndPoint, sink: StreamEndPointProxy
|
|
) -> Stream:
|
|
# Check that the source isn't already used in a stream
|
|
if source.in_use:
|
|
raise InvalidStateError('source already in use')
|
|
|
|
# Create or reuse a new stream to associate the source and the sink
|
|
if source.seid in self.streams:
|
|
stream = self.streams[source.seid]
|
|
else:
|
|
stream = Stream(self, source, sink)
|
|
self.streams[source.seid] = stream
|
|
|
|
# The stream can now be configured
|
|
await stream.configure()
|
|
|
|
return stream
|
|
|
|
async def discover_remote_endpoints(self) -> Iterable[DiscoveredStreamEndPoint]:
|
|
self.remote_endpoints = {}
|
|
|
|
response: Discover_Response = await self.send_command(Discover_Command())
|
|
for endpoint_entry in response.endpoints:
|
|
logger.debug(
|
|
f'getting endpoint capabilities for endpoint {endpoint_entry.seid}'
|
|
)
|
|
get_capabilities_response = await self.get_capabilities(endpoint_entry.seid)
|
|
endpoint = DiscoveredStreamEndPoint(
|
|
self,
|
|
endpoint_entry.seid,
|
|
endpoint_entry.media_type,
|
|
endpoint_entry.tsep,
|
|
endpoint_entry.in_use,
|
|
get_capabilities_response.capabilities,
|
|
)
|
|
self.remote_endpoints[endpoint_entry.seid] = endpoint
|
|
|
|
return self.remote_endpoints.values()
|
|
|
|
def find_remote_sink_by_codec(
|
|
self, media_type: int, codec_type: int, vendor_id: int = 0, codec_id: int = 0
|
|
) -> DiscoveredStreamEndPoint | None:
|
|
for endpoint in self.remote_endpoints.values():
|
|
if (
|
|
not endpoint.in_use
|
|
and endpoint.media_type == media_type
|
|
and endpoint.tsep == AVDTP_TSEP_SNK
|
|
):
|
|
has_media_transport = False
|
|
has_codec = False
|
|
for capabilities in endpoint.capabilities:
|
|
if (
|
|
capabilities.service_category
|
|
== AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY
|
|
):
|
|
has_media_transport = True
|
|
elif (
|
|
capabilities.service_category
|
|
== AVDTP_MEDIA_CODEC_SERVICE_CATEGORY
|
|
):
|
|
codec_capabilities = cast(MediaCodecCapabilities, capabilities)
|
|
if (
|
|
codec_capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE
|
|
and codec_capabilities.media_codec_type == codec_type
|
|
):
|
|
if isinstance(
|
|
codec_capabilities.media_codec_information,
|
|
a2dp.VendorSpecificMediaCodecInformation,
|
|
):
|
|
if (
|
|
codec_capabilities.media_codec_information.vendor_id
|
|
== vendor_id
|
|
and codec_capabilities.media_codec_information.codec_id
|
|
== codec_id
|
|
):
|
|
has_codec = True
|
|
else:
|
|
has_codec = True
|
|
if has_media_transport and has_codec:
|
|
return endpoint
|
|
|
|
return None
|
|
|
|
def on_pdu(self, pdu: bytes) -> None:
|
|
self.message_assembler.on_pdu(pdu)
|
|
|
|
def on_message(self, transaction_label: int, message: Message) -> None:
|
|
logger.debug(
|
|
f'{color("<<< Received AVDTP message", "magenta")}: '
|
|
f'[{transaction_label}] {message}'
|
|
)
|
|
|
|
# Check that the identifier is not reserved
|
|
if message.signal_identifier == 0:
|
|
logger.warning('!!! reserved signal identifier')
|
|
return
|
|
|
|
# Check that the identifier is valid
|
|
if (
|
|
message.signal_identifier < 0
|
|
or message.signal_identifier > AVDTP_DELAYREPORT
|
|
):
|
|
logger.warning('!!! invalid signal identifier')
|
|
self.send_message(transaction_label, General_Reject())
|
|
|
|
if message.message_type == Message.MessageType.COMMAND:
|
|
# Command
|
|
signal_name = message.signal_identifier.name.lower()
|
|
handler_name = f'on_{signal_name}_command'
|
|
handler = getattr(self, handler_name, None)
|
|
if handler:
|
|
try:
|
|
response = handler(message)
|
|
self.send_message(transaction_label, response)
|
|
except Exception:
|
|
logger.exception(color("!!! Exception in handler:", "red"))
|
|
else:
|
|
logger.warning('unhandled command')
|
|
else:
|
|
# Response, look for a pending transaction with the same label
|
|
transaction_result = self.transaction_results[transaction_label]
|
|
if transaction_result is None:
|
|
logger.warning(color('!!! no pending transaction for label', 'red'))
|
|
return
|
|
|
|
transaction_result.set_result(message)
|
|
self.transaction_results[transaction_label] = None
|
|
self.transaction_semaphore.release()
|
|
|
|
def on_l2cap_connection(self, channel: l2cap.ClassicChannel) -> None:
|
|
# Forward the channel to the endpoint that's expecting it
|
|
if self.channel_acceptor is None:
|
|
logger.warning(color('!!! l2cap connection with no acceptor', 'red'))
|
|
return
|
|
self.channel_acceptor.on_l2cap_connection(channel)
|
|
|
|
def on_l2cap_channel_open(self) -> None:
|
|
logger.debug(color('<<< L2CAP channel open', 'magenta'))
|
|
self.emit(self.EVENT_OPEN)
|
|
|
|
def on_l2cap_channel_close(self) -> None:
|
|
logger.debug(color('<<< L2CAP channel close', 'magenta'))
|
|
self.emit(self.EVENT_CLOSE)
|
|
|
|
def send_message(self, transaction_label: int, message: Message) -> None:
|
|
logger.debug(
|
|
f'{color(">>> Sending AVDTP message", "magenta")}: '
|
|
f'[{transaction_label}] {message}'
|
|
)
|
|
max_fragment_size = (
|
|
self.l2cap_channel.peer_mtu - 3
|
|
) # Enough space for a 3-byte start packet header
|
|
payload = message.payload
|
|
if len(payload) + 2 <= self.l2cap_channel.peer_mtu:
|
|
# Fits in a single packet
|
|
packet_type = self.PacketType.SINGLE_PACKET
|
|
else:
|
|
packet_type = self.PacketType.START_PACKET
|
|
|
|
done = False
|
|
while not done:
|
|
first_header_byte = (
|
|
transaction_label << 4 | packet_type << 2 | message.message_type
|
|
)
|
|
|
|
if packet_type == self.PacketType.SINGLE_PACKET:
|
|
header = bytes([first_header_byte, message.signal_identifier])
|
|
elif packet_type == self.PacketType.START_PACKET:
|
|
packet_count = (
|
|
max_fragment_size - 1 + len(payload)
|
|
) // max_fragment_size
|
|
header = bytes(
|
|
[first_header_byte, message.signal_identifier, packet_count]
|
|
)
|
|
else:
|
|
header = bytes([first_header_byte])
|
|
|
|
# Send one packet
|
|
self.l2cap_channel.write(header + payload[:max_fragment_size])
|
|
|
|
# Prepare for the next packet
|
|
payload = payload[max_fragment_size:]
|
|
if payload:
|
|
packet_type = (
|
|
self.PacketType.CONTINUE_PACKET
|
|
if len(payload) > max_fragment_size
|
|
else self.PacketType.END_PACKET
|
|
)
|
|
else:
|
|
done = True
|
|
|
|
async def send_command(self, command: Message):
|
|
# TODO: support timeouts
|
|
# Send the command
|
|
(transaction_label, transaction_result) = await self.start_transaction()
|
|
self.send_message(transaction_label, command)
|
|
|
|
# Wait for the response
|
|
response = await transaction_result
|
|
|
|
# Check for errors
|
|
if response.message_type in (
|
|
Message.MessageType.GENERAL_REJECT,
|
|
Message.MessageType.RESPONSE_REJECT,
|
|
):
|
|
assert hasattr(response, 'error_code')
|
|
raise ProtocolError(response.error_code, 'avdtp')
|
|
|
|
return response
|
|
|
|
async def start_transaction(self) -> tuple[int, asyncio.Future[Message]]:
|
|
# Wait until we can start a new transaction
|
|
await self.transaction_semaphore.acquire()
|
|
|
|
# Look for the next free entry to store the transaction result
|
|
for i in range(16):
|
|
transaction_label = (self.transaction_count + i) % 16
|
|
if self.transaction_results[transaction_label] is None:
|
|
transaction_result = asyncio.get_running_loop().create_future()
|
|
self.transaction_results[transaction_label] = transaction_result
|
|
self.transaction_count += 1
|
|
return (transaction_label, transaction_result)
|
|
|
|
assert False # Should never reach this
|
|
|
|
async def get_capabilities(
|
|
self, seid: int
|
|
) -> Get_Capabilities_Response | Get_All_Capabilities_Response:
|
|
if self.version > (1, 2):
|
|
return await self.send_command(Get_All_Capabilities_Command(seid))
|
|
|
|
return await self.send_command(Get_Capabilities_Command(seid))
|
|
|
|
async def set_configuration(
|
|
self, acp_seid: int, int_seid: int, capabilities: Iterable[ServiceCapabilities]
|
|
) -> Set_Configuration_Response:
|
|
return await self.send_command(
|
|
Set_Configuration_Command(acp_seid, int_seid, capabilities)
|
|
)
|
|
|
|
async def get_configuration(self, seid: int) -> Get_Configuration_Response:
|
|
response = await self.send_command(Get_Configuration_Command(seid))
|
|
return response.capabilities
|
|
|
|
async def open(self, seid: int) -> Open_Response:
|
|
return await self.send_command(Open_Command(seid))
|
|
|
|
async def start(self, seids: Iterable[int]) -> Start_Response:
|
|
return await self.send_command(Start_Command(seids))
|
|
|
|
async def suspend(self, seids: Iterable[int]) -> Suspend_Response:
|
|
return await self.send_command(Suspend_Command(seids))
|
|
|
|
async def close(self, seid: int) -> Close_Response:
|
|
return await self.send_command(Close_Command(seid))
|
|
|
|
async def abort(self, seid: int) -> Abort_Response:
|
|
return await self.send_command(Abort_Command(seid))
|
|
|
|
def on_discover_command(self, command: Discover_Command) -> Message | None:
|
|
endpoint_infos = [
|
|
EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep)
|
|
for endpoint in self.local_endpoints
|
|
]
|
|
return Discover_Response(endpoint_infos)
|
|
|
|
def on_get_capabilities_command(
|
|
self, command: Get_Capabilities_Command
|
|
) -> Message | None:
|
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
|
if endpoint is None:
|
|
return Get_Capabilities_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
|
|
|
return Get_Capabilities_Response(endpoint.capabilities)
|
|
|
|
def on_get_all_capabilities_command(
|
|
self, command: Get_All_Capabilities_Command
|
|
) -> Message | None:
|
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
|
if endpoint is None:
|
|
return Get_All_Capabilities_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
|
|
|
return Get_All_Capabilities_Response(endpoint.capabilities)
|
|
|
|
def on_set_configuration_command(
|
|
self, command: Set_Configuration_Command
|
|
) -> Message | None:
|
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
|
if endpoint is None:
|
|
return Set_Configuration_Reject(error_code=AVDTP_BAD_ACP_SEID_ERROR)
|
|
|
|
# Check that the local endpoint isn't in use
|
|
if endpoint.in_use:
|
|
return Set_Configuration_Reject(error_code=AVDTP_SEP_IN_USE_ERROR)
|
|
|
|
# Create a stream object for the pair of endpoints
|
|
stream = Stream(self, endpoint, StreamEndPointProxy(self, command.int_seid))
|
|
self.streams[command.acp_seid] = stream
|
|
|
|
result = stream.on_set_configuration_command(command.capabilities)
|
|
return result or Set_Configuration_Response()
|
|
|
|
def on_get_configuration_command(
|
|
self, command: Get_Configuration_Command
|
|
) -> Message | None:
|
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
|
if endpoint is None:
|
|
return Get_Configuration_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
|
if endpoint.stream is None:
|
|
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
|
|
|
|
return endpoint.stream.on_get_configuration_command()
|
|
|
|
def on_reconfigure_command(self, command: Reconfigure_Command) -> Message | None:
|
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
|
if endpoint is None:
|
|
return Reconfigure_Reject(error_code=AVDTP_BAD_ACP_SEID_ERROR)
|
|
if endpoint.stream is None:
|
|
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
|
|
|
result = endpoint.stream.on_reconfigure_command(command.capabilities)
|
|
return result or Reconfigure_Response()
|
|
|
|
def on_open_command(self, command: Open_Command) -> Message | None:
|
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
|
if endpoint is None:
|
|
return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
|
if endpoint.stream is None:
|
|
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
|
|
|
result = endpoint.stream.on_open_command()
|
|
return result or Open_Response()
|
|
|
|
def on_start_command(self, command: Start_Command) -> Message | None:
|
|
for seid in command.acp_seids:
|
|
endpoint = self.get_local_endpoint_by_seid(seid)
|
|
if endpoint is None:
|
|
return Start_Reject(seid, AVDTP_BAD_ACP_SEID_ERROR)
|
|
if endpoint.stream is None:
|
|
return Start_Reject(seid, AVDTP_BAD_STATE_ERROR)
|
|
|
|
# Start all streams
|
|
# TODO: deal with partial failures
|
|
for seid in command.acp_seids:
|
|
endpoint = self.get_local_endpoint_by_seid(seid)
|
|
if not endpoint or not endpoint.stream:
|
|
raise InvalidStateError("Should already be checked!")
|
|
if (result := endpoint.stream.on_start_command()) is not None:
|
|
return result
|
|
|
|
return Start_Response()
|
|
|
|
def on_suspend_command(self, command: Suspend_Command) -> Message | None:
|
|
for seid in command.acp_seids:
|
|
endpoint = self.get_local_endpoint_by_seid(seid)
|
|
if endpoint is None:
|
|
return Suspend_Reject(seid, AVDTP_BAD_ACP_SEID_ERROR)
|
|
if endpoint.stream is None:
|
|
return Suspend_Reject(seid, AVDTP_BAD_STATE_ERROR)
|
|
|
|
# Suspend all streams
|
|
# TODO: deal with partial failures
|
|
for seid in command.acp_seids:
|
|
endpoint = self.get_local_endpoint_by_seid(seid)
|
|
if not endpoint or not endpoint.stream:
|
|
raise InvalidStateError("Should already be checked!")
|
|
if (result := endpoint.stream.on_suspend_command()) is not None:
|
|
return result
|
|
|
|
return Suspend_Response()
|
|
|
|
def on_close_command(self, command: Close_Command) -> Message | None:
|
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
|
if endpoint is None:
|
|
return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
|
if endpoint.stream is None:
|
|
return Close_Reject(AVDTP_BAD_STATE_ERROR)
|
|
|
|
result = endpoint.stream.on_close_command()
|
|
return result or Close_Response()
|
|
|
|
def on_abort_command(self, command: Abort_Command) -> Message | None:
|
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
|
if endpoint is None or endpoint.stream is None:
|
|
return Abort_Response()
|
|
|
|
endpoint.stream.on_abort_command()
|
|
return Abort_Response()
|
|
|
|
def on_security_control_command(
|
|
self, command: Security_Control_Command
|
|
) -> Message | None:
|
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
|
if endpoint is None:
|
|
return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
|
|
|
result = endpoint.on_security_control_command(command.data)
|
|
return result or Security_Control_Response()
|
|
|
|
def on_delayreport_command(self, command: DelayReport_Command) -> Message | None:
|
|
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
|
if endpoint is None:
|
|
return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
|
|
|
result = endpoint.on_delayreport_command(command.delay)
|
|
return result or DelayReport_Response()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class Listener(utils.EventEmitter):
|
|
servers: dict[int, Protocol]
|
|
|
|
EVENT_CONNECTION = "connection"
|
|
|
|
@staticmethod
|
|
def create_registrar(device: device.Device):
|
|
warnings.warn("Please use Listener.for_device()", DeprecationWarning)
|
|
|
|
def wrapper(handler: Callable[[l2cap.ClassicChannel], None]) -> None:
|
|
device.create_l2cap_server(l2cap.ClassicChannelSpec(psm=AVDTP_PSM), handler)
|
|
|
|
return wrapper
|
|
|
|
def set_server(self, connection: device.Connection, server: Protocol) -> None:
|
|
self.servers[connection.handle] = server
|
|
|
|
def remove_server(self, connection: device.Connection) -> None:
|
|
if connection.handle in self.servers:
|
|
del self.servers[connection.handle]
|
|
|
|
def __init__(self, registrar=None, version=(1, 3)):
|
|
super().__init__()
|
|
self.version = version
|
|
self.servers = {} # Servers, by connection handle
|
|
|
|
# Listen for incoming L2CAP connections
|
|
if registrar:
|
|
warnings.warn("Please use Listener.for_device()", DeprecationWarning)
|
|
registrar(self.on_l2cap_connection)
|
|
|
|
@classmethod
|
|
def for_device(
|
|
cls, device: device.Device, version: tuple[int, int] = (1, 3)
|
|
) -> Listener:
|
|
listener = Listener(registrar=None, version=version)
|
|
l2cap_server = device.create_l2cap_server(
|
|
spec=l2cap.ClassicChannelSpec(psm=AVDTP_PSM)
|
|
)
|
|
l2cap_server.on(l2cap_server.EVENT_CONNECTION, listener.on_l2cap_connection)
|
|
return listener
|
|
|
|
def on_l2cap_connection(self, channel: l2cap.ClassicChannel) -> None:
|
|
logger.debug(f'{color("<<< incoming L2CAP connection:", "magenta")} {channel}')
|
|
|
|
if channel.connection.handle in self.servers:
|
|
# This is a channel for a stream endpoint
|
|
server = self.servers[channel.connection.handle]
|
|
server.on_l2cap_connection(channel)
|
|
else:
|
|
# This is a new command/response channel
|
|
def on_channel_open():
|
|
logger.debug('setting up new Protocol for the connection')
|
|
server = Protocol(channel, self.version)
|
|
self.set_server(channel.connection, server)
|
|
self.emit(self.EVENT_CONNECTION, server)
|
|
|
|
def on_channel_close():
|
|
logger.debug('removing Protocol for the connection')
|
|
self.remove_server(channel.connection)
|
|
|
|
channel.on(channel.EVENT_OPEN, on_channel_open)
|
|
channel.on(channel.EVENT_CLOSE, on_channel_close)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class Stream:
|
|
'''
|
|
Pair of a local and a remote stream endpoint that can stream from one to the other
|
|
'''
|
|
|
|
rtp_channel: l2cap.ClassicChannel | None
|
|
|
|
def change_state(self, state: State) -> None:
|
|
logger.debug(f'{self} state change -> {color(state.name, "cyan")}')
|
|
self.state = state
|
|
|
|
def send_media_packet(self, packet: MediaPacket) -> None:
|
|
assert self.rtp_channel
|
|
self.rtp_channel.write(bytes(packet))
|
|
|
|
async def configure(self) -> None:
|
|
if self.state != State.IDLE:
|
|
raise InvalidStateError('current state is not IDLE')
|
|
|
|
await self.remote_endpoint.set_configuration(
|
|
self.local_endpoint.seid, self.local_endpoint.configuration
|
|
)
|
|
self.change_state(State.CONFIGURED)
|
|
|
|
async def open(self) -> None:
|
|
if self.state != State.CONFIGURED:
|
|
raise InvalidStateError('current state is not CONFIGURED')
|
|
|
|
logger.debug('opening remote endpoint')
|
|
await self.remote_endpoint.open()
|
|
|
|
self.change_state(State.OPEN)
|
|
|
|
# Create a channel for RTP packets
|
|
self.rtp_channel = (
|
|
await self.protocol.l2cap_channel.connection.create_l2cap_channel(
|
|
l2cap.ClassicChannelSpec(psm=AVDTP_PSM)
|
|
)
|
|
)
|
|
|
|
async def start(self) -> None:
|
|
"""[Source] Start streaming."""
|
|
# Auto-open if needed
|
|
if self.state == State.CONFIGURED:
|
|
await self.open()
|
|
|
|
if self.state != State.OPEN:
|
|
raise InvalidStateError('current state is not OPEN')
|
|
|
|
logger.debug('starting remote endpoint')
|
|
await self.remote_endpoint.start()
|
|
|
|
logger.debug('starting local endpoint')
|
|
await self.local_endpoint.start()
|
|
|
|
self.change_state(State.STREAMING)
|
|
|
|
async def stop(self) -> None:
|
|
"""[Source] Stop streaming and transit to OPEN state."""
|
|
if self.state != State.STREAMING:
|
|
raise InvalidStateError('current state is not STREAMING')
|
|
|
|
logger.debug('stopping local endpoint')
|
|
await self.local_endpoint.stop()
|
|
|
|
logger.debug('stopping remote endpoint')
|
|
await self.remote_endpoint.stop()
|
|
|
|
self.change_state(State.OPEN)
|
|
|
|
async def close(self) -> None:
|
|
"""[Source] Close channel and transit to IDLE state."""
|
|
if self.state not in (State.OPEN, State.STREAMING):
|
|
raise InvalidStateError('current state is not OPEN or STREAMING')
|
|
|
|
logger.debug('closing local endpoint')
|
|
await self.local_endpoint.close()
|
|
|
|
logger.debug('closing remote endpoint')
|
|
await self.remote_endpoint.close()
|
|
|
|
# Release any channels we may have created
|
|
self.change_state(State.CLOSING)
|
|
if self.rtp_channel:
|
|
await self.rtp_channel.disconnect()
|
|
self.rtp_channel = None
|
|
|
|
# Release the endpoint
|
|
self.local_endpoint.in_use = 0
|
|
|
|
self.change_state(State.IDLE)
|
|
|
|
def on_set_configuration_command(
|
|
self, configuration: Iterable[ServiceCapabilities]
|
|
) -> Message | None:
|
|
if self.state != State.IDLE:
|
|
return Set_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
|
|
|
result = self.local_endpoint.on_set_configuration_command(configuration)
|
|
if result is not None:
|
|
return result
|
|
|
|
self.change_state(State.CONFIGURED)
|
|
return None
|
|
|
|
def on_get_configuration_command(self) -> Message | None:
|
|
if self.state not in (
|
|
State.CONFIGURED,
|
|
State.OPEN,
|
|
State.STREAMING,
|
|
):
|
|
return Get_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
|
|
|
return self.local_endpoint.on_get_configuration_command()
|
|
|
|
def on_reconfigure_command(
|
|
self, configuration: Iterable[ServiceCapabilities]
|
|
) -> Message | None:
|
|
if self.state != State.OPEN:
|
|
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
|
|
|
result = self.local_endpoint.on_reconfigure_command(configuration)
|
|
if result is not None:
|
|
return result
|
|
|
|
return None
|
|
|
|
def on_open_command(self) -> Message | None:
|
|
if self.state != State.CONFIGURED:
|
|
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
|
|
|
result = self.local_endpoint.on_open_command()
|
|
if result is not None:
|
|
return result
|
|
|
|
# Register to accept the next channel
|
|
self.protocol.channel_acceptor = self
|
|
|
|
self.change_state(State.OPEN)
|
|
return None
|
|
|
|
def on_start_command(self) -> Message | None:
|
|
if self.state != State.OPEN:
|
|
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
|
|
|
# Check that we have an RTP channel
|
|
if self.rtp_channel is None:
|
|
logger.warning('received start command before RTP channel establishment')
|
|
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
|
|
|
result = self.local_endpoint.on_start_command()
|
|
if result is not None:
|
|
return result
|
|
|
|
self.change_state(State.STREAMING)
|
|
return None
|
|
|
|
def on_suspend_command(self) -> Message | None:
|
|
if self.state != State.STREAMING:
|
|
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
|
|
|
result = self.local_endpoint.on_suspend_command()
|
|
if result is not None:
|
|
return result
|
|
|
|
self.change_state(State.OPEN)
|
|
return None
|
|
|
|
def on_close_command(self) -> Message | None:
|
|
if self.state not in (State.OPEN, State.STREAMING):
|
|
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
|
|
|
result = self.local_endpoint.on_close_command()
|
|
if result is not None:
|
|
return result
|
|
|
|
self.change_state(State.CLOSING)
|
|
|
|
if self.rtp_channel is None:
|
|
# No channel to release, we're done
|
|
self.change_state(State.IDLE)
|
|
else:
|
|
# TODO: set a timer as we wait for the RTP channel to be closed
|
|
pass
|
|
|
|
return None
|
|
|
|
def on_abort_command(self) -> Message | None:
|
|
if self.rtp_channel is None:
|
|
# No need to wait
|
|
self.change_state(State.IDLE)
|
|
else:
|
|
# Wait for the RTP channel to be closed
|
|
self.change_state(State.ABORTING)
|
|
return None
|
|
|
|
def on_l2cap_connection(self, channel: l2cap.ClassicChannel) -> None:
|
|
logger.debug(color('<<< stream channel connected', 'magenta'))
|
|
self.rtp_channel = channel
|
|
channel.on(channel.EVENT_OPEN, self.on_l2cap_channel_open)
|
|
channel.on(channel.EVENT_CLOSE, self.on_l2cap_channel_close)
|
|
|
|
# We don't need more channels
|
|
self.protocol.channel_acceptor = None
|
|
|
|
def on_l2cap_channel_open(self) -> None:
|
|
logger.debug(color('<<< stream channel open', 'magenta'))
|
|
self.local_endpoint.on_rtp_channel_open()
|
|
|
|
def on_l2cap_channel_close(self) -> None:
|
|
logger.debug(color('<<< stream channel closed', 'magenta'))
|
|
self.local_endpoint.on_rtp_channel_close()
|
|
self.local_endpoint.in_use = 0
|
|
self.rtp_channel = None
|
|
|
|
if self.state in (State.CLOSING, State.ABORTING):
|
|
self.change_state(State.IDLE)
|
|
else:
|
|
logger.warning('unexpected channel close while not CLOSING or ABORTING')
|
|
|
|
def __init__(
|
|
self,
|
|
protocol: Protocol,
|
|
local_endpoint: LocalStreamEndPoint,
|
|
remote_endpoint: StreamEndPointProxy,
|
|
) -> None:
|
|
'''
|
|
remote_endpoint must be a subclass of StreamEndPointProxy
|
|
|
|
'''
|
|
self.protocol = protocol
|
|
self.local_endpoint = local_endpoint
|
|
self.remote_endpoint = remote_endpoint
|
|
self.rtp_channel = None
|
|
self.state = State.IDLE
|
|
|
|
local_endpoint.stream = self
|
|
local_endpoint.in_use = 1
|
|
|
|
def __str__(self) -> str:
|
|
return (
|
|
f'Stream({self.local_endpoint.seid} -> '
|
|
f'{self.remote_endpoint.seid} {self.state.name})'
|
|
)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
@dataclass
|
|
class StreamEndPoint:
|
|
seid: int
|
|
media_type: MediaType
|
|
tsep: StreamEndPointType
|
|
in_use: int
|
|
capabilities: Iterable[ServiceCapabilities]
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class StreamEndPointProxy:
|
|
def __init__(self, protocol: Protocol, seid: int) -> None:
|
|
self.seid = seid
|
|
self.protocol = protocol
|
|
|
|
async def set_configuration(
|
|
self, int_seid: int, configuration: Iterable[ServiceCapabilities]
|
|
) -> Set_Configuration_Response:
|
|
return await self.protocol.set_configuration(self.seid, int_seid, configuration)
|
|
|
|
async def open(self) -> Open_Response:
|
|
return await self.protocol.open(self.seid)
|
|
|
|
async def start(self) -> Start_Response:
|
|
return await self.protocol.start([self.seid])
|
|
|
|
async def stop(self) -> Suspend_Response:
|
|
return await self.protocol.suspend([self.seid])
|
|
|
|
async def close(self) -> Close_Response:
|
|
return await self.protocol.close(self.seid)
|
|
|
|
async def abort(self) -> Abort_Response:
|
|
return await self.protocol.abort(self.seid)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
|
|
def __init__(
|
|
self,
|
|
protocol: Protocol,
|
|
seid: int,
|
|
media_type: MediaType,
|
|
tsep: StreamEndPointType,
|
|
in_use: int,
|
|
capabilities: Iterable[ServiceCapabilities],
|
|
) -> None:
|
|
StreamEndPoint.__init__(self, seid, media_type, tsep, in_use, capabilities)
|
|
StreamEndPointProxy.__init__(self, protocol, seid)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
|
|
stream: Stream | None
|
|
|
|
EVENT_CONFIGURATION = "configuration"
|
|
EVENT_OPEN = "open"
|
|
EVENT_START = "start"
|
|
EVENT_STOP = "stop"
|
|
EVENT_RTP_PACKET = "rtp_packet"
|
|
EVENT_SUSPEND = "suspend"
|
|
EVENT_CLOSE = "close"
|
|
EVENT_ABORT = "abort"
|
|
EVENT_DELAY_REPORT = "delay_report"
|
|
EVENT_SECURITY_CONTROL = "security_control"
|
|
EVENT_RTP_CHANNEL_OPEN = "rtp_channel_open"
|
|
EVENT_RTP_CHANNEL_CLOSE = "rtp_channel_close"
|
|
|
|
def __init__(
|
|
self,
|
|
protocol: Protocol,
|
|
seid: int,
|
|
media_type: MediaType,
|
|
tsep: StreamEndPointType,
|
|
capabilities: Iterable[ServiceCapabilities],
|
|
configuration: Iterable[ServiceCapabilities] | None = None,
|
|
):
|
|
StreamEndPoint.__init__(self, seid, media_type, tsep, 0, capabilities)
|
|
utils.EventEmitter.__init__(self)
|
|
self.protocol = protocol
|
|
self.configuration = configuration if configuration is not None else []
|
|
self.stream = None
|
|
|
|
async def start(self) -> None:
|
|
"""[Source Only] Handles when receiving start command."""
|
|
|
|
async def stop(self) -> None:
|
|
"""[Source Only] Handles when receiving stop command."""
|
|
|
|
async def close(self) -> None:
|
|
"""[Source Only] Handles when receiving close command."""
|
|
|
|
def on_reconfigure_command(
|
|
self, command: Iterable[ServiceCapabilities]
|
|
) -> Message | None:
|
|
del command # unused.
|
|
return None
|
|
|
|
def on_set_configuration_command(
|
|
self, configuration: Iterable[ServiceCapabilities]
|
|
) -> Message | None:
|
|
logger.debug(
|
|
'<<< received configuration: '
|
|
f'{",".join([str(capability) for capability in configuration])}'
|
|
)
|
|
self.configuration = configuration
|
|
self.emit(self.EVENT_CONFIGURATION)
|
|
return None
|
|
|
|
def on_get_configuration_command(self) -> Message | None:
|
|
return Get_Configuration_Response(self.configuration)
|
|
|
|
def on_open_command(self) -> Message | None:
|
|
self.emit(self.EVENT_OPEN)
|
|
return None
|
|
|
|
def on_start_command(self) -> Message | None:
|
|
self.emit(self.EVENT_START)
|
|
return None
|
|
|
|
def on_suspend_command(self) -> Message | None:
|
|
self.emit(self.EVENT_SUSPEND)
|
|
return None
|
|
|
|
def on_close_command(self) -> Message | None:
|
|
self.emit(self.EVENT_CLOSE)
|
|
return None
|
|
|
|
def on_abort_command(self) -> Message | None:
|
|
self.emit(self.EVENT_ABORT)
|
|
return None
|
|
|
|
def on_delayreport_command(self, delay: int) -> Message | None:
|
|
self.emit(self.EVENT_DELAY_REPORT, delay)
|
|
return None
|
|
|
|
def on_security_control_command(self, data: bytes) -> Message | None:
|
|
self.emit(self.EVENT_SECURITY_CONTROL, data)
|
|
return None
|
|
|
|
def on_rtp_channel_open(self) -> None:
|
|
self.emit(self.EVENT_RTP_CHANNEL_OPEN)
|
|
return None
|
|
|
|
def on_rtp_channel_close(self) -> None:
|
|
self.emit(self.EVENT_RTP_CHANNEL_CLOSE)
|
|
return None
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class LocalSource(LocalStreamEndPoint):
|
|
def __init__(
|
|
self,
|
|
protocol: Protocol,
|
|
seid: int,
|
|
codec_capabilities: MediaCodecCapabilities,
|
|
other_capabilities: Iterable[ServiceCapabilities],
|
|
packet_pump: MediaPacketPump,
|
|
) -> None:
|
|
capabilities = [
|
|
ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY),
|
|
codec_capabilities,
|
|
] + list(other_capabilities)
|
|
super().__init__(
|
|
protocol,
|
|
seid,
|
|
codec_capabilities.media_type,
|
|
AVDTP_TSEP_SRC,
|
|
capabilities,
|
|
capabilities,
|
|
)
|
|
self.packet_pump = packet_pump
|
|
|
|
@override
|
|
async def start(self) -> None:
|
|
if self.packet_pump and self.stream and self.stream.rtp_channel:
|
|
return await self.packet_pump.start(self.stream.rtp_channel)
|
|
|
|
self.emit(self.EVENT_START)
|
|
|
|
@override
|
|
async def stop(self) -> None:
|
|
if self.packet_pump:
|
|
return await self.packet_pump.stop()
|
|
|
|
self.emit(self.EVENT_STOP)
|
|
|
|
@override
|
|
def on_start_command(self) -> Message | None:
|
|
asyncio.create_task(self.start())
|
|
return None
|
|
|
|
@override
|
|
def on_suspend_command(self) -> Message | None:
|
|
asyncio.create_task(self.stop())
|
|
return None
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class LocalSink(LocalStreamEndPoint):
|
|
def __init__(
|
|
self, protocol: Protocol, seid: int, codec_capabilities: MediaCodecCapabilities
|
|
) -> None:
|
|
capabilities = [
|
|
ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY),
|
|
codec_capabilities,
|
|
]
|
|
super().__init__(
|
|
protocol,
|
|
seid,
|
|
codec_capabilities.media_type,
|
|
AVDTP_TSEP_SNK,
|
|
capabilities,
|
|
)
|
|
|
|
def on_rtp_channel_open(self) -> None:
|
|
logger.debug(color('<<< RTP channel open', 'magenta'))
|
|
if not self.stream:
|
|
raise InvalidStateError('Stream is None')
|
|
if not self.stream.rtp_channel:
|
|
raise InvalidStateError('RTP channel is None')
|
|
self.stream.rtp_channel.sink = self.on_avdtp_packet
|
|
super().on_rtp_channel_open()
|
|
|
|
def on_rtp_channel_close(self) -> None:
|
|
logger.debug(color('<<< RTP channel close', 'magenta'))
|
|
super().on_rtp_channel_close()
|
|
|
|
def on_avdtp_packet(self, packet: bytes) -> None:
|
|
rtp_packet = MediaPacket.from_bytes(packet)
|
|
logger.debug(
|
|
f'{color("<<< RTP Packet:", "green")} '
|
|
f'{rtp_packet} {rtp_packet.payload[:16].hex()}'
|
|
)
|
|
self.emit(self.EVENT_RTP_PACKET, rtp_packet)
|