mirror of
https://github.com/google/bumble.git
synced 2026-06-17 10:02:27 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8f00e82fac | |||
| 5bff934868 | |||
| 3c8fe5637d |
+34
-23
@@ -1683,6 +1683,10 @@ class LeCreditBasedChannel(utils.EventEmitter):
|
||||
if identifier in self.manager.le_coc_requests:
|
||||
raise InvalidStateError('too many concurrent connection requests')
|
||||
|
||||
# Create a future to wait for the response
|
||||
connection_result = asyncio.get_running_loop().create_future()
|
||||
self.connection_result = connection_result
|
||||
|
||||
self._change_state(self.State.CONNECTING)
|
||||
request = L2CAP_LE_Credit_Based_Connection_Request(
|
||||
identifier=identifier,
|
||||
@@ -1695,17 +1699,19 @@ class LeCreditBasedChannel(utils.EventEmitter):
|
||||
self.manager.le_coc_requests[identifier] = request
|
||||
self.send_control_frame(request)
|
||||
|
||||
# Create a future to wait for the response
|
||||
self.connection_result = asyncio.get_running_loop().create_future()
|
||||
|
||||
# Wait for the connection to succeed or fail
|
||||
return await self.connection_result
|
||||
return await connection_result
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
# Check that we're connected
|
||||
if self.state != self.State.CONNECTED:
|
||||
raise InvalidStateError('not connected')
|
||||
|
||||
# Create a future to wait for the state machine to get to a success or error
|
||||
# state
|
||||
disconnection_result = asyncio.get_running_loop().create_future()
|
||||
self.disconnection_result = disconnection_result
|
||||
|
||||
self._change_state(self.State.DISCONNECTING)
|
||||
self.flush_output()
|
||||
self.send_control_frame(
|
||||
@@ -1716,17 +1722,18 @@ class LeCreditBasedChannel(utils.EventEmitter):
|
||||
)
|
||||
)
|
||||
|
||||
# Create a future to wait for the state machine to get to a success or error
|
||||
# state
|
||||
self.disconnection_result = asyncio.get_running_loop().create_future()
|
||||
return await self.disconnection_result
|
||||
return await disconnection_result
|
||||
|
||||
def abort(self) -> None:
|
||||
if self.state == self.State.CONNECTED:
|
||||
if self.state in (self.State.CONNECTED, self.State.DISCONNECTING):
|
||||
self._change_state(self.State.DISCONNECTED)
|
||||
if self.state == self.State.CONNECTING:
|
||||
if self.connection_result is not None:
|
||||
self.connection_result.cancel()
|
||||
self.manager.on_channel_closed(self)
|
||||
if self.connection_result is not None:
|
||||
self.connection_result.cancel()
|
||||
self.connection_result = None
|
||||
if self.disconnection_result is not None:
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
if self.sink is None:
|
||||
@@ -1861,6 +1868,10 @@ class LeCreditBasedChannel(utils.EventEmitter):
|
||||
)
|
||||
)
|
||||
self._change_state(self.State.DISCONNECTED)
|
||||
self.manager.on_channel_closed(self)
|
||||
if self.disconnection_result is not None:
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
self.flush_output()
|
||||
|
||||
def on_disconnection_response(self, response: L2CAP_Disconnection_Response) -> None:
|
||||
@@ -1876,6 +1887,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
|
||||
return
|
||||
|
||||
self._change_state(self.State.DISCONNECTED)
|
||||
self.manager.on_channel_closed(self)
|
||||
if self.disconnection_result:
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
@@ -2461,12 +2473,9 @@ class ChannelManager:
|
||||
if (
|
||||
channel := self.find_channel(connection.handle, response.source_cid)
|
||||
) is None:
|
||||
logger.warning(
|
||||
color(
|
||||
f'channel {response.source_cid} not found for '
|
||||
f'0x{connection.handle:04X}:{cid}',
|
||||
'red',
|
||||
)
|
||||
logger.debug(
|
||||
f'channel {response.source_cid} not found for '
|
||||
f'0x{connection.handle:04X}:{cid}'
|
||||
)
|
||||
return
|
||||
|
||||
@@ -2879,11 +2888,13 @@ class ChannelManager:
|
||||
|
||||
channel.on_credits(credit.credits)
|
||||
|
||||
def on_channel_closed(self, channel: ClassicChannel) -> None:
|
||||
connection_channels = self.channels.get(channel.connection.handle)
|
||||
if connection_channels:
|
||||
if channel.source_cid in connection_channels:
|
||||
del connection_channels[channel.source_cid]
|
||||
def on_channel_closed(self, channel: ClassicChannel | LeCreditBasedChannel) -> None:
|
||||
if classic_connection_channels := self.channels.get(channel.connection.handle):
|
||||
classic_connection_channels.pop(channel.source_cid, None)
|
||||
elif le_connection_channels := self.le_coc_channels.get(
|
||||
channel.connection.handle
|
||||
):
|
||||
le_connection_channels.pop(channel.destination_cid, None)
|
||||
|
||||
async def create_le_credit_based_channel(
|
||||
self,
|
||||
|
||||
@@ -132,6 +132,7 @@ write_to = "bumble/_version.py"
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = "."
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.pylint.master]
|
||||
init-hook = 'import sys; sys.path.append(".")'
|
||||
|
||||
+3
-3
@@ -46,7 +46,7 @@ class TwoDevices(test_utils.TwoDevices):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"command,",
|
||||
"command",
|
||||
[
|
||||
avrcp.GetPlayStatusCommand(),
|
||||
avrcp.GetCapabilitiesCommand(
|
||||
@@ -132,7 +132,7 @@ def test_command(command: avrcp.Command):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"event,",
|
||||
"event",
|
||||
[
|
||||
avrcp.UidsChangedEvent(uid_counter=7),
|
||||
avrcp.TrackChangedEvent(uid=12356),
|
||||
@@ -159,7 +159,7 @@ def test_event(event: avrcp.Event):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response,",
|
||||
"response",
|
||||
[
|
||||
avrcp.GetPlayStatusResponse(
|
||||
song_length=1010, song_position=13, play_status=avrcp.PlayStatus.PAUSED
|
||||
|
||||
+1
-1
@@ -72,7 +72,7 @@ def test_sef():
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'sirk_type,', [(csip.SirkType.ENCRYPTED), (csip.SirkType.PLAINTEXT)]
|
||||
'sirk_type', [(csip.SirkType.ENCRYPTED), (csip.SirkType.PLAINTEXT)]
|
||||
)
|
||||
async def test_csis(sirk_type):
|
||||
SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
|
||||
|
||||
@@ -278,7 +278,7 @@ async def test_legacy_advertising():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'auto_restart,',
|
||||
'auto_restart',
|
||||
(True, False),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@@ -357,7 +357,7 @@ async def test_advertising_and_scanning():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'own_address_type,',
|
||||
'own_address_type',
|
||||
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@@ -395,7 +395,7 @@ async def test_extended_advertising_connection(own_address_type):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'own_address_type,',
|
||||
'own_address_type',
|
||||
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
|
||||
+3
-3
@@ -297,7 +297,7 @@ def test_custom_le_meta_event():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"clazz,",
|
||||
"clazz",
|
||||
[
|
||||
clazz[1]
|
||||
for clazz in inspect.getmembers(hci)
|
||||
@@ -313,7 +313,7 @@ def test_hci_command_subclasses_op_code(clazz: type[hci.HCI_Command]):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"clazz,",
|
||||
"clazz",
|
||||
[
|
||||
clazz[1]
|
||||
for clazz in inspect.getmembers(hci)
|
||||
@@ -330,7 +330,7 @@ def test_hci_event_subclasses_event_code(clazz: type[hci.HCI_Event]):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"clazz,",
|
||||
"clazz",
|
||||
[
|
||||
clazz[1]
|
||||
for clazz in inspect.getmembers(hci)
|
||||
|
||||
+2
-2
@@ -333,7 +333,7 @@ async def test_query_calls_with_calls(
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"operation,",
|
||||
"operation",
|
||||
(
|
||||
hfp.CallHoldOperation.RELEASE_ALL_HELD_CALLS,
|
||||
hfp.CallHoldOperation.RELEASE_ALL_ACTIVE_CALLS,
|
||||
@@ -358,7 +358,7 @@ async def test_hold_call_without_call_index(
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"operation,",
|
||||
"operation",
|
||||
(
|
||||
hfp.CallHoldOperation.RELEASE_SPECIFIC_CALL,
|
||||
hfp.CallHoldOperation.HOLD_ALL_CALLS_EXCEPT,
|
||||
|
||||
+77
-2
@@ -197,7 +197,7 @@ async def test_basic_connection():
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize("info_type,", list(l2cap.L2CAP_Information_Request.InfoType))
|
||||
@pytest.mark.parametrize("info_type", list(l2cap.L2CAP_Information_Request.InfoType))
|
||||
async def test_l2cap_information_request(monkeypatch, info_type):
|
||||
# TODO: Replace handlers with API when implemented
|
||||
devices = await TwoDevices.create_with_connection()
|
||||
@@ -321,7 +321,7 @@ async def test_mtu():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("mtu,", (50, 255, 256, 1000))
|
||||
@pytest.mark.parametrize("mtu", (50, 255, 256, 1000))
|
||||
async def test_enhanced_retransmission_mode(mtu: int):
|
||||
devices = TwoDevices()
|
||||
await devices.setup_connection()
|
||||
@@ -457,6 +457,81 @@ def test_fcs(cid: int, payload: str, expected: str):
|
||||
assert pdu.to_bytes(with_fcs=True) == bytes.fromhex(expected)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_while_disconnecting():
|
||||
# Setup client and server channels
|
||||
devices = TwoDevices()
|
||||
await devices.setup_connection()
|
||||
psm = 1234
|
||||
|
||||
server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
|
||||
devices.devices[1].create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(psm=psm),
|
||||
handler=server_channels.put_nowait,
|
||||
)
|
||||
client_channel = await devices.connections[0].create_l2cap_channel(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(psm)
|
||||
)
|
||||
server_channel = await server_channels.get()
|
||||
|
||||
# Stub the server's request handler to ignore the disconnection request.
|
||||
# This keeps the client channel in the DISCONNECTING state, waiting for a response,
|
||||
# so we can simulate calling abort() during an active disconnection.
|
||||
server_channel.on_disconnection_request = lambda request: None
|
||||
|
||||
# Intercept state change to DISCONNECTING and call abort()
|
||||
original_change_state = client_channel._change_state
|
||||
abort_called = False
|
||||
|
||||
def intercept_change_state_and_abort(new_state):
|
||||
nonlocal abort_called
|
||||
original_change_state(new_state)
|
||||
if (
|
||||
new_state == l2cap.LeCreditBasedChannel.State.DISCONNECTING
|
||||
and not abort_called
|
||||
):
|
||||
abort_called = True
|
||||
client_channel.abort()
|
||||
|
||||
client_channel._change_state = intercept_change_state_and_abort
|
||||
|
||||
# Start disconnection and wait with a timeout. It should resolve immediately due to the abort.
|
||||
await asyncio.wait_for(client_channel.disconnect(), timeout=1.0)
|
||||
assert client_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnection_collision():
|
||||
# Setup client and server channels
|
||||
devices = TwoDevices()
|
||||
await devices.setup_connection()
|
||||
psm = 1234
|
||||
|
||||
server_channels = asyncio.Queue[l2cap.LeCreditBasedChannel]()
|
||||
devices.devices[1].create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(psm=psm),
|
||||
handler=server_channels.put_nowait,
|
||||
)
|
||||
client_channel = await devices.connections[0].create_l2cap_channel(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(psm)
|
||||
)
|
||||
server_channel = await server_channels.get()
|
||||
|
||||
# Trigger disconnection from both sides concurrently to cause a collision.
|
||||
# Both channels will transition to DISCONNECTING and send DISCONNECTION_REQUESTs.
|
||||
# When each side receives the peer's request, it will handle it and resolve the
|
||||
# disconnection_result future.
|
||||
await asyncio.gather(
|
||||
client_channel.disconnect(),
|
||||
server_channel.disconnect(),
|
||||
)
|
||||
|
||||
assert client_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED
|
||||
assert server_channel.state == l2cap.LeCreditBasedChannel.State.DISCONNECTED
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run():
|
||||
test_helpers()
|
||||
|
||||
+1
-1
@@ -68,7 +68,7 @@ async def test_self_disconnection():
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'responder_role,',
|
||||
'responder_role',
|
||||
(Role.CENTRAL, Role.PERIPHERAL),
|
||||
)
|
||||
async def test_self_classic_connection(responder_role):
|
||||
|
||||
@@ -102,7 +102,7 @@ def test_parser_extensions():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"address,",
|
||||
"address",
|
||||
("127.0.0.1", "::1"),
|
||||
)
|
||||
async def test_tcp_connection(address):
|
||||
@@ -205,7 +205,7 @@ async def test_unix_connection_abstract():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"address,",
|
||||
"address",
|
||||
("127.0.0.1", "[::1]"),
|
||||
)
|
||||
async def test_android_netsim_connection(address):
|
||||
@@ -228,7 +228,7 @@ async def test_android_netsim_connection(address):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
"spec,",
|
||||
"spec",
|
||||
(
|
||||
"android-netsim:[::1]:{port},mode=host[a=b,c=d]",
|
||||
"android-netsim:localhost:{port},mode=host[a=b,c=d]",
|
||||
|
||||
Reference in New Issue
Block a user