forked from auracaster/bumble_mirror
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b50a48fed3 |
+16
-44
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import grpc
|
||||
import logging
|
||||
|
||||
@@ -28,8 +27,8 @@ from bumble.core import (
|
||||
)
|
||||
from bumble.device import Connection as BumbleConnection, Device
|
||||
from bumble.hci import HCI_Error
|
||||
from bumble.utils import EventWatcher
|
||||
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
|
||||
from contextlib import suppress
|
||||
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
|
||||
@@ -295,35 +294,23 @@ class SecurityService(SecurityServicer):
|
||||
try:
|
||||
self.log.debug('Pair...')
|
||||
|
||||
security_result = asyncio.get_running_loop().create_future()
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
):
|
||||
wait_for_security: asyncio.Future[
|
||||
bool
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
|
||||
connection.on("pairing_failure", wait_for_security.set_exception)
|
||||
|
||||
with contextlib.closing(EventWatcher()) as watcher:
|
||||
connection.request_pairing()
|
||||
|
||||
@watcher.on(connection, 'pairing')
|
||||
def on_pairing(*_: Any) -> None:
|
||||
security_result.set_result('success')
|
||||
await wait_for_security
|
||||
else:
|
||||
await connection.pair()
|
||||
|
||||
@watcher.on(connection, 'pairing_failure')
|
||||
def on_pairing_failure(*_: Any) -> None:
|
||||
security_result.set_result('pairing_failure')
|
||||
|
||||
@watcher.on(connection, 'disconnection')
|
||||
def on_disconnection(*_: Any) -> None:
|
||||
security_result.set_result('connection_died')
|
||||
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
):
|
||||
connection.request_pairing()
|
||||
else:
|
||||
await connection.pair()
|
||||
|
||||
result = await security_result
|
||||
|
||||
self.log.debug(f'Pairing session complete, status={result}')
|
||||
if result != 'success':
|
||||
return SecureResponse(**{result: empty_pb2.Empty()})
|
||||
self.log.debug('Paired')
|
||||
except asyncio.CancelledError:
|
||||
self.log.warning("Connection died during encryption")
|
||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||
@@ -382,7 +369,6 @@ class SecurityService(SecurityServicer):
|
||||
str
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
authenticate_task: Optional[asyncio.Future[None]] = None
|
||||
pair_task: Optional[asyncio.Future[None]] = None
|
||||
|
||||
async def authenticate() -> None:
|
||||
assert connection
|
||||
@@ -429,10 +415,6 @@ class SecurityService(SecurityServicer):
|
||||
if authenticate_task is None:
|
||||
authenticate_task = asyncio.create_task(authenticate())
|
||||
|
||||
def pair(*_: Any) -> None:
|
||||
if self.need_pairing(connection, level):
|
||||
pair_task = asyncio.create_task(connection.pair())
|
||||
|
||||
listeners: Dict[str, Callable[..., None]] = {
|
||||
'disconnection': set_failure('connection_died'),
|
||||
'pairing_failure': set_failure('pairing_failure'),
|
||||
@@ -443,7 +425,6 @@ class SecurityService(SecurityServicer):
|
||||
'connection_encryption_change': on_encryption_change,
|
||||
'classic_pairing': try_set_success,
|
||||
'classic_pairing_failure': set_failure('pairing_failure'),
|
||||
'security_request': pair,
|
||||
}
|
||||
|
||||
# register event handlers
|
||||
@@ -471,15 +452,6 @@ class SecurityService(SecurityServicer):
|
||||
pass
|
||||
self.log.debug('Authenticated')
|
||||
|
||||
# wait for `pair` to finish if any
|
||||
if pair_task is not None:
|
||||
self.log.debug('Wait for authentication...')
|
||||
try:
|
||||
await pair_task # type: ignore
|
||||
except:
|
||||
pass
|
||||
self.log.debug('paired')
|
||||
|
||||
return WaitSecurityResponse(**kwargs)
|
||||
|
||||
def reached_security_level(
|
||||
@@ -551,7 +523,7 @@ class SecurityStorageService(SecurityStorageServicer):
|
||||
self.log.debug(f"DeleteBond: {address}")
|
||||
|
||||
if self.device.keystore is not None:
|
||||
with contextlib.suppress(KeyError):
|
||||
with suppress(KeyError):
|
||||
await self.device.keystore.delete(str(address))
|
||||
|
||||
return empty_pb2.Empty()
|
||||
|
||||
+7
-20
@@ -37,7 +37,6 @@ from typing import (
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pyee import EventEmitter
|
||||
@@ -1772,26 +1771,7 @@ class Manager(EventEmitter):
|
||||
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
|
||||
connection.send_l2cap_pdu(cid, command.to_bytes())
|
||||
|
||||
def on_smp_security_request_command(
|
||||
self, connection: Connection, request: SMP_Security_Request_Command
|
||||
) -> None:
|
||||
connection.emit('security_request', request.auth_req)
|
||||
|
||||
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
|
||||
# Parse the L2CAP payload into an SMP Command object
|
||||
command = SMP_Command.from_bytes(pdu)
|
||||
logger.debug(
|
||||
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address}: {command}'
|
||||
)
|
||||
|
||||
# Security request is more than just pairing, so let applications handle them
|
||||
if command.code == SMP_SECURITY_REQUEST_COMMAND:
|
||||
self.on_smp_security_request_command(
|
||||
connection, cast(SMP_Security_Request_Command, command)
|
||||
)
|
||||
return
|
||||
|
||||
# Look for a session with this connection, and create one if none exists
|
||||
if not (session := self.sessions.get(connection.handle)):
|
||||
if connection.role == BT_CENTRAL_ROLE:
|
||||
@@ -1802,6 +1782,13 @@ class Manager(EventEmitter):
|
||||
)
|
||||
self.sessions[connection.handle] = session
|
||||
|
||||
# Parse the L2CAP payload into an SMP Command object
|
||||
command = SMP_Command.from_bytes(pdu)
|
||||
logger.debug(
|
||||
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address}: {command}'
|
||||
)
|
||||
|
||||
# Delegate the handling of the command to the session
|
||||
session.on_smp_command(command)
|
||||
|
||||
|
||||
+82
-106
@@ -15,25 +15,23 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
import collections
|
||||
from contextlib import ExitStack
|
||||
import sys
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Set,
|
||||
TypeVar,
|
||||
List,
|
||||
Tuple,
|
||||
Callable,
|
||||
Any,
|
||||
ContextManager,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
overload,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
from functools import wraps
|
||||
import functools
|
||||
from pyee import EventEmitter
|
||||
|
||||
from .colors import color
|
||||
@@ -76,102 +74,6 @@ def composite_listener(cls):
|
||||
return cls
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
_Handler = TypeVar('_Handler', bound=Callable)
|
||||
|
||||
|
||||
class EventWatcher:
|
||||
'''A wrapper class to control the lifecycle of event handlers better.
|
||||
|
||||
Usage:
|
||||
```
|
||||
watcher = EventWatcher()
|
||||
|
||||
def on_foo():
|
||||
...
|
||||
watcher.on(emitter, 'foo', on_foo)
|
||||
|
||||
@watcher.on(emitter, 'bar')
|
||||
def on_bar():
|
||||
...
|
||||
|
||||
# Close all event handlers watching through this watcher
|
||||
watcher.close()
|
||||
```
|
||||
|
||||
As context:
|
||||
```
|
||||
with contextlib.closing(EventWatcher()) as context:
|
||||
@context.on(emitter, 'foo')
|
||||
def on_foo():
|
||||
...
|
||||
# on_foo() has been removed here!
|
||||
```
|
||||
'''
|
||||
|
||||
handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.handlers = []
|
||||
|
||||
@overload
|
||||
def on(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
|
||||
...
|
||||
|
||||
def on(
|
||||
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
|
||||
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
|
||||
'''Watch an event until the context is closed.
|
||||
|
||||
Args:
|
||||
emitter: EventEmitter to watch
|
||||
event: Event name
|
||||
handler: (Optional) Event handler. When nothing is passed, this method works as a decorator.
|
||||
'''
|
||||
|
||||
def wrapper(f: _Handler) -> _Handler:
|
||||
self.handlers.append((emitter, event, f))
|
||||
emitter.on(event, f)
|
||||
return f
|
||||
|
||||
return wrapper if handler is None else wrapper(handler)
|
||||
|
||||
@overload
|
||||
def once(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def once(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
|
||||
...
|
||||
|
||||
def once(
|
||||
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
|
||||
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
|
||||
'''Watch an event for once.
|
||||
|
||||
Args:
|
||||
emitter: EventEmitter to watch
|
||||
event: Event name
|
||||
handler: (Optional) Event handler. When nothing passed, this method works as a decorator.
|
||||
'''
|
||||
|
||||
def wrapper(f: _Handler) -> _Handler:
|
||||
self.handlers.append((emitter, event, f))
|
||||
emitter.once(event, f)
|
||||
return f
|
||||
|
||||
return wrapper if handler is None else wrapper(handler)
|
||||
|
||||
def close(self) -> None:
|
||||
for emitter, event, handler in self.handlers:
|
||||
if handler in emitter.listeners(event):
|
||||
emitter.remove_listener(event, handler)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
_T = TypeVar('_T')
|
||||
|
||||
@@ -275,7 +177,7 @@ class AsyncRunner:
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
coroutine = func(*args, **kwargs)
|
||||
if queue is None:
|
||||
@@ -410,3 +312,77 @@ class FlowControlAsyncPipe:
|
||||
self.resume_source()
|
||||
|
||||
self.check_pump()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def event_emitter_once_for_group(
|
||||
emitter: EventEmitter,
|
||||
handlers: Mapping[str, Callable],
|
||||
context: Optional[ContextManager] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Register event listeners as a group, with optional context manager.
|
||||
|
||||
For each entry in the map, an event listener is registered with the emitter.
|
||||
When any of the event names in the handlers map is emitted by the emitter,
|
||||
the corresponding handler is invoked, then all of the listeners are removed from
|
||||
the emitter.
|
||||
If a context manager is specified, it will be entered prior to registering the
|
||||
listeners, and exited when any of the events is emitted.
|
||||
|
||||
Args:
|
||||
emitter:
|
||||
The event emitter with which to register the event listeners.
|
||||
handlers:
|
||||
A map that associates an event name with an event handler.
|
||||
context:
|
||||
A context manager to manager resources, or None if not needed.
|
||||
"""
|
||||
event_emitters_once_for_group(
|
||||
{(emitter, event_name): handler for event_name, handler in handlers.items()},
|
||||
context,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def event_emitters_once_for_group(
|
||||
handlers: Mapping[Tuple[EventEmitter, str], Callable],
|
||||
context: Optional[ContextManager] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Register event listeners as a group, with optional context manager.
|
||||
|
||||
Similar to event_emitter_once_for_group, but instead of registering the listeners
|
||||
with a single emitter, each event may by associate with a different emitter.
|
||||
|
||||
Args:
|
||||
handlers:
|
||||
A map that associates an (emitter, event_name) pair with an event handler.
|
||||
context:
|
||||
A context manager to manager resources, or None if not needed.
|
||||
"""
|
||||
# Setup an exit stack to enter and exit the context, if any.
|
||||
if context is not None:
|
||||
exit_stack = ExitStack()
|
||||
exit_stack.enter_context(context)
|
||||
else:
|
||||
exit_stack = None
|
||||
|
||||
def on_event(handler, *args, **kwargs) -> None:
|
||||
# Invoke the handler.
|
||||
handler(*args, **kwargs)
|
||||
|
||||
# Release the context, if any.
|
||||
if exit_stack is not None:
|
||||
exit_stack.close()
|
||||
|
||||
# Remove all listeners.
|
||||
for (emitter, event_name), listener in listeners.items():
|
||||
emitter.remove_listener(event_name, listener)
|
||||
|
||||
listeners = {
|
||||
(emitter, event_name): emitter.on(
|
||||
event_name, functools.partial(on_event, handler)
|
||||
)
|
||||
for (emitter, event_name), handler in handlers.items()
|
||||
}
|
||||
|
||||
+99
-54
@@ -1,4 +1,4 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,66 +12,111 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
from bumble import utils
|
||||
from pyee import EventEmitter
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def test_on() -> None:
|
||||
emitter = EventEmitter()
|
||||
with contextlib.closing(utils.EventWatcher()) as context:
|
||||
mock = MagicMock()
|
||||
context.on(emitter, 'event', mock)
|
||||
|
||||
emitter.emit('event')
|
||||
|
||||
assert not emitter.listeners('event')
|
||||
assert mock.call_count == 1
|
||||
|
||||
|
||||
def test_on_decorator() -> None:
|
||||
emitter = EventEmitter()
|
||||
with contextlib.closing(utils.EventWatcher()) as context:
|
||||
mock = MagicMock()
|
||||
|
||||
@context.on(emitter, 'event')
|
||||
def on_event(*_) -> None:
|
||||
mock()
|
||||
|
||||
emitter.emit('event')
|
||||
|
||||
assert not emitter.listeners('event')
|
||||
assert mock.call_count == 1
|
||||
|
||||
|
||||
def test_multiple_handlers() -> None:
|
||||
emitter = EventEmitter()
|
||||
with contextlib.closing(utils.EventWatcher()) as context:
|
||||
mock = MagicMock()
|
||||
|
||||
context.once(emitter, 'a', mock)
|
||||
context.once(emitter, 'b', mock)
|
||||
|
||||
emitter.emit('b', 'b')
|
||||
|
||||
assert not emitter.listeners('a')
|
||||
assert not emitter.listeners('b')
|
||||
|
||||
mock.assert_called_once_with('b')
|
||||
from bumble.utils import event_emitter_once_for_group, event_emitters_once_for_group
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def run_tests():
|
||||
test_on()
|
||||
test_on_decorator()
|
||||
test_multiple_handlers()
|
||||
def test_event_emitter_once_for_group():
|
||||
results = {'event1': None, 'event2': None, 'released': 0}
|
||||
|
||||
def handler1(x):
|
||||
results['event1'] = x
|
||||
|
||||
def handler2(y):
|
||||
results['event2'] = y
|
||||
|
||||
emitter = EventEmitter()
|
||||
|
||||
event_emitter_once_for_group(
|
||||
emitter,
|
||||
{
|
||||
'event1': handler1,
|
||||
'event2': handler2,
|
||||
},
|
||||
)
|
||||
|
||||
emitter.emit('event1', 'hello')
|
||||
|
||||
assert results['event1'] == 'hello'
|
||||
assert results['event2'] is None
|
||||
|
||||
results['event1'] = None
|
||||
|
||||
emitter.emit('event1', 'hello')
|
||||
emitter.emit('event2', 1234)
|
||||
|
||||
assert results['event1'] is None
|
||||
assert results['event2'] is None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def managed():
|
||||
try:
|
||||
yield 1234
|
||||
finally:
|
||||
results['released'] += 1
|
||||
|
||||
event_emitter_once_for_group(
|
||||
emitter,
|
||||
{
|
||||
'event1': handler1,
|
||||
'event2': handler2,
|
||||
},
|
||||
managed(),
|
||||
)
|
||||
|
||||
assert results['released'] == 0
|
||||
|
||||
emitter.emit('event2', 7756)
|
||||
|
||||
assert results['event1'] is None
|
||||
assert results['event2'] == 7756
|
||||
assert results['released'] == 1
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_event_emitters_once_for_group():
|
||||
results = {'event1': None, 'event2': None, 'released': 0}
|
||||
|
||||
def handler1(x):
|
||||
results['event1'] = x
|
||||
|
||||
def handler2(y):
|
||||
results['event2'] = y
|
||||
|
||||
emitter1 = EventEmitter()
|
||||
emitter2 = EventEmitter()
|
||||
|
||||
event_emitters_once_for_group(
|
||||
{
|
||||
(emitter1, 'event1'): handler1,
|
||||
(emitter2, 'event2'): handler2,
|
||||
},
|
||||
)
|
||||
|
||||
emitter1.emit('event1', 'hello')
|
||||
emitter2.emit('event1', 'foobar')
|
||||
|
||||
assert results['event1'] == 'hello'
|
||||
assert results['event2'] is None
|
||||
|
||||
results['event1'] = None
|
||||
|
||||
emitter1.emit('event1', 'hello')
|
||||
emitter1.emit('event2', 1234)
|
||||
emitter2.emit('event1', 'hello')
|
||||
emitter2.emit('event2', 1234)
|
||||
|
||||
assert results['event1'] is None
|
||||
assert results['event2'] is None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
run_tests()
|
||||
test_event_emitter_once_for_group()
|
||||
test_event_emitters_once_for_group()
|
||||
|
||||
Reference in New Issue
Block a user