Merge pull request #171 from google/gbg/keystore-name

use device public or static address for keystore namespace
This commit is contained in:
Gilles Boccon-Gibod
2023-04-03 17:58:18 -07:00
committed by GitHub
3 changed files with 38 additions and 15 deletions

View File

@@ -339,8 +339,7 @@ async def run(
# Create a UDP to TX bridge (receive from TX, send to UDP) # Create a UDP to TX bridge (receive from TX, send to UDP)
bridge.tx_socket, _ = await loop.create_datagram_endpoint( bridge.tx_socket, _ = await loop.create_datagram_endpoint(
# pylint: disable-next=unnecessary-lambda asyncio.DatagramProtocol,
lambda: asyncio.DatagramProtocol(),
remote_addr=(send_host, send_port), remote_addr=(send_host, send_port),
) )

View File

@@ -878,7 +878,7 @@ device_host_event_handlers: list[str] = []
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Device(CompositeEventEmitter): class Device(CompositeEventEmitter):
# incomplete list of fields. # Incomplete list of fields.
random_address: Address random_address: Address
public_address: Address public_address: Address
classic_enabled: bool classic_enabled: bool
@@ -893,6 +893,7 @@ class Device(CompositeEventEmitter):
Address, List[asyncio.Future[Union[Connection, Tuple[Address, int, int]]]] Address, List[asyncio.Future[Union[Connection, Tuple[Address, int, int]]]]
] ]
advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator] advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator]
config: DeviceConfiguration
@composite_listener @composite_listener
class Listener: class Listener:
@@ -980,9 +981,10 @@ class Device(CompositeEventEmitter):
self.connect_own_address_type = None self.connect_own_address_type = None
# Use the initial config or a default # Use the initial config or a default
config = config or DeviceConfiguration()
self.config = config
self.public_address = Address('00:00:00:00:00:00') self.public_address = Address('00:00:00:00:00:00')
if config is None:
config = DeviceConfiguration()
self.name = config.name self.name = config.name
self.random_address = config.address self.random_address = config.address
self.class_of_device = config.class_of_device self.class_of_device = config.class_of_device
@@ -990,7 +992,7 @@ class Device(CompositeEventEmitter):
self.advertising_data = config.advertising_data self.advertising_data = config.advertising_data
self.advertising_interval_min = config.advertising_interval_min self.advertising_interval_min = config.advertising_interval_min
self.advertising_interval_max = config.advertising_interval_max self.advertising_interval_max = config.advertising_interval_max
self.keystore = KeyStore.create_for_device(config) self.keystore = None
self.irk = config.irk self.irk = config.irk
self.le_enabled = config.le_enabled self.le_enabled = config.le_enabled
self.classic_enabled = config.classic_enabled self.classic_enabled = config.classic_enabled
@@ -1167,6 +1169,7 @@ class Device(CompositeEventEmitter):
# Reset the controller # Reset the controller
await self.host.reset() await self.host.reset()
# Try to get the public address from the controller
response = await self.send_command(HCI_Read_BD_ADDR_Command()) # type: ignore[call-arg] response = await self.send_command(HCI_Read_BD_ADDR_Command()) # type: ignore[call-arg]
if response.return_parameters.status == HCI_SUCCESS: if response.return_parameters.status == HCI_SUCCESS:
logger.debug( logger.debug(
@@ -1174,6 +1177,11 @@ class Device(CompositeEventEmitter):
) )
self.public_address = response.return_parameters.bd_addr self.public_address = response.return_parameters.bd_addr
# Instantiate the Key Store (we do this here rather than at __init__ time
# because some Key Store implementations use the public address as a namespace)
if self.keystore is None:
self.keystore = KeyStore.create_for_device(self)
if self.host.supports_command(HCI_WRITE_LE_HOST_SUPPORT_COMMAND): if self.host.supports_command(HCI_WRITE_LE_HOST_SUPPORT_COMMAND):
await self.send_command( await self.send_command(
HCI_Write_LE_Host_Support_Command( HCI_Write_LE_Host_Support_Command(

View File

@@ -20,15 +20,19 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os import os
import json import json
from typing import Optional from typing import TYPE_CHECKING, Optional
from .colors import color from .colors import color
from .hci import Address from .hci import Address
if TYPE_CHECKING:
from .device import Device
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -173,13 +177,13 @@ class KeyStore:
separator = '\n' separator = '\n'
@staticmethod @staticmethod
def create_for_device(device_config): def create_for_device(device: Device) -> Optional[KeyStore]:
if device_config.keystore is None: if device.config.keystore is None:
return None return None
keystore_type = device_config.keystore.split(':', 1)[0] keystore_type = device.config.keystore.split(':', 1)[0]
if keystore_type == 'JsonKeyStore': if keystore_type == 'JsonKeyStore':
return JsonKeyStore.from_device_config(device_config) return JsonKeyStore.from_device(device)
return None return None
@@ -204,7 +208,9 @@ class JsonKeyStore(KeyStore):
self.directory_name = os.path.join( self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
) )
json_filename = f'{self.namespace}.json'.lower().replace(':', '-') json_filename = (
f'{self.namespace}.json'.lower().replace(':', '-').replace('/p', '-p')
)
self.filename = os.path.join(self.directory_name, json_filename) self.filename = os.path.join(self.directory_name, json_filename)
else: else:
self.filename = filename self.filename = filename
@@ -213,9 +219,19 @@ class JsonKeyStore(KeyStore):
logger.debug(f'JSON keystore: {self.filename}') logger.debug(f'JSON keystore: {self.filename}')
@staticmethod @staticmethod
def from_device_config(device_config): def from_device(device: Device) -> Optional[JsonKeyStore]:
params = device_config.keystore.split(':', 1)[1:] if not device.config.keystore:
namespace = str(device_config.address) return None
params = device.config.keystore.split(':', 1)[1:]
# Use a namespace based on the device address
if device.public_address not in (Address.ANY, Address.ANY_RANDOM):
namespace = str(device.public_address)
elif device.random_address != Address.ANY_RANDOM:
namespace = str(device.random_address)
else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE
if params: if params:
filename = params[0] filename = params[0]
else: else: