mirror of
https://github.com/google/bumble.git
synced 2026-04-16 00:25:31 +00:00
549 lines
16 KiB
Python
549 lines
16 KiB
Python
# Copyright 2021-2022 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Imports
|
|
# -----------------------------------------------------------------------------
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import collections
|
|
import enum
|
|
import functools
|
|
import logging
|
|
import warnings
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import (
|
|
Any,
|
|
Protocol,
|
|
TypeVar,
|
|
overload,
|
|
)
|
|
|
|
import pyee
|
|
import pyee.asyncio
|
|
from typing_extensions import Self
|
|
|
|
from bumble.colors import color
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Logging
|
|
# -----------------------------------------------------------------------------
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
def setup_event_forwarding(emitter, forwarder, event_name):
|
|
def emit(*args, **kwargs):
|
|
forwarder.emit(event_name, *args, **kwargs)
|
|
|
|
emitter.on(event_name, emit)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
def wrap_async(function):
|
|
"""
|
|
Wraps the provided function in an async function.
|
|
"""
|
|
return functools.partial(async_call, function)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
def deprecated(msg: str):
|
|
"""
|
|
Throw deprecation warning before execution.
|
|
"""
|
|
|
|
def wrapper(function):
|
|
@functools.wraps(function)
|
|
def inner(*args, **kwargs):
|
|
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
|
return function(*args, **kwargs)
|
|
|
|
return inner
|
|
|
|
return wrapper
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
def experimental(msg: str):
|
|
"""
|
|
Throws a future warning before execution.
|
|
"""
|
|
|
|
def wrapper(function):
|
|
@functools.wraps(function)
|
|
def inner(*args, **kwargs):
|
|
warnings.warn(msg, FutureWarning, stacklevel=2)
|
|
return function(*args, **kwargs)
|
|
|
|
return inner
|
|
|
|
return wrapper
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
def composite_listener(cls):
|
|
"""
|
|
Decorator that adds a `register` and `deregister` method to a class, which
|
|
registers/deregisters all methods named `on_<event_name>` as a listener for
|
|
the <event_name> event with an emitter.
|
|
"""
|
|
# pylint: disable=protected-access
|
|
|
|
def register(self, emitter):
|
|
for method_name in dir(cls):
|
|
if method_name.startswith('on_'):
|
|
emitter.on(method_name[3:], getattr(self, method_name))
|
|
|
|
def deregister(self, emitter):
|
|
for method_name in dir(cls):
|
|
if method_name.startswith('on_'):
|
|
emitter.remove_listener(method_name[3:], getattr(self, method_name))
|
|
|
|
cls._bumble_register_composite = register
|
|
cls._bumble_deregister_composite = deregister
|
|
return cls
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
_Handler = TypeVar('_Handler', bound=Callable)
|
|
|
|
|
|
class EventWatcher:
|
|
'''A wrapper class to control the lifecycle of event handlers better.
|
|
|
|
Usage:
|
|
```
|
|
watcher = EventWatcher()
|
|
|
|
def on_foo():
|
|
...
|
|
watcher.on(emitter, 'foo', on_foo)
|
|
|
|
@watcher.on(emitter, 'bar')
|
|
def on_bar():
|
|
...
|
|
|
|
# Close all event handlers watching through this watcher
|
|
watcher.close()
|
|
```
|
|
|
|
As context:
|
|
```
|
|
with contextlib.closing(EventWatcher()) as context:
|
|
@context.on(emitter, 'foo')
|
|
def on_foo():
|
|
...
|
|
# on_foo() has been removed here!
|
|
```
|
|
'''
|
|
|
|
handlers: list[tuple[pyee.EventEmitter, str, Callable[..., Any]]]
|
|
|
|
def __init__(self) -> None:
|
|
self.handlers = []
|
|
|
|
@overload
|
|
def on(
|
|
self, emitter: pyee.EventEmitter, event: str
|
|
) -> Callable[[_Handler], _Handler]: ...
|
|
|
|
@overload
|
|
def on(
|
|
self, emitter: pyee.EventEmitter, event: str, handler: _Handler
|
|
) -> _Handler: ...
|
|
|
|
def on(
|
|
self, emitter: pyee.EventEmitter, event: str, handler: _Handler | None = None
|
|
) -> _Handler | Callable[[_Handler], _Handler]:
|
|
'''Watch an event until the context is closed.
|
|
|
|
Args:
|
|
emitter: EventEmitter to watch
|
|
event: Event name
|
|
handler: (Optional) Event handler. When nothing is passed, this method
|
|
works as a decorator.
|
|
'''
|
|
|
|
def wrapper(wrapped: _Handler) -> _Handler:
|
|
self.handlers.append((emitter, event, wrapped))
|
|
emitter.on(event, wrapped)
|
|
return wrapped
|
|
|
|
return wrapper if handler is None else wrapper(handler)
|
|
|
|
@overload
|
|
def once(
|
|
self, emitter: pyee.EventEmitter, event: str
|
|
) -> Callable[[_Handler], _Handler]: ...
|
|
|
|
@overload
|
|
def once(
|
|
self, emitter: pyee.EventEmitter, event: str, handler: _Handler
|
|
) -> _Handler: ...
|
|
|
|
def once(
|
|
self, emitter: pyee.EventEmitter, event: str, handler: _Handler | None = None
|
|
) -> _Handler | Callable[[_Handler], _Handler]:
|
|
'''Watch an event for once.
|
|
|
|
Args:
|
|
emitter: EventEmitter to watch
|
|
event: Event name
|
|
handler: (Optional) Event handler. When nothing passed, this method works
|
|
as a decorator.
|
|
'''
|
|
|
|
def wrapper(wrapped: _Handler) -> _Handler:
|
|
self.handlers.append((emitter, event, wrapped))
|
|
emitter.once(event, wrapped)
|
|
return wrapped
|
|
|
|
return wrapper if handler is None else wrapper(handler)
|
|
|
|
def close(self) -> None:
|
|
for emitter, event, handler in self.handlers:
|
|
if handler in emitter.listeners(event):
|
|
emitter.remove_listener(event, handler)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
_T = TypeVar('_T')
|
|
|
|
|
|
def cancel_on_event(
|
|
emitter: pyee.EventEmitter, event: str, awaitable: Awaitable[_T]
|
|
) -> Awaitable[_T]:
|
|
"""Set a coroutine or future to cancel when an event occur."""
|
|
future = asyncio.ensure_future(awaitable)
|
|
if future.done():
|
|
return future
|
|
|
|
def on_event(*args, **kwargs) -> None:
|
|
del args, kwargs
|
|
if future.done():
|
|
return
|
|
msg = f'abort: {event} event occurred.'
|
|
if isinstance(future, asyncio.Task):
|
|
future.cancel(msg)
|
|
else:
|
|
future.set_exception(asyncio.CancelledError(msg))
|
|
|
|
def on_done(_):
|
|
emitter.remove_listener(event, on_event)
|
|
|
|
emitter.on(event, on_event)
|
|
future.add_done_callback(on_done)
|
|
return future
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class EventEmitter(pyee.asyncio.AsyncIOEventEmitter):
|
|
"""A Base EventEmitter for Bumble."""
|
|
|
|
@deprecated("Use `cancel_on_event` instead.")
|
|
def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]:
|
|
"""Set a coroutine or future to abort when an event occur."""
|
|
return cancel_on_event(self, event, awaitable)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class CompositeEventEmitter(EventEmitter):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._listener = None
|
|
|
|
@property
|
|
def listener(self):
|
|
return self._listener
|
|
|
|
@listener.setter
|
|
def listener(self, listener):
|
|
# pylint: disable=protected-access
|
|
if self._listener:
|
|
# Call the deregistration methods for each base class that has them
|
|
for cls in self._listener.__class__.mro():
|
|
if '_bumble_register_composite' in cls.__dict__:
|
|
cls._bumble_deregister_composite(self._listener, self)
|
|
self._listener = listener
|
|
if listener:
|
|
# Call the registration methods for each base class that has them
|
|
for cls in listener.__class__.mro():
|
|
if '_bumble_deregister_composite' in cls.__dict__:
|
|
cls._bumble_register_composite(listener, self)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class AsyncRunner:
|
|
class WorkQueue:
|
|
def __init__(self, create_task=True):
|
|
self.queue = None
|
|
self.task = None
|
|
self.create_task = create_task
|
|
|
|
def enqueue(self, coroutine):
|
|
# Create a task now if we need to and haven't done so already
|
|
if self.create_task and self.task is None:
|
|
self.task = asyncio.create_task(self.run())
|
|
|
|
# Lazy-create the coroutine queue
|
|
if self.queue is None:
|
|
self.queue = asyncio.Queue()
|
|
|
|
# Enqueue the work
|
|
self.queue.put_nowait(coroutine)
|
|
|
|
async def run(self):
|
|
while True:
|
|
item = await self.queue.get()
|
|
try:
|
|
await item
|
|
except Exception:
|
|
logger.exception(color("!!! Exception in work queue", "red"))
|
|
|
|
# Shared default queue
|
|
default_queue = WorkQueue()
|
|
|
|
# Shared set of running tasks
|
|
running_tasks: set[Awaitable] = set()
|
|
|
|
@staticmethod
|
|
def run_in_task(queue=None):
|
|
"""
|
|
Function decorator used to adapt an async function into a sync function
|
|
"""
|
|
|
|
def decorator(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
coroutine = func(*args, **kwargs)
|
|
if queue is None:
|
|
# Spawn the coroutine as a task
|
|
async def run():
|
|
try:
|
|
await coroutine
|
|
except Exception:
|
|
logger.exception(color("!!! Exception in wrapper:", "red"))
|
|
|
|
AsyncRunner.spawn(run())
|
|
else:
|
|
# Queue the coroutine to be awaited by the work queue
|
|
queue.enqueue(coroutine)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
@staticmethod
|
|
def spawn(coroutine):
|
|
"""
|
|
Spawn a task to run a coroutine in a "fire and forget" mode.
|
|
|
|
Using this method instead of just calling `asyncio.create_task(coroutine)`
|
|
is necessary when you don't keep a reference to the task, because `asyncio`
|
|
only keeps weak references to alive tasks.
|
|
"""
|
|
task = asyncio.create_task(coroutine)
|
|
AsyncRunner.running_tasks.add(task)
|
|
task.add_done_callback(AsyncRunner.running_tasks.remove)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class FlowControlAsyncPipe:
|
|
"""
|
|
Asyncio pipe with flow control. When writing to the pipe, the source is
|
|
paused (by calling a function passed in when the pipe is created) if the
|
|
amount of queued data exceeds a specified threshold.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
pause_source,
|
|
resume_source,
|
|
write_to_sink=None,
|
|
drain_sink=None,
|
|
threshold=0,
|
|
):
|
|
self.pause_source = pause_source
|
|
self.resume_source = resume_source
|
|
self.write_to_sink = write_to_sink
|
|
self.drain_sink = drain_sink
|
|
self.threshold = threshold
|
|
self.queue = collections.deque() # Queue of packets
|
|
self.queued_bytes = 0 # Number of bytes in the queue
|
|
self.ready_to_pump = asyncio.Event()
|
|
self.paused = False
|
|
self.source_paused = False
|
|
self.pump_task = None
|
|
|
|
def start(self):
|
|
if self.pump_task is None:
|
|
self.pump_task = asyncio.create_task(self.pump())
|
|
|
|
self.check_pump()
|
|
|
|
def stop(self):
|
|
if self.pump_task is not None:
|
|
self.pump_task.cancel()
|
|
self.pump_task = None
|
|
|
|
def write(self, packet):
|
|
self.queued_bytes += len(packet)
|
|
self.queue.append(packet)
|
|
|
|
# Pause the source if we're over the threshold
|
|
if self.queued_bytes > self.threshold and not self.source_paused:
|
|
logger.debug(f'pausing source (queued={self.queued_bytes})')
|
|
self.pause_source()
|
|
self.source_paused = True
|
|
|
|
self.check_pump()
|
|
|
|
def pause(self):
|
|
if not self.paused:
|
|
self.paused = True
|
|
if not self.source_paused:
|
|
self.pause_source()
|
|
self.source_paused = True
|
|
self.check_pump()
|
|
|
|
def resume(self):
|
|
if self.paused:
|
|
self.paused = False
|
|
if self.source_paused:
|
|
self.resume_source()
|
|
self.source_paused = False
|
|
self.check_pump()
|
|
|
|
def can_pump(self):
|
|
return self.queue and not self.paused and self.write_to_sink is not None
|
|
|
|
def check_pump(self):
|
|
if self.can_pump():
|
|
self.ready_to_pump.set()
|
|
else:
|
|
self.ready_to_pump.clear()
|
|
|
|
async def pump(self):
|
|
while True:
|
|
# Wait until we can try to pump packets
|
|
await self.ready_to_pump.wait()
|
|
|
|
# Try to pump a packet
|
|
if self.can_pump():
|
|
packet = self.queue.pop()
|
|
self.write_to_sink(packet)
|
|
self.queued_bytes -= len(packet)
|
|
|
|
# Drain the sink if we can
|
|
if self.drain_sink:
|
|
await self.drain_sink()
|
|
|
|
# Check if we can accept more
|
|
if self.queued_bytes <= self.threshold and self.source_paused:
|
|
logger.debug(f'resuming source (queued={self.queued_bytes})')
|
|
self.source_paused = False
|
|
self.resume_source()
|
|
|
|
self.check_pump()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
async def async_call(function, *args, **kwargs):
|
|
"""
|
|
Immediately calls the function with provided args and kwargs, wrapping it in an
|
|
async function.
|
|
Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject
|
|
a running loop.
|
|
|
|
result = await async_call(some_function, ...)
|
|
"""
|
|
return function(*args, **kwargs)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class OpenIntEnum(enum.IntEnum):
|
|
"""
|
|
Subclass of enum.IntEnum that can hold integer values outside the set of
|
|
predefined values. This is convenient for implementing protocols where some
|
|
integer constants may be added over time.
|
|
"""
|
|
|
|
@classmethod
|
|
def _missing_(cls, value):
|
|
if not isinstance(value, int):
|
|
return None
|
|
|
|
obj = int.__new__(cls, value)
|
|
obj._value_ = value
|
|
obj._name_ = f"{cls.__name__}[{value}]"
|
|
return obj
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class CompatibleIntFlag(enum.IntFlag):
|
|
"""
|
|
Subclass of `enum.IntFlag` with a `composite_name` property that behaves like the
|
|
`name` property of the `enum.IntFlag` implementation for python vesions >= 3.11
|
|
"""
|
|
|
|
@property
|
|
def composite_name(self) -> str:
|
|
return '|'.join(
|
|
name
|
|
for flag in self.__class__
|
|
if self.value & flag.value and (name := flag.name) is not None
|
|
)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class ByteSerializable(Protocol):
|
|
"""
|
|
Type protocol for classes that can be instantiated from bytes and serialized
|
|
to bytes.
|
|
"""
|
|
|
|
@classmethod
|
|
def from_bytes(cls, data: bytes) -> Self: ...
|
|
|
|
def __bytes__(self) -> bytes: ...
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class IntConvertible(Protocol):
|
|
"""
|
|
Type protocol for classes that can be instantiated from int and converted to int.
|
|
"""
|
|
|
|
def __init__(self, value: int) -> None: ...
|
|
def __int__(self) -> int: ...
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
def crc_16(data: bytes) -> int:
|
|
"""Calculate CRC-16-IBM of given data.
|
|
|
|
Polynomial = x^16 + x^15 + x^2 + 1 = 0x8005 or 0xA001(Reversed)
|
|
"""
|
|
crc = 0x0000
|
|
for byte in data:
|
|
crc ^= byte
|
|
for _ in range(8):
|
|
if (crc & 0x0001) > 0:
|
|
crc = (crc >> 1) ^ 0xA001
|
|
else:
|
|
crc = crc >> 1
|
|
return crc
|