diff --git a/bumble/l2cap.py b/bumble/l2cap.py index c5e617f..7d97cb9 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -1683,6 +1683,10 @@ class LeCreditBasedChannel(utils.EventEmitter): if identifier in self.manager.le_coc_requests: raise InvalidStateError('too many concurrent connection requests') + # Create a future to wait for the response + connection_result = asyncio.get_running_loop().create_future() + self.connection_result = connection_result + self._change_state(self.State.CONNECTING) request = L2CAP_LE_Credit_Based_Connection_Request( identifier=identifier, @@ -1695,17 +1699,19 @@ class LeCreditBasedChannel(utils.EventEmitter): self.manager.le_coc_requests[identifier] = request self.send_control_frame(request) - # Create a future to wait for the response - self.connection_result = asyncio.get_running_loop().create_future() - # Wait for the connection to succeed or fail - return await self.connection_result + return await connection_result async def disconnect(self) -> None: # Check that we're connected if self.state != self.State.CONNECTED: raise InvalidStateError('not connected') + # Create a future to wait for the state machine to get to a success or error + # state + disconnection_result = asyncio.get_running_loop().create_future() + self.disconnection_result = disconnection_result + self._change_state(self.State.DISCONNECTING) self.flush_output() self.send_control_frame( @@ -1716,14 +1722,16 @@ class LeCreditBasedChannel(utils.EventEmitter): ) ) - # Create a future to wait for the state machine to get to a success or error - # state - self.disconnection_result = asyncio.get_running_loop().create_future() - return await self.disconnection_result + return await disconnection_result def abort(self) -> None: - if self.state == self.State.CONNECTED: + if self.state == self.State.CONNECTED or self.state == self.State.DISCONNECTING: + was_disconnecting = self.state == self.State.DISCONNECTING self._change_state(self.State.DISCONNECTED) + self.manager.on_channel_closed(self) + if was_disconnecting and self.disconnection_result: + self.disconnection_result.set_result(None) + self.disconnection_result = None if self.state == self.State.CONNECTING: if self.connection_result is not None: self.connection_result.cancel() @@ -1860,7 +1868,12 @@ class LeCreditBasedChannel(utils.EventEmitter): source_cid=request.source_cid, ) ) + was_disconnecting = self.state == self.State.DISCONNECTING self._change_state(self.State.DISCONNECTED) + self.manager.on_channel_closed(self) + if was_disconnecting and self.disconnection_result: + self.disconnection_result.set_result(None) + self.disconnection_result = None self.flush_output() def on_disconnection_response(self, response: L2CAP_Disconnection_Response) -> None: @@ -1876,6 +1889,7 @@ class LeCreditBasedChannel(utils.EventEmitter): return self._change_state(self.State.DISCONNECTED) + self.manager.on_channel_closed(self) if self.disconnection_result: self.disconnection_result.set_result(None) self.disconnection_result = None @@ -2461,12 +2475,9 @@ class ChannelManager: if ( channel := self.find_channel(connection.handle, response.source_cid) ) is None: - logger.warning( - color( - f'channel {response.source_cid} not found for ' - f'0x{connection.handle:04X}:{cid}', - 'red', - ) + logger.debug( + f'channel {response.source_cid} not found for ' + f'0x{connection.handle:04X}:{cid}' ) return @@ -2879,11 +2890,16 @@ class ChannelManager: channel.on_credits(credit.credits) - def on_channel_closed(self, channel: ClassicChannel) -> None: + def on_channel_closed(self, channel: ClassicChannel | LeCreditBasedChannel) -> None: connection_channels = self.channels.get(channel.connection.handle) if connection_channels: - if channel.source_cid in connection_channels: + if connection_channels.get(channel.source_cid) is channel: del connection_channels[channel.source_cid] + if isinstance(channel, LeCreditBasedChannel): + le_connection_channels = self.le_coc_channels.get(channel.connection.handle) + if le_connection_channels: + if le_connection_channels.get(channel.destination_cid) is channel: + del le_connection_channels[channel.destination_cid] async def create_le_credit_based_channel( self, diff --git a/tests/l2cap_test.py b/tests/l2cap_test.py index f5387c5..2258e9b 100644 --- a/tests/l2cap_test.py +++ b/tests/l2cap_test.py @@ -457,6 +457,48 @@ def test_fcs(cid: int, payload: str, expected: str): assert pdu.to_bytes(with_fcs=True) == bytes.fromhex(expected) +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_abort_while_disconnecting(): + devices = TwoDevices() + await devices.setup_connection() + psm = 1234 + + server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]() + devices.devices[1].create_l2cap_server( + spec=l2cap.LeCreditBasedChannelSpec(psm=psm), + handler=server_channels.put_nowait, + ) + client_channel = await devices.connections[0].create_l2cap_channel( + spec=l2cap.LeCreditBasedChannelSpec(psm) + ) + server_channel = await server_channels.get() + + # Stub server channel's on_disconnection_request to ignore the request, + # simulating a lost packet or unresponsive peer. + server_channel.on_disconnection_request = lambda request: None + + # Intercept state change to DISCONNECTING and call abort() + original_change_state = client_channel._change_state + abort_called = False + + def my_change_state(new_state): + nonlocal abort_called + original_change_state(new_state) + if ( + new_state == l2cap.LeCreditBasedChannel.State.DISCONNECTING + and not abort_called + ): + abort_called = True + client_channel.abort() + + client_channel._change_state = my_change_state + + # Start disconnection and wait with a timeout + await asyncio.wait_for(client_channel.disconnect(), timeout=1.0) + assert client_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED + + # ----------------------------------------------------------------------------- async def run(): test_helpers()