mirror of
https://github.com/google/bumble.git
synced 2026-06-17 10:02:27 +00:00
l2cap: Resolve teardown hang on disconnect collision or abort
Resolve a teardown hang in LeCreditBasedChannel. When a disconnection collision occurs (both DUT and peer call disconnect simultaneously) or the channel is aborted during disconnection, the connection state transitions to DISCONNECTED before the peer's response arrives (or is ignored). In these cases, the `disconnection_result` future remained unresolved, causing any awaiting teardown task to hang. This patch ensures that calling abort() or receiving a disconnection request while in the DISCONNECTING state correctly resolves `disconnection_result` and cleans up the channel. Verification: Verified with a new unit test `test_abort_while_disconnecting` added to `tests/l2cap_test.py` that stubs a non-responsive peer and calls abort() during the DISCONNECTING state transition, confirming it completes immediately.
This commit is contained in:
+33
-17
@@ -1683,6 +1683,10 @@ class LeCreditBasedChannel(utils.EventEmitter):
|
||||
if identifier in self.manager.le_coc_requests:
|
||||
raise InvalidStateError('too many concurrent connection requests')
|
||||
|
||||
# Create a future to wait for the response
|
||||
connection_result = asyncio.get_running_loop().create_future()
|
||||
self.connection_result = connection_result
|
||||
|
||||
self._change_state(self.State.CONNECTING)
|
||||
request = L2CAP_LE_Credit_Based_Connection_Request(
|
||||
identifier=identifier,
|
||||
@@ -1695,17 +1699,19 @@ class LeCreditBasedChannel(utils.EventEmitter):
|
||||
self.manager.le_coc_requests[identifier] = request
|
||||
self.send_control_frame(request)
|
||||
|
||||
# Create a future to wait for the response
|
||||
self.connection_result = asyncio.get_running_loop().create_future()
|
||||
|
||||
# Wait for the connection to succeed or fail
|
||||
return await self.connection_result
|
||||
return await connection_result
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
# Check that we're connected
|
||||
if self.state != self.State.CONNECTED:
|
||||
raise InvalidStateError('not connected')
|
||||
|
||||
# Create a future to wait for the state machine to get to a success or error
|
||||
# state
|
||||
disconnection_result = asyncio.get_running_loop().create_future()
|
||||
self.disconnection_result = disconnection_result
|
||||
|
||||
self._change_state(self.State.DISCONNECTING)
|
||||
self.flush_output()
|
||||
self.send_control_frame(
|
||||
@@ -1716,14 +1722,16 @@ class LeCreditBasedChannel(utils.EventEmitter):
|
||||
)
|
||||
)
|
||||
|
||||
# Create a future to wait for the state machine to get to a success or error
|
||||
# state
|
||||
self.disconnection_result = asyncio.get_running_loop().create_future()
|
||||
return await self.disconnection_result
|
||||
return await disconnection_result
|
||||
|
||||
def abort(self) -> None:
|
||||
if self.state == self.State.CONNECTED:
|
||||
if self.state == self.State.CONNECTED or self.state == self.State.DISCONNECTING:
|
||||
was_disconnecting = self.state == self.State.DISCONNECTING
|
||||
self._change_state(self.State.DISCONNECTED)
|
||||
self.manager.on_channel_closed(self)
|
||||
if was_disconnecting and self.disconnection_result:
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
if self.state == self.State.CONNECTING:
|
||||
if self.connection_result is not None:
|
||||
self.connection_result.cancel()
|
||||
@@ -1860,7 +1868,12 @@ class LeCreditBasedChannel(utils.EventEmitter):
|
||||
source_cid=request.source_cid,
|
||||
)
|
||||
)
|
||||
was_disconnecting = self.state == self.State.DISCONNECTING
|
||||
self._change_state(self.State.DISCONNECTED)
|
||||
self.manager.on_channel_closed(self)
|
||||
if was_disconnecting and self.disconnection_result:
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
self.flush_output()
|
||||
|
||||
def on_disconnection_response(self, response: L2CAP_Disconnection_Response) -> None:
|
||||
@@ -1876,6 +1889,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
|
||||
return
|
||||
|
||||
self._change_state(self.State.DISCONNECTED)
|
||||
self.manager.on_channel_closed(self)
|
||||
if self.disconnection_result:
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
@@ -2461,12 +2475,9 @@ class ChannelManager:
|
||||
if (
|
||||
channel := self.find_channel(connection.handle, response.source_cid)
|
||||
) is None:
|
||||
logger.warning(
|
||||
color(
|
||||
f'channel {response.source_cid} not found for '
|
||||
f'0x{connection.handle:04X}:{cid}',
|
||||
'red',
|
||||
)
|
||||
logger.debug(
|
||||
f'channel {response.source_cid} not found for '
|
||||
f'0x{connection.handle:04X}:{cid}'
|
||||
)
|
||||
return
|
||||
|
||||
@@ -2879,11 +2890,16 @@ class ChannelManager:
|
||||
|
||||
channel.on_credits(credit.credits)
|
||||
|
||||
def on_channel_closed(self, channel: ClassicChannel) -> None:
|
||||
def on_channel_closed(self, channel: ClassicChannel | LeCreditBasedChannel) -> None:
|
||||
connection_channels = self.channels.get(channel.connection.handle)
|
||||
if connection_channels:
|
||||
if channel.source_cid in connection_channels:
|
||||
if connection_channels.get(channel.source_cid) is channel:
|
||||
del connection_channels[channel.source_cid]
|
||||
if isinstance(channel, LeCreditBasedChannel):
|
||||
le_connection_channels = self.le_coc_channels.get(channel.connection.handle)
|
||||
if le_connection_channels:
|
||||
if le_connection_channels.get(channel.destination_cid) is channel:
|
||||
del le_connection_channels[channel.destination_cid]
|
||||
|
||||
async def create_le_credit_based_channel(
|
||||
self,
|
||||
|
||||
@@ -457,6 +457,48 @@ def test_fcs(cid: int, payload: str, expected: str):
|
||||
assert pdu.to_bytes(with_fcs=True) == bytes.fromhex(expected)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_while_disconnecting():
|
||||
devices = TwoDevices()
|
||||
await devices.setup_connection()
|
||||
psm = 1234
|
||||
|
||||
server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
|
||||
devices.devices[1].create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(psm=psm),
|
||||
handler=server_channels.put_nowait,
|
||||
)
|
||||
client_channel = await devices.connections[0].create_l2cap_channel(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(psm)
|
||||
)
|
||||
server_channel = await server_channels.get()
|
||||
|
||||
# Stub server channel's on_disconnection_request to ignore the request,
|
||||
# simulating a lost packet or unresponsive peer.
|
||||
server_channel.on_disconnection_request = lambda request: None
|
||||
|
||||
# Intercept state change to DISCONNECTING and call abort()
|
||||
original_change_state = client_channel._change_state
|
||||
abort_called = False
|
||||
|
||||
def my_change_state(new_state):
|
||||
nonlocal abort_called
|
||||
original_change_state(new_state)
|
||||
if (
|
||||
new_state == l2cap.LeCreditBasedChannel.State.DISCONNECTING
|
||||
and not abort_called
|
||||
):
|
||||
abort_called = True
|
||||
client_channel.abort()
|
||||
|
||||
client_channel._change_state = my_change_state
|
||||
|
||||
# Start disconnection and wait with a timeout
|
||||
await asyncio.wait_for(client_channel.disconnect(), timeout=1.0)
|
||||
assert client_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run():
|
||||
test_helpers()
|
||||
|
||||
Reference in New Issue
Block a user