forked from auracaster/bumble_mirror
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7f987dc3cd | |||
| 689745040f | |||
| 809d4a18f5 | |||
| 54be8b328a | |||
| 57b469198a | |||
| 4d74339c04 | |||
| 39db278f2e | |||
| 27fbb58447 | |||
| a74c39dc2b | |||
| 22f7cef771 | |||
| 5790d3aae8 | |||
| 744294f00e | |||
| 3697b8dde9 | |||
| f3bfbab44d | |||
| afcce0d6c8 | |||
| 69d45bed21 | |||
| 6826f68478 | |||
| f80c83d0b3 | |||
| 3de35193bc | |||
| 740a2e0ca0 |
+14
-14
@@ -207,7 +207,7 @@ def on_connection(connection, request):
|
||||
|
||||
# Listen for pairing events
|
||||
connection.on('pairing_start', on_pairing_start)
|
||||
connection.on('pairing', on_pairing)
|
||||
connection.on('pairing', lambda keys: on_pairing(connection.peer_address, keys))
|
||||
connection.on('pairing_failure', on_pairing_failure)
|
||||
|
||||
# Listen for encryption changes
|
||||
@@ -242,9 +242,9 @@ def on_pairing_start():
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_pairing(keys):
|
||||
def on_pairing(address, keys):
|
||||
print(color('***-----------------------------------', 'cyan'))
|
||||
print(color('*** Paired!', 'cyan'))
|
||||
print(color(f'*** Paired! (peer identity={address})', 'cyan'))
|
||||
keys.print(prefix=color('*** ', 'cyan'))
|
||||
print(color('***-----------------------------------', 'cyan'))
|
||||
Waiter.instance.terminate()
|
||||
@@ -283,17 +283,6 @@ async def pair(
|
||||
# Create a device to manage the host
|
||||
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
|
||||
|
||||
# Set a custom keystore if specified on the command line
|
||||
if keystore_file:
|
||||
device.keystore = JsonKeyStore(namespace=None, filename=keystore_file)
|
||||
|
||||
# Print the existing keys before pairing
|
||||
if print_keys and device.keystore:
|
||||
print(color('@@@-----------------------------------', 'blue'))
|
||||
print(color('@@@ Pairing Keys:', 'blue'))
|
||||
await device.keystore.print(prefix=color('@@@ ', 'blue'))
|
||||
print(color('@@@-----------------------------------', 'blue'))
|
||||
|
||||
# Expose a GATT characteristic that can be used to trigger pairing by
|
||||
# responding with an authentication error when read
|
||||
if mode == 'le':
|
||||
@@ -323,6 +312,17 @@ async def pair(
|
||||
# Get things going
|
||||
await device.power_on()
|
||||
|
||||
# Set a custom keystore if specified on the command line
|
||||
if keystore_file:
|
||||
device.keystore = JsonKeyStore.from_device(device, filename=keystore_file)
|
||||
|
||||
# Print the existing keys before pairing
|
||||
if print_keys and device.keystore:
|
||||
print(color('@@@-----------------------------------', 'blue'))
|
||||
print(color('@@@ Pairing Keys:', 'blue'))
|
||||
await device.keystore.print(prefix=color('@@@ ', 'blue'))
|
||||
print(color('@@@-----------------------------------', 'blue'))
|
||||
|
||||
# Set up a pairing config factory
|
||||
device.pairing_config_factory = lambda connection: PairingConfig(
|
||||
sc, mitm, bond, Delegate(mode, connection, io, prompt)
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
import asyncio
|
||||
import click
|
||||
import logging
|
||||
|
||||
from bumble.pandora import PandoraDevice, serve
|
||||
|
||||
BUMBLE_SERVER_GRPC_PORT = 7999
|
||||
ROOTCANAL_PORT_CUTTLEFISH = 7300
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--grpc-port', help='gRPC port to serve', default=BUMBLE_SERVER_GRPC_PORT)
|
||||
@click.option(
|
||||
'--rootcanal-port', help='Rootcanal TCP port', default=ROOTCANAL_PORT_CUTTLEFISH
|
||||
)
|
||||
@click.option(
|
||||
'--transport',
|
||||
help='HCI transport',
|
||||
default=f'tcp-client:127.0.0.1:<rootcanal-port>',
|
||||
)
|
||||
def main(grpc_port: int, rootcanal_port: int, transport: str) -> None:
|
||||
if '<rootcanal-port>' in transport:
|
||||
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
|
||||
device = PandoraDevice({'transport': transport})
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
asyncio.run(serve(device, port=grpc_port))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
+5
-6
@@ -133,15 +133,16 @@ async def scan(
|
||||
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
|
||||
)
|
||||
|
||||
await device.power_on()
|
||||
|
||||
if keystore_file:
|
||||
keystore = JsonKeyStore(namespace=None, filename=keystore_file)
|
||||
device.keystore = keystore
|
||||
else:
|
||||
resolver = None
|
||||
device.keystore = JsonKeyStore.from_device(device, filename=keystore_file)
|
||||
|
||||
if device.keystore:
|
||||
resolving_keys = await device.keystore.get_resolving_keys()
|
||||
resolver = AddressResolver(resolving_keys)
|
||||
else:
|
||||
resolver = None
|
||||
|
||||
printer = AdvertisementPrinter(min_rssi, resolver)
|
||||
if raw:
|
||||
@@ -149,8 +150,6 @@ async def scan(
|
||||
else:
|
||||
device.on('advertisement', printer.on_advertisement)
|
||||
|
||||
await device.power_on()
|
||||
|
||||
if phy is None:
|
||||
scanning_phys = [HCI_LE_1M_PHY, HCI_LE_CODED_PHY]
|
||||
else:
|
||||
|
||||
+40
-22
@@ -22,40 +22,58 @@ import click
|
||||
|
||||
from bumble.device import Device
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.transport import open_transport
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def unbond_with_keystore(keystore, address):
|
||||
if address is None:
|
||||
return await keystore.print()
|
||||
|
||||
try:
|
||||
await keystore.delete(address)
|
||||
except KeyError:
|
||||
print('!!! pairing not found')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def unbond(keystore_file, device_config, address):
|
||||
# Create a device to manage the host
|
||||
device = Device.from_config_file(device_config)
|
||||
|
||||
# Get all entries in the keystore
|
||||
async def unbond(keystore_file, device_config, hci_transport, address):
|
||||
# With a keystore file, we can instantiate the keystore directly
|
||||
if keystore_file:
|
||||
keystore = JsonKeyStore(None, keystore_file)
|
||||
else:
|
||||
keystore = device.keystore
|
||||
return await unbond_with_keystore(JsonKeyStore(None, keystore_file), address)
|
||||
|
||||
if keystore is None:
|
||||
print('no keystore')
|
||||
return
|
||||
# Without a keystore file, we need to obtain the keystore from the device
|
||||
async with await open_transport(hci_transport) as (hci_source, hci_sink):
|
||||
# Create a device to manage the host
|
||||
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
|
||||
|
||||
if address is None:
|
||||
await keystore.print()
|
||||
else:
|
||||
try:
|
||||
await keystore.delete(address)
|
||||
except KeyError:
|
||||
print('!!! pairing not found')
|
||||
# Power-on the device to ensure we have a key store
|
||||
await device.power_on()
|
||||
|
||||
return await unbond_with_keystore(device.keystore, address)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.command()
|
||||
@click.option('--keystore-file', help='File in which to store the pairing keys')
|
||||
@click.argument('device-config')
|
||||
@click.option('--keystore-file', help='File in which the pairing keys are stored')
|
||||
@click.option('--hci-transport', help='HCI transport for the controller')
|
||||
@click.argument('device-config', required=False)
|
||||
@click.argument('address', required=False)
|
||||
def main(keystore_file, device_config, address):
|
||||
def main(keystore_file, hci_transport, device_config, address):
|
||||
"""
|
||||
Remove pairing keys for a device, given its address.
|
||||
|
||||
If no keystore file is specified, the --hci-transport option must be used to
|
||||
connect to a controller, so that the keystore for that controller can be
|
||||
instantiated.
|
||||
If no address is passed, the existing pairing keys for all addresses are printed.
|
||||
"""
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
asyncio.run(unbond(keystore_file, device_config, address))
|
||||
|
||||
if not keystore_file and not hci_transport:
|
||||
print('either --keystore-file or --hci-transport must be specified.')
|
||||
return
|
||||
|
||||
asyncio.run(unbond(keystore_file, device_config, hci_transport, address))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
+16
-1
@@ -826,6 +826,12 @@ class DeviceConfiguration:
|
||||
advertising_data = config.get('advertising_data')
|
||||
if advertising_data:
|
||||
self.advertising_data = bytes.fromhex(advertising_data)
|
||||
elif config.get('name') is not None:
|
||||
self.advertising_data = bytes(
|
||||
AdvertisingData(
|
||||
[(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.name, 'utf-8'))]
|
||||
)
|
||||
)
|
||||
|
||||
def load_from_file(self, filename):
|
||||
with open(filename, 'r', encoding='utf-8') as file:
|
||||
@@ -3088,7 +3094,16 @@ class Device(CompositeEventEmitter):
|
||||
def on_pairing_start(self, connection: Connection) -> None:
|
||||
connection.emit('pairing_start')
|
||||
|
||||
def on_pairing(self, connection: Connection, keys: PairingKeys, sc: bool) -> None:
|
||||
def on_pairing(
|
||||
self,
|
||||
connection: Connection,
|
||||
identity_address: Optional[Address],
|
||||
keys: PairingKeys,
|
||||
sc: bool,
|
||||
) -> None:
|
||||
if identity_address is not None:
|
||||
connection.peer_resolvable_address = connection.peer_address
|
||||
connection.peer_address = identity_address
|
||||
connection.sc = sc
|
||||
connection.authenticated = True
|
||||
connection.emit('pairing', keys)
|
||||
|
||||
+1
-1
@@ -212,8 +212,8 @@ class Service(Attribute):
|
||||
self,
|
||||
uuid,
|
||||
characteristics: List[Characteristic],
|
||||
included_services: List[Service] = [],
|
||||
primary=True,
|
||||
included_services: List[Service] = [],
|
||||
):
|
||||
# Convert the uuid to a UUID object if it isn't already
|
||||
if isinstance(uuid, str):
|
||||
|
||||
+14
-8
@@ -18,10 +18,11 @@
|
||||
import logging
|
||||
import asyncio
|
||||
import collections
|
||||
from typing import Union
|
||||
|
||||
from . import rfcomm
|
||||
from .colors import color
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -34,7 +35,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class HfpProtocol:
|
||||
def __init__(self, dlc):
|
||||
dlc: rfcomm.DLC
|
||||
buffer: str
|
||||
lines: collections.deque
|
||||
lines_available: asyncio.Event
|
||||
|
||||
def __init__(self, dlc: rfcomm.DLC) -> None:
|
||||
self.dlc = dlc
|
||||
self.buffer = ''
|
||||
self.lines = collections.deque()
|
||||
@@ -42,7 +48,7 @@ class HfpProtocol:
|
||||
|
||||
dlc.sink = self.feed
|
||||
|
||||
def feed(self, data):
|
||||
def feed(self, data: Union[bytes, str]) -> None:
|
||||
# Convert the data to a string if needed
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode('utf-8')
|
||||
@@ -57,19 +63,19 @@ class HfpProtocol:
|
||||
if len(line) > 0:
|
||||
self.on_line(line)
|
||||
|
||||
def on_line(self, line):
|
||||
def on_line(self, line: str) -> None:
|
||||
self.lines.append(line)
|
||||
self.lines_available.set()
|
||||
|
||||
def send_command_line(self, line):
|
||||
def send_command_line(self, line: str) -> None:
|
||||
logger.debug(color(f'>>> {line}', 'yellow'))
|
||||
self.dlc.write(line + '\r')
|
||||
|
||||
def send_response_line(self, line):
|
||||
def send_response_line(self, line: str) -> None:
|
||||
logger.debug(color(f'>>> {line}', 'yellow'))
|
||||
self.dlc.write('\r\n' + line + '\r\n')
|
||||
|
||||
async def next_line(self):
|
||||
async def next_line(self) -> str:
|
||||
await self.lines_available.wait()
|
||||
line = self.lines.popleft()
|
||||
if not self.lines:
|
||||
@@ -77,7 +83,7 @@ class HfpProtocol:
|
||||
logger.debug(color(f'<<< {line}', 'green'))
|
||||
return line
|
||||
|
||||
async def initialize_service(self):
|
||||
async def initialize_service(self) -> None:
|
||||
# Perform Service Level Connection Initialization
|
||||
self.send_command_line('AT+BRSF=2072') # Retrieve Supported Features
|
||||
await (self.next_line())
|
||||
|
||||
+72
-46
@@ -190,10 +190,44 @@ class KeyStore:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class JsonKeyStore(KeyStore):
|
||||
"""
|
||||
KeyStore implementation that is backed by a JSON file.
|
||||
|
||||
This implementation supports storing a hierarchy of key sets in a single file.
|
||||
A key set is a representation of a PairingKeys object. Each key set is stored
|
||||
in a map, with the address of paired peer as the key. Maps are themselves grouped
|
||||
into namespaces, grouping pairing keys by controller addresses.
|
||||
The JSON object model looks like:
|
||||
{
|
||||
"<namespace>": {
|
||||
"peer-address": {
|
||||
"address_type": <n>,
|
||||
"irk" : {
|
||||
"authenticated": <true/false>,
|
||||
"value": "hex-encoded-key"
|
||||
},
|
||||
... other keys ...
|
||||
},
|
||||
... other peers ...
|
||||
}
|
||||
... other namespaces ...
|
||||
}
|
||||
|
||||
A namespace is typically the BD_ADDR of a controller, since that is a convenient
|
||||
unique identifier, but it may be something else.
|
||||
A special namespace, called the "default" namespace, is used when instantiating this
|
||||
class without a namespace. With the default namespace, reading from a file will
|
||||
load an existing namespace if there is only one, which may be convenient for reading
|
||||
from a file with a single key set and for which the namespace isn't known. If the
|
||||
file does not include any existing key set, or if there are more than one and none
|
||||
has the default name, a new one will be created with the name "__DEFAULT__".
|
||||
"""
|
||||
|
||||
APP_NAME = 'Bumble'
|
||||
APP_AUTHOR = 'Google'
|
||||
KEYS_DIR = 'Pairing'
|
||||
DEFAULT_NAMESPACE = '__DEFAULT__'
|
||||
DEFAULT_BASE_NAME = "keys"
|
||||
|
||||
def __init__(self, namespace, filename=None):
|
||||
self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE
|
||||
@@ -208,8 +242,9 @@ class JsonKeyStore(KeyStore):
|
||||
self.directory_name = os.path.join(
|
||||
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
|
||||
)
|
||||
base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace
|
||||
json_filename = (
|
||||
f'{self.namespace}.json'.lower().replace(':', '-').replace('/p', '-p')
|
||||
f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p')
|
||||
)
|
||||
self.filename = os.path.join(self.directory_name, json_filename)
|
||||
else:
|
||||
@@ -219,11 +254,13 @@ class JsonKeyStore(KeyStore):
|
||||
logger.debug(f'JSON keystore: {self.filename}')
|
||||
|
||||
@staticmethod
|
||||
def from_device(device: Device) -> Optional[JsonKeyStore]:
|
||||
if not device.config.keystore:
|
||||
return None
|
||||
|
||||
params = device.config.keystore.split(':', 1)[1:]
|
||||
def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]:
|
||||
if not filename:
|
||||
# Extract the filename from the config if there is one
|
||||
if device.config.keystore is not None:
|
||||
params = device.config.keystore.split(':', 1)[1:]
|
||||
if params:
|
||||
filename = params[0]
|
||||
|
||||
# Use a namespace based on the device address
|
||||
if device.public_address not in (Address.ANY, Address.ANY_RANDOM):
|
||||
@@ -232,19 +269,31 @@ class JsonKeyStore(KeyStore):
|
||||
namespace = str(device.random_address)
|
||||
else:
|
||||
namespace = JsonKeyStore.DEFAULT_NAMESPACE
|
||||
if params:
|
||||
filename = params[0]
|
||||
else:
|
||||
filename = None
|
||||
|
||||
return JsonKeyStore(namespace, filename)
|
||||
|
||||
async def load(self):
|
||||
# Try to open the file, without failing. If the file does not exist, it
|
||||
# will be created upon saving.
|
||||
try:
|
||||
with open(self.filename, 'r', encoding='utf-8') as json_file:
|
||||
return json.load(json_file)
|
||||
db = json.load(json_file)
|
||||
except FileNotFoundError:
|
||||
return {}
|
||||
db = {}
|
||||
|
||||
# First, look for a namespace match
|
||||
if self.namespace in db:
|
||||
return (db, db[self.namespace])
|
||||
|
||||
# Then, if the namespace is the default namespace, and there's
|
||||
# only one entry in the db, use that
|
||||
if self.namespace == self.DEFAULT_NAMESPACE and len(db) == 1:
|
||||
return next(iter(db.items()))
|
||||
|
||||
# Finally, just create an empty key map for the namespace
|
||||
key_map = {}
|
||||
db[self.namespace] = key_map
|
||||
return (db, key_map)
|
||||
|
||||
async def save(self, db):
|
||||
# Create the directory if it doesn't exist
|
||||
@@ -260,53 +309,30 @@ class JsonKeyStore(KeyStore):
|
||||
os.replace(temp_filename, self.filename)
|
||||
|
||||
async def delete(self, name: str) -> None:
|
||||
db = await self.load()
|
||||
|
||||
namespace = db.get(self.namespace)
|
||||
if namespace is None:
|
||||
raise KeyError(name)
|
||||
|
||||
del namespace[name]
|
||||
db, key_map = await self.load()
|
||||
del key_map[name]
|
||||
await self.save(db)
|
||||
|
||||
async def update(self, name, keys):
|
||||
db = await self.load()
|
||||
|
||||
namespace = db.setdefault(self.namespace, {})
|
||||
namespace.setdefault(name, {}).update(keys.to_dict())
|
||||
|
||||
db, key_map = await self.load()
|
||||
key_map.setdefault(name, {}).update(keys.to_dict())
|
||||
await self.save(db)
|
||||
|
||||
async def get_all(self):
|
||||
db = await self.load()
|
||||
|
||||
namespace = db.get(self.namespace)
|
||||
if namespace is None:
|
||||
return []
|
||||
|
||||
return [
|
||||
(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()
|
||||
]
|
||||
_, key_map = await self.load()
|
||||
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()]
|
||||
|
||||
async def delete_all(self):
|
||||
db = await self.load()
|
||||
|
||||
db.pop(self.namespace, None)
|
||||
|
||||
db, key_map = await self.load()
|
||||
key_map.clear()
|
||||
await self.save(db)
|
||||
|
||||
async def get(self, name: str) -> Optional[PairingKeys]:
|
||||
db = await self.load()
|
||||
|
||||
namespace = db.get(self.namespace)
|
||||
if namespace is None:
|
||||
_, key_map = await self.load()
|
||||
if name not in key_map:
|
||||
return None
|
||||
|
||||
keys = namespace.get(name)
|
||||
if keys is None:
|
||||
return None
|
||||
|
||||
return PairingKeys.from_dict(keys)
|
||||
return PairingKeys.from_dict(key_map[name])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Bumble Pandora server.
|
||||
This module implement the Pandora Bluetooth test APIs for the Bumble stack.
|
||||
"""
|
||||
|
||||
__version__ = "0.0.1"
|
||||
|
||||
import grpc
|
||||
import grpc.aio
|
||||
|
||||
from .config import Config
|
||||
from .device import PandoraDevice
|
||||
from .host import HostService
|
||||
from .security import SecurityService, SecurityStorageService
|
||||
from pandora.host_grpc_aio import add_HostServicer_to_server
|
||||
from pandora.security_grpc_aio import (
|
||||
add_SecurityServicer_to_server,
|
||||
add_SecurityStorageServicer_to_server,
|
||||
)
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
# public symbols
|
||||
__all__ = [
|
||||
'register_servicer_hook',
|
||||
'serve',
|
||||
'Config',
|
||||
'PandoraDevice',
|
||||
]
|
||||
|
||||
|
||||
# Add servicers hooks.
|
||||
_SERVICERS_HOOKS: List[Callable[[PandoraDevice, Config, grpc.aio.Server], None]] = []
|
||||
|
||||
|
||||
def register_servicer_hook(
|
||||
hook: Callable[[PandoraDevice, Config, grpc.aio.Server], None]
|
||||
) -> None:
|
||||
_SERVICERS_HOOKS.append(hook)
|
||||
|
||||
|
||||
async def serve(
|
||||
bumble: PandoraDevice,
|
||||
config: Config = Config(),
|
||||
grpc_server: Optional[grpc.aio.Server] = None,
|
||||
port: int = 0,
|
||||
) -> None:
|
||||
# initialize a gRPC server if not provided.
|
||||
server = grpc_server if grpc_server is not None else grpc.aio.server()
|
||||
port = server.add_insecure_port(f'localhost:{port}')
|
||||
|
||||
try:
|
||||
while True:
|
||||
# load server config from dict.
|
||||
config.load_from_dict(bumble.config.get('server', {}))
|
||||
|
||||
# add Pandora services to the gRPC server.
|
||||
add_HostServicer_to_server(
|
||||
HostService(server, bumble.device, config), server
|
||||
)
|
||||
add_SecurityServicer_to_server(
|
||||
SecurityService(bumble.device, config), server
|
||||
)
|
||||
add_SecurityStorageServicer_to_server(
|
||||
SecurityStorageService(bumble.device, config), server
|
||||
)
|
||||
|
||||
# call hooks if any.
|
||||
for hook in _SERVICERS_HOOKS:
|
||||
hook(bumble, config, server)
|
||||
|
||||
# open device.
|
||||
await bumble.open()
|
||||
try:
|
||||
# Pandora require classic devices to be discoverable & connectable.
|
||||
if bumble.device.classic_enabled:
|
||||
await bumble.device.set_discoverable(True)
|
||||
await bumble.device.set_connectable(True)
|
||||
|
||||
# start & serve gRPC server.
|
||||
await server.start()
|
||||
await server.wait_for_termination()
|
||||
finally:
|
||||
# close device.
|
||||
await bumble.close()
|
||||
|
||||
# re-initialize the gRPC server.
|
||||
server = grpc.aio.server()
|
||||
server.add_insecure_port(f'localhost:{port}')
|
||||
finally:
|
||||
# stop server.
|
||||
await server.stop(None)
|
||||
@@ -0,0 +1,48 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from bumble.pairing import PairingDelegate
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
io_capability: PairingDelegate.IoCapability = PairingDelegate.NO_OUTPUT_NO_INPUT
|
||||
pairing_sc_enable: bool = True
|
||||
pairing_mitm_enable: bool = True
|
||||
pairing_bonding_enable: bool = True
|
||||
smp_local_initiator_key_distribution: PairingDelegate.KeyDistribution = (
|
||||
PairingDelegate.DEFAULT_KEY_DISTRIBUTION
|
||||
)
|
||||
smp_local_responder_key_distribution: PairingDelegate.KeyDistribution = (
|
||||
PairingDelegate.DEFAULT_KEY_DISTRIBUTION
|
||||
)
|
||||
|
||||
def load_from_dict(self, config: Dict[str, Any]) -> None:
|
||||
io_capability_name: str = config.get(
|
||||
'io_capability', 'no_output_no_input'
|
||||
).upper()
|
||||
self.io_capability = getattr(PairingDelegate, io_capability_name)
|
||||
self.pairing_sc_enable = config.get('pairing_sc_enable', True)
|
||||
self.pairing_mitm_enable = config.get('pairing_mitm_enable', True)
|
||||
self.pairing_bonding_enable = config.get('pairing_bonding_enable', True)
|
||||
self.smp_local_initiator_key_distribution = config.get(
|
||||
'smp_local_initiator_key_distribution',
|
||||
PairingDelegate.DEFAULT_KEY_DISTRIBUTION,
|
||||
)
|
||||
self.smp_local_responder_key_distribution = config.get(
|
||||
'smp_local_responder_key_distribution',
|
||||
PairingDelegate.DEFAULT_KEY_DISTRIBUTION,
|
||||
)
|
||||
@@ -0,0 +1,157 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Generic & dependency free Bumble (reference) device."""
|
||||
|
||||
from bumble import transport
|
||||
from bumble.core import (
|
||||
BT_GENERIC_AUDIO_SERVICE,
|
||||
BT_HANDSFREE_SERVICE,
|
||||
BT_L2CAP_PROTOCOL_ID,
|
||||
BT_RFCOMM_PROTOCOL_ID,
|
||||
)
|
||||
from bumble.device import Device, DeviceConfiguration
|
||||
from bumble.host import Host
|
||||
from bumble.sdp import (
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
DataElement,
|
||||
ServiceAttribute,
|
||||
)
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class PandoraDevice:
|
||||
"""
|
||||
Small wrapper around a Bumble device and it's HCI transport.
|
||||
Notes:
|
||||
- The Bumble device is idle by default.
|
||||
- Repetitive calls to `open`/`close` will result on new Bumble device instances.
|
||||
"""
|
||||
|
||||
# Bumble device instance & configuration.
|
||||
device: Device
|
||||
config: Dict[str, Any]
|
||||
|
||||
# HCI transport name & instance.
|
||||
_hci_name: str
|
||||
_hci: Optional[transport.Transport] # type: ignore[name-defined]
|
||||
|
||||
def __init__(self, config: Dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.device = _make_device(config)
|
||||
self._hci_name = config.get('transport', '')
|
||||
self._hci = None
|
||||
|
||||
@property
|
||||
def idle(self) -> bool:
|
||||
return self._hci is None
|
||||
|
||||
async def open(self) -> None:
|
||||
if self._hci is not None:
|
||||
return
|
||||
|
||||
# open HCI transport & set device host.
|
||||
self._hci = await transport.open_transport(self._hci_name)
|
||||
self.device.host = Host(controller_source=self._hci.source, controller_sink=self._hci.sink) # type: ignore[no-untyped-call]
|
||||
|
||||
# power-on.
|
||||
await self.device.power_on()
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._hci is None:
|
||||
return
|
||||
|
||||
# flush & re-initialize device.
|
||||
await self.device.host.flush()
|
||||
self.device.host = None # type: ignore[assignment]
|
||||
self.device = _make_device(self.config)
|
||||
|
||||
# close HCI transport.
|
||||
await self._hci.close()
|
||||
self._hci = None
|
||||
|
||||
async def reset(self) -> None:
|
||||
await self.close()
|
||||
await self.open()
|
||||
|
||||
def info(self) -> Optional[Dict[str, str]]:
|
||||
return {
|
||||
'public_bd_address': str(self.device.public_address),
|
||||
'random_address': str(self.device.random_address),
|
||||
}
|
||||
|
||||
|
||||
def _make_device(config: Dict[str, Any]) -> Device:
|
||||
"""Initialize an idle Bumble device instance."""
|
||||
|
||||
# initialize bumble device.
|
||||
device_config = DeviceConfiguration()
|
||||
device_config.load_from_dict(config)
|
||||
device = Device(config=device_config, host=None)
|
||||
|
||||
# Add fake a2dp service to avoid Android disconnect
|
||||
device.sdp_service_records = _make_sdp_records(1)
|
||||
|
||||
return device
|
||||
|
||||
|
||||
# TODO(b/267540823): remove when Pandora A2dp is supported
|
||||
def _make_sdp_records(rfcomm_channel: int) -> Dict[int, List[ServiceAttribute]]:
|
||||
return {
|
||||
0x00010001: [
|
||||
ServiceAttribute(
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_32(0x00010001),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_HANDSFREE_SERVICE),
|
||||
DataElement.uuid(BT_GENERIC_AUDIO_SERVICE),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
|
||||
DataElement.unsigned_integer_8(rfcomm_channel),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_HANDSFREE_SERVICE),
|
||||
DataElement.unsigned_integer_16(0x0105),
|
||||
]
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,857 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import bumble.device
|
||||
import grpc
|
||||
import grpc.aio
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from . import utils
|
||||
from .config import Config
|
||||
from bumble.core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_LE_TRANSPORT,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
UUID,
|
||||
AdvertisingData,
|
||||
ConnectionError,
|
||||
)
|
||||
from bumble.device import (
|
||||
DEVICE_DEFAULT_SCAN_INTERVAL,
|
||||
DEVICE_DEFAULT_SCAN_WINDOW,
|
||||
Advertisement,
|
||||
AdvertisingType,
|
||||
Device,
|
||||
)
|
||||
from bumble.gatt import Service
|
||||
from bumble.hci import (
|
||||
HCI_CONNECTION_ALREADY_EXISTS_ERROR,
|
||||
HCI_PAGE_TIMEOUT_ERROR,
|
||||
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
Address,
|
||||
)
|
||||
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
||||
from pandora.host_grpc_aio import HostServicer
|
||||
from pandora.host_pb2 import (
|
||||
NOT_CONNECTABLE,
|
||||
NOT_DISCOVERABLE,
|
||||
PRIMARY_1M,
|
||||
PRIMARY_CODED,
|
||||
SECONDARY_1M,
|
||||
SECONDARY_2M,
|
||||
SECONDARY_CODED,
|
||||
SECONDARY_NONE,
|
||||
AdvertiseRequest,
|
||||
AdvertiseResponse,
|
||||
Connection,
|
||||
ConnectLERequest,
|
||||
ConnectLEResponse,
|
||||
ConnectRequest,
|
||||
ConnectResponse,
|
||||
DataTypes,
|
||||
DisconnectRequest,
|
||||
InquiryResponse,
|
||||
PrimaryPhy,
|
||||
ReadLocalAddressResponse,
|
||||
ScanningResponse,
|
||||
ScanRequest,
|
||||
SecondaryPhy,
|
||||
SetConnectabilityModeRequest,
|
||||
SetDiscoverabilityModeRequest,
|
||||
WaitConnectionRequest,
|
||||
WaitConnectionResponse,
|
||||
WaitDisconnectionRequest,
|
||||
)
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Set, Tuple, cast
|
||||
|
||||
PRIMARY_PHY_MAP: Dict[int, PrimaryPhy] = {
|
||||
# Default value reported by Bumble for legacy Advertising reports.
|
||||
# FIXME(uael): `None` might be a better value, but Bumble need to change accordingly.
|
||||
0: PRIMARY_1M,
|
||||
1: PRIMARY_1M,
|
||||
3: PRIMARY_CODED,
|
||||
}
|
||||
|
||||
SECONDARY_PHY_MAP: Dict[int, SecondaryPhy] = {
|
||||
0: SECONDARY_NONE,
|
||||
1: SECONDARY_1M,
|
||||
2: SECONDARY_2M,
|
||||
3: SECONDARY_CODED,
|
||||
}
|
||||
|
||||
|
||||
class HostService(HostServicer):
|
||||
waited_connections: Set[int]
|
||||
|
||||
def __init__(
|
||||
self, grpc_server: grpc.aio.Server, device: Device, config: Config
|
||||
) -> None:
|
||||
self.log = utils.BumbleServerLoggerAdapter(
|
||||
logging.getLogger(), {'service_name': 'Host', 'device': device}
|
||||
)
|
||||
self.grpc_server = grpc_server
|
||||
self.device = device
|
||||
self.config = config
|
||||
self.waited_connections = set()
|
||||
|
||||
@utils.rpc
|
||||
async def FactoryReset(
|
||||
self, request: empty_pb2.Empty, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
self.log.info('FactoryReset')
|
||||
|
||||
# delete all bonds
|
||||
if self.device.keystore is not None:
|
||||
await self.device.keystore.delete_all()
|
||||
|
||||
# trigger gRCP server stop then return
|
||||
asyncio.create_task(self.grpc_server.stop(None))
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@utils.rpc
|
||||
async def Reset(
|
||||
self, request: empty_pb2.Empty, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
self.log.info('Reset')
|
||||
|
||||
# clear service.
|
||||
self.waited_connections.clear()
|
||||
|
||||
# (re) power device on
|
||||
await self.device.power_on()
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@utils.rpc
|
||||
async def ReadLocalAddress(
|
||||
self, request: empty_pb2.Empty, context: grpc.ServicerContext
|
||||
) -> ReadLocalAddressResponse:
|
||||
self.log.info('ReadLocalAddress')
|
||||
return ReadLocalAddressResponse(
|
||||
address=bytes(reversed(bytes(self.device.public_address)))
|
||||
)
|
||||
|
||||
@utils.rpc
|
||||
async def Connect(
|
||||
self, request: ConnectRequest, context: grpc.ServicerContext
|
||||
) -> ConnectResponse:
|
||||
# Need to reverse bytes order since Bumble Address is using MSB.
|
||||
address = Address(
|
||||
bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS
|
||||
)
|
||||
self.log.info(f"Connect to {address}")
|
||||
|
||||
try:
|
||||
connection = await self.device.connect(
|
||||
address, transport=BT_BR_EDR_TRANSPORT
|
||||
)
|
||||
except ConnectionError as e:
|
||||
if e.error_code == HCI_PAGE_TIMEOUT_ERROR:
|
||||
self.log.warning(f"Peer not found: {e}")
|
||||
return ConnectResponse(peer_not_found=empty_pb2.Empty())
|
||||
if e.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR:
|
||||
self.log.warning(f"Connection already exists: {e}")
|
||||
return ConnectResponse(connection_already_exists=empty_pb2.Empty())
|
||||
raise e
|
||||
|
||||
self.log.info(f"Connect to {address} done (handle={connection.handle})")
|
||||
|
||||
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
|
||||
return ConnectResponse(connection=Connection(cookie=cookie))
|
||||
|
||||
@utils.rpc
|
||||
async def WaitConnection(
|
||||
self, request: WaitConnectionRequest, context: grpc.ServicerContext
|
||||
) -> WaitConnectionResponse:
|
||||
if not request.address:
|
||||
raise ValueError('Request address field must be set')
|
||||
|
||||
# Need to reverse bytes order since Bumble Address is using MSB.
|
||||
address = Address(
|
||||
bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS
|
||||
)
|
||||
if address in (Address.NIL, Address.ANY):
|
||||
raise ValueError('Invalid address')
|
||||
|
||||
self.log.info(f"WaitConnection from {address}...")
|
||||
|
||||
connection = self.device.find_connection_by_bd_addr(
|
||||
address, transport=BT_BR_EDR_TRANSPORT
|
||||
)
|
||||
if connection and id(connection) in self.waited_connections:
|
||||
# this connection was already returned: wait for a new one.
|
||||
connection = None
|
||||
|
||||
if not connection:
|
||||
connection = await self.device.accept(address)
|
||||
|
||||
# save connection has waited and respond.
|
||||
self.waited_connections.add(id(connection))
|
||||
|
||||
self.log.info(
|
||||
f"WaitConnection from {address} done (handle={connection.handle})"
|
||||
)
|
||||
|
||||
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
|
||||
return WaitConnectionResponse(connection=Connection(cookie=cookie))
|
||||
|
||||
@utils.rpc
|
||||
async def ConnectLE(
|
||||
self, request: ConnectLERequest, context: grpc.ServicerContext
|
||||
) -> ConnectLEResponse:
|
||||
address = utils.address_from_request(request, request.WhichOneof("address"))
|
||||
if address in (Address.NIL, Address.ANY):
|
||||
raise ValueError('Invalid address')
|
||||
|
||||
self.log.info(f"ConnectLE to {address}...")
|
||||
|
||||
try:
|
||||
connection = await self.device.connect(
|
||||
address,
|
||||
transport=BT_LE_TRANSPORT,
|
||||
own_address_type=request.own_address_type,
|
||||
)
|
||||
except ConnectionError as e:
|
||||
if e.error_code == HCI_PAGE_TIMEOUT_ERROR:
|
||||
self.log.warning(f"Peer not found: {e}")
|
||||
return ConnectLEResponse(peer_not_found=empty_pb2.Empty())
|
||||
if e.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR:
|
||||
self.log.warning(f"Connection already exists: {e}")
|
||||
return ConnectLEResponse(connection_already_exists=empty_pb2.Empty())
|
||||
raise e
|
||||
|
||||
self.log.info(f"ConnectLE to {address} done (handle={connection.handle})")
|
||||
|
||||
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
|
||||
return ConnectLEResponse(connection=Connection(cookie=cookie))
|
||||
|
||||
@utils.rpc
|
||||
async def Disconnect(
|
||||
self, request: DisconnectRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
self.log.info(f"Disconnect: {connection_handle}")
|
||||
|
||||
self.log.info("Disconnecting...")
|
||||
if connection := self.device.lookup_connection(connection_handle):
|
||||
await connection.disconnect(HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR)
|
||||
self.log.info("Disconnected")
|
||||
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@utils.rpc
|
||||
async def WaitDisconnection(
|
||||
self, request: WaitDisconnectionRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
self.log.info(f"WaitDisconnection: {connection_handle}")
|
||||
|
||||
if connection := self.device.lookup_connection(connection_handle):
|
||||
disconnection_future: asyncio.Future[
|
||||
None
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
|
||||
def on_disconnection(_: None) -> None:
|
||||
disconnection_future.set_result(None)
|
||||
|
||||
connection.on('disconnection', on_disconnection)
|
||||
try:
|
||||
await disconnection_future
|
||||
self.log.info("Disconnected")
|
||||
finally:
|
||||
connection.remove_listener('disconnection', on_disconnection) # type: ignore
|
||||
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@utils.rpc
|
||||
async def Advertise(
|
||||
self, request: AdvertiseRequest, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[AdvertiseResponse, None]:
|
||||
if not request.legacy:
|
||||
raise NotImplementedError(
|
||||
"TODO: add support for extended advertising in Bumble"
|
||||
)
|
||||
if request.interval:
|
||||
raise NotImplementedError("TODO: add support for `request.interval`")
|
||||
if request.interval_range:
|
||||
raise NotImplementedError("TODO: add support for `request.interval_range`")
|
||||
if request.primary_phy:
|
||||
raise NotImplementedError("TODO: add support for `request.primary_phy`")
|
||||
if request.secondary_phy:
|
||||
raise NotImplementedError("TODO: add support for `request.secondary_phy`")
|
||||
|
||||
if self.device.is_advertising:
|
||||
raise NotImplementedError('TODO: add support for advertising sets')
|
||||
|
||||
if data := request.data:
|
||||
self.device.advertising_data = bytes(self.unpack_data_types(data))
|
||||
|
||||
if scan_response_data := request.scan_response_data:
|
||||
self.device.scan_response_data = bytes(
|
||||
self.unpack_data_types(scan_response_data)
|
||||
)
|
||||
scannable = True
|
||||
else:
|
||||
scannable = False
|
||||
|
||||
# Retrieve services data
|
||||
for service in self.device.gatt_server.attributes:
|
||||
if isinstance(service, Service) and (
|
||||
service_data := service.get_advertising_data()
|
||||
):
|
||||
service_uuid = service.uuid.to_hex_str('-')
|
||||
if (
|
||||
service_uuid in request.data.incomplete_service_class_uuids16
|
||||
or service_uuid in request.data.complete_service_class_uuids16
|
||||
or service_uuid in request.data.incomplete_service_class_uuids32
|
||||
or service_uuid in request.data.complete_service_class_uuids32
|
||||
or service_uuid
|
||||
in request.data.incomplete_service_class_uuids128
|
||||
or service_uuid in request.data.complete_service_class_uuids128
|
||||
):
|
||||
self.device.advertising_data += service_data
|
||||
if (
|
||||
service_uuid
|
||||
in scan_response_data.incomplete_service_class_uuids16
|
||||
or service_uuid
|
||||
in scan_response_data.complete_service_class_uuids16
|
||||
or service_uuid
|
||||
in scan_response_data.incomplete_service_class_uuids32
|
||||
or service_uuid
|
||||
in scan_response_data.complete_service_class_uuids32
|
||||
or service_uuid
|
||||
in scan_response_data.incomplete_service_class_uuids128
|
||||
or service_uuid
|
||||
in scan_response_data.complete_service_class_uuids128
|
||||
):
|
||||
self.device.scan_response_data += service_data
|
||||
|
||||
target = None
|
||||
if request.connectable and scannable:
|
||||
advertising_type = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE
|
||||
elif scannable:
|
||||
advertising_type = AdvertisingType.UNDIRECTED_SCANNABLE
|
||||
else:
|
||||
advertising_type = AdvertisingType.UNDIRECTED
|
||||
else:
|
||||
target = None
|
||||
advertising_type = AdvertisingType.UNDIRECTED
|
||||
|
||||
if request.target:
|
||||
# Need to reverse bytes order since Bumble Address is using MSB.
|
||||
target_bytes = bytes(reversed(request.target))
|
||||
if request.target_variant() == "public":
|
||||
target = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
|
||||
advertising_type = (
|
||||
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
|
||||
) # FIXME: HIGH_DUTY ?
|
||||
else:
|
||||
target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
|
||||
advertising_type = (
|
||||
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
|
||||
) # FIXME: HIGH_DUTY ?
|
||||
|
||||
if request.connectable:
|
||||
|
||||
def on_connection(connection: bumble.device.Connection) -> None:
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
):
|
||||
pending_connection.set_result(connection)
|
||||
|
||||
self.device.on('connection', on_connection)
|
||||
|
||||
try:
|
||||
while True:
|
||||
if not self.device.is_advertising:
|
||||
self.log.info('Advertise')
|
||||
await self.device.start_advertising(
|
||||
target=target,
|
||||
advertising_type=advertising_type,
|
||||
own_address_type=request.own_address_type,
|
||||
)
|
||||
|
||||
if not request.connectable:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
pending_connection: asyncio.Future[
|
||||
bumble.device.Connection
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
|
||||
self.log.info('Wait for LE connection...')
|
||||
connection = await pending_connection
|
||||
|
||||
self.log.info(
|
||||
f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})"
|
||||
)
|
||||
|
||||
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
|
||||
yield AdvertiseResponse(connection=Connection(cookie=cookie))
|
||||
|
||||
# wait a small delay before restarting the advertisement.
|
||||
await asyncio.sleep(1)
|
||||
finally:
|
||||
if request.connectable:
|
||||
self.device.remove_listener('connection', on_connection) # type: ignore
|
||||
|
||||
try:
|
||||
self.log.info('Stop advertising')
|
||||
await self.device.abort_on('flush', self.device.stop_advertising())
|
||||
except:
|
||||
pass
|
||||
|
||||
@utils.rpc
|
||||
async def Scan(
|
||||
self, request: ScanRequest, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[ScanningResponse, None]:
|
||||
# TODO: modify `start_scanning` to accept floats instead of int for ms values
|
||||
if request.phys:
|
||||
raise NotImplementedError("TODO: add support for `request.phys`")
|
||||
|
||||
self.log.info('Scan')
|
||||
|
||||
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
|
||||
handler = self.device.on('advertisement', scan_queue.put_nowait)
|
||||
await self.device.start_scanning(
|
||||
legacy=request.legacy,
|
||||
active=not request.passive,
|
||||
own_address_type=request.own_address_type,
|
||||
scan_interval=int(request.interval)
|
||||
if request.interval
|
||||
else DEVICE_DEFAULT_SCAN_INTERVAL,
|
||||
scan_window=int(request.window)
|
||||
if request.window
|
||||
else DEVICE_DEFAULT_SCAN_WINDOW,
|
||||
)
|
||||
|
||||
try:
|
||||
# TODO: add support for `direct_address` in Bumble
|
||||
# TODO: add support for `periodic_advertising_interval` in Bumble
|
||||
while adv := await scan_queue.get():
|
||||
sr = ScanningResponse(
|
||||
legacy=adv.is_legacy,
|
||||
connectable=adv.is_connectable,
|
||||
scannable=adv.is_scannable,
|
||||
truncated=adv.is_truncated,
|
||||
sid=adv.sid,
|
||||
primary_phy=PRIMARY_PHY_MAP[adv.primary_phy],
|
||||
secondary_phy=SECONDARY_PHY_MAP[adv.secondary_phy],
|
||||
tx_power=adv.tx_power,
|
||||
rssi=adv.rssi,
|
||||
data=self.pack_data_types(adv.data),
|
||||
)
|
||||
|
||||
if adv.address.address_type == Address.PUBLIC_DEVICE_ADDRESS:
|
||||
sr.public = bytes(reversed(bytes(adv.address)))
|
||||
elif adv.address.address_type == Address.RANDOM_DEVICE_ADDRESS:
|
||||
sr.random = bytes(reversed(bytes(adv.address)))
|
||||
elif adv.address.address_type == Address.PUBLIC_IDENTITY_ADDRESS:
|
||||
sr.public_identity = bytes(reversed(bytes(adv.address)))
|
||||
else:
|
||||
sr.random_static_identity = bytes(reversed(bytes(adv.address)))
|
||||
|
||||
yield sr
|
||||
|
||||
finally:
|
||||
self.device.remove_listener('advertisement', handler) # type: ignore
|
||||
try:
|
||||
self.log.info('Stop scanning')
|
||||
await self.device.abort_on('flush', self.device.stop_scanning())
|
||||
except:
|
||||
pass
|
||||
|
||||
@utils.rpc
|
||||
async def Inquiry(
|
||||
self, request: empty_pb2.Empty, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[InquiryResponse, None]:
|
||||
self.log.info('Inquiry')
|
||||
|
||||
inquiry_queue: asyncio.Queue[
|
||||
Optional[Tuple[Address, int, AdvertisingData, int]]
|
||||
] = asyncio.Queue()
|
||||
complete_handler = self.device.on(
|
||||
'inquiry_complete', lambda: inquiry_queue.put_nowait(None)
|
||||
)
|
||||
result_handler = self.device.on( # type: ignore
|
||||
'inquiry_result',
|
||||
lambda address, class_of_device, eir_data, rssi: inquiry_queue.put_nowait( # type: ignore
|
||||
(address, class_of_device, eir_data, rssi) # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
await self.device.start_discovery(auto_restart=False)
|
||||
try:
|
||||
while inquiry_result := await inquiry_queue.get():
|
||||
(address, class_of_device, eir_data, rssi) = inquiry_result
|
||||
# FIXME: if needed, add support for `page_scan_repetition_mode` and `clock_offset` in Bumble
|
||||
yield InquiryResponse(
|
||||
address=bytes(reversed(bytes(address))),
|
||||
class_of_device=class_of_device,
|
||||
rssi=rssi,
|
||||
data=self.pack_data_types(eir_data),
|
||||
)
|
||||
|
||||
finally:
|
||||
self.device.remove_listener('inquiry_complete', complete_handler) # type: ignore
|
||||
self.device.remove_listener('inquiry_result', result_handler) # type: ignore
|
||||
try:
|
||||
self.log.info('Stop inquiry')
|
||||
await self.device.abort_on('flush', self.device.stop_discovery())
|
||||
except:
|
||||
pass
|
||||
|
||||
@utils.rpc
|
||||
async def SetDiscoverabilityMode(
|
||||
self, request: SetDiscoverabilityModeRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
self.log.info("SetDiscoverabilityMode")
|
||||
await self.device.set_discoverable(request.mode != NOT_DISCOVERABLE)
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@utils.rpc
|
||||
async def SetConnectabilityMode(
|
||||
self, request: SetConnectabilityModeRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
self.log.info("SetConnectabilityMode")
|
||||
await self.device.set_connectable(request.mode != NOT_CONNECTABLE)
|
||||
return empty_pb2.Empty()
|
||||
|
||||
def unpack_data_types(self, dt: DataTypes) -> AdvertisingData:
|
||||
ad_structures: List[Tuple[int, bytes]] = []
|
||||
|
||||
uuids: List[str]
|
||||
datas: Dict[str, bytes]
|
||||
|
||||
def uuid128_from_str(uuid: str) -> bytes:
|
||||
"""Decode a 128-bit uuid encoded as XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
|
||||
to byte format."""
|
||||
return bytes(reversed(bytes.fromhex(uuid.replace('-', ''))))
|
||||
|
||||
def uuid32_from_str(uuid: str) -> bytes:
|
||||
"""Decode a 32-bit uuid encoded as XXXXXXXX to byte format."""
|
||||
return bytes(reversed(bytes.fromhex(uuid)))
|
||||
|
||||
def uuid16_from_str(uuid: str) -> bytes:
|
||||
"""Decode a 16-bit uuid encoded as XXXX to byte format."""
|
||||
return bytes(reversed(bytes.fromhex(uuid)))
|
||||
|
||||
if uuids := dt.incomplete_service_class_uuids16:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
b''.join([uuid16_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.complete_service_class_uuids16:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
b''.join([uuid16_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.incomplete_service_class_uuids32:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
|
||||
b''.join([uuid32_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.complete_service_class_uuids32:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
|
||||
b''.join([uuid32_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.incomplete_service_class_uuids128:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
|
||||
b''.join([uuid128_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.complete_service_class_uuids128:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
|
||||
b''.join([uuid128_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if dt.HasField('include_shortened_local_name'):
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.SHORTENED_LOCAL_NAME,
|
||||
bytes(self.device.name[:8], 'utf-8'),
|
||||
)
|
||||
)
|
||||
elif dt.shortened_local_name:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.SHORTENED_LOCAL_NAME,
|
||||
bytes(dt.shortened_local_name, 'utf-8'),
|
||||
)
|
||||
)
|
||||
if dt.HasField('include_complete_local_name'):
|
||||
ad_structures.append(
|
||||
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(self.device.name, 'utf-8'))
|
||||
)
|
||||
elif dt.complete_local_name:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.COMPLETE_LOCAL_NAME,
|
||||
bytes(dt.complete_local_name, 'utf-8'),
|
||||
)
|
||||
)
|
||||
if dt.HasField('include_tx_power_level'):
|
||||
raise ValueError('unsupported data type')
|
||||
elif dt.tx_power_level:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.TX_POWER_LEVEL,
|
||||
bytes(struct.pack('<I', dt.tx_power_level)[:1]),
|
||||
)
|
||||
)
|
||||
if dt.HasField('include_class_of_device'):
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.CLASS_OF_DEVICE,
|
||||
bytes(struct.pack('<I', self.device.class_of_device)[:-1]),
|
||||
)
|
||||
)
|
||||
elif dt.class_of_device:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.CLASS_OF_DEVICE,
|
||||
bytes(struct.pack('<I', dt.class_of_device)[:-1]),
|
||||
)
|
||||
)
|
||||
if dt.peripheral_connection_interval_min:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE,
|
||||
bytes(
|
||||
[
|
||||
*struct.pack('<H', dt.peripheral_connection_interval_min),
|
||||
*struct.pack(
|
||||
'<H',
|
||||
dt.peripheral_connection_interval_max
|
||||
if dt.peripheral_connection_interval_max
|
||||
else dt.peripheral_connection_interval_min,
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
if uuids := dt.service_solicitation_uuids16:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
|
||||
b''.join([uuid16_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.service_solicitation_uuids32:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
|
||||
b''.join([uuid32_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if uuids := dt.service_solicitation_uuids128:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
|
||||
b''.join([uuid128_from_str(uuid) for uuid in uuids]),
|
||||
)
|
||||
)
|
||||
if datas := dt.service_data_uuid16:
|
||||
ad_structures.extend(
|
||||
[
|
||||
(
|
||||
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
|
||||
uuid16_from_str(uuid) + data,
|
||||
)
|
||||
for uuid, data in datas.items()
|
||||
]
|
||||
)
|
||||
if datas := dt.service_data_uuid32:
|
||||
ad_structures.extend(
|
||||
[
|
||||
(
|
||||
AdvertisingData.SERVICE_DATA_32_BIT_UUID,
|
||||
uuid32_from_str(uuid) + data,
|
||||
)
|
||||
for uuid, data in datas.items()
|
||||
]
|
||||
)
|
||||
if datas := dt.service_data_uuid128:
|
||||
ad_structures.extend(
|
||||
[
|
||||
(
|
||||
AdvertisingData.SERVICE_DATA_128_BIT_UUID,
|
||||
uuid128_from_str(uuid) + data,
|
||||
)
|
||||
for uuid, data in datas.items()
|
||||
]
|
||||
)
|
||||
if dt.appearance:
|
||||
ad_structures.append(
|
||||
(AdvertisingData.APPEARANCE, struct.pack('<H', dt.appearance))
|
||||
)
|
||||
if dt.advertising_interval:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.ADVERTISING_INTERVAL,
|
||||
struct.pack('<H', dt.advertising_interval),
|
||||
)
|
||||
)
|
||||
if dt.uri:
|
||||
ad_structures.append((AdvertisingData.URI, bytes(dt.uri, 'utf-8')))
|
||||
if dt.le_supported_features:
|
||||
ad_structures.append(
|
||||
(AdvertisingData.LE_SUPPORTED_FEATURES, dt.le_supported_features)
|
||||
)
|
||||
if dt.manufacturer_specific_data:
|
||||
ad_structures.append(
|
||||
(
|
||||
AdvertisingData.MANUFACTURER_SPECIFIC_DATA,
|
||||
dt.manufacturer_specific_data,
|
||||
)
|
||||
)
|
||||
|
||||
return AdvertisingData(ad_structures)
|
||||
|
||||
def pack_data_types(self, ad: AdvertisingData) -> DataTypes:
|
||||
dt = DataTypes()
|
||||
uuids: List[UUID]
|
||||
s: str
|
||||
i: int
|
||||
ij: Tuple[int, int]
|
||||
uuid_data: Tuple[UUID, bytes]
|
||||
data: bytes
|
||||
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.incomplete_service_class_uuids16.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.complete_service_class_uuids16.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.incomplete_service_class_uuids32.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.complete_service_class_uuids32.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.incomplete_service_class_uuids128.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.complete_service_class_uuids128.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if s := cast(str, ad.get(AdvertisingData.SHORTENED_LOCAL_NAME)):
|
||||
dt.shortened_local_name = s
|
||||
if s := cast(str, ad.get(AdvertisingData.COMPLETE_LOCAL_NAME)):
|
||||
dt.complete_local_name = s
|
||||
if i := cast(int, ad.get(AdvertisingData.TX_POWER_LEVEL)):
|
||||
dt.tx_power_level = i
|
||||
if i := cast(int, ad.get(AdvertisingData.CLASS_OF_DEVICE)):
|
||||
dt.class_of_device = i
|
||||
if ij := cast(
|
||||
Tuple[int, int],
|
||||
ad.get(AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE),
|
||||
):
|
||||
dt.peripheral_connection_interval_min = ij[0]
|
||||
dt.peripheral_connection_interval_max = ij[1]
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS),
|
||||
):
|
||||
dt.service_solicitation_uuids16.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS),
|
||||
):
|
||||
dt.service_solicitation_uuids32.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
ad.get(AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS),
|
||||
):
|
||||
dt.service_solicitation_uuids128.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuid_data := cast(
|
||||
Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_16_BIT_UUID)
|
||||
):
|
||||
dt.service_data_uuid16[uuid_data[0].to_hex_str('-')] = uuid_data[1]
|
||||
if uuid_data := cast(
|
||||
Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_32_BIT_UUID)
|
||||
):
|
||||
dt.service_data_uuid32[uuid_data[0].to_hex_str('-')] = uuid_data[1]
|
||||
if uuid_data := cast(
|
||||
Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_128_BIT_UUID)
|
||||
):
|
||||
dt.service_data_uuid128[uuid_data[0].to_hex_str('-')] = uuid_data[1]
|
||||
if data := cast(bytes, ad.get(AdvertisingData.PUBLIC_TARGET_ADDRESS, raw=True)):
|
||||
dt.public_target_addresses.extend(
|
||||
[data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))]
|
||||
)
|
||||
if data := cast(bytes, ad.get(AdvertisingData.RANDOM_TARGET_ADDRESS, raw=True)):
|
||||
dt.random_target_addresses.extend(
|
||||
[data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))]
|
||||
)
|
||||
if i := cast(int, ad.get(AdvertisingData.APPEARANCE)):
|
||||
dt.appearance = i
|
||||
if i := cast(int, ad.get(AdvertisingData.ADVERTISING_INTERVAL)):
|
||||
dt.advertising_interval = i
|
||||
if s := cast(str, ad.get(AdvertisingData.URI)):
|
||||
dt.uri = s
|
||||
if data := cast(bytes, ad.get(AdvertisingData.LE_SUPPORTED_FEATURES, raw=True)):
|
||||
dt.le_supported_features = data
|
||||
if data := cast(
|
||||
bytes, ad.get(AdvertisingData.MANUFACTURER_SPECIFIC_DATA, raw=True)
|
||||
):
|
||||
dt.manufacturer_specific_data = data
|
||||
|
||||
return dt
|
||||
@@ -0,0 +1,526 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import grpc
|
||||
import logging
|
||||
|
||||
from . import utils
|
||||
from .config import Config
|
||||
from bumble import hci
|
||||
from bumble.core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_LE_TRANSPORT,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ProtocolError,
|
||||
)
|
||||
from bumble.device import Connection as BumbleConnection, Device
|
||||
from bumble.hci import HCI_Error
|
||||
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
|
||||
from contextlib import suppress
|
||||
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
|
||||
from pandora.host_pb2 import Connection
|
||||
from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer
|
||||
from pandora.security_pb2 import (
|
||||
LE_LEVEL1,
|
||||
LE_LEVEL2,
|
||||
LE_LEVEL3,
|
||||
LE_LEVEL4,
|
||||
LEVEL0,
|
||||
LEVEL1,
|
||||
LEVEL2,
|
||||
LEVEL3,
|
||||
LEVEL4,
|
||||
DeleteBondRequest,
|
||||
IsBondedRequest,
|
||||
LESecurityLevel,
|
||||
PairingEvent,
|
||||
PairingEventAnswer,
|
||||
SecureRequest,
|
||||
SecureResponse,
|
||||
SecurityLevel,
|
||||
WaitSecurityRequest,
|
||||
WaitSecurityResponse,
|
||||
)
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Optional, Union
|
||||
|
||||
|
||||
class PairingDelegate(BasePairingDelegate):
|
||||
def __init__(
|
||||
self,
|
||||
connection: BumbleConnection,
|
||||
service: "SecurityService",
|
||||
io_capability: BasePairingDelegate.IoCapability = BasePairingDelegate.NO_OUTPUT_NO_INPUT,
|
||||
local_initiator_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
|
||||
local_responder_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
|
||||
) -> None:
|
||||
self.log = utils.BumbleServerLoggerAdapter(
|
||||
logging.getLogger(),
|
||||
{'service_name': 'Security', 'device': connection.device},
|
||||
)
|
||||
self.connection = connection
|
||||
self.service = service
|
||||
super().__init__(
|
||||
io_capability,
|
||||
local_initiator_key_distribution,
|
||||
local_responder_key_distribution,
|
||||
)
|
||||
|
||||
async def accept(self) -> bool:
|
||||
return True
|
||||
|
||||
def add_origin(self, ev: PairingEvent) -> PairingEvent:
|
||||
if not self.connection.is_incomplete:
|
||||
assert ev.connection
|
||||
ev.connection.CopyFrom(
|
||||
Connection(
|
||||
cookie=any_pb2.Any(value=self.connection.handle.to_bytes(4, 'big'))
|
||||
)
|
||||
)
|
||||
else:
|
||||
# In BR/EDR, connection may not be complete,
|
||||
# use address instead
|
||||
assert self.connection.transport == BT_BR_EDR_TRANSPORT
|
||||
ev.address = bytes(reversed(bytes(self.connection.peer_address)))
|
||||
|
||||
return ev
|
||||
|
||||
async def confirm(self, auto: bool = False) -> bool:
|
||||
self.log.info(
|
||||
f"Pairing event: `just_works` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
if self.service.event_queue is None or self.service.event_answer is None:
|
||||
return True
|
||||
|
||||
event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
assert answer.event == event
|
||||
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
|
||||
return answer.confirm
|
||||
|
||||
async def compare_numbers(self, number: int, digits: int = 6) -> bool:
|
||||
self.log.info(
|
||||
f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
if self.service.event_queue is None or self.service.event_answer is None:
|
||||
raise RuntimeError('security: unhandled number comparison request')
|
||||
|
||||
event = self.add_origin(PairingEvent(numeric_comparison=number))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
assert answer.event == event
|
||||
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
|
||||
return answer.confirm
|
||||
|
||||
async def get_number(self) -> Optional[int]:
|
||||
self.log.info(
|
||||
f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
if self.service.event_queue is None or self.service.event_answer is None:
|
||||
raise RuntimeError('security: unhandled number request')
|
||||
|
||||
event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
assert answer.event == event
|
||||
if answer.answer_variant() is None:
|
||||
return None
|
||||
assert answer.answer_variant() == 'passkey'
|
||||
return answer.passkey
|
||||
|
||||
async def get_string(self, max_length: int) -> Optional[str]:
|
||||
self.log.info(
|
||||
f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
if self.service.event_queue is None or self.service.event_answer is None:
|
||||
raise RuntimeError('security: unhandled pin_code request')
|
||||
|
||||
event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
assert answer.event == event
|
||||
if answer.answer_variant() is None:
|
||||
return None
|
||||
assert answer.answer_variant() == 'pin'
|
||||
|
||||
if answer.pin is None:
|
||||
return None
|
||||
|
||||
pin = answer.pin.decode('utf-8')
|
||||
if not pin or len(pin) > max_length:
|
||||
raise ValueError(f'Pin must be utf-8 encoded up to {max_length} bytes')
|
||||
|
||||
return pin
|
||||
|
||||
async def display_number(self, number: int, digits: int = 6) -> None:
|
||||
if (
|
||||
self.connection.transport == BT_BR_EDR_TRANSPORT
|
||||
and self.io_capability == BasePairingDelegate.DISPLAY_OUTPUT_ONLY
|
||||
):
|
||||
return
|
||||
|
||||
self.log.info(
|
||||
f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
if self.service.event_queue is None:
|
||||
raise RuntimeError('security: unhandled number display request')
|
||||
|
||||
event = self.add_origin(PairingEvent(passkey_entry_notification=number))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
|
||||
|
||||
BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = {
|
||||
LEVEL0: lambda connection: True,
|
||||
LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated,
|
||||
LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated,
|
||||
LEVEL3: lambda connection: connection.encryption != 0
|
||||
and connection.authenticated
|
||||
and connection.link_key_type
|
||||
in (
|
||||
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
|
||||
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
|
||||
),
|
||||
LEVEL4: lambda connection: connection.encryption
|
||||
== hci.HCI_Encryption_Change_Event.AES_CCM
|
||||
and connection.authenticated
|
||||
and connection.link_key_type
|
||||
== hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
|
||||
}
|
||||
|
||||
LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = {
|
||||
LE_LEVEL1: lambda connection: True,
|
||||
LE_LEVEL2: lambda connection: connection.encryption != 0,
|
||||
LE_LEVEL3: lambda connection: connection.encryption != 0
|
||||
and connection.authenticated,
|
||||
LE_LEVEL4: lambda connection: connection.encryption != 0
|
||||
and connection.authenticated
|
||||
and connection.sc,
|
||||
}
|
||||
|
||||
|
||||
class SecurityService(SecurityServicer):
|
||||
def __init__(self, device: Device, config: Config) -> None:
|
||||
self.log = utils.BumbleServerLoggerAdapter(
|
||||
logging.getLogger(), {'service_name': 'Security', 'device': device}
|
||||
)
|
||||
self.event_queue: Optional[asyncio.Queue[PairingEvent]] = None
|
||||
self.event_answer: Optional[AsyncIterator[PairingEventAnswer]] = None
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
def pairing_config_factory(connection: BumbleConnection) -> PairingConfig:
|
||||
return PairingConfig(
|
||||
sc=config.pairing_sc_enable,
|
||||
mitm=config.pairing_mitm_enable,
|
||||
bonding=config.pairing_bonding_enable,
|
||||
delegate=PairingDelegate(
|
||||
connection,
|
||||
self,
|
||||
io_capability=config.io_capability,
|
||||
local_initiator_key_distribution=config.smp_local_initiator_key_distribution,
|
||||
local_responder_key_distribution=config.smp_local_responder_key_distribution,
|
||||
),
|
||||
)
|
||||
|
||||
self.device.pairing_config_factory = pairing_config_factory
|
||||
|
||||
@utils.rpc
|
||||
async def OnPairing(
|
||||
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[PairingEvent, None]:
|
||||
self.log.info('OnPairing')
|
||||
|
||||
if self.event_queue is not None:
|
||||
raise RuntimeError('already streaming pairing events')
|
||||
|
||||
if len(self.device.connections):
|
||||
raise RuntimeError(
|
||||
'the `OnPairing` method shall be initiated before establishing any connections.'
|
||||
)
|
||||
|
||||
self.event_queue = asyncio.Queue()
|
||||
self.event_answer = request
|
||||
|
||||
try:
|
||||
while event := await self.event_queue.get():
|
||||
yield event
|
||||
|
||||
finally:
|
||||
self.event_queue = None
|
||||
self.event_answer = None
|
||||
|
||||
@utils.rpc
|
||||
async def Secure(
|
||||
self, request: SecureRequest, context: grpc.ServicerContext
|
||||
) -> SecureResponse:
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
self.log.info(f"Secure: {connection_handle}")
|
||||
|
||||
connection = self.device.lookup_connection(connection_handle)
|
||||
assert connection
|
||||
|
||||
oneof = request.WhichOneof('level')
|
||||
level = getattr(request, oneof)
|
||||
assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[
|
||||
connection.transport
|
||||
] == oneof
|
||||
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
return SecureResponse(success=empty_pb2.Empty())
|
||||
|
||||
# trigger pairing if needed
|
||||
if self.need_pairing(connection, level):
|
||||
try:
|
||||
self.log.info('Pair...')
|
||||
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
):
|
||||
wait_for_security: asyncio.Future[
|
||||
bool
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
|
||||
connection.on("pairing_failure", wait_for_security.set_exception)
|
||||
|
||||
connection.request_pairing()
|
||||
|
||||
await wait_for_security
|
||||
else:
|
||||
await connection.pair()
|
||||
|
||||
self.log.info('Paired')
|
||||
except asyncio.CancelledError:
|
||||
self.log.warning("Connection died during encryption")
|
||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||
except (HCI_Error, ProtocolError) as e:
|
||||
self.log.warning(f"Pairing failure: {e}")
|
||||
return SecureResponse(pairing_failure=empty_pb2.Empty())
|
||||
|
||||
# trigger authentication if needed
|
||||
if self.need_authentication(connection, level):
|
||||
try:
|
||||
self.log.info('Authenticate...')
|
||||
await connection.authenticate()
|
||||
self.log.info('Authenticated')
|
||||
except asyncio.CancelledError:
|
||||
self.log.warning("Connection died during authentication")
|
||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||
except (HCI_Error, ProtocolError) as e:
|
||||
self.log.warning(f"Authentication failure: {e}")
|
||||
return SecureResponse(authentication_failure=empty_pb2.Empty())
|
||||
|
||||
# trigger encryption if needed
|
||||
if self.need_encryption(connection, level):
|
||||
try:
|
||||
self.log.info('Encrypt...')
|
||||
await connection.encrypt()
|
||||
self.log.info('Encrypted')
|
||||
except asyncio.CancelledError:
|
||||
self.log.warning("Connection died during encryption")
|
||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||
except (HCI_Error, ProtocolError) as e:
|
||||
self.log.warning(f"Encryption failure: {e}")
|
||||
return SecureResponse(encryption_failure=empty_pb2.Empty())
|
||||
|
||||
# security level has been reached ?
|
||||
if self.reached_security_level(connection, level):
|
||||
return SecureResponse(success=empty_pb2.Empty())
|
||||
return SecureResponse(not_reached=empty_pb2.Empty())
|
||||
|
||||
@utils.rpc
|
||||
async def WaitSecurity(
|
||||
self, request: WaitSecurityRequest, context: grpc.ServicerContext
|
||||
) -> WaitSecurityResponse:
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
self.log.info(f"WaitSecurity: {connection_handle}")
|
||||
|
||||
connection = self.device.lookup_connection(connection_handle)
|
||||
assert connection
|
||||
|
||||
assert request.level
|
||||
level = request.level
|
||||
assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[
|
||||
connection.transport
|
||||
] == request.level_variant()
|
||||
|
||||
wait_for_security: asyncio.Future[
|
||||
str
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
authenticate_task: Optional[asyncio.Future[None]] = None
|
||||
|
||||
async def authenticate() -> None:
|
||||
assert connection
|
||||
if (encryption := connection.encryption) != 0:
|
||||
self.log.debug('Disable encryption...')
|
||||
try:
|
||||
await connection.encrypt(enable=False)
|
||||
except:
|
||||
pass
|
||||
self.log.debug('Disable encryption: done')
|
||||
|
||||
self.log.debug('Authenticate...')
|
||||
await connection.authenticate()
|
||||
self.log.debug('Authenticate: done')
|
||||
|
||||
if encryption != 0 and connection.encryption != encryption:
|
||||
self.log.debug('Re-enable encryption...')
|
||||
await connection.encrypt()
|
||||
self.log.debug('Re-enable encryption: done')
|
||||
|
||||
def set_failure(name: str) -> Callable[..., None]:
|
||||
def wrapper(*args: Any) -> None:
|
||||
self.log.info(f'Wait for security: error `{name}`: {args}')
|
||||
wait_for_security.set_result(name)
|
||||
|
||||
return wrapper
|
||||
|
||||
def try_set_success(*_: Any) -> None:
|
||||
assert connection
|
||||
if self.reached_security_level(connection, level):
|
||||
self.log.info('Wait for security: done')
|
||||
wait_for_security.set_result('success')
|
||||
|
||||
def on_encryption_change(*_: Any) -> None:
|
||||
assert connection
|
||||
if self.reached_security_level(connection, level):
|
||||
self.log.info('Wait for security: done')
|
||||
wait_for_security.set_result('success')
|
||||
elif (
|
||||
connection.transport == BT_BR_EDR_TRANSPORT
|
||||
and self.need_authentication(connection, level)
|
||||
):
|
||||
nonlocal authenticate_task
|
||||
if authenticate_task is None:
|
||||
authenticate_task = asyncio.create_task(authenticate())
|
||||
|
||||
listeners: Dict[str, Callable[..., None]] = {
|
||||
'disconnection': set_failure('connection_died'),
|
||||
'pairing_failure': set_failure('pairing_failure'),
|
||||
'connection_authentication_failure': set_failure('authentication_failure'),
|
||||
'connection_encryption_failure': set_failure('encryption_failure'),
|
||||
'pairing': try_set_success,
|
||||
'connection_authentication': try_set_success,
|
||||
'connection_encryption_change': on_encryption_change,
|
||||
}
|
||||
|
||||
# register event handlers
|
||||
for event, listener in listeners.items():
|
||||
connection.on(event, listener)
|
||||
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
return WaitSecurityResponse(success=empty_pb2.Empty())
|
||||
|
||||
self.log.info('Wait for security...')
|
||||
kwargs = {}
|
||||
kwargs[await wait_for_security] = empty_pb2.Empty()
|
||||
|
||||
# remove event handlers
|
||||
for event, listener in listeners.items():
|
||||
connection.remove_listener(event, listener) # type: ignore
|
||||
|
||||
# wait for `authenticate` to finish if any
|
||||
if authenticate_task is not None:
|
||||
self.log.info('Wait for authentication...')
|
||||
try:
|
||||
await authenticate_task # type: ignore
|
||||
except:
|
||||
pass
|
||||
self.log.info('Authenticated')
|
||||
|
||||
return WaitSecurityResponse(**kwargs)
|
||||
|
||||
def reached_security_level(
|
||||
self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
|
||||
) -> bool:
|
||||
self.log.debug(
|
||||
str(
|
||||
{
|
||||
'level': level,
|
||||
'encryption': connection.encryption,
|
||||
'authenticated': connection.authenticated,
|
||||
'sc': connection.sc,
|
||||
'link_key_type': connection.link_key_type,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(level, LESecurityLevel):
|
||||
return LE_LEVEL_REACHED[level](connection)
|
||||
|
||||
return BR_LEVEL_REACHED[level](connection)
|
||||
|
||||
def need_pairing(self, connection: BumbleConnection, level: int) -> bool:
|
||||
if connection.transport == BT_LE_TRANSPORT:
|
||||
return level >= LE_LEVEL3 and not connection.authenticated
|
||||
return False
|
||||
|
||||
def need_authentication(self, connection: BumbleConnection, level: int) -> bool:
|
||||
if connection.transport == BT_LE_TRANSPORT:
|
||||
return False
|
||||
if level == LEVEL2 and connection.encryption != 0:
|
||||
return not connection.authenticated
|
||||
return level >= LEVEL2 and not connection.authenticated
|
||||
|
||||
def need_encryption(self, connection: BumbleConnection, level: int) -> bool:
|
||||
# TODO(abel): need to support MITM
|
||||
if connection.transport == BT_LE_TRANSPORT:
|
||||
return level == LE_LEVEL2 and not connection.encryption
|
||||
return level >= LEVEL2 and not connection.encryption
|
||||
|
||||
|
||||
class SecurityStorageService(SecurityStorageServicer):
|
||||
def __init__(self, device: Device, config: Config) -> None:
|
||||
self.log = utils.BumbleServerLoggerAdapter(
|
||||
logging.getLogger(), {'service_name': 'SecurityStorage', 'device': device}
|
||||
)
|
||||
self.device = device
|
||||
self.config = config
|
||||
|
||||
@utils.rpc
|
||||
async def IsBonded(
|
||||
self, request: IsBondedRequest, context: grpc.ServicerContext
|
||||
) -> wrappers_pb2.BoolValue:
|
||||
address = utils.address_from_request(request, request.WhichOneof("address"))
|
||||
self.log.info(f"IsBonded: {address}")
|
||||
|
||||
if self.device.keystore is not None:
|
||||
is_bonded = await self.device.keystore.get(str(address)) is not None
|
||||
else:
|
||||
is_bonded = False
|
||||
|
||||
return wrappers_pb2.BoolValue(value=is_bonded)
|
||||
|
||||
@utils.rpc
|
||||
async def DeleteBond(
|
||||
self, request: DeleteBondRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
address = utils.address_from_request(request, request.WhichOneof("address"))
|
||||
self.log.info(f"DeleteBond: {address}")
|
||||
|
||||
if self.device.keystore is not None:
|
||||
with suppress(KeyError):
|
||||
await self.device.keystore.delete(str(address))
|
||||
|
||||
return empty_pb2.Empty()
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import grpc
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from bumble.device import Device
|
||||
from bumble.hci import Address
|
||||
from google.protobuf.message import Message # pytype: disable=pyi-error
|
||||
from typing import Any, Dict, Generator, MutableMapping, Optional, Tuple
|
||||
|
||||
ADDRESS_TYPES: Dict[str, int] = {
|
||||
"public": Address.PUBLIC_DEVICE_ADDRESS,
|
||||
"random": Address.RANDOM_DEVICE_ADDRESS,
|
||||
"public_identity": Address.PUBLIC_IDENTITY_ADDRESS,
|
||||
"random_static_identity": Address.RANDOM_IDENTITY_ADDRESS,
|
||||
}
|
||||
|
||||
|
||||
def address_from_request(request: Message, field: Optional[str]) -> Address:
|
||||
if field is None:
|
||||
return Address.ANY
|
||||
return Address(bytes(reversed(getattr(request, field))), ADDRESS_TYPES[field])
|
||||
|
||||
|
||||
class BumbleServerLoggerAdapter(logging.LoggerAdapter): # type: ignore
|
||||
"""Formats logs from the PandoraClient."""
|
||||
|
||||
def process(
|
||||
self, msg: str, kwargs: MutableMapping[str, Any]
|
||||
) -> Tuple[str, MutableMapping[str, Any]]:
|
||||
assert self.extra
|
||||
service_name = self.extra['service_name']
|
||||
assert isinstance(service_name, str)
|
||||
device = self.extra['device']
|
||||
assert isinstance(device, Device)
|
||||
addr_bytes = bytes(
|
||||
reversed(bytes(device.public_address))
|
||||
) # pytype: disable=attribute-error
|
||||
addr = ':'.join([f'{x:02X}' for x in addr_bytes[4:]])
|
||||
return (f'[bumble.{service_name}:{addr}] {msg}', kwargs)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def exception_to_rpc_error(
|
||||
context: grpc.ServicerContext,
|
||||
) -> Generator[None, None, None]:
|
||||
try:
|
||||
yield None
|
||||
except NotImplementedError as e:
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
|
||||
context.set_details(str(e)) # type: ignore
|
||||
except ValueError as e:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT) # type: ignore
|
||||
context.set_details(str(e)) # type: ignore
|
||||
except RuntimeError as e:
|
||||
context.set_code(grpc.StatusCode.ABORTED) # type: ignore
|
||||
context.set_details(str(e)) # type: ignore
|
||||
|
||||
|
||||
# Decorate an RPC servicer method with a wrapper that transform exceptions to gRPC errors.
|
||||
def rpc(func: Any) -> Any:
|
||||
@functools.wraps(func)
|
||||
async def asyncgen_wrapper(
|
||||
self: Any, request: Any, context: grpc.ServicerContext
|
||||
) -> Any:
|
||||
with exception_to_rpc_error(context):
|
||||
async for v in func(self, request, context):
|
||||
yield v
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(
|
||||
self: Any, request: Any, context: grpc.ServicerContext
|
||||
) -> Any:
|
||||
with exception_to_rpc_error(context):
|
||||
return await func(self, request, context)
|
||||
|
||||
@functools.wraps(func)
|
||||
def gen_wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
|
||||
with exception_to_rpc_error(context):
|
||||
for v in func(self, request, context):
|
||||
yield v
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
|
||||
with exception_to_rpc_error(context):
|
||||
return func(self, request, context)
|
||||
|
||||
if inspect.isasyncgenfunction(func):
|
||||
return asyncgen_wrapper
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
|
||||
if inspect.isgenerator(func):
|
||||
return gen_wrapper
|
||||
|
||||
return wrapper
|
||||
+110
-67
@@ -19,8 +19,9 @@ import logging
|
||||
import asyncio
|
||||
|
||||
from pyee import EventEmitter
|
||||
from typing import Optional, Tuple, Callable, Dict, Union
|
||||
|
||||
from . import core
|
||||
from . import core, l2cap
|
||||
from .colors import color
|
||||
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError
|
||||
|
||||
@@ -105,7 +106,7 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def compute_fcs(buffer):
|
||||
def compute_fcs(buffer: bytes) -> int:
|
||||
result = 0xFF
|
||||
for byte in buffer:
|
||||
result = CRC_TABLE[result ^ byte]
|
||||
@@ -114,7 +115,15 @@ def compute_fcs(buffer):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class RFCOMM_Frame:
|
||||
def __init__(self, frame_type, c_r, dlci, p_f, information=b'', with_credits=False):
|
||||
def __init__(
|
||||
self,
|
||||
frame_type: int,
|
||||
c_r: int,
|
||||
dlci: int,
|
||||
p_f: int,
|
||||
information: bytes = b'',
|
||||
with_credits: bool = False,
|
||||
) -> None:
|
||||
self.type = frame_type
|
||||
self.c_r = c_r
|
||||
self.dlci = dlci
|
||||
@@ -136,11 +145,11 @@ class RFCOMM_Frame:
|
||||
else:
|
||||
self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
|
||||
|
||||
def type_name(self):
|
||||
def type_name(self) -> str:
|
||||
return RFCOMM_FRAME_TYPE_NAMES[self.type]
|
||||
|
||||
@staticmethod
|
||||
def parse_mcc(data):
|
||||
def parse_mcc(data) -> Tuple[int, int, bytes]:
|
||||
mcc_type = data[0] >> 2
|
||||
c_r = (data[0] >> 1) & 1
|
||||
length = data[1]
|
||||
@@ -154,36 +163,36 @@ class RFCOMM_Frame:
|
||||
return (mcc_type, c_r, value)
|
||||
|
||||
@staticmethod
|
||||
def make_mcc(mcc_type, c_r, data):
|
||||
def make_mcc(mcc_type: int, c_r: int, data: bytes) -> bytes:
|
||||
return (
|
||||
bytes([(mcc_type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1])
|
||||
+ data
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def sabm(c_r, dlci):
|
||||
def sabm(c_r: int, dlci: int):
|
||||
return RFCOMM_Frame(RFCOMM_SABM_FRAME, c_r, dlci, 1)
|
||||
|
||||
@staticmethod
|
||||
def ua(c_r, dlci):
|
||||
def ua(c_r: int, dlci: int):
|
||||
return RFCOMM_Frame(RFCOMM_UA_FRAME, c_r, dlci, 1)
|
||||
|
||||
@staticmethod
|
||||
def dm(c_r, dlci):
|
||||
def dm(c_r: int, dlci: int):
|
||||
return RFCOMM_Frame(RFCOMM_DM_FRAME, c_r, dlci, 1)
|
||||
|
||||
@staticmethod
|
||||
def disc(c_r, dlci):
|
||||
def disc(c_r: int, dlci: int):
|
||||
return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1)
|
||||
|
||||
@staticmethod
|
||||
def uih(c_r, dlci, information, p_f=0):
|
||||
def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0):
|
||||
return RFCOMM_Frame(
|
||||
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data):
|
||||
def from_bytes(data: bytes):
|
||||
# Extract fields
|
||||
dlci = (data[0] >> 2) & 0x3F
|
||||
c_r = (data[0] >> 1) & 0x01
|
||||
@@ -227,15 +236,23 @@ class RFCOMM_Frame:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class RFCOMM_MCC_PN:
|
||||
dlci: int
|
||||
cl: int
|
||||
priority: int
|
||||
ack_timer: int
|
||||
max_frame_size: int
|
||||
max_retransmissions: int
|
||||
window_size: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dlci,
|
||||
cl,
|
||||
priority,
|
||||
ack_timer,
|
||||
max_frame_size,
|
||||
max_retransmissions,
|
||||
window_size,
|
||||
dlci: int,
|
||||
cl: int,
|
||||
priority: int,
|
||||
ack_timer: int,
|
||||
max_frame_size: int,
|
||||
max_retransmissions: int,
|
||||
window_size: int,
|
||||
):
|
||||
self.dlci = dlci
|
||||
self.cl = cl
|
||||
@@ -246,7 +263,7 @@ class RFCOMM_MCC_PN:
|
||||
self.window_size = window_size
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data):
|
||||
def from_bytes(data: bytes):
|
||||
return RFCOMM_MCC_PN(
|
||||
dlci=data[0],
|
||||
cl=data[1],
|
||||
@@ -285,7 +302,14 @@ class RFCOMM_MCC_PN:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class RFCOMM_MCC_MSC:
|
||||
def __init__(self, dlci, fc, rtc, rtr, ic, dv):
|
||||
dlci: int
|
||||
fc: int
|
||||
rtc: int
|
||||
rtr: int
|
||||
ic: int
|
||||
dv: int
|
||||
|
||||
def __init__(self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int):
|
||||
self.dlci = dlci
|
||||
self.fc = fc
|
||||
self.rtc = rtc
|
||||
@@ -294,7 +318,7 @@ class RFCOMM_MCC_MSC:
|
||||
self.dv = dv
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data):
|
||||
def from_bytes(data: bytes):
|
||||
return RFCOMM_MCC_MSC(
|
||||
dlci=data[0] >> 2,
|
||||
fc=data[1] >> 1 & 1,
|
||||
@@ -347,7 +371,12 @@ class DLC(EventEmitter):
|
||||
RESET: 'RESET',
|
||||
}
|
||||
|
||||
def __init__(self, multiplexer, dlci, max_frame_size, initial_tx_credits):
|
||||
connection_result: Optional[asyncio.Future]
|
||||
sink: Optional[Callable[[bytes], None]]
|
||||
|
||||
def __init__(
|
||||
self, multiplexer, dlci: int, max_frame_size: int, initial_tx_credits: int
|
||||
):
|
||||
super().__init__()
|
||||
self.multiplexer = multiplexer
|
||||
self.dlci = dlci
|
||||
@@ -368,23 +397,23 @@ class DLC(EventEmitter):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def state_name(state):
|
||||
def state_name(state: int) -> str:
|
||||
return DLC.STATE_NAMES[state]
|
||||
|
||||
def change_state(self, new_state):
|
||||
def change_state(self, new_state: int) -> None:
|
||||
logger.debug(
|
||||
f'{self} state change -> {color(self.state_name(new_state), "magenta")}'
|
||||
)
|
||||
self.state = new_state
|
||||
|
||||
def send_frame(self, frame):
|
||||
def send_frame(self, frame: RFCOMM_Frame) -> None:
|
||||
self.multiplexer.send_frame(frame)
|
||||
|
||||
def on_frame(self, frame):
|
||||
def on_frame(self, frame: RFCOMM_Frame) -> None:
|
||||
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
|
||||
handler(frame)
|
||||
|
||||
def on_sabm_frame(self, _frame):
|
||||
def on_sabm_frame(self, _frame) -> None:
|
||||
if self.state != DLC.CONNECTING:
|
||||
logger.warning(
|
||||
color('!!! received SABM when not in CONNECTING state', 'red')
|
||||
@@ -404,7 +433,7 @@ class DLC(EventEmitter):
|
||||
self.change_state(DLC.CONNECTED)
|
||||
self.emit('open')
|
||||
|
||||
def on_ua_frame(self, _frame):
|
||||
def on_ua_frame(self, _frame) -> None:
|
||||
if self.state != DLC.CONNECTING:
|
||||
logger.warning(
|
||||
color('!!! received SABM when not in CONNECTING state', 'red')
|
||||
@@ -422,15 +451,15 @@ class DLC(EventEmitter):
|
||||
self.change_state(DLC.CONNECTED)
|
||||
self.multiplexer.on_dlc_open_complete(self)
|
||||
|
||||
def on_dm_frame(self, frame):
|
||||
def on_dm_frame(self, frame) -> None:
|
||||
# TODO: handle all states
|
||||
pass
|
||||
|
||||
def on_disc_frame(self, _frame):
|
||||
def on_disc_frame(self, _frame) -> None:
|
||||
# TODO: handle all states
|
||||
self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
|
||||
|
||||
def on_uih_frame(self, frame):
|
||||
def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
|
||||
data = frame.information
|
||||
if frame.p_f == 1:
|
||||
# With credits
|
||||
@@ -460,10 +489,10 @@ class DLC(EventEmitter):
|
||||
# Check if there's anything to send (including credits)
|
||||
self.process_tx()
|
||||
|
||||
def on_ui_frame(self, frame):
|
||||
def on_ui_frame(self, frame) -> None:
|
||||
pass
|
||||
|
||||
def on_mcc_msc(self, c_r, msc):
|
||||
def on_mcc_msc(self, c_r, msc) -> None:
|
||||
if c_r:
|
||||
# Command
|
||||
logger.debug(f'<<< MCC MSC Command: {msc}')
|
||||
@@ -477,7 +506,7 @@ class DLC(EventEmitter):
|
||||
# Response
|
||||
logger.debug(f'<<< MCC MSC Response: {msc}')
|
||||
|
||||
def connect(self):
|
||||
def connect(self) -> None:
|
||||
if self.state != DLC.INIT:
|
||||
raise InvalidStateError('invalid state')
|
||||
|
||||
@@ -485,7 +514,7 @@ class DLC(EventEmitter):
|
||||
self.connection_result = asyncio.get_running_loop().create_future()
|
||||
self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
|
||||
|
||||
def accept(self):
|
||||
def accept(self) -> None:
|
||||
if self.state != DLC.INIT:
|
||||
raise InvalidStateError('invalid state')
|
||||
|
||||
@@ -503,13 +532,13 @@ class DLC(EventEmitter):
|
||||
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
|
||||
self.change_state(DLC.CONNECTING)
|
||||
|
||||
def rx_credits_needed(self):
|
||||
def rx_credits_needed(self) -> int:
|
||||
if self.rx_credits <= self.rx_threshold:
|
||||
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
|
||||
|
||||
return 0
|
||||
|
||||
def process_tx(self):
|
||||
def process_tx(self) -> None:
|
||||
# Send anything we can (or an empty frame if we need to send rx credits)
|
||||
rx_credits_needed = self.rx_credits_needed()
|
||||
while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0:
|
||||
@@ -547,7 +576,7 @@ class DLC(EventEmitter):
|
||||
rx_credits_needed = 0
|
||||
|
||||
# Stream protocol
|
||||
def write(self, data):
|
||||
def write(self, data: Union[bytes, str]) -> None:
|
||||
# We can only send bytes
|
||||
if not isinstance(data, bytes):
|
||||
if isinstance(data, str):
|
||||
@@ -559,7 +588,7 @@ class DLC(EventEmitter):
|
||||
self.tx_buffer += data
|
||||
self.process_tx()
|
||||
|
||||
def drain(self):
|
||||
def drain(self) -> None:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@@ -592,7 +621,13 @@ class Multiplexer(EventEmitter):
|
||||
RESET: 'RESET',
|
||||
}
|
||||
|
||||
def __init__(self, l2cap_channel, role):
|
||||
connection_result: Optional[asyncio.Future]
|
||||
disconnection_result: Optional[asyncio.Future]
|
||||
open_result: Optional[asyncio.Future]
|
||||
acceptor: Optional[Callable[[int], bool]]
|
||||
dlcs: Dict[int, DLC]
|
||||
|
||||
def __init__(self, l2cap_channel: l2cap.Channel, role: int) -> None:
|
||||
super().__init__()
|
||||
self.role = role
|
||||
self.l2cap_channel = l2cap_channel
|
||||
@@ -607,20 +642,20 @@ class Multiplexer(EventEmitter):
|
||||
l2cap_channel.sink = self.on_pdu
|
||||
|
||||
@staticmethod
|
||||
def state_name(state):
|
||||
def state_name(state: int):
|
||||
return Multiplexer.STATE_NAMES[state]
|
||||
|
||||
def change_state(self, new_state):
|
||||
def change_state(self, new_state: int) -> None:
|
||||
logger.debug(
|
||||
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
|
||||
)
|
||||
self.state = new_state
|
||||
|
||||
def send_frame(self, frame):
|
||||
def send_frame(self, frame: RFCOMM_Frame) -> None:
|
||||
logger.debug(f'>>> Multiplexer sending {frame}')
|
||||
self.l2cap_channel.send_pdu(frame)
|
||||
|
||||
def on_pdu(self, pdu):
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
frame = RFCOMM_Frame.from_bytes(pdu)
|
||||
logger.debug(f'<<< Multiplexer received {frame}')
|
||||
|
||||
@@ -640,18 +675,18 @@ class Multiplexer(EventEmitter):
|
||||
return
|
||||
dlc.on_frame(frame)
|
||||
|
||||
def on_frame(self, frame):
|
||||
def on_frame(self, frame: RFCOMM_Frame) -> None:
|
||||
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
|
||||
handler(frame)
|
||||
|
||||
def on_sabm_frame(self, _frame):
|
||||
def on_sabm_frame(self, _frame) -> None:
|
||||
if self.state != Multiplexer.INIT:
|
||||
logger.debug('not in INIT state, ignoring SABM')
|
||||
return
|
||||
self.change_state(Multiplexer.CONNECTED)
|
||||
self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0))
|
||||
|
||||
def on_ua_frame(self, _frame):
|
||||
def on_ua_frame(self, _frame) -> None:
|
||||
if self.state == Multiplexer.CONNECTING:
|
||||
self.change_state(Multiplexer.CONNECTED)
|
||||
if self.connection_result:
|
||||
@@ -663,7 +698,7 @@ class Multiplexer(EventEmitter):
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
|
||||
def on_dm_frame(self, _frame):
|
||||
def on_dm_frame(self, _frame) -> None:
|
||||
if self.state == Multiplexer.OPENING:
|
||||
self.change_state(Multiplexer.CONNECTED)
|
||||
if self.open_result:
|
||||
@@ -678,13 +713,13 @@ class Multiplexer(EventEmitter):
|
||||
else:
|
||||
logger.warning(f'unexpected state for DM: {self}')
|
||||
|
||||
def on_disc_frame(self, _frame):
|
||||
def on_disc_frame(self, _frame) -> None:
|
||||
self.change_state(Multiplexer.DISCONNECTED)
|
||||
self.send_frame(
|
||||
RFCOMM_Frame.ua(c_r=0 if self.role == Multiplexer.INITIATOR else 1, dlci=0)
|
||||
)
|
||||
|
||||
def on_uih_frame(self, frame):
|
||||
def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
|
||||
(mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
|
||||
|
||||
if mcc_type == RFCOMM_MCC_PN_TYPE:
|
||||
@@ -694,10 +729,10 @@ class Multiplexer(EventEmitter):
|
||||
mcs = RFCOMM_MCC_MSC.from_bytes(value)
|
||||
self.on_mcc_msc(c_r, mcs)
|
||||
|
||||
def on_ui_frame(self, frame):
|
||||
def on_ui_frame(self, frame) -> None:
|
||||
pass
|
||||
|
||||
def on_mcc_pn(self, c_r, pn):
|
||||
def on_mcc_pn(self, c_r, pn) -> None:
|
||||
if c_r == 1:
|
||||
# Command
|
||||
logger.debug(f'<<< PN Command: {pn}')
|
||||
@@ -736,14 +771,14 @@ class Multiplexer(EventEmitter):
|
||||
else:
|
||||
logger.warning('ignoring PN response')
|
||||
|
||||
def on_mcc_msc(self, c_r, msc):
|
||||
def on_mcc_msc(self, c_r, msc) -> None:
|
||||
dlc = self.dlcs.get(msc.dlci)
|
||||
if dlc is None:
|
||||
logger.warning(f'no dlc for DLCI {msc.dlci}')
|
||||
return
|
||||
dlc.on_mcc_msc(c_r, msc)
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self) -> None:
|
||||
if self.state != Multiplexer.INIT:
|
||||
raise InvalidStateError('invalid state')
|
||||
|
||||
@@ -752,7 +787,7 @@ class Multiplexer(EventEmitter):
|
||||
self.send_frame(RFCOMM_Frame.sabm(c_r=1, dlci=0))
|
||||
return await self.connection_result
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
if self.state != Multiplexer.CONNECTED:
|
||||
return
|
||||
|
||||
@@ -765,7 +800,7 @@ class Multiplexer(EventEmitter):
|
||||
)
|
||||
await self.disconnection_result
|
||||
|
||||
async def open_dlc(self, channel):
|
||||
async def open_dlc(self, channel: int) -> DLC:
|
||||
if self.state != Multiplexer.CONNECTED:
|
||||
if self.state == Multiplexer.OPENING:
|
||||
raise InvalidStateError('open already in progress')
|
||||
@@ -796,7 +831,7 @@ class Multiplexer(EventEmitter):
|
||||
self.open_result = None
|
||||
return result
|
||||
|
||||
def on_dlc_open_complete(self, dlc):
|
||||
def on_dlc_open_complete(self, dlc: DLC):
|
||||
logger.debug(f'DLC [{dlc.dlci}] open complete')
|
||||
self.change_state(Multiplexer.CONNECTED)
|
||||
if self.open_result:
|
||||
@@ -808,13 +843,16 @@ class Multiplexer(EventEmitter):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Client:
|
||||
def __init__(self, device, connection):
|
||||
multiplexer: Optional[Multiplexer]
|
||||
l2cap_channel: Optional[l2cap.Channel]
|
||||
|
||||
def __init__(self, device, connection) -> None:
|
||||
self.device = device
|
||||
self.connection = connection
|
||||
self.l2cap_channel = None
|
||||
self.multiplexer = None
|
||||
|
||||
async def start(self):
|
||||
async def start(self) -> Multiplexer:
|
||||
# Create a new L2CAP connection
|
||||
try:
|
||||
self.l2cap_channel = await self.device.l2cap_channel_manager.connect(
|
||||
@@ -824,6 +862,7 @@ class Client:
|
||||
logger.warning(f'L2CAP connection failed: {error}')
|
||||
raise
|
||||
|
||||
assert self.l2cap_channel is not None
|
||||
# Create a mutliplexer to manage DLCs with the server
|
||||
self.multiplexer = Multiplexer(self.l2cap_channel, Multiplexer.INITIATOR)
|
||||
|
||||
@@ -832,7 +871,9 @@ class Client:
|
||||
|
||||
return self.multiplexer
|
||||
|
||||
async def shutdown(self):
|
||||
async def shutdown(self) -> None:
|
||||
if self.multiplexer is None:
|
||||
return
|
||||
# Disconnect the multiplexer
|
||||
await self.multiplexer.disconnect()
|
||||
self.multiplexer = None
|
||||
@@ -843,7 +884,9 @@ class Client:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Server(EventEmitter):
|
||||
def __init__(self, device):
|
||||
acceptors: Dict[int, Callable[[DLC], None]]
|
||||
|
||||
def __init__(self, device) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.multiplexer = None
|
||||
@@ -852,7 +895,7 @@ class Server(EventEmitter):
|
||||
# Register ourselves with the L2CAP channel manager
|
||||
device.register_l2cap_server(RFCOMM_PSM, self.on_connection)
|
||||
|
||||
def listen(self, acceptor, channel=0):
|
||||
def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int:
|
||||
if channel:
|
||||
if channel in self.acceptors:
|
||||
# Busy
|
||||
@@ -874,11 +917,11 @@ class Server(EventEmitter):
|
||||
self.acceptors[channel] = acceptor
|
||||
return channel
|
||||
|
||||
def on_connection(self, l2cap_channel):
|
||||
def on_connection(self, l2cap_channel: l2cap.Channel) -> None:
|
||||
logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')
|
||||
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
|
||||
|
||||
def on_l2cap_channel_open(self, l2cap_channel):
|
||||
def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None:
|
||||
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
|
||||
|
||||
# Create a new multiplexer for the channel
|
||||
@@ -889,10 +932,10 @@ class Server(EventEmitter):
|
||||
# Notify
|
||||
self.emit('start', multiplexer)
|
||||
|
||||
def accept_dlc(self, channel_number):
|
||||
def accept_dlc(self, channel_number: int) -> bool:
|
||||
return channel_number in self.acceptors
|
||||
|
||||
def on_dlc(self, dlc):
|
||||
def on_dlc(self, dlc: DLC) -> None:
|
||||
logger.debug(f'@@@ new DLC connected: {dlc}')
|
||||
|
||||
# Let the acceptor know
|
||||
|
||||
+1
-1
@@ -1805,7 +1805,7 @@ class Manager(EventEmitter):
|
||||
self.device.abort_on('flush', store_keys())
|
||||
|
||||
# Notify the device
|
||||
self.device.on_pairing(session.connection, keys, session.sc)
|
||||
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
|
||||
|
||||
def on_pairing_failure(self, session: Session, reason: int) -> None:
|
||||
self.device.on_pairing_failure(session.connection, reason)
|
||||
|
||||
@@ -2,5 +2,5 @@
|
||||
mkdocs == 1.4.0
|
||||
mkdocs-material == 8.5.6
|
||||
mkdocs-material-extensions == 1.0.3
|
||||
pymdown-extensions == 9.6
|
||||
pymdown-extensions == 10.0
|
||||
mkdocstrings-python == 0.7.1
|
||||
|
||||
@@ -40,6 +40,9 @@ disable = [
|
||||
"too-many-statements",
|
||||
]
|
||||
|
||||
[tool.pylint.main]
|
||||
ignore="pandora" # FIXME: pylint does not support stubs yet:
|
||||
|
||||
[tool.pylint.typecheck]
|
||||
signature-mutators="AsyncRunner.run_in_task"
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ url = https://github.com/google/bumble
|
||||
|
||||
[options]
|
||||
python_requires = >=3.8
|
||||
packages = bumble, bumble.transport, bumble.profiles, bumble.apps, bumble.apps.link_relay
|
||||
packages = bumble, bumble.transport, bumble.profiles, bumble.apps, bumble.apps.link_relay, bumble.pandora
|
||||
package_dir =
|
||||
bumble = bumble
|
||||
bumble.apps = apps
|
||||
@@ -33,7 +33,7 @@ install_requires =
|
||||
appdirs >= 1.4
|
||||
click >= 7.1.2; platform_system!='Emscripten'
|
||||
cryptography == 35; platform_system!='Emscripten'
|
||||
grpcio >= 1.46; platform_system!='Emscripten'
|
||||
grpcio == 1.51.1; platform_system!='Emscripten'
|
||||
libusb1 >= 2.0.1; platform_system!='Emscripten'
|
||||
libusb-package == 1.0.26.1; platform_system!='Emscripten'
|
||||
prompt_toolkit >= 3.0.16; platform_system!='Emscripten'
|
||||
@@ -45,6 +45,7 @@ install_requires =
|
||||
websockets >= 8.1; platform_system!='Emscripten'
|
||||
prettytable >= 3.6.0
|
||||
humanize >= 4.6.0
|
||||
bt-test-interfaces >= 0.0.2
|
||||
|
||||
[options.entry_points]
|
||||
console_scripts =
|
||||
@@ -60,6 +61,7 @@ console_scripts =
|
||||
bumble-usb-probe = bumble.apps.usb_probe:main
|
||||
bumble-link-relay = bumble.apps.link_relay.link_relay:main
|
||||
bumble-bench = bumble.apps.bench:main
|
||||
bumble-pandora-server = bumble.apps.pandora_server:main
|
||||
|
||||
[options.package_data]
|
||||
* = py.typed, *.pyi
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from bumble.keys import JsonKeyStore, PairingKeys
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
JSON1 = """
|
||||
{
|
||||
"my_namespace": {
|
||||
"14:7D:DA:4E:53:A8/P": {
|
||||
"address_type": 0,
|
||||
"irk": {
|
||||
"authenticated": false,
|
||||
"value": "e7b2543b206e4e46b44f9e51dad22bd1"
|
||||
},
|
||||
"link_key": {
|
||||
"authenticated": false,
|
||||
"value": "0745dd9691e693d9dca740f7d8dfea75"
|
||||
},
|
||||
"ltk": {
|
||||
"authenticated": false,
|
||||
"value": "d1897ee10016eb1a08e4e037fd54c683"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
JSON2 = """
|
||||
{
|
||||
"my_namespace1": {
|
||||
},
|
||||
"my_namespace2": {
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
JSON3 = """
|
||||
{
|
||||
"my_namespace1": {
|
||||
},
|
||||
"__DEFAULT__": {
|
||||
"14:7D:DA:4E:53:A8/P": {
|
||||
"address_type": 0,
|
||||
"irk": {
|
||||
"authenticated": false,
|
||||
"value": "e7b2543b206e4e46b44f9e51dad22bd1"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def test_basic():
|
||||
with tempfile.NamedTemporaryFile(mode="r+", encoding='utf-8') as file:
|
||||
keystore = JsonKeyStore('my_namespace', file.name)
|
||||
file.write("{}")
|
||||
file.flush()
|
||||
|
||||
keys = await keystore.get_all()
|
||||
assert len(keys) == 0
|
||||
|
||||
keys = PairingKeys()
|
||||
await keystore.update('foo', keys)
|
||||
foo = await keystore.get('foo')
|
||||
assert foo is not None
|
||||
assert foo.ltk is None
|
||||
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
|
||||
keys.ltk = PairingKeys.Key(ltk)
|
||||
await keystore.update('foo', keys)
|
||||
foo = await keystore.get('foo')
|
||||
assert foo is not None
|
||||
assert foo.ltk is not None
|
||||
assert foo.ltk.value == ltk
|
||||
|
||||
file.flush()
|
||||
with open(file.name, "r", encoding="utf-8") as json_file:
|
||||
json_data = json.load(json_file)
|
||||
assert 'my_namespace' in json_data
|
||||
assert 'foo' in json_data['my_namespace']
|
||||
assert 'ltk' in json_data['my_namespace']['foo']
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def test_parsing():
|
||||
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
|
||||
keystore = JsonKeyStore('my_namespace', file.name)
|
||||
file.write(JSON1)
|
||||
file.flush()
|
||||
|
||||
foo = await keystore.get('14:7D:DA:4E:53:A8/P')
|
||||
assert foo is not None
|
||||
assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def test_default_namespace():
|
||||
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
|
||||
keystore = JsonKeyStore(None, file.name)
|
||||
file.write(JSON1)
|
||||
file.flush()
|
||||
|
||||
all_keys = await keystore.get_all()
|
||||
assert len(all_keys) == 1
|
||||
name, keys = all_keys[0]
|
||||
assert name == '14:7D:DA:4E:53:A8/P'
|
||||
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
|
||||
keystore = JsonKeyStore(None, file.name)
|
||||
file.write(JSON2)
|
||||
file.flush()
|
||||
|
||||
keys = PairingKeys()
|
||||
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
|
||||
keys.ltk = PairingKeys.Key(ltk)
|
||||
await keystore.update('foo', keys)
|
||||
file.flush()
|
||||
with open(file.name, "r", encoding="utf-8") as json_file:
|
||||
json_data = json.load(json_file)
|
||||
assert '__DEFAULT__' in json_data
|
||||
assert 'foo' in json_data['__DEFAULT__']
|
||||
assert 'ltk' in json_data['__DEFAULT__']['foo']
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
|
||||
keystore = JsonKeyStore(None, file.name)
|
||||
file.write(JSON3)
|
||||
file.flush()
|
||||
|
||||
all_keys = await keystore.get_all()
|
||||
assert len(all_keys) == 1
|
||||
name, keys = all_keys[0]
|
||||
assert name == '14:7D:DA:4E:53:A8/P'
|
||||
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run_tests():
|
||||
await test_basic()
|
||||
await test_parsing()
|
||||
await test_default_namespace()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
asyncio.run(run_tests())
|
||||
Reference in New Issue
Block a user