# Copyright 2021-2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio import collections import enum import functools import logging import warnings from collections.abc import Awaitable, Callable from typing import ( Any, Protocol, TypeVar, overload, ) import pyee import pyee.asyncio from typing_extensions import Self from bumble.colors import color # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- def setup_event_forwarding(emitter, forwarder, event_name): def emit(*args, **kwargs): forwarder.emit(event_name, *args, **kwargs) emitter.on(event_name, emit) # ----------------------------------------------------------------------------- def wrap_async(function): """ Wraps the provided function in an async function. """ return functools.partial(async_call, function) # ----------------------------------------------------------------------------- def deprecated(msg: str): """ Throw deprecation warning before execution. """ def wrapper(function): @functools.wraps(function) def inner(*args, **kwargs): warnings.warn(msg, DeprecationWarning, stacklevel=2) return function(*args, **kwargs) return inner return wrapper # ----------------------------------------------------------------------------- def experimental(msg: str): """ Throws a future warning before execution. """ def wrapper(function): @functools.wraps(function) def inner(*args, **kwargs): warnings.warn(msg, FutureWarning, stacklevel=2) return function(*args, **kwargs) return inner return wrapper # ----------------------------------------------------------------------------- def composite_listener(cls): """ Decorator that adds a `register` and `deregister` method to a class, which registers/deregisters all methods named `on_` as a listener for the event with an emitter. """ # pylint: disable=protected-access def register(self, emitter): for method_name in dir(cls): if method_name.startswith('on_'): emitter.on(method_name[3:], getattr(self, method_name)) def deregister(self, emitter): for method_name in dir(cls): if method_name.startswith('on_'): emitter.remove_listener(method_name[3:], getattr(self, method_name)) cls._bumble_register_composite = register cls._bumble_deregister_composite = deregister return cls # ----------------------------------------------------------------------------- _Handler = TypeVar('_Handler', bound=Callable) class EventWatcher: '''A wrapper class to control the lifecycle of event handlers better. Usage: ``` watcher = EventWatcher() def on_foo(): ... watcher.on(emitter, 'foo', on_foo) @watcher.on(emitter, 'bar') def on_bar(): ... # Close all event handlers watching through this watcher watcher.close() ``` As context: ``` with contextlib.closing(EventWatcher()) as context: @context.on(emitter, 'foo') def on_foo(): ... # on_foo() has been removed here! ``` ''' handlers: list[tuple[pyee.EventEmitter, str, Callable[..., Any]]] def __init__(self) -> None: self.handlers = [] @overload def on( self, emitter: pyee.EventEmitter, event: str ) -> Callable[[_Handler], _Handler]: ... @overload def on( self, emitter: pyee.EventEmitter, event: str, handler: _Handler ) -> _Handler: ... def on( self, emitter: pyee.EventEmitter, event: str, handler: _Handler | None = None ) -> _Handler | Callable[[_Handler], _Handler]: '''Watch an event until the context is closed. Args: emitter: EventEmitter to watch event: Event name handler: (Optional) Event handler. When nothing is passed, this method works as a decorator. ''' def wrapper(wrapped: _Handler) -> _Handler: self.handlers.append((emitter, event, wrapped)) emitter.on(event, wrapped) return wrapped return wrapper if handler is None else wrapper(handler) @overload def once( self, emitter: pyee.EventEmitter, event: str ) -> Callable[[_Handler], _Handler]: ... @overload def once( self, emitter: pyee.EventEmitter, event: str, handler: _Handler ) -> _Handler: ... def once( self, emitter: pyee.EventEmitter, event: str, handler: _Handler | None = None ) -> _Handler | Callable[[_Handler], _Handler]: '''Watch an event for once. Args: emitter: EventEmitter to watch event: Event name handler: (Optional) Event handler. When nothing passed, this method works as a decorator. ''' def wrapper(wrapped: _Handler) -> _Handler: self.handlers.append((emitter, event, wrapped)) emitter.once(event, wrapped) return wrapped return wrapper if handler is None else wrapper(handler) def close(self) -> None: for emitter, event, handler in self.handlers: if handler in emitter.listeners(event): emitter.remove_listener(event, handler) # ----------------------------------------------------------------------------- _T = TypeVar('_T') def cancel_on_event( emitter: pyee.EventEmitter, event: str, awaitable: Awaitable[_T] ) -> Awaitable[_T]: """Set a coroutine or future to cancel when an event occur.""" future = asyncio.ensure_future(awaitable) if future.done(): return future def on_event(*args, **kwargs) -> None: del args, kwargs if future.done(): return msg = f'abort: {event} event occurred.' if isinstance(future, asyncio.Task): future.cancel(msg) else: future.set_exception(asyncio.CancelledError(msg)) def on_done(_): emitter.remove_listener(event, on_event) emitter.on(event, on_event) future.add_done_callback(on_done) return future # ----------------------------------------------------------------------------- class EventEmitter(pyee.asyncio.AsyncIOEventEmitter): """A Base EventEmitter for Bumble.""" @deprecated("Use `cancel_on_event` instead.") def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]: """Set a coroutine or future to abort when an event occur.""" return cancel_on_event(self, event, awaitable) # ----------------------------------------------------------------------------- class CompositeEventEmitter(EventEmitter): def __init__(self): super().__init__() self._listener = None @property def listener(self): return self._listener @listener.setter def listener(self, listener): # pylint: disable=protected-access if self._listener: # Call the deregistration methods for each base class that has them for cls in self._listener.__class__.mro(): if '_bumble_register_composite' in cls.__dict__: cls._bumble_deregister_composite(self._listener, self) self._listener = listener if listener: # Call the registration methods for each base class that has them for cls in listener.__class__.mro(): if '_bumble_deregister_composite' in cls.__dict__: cls._bumble_register_composite(listener, self) # ----------------------------------------------------------------------------- class AsyncRunner: class WorkQueue: def __init__(self, create_task=True): self.queue = None self.task = None self.create_task = create_task def enqueue(self, coroutine): # Create a task now if we need to and haven't done so already if self.create_task and self.task is None: self.task = asyncio.create_task(self.run()) # Lazy-create the coroutine queue if self.queue is None: self.queue = asyncio.Queue() # Enqueue the work self.queue.put_nowait(coroutine) async def run(self): while True: item = await self.queue.get() try: await item except Exception: logger.exception(color("!!! Exception in work queue", "red")) # Shared default queue default_queue = WorkQueue() # Shared set of running tasks running_tasks: set[Awaitable] = set() @staticmethod def run_in_task(queue=None): """ Function decorator used to adapt an async function into a sync function """ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): coroutine = func(*args, **kwargs) if queue is None: # Spawn the coroutine as a task async def run(): try: await coroutine except Exception: logger.exception(color("!!! Exception in wrapper:", "red")) AsyncRunner.spawn(run()) else: # Queue the coroutine to be awaited by the work queue queue.enqueue(coroutine) return wrapper return decorator @staticmethod def spawn(coroutine): """ Spawn a task to run a coroutine in a "fire and forget" mode. Using this method instead of just calling `asyncio.create_task(coroutine)` is necessary when you don't keep a reference to the task, because `asyncio` only keeps weak references to alive tasks. """ task = asyncio.create_task(coroutine) AsyncRunner.running_tasks.add(task) task.add_done_callback(AsyncRunner.running_tasks.remove) # ----------------------------------------------------------------------------- class FlowControlAsyncPipe: """ Asyncio pipe with flow control. When writing to the pipe, the source is paused (by calling a function passed in when the pipe is created) if the amount of queued data exceeds a specified threshold. """ def __init__( self, pause_source, resume_source, write_to_sink=None, drain_sink=None, threshold=0, ): self.pause_source = pause_source self.resume_source = resume_source self.write_to_sink = write_to_sink self.drain_sink = drain_sink self.threshold = threshold self.queue = collections.deque() # Queue of packets self.queued_bytes = 0 # Number of bytes in the queue self.ready_to_pump = asyncio.Event() self.paused = False self.source_paused = False self.pump_task = None def start(self): if self.pump_task is None: self.pump_task = asyncio.create_task(self.pump()) self.check_pump() def stop(self): if self.pump_task is not None: self.pump_task.cancel() self.pump_task = None def write(self, packet): self.queued_bytes += len(packet) self.queue.append(packet) # Pause the source if we're over the threshold if self.queued_bytes > self.threshold and not self.source_paused: logger.debug(f'pausing source (queued={self.queued_bytes})') self.pause_source() self.source_paused = True self.check_pump() def pause(self): if not self.paused: self.paused = True if not self.source_paused: self.pause_source() self.source_paused = True self.check_pump() def resume(self): if self.paused: self.paused = False if self.source_paused: self.resume_source() self.source_paused = False self.check_pump() def can_pump(self): return self.queue and not self.paused and self.write_to_sink is not None def check_pump(self): if self.can_pump(): self.ready_to_pump.set() else: self.ready_to_pump.clear() async def pump(self): while True: # Wait until we can try to pump packets await self.ready_to_pump.wait() # Try to pump a packet if self.can_pump(): packet = self.queue.pop() self.write_to_sink(packet) self.queued_bytes -= len(packet) # Drain the sink if we can if self.drain_sink: await self.drain_sink() # Check if we can accept more if self.queued_bytes <= self.threshold and self.source_paused: logger.debug(f'resuming source (queued={self.queued_bytes})') self.source_paused = False self.resume_source() self.check_pump() # ----------------------------------------------------------------------------- async def async_call(function, *args, **kwargs): """ Immediately calls the function with provided args and kwargs, wrapping it in an async function. Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject a running loop. result = await async_call(some_function, ...) """ return function(*args, **kwargs) # ----------------------------------------------------------------------------- class OpenIntEnum(enum.IntEnum): """ Subclass of enum.IntEnum that can hold integer values outside the set of predefined values. This is convenient for implementing protocols where some integer constants may be added over time. """ @classmethod def _missing_(cls, value): if not isinstance(value, int): return None obj = int.__new__(cls, value) obj._value_ = value obj._name_ = f"{cls.__name__}[{value}]" return obj # ----------------------------------------------------------------------------- class CompatibleIntFlag(enum.IntFlag): """ Subclass of `enum.IntFlag` with a `composite_name` property that behaves like the `name` property of the `enum.IntFlag` implementation for python vesions >= 3.11 """ @property def composite_name(self) -> str: return '|'.join( name for flag in self.__class__ if self.value & flag.value and (name := flag.name) is not None ) # ----------------------------------------------------------------------------- class ByteSerializable(Protocol): """ Type protocol for classes that can be instantiated from bytes and serialized to bytes. """ @classmethod def from_bytes(cls, data: bytes) -> Self: ... def __bytes__(self) -> bytes: ... # ----------------------------------------------------------------------------- class IntConvertible(Protocol): """ Type protocol for classes that can be instantiated from int and converted to int. """ def __init__(self, value: int) -> None: ... def __int__(self) -> int: ... # ----------------------------------------------------------------------------- def crc_16(data: bytes) -> int: """Calculate CRC-16-IBM of given data. Polynomial = x^16 + x^15 + x^2 + 1 = 0x8005 or 0xA001(Reversed) """ crc = 0x0000 for byte in data: crc ^= byte for _ in range(8): if (crc & 0x0001) > 0: crc = (crc >> 1) ^ 0xA001 else: crc = crc >> 1 return crc