Merge pull request #478 from zxzxwu/config

Make DeviceConfiguration dataclass
This commit is contained in:
zxzxwu
2024-05-13 16:57:15 +08:00
committed by GitHub
2 changed files with 79 additions and 72 deletions

View File

@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
from enum import IntEnum
import copy
import functools
import json
import asyncio
@@ -40,6 +41,7 @@ from typing import (
overload,
TYPE_CHECKING,
)
from typing_extensions import Self
from pyee import EventEmitter
@@ -1252,75 +1254,47 @@ class Connection(CompositeEventEmitter):
# -----------------------------------------------------------------------------
@dataclass
class DeviceConfiguration:
def __init__(self) -> None:
# Setup defaults
self.name = DEVICE_DEFAULT_NAME
self.address = Address(DEVICE_DEFAULT_ADDRESS)
self.class_of_device = DEVICE_DEFAULT_CLASS_OF_DEVICE
self.scan_response_data = DEVICE_DEFAULT_SCAN_RESPONSE_DATA
self.advertising_interval_min = DEVICE_DEFAULT_ADVERTISING_INTERVAL
self.advertising_interval_max = DEVICE_DEFAULT_ADVERTISING_INTERVAL
self.le_enabled = True
# LE host enable 2nd parameter
self.le_simultaneous_enabled = False
self.classic_enabled = False
self.classic_sc_enabled = True
self.classic_ssp_enabled = True
self.classic_smp_enabled = True
self.classic_accept_any = True
self.connectable = True
self.discoverable = True
self.advertising_data = bytes(
AdvertisingData(
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))]
)
# Setup defaults
name: str = DEVICE_DEFAULT_NAME
address: Address = Address(DEVICE_DEFAULT_ADDRESS)
class_of_device: int = DEVICE_DEFAULT_CLASS_OF_DEVICE
scan_response_data: bytes = DEVICE_DEFAULT_SCAN_RESPONSE_DATA
advertising_interval_min: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
advertising_interval_max: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
le_enabled: bool = True
# LE host enable 2nd parameter
le_simultaneous_enabled: bool = False
classic_enabled: bool = False
classic_sc_enabled: bool = True
classic_ssp_enabled: bool = True
classic_smp_enabled: bool = True
classic_accept_any: bool = True
connectable: bool = True
discoverable: bool = True
advertising_data: bytes = bytes(
AdvertisingData(
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(DEVICE_DEFAULT_NAME, 'utf-8'))]
)
self.irk = bytes(16) # This really must be changed for any level of security
self.keystore = None
)
irk: bytes = bytes(16) # This really must be changed for any level of security
keystore: Optional[str] = None
address_resolution_offload: bool = False
cis_enabled: bool = False
def __post_init__(self) -> None:
self.gatt_services: List[Dict[str, Any]] = []
self.address_resolution_offload = False
self.cis_enabled = False
def load_from_dict(self, config: Dict[str, Any]) -> None:
config = copy.deepcopy(config)
# Load simple properties
self.name = config.get('name', self.name)
if address := config.get('address', None):
if address := config.pop('address', None):
self.address = Address(address)
self.class_of_device = config.get('class_of_device', self.class_of_device)
self.advertising_interval_min = config.get(
'advertising_interval', self.advertising_interval_min
)
self.advertising_interval_max = self.advertising_interval_min
self.keystore = config.get('keystore')
self.le_enabled = config.get('le_enabled', self.le_enabled)
self.le_simultaneous_enabled = config.get(
'le_simultaneous_enabled', self.le_simultaneous_enabled
)
self.classic_enabled = config.get('classic_enabled', self.classic_enabled)
self.classic_sc_enabled = config.get(
'classic_sc_enabled', self.classic_sc_enabled
)
self.classic_ssp_enabled = config.get(
'classic_ssp_enabled', self.classic_ssp_enabled
)
self.classic_smp_enabled = config.get(
'classic_smp_enabled', self.classic_smp_enabled
)
self.classic_accept_any = config.get(
'classic_accept_any', self.classic_accept_any
)
self.connectable = config.get('connectable', self.connectable)
self.discoverable = config.get('discoverable', self.discoverable)
self.gatt_services = config.get('gatt_services', self.gatt_services)
self.address_resolution_offload = config.get(
'address_resolution_offload', self.address_resolution_offload
)
self.cis_enabled = config.get('cis_enabled', self.cis_enabled)
# Load or synthesize an IRK
irk = config.get('irk')
if irk:
if irk := config.pop('irk', None):
self.irk = bytes.fromhex(irk)
elif self.address != Address(DEVICE_DEFAULT_ADDRESS):
# Construct an IRK from the address bytes
@@ -1332,21 +1306,53 @@ class DeviceConfiguration:
# Fallback - when both IRK and address are not set, randomly generate an IRK.
self.irk = secrets.token_bytes(16)
if (name := config.pop('name', None)) is not None:
self.name = name
# Load advertising data
advertising_data = config.get('advertising_data')
if advertising_data:
if advertising_data := config.pop('advertising_data', None):
self.advertising_data = bytes.fromhex(advertising_data)
elif config.get('name') is not None:
elif name is not None:
self.advertising_data = bytes(
AdvertisingData(
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))]
)
)
def load_from_file(self, filename):
# Load advertising interval (for backward compatibility)
if advertising_interval := config.pop('advertising_interval', None):
self.advertising_interval_min = advertising_interval
self.advertising_interval_max = advertising_interval
if (
'advertising_interval_max' in config
or 'advertising_interval_min' in config
):
logger.warning(
'Trying to set both advertising_interval and '
'advertising_interval_min/max, advertising_interval will be'
'ignored.'
)
# Load data in primitive types.
for key, value in config.items():
setattr(self, key, value)
def load_from_file(self, filename: str) -> None:
with open(filename, 'r', encoding='utf-8') as file:
self.load_from_dict(json.load(file))
@classmethod
def from_file(cls: Type[Self], filename: str) -> Self:
config = cls()
config.load_from_file(filename)
return config
@classmethod
def from_dict(cls: Type[Self], config: Dict[str, Any]) -> Self:
device_config = cls()
device_config.load_from_dict(config)
return device_config
# -----------------------------------------------------------------------------
# Decorators used with the following Device class
@@ -1470,8 +1476,7 @@ class Device(CompositeEventEmitter):
@classmethod
def from_config_file(cls, filename: str) -> Device:
config = DeviceConfiguration()
config.load_from_file(filename)
config = DeviceConfiguration.from_file(filename)
return cls(config=config)
@classmethod
@@ -1488,8 +1493,7 @@ class Device(CompositeEventEmitter):
def from_config_file_with_hci(
cls, filename: str, hci_source: TransportSource, hci_sink: TransportSink
) -> Device:
config = DeviceConfiguration()
config.load_from_file(filename)
config = DeviceConfiguration.from_file(filename)
return cls.from_config_with_hci(config, hci_source, hci_sink)
def __init__(

View File

@@ -25,7 +25,8 @@ import asyncio
import logging
import os
import json
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
from typing_extensions import Self
from .colors import color
from .hci import Address
@@ -253,8 +254,10 @@ class JsonKeyStore(KeyStore):
logger.debug(f'JSON keystore: {self.filename}')
@staticmethod
def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]:
@classmethod
def from_device(
cls: Type[Self], device: Device, filename: Optional[str] = None
) -> Self:
if not filename:
# Extract the filename from the config if there is one
if device.config.keystore is not None:
@@ -270,7 +273,7 @@ class JsonKeyStore(KeyStore):
else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE
return JsonKeyStore(namespace, filename)
return cls(namespace, filename)
async def load(self):
# Try to open the file, without failing. If the file does not exist, it