add basic support for SCO packets over USB

This commit is contained in:
Gilles Boccon-Gibod
2024-09-26 17:26:50 -07:00
parent 2d17a5f742
commit 794a4a3ef0
11 changed files with 1190 additions and 541 deletions

View File

@@ -16,6 +16,8 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import statistics
import struct
import time
import click
@@ -25,7 +27,9 @@ from bumble.colors import color
from bumble.hci import (
HCI_READ_LOOPBACK_MODE_COMMAND,
HCI_WRITE_LOOPBACK_MODE_COMMAND,
Address,
HCI_Read_Loopback_Mode_Command,
HCI_SynchronousDataPacket,
HCI_Write_Loopback_Mode_Command,
LoopbackMode,
)
@@ -36,34 +40,59 @@ from bumble.transport import open_transport
class Loopback:
"""Send and receive ACL data packets in local loopback mode"""
def __init__(self, packet_size: int, packet_count: int, transport: str):
def __init__(
self,
packet_size: int,
packet_count: int,
connection_type: str,
mode: str,
interval: int,
transport: str,
):
self.transport = transport
self.packet_size = packet_size
self.packet_count = packet_count
self.connection_handle: int | None = None
self.connection_type = connection_type
self.connection_event = asyncio.Event()
self.mode = mode
self.interval = interval
self.done = asyncio.Event()
self.expected_cid = 0
self.expected_counter = 0
self.bytes_received = 0
self.start_timestamp = 0.0
self.last_timestamp = 0.0
self.send_timestamps: list[float] = []
self.rtts: list[float] = []
def on_connection(self, connection_handle: int, *args):
"""Retrieve connection handle from new connection event"""
if not self.connection_event.is_set():
# save first connection handle for ACL
# subsequent connections are SCO
# The first connection handle is of type ACL,
# subsequent connections are of type SCO
if self.connection_type == "sco" and self.connection_handle is None:
self.connection_handle = connection_handle
return
self.connection_handle = connection_handle
self.connection_event.set()
def on_sco_connection(
self, address: Address, connection_handle: int, link_type: int
):
self.on_connection(connection_handle)
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
"""Calculate packet receive speed"""
now = time.time()
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
(counter,) = struct.unpack_from("H", pdu, 0)
rtt = now - self.send_timestamps[counter]
self.rtts.append(rtt)
print(f'<<< Received packet {counter}: {len(pdu)} bytes, RTT={rtt:.4f}')
assert connection_handle == self.connection_handle
assert cid == self.expected_cid
self.expected_cid += 1
if cid == 0:
assert counter == self.expected_counter
self.expected_counter += 1
if counter == 0:
self.start_timestamp = now
else:
elapsed_since_start = now - self.start_timestamp
@@ -71,20 +100,52 @@ class Loopback:
self.bytes_received += len(pdu)
instant_rx_speed = len(pdu) / elapsed_since_last
average_rx_speed = self.bytes_received / elapsed_since_start
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f}',
'cyan',
if self.mode == 'throughput':
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f},',
'cyan',
)
)
)
self.last_timestamp = now
if self.expected_cid == self.packet_count:
if self.expected_counter == self.packet_count:
print(color('@@@ Received last packet', 'green'))
self.done.set()
def on_sco_packet(self, connection_handle: int, packet) -> None:
print("---", connection_handle, packet)
async def send_acl_packet(self, host: Host, packet: bytes) -> None:
assert self.connection_handle
host.send_l2cap_pdu(self.connection_handle, 0, packet)
async def send_sco_packet(self, host: Host, packet: bytes) -> None:
assert self.connection_handle
host.send_hci_packet(
HCI_SynchronousDataPacket(
connection_handle=self.connection_handle,
packet_status=HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA,
data_total_length=len(packet),
data=packet,
)
)
async def send_loop(self, host: Host, sender) -> None:
for counter in range(0, self.packet_count):
print(
color(
f'>>> Sending {self.connection_type.upper()} '
f'packet {counter}: {self.packet_size} bytes',
'yellow',
)
)
self.send_timestamps.append(time.time())
await sender(host, struct.pack("H", counter) + bytes(self.packet_size - 2))
await asyncio.sleep(self.interval / 1000 if self.mode == "rtt" else 0)
async def run(self) -> None:
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
@@ -126,8 +187,11 @@ class Loopback:
return
# set event callbacks
host.on('connection', self.on_connection)
host.on('classic_connection', self.on_connection)
host.on('le_connection', self.on_connection)
host.on('sco_connection', self.on_sco_connection)
host.on('l2cap_pdu', self.on_l2cap_pdu)
host.on('sco_packet', self.on_sco_packet)
loopback_mode = LoopbackMode.LOCAL
@@ -148,32 +212,37 @@ class Loopback:
print(color('=== Start sending', 'magenta'))
start_time = time.time()
bytes_sent = 0
for cid in range(0, self.packet_count):
# using the cid as an incremental index
host.send_l2cap_pdu(
self.connection_handle, cid, bytes(self.packet_size)
)
print(
color(
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
)
)
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
await asyncio.sleep(0) # yield to allow packet receive
if self.connection_type == "acl":
sender = self.send_acl_packet
elif self.connection_type == "sco":
sender = self.send_sco_packet
else:
raise ValueError(f'Unknown connection type: {self.connection_type}')
await self.send_loop(host, sender)
await self.done.wait()
print(color('=== Done!', 'magenta'))
bytes_sent = self.packet_size * self.packet_count
elapsed = time.time() - start_time
average_tx_speed = bytes_sent / elapsed
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
f' in {elapsed:.2f} seconds)',
'green',
if self.mode == 'throughput':
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} '
f'({bytes_sent} bytes in {elapsed:.2f} seconds)',
'green',
)
)
if self.mode == 'rtt':
print(
color(
f'RTTs: min={min(self.rtts):.4f}, '
f'max={max(self.rtts):.4f}, '
f'avg={statistics.mean(self.rtts):.4f}',
'blue',
)
)
)
# -----------------------------------------------------------------------------
@@ -194,11 +263,43 @@ class Loopback:
default=10,
help='Packet count',
)
@click.option(
'--connection-type',
'-t',
metavar='TYPE',
type=click.Choice(['acl', 'sco']),
default='acl',
help='Connection type',
)
@click.option(
'--mode',
'-m',
metavar='MODE',
type=click.Choice(['throughput', 'rtt']),
default='throughput',
help='Test mode',
)
@click.option(
'--interval',
type=int,
default=100,
help='Inter-packet interval (ms) [RTT mode only]',
)
@click.argument('transport')
def main(packet_size, packet_count, transport):
def main(packet_size, packet_count, connection_type, mode, interval, transport):
bumble.logging.setup_basic_logging()
loopback = Loopback(packet_size, packet_count, transport)
asyncio.run(loopback.run())
if connection_type == "sco" and packet_size > 255:
print("ERROR: the maximum packet size for SCO is 255")
return
async def run():
loopback = Loopback(
packet_size, packet_count, connection_type, mode, interval, transport
)
await loopback.run()
asyncio.run(run())
# -----------------------------------------------------------------------------