mirror of
https://github.com/google/bumble.git
synced 2026-04-16 00:25:31 +00:00
Merge pull request #478 from zxzxwu/config
Make DeviceConfiguration dataclass
This commit is contained in:
140
bumble/device.py
140
bumble/device.py
@@ -17,6 +17,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
import copy
|
||||
import functools
|
||||
import json
|
||||
import asyncio
|
||||
@@ -40,6 +41,7 @@ from typing import (
|
||||
overload,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from pyee import EventEmitter
|
||||
|
||||
@@ -1252,75 +1254,47 @@ class Connection(CompositeEventEmitter):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class DeviceConfiguration:
|
||||
def __init__(self) -> None:
|
||||
# Setup defaults
|
||||
self.name = DEVICE_DEFAULT_NAME
|
||||
self.address = Address(DEVICE_DEFAULT_ADDRESS)
|
||||
self.class_of_device = DEVICE_DEFAULT_CLASS_OF_DEVICE
|
||||
self.scan_response_data = DEVICE_DEFAULT_SCAN_RESPONSE_DATA
|
||||
self.advertising_interval_min = DEVICE_DEFAULT_ADVERTISING_INTERVAL
|
||||
self.advertising_interval_max = DEVICE_DEFAULT_ADVERTISING_INTERVAL
|
||||
self.le_enabled = True
|
||||
# LE host enable 2nd parameter
|
||||
self.le_simultaneous_enabled = False
|
||||
self.classic_enabled = False
|
||||
self.classic_sc_enabled = True
|
||||
self.classic_ssp_enabled = True
|
||||
self.classic_smp_enabled = True
|
||||
self.classic_accept_any = True
|
||||
self.connectable = True
|
||||
self.discoverable = True
|
||||
self.advertising_data = bytes(
|
||||
AdvertisingData(
|
||||
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))]
|
||||
)
|
||||
# Setup defaults
|
||||
name: str = DEVICE_DEFAULT_NAME
|
||||
address: Address = Address(DEVICE_DEFAULT_ADDRESS)
|
||||
class_of_device: int = DEVICE_DEFAULT_CLASS_OF_DEVICE
|
||||
scan_response_data: bytes = DEVICE_DEFAULT_SCAN_RESPONSE_DATA
|
||||
advertising_interval_min: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
|
||||
advertising_interval_max: int = DEVICE_DEFAULT_ADVERTISING_INTERVAL
|
||||
le_enabled: bool = True
|
||||
# LE host enable 2nd parameter
|
||||
le_simultaneous_enabled: bool = False
|
||||
classic_enabled: bool = False
|
||||
classic_sc_enabled: bool = True
|
||||
classic_ssp_enabled: bool = True
|
||||
classic_smp_enabled: bool = True
|
||||
classic_accept_any: bool = True
|
||||
connectable: bool = True
|
||||
discoverable: bool = True
|
||||
advertising_data: bytes = bytes(
|
||||
AdvertisingData(
|
||||
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(DEVICE_DEFAULT_NAME, 'utf-8'))]
|
||||
)
|
||||
self.irk = bytes(16) # This really must be changed for any level of security
|
||||
self.keystore = None
|
||||
)
|
||||
irk: bytes = bytes(16) # This really must be changed for any level of security
|
||||
keystore: Optional[str] = None
|
||||
address_resolution_offload: bool = False
|
||||
cis_enabled: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.gatt_services: List[Dict[str, Any]] = []
|
||||
self.address_resolution_offload = False
|
||||
self.cis_enabled = False
|
||||
|
||||
def load_from_dict(self, config: Dict[str, Any]) -> None:
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
# Load simple properties
|
||||
self.name = config.get('name', self.name)
|
||||
if address := config.get('address', None):
|
||||
if address := config.pop('address', None):
|
||||
self.address = Address(address)
|
||||
self.class_of_device = config.get('class_of_device', self.class_of_device)
|
||||
self.advertising_interval_min = config.get(
|
||||
'advertising_interval', self.advertising_interval_min
|
||||
)
|
||||
self.advertising_interval_max = self.advertising_interval_min
|
||||
self.keystore = config.get('keystore')
|
||||
self.le_enabled = config.get('le_enabled', self.le_enabled)
|
||||
self.le_simultaneous_enabled = config.get(
|
||||
'le_simultaneous_enabled', self.le_simultaneous_enabled
|
||||
)
|
||||
self.classic_enabled = config.get('classic_enabled', self.classic_enabled)
|
||||
self.classic_sc_enabled = config.get(
|
||||
'classic_sc_enabled', self.classic_sc_enabled
|
||||
)
|
||||
self.classic_ssp_enabled = config.get(
|
||||
'classic_ssp_enabled', self.classic_ssp_enabled
|
||||
)
|
||||
self.classic_smp_enabled = config.get(
|
||||
'classic_smp_enabled', self.classic_smp_enabled
|
||||
)
|
||||
self.classic_accept_any = config.get(
|
||||
'classic_accept_any', self.classic_accept_any
|
||||
)
|
||||
self.connectable = config.get('connectable', self.connectable)
|
||||
self.discoverable = config.get('discoverable', self.discoverable)
|
||||
self.gatt_services = config.get('gatt_services', self.gatt_services)
|
||||
self.address_resolution_offload = config.get(
|
||||
'address_resolution_offload', self.address_resolution_offload
|
||||
)
|
||||
self.cis_enabled = config.get('cis_enabled', self.cis_enabled)
|
||||
|
||||
# Load or synthesize an IRK
|
||||
irk = config.get('irk')
|
||||
if irk:
|
||||
if irk := config.pop('irk', None):
|
||||
self.irk = bytes.fromhex(irk)
|
||||
elif self.address != Address(DEVICE_DEFAULT_ADDRESS):
|
||||
# Construct an IRK from the address bytes
|
||||
@@ -1332,21 +1306,53 @@ class DeviceConfiguration:
|
||||
# Fallback - when both IRK and address are not set, randomly generate an IRK.
|
||||
self.irk = secrets.token_bytes(16)
|
||||
|
||||
if (name := config.pop('name', None)) is not None:
|
||||
self.name = name
|
||||
|
||||
# Load advertising data
|
||||
advertising_data = config.get('advertising_data')
|
||||
if advertising_data:
|
||||
if advertising_data := config.pop('advertising_data', None):
|
||||
self.advertising_data = bytes.fromhex(advertising_data)
|
||||
elif config.get('name') is not None:
|
||||
elif name is not None:
|
||||
self.advertising_data = bytes(
|
||||
AdvertisingData(
|
||||
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))]
|
||||
)
|
||||
)
|
||||
|
||||
def load_from_file(self, filename):
|
||||
# Load advertising interval (for backward compatibility)
|
||||
if advertising_interval := config.pop('advertising_interval', None):
|
||||
self.advertising_interval_min = advertising_interval
|
||||
self.advertising_interval_max = advertising_interval
|
||||
if (
|
||||
'advertising_interval_max' in config
|
||||
or 'advertising_interval_min' in config
|
||||
):
|
||||
logger.warning(
|
||||
'Trying to set both advertising_interval and '
|
||||
'advertising_interval_min/max, advertising_interval will be'
|
||||
'ignored.'
|
||||
)
|
||||
|
||||
# Load data in primitive types.
|
||||
for key, value in config.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def load_from_file(self, filename: str) -> None:
|
||||
with open(filename, 'r', encoding='utf-8') as file:
|
||||
self.load_from_dict(json.load(file))
|
||||
|
||||
@classmethod
|
||||
def from_file(cls: Type[Self], filename: str) -> Self:
|
||||
config = cls()
|
||||
config.load_from_file(filename)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[Self], config: Dict[str, Any]) -> Self:
|
||||
device_config = cls()
|
||||
device_config.load_from_dict(config)
|
||||
return device_config
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Decorators used with the following Device class
|
||||
@@ -1470,8 +1476,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
@classmethod
|
||||
def from_config_file(cls, filename: str) -> Device:
|
||||
config = DeviceConfiguration()
|
||||
config.load_from_file(filename)
|
||||
config = DeviceConfiguration.from_file(filename)
|
||||
return cls(config=config)
|
||||
|
||||
@classmethod
|
||||
@@ -1488,8 +1493,7 @@ class Device(CompositeEventEmitter):
|
||||
def from_config_file_with_hci(
|
||||
cls, filename: str, hci_source: TransportSource, hci_sink: TransportSink
|
||||
) -> Device:
|
||||
config = DeviceConfiguration()
|
||||
config.load_from_file(filename)
|
||||
config = DeviceConfiguration.from_file(filename)
|
||||
return cls.from_config_with_hci(config, hci_source, hci_sink)
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -25,7 +25,8 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
|
||||
from typing_extensions import Self
|
||||
|
||||
from .colors import color
|
||||
from .hci import Address
|
||||
@@ -253,8 +254,10 @@ class JsonKeyStore(KeyStore):
|
||||
|
||||
logger.debug(f'JSON keystore: {self.filename}')
|
||||
|
||||
@staticmethod
|
||||
def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]:
|
||||
@classmethod
|
||||
def from_device(
|
||||
cls: Type[Self], device: Device, filename: Optional[str] = None
|
||||
) -> Self:
|
||||
if not filename:
|
||||
# Extract the filename from the config if there is one
|
||||
if device.config.keystore is not None:
|
||||
@@ -270,7 +273,7 @@ class JsonKeyStore(KeyStore):
|
||||
else:
|
||||
namespace = JsonKeyStore.DEFAULT_NAMESPACE
|
||||
|
||||
return JsonKeyStore(namespace, filename)
|
||||
return cls(namespace, filename)
|
||||
|
||||
async def load(self):
|
||||
# Try to open the file, without failing. If the file does not exist, it
|
||||
|
||||
Reference in New Issue
Block a user