improve smp compatibility with other OS flows

This commit is contained in:
Gilles Boccon-Gibod
2023-02-13 10:53:00 -08:00
parent 1321c7da81
commit e6fc63b2d8
5 changed files with 136 additions and 76 deletions

View File

@@ -71,5 +71,10 @@
"editor.rulers": [88]
},
"python.formatting.provider": "black",
"pylint.importStrategy": "useBundled"
"pylint.importStrategy": "useBundled",
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

View File

@@ -19,8 +19,8 @@ import asyncio
import os
import logging
import click
import aioconsole
from colors import color
from prompt_toolkit.shortcuts import PromptSession
from bumble.device import Device, Peer
from bumble.transport import open_transport_or_link
@@ -42,9 +42,23 @@ from bumble.att import (
)
# -----------------------------------------------------------------------------
class Waiter:
instance = None
def __init__(self):
self.done = asyncio.get_running_loop().create_future()
def terminate(self):
self.done.set_result(None)
async def wait_until_terminated(self):
return await self.done
# -----------------------------------------------------------------------------
class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, prompt):
def __init__(self, mode, connection, capability_string, do_prompt):
super().__init__(
{
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
@@ -58,7 +72,19 @@ class Delegate(PairingDelegate):
self.mode = mode
self.peer = Peer(connection)
self.peer_name = None
self.prompt = prompt
self.do_prompt = do_prompt
def print(self, message):
print(color(message, 'yellow'))
async def prompt(self, message):
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
session = PromptSession(message)
# with patch_stdout.patch_stdout(raw=True):
response = await session.prompt_async()
return response.lower().strip()
async def update_peer_name(self):
if self.peer_name is not None:
@@ -73,19 +99,15 @@ class Delegate(PairingDelegate):
self.peer_name = '[?]'
async def accept(self):
if self.prompt:
if self.do_prompt:
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for acceptance
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing request from {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
self.print('###-----------------------------------')
self.print(f'### Pairing request from {self.peer_name}')
self.print('###-----------------------------------')
while True:
response = await aioconsole.ainput(color('>>> Accept? ', 'yellow'))
response = response.lower().strip()
response = await self.prompt('>>> Accept? ')
if response == 'yes':
return True
@@ -96,23 +118,17 @@ class Delegate(PairingDelegate):
# Accept silently
return True
async def compare_numbers(self, number, digits=6):
async def compare_numbers(self, number, digits):
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for a numeric comparison
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print('###-----------------------------------')
while True:
response = await aioconsole.ainput(
color(
f'>>> Does the other device display {number:0{digits}}? ', 'yellow'
)
response = await self.prompt(
f'>>> Does the other device display {number:0{digits}}? '
)
response = response.lower().strip()
if response == 'yes':
return True
@@ -123,30 +139,24 @@ class Delegate(PairingDelegate):
async def get_number(self):
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for a PIN
while True:
try:
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
return int(await aioconsole.ainput(color('>>> Enter PIN: ', 'yellow')))
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print('###-----------------------------------')
return int(await self.prompt('>>> Enter PIN: '))
except ValueError:
pass
async def display_number(self, number, digits=6):
async def display_number(self, number, digits):
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Display a PIN code
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color(f'### PIN: {number:0{digits}}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print(f'### PIN: {number:0{digits}}')
self.print('###-----------------------------------')
# -----------------------------------------------------------------------------
@@ -238,6 +248,7 @@ def on_pairing(keys):
print(color('*** Paired!', 'cyan'))
keys.print(prefix=color('*** ', 'cyan'))
print(color('***-----------------------------------', 'cyan'))
Waiter.instance.terminate()
# -----------------------------------------------------------------------------
@@ -245,6 +256,7 @@ def on_pairing_failure(reason):
print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color('***-----------------------------------', 'red'))
Waiter.instance.terminate()
# -----------------------------------------------------------------------------
@@ -262,6 +274,8 @@ async def pair(
hci_transport,
address_or_name,
):
Waiter.instance = Waiter()
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
@@ -332,7 +346,19 @@ async def pair(
# Advertise so that peers can find us and connect
await device.start_advertising(auto_restart=True)
await hci_source.wait_for_termination()
# Run until the user asks to exit
await Waiter.instance.wait_until_terminated()
# -----------------------------------------------------------------------------
class LogHandler(logging.Handler):
def __init__(self):
super().__init__()
self.setFormatter(logging.Formatter('%(levelname)s:%(name)s:%(message)s'))
def emit(self, record):
message = self.format(record)
print(message)
# -----------------------------------------------------------------------------
@@ -388,7 +414,13 @@ def main(
hci_transport,
address_or_name,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Setup logging
log_handler = LogHandler()
root_logger = logging.getLogger()
root_logger.addHandler(log_handler)
root_logger.setLevel(os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Pair
asyncio.run(
pair(
mode,

View File

@@ -518,13 +518,14 @@ class PairingDelegate:
async def confirm(self) -> bool:
return True
async def compare_numbers(self, _number: int, _digits: int = 6) -> bool:
# pylint: disable-next=unused-argument
async def compare_numbers(self, number: int, digits: int) -> bool:
return True
async def get_number(self) -> int:
return 0
async def display_number(self, _number: int, _digits: int = 6) -> None:
async def display_number(self, number: int, digits: int) -> None:
pass
async def key_distribution_response(
@@ -661,7 +662,8 @@ class Session:
self.peer_expected_distributions = []
self.dh_key = None
self.confirm_value = None
self.passkey = 0
self.passkey = None
self.passkey_ready = asyncio.Event()
self.passkey_step = 0
self.passkey_display = False
self.pairing_method = 0
@@ -839,6 +841,7 @@ class Session:
# Generate random Passkey/PIN code
self.passkey = secrets.randbelow(1000000)
logger.debug(f'Pairing PIN CODE: {self.passkey:06}')
self.passkey_ready.set()
# The value of TK is computed from the PIN code
if not self.sc:
@@ -859,6 +862,8 @@ class Session:
self.tk = passkey.to_bytes(16, byteorder='little')
logger.debug(f'TK from passkey = {self.tk.hex()}')
self.passkey_ready.set()
if next_steps is not None:
next_steps()
@@ -910,17 +915,29 @@ class Session:
logger.debug(f'generated random: {self.r.hex()}')
if self.sc:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
z = 0
elif self.pairing_method == self.PASSKEY:
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else:
return
if self.is_initiator:
confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z]))
else:
confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z]))
async def next_steps():
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
z = 0
elif self.pairing_method == self.PASSKEY:
# We need a passkey
await self.passkey_ready.wait()
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else:
return
if self.is_initiator:
confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z]))
else:
confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z]))
self.send_command(
SMP_Pairing_Confirm_Command(confirm_value=confirm_value)
)
# Perform the next steps asynchronously in case we need to wait for input
self.connection.abort_on('disconnection', next_steps())
else:
confirm_value = crypto.c1(
self.tk,
@@ -933,7 +950,7 @@ class Session:
self.ra,
)
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
def send_pairing_random_command(self):
self.send_command(SMP_Pairing_Random_Command(random_value=self.r))
@@ -1364,8 +1381,8 @@ class Session:
# Start phase 2
if self.sc:
if self.pairing_method == self.PASSKEY and self.passkey_display:
self.display_passkey()
if self.pairing_method == self.PASSKEY:
self.display_or_input_passkey()
self.send_public_key_command()
else:
@@ -1426,18 +1443,22 @@ class Session:
else:
srand = self.r
mrand = command.random_value
stk = crypto.s1(self.tk, srand, mrand)
logger.debug(f'STK = {stk.hex()}')
self.stk = crypto.s1(self.tk, srand, mrand)
logger.debug(f'STK = {self.stk.hex()}')
# Generate LTK
self.ltk = crypto.r()
if self.is_initiator:
self.start_encryption(stk)
self.start_encryption(self.stk)
else:
self.send_pairing_random_command()
def on_smp_pairing_random_command_secure_connections(self, command):
if self.pairing_method == self.PASSKEY and self.passkey is None:
logger.warning('no passkey entered, ignoring command')
return
# pylint: disable=too-many-return-statements
if self.is_initiator:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
@@ -1565,17 +1586,13 @@ class Session:
logger.debug(f'DH key: {self.dh_key.hex()}')
if self.is_initiator:
if self.pairing_method == self.PASSKEY:
if self.passkey_display:
self.send_pairing_confirm_command()
else:
self.input_passkey(self.send_pairing_confirm_command)
self.send_pairing_confirm_command()
else:
# Send our public key back to the initiator
if self.pairing_method == self.PASSKEY:
self.display_or_input_passkey(self.send_public_key_command)
else:
self.send_public_key_command()
self.display_or_input_passkey()
# Send our public key back to the initiator
self.send_public_key_command()
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
# We can now send the confirmation value

View File

@@ -30,7 +30,6 @@ package_dir =
bumble.apps = apps
include-package-data = True
install_requires =
aioconsole >= 0.4.1
ansicolors >= 1.1
appdirs >= 1.4
click >= 7.1.2; platform_system!='Emscripten'

View File

@@ -32,7 +32,6 @@ from bumble.smp import (
PairingDelegate,
SMP_PAIRING_NOT_SUPPORTED_ERROR,
SMP_CONFIRM_VALUE_FAILED_ERROR,
SMP_ID_KEY_DISTRIBUTION_FLAG,
)
from bumble.core import ProtocolError
@@ -273,9 +272,15 @@ KEY_DIST = range(16)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'io_cap, sc, mitm, key_dist', itertools.product(IO_CAP, SC, MITM, KEY_DIST)
'io_caps, sc, mitm, key_dist',
itertools.chain(
itertools.product([IO_CAP], SC, MITM, [15]),
itertools.product(
[[PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT]], SC, MITM, KEY_DIST
),
),
)
async def test_self_smp(io_cap, sc, mitm, key_dist):
async def test_self_smp(io_caps, sc, mitm, key_dist):
class Delegate(PairingDelegate):
def __init__(
self,
@@ -296,6 +301,7 @@ async def test_self_smp(io_cap, sc, mitm, key_dist):
self.peer_delegate = None
self.number = asyncio.get_running_loop().create_future()
# pylint: disable-next=unused-argument
async def compare_numbers(self, number, digits):
if self.peer_delegate is None:
logger.warning(f'[{self.name}] no peer delegate')
@@ -331,8 +337,9 @@ async def test_self_smp(io_cap, sc, mitm, key_dist):
pairing_config_sets = [('Initiator', [None]), ('Responder', [None])]
for pairing_config_set in pairing_config_sets:
delegate = Delegate(pairing_config_set[0], io_cap, key_dist, key_dist)
pairing_config_set[1].append(PairingConfig(sc, mitm, True, delegate))
for io_cap in io_caps:
delegate = Delegate(pairing_config_set[0], io_cap, key_dist, key_dist)
pairing_config_set[1].append(PairingConfig(sc, mitm, True, delegate))
for pairing_config1 in pairing_config_sets[0][1]:
for pairing_config2 in pairing_config_sets[1][1]: