From 4a333b6c0f1118c22b7110295c959bb1ecd5e0ed Mon Sep 17 00:00:00 2001 From: uael Date: Wed, 19 Apr 2023 22:20:55 +0000 Subject: [PATCH] keys: add an in-memory key-store fallback Instead of defaulting the key-store to `None`, use an in-memory one. This way a keystore is always available. A future improvement could be to rework the device keystore initialization to remove checks like `if self.keystore:` along the codebase. --- bumble/device.py | 1 + bumble/keys.py | 41 +++++++++++++++++++++++++++++++---------- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/bumble/device.py b/bumble/device.py index 61594352..de3912b6 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -2243,6 +2243,7 @@ class Device(CompositeEventEmitter): return None return keys.link_key.value + return None # [Classic only] async def authenticate(self, connection): diff --git a/bumble/keys.py b/bumble/keys.py index bbd46a51..a30e7530 100644 --- a/bumble/keys.py +++ b/bumble/keys.py @@ -25,7 +25,7 @@ import asyncio import logging import os import json -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from .colors import color from .hci import Address @@ -139,19 +139,19 @@ class PairingKeys: # ----------------------------------------------------------------------------- class KeyStore: - async def delete(self, name): + async def delete(self, name: str): pass - async def update(self, name, keys): + async def update(self, name: str, keys: PairingKeys) -> None: pass - async def get(self, _name): - return PairingKeys() + async def get(self, _name: str) -> Optional[PairingKeys]: + return None - async def get_all(self): + async def get_all(self) -> List[Tuple[str, PairingKeys]]: return [] - async def delete_all(self): + async def delete_all(self) -> None: all_keys = await self.get_all() await asyncio.gather(*(self.delete(name) for (name, _) in all_keys)) @@ -177,15 +177,15 @@ class KeyStore: separator = '\n' @staticmethod - def create_for_device(device: Device) -> Optional[KeyStore]: + def create_for_device(device: Device) -> KeyStore: if device.config.keystore is None: - return None + return MemoryKeyStore() keystore_type = device.config.keystore.split(':', 1)[0] if keystore_type == 'JsonKeyStore': return JsonKeyStore.from_device(device) - return None + return MemoryKeyStore() # ----------------------------------------------------------------------------- @@ -307,3 +307,24 @@ class JsonKeyStore(KeyStore): return None return PairingKeys.from_dict(keys) + + +# ----------------------------------------------------------------------------- +class MemoryKeyStore(KeyStore): + all_keys: Dict[str, PairingKeys] + + def __init__(self) -> None: + self.all_keys = {} + + async def delete(self, name: str) -> None: + if name in self.all_keys: + del self.all_keys[name] + + async def update(self, name: str, keys: PairingKeys) -> None: + self.all_keys[name] = keys + + async def get(self, name: str) -> Optional[PairingKeys]: + return self.all_keys.get(name) + + async def get_all(self) -> List[Tuple[str, PairingKeys]]: + return list(self.all_keys.items())