Merge pull request #271 from zxzxwu/device_typing

Typing transport and relateds
This commit is contained in:
zxzxwu
2023-09-09 00:55:59 +08:00
committed by GitHub
19 changed files with 188 additions and 99 deletions

View File

@@ -15,6 +15,8 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import logging import logging
import asyncio import asyncio
import itertools import itertools
@@ -58,8 +60,10 @@ from bumble.hci import (
HCI_Packet, HCI_Packet,
HCI_Role_Change_Event, HCI_Role_Change_Event,
) )
from typing import Optional, Union, Dict from typing import Optional, Union, Dict, TYPE_CHECKING
if TYPE_CHECKING:
from bumble.transport.common import TransportSink, TransportSource
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -104,7 +108,7 @@ class Controller:
self, self,
name, name,
host_source=None, host_source=None,
host_sink=None, host_sink: Optional[TransportSink] = None,
link=None, link=None,
public_address: Optional[Union[bytes, str, Address]] = None, public_address: Optional[Union[bytes, str, Address]] = None,
): ):

View File

@@ -23,7 +23,18 @@ import asyncio
import logging import logging
from contextlib import asynccontextmanager, AsyncExitStack from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
Union,
TYPE_CHECKING,
)
from .colors import color from .colors import color
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
@@ -152,6 +163,9 @@ from . import sdp
from . import l2cap from . import l2cap
from . import core from . import core
if TYPE_CHECKING:
from .transport.common import TransportSource, TransportSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -942,7 +956,13 @@ class Device(CompositeEventEmitter):
pass pass
@classmethod @classmethod
def with_hci(cls, name, address, hci_source, hci_sink): def with_hci(
cls,
name: str,
address: Address,
hci_source: TransportSource,
hci_sink: TransportSink,
) -> Device:
''' '''
Create a Device instance with a Host configured to communicate with a controller Create a Device instance with a Host configured to communicate with a controller
through an HCI source/sink through an HCI source/sink
@@ -951,18 +971,25 @@ class Device(CompositeEventEmitter):
return cls(name=name, address=address, host=host) return cls(name=name, address=address, host=host)
@classmethod @classmethod
def from_config_file(cls, filename): def from_config_file(cls, filename: str) -> Device:
config = DeviceConfiguration() config = DeviceConfiguration()
config.load_from_file(filename) config.load_from_file(filename)
return cls(config=config) return cls(config=config)
@classmethod @classmethod
def from_config_with_hci(cls, config, hci_source, hci_sink): def from_config_with_hci(
cls,
config: DeviceConfiguration,
hci_source: TransportSource,
hci_sink: TransportSink,
) -> Device:
host = Host(controller_source=hci_source, controller_sink=hci_sink) host = Host(controller_source=hci_source, controller_sink=hci_sink)
return cls(config=config, host=host) return cls(config=config, host=host)
@classmethod @classmethod
def from_config_file_with_hci(cls, filename, hci_source, hci_sink): def from_config_file_with_hci(
cls, filename: str, hci_source: TransportSource, hci_sink: TransportSink
) -> Device:
config = DeviceConfiguration() config = DeviceConfiguration()
config.load_from_file(filename) config.load_from_file(filename)
return cls.from_config_with_hci(config, hci_source, hci_sink) return cls.from_config_with_hci(config, hci_source, hci_sink)
@@ -2238,9 +2265,11 @@ class Device(CompositeEventEmitter):
def request_pairing(self, connection): def request_pairing(self, connection):
return self.smp_manager.request_pairing(connection) return self.smp_manager.request_pairing(connection)
async def get_long_term_key(self, connection_handle, rand, ediv): async def get_long_term_key(
self, connection_handle: int, rand: bytes, ediv: int
) -> Optional[bytes]:
if (connection := self.lookup_connection(connection_handle)) is None: if (connection := self.lookup_connection(connection_handle)) is None:
return return None
# Start by looking for the key in an SMP session # Start by looking for the key in an SMP session
ltk = self.smp_manager.get_long_term_key(connection, rand, ediv) ltk = self.smp_manager.get_long_term_key(connection, rand, ediv)
@@ -2260,6 +2289,7 @@ class Device(CompositeEventEmitter):
if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral: if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral:
return keys.ltk_peripheral.value return keys.ltk_peripheral.value
return None
async def get_link_key(self, address: Address) -> Optional[bytes]: async def get_link_key(self, address: Address) -> Optional[bytes]:
if self.keystore is None: if self.keystore is None:

View File

@@ -21,7 +21,7 @@ import collections
import logging import logging
import struct import struct
from typing import Optional from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable
from bumble.colors import color from bumble.colors import color
from bumble.l2cap import L2CAP_PDU from bumble.l2cap import L2CAP_PDU
@@ -73,10 +73,14 @@ from .core import (
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
ConnectionPHY, ConnectionPHY,
ConnectionParameters, ConnectionParameters,
InvalidStateError,
) )
from .utils import AbortableEventEmitter from .utils import AbortableEventEmitter
from .transport.common import TransportLostError from .transport.common import TransportLostError
if TYPE_CHECKING:
from .transport.common import TransportSink, TransportSource
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -116,10 +120,21 @@ class Connection:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Host(AbortableEventEmitter): class Host(AbortableEventEmitter):
def __init__(self, controller_source=None, controller_sink=None): connections: Dict[int, Connection]
acl_packet_queue: collections.deque[HCI_AclDataPacket]
hci_sink: TransportSink
long_term_key_provider: Optional[
Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
]
link_key_provider: Optional[Callable[[Address], Awaitable[Optional[bytes]]]]
def __init__(
self,
controller_source: Optional[TransportSource] = None,
controller_sink: Optional[TransportSink] = None,
) -> None:
super().__init__() super().__init__()
self.hci_sink = None
self.hci_metadata = None self.hci_metadata = None
self.ready = False # True when we can accept incoming packets self.ready = False # True when we can accept incoming packets
self.reset_done = False self.reset_done = False
@@ -299,7 +314,7 @@ class Host(AbortableEventEmitter):
self.reset_done = True self.reset_done = True
@property @property
def controller(self): def controller(self) -> TransportSink:
return self.hci_sink return self.hci_sink
@controller.setter @controller.setter
@@ -308,13 +323,12 @@ class Host(AbortableEventEmitter):
if controller: if controller:
controller.set_packet_sink(self) controller.set_packet_sink(self)
def set_packet_sink(self, sink): def set_packet_sink(self, sink: TransportSink) -> None:
self.hci_sink = sink self.hci_sink = sink
def send_hci_packet(self, packet: HCI_Packet) -> None: def send_hci_packet(self, packet: HCI_Packet) -> None:
if self.snooper: if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
self.hci_sink.on_packet(bytes(packet)) self.hci_sink.on_packet(bytes(packet))
async def send_command(self, command, check_result=False): async def send_command(self, command, check_result=False):

View File

@@ -20,7 +20,6 @@ import logging
import os import os
from .common import Transport, AsyncPipeSink, SnoopingTransport from .common import Transport, AsyncPipeSink, SnoopingTransport
from ..controller import Controller
from ..snoop import create_snooper from ..snoop import create_snooper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -119,7 +118,8 @@ async def _open_transport(name: str) -> Transport:
if scheme == 'file': if scheme == 'file':
from .file import open_file_transport from .file import open_file_transport
return await open_file_transport(spec[0] if spec else None) assert spec is not None
return await open_file_transport(spec[0])
if scheme == 'vhci': if scheme == 'vhci':
from .vhci import open_vhci_transport from .vhci import open_vhci_transport
@@ -134,12 +134,14 @@ async def _open_transport(name: str) -> Transport:
if scheme == 'usb': if scheme == 'usb':
from .usb import open_usb_transport from .usb import open_usb_transport
return await open_usb_transport(spec[0] if spec else None) assert spec is not None
return await open_usb_transport(spec[0])
if scheme == 'pyusb': if scheme == 'pyusb':
from .pyusb import open_pyusb_transport from .pyusb import open_pyusb_transport
return await open_pyusb_transport(spec[0] if spec else None) assert spec is not None
return await open_pyusb_transport(spec[0])
if scheme == 'android-emulator': if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport from .android_emulator import open_android_emulator_transport
@@ -168,6 +170,7 @@ async def open_transport_or_link(name: str) -> Transport:
""" """
if name.startswith('link-relay:'): if name.startswith('link-relay:'):
from ..controller import Controller
from ..link import RemoteLink # lazy import from ..link import RemoteLink # lazy import
link = RemoteLink(name[11:]) link = RemoteLink(name[11:])

View File

@@ -18,7 +18,7 @@
import logging import logging
import grpc.aio import grpc.aio
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_android_emulator_transport(spec): async def open_android_emulator_transport(spec: str | None) -> Transport:
''' '''
Open a transport connection to an Android emulator via its gRPC interface. Open a transport connection to an Android emulator via its gRPC interface.
The parameter string has this syntax: The parameter string has this syntax:
@@ -66,7 +66,7 @@ async def open_android_emulator_transport(spec):
# Parse the parameters # Parse the parameters
mode = 'host' mode = 'host'
server_host = 'localhost' server_host = 'localhost'
server_port = 8554 server_port = '8554'
if spec is not None: if spec is not None:
params = spec.split(',') params = spec.split(',')
for param in params: for param in params:
@@ -82,6 +82,7 @@ async def open_android_emulator_transport(spec):
logger.debug(f'connecting to gRPC server at {server_address}') logger.debug(f'connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address) channel = grpc.aio.insecure_channel(server_address)
service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub
if mode == 'host': if mode == 'host':
# Connect as a host # Connect as a host
service = EmulatedBluetoothServiceStub(channel) service = EmulatedBluetoothServiceStub(channel)

View File

@@ -121,7 +121,9 @@ def publish_grpc_port(grpc_port) -> bool:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_android_netsim_controller_transport(server_host, server_port): async def open_android_netsim_controller_transport(
server_host: str | None, server_port: int
) -> Transport:
if not server_port: if not server_port:
raise ValueError('invalid port') raise ValueError('invalid port')
if server_host == '_' or not server_host: if server_host == '_' or not server_host:

View File

@@ -20,11 +20,12 @@ import contextlib
import struct import struct
import asyncio import asyncio
import logging import logging
from typing import ContextManager import io
from typing import ContextManager, Tuple, Optional, Protocol, Dict
from .. import hci from bumble import hci
from ..colors import color from bumble.colors import color
from ..snoop import Snooper from bumble.snoop import Snooper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -36,7 +37,7 @@ logger = logging.getLogger(__name__)
# Information needed to parse HCI packets with a generic parser: # Information needed to parse HCI packets with a generic parser:
# For each packet type, the info represents: # For each packet type, the info represents:
# (length-size, length-offset, unpack-type) # (length-size, length-offset, unpack-type)
HCI_PACKET_INFO = { HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
hci.HCI_COMMAND_PACKET: (1, 2, 'B'), hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
@@ -44,6 +45,8 @@ HCI_PACKET_INFO = {
} }
# -----------------------------------------------------------------------------
# Errors
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class TransportLostError(Exception): class TransportLostError(Exception):
""" """
@@ -51,24 +54,34 @@ class TransportLostError(Exception):
""" """
# -----------------------------------------------------------------------------
# Typing Protocols
# -----------------------------------------------------------------------------
class TransportSink(Protocol):
def on_packet(self, packet: bytes) -> None:
...
class TransportSource(Protocol):
def set_packet_sink(self, sink: TransportSink) -> None:
...
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PacketPump: class PacketPump:
""" """
Pump HCI packets from a reader to a sink. Pump HCI packets from a reader to a sink.
""" """
def __init__(self, reader, sink): def __init__(self, reader: AsyncPacketReader, sink: TransportSink) -> None:
self.reader = reader self.reader = reader
self.sink = sink self.sink = sink
async def run(self): async def run(self) -> None:
while True: while True:
try: try:
# Get a packet from the source
packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet())
# Deliver the packet to the sink # Deliver the packet to the sink
self.sink.on_packet(packet) self.sink.on_packet(await self.reader.next_packet())
except Exception as error: except Exception as error:
logger.warning(f'!!! {error}') logger.warning(f'!!! {error}')
@@ -86,18 +99,22 @@ class PacketParser:
NEED_LENGTH = 1 NEED_LENGTH = 1
NEED_BODY = 2 NEED_BODY = 2
def __init__(self, sink=None): sink: Optional[TransportSink]
extended_packet_info: Dict[int, Tuple[int, int, str]]
packet_info: Optional[Tuple[int, int, str]] = None
def __init__(self, sink: Optional[TransportSink] = None) -> None:
self.sink = sink self.sink = sink
self.extended_packet_info = {} self.extended_packet_info = {}
self.reset() self.reset()
def reset(self): def reset(self) -> None:
self.state = PacketParser.NEED_TYPE self.state = PacketParser.NEED_TYPE
self.bytes_needed = 1 self.bytes_needed = 1
self.packet = bytearray() self.packet = bytearray()
self.packet_info = None self.packet_info = None
def feed_data(self, data): def feed_data(self, data: bytes) -> None:
data_offset = 0 data_offset = 0
data_left = len(data) data_left = len(data)
while data_left and self.bytes_needed: while data_left and self.bytes_needed:
@@ -118,6 +135,7 @@ class PacketParser:
self.state = PacketParser.NEED_LENGTH self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1] self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH: elif self.state == PacketParser.NEED_LENGTH:
assert self.packet_info is not None
body_length = struct.unpack_from( body_length = struct.unpack_from(
self.packet_info[2], self.packet, 1 + self.packet_info[1] self.packet_info[2], self.packet, 1 + self.packet_info[1]
)[0] )[0]
@@ -135,7 +153,7 @@ class PacketParser:
) )
self.reset() self.reset()
def set_packet_sink(self, sink): def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink self.sink = sink
@@ -145,10 +163,10 @@ class PacketReader:
Reader that reads HCI packets from a sync source. Reader that reads HCI packets from a sync source.
""" """
def __init__(self, source): def __init__(self, source: io.BufferedReader) -> None:
self.source = source self.source = source
def next_packet(self): def next_packet(self) -> Optional[bytes]:
# Get the packet type # Get the packet type
packet_type = self.source.read(1) packet_type = self.source.read(1)
if len(packet_type) != 1: if len(packet_type) != 1:
@@ -157,7 +175,7 @@ class PacketReader:
# Get the packet info based on its type # Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0]) packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None: if packet_info is None:
raise ValueError(f'invalid packet type {packet_type} found') raise ValueError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length) # Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1] header_size = packet_info[0] + packet_info[1]
@@ -180,17 +198,17 @@ class AsyncPacketReader:
Reader that reads HCI packets from an async source. Reader that reads HCI packets from an async source.
""" """
def __init__(self, source): def __init__(self, source: asyncio.StreamReader) -> None:
self.source = source self.source = source
async def next_packet(self): async def next_packet(self) -> bytes:
# Get the packet type # Get the packet type
packet_type = await self.source.readexactly(1) packet_type = await self.source.readexactly(1)
# Get the packet info based on its type # Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0]) packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None: if packet_info is None:
raise ValueError(f'invalid packet type {packet_type} found') raise ValueError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length) # Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1] header_size = packet_info[0] + packet_info[1]
@@ -209,11 +227,11 @@ class AsyncPipeSink:
Sink that forwards packets asynchronously to another sink. Sink that forwards packets asynchronously to another sink.
""" """
def __init__(self, sink): def __init__(self, sink: TransportSink) -> None:
self.sink = sink self.sink = sink
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.loop.call_soon(self.sink.on_packet, packet) self.loop.call_soon(self.sink.on_packet, packet)
@@ -223,50 +241,48 @@ class ParserSource:
Base class designed to be subclassed by transport-specific source classes Base class designed to be subclassed by transport-specific source classes
""" """
terminated: asyncio.Future terminated: asyncio.Future[None]
parser: PacketParser parser: PacketParser
def __init__(self): def __init__(self) -> None:
self.parser = PacketParser() self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future() self.terminated = asyncio.get_running_loop().create_future()
def set_packet_sink(self, sink): def set_packet_sink(self, sink: TransportSink) -> None:
self.parser.set_packet_sink(sink) self.parser.set_packet_sink(sink)
def on_transport_lost(self): def on_transport_lost(self) -> None:
self.terminated.set_result(None) self.terminated.set_result(None)
if self.parser.sink: if self.parser.sink:
try: if hasattr(self.parser.sink, 'on_transport_lost'):
self.parser.sink.on_transport_lost() self.parser.sink.on_transport_lost()
except AttributeError:
pass
async def wait_for_termination(self): async def wait_for_termination(self) -> None:
""" """
Convenience method for backward compatibility. Prefer using the `terminated` Convenience method for backward compatibility. Prefer using the `terminated`
attribute instead. attribute instead.
""" """
return await self.terminated return await self.terminated
def close(self): def close(self) -> None:
pass pass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class StreamPacketSource(asyncio.Protocol, ParserSource): class StreamPacketSource(asyncio.Protocol, ParserSource):
def data_received(self, data): def data_received(self, data: bytes) -> None:
self.parser.feed_data(data) self.parser.feed_data(data)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class StreamPacketSink: class StreamPacketSink:
def __init__(self, transport): def __init__(self, transport: asyncio.WriteTransport) -> None:
self.transport = transport self.transport = transport
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.transport.write(packet) self.transport.write(packet)
def close(self): def close(self) -> None:
self.transport.close() self.transport.close()
@@ -286,7 +302,7 @@ class Transport:
... ...
""" """
def __init__(self, source, sink): def __init__(self, source: TransportSource, sink: TransportSink) -> None:
self.source = source self.source = source
self.sink = sink self.sink = sink
@@ -300,19 +316,23 @@ class Transport:
return iter((self.source, self.sink)) return iter((self.source, self.sink))
async def close(self) -> None: async def close(self) -> None:
self.source.close() if hasattr(self.source, 'close'):
self.sink.close() self.source.close()
if hasattr(self.sink, 'close'):
self.sink.close()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PumpedPacketSource(ParserSource): class PumpedPacketSource(ParserSource):
def __init__(self, receive): pump_task: Optional[asyncio.Task[None]]
def __init__(self, receive) -> None:
super().__init__() super().__init__()
self.receive_function = receive self.receive_function = receive
self.pump_task = None self.pump_task = None
def start(self): def start(self) -> None:
async def pump_packets(): async def pump_packets() -> None:
while True: while True:
try: try:
packet = await self.receive_function() packet = await self.receive_function()
@@ -322,12 +342,12 @@ class PumpedPacketSource(ParserSource):
break break
except Exception as error: except Exception as error:
logger.warning(f'exception while waiting for packet: {error}') logger.warning(f'exception while waiting for packet: {error}')
self.terminated.set_result(error) self.terminated.set_exception(error)
break break
self.pump_task = asyncio.create_task(pump_packets()) self.pump_task = asyncio.create_task(pump_packets())
def close(self): def close(self) -> None:
if self.pump_task: if self.pump_task:
self.pump_task.cancel() self.pump_task.cancel()
@@ -339,7 +359,7 @@ class PumpedPacketSink:
self.packet_queue = asyncio.Queue() self.packet_queue = asyncio.Queue()
self.pump_task = None self.pump_task = None
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.packet_queue.put_nowait(packet) self.packet_queue.put_nowait(packet)
def start(self): def start(self):
@@ -364,15 +384,23 @@ class PumpedPacketSink:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PumpedTransport(Transport): class PumpedTransport(Transport):
def __init__(self, source, sink, close_function): source: PumpedPacketSource
sink: PumpedPacketSink
def __init__(
self,
source: PumpedPacketSource,
sink: PumpedPacketSink,
close_function,
) -> None:
super().__init__(source, sink) super().__init__(source, sink)
self.close_function = close_function self.close_function = close_function
def start(self): def start(self) -> None:
self.source.start() self.source.start()
self.sink.start() self.sink.start()
async def close(self): async def close(self) -> None:
await super().close() await super().close()
await self.close_function() await self.close_function()
@@ -397,31 +425,37 @@ class SnoopingTransport(Transport):
raise RuntimeError('unexpected code path') # Satisfy the type checker raise RuntimeError('unexpected code path') # Satisfy the type checker
class Source: class Source:
def __init__(self, source, snooper): sink: TransportSink
def __init__(self, source: TransportSource, snooper: Snooper):
self.source = source self.source = source
self.snooper = snooper self.snooper = snooper
self.sink = None
def set_packet_sink(self, sink): def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink self.sink = sink
self.source.set_packet_sink(self) self.source.set_packet_sink(self)
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST) self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
if self.sink: if self.sink:
self.sink.on_packet(packet) self.sink.on_packet(packet)
class Sink: class Sink:
def __init__(self, sink, snooper): def __init__(self, sink: TransportSink, snooper: Snooper) -> None:
self.sink = sink self.sink = sink
self.snooper = snooper self.snooper = snooper
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER) self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
if self.sink: if self.sink:
self.sink.on_packet(packet) self.sink.on_packet(packet)
def __init__(self, transport, snooper, close_snooper=None): def __init__(
self,
transport: Transport,
snooper: Snooper,
close_snooper=None,
) -> None:
super().__init__( super().__init__(
self.Source(transport.source, snooper), self.Sink(transport.sink, snooper) self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
) )

View File

@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_file_transport(spec): async def open_file_transport(spec: str) -> Transport:
''' '''
Open a File transport (typically not for a real file, but for a PTY or other unix Open a File transport (typically not for a real file, but for a PTY or other unix
virtual files). virtual files).

View File

@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_hci_socket_transport(spec): async def open_hci_socket_transport(spec: str | None) -> Transport:
''' '''
Open an HCI Socket (only available on some platforms). Open an HCI Socket (only available on some platforms).
The parameter string is either empty (to use the first/default Bluetooth adapter) The parameter string is either empty (to use the first/default Bluetooth adapter)
@@ -47,7 +47,7 @@ async def open_hci_socket_transport(spec):
hci_socket = socket.socket( hci_socket = socket.socket(
socket.AF_BLUETOOTH, socket.AF_BLUETOOTH,
socket.SOCK_RAW | socket.SOCK_NONBLOCK, socket.SOCK_RAW | socket.SOCK_NONBLOCK,
socket.BTPROTO_HCI, socket.BTPROTO_HCI, # type: ignore
) )
except AttributeError as error: except AttributeError as error:
# Not supported on this platform # Not supported on this platform

View File

@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_pty_transport(spec): async def open_pty_transport(spec: str | None) -> Transport:
''' '''
Open a PTY transport. Open a PTY transport.
The parameter string may be empty, or a path name where a symbolic link The parameter string may be empty, or a path name where a symbolic link

View File

@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_pyusb_transport(spec): async def open_pyusb_transport(spec: str) -> Transport:
''' '''
Open a USB transport. [Implementation based on PyUSB] Open a USB transport. [Implementation based on PyUSB]
The parameter string has this syntax: The parameter string has this syntax:

View File

@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_serial_transport(spec): async def open_serial_transport(spec: str) -> Transport:
''' '''
Open a serial port transport. Open a serial port transport.
The parameter string has this syntax: The parameter string has this syntax:

View File

@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_tcp_client_transport(spec): async def open_tcp_client_transport(spec: str) -> Transport:
''' '''
Open a TCP client transport. Open a TCP client transport.
The parameter string has this syntax: The parameter string has this syntax:

View File

@@ -15,6 +15,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import logging import logging
@@ -27,7 +28,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_tcp_server_transport(spec): async def open_tcp_server_transport(spec: str) -> Transport:
''' '''
Open a TCP server transport. Open a TCP server transport.
The parameter string has this syntax: The parameter string has this syntax:
@@ -42,7 +43,7 @@ async def open_tcp_server_transport(spec):
async def close(self): async def close(self):
await super().close() await super().close()
class TcpServerProtocol: class TcpServerProtocol(asyncio.BaseProtocol):
def __init__(self, packet_source, packet_sink): def __init__(self, packet_source, packet_sink):
self.packet_source = packet_source self.packet_source = packet_source
self.packet_sink = packet_sink self.packet_sink = packet_sink

View File

@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_udp_transport(spec): async def open_udp_transport(spec: str) -> Transport:
''' '''
Open a UDP transport. Open a UDP transport.
The parameter string has this syntax: The parameter string has this syntax:

View File

@@ -60,7 +60,7 @@ def load_libusb():
usb1.loadLibrary(libusb_dll) usb1.loadLibrary(libusb_dll)
async def open_usb_transport(spec): async def open_usb_transport(spec: str) -> Transport:
''' '''
Open a USB transport. Open a USB transport.
The moniker string has this syntax: The moniker string has this syntax:

View File

@@ -17,6 +17,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
from .common import Transport
from .file import open_file_transport from .file import open_file_transport
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -26,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_vhci_transport(spec): async def open_vhci_transport(spec: str | None) -> Transport:
''' '''
Open a VHCI transport (only available on some platforms). Open a VHCI transport (only available on some platforms).
The parameter string is either empty (to use the default VHCI device The parameter string is either empty (to use the default VHCI device
@@ -42,15 +43,15 @@ async def open_vhci_transport(spec):
# Override the source's `data_received` method so that we can # Override the source's `data_received` method so that we can
# filter out the vendor packet that is received just after the # filter out the vendor packet that is received just after the
# initial open # initial open
def vhci_data_received(data): def vhci_data_received(data: bytes) -> None:
if len(data) > 0 and data[0] == HCI_VENDOR_PKT: if len(data) > 0 and data[0] == HCI_VENDOR_PKT:
if len(data) == 4: if len(data) == 4:
hci_index = data[2] << 8 | data[3] hci_index = data[2] << 8 | data[3]
logger.info(f'HCI index {hci_index}') logger.info(f'HCI index {hci_index}')
else: else:
transport.source.parser.feed_data(data) transport.source.parser.feed_data(data) # type: ignore
transport.source.data_received = vhci_data_received transport.source.data_received = vhci_data_received # type: ignore
# Write the initial config # Write the initial config
transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR])) transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR]))

View File

@@ -16,9 +16,9 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
import websockets import websockets.client
from .common import PumpedPacketSource, PumpedPacketSink, PumpedTransport from .common import PumpedPacketSource, PumpedPacketSink, PumpedTransport, Transport
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_ws_client_transport(spec): async def open_ws_client_transport(spec: str) -> Transport:
''' '''
Open a WebSocket client transport. Open a WebSocket client transport.
The parameter string has this syntax: The parameter string has this syntax:
@@ -38,7 +38,7 @@ async def open_ws_client_transport(spec):
remote_host, remote_port = spec.split(':') remote_host, remote_port = spec.split(':')
uri = f'ws://{remote_host}:{remote_port}' uri = f'ws://{remote_host}:{remote_port}'
websocket = await websockets.connect(uri) websocket = await websockets.client.connect(uri)
transport = PumpedTransport( transport = PumpedTransport(
PumpedPacketSource(websocket.recv), PumpedPacketSource(websocket.recv),

View File

@@ -15,7 +15,6 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio
import logging import logging
import websockets import websockets
@@ -28,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_ws_server_transport(spec): async def open_ws_server_transport(spec: str) -> Transport:
''' '''
Open a WebSocket server transport. Open a WebSocket server transport.
The parameter string has this syntax: The parameter string has this syntax: