Compare commits

..

20 Commits

Author SHA1 Message Date
Lucas Abel 7f987dc3cd Merge pull request #198 from qiaoccolato/main
reformat protobuf import
2023-06-07 10:05:30 -07:00
qiaoccolato 689745040f Merge branch 'google:main' into main 2023-06-07 09:19:54 -07:00
Qiao Yang 809d4a18f5 reformat protobuf import 2023-06-07 09:14:50 -07:00
Gilles Boccon-Gibod 54be8b328a Merge pull request #197 from zxzxwu/typing
Add typing for HFP and RFCOMM
2023-06-07 07:09:26 -07:00
Gilles Boccon-Gibod 57b469198a Merge pull request #196 from google/gbg/better-address-resolving
pairing event improvement
2023-06-07 07:03:53 -07:00
Josh Wu 4d74339c04 Add typing for RFCOMM 2023-06-06 00:04:25 +08:00
Josh Wu 39db278f2e Add typing for HFP 2023-06-05 23:54:42 +08:00
Gilles Boccon-Gibod 27fbb58447 add basic keystore test 2023-06-04 13:01:07 -07:00
Gilles Boccon-Gibod a74c39dc2b Merge pull request #193 from benquike/main
Update device name in advertising data from load_from_dict
2023-05-24 15:04:03 -07:00
Hui Peng 22f7cef771 Update device name in advertising data from load_from_dict 2023-05-23 16:13:24 -07:00
Lucas Abel 5790d3aae8 Merge pull request #192 from google/uael/gatt-fixes
gatt: reset args ordering to original
2023-05-23 06:40:51 +02:00
Lucas Abel 744294f00e gatt: reset args ordering to original
This was a breaking change
2023-05-22 09:34:30 +00:00
Gilles Boccon-Gibod 3697b8dde9 Merge pull request #189 from google/dependabot/pip/docs/mkdocs/pymdown-extensions-10.0
Bump pymdown-extensions from 9.6 to 10.0 in /docs/mkdocs
2023-05-17 19:05:23 -07:00
Lucas Abel f3bfbab44d Merge pull request #190 from google/uael/pandora-server
pandora: import bumble pandora server from avatar
2023-05-17 22:31:09 +02:00
uael afcce0d6c8 pandora: import bumble pandora server from avatar 2023-05-17 18:18:43 +00:00
dependabot[bot] 69d45bed21 Bump pymdown-extensions from 9.6 to 10.0 in /docs/mkdocs
Bumps [pymdown-extensions](https://github.com/facelessuser/pymdown-extensions) from 9.6 to 10.0.
- [Release notes](https://github.com/facelessuser/pymdown-extensions/releases)
- [Commits](https://github.com/facelessuser/pymdown-extensions/compare/9.6...10.0)

---
updated-dependencies:
- dependency-name: pymdown-extensions
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-05-15 20:54:41 +00:00
Gilles Boccon-Gibod 6826f68478 fix linter warnings 2023-05-05 16:16:55 -07:00
Gilles Boccon-Gibod f80c83d0b3 better doc and default behavior for json keystore 2023-05-05 16:11:20 -07:00
Gilles Boccon-Gibod 3de35193bc rebase 2023-05-05 16:09:01 -07:00
Gilles Boccon-Gibod 740a2e0ca0 instantiate keystore after power_on 2023-05-05 16:07:16 -07:00
21 changed files with 2295 additions and 169 deletions
+14 -14
View File
@@ -207,7 +207,7 @@ def on_connection(connection, request):
# Listen for pairing events
connection.on('pairing_start', on_pairing_start)
connection.on('pairing', on_pairing)
connection.on('pairing', lambda keys: on_pairing(connection.peer_address, keys))
connection.on('pairing_failure', on_pairing_failure)
# Listen for encryption changes
@@ -242,9 +242,9 @@ def on_pairing_start():
# -----------------------------------------------------------------------------
def on_pairing(keys):
def on_pairing(address, keys):
print(color('***-----------------------------------', 'cyan'))
print(color('*** Paired!', 'cyan'))
print(color(f'*** Paired! (peer identity={address})', 'cyan'))
keys.print(prefix=color('*** ', 'cyan'))
print(color('***-----------------------------------', 'cyan'))
Waiter.instance.terminate()
@@ -283,17 +283,6 @@ async def pair(
# Create a device to manage the host
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
# Set a custom keystore if specified on the command line
if keystore_file:
device.keystore = JsonKeyStore(namespace=None, filename=keystore_file)
# Print the existing keys before pairing
if print_keys and device.keystore:
print(color('@@@-----------------------------------', 'blue'))
print(color('@@@ Pairing Keys:', 'blue'))
await device.keystore.print(prefix=color('@@@ ', 'blue'))
print(color('@@@-----------------------------------', 'blue'))
# Expose a GATT characteristic that can be used to trigger pairing by
# responding with an authentication error when read
if mode == 'le':
@@ -323,6 +312,17 @@ async def pair(
# Get things going
await device.power_on()
# Set a custom keystore if specified on the command line
if keystore_file:
device.keystore = JsonKeyStore.from_device(device, filename=keystore_file)
# Print the existing keys before pairing
if print_keys and device.keystore:
print(color('@@@-----------------------------------', 'blue'))
print(color('@@@ Pairing Keys:', 'blue'))
await device.keystore.print(prefix=color('@@@ ', 'blue'))
print(color('@@@-----------------------------------', 'blue'))
# Set up a pairing config factory
device.pairing_config_factory = lambda connection: PairingConfig(
sc, mitm, bond, Delegate(mode, connection, io, prompt)
+30
View File
@@ -0,0 +1,30 @@
import asyncio
import click
import logging
from bumble.pandora import PandoraDevice, serve
BUMBLE_SERVER_GRPC_PORT = 7999
ROOTCANAL_PORT_CUTTLEFISH = 7300
@click.command()
@click.option('--grpc-port', help='gRPC port to serve', default=BUMBLE_SERVER_GRPC_PORT)
@click.option(
'--rootcanal-port', help='Rootcanal TCP port', default=ROOTCANAL_PORT_CUTTLEFISH
)
@click.option(
'--transport',
help='HCI transport',
default=f'tcp-client:127.0.0.1:<rootcanal-port>',
)
def main(grpc_port: int, rootcanal_port: int, transport: str) -> None:
if '<rootcanal-port>' in transport:
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
device = PandoraDevice({'transport': transport})
logging.basicConfig(level=logging.DEBUG)
asyncio.run(serve(device, port=grpc_port))
if __name__ == '__main__':
main() # pylint: disable=no-value-for-parameter
+5 -6
View File
@@ -133,15 +133,16 @@ async def scan(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
await device.power_on()
if keystore_file:
keystore = JsonKeyStore(namespace=None, filename=keystore_file)
device.keystore = keystore
else:
resolver = None
device.keystore = JsonKeyStore.from_device(device, filename=keystore_file)
if device.keystore:
resolving_keys = await device.keystore.get_resolving_keys()
resolver = AddressResolver(resolving_keys)
else:
resolver = None
printer = AdvertisementPrinter(min_rssi, resolver)
if raw:
@@ -149,8 +150,6 @@ async def scan(
else:
device.on('advertisement', printer.on_advertisement)
await device.power_on()
if phy is None:
scanning_phys = [HCI_LE_1M_PHY, HCI_LE_CODED_PHY]
else:
+40 -22
View File
@@ -22,40 +22,58 @@ import click
from bumble.device import Device
from bumble.keys import JsonKeyStore
from bumble.transport import open_transport
# -----------------------------------------------------------------------------
async def unbond_with_keystore(keystore, address):
if address is None:
return await keystore.print()
try:
await keystore.delete(address)
except KeyError:
print('!!! pairing not found')
# -----------------------------------------------------------------------------
async def unbond(keystore_file, device_config, address):
# Create a device to manage the host
device = Device.from_config_file(device_config)
# Get all entries in the keystore
async def unbond(keystore_file, device_config, hci_transport, address):
# With a keystore file, we can instantiate the keystore directly
if keystore_file:
keystore = JsonKeyStore(None, keystore_file)
else:
keystore = device.keystore
return await unbond_with_keystore(JsonKeyStore(None, keystore_file), address)
if keystore is None:
print('no keystore')
return
# Without a keystore file, we need to obtain the keystore from the device
async with await open_transport(hci_transport) as (hci_source, hci_sink):
# Create a device to manage the host
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
if address is None:
await keystore.print()
else:
try:
await keystore.delete(address)
except KeyError:
print('!!! pairing not found')
# Power-on the device to ensure we have a key store
await device.power_on()
return await unbond_with_keystore(device.keystore, address)
# -----------------------------------------------------------------------------
@click.command()
@click.option('--keystore-file', help='File in which to store the pairing keys')
@click.argument('device-config')
@click.option('--keystore-file', help='File in which the pairing keys are stored')
@click.option('--hci-transport', help='HCI transport for the controller')
@click.argument('device-config', required=False)
@click.argument('address', required=False)
def main(keystore_file, device_config, address):
def main(keystore_file, hci_transport, device_config, address):
"""
Remove pairing keys for a device, given its address.
If no keystore file is specified, the --hci-transport option must be used to
connect to a controller, so that the keystore for that controller can be
instantiated.
If no address is passed, the existing pairing keys for all addresses are printed.
"""
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(unbond(keystore_file, device_config, address))
if not keystore_file and not hci_transport:
print('either --keystore-file or --hci-transport must be specified.')
return
asyncio.run(unbond(keystore_file, device_config, hci_transport, address))
# -----------------------------------------------------------------------------
+16 -1
View File
@@ -826,6 +826,12 @@ class DeviceConfiguration:
advertising_data = config.get('advertising_data')
if advertising_data:
self.advertising_data = bytes.fromhex(advertising_data)
elif config.get('name') is not None:
self.advertising_data = bytes(
AdvertisingData(
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))]
)
)
def load_from_file(self, filename):
with open(filename, 'r', encoding='utf-8') as file:
@@ -3088,7 +3094,16 @@ class Device(CompositeEventEmitter):
def on_pairing_start(self, connection: Connection) -> None:
connection.emit('pairing_start')
def on_pairing(self, connection: Connection, keys: PairingKeys, sc: bool) -> None:
def on_pairing(
self,
connection: Connection,
identity_address: Optional[Address],
keys: PairingKeys,
sc: bool,
) -> None:
if identity_address is not None:
connection.peer_resolvable_address = connection.peer_address
connection.peer_address = identity_address
connection.sc = sc
connection.authenticated = True
connection.emit('pairing', keys)
+1 -1
View File
@@ -212,8 +212,8 @@ class Service(Attribute):
self,
uuid,
characteristics: List[Characteristic],
included_services: List[Service] = [],
primary=True,
included_services: List[Service] = [],
):
# Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str):
+14 -8
View File
@@ -18,10 +18,11 @@
import logging
import asyncio
import collections
from typing import Union
from . import rfcomm
from .colors import color
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -34,7 +35,12 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class HfpProtocol:
def __init__(self, dlc):
dlc: rfcomm.DLC
buffer: str
lines: collections.deque
lines_available: asyncio.Event
def __init__(self, dlc: rfcomm.DLC) -> None:
self.dlc = dlc
self.buffer = ''
self.lines = collections.deque()
@@ -42,7 +48,7 @@ class HfpProtocol:
dlc.sink = self.feed
def feed(self, data):
def feed(self, data: Union[bytes, str]) -> None:
# Convert the data to a string if needed
if isinstance(data, bytes):
data = data.decode('utf-8')
@@ -57,19 +63,19 @@ class HfpProtocol:
if len(line) > 0:
self.on_line(line)
def on_line(self, line):
def on_line(self, line: str) -> None:
self.lines.append(line)
self.lines_available.set()
def send_command_line(self, line):
def send_command_line(self, line: str) -> None:
logger.debug(color(f'>>> {line}', 'yellow'))
self.dlc.write(line + '\r')
def send_response_line(self, line):
def send_response_line(self, line: str) -> None:
logger.debug(color(f'>>> {line}', 'yellow'))
self.dlc.write('\r\n' + line + '\r\n')
async def next_line(self):
async def next_line(self) -> str:
await self.lines_available.wait()
line = self.lines.popleft()
if not self.lines:
@@ -77,7 +83,7 @@ class HfpProtocol:
logger.debug(color(f'<<< {line}', 'green'))
return line
async def initialize_service(self):
async def initialize_service(self) -> None:
# Perform Service Level Connection Initialization
self.send_command_line('AT+BRSF=2072') # Retrieve Supported Features
await (self.next_line())
+72 -46
View File
@@ -190,10 +190,44 @@ class KeyStore:
# -----------------------------------------------------------------------------
class JsonKeyStore(KeyStore):
"""
KeyStore implementation that is backed by a JSON file.
This implementation supports storing a hierarchy of key sets in a single file.
A key set is a representation of a PairingKeys object. Each key set is stored
in a map, with the address of paired peer as the key. Maps are themselves grouped
into namespaces, grouping pairing keys by controller addresses.
The JSON object model looks like:
{
"<namespace>": {
"peer-address": {
"address_type": <n>,
"irk" : {
"authenticated": <true/false>,
"value": "hex-encoded-key"
},
... other keys ...
},
... other peers ...
}
... other namespaces ...
}
A namespace is typically the BD_ADDR of a controller, since that is a convenient
unique identifier, but it may be something else.
A special namespace, called the "default" namespace, is used when instantiating this
class without a namespace. With the default namespace, reading from a file will
load an existing namespace if there is only one, which may be convenient for reading
from a file with a single key set and for which the namespace isn't known. If the
file does not include any existing key set, or if there are more than one and none
has the default name, a new one will be created with the name "__DEFAULT__".
"""
APP_NAME = 'Bumble'
APP_AUTHOR = 'Google'
KEYS_DIR = 'Pairing'
DEFAULT_NAMESPACE = '__DEFAULT__'
DEFAULT_BASE_NAME = "keys"
def __init__(self, namespace, filename=None):
self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE
@@ -208,8 +242,9 @@ class JsonKeyStore(KeyStore):
self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
)
base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace
json_filename = (
f'{self.namespace}.json'.lower().replace(':', '-').replace('/p', '-p')
f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p')
)
self.filename = os.path.join(self.directory_name, json_filename)
else:
@@ -219,11 +254,13 @@ class JsonKeyStore(KeyStore):
logger.debug(f'JSON keystore: {self.filename}')
@staticmethod
def from_device(device: Device) -> Optional[JsonKeyStore]:
if not device.config.keystore:
return None
params = device.config.keystore.split(':', 1)[1:]
def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]:
if not filename:
# Extract the filename from the config if there is one
if device.config.keystore is not None:
params = device.config.keystore.split(':', 1)[1:]
if params:
filename = params[0]
# Use a namespace based on the device address
if device.public_address not in (Address.ANY, Address.ANY_RANDOM):
@@ -232,19 +269,31 @@ class JsonKeyStore(KeyStore):
namespace = str(device.random_address)
else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE
if params:
filename = params[0]
else:
filename = None
return JsonKeyStore(namespace, filename)
async def load(self):
# Try to open the file, without failing. If the file does not exist, it
# will be created upon saving.
try:
with open(self.filename, 'r', encoding='utf-8') as json_file:
return json.load(json_file)
db = json.load(json_file)
except FileNotFoundError:
return {}
db = {}
# First, look for a namespace match
if self.namespace in db:
return (db, db[self.namespace])
# Then, if the namespace is the default namespace, and there's
# only one entry in the db, use that
if self.namespace == self.DEFAULT_NAMESPACE and len(db) == 1:
return next(iter(db.items()))
# Finally, just create an empty key map for the namespace
key_map = {}
db[self.namespace] = key_map
return (db, key_map)
async def save(self, db):
# Create the directory if it doesn't exist
@@ -260,53 +309,30 @@ class JsonKeyStore(KeyStore):
os.replace(temp_filename, self.filename)
async def delete(self, name: str) -> None:
db = await self.load()
namespace = db.get(self.namespace)
if namespace is None:
raise KeyError(name)
del namespace[name]
db, key_map = await self.load()
del key_map[name]
await self.save(db)
async def update(self, name, keys):
db = await self.load()
namespace = db.setdefault(self.namespace, {})
namespace.setdefault(name, {}).update(keys.to_dict())
db, key_map = await self.load()
key_map.setdefault(name, {}).update(keys.to_dict())
await self.save(db)
async def get_all(self):
db = await self.load()
namespace = db.get(self.namespace)
if namespace is None:
return []
return [
(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()
]
_, key_map = await self.load()
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()]
async def delete_all(self):
db = await self.load()
db.pop(self.namespace, None)
db, key_map = await self.load()
key_map.clear()
await self.save(db)
async def get(self, name: str) -> Optional[PairingKeys]:
db = await self.load()
namespace = db.get(self.namespace)
if namespace is None:
_, key_map = await self.load()
if name not in key_map:
return None
keys = namespace.get(name)
if keys is None:
return None
return PairingKeys.from_dict(keys)
return PairingKeys.from_dict(key_map[name])
# -----------------------------------------------------------------------------
+105
View File
@@ -0,0 +1,105 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Bumble Pandora server.
This module implement the Pandora Bluetooth test APIs for the Bumble stack.
"""
__version__ = "0.0.1"
import grpc
import grpc.aio
from .config import Config
from .device import PandoraDevice
from .host import HostService
from .security import SecurityService, SecurityStorageService
from pandora.host_grpc_aio import add_HostServicer_to_server
from pandora.security_grpc_aio import (
add_SecurityServicer_to_server,
add_SecurityStorageServicer_to_server,
)
from typing import Callable, List, Optional
# public symbols
__all__ = [
'register_servicer_hook',
'serve',
'Config',
'PandoraDevice',
]
# Add servicers hooks.
_SERVICERS_HOOKS: List[Callable[[PandoraDevice, Config, grpc.aio.Server], None]] = []
def register_servicer_hook(
hook: Callable[[PandoraDevice, Config, grpc.aio.Server], None]
) -> None:
_SERVICERS_HOOKS.append(hook)
async def serve(
bumble: PandoraDevice,
config: Config = Config(),
grpc_server: Optional[grpc.aio.Server] = None,
port: int = 0,
) -> None:
# initialize a gRPC server if not provided.
server = grpc_server if grpc_server is not None else grpc.aio.server()
port = server.add_insecure_port(f'localhost:{port}')
try:
while True:
# load server config from dict.
config.load_from_dict(bumble.config.get('server', {}))
# add Pandora services to the gRPC server.
add_HostServicer_to_server(
HostService(server, bumble.device, config), server
)
add_SecurityServicer_to_server(
SecurityService(bumble.device, config), server
)
add_SecurityStorageServicer_to_server(
SecurityStorageService(bumble.device, config), server
)
# call hooks if any.
for hook in _SERVICERS_HOOKS:
hook(bumble, config, server)
# open device.
await bumble.open()
try:
# Pandora require classic devices to be discoverable & connectable.
if bumble.device.classic_enabled:
await bumble.device.set_discoverable(True)
await bumble.device.set_connectable(True)
# start & serve gRPC server.
await server.start()
await server.wait_for_termination()
finally:
# close device.
await bumble.close()
# re-initialize the gRPC server.
server = grpc.aio.server()
server.add_insecure_port(f'localhost:{port}')
finally:
# stop server.
await server.stop(None)
+48
View File
@@ -0,0 +1,48 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from bumble.pairing import PairingDelegate
from dataclasses import dataclass
from typing import Any, Dict
@dataclass
class Config:
io_capability: PairingDelegate.IoCapability = PairingDelegate.NO_OUTPUT_NO_INPUT
pairing_sc_enable: bool = True
pairing_mitm_enable: bool = True
pairing_bonding_enable: bool = True
smp_local_initiator_key_distribution: PairingDelegate.KeyDistribution = (
PairingDelegate.DEFAULT_KEY_DISTRIBUTION
)
smp_local_responder_key_distribution: PairingDelegate.KeyDistribution = (
PairingDelegate.DEFAULT_KEY_DISTRIBUTION
)
def load_from_dict(self, config: Dict[str, Any]) -> None:
io_capability_name: str = config.get(
'io_capability', 'no_output_no_input'
).upper()
self.io_capability = getattr(PairingDelegate, io_capability_name)
self.pairing_sc_enable = config.get('pairing_sc_enable', True)
self.pairing_mitm_enable = config.get('pairing_mitm_enable', True)
self.pairing_bonding_enable = config.get('pairing_bonding_enable', True)
self.smp_local_initiator_key_distribution = config.get(
'smp_local_initiator_key_distribution',
PairingDelegate.DEFAULT_KEY_DISTRIBUTION,
)
self.smp_local_responder_key_distribution = config.get(
'smp_local_responder_key_distribution',
PairingDelegate.DEFAULT_KEY_DISTRIBUTION,
)
+157
View File
@@ -0,0 +1,157 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic & dependency free Bumble (reference) device."""
from bumble import transport
from bumble.core import (
BT_GENERIC_AUDIO_SERVICE,
BT_HANDSFREE_SERVICE,
BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID,
)
from bumble.device import Device, DeviceConfiguration
from bumble.host import Host
from bumble.sdp import (
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement,
ServiceAttribute,
)
from typing import Any, Dict, List, Optional
class PandoraDevice:
"""
Small wrapper around a Bumble device and it's HCI transport.
Notes:
- The Bumble device is idle by default.
- Repetitive calls to `open`/`close` will result on new Bumble device instances.
"""
# Bumble device instance & configuration.
device: Device
config: Dict[str, Any]
# HCI transport name & instance.
_hci_name: str
_hci: Optional[transport.Transport] # type: ignore[name-defined]
def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.device = _make_device(config)
self._hci_name = config.get('transport', '')
self._hci = None
@property
def idle(self) -> bool:
return self._hci is None
async def open(self) -> None:
if self._hci is not None:
return
# open HCI transport & set device host.
self._hci = await transport.open_transport(self._hci_name)
self.device.host = Host(controller_source=self._hci.source, controller_sink=self._hci.sink) # type: ignore[no-untyped-call]
# power-on.
await self.device.power_on()
async def close(self) -> None:
if self._hci is None:
return
# flush & re-initialize device.
await self.device.host.flush()
self.device.host = None # type: ignore[assignment]
self.device = _make_device(self.config)
# close HCI transport.
await self._hci.close()
self._hci = None
async def reset(self) -> None:
await self.close()
await self.open()
def info(self) -> Optional[Dict[str, str]]:
return {
'public_bd_address': str(self.device.public_address),
'random_address': str(self.device.random_address),
}
def _make_device(config: Dict[str, Any]) -> Device:
"""Initialize an idle Bumble device instance."""
# initialize bumble device.
device_config = DeviceConfiguration()
device_config.load_from_dict(config)
device = Device(config=device_config, host=None)
# Add fake a2dp service to avoid Android disconnect
device.sdp_service_records = _make_sdp_records(1)
return device
# TODO(b/267540823): remove when Pandora A2dp is supported
def _make_sdp_records(rfcomm_channel: int) -> Dict[int, List[ServiceAttribute]]:
return {
0x00010001: [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(0x00010001),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.uuid(BT_GENERIC_AUDIO_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(rfcomm_channel),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.unsigned_integer_16(0x0105),
]
)
]
),
),
]
}
+857
View File
@@ -0,0 +1,857 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import bumble.device
import grpc
import grpc.aio
import logging
import struct
from . import utils
from .config import Config
from bumble.core import (
BT_BR_EDR_TRANSPORT,
BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE,
UUID,
AdvertisingData,
ConnectionError,
)
from bumble.device import (
DEVICE_DEFAULT_SCAN_INTERVAL,
DEVICE_DEFAULT_SCAN_WINDOW,
Advertisement,
AdvertisingType,
Device,
)
from bumble.gatt import Service
from bumble.hci import (
HCI_CONNECTION_ALREADY_EXISTS_ERROR,
HCI_PAGE_TIMEOUT_ERROR,
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
Address,
)
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from pandora.host_grpc_aio import HostServicer
from pandora.host_pb2 import (
NOT_CONNECTABLE,
NOT_DISCOVERABLE,
PRIMARY_1M,
PRIMARY_CODED,
SECONDARY_1M,
SECONDARY_2M,
SECONDARY_CODED,
SECONDARY_NONE,
AdvertiseRequest,
AdvertiseResponse,
Connection,
ConnectLERequest,
ConnectLEResponse,
ConnectRequest,
ConnectResponse,
DataTypes,
DisconnectRequest,
InquiryResponse,
PrimaryPhy,
ReadLocalAddressResponse,
ScanningResponse,
ScanRequest,
SecondaryPhy,
SetConnectabilityModeRequest,
SetDiscoverabilityModeRequest,
WaitConnectionRequest,
WaitConnectionResponse,
WaitDisconnectionRequest,
)
from typing import AsyncGenerator, Dict, List, Optional, Set, Tuple, cast
PRIMARY_PHY_MAP: Dict[int, PrimaryPhy] = {
# Default value reported by Bumble for legacy Advertising reports.
# FIXME(uael): `None` might be a better value, but Bumble need to change accordingly.
0: PRIMARY_1M,
1: PRIMARY_1M,
3: PRIMARY_CODED,
}
SECONDARY_PHY_MAP: Dict[int, SecondaryPhy] = {
0: SECONDARY_NONE,
1: SECONDARY_1M,
2: SECONDARY_2M,
3: SECONDARY_CODED,
}
class HostService(HostServicer):
waited_connections: Set[int]
def __init__(
self, grpc_server: grpc.aio.Server, device: Device, config: Config
) -> None:
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(), {'service_name': 'Host', 'device': device}
)
self.grpc_server = grpc_server
self.device = device
self.config = config
self.waited_connections = set()
@utils.rpc
async def FactoryReset(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info('FactoryReset')
# delete all bonds
if self.device.keystore is not None:
await self.device.keystore.delete_all()
# trigger gRCP server stop then return
asyncio.create_task(self.grpc_server.stop(None))
return empty_pb2.Empty()
@utils.rpc
async def Reset(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info('Reset')
# clear service.
self.waited_connections.clear()
# (re) power device on
await self.device.power_on()
return empty_pb2.Empty()
@utils.rpc
async def ReadLocalAddress(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> ReadLocalAddressResponse:
self.log.info('ReadLocalAddress')
return ReadLocalAddressResponse(
address=bytes(reversed(bytes(self.device.public_address)))
)
@utils.rpc
async def Connect(
self, request: ConnectRequest, context: grpc.ServicerContext
) -> ConnectResponse:
# Need to reverse bytes order since Bumble Address is using MSB.
address = Address(
bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS
)
self.log.info(f"Connect to {address}")
try:
connection = await self.device.connect(
address, transport=BT_BR_EDR_TRANSPORT
)
except ConnectionError as e:
if e.error_code == HCI_PAGE_TIMEOUT_ERROR:
self.log.warning(f"Peer not found: {e}")
return ConnectResponse(peer_not_found=empty_pb2.Empty())
if e.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR:
self.log.warning(f"Connection already exists: {e}")
return ConnectResponse(connection_already_exists=empty_pb2.Empty())
raise e
self.log.info(f"Connect to {address} done (handle={connection.handle})")
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return ConnectResponse(connection=Connection(cookie=cookie))
@utils.rpc
async def WaitConnection(
self, request: WaitConnectionRequest, context: grpc.ServicerContext
) -> WaitConnectionResponse:
if not request.address:
raise ValueError('Request address field must be set')
# Need to reverse bytes order since Bumble Address is using MSB.
address = Address(
bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS
)
if address in (Address.NIL, Address.ANY):
raise ValueError('Invalid address')
self.log.info(f"WaitConnection from {address}...")
connection = self.device.find_connection_by_bd_addr(
address, transport=BT_BR_EDR_TRANSPORT
)
if connection and id(connection) in self.waited_connections:
# this connection was already returned: wait for a new one.
connection = None
if not connection:
connection = await self.device.accept(address)
# save connection has waited and respond.
self.waited_connections.add(id(connection))
self.log.info(
f"WaitConnection from {address} done (handle={connection.handle})"
)
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return WaitConnectionResponse(connection=Connection(cookie=cookie))
@utils.rpc
async def ConnectLE(
self, request: ConnectLERequest, context: grpc.ServicerContext
) -> ConnectLEResponse:
address = utils.address_from_request(request, request.WhichOneof("address"))
if address in (Address.NIL, Address.ANY):
raise ValueError('Invalid address')
self.log.info(f"ConnectLE to {address}...")
try:
connection = await self.device.connect(
address,
transport=BT_LE_TRANSPORT,
own_address_type=request.own_address_type,
)
except ConnectionError as e:
if e.error_code == HCI_PAGE_TIMEOUT_ERROR:
self.log.warning(f"Peer not found: {e}")
return ConnectLEResponse(peer_not_found=empty_pb2.Empty())
if e.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR:
self.log.warning(f"Connection already exists: {e}")
return ConnectLEResponse(connection_already_exists=empty_pb2.Empty())
raise e
self.log.info(f"ConnectLE to {address} done (handle={connection.handle})")
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return ConnectLEResponse(connection=Connection(cookie=cookie))
@utils.rpc
async def Disconnect(
self, request: DisconnectRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"Disconnect: {connection_handle}")
self.log.info("Disconnecting...")
if connection := self.device.lookup_connection(connection_handle):
await connection.disconnect(HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR)
self.log.info("Disconnected")
return empty_pb2.Empty()
@utils.rpc
async def WaitDisconnection(
self, request: WaitDisconnectionRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"WaitDisconnection: {connection_handle}")
if connection := self.device.lookup_connection(connection_handle):
disconnection_future: asyncio.Future[
None
] = asyncio.get_running_loop().create_future()
def on_disconnection(_: None) -> None:
disconnection_future.set_result(None)
connection.on('disconnection', on_disconnection)
try:
await disconnection_future
self.log.info("Disconnected")
finally:
connection.remove_listener('disconnection', on_disconnection) # type: ignore
return empty_pb2.Empty()
@utils.rpc
async def Advertise(
self, request: AdvertiseRequest, context: grpc.ServicerContext
) -> AsyncGenerator[AdvertiseResponse, None]:
if not request.legacy:
raise NotImplementedError(
"TODO: add support for extended advertising in Bumble"
)
if request.interval:
raise NotImplementedError("TODO: add support for `request.interval`")
if request.interval_range:
raise NotImplementedError("TODO: add support for `request.interval_range`")
if request.primary_phy:
raise NotImplementedError("TODO: add support for `request.primary_phy`")
if request.secondary_phy:
raise NotImplementedError("TODO: add support for `request.secondary_phy`")
if self.device.is_advertising:
raise NotImplementedError('TODO: add support for advertising sets')
if data := request.data:
self.device.advertising_data = bytes(self.unpack_data_types(data))
if scan_response_data := request.scan_response_data:
self.device.scan_response_data = bytes(
self.unpack_data_types(scan_response_data)
)
scannable = True
else:
scannable = False
# Retrieve services data
for service in self.device.gatt_server.attributes:
if isinstance(service, Service) and (
service_data := service.get_advertising_data()
):
service_uuid = service.uuid.to_hex_str('-')
if (
service_uuid in request.data.incomplete_service_class_uuids16
or service_uuid in request.data.complete_service_class_uuids16
or service_uuid in request.data.incomplete_service_class_uuids32
or service_uuid in request.data.complete_service_class_uuids32
or service_uuid
in request.data.incomplete_service_class_uuids128
or service_uuid in request.data.complete_service_class_uuids128
):
self.device.advertising_data += service_data
if (
service_uuid
in scan_response_data.incomplete_service_class_uuids16
or service_uuid
in scan_response_data.complete_service_class_uuids16
or service_uuid
in scan_response_data.incomplete_service_class_uuids32
or service_uuid
in scan_response_data.complete_service_class_uuids32
or service_uuid
in scan_response_data.incomplete_service_class_uuids128
or service_uuid
in scan_response_data.complete_service_class_uuids128
):
self.device.scan_response_data += service_data
target = None
if request.connectable and scannable:
advertising_type = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE
elif scannable:
advertising_type = AdvertisingType.UNDIRECTED_SCANNABLE
else:
advertising_type = AdvertisingType.UNDIRECTED
else:
target = None
advertising_type = AdvertisingType.UNDIRECTED
if request.target:
# Need to reverse bytes order since Bumble Address is using MSB.
target_bytes = bytes(reversed(request.target))
if request.target_variant() == "public":
target = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
advertising_type = (
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
) # FIXME: HIGH_DUTY ?
else:
target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
advertising_type = (
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
) # FIXME: HIGH_DUTY ?
if request.connectable:
def on_connection(connection: bumble.device.Connection) -> None:
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
pending_connection.set_result(connection)
self.device.on('connection', on_connection)
try:
while True:
if not self.device.is_advertising:
self.log.info('Advertise')
await self.device.start_advertising(
target=target,
advertising_type=advertising_type,
own_address_type=request.own_address_type,
)
if not request.connectable:
await asyncio.sleep(1)
continue
pending_connection: asyncio.Future[
bumble.device.Connection
] = asyncio.get_running_loop().create_future()
self.log.info('Wait for LE connection...')
connection = await pending_connection
self.log.info(
f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})"
)
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
yield AdvertiseResponse(connection=Connection(cookie=cookie))
# wait a small delay before restarting the advertisement.
await asyncio.sleep(1)
finally:
if request.connectable:
self.device.remove_listener('connection', on_connection) # type: ignore
try:
self.log.info('Stop advertising')
await self.device.abort_on('flush', self.device.stop_advertising())
except:
pass
@utils.rpc
async def Scan(
self, request: ScanRequest, context: grpc.ServicerContext
) -> AsyncGenerator[ScanningResponse, None]:
# TODO: modify `start_scanning` to accept floats instead of int for ms values
if request.phys:
raise NotImplementedError("TODO: add support for `request.phys`")
self.log.info('Scan')
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
handler = self.device.on('advertisement', scan_queue.put_nowait)
await self.device.start_scanning(
legacy=request.legacy,
active=not request.passive,
own_address_type=request.own_address_type,
scan_interval=int(request.interval)
if request.interval
else DEVICE_DEFAULT_SCAN_INTERVAL,
scan_window=int(request.window)
if request.window
else DEVICE_DEFAULT_SCAN_WINDOW,
)
try:
# TODO: add support for `direct_address` in Bumble
# TODO: add support for `periodic_advertising_interval` in Bumble
while adv := await scan_queue.get():
sr = ScanningResponse(
legacy=adv.is_legacy,
connectable=adv.is_connectable,
scannable=adv.is_scannable,
truncated=adv.is_truncated,
sid=adv.sid,
primary_phy=PRIMARY_PHY_MAP[adv.primary_phy],
secondary_phy=SECONDARY_PHY_MAP[adv.secondary_phy],
tx_power=adv.tx_power,
rssi=adv.rssi,
data=self.pack_data_types(adv.data),
)
if adv.address.address_type == Address.PUBLIC_DEVICE_ADDRESS:
sr.public = bytes(reversed(bytes(adv.address)))
elif adv.address.address_type == Address.RANDOM_DEVICE_ADDRESS:
sr.random = bytes(reversed(bytes(adv.address)))
elif adv.address.address_type == Address.PUBLIC_IDENTITY_ADDRESS:
sr.public_identity = bytes(reversed(bytes(adv.address)))
else:
sr.random_static_identity = bytes(reversed(bytes(adv.address)))
yield sr
finally:
self.device.remove_listener('advertisement', handler) # type: ignore
try:
self.log.info('Stop scanning')
await self.device.abort_on('flush', self.device.stop_scanning())
except:
pass
@utils.rpc
async def Inquiry(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> AsyncGenerator[InquiryResponse, None]:
self.log.info('Inquiry')
inquiry_queue: asyncio.Queue[
Optional[Tuple[Address, int, AdvertisingData, int]]
] = asyncio.Queue()
complete_handler = self.device.on(
'inquiry_complete', lambda: inquiry_queue.put_nowait(None)
)
result_handler = self.device.on( # type: ignore
'inquiry_result',
lambda address, class_of_device, eir_data, rssi: inquiry_queue.put_nowait( # type: ignore
(address, class_of_device, eir_data, rssi) # type: ignore
),
)
await self.device.start_discovery(auto_restart=False)
try:
while inquiry_result := await inquiry_queue.get():
(address, class_of_device, eir_data, rssi) = inquiry_result
# FIXME: if needed, add support for `page_scan_repetition_mode` and `clock_offset` in Bumble
yield InquiryResponse(
address=bytes(reversed(bytes(address))),
class_of_device=class_of_device,
rssi=rssi,
data=self.pack_data_types(eir_data),
)
finally:
self.device.remove_listener('inquiry_complete', complete_handler) # type: ignore
self.device.remove_listener('inquiry_result', result_handler) # type: ignore
try:
self.log.info('Stop inquiry')
await self.device.abort_on('flush', self.device.stop_discovery())
except:
pass
@utils.rpc
async def SetDiscoverabilityMode(
self, request: SetDiscoverabilityModeRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info("SetDiscoverabilityMode")
await self.device.set_discoverable(request.mode != NOT_DISCOVERABLE)
return empty_pb2.Empty()
@utils.rpc
async def SetConnectabilityMode(
self, request: SetConnectabilityModeRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info("SetConnectabilityMode")
await self.device.set_connectable(request.mode != NOT_CONNECTABLE)
return empty_pb2.Empty()
def unpack_data_types(self, dt: DataTypes) -> AdvertisingData:
ad_structures: List[Tuple[int, bytes]] = []
uuids: List[str]
datas: Dict[str, bytes]
def uuid128_from_str(uuid: str) -> bytes:
"""Decode a 128-bit uuid encoded as XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
to byte format."""
return bytes(reversed(bytes.fromhex(uuid.replace('-', ''))))
def uuid32_from_str(uuid: str) -> bytes:
"""Decode a 32-bit uuid encoded as XXXXXXXX to byte format."""
return bytes(reversed(bytes.fromhex(uuid)))
def uuid16_from_str(uuid: str) -> bytes:
"""Decode a 16-bit uuid encoded as XXXX to byte format."""
return bytes(reversed(bytes.fromhex(uuid)))
if uuids := dt.incomplete_service_class_uuids16:
ad_structures.append(
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
b''.join([uuid16_from_str(uuid) for uuid in uuids]),
)
)
if uuids := dt.complete_service_class_uuids16:
ad_structures.append(
(
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
b''.join([uuid16_from_str(uuid) for uuid in uuids]),
)
)
if uuids := dt.incomplete_service_class_uuids32:
ad_structures.append(
(
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
b''.join([uuid32_from_str(uuid) for uuid in uuids]),
)
)
if uuids := dt.complete_service_class_uuids32:
ad_structures.append(
(
AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
b''.join([uuid32_from_str(uuid) for uuid in uuids]),
)
)
if uuids := dt.incomplete_service_class_uuids128:
ad_structures.append(
(
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
b''.join([uuid128_from_str(uuid) for uuid in uuids]),
)
)
if uuids := dt.complete_service_class_uuids128:
ad_structures.append(
(
AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
b''.join([uuid128_from_str(uuid) for uuid in uuids]),
)
)
if dt.HasField('include_shortened_local_name'):
ad_structures.append(
(
AdvertisingData.SHORTENED_LOCAL_NAME,
bytes(self.device.name[:8], 'utf-8'),
)
)
elif dt.shortened_local_name:
ad_structures.append(
(
AdvertisingData.SHORTENED_LOCAL_NAME,
bytes(dt.shortened_local_name, 'utf-8'),
)
)
if dt.HasField('include_complete_local_name'):
ad_structures.append(
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.device.name, 'utf-8'))
)
elif dt.complete_local_name:
ad_structures.append(
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(dt.complete_local_name, 'utf-8'),
)
)
if dt.HasField('include_tx_power_level'):
raise ValueError('unsupported data type')
elif dt.tx_power_level:
ad_structures.append(
(
AdvertisingData.TX_POWER_LEVEL,
bytes(struct.pack('<I', dt.tx_power_level)[:1]),
)
)
if dt.HasField('include_class_of_device'):
ad_structures.append(
(
AdvertisingData.CLASS_OF_DEVICE,
bytes(struct.pack('<I', self.device.class_of_device)[:-1]),
)
)
elif dt.class_of_device:
ad_structures.append(
(
AdvertisingData.CLASS_OF_DEVICE,
bytes(struct.pack('<I', dt.class_of_device)[:-1]),
)
)
if dt.peripheral_connection_interval_min:
ad_structures.append(
(
AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE,
bytes(
[
*struct.pack('<H', dt.peripheral_connection_interval_min),
*struct.pack(
'<H',
dt.peripheral_connection_interval_max
if dt.peripheral_connection_interval_max
else dt.peripheral_connection_interval_min,
),
]
),
)
)
if uuids := dt.service_solicitation_uuids16:
ad_structures.append(
(
AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
b''.join([uuid16_from_str(uuid) for uuid in uuids]),
)
)
if uuids := dt.service_solicitation_uuids32:
ad_structures.append(
(
AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
b''.join([uuid32_from_str(uuid) for uuid in uuids]),
)
)
if uuids := dt.service_solicitation_uuids128:
ad_structures.append(
(
AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
b''.join([uuid128_from_str(uuid) for uuid in uuids]),
)
)
if datas := dt.service_data_uuid16:
ad_structures.extend(
[
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
uuid16_from_str(uuid) + data,
)
for uuid, data in datas.items()
]
)
if datas := dt.service_data_uuid32:
ad_structures.extend(
[
(
AdvertisingData.SERVICE_DATA_32_BIT_UUID,
uuid32_from_str(uuid) + data,
)
for uuid, data in datas.items()
]
)
if datas := dt.service_data_uuid128:
ad_structures.extend(
[
(
AdvertisingData.SERVICE_DATA_128_BIT_UUID,
uuid128_from_str(uuid) + data,
)
for uuid, data in datas.items()
]
)
if dt.appearance:
ad_structures.append(
(AdvertisingData.APPEARANCE, struct.pack('<H', dt.appearance))
)
if dt.advertising_interval:
ad_structures.append(
(
AdvertisingData.ADVERTISING_INTERVAL,
struct.pack('<H', dt.advertising_interval),
)
)
if dt.uri:
ad_structures.append((AdvertisingData.URI, bytes(dt.uri, 'utf-8')))
if dt.le_supported_features:
ad_structures.append(
(AdvertisingData.LE_SUPPORTED_FEATURES, dt.le_supported_features)
)
if dt.manufacturer_specific_data:
ad_structures.append(
(
AdvertisingData.MANUFACTURER_SPECIFIC_DATA,
dt.manufacturer_specific_data,
)
)
return AdvertisingData(ad_structures)
def pack_data_types(self, ad: AdvertisingData) -> DataTypes:
dt = DataTypes()
uuids: List[UUID]
s: str
i: int
ij: Tuple[int, int]
uuid_data: Tuple[UUID, bytes]
data: bytes
if uuids := cast(
List[UUID],
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS),
):
dt.incomplete_service_class_uuids16.extend(
list(map(lambda x: x.to_hex_str('-'), uuids))
)
if uuids := cast(
List[UUID],
ad.get(AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS),
):
dt.complete_service_class_uuids16.extend(
list(map(lambda x: x.to_hex_str('-'), uuids))
)
if uuids := cast(
List[UUID],
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS),
):
dt.incomplete_service_class_uuids32.extend(
list(map(lambda x: x.to_hex_str('-'), uuids))
)
if uuids := cast(
List[UUID],
ad.get(AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS),
):
dt.complete_service_class_uuids32.extend(
list(map(lambda x: x.to_hex_str('-'), uuids))
)
if uuids := cast(
List[UUID],
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS),
):
dt.incomplete_service_class_uuids128.extend(
list(map(lambda x: x.to_hex_str('-'), uuids))
)
if uuids := cast(
List[UUID],
ad.get(AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS),
):
dt.complete_service_class_uuids128.extend(
list(map(lambda x: x.to_hex_str('-'), uuids))
)
if s := cast(str, ad.get(AdvertisingData.SHORTENED_LOCAL_NAME)):
dt.shortened_local_name = s
if s := cast(str, ad.get(AdvertisingData.COMPLETE_LOCAL_NAME)):
dt.complete_local_name = s
if i := cast(int, ad.get(AdvertisingData.TX_POWER_LEVEL)):
dt.tx_power_level = i
if i := cast(int, ad.get(AdvertisingData.CLASS_OF_DEVICE)):
dt.class_of_device = i
if ij := cast(
Tuple[int, int],
ad.get(AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE),
):
dt.peripheral_connection_interval_min = ij[0]
dt.peripheral_connection_interval_max = ij[1]
if uuids := cast(
List[UUID],
ad.get(AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS),
):
dt.service_solicitation_uuids16.extend(
list(map(lambda x: x.to_hex_str('-'), uuids))
)
if uuids := cast(
List[UUID],
ad.get(AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS),
):
dt.service_solicitation_uuids32.extend(
list(map(lambda x: x.to_hex_str('-'), uuids))
)
if uuids := cast(
List[UUID],
ad.get(AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS),
):
dt.service_solicitation_uuids128.extend(
list(map(lambda x: x.to_hex_str('-'), uuids))
)
if uuid_data := cast(
Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_16_BIT_UUID)
):
dt.service_data_uuid16[uuid_data[0].to_hex_str('-')] = uuid_data[1]
if uuid_data := cast(
Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_32_BIT_UUID)
):
dt.service_data_uuid32[uuid_data[0].to_hex_str('-')] = uuid_data[1]
if uuid_data := cast(
Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_128_BIT_UUID)
):
dt.service_data_uuid128[uuid_data[0].to_hex_str('-')] = uuid_data[1]
if data := cast(bytes, ad.get(AdvertisingData.PUBLIC_TARGET_ADDRESS, raw=True)):
dt.public_target_addresses.extend(
[data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))]
)
if data := cast(bytes, ad.get(AdvertisingData.RANDOM_TARGET_ADDRESS, raw=True)):
dt.random_target_addresses.extend(
[data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))]
)
if i := cast(int, ad.get(AdvertisingData.APPEARANCE)):
dt.appearance = i
if i := cast(int, ad.get(AdvertisingData.ADVERTISING_INTERVAL)):
dt.advertising_interval = i
if s := cast(str, ad.get(AdvertisingData.URI)):
dt.uri = s
if data := cast(bytes, ad.get(AdvertisingData.LE_SUPPORTED_FEATURES, raw=True)):
dt.le_supported_features = data
if data := cast(
bytes, ad.get(AdvertisingData.MANUFACTURER_SPECIFIC_DATA, raw=True)
):
dt.manufacturer_specific_data = data
return dt
View File
+526
View File
@@ -0,0 +1,526 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import grpc
import logging
from . import utils
from .config import Config
from bumble import hci
from bumble.core import (
BT_BR_EDR_TRANSPORT,
BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE,
ProtocolError,
)
from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
from contextlib import suppress
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
from pandora.host_pb2 import Connection
from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer
from pandora.security_pb2 import (
LE_LEVEL1,
LE_LEVEL2,
LE_LEVEL3,
LE_LEVEL4,
LEVEL0,
LEVEL1,
LEVEL2,
LEVEL3,
LEVEL4,
DeleteBondRequest,
IsBondedRequest,
LESecurityLevel,
PairingEvent,
PairingEventAnswer,
SecureRequest,
SecureResponse,
SecurityLevel,
WaitSecurityRequest,
WaitSecurityResponse,
)
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Optional, Union
class PairingDelegate(BasePairingDelegate):
def __init__(
self,
connection: BumbleConnection,
service: "SecurityService",
io_capability: BasePairingDelegate.IoCapability = BasePairingDelegate.NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
) -> None:
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(),
{'service_name': 'Security', 'device': connection.device},
)
self.connection = connection
self.service = service
super().__init__(
io_capability,
local_initiator_key_distribution,
local_responder_key_distribution,
)
async def accept(self) -> bool:
return True
def add_origin(self, ev: PairingEvent) -> PairingEvent:
if not self.connection.is_incomplete:
assert ev.connection
ev.connection.CopyFrom(
Connection(
cookie=any_pb2.Any(value=self.connection.handle.to_bytes(4, 'big'))
)
)
else:
# In BR/EDR, connection may not be complete,
# use address instead
assert self.connection.transport == BT_BR_EDR_TRANSPORT
ev.address = bytes(reversed(bytes(self.connection.peer_address)))
return ev
async def confirm(self, auto: bool = False) -> bool:
self.log.info(
f"Pairing event: `just_works` (io_capability: {self.io_capability})"
)
if self.service.event_queue is None or self.service.event_answer is None:
return True
event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
async def compare_numbers(self, number: int, digits: int = 6) -> bool:
self.log.info(
f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})"
)
if self.service.event_queue is None or self.service.event_answer is None:
raise RuntimeError('security: unhandled number comparison request')
event = self.add_origin(PairingEvent(numeric_comparison=number))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
async def get_number(self) -> Optional[int]:
self.log.info(
f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
)
if self.service.event_queue is None or self.service.event_answer is None:
raise RuntimeError('security: unhandled number request')
event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event
if answer.answer_variant() is None:
return None
assert answer.answer_variant() == 'passkey'
return answer.passkey
async def get_string(self, max_length: int) -> Optional[str]:
self.log.info(
f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
)
if self.service.event_queue is None or self.service.event_answer is None:
raise RuntimeError('security: unhandled pin_code request')
event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event
if answer.answer_variant() is None:
return None
assert answer.answer_variant() == 'pin'
if answer.pin is None:
return None
pin = answer.pin.decode('utf-8')
if not pin or len(pin) > max_length:
raise ValueError(f'Pin must be utf-8 encoded up to {max_length} bytes')
return pin
async def display_number(self, number: int, digits: int = 6) -> None:
if (
self.connection.transport == BT_BR_EDR_TRANSPORT
and self.io_capability == BasePairingDelegate.DISPLAY_OUTPUT_ONLY
):
return
self.log.info(
f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})"
)
if self.service.event_queue is None:
raise RuntimeError('security: unhandled number display request')
event = self.add_origin(PairingEvent(passkey_entry_notification=number))
self.service.event_queue.put_nowait(event)
BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = {
LEVEL0: lambda connection: True,
LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated,
LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated,
LEVEL3: lambda connection: connection.encryption != 0
and connection.authenticated
and connection.link_key_type
in (
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
),
LEVEL4: lambda connection: connection.encryption
== hci.HCI_Encryption_Change_Event.AES_CCM
and connection.authenticated
and connection.link_key_type
== hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
}
LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = {
LE_LEVEL1: lambda connection: True,
LE_LEVEL2: lambda connection: connection.encryption != 0,
LE_LEVEL3: lambda connection: connection.encryption != 0
and connection.authenticated,
LE_LEVEL4: lambda connection: connection.encryption != 0
and connection.authenticated
and connection.sc,
}
class SecurityService(SecurityServicer):
def __init__(self, device: Device, config: Config) -> None:
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(), {'service_name': 'Security', 'device': device}
)
self.event_queue: Optional[asyncio.Queue[PairingEvent]] = None
self.event_answer: Optional[AsyncIterator[PairingEventAnswer]] = None
self.device = device
self.config = config
def pairing_config_factory(connection: BumbleConnection) -> PairingConfig:
return PairingConfig(
sc=config.pairing_sc_enable,
mitm=config.pairing_mitm_enable,
bonding=config.pairing_bonding_enable,
delegate=PairingDelegate(
connection,
self,
io_capability=config.io_capability,
local_initiator_key_distribution=config.smp_local_initiator_key_distribution,
local_responder_key_distribution=config.smp_local_responder_key_distribution,
),
)
self.device.pairing_config_factory = pairing_config_factory
@utils.rpc
async def OnPairing(
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
) -> AsyncGenerator[PairingEvent, None]:
self.log.info('OnPairing')
if self.event_queue is not None:
raise RuntimeError('already streaming pairing events')
if len(self.device.connections):
raise RuntimeError(
'the `OnPairing` method shall be initiated before establishing any connections.'
)
self.event_queue = asyncio.Queue()
self.event_answer = request
try:
while event := await self.event_queue.get():
yield event
finally:
self.event_queue = None
self.event_answer = None
@utils.rpc
async def Secure(
self, request: SecureRequest, context: grpc.ServicerContext
) -> SecureResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"Secure: {connection_handle}")
connection = self.device.lookup_connection(connection_handle)
assert connection
oneof = request.WhichOneof('level')
level = getattr(request, oneof)
assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[
connection.transport
] == oneof
# security level already reached
if self.reached_security_level(connection, level):
return SecureResponse(success=empty_pb2.Empty())
# trigger pairing if needed
if self.need_pairing(connection, level):
try:
self.log.info('Pair...')
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
wait_for_security: asyncio.Future[
bool
] = asyncio.get_running_loop().create_future()
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
connection.on("pairing_failure", wait_for_security.set_exception)
connection.request_pairing()
await wait_for_security
else:
await connection.pair()
self.log.info('Paired')
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
except (HCI_Error, ProtocolError) as e:
self.log.warning(f"Pairing failure: {e}")
return SecureResponse(pairing_failure=empty_pb2.Empty())
# trigger authentication if needed
if self.need_authentication(connection, level):
try:
self.log.info('Authenticate...')
await connection.authenticate()
self.log.info('Authenticated')
except asyncio.CancelledError:
self.log.warning("Connection died during authentication")
return SecureResponse(connection_died=empty_pb2.Empty())
except (HCI_Error, ProtocolError) as e:
self.log.warning(f"Authentication failure: {e}")
return SecureResponse(authentication_failure=empty_pb2.Empty())
# trigger encryption if needed
if self.need_encryption(connection, level):
try:
self.log.info('Encrypt...')
await connection.encrypt()
self.log.info('Encrypted')
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
except (HCI_Error, ProtocolError) as e:
self.log.warning(f"Encryption failure: {e}")
return SecureResponse(encryption_failure=empty_pb2.Empty())
# security level has been reached ?
if self.reached_security_level(connection, level):
return SecureResponse(success=empty_pb2.Empty())
return SecureResponse(not_reached=empty_pb2.Empty())
@utils.rpc
async def WaitSecurity(
self, request: WaitSecurityRequest, context: grpc.ServicerContext
) -> WaitSecurityResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"WaitSecurity: {connection_handle}")
connection = self.device.lookup_connection(connection_handle)
assert connection
assert request.level
level = request.level
assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[
connection.transport
] == request.level_variant()
wait_for_security: asyncio.Future[
str
] = asyncio.get_running_loop().create_future()
authenticate_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None:
assert connection
if (encryption := connection.encryption) != 0:
self.log.debug('Disable encryption...')
try:
await connection.encrypt(enable=False)
except:
pass
self.log.debug('Disable encryption: done')
self.log.debug('Authenticate...')
await connection.authenticate()
self.log.debug('Authenticate: done')
if encryption != 0 and connection.encryption != encryption:
self.log.debug('Re-enable encryption...')
await connection.encrypt()
self.log.debug('Re-enable encryption: done')
def set_failure(name: str) -> Callable[..., None]:
def wrapper(*args: Any) -> None:
self.log.info(f'Wait for security: error `{name}`: {args}')
wait_for_security.set_result(name)
return wrapper
def try_set_success(*_: Any) -> None:
assert connection
if self.reached_security_level(connection, level):
self.log.info('Wait for security: done')
wait_for_security.set_result('success')
def on_encryption_change(*_: Any) -> None:
assert connection
if self.reached_security_level(connection, level):
self.log.info('Wait for security: done')
wait_for_security.set_result('success')
elif (
connection.transport == BT_BR_EDR_TRANSPORT
and self.need_authentication(connection, level)
):
nonlocal authenticate_task
if authenticate_task is None:
authenticate_task = asyncio.create_task(authenticate())
listeners: Dict[str, Callable[..., None]] = {
'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'),
'connection_authentication_failure': set_failure('authentication_failure'),
'connection_encryption_failure': set_failure('encryption_failure'),
'pairing': try_set_success,
'connection_authentication': try_set_success,
'connection_encryption_change': on_encryption_change,
}
# register event handlers
for event, listener in listeners.items():
connection.on(event, listener)
# security level already reached
if self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty())
self.log.info('Wait for security...')
kwargs = {}
kwargs[await wait_for_security] = empty_pb2.Empty()
# remove event handlers
for event, listener in listeners.items():
connection.remove_listener(event, listener) # type: ignore
# wait for `authenticate` to finish if any
if authenticate_task is not None:
self.log.info('Wait for authentication...')
try:
await authenticate_task # type: ignore
except:
pass
self.log.info('Authenticated')
return WaitSecurityResponse(**kwargs)
def reached_security_level(
self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
) -> bool:
self.log.debug(
str(
{
'level': level,
'encryption': connection.encryption,
'authenticated': connection.authenticated,
'sc': connection.sc,
'link_key_type': connection.link_key_type,
}
)
)
if isinstance(level, LESecurityLevel):
return LE_LEVEL_REACHED[level](connection)
return BR_LEVEL_REACHED[level](connection)
def need_pairing(self, connection: BumbleConnection, level: int) -> bool:
if connection.transport == BT_LE_TRANSPORT:
return level >= LE_LEVEL3 and not connection.authenticated
return False
def need_authentication(self, connection: BumbleConnection, level: int) -> bool:
if connection.transport == BT_LE_TRANSPORT:
return False
if level == LEVEL2 and connection.encryption != 0:
return not connection.authenticated
return level >= LEVEL2 and not connection.authenticated
def need_encryption(self, connection: BumbleConnection, level: int) -> bool:
# TODO(abel): need to support MITM
if connection.transport == BT_LE_TRANSPORT:
return level == LE_LEVEL2 and not connection.encryption
return level >= LEVEL2 and not connection.encryption
class SecurityStorageService(SecurityStorageServicer):
def __init__(self, device: Device, config: Config) -> None:
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(), {'service_name': 'SecurityStorage', 'device': device}
)
self.device = device
self.config = config
@utils.rpc
async def IsBonded(
self, request: IsBondedRequest, context: grpc.ServicerContext
) -> wrappers_pb2.BoolValue:
address = utils.address_from_request(request, request.WhichOneof("address"))
self.log.info(f"IsBonded: {address}")
if self.device.keystore is not None:
is_bonded = await self.device.keystore.get(str(address)) is not None
else:
is_bonded = False
return wrappers_pb2.BoolValue(value=is_bonded)
@utils.rpc
async def DeleteBond(
self, request: DeleteBondRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
address = utils.address_from_request(request, request.WhichOneof("address"))
self.log.info(f"DeleteBond: {address}")
if self.device.keystore is not None:
with suppress(KeyError):
await self.device.keystore.delete(str(address))
return empty_pb2.Empty()
+112
View File
@@ -0,0 +1,112 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import grpc
import inspect
import logging
from bumble.device import Device
from bumble.hci import Address
from google.protobuf.message import Message # pytype: disable=pyi-error
from typing import Any, Dict, Generator, MutableMapping, Optional, Tuple
ADDRESS_TYPES: Dict[str, int] = {
"public": Address.PUBLIC_DEVICE_ADDRESS,
"random": Address.RANDOM_DEVICE_ADDRESS,
"public_identity": Address.PUBLIC_IDENTITY_ADDRESS,
"random_static_identity": Address.RANDOM_IDENTITY_ADDRESS,
}
def address_from_request(request: Message, field: Optional[str]) -> Address:
if field is None:
return Address.ANY
return Address(bytes(reversed(getattr(request, field))), ADDRESS_TYPES[field])
class BumbleServerLoggerAdapter(logging.LoggerAdapter): # type: ignore
"""Formats logs from the PandoraClient."""
def process(
self, msg: str, kwargs: MutableMapping[str, Any]
) -> Tuple[str, MutableMapping[str, Any]]:
assert self.extra
service_name = self.extra['service_name']
assert isinstance(service_name, str)
device = self.extra['device']
assert isinstance(device, Device)
addr_bytes = bytes(
reversed(bytes(device.public_address))
) # pytype: disable=attribute-error
addr = ':'.join([f'{x:02X}' for x in addr_bytes[4:]])
return (f'[bumble.{service_name}:{addr}] {msg}', kwargs)
@contextlib.contextmanager
def exception_to_rpc_error(
context: grpc.ServicerContext,
) -> Generator[None, None, None]:
try:
yield None
except NotImplementedError as e:
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
context.set_details(str(e)) # type: ignore
except ValueError as e:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT) # type: ignore
context.set_details(str(e)) # type: ignore
except RuntimeError as e:
context.set_code(grpc.StatusCode.ABORTED) # type: ignore
context.set_details(str(e)) # type: ignore
# Decorate an RPC servicer method with a wrapper that transform exceptions to gRPC errors.
def rpc(func: Any) -> Any:
@functools.wraps(func)
async def asyncgen_wrapper(
self: Any, request: Any, context: grpc.ServicerContext
) -> Any:
with exception_to_rpc_error(context):
async for v in func(self, request, context):
yield v
@functools.wraps(func)
async def async_wrapper(
self: Any, request: Any, context: grpc.ServicerContext
) -> Any:
with exception_to_rpc_error(context):
return await func(self, request, context)
@functools.wraps(func)
def gen_wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
with exception_to_rpc_error(context):
for v in func(self, request, context):
yield v
@functools.wraps(func)
def wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
with exception_to_rpc_error(context):
return func(self, request, context)
if inspect.isasyncgenfunction(func):
return asyncgen_wrapper
if inspect.iscoroutinefunction(func):
return async_wrapper
if inspect.isgenerator(func):
return gen_wrapper
return wrapper
+110 -67
View File
@@ -19,8 +19,9 @@ import logging
import asyncio
from pyee import EventEmitter
from typing import Optional, Tuple, Callable, Dict, Union
from . import core
from . import core, l2cap
from .colors import color
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError
@@ -105,7 +106,7 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# -----------------------------------------------------------------------------
def compute_fcs(buffer):
def compute_fcs(buffer: bytes) -> int:
result = 0xFF
for byte in buffer:
result = CRC_TABLE[result ^ byte]
@@ -114,7 +115,15 @@ def compute_fcs(buffer):
# -----------------------------------------------------------------------------
class RFCOMM_Frame:
def __init__(self, frame_type, c_r, dlci, p_f, information=b'', with_credits=False):
def __init__(
self,
frame_type: int,
c_r: int,
dlci: int,
p_f: int,
information: bytes = b'',
with_credits: bool = False,
) -> None:
self.type = frame_type
self.c_r = c_r
self.dlci = dlci
@@ -136,11 +145,11 @@ class RFCOMM_Frame:
else:
self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
def type_name(self):
def type_name(self) -> str:
return RFCOMM_FRAME_TYPE_NAMES[self.type]
@staticmethod
def parse_mcc(data):
def parse_mcc(data) -> Tuple[int, int, bytes]:
mcc_type = data[0] >> 2
c_r = (data[0] >> 1) & 1
length = data[1]
@@ -154,36 +163,36 @@ class RFCOMM_Frame:
return (mcc_type, c_r, value)
@staticmethod
def make_mcc(mcc_type, c_r, data):
def make_mcc(mcc_type: int, c_r: int, data: bytes) -> bytes:
return (
bytes([(mcc_type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1])
+ data
)
@staticmethod
def sabm(c_r, dlci):
def sabm(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_SABM_FRAME, c_r, dlci, 1)
@staticmethod
def ua(c_r, dlci):
def ua(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_UA_FRAME, c_r, dlci, 1)
@staticmethod
def dm(c_r, dlci):
def dm(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_DM_FRAME, c_r, dlci, 1)
@staticmethod
def disc(c_r, dlci):
def disc(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1)
@staticmethod
def uih(c_r, dlci, information, p_f=0):
def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0):
return RFCOMM_Frame(
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
)
@staticmethod
def from_bytes(data):
def from_bytes(data: bytes):
# Extract fields
dlci = (data[0] >> 2) & 0x3F
c_r = (data[0] >> 1) & 0x01
@@ -227,15 +236,23 @@ class RFCOMM_Frame:
# -----------------------------------------------------------------------------
class RFCOMM_MCC_PN:
dlci: int
cl: int
priority: int
ack_timer: int
max_frame_size: int
max_retransmissions: int
window_size: int
def __init__(
self,
dlci,
cl,
priority,
ack_timer,
max_frame_size,
max_retransmissions,
window_size,
dlci: int,
cl: int,
priority: int,
ack_timer: int,
max_frame_size: int,
max_retransmissions: int,
window_size: int,
):
self.dlci = dlci
self.cl = cl
@@ -246,7 +263,7 @@ class RFCOMM_MCC_PN:
self.window_size = window_size
@staticmethod
def from_bytes(data):
def from_bytes(data: bytes):
return RFCOMM_MCC_PN(
dlci=data[0],
cl=data[1],
@@ -285,7 +302,14 @@ class RFCOMM_MCC_PN:
# -----------------------------------------------------------------------------
class RFCOMM_MCC_MSC:
def __init__(self, dlci, fc, rtc, rtr, ic, dv):
dlci: int
fc: int
rtc: int
rtr: int
ic: int
dv: int
def __init__(self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int):
self.dlci = dlci
self.fc = fc
self.rtc = rtc
@@ -294,7 +318,7 @@ class RFCOMM_MCC_MSC:
self.dv = dv
@staticmethod
def from_bytes(data):
def from_bytes(data: bytes):
return RFCOMM_MCC_MSC(
dlci=data[0] >> 2,
fc=data[1] >> 1 & 1,
@@ -347,7 +371,12 @@ class DLC(EventEmitter):
RESET: 'RESET',
}
def __init__(self, multiplexer, dlci, max_frame_size, initial_tx_credits):
connection_result: Optional[asyncio.Future]
sink: Optional[Callable[[bytes], None]]
def __init__(
self, multiplexer, dlci: int, max_frame_size: int, initial_tx_credits: int
):
super().__init__()
self.multiplexer = multiplexer
self.dlci = dlci
@@ -368,23 +397,23 @@ class DLC(EventEmitter):
)
@staticmethod
def state_name(state):
def state_name(state: int) -> str:
return DLC.STATE_NAMES[state]
def change_state(self, new_state):
def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "magenta")}'
)
self.state = new_state
def send_frame(self, frame):
def send_frame(self, frame: RFCOMM_Frame) -> None:
self.multiplexer.send_frame(frame)
def on_frame(self, frame):
def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame):
def on_sabm_frame(self, _frame) -> None:
if self.state != DLC.CONNECTING:
logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red')
@@ -404,7 +433,7 @@ class DLC(EventEmitter):
self.change_state(DLC.CONNECTED)
self.emit('open')
def on_ua_frame(self, _frame):
def on_ua_frame(self, _frame) -> None:
if self.state != DLC.CONNECTING:
logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red')
@@ -422,15 +451,15 @@ class DLC(EventEmitter):
self.change_state(DLC.CONNECTED)
self.multiplexer.on_dlc_open_complete(self)
def on_dm_frame(self, frame):
def on_dm_frame(self, frame) -> None:
# TODO: handle all states
pass
def on_disc_frame(self, _frame):
def on_disc_frame(self, _frame) -> None:
# TODO: handle all states
self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
def on_uih_frame(self, frame):
def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
data = frame.information
if frame.p_f == 1:
# With credits
@@ -460,10 +489,10 @@ class DLC(EventEmitter):
# Check if there's anything to send (including credits)
self.process_tx()
def on_ui_frame(self, frame):
def on_ui_frame(self, frame) -> None:
pass
def on_mcc_msc(self, c_r, msc):
def on_mcc_msc(self, c_r, msc) -> None:
if c_r:
# Command
logger.debug(f'<<< MCC MSC Command: {msc}')
@@ -477,7 +506,7 @@ class DLC(EventEmitter):
# Response
logger.debug(f'<<< MCC MSC Response: {msc}')
def connect(self):
def connect(self) -> None:
if self.state != DLC.INIT:
raise InvalidStateError('invalid state')
@@ -485,7 +514,7 @@ class DLC(EventEmitter):
self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
def accept(self):
def accept(self) -> None:
if self.state != DLC.INIT:
raise InvalidStateError('invalid state')
@@ -503,13 +532,13 @@ class DLC(EventEmitter):
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.CONNECTING)
def rx_credits_needed(self):
def rx_credits_needed(self) -> int:
if self.rx_credits <= self.rx_threshold:
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
return 0
def process_tx(self):
def process_tx(self) -> None:
# Send anything we can (or an empty frame if we need to send rx credits)
rx_credits_needed = self.rx_credits_needed()
while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0:
@@ -547,7 +576,7 @@ class DLC(EventEmitter):
rx_credits_needed = 0
# Stream protocol
def write(self, data):
def write(self, data: Union[bytes, str]) -> None:
# We can only send bytes
if not isinstance(data, bytes):
if isinstance(data, str):
@@ -559,7 +588,7 @@ class DLC(EventEmitter):
self.tx_buffer += data
self.process_tx()
def drain(self):
def drain(self) -> None:
# TODO
pass
@@ -592,7 +621,13 @@ class Multiplexer(EventEmitter):
RESET: 'RESET',
}
def __init__(self, l2cap_channel, role):
connection_result: Optional[asyncio.Future]
disconnection_result: Optional[asyncio.Future]
open_result: Optional[asyncio.Future]
acceptor: Optional[Callable[[int], bool]]
dlcs: Dict[int, DLC]
def __init__(self, l2cap_channel: l2cap.Channel, role: int) -> None:
super().__init__()
self.role = role
self.l2cap_channel = l2cap_channel
@@ -607,20 +642,20 @@ class Multiplexer(EventEmitter):
l2cap_channel.sink = self.on_pdu
@staticmethod
def state_name(state):
def state_name(state: int):
return Multiplexer.STATE_NAMES[state]
def change_state(self, new_state):
def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
self.state = new_state
def send_frame(self, frame):
def send_frame(self, frame: RFCOMM_Frame) -> None:
logger.debug(f'>>> Multiplexer sending {frame}')
self.l2cap_channel.send_pdu(frame)
def on_pdu(self, pdu):
def on_pdu(self, pdu: bytes) -> None:
frame = RFCOMM_Frame.from_bytes(pdu)
logger.debug(f'<<< Multiplexer received {frame}')
@@ -640,18 +675,18 @@ class Multiplexer(EventEmitter):
return
dlc.on_frame(frame)
def on_frame(self, frame):
def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame):
def on_sabm_frame(self, _frame) -> None:
if self.state != Multiplexer.INIT:
logger.debug('not in INIT state, ignoring SABM')
return
self.change_state(Multiplexer.CONNECTED)
self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0))
def on_ua_frame(self, _frame):
def on_ua_frame(self, _frame) -> None:
if self.state == Multiplexer.CONNECTING:
self.change_state(Multiplexer.CONNECTED)
if self.connection_result:
@@ -663,7 +698,7 @@ class Multiplexer(EventEmitter):
self.disconnection_result.set_result(None)
self.disconnection_result = None
def on_dm_frame(self, _frame):
def on_dm_frame(self, _frame) -> None:
if self.state == Multiplexer.OPENING:
self.change_state(Multiplexer.CONNECTED)
if self.open_result:
@@ -678,13 +713,13 @@ class Multiplexer(EventEmitter):
else:
logger.warning(f'unexpected state for DM: {self}')
def on_disc_frame(self, _frame):
def on_disc_frame(self, _frame) -> None:
self.change_state(Multiplexer.DISCONNECTED)
self.send_frame(
RFCOMM_Frame.ua(c_r=0 if self.role == Multiplexer.INITIATOR else 1, dlci=0)
)
def on_uih_frame(self, frame):
def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
(mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
if mcc_type == RFCOMM_MCC_PN_TYPE:
@@ -694,10 +729,10 @@ class Multiplexer(EventEmitter):
mcs = RFCOMM_MCC_MSC.from_bytes(value)
self.on_mcc_msc(c_r, mcs)
def on_ui_frame(self, frame):
def on_ui_frame(self, frame) -> None:
pass
def on_mcc_pn(self, c_r, pn):
def on_mcc_pn(self, c_r, pn) -> None:
if c_r == 1:
# Command
logger.debug(f'<<< PN Command: {pn}')
@@ -736,14 +771,14 @@ class Multiplexer(EventEmitter):
else:
logger.warning('ignoring PN response')
def on_mcc_msc(self, c_r, msc):
def on_mcc_msc(self, c_r, msc) -> None:
dlc = self.dlcs.get(msc.dlci)
if dlc is None:
logger.warning(f'no dlc for DLCI {msc.dlci}')
return
dlc.on_mcc_msc(c_r, msc)
async def connect(self):
async def connect(self) -> None:
if self.state != Multiplexer.INIT:
raise InvalidStateError('invalid state')
@@ -752,7 +787,7 @@ class Multiplexer(EventEmitter):
self.send_frame(RFCOMM_Frame.sabm(c_r=1, dlci=0))
return await self.connection_result
async def disconnect(self):
async def disconnect(self) -> None:
if self.state != Multiplexer.CONNECTED:
return
@@ -765,7 +800,7 @@ class Multiplexer(EventEmitter):
)
await self.disconnection_result
async def open_dlc(self, channel):
async def open_dlc(self, channel: int) -> DLC:
if self.state != Multiplexer.CONNECTED:
if self.state == Multiplexer.OPENING:
raise InvalidStateError('open already in progress')
@@ -796,7 +831,7 @@ class Multiplexer(EventEmitter):
self.open_result = None
return result
def on_dlc_open_complete(self, dlc):
def on_dlc_open_complete(self, dlc: DLC):
logger.debug(f'DLC [{dlc.dlci}] open complete')
self.change_state(Multiplexer.CONNECTED)
if self.open_result:
@@ -808,13 +843,16 @@ class Multiplexer(EventEmitter):
# -----------------------------------------------------------------------------
class Client:
def __init__(self, device, connection):
multiplexer: Optional[Multiplexer]
l2cap_channel: Optional[l2cap.Channel]
def __init__(self, device, connection) -> None:
self.device = device
self.connection = connection
self.l2cap_channel = None
self.multiplexer = None
async def start(self):
async def start(self) -> Multiplexer:
# Create a new L2CAP connection
try:
self.l2cap_channel = await self.device.l2cap_channel_manager.connect(
@@ -824,6 +862,7 @@ class Client:
logger.warning(f'L2CAP connection failed: {error}')
raise
assert self.l2cap_channel is not None
# Create a mutliplexer to manage DLCs with the server
self.multiplexer = Multiplexer(self.l2cap_channel, Multiplexer.INITIATOR)
@@ -832,7 +871,9 @@ class Client:
return self.multiplexer
async def shutdown(self):
async def shutdown(self) -> None:
if self.multiplexer is None:
return
# Disconnect the multiplexer
await self.multiplexer.disconnect()
self.multiplexer = None
@@ -843,7 +884,9 @@ class Client:
# -----------------------------------------------------------------------------
class Server(EventEmitter):
def __init__(self, device):
acceptors: Dict[int, Callable[[DLC], None]]
def __init__(self, device) -> None:
super().__init__()
self.device = device
self.multiplexer = None
@@ -852,7 +895,7 @@ class Server(EventEmitter):
# Register ourselves with the L2CAP channel manager
device.register_l2cap_server(RFCOMM_PSM, self.on_connection)
def listen(self, acceptor, channel=0):
def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int:
if channel:
if channel in self.acceptors:
# Busy
@@ -874,11 +917,11 @@ class Server(EventEmitter):
self.acceptors[channel] = acceptor
return channel
def on_connection(self, l2cap_channel):
def on_connection(self, l2cap_channel: l2cap.Channel) -> None:
logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
def on_l2cap_channel_open(self, l2cap_channel):
def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None:
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
# Create a new multiplexer for the channel
@@ -889,10 +932,10 @@ class Server(EventEmitter):
# Notify
self.emit('start', multiplexer)
def accept_dlc(self, channel_number):
def accept_dlc(self, channel_number: int) -> bool:
return channel_number in self.acceptors
def on_dlc(self, dlc):
def on_dlc(self, dlc: DLC) -> None:
logger.debug(f'@@@ new DLC connected: {dlc}')
# Let the acceptor know
+1 -1
View File
@@ -1805,7 +1805,7 @@ class Manager(EventEmitter):
self.device.abort_on('flush', store_keys())
# Notify the device
self.device.on_pairing(session.connection, keys, session.sc)
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
def on_pairing_failure(self, session: Session, reason: int) -> None:
self.device.on_pairing_failure(session.connection, reason)
+1 -1
View File
@@ -2,5 +2,5 @@
mkdocs == 1.4.0
mkdocs-material == 8.5.6
mkdocs-material-extensions == 1.0.3
pymdown-extensions == 9.6
pymdown-extensions == 10.0
mkdocstrings-python == 0.7.1
+3
View File
@@ -40,6 +40,9 @@ disable = [
"too-many-statements",
]
[tool.pylint.main]
ignore="pandora" # FIXME: pylint does not support stubs yet:
[tool.pylint.typecheck]
signature-mutators="AsyncRunner.run_in_task"
+4 -2
View File
@@ -24,7 +24,7 @@ url = https://github.com/google/bumble
[options]
python_requires = >=3.8
packages = bumble, bumble.transport, bumble.profiles, bumble.apps, bumble.apps.link_relay
packages = bumble, bumble.transport, bumble.profiles, bumble.apps, bumble.apps.link_relay, bumble.pandora
package_dir =
bumble = bumble
bumble.apps = apps
@@ -33,7 +33,7 @@ install_requires =
appdirs >= 1.4
click >= 7.1.2; platform_system!='Emscripten'
cryptography == 35; platform_system!='Emscripten'
grpcio >= 1.46; platform_system!='Emscripten'
grpcio == 1.51.1; platform_system!='Emscripten'
libusb1 >= 2.0.1; platform_system!='Emscripten'
libusb-package == 1.0.26.1; platform_system!='Emscripten'
prompt_toolkit >= 3.0.16; platform_system!='Emscripten'
@@ -45,6 +45,7 @@ install_requires =
websockets >= 8.1; platform_system!='Emscripten'
prettytable >= 3.6.0
humanize >= 4.6.0
bt-test-interfaces >= 0.0.2
[options.entry_points]
console_scripts =
@@ -60,6 +61,7 @@ console_scripts =
bumble-usb-probe = bumble.apps.usb_probe:main
bumble-link-relay = bumble.apps.link_relay.link_relay:main
bumble-bench = bumble.apps.bench:main
bumble-pandora-server = bumble.apps.pandora_server:main
[options.package_data]
* = py.typed, *.pyi
+179
View File
@@ -0,0 +1,179 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import json
import logging
import tempfile
import os
from bumble.keys import JsonKeyStore, PairingKeys
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
JSON1 = """
{
"my_namespace": {
"14:7D:DA:4E:53:A8/P": {
"address_type": 0,
"irk": {
"authenticated": false,
"value": "e7b2543b206e4e46b44f9e51dad22bd1"
},
"link_key": {
"authenticated": false,
"value": "0745dd9691e693d9dca740f7d8dfea75"
},
"ltk": {
"authenticated": false,
"value": "d1897ee10016eb1a08e4e037fd54c683"
}
}
}
}
"""
JSON2 = """
{
"my_namespace1": {
},
"my_namespace2": {
}
}
"""
JSON3 = """
{
"my_namespace1": {
},
"__DEFAULT__": {
"14:7D:DA:4E:53:A8/P": {
"address_type": 0,
"irk": {
"authenticated": false,
"value": "e7b2543b206e4e46b44f9e51dad22bd1"
}
}
}
}
"""
# -----------------------------------------------------------------------------
async def test_basic():
with tempfile.NamedTemporaryFile(mode="r+", encoding='utf-8') as file:
keystore = JsonKeyStore('my_namespace', file.name)
file.write("{}")
file.flush()
keys = await keystore.get_all()
assert len(keys) == 0
keys = PairingKeys()
await keystore.update('foo', keys)
foo = await keystore.get('foo')
assert foo is not None
assert foo.ltk is None
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
foo = await keystore.get('foo')
assert foo is not None
assert foo.ltk is not None
assert foo.ltk.value == ltk
file.flush()
with open(file.name, "r", encoding="utf-8") as json_file:
json_data = json.load(json_file)
assert 'my_namespace' in json_data
assert 'foo' in json_data['my_namespace']
assert 'ltk' in json_data['my_namespace']['foo']
# -----------------------------------------------------------------------------
async def test_parsing():
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
keystore = JsonKeyStore('my_namespace', file.name)
file.write(JSON1)
file.flush()
foo = await keystore.get('14:7D:DA:4E:53:A8/P')
assert foo is not None
assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683')
# -----------------------------------------------------------------------------
async def test_default_namespace():
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON1)
file.flush()
all_keys = await keystore.get_all()
assert len(all_keys) == 1
name, keys = all_keys[0]
assert name == '14:7D:DA:4E:53:A8/P'
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON2)
file.flush()
keys = PairingKeys()
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
file.flush()
with open(file.name, "r", encoding="utf-8") as json_file:
json_data = json.load(json_file)
assert '__DEFAULT__' in json_data
assert 'foo' in json_data['__DEFAULT__']
assert 'ltk' in json_data['__DEFAULT__']['foo']
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON3)
file.flush()
all_keys = await keystore.get_all()
assert len(all_keys) == 1
name, keys = all_keys[0]
assert name == '14:7D:DA:4E:53:A8/P'
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
# -----------------------------------------------------------------------------
async def run_tests():
await test_basic()
await test_parsing()
await test_default_namespace()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run_tests())