# Copyright 2021-2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- import asyncio import dataclasses import enum import logging import os import statistics import struct import time from typing import Optional import click from bumble import l2cap from bumble.core import ( PhysicalTransport, BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID, UUID, CommandTimeoutError, ) from bumble.colors import color from bumble.core import ConnectionPHY from bumble.device import ( CigParameters, CisLink, Connection, ConnectionParametersPreferences, Device, Peer, ) from bumble.gatt import Characteristic, CharacteristicValue, Service from bumble.hci import ( HCI_LE_1M_PHY, HCI_LE_2M_PHY, HCI_LE_CODED_PHY, Role, HCI_Constant, HCI_Error, HCI_StatusError, HCI_IsoDataPacket, ) from bumble.sdp import ( SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PUBLIC_BROWSE_ROOT, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement, ServiceAttribute, ) from bumble.transport import open_transport_or_link import bumble.rfcomm import bumble.core from bumble.utils import AsyncRunner from bumble.pairing import PairingConfig # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- DEFAULT_CENTRAL_ADDRESS = 'F0:F0:F0:F0:F0:F0' DEFAULT_CENTRAL_NAME = 'Speed Central' DEFAULT_PERIPHERAL_ADDRESS = 'F1:F1:F1:F1:F1:F1' DEFAULT_PERIPHERAL_NAME = 'Speed Peripheral' DEFAULT_ADVERTISING_INTERVAL = 100 SPEED_SERVICE_UUID = '50DB505C-8AC4-4738-8448-3B1D9CC09CC5' SPEED_TX_UUID = 'E789C754-41A1-45F4-A948-A0A1A90DBA53' SPEED_RX_UUID = '016A2CC7-E14B-4819-935F-1F56EAE4098D' DEFAULT_RFCOMM_UUID = 'E6D55659-C8B4-4B85-96BB-B1143AF6D3AE' DEFAULT_L2CAP_PSM = 128 DEFAULT_L2CAP_MAX_CREDITS = 128 DEFAULT_L2CAP_MTU = 1024 DEFAULT_L2CAP_MPS = 1024 DEFAULT_ISO_MAX_SDU_C_TO_P = 251 DEFAULT_ISO_MAX_SDU_P_TO_C = 251 DEFAULT_ISO_SDU_INTERVAL_C_TO_P = 10000 DEFAULT_ISO_SDU_INTERVAL_P_TO_C = 10000 DEFAULT_ISO_MAX_TRANSPORT_LATENCY_C_TO_P = 35 DEFAULT_ISO_MAX_TRANSPORT_LATENCY_P_TO_C = 35 DEFAULT_ISO_RTN_C_TO_P = 3 DEFAULT_ISO_RTN_P_TO_C = 3 DEFAULT_LINGER_TIME = 1.0 DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0 DEFAULT_RFCOMM_CHANNEL = 8 DEFAULT_RFCOMM_MTU = 2048 # ----------------------------------------------------------------------------- # Utils # ----------------------------------------------------------------------------- def le_phy_name(phy_id): return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get( phy_id, HCI_Constant.le_phy_name(phy_id) ) def print_connection_phy(phy: ConnectionPHY) -> None: logging.info( color('@@@ PHY: ', 'yellow') + f'TX:{le_phy_name(phy.tx_phy)}/' f'RX:{le_phy_name(phy.rx_phy)}' ) def print_connection(connection: Connection) -> None: params = [] if connection.transport == PhysicalTransport.LE: params.append( 'DL=(' f'TX:{connection.data_length[0]}/{connection.data_length[1]},' f'RX:{connection.data_length[2]}/{connection.data_length[3]}' ')' ) params.append( 'Parameters=' f'{connection.parameters.connection_interval:.2f}/' f'{connection.parameters.peripheral_latency}/' f'{connection.parameters.supervision_timeout:.2f} ' ) params.append(f'MTU={connection.att_mtu}') else: params.append(f'Role={HCI_Constant.role_name(connection.role)}') logging.info(color('@@@ Connection: ', 'yellow') + ' '.join(params)) def print_cis_link(cis_link: CisLink) -> None: logging.info(color("@@@ CIS established", "green")) logging.info(color('@@@ ISO interval: ', 'green') + f"{cis_link.iso_interval}ms") logging.info(color('@@@ NSE: ', 'green') + f"{cis_link.nse}") logging.info(color('@@@ Central->Peripheral:', 'green')) if cis_link.phy_c_to_p is not None: logging.info( color('@@@ PHY: ', 'green') + f"{cis_link.phy_c_to_p.name}" ) logging.info( color('@@@ Latency: ', 'green') + f"{cis_link.transport_latency_c_to_p}µs" ) logging.info(color('@@@ BN: ', 'green') + f"{cis_link.bn_c_to_p}") logging.info(color('@@@ FT: ', 'green') + f"{cis_link.ft_c_to_p}") logging.info(color('@@@ Max PDU: ', 'green') + f"{cis_link.max_pdu_c_to_p}") logging.info(color('@@@ Peripheral->Central:', 'green')) if cis_link.phy_p_to_c is not None: logging.info( color('@@@ PHY: ', 'green') + f"{cis_link.phy_p_to_c.name}" ) logging.info( color('@@@ Latency: ', 'green') + f"{cis_link.transport_latency_p_to_c}µs" ) logging.info(color('@@@ BN: ', 'green') + f"{cis_link.bn_p_to_c}") logging.info(color('@@@ FT: ', 'green') + f"{cis_link.ft_p_to_c}") logging.info(color('@@@ Max PDU: ', 'green') + f"{cis_link.max_pdu_p_to_c}") def make_sdp_records(channel): return { 0x00010001: [ ServiceAttribute( SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(0x00010001), ), ServiceAttribute( SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), ), ServiceAttribute( SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([DataElement.uuid(UUID(DEFAULT_RFCOMM_UUID))]), ), ServiceAttribute( SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence( [ DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]), DataElement.sequence( [ DataElement.uuid(BT_RFCOMM_PROTOCOL_ID), DataElement.unsigned_integer_8(channel), ] ), ] ), ), ] } def log_stats(title, stats, precision=2): stats_min = min(stats) stats_max = max(stats) stats_avg = statistics.mean(stats) stats_stdev = statistics.stdev(stats) if len(stats) >= 2 else 0 logging.info( color( ( f'### {title} stats: ' f'min={stats_min:.{precision}f}, ' f'max={stats_max:.{precision}f}, ' f'average={stats_avg:.{precision}f}, ' f'stdev={stats_stdev:.{precision}f}' ), 'cyan', ) ) async def switch_roles(connection, role): target_role = Role.CENTRAL if role == "central" else Role.PERIPHERAL if connection.role != target_role: logging.info(f'{color("### Switching roles to:", "cyan")} {role}') try: await connection.switch_role(target_role) logging.info(color('### Role switch complete', 'cyan')) except HCI_Error as error: logging.info(f'{color("### Role switch failed:", "red")} {error}') async def pre_power_on(device: Device, classic: bool) -> None: device.classic_enabled = classic # Set up a pairing config factory with minimal requirements. device.config.keystore = "JsonKeyStore" device.pairing_config_factory = lambda _: PairingConfig( sc=False, mitm=False, bonding=False ) async def post_power_on( device: Device, le_scan: Optional[tuple[int, int]], le_advertise: Optional[int], classic_page_scan: bool, classic_inquiry_scan: bool, ) -> None: if classic_page_scan: logging.info(color("*** Enabling page scan", "blue")) await device.set_connectable(True) if classic_inquiry_scan: logging.info(color("*** Enabling inquiry scan", "blue")) await device.set_discoverable(True) if le_scan: scan_window, scan_interval = le_scan logging.info( color( f"*** Starting LE scanning [{scan_window}ms/{scan_interval}ms]", "blue", ) ) await device.start_scanning( scan_interval=scan_interval, scan_window=scan_window ) if le_advertise: logging.info(color(f"*** Starting LE advertising [{le_advertise}ms]", "blue")) await device.start_advertising( advertising_interval_min=le_advertise, advertising_interval_max=le_advertise, auto_restart=True, ) # ----------------------------------------------------------------------------- # Packet # ----------------------------------------------------------------------------- @dataclasses.dataclass class Packet: class PacketType(enum.IntEnum): RESET = 0 SEQUENCE = 1 ACK = 2 class PacketFlags(enum.IntFlag): LAST = 1 packet_type: PacketType flags: PacketFlags = PacketFlags(0) sequence: int = 0 timestamp: int = 0 payload: bytes = b"" @classmethod def from_bytes(cls, data: bytes): if len(data) < 1: logging.warning( color(f'!!! Packet too short (got {len(data)} bytes, need >= 1)', 'red') ) raise ValueError('packet too short') try: packet_type = cls.PacketType(data[0]) except ValueError: logging.warning(color(f'!!! Invalid packet type 0x{data[0]:02X}', 'red')) raise if packet_type == cls.PacketType.RESET: return cls(packet_type) flags = cls.PacketFlags(data[1]) (sequence,) = struct.unpack_from("= 6)', 'red', ) ) return cls(packet_type, flags, sequence) if len(data) < 10: logging.warning( color( f'!!! Packet too short (got {len(data)} bytes, need >= 10)', 'red' ) ) raise ValueError('packet too short') (timestamp,) = struct.unpack_from(" 1: expected_time = ( self.receive_times[0] + (packet.timestamp - self.packets[0].timestamp) / 1000000 ) jitter = now - expected_time else: jitter = 0.0 self.jitter.append(jitter) return jitter def show_stats(self): if len(self.jitter) < 3: return average = sum(self.jitter) / len(self.jitter) adjusted = [jitter - average for jitter in self.jitter] log_stats('Jitter (signed)', adjusted, 3) log_stats('Jitter (absolute)', [abs(jitter) for jitter in adjusted], 3) # Show a histogram bin_count = 20 bins = [0] * bin_count interval_min = min(adjusted) interval_max = max(adjusted) interval_range = interval_max - interval_min bin_thresholds = [ interval_min + i * (interval_range / bin_count) for i in range(bin_count) ] for jitter in adjusted: for i in reversed(range(bin_count)): if jitter >= bin_thresholds[i]: bins[i] += 1 break for i in range(bin_count): logging.info(f'@@@ >= {bin_thresholds[i]:.4f}: {bins[i]}') # ----------------------------------------------------------------------------- # Sender # ----------------------------------------------------------------------------- class Sender: def __init__( self, packet_io, start_delay, repeat, repeat_delay, pace, packet_size, packet_count, ): self.tx_start_delay = start_delay self.tx_packet_size = packet_size self.tx_packet_count = packet_count self.packet_io = packet_io self.packet_io.packet_listener = self self.repeat = repeat self.repeat_delay = repeat_delay self.pace = pace self.start_time = 0 self.bytes_sent = 0 self.stats = [] self.done = asyncio.Event() def reset(self): pass async def run(self): logging.info(color('--- Waiting for I/O to be ready...', 'blue')) await self.packet_io.ready.wait() logging.info(color('--- Go!', 'blue')) for run in range(self.repeat + 1): self.done.clear() if run > 0 and self.repeat and self.repeat_delay: logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green')) await asyncio.sleep(self.repeat_delay) if self.tx_start_delay: logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue')) await asyncio.sleep(self.tx_start_delay) logging.info(color('=== Sending RESET', 'magenta')) await self.packet_io.send_packet( bytes(Packet(packet_type=Packet.PacketType.RESET)) ) self.start_time = time.time() self.bytes_sent = 0 for tx_i in range(self.tx_packet_count): if self.pace > 0: # Wait until it is time to send the next packet target_time = self.start_time + (tx_i * self.pace / 1000) now = time.time() if now < target_time: await asyncio.sleep(target_time - now) else: await self.packet_io.drain() packet = bytes( Packet( packet_type=Packet.PacketType.SEQUENCE, flags=( Packet.PacketFlags.LAST if tx_i == self.tx_packet_count - 1 else 0 ), sequence=tx_i, timestamp=int((time.time() - self.start_time) * 1000000), payload=bytes( self.tx_packet_size - 10 - self.packet_io.overhead_size ), ) ) logging.info( color( f'Sending packet {tx_i}: {self.tx_packet_size} bytes', 'yellow' ) ) self.bytes_sent += len(packet) await self.packet_io.send_packet(packet) if self.packet_io.can_receive(): await self.done.wait() run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else '' logging.info(color(f'=== {run_counter} Done!', 'magenta')) if self.repeat: log_stats('Run', self.stats) if self.repeat: logging.info(color('--- End of runs', 'blue')) def on_packet_received(self, data): try: packet = Packet.from_bytes(data) except ValueError: return if packet.packet_type == Packet.PacketType.ACK: elapsed = time.time() - self.start_time average_tx_speed = self.bytes_sent / elapsed self.stats.append(average_tx_speed) logging.info( color( f'@@@ Received ACK. Speed: average={average_tx_speed:.4f}' f' ({self.bytes_sent} bytes in {elapsed:.2f} seconds)', 'green', ) ) self.done.set() def is_sender(self): return True # ----------------------------------------------------------------------------- # Receiver # ----------------------------------------------------------------------------- class Receiver: expected_packet_index: int start_timestamp: float last_timestamp: float def __init__(self, packet_io, linger): self.jitter_stats = JitterStats() self.packet_io = packet_io self.packet_io.packet_listener = self self.linger = linger self.done = asyncio.Event() self.reset() def reset(self): self.expected_packet_index = 0 self.measurements = [(time.time(), 0)] self.total_bytes_received = 0 self.jitter_stats.reset() def on_packet_received(self, data): try: packet = Packet.from_bytes(data) except ValueError: logging.exception("invalid packet") return if packet.packet_type == Packet.PacketType.RESET: logging.info(color('=== Received RESET', 'magenta')) self.reset() return jitter = self.jitter_stats.on_packet_received(packet) logging.info( f'<<< Received packet {packet.sequence}: ' f'flags={packet.flags}, ' f'jitter={jitter:.4f}, ' f'{len(data) + self.packet_io.overhead_size} bytes', ) if packet.sequence != self.expected_packet_index: logging.info( color( f'!!! Unexpected packet, expected {self.expected_packet_index} ' f'but received {packet.sequence}', 'red', ) ) now = time.time() elapsed_since_start = now - self.measurements[0][0] elapsed_since_last = now - self.measurements[-1][0] self.measurements.append((now, len(data))) self.total_bytes_received += len(data) instant_rx_speed = len(data) / elapsed_since_last average_rx_speed = self.total_bytes_received / elapsed_since_start window = self.measurements[-64:] windowed_rx_speed = sum(measurement[1] for measurement in window[1:]) / ( window[-1][0] - window[0][0] ) logging.info( color( 'Speed: ' f'instant={instant_rx_speed:.4f}, ' f'windowed={windowed_rx_speed:.4f}, ' f'average={average_rx_speed:.4f}', 'yellow', ) ) self.expected_packet_index = packet.sequence + 1 if packet.flags & Packet.PacketFlags.LAST: AsyncRunner.spawn( self.packet_io.send_packet( bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence)) ) ) logging.info(color('@@@ Received last packet', 'green')) self.jitter_stats.show_stats() if not self.linger: self.done.set() async def run(self): await self.done.wait() logging.info(color('=== Done!', 'magenta')) def is_sender(self): return False # ----------------------------------------------------------------------------- # Ping # ----------------------------------------------------------------------------- class Ping: def __init__( self, packet_io, start_delay, repeat, repeat_delay, pace, packet_size, packet_count, ): self.tx_start_delay = start_delay self.tx_packet_size = packet_size self.tx_packet_count = packet_count self.packet_io = packet_io self.packet_io.packet_listener = self self.repeat = repeat self.repeat_delay = repeat_delay self.pace = pace self.done = asyncio.Event() self.ping_times = [] self.rtts = [] self.next_expected_packet_index = 0 self.min_stats = [] self.max_stats = [] self.avg_stats = [] def reset(self): pass async def run(self): logging.info(color('--- Waiting for I/O to be ready...', 'blue')) await self.packet_io.ready.wait() logging.info(color('--- Go!', 'blue')) for run in range(self.repeat + 1): self.done.clear() self.ping_times = [] if run > 0 and self.repeat and self.repeat_delay: logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green')) await asyncio.sleep(self.repeat_delay) if self.tx_start_delay: logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue')) await asyncio.sleep(self.tx_start_delay) logging.info(color('=== Sending RESET', 'magenta')) await self.packet_io.send_packet(bytes(Packet(Packet.PacketType.RESET))) start_time = time.time() self.next_expected_packet_index = 0 for i in range(self.tx_packet_count): target_time = start_time + (i * self.pace / 1000) now = time.time() if now < target_time: await asyncio.sleep(target_time - now) now = time.time() packet = bytes( Packet( packet_type=Packet.PacketType.SEQUENCE, flags=( Packet.PacketFlags.LAST if i == self.tx_packet_count - 1 else 0 ), sequence=i, timestamp=int((now - start_time) * 1000000), payload=bytes(self.tx_packet_size - 10), ) ) logging.info(color(f'Sending packet {i}', 'yellow')) self.ping_times.append(now) await self.packet_io.send_packet(packet) await self.done.wait() min_rtt = min(self.rtts) max_rtt = max(self.rtts) avg_rtt = statistics.mean(self.rtts) stdev_rtt = statistics.stdev(self.rtts) logging.info( color( '@@@ RTTs: ' f'min={min_rtt:.2f}, ' f'max={max_rtt:.2f}, ' f'average={avg_rtt:.2f}, ' f'stdev={stdev_rtt:.2f}' ) ) self.min_stats.append(min_rtt) self.max_stats.append(max_rtt) self.avg_stats.append(avg_rtt) run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else '' logging.info(color(f'=== {run_counter} Done!', 'magenta')) if self.repeat: log_stats('Min RTT', self.min_stats) log_stats('Max RTT', self.max_stats) log_stats('Average RTT', self.avg_stats) if self.repeat: logging.info(color('--- End of runs', 'blue')) def on_packet_received(self, data): try: packet = Packet.from_bytes(data) except ValueError: return if packet.packet_type == Packet.PacketType.ACK: elapsed = time.time() - self.ping_times[packet.sequence] rtt = elapsed * 1000 self.rtts.append(rtt) logging.info( color( f'<<< Received ACK [{packet.sequence}], RTT={rtt:.2f}ms', 'green', ) ) if packet.sequence == self.next_expected_packet_index: self.next_expected_packet_index += 1 else: logging.info( color( f'!!! Unexpected packet, ' f'expected {self.next_expected_packet_index} ' f'but received {packet.sequence}', 'red', ) ) if packet.flags & Packet.PacketFlags.LAST: self.done.set() return def is_sender(self): return True # ----------------------------------------------------------------------------- # Pong # ----------------------------------------------------------------------------- class Pong: expected_packet_index: int def __init__(self, packet_io, linger): self.jitter_stats = JitterStats() self.packet_io = packet_io self.packet_io.packet_listener = self self.linger = linger self.done = asyncio.Event() self.reset() def reset(self): self.expected_packet_index = 0 self.jitter_stats.reset() def on_packet_received(self, data): try: packet = Packet.from_bytes(data) except ValueError: return if packet.packet_type == Packet.PacketType.RESET: logging.info(color('=== Received RESET', 'magenta')) self.reset() return jitter = self.jitter_stats.on_packet_received(packet) logging.info( color( f'<<< Received packet {packet.sequence}: ' f'flags={packet.flags}, {len(data)} bytes, ' f'jitter={jitter:.4f}', 'green', ) ) if packet.sequence != self.expected_packet_index: logging.info( color( f'!!! Unexpected packet, expected {self.expected_packet_index} ' f'but received {packet.sequence}', 'red', ) ) self.expected_packet_index = packet.sequence + 1 AsyncRunner.spawn( self.packet_io.send_packet( bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence)) ) ) if packet.flags & Packet.PacketFlags.LAST: self.jitter_stats.show_stats() if not self.linger: self.done.set() async def run(self): await self.done.wait() logging.info(color('=== Done!', 'magenta')) def is_sender(self): return False # ----------------------------------------------------------------------------- # GattClient # ----------------------------------------------------------------------------- class GattClient: def __init__(self, _device, att_mtu=None): self.att_mtu = att_mtu self.speed_rx = None self.speed_tx = None self.packet_listener = None self.ready = asyncio.Event() self.overhead_size = 0 async def on_connection(self, connection): peer = Peer(connection) if self.att_mtu: logging.info(color(f'*** Requesting MTU update: {self.att_mtu}', 'blue')) await peer.request_mtu(self.att_mtu) logging.info(color('*** Discovering services...', 'blue')) await peer.discover_services() speed_services = peer.get_services_by_uuid(SPEED_SERVICE_UUID) if not speed_services: logging.info(color('!!! Speed Service not found', 'red')) return speed_service = speed_services[0] logging.info(color('*** Discovering characteristics...', 'blue')) await speed_service.discover_characteristics() speed_txs = speed_service.get_characteristics_by_uuid(SPEED_TX_UUID) if not speed_txs: logging.info(color('!!! Speed TX not found', 'red')) return self.speed_tx = speed_txs[0] speed_rxs = speed_service.get_characteristics_by_uuid(SPEED_RX_UUID) if not speed_rxs: logging.info(color('!!! Speed RX not found', 'red')) return self.speed_rx = speed_rxs[0] logging.info(color('*** Subscribing to RX', 'blue')) await self.speed_rx.subscribe(self.on_packet_received) logging.info(color('*** Discovery complete', 'blue')) connection.on('disconnection', self.on_disconnection) self.ready.set() def on_disconnection(self, _): self.ready.clear() def on_packet_received(self, packet): if self.packet_listener: self.packet_listener.on_packet_received(packet) async def send_packet(self, packet): await self.speed_tx.write_value(packet) async def drain(self): pass # ----------------------------------------------------------------------------- # GattServer # ----------------------------------------------------------------------------- class GattServer: def __init__(self, device): self.device = device self.packet_listener = None self.ready = asyncio.Event() self.overhead_size = 0 # Setup the GATT service self.speed_tx = Characteristic( SPEED_TX_UUID, Characteristic.Properties.WRITE, Characteristic.WRITEABLE, CharacteristicValue(write=self.on_tx_write), ) self.speed_rx = Characteristic( SPEED_RX_UUID, Characteristic.Properties.NOTIFY, 0 ) speed_service = Service( SPEED_SERVICE_UUID, [self.speed_tx, self.speed_rx], ) device.add_services([speed_service]) self.speed_rx.on('subscription', self.on_rx_subscription) async def on_connection(self, connection): connection.on('disconnection', self.on_disconnection) def on_disconnection(self, _): self.ready.clear() def on_rx_subscription(self, _connection, notify_enabled, _indicate_enabled): if notify_enabled: logging.info(color('*** RX subscription', 'blue')) self.ready.set() else: logging.info(color('*** RX un-subscription', 'blue')) self.ready.clear() def on_tx_write(self, _, value): if self.packet_listener: self.packet_listener.on_packet_received(value) async def send_packet(self, packet): await self.device.notify_subscribers(self.speed_rx, packet) async def drain(self): pass # ----------------------------------------------------------------------------- # StreamedPacketIO # ----------------------------------------------------------------------------- class StreamedPacketIO: def __init__(self): self.packet_listener = None self.io_sink = None self.rx_packet = b'' self.rx_packet_header = b'' self.rx_packet_need = 0 self.overhead_size = 2 def on_packet(self, packet): while packet: if self.rx_packet_need: chunk = packet[: self.rx_packet_need] self.rx_packet += chunk packet = packet[len(chunk) :] self.rx_packet_need -= len(chunk) if not self.rx_packet_need: # Packet completed if self.packet_listener: self.packet_listener.on_packet_received(self.rx_packet) self.rx_packet = b'' self.rx_packet_header = b'' else: # Expect the next packet header_bytes_needed = 2 - len(self.rx_packet_header) header_bytes = packet[:header_bytes_needed] self.rx_packet_header += header_bytes if len(self.rx_packet_header) != 2: return packet = packet[len(header_bytes) :] self.rx_packet_need = struct.unpack('>H', self.rx_packet_header)[0] async def send_packet(self, packet): if not self.io_sink: logging.info(color('!!! No sink, dropping packet', 'red')) return # pylint: disable-next=not-callable self.io_sink(struct.pack('>H', len(packet)) + packet) def can_receive(self): return True # ----------------------------------------------------------------------------- # L2capClient # ----------------------------------------------------------------------------- class L2capClient(StreamedPacketIO): def __init__( self, _device, psm=DEFAULT_L2CAP_PSM, max_credits=DEFAULT_L2CAP_MAX_CREDITS, mtu=DEFAULT_L2CAP_MTU, mps=DEFAULT_L2CAP_MPS, ): super().__init__() self.psm = psm self.max_credits = max_credits self.mtu = mtu self.mps = mps self.l2cap_channel = None self.ready = asyncio.Event() async def on_connection(self, connection: Connection) -> None: connection.on('disconnection', self.on_disconnection) # Connect a new L2CAP channel logging.info(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow')) try: l2cap_channel = await connection.create_l2cap_channel( spec=l2cap.LeCreditBasedChannelSpec( psm=self.psm, max_credits=self.max_credits, mtu=self.mtu, mps=self.mps, ) ) logging.info(color(f'*** L2CAP channel: {l2cap_channel}', 'cyan')) except Exception as error: logging.info(color(f'!!! Connection failed: {error}', 'red')) return self.io_sink = l2cap_channel.write self.l2cap_channel = l2cap_channel l2cap_channel.on('close', self.on_l2cap_close) l2cap_channel.sink = self.on_packet self.ready.set() def on_disconnection(self, _): pass def on_l2cap_close(self): logging.info(color('*** L2CAP channel closed', 'red')) async def drain(self): assert self.l2cap_channel await self.l2cap_channel.drain() # ----------------------------------------------------------------------------- # L2capServer # ----------------------------------------------------------------------------- class L2capServer(StreamedPacketIO): def __init__( self, device: Device, psm=DEFAULT_L2CAP_PSM, max_credits=DEFAULT_L2CAP_MAX_CREDITS, mtu=DEFAULT_L2CAP_MTU, mps=DEFAULT_L2CAP_MPS, ): super().__init__() self.l2cap_channel = None self.ready = asyncio.Event() # Listen for incoming L2CAP connections device.create_l2cap_server( spec=l2cap.LeCreditBasedChannelSpec( psm=psm, mtu=mtu, mps=mps, max_credits=max_credits ), handler=self.on_l2cap_channel, ) logging.info( color(f'### Listening for L2CAP connection on PSM {psm}', 'yellow') ) async def on_connection(self, connection): connection.on('disconnection', self.on_disconnection) def on_disconnection(self, _): pass def on_l2cap_channel(self, l2cap_channel): logging.info(color(f'*** L2CAP channel: {l2cap_channel}', 'cyan')) self.io_sink = l2cap_channel.write self.l2cap_channel = l2cap_channel l2cap_channel.on('close', self.on_l2cap_close) l2cap_channel.sink = self.on_packet self.ready.set() def on_l2cap_close(self): logging.info(color('*** L2CAP channel closed', 'red')) self.l2cap_channel = None async def drain(self): assert self.l2cap_channel await self.l2cap_channel.drain() # ----------------------------------------------------------------------------- # RfcommClient # ----------------------------------------------------------------------------- class RfcommClient(StreamedPacketIO): def __init__( self, device, channel, uuid, l2cap_mtu, max_frame_size, initial_credits, max_credits, credits_threshold, ): super().__init__() self.device = device self.channel = channel self.uuid = uuid self.l2cap_mtu = l2cap_mtu self.max_frame_size = max_frame_size self.initial_credits = initial_credits self.max_credits = max_credits self.credits_threshold = credits_threshold self.rfcomm_session = None self.ready = asyncio.Event() async def on_connection(self, connection): connection.on('disconnection', self.on_disconnection) # Find the channel number if not specified channel = self.channel if channel == 0: logging.info( color(f'@@@ Discovering channel number from UUID {self.uuid}', 'cyan') ) channel = await bumble.rfcomm.find_rfcomm_channel_with_uuid( connection, self.uuid ) if channel: logging.info(color(f'@@@ Channel number = {channel}', 'cyan')) else: logging.warning( color('!!! No RFComm service with this UUID found', 'red') ) await connection.disconnect() return # Create a client and start it logging.info(color('*** Starting RFCOMM client...', 'blue')) rfcomm_options = {} if self.l2cap_mtu: rfcomm_options['l2cap_mtu'] = self.l2cap_mtu rfcomm_client = bumble.rfcomm.Client(connection, **rfcomm_options) rfcomm_mux = await rfcomm_client.start() logging.info(color('*** Started', 'blue')) logging.info(color(f'### Opening session for channel {channel}...', 'yellow')) try: dlc_options = {} if self.max_frame_size is not None: dlc_options['max_frame_size'] = self.max_frame_size if self.initial_credits is not None: dlc_options['initial_credits'] = self.initial_credits rfcomm_session = await rfcomm_mux.open_dlc(channel, **dlc_options) logging.info(color(f'### Session open: {rfcomm_session}', 'yellow')) if self.max_credits is not None: rfcomm_session.rx_max_credits = self.max_credits if self.credits_threshold is not None: rfcomm_session.rx_credits_threshold = self.credits_threshold except bumble.core.ConnectionError as error: logging.info(color(f'!!! Session open failed: {error}', 'red')) await rfcomm_mux.disconnect() return rfcomm_session.sink = self.on_packet self.io_sink = rfcomm_session.write self.rfcomm_session = rfcomm_session self.ready.set() def on_disconnection(self, _): pass async def drain(self): assert self.rfcomm_session await self.rfcomm_session.drain() # ----------------------------------------------------------------------------- # RfcommServer # ----------------------------------------------------------------------------- class RfcommServer(StreamedPacketIO): def __init__( self, device, channel, l2cap_mtu, max_frame_size, initial_credits, max_credits, credits_threshold, ): super().__init__() self.max_credits = max_credits self.credits_threshold = credits_threshold self.dlc = None self.ready = asyncio.Event() # Create and register a server server_options = {} if l2cap_mtu: server_options['l2cap_mtu'] = l2cap_mtu rfcomm_server = bumble.rfcomm.Server(device, **server_options) # Listen for incoming DLC connections dlc_options = {} if max_frame_size is not None: dlc_options['max_frame_size'] = max_frame_size if initial_credits is not None: dlc_options['initial_credits'] = initial_credits channel_number = rfcomm_server.listen(self.on_dlc, channel, **dlc_options) # Setup the SDP to advertise this channel device.sdp_service_records = make_sdp_records(channel_number) logging.info( color( f'### Listening for RFComm connection on channel {channel_number}', 'yellow', ) ) async def on_connection(self, connection): connection.on('disconnection', self.on_disconnection) def on_disconnection(self, _): pass def on_dlc(self, dlc): logging.info(color(f'*** DLC connected: {dlc}', 'blue')) if self.credits_threshold is not None: dlc.rx_threshold = self.credits_threshold if self.max_credits is not None: dlc.rx_max_credits = self.max_credits dlc.sink = self.on_packet self.io_sink = dlc.write self.dlc = dlc if self.max_credits is not None: dlc.rx_max_credits = self.max_credits if self.credits_threshold is not None: dlc.rx_credits_threshold = self.credits_threshold self.ready.set() async def drain(self): assert self.dlc await self.dlc.drain() # ----------------------------------------------------------------------------- # IsoClient # ----------------------------------------------------------------------------- class IsoClient(StreamedPacketIO): def __init__( self, device: Device, ) -> None: super().__init__() self.device = device self.ready = asyncio.Event() self.cis_link: Optional[CisLink] = None async def on_connection( self, connection: Connection, cis_link: CisLink, sender: bool ) -> None: connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection) self.cis_link = cis_link self.io_sink = cis_link.write await cis_link.setup_data_path( cis_link.Direction.HOST_TO_CONTROLLER if sender else cis_link.Direction.CONTROLLER_TO_HOST ) cis_link.sink = self.on_iso_packet self.ready.set() def on_iso_packet(self, iso_packet: HCI_IsoDataPacket) -> None: self.on_packet(iso_packet.iso_sdu_fragment) def on_disconnection(self, _): pass async def drain(self): if self.cis_link is None: return await self.cis_link.drain() def can_receive(self): return False # ----------------------------------------------------------------------------- # IsoServer # ----------------------------------------------------------------------------- class IsoServer(StreamedPacketIO): def __init__( self, device: Device, ): super().__init__() self.device = device self.cis_link: Optional[CisLink] = None self.ready = asyncio.Event() logging.info( color( '### Listening for ISO connection', 'yellow', ) ) async def on_connection( self, connection: Connection, cis_link: CisLink, sender: bool ) -> None: connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection) self.io_sink = cis_link.write await cis_link.setup_data_path( cis_link.Direction.HOST_TO_CONTROLLER if sender else cis_link.Direction.CONTROLLER_TO_HOST ) cis_link.sink = self.on_iso_packet self.ready.set() def on_iso_packet(self, iso_packet: HCI_IsoDataPacket) -> None: self.on_packet(iso_packet.iso_sdu_fragment) def on_disconnection(self, _): pass async def drain(self): if self.cis_link is None: return await self.cis_link.drain() def can_receive(self): return False # ----------------------------------------------------------------------------- # Central # ----------------------------------------------------------------------------- class Central(Connection.Listener): def __init__( self, transport, peripheral_address, scenario_factory, mode_factory, connection_interval, phy, authenticate, encrypt, iso, iso_sdu_interval_c_to_p, iso_sdu_interval_p_to_c, iso_max_sdu_c_to_p, iso_max_sdu_p_to_c, iso_max_transport_latency_c_to_p, iso_max_transport_latency_p_to_c, iso_rtn_c_to_p, iso_rtn_p_to_c, classic, extended_data_length, role_switch, le_scan, le_advertise, classic_page_scan, classic_inquiry_scan, ): super().__init__() self.transport = transport self.peripheral_address = peripheral_address self.classic = classic self.iso = iso self.iso_sdu_interval_c_to_p = iso_sdu_interval_c_to_p self.iso_sdu_interval_p_to_c = iso_sdu_interval_p_to_c self.iso_max_sdu_c_to_p = iso_max_sdu_c_to_p self.iso_max_sdu_p_to_c = iso_max_sdu_p_to_c self.iso_max_transport_latency_c_to_p = iso_max_transport_latency_c_to_p self.iso_max_transport_latency_p_to_c = iso_max_transport_latency_p_to_c self.iso_rtn_c_to_p = iso_rtn_c_to_p self.iso_rtn_p_to_c = iso_rtn_p_to_c self.scenario_factory = scenario_factory self.mode_factory = mode_factory self.authenticate = authenticate self.encrypt = encrypt or authenticate self.extended_data_length = extended_data_length self.role_switch = role_switch self.le_scan = le_scan self.le_advertise = le_advertise self.classic_page_scan = classic_page_scan self.classic_inquiry_scan = classic_inquiry_scan self.device = None self.connection = None if phy: self.phy = { '1m': HCI_LE_1M_PHY, '2m': HCI_LE_2M_PHY, 'coded': HCI_LE_CODED_PHY, }[phy] else: self.phy = None if connection_interval: connection_parameter_preferences = ConnectionParametersPreferences() connection_parameter_preferences.connection_interval_min = ( connection_interval ) connection_parameter_preferences.connection_interval_max = ( connection_interval ) # Preferences for the 1M PHY are always set. self.connection_parameter_preferences = { HCI_LE_1M_PHY: connection_parameter_preferences, } if self.phy not in (None, HCI_LE_1M_PHY): # Add an connections parameters entry for this PHY. self.connection_parameter_preferences[self.phy] = ( connection_parameter_preferences ) else: self.connection_parameter_preferences = None async def run(self): logging.info(color('>>> Connecting to HCI...', 'green')) async with await open_transport_or_link(self.transport) as ( hci_source, hci_sink, ): logging.info(color('>>> Connected', 'green')) central_address = DEFAULT_CENTRAL_ADDRESS self.device = Device.with_hci( DEFAULT_CENTRAL_NAME, central_address, hci_source, hci_sink ) mode = self.mode_factory(self.device) scenario = self.scenario_factory(mode) self.device.classic_enabled = self.classic self.device.cis_enabled = self.iso # Set up a pairing config factory with minimal requirements. self.device.pairing_config_factory = lambda _: PairingConfig( sc=False, mitm=False, bonding=False ) await pre_power_on(self.device, self.classic) await self.device.power_on() await post_power_on( self.device, self.le_scan, self.le_advertise, self.classic_page_scan, self.classic_inquiry_scan, ) logging.info( color(f'### Connecting to {self.peripheral_address}...', 'cyan') ) try: self.connection = await self.device.connect( self.peripheral_address, connection_parameters_preferences=self.connection_parameter_preferences, transport=( PhysicalTransport.BR_EDR if self.classic else PhysicalTransport.LE ), ) except CommandTimeoutError: logging.info(color('!!! Connection timed out', 'red')) return except bumble.core.ConnectionError as error: logging.info(color(f'!!! Connection error: {error}', 'red')) return except HCI_StatusError as error: logging.info(color(f'!!! Connection failed: {error.error_name}')) return logging.info(color('### Connected', 'cyan')) self.connection.listener = self print_connection(self.connection) if not self.classic: phy = await self.connection.get_phy() print_connection_phy(phy) # Switch roles if needed. if self.role_switch: await switch_roles(self.connection, self.role_switch) # Wait a bit after the connection, some controllers aren't very good when # we start sending data right away while some connection parameters are # updated post connection await asyncio.sleep(DEFAULT_POST_CONNECTION_WAIT_TIME) # Request a new data length if requested if self.extended_data_length: logging.info(color('+++ Requesting extended data length', 'cyan')) await self.connection.set_data_length( self.extended_data_length[0], self.extended_data_length[1] ) # Authenticate if requested if self.authenticate: # Request authentication logging.info(color('*** Authenticating...', 'cyan')) await self.connection.authenticate() logging.info(color('*** Authenticated', 'cyan')) # Encrypt if requested if self.encrypt: # Enable encryption logging.info(color('*** Enabling encryption...', 'cyan')) await self.connection.encrypt() logging.info(color('*** Encryption on', 'cyan')) # Set the PHY if requested if self.phy is not None: try: await self.connection.set_phy( tx_phys=[self.phy], rx_phys=[self.phy] ) except HCI_Error as error: logging.info( color( f'!!! Unable to set the PHY: {error.error_name}', 'yellow' ) ) # Setup ISO streams. if self.iso: if scenario.is_sender(): sdu_interval_c_to_p = ( self.iso_sdu_interval_c_to_p or DEFAULT_ISO_SDU_INTERVAL_C_TO_P ) sdu_interval_p_to_c = self.iso_sdu_interval_p_to_c or 0 max_transport_latency_c_to_p = ( self.iso_max_transport_latency_c_to_p or DEFAULT_ISO_MAX_TRANSPORT_LATENCY_C_TO_P ) max_transport_latency_p_to_c = ( self.iso_max_transport_latency_p_to_c or 0 ) max_sdu_c_to_p = ( self.iso_max_sdu_c_to_p or DEFAULT_ISO_MAX_SDU_C_TO_P ) max_sdu_p_to_c = self.iso_max_sdu_p_to_c or 0 rtn_c_to_p = self.iso_rtn_c_to_p or DEFAULT_ISO_RTN_C_TO_P rtn_p_to_c = self.iso_rtn_p_to_c or 0 else: sdu_interval_p_to_c = ( self.iso_sdu_interval_p_to_c or DEFAULT_ISO_SDU_INTERVAL_P_TO_C ) sdu_interval_c_to_p = self.iso_sdu_interval_c_to_p or 0 max_transport_latency_p_to_c = ( self.iso_max_transport_latency_p_to_c or DEFAULT_ISO_MAX_TRANSPORT_LATENCY_P_TO_C ) max_transport_latency_c_to_p = ( self.iso_max_transport_latency_c_to_p or 0 ) max_sdu_p_to_c = ( self.iso_max_sdu_p_to_c or DEFAULT_ISO_MAX_SDU_P_TO_C ) max_sdu_c_to_p = self.iso_max_sdu_c_to_p or 0 rtn_p_to_c = self.iso_rtn_p_to_c or DEFAULT_ISO_RTN_P_TO_C rtn_c_to_p = self.iso_rtn_c_to_p or 0 cis_handles = await self.device.setup_cig( CigParameters( cig_id=1, sdu_interval_c_to_p=sdu_interval_c_to_p, sdu_interval_p_to_c=sdu_interval_p_to_c, max_transport_latency_c_to_p=max_transport_latency_c_to_p, max_transport_latency_p_to_c=max_transport_latency_p_to_c, cis_parameters=[ CigParameters.CisParameters( cis_id=2, max_sdu_c_to_p=max_sdu_c_to_p, max_sdu_p_to_c=max_sdu_p_to_c, rtn_c_to_p=rtn_c_to_p, rtn_p_to_c=rtn_p_to_c, ) ], ) ) cis_link = ( await self.device.create_cis([(cis_handles[0], self.connection)]) )[0] print_cis_link(cis_link) await mode.on_connection( self.connection, cis_link, scenario.is_sender() ) else: await mode.on_connection(self.connection) await scenario.run() await asyncio.sleep(DEFAULT_LINGER_TIME) await self.connection.disconnect() def on_disconnection(self, reason): logging.info(color(f'!!! Disconnection: reason={reason}', 'red')) self.connection = None def on_connection_parameters_update(self): print_connection(self.connection) def on_connection_phy_update(self, phy): print_connection_phy(phy) def on_connection_att_mtu_update(self): print_connection(self.connection) def on_connection_data_length_change(self): print_connection(self.connection) def on_role_change(self): print_connection(self.connection) # ----------------------------------------------------------------------------- # Peripheral # ----------------------------------------------------------------------------- class Peripheral(Device.Listener, Connection.Listener): def __init__( self, transport, scenario_factory, mode_factory, classic, iso, extended_data_length, role_switch, le_scan, le_advertise, classic_page_scan, classic_inquiry_scan, ): self.transport = transport self.classic = classic self.iso = iso self.scenario_factory = scenario_factory self.mode_factory = mode_factory self.extended_data_length = extended_data_length self.role_switch = role_switch self.le_scan = le_scan self.classic_page_scan = classic_page_scan self.classic_inquiry_scan = classic_inquiry_scan self.scenario = None self.mode = None self.device = None self.connection = None self.connected = asyncio.Event() if le_advertise: self.le_advertise = le_advertise else: self.le_advertise = 0 if classic else DEFAULT_ADVERTISING_INTERVAL async def run(self): logging.info(color('>>> Connecting to HCI...', 'green')) async with await open_transport_or_link(self.transport) as ( hci_source, hci_sink, ): logging.info(color('>>> Connected', 'green')) peripheral_address = DEFAULT_PERIPHERAL_ADDRESS self.device = Device.with_hci( DEFAULT_PERIPHERAL_NAME, peripheral_address, hci_source, hci_sink ) self.device.listener = self self.mode = self.mode_factory(self.device) self.scenario = self.scenario_factory(self.mode) self.device.classic_enabled = self.classic self.device.cis_enabled = self.iso # Set up a pairing config factory with minimal requirements. self.device.pairing_config_factory = lambda _: PairingConfig( sc=False, mitm=False, bonding=False ) await pre_power_on(self.device, self.classic) await self.device.power_on() await post_power_on( self.device, self.le_scan, self.le_advertise, self.classic or self.classic_page_scan, self.classic or self.classic_inquiry_scan, ) if self.classic: logging.info( color( '### Waiting for connection on' f' {self.device.public_address}...', 'cyan', ) ) else: logging.info( color( f'### Waiting for connection on {peripheral_address}...', 'cyan', ) ) await self.connected.wait() logging.info(color('### Connected', 'cyan')) print_connection(self.connection) if self.iso: async def on_cis_request(cis_link: CisLink) -> None: logging.info(color("@@@ Accepting CIS", "green")) await self.device.accept_cis_request(cis_link) print_cis_link(cis_link) await self.mode.on_connection( self.connection, cis_link, self.scenario.is_sender() ) self.connection.on(self.connection.EVENT_CIS_REQUEST, on_cis_request) else: await self.mode.on_connection(self.connection) await self.scenario.run() await asyncio.sleep(DEFAULT_LINGER_TIME) def on_connection(self, connection): connection.listener = self self.connection = connection self.connected.set() # Stop being discoverable and connectable if possible if self.classic: if not self.classic_inquiry_scan: logging.info(color("*** Stopping inquiry scan", "blue")) AsyncRunner.spawn(self.device.set_discoverable(False)) if not self.classic_page_scan: logging.info(color("*** Stopping page scan", "blue")) AsyncRunner.spawn(self.device.set_connectable(False)) # Request a new data length if needed if not self.classic and self.extended_data_length: logging.info("+++ Requesting extended data length") AsyncRunner.spawn( connection.set_data_length( self.extended_data_length[0], self.extended_data_length[1] ) ) # Switch roles if needed. if self.role_switch: AsyncRunner.spawn(switch_roles(connection, self.role_switch)) def on_disconnection(self, reason): logging.info(color(f'!!! Disconnection: reason={reason}', 'red')) self.connection = None self.scenario.reset() if self.classic: logging.info(color("*** Enabling inquiry scan", "blue")) AsyncRunner.spawn(self.device.set_discoverable(True)) logging.info(color("*** Enabling page scan", "blue")) AsyncRunner.spawn(self.device.set_connectable(True)) def on_connection_parameters_update(self): print_connection(self.connection) def on_connection_phy_update(self, phy): print_connection_phy(phy) def on_connection_att_mtu_update(self): print_connection(self.connection) def on_connection_data_length_change(self): print_connection(self.connection) def on_role_change(self): print_connection(self.connection) # ----------------------------------------------------------------------------- def create_mode_factory(ctx, default_mode): mode = ctx.obj['mode'] if mode is None: mode = default_mode def create_mode(device): if mode == 'gatt-client': return GattClient(device, att_mtu=ctx.obj['att_mtu']) if mode == 'gatt-server': return GattServer(device) if mode == 'l2cap-client': return L2capClient( device, psm=ctx.obj['l2cap_psm'], mtu=ctx.obj['l2cap_mtu'], mps=ctx.obj['l2cap_mps'], max_credits=ctx.obj['l2cap_max_credits'], ) if mode == 'l2cap-server': return L2capServer( device, psm=ctx.obj['l2cap_psm'], mtu=ctx.obj['l2cap_mtu'], mps=ctx.obj['l2cap_mps'], max_credits=ctx.obj['l2cap_max_credits'], ) if mode == 'rfcomm-client': return RfcommClient( device, channel=ctx.obj['rfcomm_channel'], uuid=ctx.obj['rfcomm_uuid'], l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'], max_frame_size=ctx.obj['rfcomm_max_frame_size'], initial_credits=ctx.obj['rfcomm_initial_credits'], max_credits=ctx.obj['rfcomm_max_credits'], credits_threshold=ctx.obj['rfcomm_credits_threshold'], ) if mode == 'rfcomm-server': return RfcommServer( device, channel=ctx.obj['rfcomm_channel'], l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'], max_frame_size=ctx.obj['rfcomm_max_frame_size'], initial_credits=ctx.obj['rfcomm_initial_credits'], max_credits=ctx.obj['rfcomm_max_credits'], credits_threshold=ctx.obj['rfcomm_credits_threshold'], ) if mode == 'iso-server': return IsoServer(device) if mode == 'iso-client': return IsoClient(device) raise ValueError('invalid mode') return create_mode # ----------------------------------------------------------------------------- def create_scenario_factory(ctx, default_scenario): scenario = ctx.obj['scenario'] if scenario is None: scenario = default_scenario def create_scenario(packet_io): if scenario == 'send': return Sender( packet_io, start_delay=ctx.obj['start_delay'], repeat=ctx.obj['repeat'], repeat_delay=ctx.obj['repeat_delay'], pace=ctx.obj['pace'], packet_size=ctx.obj['packet_size'], packet_count=ctx.obj['packet_count'], ) if scenario == 'receive': return Receiver(packet_io, ctx.obj['linger']) if scenario == 'ping': if isinstance(packet_io, (IsoClient, IsoServer)): raise ValueError('ping not supported with ISO') return Ping( packet_io, start_delay=ctx.obj['start_delay'], repeat=ctx.obj['repeat'], repeat_delay=ctx.obj['repeat_delay'], pace=ctx.obj['pace'], packet_size=ctx.obj['packet_size'], packet_count=ctx.obj['packet_count'], ) if scenario == 'pong': if isinstance(packet_io, (IsoClient, IsoServer)): raise ValueError('pong not supported with ISO') return Pong(packet_io, ctx.obj['linger']) raise ValueError('invalid scenario') return create_scenario # ----------------------------------------------------------------------------- # Main # ----------------------------------------------------------------------------- @click.group() @click.option('--device-config', metavar='FILENAME', help='Device configuration file') @click.option('--scenario', type=click.Choice(['send', 'receive', 'ping', 'pong'])) @click.option( '--mode', type=click.Choice( [ 'gatt-client', 'gatt-server', 'l2cap-client', 'l2cap-server', 'rfcomm-client', 'rfcomm-server', 'iso-client', 'iso-server', ] ), ) @click.option( '--att-mtu', metavar='MTU', type=click.IntRange(23, 517), default=517, help='GATT MTU (gatt-client mode)', ) @click.option( '--extended-data-length', metavar='/', help='Request a data length upon connection, specified as tx_octets/tx_time', ) @click.option( '--role-switch', type=click.Choice(['central', 'peripheral']), help='Request role switch upon connection (central or peripheral)', ) @click.option( '--le-scan', metavar='/', help='Perform an LE scan with a given window and interval (milliseconds)', ) @click.option( '--le-advertise', metavar='', help='Advertise with a given interval (milliseconds)', ) @click.option( '--classic-page-scan', is_flag=True, help='Enable Classic page scanning', ) @click.option( '--classic-inquiry-scan', is_flag=True, help='Enable Classic enquiry scanning', ) @click.option( '--rfcomm-channel', type=int, default=DEFAULT_RFCOMM_CHANNEL, help='RFComm channel to use (specify 0 for channel discovery via SDP)', ) @click.option( '--rfcomm-uuid', default=DEFAULT_RFCOMM_UUID, help='RFComm service UUID to use (ignored if --rfcomm-channel is not 0)', ) @click.option( '--rfcomm-l2cap-mtu', type=int, help='RFComm L2CAP MTU', ) @click.option( '--rfcomm-max-frame-size', type=int, help='RFComm maximum frame size', ) @click.option( '--rfcomm-initial-credits', type=int, help='RFComm initial credits', ) @click.option( '--rfcomm-max-credits', type=int, help='RFComm max credits', ) @click.option( '--rfcomm-credits-threshold', type=int, help='RFComm credits threshold', ) @click.option( '--l2cap-psm', type=int, default=DEFAULT_L2CAP_PSM, help='L2CAP PSM to use', ) @click.option( '--l2cap-mtu', type=int, default=DEFAULT_L2CAP_MTU, help='L2CAP MTU to use', ) @click.option( '--l2cap-mps', type=int, default=DEFAULT_L2CAP_MPS, help='L2CAP MPS to use', ) @click.option( '--l2cap-max-credits', type=int, default=DEFAULT_L2CAP_MAX_CREDITS, help='L2CAP maximum number of credits allowed for the peer', ) @click.option( '--packet-size', '-s', metavar='SIZE', type=click.IntRange(10, 8192), default=500, help='Packet size (send or ping scenario)', ) @click.option( '--packet-count', '-c', metavar='COUNT', type=int, default=10, help='Packet count (send or ping scenario)', ) @click.option( '--start-delay', '-sd', metavar='SECONDS', type=int, default=1, help='Start delay (send or ping scenario)', ) @click.option( '--repeat', metavar='N', type=int, default=0, help=( 'Repeat the run N times (send and ping scenario)' '(0, which is the fault, to run just once) ' ), ) @click.option( '--repeat-delay', metavar='SECONDS', type=int, default=1, help=('Delay, in seconds, between repeats'), ) @click.option( '--pace', metavar='MILLISECONDS', type=int, default=0, help=( 'Wait N milliseconds between packets ' '(0, which is the fault, to send as fast as possible) ' ), ) @click.option( '--linger', is_flag=True, help="Don't exit at the end of a run (receive and pong scenarios)", ) @click.pass_context def bench( ctx, device_config, scenario, mode, att_mtu, extended_data_length, role_switch, le_scan, le_advertise, classic_page_scan, classic_inquiry_scan, packet_size, packet_count, start_delay, repeat, repeat_delay, pace, linger, rfcomm_channel, rfcomm_uuid, rfcomm_l2cap_mtu, rfcomm_max_frame_size, rfcomm_initial_credits, rfcomm_max_credits, rfcomm_credits_threshold, l2cap_psm, l2cap_mtu, l2cap_mps, l2cap_max_credits, ): ctx.ensure_object(dict) ctx.obj['device_config'] = device_config ctx.obj['scenario'] = scenario ctx.obj['mode'] = mode ctx.obj['att_mtu'] = att_mtu ctx.obj['rfcomm_channel'] = rfcomm_channel ctx.obj['rfcomm_uuid'] = rfcomm_uuid ctx.obj['rfcomm_l2cap_mtu'] = rfcomm_l2cap_mtu ctx.obj['rfcomm_max_frame_size'] = rfcomm_max_frame_size ctx.obj['rfcomm_initial_credits'] = rfcomm_initial_credits ctx.obj['rfcomm_max_credits'] = rfcomm_max_credits ctx.obj['rfcomm_credits_threshold'] = rfcomm_credits_threshold ctx.obj['l2cap_psm'] = l2cap_psm ctx.obj['l2cap_mtu'] = l2cap_mtu ctx.obj['l2cap_mps'] = l2cap_mps ctx.obj['l2cap_max_credits'] = l2cap_max_credits ctx.obj['packet_size'] = packet_size ctx.obj['packet_count'] = packet_count ctx.obj['start_delay'] = start_delay ctx.obj['repeat'] = repeat ctx.obj['repeat_delay'] = repeat_delay ctx.obj['pace'] = pace ctx.obj['linger'] = linger ctx.obj['extended_data_length'] = ( [int(x) for x in extended_data_length.split('/')] if extended_data_length else None ) ctx.obj['role_switch'] = role_switch ctx.obj['le_scan'] = [float(x) for x in le_scan.split('/')] if le_scan else None ctx.obj['le_advertise'] = float(le_advertise) if le_advertise else None ctx.obj['classic_page_scan'] = classic_page_scan ctx.obj['classic_inquiry_scan'] = classic_inquiry_scan ctx.obj['classic'] = mode in ('rfcomm-client', 'rfcomm-server') ctx.obj['iso'] = mode in ('iso-client', 'iso-server') @bench.command() @click.argument('transport') @click.option( '--peripheral', 'peripheral_address', metavar='ADDRESS_OR_NAME', default=DEFAULT_PERIPHERAL_ADDRESS, help='Address or name to connect to', ) @click.option( '--connection-interval', '--ci', metavar='CONNECTION_INTERVAL', type=int, help='Connection interval (in ms)', ) @click.option('--phy', type=click.Choice(['1m', '2m', 'coded']), help='PHY to use') @click.option('--authenticate', is_flag=True, help='Authenticate (RFComm only)') @click.option('--encrypt', is_flag=True, help='Encrypt the connection (RFComm only)') @click.option( '--iso-sdu-interval-c-to-p', type=int, help='ISO SDU central -> peripheral (microseconds)', ) @click.option( '--iso-sdu-interval-p-to-c', type=int, help='ISO SDU interval peripheral -> central (microseconds)', ) @click.option( '--iso-max-sdu-c-to-p', type=int, help='ISO max SDU central -> peripheral', ) @click.option( '--iso-max-sdu-p-to-c', type=int, help='ISO max SDU peripheral -> central', ) @click.option( '--iso-max-transport-latency-c-to-p', type=int, help='ISO max transport latency central -> peripheral (milliseconds)', ) @click.option( '--iso-max-transport-latency-p-to-c', type=int, help='ISO max transport latency peripheral -> central (milliseconds)', ) @click.option( '--iso-rtn-c-to-p', type=int, help='ISO RTN central -> peripheral (integer count)', ) @click.option( '--iso-rtn-p-to-c', type=int, help='ISO RTN peripheral -> central (integer count)', ) @click.pass_context def central( ctx, transport, peripheral_address, connection_interval, phy, authenticate, encrypt, iso_sdu_interval_c_to_p, iso_sdu_interval_p_to_c, iso_max_sdu_c_to_p, iso_max_sdu_p_to_c, iso_max_transport_latency_c_to_p, iso_max_transport_latency_p_to_c, iso_rtn_c_to_p, iso_rtn_p_to_c, ): """Run as a central (initiates the connection)""" scenario_factory = create_scenario_factory(ctx, 'send') mode_factory = create_mode_factory(ctx, 'gatt-client') async def run_central(): await Central( transport, peripheral_address, scenario_factory, mode_factory, connection_interval, phy, authenticate, encrypt or authenticate, ctx.obj['iso'], iso_sdu_interval_c_to_p, iso_sdu_interval_p_to_c, iso_max_sdu_c_to_p, iso_max_sdu_p_to_c, iso_max_transport_latency_c_to_p, iso_max_transport_latency_p_to_c, iso_rtn_c_to_p, iso_rtn_p_to_c, ctx.obj['classic'], ctx.obj['extended_data_length'], ctx.obj['role_switch'], ctx.obj['le_scan'], ctx.obj['le_advertise'], ctx.obj['classic_page_scan'], ctx.obj['classic_inquiry_scan'], ).run() asyncio.run(run_central()) @bench.command() @click.argument('transport') @click.pass_context def peripheral(ctx, transport): """Run as a peripheral (waits for a connection)""" scenario_factory = create_scenario_factory(ctx, 'receive') mode_factory = create_mode_factory(ctx, 'gatt-server') async def run_peripheral(): await Peripheral( transport, scenario_factory, mode_factory, ctx.obj['classic'], ctx.obj['iso'], ctx.obj['extended_data_length'], ctx.obj['role_switch'], ctx.obj['le_scan'], ctx.obj['le_advertise'], ctx.obj['classic_page_scan'], ctx.obj['classic_inquiry_scan'], ).run() asyncio.run(run_peripheral()) def main(): logging.basicConfig( level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper(), format="[%(asctime)s.%(msecs)03d] %(levelname)s:%(name)s:%(message)s", datefmt="%H:%M:%S", ) bench() # ----------------------------------------------------------------------------- if __name__ == "__main__": main() # pylint: disable=no-value-for-parameter