mirror of
https://github.com/google/bumble.git
synced 2026-04-16 00:25:31 +00:00
L2CAP: Fix Enhanced Retransmission Segmentation
This commit is contained in:
179
bumble/l2cap.py
179
bumble/l2cap.py
@@ -20,6 +20,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@@ -302,11 +303,9 @@ class EnhancedControlField(ControlField):
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class InformationEnhancedControlField(EnhancedControlField):
|
class InformationEnhancedControlField(EnhancedControlField):
|
||||||
tx_seq: int = 0
|
tx_seq: int
|
||||||
|
sar: int
|
||||||
req_seq: int = 0
|
req_seq: int = 0
|
||||||
segmentation_and_reassembly: int = (
|
|
||||||
EnhancedControlField.SegmentationAndReassembly.UNSEGMENTED
|
|
||||||
)
|
|
||||||
final: int = 1
|
final: int = 1
|
||||||
|
|
||||||
frame_type = EnhancedControlField.FieldType.I_FRAME
|
frame_type = EnhancedControlField.FieldType.I_FRAME
|
||||||
@@ -316,15 +315,15 @@ class InformationEnhancedControlField(EnhancedControlField):
|
|||||||
return cls(
|
return cls(
|
||||||
tx_seq=(data[0] >> 1) & 0b0111111,
|
tx_seq=(data[0] >> 1) & 0b0111111,
|
||||||
final=(data[0] >> 7) & 0b1,
|
final=(data[0] >> 7) & 0b1,
|
||||||
req_seq=(data[1] & 0b001111111),
|
req_seq=(data[1] & 0b00111111),
|
||||||
segmentation_and_reassembly=(data[1] >> 6) & 0b11,
|
sar=(data[1] >> 6) & 0b11,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __bytes__(self) -> bytes:
|
def __bytes__(self) -> bytes:
|
||||||
return bytes(
|
return bytes(
|
||||||
[
|
[
|
||||||
self.frame_type | (self.tx_seq << 1) | (self.final << 7),
|
self.frame_type | (self.tx_seq << 1) | (self.final << 7),
|
||||||
self.req_seq | (self.segmentation_and_reassembly << 6),
|
self.req_seq | (self.sar << 6),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -889,27 +888,38 @@ class EnhancedRetransmissionProcessor(Processor):
|
|||||||
class _PendingPdu:
|
class _PendingPdu:
|
||||||
payload: bytes
|
payload: bytes
|
||||||
tx_seq: int
|
tx_seq: int
|
||||||
|
sar: InformationEnhancedControlField.SegmentationAndReassembly
|
||||||
|
sdu_length: int = 0
|
||||||
req_seq: int = 0
|
req_seq: int = 0
|
||||||
|
|
||||||
def __bytes__(self) -> bytes:
|
def __bytes__(self) -> bytes:
|
||||||
return (
|
return (
|
||||||
bytes(
|
bytes(
|
||||||
InformationEnhancedControlField(
|
InformationEnhancedControlField(
|
||||||
tx_seq=self.tx_seq, req_seq=self.req_seq
|
tx_seq=self.tx_seq,
|
||||||
|
req_seq=self.req_seq,
|
||||||
|
sar=self.sar,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
+ (
|
||||||
|
struct.pack('<H', self.sdu_length)
|
||||||
|
if self.sar
|
||||||
|
== InformationEnhancedControlField.SegmentationAndReassembly.START
|
||||||
|
else b''
|
||||||
|
)
|
||||||
+ self.payload
|
+ self.payload
|
||||||
)
|
)
|
||||||
|
|
||||||
_expected_ack_seq: int = 0
|
_last_acked_tx_seq: int = 0
|
||||||
|
_last_acked_rx_seq: int = 0
|
||||||
_next_tx_seq: int = 0
|
_next_tx_seq: int = 0
|
||||||
_last_tx_seq: int = 0
|
|
||||||
_req_seq_num: int = 0
|
_req_seq_num: int = 0
|
||||||
_next_seq_num: int = 0
|
|
||||||
_remote_is_busy: bool = False
|
_remote_is_busy: bool = False
|
||||||
|
_in_sdu: bytes = b''
|
||||||
|
|
||||||
_num_receiver_ready_polls_sent: int = 0
|
_num_receiver_ready_polls_sent: int = 0
|
||||||
_pending_pdus: list[_PendingPdu]
|
_pending_pdus: list[_PendingPdu]
|
||||||
|
_tx_window: list[_PendingPdu]
|
||||||
_monitor_handle: asyncio.TimerHandle | None = None
|
_monitor_handle: asyncio.TimerHandle | None = None
|
||||||
_receiver_ready_poll_handle: asyncio.TimerHandle | None = None
|
_receiver_ready_poll_handle: asyncio.TimerHandle | None = None
|
||||||
|
|
||||||
@@ -917,12 +927,6 @@ class EnhancedRetransmissionProcessor(Processor):
|
|||||||
monitor_timeout: float
|
monitor_timeout: float
|
||||||
retransmission_timeout: float
|
retransmission_timeout: float
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _num_frames_between(cls, low: int, high: int) -> int:
|
|
||||||
if high < low:
|
|
||||||
high += cls.MAX_SEQ_NUM
|
|
||||||
return high - low
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channel: ClassicChannel,
|
channel: ClassicChannel,
|
||||||
@@ -935,6 +939,7 @@ class EnhancedRetransmissionProcessor(Processor):
|
|||||||
self.peer_mps = peer_mps
|
self.peer_mps = peer_mps
|
||||||
self.peer_tx_window_size = peer_tx_window_size
|
self.peer_tx_window_size = peer_tx_window_size
|
||||||
self._pending_pdus = []
|
self._pending_pdus = []
|
||||||
|
self._tx_window = []
|
||||||
self.monitor_timeout = spec.monitor_timeout
|
self.monitor_timeout = spec.monitor_timeout
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.retransmission_timeout = spec.retransmission_timeout
|
self.retransmission_timeout = spec.retransmission_timeout
|
||||||
@@ -972,12 +977,9 @@ class EnhancedRetransmissionProcessor(Processor):
|
|||||||
|
|
||||||
def _send_receiver_ready_poll(self) -> None:
|
def _send_receiver_ready_poll(self) -> None:
|
||||||
self._num_receiver_ready_polls_sent += 1
|
self._num_receiver_ready_polls_sent += 1
|
||||||
self.channel.send_pdu(
|
self._send_s_frame(
|
||||||
SupervisoryEnhancedControlField(
|
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
|
||||||
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
|
final=1,
|
||||||
final=1,
|
|
||||||
req_seq=self._next_seq_num,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_next_tx_seq(self) -> int:
|
def _get_next_tx_seq(self) -> int:
|
||||||
@@ -987,12 +989,35 @@ class EnhancedRetransmissionProcessor(Processor):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def send_sdu(self, sdu: bytes) -> None:
|
def send_sdu(self, sdu: bytes) -> None:
|
||||||
if len(sdu) > self.peer_mps:
|
if len(sdu) <= self.peer_mps:
|
||||||
raise InvalidArgumentError(
|
pdu = self._PendingPdu(
|
||||||
f'SDU size({len(sdu)}) exceeds channel MPS {self.peer_mps}'
|
payload=sdu,
|
||||||
|
tx_seq=self._get_next_tx_seq(),
|
||||||
|
req_seq=self._req_seq_num,
|
||||||
|
sar=InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
|
||||||
)
|
)
|
||||||
pdu = self._PendingPdu(payload=sdu, tx_seq=self._get_next_tx_seq())
|
self._pending_pdus.append(pdu)
|
||||||
self._pending_pdus.append(pdu)
|
else:
|
||||||
|
for offset in range(0, len(sdu), self.peer_mps):
|
||||||
|
payload = sdu[offset : offset + self.peer_mps]
|
||||||
|
if offset == 0:
|
||||||
|
sar = (
|
||||||
|
InformationEnhancedControlField.SegmentationAndReassembly.START
|
||||||
|
)
|
||||||
|
elif offset + len(payload) >= len(sdu):
|
||||||
|
sar = InformationEnhancedControlField.SegmentationAndReassembly.END
|
||||||
|
else:
|
||||||
|
sar = (
|
||||||
|
InformationEnhancedControlField.SegmentationAndReassembly.CONTINUATION
|
||||||
|
)
|
||||||
|
pdu = self._PendingPdu(
|
||||||
|
payload=payload,
|
||||||
|
tx_seq=self._get_next_tx_seq(),
|
||||||
|
req_seq=self._req_seq_num,
|
||||||
|
sar=sar,
|
||||||
|
sdu_length=len(sdu),
|
||||||
|
)
|
||||||
|
self._pending_pdus.append(pdu)
|
||||||
self._process_output()
|
self._process_output()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@@ -1000,17 +1025,37 @@ class EnhancedRetransmissionProcessor(Processor):
|
|||||||
control_field = EnhancedControlField.from_bytes(pdu)
|
control_field = EnhancedControlField.from_bytes(pdu)
|
||||||
self._update_ack_seq(control_field.req_seq, control_field.final != 0)
|
self._update_ack_seq(control_field.req_seq, control_field.final != 0)
|
||||||
if isinstance(control_field, InformationEnhancedControlField):
|
if isinstance(control_field, InformationEnhancedControlField):
|
||||||
if control_field.tx_seq != self._next_seq_num:
|
if control_field.tx_seq != self._req_seq_num:
|
||||||
|
logger.error(
|
||||||
|
"tx_seq != self._req_seq_num, tx_seq: %d, self._req_seq_num: %d",
|
||||||
|
control_field.tx_seq,
|
||||||
|
self._req_seq_num,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
self._next_seq_num = (self._next_seq_num + 1) % self.MAX_SEQ_NUM
|
self._req_seq_num = (control_field.tx_seq + 1) % self.MAX_SEQ_NUM
|
||||||
self._req_seq_num = self._next_seq_num
|
|
||||||
|
|
||||||
ack_frame = SupervisoryEnhancedControlField(
|
if (
|
||||||
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
|
control_field.sar
|
||||||
req_seq=self._next_seq_num,
|
== InformationEnhancedControlField.SegmentationAndReassembly.START
|
||||||
)
|
):
|
||||||
self.channel.send_pdu(ack_frame)
|
# Drop Control Field(2) + SDU Length(2)
|
||||||
self.channel.on_sdu(pdu[2:])
|
self._in_sdu += pdu[4:]
|
||||||
|
else:
|
||||||
|
# Drop Control Field(2)
|
||||||
|
self._in_sdu += pdu[2:]
|
||||||
|
if control_field.sar in (
|
||||||
|
InformationEnhancedControlField.SegmentationAndReassembly.END,
|
||||||
|
InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
|
||||||
|
):
|
||||||
|
self.channel.on_sdu(self._in_sdu)
|
||||||
|
self._in_sdu = b''
|
||||||
|
|
||||||
|
# If sink doesn't trigger any I-frame, ack this frame.
|
||||||
|
if self._req_seq_num != self._last_acked_rx_seq:
|
||||||
|
self._send_s_frame(
|
||||||
|
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
|
||||||
|
final=0,
|
||||||
|
)
|
||||||
elif isinstance(control_field, SupervisoryEnhancedControlField):
|
elif isinstance(control_field, SupervisoryEnhancedControlField):
|
||||||
self._remote_is_busy = (
|
self._remote_is_busy = (
|
||||||
control_field.supervision_function
|
control_field.supervision_function
|
||||||
@@ -1022,56 +1067,66 @@ class EnhancedRetransmissionProcessor(Processor):
|
|||||||
SupervisoryEnhancedControlField.SupervisoryFunction.RNR,
|
SupervisoryEnhancedControlField.SupervisoryFunction.RNR,
|
||||||
):
|
):
|
||||||
if control_field.poll:
|
if control_field.poll:
|
||||||
self.channel.send_pdu(
|
self._send_s_frame(
|
||||||
SupervisoryEnhancedControlField(
|
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
|
||||||
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
|
final=1,
|
||||||
final=1,
|
|
||||||
req_seq=self._next_seq_num,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# TODO: Handle Retransmission.
|
# TODO: Handle Retransmission.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _process_output(self) -> None:
|
def _process_output(self) -> None:
|
||||||
if self._remote_is_busy or self._monitor_handle:
|
if self._remote_is_busy:
|
||||||
|
logger.debug("Remote is busy")
|
||||||
|
return
|
||||||
|
if self._monitor_handle:
|
||||||
|
logger.debug("Monitor handle is not None")
|
||||||
return
|
return
|
||||||
|
|
||||||
for pdu in self._pending_pdus:
|
pdu_to_send = self.peer_tx_window_size - len(self._tx_window)
|
||||||
if self._num_unacked_frames >= self.peer_tx_window_size:
|
for pdu in itertools.islice(self._pending_pdus, pdu_to_send):
|
||||||
return
|
self._send_i_frame(pdu)
|
||||||
self._send_pdu(pdu)
|
self._pending_pdus = self._pending_pdus[pdu_to_send:]
|
||||||
self._last_tx_seq = pdu.tx_seq
|
|
||||||
|
|
||||||
@property
|
def _send_i_frame(self, pdu: _PendingPdu) -> None:
|
||||||
def _num_unacked_frames(self) -> int:
|
|
||||||
if not self._pending_pdus:
|
|
||||||
return 0
|
|
||||||
return self._num_frames_between(self._expected_ack_seq, self._last_tx_seq + 1)
|
|
||||||
|
|
||||||
def _send_pdu(self, pdu: _PendingPdu) -> None:
|
|
||||||
pdu.req_seq = self._req_seq_num
|
pdu.req_seq = self._req_seq_num
|
||||||
|
|
||||||
self._start_receiver_ready_poll()
|
self._start_receiver_ready_poll()
|
||||||
|
self._tx_window.append(pdu)
|
||||||
self.channel.send_pdu(bytes(pdu))
|
self.channel.send_pdu(bytes(pdu))
|
||||||
|
self._last_acked_rx_seq = self._req_seq_num
|
||||||
|
|
||||||
|
def _send_s_frame(
|
||||||
|
self,
|
||||||
|
supervision_function: SupervisoryEnhancedControlField.SupervisoryFunction,
|
||||||
|
final: int,
|
||||||
|
) -> None:
|
||||||
|
self.channel.send_pdu(
|
||||||
|
SupervisoryEnhancedControlField(
|
||||||
|
supervision_function=supervision_function,
|
||||||
|
final=final,
|
||||||
|
req_seq=self._req_seq_num,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._last_acked_rx_seq = self._req_seq_num
|
||||||
|
|
||||||
def _update_ack_seq(self, new_seq: int, is_poll_response: bool) -> None:
|
def _update_ack_seq(self, new_seq: int, is_poll_response: bool) -> None:
|
||||||
num_frames_acked = self._num_frames_between(self._expected_ack_seq, new_seq)
|
num_frames_acked = (new_seq - self._last_acked_tx_seq) % self.MAX_SEQ_NUM
|
||||||
if num_frames_acked > self._num_unacked_frames:
|
if num_frames_acked > len(self._tx_window):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Received acknowledgment for %d frames but only %d frames are pending",
|
"Received acknowledgment for %d frames but only %d frames are pending",
|
||||||
num_frames_acked,
|
num_frames_acked,
|
||||||
self._num_unacked_frames,
|
len(self._tx_window),
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if is_poll_response and self._monitor_handle:
|
if is_poll_response and self._monitor_handle:
|
||||||
self._monitor_handle.cancel()
|
self._monitor_handle.cancel()
|
||||||
self._monitor_handle = None
|
self._monitor_handle = None
|
||||||
|
|
||||||
del self._pending_pdus[:num_frames_acked]
|
del self._tx_window[:num_frames_acked]
|
||||||
self._expected_ack_seq = new_seq
|
self._last_acked_tx_seq = new_seq
|
||||||
if (
|
if (
|
||||||
self._expected_ack_seq == self._next_tx_seq
|
self._last_acked_tx_seq == self._next_tx_seq
|
||||||
and self._receiver_ready_poll_handle
|
and self._receiver_ready_poll_handle
|
||||||
):
|
):
|
||||||
self._receiver_ready_poll_handle.cancel()
|
self._receiver_ready_poll_handle.cancel()
|
||||||
|
|||||||
@@ -239,20 +239,7 @@ async def transfer_payload(
|
|||||||
channels[1].sink = received.put_nowait
|
channels[1].sink = received.put_nowait
|
||||||
sdu_lengths = (21, 70, 700, 5523)
|
sdu_lengths = (21, 70, 700, 5523)
|
||||||
|
|
||||||
if isinstance(channels[1], l2cap.LeCreditBasedChannel):
|
messages = [bytes([i % 8 for i in range(sdu_length)]) for sdu_length in sdu_lengths]
|
||||||
mps = channels[1].mps
|
|
||||||
elif isinstance(
|
|
||||||
processor := channels[1].processor, l2cap.EnhancedRetransmissionProcessor
|
|
||||||
):
|
|
||||||
mps = processor.mps
|
|
||||||
else:
|
|
||||||
mps = channels[1].mtu
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
bytes([i % 8 for i in range(sdu_length)])
|
|
||||||
for sdu_length in sdu_lengths
|
|
||||||
if sdu_length <= mps
|
|
||||||
]
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
channels[0].write(message)
|
channels[0].write(message)
|
||||||
if isinstance(channels[0], l2cap.LeCreditBasedChannel):
|
if isinstance(channels[0], l2cap.LeCreditBasedChannel):
|
||||||
@@ -334,20 +321,26 @@ async def test_mtu():
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_enhanced_retransmission_mode():
|
@pytest.mark.parametrize("mtu,", (50, 255, 256, 1000))
|
||||||
|
async def test_enhanced_retransmission_mode(mtu: int):
|
||||||
devices = TwoDevices()
|
devices = TwoDevices()
|
||||||
await devices.setup_connection()
|
await devices.setup_connection()
|
||||||
|
|
||||||
server_channels = asyncio.Queue[l2cap.ClassicChannel]()
|
server_channels = asyncio.Queue[l2cap.ClassicChannel]()
|
||||||
server = devices.devices[1].create_l2cap_server(
|
server = devices.devices[1].create_l2cap_server(
|
||||||
spec=l2cap.ClassicChannelSpec(
|
spec=l2cap.ClassicChannelSpec(
|
||||||
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION
|
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
|
||||||
|
mtu=mtu,
|
||||||
|
mps=256,
|
||||||
),
|
),
|
||||||
handler=server_channels.put_nowait,
|
handler=server_channels.put_nowait,
|
||||||
)
|
)
|
||||||
client_channel = await devices.connections[0].create_l2cap_channel(
|
client_channel = await devices.connections[0].create_l2cap_channel(
|
||||||
spec=l2cap.ClassicChannelSpec(
|
spec=l2cap.ClassicChannelSpec(
|
||||||
server.psm, mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION
|
server.psm,
|
||||||
|
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
|
||||||
|
mtu=mtu,
|
||||||
|
mps=1024,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
server_channel = await server_channels.get()
|
server_channel = await server_channels.get()
|
||||||
|
|||||||
Reference in New Issue
Block a user