From c6b3deb8df80a8b375476dc76e28162adb6adc01 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Fri, 9 Aug 2024 23:46:10 +0800 Subject: [PATCH] Fix CTKD failure introduced by Host RPA generation --- bumble/device.py | 66 ++++++++++++++++++++++++-------------------- bumble/smp.py | 18 ++++++++---- tests/device_test.py | 10 +++++++ tests/smp_test.py | 55 ++++++++++++++++++++++++++++++++++++ 4 files changed, 114 insertions(+), 35 deletions(-) diff --git a/bumble/device.py b/bumble/device.py index ff3c3493..03b18624 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -203,13 +203,13 @@ from .keys import ( KeyStore, PairingKeys, ) -from .pairing import PairingConfig -from . import gatt_client -from . import gatt_server -from . import smp -from . import sdp -from . import l2cap -from . import core +from bumble import pairing +from bumble import gatt_client +from bumble import gatt_server +from bumble import smp +from bumble import sdp +from bumble import l2cap +from bumble import core if TYPE_CHECKING: from .transport.common import TransportSource, TransportSink @@ -1595,6 +1595,8 @@ class DeviceConfiguration: address_resolution_offload: bool = False address_generation_offload: bool = False cis_enabled: bool = False + identity_address_type: Optional[int] = None + io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT def __post_init__(self) -> None: self.gatt_services: List[Dict[str, Any]] = [] @@ -1959,7 +1961,19 @@ class Device(CompositeEventEmitter): # Setup SMP self.smp_manager = smp.Manager( - self, pairing_config_factory=lambda connection: PairingConfig() + self, + pairing_config_factory=lambda connection: pairing.PairingConfig( + identity_address_type=( + pairing.PairingConfig.AddressType(self.config.identity_address_type) + if self.config.identity_address_type + else None + ), + delegate=pairing.PairingDelegate( + io_capability=pairing.PairingDelegate.IoCapability( + self.config.io_capability + ) + ), + ), ) self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu) @@ -2183,10 +2197,15 @@ class Device(CompositeEventEmitter): HCI_Write_LE_Host_Support_Command( le_supported_host=int(self.le_enabled), simultaneous_le_host=int(self.le_simultaneous_enabled), - ) + ), + check_result=True, ) if self.le_enabled: + # Generate a random address if not set. + if self.static_address == Address.ANY_RANDOM: + self.static_address = Address.generate_static_address() + # If LE Privacy is enabled, generate an RPA if self.le_privacy_enabled: self.random_address = Address.generate_private_address(self.irk) @@ -2196,23 +2215,8 @@ class Device(CompositeEventEmitter): self.le_rpa_periodic_update_task = asyncio.create_task( self._run_rpa_periodic_update() ) - - # Set the controller address - if self.random_address == Address.ANY_RANDOM: - # Try to use an address generated at random by the controller - if self.host.supports_command(HCI_LE_RAND_COMMAND): - # Get 8 random bytes - response = await self.send_command( - HCI_LE_Rand_Command(), check_result=True - ) - - # Ensure the address bytes can be a static random address - address_bytes = response.return_parameters.random_number[ - :5 - ] + bytes([response.return_parameters.random_number[5] | 0xC0]) - - # Create a static random address from the random bytes - self.random_address = Address(address_bytes) + else: + self.random_address = self.static_address if self.random_address != Address.ANY_RANDOM: logger.debug( @@ -2237,7 +2241,8 @@ class Device(CompositeEventEmitter): await self.send_command( HCI_LE_Set_Address_Resolution_Enable_Command( address_resolution_enable=1 - ) + ), + check_result=True, ) if self.cis_enabled: @@ -2245,7 +2250,8 @@ class Device(CompositeEventEmitter): HCI_LE_Set_Host_Feature_Command( bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM, bit_value=1, - ) + ), + check_result=True, ) if self.classic_enabled: @@ -3572,12 +3578,12 @@ class Device(CompositeEventEmitter): await self.stop_discovery() @property - def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]: + def pairing_config_factory(self) -> Callable[[Connection], pairing.PairingConfig]: return self.smp_manager.pairing_config_factory @pairing_config_factory.setter def pairing_config_factory( - self, pairing_config_factory: Callable[[Connection], PairingConfig] + self, pairing_config_factory: Callable[[Connection], pairing.PairingConfig] ) -> None: self.smp_manager.pairing_config_factory = pairing_config_factory diff --git a/bumble/smp.py b/bumble/smp.py index 9eba42dc..5d6bcc5f 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -1078,11 +1078,19 @@ class Session: ) def send_identity_address_command(self) -> None: - identity_address = { - None: self.manager.device.static_address, - Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address, - Address.RANDOM_DEVICE_ADDRESS: self.manager.device.static_address, - }[self.pairing_config.identity_address_type] + if self.pairing_config.identity_address_type == Address.PUBLIC_DEVICE_ADDRESS: + identity_address = self.manager.device.public_address + elif self.pairing_config.identity_address_type == Address.RANDOM_DEVICE_ADDRESS: + identity_address = self.manager.device.static_address + else: + # No identity address type set. If the controller has a public address, it + # will be more responsible to be the identity address. + if self.manager.device.public_address != Address.ANY: + logger.debug("No identity address type set, using PUBLIC") + identity_address = self.manager.device.public_address + else: + logger.debug("No identity address type set, using RANDOM") + identity_address = self.manager.device.static_address self.send_command( SMP_Identity_Address_Information_Command( addr_type=identity_address.address_type, diff --git a/tests/device_test.py b/tests/device_test.py index 3b30f601..45b84ce1 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -536,6 +536,16 @@ async def test_cis_setup_failure(): await asyncio.wait_for(cis_create_task, _TIMEOUT) +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_power_on_default_static_address_should_not_be_any(): + devices = TwoDevices() + devices[0].static_address = devices[0].random_address = Address.ANY_RANDOM + await devices[0].power_on() + + assert devices[0].static_address != Address.ANY_RANDOM + + # ----------------------------------------------------------------------------- def test_gatt_services_with_gas(): device = Device(host=Host(None, None)) diff --git a/tests/smp_test.py b/tests/smp_test.py index 7a32b23c..7f17bc28 100644 --- a/tests/smp_test.py +++ b/tests/smp_test.py @@ -17,13 +17,17 @@ # ----------------------------------------------------------------------------- import pytest +from unittest import mock from bumble import smp +from bumble import pairing from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1 from bumble.pairing import OobData, OobSharedData, LeRole from bumble.hci import Address from bumble.core import AdvertisingData +from bumble.device import Device +from typing import Optional # ----------------------------------------------------------------------------- # pylint: disable=invalid-name @@ -251,6 +255,57 @@ def test_link_key_to_ltk(ct2: bool, expected: str): assert smp.Session.derive_ltk(LINK_KEY, ct2) == reversed_hex(expected) +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + 'identity_address_type, public_address, random_address, expected_identity_address', + [ + ( + None, + Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS), + Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS), + Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS), + ), + ( + None, + Address.ANY, + Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS), + Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS), + ), + ( + pairing.PairingConfig.AddressType.PUBLIC, + Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS), + Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS), + Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS), + ), + ( + pairing.PairingConfig.AddressType.RANDOM, + Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS), + Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS), + Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS), + ), + ], +) +@pytest.mark.asyncio +async def test_send_identity_address_command( + identity_address_type: Optional[pairing.PairingConfig.AddressType], + public_address: Address, + random_address: Address, + expected_identity_address: Address, +): + device = Device() + device.public_address = public_address + device.static_address = random_address + pairing_config = pairing.PairingConfig(identity_address_type=identity_address_type) + session = smp.Session(device.smp_manager, mock.MagicMock(), pairing_config, True) + + with mock.patch.object(session, 'send_command') as mock_method: + session.send_identity_address_command() + + actual_command = mock_method.call_args.args[0] + assert actual_command.addr_type == expected_identity_address.address_type + assert actual_command.bd_addr == expected_identity_address + + # ----------------------------------------------------------------------------- if __name__ == '__main__': test_ecc()