mirror of
https://github.com/google/bumble.git
synced 2026-05-08 03:58:01 +00:00
host: spawn each asynchronous task with the right aliveness
This commit is contained in:
@@ -535,7 +535,7 @@ class Connection(CompositeEventEmitter):
|
|||||||
self.on('disconnection_failure', abort.set_exception)
|
self.on('disconnection_failure', abort.set_exception)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(abort, timeout)
|
await asyncio.wait_for(self.device.abort_on('flush', abort), timeout)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -1592,7 +1592,7 @@ class Device(CompositeEventEmitter):
|
|||||||
if transport == BT_LE_TRANSPORT:
|
if transport == BT_LE_TRANSPORT:
|
||||||
self.le_connecting = True
|
self.le_connecting = True
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
return await pending_connection
|
return await self.abort_on('flush', pending_connection)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return await asyncio.wait_for(
|
return await asyncio.wait_for(
|
||||||
@@ -1609,7 +1609,7 @@ class Device(CompositeEventEmitter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await pending_connection
|
return await self.abort_on('flush', pending_connection)
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
raise TimeoutError()
|
raise TimeoutError()
|
||||||
finally:
|
finally:
|
||||||
@@ -1661,6 +1661,7 @@ class Device(CompositeEventEmitter):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Wait for a request or a completed connection
|
# Wait for a request or a completed connection
|
||||||
|
pending_request = self.abort_on('flush', pending_request)
|
||||||
result = await (
|
result = await (
|
||||||
asyncio.wait_for(pending_request, timeout)
|
asyncio.wait_for(pending_request, timeout)
|
||||||
if timeout
|
if timeout
|
||||||
@@ -1682,6 +1683,9 @@ class Device(CompositeEventEmitter):
|
|||||||
# Otherwise, result came from `on_connection_request`
|
# Otherwise, result came from `on_connection_request`
|
||||||
peer_address, class_of_device, link_type = result
|
peer_address, class_of_device, link_type = result
|
||||||
|
|
||||||
|
# Create a future so that we can wait for the connection's result
|
||||||
|
pending_connection = asyncio.get_running_loop().create_future()
|
||||||
|
|
||||||
def on_connection(connection):
|
def on_connection(connection):
|
||||||
if (
|
if (
|
||||||
connection.transport == BT_BR_EDR_TRANSPORT
|
connection.transport == BT_BR_EDR_TRANSPORT
|
||||||
@@ -1696,8 +1700,6 @@ class Device(CompositeEventEmitter):
|
|||||||
):
|
):
|
||||||
pending_connection.set_exception(error)
|
pending_connection.set_exception(error)
|
||||||
|
|
||||||
# Create a future so that we can wait for the connection's result
|
|
||||||
pending_connection = asyncio.get_running_loop().create_future()
|
|
||||||
self.on('connection', on_connection)
|
self.on('connection', on_connection)
|
||||||
self.on('connection_failure', on_connection_failure)
|
self.on('connection_failure', on_connection_failure)
|
||||||
|
|
||||||
@@ -1713,7 +1715,7 @@ class Device(CompositeEventEmitter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Wait for connection complete
|
# Wait for connection complete
|
||||||
return await pending_connection
|
return await self.abort_on('flush', pending_connection)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self.remove_listener('connection', on_connection)
|
self.remove_listener('connection', on_connection)
|
||||||
@@ -1782,7 +1784,7 @@ class Device(CompositeEventEmitter):
|
|||||||
|
|
||||||
# Wait for the disconnection process to complete
|
# Wait for the disconnection process to complete
|
||||||
self.disconnecting = True
|
self.disconnecting = True
|
||||||
return await pending_disconnection
|
return await self.abort_on('flush', pending_disconnection)
|
||||||
finally:
|
finally:
|
||||||
connection.remove_listener(
|
connection.remove_listener(
|
||||||
'disconnection', pending_disconnection.set_result
|
'disconnection', pending_disconnection.set_result
|
||||||
@@ -1910,7 +1912,7 @@ class Device(CompositeEventEmitter):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return await peer_address
|
return await self.abort_on('flush', peer_address)
|
||||||
finally:
|
finally:
|
||||||
if handler is not None:
|
if handler is not None:
|
||||||
self.remove_listener(event_name, handler)
|
self.remove_listener(event_name, handler)
|
||||||
@@ -1994,7 +1996,7 @@ class Device(CompositeEventEmitter):
|
|||||||
connection.authenticating = True
|
connection.authenticating = True
|
||||||
|
|
||||||
# Wait for the authentication to complete
|
# Wait for the authentication to complete
|
||||||
await pending_authentication
|
await connection.abort_on('disconnection', pending_authentication)
|
||||||
finally:
|
finally:
|
||||||
connection.authenticating = False
|
connection.authenticating = False
|
||||||
connection.remove_listener('connection_authentication', on_authentication)
|
connection.remove_listener('connection_authentication', on_authentication)
|
||||||
@@ -2068,7 +2070,7 @@ class Device(CompositeEventEmitter):
|
|||||||
raise HCI_StatusError(result)
|
raise HCI_StatusError(result)
|
||||||
|
|
||||||
# Wait for the result
|
# Wait for the result
|
||||||
await pending_encryption
|
await connection.abort_on('disconnection', pending_encryption)
|
||||||
finally:
|
finally:
|
||||||
connection.remove_listener(
|
connection.remove_listener(
|
||||||
'connection_encryption_change', on_encryption_change
|
'connection_encryption_change', on_encryption_change
|
||||||
@@ -2116,11 +2118,18 @@ class Device(CompositeEventEmitter):
|
|||||||
raise HCI_StatusError(result)
|
raise HCI_StatusError(result)
|
||||||
|
|
||||||
# Wait for the result
|
# Wait for the result
|
||||||
return await pending_name
|
return await self.abort_on('flush', pending_name)
|
||||||
finally:
|
finally:
|
||||||
self.remove_listener('remote_name', handler)
|
self.remove_listener('remote_name', handler)
|
||||||
self.remove_listener('remote_name_failure', failure_handler)
|
self.remove_listener('remote_name_failure', failure_handler)
|
||||||
|
|
||||||
|
@host_event_handler
|
||||||
|
def on_flush(self):
|
||||||
|
self.emit('flush')
|
||||||
|
for _, connection in self.connections.items():
|
||||||
|
connection.emit('disconnection', 0)
|
||||||
|
self.connections = {}
|
||||||
|
|
||||||
# [Classic only]
|
# [Classic only]
|
||||||
@host_event_handler
|
@host_event_handler
|
||||||
def on_link_key(self, bd_addr, link_key, key_type):
|
def on_link_key(self, bd_addr, link_key, key_type):
|
||||||
@@ -2135,7 +2144,7 @@ class Device(CompositeEventEmitter):
|
|||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.warn(f'!!! error while storing keys: {error}')
|
logger.warn(f'!!! error while storing keys: {error}')
|
||||||
|
|
||||||
asyncio.create_task(store_keys())
|
self.abort_on('flush', store_keys())
|
||||||
|
|
||||||
if connection := self.find_connection_by_bd_addr(
|
if connection := self.find_connection_by_bd_addr(
|
||||||
bd_addr, transport=BT_BR_EDR_TRANSPORT
|
bd_addr, transport=BT_BR_EDR_TRANSPORT
|
||||||
@@ -2227,10 +2236,10 @@ class Device(CompositeEventEmitter):
|
|||||||
async def new_connection():
|
async def new_connection():
|
||||||
# Figure out which PHY we're connected with
|
# Figure out which PHY we're connected with
|
||||||
if self.host.supports_command(HCI_LE_READ_PHY_COMMAND):
|
if self.host.supports_command(HCI_LE_READ_PHY_COMMAND):
|
||||||
result = await self.send_command(
|
result = await asyncio.shield(self.send_command(
|
||||||
HCI_LE_Read_PHY_Command(connection_handle=connection_handle),
|
HCI_LE_Read_PHY_Command(connection_handle=connection_handle),
|
||||||
check_result=True,
|
check_result=True,
|
||||||
)
|
))
|
||||||
phy = ConnectionPHY(
|
phy = ConnectionPHY(
|
||||||
result.return_parameters.tx_phy, result.return_parameters.rx_phy
|
result.return_parameters.tx_phy, result.return_parameters.rx_phy
|
||||||
)
|
)
|
||||||
@@ -2261,7 +2270,7 @@ class Device(CompositeEventEmitter):
|
|||||||
# Emit an event to notify listeners of the new connection
|
# Emit an event to notify listeners of the new connection
|
||||||
self.emit('connection', connection)
|
self.emit('connection', connection)
|
||||||
|
|
||||||
asyncio.create_task(new_connection())
|
self.abort_on('flush', new_connection())
|
||||||
|
|
||||||
@host_event_handler
|
@host_event_handler
|
||||||
def on_connection_failure(self, transport, peer_address, error_code):
|
def on_connection_failure(self, transport, peer_address, error_code):
|
||||||
@@ -2338,7 +2347,7 @@ class Device(CompositeEventEmitter):
|
|||||||
# Restart advertising if auto-restart is enabled
|
# Restart advertising if auto-restart is enabled
|
||||||
if self.auto_restart_advertising:
|
if self.auto_restart_advertising:
|
||||||
logger.debug('restarting advertising')
|
logger.debug('restarting advertising')
|
||||||
asyncio.create_task(
|
self.abort_on('flush',
|
||||||
self.start_advertising(
|
self.start_advertising(
|
||||||
advertising_type=self.advertising_type, auto_restart=True
|
advertising_type=self.advertising_type, auto_restart=True
|
||||||
)
|
)
|
||||||
@@ -2460,17 +2469,19 @@ class Device(CompositeEventEmitter):
|
|||||||
if can_compare:
|
if can_compare:
|
||||||
|
|
||||||
async def compare_numbers():
|
async def compare_numbers():
|
||||||
numbers_match = await pairing_config.delegate.compare_numbers(
|
numbers_match = await connection.abort_on('disconnection',
|
||||||
code, digits=6
|
pairing_config.delegate.compare_numbers(
|
||||||
|
code, digits=6
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if numbers_match:
|
if numbers_match:
|
||||||
self.host.send_command_sync(
|
await self.host.send_command(
|
||||||
HCI_User_Confirmation_Request_Reply_Command(
|
HCI_User_Confirmation_Request_Reply_Command(
|
||||||
bd_addr=connection.peer_address
|
bd_addr=connection.peer_address
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.host.send_command_sync(
|
await self.host.send_command(
|
||||||
HCI_User_Confirmation_Request_Negative_Reply_Command(
|
HCI_User_Confirmation_Request_Negative_Reply_Command(
|
||||||
bd_addr=connection.peer_address
|
bd_addr=connection.peer_address
|
||||||
)
|
)
|
||||||
@@ -2480,15 +2491,16 @@ class Device(CompositeEventEmitter):
|
|||||||
else:
|
else:
|
||||||
|
|
||||||
async def confirm():
|
async def confirm():
|
||||||
confirm = await pairing_config.delegate.confirm()
|
confirm = await connection.abort_on('disconnection',
|
||||||
|
pairing_config.delegate.confirm())
|
||||||
if confirm:
|
if confirm:
|
||||||
self.host.send_command_sync(
|
await self.host.send_command(
|
||||||
HCI_User_Confirmation_Request_Reply_Command(
|
HCI_User_Confirmation_Request_Reply_Command(
|
||||||
bd_addr=connection.peer_address
|
bd_addr=connection.peer_address
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.host.send_command_sync(
|
await self.host.send_command(
|
||||||
HCI_User_Confirmation_Request_Negative_Reply_Command(
|
HCI_User_Confirmation_Request_Negative_Reply_Command(
|
||||||
bd_addr=connection.peer_address
|
bd_addr=connection.peer_address
|
||||||
)
|
)
|
||||||
@@ -2512,15 +2524,16 @@ class Device(CompositeEventEmitter):
|
|||||||
if can_input:
|
if can_input:
|
||||||
|
|
||||||
async def get_number():
|
async def get_number():
|
||||||
number = await pairing_config.delegate.get_number()
|
number = await connection.abort_on('disconnection',
|
||||||
|
pairing_config.delegate.get_number())
|
||||||
if number is not None:
|
if number is not None:
|
||||||
self.host.send_command_sync(
|
await self.host.send_command(
|
||||||
HCI_User_Passkey_Request_Reply_Command(
|
HCI_User_Passkey_Request_Reply_Command(
|
||||||
bd_addr=connection.peer_address, numeric_value=number
|
bd_addr=connection.peer_address, numeric_value=number
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.host.send_command_sync(
|
await self.host.send_command(
|
||||||
HCI_User_Passkey_Request_Negative_Reply_Command(
|
HCI_User_Passkey_Request_Negative_Reply_Command(
|
||||||
bd_addr=connection.peer_address
|
bd_addr=connection.peer_address
|
||||||
)
|
)
|
||||||
@@ -2541,7 +2554,7 @@ class Device(CompositeEventEmitter):
|
|||||||
# Ask what the pairing config should be for this connection
|
# Ask what the pairing config should be for this connection
|
||||||
pairing_config = self.pairing_config_factory(connection)
|
pairing_config = self.pairing_config_factory(connection)
|
||||||
|
|
||||||
asyncio.create_task(pairing_config.delegate.display_number(passkey))
|
connection.abort_on('disconnection', pairing_config.delegate.display_number(passkey))
|
||||||
|
|
||||||
# [Classic only]
|
# [Classic only]
|
||||||
@host_event_handler
|
@host_event_handler
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from pyee import EventEmitter
|
|
||||||
from colors import color
|
from colors import color
|
||||||
|
|
||||||
from .hci import *
|
from .hci import *
|
||||||
@@ -26,6 +25,7 @@ from .att import *
|
|||||||
from .gatt import *
|
from .gatt import *
|
||||||
from .smp import *
|
from .smp import *
|
||||||
from .core import ConnectionParameters
|
from .core import ConnectionParameters
|
||||||
|
from .utils import AbortableEventEmitter
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Logging
|
# Logging
|
||||||
@@ -65,7 +65,7 @@ class Connection:
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class Host(EventEmitter):
|
class Host(AbortableEventEmitter):
|
||||||
def __init__(self, controller_source=None, controller_sink=None):
|
def __init__(self, controller_source=None, controller_sink=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -96,7 +96,19 @@ class Host(EventEmitter):
|
|||||||
if controller_sink:
|
if controller_sink:
|
||||||
self.set_packet_sink(controller_sink)
|
self.set_packet_sink(controller_sink)
|
||||||
|
|
||||||
|
async def flush(self):
|
||||||
|
# Make sure no command is pending
|
||||||
|
await self.command_semaphore.acquire()
|
||||||
|
|
||||||
|
# Flush current host state, then release command semaphore
|
||||||
|
self.emit('flush')
|
||||||
|
self.command_semaphore.release()
|
||||||
|
|
||||||
async def reset(self):
|
async def reset(self):
|
||||||
|
if self.ready:
|
||||||
|
self.ready = False
|
||||||
|
await self.flush()
|
||||||
|
|
||||||
await self.send_command(HCI_Reset_Command(), check_result=True)
|
await self.send_command(HCI_Reset_Command(), check_result=True)
|
||||||
self.ready = True
|
self.ready = True
|
||||||
|
|
||||||
@@ -604,9 +616,9 @@ class Host(EventEmitter):
|
|||||||
logger.debug('no long term key provider')
|
logger.debug('no long term key provider')
|
||||||
long_term_key = None
|
long_term_key = None
|
||||||
else:
|
else:
|
||||||
long_term_key = await self.long_term_key_provider(
|
long_term_key = await self.abort_on('flush', self.long_term_key_provider(
|
||||||
connection.handle, event.random_number, event.encryption_diversifier
|
connection.handle, event.random_number, event.encryption_diversifier
|
||||||
)
|
))
|
||||||
if long_term_key:
|
if long_term_key:
|
||||||
response = HCI_LE_Long_Term_Key_Request_Reply_Command(
|
response = HCI_LE_Long_Term_Key_Request_Reply_Command(
|
||||||
connection_handle=event.connection_handle,
|
connection_handle=event.connection_handle,
|
||||||
@@ -719,7 +731,7 @@ class Host(EventEmitter):
|
|||||||
logger.debug('no link key provider')
|
logger.debug('no link key provider')
|
||||||
link_key = None
|
link_key = None
|
||||||
else:
|
else:
|
||||||
link_key = await self.link_key_provider(event.bd_addr)
|
link_key = await self.abort_on('flush', self.link_key_provider(event.bd_addr))
|
||||||
if link_key:
|
if link_key:
|
||||||
response = HCI_Link_Key_Request_Reply_Command(
|
response = HCI_Link_Key_Request_Reply_Command(
|
||||||
bd_addr=event.bd_addr, link_key=link_key
|
bd_addr=event.bd_addr, link_key=link_key
|
||||||
|
|||||||
@@ -766,7 +766,7 @@ class Session:
|
|||||||
|
|
||||||
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
|
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
|
||||||
|
|
||||||
asyncio.create_task(prompt())
|
self.connection.abort_on('disconnection', prompt())
|
||||||
|
|
||||||
def prompt_user_for_numeric_comparison(self, code, next_steps):
|
def prompt_user_for_numeric_comparison(self, code, next_steps):
|
||||||
async def prompt():
|
async def prompt():
|
||||||
@@ -783,7 +783,7 @@ class Session:
|
|||||||
|
|
||||||
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
|
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
|
||||||
|
|
||||||
asyncio.create_task(prompt())
|
self.connection.abort_on('disconnection', prompt())
|
||||||
|
|
||||||
def prompt_user_for_number(self, next_steps):
|
def prompt_user_for_number(self, next_steps):
|
||||||
async def prompt():
|
async def prompt():
|
||||||
@@ -796,7 +796,7 @@ class Session:
|
|||||||
logger.warn(f'exception while prompting: {error}')
|
logger.warn(f'exception while prompting: {error}')
|
||||||
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
|
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
|
||||||
|
|
||||||
asyncio.create_task(prompt())
|
self.connection.abort_on('disconnection', prompt())
|
||||||
|
|
||||||
def display_passkey(self):
|
def display_passkey(self):
|
||||||
# Generate random Passkey/PIN code
|
# Generate random Passkey/PIN code
|
||||||
@@ -808,7 +808,7 @@ class Session:
|
|||||||
self.tk = self.passkey.to_bytes(16, byteorder='little')
|
self.tk = self.passkey.to_bytes(16, byteorder='little')
|
||||||
logger.debug(f'TK from passkey = {self.tk.hex()}')
|
logger.debug(f'TK from passkey = {self.tk.hex()}')
|
||||||
|
|
||||||
asyncio.create_task(
|
self.connection.abort_on('disconnection',
|
||||||
self.pairing_config.delegate.display_number(self.passkey, digits=6)
|
self.pairing_config.delegate.display_number(self.passkey, digits=6)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -921,14 +921,12 @@ class Session:
|
|||||||
def start_encryption(self, key):
|
def start_encryption(self, key):
|
||||||
# We can now encrypt the connection with the short term key, so that we can
|
# We can now encrypt the connection with the short term key, so that we can
|
||||||
# distribute the long term and/or other keys over an encrypted connection
|
# distribute the long term and/or other keys over an encrypted connection
|
||||||
asyncio.create_task(
|
self.manager.device.host.send_command_sync(
|
||||||
self.manager.device.host.send_command(
|
HCI_LE_Enable_Encryption_Command(
|
||||||
HCI_LE_Enable_Encryption_Command(
|
connection_handle=self.connection.handle,
|
||||||
connection_handle=self.connection.handle,
|
random_number=bytes(8),
|
||||||
random_number=bytes(8),
|
encrypted_diversifier=0,
|
||||||
encrypted_diversifier=0,
|
long_term_key=key
|
||||||
long_term_key=key,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -950,7 +948,7 @@ class Session:
|
|||||||
self.connection.transport == BT_BR_EDR_TRANSPORT
|
self.connection.transport == BT_BR_EDR_TRANSPORT
|
||||||
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
||||||
):
|
):
|
||||||
self.ctkd_task = asyncio.create_task(self.derive_ltk())
|
self.ctkd_task = self.connection.abort_on('disconnection', self.derive_ltk())
|
||||||
elif not self.sc:
|
elif not self.sc:
|
||||||
# Distribute the LTK, EDIV and RAND
|
# Distribute the LTK, EDIV and RAND
|
||||||
if self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
|
if self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
|
||||||
@@ -997,7 +995,7 @@ class Session:
|
|||||||
self.connection.transport == BT_BR_EDR_TRANSPORT
|
self.connection.transport == BT_BR_EDR_TRANSPORT
|
||||||
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
||||||
):
|
):
|
||||||
self.ctkd_task = asyncio.create_task(self.derive_ltk())
|
self.ctkd_task = self.connection.abort_on('disconnection', self.derive_ltk())
|
||||||
# Distribute the LTK, EDIV and RAND
|
# Distribute the LTK, EDIV and RAND
|
||||||
elif not self.sc:
|
elif not self.sc:
|
||||||
if self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
|
if self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
|
||||||
@@ -1094,7 +1092,7 @@ class Session:
|
|||||||
self.send_pairing_request_command()
|
self.send_pairing_request_command()
|
||||||
|
|
||||||
# Wait for the pairing process to finish
|
# Wait for the pairing process to finish
|
||||||
await self.pairing_result
|
await self.connection.abort_on('disconnection', self.pairing_result)
|
||||||
|
|
||||||
def on_disconnection(self, reason):
|
def on_disconnection(self, reason):
|
||||||
self.connection.remove_listener('disconnection', self.on_disconnection)
|
self.connection.remove_listener('disconnection', self.on_disconnection)
|
||||||
@@ -1112,7 +1110,7 @@ class Session:
|
|||||||
if self.is_initiator:
|
if self.is_initiator:
|
||||||
self.distribute_keys()
|
self.distribute_keys()
|
||||||
|
|
||||||
asyncio.create_task(self.on_pairing())
|
self.connection.abort_on('disconnection', self.on_pairing())
|
||||||
|
|
||||||
def on_connection_encryption_change(self):
|
def on_connection_encryption_change(self):
|
||||||
if self.connection.is_encrypted:
|
if self.connection.is_encrypted:
|
||||||
@@ -1219,7 +1217,7 @@ class Session:
|
|||||||
logger.error(color('SMP command not handled???', 'red'))
|
logger.error(color('SMP command not handled???', 'red'))
|
||||||
|
|
||||||
def on_smp_pairing_request_command(self, command):
|
def on_smp_pairing_request_command(self, command):
|
||||||
asyncio.create_task(self.on_smp_pairing_request_command_async(command))
|
self.connection.abort_on('disconnection', self.on_smp_pairing_request_command_async(command))
|
||||||
|
|
||||||
async def on_smp_pairing_request_command_async(self, command):
|
async def on_smp_pairing_request_command_async(self, command):
|
||||||
# Check if the request should proceed
|
# Check if the request should proceed
|
||||||
@@ -1572,7 +1570,7 @@ class Session:
|
|||||||
self.wait_before_continuing = None
|
self.wait_before_continuing = None
|
||||||
self.send_pairing_dhkey_check_command()
|
self.send_pairing_dhkey_check_command()
|
||||||
|
|
||||||
asyncio.create_task(next_steps())
|
self.connection.abort_on('disconnection', next_steps())
|
||||||
else:
|
else:
|
||||||
self.send_pairing_dhkey_check_command()
|
self.send_pairing_dhkey_check_command()
|
||||||
else:
|
else:
|
||||||
@@ -1688,7 +1686,7 @@ class Manager(EventEmitter):
|
|||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.warn(f'!!! error while storing keys: {error}')
|
logger.warn(f'!!! error while storing keys: {error}')
|
||||||
|
|
||||||
asyncio.create_task(store_keys())
|
self.device.abort_on('flush', store_keys())
|
||||||
|
|
||||||
# Notify the device
|
# Notify the device
|
||||||
self.device.on_pairing(session.connection.handle, keys, session.sc)
|
self.device.on_pairing(session.connection.handle, keys, session.sc)
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
import collections
|
import collections
|
||||||
|
import sys
|
||||||
|
from typing import Awaitable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from colors import color
|
from colors import color
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
@@ -62,7 +64,37 @@ def composite_listener(cls):
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class CompositeEventEmitter(EventEmitter):
|
class AbortableEventEmitter(EventEmitter):
|
||||||
|
|
||||||
|
def abort_on(self, event: str, awaitable: Awaitable):
|
||||||
|
"""
|
||||||
|
Set a coroutine or future to abort when an event occur.
|
||||||
|
"""
|
||||||
|
future = asyncio.ensure_future(awaitable)
|
||||||
|
if future.done():
|
||||||
|
return future
|
||||||
|
|
||||||
|
def on_event(*_):
|
||||||
|
msg = f'abort: {event} event occurred.'
|
||||||
|
if isinstance(future, asyncio.Task):
|
||||||
|
# python prior to 3.9 does not support passing a message on `Task.cancel`
|
||||||
|
if sys.version_info < (3, 9, 0):
|
||||||
|
future.cancel()
|
||||||
|
else:
|
||||||
|
future.cancel(msg)
|
||||||
|
else:
|
||||||
|
future.set_exception(asyncio.CancelledError(msg))
|
||||||
|
|
||||||
|
def on_done(_):
|
||||||
|
self.remove_listener(event, on_event)
|
||||||
|
|
||||||
|
self.on(event, on_event)
|
||||||
|
future.add_done_callback(on_done)
|
||||||
|
return future
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class CompositeEventEmitter(AbortableEventEmitter):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._listener = None
|
self._listener = None
|
||||||
|
|||||||
@@ -223,8 +223,16 @@ async def test_device_connect_parallel():
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def run_test_device():
|
@pytest.mark.asyncio
|
||||||
await test_device_connect_parallel()
|
async def test_flush():
|
||||||
|
d0 = Device(host=Host(None, None))
|
||||||
|
task = d0.abort_on('flush', asyncio.sleep(10000))
|
||||||
|
await d0.host.flush()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
assert False
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -248,6 +256,14 @@ def test_gatt_services_without_gas():
|
|||||||
assert len(device.gatt_server.attributes) == 0
|
assert len(device.gatt_server.attributes) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
async def run_test_device():
|
||||||
|
await test_device_connect_parallel()
|
||||||
|
await test_flush()
|
||||||
|
await test_gatt_services_with_gas()
|
||||||
|
await test_gatt_services_without_gas()
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||||
|
|||||||
Reference in New Issue
Block a user