diff --git a/apps/controller_loopback.py b/apps/controller_loopback.py new file mode 100644 index 00000000..a1cc3a52 --- /dev/null +++ b/apps/controller_loopback.py @@ -0,0 +1,198 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import logging +import os +import time +from typing import Optional +from bumble.colors import color +from bumble.hci import ( + HCI_READ_LOOPBACK_MODE_COMMAND, + HCI_Read_Loopback_Mode_Command, + HCI_WRITE_LOOPBACK_MODE_COMMAND, + HCI_Write_Loopback_Mode_Command, + LoopbackMode, +) +from bumble.host import Host +from bumble.transport import open_transport_or_link +import click + + +class Loopback: + """Send and receive ACL data packets in local loopback mode""" + + def __init__(self, packet_size: int, packet_count: int, transport: str): + self.transport = transport + self.packet_size = packet_size + self.packet_count = packet_count + self.connection_handle: Optional[int] = None + self.connection_event = asyncio.Event() + self.done = asyncio.Event() + self.expected_cid = 0 + self.bytes_received = 0 + self.start_timestamp = 0.0 + self.last_timestamp = 0.0 + + def on_connection(self, connection_handle: int, *args): + """Retrieve connection handle from new connection event""" + if not self.connection_event.is_set(): + # save first connection handle for ACL + # subsequent connections are SCO + self.connection_handle = connection_handle + self.connection_event.set() + + def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes): + """Calculate packet receive speed""" + now = time.time() + print(f'<<< Received packet {cid}: {len(pdu)} bytes') + assert connection_handle == self.connection_handle + assert cid == self.expected_cid + self.expected_cid += 1 + if cid == 0: + self.start_timestamp = now + self.last_timestamp = now + else: + elapsed_since_start = now - self.start_timestamp + elapsed_since_last = now - self.last_timestamp + self.bytes_received += len(pdu) + instant_rx_speed = len(pdu) / elapsed_since_last + average_rx_speed = self.bytes_received / elapsed_since_start + print( + color( + f'@@@ RX speed: instant={instant_rx_speed:.4f},' + f' average={average_rx_speed:.4f}', + 'cyan', + ) + ) + if self.expected_cid == self.packet_count: + print(color('@@@ Received last packet', 'green')) + self.done.set() + + async def run(self): + """Run a loopback throughput test""" + print(color('>>> Connecting to HCI...', 'green')) + async with await open_transport_or_link(self.transport) as ( + hci_source, + hci_sink, + ): + print(color('>>> Connected', 'green')) + + host = Host(hci_source, hci_sink) + await host.reset() + + # make sure data can fit in one l2cap pdu + l2cap_header_size = 4 + max_packet_size = host.hc_acl_data_packet_length - l2cap_header_size + if self.packet_size > max_packet_size: + print( + color( + f'!!! Packet size ({self.packet_size}) larger than max supported' + f' size ({max_packet_size})', + 'red', + ) + ) + return + + if not host.supports_command( + HCI_WRITE_LOOPBACK_MODE_COMMAND + ) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND): + print(color('!!! Loopback mode not supported', 'red')) + return + + # set event callbacks + host.on('connection', self.on_connection) + host.on('l2cap_pdu', self.on_l2cap_pdu) + + loopback_mode = LoopbackMode.LOCAL + + print(color('### Setting loopback mode', 'blue')) + await host.send_command( + HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL), + check_result=True, + ) + + print(color('### Checking loopback mode', 'blue')) + response = await host.send_command( + HCI_Read_Loopback_Mode_Command(), check_result=True + ) + if response.return_parameters.loopback_mode != loopback_mode: + print(color('!!! Loopback mode mismatch', 'red')) + return + + await self.connection_event.wait() + print(color('### Connected', 'cyan')) + + print(color('=== Start sending', 'magenta')) + start_time = time.time() + bytes_sent = 0 + for cid in range(0, self.packet_count): + # using the cid as an incremental index + host.send_l2cap_pdu( + self.connection_handle, cid, bytes(self.packet_size) + ) + print( + color( + f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow' + ) + ) + bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes + await asyncio.sleep(0) # yield to allow packet receive + + await self.done.wait() + print(color('=== Done!', 'magenta')) + + elapsed = time.time() - start_time + average_tx_speed = bytes_sent / elapsed + print( + color( + f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes' + f' in {elapsed:.2f} seconds)', + 'green', + ) + ) + + +# ----------------------------------------------------------------------------- +@click.command() +@click.option( + '--packet-size', + '-s', + metavar='SIZE', + type=click.IntRange(8, 4096), + default=500, + help='Packet size', +) +@click.option( + '--packet-count', + '-c', + metavar='COUNT', + type=int, + default=10, + help='Packet count', +) +@click.argument('transport') +def main(packet_size, packet_count, transport): + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) + + loopback = Loopback(packet_size, packet_count, transport) + asyncio.run(loopback.run()) + + +# ----------------------------------------------------------------------------- +if __name__ == '__main__': + main() diff --git a/bumble/hci.py b/bumble/hci.py index 36c049c9..5a488380 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -1987,6 +1987,17 @@ class OwnAddressType(enum.IntEnum): return {'size': 1, 'mapper': lambda x: OwnAddressType(x).name} +# ----------------------------------------------------------------------------- +class LoopbackMode(enum.IntEnum): + DISABLED = 0 + LOCAL = 1 + REMOTE = 2 + + @classmethod + def type_spec(cls): + return {'size': 1, 'mapper': lambda x: LoopbackMode(x).name} + + # ----------------------------------------------------------------------------- class HCI_Packet: ''' @@ -3313,6 +3324,27 @@ class HCI_Read_Encryption_Key_Size_Command(HCI_Command): ''' +# ----------------------------------------------------------------------------- +@HCI_Command.command( + return_parameters_fields=[ + ('status', STATUS_SPEC), + ('loopback_mode', LoopbackMode.type_spec()), + ], +) +class HCI_Read_Loopback_Mode_Command(HCI_Command): + ''' + See Bluetooth spec @ 7.6.1 Read Loopback Mode Command + ''' + + +# ----------------------------------------------------------------------------- +@HCI_Command.command([('loopback_mode', 1)]) +class HCI_Write_Loopback_Mode_Command(HCI_Command): + ''' + See Bluetooth spec @ 7.6.2 Write Loopback Mode Command + ''' + + # ----------------------------------------------------------------------------- @HCI_Command.command([('le_event_mask', 8)]) class HCI_LE_Set_Event_Mask_Command(HCI_Command):