Fix CTKD failure introduced by Host RPA generation

This commit is contained in:
Josh Wu
2024-08-09 23:46:10 +08:00
parent 4433184048
commit c6b3deb8df
4 changed files with 114 additions and 35 deletions

View File

@@ -203,13 +203,13 @@ from .keys import (
KeyStore, KeyStore,
PairingKeys, PairingKeys,
) )
from .pairing import PairingConfig from bumble import pairing
from . import gatt_client from bumble import gatt_client
from . import gatt_server from bumble import gatt_server
from . import smp from bumble import smp
from . import sdp from bumble import sdp
from . import l2cap from bumble import l2cap
from . import core from bumble import core
if TYPE_CHECKING: if TYPE_CHECKING:
from .transport.common import TransportSource, TransportSink from .transport.common import TransportSource, TransportSink
@@ -1595,6 +1595,8 @@ class DeviceConfiguration:
address_resolution_offload: bool = False address_resolution_offload: bool = False
address_generation_offload: bool = False address_generation_offload: bool = False
cis_enabled: bool = False cis_enabled: bool = False
identity_address_type: Optional[int] = None
io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.gatt_services: List[Dict[str, Any]] = [] self.gatt_services: List[Dict[str, Any]] = []
@@ -1959,7 +1961,19 @@ class Device(CompositeEventEmitter):
# Setup SMP # Setup SMP
self.smp_manager = smp.Manager( self.smp_manager = smp.Manager(
self, pairing_config_factory=lambda connection: PairingConfig() self,
pairing_config_factory=lambda connection: pairing.PairingConfig(
identity_address_type=(
pairing.PairingConfig.AddressType(self.config.identity_address_type)
if self.config.identity_address_type
else None
),
delegate=pairing.PairingDelegate(
io_capability=pairing.PairingDelegate.IoCapability(
self.config.io_capability
)
),
),
) )
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu) self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
@@ -2183,10 +2197,15 @@ class Device(CompositeEventEmitter):
HCI_Write_LE_Host_Support_Command( HCI_Write_LE_Host_Support_Command(
le_supported_host=int(self.le_enabled), le_supported_host=int(self.le_enabled),
simultaneous_le_host=int(self.le_simultaneous_enabled), simultaneous_le_host=int(self.le_simultaneous_enabled),
) ),
check_result=True,
) )
if self.le_enabled: if self.le_enabled:
# Generate a random address if not set.
if self.static_address == Address.ANY_RANDOM:
self.static_address = Address.generate_static_address()
# If LE Privacy is enabled, generate an RPA # If LE Privacy is enabled, generate an RPA
if self.le_privacy_enabled: if self.le_privacy_enabled:
self.random_address = Address.generate_private_address(self.irk) self.random_address = Address.generate_private_address(self.irk)
@@ -2196,23 +2215,8 @@ class Device(CompositeEventEmitter):
self.le_rpa_periodic_update_task = asyncio.create_task( self.le_rpa_periodic_update_task = asyncio.create_task(
self._run_rpa_periodic_update() self._run_rpa_periodic_update()
) )
else:
# Set the controller address self.random_address = self.static_address
if self.random_address == Address.ANY_RANDOM:
# Try to use an address generated at random by the controller
if self.host.supports_command(HCI_LE_RAND_COMMAND):
# Get 8 random bytes
response = await self.send_command(
HCI_LE_Rand_Command(), check_result=True
)
# Ensure the address bytes can be a static random address
address_bytes = response.return_parameters.random_number[
:5
] + bytes([response.return_parameters.random_number[5] | 0xC0])
# Create a static random address from the random bytes
self.random_address = Address(address_bytes)
if self.random_address != Address.ANY_RANDOM: if self.random_address != Address.ANY_RANDOM:
logger.debug( logger.debug(
@@ -2237,7 +2241,8 @@ class Device(CompositeEventEmitter):
await self.send_command( await self.send_command(
HCI_LE_Set_Address_Resolution_Enable_Command( HCI_LE_Set_Address_Resolution_Enable_Command(
address_resolution_enable=1 address_resolution_enable=1
) ),
check_result=True,
) )
if self.cis_enabled: if self.cis_enabled:
@@ -2245,7 +2250,8 @@ class Device(CompositeEventEmitter):
HCI_LE_Set_Host_Feature_Command( HCI_LE_Set_Host_Feature_Command(
bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM, bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM,
bit_value=1, bit_value=1,
) ),
check_result=True,
) )
if self.classic_enabled: if self.classic_enabled:
@@ -3572,12 +3578,12 @@ class Device(CompositeEventEmitter):
await self.stop_discovery() await self.stop_discovery()
@property @property
def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]: def pairing_config_factory(self) -> Callable[[Connection], pairing.PairingConfig]:
return self.smp_manager.pairing_config_factory return self.smp_manager.pairing_config_factory
@pairing_config_factory.setter @pairing_config_factory.setter
def pairing_config_factory( def pairing_config_factory(
self, pairing_config_factory: Callable[[Connection], PairingConfig] self, pairing_config_factory: Callable[[Connection], pairing.PairingConfig]
) -> None: ) -> None:
self.smp_manager.pairing_config_factory = pairing_config_factory self.smp_manager.pairing_config_factory = pairing_config_factory

View File

@@ -1078,11 +1078,19 @@ class Session:
) )
def send_identity_address_command(self) -> None: def send_identity_address_command(self) -> None:
identity_address = { if self.pairing_config.identity_address_type == Address.PUBLIC_DEVICE_ADDRESS:
None: self.manager.device.static_address, identity_address = self.manager.device.public_address
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address, elif self.pairing_config.identity_address_type == Address.RANDOM_DEVICE_ADDRESS:
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.static_address, identity_address = self.manager.device.static_address
}[self.pairing_config.identity_address_type] else:
# No identity address type set. If the controller has a public address, it
# will be more responsible to be the identity address.
if self.manager.device.public_address != Address.ANY:
logger.debug("No identity address type set, using PUBLIC")
identity_address = self.manager.device.public_address
else:
logger.debug("No identity address type set, using RANDOM")
identity_address = self.manager.device.static_address
self.send_command( self.send_command(
SMP_Identity_Address_Information_Command( SMP_Identity_Address_Information_Command(
addr_type=identity_address.address_type, addr_type=identity_address.address_type,

View File

@@ -536,6 +536,16 @@ async def test_cis_setup_failure():
await asyncio.wait_for(cis_create_task, _TIMEOUT) await asyncio.wait_for(cis_create_task, _TIMEOUT)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_power_on_default_static_address_should_not_be_any():
devices = TwoDevices()
devices[0].static_address = devices[0].random_address = Address.ANY_RANDOM
await devices[0].power_on()
assert devices[0].static_address != Address.ANY_RANDOM
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_gatt_services_with_gas(): def test_gatt_services_with_gas():
device = Device(host=Host(None, None)) device = Device(host=Host(None, None))

View File

@@ -17,13 +17,17 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import pytest import pytest
from unittest import mock
from bumble import smp from bumble import smp
from bumble import pairing
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1 from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
from bumble.pairing import OobData, OobSharedData, LeRole from bumble.pairing import OobData, OobSharedData, LeRole
from bumble.hci import Address from bumble.hci import Address
from bumble.core import AdvertisingData from bumble.core import AdvertisingData
from bumble.device import Device
from typing import Optional
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# pylint: disable=invalid-name # pylint: disable=invalid-name
@@ -251,6 +255,57 @@ def test_link_key_to_ltk(ct2: bool, expected: str):
assert smp.Session.derive_ltk(LINK_KEY, ct2) == reversed_hex(expected) assert smp.Session.derive_ltk(LINK_KEY, ct2) == reversed_hex(expected)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'identity_address_type, public_address, random_address, expected_identity_address',
[
(
None,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
),
(
None,
Address.ANY,
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
),
(
pairing.PairingConfig.AddressType.PUBLIC,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
),
(
pairing.PairingConfig.AddressType.RANDOM,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
),
],
)
@pytest.mark.asyncio
async def test_send_identity_address_command(
identity_address_type: Optional[pairing.PairingConfig.AddressType],
public_address: Address,
random_address: Address,
expected_identity_address: Address,
):
device = Device()
device.public_address = public_address
device.static_address = random_address
pairing_config = pairing.PairingConfig(identity_address_type=identity_address_type)
session = smp.Session(device.smp_manager, mock.MagicMock(), pairing_config, True)
with mock.patch.object(session, 'send_command') as mock_method:
session.send_identity_address_command()
actual_command = mock_method.call_args.args[0]
assert actual_command.addr_type == expected_identity_address.address_type
assert actual_command.bd_addr == expected_identity_address
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
test_ecc() test_ecc()