host: spawn each asynchronous task with the right aliveness

This commit is contained in:
Abel Lucas
2022-11-29 00:59:04 +00:00
parent f5fe3d87f2
commit 287df94090
5 changed files with 125 additions and 54 deletions

View File

@@ -535,7 +535,7 @@ class Connection(CompositeEventEmitter):
self.on('disconnection_failure', abort.set_exception) self.on('disconnection_failure', abort.set_exception)
try: try:
await asyncio.wait_for(abort, timeout) await asyncio.wait_for(self.device.abort_on('flush', abort), timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
@@ -1592,7 +1592,7 @@ class Device(CompositeEventEmitter):
if transport == BT_LE_TRANSPORT: if transport == BT_LE_TRANSPORT:
self.le_connecting = True self.le_connecting = True
if timeout is None: if timeout is None:
return await pending_connection return await self.abort_on('flush', pending_connection)
else: else:
try: try:
return await asyncio.wait_for( return await asyncio.wait_for(
@@ -1609,7 +1609,7 @@ class Device(CompositeEventEmitter):
) )
try: try:
return await pending_connection return await self.abort_on('flush', pending_connection)
except ConnectionError: except ConnectionError:
raise TimeoutError() raise TimeoutError()
finally: finally:
@@ -1661,6 +1661,7 @@ class Device(CompositeEventEmitter):
try: try:
# Wait for a request or a completed connection # Wait for a request or a completed connection
pending_request = self.abort_on('flush', pending_request)
result = await ( result = await (
asyncio.wait_for(pending_request, timeout) asyncio.wait_for(pending_request, timeout)
if timeout if timeout
@@ -1682,6 +1683,9 @@ class Device(CompositeEventEmitter):
# Otherwise, result came from `on_connection_request` # Otherwise, result came from `on_connection_request`
peer_address, class_of_device, link_type = result peer_address, class_of_device, link_type = result
# Create a future so that we can wait for the connection's result
pending_connection = asyncio.get_running_loop().create_future()
def on_connection(connection): def on_connection(connection):
if ( if (
connection.transport == BT_BR_EDR_TRANSPORT connection.transport == BT_BR_EDR_TRANSPORT
@@ -1696,8 +1700,6 @@ class Device(CompositeEventEmitter):
): ):
pending_connection.set_exception(error) pending_connection.set_exception(error)
# Create a future so that we can wait for the connection's result
pending_connection = asyncio.get_running_loop().create_future()
self.on('connection', on_connection) self.on('connection', on_connection)
self.on('connection_failure', on_connection_failure) self.on('connection_failure', on_connection_failure)
@@ -1713,7 +1715,7 @@ class Device(CompositeEventEmitter):
) )
# Wait for connection complete # Wait for connection complete
return await pending_connection return await self.abort_on('flush', pending_connection)
finally: finally:
self.remove_listener('connection', on_connection) self.remove_listener('connection', on_connection)
@@ -1782,7 +1784,7 @@ class Device(CompositeEventEmitter):
# Wait for the disconnection process to complete # Wait for the disconnection process to complete
self.disconnecting = True self.disconnecting = True
return await pending_disconnection return await self.abort_on('flush', pending_disconnection)
finally: finally:
connection.remove_listener( connection.remove_listener(
'disconnection', pending_disconnection.set_result 'disconnection', pending_disconnection.set_result
@@ -1910,7 +1912,7 @@ class Device(CompositeEventEmitter):
else: else:
return None return None
return await peer_address return await self.abort_on('flush', peer_address)
finally: finally:
if handler is not None: if handler is not None:
self.remove_listener(event_name, handler) self.remove_listener(event_name, handler)
@@ -1994,7 +1996,7 @@ class Device(CompositeEventEmitter):
connection.authenticating = True connection.authenticating = True
# Wait for the authentication to complete # Wait for the authentication to complete
await pending_authentication await connection.abort_on('disconnection', pending_authentication)
finally: finally:
connection.authenticating = False connection.authenticating = False
connection.remove_listener('connection_authentication', on_authentication) connection.remove_listener('connection_authentication', on_authentication)
@@ -2068,7 +2070,7 @@ class Device(CompositeEventEmitter):
raise HCI_StatusError(result) raise HCI_StatusError(result)
# Wait for the result # Wait for the result
await pending_encryption await connection.abort_on('disconnection', pending_encryption)
finally: finally:
connection.remove_listener( connection.remove_listener(
'connection_encryption_change', on_encryption_change 'connection_encryption_change', on_encryption_change
@@ -2116,11 +2118,18 @@ class Device(CompositeEventEmitter):
raise HCI_StatusError(result) raise HCI_StatusError(result)
# Wait for the result # Wait for the result
return await pending_name return await self.abort_on('flush', pending_name)
finally: finally:
self.remove_listener('remote_name', handler) self.remove_listener('remote_name', handler)
self.remove_listener('remote_name_failure', failure_handler) self.remove_listener('remote_name_failure', failure_handler)
@host_event_handler
def on_flush(self):
self.emit('flush')
for _, connection in self.connections.items():
connection.emit('disconnection', 0)
self.connections = {}
# [Classic only] # [Classic only]
@host_event_handler @host_event_handler
def on_link_key(self, bd_addr, link_key, key_type): def on_link_key(self, bd_addr, link_key, key_type):
@@ -2135,7 +2144,7 @@ class Device(CompositeEventEmitter):
except Exception as error: except Exception as error:
logger.warn(f'!!! error while storing keys: {error}') logger.warn(f'!!! error while storing keys: {error}')
asyncio.create_task(store_keys()) self.abort_on('flush', store_keys())
if connection := self.find_connection_by_bd_addr( if connection := self.find_connection_by_bd_addr(
bd_addr, transport=BT_BR_EDR_TRANSPORT bd_addr, transport=BT_BR_EDR_TRANSPORT
@@ -2227,10 +2236,10 @@ class Device(CompositeEventEmitter):
async def new_connection(): async def new_connection():
# Figure out which PHY we're connected with # Figure out which PHY we're connected with
if self.host.supports_command(HCI_LE_READ_PHY_COMMAND): if self.host.supports_command(HCI_LE_READ_PHY_COMMAND):
result = await self.send_command( result = await asyncio.shield(self.send_command(
HCI_LE_Read_PHY_Command(connection_handle=connection_handle), HCI_LE_Read_PHY_Command(connection_handle=connection_handle),
check_result=True, check_result=True,
) ))
phy = ConnectionPHY( phy = ConnectionPHY(
result.return_parameters.tx_phy, result.return_parameters.rx_phy result.return_parameters.tx_phy, result.return_parameters.rx_phy
) )
@@ -2261,7 +2270,7 @@ class Device(CompositeEventEmitter):
# Emit an event to notify listeners of the new connection # Emit an event to notify listeners of the new connection
self.emit('connection', connection) self.emit('connection', connection)
asyncio.create_task(new_connection()) self.abort_on('flush', new_connection())
@host_event_handler @host_event_handler
def on_connection_failure(self, transport, peer_address, error_code): def on_connection_failure(self, transport, peer_address, error_code):
@@ -2338,7 +2347,7 @@ class Device(CompositeEventEmitter):
# Restart advertising if auto-restart is enabled # Restart advertising if auto-restart is enabled
if self.auto_restart_advertising: if self.auto_restart_advertising:
logger.debug('restarting advertising') logger.debug('restarting advertising')
asyncio.create_task( self.abort_on('flush',
self.start_advertising( self.start_advertising(
advertising_type=self.advertising_type, auto_restart=True advertising_type=self.advertising_type, auto_restart=True
) )
@@ -2460,17 +2469,19 @@ class Device(CompositeEventEmitter):
if can_compare: if can_compare:
async def compare_numbers(): async def compare_numbers():
numbers_match = await pairing_config.delegate.compare_numbers( numbers_match = await connection.abort_on('disconnection',
code, digits=6 pairing_config.delegate.compare_numbers(
code, digits=6
)
) )
if numbers_match: if numbers_match:
self.host.send_command_sync( await self.host.send_command(
HCI_User_Confirmation_Request_Reply_Command( HCI_User_Confirmation_Request_Reply_Command(
bd_addr=connection.peer_address bd_addr=connection.peer_address
) )
) )
else: else:
self.host.send_command_sync( await self.host.send_command(
HCI_User_Confirmation_Request_Negative_Reply_Command( HCI_User_Confirmation_Request_Negative_Reply_Command(
bd_addr=connection.peer_address bd_addr=connection.peer_address
) )
@@ -2480,15 +2491,16 @@ class Device(CompositeEventEmitter):
else: else:
async def confirm(): async def confirm():
confirm = await pairing_config.delegate.confirm() confirm = await connection.abort_on('disconnection',
pairing_config.delegate.confirm())
if confirm: if confirm:
self.host.send_command_sync( await self.host.send_command(
HCI_User_Confirmation_Request_Reply_Command( HCI_User_Confirmation_Request_Reply_Command(
bd_addr=connection.peer_address bd_addr=connection.peer_address
) )
) )
else: else:
self.host.send_command_sync( await self.host.send_command(
HCI_User_Confirmation_Request_Negative_Reply_Command( HCI_User_Confirmation_Request_Negative_Reply_Command(
bd_addr=connection.peer_address bd_addr=connection.peer_address
) )
@@ -2512,15 +2524,16 @@ class Device(CompositeEventEmitter):
if can_input: if can_input:
async def get_number(): async def get_number():
number = await pairing_config.delegate.get_number() number = await connection.abort_on('disconnection',
pairing_config.delegate.get_number())
if number is not None: if number is not None:
self.host.send_command_sync( await self.host.send_command(
HCI_User_Passkey_Request_Reply_Command( HCI_User_Passkey_Request_Reply_Command(
bd_addr=connection.peer_address, numeric_value=number bd_addr=connection.peer_address, numeric_value=number
) )
) )
else: else:
self.host.send_command_sync( await self.host.send_command(
HCI_User_Passkey_Request_Negative_Reply_Command( HCI_User_Passkey_Request_Negative_Reply_Command(
bd_addr=connection.peer_address bd_addr=connection.peer_address
) )
@@ -2541,7 +2554,7 @@ class Device(CompositeEventEmitter):
# Ask what the pairing config should be for this connection # Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection) pairing_config = self.pairing_config_factory(connection)
asyncio.create_task(pairing_config.delegate.display_number(passkey)) connection.abort_on('disconnection', pairing_config.delegate.display_number(passkey))
# [Classic only] # [Classic only]
@host_event_handler @host_event_handler

View File

@@ -17,7 +17,6 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import logging import logging
from pyee import EventEmitter
from colors import color from colors import color
from .hci import * from .hci import *
@@ -26,6 +25,7 @@ from .att import *
from .gatt import * from .gatt import *
from .smp import * from .smp import *
from .core import ConnectionParameters from .core import ConnectionParameters
from .utils import AbortableEventEmitter
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -65,7 +65,7 @@ class Connection:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Host(EventEmitter): class Host(AbortableEventEmitter):
def __init__(self, controller_source=None, controller_sink=None): def __init__(self, controller_source=None, controller_sink=None):
super().__init__() super().__init__()
@@ -96,7 +96,19 @@ class Host(EventEmitter):
if controller_sink: if controller_sink:
self.set_packet_sink(controller_sink) self.set_packet_sink(controller_sink)
async def flush(self):
# Make sure no command is pending
await self.command_semaphore.acquire()
# Flush current host state, then release command semaphore
self.emit('flush')
self.command_semaphore.release()
async def reset(self): async def reset(self):
if self.ready:
self.ready = False
await self.flush()
await self.send_command(HCI_Reset_Command(), check_result=True) await self.send_command(HCI_Reset_Command(), check_result=True)
self.ready = True self.ready = True
@@ -604,9 +616,9 @@ class Host(EventEmitter):
logger.debug('no long term key provider') logger.debug('no long term key provider')
long_term_key = None long_term_key = None
else: else:
long_term_key = await self.long_term_key_provider( long_term_key = await self.abort_on('flush', self.long_term_key_provider(
connection.handle, event.random_number, event.encryption_diversifier connection.handle, event.random_number, event.encryption_diversifier
) ))
if long_term_key: if long_term_key:
response = HCI_LE_Long_Term_Key_Request_Reply_Command( response = HCI_LE_Long_Term_Key_Request_Reply_Command(
connection_handle=event.connection_handle, connection_handle=event.connection_handle,
@@ -719,7 +731,7 @@ class Host(EventEmitter):
logger.debug('no link key provider') logger.debug('no link key provider')
link_key = None link_key = None
else: else:
link_key = await self.link_key_provider(event.bd_addr) link_key = await self.abort_on('flush', self.link_key_provider(event.bd_addr))
if link_key: if link_key:
response = HCI_Link_Key_Request_Reply_Command( response = HCI_Link_Key_Request_Reply_Command(
bd_addr=event.bd_addr, link_key=link_key bd_addr=event.bd_addr, link_key=link_key

View File

@@ -766,7 +766,7 @@ class Session:
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR) self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
asyncio.create_task(prompt()) self.connection.abort_on('disconnection', prompt())
def prompt_user_for_numeric_comparison(self, code, next_steps): def prompt_user_for_numeric_comparison(self, code, next_steps):
async def prompt(): async def prompt():
@@ -783,7 +783,7 @@ class Session:
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR) self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
asyncio.create_task(prompt()) self.connection.abort_on('disconnection', prompt())
def prompt_user_for_number(self, next_steps): def prompt_user_for_number(self, next_steps):
async def prompt(): async def prompt():
@@ -796,7 +796,7 @@ class Session:
logger.warn(f'exception while prompting: {error}') logger.warn(f'exception while prompting: {error}')
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR) self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
asyncio.create_task(prompt()) self.connection.abort_on('disconnection', prompt())
def display_passkey(self): def display_passkey(self):
# Generate random Passkey/PIN code # Generate random Passkey/PIN code
@@ -808,7 +808,7 @@ class Session:
self.tk = self.passkey.to_bytes(16, byteorder='little') self.tk = self.passkey.to_bytes(16, byteorder='little')
logger.debug(f'TK from passkey = {self.tk.hex()}') logger.debug(f'TK from passkey = {self.tk.hex()}')
asyncio.create_task( self.connection.abort_on('disconnection',
self.pairing_config.delegate.display_number(self.passkey, digits=6) self.pairing_config.delegate.display_number(self.passkey, digits=6)
) )
@@ -921,14 +921,12 @@ class Session:
def start_encryption(self, key): def start_encryption(self, key):
# We can now encrypt the connection with the short term key, so that we can # We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection # distribute the long term and/or other keys over an encrypted connection
asyncio.create_task( self.manager.device.host.send_command_sync(
self.manager.device.host.send_command( HCI_LE_Enable_Encryption_Command(
HCI_LE_Enable_Encryption_Command( connection_handle=self.connection.handle,
connection_handle=self.connection.handle, random_number=bytes(8),
random_number=bytes(8), encrypted_diversifier=0,
encrypted_diversifier=0, long_term_key=key
long_term_key=key,
)
) )
) )
@@ -950,7 +948,7 @@ class Session:
self.connection.transport == BT_BR_EDR_TRANSPORT self.connection.transport == BT_BR_EDR_TRANSPORT
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
): ):
self.ctkd_task = asyncio.create_task(self.derive_ltk()) self.ctkd_task = self.connection.abort_on('disconnection', self.derive_ltk())
elif not self.sc: elif not self.sc:
# Distribute the LTK, EDIV and RAND # Distribute the LTK, EDIV and RAND
if self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG: if self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
@@ -997,7 +995,7 @@ class Session:
self.connection.transport == BT_BR_EDR_TRANSPORT self.connection.transport == BT_BR_EDR_TRANSPORT
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
): ):
self.ctkd_task = asyncio.create_task(self.derive_ltk()) self.ctkd_task = self.connection.abort_on('disconnection', self.derive_ltk())
# Distribute the LTK, EDIV and RAND # Distribute the LTK, EDIV and RAND
elif not self.sc: elif not self.sc:
if self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG: if self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
@@ -1094,7 +1092,7 @@ class Session:
self.send_pairing_request_command() self.send_pairing_request_command()
# Wait for the pairing process to finish # Wait for the pairing process to finish
await self.pairing_result await self.connection.abort_on('disconnection', self.pairing_result)
def on_disconnection(self, reason): def on_disconnection(self, reason):
self.connection.remove_listener('disconnection', self.on_disconnection) self.connection.remove_listener('disconnection', self.on_disconnection)
@@ -1112,7 +1110,7 @@ class Session:
if self.is_initiator: if self.is_initiator:
self.distribute_keys() self.distribute_keys()
asyncio.create_task(self.on_pairing()) self.connection.abort_on('disconnection', self.on_pairing())
def on_connection_encryption_change(self): def on_connection_encryption_change(self):
if self.connection.is_encrypted: if self.connection.is_encrypted:
@@ -1219,7 +1217,7 @@ class Session:
logger.error(color('SMP command not handled???', 'red')) logger.error(color('SMP command not handled???', 'red'))
def on_smp_pairing_request_command(self, command): def on_smp_pairing_request_command(self, command):
asyncio.create_task(self.on_smp_pairing_request_command_async(command)) self.connection.abort_on('disconnection', self.on_smp_pairing_request_command_async(command))
async def on_smp_pairing_request_command_async(self, command): async def on_smp_pairing_request_command_async(self, command):
# Check if the request should proceed # Check if the request should proceed
@@ -1572,7 +1570,7 @@ class Session:
self.wait_before_continuing = None self.wait_before_continuing = None
self.send_pairing_dhkey_check_command() self.send_pairing_dhkey_check_command()
asyncio.create_task(next_steps()) self.connection.abort_on('disconnection', next_steps())
else: else:
self.send_pairing_dhkey_check_command() self.send_pairing_dhkey_check_command()
else: else:
@@ -1688,7 +1686,7 @@ class Manager(EventEmitter):
except Exception as error: except Exception as error:
logger.warn(f'!!! error while storing keys: {error}') logger.warn(f'!!! error while storing keys: {error}')
asyncio.create_task(store_keys()) self.device.abort_on('flush', store_keys())
# Notify the device # Notify the device
self.device.on_pairing(session.connection.handle, keys, session.sc) self.device.on_pairing(session.connection.handle, keys, session.sc)

View File

@@ -19,6 +19,8 @@ import asyncio
import logging import logging
import traceback import traceback
import collections import collections
import sys
from typing import Awaitable
from functools import wraps from functools import wraps
from colors import color from colors import color
from pyee import EventEmitter from pyee import EventEmitter
@@ -62,7 +64,37 @@ def composite_listener(cls):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CompositeEventEmitter(EventEmitter): class AbortableEventEmitter(EventEmitter):
def abort_on(self, event: str, awaitable: Awaitable):
"""
Set a coroutine or future to abort when an event occur.
"""
future = asyncio.ensure_future(awaitable)
if future.done():
return future
def on_event(*_):
msg = f'abort: {event} event occurred.'
if isinstance(future, asyncio.Task):
# python prior to 3.9 does not support passing a message on `Task.cancel`
if sys.version_info < (3, 9, 0):
future.cancel()
else:
future.cancel(msg)
else:
future.set_exception(asyncio.CancelledError(msg))
def on_done(_):
self.remove_listener(event, on_event)
self.on(event, on_event)
future.add_done_callback(on_done)
return future
# -----------------------------------------------------------------------------
class CompositeEventEmitter(AbortableEventEmitter):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._listener = None self._listener = None

View File

@@ -223,8 +223,16 @@ async def test_device_connect_parallel():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def run_test_device(): @pytest.mark.asyncio
await test_device_connect_parallel() async def test_flush():
d0 = Device(host=Host(None, None))
task = d0.abort_on('flush', asyncio.sleep(10000))
await d0.host.flush()
try:
await task
assert False
except asyncio.CancelledError:
pass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -248,6 +256,14 @@ def test_gatt_services_without_gas():
assert len(device.gatt_server.attributes) == 0 assert len(device.gatt_server.attributes) == 0
# -----------------------------------------------------------------------------
async def run_test_device():
await test_device_connect_parallel()
await test_flush()
await test_gatt_services_with_gas()
await test_gatt_services_without_gas()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())