Compare commits

..

8 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
7523118581 typing surrport for HCI commands return parameters 2026-01-17 13:19:36 -08:00
Gilles Boccon-Gibod
2cad743f8c Merge pull request #854 from TinyServal/rtl8761cu
Add support for RTL8761CU
2026-01-08 18:37:21 -08:00
zxzxwu
255414f315 Merge pull request #857 from zxzxwu/testing
Add test for Heart Rate and Battery Service
2026-01-08 17:52:12 +08:00
Josh Wu
d2df76f6f4 Add test for Heart Rate and Battery Service 2026-01-08 16:42:05 +08:00
zxzxwu
884b1c20e4 Merge pull request #856 from zxzxwu/typing
Add annotation for Heart Rate and Battery Service
2026-01-08 15:29:50 +08:00
Josh Wu
91a2b4f676 Add annotation for Heart Rate and Battery Service 2026-01-08 14:43:27 +08:00
Bowen Yan
5831f79d62 Add support for the RTL8761CU 2026-01-08 16:50:11 +11:00
Bowen Yan
054dc70f3f Exclude macOS xattr files 2026-01-07 15:00:21 +11:00
25 changed files with 2366 additions and 1725 deletions

3
.gitignore vendored
View File

@@ -17,3 +17,6 @@ venv/
.venv/
# snoop logs
out/
# macOS
.DS_Store
._*

View File

@@ -34,11 +34,7 @@ from bumble.hci import (
HCI_READ_BD_ADDR_COMMAND,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_SUCCESS,
CodecID,
HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_Read_Buffer_Size_V2_Command,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
@@ -59,34 +55,23 @@ from bumble.host import Host
from bumble.transport import open_transport
# -----------------------------------------------------------------------------
def command_succeeded(response):
if isinstance(response, HCI_Command_Status_Event):
return response.status == HCI_SUCCESS
if isinstance(response, HCI_Command_Complete_Event):
return response.return_parameters.status == HCI_SUCCESS
return False
# -----------------------------------------------------------------------------
async def get_classic_info(host: Host) -> None:
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response):
print()
print(
color('Public Address:', 'yellow'),
response.return_parameters.bd_addr.to_string(False),
)
response1 = await host.send_sync_command(HCI_Read_BD_ADDR_Command())
print()
print(
color('Public Address:', 'yellow'),
response1.bd_addr.to_string(False),
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response = await host.send_command(HCI_Read_Local_Name_Command())
if command_succeeded(response):
print()
print(
color('Local Name:', 'yellow'),
map_null_terminated_utf8_string(response.return_parameters.local_name),
)
response2 = await host.send_sync_command(HCI_Read_Local_Name_Command())
print()
print(
color('Local Name:', 'yellow'),
map_null_terminated_utf8_string(response2.local_name),
)
# -----------------------------------------------------------------------------
@@ -94,52 +79,50 @@ async def get_le_info(host: Host) -> None:
print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
response = await host.send_command(
response1 = await host.send_sync_command(
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
if command_succeeded(response):
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response.return_parameters.num_supported_advertising_sets,
'\n',
)
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response1.num_supported_advertising_sets,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND):
response = await host.send_command(
response2 = await host.send_sync_command(
HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
if command_succeeded(response):
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response.return_parameters.max_advertising_data_length,
'\n',
)
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response2.max_advertising_data_length,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
response = await host.send_command(HCI_LE_Read_Maximum_Data_Length_Command())
if command_succeeded(response):
print(
color('Maximum Data Length:', 'yellow'),
(
f'tx:{response.return_parameters.supported_max_tx_octets}/'
f'{response.return_parameters.supported_max_tx_time}, '
f'rx:{response.return_parameters.supported_max_rx_octets}/'
f'{response.return_parameters.supported_max_rx_time}'
),
'\n',
)
response3 = await host.send_sync_command(
HCI_LE_Read_Maximum_Data_Length_Command()
)
print(
color('Maximum Data Length:', 'yellow'),
(
f'tx:{response3.supported_max_tx_octets}/'
f'{response3.supported_max_tx_time}, '
f'rx:{response3.supported_max_rx_octets}/'
f'{response3.supported_max_rx_time}'
),
'\n',
)
if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await host.send_command(
response4 = await host.send_sync_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
if command_succeeded(response):
print(
color('Suggested Default Data Length:', 'yellow'),
f'{response.return_parameters.suggested_max_tx_octets}/'
f'{response.return_parameters.suggested_max_tx_time}',
'\n',
)
print(
color('Suggested Default Data Length:', 'yellow'),
f'{response4.suggested_max_tx_octets}/'
f'{response4.suggested_max_tx_time}',
'\n',
)
print(color('LE Features:', 'yellow'))
for feature in host.supported_le_features:
@@ -151,37 +134,31 @@ async def get_flow_control_info(host: Host) -> None:
print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
response1 = await host.send_sync_command(HCI_Read_Buffer_Size_Command())
print(
color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
f'{response1.hc_total_num_acl_data_packets} '
f'packets of size {response1.hc_acl_data_packet_length}',
)
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
response2 = await host.send_sync_command(HCI_LE_Read_Buffer_Size_V2_Command())
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.le_acl_data_packet_length}',
f'{response2.total_num_le_acl_data_packets} '
f'packets of size {response2.le_acl_data_packet_length}',
)
print(
color('LE ISO Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_iso_data_packets} '
f'packets of size {response.return_parameters.iso_data_packet_length}',
f'{response2.total_num_iso_data_packets} '
f'packets of size {response2.iso_data_packet_length}',
)
elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
response3 = await host.send_sync_command(HCI_LE_Read_Buffer_Size_Command())
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.le_acl_data_packet_length}',
f'{response3.total_num_le_acl_data_packets} '
f'packets of size {response3.le_acl_data_packet_length}',
)
@@ -190,52 +167,44 @@ async def get_codecs_info(host: Host) -> None:
print()
if host.supports_command(HCI_Read_Local_Supported_Codecs_V2_Command.op_code):
response = await host.send_command(
HCI_Read_Local_Supported_Codecs_V2_Command(), check_result=True
response1 = await host.send_sync_command(
HCI_Read_Local_Supported_Codecs_V2_Command()
)
print(color('Codecs:', 'yellow'))
for codec_id, transport in zip(
response.return_parameters.standard_codec_ids,
response.return_parameters.standard_codec_transports,
response1.standard_codec_ids,
response1.standard_codec_transports,
):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
transport
).name
codec_name = CodecID(codec_id).name
print(f' {codec_name} - {transport_name}')
print(f' {codec_id.name} - {transport.name}')
for codec_id, transport in zip(
response.return_parameters.vendor_specific_codec_ids,
response.return_parameters.vendor_specific_codec_transports,
for vendor_codec_id, vendor_transport in zip(
response1.vendor_specific_codec_ids,
response1.vendor_specific_codec_transports,
):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
transport
).name
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF} - {transport_name}')
company = name_or_number(COMPANY_IDENTIFIERS, vendor_codec_id >> 16)
print(f' {company} / {vendor_codec_id & 0xFFFF} - {vendor_transport.name}')
if not response.return_parameters.standard_codec_ids:
if not response1.standard_codec_ids:
print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids:
if not response1.vendor_specific_codec_ids:
print(' No Vendor-specific codecs')
if host.supports_command(HCI_Read_Local_Supported_Codecs_Command.op_code):
response = await host.send_command(
HCI_Read_Local_Supported_Codecs_Command(), check_result=True
response2 = await host.send_sync_command(
HCI_Read_Local_Supported_Codecs_Command()
)
print(color('Codecs (BR/EDR):', 'yellow'))
for codec_id in response.return_parameters.standard_codec_ids:
codec_name = CodecID(codec_id).name
print(f' {codec_name}')
for codec_id in response2.standard_codec_ids:
print(f' {codec_id.name}')
for codec_id in response.return_parameters.vendor_specific_codec_ids:
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF}')
for vendor_codec_id in response2.vendor_specific_codec_ids:
company = name_or_number(COMPANY_IDENTIFIERS, vendor_codec_id >> 16)
print(f' {company} / {vendor_codec_id & 0xFFFF}')
if not response.return_parameters.standard_codec_ids:
if not response2.standard_codec_ids:
print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids:
if not response2.vendor_specific_codec_ids:
print(' No Vendor-specific codecs')

View File

@@ -85,7 +85,7 @@ class Loopback:
print(color('@@@ Received last packet', 'green'))
self.done.set()
async def run(self):
async def run(self) -> None:
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport(self.transport) as (
@@ -100,11 +100,15 @@ class Loopback:
# make sure data can fit in one l2cap pdu
l2cap_header_size = 4
max_packet_size = (
packet_queue = (
host.acl_packet_queue
if host.acl_packet_queue
else host.le_acl_packet_queue
).max_packet_size - l2cap_header_size
)
if packet_queue is None:
print(color('!!! No packet queue', 'red'))
return
max_packet_size = packet_queue.max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size:
print(
color(
@@ -128,20 +132,18 @@ class Loopback:
loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue'))
await host.send_command(
await host.send_sync_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
)
print(color('### Checking loopback mode', 'blue'))
response = await host.send_command(
HCI_Read_Loopback_Mode_Command(), check_result=True
)
if response.return_parameters.loopback_mode != loopback_mode:
response = await host.send_sync_command(HCI_Read_Loopback_Mode_Command())
if response.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red'))
return
await self.connection_event.wait()
assert self.connection_handle is not None
print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta'))

View File

@@ -421,7 +421,7 @@ class Controller:
hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=command.op_code,
return_parameters=result,
return_parameters=hci.HCI_GenericReturnParameters(data=result),
)
)

View File

@@ -923,7 +923,7 @@ class DeviceClass:
# pylint: enable=line-too-long
@staticmethod
def split_class_of_device(class_of_device):
def split_class_of_device(class_of_device: int) -> tuple[int, int, int]:
# Split the bit fields of the composite class of device value into:
# (service_classes, major_device_class, minor_device_class)
return (

File diff suppressed because it is too large Load Diff

View File

@@ -89,51 +89,54 @@ HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND = hci.hci_vendor_command_op_code(0x000E)
hci.HCI_Command.register_commands(globals())
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Intel_Read_Version_Command(hci.HCI_Command):
class HCI_Intel_Read_Version_ReturnParameters(hci.HCI_StatusReturnParameters):
tlv: bytes = hci.field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Read_Version_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Read_Version_Command(
hci.HCI_SyncCommand[HCI_Intel_Read_Version_ReturnParameters]
):
param0: int = dataclasses.field(metadata=hci.metadata(1))
return_parameters_fields = [
("status", hci.STATUS_SPEC),
("tlv", "*"),
]
@hci.HCI_Command.command
@hci.HCI_SyncCommand.sync_command(hci.HCI_StatusReturnParameters)
@dataclasses.dataclass
class Hci_Intel_Secure_Send_Command(hci.HCI_Command):
class Hci_Intel_Secure_Send_Command(
hci.HCI_SyncCommand[hci.HCI_StatusReturnParameters]
):
data_type: int = dataclasses.field(metadata=hci.metadata(1))
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
("status", 1),
]
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Intel_Reset_Command(hci.HCI_Command):
class HCI_Intel_Reset_ReturnParameters(hci.HCI_ReturnParameters):
data: bytes = hci.field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Reset_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Reset_Command(hci.HCI_SyncCommand[HCI_Intel_Reset_ReturnParameters]):
reset_type: int = dataclasses.field(metadata=hci.metadata(1))
patch_enable: int = dataclasses.field(metadata=hci.metadata(1))
ddc_reload: int = dataclasses.field(metadata=hci.metadata(1))
boot_option: int = dataclasses.field(metadata=hci.metadata(1))
boot_address: int = dataclasses.field(metadata=hci.metadata(4))
return_parameters_fields = [
("data", "*"),
]
@hci.HCI_Command.command
@dataclasses.dataclass
class Hci_Intel_Write_Device_Config_Command(hci.HCI_Command):
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
class HCI_Intel_Write_Device_Config_ReturnParameters(hci.HCI_StatusReturnParameters):
params: bytes = hci.field(metadata=hci.metadata('*'))
return_parameters_fields = [
("status", hci.STATUS_SPEC),
("params", "*"),
]
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Write_Device_Config_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Write_Device_Config_Command(
hci.HCI_SyncCommand[HCI_Intel_Write_Device_Config_ReturnParameters]
):
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
# -----------------------------------------------------------------------------
@@ -402,7 +405,7 @@ class Driver(common.Driver):
self.host.on_hci_event_packet(event)
return
if not event.return_parameters == hci.HCI_SUCCESS:
if not event.return_parameters.status == hci.HCI_SUCCESS:
raise DriverError("HCI_Command_Complete_Event error")
if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets:
@@ -641,8 +644,8 @@ class Driver(common.Driver):
while ddc_data:
ddc_len = 1 + ddc_data[0]
ddc_payload = ddc_data[:ddc_len]
await self.host.send_command(
Hci_Intel_Write_Device_Config_Command(data=ddc_payload)
await self.host.send_sync_command(
HCI_Intel_Write_Device_Config_Command(data=ddc_payload)
)
ddc_data = ddc_data[ddc_len:]
@@ -660,31 +663,26 @@ class Driver(common.Driver):
async def read_device_info(self) -> dict[ValueType, Any]:
self.host.ready = True
response = await self.host.send_command(hci.HCI_Reset_Command())
if not (
isinstance(response, hci.HCI_Command_Complete_Event)
and response.return_parameters
in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS)
):
response1 = await self.host.send_sync_command(
hci.HCI_Reset_Command(), check_status=False
)
if response1.status not in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS):
# When the controller is in operational mode, the response is a
# successful response.
# When the controller is in bootloader mode,
# HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything
# else is a failure.
logger.warning(f"unexpected response: {response}")
logger.warning(f"unexpected response: {response1}")
raise DriverError("unexpected HCI response")
# Read the firmware version.
response = await self.host.send_command(
HCI_Intel_Read_Version_Command(param0=0xFF)
response2 = await self.host.send_sync_command(
HCI_Intel_Read_Version_Command(param0=0xFF), check_status=False
)
if not isinstance(response, hci.HCI_Command_Complete_Event):
raise DriverError("unexpected HCI response")
if response.return_parameters.status != 0: # type: ignore
if response2.status != 0: # type: ignore
raise DriverError("HCI_Intel_Read_Version_Command error")
tlvs = _parse_tlv(response.return_parameters.tlv) # type: ignore
tlvs = _parse_tlv(response2.tlv) # type: ignore
# Convert the list to a dict. That's Ok here because we only expect each type
# to appear just once.

View File

@@ -16,6 +16,7 @@ Support for Realtek USB dongles.
Based on various online bits of information, including the Linux kernel.
(see `drivers/bluetooth/btrtl.c`)
"""
from __future__ import annotations
import asyncio
import enum
@@ -31,10 +32,14 @@ import weakref
# Imports
# -----------------------------------------------------------------------------
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from bumble import core, hci
from bumble.drivers import common
if TYPE_CHECKING:
from bumble.host import Host
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -77,6 +82,7 @@ class RtlProjectId(enum.IntEnum):
PROJECT_ID_8852A = 18
PROJECT_ID_8852B = 20
PROJECT_ID_8852C = 25
PROJECT_ID_8761C = 51
RTK_PROJECT_ID_TO_ROM = {
@@ -92,6 +98,7 @@ RTK_PROJECT_ID_TO_ROM = {
18: RTK_ROM_LMP_8852A,
20: RTK_ROM_LMP_8852A,
25: RTK_ROM_LMP_8852A,
51: RTK_ROM_LMP_8761A,
}
# List of USB (VendorID, ProductID) for Realtek-based devices.
@@ -123,6 +130,10 @@ RTK_USB_PRODUCTS = {
(0x2550, 0x8761),
(0x2B89, 0x8761),
(0x7392, 0xC611),
# Realtek 8761CUV
(0x0B05, 0x1BF6),
(0x0BDA, 0xC761),
(0x7392, 0xF611),
# Realtek 8821AE
(0x0B05, 0x17DC),
(0x13D3, 0x3414),
@@ -182,23 +193,36 @@ HCI_RTK_DROP_FIRMWARE_COMMAND = hci.hci_vendor_command_op_code(0x66)
hci.HCI_Command.register_commands(globals())
@hci.HCI_Command.command
@dataclass
class HCI_RTK_Read_ROM_Version_Command(hci.HCI_Command):
return_parameters_fields = [("status", hci.STATUS_SPEC), ("version", 1)]
class HCI_RTK_Read_ROM_Version_ReturnParameters(hci.HCI_StatusReturnParameters):
version: int = field(metadata=hci.metadata(1))
@hci.HCI_Command.command
@hci.HCI_SyncCommand.sync_command(HCI_RTK_Read_ROM_Version_ReturnParameters)
@dataclass
class HCI_RTK_Download_Command(hci.HCI_Command):
class HCI_RTK_Read_ROM_Version_Command(
hci.HCI_SyncCommand[HCI_RTK_Read_ROM_Version_ReturnParameters]
):
pass
@dataclass
class HCI_RTK_Download_ReturnParameters(hci.HCI_StatusReturnParameters):
index: int = field(metadata=hci.metadata(1))
@hci.HCI_SyncCommand.sync_command(HCI_RTK_Download_ReturnParameters)
@dataclass
class HCI_RTK_Download_Command(hci.HCI_SyncCommand[HCI_RTK_Download_ReturnParameters]):
index: int = field(metadata=hci.metadata(1))
payload: bytes = field(metadata=hci.metadata(RTK_FRAGMENT_LENGTH))
return_parameters_fields = [("status", hci.STATUS_SPEC), ("index", 1)]
@hci.HCI_Command.command
@hci.HCI_SyncCommand.sync_command(hci.HCI_GenericReturnParameters)
@dataclass
class HCI_RTK_Drop_Firmware_Command(hci.HCI_Command):
class HCI_RTK_Drop_Firmware_Command(
hci.HCI_SyncCommand[hci.HCI_GenericReturnParameters]
):
pass
@@ -363,6 +387,15 @@ class Driver(common.Driver):
fw_name="rtl8761bu_fw.bin",
config_name="rtl8761bu_config.bin",
),
# 8761CU
DriverInfo(
rom=RTK_ROM_LMP_8761A,
hci=(0x0E, 0x00),
config_needed=False,
has_rom_version=True,
fw_name="rtl8761cu_fw.bin",
config_name="rtl8761cu_config.bin",
),
# 8822C
DriverInfo(
rom=RTK_ROM_LMP_8822B,
@@ -420,9 +453,17 @@ class Driver(common.Driver):
@staticmethod
def find_driver_info(hci_version, hci_subversion, lmp_subversion):
for driver_info in Driver.DRIVER_INFOS:
if driver_info.rom == lmp_subversion and driver_info.hci == (
hci_subversion,
hci_version,
if driver_info.rom == lmp_subversion and (
driver_info.hci
== (
hci_subversion,
hci_version,
)
or driver_info.hci
== (
hci_subversion,
0x0,
)
):
return driver_info
@@ -467,7 +508,7 @@ class Driver(common.Driver):
return None
@staticmethod
def check(host):
def check(host: Host) -> bool:
if not host.hci_metadata:
logger.debug("USB metadata not found")
return False
@@ -491,41 +532,39 @@ class Driver(common.Driver):
return True
@staticmethod
async def get_loaded_firmware_version(host):
response = await host.send_command(HCI_RTK_Read_ROM_Version_Command())
async def get_loaded_firmware_version(host: Host) -> int | None:
response1 = await host.send_sync_command(
HCI_RTK_Read_ROM_Version_Command(), check_status=False
)
if response.return_parameters.status != hci.HCI_SUCCESS:
if response1.status != hci.HCI_SUCCESS:
return None
response = await host.send_command(
hci.HCI_Read_Local_Version_Information_Command(), check_result=True
)
return (
response.return_parameters.hci_subversion << 16
| response.return_parameters.lmp_subversion
response2 = await host.send_sync_command(
hci.HCI_Read_Local_Version_Information_Command()
)
return response2.hci_subversion << 16 | response2.lmp_subversion
@classmethod
async def driver_info_for_host(cls, host):
async def driver_info_for_host(cls, host: Host) -> DriverInfo | None:
try:
await host.send_command(
await host.send_sync_command(
hci.HCI_Reset_Command(),
check_result=True,
response_timeout=cls.POST_RESET_DELAY,
)
host.ready = True # Needed to let the host know the controller is ready.
except asyncio.exceptions.TimeoutError:
logger.warning("timeout waiting for hci reset, retrying")
await host.send_command(hci.HCI_Reset_Command(), check_result=True)
await host.send_sync_command(hci.HCI_Reset_Command())
host.ready = True
command = hci.HCI_Read_Local_Version_Information_Command()
response = await host.send_command(command, check_result=True)
if response.command_opcode != command.op_code:
response = await host.send_sync_command(command, check_status=False)
if response.status != hci.HCI_SUCCESS:
logger.error("failed to probe local version information")
return None
local_version = response.return_parameters
local_version = response
logger.debug(
f"looking for a driver: 0x{local_version.lmp_subversion:04X} "
@@ -546,7 +585,7 @@ class Driver(common.Driver):
return driver_info
@classmethod
async def for_host(cls, host, force=False):
async def for_host(cls, host: Host, force: bool = False):
# Check that a driver is needed for this host
if not force and not cls.check(host):
return None
@@ -603,13 +642,13 @@ class Driver(common.Driver):
async def download_for_rtl8723b(self):
if self.driver_info.has_rom_version:
response = await self.host.send_command(
HCI_RTK_Read_ROM_Version_Command(), check_result=True
response1 = await self.host.send_sync_command(
HCI_RTK_Read_ROM_Version_Command(), check_status=False
)
if response.return_parameters.status != hci.HCI_SUCCESS:
if response1.status != hci.HCI_SUCCESS:
logger.warning("can't get ROM version")
return None
rom_version = response.return_parameters.version
rom_version = response1.version
logger.debug(f"ROM version before download: {rom_version:04X}")
else:
rom_version = 0
@@ -644,21 +683,20 @@ class Driver(common.Driver):
fragment_offset = fragment_index * RTK_FRAGMENT_LENGTH
fragment = payload[fragment_offset : fragment_offset + RTK_FRAGMENT_LENGTH]
logger.debug(f"downloading fragment {fragment_index}")
await self.host.send_command(
HCI_RTK_Download_Command(index=download_index, payload=fragment),
check_result=True,
await self.host.send_sync_command(
HCI_RTK_Download_Command(index=download_index, payload=fragment)
)
logger.debug("download complete!")
# Read the version again
response = await self.host.send_command(
HCI_RTK_Read_ROM_Version_Command(), check_result=True
response2 = await self.host.send_sync_command(
HCI_RTK_Read_ROM_Version_Command(), check_status=False
)
if response.return_parameters.status != hci.HCI_SUCCESS:
if response2.status != hci.HCI_SUCCESS:
logger.warning("can't get ROM version")
else:
rom_version = response.return_parameters.version
rom_version = response2.version
logger.debug(f"ROM version after download: {rom_version:02X}")
return firmware.version
@@ -680,7 +718,7 @@ class Driver(common.Driver):
async def init_controller(self):
await self.download_firmware()
await self.host.send_command(hci.HCI_Reset_Command(), check_result=True)
await self.host.send_sync_command(hci.HCI_Reset_Command())
logger.info(f"loaded FW image {self.driver_info.fw_name}")

File diff suppressed because it is too large Load Diff

View File

@@ -23,11 +23,15 @@ import dataclasses
import logging
import struct
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
from bumble import drivers, hci, utils
from bumble.colors import color
from bumble.core import ConnectionPHY, InvalidStateError, PhysicalTransport
from bumble.core import (
ConnectionPHY,
InvalidStateError,
PhysicalTransport,
)
from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper
from bumble.transport.common import TransportLostError
@@ -35,7 +39,6 @@ from bumble.transport.common import TransportLostError
if TYPE_CHECKING:
from bumble.transport.common import TransportSink, TransportSource
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -236,6 +239,9 @@ class IsoLink:
# -----------------------------------------------------------------------------
_RP = TypeVar('_RP', bound=hci.HCI_ReturnParameters)
class Host(utils.EventEmitter):
connections: dict[int, Connection]
cis_links: dict[int, IsoLink]
@@ -264,11 +270,13 @@ class Host(utils.EventEmitter):
self.bis_links = {} # BIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.bigs = {} # BIG Handle to BIS Handles
self.pending_command = None
self.pending_command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand | None = None
self.pending_response: asyncio.Future[Any] | None = None
self.number_of_supported_advertising_sets = 0
self.maximum_advertising_data_length = 31
self.local_version = None
self.local_version: (
hci.HCI_Read_Local_Version_Information_ReturnParameters | None
) = None
self.local_supported_commands = 0
self.local_le_features = 0
self.local_lmp_features = hci.LmpFeatureMask(0) # Classic LMP features
@@ -312,7 +320,7 @@ class Host(utils.EventEmitter):
self.emit('flush')
self.command_semaphore.release()
async def reset(self, driver_factory=drivers.get_driver_for_host):
async def reset(self, driver_factory=drivers.get_driver_for_host) -> None:
if self.ready:
self.ready = False
await self.flush()
@@ -330,57 +338,53 @@ class Host(utils.EventEmitter):
# Send a reset command unless a driver has already done so.
if reset_needed:
await self.send_command(hci.HCI_Reset_Command(), check_result=True)
await self.send_sync_command(hci.HCI_Reset_Command())
self.ready = True
response = await self.send_command(
hci.HCI_Read_Local_Supported_Commands_Command(), check_result=True
response1 = await self.send_sync_command(
hci.HCI_Read_Local_Supported_Commands_Command()
)
self.local_supported_commands = int.from_bytes(
response.return_parameters.supported_commands, 'little'
response1.supported_commands, 'little'
)
if self.supports_command(hci.HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command(
hci.HCI_LE_Read_Local_Supported_Features_Command(), check_result=True
response2 = await self.send_sync_command(
hci.HCI_LE_Read_Local_Supported_Features_Command()
)
self.local_le_features = struct.unpack(
'<Q', response.return_parameters.le_features
)[0]
self.local_le_features = struct.unpack('<Q', response2.le_features)[0]
if self.supports_command(hci.HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND):
response = await self.send_command(
hci.HCI_Read_Local_Version_Information_Command(), check_result=True
self.local_version = await self.send_sync_command(
hci.HCI_Read_Local_Version_Information_Command()
)
self.local_version = response.return_parameters
if self.supports_command(hci.HCI_READ_LOCAL_EXTENDED_FEATURES_COMMAND):
max_page_number = 0
page_number = 0
lmp_features = 0
while page_number <= max_page_number:
response = await self.send_command(
response4 = await self.send_sync_command(
hci.HCI_Read_Local_Extended_Features_Command(
page_number=page_number
),
check_result=True,
)
)
lmp_features |= int.from_bytes(
response.return_parameters.extended_lmp_features, 'little'
response4.extended_lmp_features, 'little'
) << (64 * page_number)
max_page_number = response.return_parameters.maximum_page_number
max_page_number = response4.maximum_page_number
page_number += 1
self.local_lmp_features = hci.LmpFeatureMask(lmp_features)
elif self.supports_command(hci.HCI_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command(
hci.HCI_Read_Local_Supported_Features_Command(), check_result=True
response5 = await self.send_sync_command(
hci.HCI_Read_Local_Supported_Features_Command()
)
self.local_lmp_features = hci.LmpFeatureMask(
int.from_bytes(response.return_parameters.lmp_features, 'little')
int.from_bytes(response5.lmp_features, 'little')
)
await self.send_command(
await self.send_sync_command(
hci.HCI_Set_Event_Mask_Command(
event_mask=hci.HCI_Set_Event_Mask_Command.mask(
[
@@ -437,7 +441,7 @@ class Host(utils.EventEmitter):
)
)
if self.supports_command(hci.HCI_SET_EVENT_MASK_PAGE_2_COMMAND):
await self.send_command(
await self.send_sync_command(
hci.HCI_Set_Event_Mask_Page_2_Command(
event_mask_page_2=hci.HCI_Set_Event_Mask_Page_2_Command.mask(
[hci.HCI_ENCRYPTION_CHANGE_V2_EVENT]
@@ -499,20 +503,14 @@ class Host(utils.EventEmitter):
]
)
await self.send_command(
await self.send_sync_command(
hci.HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask)
)
if self.supports_command(hci.HCI_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(
hci.HCI_Read_Buffer_Size_Command(), check_result=True
)
hc_acl_data_packet_length = (
response.return_parameters.hc_acl_data_packet_length
)
hc_total_num_acl_data_packets = (
response.return_parameters.hc_total_num_acl_data_packets
)
response6 = await self.send_sync_command(hci.HCI_Read_Buffer_Size_Command())
hc_acl_data_packet_length = response6.hc_acl_data_packet_length
hc_total_num_acl_data_packets = response6.hc_total_num_acl_data_packets
logger.debug(
'HCI ACL flow control: '
@@ -531,19 +529,13 @@ class Host(utils.EventEmitter):
iso_data_packet_length = 0
total_num_iso_data_packets = 0
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await self.send_command(
hci.HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
le_acl_data_packet_length = (
response.return_parameters.le_acl_data_packet_length
)
total_num_le_acl_data_packets = (
response.return_parameters.total_num_le_acl_data_packets
)
iso_data_packet_length = response.return_parameters.iso_data_packet_length
total_num_iso_data_packets = (
response.return_parameters.total_num_iso_data_packets
response7 = await self.send_sync_command(
hci.HCI_LE_Read_Buffer_Size_V2_Command()
)
le_acl_data_packet_length = response7.le_acl_data_packet_length
total_num_le_acl_data_packets = response7.total_num_le_acl_data_packets
iso_data_packet_length = response7.iso_data_packet_length
total_num_iso_data_packets = response7.total_num_iso_data_packets
logger.debug(
'HCI LE flow control: '
@@ -553,15 +545,11 @@ class Host(utils.EventEmitter):
f'total_num_iso_data_packets={total_num_iso_data_packets}'
)
elif self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(
hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
le_acl_data_packet_length = (
response.return_parameters.le_acl_data_packet_length
)
total_num_le_acl_data_packets = (
response.return_parameters.total_num_le_acl_data_packets
response8 = await self.send_sync_command(
hci.HCI_LE_Read_Buffer_Size_Command()
)
le_acl_data_packet_length = response8.le_acl_data_packet_length
total_num_le_acl_data_packets = response8.total_num_le_acl_data_packets
logger.debug(
'HCI LE ACL flow control: '
@@ -592,16 +580,16 @@ class Host(utils.EventEmitter):
) and self.supports_command(
hci.HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
):
response = await self.send_command(
response9 = await self.send_sync_command(
hci.HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
suggested_max_tx_octets = response9.suggested_max_tx_octets
suggested_max_tx_time = response9.suggested_max_tx_time
if (
suggested_max_tx_octets != self.suggested_max_tx_octets
or suggested_max_tx_time != self.suggested_max_tx_time
):
await self.send_command(
await self.send_sync_command(
hci.HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets=self.suggested_max_tx_octets,
suggested_max_tx_time=self.suggested_max_tx_time,
@@ -611,23 +599,21 @@ class Host(utils.EventEmitter):
if self.supports_command(
hci.HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND
):
response = await self.send_command(
hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command(),
check_result=True,
response10 = await self.send_sync_command(
hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
self.number_of_supported_advertising_sets = (
response.return_parameters.num_supported_advertising_sets
response10.num_supported_advertising_sets
)
if self.supports_command(
hci.HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND
):
response = await self.send_command(
hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command(),
check_result=True,
response11 = await self.send_sync_command(
hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
self.maximum_advertising_data_length = (
response.return_parameters.max_advertising_data_length
response11.max_advertising_data_length
)
@property
@@ -654,9 +640,11 @@ class Host(utils.EventEmitter):
if self.hci_sink:
self.hci_sink.on_packet(bytes(packet))
async def send_command(
self, command, check_result=False, response_timeout: int | None = None
):
async def _send_command(
self,
command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand,
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event:
# Wait until we can send (only one pending command at a time)
async with self.command_semaphore:
assert self.pending_command is None
@@ -668,29 +656,9 @@ class Host(utils.EventEmitter):
try:
self.send_hci_packet(command)
await asyncio.wait_for(self.pending_response, timeout=response_timeout)
response = self.pending_response.result()
# Check the return parameters if required
if check_result:
if isinstance(response, hci.HCI_Command_Status_Event):
status = response.status # type: ignore[attr-defined]
elif isinstance(response.return_parameters, int):
status = response.return_parameters
elif isinstance(response.return_parameters, bytes):
# return parameters first field is a one byte status code
status = response.return_parameters[0]
else:
status = response.return_parameters.status
if status != hci.HCI_SUCCESS:
logger.warning(
f'{command.name} failed '
f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return response
return await asyncio.wait_for(
self.pending_response, timeout=response_timeout
)
except Exception:
logger.exception(color("!!! Exception while sending command:", "red"))
raise
@@ -698,12 +666,107 @@ class Host(utils.EventEmitter):
self.pending_command = None
self.pending_response = None
# Use this method to send a command from a task
def send_command_sync(self, command: hci.HCI_Command) -> None:
async def send_command(command: hci.HCI_Command) -> None:
await self.send_command(command)
@overload
async def send_command(
self,
command: hci.HCI_SyncCommand[_RP],
check_result: bool = False,
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event[_RP]: ...
asyncio.create_task(send_command(command))
@overload
async def send_command(
self,
command: hci.HCI_AsyncCommand,
check_result: bool = False,
response_timeout: float | None = None,
) -> hci.HCI_Command_Status_Event: ...
async def send_command(
self,
command: hci.HCI_SyncCommand[_RP] | hci.HCI_AsyncCommand,
check_result: bool = False,
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event[_RP] | hci.HCI_Command_Status_Event:
response = await self._send_command(command, response_timeout)
# Check the return parameters if required
if check_result:
if isinstance(response, hci.HCI_Command_Status_Event):
status = response.status # type: ignore[attr-defined]
elif isinstance(response.return_parameters, int):
status = response.return_parameters
elif isinstance(response.return_parameters, bytes):
# return parameters first field is a one byte status code
status = response.return_parameters[0]
elif isinstance(
response.return_parameters, hci.HCI_GenericReturnParameters
):
# FIXME: temporary workaround
# NO STATUS
status = hci.HCI_SUCCESS
else:
status = response.return_parameters.status
if status != hci.HCI_SUCCESS:
logger.warning(
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return response
async def send_sync_command(
self,
command: hci.HCI_SyncCommand[_RP],
check_status: bool = True,
response_timeout: float | None = None,
) -> _RP:
response = await self._send_command(command, response_timeout)
# Check that the response is of the expected type
assert isinstance(response, hci.HCI_Command_Complete_Event)
return_parameters: _RP = response.return_parameters
assert isinstance(return_parameters, command.return_parameters_class)
# Check the return parameters if required
if check_status:
if isinstance(return_parameters, hci.HCI_StatusReturnParameters):
status = return_parameters.status
if status != hci.HCI_SUCCESS:
logger.warning(
f'{command.name} failed '
f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return return_parameters
async def send_async_command(
self,
command: hci.HCI_AsyncCommand,
check_status: bool = True,
response_timeout: float | None = None,
) -> hci.HCI_ErrorCode:
response = await self._send_command(command, response_timeout)
# Check that the response is of the expected type
assert isinstance(response, hci.HCI_Command_Status_Event)
# Check the return parameters if required
status = response.status
if check_status:
if status != hci.HCI_CommandStatus.PENDING:
logger.warning(
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return hci.HCI_ErrorCode(status)
@utils.deprecated("Use utils.AsyncRunner.spawn() instead.")
def send_command_sync(self, command: hci.HCI_AsyncCommand) -> None:
utils.AsyncRunner.spawn(self.send_async_command(command))
def send_acl_sdu(self, connection_handle: int, sdu: bytes) -> None:
if not (connection := self.connections.get(connection_handle)):
@@ -1338,15 +1401,17 @@ class Host(utils.EventEmitter):
# For now, just accept everything
# TODO: delegate the decision
self.send_command_sync(
hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
connection_handle=event.connection_handle,
interval_min=event.interval_min,
interval_max=event.interval_max,
max_latency=event.max_latency,
timeout=event.timeout,
min_ce_length=0,
max_ce_length=0,
utils.AsyncRunner.spawn(
self.send_sync_command(
hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
connection_handle=event.connection_handle,
interval_min=event.interval_min,
interval_max=event.interval_max,
max_latency=event.max_latency,
timeout=event.timeout,
min_ce_length=0,
max_ce_length=0,
)
)
)
@@ -1382,9 +1447,9 @@ class Host(utils.EventEmitter):
connection_handle=event.connection_handle
)
await self.send_command(response)
await self.send_sync_command(response)
asyncio.create_task(send_long_term_key())
utils.AsyncRunner.spawn(send_long_term_key())
def on_hci_synchronous_connection_complete_event(
self, event: hci.HCI_Synchronous_Connection_Complete_Event
@@ -1583,9 +1648,9 @@ class Host(utils.EventEmitter):
bd_addr=event.bd_addr
)
await self.send_command(response)
await self.send_sync_command(response)
asyncio.create_task(send_link_key())
utils.AsyncRunner.spawn(send_link_key())
def on_hci_io_capability_request_event(
self, event: hci.HCI_IO_Capability_Request_Event

View File

@@ -16,35 +16,28 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from collections.abc import Callable
from bumble.gatt import (
GATT_BATTERY_LEVEL_CHARACTERISTIC,
GATT_BATTERY_SERVICE,
Characteristic,
CharacteristicValue,
TemplateService,
)
from bumble.gatt_adapters import (
PackedCharacteristicAdapter,
PackedCharacteristicProxyAdapter,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy
from bumble import device, gatt, gatt_adapters, gatt_client
# -----------------------------------------------------------------------------
class BatteryService(TemplateService):
UUID = GATT_BATTERY_SERVICE
class BatteryService(gatt.TemplateService):
UUID = gatt.GATT_BATTERY_SERVICE
BATTERY_LEVEL_FORMAT = 'B'
battery_level_characteristic: Characteristic[int]
battery_level_characteristic: gatt.Characteristic[int]
def __init__(self, read_battery_level):
self.battery_level_characteristic = PackedCharacteristicAdapter(
Characteristic(
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
CharacteristicValue(read=read_battery_level),
def __init__(self, read_battery_level: Callable[[device.Connection], int]) -> None:
self.battery_level_characteristic = gatt_adapters.PackedCharacteristicAdapter(
gatt.Characteristic(
gatt.GATT_BATTERY_LEVEL_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.READABLE,
value=gatt.CharacteristicValue(read=read_battery_level),
),
pack_format=BatteryService.BATTERY_LEVEL_FORMAT,
)
@@ -52,19 +45,17 @@ class BatteryService(TemplateService):
# -----------------------------------------------------------------------------
class BatteryServiceProxy(ProfileServiceProxy):
class BatteryServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = BatteryService
battery_level: CharacteristicProxy[int] | None
battery_level: gatt_client.CharacteristicProxy[int]
def __init__(self, service_proxy):
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BATTERY_LEVEL_CHARACTERISTIC
):
self.battery_level = PackedCharacteristicProxyAdapter(
characteristics[0], pack_format=BatteryService.BATTERY_LEVEL_FORMAT
)
else:
self.battery_level = None
self.battery_level = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_BATTERY_LEVEL_CHARACTERISTIC
),
pack_format=BatteryService.BATTERY_LEVEL_FORMAT,
)

View File

@@ -18,40 +18,30 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import enum
import struct
from enum import IntEnum
from collections.abc import Callable, Sequence
from typing import Any
from bumble import core
from bumble.att import ATT_Error
from bumble.gatt import (
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
GATT_HEART_RATE_SERVICE,
Characteristic,
CharacteristicValue,
TemplateService,
)
from bumble.gatt_adapters import (
DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter,
SerializableCharacteristicAdapter,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy
from typing_extensions import Self
from bumble import att, core, device, gatt, gatt_adapters, gatt_client, utils
# -----------------------------------------------------------------------------
class HeartRateService(TemplateService):
UUID = GATT_HEART_RATE_SERVICE
class HeartRateService(gatt.TemplateService):
UUID = gatt.GATT_HEART_RATE_SERVICE
HEART_RATE_CONTROL_POINT_FORMAT = 'B'
CONTROL_POINT_NOT_SUPPORTED = 0x80
RESET_ENERGY_EXPENDED = 0x01
heart_rate_measurement_characteristic: Characteristic[HeartRateMeasurement]
body_sensor_location_characteristic: Characteristic[BodySensorLocation]
heart_rate_control_point_characteristic: Characteristic[int]
heart_rate_measurement_characteristic: gatt.Characteristic[HeartRateMeasurement]
body_sensor_location_characteristic: gatt.Characteristic[BodySensorLocation]
heart_rate_control_point_characteristic: gatt.Characteristic[int]
class BodySensorLocation(IntEnum):
class BodySensorLocation(utils.OpenIntEnum):
OTHER = 0
CHEST = 1
WRIST = 2
@@ -60,82 +50,90 @@ class HeartRateService(TemplateService):
EAR_LOBE = 5
FOOT = 6
@dataclasses.dataclass
class HeartRateMeasurement:
def __init__(
self,
heart_rate,
sensor_contact_detected=None,
energy_expended=None,
rr_intervals=None,
):
if heart_rate < 0 or heart_rate > 0xFFFF:
heart_rate: int
sensor_contact_detected: bool | None = None
energy_expended: int | None = None
rr_intervals: Sequence[float] | None = None
class Flag(enum.IntFlag):
INT16_HEART_RATE = 1 << 0
SENSOR_CONTACT_DETECTED = 1 << 1
SENSOR_CONTACT_SUPPORTED = 1 << 2
ENERGY_EXPENDED_STATUS = 1 << 3
RR_INTERVAL = 1 << 4
def __post_init__(self) -> None:
if self.heart_rate < 0 or self.heart_rate > 0xFFFF:
raise core.InvalidArgumentError('heart_rate out of range')
if energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF
if self.energy_expended is not None and (
self.energy_expended < 0 or self.energy_expended > 0xFFFF
):
raise core.InvalidArgumentError('energy_expended out of range')
if rr_intervals:
for rr_interval in rr_intervals:
if self.rr_intervals:
for rr_interval in self.rr_intervals:
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise core.InvalidArgumentError('rr_intervals out of range')
self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected
self.energy_expended = energy_expended
self.rr_intervals = rr_intervals
@classmethod
def from_bytes(cls, data):
def from_bytes(cls, data: bytes) -> Self:
flags = data[0]
offset = 1
if flags & 1:
hr = struct.unpack_from('<H', data, offset)[0]
if flags & cls.Flag.INT16_HEART_RATE:
heart_rate = struct.unpack_from('<H', data, offset)[0]
offset += 2
else:
hr = struct.unpack_from('B', data, offset)[0]
heart_rate = struct.unpack_from('B', data, offset)[0]
offset += 1
if flags & (1 << 2):
sensor_contact_detected = flags & (1 << 1) != 0
if flags & cls.Flag.SENSOR_CONTACT_SUPPORTED:
sensor_contact_detected = flags & cls.Flag.SENSOR_CONTACT_DETECTED != 0
else:
sensor_contact_detected = None
if flags & (1 << 3):
if flags & cls.Flag.ENERGY_EXPENDED_STATUS:
energy_expended = struct.unpack_from('<H', data, offset)[0]
offset += 2
else:
energy_expended = None
if flags & (1 << 4):
rr_intervals: Sequence[float] | None = None
if flags & cls.Flag.RR_INTERVAL:
rr_intervals = tuple(
struct.unpack_from('<H', data, offset + i * 2)[0] / 1024
for i in range((len(data) - offset) // 2)
struct.unpack_from('<H', data, i)[0] / 1024
for i in range(offset, len(data), 2)
)
else:
rr_intervals = ()
return cls(hr, sensor_contact_detected, energy_expended, rr_intervals)
return cls(
heart_rate=heart_rate,
sensor_contact_detected=sensor_contact_detected,
energy_expended=energy_expended,
rr_intervals=rr_intervals,
)
def __bytes__(self):
def __bytes__(self) -> bytes:
flags = 0
if self.heart_rate < 256:
flags = 0
data = struct.pack('B', self.heart_rate)
else:
flags = 1
flags |= self.Flag.INT16_HEART_RATE
data = struct.pack('<H', self.heart_rate)
if self.sensor_contact_detected is not None:
flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2)
flags |= self.Flag.SENSOR_CONTACT_SUPPORTED
if self.sensor_contact_detected:
flags |= self.Flag.SENSOR_CONTACT_DETECTED
if self.energy_expended is not None:
flags |= 1 << 3
flags |= self.Flag.ENERGY_EXPENDED_STATUS
data += struct.pack('<H', self.energy_expended)
if self.rr_intervals:
flags |= 1 << 4
if self.rr_intervals is not None:
flags |= self.Flag.RR_INTERVAL
data += b''.join(
[
struct.pack('<H', int(rr_interval * 1024))
@@ -145,57 +143,67 @@ class HeartRateService(TemplateService):
return bytes([flags]) + data
def __str__(self):
return (
f'HeartRateMeasurement(heart_rate={self.heart_rate},'
f' sensor_contact_detected={self.sensor_contact_detected},'
f' energy_expended={self.energy_expended},'
f' rr_intervals={self.rr_intervals})'
)
def __init__(
self,
read_heart_rate_measurement,
body_sensor_location=None,
reset_energy_expended=None,
read_heart_rate_measurement: Callable[
[device.Connection], HeartRateMeasurement
],
body_sensor_location: HeartRateService.BodySensorLocation | None = None,
reset_energy_expended: Callable[[device.Connection], Any] | None = None,
):
self.heart_rate_measurement_characteristic = SerializableCharacteristicAdapter(
Characteristic(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Characteristic.Properties.NOTIFY,
0,
CharacteristicValue(read=read_heart_rate_measurement),
),
HeartRateService.HeartRateMeasurement,
self.heart_rate_measurement_characteristic = (
gatt_adapters.SerializableCharacteristicAdapter(
gatt.Characteristic(
uuid=gatt.GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions(0),
value=gatt.CharacteristicValue(read=read_heart_rate_measurement),
),
HeartRateService.HeartRateMeasurement,
)
)
characteristics = [self.heart_rate_measurement_characteristic]
characteristics: list[gatt.Characteristic] = [
self.heart_rate_measurement_characteristic
]
if body_sensor_location is not None:
self.body_sensor_location_characteristic = Characteristic(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
bytes([int(body_sensor_location)]),
self.body_sensor_location_characteristic = (
gatt_adapters.EnumCharacteristicAdapter(
gatt.Characteristic(
uuid=gatt.GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.READABLE,
value=body_sensor_location,
),
cls=self.BodySensorLocation,
length=1,
)
)
characteristics.append(self.body_sensor_location_characteristic)
if reset_energy_expended:
def write_heart_rate_control_point_value(connection, value):
def write_heart_rate_control_point_value(
connection: device.Connection, value: bytes
) -> None:
if value == self.RESET_ENERGY_EXPENDED:
if reset_energy_expended is not None:
reset_energy_expended(connection)
else:
raise ATT_Error(self.CONTROL_POINT_NOT_SUPPORTED)
raise att.ATT_Error(self.CONTROL_POINT_NOT_SUPPORTED)
self.heart_rate_control_point_characteristic = PackedCharacteristicAdapter(
Characteristic(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE,
Characteristic.WRITEABLE,
CharacteristicValue(write=write_heart_rate_control_point_value),
),
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
self.heart_rate_control_point_characteristic = (
gatt_adapters.PackedCharacteristicAdapter(
gatt.Characteristic(
uuid=gatt.GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.WRITEABLE,
value=gatt.CharacteristicValue(
write=write_heart_rate_control_point_value
),
),
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
)
characteristics.append(self.heart_rate_control_point_characteristic)
@@ -203,50 +211,51 @@ class HeartRateService(TemplateService):
# -----------------------------------------------------------------------------
class HeartRateServiceProxy(ProfileServiceProxy):
class HeartRateServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HeartRateService
heart_rate_measurement: (
CharacteristicProxy[HeartRateService.HeartRateMeasurement] | None
)
heart_rate_measurement: gatt_client.CharacteristicProxy[
HeartRateService.HeartRateMeasurement
]
body_sensor_location: (
CharacteristicProxy[HeartRateService.BodySensorLocation] | None
gatt_client.CharacteristicProxy[HeartRateService.BodySensorLocation] | None
)
heart_rate_control_point: CharacteristicProxy[int] | None
heart_rate_control_point: gatt_client.CharacteristicProxy[int] | None
def __init__(self, service_proxy):
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
):
self.heart_rate_measurement = SerializableCharacteristicAdapter(
characteristics[0], HeartRateService.HeartRateMeasurement
self.heart_rate_measurement = (
gatt_adapters.SerializableCharacteristicProxyAdapter(
service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
),
HeartRateService.HeartRateMeasurement,
)
else:
self.heart_rate_measurement = None
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
gatt.GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
):
self.body_sensor_location = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: HeartRateService.BodySensorLocation(value[0]),
self.body_sensor_location = gatt_adapters.EnumCharacteristicProxyAdapter(
characteristics[0], cls=HeartRateService.BodySensorLocation, length=1
)
else:
self.body_sensor_location = None
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
gatt.GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
):
self.heart_rate_control_point = PackedCharacteristicAdapter(
characteristics[0],
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
self.heart_rate_control_point = (
gatt_adapters.PackedCharacteristicProxyAdapter(
characteristics[0],
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
)
else:
self.heart_rate_control_point = None
async def reset_energy_expended(self):
async def reset_energy_expended(self) -> None:
if self.heart_rate_control_point is not None:
return await self.heart_rate_control_point.write_value(
HeartRateService.RESET_ENERGY_EXPENDED

View File

@@ -43,44 +43,53 @@ hci.HCI_Command.register_commands(globals())
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_LE_Get_Vendor_Capabilities_Command(hci.HCI_Command):
class HCI_LE_Get_Vendor_Capabilities_ReturnParameters(hci.HCI_StatusReturnParameters):
max_advt_instances: int = field(metadata=hci.metadata(1), default=0)
offloaded_resolution_of_private_address: int = field(
metadata=hci.metadata(1), default=0
)
total_scan_results_storage: int = field(metadata=hci.metadata(2), default=0)
max_irk_list_sz: int = field(metadata=hci.metadata(1), default=0)
filtering_support: int = field(metadata=hci.metadata(1), default=0)
max_filter: int = field(metadata=hci.metadata(1), default=0)
activity_energy_info_support: int = field(metadata=hci.metadata(1), default=0)
version_supported: int = field(metadata=hci.metadata(2), default=0)
total_num_of_advt_tracked: int = field(metadata=hci.metadata(2), default=0)
extended_scan_support: int = field(metadata=hci.metadata(1), default=0)
debug_logging_supported: int = field(metadata=hci.metadata(1), default=0)
le_address_generation_offloading_support: int = field(
metadata=hci.metadata(1), default=0
)
a2dp_source_offload_capability_mask: int = field(
metadata=hci.metadata(4), default=0
)
bluetooth_quality_report_support: int = field(metadata=hci.metadata(1), default=0)
dynamic_audio_buffer_support: int = field(metadata=hci.metadata(4), default=0)
@hci.HCI_SyncCommand.sync_command(HCI_LE_Get_Vendor_Capabilities_ReturnParameters)
@dataclasses.dataclass
class HCI_LE_Get_Vendor_Capabilities_Command(
hci.HCI_SyncCommand[HCI_LE_Get_Vendor_Capabilities_ReturnParameters]
):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#vendor-specific-capabilities
'''
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('max_advt_instances', 1),
('offloaded_resolution_of_private_address', 1),
('total_scan_results_storage', 2),
('max_irk_list_sz', 1),
('filtering_support', 1),
('max_filter', 1),
('activity_energy_info_support', 1),
('version_supported', 2),
('total_num_of_advt_tracked', 2),
('extended_scan_support', 1),
('debug_logging_supported', 1),
('le_address_generation_offloading_support', 1),
('a2dp_source_offload_capability_mask', 4),
('bluetooth_quality_report_support', 1),
('dynamic_audio_buffer_support', 4),
]
@classmethod
def parse_return_parameters(cls, parameters):
# There are many versions of this data structure, so we need to parse until
# there are no more bytes to parse, and leave un-signal parameters set to
# None (older versions)
nones = {field: None for field, _ in cls.return_parameters_fields}
return_parameters = hci.HCI_Object(cls.return_parameters_fields, **nones)
# there are no more bytes to parse, and leave un-signaled parameters set to
# 0
return_parameters = HCI_LE_Get_Vendor_Capabilities_ReturnParameters(
hci.HCI_ErrorCode.SUCCESS
)
try:
offset = 0
for field in cls.return_parameters_fields:
for field in cls.return_parameters_class.fields:
field_name, field_type = field
field_value, field_size = hci.HCI_Object.parse_field(
parameters, offset, field_type
@@ -94,9 +103,30 @@ class HCI_LE_Get_Vendor_Capabilities_Command(hci.HCI_Command):
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
# APCF Subcommands
class LeApcfOpcode(hci.SpecableEnum):
ENABLE = 0x00
SET_FILTERING_PARAMETERS = 0x01
BROADCASTER_ADDRESS = 0x02
SERVICE_UUID = 0x03
SERVICE_SOLICITATION_UUID = 0x04
LOCAL_NAME = 0x05
MANUFACTURER_DATA = 0x06
SERVICE_DATA = 0x07
TRANSPORT_DISCOVERY_SERVICE = 0x08
AD_TYPE_FILTER = 0x09
READ_EXTENDED_FEATURES = 0xFF
@dataclasses.dataclass
class HCI_LE_APCF_Command(hci.HCI_Command):
class HCI_LE_APCF_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = field(metadata=LeApcfOpcode.type_metadata(1))
payload: bytes = field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_LE_APCF_ReturnParameters)
@dataclasses.dataclass
class HCI_LE_APCF_Command(hci.HCI_SyncCommand[HCI_LE_APCF_ReturnParameters]):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_apcf_command
@@ -105,52 +135,52 @@ class HCI_LE_APCF_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures.
'''
# APCF Subcommands
class Opcode(hci.SpecableEnum):
ENABLE = 0x00
SET_FILTERING_PARAMETERS = 0x01
BROADCASTER_ADDRESS = 0x02
SERVICE_UUID = 0x03
SERVICE_SOLICITATION_UUID = 0x04
LOCAL_NAME = 0x05
MANUFACTURER_DATA = 0x06
SERVICE_DATA = 0x07
TRANSPORT_DISCOVERY_SERVICE = 0x08
AD_TYPE_FILTER = 0x09
READ_EXTENDED_FEATURES = 0xFF
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
opcode: int = dataclasses.field(metadata=LeApcfOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Get_Controller_Activity_Energy_Info_Command(hci.HCI_Command):
class HCI_Get_Controller_Activity_Energy_Info_ReturnParameters(
hci.HCI_StatusReturnParameters
):
total_tx_time_ms: int = field(metadata=hci.metadata(4))
total_rx_time_ms: int = field(metadata=hci.metadata(4))
total_idle_time_ms: int = field(metadata=hci.metadata(4))
total_energy_used: int = field(metadata=hci.metadata(4))
@hci.HCI_SyncCommand.sync_command(
HCI_Get_Controller_Activity_Energy_Info_ReturnParameters
)
@dataclasses.dataclass
class HCI_Get_Controller_Activity_Energy_Info_Command(
hci.HCI_SyncCommand[HCI_Get_Controller_Activity_Energy_Info_ReturnParameters]
):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_get_controller_activity_energy_info
'''
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('total_tx_time_ms', 4),
('total_rx_time_ms', 4),
('total_idle_time_ms', 4),
('total_energy_used', 4),
]
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
# A2DP Hardware Offload Subcommands
class A2dpHardwareOffloadOpcode(hci.SpecableEnum):
START_A2DP_OFFLOAD = 0x01
STOP_A2DP_OFFLOAD = 0x02
@dataclasses.dataclass
class HCI_A2DP_Hardware_Offload_Command(hci.HCI_Command):
class HCI_A2DP_Hardware_Offload_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = dataclasses.field(metadata=A2dpHardwareOffloadOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
@hci.HCI_SyncCommand.sync_command(HCI_A2DP_Hardware_Offload_ReturnParameters)
@dataclasses.dataclass
class HCI_A2DP_Hardware_Offload_Command(
hci.HCI_SyncCommand[HCI_A2DP_Hardware_Offload_ReturnParameters]
):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#a2dp-hardware-offload-support
@@ -159,25 +189,27 @@ class HCI_A2DP_Hardware_Offload_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures.
'''
# A2DP Hardware Offload Subcommands
class Opcode(hci.SpecableEnum):
START_A2DP_OFFLOAD = 0x01
STOP_A2DP_OFFLOAD = 0x02
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
opcode: int = dataclasses.field(metadata=A2dpHardwareOffloadOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
# Dynamic Audio Buffer Subcommands
class DynamicAudioBufferOpcode(hci.SpecableEnum):
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
@dataclasses.dataclass
class HCI_Dynamic_Audio_Buffer_Command(hci.HCI_Command):
class HCI_Dynamic_Audio_Buffer_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = dataclasses.field(metadata=DynamicAudioBufferOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
@hci.HCI_SyncCommand.sync_command(HCI_Dynamic_Audio_Buffer_ReturnParameters)
@dataclasses.dataclass
class HCI_Dynamic_Audio_Buffer_Command(
hci.HCI_SyncCommand[HCI_Dynamic_Audio_Buffer_ReturnParameters]
):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#dynamic-audio-buffer-command
@@ -186,19 +218,9 @@ class HCI_Dynamic_Audio_Buffer_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures.
'''
# Dynamic Audio Buffer Subcommands
class Opcode(hci.SpecableEnum):
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
opcode: int = dataclasses.field(metadata=DynamicAudioBufferOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# -----------------------------------------------------------------------------
class HCI_Android_Vendor_Event(hci.HCI_Extended_Event):

View File

@@ -46,9 +46,19 @@ class TX_Power_Level_Command:
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Write_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
class HCI_Write_Tx_Power_Level_ReturnParameters(hci.HCI_StatusReturnParameters):
handle_type: int = hci.field(metadata=hci.metadata(1))
connection_handle: int = hci.field(metadata=hci.metadata(2))
selected_tx_power_level: int = hci.field(metadata=hci.metadata(-1))
@hci.HCI_SyncCommand.sync_command(HCI_Write_Tx_Power_Level_ReturnParameters)
@dataclasses.dataclass
class HCI_Write_Tx_Power_Level_Command(
hci.HCI_SyncCommand[HCI_Write_Tx_Power_Level_ReturnParameters],
TX_Power_Level_Command,
):
'''
Write TX power level. See BT_HCI_OP_VS_WRITE_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
@@ -61,18 +71,21 @@ class HCI_Write_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
connection_handle: int = dataclasses.field(metadata=hci.metadata(2))
tx_power_level: int = dataclasses.field(metadata=hci.metadata(-1))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('selected_tx_power_level', -1),
]
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Read_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
class HCI_Read_Tx_Power_Level_ReturnParameters(hci.HCI_StatusReturnParameters):
handle_type: int = hci.field(metadata=hci.metadata(1))
connection_handle: int = hci.field(metadata=hci.metadata(2))
tx_power_level: int = hci.field(metadata=hci.metadata(-1))
@hci.HCI_SyncCommand.sync_command(HCI_Read_Tx_Power_Level_ReturnParameters)
@dataclasses.dataclass
class HCI_Read_Tx_Power_Level_Command(
hci.HCI_SyncCommand[HCI_Read_Tx_Power_Level_ReturnParameters],
TX_Power_Level_Command,
):
'''
Read TX power level. See BT_HCI_OP_VS_READ_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
@@ -83,10 +96,3 @@ class HCI_Read_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
handle_type: int = dataclasses.field(metadata=hci.metadata(1))
connection_handle: int = dataclasses.field(metadata=hci.metadata(2))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('tx_power_level', -1),
]

View File

@@ -37,15 +37,16 @@ The vendor specific HCI commands to read and write TX power are defined in
from bumble.vendor.zephyr.hci import HCI_Write_Tx_Power_Level_Command
# set advertising power to -4 dB
response = await host.send_command(
response = await host.send_sync_command(
HCI_Write_Tx_Power_Level_Command(
handle_type=HCI_Write_Tx_Power_Level_Command.TX_POWER_HANDLE_TYPE_ADV,
connection_handle=0,
tx_power_level=-4,
)
),
check_status=False
)
if response.return_parameters.status == HCI_SUCCESS:
print(f"TX power set to {response.return_parameters.selected_tx_power_level}")
if response.status == HCI_SUCCESS:
print(f"TX power set to {response.selected_tx_power_level}")
```

View File

@@ -71,8 +71,8 @@ async def main() -> None:
rr_intervals=random.choice(
(
(
random.randint(900, 1100) / 1000,
random.randint(900, 1100) / 1000,
random.randint(900, 1100) // 1000,
random.randint(900, 1100) // 1000,
),
None,
)

View File

@@ -17,6 +17,6 @@ use pyo3::PyResult;
#[pyo3_asyncio::tokio::test]
async fn realtek_driver_info_all_drivers() -> PyResult<()> {
assert_eq!(12, DriverInfo::all_drivers()?.len());
assert_eq!(13, DriverInfo::all_drivers()?.len());
Ok(())
}

View File

@@ -0,0 +1,34 @@
# Copyright 2021-2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from bumble import device as device_module
from bumble.profiles import battery_service
from . import test_utils
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_battery_level():
devices = await test_utils.TwoDevices.create_with_connection()
service = battery_service.BatteryService(lambda _: 1)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(battery_service.BatteryServiceProxy)
assert client
assert await client.battery_level.read_value() == 1

View File

@@ -42,7 +42,6 @@ from bumble.hci import (
HCI_CREATE_CONNECTION_COMMAND,
HCI_SUCCESS,
Address,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_Connection_Complete_Event,
HCI_Connection_Request_Event,
@@ -154,10 +153,10 @@ async def test_device_connect_parallel():
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d1.host.on_hci_packet(
HCI_Command_Complete_Event(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
)
)
@@ -188,10 +187,10 @@ async def test_device_connect_parallel():
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d2.host.on_hci_packet(
HCI_Command_Complete_Event(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
)
)

View File

@@ -20,7 +20,7 @@ import struct
import pytest
from bumble import hci
from bumble import hci, utils
# -----------------------------------------------------------------------------
# pylint: disable=invalid-name
@@ -136,43 +136,25 @@ def test_HCI_LE_Channel_Selection_Algorithm_Event():
# -----------------------------------------------------------------------------
def test_HCI_Command_Complete_Event():
# With a serializable object
event = hci.HCI_Command_Complete_Event(
event1 = hci.HCI_Command_Complete_Event(
num_hci_command_packets=34,
command_opcode=hci.HCI_LE_READ_BUFFER_SIZE_COMMAND,
return_parameters=hci.HCI_LE_Read_Buffer_Size_Command.create_return_parameters(
return_parameters=hci.HCI_LE_Read_Buffer_Size_Command.return_parameters_class(
status=0,
le_acl_data_packet_length=1234,
total_num_le_acl_data_packets=56,
),
)
basic_check(event)
# With an arbitrary byte array
event = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=bytes([1, 2, 3, 4]),
)
basic_check(event)
# With a simple status as a 1-byte array
event = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=bytes([7]),
)
basic_check(event)
event = hci.HCI_Packet.from_bytes(bytes(event))
assert event.return_parameters == 7
basic_check(event1)
# With a simple status as an integer status
event = hci.HCI_Command_Complete_Event(
event3 = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=9,
return_parameters=hci.HCI_StatusReturnParameters(hci.HCI_ErrorCode(9)),
)
basic_check(event)
assert event.return_parameters == 9
basic_check(event3)
assert event3.return_parameters.status == 9
# -----------------------------------------------------------------------------
@@ -229,6 +211,28 @@ def test_HCI_Vendor_Event():
assert isinstance(parsed, hci.HCI_Vendor_Event)
# -----------------------------------------------------------------------------
def test_return_parameters() -> None:
params = hci.HCI_Reset_Command.parse_return_parameters(bytes.fromhex('3C'))
assert params.status == hci.HCI_ErrorCode.ADVERTISING_TIMEOUT_ERROR
assert isinstance(params.status, utils.OpenIntEnum)
params = hci.HCI_Read_BD_ADDR_Command.parse_return_parameters(
bytes.fromhex('3C001122334455')
)
assert params.status == hci.HCI_ErrorCode.ADVERTISING_TIMEOUT_ERROR
assert isinstance(params.status, utils.OpenIntEnum)
assert isinstance(params.bd_addr, hci.Address)
params = hci.HCI_Read_Local_Name_Command.parse_return_parameters(
bytes.fromhex('0068656c6c6f') + bytes(248 - 5)
)
assert params.status == hci.HCI_ErrorCode.SUCCESS
assert isinstance(params.local_name, bytes)
assert len(params.local_name) == 248
assert hci.map_null_terminated_utf8_string(params.local_name) == 'hello'
# -----------------------------------------------------------------------------
def test_HCI_Command():
command = hci.HCI_Command(op_code=0x5566)
@@ -291,7 +295,7 @@ def test_custom_le_meta_event():
for clazz in inspect.getmembers(hci)
if isinstance(clazz[1], type)
and issubclass(clazz[1], hci.HCI_Command)
and clazz[1] is not hci.HCI_Command
and clazz[1] not in (hci.HCI_Command, hci.HCI_SyncCommand, hci.HCI_AsyncCommand)
],
)
def test_hci_command_subclasses_op_code(clazz: type[hci.HCI_Command]):
@@ -620,21 +624,19 @@ def test_HCI_Read_Local_Supported_Codecs_Command_Complete():
# -----------------------------------------------------------------------------
def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete():
returned_parameters = (
hci.HCI_Read_Local_Supported_Codecs_V2_Command.parse_return_parameters(
bytes(
[
hci.HCI_SUCCESS,
3,
hci.CodecID.A_LOG,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL,
hci.CodecID.CVSD,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO,
hci.CodecID.LINEAR_PCM,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS,
0,
]
)
returned_parameters = hci.HCI_Read_Local_Supported_Codecs_V2_Command.parse_return_parameters(
bytes(
[
hci.HCI_SUCCESS,
3,
hci.CodecID.A_LOG,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_ACL,
hci.CodecID.CVSD,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_SCO,
hci.CodecID.LINEAR_PCM,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.LE_CIS,
0,
]
)
)
assert returned_parameters.standard_codec_ids == [
@@ -643,9 +645,9 @@ def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete():
hci.CodecID.LINEAR_PCM,
]
assert returned_parameters.standard_codec_transports == [
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_ACL,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_SCO,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.LE_CIS,
]
@@ -737,6 +739,7 @@ def run_test_commands():
if __name__ == '__main__':
run_test_events()
run_test_commands()
test_return_parameters()
test_address()
test_custom()
test_iso_data_packet()

View File

@@ -0,0 +1,89 @@
# Copyright 2021-2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import itertools
from collections.abc import Sequence
import pytest
from bumble import device as device_module
from bumble.profiles import heart_rate_service
from . import test_utils
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
"heart_rate, sensor_contact_detected, energy_expanded, rr_intervals",
itertools.product(
(1, 1000), (True, False, None), (2, None), ((3.0, 4.0, 5.0), None)
),
)
async def test_read_measurement(
heart_rate: int,
sensor_contact_detected: bool | None,
energy_expanded: int | None,
rr_intervals: Sequence[int] | None,
):
devices = await test_utils.TwoDevices.create_with_connection()
measurement = heart_rate_service.HeartRateService.HeartRateMeasurement(
heart_rate, sensor_contact_detected, energy_expanded, rr_intervals
)
service = heart_rate_service.HeartRateService(lambda _: measurement)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(heart_rate_service.HeartRateServiceProxy)
assert client
assert await client.heart_rate_measurement.read_value() == measurement
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_body_sensor_location():
devices = await test_utils.TwoDevices.create_with_connection()
measurement = heart_rate_service.HeartRateService.HeartRateMeasurement(0)
location = heart_rate_service.HeartRateService.BodySensorLocation.FINGER
service = heart_rate_service.HeartRateService(
lambda _: measurement,
body_sensor_location=location,
)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(heart_rate_service.HeartRateServiceProxy)
assert client
assert client.body_sensor_location
assert await client.body_sensor_location.read_value() == location
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_reset_energy_expended() -> None:
devices = await test_utils.TwoDevices.create_with_connection()
measurement = heart_rate_service.HeartRateService.HeartRateMeasurement(1)
reset_energy_expended = asyncio.Queue[None]()
service = heart_rate_service.HeartRateService(
lambda _: measurement,
reset_energy_expended=lambda _: reset_energy_expended.put_nowait(None),
)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(heart_rate_service.HeartRateServiceProxy)
assert client
await client.reset_energy_expended()
await reset_energy_expended.get()

View File

@@ -15,6 +15,7 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import unittest
import unittest.mock
@@ -22,9 +23,17 @@ import unittest.mock
import pytest
from bumble.controller import Controller
from bumble.hci import HCI_AclDataPacket
from bumble.hci import (
HCI_AclDataPacket,
HCI_Command_Complete_Event,
HCI_Error,
HCI_ErrorCode,
HCI_Event,
HCI_Reset_Command,
HCI_StatusReturnParameters,
)
from bumble.host import DataPacketQueue, Host
from bumble.transport.common import AsyncPipeSink
from bumble.transport.common import AsyncPipeSink, TransportSink
# -----------------------------------------------------------------------------
# Logging
@@ -151,3 +160,58 @@ def test_data_packet_queue():
assert drain_listener.on_flow.call_count == 1
assert queue.queued == 15
assert queue.completed == 15
# -----------------------------------------------------------------------------
class Source:
terminated: asyncio.Future[None]
sink: TransportSink
def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink
class Sink:
response: HCI_Event
def __init__(self, source: Source, response: HCI_Event) -> None:
self.source = source
self.response = response
def on_packet(self, packet: bytes) -> None:
self.source.sink.on_packet(bytes(self.response))
@pytest.mark.asyncio
async def test_send_sync_command() -> None:
source = Source()
sink = Sink(
source,
HCI_Command_Complete_Event(
1,
HCI_Reset_Command.op_code,
HCI_StatusReturnParameters(status=HCI_ErrorCode.SUCCESS),
),
)
host = Host(source, sink)
# Sync command with success
response1 = await host.send_sync_command(HCI_Reset_Command())
assert response1.status == HCI_ErrorCode.SUCCESS
# Sync command with error status should raise
error_response = HCI_Command_Complete_Event(
1,
HCI_Reset_Command.op_code,
HCI_StatusReturnParameters(status=HCI_ErrorCode.COMMAND_DISALLOWED_ERROR),
)
sink.response = error_response
with pytest.raises(HCI_Error) as excinfo:
await host.send_sync_command(HCI_Reset_Command())
assert excinfo.value.error_code == error_response.return_parameters.status
# Sync command with error status should not raise when `check_status` is False
response2 = await host.send_sync_command(HCI_Reset_Command(), check_status=False)
assert response2.status == HCI_ErrorCode.COMMAND_DISALLOWED_ERROR

View File

@@ -89,7 +89,7 @@ class HeartRateMonitor:
async def stop(self):
# TODO: replace this once a proper reset is implemented in the lib.
await self.device.host.send_command(HCI_Reset_Command())
await self.device.host.send_sync_command(HCI_Reset_Command())
await self.device.power_off()
print('### Monitor stopped')

View File

@@ -60,7 +60,7 @@ class Scanner(utils.EventEmitter):
async def stop(self):
# TODO: replace this once a proper reset is implemented in the lib.
await self.device.host.send_command(HCI_Reset_Command())
await self.device.host.send_sync_command(HCI_Reset_Command())
await self.device.power_off()
print('### Scanner stopped')

View File

@@ -311,7 +311,7 @@ class Speaker:
async def stop(self):
# TODO: replace this once a proper reset is implemented in the lib.
await self.device.host.send_command(HCI_Reset_Command())
await self.device.host.send_sync_command(HCI_Reset_Command())
await self.device.power_off()
print('Speaker stopped')