mirror of
https://github.com/google/bumble.git
synced 2026-04-17 00:35:31 +00:00
506 lines
16 KiB
Python
506 lines
16 KiB
Python
# Copyright 2021-2022 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Imports
|
|
# -----------------------------------------------------------------------------
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import io
|
|
import logging
|
|
import struct
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import Any, Protocol
|
|
|
|
from bumble import core, hci
|
|
from bumble.colors import color
|
|
from bumble.snoop import Snooper
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Logging
|
|
# -----------------------------------------------------------------------------
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Information needed to parse HCI packets with a generic parser:
|
|
# For each packet type, the info represents:
|
|
# (length-size, length-offset, unpack-type)
|
|
HCI_PACKET_INFO: dict[int, tuple[int, int, str]] = {
|
|
hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
|
|
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
|
|
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
|
|
hci.HCI_EVENT_PACKET: (1, 1, 'B'),
|
|
hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'),
|
|
}
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Errors
|
|
# -----------------------------------------------------------------------------
|
|
class TransportLostError(core.BaseBumbleError, RuntimeError):
|
|
"""The Transport has been lost/disconnected."""
|
|
|
|
|
|
class TransportInitError(core.BaseBumbleError, RuntimeError):
|
|
"""Error raised when the transport cannot be initialized."""
|
|
|
|
|
|
class TransportSpecError(core.BaseBumbleError, ValueError):
|
|
"""Error raised when the transport spec is invalid."""
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Typing Protocols
|
|
# -----------------------------------------------------------------------------
|
|
class TransportSink(Protocol):
|
|
def on_packet(self, packet: bytes) -> None: ...
|
|
|
|
|
|
class TransportSource(Protocol):
|
|
terminated: asyncio.Future[None]
|
|
|
|
def set_packet_sink(self, sink: TransportSink) -> None: ...
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class PacketPump:
|
|
"""
|
|
Pump HCI packets from a reader to a sink.
|
|
"""
|
|
|
|
def __init__(self, reader: AsyncPacketReader, sink: TransportSink) -> None:
|
|
self.reader = reader
|
|
self.sink = sink
|
|
|
|
async def run(self) -> None:
|
|
while True:
|
|
try:
|
|
# Deliver the packet to the sink
|
|
self.sink.on_packet(await self.reader.next_packet())
|
|
except Exception:
|
|
logger.exception('!!!')
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class PacketParser:
|
|
"""
|
|
In-line parser that accepts data and emits 'on_packet' when a full packet has been
|
|
parsed.
|
|
"""
|
|
|
|
# pylint: disable=attribute-defined-outside-init
|
|
|
|
NEED_TYPE = 0
|
|
NEED_LENGTH = 1
|
|
NEED_BODY = 2
|
|
|
|
sink: TransportSink | None
|
|
extended_packet_info: dict[int, tuple[int, int, str]]
|
|
packet_info: tuple[int, int, str] | None = None
|
|
|
|
def __init__(self, sink: TransportSink | None = None) -> None:
|
|
self.sink = sink
|
|
self.extended_packet_info = {}
|
|
self.reset()
|
|
|
|
def reset(self) -> None:
|
|
self.state = PacketParser.NEED_TYPE
|
|
self.bytes_needed = 1
|
|
self.packet = bytearray()
|
|
self.packet_info = None
|
|
|
|
def feed_data(self, data: bytes) -> None:
|
|
data_offset = 0
|
|
data_left = len(data)
|
|
while data_left and self.bytes_needed:
|
|
consumed = min(self.bytes_needed, data_left)
|
|
self.packet.extend(data[data_offset : data_offset + consumed])
|
|
data_offset += consumed
|
|
data_left -= consumed
|
|
self.bytes_needed -= consumed
|
|
|
|
if self.bytes_needed == 0:
|
|
if self.state == PacketParser.NEED_TYPE:
|
|
packet_type = self.packet[0]
|
|
self.packet_info = HCI_PACKET_INFO.get(
|
|
packet_type
|
|
) or self.extended_packet_info.get(packet_type)
|
|
if self.packet_info is None:
|
|
self.reset()
|
|
raise core.InvalidPacketError(
|
|
f'invalid packet type {packet_type}'
|
|
)
|
|
self.state = PacketParser.NEED_LENGTH
|
|
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
|
|
elif self.state == PacketParser.NEED_LENGTH:
|
|
assert self.packet_info is not None
|
|
body_length = struct.unpack_from(
|
|
self.packet_info[2], self.packet, 1 + self.packet_info[1]
|
|
)[0]
|
|
self.bytes_needed = body_length
|
|
self.state = PacketParser.NEED_BODY
|
|
|
|
# Emit a packet if one is complete
|
|
if self.state == PacketParser.NEED_BODY and not self.bytes_needed:
|
|
if self.sink:
|
|
try:
|
|
self.sink.on_packet(bytes(self.packet))
|
|
except Exception:
|
|
logger.exception(color('!!! Exception in on_packet', 'red'))
|
|
self.reset()
|
|
|
|
def set_packet_sink(self, sink: TransportSink) -> None:
|
|
self.sink = sink
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class PacketReader:
|
|
"""
|
|
Reader that reads HCI packets from a sync source.
|
|
"""
|
|
|
|
def __init__(self, source: io.BufferedReader) -> None:
|
|
self.source = source
|
|
self.at_end = False
|
|
|
|
def next_packet(self) -> bytes | None:
|
|
# Get the packet type
|
|
packet_type = self.source.read(1)
|
|
if len(packet_type) != 1:
|
|
self.at_end = True
|
|
return None
|
|
|
|
# Get the packet info based on its type
|
|
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
|
if packet_info is None:
|
|
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
|
|
|
|
# Read the header (that includes the length)
|
|
header_size = packet_info[0] + packet_info[1]
|
|
header = self.source.read(header_size)
|
|
if len(header) != header_size:
|
|
raise core.InvalidPacketError('packet too short')
|
|
|
|
# Read the body
|
|
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
|
|
body = self.source.read(body_length)
|
|
if len(body) != body_length:
|
|
raise core.InvalidPacketError('packet too short')
|
|
|
|
return packet_type + header + body
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class AsyncPacketReader:
|
|
"""
|
|
Reader that reads HCI packets from an async source.
|
|
"""
|
|
|
|
def __init__(self, source: asyncio.StreamReader) -> None:
|
|
self.source = source
|
|
|
|
async def next_packet(self) -> bytes:
|
|
# Get the packet type
|
|
packet_type = await self.source.readexactly(1)
|
|
|
|
# Get the packet info based on its type
|
|
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
|
if packet_info is None:
|
|
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
|
|
|
|
# Read the header (that includes the length)
|
|
header_size = packet_info[0] + packet_info[1]
|
|
header = await self.source.readexactly(header_size)
|
|
|
|
# Read the body
|
|
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
|
|
body = await self.source.readexactly(body_length)
|
|
|
|
return packet_type + header + body
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class AsyncPipeSink:
|
|
"""
|
|
Sink that forwards packets asynchronously to another sink.
|
|
"""
|
|
|
|
def __init__(self, sink: TransportSink) -> None:
|
|
self.sink = sink
|
|
self.loop = asyncio.get_running_loop()
|
|
|
|
def on_packet(self, packet: bytes) -> None:
|
|
self.loop.call_soon(self.sink.on_packet, packet)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class BaseSource:
|
|
"""
|
|
Base class designed to be subclassed by transport-specific source classes
|
|
"""
|
|
|
|
terminated: asyncio.Future[None]
|
|
sink: TransportSink | None
|
|
|
|
def __init__(self) -> None:
|
|
self.terminated = asyncio.get_running_loop().create_future()
|
|
self.sink = None
|
|
|
|
def set_packet_sink(self, sink: TransportSink) -> None:
|
|
self.sink = sink
|
|
|
|
def on_transport_lost(self) -> None:
|
|
if not self.terminated.done():
|
|
self.terminated.set_result(None)
|
|
|
|
if self.sink:
|
|
if hasattr(self.sink, 'on_transport_lost'):
|
|
self.sink.on_transport_lost()
|
|
|
|
async def wait_for_termination(self) -> None:
|
|
"""
|
|
Convenience method for backward compatibility. Prefer using the `terminated`
|
|
attribute instead.
|
|
"""
|
|
return await self.terminated
|
|
|
|
def close(self) -> None:
|
|
pass
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class ParserSource(BaseSource):
|
|
"""
|
|
Base class for sources that use an HCI parser.
|
|
"""
|
|
|
|
parser: PacketParser
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.parser = PacketParser()
|
|
|
|
def set_packet_sink(self, sink: TransportSink) -> None:
|
|
super().set_packet_sink(sink)
|
|
self.parser.set_packet_sink(sink)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class StreamPacketSource(asyncio.Protocol, ParserSource):
|
|
def data_received(self, data: bytes) -> None:
|
|
try:
|
|
self.parser.feed_data(data)
|
|
except core.InvalidPacketError:
|
|
logger.warning("invalid packet, ignoring data")
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class StreamPacketSink:
|
|
def __init__(self, transport: asyncio.WriteTransport) -> None:
|
|
self.transport = transport
|
|
|
|
def on_packet(self, packet: bytes) -> None:
|
|
self.transport.write(packet)
|
|
|
|
def close(self) -> None:
|
|
self.transport.close()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class Transport:
|
|
"""
|
|
Base class for all transports.
|
|
|
|
A Transport represents a source and a sink together.
|
|
An instance must be closed by calling close() when no longer used. Instances
|
|
implement the ContextManager protocol so that they may be used in a `async with`
|
|
statement.
|
|
An instance is iterable. The iterator yields, in order, its source and sink, so
|
|
that it may be used with a convenient call syntax like:
|
|
|
|
async with create_transport() as (source, sink):
|
|
...
|
|
"""
|
|
|
|
def __init__(self, source: TransportSource, sink: TransportSink) -> None:
|
|
self.source = source
|
|
self.sink = sink
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *args):
|
|
await self.close()
|
|
|
|
def __iter__(self):
|
|
return iter((self.source, self.sink))
|
|
|
|
async def close(self) -> None:
|
|
if hasattr(self.source, 'close'):
|
|
self.source.close()
|
|
if hasattr(self.sink, 'close'):
|
|
self.sink.close()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class PumpedPacketSource(ParserSource):
|
|
pump_task: asyncio.Task[None] | None
|
|
|
|
def __init__(self, receive) -> None:
|
|
super().__init__()
|
|
self.receive_function = receive
|
|
self.pump_task = None
|
|
|
|
def start(self) -> None:
|
|
async def pump_packets() -> None:
|
|
while True:
|
|
try:
|
|
packet = await self.receive_function()
|
|
self.parser.feed_data(packet)
|
|
except asyncio.CancelledError:
|
|
logger.debug('source pump task done')
|
|
if not self.terminated.done():
|
|
self.terminated.set_result(None)
|
|
break
|
|
except Exception as error:
|
|
logger.exception('exception while waiting for packet')
|
|
if not self.terminated.done():
|
|
self.terminated.set_exception(error)
|
|
break
|
|
|
|
self.pump_task = asyncio.create_task(pump_packets())
|
|
|
|
def close(self) -> None:
|
|
if self.pump_task:
|
|
self.pump_task.cancel()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class PumpedPacketSink:
|
|
pump_task: asyncio.Task[None] | None
|
|
|
|
def __init__(self, send: Callable[[bytes], Awaitable[Any]]):
|
|
self.send_function = send
|
|
self.packet_queue = asyncio.Queue[bytes]()
|
|
self.pump_task = None
|
|
|
|
def on_packet(self, packet: bytes) -> None:
|
|
self.packet_queue.put_nowait(packet)
|
|
|
|
def start(self) -> None:
|
|
async def pump_packets():
|
|
while True:
|
|
try:
|
|
packet = await self.packet_queue.get()
|
|
await self.send_function(packet)
|
|
except asyncio.CancelledError:
|
|
logger.debug('sink pump task done')
|
|
break
|
|
except Exception:
|
|
logger.exception('exception while sending packet')
|
|
break
|
|
|
|
self.pump_task = asyncio.create_task(pump_packets())
|
|
|
|
def close(self):
|
|
if self.pump_task:
|
|
self.pump_task.cancel()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class PumpedTransport(Transport):
|
|
source: PumpedPacketSource
|
|
sink: PumpedPacketSink
|
|
|
|
def __init__(
|
|
self,
|
|
source: PumpedPacketSource,
|
|
sink: PumpedPacketSink,
|
|
) -> None:
|
|
super().__init__(source, sink)
|
|
|
|
def start(self) -> None:
|
|
self.source.start()
|
|
self.sink.start()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class SnoopingTransport(Transport):
|
|
"""Transport wrapper that snoops on packets to/from a wrapped transport."""
|
|
|
|
@staticmethod
|
|
def create_with(
|
|
transport: Transport, snooper: contextlib.AbstractContextManager[Snooper]
|
|
) -> SnoopingTransport:
|
|
"""
|
|
Create an instance given a snooper that works as as context manager.
|
|
|
|
The returned instance will exit the snooper context when it is closed.
|
|
"""
|
|
with contextlib.ExitStack() as exit_stack:
|
|
return SnoopingTransport(
|
|
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
|
|
)
|
|
raise core.UnreachableError() # Satisfy the type checker
|
|
|
|
class Source:
|
|
sink: TransportSink
|
|
|
|
@property
|
|
def metadata(self) -> dict[str, Any]:
|
|
return getattr(self.source, 'metadata', {})
|
|
|
|
def __init__(self, source: TransportSource, snooper: Snooper):
|
|
self.source = source
|
|
self.snooper = snooper
|
|
self.terminated = source.terminated
|
|
|
|
def set_packet_sink(self, sink: TransportSink) -> None:
|
|
self.sink = sink
|
|
self.source.set_packet_sink(self)
|
|
|
|
def on_packet(self, packet: bytes) -> None:
|
|
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
|
|
if self.sink:
|
|
self.sink.on_packet(packet)
|
|
|
|
class Sink:
|
|
def __init__(self, sink: TransportSink, snooper: Snooper) -> None:
|
|
self.sink = sink
|
|
self.snooper = snooper
|
|
|
|
def on_packet(self, packet: bytes) -> None:
|
|
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
|
|
if self.sink:
|
|
self.sink.on_packet(packet)
|
|
|
|
def __init__(
|
|
self,
|
|
transport: Transport,
|
|
snooper: Snooper,
|
|
close_snooper=None,
|
|
) -> None:
|
|
super().__init__(
|
|
self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
|
|
)
|
|
self.transport = transport
|
|
self.close_snooper = close_snooper
|
|
|
|
async def close(self):
|
|
await self.transport.close()
|
|
if self.close_snooper:
|
|
self.close_snooper()
|