# Copyright 2021-2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- import struct import asyncio import logging from colors import color from .. import hci # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Information needed to parse HCI packets with a generic parser: # For each packet type, the info represents: # (length-size, length-offset, unpack-type) HCI_PACKET_INFO = { hci.HCI_COMMAND_PACKET: (1, 2, 'B'), hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), hci.HCI_EVENT_PACKET: (1, 1, 'B'), } # ----------------------------------------------------------------------------- class PacketPump: ''' Pump HCI packets from a reader to a sink ''' def __init__(self, reader, sink): self.reader = reader self.sink = sink async def run(self): while True: try: # Get a packet from the source packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet()) # Deliver the packet to the sink self.sink.on_packet(packet) except Exception as error: logger.warning(f'!!! {error}') # ----------------------------------------------------------------------------- class PacketParser: ''' In-line parser that accepts data and emits 'on_packet' when a full packet has been parsed ''' # pylint: disable=attribute-defined-outside-init NEED_TYPE = 0 NEED_LENGTH = 1 NEED_BODY = 2 def __init__(self, sink=None): self.sink = sink self.extended_packet_info = {} self.reset() def reset(self): self.state = PacketParser.NEED_TYPE self.bytes_needed = 1 self.packet = bytearray() self.packet_info = None def feed_data(self, data): data_offset = 0 data_left = len(data) while data_left and self.bytes_needed: consumed = min(self.bytes_needed, data_left) self.packet.extend(data[data_offset : data_offset + consumed]) data_offset += consumed data_left -= consumed self.bytes_needed -= consumed if self.bytes_needed == 0: if self.state == PacketParser.NEED_TYPE: packet_type = self.packet[0] self.packet_info = HCI_PACKET_INFO.get( packet_type ) or self.extended_packet_info.get(packet_type) if self.packet_info is None: raise ValueError(f'invalid packet type {packet_type}') self.state = PacketParser.NEED_LENGTH self.bytes_needed = self.packet_info[0] + self.packet_info[1] elif self.state == PacketParser.NEED_LENGTH: body_length = struct.unpack_from( self.packet_info[2], self.packet, 1 + self.packet_info[1] )[0] self.bytes_needed = body_length self.state = PacketParser.NEED_BODY # Emit a packet if one is complete if self.state == PacketParser.NEED_BODY and not self.bytes_needed: if self.sink: try: self.sink.on_packet(bytes(self.packet)) except Exception as error: logger.warning( color(f'!!! Exception in on_packet: {error}', 'red') ) self.reset() def set_packet_sink(self, sink): self.sink = sink # ----------------------------------------------------------------------------- class PacketReader: ''' Reader that reads HCI packets from a sync source ''' def __init__(self, source): self.source = source def next_packet(self): # Get the packet type packet_type = self.source.read(1) if len(packet_type) != 1: return None # Get the packet info based on its type packet_info = HCI_PACKET_INFO.get(packet_type[0]) if packet_info is None: raise ValueError(f'invalid packet type {packet_type} found') # Read the header (that includes the length) header_size = packet_info[0] + packet_info[1] header = self.source.read(header_size) if len(header) != header_size: raise ValueError('packet too short') # Read the body body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0] body = self.source.read(body_length) if len(body) != body_length: raise ValueError('packet too short') return packet_type + header + body # ----------------------------------------------------------------------------- class AsyncPacketReader: ''' Reader that reads HCI packets from an async source ''' def __init__(self, source): self.source = source async def next_packet(self): # Get the packet type packet_type = await self.source.readexactly(1) # Get the packet info based on its type packet_info = HCI_PACKET_INFO.get(packet_type[0]) if packet_info is None: raise ValueError(f'invalid packet type {packet_type} found') # Read the header (that includes the length) header_size = packet_info[0] + packet_info[1] header = await self.source.readexactly(header_size) # Read the body body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0] body = await self.source.readexactly(body_length) return packet_type + header + body # ----------------------------------------------------------------------------- class AsyncPipeSink: ''' Sink that forwards packets asynchronously to another sink ''' def __init__(self, sink): self.sink = sink self.loop = asyncio.get_running_loop() def on_packet(self, packet): self.loop.call_soon(self.sink.on_packet, packet) # ----------------------------------------------------------------------------- class ParserSource: """ Base class designed to be subclassed by transport-specific source classes """ def __init__(self): self.parser = PacketParser() self.terminated = asyncio.get_running_loop().create_future() def set_packet_sink(self, sink): self.parser.set_packet_sink(sink) async def wait_for_termination(self): return await self.terminated def close(self): pass # ----------------------------------------------------------------------------- class StreamPacketSource(asyncio.Protocol, ParserSource): def data_received(self, data): self.parser.feed_data(data) # ----------------------------------------------------------------------------- class StreamPacketSink: def __init__(self, transport): self.transport = transport def on_packet(self, packet): self.transport.write(packet) def close(self): self.transport.close() # ----------------------------------------------------------------------------- class Transport: def __init__(self, source, sink): self.source = source self.sink = sink async def __aenter__(self): return self async def __aexit__(self, *args): await self.close() def __iter__(self): return iter((self.source, self.sink)) async def close(self): self.source.close() self.sink.close() # ----------------------------------------------------------------------------- class PumpedPacketSource(ParserSource): def __init__(self, receive): super().__init__() self.receive_function = receive self.pump_task = None def start(self): async def pump_packets(): while True: try: packet = await self.receive_function() self.parser.feed_data(packet) except asyncio.exceptions.CancelledError: logger.debug('source pump task done') break except Exception as error: logger.warning(f'exception while waiting for packet: {error}') self.terminated.set_result(error) break self.pump_task = asyncio.create_task(pump_packets()) def close(self): if self.pump_task: self.pump_task.cancel() # ----------------------------------------------------------------------------- class PumpedPacketSink: def __init__(self, send): self.send_function = send self.packet_queue = asyncio.Queue() self.pump_task = None def on_packet(self, packet): self.packet_queue.put_nowait(packet) def start(self): async def pump_packets(): while True: try: packet = await self.packet_queue.get() await self.send_function(packet) except asyncio.exceptions.CancelledError: logger.debug('sink pump task done') break except Exception as error: logger.warning(f'exception while sending packet: {error}') break self.pump_task = asyncio.create_task(pump_packets()) def close(self): if self.pump_task: self.pump_task.cancel() # ----------------------------------------------------------------------------- class PumpedTransport(Transport): def __init__(self, source, sink, close_function): super().__init__(source, sink) self.close_function = close_function def start(self): self.source.start() self.sink.start() async def close(self): await super().close() await self.close_function()