forked from auracaster/bumble_mirror
Compare commits
9 Commits
gbg/snoop
...
gbg/snoop-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa9af61cbe | ||
|
|
dc3ac3060e | ||
|
|
e77723a5f9 | ||
|
|
fe8cf51432 | ||
|
|
97a0e115ae | ||
|
|
46e7aac77c | ||
|
|
08a6f4fa49 | ||
|
|
ca063eda0b | ||
|
|
c97ba4319f |
@@ -25,12 +25,14 @@
|
||||
from __future__ import annotations
|
||||
import struct
|
||||
from pyee import EventEmitter
|
||||
from typing import Dict, Type
|
||||
from typing import Dict, Type, TYPE_CHECKING
|
||||
|
||||
from bumble.core import UUID, name_or_number
|
||||
from bumble.hci import HCI_Object, key_with_value
|
||||
from bumble.colors import color
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Connection
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -749,7 +751,25 @@ class Attribute(EventEmitter):
|
||||
def decode_value(self, value_bytes):
|
||||
return value_bytes
|
||||
|
||||
def read_value(self, connection):
|
||||
def read_value(self, connection: Connection):
|
||||
if (
|
||||
self.permissions & self.READ_REQUIRES_ENCRYPTION
|
||||
) and not connection.encryption:
|
||||
raise ATT_Error(
|
||||
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
|
||||
)
|
||||
if (
|
||||
self.permissions & self.READ_REQUIRES_AUTHENTICATION
|
||||
) and not connection.authenticated:
|
||||
raise ATT_Error(
|
||||
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
|
||||
)
|
||||
if self.permissions & self.READ_REQUIRES_AUTHORIZATION:
|
||||
# TODO: handle authorization better
|
||||
raise ATT_Error(
|
||||
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
|
||||
)
|
||||
|
||||
if read := getattr(self.value, 'read', None):
|
||||
try:
|
||||
value = read(connection) # pylint: disable=not-callable
|
||||
@@ -762,7 +782,25 @@ class Attribute(EventEmitter):
|
||||
|
||||
return self.encode_value(value)
|
||||
|
||||
def write_value(self, connection, value_bytes):
|
||||
def write_value(self, connection: Connection, value_bytes):
|
||||
if (
|
||||
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
|
||||
) and not connection.encryption:
|
||||
raise ATT_Error(
|
||||
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
|
||||
)
|
||||
if (
|
||||
self.permissions & self.WRITE_REQUIRES_AUTHENTICATION
|
||||
) and not connection.authenticated:
|
||||
raise ATT_Error(
|
||||
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
|
||||
)
|
||||
if self.permissions & self.WRITE_REQUIRES_AUTHORIZATION:
|
||||
# TODO: handle authorization better
|
||||
raise ATT_Error(
|
||||
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
|
||||
)
|
||||
|
||||
value = self.decode_value(value_bytes)
|
||||
|
||||
if write := getattr(self.value, 'write', None):
|
||||
|
||||
@@ -144,9 +144,12 @@ class ConnectionError(BaseError): # pylint: disable=redefined-builtin
|
||||
class UUID:
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part B - 2.5.1 UUID
|
||||
|
||||
Note that this class expects and works in little-endian byte-order throughout.
|
||||
The exception is when interacting with strings, which are in big-endian byte-order.
|
||||
'''
|
||||
|
||||
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')
|
||||
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian
|
||||
UUIDS: List[UUID] = [] # Registry of all instances created
|
||||
|
||||
def __init__(self, uuid_str_or_int, name=None):
|
||||
@@ -209,13 +212,20 @@ class UUID:
|
||||
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
|
||||
|
||||
def to_bytes(self, force_128=False):
|
||||
if len(self.uuid_bytes) == 16 or not force_128:
|
||||
'''
|
||||
Serialize UUID in little-endian byte-order
|
||||
'''
|
||||
if not force_128:
|
||||
return self.uuid_bytes
|
||||
|
||||
if len(self.uuid_bytes) == 4:
|
||||
return self.uuid_bytes + UUID.BASE_UUID
|
||||
|
||||
return self.uuid_bytes + bytes([0, 0]) + UUID.BASE_UUID
|
||||
if len(self.uuid_bytes) == 2:
|
||||
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
|
||||
elif len(self.uuid_bytes) == 4:
|
||||
return self.BASE_UUID + self.uuid_bytes
|
||||
elif len(self.uuid_bytes) == 16:
|
||||
return self.uuid_bytes
|
||||
else:
|
||||
assert False, "unreachable"
|
||||
|
||||
def to_pdu_bytes(self):
|
||||
'''
|
||||
|
||||
@@ -534,6 +534,9 @@ class Connection(CompositeEventEmitter):
|
||||
def on_connection_parameters_update_failure(self, error):
|
||||
pass
|
||||
|
||||
def on_connection_data_length_change(self):
|
||||
pass
|
||||
|
||||
def on_connection_phy_update(self):
|
||||
pass
|
||||
|
||||
@@ -2008,7 +2011,7 @@ class Device(CompositeEventEmitter):
|
||||
NOTE: the name of the parameters may look odd, but it just follows the names
|
||||
used in the Bluetooth spec.
|
||||
'''
|
||||
await self.send_command(
|
||||
result = await self.send_command(
|
||||
HCI_LE_Connection_Update_Command(
|
||||
connection_handle=connection.handle,
|
||||
connection_interval_min=connection_interval_min,
|
||||
@@ -2017,9 +2020,10 @@ class Device(CompositeEventEmitter):
|
||||
supervision_timeout=supervision_timeout,
|
||||
min_ce_length=min_ce_length,
|
||||
max_ce_length=max_ce_length,
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
)
|
||||
if result.status != HCI_Command_Status_Event.PENDING:
|
||||
raise HCI_StatusError(result)
|
||||
|
||||
async def get_connection_rssi(self, connection):
|
||||
result = await self.send_command(
|
||||
|
||||
@@ -61,7 +61,6 @@ from .att import (
|
||||
from .gatt import (
|
||||
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
GATT_REQUEST_TIMEOUT,
|
||||
@@ -543,8 +542,6 @@ class Server(EventEmitter):
|
||||
if attribute.handle >= request.starting_handle
|
||||
and attribute.handle <= request.ending_handle
|
||||
):
|
||||
# TODO: check permissions
|
||||
|
||||
this_uuid_size = len(attribute.type.to_pdu_bytes())
|
||||
|
||||
if attributes:
|
||||
@@ -638,6 +635,13 @@ class Server(EventEmitter):
|
||||
'''
|
||||
|
||||
pdu_space_available = connection.att_mtu - 2
|
||||
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.starting_handle,
|
||||
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
)
|
||||
|
||||
attributes = []
|
||||
for attribute in (
|
||||
attribute
|
||||
@@ -647,10 +651,21 @@ class Server(EventEmitter):
|
||||
and attribute.handle <= request.ending_handle
|
||||
and pdu_space_available
|
||||
):
|
||||
# TODO: check permissions
|
||||
|
||||
try:
|
||||
attribute_value = attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
# If the first attribute is unreadable, return an error
|
||||
# Otherwise return attributes up to this point
|
||||
if not attributes:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=attribute.handle,
|
||||
error_code=error.error_code,
|
||||
)
|
||||
break
|
||||
|
||||
# Check the attribute value size
|
||||
attribute_value = attribute.read_value(connection)
|
||||
max_attribute_size = min(connection.att_mtu - 4, 253)
|
||||
if len(attribute_value) > max_attribute_size:
|
||||
# We need to truncate
|
||||
@@ -676,11 +691,7 @@ class Server(EventEmitter):
|
||||
length=entry_size, attribute_data_list=b''.join(attribute_data_list)
|
||||
)
|
||||
else:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.starting_handle,
|
||||
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
)
|
||||
logging.warning(f"not found {request}")
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
@@ -690,10 +701,17 @@ class Server(EventEmitter):
|
||||
'''
|
||||
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
# TODO: check permissions
|
||||
value = attribute.read_value(connection)
|
||||
value_size = min(connection.att_mtu - 1, len(value))
|
||||
response = ATT_Read_Response(attribute_value=value[:value_size])
|
||||
try:
|
||||
value = attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=error.error_code,
|
||||
)
|
||||
else:
|
||||
value_size = min(connection.att_mtu - 1, len(value))
|
||||
response = ATT_Read_Response(attribute_value=value[:value_size])
|
||||
else:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -708,29 +726,36 @@ class Server(EventEmitter):
|
||||
'''
|
||||
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
# TODO: check permissions
|
||||
value = attribute.read_value(connection)
|
||||
if request.value_offset > len(value):
|
||||
try:
|
||||
value = attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=ATT_INVALID_OFFSET_ERROR,
|
||||
)
|
||||
elif len(value) <= connection.att_mtu - 1:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
||||
error_code=error.error_code,
|
||||
)
|
||||
else:
|
||||
part_size = min(
|
||||
connection.att_mtu - 1, len(value) - request.value_offset
|
||||
)
|
||||
response = ATT_Read_Blob_Response(
|
||||
part_attribute_value=value[
|
||||
request.value_offset : request.value_offset + part_size
|
||||
]
|
||||
)
|
||||
if request.value_offset > len(value):
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=ATT_INVALID_OFFSET_ERROR,
|
||||
)
|
||||
elif len(value) <= connection.att_mtu - 1:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
||||
)
|
||||
else:
|
||||
part_size = min(
|
||||
connection.att_mtu - 1, len(value) - request.value_offset
|
||||
)
|
||||
response = ATT_Read_Blob_Response(
|
||||
part_attribute_value=value[
|
||||
request.value_offset : request.value_offset + part_size
|
||||
]
|
||||
)
|
||||
else:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -746,7 +771,6 @@ class Server(EventEmitter):
|
||||
if request.attribute_group_type not in (
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||
):
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -766,8 +790,10 @@ class Server(EventEmitter):
|
||||
and attribute.handle <= request.ending_handle
|
||||
and pdu_space_available
|
||||
):
|
||||
# Check the attribute value size
|
||||
# No need to catch permission errors here, since these attributes
|
||||
# must all be world-readable
|
||||
attribute_value = attribute.read_value(connection)
|
||||
# Check the attribute value size
|
||||
max_attribute_size = min(connection.att_mtu - 6, 251)
|
||||
if len(attribute_value) > max_attribute_size:
|
||||
# We need to truncate
|
||||
|
||||
@@ -1421,7 +1421,11 @@ class HCI_Constant:
|
||||
# -----------------------------------------------------------------------------
|
||||
class HCI_Error(ProtocolError):
|
||||
def __init__(self, error_code):
|
||||
super().__init__(error_code, 'hci', HCI_Constant.error_name(error_code))
|
||||
super().__init__(
|
||||
error_code,
|
||||
error_namespace='hci',
|
||||
error_name=HCI_Constant.error_name(error_code),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -276,7 +276,7 @@ class Host(AbortableEventEmitter):
|
||||
|
||||
def send_hci_packet(self, packet):
|
||||
if self.snooper:
|
||||
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
|
||||
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
|
||||
|
||||
self.hci_sink.on_packet(packet.to_bytes())
|
||||
|
||||
@@ -425,7 +425,7 @@ class Host(AbortableEventEmitter):
|
||||
logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}')
|
||||
|
||||
if self.snooper:
|
||||
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
|
||||
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
|
||||
|
||||
# If the packet is a command, invoke the handler for this packet
|
||||
if packet.hci_packet_type == HCI_COMMAND_PACKET:
|
||||
|
||||
@@ -15,12 +15,21 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from contextlib import contextmanager
|
||||
from enum import IntEnum
|
||||
import logging
|
||||
import struct
|
||||
import datetime
|
||||
from typing import BinaryIO
|
||||
from typing import BinaryIO, Generator
|
||||
import os
|
||||
|
||||
from bumble.hci import HCI_Packet, HCI_COMMAND_PACKET, HCI_EVENT_PACKET
|
||||
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -44,7 +53,7 @@ class Snooper:
|
||||
HCI_BSCP = 1003
|
||||
H5 = 1004
|
||||
|
||||
def snoop(self, hci_packet: HCI_Packet, direction: Direction) -> None:
|
||||
def snoop(self, hci_packet: bytes, direction: Direction) -> None:
|
||||
"""Snoop on an HCI packet."""
|
||||
|
||||
|
||||
@@ -67,9 +76,10 @@ class BtSnooper(Snooper):
|
||||
self.IDENTIFICATION_PATTERN + struct.pack('>LL', 1, self.DataLinkType.H4)
|
||||
)
|
||||
|
||||
def snoop(self, hci_packet: HCI_Packet, direction: Snooper.Direction) -> None:
|
||||
def snoop(self, hci_packet: bytes, direction: Snooper.Direction) -> None:
|
||||
flags = int(direction)
|
||||
if hci_packet.hci_packet_type in (HCI_EVENT_PACKET, HCI_COMMAND_PACKET):
|
||||
packet_type = hci_packet[0]
|
||||
if packet_type in (HCI_EVENT_PACKET, HCI_COMMAND_PACKET):
|
||||
flags |= 0x10
|
||||
|
||||
# Compute the current timestamp
|
||||
@@ -79,15 +89,82 @@ class BtSnooper(Snooper):
|
||||
)
|
||||
|
||||
# Emit the record
|
||||
packet_data = bytes(hci_packet)
|
||||
self.output.write(
|
||||
struct.pack(
|
||||
'>IIIIQ',
|
||||
len(packet_data), # Original Length
|
||||
len(packet_data), # Included Length
|
||||
len(hci_packet), # Original Length
|
||||
len(hci_packet), # Included Length
|
||||
flags, # Packet Flags
|
||||
0, # Cumulative Drops
|
||||
timestamp, # Timestamp
|
||||
)
|
||||
+ packet_data
|
||||
+ hci_packet
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
_SNOOPER_INSTANCE_COUNT = 0
|
||||
|
||||
|
||||
@contextmanager
|
||||
def create_snooper(spec: str) -> Generator[Snooper, None, None]:
|
||||
"""
|
||||
Create a snooper given a specification string.
|
||||
|
||||
The general syntax for the specification string is:
|
||||
<snooper-type>:<type-specific-arguments>
|
||||
|
||||
Supported snooper types are:
|
||||
|
||||
btsnoop
|
||||
The syntax for the type-specific arguments for this type is:
|
||||
<io-type>:<io-type-specific-arguments>
|
||||
|
||||
Supported I/O types are:
|
||||
|
||||
file
|
||||
The type-specific arguments for this I/O type is a string that is converted
|
||||
to a file path using the python `str.format()` string formatting. The log
|
||||
records will be written to that file if it can be opened/created.
|
||||
The keyword args that may be referenced by the string pattern are:
|
||||
now: the value of `datetime.now()`
|
||||
utcnow: the value of `datetime.utcnow()`
|
||||
pid: the current process ID.
|
||||
instance: the instance ID in the current process.
|
||||
|
||||
Examples:
|
||||
btsnoop:file:my_btsnoop.log
|
||||
btsnoop:file:/tmp/bumble_{now:%Y-%m-%d-%H:%M:%S}_{pid}.log
|
||||
|
||||
"""
|
||||
if ':' not in spec:
|
||||
raise ValueError('snooper type prefix missing')
|
||||
|
||||
snooper_type, snooper_args = spec.split(':', maxsplit=1)
|
||||
|
||||
if snooper_type == 'btsnoop':
|
||||
if ':' not in snooper_args:
|
||||
raise ValueError('I/O type for btsnoop snooper type missing')
|
||||
|
||||
io_type, io_name = snooper_args.split(':', maxsplit=1)
|
||||
if io_type == 'file':
|
||||
# Process the file name string pattern.
|
||||
global _SNOOPER_INSTANCE_COUNT
|
||||
file_path = io_name.format(
|
||||
now=datetime.datetime.now(),
|
||||
utcnow=datetime.datetime.utcnow(),
|
||||
pid=os.getpid(),
|
||||
instance=_SNOOPER_INSTANCE_COUNT,
|
||||
)
|
||||
|
||||
# Open the file
|
||||
logger.debug(f'Snoop file: {file_path}')
|
||||
with open(file_path, 'wb') as snoop_file:
|
||||
_SNOOPER_INSTANCE_COUNT += 1
|
||||
yield BtSnooper(snoop_file)
|
||||
_SNOOPER_INSTANCE_COUNT -= 1
|
||||
return
|
||||
|
||||
raise ValueError(f'I/O type {io_type} not supported')
|
||||
|
||||
raise ValueError(f'snooper type {snooper_type} not found')
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,10 +15,13 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import os
|
||||
|
||||
from .common import Transport, AsyncPipeSink
|
||||
from .common import Transport, AsyncPipeSink, SnoopingTransport
|
||||
from ..controller import Controller
|
||||
from ..snoop import create_snooper
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -26,14 +29,53 @@ from ..controller import Controller
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def _wrap_transport(transport: Transport) -> Transport:
|
||||
"""
|
||||
Automatically wrap a Transport instance when a wrapping class can be inferred
|
||||
from the environment.
|
||||
If no wrapping class is applicable, the transport argument is returned as-is.
|
||||
"""
|
||||
|
||||
# If BUMBLE_SNOOPER is set, try to automatically create a snooper.
|
||||
if snooper_spec := os.getenv('BUMBLE_SNOOPER'):
|
||||
try:
|
||||
return SnoopingTransport.create_with(
|
||||
transport, create_snooper(snooper_spec)
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f'Exception while creating snooper: {exc}')
|
||||
|
||||
return transport
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_transport(name: str) -> Transport:
|
||||
'''
|
||||
"""
|
||||
Open a transport by name.
|
||||
The name must be <type>:<parameters>
|
||||
Where <parameters> depend on the type (and may be empty for some types).
|
||||
The supported types are: serial,udp,tcp,pty,usb
|
||||
'''
|
||||
The supported types are:
|
||||
* serial
|
||||
* udp
|
||||
* tcp-client
|
||||
* tcp-server
|
||||
* ws-client
|
||||
* ws-server
|
||||
* pty
|
||||
* file
|
||||
* vhci
|
||||
* hci-socket
|
||||
* usb
|
||||
* pyusb
|
||||
* android-emulator
|
||||
"""
|
||||
|
||||
return _wrap_transport(await _open_transport(name))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def _open_transport(name: str) -> Transport:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pylint: disable=too-many-return-statements
|
||||
|
||||
@@ -107,7 +149,18 @@ async def open_transport(name: str) -> Transport:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_transport_or_link(name):
|
||||
async def open_transport_or_link(name: str) -> Transport:
|
||||
"""
|
||||
Open a transport or a link relay.
|
||||
|
||||
Args:
|
||||
name:
|
||||
Name of the transport or link relay to open.
|
||||
When the name starts with "link-relay:", open a link relay (see RemoteLink
|
||||
for details on what the arguments are).
|
||||
For other namespaces, see `open_transport`.
|
||||
|
||||
"""
|
||||
if name.startswith('link-relay:'):
|
||||
from ..link import RemoteLink # lazy import
|
||||
|
||||
@@ -119,6 +172,6 @@ async def open_transport_or_link(name):
|
||||
async def close(self):
|
||||
link.close()
|
||||
|
||||
return LinkTransport(controller, AsyncPipeSink(controller))
|
||||
return _wrap_transport(LinkTransport(controller, AsyncPipeSink(controller)))
|
||||
|
||||
return await open_transport(name)
|
||||
|
||||
@@ -15,12 +15,16 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import struct
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import ContextManager
|
||||
|
||||
from .. import hci
|
||||
from ..colors import color
|
||||
from ..snoop import Snooper
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -246,6 +250,20 @@ class StreamPacketSink:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Transport:
|
||||
"""
|
||||
Base class for all transports.
|
||||
|
||||
A Transport represents a source and a sink together.
|
||||
An instance must be closed by calling close() when no longer used. Instances
|
||||
implement the ContextManager protocol so that they may be used in a `async with`
|
||||
statement.
|
||||
An instance is iterable. The iterator yields, in order, its source and sink, so
|
||||
that it may be used with a convenient call syntax like:
|
||||
|
||||
async with create_transport() as (source, sink):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, source, sink):
|
||||
self.source = source
|
||||
self.sink = sink
|
||||
@@ -335,3 +353,60 @@ class PumpedTransport(Transport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
await self.close_function()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SnoopingTransport(Transport):
|
||||
"""Transport wrapper that snoops on packets to/from a wrapped transport."""
|
||||
|
||||
@staticmethod
|
||||
def create_with(
|
||||
transport: Transport, snooper: ContextManager[Snooper]
|
||||
) -> SnoopingTransport:
|
||||
"""
|
||||
Create an instance given a snooper that works as as context manager.
|
||||
|
||||
The returned instance will exit the snooper context when it is closed.
|
||||
"""
|
||||
with contextlib.ExitStack() as exit_stack:
|
||||
return SnoopingTransport(
|
||||
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
|
||||
)
|
||||
raise RuntimeError('unexpected code path') # Satisfy the type checker
|
||||
|
||||
class Source:
|
||||
def __init__(self, source, snooper):
|
||||
self.source = source
|
||||
self.snooper = snooper
|
||||
self.sink = None
|
||||
|
||||
def set_packet_sink(self, sink):
|
||||
self.sink = sink
|
||||
self.source.set_packet_sink(self)
|
||||
|
||||
def on_packet(self, packet):
|
||||
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
|
||||
if self.sink:
|
||||
self.sink.on_packet(packet)
|
||||
|
||||
class Sink:
|
||||
def __init__(self, sink, snooper):
|
||||
self.sink = sink
|
||||
self.snooper = snooper
|
||||
|
||||
def on_packet(self, packet):
|
||||
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
|
||||
if self.sink:
|
||||
self.sink.on_packet(packet)
|
||||
|
||||
def __init__(self, transport, snooper, close_snooper=None):
|
||||
super().__init__(
|
||||
self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
|
||||
)
|
||||
self.transport = transport
|
||||
self.close_snooper = close_snooper
|
||||
|
||||
async def close(self):
|
||||
await self.transport.close()
|
||||
if self.close_snooper:
|
||||
self.close_snooper()
|
||||
|
||||
@@ -72,7 +72,7 @@ test =
|
||||
development =
|
||||
black == 22.10
|
||||
invoke >= 1.7.3
|
||||
mypy == 0.991
|
||||
mypy == 1.1.1
|
||||
nox >= 2022
|
||||
pylint == 2.15.8
|
||||
types-appdirs >= 1.4.3
|
||||
|
||||
@@ -72,5 +72,6 @@ def test_parser_extensions():
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
test_parser()
|
||||
test_parser_extensions()
|
||||
if __name__ == '__main__':
|
||||
test_parser()
|
||||
test_parser_extensions()
|
||||
|
||||
Reference in New Issue
Block a user