Merge pull request #840 from zxzxwu/credit

L2CAP: Enhanced Credit-based Flow Control Mode
This commit is contained in:
zxzxwu
2025-12-30 20:26:44 +08:00
committed by GitHub
2 changed files with 489 additions and 192 deletions

View File

@@ -16,15 +16,17 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import itertools
import logging
import os
import random
import struct
from collections.abc import Sequence
from unittest import mock
import pytest
from bumble import l2cap
from bumble.core import ProtocolError
from bumble import core, l2cap
from .test_utils import TwoDevices, async_barrier
@@ -143,7 +145,7 @@ async def test_basic_connection():
psm = 1234
# Check that if there's no one listening, we can't connect
with pytest.raises(ProtocolError):
with pytest.raises(core.ProtocolError):
l2cap_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(psm)
)
@@ -231,48 +233,63 @@ async def test_l2cap_information_request(monkeypatch, info_type):
# -----------------------------------------------------------------------------
async def transfer_payload(max_credits, mtu, mps):
devices = TwoDevices()
await devices.setup_connection()
async def transfer_payload(
channels: Sequence[l2cap.ClassicChannel | l2cap.LeCreditBasedChannel],
):
received = asyncio.Queue[bytes]()
channels[1].sink = received.put_nowait
sdu_lengths = (21, 70, 700, 5523)
received = []
if isinstance(channels[1], l2cap.LeCreditBasedChannel):
mps = channels[1].mps
elif isinstance(
processor := channels[1].processor, l2cap.EnhancedRetransmissionProcessor
):
mps = processor.mps
else:
mps = channels[1].mtu
def on_coc(channel):
def on_data(data):
received.append(data)
channel.sink = on_data
server = devices.devices[1].create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps),
handler=on_coc,
)
l2cap_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(server.psm)
)
messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100, 789)]
messages = [
bytes([i % 8 for i in range(sdu_length)])
for sdu_length in sdu_lengths
if sdu_length <= mps
]
for message in messages:
l2cap_channel.write(message)
await asyncio.sleep(0)
if random.randint(0, 5) == 1:
await l2cap_channel.drain()
channels[0].write(message)
if isinstance(channels[0], l2cap.LeCreditBasedChannel):
if random.randint(0, 5) == 1:
await channels[0].drain()
await l2cap_channel.drain()
await l2cap_channel.disconnect()
if isinstance(channels[0], l2cap.LeCreditBasedChannel):
await channels[0].drain()
sent_bytes = b''.join(messages)
received_bytes = b''.join(received)
received_bytes = b''
while len(received_bytes) < len(sent_bytes):
received_bytes += await received.get()
assert sent_bytes == received_bytes
@pytest.mark.asyncio
async def test_transfer():
for max_credits in (1, 10, 100, 10000):
for mtu in (50, 255, 256, 1000):
for mps in (50, 255, 256, 1000):
# print(max_credits, mtu, mps)
await transfer_payload(max_credits, mtu, mps)
@pytest.mark.parametrize(
"max_credits, mtu, mps",
itertools.product((1, 10, 100, 10000), (50, 255, 256, 1000), (50, 255, 256, 1000)),
)
async def test_transfer(max_credits: int, mtu: int, mps: int):
devices = await TwoDevices.create_with_connection()
server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
server = devices[1].create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps),
handler=server_channels.put_nowait,
)
assert (connection := devices.connections[0])
client = await connection.create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(server.psm)
)
server_channel = await server_channels.get()
await transfer_payload((client, server_channel))
await client.disconnect()
# -----------------------------------------------------------------------------
@@ -281,45 +298,18 @@ async def test_bidirectional_transfer():
devices = TwoDevices()
await devices.setup_connection()
client_received = []
server_received = []
server_channel = None
def on_server_coc(channel):
nonlocal server_channel
server_channel = channel
def on_server_data(data):
server_received.append(data)
channel.sink = on_server_data
def on_client_data(data):
client_received.append(data)
server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
server = devices.devices[1].create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(), handler=on_server_coc
spec=l2cap.LeCreditBasedChannelSpec(),
handler=server_channels.put_nowait,
)
client_channel = await devices.connections[0].create_l2cap_channel(
client = await devices.connections[0].create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(server.psm)
)
client_channel.sink = on_client_data
messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100)]
for message in messages:
client_channel.write(message)
await client_channel.drain()
await asyncio.sleep(0)
server_channel.write(message)
await server_channel.drain()
await client_channel.disconnect()
message_bytes = b''.join(messages)
client_received_bytes = b''.join(client_received)
server_received_bytes = b''.join(server_received)
assert client_received_bytes == message_bytes
assert server_received_bytes == message_bytes
server_channel = await server_channels.get()
await transfer_payload((client, server_channel))
await transfer_payload((server_channel, client))
await client.disconnect()
# -----------------------------------------------------------------------------
@@ -363,18 +353,8 @@ async def test_enhanced_retransmission_mode():
)
server_channel = await server_channels.get()
sinks = [asyncio.Queue[bytes]() for _ in range(2)]
server_channel.sink = sinks[0].put_nowait
client_channel.sink = sinks[1].put_nowait
for i in range(128):
server_channel.write(struct.pack('<I', i))
for i in range(128):
assert (await sinks[1].get()) == struct.pack('<I', i)
for i in range(128):
client_channel.write(struct.pack('<I', i))
for i in range(128):
assert (await sinks[0].get()) == struct.pack('<I', i)
await transfer_payload((client_channel, server_channel))
await transfer_payload((server_channel, client_channel))
# -----------------------------------------------------------------------------
@@ -399,6 +379,78 @@ async def test_mode_mismatching(server_mode, client_mode):
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_enhanced_credit_based_flow_control_connection():
devices = await TwoDevices.create_with_connection()
server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
server = devices[1].create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(), handler=server_channels.put_nowait
)
client_channels = await devices[
0
].l2cap_channel_manager.create_enhanced_credit_based_channels(
devices.connections[0], l2cap.LeCreditBasedChannelSpec(psm=server.psm), count=5
)
assert len(client_channels) == 5
for client_channel in client_channels:
server_channel = await server_channels.get()
await transfer_payload((client_channel, server_channel))
await transfer_payload((server_channel, client_channel))
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_enhanced_credit_based_flow_control_connection_failure_no_psm():
devices = await TwoDevices.create_with_connection()
with pytest.raises(l2cap.L2capError):
await devices[0].l2cap_channel_manager.create_enhanced_credit_based_channels(
devices.connections[0], l2cap.LeCreditBasedChannelSpec(psm=12345), count=5
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_enhanced_credit_based_flow_control_connection_failure_insufficient_resource_client_side():
devices = await TwoDevices.create_with_connection()
server = devices[1].create_l2cap_server(spec=l2cap.LeCreditBasedChannelSpec())
with pytest.raises(core.OutOfResourcesError):
await devices[0].l2cap_channel_manager.create_enhanced_credit_based_channels(
devices.connections[0],
l2cap.LeCreditBasedChannelSpec(server.psm),
count=(
l2cap.L2CAP_LE_U_DYNAMIC_CID_RANGE_END
- l2cap.L2CAP_LE_U_DYNAMIC_CID_RANGE_START
)
* 2,
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_enhanced_credit_based_flow_control_connection_failure_insufficient_resource_server_side():
devices = await TwoDevices.create_with_connection()
server = devices[1].create_l2cap_server(spec=l2cap.LeCreditBasedChannelSpec())
# Simulate that the server side has no available CID.
channels = {
cid: mock.Mock()
for cid in range(
l2cap.L2CAP_LE_U_DYNAMIC_CID_RANGE_START,
l2cap.L2CAP_LE_U_DYNAMIC_CID_RANGE_END + 1,
)
}
devices[1].l2cap_channel_manager.channels[devices.connections[1].handle] = channels
with pytest.raises(l2cap.L2capError):
await devices[0].l2cap_channel_manager.create_enhanced_credit_based_channels(
devices.connections[0], l2cap.LeCreditBasedChannelSpec(server.psm), count=1
)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'cid, payload, expected',