mirror of
https://github.com/google/bumble.git
synced 2026-06-23 10:50:49 +00:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 16dd5ae63d | |||
| 3266d16cf1 | |||
| 65c4f9a698 | |||
| 17bc5566aa | |||
| b6a21fa3c6 | |||
| 7a14ebdabe | |||
| e44eaf2147 | |||
| 17a202bc13 | |||
| ef634953f0 | |||
| 72d821b1f6 | |||
| afe064b4ea | |||
| 8d0cef70c2 |
+41
-24
@@ -17,6 +17,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import enum
|
||||
import logging
|
||||
@@ -1946,9 +1947,6 @@ class Stream:
|
||||
await self.rtp_channel.disconnect()
|
||||
self.rtp_channel = None
|
||||
|
||||
# Release the endpoint
|
||||
self.local_endpoint.in_use = 0
|
||||
|
||||
self.change_state(State.IDLE)
|
||||
|
||||
async def on_set_configuration_command(
|
||||
@@ -2039,7 +2037,6 @@ class Stream:
|
||||
|
||||
if self.rtp_channel is None:
|
||||
# No channel to release, we're done
|
||||
self.local_endpoint.in_use = 0
|
||||
self.change_state(State.IDLE)
|
||||
else:
|
||||
# TODO: set a timer as we wait for the RTP channel to be closed
|
||||
@@ -2051,7 +2048,6 @@ class Stream:
|
||||
await self.local_endpoint.on_abort_command()
|
||||
if self.rtp_channel is None:
|
||||
# No need to wait
|
||||
self.local_endpoint.in_use = 0
|
||||
self.change_state(State.IDLE)
|
||||
else:
|
||||
# Wait for the RTP channel to be closed
|
||||
@@ -2074,7 +2070,6 @@ class Stream:
|
||||
def on_l2cap_channel_close(self) -> None:
|
||||
logger.debug(color('<<< stream channel closed', 'magenta'))
|
||||
self.local_endpoint.on_rtp_channel_close()
|
||||
self.local_endpoint.in_use = 0
|
||||
self.rtp_channel = None
|
||||
|
||||
if self.state in (State.CLOSING, State.ABORTING):
|
||||
@@ -2099,7 +2094,6 @@ class Stream:
|
||||
self.state = State.IDLE
|
||||
|
||||
local_endpoint.stream = self
|
||||
local_endpoint.in_use = 1
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
@@ -2109,14 +2103,16 @@ class Stream:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class StreamEndPoint:
|
||||
class StreamEndPoint(abc.ABC):
|
||||
seid: int
|
||||
media_type: MediaType
|
||||
tsep: StreamEndPointType
|
||||
in_use: int
|
||||
capabilities: Iterable[ServiceCapabilities]
|
||||
|
||||
@property
|
||||
def in_use(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class StreamEndPointProxy:
|
||||
@@ -2156,14 +2152,30 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
|
||||
in_use: int,
|
||||
capabilities: Iterable[ServiceCapabilities],
|
||||
) -> None:
|
||||
StreamEndPoint.__init__(self, seid, media_type, tsep, in_use, capabilities)
|
||||
StreamEndPointProxy.__init__(self, protocol, seid)
|
||||
# StreamEndPoint attributes
|
||||
self.seid = seid
|
||||
self.media_type = media_type
|
||||
self.tsep = tsep
|
||||
self._in_use = in_use
|
||||
self.capabilities = capabilities
|
||||
|
||||
StreamEndPointProxy.__init__(self, protocol=protocol, seid=seid)
|
||||
|
||||
@property
|
||||
def in_use(self) -> int:
|
||||
return self._in_use
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
|
||||
stream: Stream | None
|
||||
|
||||
@property
|
||||
def in_use(self) -> int:
|
||||
if self.stream and self.stream.state != State.IDLE:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
EVENT_CONFIGURATION = "configuration"
|
||||
EVENT_OPEN = "open"
|
||||
EVENT_START = "start"
|
||||
@@ -2186,8 +2198,13 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
|
||||
capabilities: Iterable[ServiceCapabilities],
|
||||
configuration: Iterable[ServiceCapabilities] | None = None,
|
||||
):
|
||||
StreamEndPoint.__init__(self, seid, media_type, tsep, 0, capabilities)
|
||||
utils.EventEmitter.__init__(self)
|
||||
# StreamEndPoint attributes
|
||||
self.seid = seid
|
||||
self.media_type = media_type
|
||||
self.tsep = tsep
|
||||
self.capabilities = capabilities
|
||||
|
||||
self.protocol = protocol
|
||||
self.configuration = configuration if configuration is not None else []
|
||||
self.stream = None
|
||||
@@ -2273,12 +2290,12 @@ class LocalSource(LocalStreamEndPoint):
|
||||
codec_capabilities,
|
||||
] + list(other_capabilities)
|
||||
super().__init__(
|
||||
protocol,
|
||||
seid,
|
||||
codec_capabilities.media_type,
|
||||
AVDTP_TSEP_SRC,
|
||||
capabilities,
|
||||
capabilities,
|
||||
protocol=protocol,
|
||||
seid=seid,
|
||||
media_type=codec_capabilities.media_type,
|
||||
tsep=AVDTP_TSEP_SRC,
|
||||
capabilities=capabilities,
|
||||
configuration=capabilities,
|
||||
)
|
||||
self.packet_pump = packet_pump
|
||||
|
||||
@@ -2317,11 +2334,11 @@ class LocalSink(LocalStreamEndPoint):
|
||||
codec_capabilities,
|
||||
]
|
||||
super().__init__(
|
||||
protocol,
|
||||
seid,
|
||||
codec_capabilities.media_type,
|
||||
AVDTP_TSEP_SNK,
|
||||
capabilities,
|
||||
protocol=protocol,
|
||||
seid=seid,
|
||||
media_type=codec_capabilities.media_type,
|
||||
tsep=AVDTP_TSEP_SNK,
|
||||
capabilities=capabilities,
|
||||
)
|
||||
|
||||
def on_rtp_channel_open(self) -> None:
|
||||
|
||||
@@ -247,6 +247,7 @@ class Host(utils.EventEmitter):
|
||||
bis_links: dict[int, IsoLink]
|
||||
sco_links: dict[int, ScoLink]
|
||||
bigs: dict[int, set[int]]
|
||||
link_ts_flags: dict[int, int]
|
||||
acl_packet_queue: DataPacketQueue | None = None
|
||||
le_acl_packet_queue: DataPacketQueue | None = None
|
||||
iso_packet_queue: DataPacketQueue | None = None
|
||||
@@ -269,6 +270,7 @@ class Host(utils.EventEmitter):
|
||||
self.bis_links = {} # BIS links, by connection handle
|
||||
self.sco_links = {} # SCO links, by connection handle
|
||||
self.bigs = {} # BIG Handle to BIS Handles
|
||||
self.link_ts_flags = {} # TS_Flag for ISO links, by handle
|
||||
self.pending_command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand | None = None
|
||||
self.pending_response: (
|
||||
asyncio.Future[
|
||||
@@ -486,6 +488,7 @@ class Host(utils.EventEmitter):
|
||||
hci.HCI_LE_PHY_UPDATE_COMPLETE_EVENT,
|
||||
hci.HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT,
|
||||
hci.HCI_LE_PERIODIC_ADVERTISING_SYNC_ESTABLISHED_EVENT,
|
||||
hci.HCI_LE_PERIODIC_ADVERTISING_SYNC_ESTABLISHED_V2_EVENT,
|
||||
hci.HCI_LE_PERIODIC_ADVERTISING_REPORT_EVENT,
|
||||
hci.HCI_LE_PERIODIC_ADVERTISING_SYNC_LOST_EVENT,
|
||||
hci.HCI_LE_SCAN_TIMEOUT_EVENT,
|
||||
@@ -1028,6 +1031,82 @@ class Host(utils.EventEmitter):
|
||||
# Look for the connection to which this data belongs
|
||||
if connection := self.connections.get(packet.connection_handle):
|
||||
connection.on_hci_acl_data_packet(packet)
|
||||
return
|
||||
|
||||
# WORKAROUND: Some controllers (e.g. Intel BE200) send ISO data wrapped in ACL packets
|
||||
# using the CIS handle.
|
||||
is_cis = packet.connection_handle in self.cis_links
|
||||
is_bis = packet.connection_handle in self.bis_links
|
||||
|
||||
if is_cis or is_bis:
|
||||
logger.debug(
|
||||
f"Received ISO data wrapped in ACL packet for handle 0x{packet.connection_handle:04X}"
|
||||
)
|
||||
payload = packet.data
|
||||
|
||||
ts_flag = self.link_ts_flags.get(packet.connection_handle)
|
||||
if ts_flag is None:
|
||||
# Learn TS flag from the first packet on this link
|
||||
if is_bis:
|
||||
# BIS packets always have Timestamp according to spec
|
||||
ts_flag = 1
|
||||
elif len(payload) < 8:
|
||||
# Too short to have 8-byte header (TS), must be No TS
|
||||
ts_flag = 0
|
||||
else:
|
||||
psn_no_ts = int.from_bytes(payload[0:2], 'little')
|
||||
psn_has_ts = int.from_bytes(payload[4:6], 'little')
|
||||
if psn_has_ts == 0:
|
||||
ts_flag = 1
|
||||
elif psn_no_ts == 0:
|
||||
ts_flag = 0
|
||||
else:
|
||||
# Fallback heuristic
|
||||
ts_flag = 1 if psn_has_ts < psn_no_ts else 0
|
||||
self.link_ts_flags[packet.connection_handle] = ts_flag
|
||||
logger.info(
|
||||
f"Learned TS_Flag = {ts_flag} for handle 0x{packet.connection_handle:04X}"
|
||||
)
|
||||
|
||||
if ts_flag:
|
||||
header_size = 8
|
||||
sdu_length_offset = 6
|
||||
else:
|
||||
header_size = 4
|
||||
sdu_length_offset = 2
|
||||
|
||||
pb_flag = 0b10
|
||||
if len(payload) >= header_size:
|
||||
sdu_length = int.from_bytes(
|
||||
payload[sdu_length_offset : sdu_length_offset + 2], 'little'
|
||||
)
|
||||
if sdu_length == len(payload) - header_size:
|
||||
pb_flag = 0b10 # Complete SDU
|
||||
else:
|
||||
pb_flag = 0b00 # First fragment
|
||||
else:
|
||||
pb_flag = 0b01 # Continuation
|
||||
ts_flag = 0
|
||||
|
||||
# Reconstruct the raw ISO packet (excluding packet indicator 0x05)
|
||||
pdu_info = packet.connection_handle | (pb_flag << 12) | (ts_flag << 14)
|
||||
header = bytes(
|
||||
[
|
||||
pdu_info & 0xFF,
|
||||
(pdu_info >> 8) & 0xFF,
|
||||
len(payload) & 0xFF,
|
||||
(len(payload) >> 8) & 0xFF,
|
||||
]
|
||||
)
|
||||
raw_iso_packet = header + payload
|
||||
|
||||
try:
|
||||
iso_packet = hci.HCI_IsoDataPacket.from_bytes(
|
||||
bytes([hci.HCI_ISO_DATA_PACKET]) + raw_iso_packet
|
||||
)
|
||||
self.on_hci_iso_data_packet(iso_packet)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconstruct ISO packet from ACL: {e}")
|
||||
|
||||
def on_hci_sco_data_packet(self, packet: hci.HCI_SynchronousDataPacket) -> None:
|
||||
# Experimental
|
||||
@@ -1251,6 +1330,7 @@ class Host(utils.EventEmitter):
|
||||
self.emit('disconnection', handle, event.reason)
|
||||
|
||||
# Remove the handle reference
|
||||
self.link_ts_flags.pop(handle, None)
|
||||
_ = (
|
||||
self.connections.pop(handle, 0)
|
||||
or self.cis_links.pop(handle, 0)
|
||||
@@ -1371,6 +1451,20 @@ class Host(utils.EventEmitter):
|
||||
event.advertiser_clock_accuracy,
|
||||
)
|
||||
|
||||
def on_hci_le_periodic_advertising_sync_established_v2_event(
|
||||
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Established_V2_Event
|
||||
):
|
||||
self.emit(
|
||||
'periodic_advertising_sync_establishment',
|
||||
event.status,
|
||||
event.sync_handle,
|
||||
event.advertising_sid,
|
||||
event.advertiser_address,
|
||||
event.advertiser_phy,
|
||||
event.periodic_advertising_interval,
|
||||
event.advertiser_clock_accuracy,
|
||||
)
|
||||
|
||||
def on_hci_le_periodic_advertising_sync_lost_event(
|
||||
self, event: hci.HCI_LE_Periodic_Advertising_Sync_Lost_Event
|
||||
):
|
||||
|
||||
@@ -104,6 +104,9 @@ async def open_pyusb_transport(spec: str) -> Transport:
|
||||
0,
|
||||
packet[1:],
|
||||
)
|
||||
elif packet_type == hci.HCI_ISO_DATA_PACKET:
|
||||
# Workaround: Send ISO packets over Bulk Out
|
||||
self.device.write(USB_ENDPOINT_ACL_OUT, packet[1:])
|
||||
else:
|
||||
logger.warning(
|
||||
color(f'unsupported packet type {packet_type}', 'red')
|
||||
|
||||
+97
-30
@@ -59,6 +59,8 @@ USB_BT_HCI_CLASS_TUPLE = (
|
||||
)
|
||||
|
||||
MAX_SCO_PACKET_SIZE = 1024
|
||||
MAX_SCO_IN_PACKETS = 128
|
||||
NUMBER_OF_SCO_IN_TRANSFERS = 2
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -336,6 +338,25 @@ class UsbPacketSink:
|
||||
)
|
||||
self.isochronous_out_transfer.submit()
|
||||
submitted = True
|
||||
elif packet_type == hci.HCI_ISO_DATA_PACKET:
|
||||
if self.isochronous_out_transfer is None:
|
||||
# Workaround: Send ISO packets over Bulk Out when Isochronous endpoints are not enabled
|
||||
self.bulk_or_control_out_transfer.setBulk(
|
||||
self.bulk_out.getAddress(),
|
||||
packet_payload,
|
||||
callback=self.transfer_callback,
|
||||
)
|
||||
self.bulk_or_control_out_transfer.submit()
|
||||
submitted = True
|
||||
else:
|
||||
logger.warning(
|
||||
color(
|
||||
'ISO packets over Isochronous endpoints not supported yet',
|
||||
'red',
|
||||
)
|
||||
)
|
||||
self.out_transfer_ready.release()
|
||||
continue
|
||||
else:
|
||||
logger.warning(
|
||||
color(f'unsupported packet type {packet_type}', 'red')
|
||||
@@ -388,20 +409,36 @@ class UsbPacketSink:
|
||||
READ_SIZE = 4096
|
||||
|
||||
|
||||
class ScoAccumulator:
|
||||
def __init__(self, emit: Callable[[bytes], Any]) -> None:
|
||||
class PacketSplitter:
|
||||
"""Splitter than can parse a byte stream and extract packets that consist of a
|
||||
header and a body, where the header includes an n-byte 'length' field at a
|
||||
certain offset.
|
||||
Extracted packets are emitted by calling a function passed to the constructor,
|
||||
with the full packet (header + body) as argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, length_offset: int, length_size: int, emit: Callable[[bytes], Any]
|
||||
) -> None:
|
||||
self.emit = emit
|
||||
self.packet = b''
|
||||
self.length_offset = length_offset
|
||||
self.length_size = length_size
|
||||
self.header_size = length_offset + length_size
|
||||
|
||||
def feed(self, data: bytes) -> None:
|
||||
while data:
|
||||
# Accumulate until we have a complete 3-byte header
|
||||
if (bytes_needed := 3 - len(self.packet)) > 0:
|
||||
# Accumulate until we have a complete header
|
||||
if (bytes_needed := self.header_size - len(self.packet)) > 0:
|
||||
self.packet += data[:bytes_needed]
|
||||
data = data[bytes_needed:]
|
||||
continue
|
||||
if len(self.packet) < self.header_size:
|
||||
continue
|
||||
|
||||
packet_length = 3 + self.packet[2]
|
||||
packet_length = self.header_size + int.from_bytes(
|
||||
self.packet[self.length_offset : self.length_offset + self.length_size],
|
||||
'little',
|
||||
)
|
||||
bytes_needed = packet_length - len(self.packet)
|
||||
self.packet += data[:bytes_needed]
|
||||
data = data[bytes_needed:]
|
||||
@@ -411,6 +448,24 @@ class ScoAccumulator:
|
||||
self.packet = b''
|
||||
|
||||
|
||||
class ScoPacketSplitter(PacketSplitter):
|
||||
def __init__(self, emit: Callable[[bytes], Any]) -> None:
|
||||
# The length field is 1 byte at offset 2 in the HCI SCO packet header
|
||||
super().__init__(length_offset=2, length_size=1, emit=emit)
|
||||
|
||||
|
||||
class EventPacketSplitter(PacketSplitter):
|
||||
def __init__(self, emit: Callable[[bytes], Any]) -> None:
|
||||
# The length field is 1 byte at offset 1 in the HCI Event packet header
|
||||
super().__init__(length_offset=1, length_size=1, emit=emit)
|
||||
|
||||
|
||||
class AclPacketSplitter(PacketSplitter):
|
||||
def __init__(self, emit: Callable[[bytes], Any]) -> None:
|
||||
# The length field is 2 bytes at offset 2 in the HCI ACL packet header
|
||||
super().__init__(length_offset=2, length_size=2, emit=emit)
|
||||
|
||||
|
||||
class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
def __init__(self, device, metadata, interrupt_in, bulk_in, isochronous_in):
|
||||
super().__init__()
|
||||
@@ -421,17 +476,23 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
self.bulk_in = bulk_in
|
||||
self.bulk_in_transfer = None
|
||||
self.isochronous_in = isochronous_in
|
||||
self.isochronous_in_transfer = None
|
||||
self.isochronous_accumulator = ScoAccumulator(
|
||||
lambda packet: self.queue_packet(hci.HCI_SYNCHRONOUS_DATA_PACKET, packet)
|
||||
)
|
||||
self.isochronous_in_transfers = []
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.queue = asyncio.Queue()
|
||||
self.dequeue_task = None
|
||||
self.done = {
|
||||
hci.HCI_EVENT_PACKET: asyncio.Event(),
|
||||
hci.HCI_ACL_DATA_PACKET: asyncio.Event(),
|
||||
hci.HCI_SYNCHRONOUS_DATA_PACKET: asyncio.Event(),
|
||||
self.done = {}
|
||||
self.splitters = {
|
||||
hci.HCI_EVENT_PACKET: EventPacketSplitter(
|
||||
lambda packet: self.queue_packet(hci.HCI_EVENT_PACKET, packet)
|
||||
),
|
||||
hci.HCI_ACL_DATA_PACKET: AclPacketSplitter(
|
||||
lambda packet: self.queue_packet(hci.HCI_ACL_DATA_PACKET, packet)
|
||||
),
|
||||
hci.HCI_SYNCHRONOUS_DATA_PACKET: ScoPacketSplitter(
|
||||
lambda packet: self.queue_packet(
|
||||
hci.HCI_SYNCHRONOUS_DATA_PACKET, packet
|
||||
)
|
||||
),
|
||||
}
|
||||
self.closed = False
|
||||
self.lock = threading.Lock()
|
||||
@@ -445,6 +506,7 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
callback=self.transfer_callback,
|
||||
user_data=hci.HCI_EVENT_PACKET,
|
||||
)
|
||||
self.done[self.interrupt_in_transfer] = asyncio.Event()
|
||||
self.interrupt_in_transfer.submit()
|
||||
|
||||
self.bulk_in_transfer = self.device.getTransfer()
|
||||
@@ -454,17 +516,21 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
callback=self.transfer_callback,
|
||||
user_data=hci.HCI_ACL_DATA_PACKET,
|
||||
)
|
||||
self.done[self.bulk_in_transfer] = asyncio.Event()
|
||||
self.bulk_in_transfer.submit()
|
||||
|
||||
if self.isochronous_in is not None:
|
||||
self.isochronous_in_transfer = self.device.getTransfer(iso_packets=16)
|
||||
self.isochronous_in_transfer.setIsochronous(
|
||||
self.isochronous_in.getAddress(),
|
||||
16 * self.isochronous_in.getMaxPacketSize(),
|
||||
callback=self.transfer_callback,
|
||||
user_data=hci.HCI_SYNCHRONOUS_DATA_PACKET,
|
||||
)
|
||||
self.isochronous_in_transfer.submit()
|
||||
for _ in range(NUMBER_OF_SCO_IN_TRANSFERS):
|
||||
transfer = self.device.getTransfer(iso_packets=MAX_SCO_IN_PACKETS)
|
||||
transfer.setIsochronous(
|
||||
self.isochronous_in.getAddress(),
|
||||
MAX_SCO_IN_PACKETS * self.isochronous_in.getMaxPacketSize(),
|
||||
callback=self.transfer_callback,
|
||||
user_data=hci.HCI_SYNCHRONOUS_DATA_PACKET,
|
||||
)
|
||||
self.isochronous_in_transfers.append(transfer)
|
||||
self.done[transfer] = asyncio.Event()
|
||||
transfer.submit()
|
||||
|
||||
self.dequeue_task = self.loop.create_task(self.dequeue())
|
||||
|
||||
@@ -490,6 +556,8 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
with self.lock:
|
||||
if self.closed:
|
||||
logger.debug("packet source closed, discarding transfer")
|
||||
elif (splitter := self.splitters.get(packet_type)) is None:
|
||||
logger.warning(f'no splitter for packet type {packet_type}')
|
||||
else:
|
||||
if packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET:
|
||||
for iso_status, iso_buffer in transfer.iterISO():
|
||||
@@ -503,11 +571,10 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
len(iso_buffer),
|
||||
iso_buffer.hex(),
|
||||
)
|
||||
self.isochronous_accumulator.feed(iso_buffer)
|
||||
splitter.feed(iso_buffer)
|
||||
else:
|
||||
self.queue_packet(
|
||||
packet_type,
|
||||
transfer.getBuffer()[: transfer.getActualLength()],
|
||||
splitter.feed(
|
||||
transfer.getBuffer()[: transfer.getActualLength()]
|
||||
)
|
||||
|
||||
# Re-submit the transfer so we can receive more data
|
||||
@@ -518,12 +585,12 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
self.loop.call_soon_threadsafe(self.on_transport_lost)
|
||||
elif status == usb1.TRANSFER_CANCELLED:
|
||||
logger.debug(f"IN[{packet_type}] transfer canceled")
|
||||
self.loop.call_soon_threadsafe(self.done[packet_type].set)
|
||||
self.loop.call_soon_threadsafe(self.done[transfer].set)
|
||||
else:
|
||||
logger.warning(
|
||||
color(f'!!! IN[{packet_type}] transfer not completed', 'red')
|
||||
)
|
||||
self.loop.call_soon_threadsafe(self.done[packet_type].set)
|
||||
self.loop.call_soon_threadsafe(self.done[transfer].set)
|
||||
self.loop.call_soon_threadsafe(self.on_transport_lost)
|
||||
|
||||
async def dequeue(self):
|
||||
@@ -552,7 +619,7 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
for transfer in (
|
||||
self.interrupt_in_transfer,
|
||||
self.bulk_in_transfer,
|
||||
self.isochronous_in_transfer,
|
||||
*self.isochronous_in_transfers,
|
||||
):
|
||||
if transfer is None:
|
||||
continue
|
||||
@@ -568,7 +635,7 @@ class UsbPacketSource(asyncio.Protocol, BaseSource):
|
||||
f'waiting for IN[{packet_type}] transfer cancellation '
|
||||
'to be done...'
|
||||
)
|
||||
await self.done[packet_type].wait()
|
||||
await self.done[transfer].wait()
|
||||
logger.debug(f'IN[{packet_type}] transfer cancellation done')
|
||||
except usb1.USBError as error:
|
||||
logger.debug(
|
||||
|
||||
+3
-3
@@ -43,9 +43,9 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
build = ["build >= 0.7"]
|
||||
test = [
|
||||
"pytest >= 8.2",
|
||||
"pytest-asyncio >= 0.23.5",
|
||||
"pytest-html >= 3.2.0",
|
||||
"pytest >= 9.0",
|
||||
"pytest-asyncio >= 1.4",
|
||||
"pytest-html >= 4.2",
|
||||
"coverage >= 6.4",
|
||||
]
|
||||
development = [
|
||||
|
||||
+3
-3
@@ -46,7 +46,7 @@ class TwoDevices(test_utils.TwoDevices):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"command,",
|
||||
"command",
|
||||
[
|
||||
avrcp.GetPlayStatusCommand(),
|
||||
avrcp.GetCapabilitiesCommand(
|
||||
@@ -132,7 +132,7 @@ def test_command(command: avrcp.Command):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"event,",
|
||||
"event",
|
||||
[
|
||||
avrcp.UidsChangedEvent(uid_counter=7),
|
||||
avrcp.TrackChangedEvent(uid=12356),
|
||||
@@ -159,7 +159,7 @@ def test_event(event: avrcp.Event):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response,",
|
||||
"response",
|
||||
[
|
||||
avrcp.GetPlayStatusResponse(
|
||||
song_length=1010, song_position=13, play_status=avrcp.PlayStatus.PAUSED
|
||||
|
||||
+1
-1
@@ -72,7 +72,7 @@ def test_sef():
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'sirk_type,', [(csip.SirkType.ENCRYPTED), (csip.SirkType.PLAINTEXT)]
|
||||
'sirk_type', [(csip.SirkType.ENCRYPTED), (csip.SirkType.PLAINTEXT)]
|
||||
)
|
||||
async def test_csis(sirk_type):
|
||||
SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
|
||||
|
||||
@@ -278,7 +278,7 @@ async def test_legacy_advertising():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'auto_restart,',
|
||||
'auto_restart',
|
||||
(True, False),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@@ -357,7 +357,7 @@ async def test_advertising_and_scanning():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'own_address_type,',
|
||||
'own_address_type',
|
||||
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@@ -395,7 +395,7 @@ async def test_extended_advertising_connection(own_address_type):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'own_address_type,',
|
||||
'own_address_type',
|
||||
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
|
||||
+3
-3
@@ -297,7 +297,7 @@ def test_custom_le_meta_event():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"clazz,",
|
||||
"clazz",
|
||||
[
|
||||
clazz[1]
|
||||
for clazz in inspect.getmembers(hci)
|
||||
@@ -313,7 +313,7 @@ def test_hci_command_subclasses_op_code(clazz: type[hci.HCI_Command]):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"clazz,",
|
||||
"clazz",
|
||||
[
|
||||
clazz[1]
|
||||
for clazz in inspect.getmembers(hci)
|
||||
@@ -330,7 +330,7 @@ def test_hci_event_subclasses_event_code(clazz: type[hci.HCI_Event]):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"clazz,",
|
||||
"clazz",
|
||||
[
|
||||
clazz[1]
|
||||
for clazz in inspect.getmembers(hci)
|
||||
|
||||
+2
-2
@@ -333,7 +333,7 @@ async def test_query_calls_with_calls(
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"operation,",
|
||||
"operation",
|
||||
(
|
||||
hfp.CallHoldOperation.RELEASE_ALL_HELD_CALLS,
|
||||
hfp.CallHoldOperation.RELEASE_ALL_ACTIVE_CALLS,
|
||||
@@ -358,7 +358,7 @@ async def test_hold_call_without_call_index(
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"operation,",
|
||||
"operation",
|
||||
(
|
||||
hfp.CallHoldOperation.RELEASE_SPECIFIC_CALL,
|
||||
hfp.CallHoldOperation.HOLD_ALL_CALLS_EXCEPT,
|
||||
|
||||
+5
-5
@@ -49,19 +49,19 @@ def test_helpers():
|
||||
psm = l2cap.L2CAP_Connection_Request.serialize_psm(0x242311)
|
||||
assert psm == bytes([0x11, 0x23, 0x24])
|
||||
|
||||
(offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm(
|
||||
offset, psm = l2cap.L2CAP_Connection_Request.parse_psm(
|
||||
bytes([0x00, 0x01, 0x00, 0x44]), 1
|
||||
)
|
||||
assert offset == 3
|
||||
assert psm == 0x01
|
||||
|
||||
(offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm(
|
||||
offset, psm = l2cap.L2CAP_Connection_Request.parse_psm(
|
||||
bytes([0x00, 0x23, 0x10, 0x44]), 1
|
||||
)
|
||||
assert offset == 3
|
||||
assert psm == 0x1023
|
||||
|
||||
(offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm(
|
||||
offset, psm = l2cap.L2CAP_Connection_Request.parse_psm(
|
||||
bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1
|
||||
)
|
||||
assert offset == 4
|
||||
@@ -197,7 +197,7 @@ async def test_basic_connection():
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize("info_type,", list(l2cap.L2CAP_Information_Request.InfoType))
|
||||
@pytest.mark.parametrize("info_type", list(l2cap.L2CAP_Information_Request.InfoType))
|
||||
async def test_l2cap_information_request(monkeypatch, info_type):
|
||||
# TODO: Replace handlers with API when implemented
|
||||
devices = await TwoDevices.create_with_connection()
|
||||
@@ -321,7 +321,7 @@ async def test_mtu():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("mtu,", (50, 255, 256, 1000))
|
||||
@pytest.mark.parametrize("mtu", (50, 255, 256, 1000))
|
||||
async def test_enhanced_retransmission_mode(mtu: int):
|
||||
devices = TwoDevices()
|
||||
await devices.setup_connection()
|
||||
|
||||
+1
-1
@@ -68,7 +68,7 @@ async def test_self_disconnection():
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'responder_role,',
|
||||
'responder_role',
|
||||
(Role.CENTRAL, Role.PERIPHERAL),
|
||||
)
|
||||
async def test_self_classic_connection(responder_role):
|
||||
|
||||
+67
-4
@@ -24,7 +24,7 @@ import sys
|
||||
import pytest
|
||||
|
||||
from bumble import controller, device, hci, link, transport
|
||||
from bumble.transport import common
|
||||
from bumble.transport import common, usb
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -102,7 +102,7 @@ def test_parser_extensions():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"address,",
|
||||
"address",
|
||||
("127.0.0.1", "::1"),
|
||||
)
|
||||
async def test_tcp_connection(address):
|
||||
@@ -205,7 +205,7 @@ async def test_unix_connection_abstract():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"address,",
|
||||
"address",
|
||||
("127.0.0.1", "[::1]"),
|
||||
)
|
||||
async def test_android_netsim_connection(address):
|
||||
@@ -228,7 +228,7 @@ async def test_android_netsim_connection(address):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"spec,",
|
||||
"spec",
|
||||
(
|
||||
"android-netsim:[::1]:{port},mode=host[a=b,c=d]",
|
||||
"android-netsim:localhost:{port},mode=host[a=b,c=d]",
|
||||
@@ -252,6 +252,69 @@ async def test_open_transport_with_metadata(spec):
|
||||
await controller_transport.close()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_packet_splitter_complete():
|
||||
emitted = []
|
||||
splitter = usb.AclPacketSplitter(emitted.append)
|
||||
packet = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
|
||||
splitter.feed(packet)
|
||||
assert emitted == [packet]
|
||||
|
||||
|
||||
def test_packet_splitter_chunks():
|
||||
emitted = []
|
||||
splitter = usb.AclPacketSplitter(emitted.append)
|
||||
packet = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
|
||||
splitter.feed(packet[:4])
|
||||
assert emitted == []
|
||||
splitter.feed(packet[4:])
|
||||
assert emitted == [packet]
|
||||
|
||||
|
||||
def test_packet_splitter_multiple():
|
||||
emitted = []
|
||||
splitter = usb.AclPacketSplitter(emitted.append)
|
||||
packet1 = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
|
||||
packet2 = bytes([0x02, 0x00, 0x02, 0x00, 0x55, 0x66])
|
||||
splitter.feed(packet1 + packet2)
|
||||
assert emitted == [packet1, packet2]
|
||||
|
||||
|
||||
def test_packet_splitter_partial():
|
||||
emitted = []
|
||||
splitter = usb.AclPacketSplitter(emitted.append)
|
||||
packet1 = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
|
||||
packet2 = bytes([0x02, 0x00, 0x02, 0x00, 0x55, 0x66])
|
||||
splitter.feed(packet1 + packet2[:4])
|
||||
assert emitted == [packet1]
|
||||
splitter.feed(packet2[4:])
|
||||
assert emitted == [packet1, packet2]
|
||||
|
||||
|
||||
def test_packet_splitter_empty_payload():
|
||||
emitted = []
|
||||
splitter = usb.AclPacketSplitter(emitted.append)
|
||||
packet = bytes([0x01, 0x00, 0x00, 0x00])
|
||||
splitter.feed(packet)
|
||||
assert emitted == [packet]
|
||||
|
||||
|
||||
def test_sco_packet_splitter():
|
||||
emitted = []
|
||||
splitter = usb.ScoPacketSplitter(emitted.append)
|
||||
packet = bytes([0x01, 0x00, 0x03, 0x11, 0x22, 0x33])
|
||||
splitter.feed(packet)
|
||||
assert emitted == [packet]
|
||||
|
||||
|
||||
def test_event_packet_splitter():
|
||||
emitted = []
|
||||
splitter = usb.EventPacketSplitter(emitted.append)
|
||||
packet = bytes([0x04, 0x02, 0x11, 0x22])
|
||||
splitter.feed(packet)
|
||||
assert emitted == [packet]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
test_parser()
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from bumble import hci
|
||||
from bumble.transport import usb
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usb_packet_sink_iso_routing():
|
||||
# Mock usb1 device and endpoints
|
||||
mock_device = mock.Mock()
|
||||
mock_bulk_out = mock.Mock()
|
||||
mock_bulk_out.getAddress.return_value = 0x02
|
||||
|
||||
# Scenario 1: Isochronous endpoints are not enabled (isochronous_out is None)
|
||||
mock_transfer = mock.Mock()
|
||||
mock_device.getTransfer.return_value = mock_transfer
|
||||
|
||||
sink = usb.UsbPacketSink(mock_device, mock_bulk_out, isochronous_out=None)
|
||||
sink.start()
|
||||
|
||||
# Send HCI_ISO_DATA_PACKET
|
||||
iso_packet = bytes([hci.HCI_ISO_DATA_PACKET, 0x01, 0x02, 0x03])
|
||||
sink.on_packet(iso_packet)
|
||||
|
||||
# Yield control to let the queue processor run
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Verify it was sent via bulk transfer
|
||||
mock_transfer.setBulk.assert_called_once_with(
|
||||
0x02,
|
||||
bytes([0x01, 0x02, 0x03]),
|
||||
callback=sink.transfer_callback,
|
||||
)
|
||||
mock_transfer.submit.assert_called_once()
|
||||
|
||||
if sink.queue_task:
|
||||
sink.queue_task.cancel()
|
||||
try:
|
||||
await sink.queue_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usb_packet_sink_iso_routing_with_iso_endpoint():
|
||||
# Mock usb1 device and endpoints
|
||||
mock_device = mock.Mock()
|
||||
mock_bulk_out = mock.Mock()
|
||||
mock_bulk_out.getAddress.return_value = 0x02
|
||||
mock_iso_out = mock.Mock()
|
||||
mock_iso_out.getMaxPacketSize.return_value = 64
|
||||
|
||||
# Scenario 2: Isochronous endpoints are enabled
|
||||
mock_transfer_bulk = mock.Mock()
|
||||
mock_transfer_iso = mock.Mock()
|
||||
|
||||
# getTransfer is called twice: once for bulk_or_control and once for isochronous
|
||||
mock_device.getTransfer.side_effect = [mock_transfer_bulk, mock_transfer_iso]
|
||||
|
||||
sink = usb.UsbPacketSink(mock_device, mock_bulk_out, isochronous_out=mock_iso_out)
|
||||
sink.start()
|
||||
|
||||
# Send HCI_ISO_DATA_PACKET
|
||||
iso_packet = bytes([hci.HCI_ISO_DATA_PACKET, 0x01, 0x02, 0x03])
|
||||
sink.on_packet(iso_packet)
|
||||
|
||||
# Yield control to let the queue processor run
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Verify it was NOT sent via bulk transfer
|
||||
mock_transfer_bulk.setBulk.assert_not_called()
|
||||
|
||||
if sink.queue_task:
|
||||
sink.queue_task.cancel()
|
||||
try:
|
||||
await sink.queue_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
Reference in New Issue
Block a user