mirror of
https://github.com/google/bumble.git
synced 2026-04-16 00:25:31 +00:00
Squashed commits: [cd479ba] formatting and linting automation [7fbfabb] formatting and linting automation [c4f9505] fix after rebase [f506ad4] rename job [441d517] update doc (+7 squashed commits) [2e1b416] fix invoke and github action [6ae5bb4] doc for git blame [44b5461] add GitHub action [b07474f] add docs [4cd9a6f] more linter fixes [db71901] wip [540dc88] wip
338 lines
11 KiB
Python
338 lines
11 KiB
Python
# Copyright 2021-2022 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Imports
|
|
# -----------------------------------------------------------------------------
|
|
import struct
|
|
import asyncio
|
|
import logging
|
|
from colors import color
|
|
|
|
from .. import hci
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Logging
|
|
# -----------------------------------------------------------------------------
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Information needed to parse HCI packets with a generic parser:
|
|
# For each packet type, the info represents:
|
|
# (length-size, length-offset, unpack-type)
|
|
HCI_PACKET_INFO = {
|
|
hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
|
|
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
|
|
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
|
|
hci.HCI_EVENT_PACKET: (1, 1, 'B'),
|
|
}
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class PacketPump:
|
|
'''
|
|
Pump HCI packets from a reader to a sink
|
|
'''
|
|
|
|
def __init__(self, reader, sink):
|
|
self.reader = reader
|
|
self.sink = sink
|
|
|
|
async def run(self):
|
|
while True:
|
|
try:
|
|
# Get a packet from the source
|
|
packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet())
|
|
|
|
# Deliver the packet to the sink
|
|
self.sink.on_packet(packet)
|
|
except Exception as error:
|
|
logger.warning(f'!!! {error}')
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class PacketParser:
|
|
'''
|
|
In-line parser that accepts data and emits 'on_packet' when a full packet has been
|
|
parsed
|
|
'''
|
|
|
|
# pylint: disable=attribute-defined-outside-init
|
|
|
|
NEED_TYPE = 0
|
|
NEED_LENGTH = 1
|
|
NEED_BODY = 2
|
|
|
|
def __init__(self, sink=None):
|
|
self.sink = sink
|
|
self.extended_packet_info = {}
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.state = PacketParser.NEED_TYPE
|
|
self.bytes_needed = 1
|
|
self.packet = bytearray()
|
|
self.packet_info = None
|
|
|
|
def feed_data(self, data):
|
|
data_offset = 0
|
|
data_left = len(data)
|
|
while data_left and self.bytes_needed:
|
|
consumed = min(self.bytes_needed, data_left)
|
|
self.packet.extend(data[data_offset : data_offset + consumed])
|
|
data_offset += consumed
|
|
data_left -= consumed
|
|
self.bytes_needed -= consumed
|
|
|
|
if self.bytes_needed == 0:
|
|
if self.state == PacketParser.NEED_TYPE:
|
|
packet_type = self.packet[0]
|
|
self.packet_info = HCI_PACKET_INFO.get(
|
|
packet_type
|
|
) or self.extended_packet_info.get(packet_type)
|
|
if self.packet_info is None:
|
|
raise ValueError(f'invalid packet type {packet_type}')
|
|
self.state = PacketParser.NEED_LENGTH
|
|
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
|
|
elif self.state == PacketParser.NEED_LENGTH:
|
|
body_length = struct.unpack_from(
|
|
self.packet_info[2], self.packet, 1 + self.packet_info[1]
|
|
)[0]
|
|
self.bytes_needed = body_length
|
|
self.state = PacketParser.NEED_BODY
|
|
|
|
# Emit a packet if one is complete
|
|
if self.state == PacketParser.NEED_BODY and not self.bytes_needed:
|
|
if self.sink:
|
|
try:
|
|
self.sink.on_packet(bytes(self.packet))
|
|
except Exception as error:
|
|
logger.warning(
|
|
color(f'!!! Exception in on_packet: {error}', 'red')
|
|
)
|
|
self.reset()
|
|
|
|
def set_packet_sink(self, sink):
|
|
self.sink = sink
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class PacketReader:
|
|
'''
|
|
Reader that reads HCI packets from a sync source
|
|
'''
|
|
|
|
def __init__(self, source):
|
|
self.source = source
|
|
|
|
def next_packet(self):
|
|
# Get the packet type
|
|
packet_type = self.source.read(1)
|
|
if len(packet_type) != 1:
|
|
return None
|
|
|
|
# Get the packet info based on its type
|
|
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
|
if packet_info is None:
|
|
raise ValueError(f'invalid packet type {packet_type} found')
|
|
|
|
# Read the header (that includes the length)
|
|
header_size = packet_info[0] + packet_info[1]
|
|
header = self.source.read(header_size)
|
|
if len(header) != header_size:
|
|
raise ValueError('packet too short')
|
|
|
|
# Read the body
|
|
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
|
|
body = self.source.read(body_length)
|
|
if len(body) != body_length:
|
|
raise ValueError('packet too short')
|
|
|
|
return packet_type + header + body
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class AsyncPacketReader:
|
|
'''
|
|
Reader that reads HCI packets from an async source
|
|
'''
|
|
|
|
def __init__(self, source):
|
|
self.source = source
|
|
|
|
async def next_packet(self):
|
|
# Get the packet type
|
|
packet_type = await self.source.readexactly(1)
|
|
|
|
# Get the packet info based on its type
|
|
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
|
if packet_info is None:
|
|
raise ValueError(f'invalid packet type {packet_type} found')
|
|
|
|
# Read the header (that includes the length)
|
|
header_size = packet_info[0] + packet_info[1]
|
|
header = await self.source.readexactly(header_size)
|
|
|
|
# Read the body
|
|
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
|
|
body = await self.source.readexactly(body_length)
|
|
|
|
return packet_type + header + body
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class AsyncPipeSink:
|
|
'''
|
|
Sink that forwards packets asynchronously to another sink
|
|
'''
|
|
|
|
def __init__(self, sink):
|
|
self.sink = sink
|
|
self.loop = asyncio.get_running_loop()
|
|
|
|
def on_packet(self, packet):
|
|
self.loop.call_soon(self.sink.on_packet, packet)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class ParserSource:
|
|
"""
|
|
Base class designed to be subclassed by transport-specific source classes
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.parser = PacketParser()
|
|
self.terminated = asyncio.get_running_loop().create_future()
|
|
|
|
def set_packet_sink(self, sink):
|
|
self.parser.set_packet_sink(sink)
|
|
|
|
async def wait_for_termination(self):
|
|
return await self.terminated
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class StreamPacketSource(asyncio.Protocol, ParserSource):
|
|
def data_received(self, data):
|
|
self.parser.feed_data(data)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class StreamPacketSink:
|
|
def __init__(self, transport):
|
|
self.transport = transport
|
|
|
|
def on_packet(self, packet):
|
|
self.transport.write(packet)
|
|
|
|
def close(self):
|
|
self.transport.close()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class Transport:
|
|
def __init__(self, source, sink):
|
|
self.source = source
|
|
self.sink = sink
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *args):
|
|
await self.close()
|
|
|
|
def __iter__(self):
|
|
return iter((self.source, self.sink))
|
|
|
|
async def close(self):
|
|
self.source.close()
|
|
self.sink.close()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class PumpedPacketSource(ParserSource):
|
|
def __init__(self, receive):
|
|
super().__init__()
|
|
self.receive_function = receive
|
|
self.pump_task = None
|
|
|
|
def start(self):
|
|
async def pump_packets():
|
|
while True:
|
|
try:
|
|
packet = await self.receive_function()
|
|
self.parser.feed_data(packet)
|
|
except asyncio.exceptions.CancelledError:
|
|
logger.debug('source pump task done')
|
|
break
|
|
except Exception as error:
|
|
logger.warning(f'exception while waiting for packet: {error}')
|
|
self.terminated.set_result(error)
|
|
break
|
|
|
|
self.pump_task = asyncio.create_task(pump_packets())
|
|
|
|
def close(self):
|
|
if self.pump_task:
|
|
self.pump_task.cancel()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class PumpedPacketSink:
|
|
def __init__(self, send):
|
|
self.send_function = send
|
|
self.packet_queue = asyncio.Queue()
|
|
self.pump_task = None
|
|
|
|
def on_packet(self, packet):
|
|
self.packet_queue.put_nowait(packet)
|
|
|
|
def start(self):
|
|
async def pump_packets():
|
|
while True:
|
|
try:
|
|
packet = await self.packet_queue.get()
|
|
await self.send_function(packet)
|
|
except asyncio.exceptions.CancelledError:
|
|
logger.debug('sink pump task done')
|
|
break
|
|
except Exception as error:
|
|
logger.warning(f'exception while sending packet: {error}')
|
|
break
|
|
|
|
self.pump_task = asyncio.create_task(pump_packets())
|
|
|
|
def close(self):
|
|
if self.pump_task:
|
|
self.pump_task.cancel()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
class PumpedTransport(Transport):
|
|
def __init__(self, source, sink, close_function):
|
|
super().__init__(source, sink)
|
|
self.close_function = close_function
|
|
|
|
def start(self):
|
|
self.source.start()
|
|
self.sink.start()
|
|
|
|
async def close(self):
|
|
await super().close()
|
|
await self.close_function()
|