l2cap: Simplify abort and disconnection response handling

Also includes the test_disconnection_collision unit test.
This commit is contained in:
uier
2026-06-16 18:10:53 +00:00
parent 5bff934868
commit 8f00e82fac
2 changed files with 52 additions and 24 deletions
+14 -19
View File
@@ -1725,16 +1725,15 @@ class LeCreditBasedChannel(utils.EventEmitter):
return await disconnection_result
def abort(self) -> None:
if self.state == self.State.CONNECTED or self.state == self.State.DISCONNECTING:
was_disconnecting = self.state == self.State.DISCONNECTING
if self.state in (self.State.CONNECTED, self.State.DISCONNECTING):
self._change_state(self.State.DISCONNECTED)
self.manager.on_channel_closed(self)
if was_disconnecting and self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
if self.state == self.State.CONNECTING:
if self.connection_result is not None:
self.connection_result.cancel()
if self.connection_result is not None:
self.connection_result.cancel()
self.connection_result = None
if self.disconnection_result is not None:
self.disconnection_result.set_result(None)
self.disconnection_result = None
def on_pdu(self, pdu: bytes) -> None:
if self.sink is None:
@@ -1868,10 +1867,9 @@ class LeCreditBasedChannel(utils.EventEmitter):
source_cid=request.source_cid,
)
)
was_disconnecting = self.state == self.State.DISCONNECTING
self._change_state(self.State.DISCONNECTED)
self.manager.on_channel_closed(self)
if was_disconnecting and self.disconnection_result:
if self.disconnection_result is not None:
self.disconnection_result.set_result(None)
self.disconnection_result = None
self.flush_output()
@@ -2891,15 +2889,12 @@ class ChannelManager:
channel.on_credits(credit.credits)
def on_channel_closed(self, channel: ClassicChannel | LeCreditBasedChannel) -> None:
connection_channels = self.channels.get(channel.connection.handle)
if connection_channels:
if connection_channels.get(channel.source_cid) is channel:
del connection_channels[channel.source_cid]
if isinstance(channel, LeCreditBasedChannel):
le_connection_channels = self.le_coc_channels.get(channel.connection.handle)
if le_connection_channels:
if le_connection_channels.get(channel.destination_cid) is channel:
del le_connection_channels[channel.destination_cid]
if classic_connection_channels := self.channels.get(channel.connection.handle):
classic_connection_channels.pop(channel.source_cid, None)
elif le_connection_channels := self.le_coc_channels.get(
channel.connection.handle
):
le_connection_channels.pop(channel.destination_cid, None)
async def create_le_credit_based_channel(
self,
+38 -5
View File
@@ -460,6 +460,7 @@ def test_fcs(cid: int, payload: str, expected: str):
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_abort_while_disconnecting():
# Setup client and server channels
devices = TwoDevices()
await devices.setup_connection()
psm = 1234
@@ -474,15 +475,16 @@ async def test_abort_while_disconnecting():
)
server_channel = await server_channels.get()
# Stub server channel's on_disconnection_request to ignore the request,
# simulating a lost packet or unresponsive peer.
# Stub the server's request handler to ignore the disconnection request.
# This keeps the client channel in the DISCONNECTING state, waiting for a response,
# so we can simulate calling abort() during an active disconnection.
server_channel.on_disconnection_request = lambda request: None
# Intercept state change to DISCONNECTING and call abort()
original_change_state = client_channel._change_state
abort_called = False
def my_change_state(new_state):
def intercept_change_state_and_abort(new_state):
nonlocal abort_called
original_change_state(new_state)
if (
@@ -492,13 +494,44 @@ async def test_abort_while_disconnecting():
abort_called = True
client_channel.abort()
client_channel._change_state = my_change_state
client_channel._change_state = intercept_change_state_and_abort
# Start disconnection and wait with a timeout
# Start disconnection and wait with a timeout. It should resolve immediately due to the abort.
await asyncio.wait_for(client_channel.disconnect(), timeout=1.0)
assert client_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_disconnection_collision():
# Setup client and server channels
devices = TwoDevices()
await devices.setup_connection()
psm = 1234
server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
devices.devices[1].create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=psm),
handler=server_channels.put_nowait,
)
client_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(psm)
)
server_channel = await server_channels.get()
# Trigger disconnection from both sides concurrently to cause a collision.
# Both channels will transition to DISCONNECTING and send DISCONNECTION_REQUESTs.
# When each side receives the peer's request, it will handle it and resolve the
# disconnection_result future.
await asyncio.gather(
client_channel.disconnect(),
server_channel.disconnect(),
)
assert client_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED
assert server_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED
# -----------------------------------------------------------------------------
async def run():
test_helpers()