forked from auracaster/bumble_mirror
use cancel_on_disconnection helper
This commit is contained in:
@@ -35,6 +35,7 @@ import secrets
|
||||
import sys
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Optional,
|
||||
@@ -84,6 +85,7 @@ from bumble.profiles import gatt_service
|
||||
if TYPE_CHECKING:
|
||||
from bumble.transport.common import TransportSource, TransportSink
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -1883,6 +1885,12 @@ class Connection(utils.CompositeEventEmitter):
|
||||
def data_packet_queue(self) -> DataPacketQueue | None:
|
||||
return self.device.host.get_data_packet_queue(self.handle)
|
||||
|
||||
def cancel_on_disconnection(self, awaitable: Awaitable) -> Awaitable[_T]:
|
||||
"""
|
||||
Helper method to call `utils.cancel_on_event` for the 'disconnection' event
|
||||
"""
|
||||
return utils.cancel_on_event(self, self.EVENT_DISCONNECTION, awaitable)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
@@ -4358,9 +4366,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
raise hci.HCI_StatusError(result)
|
||||
|
||||
# Wait for the authentication to complete
|
||||
await utils.cancel_on_event(
|
||||
connection, Connection.EVENT_DISCONNECTION, pending_authentication
|
||||
)
|
||||
await connection.cancel_on_disconnection(pending_authentication)
|
||||
finally:
|
||||
connection.remove_listener(
|
||||
connection.EVENT_CONNECTION_AUTHENTICATION, on_authentication
|
||||
@@ -4447,9 +4453,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
raise hci.HCI_StatusError(result)
|
||||
|
||||
# Wait for the result
|
||||
await utils.cancel_on_event(
|
||||
connection, Connection.EVENT_DISCONNECTION, pending_encryption
|
||||
)
|
||||
await connection.cancel_on_disconnection(pending_encryption)
|
||||
finally:
|
||||
connection.remove_listener(
|
||||
connection.EVENT_CONNECTION_ENCRYPTION_CHANGE, on_encryption_change
|
||||
@@ -4493,9 +4497,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
f'{hci.HCI_Constant.error_name(result.status)}'
|
||||
)
|
||||
raise hci.HCI_StatusError(result)
|
||||
await utils.cancel_on_event(
|
||||
connection, Connection.EVENT_DISCONNECTION, pending_role_change
|
||||
)
|
||||
await connection.cancel_on_disconnection(pending_role_change)
|
||||
finally:
|
||||
connection.remove_listener(connection.EVENT_ROLE_CHANGE, on_role_change)
|
||||
connection.remove_listener(
|
||||
@@ -5727,9 +5729,7 @@ class Device(utils.CompositeEventEmitter):
|
||||
|
||||
async def reply() -> None:
|
||||
try:
|
||||
if await utils.cancel_on_event(
|
||||
connection, Connection.EVENT_DISCONNECTION, method()
|
||||
):
|
||||
if await connection.cancel_on_disconnection(method()):
|
||||
await self.host.send_command(
|
||||
hci.HCI_User_Confirmation_Request_Reply_Command(
|
||||
bd_addr=connection.peer_address
|
||||
@@ -5756,10 +5756,8 @@ class Device(utils.CompositeEventEmitter):
|
||||
|
||||
async def reply() -> None:
|
||||
try:
|
||||
number = await utils.cancel_on_event(
|
||||
connection,
|
||||
Connection.EVENT_DISCONNECTION,
|
||||
pairing_config.delegate.get_number(),
|
||||
number = await connection.cancel_on_disconnection(
|
||||
pairing_config.delegate.get_number()
|
||||
)
|
||||
if number is not None:
|
||||
await self.host.send_command(
|
||||
@@ -5792,10 +5790,8 @@ class Device(utils.CompositeEventEmitter):
|
||||
if io_capability == hci.HCI_KEYBOARD_ONLY_IO_CAPABILITY:
|
||||
# Ask the user to enter a string
|
||||
async def get_pin_code():
|
||||
pin_code = await utils.cancel_on_event(
|
||||
connection,
|
||||
Connection.EVENT_DISCONNECTION,
|
||||
pairing_config.delegate.get_string(16),
|
||||
pin_code = await connection.cancel_on_disconnection(
|
||||
pairing_config.delegate.get_string(16)
|
||||
)
|
||||
|
||||
if pin_code is not None:
|
||||
@@ -5833,10 +5829,8 @@ class Device(utils.CompositeEventEmitter):
|
||||
pairing_config = self.pairing_config_factory(connection)
|
||||
|
||||
# Show the passkey to the user
|
||||
utils.cancel_on_event(
|
||||
connection,
|
||||
Connection.EVENT_DISCONNECTION,
|
||||
pairing_config.delegate.display_number(passkey, digits=6),
|
||||
connection.cancel_on_disconnection(
|
||||
pairing_config.delegate.display_number(passkey, digits=6)
|
||||
)
|
||||
|
||||
# [Classic only]
|
||||
|
||||
@@ -818,9 +818,7 @@ class ClassicChannel(utils.EventEmitter):
|
||||
|
||||
# Wait for the connection to succeed or fail
|
||||
try:
|
||||
return await utils.cancel_on_event(
|
||||
self.connection, 'disconnection', self.connection_result
|
||||
)
|
||||
return await self.connection.cancel_on_disconnection(self.connection_result)
|
||||
finally:
|
||||
self.connection_result = None
|
||||
|
||||
|
||||
@@ -335,7 +335,7 @@ class HearingAccessService(gatt.TemplateService):
|
||||
# Update the active preset index if needed
|
||||
await self.notify_active_preset_for_connection(connection)
|
||||
|
||||
utils.cancel_on_event(connection, 'disconnection', on_connection_async())
|
||||
connection.cancel_on_disconnection(on_connection_async())
|
||||
|
||||
def _on_read_active_preset_index(self, connection: Connection) -> bytes:
|
||||
del connection # Unused
|
||||
|
||||
@@ -161,10 +161,8 @@ class VolumeControlService(gatt.TemplateService):
|
||||
handler = getattr(self, '_on_' + opcode.name.lower())
|
||||
if handler(*value[2:]):
|
||||
self.change_counter = (self.change_counter + 1) % 256
|
||||
utils.cancel_on_event(
|
||||
connection,
|
||||
'disconnection',
|
||||
connection.device.notify_subscribers(attribute=self.volume_state),
|
||||
connection.cancel_on_disconnection(
|
||||
connection.device.notify_subscribers(attribute=self.volume_state)
|
||||
)
|
||||
self.emit(self.EVENT_VOLUME_STATE_CHANGE)
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ from __future__ import annotations
|
||||
import logging
|
||||
import asyncio
|
||||
import enum
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -896,7 +895,7 @@ class Session:
|
||||
|
||||
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
|
||||
|
||||
utils.cancel_on_event(self.connection, 'disconnection', prompt())
|
||||
self.connection.cancel_on_disconnection(prompt())
|
||||
|
||||
def prompt_user_for_numeric_comparison(
|
||||
self, code: int, next_steps: Callable[[], None]
|
||||
@@ -915,7 +914,7 @@ class Session:
|
||||
|
||||
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
|
||||
|
||||
utils.cancel_on_event(self.connection, 'disconnection', prompt())
|
||||
self.connection.cancel_on_disconnection(prompt())
|
||||
|
||||
def prompt_user_for_number(self, next_steps: Callable[[int], None]) -> None:
|
||||
async def prompt() -> None:
|
||||
@@ -932,7 +931,7 @@ class Session:
|
||||
logger.warning(f'exception while prompting: {error}')
|
||||
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
|
||||
|
||||
utils.cancel_on_event(self.connection, 'disconnection', prompt())
|
||||
self.connection.cancel_on_disconnection(prompt())
|
||||
|
||||
async def display_passkey(self) -> None:
|
||||
# Get the passkey value from the delegate
|
||||
@@ -974,11 +973,7 @@ class Session:
|
||||
next_steps()
|
||||
|
||||
try:
|
||||
utils.cancel_on_event(
|
||||
self.connection,
|
||||
'disconnection',
|
||||
display_passkey(),
|
||||
)
|
||||
self.connection.cancel_on_disconnection(display_passkey())
|
||||
except Exception as error:
|
||||
logger.warning(f'exception while displaying passkey: {error}')
|
||||
else:
|
||||
@@ -1050,7 +1045,7 @@ class Session:
|
||||
)
|
||||
|
||||
# Perform the next steps asynchronously in case we need to wait for input
|
||||
utils.cancel_on_event(self.connection, 'disconnection', next_steps())
|
||||
self.connection.cancel_on_disconnection(next_steps())
|
||||
else:
|
||||
confirm_value = crypto.c1(
|
||||
self.tk,
|
||||
@@ -1173,8 +1168,8 @@ class Session:
|
||||
self.connection.transport == PhysicalTransport.BR_EDR
|
||||
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
||||
):
|
||||
self.ctkd_task = utils.cancel_on_event(
|
||||
self.connection, 'disconnection', self.get_link_key_and_derive_ltk()
|
||||
self.ctkd_task = self.connection.cancel_on_disconnection(
|
||||
self.get_link_key_and_derive_ltk()
|
||||
)
|
||||
elif not self.sc:
|
||||
# Distribute the LTK, EDIV and RAND
|
||||
@@ -1212,8 +1207,8 @@ class Session:
|
||||
self.connection.transport == PhysicalTransport.BR_EDR
|
||||
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
||||
):
|
||||
self.ctkd_task = utils.cancel_on_event(
|
||||
self.connection, 'disconnection', self.get_link_key_and_derive_ltk()
|
||||
self.ctkd_task = self.connection.cancel_on_disconnection(
|
||||
self.get_link_key_and_derive_ltk()
|
||||
)
|
||||
# Distribute the LTK, EDIV and RAND
|
||||
elif not self.sc:
|
||||
@@ -1305,9 +1300,7 @@ class Session:
|
||||
|
||||
# Wait for the pairing process to finish
|
||||
assert self.pairing_result
|
||||
await utils.cancel_on_event(
|
||||
self.connection, 'disconnection', self.pairing_result
|
||||
)
|
||||
await self.connection.cancel_on_disconnection(self.pairing_result)
|
||||
|
||||
def on_disconnection(self, _: int) -> None:
|
||||
self.connection.remove_listener(
|
||||
@@ -1328,7 +1321,7 @@ class Session:
|
||||
if self.is_initiator:
|
||||
self.distribute_keys()
|
||||
|
||||
utils.cancel_on_event(self.connection, 'disconnection', self.on_pairing())
|
||||
self.connection.cancel_on_disconnection(self.on_pairing())
|
||||
|
||||
def on_connection_encryption_change(self) -> None:
|
||||
if self.connection.is_encrypted and not self.completed:
|
||||
@@ -1439,10 +1432,8 @@ class Session:
|
||||
def on_smp_pairing_request_command(
|
||||
self, command: SMP_Pairing_Request_Command
|
||||
) -> None:
|
||||
utils.cancel_on_event(
|
||||
self.connection,
|
||||
'disconnection',
|
||||
self.on_smp_pairing_request_command_async(command),
|
||||
self.connection.cancel_on_disconnection(
|
||||
self.on_smp_pairing_request_command_async(command)
|
||||
)
|
||||
|
||||
async def on_smp_pairing_request_command_async(
|
||||
@@ -1854,7 +1845,7 @@ class Session:
|
||||
self.send_pairing_confirm_command()
|
||||
else:
|
||||
|
||||
def next_steps():
|
||||
def next_steps() -> None:
|
||||
# Send our public key back to the initiator
|
||||
self.send_public_key_command()
|
||||
|
||||
@@ -1891,7 +1882,7 @@ class Session:
|
||||
self.wait_before_continuing = None
|
||||
self.send_pairing_dhkey_check_command()
|
||||
|
||||
utils.cancel_on_event(self.connection, 'disconnection', next_steps())
|
||||
self.connection.cancel_on_disconnection(next_steps())
|
||||
else:
|
||||
self.send_pairing_dhkey_check_command()
|
||||
else:
|
||||
|
||||
@@ -75,9 +75,7 @@ async def main() -> None:
|
||||
def on_cis_request(
|
||||
connection: Connection, cis_handle: int, _cig_id: int, _cis_id: int
|
||||
):
|
||||
utils.cancel_on_event(
|
||||
connection, 'disconnection', devices[0].accept_cis_request(cis_handle)
|
||||
)
|
||||
connection.cancel_on_disconnection(devices[0].accept_cis_request(cis_handle))
|
||||
|
||||
devices[0].on('cis_request', on_cis_request)
|
||||
|
||||
|
||||
@@ -61,14 +61,12 @@ def on_dlc(dlc: rfcomm.DLC, configuration: hfp.HfConfiguration):
|
||||
else:
|
||||
raise RuntimeError("unknown active codec")
|
||||
|
||||
utils.cancel_on_event(
|
||||
connection,
|
||||
'disconnection',
|
||||
connection.cancel_on_disconnection(
|
||||
connection.device.send_command(
|
||||
hci.HCI_Enhanced_Accept_Synchronous_Connection_Request_Command(
|
||||
bd_addr=connection.peer_address, **esco_parameters.asdict()
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
handler = functools.partial(on_sco_request, protocol=hf_protocol)
|
||||
|
||||
@@ -170,7 +170,7 @@ async def main() -> None:
|
||||
mcp.on('track_position', on_track_position)
|
||||
await mcp.subscribe_characteristics()
|
||||
|
||||
utils.cancel_on_event(connection, 'disconnection', on_connection_async())
|
||||
connection.cancel_on_disconnection(on_connection_async())
|
||||
|
||||
device.on('connection', on_connection)
|
||||
|
||||
|
||||
@@ -483,8 +483,8 @@ async def test_cis():
|
||||
_cig_id: int,
|
||||
_cis_id: int,
|
||||
):
|
||||
utils.cancel_on_event(
|
||||
acl_connection, 'disconnection', devices[1].accept_cis_request(cis_handle)
|
||||
acl_connection.cancel_on_disconnection(
|
||||
devices[1].accept_cis_request(cis_handle)
|
||||
)
|
||||
peripheral_cis_futures[cis_handle] = asyncio.get_running_loop().create_future()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user