Compare commits

...

10 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod a311c3f723 hotfix for usb transport 2024-08-22 22:26:44 -07:00
zxzxwu b2bb82a432 Merge pull request #537 from zxzxwu/smp
Ignore invalid RPA
2024-08-21 13:54:02 +08:00
Josh Wu 597560ff80 Ignore invalid local resolvable address 2024-08-21 00:11:14 +08:00
Gilles Boccon-Gibod db383bb3e6 Merge pull request #531 from AlanRosenthal/btbench-scan
BtBench: Add Scan functionality
2024-08-14 11:59:13 -07:00
Alan Rosenthal ccc5bbdad4 BtBench: Scan 2024-08-14 11:26:31 -04:00
zxzxwu 11c8229017 Merge pull request #533 from zxzxwu/hid
Correct HID type annotations
2024-08-14 12:08:53 +08:00
Josh Wu 2248f9ae5e Correct HID type annotations 2024-08-13 23:13:33 +08:00
Gilles Boccon-Gibod 03c79aacb2 Merge pull request #529 from google/gbg/broadcast-assistant
basic broadcast assistant functionality
2024-08-12 13:02:50 -07:00
zxzxwu 0c31713a8e Merge pull request #528 from zxzxwu/rpa
Fix CTKD failure introduced by Host RPA generation
2024-08-13 01:30:19 +08:00
Josh Wu c6b3deb8df Fix CTKD failure introduced by Host RPA generation 2024-08-12 15:13:40 +08:00
9 changed files with 227 additions and 96 deletions
+52 -31
View File
@@ -207,13 +207,13 @@ from .keys import (
KeyStore,
PairingKeys,
)
from .pairing import PairingConfig
from . import gatt_client
from . import gatt_server
from . import smp
from . import sdp
from . import l2cap
from . import core
from bumble import pairing
from bumble import gatt_client
from bumble import gatt_server
from bumble import smp
from bumble import sdp
from bumble import l2cap
from bumble import core
if TYPE_CHECKING:
from .transport.common import TransportSource, TransportSink
@@ -1613,6 +1613,8 @@ class DeviceConfiguration:
address_resolution_offload: bool = False
address_generation_offload: bool = False
cis_enabled: bool = False
identity_address_type: Optional[int] = None
io_capability: int = pairing.PairingDelegate.IoCapability.NO_OUTPUT_NO_INPUT
def __post_init__(self) -> None:
self.gatt_services: List[Dict[str, Any]] = []
@@ -1978,7 +1980,19 @@ class Device(CompositeEventEmitter):
# Setup SMP
self.smp_manager = smp.Manager(
self, pairing_config_factory=lambda connection: PairingConfig()
self,
pairing_config_factory=lambda connection: pairing.PairingConfig(
identity_address_type=(
pairing.PairingConfig.AddressType(self.config.identity_address_type)
if self.config.identity_address_type
else None
),
delegate=pairing.PairingDelegate(
io_capability=pairing.PairingDelegate.IoCapability(
self.config.io_capability
)
),
),
)
self.l2cap_channel_manager.register_fixed_channel(smp.SMP_CID, self.on_smp_pdu)
@@ -2202,10 +2216,15 @@ class Device(CompositeEventEmitter):
HCI_Write_LE_Host_Support_Command(
le_supported_host=int(self.le_enabled),
simultaneous_le_host=int(self.le_simultaneous_enabled),
)
),
check_result=True,
)
if self.le_enabled:
# Generate a random address if not set.
if self.static_address == Address.ANY_RANDOM:
self.static_address = Address.generate_static_address()
# If LE Privacy is enabled, generate an RPA
if self.le_privacy_enabled:
self.random_address = Address.generate_private_address(self.irk)
@@ -2215,23 +2234,8 @@ class Device(CompositeEventEmitter):
self.le_rpa_periodic_update_task = asyncio.create_task(
self._run_rpa_periodic_update()
)
# Set the controller address
if self.random_address == Address.ANY_RANDOM:
# Try to use an address generated at random by the controller
if self.host.supports_command(HCI_LE_RAND_COMMAND):
# Get 8 random bytes
response = await self.send_command(
HCI_LE_Rand_Command(), check_result=True
)
# Ensure the address bytes can be a static random address
address_bytes = response.return_parameters.random_number[
:5
] + bytes([response.return_parameters.random_number[5] | 0xC0])
# Create a static random address from the random bytes
self.random_address = Address(address_bytes)
else:
self.random_address = self.static_address
if self.random_address != Address.ANY_RANDOM:
logger.debug(
@@ -2256,7 +2260,8 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_LE_Set_Address_Resolution_Enable_Command(
address_resolution_enable=1
)
),
check_result=True,
)
if self.cis_enabled:
@@ -2264,7 +2269,8 @@ class Device(CompositeEventEmitter):
HCI_LE_Set_Host_Feature_Command(
bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM,
bit_value=1,
)
),
check_result=True,
)
if self.classic_enabled:
@@ -3714,12 +3720,12 @@ class Device(CompositeEventEmitter):
await self.stop_scanning()
@property
def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]:
def pairing_config_factory(self) -> Callable[[Connection], pairing.PairingConfig]:
return self.smp_manager.pairing_config_factory
@pairing_config_factory.setter
def pairing_config_factory(
self, pairing_config_factory: Callable[[Connection], PairingConfig]
self, pairing_config_factory: Callable[[Connection], pairing.PairingConfig]
) -> None:
self.smp_manager.pairing_config_factory = pairing_config_factory
@@ -4275,6 +4281,12 @@ class Device(CompositeEventEmitter):
else self.public_address
)
if advertising_set.advertising_parameters.own_address_type in (
OwnAddressType.RANDOM,
OwnAddressType.PUBLIC,
):
connection.self_resolvable_address = None
# Setup auto-restart of the advertising set if needed.
if advertising_set.auto_restart:
connection.once(
@@ -4321,7 +4333,10 @@ class Device(CompositeEventEmitter):
# Convert all-zeros addresses into None.
if self_resolvable_address == Address.ANY_RANDOM:
self_resolvable_address = None
if peer_resolvable_address == Address.ANY_RANDOM:
if (
peer_resolvable_address == Address.ANY_RANDOM
or not peer_address.is_resolved
):
peer_resolvable_address = None
logger.debug(
@@ -4355,6 +4370,7 @@ class Device(CompositeEventEmitter):
peer_address = resolved_address
self_address = None
own_address_type: Optional[int] = None
if role == HCI_CENTRAL_ROLE:
own_address_type = self.connect_own_address_type
assert own_address_type is not None
@@ -4383,6 +4399,11 @@ class Device(CompositeEventEmitter):
else self.random_address
)
# Some controllers may return local resolvable address even not using address
# generation offloading. Ignore the value to prevent SMP failure.
if own_address_type in (OwnAddressType.RANDOM, OwnAddressType.PUBLIC):
self_resolvable_address = None
# Create a connection.
connection = Connection(
self,
+24 -28
View File
@@ -23,13 +23,12 @@ import struct
from abc import ABC, abstractmethod
from pyee import EventEmitter
from typing import Optional, Callable, TYPE_CHECKING
from typing import Optional, Callable
from typing_extensions import override
from bumble import l2cap, device
from bumble.colors import color
from bumble.core import InvalidStateError, ProtocolError
from .hci import Address
from bumble.hci import Address
# -----------------------------------------------------------------------------
@@ -220,31 +219,27 @@ class HID(ABC, EventEmitter):
async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel
try:
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_CONTROL_PSM
)
channel.sink = self.on_ctrl_pdu
self.l2cap_ctrl_channel = channel
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_ctrl_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
async def connect_interrupt_channel(self) -> None:
# Create a new L2CAP connection - interrupt channel
try:
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_INTERRUPT_PSM
)
channel.sink = self.on_intr_pdu
self.l2cap_intr_channel = channel
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_intr_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_intr_channel.sink = self.on_intr_pdu
async def disconnect_interrupt_channel(self) -> None:
if self.l2cap_intr_channel is None:
raise InvalidStateError('invalid state')
@@ -334,17 +329,18 @@ class Device(HID):
ERR_INVALID_PARAMETER = 0x04
SUCCESS = 0xFF
@dataclass
class GetSetStatus:
def __init__(self) -> None:
self.data = bytearray()
self.status = 0
data: bytes = b''
status: int = 0
get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE)
get_report_cb: Optional[Callable[[int, int, int], None]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
get_protocol_cb: Optional[Callable[[], None]] = None
set_protocol_cb: Optional[Callable[[int], None]] = None
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
@@ -410,7 +406,6 @@ class Device(HID):
buffer_size = 0
ret = self.get_report_cb(report_id, report_type, buffer_size)
assert ret is not None
if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS:
@@ -428,7 +423,9 @@ class Device(HID):
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
def register_get_report_cb(
self, cb: Callable[[int, int, int], Device.GetSetStatus]
) -> None:
self.get_report_cb = cb
logger.debug("GetReport callback registered successfully")
@@ -442,7 +439,6 @@ class Device(HID):
report_data = pdu[2:]
report_size = len(report_data) + 1
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
@@ -453,7 +449,7 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb(
self, cb: Callable[[int, int, int, bytes], None]
self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus]
) -> None:
self.set_report_cb = cb
logger.debug("SetReport callback registered successfully")
@@ -464,13 +460,12 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.get_protocol_cb()
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None:
self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully")
@@ -480,13 +475,14 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.set_protocol_cb(pdu[0] & 0x01)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
def register_set_protocol_cb(
self, cb: Callable[[int], Device.GetSetStatus]
) -> None:
self.set_protocol_cb = cb
logger.debug("SetProtocol callback registered successfully")
+13 -5
View File
@@ -1078,11 +1078,19 @@ class Session:
)
def send_identity_address_command(self) -> None:
identity_address = {
None: self.manager.device.static_address,
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.static_address,
}[self.pairing_config.identity_address_type]
if self.pairing_config.identity_address_type == Address.PUBLIC_DEVICE_ADDRESS:
identity_address = self.manager.device.public_address
elif self.pairing_config.identity_address_type == Address.RANDOM_DEVICE_ADDRESS:
identity_address = self.manager.device.static_address
else:
# No identity address type set. If the controller has a public address, it
# will be more responsible to be the identity address.
if self.manager.device.public_address != Address.ANY:
logger.debug("No identity address type set, using PUBLIC")
identity_address = self.manager.device.public_address
else:
logger.debug("No identity address type set, using RANDOM")
identity_address = self.manager.device.static_address
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=identity_address.address_type,
+1 -1
View File
@@ -139,7 +139,7 @@ async def open_usb_transport(spec: str) -> Transport:
self.packets.put_nowait(packet)
def transfer_callback(self, transfer):
self.acl_out_transfer_ready.release()
self.loop.call_soon_threadsafe(self.acl_out_transfer_ready.release)
status = transfer.getStatus()
# pylint: disable=no-member
+25 -29
View File
@@ -21,7 +21,7 @@ import os
import logging
import json
import websockets
from bumble.colors import color
import struct
from bumble.device import Device
from bumble.transport import open_transport_or_link
@@ -30,9 +30,7 @@ from bumble.core import (
BT_L2CAP_PROTOCOL_ID,
BT_HUMAN_INTERFACE_DEVICE_SERVICE,
BT_HIDP_PROTOCOL_ID,
UUID,
)
from bumble.hci import Address
from bumble.hid import (
Device as HID_Device,
HID_CONTROL_PSM,
@@ -40,20 +38,17 @@ from bumble.hid import (
Message,
)
from bumble.sdp import (
Client as SDP_Client,
DataElement,
ServiceAttribute,
SDP_PUBLIC_BROWSE_ROOT,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_ALL_ATTRIBUTES_RANGE,
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
)
from bumble.utils import AsyncRunner
# -----------------------------------------------------------------------------
# SDP attributes for Bluetooth HID devices
@@ -430,7 +425,7 @@ deviceData = DeviceData()
# -----------------------------------------------------------------------------
async def keyboard_device(hid_device):
async def keyboard_device(hid_device: HID_Device):
# Start a Websocket server to receive events from a web page
async def serve(websocket, _path):
@@ -476,9 +471,9 @@ async def keyboard_device(hid_device):
# limiting x and y values within logical max and min range
x = max(log_min, min(log_max, x))
y = max(log_min, min(log_max, y))
x_cord = x.to_bytes(signed=True)
y_cord = y.to_bytes(signed=True)
deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord
deviceData.mouseData = bytearray([0x02, 0x00]) + struct.pack(
">bb", x, y
)
hid_device.send_data(deviceData.mouseData)
except websockets.exceptions.ConnectionClosedOK:
pass
@@ -515,7 +510,9 @@ async def main() -> None:
def on_hid_data_cb(pdu: bytes):
print(f'Received Data, PDU: {pdu.hex()}')
def on_get_report_cb(report_id: int, report_type: int, buffer_size: int):
def on_get_report_cb(
report_id: int, report_type: int, buffer_size: int
) -> HID_Device.GetSetStatus:
retValue = hid_device.GetSetStatus()
print(
"GET_REPORT report_id: "
@@ -555,8 +552,7 @@ async def main() -> None:
def on_set_report_cb(
report_id: int, report_type: int, report_size: int, data: bytes
):
retValue = hid_device.GetSetStatus()
) -> HID_Device.GetSetStatus:
print(
"SET_REPORT report_id: "
+ str(report_id)
@@ -568,33 +564,33 @@ async def main() -> None:
+ str(data)
)
if report_type == Message.ReportType.FEATURE_REPORT:
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_type == Message.ReportType.INPUT_REPORT:
if report_id == 1 and report_size != len(deviceData.keyboardData):
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 2 and report_size != len(deviceData.mouseData):
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 3:
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
status = HID_Device.GetSetReturn.REPORT_ID_NOT_FOUND
else:
retValue.status = hid_device.GetSetReturn.SUCCESS
status = HID_Device.GetSetReturn.SUCCESS
else:
retValue.status = hid_device.GetSetReturn.SUCCESS
status = HID_Device.GetSetReturn.SUCCESS
return retValue
return HID_Device.GetSetStatus(status=status)
def on_get_protocol_cb():
retValue = hid_device.GetSetStatus()
retValue.data = protocol_mode.to_bytes()
retValue.status = hid_device.GetSetReturn.SUCCESS
return retValue
def on_get_protocol_cb() -> HID_Device.GetSetStatus:
return HID_Device.GetSetStatus(
data=bytes([protocol_mode]),
status=hid_device.GetSetReturn.SUCCESS,
)
def on_set_protocol_cb(protocol: int):
retValue = hid_device.GetSetStatus()
def on_set_protocol_cb(protocol: int) -> HID_Device.GetSetStatus:
# We do not support SET_PROTOCOL.
print(f"SET_PROTOCOL report_id: {protocol}")
retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
return retValue
return HID_Device.GetSetStatus(
status=hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
)
def on_virtual_cable_unplug_cb():
print('Received Virtual Cable Unplug')
@@ -142,7 +142,7 @@ class MainActivity : ComponentActivity() {
::runRfcommClient,
::runRfcommServer,
::runL2capClient,
::runL2capServer
::runL2capServer,
)
}
@@ -166,6 +166,8 @@ class MainActivity : ComponentActivity() {
"rfcomm-server" -> runRfcommServer()
"l2cap-client" -> runL2capClient()
"l2cap-server" -> runL2capServer()
"scan-start" -> runScan(true)
"stop-start" -> runScan(false)
}
}
}
@@ -190,6 +192,11 @@ class MainActivity : ComponentActivity() {
l2capServer?.run()
}
private fun runScan(startScan: Boolean) {
val scan = bluetoothAdapter?.let { Scan(it) }
scan?.run(startScan)
}
@SuppressLint("MissingPermission")
fun becomeDiscoverable() {
val discoverableIntent = Intent(BluetoothAdapter.ACTION_REQUEST_DISCOVERABLE)
@@ -206,7 +213,7 @@ fun MainView(
runRfcommClient: () -> Unit,
runRfcommServer: () -> Unit,
runL2capClient: () -> Unit,
runL2capServer: () -> Unit
runL2capServer: () -> Unit,
) {
BTBenchTheme {
val scrollState = rememberScrollState()
@@ -0,0 +1,38 @@
package com.github.google.bumble.btbench
import android.annotation.SuppressLint
import android.bluetooth.BluetoothAdapter
import android.bluetooth.BluetoothDevice
import android.bluetooth.le.ScanCallback
import android.bluetooth.le.ScanResult
import java.util.logging.Logger
private val Log = Logger.getLogger("btbench.scan")
class Scan(val bluetoothAdapter: BluetoothAdapter) {
@SuppressLint("MissingPermission")
fun run(startScan: Boolean) {
var bluetoothLeScanner = bluetoothAdapter.bluetoothLeScanner
val scanCallback = object : ScanCallback() {
override fun onScanResult(callbackType: Int, result: ScanResult?) {
super.onScanResult(callbackType, result)
val device: BluetoothDevice? = result?.device
val deviceName = device?.name ?: "Unknown"
val deviceAddress = device?.address ?: "Unknown"
Log.info("Device found: $deviceName ($deviceAddress)")
}
override fun onScanFailed(errorCode: Int) {
// Handle scan failure
Log.warning("Scan failed with error code: $errorCode")
}
}
if (startScan) {
bluetoothLeScanner?.startScan(scanCallback)
} else {
bluetoothLeScanner?.stopScan(scanCallback)
}
}
}
+10
View File
@@ -536,6 +536,16 @@ async def test_cis_setup_failure():
await asyncio.wait_for(cis_create_task, _TIMEOUT)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_power_on_default_static_address_should_not_be_any():
devices = TwoDevices()
devices[0].static_address = devices[0].random_address = Address.ANY_RANDOM
await devices[0].power_on()
assert devices[0].static_address != Address.ANY_RANDOM
# -----------------------------------------------------------------------------
def test_gatt_services_with_gas():
device = Device(host=Host(None, None))
+55
View File
@@ -17,13 +17,17 @@
# -----------------------------------------------------------------------------
import pytest
from unittest import mock
from bumble import smp
from bumble import pairing
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
from bumble.pairing import OobData, OobSharedData, LeRole
from bumble.hci import Address
from bumble.core import AdvertisingData
from bumble.device import Device
from typing import Optional
# -----------------------------------------------------------------------------
# pylint: disable=invalid-name
@@ -251,6 +255,57 @@ def test_link_key_to_ltk(ct2: bool, expected: str):
assert smp.Session.derive_ltk(LINK_KEY, ct2) == reversed_hex(expected)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'identity_address_type, public_address, random_address, expected_identity_address',
[
(
None,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
),
(
None,
Address.ANY,
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
),
(
pairing.PairingConfig.AddressType.PUBLIC,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
),
(
pairing.PairingConfig.AddressType.RANDOM,
Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
),
],
)
@pytest.mark.asyncio
async def test_send_identity_address_command(
identity_address_type: Optional[pairing.PairingConfig.AddressType],
public_address: Address,
random_address: Address,
expected_identity_address: Address,
):
device = Device()
device.public_address = public_address
device.static_address = random_address
pairing_config = pairing.PairingConfig(identity_address_type=identity_address_type)
session = smp.Session(device.smp_manager, mock.MagicMock(), pairing_config, True)
with mock.patch.object(session, 'send_command') as mock_method:
session.send_identity_address_command()
actual_command = mock_method.call_args.args[0]
assert actual_command.addr_type == expected_identity_address.address_type
assert actual_command.bd_addr == expected_identity_address
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_ecc()