mirror of
https://github.com/google/bumble.git
synced 2026-06-04 08:07:03 +00:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b6a21fa3c6 | |||
| e44eaf2147 | |||
| ef634953f0 | |||
| 71672ec64f | |||
| 5ee2d80ce4 | |||
| 72d821b1f6 | |||
| 9b2e345a1e | |||
| f9bd3084b9 | |||
| 808ea1abeb | |||
| afe064b4ea | |||
| 8d0cef70c2 | |||
| 9cefde1c3e | |||
| ffb9d5f117 | |||
| 7d3be8157a | |||
| 9dc9c348e5 | |||
| b18555539e | |||
| 8a853d5b2f | |||
| 8988a85245 | |||
| 0813da2278 | |||
| a1ff183d44 | |||
| 7adf44eddf | |||
| 05accbf805 | |||
| 80f54f2a09 | |||
| 07b5e33e09 | |||
| b874e26a4f | |||
| baa5257780 | |||
| a91ea9110c | |||
| 1686c5b11b | |||
| d9481992bb | |||
| 16d0ed56cf | |||
| c55eb156b8 | |||
| 8614881fb3 |
+1
-1
@@ -489,7 +489,7 @@ class Sender:
|
||||
flags=(
|
||||
Packet.PacketFlags.LAST
|
||||
if tx_i == self.tx_packet_count - 1
|
||||
else 0
|
||||
else Packet.PacketFlags(0)
|
||||
),
|
||||
sequence=tx_i,
|
||||
timestamp=int((time.time() - self.start_time) * 1000000),
|
||||
|
||||
@@ -45,8 +45,10 @@ from bumble.hci import (
|
||||
HCI_Read_Local_Supported_Codecs_Command,
|
||||
HCI_Read_Local_Supported_Codecs_V2_Command,
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
HCI_Read_Voice_Setting_Command,
|
||||
LeFeature,
|
||||
SpecificationVersion,
|
||||
VoiceSetting,
|
||||
map_null_terminated_utf8_string,
|
||||
)
|
||||
from bumble.host import Host
|
||||
@@ -214,6 +216,16 @@ async def get_codecs_info(host: Host) -> None:
|
||||
if not response2.vendor_specific_codec_ids:
|
||||
print(' No Vendor-specific codecs')
|
||||
|
||||
if host.supports_command(HCI_Read_Voice_Setting_Command.op_code):
|
||||
response3 = await host.send_sync_command(HCI_Read_Voice_Setting_Command())
|
||||
voice_setting = VoiceSetting.from_int(response3.voice_setting)
|
||||
print(color('Voice Setting:', 'yellow'))
|
||||
print(f' Air Coding Format: {voice_setting.air_coding_format.name}')
|
||||
print(f' Linear PCM Bit Position: {voice_setting.linear_pcm_bit_position}')
|
||||
print(f' Input Sample Size: {voice_setting.input_sample_size.name}')
|
||||
print(f' Input Data Format: {voice_setting.input_data_format.name}')
|
||||
print(f' Input Coding Format: {voice_setting.input_coding_format.name}')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(
|
||||
|
||||
+151
-41
@@ -16,6 +16,8 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import statistics
|
||||
import struct
|
||||
import time
|
||||
|
||||
import click
|
||||
@@ -25,7 +27,9 @@ from bumble.colors import color
|
||||
from bumble.hci import (
|
||||
HCI_READ_LOOPBACK_MODE_COMMAND,
|
||||
HCI_WRITE_LOOPBACK_MODE_COMMAND,
|
||||
Address,
|
||||
HCI_Read_Loopback_Mode_Command,
|
||||
HCI_SynchronousDataPacket,
|
||||
HCI_Write_Loopback_Mode_Command,
|
||||
LoopbackMode,
|
||||
)
|
||||
@@ -36,55 +40,121 @@ from bumble.transport import open_transport
|
||||
class Loopback:
|
||||
"""Send and receive ACL data packets in local loopback mode"""
|
||||
|
||||
def __init__(self, packet_size: int, packet_count: int, transport: str):
|
||||
def __init__(
|
||||
self,
|
||||
packet_size: int,
|
||||
packet_count: int,
|
||||
connection_type: str,
|
||||
mode: str,
|
||||
interval: int,
|
||||
transport: str,
|
||||
):
|
||||
self.transport = transport
|
||||
self.packet_size = packet_size
|
||||
self.packet_count = packet_count
|
||||
self.connection_handle: int | None = None
|
||||
self.connection_type = connection_type
|
||||
self.connection_event = asyncio.Event()
|
||||
self.mode = mode
|
||||
self.interval = interval
|
||||
self.done = asyncio.Event()
|
||||
self.expected_cid = 0
|
||||
self.expected_counter = 0
|
||||
self.bytes_received = 0
|
||||
self.start_timestamp = 0.0
|
||||
self.last_timestamp = 0.0
|
||||
self.send_timestamps: list[float] = []
|
||||
self.rtts: list[float] = []
|
||||
|
||||
def on_connection(self, connection_handle: int, *args):
|
||||
"""Retrieve connection handle from new connection event"""
|
||||
if not self.connection_event.is_set():
|
||||
# save first connection handle for ACL
|
||||
# subsequent connections are SCO
|
||||
# The first connection handle is of type ACL,
|
||||
# subsequent connections are of type SCO
|
||||
if self.connection_type == "sco" and self.connection_handle is None:
|
||||
self.connection_handle = connection_handle
|
||||
return
|
||||
|
||||
self.connection_handle = connection_handle
|
||||
self.connection_event.set()
|
||||
|
||||
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
|
||||
def on_sco_connection(
|
||||
self,
|
||||
address: Address,
|
||||
connection_handle: int,
|
||||
link_type,
|
||||
rx_packet_length: int,
|
||||
tx_packet_length: int,
|
||||
air_mode: int,
|
||||
) -> None:
|
||||
self.on_connection(connection_handle)
|
||||
|
||||
def on_packet(self, connection_handle: int, packet: bytes):
|
||||
"""Calculate packet receive speed"""
|
||||
now = time.time()
|
||||
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
|
||||
(counter,) = struct.unpack_from("H", packet, 0)
|
||||
rtt = now - self.send_timestamps[counter]
|
||||
self.rtts.append(rtt)
|
||||
print(f'<<< Received packet {counter}: {len(packet)} bytes, RTT={rtt:.4f}')
|
||||
assert connection_handle == self.connection_handle
|
||||
assert cid == self.expected_cid
|
||||
self.expected_cid += 1
|
||||
if cid == 0:
|
||||
assert counter == self.expected_counter
|
||||
self.expected_counter += 1
|
||||
if counter == 0:
|
||||
self.start_timestamp = now
|
||||
else:
|
||||
elapsed_since_start = now - self.start_timestamp
|
||||
elapsed_since_last = now - self.last_timestamp
|
||||
self.bytes_received += len(pdu)
|
||||
instant_rx_speed = len(pdu) / elapsed_since_last
|
||||
self.bytes_received += len(packet)
|
||||
instant_rx_speed = len(packet) / elapsed_since_last
|
||||
average_rx_speed = self.bytes_received / elapsed_since_start
|
||||
print(
|
||||
color(
|
||||
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
|
||||
f' average={average_rx_speed:.4f}',
|
||||
'cyan',
|
||||
if self.mode == 'throughput':
|
||||
print(
|
||||
color(
|
||||
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
|
||||
f' average={average_rx_speed:.4f},',
|
||||
'cyan',
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.last_timestamp = now
|
||||
|
||||
if self.expected_cid == self.packet_count:
|
||||
if self.expected_counter == self.packet_count:
|
||||
print(color('@@@ Received last packet', 'green'))
|
||||
self.done.set()
|
||||
|
||||
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
|
||||
self.on_packet(connection_handle, pdu)
|
||||
|
||||
def on_sco_packet(self, connection_handle: int, packet) -> None:
|
||||
self.on_packet(connection_handle, packet)
|
||||
|
||||
async def send_acl_packet(self, host: Host, packet: bytes) -> None:
|
||||
assert self.connection_handle
|
||||
host.send_l2cap_pdu(self.connection_handle, 0, packet)
|
||||
|
||||
async def send_sco_packet(self, host: Host, packet: bytes) -> None:
|
||||
assert self.connection_handle
|
||||
host.send_hci_packet(
|
||||
HCI_SynchronousDataPacket(
|
||||
connection_handle=self.connection_handle,
|
||||
packet_status=HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA,
|
||||
data_total_length=len(packet),
|
||||
data=packet,
|
||||
)
|
||||
)
|
||||
|
||||
async def send_loop(self, host: Host, sender) -> None:
|
||||
for counter in range(0, self.packet_count):
|
||||
print(
|
||||
color(
|
||||
f'>>> Sending {self.connection_type.upper()} '
|
||||
f'packet {counter}: {self.packet_size} bytes',
|
||||
'yellow',
|
||||
)
|
||||
)
|
||||
self.send_timestamps.append(time.time())
|
||||
await sender(host, struct.pack("H", counter) + bytes(self.packet_size - 2))
|
||||
await asyncio.sleep(self.interval / 1000 if self.mode == "rtt" else 0)
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run a loopback throughput test"""
|
||||
print(color('>>> Connecting to HCI...', 'green'))
|
||||
@@ -126,8 +196,11 @@ class Loopback:
|
||||
return
|
||||
|
||||
# set event callbacks
|
||||
host.on('connection', self.on_connection)
|
||||
host.on('classic_connection', self.on_connection)
|
||||
host.on('le_connection', self.on_connection)
|
||||
host.on('sco_connection', self.on_sco_connection)
|
||||
host.on('l2cap_pdu', self.on_l2cap_pdu)
|
||||
host.on('sco_packet', self.on_sco_packet)
|
||||
|
||||
loopback_mode = LoopbackMode.LOCAL
|
||||
|
||||
@@ -148,32 +221,37 @@ class Loopback:
|
||||
|
||||
print(color('=== Start sending', 'magenta'))
|
||||
start_time = time.time()
|
||||
bytes_sent = 0
|
||||
for cid in range(0, self.packet_count):
|
||||
# using the cid as an incremental index
|
||||
host.send_l2cap_pdu(
|
||||
self.connection_handle, cid, bytes(self.packet_size)
|
||||
)
|
||||
print(
|
||||
color(
|
||||
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
|
||||
)
|
||||
)
|
||||
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
|
||||
await asyncio.sleep(0) # yield to allow packet receive
|
||||
if self.connection_type == "acl":
|
||||
sender = self.send_acl_packet
|
||||
elif self.connection_type == "sco":
|
||||
sender = self.send_sco_packet
|
||||
else:
|
||||
raise ValueError(f'Unknown connection type: {self.connection_type}')
|
||||
await self.send_loop(host, sender)
|
||||
|
||||
await self.done.wait()
|
||||
print(color('=== Done!', 'magenta'))
|
||||
|
||||
bytes_sent = self.packet_size * self.packet_count
|
||||
elapsed = time.time() - start_time
|
||||
average_tx_speed = bytes_sent / elapsed
|
||||
print(
|
||||
color(
|
||||
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
|
||||
f' in {elapsed:.2f} seconds)',
|
||||
'green',
|
||||
if self.mode == 'throughput':
|
||||
print(
|
||||
color(
|
||||
f'@@@ TX speed: average={average_tx_speed:.4f} '
|
||||
f'({bytes_sent} bytes in {elapsed:.2f} seconds)',
|
||||
'green',
|
||||
)
|
||||
)
|
||||
if self.mode == 'rtt':
|
||||
print(
|
||||
color(
|
||||
f'RTTs: min={min(self.rtts):.4f}, '
|
||||
f'max={max(self.rtts):.4f}, '
|
||||
f'avg={statistics.mean(self.rtts):.4f}',
|
||||
'blue',
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -194,11 +272,43 @@ class Loopback:
|
||||
default=10,
|
||||
help='Packet count',
|
||||
)
|
||||
@click.option(
|
||||
'--connection-type',
|
||||
'-t',
|
||||
metavar='TYPE',
|
||||
type=click.Choice(['acl', 'sco']),
|
||||
default='acl',
|
||||
help='Connection type',
|
||||
)
|
||||
@click.option(
|
||||
'--mode',
|
||||
'-m',
|
||||
metavar='MODE',
|
||||
type=click.Choice(['throughput', 'rtt']),
|
||||
default='throughput',
|
||||
help='Test mode',
|
||||
)
|
||||
@click.option(
|
||||
'--interval',
|
||||
type=int,
|
||||
default=100,
|
||||
help='Inter-packet interval (ms) [RTT mode only]',
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(packet_size, packet_count, transport):
|
||||
def main(packet_size, packet_count, connection_type, mode, interval, transport):
|
||||
bumble.logging.setup_basic_logging()
|
||||
loopback = Loopback(packet_size, packet_count, transport)
|
||||
asyncio.run(loopback.run())
|
||||
|
||||
if connection_type == "sco" and packet_size > 255:
|
||||
print("ERROR: the maximum packet size for SCO is 255")
|
||||
return
|
||||
|
||||
async def run():
|
||||
loopback = Loopback(
|
||||
packet_size, packet_count, connection_type, mode, interval, transport
|
||||
)
|
||||
await loopback.run()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
+6
-1
@@ -111,9 +111,14 @@ def show_device_details(device):
|
||||
if (endpoint.getAddress() & USB_ENDPOINT_IN == 0)
|
||||
else 'IN'
|
||||
)
|
||||
endpoint_details = (
|
||||
f', Max Packet Size = {endpoint.getMaxPacketSize()}'
|
||||
if endpoint_type == 'ISOCHRONOUS'
|
||||
else ''
|
||||
)
|
||||
print(
|
||||
f' Endpoint 0x{endpoint.getAddress():02X}: '
|
||||
f'{endpoint_type} {endpoint_direction}'
|
||||
f'{endpoint_type} {endpoint_direction}{endpoint_details}'
|
||||
)
|
||||
|
||||
|
||||
|
||||
+1
-1
@@ -1083,7 +1083,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
|
||||
else:
|
||||
value_str = str(self.value)
|
||||
if value_str:
|
||||
value_string = f', value={self.value.hex()}'
|
||||
value_string = f', value={value_str}'
|
||||
else:
|
||||
value_string = ''
|
||||
return (
|
||||
|
||||
+140
-77
@@ -17,6 +17,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import enum
|
||||
import logging
|
||||
@@ -311,6 +312,13 @@ class MessageAssembler:
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
self.packet_count += 1
|
||||
|
||||
# Drop empty PDUs sent by remote — accessing pdu[0] below would
|
||||
# raise IndexError, propagating up to the L2CAP read loop and
|
||||
# tearing down the channel. Same class as #912 (ATT empty PDU).
|
||||
if not pdu:
|
||||
logger.warning('AVDTP message assembler: empty PDU dropped')
|
||||
return
|
||||
|
||||
transaction_label = pdu[0] >> 4
|
||||
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
|
||||
message_type = Message.MessageType(pdu[0] & 3)
|
||||
@@ -324,6 +332,23 @@ class MessageAssembler:
|
||||
Protocol.PacketType.SINGLE_PACKET,
|
||||
Protocol.PacketType.START_PACKET,
|
||||
):
|
||||
# Both single and start packets carry the signal identifier in
|
||||
# pdu[1]; start packets additionally carry the packet count in
|
||||
# pdu[2]. Guard each access so a malformed remote frame can't
|
||||
# crash the message assembler.
|
||||
if len(pdu) < 2:
|
||||
logger.warning(
|
||||
'AVDTP %s packet too short (%d bytes); dropped',
|
||||
packet_type.name,
|
||||
len(pdu),
|
||||
)
|
||||
return
|
||||
if packet_type == Protocol.PacketType.START_PACKET and len(pdu) < 3:
|
||||
logger.warning(
|
||||
'AVDTP START packet missing signal-packet count; dropped'
|
||||
)
|
||||
return
|
||||
|
||||
if self.message is not None:
|
||||
# The previous message has not been terminated
|
||||
logger.warning(
|
||||
@@ -1453,8 +1478,23 @@ class Protocol(utils.EventEmitter):
|
||||
handler = getattr(self, handler_name, None)
|
||||
if handler:
|
||||
try:
|
||||
response = handler(message)
|
||||
self.send_message(transaction_label, response)
|
||||
result = handler(message)
|
||||
if asyncio.iscoroutine(result):
|
||||
|
||||
async def wait_and_send() -> None:
|
||||
try:
|
||||
response = await result
|
||||
if response:
|
||||
self.send_message(transaction_label, response)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
color("!!! Exception in handler:", "red")
|
||||
)
|
||||
|
||||
utils.cancel_on_event(self, self.EVENT_CLOSE, wait_and_send())
|
||||
else:
|
||||
if result:
|
||||
self.send_message(transaction_label, result)
|
||||
except Exception:
|
||||
logger.exception(color("!!! Exception in handler:", "red"))
|
||||
else:
|
||||
@@ -1535,7 +1575,7 @@ class Protocol(utils.EventEmitter):
|
||||
async def send_command(self, command: Message):
|
||||
# TODO: support timeouts
|
||||
# Send the command
|
||||
(transaction_label, transaction_result) = await self.start_transaction()
|
||||
transaction_label, transaction_result = await self.start_transaction()
|
||||
self.send_message(transaction_label, command)
|
||||
|
||||
# Wait for the response
|
||||
@@ -1600,14 +1640,14 @@ class Protocol(utils.EventEmitter):
|
||||
async def abort(self, seid: int) -> Abort_Response:
|
||||
return await self.send_command(Abort_Command(seid))
|
||||
|
||||
def on_discover_command(self, command: Discover_Command) -> Message | None:
|
||||
async def on_discover_command(self, command: Discover_Command) -> Message | None:
|
||||
endpoint_infos = [
|
||||
EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep)
|
||||
for endpoint in self.local_endpoints
|
||||
]
|
||||
return Discover_Response(endpoint_infos)
|
||||
|
||||
def on_get_capabilities_command(
|
||||
async def on_get_capabilities_command(
|
||||
self, command: Get_Capabilities_Command
|
||||
) -> Message | None:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
@@ -1616,7 +1656,7 @@ class Protocol(utils.EventEmitter):
|
||||
|
||||
return Get_Capabilities_Response(endpoint.capabilities)
|
||||
|
||||
def on_get_all_capabilities_command(
|
||||
async def on_get_all_capabilities_command(
|
||||
self, command: Get_All_Capabilities_Command
|
||||
) -> Message | None:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
@@ -1625,7 +1665,7 @@ class Protocol(utils.EventEmitter):
|
||||
|
||||
return Get_All_Capabilities_Response(endpoint.capabilities)
|
||||
|
||||
def on_set_configuration_command(
|
||||
async def on_set_configuration_command(
|
||||
self, command: Set_Configuration_Command
|
||||
) -> Message | None:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
@@ -1640,10 +1680,10 @@ class Protocol(utils.EventEmitter):
|
||||
stream = Stream(self, endpoint, StreamEndPointProxy(self, command.int_seid))
|
||||
self.streams[command.acp_seid] = stream
|
||||
|
||||
result = stream.on_set_configuration_command(command.capabilities)
|
||||
result = await stream.on_set_configuration_command(command.capabilities)
|
||||
return result or Set_Configuration_Response()
|
||||
|
||||
def on_get_configuration_command(
|
||||
async def on_get_configuration_command(
|
||||
self, command: Get_Configuration_Command
|
||||
) -> Message | None:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
@@ -1652,29 +1692,31 @@ class Protocol(utils.EventEmitter):
|
||||
if endpoint.stream is None:
|
||||
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
return endpoint.stream.on_get_configuration_command()
|
||||
return await endpoint.stream.on_get_configuration_command()
|
||||
|
||||
def on_reconfigure_command(self, command: Reconfigure_Command) -> Message | None:
|
||||
async def on_reconfigure_command(
|
||||
self, command: Reconfigure_Command
|
||||
) -> Message | None:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return Reconfigure_Reject(error_code=AVDTP_BAD_ACP_SEID_ERROR)
|
||||
if endpoint.stream is None:
|
||||
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
result = endpoint.stream.on_reconfigure_command(command.capabilities)
|
||||
result = await endpoint.stream.on_reconfigure_command(command.capabilities)
|
||||
return result or Reconfigure_Response()
|
||||
|
||||
def on_open_command(self, command: Open_Command) -> Message | None:
|
||||
async def on_open_command(self, command: Open_Command) -> Message | None:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||
if endpoint.stream is None:
|
||||
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
result = endpoint.stream.on_open_command()
|
||||
result = await endpoint.stream.on_open_command()
|
||||
return result or Open_Response()
|
||||
|
||||
def on_start_command(self, command: Start_Command) -> Message | None:
|
||||
async def on_start_command(self, command: Start_Command) -> Message | None:
|
||||
for seid in command.acp_seids:
|
||||
endpoint = self.get_local_endpoint_by_seid(seid)
|
||||
if endpoint is None:
|
||||
@@ -1688,12 +1730,12 @@ class Protocol(utils.EventEmitter):
|
||||
endpoint = self.get_local_endpoint_by_seid(seid)
|
||||
if not endpoint or not endpoint.stream:
|
||||
raise InvalidStateError("Should already be checked!")
|
||||
if (result := endpoint.stream.on_start_command()) is not None:
|
||||
if (result := await endpoint.stream.on_start_command()) is not None:
|
||||
return result
|
||||
|
||||
return Start_Response()
|
||||
|
||||
def on_suspend_command(self, command: Suspend_Command) -> Message | None:
|
||||
async def on_suspend_command(self, command: Suspend_Command) -> Message | None:
|
||||
for seid in command.acp_seids:
|
||||
endpoint = self.get_local_endpoint_by_seid(seid)
|
||||
if endpoint is None:
|
||||
@@ -1707,45 +1749,47 @@ class Protocol(utils.EventEmitter):
|
||||
endpoint = self.get_local_endpoint_by_seid(seid)
|
||||
if not endpoint or not endpoint.stream:
|
||||
raise InvalidStateError("Should already be checked!")
|
||||
if (result := endpoint.stream.on_suspend_command()) is not None:
|
||||
if (result := await endpoint.stream.on_suspend_command()) is not None:
|
||||
return result
|
||||
|
||||
return Suspend_Response()
|
||||
|
||||
def on_close_command(self, command: Close_Command) -> Message | None:
|
||||
async def on_close_command(self, command: Close_Command) -> Message | None:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||
if endpoint.stream is None:
|
||||
return Close_Reject(AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
result = endpoint.stream.on_close_command()
|
||||
result = await endpoint.stream.on_close_command()
|
||||
return result or Close_Response()
|
||||
|
||||
def on_abort_command(self, command: Abort_Command) -> Message | None:
|
||||
async def on_abort_command(self, command: Abort_Command) -> Message | None:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None or endpoint.stream is None:
|
||||
return Abort_Response()
|
||||
|
||||
endpoint.stream.on_abort_command()
|
||||
await endpoint.stream.on_abort_command()
|
||||
return Abort_Response()
|
||||
|
||||
def on_security_control_command(
|
||||
async def on_security_control_command(
|
||||
self, command: Security_Control_Command
|
||||
) -> Message | None:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||
|
||||
result = endpoint.on_security_control_command(command.data)
|
||||
result = await endpoint.on_security_control_command(command.data)
|
||||
return result or Security_Control_Response()
|
||||
|
||||
def on_delayreport_command(self, command: DelayReport_Command) -> Message | None:
|
||||
async def on_delayreport_command(
|
||||
self, command: DelayReport_Command
|
||||
) -> Message | None:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||
|
||||
result = endpoint.on_delayreport_command(command.delay)
|
||||
result = await endpoint.on_delayreport_command(command.delay)
|
||||
return result or DelayReport_Response()
|
||||
|
||||
|
||||
@@ -1903,25 +1947,22 @@ class Stream:
|
||||
await self.rtp_channel.disconnect()
|
||||
self.rtp_channel = None
|
||||
|
||||
# Release the endpoint
|
||||
self.local_endpoint.in_use = 0
|
||||
|
||||
self.change_state(State.IDLE)
|
||||
|
||||
def on_set_configuration_command(
|
||||
async def on_set_configuration_command(
|
||||
self, configuration: Iterable[ServiceCapabilities]
|
||||
) -> Message | None:
|
||||
if self.state != State.IDLE:
|
||||
return Set_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
result = self.local_endpoint.on_set_configuration_command(configuration)
|
||||
result = await self.local_endpoint.on_set_configuration_command(configuration)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
self.change_state(State.CONFIGURED)
|
||||
return None
|
||||
|
||||
def on_get_configuration_command(self) -> Message | None:
|
||||
async def on_get_configuration_command(self) -> Message | None:
|
||||
if self.state not in (
|
||||
State.CONFIGURED,
|
||||
State.OPEN,
|
||||
@@ -1929,25 +1970,25 @@ class Stream:
|
||||
):
|
||||
return Get_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
return self.local_endpoint.on_get_configuration_command()
|
||||
return await self.local_endpoint.on_get_configuration_command()
|
||||
|
||||
def on_reconfigure_command(
|
||||
async def on_reconfigure_command(
|
||||
self, configuration: Iterable[ServiceCapabilities]
|
||||
) -> Message | None:
|
||||
if self.state != State.OPEN:
|
||||
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
result = self.local_endpoint.on_reconfigure_command(configuration)
|
||||
result = await self.local_endpoint.on_reconfigure_command(configuration)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def on_open_command(self) -> Message | None:
|
||||
async def on_open_command(self) -> Message | None:
|
||||
if self.state != State.CONFIGURED:
|
||||
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
result = self.local_endpoint.on_open_command()
|
||||
result = await self.local_endpoint.on_open_command()
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
@@ -1957,7 +1998,7 @@ class Stream:
|
||||
self.change_state(State.OPEN)
|
||||
return None
|
||||
|
||||
def on_start_command(self) -> Message | None:
|
||||
async def on_start_command(self) -> Message | None:
|
||||
if self.state != State.OPEN:
|
||||
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
@@ -1966,29 +2007,29 @@ class Stream:
|
||||
logger.warning('received start command before RTP channel establishment')
|
||||
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
result = self.local_endpoint.on_start_command()
|
||||
result = await self.local_endpoint.on_start_command()
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
self.change_state(State.STREAMING)
|
||||
return None
|
||||
|
||||
def on_suspend_command(self) -> Message | None:
|
||||
async def on_suspend_command(self) -> Message | None:
|
||||
if self.state != State.STREAMING:
|
||||
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
result = self.local_endpoint.on_suspend_command()
|
||||
result = await self.local_endpoint.on_suspend_command()
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
self.change_state(State.OPEN)
|
||||
return None
|
||||
|
||||
def on_close_command(self) -> Message | None:
|
||||
async def on_close_command(self) -> Message | None:
|
||||
if self.state not in (State.OPEN, State.STREAMING):
|
||||
return Open_Reject(AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
result = self.local_endpoint.on_close_command()
|
||||
result = await self.local_endpoint.on_close_command()
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
@@ -2003,7 +2044,8 @@ class Stream:
|
||||
|
||||
return None
|
||||
|
||||
def on_abort_command(self) -> Message | None:
|
||||
async def on_abort_command(self) -> Message | None:
|
||||
await self.local_endpoint.on_abort_command()
|
||||
if self.rtp_channel is None:
|
||||
# No need to wait
|
||||
self.change_state(State.IDLE)
|
||||
@@ -2028,7 +2070,6 @@ class Stream:
|
||||
def on_l2cap_channel_close(self) -> None:
|
||||
logger.debug(color('<<< stream channel closed', 'magenta'))
|
||||
self.local_endpoint.on_rtp_channel_close()
|
||||
self.local_endpoint.in_use = 0
|
||||
self.rtp_channel = None
|
||||
|
||||
if self.state in (State.CLOSING, State.ABORTING):
|
||||
@@ -2053,7 +2094,6 @@ class Stream:
|
||||
self.state = State.IDLE
|
||||
|
||||
local_endpoint.stream = self
|
||||
local_endpoint.in_use = 1
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
@@ -2063,14 +2103,16 @@ class Stream:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class StreamEndPoint:
|
||||
class StreamEndPoint(abc.ABC):
|
||||
seid: int
|
||||
media_type: MediaType
|
||||
tsep: StreamEndPointType
|
||||
in_use: int
|
||||
capabilities: Iterable[ServiceCapabilities]
|
||||
|
||||
@property
|
||||
def in_use(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class StreamEndPointProxy:
|
||||
@@ -2110,14 +2152,30 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
|
||||
in_use: int,
|
||||
capabilities: Iterable[ServiceCapabilities],
|
||||
) -> None:
|
||||
StreamEndPoint.__init__(self, seid, media_type, tsep, in_use, capabilities)
|
||||
StreamEndPointProxy.__init__(self, protocol, seid)
|
||||
# StreamEndPoint attributes
|
||||
self.seid = seid
|
||||
self.media_type = media_type
|
||||
self.tsep = tsep
|
||||
self._in_use = in_use
|
||||
self.capabilities = capabilities
|
||||
|
||||
StreamEndPointProxy.__init__(self, protocol=protocol, seid=seid)
|
||||
|
||||
@property
|
||||
def in_use(self) -> int:
|
||||
return self._in_use
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
|
||||
stream: Stream | None
|
||||
|
||||
@property
|
||||
def in_use(self) -> int:
|
||||
if self.stream and self.stream.state != State.IDLE:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
EVENT_CONFIGURATION = "configuration"
|
||||
EVENT_OPEN = "open"
|
||||
EVENT_START = "start"
|
||||
@@ -2140,8 +2198,13 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
|
||||
capabilities: Iterable[ServiceCapabilities],
|
||||
configuration: Iterable[ServiceCapabilities] | None = None,
|
||||
):
|
||||
StreamEndPoint.__init__(self, seid, media_type, tsep, 0, capabilities)
|
||||
utils.EventEmitter.__init__(self)
|
||||
# StreamEndPoint attributes
|
||||
self.seid = seid
|
||||
self.media_type = media_type
|
||||
self.tsep = tsep
|
||||
self.capabilities = capabilities
|
||||
|
||||
self.protocol = protocol
|
||||
self.configuration = configuration if configuration is not None else []
|
||||
self.stream = None
|
||||
@@ -2155,13 +2218,13 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
|
||||
async def close(self) -> None:
|
||||
"""[Source Only] Handles when receiving close command."""
|
||||
|
||||
def on_reconfigure_command(
|
||||
async def on_reconfigure_command(
|
||||
self, command: Iterable[ServiceCapabilities]
|
||||
) -> Message | None:
|
||||
del command # unused.
|
||||
return None
|
||||
|
||||
def on_set_configuration_command(
|
||||
async def on_set_configuration_command(
|
||||
self, configuration: Iterable[ServiceCapabilities]
|
||||
) -> Message | None:
|
||||
logger.debug(
|
||||
@@ -2172,34 +2235,34 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
|
||||
self.emit(self.EVENT_CONFIGURATION)
|
||||
return None
|
||||
|
||||
def on_get_configuration_command(self) -> Message | None:
|
||||
async def on_get_configuration_command(self) -> Message | None:
|
||||
return Get_Configuration_Response(self.configuration)
|
||||
|
||||
def on_open_command(self) -> Message | None:
|
||||
async def on_open_command(self) -> Message | None:
|
||||
self.emit(self.EVENT_OPEN)
|
||||
return None
|
||||
|
||||
def on_start_command(self) -> Message | None:
|
||||
async def on_start_command(self) -> Message | None:
|
||||
self.emit(self.EVENT_START)
|
||||
return None
|
||||
|
||||
def on_suspend_command(self) -> Message | None:
|
||||
async def on_suspend_command(self) -> Message | None:
|
||||
self.emit(self.EVENT_SUSPEND)
|
||||
return None
|
||||
|
||||
def on_close_command(self) -> Message | None:
|
||||
async def on_close_command(self) -> Message | None:
|
||||
self.emit(self.EVENT_CLOSE)
|
||||
return None
|
||||
|
||||
def on_abort_command(self) -> Message | None:
|
||||
async def on_abort_command(self) -> Message | None:
|
||||
self.emit(self.EVENT_ABORT)
|
||||
return None
|
||||
|
||||
def on_delayreport_command(self, delay: int) -> Message | None:
|
||||
async def on_delayreport_command(self, delay: int) -> Message | None:
|
||||
self.emit(self.EVENT_DELAY_REPORT, delay)
|
||||
return None
|
||||
|
||||
def on_security_control_command(self, data: bytes) -> Message | None:
|
||||
async def on_security_control_command(self, data: bytes) -> Message | None:
|
||||
self.emit(self.EVENT_SECURITY_CONTROL, data)
|
||||
return None
|
||||
|
||||
@@ -2227,12 +2290,12 @@ class LocalSource(LocalStreamEndPoint):
|
||||
codec_capabilities,
|
||||
] + list(other_capabilities)
|
||||
super().__init__(
|
||||
protocol,
|
||||
seid,
|
||||
codec_capabilities.media_type,
|
||||
AVDTP_TSEP_SRC,
|
||||
capabilities,
|
||||
capabilities,
|
||||
protocol=protocol,
|
||||
seid=seid,
|
||||
media_type=codec_capabilities.media_type,
|
||||
tsep=AVDTP_TSEP_SRC,
|
||||
capabilities=capabilities,
|
||||
configuration=capabilities,
|
||||
)
|
||||
self.packet_pump = packet_pump
|
||||
|
||||
@@ -2251,13 +2314,13 @@ class LocalSource(LocalStreamEndPoint):
|
||||
self.emit(self.EVENT_STOP)
|
||||
|
||||
@override
|
||||
def on_start_command(self) -> Message | None:
|
||||
asyncio.create_task(self.start())
|
||||
async def on_start_command(self) -> Message | None:
|
||||
await self.start()
|
||||
return None
|
||||
|
||||
@override
|
||||
def on_suspend_command(self) -> Message | None:
|
||||
asyncio.create_task(self.stop())
|
||||
async def on_suspend_command(self) -> Message | None:
|
||||
await self.stop()
|
||||
return None
|
||||
|
||||
|
||||
@@ -2271,11 +2334,11 @@ class LocalSink(LocalStreamEndPoint):
|
||||
codec_capabilities,
|
||||
]
|
||||
super().__init__(
|
||||
protocol,
|
||||
seid,
|
||||
codec_capabilities.media_type,
|
||||
AVDTP_TSEP_SNK,
|
||||
capabilities,
|
||||
protocol=protocol,
|
||||
seid=seid,
|
||||
media_type=codec_capabilities.media_type,
|
||||
tsep=AVDTP_TSEP_SNK,
|
||||
capabilities=capabilities,
|
||||
)
|
||||
|
||||
def on_rtp_channel_open(self) -> None:
|
||||
|
||||
+31
-16
@@ -1423,6 +1423,9 @@ class ScoLink(utils.CompositeEventEmitter):
|
||||
acl_connection: Connection
|
||||
handle: int
|
||||
link_type: int
|
||||
rx_packet_length: int
|
||||
tx_packet_length: int
|
||||
air_mode: hci.CodecID
|
||||
sink: Callable[[hci.HCI_SynchronousDataPacket], Any] | None = None
|
||||
|
||||
EVENT_DISCONNECTION: ClassVar[str] = "disconnection"
|
||||
@@ -2343,6 +2346,9 @@ class Device(utils.CompositeEventEmitter):
|
||||
_pending_cis: dict[int, tuple[int, int]]
|
||||
gatt_service: gatt_service.GenericAttributeProfileService | None = None
|
||||
keystore: KeyStore | None = None
|
||||
inquiry_response: bytes | None = None
|
||||
address_resolver: smp.AddressResolver | None = None
|
||||
connect_own_address_type: hci.OwnAddressType | None = None
|
||||
|
||||
EVENT_ADVERTISEMENT = "advertisement"
|
||||
EVENT_PERIODIC_ADVERTISING_SYNC_TRANSFER = "periodic_advertising_sync_transfer"
|
||||
@@ -2461,17 +2467,12 @@ class Device(utils.CompositeEventEmitter):
|
||||
self.bis_links = {}
|
||||
self.big_syncs = {}
|
||||
self.classic_enabled = False
|
||||
self.inquiry_response = None
|
||||
self.address_resolver = None
|
||||
self.classic_pending_accepts = {
|
||||
hci.Address.ANY: []
|
||||
} # Futures, by BD address OR [Futures] for hci.Address.ANY
|
||||
|
||||
self._cis_lock = asyncio.Lock()
|
||||
|
||||
# Own address type cache
|
||||
self.connect_own_address_type = None
|
||||
|
||||
self.name = config.name
|
||||
self.public_address = hci.Address.ANY
|
||||
self.random_address = config.address
|
||||
@@ -5618,8 +5619,8 @@ class Device(utils.CompositeEventEmitter):
|
||||
async def notify_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
attribute: Attribute,
|
||||
value: Any | None = None,
|
||||
attribute: Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -5638,7 +5639,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
|
||||
|
||||
async def notify_subscribers(
|
||||
self, attribute: Attribute, value: Any | None = None, force: bool = False
|
||||
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Send a notification to all the subscribers of an attribute.
|
||||
@@ -5657,8 +5658,8 @@ class Device(utils.CompositeEventEmitter):
|
||||
async def indicate_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
attribute: Attribute,
|
||||
value: Any | None = None,
|
||||
attribute: Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -5679,7 +5680,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
|
||||
|
||||
async def indicate_subscribers(
|
||||
self, attribute: Attribute, value: Any | None = None, force: bool = False
|
||||
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
|
||||
):
|
||||
"""
|
||||
Send an indication to all the subscribers of an attribute.
|
||||
@@ -6051,7 +6052,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
def on_connection_request(
|
||||
self, bd_addr: hci.Address, class_of_device: int, link_type: int
|
||||
):
|
||||
logger.debug(f'*** Connection request: {bd_addr}')
|
||||
logger.debug(f'*** Connection request: {bd_addr} link_type={link_type}')
|
||||
|
||||
# Handle SCO request.
|
||||
if link_type in (
|
||||
@@ -6061,6 +6062,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
if connection := self.find_connection_by_bd_addr(
|
||||
bd_addr, transport=PhysicalTransport.BR_EDR
|
||||
):
|
||||
connection.emit(self.EVENT_SCO_REQUEST, link_type)
|
||||
self.emit(self.EVENT_SCO_REQUEST, connection, link_type)
|
||||
else:
|
||||
logger.error(f'SCO request from a non-connected device {bd_addr}')
|
||||
@@ -6420,8 +6422,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
logger.warning('peer name is not valid UTF-8')
|
||||
if connection:
|
||||
connection.emit(connection.EVENT_REMOTE_NAME_FAILURE, error)
|
||||
else:
|
||||
self.emit(self.EVENT_REMOTE_NAME_FAILURE, address, error)
|
||||
self.emit(self.EVENT_REMOTE_NAME_FAILURE, address, error)
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@@ -6438,7 +6439,13 @@ class Device(utils.CompositeEventEmitter):
|
||||
@with_connection_from_address
|
||||
@utils.experimental('Only for testing.')
|
||||
def on_sco_connection(
|
||||
self, acl_connection: Connection, sco_handle: int, link_type: int
|
||||
self,
|
||||
acl_connection: Connection,
|
||||
sco_handle: int,
|
||||
link_type: int,
|
||||
rx_packet_length: int,
|
||||
tx_packet_length: int,
|
||||
air_mode: int,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'*** SCO connected: {acl_connection.peer_address}, '
|
||||
@@ -6450,7 +6457,11 @@ class Device(utils.CompositeEventEmitter):
|
||||
acl_connection=acl_connection,
|
||||
handle=sco_handle,
|
||||
link_type=link_type,
|
||||
rx_packet_length=rx_packet_length,
|
||||
tx_packet_length=tx_packet_length,
|
||||
air_mode=hci.CodecID(air_mode),
|
||||
)
|
||||
acl_connection.emit(self.EVENT_SCO_CONNECTION, sco_link)
|
||||
self.emit(self.EVENT_SCO_CONNECTION, sco_link)
|
||||
|
||||
# [Classic only]
|
||||
@@ -6461,7 +6472,8 @@ class Device(utils.CompositeEventEmitter):
|
||||
self, acl_connection: Connection, status: int
|
||||
) -> None:
|
||||
logger.debug(f'*** SCO connection failure: {acl_connection.peer_address}***')
|
||||
self.emit(self.EVENT_SCO_CONNECTION_FAILURE)
|
||||
acl_connection.emit(self.EVENT_SCO_CONNECTION_FAILURE, status)
|
||||
self.emit(self.EVENT_SCO_CONNECTION_FAILURE, status)
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@@ -6924,15 +6936,18 @@ class Device(utils.CompositeEventEmitter):
|
||||
@with_connection_from_address
|
||||
def on_classic_pairing(self, connection: Connection) -> None:
|
||||
connection.emit(connection.EVENT_CLASSIC_PAIRING)
|
||||
self.emit(connection.EVENT_CLASSIC_PAIRING, connection)
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@with_connection_from_address
|
||||
def on_classic_pairing_failure(self, connection: Connection, status: int) -> None:
|
||||
connection.emit(connection.EVENT_CLASSIC_PAIRING_FAILURE, status)
|
||||
self.emit(connection.EVENT_CLASSIC_PAIRING_FAILURE, connection, status)
|
||||
|
||||
def on_pairing_start(self, connection: Connection) -> None:
|
||||
connection.emit(connection.EVENT_PAIRING_START)
|
||||
self.emit(connection.EVENT_PAIRING_START, connection)
|
||||
|
||||
def on_pairing(
|
||||
self,
|
||||
|
||||
+24
-22
@@ -67,6 +67,8 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
|
||||
# Helpers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
|
||||
def _bearer_id(bearer: att.Bearer) -> str:
|
||||
if att.is_enhanced_bearer(bearer):
|
||||
@@ -369,8 +371,8 @@ class Server(utils.EventEmitter):
|
||||
async def notify_subscriber(
|
||||
self,
|
||||
bearer: att.Bearer,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None = None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
if att.is_enhanced_bearer(bearer) or force:
|
||||
@@ -390,8 +392,8 @@ class Server(utils.EventEmitter):
|
||||
async def _notify_single_subscriber(
|
||||
self,
|
||||
bearer: att.Bearer,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None,
|
||||
force: bool,
|
||||
) -> None:
|
||||
# Check if there's a subscriber
|
||||
@@ -411,19 +413,19 @@ class Server(utils.EventEmitter):
|
||||
return
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
value_as_bytes = (
|
||||
await attribute.read_value(bearer)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
|
||||
# Truncate if needed
|
||||
if len(value) > bearer.att_mtu - 3:
|
||||
value = value[: bearer.att_mtu - 3]
|
||||
if len(value_as_bytes) > bearer.att_mtu - 3:
|
||||
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
|
||||
|
||||
# Notify
|
||||
notification = att.ATT_Handle_Value_Notification(
|
||||
attribute_handle=attribute.handle, attribute_value=value
|
||||
attribute_handle=attribute.handle, attribute_value=value_as_bytes
|
||||
)
|
||||
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
|
||||
self.send_gatt_pdu(bearer, bytes(notification))
|
||||
@@ -431,8 +433,8 @@ class Server(utils.EventEmitter):
|
||||
async def indicate_subscriber(
|
||||
self,
|
||||
bearer: att.Bearer,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None = None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
if att.is_enhanced_bearer(bearer) or force:
|
||||
@@ -452,8 +454,8 @@ class Server(utils.EventEmitter):
|
||||
async def _indicate_single_bearer(
|
||||
self,
|
||||
bearer: att.Bearer,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None,
|
||||
force: bool,
|
||||
) -> None:
|
||||
# Check if there's a subscriber
|
||||
@@ -473,19 +475,19 @@ class Server(utils.EventEmitter):
|
||||
return
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
value_as_bytes = (
|
||||
await attribute.read_value(bearer)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
|
||||
# Truncate if needed
|
||||
if len(value) > bearer.att_mtu - 3:
|
||||
value = value[: bearer.att_mtu - 3]
|
||||
if len(value_as_bytes) > bearer.att_mtu - 3:
|
||||
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
|
||||
|
||||
# Indicate
|
||||
indication = att.ATT_Handle_Value_Indication(
|
||||
attribute_handle=attribute.handle, attribute_value=value
|
||||
attribute_handle=attribute.handle, attribute_value=value_as_bytes
|
||||
)
|
||||
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
|
||||
|
||||
@@ -510,8 +512,8 @@ class Server(utils.EventEmitter):
|
||||
async def _notify_or_indicate_subscribers(
|
||||
self,
|
||||
indicate: bool,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None = None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
# Get all the bearers for which there's at least one subscription
|
||||
@@ -537,8 +539,8 @@ class Server(utils.EventEmitter):
|
||||
|
||||
async def notify_subscribers(
|
||||
self,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None = None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
):
|
||||
return await self._notify_or_indicate_subscribers(
|
||||
@@ -547,8 +549,8 @@ class Server(utils.EventEmitter):
|
||||
|
||||
async def indicate_subscribers(
|
||||
self,
|
||||
attribute: att.Attribute,
|
||||
value: bytes | None = None,
|
||||
attribute: att.Attribute[_T],
|
||||
value: _T | None = None,
|
||||
force: bool = False,
|
||||
):
|
||||
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
|
||||
|
||||
+121
-28
@@ -1721,6 +1721,15 @@ class CodecID(SpecableEnum):
|
||||
VENDOR_SPECIFIC = 0xFF
|
||||
|
||||
|
||||
# From Bluetooth Assigned Numbers, 2.10 PCM_Data_Format
|
||||
class PcmDataFormat(SpecableEnum):
|
||||
NA = 0x00
|
||||
ONES_COMPLEMENT = 0x01
|
||||
TWOS_COMPLEMENT = 0x02
|
||||
SIGN_MAGNITUDE = 0x03
|
||||
UNSIGNED = 0x04
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CodingFormat:
|
||||
codec_id: CodecID
|
||||
@@ -1729,7 +1738,7 @@ class CodingFormat:
|
||||
|
||||
@classmethod
|
||||
def parse_from_bytes(cls, data: bytes, offset: int) -> tuple[int, CodingFormat]:
|
||||
(codec_id, company_id, vendor_specific_codec_id) = struct.unpack_from(
|
||||
codec_id, company_id, vendor_specific_codec_id = struct.unpack_from(
|
||||
'<BHH', data, offset
|
||||
)
|
||||
return offset + 5, cls(
|
||||
@@ -1748,6 +1757,61 @@ class CodingFormat:
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class VoiceSetting:
|
||||
class AirCodingFormat(enum.IntEnum):
|
||||
CVSD = 0
|
||||
U_LAW = 1
|
||||
A_LAW = 2
|
||||
TRANSPARENT_DATA = 3
|
||||
|
||||
class InputSampleSize(enum.IntEnum):
|
||||
SIZE_8_BITS = 0
|
||||
SIZE_16_BITS = 1
|
||||
|
||||
class InputDataFormat(enum.IntEnum):
|
||||
ONES_COMPLEMENT = 0
|
||||
TWOS_COMPLEMENT = 1
|
||||
SIGN_AND_MAGNITUDE = 2
|
||||
UNSIGNED = 3
|
||||
|
||||
class InputCodingFormat(enum.IntEnum):
|
||||
LINEAR = 0
|
||||
U_LAW = 1
|
||||
A_LAW = 2
|
||||
RESERVED = 3
|
||||
|
||||
air_coding_format: AirCodingFormat = AirCodingFormat.CVSD
|
||||
linear_pcm_bit_position: int = 0
|
||||
input_sample_size: InputSampleSize = InputSampleSize.SIZE_8_BITS
|
||||
input_data_format: InputDataFormat = InputDataFormat.ONES_COMPLEMENT
|
||||
input_coding_format: InputCodingFormat = InputCodingFormat.LINEAR
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, value: int) -> VoiceSetting:
|
||||
air_coding_format = cls.AirCodingFormat(value & 0b11)
|
||||
linear_pcm_bit_position = (value >> 2) & 0b111
|
||||
input_sample_size = cls.InputSampleSize((value >> 5) & 0b1)
|
||||
input_data_format = cls.InputDataFormat((value >> 6) & 0b11)
|
||||
input_coding_format = cls.InputCodingFormat((value >> 8) & 0b11)
|
||||
return cls(
|
||||
air_coding_format=air_coding_format,
|
||||
linear_pcm_bit_position=linear_pcm_bit_position,
|
||||
input_sample_size=input_sample_size,
|
||||
input_data_format=input_data_format,
|
||||
input_coding_format=input_coding_format,
|
||||
)
|
||||
|
||||
def __int__(self) -> int:
|
||||
return (
|
||||
self.air_coding_format
|
||||
| (self.linear_pcm_bit_position << 2)
|
||||
| (self.input_sample_size << 5)
|
||||
| (self.input_data_format << 6)
|
||||
| (self.input_coding_format << 8)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class HCI_Constant:
|
||||
@staticmethod
|
||||
@@ -2008,7 +2072,7 @@ class HCI_Object:
|
||||
)
|
||||
continue
|
||||
|
||||
(field_name, field_type) = object_field
|
||||
field_name, field_type = object_field
|
||||
result += HCI_Object.serialize_field(hci_object[field_name], field_type)
|
||||
|
||||
return bytes(result)
|
||||
@@ -2886,6 +2950,23 @@ class HCI_Read_Clock_Offset_Command(HCI_AsyncCommand):
|
||||
connection_handle: int = field(metadata=metadata(2))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Command.command
|
||||
@dataclasses.dataclass
|
||||
class HCI_Accept_Synchronous_Connection_Request_Command(HCI_AsyncCommand):
|
||||
'''
|
||||
See Bluetooth spec @ 7.1.27 Accept Synchronous Connection Request Command
|
||||
'''
|
||||
|
||||
bd_addr: Address = field(metadata=metadata(Address.parse_address))
|
||||
transmit_bandwidth: int = field(metadata=metadata(4))
|
||||
receive_bandwidth: int = field(metadata=metadata(4))
|
||||
max_latency: int = field(metadata=metadata(2))
|
||||
voice_setting: int = field(metadata=metadata(2))
|
||||
retransmission_effort: int = field(metadata=metadata(1))
|
||||
packet_type: int = field(metadata=metadata(2))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Command.command
|
||||
@dataclasses.dataclass
|
||||
@@ -3034,8 +3115,8 @@ class HCI_Enhanced_Setup_Synchronous_Connection_Command(HCI_AsyncCommand):
|
||||
output_coding_format: int = field(metadata=metadata(CodingFormat.parse_from_bytes))
|
||||
input_coded_data_size: int = field(metadata=metadata(2))
|
||||
output_coded_data_size: int = field(metadata=metadata(2))
|
||||
input_pcm_data_format: int = field(metadata=metadata(1))
|
||||
output_pcm_data_format: int = field(metadata=metadata(1))
|
||||
input_pcm_data_format: int = field(metadata=PcmDataFormat.type_metadata(1))
|
||||
output_pcm_data_format: int = field(metadata=PcmDataFormat.type_metadata(1))
|
||||
input_pcm_sample_payload_msb_position: int = field(metadata=metadata(1))
|
||||
output_pcm_sample_payload_msb_position: int = field(metadata=metadata(1))
|
||||
input_data_path: int = field(metadata=metadata(1))
|
||||
@@ -3046,13 +3127,6 @@ class HCI_Enhanced_Setup_Synchronous_Connection_Command(HCI_AsyncCommand):
|
||||
packet_type: int = field(metadata=metadata(2))
|
||||
retransmission_effort: int = field(metadata=metadata(1))
|
||||
|
||||
class PcmDataFormat(SpecableEnum):
|
||||
NA = 0x00
|
||||
ONES_COMPLEMENT = 0x01
|
||||
TWOS_COMPLEMENT = 0x02
|
||||
SIGN_MAGNITUDE = 0x03
|
||||
UNSIGNED = 0x04
|
||||
|
||||
class DataPath(SpecableEnum):
|
||||
HCI = 0x00
|
||||
PCM = 0x01
|
||||
@@ -3099,8 +3173,8 @@ class HCI_Enhanced_Accept_Synchronous_Connection_Request_Command(HCI_AsyncComman
|
||||
output_coding_format: int = field(metadata=metadata(CodingFormat.parse_from_bytes))
|
||||
input_coded_data_size: int = field(metadata=metadata(2))
|
||||
output_coded_data_size: int = field(metadata=metadata(2))
|
||||
input_pcm_data_format: int = field(metadata=metadata(1))
|
||||
output_pcm_data_format: int = field(metadata=metadata(1))
|
||||
input_pcm_data_format: int = field(metadata=PcmDataFormat.type_metadata(1))
|
||||
output_pcm_data_format: int = field(metadata=PcmDataFormat.type_metadata(1))
|
||||
input_pcm_sample_payload_msb_position: int = field(metadata=metadata(1))
|
||||
output_pcm_sample_payload_msb_position: int = field(metadata=metadata(1))
|
||||
input_data_path: int = field(metadata=metadata(1))
|
||||
@@ -3944,6 +4018,23 @@ class HCI_Read_Local_OOB_Extended_Data_Command(
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_SyncCommand.sync_command(HCI_StatusReturnParameters)
|
||||
@dataclasses.dataclass
|
||||
class HCI_Configure_Data_Path_Command(HCI_SyncCommand[HCI_StatusReturnParameters]):
|
||||
'''
|
||||
See Bluetooth spec @ 7.3.101 Configure Data Path Command
|
||||
'''
|
||||
|
||||
class DataPathDirection(SpecableEnum):
|
||||
INPUT = 0x00
|
||||
OUTPUT = 0x01
|
||||
|
||||
data_path_direction: DataPathDirection = field(metadata=metadata(1))
|
||||
data_path_id: int = field(metadata=metadata(1))
|
||||
vendor_specific_config: bytes = field(metadata=metadata('*'))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class HCI_Read_Local_Version_Information_ReturnParameters(HCI_StatusReturnParameters):
|
||||
@@ -7334,7 +7425,7 @@ class HCI_Connection_Complete_Event(HCI_Event):
|
||||
status: int = field(metadata=metadata(STATUS_SPEC))
|
||||
connection_handle: int = field(metadata=metadata(2))
|
||||
bd_addr: Address = field(metadata=metadata(Address.parse_address))
|
||||
link_type: int = field(metadata=LinkType.type_metadata(1))
|
||||
link_type: LinkType = field(metadata=LinkType.type_metadata(1))
|
||||
encryption_enabled: int = field(metadata=metadata(1))
|
||||
|
||||
|
||||
@@ -7730,12 +7821,6 @@ class HCI_Synchronous_Connection_Complete_Event(HCI_Event):
|
||||
SCO = 0x00
|
||||
ESCO = 0x02
|
||||
|
||||
class AirMode(SpecableEnum):
|
||||
U_LAW_LOG = 0x00
|
||||
A_LAW_LOG_AIR_MORE = 0x01
|
||||
CVSD = 0x02
|
||||
TRANSPARENT_DATA = 0x03
|
||||
|
||||
status: int = field(metadata=metadata(STATUS_SPEC))
|
||||
connection_handle: int = field(metadata=metadata(2))
|
||||
bd_addr: Address = field(metadata=metadata(Address.parse_address))
|
||||
@@ -7744,7 +7829,7 @@ class HCI_Synchronous_Connection_Complete_Event(HCI_Event):
|
||||
retransmission_window: int = field(metadata=metadata(1))
|
||||
rx_packet_length: int = field(metadata=metadata(2))
|
||||
tx_packet_length: int = field(metadata=metadata(2))
|
||||
air_mode: int = field(metadata=AirMode.type_metadata(1))
|
||||
air_mode: int = field(metadata=CodecID.type_metadata(1))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -7976,7 +8061,9 @@ class HCI_AclDataPacket(HCI_Packet):
|
||||
bc_flag = (h >> 14) & 3
|
||||
data = packet[5:]
|
||||
if len(data) != data_total_length:
|
||||
raise InvalidPacketError('invalid packet length')
|
||||
raise InvalidPacketError(
|
||||
f'invalid packet length {len(data)} != {data_total_length}'
|
||||
)
|
||||
return cls(
|
||||
connection_handle=connection_handle,
|
||||
pb_flag=pb_flag,
|
||||
@@ -8009,10 +8096,16 @@ class HCI_SynchronousDataPacket(HCI_Packet):
|
||||
See Bluetooth spec @ 5.4.3 HCI SCO Data Packets
|
||||
'''
|
||||
|
||||
class Status(enum.IntEnum):
|
||||
CORRECTLY_RECEIVED_DATA = 0b00
|
||||
POSSIBLY_INVALID_DATA = 0b01
|
||||
NO_DATA = 0b10
|
||||
DATA_PARTIALLY_LOST = 0b11
|
||||
|
||||
hci_packet_type = HCI_SYNCHRONOUS_DATA_PACKET
|
||||
|
||||
connection_handle: int
|
||||
packet_status: int
|
||||
packet_status: Status
|
||||
data_total_length: int
|
||||
data: bytes
|
||||
|
||||
@@ -8021,7 +8114,7 @@ class HCI_SynchronousDataPacket(HCI_Packet):
|
||||
# Read the header
|
||||
h, data_total_length = struct.unpack_from('<HB', packet, 1)
|
||||
connection_handle = h & 0xFFF
|
||||
packet_status = (h >> 12) & 0b11
|
||||
packet_status = cls.Status((h >> 12) & 0b11)
|
||||
data = packet[4:]
|
||||
if len(data) != data_total_length:
|
||||
raise InvalidPacketError(
|
||||
@@ -8045,7 +8138,7 @@ class HCI_SynchronousDataPacket(HCI_Packet):
|
||||
return (
|
||||
f'{color("SCO", "blue")}: '
|
||||
f'handle=0x{self.connection_handle:04x}, '
|
||||
f'ps={self.packet_status}, '
|
||||
f'ps={self.packet_status.name}, '
|
||||
f'data_total_length={self.data_total_length}, '
|
||||
f'data={self.data.hex()}'
|
||||
)
|
||||
@@ -8073,8 +8166,8 @@ class HCI_IsoDataPacket(HCI_Packet):
|
||||
def __post_init__(self) -> None:
|
||||
self.ts_flag = self.time_stamp is not None
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(packet: bytes) -> HCI_IsoDataPacket:
|
||||
@classmethod
|
||||
def from_bytes(cls, packet: bytes) -> HCI_IsoDataPacket:
|
||||
time_stamp: int | None = None
|
||||
packet_sequence_number: int | None = None
|
||||
iso_sdu_length: int | None = None
|
||||
@@ -8103,7 +8196,7 @@ class HCI_IsoDataPacket(HCI_Packet):
|
||||
pos += 4
|
||||
|
||||
iso_sdu_fragment = packet[pos:]
|
||||
return HCI_IsoDataPacket(
|
||||
return cls(
|
||||
connection_handle=connection_handle,
|
||||
pb_flag=pb_flag,
|
||||
ts_flag=ts_flag,
|
||||
|
||||
+16
-19
@@ -44,6 +44,7 @@ from bumble.hci import (
|
||||
CodecID,
|
||||
CodingFormat,
|
||||
HCI_Enhanced_Setup_Synchronous_Connection_Command,
|
||||
PcmDataFormat,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -177,7 +178,7 @@ class AgFeature(enum.IntFlag):
|
||||
VOICE_RECOGNITION_TEXT = 0x2000
|
||||
|
||||
|
||||
class AudioCodec(enum.IntEnum):
|
||||
class AudioCodec(utils.OpenIntEnum):
|
||||
"""
|
||||
Audio Codec IDs (normative).
|
||||
|
||||
@@ -189,7 +190,7 @@ class AudioCodec(enum.IntEnum):
|
||||
LC3_SWB = 0x03 # Support for LC3-SWB audio codec
|
||||
|
||||
|
||||
class HfIndicator(enum.IntEnum):
|
||||
class HfIndicator(utils.OpenIntEnum):
|
||||
"""
|
||||
HF Indicators (normative).
|
||||
|
||||
@@ -218,7 +219,7 @@ class CallHoldOperation(enum.Enum):
|
||||
)
|
||||
|
||||
|
||||
class ResponseHoldStatus(enum.IntEnum):
|
||||
class ResponseHoldStatus(utils.OpenIntEnum):
|
||||
"""
|
||||
Response Hold status (normative).
|
||||
|
||||
@@ -246,7 +247,7 @@ class AgIndicator(enum.Enum):
|
||||
BATTERY_CHARGE = 'battchg'
|
||||
|
||||
|
||||
class CallSetupAgIndicator(enum.IntEnum):
|
||||
class CallSetupAgIndicator(utils.OpenIntEnum):
|
||||
"""
|
||||
Values for the Call Setup AG indicator (normative).
|
||||
|
||||
@@ -259,7 +260,7 @@ class CallSetupAgIndicator(enum.IntEnum):
|
||||
REMOTE_ALERTED = 3 # Remote party alerted in an outgoing call
|
||||
|
||||
|
||||
class CallHeldAgIndicator(enum.IntEnum):
|
||||
class CallHeldAgIndicator(utils.OpenIntEnum):
|
||||
"""
|
||||
Values for the Call Held AG indicator (normative).
|
||||
|
||||
@@ -273,7 +274,7 @@ class CallHeldAgIndicator(enum.IntEnum):
|
||||
CALL_ON_HOLD_NO_ACTIVE_CALL = 2 # Call on hold, no active call
|
||||
|
||||
|
||||
class CallInfoDirection(enum.IntEnum):
|
||||
class CallInfoDirection(utils.OpenIntEnum):
|
||||
"""
|
||||
Call Info direction (normative).
|
||||
|
||||
@@ -284,7 +285,7 @@ class CallInfoDirection(enum.IntEnum):
|
||||
MOBILE_TERMINATED_CALL = 1
|
||||
|
||||
|
||||
class CallInfoStatus(enum.IntEnum):
|
||||
class CallInfoStatus(utils.OpenIntEnum):
|
||||
"""
|
||||
Call Info status (normative).
|
||||
|
||||
@@ -299,7 +300,7 @@ class CallInfoStatus(enum.IntEnum):
|
||||
WAITING = 5
|
||||
|
||||
|
||||
class CallInfoMode(enum.IntEnum):
|
||||
class CallInfoMode(utils.OpenIntEnum):
|
||||
"""
|
||||
Call Info mode (normative).
|
||||
|
||||
@@ -312,7 +313,7 @@ class CallInfoMode(enum.IntEnum):
|
||||
UNKNOWN = 9
|
||||
|
||||
|
||||
class CallInfoMultiParty(enum.IntEnum):
|
||||
class CallInfoMultiParty(utils.OpenIntEnum):
|
||||
"""
|
||||
Call Info Multi-Party state (normative).
|
||||
|
||||
@@ -399,7 +400,7 @@ class CallLineIdentification:
|
||||
)
|
||||
|
||||
|
||||
class VoiceRecognitionState(enum.IntEnum):
|
||||
class VoiceRecognitionState(utils.OpenIntEnum):
|
||||
"""
|
||||
vrec values provided in AT+BVRA command.
|
||||
|
||||
@@ -412,7 +413,7 @@ class VoiceRecognitionState(enum.IntEnum):
|
||||
ENHANCED_READY = 2
|
||||
|
||||
|
||||
class CmeError(enum.IntEnum):
|
||||
class CmeError(utils.OpenIntEnum):
|
||||
"""
|
||||
CME ERROR codes (partial listed).
|
||||
|
||||
@@ -1606,7 +1607,7 @@ class AgProtocol(utils.EventEmitter):
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ProfileVersion(enum.IntEnum):
|
||||
class ProfileVersion(utils.OpenIntEnum):
|
||||
"""
|
||||
Profile version (normative).
|
||||
|
||||
@@ -1954,12 +1955,8 @@ class EscoParameters:
|
||||
output_coding_format: CodingFormat = CodingFormat(CodecID.LINEAR_PCM)
|
||||
input_coded_data_size: int = 16
|
||||
output_coded_data_size: int = 16
|
||||
input_pcm_data_format: (
|
||||
HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat
|
||||
) = HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat.TWOS_COMPLEMENT
|
||||
output_pcm_data_format: (
|
||||
HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat
|
||||
) = HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat.TWOS_COMPLEMENT
|
||||
input_pcm_data_format: PcmDataFormat = PcmDataFormat.TWOS_COMPLEMENT
|
||||
output_pcm_data_format: PcmDataFormat = PcmDataFormat.TWOS_COMPLEMENT
|
||||
input_pcm_sample_payload_msb_position: int = 0
|
||||
output_pcm_sample_payload_msb_position: int = 0
|
||||
input_data_path: HCI_Enhanced_Setup_Synchronous_Connection_Command.DataPath = (
|
||||
@@ -2058,6 +2055,7 @@ _ESCO_PARAMETERS_MSBC_T1 = EscoParameters(
|
||||
max_latency=0x0008,
|
||||
packet_type=(
|
||||
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
|
||||
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3
|
||||
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
|
||||
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
|
||||
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
|
||||
@@ -2073,7 +2071,6 @@ _ESCO_PARAMETERS_MSBC_T2 = EscoParameters(
|
||||
max_latency=0x000D,
|
||||
packet_type=(
|
||||
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
|
||||
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3
|
||||
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
|
||||
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
|
||||
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
|
||||
|
||||
+25
-3
@@ -686,6 +686,8 @@ class Host(utils.EventEmitter):
|
||||
self.pending_response, timeout=response_timeout
|
||||
)
|
||||
return response
|
||||
except asyncio.TimeoutError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(color("!!! Exception while sending command:", "red"))
|
||||
raise
|
||||
@@ -864,7 +866,7 @@ class Host(utils.EventEmitter):
|
||||
self.send_hci_packet(
|
||||
hci.HCI_SynchronousDataPacket(
|
||||
connection_handle=connection_handle,
|
||||
packet_status=0,
|
||||
packet_status=hci.HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA,
|
||||
data_total_length=len(sdu),
|
||||
data=sdu,
|
||||
)
|
||||
@@ -1176,11 +1178,28 @@ class Host(utils.EventEmitter):
|
||||
def on_hci_connection_complete_event(
|
||||
self, event: hci.HCI_Connection_Complete_Event
|
||||
):
|
||||
if event.link_type == hci.HCI_Connection_Complete_Event.LinkType.SCO:
|
||||
# Pass this on to the synchronous connection handler
|
||||
forwarded_event = hci.HCI_Synchronous_Connection_Complete_Event(
|
||||
status=event.status,
|
||||
connection_handle=event.connection_handle,
|
||||
bd_addr=event.bd_addr,
|
||||
link_type=event.link_type,
|
||||
transmission_interval=0,
|
||||
retransmission_window=0,
|
||||
rx_packet_length=0,
|
||||
tx_packet_length=0,
|
||||
air_mode=0,
|
||||
)
|
||||
self.on_hci_synchronous_connection_complete_event(forwarded_event)
|
||||
return
|
||||
|
||||
if event.status == hci.HCI_SUCCESS:
|
||||
# Create/update the connection
|
||||
logger.debug(
|
||||
f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] '
|
||||
f'{event.bd_addr}'
|
||||
f'### BR/EDR ACL CONNECTION: [0x{event.connection_handle:04X}] '
|
||||
f'{event.bd_addr} '
|
||||
f'{event.link_type.name}'
|
||||
)
|
||||
|
||||
connection = self.connections.get(event.connection_handle)
|
||||
@@ -1580,6 +1599,9 @@ class Host(utils.EventEmitter):
|
||||
event.bd_addr,
|
||||
event.connection_handle,
|
||||
event.link_type,
|
||||
event.rx_packet_length,
|
||||
event.tx_packet_length,
|
||||
event.air_mode,
|
||||
)
|
||||
else:
|
||||
logger.debug(f'### SCO CONNECTION FAILED: {event.status}')
|
||||
|
||||
+1
-1
@@ -110,7 +110,7 @@ RFCOMM_DEFAULT_L2CAP_MTU = 2048
|
||||
RFCOMM_DEFAULT_INITIAL_CREDITS = 7
|
||||
RFCOMM_DEFAULT_MAX_CREDITS = 32
|
||||
RFCOMM_DEFAULT_CREDIT_THRESHOLD = RFCOMM_DEFAULT_MAX_CREDITS // 2
|
||||
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
|
||||
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 1000
|
||||
|
||||
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
|
||||
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
|
||||
|
||||
+315
-260
@@ -44,6 +44,12 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# SDP data elements are nested (SEQUENCE, ALTERNATIVE). Cap parse recursion to
|
||||
# prevent a malicious peer from crashing the process via a deeply nested PDU.
|
||||
# 32 levels is well beyond anything a legitimate service record uses.
|
||||
_MAX_DATA_ELEMENT_NESTING = 32
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -148,32 +154,6 @@ class DataElement:
|
||||
ALTERNATIVE = Type.ALTERNATIVE
|
||||
URL = Type.URL
|
||||
|
||||
TYPE_CONSTRUCTORS = {
|
||||
NIL: lambda x: DataElement(DataElement.NIL, None),
|
||||
UNSIGNED_INTEGER: lambda x, y: DataElement(
|
||||
DataElement.UNSIGNED_INTEGER,
|
||||
DataElement.unsigned_integer_from_bytes(x),
|
||||
value_size=y,
|
||||
),
|
||||
SIGNED_INTEGER: lambda x, y: DataElement(
|
||||
DataElement.SIGNED_INTEGER,
|
||||
DataElement.signed_integer_from_bytes(x),
|
||||
value_size=y,
|
||||
),
|
||||
UUID: lambda x: DataElement(
|
||||
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
|
||||
),
|
||||
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
|
||||
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
|
||||
SEQUENCE: lambda x: DataElement(
|
||||
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
|
||||
),
|
||||
ALTERNATIVE: lambda x: DataElement(
|
||||
DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
|
||||
),
|
||||
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
|
||||
}
|
||||
|
||||
type: Type
|
||||
value: Any
|
||||
value_size: int | None = None
|
||||
@@ -190,279 +170,354 @@ class DataElement:
|
||||
'integer types must have a value size specified'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def nil() -> DataElement:
|
||||
return DataElement(DataElement.NIL, None)
|
||||
@classmethod
|
||||
def nil(cls) -> DataElement:
|
||||
return cls(cls.NIL, None)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer(value: int, value_size: int) -> DataElement:
|
||||
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size)
|
||||
@classmethod
|
||||
def unsigned_integer(cls, value: int, value_size: int) -> DataElement:
|
||||
return cls(cls.UNSIGNED_INTEGER, value, value_size)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer_8(value: int) -> DataElement:
|
||||
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1)
|
||||
@classmethod
|
||||
def unsigned_integer_8(cls, value: int) -> DataElement:
|
||||
return cls(cls.UNSIGNED_INTEGER, value, value_size=1)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer_16(value: int) -> DataElement:
|
||||
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2)
|
||||
@classmethod
|
||||
def unsigned_integer_16(cls, value: int) -> DataElement:
|
||||
return cls(cls.UNSIGNED_INTEGER, value, value_size=2)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer_32(value: int) -> DataElement:
|
||||
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4)
|
||||
@classmethod
|
||||
def unsigned_integer_32(cls, value: int) -> DataElement:
|
||||
return cls(cls.UNSIGNED_INTEGER, value, value_size=4)
|
||||
|
||||
@staticmethod
|
||||
def signed_integer(value: int, value_size: int) -> DataElement:
|
||||
return DataElement(DataElement.SIGNED_INTEGER, value, value_size)
|
||||
@classmethod
|
||||
def signed_integer(cls, value: int, value_size: int) -> DataElement:
|
||||
return cls(cls.SIGNED_INTEGER, value, value_size)
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_8(value: int) -> DataElement:
|
||||
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1)
|
||||
@classmethod
|
||||
def signed_integer_8(cls, value: int) -> DataElement:
|
||||
return cls(cls.SIGNED_INTEGER, value, value_size=1)
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_16(value: int) -> DataElement:
|
||||
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2)
|
||||
@classmethod
|
||||
def signed_integer_16(cls, value: int) -> DataElement:
|
||||
return cls(cls.SIGNED_INTEGER, value, value_size=2)
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_32(value: int) -> DataElement:
|
||||
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4)
|
||||
@classmethod
|
||||
def signed_integer_32(cls, value: int) -> DataElement:
|
||||
return cls(cls.SIGNED_INTEGER, value, value_size=4)
|
||||
|
||||
@staticmethod
|
||||
def uuid(value: core.UUID) -> DataElement:
|
||||
return DataElement(DataElement.UUID, value)
|
||||
@classmethod
|
||||
def uuid(cls, value: core.UUID) -> DataElement:
|
||||
return cls(cls.UUID, value)
|
||||
|
||||
@staticmethod
|
||||
def text_string(value: bytes) -> DataElement:
|
||||
return DataElement(DataElement.TEXT_STRING, value)
|
||||
@classmethod
|
||||
def text_string(cls, value: bytes) -> DataElement:
|
||||
return cls(cls.TEXT_STRING, value)
|
||||
|
||||
@staticmethod
|
||||
def boolean(value: bool) -> DataElement:
|
||||
return DataElement(DataElement.BOOLEAN, value)
|
||||
@classmethod
|
||||
def boolean(cls, value: bool) -> DataElement:
|
||||
return cls(cls.BOOLEAN, value)
|
||||
|
||||
@staticmethod
|
||||
def sequence(value: Iterable[DataElement]) -> DataElement:
|
||||
return DataElement(DataElement.SEQUENCE, value)
|
||||
@classmethod
|
||||
def sequence(cls, value: Iterable[DataElement]) -> DataElement:
|
||||
return cls(cls.SEQUENCE, value)
|
||||
|
||||
@staticmethod
|
||||
def alternative(value: Iterable[DataElement]) -> DataElement:
|
||||
return DataElement(DataElement.ALTERNATIVE, value)
|
||||
@classmethod
|
||||
def alternative(cls, value: Iterable[DataElement]) -> DataElement:
|
||||
return cls(cls.ALTERNATIVE, value)
|
||||
|
||||
@staticmethod
|
||||
def url(value: str) -> DataElement:
|
||||
return DataElement(DataElement.URL, value)
|
||||
@classmethod
|
||||
def url(cls, value: str) -> DataElement:
|
||||
return cls(cls.URL, value)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer_from_bytes(data):
|
||||
if len(data) == 1:
|
||||
return data[0]
|
||||
@classmethod
|
||||
def unsigned_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
|
||||
match length:
|
||||
case 1:
|
||||
return data[offset]
|
||||
case 2:
|
||||
return struct.unpack_from('>H', data, offset)[0]
|
||||
case 4:
|
||||
return struct.unpack_from('>I', data, offset)[0]
|
||||
case 8:
|
||||
return struct.unpack_from('>Q', data, offset)[0]
|
||||
case invalid_length:
|
||||
raise InvalidPacketError(f'invalid integer length {invalid_length}')
|
||||
|
||||
if len(data) == 2:
|
||||
return struct.unpack('>H', data)[0]
|
||||
@classmethod
|
||||
def signed_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
|
||||
match length:
|
||||
case 1:
|
||||
return struct.unpack_from('b', data, offset)[0]
|
||||
case 2:
|
||||
return struct.unpack_from('>h', data, offset)[0]
|
||||
case 4:
|
||||
return struct.unpack_from('>i', data, offset)[0]
|
||||
case 8:
|
||||
return struct.unpack_from('>q', data, offset)[0]
|
||||
case invalid_length:
|
||||
raise InvalidPacketError(f'invalid integer length {invalid_length}')
|
||||
|
||||
if len(data) == 4:
|
||||
return struct.unpack('>I', data)[0]
|
||||
@classmethod
|
||||
def parse_from_bytes(cls, data: bytes, offset: int) -> tuple[int, DataElement]:
|
||||
parser = DataElementParser(data, offset)
|
||||
element = parser.parse_next()
|
||||
return parser.offset, element
|
||||
|
||||
if len(data) == 8:
|
||||
return struct.unpack('>Q', data)[0]
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> DataElement:
|
||||
return DataElementParser(data).parse_next()
|
||||
|
||||
raise InvalidPacketError(f'invalid integer length {len(data)}')
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_from_bytes(data):
|
||||
if len(data) == 1:
|
||||
return struct.unpack('b', data)[0]
|
||||
|
||||
if len(data) == 2:
|
||||
return struct.unpack('>h', data)[0]
|
||||
|
||||
if len(data) == 4:
|
||||
return struct.unpack('>i', data)[0]
|
||||
|
||||
if len(data) == 8:
|
||||
return struct.unpack('>q', data)[0]
|
||||
|
||||
raise InvalidPacketError(f'invalid integer length {len(data)}')
|
||||
|
||||
@staticmethod
|
||||
def list_from_bytes(data):
|
||||
elements = []
|
||||
while data:
|
||||
element = DataElement.from_bytes(data)
|
||||
elements.append(element)
|
||||
data = data[len(bytes(element)) :]
|
||||
return elements
|
||||
|
||||
@staticmethod
|
||||
def parse_from_bytes(data, offset):
|
||||
element = DataElement.from_bytes(data[offset:])
|
||||
return offset + len(bytes(element)), element
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data):
|
||||
element_type = data[0] >> 3
|
||||
size_index = data[0] & 7
|
||||
value_offset = 0
|
||||
if size_index == 0:
|
||||
if element_type == DataElement.NIL:
|
||||
value_size = 0
|
||||
else:
|
||||
value_size = 1
|
||||
elif size_index == 1:
|
||||
value_size = 2
|
||||
elif size_index == 2:
|
||||
value_size = 4
|
||||
elif size_index == 3:
|
||||
value_size = 8
|
||||
elif size_index == 4:
|
||||
value_size = 16
|
||||
elif size_index == 5:
|
||||
value_size = data[1]
|
||||
value_offset = 1
|
||||
elif size_index == 6:
|
||||
value_size = struct.unpack('>H', data[1:3])[0]
|
||||
value_offset = 2
|
||||
else: # size_index == 7
|
||||
value_size = struct.unpack('>I', data[1:5])[0]
|
||||
value_offset = 4
|
||||
|
||||
value_data = data[1 + value_offset : 1 + value_offset + value_size]
|
||||
constructor = DataElement.TYPE_CONSTRUCTORS.get(element_type)
|
||||
if constructor:
|
||||
if element_type in (
|
||||
DataElement.UNSIGNED_INTEGER,
|
||||
DataElement.SIGNED_INTEGER,
|
||||
):
|
||||
result = constructor(value_data, value_size)
|
||||
else:
|
||||
result = constructor(value_data)
|
||||
else:
|
||||
result = DataElement(element_type, value_data)
|
||||
result._bytes = data[
|
||||
: 1 + value_offset + value_size
|
||||
] # Keep a copy so we can re-serialize to an exact replica
|
||||
return result
|
||||
|
||||
def __bytes__(self):
|
||||
def __bytes__(self) -> bytes:
|
||||
# Return early if we have a cache
|
||||
if self._bytes:
|
||||
return self._bytes
|
||||
|
||||
if self.type == DataElement.NIL:
|
||||
data = b''
|
||||
elif self.type == DataElement.UNSIGNED_INTEGER:
|
||||
if self.value < 0:
|
||||
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
|
||||
match self.type:
|
||||
case DataElement.NIL:
|
||||
data = b''
|
||||
case DataElement.UNSIGNED_INTEGER:
|
||||
if self.value < 0:
|
||||
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
|
||||
|
||||
if self.value_size == 1:
|
||||
data = struct.pack('B', self.value)
|
||||
elif self.value_size == 2:
|
||||
data = struct.pack('>H', self.value)
|
||||
elif self.value_size == 4:
|
||||
data = struct.pack('>I', self.value)
|
||||
elif self.value_size == 8:
|
||||
data = struct.pack('>Q', self.value)
|
||||
else:
|
||||
raise InvalidArgumentError('invalid value_size')
|
||||
elif self.type == DataElement.SIGNED_INTEGER:
|
||||
if self.value_size == 1:
|
||||
data = struct.pack('b', self.value)
|
||||
elif self.value_size == 2:
|
||||
data = struct.pack('>h', self.value)
|
||||
elif self.value_size == 4:
|
||||
data = struct.pack('>i', self.value)
|
||||
elif self.value_size == 8:
|
||||
data = struct.pack('>q', self.value)
|
||||
else:
|
||||
raise InvalidArgumentError('invalid value_size')
|
||||
elif self.type == DataElement.UUID:
|
||||
data = bytes(reversed(bytes(self.value)))
|
||||
elif self.type == DataElement.URL:
|
||||
data = self.value.encode('utf8')
|
||||
elif self.type == DataElement.BOOLEAN:
|
||||
data = bytes([1 if self.value else 0])
|
||||
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
|
||||
data = b''.join([bytes(element) for element in self.value])
|
||||
else:
|
||||
data = self.value
|
||||
match self.value_size:
|
||||
case 1:
|
||||
data = struct.pack('B', self.value)
|
||||
case 2:
|
||||
data = struct.pack('>H', self.value)
|
||||
case 4:
|
||||
data = struct.pack('>I', self.value)
|
||||
case 8:
|
||||
data = struct.pack('>Q', self.value)
|
||||
case invalid_length:
|
||||
raise InvalidArgumentError(
|
||||
f'invalid value_size of {invalid_length}'
|
||||
)
|
||||
case DataElement.SIGNED_INTEGER:
|
||||
match self.value_size:
|
||||
case 1:
|
||||
data = struct.pack('b', self.value)
|
||||
case 2:
|
||||
data = struct.pack('>h', self.value)
|
||||
case 4:
|
||||
data = struct.pack('>i', self.value)
|
||||
case 8:
|
||||
data = struct.pack('>q', self.value)
|
||||
case invalid_length:
|
||||
raise InvalidArgumentError(
|
||||
f'invalid value_size of {invalid_length}'
|
||||
)
|
||||
case DataElement.UUID:
|
||||
data = bytes(self.value)[::-1]
|
||||
case DataElement.URL:
|
||||
data = self.value.encode('utf8')
|
||||
case DataElement.BOOLEAN:
|
||||
data = bytes([1 if self.value else 0])
|
||||
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
|
||||
data = b''.join([bytes(element) for element in self.value])
|
||||
case _:
|
||||
data = self.value
|
||||
|
||||
size = len(data)
|
||||
size_bytes = b''
|
||||
if self.type == DataElement.NIL:
|
||||
if size != 0:
|
||||
raise InvalidArgumentError('NIL must be empty')
|
||||
size_index = 0
|
||||
elif self.type in (
|
||||
DataElement.UNSIGNED_INTEGER,
|
||||
DataElement.SIGNED_INTEGER,
|
||||
DataElement.UUID,
|
||||
):
|
||||
if size <= 1:
|
||||
match self.type:
|
||||
case DataElement.NIL:
|
||||
if size != 0:
|
||||
raise InvalidArgumentError('NIL must be empty')
|
||||
size_index = 0
|
||||
elif size == 2:
|
||||
size_index = 1
|
||||
elif size == 4:
|
||||
size_index = 2
|
||||
elif size == 8:
|
||||
size_index = 3
|
||||
elif size == 16:
|
||||
size_index = 4
|
||||
else:
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
elif self.type in (
|
||||
DataElement.TEXT_STRING,
|
||||
DataElement.SEQUENCE,
|
||||
DataElement.ALTERNATIVE,
|
||||
DataElement.URL,
|
||||
):
|
||||
if size <= 0xFF:
|
||||
size_index = 5
|
||||
size_bytes = bytes([size])
|
||||
elif size <= 0xFFFF:
|
||||
size_index = 6
|
||||
size_bytes = struct.pack('>H', size)
|
||||
elif size <= 0xFFFFFFFF:
|
||||
size_index = 7
|
||||
size_bytes = struct.pack('>I', size)
|
||||
else:
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
elif self.type == DataElement.BOOLEAN:
|
||||
if size != 1:
|
||||
raise InvalidArgumentError('boolean must be 1 byte')
|
||||
size_index = 0
|
||||
else:
|
||||
raise RuntimeError("internal error - self.type not supported")
|
||||
case (
|
||||
DataElement.UNSIGNED_INTEGER
|
||||
| DataElement.SIGNED_INTEGER
|
||||
| DataElement.UUID
|
||||
):
|
||||
if size <= 1:
|
||||
size_index = 0
|
||||
elif size == 2:
|
||||
size_index = 1
|
||||
elif size == 4:
|
||||
size_index = 2
|
||||
elif size == 8:
|
||||
size_index = 3
|
||||
elif size == 16:
|
||||
size_index = 4
|
||||
else:
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
case (
|
||||
DataElement.TEXT_STRING
|
||||
| DataElement.SEQUENCE
|
||||
| DataElement.ALTERNATIVE
|
||||
| DataElement.URL
|
||||
):
|
||||
if size <= 0xFF:
|
||||
size_index = 5
|
||||
size_bytes = bytes([size])
|
||||
elif size <= 0xFFFF:
|
||||
size_index = 6
|
||||
size_bytes = struct.pack('>H', size)
|
||||
elif size <= 0xFFFFFFFF:
|
||||
size_index = 7
|
||||
size_bytes = struct.pack('>I', size)
|
||||
else:
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
case DataElement.BOOLEAN:
|
||||
if size != 1:
|
||||
raise InvalidArgumentError('boolean must be 1 byte')
|
||||
size_index = 0
|
||||
case unsupported_type:
|
||||
raise core.InvalidPacketError(
|
||||
f"internal error - {unsupported_type} not supported"
|
||||
)
|
||||
|
||||
self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
|
||||
return self._bytes
|
||||
|
||||
def to_string(self, pretty=False, indentation=0):
|
||||
def to_string(self, pretty: bool = False, indentation: int = 0) -> str:
|
||||
prefix = ' ' * indentation
|
||||
type_name = self.type.name
|
||||
if self.type == DataElement.NIL:
|
||||
value_string = ''
|
||||
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
|
||||
container_separator = '\n' if pretty else ''
|
||||
element_separator = '\n' if pretty else ','
|
||||
elements = [
|
||||
element.to_string(pretty, indentation + 1 if pretty else 0)
|
||||
for element in self.value
|
||||
]
|
||||
value_string = (
|
||||
f'[{container_separator}'
|
||||
f'{element_separator.join(elements)}'
|
||||
f'{container_separator}{prefix}]'
|
||||
)
|
||||
elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
|
||||
value_string = f'{self.value}#{self.value_size}'
|
||||
elif isinstance(self.value, DataElement):
|
||||
value_string = self.value.to_string(pretty, indentation)
|
||||
else:
|
||||
value_string = str(self.value)
|
||||
match self.type:
|
||||
case DataElement.NIL:
|
||||
value_string = ''
|
||||
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
|
||||
container_separator = '\n' if pretty else ''
|
||||
element_separator = '\n' if pretty else ','
|
||||
elements = [
|
||||
element.to_string(pretty, indentation + 1 if pretty else 0)
|
||||
for element in self.value
|
||||
]
|
||||
value_string = (
|
||||
f'[{container_separator}'
|
||||
f'{element_separator.join(elements)}'
|
||||
f'{container_separator}{prefix}]'
|
||||
)
|
||||
case DataElement.UNSIGNED_INTEGER | DataElement.SIGNED_INTEGER:
|
||||
value_string = f'{self.value}#{self.value_size}'
|
||||
case _:
|
||||
if isinstance(self.value, DataElement):
|
||||
value_string = self.value.to_string(pretty, indentation)
|
||||
else:
|
||||
value_string = str(self.value)
|
||||
return f'{prefix}{type_name}({value_string})'
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.to_string()
|
||||
|
||||
|
||||
class DataElementParser:
|
||||
def __init__(
|
||||
self, data: bytes, offset: int = 0, max_depth: int = _MAX_DATA_ELEMENT_NESTING
|
||||
) -> None:
|
||||
self.data = data
|
||||
self.offset = offset
|
||||
self.depth = 0
|
||||
self.max_depth = max_depth
|
||||
|
||||
def parse_next(self) -> DataElement:
|
||||
if self.offset >= len(self.data):
|
||||
raise core.InvalidStateError(
|
||||
f"offset {self.offset} exceeds len(data) {len(self.data)}"
|
||||
)
|
||||
start_offset = self.offset
|
||||
element_type = DataElement.Type(self.data[self.offset] >> 3)
|
||||
size_index = self.data[self.offset] & 7
|
||||
self.offset += 1
|
||||
|
||||
value_size: int
|
||||
match size_index:
|
||||
case 0:
|
||||
if element_type == DataElement.NIL:
|
||||
value_size = 0
|
||||
else:
|
||||
value_size = 1
|
||||
case 1:
|
||||
value_size = 2
|
||||
case 2:
|
||||
value_size = 4
|
||||
case 3:
|
||||
value_size = 8
|
||||
case 4:
|
||||
value_size = 16
|
||||
case 5:
|
||||
value_size = self.data[self.offset]
|
||||
self.offset += 1
|
||||
case 6:
|
||||
value_size = struct.unpack_from('>H', self.data, self.offset)[0]
|
||||
self.offset += 2
|
||||
case 7:
|
||||
value_size = struct.unpack_from('>I', self.data, self.offset)[0]
|
||||
self.offset += 4
|
||||
case _:
|
||||
raise core.UnreachableError()
|
||||
|
||||
value_start = self.offset
|
||||
value_end = self.offset + value_size
|
||||
|
||||
match element_type:
|
||||
case DataElement.NIL:
|
||||
result = DataElement(DataElement.NIL, None)
|
||||
case DataElement.UNSIGNED_INTEGER:
|
||||
result = DataElement(
|
||||
DataElement.UNSIGNED_INTEGER,
|
||||
DataElement.unsigned_integer_from_bytes(
|
||||
self.data, value_start, value_size
|
||||
),
|
||||
value_size=value_size,
|
||||
)
|
||||
case DataElement.SIGNED_INTEGER:
|
||||
result = DataElement(
|
||||
DataElement.SIGNED_INTEGER,
|
||||
DataElement.signed_integer_from_bytes(
|
||||
self.data, value_start, value_size
|
||||
),
|
||||
value_size=value_size,
|
||||
)
|
||||
case DataElement.UUID:
|
||||
result = DataElement(
|
||||
DataElement.UUID,
|
||||
core.UUID.from_bytes(self.data[value_start:value_end][::-1]),
|
||||
)
|
||||
case DataElement.TEXT_STRING:
|
||||
result = DataElement(
|
||||
DataElement.TEXT_STRING, self.data[value_start:value_end]
|
||||
)
|
||||
case DataElement.BOOLEAN:
|
||||
result = DataElement(DataElement.BOOLEAN, self.data[value_start] == 1)
|
||||
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
|
||||
self.offset = value_start
|
||||
result = DataElement(
|
||||
element_type,
|
||||
self._list_from_bytes(value_end),
|
||||
)
|
||||
if self.offset != value_end:
|
||||
logger.warning(
|
||||
"Expect parsing until offset %d, but ends at %d",
|
||||
value_end,
|
||||
self.offset,
|
||||
)
|
||||
case DataElement.URL:
|
||||
result = DataElement(
|
||||
DataElement.URL, self.data[value_start:value_end].decode('utf8')
|
||||
)
|
||||
case other_type:
|
||||
result = DataElement(other_type, self.data[value_start:value_end])
|
||||
|
||||
self.offset = value_end
|
||||
result._bytes = self.data[start_offset:value_end]
|
||||
|
||||
return result
|
||||
|
||||
def _list_from_bytes(self, end_offset: int) -> list[DataElement]:
|
||||
if self.depth >= self.max_depth:
|
||||
raise InvalidPacketError(
|
||||
f"SDP data element nesting exceeds max depth " f"({self.max_depth})"
|
||||
)
|
||||
self.depth += 1
|
||||
elements = []
|
||||
while self.offset < end_offset:
|
||||
elements.append(self.parse_next())
|
||||
self.depth -= 1
|
||||
return elements
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class ServiceAttribute:
|
||||
|
||||
+700
-353
File diff suppressed because it is too large
Load Diff
+242
-111
@@ -20,17 +20,119 @@ import contextlib
|
||||
import functools
|
||||
import json
|
||||
import sys
|
||||
import wave
|
||||
|
||||
import websockets.asyncio.server
|
||||
|
||||
import bumble.logging
|
||||
from bumble import hci, hfp, rfcomm
|
||||
from bumble.device import Connection, Device
|
||||
from bumble.device import Connection, Device, ScoLink
|
||||
from bumble.hfp import HfProtocol
|
||||
from bumble.transport import open_transport
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
ws: websockets.asyncio.server.ServerConnection | None = None
|
||||
hf_protocol: HfProtocol | None = None
|
||||
input_wav: wave.Wave_read | None = None
|
||||
output_wav: wave.Wave_write | None = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_audio_packet(packet: hci.HCI_SynchronousDataPacket) -> None:
|
||||
if (
|
||||
packet.packet_status
|
||||
!= hci.HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA
|
||||
):
|
||||
print('!!! discarding packet with status ', packet.packet_status.name)
|
||||
return
|
||||
|
||||
frame_count = len(packet.data) // 2
|
||||
print(f">>> received {frame_count} PCM samples")
|
||||
|
||||
if output_wav:
|
||||
# Save the PCM audio to the output
|
||||
output_wav.writeframes(packet.data)
|
||||
|
||||
if input_wav and hf_protocol:
|
||||
# Send PCM audio from the input, same amount as what was received
|
||||
while not (pcm_data := input_wav.readframes(frame_count)):
|
||||
input_wav.setpos(0) # Loop
|
||||
print(f">>> sending {frame_count} PCM samples")
|
||||
hf_protocol.dlc.multiplexer.l2cap_channel.connection.device.host.send_sco_sdu(
|
||||
connection_handle=packet.connection_handle,
|
||||
sdu=pcm_data,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_sco_connection(link: ScoLink) -> None:
|
||||
print('### SCO connection established:', link)
|
||||
if link.air_mode == hci.CodecID.TRANSPARENT:
|
||||
print("@@@ The controller does not encode/decode voice")
|
||||
return
|
||||
|
||||
link.sink = on_audio_packet
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_sco_request(
|
||||
link_type: int, connection: Connection, protocol: HfProtocol
|
||||
) -> None:
|
||||
if link_type == hci.HCI_Connection_Complete_Event.LinkType.SCO:
|
||||
esco_parameters = hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.SCO_CVSD_D1]
|
||||
elif protocol.active_codec == hfp.AudioCodec.MSBC:
|
||||
esco_parameters = hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_MSBC_T2]
|
||||
elif protocol.active_codec == hfp.AudioCodec.CVSD:
|
||||
esco_parameters = hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_CVSD_S4]
|
||||
else:
|
||||
raise RuntimeError("unknown active codec")
|
||||
|
||||
if connection.device.host.supports_command(
|
||||
hci.HCI_ENHANCED_ACCEPT_SYNCHRONOUS_CONNECTION_REQUEST_COMMAND
|
||||
):
|
||||
connection.cancel_on_disconnection(
|
||||
connection.device.send_async_command(
|
||||
hci.HCI_Enhanced_Accept_Synchronous_Connection_Request_Command(
|
||||
bd_addr=connection.peer_address, **esco_parameters.asdict()
|
||||
)
|
||||
)
|
||||
)
|
||||
elif connection.device.host.supports_command(
|
||||
hci.HCI_ACCEPT_SYNCHRONOUS_CONNECTION_REQUEST_COMMAND
|
||||
):
|
||||
connection.cancel_on_disconnection(
|
||||
connection.device.send_async_command(
|
||||
hci.HCI_Accept_Synchronous_Connection_Request_Command(
|
||||
bd_addr=connection.peer_address,
|
||||
transmit_bandwidth=esco_parameters.transmit_bandwidth,
|
||||
receive_bandwidth=esco_parameters.receive_bandwidth,
|
||||
max_latency=esco_parameters.max_latency,
|
||||
voice_setting=int(
|
||||
hci.VoiceSetting(
|
||||
input_sample_size=hci.VoiceSetting.InputSampleSize.SIZE_16_BITS,
|
||||
input_data_format=hci.VoiceSetting.InputDataFormat.TWOS_COMPLEMENT,
|
||||
)
|
||||
),
|
||||
retransmission_effort=esco_parameters.retransmission_effort,
|
||||
packet_type=esco_parameters.packet_type,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
print('!!! no supported command for SCO connection request')
|
||||
return
|
||||
|
||||
global output_wav
|
||||
if output_wav:
|
||||
output_wav.setnchannels(1)
|
||||
output_wav.setsampwidth(2)
|
||||
match protocol.active_codec:
|
||||
case hfp.AudioCodec.CVSD:
|
||||
output_wav.setframerate(8000)
|
||||
case hfp.AudioCodec.MSBC:
|
||||
output_wav.setframerate(16000)
|
||||
|
||||
connection.on('sco_connection', on_sco_connection)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -40,134 +142,163 @@ def on_dlc(dlc: rfcomm.DLC, configuration: hfp.HfConfiguration):
|
||||
hf_protocol = HfProtocol(dlc, configuration)
|
||||
asyncio.create_task(hf_protocol.run())
|
||||
|
||||
def on_sco_request(connection: Connection, link_type: int, protocol: HfProtocol):
|
||||
if connection == protocol.dlc.multiplexer.l2cap_channel.connection:
|
||||
if link_type == hci.HCI_Connection_Complete_Event.LinkType.SCO:
|
||||
esco_parameters = hfp.ESCO_PARAMETERS[
|
||||
hfp.DefaultCodecParameters.SCO_CVSD_D1
|
||||
]
|
||||
elif protocol.active_codec == hfp.AudioCodec.MSBC:
|
||||
esco_parameters = hfp.ESCO_PARAMETERS[
|
||||
hfp.DefaultCodecParameters.ESCO_MSBC_T2
|
||||
]
|
||||
elif protocol.active_codec == hfp.AudioCodec.CVSD:
|
||||
esco_parameters = hfp.ESCO_PARAMETERS[
|
||||
hfp.DefaultCodecParameters.ESCO_CVSD_S4
|
||||
]
|
||||
else:
|
||||
raise RuntimeError("unknown active codec")
|
||||
|
||||
connection.cancel_on_disconnection(
|
||||
connection.device.send_command(
|
||||
hci.HCI_Enhanced_Accept_Synchronous_Connection_Request_Command(
|
||||
bd_addr=connection.peer_address, **esco_parameters.asdict()
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
handler = functools.partial(on_sco_request, protocol=hf_protocol)
|
||||
dlc.multiplexer.l2cap_channel.connection.device.on('sco_request', handler)
|
||||
connection = dlc.multiplexer.l2cap_channel.connection
|
||||
handler = functools.partial(
|
||||
on_sco_request,
|
||||
connection=connection,
|
||||
protocol=hf_protocol,
|
||||
)
|
||||
connection.on('sco_request', handler)
|
||||
dlc.multiplexer.l2cap_channel.once(
|
||||
'close',
|
||||
lambda: dlc.multiplexer.l2cap_channel.connection.device.remove_listener(
|
||||
'sco_request', handler
|
||||
),
|
||||
lambda: connection.remove_listener('sco_request', handler),
|
||||
)
|
||||
|
||||
def on_ag_indicator(indicator):
|
||||
global ws
|
||||
if ws:
|
||||
asyncio.create_task(ws.send(str(indicator)))
|
||||
|
||||
hf_protocol.on('ag_indicator', on_ag_indicator)
|
||||
hf_protocol.on('codec_negotiation', on_codec_negotiation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_ag_indicator(indicator):
|
||||
global ws
|
||||
if ws:
|
||||
asyncio.create_task(ws.send(str(indicator)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_codec_negotiation(codec: hfp.AudioCodec):
|
||||
print(f'### Negotiated codec: {codec.name}')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run(device: Device, codec: str | None) -> None:
|
||||
if codec is None:
|
||||
supported_audio_codecs = [hfp.AudioCodec.CVSD, hfp.AudioCodec.MSBC]
|
||||
else:
|
||||
if codec == 'cvsd':
|
||||
supported_audio_codecs = [hfp.AudioCodec.CVSD]
|
||||
elif codec == 'msbc':
|
||||
supported_audio_codecs = [hfp.AudioCodec.MSBC]
|
||||
else:
|
||||
print('Unknown codec: ', codec)
|
||||
return
|
||||
|
||||
# Hands-Free profile configuration.
|
||||
# TODO: load configuration from file.
|
||||
configuration = hfp.HfConfiguration(
|
||||
supported_hf_features=[
|
||||
hfp.HfFeature.THREE_WAY_CALLING,
|
||||
hfp.HfFeature.REMOTE_VOLUME_CONTROL,
|
||||
hfp.HfFeature.ENHANCED_CALL_STATUS,
|
||||
hfp.HfFeature.ENHANCED_CALL_CONTROL,
|
||||
hfp.HfFeature.CODEC_NEGOTIATION,
|
||||
hfp.HfFeature.HF_INDICATORS,
|
||||
hfp.HfFeature.ESCO_S4_SETTINGS_SUPPORTED,
|
||||
],
|
||||
supported_hf_indicators=[
|
||||
hfp.HfIndicator.BATTERY_LEVEL,
|
||||
],
|
||||
supported_audio_codecs=supported_audio_codecs,
|
||||
)
|
||||
|
||||
# Create and register a server
|
||||
rfcomm_server = rfcomm.Server(device)
|
||||
|
||||
# Listen for incoming DLC connections
|
||||
channel_number = rfcomm_server.listen(lambda dlc: on_dlc(dlc, configuration))
|
||||
print(f'### Listening for connection on channel {channel_number}')
|
||||
|
||||
# Advertise the HFP RFComm channel in the SDP
|
||||
device.sdp_service_records = {
|
||||
0x00010001: hfp.make_hf_sdp_records(0x00010001, channel_number, configuration)
|
||||
}
|
||||
|
||||
# Let's go!
|
||||
await device.power_on()
|
||||
|
||||
# Start being discoverable and connectable
|
||||
await device.set_discoverable(True)
|
||||
await device.set_connectable(True)
|
||||
|
||||
# Start the UI websocket server to offer a few buttons and input boxes
|
||||
async def serve(websocket: websockets.asyncio.server.ServerConnection):
|
||||
global ws
|
||||
ws = websocket
|
||||
async for message in websocket:
|
||||
with contextlib.suppress(websockets.exceptions.ConnectionClosedOK):
|
||||
print('Received: ', str(message))
|
||||
|
||||
parsed = json.loads(message)
|
||||
message_type = parsed['type']
|
||||
if message_type == 'at_command':
|
||||
if hf_protocol is not None:
|
||||
response = str(
|
||||
await hf_protocol.execute_command(
|
||||
parsed['command'],
|
||||
response_type=hfp.AtResponseType.MULTIPLE,
|
||||
)
|
||||
)
|
||||
await websocket.send(response)
|
||||
elif message_type == 'query_call':
|
||||
if hf_protocol:
|
||||
response = str(await hf_protocol.query_current_calls())
|
||||
await websocket.send(response)
|
||||
|
||||
await websockets.asyncio.server.serve(serve, 'localhost', 8989)
|
||||
|
||||
await asyncio.get_running_loop().create_future() # run forever
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def main() -> None:
|
||||
if len(sys.argv) < 3:
|
||||
print('Usage: run_classic_hfp.py <device-config> <transport-spec>')
|
||||
print('example: run_classic_hfp.py classic2.json usb:04b4:f901')
|
||||
print(
|
||||
'Usage: run_hfp_handsfree.py <device-config> <transport-spec> '
|
||||
'[codec] [input] [output]'
|
||||
)
|
||||
print('example: run_hfp_handsfree.py classic2.json usb:0')
|
||||
return
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport(sys.argv[2]) as hci_transport:
|
||||
print('<<< connected')
|
||||
device_config = sys.argv[1]
|
||||
transport_spec = sys.argv[2]
|
||||
|
||||
# Hands-Free profile configuration.
|
||||
# TODO: load configuration from file.
|
||||
configuration = hfp.HfConfiguration(
|
||||
supported_hf_features=[
|
||||
hfp.HfFeature.THREE_WAY_CALLING,
|
||||
hfp.HfFeature.REMOTE_VOLUME_CONTROL,
|
||||
hfp.HfFeature.ENHANCED_CALL_STATUS,
|
||||
hfp.HfFeature.ENHANCED_CALL_CONTROL,
|
||||
hfp.HfFeature.CODEC_NEGOTIATION,
|
||||
hfp.HfFeature.HF_INDICATORS,
|
||||
hfp.HfFeature.ESCO_S4_SETTINGS_SUPPORTED,
|
||||
],
|
||||
supported_hf_indicators=[
|
||||
hfp.HfIndicator.BATTERY_LEVEL,
|
||||
],
|
||||
supported_audio_codecs=[
|
||||
hfp.AudioCodec.CVSD,
|
||||
hfp.AudioCodec.MSBC,
|
||||
],
|
||||
)
|
||||
codec: str | None = None
|
||||
if len(sys.argv) >= 4:
|
||||
codec = sys.argv[3]
|
||||
|
||||
# Create a device
|
||||
device = Device.from_config_file_with_hci(
|
||||
sys.argv[1], hci_transport.source, hci_transport.sink
|
||||
)
|
||||
device.classic_enabled = True
|
||||
input_file_name: str | None = None
|
||||
if len(sys.argv) >= 5:
|
||||
input_file_name = sys.argv[4]
|
||||
|
||||
# Create and register a server
|
||||
rfcomm_server = rfcomm.Server(device)
|
||||
output_file_name: str | None = None
|
||||
if len(sys.argv) >= 6:
|
||||
output_file_name = sys.argv[5]
|
||||
|
||||
# Listen for incoming DLC connections
|
||||
channel_number = rfcomm_server.listen(lambda dlc: on_dlc(dlc, configuration))
|
||||
print(f'### Listening for connection on channel {channel_number}')
|
||||
global input_wav, output_wav
|
||||
input_cm: contextlib.AbstractContextManager[wave.Wave_read | None] = (
|
||||
wave.open(input_file_name, "rb")
|
||||
if input_file_name
|
||||
else contextlib.nullcontext(None)
|
||||
)
|
||||
output_cm: contextlib.AbstractContextManager[wave.Wave_write | None] = (
|
||||
wave.open(output_file_name, "wb")
|
||||
if output_file_name
|
||||
else contextlib.nullcontext(None)
|
||||
)
|
||||
with input_cm as input_wav, output_cm as output_wav:
|
||||
if input_wav and input_wav.getnchannels() != 1:
|
||||
print("Mono input required")
|
||||
return
|
||||
if input_wav and input_wav.getsampwidth() != 2:
|
||||
print("16-bit input required")
|
||||
return
|
||||
|
||||
# Advertise the HFP RFComm channel in the SDP
|
||||
device.sdp_service_records = {
|
||||
0x00010001: hfp.make_hf_sdp_records(
|
||||
0x00010001, channel_number, configuration
|
||||
async with await open_transport(transport_spec) as transport:
|
||||
device = Device.from_config_file_with_hci(
|
||||
device_config, transport.source, transport.sink
|
||||
)
|
||||
}
|
||||
|
||||
# Let's go!
|
||||
await device.power_on()
|
||||
|
||||
# Start being discoverable and connectable
|
||||
await device.set_discoverable(True)
|
||||
await device.set_connectable(True)
|
||||
|
||||
# Start the UI websocket server to offer a few buttons and input boxes
|
||||
async def serve(websocket: websockets.asyncio.server.ServerConnection):
|
||||
global ws
|
||||
ws = websocket
|
||||
async for message in websocket:
|
||||
with contextlib.suppress(websockets.exceptions.ConnectionClosedOK):
|
||||
print('Received: ', str(message))
|
||||
|
||||
parsed = json.loads(message)
|
||||
message_type = parsed['type']
|
||||
if message_type == 'at_command':
|
||||
if hf_protocol is not None:
|
||||
response = str(
|
||||
await hf_protocol.execute_command(
|
||||
parsed['command'],
|
||||
response_type=hfp.AtResponseType.MULTIPLE,
|
||||
)
|
||||
)
|
||||
await websocket.send(response)
|
||||
elif message_type == 'query_call':
|
||||
if hf_protocol:
|
||||
response = str(await hf_protocol.query_current_calls())
|
||||
await websocket.send(response)
|
||||
|
||||
await websockets.asyncio.server.serve(serve, 'localhost', 8989)
|
||||
|
||||
await hci_transport.source.terminated
|
||||
device.classic_enabled = True
|
||||
await run(device, codec)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -170,7 +170,9 @@ def format_code(ctx, check=False, diff=False):
|
||||
@task
|
||||
def check_types(ctx):
|
||||
checklist = ["apps", "bumble", "examples", "tests", "tasks.py"]
|
||||
print(">>> Running the type checker...")
|
||||
try:
|
||||
print("+++ Checking with mypy...")
|
||||
ctx.run(f"mypy {' '.join(checklist)}")
|
||||
except UnexpectedExit as exc:
|
||||
print("Please check your code against the mypy messages.")
|
||||
|
||||
@@ -120,6 +120,31 @@ def test_messages(message: avdtp.Message):
|
||||
assert message.payload == parsed.payload
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'pdu',
|
||||
(
|
||||
b'', # empty PDU — would IndexError on pdu[0]
|
||||
b'\x00', # 1-byte SINGLE_PACKET — would IndexError on pdu[1]
|
||||
b'\x04', # 1-byte START_PACKET — would IndexError on pdu[1]
|
||||
b'\x44\x10', # 2-byte START_PACKET — would IndexError on pdu[2]
|
||||
),
|
||||
)
|
||||
def test_message_assembler_truncated_pdu(pdu: bytes):
|
||||
"""Truncated AVDTP PDUs from a remote peer must NOT raise IndexError —
|
||||
same DoS class as #912 (ATT empty PDU). The assembler is required to
|
||||
log + drop and stay alive so the L2CAP channel survives."""
|
||||
completed = []
|
||||
|
||||
def callback(transaction_label, message):
|
||||
completed.append((transaction_label, message))
|
||||
|
||||
assembler = avdtp.MessageAssembler(callback)
|
||||
# Must not raise; nothing should be delivered to callback either.
|
||||
assembler.on_pdu(pdu)
|
||||
assert not completed
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_rtp():
|
||||
packet = bytes.fromhex(
|
||||
|
||||
+41
-1
@@ -215,7 +215,7 @@ def test_pdu_parameter_length(caplog) -> None:
|
||||
transaction_id=0, error_code=sdp.ErrorCode.INVALID_SDP_VERSION
|
||||
)
|
||||
assert sdp.SDP_PDU.from_bytes(bytes(pdu)) == pdu
|
||||
assert not re.search("Expect \d+ bytes, got \d+", caplog.text)
|
||||
assert not re.search(r"Expect \d+ bytes, got \d+", caplog.text)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -440,3 +440,43 @@ async def run():
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_nested_sequence_recursion_guard():
|
||||
"""Regression test: deeply-nested SDP SEQUENCE/ALTERNATIVE must not crash
|
||||
the parser with RecursionError. Instead a ValueError is raised once the
|
||||
configured nesting limit is exceeded.
|
||||
|
||||
Root cause: DataElement.from_bytes -> list_from_bytes -> (constructor
|
||||
dispatching back to list_from_bytes for SEQUENCE/ALTERNATIVE) recursed
|
||||
without a depth limit. A malicious SDP peer could craft a PDU exceeding
|
||||
Pythons default recursion limit (~1000 frames) and crash the host.
|
||||
"""
|
||||
# Build nested SEQUENCE payload with tag 0x36 (SEQUENCE, 2-byte length).
|
||||
inner = b"\x35\x00" # empty SEQUENCE terminator
|
||||
for _ in range(1500):
|
||||
size = len(inner)
|
||||
if size >= 65535:
|
||||
break
|
||||
inner = bytes([0x36, (size >> 8) & 0xFF, size & 0xFF]) + inner
|
||||
|
||||
with pytest.raises(ValueError, match="nesting exceeds max depth"):
|
||||
DataElement.from_bytes(inner)
|
||||
|
||||
|
||||
def test_nested_sequence_within_limit_still_works():
|
||||
"""Nested-but-reasonable SDP SEQUENCEs must still parse correctly."""
|
||||
leaf = DataElement.unsigned_integer(1, value_size=2)
|
||||
payload = leaf
|
||||
for _ in range(16): # under the 32-depth limit
|
||||
payload = DataElement.sequence([payload])
|
||||
raw = bytes(payload)
|
||||
parsed = DataElement.from_bytes(raw)
|
||||
# Walk back down to confirm structural integrity preserved
|
||||
cur = parsed
|
||||
for _ in range(16):
|
||||
assert cur.type == DataElement.SEQUENCE
|
||||
cur = cur.value[0]
|
||||
assert cur.type == DataElement.UNSIGNED_INTEGER
|
||||
assert cur.value == 1
|
||||
|
||||
Reference in New Issue
Block a user