From 8f00e82fac558c38c1fc918adb1322b57ffa8fd4 Mon Sep 17 00:00:00 2001 From: uier Date: Tue, 16 Jun 2026 18:10:53 +0000 Subject: [PATCH] l2cap: Simplify abort and disconnection response handling Also includes the test_disconnection_collision unit test. --- bumble/l2cap.py | 33 ++++++++++++++------------------- tests/l2cap_test.py | 43 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 52 insertions(+), 24 deletions(-) diff --git a/bumble/l2cap.py b/bumble/l2cap.py index 7d97cb9..169371c 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -1725,16 +1725,15 @@ class LeCreditBasedChannel(utils.EventEmitter): return await disconnection_result def abort(self) -> None: - if self.state == self.State.CONNECTED or self.state == self.State.DISCONNECTING: - was_disconnecting = self.state == self.State.DISCONNECTING + if self.state in (self.State.CONNECTED, self.State.DISCONNECTING): self._change_state(self.State.DISCONNECTED) self.manager.on_channel_closed(self) - if was_disconnecting and self.disconnection_result: - self.disconnection_result.set_result(None) - self.disconnection_result = None - if self.state == self.State.CONNECTING: - if self.connection_result is not None: - self.connection_result.cancel() + if self.connection_result is not None: + self.connection_result.cancel() + self.connection_result = None + if self.disconnection_result is not None: + self.disconnection_result.set_result(None) + self.disconnection_result = None def on_pdu(self, pdu: bytes) -> None: if self.sink is None: @@ -1868,10 +1867,9 @@ class LeCreditBasedChannel(utils.EventEmitter): source_cid=request.source_cid, ) ) - was_disconnecting = self.state == self.State.DISCONNECTING self._change_state(self.State.DISCONNECTED) self.manager.on_channel_closed(self) - if was_disconnecting and self.disconnection_result: + if self.disconnection_result is not None: self.disconnection_result.set_result(None) self.disconnection_result = None self.flush_output() @@ -2891,15 +2889,12 @@ class ChannelManager: channel.on_credits(credit.credits) def on_channel_closed(self, channel: ClassicChannel | LeCreditBasedChannel) -> None: - connection_channels = self.channels.get(channel.connection.handle) - if connection_channels: - if connection_channels.get(channel.source_cid) is channel: - del connection_channels[channel.source_cid] - if isinstance(channel, LeCreditBasedChannel): - le_connection_channels = self.le_coc_channels.get(channel.connection.handle) - if le_connection_channels: - if le_connection_channels.get(channel.destination_cid) is channel: - del le_connection_channels[channel.destination_cid] + if classic_connection_channels := self.channels.get(channel.connection.handle): + classic_connection_channels.pop(channel.source_cid, None) + elif le_connection_channels := self.le_coc_channels.get( + channel.connection.handle + ): + le_connection_channels.pop(channel.destination_cid, None) async def create_le_credit_based_channel( self, diff --git a/tests/l2cap_test.py b/tests/l2cap_test.py index 2258e9b..9cc3c83 100644 --- a/tests/l2cap_test.py +++ b/tests/l2cap_test.py @@ -460,6 +460,7 @@ def test_fcs(cid: int, payload: str, expected: str): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_abort_while_disconnecting(): + # Setup client and server channels devices = TwoDevices() await devices.setup_connection() psm = 1234 @@ -474,15 +475,16 @@ async def test_abort_while_disconnecting(): ) server_channel = await server_channels.get() - # Stub server channel's on_disconnection_request to ignore the request, - # simulating a lost packet or unresponsive peer. + # Stub the server's request handler to ignore the disconnection request. + # This keeps the client channel in the DISCONNECTING state, waiting for a response, + # so we can simulate calling abort() during an active disconnection. server_channel.on_disconnection_request = lambda request: None # Intercept state change to DISCONNECTING and call abort() original_change_state = client_channel._change_state abort_called = False - def my_change_state(new_state): + def intercept_change_state_and_abort(new_state): nonlocal abort_called original_change_state(new_state) if ( @@ -492,13 +494,44 @@ async def test_abort_while_disconnecting(): abort_called = True client_channel.abort() - client_channel._change_state = my_change_state + client_channel._change_state = intercept_change_state_and_abort - # Start disconnection and wait with a timeout + # Start disconnection and wait with a timeout. It should resolve immediately due to the abort. await asyncio.wait_for(client_channel.disconnect(), timeout=1.0) assert client_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_disconnection_collision(): + # Setup client and server channels + devices = TwoDevices() + await devices.setup_connection() + psm = 1234 + + server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]() + devices.devices[1].create_l2cap_server( + spec=l2cap.LeCreditBasedChannelSpec(psm=psm), + handler=server_channels.put_nowait, + ) + client_channel = await devices.connections[0].create_l2cap_channel( + spec=l2cap.LeCreditBasedChannelSpec(psm) + ) + server_channel = await server_channels.get() + + # Trigger disconnection from both sides concurrently to cause a collision. + # Both channels will transition to DISCONNECTING and send DISCONNECTION_REQUESTs. + # When each side receives the peer's request, it will handle it and resolve the + # disconnection_result future. + await asyncio.gather( + client_channel.disconnect(), + server_channel.disconnect(), + ) + + assert client_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED + assert server_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED + + # ----------------------------------------------------------------------------- async def run(): test_helpers()