diff --git a/bumble/device.py b/bumble/device.py index 10ce28a..6bc945a 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -17,6 +17,7 @@ # ----------------------------------------------------------------------------- from __future__ import annotations from enum import IntEnum +import copy import functools import json import asyncio @@ -40,6 +41,7 @@ from typing import ( overload, TYPE_CHECKING, ) +from typing_extensions import Self from pyee import EventEmitter @@ -1252,75 +1254,47 @@ class Connection(CompositeEventEmitter): # ----------------------------------------------------------------------------- +@dataclass class DeviceConfiguration: - def __init__(self) -> None: - # Setup defaults - self.name = DEVICE_DEFAULT_NAME - self.address = Address(DEVICE_DEFAULT_ADDRESS) - self.class_of_device = DEVICE_DEFAULT_CLASS_OF_DEVICE - self.scan_response_data = DEVICE_DEFAULT_SCAN_RESPONSE_DATA - self.advertising_interval_min = DEVICE_DEFAULT_ADVERTISING_INTERVAL - self.advertising_interval_max = DEVICE_DEFAULT_ADVERTISING_INTERVAL - self.le_enabled = True - # LE host enable 2nd parameter - self.le_simultaneous_enabled = False - self.classic_enabled = False - self.classic_sc_enabled = True - self.classic_ssp_enabled = True - self.classic_smp_enabled = True - self.classic_accept_any = True - self.connectable = True - self.discoverable = True - self.advertising_data = bytes( - AdvertisingData( - [(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))] - ) + # Setup defaults + name: str = DEVICE_DEFAULT_NAME + address: Address = Address(DEVICE_DEFAULT_ADDRESS) + class_of_device: int = DEVICE_DEFAULT_CLASS_OF_DEVICE + scan_response_data: bytes = DEVICE_DEFAULT_SCAN_RESPONSE_DATA + advertising_interval_min: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL + advertising_interval_max: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL + le_enabled: bool = True + # LE host enable 2nd parameter + le_simultaneous_enabled: bool = False + classic_enabled: bool = False + classic_sc_enabled: bool = True + classic_ssp_enabled: bool = True + classic_smp_enabled: bool = True + classic_accept_any: bool = True + connectable: bool = True + discoverable: bool = True + advertising_data: bytes = bytes( + AdvertisingData( + [(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(DEVICE_DEFAULT_NAME, 'utf-8'))] ) - self.irk = bytes(16) # This really must be changed for any level of security - self.keystore = None + ) + irk: bytes = bytes(16) # This really must be changed for any level of security + keystore: Optional[str] = None + address_resolution_offload: bool = False + cis_enabled: bool = False + + def __post_init__(self) -> None: self.gatt_services: List[Dict[str, Any]] = [] - self.address_resolution_offload = False - self.cis_enabled = False def load_from_dict(self, config: Dict[str, Any]) -> None: + config = copy.deepcopy(config) + # Load simple properties - self.name = config.get('name', self.name) - if address := config.get('address', None): + if address := config.pop('address', None): self.address = Address(address) - self.class_of_device = config.get('class_of_device', self.class_of_device) - self.advertising_interval_min = config.get( - 'advertising_interval', self.advertising_interval_min - ) - self.advertising_interval_max = self.advertising_interval_min - self.keystore = config.get('keystore') - self.le_enabled = config.get('le_enabled', self.le_enabled) - self.le_simultaneous_enabled = config.get( - 'le_simultaneous_enabled', self.le_simultaneous_enabled - ) - self.classic_enabled = config.get('classic_enabled', self.classic_enabled) - self.classic_sc_enabled = config.get( - 'classic_sc_enabled', self.classic_sc_enabled - ) - self.classic_ssp_enabled = config.get( - 'classic_ssp_enabled', self.classic_ssp_enabled - ) - self.classic_smp_enabled = config.get( - 'classic_smp_enabled', self.classic_smp_enabled - ) - self.classic_accept_any = config.get( - 'classic_accept_any', self.classic_accept_any - ) - self.connectable = config.get('connectable', self.connectable) - self.discoverable = config.get('discoverable', self.discoverable) - self.gatt_services = config.get('gatt_services', self.gatt_services) - self.address_resolution_offload = config.get( - 'address_resolution_offload', self.address_resolution_offload - ) - self.cis_enabled = config.get('cis_enabled', self.cis_enabled) # Load or synthesize an IRK - irk = config.get('irk') - if irk: + if irk := config.pop('irk', None): self.irk = bytes.fromhex(irk) elif self.address != Address(DEVICE_DEFAULT_ADDRESS): # Construct an IRK from the address bytes @@ -1332,21 +1306,53 @@ class DeviceConfiguration: # Fallback - when both IRK and address are not set, randomly generate an IRK. self.irk = secrets.token_bytes(16) + if (name := config.pop('name', None)) is not None: + self.name = name + # Load advertising data - advertising_data = config.get('advertising_data') - if advertising_data: + if advertising_data := config.pop('advertising_data', None): self.advertising_data = bytes.fromhex(advertising_data) - elif config.get('name') is not None: + elif name is not None: self.advertising_data = bytes( AdvertisingData( [(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))] ) ) - def load_from_file(self, filename): + # Load advertising interval (for backward compatibility) + if advertising_interval := config.pop('advertising_interval', None): + self.advertising_interval_min = advertising_interval + self.advertising_interval_max = advertising_interval + if ( + 'advertising_interval_max' in config + or 'advertising_interval_min' in config + ): + logger.warning( + 'Trying to set both advertising_interval and ' + 'advertising_interval_min/max, advertising_interval will be' + 'ignored.' + ) + + # Load data in primitive types. + for key, value in config.items(): + setattr(self, key, value) + + def load_from_file(self, filename: str) -> None: with open(filename, 'r', encoding='utf-8') as file: self.load_from_dict(json.load(file)) + @classmethod + def from_file(cls: Type[Self], filename: str) -> Self: + config = cls() + config.load_from_file(filename) + return config + + @classmethod + def from_dict(cls: Type[Self], config: Dict[str, Any]) -> Self: + device_config = cls() + device_config.load_from_dict(config) + return device_config + # ----------------------------------------------------------------------------- # Decorators used with the following Device class @@ -1470,8 +1476,7 @@ class Device(CompositeEventEmitter): @classmethod def from_config_file(cls, filename: str) -> Device: - config = DeviceConfiguration() - config.load_from_file(filename) + config = DeviceConfiguration.from_file(filename) return cls(config=config) @classmethod @@ -1488,8 +1493,7 @@ class Device(CompositeEventEmitter): def from_config_file_with_hci( cls, filename: str, hci_source: TransportSource, hci_sink: TransportSink ) -> Device: - config = DeviceConfiguration() - config.load_from_file(filename) + config = DeviceConfiguration.from_file(filename) return cls.from_config_with_hci(config, hci_source, hci_sink) def __init__( diff --git a/bumble/keys.py b/bumble/keys.py index 5be5e09..facaa37 100644 --- a/bumble/keys.py +++ b/bumble/keys.py @@ -25,7 +25,8 @@ import asyncio import logging import os import json -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type +from typing_extensions import Self from .colors import color from .hci import Address @@ -253,8 +254,10 @@ class JsonKeyStore(KeyStore): logger.debug(f'JSON keystore: {self.filename}') - @staticmethod - def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]: + @classmethod + def from_device( + cls: Type[Self], device: Device, filename: Optional[str] = None + ) -> Self: if not filename: # Extract the filename from the config if there is one if device.config.keystore is not None: @@ -270,7 +273,7 @@ class JsonKeyStore(KeyStore): else: namespace = JsonKeyStore.DEFAULT_NAMESPACE - return JsonKeyStore(namespace, filename) + return cls(namespace, filename) async def load(self): # Try to open the file, without failing. If the file does not exist, it