diff --git a/bumble/transport/usb.py b/bumble/transport/usb.py index fa1f046..84c7a7c 100644 --- a/bumble/transport/usb.py +++ b/bumble/transport/usb.py @@ -59,6 +59,8 @@ USB_BT_HCI_CLASS_TUPLE = ( ) MAX_SCO_PACKET_SIZE = 1024 +MAX_SCO_IN_PACKETS = 128 +NUMBER_OF_SCO_IN_TRANSFERS = 2 # ----------------------------------------------------------------------------- @@ -407,20 +409,36 @@ class UsbPacketSink: READ_SIZE = 4096 -class ScoAccumulator: - def __init__(self, emit: Callable[[bytes], Any]) -> None: +class PacketSplitter: + """Splitter than can parse a byte stream and extract packets that consist of a + header and a body, where the header includes an n-byte 'length' field at a + certain offset. + Extracted packets are emitted by calling a function passed to the constructor, + with the full packet (header + body) as argument. + """ + + def __init__( + self, length_offset: int, length_size: int, emit: Callable[[bytes], Any] + ) -> None: self.emit = emit self.packet = b'' + self.length_offset = length_offset + self.length_size = length_size + self.header_size = length_offset + length_size def feed(self, data: bytes) -> None: while data: - # Accumulate until we have a complete 3-byte header - if (bytes_needed := 3 - len(self.packet)) > 0: + # Accumulate until we have a complete header + if (bytes_needed := self.header_size - len(self.packet)) > 0: self.packet += data[:bytes_needed] data = data[bytes_needed:] - continue + if len(self.packet) < self.header_size: + continue - packet_length = 3 + self.packet[2] + packet_length = self.header_size + int.from_bytes( + self.packet[self.length_offset : self.length_offset + self.length_size], + 'little', + ) bytes_needed = packet_length - len(self.packet) self.packet += data[:bytes_needed] data = data[bytes_needed:] @@ -430,6 +448,24 @@ class ScoAccumulator: self.packet = b'' +class ScoPacketSplitter(PacketSplitter): + def __init__(self, emit: Callable[[bytes], Any]) -> None: + # The length field is 1 byte at offset 2 in the HCI SCO packet header + super().__init__(length_offset=2, length_size=1, emit=emit) + + +class EventPacketSplitter(PacketSplitter): + def __init__(self, emit: Callable[[bytes], Any]) -> None: + # The length field is 1 byte at offset 1 in the HCI Event packet header + super().__init__(length_offset=1, length_size=1, emit=emit) + + +class AclPacketSplitter(PacketSplitter): + def __init__(self, emit: Callable[[bytes], Any]) -> None: + # The length field is 2 bytes at offset 2 in the HCI ACL packet header + super().__init__(length_offset=2, length_size=2, emit=emit) + + class UsbPacketSource(asyncio.Protocol, BaseSource): def __init__(self, device, metadata, interrupt_in, bulk_in, isochronous_in): super().__init__() @@ -440,17 +476,23 @@ class UsbPacketSource(asyncio.Protocol, BaseSource): self.bulk_in = bulk_in self.bulk_in_transfer = None self.isochronous_in = isochronous_in - self.isochronous_in_transfer = None - self.isochronous_accumulator = ScoAccumulator( - lambda packet: self.queue_packet(hci.HCI_SYNCHRONOUS_DATA_PACKET, packet) - ) + self.isochronous_in_transfers = [] self.loop = asyncio.get_running_loop() self.queue = asyncio.Queue() self.dequeue_task = None - self.done = { - hci.HCI_EVENT_PACKET: asyncio.Event(), - hci.HCI_ACL_DATA_PACKET: asyncio.Event(), - hci.HCI_SYNCHRONOUS_DATA_PACKET: asyncio.Event(), + self.done = {} + self.splitters = { + hci.HCI_EVENT_PACKET: EventPacketSplitter( + lambda packet: self.queue_packet(hci.HCI_EVENT_PACKET, packet) + ), + hci.HCI_ACL_DATA_PACKET: AclPacketSplitter( + lambda packet: self.queue_packet(hci.HCI_ACL_DATA_PACKET, packet) + ), + hci.HCI_SYNCHRONOUS_DATA_PACKET: ScoPacketSplitter( + lambda packet: self.queue_packet( + hci.HCI_SYNCHRONOUS_DATA_PACKET, packet + ) + ), } self.closed = False self.lock = threading.Lock() @@ -464,6 +506,7 @@ class UsbPacketSource(asyncio.Protocol, BaseSource): callback=self.transfer_callback, user_data=hci.HCI_EVENT_PACKET, ) + self.done[self.interrupt_in_transfer] = asyncio.Event() self.interrupt_in_transfer.submit() self.bulk_in_transfer = self.device.getTransfer() @@ -473,17 +516,21 @@ class UsbPacketSource(asyncio.Protocol, BaseSource): callback=self.transfer_callback, user_data=hci.HCI_ACL_DATA_PACKET, ) + self.done[self.bulk_in_transfer] = asyncio.Event() self.bulk_in_transfer.submit() if self.isochronous_in is not None: - self.isochronous_in_transfer = self.device.getTransfer(iso_packets=16) - self.isochronous_in_transfer.setIsochronous( - self.isochronous_in.getAddress(), - 16 * self.isochronous_in.getMaxPacketSize(), - callback=self.transfer_callback, - user_data=hci.HCI_SYNCHRONOUS_DATA_PACKET, - ) - self.isochronous_in_transfer.submit() + for _ in range(NUMBER_OF_SCO_IN_TRANSFERS): + transfer = self.device.getTransfer(iso_packets=MAX_SCO_IN_PACKETS) + transfer.setIsochronous( + self.isochronous_in.getAddress(), + MAX_SCO_IN_PACKETS * self.isochronous_in.getMaxPacketSize(), + callback=self.transfer_callback, + user_data=hci.HCI_SYNCHRONOUS_DATA_PACKET, + ) + self.isochronous_in_transfers.append(transfer) + self.done[transfer] = asyncio.Event() + transfer.submit() self.dequeue_task = self.loop.create_task(self.dequeue()) @@ -509,6 +556,8 @@ class UsbPacketSource(asyncio.Protocol, BaseSource): with self.lock: if self.closed: logger.debug("packet source closed, discarding transfer") + elif (splitter := self.splitters.get(packet_type)) is None: + logger.warning(f'no splitter for packet type {packet_type}') else: if packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET: for iso_status, iso_buffer in transfer.iterISO(): @@ -522,11 +571,10 @@ class UsbPacketSource(asyncio.Protocol, BaseSource): len(iso_buffer), iso_buffer.hex(), ) - self.isochronous_accumulator.feed(iso_buffer) + splitter.feed(iso_buffer) else: - self.queue_packet( - packet_type, - transfer.getBuffer()[: transfer.getActualLength()], + splitter.feed( + transfer.getBuffer()[: transfer.getActualLength()] ) # Re-submit the transfer so we can receive more data @@ -537,12 +585,12 @@ class UsbPacketSource(asyncio.Protocol, BaseSource): self.loop.call_soon_threadsafe(self.on_transport_lost) elif status == usb1.TRANSFER_CANCELLED: logger.debug(f"IN[{packet_type}] transfer canceled") - self.loop.call_soon_threadsafe(self.done[packet_type].set) + self.loop.call_soon_threadsafe(self.done[transfer].set) else: logger.warning( color(f'!!! IN[{packet_type}] transfer not completed', 'red') ) - self.loop.call_soon_threadsafe(self.done[packet_type].set) + self.loop.call_soon_threadsafe(self.done[transfer].set) self.loop.call_soon_threadsafe(self.on_transport_lost) async def dequeue(self): @@ -571,7 +619,7 @@ class UsbPacketSource(asyncio.Protocol, BaseSource): for transfer in ( self.interrupt_in_transfer, self.bulk_in_transfer, - self.isochronous_in_transfer, + *self.isochronous_in_transfers, ): if transfer is None: continue @@ -587,7 +635,7 @@ class UsbPacketSource(asyncio.Protocol, BaseSource): f'waiting for IN[{packet_type}] transfer cancellation ' 'to be done...' ) - await self.done[packet_type].wait() + await self.done[transfer].wait() logger.debug(f'IN[{packet_type}] transfer cancellation done') except usb1.USBError as error: logger.debug( diff --git a/tests/transport_test.py b/tests/transport_test.py index bf7d8d4..c7c3c1c 100644 --- a/tests/transport_test.py +++ b/tests/transport_test.py @@ -24,7 +24,7 @@ import sys import pytest from bumble import controller, device, hci, link, transport -from bumble.transport import common +from bumble.transport import common, usb # ----------------------------------------------------------------------------- @@ -252,6 +252,69 @@ async def test_open_transport_with_metadata(spec): await controller_transport.close() +# ----------------------------------------------------------------------------- +def test_packet_splitter_complete(): + emitted = [] + splitter = usb.AclPacketSplitter(emitted.append) + packet = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44]) + splitter.feed(packet) + assert emitted == [packet] + + +def test_packet_splitter_chunks(): + emitted = [] + splitter = usb.AclPacketSplitter(emitted.append) + packet = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44]) + splitter.feed(packet[:4]) + assert emitted == [] + splitter.feed(packet[4:]) + assert emitted == [packet] + + +def test_packet_splitter_multiple(): + emitted = [] + splitter = usb.AclPacketSplitter(emitted.append) + packet1 = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44]) + packet2 = bytes([0x02, 0x00, 0x02, 0x00, 0x55, 0x66]) + splitter.feed(packet1 + packet2) + assert emitted == [packet1, packet2] + + +def test_packet_splitter_partial(): + emitted = [] + splitter = usb.AclPacketSplitter(emitted.append) + packet1 = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44]) + packet2 = bytes([0x02, 0x00, 0x02, 0x00, 0x55, 0x66]) + splitter.feed(packet1 + packet2[:4]) + assert emitted == [packet1] + splitter.feed(packet2[4:]) + assert emitted == [packet1, packet2] + + +def test_packet_splitter_empty_payload(): + emitted = [] + splitter = usb.AclPacketSplitter(emitted.append) + packet = bytes([0x01, 0x00, 0x00, 0x00]) + splitter.feed(packet) + assert emitted == [packet] + + +def test_sco_packet_splitter(): + emitted = [] + splitter = usb.ScoPacketSplitter(emitted.append) + packet = bytes([0x01, 0x00, 0x03, 0x11, 0x22, 0x33]) + splitter.feed(packet) + assert emitted == [packet] + + +def test_event_packet_splitter(): + emitted = [] + splitter = usb.EventPacketSplitter(emitted.append) + packet = bytes([0x04, 0x02, 0x11, 0x22]) + splitter.feed(packet) + assert emitted == [packet] + + # ----------------------------------------------------------------------------- if __name__ == '__main__': test_parser()