From b153d0fcde5e27a2ac041e4fd70513b170824649 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Wed, 7 Jan 2026 00:28:54 +0800 Subject: [PATCH] L2CAP: Fix Enhanced Retransmission Segmentation --- bumble/l2cap.py | 179 +++++++++++++++++++++++++++++--------------- tests/l2cap_test.py | 27 +++---- 2 files changed, 127 insertions(+), 79 deletions(-) diff --git a/bumble/l2cap.py b/bumble/l2cap.py index b595ec6..7023a87 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -20,6 +20,7 @@ from __future__ import annotations import asyncio import dataclasses import enum +import itertools import logging import struct from collections import deque @@ -302,11 +303,9 @@ class EnhancedControlField(ControlField): @dataclasses.dataclass class InformationEnhancedControlField(EnhancedControlField): - tx_seq: int = 0 + tx_seq: int + sar: int req_seq: int = 0 - segmentation_and_reassembly: int = ( - EnhancedControlField.SegmentationAndReassembly.UNSEGMENTED - ) final: int = 1 frame_type = EnhancedControlField.FieldType.I_FRAME @@ -316,15 +315,15 @@ class InformationEnhancedControlField(EnhancedControlField): return cls( tx_seq=(data[0] >> 1) & 0b0111111, final=(data[0] >> 7) & 0b1, - req_seq=(data[1] & 0b001111111), - segmentation_and_reassembly=(data[1] >> 6) & 0b11, + req_seq=(data[1] & 0b00111111), + sar=(data[1] >> 6) & 0b11, ) def __bytes__(self) -> bytes: return bytes( [ self.frame_type | (self.tx_seq << 1) | (self.final << 7), - self.req_seq | (self.segmentation_and_reassembly << 6), + self.req_seq | (self.sar << 6), ] ) @@ -889,27 +888,38 @@ class EnhancedRetransmissionProcessor(Processor): class _PendingPdu: payload: bytes tx_seq: int + sar: InformationEnhancedControlField.SegmentationAndReassembly + sdu_length: int = 0 req_seq: int = 0 def __bytes__(self) -> bytes: return ( bytes( InformationEnhancedControlField( - tx_seq=self.tx_seq, req_seq=self.req_seq + tx_seq=self.tx_seq, + req_seq=self.req_seq, + sar=self.sar, ) ) + + ( + struct.pack(' int: - if high < low: - high += cls.MAX_SEQ_NUM - return high - low - def __init__( self, channel: ClassicChannel, @@ -935,6 +939,7 @@ class EnhancedRetransmissionProcessor(Processor): self.peer_mps = peer_mps self.peer_tx_window_size = peer_tx_window_size self._pending_pdus = [] + self._tx_window = [] self.monitor_timeout = spec.monitor_timeout self.channel = channel self.retransmission_timeout = spec.retransmission_timeout @@ -972,12 +977,9 @@ class EnhancedRetransmissionProcessor(Processor): def _send_receiver_ready_poll(self) -> None: self._num_receiver_ready_polls_sent += 1 - self.channel.send_pdu( - SupervisoryEnhancedControlField( - supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, - final=1, - req_seq=self._next_seq_num, - ) + self._send_s_frame( + supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, + final=1, ) def _get_next_tx_seq(self) -> int: @@ -987,12 +989,35 @@ class EnhancedRetransmissionProcessor(Processor): @override def send_sdu(self, sdu: bytes) -> None: - if len(sdu) > self.peer_mps: - raise InvalidArgumentError( - f'SDU size({len(sdu)}) exceeds channel MPS {self.peer_mps}' + if len(sdu) <= self.peer_mps: + pdu = self._PendingPdu( + payload=sdu, + tx_seq=self._get_next_tx_seq(), + req_seq=self._req_seq_num, + sar=InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED, ) - pdu = self._PendingPdu(payload=sdu, tx_seq=self._get_next_tx_seq()) - self._pending_pdus.append(pdu) + self._pending_pdus.append(pdu) + else: + for offset in range(0, len(sdu), self.peer_mps): + payload = sdu[offset : offset + self.peer_mps] + if offset == 0: + sar = ( + InformationEnhancedControlField.SegmentationAndReassembly.START + ) + elif offset + len(payload) >= len(sdu): + sar = InformationEnhancedControlField.SegmentationAndReassembly.END + else: + sar = ( + InformationEnhancedControlField.SegmentationAndReassembly.CONTINUATION + ) + pdu = self._PendingPdu( + payload=payload, + tx_seq=self._get_next_tx_seq(), + req_seq=self._req_seq_num, + sar=sar, + sdu_length=len(sdu), + ) + self._pending_pdus.append(pdu) self._process_output() @override @@ -1000,17 +1025,37 @@ class EnhancedRetransmissionProcessor(Processor): control_field = EnhancedControlField.from_bytes(pdu) self._update_ack_seq(control_field.req_seq, control_field.final != 0) if isinstance(control_field, InformationEnhancedControlField): - if control_field.tx_seq != self._next_seq_num: + if control_field.tx_seq != self._req_seq_num: + logger.error( + "tx_seq != self._req_seq_num, tx_seq: %d, self._req_seq_num: %d", + control_field.tx_seq, + self._req_seq_num, + ) return - self._next_seq_num = (self._next_seq_num + 1) % self.MAX_SEQ_NUM - self._req_seq_num = self._next_seq_num + self._req_seq_num = (control_field.tx_seq + 1) % self.MAX_SEQ_NUM - ack_frame = SupervisoryEnhancedControlField( - supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, - req_seq=self._next_seq_num, - ) - self.channel.send_pdu(ack_frame) - self.channel.on_sdu(pdu[2:]) + if ( + control_field.sar + == InformationEnhancedControlField.SegmentationAndReassembly.START + ): + # Drop Control Field(2) + SDU Length(2) + self._in_sdu += pdu[4:] + else: + # Drop Control Field(2) + self._in_sdu += pdu[2:] + if control_field.sar in ( + InformationEnhancedControlField.SegmentationAndReassembly.END, + InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED, + ): + self.channel.on_sdu(self._in_sdu) + self._in_sdu = b'' + + # If sink doesn't trigger any I-frame, ack this frame. + if self._req_seq_num != self._last_acked_rx_seq: + self._send_s_frame( + supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, + final=0, + ) elif isinstance(control_field, SupervisoryEnhancedControlField): self._remote_is_busy = ( control_field.supervision_function @@ -1022,56 +1067,66 @@ class EnhancedRetransmissionProcessor(Processor): SupervisoryEnhancedControlField.SupervisoryFunction.RNR, ): if control_field.poll: - self.channel.send_pdu( - SupervisoryEnhancedControlField( - supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, - final=1, - req_seq=self._next_seq_num, - ) + self._send_s_frame( + supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, + final=1, ) else: # TODO: Handle Retransmission. pass def _process_output(self) -> None: - if self._remote_is_busy or self._monitor_handle: + if self._remote_is_busy: + logger.debug("Remote is busy") + return + if self._monitor_handle: + logger.debug("Monitor handle is not None") return - for pdu in self._pending_pdus: - if self._num_unacked_frames >= self.peer_tx_window_size: - return - self._send_pdu(pdu) - self._last_tx_seq = pdu.tx_seq + pdu_to_send = self.peer_tx_window_size - len(self._tx_window) + for pdu in itertools.islice(self._pending_pdus, pdu_to_send): + self._send_i_frame(pdu) + self._pending_pdus = self._pending_pdus[pdu_to_send:] - @property - def _num_unacked_frames(self) -> int: - if not self._pending_pdus: - return 0 - return self._num_frames_between(self._expected_ack_seq, self._last_tx_seq + 1) - - def _send_pdu(self, pdu: _PendingPdu) -> None: + def _send_i_frame(self, pdu: _PendingPdu) -> None: pdu.req_seq = self._req_seq_num self._start_receiver_ready_poll() + self._tx_window.append(pdu) self.channel.send_pdu(bytes(pdu)) + self._last_acked_rx_seq = self._req_seq_num + + def _send_s_frame( + self, + supervision_function: SupervisoryEnhancedControlField.SupervisoryFunction, + final: int, + ) -> None: + self.channel.send_pdu( + SupervisoryEnhancedControlField( + supervision_function=supervision_function, + final=final, + req_seq=self._req_seq_num, + ) + ) + self._last_acked_rx_seq = self._req_seq_num def _update_ack_seq(self, new_seq: int, is_poll_response: bool) -> None: - num_frames_acked = self._num_frames_between(self._expected_ack_seq, new_seq) - if num_frames_acked > self._num_unacked_frames: + num_frames_acked = (new_seq - self._last_acked_tx_seq) % self.MAX_SEQ_NUM + if num_frames_acked > len(self._tx_window): logger.error( "Received acknowledgment for %d frames but only %d frames are pending", num_frames_acked, - self._num_unacked_frames, + len(self._tx_window), ) return if is_poll_response and self._monitor_handle: self._monitor_handle.cancel() self._monitor_handle = None - del self._pending_pdus[:num_frames_acked] - self._expected_ack_seq = new_seq + del self._tx_window[:num_frames_acked] + self._last_acked_tx_seq = new_seq if ( - self._expected_ack_seq == self._next_tx_seq + self._last_acked_tx_seq == self._next_tx_seq and self._receiver_ready_poll_handle ): self._receiver_ready_poll_handle.cancel() diff --git a/tests/l2cap_test.py b/tests/l2cap_test.py index 689704c..0cb1db9 100644 --- a/tests/l2cap_test.py +++ b/tests/l2cap_test.py @@ -239,20 +239,7 @@ async def transfer_payload( channels[1].sink = received.put_nowait sdu_lengths = (21, 70, 700, 5523) - if isinstance(channels[1], l2cap.LeCreditBasedChannel): - mps = channels[1].mps - elif isinstance( - processor := channels[1].processor, l2cap.EnhancedRetransmissionProcessor - ): - mps = processor.mps - else: - mps = channels[1].mtu - - messages = [ - bytes([i % 8 for i in range(sdu_length)]) - for sdu_length in sdu_lengths - if sdu_length <= mps - ] + messages = [bytes([i % 8 for i in range(sdu_length)]) for sdu_length in sdu_lengths] for message in messages: channels[0].write(message) if isinstance(channels[0], l2cap.LeCreditBasedChannel): @@ -334,20 +321,26 @@ async def test_mtu(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_enhanced_retransmission_mode(): +@pytest.mark.parametrize("mtu,", (50, 255, 256, 1000)) +async def test_enhanced_retransmission_mode(mtu: int): devices = TwoDevices() await devices.setup_connection() server_channels = asyncio.Queue[l2cap.ClassicChannel]() server = devices.devices[1].create_l2cap_server( spec=l2cap.ClassicChannelSpec( - mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION + mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION, + mtu=mtu, + mps=256, ), handler=server_channels.put_nowait, ) client_channel = await devices.connections[0].create_l2cap_channel( spec=l2cap.ClassicChannelSpec( - server.psm, mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION + server.psm, + mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION, + mtu=mtu, + mps=1024, ) ) server_channel = await server_channels.get()