Files
bumble_mirror/apps/bench.py
2026-01-01 03:25:32 +08:00

2329 lines
75 KiB
Python

# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import dataclasses
import enum
import logging
import statistics
import struct
import time
import click
import bumble.core
import bumble.logging
import bumble.rfcomm
from bumble import l2cap
from bumble.colors import color
from bumble.core import (
BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID,
UUID,
CommandTimeoutError,
ConnectionPHY,
PhysicalTransport,
)
from bumble.device import (
CigParameters,
CisLink,
Connection,
ConnectionParametersPreferences,
Device,
Peer,
)
from bumble.gatt import Characteristic, CharacteristicValue, Service
from bumble.hci import (
HCI_LE_1M_PHY,
HCI_LE_2M_PHY,
HCI_LE_CODED_PHY,
HCI_Constant,
HCI_Error,
HCI_IsoDataPacket,
HCI_StatusError,
Role,
)
from bumble.pairing import PairingConfig
from bumble.sdp import (
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_PUBLIC_BROWSE_ROOT,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement,
ServiceAttribute,
)
from bumble.transport import open_transport
from bumble.utils import AsyncRunner
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_CENTRAL_ADDRESS = 'F0:F0:F0:F0:F0:F0'
DEFAULT_CENTRAL_NAME = 'Speed Central'
DEFAULT_PERIPHERAL_ADDRESS = 'F1:F1:F1:F1:F1:F1'
DEFAULT_PERIPHERAL_NAME = 'Speed Peripheral'
DEFAULT_ADVERTISING_INTERVAL = 100
SPEED_SERVICE_UUID = '50DB505C-8AC4-4738-8448-3B1D9CC09CC5'
SPEED_TX_UUID = 'E789C754-41A1-45F4-A948-A0A1A90DBA53'
SPEED_RX_UUID = '016A2CC7-E14B-4819-935F-1F56EAE4098D'
DEFAULT_RFCOMM_UUID = 'E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'
DEFAULT_L2CAP_PSM = 128
DEFAULT_L2CAP_MAX_CREDITS = 128
DEFAULT_L2CAP_MTU = 1024
DEFAULT_L2CAP_MPS = 1024
DEFAULT_ISO_MAX_SDU_C_TO_P = 251
DEFAULT_ISO_MAX_SDU_P_TO_C = 251
DEFAULT_ISO_SDU_INTERVAL_C_TO_P = 10000
DEFAULT_ISO_SDU_INTERVAL_P_TO_C = 10000
DEFAULT_ISO_MAX_TRANSPORT_LATENCY_C_TO_P = 35
DEFAULT_ISO_MAX_TRANSPORT_LATENCY_P_TO_C = 35
DEFAULT_ISO_RTN_C_TO_P = 3
DEFAULT_ISO_RTN_P_TO_C = 3
DEFAULT_LINGER_TIME = 1.0
DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0
DEFAULT_RFCOMM_CHANNEL = 8
DEFAULT_RFCOMM_MTU = 2048
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def le_phy_name(phy_id):
return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
phy_id, HCI_Constant.le_phy_name(phy_id)
)
def print_connection_phy(phy: ConnectionPHY) -> None:
logging.info(
color('@@@ PHY: ', 'yellow') + f'TX:{le_phy_name(phy.tx_phy)}/'
f'RX:{le_phy_name(phy.rx_phy)}'
)
def print_connection(connection: Connection) -> None:
params = []
if connection.transport == PhysicalTransport.LE:
params.append(
'DL=('
f'TX:{connection.data_length[0]}/{connection.data_length[1]},'
f'RX:{connection.data_length[2]}/{connection.data_length[3]}'
')'
)
params.append(
'Parameters='
f'{connection.parameters.connection_interval:.2f}/'
f'{connection.parameters.peripheral_latency}/'
f'{connection.parameters.supervision_timeout:.2f} '
)
params.append(f'MTU={connection.att_mtu}')
else:
params.append(f'Role={HCI_Constant.role_name(connection.role)}')
logging.info(color('@@@ Connection: ', 'yellow') + ' '.join(params))
def print_cis_link(cis_link: CisLink) -> None:
logging.info(color("@@@ CIS established", "green"))
logging.info(color('@@@ ISO interval: ', 'green') + f"{cis_link.iso_interval}ms")
logging.info(color('@@@ NSE: ', 'green') + f"{cis_link.nse}")
logging.info(color('@@@ Central->Peripheral:', 'green'))
if cis_link.phy_c_to_p is not None:
logging.info(
color('@@@ PHY: ', 'green') + f"{cis_link.phy_c_to_p.name}"
)
logging.info(
color('@@@ Latency: ', 'green') + f"{cis_link.transport_latency_c_to_p}µs"
)
logging.info(color('@@@ BN: ', 'green') + f"{cis_link.bn_c_to_p}")
logging.info(color('@@@ FT: ', 'green') + f"{cis_link.ft_c_to_p}")
logging.info(color('@@@ Max PDU: ', 'green') + f"{cis_link.max_pdu_c_to_p}")
logging.info(color('@@@ Peripheral->Central:', 'green'))
if cis_link.phy_p_to_c is not None:
logging.info(
color('@@@ PHY: ', 'green') + f"{cis_link.phy_p_to_c.name}"
)
logging.info(
color('@@@ Latency: ', 'green') + f"{cis_link.transport_latency_p_to_c}µs"
)
logging.info(color('@@@ BN: ', 'green') + f"{cis_link.bn_p_to_c}")
logging.info(color('@@@ FT: ', 'green') + f"{cis_link.ft_p_to_c}")
logging.info(color('@@@ Max PDU: ', 'green') + f"{cis_link.max_pdu_p_to_c}")
def make_sdp_records(channel):
return {
0x00010001: [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(0x00010001),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(UUID(DEFAULT_RFCOMM_UUID))]),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel),
]
),
]
),
),
]
}
def log_stats(title, stats, precision=2):
stats_min = min(stats)
stats_max = max(stats)
stats_avg = statistics.mean(stats)
stats_stdev = statistics.stdev(stats) if len(stats) >= 2 else 0
logging.info(
color(
(
f'### {title} stats: '
f'min={stats_min:.{precision}f}, '
f'max={stats_max:.{precision}f}, '
f'average={stats_avg:.{precision}f}, '
f'stdev={stats_stdev:.{precision}f}'
),
'cyan',
)
)
async def switch_roles(connection, role):
target_role = Role.CENTRAL if role == "central" else Role.PERIPHERAL
if connection.role != target_role:
logging.info(f'{color("### Switching roles to:", "cyan")} {role}')
try:
await connection.switch_role(target_role)
logging.info(color('### Role switch complete', 'cyan'))
except HCI_Error as error:
logging.info(f'{color("### Role switch failed:", "red")} {error}')
async def pre_power_on(device: Device, classic: bool) -> None:
device.classic_enabled = classic
# Set up a pairing config factory with minimal requirements.
device.config.keystore = "JsonKeyStore"
device.pairing_config_factory = lambda _: PairingConfig(
sc=False, mitm=False, bonding=False
)
async def post_power_on(
device: Device,
le_scan: tuple[int, int] | None,
le_advertise: int | None,
classic_page_scan: bool,
classic_inquiry_scan: bool,
) -> None:
if classic_page_scan:
logging.info(color("*** Enabling page scan", "blue"))
await device.set_connectable(True)
if classic_inquiry_scan:
logging.info(color("*** Enabling inquiry scan", "blue"))
await device.set_discoverable(True)
if le_scan:
scan_window, scan_interval = le_scan
logging.info(
color(
f"*** Starting LE scanning [{scan_window}ms/{scan_interval}ms]",
"blue",
)
)
await device.start_scanning(
scan_interval=scan_interval, scan_window=scan_window
)
if le_advertise:
logging.info(color(f"*** Starting LE advertising [{le_advertise}ms]", "blue"))
await device.start_advertising(
advertising_interval_min=le_advertise,
advertising_interval_max=le_advertise,
auto_restart=True,
)
# -----------------------------------------------------------------------------
# Packet
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class Packet:
class PacketType(enum.IntEnum):
RESET = 0
SEQUENCE = 1
ACK = 2
class PacketFlags(enum.IntFlag):
LAST = 1
packet_type: PacketType
flags: PacketFlags = PacketFlags(0)
sequence: int = 0
timestamp: int = 0
payload: bytes = b""
@classmethod
def from_bytes(cls, data: bytes):
if len(data) < 1:
logging.warning(
color(f'!!! Packet too short (got {len(data)} bytes, need >= 1)', 'red')
)
raise ValueError('packet too short')
try:
packet_type = cls.PacketType(data[0])
except ValueError:
logging.warning(color(f'!!! Invalid packet type 0x{data[0]:02X}', 'red'))
raise
if packet_type == cls.PacketType.RESET:
return cls(packet_type)
flags = cls.PacketFlags(data[1])
(sequence,) = struct.unpack_from("<I", data, 2)
if packet_type == cls.PacketType.ACK:
if len(data) < 6:
logging.warning(
color(
f'!!! Packet too short (got {len(data)} bytes, need >= 6)',
'red',
)
)
return cls(packet_type, flags, sequence)
if len(data) < 10:
logging.warning(
color(
f'!!! Packet too short (got {len(data)} bytes, need >= 10)', 'red'
)
)
raise ValueError('packet too short')
(timestamp,) = struct.unpack_from("<I", data, 6)
return cls(packet_type, flags, sequence, timestamp, data[10:])
def __bytes__(self):
if self.packet_type == self.PacketType.RESET:
return bytes([self.packet_type])
if self.packet_type == self.PacketType.ACK:
return struct.pack("<BBI", self.packet_type, self.flags, self.sequence)
return (
struct.pack(
"<BBII", self.packet_type, self.flags, self.sequence, self.timestamp
)
+ self.payload
)
# -----------------------------------------------------------------------------
# Jitter Stats
# -----------------------------------------------------------------------------
class JitterStats:
def __init__(self):
self.reset()
def reset(self):
self.packets = []
self.receive_times = []
self.jitter = []
def on_packet_received(self, packet):
now = time.time()
self.packets.append(packet)
self.receive_times.append(now)
if packet.timestamp and len(self.packets) > 1:
expected_time = (
self.receive_times[0]
+ (packet.timestamp - self.packets[0].timestamp) / 1000000
)
jitter = now - expected_time
else:
jitter = 0.0
self.jitter.append(jitter)
return jitter
def show_stats(self):
if len(self.jitter) < 3:
return
average = sum(self.jitter) / len(self.jitter)
adjusted = [jitter - average for jitter in self.jitter]
log_stats('Jitter (signed)', adjusted, 3)
log_stats('Jitter (absolute)', [abs(jitter) for jitter in adjusted], 3)
# Show a histogram
bin_count = 20
bins = [0] * bin_count
interval_min = min(adjusted)
interval_max = max(adjusted)
interval_range = interval_max - interval_min
bin_thresholds = [
interval_min + i * (interval_range / bin_count) for i in range(bin_count)
]
for jitter in adjusted:
for i in reversed(range(bin_count)):
if jitter >= bin_thresholds[i]:
bins[i] += 1
break
for i in range(bin_count):
logging.info(f'@@@ >= {bin_thresholds[i]:.4f}: {bins[i]}')
# -----------------------------------------------------------------------------
# Sender
# -----------------------------------------------------------------------------
class Sender:
def __init__(
self,
packet_io,
start_delay,
repeat,
repeat_delay,
pace,
packet_size,
packet_count,
):
self.tx_start_delay = start_delay
self.tx_packet_size = packet_size
self.tx_packet_count = packet_count
self.packet_io = packet_io
self.packet_io.packet_listener = self
self.repeat = repeat
self.repeat_delay = repeat_delay
self.pace = pace
self.start_time = 0
self.bytes_sent = 0
self.stats = []
self.done = asyncio.Event()
def reset(self):
pass
async def run(self):
logging.info(color('--- Waiting for I/O to be ready...', 'blue'))
await self.packet_io.ready.wait()
logging.info(color('--- Go!', 'blue'))
for run in range(self.repeat + 1):
self.done.clear()
if run > 0 and self.repeat and self.repeat_delay:
logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green'))
await asyncio.sleep(self.repeat_delay)
if self.tx_start_delay:
logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
await asyncio.sleep(self.tx_start_delay)
logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(
bytes(Packet(packet_type=Packet.PacketType.RESET))
)
self.start_time = time.time()
self.bytes_sent = 0
for tx_i in range(self.tx_packet_count):
if self.pace > 0:
# Wait until it is time to send the next packet
target_time = self.start_time + (tx_i * self.pace / 1000)
now = time.time()
if now < target_time:
await asyncio.sleep(target_time - now)
else:
await self.packet_io.drain()
packet = bytes(
Packet(
packet_type=Packet.PacketType.SEQUENCE,
flags=(
Packet.PacketFlags.LAST
if tx_i == self.tx_packet_count - 1
else 0
),
sequence=tx_i,
timestamp=int((time.time() - self.start_time) * 1000000),
payload=bytes(
self.tx_packet_size - 10 - self.packet_io.overhead_size
),
)
)
logging.info(
color(
f'Sending packet {tx_i}: {self.tx_packet_size} bytes', 'yellow'
)
)
self.bytes_sent += len(packet)
await self.packet_io.send_packet(packet)
if self.packet_io.can_receive():
await self.done.wait()
run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
logging.info(color(f'=== {run_counter} Done!', 'magenta'))
if self.repeat:
log_stats('Run', self.stats)
if self.repeat:
logging.info(color('--- End of runs', 'blue'))
def on_packet_received(self, data):
try:
packet = Packet.from_bytes(data)
except ValueError:
return
if packet.packet_type == Packet.PacketType.ACK:
elapsed = time.time() - self.start_time
average_tx_speed = self.bytes_sent / elapsed
self.stats.append(average_tx_speed)
logging.info(
color(
f'@@@ Received ACK. Speed: average={average_tx_speed:.4f}'
f' ({self.bytes_sent} bytes in {elapsed:.2f} seconds)',
'green',
)
)
self.done.set()
def is_sender(self):
return True
# -----------------------------------------------------------------------------
# Receiver
# -----------------------------------------------------------------------------
class Receiver:
expected_packet_index: int
start_timestamp: float
last_timestamp: float
def __init__(self, packet_io, linger):
self.jitter_stats = JitterStats()
self.packet_io = packet_io
self.packet_io.packet_listener = self
self.linger = linger
self.done = asyncio.Event()
self.reset()
def reset(self):
self.expected_packet_index = 0
self.measurements = [(time.time(), 0)]
self.total_bytes_received = 0
self.jitter_stats.reset()
def on_packet_received(self, data):
try:
packet = Packet.from_bytes(data)
except ValueError:
logging.exception("invalid packet")
return
if packet.packet_type == Packet.PacketType.RESET:
logging.info(color('=== Received RESET', 'magenta'))
self.reset()
return
jitter = self.jitter_stats.on_packet_received(packet)
logging.info(
f'<<< Received packet {packet.sequence}: '
f'flags={packet.flags}, '
f'jitter={jitter:.4f}, '
f'{len(data) + self.packet_io.overhead_size} bytes',
)
if packet.sequence != self.expected_packet_index:
logging.info(
color(
f'!!! Unexpected packet, expected {self.expected_packet_index} '
f'but received {packet.sequence}',
'red',
)
)
now = time.time()
elapsed_since_start = now - self.measurements[0][0]
elapsed_since_last = now - self.measurements[-1][0]
self.measurements.append((now, len(data)))
self.total_bytes_received += len(data)
instant_rx_speed = len(data) / elapsed_since_last
average_rx_speed = self.total_bytes_received / elapsed_since_start
window = self.measurements[-64:]
windowed_rx_speed = sum(measurement[1] for measurement in window[1:]) / (
window[-1][0] - window[0][0]
)
logging.info(
color(
'Speed: '
f'instant={instant_rx_speed:.4f}, '
f'windowed={windowed_rx_speed:.4f}, '
f'average={average_rx_speed:.4f}',
'yellow',
)
)
self.expected_packet_index = packet.sequence + 1
if packet.flags & Packet.PacketFlags.LAST:
AsyncRunner.spawn(
self.packet_io.send_packet(
bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence))
)
)
logging.info(color('@@@ Received last packet', 'green'))
self.jitter_stats.show_stats()
if not self.linger:
self.done.set()
async def run(self):
await self.done.wait()
logging.info(color('=== Done!', 'magenta'))
def is_sender(self):
return False
# -----------------------------------------------------------------------------
# Ping
# -----------------------------------------------------------------------------
class Ping:
def __init__(
self,
packet_io,
start_delay,
repeat,
repeat_delay,
pace,
packet_size,
packet_count,
):
self.tx_start_delay = start_delay
self.tx_packet_size = packet_size
self.tx_packet_count = packet_count
self.packet_io = packet_io
self.packet_io.packet_listener = self
self.repeat = repeat
self.repeat_delay = repeat_delay
self.pace = pace
self.done = asyncio.Event()
self.ping_times = []
self.rtts = []
self.next_expected_packet_index = 0
self.min_stats = []
self.max_stats = []
self.avg_stats = []
def reset(self):
pass
async def run(self):
logging.info(color('--- Waiting for I/O to be ready...', 'blue'))
await self.packet_io.ready.wait()
logging.info(color('--- Go!', 'blue'))
for run in range(self.repeat + 1):
self.done.clear()
self.ping_times = []
if run > 0 and self.repeat and self.repeat_delay:
logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green'))
await asyncio.sleep(self.repeat_delay)
if self.tx_start_delay:
logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
await asyncio.sleep(self.tx_start_delay)
logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(bytes(Packet(Packet.PacketType.RESET)))
start_time = time.time()
self.next_expected_packet_index = 0
for i in range(self.tx_packet_count):
target_time = start_time + (i * self.pace / 1000)
now = time.time()
if now < target_time:
await asyncio.sleep(target_time - now)
now = time.time()
packet = bytes(
Packet(
packet_type=Packet.PacketType.SEQUENCE,
flags=(
Packet.PacketFlags.LAST
if i == self.tx_packet_count - 1
else 0
),
sequence=i,
timestamp=int((now - start_time) * 1000000),
payload=bytes(self.tx_packet_size - 10),
)
)
logging.info(color(f'Sending packet {i}', 'yellow'))
self.ping_times.append(now)
await self.packet_io.send_packet(packet)
await self.done.wait()
min_rtt = min(self.rtts)
max_rtt = max(self.rtts)
avg_rtt = statistics.mean(self.rtts)
stdev_rtt = statistics.stdev(self.rtts)
logging.info(
color(
'@@@ RTTs: '
f'min={min_rtt:.2f}, '
f'max={max_rtt:.2f}, '
f'average={avg_rtt:.2f}, '
f'stdev={stdev_rtt:.2f}'
)
)
self.min_stats.append(min_rtt)
self.max_stats.append(max_rtt)
self.avg_stats.append(avg_rtt)
run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
logging.info(color(f'=== {run_counter} Done!', 'magenta'))
if self.repeat:
log_stats('Min RTT', self.min_stats)
log_stats('Max RTT', self.max_stats)
log_stats('Average RTT', self.avg_stats)
if self.repeat:
logging.info(color('--- End of runs', 'blue'))
def on_packet_received(self, data):
try:
packet = Packet.from_bytes(data)
except ValueError:
return
if packet.packet_type == Packet.PacketType.ACK:
elapsed = time.time() - self.ping_times[packet.sequence]
rtt = elapsed * 1000
self.rtts.append(rtt)
logging.info(
color(
f'<<< Received ACK [{packet.sequence}], RTT={rtt:.2f}ms',
'green',
)
)
if packet.sequence == self.next_expected_packet_index:
self.next_expected_packet_index += 1
else:
logging.info(
color(
f'!!! Unexpected packet, '
f'expected {self.next_expected_packet_index} '
f'but received {packet.sequence}',
'red',
)
)
if packet.flags & Packet.PacketFlags.LAST:
self.done.set()
return
def is_sender(self):
return True
# -----------------------------------------------------------------------------
# Pong
# -----------------------------------------------------------------------------
class Pong:
expected_packet_index: int
def __init__(self, packet_io, linger):
self.jitter_stats = JitterStats()
self.packet_io = packet_io
self.packet_io.packet_listener = self
self.linger = linger
self.done = asyncio.Event()
self.reset()
def reset(self):
self.expected_packet_index = 0
self.jitter_stats.reset()
def on_packet_received(self, data):
try:
packet = Packet.from_bytes(data)
except ValueError:
return
if packet.packet_type == Packet.PacketType.RESET:
logging.info(color('=== Received RESET', 'magenta'))
self.reset()
return
jitter = self.jitter_stats.on_packet_received(packet)
logging.info(
color(
f'<<< Received packet {packet.sequence}: '
f'flags={packet.flags}, {len(data)} bytes, '
f'jitter={jitter:.4f}',
'green',
)
)
if packet.sequence != self.expected_packet_index:
logging.info(
color(
f'!!! Unexpected packet, expected {self.expected_packet_index} '
f'but received {packet.sequence}',
'red',
)
)
self.expected_packet_index = packet.sequence + 1
AsyncRunner.spawn(
self.packet_io.send_packet(
bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence))
)
)
if packet.flags & Packet.PacketFlags.LAST:
self.jitter_stats.show_stats()
if not self.linger:
self.done.set()
async def run(self):
await self.done.wait()
logging.info(color('=== Done!', 'magenta'))
def is_sender(self):
return False
# -----------------------------------------------------------------------------
# GattClient
# -----------------------------------------------------------------------------
class GattClient:
def __init__(self, _device, att_mtu=None):
self.att_mtu = att_mtu
self.speed_rx = None
self.speed_tx = None
self.packet_listener = None
self.ready = asyncio.Event()
self.overhead_size = 0
async def on_connection(self, connection):
peer = Peer(connection)
if self.att_mtu:
logging.info(color(f'*** Requesting MTU update: {self.att_mtu}', 'blue'))
await peer.request_mtu(self.att_mtu)
logging.info(color('*** Discovering services...', 'blue'))
await peer.discover_services()
speed_services = peer.get_services_by_uuid(SPEED_SERVICE_UUID)
if not speed_services:
logging.info(color('!!! Speed Service not found', 'red'))
return
speed_service = speed_services[0]
logging.info(color('*** Discovering characteristics...', 'blue'))
await speed_service.discover_characteristics()
speed_txs = speed_service.get_characteristics_by_uuid(SPEED_TX_UUID)
if not speed_txs:
logging.info(color('!!! Speed TX not found', 'red'))
return
self.speed_tx = speed_txs[0]
speed_rxs = speed_service.get_characteristics_by_uuid(SPEED_RX_UUID)
if not speed_rxs:
logging.info(color('!!! Speed RX not found', 'red'))
return
self.speed_rx = speed_rxs[0]
logging.info(color('*** Subscribing to RX', 'blue'))
await self.speed_rx.subscribe(self.on_packet_received)
logging.info(color('*** Discovery complete', 'blue'))
connection.on('disconnection', self.on_disconnection)
self.ready.set()
def on_disconnection(self, _):
self.ready.clear()
def on_packet_received(self, packet):
if self.packet_listener:
self.packet_listener.on_packet_received(packet)
async def send_packet(self, packet):
await self.speed_tx.write_value(packet)
async def drain(self):
pass
# -----------------------------------------------------------------------------
# GattServer
# -----------------------------------------------------------------------------
class GattServer:
def __init__(self, device):
self.device = device
self.packet_listener = None
self.ready = asyncio.Event()
self.overhead_size = 0
# Setup the GATT service
self.speed_tx = Characteristic(
SPEED_TX_UUID,
Characteristic.Properties.WRITE,
Characteristic.WRITEABLE,
CharacteristicValue(write=self.on_tx_write),
)
self.speed_rx = Characteristic(
SPEED_RX_UUID, Characteristic.Properties.NOTIFY, 0
)
speed_service = Service(
SPEED_SERVICE_UUID,
[self.speed_tx, self.speed_rx],
)
device.add_services([speed_service])
self.speed_rx.on('subscription', self.on_rx_subscription)
async def on_connection(self, connection):
connection.on('disconnection', self.on_disconnection)
def on_disconnection(self, _):
self.ready.clear()
def on_rx_subscription(self, _connection, notify_enabled, _indicate_enabled):
if notify_enabled:
logging.info(color('*** RX subscription', 'blue'))
self.ready.set()
else:
logging.info(color('*** RX un-subscription', 'blue'))
self.ready.clear()
def on_tx_write(self, _, value):
if self.packet_listener:
self.packet_listener.on_packet_received(value)
async def send_packet(self, packet):
await self.device.notify_subscribers(self.speed_rx, packet)
async def drain(self):
pass
# -----------------------------------------------------------------------------
# StreamedPacketIO
# -----------------------------------------------------------------------------
class StreamedPacketIO:
def __init__(self):
self.packet_listener = None
self.io_sink = None
self.rx_packet = b''
self.rx_packet_header = b''
self.rx_packet_need = 0
self.overhead_size = 2
def on_packet(self, packet):
while packet:
if self.rx_packet_need:
chunk = packet[: self.rx_packet_need]
self.rx_packet += chunk
packet = packet[len(chunk) :]
self.rx_packet_need -= len(chunk)
if not self.rx_packet_need:
# Packet completed
if self.packet_listener:
self.packet_listener.on_packet_received(self.rx_packet)
self.rx_packet = b''
self.rx_packet_header = b''
else:
# Expect the next packet
header_bytes_needed = 2 - len(self.rx_packet_header)
header_bytes = packet[:header_bytes_needed]
self.rx_packet_header += header_bytes
if len(self.rx_packet_header) != 2:
return
packet = packet[len(header_bytes) :]
self.rx_packet_need = struct.unpack('>H', self.rx_packet_header)[0]
async def send_packet(self, packet):
if not self.io_sink:
logging.info(color('!!! No sink, dropping packet', 'red'))
return
# pylint: disable-next=not-callable
self.io_sink(struct.pack('>H', len(packet)) + packet)
def can_receive(self):
return True
# -----------------------------------------------------------------------------
# L2capClient
# -----------------------------------------------------------------------------
class L2capClient(StreamedPacketIO):
def __init__(
self,
_device,
psm=DEFAULT_L2CAP_PSM,
max_credits=DEFAULT_L2CAP_MAX_CREDITS,
mtu=DEFAULT_L2CAP_MTU,
mps=DEFAULT_L2CAP_MPS,
):
super().__init__()
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.l2cap_channel = None
self.ready = asyncio.Event()
async def on_connection(self, connection: Connection) -> None:
connection.on('disconnection', self.on_disconnection)
# Connect a new L2CAP channel
logging.info(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
try:
l2cap_channel = await connection.create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(
psm=self.psm,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
)
logging.info(color(f'*** L2CAP channel: {l2cap_channel}', 'cyan'))
except Exception as error:
logging.info(color(f'!!! Connection failed: {error}', 'red'))
return
self.io_sink = l2cap_channel.write
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_packet
self.ready.set()
def on_disconnection(self, _):
pass
def on_l2cap_close(self):
logging.info(color('*** L2CAP channel closed', 'red'))
async def drain(self):
assert self.l2cap_channel
await self.l2cap_channel.drain()
# -----------------------------------------------------------------------------
# L2capServer
# -----------------------------------------------------------------------------
class L2capServer(StreamedPacketIO):
def __init__(
self,
device: Device,
psm=DEFAULT_L2CAP_PSM,
max_credits=DEFAULT_L2CAP_MAX_CREDITS,
mtu=DEFAULT_L2CAP_MTU,
mps=DEFAULT_L2CAP_MPS,
):
super().__init__()
self.l2cap_channel = None
self.ready = asyncio.Event()
# Listen for incoming L2CAP connections
device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(
psm=psm, mtu=mtu, mps=mps, max_credits=max_credits
),
handler=self.on_l2cap_channel,
)
logging.info(
color(f'### Listening for L2CAP connection on PSM {psm}', 'yellow')
)
async def on_connection(self, connection):
connection.on('disconnection', self.on_disconnection)
def on_disconnection(self, _):
pass
def on_l2cap_channel(self, l2cap_channel):
logging.info(color(f'*** L2CAP channel: {l2cap_channel}', 'cyan'))
self.io_sink = l2cap_channel.write
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_packet
self.ready.set()
def on_l2cap_close(self):
logging.info(color('*** L2CAP channel closed', 'red'))
self.l2cap_channel = None
async def drain(self):
assert self.l2cap_channel
await self.l2cap_channel.drain()
# -----------------------------------------------------------------------------
# RfcommClient
# -----------------------------------------------------------------------------
class RfcommClient(StreamedPacketIO):
def __init__(
self,
device,
channel,
uuid,
l2cap_mtu,
max_frame_size,
initial_credits,
max_credits,
credits_threshold,
):
super().__init__()
self.device = device
self.channel = channel
self.uuid = uuid
self.l2cap_mtu = l2cap_mtu
self.max_frame_size = max_frame_size
self.initial_credits = initial_credits
self.max_credits = max_credits
self.credits_threshold = credits_threshold
self.rfcomm_session = None
self.ready = asyncio.Event()
async def on_connection(self, connection):
connection.on('disconnection', self.on_disconnection)
# Find the channel number if not specified
channel = self.channel
if channel == 0:
logging.info(
color(f'@@@ Discovering channel number from UUID {self.uuid}', 'cyan')
)
channel = await bumble.rfcomm.find_rfcomm_channel_with_uuid(
connection, self.uuid
)
if channel:
logging.info(color(f'@@@ Channel number = {channel}', 'cyan'))
else:
logging.warning(
color('!!! No RFComm service with this UUID found', 'red')
)
await connection.disconnect()
return
# Create a client and start it
logging.info(color('*** Starting RFCOMM client...', 'blue'))
rfcomm_options = {}
if self.l2cap_mtu:
rfcomm_options['l2cap_mtu'] = self.l2cap_mtu
rfcomm_client = bumble.rfcomm.Client(connection, **rfcomm_options)
rfcomm_mux = await rfcomm_client.start()
logging.info(color('*** Started', 'blue'))
logging.info(color(f'### Opening session for channel {channel}...', 'yellow'))
try:
dlc_options = {}
if self.max_frame_size is not None:
dlc_options['max_frame_size'] = self.max_frame_size
if self.initial_credits is not None:
dlc_options['initial_credits'] = self.initial_credits
rfcomm_session = await rfcomm_mux.open_dlc(channel, **dlc_options)
logging.info(color(f'### Session open: {rfcomm_session}', 'yellow'))
if self.max_credits is not None:
rfcomm_session.rx_max_credits = self.max_credits
if self.credits_threshold is not None:
rfcomm_session.rx_credits_threshold = self.credits_threshold
except bumble.core.ConnectionError as error:
logging.info(color(f'!!! Session open failed: {error}', 'red'))
await rfcomm_mux.disconnect()
return
rfcomm_session.sink = self.on_packet
self.io_sink = rfcomm_session.write
self.rfcomm_session = rfcomm_session
self.ready.set()
def on_disconnection(self, _):
pass
async def drain(self):
assert self.rfcomm_session
await self.rfcomm_session.drain()
# -----------------------------------------------------------------------------
# RfcommServer
# -----------------------------------------------------------------------------
class RfcommServer(StreamedPacketIO):
def __init__(
self,
device,
channel,
l2cap_mtu,
max_frame_size,
initial_credits,
max_credits,
credits_threshold,
):
super().__init__()
self.max_credits = max_credits
self.credits_threshold = credits_threshold
self.dlc = None
self.ready = asyncio.Event()
# Create and register a server
server_options = {}
if l2cap_mtu:
server_options['l2cap_mtu'] = l2cap_mtu
rfcomm_server = bumble.rfcomm.Server(device, **server_options)
# Listen for incoming DLC connections
dlc_options = {}
if max_frame_size is not None:
dlc_options['max_frame_size'] = max_frame_size
if initial_credits is not None:
dlc_options['initial_credits'] = initial_credits
channel_number = rfcomm_server.listen(self.on_dlc, channel, **dlc_options)
# Setup the SDP to advertise this channel
device.sdp_service_records = make_sdp_records(channel_number)
logging.info(
color(
f'### Listening for RFComm connection on channel {channel_number}',
'yellow',
)
)
async def on_connection(self, connection):
connection.on('disconnection', self.on_disconnection)
def on_disconnection(self, _):
pass
def on_dlc(self, dlc):
logging.info(color(f'*** DLC connected: {dlc}', 'blue'))
if self.credits_threshold is not None:
dlc.rx_threshold = self.credits_threshold
if self.max_credits is not None:
dlc.rx_max_credits = self.max_credits
dlc.sink = self.on_packet
self.io_sink = dlc.write
self.dlc = dlc
if self.max_credits is not None:
dlc.rx_max_credits = self.max_credits
if self.credits_threshold is not None:
dlc.rx_credits_threshold = self.credits_threshold
self.ready.set()
async def drain(self):
assert self.dlc
await self.dlc.drain()
# -----------------------------------------------------------------------------
# IsoClient
# -----------------------------------------------------------------------------
class IsoClient(StreamedPacketIO):
def __init__(
self,
device: Device,
) -> None:
super().__init__()
self.device = device
self.ready = asyncio.Event()
self.cis_link: CisLink | None = None
async def on_connection(
self, connection: Connection, cis_link: CisLink, sender: bool
) -> None:
connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
self.cis_link = cis_link
self.io_sink = cis_link.write
await cis_link.setup_data_path(
cis_link.Direction.HOST_TO_CONTROLLER
if sender
else cis_link.Direction.CONTROLLER_TO_HOST
)
cis_link.sink = self.on_iso_packet
self.ready.set()
def on_iso_packet(self, iso_packet: HCI_IsoDataPacket) -> None:
self.on_packet(iso_packet.iso_sdu_fragment)
def on_disconnection(self, _):
pass
async def drain(self):
if self.cis_link is None:
return
await self.cis_link.drain()
def can_receive(self):
return False
# -----------------------------------------------------------------------------
# IsoServer
# -----------------------------------------------------------------------------
class IsoServer(StreamedPacketIO):
def __init__(
self,
device: Device,
):
super().__init__()
self.device = device
self.cis_link: CisLink | None = None
self.ready = asyncio.Event()
logging.info(
color(
'### Listening for ISO connection',
'yellow',
)
)
async def on_connection(
self, connection: Connection, cis_link: CisLink, sender: bool
) -> None:
connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
self.io_sink = cis_link.write
await cis_link.setup_data_path(
cis_link.Direction.HOST_TO_CONTROLLER
if sender
else cis_link.Direction.CONTROLLER_TO_HOST
)
cis_link.sink = self.on_iso_packet
self.ready.set()
def on_iso_packet(self, iso_packet: HCI_IsoDataPacket) -> None:
self.on_packet(iso_packet.iso_sdu_fragment)
def on_disconnection(self, _):
pass
async def drain(self):
if self.cis_link is None:
return
await self.cis_link.drain()
def can_receive(self):
return False
# -----------------------------------------------------------------------------
# Central
# -----------------------------------------------------------------------------
class Central(Connection.Listener):
def __init__(
self,
transport,
peripheral_address,
scenario_factory,
mode_factory,
connection_interval,
phy,
authenticate,
encrypt,
iso,
iso_sdu_interval_c_to_p,
iso_sdu_interval_p_to_c,
iso_max_sdu_c_to_p,
iso_max_sdu_p_to_c,
iso_max_transport_latency_c_to_p,
iso_max_transport_latency_p_to_c,
iso_rtn_c_to_p,
iso_rtn_p_to_c,
classic,
extended_data_length,
role_switch,
le_scan,
le_advertise,
classic_page_scan,
classic_inquiry_scan,
):
super().__init__()
self.transport = transport
self.peripheral_address = peripheral_address
self.classic = classic
self.iso = iso
self.iso_sdu_interval_c_to_p = iso_sdu_interval_c_to_p
self.iso_sdu_interval_p_to_c = iso_sdu_interval_p_to_c
self.iso_max_sdu_c_to_p = iso_max_sdu_c_to_p
self.iso_max_sdu_p_to_c = iso_max_sdu_p_to_c
self.iso_max_transport_latency_c_to_p = iso_max_transport_latency_c_to_p
self.iso_max_transport_latency_p_to_c = iso_max_transport_latency_p_to_c
self.iso_rtn_c_to_p = iso_rtn_c_to_p
self.iso_rtn_p_to_c = iso_rtn_p_to_c
self.scenario_factory = scenario_factory
self.mode_factory = mode_factory
self.authenticate = authenticate
self.encrypt = encrypt or authenticate
self.extended_data_length = extended_data_length
self.role_switch = role_switch
self.le_scan = le_scan
self.le_advertise = le_advertise
self.classic_page_scan = classic_page_scan
self.classic_inquiry_scan = classic_inquiry_scan
self.device = None
self.connection = None
if phy:
self.phy = {
'1m': HCI_LE_1M_PHY,
'2m': HCI_LE_2M_PHY,
'coded': HCI_LE_CODED_PHY,
}[phy]
else:
self.phy = None
if connection_interval:
connection_parameter_preferences = ConnectionParametersPreferences()
connection_parameter_preferences.connection_interval_min = (
connection_interval
)
connection_parameter_preferences.connection_interval_max = (
connection_interval
)
# Preferences for the 1M PHY are always set.
self.connection_parameter_preferences = {
HCI_LE_1M_PHY: connection_parameter_preferences,
}
if self.phy not in (None, HCI_LE_1M_PHY):
# Add an connections parameters entry for this PHY.
self.connection_parameter_preferences[self.phy] = (
connection_parameter_preferences
)
else:
self.connection_parameter_preferences = None
async def run(self):
logging.info(color('>>> Connecting to HCI...', 'green'))
async with await open_transport(self.transport) as (
hci_source,
hci_sink,
):
logging.info(color('>>> Connected', 'green'))
central_address = DEFAULT_CENTRAL_ADDRESS
self.device = Device.with_hci(
DEFAULT_CENTRAL_NAME, central_address, hci_source, hci_sink
)
mode = self.mode_factory(self.device)
scenario = self.scenario_factory(mode)
self.device.classic_enabled = self.classic
self.device.cis_enabled = self.iso
# Set up a pairing config factory with minimal requirements.
self.device.pairing_config_factory = lambda _: PairingConfig(
sc=False, mitm=False, bonding=False
)
await pre_power_on(self.device, self.classic)
await self.device.power_on()
await post_power_on(
self.device,
self.le_scan,
self.le_advertise,
self.classic_page_scan,
self.classic_inquiry_scan,
)
logging.info(
color(f'### Connecting to {self.peripheral_address}...', 'cyan')
)
try:
self.connection = await self.device.connect(
self.peripheral_address,
connection_parameters_preferences=self.connection_parameter_preferences,
transport=(
PhysicalTransport.BR_EDR
if self.classic
else PhysicalTransport.LE
),
)
except CommandTimeoutError:
logging.info(color('!!! Connection timed out', 'red'))
return
except bumble.core.ConnectionError as error:
logging.info(color(f'!!! Connection error: {error}', 'red'))
return
except HCI_StatusError as error:
logging.info(color(f'!!! Connection failed: {error.error_name}'))
return
logging.info(color('### Connected', 'cyan'))
self.connection.listener = self
print_connection(self.connection)
if not self.classic:
phy = await self.connection.get_phy()
print_connection_phy(phy)
# Switch roles if needed.
if self.role_switch:
await switch_roles(self.connection, self.role_switch)
# Wait a bit after the connection, some controllers aren't very good when
# we start sending data right away while some connection parameters are
# updated post connection
await asyncio.sleep(DEFAULT_POST_CONNECTION_WAIT_TIME)
# Request a new data length if requested
if self.extended_data_length:
logging.info(color('+++ Requesting extended data length', 'cyan'))
await self.connection.set_data_length(
self.extended_data_length[0], self.extended_data_length[1]
)
# Authenticate if requested
if self.authenticate:
# Request authentication
logging.info(color('*** Authenticating...', 'cyan'))
await self.connection.authenticate()
logging.info(color('*** Authenticated', 'cyan'))
# Encrypt if requested
if self.encrypt:
# Enable encryption
logging.info(color('*** Enabling encryption...', 'cyan'))
await self.connection.encrypt()
logging.info(color('*** Encryption on', 'cyan'))
# Set the PHY if requested
if self.phy is not None:
try:
await self.connection.set_phy(
tx_phys=[self.phy], rx_phys=[self.phy]
)
except HCI_Error as error:
logging.info(
color(
f'!!! Unable to set the PHY: {error.error_name}', 'yellow'
)
)
# Setup ISO streams.
if self.iso:
if scenario.is_sender():
sdu_interval_c_to_p = (
self.iso_sdu_interval_c_to_p or DEFAULT_ISO_SDU_INTERVAL_C_TO_P
)
sdu_interval_p_to_c = self.iso_sdu_interval_p_to_c or 0
max_transport_latency_c_to_p = (
self.iso_max_transport_latency_c_to_p
or DEFAULT_ISO_MAX_TRANSPORT_LATENCY_C_TO_P
)
max_transport_latency_p_to_c = (
self.iso_max_transport_latency_p_to_c or 0
)
max_sdu_c_to_p = (
self.iso_max_sdu_c_to_p or DEFAULT_ISO_MAX_SDU_C_TO_P
)
max_sdu_p_to_c = self.iso_max_sdu_p_to_c or 0
rtn_c_to_p = self.iso_rtn_c_to_p or DEFAULT_ISO_RTN_C_TO_P
rtn_p_to_c = self.iso_rtn_p_to_c or 0
else:
sdu_interval_p_to_c = (
self.iso_sdu_interval_p_to_c or DEFAULT_ISO_SDU_INTERVAL_P_TO_C
)
sdu_interval_c_to_p = self.iso_sdu_interval_c_to_p or 0
max_transport_latency_p_to_c = (
self.iso_max_transport_latency_p_to_c
or DEFAULT_ISO_MAX_TRANSPORT_LATENCY_P_TO_C
)
max_transport_latency_c_to_p = (
self.iso_max_transport_latency_c_to_p or 0
)
max_sdu_p_to_c = (
self.iso_max_sdu_p_to_c or DEFAULT_ISO_MAX_SDU_P_TO_C
)
max_sdu_c_to_p = self.iso_max_sdu_c_to_p or 0
rtn_p_to_c = self.iso_rtn_p_to_c or DEFAULT_ISO_RTN_P_TO_C
rtn_c_to_p = self.iso_rtn_c_to_p or 0
cis_handles = await self.device.setup_cig(
CigParameters(
cig_id=1,
sdu_interval_c_to_p=sdu_interval_c_to_p,
sdu_interval_p_to_c=sdu_interval_p_to_c,
max_transport_latency_c_to_p=max_transport_latency_c_to_p,
max_transport_latency_p_to_c=max_transport_latency_p_to_c,
cis_parameters=[
CigParameters.CisParameters(
cis_id=2,
max_sdu_c_to_p=max_sdu_c_to_p,
max_sdu_p_to_c=max_sdu_p_to_c,
rtn_c_to_p=rtn_c_to_p,
rtn_p_to_c=rtn_p_to_c,
)
],
)
)
cis_link = (
await self.device.create_cis([(cis_handles[0], self.connection)])
)[0]
print_cis_link(cis_link)
await mode.on_connection(
self.connection, cis_link, scenario.is_sender()
)
else:
await mode.on_connection(self.connection)
await scenario.run()
await asyncio.sleep(DEFAULT_LINGER_TIME)
await self.connection.disconnect()
def on_disconnection(self, reason):
logging.info(color(f'!!! Disconnection: reason={reason}', 'red'))
self.connection = None
def on_connection_parameters_update(self):
print_connection(self.connection)
def on_connection_phy_update(self, phy):
print_connection_phy(phy)
def on_connection_att_mtu_update(self):
print_connection(self.connection)
def on_connection_data_length_change(self):
print_connection(self.connection)
def on_role_change(self):
print_connection(self.connection)
# -----------------------------------------------------------------------------
# Peripheral
# -----------------------------------------------------------------------------
class Peripheral(Device.Listener, Connection.Listener):
def __init__(
self,
transport,
scenario_factory,
mode_factory,
classic,
iso,
extended_data_length,
role_switch,
le_scan,
le_advertise,
classic_page_scan,
classic_inquiry_scan,
):
self.transport = transport
self.classic = classic
self.iso = iso
self.scenario_factory = scenario_factory
self.mode_factory = mode_factory
self.extended_data_length = extended_data_length
self.role_switch = role_switch
self.le_scan = le_scan
self.classic_page_scan = classic_page_scan
self.classic_inquiry_scan = classic_inquiry_scan
self.scenario = None
self.mode = None
self.device = None
self.connection = None
self.connected = asyncio.Event()
if le_advertise:
self.le_advertise = le_advertise
else:
self.le_advertise = 0 if classic else DEFAULT_ADVERTISING_INTERVAL
async def run(self):
logging.info(color('>>> Connecting to HCI...', 'green'))
async with await open_transport(self.transport) as (
hci_source,
hci_sink,
):
logging.info(color('>>> Connected', 'green'))
peripheral_address = DEFAULT_PERIPHERAL_ADDRESS
self.device = Device.with_hci(
DEFAULT_PERIPHERAL_NAME, peripheral_address, hci_source, hci_sink
)
self.device.listener = self
self.mode = self.mode_factory(self.device)
self.scenario = self.scenario_factory(self.mode)
self.device.classic_enabled = self.classic
self.device.cis_enabled = self.iso
# Set up a pairing config factory with minimal requirements.
self.device.pairing_config_factory = lambda _: PairingConfig(
sc=False, mitm=False, bonding=False
)
await pre_power_on(self.device, self.classic)
await self.device.power_on()
await post_power_on(
self.device,
self.le_scan,
self.le_advertise,
self.classic or self.classic_page_scan,
self.classic or self.classic_inquiry_scan,
)
if self.classic:
logging.info(
color(
'### Waiting for connection on'
f' {self.device.public_address}...',
'cyan',
)
)
else:
logging.info(
color(
f'### Waiting for connection on {peripheral_address}...',
'cyan',
)
)
await self.connected.wait()
logging.info(color('### Connected', 'cyan'))
print_connection(self.connection)
if self.iso:
async def on_cis_request(cis_link: CisLink) -> None:
logging.info(color("@@@ Accepting CIS", "green"))
await self.device.accept_cis_request(cis_link)
print_cis_link(cis_link)
await self.mode.on_connection(
self.connection, cis_link, self.scenario.is_sender()
)
self.connection.on(self.connection.EVENT_CIS_REQUEST, on_cis_request)
else:
await self.mode.on_connection(self.connection)
await self.scenario.run()
await asyncio.sleep(DEFAULT_LINGER_TIME)
def on_connection(self, connection):
connection.listener = self
self.connection = connection
self.connected.set()
# Stop being discoverable and connectable if possible
if self.classic:
if not self.classic_inquiry_scan:
logging.info(color("*** Stopping inquiry scan", "blue"))
AsyncRunner.spawn(self.device.set_discoverable(False))
if not self.classic_page_scan:
logging.info(color("*** Stopping page scan", "blue"))
AsyncRunner.spawn(self.device.set_connectable(False))
# Request a new data length if needed
if not self.classic and self.extended_data_length:
logging.info("+++ Requesting extended data length")
AsyncRunner.spawn(
connection.set_data_length(
self.extended_data_length[0], self.extended_data_length[1]
)
)
# Switch roles if needed.
if self.role_switch:
AsyncRunner.spawn(switch_roles(connection, self.role_switch))
def on_disconnection(self, reason):
logging.info(color(f'!!! Disconnection: reason={reason}', 'red'))
self.connection = None
self.scenario.reset()
if self.classic:
logging.info(color("*** Enabling inquiry scan", "blue"))
AsyncRunner.spawn(self.device.set_discoverable(True))
logging.info(color("*** Enabling page scan", "blue"))
AsyncRunner.spawn(self.device.set_connectable(True))
def on_connection_parameters_update(self):
print_connection(self.connection)
def on_connection_phy_update(self, phy):
print_connection_phy(phy)
def on_connection_att_mtu_update(self):
print_connection(self.connection)
def on_connection_data_length_change(self):
print_connection(self.connection)
def on_role_change(self):
print_connection(self.connection)
# -----------------------------------------------------------------------------
def create_mode_factory(ctx, default_mode):
mode = ctx.obj['mode']
if mode is None:
mode = default_mode
def create_mode(device):
if mode == 'gatt-client':
return GattClient(device, att_mtu=ctx.obj['att_mtu'])
if mode == 'gatt-server':
return GattServer(device)
if mode == 'l2cap-client':
return L2capClient(
device,
psm=ctx.obj['l2cap_psm'],
mtu=ctx.obj['l2cap_mtu'],
mps=ctx.obj['l2cap_mps'],
max_credits=ctx.obj['l2cap_max_credits'],
)
if mode == 'l2cap-server':
return L2capServer(
device,
psm=ctx.obj['l2cap_psm'],
mtu=ctx.obj['l2cap_mtu'],
mps=ctx.obj['l2cap_mps'],
max_credits=ctx.obj['l2cap_max_credits'],
)
if mode == 'rfcomm-client':
return RfcommClient(
device,
channel=ctx.obj['rfcomm_channel'],
uuid=ctx.obj['rfcomm_uuid'],
l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'],
max_frame_size=ctx.obj['rfcomm_max_frame_size'],
initial_credits=ctx.obj['rfcomm_initial_credits'],
max_credits=ctx.obj['rfcomm_max_credits'],
credits_threshold=ctx.obj['rfcomm_credits_threshold'],
)
if mode == 'rfcomm-server':
return RfcommServer(
device,
channel=ctx.obj['rfcomm_channel'],
l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'],
max_frame_size=ctx.obj['rfcomm_max_frame_size'],
initial_credits=ctx.obj['rfcomm_initial_credits'],
max_credits=ctx.obj['rfcomm_max_credits'],
credits_threshold=ctx.obj['rfcomm_credits_threshold'],
)
if mode == 'iso-server':
return IsoServer(device)
if mode == 'iso-client':
return IsoClient(device)
raise ValueError('invalid mode')
return create_mode
# -----------------------------------------------------------------------------
def create_scenario_factory(ctx, default_scenario):
scenario = ctx.obj['scenario']
if scenario is None:
scenario = default_scenario
def create_scenario(packet_io):
if scenario == 'send':
return Sender(
packet_io,
start_delay=ctx.obj['start_delay'],
repeat=ctx.obj['repeat'],
repeat_delay=ctx.obj['repeat_delay'],
pace=ctx.obj['pace'],
packet_size=ctx.obj['packet_size'],
packet_count=ctx.obj['packet_count'],
)
if scenario == 'receive':
return Receiver(packet_io, ctx.obj['linger'])
if scenario == 'ping':
if isinstance(packet_io, (IsoClient, IsoServer)):
raise ValueError('ping not supported with ISO')
return Ping(
packet_io,
start_delay=ctx.obj['start_delay'],
repeat=ctx.obj['repeat'],
repeat_delay=ctx.obj['repeat_delay'],
pace=ctx.obj['pace'],
packet_size=ctx.obj['packet_size'],
packet_count=ctx.obj['packet_count'],
)
if scenario == 'pong':
if isinstance(packet_io, (IsoClient, IsoServer)):
raise ValueError('pong not supported with ISO')
return Pong(packet_io, ctx.obj['linger'])
raise ValueError('invalid scenario')
return create_scenario
# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
@click.group()
@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
@click.option('--scenario', type=click.Choice(['send', 'receive', 'ping', 'pong']))
@click.option(
'--mode',
type=click.Choice(
[
'gatt-client',
'gatt-server',
'l2cap-client',
'l2cap-server',
'rfcomm-client',
'rfcomm-server',
'iso-client',
'iso-server',
]
),
)
@click.option(
'--att-mtu',
metavar='MTU',
type=click.IntRange(23, 517),
default=517,
help='GATT MTU (gatt-client mode)',
)
@click.option(
'--extended-data-length',
metavar='<TX-OCTETS>/<TX-TIME>',
help='Request a data length upon connection, specified as tx_octets/tx_time',
)
@click.option(
'--role-switch',
type=click.Choice(['central', 'peripheral']),
help='Request role switch upon connection (central or peripheral)',
)
@click.option(
'--le-scan',
metavar='<WINDOW>/<INTERVAL>',
help='Perform an LE scan with a given window and interval (milliseconds)',
)
@click.option(
'--le-advertise',
metavar='<INTERVAL>',
help='Advertise with a given interval (milliseconds)',
)
@click.option(
'--classic-page-scan',
is_flag=True,
help='Enable Classic page scanning',
)
@click.option(
'--classic-inquiry-scan',
is_flag=True,
help='Enable Classic enquiry scanning',
)
@click.option(
'--rfcomm-channel',
type=int,
default=DEFAULT_RFCOMM_CHANNEL,
help='RFComm channel to use (specify 0 for channel discovery via SDP)',
)
@click.option(
'--rfcomm-uuid',
default=DEFAULT_RFCOMM_UUID,
help='RFComm service UUID to use (ignored if --rfcomm-channel is not 0)',
)
@click.option(
'--rfcomm-l2cap-mtu',
type=int,
help='RFComm L2CAP MTU',
)
@click.option(
'--rfcomm-max-frame-size',
type=int,
help='RFComm maximum frame size',
)
@click.option(
'--rfcomm-initial-credits',
type=int,
help='RFComm initial credits',
)
@click.option(
'--rfcomm-max-credits',
type=int,
help='RFComm max credits',
)
@click.option(
'--rfcomm-credits-threshold',
type=int,
help='RFComm credits threshold',
)
@click.option(
'--l2cap-psm',
type=int,
default=DEFAULT_L2CAP_PSM,
help='L2CAP PSM to use',
)
@click.option(
'--l2cap-mtu',
type=int,
default=DEFAULT_L2CAP_MTU,
help='L2CAP MTU to use',
)
@click.option(
'--l2cap-mps',
type=int,
default=DEFAULT_L2CAP_MPS,
help='L2CAP MPS to use',
)
@click.option(
'--l2cap-max-credits',
type=int,
default=DEFAULT_L2CAP_MAX_CREDITS,
help='L2CAP maximum number of credits allowed for the peer',
)
@click.option(
'--packet-size',
'-s',
metavar='SIZE',
type=click.IntRange(10, 8192),
default=500,
help='Packet size (send or ping scenario)',
)
@click.option(
'--packet-count',
'-c',
metavar='COUNT',
type=int,
default=10,
help='Packet count (send or ping scenario)',
)
@click.option(
'--start-delay',
'-sd',
metavar='SECONDS',
type=int,
default=1,
help='Start delay (send or ping scenario)',
)
@click.option(
'--repeat',
metavar='N',
type=int,
default=0,
help=(
'Repeat the run N times (send and ping scenario)'
'(0, which is the fault, to run just once) '
),
)
@click.option(
'--repeat-delay',
metavar='SECONDS',
type=int,
default=1,
help=('Delay, in seconds, between repeats'),
)
@click.option(
'--pace',
metavar='MILLISECONDS',
type=int,
default=0,
help=(
'Wait N milliseconds between packets '
'(0, which is the fault, to send as fast as possible) '
),
)
@click.option(
'--linger',
is_flag=True,
help="Don't exit at the end of a run (receive and pong scenarios)",
)
@click.pass_context
def bench(
ctx,
device_config,
scenario,
mode,
att_mtu,
extended_data_length,
role_switch,
le_scan,
le_advertise,
classic_page_scan,
classic_inquiry_scan,
packet_size,
packet_count,
start_delay,
repeat,
repeat_delay,
pace,
linger,
rfcomm_channel,
rfcomm_uuid,
rfcomm_l2cap_mtu,
rfcomm_max_frame_size,
rfcomm_initial_credits,
rfcomm_max_credits,
rfcomm_credits_threshold,
l2cap_psm,
l2cap_mtu,
l2cap_mps,
l2cap_max_credits,
):
ctx.ensure_object(dict)
ctx.obj['device_config'] = device_config
ctx.obj['scenario'] = scenario
ctx.obj['mode'] = mode
ctx.obj['att_mtu'] = att_mtu
ctx.obj['rfcomm_channel'] = rfcomm_channel
ctx.obj['rfcomm_uuid'] = rfcomm_uuid
ctx.obj['rfcomm_l2cap_mtu'] = rfcomm_l2cap_mtu
ctx.obj['rfcomm_max_frame_size'] = rfcomm_max_frame_size
ctx.obj['rfcomm_initial_credits'] = rfcomm_initial_credits
ctx.obj['rfcomm_max_credits'] = rfcomm_max_credits
ctx.obj['rfcomm_credits_threshold'] = rfcomm_credits_threshold
ctx.obj['l2cap_psm'] = l2cap_psm
ctx.obj['l2cap_mtu'] = l2cap_mtu
ctx.obj['l2cap_mps'] = l2cap_mps
ctx.obj['l2cap_max_credits'] = l2cap_max_credits
ctx.obj['packet_size'] = packet_size
ctx.obj['packet_count'] = packet_count
ctx.obj['start_delay'] = start_delay
ctx.obj['repeat'] = repeat
ctx.obj['repeat_delay'] = repeat_delay
ctx.obj['pace'] = pace
ctx.obj['linger'] = linger
ctx.obj['extended_data_length'] = (
[int(x) for x in extended_data_length.split('/')]
if extended_data_length
else None
)
ctx.obj['role_switch'] = role_switch
ctx.obj['le_scan'] = [float(x) for x in le_scan.split('/')] if le_scan else None
ctx.obj['le_advertise'] = float(le_advertise) if le_advertise else None
ctx.obj['classic_page_scan'] = classic_page_scan
ctx.obj['classic_inquiry_scan'] = classic_inquiry_scan
ctx.obj['classic'] = mode in ('rfcomm-client', 'rfcomm-server')
ctx.obj['iso'] = mode in ('iso-client', 'iso-server')
@bench.command()
@click.argument('transport')
@click.option(
'--peripheral',
'peripheral_address',
metavar='ADDRESS_OR_NAME',
default=DEFAULT_PERIPHERAL_ADDRESS,
help='Address or name to connect to',
)
@click.option(
'--connection-interval',
'--ci',
metavar='CONNECTION_INTERVAL',
type=int,
help='Connection interval (in ms)',
)
@click.option('--phy', type=click.Choice(['1m', '2m', 'coded']), help='PHY to use')
@click.option('--authenticate', is_flag=True, help='Authenticate (RFComm only)')
@click.option('--encrypt', is_flag=True, help='Encrypt the connection (RFComm only)')
@click.option(
'--iso-sdu-interval-c-to-p',
type=int,
help='ISO SDU central -> peripheral (microseconds)',
)
@click.option(
'--iso-sdu-interval-p-to-c',
type=int,
help='ISO SDU interval peripheral -> central (microseconds)',
)
@click.option(
'--iso-max-sdu-c-to-p',
type=int,
help='ISO max SDU central -> peripheral',
)
@click.option(
'--iso-max-sdu-p-to-c',
type=int,
help='ISO max SDU peripheral -> central',
)
@click.option(
'--iso-max-transport-latency-c-to-p',
type=int,
help='ISO max transport latency central -> peripheral (milliseconds)',
)
@click.option(
'--iso-max-transport-latency-p-to-c',
type=int,
help='ISO max transport latency peripheral -> central (milliseconds)',
)
@click.option(
'--iso-rtn-c-to-p',
type=int,
help='ISO RTN central -> peripheral (integer count)',
)
@click.option(
'--iso-rtn-p-to-c',
type=int,
help='ISO RTN peripheral -> central (integer count)',
)
@click.pass_context
def central(
ctx,
transport,
peripheral_address,
connection_interval,
phy,
authenticate,
encrypt,
iso_sdu_interval_c_to_p,
iso_sdu_interval_p_to_c,
iso_max_sdu_c_to_p,
iso_max_sdu_p_to_c,
iso_max_transport_latency_c_to_p,
iso_max_transport_latency_p_to_c,
iso_rtn_c_to_p,
iso_rtn_p_to_c,
):
"""Run as a central (initiates the connection)"""
scenario_factory = create_scenario_factory(ctx, 'send')
mode_factory = create_mode_factory(ctx, 'gatt-client')
async def run_central():
await Central(
transport,
peripheral_address,
scenario_factory,
mode_factory,
connection_interval,
phy,
authenticate,
encrypt or authenticate,
ctx.obj['iso'],
iso_sdu_interval_c_to_p,
iso_sdu_interval_p_to_c,
iso_max_sdu_c_to_p,
iso_max_sdu_p_to_c,
iso_max_transport_latency_c_to_p,
iso_max_transport_latency_p_to_c,
iso_rtn_c_to_p,
iso_rtn_p_to_c,
ctx.obj['classic'],
ctx.obj['extended_data_length'],
ctx.obj['role_switch'],
ctx.obj['le_scan'],
ctx.obj['le_advertise'],
ctx.obj['classic_page_scan'],
ctx.obj['classic_inquiry_scan'],
).run()
asyncio.run(run_central())
@bench.command()
@click.argument('transport')
@click.pass_context
def peripheral(ctx, transport):
"""Run as a peripheral (waits for a connection)"""
scenario_factory = create_scenario_factory(ctx, 'receive')
mode_factory = create_mode_factory(ctx, 'gatt-server')
async def run_peripheral():
await Peripheral(
transport,
scenario_factory,
mode_factory,
ctx.obj['classic'],
ctx.obj['iso'],
ctx.obj['extended_data_length'],
ctx.obj['role_switch'],
ctx.obj['le_scan'],
ctx.obj['le_advertise'],
ctx.obj['classic_page_scan'],
ctx.obj['classic_inquiry_scan'],
).run()
asyncio.run(run_peripheral())
def main():
bumble.logging.setup_basic_logging('INFO')
bench()
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter