diff --git a/bumble/drivers/__init__.py b/bumble/drivers/__init__.py index d8ea06e6..b5712e66 100644 --- a/bumble/drivers/__init__.py +++ b/bumble/drivers/__init__.py @@ -19,12 +19,17 @@ like loading firmware after a cold start. # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- -import abc +from __future__ import annotations import logging import pathlib import platform -from . import rtk +from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING +from . import rtk +from .common import Driver + +if TYPE_CHECKING: + from bumble.host import Host # ----------------------------------------------------------------------------- # Logging @@ -32,40 +37,31 @@ from . import rtk logger = logging.getLogger(__name__) -# ----------------------------------------------------------------------------- -# Classes -# ----------------------------------------------------------------------------- -class Driver(abc.ABC): - """Base class for drivers.""" - - @staticmethod - async def for_host(_host): - """Return a driver instance for a host. - - Args: - host: Host object for which a driver should be created. - - Returns: - A Driver instance if a driver should be instantiated for this host, or - None if no driver instance of this class is needed. - """ - return None - - @abc.abstractmethod - async def init_controller(self): - """Initialize the controller.""" - - # ----------------------------------------------------------------------------- # Functions # ----------------------------------------------------------------------------- -async def get_driver_for_host(host): - """Probe all known diver classes until one returns a valid instance for a host, - or none is found. +async def get_driver_for_host(host: Host) -> Optional[Driver]: + """Probe diver classes until one returns a valid instance for a host, or none is + found. + If a "driver" HCI metadata entry is present, only that driver class will be probed. """ - if driver := await rtk.Driver.for_host(host): - logger.debug("Instantiated RTK driver") - return driver + driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver} + probe_list: Iterable[str] + if driver_name := host.hci_metadata.get("driver"): + # Only probe a single driver + probe_list = [driver_name] + else: + # Probe all drivers + probe_list = driver_classes.keys() + + for driver_name in probe_list: + if driver_class := driver_classes.get(driver_name): + logger.debug(f"Probing driver class: {driver_name}") + if driver := await driver_class.for_host(host): + logger.debug(f"Instantiated {driver_name} driver") + return driver + else: + logger.debug(f"Skipping unknown driver class: {driver_name}") return None diff --git a/bumble/drivers/common.py b/bumble/drivers/common.py new file mode 100644 index 00000000..a4c0427c --- /dev/null +++ b/bumble/drivers/common.py @@ -0,0 +1,45 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Common types for drivers. +""" + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import abc + + +# ----------------------------------------------------------------------------- +# Classes +# ----------------------------------------------------------------------------- +class Driver(abc.ABC): + """Base class for drivers.""" + + @staticmethod + async def for_host(_host): + """Return a driver instance for a host. + + Args: + host: Host object for which a driver should be created. + + Returns: + A Driver instance if a driver should be instantiated for this host, or + None if no driver instance of this class is needed. + """ + return None + + @abc.abstractmethod + async def init_controller(self): + """Initialize the controller.""" diff --git a/bumble/drivers/rtk.py b/bumble/drivers/rtk.py index f78a14d3..4a9034db 100644 --- a/bumble/drivers/rtk.py +++ b/bumble/drivers/rtk.py @@ -41,7 +41,7 @@ from bumble.hci import ( HCI_Reset_Command, HCI_Read_Local_Version_Information_Command, ) - +from bumble.drivers import common # ----------------------------------------------------------------------------- # Logging @@ -285,7 +285,7 @@ class Firmware: ) -class Driver: +class Driver(common.Driver): @dataclass class DriverInfo: rom: int @@ -470,8 +470,12 @@ class Driver: logger.debug("USB metadata not found") return False - vendor_id = host.hci_metadata.get("vendor_id", None) - product_id = host.hci_metadata.get("product_id", None) + if host.hci_metadata.get('driver') == 'rtk': + # Forced driver + return True + + vendor_id = host.hci_metadata.get("vendor_id") + product_id = host.hci_metadata.get("product_id") if vendor_id is None or product_id is None: logger.debug("USB metadata not sufficient") return False @@ -486,6 +490,9 @@ class Driver: @classmethod async def driver_info_for_host(cls, host): + await host.send_command(HCI_Reset_Command(), check_result=True) + host.ready = True # Needed to let the host know the controller is ready. + response = await host.send_command( HCI_Read_Local_Version_Information_Command(), check_result=True ) diff --git a/bumble/host.py b/bumble/host.py index 3ae2280b..190ab89e 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -21,7 +21,7 @@ import collections import logging import struct -from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable, cast +from typing import Any, Awaitable, Callable, Dict, Optional, Union, cast, TYPE_CHECKING from bumble.colors import color from bumble.l2cap import L2CAP_PDU @@ -124,7 +124,8 @@ class Connection: class Host(AbortableEventEmitter): connections: Dict[int, Connection] acl_packet_queue: collections.deque[HCI_AclDataPacket] - hci_sink: TransportSink + hci_sink: Optional[TransportSink] = None + hci_metadata: Dict[str, Any] long_term_key_provider: Optional[ Callable[[int, bytes, int], Awaitable[Optional[bytes]]] ] @@ -137,9 +138,8 @@ class Host(AbortableEventEmitter): ) -> None: super().__init__() - self.hci_metadata = None + self.hci_metadata = {} self.ready = False # True when we can accept incoming packets - self.reset_done = False self.connections = {} # Connections, by connection handle self.pending_command = None self.pending_response = None @@ -162,10 +162,7 @@ class Host(AbortableEventEmitter): # Connect to the source and sink if specified if controller_source: - controller_source.set_packet_sink(self) - self.hci_metadata = getattr( - controller_source, 'metadata', self.hci_metadata - ) + self.set_packet_source(controller_source) if controller_sink: self.set_packet_sink(controller_sink) @@ -200,17 +197,21 @@ class Host(AbortableEventEmitter): self.ready = False await self.flush() - await self.send_command(HCI_Reset_Command(), check_result=True) - self.ready = True - # Instantiate and init a driver for the host if needed. # NOTE: we don't keep a reference to the driver here, because we don't # currently have a need for the driver later on. But if the driver interface # evolves, it may be required, then, to store a reference to the driver in # an object property. + reset_needed = True if driver_factory is not None: if driver := await driver_factory(self): await driver.init_controller() + reset_needed = False + + # Send a reset command unless a driver has already done so. + if reset_needed: + await self.send_command(HCI_Reset_Command(), check_result=True) + self.ready = True response = await self.send_command( HCI_Read_Local_Supported_Commands_Command(), check_result=True @@ -313,25 +314,28 @@ class Host(AbortableEventEmitter): ) ) - self.reset_done = True - @property - def controller(self) -> TransportSink: + def controller(self) -> Optional[TransportSink]: return self.hci_sink @controller.setter - def controller(self, controller): + def controller(self, controller) -> None: self.set_packet_sink(controller) if controller: controller.set_packet_sink(self) - def set_packet_sink(self, sink: TransportSink) -> None: + def set_packet_sink(self, sink: Optional[TransportSink]) -> None: self.hci_sink = sink + def set_packet_source(self, source: TransportSource) -> None: + source.set_packet_sink(self) + self.hci_metadata = getattr(source, 'metadata', self.hci_metadata) + def send_hci_packet(self, packet: HCI_Packet) -> None: if self.snooper: self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) - self.hci_sink.on_packet(bytes(packet)) + if self.hci_sink: + self.hci_sink.on_packet(bytes(packet)) async def send_command(self, command, check_result=False): logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}') diff --git a/bumble/transport/__init__.py b/bumble/transport/__init__.py index bc0766b2..065e6964 100644 --- a/bumble/transport/__init__.py +++ b/bumble/transport/__init__.py @@ -18,6 +18,7 @@ from contextlib import asynccontextmanager import logging import os +from typing import Optional from .common import Transport, AsyncPipeSink, SnoopingTransport from ..snoop import create_snooper @@ -52,8 +53,16 @@ def _wrap_transport(transport: Transport) -> Transport: async def open_transport(name: str) -> Transport: """ Open a transport by name. - The name must be : - Where depend on the type (and may be empty for some types). + The name must be : + Where depend on the type (and may be empty for some types), and + is either omitted, or a ,-separated list of = pairs, + enclosed in []. + If there are not metadata or parameter, the : after the may be omitted. + Examples: + * usb:0 + * usb:[driver=rtk]0 + * android-netsim + The supported types are: * serial * udp @@ -71,87 +80,106 @@ async def open_transport(name: str) -> Transport: * android-netsim """ - return _wrap_transport(await _open_transport(name)) + scheme, *tail = name.split(':', 1) + spec = tail[0] if tail else None + if spec: + # Metadata may precede the spec + if spec.startswith('['): + metadata_str, *tail = spec[1:].split(']') + spec = tail[0] if tail else None + metadata = dict([entry.split('=') for entry in metadata_str.split(',')]) + else: + metadata = None + + transport = await _open_transport(scheme, spec) + if metadata: + transport.source.metadata = { # type: ignore[attr-defined] + **metadata, + **getattr(transport.source, 'metadata', {}), + } + # pylint: disable=line-too-long + logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined] + + return _wrap_transport(transport) # ----------------------------------------------------------------------------- -async def _open_transport(name: str) -> Transport: +async def _open_transport(scheme: str, spec: Optional[str]) -> Transport: # pylint: disable=import-outside-toplevel # pylint: disable=too-many-return-statements - scheme, *spec = name.split(':', 1) if scheme == 'serial' and spec: from .serial import open_serial_transport - return await open_serial_transport(spec[0]) + return await open_serial_transport(spec) if scheme == 'udp' and spec: from .udp import open_udp_transport - return await open_udp_transport(spec[0]) + return await open_udp_transport(spec) if scheme == 'tcp-client' and spec: from .tcp_client import open_tcp_client_transport - return await open_tcp_client_transport(spec[0]) + return await open_tcp_client_transport(spec) if scheme == 'tcp-server' and spec: from .tcp_server import open_tcp_server_transport - return await open_tcp_server_transport(spec[0]) + return await open_tcp_server_transport(spec) if scheme == 'ws-client' and spec: from .ws_client import open_ws_client_transport - return await open_ws_client_transport(spec[0]) + return await open_ws_client_transport(spec) if scheme == 'ws-server' and spec: from .ws_server import open_ws_server_transport - return await open_ws_server_transport(spec[0]) + return await open_ws_server_transport(spec) if scheme == 'pty': from .pty import open_pty_transport - return await open_pty_transport(spec[0] if spec else None) + return await open_pty_transport(spec) if scheme == 'file': from .file import open_file_transport assert spec is not None - return await open_file_transport(spec[0]) + return await open_file_transport(spec) if scheme == 'vhci': from .vhci import open_vhci_transport - return await open_vhci_transport(spec[0] if spec else None) + return await open_vhci_transport(spec) if scheme == 'hci-socket': from .hci_socket import open_hci_socket_transport - return await open_hci_socket_transport(spec[0] if spec else None) + return await open_hci_socket_transport(spec) if scheme == 'usb': from .usb import open_usb_transport - assert spec is not None - return await open_usb_transport(spec[0]) + assert spec + return await open_usb_transport(spec) if scheme == 'pyusb': from .pyusb import open_pyusb_transport - assert spec is not None - return await open_pyusb_transport(spec[0]) + assert spec + return await open_pyusb_transport(spec) if scheme == 'android-emulator': from .android_emulator import open_android_emulator_transport - return await open_android_emulator_transport(spec[0] if spec else None) + return await open_android_emulator_transport(spec) if scheme == 'android-netsim': from .android_netsim import open_android_netsim_transport - return await open_android_netsim_transport(spec[0] if spec else None) + return await open_android_netsim_transport(spec) raise ValueError('unknown transport scheme') diff --git a/bumble/transport/android_emulator.py b/bumble/transport/android_emulator.py index 8d19a9e2..9cd7ec21 100644 --- a/bumble/transport/android_emulator.py +++ b/bumble/transport/android_emulator.py @@ -69,7 +69,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport: mode = 'host' server_host = 'localhost' server_port = '8554' - if spec is not None: + if spec: params = spec.split(',') for param in params: if param.startswith('mode='): diff --git a/bumble/transport/common.py b/bumble/transport/common.py index ace04da5..f767f54f 100644 --- a/bumble/transport/common.py +++ b/bumble/transport/common.py @@ -21,7 +21,7 @@ import struct import asyncio import logging import io -from typing import ContextManager, Tuple, Optional, Protocol, Dict +from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict from bumble import hci from bumble.colors import color diff --git a/bumble/transport/hci_socket.py b/bumble/transport/hci_socket.py index df9e885a..41250433 100644 --- a/bumble/transport/hci_socket.py +++ b/bumble/transport/hci_socket.py @@ -59,10 +59,7 @@ async def open_hci_socket_transport(spec: Optional[str]) -> Transport: ) from error # Compute the adapter index - if spec is None: - adapter_index = 0 - else: - adapter_index = int(spec) + adapter_index = int(spec) if spec else 0 # Bind the socket # NOTE: since Python doesn't support binding with the required address format (yet), diff --git a/docs/mkdocs/src/drivers/index.md b/docs/mkdocs/src/drivers/index.md index a904e006..cb0a981e 100644 --- a/docs/mkdocs/src/drivers/index.md +++ b/docs/mkdocs/src/drivers/index.md @@ -5,6 +5,15 @@ Some Bluetooth controllers require a driver to function properly. This may include, for instance, loading a Firmware image or patch, loading a configuration. +By default, drivers will be automatically probed to determine if they should be +used with particular HCI controller. +When the transport for an HCI controller is instantiated from a transport name, +a driver may also be forced by specifying ``driver=`` in the optional +metadata portion of the transport name. For example, +``usb:[driver=-rtk]0`` indicates that the ``rtk`` driver should be used with the +first USB device, even if a normal probe would not have selected it based on the +USB vendor ID and product ID. + Drivers included in the module are: * [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles. \ No newline at end of file diff --git a/docs/mkdocs/src/drivers/realtek.md b/docs/mkdocs/src/drivers/realtek.md index acbce490..599ce048 100644 --- a/docs/mkdocs/src/drivers/realtek.md +++ b/docs/mkdocs/src/drivers/realtek.md @@ -1,13 +1,16 @@ REALTEK DRIVER ============== -This driver supports loading firmware images and optional config data to +This driver supports loading firmware images and optional config data to USB dongles with a Realtek chipset. A number of USB dongles are supported, but likely not all. -When using a USB dongle, the USB product ID and manufacturer ID are used +When using a USB dongle, the USB product ID and vendor ID are used to find whether a matching set of firmware image and config data is needed for that specific model. If a match exists, the driver will try load the firmware image and, if needed, config data. +Alternatively, the metadata property ``driver=rtk`` may be specified in a transport +name to force that driver to be used (ex: ``usb:[driver=rtk]0`` instead of just +``usb:0`` for the first USB device). The driver will look for those files by name, in order, in: * The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR`