diff --git a/bumble/l2cap.py b/bumble/l2cap.py index 72b39f7e..cae5ff24 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -829,7 +829,7 @@ class L2CAP_Credit_Based_Connection_Response(L2CAP_Control_Frame): mtu: int = dataclasses.field(metadata=hci.metadata(2)) mps: int = dataclasses.field(metadata=hci.metadata(2)) initial_credits: int = dataclasses.field(metadata=hci.metadata(2)) - result: int = dataclasses.field(metadata=Result.type_metadata(2)) + result: Result = dataclasses.field(metadata=Result.type_metadata(2)) destination_cid: Sequence[int] = dataclasses.field( metadata=L2CAP_Credit_Based_Connection_Request.CID_METADATA ) @@ -1559,7 +1559,7 @@ class LeCreditBasedChannel(utils.EventEmitter): self, manager: ChannelManager, connection: Connection, - le_psm: int, + psm: int, source_cid: int, destination_cid: int, mtu: int, @@ -1573,7 +1573,7 @@ class LeCreditBasedChannel(utils.EventEmitter): super().__init__() self.manager = manager self.connection = connection - self.le_psm = le_psm + self.psm = psm self.source_cid = source_cid self.destination_cid = destination_cid self.mtu = mtu @@ -1629,7 +1629,7 @@ class LeCreditBasedChannel(utils.EventEmitter): self._change_state(self.State.CONNECTING) request = L2CAP_LE_Credit_Based_Connection_Request( identifier=identifier, - le_psm=self.le_psm, + le_psm=self.psm, source_cid=self.source_cid, mtu=self.mtu, mps=self.mps, @@ -1772,6 +1772,22 @@ class LeCreditBasedChannel(utils.EventEmitter): # Cleanup self.connection_result = None + def on_enhanced_connection_response( + self, destination_cid: int, response: L2CAP_Credit_Based_Connection_Response + ) -> None: + if ( + response.result + == L2CAP_Credit_Based_Connection_Response.Result.ALL_CONNECTIONS_SUCCESSFUL + ): + self.destination_cid = destination_cid + self.peer_mtu = response.mtu + self.peer_mps = response.mps + self.credits = response.initial_credits + self.connected = True + self._change_state(self.State.CONNECTED) + else: + self._change_state(self.State.CONNECTION_ERROR) + def on_credits(self, credits: int) -> None: # pylint: disable=redefined-builtin self.credits += credits logger.debug(f'received {credits} credits, total = {self.credits}') @@ -1884,7 +1900,7 @@ class LeCreditBasedChannel(utils.EventEmitter): return ( f'CoC({self.source_cid}->{self.destination_cid}, ' f'State={self.state.name}, ' - f'PSM={self.le_psm}, ' + f'PSM={self.psm}, ' f'MTU={self.mtu}/{self.peer_mtu}, ' f'MPS={self.mps}/{self.peer_mps}, ' f'credits={self.credits}/{self.peer_credits})' @@ -1958,6 +1974,16 @@ class ChannelManager: le_coc_servers: dict[int, LeCreditBasedChannelServer] le_coc_requests: dict[int, L2CAP_LE_Credit_Based_Connection_Request] fixed_channels: dict[int, Optional[Callable[[int, bytes], Any]]] + pending_credit_based_connections: dict[ + int, + dict[ + int, + tuple[ + asyncio.Future[None], + list[LeCreditBasedChannel], + ], + ], + ] _host: Optional[Host] connection_parameters_update_response: Optional[asyncio.Future[int]] @@ -1979,6 +2005,9 @@ class ChannelManager: ) # LE CoC channels, mapped by connection and destination cid self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM self.le_coc_requests = {} # LE CoC connection requests, by identifier + self.pending_credit_based_connections = ( + {} + ) # Credit-based connection request contexts, by connection handle and identifier self.extended_features = set(extended_features) self.connectionless_mtu = connectionless_mtu self.connection_parameters_update_response = None @@ -2021,18 +2050,26 @@ class ChannelManager: raise OutOfResourcesError('no free CID available') - @staticmethod - def find_free_le_cid(channels: Iterable[int]) -> int: + @classmethod + def find_free_le_cid(cls, channels: Iterable[int]) -> int | None: + cids = cls.find_free_le_cids(channels, 1) + return cids[0] if cids else None + + @classmethod + def find_free_le_cids(cls, channels: Iterable[int], count: int) -> list[int]: # Pick the smallest valid CID that's not already in the list # (not necessarily the most efficient algorithm, but the list of CID is # very small in practice) + cids: list[int] = [] for cid in range( L2CAP_LE_U_DYNAMIC_CID_RANGE_START, L2CAP_LE_U_DYNAMIC_CID_RANGE_END + 1 ): if cid not in channels: - return cid + cids.append(cid) + if len(cids) == count: + return cids - raise OutOfResourcesError('no free CID') + return [] def next_identifier(self, connection: Connection) -> int: identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256 @@ -2119,18 +2156,22 @@ class ChannelManager: return self.le_coc_servers[spec.psm] - def on_disconnection(self, connection_handle: int, _reason: int) -> None: - logger.debug(f'disconnection from {connection_handle}, cleaning up channels') - if connection_handle in self.channels: - for _, channel in self.channels[connection_handle].items(): + def on_disconnection(self, connection_handle: int, reason: int) -> None: + del reason # unused. + logger.debug('disconnection from %d, cleaning up channels', connection_handle) + if channels := self.channels.pop(connection_handle, None): + for channel in channels.values(): channel.abort() - del self.channels[connection_handle] - if connection_handle in self.le_coc_channels: - for _, channel in self.le_coc_channels[connection_handle].items(): - channel.abort() - del self.le_coc_channels[connection_handle] - if connection_handle in self.identifiers: - del self.identifiers[connection_handle] + if le_coc_channels := self.le_coc_channels.pop(connection_handle, None): + for le_coc_channel in le_coc_channels.values(): + le_coc_channel.abort() + if pending_credit_based_connections := self.pending_credit_based_connections.pop( + connection_handle, None + ): + for future, _ in pending_credit_based_connections.values(): + if not future.done(): + future.cancel("ACL disconnected") + self.identifiers.pop(connection_handle, None) def send_pdu( self, @@ -2242,7 +2283,6 @@ class ChannelManager: identifier=request.identifier, destination_cid=request.source_cid, source_cid=0, - # pylint: disable=line-too-long result=L2CAP_Connection_Response.Result.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, status=0x0000, ), @@ -2273,7 +2313,6 @@ class ChannelManager: identifier=request.identifier, destination_cid=request.source_cid, source_cid=0, - # pylint: disable=line-too-long result=L2CAP_Connection_Response.Result.CONNECTION_REFUSED_PSM_NOT_SUPPORTED, status=0x0000, ), @@ -2502,89 +2541,7 @@ class ChannelManager: cid: int, request: L2CAP_LE_Credit_Based_Connection_Request, ) -> None: - if request.le_psm in self.le_coc_servers: - server = self.le_coc_servers[request.le_psm] - - # Check that the CID isn't already used - le_connection_channels = self.le_coc_channels.setdefault( - connection.handle, {} - ) - if request.source_cid in le_connection_channels: - logger.warning(f'source CID {request.source_cid} already in use') - self.send_control_frame( - connection, - cid, - L2CAP_LE_Credit_Based_Connection_Response( - identifier=request.identifier, - destination_cid=0, - mtu=server.mtu, - mps=server.mps, - initial_credits=0, - # pylint: disable=line-too-long - result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED, - ), - ) - return - - # Find a free CID for this new channel - connection_channels = self.channels.setdefault(connection.handle, {}) - source_cid = self.find_free_le_cid(connection_channels) - if source_cid is None: # Should never happen! - self.send_control_frame( - connection, - cid, - L2CAP_LE_Credit_Based_Connection_Response( - identifier=request.identifier, - destination_cid=0, - mtu=server.mtu, - mps=server.mps, - initial_credits=0, - # pylint: disable=line-too-long - result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, - ), - ) - return - - # Create a new channel - logger.debug( - f'creating LE CoC server channel with cid={source_cid} for psm ' - f'{request.le_psm}' - ) - channel = LeCreditBasedChannel( - self, - connection, - request.le_psm, - source_cid, - request.source_cid, - server.mtu, - server.mps, - request.initial_credits, - request.mtu, - request.mps, - server.max_credits, - True, - ) - connection_channels[source_cid] = channel - le_connection_channels[request.source_cid] = channel - - # Respond - self.send_control_frame( - connection, - cid, - L2CAP_LE_Credit_Based_Connection_Response( - identifier=request.identifier, - destination_cid=source_cid, - mtu=server.mtu, - mps=server.mps, - initial_credits=server.max_credits, - # pylint: disable=line-too-long - result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_SUCCESSFUL, - ), - ) - - # Notify - server.on_connection(channel) - else: + if not (server := self.le_coc_servers.get(request.le_psm)): logger.info( f'No LE server for connection 0x{connection.handle:04X} ' f'on PSM {request.le_psm}' @@ -2598,10 +2555,86 @@ class ChannelManager: mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, initial_credits=0, - # pylint: disable=line-too-long result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED, ), ) + return + + # Check that the CID isn't already used + le_connection_channels = self.le_coc_channels.setdefault(connection.handle, {}) + if request.source_cid in le_connection_channels: + logger.warning(f'source CID {request.source_cid} already in use') + self.send_control_frame( + connection, + cid, + L2CAP_LE_Credit_Based_Connection_Response( + identifier=request.identifier, + destination_cid=0, + mtu=server.mtu, + mps=server.mps, + initial_credits=0, + result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED, + ), + ) + return + + # Find a free CID for this new channel + connection_channels = self.channels.setdefault(connection.handle, {}) + source_cid = self.find_free_le_cid(connection_channels) + if source_cid is None: # Should never happen! + self.send_control_frame( + connection, + cid, + L2CAP_LE_Credit_Based_Connection_Response( + identifier=request.identifier, + destination_cid=0, + mtu=server.mtu, + mps=server.mps, + initial_credits=0, + result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, + ), + ) + return + + # Create a new channel + logger.debug( + f'creating LE CoC server channel with cid={source_cid} for psm ' + f'{request.le_psm}' + ) + channel = LeCreditBasedChannel( + self, + connection, + request.le_psm, + source_cid, + request.source_cid, + server.mtu, + server.mps, + request.initial_credits, + request.mtu, + request.mps, + server.max_credits, + True, + ) + connection_channels[source_cid] = channel + le_connection_channels[request.source_cid] = channel + + # Respond + self.send_control_frame( + connection, + cid, + L2CAP_LE_Credit_Based_Connection_Response( + identifier=request.identifier, + destination_cid=source_cid, + mtu=server.mtu, + mps=server.mps, + initial_credits=server.max_credits, + # pylint: disable=line-too-long + result=L2CAP_LE_Credit_Based_Connection_Response.Result.CONNECTION_SUCCESSFUL, + ), + ) + + # Notify + server.on_connection(channel) def on_l2cap_le_credit_based_connection_response( self, @@ -2610,11 +2643,9 @@ class ChannelManager: response: L2CAP_LE_Credit_Based_Connection_Response, ) -> None: # Find the pending request by identifier - request = self.le_coc_requests.get(response.identifier) - if request is None: + if not (request := self.le_coc_requests.pop(response.identifier, None)): logger.warning(color('!!! received response for unknown request', 'red')) return - del self.le_coc_requests[response.identifier] # Find the channel for this request channel = self.find_channel(connection.handle, request.source_cid) @@ -2631,6 +2662,147 @@ class ChannelManager: # Process the response channel.on_connection_response(response) + def on_l2cap_credit_based_connection_request( + self, + connection: Connection, + cid: int, + request: L2CAP_Credit_Based_Connection_Request, + ) -> None: + if not (server := self.le_coc_servers.get(request.spsm)): + logger.info( + 'No LE server for connection 0x%04X ' 'on PSM %d', + connection.handle, + request.spsm, + ) + self.send_control_frame( + connection, + cid, + L2CAP_Credit_Based_Connection_Response( + identifier=request.identifier, + destination_cid=[], + mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, + mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, + initial_credits=0, + result=L2CAP_Credit_Based_Connection_Response.Result.ALL_CONNECTIONS_REFUSED_SPSM_NOT_SUPPORTED, + ), + ) + return + + # Check that the CID isn't already used + le_connection_channels = self.le_coc_channels.setdefault(connection.handle, {}) + if cid_in_use := set(request.source_cid).intersection( + set(le_connection_channels) + ): + logger.warning('source CID already in use: %s', cid_in_use) + self.send_control_frame( + connection, + cid, + L2CAP_Credit_Based_Connection_Response( + identifier=request.identifier, + mtu=server.mtu, + mps=server.mps, + initial_credits=0, + result=L2CAP_Credit_Based_Connection_Response.Result.SOME_CONNECTIONS_REFUSED_SOURCE_CID_ALREADY_ALLOCATED, + destination_cid=[], + ), + ) + return + + # Find free CIDs for new channels + connection_channels = self.channels.setdefault(connection.handle, {}) + source_cids = self.find_free_le_cids( + connection_channels, len(request.source_cid) + ) + if not source_cids: + self.send_control_frame( + connection, + cid, + L2CAP_Credit_Based_Connection_Response( + identifier=request.identifier, + destination_cid=[], + mtu=server.mtu, + mps=server.mps, + initial_credits=server.max_credits, + result=L2CAP_Credit_Based_Connection_Response.Result.SOME_CONNECTIONS_REFUSED_INSUFFICIENT_RESOURCES_AVAILABLE, + ), + ) + return + + for destination_cid in request.source_cid: + # TODO: Handle Classic channels. + if not (source_cid := self.find_free_le_cid(connection_channels)): + logger.warning("No free CIDs available") + break + # Create a new channel + logger.debug( + 'creating LE CoC server channel with cid=%s for psm %s', + source_cid, + request.spsm, + ) + channel = LeCreditBasedChannel( + self, + connection, + request.spsm, + source_cid, + destination_cid, + server.mtu, + server.mps, + request.initial_credits, + request.mtu, + request.mps, + server.max_credits, + True, + ) + connection_channels[source_cid] = channel + le_connection_channels[source_cid] = channel + server.on_connection(channel) + + # Respond + self.send_control_frame( + connection, + cid, + L2CAP_Credit_Based_Connection_Response( + identifier=request.identifier, + destination_cid=source_cids, + mtu=server.mtu, + mps=server.mps, + initial_credits=server.max_credits, + result=L2CAP_Credit_Based_Connection_Response.Result.ALL_CONNECTIONS_SUCCESSFUL, + ), + ) + + def on_l2cap_credit_based_connection_response( + self, + connection: Connection, + _cid: int, + response: L2CAP_Credit_Based_Connection_Response, + ) -> None: + # Find the pending request by identifier + pending_connections = self.pending_credit_based_connections.setdefault( + connection.handle, {} + ) + if not ( + pending_connection := pending_connections.pop(response.identifier, None) + ): + logger.warning(color('!!! received response for unknown request', 'red')) + return + + connection_result, channels = pending_connection + + # Process the response + for channel, destination_cid in zip(channels, response.destination_cid): + channel.on_enhanced_connection_response(destination_cid, response) + + if ( + response.result + == L2CAP_Credit_Based_Connection_Response.Result.ALL_CONNECTIONS_SUCCESSFUL + ): + connection_result.set_result(None) + else: + connection_result.set_exception( + L2capError(response.result, response.result.name) + ) + def on_l2cap_le_flow_control_credit( self, connection: Connection, _cid: int, credit: L2CAP_LE_Flow_Control_Credit ) -> None: @@ -2666,7 +2838,7 @@ class ChannelManager: channel = LeCreditBasedChannel( manager=self, connection=connection, - le_psm=spec.psm, + psm=spec.psm, source_cid=source_cid, destination_cid=0, mtu=spec.mtu, @@ -2730,6 +2902,79 @@ class ChannelManager: return channel + async def create_enhanced_credit_based_channels( + self, + connection: Connection, + spec: LeCreditBasedChannelSpec, + count: int, + ) -> list[LeCreditBasedChannel]: + # Find a free CID for the new channel + connection_channels = self.channels.setdefault(connection.handle, {}) + source_cids = self.find_free_le_cids(connection_channels, count) + if not source_cids: # Should never happen! + raise OutOfResourcesError('all CIDs already in use') + + if spec.psm is None: + raise InvalidArgumentError('PSM cannot be None') + + # Create the channel + logger.debug( + 'creating coc channel with cid=%s for psm %s', source_cids, spec.psm + ) + channels: list[LeCreditBasedChannel] = [] + for source_cid in source_cids: + channel = LeCreditBasedChannel( + manager=self, + connection=connection, + psm=spec.psm, + source_cid=source_cid, + destination_cid=0, + mtu=spec.mtu, + mps=spec.mps, + credits=0, + peer_mtu=0, + peer_mps=0, + peer_credits=spec.max_credits, + connected=False, + ) + connection_channels[source_cid] = channel + channels.append(channel) + + identifier = self.next_identifier(connection) + request = L2CAP_Credit_Based_Connection_Request( + identifier=identifier, + spsm=spec.psm, + mtu=spec.mtu, + mps=spec.mps, + initial_credits=spec.max_credits, + source_cid=source_cids, + ) + connection_result = asyncio.get_running_loop().create_future() + pending_connections = self.pending_credit_based_connections.setdefault( + connection.handle, {} + ) + pending_connections[identifier] = (connection_result, channels) + self.send_control_frame( + connection, + L2CAP_LE_SIGNALING_CID, + request, + ) + # Connect + try: + await connection_result + except Exception: + logger.exception('connection failed') + for cid in source_cids: + del connection_channels[cid] + raise + + # Remember the channel by source CID and destination CID + le_connection_channels = self.le_coc_channels.setdefault(connection.handle, {}) + for channel in channels: + le_connection_channels[channel.destination_cid] = channel + + return channels + @classmethod def make_mode_processor( self, diff --git a/tests/l2cap_test.py b/tests/l2cap_test.py index bfc566ef..9e35a75a 100644 --- a/tests/l2cap_test.py +++ b/tests/l2cap_test.py @@ -16,15 +16,17 @@ # Imports # ----------------------------------------------------------------------------- import asyncio +import itertools import logging import os import random import struct +from collections.abc import Sequence +from unittest import mock import pytest -from bumble import l2cap -from bumble.core import ProtocolError +from bumble import core, l2cap from .test_utils import TwoDevices, async_barrier @@ -143,7 +145,7 @@ async def test_basic_connection(): psm = 1234 # Check that if there's no one listening, we can't connect - with pytest.raises(ProtocolError): + with pytest.raises(core.ProtocolError): l2cap_channel = await devices.connections[0].create_l2cap_channel( spec=l2cap.LeCreditBasedChannelSpec(psm) ) @@ -231,48 +233,63 @@ async def test_l2cap_information_request(monkeypatch, info_type): # ----------------------------------------------------------------------------- -async def transfer_payload(max_credits, mtu, mps): - devices = TwoDevices() - await devices.setup_connection() +async def transfer_payload( + channels: Sequence[l2cap.ClassicChannel | l2cap.LeCreditBasedChannel], +): + received = asyncio.Queue[bytes]() + channels[1].sink = received.put_nowait + sdu_lengths = (21, 70, 700, 5523) - received = [] + if isinstance(channels[1], l2cap.LeCreditBasedChannel): + mps = channels[1].mps + elif isinstance( + processor := channels[1].processor, l2cap.EnhancedRetransmissionProcessor + ): + mps = processor.mps + else: + mps = channels[1].mtu - def on_coc(channel): - def on_data(data): - received.append(data) - - channel.sink = on_data - - server = devices.devices[1].create_l2cap_server( - spec=l2cap.LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps), - handler=on_coc, - ) - l2cap_channel = await devices.connections[0].create_l2cap_channel( - spec=l2cap.LeCreditBasedChannelSpec(server.psm) - ) - - messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100, 789)] + messages = [ + bytes([i % 8 for i in range(sdu_length)]) + for sdu_length in sdu_lengths + if sdu_length <= mps + ] for message in messages: - l2cap_channel.write(message) - await asyncio.sleep(0) - if random.randint(0, 5) == 1: - await l2cap_channel.drain() + channels[0].write(message) + if isinstance(channels[0], l2cap.LeCreditBasedChannel): + if random.randint(0, 5) == 1: + await channels[0].drain() - await l2cap_channel.drain() - await l2cap_channel.disconnect() + if isinstance(channels[0], l2cap.LeCreditBasedChannel): + await channels[0].drain() sent_bytes = b''.join(messages) - received_bytes = b''.join(received) + received_bytes = b'' + while len(received_bytes) < len(sent_bytes): + received_bytes += await received.get() assert sent_bytes == received_bytes @pytest.mark.asyncio -async def test_transfer(): - for max_credits in (1, 10, 100, 10000): - for mtu in (50, 255, 256, 1000): - for mps in (50, 255, 256, 1000): - # print(max_credits, mtu, mps) - await transfer_payload(max_credits, mtu, mps) +@pytest.mark.parametrize( + "max_credits, mtu, mps", + itertools.product((1, 10, 100, 10000), (50, 255, 256, 1000), (50, 255, 256, 1000)), +) +async def test_transfer(max_credits: int, mtu: int, mps: int): + devices = await TwoDevices.create_with_connection() + + server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]() + server = devices[1].create_l2cap_server( + spec=l2cap.LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps), + handler=server_channels.put_nowait, + ) + assert (connection := devices.connections[0]) + client = await connection.create_l2cap_channel( + spec=l2cap.LeCreditBasedChannelSpec(server.psm) + ) + server_channel = await server_channels.get() + await transfer_payload((client, server_channel)) + await client.disconnect() # ----------------------------------------------------------------------------- @@ -281,45 +298,18 @@ async def test_bidirectional_transfer(): devices = TwoDevices() await devices.setup_connection() - client_received = [] - server_received = [] - server_channel = None - - def on_server_coc(channel): - nonlocal server_channel - server_channel = channel - - def on_server_data(data): - server_received.append(data) - - channel.sink = on_server_data - - def on_client_data(data): - client_received.append(data) - + server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]() server = devices.devices[1].create_l2cap_server( - spec=l2cap.LeCreditBasedChannelSpec(), handler=on_server_coc + spec=l2cap.LeCreditBasedChannelSpec(), + handler=server_channels.put_nowait, ) - client_channel = await devices.connections[0].create_l2cap_channel( + client = await devices.connections[0].create_l2cap_channel( spec=l2cap.LeCreditBasedChannelSpec(server.psm) ) - client_channel.sink = on_client_data - - messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100)] - for message in messages: - client_channel.write(message) - await client_channel.drain() - await asyncio.sleep(0) - server_channel.write(message) - await server_channel.drain() - - await client_channel.disconnect() - - message_bytes = b''.join(messages) - client_received_bytes = b''.join(client_received) - server_received_bytes = b''.join(server_received) - assert client_received_bytes == message_bytes - assert server_received_bytes == message_bytes + server_channel = await server_channels.get() + await transfer_payload((client, server_channel)) + await transfer_payload((server_channel, client)) + await client.disconnect() # ----------------------------------------------------------------------------- @@ -363,18 +353,8 @@ async def test_enhanced_retransmission_mode(): ) server_channel = await server_channels.get() - sinks = [asyncio.Queue[bytes]() for _ in range(2)] - server_channel.sink = sinks[0].put_nowait - client_channel.sink = sinks[1].put_nowait - - for i in range(128): - server_channel.write(struct.pack('