diff --git a/bumble/device.py b/bumble/device.py index 2762daf8..d1b4f31c 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -2171,7 +2171,7 @@ def with_connection_from_address(function): @functools.wraps(function) def wrapper(device: Device, address: hci.Address, *args, **kwargs): if connection := device.pending_connections.get(address): - return function(device, connection, address, *args, **kwargs) + return function(device, connection, *args, **kwargs) for connection in device.connections.values(): if connection.peer_address == address: return function(device, connection, *args, **kwargs) @@ -6443,18 +6443,14 @@ class Device(utils.CompositeEventEmitter): # [Classic only] @host_event_handler - @try_with_connection_from_address + @with_connection_from_address def on_role_change( self, - connection: Optional[Connection], - peer_address: hci.Address, + connection: Connection, new_role: hci.Role, ): - if connection: - connection.role = new_role - connection.emit(connection.EVENT_ROLE_CHANGE, new_role) - else: - logger.warning("Role change to unknown connection %s", peer_address) + connection.role = new_role + connection.emit(connection.EVENT_ROLE_CHANGE, new_role) # [Classic only] @host_event_handler diff --git a/tests/device_test.py b/tests/device_test.py index cc3e0631..d7717651 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -761,6 +761,34 @@ async def test_inquiry_result_with_rssi(): m.assert_called_with(hci.Address("00:11:22:33:44:55/P"), 3, mock.ANY, 5) +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + "roles", + ( + (hci.Role.PERIPHERAL, hci.Role.CENTRAL), + (hci.Role.CENTRAL, hci.Role.PERIPHERAL), + ), +) +@pytest.mark.asyncio +async def test_accept_classic_connection(roles: tuple[hci.Role, hci.Role]): + devices = TwoDevices() + devices[0].classic_enabled = True + devices[1].classic_enabled = True + await devices[0].power_on() + await devices[1].power_on() + + accept_task = asyncio.create_task(devices[1].accept(role=roles[1])) + await devices[0].connect( + devices[1].public_address, transport=PhysicalTransport.BR_EDR + ) + await accept_task + + assert devices.connections[0] + assert devices.connections[0].role == roles[0] + assert devices.connections[1] + assert devices.connections[1].role == roles[1] + + # ----------------------------------------------------------------------------- async def run_test_device(): await test_device_connect_parallel()