Compare commits

..

1 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod b50a48fed3 add pyee group util 2023-09-19 16:14:53 -07:00
4 changed files with 204 additions and 224 deletions
+16 -44
View File
@@ -13,7 +13,6 @@
# limitations under the License.
import asyncio
import contextlib
import grpc
import logging
@@ -28,8 +27,8 @@ from bumble.core import (
)
from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error
from bumble.utils import EventWatcher
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
from contextlib import suppress
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
@@ -295,35 +294,23 @@ class SecurityService(SecurityServicer):
try:
self.log.debug('Pair...')
security_result = asyncio.get_running_loop().create_future()
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
wait_for_security: asyncio.Future[
bool
] = asyncio.get_running_loop().create_future()
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
connection.on("pairing_failure", wait_for_security.set_exception)
with contextlib.closing(EventWatcher()) as watcher:
connection.request_pairing()
@watcher.on(connection, 'pairing')
def on_pairing(*_: Any) -> None:
security_result.set_result('success')
await wait_for_security
else:
await connection.pair()
@watcher.on(connection, 'pairing_failure')
def on_pairing_failure(*_: Any) -> None:
security_result.set_result('pairing_failure')
@watcher.on(connection, 'disconnection')
def on_disconnection(*_: Any) -> None:
security_result.set_result('connection_died')
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
connection.request_pairing()
else:
await connection.pair()
result = await security_result
self.log.debug(f'Pairing session complete, status={result}')
if result != 'success':
return SecureResponse(**{result: empty_pb2.Empty()})
self.log.debug('Paired')
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -382,7 +369,6 @@ class SecurityService(SecurityServicer):
str
] = asyncio.get_running_loop().create_future()
authenticate_task: Optional[asyncio.Future[None]] = None
pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None:
assert connection
@@ -429,10 +415,6 @@ class SecurityService(SecurityServicer):
if authenticate_task is None:
authenticate_task = asyncio.create_task(authenticate())
def pair(*_: Any) -> None:
if self.need_pairing(connection, level):
pair_task = asyncio.create_task(connection.pair())
listeners: Dict[str, Callable[..., None]] = {
'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'),
@@ -443,7 +425,6 @@ class SecurityService(SecurityServicer):
'connection_encryption_change': on_encryption_change,
'classic_pairing': try_set_success,
'classic_pairing_failure': set_failure('pairing_failure'),
'security_request': pair,
}
# register event handlers
@@ -471,15 +452,6 @@ class SecurityService(SecurityServicer):
pass
self.log.debug('Authenticated')
# wait for `pair` to finish if any
if pair_task is not None:
self.log.debug('Wait for authentication...')
try:
await pair_task # type: ignore
except:
pass
self.log.debug('paired')
return WaitSecurityResponse(**kwargs)
def reached_security_level(
@@ -551,7 +523,7 @@ class SecurityStorageService(SecurityStorageServicer):
self.log.debug(f"DeleteBond: {address}")
if self.device.keystore is not None:
with contextlib.suppress(KeyError):
with suppress(KeyError):
await self.device.keystore.delete(str(address))
return empty_pb2.Empty()
+7 -20
View File
@@ -37,7 +37,6 @@ from typing import (
Optional,
Tuple,
Type,
cast,
)
from pyee import EventEmitter
@@ -1772,26 +1771,7 @@ class Manager(EventEmitter):
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes())
def on_smp_security_request_command(
self, connection: Connection, request: SMP_Security_Request_Command
) -> None:
connection.emit('security_request', request.auth_req)
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
# Parse the L2CAP payload into an SMP Command object
command = SMP_Command.from_bytes(pdu)
logger.debug(
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
)
# Security request is more than just pairing, so let applications handle them
if command.code == SMP_SECURITY_REQUEST_COMMAND:
self.on_smp_security_request_command(
connection, cast(SMP_Security_Request_Command, command)
)
return
# Look for a session with this connection, and create one if none exists
if not (session := self.sessions.get(connection.handle)):
if connection.role == BT_CENTRAL_ROLE:
@@ -1802,6 +1782,13 @@ class Manager(EventEmitter):
)
self.sessions[connection.handle] = session
# Parse the L2CAP payload into an SMP Command object
command = SMP_Command.from_bytes(pdu)
logger.debug(
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
)
# Delegate the handling of the command to the session
session.on_smp_command(command)
+82 -106
View File
@@ -15,25 +15,23 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import traceback
import collections
from contextlib import ExitStack
import sys
from typing import (
Awaitable,
Set,
TypeVar,
List,
Tuple,
Callable,
Any,
ContextManager,
Mapping,
Optional,
Union,
overload,
Set,
Tuple,
TypeVar,
)
from functools import wraps
import functools
from pyee import EventEmitter
from .colors import color
@@ -76,102 +74,6 @@ def composite_listener(cls):
return cls
# -----------------------------------------------------------------------------
_Handler = TypeVar('_Handler', bound=Callable)
class EventWatcher:
'''A wrapper class to control the lifecycle of event handlers better.
Usage:
```
watcher = EventWatcher()
def on_foo():
...
watcher.on(emitter, 'foo', on_foo)
@watcher.on(emitter, 'bar')
def on_bar():
...
# Close all event handlers watching through this watcher
watcher.close()
```
As context:
```
with contextlib.closing(EventWatcher()) as context:
@context.on(emitter, 'foo')
def on_foo():
...
# on_foo() has been removed here!
```
'''
handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]]
def __init__(self) -> None:
self.handlers = []
@overload
def on(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
...
@overload
def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
...
def on(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event until the context is closed.
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing is passed, this method works as a decorator.
'''
def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, f))
emitter.on(event, f)
return f
return wrapper if handler is None else wrapper(handler)
@overload
def once(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
...
@overload
def once(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
...
def once(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event for once.
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing passed, this method works as a decorator.
'''
def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, f))
emitter.once(event, f)
return f
return wrapper if handler is None else wrapper(handler)
def close(self) -> None:
for emitter, event, handler in self.handlers:
if handler in emitter.listeners(event):
emitter.remove_listener(event, handler)
# -----------------------------------------------------------------------------
_T = TypeVar('_T')
@@ -275,7 +177,7 @@ class AsyncRunner:
"""
def decorator(func):
@wraps(func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
coroutine = func(*args, **kwargs)
if queue is None:
@@ -410,3 +312,77 @@ class FlowControlAsyncPipe:
self.resume_source()
self.check_pump()
# -----------------------------------------------------------------------------
def event_emitter_once_for_group(
emitter: EventEmitter,
handlers: Mapping[str, Callable],
context: Optional[ContextManager] = None,
) -> None:
"""
Register event listeners as a group, with optional context manager.
For each entry in the map, an event listener is registered with the emitter.
When any of the event names in the handlers map is emitted by the emitter,
the corresponding handler is invoked, then all of the listeners are removed from
the emitter.
If a context manager is specified, it will be entered prior to registering the
listeners, and exited when any of the events is emitted.
Args:
emitter:
The event emitter with which to register the event listeners.
handlers:
A map that associates an event name with an event handler.
context:
A context manager to manager resources, or None if not needed.
"""
event_emitters_once_for_group(
{(emitter, event_name): handler for event_name, handler in handlers.items()},
context,
)
# -----------------------------------------------------------------------------
def event_emitters_once_for_group(
handlers: Mapping[Tuple[EventEmitter, str], Callable],
context: Optional[ContextManager] = None,
) -> None:
"""
Register event listeners as a group, with optional context manager.
Similar to event_emitter_once_for_group, but instead of registering the listeners
with a single emitter, each event may by associate with a different emitter.
Args:
handlers:
A map that associates an (emitter, event_name) pair with an event handler.
context:
A context manager to manager resources, or None if not needed.
"""
# Setup an exit stack to enter and exit the context, if any.
if context is not None:
exit_stack = ExitStack()
exit_stack.enter_context(context)
else:
exit_stack = None
def on_event(handler, *args, **kwargs) -> None:
# Invoke the handler.
handler(*args, **kwargs)
# Release the context, if any.
if exit_stack is not None:
exit_stack.close()
# Remove all listeners.
for (emitter, event_name), listener in listeners.items():
emitter.remove_listener(event_name, listener)
listeners = {
(emitter, event_name): emitter.on(
event_name, functools.partial(on_event, handler)
)
for (emitter, event_name), handler in handlers.items()
}
+99 -54
View File
@@ -1,4 +1,4 @@
# Copyright 2021-2023 Google LLC
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,66 +12,111 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import contextlib
import logging
import os
from bumble import utils
from pyee import EventEmitter
from unittest.mock import MagicMock
def test_on() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
mock = MagicMock()
context.on(emitter, 'event', mock)
emitter.emit('event')
assert not emitter.listeners('event')
assert mock.call_count == 1
def test_on_decorator() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
mock = MagicMock()
@context.on(emitter, 'event')
def on_event(*_) -> None:
mock()
emitter.emit('event')
assert not emitter.listeners('event')
assert mock.call_count == 1
def test_multiple_handlers() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
mock = MagicMock()
context.once(emitter, 'a', mock)
context.once(emitter, 'b', mock)
emitter.emit('b', 'b')
assert not emitter.listeners('a')
assert not emitter.listeners('b')
mock.assert_called_once_with('b')
from bumble.utils import event_emitter_once_for_group, event_emitters_once_for_group
# -----------------------------------------------------------------------------
def run_tests():
test_on()
test_on_decorator()
test_multiple_handlers()
def test_event_emitter_once_for_group():
results = {'event1': None, 'event2': None, 'released': 0}
def handler1(x):
results['event1'] = x
def handler2(y):
results['event2'] = y
emitter = EventEmitter()
event_emitter_once_for_group(
emitter,
{
'event1': handler1,
'event2': handler2,
},
)
emitter.emit('event1', 'hello')
assert results['event1'] == 'hello'
assert results['event2'] is None
results['event1'] = None
emitter.emit('event1', 'hello')
emitter.emit('event2', 1234)
assert results['event1'] is None
assert results['event2'] is None
@contextlib.contextmanager
def managed():
try:
yield 1234
finally:
results['released'] += 1
event_emitter_once_for_group(
emitter,
{
'event1': handler1,
'event2': handler2,
},
managed(),
)
assert results['released'] == 0
emitter.emit('event2', 7756)
assert results['event1'] is None
assert results['event2'] == 7756
assert results['released'] == 1
# -----------------------------------------------------------------------------
def test_event_emitters_once_for_group():
results = {'event1': None, 'event2': None, 'released': 0}
def handler1(x):
results['event1'] = x
def handler2(y):
results['event2'] = y
emitter1 = EventEmitter()
emitter2 = EventEmitter()
event_emitters_once_for_group(
{
(emitter1, 'event1'): handler1,
(emitter2, 'event2'): handler2,
},
)
emitter1.emit('event1', 'hello')
emitter2.emit('event1', 'foobar')
assert results['event1'] == 'hello'
assert results['event2'] is None
results['event1'] = None
emitter1.emit('event1', 'hello')
emitter1.emit('event2', 1234)
emitter2.emit('event1', 'hello')
emitter2.emit('event2', 1234)
assert results['event1'] is None
assert results['event2'] is None
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
run_tests()
test_event_emitter_once_for_group()
test_event_emitters_once_for_group()