mirror of
https://github.com/google/bumble.git
synced 2026-05-09 04:08:02 +00:00
add basic support for SCO packets over USB
This commit is contained in:
@@ -45,8 +45,10 @@ from bumble.hci import (
|
||||
HCI_Read_Local_Supported_Codecs_Command,
|
||||
HCI_Read_Local_Supported_Codecs_V2_Command,
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
HCI_Read_Voice_Setting_Command,
|
||||
LeFeature,
|
||||
SpecificationVersion,
|
||||
VoiceSetting,
|
||||
map_null_terminated_utf8_string,
|
||||
)
|
||||
from bumble.host import Host
|
||||
@@ -214,6 +216,16 @@ async def get_codecs_info(host: Host) -> None:
|
||||
if not response2.vendor_specific_codec_ids:
|
||||
print(' No Vendor-specific codecs')
|
||||
|
||||
if host.supports_command(HCI_Read_Voice_Setting_Command.op_code):
|
||||
response3 = await host.send_sync_command(HCI_Read_Voice_Setting_Command())
|
||||
voice_setting = VoiceSetting.from_int(response3.voice_setting)
|
||||
print(color('Voice Setting:', 'yellow'))
|
||||
print(f' Air Coding Format: {voice_setting.air_coding_format.name}')
|
||||
print(f' Linear PCM Bit Position: {voice_setting.linear_pcm_bit_position}')
|
||||
print(f' Input Sample Size: {voice_setting.input_sample_size.name}')
|
||||
print(f' Input Data Format: {voice_setting.input_data_format.name}')
|
||||
print(f' Input Coding Format: {voice_setting.input_coding_format.name}')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import statistics
|
||||
import struct
|
||||
import time
|
||||
|
||||
import click
|
||||
@@ -25,7 +27,9 @@ from bumble.colors import color
|
||||
from bumble.hci import (
|
||||
HCI_READ_LOOPBACK_MODE_COMMAND,
|
||||
HCI_WRITE_LOOPBACK_MODE_COMMAND,
|
||||
Address,
|
||||
HCI_Read_Loopback_Mode_Command,
|
||||
HCI_SynchronousDataPacket,
|
||||
HCI_Write_Loopback_Mode_Command,
|
||||
LoopbackMode,
|
||||
)
|
||||
@@ -36,34 +40,59 @@ from bumble.transport import open_transport
|
||||
class Loopback:
|
||||
"""Send and receive ACL data packets in local loopback mode"""
|
||||
|
||||
def __init__(self, packet_size: int, packet_count: int, transport: str):
|
||||
def __init__(
|
||||
self,
|
||||
packet_size: int,
|
||||
packet_count: int,
|
||||
connection_type: str,
|
||||
mode: str,
|
||||
interval: int,
|
||||
transport: str,
|
||||
):
|
||||
self.transport = transport
|
||||
self.packet_size = packet_size
|
||||
self.packet_count = packet_count
|
||||
self.connection_handle: int | None = None
|
||||
self.connection_type = connection_type
|
||||
self.connection_event = asyncio.Event()
|
||||
self.mode = mode
|
||||
self.interval = interval
|
||||
self.done = asyncio.Event()
|
||||
self.expected_cid = 0
|
||||
self.expected_counter = 0
|
||||
self.bytes_received = 0
|
||||
self.start_timestamp = 0.0
|
||||
self.last_timestamp = 0.0
|
||||
self.send_timestamps: list[float] = []
|
||||
self.rtts: list[float] = []
|
||||
|
||||
def on_connection(self, connection_handle: int, *args):
|
||||
"""Retrieve connection handle from new connection event"""
|
||||
if not self.connection_event.is_set():
|
||||
# save first connection handle for ACL
|
||||
# subsequent connections are SCO
|
||||
# The first connection handle is of type ACL,
|
||||
# subsequent connections are of type SCO
|
||||
if self.connection_type == "sco" and self.connection_handle is None:
|
||||
self.connection_handle = connection_handle
|
||||
return
|
||||
|
||||
self.connection_handle = connection_handle
|
||||
self.connection_event.set()
|
||||
|
||||
def on_sco_connection(
|
||||
self, address: Address, connection_handle: int, link_type: int
|
||||
):
|
||||
self.on_connection(connection_handle)
|
||||
|
||||
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
|
||||
"""Calculate packet receive speed"""
|
||||
now = time.time()
|
||||
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
|
||||
(counter,) = struct.unpack_from("H", pdu, 0)
|
||||
rtt = now - self.send_timestamps[counter]
|
||||
self.rtts.append(rtt)
|
||||
print(f'<<< Received packet {counter}: {len(pdu)} bytes, RTT={rtt:.4f}')
|
||||
assert connection_handle == self.connection_handle
|
||||
assert cid == self.expected_cid
|
||||
self.expected_cid += 1
|
||||
if cid == 0:
|
||||
assert counter == self.expected_counter
|
||||
self.expected_counter += 1
|
||||
if counter == 0:
|
||||
self.start_timestamp = now
|
||||
else:
|
||||
elapsed_since_start = now - self.start_timestamp
|
||||
@@ -71,20 +100,52 @@ class Loopback:
|
||||
self.bytes_received += len(pdu)
|
||||
instant_rx_speed = len(pdu) / elapsed_since_last
|
||||
average_rx_speed = self.bytes_received / elapsed_since_start
|
||||
print(
|
||||
color(
|
||||
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
|
||||
f' average={average_rx_speed:.4f}',
|
||||
'cyan',
|
||||
if self.mode == 'throughput':
|
||||
print(
|
||||
color(
|
||||
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
|
||||
f' average={average_rx_speed:.4f},',
|
||||
'cyan',
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.last_timestamp = now
|
||||
|
||||
if self.expected_cid == self.packet_count:
|
||||
if self.expected_counter == self.packet_count:
|
||||
print(color('@@@ Received last packet', 'green'))
|
||||
self.done.set()
|
||||
|
||||
def on_sco_packet(self, connection_handle: int, packet) -> None:
|
||||
print("---", connection_handle, packet)
|
||||
|
||||
async def send_acl_packet(self, host: Host, packet: bytes) -> None:
|
||||
assert self.connection_handle
|
||||
host.send_l2cap_pdu(self.connection_handle, 0, packet)
|
||||
|
||||
async def send_sco_packet(self, host: Host, packet: bytes) -> None:
|
||||
assert self.connection_handle
|
||||
host.send_hci_packet(
|
||||
HCI_SynchronousDataPacket(
|
||||
connection_handle=self.connection_handle,
|
||||
packet_status=HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA,
|
||||
data_total_length=len(packet),
|
||||
data=packet,
|
||||
)
|
||||
)
|
||||
|
||||
async def send_loop(self, host: Host, sender) -> None:
|
||||
for counter in range(0, self.packet_count):
|
||||
print(
|
||||
color(
|
||||
f'>>> Sending {self.connection_type.upper()} '
|
||||
f'packet {counter}: {self.packet_size} bytes',
|
||||
'yellow',
|
||||
)
|
||||
)
|
||||
self.send_timestamps.append(time.time())
|
||||
await sender(host, struct.pack("H", counter) + bytes(self.packet_size - 2))
|
||||
await asyncio.sleep(self.interval / 1000 if self.mode == "rtt" else 0)
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run a loopback throughput test"""
|
||||
print(color('>>> Connecting to HCI...', 'green'))
|
||||
@@ -126,8 +187,11 @@ class Loopback:
|
||||
return
|
||||
|
||||
# set event callbacks
|
||||
host.on('connection', self.on_connection)
|
||||
host.on('classic_connection', self.on_connection)
|
||||
host.on('le_connection', self.on_connection)
|
||||
host.on('sco_connection', self.on_sco_connection)
|
||||
host.on('l2cap_pdu', self.on_l2cap_pdu)
|
||||
host.on('sco_packet', self.on_sco_packet)
|
||||
|
||||
loopback_mode = LoopbackMode.LOCAL
|
||||
|
||||
@@ -148,32 +212,37 @@ class Loopback:
|
||||
|
||||
print(color('=== Start sending', 'magenta'))
|
||||
start_time = time.time()
|
||||
bytes_sent = 0
|
||||
for cid in range(0, self.packet_count):
|
||||
# using the cid as an incremental index
|
||||
host.send_l2cap_pdu(
|
||||
self.connection_handle, cid, bytes(self.packet_size)
|
||||
)
|
||||
print(
|
||||
color(
|
||||
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
|
||||
)
|
||||
)
|
||||
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
|
||||
await asyncio.sleep(0) # yield to allow packet receive
|
||||
if self.connection_type == "acl":
|
||||
sender = self.send_acl_packet
|
||||
elif self.connection_type == "sco":
|
||||
sender = self.send_sco_packet
|
||||
else:
|
||||
raise ValueError(f'Unknown connection type: {self.connection_type}')
|
||||
await self.send_loop(host, sender)
|
||||
|
||||
await self.done.wait()
|
||||
print(color('=== Done!', 'magenta'))
|
||||
|
||||
bytes_sent = self.packet_size * self.packet_count
|
||||
elapsed = time.time() - start_time
|
||||
average_tx_speed = bytes_sent / elapsed
|
||||
print(
|
||||
color(
|
||||
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
|
||||
f' in {elapsed:.2f} seconds)',
|
||||
'green',
|
||||
if self.mode == 'throughput':
|
||||
print(
|
||||
color(
|
||||
f'@@@ TX speed: average={average_tx_speed:.4f} '
|
||||
f'({bytes_sent} bytes in {elapsed:.2f} seconds)',
|
||||
'green',
|
||||
)
|
||||
)
|
||||
if self.mode == 'rtt':
|
||||
print(
|
||||
color(
|
||||
f'RTTs: min={min(self.rtts):.4f}, '
|
||||
f'max={max(self.rtts):.4f}, '
|
||||
f'avg={statistics.mean(self.rtts):.4f}',
|
||||
'blue',
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -194,11 +263,43 @@ class Loopback:
|
||||
default=10,
|
||||
help='Packet count',
|
||||
)
|
||||
@click.option(
|
||||
'--connection-type',
|
||||
'-t',
|
||||
metavar='TYPE',
|
||||
type=click.Choice(['acl', 'sco']),
|
||||
default='acl',
|
||||
help='Connection type',
|
||||
)
|
||||
@click.option(
|
||||
'--mode',
|
||||
'-m',
|
||||
metavar='MODE',
|
||||
type=click.Choice(['throughput', 'rtt']),
|
||||
default='throughput',
|
||||
help='Test mode',
|
||||
)
|
||||
@click.option(
|
||||
'--interval',
|
||||
type=int,
|
||||
default=100,
|
||||
help='Inter-packet interval (ms) [RTT mode only]',
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(packet_size, packet_count, transport):
|
||||
def main(packet_size, packet_count, connection_type, mode, interval, transport):
|
||||
bumble.logging.setup_basic_logging()
|
||||
loopback = Loopback(packet_size, packet_count, transport)
|
||||
asyncio.run(loopback.run())
|
||||
|
||||
if connection_type == "sco" and packet_size > 255:
|
||||
print("ERROR: the maximum packet size for SCO is 255")
|
||||
return
|
||||
|
||||
async def run():
|
||||
loopback = Loopback(
|
||||
packet_size, packet_count, connection_type, mode, interval, transport
|
||||
)
|
||||
await loopback.run()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -111,9 +111,14 @@ def show_device_details(device):
|
||||
if (endpoint.getAddress() & USB_ENDPOINT_IN == 0)
|
||||
else 'IN'
|
||||
)
|
||||
endpoint_details = (
|
||||
f', Max Packet Size = {endpoint.getMaxPacketSize()}'
|
||||
if endpoint_type == 'ISOCHRONOUS'
|
||||
else ''
|
||||
)
|
||||
print(
|
||||
f' Endpoint 0x{endpoint.getAddress():02X}: '
|
||||
f'{endpoint_type} {endpoint_direction}'
|
||||
f'{endpoint_type} {endpoint_direction}{endpoint_details}'
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user