add auto-snooping for transports

This commit is contained in:
Gilles Boccon-Gibod
2023-03-19 21:18:44 -07:00
parent e77723a5f9
commit dc3ac3060e
6 changed files with 227 additions and 21 deletions

View File

@@ -276,7 +276,7 @@ class Host(AbortableEventEmitter):
def send_hci_packet(self, packet): def send_hci_packet(self, packet):
if self.snooper: if self.snooper:
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER) self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
self.hci_sink.on_packet(packet.to_bytes()) self.hci_sink.on_packet(packet.to_bytes())
@@ -425,7 +425,7 @@ class Host(AbortableEventEmitter):
logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}') logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}')
if self.snooper: if self.snooper:
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST) self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
# If the packet is a command, invoke the handler for this packet # If the packet is a command, invoke the handler for this packet
if packet.hci_packet_type == HCI_COMMAND_PACKET: if packet.hci_packet_type == HCI_COMMAND_PACKET:

View File

@@ -15,12 +15,21 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from contextlib import contextmanager
from enum import IntEnum from enum import IntEnum
import logging
import struct import struct
import datetime import datetime
from typing import BinaryIO from typing import BinaryIO, Generator
import os
from bumble.hci import HCI_Packet, HCI_COMMAND_PACKET, HCI_EVENT_PACKET from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -44,7 +53,7 @@ class Snooper:
HCI_BSCP = 1003 HCI_BSCP = 1003
H5 = 1004 H5 = 1004
def snoop(self, hci_packet: HCI_Packet, direction: Direction) -> None: def snoop(self, hci_packet: bytes, direction: Direction) -> None:
"""Snoop on an HCI packet.""" """Snoop on an HCI packet."""
@@ -67,9 +76,10 @@ class BtSnooper(Snooper):
self.IDENTIFICATION_PATTERN + struct.pack('>LL', 1, self.DataLinkType.H4) self.IDENTIFICATION_PATTERN + struct.pack('>LL', 1, self.DataLinkType.H4)
) )
def snoop(self, hci_packet: HCI_Packet, direction: Snooper.Direction) -> None: def snoop(self, hci_packet: bytes, direction: Snooper.Direction) -> None:
flags = int(direction) flags = int(direction)
if hci_packet.hci_packet_type in (HCI_EVENT_PACKET, HCI_COMMAND_PACKET): packet_type = hci_packet[0]
if packet_type in (HCI_EVENT_PACKET, HCI_COMMAND_PACKET):
flags |= 0x10 flags |= 0x10
# Compute the current timestamp # Compute the current timestamp
@@ -79,15 +89,82 @@ class BtSnooper(Snooper):
) )
# Emit the record # Emit the record
packet_data = bytes(hci_packet)
self.output.write( self.output.write(
struct.pack( struct.pack(
'>IIIIQ', '>IIIIQ',
len(packet_data), # Original Length len(hci_packet), # Original Length
len(packet_data), # Included Length len(hci_packet), # Included Length
flags, # Packet Flags flags, # Packet Flags
0, # Cumulative Drops 0, # Cumulative Drops
timestamp, # Timestamp timestamp, # Timestamp
) )
+ packet_data + hci_packet
) )
# -----------------------------------------------------------------------------
_SNOOPER_INSTANCE_COUNT = 0
@contextmanager
def create_snooper(spec: str) -> Generator[Snooper, None, None]:
"""
Create a snooper given a specification string.
The general syntax for the specification string is:
<snooper-type>:<type-specific-arguments>
Supported snooper types are:
btsnoop
The syntax for the type-specific arguments for this type is:
<io-type>:<io-type-specific-arguments>
Supported I/O types are:
file
The type-specific arguments for this I/O type is a string that is converted
to a file path using the python `str.format()` string formatting. The log
records will be written to that file if it can be opened/created.
The keyword args that may be referenced by the string pattern are:
now: the value of `datetime.now()`
utcnow: the value of `datetime.utcnow()`
pid: the current process ID.
instance: the instance ID in the current process.
Examples:
btsnoop:file:my_btsnoop.log
btsnoop:file:/tmp/bumble_{now:%Y-%m-%d-%H:%M:%S}_{pid}.log
"""
if ':' not in spec:
raise ValueError('snooper type prefix missing')
snooper_type, snooper_args = spec.split(':', maxsplit=1)
if snooper_type == 'btsnoop':
if ':' not in snooper_args:
raise ValueError('I/O type for btsnoop type missing')
io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type == 'file':
# Process the file name string pattern.
global _SNOOPER_INSTANCE_COUNT
file_path = io_name.format(
now=datetime.datetime.now(),
utcnow=datetime.datetime.utcnow(),
pid=os.getpid(),
instance=_SNOOPER_INSTANCE_COUNT,
)
# Open the file
logger.debug(f'Snoop file: {file_path}')
with open(file_path, 'wb') as snoop_file:
_SNOOPER_INSTANCE_COUNT += 1
yield BtSnooper(snoop_file)
_SNOOPER_INSTANCE_COUNT -= 1
return
raise ValueError('I/O type not supported')
raise ValueError(f'snooper type {snooper_type} not found')

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC # Copyright 2021-2023 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -15,10 +15,13 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from contextlib import asynccontextmanager
import logging import logging
import os
from .common import Transport, AsyncPipeSink from .common import Transport, AsyncPipeSink, SnoopingTransport
from ..controller import Controller from ..controller import Controller
from ..snoop import create_snooper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -26,14 +29,53 @@ from ..controller import Controller
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def _wrap_transport(transport: Transport) -> Transport:
"""
Automatically wrap a Transport instance when a wrapping class can be inferred
from the environment.
If no wrapping class is applicable, the transport argument is returned as-is.
"""
# If BUMBLE_SNOOPER is set, try to automatically create a snooper.
if snooper_spec := os.getenv('BUMBLE_SNOOPER'):
try:
return SnoopingTransport.create_with(
transport, create_snooper(snooper_spec)
)
except Exception as exc:
logger.warning(f'Exception while creating snooper: {exc}')
return transport
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_transport(name: str) -> Transport: async def open_transport(name: str) -> Transport:
''' """
Open a transport by name. Open a transport by name.
The name must be <type>:<parameters> The name must be <type>:<parameters>
Where <parameters> depend on the type (and may be empty for some types). Where <parameters> depend on the type (and may be empty for some types).
The supported types are: serial,udp,tcp,pty,usb The supported types are:
''' * serial
* udp
* tcp-client
* tcp-server
* ws-client
* ws-server
* pty
* file
* vhci
* hci-socket
* usb
* pyusb
* android-emulator
"""
return _wrap_transport(await _open_transport(name))
# -----------------------------------------------------------------------------
async def _open_transport(name: str) -> Transport:
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
@@ -107,7 +149,18 @@ async def open_transport(name: str) -> Transport:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_transport_or_link(name): async def open_transport_or_link(name: str) -> Transport:
"""
Open a transport or a link relay.
Args:
name:
Name of the transport or link relay to open.
When the name starts with "link-relay:", open a link relay (see RemoteLink
for details on what the arguments are).
For other namespaces, see `open_transport`.
"""
if name.startswith('link-relay:'): if name.startswith('link-relay:'):
from ..link import RemoteLink # lazy import from ..link import RemoteLink # lazy import
@@ -119,6 +172,6 @@ async def open_transport_or_link(name):
async def close(self): async def close(self):
link.close() link.close()
return LinkTransport(controller, AsyncPipeSink(controller)) return _wrap_transport(LinkTransport(controller, AsyncPipeSink(controller)))
return await open_transport(name) return await open_transport(name)

View File

@@ -15,12 +15,16 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import contextlib
import struct import struct
import asyncio import asyncio
import logging import logging
from typing import ContextManager
from .. import hci from .. import hci
from ..colors import color from ..colors import color
from ..snoop import Snooper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -246,6 +250,20 @@ class StreamPacketSink:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Transport: class Transport:
"""
Base class for all transports.
A Transport represents a source and a sink together.
An instance must be closed by calling close() when no longer used. Instances
implement the ContextManager protocol so that they may be used in a `async with`
statement.
An instance is iterable. The iterator yields, in order, its source and sink, so
that it may be used with a convenient call syntax like:
async with create_transport() as (source, sink):
...
"""
def __init__(self, source, sink): def __init__(self, source, sink):
self.source = source self.source = source
self.sink = sink self.sink = sink
@@ -335,3 +353,60 @@ class PumpedTransport(Transport):
async def close(self): async def close(self):
await super().close() await super().close()
await self.close_function() await self.close_function()
# -----------------------------------------------------------------------------
class SnoopingTransport(Transport):
"""Transport wrapper that snoops on packets to/from a wrapped transport."""
@staticmethod
def create_with(
transport: Transport, snooper: ContextManager[Snooper]
) -> SnoopingTransport:
"""
Create an instance given a snooper that works as as context manager.
The returned instance will exit the snooper context when it is closed.
"""
with contextlib.ExitStack() as exit_stack:
return SnoopingTransport(
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
)
raise RuntimeError('unexpected code path') # Satisfy the type checker
class Source:
def __init__(self, source, snooper):
self.source = source
self.snooper = snooper
self.sink = None
def set_packet_sink(self, sink):
self.sink = sink
self.source.set_packet_sink(self)
def on_packet(self, packet):
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
if self.sink:
self.sink.on_packet(packet)
class Sink:
def __init__(self, sink, snooper):
self.sink = sink
self.snooper = snooper
def on_packet(self, packet):
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
if self.sink:
self.sink.on_packet(packet)
def __init__(self, transport, snooper, close_snooper=None):
super().__init__(
self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
)
self.transport = transport
self.close_snooper = close_snooper
async def close(self):
await self.transport.close()
if self.close_snooper:
self.close_snooper()

View File

@@ -72,7 +72,7 @@ test =
development = development =
black == 22.10 black == 22.10
invoke >= 1.7.3 invoke >= 1.7.3
mypy == 0.991 mypy == 1.1.1
nox >= 2022 nox >= 2022
pylint == 2.15.8 pylint == 2.15.8
types-appdirs >= 1.4.3 types-appdirs >= 1.4.3

View File

@@ -72,5 +72,6 @@ def test_parser_extensions():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
test_parser() if __name__ == '__main__':
test_parser_extensions() test_parser()
test_parser_extensions()