Make DeviceConfiguration dataclass

This commit is contained in:
Josh Wu
2024-04-29 20:04:59 +08:00
parent 1b33c9eb74
commit a5ac5f26e2
2 changed files with 79 additions and 72 deletions

View File

@@ -17,6 +17,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
from enum import IntEnum from enum import IntEnum
import copy
import functools import functools
import json import json
import asyncio import asyncio
@@ -40,6 +41,7 @@ from typing import (
overload, overload,
TYPE_CHECKING, TYPE_CHECKING,
) )
from typing_extensions import Self
from pyee import EventEmitter from pyee import EventEmitter
@@ -1252,75 +1254,47 @@ class Connection(CompositeEventEmitter):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclass
class DeviceConfiguration: class DeviceConfiguration:
def __init__(self) -> None: # Setup defaults
# Setup defaults name: str = DEVICE_DEFAULT_NAME
self.name = DEVICE_DEFAULT_NAME address: Address = Address(DEVICE_DEFAULT_ADDRESS)
self.address = Address(DEVICE_DEFAULT_ADDRESS) class_of_device: int = DEVICE_DEFAULT_CLASS_OF_DEVICE
self.class_of_device = DEVICE_DEFAULT_CLASS_OF_DEVICE scan_response_data: bytes = DEVICE_DEFAULT_SCAN_RESPONSE_DATA
self.scan_response_data = DEVICE_DEFAULT_SCAN_RESPONSE_DATA advertising_interval_min: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
self.advertising_interval_min = DEVICE_DEFAULT_ADVERTISING_INTERVAL advertising_interval_max: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
self.advertising_interval_max = DEVICE_DEFAULT_ADVERTISING_INTERVAL le_enabled: bool = True
self.le_enabled = True # LE host enable 2nd parameter
# LE host enable 2nd parameter le_simultaneous_enabled: bool = False
self.le_simultaneous_enabled = False classic_enabled: bool = False
self.classic_enabled = False classic_sc_enabled: bool = True
self.classic_sc_enabled = True classic_ssp_enabled: bool = True
self.classic_ssp_enabled = True classic_smp_enabled: bool = True
self.classic_smp_enabled = True classic_accept_any: bool = True
self.classic_accept_any = True connectable: bool = True
self.connectable = True discoverable: bool = True
self.discoverable = True advertising_data: bytes = bytes(
self.advertising_data = bytes( AdvertisingData(
AdvertisingData( [(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(DEVICE_DEFAULT_NAME, 'utf-8'))]
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))]
)
) )
self.irk = bytes(16) # This really must be changed for any level of security )
self.keystore = None irk: bytes = bytes(16) # This really must be changed for any level of security
keystore: Optional[str] = None
address_resolution_offload: bool = False
cis_enabled: bool = False
def __post_init__(self) -> None:
self.gatt_services: List[Dict[str, Any]] = [] self.gatt_services: List[Dict[str, Any]] = []
self.address_resolution_offload = False
self.cis_enabled = False
def load_from_dict(self, config: Dict[str, Any]) -> None: def load_from_dict(self, config: Dict[str, Any]) -> None:
config = copy.deepcopy(config)
# Load simple properties # Load simple properties
self.name = config.get('name', self.name) if address := config.pop('address', None):
if address := config.get('address', None):
self.address = Address(address) self.address = Address(address)
self.class_of_device = config.get('class_of_device', self.class_of_device)
self.advertising_interval_min = config.get(
'advertising_interval', self.advertising_interval_min
)
self.advertising_interval_max = self.advertising_interval_min
self.keystore = config.get('keystore')
self.le_enabled = config.get('le_enabled', self.le_enabled)
self.le_simultaneous_enabled = config.get(
'le_simultaneous_enabled', self.le_simultaneous_enabled
)
self.classic_enabled = config.get('classic_enabled', self.classic_enabled)
self.classic_sc_enabled = config.get(
'classic_sc_enabled', self.classic_sc_enabled
)
self.classic_ssp_enabled = config.get(
'classic_ssp_enabled', self.classic_ssp_enabled
)
self.classic_smp_enabled = config.get(
'classic_smp_enabled', self.classic_smp_enabled
)
self.classic_accept_any = config.get(
'classic_accept_any', self.classic_accept_any
)
self.connectable = config.get('connectable', self.connectable)
self.discoverable = config.get('discoverable', self.discoverable)
self.gatt_services = config.get('gatt_services', self.gatt_services)
self.address_resolution_offload = config.get(
'address_resolution_offload', self.address_resolution_offload
)
self.cis_enabled = config.get('cis_enabled', self.cis_enabled)
# Load or synthesize an IRK # Load or synthesize an IRK
irk = config.get('irk') if irk := config.pop('irk', None):
if irk:
self.irk = bytes.fromhex(irk) self.irk = bytes.fromhex(irk)
elif self.address != Address(DEVICE_DEFAULT_ADDRESS): elif self.address != Address(DEVICE_DEFAULT_ADDRESS):
# Construct an IRK from the address bytes # Construct an IRK from the address bytes
@@ -1332,21 +1306,53 @@ class DeviceConfiguration:
# Fallback - when both IRK and address are not set, randomly generate an IRK. # Fallback - when both IRK and address are not set, randomly generate an IRK.
self.irk = secrets.token_bytes(16) self.irk = secrets.token_bytes(16)
if (name := config.pop('name', None)) is not None:
self.name = name
# Load advertising data # Load advertising data
advertising_data = config.get('advertising_data') if advertising_data := config.pop('advertising_data', None):
if advertising_data:
self.advertising_data = bytes.fromhex(advertising_data) self.advertising_data = bytes.fromhex(advertising_data)
elif config.get('name') is not None: elif name is not None:
self.advertising_data = bytes( self.advertising_data = bytes(
AdvertisingData( AdvertisingData(
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))] [(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))]
) )
) )
def load_from_file(self, filename): # Load advertising interval (for backward compatibility)
if advertising_interval := config.pop('advertising_interval', None):
self.advertising_interval_min = advertising_interval
self.advertising_interval_max = advertising_interval
if (
'advertising_interval_max' in config
or 'advertising_interval_min' in config
):
logger.warning(
'Trying to set both advertising_interval and '
'advertising_interval_min/max, advertising_interval will be'
'ignored.'
)
# Load data in primitive types.
for key, value in config.items():
setattr(self, key, value)
def load_from_file(self, filename: str) -> None:
with open(filename, 'r', encoding='utf-8') as file: with open(filename, 'r', encoding='utf-8') as file:
self.load_from_dict(json.load(file)) self.load_from_dict(json.load(file))
@classmethod
def from_file(cls: Type[Self], filename: str) -> Self:
config = cls()
config.load_from_file(filename)
return config
@classmethod
def from_dict(cls: Type[Self], config: Dict[str, Any]) -> Self:
device_config = cls()
device_config.load_from_dict(config)
return device_config
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Decorators used with the following Device class # Decorators used with the following Device class
@@ -1470,8 +1476,7 @@ class Device(CompositeEventEmitter):
@classmethod @classmethod
def from_config_file(cls, filename: str) -> Device: def from_config_file(cls, filename: str) -> Device:
config = DeviceConfiguration() config = DeviceConfiguration.from_file(filename)
config.load_from_file(filename)
return cls(config=config) return cls(config=config)
@classmethod @classmethod
@@ -1488,8 +1493,7 @@ class Device(CompositeEventEmitter):
def from_config_file_with_hci( def from_config_file_with_hci(
cls, filename: str, hci_source: TransportSource, hci_sink: TransportSink cls, filename: str, hci_source: TransportSource, hci_sink: TransportSink
) -> Device: ) -> Device:
config = DeviceConfiguration() config = DeviceConfiguration.from_file(filename)
config.load_from_file(filename)
return cls.from_config_with_hci(config, hci_source, hci_sink) return cls.from_config_with_hci(config, hci_source, hci_sink)
def __init__( def __init__(

View File

@@ -25,7 +25,8 @@ import asyncio
import logging import logging
import os import os
import json import json
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
from typing_extensions import Self
from .colors import color from .colors import color
from .hci import Address from .hci import Address
@@ -253,8 +254,10 @@ class JsonKeyStore(KeyStore):
logger.debug(f'JSON keystore: {self.filename}') logger.debug(f'JSON keystore: {self.filename}')
@staticmethod @classmethod
def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]: def from_device(
cls: Type[Self], device: Device, filename: Optional[str] = None
) -> Self:
if not filename: if not filename:
# Extract the filename from the config if there is one # Extract the filename from the config if there is one
if device.config.keystore is not None: if device.config.keystore is not None:
@@ -270,7 +273,7 @@ class JsonKeyStore(KeyStore):
else: else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE namespace = JsonKeyStore.DEFAULT_NAMESPACE
return JsonKeyStore(namespace, filename) return cls(namespace, filename)
async def load(self): async def load(self):
# Try to open the file, without failing. If the file does not exist, it # Try to open the file, without failing. If the file does not exist, it