diff --git a/bumble/drivers/__init__.py b/bumble/drivers/__init__.py index d8ea06e..0a38f08 100644 --- a/bumble/drivers/__init__.py +++ b/bumble/drivers/__init__.py @@ -60,12 +60,23 @@ class Driver(abc.ABC): # Functions # ----------------------------------------------------------------------------- async def get_driver_for_host(host): - """Probe all known diver classes until one returns a valid instance for a host, - or none is found. + """Probe diver classes until one returns a valid instance for a host, or none is + found. + If a "driver" HCI metadata entry is present, only that driver class will be probed. """ - if driver := await rtk.Driver.for_host(host): - logger.debug("Instantiated RTK driver") - return driver + driver_classes = {"rtk": rtk.Driver} + if driver_name := host.hci_metadata.get("driver"): + # Only probe a single driver + probe_list = [driver_name] + else: + # Probe all drivers + probe_list = driver_classes.keys() + + for driver_name in probe_list: + logger.debug(f"Probing {driver_name} driver class") + if driver := await rtk.Driver.for_host(host): + logger.debug(f"Instantiated {driver_name} driver") + return driver return None diff --git a/bumble/drivers/rtk.py b/bumble/drivers/rtk.py index f78a14d..0b64e0c 100644 --- a/bumble/drivers/rtk.py +++ b/bumble/drivers/rtk.py @@ -470,8 +470,12 @@ class Driver: logger.debug("USB metadata not found") return False - vendor_id = host.hci_metadata.get("vendor_id", None) - product_id = host.hci_metadata.get("product_id", None) + if host.hci_metadata.get('driver') == 'rtk': + # Forced driver + return True + + vendor_id = host.hci_metadata.get("vendor_id") + product_id = host.hci_metadata.get("product_id") if vendor_id is None or product_id is None: logger.debug("USB metadata not sufficient") return False @@ -486,6 +490,9 @@ class Driver: @classmethod async def driver_info_for_host(cls, host): + await host.send_command(HCI_Reset_Command(), check_result=True) + host.ready = True # Needed to let the host know the controller is ready. + response = await host.send_command( HCI_Read_Local_Version_Information_Command(), check_result=True ) diff --git a/bumble/host.py b/bumble/host.py index 3ae2280..190ab89 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -21,7 +21,7 @@ import collections import logging import struct -from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable, cast +from typing import Any, Awaitable, Callable, Dict, Optional, Union, cast, TYPE_CHECKING from bumble.colors import color from bumble.l2cap import L2CAP_PDU @@ -124,7 +124,8 @@ class Connection: class Host(AbortableEventEmitter): connections: Dict[int, Connection] acl_packet_queue: collections.deque[HCI_AclDataPacket] - hci_sink: TransportSink + hci_sink: Optional[TransportSink] = None + hci_metadata: Dict[str, Any] long_term_key_provider: Optional[ Callable[[int, bytes, int], Awaitable[Optional[bytes]]] ] @@ -137,9 +138,8 @@ class Host(AbortableEventEmitter): ) -> None: super().__init__() - self.hci_metadata = None + self.hci_metadata = {} self.ready = False # True when we can accept incoming packets - self.reset_done = False self.connections = {} # Connections, by connection handle self.pending_command = None self.pending_response = None @@ -162,10 +162,7 @@ class Host(AbortableEventEmitter): # Connect to the source and sink if specified if controller_source: - controller_source.set_packet_sink(self) - self.hci_metadata = getattr( - controller_source, 'metadata', self.hci_metadata - ) + self.set_packet_source(controller_source) if controller_sink: self.set_packet_sink(controller_sink) @@ -200,17 +197,21 @@ class Host(AbortableEventEmitter): self.ready = False await self.flush() - await self.send_command(HCI_Reset_Command(), check_result=True) - self.ready = True - # Instantiate and init a driver for the host if needed. # NOTE: we don't keep a reference to the driver here, because we don't # currently have a need for the driver later on. But if the driver interface # evolves, it may be required, then, to store a reference to the driver in # an object property. + reset_needed = True if driver_factory is not None: if driver := await driver_factory(self): await driver.init_controller() + reset_needed = False + + # Send a reset command unless a driver has already done so. + if reset_needed: + await self.send_command(HCI_Reset_Command(), check_result=True) + self.ready = True response = await self.send_command( HCI_Read_Local_Supported_Commands_Command(), check_result=True @@ -313,25 +314,28 @@ class Host(AbortableEventEmitter): ) ) - self.reset_done = True - @property - def controller(self) -> TransportSink: + def controller(self) -> Optional[TransportSink]: return self.hci_sink @controller.setter - def controller(self, controller): + def controller(self, controller) -> None: self.set_packet_sink(controller) if controller: controller.set_packet_sink(self) - def set_packet_sink(self, sink: TransportSink) -> None: + def set_packet_sink(self, sink: Optional[TransportSink]) -> None: self.hci_sink = sink + def set_packet_source(self, source: TransportSource) -> None: + source.set_packet_sink(self) + self.hci_metadata = getattr(source, 'metadata', self.hci_metadata) + def send_hci_packet(self, packet: HCI_Packet) -> None: if self.snooper: self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) - self.hci_sink.on_packet(bytes(packet)) + if self.hci_sink: + self.hci_sink.on_packet(bytes(packet)) async def send_command(self, command, check_result=False): logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}') diff --git a/bumble/transport/__init__.py b/bumble/transport/__init__.py index bc0766b..4822dfe 100644 --- a/bumble/transport/__init__.py +++ b/bumble/transport/__init__.py @@ -18,6 +18,7 @@ from contextlib import asynccontextmanager import logging import os +from typing import Optional from .common import Transport, AsyncPipeSink, SnoopingTransport from ..snoop import create_snooper @@ -52,8 +53,16 @@ def _wrap_transport(transport: Transport) -> Transport: async def open_transport(name: str) -> Transport: """ Open a transport by name. - The name must be : - Where depend on the type (and may be empty for some types). + The name must be : + Where depend on the type (and may be empty for some types), and + is either omitted, or a ,-separated list of = pairs, + enclosed in []. + If there are not metadata or parameter, the : after the may be omitted. + Examples: + * usb:0 + * usb:[driver=rtk]0 + * android-netsim + The supported types are: * serial * udp @@ -71,15 +80,34 @@ async def open_transport(name: str) -> Transport: * android-netsim """ - return _wrap_transport(await _open_transport(name)) + scheme, *tail = name.split(':', 1) + spec = tail[0] if tail else None + if spec: + # Metadata may precede the spec + if spec.startswith('['): + metadata_str, *tail = spec[1:].split(']') + spec = tail[0] if tail else None + metadata = dict([entry.split('=') for entry in metadata_str.split(',')]) + else: + metadata = None + + transport = await _open_transport(scheme, spec) + if metadata: + transport.source.metadata = { # type: ignore[attr-defined] + **metadata, + **getattr(transport.source, 'metadata', {}), + } + # pylint: disable=line-too-long + logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined] + + return _wrap_transport(transport) # ----------------------------------------------------------------------------- -async def _open_transport(name: str) -> Transport: +async def _open_transport(scheme: str, spec: Optional[str]) -> Transport: # pylint: disable=import-outside-toplevel # pylint: disable=too-many-return-statements - scheme, *spec = name.split(':', 1) if scheme == 'serial' and spec: from .serial import open_serial_transport diff --git a/bumble/transport/common.py b/bumble/transport/common.py index ace04da..f767f54 100644 --- a/bumble/transport/common.py +++ b/bumble/transport/common.py @@ -21,7 +21,7 @@ import struct import asyncio import logging import io -from typing import ContextManager, Tuple, Optional, Protocol, Dict +from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict from bumble import hci from bumble.colors import color diff --git a/docs/mkdocs/src/drivers/index.md b/docs/mkdocs/src/drivers/index.md index a904e00..cb0a981 100644 --- a/docs/mkdocs/src/drivers/index.md +++ b/docs/mkdocs/src/drivers/index.md @@ -5,6 +5,15 @@ Some Bluetooth controllers require a driver to function properly. This may include, for instance, loading a Firmware image or patch, loading a configuration. +By default, drivers will be automatically probed to determine if they should be +used with particular HCI controller. +When the transport for an HCI controller is instantiated from a transport name, +a driver may also be forced by specifying ``driver=`` in the optional +metadata portion of the transport name. For example, +``usb:[driver=-rtk]0`` indicates that the ``rtk`` driver should be used with the +first USB device, even if a normal probe would not have selected it based on the +USB vendor ID and product ID. + Drivers included in the module are: * [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles. \ No newline at end of file diff --git a/docs/mkdocs/src/drivers/realtek.md b/docs/mkdocs/src/drivers/realtek.md index acbce49..599ce04 100644 --- a/docs/mkdocs/src/drivers/realtek.md +++ b/docs/mkdocs/src/drivers/realtek.md @@ -1,13 +1,16 @@ REALTEK DRIVER ============== -This driver supports loading firmware images and optional config data to +This driver supports loading firmware images and optional config data to USB dongles with a Realtek chipset. A number of USB dongles are supported, but likely not all. -When using a USB dongle, the USB product ID and manufacturer ID are used +When using a USB dongle, the USB product ID and vendor ID are used to find whether a matching set of firmware image and config data is needed for that specific model. If a match exists, the driver will try load the firmware image and, if needed, config data. +Alternatively, the metadata property ``driver=rtk`` may be specified in a transport +name to force that driver to be used (ex: ``usb:[driver=rtk]0`` instead of just +``usb:0`` for the first USB device). The driver will look for those files by name, in order, in: * The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR`