add support for ACL and ISO HCI packet queues

This commit is contained in:
Gilles Boccon-Gibod
2025-01-22 13:15:37 -05:00
parent af466c2970
commit cbd46adbcf
9 changed files with 471 additions and 115 deletions

View File

@@ -825,10 +825,24 @@ async def run_broadcast(
), ),
) )
print('Setup ISO Data Path') print('Setup ISO Data Path')
def on_drain(packet_queue):
print(
f'\rPACKETS: pending={packet_queue.pending}, '
f'queued={packet_queue.queued}, completed={packet_queue.completed}',
end='',
)
packet_queue = None
for bis_link in big.bis_links: for bis_link in big.bis_links:
await bis_link.setup_data_path( await bis_link.setup_data_path(
direction=bis_link.Direction.HOST_TO_CONTROLLER direction=bis_link.Direction.HOST_TO_CONTROLLER
) )
if packet_queue is None:
packet_queue = bis_link.data_packet_queue
if packet_queue:
packet_queue.on('drain', lambda: on_drain(packet_queue))
for frame in itertools.cycle(frames): for frame in itertools.cycle(frames):
mid = len(frame) // 2 mid = len(frame) // 2

View File

@@ -37,6 +37,8 @@ from bumble.hci import (
HCI_Command_Status_Event, HCI_Command_Status_Event,
HCI_READ_BUFFER_SIZE_COMMAND, HCI_READ_BUFFER_SIZE_COMMAND,
HCI_Read_Buffer_Size_Command, HCI_Read_Buffer_Size_Command,
HCI_LE_READ_BUFFER_SIZE_V2_COMMAND,
HCI_LE_Read_Buffer_Size_V2_Command,
HCI_READ_BD_ADDR_COMMAND, HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command, HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND, HCI_READ_LOCAL_NAME_COMMAND,
@@ -147,7 +149,7 @@ async def get_le_info(host: Host) -> None:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_acl_flow_control_info(host: Host) -> None: async def get_flow_control_info(host: Host) -> None:
print() print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND): if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
@@ -160,14 +162,28 @@ async def get_acl_flow_control_info(host: Host) -> None:
f'packets of size {response.return_parameters.hc_acl_data_packet_length}', f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
) )
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND): if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.le_acl_data_packet_length}',
)
print(
color('LE ISO Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_iso_data_packets} '
f'packets of size {response.return_parameters.iso_data_packet_length}',
)
elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command( response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True HCI_LE_Read_Buffer_Size_Command(), check_result=True
) )
print( print(
color('LE ACL Flow Control:', 'yellow'), color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_le_acl_data_packets} ' f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}', f'packets of size {response.return_parameters.le_acl_data_packet_length}',
) )
@@ -274,8 +290,8 @@ async def async_main(latency_probes, transport):
# Get the LE info # Get the LE info
await get_le_info(host) await get_le_info(host)
# Print the ACL flow control info # Print the flow control info
await get_acl_flow_control_info(host) await get_flow_control_info(host)
# Get codec info # Get codec info
await get_codecs_info(host) await get_codecs_info(host)

View File

@@ -154,15 +154,17 @@ class Controller:
'0000000060000000' '0000000060000000'
) # BR/EDR Not Supported, LE Supported (Controller) ) # BR/EDR Not Supported, LE Supported (Controller)
self.manufacturer_name = 0xFFFF self.manufacturer_name = 0xFFFF
self.hc_data_packet_length = 27 self.acl_data_packet_length = 27
self.hc_total_num_data_packets = 64 self.total_num_acl_data_packets = 64
self.hc_le_data_packet_length = 27 self.le_acl_data_packet_length = 27
self.hc_total_num_le_data_packets = 64 self.total_num_le_acl_data_packets = 64
self.iso_data_packet_length = 960
self.total_num_iso_data_packets = 64
self.event_mask = 0 self.event_mask = 0
self.event_mask_page_2 = 0 self.event_mask_page_2 = 0
self.supported_commands = bytes.fromhex( self.supported_commands = bytes.fromhex(
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000' '2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
'30f0f9ff01008004000000000000000000000000000000000000000000000000' '30f0f9ff01008004002000000000000000000000000000000000000000000000'
) )
self.le_event_mask = 0 self.le_event_mask = 0
self.advertising_parameters = None self.advertising_parameters = None
@@ -1181,9 +1183,9 @@ class Controller:
return struct.pack( return struct.pack(
'<BHBHH', '<BHBHH',
HCI_SUCCESS, HCI_SUCCESS,
self.hc_data_packet_length, self.acl_data_packet_length,
0, 0,
self.hc_total_num_data_packets, self.total_num_acl_data_packets,
0, 0,
) )
@@ -1212,8 +1214,21 @@ class Controller:
return struct.pack( return struct.pack(
'<BHB', '<BHB',
HCI_SUCCESS, HCI_SUCCESS,
self.hc_le_data_packet_length, self.le_acl_data_packet_length,
self.hc_total_num_le_data_packets, self.total_num_le_acl_data_packets,
)
def on_hci_le_read_buffer_size_v2_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.2 LE Read Buffer Size Command
'''
return struct.pack(
'<BHBHB',
HCI_SUCCESS,
self.le_acl_data_packet_length,
self.total_num_le_acl_data_packets,
self.iso_data_packet_length,
self.total_num_iso_data_packets,
) )
def on_hci_le_read_local_supported_features_command(self, _command): def on_hci_le_read_local_supported_features_command(self, _command):

View File

@@ -52,7 +52,7 @@ from pyee import EventEmitter
from .colors import color from .colors import color
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
from .gatt import Characteristic, Descriptor, Service from .gatt import Characteristic, Descriptor, Service
from .host import Host from .host import DataPacketQueue, Host
from .profiles.gap import GenericAccessService from .profiles.gap import GenericAccessService
from .core import ( from .core import (
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
@@ -1329,7 +1329,6 @@ class ScoLink(CompositeEventEmitter):
class _IsoLink: class _IsoLink:
handle: int handle: int
device: Device device: Device
packet_sequence_number: int
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
class Direction(IntEnum): class Direction(IntEnum):
@@ -1391,22 +1390,12 @@ class _IsoLink:
return response.return_parameters.status return response.return_parameters.status
def write(self, sdu: bytes) -> None: def write(self, sdu: bytes) -> None:
"""Write an ISO SDU. """Write an ISO SDU."""
self.device.host.send_iso_sdu(connection_handle=self.handle, sdu=sdu)
This will automatically increase the packet sequence number. @property
""" def data_packet_queue(self) -> DataPacketQueue | None:
self.device.host.send_hci_packet( return self.device.host.get_data_packet_queue(self.handle)
hci.HCI_IsoDataPacket(
connection_handle=self.handle,
data_total_length=len(sdu) + 4,
packet_sequence_number=self.packet_sequence_number,
pb_flag=0b10,
packet_status_flag=0,
iso_sdu_length=len(sdu),
iso_sdu_fragment=sdu,
)
)
self.packet_sequence_number = (self.packet_sequence_number + 1) % 0x10000
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1426,7 +1415,6 @@ class CisLink(CompositeEventEmitter, _IsoLink):
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__init__() super().__init__()
self.packet_sequence_number = 0
async def disconnect( async def disconnect(
self, reason: int = hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR self, reason: int = hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
@@ -1443,7 +1431,6 @@ class BisLink(_IsoLink):
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.device = self.big.device self.device = self.big.device
self.packet_sequence_number = 0
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1691,6 +1678,10 @@ class Connection(CompositeEventEmitter):
self.peer_le_features = await self.device.get_remote_le_features(self) self.peer_le_features = await self.device.get_remote_le_features(self)
return self.peer_le_features return self.peer_le_features
@property
def data_packet_queue(self) -> DataPacketQueue | None:
return self.device.host.get_data_packet_queue(self.handle)
async def __aenter__(self): async def __aenter__(self):
return self return self

View File

@@ -3585,8 +3585,8 @@ class HCI_LE_Set_Event_Mask_Command(HCI_Command):
@HCI_Command.command( @HCI_Command.command(
return_parameters_fields=[ return_parameters_fields=[
('status', STATUS_SPEC), ('status', STATUS_SPEC),
('hc_le_acl_data_packet_length', 2), ('le_acl_data_packet_length', 2),
('hc_total_num_le_acl_data_packets', 1), ('total_num_le_acl_data_packets', 1),
] ]
) )
class HCI_LE_Read_Buffer_Size_Command(HCI_Command): class HCI_LE_Read_Buffer_Size_Command(HCI_Command):
@@ -3595,6 +3595,22 @@ class HCI_LE_Read_Buffer_Size_Command(HCI_Command):
''' '''
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('le_acl_data_packet_length', 2),
('total_num_le_acl_data_packets', 1),
('iso_data_packet_length', 2),
('total_num_iso_data_packets', 1),
]
)
class HCI_LE_Read_Buffer_Size_V2_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.2 LE Read Buffer Size V2 Command
'''
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Command.command( @HCI_Command.command(
return_parameters_fields=[('status', STATUS_SPEC), ('le_features', 8)] return_parameters_fields=[('status', STATUS_SPEC), ('le_features', 8)]
@@ -7555,7 +7571,7 @@ class HCI_IsoDataPacket(HCI_Packet):
if should_include_sdu_info: if should_include_sdu_info:
packet_sequence_number, sdu_info = struct.unpack_from('<HH', packet, pos) packet_sequence_number, sdu_info = struct.unpack_from('<HH', packet, pos)
iso_sdu_length = sdu_info & 0xFFF iso_sdu_length = sdu_info & 0xFFF
packet_status_flag = sdu_info >> 14 packet_status_flag = (sdu_info >> 15) & 1
pos += 4 pos += 4
iso_sdu_fragment = packet[pos:] iso_sdu_fragment = packet[pos:]
@@ -7589,7 +7605,7 @@ class HCI_IsoDataPacket(HCI_Packet):
fmt += 'HH' fmt += 'HH'
args += [ args += [
self.packet_sequence_number, self.packet_sequence_number,
self.iso_sdu_length | self.packet_status_flag << 14, self.iso_sdu_length | self.packet_status_flag << 15,
] ]
return struct.pack(fmt, *args) + self.iso_sdu_fragment return struct.pack(fmt, *args) + self.iso_sdu_fragment
@@ -7597,9 +7613,10 @@ class HCI_IsoDataPacket(HCI_Packet):
return ( return (
f'{color("ISO", "blue")}: ' f'{color("ISO", "blue")}: '
f'handle=0x{self.connection_handle:04x}, ' f'handle=0x{self.connection_handle:04x}, '
f'pb={self.pb_flag}, '
f'ps={self.packet_status_flag}, ' f'ps={self.packet_status_flag}, '
f'data_total_length={self.data_total_length}, ' f'data_total_length={self.data_total_length}, '
f'sdu={self.iso_sdu_fragment.hex()}' f'sdu_fragment={self.iso_sdu_fragment.hex()}'
) )

View File

@@ -21,7 +21,6 @@ import collections
import dataclasses import dataclasses
import logging import logging
import struct import struct
import itertools
from typing import ( from typing import (
Any, Any,
@@ -35,6 +34,8 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
) )
import pyee
from bumble.colors import color from bumble.colors import color
from bumble.l2cap import L2CAP_PDU from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper from bumble.snoop import Snooper
@@ -60,7 +61,19 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AclPacketQueue: class DataPacketQueue(pyee.EventEmitter):
"""
Flow-control queue for host->controller data packets (ACL, ISO).
The queue holds packets associated with a connection handle. The packets
are sent to the controller, up to a maximum total number of packets in flight.
A packet is considered to be "in flight" when it has been sent to the controller
but not completed yet. Packets are no longer "in flight" when the controller
declares them as completed.
The queue emits a 'drain' event whenever one or more packets are completed.
"""
max_packet_size: int max_packet_size: int
def __init__( def __init__(
@@ -69,40 +82,105 @@ class AclPacketQueue:
max_in_flight: int, max_in_flight: int,
send: Callable[[hci.HCI_Packet], None], send: Callable[[hci.HCI_Packet], None],
) -> None: ) -> None:
super().__init__()
self.max_packet_size = max_packet_size self.max_packet_size = max_packet_size
self.max_in_flight = max_in_flight self.max_in_flight = max_in_flight
self.in_flight = 0 self._in_flight = 0 # Total number of packets in flight across all connections
self.send = send self._in_flight_per_connection: dict[int, int] = collections.defaultdict(
self.packets: Deque[hci.HCI_AclDataPacket] = collections.deque() int
) # Number of packets in flight per connection
self._send = send
self._packets: Deque[tuple[hci.HCI_Packet, int]] = collections.deque()
self._queued = 0
self._completed = 0
def enqueue(self, packet: hci.HCI_AclDataPacket) -> None: @property
self.packets.appendleft(packet) def queued(self) -> int:
self.check_queue() """Total number of packets queued since creation."""
return self._queued
if self.packets: @property
def completed(self) -> int:
"""Total number of packets completed since creation."""
return self._completed
@property
def pending(self) -> int:
"""Number of packets that have been queued but not completed."""
return self._queued - self._completed
def enqueue(self, packet: hci.HCI_Packet, connection_handle: int) -> None:
"""Enqueue a packet associated with a connection"""
self._packets.appendleft((packet, connection_handle))
self._queued += 1
self._check_queue()
if self._packets:
logger.debug( logger.debug(
f'{self.in_flight} ACL packets in flight, ' f'{self._in_flight} packets in flight, '
f'{len(self.packets)} in queue' f'{len(self._packets)} in queue'
) )
def check_queue(self) -> None: def flush(self, connection_handle: int) -> None:
while self.packets and self.in_flight < self.max_in_flight: """
packet = self.packets.pop() Remove all packets associated with a connection.
self.send(packet)
self.in_flight += 1
def on_packets_completed(self, packet_count: int) -> None: All packets associated with the connection that are in flight are implicitly
if packet_count > self.in_flight: marked as completed, but no 'drain' event is emitted.
"""
packets_to_keep = [
(packet, handle)
for (packet, handle) in self._packets
if handle != connection_handle
]
if flushed_count := len(self._packets) - len(packets_to_keep):
self._completed += flushed_count
self._packets = collections.deque(packets_to_keep)
if connection_handle in self._in_flight_per_connection:
in_flight = self._in_flight_per_connection[connection_handle]
self._completed += in_flight
self._in_flight -= in_flight
del self._in_flight_per_connection[connection_handle]
def _check_queue(self) -> None:
while self._packets and self._in_flight < self.max_in_flight:
packet, connection_handle = self._packets.pop()
self._send(packet)
self._in_flight += 1
self._in_flight_per_connection[connection_handle] += 1
def on_packets_completed(self, packet_count: int, connection_handle: int) -> None:
"""Mark one or more packets associated with a connection as completed."""
if connection_handle not in self._in_flight_per_connection:
logger.warning( logger.warning(
color( f'received completion for unknown connection {connection_handle}'
'!!! {packet_count} completed but only '
f'{self.in_flight} in flight'
) )
) return
packet_count = self.in_flight
self.in_flight -= packet_count in_flight_for_connection = self._in_flight_per_connection[connection_handle]
self.check_queue() if packet_count <= in_flight_for_connection:
self._in_flight_per_connection[connection_handle] -= packet_count
else:
logger.warning(
f'{packet_count} completed for {connection_handle} '
f'but only {in_flight_for_connection} in flight'
)
self._in_flight_per_connection[connection_handle] = 0
if packet_count <= self._in_flight:
self._in_flight -= packet_count
self._completed += packet_count
else:
logger.warning(
f'{packet_count} completed but only {self._in_flight} in flight'
)
self._in_flight = 0
self._completed = self._queued
self._check_queue()
self.emit('drain')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -115,7 +193,7 @@ class Connection:
self.peer_address = peer_address self.peer_address = peer_address
self.assembler = hci.HCI_AclDataPacketAssembler(self.on_acl_pdu) self.assembler = hci.HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport self.transport = transport
acl_packet_queue: Optional[AclPacketQueue] = ( acl_packet_queue: Optional[DataPacketQueue] = (
host.le_acl_packet_queue host.le_acl_packet_queue
if transport == BT_LE_TRANSPORT if transport == BT_LE_TRANSPORT
else host.acl_packet_queue else host.acl_packet_queue
@@ -130,29 +208,37 @@ class Connection:
l2cap_pdu = L2CAP_PDU.from_bytes(pdu) l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload) self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
def __str__(self) -> str:
return (
f'Connection(transport={self.transport}, peer_address={self.peer_address})'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass @dataclasses.dataclass
class ScoLink: class ScoLink:
peer_address: hci.Address peer_address: hci.Address
handle: int connection_handle: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass @dataclasses.dataclass
class CisLink: class IsoLink:
peer_address: hci.Address
handle: int handle: int
packet_queue: DataPacketQueue = dataclasses.field(repr=False)
packet_sequence_number: int = 0
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Host(AbortableEventEmitter): class Host(AbortableEventEmitter):
connections: Dict[int, Connection] connections: Dict[int, Connection]
cis_links: Dict[int, CisLink] cis_links: Dict[int, IsoLink]
bis_links: Dict[int, IsoLink]
sco_links: Dict[int, ScoLink] sco_links: Dict[int, ScoLink]
bigs: dict[int, set[int]] = {} # BIG Handle to BIS Handles bigs: dict[int, set[int]] = {} # BIG Handle to BIS Handles
acl_packet_queue: Optional[AclPacketQueue] = None acl_packet_queue: Optional[DataPacketQueue] = None
le_acl_packet_queue: Optional[AclPacketQueue] = None le_acl_packet_queue: Optional[DataPacketQueue] = None
iso_packet_queue: Optional[DataPacketQueue] = None
hci_sink: Optional[TransportSink] = None hci_sink: Optional[TransportSink] = None
hci_metadata: Dict[str, Any] hci_metadata: Dict[str, Any]
long_term_key_provider: Optional[ long_term_key_provider: Optional[
@@ -171,6 +257,7 @@ class Host(AbortableEventEmitter):
self.ready = False # True when we can accept incoming packets self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle self.connections = {} # Connections, by connection handle
self.cis_links = {} # CIS links, by connection handle self.cis_links = {} # CIS links, by connection handle
self.bis_links = {} # BIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle self.sco_links = {} # SCO links, by connection handle
self.pending_command = None self.pending_command = None
self.pending_response: Optional[asyncio.Future[Any]] = None self.pending_response: Optional[asyncio.Future[Any]] = None
@@ -413,39 +500,70 @@ class Host(AbortableEventEmitter):
f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}' f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}'
) )
self.acl_packet_queue = AclPacketQueue( self.acl_packet_queue = DataPacketQueue(
max_packet_size=hc_acl_data_packet_length, max_packet_size=hc_acl_data_packet_length,
max_in_flight=hc_total_num_acl_data_packets, max_in_flight=hc_total_num_acl_data_packets,
send=self.send_hci_packet, send=self.send_hci_packet,
) )
hc_le_acl_data_packet_length = 0 le_acl_data_packet_length = 0
hc_total_num_le_acl_data_packets = 0 total_num_le_acl_data_packets = 0
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND): iso_data_packet_length = 0
total_num_iso_data_packets = 0
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await self.send_command(
hci.HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
le_acl_data_packet_length = (
response.return_parameters.le_acl_data_packet_length
)
total_num_le_acl_data_packets = (
response.return_parameters.total_num_le_acl_data_packets
)
iso_data_packet_length = response.return_parameters.iso_data_packet_length
total_num_iso_data_packets = (
response.return_parameters.total_num_iso_data_packets
)
logger.debug(
'HCI LE flow control: '
f'le_acl_data_packet_length={le_acl_data_packet_length},'
f'total_num_le_acl_data_packets={total_num_le_acl_data_packets}'
f'iso_data_packet_length={iso_data_packet_length},'
f'total_num_iso_data_packets={total_num_iso_data_packets}'
)
elif self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command( response = await self.send_command(
hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True
) )
hc_le_acl_data_packet_length = ( le_acl_data_packet_length = (
response.return_parameters.hc_le_acl_data_packet_length response.return_parameters.le_acl_data_packet_length
) )
hc_total_num_le_acl_data_packets = ( total_num_le_acl_data_packets = (
response.return_parameters.hc_total_num_le_acl_data_packets response.return_parameters.total_num_le_acl_data_packets
) )
logger.debug( logger.debug(
'HCI LE ACL flow control: ' 'HCI LE ACL flow control: '
f'hc_le_acl_data_packet_length={hc_le_acl_data_packet_length},' f'le_acl_data_packet_length={le_acl_data_packet_length},'
f'hc_total_num_le_acl_data_packets={hc_total_num_le_acl_data_packets}' f'total_num_le_acl_data_packets={total_num_le_acl_data_packets}'
) )
if hc_le_acl_data_packet_length == 0 or hc_total_num_le_acl_data_packets == 0: if le_acl_data_packet_length == 0 or total_num_le_acl_data_packets == 0:
# LE and Classic share the same queue # LE and Classic share the same queue
self.le_acl_packet_queue = self.acl_packet_queue self.le_acl_packet_queue = self.acl_packet_queue
else: else:
# Create a separate queue for LE # Create a separate queue for LE
self.le_acl_packet_queue = AclPacketQueue( self.le_acl_packet_queue = DataPacketQueue(
max_packet_size=hc_le_acl_data_packet_length, max_packet_size=le_acl_data_packet_length,
max_in_flight=hc_total_num_le_acl_data_packets, max_in_flight=total_num_le_acl_data_packets,
send=self.send_hci_packet,
)
if iso_data_packet_length and total_num_iso_data_packets:
self.iso_packet_queue = DataPacketQueue(
max_packet_size=iso_data_packet_length,
max_in_flight=total_num_iso_data_packets,
send=self.send_hci_packet, send=self.send_hci_packet,
) )
@@ -597,11 +715,78 @@ class Host(AbortableEventEmitter):
data=l2cap_pdu[offset : offset + data_total_length], data=l2cap_pdu[offset : offset + data_total_length],
) )
logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}') logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}')
packet_queue.enqueue(acl_packet) packet_queue.enqueue(acl_packet, connection_handle)
pb_flag = 1 pb_flag = 1
offset += data_total_length offset += data_total_length
bytes_remaining -= data_total_length bytes_remaining -= data_total_length
def get_data_packet_queue(self, connection_handle: int) -> DataPacketQueue | None:
if connection := self.connections.get(connection_handle):
return connection.acl_packet_queue
if iso_link := self.cis_links.get(connection_handle) or self.bis_links.get(
connection_handle
):
return iso_link.packet_queue
return None
def send_iso_sdu(self, connection_handle: int, sdu: bytes) -> None:
if not (
iso_link := self.cis_links.get(connection_handle)
or self.bis_links.get(connection_handle)
):
logger.warning(f"no ISO link for connection handle {connection_handle}")
return
if iso_link.packet_queue is None:
logger.warning("ISO link has no data packet queue")
return
bytes_remaining = len(sdu)
offset = 0
while bytes_remaining:
is_first_fragment = offset == 0
header_length = 4 if is_first_fragment else 0
assert iso_link.packet_queue.max_packet_size > header_length
fragment_length = min(
bytes_remaining, iso_link.packet_queue.max_packet_size - header_length
)
is_last_fragment = bytes_remaining == fragment_length
iso_sdu_fragment = sdu[offset : offset + fragment_length]
iso_link.packet_queue.enqueue(
(
hci.HCI_IsoDataPacket(
connection_handle=connection_handle,
data_total_length=header_length + fragment_length,
packet_sequence_number=iso_link.packet_sequence_number,
pb_flag=0b10 if is_last_fragment else 0b00,
packet_status_flag=0,
iso_sdu_length=len(sdu),
iso_sdu_fragment=iso_sdu_fragment,
)
if is_first_fragment
else hci.HCI_IsoDataPacket(
connection_handle=connection_handle,
data_total_length=fragment_length,
pb_flag=0b11 if is_last_fragment else 0b01,
iso_sdu_fragment=iso_sdu_fragment,
)
),
connection_handle,
)
offset += fragment_length
bytes_remaining -= fragment_length
iso_link.packet_sequence_number = (iso_link.packet_sequence_number + 1) & 0xFFFF
def remove_big(self, big_handle: int) -> None:
if big := self.bigs.pop(big_handle, None):
for connection_handle in big:
if bis_link := self.bis_links.pop(connection_handle, None):
bis_link.packet_queue.flush(bis_link.handle)
def supports_command(self, op_code: int) -> bool: def supports_command(self, op_code: int) -> bool:
return ( return (
self.local_supported_commands self.local_supported_commands
@@ -729,17 +914,31 @@ class Host(AbortableEventEmitter):
def on_hci_command_status_event(self, event): def on_hci_command_status_event(self, event):
return self.on_command_processed(event) return self.on_command_processed(event)
def on_hci_number_of_completed_packets_event(self, event): def on_hci_number_of_completed_packets_event(
self, event: hci.HCI_Number_Of_Completed_Packets_Event
) -> None:
for connection_handle, num_completed_packets in zip( for connection_handle, num_completed_packets in zip(
event.connection_handles, event.num_completed_packets event.connection_handles, event.num_completed_packets
): ):
if connection := self.connections.get(connection_handle): if connection := self.connections.get(connection_handle):
connection.acl_packet_queue.on_packets_completed(num_completed_packets) connection.acl_packet_queue.on_packets_completed(
elif connection_handle not in itertools.chain( num_completed_packets, connection_handle
self.cis_links.keys(), )
self.sco_links.keys(), return
itertools.chain.from_iterable(self.bigs.values()),
): if cis_link := self.cis_links.get(connection_handle):
cis_link.packet_queue.on_packets_completed(
num_completed_packets, connection_handle
)
return
if bis_link := self.bis_links.get(connection_handle):
bis_link.packet_queue.on_packets_completed(
num_completed_packets, connection_handle
)
return
if connection_handle not in self.sco_links:
logger.warning( logger.warning(
'received packet completion event for unknown handle ' 'received packet completion event for unknown handle '
f'0x{connection_handle:04X}' f'0x{connection_handle:04X}'
@@ -857,11 +1056,7 @@ class Host(AbortableEventEmitter):
return return
if event.status == hci.HCI_SUCCESS: if event.status == hci.HCI_SUCCESS:
logger.debug( logger.debug(f'### DISCONNECTION: {connection}, reason={event.reason}')
f'### DISCONNECTION: [0x{handle:04X}] '
f'{connection.peer_address} '
f'reason={event.reason}'
)
# Notify the listeners # Notify the listeners
self.emit('disconnection', handle, event.reason) self.emit('disconnection', handle, event.reason)
@@ -872,6 +1067,12 @@ class Host(AbortableEventEmitter):
or self.cis_links.pop(handle, 0) or self.cis_links.pop(handle, 0)
or self.sco_links.pop(handle, 0) or self.sco_links.pop(handle, 0)
) )
# Flush the data queues
self.acl_packet_queue.flush(handle)
self.le_acl_packet_queue.flush(handle)
if self.iso_packet_queue:
self.iso_packet_queue.flush(handle)
else: else:
logger.debug(f'### DISCONNECTION FAILED: {event.status}') logger.debug(f'### DISCONNECTION FAILED: {event.status}')
@@ -958,6 +1159,14 @@ class Host(AbortableEventEmitter):
def on_hci_le_create_big_complete_event(self, event): def on_hci_le_create_big_complete_event(self, event):
self.bigs[event.big_handle] = set(event.connection_handle) self.bigs[event.big_handle] = set(event.connection_handle)
if self.iso_packet_queue is None:
logger.warning("BIS established but ISO packets not supported")
for connection_handle in event.connection_handle:
self.bis_links[connection_handle] = IsoLink(
connection_handle, self.iso_packet_queue
)
self.emit( self.emit(
'big_establishment', 'big_establishment',
event.status, event.status,
@@ -975,6 +1184,12 @@ class Host(AbortableEventEmitter):
) )
def on_hci_le_big_sync_established_event(self, event): def on_hci_le_big_sync_established_event(self, event):
self.bigs[event.big_handle] = set(event.connection_handle)
for connection_handle in event.connection_handle:
self.bis_links[connection_handle] = IsoLink(
connection_handle, self.iso_packet_queue
)
self.emit( self.emit(
'big_sync_establishment', 'big_sync_establishment',
event.status, event.status,
@@ -990,22 +1205,20 @@ class Host(AbortableEventEmitter):
) )
def on_hci_le_big_sync_lost_event(self, event): def on_hci_le_big_sync_lost_event(self, event):
self.emit( self.remove_big(event.big_handle)
'big_sync_lost', self.emit('big_sync_lost', event.big_handle, event.reason)
event.big_handle,
event.reason,
)
def on_hci_le_terminate_big_complete_event(self, event): def on_hci_le_terminate_big_complete_event(self, event):
self.bigs.pop(event.big_handle) self.remove_big(event.big_handle)
self.emit('big_termination', event.reason, event.big_handle) self.emit('big_termination', event.reason, event.big_handle)
def on_hci_le_cis_established_event(self, event): def on_hci_le_cis_established_event(self, event):
# The remaining parameters are unused for now. # The remaining parameters are unused for now.
if event.status == hci.HCI_SUCCESS: if event.status == hci.HCI_SUCCESS:
self.cis_links[event.connection_handle] = CisLink( if self.iso_packet_queue is None:
handle=event.connection_handle, logger.warning("CIS established but ISO packets not supported")
peer_address=hci.Address.ANY, self.cis_links[event.connection_handle] = IsoLink(
handle=event.connection_handle, packet_queue=self.iso_packet_queue
) )
self.emit('cis_establishment', event.connection_handle) self.emit('cis_establishment', event.connection_handle)
else: else:
@@ -1075,7 +1288,7 @@ class Host(AbortableEventEmitter):
self.sco_links[event.connection_handle] = ScoLink( self.sco_links[event.connection_handle] = ScoLink(
peer_address=event.bd_addr, peer_address=event.bd_addr,
handle=event.connection_handle, connection_handle=event.connection_handle,
) )
# Notify the client # Notify the client

View File

@@ -34,7 +34,7 @@ from bumble.device import (
Device, Device,
PeriodicAdvertisingParameters, PeriodicAdvertisingParameters,
) )
from bumble.host import AclPacketQueue, Host from bumble.host import DataPacketQueue, Host
from bumble.hci import ( from bumble.hci import (
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
HCI_COMMAND_STATUS_PENDING, HCI_COMMAND_STATUS_PENDING,
@@ -90,9 +90,9 @@ async def test_device_connect_parallel():
def _send(packet): def _send(packet):
pass pass
d0.host.acl_packet_queue = AclPacketQueue(0, 0, _send) d0.host.acl_packet_queue = DataPacketQueue(0, 0, _send)
d1.host.acl_packet_queue = AclPacketQueue(0, 0, _send) d1.host.acl_packet_queue = DataPacketQueue(0, 0, _send)
d2.host.acl_packet_queue = AclPacketQueue(0, 0, _send) d2.host.acl_packet_queue = DataPacketQueue(0, 0, _send)
# enable classic # enable classic
d0.classic_enabled = True d0.classic_enabled = True

View File

@@ -170,8 +170,8 @@ def test_HCI_Command_Complete_Event():
command_opcode=HCI_LE_READ_BUFFER_SIZE_COMMAND, command_opcode=HCI_LE_READ_BUFFER_SIZE_COMMAND,
return_parameters=HCI_LE_Read_Buffer_Size_Command.create_return_parameters( return_parameters=HCI_LE_Read_Buffer_Size_Command.create_return_parameters(
status=0, status=0,
hc_le_acl_data_packet_length=1234, le_acl_data_packet_length=1234,
hc_total_num_le_acl_data_packets=56, total_num_le_acl_data_packets=56,
), ),
) )
basic_check(event) basic_check(event)

View File

@@ -16,11 +16,14 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
import unittest.mock
import pytest import pytest
import unittest
from bumble.controller import Controller from bumble.controller import Controller
from bumble.host import Host from bumble.host import Host, DataPacketQueue
from bumble.transport import AsyncPipeSink from bumble.transport import AsyncPipeSink
from bumble.hci import HCI_AclDataPacket
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -60,3 +63,90 @@ async def test_reset(supported_commands: str, lmp_features: str):
assert host.local_lmp_features == int.from_bytes( assert host.local_lmp_features == int.from_bytes(
bytes.fromhex(lmp_features), 'little' bytes.fromhex(lmp_features), 'little'
) )
# -----------------------------------------------------------------------------
def test_data_packet_queue():
controller = unittest.mock.Mock()
queue = DataPacketQueue(10, 2, controller.send)
assert queue.queued == 0
assert queue.completed == 0
packet = HCI_AclDataPacket(
connection_handle=123, pb_flag=0, bc_flag=0, data_total_length=0, data=b''
)
queue.enqueue(packet, packet.connection_handle)
assert queue.queued == 1
assert queue.completed == 0
assert controller.send.call_count == 1
queue.enqueue(packet, packet.connection_handle)
assert queue.queued == 2
assert queue.completed == 0
assert controller.send.call_count == 2
queue.enqueue(packet, packet.connection_handle)
assert queue.queued == 3
assert queue.completed == 0
assert controller.send.call_count == 2
queue.on_packets_completed(1, 8000)
assert queue.queued == 3
assert queue.completed == 0
assert controller.send.call_count == 2
queue.on_packets_completed(1, 123)
assert queue.queued == 3
assert queue.completed == 1
assert controller.send.call_count == 3
queue.enqueue(packet, packet.connection_handle)
assert queue.queued == 4
assert queue.completed == 1
assert controller.send.call_count == 3
queue.on_packets_completed(2, 123)
assert queue.queued == 4
assert queue.completed == 3
assert controller.send.call_count == 4
queue.on_packets_completed(1, 123)
assert queue.queued == 4
assert queue.completed == 4
assert controller.send.call_count == 4
queue.enqueue(packet, 123)
queue.enqueue(packet, 123)
queue.enqueue(packet, 123)
queue.enqueue(packet, 124)
queue.enqueue(packet, 124)
queue.enqueue(packet, 124)
queue.on_packets_completed(1, 123)
assert queue.queued == 10
assert queue.completed == 5
queue.flush(123)
queue.flush(124)
assert queue.queued == 10
assert queue.completed == 10
queue.enqueue(packet, 123)
queue.on_packets_completed(1, 124)
assert queue.queued == 11
assert queue.completed == 10
queue.on_packets_completed(1000, 123)
assert queue.queued == 11
assert queue.completed == 11
drain_listener = unittest.mock.Mock()
queue.on('drain', drain_listener.on_drain)
queue.enqueue(packet, 123)
assert drain_listener.on_drain.call_count == 0
queue.on_packets_completed(1, 123)
assert drain_listener.on_drain.call_count == 1
queue.enqueue(packet, 123)
queue.enqueue(packet, 123)
queue.enqueue(packet, 123)
queue.flush(123)
assert drain_listener.on_drain.call_count == 1
assert queue.queued == 15
assert queue.completed == 15