diff --git a/bumble/drivers/__init__.py b/bumble/drivers/__init__.py index 0a38f08..b5712e6 100644 --- a/bumble/drivers/__init__.py +++ b/bumble/drivers/__init__.py @@ -19,12 +19,17 @@ like loading firmware after a cold start. # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- -import abc +from __future__ import annotations import logging import pathlib import platform -from . import rtk +from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING +from . import rtk +from .common import Driver + +if TYPE_CHECKING: + from bumble.host import Host # ----------------------------------------------------------------------------- # Logging @@ -32,39 +37,16 @@ from . import rtk logger = logging.getLogger(__name__) -# ----------------------------------------------------------------------------- -# Classes -# ----------------------------------------------------------------------------- -class Driver(abc.ABC): - """Base class for drivers.""" - - @staticmethod - async def for_host(_host): - """Return a driver instance for a host. - - Args: - host: Host object for which a driver should be created. - - Returns: - A Driver instance if a driver should be instantiated for this host, or - None if no driver instance of this class is needed. - """ - return None - - @abc.abstractmethod - async def init_controller(self): - """Initialize the controller.""" - - # ----------------------------------------------------------------------------- # Functions # ----------------------------------------------------------------------------- -async def get_driver_for_host(host): +async def get_driver_for_host(host: Host) -> Optional[Driver]: """Probe diver classes until one returns a valid instance for a host, or none is found. If a "driver" HCI metadata entry is present, only that driver class will be probed. """ - driver_classes = {"rtk": rtk.Driver} + driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver} + probe_list: Iterable[str] if driver_name := host.hci_metadata.get("driver"): # Only probe a single driver probe_list = [driver_name] @@ -73,10 +55,13 @@ async def get_driver_for_host(host): probe_list = driver_classes.keys() for driver_name in probe_list: - logger.debug(f"Probing {driver_name} driver class") - if driver := await rtk.Driver.for_host(host): - logger.debug(f"Instantiated {driver_name} driver") - return driver + if driver_class := driver_classes.get(driver_name): + logger.debug(f"Probing driver class: {driver_name}") + if driver := await driver_class.for_host(host): + logger.debug(f"Instantiated {driver_name} driver") + return driver + else: + logger.debug(f"Skipping unknown driver class: {driver_name}") return None diff --git a/bumble/drivers/common.py b/bumble/drivers/common.py new file mode 100644 index 0000000..a4c0427 --- /dev/null +++ b/bumble/drivers/common.py @@ -0,0 +1,45 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Common types for drivers. +""" + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import abc + + +# ----------------------------------------------------------------------------- +# Classes +# ----------------------------------------------------------------------------- +class Driver(abc.ABC): + """Base class for drivers.""" + + @staticmethod + async def for_host(_host): + """Return a driver instance for a host. + + Args: + host: Host object for which a driver should be created. + + Returns: + A Driver instance if a driver should be instantiated for this host, or + None if no driver instance of this class is needed. + """ + return None + + @abc.abstractmethod + async def init_controller(self): + """Initialize the controller.""" diff --git a/bumble/drivers/rtk.py b/bumble/drivers/rtk.py index 0b64e0c..4a9034d 100644 --- a/bumble/drivers/rtk.py +++ b/bumble/drivers/rtk.py @@ -41,7 +41,7 @@ from bumble.hci import ( HCI_Reset_Command, HCI_Read_Local_Version_Information_Command, ) - +from bumble.drivers import common # ----------------------------------------------------------------------------- # Logging @@ -285,7 +285,7 @@ class Firmware: ) -class Driver: +class Driver(common.Driver): @dataclass class DriverInfo: rom: int