use cancel_on_disconnection helper

This commit is contained in:
Gilles Boccon-Gibod
2025-06-10 13:28:08 -04:00
parent 39518c89f5
commit 8137caf37b
9 changed files with 43 additions and 66 deletions

View File

@@ -26,7 +26,6 @@ from __future__ import annotations
import logging
import asyncio
import enum
import secrets
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
@@ -896,7 +895,7 @@ class Session:
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
utils.cancel_on_event(self.connection, 'disconnection', prompt())
self.connection.cancel_on_disconnection(prompt())
def prompt_user_for_numeric_comparison(
self, code: int, next_steps: Callable[[], None]
@@ -915,7 +914,7 @@ class Session:
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
utils.cancel_on_event(self.connection, 'disconnection', prompt())
self.connection.cancel_on_disconnection(prompt())
def prompt_user_for_number(self, next_steps: Callable[[int], None]) -> None:
async def prompt() -> None:
@@ -932,7 +931,7 @@ class Session:
logger.warning(f'exception while prompting: {error}')
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
utils.cancel_on_event(self.connection, 'disconnection', prompt())
self.connection.cancel_on_disconnection(prompt())
async def display_passkey(self) -> None:
# Get the passkey value from the delegate
@@ -974,11 +973,7 @@ class Session:
next_steps()
try:
utils.cancel_on_event(
self.connection,
'disconnection',
display_passkey(),
)
self.connection.cancel_on_disconnection(display_passkey())
except Exception as error:
logger.warning(f'exception while displaying passkey: {error}')
else:
@@ -1050,7 +1045,7 @@ class Session:
)
# Perform the next steps asynchronously in case we need to wait for input
utils.cancel_on_event(self.connection, 'disconnection', next_steps())
self.connection.cancel_on_disconnection(next_steps())
else:
confirm_value = crypto.c1(
self.tk,
@@ -1173,8 +1168,8 @@ class Session:
self.connection.transport == PhysicalTransport.BR_EDR
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
):
self.ctkd_task = utils.cancel_on_event(
self.connection, 'disconnection', self.get_link_key_and_derive_ltk()
self.ctkd_task = self.connection.cancel_on_disconnection(
self.get_link_key_and_derive_ltk()
)
elif not self.sc:
# Distribute the LTK, EDIV and RAND
@@ -1212,8 +1207,8 @@ class Session:
self.connection.transport == PhysicalTransport.BR_EDR
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
):
self.ctkd_task = utils.cancel_on_event(
self.connection, 'disconnection', self.get_link_key_and_derive_ltk()
self.ctkd_task = self.connection.cancel_on_disconnection(
self.get_link_key_and_derive_ltk()
)
# Distribute the LTK, EDIV and RAND
elif not self.sc:
@@ -1305,9 +1300,7 @@ class Session:
# Wait for the pairing process to finish
assert self.pairing_result
await utils.cancel_on_event(
self.connection, 'disconnection', self.pairing_result
)
await self.connection.cancel_on_disconnection(self.pairing_result)
def on_disconnection(self, _: int) -> None:
self.connection.remove_listener(
@@ -1328,7 +1321,7 @@ class Session:
if self.is_initiator:
self.distribute_keys()
utils.cancel_on_event(self.connection, 'disconnection', self.on_pairing())
self.connection.cancel_on_disconnection(self.on_pairing())
def on_connection_encryption_change(self) -> None:
if self.connection.is_encrypted and not self.completed:
@@ -1439,10 +1432,8 @@ class Session:
def on_smp_pairing_request_command(
self, command: SMP_Pairing_Request_Command
) -> None:
utils.cancel_on_event(
self.connection,
'disconnection',
self.on_smp_pairing_request_command_async(command),
self.connection.cancel_on_disconnection(
self.on_smp_pairing_request_command_async(command)
)
async def on_smp_pairing_request_command_async(
@@ -1854,7 +1845,7 @@ class Session:
self.send_pairing_confirm_command()
else:
def next_steps():
def next_steps() -> None:
# Send our public key back to the initiator
self.send_public_key_command()
@@ -1891,7 +1882,7 @@ class Session:
self.wait_before_continuing = None
self.send_pairing_dhkey_check_command()
utils.cancel_on_event(self.connection, 'disconnection', next_steps())
self.connection.cancel_on_disconnection(next_steps())
else:
self.send_pairing_dhkey_check_command()
else: