Fix wrong with_connection_from_address parameter

This commit is contained in:
Josh Wu
2025-09-23 17:17:36 +08:00
parent f8223ca81f
commit 85215df2c3
2 changed files with 33 additions and 9 deletions

View File

@@ -2171,7 +2171,7 @@ def with_connection_from_address(function):
@functools.wraps(function)
def wrapper(device: Device, address: hci.Address, *args, **kwargs):
if connection := device.pending_connections.get(address):
return function(device, connection, address, *args, **kwargs)
return function(device, connection, *args, **kwargs)
for connection in device.connections.values():
if connection.peer_address == address:
return function(device, connection, *args, **kwargs)
@@ -6443,18 +6443,14 @@ class Device(utils.CompositeEventEmitter):
# [Classic only]
@host_event_handler
@try_with_connection_from_address
@with_connection_from_address
def on_role_change(
self,
connection: Optional[Connection],
peer_address: hci.Address,
connection: Connection,
new_role: hci.Role,
):
if connection:
connection.role = new_role
connection.emit(connection.EVENT_ROLE_CHANGE, new_role)
else:
logger.warning("Role change to unknown connection %s", peer_address)
connection.role = new_role
connection.emit(connection.EVENT_ROLE_CHANGE, new_role)
# [Classic only]
@host_event_handler

View File

@@ -761,6 +761,34 @@ async def test_inquiry_result_with_rssi():
m.assert_called_with(hci.Address("00:11:22:33:44:55/P"), 3, mock.ANY, 5)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"roles",
(
(hci.Role.PERIPHERAL, hci.Role.CENTRAL),
(hci.Role.CENTRAL, hci.Role.PERIPHERAL),
),
)
@pytest.mark.asyncio
async def test_accept_classic_connection(roles: tuple[hci.Role, hci.Role]):
devices = TwoDevices()
devices[0].classic_enabled = True
devices[1].classic_enabled = True
await devices[0].power_on()
await devices[1].power_on()
accept_task = asyncio.create_task(devices[1].accept(role=roles[1]))
await devices[0].connect(
devices[1].public_address, transport=PhysicalTransport.BR_EDR
)
await accept_task
assert devices.connections[0]
assert devices.connections[0].role == roles[0]
assert devices.connections[1]
assert devices.connections[1].role == roles[1]
# -----------------------------------------------------------------------------
async def run_test_device():
await test_device_connect_parallel()