L2CAP: Fix Enhanced Retransmission Segmentation

This commit is contained in:
Josh Wu
2026-01-07 00:28:54 +08:00
parent 8ac8724cd8
commit b153d0fcde
2 changed files with 127 additions and 79 deletions

View File

@@ -20,6 +20,7 @@ from __future__ import annotations
import asyncio
import dataclasses
import enum
import itertools
import logging
import struct
from collections import deque
@@ -302,11 +303,9 @@ class EnhancedControlField(ControlField):
@dataclasses.dataclass
class InformationEnhancedControlField(EnhancedControlField):
tx_seq: int = 0
tx_seq: int
sar: int
req_seq: int = 0
segmentation_and_reassembly: int = (
EnhancedControlField.SegmentationAndReassembly.UNSEGMENTED
)
final: int = 1
frame_type = EnhancedControlField.FieldType.I_FRAME
@@ -316,15 +315,15 @@ class InformationEnhancedControlField(EnhancedControlField):
return cls(
tx_seq=(data[0] >> 1) & 0b0111111,
final=(data[0] >> 7) & 0b1,
req_seq=(data[1] & 0b001111111),
segmentation_and_reassembly=(data[1] >> 6) & 0b11,
req_seq=(data[1] & 0b00111111),
sar=(data[1] >> 6) & 0b11,
)
def __bytes__(self) -> bytes:
return bytes(
[
self.frame_type | (self.tx_seq << 1) | (self.final << 7),
self.req_seq | (self.segmentation_and_reassembly << 6),
self.req_seq | (self.sar << 6),
]
)
@@ -889,27 +888,38 @@ class EnhancedRetransmissionProcessor(Processor):
class _PendingPdu:
payload: bytes
tx_seq: int
sar: InformationEnhancedControlField.SegmentationAndReassembly
sdu_length: int = 0
req_seq: int = 0
def __bytes__(self) -> bytes:
return (
bytes(
InformationEnhancedControlField(
tx_seq=self.tx_seq, req_seq=self.req_seq
tx_seq=self.tx_seq,
req_seq=self.req_seq,
sar=self.sar,
)
)
+ (
struct.pack('<H', self.sdu_length)
if self.sar
== InformationEnhancedControlField.SegmentationAndReassembly.START
else b''
)
+ self.payload
)
_expected_ack_seq: int = 0
_last_acked_tx_seq: int = 0
_last_acked_rx_seq: int = 0
_next_tx_seq: int = 0
_last_tx_seq: int = 0
_req_seq_num: int = 0
_next_seq_num: int = 0
_remote_is_busy: bool = False
_in_sdu: bytes = b''
_num_receiver_ready_polls_sent: int = 0
_pending_pdus: list[_PendingPdu]
_tx_window: list[_PendingPdu]
_monitor_handle: asyncio.TimerHandle | None = None
_receiver_ready_poll_handle: asyncio.TimerHandle | None = None
@@ -917,12 +927,6 @@ class EnhancedRetransmissionProcessor(Processor):
monitor_timeout: float
retransmission_timeout: float
@classmethod
def _num_frames_between(cls, low: int, high: int) -> int:
if high < low:
high += cls.MAX_SEQ_NUM
return high - low
def __init__(
self,
channel: ClassicChannel,
@@ -935,6 +939,7 @@ class EnhancedRetransmissionProcessor(Processor):
self.peer_mps = peer_mps
self.peer_tx_window_size = peer_tx_window_size
self._pending_pdus = []
self._tx_window = []
self.monitor_timeout = spec.monitor_timeout
self.channel = channel
self.retransmission_timeout = spec.retransmission_timeout
@@ -972,12 +977,9 @@ class EnhancedRetransmissionProcessor(Processor):
def _send_receiver_ready_poll(self) -> None:
self._num_receiver_ready_polls_sent += 1
self.channel.send_pdu(
SupervisoryEnhancedControlField(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=1,
req_seq=self._next_seq_num,
)
self._send_s_frame(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=1,
)
def _get_next_tx_seq(self) -> int:
@@ -987,12 +989,35 @@ class EnhancedRetransmissionProcessor(Processor):
@override
def send_sdu(self, sdu: bytes) -> None:
if len(sdu) > self.peer_mps:
raise InvalidArgumentError(
f'SDU size({len(sdu)}) exceeds channel MPS {self.peer_mps}'
if len(sdu) <= self.peer_mps:
pdu = self._PendingPdu(
payload=sdu,
tx_seq=self._get_next_tx_seq(),
req_seq=self._req_seq_num,
sar=InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
)
pdu = self._PendingPdu(payload=sdu, tx_seq=self._get_next_tx_seq())
self._pending_pdus.append(pdu)
self._pending_pdus.append(pdu)
else:
for offset in range(0, len(sdu), self.peer_mps):
payload = sdu[offset : offset + self.peer_mps]
if offset == 0:
sar = (
InformationEnhancedControlField.SegmentationAndReassembly.START
)
elif offset + len(payload) >= len(sdu):
sar = InformationEnhancedControlField.SegmentationAndReassembly.END
else:
sar = (
InformationEnhancedControlField.SegmentationAndReassembly.CONTINUATION
)
pdu = self._PendingPdu(
payload=payload,
tx_seq=self._get_next_tx_seq(),
req_seq=self._req_seq_num,
sar=sar,
sdu_length=len(sdu),
)
self._pending_pdus.append(pdu)
self._process_output()
@override
@@ -1000,17 +1025,37 @@ class EnhancedRetransmissionProcessor(Processor):
control_field = EnhancedControlField.from_bytes(pdu)
self._update_ack_seq(control_field.req_seq, control_field.final != 0)
if isinstance(control_field, InformationEnhancedControlField):
if control_field.tx_seq != self._next_seq_num:
if control_field.tx_seq != self._req_seq_num:
logger.error(
"tx_seq != self._req_seq_num, tx_seq: %d, self._req_seq_num: %d",
control_field.tx_seq,
self._req_seq_num,
)
return
self._next_seq_num = (self._next_seq_num + 1) % self.MAX_SEQ_NUM
self._req_seq_num = self._next_seq_num
self._req_seq_num = (control_field.tx_seq + 1) % self.MAX_SEQ_NUM
ack_frame = SupervisoryEnhancedControlField(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
req_seq=self._next_seq_num,
)
self.channel.send_pdu(ack_frame)
self.channel.on_sdu(pdu[2:])
if (
control_field.sar
== InformationEnhancedControlField.SegmentationAndReassembly.START
):
# Drop Control Field(2) + SDU Length(2)
self._in_sdu += pdu[4:]
else:
# Drop Control Field(2)
self._in_sdu += pdu[2:]
if control_field.sar in (
InformationEnhancedControlField.SegmentationAndReassembly.END,
InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
):
self.channel.on_sdu(self._in_sdu)
self._in_sdu = b''
# If sink doesn't trigger any I-frame, ack this frame.
if self._req_seq_num != self._last_acked_rx_seq:
self._send_s_frame(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=0,
)
elif isinstance(control_field, SupervisoryEnhancedControlField):
self._remote_is_busy = (
control_field.supervision_function
@@ -1022,56 +1067,66 @@ class EnhancedRetransmissionProcessor(Processor):
SupervisoryEnhancedControlField.SupervisoryFunction.RNR,
):
if control_field.poll:
self.channel.send_pdu(
SupervisoryEnhancedControlField(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=1,
req_seq=self._next_seq_num,
)
self._send_s_frame(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=1,
)
else:
# TODO: Handle Retransmission.
pass
def _process_output(self) -> None:
if self._remote_is_busy or self._monitor_handle:
if self._remote_is_busy:
logger.debug("Remote is busy")
return
if self._monitor_handle:
logger.debug("Monitor handle is not None")
return
for pdu in self._pending_pdus:
if self._num_unacked_frames >= self.peer_tx_window_size:
return
self._send_pdu(pdu)
self._last_tx_seq = pdu.tx_seq
pdu_to_send = self.peer_tx_window_size - len(self._tx_window)
for pdu in itertools.islice(self._pending_pdus, pdu_to_send):
self._send_i_frame(pdu)
self._pending_pdus = self._pending_pdus[pdu_to_send:]
@property
def _num_unacked_frames(self) -> int:
if not self._pending_pdus:
return 0
return self._num_frames_between(self._expected_ack_seq, self._last_tx_seq + 1)
def _send_pdu(self, pdu: _PendingPdu) -> None:
def _send_i_frame(self, pdu: _PendingPdu) -> None:
pdu.req_seq = self._req_seq_num
self._start_receiver_ready_poll()
self._tx_window.append(pdu)
self.channel.send_pdu(bytes(pdu))
self._last_acked_rx_seq = self._req_seq_num
def _send_s_frame(
self,
supervision_function: SupervisoryEnhancedControlField.SupervisoryFunction,
final: int,
) -> None:
self.channel.send_pdu(
SupervisoryEnhancedControlField(
supervision_function=supervision_function,
final=final,
req_seq=self._req_seq_num,
)
)
self._last_acked_rx_seq = self._req_seq_num
def _update_ack_seq(self, new_seq: int, is_poll_response: bool) -> None:
num_frames_acked = self._num_frames_between(self._expected_ack_seq, new_seq)
if num_frames_acked > self._num_unacked_frames:
num_frames_acked = (new_seq - self._last_acked_tx_seq) % self.MAX_SEQ_NUM
if num_frames_acked > len(self._tx_window):
logger.error(
"Received acknowledgment for %d frames but only %d frames are pending",
num_frames_acked,
self._num_unacked_frames,
len(self._tx_window),
)
return
if is_poll_response and self._monitor_handle:
self._monitor_handle.cancel()
self._monitor_handle = None
del self._pending_pdus[:num_frames_acked]
self._expected_ack_seq = new_seq
del self._tx_window[:num_frames_acked]
self._last_acked_tx_seq = new_seq
if (
self._expected_ack_seq == self._next_tx_seq
self._last_acked_tx_seq == self._next_tx_seq
and self._receiver_ready_poll_handle
):
self._receiver_ready_poll_handle.cancel()

View File

@@ -239,20 +239,7 @@ async def transfer_payload(
channels[1].sink = received.put_nowait
sdu_lengths = (21, 70, 700, 5523)
if isinstance(channels[1], l2cap.LeCreditBasedChannel):
mps = channels[1].mps
elif isinstance(
processor := channels[1].processor, l2cap.EnhancedRetransmissionProcessor
):
mps = processor.mps
else:
mps = channels[1].mtu
messages = [
bytes([i % 8 for i in range(sdu_length)])
for sdu_length in sdu_lengths
if sdu_length <= mps
]
messages = [bytes([i % 8 for i in range(sdu_length)]) for sdu_length in sdu_lengths]
for message in messages:
channels[0].write(message)
if isinstance(channels[0], l2cap.LeCreditBasedChannel):
@@ -334,20 +321,26 @@ async def test_mtu():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_enhanced_retransmission_mode():
@pytest.mark.parametrize("mtu,", (50, 255, 256, 1000))
async def test_enhanced_retransmission_mode(mtu: int):
devices = TwoDevices()
await devices.setup_connection()
server_channels = asyncio.Queue[l2cap.ClassicChannel]()
server = devices.devices[1].create_l2cap_server(
spec=l2cap.ClassicChannelSpec(
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
mtu=mtu,
mps=256,
),
handler=server_channels.put_nowait,
)
client_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(
server.psm, mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION
server.psm,
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
mtu=mtu,
mps=1024,
)
)
server_channel = await server_channels.get()