improve smp compatibility with other OS flows

This commit is contained in:
Gilles Boccon-Gibod
2023-02-13 10:53:00 -08:00
parent 1321c7da81
commit e6fc63b2d8
5 changed files with 136 additions and 76 deletions

View File

@@ -71,5 +71,10 @@
"editor.rulers": [88] "editor.rulers": [88]
}, },
"python.formatting.provider": "black", "python.formatting.provider": "black",
"pylint.importStrategy": "useBundled" "pylint.importStrategy": "useBundled",
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
} }

View File

@@ -19,8 +19,8 @@ import asyncio
import os import os
import logging import logging
import click import click
import aioconsole
from colors import color from colors import color
from prompt_toolkit.shortcuts import PromptSession
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -42,9 +42,23 @@ from bumble.att import (
) )
# -----------------------------------------------------------------------------
class Waiter:
instance = None
def __init__(self):
self.done = asyncio.get_running_loop().create_future()
def terminate(self):
self.done.set_result(None)
async def wait_until_terminated(self):
return await self.done
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Delegate(PairingDelegate): class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, prompt): def __init__(self, mode, connection, capability_string, do_prompt):
super().__init__( super().__init__(
{ {
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY, 'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
@@ -58,7 +72,19 @@ class Delegate(PairingDelegate):
self.mode = mode self.mode = mode
self.peer = Peer(connection) self.peer = Peer(connection)
self.peer_name = None self.peer_name = None
self.prompt = prompt self.do_prompt = do_prompt
def print(self, message):
print(color(message, 'yellow'))
async def prompt(self, message):
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
session = PromptSession(message)
# with patch_stdout.patch_stdout(raw=True):
response = await session.prompt_async()
return response.lower().strip()
async def update_peer_name(self): async def update_peer_name(self):
if self.peer_name is not None: if self.peer_name is not None:
@@ -73,19 +99,15 @@ class Delegate(PairingDelegate):
self.peer_name = '[?]' self.peer_name = '[?]'
async def accept(self): async def accept(self):
if self.prompt: if self.do_prompt:
await self.update_peer_name() await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for acceptance # Prompt for acceptance
print(color('###-----------------------------------', 'yellow')) self.print('###-----------------------------------')
print(color(f'### Pairing request from {self.peer_name}', 'yellow')) self.print(f'### Pairing request from {self.peer_name}')
print(color('###-----------------------------------', 'yellow')) self.print('###-----------------------------------')
while True: while True:
response = await aioconsole.ainput(color('>>> Accept? ', 'yellow')) response = await self.prompt('>>> Accept? ')
response = response.lower().strip()
if response == 'yes': if response == 'yes':
return True return True
@@ -96,23 +118,17 @@ class Delegate(PairingDelegate):
# Accept silently # Accept silently
return True return True
async def compare_numbers(self, number, digits=6): async def compare_numbers(self, number, digits):
await self.update_peer_name() await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for a numeric comparison # Prompt for a numeric comparison
print(color('###-----------------------------------', 'yellow')) self.print('###-----------------------------------')
print(color(f'### Pairing with {self.peer_name}', 'yellow')) self.print(f'### Pairing with {self.peer_name}')
print(color('###-----------------------------------', 'yellow')) self.print('###-----------------------------------')
while True: while True:
response = await aioconsole.ainput( response = await self.prompt(
color( f'>>> Does the other device display {number:0{digits}}? '
f'>>> Does the other device display {number:0{digits}}? ', 'yellow'
)
) )
response = response.lower().strip()
if response == 'yes': if response == 'yes':
return True return True
@@ -123,30 +139,24 @@ class Delegate(PairingDelegate):
async def get_number(self): async def get_number(self):
await self.update_peer_name() await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for a PIN # Prompt for a PIN
while True: while True:
try: try:
print(color('###-----------------------------------', 'yellow')) self.print('###-----------------------------------')
print(color(f'### Pairing with {self.peer_name}', 'yellow')) self.print(f'### Pairing with {self.peer_name}')
print(color('###-----------------------------------', 'yellow')) self.print('###-----------------------------------')
return int(await aioconsole.ainput(color('>>> Enter PIN: ', 'yellow'))) return int(await self.prompt('>>> Enter PIN: '))
except ValueError: except ValueError:
pass pass
async def display_number(self, number, digits=6): async def display_number(self, number, digits):
await self.update_peer_name() await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Display a PIN code # Display a PIN code
print(color('###-----------------------------------', 'yellow')) self.print('###-----------------------------------')
print(color(f'### Pairing with {self.peer_name}', 'yellow')) self.print(f'### Pairing with {self.peer_name}')
print(color(f'### PIN: {number:0{digits}}', 'yellow')) self.print(f'### PIN: {number:0{digits}}')
print(color('###-----------------------------------', 'yellow')) self.print('###-----------------------------------')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -238,6 +248,7 @@ def on_pairing(keys):
print(color('*** Paired!', 'cyan')) print(color('*** Paired!', 'cyan'))
keys.print(prefix=color('*** ', 'cyan')) keys.print(prefix=color('*** ', 'cyan'))
print(color('***-----------------------------------', 'cyan')) print(color('***-----------------------------------', 'cyan'))
Waiter.instance.terminate()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -245,6 +256,7 @@ def on_pairing_failure(reason):
print(color('***-----------------------------------', 'red')) print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red')) print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color('***-----------------------------------', 'red')) print(color('***-----------------------------------', 'red'))
Waiter.instance.terminate()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -262,6 +274,8 @@ async def pair(
hci_transport, hci_transport,
address_or_name, address_or_name,
): ):
Waiter.instance = Waiter()
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
@@ -332,7 +346,19 @@ async def pair(
# Advertise so that peers can find us and connect # Advertise so that peers can find us and connect
await device.start_advertising(auto_restart=True) await device.start_advertising(auto_restart=True)
await hci_source.wait_for_termination() # Run until the user asks to exit
await Waiter.instance.wait_until_terminated()
# -----------------------------------------------------------------------------
class LogHandler(logging.Handler):
def __init__(self):
super().__init__()
self.setFormatter(logging.Formatter('%(levelname)s:%(name)s:%(message)s'))
def emit(self, record):
message = self.format(record)
print(message)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -388,7 +414,13 @@ def main(
hci_transport, hci_transport,
address_or_name, address_or_name,
): ):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) # Setup logging
log_handler = LogHandler()
root_logger = logging.getLogger()
root_logger.addHandler(log_handler)
root_logger.setLevel(os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Pair
asyncio.run( asyncio.run(
pair( pair(
mode, mode,

View File

@@ -518,13 +518,14 @@ class PairingDelegate:
async def confirm(self) -> bool: async def confirm(self) -> bool:
return True return True
async def compare_numbers(self, _number: int, _digits: int = 6) -> bool: # pylint: disable-next=unused-argument
async def compare_numbers(self, number: int, digits: int) -> bool:
return True return True
async def get_number(self) -> int: async def get_number(self) -> int:
return 0 return 0
async def display_number(self, _number: int, _digits: int = 6) -> None: async def display_number(self, number: int, digits: int) -> None:
pass pass
async def key_distribution_response( async def key_distribution_response(
@@ -661,7 +662,8 @@ class Session:
self.peer_expected_distributions = [] self.peer_expected_distributions = []
self.dh_key = None self.dh_key = None
self.confirm_value = None self.confirm_value = None
self.passkey = 0 self.passkey = None
self.passkey_ready = asyncio.Event()
self.passkey_step = 0 self.passkey_step = 0
self.passkey_display = False self.passkey_display = False
self.pairing_method = 0 self.pairing_method = 0
@@ -839,6 +841,7 @@ class Session:
# Generate random Passkey/PIN code # Generate random Passkey/PIN code
self.passkey = secrets.randbelow(1000000) self.passkey = secrets.randbelow(1000000)
logger.debug(f'Pairing PIN CODE: {self.passkey:06}') logger.debug(f'Pairing PIN CODE: {self.passkey:06}')
self.passkey_ready.set()
# The value of TK is computed from the PIN code # The value of TK is computed from the PIN code
if not self.sc: if not self.sc:
@@ -859,6 +862,8 @@ class Session:
self.tk = passkey.to_bytes(16, byteorder='little') self.tk = passkey.to_bytes(16, byteorder='little')
logger.debug(f'TK from passkey = {self.tk.hex()}') logger.debug(f'TK from passkey = {self.tk.hex()}')
self.passkey_ready.set()
if next_steps is not None: if next_steps is not None:
next_steps() next_steps()
@@ -910,17 +915,29 @@ class Session:
logger.debug(f'generated random: {self.r.hex()}') logger.debug(f'generated random: {self.r.hex()}')
if self.sc: if self.sc:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
z = 0
elif self.pairing_method == self.PASSKEY:
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else:
return
if self.is_initiator: async def next_steps():
confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z])) if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
else: z = 0
confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z])) elif self.pairing_method == self.PASSKEY:
# We need a passkey
await self.passkey_ready.wait()
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else:
return
if self.is_initiator:
confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z]))
else:
confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z]))
self.send_command(
SMP_Pairing_Confirm_Command(confirm_value=confirm_value)
)
# Perform the next steps asynchronously in case we need to wait for input
self.connection.abort_on('disconnection', next_steps())
else: else:
confirm_value = crypto.c1( confirm_value = crypto.c1(
self.tk, self.tk,
@@ -933,7 +950,7 @@ class Session:
self.ra, self.ra,
) )
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value)) self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
def send_pairing_random_command(self): def send_pairing_random_command(self):
self.send_command(SMP_Pairing_Random_Command(random_value=self.r)) self.send_command(SMP_Pairing_Random_Command(random_value=self.r))
@@ -1364,8 +1381,8 @@ class Session:
# Start phase 2 # Start phase 2
if self.sc: if self.sc:
if self.pairing_method == self.PASSKEY and self.passkey_display: if self.pairing_method == self.PASSKEY:
self.display_passkey() self.display_or_input_passkey()
self.send_public_key_command() self.send_public_key_command()
else: else:
@@ -1426,18 +1443,22 @@ class Session:
else: else:
srand = self.r srand = self.r
mrand = command.random_value mrand = command.random_value
stk = crypto.s1(self.tk, srand, mrand) self.stk = crypto.s1(self.tk, srand, mrand)
logger.debug(f'STK = {stk.hex()}') logger.debug(f'STK = {self.stk.hex()}')
# Generate LTK # Generate LTK
self.ltk = crypto.r() self.ltk = crypto.r()
if self.is_initiator: if self.is_initiator:
self.start_encryption(stk) self.start_encryption(self.stk)
else: else:
self.send_pairing_random_command() self.send_pairing_random_command()
def on_smp_pairing_random_command_secure_connections(self, command): def on_smp_pairing_random_command_secure_connections(self, command):
if self.pairing_method == self.PASSKEY and self.passkey is None:
logger.warning('no passkey entered, ignoring command')
return
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
if self.is_initiator: if self.is_initiator:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
@@ -1565,17 +1586,13 @@ class Session:
logger.debug(f'DH key: {self.dh_key.hex()}') logger.debug(f'DH key: {self.dh_key.hex()}')
if self.is_initiator: if self.is_initiator:
if self.pairing_method == self.PASSKEY: self.send_pairing_confirm_command()
if self.passkey_display:
self.send_pairing_confirm_command()
else:
self.input_passkey(self.send_pairing_confirm_command)
else: else:
# Send our public key back to the initiator
if self.pairing_method == self.PASSKEY: if self.pairing_method == self.PASSKEY:
self.display_or_input_passkey(self.send_public_key_command) self.display_or_input_passkey()
else:
self.send_public_key_command() # Send our public key back to the initiator
self.send_public_key_command()
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
# We can now send the confirmation value # We can now send the confirmation value

View File

@@ -30,7 +30,6 @@ package_dir =
bumble.apps = apps bumble.apps = apps
include-package-data = True include-package-data = True
install_requires = install_requires =
aioconsole >= 0.4.1
ansicolors >= 1.1 ansicolors >= 1.1
appdirs >= 1.4 appdirs >= 1.4
click >= 7.1.2; platform_system!='Emscripten' click >= 7.1.2; platform_system!='Emscripten'

View File

@@ -32,7 +32,6 @@ from bumble.smp import (
PairingDelegate, PairingDelegate,
SMP_PAIRING_NOT_SUPPORTED_ERROR, SMP_PAIRING_NOT_SUPPORTED_ERROR,
SMP_CONFIRM_VALUE_FAILED_ERROR, SMP_CONFIRM_VALUE_FAILED_ERROR,
SMP_ID_KEY_DISTRIBUTION_FLAG,
) )
from bumble.core import ProtocolError from bumble.core import ProtocolError
@@ -273,9 +272,15 @@ KEY_DIST = range(16)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'io_cap, sc, mitm, key_dist', itertools.product(IO_CAP, SC, MITM, KEY_DIST) 'io_caps, sc, mitm, key_dist',
itertools.chain(
itertools.product([IO_CAP], SC, MITM, [15]),
itertools.product(
[[PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT]], SC, MITM, KEY_DIST
),
),
) )
async def test_self_smp(io_cap, sc, mitm, key_dist): async def test_self_smp(io_caps, sc, mitm, key_dist):
class Delegate(PairingDelegate): class Delegate(PairingDelegate):
def __init__( def __init__(
self, self,
@@ -296,6 +301,7 @@ async def test_self_smp(io_cap, sc, mitm, key_dist):
self.peer_delegate = None self.peer_delegate = None
self.number = asyncio.get_running_loop().create_future() self.number = asyncio.get_running_loop().create_future()
# pylint: disable-next=unused-argument
async def compare_numbers(self, number, digits): async def compare_numbers(self, number, digits):
if self.peer_delegate is None: if self.peer_delegate is None:
logger.warning(f'[{self.name}] no peer delegate') logger.warning(f'[{self.name}] no peer delegate')
@@ -331,8 +337,9 @@ async def test_self_smp(io_cap, sc, mitm, key_dist):
pairing_config_sets = [('Initiator', [None]), ('Responder', [None])] pairing_config_sets = [('Initiator', [None]), ('Responder', [None])]
for pairing_config_set in pairing_config_sets: for pairing_config_set in pairing_config_sets:
delegate = Delegate(pairing_config_set[0], io_cap, key_dist, key_dist) for io_cap in io_caps:
pairing_config_set[1].append(PairingConfig(sc, mitm, True, delegate)) delegate = Delegate(pairing_config_set[0], io_cap, key_dist, key_dist)
pairing_config_set[1].append(PairingConfig(sc, mitm, True, delegate))
for pairing_config1 in pairing_config_sets[0][1]: for pairing_config1 in pairing_config_sets[0][1]:
for pairing_config2 in pairing_config_sets[1][1]: for pairing_config2 in pairing_config_sets[1][1]: