Merge pull request #827 from zxzxwu/emu

Implement extended advertising emulation
This commit is contained in:
zxzxwu
2025-12-01 15:57:42 +08:00
committed by GitHub
8 changed files with 954 additions and 611 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -265,12 +265,22 @@ class ExtendedAdvertisement(Advertisement):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AdvertisementDataAccumulator: class AdvertisementDataAccumulator:
last_advertisement: Advertisement | None
last_data: bytes
passive: bool
def __init__(self, passive: bool = False): def __init__(self, passive: bool = False):
self.passive = passive self.passive = passive
self.last_advertisement = None self.last_advertisement = None
self.last_data = b'' self.last_data = b''
def update(self, report): def update(
self,
report: (
hci.HCI_LE_Advertising_Report_Event.Report
| hci.HCI_LE_Extended_Advertising_Report_Event.Report
),
) -> Advertisement | None:
advertisement = Advertisement.from_advertising_report(report) advertisement = Advertisement.from_advertising_report(report)
if advertisement is None: if advertisement is None:
return None return None
@@ -283,10 +293,12 @@ class AdvertisementDataAccumulator:
and not self.last_advertisement.is_scan_response and not self.last_advertisement.is_scan_response
): ):
# This is the response to a scannable advertisement # This is the response to a scannable advertisement
result = Advertisement.from_advertising_report(report) if result := Advertisement.from_advertising_report(report):
result.is_connectable = self.last_advertisement.is_connectable result.is_connectable = self.last_advertisement.is_connectable
result.is_scannable = True result.is_scannable = True
result.data = AdvertisingData.from_bytes(self.last_data + report.data) result.data = AdvertisingData.from_bytes(
self.last_data + report.data
)
self.last_data = b'' self.last_data = b''
else: else:
if ( if (
@@ -3333,7 +3345,13 @@ class Device(utils.CompositeEventEmitter):
return self.scanning return self.scanning
@host_event_handler @host_event_handler
def on_advertising_report(self, report): def on_advertising_report(
self,
report: (
hci.HCI_LE_Advertising_Report_Event.Report
| hci.HCI_LE_Extended_Advertising_Report_Event.Report
),
) -> None:
if not (accumulator := self.advertisement_accumulators.get(report.address)): if not (accumulator := self.advertisement_accumulators.get(report.address)):
accumulator = AdvertisementDataAccumulator(passive=self.scanning_is_passive) accumulator = AdvertisementDataAccumulator(passive=self.scanning_is_passive)
self.advertisement_accumulators[report.address] = accumulator self.advertisement_accumulators[report.address] = accumulator

View File

@@ -19,9 +19,12 @@ import asyncio
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
from typing import Optional from typing import TYPE_CHECKING, Optional
from bumble import controller, core, hci, lmp from bumble import core, hci, ll, lmp
if TYPE_CHECKING:
from bumble import controller
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -29,11 +32,6 @@ from bumble import controller, core, hci, lmp
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# TODO: add more support for various LL exchanges # TODO: add more support for various LL exchanges
# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU) # (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
@@ -47,7 +45,6 @@ class LocalLink:
def __init__(self): def __init__(self):
self.controllers = set() self.controllers = set()
self.pending_connection = None
self.pending_classic_connection = None self.pending_classic_connection = None
############################################################ ############################################################
@@ -61,10 +58,11 @@ class LocalLink:
def remove_controller(self, controller: controller.Controller): def remove_controller(self, controller: controller.Controller):
self.controllers.remove(controller) self.controllers.remove(controller)
def find_controller(self, address: hci.Address) -> controller.Controller | None: def find_le_controller(self, address: hci.Address) -> controller.Controller | None:
for controller in self.controllers: for controller in self.controllers:
if controller.random_address == address: for connection in controller.le_connections.values():
return controller if connection.self_address == address:
return controller
return None return None
def find_classic_controller( def find_classic_controller(
@@ -75,9 +73,6 @@ class LocalLink:
return controller return controller
return None return None
def get_pending_connection(self):
return self.pending_connection
############################################################ ############################################################
# LE handlers # LE handlers
############################################################ ############################################################
@@ -85,12 +80,6 @@ class LocalLink:
def on_address_changed(self, controller): def on_address_changed(self, controller):
pass pass
def send_advertising_data(self, sender_address: hci.Address, data: bytes):
# Send the advertising data to all controllers, except the sender
for controller in self.controllers:
if controller.random_address != sender_address:
controller.on_link_advertising_data(sender_address, data)
def send_acl_data( def send_acl_data(
self, self,
sender_controller: controller.Controller, sender_controller: controller.Controller,
@@ -100,7 +89,7 @@ class LocalLink:
): ):
# Send the data to the first controller with a matching address # Send the data to the first controller with a matching address
if transport == core.PhysicalTransport.LE: if transport == core.PhysicalTransport.LE:
destination_controller = self.find_controller(destination_address) destination_controller = self.find_le_controller(destination_address)
source_address = sender_controller.random_address source_address = sender_controller.random_address
elif transport == core.PhysicalTransport.BR_EDR: elif transport == core.PhysicalTransport.BR_EDR:
destination_controller = self.find_classic_controller(destination_address) destination_controller = self.find_classic_controller(destination_address)
@@ -115,152 +104,30 @@ class LocalLink:
) )
) )
def on_connection_complete(self) -> None: def send_advertising_pdu(
# Check that we expect this call self,
if not self.pending_connection: sender_controller: controller.Controller,
logger.warning('on_connection_complete with no pending connection') packet: ll.AdvertisingPdu,
return ):
loop = asyncio.get_running_loop()
for c in self.controllers:
if c != sender_controller:
loop.call_soon(c.on_ll_advertising_pdu, packet)
central_address, le_create_connection_command = self.pending_connection def send_ll_control_pdu(
self.pending_connection = None self,
sender_address: hci.Address,
# Find the controller that initiated the connection receiver_address: hci.Address,
if not (central_controller := self.find_controller(central_address)): packet: ll.ControlPdu,
logger.warning('!!! Initiating controller not found') ):
return if not (receiver_controller := self.find_le_controller(receiver_address)):
raise core.InvalidArgumentError(
# Connect to the first controller with a matching address f"Unable to find controller for address {receiver_address}"
if peripheral_controller := self.find_controller(
le_create_connection_command.peer_address
):
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, hci.HCI_SUCCESS
) )
peripheral_controller.on_link_central_connected(central_address)
return
# No peripheral found
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, hci.HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
)
def connect(
self,
central_address: hci.Address,
le_create_connection_command: hci.HCI_LE_Create_Connection_Command,
):
logger.debug(
f'$$$ CONNECTION {central_address} -> '
f'{le_create_connection_command.peer_address}'
)
self.pending_connection = (central_address, le_create_connection_command)
asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete(
self,
initiating_address: hci.Address,
target_address: hci.Address,
disconnect_command: hci.HCI_Disconnect_Command,
):
# Find the controller that initiated the disconnection
if not (initiating_controller := self.find_controller(initiating_address)):
logger.warning('!!! Initiating controller not found')
return
# Disconnect from the first controller with a matching address
if target_controller := self.find_controller(target_address):
target_controller.on_link_disconnected(
initiating_address, disconnect_command.reason
)
initiating_controller.on_link_disconnection_complete(
disconnect_command, hci.HCI_SUCCESS
)
def disconnect(
self,
initiating_address: hci.Address,
target_address: hci.Address,
disconnect_command: hci.HCI_Disconnect_Command,
):
logger.debug(
f'$$$ DISCONNECTION {initiating_address} -> '
f'{target_address}: reason = {disconnect_command.reason}'
)
asyncio.get_running_loop().call_soon( asyncio.get_running_loop().call_soon(
lambda: self.on_disconnection_complete( lambda: receiver_controller.on_ll_control_pdu(sender_address, packet)
initiating_address, target_address, disconnect_command
)
) )
def on_connection_encrypted(
self,
central_address: hci.Address,
peripheral_address: hci.Address,
rand: bytes,
ediv: int,
ltk: bytes,
):
logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
if central_controller := self.find_controller(central_address):
central_controller.on_link_encrypted(peripheral_address, rand, ediv, ltk)
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
def create_cis(
self,
central_controller: controller.Controller,
peripheral_address: hci.Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
)
if peripheral_controller := self.find_controller(peripheral_address):
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_request,
central_controller.random_address,
cig_id,
cis_id,
)
def accept_cis(
self,
peripheral_controller: controller.Controller,
central_address: hci.Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
)
if central_controller := self.find_controller(central_address):
loop = asyncio.get_running_loop()
loop.call_soon(central_controller.on_link_cis_established, cig_id, cis_id)
loop.call_soon(
peripheral_controller.on_link_cis_established, cig_id, cis_id
)
def disconnect_cis(
self,
initiator_controller: controller.Controller,
peer_address: hci.Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
)
if peer_controller := self.find_controller(peer_address):
loop = asyncio.get_running_loop()
loop.call_soon(
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
)
loop.call_soon(peer_controller.on_link_cis_disconnected, cig_id, cis_id)
############################################################ ############################################################
# Classic handlers # Classic handlers
############################################################ ############################################################

200
bumble/ll.py Normal file
View File

@@ -0,0 +1,200 @@
# Copyright 2021-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
from typing import ClassVar
from bumble import hci
# -----------------------------------------------------------------------------
# Advertising PDU
# -----------------------------------------------------------------------------
class AdvertisingPdu:
"""Base Advertising Physical Channel PDU class.
See Core Spec 6.0, Volume 6, Part B, 2.3. Advertising physical channel PDU.
Currently these messages don't really follow the LL spec, because LL protocol is
context-aware and we don't have real physical transport.
"""
@dataclasses.dataclass
class ConnectInd(AdvertisingPdu):
initiator_address: hci.Address
advertiser_address: hci.Address
interval: int
latency: int
timeout: int
@dataclasses.dataclass
class AdvInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
@dataclasses.dataclass
class AdvDirectInd(AdvertisingPdu):
advertiser_address: hci.Address
target_address: hci.Address
@dataclasses.dataclass
class AdvNonConnInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
@dataclasses.dataclass
class AdvExtInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
target_address: hci.Address | None = None
adi: int | None = None
tx_power: int | None = None
# -----------------------------------------------------------------------------
# LL Control PDU
# -----------------------------------------------------------------------------
class ControlPdu:
"""Base LL Control PDU Class.
See Core Spec 6.0, Volume 6, Part B, 2.4.2. LL Control PDU.
Currently these messages don't really follow the LL spec, because LL protocol is
context-aware and we don't have real physical transport.
"""
class Opcode(hci.SpecableEnum):
LL_CONNECTION_UPDATE_IND = 0x00
LL_CHANNEL_MAP_IND = 0x01
LL_TERMINATE_IND = 0x02
LL_ENC_REQ = 0x03
LL_ENC_RSP = 0x04
LL_START_ENC_REQ = 0x05
LL_START_ENC_RSP = 0x06
LL_UNKNOWN_RSP = 0x07
LL_FEATURE_REQ = 0x08
LL_FEATURE_RSP = 0x09
LL_PAUSE_ENC_REQ = 0x0A
LL_PAUSE_ENC_RSP = 0x0B
LL_VERSION_IND = 0x0C
LL_REJECT_IND = 0x0D
LL_PERIPHERAL_FEATURE_REQ = 0x0E
LL_CONNECTION_PARAM_REQ = 0x0F
LL_CONNECTION_PARAM_RSP = 0x10
LL_REJECT_EXT_IND = 0x11
LL_PING_REQ = 0x12
LL_PING_RSP = 0x13
LL_LENGTH_REQ = 0x14
LL_LENGTH_RSP = 0x15
LL_PHY_REQ = 0x16
LL_PHY_RSP = 0x17
LL_PHY_UPDATE_IND = 0x18
LL_MIN_USED_CHANNELS_IND = 0x19
LL_CTE_REQ = 0x1A
LL_CTE_RSP = 0x1B
LL_PERIODIC_SYNC_IND = 0x1C
LL_CLOCK_ACCURACY_REQ = 0x1D
LL_CLOCK_ACCURACY_RSP = 0x1E
LL_CIS_REQ = 0x1F
LL_CIS_RSP = 0x20
LL_CIS_IND = 0x21
LL_CIS_TERMINATE_IND = 0x22
LL_POWER_CONTROL_REQ = 0x23
LL_POWER_CONTROL_RSP = 0x24
LL_POWER_CHANGE_IND = 0x25
LL_SUBRATE_REQ = 0x26
LL_SUBRATE_IND = 0x27
LL_CHANNEL_REPORTING_IND = 0x28
LL_CHANNEL_STATUS_IND = 0x29
LL_PERIODIC_SYNC_WR_IND = 0x2A
LL_FEATURE_EXT_REQ = 0x2B
LL_FEATURE_EXT_RSP = 0x2C
LL_CS_SEC_RSP = 0x2D
LL_CS_CAPABILITIES_REQ = 0x2E
LL_CS_CAPABILITIES_RSP = 0x2F
LL_CS_CONFIG_REQ = 0x30
LL_CS_CONFIG_RSP = 0x31
LL_CS_REQ = 0x32
LL_CS_RSP = 0x33
LL_CS_IND = 0x34
LL_CS_TERMINATE_REQ = 0x35
LL_CS_FAE_REQ = 0x36
LL_CS_FAE_RSP = 0x37
LL_CS_CHANNEL_MAP_IND = 0x38
LL_CS_SEC_REQ = 0x39
LL_CS_TERMINATE_RSP = 0x3A
LL_FRAME_SPACE_REQ = 0x3B
LL_FRAME_SPACE_RSP = 0x3C
opcode: ClassVar[Opcode]
@dataclasses.dataclass
class TerminateInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_TERMINATE_IND
error_code: int
@dataclasses.dataclass
class EncReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_ENC_REQ
rand: bytes
ediv: int
ltk: bytes
@dataclasses.dataclass
class CisReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisRsp(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisTerminateInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_TERMINATE_IND
cig_id: int
cis_id: int
error_code: int

View File

@@ -284,52 +284,51 @@ async def test_legacy_advertising():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_legacy_advertising_disconnection(auto_restart): async def test_legacy_advertising_disconnection(auto_restart):
devices = TwoDevices() devices = TwoDevices()
device = devices[0] for controller in devices.controllers:
devices.controllers[0].le_features = bytes.fromhex('ffffffffffffffff') controller.le_features |= hci.LeFeatureMask.LE_EXTENDED_ADVERTISING
await device.power_on() for dev in devices:
peer_address = Address('F0:F1:F2:F3:F4:F5') await dev.power_on()
await device.start_advertising(auto_restart=auto_restart) await devices[0].start_advertising(
device.on_le_connection( auto_restart=auto_restart, advertising_interval_min=1.0
0x0001,
peer_address,
None,
None,
Role.PERIPHERAL,
0,
0,
0,
) )
connecion = await devices[1].connect(devices[0].random_address)
device.on_advertising_set_termination( await connecion.disconnect()
HCI_SUCCESS, device.legacy_advertising_set.advertising_handle, 0x0001, 0
)
device.on_disconnection(0x0001, 0)
await async_barrier() await async_barrier()
await async_barrier() await async_barrier()
if auto_restart: if auto_restart:
assert device.legacy_advertising_set assert devices[0].legacy_advertising_set
started = asyncio.Event() started = asyncio.Event()
if not device.is_advertising: if not devices[0].is_advertising:
device.legacy_advertising_set.once('start', started.set) devices[0].legacy_advertising_set.once('start', started.set)
await asyncio.wait_for(started.wait(), _TIMEOUT) await asyncio.wait_for(started.wait(), _TIMEOUT)
assert device.is_advertising assert devices[0].is_advertising
else: else:
assert not device.is_advertising assert not devices[0].is_advertising
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extended_advertising(): async def test_advertising_and_scanning():
device = TwoDevices()[0] devices = TwoDevices()
await device.power_on() for dev in devices:
await dev.power_on()
# Start scanning
advertisements = asyncio.Queue[device.Advertisement]()
devices[1].on(devices[1].EVENT_ADVERTISEMENT, advertisements.put_nowait)
await devices[1].start_scanning()
# Start advertising # Start advertising
advertising_set = await device.create_advertising_set() advertising_set = await devices[0].create_advertising_set(advertising_data=b'123')
assert device.extended_advertising_sets assert devices[0].extended_advertising_sets
assert advertising_set.enabled assert advertising_set.enabled
advertisement = await asyncio.wait_for(advertisements.get(), _TIMEOUT)
assert advertisement.data_bytes == b'123'
# Stop advertising # Stop advertising
await advertising_set.stop() await advertising_set.stop()
assert not advertising_set.enabled assert not advertising_set.enabled
@@ -342,33 +341,33 @@ async def test_extended_advertising():
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extended_advertising_connection(own_address_type): async def test_extended_advertising_connection(own_address_type):
device = TwoDevices()[0] devices = TwoDevices()
await device.power_on() for dev in devices:
peer_address = Address('F0:F1:F2:F3:F4:F5') await dev.power_on()
advertising_set = await device.create_advertising_set( advertising_set = await devices[0].create_advertising_set(
advertising_parameters=AdvertisingParameters(own_address_type=own_address_type) advertising_parameters=AdvertisingParameters(
own_address_type=own_address_type, primary_advertising_interval_min=1.0
)
) )
device.on_le_connection( await asyncio.wait_for(
0x0001, devices[1].connect(advertising_set.random_address or devices[0].public_address),
peer_address, _TIMEOUT,
None,
None,
Role.PERIPHERAL,
0,
0,
0,
)
device.on_advertising_set_termination(
HCI_SUCCESS,
advertising_set.advertising_handle,
0x0001,
0,
) )
await async_barrier()
# Advertising set should be terminated after connected.
assert not advertising_set.enabled
if own_address_type == OwnAddressType.PUBLIC: if own_address_type == OwnAddressType.PUBLIC:
assert device.lookup_connection(0x0001).self_address == device.public_address assert (
devices[0].lookup_connection(0x0001).self_address
== devices[0].public_address
)
else: else:
assert device.lookup_connection(0x0001).self_address == device.random_address assert (
devices[0].lookup_connection(0x0001).self_address
== devices[0].random_address
)
await async_barrier() await async_barrier()
@@ -382,7 +381,7 @@ async def test_extended_advertising_connection(own_address_type):
async def test_extended_advertising_connection_out_of_order(own_address_type): async def test_extended_advertising_connection_out_of_order(own_address_type):
devices = TwoDevices() devices = TwoDevices()
device = devices[0] device = devices[0]
devices.controllers[0].le_features = bytes.fromhex('ffffffffffffffff') devices.controllers[0].le_features |= hci.LeFeatureMask.LE_EXTENDED_ADVERTISING
await device.power_on() await device.power_on()
advertising_set = await device.create_advertising_set( advertising_set = await device.create_advertising_set(
advertising_parameters=AdvertisingParameters(own_address_type=own_address_type) advertising_parameters=AdvertisingParameters(own_address_type=own_address_type)

View File

@@ -69,7 +69,7 @@ from bumble.host import Host
from bumble.link import LocalLink from bumble.link import LocalLink
from bumble.transport.common import AsyncPipeSink from bumble.transport.common import AsyncPipeSink
from .test_utils import async_barrier from .test_utils import Devices, TwoDevices, async_barrier
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -160,7 +160,8 @@ async def test_characteristic_encoding():
def decode_value(self, value_bytes): def decode_value(self, value_bytes):
return value_bytes[0] return value_bytes[0]
[client, server] = LinkedDevices().devices[:2] devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic = Characteristic( characteristic = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806', 'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -189,9 +190,7 @@ async def test_characteristic_encoding():
) )
server.add_service(service) server.add_service(service)
await client.power_on() connection = devices.connections[0]
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection) peer = Peer(connection)
await peer.discover_services() await peer.discover_services()
@@ -279,7 +278,8 @@ async def test_characteristic_encoding():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_attribute_getters(): async def test_attribute_getters():
[client, server] = LinkedDevices().devices[:2] devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic_uuid = UUID('FDB159DB-036C-49E3-B3DB-6325AC750806') characteristic_uuid = UUID('FDB159DB-036C-49E3-B3DB-6325AC750806')
characteristic = Characteristic( characteristic = Characteristic(
@@ -629,39 +629,11 @@ async def test_CharacteristicValue_async():
m.assert_called_once_with(z, b) m.assert_called_once_with(z, b)
# -----------------------------------------------------------------------------
class LinkedDevices:
def __init__(self):
self.connections = [None, None, None]
self.link = LocalLink()
self.controllers = [
Controller('C1', link=self.link),
Controller('C2', link=self.link),
Controller('C3', link=self.link),
]
self.devices = [
Device(
address='F0:F1:F2:F3:F4:F5',
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
),
Device(
address='F1:F2:F3:F4:F5:F6',
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
Device(
address='F2:F3:F4:F5:F6:F7',
host=Host(self.controllers[2], AsyncPipeSink(self.controllers[2])),
),
]
self.paired = [None, None, None]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_read_write(): async def test_read_write():
[client, server] = LinkedDevices().devices[:2] devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic1 = Characteristic( characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806', 'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -694,9 +666,7 @@ async def test_read_write():
) )
server.add_services([service1]) server.add_services([service1])
await client.power_on() connection = devices.connections[0]
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection) peer = Peer(connection)
await peer.discover_services() await peer.discover_services()
@@ -740,7 +710,8 @@ async def test_read_write():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_read_write2(): async def test_read_write2():
[client, server] = LinkedDevices().devices[:2] devices = await TwoDevices.create_with_connection()
[client, server] = devices
v = bytes([0x11, 0x22, 0x33, 0x44]) v = bytes([0x11, 0x22, 0x33, 0x44])
characteristic1 = Characteristic( characteristic1 = Characteristic(
@@ -753,9 +724,7 @@ async def test_read_write2():
service1 = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic1]) service1 = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic1])
server.add_services([service1]) server.add_services([service1])
await client.power_on() connection = devices.connections[0]
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection) peer = Peer(connection)
await peer.discover_services() await peer.discover_services()
@@ -785,7 +754,8 @@ async def test_read_write2():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subscribe_notify(): async def test_subscribe_notify():
[client, server] = LinkedDevices().devices[:2] devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic1 = Characteristic( characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806', 'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -855,9 +825,7 @@ async def test_subscribe_notify():
server.on('characteristic_subscription', on_characteristic_subscription) server.on('characteristic_subscription', on_characteristic_subscription)
await client.power_on() connection = devices.connections[0]
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection) peer = Peer(connection)
await peer.discover_services() await peer.discover_services()
@@ -1006,7 +974,8 @@ async def test_subscribe_notify():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unsubscribe(): async def test_unsubscribe():
[client, server] = LinkedDevices().devices[:2] devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic1 = Characteristic( characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806', 'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -1032,9 +1001,7 @@ async def test_unsubscribe():
mock2 = Mock() mock2 = Mock()
characteristic2.on('subscription', mock2) characteristic2.on('subscription', mock2)
await client.power_on() connection = devices.connections[0]
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection) peer = Peer(connection)
await peer.discover_services() await peer.discover_services()
@@ -1094,7 +1061,8 @@ async def test_unsubscribe():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_discover_all(): async def test_discover_all():
[client, server] = LinkedDevices().devices[:2] devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic1 = Characteristic( characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806', 'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -1120,9 +1088,7 @@ async def test_discover_all():
service2 = Service('1111', []) service2 = Service('1111', [])
server.add_services([service1, service2]) server.add_services([service1, service2])
await client.power_on() connection = devices.connections[0]
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection) peer = Peer(connection)
await peer.discover_all() await peer.discover_all()
@@ -1146,7 +1112,10 @@ async def test_discover_all():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mtu_exchange(): async def test_mtu_exchange():
[d1, d2, d3] = LinkedDevices().devices[:3] devices = Devices(3)
for dev in devices:
await dev.power_on()
[d1, d2, d3] = devices
d3.gatt_server.max_mtu = 100 d3.gatt_server.max_mtu = 100
@@ -1160,11 +1129,15 @@ async def test_mtu_exchange():
await d2.power_on() await d2.power_on()
await d3.power_on() await d3.power_on()
await d3.start_advertising(advertising_interval_min=1.0)
d1_connection = await d1.connect(d3.random_address) d1_connection = await d1.connect(d3.random_address)
await async_barrier()
assert len(d3_connections) == 1 assert len(d3_connections) == 1
assert d3_connections[0] is not None assert d3_connections[0] is not None
await d3.start_advertising(advertising_interval_min=1.0)
d2_connection = await d2.connect(d3.random_address) d2_connection = await d2.connect(d3.random_address)
await async_barrier()
assert len(d3_connections) == 2 assert len(d3_connections) == 2
assert d3_connections[1] is not None assert d3_connections[1] is not None
@@ -1233,7 +1206,8 @@ Got: BROADCAST,HELLO"""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_string(): async def test_server_string():
[_, server] = LinkedDevices().devices[:2] devices = await TwoDevices.create_with_connection()
[_, server] = devices
characteristic = Characteristic( characteristic = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806', 'FDB159DB-036C-49E3-B3DB-6325AC750806',
@@ -1422,7 +1396,8 @@ def test_get_attribute_group():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_characteristics_by_uuid(): async def test_get_characteristics_by_uuid():
[client, server] = LinkedDevices().devices[:2] devices = await TwoDevices.create_with_connection()
[client, server] = devices
characteristic1 = Characteristic( characteristic1 = Characteristic(
'1234', '1234',
@@ -1447,9 +1422,7 @@ async def test_get_characteristics_by_uuid():
server.add_services([service1, service2]) server.add_services([service1, service2])
await client.power_on() connection = devices.connections[0]
await server.power_on()
connection = await client.connect(server.random_address)
peer = Peer(connection) peer = Peer(connection)
await peer.discover_services() await peer.discover_services()
@@ -1472,7 +1445,8 @@ async def test_get_characteristics_by_uuid():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_write_return_error(): async def test_write_return_error():
[client, server] = LinkedDevices().devices[:2] devices = await TwoDevices.create_with_connection()
[client, server] = devices
on_write = Mock(side_effect=ATT_Error(error_code=ErrorCode.VALUE_NOT_ALLOWED)) on_write = Mock(side_effect=ATT_Error(error_code=ErrorCode.VALUE_NOT_ALLOWED))
characteristic = Characteristic( characteristic = Characteristic(
@@ -1484,9 +1458,7 @@ async def test_write_return_error():
service = Service('ABCD', [characteristic]) service = Service('ABCD', [characteristic])
server.add_service(service) server.add_service(service)
await client.power_on() connection = devices.connections[0]
await server.power_on()
connection = await client.connect(server.random_address)
async with Peer(connection) as peer: async with Peer(connection) as peer:
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))[0] c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))[0]

View File

@@ -35,7 +35,7 @@ from bumble.smp import (
OobLegacyContext, OobLegacyContext,
) )
from .test_utils import TwoDevices from .test_utils import TwoDevices, async_barrier
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -56,12 +56,14 @@ async def test_self_disconnection():
two_devices = TwoDevices() two_devices = TwoDevices()
await two_devices.setup_connection() await two_devices.setup_connection()
await two_devices.connections[0].disconnect() await two_devices.connections[0].disconnect()
await async_barrier()
assert two_devices.connections[0] is None assert two_devices.connections[0] is None
assert two_devices.connections[1] is None assert two_devices.connections[1] is None
two_devices = TwoDevices() two_devices = TwoDevices()
await two_devices.setup_connection() await two_devices.setup_connection()
await two_devices.connections[1].disconnect() await two_devices.connections[1].disconnect()
await async_barrier()
assert two_devices.connections[0] is None assert two_devices.connections[0] is None
assert two_devices.connections[1] is None assert two_devices.connections[1] is None
@@ -80,7 +82,8 @@ async def test_self_classic_connection(responder_role):
two_devices.devices[1].classic_enabled = True two_devices.devices[1].classic_enabled = True
# Start # Start
await two_devices.setup_connection() for dev in two_devices.devices:
await dev.power_on()
# Connect the two devices # Connect the two devices
await asyncio.gather( await asyncio.gather(
@@ -418,8 +421,9 @@ async def test_self_smp_over_classic():
two_devices.devices[1].classic_enabled = True two_devices.devices[1].classic_enabled = True
# Connect the two devices # Connect the two devices
await two_devices.devices[0].power_on() for dev in two_devices.devices:
await two_devices.devices[1].power_on() await dev.power_on()
await asyncio.gather( await asyncio.gather(
two_devices.devices[0].connect( two_devices.devices[0].connect(
two_devices.devices[1].public_address, transport=PhysicalTransport.BR_EDR two_devices.devices[1].public_address, transport=PhysicalTransport.BR_EDR

View File

@@ -16,6 +16,7 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import functools
from typing import Optional from typing import Optional
from typing_extensions import Self from typing_extensions import Self
@@ -30,39 +31,34 @@ from bumble.transport.common import AsyncPipeSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class TwoDevices: class Devices:
connections: list[Optional[Connection]] connections: list[Optional[Connection]]
def __init__(self) -> None: def __init__(self, num_devices: int) -> None:
self.connections = [None, None] self.connections = [None for _ in range(num_devices)]
self.link = LocalLink() self.link = LocalLink()
addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0'] addresses = [":".join([f"F{i}"] * 6) for i in range(num_devices)]
self.controllers = [ self.controllers = [
Controller('C1', link=self.link, public_address=addresses[0]), Controller(f'C{i+i}', link=self.link, public_address=addresses[i])
Controller('C2', link=self.link, public_address=addresses[1]), for i in range(num_devices)
] ]
self.devices = [ self.devices = [
Device( Device(
address=Address(addresses[0]), address=Address(addresses[i]),
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])), host=Host(self.controllers[i], AsyncPipeSink(self.controllers[i])),
), )
Device( for i in range(num_devices)
address=Address(addresses[1]),
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
] ]
self.devices[0].on( for i in range(num_devices):
'connection', lambda connection: self.on_connection(0, connection) self.devices[i].on(
) self.devices[i].EVENT_CONNECTION,
self.devices[1].on( functools.partial(self.on_connection, i),
'connection', lambda connection: self.on_connection(1, connection) )
)
self.paired = [ self.paired = [
asyncio.get_event_loop().create_future(), asyncio.get_event_loop().create_future() for _ in range(num_devices)
asyncio.get_event_loop().create_future(),
] ]
def on_connection(self, which, connection): def on_connection(self, which, connection):
@@ -77,19 +73,26 @@ class TwoDevices:
async def setup_connection(self) -> None: async def setup_connection(self) -> None:
# Start # Start
await self.devices[0].power_on() for dev in self.devices:
await self.devices[1].power_on() await dev.power_on()
# Connect the two devices # Connect devices
await self.devices[0].connect(self.devices[1].random_address) for dev in self.devices[1:]:
connection_future = asyncio.get_running_loop().create_future()
# Check the post conditions dev.once(dev.EVENT_CONNECTION, connection_future.set_result)
assert self.connections[0] is not None await dev.start_advertising(advertising_interval_min=1.0)
assert self.connections[1] is not None await self.devices[0].connect(dev.random_address)
await connection_future
def __getitem__(self, index: int) -> Device: def __getitem__(self, index: int) -> Device:
return self.devices[index] return self.devices[index]
# -----------------------------------------------------------------------------
class TwoDevices(Devices):
def __init__(self) -> None:
super().__init__(2)
@classmethod @classmethod
async def create_with_connection(cls: type[Self]) -> Self: async def create_with_connection(cls: type[Self]) -> Self:
devices = cls() devices = cls()