Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot] 834c8acd85 Bump rand in /rust in the cargo group across 1 directory
Bumps the cargo group with 1 update in the /rust directory: [rand](https://github.com/rust-random/rand).


Updates `rand` from 0.8.5 to 0.9.3
- [Release notes](https://github.com/rust-random/rand/releases)
- [Changelog](https://github.com/rust-random/rand/blob/0.9.3/CHANGELOG.md)
- [Commits](https://github.com/rust-random/rand/compare/0.8.5...0.9.3)

---
updated-dependencies:
- dependency-name: rand
  dependency-version: 0.9.3
  dependency-type: direct:production
  dependency-group: cargo
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 02:05:17 +00:00
25 changed files with 991 additions and 2202 deletions
+1 -1
View File
@@ -489,7 +489,7 @@ class Sender:
flags=(
Packet.PacketFlags.LAST
if tx_i == self.tx_packet_count - 1
else Packet.PacketFlags(0)
else 0
),
sequence=tx_i,
timestamp=int((time.time() - self.start_time) * 1000000),
-12
View File
@@ -45,10 +45,8 @@ from bumble.hci import (
HCI_Read_Local_Supported_Codecs_Command,
HCI_Read_Local_Supported_Codecs_V2_Command,
HCI_Read_Local_Version_Information_Command,
HCI_Read_Voice_Setting_Command,
LeFeature,
SpecificationVersion,
VoiceSetting,
map_null_terminated_utf8_string,
)
from bumble.host import Host
@@ -216,16 +214,6 @@ async def get_codecs_info(host: Host) -> None:
if not response2.vendor_specific_codec_ids:
print(' No Vendor-specific codecs')
if host.supports_command(HCI_Read_Voice_Setting_Command.op_code):
response3 = await host.send_sync_command(HCI_Read_Voice_Setting_Command())
voice_setting = VoiceSetting.from_int(response3.voice_setting)
print(color('Voice Setting:', 'yellow'))
print(f' Air Coding Format: {voice_setting.air_coding_format.name}')
print(f' Linear PCM Bit Position: {voice_setting.linear_pcm_bit_position}')
print(f' Input Sample Size: {voice_setting.input_sample_size.name}')
print(f' Input Data Format: {voice_setting.input_data_format.name}')
print(f' Input Coding Format: {voice_setting.input_coding_format.name}')
# -----------------------------------------------------------------------------
async def async_main(
+41 -151
View File
@@ -16,8 +16,6 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import statistics
import struct
import time
import click
@@ -27,9 +25,7 @@ from bumble.colors import color
from bumble.hci import (
HCI_READ_LOOPBACK_MODE_COMMAND,
HCI_WRITE_LOOPBACK_MODE_COMMAND,
Address,
HCI_Read_Loopback_Mode_Command,
HCI_SynchronousDataPacket,
HCI_Write_Loopback_Mode_Command,
LoopbackMode,
)
@@ -40,121 +36,55 @@ from bumble.transport import open_transport
class Loopback:
"""Send and receive ACL data packets in local loopback mode"""
def __init__(
self,
packet_size: int,
packet_count: int,
connection_type: str,
mode: str,
interval: int,
transport: str,
):
def __init__(self, packet_size: int, packet_count: int, transport: str):
self.transport = transport
self.packet_size = packet_size
self.packet_count = packet_count
self.connection_handle: int | None = None
self.connection_type = connection_type
self.connection_event = asyncio.Event()
self.mode = mode
self.interval = interval
self.done = asyncio.Event()
self.expected_counter = 0
self.expected_cid = 0
self.bytes_received = 0
self.start_timestamp = 0.0
self.last_timestamp = 0.0
self.send_timestamps: list[float] = []
self.rtts: list[float] = []
def on_connection(self, connection_handle: int, *args):
"""Retrieve connection handle from new connection event"""
if not self.connection_event.is_set():
# The first connection handle is of type ACL,
# subsequent connections are of type SCO
if self.connection_type == "sco" and self.connection_handle is None:
self.connection_handle = connection_handle
return
# save first connection handle for ACL
# subsequent connections are SCO
self.connection_handle = connection_handle
self.connection_event.set()
def on_sco_connection(
self,
address: Address,
connection_handle: int,
link_type,
rx_packet_length: int,
tx_packet_length: int,
air_mode: int,
) -> None:
self.on_connection(connection_handle)
def on_packet(self, connection_handle: int, packet: bytes):
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
"""Calculate packet receive speed"""
now = time.time()
(counter,) = struct.unpack_from("H", packet, 0)
rtt = now - self.send_timestamps[counter]
self.rtts.append(rtt)
print(f'<<< Received packet {counter}: {len(packet)} bytes, RTT={rtt:.4f}')
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
assert connection_handle == self.connection_handle
assert counter == self.expected_counter
self.expected_counter += 1
if counter == 0:
assert cid == self.expected_cid
self.expected_cid += 1
if cid == 0:
self.start_timestamp = now
else:
elapsed_since_start = now - self.start_timestamp
elapsed_since_last = now - self.last_timestamp
self.bytes_received += len(packet)
instant_rx_speed = len(packet) / elapsed_since_last
self.bytes_received += len(pdu)
instant_rx_speed = len(pdu) / elapsed_since_last
average_rx_speed = self.bytes_received / elapsed_since_start
if self.mode == 'throughput':
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f},',
'cyan',
)
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f}',
'cyan',
)
)
self.last_timestamp = now
if self.expected_counter == self.packet_count:
if self.expected_cid == self.packet_count:
print(color('@@@ Received last packet', 'green'))
self.done.set()
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
self.on_packet(connection_handle, pdu)
def on_sco_packet(self, connection_handle: int, packet) -> None:
self.on_packet(connection_handle, packet)
async def send_acl_packet(self, host: Host, packet: bytes) -> None:
assert self.connection_handle
host.send_l2cap_pdu(self.connection_handle, 0, packet)
async def send_sco_packet(self, host: Host, packet: bytes) -> None:
assert self.connection_handle
host.send_hci_packet(
HCI_SynchronousDataPacket(
connection_handle=self.connection_handle,
packet_status=HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA,
data_total_length=len(packet),
data=packet,
)
)
async def send_loop(self, host: Host, sender) -> None:
for counter in range(0, self.packet_count):
print(
color(
f'>>> Sending {self.connection_type.upper()} '
f'packet {counter}: {self.packet_size} bytes',
'yellow',
)
)
self.send_timestamps.append(time.time())
await sender(host, struct.pack("H", counter) + bytes(self.packet_size - 2))
await asyncio.sleep(self.interval / 1000 if self.mode == "rtt" else 0)
async def run(self) -> None:
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
@@ -196,11 +126,8 @@ class Loopback:
return
# set event callbacks
host.on('classic_connection', self.on_connection)
host.on('le_connection', self.on_connection)
host.on('sco_connection', self.on_sco_connection)
host.on('connection', self.on_connection)
host.on('l2cap_pdu', self.on_l2cap_pdu)
host.on('sco_packet', self.on_sco_packet)
loopback_mode = LoopbackMode.LOCAL
@@ -221,37 +148,32 @@ class Loopback:
print(color('=== Start sending', 'magenta'))
start_time = time.time()
if self.connection_type == "acl":
sender = self.send_acl_packet
elif self.connection_type == "sco":
sender = self.send_sco_packet
else:
raise ValueError(f'Unknown connection type: {self.connection_type}')
await self.send_loop(host, sender)
bytes_sent = 0
for cid in range(0, self.packet_count):
# using the cid as an incremental index
host.send_l2cap_pdu(
self.connection_handle, cid, bytes(self.packet_size)
)
print(
color(
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
)
)
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
await asyncio.sleep(0) # yield to allow packet receive
await self.done.wait()
print(color('=== Done!', 'magenta'))
bytes_sent = self.packet_size * self.packet_count
elapsed = time.time() - start_time
average_tx_speed = bytes_sent / elapsed
if self.mode == 'throughput':
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} '
f'({bytes_sent} bytes in {elapsed:.2f} seconds)',
'green',
)
)
if self.mode == 'rtt':
print(
color(
f'RTTs: min={min(self.rtts):.4f}, '
f'max={max(self.rtts):.4f}, '
f'avg={statistics.mean(self.rtts):.4f}',
'blue',
)
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
f' in {elapsed:.2f} seconds)',
'green',
)
)
# -----------------------------------------------------------------------------
@@ -272,43 +194,11 @@ class Loopback:
default=10,
help='Packet count',
)
@click.option(
'--connection-type',
'-t',
metavar='TYPE',
type=click.Choice(['acl', 'sco']),
default='acl',
help='Connection type',
)
@click.option(
'--mode',
'-m',
metavar='MODE',
type=click.Choice(['throughput', 'rtt']),
default='throughput',
help='Test mode',
)
@click.option(
'--interval',
type=int,
default=100,
help='Inter-packet interval (ms) [RTT mode only]',
)
@click.argument('transport')
def main(packet_size, packet_count, connection_type, mode, interval, transport):
def main(packet_size, packet_count, transport):
bumble.logging.setup_basic_logging()
if connection_type == "sco" and packet_size > 255:
print("ERROR: the maximum packet size for SCO is 255")
return
async def run():
loopback = Loopback(
packet_size, packet_count, connection_type, mode, interval, transport
)
await loopback.run()
asyncio.run(run())
loopback = Loopback(packet_size, packet_count, transport)
asyncio.run(loopback.run())
# -----------------------------------------------------------------------------
+1 -6
View File
@@ -111,14 +111,9 @@ def show_device_details(device):
if (endpoint.getAddress() & USB_ENDPOINT_IN == 0)
else 'IN'
)
endpoint_details = (
f', Max Packet Size = {endpoint.getMaxPacketSize()}'
if endpoint_type == 'ISOCHRONOUS'
else ''
)
print(
f' Endpoint 0x{endpoint.getAddress():02X}: '
f'{endpoint_type} {endpoint_direction}{endpoint_details}'
f'{endpoint_type} {endpoint_direction}'
)
+2 -4
View File
@@ -42,7 +42,7 @@ from typing_extensions import TypeIs
from bumble import hci, l2cap, utils
from bumble.colors import color
from bumble.core import UUID, InvalidOperationError, InvalidPacketError, ProtocolError
from bumble.core import UUID, InvalidOperationError, ProtocolError
from bumble.hci import HCI_Object
# -----------------------------------------------------------------------------
@@ -249,8 +249,6 @@ class ATT_PDU:
@classmethod
def from_bytes(cls, pdu: bytes) -> ATT_PDU:
if not pdu:
raise InvalidPacketError("Empty ATT PDU")
op_code = pdu[0]
subclass = ATT_PDU.pdu_classes.get(op_code)
@@ -1083,7 +1081,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
else:
value_str = str(self.value)
if value_str:
value_string = f', value={value_str}'
value_string = f', value={self.value.hex()}'
else:
value_string = ''
return (
+77 -140
View File
@@ -17,7 +17,6 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import abc
import asyncio
import enum
import logging
@@ -312,13 +311,6 @@ class MessageAssembler:
def on_pdu(self, pdu: bytes) -> None:
self.packet_count += 1
# Drop empty PDUs sent by remote — accessing pdu[0] below would
# raise IndexError, propagating up to the L2CAP read loop and
# tearing down the channel. Same class as #912 (ATT empty PDU).
if not pdu:
logger.warning('AVDTP message assembler: empty PDU dropped')
return
transaction_label = pdu[0] >> 4
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
message_type = Message.MessageType(pdu[0] & 3)
@@ -332,23 +324,6 @@ class MessageAssembler:
Protocol.PacketType.SINGLE_PACKET,
Protocol.PacketType.START_PACKET,
):
# Both single and start packets carry the signal identifier in
# pdu[1]; start packets additionally carry the packet count in
# pdu[2]. Guard each access so a malformed remote frame can't
# crash the message assembler.
if len(pdu) < 2:
logger.warning(
'AVDTP %s packet too short (%d bytes); dropped',
packet_type.name,
len(pdu),
)
return
if packet_type == Protocol.PacketType.START_PACKET and len(pdu) < 3:
logger.warning(
'AVDTP START packet missing signal-packet count; dropped'
)
return
if self.message is not None:
# The previous message has not been terminated
logger.warning(
@@ -1478,23 +1453,8 @@ class Protocol(utils.EventEmitter):
handler = getattr(self, handler_name, None)
if handler:
try:
result = handler(message)
if asyncio.iscoroutine(result):
async def wait_and_send() -> None:
try:
response = await result
if response:
self.send_message(transaction_label, response)
except Exception:
logger.exception(
color("!!! Exception in handler:", "red")
)
utils.cancel_on_event(self, self.EVENT_CLOSE, wait_and_send())
else:
if result:
self.send_message(transaction_label, result)
response = handler(message)
self.send_message(transaction_label, response)
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
else:
@@ -1575,7 +1535,7 @@ class Protocol(utils.EventEmitter):
async def send_command(self, command: Message):
# TODO: support timeouts
# Send the command
transaction_label, transaction_result = await self.start_transaction()
(transaction_label, transaction_result) = await self.start_transaction()
self.send_message(transaction_label, command)
# Wait for the response
@@ -1640,14 +1600,14 @@ class Protocol(utils.EventEmitter):
async def abort(self, seid: int) -> Abort_Response:
return await self.send_command(Abort_Command(seid))
async def on_discover_command(self, command: Discover_Command) -> Message | None:
def on_discover_command(self, command: Discover_Command) -> Message | None:
endpoint_infos = [
EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep)
for endpoint in self.local_endpoints
]
return Discover_Response(endpoint_infos)
async def on_get_capabilities_command(
def on_get_capabilities_command(
self, command: Get_Capabilities_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1656,7 +1616,7 @@ class Protocol(utils.EventEmitter):
return Get_Capabilities_Response(endpoint.capabilities)
async def on_get_all_capabilities_command(
def on_get_all_capabilities_command(
self, command: Get_All_Capabilities_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1665,7 +1625,7 @@ class Protocol(utils.EventEmitter):
return Get_All_Capabilities_Response(endpoint.capabilities)
async def on_set_configuration_command(
def on_set_configuration_command(
self, command: Set_Configuration_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1680,10 +1640,10 @@ class Protocol(utils.EventEmitter):
stream = Stream(self, endpoint, StreamEndPointProxy(self, command.int_seid))
self.streams[command.acp_seid] = stream
result = await stream.on_set_configuration_command(command.capabilities)
result = stream.on_set_configuration_command(command.capabilities)
return result or Set_Configuration_Response()
async def on_get_configuration_command(
def on_get_configuration_command(
self, command: Get_Configuration_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1692,31 +1652,29 @@ class Protocol(utils.EventEmitter):
if endpoint.stream is None:
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
return await endpoint.stream.on_get_configuration_command()
return endpoint.stream.on_get_configuration_command()
async def on_reconfigure_command(
self, command: Reconfigure_Command
) -> Message | None:
def on_reconfigure_command(self, command: Reconfigure_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Reconfigure_Reject(error_code=AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = await endpoint.stream.on_reconfigure_command(command.capabilities)
result = endpoint.stream.on_reconfigure_command(command.capabilities)
return result or Reconfigure_Response()
async def on_open_command(self, command: Open_Command) -> Message | None:
def on_open_command(self, command: Open_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = await endpoint.stream.on_open_command()
result = endpoint.stream.on_open_command()
return result or Open_Response()
async def on_start_command(self, command: Start_Command) -> Message | None:
def on_start_command(self, command: Start_Command) -> Message | None:
for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None:
@@ -1730,12 +1688,12 @@ class Protocol(utils.EventEmitter):
endpoint = self.get_local_endpoint_by_seid(seid)
if not endpoint or not endpoint.stream:
raise InvalidStateError("Should already be checked!")
if (result := await endpoint.stream.on_start_command()) is not None:
if (result := endpoint.stream.on_start_command()) is not None:
return result
return Start_Response()
async def on_suspend_command(self, command: Suspend_Command) -> Message | None:
def on_suspend_command(self, command: Suspend_Command) -> Message | None:
for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None:
@@ -1749,47 +1707,45 @@ class Protocol(utils.EventEmitter):
endpoint = self.get_local_endpoint_by_seid(seid)
if not endpoint or not endpoint.stream:
raise InvalidStateError("Should already be checked!")
if (result := await endpoint.stream.on_suspend_command()) is not None:
if (result := endpoint.stream.on_suspend_command()) is not None:
return result
return Suspend_Response()
async def on_close_command(self, command: Close_Command) -> Message | None:
def on_close_command(self, command: Close_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Close_Reject(AVDTP_BAD_STATE_ERROR)
result = await endpoint.stream.on_close_command()
result = endpoint.stream.on_close_command()
return result or Close_Response()
async def on_abort_command(self, command: Abort_Command) -> Message | None:
def on_abort_command(self, command: Abort_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None or endpoint.stream is None:
return Abort_Response()
await endpoint.stream.on_abort_command()
endpoint.stream.on_abort_command()
return Abort_Response()
async def on_security_control_command(
def on_security_control_command(
self, command: Security_Control_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR)
result = await endpoint.on_security_control_command(command.data)
result = endpoint.on_security_control_command(command.data)
return result or Security_Control_Response()
async def on_delayreport_command(
self, command: DelayReport_Command
) -> Message | None:
def on_delayreport_command(self, command: DelayReport_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR)
result = await endpoint.on_delayreport_command(command.delay)
result = endpoint.on_delayreport_command(command.delay)
return result or DelayReport_Response()
@@ -1947,22 +1903,25 @@ class Stream:
await self.rtp_channel.disconnect()
self.rtp_channel = None
# Release the endpoint
self.local_endpoint.in_use = 0
self.change_state(State.IDLE)
async def on_set_configuration_command(
def on_set_configuration_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
if self.state != State.IDLE:
return Set_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = await self.local_endpoint.on_set_configuration_command(configuration)
result = self.local_endpoint.on_set_configuration_command(configuration)
if result is not None:
return result
self.change_state(State.CONFIGURED)
return None
async def on_get_configuration_command(self) -> Message | None:
def on_get_configuration_command(self) -> Message | None:
if self.state not in (
State.CONFIGURED,
State.OPEN,
@@ -1970,25 +1929,25 @@ class Stream:
):
return Get_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
return await self.local_endpoint.on_get_configuration_command()
return self.local_endpoint.on_get_configuration_command()
async def on_reconfigure_command(
def on_reconfigure_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
if self.state != State.OPEN:
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = await self.local_endpoint.on_reconfigure_command(configuration)
result = self.local_endpoint.on_reconfigure_command(configuration)
if result is not None:
return result
return None
async def on_open_command(self) -> Message | None:
def on_open_command(self) -> Message | None:
if self.state != State.CONFIGURED:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = await self.local_endpoint.on_open_command()
result = self.local_endpoint.on_open_command()
if result is not None:
return result
@@ -1998,7 +1957,7 @@ class Stream:
self.change_state(State.OPEN)
return None
async def on_start_command(self) -> Message | None:
def on_start_command(self) -> Message | None:
if self.state != State.OPEN:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
@@ -2007,29 +1966,29 @@ class Stream:
logger.warning('received start command before RTP channel establishment')
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = await self.local_endpoint.on_start_command()
result = self.local_endpoint.on_start_command()
if result is not None:
return result
self.change_state(State.STREAMING)
return None
async def on_suspend_command(self) -> Message | None:
def on_suspend_command(self) -> Message | None:
if self.state != State.STREAMING:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = await self.local_endpoint.on_suspend_command()
result = self.local_endpoint.on_suspend_command()
if result is not None:
return result
self.change_state(State.OPEN)
return None
async def on_close_command(self) -> Message | None:
def on_close_command(self) -> Message | None:
if self.state not in (State.OPEN, State.STREAMING):
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = await self.local_endpoint.on_close_command()
result = self.local_endpoint.on_close_command()
if result is not None:
return result
@@ -2044,8 +2003,7 @@ class Stream:
return None
async def on_abort_command(self) -> Message | None:
await self.local_endpoint.on_abort_command()
def on_abort_command(self) -> Message | None:
if self.rtp_channel is None:
# No need to wait
self.change_state(State.IDLE)
@@ -2070,6 +2028,7 @@ class Stream:
def on_l2cap_channel_close(self) -> None:
logger.debug(color('<<< stream channel closed', 'magenta'))
self.local_endpoint.on_rtp_channel_close()
self.local_endpoint.in_use = 0
self.rtp_channel = None
if self.state in (State.CLOSING, State.ABORTING):
@@ -2094,6 +2053,7 @@ class Stream:
self.state = State.IDLE
local_endpoint.stream = self
local_endpoint.in_use = 1
def __str__(self) -> str:
return (
@@ -2103,16 +2063,14 @@ class Stream:
# -----------------------------------------------------------------------------
class StreamEndPoint(abc.ABC):
@dataclass
class StreamEndPoint:
seid: int
media_type: MediaType
tsep: StreamEndPointType
in_use: int
capabilities: Iterable[ServiceCapabilities]
@property
def in_use(self) -> int:
raise NotImplementedError()
# -----------------------------------------------------------------------------
class StreamEndPointProxy:
@@ -2152,30 +2110,14 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
in_use: int,
capabilities: Iterable[ServiceCapabilities],
) -> None:
# StreamEndPoint attributes
self.seid = seid
self.media_type = media_type
self.tsep = tsep
self._in_use = in_use
self.capabilities = capabilities
StreamEndPointProxy.__init__(self, protocol=protocol, seid=seid)
@property
def in_use(self) -> int:
return self._in_use
StreamEndPoint.__init__(self, seid, media_type, tsep, in_use, capabilities)
StreamEndPointProxy.__init__(self, protocol, seid)
# -----------------------------------------------------------------------------
class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
stream: Stream | None
@property
def in_use(self) -> int:
if self.stream and self.stream.state != State.IDLE:
return 1
return 0
EVENT_CONFIGURATION = "configuration"
EVENT_OPEN = "open"
EVENT_START = "start"
@@ -2198,13 +2140,8 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
capabilities: Iterable[ServiceCapabilities],
configuration: Iterable[ServiceCapabilities] | None = None,
):
StreamEndPoint.__init__(self, seid, media_type, tsep, 0, capabilities)
utils.EventEmitter.__init__(self)
# StreamEndPoint attributes
self.seid = seid
self.media_type = media_type
self.tsep = tsep
self.capabilities = capabilities
self.protocol = protocol
self.configuration = configuration if configuration is not None else []
self.stream = None
@@ -2218,13 +2155,13 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
async def close(self) -> None:
"""[Source Only] Handles when receiving close command."""
async def on_reconfigure_command(
def on_reconfigure_command(
self, command: Iterable[ServiceCapabilities]
) -> Message | None:
del command # unused.
return None
async def on_set_configuration_command(
def on_set_configuration_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
logger.debug(
@@ -2235,34 +2172,34 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
self.emit(self.EVENT_CONFIGURATION)
return None
async def on_get_configuration_command(self) -> Message | None:
def on_get_configuration_command(self) -> Message | None:
return Get_Configuration_Response(self.configuration)
async def on_open_command(self) -> Message | None:
def on_open_command(self) -> Message | None:
self.emit(self.EVENT_OPEN)
return None
async def on_start_command(self) -> Message | None:
def on_start_command(self) -> Message | None:
self.emit(self.EVENT_START)
return None
async def on_suspend_command(self) -> Message | None:
def on_suspend_command(self) -> Message | None:
self.emit(self.EVENT_SUSPEND)
return None
async def on_close_command(self) -> Message | None:
def on_close_command(self) -> Message | None:
self.emit(self.EVENT_CLOSE)
return None
async def on_abort_command(self) -> Message | None:
def on_abort_command(self) -> Message | None:
self.emit(self.EVENT_ABORT)
return None
async def on_delayreport_command(self, delay: int) -> Message | None:
def on_delayreport_command(self, delay: int) -> Message | None:
self.emit(self.EVENT_DELAY_REPORT, delay)
return None
async def on_security_control_command(self, data: bytes) -> Message | None:
def on_security_control_command(self, data: bytes) -> Message | None:
self.emit(self.EVENT_SECURITY_CONTROL, data)
return None
@@ -2290,12 +2227,12 @@ class LocalSource(LocalStreamEndPoint):
codec_capabilities,
] + list(other_capabilities)
super().__init__(
protocol=protocol,
seid=seid,
media_type=codec_capabilities.media_type,
tsep=AVDTP_TSEP_SRC,
capabilities=capabilities,
configuration=capabilities,
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SRC,
capabilities,
capabilities,
)
self.packet_pump = packet_pump
@@ -2314,13 +2251,13 @@ class LocalSource(LocalStreamEndPoint):
self.emit(self.EVENT_STOP)
@override
async def on_start_command(self) -> Message | None:
await self.start()
def on_start_command(self) -> Message | None:
asyncio.create_task(self.start())
return None
@override
async def on_suspend_command(self) -> Message | None:
await self.stop()
def on_suspend_command(self) -> Message | None:
asyncio.create_task(self.stop())
return None
@@ -2334,11 +2271,11 @@ class LocalSink(LocalStreamEndPoint):
codec_capabilities,
]
super().__init__(
protocol=protocol,
seid=seid,
media_type=codec_capabilities.media_type,
tsep=AVDTP_TSEP_SNK,
capabilities=capabilities,
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SNK,
capabilities,
)
def on_rtp_channel_open(self) -> None:
+16 -31
View File
@@ -1423,9 +1423,6 @@ class ScoLink(utils.CompositeEventEmitter):
acl_connection: Connection
handle: int
link_type: int
rx_packet_length: int
tx_packet_length: int
air_mode: hci.CodecID
sink: Callable[[hci.HCI_SynchronousDataPacket], Any] | None = None
EVENT_DISCONNECTION: ClassVar[str] = "disconnection"
@@ -2346,9 +2343,6 @@ class Device(utils.CompositeEventEmitter):
_pending_cis: dict[int, tuple[int, int]]
gatt_service: gatt_service.GenericAttributeProfileService | None = None
keystore: KeyStore | None = None
inquiry_response: bytes | None = None
address_resolver: smp.AddressResolver | None = None
connect_own_address_type: hci.OwnAddressType | None = None
EVENT_ADVERTISEMENT = "advertisement"
EVENT_PERIODIC_ADVERTISING_SYNC_TRANSFER = "periodic_advertising_sync_transfer"
@@ -2467,12 +2461,17 @@ class Device(utils.CompositeEventEmitter):
self.bis_links = {}
self.big_syncs = {}
self.classic_enabled = False
self.inquiry_response = None
self.address_resolver = None
self.classic_pending_accepts = {
hci.Address.ANY: []
} # Futures, by BD address OR [Futures] for hci.Address.ANY
self._cis_lock = asyncio.Lock()
# Own address type cache
self.connect_own_address_type = None
self.name = config.name
self.public_address = hci.Address.ANY
self.random_address = config.address
@@ -5619,8 +5618,8 @@ class Device(utils.CompositeEventEmitter):
async def notify_subscriber(
self,
connection: Connection,
attribute: Attribute[_T],
value: _T | None = None,
attribute: Attribute,
value: Any | None = None,
force: bool = False,
) -> None:
"""
@@ -5639,7 +5638,7 @@ class Device(utils.CompositeEventEmitter):
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
async def notify_subscribers(
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
self, attribute: Attribute, value: Any | None = None, force: bool = False
) -> None:
"""
Send a notification to all the subscribers of an attribute.
@@ -5658,8 +5657,8 @@ class Device(utils.CompositeEventEmitter):
async def indicate_subscriber(
self,
connection: Connection,
attribute: Attribute[_T],
value: _T | None = None,
attribute: Attribute,
value: Any | None = None,
force: bool = False,
):
"""
@@ -5680,7 +5679,7 @@ class Device(utils.CompositeEventEmitter):
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
async def indicate_subscribers(
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
self, attribute: Attribute, value: Any | None = None, force: bool = False
):
"""
Send an indication to all the subscribers of an attribute.
@@ -6052,7 +6051,7 @@ class Device(utils.CompositeEventEmitter):
def on_connection_request(
self, bd_addr: hci.Address, class_of_device: int, link_type: int
):
logger.debug(f'*** Connection request: {bd_addr} link_type={link_type}')
logger.debug(f'*** Connection request: {bd_addr}')
# Handle SCO request.
if link_type in (
@@ -6062,7 +6061,6 @@ class Device(utils.CompositeEventEmitter):
if connection := self.find_connection_by_bd_addr(
bd_addr, transport=PhysicalTransport.BR_EDR
):
connection.emit(self.EVENT_SCO_REQUEST, link_type)
self.emit(self.EVENT_SCO_REQUEST, connection, link_type)
else:
logger.error(f'SCO request from a non-connected device {bd_addr}')
@@ -6422,7 +6420,8 @@ class Device(utils.CompositeEventEmitter):
logger.warning('peer name is not valid UTF-8')
if connection:
connection.emit(connection.EVENT_REMOTE_NAME_FAILURE, error)
self.emit(self.EVENT_REMOTE_NAME_FAILURE, address, error)
else:
self.emit(self.EVENT_REMOTE_NAME_FAILURE, address, error)
# [Classic only]
@host_event_handler
@@ -6439,13 +6438,7 @@ class Device(utils.CompositeEventEmitter):
@with_connection_from_address
@utils.experimental('Only for testing.')
def on_sco_connection(
self,
acl_connection: Connection,
sco_handle: int,
link_type: int,
rx_packet_length: int,
tx_packet_length: int,
air_mode: int,
self, acl_connection: Connection, sco_handle: int, link_type: int
) -> None:
logger.debug(
f'*** SCO connected: {acl_connection.peer_address}, '
@@ -6457,11 +6450,7 @@ class Device(utils.CompositeEventEmitter):
acl_connection=acl_connection,
handle=sco_handle,
link_type=link_type,
rx_packet_length=rx_packet_length,
tx_packet_length=tx_packet_length,
air_mode=hci.CodecID(air_mode),
)
acl_connection.emit(self.EVENT_SCO_CONNECTION, sco_link)
self.emit(self.EVENT_SCO_CONNECTION, sco_link)
# [Classic only]
@@ -6472,8 +6461,7 @@ class Device(utils.CompositeEventEmitter):
self, acl_connection: Connection, status: int
) -> None:
logger.debug(f'*** SCO connection failure: {acl_connection.peer_address}***')
acl_connection.emit(self.EVENT_SCO_CONNECTION_FAILURE, status)
self.emit(self.EVENT_SCO_CONNECTION_FAILURE, status)
self.emit(self.EVENT_SCO_CONNECTION_FAILURE)
# [Classic only]
@host_event_handler
@@ -6936,18 +6924,15 @@ class Device(utils.CompositeEventEmitter):
@with_connection_from_address
def on_classic_pairing(self, connection: Connection) -> None:
connection.emit(connection.EVENT_CLASSIC_PAIRING)
self.emit(connection.EVENT_CLASSIC_PAIRING, connection)
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_classic_pairing_failure(self, connection: Connection, status: int) -> None:
connection.emit(connection.EVENT_CLASSIC_PAIRING_FAILURE, status)
self.emit(connection.EVENT_CLASSIC_PAIRING_FAILURE, connection, status)
def on_pairing_start(self, connection: Connection) -> None:
connection.emit(connection.EVENT_PAIRING_START)
self.emit(connection.EVENT_PAIRING_START, connection)
def on_pairing(
self,
+22 -24
View File
@@ -67,8 +67,6 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# Helpers
# -----------------------------------------------------------------------------
_T = TypeVar('_T')
def _bearer_id(bearer: att.Bearer) -> str:
if att.is_enhanced_bearer(bearer):
@@ -371,8 +369,8 @@ class Server(utils.EventEmitter):
async def notify_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute[_T],
value: _T | None = None,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
@@ -392,8 +390,8 @@ class Server(utils.EventEmitter):
async def _notify_single_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute[_T],
value: _T | None,
attribute: att.Attribute,
value: bytes | None,
force: bool,
) -> None:
# Check if there's a subscriber
@@ -413,19 +411,19 @@ class Server(utils.EventEmitter):
return
# Get or encode the value
value_as_bytes = (
value = (
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value_as_bytes) > bearer.att_mtu - 3:
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
# Notify
notification = att.ATT_Handle_Value_Notification(
attribute_handle=attribute.handle, attribute_value=value_as_bytes
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
self.send_gatt_pdu(bearer, bytes(notification))
@@ -433,8 +431,8 @@ class Server(utils.EventEmitter):
async def indicate_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute[_T],
value: _T | None = None,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
@@ -454,8 +452,8 @@ class Server(utils.EventEmitter):
async def _indicate_single_bearer(
self,
bearer: att.Bearer,
attribute: att.Attribute[_T],
value: _T | None,
attribute: att.Attribute,
value: bytes | None,
force: bool,
) -> None:
# Check if there's a subscriber
@@ -475,19 +473,19 @@ class Server(utils.EventEmitter):
return
# Get or encode the value
value_as_bytes = (
value = (
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value_as_bytes) > bearer.att_mtu - 3:
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
# Indicate
indication = att.ATT_Handle_Value_Indication(
attribute_handle=attribute.handle, attribute_value=value_as_bytes
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
@@ -512,8 +510,8 @@ class Server(utils.EventEmitter):
async def _notify_or_indicate_subscribers(
self,
indicate: bool,
attribute: att.Attribute[_T],
value: _T | None = None,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
) -> None:
# Get all the bearers for which there's at least one subscription
@@ -539,8 +537,8 @@ class Server(utils.EventEmitter):
async def notify_subscribers(
self,
attribute: att.Attribute[_T],
value: _T | None = None,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(
@@ -549,8 +547,8 @@ class Server(utils.EventEmitter):
async def indicate_subscribers(
self,
attribute: att.Attribute[_T],
value: _T | None = None,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
+28 -121
View File
@@ -1721,15 +1721,6 @@ class CodecID(SpecableEnum):
VENDOR_SPECIFIC = 0xFF
# From Bluetooth Assigned Numbers, 2.10 PCM_Data_Format
class PcmDataFormat(SpecableEnum):
NA = 0x00
ONES_COMPLEMENT = 0x01
TWOS_COMPLEMENT = 0x02
SIGN_MAGNITUDE = 0x03
UNSIGNED = 0x04
@dataclasses.dataclass(frozen=True)
class CodingFormat:
codec_id: CodecID
@@ -1738,7 +1729,7 @@ class CodingFormat:
@classmethod
def parse_from_bytes(cls, data: bytes, offset: int) -> tuple[int, CodingFormat]:
codec_id, company_id, vendor_specific_codec_id = struct.unpack_from(
(codec_id, company_id, vendor_specific_codec_id) = struct.unpack_from(
'<BHH', data, offset
)
return offset + 5, cls(
@@ -1757,61 +1748,6 @@ class CodingFormat:
)
@dataclasses.dataclass(frozen=True)
class VoiceSetting:
class AirCodingFormat(enum.IntEnum):
CVSD = 0
U_LAW = 1
A_LAW = 2
TRANSPARENT_DATA = 3
class InputSampleSize(enum.IntEnum):
SIZE_8_BITS = 0
SIZE_16_BITS = 1
class InputDataFormat(enum.IntEnum):
ONES_COMPLEMENT = 0
TWOS_COMPLEMENT = 1
SIGN_AND_MAGNITUDE = 2
UNSIGNED = 3
class InputCodingFormat(enum.IntEnum):
LINEAR = 0
U_LAW = 1
A_LAW = 2
RESERVED = 3
air_coding_format: AirCodingFormat = AirCodingFormat.CVSD
linear_pcm_bit_position: int = 0
input_sample_size: InputSampleSize = InputSampleSize.SIZE_8_BITS
input_data_format: InputDataFormat = InputDataFormat.ONES_COMPLEMENT
input_coding_format: InputCodingFormat = InputCodingFormat.LINEAR
@classmethod
def from_int(cls, value: int) -> VoiceSetting:
air_coding_format = cls.AirCodingFormat(value & 0b11)
linear_pcm_bit_position = (value >> 2) & 0b111
input_sample_size = cls.InputSampleSize((value >> 5) & 0b1)
input_data_format = cls.InputDataFormat((value >> 6) & 0b11)
input_coding_format = cls.InputCodingFormat((value >> 8) & 0b11)
return cls(
air_coding_format=air_coding_format,
linear_pcm_bit_position=linear_pcm_bit_position,
input_sample_size=input_sample_size,
input_data_format=input_data_format,
input_coding_format=input_coding_format,
)
def __int__(self) -> int:
return (
self.air_coding_format
| (self.linear_pcm_bit_position << 2)
| (self.input_sample_size << 5)
| (self.input_data_format << 6)
| (self.input_coding_format << 8)
)
# -----------------------------------------------------------------------------
class HCI_Constant:
@staticmethod
@@ -2072,7 +2008,7 @@ class HCI_Object:
)
continue
field_name, field_type = object_field
(field_name, field_type) = object_field
result += HCI_Object.serialize_field(hci_object[field_name], field_type)
return bytes(result)
@@ -2950,23 +2886,6 @@ class HCI_Read_Clock_Offset_Command(HCI_AsyncCommand):
connection_handle: int = field(metadata=metadata(2))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
class HCI_Accept_Synchronous_Connection_Request_Command(HCI_AsyncCommand):
'''
See Bluetooth spec @ 7.1.27 Accept Synchronous Connection Request Command
'''
bd_addr: Address = field(metadata=metadata(Address.parse_address))
transmit_bandwidth: int = field(metadata=metadata(4))
receive_bandwidth: int = field(metadata=metadata(4))
max_latency: int = field(metadata=metadata(2))
voice_setting: int = field(metadata=metadata(2))
retransmission_effort: int = field(metadata=metadata(1))
packet_type: int = field(metadata=metadata(2))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
@@ -3115,8 +3034,8 @@ class HCI_Enhanced_Setup_Synchronous_Connection_Command(HCI_AsyncCommand):
output_coding_format: int = field(metadata=metadata(CodingFormat.parse_from_bytes))
input_coded_data_size: int = field(metadata=metadata(2))
output_coded_data_size: int = field(metadata=metadata(2))
input_pcm_data_format: int = field(metadata=PcmDataFormat.type_metadata(1))
output_pcm_data_format: int = field(metadata=PcmDataFormat.type_metadata(1))
input_pcm_data_format: int = field(metadata=metadata(1))
output_pcm_data_format: int = field(metadata=metadata(1))
input_pcm_sample_payload_msb_position: int = field(metadata=metadata(1))
output_pcm_sample_payload_msb_position: int = field(metadata=metadata(1))
input_data_path: int = field(metadata=metadata(1))
@@ -3127,6 +3046,13 @@ class HCI_Enhanced_Setup_Synchronous_Connection_Command(HCI_AsyncCommand):
packet_type: int = field(metadata=metadata(2))
retransmission_effort: int = field(metadata=metadata(1))
class PcmDataFormat(SpecableEnum):
NA = 0x00
ONES_COMPLEMENT = 0x01
TWOS_COMPLEMENT = 0x02
SIGN_MAGNITUDE = 0x03
UNSIGNED = 0x04
class DataPath(SpecableEnum):
HCI = 0x00
PCM = 0x01
@@ -3173,8 +3099,8 @@ class HCI_Enhanced_Accept_Synchronous_Connection_Request_Command(HCI_AsyncComman
output_coding_format: int = field(metadata=metadata(CodingFormat.parse_from_bytes))
input_coded_data_size: int = field(metadata=metadata(2))
output_coded_data_size: int = field(metadata=metadata(2))
input_pcm_data_format: int = field(metadata=PcmDataFormat.type_metadata(1))
output_pcm_data_format: int = field(metadata=PcmDataFormat.type_metadata(1))
input_pcm_data_format: int = field(metadata=metadata(1))
output_pcm_data_format: int = field(metadata=metadata(1))
input_pcm_sample_payload_msb_position: int = field(metadata=metadata(1))
output_pcm_sample_payload_msb_position: int = field(metadata=metadata(1))
input_data_path: int = field(metadata=metadata(1))
@@ -4018,23 +3944,6 @@ class HCI_Read_Local_OOB_Extended_Data_Command(
'''
# -----------------------------------------------------------------------------
@HCI_SyncCommand.sync_command(HCI_StatusReturnParameters)
@dataclasses.dataclass
class HCI_Configure_Data_Path_Command(HCI_SyncCommand[HCI_StatusReturnParameters]):
'''
See Bluetooth spec @ 7.3.101 Configure Data Path Command
'''
class DataPathDirection(SpecableEnum):
INPUT = 0x00
OUTPUT = 0x01
data_path_direction: DataPathDirection = field(metadata=metadata(1))
data_path_id: int = field(metadata=metadata(1))
vendor_specific_config: bytes = field(metadata=metadata('*'))
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class HCI_Read_Local_Version_Information_ReturnParameters(HCI_StatusReturnParameters):
@@ -7425,7 +7334,7 @@ class HCI_Connection_Complete_Event(HCI_Event):
status: int = field(metadata=metadata(STATUS_SPEC))
connection_handle: int = field(metadata=metadata(2))
bd_addr: Address = field(metadata=metadata(Address.parse_address))
link_type: LinkType = field(metadata=LinkType.type_metadata(1))
link_type: int = field(metadata=LinkType.type_metadata(1))
encryption_enabled: int = field(metadata=metadata(1))
@@ -7821,6 +7730,12 @@ class HCI_Synchronous_Connection_Complete_Event(HCI_Event):
SCO = 0x00
ESCO = 0x02
class AirMode(SpecableEnum):
U_LAW_LOG = 0x00
A_LAW_LOG_AIR_MORE = 0x01
CVSD = 0x02
TRANSPARENT_DATA = 0x03
status: int = field(metadata=metadata(STATUS_SPEC))
connection_handle: int = field(metadata=metadata(2))
bd_addr: Address = field(metadata=metadata(Address.parse_address))
@@ -7829,7 +7744,7 @@ class HCI_Synchronous_Connection_Complete_Event(HCI_Event):
retransmission_window: int = field(metadata=metadata(1))
rx_packet_length: int = field(metadata=metadata(2))
tx_packet_length: int = field(metadata=metadata(2))
air_mode: int = field(metadata=CodecID.type_metadata(1))
air_mode: int = field(metadata=AirMode.type_metadata(1))
# -----------------------------------------------------------------------------
@@ -8061,9 +7976,7 @@ class HCI_AclDataPacket(HCI_Packet):
bc_flag = (h >> 14) & 3
data = packet[5:]
if len(data) != data_total_length:
raise InvalidPacketError(
f'invalid packet length {len(data)} != {data_total_length}'
)
raise InvalidPacketError('invalid packet length')
return cls(
connection_handle=connection_handle,
pb_flag=pb_flag,
@@ -8096,16 +8009,10 @@ class HCI_SynchronousDataPacket(HCI_Packet):
See Bluetooth spec @ 5.4.3 HCI SCO Data Packets
'''
class Status(enum.IntEnum):
CORRECTLY_RECEIVED_DATA = 0b00
POSSIBLY_INVALID_DATA = 0b01
NO_DATA = 0b10
DATA_PARTIALLY_LOST = 0b11
hci_packet_type = HCI_SYNCHRONOUS_DATA_PACKET
connection_handle: int
packet_status: Status
packet_status: int
data_total_length: int
data: bytes
@@ -8114,7 +8021,7 @@ class HCI_SynchronousDataPacket(HCI_Packet):
# Read the header
h, data_total_length = struct.unpack_from('<HB', packet, 1)
connection_handle = h & 0xFFF
packet_status = cls.Status((h >> 12) & 0b11)
packet_status = (h >> 12) & 0b11
data = packet[4:]
if len(data) != data_total_length:
raise InvalidPacketError(
@@ -8138,7 +8045,7 @@ class HCI_SynchronousDataPacket(HCI_Packet):
return (
f'{color("SCO", "blue")}: '
f'handle=0x{self.connection_handle:04x}, '
f'ps={self.packet_status.name}, '
f'ps={self.packet_status}, '
f'data_total_length={self.data_total_length}, '
f'data={self.data.hex()}'
)
@@ -8166,8 +8073,8 @@ class HCI_IsoDataPacket(HCI_Packet):
def __post_init__(self) -> None:
self.ts_flag = self.time_stamp is not None
@classmethod
def from_bytes(cls, packet: bytes) -> HCI_IsoDataPacket:
@staticmethod
def from_bytes(packet: bytes) -> HCI_IsoDataPacket:
time_stamp: int | None = None
packet_sequence_number: int | None = None
iso_sdu_length: int | None = None
@@ -8196,7 +8103,7 @@ class HCI_IsoDataPacket(HCI_Packet):
pos += 4
iso_sdu_fragment = packet[pos:]
return cls(
return HCI_IsoDataPacket(
connection_handle=connection_handle,
pb_flag=pb_flag,
ts_flag=ts_flag,
+20 -28
View File
@@ -44,7 +44,6 @@ from bumble.hci import (
CodecID,
CodingFormat,
HCI_Enhanced_Setup_Synchronous_Connection_Command,
PcmDataFormat,
)
# -----------------------------------------------------------------------------
@@ -69,8 +68,6 @@ class HfpProtocolError(ProtocolError):
# -----------------------------------------------------------------------------
class HfpProtocol:
MAX_BUFFER_SIZE: ClassVar[int] = 65536
dlc: rfcomm.DLC
buffer: str
lines: collections.deque
@@ -87,19 +84,10 @@ class HfpProtocol:
def feed(self, data: bytes | str) -> None:
# Convert the data to a string if needed
if isinstance(data, bytes):
data = data.decode('utf-8', errors='replace')
data = data.decode('utf-8')
logger.debug(f'<<< Data received: {data}')
# Drop incoming data if it would overflow the buffer; keep existing
# partial packet state intact so a future clean packet can still parse.
if len(self.buffer) + len(data) > self.MAX_BUFFER_SIZE:
logger.warning(
'HFP buffer overflow (>%d bytes), dropping incoming data',
self.MAX_BUFFER_SIZE,
)
return
# Add to the buffer and look for lines
self.buffer += data
while (separator := self.buffer.find('\r')) >= 0:
@@ -178,7 +166,7 @@ class AgFeature(enum.IntFlag):
VOICE_RECOGNITION_TEXT = 0x2000
class AudioCodec(utils.OpenIntEnum):
class AudioCodec(enum.IntEnum):
"""
Audio Codec IDs (normative).
@@ -190,7 +178,7 @@ class AudioCodec(utils.OpenIntEnum):
LC3_SWB = 0x03 # Support for LC3-SWB audio codec
class HfIndicator(utils.OpenIntEnum):
class HfIndicator(enum.IntEnum):
"""
HF Indicators (normative).
@@ -219,7 +207,7 @@ class CallHoldOperation(enum.Enum):
)
class ResponseHoldStatus(utils.OpenIntEnum):
class ResponseHoldStatus(enum.IntEnum):
"""
Response Hold status (normative).
@@ -247,7 +235,7 @@ class AgIndicator(enum.Enum):
BATTERY_CHARGE = 'battchg'
class CallSetupAgIndicator(utils.OpenIntEnum):
class CallSetupAgIndicator(enum.IntEnum):
"""
Values for the Call Setup AG indicator (normative).
@@ -260,7 +248,7 @@ class CallSetupAgIndicator(utils.OpenIntEnum):
REMOTE_ALERTED = 3 # Remote party alerted in an outgoing call
class CallHeldAgIndicator(utils.OpenIntEnum):
class CallHeldAgIndicator(enum.IntEnum):
"""
Values for the Call Held AG indicator (normative).
@@ -274,7 +262,7 @@ class CallHeldAgIndicator(utils.OpenIntEnum):
CALL_ON_HOLD_NO_ACTIVE_CALL = 2 # Call on hold, no active call
class CallInfoDirection(utils.OpenIntEnum):
class CallInfoDirection(enum.IntEnum):
"""
Call Info direction (normative).
@@ -285,7 +273,7 @@ class CallInfoDirection(utils.OpenIntEnum):
MOBILE_TERMINATED_CALL = 1
class CallInfoStatus(utils.OpenIntEnum):
class CallInfoStatus(enum.IntEnum):
"""
Call Info status (normative).
@@ -300,7 +288,7 @@ class CallInfoStatus(utils.OpenIntEnum):
WAITING = 5
class CallInfoMode(utils.OpenIntEnum):
class CallInfoMode(enum.IntEnum):
"""
Call Info mode (normative).
@@ -313,7 +301,7 @@ class CallInfoMode(utils.OpenIntEnum):
UNKNOWN = 9
class CallInfoMultiParty(utils.OpenIntEnum):
class CallInfoMultiParty(enum.IntEnum):
"""
Call Info Multi-Party state (normative).
@@ -400,7 +388,7 @@ class CallLineIdentification:
)
class VoiceRecognitionState(utils.OpenIntEnum):
class VoiceRecognitionState(enum.IntEnum):
"""
vrec values provided in AT+BVRA command.
@@ -413,7 +401,7 @@ class VoiceRecognitionState(utils.OpenIntEnum):
ENHANCED_READY = 2
class CmeError(utils.OpenIntEnum):
class CmeError(enum.IntEnum):
"""
CME ERROR codes (partial listed).
@@ -1607,7 +1595,7 @@ class AgProtocol(utils.EventEmitter):
# -----------------------------------------------------------------------------
class ProfileVersion(utils.OpenIntEnum):
class ProfileVersion(enum.IntEnum):
"""
Profile version (normative).
@@ -1955,8 +1943,12 @@ class EscoParameters:
output_coding_format: CodingFormat = CodingFormat(CodecID.LINEAR_PCM)
input_coded_data_size: int = 16
output_coded_data_size: int = 16
input_pcm_data_format: PcmDataFormat = PcmDataFormat.TWOS_COMPLEMENT
output_pcm_data_format: PcmDataFormat = PcmDataFormat.TWOS_COMPLEMENT
input_pcm_data_format: (
HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat
) = HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat.TWOS_COMPLEMENT
output_pcm_data_format: (
HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat
) = HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat.TWOS_COMPLEMENT
input_pcm_sample_payload_msb_position: int = 0
output_pcm_sample_payload_msb_position: int = 0
input_data_path: HCI_Enhanced_Setup_Synchronous_Connection_Command.DataPath = (
@@ -2055,7 +2047,6 @@ _ESCO_PARAMETERS_MSBC_T1 = EscoParameters(
max_latency=0x0008,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
@@ -2071,6 +2062,7 @@ _ESCO_PARAMETERS_MSBC_T2 = EscoParameters(
max_latency=0x000D,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
+7 -121
View File
@@ -247,7 +247,6 @@ class Host(utils.EventEmitter):
bis_links: dict[int, IsoLink]
sco_links: dict[int, ScoLink]
bigs: dict[int, set[int]]
link_ts_flags: dict[int, int]
acl_packet_queue: DataPacketQueue | None = None
le_acl_packet_queue: DataPacketQueue | None = None
iso_packet_queue: DataPacketQueue | None = None
@@ -270,7 +269,6 @@ class Host(utils.EventEmitter):
self.bis_links = {} # BIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.bigs = {} # BIG Handle to BIS Handles
self.link_ts_flags = {} # TS_Flag for ISO links, by handle
self.pending_command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand | None = None
self.pending_response: (
asyncio.Future[
@@ -488,7 +486,6 @@ class Host(utils.EventEmitter):
hci.HCI_LE_PHY_UPDATE_COMPLETE_EVENT,
hci.HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT,
hci.HCI_LE_PERIODIC_ADVERTISING_SYNC_ESTABLISHED_EVENT,
hci.HCI_LE_PERIODIC_ADVERTISING_SYNC_ESTABLISHED_V2_EVENT,
hci.HCI_LE_PERIODIC_ADVERTISING_REPORT_EVENT,
hci.HCI_LE_PERIODIC_ADVERTISING_SYNC_LOST_EVENT,
hci.HCI_LE_SCAN_TIMEOUT_EVENT,
@@ -689,16 +686,16 @@ class Host(utils.EventEmitter):
self.pending_response, timeout=response_timeout
)
return response
except asyncio.TimeoutError:
raise
except Exception:
logger.exception(color("!!! Exception while sending command:", "red"))
raise
finally:
self.pending_command = None
self.pending_response = None
if response is None or (
response.num_hci_command_packets and self.command_semaphore.locked()
if (
response is not None
and response.num_hci_command_packets
and self.command_semaphore.locked()
):
self.command_semaphore.release()
@@ -869,7 +866,7 @@ class Host(utils.EventEmitter):
self.send_hci_packet(
hci.HCI_SynchronousDataPacket(
connection_handle=connection_handle,
packet_status=hci.HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA,
packet_status=0,
data_total_length=len(sdu),
data=sdu,
)
@@ -1031,82 +1028,6 @@ class Host(utils.EventEmitter):
# Look for the connection to which this data belongs
if connection := self.connections.get(packet.connection_handle):
connection.on_hci_acl_data_packet(packet)
return
# WORKAROUND: Some controllers (e.g. Intel BE200) send ISO data wrapped in ACL packets
# using the CIS handle.
is_cis = packet.connection_handle in self.cis_links
is_bis = packet.connection_handle in self.bis_links
if is_cis or is_bis:
logger.debug(
f"Received ISO data wrapped in ACL packet for handle 0x{packet.connection_handle:04X}"
)
payload = packet.data
ts_flag = self.link_ts_flags.get(packet.connection_handle)
if ts_flag is None:
# Learn TS flag from the first packet on this link
if is_bis:
# BIS packets always have Timestamp according to spec
ts_flag = 1
elif len(payload) < 8:
# Too short to have 8-byte header (TS), must be No TS
ts_flag = 0
else:
psn_no_ts = int.from_bytes(payload[0:2], 'little')
psn_has_ts = int.from_bytes(payload[4:6], 'little')
if psn_has_ts == 0:
ts_flag = 1
elif psn_no_ts == 0:
ts_flag = 0
else:
# Fallback heuristic
ts_flag = 1 if psn_has_ts < psn_no_ts else 0
self.link_ts_flags[packet.connection_handle] = ts_flag
logger.info(
f"Learned TS_Flag = {ts_flag} for handle 0x{packet.connection_handle:04X}"
)
if ts_flag:
header_size = 8
sdu_length_offset = 6
else:
header_size = 4
sdu_length_offset = 2
pb_flag = 0b10
if len(payload) >= header_size:
sdu_length = int.from_bytes(
payload[sdu_length_offset : sdu_length_offset + 2], 'little'
)
if sdu_length == len(payload) - header_size:
pb_flag = 0b10 # Complete SDU
else:
pb_flag = 0b00 # First fragment
else:
pb_flag = 0b01 # Continuation
ts_flag = 0
# Reconstruct the raw ISO packet (excluding packet indicator 0x05)
pdu_info = packet.connection_handle | (pb_flag << 12) | (ts_flag << 14)
header = bytes(
[
pdu_info & 0xFF,
(pdu_info >> 8) & 0xFF,
len(payload) & 0xFF,
(len(payload) >> 8) & 0xFF,
]
)
raw_iso_packet = header + payload
try:
iso_packet = hci.HCI_IsoDataPacket.from_bytes(
bytes([hci.HCI_ISO_DATA_PACKET]) + raw_iso_packet
)
self.on_hci_iso_data_packet(iso_packet)
except Exception as e:
logger.warning(f"Failed to reconstruct ISO packet from ACL: {e}")
def on_hci_sco_data_packet(self, packet: hci.HCI_SynchronousDataPacket) -> None:
# Experimental
@@ -1257,28 +1178,11 @@ class Host(utils.EventEmitter):
def on_hci_connection_complete_event(
self, event: hci.HCI_Connection_Complete_Event
):
if event.link_type == hci.HCI_Connection_Complete_Event.LinkType.SCO:
# Pass this on to the synchronous connection handler
forwarded_event = hci.HCI_Synchronous_Connection_Complete_Event(
status=event.status,
connection_handle=event.connection_handle,
bd_addr=event.bd_addr,
link_type=event.link_type,
transmission_interval=0,
retransmission_window=0,
rx_packet_length=0,
tx_packet_length=0,
air_mode=0,
)
self.on_hci_synchronous_connection_complete_event(forwarded_event)
return
if event.status == hci.HCI_SUCCESS:
# Create/update the connection
logger.debug(
f'### BR/EDR ACL CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.bd_addr} '
f'{event.link_type.name}'
f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.bd_addr}'
)
connection = self.connections.get(event.connection_handle)
@@ -1330,7 +1234,6 @@ class Host(utils.EventEmitter):
self.emit('disconnection', handle, event.reason)
# Remove the handle reference
self.link_ts_flags.pop(handle, None)
_ = (
self.connections.pop(handle, 0)
or self.cis_links.pop(handle, 0)
@@ -1451,20 +1354,6 @@ class Host(utils.EventEmitter):
event.advertiser_clock_accuracy,
)
def on_hci_le_periodic_advertising_sync_established_v2_event(
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Established_V2_Event
):
self.emit(
'periodic_advertising_sync_establishment',
event.status,
event.sync_handle,
event.advertising_sid,
event.advertiser_address,
event.advertiser_phy,
event.periodic_advertising_interval,
event.advertiser_clock_accuracy,
)
def on_hci_le_periodic_advertising_sync_lost_event(
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Lost_Event
):
@@ -1693,9 +1582,6 @@ class Host(utils.EventEmitter):
event.bd_addr,
event.connection_handle,
event.link_type,
event.rx_packet_length,
event.tx_packet_length,
event.air_mode,
)
else:
logger.debug(f'### SCO CONNECTION FAILED: {event.status}')
+1 -1
View File
@@ -110,7 +110,7 @@ RFCOMM_DEFAULT_L2CAP_MTU = 2048
RFCOMM_DEFAULT_INITIAL_CREDITS = 7
RFCOMM_DEFAULT_MAX_CREDITS = 32
RFCOMM_DEFAULT_CREDIT_THRESHOLD = RFCOMM_DEFAULT_MAX_CREDITS // 2
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 1000
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
+264 -324
View File
@@ -44,12 +44,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# SDP data elements are nested (SEQUENCE, ALTERNATIVE). Cap parse recursion to
# prevent a malicious peer from crashing the process via a deeply nested PDU.
# 32 levels is well beyond anything a legitimate service record uses.
_MAX_DATA_ELEMENT_NESTING = 32
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -154,6 +148,32 @@ class DataElement:
ALTERNATIVE = Type.ALTERNATIVE
URL = Type.URL
TYPE_CONSTRUCTORS = {
NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(x),
value_size=y,
),
SIGNED_INTEGER: lambda x, y: DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(x),
value_size=y,
),
UUID: lambda x: DataElement(
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
),
ALTERNATIVE: lambda x: DataElement(
DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
}
type: Type
value: Any
value_size: int | None = None
@@ -170,354 +190,279 @@ class DataElement:
'integer types must have a value size specified'
)
@classmethod
def nil(cls) -> DataElement:
return cls(cls.NIL, None)
@staticmethod
def nil() -> DataElement:
return DataElement(DataElement.NIL, None)
@classmethod
def unsigned_integer(cls, value: int, value_size: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size)
@staticmethod
def unsigned_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size)
@classmethod
def unsigned_integer_8(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=1)
@staticmethod
def unsigned_integer_8(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1)
@classmethod
def unsigned_integer_16(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=2)
@staticmethod
def unsigned_integer_16(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2)
@classmethod
def unsigned_integer_32(cls, value: int) -> DataElement:
return cls(cls.UNSIGNED_INTEGER, value, value_size=4)
@staticmethod
def unsigned_integer_32(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4)
@classmethod
def signed_integer(cls, value: int, value_size: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size)
@staticmethod
def signed_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size)
@classmethod
def signed_integer_8(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=1)
@staticmethod
def signed_integer_8(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1)
@classmethod
def signed_integer_16(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=2)
@staticmethod
def signed_integer_16(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2)
@classmethod
def signed_integer_32(cls, value: int) -> DataElement:
return cls(cls.SIGNED_INTEGER, value, value_size=4)
@staticmethod
def signed_integer_32(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4)
@classmethod
def uuid(cls, value: core.UUID) -> DataElement:
return cls(cls.UUID, value)
@staticmethod
def uuid(value: core.UUID) -> DataElement:
return DataElement(DataElement.UUID, value)
@classmethod
def text_string(cls, value: bytes) -> DataElement:
return cls(cls.TEXT_STRING, value)
@staticmethod
def text_string(value: bytes) -> DataElement:
return DataElement(DataElement.TEXT_STRING, value)
@classmethod
def boolean(cls, value: bool) -> DataElement:
return cls(cls.BOOLEAN, value)
@staticmethod
def boolean(value: bool) -> DataElement:
return DataElement(DataElement.BOOLEAN, value)
@classmethod
def sequence(cls, value: Iterable[DataElement]) -> DataElement:
return cls(cls.SEQUENCE, value)
@staticmethod
def sequence(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.SEQUENCE, value)
@classmethod
def alternative(cls, value: Iterable[DataElement]) -> DataElement:
return cls(cls.ALTERNATIVE, value)
@staticmethod
def alternative(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.ALTERNATIVE, value)
@classmethod
def url(cls, value: str) -> DataElement:
return cls(cls.URL, value)
@staticmethod
def url(value: str) -> DataElement:
return DataElement(DataElement.URL, value)
@classmethod
def unsigned_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
match length:
case 1:
return data[offset]
case 2:
return struct.unpack_from('>H', data, offset)[0]
case 4:
return struct.unpack_from('>I', data, offset)[0]
case 8:
return struct.unpack_from('>Q', data, offset)[0]
case invalid_length:
raise InvalidPacketError(f'invalid integer length {invalid_length}')
@staticmethod
def unsigned_integer_from_bytes(data):
if len(data) == 1:
return data[0]
@classmethod
def signed_integer_from_bytes(cls, data: bytes, offset: int, length: int) -> int:
match length:
case 1:
return struct.unpack_from('b', data, offset)[0]
case 2:
return struct.unpack_from('>h', data, offset)[0]
case 4:
return struct.unpack_from('>i', data, offset)[0]
case 8:
return struct.unpack_from('>q', data, offset)[0]
case invalid_length:
raise InvalidPacketError(f'invalid integer length {invalid_length}')
if len(data) == 2:
return struct.unpack('>H', data)[0]
@classmethod
def parse_from_bytes(cls, data: bytes, offset: int) -> tuple[int, DataElement]:
parser = DataElementParser(data, offset)
element = parser.parse_next()
return parser.offset, element
if len(data) == 4:
return struct.unpack('>I', data)[0]
@classmethod
def from_bytes(cls, data: bytes) -> DataElement:
return DataElementParser(data).parse_next()
if len(data) == 8:
return struct.unpack('>Q', data)[0]
def __bytes__(self) -> bytes:
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def signed_integer_from_bytes(data):
if len(data) == 1:
return struct.unpack('b', data)[0]
if len(data) == 2:
return struct.unpack('>h', data)[0]
if len(data) == 4:
return struct.unpack('>i', data)[0]
if len(data) == 8:
return struct.unpack('>q', data)[0]
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def list_from_bytes(data):
elements = []
while data:
element = DataElement.from_bytes(data)
elements.append(element)
data = data[len(bytes(element)) :]
return elements
@staticmethod
def parse_from_bytes(data, offset):
element = DataElement.from_bytes(data[offset:])
return offset + len(bytes(element)), element
@staticmethod
def from_bytes(data):
element_type = data[0] >> 3
size_index = data[0] & 7
value_offset = 0
if size_index == 0:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
elif size_index == 1:
value_size = 2
elif size_index == 2:
value_size = 4
elif size_index == 3:
value_size = 8
elif size_index == 4:
value_size = 16
elif size_index == 5:
value_size = data[1]
value_offset = 1
elif size_index == 6:
value_size = struct.unpack('>H', data[1:3])[0]
value_offset = 2
else: # size_index == 7
value_size = struct.unpack('>I', data[1:5])[0]
value_offset = 4
value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.TYPE_CONSTRUCTORS.get(element_type)
if constructor:
if element_type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
):
result = constructor(value_data, value_size)
else:
result = constructor(value_data)
else:
result = DataElement(element_type, value_data)
result._bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
return result
def __bytes__(self):
# Return early if we have a cache
if self._bytes:
return self._bytes
match self.type:
case DataElement.NIL:
data = b''
case DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
if self.type == DataElement.NIL:
data = b''
elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
match self.value_size:
case 1:
data = struct.pack('B', self.value)
case 2:
data = struct.pack('>H', self.value)
case 4:
data = struct.pack('>I', self.value)
case 8:
data = struct.pack('>Q', self.value)
case invalid_length:
raise InvalidArgumentError(
f'invalid value_size of {invalid_length}'
)
case DataElement.SIGNED_INTEGER:
match self.value_size:
case 1:
data = struct.pack('b', self.value)
case 2:
data = struct.pack('>h', self.value)
case 4:
data = struct.pack('>i', self.value)
case 8:
data = struct.pack('>q', self.value)
case invalid_length:
raise InvalidArgumentError(
f'invalid value_size of {invalid_length}'
)
case DataElement.UUID:
data = bytes(self.value)[::-1]
case DataElement.URL:
data = self.value.encode('utf8')
case DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
data = b''.join([bytes(element) for element in self.value])
case _:
data = self.value
if self.value_size == 1:
data = struct.pack('B', self.value)
elif self.value_size == 2:
data = struct.pack('>H', self.value)
elif self.value_size == 4:
data = struct.pack('>I', self.value)
elif self.value_size == 8:
data = struct.pack('>Q', self.value)
else:
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.SIGNED_INTEGER:
if self.value_size == 1:
data = struct.pack('b', self.value)
elif self.value_size == 2:
data = struct.pack('>h', self.value)
elif self.value_size == 4:
data = struct.pack('>i', self.value)
elif self.value_size == 8:
data = struct.pack('>q', self.value)
else:
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value)))
elif self.type == DataElement.URL:
data = self.value.encode('utf8')
elif self.type == DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
data = b''.join([bytes(element) for element in self.value])
else:
data = self.value
size = len(data)
size_bytes = b''
match self.type:
case DataElement.NIL:
if size != 0:
raise InvalidArgumentError('NIL must be empty')
if self.type == DataElement.NIL:
if size != 0:
raise InvalidArgumentError('NIL must be empty')
size_index = 0
elif self.type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
DataElement.UUID,
):
if size <= 1:
size_index = 0
case (
DataElement.UNSIGNED_INTEGER
| DataElement.SIGNED_INTEGER
| DataElement.UUID
):
if size <= 1:
size_index = 0
elif size == 2:
size_index = 1
elif size == 4:
size_index = 2
elif size == 8:
size_index = 3
elif size == 16:
size_index = 4
else:
raise InvalidArgumentError('invalid data size')
case (
DataElement.TEXT_STRING
| DataElement.SEQUENCE
| DataElement.ALTERNATIVE
| DataElement.URL
):
if size <= 0xFF:
size_index = 5
size_bytes = bytes([size])
elif size <= 0xFFFF:
size_index = 6
size_bytes = struct.pack('>H', size)
elif size <= 0xFFFFFFFF:
size_index = 7
size_bytes = struct.pack('>I', size)
else:
raise InvalidArgumentError('invalid data size')
case DataElement.BOOLEAN:
if size != 1:
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
case unsupported_type:
raise core.InvalidPacketError(
f"internal error - {unsupported_type} not supported"
)
elif size == 2:
size_index = 1
elif size == 4:
size_index = 2
elif size == 8:
size_index = 3
elif size == 16:
size_index = 4
else:
raise InvalidArgumentError('invalid data size')
elif self.type in (
DataElement.TEXT_STRING,
DataElement.SEQUENCE,
DataElement.ALTERNATIVE,
DataElement.URL,
):
if size <= 0xFF:
size_index = 5
size_bytes = bytes([size])
elif size <= 0xFFFF:
size_index = 6
size_bytes = struct.pack('>H', size)
elif size <= 0xFFFFFFFF:
size_index = 7
size_bytes = struct.pack('>I', size)
else:
raise InvalidArgumentError('invalid data size')
elif self.type == DataElement.BOOLEAN:
if size != 1:
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
else:
raise RuntimeError("internal error - self.type not supported")
self._bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
return self._bytes
def to_string(self, pretty: bool = False, indentation: int = 0) -> str:
def to_string(self, pretty=False, indentation=0):
prefix = ' ' * indentation
type_name = self.type.name
match self.type:
case DataElement.NIL:
value_string = ''
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
elements = [
element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
case DataElement.UNSIGNED_INTEGER | DataElement.SIGNED_INTEGER:
value_string = f'{self.value}#{self.value_size}'
case _:
if isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
else:
value_string = str(self.value)
if self.type == DataElement.NIL:
value_string = ''
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
elements = [
element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
value_string = f'{self.value}#{self.value_size}'
elif isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
else:
value_string = str(self.value)
return f'{prefix}{type_name}({value_string})'
def __str__(self) -> str:
def __str__(self):
return self.to_string()
class DataElementParser:
def __init__(
self, data: bytes, offset: int = 0, max_depth: int = _MAX_DATA_ELEMENT_NESTING
) -> None:
self.data = data
self.offset = offset
self.depth = 0
self.max_depth = max_depth
def parse_next(self) -> DataElement:
if self.offset >= len(self.data):
raise core.InvalidStateError(
f"offset {self.offset} exceeds len(data) {len(self.data)}"
)
start_offset = self.offset
element_type = DataElement.Type(self.data[self.offset] >> 3)
size_index = self.data[self.offset] & 7
self.offset += 1
value_size: int
match size_index:
case 0:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
case 1:
value_size = 2
case 2:
value_size = 4
case 3:
value_size = 8
case 4:
value_size = 16
case 5:
value_size = self.data[self.offset]
self.offset += 1
case 6:
value_size = struct.unpack_from('>H', self.data, self.offset)[0]
self.offset += 2
case 7:
value_size = struct.unpack_from('>I', self.data, self.offset)[0]
self.offset += 4
case _:
raise core.UnreachableError()
value_start = self.offset
value_end = self.offset + value_size
match element_type:
case DataElement.NIL:
result = DataElement(DataElement.NIL, None)
case DataElement.UNSIGNED_INTEGER:
result = DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(
self.data, value_start, value_size
),
value_size=value_size,
)
case DataElement.SIGNED_INTEGER:
result = DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(
self.data, value_start, value_size
),
value_size=value_size,
)
case DataElement.UUID:
result = DataElement(
DataElement.UUID,
core.UUID.from_bytes(self.data[value_start:value_end][::-1]),
)
case DataElement.TEXT_STRING:
result = DataElement(
DataElement.TEXT_STRING, self.data[value_start:value_end]
)
case DataElement.BOOLEAN:
result = DataElement(DataElement.BOOLEAN, self.data[value_start] == 1)
case DataElement.SEQUENCE | DataElement.ALTERNATIVE:
self.offset = value_start
result = DataElement(
element_type,
self._list_from_bytes(value_end),
)
if self.offset != value_end:
logger.warning(
"Expect parsing until offset %d, but ends at %d",
value_end,
self.offset,
)
case DataElement.URL:
result = DataElement(
DataElement.URL, self.data[value_start:value_end].decode('utf8')
)
case other_type:
result = DataElement(other_type, self.data[value_start:value_end])
self.offset = value_end
result._bytes = self.data[start_offset:value_end]
return result
def _list_from_bytes(self, end_offset: int) -> list[DataElement]:
if self.depth >= self.max_depth:
raise InvalidPacketError(
f"SDP data element nesting exceeds max depth " f"({self.max_depth})"
)
self.depth += 1
elements = []
while self.offset < end_offset:
elements.append(self.parse_next())
self.depth -= 1
return elements
# -----------------------------------------------------------------------------
@dataclass
class ServiceAttribute:
@@ -649,10 +594,7 @@ class SDP_PDU:
@classmethod
def from_bytes(cls, pdu: bytes) -> SDP_PDU:
pdu_id, transaction_id, parameters_length = struct.unpack_from('>BHH', pdu, 0)
if len(pdu) != 5 + parameters_length:
logger.warning("Expect %d bytes, got %d", 5 + parameters_length, len(pdu))
pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0)
subclass = cls.subclasses.get(pdu_id)
if not (subclass := cls.subclasses.get(pdu_id)):
@@ -674,11 +616,9 @@ class SDP_PDU:
def __bytes__(self):
if self._payload is None:
parameters = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
self._payload = (
struct.pack('>BHH', self.pdu_id, self.transaction_id, len(parameters))
+ parameters
)
self._payload = struct.pack(
'>BHH', self.pdu_id, self.transaction_id, 0
) + hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
return self._payload
@property
-3
View File
@@ -36,7 +36,6 @@ from bumble.colors import color
from bumble.core import (
AdvertisingData,
InvalidArgumentError,
InvalidPacketError,
PhysicalTransport,
ProtocolError,
)
@@ -216,8 +215,6 @@ class SMP_Command:
@classmethod
def from_bytes(cls, pdu: bytes) -> SMP_Command:
if not pdu:
raise InvalidPacketError("Empty SMP PDU")
code = CommandCode(pdu[0])
subclass = SMP_Command.smp_classes.get(code)
-3
View File
@@ -104,9 +104,6 @@ async def open_pyusb_transport(spec: str) -> Transport:
0,
packet[1:],
)
elif packet_type == hci.HCI_ISO_DATA_PACKET:
# Workaround: Send ISO packets over Bulk Out
self.device.write(USB_ENDPOINT_ACL_OUT, packet[1:])
else:
logger.warning(
color(f'unsupported packet type {packet_type}', 'red')
+353 -720
View File
File diff suppressed because it is too large Load Diff
+111 -242
View File
@@ -20,119 +20,17 @@ import contextlib
import functools
import json
import sys
import wave
import websockets.asyncio.server
import bumble.logging
from bumble import hci, hfp, rfcomm
from bumble.device import Connection, Device, ScoLink
from bumble.device import Connection, Device
from bumble.hfp import HfProtocol
from bumble.transport import open_transport
# -----------------------------------------------------------------------------
ws: websockets.asyncio.server.ServerConnection | None = None
hf_protocol: HfProtocol | None = None
input_wav: wave.Wave_read | None = None
output_wav: wave.Wave_write | None = None
# -----------------------------------------------------------------------------
def on_audio_packet(packet: hci.HCI_SynchronousDataPacket) -> None:
if (
packet.packet_status
!= hci.HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA
):
print('!!! discarding packet with status ', packet.packet_status.name)
return
frame_count = len(packet.data) // 2
print(f">>> received {frame_count} PCM samples")
if output_wav:
# Save the PCM audio to the output
output_wav.writeframes(packet.data)
if input_wav and hf_protocol:
# Send PCM audio from the input, same amount as what was received
while not (pcm_data := input_wav.readframes(frame_count)):
input_wav.setpos(0) # Loop
print(f">>> sending {frame_count} PCM samples")
hf_protocol.dlc.multiplexer.l2cap_channel.connection.device.host.send_sco_sdu(
connection_handle=packet.connection_handle,
sdu=pcm_data,
)
# -----------------------------------------------------------------------------
def on_sco_connection(link: ScoLink) -> None:
print('### SCO connection established:', link)
if link.air_mode == hci.CodecID.TRANSPARENT:
print("@@@ The controller does not encode/decode voice")
return
link.sink = on_audio_packet
# -----------------------------------------------------------------------------
def on_sco_request(
link_type: int, connection: Connection, protocol: HfProtocol
) -> None:
if link_type == hci.HCI_Connection_Complete_Event.LinkType.SCO:
esco_parameters = hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.SCO_CVSD_D1]
elif protocol.active_codec == hfp.AudioCodec.MSBC:
esco_parameters = hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_MSBC_T2]
elif protocol.active_codec == hfp.AudioCodec.CVSD:
esco_parameters = hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_CVSD_S4]
else:
raise RuntimeError("unknown active codec")
if connection.device.host.supports_command(
hci.HCI_ENHANCED_ACCEPT_SYNCHRONOUS_CONNECTION_REQUEST_COMMAND
):
connection.cancel_on_disconnection(
connection.device.send_async_command(
hci.HCI_Enhanced_Accept_Synchronous_Connection_Request_Command(
bd_addr=connection.peer_address, **esco_parameters.asdict()
)
)
)
elif connection.device.host.supports_command(
hci.HCI_ACCEPT_SYNCHRONOUS_CONNECTION_REQUEST_COMMAND
):
connection.cancel_on_disconnection(
connection.device.send_async_command(
hci.HCI_Accept_Synchronous_Connection_Request_Command(
bd_addr=connection.peer_address,
transmit_bandwidth=esco_parameters.transmit_bandwidth,
receive_bandwidth=esco_parameters.receive_bandwidth,
max_latency=esco_parameters.max_latency,
voice_setting=int(
hci.VoiceSetting(
input_sample_size=hci.VoiceSetting.InputSampleSize.SIZE_16_BITS,
input_data_format=hci.VoiceSetting.InputDataFormat.TWOS_COMPLEMENT,
)
),
retransmission_effort=esco_parameters.retransmission_effort,
packet_type=esco_parameters.packet_type,
)
)
)
else:
print('!!! no supported command for SCO connection request')
return
global output_wav
if output_wav:
output_wav.setnchannels(1)
output_wav.setsampwidth(2)
match protocol.active_codec:
case hfp.AudioCodec.CVSD:
output_wav.setframerate(8000)
case hfp.AudioCodec.MSBC:
output_wav.setframerate(16000)
connection.on('sco_connection', on_sco_connection)
# -----------------------------------------------------------------------------
@@ -142,163 +40,134 @@ def on_dlc(dlc: rfcomm.DLC, configuration: hfp.HfConfiguration):
hf_protocol = HfProtocol(dlc, configuration)
asyncio.create_task(hf_protocol.run())
connection = dlc.multiplexer.l2cap_channel.connection
handler = functools.partial(
on_sco_request,
connection=connection,
protocol=hf_protocol,
)
connection.on('sco_request', handler)
def on_sco_request(connection: Connection, link_type: int, protocol: HfProtocol):
if connection == protocol.dlc.multiplexer.l2cap_channel.connection:
if link_type == hci.HCI_Connection_Complete_Event.LinkType.SCO:
esco_parameters = hfp.ESCO_PARAMETERS[
hfp.DefaultCodecParameters.SCO_CVSD_D1
]
elif protocol.active_codec == hfp.AudioCodec.MSBC:
esco_parameters = hfp.ESCO_PARAMETERS[
hfp.DefaultCodecParameters.ESCO_MSBC_T2
]
elif protocol.active_codec == hfp.AudioCodec.CVSD:
esco_parameters = hfp.ESCO_PARAMETERS[
hfp.DefaultCodecParameters.ESCO_CVSD_S4
]
else:
raise RuntimeError("unknown active codec")
connection.cancel_on_disconnection(
connection.device.send_command(
hci.HCI_Enhanced_Accept_Synchronous_Connection_Request_Command(
bd_addr=connection.peer_address, **esco_parameters.asdict()
)
)
)
handler = functools.partial(on_sco_request, protocol=hf_protocol)
dlc.multiplexer.l2cap_channel.connection.device.on('sco_request', handler)
dlc.multiplexer.l2cap_channel.once(
'close',
lambda: connection.remove_listener('sco_request', handler),
lambda: dlc.multiplexer.l2cap_channel.connection.device.remove_listener(
'sco_request', handler
),
)
def on_ag_indicator(indicator):
global ws
if ws:
asyncio.create_task(ws.send(str(indicator)))
hf_protocol.on('ag_indicator', on_ag_indicator)
hf_protocol.on('codec_negotiation', on_codec_negotiation)
# -----------------------------------------------------------------------------
def on_ag_indicator(indicator):
global ws
if ws:
asyncio.create_task(ws.send(str(indicator)))
# -----------------------------------------------------------------------------
def on_codec_negotiation(codec: hfp.AudioCodec):
print(f'### Negotiated codec: {codec.name}')
# -----------------------------------------------------------------------------
async def run(device: Device, codec: str | None) -> None:
if codec is None:
supported_audio_codecs = [hfp.AudioCodec.CVSD, hfp.AudioCodec.MSBC]
else:
if codec == 'cvsd':
supported_audio_codecs = [hfp.AudioCodec.CVSD]
elif codec == 'msbc':
supported_audio_codecs = [hfp.AudioCodec.MSBC]
else:
print('Unknown codec: ', codec)
return
# Hands-Free profile configuration.
# TODO: load configuration from file.
configuration = hfp.HfConfiguration(
supported_hf_features=[
hfp.HfFeature.THREE_WAY_CALLING,
hfp.HfFeature.REMOTE_VOLUME_CONTROL,
hfp.HfFeature.ENHANCED_CALL_STATUS,
hfp.HfFeature.ENHANCED_CALL_CONTROL,
hfp.HfFeature.CODEC_NEGOTIATION,
hfp.HfFeature.HF_INDICATORS,
hfp.HfFeature.ESCO_S4_SETTINGS_SUPPORTED,
],
supported_hf_indicators=[
hfp.HfIndicator.BATTERY_LEVEL,
],
supported_audio_codecs=supported_audio_codecs,
)
# Create and register a server
rfcomm_server = rfcomm.Server(device)
# Listen for incoming DLC connections
channel_number = rfcomm_server.listen(lambda dlc: on_dlc(dlc, configuration))
print(f'### Listening for connection on channel {channel_number}')
# Advertise the HFP RFComm channel in the SDP
device.sdp_service_records = {
0x00010001: hfp.make_hf_sdp_records(0x00010001, channel_number, configuration)
}
# Let's go!
await device.power_on()
# Start being discoverable and connectable
await device.set_discoverable(True)
await device.set_connectable(True)
# Start the UI websocket server to offer a few buttons and input boxes
async def serve(websocket: websockets.asyncio.server.ServerConnection):
global ws
ws = websocket
async for message in websocket:
with contextlib.suppress(websockets.exceptions.ConnectionClosedOK):
print('Received: ', str(message))
parsed = json.loads(message)
message_type = parsed['type']
if message_type == 'at_command':
if hf_protocol is not None:
response = str(
await hf_protocol.execute_command(
parsed['command'],
response_type=hfp.AtResponseType.MULTIPLE,
)
)
await websocket.send(response)
elif message_type == 'query_call':
if hf_protocol:
response = str(await hf_protocol.query_current_calls())
await websocket.send(response)
await websockets.asyncio.server.serve(serve, 'localhost', 8989)
await asyncio.get_running_loop().create_future() # run forever
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) < 3:
print(
'Usage: run_hfp_handsfree.py <device-config> <transport-spec> '
'[codec] [input] [output]'
)
print('example: run_hfp_handsfree.py classic2.json usb:0')
print('Usage: run_classic_hfp.py <device-config> <transport-spec>')
print('example: run_classic_hfp.py classic2.json usb:04b4:f901')
return
device_config = sys.argv[1]
transport_spec = sys.argv[2]
print('<<< connecting to HCI...')
async with await open_transport(sys.argv[2]) as hci_transport:
print('<<< connected')
codec: str | None = None
if len(sys.argv) >= 4:
codec = sys.argv[3]
# Hands-Free profile configuration.
# TODO: load configuration from file.
configuration = hfp.HfConfiguration(
supported_hf_features=[
hfp.HfFeature.THREE_WAY_CALLING,
hfp.HfFeature.REMOTE_VOLUME_CONTROL,
hfp.HfFeature.ENHANCED_CALL_STATUS,
hfp.HfFeature.ENHANCED_CALL_CONTROL,
hfp.HfFeature.CODEC_NEGOTIATION,
hfp.HfFeature.HF_INDICATORS,
hfp.HfFeature.ESCO_S4_SETTINGS_SUPPORTED,
],
supported_hf_indicators=[
hfp.HfIndicator.BATTERY_LEVEL,
],
supported_audio_codecs=[
hfp.AudioCodec.CVSD,
hfp.AudioCodec.MSBC,
],
)
input_file_name: str | None = None
if len(sys.argv) >= 5:
input_file_name = sys.argv[4]
# Create a device
device = Device.from_config_file_with_hci(
sys.argv[1], hci_transport.source, hci_transport.sink
)
device.classic_enabled = True
output_file_name: str | None = None
if len(sys.argv) >= 6:
output_file_name = sys.argv[5]
# Create and register a server
rfcomm_server = rfcomm.Server(device)
global input_wav, output_wav
input_cm: contextlib.AbstractContextManager[wave.Wave_read | None] = (
wave.open(input_file_name, "rb")
if input_file_name
else contextlib.nullcontext(None)
)
output_cm: contextlib.AbstractContextManager[wave.Wave_write | None] = (
wave.open(output_file_name, "wb")
if output_file_name
else contextlib.nullcontext(None)
)
with input_cm as input_wav, output_cm as output_wav:
if input_wav and input_wav.getnchannels() != 1:
print("Mono input required")
return
if input_wav and input_wav.getsampwidth() != 2:
print("16-bit input required")
return
# Listen for incoming DLC connections
channel_number = rfcomm_server.listen(lambda dlc: on_dlc(dlc, configuration))
print(f'### Listening for connection on channel {channel_number}')
async with await open_transport(transport_spec) as transport:
device = Device.from_config_file_with_hci(
device_config, transport.source, transport.sink
# Advertise the HFP RFComm channel in the SDP
device.sdp_service_records = {
0x00010001: hfp.make_hf_sdp_records(
0x00010001, channel_number, configuration
)
device.classic_enabled = True
await run(device, codec)
}
# Let's go!
await device.power_on()
# Start being discoverable and connectable
await device.set_discoverable(True)
await device.set_connectable(True)
# Start the UI websocket server to offer a few buttons and input boxes
async def serve(websocket: websockets.asyncio.server.ServerConnection):
global ws
ws = websocket
async for message in websocket:
with contextlib.suppress(websockets.exceptions.ConnectionClosedOK):
print('Received: ', str(message))
parsed = json.loads(message)
message_type = parsed['type']
if message_type == 'at_command':
if hf_protocol is not None:
response = str(
await hf_protocol.execute_command(
parsed['command'],
response_type=hfp.AtResponseType.MULTIPLE,
)
)
await websocket.send(response)
elif message_type == 'query_call':
if hf_protocol:
response = str(await hf_protocol.query_current_calls())
await websocket.send(response)
await websockets.asyncio.server.serve(serve, 'localhost', 8989)
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+42 -10
View File
@@ -657,6 +657,18 @@ dependencies = [
"wasi",
]
[[package]]
name = "getrandom"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasip2",
]
[[package]]
name = "gimli"
version = "0.28.0"
@@ -1402,21 +1414,26 @@ dependencies = [
]
[[package]]
name = "rand"
version = "0.8.5"
name = "r-efi"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rand"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ec095654a25171c2124e9e3393a930bddbffdc939556c914957a4c3e0a87166"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core",
@@ -1424,11 +1441,11 @@ dependencies = [
[[package]]
name = "rand_core"
version = "0.6.4"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c"
dependencies = [
"getrandom",
"getrandom 0.3.4",
]
[[package]]
@@ -1455,7 +1472,7 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
dependencies = [
"getrandom",
"getrandom 0.2.10",
"redox_syscall 0.2.16",
"thiserror",
]
@@ -2028,6 +2045,15 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasip2"
version = "1.0.2+wasi-0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5"
dependencies = [
"wit-bindgen",
]
[[package]]
name = "wasm-bindgen"
version = "0.2.87"
@@ -2283,3 +2309,9 @@ dependencies = [
"cfg-if",
"windows-sys 0.48.0",
]
[[package]]
name = "wit-bindgen"
version = "0.51.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5"
+1 -1
View File
@@ -57,7 +57,7 @@ anyhow = "1.0.71"
pyo3 = { version = "0.18.3", features = ["macros", "anyhow"] }
pyo3-asyncio = { version = "0.18.0", features = ["tokio-runtime", "attributes", "testing"] }
rusb = "0.9.2"
rand = "0.8.5"
rand = "0.9.3"
clap = { version = "4.3.3", features = ["derive"] }
owo-colors = "3.5.0"
log = "0.4.19"
-2
View File
@@ -170,9 +170,7 @@ def format_code(ctx, check=False, diff=False):
@task
def check_types(ctx):
checklist = ["apps", "bumble", "examples", "tests", "tasks.py"]
print(">>> Running the type checker...")
try:
print("+++ Checking with mypy...")
ctx.run(f"mypy {' '.join(checklist)}")
except UnexpectedExit as exc:
print("Please check your code against the mypy messages.")
-25
View File
@@ -120,31 +120,6 @@ def test_messages(message: avdtp.Message):
assert message.payload == parsed.payload
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'pdu',
(
b'', # empty PDU — would IndexError on pdu[0]
b'\x00', # 1-byte SINGLE_PACKET — would IndexError on pdu[1]
b'\x04', # 1-byte START_PACKET — would IndexError on pdu[1]
b'\x44\x10', # 2-byte START_PACKET — would IndexError on pdu[2]
),
)
def test_message_assembler_truncated_pdu(pdu: bytes):
"""Truncated AVDTP PDUs from a remote peer must NOT raise IndexError —
same DoS class as #912 (ATT empty PDU). The assembler is required to
log + drop and stay alive so the L2CAP channel survives."""
completed = []
def callback(transaction_label, message):
completed.append((transaction_label, message))
assembler = avdtp.MessageAssembler(callback)
# Must not raise; nothing should be delivered to callback either.
assembler.on_pdu(pdu)
assert not completed
# -----------------------------------------------------------------------------
def test_rtp():
packet = bytes.fromhex(
+3 -21
View File
@@ -171,15 +171,14 @@ class Source:
class Sink:
response: HCI_Event | None
response: HCI_Event
def __init__(self, source: Source, response: HCI_Event | None) -> None:
def __init__(self, source: Source, response: HCI_Event) -> None:
self.source = source
self.response = response
def on_packet(self, packet: bytes) -> None:
if self.response is not None:
self.source.sink.on_packet(bytes(self.response))
self.source.sink.on_packet(bytes(self.response))
@pytest.mark.asyncio
@@ -229,23 +228,6 @@ async def test_send_sync_command() -> None:
assert isinstance(response3.return_parameters, HCI_GenericReturnParameters)
@pytest.mark.asyncio
async def test_send_sync_command_timeout() -> None:
source = Source()
sink = Sink(source, None)
host = Host(source, sink)
host.ready = True
with pytest.raises(asyncio.TimeoutError):
await host.send_sync_command(HCI_Reset_Command(), response_timeout=0.01)
# The sending semaphore should have been released, so this should not block
# indefinitely
with pytest.raises(asyncio.TimeoutError):
await host.send_sync_command(hci.HCI_Reset_Command(), response_timeout=0.01)
@pytest.mark.asyncio
async def test_send_async_command() -> None:
source = Source()
-52
View File
@@ -18,11 +18,9 @@
import asyncio
import logging
import os
import re
import pytest
from bumble import sdp
from bumble.core import BT_L2CAP_PROTOCOL_ID, UUID
from bumble.sdp import (
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
@@ -208,16 +206,6 @@ def sdp_records(record_count=1):
}
# -----------------------------------------------------------------------------
def test_pdu_parameter_length(caplog) -> None:
caplog.set_level(logging.WARNING)
pdu = sdp.SDP_ErrorResponse(
transaction_id=0, error_code=sdp.ErrorCode.INVALID_SDP_VERSION
)
assert sdp.SDP_PDU.from_bytes(bytes(pdu)) == pdu
assert not re.search(r"Expect \d+ bytes, got \d+", caplog.text)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_search():
@@ -440,43 +428,3 @@ async def run():
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())
# -----------------------------------------------------------------------------
def test_nested_sequence_recursion_guard():
"""Regression test: deeply-nested SDP SEQUENCE/ALTERNATIVE must not crash
the parser with RecursionError. Instead a ValueError is raised once the
configured nesting limit is exceeded.
Root cause: DataElement.from_bytes -> list_from_bytes -> (constructor
dispatching back to list_from_bytes for SEQUENCE/ALTERNATIVE) recursed
without a depth limit. A malicious SDP peer could craft a PDU exceeding
Pythons default recursion limit (~1000 frames) and crash the host.
"""
# Build nested SEQUENCE payload with tag 0x36 (SEQUENCE, 2-byte length).
inner = b"\x35\x00" # empty SEQUENCE terminator
for _ in range(1500):
size = len(inner)
if size >= 65535:
break
inner = bytes([0x36, (size >> 8) & 0xFF, size & 0xFF]) + inner
with pytest.raises(ValueError, match="nesting exceeds max depth"):
DataElement.from_bytes(inner)
def test_nested_sequence_within_limit_still_works():
"""Nested-but-reasonable SDP SEQUENCEs must still parse correctly."""
leaf = DataElement.unsigned_integer(1, value_size=2)
payload = leaf
for _ in range(16): # under the 32-depth limit
payload = DataElement.sequence([payload])
raw = bytes(payload)
parsed = DataElement.from_bytes(raw)
# Walk back down to confirm structural integrity preserved
cur = parsed
for _ in range(16):
assert cur.type == DataElement.SEQUENCE
cur = cur.value[0]
assert cur.type == DataElement.UNSIGNED_INTEGER
assert cur.value == 1
+1 -64
View File
@@ -24,7 +24,7 @@ import sys
import pytest
from bumble import controller, device, hci, link, transport
from bumble.transport import common, usb
from bumble.transport import common
# -----------------------------------------------------------------------------
@@ -252,69 +252,6 @@ async def test_open_transport_with_metadata(spec):
await controller_transport.close()
# -----------------------------------------------------------------------------
def test_packet_splitter_complete():
emitted = []
splitter = usb.AclPacketSplitter(emitted.append)
packet = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
splitter.feed(packet)
assert emitted == [packet]
def test_packet_splitter_chunks():
emitted = []
splitter = usb.AclPacketSplitter(emitted.append)
packet = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
splitter.feed(packet[:4])
assert emitted == []
splitter.feed(packet[4:])
assert emitted == [packet]
def test_packet_splitter_multiple():
emitted = []
splitter = usb.AclPacketSplitter(emitted.append)
packet1 = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
packet2 = bytes([0x02, 0x00, 0x02, 0x00, 0x55, 0x66])
splitter.feed(packet1 + packet2)
assert emitted == [packet1, packet2]
def test_packet_splitter_partial():
emitted = []
splitter = usb.AclPacketSplitter(emitted.append)
packet1 = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
packet2 = bytes([0x02, 0x00, 0x02, 0x00, 0x55, 0x66])
splitter.feed(packet1 + packet2[:4])
assert emitted == [packet1]
splitter.feed(packet2[4:])
assert emitted == [packet1, packet2]
def test_packet_splitter_empty_payload():
emitted = []
splitter = usb.AclPacketSplitter(emitted.append)
packet = bytes([0x01, 0x00, 0x00, 0x00])
splitter.feed(packet)
assert emitted == [packet]
def test_sco_packet_splitter():
emitted = []
splitter = usb.ScoPacketSplitter(emitted.append)
packet = bytes([0x01, 0x00, 0x03, 0x11, 0x22, 0x33])
splitter.feed(packet)
assert emitted == [packet]
def test_event_packet_splitter():
emitted = []
splitter = usb.EventPacketSplitter(emitted.append)
packet = bytes([0x04, 0x02, 0x11, 0x22])
splitter.feed(packet)
assert emitted == [packet]
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_parser()
-95
View File
@@ -1,95 +0,0 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from unittest import mock
import pytest
from bumble import hci
from bumble.transport import usb
@pytest.mark.asyncio
async def test_usb_packet_sink_iso_routing():
# Mock usb1 device and endpoints
mock_device = mock.Mock()
mock_bulk_out = mock.Mock()
mock_bulk_out.getAddress.return_value = 0x02
# Scenario 1: Isochronous endpoints are not enabled (isochronous_out is None)
mock_transfer = mock.Mock()
mock_device.getTransfer.return_value = mock_transfer
sink = usb.UsbPacketSink(mock_device, mock_bulk_out, isochronous_out=None)
sink.start()
# Send HCI_ISO_DATA_PACKET
iso_packet = bytes([hci.HCI_ISO_DATA_PACKET, 0x01, 0x02, 0x03])
sink.on_packet(iso_packet)
# Yield control to let the queue processor run
await asyncio.sleep(0.01)
# Verify it was sent via bulk transfer
mock_transfer.setBulk.assert_called_once_with(
0x02,
bytes([0x01, 0x02, 0x03]),
callback=sink.transfer_callback,
)
mock_transfer.submit.assert_called_once()
if sink.queue_task:
sink.queue_task.cancel()
try:
await sink.queue_task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_usb_packet_sink_iso_routing_with_iso_endpoint():
# Mock usb1 device and endpoints
mock_device = mock.Mock()
mock_bulk_out = mock.Mock()
mock_bulk_out.getAddress.return_value = 0x02
mock_iso_out = mock.Mock()
mock_iso_out.getMaxPacketSize.return_value = 64
# Scenario 2: Isochronous endpoints are enabled
mock_transfer_bulk = mock.Mock()
mock_transfer_iso = mock.Mock()
# getTransfer is called twice: once for bulk_or_control and once for isochronous
mock_device.getTransfer.side_effect = [mock_transfer_bulk, mock_transfer_iso]
sink = usb.UsbPacketSink(mock_device, mock_bulk_out, isochronous_out=mock_iso_out)
sink.start()
# Send HCI_ISO_DATA_PACKET
iso_packet = bytes([hci.HCI_ISO_DATA_PACKET, 0x01, 0x02, 0x03])
sink.on_packet(iso_packet)
# Yield control to let the queue processor run
await asyncio.sleep(0.01)
# Verify it was NOT sent via bulk transfer
mock_transfer_bulk.setBulk.assert_not_called()
if sink.queue_task:
sink.queue_task.cancel()
try:
await sink.queue_task
except asyncio.CancelledError:
pass