mirror of
https://github.com/google/bumble.git
synced 2026-04-16 00:25:31 +00:00
L2CAP: Fix Enhanced Retransmission Segmentation
This commit is contained in:
171
bumble/l2cap.py
171
bumble/l2cap.py
@@ -20,6 +20,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import itertools
|
||||
import logging
|
||||
import struct
|
||||
from collections import deque
|
||||
@@ -302,11 +303,9 @@ class EnhancedControlField(ControlField):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InformationEnhancedControlField(EnhancedControlField):
|
||||
tx_seq: int = 0
|
||||
tx_seq: int
|
||||
sar: int
|
||||
req_seq: int = 0
|
||||
segmentation_and_reassembly: int = (
|
||||
EnhancedControlField.SegmentationAndReassembly.UNSEGMENTED
|
||||
)
|
||||
final: int = 1
|
||||
|
||||
frame_type = EnhancedControlField.FieldType.I_FRAME
|
||||
@@ -316,15 +315,15 @@ class InformationEnhancedControlField(EnhancedControlField):
|
||||
return cls(
|
||||
tx_seq=(data[0] >> 1) & 0b0111111,
|
||||
final=(data[0] >> 7) & 0b1,
|
||||
req_seq=(data[1] & 0b001111111),
|
||||
segmentation_and_reassembly=(data[1] >> 6) & 0b11,
|
||||
req_seq=(data[1] & 0b00111111),
|
||||
sar=(data[1] >> 6) & 0b11,
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes(
|
||||
[
|
||||
self.frame_type | (self.tx_seq << 1) | (self.final << 7),
|
||||
self.req_seq | (self.segmentation_and_reassembly << 6),
|
||||
self.req_seq | (self.sar << 6),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -889,27 +888,38 @@ class EnhancedRetransmissionProcessor(Processor):
|
||||
class _PendingPdu:
|
||||
payload: bytes
|
||||
tx_seq: int
|
||||
sar: InformationEnhancedControlField.SegmentationAndReassembly
|
||||
sdu_length: int = 0
|
||||
req_seq: int = 0
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return (
|
||||
bytes(
|
||||
InformationEnhancedControlField(
|
||||
tx_seq=self.tx_seq, req_seq=self.req_seq
|
||||
tx_seq=self.tx_seq,
|
||||
req_seq=self.req_seq,
|
||||
sar=self.sar,
|
||||
)
|
||||
)
|
||||
+ (
|
||||
struct.pack('<H', self.sdu_length)
|
||||
if self.sar
|
||||
== InformationEnhancedControlField.SegmentationAndReassembly.START
|
||||
else b''
|
||||
)
|
||||
+ self.payload
|
||||
)
|
||||
|
||||
_expected_ack_seq: int = 0
|
||||
_last_acked_tx_seq: int = 0
|
||||
_last_acked_rx_seq: int = 0
|
||||
_next_tx_seq: int = 0
|
||||
_last_tx_seq: int = 0
|
||||
_req_seq_num: int = 0
|
||||
_next_seq_num: int = 0
|
||||
_remote_is_busy: bool = False
|
||||
_in_sdu: bytes = b''
|
||||
|
||||
_num_receiver_ready_polls_sent: int = 0
|
||||
_pending_pdus: list[_PendingPdu]
|
||||
_tx_window: list[_PendingPdu]
|
||||
_monitor_handle: asyncio.TimerHandle | None = None
|
||||
_receiver_ready_poll_handle: asyncio.TimerHandle | None = None
|
||||
|
||||
@@ -917,12 +927,6 @@ class EnhancedRetransmissionProcessor(Processor):
|
||||
monitor_timeout: float
|
||||
retransmission_timeout: float
|
||||
|
||||
@classmethod
|
||||
def _num_frames_between(cls, low: int, high: int) -> int:
|
||||
if high < low:
|
||||
high += cls.MAX_SEQ_NUM
|
||||
return high - low
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel: ClassicChannel,
|
||||
@@ -935,6 +939,7 @@ class EnhancedRetransmissionProcessor(Processor):
|
||||
self.peer_mps = peer_mps
|
||||
self.peer_tx_window_size = peer_tx_window_size
|
||||
self._pending_pdus = []
|
||||
self._tx_window = []
|
||||
self.monitor_timeout = spec.monitor_timeout
|
||||
self.channel = channel
|
||||
self.retransmission_timeout = spec.retransmission_timeout
|
||||
@@ -972,12 +977,9 @@ class EnhancedRetransmissionProcessor(Processor):
|
||||
|
||||
def _send_receiver_ready_poll(self) -> None:
|
||||
self._num_receiver_ready_polls_sent += 1
|
||||
self.channel.send_pdu(
|
||||
SupervisoryEnhancedControlField(
|
||||
self._send_s_frame(
|
||||
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
|
||||
final=1,
|
||||
req_seq=self._next_seq_num,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_next_tx_seq(self) -> int:
|
||||
@@ -987,11 +989,34 @@ class EnhancedRetransmissionProcessor(Processor):
|
||||
|
||||
@override
|
||||
def send_sdu(self, sdu: bytes) -> None:
|
||||
if len(sdu) > self.peer_mps:
|
||||
raise InvalidArgumentError(
|
||||
f'SDU size({len(sdu)}) exceeds channel MPS {self.peer_mps}'
|
||||
if len(sdu) <= self.peer_mps:
|
||||
pdu = self._PendingPdu(
|
||||
payload=sdu,
|
||||
tx_seq=self._get_next_tx_seq(),
|
||||
req_seq=self._req_seq_num,
|
||||
sar=InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
|
||||
)
|
||||
self._pending_pdus.append(pdu)
|
||||
else:
|
||||
for offset in range(0, len(sdu), self.peer_mps):
|
||||
payload = sdu[offset : offset + self.peer_mps]
|
||||
if offset == 0:
|
||||
sar = (
|
||||
InformationEnhancedControlField.SegmentationAndReassembly.START
|
||||
)
|
||||
elif offset + len(payload) >= len(sdu):
|
||||
sar = InformationEnhancedControlField.SegmentationAndReassembly.END
|
||||
else:
|
||||
sar = (
|
||||
InformationEnhancedControlField.SegmentationAndReassembly.CONTINUATION
|
||||
)
|
||||
pdu = self._PendingPdu(
|
||||
payload=payload,
|
||||
tx_seq=self._get_next_tx_seq(),
|
||||
req_seq=self._req_seq_num,
|
||||
sar=sar,
|
||||
sdu_length=len(sdu),
|
||||
)
|
||||
pdu = self._PendingPdu(payload=sdu, tx_seq=self._get_next_tx_seq())
|
||||
self._pending_pdus.append(pdu)
|
||||
self._process_output()
|
||||
|
||||
@@ -1000,17 +1025,37 @@ class EnhancedRetransmissionProcessor(Processor):
|
||||
control_field = EnhancedControlField.from_bytes(pdu)
|
||||
self._update_ack_seq(control_field.req_seq, control_field.final != 0)
|
||||
if isinstance(control_field, InformationEnhancedControlField):
|
||||
if control_field.tx_seq != self._next_seq_num:
|
||||
return
|
||||
self._next_seq_num = (self._next_seq_num + 1) % self.MAX_SEQ_NUM
|
||||
self._req_seq_num = self._next_seq_num
|
||||
|
||||
ack_frame = SupervisoryEnhancedControlField(
|
||||
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
|
||||
req_seq=self._next_seq_num,
|
||||
if control_field.tx_seq != self._req_seq_num:
|
||||
logger.error(
|
||||
"tx_seq != self._req_seq_num, tx_seq: %d, self._req_seq_num: %d",
|
||||
control_field.tx_seq,
|
||||
self._req_seq_num,
|
||||
)
|
||||
return
|
||||
self._req_seq_num = (control_field.tx_seq + 1) % self.MAX_SEQ_NUM
|
||||
|
||||
if (
|
||||
control_field.sar
|
||||
== InformationEnhancedControlField.SegmentationAndReassembly.START
|
||||
):
|
||||
# Drop Control Field(2) + SDU Length(2)
|
||||
self._in_sdu += pdu[4:]
|
||||
else:
|
||||
# Drop Control Field(2)
|
||||
self._in_sdu += pdu[2:]
|
||||
if control_field.sar in (
|
||||
InformationEnhancedControlField.SegmentationAndReassembly.END,
|
||||
InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
|
||||
):
|
||||
self.channel.on_sdu(self._in_sdu)
|
||||
self._in_sdu = b''
|
||||
|
||||
# If sink doesn't trigger any I-frame, ack this frame.
|
||||
if self._req_seq_num != self._last_acked_rx_seq:
|
||||
self._send_s_frame(
|
||||
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
|
||||
final=0,
|
||||
)
|
||||
self.channel.send_pdu(ack_frame)
|
||||
self.channel.on_sdu(pdu[2:])
|
||||
elif isinstance(control_field, SupervisoryEnhancedControlField):
|
||||
self._remote_is_busy = (
|
||||
control_field.supervision_function
|
||||
@@ -1022,56 +1067,66 @@ class EnhancedRetransmissionProcessor(Processor):
|
||||
SupervisoryEnhancedControlField.SupervisoryFunction.RNR,
|
||||
):
|
||||
if control_field.poll:
|
||||
self.channel.send_pdu(
|
||||
SupervisoryEnhancedControlField(
|
||||
self._send_s_frame(
|
||||
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
|
||||
final=1,
|
||||
req_seq=self._next_seq_num,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# TODO: Handle Retransmission.
|
||||
pass
|
||||
|
||||
def _process_output(self) -> None:
|
||||
if self._remote_is_busy or self._monitor_handle:
|
||||
if self._remote_is_busy:
|
||||
logger.debug("Remote is busy")
|
||||
return
|
||||
if self._monitor_handle:
|
||||
logger.debug("Monitor handle is not None")
|
||||
return
|
||||
|
||||
for pdu in self._pending_pdus:
|
||||
if self._num_unacked_frames >= self.peer_tx_window_size:
|
||||
return
|
||||
self._send_pdu(pdu)
|
||||
self._last_tx_seq = pdu.tx_seq
|
||||
pdu_to_send = self.peer_tx_window_size - len(self._tx_window)
|
||||
for pdu in itertools.islice(self._pending_pdus, pdu_to_send):
|
||||
self._send_i_frame(pdu)
|
||||
self._pending_pdus = self._pending_pdus[pdu_to_send:]
|
||||
|
||||
@property
|
||||
def _num_unacked_frames(self) -> int:
|
||||
if not self._pending_pdus:
|
||||
return 0
|
||||
return self._num_frames_between(self._expected_ack_seq, self._last_tx_seq + 1)
|
||||
|
||||
def _send_pdu(self, pdu: _PendingPdu) -> None:
|
||||
def _send_i_frame(self, pdu: _PendingPdu) -> None:
|
||||
pdu.req_seq = self._req_seq_num
|
||||
|
||||
self._start_receiver_ready_poll()
|
||||
self._tx_window.append(pdu)
|
||||
self.channel.send_pdu(bytes(pdu))
|
||||
self._last_acked_rx_seq = self._req_seq_num
|
||||
|
||||
def _send_s_frame(
|
||||
self,
|
||||
supervision_function: SupervisoryEnhancedControlField.SupervisoryFunction,
|
||||
final: int,
|
||||
) -> None:
|
||||
self.channel.send_pdu(
|
||||
SupervisoryEnhancedControlField(
|
||||
supervision_function=supervision_function,
|
||||
final=final,
|
||||
req_seq=self._req_seq_num,
|
||||
)
|
||||
)
|
||||
self._last_acked_rx_seq = self._req_seq_num
|
||||
|
||||
def _update_ack_seq(self, new_seq: int, is_poll_response: bool) -> None:
|
||||
num_frames_acked = self._num_frames_between(self._expected_ack_seq, new_seq)
|
||||
if num_frames_acked > self._num_unacked_frames:
|
||||
num_frames_acked = (new_seq - self._last_acked_tx_seq) % self.MAX_SEQ_NUM
|
||||
if num_frames_acked > len(self._tx_window):
|
||||
logger.error(
|
||||
"Received acknowledgment for %d frames but only %d frames are pending",
|
||||
num_frames_acked,
|
||||
self._num_unacked_frames,
|
||||
len(self._tx_window),
|
||||
)
|
||||
return
|
||||
if is_poll_response and self._monitor_handle:
|
||||
self._monitor_handle.cancel()
|
||||
self._monitor_handle = None
|
||||
|
||||
del self._pending_pdus[:num_frames_acked]
|
||||
self._expected_ack_seq = new_seq
|
||||
del self._tx_window[:num_frames_acked]
|
||||
self._last_acked_tx_seq = new_seq
|
||||
if (
|
||||
self._expected_ack_seq == self._next_tx_seq
|
||||
self._last_acked_tx_seq == self._next_tx_seq
|
||||
and self._receiver_ready_poll_handle
|
||||
):
|
||||
self._receiver_ready_poll_handle.cancel()
|
||||
|
||||
@@ -239,20 +239,7 @@ async def transfer_payload(
|
||||
channels[1].sink = received.put_nowait
|
||||
sdu_lengths = (21, 70, 700, 5523)
|
||||
|
||||
if isinstance(channels[1], l2cap.LeCreditBasedChannel):
|
||||
mps = channels[1].mps
|
||||
elif isinstance(
|
||||
processor := channels[1].processor, l2cap.EnhancedRetransmissionProcessor
|
||||
):
|
||||
mps = processor.mps
|
||||
else:
|
||||
mps = channels[1].mtu
|
||||
|
||||
messages = [
|
||||
bytes([i % 8 for i in range(sdu_length)])
|
||||
for sdu_length in sdu_lengths
|
||||
if sdu_length <= mps
|
||||
]
|
||||
messages = [bytes([i % 8 for i in range(sdu_length)]) for sdu_length in sdu_lengths]
|
||||
for message in messages:
|
||||
channels[0].write(message)
|
||||
if isinstance(channels[0], l2cap.LeCreditBasedChannel):
|
||||
@@ -334,20 +321,26 @@ async def test_mtu():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_enhanced_retransmission_mode():
|
||||
@pytest.mark.parametrize("mtu,", (50, 255, 256, 1000))
|
||||
async def test_enhanced_retransmission_mode(mtu: int):
|
||||
devices = TwoDevices()
|
||||
await devices.setup_connection()
|
||||
|
||||
server_channels = asyncio.Queue[l2cap.ClassicChannel]()
|
||||
server = devices.devices[1].create_l2cap_server(
|
||||
spec=l2cap.ClassicChannelSpec(
|
||||
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION
|
||||
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
|
||||
mtu=mtu,
|
||||
mps=256,
|
||||
),
|
||||
handler=server_channels.put_nowait,
|
||||
)
|
||||
client_channel = await devices.connections[0].create_l2cap_channel(
|
||||
spec=l2cap.ClassicChannelSpec(
|
||||
server.psm, mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION
|
||||
server.psm,
|
||||
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
|
||||
mtu=mtu,
|
||||
mps=1024,
|
||||
)
|
||||
)
|
||||
server_channel = await server_channels.get()
|
||||
|
||||
Reference in New Issue
Block a user