diff --git a/apps/auracast.py b/apps/auracast.py index ede4eada..42c1cab0 100644 --- a/apps/auracast.py +++ b/apps/auracast.py @@ -825,10 +825,24 @@ async def run_broadcast( ), ) print('Setup ISO Data Path') + + def on_drain(packet_queue): + print( + f'\rPACKETS: pending={packet_queue.pending}, ' + f'queued={packet_queue.queued}, completed={packet_queue.completed}', + end='', + ) + + packet_queue = None for bis_link in big.bis_links: await bis_link.setup_data_path( direction=bis_link.Direction.HOST_TO_CONTROLLER ) + if packet_queue is None: + packet_queue = bis_link.data_packet_queue + + if packet_queue: + packet_queue.on('drain', lambda: on_drain(packet_queue)) for frame in itertools.cycle(frames): mid = len(frame) // 2 diff --git a/apps/controller_info.py b/apps/controller_info.py index 89c830c9..2c8d9aab 100644 --- a/apps/controller_info.py +++ b/apps/controller_info.py @@ -37,6 +37,8 @@ from bumble.hci import ( HCI_Command_Status_Event, HCI_READ_BUFFER_SIZE_COMMAND, HCI_Read_Buffer_Size_Command, + HCI_LE_READ_BUFFER_SIZE_V2_COMMAND, + HCI_LE_Read_Buffer_Size_V2_Command, HCI_READ_BD_ADDR_COMMAND, HCI_Read_BD_ADDR_Command, HCI_READ_LOCAL_NAME_COMMAND, @@ -147,7 +149,7 @@ async def get_le_info(host: Host) -> None: # ----------------------------------------------------------------------------- -async def get_acl_flow_control_info(host: Host) -> None: +async def get_flow_control_info(host: Host) -> None: print() if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND): @@ -160,14 +162,28 @@ async def get_acl_flow_control_info(host: Host) -> None: f'packets of size {response.return_parameters.hc_acl_data_packet_length}', ) - if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND): + if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND): + response = await host.send_command( + HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True + ) + print( + color('LE ACL Flow Control:', 'yellow'), + f'{response.return_parameters.total_num_le_acl_data_packets} ' + f'packets of size {response.return_parameters.le_acl_data_packet_length}', + ) + print( + color('LE ISO Flow Control:', 'yellow'), + f'{response.return_parameters.total_num_iso_data_packets} ' + f'packets of size {response.return_parameters.iso_data_packet_length}', + ) + elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND): response = await host.send_command( HCI_LE_Read_Buffer_Size_Command(), check_result=True ) print( color('LE ACL Flow Control:', 'yellow'), - f'{response.return_parameters.hc_total_num_le_acl_data_packets} ' - f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}', + f'{response.return_parameters.total_num_le_acl_data_packets} ' + f'packets of size {response.return_parameters.le_acl_data_packet_length}', ) @@ -274,8 +290,8 @@ async def async_main(latency_probes, transport): # Get the LE info await get_le_info(host) - # Print the ACL flow control info - await get_acl_flow_control_info(host) + # Print the flow control info + await get_flow_control_info(host) # Get codec info await get_codecs_info(host) diff --git a/bumble/controller.py b/bumble/controller.py index 03d3c14d..9366c1d9 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -154,15 +154,17 @@ class Controller: '0000000060000000' ) # BR/EDR Not Supported, LE Supported (Controller) self.manufacturer_name = 0xFFFF - self.hc_data_packet_length = 27 - self.hc_total_num_data_packets = 64 - self.hc_le_data_packet_length = 27 - self.hc_total_num_le_data_packets = 64 + self.acl_data_packet_length = 27 + self.total_num_acl_data_packets = 64 + self.le_acl_data_packet_length = 27 + self.total_num_le_acl_data_packets = 64 + self.iso_data_packet_length = 960 + self.total_num_iso_data_packets = 64 self.event_mask = 0 self.event_mask_page_2 = 0 self.supported_commands = bytes.fromhex( '2000800000c000000000e4000000a822000000000000040000f7ffff7f000000' - '30f0f9ff01008004000000000000000000000000000000000000000000000000' + '30f0f9ff01008004002000000000000000000000000000000000000000000000' ) self.le_event_mask = 0 self.advertising_parameters = None @@ -1181,9 +1183,9 @@ class Controller: return struct.pack( ' None: - """Write an ISO SDU. + """Write an ISO SDU.""" + self.device.host.send_iso_sdu(connection_handle=self.handle, sdu=sdu) - This will automatically increase the packet sequence number. - """ - self.device.host.send_hci_packet( - hci.HCI_IsoDataPacket( - connection_handle=self.handle, - data_total_length=len(sdu) + 4, - packet_sequence_number=self.packet_sequence_number, - pb_flag=0b10, - packet_status_flag=0, - iso_sdu_length=len(sdu), - iso_sdu_fragment=sdu, - ) - ) - self.packet_sequence_number = (self.packet_sequence_number + 1) % 0x10000 + @property + def data_packet_queue(self) -> DataPacketQueue | None: + return self.device.host.get_data_packet_queue(self.handle) # ----------------------------------------------------------------------------- @@ -1426,7 +1415,6 @@ class CisLink(CompositeEventEmitter, _IsoLink): def __post_init__(self) -> None: super().__init__() - self.packet_sequence_number = 0 async def disconnect( self, reason: int = hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR @@ -1443,7 +1431,6 @@ class BisLink(_IsoLink): def __post_init__(self) -> None: self.device = self.big.device - self.packet_sequence_number = 0 # ----------------------------------------------------------------------------- @@ -1691,6 +1678,10 @@ class Connection(CompositeEventEmitter): self.peer_le_features = await self.device.get_remote_le_features(self) return self.peer_le_features + @property + def data_packet_queue(self) -> DataPacketQueue | None: + return self.device.host.get_data_packet_queue(self.handle) + async def __aenter__(self): return self diff --git a/bumble/hci.py b/bumble/hci.py index 77a95e72..2a1eb85e 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -3585,8 +3585,8 @@ class HCI_LE_Set_Event_Mask_Command(HCI_Command): @HCI_Command.command( return_parameters_fields=[ ('status', STATUS_SPEC), - ('hc_le_acl_data_packet_length', 2), - ('hc_total_num_le_acl_data_packets', 1), + ('le_acl_data_packet_length', 2), + ('total_num_le_acl_data_packets', 1), ] ) class HCI_LE_Read_Buffer_Size_Command(HCI_Command): @@ -3595,6 +3595,22 @@ class HCI_LE_Read_Buffer_Size_Command(HCI_Command): ''' +# ----------------------------------------------------------------------------- +@HCI_Command.command( + return_parameters_fields=[ + ('status', STATUS_SPEC), + ('le_acl_data_packet_length', 2), + ('total_num_le_acl_data_packets', 1), + ('iso_data_packet_length', 2), + ('total_num_iso_data_packets', 1), + ] +) +class HCI_LE_Read_Buffer_Size_V2_Command(HCI_Command): + ''' + See Bluetooth spec @ 7.8.2 LE Read Buffer Size V2 Command + ''' + + # ----------------------------------------------------------------------------- @HCI_Command.command( return_parameters_fields=[('status', STATUS_SPEC), ('le_features', 8)] @@ -7555,7 +7571,7 @@ class HCI_IsoDataPacket(HCI_Packet): if should_include_sdu_info: packet_sequence_number, sdu_info = struct.unpack_from('> 14 + packet_status_flag = (sdu_info >> 15) & 1 pos += 4 iso_sdu_fragment = packet[pos:] @@ -7589,7 +7605,7 @@ class HCI_IsoDataPacket(HCI_Packet): fmt += 'HH' args += [ self.packet_sequence_number, - self.iso_sdu_length | self.packet_status_flag << 14, + self.iso_sdu_length | self.packet_status_flag << 15, ] return struct.pack(fmt, *args) + self.iso_sdu_fragment @@ -7597,9 +7613,10 @@ class HCI_IsoDataPacket(HCI_Packet): return ( f'{color("ISO", "blue")}: ' f'handle=0x{self.connection_handle:04x}, ' + f'pb={self.pb_flag}, ' f'ps={self.packet_status_flag}, ' f'data_total_length={self.data_total_length}, ' - f'sdu={self.iso_sdu_fragment.hex()}' + f'sdu_fragment={self.iso_sdu_fragment.hex()}' ) diff --git a/bumble/host.py b/bumble/host.py index 1ce4263a..48c03e04 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -21,7 +21,6 @@ import collections import dataclasses import logging import struct -import itertools from typing import ( Any, @@ -35,6 +34,8 @@ from typing import ( TYPE_CHECKING, ) +import pyee + from bumble.colors import color from bumble.l2cap import L2CAP_PDU from bumble.snoop import Snooper @@ -60,7 +61,19 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -class AclPacketQueue: +class DataPacketQueue(pyee.EventEmitter): + """ + Flow-control queue for host->controller data packets (ACL, ISO). + + The queue holds packets associated with a connection handle. The packets + are sent to the controller, up to a maximum total number of packets in flight. + A packet is considered to be "in flight" when it has been sent to the controller + but not completed yet. Packets are no longer "in flight" when the controller + declares them as completed. + + The queue emits a 'drain' event whenever one or more packets are completed. + """ + max_packet_size: int def __init__( @@ -69,40 +82,105 @@ class AclPacketQueue: max_in_flight: int, send: Callable[[hci.HCI_Packet], None], ) -> None: + super().__init__() self.max_packet_size = max_packet_size self.max_in_flight = max_in_flight - self.in_flight = 0 - self.send = send - self.packets: Deque[hci.HCI_AclDataPacket] = collections.deque() + self._in_flight = 0 # Total number of packets in flight across all connections + self._in_flight_per_connection: dict[int, int] = collections.defaultdict( + int + ) # Number of packets in flight per connection + self._send = send + self._packets: Deque[tuple[hci.HCI_Packet, int]] = collections.deque() + self._queued = 0 + self._completed = 0 - def enqueue(self, packet: hci.HCI_AclDataPacket) -> None: - self.packets.appendleft(packet) - self.check_queue() + @property + def queued(self) -> int: + """Total number of packets queued since creation.""" + return self._queued - if self.packets: + @property + def completed(self) -> int: + """Total number of packets completed since creation.""" + return self._completed + + @property + def pending(self) -> int: + """Number of packets that have been queued but not completed.""" + return self._queued - self._completed + + def enqueue(self, packet: hci.HCI_Packet, connection_handle: int) -> None: + """Enqueue a packet associated with a connection""" + self._packets.appendleft((packet, connection_handle)) + self._queued += 1 + self._check_queue() + + if self._packets: logger.debug( - f'{self.in_flight} ACL packets in flight, ' - f'{len(self.packets)} in queue' + f'{self._in_flight} packets in flight, ' + f'{len(self._packets)} in queue' ) - def check_queue(self) -> None: - while self.packets and self.in_flight < self.max_in_flight: - packet = self.packets.pop() - self.send(packet) - self.in_flight += 1 + def flush(self, connection_handle: int) -> None: + """ + Remove all packets associated with a connection. - def on_packets_completed(self, packet_count: int) -> None: - if packet_count > self.in_flight: + All packets associated with the connection that are in flight are implicitly + marked as completed, but no 'drain' event is emitted. + """ + + packets_to_keep = [ + (packet, handle) + for (packet, handle) in self._packets + if handle != connection_handle + ] + if flushed_count := len(self._packets) - len(packets_to_keep): + self._completed += flushed_count + self._packets = collections.deque(packets_to_keep) + + if connection_handle in self._in_flight_per_connection: + in_flight = self._in_flight_per_connection[connection_handle] + self._completed += in_flight + self._in_flight -= in_flight + del self._in_flight_per_connection[connection_handle] + + def _check_queue(self) -> None: + while self._packets and self._in_flight < self.max_in_flight: + packet, connection_handle = self._packets.pop() + self._send(packet) + self._in_flight += 1 + self._in_flight_per_connection[connection_handle] += 1 + + def on_packets_completed(self, packet_count: int, connection_handle: int) -> None: + """Mark one or more packets associated with a connection as completed.""" + if connection_handle not in self._in_flight_per_connection: logger.warning( - color( - '!!! {packet_count} completed but only ' - f'{self.in_flight} in flight' - ) + f'received completion for unknown connection {connection_handle}' ) - packet_count = self.in_flight + return - self.in_flight -= packet_count - self.check_queue() + in_flight_for_connection = self._in_flight_per_connection[connection_handle] + if packet_count <= in_flight_for_connection: + self._in_flight_per_connection[connection_handle] -= packet_count + else: + logger.warning( + f'{packet_count} completed for {connection_handle} ' + f'but only {in_flight_for_connection} in flight' + ) + self._in_flight_per_connection[connection_handle] = 0 + + if packet_count <= self._in_flight: + self._in_flight -= packet_count + self._completed += packet_count + else: + logger.warning( + f'{packet_count} completed but only {self._in_flight} in flight' + ) + self._in_flight = 0 + self._completed = self._queued + + self._check_queue() + self.emit('drain') # ----------------------------------------------------------------------------- @@ -115,7 +193,7 @@ class Connection: self.peer_address = peer_address self.assembler = hci.HCI_AclDataPacketAssembler(self.on_acl_pdu) self.transport = transport - acl_packet_queue: Optional[AclPacketQueue] = ( + acl_packet_queue: Optional[DataPacketQueue] = ( host.le_acl_packet_queue if transport == BT_LE_TRANSPORT else host.acl_packet_queue @@ -130,29 +208,37 @@ class Connection: l2cap_pdu = L2CAP_PDU.from_bytes(pdu) self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload) + def __str__(self) -> str: + return ( + f'Connection(transport={self.transport}, peer_address={self.peer_address})' + ) + # ----------------------------------------------------------------------------- @dataclasses.dataclass class ScoLink: peer_address: hci.Address - handle: int + connection_handle: int # ----------------------------------------------------------------------------- @dataclasses.dataclass -class CisLink: - peer_address: hci.Address +class IsoLink: handle: int + packet_queue: DataPacketQueue = dataclasses.field(repr=False) + packet_sequence_number: int = 0 # ----------------------------------------------------------------------------- class Host(AbortableEventEmitter): connections: Dict[int, Connection] - cis_links: Dict[int, CisLink] + cis_links: Dict[int, IsoLink] + bis_links: Dict[int, IsoLink] sco_links: Dict[int, ScoLink] bigs: dict[int, set[int]] = {} # BIG Handle to BIS Handles - acl_packet_queue: Optional[AclPacketQueue] = None - le_acl_packet_queue: Optional[AclPacketQueue] = None + acl_packet_queue: Optional[DataPacketQueue] = None + le_acl_packet_queue: Optional[DataPacketQueue] = None + iso_packet_queue: Optional[DataPacketQueue] = None hci_sink: Optional[TransportSink] = None hci_metadata: Dict[str, Any] long_term_key_provider: Optional[ @@ -171,6 +257,7 @@ class Host(AbortableEventEmitter): self.ready = False # True when we can accept incoming packets self.connections = {} # Connections, by connection handle self.cis_links = {} # CIS links, by connection handle + self.bis_links = {} # BIS links, by connection handle self.sco_links = {} # SCO links, by connection handle self.pending_command = None self.pending_response: Optional[asyncio.Future[Any]] = None @@ -413,39 +500,70 @@ class Host(AbortableEventEmitter): f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}' ) - self.acl_packet_queue = AclPacketQueue( + self.acl_packet_queue = DataPacketQueue( max_packet_size=hc_acl_data_packet_length, max_in_flight=hc_total_num_acl_data_packets, send=self.send_hci_packet, ) - hc_le_acl_data_packet_length = 0 - hc_total_num_le_acl_data_packets = 0 - if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND): + le_acl_data_packet_length = 0 + total_num_le_acl_data_packets = 0 + iso_data_packet_length = 0 + total_num_iso_data_packets = 0 + if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_V2_COMMAND): + response = await self.send_command( + hci.HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True + ) + le_acl_data_packet_length = ( + response.return_parameters.le_acl_data_packet_length + ) + total_num_le_acl_data_packets = ( + response.return_parameters.total_num_le_acl_data_packets + ) + iso_data_packet_length = response.return_parameters.iso_data_packet_length + total_num_iso_data_packets = ( + response.return_parameters.total_num_iso_data_packets + ) + + logger.debug( + 'HCI LE flow control: ' + f'le_acl_data_packet_length={le_acl_data_packet_length},' + f'total_num_le_acl_data_packets={total_num_le_acl_data_packets}' + f'iso_data_packet_length={iso_data_packet_length},' + f'total_num_iso_data_packets={total_num_iso_data_packets}' + ) + elif self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND): response = await self.send_command( hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True ) - hc_le_acl_data_packet_length = ( - response.return_parameters.hc_le_acl_data_packet_length + le_acl_data_packet_length = ( + response.return_parameters.le_acl_data_packet_length ) - hc_total_num_le_acl_data_packets = ( - response.return_parameters.hc_total_num_le_acl_data_packets + total_num_le_acl_data_packets = ( + response.return_parameters.total_num_le_acl_data_packets ) logger.debug( 'HCI LE ACL flow control: ' - f'hc_le_acl_data_packet_length={hc_le_acl_data_packet_length},' - f'hc_total_num_le_acl_data_packets={hc_total_num_le_acl_data_packets}' + f'le_acl_data_packet_length={le_acl_data_packet_length},' + f'total_num_le_acl_data_packets={total_num_le_acl_data_packets}' ) - if hc_le_acl_data_packet_length == 0 or hc_total_num_le_acl_data_packets == 0: + if le_acl_data_packet_length == 0 or total_num_le_acl_data_packets == 0: # LE and Classic share the same queue self.le_acl_packet_queue = self.acl_packet_queue else: # Create a separate queue for LE - self.le_acl_packet_queue = AclPacketQueue( - max_packet_size=hc_le_acl_data_packet_length, - max_in_flight=hc_total_num_le_acl_data_packets, + self.le_acl_packet_queue = DataPacketQueue( + max_packet_size=le_acl_data_packet_length, + max_in_flight=total_num_le_acl_data_packets, + send=self.send_hci_packet, + ) + + if iso_data_packet_length and total_num_iso_data_packets: + self.iso_packet_queue = DataPacketQueue( + max_packet_size=iso_data_packet_length, + max_in_flight=total_num_iso_data_packets, send=self.send_hci_packet, ) @@ -597,11 +715,78 @@ class Host(AbortableEventEmitter): data=l2cap_pdu[offset : offset + data_total_length], ) logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}') - packet_queue.enqueue(acl_packet) + packet_queue.enqueue(acl_packet, connection_handle) pb_flag = 1 offset += data_total_length bytes_remaining -= data_total_length + def get_data_packet_queue(self, connection_handle: int) -> DataPacketQueue | None: + if connection := self.connections.get(connection_handle): + return connection.acl_packet_queue + + if iso_link := self.cis_links.get(connection_handle) or self.bis_links.get( + connection_handle + ): + return iso_link.packet_queue + + return None + + def send_iso_sdu(self, connection_handle: int, sdu: bytes) -> None: + if not ( + iso_link := self.cis_links.get(connection_handle) + or self.bis_links.get(connection_handle) + ): + logger.warning(f"no ISO link for connection handle {connection_handle}") + return + + if iso_link.packet_queue is None: + logger.warning("ISO link has no data packet queue") + return + + bytes_remaining = len(sdu) + offset = 0 + while bytes_remaining: + is_first_fragment = offset == 0 + header_length = 4 if is_first_fragment else 0 + assert iso_link.packet_queue.max_packet_size > header_length + fragment_length = min( + bytes_remaining, iso_link.packet_queue.max_packet_size - header_length + ) + is_last_fragment = bytes_remaining == fragment_length + iso_sdu_fragment = sdu[offset : offset + fragment_length] + iso_link.packet_queue.enqueue( + ( + hci.HCI_IsoDataPacket( + connection_handle=connection_handle, + data_total_length=header_length + fragment_length, + packet_sequence_number=iso_link.packet_sequence_number, + pb_flag=0b10 if is_last_fragment else 0b00, + packet_status_flag=0, + iso_sdu_length=len(sdu), + iso_sdu_fragment=iso_sdu_fragment, + ) + if is_first_fragment + else hci.HCI_IsoDataPacket( + connection_handle=connection_handle, + data_total_length=fragment_length, + pb_flag=0b11 if is_last_fragment else 0b01, + iso_sdu_fragment=iso_sdu_fragment, + ) + ), + connection_handle, + ) + + offset += fragment_length + bytes_remaining -= fragment_length + + iso_link.packet_sequence_number = (iso_link.packet_sequence_number + 1) & 0xFFFF + + def remove_big(self, big_handle: int) -> None: + if big := self.bigs.pop(big_handle, None): + for connection_handle in big: + if bis_link := self.bis_links.pop(connection_handle, None): + bis_link.packet_queue.flush(bis_link.handle) + def supports_command(self, op_code: int) -> bool: return ( self.local_supported_commands @@ -729,17 +914,31 @@ class Host(AbortableEventEmitter): def on_hci_command_status_event(self, event): return self.on_command_processed(event) - def on_hci_number_of_completed_packets_event(self, event): + def on_hci_number_of_completed_packets_event( + self, event: hci.HCI_Number_Of_Completed_Packets_Event + ) -> None: for connection_handle, num_completed_packets in zip( event.connection_handles, event.num_completed_packets ): if connection := self.connections.get(connection_handle): - connection.acl_packet_queue.on_packets_completed(num_completed_packets) - elif connection_handle not in itertools.chain( - self.cis_links.keys(), - self.sco_links.keys(), - itertools.chain.from_iterable(self.bigs.values()), - ): + connection.acl_packet_queue.on_packets_completed( + num_completed_packets, connection_handle + ) + return + + if cis_link := self.cis_links.get(connection_handle): + cis_link.packet_queue.on_packets_completed( + num_completed_packets, connection_handle + ) + return + + if bis_link := self.bis_links.get(connection_handle): + bis_link.packet_queue.on_packets_completed( + num_completed_packets, connection_handle + ) + return + + if connection_handle not in self.sco_links: logger.warning( 'received packet completion event for unknown handle ' f'0x{connection_handle:04X}' @@ -857,11 +1056,7 @@ class Host(AbortableEventEmitter): return if event.status == hci.HCI_SUCCESS: - logger.debug( - f'### DISCONNECTION: [0x{handle:04X}] ' - f'{connection.peer_address} ' - f'reason={event.reason}' - ) + logger.debug(f'### DISCONNECTION: {connection}, reason={event.reason}') # Notify the listeners self.emit('disconnection', handle, event.reason) @@ -872,6 +1067,12 @@ class Host(AbortableEventEmitter): or self.cis_links.pop(handle, 0) or self.sco_links.pop(handle, 0) ) + + # Flush the data queues + self.acl_packet_queue.flush(handle) + self.le_acl_packet_queue.flush(handle) + if self.iso_packet_queue: + self.iso_packet_queue.flush(handle) else: logger.debug(f'### DISCONNECTION FAILED: {event.status}') @@ -958,6 +1159,14 @@ class Host(AbortableEventEmitter): def on_hci_le_create_big_complete_event(self, event): self.bigs[event.big_handle] = set(event.connection_handle) + if self.iso_packet_queue is None: + logger.warning("BIS established but ISO packets not supported") + + for connection_handle in event.connection_handle: + self.bis_links[connection_handle] = IsoLink( + connection_handle, self.iso_packet_queue + ) + self.emit( 'big_establishment', event.status, @@ -975,6 +1184,12 @@ class Host(AbortableEventEmitter): ) def on_hci_le_big_sync_established_event(self, event): + self.bigs[event.big_handle] = set(event.connection_handle) + for connection_handle in event.connection_handle: + self.bis_links[connection_handle] = IsoLink( + connection_handle, self.iso_packet_queue + ) + self.emit( 'big_sync_establishment', event.status, @@ -990,22 +1205,20 @@ class Host(AbortableEventEmitter): ) def on_hci_le_big_sync_lost_event(self, event): - self.emit( - 'big_sync_lost', - event.big_handle, - event.reason, - ) + self.remove_big(event.big_handle) + self.emit('big_sync_lost', event.big_handle, event.reason) def on_hci_le_terminate_big_complete_event(self, event): - self.bigs.pop(event.big_handle) + self.remove_big(event.big_handle) self.emit('big_termination', event.reason, event.big_handle) def on_hci_le_cis_established_event(self, event): # The remaining parameters are unused for now. if event.status == hci.HCI_SUCCESS: - self.cis_links[event.connection_handle] = CisLink( - handle=event.connection_handle, - peer_address=hci.Address.ANY, + if self.iso_packet_queue is None: + logger.warning("CIS established but ISO packets not supported") + self.cis_links[event.connection_handle] = IsoLink( + handle=event.connection_handle, packet_queue=self.iso_packet_queue ) self.emit('cis_establishment', event.connection_handle) else: @@ -1075,7 +1288,7 @@ class Host(AbortableEventEmitter): self.sco_links[event.connection_handle] = ScoLink( peer_address=event.bd_addr, - handle=event.connection_handle, + connection_handle=event.connection_handle, ) # Notify the client diff --git a/tests/device_test.py b/tests/device_test.py index 1f6175ab..350c0a46 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -34,7 +34,7 @@ from bumble.device import ( Device, PeriodicAdvertisingParameters, ) -from bumble.host import AclPacketQueue, Host +from bumble.host import DataPacketQueue, Host from bumble.hci import ( HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, HCI_COMMAND_STATUS_PENDING, @@ -90,9 +90,9 @@ async def test_device_connect_parallel(): def _send(packet): pass - d0.host.acl_packet_queue = AclPacketQueue(0, 0, _send) - d1.host.acl_packet_queue = AclPacketQueue(0, 0, _send) - d2.host.acl_packet_queue = AclPacketQueue(0, 0, _send) + d0.host.acl_packet_queue = DataPacketQueue(0, 0, _send) + d1.host.acl_packet_queue = DataPacketQueue(0, 0, _send) + d2.host.acl_packet_queue = DataPacketQueue(0, 0, _send) # enable classic d0.classic_enabled = True diff --git a/tests/hci_test.py b/tests/hci_test.py index ee4ef8a5..eac641e0 100644 --- a/tests/hci_test.py +++ b/tests/hci_test.py @@ -170,8 +170,8 @@ def test_HCI_Command_Complete_Event(): command_opcode=HCI_LE_READ_BUFFER_SIZE_COMMAND, return_parameters=HCI_LE_Read_Buffer_Size_Command.create_return_parameters( status=0, - hc_le_acl_data_packet_length=1234, - hc_total_num_le_acl_data_packets=56, + le_acl_data_packet_length=1234, + total_num_le_acl_data_packets=56, ), ) basic_check(event) diff --git a/tests/host_test.py b/tests/host_test.py index 51704977..5789b5ec 100644 --- a/tests/host_test.py +++ b/tests/host_test.py @@ -16,11 +16,14 @@ # Imports # ----------------------------------------------------------------------------- import logging +import unittest.mock import pytest +import unittest from bumble.controller import Controller -from bumble.host import Host +from bumble.host import Host, DataPacketQueue from bumble.transport import AsyncPipeSink +from bumble.hci import HCI_AclDataPacket # ----------------------------------------------------------------------------- # Logging @@ -60,3 +63,90 @@ async def test_reset(supported_commands: str, lmp_features: str): assert host.local_lmp_features == int.from_bytes( bytes.fromhex(lmp_features), 'little' ) + + +# ----------------------------------------------------------------------------- +def test_data_packet_queue(): + controller = unittest.mock.Mock() + queue = DataPacketQueue(10, 2, controller.send) + assert queue.queued == 0 + assert queue.completed == 0 + packet = HCI_AclDataPacket( + connection_handle=123, pb_flag=0, bc_flag=0, data_total_length=0, data=b'' + ) + + queue.enqueue(packet, packet.connection_handle) + assert queue.queued == 1 + assert queue.completed == 0 + assert controller.send.call_count == 1 + + queue.enqueue(packet, packet.connection_handle) + assert queue.queued == 2 + assert queue.completed == 0 + assert controller.send.call_count == 2 + + queue.enqueue(packet, packet.connection_handle) + assert queue.queued == 3 + assert queue.completed == 0 + assert controller.send.call_count == 2 + + queue.on_packets_completed(1, 8000) + assert queue.queued == 3 + assert queue.completed == 0 + assert controller.send.call_count == 2 + + queue.on_packets_completed(1, 123) + assert queue.queued == 3 + assert queue.completed == 1 + assert controller.send.call_count == 3 + + queue.enqueue(packet, packet.connection_handle) + assert queue.queued == 4 + assert queue.completed == 1 + assert controller.send.call_count == 3 + + queue.on_packets_completed(2, 123) + assert queue.queued == 4 + assert queue.completed == 3 + assert controller.send.call_count == 4 + + queue.on_packets_completed(1, 123) + assert queue.queued == 4 + assert queue.completed == 4 + assert controller.send.call_count == 4 + + queue.enqueue(packet, 123) + queue.enqueue(packet, 123) + queue.enqueue(packet, 123) + queue.enqueue(packet, 124) + queue.enqueue(packet, 124) + queue.enqueue(packet, 124) + queue.on_packets_completed(1, 123) + assert queue.queued == 10 + assert queue.completed == 5 + queue.flush(123) + queue.flush(124) + assert queue.queued == 10 + assert queue.completed == 10 + + queue.enqueue(packet, 123) + queue.on_packets_completed(1, 124) + assert queue.queued == 11 + assert queue.completed == 10 + queue.on_packets_completed(1000, 123) + assert queue.queued == 11 + assert queue.completed == 11 + + drain_listener = unittest.mock.Mock() + queue.on('drain', drain_listener.on_drain) + queue.enqueue(packet, 123) + assert drain_listener.on_drain.call_count == 0 + queue.on_packets_completed(1, 123) + assert drain_listener.on_drain.call_count == 1 + queue.enqueue(packet, 123) + queue.enqueue(packet, 123) + queue.enqueue(packet, 123) + queue.flush(123) + assert drain_listener.on_drain.call_count == 1 + assert queue.queued == 15 + assert queue.completed == 15