Compare commits

..

3 Commits

Author SHA1 Message Date
uier 8f00e82fac l2cap: Simplify abort and disconnection response handling
Also includes the test_disconnection_collision unit test.
2026-06-17 04:37:59 +00:00
uier 5bff934868 l2cap: Resolve teardown hang on disconnect collision or abort
Resolve a teardown hang in LeCreditBasedChannel. When a disconnection
collision occurs (both DUT and peer call disconnect simultaneously) or the
channel is aborted during disconnection, the connection state transitions
to DISCONNECTED before the peer's response arrives (or is ignored).
In these cases, the `disconnection_result` future remained unresolved,
causing any awaiting teardown task to hang.

This patch ensures that calling abort() or receiving a disconnection request
while in the DISCONNECTING state correctly resolves `disconnection_result` and
cleans up the channel.

Verification:
Verified with a new unit test `test_abort_while_disconnecting` added to
`tests/l2cap_test.py` that stubs a non-responsive peer and calls abort()
during the DISCONNECTING state transition, confirming it completes immediately.
2026-06-16 18:03:26 +00:00
uier 3c8fe5637d tests: Fix pytest 9.1 compat by removing trailing commas in parametrize and setting asyncio_mode 2026-06-16 12:18:12 +00:00
10 changed files with 128 additions and 41 deletions
+34 -23
View File
@@ -1683,6 +1683,10 @@ class LeCreditBasedChannel(utils.EventEmitter):
if identifier in self.manager.le_coc_requests:
raise InvalidStateError('too many concurrent connection requests')
# Create a future to wait for the response
connection_result = asyncio.get_running_loop().create_future()
self.connection_result = connection_result
self._change_state(self.State.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request(
identifier=identifier,
@@ -1695,17 +1699,19 @@ class LeCreditBasedChannel(utils.EventEmitter):
self.manager.le_coc_requests[identifier] = request
self.send_control_frame(request)
# Create a future to wait for the response
self.connection_result = asyncio.get_running_loop().create_future()
# Wait for the connection to succeed or fail
return await self.connection_result
return await connection_result
async def disconnect(self) -> None:
# Check that we're connected
if self.state != self.State.CONNECTED:
raise InvalidStateError('not connected')
# Create a future to wait for the state machine to get to a success or error
# state
disconnection_result = asyncio.get_running_loop().create_future()
self.disconnection_result = disconnection_result
self._change_state(self.State.DISCONNECTING)
self.flush_output()
self.send_control_frame(
@@ -1716,17 +1722,18 @@ class LeCreditBasedChannel(utils.EventEmitter):
)
)
# Create a future to wait for the state machine to get to a success or error
# state
self.disconnection_result = asyncio.get_running_loop().create_future()
return await self.disconnection_result
return await disconnection_result
def abort(self) -> None:
if self.state == self.State.CONNECTED:
if self.state in (self.State.CONNECTED, self.State.DISCONNECTING):
self._change_state(self.State.DISCONNECTED)
if self.state == self.State.CONNECTING:
if self.connection_result is not None:
self.connection_result.cancel()
self.manager.on_channel_closed(self)
if self.connection_result is not None:
self.connection_result.cancel()
self.connection_result = None
if self.disconnection_result is not None:
self.disconnection_result.set_result(None)
self.disconnection_result = None
def on_pdu(self, pdu: bytes) -> None:
if self.sink is None:
@@ -1861,6 +1868,10 @@ class LeCreditBasedChannel(utils.EventEmitter):
)
)
self._change_state(self.State.DISCONNECTED)
self.manager.on_channel_closed(self)
if self.disconnection_result is not None:
self.disconnection_result.set_result(None)
self.disconnection_result = None
self.flush_output()
def on_disconnection_response(self, response: L2CAP_Disconnection_Response) -> None:
@@ -1876,6 +1887,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
return
self._change_state(self.State.DISCONNECTED)
self.manager.on_channel_closed(self)
if self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
@@ -2461,12 +2473,9 @@ class ChannelManager:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
logger.warning(
color(
f'channel {response.source_cid} not found for '
f'0x{connection.handle:04X}:{cid}',
'red',
)
logger.debug(
f'channel {response.source_cid} not found for '
f'0x{connection.handle:04X}:{cid}'
)
return
@@ -2879,11 +2888,13 @@ class ChannelManager:
channel.on_credits(credit.credits)
def on_channel_closed(self, channel: ClassicChannel) -> None:
connection_channels = self.channels.get(channel.connection.handle)
if connection_channels:
if channel.source_cid in connection_channels:
del connection_channels[channel.source_cid]
def on_channel_closed(self, channel: ClassicChannel | LeCreditBasedChannel) -> None:
if classic_connection_channels := self.channels.get(channel.connection.handle):
classic_connection_channels.pop(channel.source_cid, None)
elif le_connection_channels := self.le_coc_channels.get(
channel.connection.handle
):
le_connection_channels.pop(channel.destination_cid, None)
async def create_le_credit_based_channel(
self,
+1
View File
@@ -132,6 +132,7 @@ write_to = "bumble/_version.py"
[tool.pytest.ini_options]
pythonpath = "."
testpaths = ["tests"]
asyncio_mode = "auto"
[tool.pylint.master]
init-hook = 'import sys; sys.path.append(".")'
+3 -3
View File
@@ -46,7 +46,7 @@ class TwoDevices(test_utils.TwoDevices):
@pytest.mark.parametrize(
"command,",
"command",
[
avrcp.GetPlayStatusCommand(),
avrcp.GetCapabilitiesCommand(
@@ -132,7 +132,7 @@ def test_command(command: avrcp.Command):
@pytest.mark.parametrize(
"event,",
"event",
[
avrcp.UidsChangedEvent(uid_counter=7),
avrcp.TrackChangedEvent(uid=12356),
@@ -159,7 +159,7 @@ def test_event(event: avrcp.Event):
@pytest.mark.parametrize(
"response,",
"response",
[
avrcp.GetPlayStatusResponse(
song_length=1010, song_position=13, play_status=avrcp.PlayStatus.PAUSED
+1 -1
View File
@@ -72,7 +72,7 @@ def test_sef():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
'sirk_type,', [(csip.SirkType.ENCRYPTED), (csip.SirkType.PLAINTEXT)]
'sirk_type', [(csip.SirkType.ENCRYPTED), (csip.SirkType.PLAINTEXT)]
)
async def test_csis(sirk_type):
SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
+3 -3
View File
@@ -278,7 +278,7 @@ async def test_legacy_advertising():
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'auto_restart,',
'auto_restart',
(True, False),
)
@pytest.mark.asyncio
@@ -357,7 +357,7 @@ async def test_advertising_and_scanning():
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'own_address_type,',
'own_address_type',
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
)
@pytest.mark.asyncio
@@ -395,7 +395,7 @@ async def test_extended_advertising_connection(own_address_type):
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'own_address_type,',
'own_address_type',
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
)
@pytest.mark.asyncio
+3 -3
View File
@@ -297,7 +297,7 @@ def test_custom_le_meta_event():
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"clazz,",
"clazz",
[
clazz[1]
for clazz in inspect.getmembers(hci)
@@ -313,7 +313,7 @@ def test_hci_command_subclasses_op_code(clazz: type[hci.HCI_Command]):
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"clazz,",
"clazz",
[
clazz[1]
for clazz in inspect.getmembers(hci)
@@ -330,7 +330,7 @@ def test_hci_event_subclasses_event_code(clazz: type[hci.HCI_Event]):
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"clazz,",
"clazz",
[
clazz[1]
for clazz in inspect.getmembers(hci)
+2 -2
View File
@@ -333,7 +333,7 @@ async def test_query_calls_with_calls(
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
"operation,",
"operation",
(
hfp.CallHoldOperation.RELEASE_ALL_HELD_CALLS,
hfp.CallHoldOperation.RELEASE_ALL_ACTIVE_CALLS,
@@ -358,7 +358,7 @@ async def test_hold_call_without_call_index(
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
"operation,",
"operation",
(
hfp.CallHoldOperation.RELEASE_SPECIFIC_CALL,
hfp.CallHoldOperation.HOLD_ALL_CALLS_EXCEPT,
+77 -2
View File
@@ -197,7 +197,7 @@ async def test_basic_connection():
# -----------------------------------------------------------------------------
@pytest.mark.parametrize("info_type,", list(l2cap.L2CAP_Information_Request.InfoType))
@pytest.mark.parametrize("info_type", list(l2cap.L2CAP_Information_Request.InfoType))
async def test_l2cap_information_request(monkeypatch, info_type):
# TODO: Replace handlers with API when implemented
devices = await TwoDevices.create_with_connection()
@@ -321,7 +321,7 @@ async def test_mtu():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize("mtu,", (50, 255, 256, 1000))
@pytest.mark.parametrize("mtu", (50, 255, 256, 1000))
async def test_enhanced_retransmission_mode(mtu: int):
devices = TwoDevices()
await devices.setup_connection()
@@ -457,6 +457,81 @@ def test_fcs(cid: int, payload: str, expected: str):
assert pdu.to_bytes(with_fcs=True) == bytes.fromhex(expected)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_abort_while_disconnecting():
# Setup client and server channels
devices = TwoDevices()
await devices.setup_connection()
psm = 1234
server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
devices.devices[1].create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=psm),
handler=server_channels.put_nowait,
)
client_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(psm)
)
server_channel = await server_channels.get()
# Stub the server's request handler to ignore the disconnection request.
# This keeps the client channel in the DISCONNECTING state, waiting for a response,
# so we can simulate calling abort() during an active disconnection.
server_channel.on_disconnection_request = lambda request: None
# Intercept state change to DISCONNECTING and call abort()
original_change_state = client_channel._change_state
abort_called = False
def intercept_change_state_and_abort(new_state):
nonlocal abort_called
original_change_state(new_state)
if (
new_state == l2cap.LeCreditBasedChannel.State.DISCONNECTING
and not abort_called
):
abort_called = True
client_channel.abort()
client_channel._change_state = intercept_change_state_and_abort
# Start disconnection and wait with a timeout. It should resolve immediately due to the abort.
await asyncio.wait_for(client_channel.disconnect(), timeout=1.0)
assert client_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_disconnection_collision():
# Setup client and server channels
devices = TwoDevices()
await devices.setup_connection()
psm = 1234
server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
devices.devices[1].create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=psm),
handler=server_channels.put_nowait,
)
client_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(psm)
)
server_channel = await server_channels.get()
# Trigger disconnection from both sides concurrently to cause a collision.
# Both channels will transition to DISCONNECTING and send DISCONNECTION_REQUESTs.
# When each side receives the peer's request, it will handle it and resolve the
# disconnection_result future.
await asyncio.gather(
client_channel.disconnect(),
server_channel.disconnect(),
)
assert client_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED
assert server_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED
# -----------------------------------------------------------------------------
async def run():
test_helpers()
+1 -1
View File
@@ -68,7 +68,7 @@ async def test_self_disconnection():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
'responder_role,',
'responder_role',
(Role.CENTRAL, Role.PERIPHERAL),
)
async def test_self_classic_connection(responder_role):
+3 -3
View File
@@ -102,7 +102,7 @@ def test_parser_extensions():
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"address,",
"address",
("127.0.0.1", "::1"),
)
async def test_tcp_connection(address):
@@ -205,7 +205,7 @@ async def test_unix_connection_abstract():
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"address,",
"address",
("127.0.0.1", "[::1]"),
)
async def test_android_netsim_connection(address):
@@ -228,7 +228,7 @@ async def test_android_netsim_connection(address):
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"spec,",
"spec",
(
"android-netsim:[::1]:{port},mode=host[a=b,c=d]",
"android-netsim:localhost:{port},mode=host[a=b,c=d]",