format with Black

This commit is contained in:
Gilles Boccon-Gibod
2022-12-10 08:53:51 -08:00
parent 297246fa4c
commit 135df0dcc0
104 changed files with 8646 additions and 5766 deletions

View File

@@ -26,7 +26,7 @@ from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
InvalidStateError,
ProtocolError,
name_or_number
name_or_number,
)
from .a2dp import (
A2DP_CODEC_TYPE_NAMES,
@@ -35,7 +35,7 @@ from .a2dp import (
A2DP_SBC_CODEC_TYPE,
AacMediaCodecInformation,
SbcMediaCodecInformation,
VendorSpecificMediaCodecInformation
VendorSpecificMediaCodecInformation,
)
from . import sdp
@@ -48,6 +48,8 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
AVDTP_PSM = 0x0019
AVDTP_DEFAULT_RTX_SIG_TIMER = 5 # Seconds
@@ -195,6 +197,8 @@ AVDTP_STATE_NAMES = {
AVDTP_ABORTING_STATE: 'AVDTP_ABORTING_STATE'
}
# fmt: on
# -----------------------------------------------------------------------------
async def find_avdtp_service_with_sdp_client(sdp_client):
@@ -206,14 +210,11 @@ async def find_avdtp_service_with_sdp_client(sdp_client):
# Search for services with an Audio Sink service class
search_result = await sdp_client.search_attributes(
[BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE],
[
sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
]
[sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID],
)
for attribute_list in search_result:
profile_descriptor_list = sdp.ServiceAttribute.find_attribute_in_list(
attribute_list,
sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
attribute_list, sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
)
if profile_descriptor_list:
for profile_descriptor in profile_descriptor_list.value:
@@ -251,17 +252,19 @@ class RealtimeClock:
class MediaPacket:
@staticmethod
def from_bytes(data):
version = (data[0] >> 6) & 0x03
padding = (data[0] >> 5) & 0x01
extension = (data[0] >> 4) & 0x01
csrc_count = data[0] & 0x0F
marker = (data[1] >> 7) & 0x01
payload_type = data[1] & 0x7F
version = (data[0] >> 6) & 0x03
padding = (data[0] >> 5) & 0x01
extension = (data[0] >> 4) & 0x01
csrc_count = data[0] & 0x0F
marker = (data[1] >> 7) & 0x01
payload_type = data[1] & 0x7F
sequence_number = struct.unpack_from('>H', data, 2)[0]
timestamp = struct.unpack_from('>I', data, 4)[0]
ssrc = struct.unpack_from('>I', data, 8)[0]
csrc_list = [struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)]
payload = data[12 + csrc_count * 4:]
timestamp = struct.unpack_from('>I', data, 4)[0]
ssrc = struct.unpack_from('>I', data, 8)[0]
csrc_list = [
struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)
]
payload = data[12 + csrc_count * 4 :]
return MediaPacket(
version,
@@ -273,7 +276,7 @@ class MediaPacket:
ssrc,
csrc_list,
payload_type,
payload
payload,
)
def __init__(
@@ -287,27 +290,29 @@ class MediaPacket:
ssrc,
csrc_list,
payload_type,
payload
payload,
):
self.version = version
self.padding = padding
self.extension = extension
self.marker = marker
self.version = version
self.padding = padding
self.extension = extension
self.marker = marker
self.sequence_number = sequence_number
self.timestamp = timestamp
self.ssrc = ssrc
self.csrc_list = csrc_list
self.payload_type = payload_type
self.payload = payload
self.timestamp = timestamp
self.ssrc = ssrc
self.csrc_list = csrc_list
self.payload_type = payload_type
self.payload = payload
def __bytes__(self):
header = (
bytes([
self.version << 6 | self.padding << 5 | self.extension << 4 | len(self.csrc_list),
self.marker << 7 | self.payload_type
]) +
struct.pack('>HII', self.sequence_number, self.timestamp, self.ssrc)
)
header = bytes(
[
self.version << 6
| self.padding << 5
| self.extension << 4
| len(self.csrc_list),
self.marker << 7 | self.payload_type,
]
) + struct.pack('>HII', self.sequence_number, self.timestamp, self.ssrc)
for csrc in self.csrc_list:
header += struct.pack('>I', csrc)
return header + self.payload
@@ -319,13 +324,13 @@ class MediaPacket:
# -----------------------------------------------------------------------------
class MediaPacketPump:
def __init__(self, packets, clock=RealtimeClock()):
self.packets = packets
self.clock = clock
self.packets = packets
self.clock = clock
self.pump_task = None
async def start(self, rtp_channel):
async def pump_packets():
start_time = 0
start_time = 0
start_timestamp = 0
try:
@@ -333,7 +338,7 @@ class MediaPacketPump:
async for packet in self.packets:
# Capture the timestamp of the first packet
if start_time == 0:
start_time = self.clock.now()
start_time = self.clock.now()
start_timestamp = packet.timestamp_seconds
# Wait until we can send
@@ -346,7 +351,9 @@ class MediaPacketPump:
# Emit
rtp_channel.send_pdu(bytes(packet))
logger.debug(f'{color(">>> sending RTP packet:", "green")} {packet}')
logger.debug(
f'{color(">>> sending RTP packet:", "green")} {packet}'
)
except asyncio.exceptions.CancelledError:
logger.debug('pump canceled')
@@ -368,67 +375,87 @@ class MessageAssembler:
self.reset()
def reset(self):
self.transaction_label = 0
self.message = None
self.message_type = 0
self.signal_identifier = 0
self.transaction_label = 0
self.message = None
self.message_type = 0
self.signal_identifier = 0
self.number_of_signal_packets = 0
self.packet_count = 0
self.packet_count = 0
def on_pdu(self, pdu):
self.packet_count += 1
transaction_label = pdu[0] >> 4
packet_type = (pdu[0] >> 2) & 3
message_type = pdu[0] & 3
packet_type = (pdu[0] >> 2) & 3
message_type = pdu[0] & 3
logger.debug(f'transaction_label={transaction_label}, packet_type={Protocol.packet_type_name(packet_type)}, message_type={Message.message_type_name(message_type)}')
if packet_type == Protocol.SINGLE_PACKET or packet_type == Protocol.START_PACKET:
logger.debug(
f'transaction_label={transaction_label}, packet_type={Protocol.packet_type_name(packet_type)}, message_type={Message.message_type_name(message_type)}'
)
if (
packet_type == Protocol.SINGLE_PACKET
or packet_type == Protocol.START_PACKET
):
if self.message is not None:
# The previous message has not been terminated
logger.warning('received a start or single packet when expecting an end or continuation')
logger.warning(
'received a start or single packet when expecting an end or continuation'
)
self.reset()
self.transaction_label = transaction_label
self.signal_identifier = pdu[1] & 0x3F
self.message_type = message_type
self.message_type = message_type
if packet_type == Protocol.SINGLE_PACKET:
self.message = pdu[2:]
self.on_message_complete()
else:
self.number_of_signal_packets = pdu[2]
self.message = pdu[3:]
elif packet_type == Protocol.CONTINUE_PACKET or packet_type == Protocol.END_PACKET:
self.message = pdu[3:]
elif (
packet_type == Protocol.CONTINUE_PACKET
or packet_type == Protocol.END_PACKET
):
if self.packet_count == 0:
logger.warning('unexpected continuation')
return
if transaction_label != self.transaction_label:
logger.warning(f'transaction label mismatch: expected {self.transaction_label}, received {transaction_label}')
logger.warning(
f'transaction label mismatch: expected {self.transaction_label}, received {transaction_label}'
)
return
if message_type != self.message_type:
logger.warning(f'message type mismatch: expected {self.message_type}, received {message_type}')
logger.warning(
f'message type mismatch: expected {self.message_type}, received {message_type}'
)
return
self.message += pdu[1:]
if packet_type == Protocol.END_PACKET:
if self.packet_count != self.number_of_signal_packets:
logger.warning(f'incomplete fragmented message: expected {self.number_of_signal_packets} packets, received {self.packet_count}')
logger.warning(
f'incomplete fragmented message: expected {self.number_of_signal_packets} packets, received {self.packet_count}'
)
self.reset()
return
self.on_message_complete()
else:
if self.packet_count > self.number_of_signal_packets:
logger.warning(f'too many packets: expected {self.number_of_signal_packets}, received {self.packet_count}')
logger.warning(
f'too many packets: expected {self.number_of_signal_packets}, received {self.packet_count}'
)
self.reset()
return
def on_message_complete(self):
message = Message.create(self.signal_identifier, self.message_type, self.message)
message = Message.create(
self.signal_identifier, self.message_type, self.message
)
try:
self.callback(self.transaction_label, message)
@@ -460,12 +487,14 @@ class ServiceCapabilities:
def parse_capabilities(payload):
capabilities = []
while payload:
service_category = payload[0]
service_category = payload[0]
length_of_service_capabilities = payload[1]
service_capabilities_bytes = payload[2:2 + length_of_service_capabilities]
capabilities.append(ServiceCapabilities.create(service_category, service_capabilities_bytes))
service_capabilities_bytes = payload[2 : 2 + length_of_service_capabilities]
capabilities.append(
ServiceCapabilities.create(service_category, service_capabilities_bytes)
)
payload = payload[2 + length_of_service_capabilities:]
payload = payload[2 + length_of_service_capabilities :]
return capabilities
@@ -473,21 +502,24 @@ class ServiceCapabilities:
def serialize_capabilities(capabilities):
serialized = b''
for item in capabilities:
serialized += bytes([
item.service_category,
len(item.service_capabilities_bytes)
]) + item.service_capabilities_bytes
serialized += (
bytes([item.service_category, len(item.service_capabilities_bytes)])
+ item.service_capabilities_bytes
)
return serialized
def init_from_bytes(self):
pass
def __init__(self, service_category, service_capabilities_bytes=b''):
self.service_category = service_category
self.service_category = service_category
self.service_capabilities_bytes = service_capabilities_bytes
def to_string(self, details=[]):
attributes = ','.join([name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)] + details)
attributes = ','.join(
[name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)]
+ details
)
return f'ServiceCapabilities({attributes})'
def __str__(self):
@@ -501,31 +533,39 @@ class ServiceCapabilities:
# -----------------------------------------------------------------------------
class MediaCodecCapabilities(ServiceCapabilities):
def init_from_bytes(self):
self.media_type = self.service_capabilities_bytes[0]
self.media_codec_type = self.service_capabilities_bytes[1]
self.media_type = self.service_capabilities_bytes[0]
self.media_codec_type = self.service_capabilities_bytes[1]
self.media_codec_information = self.service_capabilities_bytes[2:]
if self.media_codec_type == A2DP_SBC_CODEC_TYPE:
self.media_codec_information = SbcMediaCodecInformation.from_bytes(self.media_codec_information)
self.media_codec_information = SbcMediaCodecInformation.from_bytes(
self.media_codec_information
)
elif self.media_codec_type == A2DP_MPEG_2_4_AAC_CODEC_TYPE:
self.media_codec_information = AacMediaCodecInformation.from_bytes(self.media_codec_information)
self.media_codec_information = AacMediaCodecInformation.from_bytes(
self.media_codec_information
)
elif self.media_codec_type == A2DP_NON_A2DP_CODEC_TYPE:
self.media_codec_information = VendorSpecificMediaCodecInformation.from_bytes(self.media_codec_information)
self.media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(
self.media_codec_information
)
)
def __init__(self, media_type, media_codec_type, media_codec_information):
super().__init__(
AVDTP_MEDIA_CODEC_SERVICE_CATEGORY,
bytes([media_type, media_codec_type]) + bytes(media_codec_information)
bytes([media_type, media_codec_type]) + bytes(media_codec_information),
)
self.media_type = media_type
self.media_codec_type = media_codec_type
self.media_type = media_type
self.media_codec_type = media_codec_type
self.media_codec_information = media_codec_information
def __str__(self):
details = [
f'media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}',
f'codec={name_or_number(A2DP_CODEC_TYPE_NAMES, self.media_codec_type)}',
f'codec_info={self.media_codec_information.hex() if type(self.media_codec_information) is bytes else str(self.media_codec_information)}'
f'codec_info={self.media_codec_information.hex() if type(self.media_codec_information) is bytes else str(self.media_codec_information)}',
]
return self.to_string(details)
@@ -535,37 +575,33 @@ class EndPointInfo:
@staticmethod
def from_bytes(payload):
return EndPointInfo(
payload[0] >> 2,
payload[0] >> 1 & 1,
payload[1] >> 4,
payload[1] >> 3 & 1
payload[0] >> 2, payload[0] >> 1 & 1, payload[1] >> 4, payload[1] >> 3 & 1
)
def __bytes__(self):
return bytes([
self.seid << 2 | self.in_use << 1,
self.media_type << 4 | self.tsep << 3
])
return bytes(
[self.seid << 2 | self.in_use << 1, self.media_type << 4 | self.tsep << 3]
)
def __init__(self, seid, in_use, media_type, tsep):
self.seid = seid
self.in_use = in_use
self.seid = seid
self.in_use = in_use
self.media_type = media_type
self.tsep = tsep
self.tsep = tsep
# -----------------------------------------------------------------------------
class Message:
COMMAND = 0
GENERAL_REJECT = 1
COMMAND = 0
GENERAL_REJECT = 1
RESPONSE_ACCEPT = 2
RESPONSE_REJECT = 3
MESSAGE_TYPE_NAMES = {
COMMAND: 'COMMAND',
GENERAL_REJECT: 'GENERAL_REJECT',
COMMAND: 'COMMAND',
GENERAL_REJECT: 'GENERAL_REJECT',
RESPONSE_ACCEPT: 'RESPONSE_ACCEPT',
RESPONSE_REJECT: 'RESPONSE_REJECT'
RESPONSE_REJECT: 'RESPONSE_REJECT',
}
subclasses = {} # Subclasses, by signal identifier and message type
@@ -603,7 +639,9 @@ class Message:
break
# Register the subclass
Message.subclasses.setdefault(cls.signal_identifier, {})[cls.message_type] = cls
Message.subclasses.setdefault(cls.signal_identifier, {})[
cls.message_type
] = cls
return cls
@@ -635,7 +673,7 @@ class Message:
pass
def __init__(self, payload=b''):
self.payload = payload
self.payload = payload
def to_string(self, details):
base = f'{color(f"{name_or_number(AVDTP_SIGNAL_NAMES, self.signal_identifier)}_{Message.message_type_name(self.message_type)}", "yellow")}'
@@ -643,7 +681,11 @@ class Message:
if type(details) is str:
return f'{base}: {details}'
else:
return base + ':\n' + '\n'.join([' ' + color(detail, 'cyan') for detail in details])
return (
base
+ ':\n'
+ '\n'.join([' ' + color(detail, 'cyan') for detail in details])
)
else:
return base
@@ -682,9 +724,7 @@ class Simple_Reject(Message):
self.payload = bytes([self.error_code])
def __str__(self):
details = [
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}'
]
details = [f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}']
return self.to_string(details)
@@ -707,11 +747,13 @@ class Discover_Response(Message):
self.endpoints = []
endpoint_count = len(self.payload) // 2
for i in range(endpoint_count):
self.endpoints.append(EndPointInfo.from_bytes(self.payload[i * 2:(i + 1) * 2]))
self.endpoints.append(
EndPointInfo.from_bytes(self.payload[i * 2 : (i + 1) * 2])
)
def __init__(self, endpoints):
self.endpoints = endpoints
self.payload = b''.join([bytes(endpoint) for endpoint in endpoints])
self.payload = b''.join([bytes(endpoint) for endpoint in endpoints])
def __str__(self):
details = []
@@ -721,7 +763,7 @@ class Discover_Response(Message):
f'ACP SEID: {endpoint.seid}',
f' in_use: {endpoint.in_use}',
f' media_type: {name_or_number(AVDTP_MEDIA_TYPE_NAMES, endpoint.media_type)}',
f' tsep: {name_or_number(AVDTP_TSEP_NAMES, endpoint.tsep)}'
f' tsep: {name_or_number(AVDTP_TSEP_NAMES, endpoint.tsep)}',
]
)
return self.to_string(details)
@@ -794,21 +836,22 @@ class Set_Configuration_Command(Message):
'''
def init_from_payload(self):
self.acp_seid = self.payload[0] >> 2
self.int_seid = self.payload[1] >> 2
self.acp_seid = self.payload[0] >> 2
self.int_seid = self.payload[1] >> 2
self.capabilities = ServiceCapabilities.parse_capabilities(self.payload[2:])
def __init__(self, acp_seid, int_seid, capabilities):
self.acp_seid = acp_seid
self.int_seid = int_seid
self.acp_seid = acp_seid
self.int_seid = int_seid
self.capabilities = capabilities
self.payload = bytes([acp_seid << 2, int_seid << 2]) + ServiceCapabilities.serialize_capabilities(capabilities)
self.payload = bytes(
[acp_seid << 2, int_seid << 2]
) + ServiceCapabilities.serialize_capabilities(capabilities)
def __str__(self):
details = [
f'ACP SEID: {self.acp_seid}',
f'INT SEID: {self.int_seid}'
] + [str(capability) for capability in self.capabilities]
details = [f'ACP SEID: {self.acp_seid}', f'INT SEID: {self.int_seid}'] + [
str(capability) for capability in self.capabilities
]
return self.to_string(details)
@@ -829,17 +872,17 @@ class Set_Configuration_Reject(Message):
def init_from_payload(self):
self.service_category = self.payload[0]
self.error_code = self.payload[1]
self.error_code = self.payload[1]
def __init__(self, service_category, error_code):
self.service_category = service_category
self.error_code = error_code
self.payload = bytes([service_category, self.error_code])
self.error_code = error_code
self.payload = bytes([service_category, self.error_code])
def __str__(self):
details = [
f'service_category: {name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)}',
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}'
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}',
]
return self.to_string(details)
@@ -887,7 +930,7 @@ class Reconfigure_Command(Message):
'''
def init_from_payload(self):
self.acp_seid = self.payload[0] >> 2
self.acp_seid = self.payload[0] >> 2
self.capabilities = ServiceCapabilities.parse_capabilities(self.payload[1:])
def __str__(self):
@@ -971,18 +1014,18 @@ class Start_Reject(Message):
'''
def init_from_payload(self):
self.acp_seid = self.payload[0] >> 2
self.acp_seid = self.payload[0] >> 2
self.error_code = self.payload[1]
def __init__(self, acp_seid, error_code):
self.acp_seid = acp_seid
self.acp_seid = acp_seid
self.error_code = error_code
self.payload = bytes([self.acp_seid << 2, self.error_code])
self.payload = bytes([self.acp_seid << 2, self.error_code])
def __str__(self):
details = [
f'acp_seid: {self.acp_seid}',
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}'
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}',
]
return self.to_string(details)
@@ -1095,13 +1138,10 @@ class DelayReport_Command(Message):
def init_from_payload(self):
self.acp_seid = self.payload[0] >> 2
self.delay = (self.payload[1] << 8) | (self.payload[2])
self.delay = (self.payload[1] << 8) | (self.payload[2])
def __str__(self):
return self.to_string([
f'ACP_SEID: {self.acp_seid}',
f'delay: {self.delay}'
])
return self.to_string([f'ACP_SEID: {self.acp_seid}', f'delay: {self.delay}'])
# -----------------------------------------------------------------------------
@@ -1122,16 +1162,16 @@ class DelayReport_Reject(Simple_Reject):
# -----------------------------------------------------------------------------
class Protocol:
SINGLE_PACKET = 0
START_PACKET = 1
SINGLE_PACKET = 0
START_PACKET = 1
CONTINUE_PACKET = 2
END_PACKET = 3
END_PACKET = 3
PACKET_TYPE_NAMES = {
SINGLE_PACKET: 'SINGLE_PACKET',
START_PACKET: 'START_PACKET',
SINGLE_PACKET: 'SINGLE_PACKET',
START_PACKET: 'START_PACKET',
CONTINUE_PACKET: 'CONTINUE_PACKET',
END_PACKET: 'END_PACKET'
END_PACKET: 'END_PACKET',
}
@staticmethod
@@ -1148,18 +1188,18 @@ class Protocol:
return protocol
def __init__(self, l2cap_channel, version=(1, 3)):
self.l2cap_channel = l2cap_channel
self.version = version
self.rtx_sig_timer = AVDTP_DEFAULT_RTX_SIG_TIMER
self.message_assembler = MessageAssembler(self.on_message)
self.transaction_results = [None] * 16 # Futures for up to 16 transactions
self.l2cap_channel = l2cap_channel
self.version = version
self.rtx_sig_timer = AVDTP_DEFAULT_RTX_SIG_TIMER
self.message_assembler = MessageAssembler(self.on_message)
self.transaction_results = [None] * 16 # Futures for up to 16 transactions
self.transaction_semaphore = asyncio.Semaphore(16)
self.transaction_count = 0
self.channel_acceptor = None
self.channel_connector = None
self.local_endpoints = [] # Local endpoints, with contiguous seid values
self.remote_endpoints = {} # Remote stream endpoints, by seid
self.streams = {} # Streams, by seid
self.transaction_count = 0
self.channel_acceptor = None
self.channel_connector = None
self.local_endpoints = [] # Local endpoints, with contiguous seid values
self.remote_endpoints = {} # Remote stream endpoints, by seid
self.streams = {} # Streams, by seid
# Register to receive PDUs from the channel
l2cap_channel.sink = self.on_pdu
@@ -1205,7 +1245,9 @@ class Protocol:
response = await self.send_command(Discover_Command())
for endpoint_entry in response.endpoints:
logger.debug(f'getting endpoint capabilities for endpoint {endpoint_entry.seid}')
logger.debug(
f'getting endpoint capabilities for endpoint {endpoint_entry.seid}'
)
get_capabilities_response = await self.get_capabilities(endpoint_entry.seid)
endpoint = DiscoveredStreamEndPoint(
self,
@@ -1213,7 +1255,7 @@ class Protocol:
endpoint_entry.media_type,
endpoint_entry.tsep,
endpoint_entry.in_use,
get_capabilities_response.capabilities
get_capabilities_response.capabilities,
)
self.remote_endpoints[endpoint_entry.seid] = endpoint
@@ -1221,14 +1263,27 @@ class Protocol:
def find_remote_sink_by_codec(self, media_type, codec_type):
for endpoint in self.remote_endpoints.values():
if not endpoint.in_use and endpoint.media_type == media_type and endpoint.tsep == AVDTP_TSEP_SNK:
if (
not endpoint.in_use
and endpoint.media_type == media_type
and endpoint.tsep == AVDTP_TSEP_SNK
):
has_media_transport = False
has_codec = False
for capabilities in endpoint.capabilities:
if capabilities.service_category == AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY:
if (
capabilities.service_category
== AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY
):
has_media_transport = True
elif capabilities.service_category == AVDTP_MEDIA_CODEC_SERVICE_CATEGORY:
if capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE and capabilities.media_codec_type == codec_type:
elif (
capabilities.service_category
== AVDTP_MEDIA_CODEC_SERVICE_CATEGORY
):
if (
capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE
and capabilities.media_codec_type == codec_type
):
has_codec = True
if has_media_transport and has_codec:
return endpoint
@@ -1237,7 +1292,9 @@ class Protocol:
self.message_assembler.on_pdu(pdu)
def on_message(self, transaction_label, message):
logger.debug(f'{color("<<< Received AVDTP message", "magenta")}: [{transaction_label}] {message}')
logger.debug(
f'{color("<<< Received AVDTP message", "magenta")}: [{transaction_label}] {message}'
)
# Check that the identifier is not reserved
if message.signal_identifier == 0:
@@ -1245,7 +1302,10 @@ class Protocol:
return
# Check that the identifier is valid
if message.signal_identifier < 0 or message.signal_identifier > AVDTP_DELAYREPORT:
if (
message.signal_identifier < 0
or message.signal_identifier > AVDTP_DELAYREPORT
):
logger.warning('!!! invalid signal identifier')
self.send_message(transaction_label, General_Reject())
@@ -1258,7 +1318,9 @@ class Protocol:
response = handler(message)
self.send_message(transaction_label, response)
except Exception as error:
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
logger.warning(
f'{color("!!! Exception in handler:", "red")} {error}'
)
else:
logger.warning('unhandled command')
else:
@@ -1281,8 +1343,12 @@ class Protocol:
logger.debug(color('<<< L2CAP channel open', 'magenta'))
def send_message(self, transaction_label, message):
logger.debug(f'{color(">>> Sending AVDTP message", "magenta")}: [{transaction_label}] {message}')
max_fragment_size = self.l2cap_channel.mtu - 3 # Enough space for a 3-byte start packet header
logger.debug(
f'{color(">>> Sending AVDTP message", "magenta")}: [{transaction_label}] {message}'
)
max_fragment_size = (
self.l2cap_channel.mtu - 3
) # Enough space for a 3-byte start packet header
payload = message.payload
if len(payload) + 2 <= self.l2cap_channel.mtu:
# Fits in a single packet
@@ -1292,13 +1358,19 @@ class Protocol:
done = False
while not done:
first_header_byte = transaction_label << 4 | packet_type << 2 | message.message_type
first_header_byte = (
transaction_label << 4 | packet_type << 2 | message.message_type
)
if packet_type == self.SINGLE_PACKET:
header = bytes([first_header_byte, message.signal_identifier])
elif packet_type == self.START_PACKET:
packet_count = (max_fragment_size - 1 + len(payload)) // max_fragment_size
header = bytes([first_header_byte, message.signal_identifier, packet_count])
packet_count = (
max_fragment_size - 1 + len(payload)
) // max_fragment_size
header = bytes(
[first_header_byte, message.signal_identifier, packet_count]
)
else:
header = bytes([first_header_byte])
@@ -1308,7 +1380,11 @@ class Protocol:
# Prepare for the next packet
payload = payload[max_fragment_size:]
if payload:
packet_type = self.CONTINUE_PACKET if payload > max_fragment_size else self.END_PACKET
packet_type = (
self.CONTINUE_PACKET
if payload > max_fragment_size
else self.END_PACKET
)
else:
done = True
@@ -1322,7 +1398,10 @@ class Protocol:
response = await transaction_result
# Check for errors
if response.message_type == Message.GENERAL_REJECT or response.message_type == Message.RESPONSE_REJECT:
if (
response.message_type == Message.GENERAL_REJECT
or response.message_type == Message.RESPONSE_REJECT
):
raise ProtocolError(response.error_code, 'avdtp')
return response
@@ -1340,7 +1419,7 @@ class Protocol:
self.transaction_count += 1
return (transaction_label, transaction_result)
assert(False) # Should never reach this
assert False # Should never reach this
async def get_capabilities(self, seid):
if self.version > (1, 2):
@@ -1349,7 +1428,9 @@ class Protocol:
return await self.send_command(Get_Capabilities_Command(seid))
async def set_configuration(self, acp_seid, int_seid, capabilities):
return await self.send_command(Set_Configuration_Command(acp_seid, int_seid, capabilities))
return await self.send_command(
Set_Configuration_Command(acp_seid, int_seid, capabilities)
)
async def get_configuration(self, seid):
response = await self.send_command(Get_Configuration_Command(seid))
@@ -1537,6 +1618,7 @@ class Listener(EventEmitter):
server = Protocol(channel, self.version)
self.set_server(channel.connection, server)
self.emit('connection', server)
channel.on('open', on_channel_open)
@@ -1562,8 +1644,7 @@ class Stream:
raise InvalidStateError('current state is not IDLE')
await self.remote_endpoint.set_configuration(
self.local_endpoint.seid,
self.local_endpoint.configuration
self.local_endpoint.seid, self.local_endpoint.configuration
)
self.change_state(AVDTP_CONFIGURED_STATE)
@@ -1639,7 +1720,11 @@ class Stream:
self.change_state(AVDTP_CONFIGURED_STATE)
def on_get_configuration_command(self, configuration):
if self.state not in {AVDTP_CONFIGURED_STATE, AVDTP_OPEN_STATE, AVDTP_STREAMING_STATE}:
if self.state not in {
AVDTP_CONFIGURED_STATE,
AVDTP_OPEN_STATE,
AVDTP_STREAMING_STATE,
}:
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
return self.local_endpoint.on_get_configuration_command(configuration)
@@ -1718,7 +1803,7 @@ class Stream:
def on_l2cap_connection(self, channel):
logger.debug(color('<<< stream channel connected', 'magenta'))
self.rtp_channel = channel
channel.on('open', self.on_l2cap_channel_open)
channel.on('open', self.on_l2cap_channel_open)
channel.on('close', self.on_l2cap_channel_close)
# We don't need more channels
@@ -1744,11 +1829,11 @@ class Stream:
remote_endpoint must be a subclass of StreamEndPointProxy
'''
self.protocol = protocol
self.local_endpoint = local_endpoint
self.protocol = protocol
self.local_endpoint = local_endpoint
self.remote_endpoint = remote_endpoint
self.rtp_channel = None
self.state = AVDTP_IDLE_STATE
self.rtp_channel = None
self.state = AVDTP_IDLE_STATE
local_endpoint.stream = self
local_endpoint.in_use = 1
@@ -1760,38 +1845,36 @@ class Stream:
# -----------------------------------------------------------------------------
class StreamEndPoint:
def __init__(self, seid, media_type, tsep, in_use, capabilities):
self.seid = seid
self.media_type = media_type
self.tsep = tsep
self.in_use = in_use
self.seid = seid
self.media_type = media_type
self.tsep = tsep
self.in_use = in_use
self.capabilities = capabilities
def __str__(self):
return '\n'.join([
'SEP(',
f' seid={self.seid}',
f' media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}',
f' tsep={name_or_number(AVDTP_TSEP_NAMES, self.tsep)}',
f' in_use={self.in_use}',
' capabilities=[',
'\n'.join([f' {x}' for x in self.capabilities]),
' ]',
')'
])
return '\n'.join(
[
'SEP(',
f' seid={self.seid}',
f' media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}',
f' tsep={name_or_number(AVDTP_TSEP_NAMES, self.tsep)}',
f' in_use={self.in_use}',
' capabilities=[',
'\n'.join([f' {x}' for x in self.capabilities]),
' ]',
')',
]
)
# -----------------------------------------------------------------------------
class StreamEndPointProxy:
def __init__(self, protocol, seid):
self.seid = seid
self.seid = seid
self.protocol = protocol
async def set_configuration(self, int_seid, configuration):
return await self.protocol.set_configuration(
self.seid,
int_seid,
configuration
)
return await self.protocol.set_configuration(self.seid, int_seid, configuration)
async def open(self):
return await self.protocol.open(self.seid)
@@ -1818,11 +1901,13 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
# -----------------------------------------------------------------------------
class LocalStreamEndPoint(StreamEndPoint):
def __init__(self, protocol, seid, media_type, tsep, capabilities, configuration=[]):
def __init__(
self, protocol, seid, media_type, tsep, capabilities, configuration=[]
):
super().__init__(seid, media_type, tsep, 0, capabilities)
self.protocol = protocol
self.protocol = protocol
self.configuration = configuration
self.stream = None
self.stream = None
async def start(self):
pass
@@ -1866,9 +1951,17 @@ class LocalSource(LocalStreamEndPoint, EventEmitter):
def __init__(self, protocol, seid, codec_capabilities, packet_pump):
capabilities = [
ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY),
codec_capabilities
codec_capabilities,
]
LocalStreamEndPoint.__init__(self, protocol, seid, codec_capabilities.media_type, AVDTP_TSEP_SRC, capabilities, capabilities)
LocalStreamEndPoint.__init__(
self,
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SRC,
capabilities,
capabilities,
)
EventEmitter.__init__(self)
self.packet_pump = packet_pump
@@ -1901,9 +1994,16 @@ class LocalSink(LocalStreamEndPoint, EventEmitter):
def __init__(self, protocol, seid, codec_capabilities):
capabilities = [
ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY),
codec_capabilities
codec_capabilities,
]
LocalStreamEndPoint.__init__(self, protocol, seid, codec_capabilities.media_type, AVDTP_TSEP_SNK, capabilities)
LocalStreamEndPoint.__init__(
self,
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SNK,
capabilities,
)
EventEmitter.__init__(self)
def on_set_configuration_command(self, configuration):
@@ -1917,5 +2017,7 @@ class LocalSink(LocalStreamEndPoint, EventEmitter):
def on_avdtp_packet(self, packet):
rtp_packet = MediaPacket.from_bytes(packet)
logger.debug(f'{color("<<< RTP Packet:", "green")} {rtp_packet} {rtp_packet.payload[:16].hex()}')
logger.debug(
f'{color("<<< RTP Packet:", "green")} {rtp_packet} {rtp_packet.payload[:16].hex()}'
)
self.emit('rtp_packet', rtp_packet)