forked from auracaster/bumble_mirror
Compare commits
15 Commits
gbg/multi-
...
v0.0.170
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2491b686fa | ||
|
|
efd02b2f3e | ||
|
|
3b14078646 | ||
|
|
eb9d5632bc | ||
|
|
45f60edbb6 | ||
|
|
393ea6a7bb | ||
|
|
0d36d99a73 | ||
|
|
d8a9f5a724 | ||
|
|
2c66e1a042 | ||
|
|
d5eccdb00f | ||
|
|
32626573a6 | ||
|
|
caa82b8f7e | ||
|
|
5af347b499 | ||
|
|
4ed5bb5a9e | ||
|
|
f39f5f531c |
@@ -3,7 +3,7 @@ import click
|
|||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from bumble.pandora import PandoraDevice, serve
|
from bumble.pandora import PandoraDevice, Config, serve
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
BUMBLE_SERVER_GRPC_PORT = 7999
|
BUMBLE_SERVER_GRPC_PORT = 7999
|
||||||
@@ -29,12 +29,14 @@ def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> No
|
|||||||
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
|
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
|
||||||
|
|
||||||
bumble_config = retrieve_config(config)
|
bumble_config = retrieve_config(config)
|
||||||
if 'transport' not in bumble_config.keys():
|
bumble_config.setdefault('transport', transport)
|
||||||
bumble_config.update({'transport': transport})
|
|
||||||
device = PandoraDevice(bumble_config)
|
device = PandoraDevice(bumble_config)
|
||||||
|
|
||||||
|
server_config = Config()
|
||||||
|
server_config.load_from_dict(bumble_config.get('server', {}))
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
asyncio.run(serve(device, port=grpc_port))
|
asyncio.run(serve(device, config=server_config, port=grpc_port))
|
||||||
|
|
||||||
|
|
||||||
def retrieve_config(config: str) -> Dict[str, Any]:
|
def retrieve_config(config: str) -> Dict[str, Any]:
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class BaseError(Exception):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
error_code: int | None,
|
error_code: Optional[int],
|
||||||
error_namespace: str = '',
|
error_namespace: str = '',
|
||||||
error_name: str = '',
|
error_name: str = '',
|
||||||
details: str = '',
|
details: str = '',
|
||||||
|
|||||||
@@ -4397,7 +4397,7 @@ class HCI_Event(HCI_Packet):
|
|||||||
if len(parameters) != length:
|
if len(parameters) != length:
|
||||||
raise ValueError('invalid packet length')
|
raise ValueError('invalid packet length')
|
||||||
|
|
||||||
cls: Type[HCI_Event | HCI_LE_Meta_Event] | None
|
cls: Any
|
||||||
if event_code == HCI_LE_META_EVENT:
|
if event_code == HCI_LE_META_EVENT:
|
||||||
# We do this dispatch here and not in the subclass in order to avoid call
|
# We do this dispatch here and not in the subclass in order to avoid call
|
||||||
# loops
|
# loops
|
||||||
|
|||||||
@@ -757,7 +757,7 @@ class Channel(EventEmitter):
|
|||||||
)
|
)
|
||||||
self.state = new_state
|
self.state = new_state
|
||||||
|
|
||||||
def send_pdu(self, pdu: SupportsBytes | bytes) -> None:
|
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||||
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
||||||
|
|
||||||
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
||||||
@@ -1098,7 +1098,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
elif new_state == self.DISCONNECTED:
|
elif new_state == self.DISCONNECTED:
|
||||||
self.emit('close')
|
self.emit('close')
|
||||||
|
|
||||||
def send_pdu(self, pdu: SupportsBytes | bytes) -> None:
|
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||||
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
||||||
|
|
||||||
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
||||||
@@ -1571,7 +1571,7 @@ class ChannelManager:
|
|||||||
if connection_handle in self.identifiers:
|
if connection_handle in self.identifiers:
|
||||||
del self.identifiers[connection_handle]
|
del self.identifiers[connection_handle]
|
||||||
|
|
||||||
def send_pdu(self, connection, cid: int, pdu: SupportsBytes | bytes) -> None:
|
def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||||
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
|
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'{color(">>> Sending L2CAP PDU", "blue")} '
|
f'{color(">>> Sending L2CAP PDU", "blue")} '
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import grpc
|
import grpc
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -27,8 +28,8 @@ from bumble.core import (
|
|||||||
)
|
)
|
||||||
from bumble.device import Connection as BumbleConnection, Device
|
from bumble.device import Connection as BumbleConnection, Device
|
||||||
from bumble.hci import HCI_Error
|
from bumble.hci import HCI_Error
|
||||||
|
from bumble.utils import EventWatcher
|
||||||
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
|
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
|
||||||
from contextlib import suppress
|
|
||||||
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
||||||
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
||||||
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
|
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
|
||||||
@@ -294,23 +295,35 @@ class SecurityService(SecurityServicer):
|
|||||||
try:
|
try:
|
||||||
self.log.debug('Pair...')
|
self.log.debug('Pair...')
|
||||||
|
|
||||||
if (
|
security_result = asyncio.get_running_loop().create_future()
|
||||||
connection.transport == BT_LE_TRANSPORT
|
|
||||||
and connection.role == BT_PERIPHERAL_ROLE
|
|
||||||
):
|
|
||||||
wait_for_security: asyncio.Future[
|
|
||||||
bool
|
|
||||||
] = asyncio.get_running_loop().create_future()
|
|
||||||
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
|
|
||||||
connection.on("pairing_failure", wait_for_security.set_exception)
|
|
||||||
|
|
||||||
connection.request_pairing()
|
with contextlib.closing(EventWatcher()) as watcher:
|
||||||
|
|
||||||
await wait_for_security
|
@watcher.on(connection, 'pairing')
|
||||||
else:
|
def on_pairing(*_: Any) -> None:
|
||||||
await connection.pair()
|
security_result.set_result('success')
|
||||||
|
|
||||||
self.log.debug('Paired')
|
@watcher.on(connection, 'pairing_failure')
|
||||||
|
def on_pairing_failure(*_: Any) -> None:
|
||||||
|
security_result.set_result('pairing_failure')
|
||||||
|
|
||||||
|
@watcher.on(connection, 'disconnection')
|
||||||
|
def on_disconnection(*_: Any) -> None:
|
||||||
|
security_result.set_result('connection_died')
|
||||||
|
|
||||||
|
if (
|
||||||
|
connection.transport == BT_LE_TRANSPORT
|
||||||
|
and connection.role == BT_PERIPHERAL_ROLE
|
||||||
|
):
|
||||||
|
connection.request_pairing()
|
||||||
|
else:
|
||||||
|
await connection.pair()
|
||||||
|
|
||||||
|
result = await security_result
|
||||||
|
|
||||||
|
self.log.debug(f'Pairing session complete, status={result}')
|
||||||
|
if result != 'success':
|
||||||
|
return SecureResponse(**{result: empty_pb2.Empty()})
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
self.log.warning("Connection died during encryption")
|
self.log.warning("Connection died during encryption")
|
||||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||||
@@ -369,6 +382,7 @@ class SecurityService(SecurityServicer):
|
|||||||
str
|
str
|
||||||
] = asyncio.get_running_loop().create_future()
|
] = asyncio.get_running_loop().create_future()
|
||||||
authenticate_task: Optional[asyncio.Future[None]] = None
|
authenticate_task: Optional[asyncio.Future[None]] = None
|
||||||
|
pair_task: Optional[asyncio.Future[None]] = None
|
||||||
|
|
||||||
async def authenticate() -> None:
|
async def authenticate() -> None:
|
||||||
assert connection
|
assert connection
|
||||||
@@ -415,6 +429,10 @@ class SecurityService(SecurityServicer):
|
|||||||
if authenticate_task is None:
|
if authenticate_task is None:
|
||||||
authenticate_task = asyncio.create_task(authenticate())
|
authenticate_task = asyncio.create_task(authenticate())
|
||||||
|
|
||||||
|
def pair(*_: Any) -> None:
|
||||||
|
if self.need_pairing(connection, level):
|
||||||
|
pair_task = asyncio.create_task(connection.pair())
|
||||||
|
|
||||||
listeners: Dict[str, Callable[..., None]] = {
|
listeners: Dict[str, Callable[..., None]] = {
|
||||||
'disconnection': set_failure('connection_died'),
|
'disconnection': set_failure('connection_died'),
|
||||||
'pairing_failure': set_failure('pairing_failure'),
|
'pairing_failure': set_failure('pairing_failure'),
|
||||||
@@ -425,6 +443,7 @@ class SecurityService(SecurityServicer):
|
|||||||
'connection_encryption_change': on_encryption_change,
|
'connection_encryption_change': on_encryption_change,
|
||||||
'classic_pairing': try_set_success,
|
'classic_pairing': try_set_success,
|
||||||
'classic_pairing_failure': set_failure('pairing_failure'),
|
'classic_pairing_failure': set_failure('pairing_failure'),
|
||||||
|
'security_request': pair,
|
||||||
}
|
}
|
||||||
|
|
||||||
# register event handlers
|
# register event handlers
|
||||||
@@ -452,6 +471,15 @@ class SecurityService(SecurityServicer):
|
|||||||
pass
|
pass
|
||||||
self.log.debug('Authenticated')
|
self.log.debug('Authenticated')
|
||||||
|
|
||||||
|
# wait for `pair` to finish if any
|
||||||
|
if pair_task is not None:
|
||||||
|
self.log.debug('Wait for authentication...')
|
||||||
|
try:
|
||||||
|
await pair_task # type: ignore
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
self.log.debug('paired')
|
||||||
|
|
||||||
return WaitSecurityResponse(**kwargs)
|
return WaitSecurityResponse(**kwargs)
|
||||||
|
|
||||||
def reached_security_level(
|
def reached_security_level(
|
||||||
@@ -523,7 +551,7 @@ class SecurityStorageService(SecurityStorageServicer):
|
|||||||
self.log.debug(f"DeleteBond: {address}")
|
self.log.debug(f"DeleteBond: {address}")
|
||||||
|
|
||||||
if self.device.keystore is not None:
|
if self.device.keystore is not None:
|
||||||
with suppress(KeyError):
|
with contextlib.suppress(KeyError):
|
||||||
await self.device.keystore.delete(str(address))
|
await self.device.keystore.delete(str(address))
|
||||||
|
|
||||||
return empty_pb2.Empty()
|
return empty_pb2.Empty()
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
@@ -1771,7 +1772,26 @@ class Manager(EventEmitter):
|
|||||||
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
|
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
|
||||||
connection.send_l2cap_pdu(cid, command.to_bytes())
|
connection.send_l2cap_pdu(cid, command.to_bytes())
|
||||||
|
|
||||||
|
def on_smp_security_request_command(
|
||||||
|
self, connection: Connection, request: SMP_Security_Request_Command
|
||||||
|
) -> None:
|
||||||
|
connection.emit('security_request', request.auth_req)
|
||||||
|
|
||||||
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
|
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
|
||||||
|
# Parse the L2CAP payload into an SMP Command object
|
||||||
|
command = SMP_Command.from_bytes(pdu)
|
||||||
|
logger.debug(
|
||||||
|
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
|
||||||
|
f'{connection.peer_address}: {command}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Security request is more than just pairing, so let applications handle them
|
||||||
|
if command.code == SMP_SECURITY_REQUEST_COMMAND:
|
||||||
|
self.on_smp_security_request_command(
|
||||||
|
connection, cast(SMP_Security_Request_Command, command)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Look for a session with this connection, and create one if none exists
|
# Look for a session with this connection, and create one if none exists
|
||||||
if not (session := self.sessions.get(connection.handle)):
|
if not (session := self.sessions.get(connection.handle)):
|
||||||
if connection.role == BT_CENTRAL_ROLE:
|
if connection.role == BT_CENTRAL_ROLE:
|
||||||
@@ -1782,13 +1802,6 @@ class Manager(EventEmitter):
|
|||||||
)
|
)
|
||||||
self.sessions[connection.handle] = session
|
self.sessions[connection.handle] = session
|
||||||
|
|
||||||
# Parse the L2CAP payload into an SMP Command object
|
|
||||||
command = SMP_Command.from_bytes(pdu)
|
|
||||||
logger.debug(
|
|
||||||
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
|
|
||||||
f'{connection.peer_address}: {command}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delegate the handling of the command to the session
|
# Delegate the handling of the command to the session
|
||||||
session.on_smp_command(command)
|
session.on_smp_command(command)
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import grpc.aio
|
import grpc.aio
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
|
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
|
||||||
|
|
||||||
# pylint: disable=no-name-in-module
|
# pylint: disable=no-name-in-module
|
||||||
@@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def open_android_emulator_transport(spec: str | None) -> Transport:
|
async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||||
'''
|
'''
|
||||||
Open a transport connection to an Android emulator via its gRPC interface.
|
Open a transport connection to an Android emulator via its gRPC interface.
|
||||||
The parameter string has this syntax:
|
The parameter string has this syntax:
|
||||||
@@ -82,7 +84,7 @@ async def open_android_emulator_transport(spec: str | None) -> Transport:
|
|||||||
logger.debug(f'connecting to gRPC server at {server_address}')
|
logger.debug(f'connecting to gRPC server at {server_address}')
|
||||||
channel = grpc.aio.insecure_channel(server_address)
|
channel = grpc.aio.insecure_channel(server_address)
|
||||||
|
|
||||||
service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub
|
service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]
|
||||||
if mode == 'host':
|
if mode == 'host':
|
||||||
# Connect as a host
|
# Connect as a host
|
||||||
service = EmulatedBluetoothServiceStub(channel)
|
service = EmulatedBluetoothServiceStub(channel)
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ def publish_grpc_port(grpc_port) -> bool:
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def open_android_netsim_controller_transport(
|
async def open_android_netsim_controller_transport(
|
||||||
server_host: str | None, server_port: int
|
server_host: Optional[str], server_port: int
|
||||||
) -> Transport:
|
) -> Transport:
|
||||||
if not server_port:
|
if not server_port:
|
||||||
raise ValueError('invalid port')
|
raise ValueError('invalid port')
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import socket
|
|||||||
import ctypes
|
import ctypes
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .common import Transport, ParserSource
|
from .common import Transport, ParserSource
|
||||||
|
|
||||||
|
|
||||||
@@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def open_hci_socket_transport(spec: str | None) -> Transport:
|
async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
|
||||||
'''
|
'''
|
||||||
Open an HCI Socket (only available on some platforms).
|
Open an HCI Socket (only available on some platforms).
|
||||||
The parameter string is either empty (to use the first/default Bluetooth adapter)
|
The parameter string is either empty (to use the first/default Bluetooth adapter)
|
||||||
@@ -45,9 +47,9 @@ async def open_hci_socket_transport(spec: str | None) -> Transport:
|
|||||||
# Create a raw HCI socket
|
# Create a raw HCI socket
|
||||||
try:
|
try:
|
||||||
hci_socket = socket.socket(
|
hci_socket = socket.socket(
|
||||||
socket.AF_BLUETOOTH,
|
socket.AF_BLUETOOTH, # type: ignore[attr-defined]
|
||||||
socket.SOCK_RAW | socket.SOCK_NONBLOCK,
|
socket.SOCK_RAW | socket.SOCK_NONBLOCK, # type: ignore[attr-defined]
|
||||||
socket.BTPROTO_HCI, # type: ignore
|
socket.BTPROTO_HCI, # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
except AttributeError as error:
|
except AttributeError as error:
|
||||||
# Not supported on this platform
|
# Not supported on this platform
|
||||||
@@ -78,7 +80,7 @@ async def open_hci_socket_transport(spec: str | None) -> Transport:
|
|||||||
bind_address = struct.pack(
|
bind_address = struct.pack(
|
||||||
# pylint: disable=no-member
|
# pylint: disable=no-member
|
||||||
'<HHH',
|
'<HHH',
|
||||||
socket.AF_BLUETOOTH,
|
socket.AF_BLUETOOTH, # type: ignore[attr-defined]
|
||||||
adapter_index,
|
adapter_index,
|
||||||
HCI_CHANNEL_USER,
|
HCI_CHANNEL_USER,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import atexit
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .common import Transport, StreamPacketSource, StreamPacketSink
|
from .common import Transport, StreamPacketSource, StreamPacketSink
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -32,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def open_pty_transport(spec: str | None) -> Transport:
|
async def open_pty_transport(spec: Optional[str]) -> Transport:
|
||||||
'''
|
'''
|
||||||
Open a PTY transport.
|
Open a PTY transport.
|
||||||
The parameter string may be empty, or a path name where a symbolic link
|
The parameter string may be empty, or a path name where a symbolic link
|
||||||
|
|||||||
@@ -17,6 +17,8 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .common import Transport
|
from .common import Transport
|
||||||
from .file import open_file_transport
|
from .file import open_file_transport
|
||||||
|
|
||||||
@@ -27,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def open_vhci_transport(spec: str | None) -> Transport:
|
async def open_vhci_transport(spec: Optional[str]) -> Transport:
|
||||||
'''
|
'''
|
||||||
Open a VHCI transport (only available on some platforms).
|
Open a VHCI transport (only available on some platforms).
|
||||||
The parameter string is either empty (to use the default VHCI device
|
The parameter string is either empty (to use the default VHCI device
|
||||||
|
|||||||
110
bumble/utils.py
110
bumble/utils.py
@@ -15,12 +15,24 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Imports
|
# Imports
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
import collections
|
import collections
|
||||||
import sys
|
import sys
|
||||||
from typing import Awaitable, Set, TypeVar
|
from typing import (
|
||||||
|
Awaitable,
|
||||||
|
Set,
|
||||||
|
TypeVar,
|
||||||
|
List,
|
||||||
|
Tuple,
|
||||||
|
Callable,
|
||||||
|
Any,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
|
|
||||||
@@ -64,6 +76,102 @@ def composite_listener(cls):
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
_Handler = TypeVar('_Handler', bound=Callable)
|
||||||
|
|
||||||
|
|
||||||
|
class EventWatcher:
|
||||||
|
'''A wrapper class to control the lifecycle of event handlers better.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```
|
||||||
|
watcher = EventWatcher()
|
||||||
|
|
||||||
|
def on_foo():
|
||||||
|
...
|
||||||
|
watcher.on(emitter, 'foo', on_foo)
|
||||||
|
|
||||||
|
@watcher.on(emitter, 'bar')
|
||||||
|
def on_bar():
|
||||||
|
...
|
||||||
|
|
||||||
|
# Close all event handlers watching through this watcher
|
||||||
|
watcher.close()
|
||||||
|
```
|
||||||
|
|
||||||
|
As context:
|
||||||
|
```
|
||||||
|
with contextlib.closing(EventWatcher()) as context:
|
||||||
|
@context.on(emitter, 'foo')
|
||||||
|
def on_foo():
|
||||||
|
...
|
||||||
|
# on_foo() has been removed here!
|
||||||
|
```
|
||||||
|
'''
|
||||||
|
|
||||||
|
handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.handlers = []
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def on(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on(
|
||||||
|
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
|
||||||
|
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
|
||||||
|
'''Watch an event until the context is closed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
emitter: EventEmitter to watch
|
||||||
|
event: Event name
|
||||||
|
handler: (Optional) Event handler. When nothing is passed, this method works as a decorator.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def wrapper(f: _Handler) -> _Handler:
|
||||||
|
self.handlers.append((emitter, event, f))
|
||||||
|
emitter.on(event, f)
|
||||||
|
return f
|
||||||
|
|
||||||
|
return wrapper if handler is None else wrapper(handler)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def once(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def once(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
|
||||||
|
...
|
||||||
|
|
||||||
|
def once(
|
||||||
|
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
|
||||||
|
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
|
||||||
|
'''Watch an event for once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
emitter: EventEmitter to watch
|
||||||
|
event: Event name
|
||||||
|
handler: (Optional) Event handler. When nothing passed, this method works as a decorator.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def wrapper(f: _Handler) -> _Handler:
|
||||||
|
self.handlers.append((emitter, event, f))
|
||||||
|
emitter.once(event, f)
|
||||||
|
return f
|
||||||
|
|
||||||
|
return wrapper if handler is None else wrapper(handler)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
for emitter, event, handler in self.handlers:
|
||||||
|
if handler in emitter.listeners(event):
|
||||||
|
emitter.remove_listener(event, handler)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
_T = TypeVar('_T')
|
_T = TypeVar('_T')
|
||||||
|
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ development =
|
|||||||
black == 22.10
|
black == 22.10
|
||||||
grpcio-tools >= 1.57.0
|
grpcio-tools >= 1.57.0
|
||||||
invoke >= 1.7.3
|
invoke >= 1.7.3
|
||||||
mypy == 1.2.0
|
mypy == 1.5.0
|
||||||
nox >= 2022
|
nox >= 2022
|
||||||
pylint == 2.15.8
|
pylint == 2.15.8
|
||||||
types-appdirs >= 1.4.3
|
types-appdirs >= 1.4.3
|
||||||
|
|||||||
77
tests/utils_test.py
Normal file
77
tests/utils_test.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# Copyright 2021-2023 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from bumble import utils
|
||||||
|
from pyee import EventEmitter
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
def test_on() -> None:
|
||||||
|
emitter = EventEmitter()
|
||||||
|
with contextlib.closing(utils.EventWatcher()) as context:
|
||||||
|
mock = MagicMock()
|
||||||
|
context.on(emitter, 'event', mock)
|
||||||
|
|
||||||
|
emitter.emit('event')
|
||||||
|
|
||||||
|
assert not emitter.listeners('event')
|
||||||
|
assert mock.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_decorator() -> None:
|
||||||
|
emitter = EventEmitter()
|
||||||
|
with contextlib.closing(utils.EventWatcher()) as context:
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@context.on(emitter, 'event')
|
||||||
|
def on_event(*_) -> None:
|
||||||
|
mock()
|
||||||
|
|
||||||
|
emitter.emit('event')
|
||||||
|
|
||||||
|
assert not emitter.listeners('event')
|
||||||
|
assert mock.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_handlers() -> None:
|
||||||
|
emitter = EventEmitter()
|
||||||
|
with contextlib.closing(utils.EventWatcher()) as context:
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
context.once(emitter, 'a', mock)
|
||||||
|
context.once(emitter, 'b', mock)
|
||||||
|
|
||||||
|
emitter.emit('b', 'b')
|
||||||
|
|
||||||
|
assert not emitter.listeners('a')
|
||||||
|
assert not emitter.listeners('b')
|
||||||
|
|
||||||
|
mock.assert_called_once_with('b')
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
def run_tests():
|
||||||
|
test_on()
|
||||||
|
test_on_decorator()
|
||||||
|
test_multiple_handlers()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||||
|
run_tests()
|
||||||
Reference in New Issue
Block a user