diff --git a/.vscode/settings.json b/.vscode/settings.json index 521fb84..864fe69 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -71,5 +71,10 @@ "editor.rulers": [88] }, "python.formatting.provider": "black", - "pylint.importStrategy": "useBundled" + "pylint.importStrategy": "useBundled", + "python.testing.pytestArgs": [ + "." + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } diff --git a/apps/pair.py b/apps/pair.py index cba88f8..94d4eec 100644 --- a/apps/pair.py +++ b/apps/pair.py @@ -19,8 +19,8 @@ import asyncio import os import logging import click -import aioconsole from colors import color +from prompt_toolkit.shortcuts import PromptSession from bumble.device import Device, Peer from bumble.transport import open_transport_or_link @@ -42,9 +42,23 @@ from bumble.att import ( ) +# ----------------------------------------------------------------------------- +class Waiter: + instance = None + + def __init__(self): + self.done = asyncio.get_running_loop().create_future() + + def terminate(self): + self.done.set_result(None) + + async def wait_until_terminated(self): + return await self.done + + # ----------------------------------------------------------------------------- class Delegate(PairingDelegate): - def __init__(self, mode, connection, capability_string, prompt): + def __init__(self, mode, connection, capability_string, do_prompt): super().__init__( { 'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY, @@ -58,7 +72,19 @@ class Delegate(PairingDelegate): self.mode = mode self.peer = Peer(connection) self.peer_name = None - self.prompt = prompt + self.do_prompt = do_prompt + + def print(self, message): + print(color(message, 'yellow')) + + async def prompt(self, message): + # Wait a bit to allow some of the log lines to print before we prompt + await asyncio.sleep(1) + + session = PromptSession(message) + # with patch_stdout.patch_stdout(raw=True): + response = await session.prompt_async() + return response.lower().strip() async def update_peer_name(self): if self.peer_name is not None: @@ -73,19 +99,15 @@ class Delegate(PairingDelegate): self.peer_name = '[?]' async def accept(self): - if self.prompt: + if self.do_prompt: await self.update_peer_name() - # Wait a bit to allow some of the log lines to print before we prompt - await asyncio.sleep(1) - # Prompt for acceptance - print(color('###-----------------------------------', 'yellow')) - print(color(f'### Pairing request from {self.peer_name}', 'yellow')) - print(color('###-----------------------------------', 'yellow')) + self.print('###-----------------------------------') + self.print(f'### Pairing request from {self.peer_name}') + self.print('###-----------------------------------') while True: - response = await aioconsole.ainput(color('>>> Accept? ', 'yellow')) - response = response.lower().strip() + response = await self.prompt('>>> Accept? ') if response == 'yes': return True @@ -96,23 +118,17 @@ class Delegate(PairingDelegate): # Accept silently return True - async def compare_numbers(self, number, digits=6): + async def compare_numbers(self, number, digits): await self.update_peer_name() - # Wait a bit to allow some of the log lines to print before we prompt - await asyncio.sleep(1) - # Prompt for a numeric comparison - print(color('###-----------------------------------', 'yellow')) - print(color(f'### Pairing with {self.peer_name}', 'yellow')) - print(color('###-----------------------------------', 'yellow')) + self.print('###-----------------------------------') + self.print(f'### Pairing with {self.peer_name}') + self.print('###-----------------------------------') while True: - response = await aioconsole.ainput( - color( - f'>>> Does the other device display {number:0{digits}}? ', 'yellow' - ) + response = await self.prompt( + f'>>> Does the other device display {number:0{digits}}? ' ) - response = response.lower().strip() if response == 'yes': return True @@ -123,30 +139,24 @@ class Delegate(PairingDelegate): async def get_number(self): await self.update_peer_name() - # Wait a bit to allow some of the log lines to print before we prompt - await asyncio.sleep(1) - # Prompt for a PIN while True: try: - print(color('###-----------------------------------', 'yellow')) - print(color(f'### Pairing with {self.peer_name}', 'yellow')) - print(color('###-----------------------------------', 'yellow')) - return int(await aioconsole.ainput(color('>>> Enter PIN: ', 'yellow'))) + self.print('###-----------------------------------') + self.print(f'### Pairing with {self.peer_name}') + self.print('###-----------------------------------') + return int(await self.prompt('>>> Enter PIN: ')) except ValueError: pass - async def display_number(self, number, digits=6): + async def display_number(self, number, digits): await self.update_peer_name() - # Wait a bit to allow some of the log lines to print before we prompt - await asyncio.sleep(1) - # Display a PIN code - print(color('###-----------------------------------', 'yellow')) - print(color(f'### Pairing with {self.peer_name}', 'yellow')) - print(color(f'### PIN: {number:0{digits}}', 'yellow')) - print(color('###-----------------------------------', 'yellow')) + self.print('###-----------------------------------') + self.print(f'### Pairing with {self.peer_name}') + self.print(f'### PIN: {number:0{digits}}') + self.print('###-----------------------------------') # ----------------------------------------------------------------------------- @@ -238,6 +248,7 @@ def on_pairing(keys): print(color('*** Paired!', 'cyan')) keys.print(prefix=color('*** ', 'cyan')) print(color('***-----------------------------------', 'cyan')) + Waiter.instance.terminate() # ----------------------------------------------------------------------------- @@ -245,6 +256,7 @@ def on_pairing_failure(reason): print(color('***-----------------------------------', 'red')) print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red')) print(color('***-----------------------------------', 'red')) + Waiter.instance.terminate() # ----------------------------------------------------------------------------- @@ -262,6 +274,8 @@ async def pair( hci_transport, address_or_name, ): + Waiter.instance = Waiter() + print('<<< connecting to HCI...') async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): print('<<< connected') @@ -332,7 +346,19 @@ async def pair( # Advertise so that peers can find us and connect await device.start_advertising(auto_restart=True) - await hci_source.wait_for_termination() + # Run until the user asks to exit + await Waiter.instance.wait_until_terminated() + + +# ----------------------------------------------------------------------------- +class LogHandler(logging.Handler): + def __init__(self): + super().__init__() + self.setFormatter(logging.Formatter('%(levelname)s:%(name)s:%(message)s')) + + def emit(self, record): + message = self.format(record) + print(message) # ----------------------------------------------------------------------------- @@ -388,7 +414,13 @@ def main( hci_transport, address_or_name, ): - logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + # Setup logging + log_handler = LogHandler() + root_logger = logging.getLogger() + root_logger.addHandler(log_handler) + root_logger.setLevel(os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + + # Pair asyncio.run( pair( mode, diff --git a/bumble/smp.py b/bumble/smp.py index 8c0c50a..512ffb6 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -518,13 +518,14 @@ class PairingDelegate: async def confirm(self) -> bool: return True - async def compare_numbers(self, _number: int, _digits: int = 6) -> bool: + # pylint: disable-next=unused-argument + async def compare_numbers(self, number: int, digits: int) -> bool: return True async def get_number(self) -> int: return 0 - async def display_number(self, _number: int, _digits: int = 6) -> None: + async def display_number(self, number: int, digits: int) -> None: pass async def key_distribution_response( @@ -661,7 +662,8 @@ class Session: self.peer_expected_distributions = [] self.dh_key = None self.confirm_value = None - self.passkey = 0 + self.passkey = None + self.passkey_ready = asyncio.Event() self.passkey_step = 0 self.passkey_display = False self.pairing_method = 0 @@ -839,6 +841,7 @@ class Session: # Generate random Passkey/PIN code self.passkey = secrets.randbelow(1000000) logger.debug(f'Pairing PIN CODE: {self.passkey:06}') + self.passkey_ready.set() # The value of TK is computed from the PIN code if not self.sc: @@ -859,6 +862,8 @@ class Session: self.tk = passkey.to_bytes(16, byteorder='little') logger.debug(f'TK from passkey = {self.tk.hex()}') + self.passkey_ready.set() + if next_steps is not None: next_steps() @@ -910,17 +915,29 @@ class Session: logger.debug(f'generated random: {self.r.hex()}') if self.sc: - if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): - z = 0 - elif self.pairing_method == self.PASSKEY: - z = 0x80 + ((self.passkey >> self.passkey_step) & 1) - else: - return - if self.is_initiator: - confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z])) - else: - confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z])) + async def next_steps(): + if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): + z = 0 + elif self.pairing_method == self.PASSKEY: + # We need a passkey + await self.passkey_ready.wait() + + z = 0x80 + ((self.passkey >> self.passkey_step) & 1) + else: + return + + if self.is_initiator: + confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z])) + else: + confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z])) + + self.send_command( + SMP_Pairing_Confirm_Command(confirm_value=confirm_value) + ) + + # Perform the next steps asynchronously in case we need to wait for input + self.connection.abort_on('disconnection', next_steps()) else: confirm_value = crypto.c1( self.tk, @@ -933,7 +950,7 @@ class Session: self.ra, ) - self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value)) + self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value)) def send_pairing_random_command(self): self.send_command(SMP_Pairing_Random_Command(random_value=self.r)) @@ -1364,8 +1381,8 @@ class Session: # Start phase 2 if self.sc: - if self.pairing_method == self.PASSKEY and self.passkey_display: - self.display_passkey() + if self.pairing_method == self.PASSKEY: + self.display_or_input_passkey() self.send_public_key_command() else: @@ -1426,18 +1443,22 @@ class Session: else: srand = self.r mrand = command.random_value - stk = crypto.s1(self.tk, srand, mrand) - logger.debug(f'STK = {stk.hex()}') + self.stk = crypto.s1(self.tk, srand, mrand) + logger.debug(f'STK = {self.stk.hex()}') # Generate LTK self.ltk = crypto.r() if self.is_initiator: - self.start_encryption(stk) + self.start_encryption(self.stk) else: self.send_pairing_random_command() def on_smp_pairing_random_command_secure_connections(self, command): + if self.pairing_method == self.PASSKEY and self.passkey is None: + logger.warning('no passkey entered, ignoring command') + return + # pylint: disable=too-many-return-statements if self.is_initiator: if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): @@ -1565,17 +1586,13 @@ class Session: logger.debug(f'DH key: {self.dh_key.hex()}') if self.is_initiator: - if self.pairing_method == self.PASSKEY: - if self.passkey_display: - self.send_pairing_confirm_command() - else: - self.input_passkey(self.send_pairing_confirm_command) + self.send_pairing_confirm_command() else: - # Send our public key back to the initiator if self.pairing_method == self.PASSKEY: - self.display_or_input_passkey(self.send_public_key_command) - else: - self.send_public_key_command() + self.display_or_input_passkey() + + # Send our public key back to the initiator + self.send_public_key_command() if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON): # We can now send the confirmation value diff --git a/setup.cfg b/setup.cfg index 066dfd7..781c367 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,6 @@ package_dir = bumble.apps = apps include-package-data = True install_requires = - aioconsole >= 0.4.1 ansicolors >= 1.1 appdirs >= 1.4 click >= 7.1.2; platform_system!='Emscripten' diff --git a/tests/self_test.py b/tests/self_test.py index 55f7e0c..751825f 100644 --- a/tests/self_test.py +++ b/tests/self_test.py @@ -32,7 +32,6 @@ from bumble.smp import ( PairingDelegate, SMP_PAIRING_NOT_SUPPORTED_ERROR, SMP_CONFIRM_VALUE_FAILED_ERROR, - SMP_ID_KEY_DISTRIBUTION_FLAG, ) from bumble.core import ProtocolError @@ -273,9 +272,15 @@ KEY_DIST = range(16) @pytest.mark.asyncio @pytest.mark.parametrize( - 'io_cap, sc, mitm, key_dist', itertools.product(IO_CAP, SC, MITM, KEY_DIST) + 'io_caps, sc, mitm, key_dist', + itertools.chain( + itertools.product([IO_CAP], SC, MITM, [15]), + itertools.product( + [[PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT]], SC, MITM, KEY_DIST + ), + ), ) -async def test_self_smp(io_cap, sc, mitm, key_dist): +async def test_self_smp(io_caps, sc, mitm, key_dist): class Delegate(PairingDelegate): def __init__( self, @@ -296,6 +301,7 @@ async def test_self_smp(io_cap, sc, mitm, key_dist): self.peer_delegate = None self.number = asyncio.get_running_loop().create_future() + # pylint: disable-next=unused-argument async def compare_numbers(self, number, digits): if self.peer_delegate is None: logger.warning(f'[{self.name}] no peer delegate') @@ -331,8 +337,9 @@ async def test_self_smp(io_cap, sc, mitm, key_dist): pairing_config_sets = [('Initiator', [None]), ('Responder', [None])] for pairing_config_set in pairing_config_sets: - delegate = Delegate(pairing_config_set[0], io_cap, key_dist, key_dist) - pairing_config_set[1].append(PairingConfig(sc, mitm, True, delegate)) + for io_cap in io_caps: + delegate = Delegate(pairing_config_set[0], io_cap, key_dist, key_dist) + pairing_config_set[1].append(PairingConfig(sc, mitm, True, delegate)) for pairing_config1 in pairing_config_sets[0][1]: for pairing_config2 in pairing_config_sets[1][1]: