Compare commits

..

5 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod 27fbb58447 add basic keystore test 2023-06-04 13:01:07 -07:00
Gilles Boccon-Gibod 6826f68478 fix linter warnings 2023-05-05 16:16:55 -07:00
Gilles Boccon-Gibod f80c83d0b3 better doc and default behavior for json keystore 2023-05-05 16:11:20 -07:00
Gilles Boccon-Gibod 3de35193bc rebase 2023-05-05 16:09:01 -07:00
Gilles Boccon-Gibod 740a2e0ca0 instantiate keystore after power_on 2023-05-05 16:07:16 -07:00
11 changed files with 329 additions and 211 deletions
+14 -14
View File
@@ -207,7 +207,7 @@ def on_connection(connection, request):
# Listen for pairing events
connection.on('pairing_start', on_pairing_start)
connection.on('pairing', on_pairing)
connection.on('pairing', lambda keys: on_pairing(connection.peer_address, keys))
connection.on('pairing_failure', on_pairing_failure)
# Listen for encryption changes
@@ -242,9 +242,9 @@ def on_pairing_start():
# -----------------------------------------------------------------------------
def on_pairing(keys):
def on_pairing(address, keys):
print(color('***-----------------------------------', 'cyan'))
print(color('*** Paired!', 'cyan'))
print(color(f'*** Paired! (peer identity={address})', 'cyan'))
keys.print(prefix=color('*** ', 'cyan'))
print(color('***-----------------------------------', 'cyan'))
Waiter.instance.terminate()
@@ -283,17 +283,6 @@ async def pair(
# Create a device to manage the host
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
# Set a custom keystore if specified on the command line
if keystore_file:
device.keystore = JsonKeyStore(namespace=None, filename=keystore_file)
# Print the existing keys before pairing
if print_keys and device.keystore:
print(color('@@@-----------------------------------', 'blue'))
print(color('@@@ Pairing Keys:', 'blue'))
await device.keystore.print(prefix=color('@@@ ', 'blue'))
print(color('@@@-----------------------------------', 'blue'))
# Expose a GATT characteristic that can be used to trigger pairing by
# responding with an authentication error when read
if mode == 'le':
@@ -323,6 +312,17 @@ async def pair(
# Get things going
await device.power_on()
# Set a custom keystore if specified on the command line
if keystore_file:
device.keystore = JsonKeyStore.from_device(device, filename=keystore_file)
# Print the existing keys before pairing
if print_keys and device.keystore:
print(color('@@@-----------------------------------', 'blue'))
print(color('@@@ Pairing Keys:', 'blue'))
await device.keystore.print(prefix=color('@@@ ', 'blue'))
print(color('@@@-----------------------------------', 'blue'))
# Set up a pairing config factory
device.pairing_config_factory = lambda connection: PairingConfig(
sc, mitm, bond, Delegate(mode, connection, io, prompt)
+5 -6
View File
@@ -133,15 +133,16 @@ async def scan(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
await device.power_on()
if keystore_file:
keystore = JsonKeyStore(namespace=None, filename=keystore_file)
device.keystore = keystore
else:
resolver = None
device.keystore = JsonKeyStore.from_device(device, filename=keystore_file)
if device.keystore:
resolving_keys = await device.keystore.get_resolving_keys()
resolver = AddressResolver(resolving_keys)
else:
resolver = None
printer = AdvertisementPrinter(min_rssi, resolver)
if raw:
@@ -149,8 +150,6 @@ async def scan(
else:
device.on('advertisement', printer.on_advertisement)
await device.power_on()
if phy is None:
scanning_phys = [HCI_LE_1M_PHY, HCI_LE_CODED_PHY]
else:
+40 -22
View File
@@ -22,40 +22,58 @@ import click
from bumble.device import Device
from bumble.keys import JsonKeyStore
from bumble.transport import open_transport
# -----------------------------------------------------------------------------
async def unbond_with_keystore(keystore, address):
if address is None:
return await keystore.print()
try:
await keystore.delete(address)
except KeyError:
print('!!! pairing not found')
# -----------------------------------------------------------------------------
async def unbond(keystore_file, device_config, address):
# Create a device to manage the host
device = Device.from_config_file(device_config)
# Get all entries in the keystore
async def unbond(keystore_file, device_config, hci_transport, address):
# With a keystore file, we can instantiate the keystore directly
if keystore_file:
keystore = JsonKeyStore(None, keystore_file)
else:
keystore = device.keystore
return await unbond_with_keystore(JsonKeyStore(None, keystore_file), address)
if keystore is None:
print('no keystore')
return
# Without a keystore file, we need to obtain the keystore from the device
async with await open_transport(hci_transport) as (hci_source, hci_sink):
# Create a device to manage the host
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
if address is None:
await keystore.print()
else:
try:
await keystore.delete(address)
except KeyError:
print('!!! pairing not found')
# Power-on the device to ensure we have a key store
await device.power_on()
return await unbond_with_keystore(device.keystore, address)
# -----------------------------------------------------------------------------
@click.command()
@click.option('--keystore-file', help='File in which to store the pairing keys')
@click.argument('device-config')
@click.option('--keystore-file', help='File in which the pairing keys are stored')
@click.option('--hci-transport', help='HCI transport for the controller')
@click.argument('device-config', required=False)
@click.argument('address', required=False)
def main(keystore_file, device_config, address):
def main(keystore_file, hci_transport, device_config, address):
"""
Remove pairing keys for a device, given its address.
If no keystore file is specified, the --hci-transport option must be used to
connect to a controller, so that the keystore for that controller can be
instantiated.
If no address is passed, the existing pairing keys for all addresses are printed.
"""
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(unbond(keystore_file, device_config, address))
if not keystore_file and not hci_transport:
print('either --keystore-file or --hci-transport must be specified.')
return
asyncio.run(unbond(keystore_file, device_config, hci_transport, address))
# -----------------------------------------------------------------------------
+11 -2
View File
@@ -889,7 +889,7 @@ def host_event_handler(function):
# List of host event handlers for the Device class.
# (we define this list outside the class, because referencing a class in method
# decorators is not straightforward)
device_host_event_handlers: List[str] = []
device_host_event_handlers: list[str] = []
# -----------------------------------------------------------------------------
@@ -3088,7 +3088,16 @@ class Device(CompositeEventEmitter):
def on_pairing_start(self, connection: Connection) -> None:
connection.emit('pairing_start')
def on_pairing(self, connection: Connection, keys: PairingKeys, sc: bool) -> None:
def on_pairing(
self,
connection: Connection,
identity_address: Optional[Address],
keys: PairingKeys,
sc: bool,
) -> None:
if identity_address is not None:
connection.peer_resolvable_address = connection.peer_address
connection.peer_address = identity_address
connection.sc = sc
connection.authenticated = True
connection.emit('pairing', keys)
+2 -37
View File
@@ -205,16 +205,8 @@ class Service(Attribute):
'''
uuid: UUID
characteristics: List[Characteristic]
included_services: List[Service]
def __init__(
self,
uuid,
characteristics: List[Characteristic],
included_services: List[Service] = [],
primary=True,
):
def __init__(self, uuid, characteristics: list[Characteristic], primary=True):
# Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str):
uuid = UUID(uuid)
@@ -227,7 +219,7 @@ class Service(Attribute):
uuid.to_pdu_bytes(),
)
self.uuid = uuid
self.included_services = included_services[:]
# self.included_services = []
self.characteristics = characteristics[:]
self.primary = primary
@@ -261,33 +253,6 @@ class TemplateService(Service):
super().__init__(self.UUID, characteristics, primary)
# -----------------------------------------------------------------------------
class IncludedServiceDeclaration(Attribute):
'''
See Vol 3, Part G - 3.2 INCLUDE DEFINITION
'''
service: Service
def __init__(self, service):
declaration_bytes = struct.pack(
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
)
super().__init__(
GATT_INCLUDE_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
)
self.service = service
def __str__(self):
return (
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
f'group_starting_handle=0x{self.service.handle:04X}, '
f'group_ending_handle=0x{self.service.end_group_handle:04X}, '
f'uuid={self.service.uuid}, '
f'{self.service.properties!s})'
)
# -----------------------------------------------------------------------------
class Characteristic(Attribute):
'''
+3 -62
View File
@@ -63,7 +63,6 @@ from .gatt import (
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_INCLUDE_ATTRIBUTE_TYPE,
Characteristic,
ClientCharacteristicConfigurationBits,
)
@@ -110,7 +109,6 @@ class AttributeProxy(EventEmitter):
class ServiceProxy(AttributeProxy):
uuid: UUID
characteristics: List[CharacteristicProxy]
included_services: List[ServiceProxy]
@staticmethod
def from_client(service_class, client, service_uuid):
@@ -504,69 +502,12 @@ class Client:
return services
async def discover_included_services(
self, service: ServiceProxy
) -> List[ServiceProxy]:
async def discover_included_services(self, _service):
'''
See Vol 3, Part G - 4.5.1 Find Included Services
'''
starting_handle = service.handle
ending_handle = service.end_group_handle
included_services: List[ServiceProxy] = []
while starting_handle <= ending_handle:
response = await self.send_request(
ATT_Read_By_Type_Request(
starting_handle=starting_handle,
ending_handle=ending_handle,
attribute_type=GATT_INCLUDE_ATTRIBUTE_TYPE,
)
)
if response is None:
# TODO raise appropriate exception
return []
# Check if we reached the end of the iteration
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(
'!!! unexpected error while discovering included services: '
f'{HCI_Constant.error_name(response.error_code)}'
)
raise ATT_Error(
error_code=response.error_code,
message='Unexpected error while discovering included services',
)
break
# Stop if for some reason the list was empty
if not response.attributes:
break
# Process all included services returned in this iteration
for attribute_handle, attribute_value in response.attributes:
if attribute_handle < starting_handle:
# Something's not right
logger.warning(f'bogus handle value: {attribute_handle}')
return []
group_starting_handle, group_ending_handle = struct.unpack_from(
'<HH', attribute_value
)
service_uuid = UUID.from_bytes(attribute_value[4:])
included_service = ServiceProxy(
self, group_starting_handle, group_ending_handle, service_uuid, True
)
included_services.append(included_service)
# Move on to the next included services
starting_handle = response.attributes[-1][0] + 1
service.included_services = included_services
return included_services
# TODO
return []
async def discover_characteristics(
self, uuids, service: Optional[ServiceProxy]
+1 -11
View File
@@ -68,7 +68,6 @@ from .gatt import (
Characteristic,
CharacteristicDeclaration,
CharacteristicValue,
IncludedServiceDeclaration,
Descriptor,
Service,
)
@@ -95,7 +94,6 @@ class Server(EventEmitter):
def __init__(self, device):
super().__init__()
self.device = device
self.services = []
self.attributes = [] # Attributes, ordered by increasing handle values
self.attributes_by_handle = {} # Map for fast attribute access by handle
self.max_mtu = (
@@ -224,14 +222,7 @@ class Server(EventEmitter):
# Add the service attribute to the DB
self.add_attribute(service)
# Add all included service
for included_service in service.included_services:
# Not registered yet, register the included service first.
if included_service not in self.services:
self.add_service(included_service)
# TODO: Handle circular service reference
include_declaration = IncludedServiceDeclaration(included_service)
self.add_attribute(include_declaration)
# TODO: add included services
# Add all characteristics
for characteristic in service.characteristics:
@@ -283,7 +274,6 @@ class Server(EventEmitter):
# Update the service group end
service.end_group_handle = self.attributes[-1].handle
self.services.append(service)
def add_services(self, services):
for service in services:
+72 -46
View File
@@ -190,10 +190,44 @@ class KeyStore:
# -----------------------------------------------------------------------------
class JsonKeyStore(KeyStore):
"""
KeyStore implementation that is backed by a JSON file.
This implementation supports storing a hierarchy of key sets in a single file.
A key set is a representation of a PairingKeys object. Each key set is stored
in a map, with the address of paired peer as the key. Maps are themselves grouped
into namespaces, grouping pairing keys by controller addresses.
The JSON object model looks like:
{
"<namespace>": {
"peer-address": {
"address_type": <n>,
"irk" : {
"authenticated": <true/false>,
"value": "hex-encoded-key"
},
... other keys ...
},
... other peers ...
}
... other namespaces ...
}
A namespace is typically the BD_ADDR of a controller, since that is a convenient
unique identifier, but it may be something else.
A special namespace, called the "default" namespace, is used when instantiating this
class without a namespace. With the default namespace, reading from a file will
load an existing namespace if there is only one, which may be convenient for reading
from a file with a single key set and for which the namespace isn't known. If the
file does not include any existing key set, or if there are more than one and none
has the default name, a new one will be created with the name "__DEFAULT__".
"""
APP_NAME = 'Bumble'
APP_AUTHOR = 'Google'
KEYS_DIR = 'Pairing'
DEFAULT_NAMESPACE = '__DEFAULT__'
DEFAULT_BASE_NAME = "keys"
def __init__(self, namespace, filename=None):
self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE
@@ -208,8 +242,9 @@ class JsonKeyStore(KeyStore):
self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
)
base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace
json_filename = (
f'{self.namespace}.json'.lower().replace(':', '-').replace('/p', '-p')
f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p')
)
self.filename = os.path.join(self.directory_name, json_filename)
else:
@@ -219,11 +254,13 @@ class JsonKeyStore(KeyStore):
logger.debug(f'JSON keystore: {self.filename}')
@staticmethod
def from_device(device: Device) -> Optional[JsonKeyStore]:
if not device.config.keystore:
return None
params = device.config.keystore.split(':', 1)[1:]
def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]:
if not filename:
# Extract the filename from the config if there is one
if device.config.keystore is not None:
params = device.config.keystore.split(':', 1)[1:]
if params:
filename = params[0]
# Use a namespace based on the device address
if device.public_address not in (Address.ANY, Address.ANY_RANDOM):
@@ -232,19 +269,31 @@ class JsonKeyStore(KeyStore):
namespace = str(device.random_address)
else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE
if params:
filename = params[0]
else:
filename = None
return JsonKeyStore(namespace, filename)
async def load(self):
# Try to open the file, without failing. If the file does not exist, it
# will be created upon saving.
try:
with open(self.filename, 'r', encoding='utf-8') as json_file:
return json.load(json_file)
db = json.load(json_file)
except FileNotFoundError:
return {}
db = {}
# First, look for a namespace match
if self.namespace in db:
return (db, db[self.namespace])
# Then, if the namespace is the default namespace, and there's
# only one entry in the db, use that
if self.namespace == self.DEFAULT_NAMESPACE and len(db) == 1:
return next(iter(db.items()))
# Finally, just create an empty key map for the namespace
key_map = {}
db[self.namespace] = key_map
return (db, key_map)
async def save(self, db):
# Create the directory if it doesn't exist
@@ -260,53 +309,30 @@ class JsonKeyStore(KeyStore):
os.replace(temp_filename, self.filename)
async def delete(self, name: str) -> None:
db = await self.load()
namespace = db.get(self.namespace)
if namespace is None:
raise KeyError(name)
del namespace[name]
db, key_map = await self.load()
del key_map[name]
await self.save(db)
async def update(self, name, keys):
db = await self.load()
namespace = db.setdefault(self.namespace, {})
namespace.setdefault(name, {}).update(keys.to_dict())
db, key_map = await self.load()
key_map.setdefault(name, {}).update(keys.to_dict())
await self.save(db)
async def get_all(self):
db = await self.load()
namespace = db.get(self.namespace)
if namespace is None:
return []
return [
(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()
]
_, key_map = await self.load()
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()]
async def delete_all(self):
db = await self.load()
db.pop(self.namespace, None)
db, key_map = await self.load()
key_map.clear()
await self.save(db)
async def get(self, name: str) -> Optional[PairingKeys]:
db = await self.load()
namespace = db.get(self.namespace)
if namespace is None:
_, key_map = await self.load()
if name not in key_map:
return None
keys = namespace.get(name)
if keys is None:
return None
return PairingKeys.from_dict(keys)
return PairingKeys.from_dict(key_map[name])
# -----------------------------------------------------------------------------
+1 -1
View File
@@ -1805,7 +1805,7 @@ class Manager(EventEmitter):
self.device.abort_on('flush', store_keys())
# Notify the device
self.device.on_pairing(session.connection, keys, session.sc)
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
def on_pairing_failure(self, session: Session, reason: int) -> None:
self.device.on_pairing_failure(session.connection, reason)
+179
View File
@@ -0,0 +1,179 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import json
import logging
import tempfile
import os
from bumble.keys import JsonKeyStore, PairingKeys
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
JSON1 = """
{
"my_namespace": {
"14:7D:DA:4E:53:A8/P": {
"address_type": 0,
"irk": {
"authenticated": false,
"value": "e7b2543b206e4e46b44f9e51dad22bd1"
},
"link_key": {
"authenticated": false,
"value": "0745dd9691e693d9dca740f7d8dfea75"
},
"ltk": {
"authenticated": false,
"value": "d1897ee10016eb1a08e4e037fd54c683"
}
}
}
}
"""
JSON2 = """
{
"my_namespace1": {
},
"my_namespace2": {
}
}
"""
JSON3 = """
{
"my_namespace1": {
},
"__DEFAULT__": {
"14:7D:DA:4E:53:A8/P": {
"address_type": 0,
"irk": {
"authenticated": false,
"value": "e7b2543b206e4e46b44f9e51dad22bd1"
}
}
}
}
"""
# -----------------------------------------------------------------------------
async def test_basic():
with tempfile.NamedTemporaryFile(mode="r+", encoding='utf-8') as file:
keystore = JsonKeyStore('my_namespace', file.name)
file.write("{}")
file.flush()
keys = await keystore.get_all()
assert len(keys) == 0
keys = PairingKeys()
await keystore.update('foo', keys)
foo = await keystore.get('foo')
assert foo is not None
assert foo.ltk is None
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
foo = await keystore.get('foo')
assert foo is not None
assert foo.ltk is not None
assert foo.ltk.value == ltk
file.flush()
with open(file.name, "r", encoding="utf-8") as json_file:
json_data = json.load(json_file)
assert 'my_namespace' in json_data
assert 'foo' in json_data['my_namespace']
assert 'ltk' in json_data['my_namespace']['foo']
# -----------------------------------------------------------------------------
async def test_parsing():
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
keystore = JsonKeyStore('my_namespace', file.name)
file.write(JSON1)
file.flush()
foo = await keystore.get('14:7D:DA:4E:53:A8/P')
assert foo is not None
assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683')
# -----------------------------------------------------------------------------
async def test_default_namespace():
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON1)
file.flush()
all_keys = await keystore.get_all()
assert len(all_keys) == 1
name, keys = all_keys[0]
assert name == '14:7D:DA:4E:53:A8/P'
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON2)
file.flush()
keys = PairingKeys()
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
file.flush()
with open(file.name, "r", encoding="utf-8") as json_file:
json_data = json.load(json_file)
assert '__DEFAULT__' in json_data
assert 'foo' in json_data['__DEFAULT__']
assert 'ltk' in json_data['__DEFAULT__']['foo']
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON3)
file.flush()
all_keys = await keystore.get_all()
assert len(all_keys) == 1
name, keys = all_keys[0]
assert name == '14:7D:DA:4E:53:A8/P'
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
# -----------------------------------------------------------------------------
async def run_tests():
await test_basic()
await test_parsing()
await test_default_namespace()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run_tests())
+1 -10
View File
@@ -190,9 +190,7 @@ async def test_self_gatt():
s1 = Service('8140E247-04F0-42C1-BC34-534C344DAFCA', [c1, c2, c3])
s2 = Service('97210A0F-1875-4D05-9E5D-326EB171257A', [c4])
s3 = Service('1853', [])
s4 = Service('3A12C182-14E2-4FE0-8C5B-65D7C569F9DB', [], included_services=[s2, s3])
two_devices.devices[1].add_services([s1, s2, s4])
two_devices.devices[1].add_services([s1, s2])
# Start
await two_devices.devices[0].power_on()
@@ -227,13 +225,6 @@ async def test_self_gatt():
assert result is not None
assert result == c1.value
result = await peer.discover_service(s4.uuid)
assert len(result) == 1
result = await peer.discover_included_services(result[0])
assert len(result) == 2
# Service UUID is only present when the UUID is 16-bit Bluetooth UUID
assert result[1].uuid.to_bytes() == s3.uuid.to_bytes()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio