Compare commits

..

1 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod 16dd5ae63d remove trailing commas in pytest parametrize argname 2026-06-21 13:31:03 +02:00
3 changed files with 29 additions and 116 deletions
+23 -34
View File
@@ -1683,10 +1683,6 @@ class LeCreditBasedChannel(utils.EventEmitter):
if identifier in self.manager.le_coc_requests:
raise InvalidStateError('too many concurrent connection requests')
# Create a future to wait for the response
connection_result = asyncio.get_running_loop().create_future()
self.connection_result = connection_result
self._change_state(self.State.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request(
identifier=identifier,
@@ -1699,19 +1695,17 @@ class LeCreditBasedChannel(utils.EventEmitter):
self.manager.le_coc_requests[identifier] = request
self.send_control_frame(request)
# Create a future to wait for the response
self.connection_result = asyncio.get_running_loop().create_future()
# Wait for the connection to succeed or fail
return await connection_result
return await self.connection_result
async def disconnect(self) -> None:
# Check that we're connected
if self.state != self.State.CONNECTED:
raise InvalidStateError('not connected')
# Create a future to wait for the state machine to get to a success or error
# state
disconnection_result = asyncio.get_running_loop().create_future()
self.disconnection_result = disconnection_result
self._change_state(self.State.DISCONNECTING)
self.flush_output()
self.send_control_frame(
@@ -1722,18 +1716,17 @@ class LeCreditBasedChannel(utils.EventEmitter):
)
)
return await disconnection_result
# Create a future to wait for the state machine to get to a success or error
# state
self.disconnection_result = asyncio.get_running_loop().create_future()
return await self.disconnection_result
def abort(self) -> None:
if self.state in (self.State.CONNECTED, self.State.DISCONNECTING):
if self.state == self.State.CONNECTED:
self._change_state(self.State.DISCONNECTED)
self.manager.on_channel_closed(self)
if self.connection_result is not None:
self.connection_result.cancel()
self.connection_result = None
if self.disconnection_result is not None:
self.disconnection_result.set_result(None)
self.disconnection_result = None
if self.state == self.State.CONNECTING:
if self.connection_result is not None:
self.connection_result.cancel()
def on_pdu(self, pdu: bytes) -> None:
if self.sink is None:
@@ -1868,10 +1861,6 @@ class LeCreditBasedChannel(utils.EventEmitter):
)
)
self._change_state(self.State.DISCONNECTED)
self.manager.on_channel_closed(self)
if self.disconnection_result is not None:
self.disconnection_result.set_result(None)
self.disconnection_result = None
self.flush_output()
def on_disconnection_response(self, response: L2CAP_Disconnection_Response) -> None:
@@ -1887,7 +1876,6 @@ class LeCreditBasedChannel(utils.EventEmitter):
return
self._change_state(self.State.DISCONNECTED)
self.manager.on_channel_closed(self)
if self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
@@ -2473,9 +2461,12 @@ class ChannelManager:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
logger.debug(
f'channel {response.source_cid} not found for '
f'0x{connection.handle:04X}:{cid}'
logger.warning(
color(
f'channel {response.source_cid} not found for '
f'0x{connection.handle:04X}:{cid}',
'red',
)
)
return
@@ -2888,13 +2879,11 @@ class ChannelManager:
channel.on_credits(credit.credits)
def on_channel_closed(self, channel: ClassicChannel | LeCreditBasedChannel) -> None:
if classic_connection_channels := self.channels.get(channel.connection.handle):
classic_connection_channels.pop(channel.source_cid, None)
elif le_connection_channels := self.le_coc_channels.get(
channel.connection.handle
):
le_connection_channels.pop(channel.destination_cid, None)
def on_channel_closed(self, channel: ClassicChannel) -> None:
connection_channels = self.channels.get(channel.connection.handle)
if connection_channels:
if channel.source_cid in connection_channels:
del connection_channels[channel.source_cid]
async def create_le_credit_based_channel(
self,
+3 -4
View File
@@ -43,9 +43,9 @@ dependencies = [
[project.optional-dependencies]
build = ["build >= 0.7"]
test = [
"pytest >= 8.2",
"pytest-asyncio >= 0.23.5",
"pytest-html >= 3.2.0",
"pytest >= 9.0",
"pytest-asyncio >= 1.4",
"pytest-html >= 4.2",
"coverage >= 6.4",
]
development = [
@@ -132,7 +132,6 @@ write_to = "bumble/_version.py"
[tool.pytest.ini_options]
pythonpath = "."
testpaths = ["tests"]
asyncio_mode = "auto"
[tool.pylint.master]
init-hook = 'import sys; sys.path.append(".")'
+3 -78
View File
@@ -49,19 +49,19 @@ def test_helpers():
psm = l2cap.L2CAP_Connection_Request.serialize_psm(0x242311)
assert psm == bytes([0x11, 0x23, 0x24])
(offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm(
offset, psm = l2cap.L2CAP_Connection_Request.parse_psm(
bytes([0x00, 0x01, 0x00, 0x44]), 1
)
assert offset == 3
assert psm == 0x01
(offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm(
offset, psm = l2cap.L2CAP_Connection_Request.parse_psm(
bytes([0x00, 0x23, 0x10, 0x44]), 1
)
assert offset == 3
assert psm == 0x1023
(offset, psm) = l2cap.L2CAP_Connection_Request.parse_psm(
offset, psm = l2cap.L2CAP_Connection_Request.parse_psm(
bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1
)
assert offset == 4
@@ -457,81 +457,6 @@ def test_fcs(cid: int, payload: str, expected: str):
assert pdu.to_bytes(with_fcs=True) == bytes.fromhex(expected)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_abort_while_disconnecting():
# Setup client and server channels
devices = TwoDevices()
await devices.setup_connection()
psm = 1234
server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
devices.devices[1].create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=psm),
handler=server_channels.put_nowait,
)
client_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(psm)
)
server_channel = await server_channels.get()
# Stub the server's request handler to ignore the disconnection request.
# This keeps the client channel in the DISCONNECTING state, waiting for a response,
# so we can simulate calling abort() during an active disconnection.
server_channel.on_disconnection_request = lambda request: None
# Intercept state change to DISCONNECTING and call abort()
original_change_state = client_channel._change_state
abort_called = False
def intercept_change_state_and_abort(new_state):
nonlocal abort_called
original_change_state(new_state)
if (
new_state == l2cap.LeCreditBasedChannel.State.DISCONNECTING
and not abort_called
):
abort_called = True
client_channel.abort()
client_channel._change_state = intercept_change_state_and_abort
# Start disconnection and wait with a timeout. It should resolve immediately due to the abort.
await asyncio.wait_for(client_channel.disconnect(), timeout=1.0)
assert client_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_disconnection_collision():
# Setup client and server channels
devices = TwoDevices()
await devices.setup_connection()
psm = 1234
server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
devices.devices[1].create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=psm),
handler=server_channels.put_nowait,
)
client_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(psm)
)
server_channel = await server_channels.get()
# Trigger disconnection from both sides concurrently to cause a collision.
# Both channels will transition to DISCONNECTING and send DISCONNECTION_REQUESTs.
# When each side receives the peer's request, it will handle it and resolve the
# disconnection_result future.
await asyncio.gather(
client_channel.disconnect(),
server_channel.disconnect(),
)
assert client_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED
assert server_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED
# -----------------------------------------------------------------------------
async def run():
test_helpers()