mirror of
https://github.com/google/bumble.git
synced 2026-06-08 08:42:26 +00:00
Merge pull request #936 from google/gbg/usb-transport-packet-splitter
usb transport packet splitter
This commit is contained in:
+78
-30
@@ -59,6 +59,8 @@ USB_BT_HCI_CLASS_TUPLE = (
|
||||
)
|
||||
|
||||
MAX_SCO_PACKET_SIZE = 1024
|
||||
MAX_SCO_IN_PACKETS = 128
|
||||
NUMBER_OF_SCO_IN_TRANSFERS = 2
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -407,20 +409,36 @@ class UsbPacketSink:
|
||||
READ_SIZE = 4096
|
||||
|
||||
|
||||
class ScoAccumulator:
|
||||
def __init__(self, emit: Callable[[bytes], Any]) -> None:
|
||||
class PacketSplitter:
|
||||
"""Splitter than can parse a byte stream and extract packets that consist of a
|
||||
header and a body, where the header includes an n-byte 'length' field at a
|
||||
certain offset.
|
||||
Extracted packets are emitted by calling a function passed to the constructor,
|
||||
with the full packet (header + body) as argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, length_offset: int, length_size: int, emit: Callable[[bytes], Any]
|
||||
) -> None:
|
||||
self.emit = emit
|
||||
self.packet = b''
|
||||
self.length_offset = length_offset
|
||||
self.length_size = length_size
|
||||
self.header_size = length_offset + length_size
|
||||
|
||||
def feed(self, data: bytes) -> None:
|
||||
while data:
|
||||
# Accumulate until we have a complete 3-byte header
|
||||
if (bytes_needed := 3 - len(self.packet)) > 0:
|
||||
# Accumulate until we have a complete header
|
||||
if (bytes_needed := self.header_size - len(self.packet)) > 0:
|
||||
self.packet += data[:bytes_needed]
|
||||
data = data[bytes_needed:]
|
||||
continue
|
||||
if len(self.packet) < self.header_size:
|
||||
continue
|
||||
|
||||
packet_length = 3 + self.packet[2]
|
||||
packet_length = self.header_size + int.from_bytes(
|
||||
self.packet[self.length_offset : self.length_offset + self.length_size],
|
||||
'little',
|
||||
)
|
||||
bytes_needed = packet_length - len(self.packet)
|
||||
self.packet += data[:bytes_needed]
|
||||
data = data[bytes_needed:]
|
||||
@@ -430,6 +448,24 @@ class ScoAccumulator:
|
||||
self.packet = b''
|
||||
|
||||
|
||||
class ScoPacketSplitter(PacketSplitter):
|
||||
def __init__(self, emit: Callable[[bytes], Any]) -> None:
|
||||
# The length field is 1 byte at offset 2 in the HCI SCO packet header
|
||||
super().__init__(length_offset=2, length_size=1, emit=emit)
|
||||
|
||||
|
||||
class EventPacketSplitter(PacketSplitter):
|
||||
def __init__(self, emit: Callable[[bytes], Any]) -> None:
|
||||
# The length field is 1 byte at offset 1 in the HCI Event packet header
|
||||
super().__init__(length_offset=1, length_size=1, emit=emit)
|
||||
|
||||
|
||||
class AclPacketSplitter(PacketSplitter):
|
||||
def __init__(self, emit: Callable[[bytes], Any]) -> None:
|
||||
# The length field is 2 bytes at offset 2 in the HCI ACL packet header
|
||||
super().__init__(length_offset=2, length_size=2, emit=emit)
|
||||
|
||||
|
||||
class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
def __init__(self, device, metadata, interrupt_in, bulk_in, isochronous_in):
|
||||
super().__init__()
|
||||
@@ -440,17 +476,23 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
self.bulk_in = bulk_in
|
||||
self.bulk_in_transfer = None
|
||||
self.isochronous_in = isochronous_in
|
||||
self.isochronous_in_transfer = None
|
||||
self.isochronous_accumulator = ScoAccumulator(
|
||||
lambda packet: self.queue_packet(hci.HCI_SYNCHRONOUS_DATA_PACKET, packet)
|
||||
)
|
||||
self.isochronous_in_transfers = []
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.queue = asyncio.Queue()
|
||||
self.dequeue_task = None
|
||||
self.done = {
|
||||
hci.HCI_EVENT_PACKET: asyncio.Event(),
|
||||
hci.HCI_ACL_DATA_PACKET: asyncio.Event(),
|
||||
hci.HCI_SYNCHRONOUS_DATA_PACKET: asyncio.Event(),
|
||||
self.done = {}
|
||||
self.splitters = {
|
||||
hci.HCI_EVENT_PACKET: EventPacketSplitter(
|
||||
lambda packet: self.queue_packet(hci.HCI_EVENT_PACKET, packet)
|
||||
),
|
||||
hci.HCI_ACL_DATA_PACKET: AclPacketSplitter(
|
||||
lambda packet: self.queue_packet(hci.HCI_ACL_DATA_PACKET, packet)
|
||||
),
|
||||
hci.HCI_SYNCHRONOUS_DATA_PACKET: ScoPacketSplitter(
|
||||
lambda packet: self.queue_packet(
|
||||
hci.HCI_SYNCHRONOUS_DATA_PACKET, packet
|
||||
)
|
||||
),
|
||||
}
|
||||
self.closed = False
|
||||
self.lock = threading.Lock()
|
||||
@@ -464,6 +506,7 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
callback=self.transfer_callback,
|
||||
user_data=hci.HCI_EVENT_PACKET,
|
||||
)
|
||||
self.done[self.interrupt_in_transfer] = asyncio.Event()
|
||||
self.interrupt_in_transfer.submit()
|
||||
|
||||
self.bulk_in_transfer = self.device.getTransfer()
|
||||
@@ -473,17 +516,21 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
callback=self.transfer_callback,
|
||||
user_data=hci.HCI_ACL_DATA_PACKET,
|
||||
)
|
||||
self.done[self.bulk_in_transfer] = asyncio.Event()
|
||||
self.bulk_in_transfer.submit()
|
||||
|
||||
if self.isochronous_in is not None:
|
||||
self.isochronous_in_transfer = self.device.getTransfer(iso_packets=16)
|
||||
self.isochronous_in_transfer.setIsochronous(
|
||||
self.isochronous_in.getAddress(),
|
||||
16 * self.isochronous_in.getMaxPacketSize(),
|
||||
callback=self.transfer_callback,
|
||||
user_data=hci.HCI_SYNCHRONOUS_DATA_PACKET,
|
||||
)
|
||||
self.isochronous_in_transfer.submit()
|
||||
for _ in range(NUMBER_OF_SCO_IN_TRANSFERS):
|
||||
transfer = self.device.getTransfer(iso_packets=MAX_SCO_IN_PACKETS)
|
||||
transfer.setIsochronous(
|
||||
self.isochronous_in.getAddress(),
|
||||
MAX_SCO_IN_PACKETS * self.isochronous_in.getMaxPacketSize(),
|
||||
callback=self.transfer_callback,
|
||||
user_data=hci.HCI_SYNCHRONOUS_DATA_PACKET,
|
||||
)
|
||||
self.isochronous_in_transfers.append(transfer)
|
||||
self.done[transfer] = asyncio.Event()
|
||||
transfer.submit()
|
||||
|
||||
self.dequeue_task = self.loop.create_task(self.dequeue())
|
||||
|
||||
@@ -509,6 +556,8 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
with self.lock:
|
||||
if self.closed:
|
||||
logger.debug("packet source closed, discarding transfer")
|
||||
elif (splitter := self.splitters.get(packet_type)) is None:
|
||||
logger.warning(f'no splitter for packet type {packet_type}')
|
||||
else:
|
||||
if packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET:
|
||||
for iso_status, iso_buffer in transfer.iterISO():
|
||||
@@ -522,11 +571,10 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
len(iso_buffer),
|
||||
iso_buffer.hex(),
|
||||
)
|
||||
self.isochronous_accumulator.feed(iso_buffer)
|
||||
splitter.feed(iso_buffer)
|
||||
else:
|
||||
self.queue_packet(
|
||||
packet_type,
|
||||
transfer.getBuffer()[: transfer.getActualLength()],
|
||||
splitter.feed(
|
||||
transfer.getBuffer()[: transfer.getActualLength()]
|
||||
)
|
||||
|
||||
# Re-submit the transfer so we can receive more data
|
||||
@@ -537,12 +585,12 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
self.loop.call_soon_threadsafe(self.on_transport_lost)
|
||||
elif status == usb1.TRANSFER_CANCELLED:
|
||||
logger.debug(f"IN[{packet_type}] transfer canceled")
|
||||
self.loop.call_soon_threadsafe(self.done[packet_type].set)
|
||||
self.loop.call_soon_threadsafe(self.done[transfer].set)
|
||||
else:
|
||||
logger.warning(
|
||||
color(f'!!! IN[{packet_type}] transfer not completed', 'red')
|
||||
)
|
||||
self.loop.call_soon_threadsafe(self.done[packet_type].set)
|
||||
self.loop.call_soon_threadsafe(self.done[transfer].set)
|
||||
self.loop.call_soon_threadsafe(self.on_transport_lost)
|
||||
|
||||
async def dequeue(self):
|
||||
@@ -571,7 +619,7 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
for transfer in (
|
||||
self.interrupt_in_transfer,
|
||||
self.bulk_in_transfer,
|
||||
self.isochronous_in_transfer,
|
||||
*self.isochronous_in_transfers,
|
||||
):
|
||||
if transfer is None:
|
||||
continue
|
||||
@@ -587,7 +635,7 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
f'waiting for IN[{packet_type}] transfer cancellation '
|
||||
'to be done...'
|
||||
)
|
||||
await self.done[packet_type].wait()
|
||||
await self.done[transfer].wait()
|
||||
logger.debug(f'IN[{packet_type}] transfer cancellation done')
|
||||
except usb1.USBError as error:
|
||||
logger.debug(
|
||||
|
||||
+64
-1
@@ -24,7 +24,7 @@ import sys
|
||||
import pytest
|
||||
|
||||
from bumble import controller, device, hci, link, transport
|
||||
from bumble.transport import common
|
||||
from bumble.transport import common, usb
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -252,6 +252,69 @@ async def test_open_transport_with_metadata(spec):
|
||||
await controller_transport.close()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_packet_splitter_complete():
|
||||
emitted = []
|
||||
splitter = usb.AclPacketSplitter(emitted.append)
|
||||
packet = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
|
||||
splitter.feed(packet)
|
||||
assert emitted == [packet]
|
||||
|
||||
|
||||
def test_packet_splitter_chunks():
|
||||
emitted = []
|
||||
splitter = usb.AclPacketSplitter(emitted.append)
|
||||
packet = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
|
||||
splitter.feed(packet[:4])
|
||||
assert emitted == []
|
||||
splitter.feed(packet[4:])
|
||||
assert emitted == [packet]
|
||||
|
||||
|
||||
def test_packet_splitter_multiple():
|
||||
emitted = []
|
||||
splitter = usb.AclPacketSplitter(emitted.append)
|
||||
packet1 = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
|
||||
packet2 = bytes([0x02, 0x00, 0x02, 0x00, 0x55, 0x66])
|
||||
splitter.feed(packet1 + packet2)
|
||||
assert emitted == [packet1, packet2]
|
||||
|
||||
|
||||
def test_packet_splitter_partial():
|
||||
emitted = []
|
||||
splitter = usb.AclPacketSplitter(emitted.append)
|
||||
packet1 = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
|
||||
packet2 = bytes([0x02, 0x00, 0x02, 0x00, 0x55, 0x66])
|
||||
splitter.feed(packet1 + packet2[:4])
|
||||
assert emitted == [packet1]
|
||||
splitter.feed(packet2[4:])
|
||||
assert emitted == [packet1, packet2]
|
||||
|
||||
|
||||
def test_packet_splitter_empty_payload():
|
||||
emitted = []
|
||||
splitter = usb.AclPacketSplitter(emitted.append)
|
||||
packet = bytes([0x01, 0x00, 0x00, 0x00])
|
||||
splitter.feed(packet)
|
||||
assert emitted == [packet]
|
||||
|
||||
|
||||
def test_sco_packet_splitter():
|
||||
emitted = []
|
||||
splitter = usb.ScoPacketSplitter(emitted.append)
|
||||
packet = bytes([0x01, 0x00, 0x03, 0x11, 0x22, 0x33])
|
||||
splitter.feed(packet)
|
||||
assert emitted == [packet]
|
||||
|
||||
|
||||
def test_event_packet_splitter():
|
||||
emitted = []
|
||||
splitter = usb.EventPacketSplitter(emitted.append)
|
||||
packet = bytes([0x04, 0x02, 0x11, 0x22])
|
||||
splitter.feed(packet)
|
||||
assert emitted == [packet]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
test_parser()
|
||||
|
||||
Reference in New Issue
Block a user