diff --git a/bumble/controller.py b/bumble/controller.py index 382be13..9b2960a 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -15,6 +15,8 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from __future__ import annotations + import logging import asyncio import itertools @@ -58,8 +60,10 @@ from bumble.hci import ( HCI_Packet, HCI_Role_Change_Event, ) -from typing import Optional, Union, Dict +from typing import Optional, Union, Dict, TYPE_CHECKING +if TYPE_CHECKING: + from bumble.transport.common import TransportSink, TransportSource # ----------------------------------------------------------------------------- # Logging @@ -104,7 +108,7 @@ class Controller: self, name, host_source=None, - host_sink=None, + host_sink: Optional[TransportSink] = None, link=None, public_address: Optional[Union[bytes, str, Address]] = None, ): diff --git a/bumble/device.py b/bumble/device.py index 46ce012..fca2e7a 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -23,7 +23,18 @@ import asyncio import logging from contextlib import asynccontextmanager, AsyncExitStack from dataclasses import dataclass -from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + Optional, + Tuple, + Type, + Union, + TYPE_CHECKING, +) from .colors import color from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU @@ -152,6 +163,9 @@ from . import sdp from . import l2cap from . import core +if TYPE_CHECKING: + from .transport.common import TransportSource, TransportSink + # ----------------------------------------------------------------------------- # Logging @@ -942,7 +956,13 @@ class Device(CompositeEventEmitter): pass @classmethod - def with_hci(cls, name, address, hci_source, hci_sink): + def with_hci( + cls, + name: str, + address: Address, + hci_source: TransportSource, + hci_sink: TransportSink, + ) -> Device: ''' Create a Device instance with a Host configured to communicate with a controller through an HCI source/sink @@ -951,18 +971,25 @@ class Device(CompositeEventEmitter): return cls(name=name, address=address, host=host) @classmethod - def from_config_file(cls, filename): + def from_config_file(cls, filename: str) -> Device: config = DeviceConfiguration() config.load_from_file(filename) return cls(config=config) @classmethod - def from_config_with_hci(cls, config, hci_source, hci_sink): + def from_config_with_hci( + cls, + config: DeviceConfiguration, + hci_source: TransportSource, + hci_sink: TransportSink, + ) -> Device: host = Host(controller_source=hci_source, controller_sink=hci_sink) return cls(config=config, host=host) @classmethod - def from_config_file_with_hci(cls, filename, hci_source, hci_sink): + def from_config_file_with_hci( + cls, filename: str, hci_source: TransportSource, hci_sink: TransportSink + ) -> Device: config = DeviceConfiguration() config.load_from_file(filename) return cls.from_config_with_hci(config, hci_source, hci_sink) @@ -2238,9 +2265,11 @@ class Device(CompositeEventEmitter): def request_pairing(self, connection): return self.smp_manager.request_pairing(connection) - async def get_long_term_key(self, connection_handle, rand, ediv): + async def get_long_term_key( + self, connection_handle: int, rand: bytes, ediv: int + ) -> Optional[bytes]: if (connection := self.lookup_connection(connection_handle)) is None: - return + return None # Start by looking for the key in an SMP session ltk = self.smp_manager.get_long_term_key(connection, rand, ediv) @@ -2260,6 +2289,7 @@ class Device(CompositeEventEmitter): if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral: return keys.ltk_peripheral.value + return None async def get_link_key(self, address: Address) -> Optional[bytes]: if self.keystore is None: diff --git a/bumble/host.py b/bumble/host.py index 288b1b6..02caa46 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -21,7 +21,7 @@ import collections import logging import struct -from typing import Optional +from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable from bumble.colors import color from bumble.l2cap import L2CAP_PDU @@ -73,10 +73,14 @@ from .core import ( BT_LE_TRANSPORT, ConnectionPHY, ConnectionParameters, + InvalidStateError, ) from .utils import AbortableEventEmitter from .transport.common import TransportLostError +if TYPE_CHECKING: + from .transport.common import TransportSink, TransportSource + # ----------------------------------------------------------------------------- # Logging @@ -116,10 +120,21 @@ class Connection: # ----------------------------------------------------------------------------- class Host(AbortableEventEmitter): - def __init__(self, controller_source=None, controller_sink=None): + connections: Dict[int, Connection] + acl_packet_queue: collections.deque[HCI_AclDataPacket] + hci_sink: TransportSink + long_term_key_provider: Optional[ + Callable[[int, bytes, int], Awaitable[Optional[bytes]]] + ] + link_key_provider: Optional[Callable[[Address], Awaitable[Optional[bytes]]]] + + def __init__( + self, + controller_source: Optional[TransportSource] = None, + controller_sink: Optional[TransportSink] = None, + ) -> None: super().__init__() - self.hci_sink = None self.hci_metadata = None self.ready = False # True when we can accept incoming packets self.reset_done = False @@ -299,7 +314,7 @@ class Host(AbortableEventEmitter): self.reset_done = True @property - def controller(self): + def controller(self) -> TransportSink: return self.hci_sink @controller.setter @@ -308,13 +323,12 @@ class Host(AbortableEventEmitter): if controller: controller.set_packet_sink(self) - def set_packet_sink(self, sink): + def set_packet_sink(self, sink: TransportSink) -> None: self.hci_sink = sink def send_hci_packet(self, packet: HCI_Packet) -> None: if self.snooper: self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) - self.hci_sink.on_packet(bytes(packet)) async def send_command(self, command, check_result=False): diff --git a/bumble/transport/__init__.py b/bumble/transport/__init__.py index c722283..bc0766b 100644 --- a/bumble/transport/__init__.py +++ b/bumble/transport/__init__.py @@ -20,7 +20,6 @@ import logging import os from .common import Transport, AsyncPipeSink, SnoopingTransport -from ..controller import Controller from ..snoop import create_snooper # ----------------------------------------------------------------------------- @@ -119,7 +118,8 @@ async def _open_transport(name: str) -> Transport: if scheme == 'file': from .file import open_file_transport - return await open_file_transport(spec[0] if spec else None) + assert spec is not None + return await open_file_transport(spec[0]) if scheme == 'vhci': from .vhci import open_vhci_transport @@ -134,12 +134,14 @@ async def _open_transport(name: str) -> Transport: if scheme == 'usb': from .usb import open_usb_transport - return await open_usb_transport(spec[0] if spec else None) + assert spec is not None + return await open_usb_transport(spec[0]) if scheme == 'pyusb': from .pyusb import open_pyusb_transport - return await open_pyusb_transport(spec[0] if spec else None) + assert spec is not None + return await open_pyusb_transport(spec[0]) if scheme == 'android-emulator': from .android_emulator import open_android_emulator_transport @@ -168,6 +170,7 @@ async def open_transport_or_link(name: str) -> Transport: """ if name.startswith('link-relay:'): + from ..controller import Controller from ..link import RemoteLink # lazy import link = RemoteLink(name[11:]) diff --git a/bumble/transport/android_emulator.py b/bumble/transport/android_emulator.py index b78e263..5ef0047 100644 --- a/bumble/transport/android_emulator.py +++ b/bumble/transport/android_emulator.py @@ -18,7 +18,7 @@ import logging import grpc.aio -from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink +from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport # pylint: disable=no-name-in-module from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_android_emulator_transport(spec): +async def open_android_emulator_transport(spec: str | None) -> Transport: ''' Open a transport connection to an Android emulator via its gRPC interface. The parameter string has this syntax: @@ -66,7 +66,7 @@ async def open_android_emulator_transport(spec): # Parse the parameters mode = 'host' server_host = 'localhost' - server_port = 8554 + server_port = '8554' if spec is not None: params = spec.split(',') for param in params: @@ -82,6 +82,7 @@ async def open_android_emulator_transport(spec): logger.debug(f'connecting to gRPC server at {server_address}') channel = grpc.aio.insecure_channel(server_address) + service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub if mode == 'host': # Connect as a host service = EmulatedBluetoothServiceStub(channel) diff --git a/bumble/transport/android_netsim.py b/bumble/transport/android_netsim.py index 99ebf87..76a7385 100644 --- a/bumble/transport/android_netsim.py +++ b/bumble/transport/android_netsim.py @@ -121,7 +121,9 @@ def publish_grpc_port(grpc_port) -> bool: # ----------------------------------------------------------------------------- -async def open_android_netsim_controller_transport(server_host, server_port): +async def open_android_netsim_controller_transport( + server_host: str | None, server_port: int +) -> Transport: if not server_port: raise ValueError('invalid port') if server_host == '_' or not server_host: diff --git a/bumble/transport/common.py b/bumble/transport/common.py index c7be3ad..5d5bdf1 100644 --- a/bumble/transport/common.py +++ b/bumble/transport/common.py @@ -20,11 +20,12 @@ import contextlib import struct import asyncio import logging -from typing import ContextManager +import io +from typing import ContextManager, Tuple, Optional, Protocol, Dict -from .. import hci -from ..colors import color -from ..snoop import Snooper +from bumble import hci +from bumble.colors import color +from bumble.snoop import Snooper # ----------------------------------------------------------------------------- @@ -36,7 +37,7 @@ logger = logging.getLogger(__name__) # Information needed to parse HCI packets with a generic parser: # For each packet type, the info represents: # (length-size, length-offset, unpack-type) -HCI_PACKET_INFO = { +HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = { hci.HCI_COMMAND_PACKET: (1, 2, 'B'), hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), @@ -44,6 +45,8 @@ HCI_PACKET_INFO = { } +# ----------------------------------------------------------------------------- +# Errors # ----------------------------------------------------------------------------- class TransportLostError(Exception): """ @@ -51,24 +54,34 @@ class TransportLostError(Exception): """ +# ----------------------------------------------------------------------------- +# Typing Protocols +# ----------------------------------------------------------------------------- +class TransportSink(Protocol): + def on_packet(self, packet: bytes) -> None: + ... + + +class TransportSource(Protocol): + def set_packet_sink(self, sink: TransportSink) -> None: + ... + + # ----------------------------------------------------------------------------- class PacketPump: """ Pump HCI packets from a reader to a sink. """ - def __init__(self, reader, sink): + def __init__(self, reader: AsyncPacketReader, sink: TransportSink) -> None: self.reader = reader self.sink = sink - async def run(self): + async def run(self) -> None: while True: try: - # Get a packet from the source - packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet()) - # Deliver the packet to the sink - self.sink.on_packet(packet) + self.sink.on_packet(await self.reader.next_packet()) except Exception as error: logger.warning(f'!!! {error}') @@ -86,18 +99,22 @@ class PacketParser: NEED_LENGTH = 1 NEED_BODY = 2 - def __init__(self, sink=None): + sink: Optional[TransportSink] + extended_packet_info: Dict[int, Tuple[int, int, str]] + packet_info: Optional[Tuple[int, int, str]] = None + + def __init__(self, sink: Optional[TransportSink] = None) -> None: self.sink = sink self.extended_packet_info = {} self.reset() - def reset(self): + def reset(self) -> None: self.state = PacketParser.NEED_TYPE self.bytes_needed = 1 self.packet = bytearray() self.packet_info = None - def feed_data(self, data): + def feed_data(self, data: bytes) -> None: data_offset = 0 data_left = len(data) while data_left and self.bytes_needed: @@ -118,6 +135,7 @@ class PacketParser: self.state = PacketParser.NEED_LENGTH self.bytes_needed = self.packet_info[0] + self.packet_info[1] elif self.state == PacketParser.NEED_LENGTH: + assert self.packet_info is not None body_length = struct.unpack_from( self.packet_info[2], self.packet, 1 + self.packet_info[1] )[0] @@ -135,7 +153,7 @@ class PacketParser: ) self.reset() - def set_packet_sink(self, sink): + def set_packet_sink(self, sink: TransportSink) -> None: self.sink = sink @@ -145,10 +163,10 @@ class PacketReader: Reader that reads HCI packets from a sync source. """ - def __init__(self, source): + def __init__(self, source: io.BufferedReader) -> None: self.source = source - def next_packet(self): + def next_packet(self) -> Optional[bytes]: # Get the packet type packet_type = self.source.read(1) if len(packet_type) != 1: @@ -157,7 +175,7 @@ class PacketReader: # Get the packet info based on its type packet_info = HCI_PACKET_INFO.get(packet_type[0]) if packet_info is None: - raise ValueError(f'invalid packet type {packet_type} found') + raise ValueError(f'invalid packet type {packet_type[0]} found') # Read the header (that includes the length) header_size = packet_info[0] + packet_info[1] @@ -180,17 +198,17 @@ class AsyncPacketReader: Reader that reads HCI packets from an async source. """ - def __init__(self, source): + def __init__(self, source: asyncio.StreamReader) -> None: self.source = source - async def next_packet(self): + async def next_packet(self) -> bytes: # Get the packet type packet_type = await self.source.readexactly(1) # Get the packet info based on its type packet_info = HCI_PACKET_INFO.get(packet_type[0]) if packet_info is None: - raise ValueError(f'invalid packet type {packet_type} found') + raise ValueError(f'invalid packet type {packet_type[0]} found') # Read the header (that includes the length) header_size = packet_info[0] + packet_info[1] @@ -209,11 +227,11 @@ class AsyncPipeSink: Sink that forwards packets asynchronously to another sink. """ - def __init__(self, sink): + def __init__(self, sink: TransportSink) -> None: self.sink = sink self.loop = asyncio.get_running_loop() - def on_packet(self, packet): + def on_packet(self, packet: bytes) -> None: self.loop.call_soon(self.sink.on_packet, packet) @@ -223,50 +241,48 @@ class ParserSource: Base class designed to be subclassed by transport-specific source classes """ - terminated: asyncio.Future + terminated: asyncio.Future[None] parser: PacketParser - def __init__(self): + def __init__(self) -> None: self.parser = PacketParser() self.terminated = asyncio.get_running_loop().create_future() - def set_packet_sink(self, sink): + def set_packet_sink(self, sink: TransportSink) -> None: self.parser.set_packet_sink(sink) - def on_transport_lost(self): + def on_transport_lost(self) -> None: self.terminated.set_result(None) if self.parser.sink: - try: + if hasattr(self.parser.sink, 'on_transport_lost'): self.parser.sink.on_transport_lost() - except AttributeError: - pass - async def wait_for_termination(self): + async def wait_for_termination(self) -> None: """ Convenience method for backward compatibility. Prefer using the `terminated` attribute instead. """ return await self.terminated - def close(self): + def close(self) -> None: pass # ----------------------------------------------------------------------------- class StreamPacketSource(asyncio.Protocol, ParserSource): - def data_received(self, data): + def data_received(self, data: bytes) -> None: self.parser.feed_data(data) # ----------------------------------------------------------------------------- class StreamPacketSink: - def __init__(self, transport): + def __init__(self, transport: asyncio.WriteTransport) -> None: self.transport = transport - def on_packet(self, packet): + def on_packet(self, packet: bytes) -> None: self.transport.write(packet) - def close(self): + def close(self) -> None: self.transport.close() @@ -286,7 +302,7 @@ class Transport: ... """ - def __init__(self, source, sink): + def __init__(self, source: TransportSource, sink: TransportSink) -> None: self.source = source self.sink = sink @@ -300,19 +316,23 @@ class Transport: return iter((self.source, self.sink)) async def close(self) -> None: - self.source.close() - self.sink.close() + if hasattr(self.source, 'close'): + self.source.close() + if hasattr(self.sink, 'close'): + self.sink.close() # ----------------------------------------------------------------------------- class PumpedPacketSource(ParserSource): - def __init__(self, receive): + pump_task: Optional[asyncio.Task[None]] + + def __init__(self, receive) -> None: super().__init__() self.receive_function = receive self.pump_task = None - def start(self): - async def pump_packets(): + def start(self) -> None: + async def pump_packets() -> None: while True: try: packet = await self.receive_function() @@ -322,12 +342,12 @@ class PumpedPacketSource(ParserSource): break except Exception as error: logger.warning(f'exception while waiting for packet: {error}') - self.terminated.set_result(error) + self.terminated.set_exception(error) break self.pump_task = asyncio.create_task(pump_packets()) - def close(self): + def close(self) -> None: if self.pump_task: self.pump_task.cancel() @@ -339,7 +359,7 @@ class PumpedPacketSink: self.packet_queue = asyncio.Queue() self.pump_task = None - def on_packet(self, packet): + def on_packet(self, packet: bytes) -> None: self.packet_queue.put_nowait(packet) def start(self): @@ -364,15 +384,23 @@ class PumpedPacketSink: # ----------------------------------------------------------------------------- class PumpedTransport(Transport): - def __init__(self, source, sink, close_function): + source: PumpedPacketSource + sink: PumpedPacketSink + + def __init__( + self, + source: PumpedPacketSource, + sink: PumpedPacketSink, + close_function, + ) -> None: super().__init__(source, sink) self.close_function = close_function - def start(self): + def start(self) -> None: self.source.start() self.sink.start() - async def close(self): + async def close(self) -> None: await super().close() await self.close_function() @@ -397,31 +425,37 @@ class SnoopingTransport(Transport): raise RuntimeError('unexpected code path') # Satisfy the type checker class Source: - def __init__(self, source, snooper): + sink: TransportSink + + def __init__(self, source: TransportSource, snooper: Snooper): self.source = source self.snooper = snooper - self.sink = None - def set_packet_sink(self, sink): + def set_packet_sink(self, sink: TransportSink) -> None: self.sink = sink self.source.set_packet_sink(self) - def on_packet(self, packet): + def on_packet(self, packet: bytes) -> None: self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST) if self.sink: self.sink.on_packet(packet) class Sink: - def __init__(self, sink, snooper): + def __init__(self, sink: TransportSink, snooper: Snooper) -> None: self.sink = sink self.snooper = snooper - def on_packet(self, packet): + def on_packet(self, packet: bytes) -> None: self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER) if self.sink: self.sink.on_packet(packet) - def __init__(self, transport, snooper, close_snooper=None): + def __init__( + self, + transport: Transport, + snooper: Snooper, + close_snooper=None, + ) -> None: super().__init__( self.Source(transport.source, snooper), self.Sink(transport.sink, snooper) ) diff --git a/bumble/transport/file.py b/bumble/transport/file.py index 9c073d2..dee1c23 100644 --- a/bumble/transport/file.py +++ b/bumble/transport/file.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_file_transport(spec): +async def open_file_transport(spec: str) -> Transport: ''' Open a File transport (typically not for a real file, but for a PTY or other unix virtual files). diff --git a/bumble/transport/hci_socket.py b/bumble/transport/hci_socket.py index 4e1ad99..9891c5b 100644 --- a/bumble/transport/hci_socket.py +++ b/bumble/transport/hci_socket.py @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_hci_socket_transport(spec): +async def open_hci_socket_transport(spec: str | None) -> Transport: ''' Open an HCI Socket (only available on some platforms). The parameter string is either empty (to use the first/default Bluetooth adapter) @@ -47,7 +47,7 @@ async def open_hci_socket_transport(spec): hci_socket = socket.socket( socket.AF_BLUETOOTH, socket.SOCK_RAW | socket.SOCK_NONBLOCK, - socket.BTPROTO_HCI, + socket.BTPROTO_HCI, # type: ignore ) except AttributeError as error: # Not supported on this platform diff --git a/bumble/transport/pty.py b/bumble/transport/pty.py index e6e2ab5..7765b09 100644 --- a/bumble/transport/pty.py +++ b/bumble/transport/pty.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_pty_transport(spec): +async def open_pty_transport(spec: str | None) -> Transport: ''' Open a PTY transport. The parameter string may be empty, or a path name where a symbolic link diff --git a/bumble/transport/pyusb.py b/bumble/transport/pyusb.py index 8ad8598..5e686d1 100644 --- a/bumble/transport/pyusb.py +++ b/bumble/transport/pyusb.py @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_pyusb_transport(spec): +async def open_pyusb_transport(spec: str) -> Transport: ''' Open a USB transport. [Implementation based on PyUSB] The parameter string has this syntax: diff --git a/bumble/transport/serial.py b/bumble/transport/serial.py index c83b605..c48cdc6 100644 --- a/bumble/transport/serial.py +++ b/bumble/transport/serial.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_serial_transport(spec): +async def open_serial_transport(spec: str) -> Transport: ''' Open a serial port transport. The parameter string has this syntax: diff --git a/bumble/transport/tcp_client.py b/bumble/transport/tcp_client.py index 456a19a..4fb268a 100644 --- a/bumble/transport/tcp_client.py +++ b/bumble/transport/tcp_client.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_tcp_client_transport(spec): +async def open_tcp_client_transport(spec: str) -> Transport: ''' Open a TCP client transport. The parameter string has this syntax: diff --git a/bumble/transport/tcp_server.py b/bumble/transport/tcp_server.py index 11b0453..77d0304 100644 --- a/bumble/transport/tcp_server.py +++ b/bumble/transport/tcp_server.py @@ -15,6 +15,7 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging @@ -27,7 +28,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_tcp_server_transport(spec): +async def open_tcp_server_transport(spec: str) -> Transport: ''' Open a TCP server transport. The parameter string has this syntax: @@ -42,7 +43,7 @@ async def open_tcp_server_transport(spec): async def close(self): await super().close() - class TcpServerProtocol: + class TcpServerProtocol(asyncio.BaseProtocol): def __init__(self, packet_source, packet_sink): self.packet_source = packet_source self.packet_sink = packet_sink diff --git a/bumble/transport/udp.py b/bumble/transport/udp.py index e5e26fa..faa9bf0 100644 --- a/bumble/transport/udp.py +++ b/bumble/transport/udp.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_udp_transport(spec): +async def open_udp_transport(spec: str) -> Transport: ''' Open a UDP transport. The parameter string has this syntax: diff --git a/bumble/transport/usb.py b/bumble/transport/usb.py index 13cad60..ccc82c1 100644 --- a/bumble/transport/usb.py +++ b/bumble/transport/usb.py @@ -60,7 +60,7 @@ def load_libusb(): usb1.loadLibrary(libusb_dll) -async def open_usb_transport(spec): +async def open_usb_transport(spec: str) -> Transport: ''' Open a USB transport. The moniker string has this syntax: diff --git a/bumble/transport/vhci.py b/bumble/transport/vhci.py index ec61ab4..5795840 100644 --- a/bumble/transport/vhci.py +++ b/bumble/transport/vhci.py @@ -17,6 +17,7 @@ # ----------------------------------------------------------------------------- import logging +from .common import Transport from .file import open_file_transport # ----------------------------------------------------------------------------- @@ -26,7 +27,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_vhci_transport(spec): +async def open_vhci_transport(spec: str | None) -> Transport: ''' Open a VHCI transport (only available on some platforms). The parameter string is either empty (to use the default VHCI device @@ -42,15 +43,15 @@ async def open_vhci_transport(spec): # Override the source's `data_received` method so that we can # filter out the vendor packet that is received just after the # initial open - def vhci_data_received(data): + def vhci_data_received(data: bytes) -> None: if len(data) > 0 and data[0] == HCI_VENDOR_PKT: if len(data) == 4: hci_index = data[2] << 8 | data[3] logger.info(f'HCI index {hci_index}') else: - transport.source.parser.feed_data(data) + transport.source.parser.feed_data(data) # type: ignore - transport.source.data_received = vhci_data_received + transport.source.data_received = vhci_data_received # type: ignore # Write the initial config transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR])) diff --git a/bumble/transport/ws_client.py b/bumble/transport/ws_client.py index 85f6e88..facd1c9 100644 --- a/bumble/transport/ws_client.py +++ b/bumble/transport/ws_client.py @@ -16,9 +16,9 @@ # Imports # ----------------------------------------------------------------------------- import logging -import websockets +import websockets.client -from .common import PumpedPacketSource, PumpedPacketSink, PumpedTransport +from .common import PumpedPacketSource, PumpedPacketSink, PumpedTransport, Transport # ----------------------------------------------------------------------------- # Logging @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_ws_client_transport(spec): +async def open_ws_client_transport(spec: str) -> Transport: ''' Open a WebSocket client transport. The parameter string has this syntax: @@ -38,7 +38,7 @@ async def open_ws_client_transport(spec): remote_host, remote_port = spec.split(':') uri = f'ws://{remote_host}:{remote_port}' - websocket = await websockets.connect(uri) + websocket = await websockets.client.connect(uri) transport = PumpedTransport( PumpedPacketSource(websocket.recv), diff --git a/bumble/transport/ws_server.py b/bumble/transport/ws_server.py index ddebef2..3c72c36 100644 --- a/bumble/transport/ws_server.py +++ b/bumble/transport/ws_server.py @@ -15,7 +15,6 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- -import asyncio import logging import websockets @@ -28,7 +27,7 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -async def open_ws_server_transport(spec): +async def open_ws_server_transport(spec: str) -> Transport: ''' Open a WebSocket server transport. The parameter string has this syntax: