From dc3ac3060ebb4700381f121d4ed749593a25fdf5 Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Sun, 19 Mar 2023 21:18:44 -0700 Subject: [PATCH] add auto-snooping for transports --- bumble/host.py | 4 +- bumble/snoop.py | 95 ++++++++++++++++++++++++++++++++---- bumble/transport/__init__.py | 67 ++++++++++++++++++++++--- bumble/transport/common.py | 75 ++++++++++++++++++++++++++++ setup.cfg | 2 +- tests/transport_test.py | 5 +- 6 files changed, 227 insertions(+), 21 deletions(-) diff --git a/bumble/host.py b/bumble/host.py index 87ec610..9f667a1 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -276,7 +276,7 @@ class Host(AbortableEventEmitter): def send_hci_packet(self, packet): if self.snooper: - self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER) + self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) self.hci_sink.on_packet(packet.to_bytes()) @@ -425,7 +425,7 @@ class Host(AbortableEventEmitter): logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}') if self.snooper: - self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST) + self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST) # If the packet is a command, invoke the handler for this packet if packet.hci_packet_type == HCI_COMMAND_PACKET: diff --git a/bumble/snoop.py b/bumble/snoop.py index 359fa38..462e923 100644 --- a/bumble/snoop.py +++ b/bumble/snoop.py @@ -15,12 +15,21 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from contextlib import contextmanager from enum import IntEnum +import logging import struct import datetime -from typing import BinaryIO +from typing import BinaryIO, Generator +import os -from bumble.hci import HCI_Packet, HCI_COMMAND_PACKET, HCI_EVENT_PACKET +from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET + + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- @@ -44,7 +53,7 @@ class Snooper: HCI_BSCP = 1003 H5 = 1004 - def snoop(self, hci_packet: HCI_Packet, direction: Direction) -> None: + def snoop(self, hci_packet: bytes, direction: Direction) -> None: """Snoop on an HCI packet.""" @@ -67,9 +76,10 @@ class BtSnooper(Snooper): self.IDENTIFICATION_PATTERN + struct.pack('>LL', 1, self.DataLinkType.H4) ) - def snoop(self, hci_packet: HCI_Packet, direction: Snooper.Direction) -> None: + def snoop(self, hci_packet: bytes, direction: Snooper.Direction) -> None: flags = int(direction) - if hci_packet.hci_packet_type in (HCI_EVENT_PACKET, HCI_COMMAND_PACKET): + packet_type = hci_packet[0] + if packet_type in (HCI_EVENT_PACKET, HCI_COMMAND_PACKET): flags |= 0x10 # Compute the current timestamp @@ -79,15 +89,82 @@ class BtSnooper(Snooper): ) # Emit the record - packet_data = bytes(hci_packet) self.output.write( struct.pack( '>IIIIQ', - len(packet_data), # Original Length - len(packet_data), # Included Length + len(hci_packet), # Original Length + len(hci_packet), # Included Length flags, # Packet Flags 0, # Cumulative Drops timestamp, # Timestamp ) - + packet_data + + hci_packet ) + + +# ----------------------------------------------------------------------------- +_SNOOPER_INSTANCE_COUNT = 0 + + +@contextmanager +def create_snooper(spec: str) -> Generator[Snooper, None, None]: + """ + Create a snooper given a specification string. + + The general syntax for the specification string is: + : + + Supported snooper types are: + + btsnoop + The syntax for the type-specific arguments for this type is: + : + + Supported I/O types are: + + file + The type-specific arguments for this I/O type is a string that is converted + to a file path using the python `str.format()` string formatting. The log + records will be written to that file if it can be opened/created. + The keyword args that may be referenced by the string pattern are: + now: the value of `datetime.now()` + utcnow: the value of `datetime.utcnow()` + pid: the current process ID. + instance: the instance ID in the current process. + + Examples: + btsnoop:file:my_btsnoop.log + btsnoop:file:/tmp/bumble_{now:%Y-%m-%d-%H:%M:%S}_{pid}.log + + """ + if ':' not in spec: + raise ValueError('snooper type prefix missing') + + snooper_type, snooper_args = spec.split(':', maxsplit=1) + + if snooper_type == 'btsnoop': + if ':' not in snooper_args: + raise ValueError('I/O type for btsnoop type missing') + + io_type, io_name = snooper_args.split(':', maxsplit=1) + if io_type == 'file': + # Process the file name string pattern. + global _SNOOPER_INSTANCE_COUNT + file_path = io_name.format( + now=datetime.datetime.now(), + utcnow=datetime.datetime.utcnow(), + pid=os.getpid(), + instance=_SNOOPER_INSTANCE_COUNT, + ) + + # Open the file + logger.debug(f'Snoop file: {file_path}') + with open(file_path, 'wb') as snoop_file: + _SNOOPER_INSTANCE_COUNT += 1 + yield BtSnooper(snoop_file) + _SNOOPER_INSTANCE_COUNT -= 1 + return + + raise ValueError('I/O type not supported') + + raise ValueError(f'snooper type {snooper_type} not found') diff --git a/bumble/transport/__init__.py b/bumble/transport/__init__.py index 8a93ed7..2d4600f 100644 --- a/bumble/transport/__init__.py +++ b/bumble/transport/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021-2022 Google LLC +# Copyright 2021-2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +15,13 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from contextlib import asynccontextmanager import logging +import os -from .common import Transport, AsyncPipeSink +from .common import Transport, AsyncPipeSink, SnoopingTransport from ..controller import Controller +from ..snoop import create_snooper # ----------------------------------------------------------------------------- # Logging @@ -26,14 +29,53 @@ from ..controller import Controller logger = logging.getLogger(__name__) +# ----------------------------------------------------------------------------- +def _wrap_transport(transport: Transport) -> Transport: + """ + Automatically wrap a Transport instance when a wrapping class can be inferred + from the environment. + If no wrapping class is applicable, the transport argument is returned as-is. + """ + + # If BUMBLE_SNOOPER is set, try to automatically create a snooper. + if snooper_spec := os.getenv('BUMBLE_SNOOPER'): + try: + return SnoopingTransport.create_with( + transport, create_snooper(snooper_spec) + ) + except Exception as exc: + logger.warning(f'Exception while creating snooper: {exc}') + + return transport + + # ----------------------------------------------------------------------------- async def open_transport(name: str) -> Transport: - ''' + """ Open a transport by name. The name must be : Where depend on the type (and may be empty for some types). - The supported types are: serial,udp,tcp,pty,usb - ''' + The supported types are: + * serial + * udp + * tcp-client + * tcp-server + * ws-client + * ws-server + * pty + * file + * vhci + * hci-socket + * usb + * pyusb + * android-emulator + """ + + return _wrap_transport(await _open_transport(name)) + + +# ----------------------------------------------------------------------------- +async def _open_transport(name: str) -> Transport: # pylint: disable=import-outside-toplevel # pylint: disable=too-many-return-statements @@ -107,7 +149,18 @@ async def open_transport(name: str) -> Transport: # ----------------------------------------------------------------------------- -async def open_transport_or_link(name): +async def open_transport_or_link(name: str) -> Transport: + """ + Open a transport or a link relay. + + Args: + name: + Name of the transport or link relay to open. + When the name starts with "link-relay:", open a link relay (see RemoteLink + for details on what the arguments are). + For other namespaces, see `open_transport`. + + """ if name.startswith('link-relay:'): from ..link import RemoteLink # lazy import @@ -119,6 +172,6 @@ async def open_transport_or_link(name): async def close(self): link.close() - return LinkTransport(controller, AsyncPipeSink(controller)) + return _wrap_transport(LinkTransport(controller, AsyncPipeSink(controller))) return await open_transport(name) diff --git a/bumble/transport/common.py b/bumble/transport/common.py index 945ba4b..05a1fb5 100644 --- a/bumble/transport/common.py +++ b/bumble/transport/common.py @@ -15,12 +15,16 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from __future__ import annotations +import contextlib import struct import asyncio import logging +from typing import ContextManager from .. import hci from ..colors import color +from ..snoop import Snooper # ----------------------------------------------------------------------------- @@ -246,6 +250,20 @@ class StreamPacketSink: # ----------------------------------------------------------------------------- class Transport: + """ + Base class for all transports. + + A Transport represents a source and a sink together. + An instance must be closed by calling close() when no longer used. Instances + implement the ContextManager protocol so that they may be used in a `async with` + statement. + An instance is iterable. The iterator yields, in order, its source and sink, so + that it may be used with a convenient call syntax like: + + async with create_transport() as (source, sink): + ... + """ + def __init__(self, source, sink): self.source = source self.sink = sink @@ -335,3 +353,60 @@ class PumpedTransport(Transport): async def close(self): await super().close() await self.close_function() + + +# ----------------------------------------------------------------------------- +class SnoopingTransport(Transport): + """Transport wrapper that snoops on packets to/from a wrapped transport.""" + + @staticmethod + def create_with( + transport: Transport, snooper: ContextManager[Snooper] + ) -> SnoopingTransport: + """ + Create an instance given a snooper that works as as context manager. + + The returned instance will exit the snooper context when it is closed. + """ + with contextlib.ExitStack() as exit_stack: + return SnoopingTransport( + transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close + ) + raise RuntimeError('unexpected code path') # Satisfy the type checker + + class Source: + def __init__(self, source, snooper): + self.source = source + self.snooper = snooper + self.sink = None + + def set_packet_sink(self, sink): + self.sink = sink + self.source.set_packet_sink(self) + + def on_packet(self, packet): + self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST) + if self.sink: + self.sink.on_packet(packet) + + class Sink: + def __init__(self, sink, snooper): + self.sink = sink + self.snooper = snooper + + def on_packet(self, packet): + self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER) + if self.sink: + self.sink.on_packet(packet) + + def __init__(self, transport, snooper, close_snooper=None): + super().__init__( + self.Source(transport.source, snooper), self.Sink(transport.sink, snooper) + ) + self.transport = transport + self.close_snooper = close_snooper + + async def close(self): + await self.transport.close() + if self.close_snooper: + self.close_snooper() diff --git a/setup.cfg b/setup.cfg index ef3dbdf..662dd5c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,7 +72,7 @@ test = development = black == 22.10 invoke >= 1.7.3 - mypy == 0.991 + mypy == 1.1.1 nox >= 2022 pylint == 2.15.8 types-appdirs >= 1.4.3 diff --git a/tests/transport_test.py b/tests/transport_test.py index c0069a0..cd3c5f2 100644 --- a/tests/transport_test.py +++ b/tests/transport_test.py @@ -72,5 +72,6 @@ def test_parser_extensions(): # ----------------------------------------------------------------------------- -test_parser() -test_parser_extensions() +if __name__ == '__main__': + test_parser() + test_parser_extensions()