From 495ce62d9cfb53df8679289cab2f065f76791150 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Wed, 4 Oct 2023 20:16:21 +0800 Subject: [PATCH] Typing AVDTP --- bumble/avdtp.py | 441 +++++++++++++++++++++++++++----------------- tests/avdtp_test.py | 6 +- 2 files changed, 273 insertions(+), 174 deletions(-) diff --git a/bumble/avdtp.py b/bumble/avdtp.py index 3988f309..a66395d7 100644 --- a/bumble/avdtp.py +++ b/bumble/avdtp.py @@ -20,8 +20,23 @@ import asyncio import struct import time import logging +import enum from pyee import EventEmitter -from typing import Dict, Type +from typing import ( + Any, + Awaitable, + Dict, + Type, + Tuple, + Optional, + Callable, + List, + AsyncGenerator, + Iterable, + Union, + SupportsBytes, + cast, +) from .core import ( BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE, @@ -38,7 +53,7 @@ from .a2dp import ( SbcMediaCodecInformation, VendorSpecificMediaCodecInformation, ) -from . import sdp +from . import sdp, device, l2cap from .colors import color # ----------------------------------------------------------------------------- @@ -206,7 +221,9 @@ AVDTP_STATE_NAMES = { # ----------------------------------------------------------------------------- -async def find_avdtp_service_with_sdp_client(sdp_client): +async def find_avdtp_service_with_sdp_client( + sdp_client: sdp.Client, +) -> Optional[Tuple[int, int]]: ''' Find an AVDTP service, using a connected SDP client, and return its version, or None if none is found @@ -227,10 +244,13 @@ async def find_avdtp_service_with_sdp_client(sdp_client): avdtp_version_major = profile_descriptor.value[1].value >> 8 avdtp_version_minor = profile_descriptor.value[1].value & 0xFF return (avdtp_version_major, avdtp_version_minor) + return None # ----------------------------------------------------------------------------- -async def find_avdtp_service_with_connection(device, connection): +async def find_avdtp_service_with_connection( + device: device.Device, connection: device.Connection +) -> Optional[Tuple[int, int]]: ''' Find an AVDTP service, for a connection, and return its version, or None if none is found @@ -246,17 +266,17 @@ async def find_avdtp_service_with_connection(device, connection): # ----------------------------------------------------------------------------- class RealtimeClock: - def now(self): + def now(self) -> float: return time.time() - async def sleep(self, duration): + async def sleep(self, duration: float) -> None: await asyncio.sleep(duration) # ----------------------------------------------------------------------------- class MediaPacket: @staticmethod - def from_bytes(data): + def from_bytes(data: bytes) -> MediaPacket: version = (data[0] >> 6) & 0x03 padding = (data[0] >> 5) & 0x01 extension = (data[0] >> 4) & 0x01 @@ -286,17 +306,17 @@ class MediaPacket: def __init__( self, - version, - padding, - extension, - marker, - sequence_number, - timestamp, - ssrc, - csrc_list, - payload_type, - payload, - ): + version: int, + padding: int, + extension: int, + marker: int, + sequence_number: int, + timestamp: int, + ssrc: int, + csrc_list: List[int], + payload_type: int, + payload: bytes, + ) -> None: self.version = version self.padding = padding self.extension = extension @@ -308,7 +328,7 @@ class MediaPacket: self.payload_type = payload_type self.payload = payload - def __bytes__(self): + def __bytes__(self) -> bytes: header = bytes( [ self.version << 6 @@ -322,7 +342,7 @@ class MediaPacket: header += struct.pack('>I', csrc) return header + self.payload - def __str__(self): + def __str__(self) -> str: return ( f'RTP(v={self.version},' f'p={self.padding},' @@ -339,12 +359,16 @@ class MediaPacket: # ----------------------------------------------------------------------------- class MediaPacketPump: - def __init__(self, packets, clock=RealtimeClock()): + pump_task: Optional[asyncio.Task] + + def __init__( + self, packets: AsyncGenerator, clock: RealtimeClock = RealtimeClock() + ) -> None: self.packets = packets self.clock = clock self.pump_task = None - async def start(self, rtp_channel): + async def start(self, rtp_channel: l2cap.Channel) -> None: async def pump_packets(): start_time = 0 start_timestamp = 0 @@ -376,7 +400,7 @@ class MediaPacketPump: # Pump packets self.pump_task = asyncio.create_task(pump_packets()) - async def stop(self): + async def stop(self) -> None: # Stop the pump if self.pump_task: self.pump_task.cancel() @@ -385,32 +409,37 @@ class MediaPacketPump: # ----------------------------------------------------------------------------- -class MessageAssembler: # pylint: disable=attribute-defined-outside-init - def __init__(self, callback): +class MessageAssembler: + message: Optional[bytes] + + def __init__(self, callback: Callable[[int, Message], Any]) -> None: self.callback = callback self.reset() - def reset(self): + def reset(self) -> None: self.transaction_label = 0 self.message = None - self.message_type = 0 + self.message_type = Message.MessageType.COMMAND self.signal_identifier = 0 self.number_of_signal_packets = 0 self.packet_count = 0 - def on_pdu(self, pdu): + def on_pdu(self, pdu: bytes) -> None: self.packet_count += 1 transaction_label = pdu[0] >> 4 - packet_type = (pdu[0] >> 2) & 3 - message_type = pdu[0] & 3 + packet_type = Protocol.PacketType((pdu[0] >> 2) & 3) + message_type = Message.MessageType(pdu[0] & 3) logger.debug( f'transaction_label={transaction_label}, ' - f'packet_type={Protocol.packet_type_name(packet_type)}, ' - f'message_type={Message.message_type_name(message_type)}' + f'packet_type={packet_type.name}, ' + f'message_type={message_type.name}' ) - if packet_type in (Protocol.SINGLE_PACKET, Protocol.START_PACKET): + if packet_type in ( + Protocol.PacketType.SINGLE_PACKET, + Protocol.PacketType.START_PACKET, + ): if self.message is not None: # The previous message has not been terminated logger.warning( @@ -423,13 +452,16 @@ class MessageAssembler: # pylint: disable=attribute-defined-outside-init self.signal_identifier = pdu[1] & 0x3F self.message_type = message_type - if packet_type == Protocol.SINGLE_PACKET: + if packet_type == Protocol.PacketType.SINGLE_PACKET: self.message = pdu[2:] self.on_message_complete() else: self.number_of_signal_packets = pdu[2] self.message = pdu[3:] - elif packet_type in (Protocol.CONTINUE_PACKET, Protocol.END_PACKET): + elif packet_type in ( + Protocol.PacketType.CONTINUE_PACKET, + Protocol.PacketType.END_PACKET, + ): if self.packet_count == 0: logger.warning('unexpected continuation') return @@ -448,9 +480,9 @@ class MessageAssembler: # pylint: disable=attribute-defined-outside-init ) return - self.message += pdu[1:] + self.message = (self.message or b'') + pdu[1:] - if packet_type == Protocol.END_PACKET: + if packet_type == Protocol.PacketType.END_PACKET: if self.packet_count != self.number_of_signal_packets: logger.warning( 'incomplete fragmented message: ' @@ -471,24 +503,25 @@ class MessageAssembler: # pylint: disable=attribute-defined-outside-init self.reset() return - def on_message_complete(self): + def on_message_complete(self) -> None: message = Message.create( - self.signal_identifier, self.message_type, self.message + self.signal_identifier, self.message_type, self.message or b'' ) - try: self.callback(self.transaction_label, message) except Exception as error: logger.warning(color(f'!!! exception in callback: {error}')) - self.reset() # ----------------------------------------------------------------------------- class ServiceCapabilities: @staticmethod - def create(service_category, service_capabilities_bytes): + def create( + service_category: int, service_capabilities_bytes: bytes + ) -> ServiceCapabilities: # Select the appropriate subclass + cls: Type[ServiceCapabilities] if service_category == AVDTP_MEDIA_CODEC_SERVICE_CATEGORY: cls = MediaCodecCapabilities else: @@ -503,7 +536,7 @@ class ServiceCapabilities: return instance @staticmethod - def parse_capabilities(payload): + def parse_capabilities(payload: bytes) -> List[ServiceCapabilities]: capabilities = [] while payload: service_category = payload[0] @@ -518,7 +551,7 @@ class ServiceCapabilities: return capabilities @staticmethod - def serialize_capabilities(capabilities): + def serialize_capabilities(capabilities: Iterable[ServiceCapabilities]) -> bytes: serialized = b'' for item in capabilities: serialized += ( @@ -527,21 +560,23 @@ class ServiceCapabilities: ) return serialized - def init_from_bytes(self): + def init_from_bytes(self) -> None: pass - def __init__(self, service_category, service_capabilities_bytes=b''): + def __init__( + self, service_category: int, service_capabilities_bytes: bytes = b'' + ) -> None: self.service_category = service_category self.service_capabilities_bytes = service_capabilities_bytes - def to_string(self, details=[]): # pylint: disable=dangerous-default-value + def to_string(self, details: List[str] = []) -> str: attributes = ','.join( [name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)] + details ) return f'ServiceCapabilities({attributes})' - def __str__(self): + def __str__(self) -> str: if self.service_capabilities_bytes: details = [self.service_capabilities_bytes.hex()] else: @@ -551,7 +586,9 @@ class ServiceCapabilities: # ----------------------------------------------------------------------------- class MediaCodecCapabilities(ServiceCapabilities): - def init_from_bytes(self): + media_codec_information: Union[bytes, SupportsBytes] + + def init_from_bytes(self) -> None: self.media_type = self.service_capabilities_bytes[0] self.media_codec_type = self.service_capabilities_bytes[1] self.media_codec_information = self.service_capabilities_bytes[2:] @@ -571,7 +608,12 @@ class MediaCodecCapabilities(ServiceCapabilities): ) ) - def __init__(self, media_type, media_codec_type, media_codec_information): + def __init__( + self, + media_type: int, + media_codec_type: int, + media_codec_information: Union[bytes, SupportsBytes], + ) -> None: super().__init__( AVDTP_MEDIA_CODEC_SERVICE_CATEGORY, bytes([media_type, media_codec_type]) + bytes(media_codec_information), @@ -580,7 +622,7 @@ class MediaCodecCapabilities(ServiceCapabilities): self.media_codec_type = media_codec_type self.media_codec_information = media_codec_information - def __str__(self): + def __str__(self) -> str: codec_info = ( self.media_codec_information.hex() if isinstance(self.media_codec_information, bytes) @@ -598,17 +640,17 @@ class MediaCodecCapabilities(ServiceCapabilities): # ----------------------------------------------------------------------------- class EndPointInfo: @staticmethod - def from_bytes(payload): + def from_bytes(payload: bytes) -> EndPointInfo: return EndPointInfo( payload[0] >> 2, payload[0] >> 1 & 1, payload[1] >> 4, payload[1] >> 3 & 1 ) - def __bytes__(self): + def __bytes__(self) -> bytes: return bytes( [self.seid << 2 | self.in_use << 1, self.media_type << 4 | self.tsep << 3] ) - def __init__(self, seid, in_use, media_type, tsep): + def __init__(self, seid: int, in_use: int, media_type: int, tsep: int) -> None: self.seid = seid self.in_use = in_use self.media_type = media_type @@ -617,24 +659,16 @@ class EndPointInfo: # ----------------------------------------------------------------------------- class Message: # pylint:disable=attribute-defined-outside-init - COMMAND = 0 - GENERAL_REJECT = 1 - RESPONSE_ACCEPT = 2 - RESPONSE_REJECT = 3 - - MESSAGE_TYPE_NAMES = { - COMMAND: 'COMMAND', - GENERAL_REJECT: 'GENERAL_REJECT', - RESPONSE_ACCEPT: 'RESPONSE_ACCEPT', - RESPONSE_REJECT: 'RESPONSE_REJECT', - } + class MessageType(enum.IntEnum): + COMMAND = 0 + GENERAL_REJECT = 1 + RESPONSE_ACCEPT = 2 + RESPONSE_REJECT = 3 # Subclasses, by signal identifier and message type subclasses: Dict[int, Dict[int, Type[Message]]] = {} - - @staticmethod - def message_type_name(message_type): - return name_or_number(Message.MESSAGE_TYPE_NAMES, message_type) + message_type: MessageType + signal_identifier: int @staticmethod def subclass(subclass): @@ -643,23 +677,23 @@ class Message: # pylint:disable=attribute-defined-outside-init if name == 'General_Reject': subclass.signal_identifier = 0 signal_identifier_str = None - message_type = Message.COMMAND + message_type = Message.MessageType.COMMAND elif name.endswith('_Command'): signal_identifier_str = name[:-8] - message_type = Message.COMMAND + message_type = Message.MessageType.COMMAND elif name.endswith('_Response'): signal_identifier_str = name[:-9] - message_type = Message.RESPONSE_ACCEPT + message_type = Message.MessageType.RESPONSE_ACCEPT elif name.endswith('_Reject'): signal_identifier_str = name[:-7] - message_type = Message.RESPONSE_REJECT + message_type = Message.MessageType.RESPONSE_REJECT else: raise ValueError('invalid class name') subclass.message_type = message_type if signal_identifier_str is not None: - for (name, signal_identifier) in AVDTP_SIGNAL_IDENTIFIERS.items(): + for name, signal_identifier in AVDTP_SIGNAL_IDENTIFIERS.items(): if name.lower().endswith(signal_identifier_str.lower()): subclass.signal_identifier = signal_identifier break @@ -674,7 +708,9 @@ class Message: # pylint:disable=attribute-defined-outside-init # Factory method to create a subclass based on the signal identifier and message # type @staticmethod - def create(signal_identifier, message_type, payload): + def create( + signal_identifier: int, message_type: MessageType, payload: bytes + ) -> Message: # Look for a registered subclass subclasses = Message.subclasses.get(signal_identifier) if subclasses: @@ -686,7 +722,7 @@ class Message: # pylint:disable=attribute-defined-outside-init return instance # Instantiate the appropriate class based on the message type - if message_type == Message.RESPONSE_REJECT: + if message_type == Message.MessageType.RESPONSE_REJECT: # Assume a simple reject message instance = Simple_Reject(payload) instance.init_from_payload() @@ -696,16 +732,16 @@ class Message: # pylint:disable=attribute-defined-outside-init instance.message_type = message_type return instance - def init_from_payload(self): + def init_from_payload(self) -> None: pass - def __init__(self, payload=b''): + def __init__(self, payload: bytes = b'') -> None: self.payload = payload - def to_string(self, details): + def to_string(self, details: Union[str, Iterable[str]]) -> str: base = color( f'{name_or_number(AVDTP_SIGNAL_NAMES, self.signal_identifier)}_' - f'{Message.message_type_name(self.message_type)}', + f'{self.message_type.name}', 'yellow', ) @@ -721,7 +757,7 @@ class Message: # pylint:disable=attribute-defined-outside-init return base - def __str__(self): + def __str__(self) -> str: return self.to_string(self.payload.hex()) @@ -738,7 +774,7 @@ class Simple_Command(Message): super().__init__(payload=bytes([seid << 2])) self.acp_seid = seid - def __str__(self): + def __str__(self) -> str: return self.to_string([f'ACP SEID: {self.acp_seid}']) @@ -755,7 +791,7 @@ class Simple_Reject(Message): super().__init__(payload=bytes([error_code])) self.error_code = error_code - def __str__(self): + def __str__(self) -> str: details = [f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}'] return self.to_string(details) @@ -775,6 +811,8 @@ class Discover_Response(Message): See Bluetooth AVDTP spec - 8.6.2 Stream End Point Discovery Response ''' + endpoints: List[EndPointInfo] + def init_from_payload(self): self.endpoints = [] endpoint_count = len(self.payload) // 2 @@ -787,7 +825,7 @@ class Discover_Response(Message): super().__init__(payload=b''.join([bytes(endpoint) for endpoint in endpoints])) self.endpoints = endpoints - def __str__(self): + def __str__(self) -> str: details = [] for endpoint in self.endpoints: details.extend( @@ -826,7 +864,7 @@ class Get_Capabilities_Response(Message): ) self.capabilities = capabilities - def __str__(self): + def __str__(self) -> str: details = [str(capability) for capability in self.capabilities] return self.to_string(details) @@ -875,7 +913,9 @@ class Set_Configuration_Command(Message): self.int_seid = self.payload[1] >> 2 self.capabilities = ServiceCapabilities.parse_capabilities(self.payload[2:]) - def __init__(self, acp_seid, int_seid, capabilities): + def __init__( + self, acp_seid: int, int_seid: int, capabilities: Iterable[ServiceCapabilities] + ) -> None: super().__init__( payload=bytes([acp_seid << 2, int_seid << 2]) + ServiceCapabilities.serialize_capabilities(capabilities) @@ -884,7 +924,7 @@ class Set_Configuration_Command(Message): self.int_seid = int_seid self.capabilities = capabilities - def __str__(self): + def __str__(self) -> str: details = [f'ACP SEID: {self.acp_seid}', f'INT SEID: {self.int_seid}'] + [ str(capability) for capability in self.capabilities ] @@ -915,7 +955,7 @@ class Set_Configuration_Reject(Message): self.service_category = service_category self.error_code = error_code - def __str__(self): + def __str__(self) -> str: details = [ ( 'service_category: ' @@ -947,13 +987,13 @@ class Get_Configuration_Response(Message): def init_from_payload(self): self.capabilities = ServiceCapabilities.parse_capabilities(self.payload) - def __init__(self, capabilities): + def __init__(self, capabilities: Iterable[ServiceCapabilities]) -> None: super().__init__( payload=ServiceCapabilities.serialize_capabilities(capabilities) ) self.capabilities = capabilities - def __str__(self): + def __str__(self) -> str: details = [str(capability) for capability in self.capabilities] return self.to_string(details) @@ -978,7 +1018,7 @@ class Reconfigure_Command(Message): self.acp_seid = self.payload[0] >> 2 self.capabilities = ServiceCapabilities.parse_capabilities(self.payload[1:]) - def __str__(self): + def __str__(self) -> str: details = [ f'ACP SEID: {self.acp_seid}', ] + [str(capability) for capability in self.capabilities] @@ -1035,11 +1075,11 @@ class Start_Command(Message): def init_from_payload(self): self.acp_seids = [x >> 2 for x in self.payload] - def __init__(self, seids): + def __init__(self, seids: Iterable[int]) -> None: super().__init__(payload=bytes([seid << 2 for seid in seids])) self.acp_seids = seids - def __str__(self): + def __str__(self) -> str: return self.to_string([f'ACP SEIDs: {self.acp_seids}']) @@ -1067,7 +1107,7 @@ class Start_Reject(Message): self.acp_seid = acp_seid self.error_code = error_code - def __str__(self): + def __str__(self) -> str: details = [ f'acp_seid: {self.acp_seid}', f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}', @@ -1186,7 +1226,7 @@ class DelayReport_Command(Message): self.acp_seid = self.payload[0] >> 2 self.delay = (self.payload[1] << 8) | (self.payload[2]) - def __str__(self): + def __str__(self) -> str: return self.to_string([f'ACP_SEID: {self.acp_seid}', f'delay: {self.delay}']) @@ -1208,24 +1248,22 @@ class DelayReport_Reject(Simple_Reject): # ----------------------------------------------------------------------------- class Protocol(EventEmitter): - SINGLE_PACKET = 0 - START_PACKET = 1 - CONTINUE_PACKET = 2 - END_PACKET = 3 + local_endpoints: List[LocalStreamEndPoint] + remote_endpoints: Dict[int, DiscoveredStreamEndPoint] + streams: Dict[int, Stream] + transaction_results: List[Optional[asyncio.Future[Message]]] + channel_connector: Callable[[], Awaitable[l2cap.Channel]] - PACKET_TYPE_NAMES = { - SINGLE_PACKET: 'SINGLE_PACKET', - START_PACKET: 'START_PACKET', - CONTINUE_PACKET: 'CONTINUE_PACKET', - END_PACKET: 'END_PACKET', - } + class PacketType(enum.IntEnum): + SINGLE_PACKET = 0 + START_PACKET = 1 + CONTINUE_PACKET = 2 + END_PACKET = 3 @staticmethod - def packet_type_name(packet_type): - return name_or_number(Protocol.PACKET_TYPE_NAMES, packet_type) - - @staticmethod - async def connect(connection, version=(1, 3)): + async def connect( + connection: device.Connection, version: Tuple[int, int] = (1, 3) + ) -> Protocol: connector = connection.create_l2cap_connector(AVDTP_PSM) channel = await connector() protocol = Protocol(channel, version) @@ -1233,7 +1271,9 @@ class Protocol(EventEmitter): return protocol - def __init__(self, l2cap_channel, version=(1, 3)): + def __init__( + self, l2cap_channel: l2cap.Channel, version: Tuple[int, int] = (1, 3) + ) -> None: super().__init__() self.l2cap_channel = l2cap_channel self.version = version @@ -1243,7 +1283,6 @@ class Protocol(EventEmitter): self.transaction_semaphore = asyncio.Semaphore(16) self.transaction_count = 0 self.channel_acceptor = None - self.channel_connector = None self.local_endpoints = [] # Local endpoints, with contiguous seid values self.remote_endpoints = {} # Remote stream endpoints, by seid self.streams = {} # Streams, by seid @@ -1253,27 +1292,31 @@ class Protocol(EventEmitter): l2cap_channel.on('open', self.on_l2cap_channel_open) l2cap_channel.on('close', self.on_l2cap_channel_close) - def get_local_endpoint_by_seid(self, seid): + def get_local_endpoint_by_seid(self, seid: int) -> Optional[LocalStreamEndPoint]: if 0 < seid <= len(self.local_endpoints): return self.local_endpoints[seid - 1] return None - def add_source(self, codec_capabilities, packet_pump): + def add_source( + self, codec_capabilities: MediaCodecCapabilities, packet_pump: MediaPacketPump + ) -> LocalSource: seid = len(self.local_endpoints) + 1 source = LocalSource(self, seid, codec_capabilities, packet_pump) self.local_endpoints.append(source) return source - def add_sink(self, codec_capabilities): + def add_sink(self, codec_capabilities: MediaCodecCapabilities) -> LocalSink: seid = len(self.local_endpoints) + 1 sink = LocalSink(self, seid, codec_capabilities) self.local_endpoints.append(sink) return sink - async def create_stream(self, source, sink): + async def create_stream( + self, source: LocalStreamEndPoint, sink: StreamEndPointProxy + ) -> Stream: # Check that the source isn't already used in a stream if source.in_use: raise InvalidStateError('source already in use') @@ -1290,10 +1333,10 @@ class Protocol(EventEmitter): return stream - async def discover_remote_endpoints(self): + async def discover_remote_endpoints(self) -> Iterable[DiscoveredStreamEndPoint]: self.remote_endpoints = {} - response = await self.send_command(Discover_Command()) + response: Discover_Response = await self.send_command(Discover_Command()) for endpoint_entry in response.endpoints: logger.debug( f'getting endpoint capabilities for endpoint {endpoint_entry.seid}' @@ -1311,7 +1354,9 @@ class Protocol(EventEmitter): return self.remote_endpoints.values() - def find_remote_sink_by_codec(self, media_type, codec_type): + def find_remote_sink_by_codec( + self, media_type: int, codec_type: int + ) -> Optional[DiscoveredStreamEndPoint]: for endpoint in self.remote_endpoints.values(): if ( not endpoint.in_use @@ -1330,9 +1375,10 @@ class Protocol(EventEmitter): capabilities.service_category == AVDTP_MEDIA_CODEC_SERVICE_CATEGORY ): + codec_capabilities = cast(MediaCodecCapabilities, capabilities) if ( - capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE - and capabilities.media_codec_type == codec_type + codec_capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE + and codec_capabilities.media_codec_type == codec_type ): has_codec = True if has_media_transport and has_codec: @@ -1340,10 +1386,10 @@ class Protocol(EventEmitter): return None - def on_pdu(self, pdu): + def on_pdu(self, pdu: bytes) -> None: self.message_assembler.on_pdu(pdu) - def on_message(self, transaction_label, message): + def on_message(self, transaction_label: int, message: Message) -> None: logger.debug( f'{color("<<< Received AVDTP message", "magenta")}: ' f'[{transaction_label}] {message}' @@ -1362,7 +1408,7 @@ class Protocol(EventEmitter): logger.warning('!!! invalid signal identifier') self.send_message(transaction_label, General_Reject()) - if message.message_type == Message.COMMAND: + if message.message_type == Message.MessageType.COMMAND: # Command signal_name = ( AVDTP_SIGNAL_NAMES.get(message.signal_identifier, "") @@ -1407,7 +1453,7 @@ class Protocol(EventEmitter): logger.debug(color('<<< L2CAP channel close', 'magenta')) self.emit('close') - def send_message(self, transaction_label, message): + def send_message(self, transaction_label: int, message: Message) -> None: logger.debug( f'{color(">>> Sending AVDTP message", "magenta")}: ' f'[{transaction_label}] {message}' @@ -1418,9 +1464,9 @@ class Protocol(EventEmitter): payload = message.payload if len(payload) + 2 <= self.l2cap_channel.mtu: # Fits in a single packet - packet_type = self.SINGLE_PACKET + packet_type = self.PacketType.SINGLE_PACKET else: - packet_type = self.START_PACKET + packet_type = self.PacketType.START_PACKET done = False while not done: @@ -1428,9 +1474,9 @@ class Protocol(EventEmitter): transaction_label << 4 | packet_type << 2 | message.message_type ) - if packet_type == self.SINGLE_PACKET: + if packet_type == self.PacketType.SINGLE_PACKET: header = bytes([first_header_byte, message.signal_identifier]) - elif packet_type == self.START_PACKET: + elif packet_type == self.PacketType.START_PACKET: packet_count = ( max_fragment_size - 1 + len(payload) ) // max_fragment_size @@ -1447,14 +1493,14 @@ class Protocol(EventEmitter): payload = payload[max_fragment_size:] if payload: packet_type = ( - self.CONTINUE_PACKET - if payload > max_fragment_size - else self.END_PACKET + self.PacketType.CONTINUE_PACKET + if len(payload) > max_fragment_size + else self.PacketType.END_PACKET ) else: done = True - async def send_command(self, command): + async def send_command(self, command: Message): # TODO: support timeouts # Send the command (transaction_label, transaction_result) = await self.start_transaction() @@ -1464,12 +1510,16 @@ class Protocol(EventEmitter): response = await transaction_result # Check for errors - if response.message_type in (Message.GENERAL_REJECT, Message.RESPONSE_REJECT): + if response.message_type in ( + Message.MessageType.GENERAL_REJECT, + Message.MessageType.RESPONSE_REJECT, + ): + assert hasattr(response, 'error_code') raise ProtocolError(response.error_code, 'avdtp') return response - async def start_transaction(self): + async def start_transaction(self) -> Tuple[int, asyncio.Future[Message]]: # Wait until we can start a new transaction await self.transaction_semaphore.acquire() @@ -1484,34 +1534,38 @@ class Protocol(EventEmitter): assert False # Should never reach this - async def get_capabilities(self, seid): + async def get_capabilities( + self, seid: int + ) -> Union[Get_Capabilities_Response, Get_All_Capabilities_Response,]: if self.version > (1, 2): return await self.send_command(Get_All_Capabilities_Command(seid)) return await self.send_command(Get_Capabilities_Command(seid)) - async def set_configuration(self, acp_seid, int_seid, capabilities): + async def set_configuration( + self, acp_seid: int, int_seid: int, capabilities: Iterable[ServiceCapabilities] + ) -> Set_Configuration_Response: return await self.send_command( Set_Configuration_Command(acp_seid, int_seid, capabilities) ) - async def get_configuration(self, seid): + async def get_configuration(self, seid: int) -> Get_Configuration_Response: response = await self.send_command(Get_Configuration_Command(seid)) return response.capabilities - async def open(self, seid): + async def open(self, seid: int) -> Open_Response: return await self.send_command(Open_Command(seid)) - async def start(self, seids): + async def start(self, seids: Iterable[int]) -> Start_Response: return await self.send_command(Start_Command(seids)) - async def suspend(self, seids): + async def suspend(self, seids: Iterable[int]) -> Suspend_Response: return await self.send_command(Suspend_Command(seids)) - async def close(self, seid): + async def close(self, seid: int) -> Close_Response: return await self.send_command(Close_Command(seid)) - async def abort(self, seid): + async def abort(self, seid: int) -> Abort_Response: return await self.send_command(Abort_Command(seid)) def on_discover_command(self, _command): @@ -1653,14 +1707,16 @@ class Protocol(EventEmitter): # ----------------------------------------------------------------------------- class Listener(EventEmitter): + servers: Dict[int, Protocol] + @staticmethod def create_registrar(device): return device.create_l2cap_registrar(AVDTP_PSM) - def set_server(self, connection, server): + def set_server(self, connection: device.Connection, server: Protocol) -> None: self.servers[connection.handle] = server - def remove_server(self, connection): + def remove_server(self, connection: device.Connection) -> None: if connection.handle in self.servers: del self.servers[connection.handle] @@ -1672,7 +1728,7 @@ class Listener(EventEmitter): # Listen for incoming L2CAP connections registrar(self.on_l2cap_connection) - def on_l2cap_connection(self, channel): + def on_l2cap_connection(self, channel: l2cap.Channel) -> None: logger.debug(f'{color("<<< incoming L2CAP connection:", "magenta")} {channel}') if channel.connection.handle in self.servers: @@ -1701,18 +1757,21 @@ class Stream: Pair of a local and a remote stream endpoint that can stream from one to the other ''' + rtp_channel: Optional[l2cap.Channel] + @staticmethod - def state_name(state): + def state_name(state: int) -> str: return name_or_number(AVDTP_STATE_NAMES, state) - def change_state(self, state): + def change_state(self, state: int) -> None: logger.debug(f'{self} state change -> {color(self.state_name(state), "cyan")}') self.state = state - def send_media_packet(self, packet): + def send_media_packet(self, packet: MediaPacket) -> None: + assert self.rtp_channel self.rtp_channel.send_pdu(bytes(packet)) - async def configure(self): + async def configure(self) -> None: if self.state != AVDTP_IDLE_STATE: raise InvalidStateError('current state is not IDLE') @@ -1721,7 +1780,7 @@ class Stream: ) self.change_state(AVDTP_CONFIGURED_STATE) - async def open(self): + async def open(self) -> None: if self.state != AVDTP_CONFIGURED_STATE: raise InvalidStateError('current state is not CONFIGURED') @@ -1733,7 +1792,7 @@ class Stream: # Create a channel for RTP packets self.rtp_channel = await self.protocol.channel_connector() - async def start(self): + async def start(self) -> None: # Auto-open if needed if self.state == AVDTP_CONFIGURED_STATE: await self.open() @@ -1749,7 +1808,7 @@ class Stream: self.change_state(AVDTP_STREAMING_STATE) - async def stop(self): + async def stop(self) -> None: if self.state != AVDTP_STREAMING_STATE: raise InvalidStateError('current state is not STREAMING') @@ -1761,7 +1820,7 @@ class Stream: self.change_state(AVDTP_OPEN_STATE) - async def close(self): + async def close(self) -> None: if self.state not in (AVDTP_OPEN_STATE, AVDTP_STREAMING_STATE): raise InvalidStateError('current state is not OPEN or STREAMING') @@ -1905,7 +1964,12 @@ class Stream: else: logger.warning('unexpected channel close while not CLOSING or ABORTING') - def __init__(self, protocol, local_endpoint, remote_endpoint): + def __init__( + self, + protocol: Protocol, + local_endpoint: LocalStreamEndPoint, + remote_endpoint: StreamEndPointProxy, + ) -> None: ''' remote_endpoint must be a subclass of StreamEndPointProxy @@ -1919,7 +1983,7 @@ class Stream: local_endpoint.stream = self local_endpoint.in_use = 1 - def __str__(self): + def __str__(self) -> str: return ( f'Stream({self.local_endpoint.seid} -> ' f'{self.remote_endpoint.seid} {self.state_name(self.state)})' @@ -1928,14 +1992,21 @@ class Stream: # ----------------------------------------------------------------------------- class StreamEndPoint: - def __init__(self, seid, media_type, tsep, in_use, capabilities): + def __init__( + self, + seid: int, + media_type: int, + tsep: int, + in_use: int, + capabilities: Iterable[ServiceCapabilities], + ) -> None: self.seid = seid self.media_type = media_type self.tsep = tsep self.in_use = in_use self.capabilities = capabilities - def __str__(self): + def __str__(self) -> str: media_type = f'{name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}' tsep = f'{name_or_number(AVDTP_TSEP_NAMES, self.tsep)}' return '\n'.join( @@ -1955,40 +2026,58 @@ class StreamEndPoint: # ----------------------------------------------------------------------------- class StreamEndPointProxy: - def __init__(self, protocol, seid): + def __init__(self, protocol: Protocol, seid: int) -> None: self.seid = seid self.protocol = protocol - async def set_configuration(self, int_seid, configuration): + async def set_configuration( + self, int_seid: int, configuration: Iterable[ServiceCapabilities] + ) -> Set_Configuration_Response: return await self.protocol.set_configuration(self.seid, int_seid, configuration) - async def open(self): + async def open(self) -> Open_Response: return await self.protocol.open(self.seid) - async def start(self): + async def start(self) -> Start_Response: return await self.protocol.start([self.seid]) - async def stop(self): + async def stop(self) -> Suspend_Response: return await self.protocol.suspend([self.seid]) - async def close(self): + async def close(self) -> Close_Response: return await self.protocol.close(self.seid) - async def abort(self): + async def abort(self) -> Abort_Response: return await self.protocol.abort(self.seid) # ----------------------------------------------------------------------------- class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy): - def __init__(self, protocol, seid, media_type, tsep, in_use, capabilities): + def __init__( + self, + protocol: Protocol, + seid: int, + media_type: int, + tsep: int, + in_use: int, + capabilities: Iterable[ServiceCapabilities], + ) -> None: StreamEndPoint.__init__(self, seid, media_type, tsep, in_use, capabilities) StreamEndPointProxy.__init__(self, protocol, seid) # ----------------------------------------------------------------------------- class LocalStreamEndPoint(StreamEndPoint, EventEmitter): + stream: Optional[Stream] + def __init__( - self, protocol, seid, media_type, tsep, capabilities, configuration=None + self, + protocol: Protocol, + seid: int, + media_type: int, + tsep: int, + capabilities: Iterable[ServiceCapabilities], + configuration: Optional[Iterable[ServiceCapabilities]] = None, ): StreamEndPoint.__init__(self, seid, media_type, tsep, 0, capabilities) EventEmitter.__init__(self) @@ -2043,7 +2132,13 @@ class LocalStreamEndPoint(StreamEndPoint, EventEmitter): # ----------------------------------------------------------------------------- class LocalSource(LocalStreamEndPoint): - def __init__(self, protocol, seid, codec_capabilities, packet_pump): + def __init__( + self, + protocol: Protocol, + seid: int, + codec_capabilities: MediaCodecCapabilities, + packet_pump: MediaPacketPump, + ) -> None: capabilities = [ ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY), codec_capabilities, @@ -2058,13 +2153,13 @@ class LocalSource(LocalStreamEndPoint): ) self.packet_pump = packet_pump - async def start(self): - if self.packet_pump: + async def start(self) -> None: + if self.packet_pump and self.stream and self.stream.rtp_channel: return await self.packet_pump.start(self.stream.rtp_channel) self.emit('start') - async def stop(self): + async def stop(self) -> None: if self.packet_pump: return await self.packet_pump.stop() @@ -2079,7 +2174,9 @@ class LocalSource(LocalStreamEndPoint): # ----------------------------------------------------------------------------- class LocalSink(LocalStreamEndPoint): - def __init__(self, protocol, seid, codec_capabilities): + def __init__( + self, protocol: Protocol, seid: int, codec_capabilities: MediaCodecCapabilities + ) -> None: capabilities = [ ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY), codec_capabilities, diff --git a/tests/avdtp_test.py b/tests/avdtp_test.py index 1ca5254d..666a84cf 100644 --- a/tests/avdtp_test.py +++ b/tests/avdtp_test.py @@ -45,12 +45,14 @@ def test_messages(): ] message = Get_Capabilities_Response(capabilities) parsed = Message.create( - AVDTP_GET_CAPABILITIES, Message.RESPONSE_ACCEPT, message.payload + AVDTP_GET_CAPABILITIES, Message.MessageType.RESPONSE_ACCEPT, message.payload ) assert message.payload == parsed.payload message = Set_Configuration_Command(3, 4, capabilities) - parsed = Message.create(AVDTP_SET_CONFIGURATION, Message.COMMAND, message.payload) + parsed = Message.create( + AVDTP_SET_CONFIGURATION, Message.MessageType.COMMAND, message.payload + ) assert message.payload == parsed.payload