Fix CTKD failure introduced by Host RPA generation

This commit is contained in:
Josh Wu
2024-08-09 23:46:10 +08:00
parent 4433184048
commit c6b3deb8df
4 changed files with 114 additions and 35 deletions

View File

@@ -203,13 +203,13 @@ from .keys import (
KeyStore,
PairingKeys,
)
from .pairing import PairingConfig
from . import gatt_client
from . import gatt_server
from . import smp
from . import sdp
from . import l2cap
from . import core
from bumble import pairing
from bumble import gatt_client
from bumble import gatt_server
from bumble import smp
from bumble import sdp
from bumble import l2cap
from bumble import core
if TYPE_CHECKING:
from .transport.common import TransportSource, TransportSink
@@ -1595,6 +1595,8 @@ class DeviceConfiguration:
address_resolution_offload: bool = False
address_generation_offload: bool = False
cis_enabled: bool = False
identity_address_type: Optional[int] = None
io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT
def __post_init__(self) -> None:
self.gatt_services: List[Dict[str, Any]] = []
@@ -1959,7 +1961,19 @@ class Device(CompositeEventEmitter):
# Setup SMP
self.smp_manager = smp.Manager(
self, pairing_config_factory=lambda connection: PairingConfig()
self,
pairing_config_factory=lambda connection: pairing.PairingConfig(
identity_address_type=(
pairing.PairingConfig.AddressType(self.config.identity_address_type)
if self.config.identity_address_type
else None
),
delegate=pairing.PairingDelegate(
io_capability=pairing.PairingDelegate.IoCapability(
self.config.io_capability
)
),
),
)
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
@@ -2183,10 +2197,15 @@ class Device(CompositeEventEmitter):
HCI_Write_LE_Host_Support_Command(
le_supported_host=int(self.le_enabled),
simultaneous_le_host=int(self.le_simultaneous_enabled),
)
),
check_result=True,
)
if self.le_enabled:
# Generate a random address if not set.
if self.static_address == Address.ANY_RANDOM:
self.static_address = Address.generate_static_address()
# If LE Privacy is enabled, generate an RPA
if self.le_privacy_enabled:
self.random_address = Address.generate_private_address(self.irk)
@@ -2196,23 +2215,8 @@ class Device(CompositeEventEmitter):
self.le_rpa_periodic_update_task = asyncio.create_task(
self._run_rpa_periodic_update()
)
# Set the controller address
if self.random_address == Address.ANY_RANDOM:
# Try to use an address generated at random by the controller
if self.host.supports_command(HCI_LE_RAND_COMMAND):
# Get 8 random bytes
response = await self.send_command(
HCI_LE_Rand_Command(), check_result=True
)
# Ensure the address bytes can be a static random address
address_bytes = response.return_parameters.random_number[
:5
] + bytes([response.return_parameters.random_number[5] | 0xC0])
# Create a static random address from the random bytes
self.random_address = Address(address_bytes)
else:
self.random_address = self.static_address
if self.random_address != Address.ANY_RANDOM:
logger.debug(
@@ -2237,7 +2241,8 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_LE_Set_Address_Resolution_Enable_Command(
address_resolution_enable=1
)
),
check_result=True,
)
if self.cis_enabled:
@@ -2245,7 +2250,8 @@ class Device(CompositeEventEmitter):
HCI_LE_Set_Host_Feature_Command(
bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM,
bit_value=1,
)
),
check_result=True,
)
if self.classic_enabled:
@@ -3572,12 +3578,12 @@ class Device(CompositeEventEmitter):
await self.stop_discovery()
@property
def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]:
def pairing_config_factory(self) -> Callable[[Connection], pairing.PairingConfig]:
return self.smp_manager.pairing_config_factory
@pairing_config_factory.setter
def pairing_config_factory(
self, pairing_config_factory: Callable[[Connection], PairingConfig]
self, pairing_config_factory: Callable[[Connection], pairing.PairingConfig]
) -> None:
self.smp_manager.pairing_config_factory = pairing_config_factory

View File

@@ -1078,11 +1078,19 @@ class Session:
)
def send_identity_address_command(self) -> None:
identity_address = {
None: self.manager.device.static_address,
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.static_address,
}[self.pairing_config.identity_address_type]
if self.pairing_config.identity_address_type == Address.PUBLIC_DEVICE_ADDRESS:
identity_address = self.manager.device.public_address
elif self.pairing_config.identity_address_type == Address.RANDOM_DEVICE_ADDRESS:
identity_address = self.manager.device.static_address
else:
# No identity address type set. If the controller has a public address, it
# will be more responsible to be the identity address.
if self.manager.device.public_address != Address.ANY:
logger.debug("No identity address type set, using PUBLIC")
identity_address = self.manager.device.public_address
else:
logger.debug("No identity address type set, using RANDOM")
identity_address = self.manager.device.static_address
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=identity_address.address_type,