Merge pull request #827 from zxzxwu/emu

Implement extended advertising emulation
This commit is contained in:
zxzxwu
2025-12-01 15:57:42 +08:00
committed by GitHub
8 changed files with 954 additions and 611 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -265,12 +265,22 @@ class ExtendedAdvertisement(Advertisement):
# -----------------------------------------------------------------------------
class AdvertisementDataAccumulator:
last_advertisement: Advertisement | None
last_data: bytes
passive: bool
def __init__(self, passive: bool = False):
self.passive = passive
self.last_advertisement = None
self.last_data = b''
def update(self, report):
def update(
self,
report: (
hci.HCI_LE_Advertising_Report_Event.Report
| hci.HCI_LE_Extended_Advertising_Report_Event.Report
),
) -> Advertisement | None:
advertisement = Advertisement.from_advertising_report(report)
if advertisement is None:
return None
@@ -283,10 +293,12 @@ class AdvertisementDataAccumulator:
and not self.last_advertisement.is_scan_response
):
# This is the response to a scannable advertisement
result = Advertisement.from_advertising_report(report)
result.is_connectable = self.last_advertisement.is_connectable
result.is_scannable = True
result.data = AdvertisingData.from_bytes(self.last_data + report.data)
if result := Advertisement.from_advertising_report(report):
result.is_connectable = self.last_advertisement.is_connectable
result.is_scannable = True
result.data = AdvertisingData.from_bytes(
self.last_data + report.data
)
self.last_data = b''
else:
if (
@@ -3333,7 +3345,13 @@ class Device(utils.CompositeEventEmitter):
return self.scanning
@host_event_handler
def on_advertising_report(self, report):
def on_advertising_report(
self,
report: (
hci.HCI_LE_Advertising_Report_Event.Report
| hci.HCI_LE_Extended_Advertising_Report_Event.Report
),
) -> None:
if not (accumulator := self.advertisement_accumulators.get(report.address)):
accumulator = AdvertisementDataAccumulator(passive=self.scanning_is_passive)
self.advertisement_accumulators[report.address] = accumulator

View File

@@ -19,9 +19,12 @@ import asyncio
# Imports
# -----------------------------------------------------------------------------
import logging
from typing import Optional
from typing import TYPE_CHECKING, Optional
from bumble import controller, core, hci, lmp
from bumble import core, hci, ll, lmp
if TYPE_CHECKING:
from bumble import controller
# -----------------------------------------------------------------------------
# Logging
@@ -29,11 +32,6 @@ from bumble import controller, core, hci, lmp
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# TODO: add more support for various LL exchanges
# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
@@ -47,7 +45,6 @@ class LocalLink:
def __init__(self):
self.controllers = set()
self.pending_connection = None
self.pending_classic_connection = None
############################################################
@@ -61,10 +58,11 @@ class LocalLink:
def remove_controller(self, controller: controller.Controller):
self.controllers.remove(controller)
def find_controller(self, address: hci.Address) -> controller.Controller | None:
def find_le_controller(self, address: hci.Address) -> controller.Controller | None:
for controller in self.controllers:
if controller.random_address == address:
return controller
for connection in controller.le_connections.values():
if connection.self_address == address:
return controller
return None
def find_classic_controller(
@@ -75,9 +73,6 @@ class LocalLink:
return controller
return None
def get_pending_connection(self):
return self.pending_connection
############################################################
# LE handlers
############################################################
@@ -85,12 +80,6 @@ class LocalLink:
def on_address_changed(self, controller):
pass
def send_advertising_data(self, sender_address: hci.Address, data: bytes):
# Send the advertising data to all controllers, except the sender
for controller in self.controllers:
if controller.random_address != sender_address:
controller.on_link_advertising_data(sender_address, data)
def send_acl_data(
self,
sender_controller: controller.Controller,
@@ -100,7 +89,7 @@ class LocalLink:
):
# Send the data to the first controller with a matching address
if transport == core.PhysicalTransport.LE:
destination_controller = self.find_controller(destination_address)
destination_controller = self.find_le_controller(destination_address)
source_address = sender_controller.random_address
elif transport == core.PhysicalTransport.BR_EDR:
destination_controller = self.find_classic_controller(destination_address)
@@ -115,152 +104,30 @@ class LocalLink:
)
)
def on_connection_complete(self) -> None:
# Check that we expect this call
if not self.pending_connection:
logger.warning('on_connection_complete with no pending connection')
return
def send_advertising_pdu(
self,
sender_controller: controller.Controller,
packet: ll.AdvertisingPdu,
):
loop = asyncio.get_running_loop()
for c in self.controllers:
if c != sender_controller:
loop.call_soon(c.on_ll_advertising_pdu, packet)
central_address, le_create_connection_command = self.pending_connection
self.pending_connection = None
# Find the controller that initiated the connection
if not (central_controller := self.find_controller(central_address)):
logger.warning('!!! Initiating controller not found')
return
# Connect to the first controller with a matching address
if peripheral_controller := self.find_controller(
le_create_connection_command.peer_address
):
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, hci.HCI_SUCCESS
def send_ll_control_pdu(
self,
sender_address: hci.Address,
receiver_address: hci.Address,
packet: ll.ControlPdu,
):
if not (receiver_controller := self.find_le_controller(receiver_address)):
raise core.InvalidArgumentError(
f"Unable to find controller for address {receiver_address}"
)
peripheral_controller.on_link_central_connected(central_address)
return
# No peripheral found
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, hci.HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
)
def connect(
self,
central_address: hci.Address,
le_create_connection_command: hci.HCI_LE_Create_Connection_Command,
):
logger.debug(
f'$$$ CONNECTION {central_address} -> '
f'{le_create_connection_command.peer_address}'
)
self.pending_connection = (central_address, le_create_connection_command)
asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete(
self,
initiating_address: hci.Address,
target_address: hci.Address,
disconnect_command: hci.HCI_Disconnect_Command,
):
# Find the controller that initiated the disconnection
if not (initiating_controller := self.find_controller(initiating_address)):
logger.warning('!!! Initiating controller not found')
return
# Disconnect from the first controller with a matching address
if target_controller := self.find_controller(target_address):
target_controller.on_link_disconnected(
initiating_address, disconnect_command.reason
)
initiating_controller.on_link_disconnection_complete(
disconnect_command, hci.HCI_SUCCESS
)
def disconnect(
self,
initiating_address: hci.Address,
target_address: hci.Address,
disconnect_command: hci.HCI_Disconnect_Command,
):
logger.debug(
f'$$$ DISCONNECTION {initiating_address} -> '
f'{target_address}: reason = {disconnect_command.reason}'
)
asyncio.get_running_loop().call_soon(
lambda: self.on_disconnection_complete(
initiating_address, target_address, disconnect_command
)
lambda: receiver_controller.on_ll_control_pdu(sender_address, packet)
)
def on_connection_encrypted(
self,
central_address: hci.Address,
peripheral_address: hci.Address,
rand: bytes,
ediv: int,
ltk: bytes,
):
logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
if central_controller := self.find_controller(central_address):
central_controller.on_link_encrypted(peripheral_address, rand, ediv, ltk)
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
def create_cis(
self,
central_controller: controller.Controller,
peripheral_address: hci.Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
)
if peripheral_controller := self.find_controller(peripheral_address):
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_request,
central_controller.random_address,
cig_id,
cis_id,
)
def accept_cis(
self,
peripheral_controller: controller.Controller,
central_address: hci.Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
)
if central_controller := self.find_controller(central_address):
loop = asyncio.get_running_loop()
loop.call_soon(central_controller.on_link_cis_established, cig_id, cis_id)
loop.call_soon(
peripheral_controller.on_link_cis_established, cig_id, cis_id
)
def disconnect_cis(
self,
initiator_controller: controller.Controller,
peer_address: hci.Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
)
if peer_controller := self.find_controller(peer_address):
loop = asyncio.get_running_loop()
loop.call_soon(
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
)
loop.call_soon(peer_controller.on_link_cis_disconnected, cig_id, cis_id)
############################################################
# Classic handlers
############################################################

200
bumble/ll.py Normal file
View File

@@ -0,0 +1,200 @@
# Copyright 2021-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
from typing import ClassVar
from bumble import hci
# -----------------------------------------------------------------------------
# Advertising PDU
# -----------------------------------------------------------------------------
class AdvertisingPdu:
"""Base Advertising Physical Channel PDU class.
See Core Spec 6.0, Volume 6, Part B, 2.3. Advertising physical channel PDU.
Currently these messages don't really follow the LL spec, because LL protocol is
context-aware and we don't have real physical transport.
"""
@dataclasses.dataclass
class ConnectInd(AdvertisingPdu):
initiator_address: hci.Address
advertiser_address: hci.Address
interval: int
latency: int
timeout: int
@dataclasses.dataclass
class AdvInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
@dataclasses.dataclass
class AdvDirectInd(AdvertisingPdu):
advertiser_address: hci.Address
target_address: hci.Address
@dataclasses.dataclass
class AdvNonConnInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
@dataclasses.dataclass
class AdvExtInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
target_address: hci.Address | None = None
adi: int | None = None
tx_power: int | None = None
# -----------------------------------------------------------------------------
# LL Control PDU
# -----------------------------------------------------------------------------
class ControlPdu:
"""Base LL Control PDU Class.
See Core Spec 6.0, Volume 6, Part B, 2.4.2. LL Control PDU.
Currently these messages don't really follow the LL spec, because LL protocol is
context-aware and we don't have real physical transport.
"""
class Opcode(hci.SpecableEnum):
LL_CONNECTION_UPDATE_IND = 0x00
LL_CHANNEL_MAP_IND = 0x01
LL_TERMINATE_IND = 0x02
LL_ENC_REQ = 0x03
LL_ENC_RSP = 0x04
LL_START_ENC_REQ = 0x05
LL_START_ENC_RSP = 0x06
LL_UNKNOWN_RSP = 0x07
LL_FEATURE_REQ = 0x08
LL_FEATURE_RSP = 0x09
LL_PAUSE_ENC_REQ = 0x0A
LL_PAUSE_ENC_RSP = 0x0B
LL_VERSION_IND = 0x0C
LL_REJECT_IND = 0x0D
LL_PERIPHERAL_FEATURE_REQ = 0x0E
LL_CONNECTION_PARAM_REQ = 0x0F
LL_CONNECTION_PARAM_RSP = 0x10
LL_REJECT_EXT_IND = 0x11
LL_PING_REQ = 0x12
LL_PING_RSP = 0x13
LL_LENGTH_REQ = 0x14
LL_LENGTH_RSP = 0x15
LL_PHY_REQ = 0x16
LL_PHY_RSP = 0x17
LL_PHY_UPDATE_IND = 0x18
LL_MIN_USED_CHANNELS_IND = 0x19
LL_CTE_REQ = 0x1A
LL_CTE_RSP = 0x1B
LL_PERIODIC_SYNC_IND = 0x1C
LL_CLOCK_ACCURACY_REQ = 0x1D
LL_CLOCK_ACCURACY_RSP = 0x1E
LL_CIS_REQ = 0x1F
LL_CIS_RSP = 0x20
LL_CIS_IND = 0x21
LL_CIS_TERMINATE_IND = 0x22
LL_POWER_CONTROL_REQ = 0x23
LL_POWER_CONTROL_RSP = 0x24
LL_POWER_CHANGE_IND = 0x25
LL_SUBRATE_REQ = 0x26
LL_SUBRATE_IND = 0x27
LL_CHANNEL_REPORTING_IND = 0x28
LL_CHANNEL_STATUS_IND = 0x29
LL_PERIODIC_SYNC_WR_IND = 0x2A
LL_FEATURE_EXT_REQ = 0x2B
LL_FEATURE_EXT_RSP = 0x2C
LL_CS_SEC_RSP = 0x2D
LL_CS_CAPABILITIES_REQ = 0x2E
LL_CS_CAPABILITIES_RSP = 0x2F
LL_CS_CONFIG_REQ = 0x30
LL_CS_CONFIG_RSP = 0x31
LL_CS_REQ = 0x32
LL_CS_RSP = 0x33
LL_CS_IND = 0x34
LL_CS_TERMINATE_REQ = 0x35
LL_CS_FAE_REQ = 0x36
LL_CS_FAE_RSP = 0x37
LL_CS_CHANNEL_MAP_IND = 0x38
LL_CS_SEC_REQ = 0x39
LL_CS_TERMINATE_RSP = 0x3A
LL_FRAME_SPACE_REQ = 0x3B
LL_FRAME_SPACE_RSP = 0x3C
opcode: ClassVar[Opcode]
@dataclasses.dataclass
class TerminateInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_TERMINATE_IND
error_code: int
@dataclasses.dataclass
class EncReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_ENC_REQ
rand: bytes
ediv: int
ltk: bytes
@dataclasses.dataclass
class CisReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisRsp(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisTerminateInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_TERMINATE_IND
cig_id: int
cis_id: int
error_code: int

View File

@@ -284,52 +284,51 @@ async def test_legacy_advertising():
@pytest.mark.asyncio
async def test_legacy_advertising_disconnection(auto_restart):
devices = TwoDevices()
device = devices[0]
devices.controllers[0].le_features = bytes.fromhex('ffffffffffffffff')
await device.power_on()
peer_address = Address('F0:F1:F2:F3:F4:F5')
await device.start_advertising(auto_restart=auto_restart)
device.on_le_connection(
0x0001,
peer_address,
None,
None,
Role.PERIPHERAL,
0,
0,
0,
for controller in devices.controllers:
controller.le_features |= hci.LeFeatureMask.LE_EXTENDED_ADVERTISING
for dev in devices:
await dev.power_on()
await devices[0].start_advertising(
auto_restart=auto_restart, advertising_interval_min=1.0
)
connecion = await devices[1].connect(devices[0].random_address)
device.on_advertising_set_termination(
HCI_SUCCESS, device.legacy_advertising_set.advertising_handle, 0x0001, 0
)
await connecion.disconnect()
device.on_disconnection(0x0001, 0)
await async_barrier()
await async_barrier()
if auto_restart:
assert device.legacy_advertising_set
assert devices[0].legacy_advertising_set
started = asyncio.Event()
if not device.is_advertising:
device.legacy_advertising_set.once('start', started.set)
if not devices[0].is_advertising:
devices[0].legacy_advertising_set.once('start', started.set)
await asyncio.wait_for(started.wait(), _TIMEOUT)
assert device.is_advertising
assert devices[0].is_advertising
else:
assert not device.is_advertising
assert not devices[0].is_advertising
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_extended_advertising():
device = TwoDevices()[0]
await device.power_on()
async def test_advertising_and_scanning():
devices = TwoDevices()
for dev in devices:
await dev.power_on()
# Start scanning
advertisements = asyncio.Queue[device.Advertisement]()
devices[1].on(devices[1].EVENT_ADVERTISEMENT, advertisements.put_nowait)
await devices[1].start_scanning()
# Start advertising
advertising_set = await device.create_advertising_set()
assert device.extended_advertising_sets
advertising_set = await devices[0].create_advertising_set(advertising_data=b'123')
assert devices[0].extended_advertising_sets
assert advertising_set.enabled
advertisement = await asyncio.wait_for(advertisements.get(), _TIMEOUT)
assert advertisement.data_bytes == b'123'
# Stop advertising
await advertising_set.stop()
assert not advertising_set.enabled
@@ -342,33 +341,33 @@ async def test_extended_advertising():
)
@pytest.mark.asyncio
async def test_extended_advertising_connection(own_address_type):
device = TwoDevices()[0]
await device.power_on()
peer_address = Address('F0:F1:F2:F3:F4:F5')
advertising_set = await device.create_advertising_set(
advertising_parameters=AdvertisingParameters(own_address_type=own_address_type)
devices = TwoDevices()
for dev in devices:
await dev.power_on()
advertising_set = await devices[0].create_advertising_set(
advertising_parameters=AdvertisingParameters(
own_address_type=own_address_type, primary_advertising_interval_min=1.0
)
)
device.on_le_connection(
0x0001,
peer_address,
None,
None,
Role.PERIPHERAL,
0,
0,
0,
)
device.on_advertising_set_termination(
HCI_SUCCESS,
advertising_set.advertising_handle,
0x0001,
0,
await asyncio.wait_for(
devices[1].connect(advertising_set.random_address or devices[0].public_address),
_TIMEOUT,
)
await async_barrier()
# Advertising set should be terminated after connected.
assert not advertising_set.enabled
if own_address_type == OwnAddressType.PUBLIC:
assert device.lookup_connection(0x0001).self_address == device.public_address
assert (
devices[0].lookup_connection(0x0001).self_address
== devices[0].public_address
)
else:
assert device.lookup_connection(0x0001).self_address == device.random_address
assert (
devices[0].lookup_connection(0x0001).self_address
== devices[0].random_address
)
await async_barrier()
@@ -382,7 +381,7 @@ async def test_extended_advertising_connection(own_address_type):
async def test_extended_advertising_connection_out_of_order(own_address_type):
devices = TwoDevices()
device = devices[0]
devices.controllers[0].le_features = bytes.fromhex('ffffffffffffffff')
devices.controllers[0].le_features |= hci.LeFeatureMask.LE_EXTENDED_ADVERTISING
await device.power_on()
advertising_set = await device.create_advertising_set(
advertising_parameters=AdvertisingParameters(own_address_type=own_address_type)

View File

@@ -69,7 +69,7 @@ from bumble.host import Host
from bumble.link import LocalLink
from bumble.transport.common import AsyncPipeSink
from .test_utils import async_barrier
from .test_utils import Devices, TwoDevices, async_barrier
# -----------------------------------------------------------------------------
@@ -160,7 +160,8 @@ async def test_characteristic_encoding():
def decode_value(self, value_bytes):
return value_bytes[0]
[client, server] = LinkedDevices().devices[:2]
devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -189,9 +190,7 @@ async def test_characteristic_encoding():
)
server.add_service(service)
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
connection = devices.connections[0]
peer = Peer(connection)
await peer.discover_services()
@@ -279,7 +278,8 @@ async def test_characteristic_encoding():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_attribute_getters():
[client, server] = LinkedDevices().devices[:2]
devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic_uuid = UUID('FDB159DB-036C-49E3-B3DB-6325AC750806')
characteristic = Characteristic(
@@ -629,39 +629,11 @@ async def test_CharacteristicValue_async():
m.assert_called_once_with(z, b)
# -----------------------------------------------------------------------------
class LinkedDevices:
def __init__(self):
self.connections = [None, None, None]
self.link = LocalLink()
self.controllers = [
Controller('C1', link=self.link),
Controller('C2', link=self.link),
Controller('C3', link=self.link),
]
self.devices = [
Device(
address='F0:F1:F2:F3:F4:F5',
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
),
Device(
address='F1:F2:F3:F4:F5:F6',
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
Device(
address='F2:F3:F4:F5:F6:F7',
host=Host(self.controllers[2], AsyncPipeSink(self.controllers[2])),
),
]
self.paired = [None, None, None]
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_write():
[client, server] = LinkedDevices().devices[:2]
devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -694,9 +666,7 @@ async def test_read_write():
)
server.add_services([service1])
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
connection = devices.connections[0]
peer = Peer(connection)
await peer.discover_services()
@@ -740,7 +710,8 @@ async def test_read_write():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_write2():
[client, server] = LinkedDevices().devices[:2]
devices = await TwoDevices.create_with_connection()
[client, server] = devices
v = bytes([0x11, 0x22, 0x33, 0x44])
characteristic1 = Characteristic(
@@ -753,9 +724,7 @@ async def test_read_write2():
service1 = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic1])
server.add_services([service1])
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
connection = devices.connections[0]
peer = Peer(connection)
await peer.discover_services()
@@ -785,7 +754,8 @@ async def test_read_write2():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_subscribe_notify():
[client, server] = LinkedDevices().devices[:2]
devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -855,9 +825,7 @@ async def test_subscribe_notify():
server.on('characteristic_subscription', on_characteristic_subscription)
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
connection = devices.connections[0]
peer = Peer(connection)
await peer.discover_services()
@@ -1006,7 +974,8 @@ async def test_subscribe_notify():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_unsubscribe():
[client, server] = LinkedDevices().devices[:2]
devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -1032,9 +1001,7 @@ async def test_unsubscribe():
mock2 = Mock()
characteristic2.on('subscription', mock2)
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
connection = devices.connections[0]
peer = Peer(connection)
await peer.discover_services()
@@ -1094,7 +1061,8 @@ async def test_unsubscribe():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_discover_all():
[client, server] = LinkedDevices().devices[:2]
devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -1120,9 +1088,7 @@ async def test_discover_all():
service2 = Service('1111', [])
server.add_services([service1, service2])
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
connection = devices.connections[0]
peer = Peer(connection)
await peer.discover_all()
@@ -1146,7 +1112,10 @@ async def test_discover_all():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mtu_exchange():
[d1, d2, d3] = LinkedDevices().devices[:3]
devices = Devices(3)
for dev in devices:
await dev.power_on()
[d1, d2, d3] = devices
d3.gatt_server.max_mtu = 100
@@ -1160,11 +1129,15 @@ async def test_mtu_exchange():
await d2.power_on()
await d3.power_on()
await d3.start_advertising(advertising_interval_min=1.0)
d1_connection = await d1.connect(d3.random_address)
await async_barrier()
assert len(d3_connections) == 1
assert d3_connections[0] is not None
await d3.start_advertising(advertising_interval_min=1.0)
d2_connection = await d2.connect(d3.random_address)
await async_barrier()
assert len(d3_connections) == 2
assert d3_connections[1] is not None
@@ -1233,7 +1206,8 @@ Got: BROADCAST,HELLO"""
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_server_string():
[_, server] = LinkedDevices().devices[:2]
devices = await TwoDevices.create_with_connection()
[_, server] = devices
characteristic = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -1422,7 +1396,8 @@ def test_get_attribute_group():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_characteristics_by_uuid():
[client, server] = LinkedDevices().devices[:2]
devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic1 = Characteristic(
'1234',
@@ -1447,9 +1422,7 @@ async def test_get_characteristics_by_uuid():
server.add_services([service1, service2])
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
connection = devices.connections[0]
peer = Peer(connection)
await peer.discover_services()
@@ -1472,7 +1445,8 @@ async def test_get_characteristics_by_uuid():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_write_return_error():
[client, server] = LinkedDevices().devices[:2]
devices = await TwoDevices.create_with_connection()
[client, server] = devices
on_write = Mock(side_effect=ATT_Error(error_code=ErrorCode.VALUE_NOT_ALLOWED))
characteristic = Characteristic(
@@ -1484,9 +1458,7 @@ async def test_write_return_error():
service = Service('ABCD', [characteristic])
server.add_service(service)
await client.power_on()
await server.power_on()
connection = await client.connect(server.random_address)
connection = devices.connections[0]
async with Peer(connection) as peer:
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))[0]

View File

@@ -35,7 +35,7 @@ from bumble.smp import (
OobLegacyContext,
)
from .test_utils import TwoDevices
from .test_utils import TwoDevices, async_barrier
# -----------------------------------------------------------------------------
# Logging
@@ -56,12 +56,14 @@ async def test_self_disconnection():
two_devices = TwoDevices()
await two_devices.setup_connection()
await two_devices.connections[0].disconnect()
await async_barrier()
assert two_devices.connections[0] is None
assert two_devices.connections[1] is None
two_devices = TwoDevices()
await two_devices.setup_connection()
await two_devices.connections[1].disconnect()
await async_barrier()
assert two_devices.connections[0] is None
assert two_devices.connections[1] is None
@@ -80,7 +82,8 @@ async def test_self_classic_connection(responder_role):
two_devices.devices[1].classic_enabled = True
# Start
await two_devices.setup_connection()
for dev in two_devices.devices:
await dev.power_on()
# Connect the two devices
await asyncio.gather(
@@ -418,8 +421,9 @@ async def test_self_smp_over_classic():
two_devices.devices[1].classic_enabled = True
# Connect the two devices
await two_devices.devices[0].power_on()
await two_devices.devices[1].power_on()
for dev in two_devices.devices:
await dev.power_on()
await asyncio.gather(
two_devices.devices[0].connect(
two_devices.devices[1].public_address, transport=PhysicalTransport.BR_EDR

View File

@@ -16,6 +16,7 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import functools
from typing import Optional
from typing_extensions import Self
@@ -30,39 +31,34 @@ from bumble.transport.common import AsyncPipeSink
# -----------------------------------------------------------------------------
class TwoDevices:
class Devices:
connections: list[Optional[Connection]]
def __init__(self) -> None:
self.connections = [None, None]
def __init__(self, num_devices: int) -> None:
self.connections = [None for _ in range(num_devices)]
self.link = LocalLink()
addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0']
addresses = [":".join([f"F{i}"] * 6) for i in range(num_devices)]
self.controllers = [
Controller('C1', link=self.link, public_address=addresses[0]),
Controller('C2', link=self.link, public_address=addresses[1]),
Controller(f'C{i+i}', link=self.link, public_address=addresses[i])
for i in range(num_devices)
]
self.devices = [
Device(
address=Address(addresses[0]),
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
),
Device(
address=Address(addresses[1]),
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
address=Address(addresses[i]),
host=Host(self.controllers[i], AsyncPipeSink(self.controllers[i])),
)
for i in range(num_devices)
]
self.devices[0].on(
'connection', lambda connection: self.on_connection(0, connection)
)
self.devices[1].on(
'connection', lambda connection: self.on_connection(1, connection)
)
for i in range(num_devices):
self.devices[i].on(
self.devices[i].EVENT_CONNECTION,
functools.partial(self.on_connection, i),
)
self.paired = [
asyncio.get_event_loop().create_future(),
asyncio.get_event_loop().create_future(),
asyncio.get_event_loop().create_future() for _ in range(num_devices)
]
def on_connection(self, which, connection):
@@ -77,19 +73,26 @@ class TwoDevices:
async def setup_connection(self) -> None:
# Start
await self.devices[0].power_on()
await self.devices[1].power_on()
for dev in self.devices:
await dev.power_on()
# Connect the two devices
await self.devices[0].connect(self.devices[1].random_address)
# Check the post conditions
assert self.connections[0] is not None
assert self.connections[1] is not None
# Connect devices
for dev in self.devices[1:]:
connection_future = asyncio.get_running_loop().create_future()
dev.once(dev.EVENT_CONNECTION, connection_future.set_result)
await dev.start_advertising(advertising_interval_min=1.0)
await self.devices[0].connect(dev.random_address)
await connection_future
def __getitem__(self, index: int) -> Device:
return self.devices[index]
# -----------------------------------------------------------------------------
class TwoDevices(Devices):
def __init__(self) -> None:
super().__init__(2)
@classmethod
async def create_with_connection(cls: type[Self]) -> Self:
devices = cls()