L2CAP: Fix Enhanced Retransmission Segmentation

This commit is contained in:
Josh Wu
2026-01-07 00:28:54 +08:00
parent 8ac8724cd8
commit b153d0fcde
2 changed files with 127 additions and 79 deletions

View File

@@ -20,6 +20,7 @@ from __future__ import annotations
import asyncio import asyncio
import dataclasses import dataclasses
import enum import enum
import itertools
import logging import logging
import struct import struct
from collections import deque from collections import deque
@@ -302,11 +303,9 @@ class EnhancedControlField(ControlField):
@dataclasses.dataclass @dataclasses.dataclass
class InformationEnhancedControlField(EnhancedControlField): class InformationEnhancedControlField(EnhancedControlField):
tx_seq: int = 0 tx_seq: int
sar: int
req_seq: int = 0 req_seq: int = 0
segmentation_and_reassembly: int = (
EnhancedControlField.SegmentationAndReassembly.UNSEGMENTED
)
final: int = 1 final: int = 1
frame_type = EnhancedControlField.FieldType.I_FRAME frame_type = EnhancedControlField.FieldType.I_FRAME
@@ -316,15 +315,15 @@ class InformationEnhancedControlField(EnhancedControlField):
return cls( return cls(
tx_seq=(data[0] >> 1) & 0b0111111, tx_seq=(data[0] >> 1) & 0b0111111,
final=(data[0] >> 7) & 0b1, final=(data[0] >> 7) & 0b1,
req_seq=(data[1] & 0b001111111), req_seq=(data[1] & 0b00111111),
segmentation_and_reassembly=(data[1] >> 6) & 0b11, sar=(data[1] >> 6) & 0b11,
) )
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return bytes( return bytes(
[ [
self.frame_type | (self.tx_seq << 1) | (self.final << 7), self.frame_type | (self.tx_seq << 1) | (self.final << 7),
self.req_seq | (self.segmentation_and_reassembly << 6), self.req_seq | (self.sar << 6),
] ]
) )
@@ -889,27 +888,38 @@ class EnhancedRetransmissionProcessor(Processor):
class _PendingPdu: class _PendingPdu:
payload: bytes payload: bytes
tx_seq: int tx_seq: int
sar: InformationEnhancedControlField.SegmentationAndReassembly
sdu_length: int = 0
req_seq: int = 0 req_seq: int = 0
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return ( return (
bytes( bytes(
InformationEnhancedControlField( InformationEnhancedControlField(
tx_seq=self.tx_seq, req_seq=self.req_seq tx_seq=self.tx_seq,
req_seq=self.req_seq,
sar=self.sar,
) )
) )
+ (
struct.pack('<H', self.sdu_length)
if self.sar
== InformationEnhancedControlField.SegmentationAndReassembly.START
else b''
)
+ self.payload + self.payload
) )
_expected_ack_seq: int = 0 _last_acked_tx_seq: int = 0
_last_acked_rx_seq: int = 0
_next_tx_seq: int = 0 _next_tx_seq: int = 0
_last_tx_seq: int = 0
_req_seq_num: int = 0 _req_seq_num: int = 0
_next_seq_num: int = 0
_remote_is_busy: bool = False _remote_is_busy: bool = False
_in_sdu: bytes = b''
_num_receiver_ready_polls_sent: int = 0 _num_receiver_ready_polls_sent: int = 0
_pending_pdus: list[_PendingPdu] _pending_pdus: list[_PendingPdu]
_tx_window: list[_PendingPdu]
_monitor_handle: asyncio.TimerHandle | None = None _monitor_handle: asyncio.TimerHandle | None = None
_receiver_ready_poll_handle: asyncio.TimerHandle | None = None _receiver_ready_poll_handle: asyncio.TimerHandle | None = None
@@ -917,12 +927,6 @@ class EnhancedRetransmissionProcessor(Processor):
monitor_timeout: float monitor_timeout: float
retransmission_timeout: float retransmission_timeout: float
@classmethod
def _num_frames_between(cls, low: int, high: int) -> int:
if high < low:
high += cls.MAX_SEQ_NUM
return high - low
def __init__( def __init__(
self, self,
channel: ClassicChannel, channel: ClassicChannel,
@@ -935,6 +939,7 @@ class EnhancedRetransmissionProcessor(Processor):
self.peer_mps = peer_mps self.peer_mps = peer_mps
self.peer_tx_window_size = peer_tx_window_size self.peer_tx_window_size = peer_tx_window_size
self._pending_pdus = [] self._pending_pdus = []
self._tx_window = []
self.monitor_timeout = spec.monitor_timeout self.monitor_timeout = spec.monitor_timeout
self.channel = channel self.channel = channel
self.retransmission_timeout = spec.retransmission_timeout self.retransmission_timeout = spec.retransmission_timeout
@@ -972,12 +977,9 @@ class EnhancedRetransmissionProcessor(Processor):
def _send_receiver_ready_poll(self) -> None: def _send_receiver_ready_poll(self) -> None:
self._num_receiver_ready_polls_sent += 1 self._num_receiver_ready_polls_sent += 1
self.channel.send_pdu( self._send_s_frame(
SupervisoryEnhancedControlField( supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, final=1,
final=1,
req_seq=self._next_seq_num,
)
) )
def _get_next_tx_seq(self) -> int: def _get_next_tx_seq(self) -> int:
@@ -987,12 +989,35 @@ class EnhancedRetransmissionProcessor(Processor):
@override @override
def send_sdu(self, sdu: bytes) -> None: def send_sdu(self, sdu: bytes) -> None:
if len(sdu) > self.peer_mps: if len(sdu) <= self.peer_mps:
raise InvalidArgumentError( pdu = self._PendingPdu(
f'SDU size({len(sdu)}) exceeds channel MPS {self.peer_mps}' payload=sdu,
tx_seq=self._get_next_tx_seq(),
req_seq=self._req_seq_num,
sar=InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
) )
pdu = self._PendingPdu(payload=sdu, tx_seq=self._get_next_tx_seq()) self._pending_pdus.append(pdu)
self._pending_pdus.append(pdu) else:
for offset in range(0, len(sdu), self.peer_mps):
payload = sdu[offset : offset + self.peer_mps]
if offset == 0:
sar = (
InformationEnhancedControlField.SegmentationAndReassembly.START
)
elif offset + len(payload) >= len(sdu):
sar = InformationEnhancedControlField.SegmentationAndReassembly.END
else:
sar = (
InformationEnhancedControlField.SegmentationAndReassembly.CONTINUATION
)
pdu = self._PendingPdu(
payload=payload,
tx_seq=self._get_next_tx_seq(),
req_seq=self._req_seq_num,
sar=sar,
sdu_length=len(sdu),
)
self._pending_pdus.append(pdu)
self._process_output() self._process_output()
@override @override
@@ -1000,17 +1025,37 @@ class EnhancedRetransmissionProcessor(Processor):
control_field = EnhancedControlField.from_bytes(pdu) control_field = EnhancedControlField.from_bytes(pdu)
self._update_ack_seq(control_field.req_seq, control_field.final != 0) self._update_ack_seq(control_field.req_seq, control_field.final != 0)
if isinstance(control_field, InformationEnhancedControlField): if isinstance(control_field, InformationEnhancedControlField):
if control_field.tx_seq != self._next_seq_num: if control_field.tx_seq != self._req_seq_num:
logger.error(
"tx_seq != self._req_seq_num, tx_seq: %d, self._req_seq_num: %d",
control_field.tx_seq,
self._req_seq_num,
)
return return
self._next_seq_num = (self._next_seq_num + 1) % self.MAX_SEQ_NUM self._req_seq_num = (control_field.tx_seq + 1) % self.MAX_SEQ_NUM
self._req_seq_num = self._next_seq_num
ack_frame = SupervisoryEnhancedControlField( if (
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, control_field.sar
req_seq=self._next_seq_num, == InformationEnhancedControlField.SegmentationAndReassembly.START
) ):
self.channel.send_pdu(ack_frame) # Drop Control Field(2) + SDU Length(2)
self.channel.on_sdu(pdu[2:]) self._in_sdu += pdu[4:]
else:
# Drop Control Field(2)
self._in_sdu += pdu[2:]
if control_field.sar in (
InformationEnhancedControlField.SegmentationAndReassembly.END,
InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
):
self.channel.on_sdu(self._in_sdu)
self._in_sdu = b''
# If sink doesn't trigger any I-frame, ack this frame.
if self._req_seq_num != self._last_acked_rx_seq:
self._send_s_frame(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=0,
)
elif isinstance(control_field, SupervisoryEnhancedControlField): elif isinstance(control_field, SupervisoryEnhancedControlField):
self._remote_is_busy = ( self._remote_is_busy = (
control_field.supervision_function control_field.supervision_function
@@ -1022,56 +1067,66 @@ class EnhancedRetransmissionProcessor(Processor):
SupervisoryEnhancedControlField.SupervisoryFunction.RNR, SupervisoryEnhancedControlField.SupervisoryFunction.RNR,
): ):
if control_field.poll: if control_field.poll:
self.channel.send_pdu( self._send_s_frame(
SupervisoryEnhancedControlField( supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, final=1,
final=1,
req_seq=self._next_seq_num,
)
) )
else: else:
# TODO: Handle Retransmission. # TODO: Handle Retransmission.
pass pass
def _process_output(self) -> None: def _process_output(self) -> None:
if self._remote_is_busy or self._monitor_handle: if self._remote_is_busy:
logger.debug("Remote is busy")
return
if self._monitor_handle:
logger.debug("Monitor handle is not None")
return return
for pdu in self._pending_pdus: pdu_to_send = self.peer_tx_window_size - len(self._tx_window)
if self._num_unacked_frames >= self.peer_tx_window_size: for pdu in itertools.islice(self._pending_pdus, pdu_to_send):
return self._send_i_frame(pdu)
self._send_pdu(pdu) self._pending_pdus = self._pending_pdus[pdu_to_send:]
self._last_tx_seq = pdu.tx_seq
@property def _send_i_frame(self, pdu: _PendingPdu) -> None:
def _num_unacked_frames(self) -> int:
if not self._pending_pdus:
return 0
return self._num_frames_between(self._expected_ack_seq, self._last_tx_seq + 1)
def _send_pdu(self, pdu: _PendingPdu) -> None:
pdu.req_seq = self._req_seq_num pdu.req_seq = self._req_seq_num
self._start_receiver_ready_poll() self._start_receiver_ready_poll()
self._tx_window.append(pdu)
self.channel.send_pdu(bytes(pdu)) self.channel.send_pdu(bytes(pdu))
self._last_acked_rx_seq = self._req_seq_num
def _send_s_frame(
self,
supervision_function: SupervisoryEnhancedControlField.SupervisoryFunction,
final: int,
) -> None:
self.channel.send_pdu(
SupervisoryEnhancedControlField(
supervision_function=supervision_function,
final=final,
req_seq=self._req_seq_num,
)
)
self._last_acked_rx_seq = self._req_seq_num
def _update_ack_seq(self, new_seq: int, is_poll_response: bool) -> None: def _update_ack_seq(self, new_seq: int, is_poll_response: bool) -> None:
num_frames_acked = self._num_frames_between(self._expected_ack_seq, new_seq) num_frames_acked = (new_seq - self._last_acked_tx_seq) % self.MAX_SEQ_NUM
if num_frames_acked > self._num_unacked_frames: if num_frames_acked > len(self._tx_window):
logger.error( logger.error(
"Received acknowledgment for %d frames but only %d frames are pending", "Received acknowledgment for %d frames but only %d frames are pending",
num_frames_acked, num_frames_acked,
self._num_unacked_frames, len(self._tx_window),
) )
return return
if is_poll_response and self._monitor_handle: if is_poll_response and self._monitor_handle:
self._monitor_handle.cancel() self._monitor_handle.cancel()
self._monitor_handle = None self._monitor_handle = None
del self._pending_pdus[:num_frames_acked] del self._tx_window[:num_frames_acked]
self._expected_ack_seq = new_seq self._last_acked_tx_seq = new_seq
if ( if (
self._expected_ack_seq == self._next_tx_seq self._last_acked_tx_seq == self._next_tx_seq
and self._receiver_ready_poll_handle and self._receiver_ready_poll_handle
): ):
self._receiver_ready_poll_handle.cancel() self._receiver_ready_poll_handle.cancel()

View File

@@ -239,20 +239,7 @@ async def transfer_payload(
channels[1].sink = received.put_nowait channels[1].sink = received.put_nowait
sdu_lengths = (21, 70, 700, 5523) sdu_lengths = (21, 70, 700, 5523)
if isinstance(channels[1], l2cap.LeCreditBasedChannel): messages = [bytes([i % 8 for i in range(sdu_length)]) for sdu_length in sdu_lengths]
mps = channels[1].mps
elif isinstance(
processor := channels[1].processor, l2cap.EnhancedRetransmissionProcessor
):
mps = processor.mps
else:
mps = channels[1].mtu
messages = [
bytes([i % 8 for i in range(sdu_length)])
for sdu_length in sdu_lengths
if sdu_length <= mps
]
for message in messages: for message in messages:
channels[0].write(message) channels[0].write(message)
if isinstance(channels[0], l2cap.LeCreditBasedChannel): if isinstance(channels[0], l2cap.LeCreditBasedChannel):
@@ -334,20 +321,26 @@ async def test_mtu():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_enhanced_retransmission_mode(): @pytest.mark.parametrize("mtu,", (50, 255, 256, 1000))
async def test_enhanced_retransmission_mode(mtu: int):
devices = TwoDevices() devices = TwoDevices()
await devices.setup_connection() await devices.setup_connection()
server_channels = asyncio.Queue[l2cap.ClassicChannel]() server_channels = asyncio.Queue[l2cap.ClassicChannel]()
server = devices.devices[1].create_l2cap_server( server = devices.devices[1].create_l2cap_server(
spec=l2cap.ClassicChannelSpec( spec=l2cap.ClassicChannelSpec(
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
mtu=mtu,
mps=256,
), ),
handler=server_channels.put_nowait, handler=server_channels.put_nowait,
) )
client_channel = await devices.connections[0].create_l2cap_channel( client_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.ClassicChannelSpec( spec=l2cap.ClassicChannelSpec(
server.psm, mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION server.psm,
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
mtu=mtu,
mps=1024,
) )
) )
server_channel = await server_channels.get() server_channel = await server_channels.get()