Compare commits

..

12 Commits

Author SHA1 Message Date
Lucas Abel
7f987dc3cd Merge pull request #198 from qiaoccolato/main
reformat protobuf import
2023-06-07 10:05:30 -07:00
qiaoccolato
689745040f Merge branch 'google:main' into main 2023-06-07 09:19:54 -07:00
Qiao Yang
809d4a18f5 reformat protobuf import 2023-06-07 09:14:50 -07:00
Gilles Boccon-Gibod
54be8b328a Merge pull request #197 from zxzxwu/typing
Add typing for HFP and RFCOMM
2023-06-07 07:09:26 -07:00
Gilles Boccon-Gibod
57b469198a Merge pull request #196 from google/gbg/better-address-resolving
pairing event improvement
2023-06-07 07:03:53 -07:00
Josh Wu
4d74339c04 Add typing for RFCOMM 2023-06-06 00:04:25 +08:00
Josh Wu
39db278f2e Add typing for HFP 2023-06-05 23:54:42 +08:00
Gilles Boccon-Gibod
27fbb58447 add basic keystore test 2023-06-04 13:01:07 -07:00
Gilles Boccon-Gibod
6826f68478 fix linter warnings 2023-05-05 16:16:55 -07:00
Gilles Boccon-Gibod
f80c83d0b3 better doc and default behavior for json keystore 2023-05-05 16:11:20 -07:00
Gilles Boccon-Gibod
3de35193bc rebase 2023-05-05 16:09:01 -07:00
Gilles Boccon-Gibod
740a2e0ca0 instantiate keystore after power_on 2023-05-05 16:07:16 -07:00
11 changed files with 451 additions and 173 deletions

View File

@@ -207,7 +207,7 @@ def on_connection(connection, request):
# Listen for pairing events
connection.on('pairing_start', on_pairing_start)
connection.on('pairing', on_pairing)
connection.on('pairing', lambda keys: on_pairing(connection.peer_address, keys))
connection.on('pairing_failure', on_pairing_failure)
# Listen for encryption changes
@@ -242,9 +242,9 @@ def on_pairing_start():
# -----------------------------------------------------------------------------
def on_pairing(keys):
def on_pairing(address, keys):
print(color('***-----------------------------------', 'cyan'))
print(color('*** Paired!', 'cyan'))
print(color(f'*** Paired! (peer identity={address})', 'cyan'))
keys.print(prefix=color('*** ', 'cyan'))
print(color('***-----------------------------------', 'cyan'))
Waiter.instance.terminate()
@@ -283,17 +283,6 @@ async def pair(
# Create a device to manage the host
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
# Set a custom keystore if specified on the command line
if keystore_file:
device.keystore = JsonKeyStore(namespace=None, filename=keystore_file)
# Print the existing keys before pairing
if print_keys and device.keystore:
print(color('@@@-----------------------------------', 'blue'))
print(color('@@@ Pairing Keys:', 'blue'))
await device.keystore.print(prefix=color('@@@ ', 'blue'))
print(color('@@@-----------------------------------', 'blue'))
# Expose a GATT characteristic that can be used to trigger pairing by
# responding with an authentication error when read
if mode == 'le':
@@ -323,6 +312,17 @@ async def pair(
# Get things going
await device.power_on()
# Set a custom keystore if specified on the command line
if keystore_file:
device.keystore = JsonKeyStore.from_device(device, filename=keystore_file)
# Print the existing keys before pairing
if print_keys and device.keystore:
print(color('@@@-----------------------------------', 'blue'))
print(color('@@@ Pairing Keys:', 'blue'))
await device.keystore.print(prefix=color('@@@ ', 'blue'))
print(color('@@@-----------------------------------', 'blue'))
# Set up a pairing config factory
device.pairing_config_factory = lambda connection: PairingConfig(
sc, mitm, bond, Delegate(mode, connection, io, prompt)

View File

@@ -133,15 +133,16 @@ async def scan(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
await device.power_on()
if keystore_file:
keystore = JsonKeyStore(namespace=None, filename=keystore_file)
device.keystore = keystore
else:
resolver = None
device.keystore = JsonKeyStore.from_device(device, filename=keystore_file)
if device.keystore:
resolving_keys = await device.keystore.get_resolving_keys()
resolver = AddressResolver(resolving_keys)
else:
resolver = None
printer = AdvertisementPrinter(min_rssi, resolver)
if raw:
@@ -149,8 +150,6 @@ async def scan(
else:
device.on('advertisement', printer.on_advertisement)
await device.power_on()
if phy is None:
scanning_phys = [HCI_LE_1M_PHY, HCI_LE_CODED_PHY]
else:

View File

@@ -22,40 +22,58 @@ import click
from bumble.device import Device
from bumble.keys import JsonKeyStore
from bumble.transport import open_transport
# -----------------------------------------------------------------------------
async def unbond_with_keystore(keystore, address):
if address is None:
return await keystore.print()
try:
await keystore.delete(address)
except KeyError:
print('!!! pairing not found')
# -----------------------------------------------------------------------------
async def unbond(keystore_file, device_config, address):
# Create a device to manage the host
device = Device.from_config_file(device_config)
# Get all entries in the keystore
async def unbond(keystore_file, device_config, hci_transport, address):
# With a keystore file, we can instantiate the keystore directly
if keystore_file:
keystore = JsonKeyStore(None, keystore_file)
else:
keystore = device.keystore
return await unbond_with_keystore(JsonKeyStore(None, keystore_file), address)
if keystore is None:
print('no keystore')
return
# Without a keystore file, we need to obtain the keystore from the device
async with await open_transport(hci_transport) as (hci_source, hci_sink):
# Create a device to manage the host
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
if address is None:
await keystore.print()
else:
try:
await keystore.delete(address)
except KeyError:
print('!!! pairing not found')
# Power-on the device to ensure we have a key store
await device.power_on()
return await unbond_with_keystore(device.keystore, address)
# -----------------------------------------------------------------------------
@click.command()
@click.option('--keystore-file', help='File in which to store the pairing keys')
@click.argument('device-config')
@click.option('--keystore-file', help='File in which the pairing keys are stored')
@click.option('--hci-transport', help='HCI transport for the controller')
@click.argument('device-config', required=False)
@click.argument('address', required=False)
def main(keystore_file, device_config, address):
def main(keystore_file, hci_transport, device_config, address):
"""
Remove pairing keys for a device, given its address.
If no keystore file is specified, the --hci-transport option must be used to
connect to a controller, so that the keystore for that controller can be
instantiated.
If no address is passed, the existing pairing keys for all addresses are printed.
"""
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(unbond(keystore_file, device_config, address))
if not keystore_file and not hci_transport:
print('either --keystore-file or --hci-transport must be specified.')
return
asyncio.run(unbond(keystore_file, device_config, hci_transport, address))
# -----------------------------------------------------------------------------

View File

@@ -3094,7 +3094,16 @@ class Device(CompositeEventEmitter):
def on_pairing_start(self, connection: Connection) -> None:
connection.emit('pairing_start')
def on_pairing(self, connection: Connection, keys: PairingKeys, sc: bool) -> None:
def on_pairing(
self,
connection: Connection,
identity_address: Optional[Address],
keys: PairingKeys,
sc: bool,
) -> None:
if identity_address is not None:
connection.peer_resolvable_address = connection.peer_address
connection.peer_address = identity_address
connection.sc = sc
connection.authenticated = True
connection.emit('pairing', keys)

View File

@@ -18,10 +18,11 @@
import logging
import asyncio
import collections
from typing import Union
from . import rfcomm
from .colors import color
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -34,7 +35,12 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class HfpProtocol:
def __init__(self, dlc):
dlc: rfcomm.DLC
buffer: str
lines: collections.deque
lines_available: asyncio.Event
def __init__(self, dlc: rfcomm.DLC) -> None:
self.dlc = dlc
self.buffer = ''
self.lines = collections.deque()
@@ -42,7 +48,7 @@ class HfpProtocol:
dlc.sink = self.feed
def feed(self, data):
def feed(self, data: Union[bytes, str]) -> None:
# Convert the data to a string if needed
if isinstance(data, bytes):
data = data.decode('utf-8')
@@ -57,19 +63,19 @@ class HfpProtocol:
if len(line) > 0:
self.on_line(line)
def on_line(self, line):
def on_line(self, line: str) -> None:
self.lines.append(line)
self.lines_available.set()
def send_command_line(self, line):
def send_command_line(self, line: str) -> None:
logger.debug(color(f'>>> {line}', 'yellow'))
self.dlc.write(line + '\r')
def send_response_line(self, line):
def send_response_line(self, line: str) -> None:
logger.debug(color(f'>>> {line}', 'yellow'))
self.dlc.write('\r\n' + line + '\r\n')
async def next_line(self):
async def next_line(self) -> str:
await self.lines_available.wait()
line = self.lines.popleft()
if not self.lines:
@@ -77,7 +83,7 @@ class HfpProtocol:
logger.debug(color(f'<<< {line}', 'green'))
return line
async def initialize_service(self):
async def initialize_service(self) -> None:
# Perform Service Level Connection Initialization
self.send_command_line('AT+BRSF=2072') # Retrieve Supported Features
await (self.next_line())

View File

@@ -190,10 +190,44 @@ class KeyStore:
# -----------------------------------------------------------------------------
class JsonKeyStore(KeyStore):
"""
KeyStore implementation that is backed by a JSON file.
This implementation supports storing a hierarchy of key sets in a single file.
A key set is a representation of a PairingKeys object. Each key set is stored
in a map, with the address of paired peer as the key. Maps are themselves grouped
into namespaces, grouping pairing keys by controller addresses.
The JSON object model looks like:
{
"<namespace>": {
"peer-address": {
"address_type": <n>,
"irk" : {
"authenticated": <true/false>,
"value": "hex-encoded-key"
},
... other keys ...
},
... other peers ...
}
... other namespaces ...
}
A namespace is typically the BD_ADDR of a controller, since that is a convenient
unique identifier, but it may be something else.
A special namespace, called the "default" namespace, is used when instantiating this
class without a namespace. With the default namespace, reading from a file will
load an existing namespace if there is only one, which may be convenient for reading
from a file with a single key set and for which the namespace isn't known. If the
file does not include any existing key set, or if there are more than one and none
has the default name, a new one will be created with the name "__DEFAULT__".
"""
APP_NAME = 'Bumble'
APP_AUTHOR = 'Google'
KEYS_DIR = 'Pairing'
DEFAULT_NAMESPACE = '__DEFAULT__'
DEFAULT_BASE_NAME = "keys"
def __init__(self, namespace, filename=None):
self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE
@@ -208,8 +242,9 @@ class JsonKeyStore(KeyStore):
self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
)
base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace
json_filename = (
f'{self.namespace}.json'.lower().replace(':', '-').replace('/p', '-p')
f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p')
)
self.filename = os.path.join(self.directory_name, json_filename)
else:
@@ -219,11 +254,13 @@ class JsonKeyStore(KeyStore):
logger.debug(f'JSON keystore: {self.filename}')
@staticmethod
def from_device(device: Device) -> Optional[JsonKeyStore]:
if not device.config.keystore:
return None
params = device.config.keystore.split(':', 1)[1:]
def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]:
if not filename:
# Extract the filename from the config if there is one
if device.config.keystore is not None:
params = device.config.keystore.split(':', 1)[1:]
if params:
filename = params[0]
# Use a namespace based on the device address
if device.public_address not in (Address.ANY, Address.ANY_RANDOM):
@@ -232,19 +269,31 @@ class JsonKeyStore(KeyStore):
namespace = str(device.random_address)
else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE
if params:
filename = params[0]
else:
filename = None
return JsonKeyStore(namespace, filename)
async def load(self):
# Try to open the file, without failing. If the file does not exist, it
# will be created upon saving.
try:
with open(self.filename, 'r', encoding='utf-8') as json_file:
return json.load(json_file)
db = json.load(json_file)
except FileNotFoundError:
return {}
db = {}
# First, look for a namespace match
if self.namespace in db:
return (db, db[self.namespace])
# Then, if the namespace is the default namespace, and there's
# only one entry in the db, use that
if self.namespace == self.DEFAULT_NAMESPACE and len(db) == 1:
return next(iter(db.items()))
# Finally, just create an empty key map for the namespace
key_map = {}
db[self.namespace] = key_map
return (db, key_map)
async def save(self, db):
# Create the directory if it doesn't exist
@@ -260,53 +309,30 @@ class JsonKeyStore(KeyStore):
os.replace(temp_filename, self.filename)
async def delete(self, name: str) -> None:
db = await self.load()
namespace = db.get(self.namespace)
if namespace is None:
raise KeyError(name)
del namespace[name]
db, key_map = await self.load()
del key_map[name]
await self.save(db)
async def update(self, name, keys):
db = await self.load()
namespace = db.setdefault(self.namespace, {})
namespace.setdefault(name, {}).update(keys.to_dict())
db, key_map = await self.load()
key_map.setdefault(name, {}).update(keys.to_dict())
await self.save(db)
async def get_all(self):
db = await self.load()
namespace = db.get(self.namespace)
if namespace is None:
return []
return [
(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()
]
_, key_map = await self.load()
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()]
async def delete_all(self):
db = await self.load()
db.pop(self.namespace, None)
db, key_map = await self.load()
key_map.clear()
await self.save(db)
async def get(self, name: str) -> Optional[PairingKeys]:
db = await self.load()
namespace = db.get(self.namespace)
if namespace is None:
_, key_map = await self.load()
if name not in key_map:
return None
keys = namespace.get(name)
if keys is None:
return None
return PairingKeys.from_dict(keys)
return PairingKeys.from_dict(key_map[name])
# -----------------------------------------------------------------------------

View File

@@ -43,7 +43,8 @@ from bumble.hci import (
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
Address,
)
from google.protobuf import any_pb2, empty_pb2 # pytype: disable=pyi-error
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from pandora.host_grpc_aio import HostServicer
from pandora.host_pb2 import (
NOT_CONNECTABLE,

View File

@@ -29,12 +29,9 @@ from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
from contextlib import suppress
from google.protobuf import (
any_pb2,
empty_pb2,
wrappers_pb2,
) # pytype: disable=pyi-error
from google.protobuf.wrappers_pb2 import BoolValue # pytype: disable=pyi-error
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
from pandora.host_pb2 import Connection
from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer
from pandora.security_pb2 import (
@@ -513,7 +510,7 @@ class SecurityStorageService(SecurityStorageServicer):
else:
is_bonded = False
return BoolValue(value=is_bonded)
return wrappers_pb2.BoolValue(value=is_bonded)
@utils.rpc
async def DeleteBond(

View File

@@ -19,8 +19,9 @@ import logging
import asyncio
from pyee import EventEmitter
from typing import Optional, Tuple, Callable, Dict, Union
from . import core
from . import core, l2cap
from .colors import color
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError
@@ -105,7 +106,7 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# -----------------------------------------------------------------------------
def compute_fcs(buffer):
def compute_fcs(buffer: bytes) -> int:
result = 0xFF
for byte in buffer:
result = CRC_TABLE[result ^ byte]
@@ -114,7 +115,15 @@ def compute_fcs(buffer):
# -----------------------------------------------------------------------------
class RFCOMM_Frame:
def __init__(self, frame_type, c_r, dlci, p_f, information=b'', with_credits=False):
def __init__(
self,
frame_type: int,
c_r: int,
dlci: int,
p_f: int,
information: bytes = b'',
with_credits: bool = False,
) -> None:
self.type = frame_type
self.c_r = c_r
self.dlci = dlci
@@ -136,11 +145,11 @@ class RFCOMM_Frame:
else:
self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
def type_name(self):
def type_name(self) -> str:
return RFCOMM_FRAME_TYPE_NAMES[self.type]
@staticmethod
def parse_mcc(data):
def parse_mcc(data) -> Tuple[int, int, bytes]:
mcc_type = data[0] >> 2
c_r = (data[0] >> 1) & 1
length = data[1]
@@ -154,36 +163,36 @@ class RFCOMM_Frame:
return (mcc_type, c_r, value)
@staticmethod
def make_mcc(mcc_type, c_r, data):
def make_mcc(mcc_type: int, c_r: int, data: bytes) -> bytes:
return (
bytes([(mcc_type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1])
+ data
)
@staticmethod
def sabm(c_r, dlci):
def sabm(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_SABM_FRAME, c_r, dlci, 1)
@staticmethod
def ua(c_r, dlci):
def ua(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_UA_FRAME, c_r, dlci, 1)
@staticmethod
def dm(c_r, dlci):
def dm(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_DM_FRAME, c_r, dlci, 1)
@staticmethod
def disc(c_r, dlci):
def disc(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1)
@staticmethod
def uih(c_r, dlci, information, p_f=0):
def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0):
return RFCOMM_Frame(
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
)
@staticmethod
def from_bytes(data):
def from_bytes(data: bytes):
# Extract fields
dlci = (data[0] >> 2) & 0x3F
c_r = (data[0] >> 1) & 0x01
@@ -227,15 +236,23 @@ class RFCOMM_Frame:
# -----------------------------------------------------------------------------
class RFCOMM_MCC_PN:
dlci: int
cl: int
priority: int
ack_timer: int
max_frame_size: int
max_retransmissions: int
window_size: int
def __init__(
self,
dlci,
cl,
priority,
ack_timer,
max_frame_size,
max_retransmissions,
window_size,
dlci: int,
cl: int,
priority: int,
ack_timer: int,
max_frame_size: int,
max_retransmissions: int,
window_size: int,
):
self.dlci = dlci
self.cl = cl
@@ -246,7 +263,7 @@ class RFCOMM_MCC_PN:
self.window_size = window_size
@staticmethod
def from_bytes(data):
def from_bytes(data: bytes):
return RFCOMM_MCC_PN(
dlci=data[0],
cl=data[1],
@@ -285,7 +302,14 @@ class RFCOMM_MCC_PN:
# -----------------------------------------------------------------------------
class RFCOMM_MCC_MSC:
def __init__(self, dlci, fc, rtc, rtr, ic, dv):
dlci: int
fc: int
rtc: int
rtr: int
ic: int
dv: int
def __init__(self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int):
self.dlci = dlci
self.fc = fc
self.rtc = rtc
@@ -294,7 +318,7 @@ class RFCOMM_MCC_MSC:
self.dv = dv
@staticmethod
def from_bytes(data):
def from_bytes(data: bytes):
return RFCOMM_MCC_MSC(
dlci=data[0] >> 2,
fc=data[1] >> 1 & 1,
@@ -347,7 +371,12 @@ class DLC(EventEmitter):
RESET: 'RESET',
}
def __init__(self, multiplexer, dlci, max_frame_size, initial_tx_credits):
connection_result: Optional[asyncio.Future]
sink: Optional[Callable[[bytes], None]]
def __init__(
self, multiplexer, dlci: int, max_frame_size: int, initial_tx_credits: int
):
super().__init__()
self.multiplexer = multiplexer
self.dlci = dlci
@@ -368,23 +397,23 @@ class DLC(EventEmitter):
)
@staticmethod
def state_name(state):
def state_name(state: int) -> str:
return DLC.STATE_NAMES[state]
def change_state(self, new_state):
def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "magenta")}'
)
self.state = new_state
def send_frame(self, frame):
def send_frame(self, frame: RFCOMM_Frame) -> None:
self.multiplexer.send_frame(frame)
def on_frame(self, frame):
def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame):
def on_sabm_frame(self, _frame) -> None:
if self.state != DLC.CONNECTING:
logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red')
@@ -404,7 +433,7 @@ class DLC(EventEmitter):
self.change_state(DLC.CONNECTED)
self.emit('open')
def on_ua_frame(self, _frame):
def on_ua_frame(self, _frame) -> None:
if self.state != DLC.CONNECTING:
logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red')
@@ -422,15 +451,15 @@ class DLC(EventEmitter):
self.change_state(DLC.CONNECTED)
self.multiplexer.on_dlc_open_complete(self)
def on_dm_frame(self, frame):
def on_dm_frame(self, frame) -> None:
# TODO: handle all states
pass
def on_disc_frame(self, _frame):
def on_disc_frame(self, _frame) -> None:
# TODO: handle all states
self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
def on_uih_frame(self, frame):
def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
data = frame.information
if frame.p_f == 1:
# With credits
@@ -460,10 +489,10 @@ class DLC(EventEmitter):
# Check if there's anything to send (including credits)
self.process_tx()
def on_ui_frame(self, frame):
def on_ui_frame(self, frame) -> None:
pass
def on_mcc_msc(self, c_r, msc):
def on_mcc_msc(self, c_r, msc) -> None:
if c_r:
# Command
logger.debug(f'<<< MCC MSC Command: {msc}')
@@ -477,7 +506,7 @@ class DLC(EventEmitter):
# Response
logger.debug(f'<<< MCC MSC Response: {msc}')
def connect(self):
def connect(self) -> None:
if self.state != DLC.INIT:
raise InvalidStateError('invalid state')
@@ -485,7 +514,7 @@ class DLC(EventEmitter):
self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
def accept(self):
def accept(self) -> None:
if self.state != DLC.INIT:
raise InvalidStateError('invalid state')
@@ -503,13 +532,13 @@ class DLC(EventEmitter):
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.CONNECTING)
def rx_credits_needed(self):
def rx_credits_needed(self) -> int:
if self.rx_credits <= self.rx_threshold:
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
return 0
def process_tx(self):
def process_tx(self) -> None:
# Send anything we can (or an empty frame if we need to send rx credits)
rx_credits_needed = self.rx_credits_needed()
while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0:
@@ -547,7 +576,7 @@ class DLC(EventEmitter):
rx_credits_needed = 0
# Stream protocol
def write(self, data):
def write(self, data: Union[bytes, str]) -> None:
# We can only send bytes
if not isinstance(data, bytes):
if isinstance(data, str):
@@ -559,7 +588,7 @@ class DLC(EventEmitter):
self.tx_buffer += data
self.process_tx()
def drain(self):
def drain(self) -> None:
# TODO
pass
@@ -592,7 +621,13 @@ class Multiplexer(EventEmitter):
RESET: 'RESET',
}
def __init__(self, l2cap_channel, role):
connection_result: Optional[asyncio.Future]
disconnection_result: Optional[asyncio.Future]
open_result: Optional[asyncio.Future]
acceptor: Optional[Callable[[int], bool]]
dlcs: Dict[int, DLC]
def __init__(self, l2cap_channel: l2cap.Channel, role: int) -> None:
super().__init__()
self.role = role
self.l2cap_channel = l2cap_channel
@@ -607,20 +642,20 @@ class Multiplexer(EventEmitter):
l2cap_channel.sink = self.on_pdu
@staticmethod
def state_name(state):
def state_name(state: int):
return Multiplexer.STATE_NAMES[state]
def change_state(self, new_state):
def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
self.state = new_state
def send_frame(self, frame):
def send_frame(self, frame: RFCOMM_Frame) -> None:
logger.debug(f'>>> Multiplexer sending {frame}')
self.l2cap_channel.send_pdu(frame)
def on_pdu(self, pdu):
def on_pdu(self, pdu: bytes) -> None:
frame = RFCOMM_Frame.from_bytes(pdu)
logger.debug(f'<<< Multiplexer received {frame}')
@@ -640,18 +675,18 @@ class Multiplexer(EventEmitter):
return
dlc.on_frame(frame)
def on_frame(self, frame):
def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame):
def on_sabm_frame(self, _frame) -> None:
if self.state != Multiplexer.INIT:
logger.debug('not in INIT state, ignoring SABM')
return
self.change_state(Multiplexer.CONNECTED)
self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0))
def on_ua_frame(self, _frame):
def on_ua_frame(self, _frame) -> None:
if self.state == Multiplexer.CONNECTING:
self.change_state(Multiplexer.CONNECTED)
if self.connection_result:
@@ -663,7 +698,7 @@ class Multiplexer(EventEmitter):
self.disconnection_result.set_result(None)
self.disconnection_result = None
def on_dm_frame(self, _frame):
def on_dm_frame(self, _frame) -> None:
if self.state == Multiplexer.OPENING:
self.change_state(Multiplexer.CONNECTED)
if self.open_result:
@@ -678,13 +713,13 @@ class Multiplexer(EventEmitter):
else:
logger.warning(f'unexpected state for DM: {self}')
def on_disc_frame(self, _frame):
def on_disc_frame(self, _frame) -> None:
self.change_state(Multiplexer.DISCONNECTED)
self.send_frame(
RFCOMM_Frame.ua(c_r=0 if self.role == Multiplexer.INITIATOR else 1, dlci=0)
)
def on_uih_frame(self, frame):
def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
(mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
if mcc_type == RFCOMM_MCC_PN_TYPE:
@@ -694,10 +729,10 @@ class Multiplexer(EventEmitter):
mcs = RFCOMM_MCC_MSC.from_bytes(value)
self.on_mcc_msc(c_r, mcs)
def on_ui_frame(self, frame):
def on_ui_frame(self, frame) -> None:
pass
def on_mcc_pn(self, c_r, pn):
def on_mcc_pn(self, c_r, pn) -> None:
if c_r == 1:
# Command
logger.debug(f'<<< PN Command: {pn}')
@@ -736,14 +771,14 @@ class Multiplexer(EventEmitter):
else:
logger.warning('ignoring PN response')
def on_mcc_msc(self, c_r, msc):
def on_mcc_msc(self, c_r, msc) -> None:
dlc = self.dlcs.get(msc.dlci)
if dlc is None:
logger.warning(f'no dlc for DLCI {msc.dlci}')
return
dlc.on_mcc_msc(c_r, msc)
async def connect(self):
async def connect(self) -> None:
if self.state != Multiplexer.INIT:
raise InvalidStateError('invalid state')
@@ -752,7 +787,7 @@ class Multiplexer(EventEmitter):
self.send_frame(RFCOMM_Frame.sabm(c_r=1, dlci=0))
return await self.connection_result
async def disconnect(self):
async def disconnect(self) -> None:
if self.state != Multiplexer.CONNECTED:
return
@@ -765,7 +800,7 @@ class Multiplexer(EventEmitter):
)
await self.disconnection_result
async def open_dlc(self, channel):
async def open_dlc(self, channel: int) -> DLC:
if self.state != Multiplexer.CONNECTED:
if self.state == Multiplexer.OPENING:
raise InvalidStateError('open already in progress')
@@ -796,7 +831,7 @@ class Multiplexer(EventEmitter):
self.open_result = None
return result
def on_dlc_open_complete(self, dlc):
def on_dlc_open_complete(self, dlc: DLC):
logger.debug(f'DLC [{dlc.dlci}] open complete')
self.change_state(Multiplexer.CONNECTED)
if self.open_result:
@@ -808,13 +843,16 @@ class Multiplexer(EventEmitter):
# -----------------------------------------------------------------------------
class Client:
def __init__(self, device, connection):
multiplexer: Optional[Multiplexer]
l2cap_channel: Optional[l2cap.Channel]
def __init__(self, device, connection) -> None:
self.device = device
self.connection = connection
self.l2cap_channel = None
self.multiplexer = None
async def start(self):
async def start(self) -> Multiplexer:
# Create a new L2CAP connection
try:
self.l2cap_channel = await self.device.l2cap_channel_manager.connect(
@@ -824,6 +862,7 @@ class Client:
logger.warning(f'L2CAP connection failed: {error}')
raise
assert self.l2cap_channel is not None
# Create a mutliplexer to manage DLCs with the server
self.multiplexer = Multiplexer(self.l2cap_channel, Multiplexer.INITIATOR)
@@ -832,7 +871,9 @@ class Client:
return self.multiplexer
async def shutdown(self):
async def shutdown(self) -> None:
if self.multiplexer is None:
return
# Disconnect the multiplexer
await self.multiplexer.disconnect()
self.multiplexer = None
@@ -843,7 +884,9 @@ class Client:
# -----------------------------------------------------------------------------
class Server(EventEmitter):
def __init__(self, device):
acceptors: Dict[int, Callable[[DLC], None]]
def __init__(self, device) -> None:
super().__init__()
self.device = device
self.multiplexer = None
@@ -852,7 +895,7 @@ class Server(EventEmitter):
# Register ourselves with the L2CAP channel manager
device.register_l2cap_server(RFCOMM_PSM, self.on_connection)
def listen(self, acceptor, channel=0):
def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int:
if channel:
if channel in self.acceptors:
# Busy
@@ -874,11 +917,11 @@ class Server(EventEmitter):
self.acceptors[channel] = acceptor
return channel
def on_connection(self, l2cap_channel):
def on_connection(self, l2cap_channel: l2cap.Channel) -> None:
logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
def on_l2cap_channel_open(self, l2cap_channel):
def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None:
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
# Create a new multiplexer for the channel
@@ -889,10 +932,10 @@ class Server(EventEmitter):
# Notify
self.emit('start', multiplexer)
def accept_dlc(self, channel_number):
def accept_dlc(self, channel_number: int) -> bool:
return channel_number in self.acceptors
def on_dlc(self, dlc):
def on_dlc(self, dlc: DLC) -> None:
logger.debug(f'@@@ new DLC connected: {dlc}')
# Let the acceptor know

View File

@@ -1805,7 +1805,7 @@ class Manager(EventEmitter):
self.device.abort_on('flush', store_keys())
# Notify the device
self.device.on_pairing(session.connection, keys, session.sc)
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
def on_pairing_failure(self, session: Session, reason: int) -> None:
self.device.on_pairing_failure(session.connection, reason)

179
tests/keystore_test.py Normal file
View File

@@ -0,0 +1,179 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import json
import logging
import tempfile
import os
from bumble.keys import JsonKeyStore, PairingKeys
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
JSON1 = """
{
"my_namespace": {
"14:7D:DA:4E:53:A8/P": {
"address_type": 0,
"irk": {
"authenticated": false,
"value": "e7b2543b206e4e46b44f9e51dad22bd1"
},
"link_key": {
"authenticated": false,
"value": "0745dd9691e693d9dca740f7d8dfea75"
},
"ltk": {
"authenticated": false,
"value": "d1897ee10016eb1a08e4e037fd54c683"
}
}
}
}
"""
JSON2 = """
{
"my_namespace1": {
},
"my_namespace2": {
}
}
"""
JSON3 = """
{
"my_namespace1": {
},
"__DEFAULT__": {
"14:7D:DA:4E:53:A8/P": {
"address_type": 0,
"irk": {
"authenticated": false,
"value": "e7b2543b206e4e46b44f9e51dad22bd1"
}
}
}
}
"""
# -----------------------------------------------------------------------------
async def test_basic():
with tempfile.NamedTemporaryFile(mode="r+", encoding='utf-8') as file:
keystore = JsonKeyStore('my_namespace', file.name)
file.write("{}")
file.flush()
keys = await keystore.get_all()
assert len(keys) == 0
keys = PairingKeys()
await keystore.update('foo', keys)
foo = await keystore.get('foo')
assert foo is not None
assert foo.ltk is None
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
foo = await keystore.get('foo')
assert foo is not None
assert foo.ltk is not None
assert foo.ltk.value == ltk
file.flush()
with open(file.name, "r", encoding="utf-8") as json_file:
json_data = json.load(json_file)
assert 'my_namespace' in json_data
assert 'foo' in json_data['my_namespace']
assert 'ltk' in json_data['my_namespace']['foo']
# -----------------------------------------------------------------------------
async def test_parsing():
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
keystore = JsonKeyStore('my_namespace', file.name)
file.write(JSON1)
file.flush()
foo = await keystore.get('14:7D:DA:4E:53:A8/P')
assert foo is not None
assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683')
# -----------------------------------------------------------------------------
async def test_default_namespace():
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON1)
file.flush()
all_keys = await keystore.get_all()
assert len(all_keys) == 1
name, keys = all_keys[0]
assert name == '14:7D:DA:4E:53:A8/P'
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON2)
file.flush()
keys = PairingKeys()
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
file.flush()
with open(file.name, "r", encoding="utf-8") as json_file:
json_data = json.load(json_file)
assert '__DEFAULT__' in json_data
assert 'foo' in json_data['__DEFAULT__']
assert 'ltk' in json_data['__DEFAULT__']['foo']
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON3)
file.flush()
all_keys = await keystore.get_all()
assert len(all_keys) == 1
name, keys = all_keys[0]
assert name == '14:7D:DA:4E:53:A8/P'
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
# -----------------------------------------------------------------------------
async def run_tests():
await test_basic()
await test_parsing()
await test_default_namespace()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run_tests())