From 740a2e0ca0000b5ad85e55102c77162949de95ef Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Fri, 5 May 2023 15:57:13 -0700 Subject: [PATCH 1/7] instantiate keystore after power_on --- apps/pair.py | 28 +++++++++++------------ apps/scan.py | 11 ++++----- apps/unbond.py | 62 ++++++++++++++++++++++++++++++++------------------ 3 files changed, 59 insertions(+), 42 deletions(-) diff --git a/apps/pair.py b/apps/pair.py index a7844fe..162442a 100644 --- a/apps/pair.py +++ b/apps/pair.py @@ -207,7 +207,7 @@ def on_connection(connection, request): # Listen for pairing events connection.on('pairing_start', on_pairing_start) - connection.on('pairing', on_pairing) + connection.on('pairing', lambda keys: on_pairing(connection.peer_address, keys)) connection.on('pairing_failure', on_pairing_failure) # Listen for encryption changes @@ -242,9 +242,9 @@ def on_pairing_start(): # ----------------------------------------------------------------------------- -def on_pairing(keys): +def on_pairing(address, keys): print(color('***-----------------------------------', 'cyan')) - print(color('*** Paired!', 'cyan')) + print(color(f'*** Paired! (peer identity={address})', 'cyan')) keys.print(prefix=color('*** ', 'cyan')) print(color('***-----------------------------------', 'cyan')) Waiter.instance.terminate() @@ -283,17 +283,6 @@ async def pair( # Create a device to manage the host device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) - # Set a custom keystore if specified on the command line - if keystore_file: - device.keystore = JsonKeyStore(namespace=None, filename=keystore_file) - - # Print the existing keys before pairing - if print_keys and device.keystore: - print(color('@@@-----------------------------------', 'blue')) - print(color('@@@ Pairing Keys:', 'blue')) - await device.keystore.print(prefix=color('@@@ ', 'blue')) - print(color('@@@-----------------------------------', 'blue')) - # Expose a GATT characteristic that can be used to trigger pairing by # responding with an authentication error when read if mode == 'le': @@ -323,6 +312,17 @@ async def pair( # Get things going await device.power_on() + # Set a custom keystore if specified on the command line + if keystore_file: + device.keystore = JsonKeyStore.from_device(device, filename=keystore_file) + + # Print the existing keys before pairing + if print_keys and device.keystore: + print(color('@@@-----------------------------------', 'blue')) + print(color('@@@ Pairing Keys:', 'blue')) + await device.keystore.print(prefix=color('@@@ ', 'blue')) + print(color('@@@-----------------------------------', 'blue')) + # Set up a pairing config factory device.pairing_config_factory = lambda connection: PairingConfig( sc, mitm, bond, Delegate(mode, connection, io, prompt) diff --git a/apps/scan.py b/apps/scan.py index dac7a2c..268912f 100644 --- a/apps/scan.py +++ b/apps/scan.py @@ -133,15 +133,16 @@ async def scan( 'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink ) + await device.power_on() + if keystore_file: - keystore = JsonKeyStore(namespace=None, filename=keystore_file) - device.keystore = keystore - else: - resolver = None + device.keystore = JsonKeyStore.from_device(device, filename=keystore_file) if device.keystore: resolving_keys = await device.keystore.get_resolving_keys() resolver = AddressResolver(resolving_keys) + else: + resolver = None printer = AdvertisementPrinter(min_rssi, resolver) if raw: @@ -149,8 +150,6 @@ async def scan( else: device.on('advertisement', printer.on_advertisement) - await device.power_on() - if phy is None: scanning_phys = [HCI_LE_1M_PHY, HCI_LE_CODED_PHY] else: diff --git a/apps/unbond.py b/apps/unbond.py index 105d9a4..5ffd746 100644 --- a/apps/unbond.py +++ b/apps/unbond.py @@ -22,40 +22,58 @@ import click from bumble.device import Device from bumble.keys import JsonKeyStore +from bumble.transport import open_transport + +# ----------------------------------------------------------------------------- +async def unbond_with_keystore(keystore, address): + if address is None: + return await keystore.print() + + try: + await keystore.delete(address) + except KeyError: + print('!!! pairing not found') # ----------------------------------------------------------------------------- -async def unbond(keystore_file, device_config, address): - # Create a device to manage the host - device = Device.from_config_file(device_config) - - # Get all entries in the keystore +async def unbond(keystore_file, device_config, hci_transport, address): + # With a keystore file, we can instantiate the keystore directly if keystore_file: - keystore = JsonKeyStore(None, keystore_file) - else: - keystore = device.keystore + return await unbond_with_keystore(JsonKeyStore(None, keystore_file), address) - if keystore is None: - print('no keystore') - return + # Without a keystore file, we need to obtain the keystore from the device + async with await open_transport(hci_transport) as (hci_source, hci_sink): + # Create a device to manage the host + device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) - if address is None: - await keystore.print() - else: - try: - await keystore.delete(address) - except KeyError: - print('!!! pairing not found') + # Power-on the device to ensure we have a key store + await device.power_on() + + return await unbond_with_keystore(device.keystore, address) # ----------------------------------------------------------------------------- @click.command() -@click.option('--keystore-file', help='File in which to store the pairing keys') -@click.argument('device-config') +@click.option('--keystore-file', help='File in which the pairing keys are stored') +@click.option('--hci-transport', help='HCI transport for the controller') +@click.argument('device-config', required=False) @click.argument('address', required=False) -def main(keystore_file, device_config, address): +def main(keystore_file, hci_transport, device_config, address): + """ + Remove pairing keys for a device, given its address. + + If no keystore file is specified, the --hci-transport option must be used to + connect to a controller, so that the keystore for that controller can be + instantiated. + If no address is passed, the existing pairing keys for all addresses are printed. + """ logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) - asyncio.run(unbond(keystore_file, device_config, address)) + + if not keystore_file and not hci_transport: + print('either --keystore-file or --hci-transport must be specified.') + return + + asyncio.run(unbond(keystore_file, device_config, hci_transport, address)) # ----------------------------------------------------------------------------- From 3de35193bc3b8247372491e55f605517632cb3dd Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Fri, 5 May 2023 15:57:50 -0700 Subject: [PATCH 2/7] rebase --- bumble/device.py | 11 ++++++++++- bumble/smp.py | 4 +++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/bumble/device.py b/bumble/device.py index 72fd755..def02f5 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -3088,7 +3088,16 @@ class Device(CompositeEventEmitter): def on_pairing_start(self, connection: Connection) -> None: connection.emit('pairing_start') - def on_pairing(self, connection: Connection, keys: PairingKeys, sc: bool) -> None: + def on_pairing( + self, + connection: Connection, + identity_address: Address, + keys: PairingKeys, + sc: bool, + ) -> None: + if identity_address is not None: + connection.peer_resolvable_address = connection.peer_address + connection.peer_address = identity_address connection.sc = sc connection.authenticated = True connection.emit('pairing', keys) diff --git a/bumble/smp.py b/bumble/smp.py index f3fbf27..79ae578 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -1805,7 +1805,9 @@ class Manager(EventEmitter): self.device.abort_on('flush', store_keys()) # Notify the device - self.device.on_pairing(session.connection, keys, session.sc) + self.device.on_pairing( + session.connection, identity_address, keys, session.sc + ) def on_pairing_failure(self, session: Session, reason: int) -> None: self.device.on_pairing_failure(session.connection, reason) From f80c83d0b3b9134aaf7661b7cb3927de81b12cc9 Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Fri, 5 May 2023 15:58:21 -0700 Subject: [PATCH 3/7] better doc and default behavior for json keystore --- bumble/keys.py | 117 ++++++++++++++++++++++++++++++------------------- 1 file changed, 71 insertions(+), 46 deletions(-) diff --git a/bumble/keys.py b/bumble/keys.py index a30e753..b09f1c1 100644 --- a/bumble/keys.py +++ b/bumble/keys.py @@ -190,10 +190,44 @@ class KeyStore: # ----------------------------------------------------------------------------- class JsonKeyStore(KeyStore): + """ + KeyStore implementation that is backed by a JSON file. + + This implementation supports storing a hierarchy of key sets in a single file. + A key set is a representation of a PairingKeys object. Each key set is stored + in a map, with the address of paired peer as the key. Maps are themselves grouped + into namespaces, grouping pairing keys by controller addresses. + The JSON object model looks like: + { + "": { + "peer-address": { + "address_type": , + "irk" : { + "authenticated": , + "value": "hex-encoded-key" + }, + ... other keys ... + }, + ... other peers ... + } + ... other namespaces ... + } + + A namespace is typically the BD_ADDR of a controller, since that is a convenient + unique identifier, but it may be something else. + A special namespace, called the "default" namespace, is used when instantiating this + class without a namespace. With the default namespace, reading from a file will + load an existing namespace if there is only one, which may be convenient for reading + from a file with a single key set and for which the namespace isn't known. If the + file does not include any existing key set, or if there are more than one and none + has the default name, a new one will be created with the name "__DEFAULT__". + """ + APP_NAME = 'Bumble' APP_AUTHOR = 'Google' KEYS_DIR = 'Pairing' DEFAULT_NAMESPACE = '__DEFAULT__' + DEFAULT_BASE_NAME = "keys" def __init__(self, namespace, filename=None): self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE @@ -208,8 +242,9 @@ class JsonKeyStore(KeyStore): self.directory_name = os.path.join( appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR ) + base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace json_filename = ( - f'{self.namespace}.json'.lower().replace(':', '-').replace('/p', '-p') + f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p') ) self.filename = os.path.join(self.directory_name, json_filename) else: @@ -219,11 +254,12 @@ class JsonKeyStore(KeyStore): logger.debug(f'JSON keystore: {self.filename}') @staticmethod - def from_device(device: Device) -> Optional[JsonKeyStore]: - if not device.config.keystore: - return None - - params = device.config.keystore.split(':', 1)[1:] + def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]: + if not filename: + # Extract the filename from the config + params = device.config.keystore.split(':', 1)[1:] + if params: + filename = params[0] # Use a namespace based on the device address if device.public_address not in (Address.ANY, Address.ANY_RANDOM): @@ -232,19 +268,31 @@ class JsonKeyStore(KeyStore): namespace = str(device.random_address) else: namespace = JsonKeyStore.DEFAULT_NAMESPACE - if params: - filename = params[0] - else: - filename = None return JsonKeyStore(namespace, filename) async def load(self): + # Try to open the file, without failing. If the file does not exist, it + # will be created upon saving. try: with open(self.filename, 'r', encoding='utf-8') as json_file: - return json.load(json_file) + db = json.load(json_file) except FileNotFoundError: - return {} + db = {} + + # First, look for a namespace match + if self.namespace in db: + return (db, db[self.namespace]) + + # Then, if the namespace is the default namespace, and there's + # only one entry in the db, use that + if self.namespace == self.DEFAULT_NAMESPACE and len(db) == 1: + return next(iter(db.items())) + + # Finally, just create an empty key map for the namespace + key_map = {} + db[self.namespace] = key_map + return (db, key_map) async def save(self, db): # Create the directory if it doesn't exist @@ -260,53 +308,30 @@ class JsonKeyStore(KeyStore): os.replace(temp_filename, self.filename) async def delete(self, name: str) -> None: - db = await self.load() - - namespace = db.get(self.namespace) - if namespace is None: - raise KeyError(name) - - del namespace[name] + db, key_map = await self.load() + del key_map[name] await self.save(db) async def update(self, name, keys): - db = await self.load() - - namespace = db.setdefault(self.namespace, {}) - namespace.setdefault(name, {}).update(keys.to_dict()) - + db, key_map = await self.load() + key_map.setdefault(name, {}).update(keys.to_dict()) await self.save(db) async def get_all(self): - db = await self.load() - - namespace = db.get(self.namespace) - if namespace is None: - return [] - - return [ - (name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items() - ] + _, key_map = await self.load() + return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()] async def delete_all(self): - db = await self.load() - - db.pop(self.namespace, None) - + db, key_map = await self.load() + key_map.clear() await self.save(db) async def get(self, name: str) -> Optional[PairingKeys]: - db = await self.load() - - namespace = db.get(self.namespace) - if namespace is None: + _, key_map = await self.load() + if name not in key_map: return None - keys = namespace.get(name) - if keys is None: - return None - - return PairingKeys.from_dict(keys) + return PairingKeys.from_dict(key_map[name]) # ----------------------------------------------------------------------------- From 6826f68478a039175f391039d6972f8a8892402d Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Fri, 5 May 2023 16:16:55 -0700 Subject: [PATCH 4/7] fix linter warnings --- bumble/device.py | 2 +- bumble/keys.py | 9 +++++---- bumble/smp.py | 4 +--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/bumble/device.py b/bumble/device.py index def02f5..5602fdf 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -3091,7 +3091,7 @@ class Device(CompositeEventEmitter): def on_pairing( self, connection: Connection, - identity_address: Address, + identity_address: Optional[Address], keys: PairingKeys, sc: bool, ) -> None: diff --git a/bumble/keys.py b/bumble/keys.py index b09f1c1..198d5c4 100644 --- a/bumble/keys.py +++ b/bumble/keys.py @@ -256,10 +256,11 @@ class JsonKeyStore(KeyStore): @staticmethod def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]: if not filename: - # Extract the filename from the config - params = device.config.keystore.split(':', 1)[1:] - if params: - filename = params[0] + # Extract the filename from the config if there is one + if device.config.keystore is not None: + params = device.config.keystore.split(':', 1)[1:] + if params: + filename = params[0] # Use a namespace based on the device address if device.public_address not in (Address.ANY, Address.ANY_RANDOM): diff --git a/bumble/smp.py b/bumble/smp.py index 79ae578..3cdcae1 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -1805,9 +1805,7 @@ class Manager(EventEmitter): self.device.abort_on('flush', store_keys()) # Notify the device - self.device.on_pairing( - session.connection, identity_address, keys, session.sc - ) + self.device.on_pairing(session.connection, identity_address, keys, session.sc) def on_pairing_failure(self, session: Session, reason: int) -> None: self.device.on_pairing_failure(session.connection, reason) From 27fbb58447b4295e3a976aa8d2141dac529f71ec Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Sun, 4 Jun 2023 13:01:07 -0700 Subject: [PATCH 5/7] add basic keystore test --- tests/keystore_test.py | 179 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 tests/keystore_test.py diff --git a/tests/keystore_test.py b/tests/keystore_test.py new file mode 100644 index 0000000..2e73039 --- /dev/null +++ b/tests/keystore_test.py @@ -0,0 +1,179 @@ +# Copyright 2021-2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import json +import logging +import tempfile +import os + +from bumble.keys import JsonKeyStore, PairingKeys + + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + +JSON1 = """ + { + "my_namespace": { + "14:7D:DA:4E:53:A8/P": { + "address_type": 0, + "irk": { + "authenticated": false, + "value": "e7b2543b206e4e46b44f9e51dad22bd1" + }, + "link_key": { + "authenticated": false, + "value": "0745dd9691e693d9dca740f7d8dfea75" + }, + "ltk": { + "authenticated": false, + "value": "d1897ee10016eb1a08e4e037fd54c683" + } + } + } + } + """ + +JSON2 = """ + { + "my_namespace1": { + }, + "my_namespace2": { + } + } + """ + +JSON3 = """ + { + "my_namespace1": { + }, + "__DEFAULT__": { + "14:7D:DA:4E:53:A8/P": { + "address_type": 0, + "irk": { + "authenticated": false, + "value": "e7b2543b206e4e46b44f9e51dad22bd1" + } + } + } + } + """ + + +# ----------------------------------------------------------------------------- +async def test_basic(): + with tempfile.NamedTemporaryFile(mode="r+", encoding='utf-8') as file: + keystore = JsonKeyStore('my_namespace', file.name) + file.write("{}") + file.flush() + + keys = await keystore.get_all() + assert len(keys) == 0 + + keys = PairingKeys() + await keystore.update('foo', keys) + foo = await keystore.get('foo') + assert foo is not None + assert foo.ltk is None + ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) + keys.ltk = PairingKeys.Key(ltk) + await keystore.update('foo', keys) + foo = await keystore.get('foo') + assert foo is not None + assert foo.ltk is not None + assert foo.ltk.value == ltk + + file.flush() + with open(file.name, "r", encoding="utf-8") as json_file: + json_data = json.load(json_file) + assert 'my_namespace' in json_data + assert 'foo' in json_data['my_namespace'] + assert 'ltk' in json_data['my_namespace']['foo'] + + +# ----------------------------------------------------------------------------- +async def test_parsing(): + with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: + keystore = JsonKeyStore('my_namespace', file.name) + file.write(JSON1) + file.flush() + + foo = await keystore.get('14:7D:DA:4E:53:A8/P') + assert foo is not None + assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683') + + +# ----------------------------------------------------------------------------- +async def test_default_namespace(): + with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: + keystore = JsonKeyStore(None, file.name) + file.write(JSON1) + file.flush() + + all_keys = await keystore.get_all() + assert len(all_keys) == 1 + name, keys = all_keys[0] + assert name == '14:7D:DA:4E:53:A8/P' + assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1') + + with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: + keystore = JsonKeyStore(None, file.name) + file.write(JSON2) + file.flush() + + keys = PairingKeys() + ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) + keys.ltk = PairingKeys.Key(ltk) + await keystore.update('foo', keys) + file.flush() + with open(file.name, "r", encoding="utf-8") as json_file: + json_data = json.load(json_file) + assert '__DEFAULT__' in json_data + assert 'foo' in json_data['__DEFAULT__'] + assert 'ltk' in json_data['__DEFAULT__']['foo'] + + with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: + keystore = JsonKeyStore(None, file.name) + file.write(JSON3) + file.flush() + + all_keys = await keystore.get_all() + assert len(all_keys) == 1 + name, keys = all_keys[0] + assert name == '14:7D:DA:4E:53:A8/P' + assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1') + + +# ----------------------------------------------------------------------------- +async def run_tests(): + await test_basic() + await test_parsing() + await test_default_namespace() + + +# ----------------------------------------------------------------------------- +if __name__ == '__main__': + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + asyncio.run(run_tests()) From 39db278f2edf5f2e228054719df52f7797eac93b Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Mon, 5 Jun 2023 16:03:02 +0800 Subject: [PATCH 6/7] Add typing for HFP --- bumble/hfp.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/bumble/hfp.py b/bumble/hfp.py index 7bb9f08..9080a55 100644 --- a/bumble/hfp.py +++ b/bumble/hfp.py @@ -18,10 +18,11 @@ import logging import asyncio import collections +from typing import Union +from . import rfcomm from .colors import color - # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- @@ -34,7 +35,12 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- class HfpProtocol: - def __init__(self, dlc): + dlc: rfcomm.DLC + buffer: str + lines: collections.deque + lines_available: asyncio.Event + + def __init__(self, dlc: rfcomm.DLC) -> None: self.dlc = dlc self.buffer = '' self.lines = collections.deque() @@ -42,7 +48,7 @@ class HfpProtocol: dlc.sink = self.feed - def feed(self, data): + def feed(self, data: Union[bytes, str]) -> None: # Convert the data to a string if needed if isinstance(data, bytes): data = data.decode('utf-8') @@ -57,19 +63,19 @@ class HfpProtocol: if len(line) > 0: self.on_line(line) - def on_line(self, line): + def on_line(self, line: str) -> None: self.lines.append(line) self.lines_available.set() - def send_command_line(self, line): + def send_command_line(self, line: str) -> None: logger.debug(color(f'>>> {line}', 'yellow')) self.dlc.write(line + '\r') - def send_response_line(self, line): + def send_response_line(self, line: str) -> None: logger.debug(color(f'>>> {line}', 'yellow')) self.dlc.write('\r\n' + line + '\r\n') - async def next_line(self): + async def next_line(self) -> str: await self.lines_available.wait() line = self.lines.popleft() if not self.lines: @@ -77,7 +83,7 @@ class HfpProtocol: logger.debug(color(f'<<< {line}', 'green')) return line - async def initialize_service(self): + async def initialize_service(self) -> None: # Perform Service Level Connection Initialization self.send_command_line('AT+BRSF=2072') # Retrieve Supported Features await (self.next_line()) From 4d74339c04d540045dec1f5f16e074771b8f36ba Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Mon, 5 Jun 2023 16:28:52 +0800 Subject: [PATCH 7/7] Add typing for RFCOMM --- bumble/rfcomm.py | 177 +++++++++++++++++++++++++++++------------------ 1 file changed, 110 insertions(+), 67 deletions(-) diff --git a/bumble/rfcomm.py b/bumble/rfcomm.py index 71be8dc..0176a78 100644 --- a/bumble/rfcomm.py +++ b/bumble/rfcomm.py @@ -19,8 +19,9 @@ import logging import asyncio from pyee import EventEmitter +from typing import Optional, Tuple, Callable, Dict, Union -from . import core +from . import core, l2cap from .colors import color from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError @@ -105,7 +106,7 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30 # ----------------------------------------------------------------------------- -def compute_fcs(buffer): +def compute_fcs(buffer: bytes) -> int: result = 0xFF for byte in buffer: result = CRC_TABLE[result ^ byte] @@ -114,7 +115,15 @@ def compute_fcs(buffer): # ----------------------------------------------------------------------------- class RFCOMM_Frame: - def __init__(self, frame_type, c_r, dlci, p_f, information=b'', with_credits=False): + def __init__( + self, + frame_type: int, + c_r: int, + dlci: int, + p_f: int, + information: bytes = b'', + with_credits: bool = False, + ) -> None: self.type = frame_type self.c_r = c_r self.dlci = dlci @@ -136,11 +145,11 @@ class RFCOMM_Frame: else: self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length) - def type_name(self): + def type_name(self) -> str: return RFCOMM_FRAME_TYPE_NAMES[self.type] @staticmethod - def parse_mcc(data): + def parse_mcc(data) -> Tuple[int, int, bytes]: mcc_type = data[0] >> 2 c_r = (data[0] >> 1) & 1 length = data[1] @@ -154,36 +163,36 @@ class RFCOMM_Frame: return (mcc_type, c_r, value) @staticmethod - def make_mcc(mcc_type, c_r, data): + def make_mcc(mcc_type: int, c_r: int, data: bytes) -> bytes: return ( bytes([(mcc_type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1]) + data ) @staticmethod - def sabm(c_r, dlci): + def sabm(c_r: int, dlci: int): return RFCOMM_Frame(RFCOMM_SABM_FRAME, c_r, dlci, 1) @staticmethod - def ua(c_r, dlci): + def ua(c_r: int, dlci: int): return RFCOMM_Frame(RFCOMM_UA_FRAME, c_r, dlci, 1) @staticmethod - def dm(c_r, dlci): + def dm(c_r: int, dlci: int): return RFCOMM_Frame(RFCOMM_DM_FRAME, c_r, dlci, 1) @staticmethod - def disc(c_r, dlci): + def disc(c_r: int, dlci: int): return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1) @staticmethod - def uih(c_r, dlci, information, p_f=0): + def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0): return RFCOMM_Frame( RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1) ) @staticmethod - def from_bytes(data): + def from_bytes(data: bytes): # Extract fields dlci = (data[0] >> 2) & 0x3F c_r = (data[0] >> 1) & 0x01 @@ -227,15 +236,23 @@ class RFCOMM_Frame: # ----------------------------------------------------------------------------- class RFCOMM_MCC_PN: + dlci: int + cl: int + priority: int + ack_timer: int + max_frame_size: int + max_retransmissions: int + window_size: int + def __init__( self, - dlci, - cl, - priority, - ack_timer, - max_frame_size, - max_retransmissions, - window_size, + dlci: int, + cl: int, + priority: int, + ack_timer: int, + max_frame_size: int, + max_retransmissions: int, + window_size: int, ): self.dlci = dlci self.cl = cl @@ -246,7 +263,7 @@ class RFCOMM_MCC_PN: self.window_size = window_size @staticmethod - def from_bytes(data): + def from_bytes(data: bytes): return RFCOMM_MCC_PN( dlci=data[0], cl=data[1], @@ -285,7 +302,14 @@ class RFCOMM_MCC_PN: # ----------------------------------------------------------------------------- class RFCOMM_MCC_MSC: - def __init__(self, dlci, fc, rtc, rtr, ic, dv): + dlci: int + fc: int + rtc: int + rtr: int + ic: int + dv: int + + def __init__(self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int): self.dlci = dlci self.fc = fc self.rtc = rtc @@ -294,7 +318,7 @@ class RFCOMM_MCC_MSC: self.dv = dv @staticmethod - def from_bytes(data): + def from_bytes(data: bytes): return RFCOMM_MCC_MSC( dlci=data[0] >> 2, fc=data[1] >> 1 & 1, @@ -347,7 +371,12 @@ class DLC(EventEmitter): RESET: 'RESET', } - def __init__(self, multiplexer, dlci, max_frame_size, initial_tx_credits): + connection_result: Optional[asyncio.Future] + sink: Optional[Callable[[bytes], None]] + + def __init__( + self, multiplexer, dlci: int, max_frame_size: int, initial_tx_credits: int + ): super().__init__() self.multiplexer = multiplexer self.dlci = dlci @@ -368,23 +397,23 @@ class DLC(EventEmitter): ) @staticmethod - def state_name(state): + def state_name(state: int) -> str: return DLC.STATE_NAMES[state] - def change_state(self, new_state): + def change_state(self, new_state: int) -> None: logger.debug( f'{self} state change -> {color(self.state_name(new_state), "magenta")}' ) self.state = new_state - def send_frame(self, frame): + def send_frame(self, frame: RFCOMM_Frame) -> None: self.multiplexer.send_frame(frame) - def on_frame(self, frame): + def on_frame(self, frame: RFCOMM_Frame) -> None: handler = getattr(self, f'on_{frame.type_name()}_frame'.lower()) handler(frame) - def on_sabm_frame(self, _frame): + def on_sabm_frame(self, _frame) -> None: if self.state != DLC.CONNECTING: logger.warning( color('!!! received SABM when not in CONNECTING state', 'red') @@ -404,7 +433,7 @@ class DLC(EventEmitter): self.change_state(DLC.CONNECTED) self.emit('open') - def on_ua_frame(self, _frame): + def on_ua_frame(self, _frame) -> None: if self.state != DLC.CONNECTING: logger.warning( color('!!! received SABM when not in CONNECTING state', 'red') @@ -422,15 +451,15 @@ class DLC(EventEmitter): self.change_state(DLC.CONNECTED) self.multiplexer.on_dlc_open_complete(self) - def on_dm_frame(self, frame): + def on_dm_frame(self, frame) -> None: # TODO: handle all states pass - def on_disc_frame(self, _frame): + def on_disc_frame(self, _frame) -> None: # TODO: handle all states self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci)) - def on_uih_frame(self, frame): + def on_uih_frame(self, frame: RFCOMM_Frame) -> None: data = frame.information if frame.p_f == 1: # With credits @@ -460,10 +489,10 @@ class DLC(EventEmitter): # Check if there's anything to send (including credits) self.process_tx() - def on_ui_frame(self, frame): + def on_ui_frame(self, frame) -> None: pass - def on_mcc_msc(self, c_r, msc): + def on_mcc_msc(self, c_r, msc) -> None: if c_r: # Command logger.debug(f'<<< MCC MSC Command: {msc}') @@ -477,7 +506,7 @@ class DLC(EventEmitter): # Response logger.debug(f'<<< MCC MSC Response: {msc}') - def connect(self): + def connect(self) -> None: if self.state != DLC.INIT: raise InvalidStateError('invalid state') @@ -485,7 +514,7 @@ class DLC(EventEmitter): self.connection_result = asyncio.get_running_loop().create_future() self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci)) - def accept(self): + def accept(self) -> None: if self.state != DLC.INIT: raise InvalidStateError('invalid state') @@ -503,13 +532,13 @@ class DLC(EventEmitter): self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc)) self.change_state(DLC.CONNECTING) - def rx_credits_needed(self): + def rx_credits_needed(self) -> int: if self.rx_credits <= self.rx_threshold: return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits return 0 - def process_tx(self): + def process_tx(self) -> None: # Send anything we can (or an empty frame if we need to send rx credits) rx_credits_needed = self.rx_credits_needed() while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0: @@ -547,7 +576,7 @@ class DLC(EventEmitter): rx_credits_needed = 0 # Stream protocol - def write(self, data): + def write(self, data: Union[bytes, str]) -> None: # We can only send bytes if not isinstance(data, bytes): if isinstance(data, str): @@ -559,7 +588,7 @@ class DLC(EventEmitter): self.tx_buffer += data self.process_tx() - def drain(self): + def drain(self) -> None: # TODO pass @@ -592,7 +621,13 @@ class Multiplexer(EventEmitter): RESET: 'RESET', } - def __init__(self, l2cap_channel, role): + connection_result: Optional[asyncio.Future] + disconnection_result: Optional[asyncio.Future] + open_result: Optional[asyncio.Future] + acceptor: Optional[Callable[[int], bool]] + dlcs: Dict[int, DLC] + + def __init__(self, l2cap_channel: l2cap.Channel, role: int) -> None: super().__init__() self.role = role self.l2cap_channel = l2cap_channel @@ -607,20 +642,20 @@ class Multiplexer(EventEmitter): l2cap_channel.sink = self.on_pdu @staticmethod - def state_name(state): + def state_name(state: int): return Multiplexer.STATE_NAMES[state] - def change_state(self, new_state): + def change_state(self, new_state: int) -> None: logger.debug( f'{self} state change -> {color(self.state_name(new_state), "cyan")}' ) self.state = new_state - def send_frame(self, frame): + def send_frame(self, frame: RFCOMM_Frame) -> None: logger.debug(f'>>> Multiplexer sending {frame}') self.l2cap_channel.send_pdu(frame) - def on_pdu(self, pdu): + def on_pdu(self, pdu: bytes) -> None: frame = RFCOMM_Frame.from_bytes(pdu) logger.debug(f'<<< Multiplexer received {frame}') @@ -640,18 +675,18 @@ class Multiplexer(EventEmitter): return dlc.on_frame(frame) - def on_frame(self, frame): + def on_frame(self, frame: RFCOMM_Frame) -> None: handler = getattr(self, f'on_{frame.type_name()}_frame'.lower()) handler(frame) - def on_sabm_frame(self, _frame): + def on_sabm_frame(self, _frame) -> None: if self.state != Multiplexer.INIT: logger.debug('not in INIT state, ignoring SABM') return self.change_state(Multiplexer.CONNECTED) self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0)) - def on_ua_frame(self, _frame): + def on_ua_frame(self, _frame) -> None: if self.state == Multiplexer.CONNECTING: self.change_state(Multiplexer.CONNECTED) if self.connection_result: @@ -663,7 +698,7 @@ class Multiplexer(EventEmitter): self.disconnection_result.set_result(None) self.disconnection_result = None - def on_dm_frame(self, _frame): + def on_dm_frame(self, _frame) -> None: if self.state == Multiplexer.OPENING: self.change_state(Multiplexer.CONNECTED) if self.open_result: @@ -678,13 +713,13 @@ class Multiplexer(EventEmitter): else: logger.warning(f'unexpected state for DM: {self}') - def on_disc_frame(self, _frame): + def on_disc_frame(self, _frame) -> None: self.change_state(Multiplexer.DISCONNECTED) self.send_frame( RFCOMM_Frame.ua(c_r=0 if self.role == Multiplexer.INITIATOR else 1, dlci=0) ) - def on_uih_frame(self, frame): + def on_uih_frame(self, frame: RFCOMM_Frame) -> None: (mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information) if mcc_type == RFCOMM_MCC_PN_TYPE: @@ -694,10 +729,10 @@ class Multiplexer(EventEmitter): mcs = RFCOMM_MCC_MSC.from_bytes(value) self.on_mcc_msc(c_r, mcs) - def on_ui_frame(self, frame): + def on_ui_frame(self, frame) -> None: pass - def on_mcc_pn(self, c_r, pn): + def on_mcc_pn(self, c_r, pn) -> None: if c_r == 1: # Command logger.debug(f'<<< PN Command: {pn}') @@ -736,14 +771,14 @@ class Multiplexer(EventEmitter): else: logger.warning('ignoring PN response') - def on_mcc_msc(self, c_r, msc): + def on_mcc_msc(self, c_r, msc) -> None: dlc = self.dlcs.get(msc.dlci) if dlc is None: logger.warning(f'no dlc for DLCI {msc.dlci}') return dlc.on_mcc_msc(c_r, msc) - async def connect(self): + async def connect(self) -> None: if self.state != Multiplexer.INIT: raise InvalidStateError('invalid state') @@ -752,7 +787,7 @@ class Multiplexer(EventEmitter): self.send_frame(RFCOMM_Frame.sabm(c_r=1, dlci=0)) return await self.connection_result - async def disconnect(self): + async def disconnect(self) -> None: if self.state != Multiplexer.CONNECTED: return @@ -765,7 +800,7 @@ class Multiplexer(EventEmitter): ) await self.disconnection_result - async def open_dlc(self, channel): + async def open_dlc(self, channel: int) -> DLC: if self.state != Multiplexer.CONNECTED: if self.state == Multiplexer.OPENING: raise InvalidStateError('open already in progress') @@ -796,7 +831,7 @@ class Multiplexer(EventEmitter): self.open_result = None return result - def on_dlc_open_complete(self, dlc): + def on_dlc_open_complete(self, dlc: DLC): logger.debug(f'DLC [{dlc.dlci}] open complete') self.change_state(Multiplexer.CONNECTED) if self.open_result: @@ -808,13 +843,16 @@ class Multiplexer(EventEmitter): # ----------------------------------------------------------------------------- class Client: - def __init__(self, device, connection): + multiplexer: Optional[Multiplexer] + l2cap_channel: Optional[l2cap.Channel] + + def __init__(self, device, connection) -> None: self.device = device self.connection = connection self.l2cap_channel = None self.multiplexer = None - async def start(self): + async def start(self) -> Multiplexer: # Create a new L2CAP connection try: self.l2cap_channel = await self.device.l2cap_channel_manager.connect( @@ -824,6 +862,7 @@ class Client: logger.warning(f'L2CAP connection failed: {error}') raise + assert self.l2cap_channel is not None # Create a mutliplexer to manage DLCs with the server self.multiplexer = Multiplexer(self.l2cap_channel, Multiplexer.INITIATOR) @@ -832,7 +871,9 @@ class Client: return self.multiplexer - async def shutdown(self): + async def shutdown(self) -> None: + if self.multiplexer is None: + return # Disconnect the multiplexer await self.multiplexer.disconnect() self.multiplexer = None @@ -843,7 +884,9 @@ class Client: # ----------------------------------------------------------------------------- class Server(EventEmitter): - def __init__(self, device): + acceptors: Dict[int, Callable[[DLC], None]] + + def __init__(self, device) -> None: super().__init__() self.device = device self.multiplexer = None @@ -852,7 +895,7 @@ class Server(EventEmitter): # Register ourselves with the L2CAP channel manager device.register_l2cap_server(RFCOMM_PSM, self.on_connection) - def listen(self, acceptor, channel=0): + def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int: if channel: if channel in self.acceptors: # Busy @@ -874,11 +917,11 @@ class Server(EventEmitter): self.acceptors[channel] = acceptor return channel - def on_connection(self, l2cap_channel): + def on_connection(self, l2cap_channel: l2cap.Channel) -> None: logger.debug(f'+++ new L2CAP connection: {l2cap_channel}') l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) - def on_l2cap_channel_open(self, l2cap_channel): + def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None: logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}') # Create a new multiplexer for the channel @@ -889,10 +932,10 @@ class Server(EventEmitter): # Notify self.emit('start', multiplexer) - def accept_dlc(self, channel_number): + def accept_dlc(self, channel_number: int) -> bool: return channel_number in self.acceptors - def on_dlc(self, dlc): + def on_dlc(self, dlc: DLC) -> None: logger.debug(f'@@@ new DLC connected: {dlc}') # Let the acceptor know