Merge pull request #840 from zxzxwu/credit

L2CAP: Enhanced Credit-based Flow Control Mode
This commit is contained in:
zxzxwu
2025-12-30 20:26:44 +08:00
committed by GitHub
2 changed files with 489 additions and 192 deletions

View File

@@ -829,7 +829,7 @@ class L2CAP_Credit_Based_Connection_Response(L2CAP_Control_Frame):
mtu: int = dataclasses.field(metadata=hci.metadata(2)) mtu: int = dataclasses.field(metadata=hci.metadata(2))
mps: int = dataclasses.field(metadata=hci.metadata(2)) mps: int = dataclasses.field(metadata=hci.metadata(2))
initial_credits: int = dataclasses.field(metadata=hci.metadata(2)) initial_credits: int = dataclasses.field(metadata=hci.metadata(2))
result: int = dataclasses.field(metadata=Result.type_metadata(2)) result: Result = dataclasses.field(metadata=Result.type_metadata(2))
destination_cid: Sequence[int] = dataclasses.field( destination_cid: Sequence[int] = dataclasses.field(
metadata=L2CAP_Credit_Based_Connection_Request.CID_METADATA metadata=L2CAP_Credit_Based_Connection_Request.CID_METADATA
) )
@@ -1559,7 +1559,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
self, self,
manager: ChannelManager, manager: ChannelManager,
connection: Connection, connection: Connection,
le_psm: int, psm: int,
source_cid: int, source_cid: int,
destination_cid: int, destination_cid: int,
mtu: int, mtu: int,
@@ -1573,7 +1573,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
super().__init__() super().__init__()
self.manager = manager self.manager = manager
self.connection = connection self.connection = connection
self.le_psm = le_psm self.psm = psm
self.source_cid = source_cid self.source_cid = source_cid
self.destination_cid = destination_cid self.destination_cid = destination_cid
self.mtu = mtu self.mtu = mtu
@@ -1629,7 +1629,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
self._change_state(self.State.CONNECTING) self._change_state(self.State.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request( request = L2CAP_LE_Credit_Based_Connection_Request(
identifier=identifier, identifier=identifier,
le_psm=self.le_psm, le_psm=self.psm,
source_cid=self.source_cid, source_cid=self.source_cid,
mtu=self.mtu, mtu=self.mtu,
mps=self.mps, mps=self.mps,
@@ -1772,6 +1772,22 @@ class LeCreditBasedChannel(utils.EventEmitter):
# Cleanup # Cleanup
self.connection_result = None self.connection_result = None
def on_enhanced_connection_response(
self, destination_cid: int, response: L2CAP_Credit_Based_Connection_Response
) -> None:
if (
response.result
== L2CAP_Credit_Based_Connection_Response.Result.ALL_CONNECTIONS_SUCCESSFUL
):
self.destination_cid = destination_cid
self.peer_mtu = response.mtu
self.peer_mps = response.mps
self.credits = response.initial_credits
self.connected = True
self._change_state(self.State.CONNECTED)
else:
self._change_state(self.State.CONNECTION_ERROR)
def on_credits(self, credits: int) -> None: # pylint: disable=redefined-builtin def on_credits(self, credits: int) -> None: # pylint: disable=redefined-builtin
self.credits += credits self.credits += credits
logger.debug(f'received {credits} credits, total = {self.credits}') logger.debug(f'received {credits} credits, total = {self.credits}')
@@ -1884,7 +1900,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
return ( return (
f'CoC({self.source_cid}->{self.destination_cid}, ' f'CoC({self.source_cid}->{self.destination_cid}, '
f'State={self.state.name}, ' f'State={self.state.name}, '
f'PSM={self.le_psm}, ' f'PSM={self.psm}, '
f'MTU={self.mtu}/{self.peer_mtu}, ' f'MTU={self.mtu}/{self.peer_mtu}, '
f'MPS={self.mps}/{self.peer_mps}, ' f'MPS={self.mps}/{self.peer_mps}, '
f'credits={self.credits}/{self.peer_credits})' f'credits={self.credits}/{self.peer_credits})'
@@ -1958,6 +1974,16 @@ class ChannelManager:
le_coc_servers: dict[int, LeCreditBasedChannelServer] le_coc_servers: dict[int, LeCreditBasedChannelServer]
le_coc_requests: dict[int, L2CAP_LE_Credit_Based_Connection_Request] le_coc_requests: dict[int, L2CAP_LE_Credit_Based_Connection_Request]
fixed_channels: dict[int, Optional[Callable[[int, bytes], Any]]] fixed_channels: dict[int, Optional[Callable[[int, bytes], Any]]]
pending_credit_based_connections: dict[
int,
dict[
int,
tuple[
asyncio.Future[None],
list[LeCreditBasedChannel],
],
],
]
_host: Optional[Host] _host: Optional[Host]
connection_parameters_update_response: Optional[asyncio.Future[int]] connection_parameters_update_response: Optional[asyncio.Future[int]]
@@ -1979,6 +2005,9 @@ class ChannelManager:
) # LE CoC channels, mapped by connection and destination cid ) # LE CoC channels, mapped by connection and destination cid
self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM
self.le_coc_requests = {} # LE CoC connection requests, by identifier self.le_coc_requests = {} # LE CoC connection requests, by identifier
self.pending_credit_based_connections = (
{}
) # Credit-based connection request contexts, by connection handle and identifier
self.extended_features = set(extended_features) self.extended_features = set(extended_features)
self.connectionless_mtu = connectionless_mtu self.connectionless_mtu = connectionless_mtu
self.connection_parameters_update_response = None self.connection_parameters_update_response = None
@@ -2021,18 +2050,26 @@ class ChannelManager:
raise OutOfResourcesError('no free CID available') raise OutOfResourcesError('no free CID available')
@staticmethod @classmethod
def find_free_le_cid(channels: Iterable[int]) -> int: def find_free_le_cid(cls, channels: Iterable[int]) -> int | None:
cids = cls.find_free_le_cids(channels, 1)
return cids[0] if cids else None
@classmethod
def find_free_le_cids(cls, channels: Iterable[int], count: int) -> list[int]:
# Pick the smallest valid CID that's not already in the list # Pick the smallest valid CID that's not already in the list
# (not necessarily the most efficient algorithm, but the list of CID is # (not necessarily the most efficient algorithm, but the list of CID is
# very small in practice) # very small in practice)
cids: list[int] = []
for cid in range( for cid in range(
L2CAP_LE_U_DYNAMIC_CID_RANGE_START, L2CAP_LE_U_DYNAMIC_CID_RANGE_END + 1 L2CAP_LE_U_DYNAMIC_CID_RANGE_START, L2CAP_LE_U_DYNAMIC_CID_RANGE_END + 1
): ):
if cid not in channels: if cid not in channels:
return cid cids.append(cid)
if len(cids) == count:
return cids
raise OutOfResourcesError('no free CID') return []
def next_identifier(self, connection: Connection) -> int: def next_identifier(self, connection: Connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256 identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
@@ -2119,18 +2156,22 @@ class ChannelManager:
return self.le_coc_servers[spec.psm] return self.le_coc_servers[spec.psm]
def on_disconnection(self, connection_handle: int, _reason: int) -> None: def on_disconnection(self, connection_handle: int, reason: int) -> None:
logger.debug(f'disconnection from {connection_handle}, cleaning up channels') del reason # unused.
if connection_handle in self.channels: logger.debug('disconnection from %d, cleaning up channels', connection_handle)
for _, channel in self.channels[connection_handle].items(): if channels := self.channels.pop(connection_handle, None):
for channel in channels.values():
channel.abort() channel.abort()
del self.channels[connection_handle] if le_coc_channels := self.le_coc_channels.pop(connection_handle, None):
if connection_handle in self.le_coc_channels: for le_coc_channel in le_coc_channels.values():
for _, channel in self.le_coc_channels[connection_handle].items(): le_coc_channel.abort()
channel.abort() if pending_credit_based_connections := self.pending_credit_based_connections.pop(
del self.le_coc_channels[connection_handle] connection_handle, None
if connection_handle in self.identifiers: ):
del self.identifiers[connection_handle] for future, _ in pending_credit_based_connections.values():
if not future.done():
future.cancel("ACL disconnected")
self.identifiers.pop(connection_handle, None)
def send_pdu( def send_pdu(
self, self,
@@ -2242,7 +2283,6 @@ class ChannelManager:
identifier=request.identifier, identifier=request.identifier,
destination_cid=request.source_cid, destination_cid=request.source_cid,
source_cid=0, source_cid=0,
# pylint: disable=line-too-long
result=L2CAP_Connection_Response.Result.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, result=L2CAP_Connection_Response.Result.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
status=0x0000, status=0x0000,
), ),
@@ -2273,7 +2313,6 @@ class ChannelManager:
identifier=request.identifier, identifier=request.identifier,
destination_cid=request.source_cid, destination_cid=request.source_cid,
source_cid=0, source_cid=0,
# pylint: disable=line-too-long
result=L2CAP_Connection_Response.Result.CONNECTION_REFUSED_PSM_NOT_SUPPORTED, result=L2CAP_Connection_Response.Result.CONNECTION_REFUSED_PSM_NOT_SUPPORTED,
status=0x0000, status=0x0000,
), ),
@@ -2502,89 +2541,7 @@ class ChannelManager:
cid: int, cid: int,
request: L2CAP_LE_Credit_Based_Connection_Request, request: L2CAP_LE_Credit_Based_Connection_Request,
) -> None: ) -> None:
if request.le_psm in self.le_coc_servers: if not (server := self.le_coc_servers.get(request.le_psm)):
server = self.le_coc_servers[request.le_psm]
# Check that the CID isn't already used
le_connection_channels = self.le_coc_channels.setdefault(
connection.handle, {}
)
if request.source_cid in le_connection_channels:
logger.warning(f'source CID {request.source_cid} already in use')
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
),
)
return
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
),
)
return
# Create a new channel
logger.debug(
f'creating LE CoC server channel with cid={source_cid} for psm '
f'{request.le_psm}'
)
channel = LeCreditBasedChannel(
self,
connection,
request.le_psm,
source_cid,
request.source_cid,
server.mtu,
server.mps,
request.initial_credits,
request.mtu,
request.mps,
server.max_credits,
True,
)
connection_channels[source_cid] = channel
le_connection_channels[request.source_cid] = channel
# Respond
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=source_cid,
mtu=server.mtu,
mps=server.mps,
initial_credits=server.max_credits,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_SUCCESSFUL,
),
)
# Notify
server.on_connection(channel)
else:
logger.info( logger.info(
f'No LE server for connection 0x{connection.handle:04X} ' f'No LE server for connection 0x{connection.handle:04X} '
f'on PSM {request.le_psm}' f'on PSM {request.le_psm}'
@@ -2598,10 +2555,86 @@ class ChannelManager:
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
initial_credits=0, initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED, result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED,
), ),
) )
return
# Check that the CID isn't already used
le_connection_channels = self.le_coc_channels.setdefault(connection.handle, {})
if request.source_cid in le_connection_channels:
logger.warning(f'source CID {request.source_cid} already in use')
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
initial_credits=0,
result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
),
)
return
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
initial_credits=0,
result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
),
)
return
# Create a new channel
logger.debug(
f'creating LE CoC server channel with cid={source_cid} for psm '
f'{request.le_psm}'
)
channel = LeCreditBasedChannel(
self,
connection,
request.le_psm,
source_cid,
request.source_cid,
server.mtu,
server.mps,
request.initial_credits,
request.mtu,
request.mps,
server.max_credits,
True,
)
connection_channels[source_cid] = channel
le_connection_channels[request.source_cid] = channel
# Respond
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=source_cid,
mtu=server.mtu,
mps=server.mps,
initial_credits=server.max_credits,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_SUCCESSFUL,
),
)
# Notify
server.on_connection(channel)
def on_l2cap_le_credit_based_connection_response( def on_l2cap_le_credit_based_connection_response(
self, self,
@@ -2610,11 +2643,9 @@ class ChannelManager:
response: L2CAP_LE_Credit_Based_Connection_Response, response: L2CAP_LE_Credit_Based_Connection_Response,
) -> None: ) -> None:
# Find the pending request by identifier # Find the pending request by identifier
request = self.le_coc_requests.get(response.identifier) if not (request := self.le_coc_requests.pop(response.identifier, None)):
if request is None:
logger.warning(color('!!! received response for unknown request', 'red')) logger.warning(color('!!! received response for unknown request', 'red'))
return return
del self.le_coc_requests[response.identifier]
# Find the channel for this request # Find the channel for this request
channel = self.find_channel(connection.handle, request.source_cid) channel = self.find_channel(connection.handle, request.source_cid)
@@ -2631,6 +2662,147 @@ class ChannelManager:
# Process the response # Process the response
channel.on_connection_response(response) channel.on_connection_response(response)
def on_l2cap_credit_based_connection_request(
self,
connection: Connection,
cid: int,
request: L2CAP_Credit_Based_Connection_Request,
) -> None:
if not (server := self.le_coc_servers.get(request.spsm)):
logger.info(
'No LE server for connection 0x%04X ' 'on PSM %d',
connection.handle,
request.spsm,
)
self.send_control_frame(
connection,
cid,
L2CAP_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=[],
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
initial_credits=0,
result=L2CAP_Credit_Based_Connection_Response.Result.ALL_CONNECTIONS_REFUSED_SPSM_NOT_SUPPORTED,
),
)
return
# Check that the CID isn't already used
le_connection_channels = self.le_coc_channels.setdefault(connection.handle, {})
if cid_in_use := set(request.source_cid).intersection(
set(le_connection_channels)
):
logger.warning('source CID already in use: %s', cid_in_use)
self.send_control_frame(
connection,
cid,
L2CAP_Credit_Based_Connection_Response(
identifier=request.identifier,
mtu=server.mtu,
mps=server.mps,
initial_credits=0,
result=L2CAP_Credit_Based_Connection_Response.Result.SOME_CONNECTIONS_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
destination_cid=[],
),
)
return
# Find free CIDs for new channels
connection_channels = self.channels.setdefault(connection.handle, {})
source_cids = self.find_free_le_cids(
connection_channels, len(request.source_cid)
)
if not source_cids:
self.send_control_frame(
connection,
cid,
L2CAP_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=[],
mtu=server.mtu,
mps=server.mps,
initial_credits=server.max_credits,
result=L2CAP_Credit_Based_Connection_Response.Result.SOME_CONNECTIONS_REFUSED_INSUFFICIENT_RESOURCES_AVAILABLE,
),
)
return
for destination_cid in request.source_cid:
# TODO: Handle Classic channels.
if not (source_cid := self.find_free_le_cid(connection_channels)):
logger.warning("No free CIDs available")
break
# Create a new channel
logger.debug(
'creating LE CoC server channel with cid=%s for psm %s',
source_cid,
request.spsm,
)
channel = LeCreditBasedChannel(
self,
connection,
request.spsm,
source_cid,
destination_cid,
server.mtu,
server.mps,
request.initial_credits,
request.mtu,
request.mps,
server.max_credits,
True,
)
connection_channels[source_cid] = channel
le_connection_channels[source_cid] = channel
server.on_connection(channel)
# Respond
self.send_control_frame(
connection,
cid,
L2CAP_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=source_cids,
mtu=server.mtu,
mps=server.mps,
initial_credits=server.max_credits,
result=L2CAP_Credit_Based_Connection_Response.Result.ALL_CONNECTIONS_SUCCESSFUL,
),
)
def on_l2cap_credit_based_connection_response(
self,
connection: Connection,
_cid: int,
response: L2CAP_Credit_Based_Connection_Response,
) -> None:
# Find the pending request by identifier
pending_connections = self.pending_credit_based_connections.setdefault(
connection.handle, {}
)
if not (
pending_connection := pending_connections.pop(response.identifier, None)
):
logger.warning(color('!!! received response for unknown request', 'red'))
return
connection_result, channels = pending_connection
# Process the response
for channel, destination_cid in zip(channels, response.destination_cid):
channel.on_enhanced_connection_response(destination_cid, response)
if (
response.result
== L2CAP_Credit_Based_Connection_Response.Result.ALL_CONNECTIONS_SUCCESSFUL
):
connection_result.set_result(None)
else:
connection_result.set_exception(
L2capError(response.result, response.result.name)
)
def on_l2cap_le_flow_control_credit( def on_l2cap_le_flow_control_credit(
self, connection: Connection, _cid: int, credit: L2CAP_LE_Flow_Control_Credit self, connection: Connection, _cid: int, credit: L2CAP_LE_Flow_Control_Credit
) -> None: ) -> None:
@@ -2666,7 +2838,7 @@ class ChannelManager:
channel = LeCreditBasedChannel( channel = LeCreditBasedChannel(
manager=self, manager=self,
connection=connection, connection=connection,
le_psm=spec.psm, psm=spec.psm,
source_cid=source_cid, source_cid=source_cid,
destination_cid=0, destination_cid=0,
mtu=spec.mtu, mtu=spec.mtu,
@@ -2730,6 +2902,79 @@ class ChannelManager:
return channel return channel
async def create_enhanced_credit_based_channels(
self,
connection: Connection,
spec: LeCreditBasedChannelSpec,
count: int,
) -> list[LeCreditBasedChannel]:
# Find a free CID for the new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cids = self.find_free_le_cids(connection_channels, count)
if not source_cids: # Should never happen!
raise OutOfResourcesError('all CIDs already in use')
if spec.psm is None:
raise InvalidArgumentError('PSM cannot be None')
# Create the channel
logger.debug(
'creating coc channel with cid=%s for psm %s', source_cids, spec.psm
)
channels: list[LeCreditBasedChannel] = []
for source_cid in source_cids:
channel = LeCreditBasedChannel(
manager=self,
connection=connection,
psm=spec.psm,
source_cid=source_cid,
destination_cid=0,
mtu=spec.mtu,
mps=spec.mps,
credits=0,
peer_mtu=0,
peer_mps=0,
peer_credits=spec.max_credits,
connected=False,
)
connection_channels[source_cid] = channel
channels.append(channel)
identifier = self.next_identifier(connection)
request = L2CAP_Credit_Based_Connection_Request(
identifier=identifier,
spsm=spec.psm,
mtu=spec.mtu,
mps=spec.mps,
initial_credits=spec.max_credits,
source_cid=source_cids,
)
connection_result = asyncio.get_running_loop().create_future()
pending_connections = self.pending_credit_based_connections.setdefault(
connection.handle, {}
)
pending_connections[identifier] = (connection_result, channels)
self.send_control_frame(
connection,
L2CAP_LE_SIGNALING_CID,
request,
)
# Connect
try:
await connection_result
except Exception:
logger.exception('connection failed')
for cid in source_cids:
del connection_channels[cid]
raise
# Remember the channel by source CID and destination CID
le_connection_channels = self.le_coc_channels.setdefault(connection.handle, {})
for channel in channels:
le_connection_channels[channel.destination_cid] = channel
return channels
@classmethod @classmethod
def make_mode_processor( def make_mode_processor(
self, self,

View File

@@ -16,15 +16,17 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import itertools
import logging import logging
import os import os
import random import random
import struct import struct
from collections.abc import Sequence
from unittest import mock
import pytest import pytest
from bumble import l2cap from bumble import core, l2cap
from bumble.core import ProtocolError
from .test_utils import TwoDevices, async_barrier from .test_utils import TwoDevices, async_barrier
@@ -143,7 +145,7 @@ async def test_basic_connection():
psm = 1234 psm = 1234
# Check that if there's no one listening, we can't connect # Check that if there's no one listening, we can't connect
with pytest.raises(ProtocolError): with pytest.raises(core.ProtocolError):
l2cap_channel = await devices.connections[0].create_l2cap_channel( l2cap_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(psm) spec=l2cap.LeCreditBasedChannelSpec(psm)
) )
@@ -231,48 +233,63 @@ async def test_l2cap_information_request(monkeypatch, info_type):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def transfer_payload(max_credits, mtu, mps): async def transfer_payload(
devices = TwoDevices() channels: Sequence[l2cap.ClassicChannel | l2cap.LeCreditBasedChannel],
await devices.setup_connection() ):
received = asyncio.Queue[bytes]()
channels[1].sink = received.put_nowait
sdu_lengths = (21, 70, 700, 5523)
received = [] if isinstance(channels[1], l2cap.LeCreditBasedChannel):
mps = channels[1].mps
elif isinstance(
processor := channels[1].processor, l2cap.EnhancedRetransmissionProcessor
):
mps = processor.mps
else:
mps = channels[1].mtu
def on_coc(channel): messages = [
def on_data(data): bytes([i % 8 for i in range(sdu_length)])
received.append(data) for sdu_length in sdu_lengths
if sdu_length <= mps
channel.sink = on_data ]
server = devices.devices[1].create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps),
handler=on_coc,
)
l2cap_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(server.psm)
)
messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100, 789)]
for message in messages: for message in messages:
l2cap_channel.write(message) channels[0].write(message)
await asyncio.sleep(0) if isinstance(channels[0], l2cap.LeCreditBasedChannel):
if random.randint(0, 5) == 1: if random.randint(0, 5) == 1:
await l2cap_channel.drain() await channels[0].drain()
await l2cap_channel.drain() if isinstance(channels[0], l2cap.LeCreditBasedChannel):
await l2cap_channel.disconnect() await channels[0].drain()
sent_bytes = b''.join(messages) sent_bytes = b''.join(messages)
received_bytes = b''.join(received) received_bytes = b''
while len(received_bytes) < len(sent_bytes):
received_bytes += await received.get()
assert sent_bytes == received_bytes assert sent_bytes == received_bytes
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transfer(): @pytest.mark.parametrize(
for max_credits in (1, 10, 100, 10000): "max_credits, mtu, mps",
for mtu in (50, 255, 256, 1000): itertools.product((1, 10, 100, 10000), (50, 255, 256, 1000), (50, 255, 256, 1000)),
for mps in (50, 255, 256, 1000): )
# print(max_credits, mtu, mps) async def test_transfer(max_credits: int, mtu: int, mps: int):
await transfer_payload(max_credits, mtu, mps) devices = await TwoDevices.create_with_connection()
server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
server = devices[1].create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps),
handler=server_channels.put_nowait,
)
assert (connection := devices.connections[0])
client = await connection.create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(server.psm)
)
server_channel = await server_channels.get()
await transfer_payload((client, server_channel))
await client.disconnect()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -281,45 +298,18 @@ async def test_bidirectional_transfer():
devices = TwoDevices() devices = TwoDevices()
await devices.setup_connection() await devices.setup_connection()
client_received = [] server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
server_received = []
server_channel = None
def on_server_coc(channel):
nonlocal server_channel
server_channel = channel
def on_server_data(data):
server_received.append(data)
channel.sink = on_server_data
def on_client_data(data):
client_received.append(data)
server = devices.devices[1].create_l2cap_server( server = devices.devices[1].create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(), handler=on_server_coc spec=l2cap.LeCreditBasedChannelSpec(),
handler=server_channels.put_nowait,
) )
client_channel = await devices.connections[0].create_l2cap_channel( client = await devices.connections[0].create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(server.psm) spec=l2cap.LeCreditBasedChannelSpec(server.psm)
) )
client_channel.sink = on_client_data server_channel = await server_channels.get()
await transfer_payload((client, server_channel))
messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100)] await transfer_payload((server_channel, client))
for message in messages: await client.disconnect()
client_channel.write(message)
await client_channel.drain()
await asyncio.sleep(0)
server_channel.write(message)
await server_channel.drain()
await client_channel.disconnect()
message_bytes = b''.join(messages)
client_received_bytes = b''.join(client_received)
server_received_bytes = b''.join(server_received)
assert client_received_bytes == message_bytes
assert server_received_bytes == message_bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -363,18 +353,8 @@ async def test_enhanced_retransmission_mode():
) )
server_channel = await server_channels.get() server_channel = await server_channels.get()
sinks = [asyncio.Queue[bytes]() for _ in range(2)] await transfer_payload((client_channel, server_channel))
server_channel.sink = sinks[0].put_nowait await transfer_payload((server_channel, client_channel))
client_channel.sink = sinks[1].put_nowait
for i in range(128):
server_channel.write(struct.pack('<I', i))
for i in range(128):
assert (await sinks[1].get()) == struct.pack('<I', i)
for i in range(128):
client_channel.write(struct.pack('<I', i))
for i in range(128):
assert (await sinks[0].get()) == struct.pack('<I', i)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -399,6 +379,78 @@ async def test_mode_mismatching(server_mode, client_mode):
) )
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_enhanced_credit_based_flow_control_connection():
devices = await TwoDevices.create_with_connection()
server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
server = devices[1].create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(), handler=server_channels.put_nowait
)
client_channels = await devices[
0
].l2cap_channel_manager.create_enhanced_credit_based_channels(
devices.connections[0], l2cap.LeCreditBasedChannelSpec(psm=server.psm), count=5
)
assert len(client_channels) == 5
for client_channel in client_channels:
server_channel = await server_channels.get()
await transfer_payload((client_channel, server_channel))
await transfer_payload((server_channel, client_channel))
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_enhanced_credit_based_flow_control_connection_failure_no_psm():
devices = await TwoDevices.create_with_connection()
with pytest.raises(l2cap.L2capError):
await devices[0].l2cap_channel_manager.create_enhanced_credit_based_channels(
devices.connections[0], l2cap.LeCreditBasedChannelSpec(psm=12345), count=5
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_enhanced_credit_based_flow_control_connection_failure_insufficient_resource_client_side():
devices = await TwoDevices.create_with_connection()
server = devices[1].create_l2cap_server(spec=l2cap.LeCreditBasedChannelSpec())
with pytest.raises(core.OutOfResourcesError):
await devices[0].l2cap_channel_manager.create_enhanced_credit_based_channels(
devices.connections[0],
l2cap.LeCreditBasedChannelSpec(server.psm),
count=(
l2cap.L2CAP_LE_U_DYNAMIC_CID_RANGE_END
- l2cap.L2CAP_LE_U_DYNAMIC_CID_RANGE_START
)
* 2,
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_enhanced_credit_based_flow_control_connection_failure_insufficient_resource_server_side():
devices = await TwoDevices.create_with_connection()
server = devices[1].create_l2cap_server(spec=l2cap.LeCreditBasedChannelSpec())
# Simulate that the server side has no available CID.
channels = {
cid: mock.Mock()
for cid in range(
l2cap.L2CAP_LE_U_DYNAMIC_CID_RANGE_START,
l2cap.L2CAP_LE_U_DYNAMIC_CID_RANGE_END + 1,
)
}
devices[1].l2cap_channel_manager.channels[devices.connections[1].handle] = channels
with pytest.raises(l2cap.L2capError):
await devices[0].l2cap_channel_manager.create_enhanced_credit_based_channels(
devices.connections[0], l2cap.LeCreditBasedChannelSpec(server.psm), count=1
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.parametrize( @pytest.mark.parametrize(
'cid, payload, expected', 'cid, payload, expected',