mirror of
https://github.com/google/bumble.git
synced 2026-05-08 03:58:01 +00:00
Fix CTKD failure introduced by Host RPA generation
This commit is contained in:
@@ -203,13 +203,13 @@ from .keys import (
|
|||||||
KeyStore,
|
KeyStore,
|
||||||
PairingKeys,
|
PairingKeys,
|
||||||
)
|
)
|
||||||
from .pairing import PairingConfig
|
from bumble import pairing
|
||||||
from . import gatt_client
|
from bumble import gatt_client
|
||||||
from . import gatt_server
|
from bumble import gatt_server
|
||||||
from . import smp
|
from bumble import smp
|
||||||
from . import sdp
|
from bumble import sdp
|
||||||
from . import l2cap
|
from bumble import l2cap
|
||||||
from . import core
|
from bumble import core
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .transport.common import TransportSource, TransportSink
|
from .transport.common import TransportSource, TransportSink
|
||||||
@@ -1595,6 +1595,8 @@ class DeviceConfiguration:
|
|||||||
address_resolution_offload: bool = False
|
address_resolution_offload: bool = False
|
||||||
address_generation_offload: bool = False
|
address_generation_offload: bool = False
|
||||||
cis_enabled: bool = False
|
cis_enabled: bool = False
|
||||||
|
identity_address_type: Optional[int] = None
|
||||||
|
io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self.gatt_services: List[Dict[str, Any]] = []
|
self.gatt_services: List[Dict[str, Any]] = []
|
||||||
@@ -1959,7 +1961,19 @@ class Device(CompositeEventEmitter):
|
|||||||
|
|
||||||
# Setup SMP
|
# Setup SMP
|
||||||
self.smp_manager = smp.Manager(
|
self.smp_manager = smp.Manager(
|
||||||
self, pairing_config_factory=lambda connection: PairingConfig()
|
self,
|
||||||
|
pairing_config_factory=lambda connection: pairing.PairingConfig(
|
||||||
|
identity_address_type=(
|
||||||
|
pairing.PairingConfig.AddressType(self.config.identity_address_type)
|
||||||
|
if self.config.identity_address_type
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
delegate=pairing.PairingDelegate(
|
||||||
|
io_capability=pairing.PairingDelegate.IoCapability(
|
||||||
|
self.config.io_capability
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
|
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
|
||||||
@@ -2183,10 +2197,15 @@ class Device(CompositeEventEmitter):
|
|||||||
HCI_Write_LE_Host_Support_Command(
|
HCI_Write_LE_Host_Support_Command(
|
||||||
le_supported_host=int(self.le_enabled),
|
le_supported_host=int(self.le_enabled),
|
||||||
simultaneous_le_host=int(self.le_simultaneous_enabled),
|
simultaneous_le_host=int(self.le_simultaneous_enabled),
|
||||||
)
|
),
|
||||||
|
check_result=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.le_enabled:
|
if self.le_enabled:
|
||||||
|
# Generate a random address if not set.
|
||||||
|
if self.static_address == Address.ANY_RANDOM:
|
||||||
|
self.static_address = Address.generate_static_address()
|
||||||
|
|
||||||
# If LE Privacy is enabled, generate an RPA
|
# If LE Privacy is enabled, generate an RPA
|
||||||
if self.le_privacy_enabled:
|
if self.le_privacy_enabled:
|
||||||
self.random_address = Address.generate_private_address(self.irk)
|
self.random_address = Address.generate_private_address(self.irk)
|
||||||
@@ -2196,23 +2215,8 @@ class Device(CompositeEventEmitter):
|
|||||||
self.le_rpa_periodic_update_task = asyncio.create_task(
|
self.le_rpa_periodic_update_task = asyncio.create_task(
|
||||||
self._run_rpa_periodic_update()
|
self._run_rpa_periodic_update()
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
# Set the controller address
|
self.random_address = self.static_address
|
||||||
if self.random_address == Address.ANY_RANDOM:
|
|
||||||
# Try to use an address generated at random by the controller
|
|
||||||
if self.host.supports_command(HCI_LE_RAND_COMMAND):
|
|
||||||
# Get 8 random bytes
|
|
||||||
response = await self.send_command(
|
|
||||||
HCI_LE_Rand_Command(), check_result=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure the address bytes can be a static random address
|
|
||||||
address_bytes = response.return_parameters.random_number[
|
|
||||||
:5
|
|
||||||
] + bytes([response.return_parameters.random_number[5] | 0xC0])
|
|
||||||
|
|
||||||
# Create a static random address from the random bytes
|
|
||||||
self.random_address = Address(address_bytes)
|
|
||||||
|
|
||||||
if self.random_address != Address.ANY_RANDOM:
|
if self.random_address != Address.ANY_RANDOM:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -2237,7 +2241,8 @@ class Device(CompositeEventEmitter):
|
|||||||
await self.send_command(
|
await self.send_command(
|
||||||
HCI_LE_Set_Address_Resolution_Enable_Command(
|
HCI_LE_Set_Address_Resolution_Enable_Command(
|
||||||
address_resolution_enable=1
|
address_resolution_enable=1
|
||||||
)
|
),
|
||||||
|
check_result=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cis_enabled:
|
if self.cis_enabled:
|
||||||
@@ -2245,7 +2250,8 @@ class Device(CompositeEventEmitter):
|
|||||||
HCI_LE_Set_Host_Feature_Command(
|
HCI_LE_Set_Host_Feature_Command(
|
||||||
bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM,
|
bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM,
|
||||||
bit_value=1,
|
bit_value=1,
|
||||||
)
|
),
|
||||||
|
check_result=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.classic_enabled:
|
if self.classic_enabled:
|
||||||
@@ -3572,12 +3578,12 @@ class Device(CompositeEventEmitter):
|
|||||||
await self.stop_discovery()
|
await self.stop_discovery()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]:
|
def pairing_config_factory(self) -> Callable[[Connection], pairing.PairingConfig]:
|
||||||
return self.smp_manager.pairing_config_factory
|
return self.smp_manager.pairing_config_factory
|
||||||
|
|
||||||
@pairing_config_factory.setter
|
@pairing_config_factory.setter
|
||||||
def pairing_config_factory(
|
def pairing_config_factory(
|
||||||
self, pairing_config_factory: Callable[[Connection], PairingConfig]
|
self, pairing_config_factory: Callable[[Connection], pairing.PairingConfig]
|
||||||
) -> None:
|
) -> None:
|
||||||
self.smp_manager.pairing_config_factory = pairing_config_factory
|
self.smp_manager.pairing_config_factory = pairing_config_factory
|
||||||
|
|
||||||
|
|||||||
@@ -1078,11 +1078,19 @@ class Session:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def send_identity_address_command(self) -> None:
|
def send_identity_address_command(self) -> None:
|
||||||
identity_address = {
|
if self.pairing_config.identity_address_type == Address.PUBLIC_DEVICE_ADDRESS:
|
||||||
None: self.manager.device.static_address,
|
identity_address = self.manager.device.public_address
|
||||||
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
|
elif self.pairing_config.identity_address_type == Address.RANDOM_DEVICE_ADDRESS:
|
||||||
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.static_address,
|
identity_address = self.manager.device.static_address
|
||||||
}[self.pairing_config.identity_address_type]
|
else:
|
||||||
|
# No identity address type set. If the controller has a public address, it
|
||||||
|
# will be more responsible to be the identity address.
|
||||||
|
if self.manager.device.public_address != Address.ANY:
|
||||||
|
logger.debug("No identity address type set, using PUBLIC")
|
||||||
|
identity_address = self.manager.device.public_address
|
||||||
|
else:
|
||||||
|
logger.debug("No identity address type set, using RANDOM")
|
||||||
|
identity_address = self.manager.device.static_address
|
||||||
self.send_command(
|
self.send_command(
|
||||||
SMP_Identity_Address_Information_Command(
|
SMP_Identity_Address_Information_Command(
|
||||||
addr_type=identity_address.address_type,
|
addr_type=identity_address.address_type,
|
||||||
|
|||||||
@@ -536,6 +536,16 @@ async def test_cis_setup_failure():
|
|||||||
await asyncio.wait_for(cis_create_task, _TIMEOUT)
|
await asyncio.wait_for(cis_create_task, _TIMEOUT)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_power_on_default_static_address_should_not_be_any():
|
||||||
|
devices = TwoDevices()
|
||||||
|
devices[0].static_address = devices[0].random_address = Address.ANY_RANDOM
|
||||||
|
await devices[0].power_on()
|
||||||
|
|
||||||
|
assert devices[0].static_address != Address.ANY_RANDOM
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def test_gatt_services_with_gas():
|
def test_gatt_services_with_gas():
|
||||||
device = Device(host=Host(None, None))
|
device = Device(host=Host(None, None))
|
||||||
|
|||||||
@@ -17,13 +17,17 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from bumble import smp
|
from bumble import smp
|
||||||
|
from bumble import pairing
|
||||||
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
|
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
|
||||||
from bumble.pairing import OobData, OobSharedData, LeRole
|
from bumble.pairing import OobData, OobSharedData, LeRole
|
||||||
from bumble.hci import Address
|
from bumble.hci import Address
|
||||||
from bumble.core import AdvertisingData
|
from bumble.core import AdvertisingData
|
||||||
|
from bumble.device import Device
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
@@ -251,6 +255,57 @@ def test_link_key_to_ltk(ct2: bool, expected: str):
|
|||||||
assert smp.Session.derive_ltk(LINK_KEY, ct2) == reversed_hex(expected)
|
assert smp.Session.derive_ltk(LINK_KEY, ct2) == reversed_hex(expected)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'identity_address_type, public_address, random_address, expected_identity_address',
|
||||||
|
[
|
||||||
|
(
|
||||||
|
None,
|
||||||
|
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
|
||||||
|
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
|
||||||
|
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
None,
|
||||||
|
Address.ANY,
|
||||||
|
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
|
||||||
|
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
pairing.PairingConfig.AddressType.PUBLIC,
|
||||||
|
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
|
||||||
|
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
|
||||||
|
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
pairing.PairingConfig.AddressType.RANDOM,
|
||||||
|
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
|
||||||
|
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
|
||||||
|
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_identity_address_command(
|
||||||
|
identity_address_type: Optional[pairing.PairingConfig.AddressType],
|
||||||
|
public_address: Address,
|
||||||
|
random_address: Address,
|
||||||
|
expected_identity_address: Address,
|
||||||
|
):
|
||||||
|
device = Device()
|
||||||
|
device.public_address = public_address
|
||||||
|
device.static_address = random_address
|
||||||
|
pairing_config = pairing.PairingConfig(identity_address_type=identity_address_type)
|
||||||
|
session = smp.Session(device.smp_manager, mock.MagicMock(), pairing_config, True)
|
||||||
|
|
||||||
|
with mock.patch.object(session, 'send_command') as mock_method:
|
||||||
|
session.send_identity_address_command()
|
||||||
|
|
||||||
|
actual_command = mock_method.call_args.args[0]
|
||||||
|
assert actual_command.addr_type == expected_identity_address.address_type
|
||||||
|
assert actual_command.bd_addr == expected_identity_address
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_ecc()
|
test_ecc()
|
||||||
|
|||||||
Reference in New Issue
Block a user