diff --git a/bumble/transport/usb.py b/bumble/transport/usb.py index b6e684f..c10f063 100644 --- a/bumble/transport/usb.py +++ b/bumble/transport/usb.py @@ -59,6 +59,8 @@ USB_BT_HCI_CLASS_TUPLE = ( ) MAX_SCO_PACKET_SIZE = 1024 +MAX_SCO_IN_PACKETS = 128 +NUMBER_OF_SCO_IN_TRANSFERS = 2 # ----------------------------------------------------------------------------- @@ -388,20 +390,35 @@ class UsbPacketSink: READ_SIZE = 4096 -class ScoAccumulator: - def __init__(self, emit: Callable[[bytes], Any]) -> None: +class PacketSplitter: + """Splitter than can parse a byte stream and extract packets that consist of a + header and a body, where the header includes an n-byte 'length' field at a + certain offset. + Extracted packets are emitted by calling a function passed to the constructor, + with the full packet (header + body) as argument. + """ + + def __init__( + self, length_offset: int, length_size: int, emit: Callable[[bytes], Any] + ) -> None: self.emit = emit self.packet = b'' + self.length_offset = length_offset + self.length_size = length_size + self.header_size = length_offset + length_size def feed(self, data: bytes) -> None: while data: - # Accumulate until we have a complete 3-byte header - if (bytes_needed := 3 - len(self.packet)) > 0: + # Accumulate until we have a complete header + if (bytes_needed := self.header_size - len(self.packet)) > 0: self.packet += data[:bytes_needed] data = data[bytes_needed:] continue - packet_length = 3 + self.packet[2] + packet_length = self.header_size + int.from_bytes( + self.packet[self.length_offset : self.length_offset + self.length_size], + 'little', + ) bytes_needed = packet_length - len(self.packet) self.packet += data[:bytes_needed] data = data[bytes_needed:] @@ -411,6 +428,24 @@ class ScoAccumulator: self.packet = b'' +class ScoPacketSplitter(PacketSplitter): + def __init__(self, emit: Callable[[bytes], Any]) -> None: + # The length field is 1 byte at offset 2 in the HCI SCO packet header + super().__init__(length_offset=2, length_size=1, emit=emit) + + +class EventPacketSplitter(PacketSplitter): + def __init__(self, emit: Callable[[bytes], Any]) -> None: + # The length field is 1 byte at offset 1 in the HCI Event packet header + super().__init__(length_offset=1, length_size=1, emit=emit) + + +class AclPacketSplitter(PacketSplitter): + def __init__(self, emit: Callable[[bytes], Any]) -> None: + # The length field is 2 bytes at offset 2 in the HCI ACL packet header + super().__init__(length_offset=2, length_size=2, emit=emit) + + class UsbPacketSource(asyncio.Protocol, BaseSource): def __init__(self, device, metadata, interrupt_in, bulk_in, isochronous_in): super().__init__() @@ -422,9 +457,6 @@ class UsbPacketSource(asyncio.Protocol, BaseSource): self.bulk_in_transfer = None self.isochronous_in = isochronous_in self.isochronous_in_transfer = None - self.isochronous_accumulator = ScoAccumulator( - lambda packet: self.queue_packet(hci.HCI_SYNCHRONOUS_DATA_PACKET, packet) - ) self.loop = asyncio.get_running_loop() self.queue = asyncio.Queue() self.dequeue_task = None @@ -433,6 +465,19 @@ class UsbPacketSource(asyncio.Protocol, BaseSource): hci.HCI_ACL_DATA_PACKET: asyncio.Event(), hci.HCI_SYNCHRONOUS_DATA_PACKET: asyncio.Event(), } + self.splitters = { + hci.HCI_EVENT_PACKET: EventPacketSplitter( + lambda packet: self.queue_packet(hci.HCI_EVENT_PACKET, packet) + ), + hci.HCI_ACL_DATA_PACKET: AclPacketSplitter( + lambda packet: self.queue_packet(hci.HCI_ACL_DATA_PACKET, packet) + ), + hci.HCI_SYNCHRONOUS_DATA_PACKET: ScoPacketSplitter( + lambda packet: self.queue_packet( + hci.HCI_SYNCHRONOUS_DATA_PACKET, packet + ) + ), + } self.closed = False self.lock = threading.Lock() @@ -457,10 +502,12 @@ class UsbPacketSource(asyncio.Protocol, BaseSource): self.bulk_in_transfer.submit() if self.isochronous_in is not None: - self.isochronous_in_transfer = self.device.getTransfer(iso_packets=16) + self.isochronous_in_transfer = self.device.getTransfer( + iso_packets=MAX_SCO_IN_PACKETS + ) self.isochronous_in_transfer.setIsochronous( self.isochronous_in.getAddress(), - 16 * self.isochronous_in.getMaxPacketSize(), + MAX_SCO_IN_PACKETS * self.isochronous_in.getMaxPacketSize(), callback=self.transfer_callback, user_data=hci.HCI_SYNCHRONOUS_DATA_PACKET, ) @@ -490,6 +537,8 @@ class UsbPacketSource(asyncio.Protocol, BaseSource): with self.lock: if self.closed: logger.debug("packet source closed, discarding transfer") + elif (splitter := self.splitters.get(packet_type)) is None: + logger.warning(f'no splitter for packet type {packet_type}') else: if packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET: for iso_status, iso_buffer in transfer.iterISO(): @@ -503,11 +552,10 @@ class UsbPacketSource(asyncio.Protocol, BaseSource): len(iso_buffer), iso_buffer.hex(), ) - self.isochronous_accumulator.feed(iso_buffer) + splitter.feed(iso_buffer) else: - self.queue_packet( - packet_type, - transfer.getBuffer()[: transfer.getActualLength()], + splitter.feed( + transfer.getBuffer()[: transfer.getActualLength()] ) # Re-submit the transfer so we can receive more data