typing surrport for HCI commands return parameters

This commit is contained in:
Gilles Boccon-Gibod
2026-01-17 13:19:36 -08:00
parent 2cad743f8c
commit 7523118581
21 changed files with 2066 additions and 1582 deletions

View File

@@ -34,11 +34,7 @@ from bumble.hci import (
HCI_READ_BD_ADDR_COMMAND, HCI_READ_BD_ADDR_COMMAND,
HCI_READ_BUFFER_SIZE_COMMAND, HCI_READ_BUFFER_SIZE_COMMAND,
HCI_READ_LOCAL_NAME_COMMAND, HCI_READ_LOCAL_NAME_COMMAND,
HCI_SUCCESS,
CodecID,
HCI_Command, HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_LE_Read_Buffer_Size_Command, HCI_LE_Read_Buffer_Size_Command,
HCI_LE_Read_Buffer_Size_V2_Command, HCI_LE_Read_Buffer_Size_V2_Command,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command, HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
@@ -59,34 +55,23 @@ from bumble.host import Host
from bumble.transport import open_transport from bumble.transport import open_transport
# -----------------------------------------------------------------------------
def command_succeeded(response):
if isinstance(response, HCI_Command_Status_Event):
return response.status == HCI_SUCCESS
if isinstance(response, HCI_Command_Complete_Event):
return response.return_parameters.status == HCI_SUCCESS
return False
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_classic_info(host: Host) -> None: async def get_classic_info(host: Host) -> None:
if host.supports_command(HCI_READ_BD_ADDR_COMMAND): if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command()) response1 = await host.send_sync_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response): print()
print() print(
print( color('Public Address:', 'yellow'),
color('Public Address:', 'yellow'), response1.bd_addr.to_string(False),
response.return_parameters.bd_addr.to_string(False), )
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND): if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response = await host.send_command(HCI_Read_Local_Name_Command()) response2 = await host.send_sync_command(HCI_Read_Local_Name_Command())
if command_succeeded(response): print()
print() print(
print( color('Local Name:', 'yellow'),
color('Local Name:', 'yellow'), map_null_terminated_utf8_string(response2.local_name),
map_null_terminated_utf8_string(response.return_parameters.local_name), )
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -94,52 +79,50 @@ async def get_le_info(host: Host) -> None:
print() print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND): if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
response = await host.send_command( response1 = await host.send_sync_command(
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command() HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
) )
if command_succeeded(response): print(
print( color('LE Number Of Supported Advertising Sets:', 'yellow'),
color('LE Number Of Supported Advertising Sets:', 'yellow'), response1.num_supported_advertising_sets,
response.return_parameters.num_supported_advertising_sets, '\n',
'\n', )
)
if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND): if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND):
response = await host.send_command( response2 = await host.send_sync_command(
HCI_LE_Read_Maximum_Advertising_Data_Length_Command() HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
) )
if command_succeeded(response): print(
print( color('LE Maximum Advertising Data Length:', 'yellow'),
color('LE Maximum Advertising Data Length:', 'yellow'), response2.max_advertising_data_length,
response.return_parameters.max_advertising_data_length, '\n',
'\n', )
)
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND): if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
response = await host.send_command(HCI_LE_Read_Maximum_Data_Length_Command()) response3 = await host.send_sync_command(
if command_succeeded(response): HCI_LE_Read_Maximum_Data_Length_Command()
print( )
color('Maximum Data Length:', 'yellow'), print(
( color('Maximum Data Length:', 'yellow'),
f'tx:{response.return_parameters.supported_max_tx_octets}/' (
f'{response.return_parameters.supported_max_tx_time}, ' f'tx:{response3.supported_max_tx_octets}/'
f'rx:{response.return_parameters.supported_max_rx_octets}/' f'{response3.supported_max_tx_time}, '
f'{response.return_parameters.supported_max_rx_time}' f'rx:{response3.supported_max_rx_octets}/'
), f'{response3.supported_max_rx_time}'
'\n', ),
) '\n',
)
if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND): if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await host.send_command( response4 = await host.send_sync_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command() HCI_LE_Read_Suggested_Default_Data_Length_Command()
) )
if command_succeeded(response): print(
print( color('Suggested Default Data Length:', 'yellow'),
color('Suggested Default Data Length:', 'yellow'), f'{response4.suggested_max_tx_octets}/'
f'{response.return_parameters.suggested_max_tx_octets}/' f'{response4.suggested_max_tx_time}',
f'{response.return_parameters.suggested_max_tx_time}', '\n',
'\n', )
)
print(color('LE Features:', 'yellow')) print(color('LE Features:', 'yellow'))
for feature in host.supported_le_features: for feature in host.supported_le_features:
@@ -151,37 +134,31 @@ async def get_flow_control_info(host: Host) -> None:
print() print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND): if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command( response1 = await host.send_sync_command(HCI_Read_Buffer_Size_Command())
HCI_Read_Buffer_Size_Command(), check_result=True
)
print( print(
color('ACL Flow Control:', 'yellow'), color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} ' f'{response1.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}', f'packets of size {response1.hc_acl_data_packet_length}',
) )
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND): if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await host.send_command( response2 = await host.send_sync_command(HCI_LE_Read_Buffer_Size_V2_Command())
HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
print( print(
color('LE ACL Flow Control:', 'yellow'), color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_le_acl_data_packets} ' f'{response2.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.le_acl_data_packet_length}', f'packets of size {response2.le_acl_data_packet_length}',
) )
print( print(
color('LE ISO Flow Control:', 'yellow'), color('LE ISO Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_iso_data_packets} ' f'{response2.total_num_iso_data_packets} '
f'packets of size {response.return_parameters.iso_data_packet_length}', f'packets of size {response2.iso_data_packet_length}',
) )
elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND): elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command( response3 = await host.send_sync_command(HCI_LE_Read_Buffer_Size_Command())
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
print( print(
color('LE ACL Flow Control:', 'yellow'), color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_le_acl_data_packets} ' f'{response3.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.le_acl_data_packet_length}', f'packets of size {response3.le_acl_data_packet_length}',
) )
@@ -190,52 +167,44 @@ async def get_codecs_info(host: Host) -> None:
print() print()
if host.supports_command(HCI_Read_Local_Supported_Codecs_V2_Command.op_code): if host.supports_command(HCI_Read_Local_Supported_Codecs_V2_Command.op_code):
response = await host.send_command( response1 = await host.send_sync_command(
HCI_Read_Local_Supported_Codecs_V2_Command(), check_result=True HCI_Read_Local_Supported_Codecs_V2_Command()
) )
print(color('Codecs:', 'yellow')) print(color('Codecs:', 'yellow'))
for codec_id, transport in zip( for codec_id, transport in zip(
response.return_parameters.standard_codec_ids, response1.standard_codec_ids,
response.return_parameters.standard_codec_transports, response1.standard_codec_transports,
): ):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport( print(f' {codec_id.name} - {transport.name}')
transport
).name
codec_name = CodecID(codec_id).name
print(f' {codec_name} - {transport_name}')
for codec_id, transport in zip( for vendor_codec_id, vendor_transport in zip(
response.return_parameters.vendor_specific_codec_ids, response1.vendor_specific_codec_ids,
response.return_parameters.vendor_specific_codec_transports, response1.vendor_specific_codec_transports,
): ):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport( company = name_or_number(COMPANY_IDENTIFIERS, vendor_codec_id >> 16)
transport print(f' {company} / {vendor_codec_id & 0xFFFF} - {vendor_transport.name}')
).name
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF} - {transport_name}')
if not response.return_parameters.standard_codec_ids: if not response1.standard_codec_ids:
print(' No standard codecs') print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids: if not response1.vendor_specific_codec_ids:
print(' No Vendor-specific codecs') print(' No Vendor-specific codecs')
if host.supports_command(HCI_Read_Local_Supported_Codecs_Command.op_code): if host.supports_command(HCI_Read_Local_Supported_Codecs_Command.op_code):
response = await host.send_command( response2 = await host.send_sync_command(
HCI_Read_Local_Supported_Codecs_Command(), check_result=True HCI_Read_Local_Supported_Codecs_Command()
) )
print(color('Codecs (BR/EDR):', 'yellow')) print(color('Codecs (BR/EDR):', 'yellow'))
for codec_id in response.return_parameters.standard_codec_ids: for codec_id in response2.standard_codec_ids:
codec_name = CodecID(codec_id).name print(f' {codec_id.name}')
print(f' {codec_name}')
for codec_id in response.return_parameters.vendor_specific_codec_ids: for vendor_codec_id in response2.vendor_specific_codec_ids:
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16) company = name_or_number(COMPANY_IDENTIFIERS, vendor_codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF}') print(f' {company} / {vendor_codec_id & 0xFFFF}')
if not response.return_parameters.standard_codec_ids: if not response2.standard_codec_ids:
print(' No standard codecs') print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids: if not response2.vendor_specific_codec_ids:
print(' No Vendor-specific codecs') print(' No Vendor-specific codecs')

View File

@@ -85,7 +85,7 @@ class Loopback:
print(color('@@@ Received last packet', 'green')) print(color('@@@ Received last packet', 'green'))
self.done.set() self.done.set()
async def run(self): async def run(self) -> None:
"""Run a loopback throughput test""" """Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green')) print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport(self.transport) as ( async with await open_transport(self.transport) as (
@@ -100,11 +100,15 @@ class Loopback:
# make sure data can fit in one l2cap pdu # make sure data can fit in one l2cap pdu
l2cap_header_size = 4 l2cap_header_size = 4
max_packet_size = ( packet_queue = (
host.acl_packet_queue host.acl_packet_queue
if host.acl_packet_queue if host.acl_packet_queue
else host.le_acl_packet_queue else host.le_acl_packet_queue
).max_packet_size - l2cap_header_size )
if packet_queue is None:
print(color('!!! No packet queue', 'red'))
return
max_packet_size = packet_queue.max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size: if self.packet_size > max_packet_size:
print( print(
color( color(
@@ -128,20 +132,18 @@ class Loopback:
loopback_mode = LoopbackMode.LOCAL loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue')) print(color('### Setting loopback mode', 'blue'))
await host.send_command( await host.send_sync_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL), HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
) )
print(color('### Checking loopback mode', 'blue')) print(color('### Checking loopback mode', 'blue'))
response = await host.send_command( response = await host.send_sync_command(HCI_Read_Loopback_Mode_Command())
HCI_Read_Loopback_Mode_Command(), check_result=True if response.loopback_mode != loopback_mode:
)
if response.return_parameters.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red')) print(color('!!! Loopback mode mismatch', 'red'))
return return
await self.connection_event.wait() await self.connection_event.wait()
assert self.connection_handle is not None
print(color('### Connected', 'cyan')) print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta')) print(color('=== Start sending', 'magenta'))

View File

@@ -421,7 +421,7 @@ class Controller:
hci.HCI_Command_Complete_Event( hci.HCI_Command_Complete_Event(
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode=command.op_code, command_opcode=command.op_code,
return_parameters=result, return_parameters=hci.HCI_GenericReturnParameters(data=result),
) )
) )

View File

@@ -923,7 +923,7 @@ class DeviceClass:
# pylint: enable=line-too-long # pylint: enable=line-too-long
@staticmethod @staticmethod
def split_class_of_device(class_of_device): def split_class_of_device(class_of_device: int) -> tuple[int, int, int]:
# Split the bit fields of the composite class of device value into: # Split the bit fields of the composite class of device value into:
# (service_classes, major_device_class, minor_device_class) # (service_classes, major_device_class, minor_device_class)
return ( return (

File diff suppressed because it is too large Load Diff

View File

@@ -89,51 +89,54 @@ HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND = hci.hci_vendor_command_op_code(0x000E)
hci.HCI_Command.register_commands(globals()) hci.HCI_Command.register_commands(globals())
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class HCI_Intel_Read_Version_Command(hci.HCI_Command): class HCI_Intel_Read_Version_ReturnParameters(hci.HCI_StatusReturnParameters):
tlv: bytes = hci.field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Read_Version_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Read_Version_Command(
hci.HCI_SyncCommand[HCI_Intel_Read_Version_ReturnParameters]
):
param0: int = dataclasses.field(metadata=hci.metadata(1)) param0: int = dataclasses.field(metadata=hci.metadata(1))
return_parameters_fields = [
("status", hci.STATUS_SPEC),
("tlv", "*"),
]
@hci.HCI_SyncCommand.sync_command(hci.HCI_StatusReturnParameters)
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class Hci_Intel_Secure_Send_Command(hci.HCI_Command): class Hci_Intel_Secure_Send_Command(
hci.HCI_SyncCommand[hci.HCI_StatusReturnParameters]
):
data_type: int = dataclasses.field(metadata=hci.metadata(1)) data_type: int = dataclasses.field(metadata=hci.metadata(1))
data: bytes = dataclasses.field(metadata=hci.metadata("*")) data: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
("status", 1),
]
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class HCI_Intel_Reset_Command(hci.HCI_Command): class HCI_Intel_Reset_ReturnParameters(hci.HCI_ReturnParameters):
data: bytes = hci.field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Reset_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Reset_Command(hci.HCI_SyncCommand[HCI_Intel_Reset_ReturnParameters]):
reset_type: int = dataclasses.field(metadata=hci.metadata(1)) reset_type: int = dataclasses.field(metadata=hci.metadata(1))
patch_enable: int = dataclasses.field(metadata=hci.metadata(1)) patch_enable: int = dataclasses.field(metadata=hci.metadata(1))
ddc_reload: int = dataclasses.field(metadata=hci.metadata(1)) ddc_reload: int = dataclasses.field(metadata=hci.metadata(1))
boot_option: int = dataclasses.field(metadata=hci.metadata(1)) boot_option: int = dataclasses.field(metadata=hci.metadata(1))
boot_address: int = dataclasses.field(metadata=hci.metadata(4)) boot_address: int = dataclasses.field(metadata=hci.metadata(4))
return_parameters_fields = [
("data", "*"),
]
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class Hci_Intel_Write_Device_Config_Command(hci.HCI_Command): class HCI_Intel_Write_Device_Config_ReturnParameters(hci.HCI_StatusReturnParameters):
data: bytes = dataclasses.field(metadata=hci.metadata("*")) params: bytes = hci.field(metadata=hci.metadata('*'))
return_parameters_fields = [
("status", hci.STATUS_SPEC), @hci.HCI_SyncCommand.sync_command(HCI_Intel_Write_Device_Config_ReturnParameters)
("params", "*"), @dataclasses.dataclass
] class HCI_Intel_Write_Device_Config_Command(
hci.HCI_SyncCommand[HCI_Intel_Write_Device_Config_ReturnParameters]
):
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -402,7 +405,7 @@ class Driver(common.Driver):
self.host.on_hci_event_packet(event) self.host.on_hci_event_packet(event)
return return
if not event.return_parameters == hci.HCI_SUCCESS: if not event.return_parameters.status == hci.HCI_SUCCESS:
raise DriverError("HCI_Command_Complete_Event error") raise DriverError("HCI_Command_Complete_Event error")
if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets: if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets:
@@ -641,8 +644,8 @@ class Driver(common.Driver):
while ddc_data: while ddc_data:
ddc_len = 1 + ddc_data[0] ddc_len = 1 + ddc_data[0]
ddc_payload = ddc_data[:ddc_len] ddc_payload = ddc_data[:ddc_len]
await self.host.send_command( await self.host.send_sync_command(
Hci_Intel_Write_Device_Config_Command(data=ddc_payload) HCI_Intel_Write_Device_Config_Command(data=ddc_payload)
) )
ddc_data = ddc_data[ddc_len:] ddc_data = ddc_data[ddc_len:]
@@ -660,31 +663,26 @@ class Driver(common.Driver):
async def read_device_info(self) -> dict[ValueType, Any]: async def read_device_info(self) -> dict[ValueType, Any]:
self.host.ready = True self.host.ready = True
response = await self.host.send_command(hci.HCI_Reset_Command()) response1 = await self.host.send_sync_command(
if not ( hci.HCI_Reset_Command(), check_status=False
isinstance(response, hci.HCI_Command_Complete_Event) )
and response.return_parameters if response1.status not in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS):
in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS)
):
# When the controller is in operational mode, the response is a # When the controller is in operational mode, the response is a
# successful response. # successful response.
# When the controller is in bootloader mode, # When the controller is in bootloader mode,
# HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything # HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything
# else is a failure. # else is a failure.
logger.warning(f"unexpected response: {response}") logger.warning(f"unexpected response: {response1}")
raise DriverError("unexpected HCI response") raise DriverError("unexpected HCI response")
# Read the firmware version. # Read the firmware version.
response = await self.host.send_command( response2 = await self.host.send_sync_command(
HCI_Intel_Read_Version_Command(param0=0xFF) HCI_Intel_Read_Version_Command(param0=0xFF), check_status=False
) )
if not isinstance(response, hci.HCI_Command_Complete_Event): if response2.status != 0: # type: ignore
raise DriverError("unexpected HCI response")
if response.return_parameters.status != 0: # type: ignore
raise DriverError("HCI_Intel_Read_Version_Command error") raise DriverError("HCI_Intel_Read_Version_Command error")
tlvs = _parse_tlv(response.return_parameters.tlv) # type: ignore tlvs = _parse_tlv(response2.tlv) # type: ignore
# Convert the list to a dict. That's Ok here because we only expect each type # Convert the list to a dict. That's Ok here because we only expect each type
# to appear just once. # to appear just once.

View File

@@ -16,6 +16,7 @@ Support for Realtek USB dongles.
Based on various online bits of information, including the Linux kernel. Based on various online bits of information, including the Linux kernel.
(see `drivers/bluetooth/btrtl.c`) (see `drivers/bluetooth/btrtl.c`)
""" """
from __future__ import annotations
import asyncio import asyncio
import enum import enum
@@ -31,10 +32,14 @@ import weakref
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from bumble import core, hci from bumble import core, hci
from bumble.drivers import common from bumble.drivers import common
if TYPE_CHECKING:
from bumble.host import Host
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -188,23 +193,36 @@ HCI_RTK_DROP_FIRMWARE_COMMAND = hci.hci_vendor_command_op_code(0x66)
hci.HCI_Command.register_commands(globals()) hci.HCI_Command.register_commands(globals())
@hci.HCI_Command.command
@dataclass @dataclass
class HCI_RTK_Read_ROM_Version_Command(hci.HCI_Command): class HCI_RTK_Read_ROM_Version_ReturnParameters(hci.HCI_StatusReturnParameters):
return_parameters_fields = [("status", hci.STATUS_SPEC), ("version", 1)] version: int = field(metadata=hci.metadata(1))
@hci.HCI_Command.command @hci.HCI_SyncCommand.sync_command(HCI_RTK_Read_ROM_Version_ReturnParameters)
@dataclass @dataclass
class HCI_RTK_Download_Command(hci.HCI_Command): class HCI_RTK_Read_ROM_Version_Command(
hci.HCI_SyncCommand[HCI_RTK_Read_ROM_Version_ReturnParameters]
):
pass
@dataclass
class HCI_RTK_Download_ReturnParameters(hci.HCI_StatusReturnParameters):
index: int = field(metadata=hci.metadata(1))
@hci.HCI_SyncCommand.sync_command(HCI_RTK_Download_ReturnParameters)
@dataclass
class HCI_RTK_Download_Command(hci.HCI_SyncCommand[HCI_RTK_Download_ReturnParameters]):
index: int = field(metadata=hci.metadata(1)) index: int = field(metadata=hci.metadata(1))
payload: bytes = field(metadata=hci.metadata(RTK_FRAGMENT_LENGTH)) payload: bytes = field(metadata=hci.metadata(RTK_FRAGMENT_LENGTH))
return_parameters_fields = [("status", hci.STATUS_SPEC), ("index", 1)]
@hci.HCI_Command.command @hci.HCI_SyncCommand.sync_command(hci.HCI_GenericReturnParameters)
@dataclass @dataclass
class HCI_RTK_Drop_Firmware_Command(hci.HCI_Command): class HCI_RTK_Drop_Firmware_Command(
hci.HCI_SyncCommand[hci.HCI_GenericReturnParameters]
):
pass pass
@@ -490,7 +508,7 @@ class Driver(common.Driver):
return None return None
@staticmethod @staticmethod
def check(host): def check(host: Host) -> bool:
if not host.hci_metadata: if not host.hci_metadata:
logger.debug("USB metadata not found") logger.debug("USB metadata not found")
return False return False
@@ -514,41 +532,39 @@ class Driver(common.Driver):
return True return True
@staticmethod @staticmethod
async def get_loaded_firmware_version(host): async def get_loaded_firmware_version(host: Host) -> int | None:
response = await host.send_command(HCI_RTK_Read_ROM_Version_Command()) response1 = await host.send_sync_command(
HCI_RTK_Read_ROM_Version_Command(), check_status=False
)
if response.return_parameters.status != hci.HCI_SUCCESS: if response1.status != hci.HCI_SUCCESS:
return None return None
response = await host.send_command( response2 = await host.send_sync_command(
hci.HCI_Read_Local_Version_Information_Command(), check_result=True hci.HCI_Read_Local_Version_Information_Command()
)
return (
response.return_parameters.hci_subversion << 16
| response.return_parameters.lmp_subversion
) )
return response2.hci_subversion << 16 | response2.lmp_subversion
@classmethod @classmethod
async def driver_info_for_host(cls, host): async def driver_info_for_host(cls, host: Host) -> DriverInfo | None:
try: try:
await host.send_command( await host.send_sync_command(
hci.HCI_Reset_Command(), hci.HCI_Reset_Command(),
check_result=True,
response_timeout=cls.POST_RESET_DELAY, response_timeout=cls.POST_RESET_DELAY,
) )
host.ready = True # Needed to let the host know the controller is ready. host.ready = True # Needed to let the host know the controller is ready.
except asyncio.exceptions.TimeoutError: except asyncio.exceptions.TimeoutError:
logger.warning("timeout waiting for hci reset, retrying") logger.warning("timeout waiting for hci reset, retrying")
await host.send_command(hci.HCI_Reset_Command(), check_result=True) await host.send_sync_command(hci.HCI_Reset_Command())
host.ready = True host.ready = True
command = hci.HCI_Read_Local_Version_Information_Command() command = hci.HCI_Read_Local_Version_Information_Command()
response = await host.send_command(command, check_result=True) response = await host.send_sync_command(command, check_status=False)
if response.command_opcode != command.op_code: if response.status != hci.HCI_SUCCESS:
logger.error("failed to probe local version information") logger.error("failed to probe local version information")
return None return None
local_version = response.return_parameters local_version = response
logger.debug( logger.debug(
f"looking for a driver: 0x{local_version.lmp_subversion:04X} " f"looking for a driver: 0x{local_version.lmp_subversion:04X} "
@@ -569,7 +585,7 @@ class Driver(common.Driver):
return driver_info return driver_info
@classmethod @classmethod
async def for_host(cls, host, force=False): async def for_host(cls, host: Host, force: bool = False):
# Check that a driver is needed for this host # Check that a driver is needed for this host
if not force and not cls.check(host): if not force and not cls.check(host):
return None return None
@@ -626,13 +642,13 @@ class Driver(common.Driver):
async def download_for_rtl8723b(self): async def download_for_rtl8723b(self):
if self.driver_info.has_rom_version: if self.driver_info.has_rom_version:
response = await self.host.send_command( response1 = await self.host.send_sync_command(
HCI_RTK_Read_ROM_Version_Command(), check_result=True HCI_RTK_Read_ROM_Version_Command(), check_status=False
) )
if response.return_parameters.status != hci.HCI_SUCCESS: if response1.status != hci.HCI_SUCCESS:
logger.warning("can't get ROM version") logger.warning("can't get ROM version")
return None return None
rom_version = response.return_parameters.version rom_version = response1.version
logger.debug(f"ROM version before download: {rom_version:04X}") logger.debug(f"ROM version before download: {rom_version:04X}")
else: else:
rom_version = 0 rom_version = 0
@@ -667,21 +683,20 @@ class Driver(common.Driver):
fragment_offset = fragment_index * RTK_FRAGMENT_LENGTH fragment_offset = fragment_index * RTK_FRAGMENT_LENGTH
fragment = payload[fragment_offset : fragment_offset + RTK_FRAGMENT_LENGTH] fragment = payload[fragment_offset : fragment_offset + RTK_FRAGMENT_LENGTH]
logger.debug(f"downloading fragment {fragment_index}") logger.debug(f"downloading fragment {fragment_index}")
await self.host.send_command( await self.host.send_sync_command(
HCI_RTK_Download_Command(index=download_index, payload=fragment), HCI_RTK_Download_Command(index=download_index, payload=fragment)
check_result=True,
) )
logger.debug("download complete!") logger.debug("download complete!")
# Read the version again # Read the version again
response = await self.host.send_command( response2 = await self.host.send_sync_command(
HCI_RTK_Read_ROM_Version_Command(), check_result=True HCI_RTK_Read_ROM_Version_Command(), check_status=False
) )
if response.return_parameters.status != hci.HCI_SUCCESS: if response2.status != hci.HCI_SUCCESS:
logger.warning("can't get ROM version") logger.warning("can't get ROM version")
else: else:
rom_version = response.return_parameters.version rom_version = response2.version
logger.debug(f"ROM version after download: {rom_version:02X}") logger.debug(f"ROM version after download: {rom_version:02X}")
return firmware.version return firmware.version
@@ -703,7 +718,7 @@ class Driver(common.Driver):
async def init_controller(self): async def init_controller(self):
await self.download_firmware() await self.download_firmware()
await self.host.send_command(hci.HCI_Reset_Command(), check_result=True) await self.host.send_sync_command(hci.HCI_Reset_Command())
logger.info(f"loaded FW image {self.driver_info.fw_name}") logger.info(f"loaded FW image {self.driver_info.fw_name}")

File diff suppressed because it is too large Load Diff

View File

@@ -23,11 +23,15 @@ import dataclasses
import logging import logging
import struct import struct
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
from bumble import drivers, hci, utils from bumble import drivers, hci, utils
from bumble.colors import color from bumble.colors import color
from bumble.core import ConnectionPHY, InvalidStateError, PhysicalTransport from bumble.core import (
ConnectionPHY,
InvalidStateError,
PhysicalTransport,
)
from bumble.l2cap import L2CAP_PDU from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper from bumble.snoop import Snooper
from bumble.transport.common import TransportLostError from bumble.transport.common import TransportLostError
@@ -35,7 +39,6 @@ from bumble.transport.common import TransportLostError
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.transport.common import TransportSink, TransportSource from bumble.transport.common import TransportSink, TransportSource
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -236,6 +239,9 @@ class IsoLink:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_RP = TypeVar('_RP', bound=hci.HCI_ReturnParameters)
class Host(utils.EventEmitter): class Host(utils.EventEmitter):
connections: dict[int, Connection] connections: dict[int, Connection]
cis_links: dict[int, IsoLink] cis_links: dict[int, IsoLink]
@@ -264,11 +270,13 @@ class Host(utils.EventEmitter):
self.bis_links = {} # BIS links, by connection handle self.bis_links = {} # BIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle self.sco_links = {} # SCO links, by connection handle
self.bigs = {} # BIG Handle to BIS Handles self.bigs = {} # BIG Handle to BIS Handles
self.pending_command = None self.pending_command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand | None = None
self.pending_response: asyncio.Future[Any] | None = None self.pending_response: asyncio.Future[Any] | None = None
self.number_of_supported_advertising_sets = 0 self.number_of_supported_advertising_sets = 0
self.maximum_advertising_data_length = 31 self.maximum_advertising_data_length = 31
self.local_version = None self.local_version: (
hci.HCI_Read_Local_Version_Information_ReturnParameters | None
) = None
self.local_supported_commands = 0 self.local_supported_commands = 0
self.local_le_features = 0 self.local_le_features = 0
self.local_lmp_features = hci.LmpFeatureMask(0) # Classic LMP features self.local_lmp_features = hci.LmpFeatureMask(0) # Classic LMP features
@@ -312,7 +320,7 @@ class Host(utils.EventEmitter):
self.emit('flush') self.emit('flush')
self.command_semaphore.release() self.command_semaphore.release()
async def reset(self, driver_factory=drivers.get_driver_for_host): async def reset(self, driver_factory=drivers.get_driver_for_host) -> None:
if self.ready: if self.ready:
self.ready = False self.ready = False
await self.flush() await self.flush()
@@ -330,57 +338,53 @@ class Host(utils.EventEmitter):
# Send a reset command unless a driver has already done so. # Send a reset command unless a driver has already done so.
if reset_needed: if reset_needed:
await self.send_command(hci.HCI_Reset_Command(), check_result=True) await self.send_sync_command(hci.HCI_Reset_Command())
self.ready = True self.ready = True
response = await self.send_command( response1 = await self.send_sync_command(
hci.HCI_Read_Local_Supported_Commands_Command(), check_result=True hci.HCI_Read_Local_Supported_Commands_Command()
) )
self.local_supported_commands = int.from_bytes( self.local_supported_commands = int.from_bytes(
response.return_parameters.supported_commands, 'little' response1.supported_commands, 'little'
) )
if self.supports_command(hci.HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND): if self.supports_command(hci.HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command( response2 = await self.send_sync_command(
hci.HCI_LE_Read_Local_Supported_Features_Command(), check_result=True hci.HCI_LE_Read_Local_Supported_Features_Command()
) )
self.local_le_features = struct.unpack( self.local_le_features = struct.unpack('<Q', response2.le_features)[0]
'<Q', response.return_parameters.le_features
)[0]
if self.supports_command(hci.HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND): if self.supports_command(hci.HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND):
response = await self.send_command( self.local_version = await self.send_sync_command(
hci.HCI_Read_Local_Version_Information_Command(), check_result=True hci.HCI_Read_Local_Version_Information_Command()
) )
self.local_version = response.return_parameters
if self.supports_command(hci.HCI_READ_LOCAL_EXTENDED_FEATURES_COMMAND): if self.supports_command(hci.HCI_READ_LOCAL_EXTENDED_FEATURES_COMMAND):
max_page_number = 0 max_page_number = 0
page_number = 0 page_number = 0
lmp_features = 0 lmp_features = 0
while page_number <= max_page_number: while page_number <= max_page_number:
response = await self.send_command( response4 = await self.send_sync_command(
hci.HCI_Read_Local_Extended_Features_Command( hci.HCI_Read_Local_Extended_Features_Command(
page_number=page_number page_number=page_number
), )
check_result=True,
) )
lmp_features |= int.from_bytes( lmp_features |= int.from_bytes(
response.return_parameters.extended_lmp_features, 'little' response4.extended_lmp_features, 'little'
) << (64 * page_number) ) << (64 * page_number)
max_page_number = response.return_parameters.maximum_page_number max_page_number = response4.maximum_page_number
page_number += 1 page_number += 1
self.local_lmp_features = hci.LmpFeatureMask(lmp_features) self.local_lmp_features = hci.LmpFeatureMask(lmp_features)
elif self.supports_command(hci.HCI_READ_LOCAL_SUPPORTED_FEATURES_COMMAND): elif self.supports_command(hci.HCI_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command( response5 = await self.send_sync_command(
hci.HCI_Read_Local_Supported_Features_Command(), check_result=True hci.HCI_Read_Local_Supported_Features_Command()
) )
self.local_lmp_features = hci.LmpFeatureMask( self.local_lmp_features = hci.LmpFeatureMask(
int.from_bytes(response.return_parameters.lmp_features, 'little') int.from_bytes(response5.lmp_features, 'little')
) )
await self.send_command( await self.send_sync_command(
hci.HCI_Set_Event_Mask_Command( hci.HCI_Set_Event_Mask_Command(
event_mask=hci.HCI_Set_Event_Mask_Command.mask( event_mask=hci.HCI_Set_Event_Mask_Command.mask(
[ [
@@ -437,7 +441,7 @@ class Host(utils.EventEmitter):
) )
) )
if self.supports_command(hci.HCI_SET_EVENT_MASK_PAGE_2_COMMAND): if self.supports_command(hci.HCI_SET_EVENT_MASK_PAGE_2_COMMAND):
await self.send_command( await self.send_sync_command(
hci.HCI_Set_Event_Mask_Page_2_Command( hci.HCI_Set_Event_Mask_Page_2_Command(
event_mask_page_2=hci.HCI_Set_Event_Mask_Page_2_Command.mask( event_mask_page_2=hci.HCI_Set_Event_Mask_Page_2_Command.mask(
[hci.HCI_ENCRYPTION_CHANGE_V2_EVENT] [hci.HCI_ENCRYPTION_CHANGE_V2_EVENT]
@@ -499,20 +503,14 @@ class Host(utils.EventEmitter):
] ]
) )
await self.send_command( await self.send_sync_command(
hci.HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask) hci.HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask)
) )
if self.supports_command(hci.HCI_READ_BUFFER_SIZE_COMMAND): if self.supports_command(hci.HCI_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command( response6 = await self.send_sync_command(hci.HCI_Read_Buffer_Size_Command())
hci.HCI_Read_Buffer_Size_Command(), check_result=True hc_acl_data_packet_length = response6.hc_acl_data_packet_length
) hc_total_num_acl_data_packets = response6.hc_total_num_acl_data_packets
hc_acl_data_packet_length = (
response.return_parameters.hc_acl_data_packet_length
)
hc_total_num_acl_data_packets = (
response.return_parameters.hc_total_num_acl_data_packets
)
logger.debug( logger.debug(
'HCI ACL flow control: ' 'HCI ACL flow control: '
@@ -531,19 +529,13 @@ class Host(utils.EventEmitter):
iso_data_packet_length = 0 iso_data_packet_length = 0
total_num_iso_data_packets = 0 total_num_iso_data_packets = 0
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_V2_COMMAND): if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await self.send_command( response7 = await self.send_sync_command(
hci.HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True hci.HCI_LE_Read_Buffer_Size_V2_Command()
)
le_acl_data_packet_length = (
response.return_parameters.le_acl_data_packet_length
)
total_num_le_acl_data_packets = (
response.return_parameters.total_num_le_acl_data_packets
)
iso_data_packet_length = response.return_parameters.iso_data_packet_length
total_num_iso_data_packets = (
response.return_parameters.total_num_iso_data_packets
) )
le_acl_data_packet_length = response7.le_acl_data_packet_length
total_num_le_acl_data_packets = response7.total_num_le_acl_data_packets
iso_data_packet_length = response7.iso_data_packet_length
total_num_iso_data_packets = response7.total_num_iso_data_packets
logger.debug( logger.debug(
'HCI LE flow control: ' 'HCI LE flow control: '
@@ -553,15 +545,11 @@ class Host(utils.EventEmitter):
f'total_num_iso_data_packets={total_num_iso_data_packets}' f'total_num_iso_data_packets={total_num_iso_data_packets}'
) )
elif self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND): elif self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command( response8 = await self.send_sync_command(
hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True hci.HCI_LE_Read_Buffer_Size_Command()
)
le_acl_data_packet_length = (
response.return_parameters.le_acl_data_packet_length
)
total_num_le_acl_data_packets = (
response.return_parameters.total_num_le_acl_data_packets
) )
le_acl_data_packet_length = response8.le_acl_data_packet_length
total_num_le_acl_data_packets = response8.total_num_le_acl_data_packets
logger.debug( logger.debug(
'HCI LE ACL flow control: ' 'HCI LE ACL flow control: '
@@ -592,16 +580,16 @@ class Host(utils.EventEmitter):
) and self.supports_command( ) and self.supports_command(
hci.HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND hci.HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
): ):
response = await self.send_command( response9 = await self.send_sync_command(
hci.HCI_LE_Read_Suggested_Default_Data_Length_Command() hci.HCI_LE_Read_Suggested_Default_Data_Length_Command()
) )
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets suggested_max_tx_octets = response9.suggested_max_tx_octets
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time suggested_max_tx_time = response9.suggested_max_tx_time
if ( if (
suggested_max_tx_octets != self.suggested_max_tx_octets suggested_max_tx_octets != self.suggested_max_tx_octets
or suggested_max_tx_time != self.suggested_max_tx_time or suggested_max_tx_time != self.suggested_max_tx_time
): ):
await self.send_command( await self.send_sync_command(
hci.HCI_LE_Write_Suggested_Default_Data_Length_Command( hci.HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets=self.suggested_max_tx_octets, suggested_max_tx_octets=self.suggested_max_tx_octets,
suggested_max_tx_time=self.suggested_max_tx_time, suggested_max_tx_time=self.suggested_max_tx_time,
@@ -611,23 +599,21 @@ class Host(utils.EventEmitter):
if self.supports_command( if self.supports_command(
hci.HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND hci.HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND
): ):
response = await self.send_command( response10 = await self.send_sync_command(
hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command(), hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
check_result=True,
) )
self.number_of_supported_advertising_sets = ( self.number_of_supported_advertising_sets = (
response.return_parameters.num_supported_advertising_sets response10.num_supported_advertising_sets
) )
if self.supports_command( if self.supports_command(
hci.HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND hci.HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND
): ):
response = await self.send_command( response11 = await self.send_sync_command(
hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command(), hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
check_result=True,
) )
self.maximum_advertising_data_length = ( self.maximum_advertising_data_length = (
response.return_parameters.max_advertising_data_length response11.max_advertising_data_length
) )
@property @property
@@ -654,9 +640,11 @@ class Host(utils.EventEmitter):
if self.hci_sink: if self.hci_sink:
self.hci_sink.on_packet(bytes(packet)) self.hci_sink.on_packet(bytes(packet))
async def send_command( async def _send_command(
self, command, check_result=False, response_timeout: int | None = None self,
): command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand,
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event:
# Wait until we can send (only one pending command at a time) # Wait until we can send (only one pending command at a time)
async with self.command_semaphore: async with self.command_semaphore:
assert self.pending_command is None assert self.pending_command is None
@@ -668,29 +656,9 @@ class Host(utils.EventEmitter):
try: try:
self.send_hci_packet(command) self.send_hci_packet(command)
await asyncio.wait_for(self.pending_response, timeout=response_timeout) return await asyncio.wait_for(
response = self.pending_response.result() self.pending_response, timeout=response_timeout
)
# Check the return parameters if required
if check_result:
if isinstance(response, hci.HCI_Command_Status_Event):
status = response.status # type: ignore[attr-defined]
elif isinstance(response.return_parameters, int):
status = response.return_parameters
elif isinstance(response.return_parameters, bytes):
# return parameters first field is a one byte status code
status = response.return_parameters[0]
else:
status = response.return_parameters.status
if status != hci.HCI_SUCCESS:
logger.warning(
f'{command.name} failed '
f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return response
except Exception: except Exception:
logger.exception(color("!!! Exception while sending command:", "red")) logger.exception(color("!!! Exception while sending command:", "red"))
raise raise
@@ -698,12 +666,107 @@ class Host(utils.EventEmitter):
self.pending_command = None self.pending_command = None
self.pending_response = None self.pending_response = None
# Use this method to send a command from a task @overload
def send_command_sync(self, command: hci.HCI_Command) -> None: async def send_command(
async def send_command(command: hci.HCI_Command) -> None: self,
await self.send_command(command) command: hci.HCI_SyncCommand[_RP],
check_result: bool = False,
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event[_RP]: ...
asyncio.create_task(send_command(command)) @overload
async def send_command(
self,
command: hci.HCI_AsyncCommand,
check_result: bool = False,
response_timeout: float | None = None,
) -> hci.HCI_Command_Status_Event: ...
async def send_command(
self,
command: hci.HCI_SyncCommand[_RP] | hci.HCI_AsyncCommand,
check_result: bool = False,
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event[_RP] | hci.HCI_Command_Status_Event:
response = await self._send_command(command, response_timeout)
# Check the return parameters if required
if check_result:
if isinstance(response, hci.HCI_Command_Status_Event):
status = response.status # type: ignore[attr-defined]
elif isinstance(response.return_parameters, int):
status = response.return_parameters
elif isinstance(response.return_parameters, bytes):
# return parameters first field is a one byte status code
status = response.return_parameters[0]
elif isinstance(
response.return_parameters, hci.HCI_GenericReturnParameters
):
# FIXME: temporary workaround
# NO STATUS
status = hci.HCI_SUCCESS
else:
status = response.return_parameters.status
if status != hci.HCI_SUCCESS:
logger.warning(
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return response
async def send_sync_command(
self,
command: hci.HCI_SyncCommand[_RP],
check_status: bool = True,
response_timeout: float | None = None,
) -> _RP:
response = await self._send_command(command, response_timeout)
# Check that the response is of the expected type
assert isinstance(response, hci.HCI_Command_Complete_Event)
return_parameters: _RP = response.return_parameters
assert isinstance(return_parameters, command.return_parameters_class)
# Check the return parameters if required
if check_status:
if isinstance(return_parameters, hci.HCI_StatusReturnParameters):
status = return_parameters.status
if status != hci.HCI_SUCCESS:
logger.warning(
f'{command.name} failed '
f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return return_parameters
async def send_async_command(
self,
command: hci.HCI_AsyncCommand,
check_status: bool = True,
response_timeout: float | None = None,
) -> hci.HCI_ErrorCode:
response = await self._send_command(command, response_timeout)
# Check that the response is of the expected type
assert isinstance(response, hci.HCI_Command_Status_Event)
# Check the return parameters if required
status = response.status
if check_status:
if status != hci.HCI_CommandStatus.PENDING:
logger.warning(
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return hci.HCI_ErrorCode(status)
@utils.deprecated("Use utils.AsyncRunner.spawn() instead.")
def send_command_sync(self, command: hci.HCI_AsyncCommand) -> None:
utils.AsyncRunner.spawn(self.send_async_command(command))
def send_acl_sdu(self, connection_handle: int, sdu: bytes) -> None: def send_acl_sdu(self, connection_handle: int, sdu: bytes) -> None:
if not (connection := self.connections.get(connection_handle)): if not (connection := self.connections.get(connection_handle)):
@@ -1338,15 +1401,17 @@ class Host(utils.EventEmitter):
# For now, just accept everything # For now, just accept everything
# TODO: delegate the decision # TODO: delegate the decision
self.send_command_sync( utils.AsyncRunner.spawn(
hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command( self.send_sync_command(
connection_handle=event.connection_handle, hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
interval_min=event.interval_min, connection_handle=event.connection_handle,
interval_max=event.interval_max, interval_min=event.interval_min,
max_latency=event.max_latency, interval_max=event.interval_max,
timeout=event.timeout, max_latency=event.max_latency,
min_ce_length=0, timeout=event.timeout,
max_ce_length=0, min_ce_length=0,
max_ce_length=0,
)
) )
) )
@@ -1382,9 +1447,9 @@ class Host(utils.EventEmitter):
connection_handle=event.connection_handle connection_handle=event.connection_handle
) )
await self.send_command(response) await self.send_sync_command(response)
asyncio.create_task(send_long_term_key()) utils.AsyncRunner.spawn(send_long_term_key())
def on_hci_synchronous_connection_complete_event( def on_hci_synchronous_connection_complete_event(
self, event: hci.HCI_Synchronous_Connection_Complete_Event self, event: hci.HCI_Synchronous_Connection_Complete_Event
@@ -1583,9 +1648,9 @@ class Host(utils.EventEmitter):
bd_addr=event.bd_addr bd_addr=event.bd_addr
) )
await self.send_command(response) await self.send_sync_command(response)
asyncio.create_task(send_link_key()) utils.AsyncRunner.spawn(send_link_key())
def on_hci_io_capability_request_event( def on_hci_io_capability_request_event(
self, event: hci.HCI_IO_Capability_Request_Event self, event: hci.HCI_IO_Capability_Request_Event

View File

@@ -18,10 +18,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from collections.abc import Callable from collections.abc import Callable
from bumble import device from bumble import device, gatt, gatt_adapters, gatt_client
from bumble import gatt
from bumble import gatt_adapters
from bumble import gatt_client
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -19,19 +19,14 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
from typing import Any
from typing_extensions import Self
from collections.abc import Sequence, Callable
import struct
import enum import enum
import struct
from collections.abc import Callable, Sequence
from typing import Any
from bumble import core from typing_extensions import Self
from bumble import device
from bumble import utils from bumble import att, core, device, gatt, gatt_adapters, gatt_client, utils
from bumble import att
from bumble import gatt
from bumble import gatt_adapters
from bumble import gatt_client
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -43,44 +43,53 @@ hci.HCI_Command.register_commands(globals())
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class HCI_LE_Get_Vendor_Capabilities_Command(hci.HCI_Command): class HCI_LE_Get_Vendor_Capabilities_ReturnParameters(hci.HCI_StatusReturnParameters):
max_advt_instances: int = field(metadata=hci.metadata(1), default=0)
offloaded_resolution_of_private_address: int = field(
metadata=hci.metadata(1), default=0
)
total_scan_results_storage: int = field(metadata=hci.metadata(2), default=0)
max_irk_list_sz: int = field(metadata=hci.metadata(1), default=0)
filtering_support: int = field(metadata=hci.metadata(1), default=0)
max_filter: int = field(metadata=hci.metadata(1), default=0)
activity_energy_info_support: int = field(metadata=hci.metadata(1), default=0)
version_supported: int = field(metadata=hci.metadata(2), default=0)
total_num_of_advt_tracked: int = field(metadata=hci.metadata(2), default=0)
extended_scan_support: int = field(metadata=hci.metadata(1), default=0)
debug_logging_supported: int = field(metadata=hci.metadata(1), default=0)
le_address_generation_offloading_support: int = field(
metadata=hci.metadata(1), default=0
)
a2dp_source_offload_capability_mask: int = field(
metadata=hci.metadata(4), default=0
)
bluetooth_quality_report_support: int = field(metadata=hci.metadata(1), default=0)
dynamic_audio_buffer_support: int = field(metadata=hci.metadata(4), default=0)
@hci.HCI_SyncCommand.sync_command(HCI_LE_Get_Vendor_Capabilities_ReturnParameters)
@dataclasses.dataclass
class HCI_LE_Get_Vendor_Capabilities_Command(
hci.HCI_SyncCommand[HCI_LE_Get_Vendor_Capabilities_ReturnParameters]
):
# pylint: disable=line-too-long # pylint: disable=line-too-long
''' '''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#vendor-specific-capabilities See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#vendor-specific-capabilities
''' '''
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('max_advt_instances', 1),
('offloaded_resolution_of_private_address', 1),
('total_scan_results_storage', 2),
('max_irk_list_sz', 1),
('filtering_support', 1),
('max_filter', 1),
('activity_energy_info_support', 1),
('version_supported', 2),
('total_num_of_advt_tracked', 2),
('extended_scan_support', 1),
('debug_logging_supported', 1),
('le_address_generation_offloading_support', 1),
('a2dp_source_offload_capability_mask', 4),
('bluetooth_quality_report_support', 1),
('dynamic_audio_buffer_support', 4),
]
@classmethod @classmethod
def parse_return_parameters(cls, parameters): def parse_return_parameters(cls, parameters):
# There are many versions of this data structure, so we need to parse until # There are many versions of this data structure, so we need to parse until
# there are no more bytes to parse, and leave un-signal parameters set to # there are no more bytes to parse, and leave un-signaled parameters set to
# None (older versions) # 0
nones = {field: None for field, _ in cls.return_parameters_fields} return_parameters = HCI_LE_Get_Vendor_Capabilities_ReturnParameters(
return_parameters = hci.HCI_Object(cls.return_parameters_fields, **nones) hci.HCI_ErrorCode.SUCCESS
)
try: try:
offset = 0 offset = 0
for field in cls.return_parameters_fields: for field in cls.return_parameters_class.fields:
field_name, field_type = field field_name, field_type = field
field_value, field_size = hci.HCI_Object.parse_field( field_value, field_size = hci.HCI_Object.parse_field(
parameters, offset, field_type parameters, offset, field_type
@@ -94,9 +103,30 @@ class HCI_LE_Get_Vendor_Capabilities_Command(hci.HCI_Command):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command # APCF Subcommands
class LeApcfOpcode(hci.SpecableEnum):
ENABLE = 0x00
SET_FILTERING_PARAMETERS = 0x01
BROADCASTER_ADDRESS = 0x02
SERVICE_UUID = 0x03
SERVICE_SOLICITATION_UUID = 0x04
LOCAL_NAME = 0x05
MANUFACTURER_DATA = 0x06
SERVICE_DATA = 0x07
TRANSPORT_DISCOVERY_SERVICE = 0x08
AD_TYPE_FILTER = 0x09
READ_EXTENDED_FEATURES = 0xFF
@dataclasses.dataclass @dataclasses.dataclass
class HCI_LE_APCF_Command(hci.HCI_Command): class HCI_LE_APCF_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = field(metadata=LeApcfOpcode.type_metadata(1))
payload: bytes = field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_LE_APCF_ReturnParameters)
@dataclasses.dataclass
class HCI_LE_APCF_Command(hci.HCI_SyncCommand[HCI_LE_APCF_ReturnParameters]):
# pylint: disable=line-too-long # pylint: disable=line-too-long
''' '''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_apcf_command See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_apcf_command
@@ -105,52 +135,52 @@ class HCI_LE_APCF_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures. implementation. A future enhancement may define subcommand-specific data structures.
''' '''
# APCF Subcommands opcode: int = dataclasses.field(metadata=LeApcfOpcode.type_metadata(1))
class Opcode(hci.SpecableEnum):
ENABLE = 0x00
SET_FILTERING_PARAMETERS = 0x01
BROADCASTER_ADDRESS = 0x02
SERVICE_UUID = 0x03
SERVICE_SOLICITATION_UUID = 0x04
LOCAL_NAME = 0x05
MANUFACTURER_DATA = 0x06
SERVICE_DATA = 0x07
TRANSPORT_DISCOVERY_SERVICE = 0x08
AD_TYPE_FILTER = 0x09
READ_EXTENDED_FEATURES = 0xFF
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*")) payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class HCI_Get_Controller_Activity_Energy_Info_Command(hci.HCI_Command): class HCI_Get_Controller_Activity_Energy_Info_ReturnParameters(
hci.HCI_StatusReturnParameters
):
total_tx_time_ms: int = field(metadata=hci.metadata(4))
total_rx_time_ms: int = field(metadata=hci.metadata(4))
total_idle_time_ms: int = field(metadata=hci.metadata(4))
total_energy_used: int = field(metadata=hci.metadata(4))
@hci.HCI_SyncCommand.sync_command(
HCI_Get_Controller_Activity_Energy_Info_ReturnParameters
)
@dataclasses.dataclass
class HCI_Get_Controller_Activity_Energy_Info_Command(
hci.HCI_SyncCommand[HCI_Get_Controller_Activity_Energy_Info_ReturnParameters]
):
# pylint: disable=line-too-long # pylint: disable=line-too-long
''' '''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_get_controller_activity_energy_info See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_get_controller_activity_energy_info
''' '''
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('total_tx_time_ms', 4),
('total_rx_time_ms', 4),
('total_idle_time_ms', 4),
('total_energy_used', 4),
]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command # A2DP Hardware Offload Subcommands
class A2dpHardwareOffloadOpcode(hci.SpecableEnum):
START_A2DP_OFFLOAD = 0x01
STOP_A2DP_OFFLOAD = 0x02
@dataclasses.dataclass @dataclasses.dataclass
class HCI_A2DP_Hardware_Offload_Command(hci.HCI_Command): class HCI_A2DP_Hardware_Offload_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = dataclasses.field(metadata=A2dpHardwareOffloadOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
@hci.HCI_SyncCommand.sync_command(HCI_A2DP_Hardware_Offload_ReturnParameters)
@dataclasses.dataclass
class HCI_A2DP_Hardware_Offload_Command(
hci.HCI_SyncCommand[HCI_A2DP_Hardware_Offload_ReturnParameters]
):
# pylint: disable=line-too-long # pylint: disable=line-too-long
''' '''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#a2dp-hardware-offload-support See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#a2dp-hardware-offload-support
@@ -159,25 +189,27 @@ class HCI_A2DP_Hardware_Offload_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures. implementation. A future enhancement may define subcommand-specific data structures.
''' '''
# A2DP Hardware Offload Subcommands opcode: int = dataclasses.field(metadata=A2dpHardwareOffloadOpcode.type_metadata(1))
class Opcode(hci.SpecableEnum):
START_A2DP_OFFLOAD = 0x01
STOP_A2DP_OFFLOAD = 0x02
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*")) payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command # Dynamic Audio Buffer Subcommands
class DynamicAudioBufferOpcode(hci.SpecableEnum):
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
@dataclasses.dataclass @dataclasses.dataclass
class HCI_Dynamic_Audio_Buffer_Command(hci.HCI_Command): class HCI_Dynamic_Audio_Buffer_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = dataclasses.field(metadata=DynamicAudioBufferOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
@hci.HCI_SyncCommand.sync_command(HCI_Dynamic_Audio_Buffer_ReturnParameters)
@dataclasses.dataclass
class HCI_Dynamic_Audio_Buffer_Command(
hci.HCI_SyncCommand[HCI_Dynamic_Audio_Buffer_ReturnParameters]
):
# pylint: disable=line-too-long # pylint: disable=line-too-long
''' '''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#dynamic-audio-buffer-command See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#dynamic-audio-buffer-command
@@ -186,19 +218,9 @@ class HCI_Dynamic_Audio_Buffer_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures. implementation. A future enhancement may define subcommand-specific data structures.
''' '''
# Dynamic Audio Buffer Subcommands opcode: int = dataclasses.field(metadata=DynamicAudioBufferOpcode.type_metadata(1))
class Opcode(hci.SpecableEnum):
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*")) payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HCI_Android_Vendor_Event(hci.HCI_Extended_Event): class HCI_Android_Vendor_Event(hci.HCI_Extended_Event):

View File

@@ -46,9 +46,19 @@ class TX_Power_Level_Command:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class HCI_Write_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command): class HCI_Write_Tx_Power_Level_ReturnParameters(hci.HCI_StatusReturnParameters):
handle_type: int = hci.field(metadata=hci.metadata(1))
connection_handle: int = hci.field(metadata=hci.metadata(2))
selected_tx_power_level: int = hci.field(metadata=hci.metadata(-1))
@hci.HCI_SyncCommand.sync_command(HCI_Write_Tx_Power_Level_ReturnParameters)
@dataclasses.dataclass
class HCI_Write_Tx_Power_Level_Command(
hci.HCI_SyncCommand[HCI_Write_Tx_Power_Level_ReturnParameters],
TX_Power_Level_Command,
):
''' '''
Write TX power level. See BT_HCI_OP_VS_WRITE_TX_POWER_LEVEL in Write TX power level. See BT_HCI_OP_VS_WRITE_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
@@ -61,18 +71,21 @@ class HCI_Write_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
connection_handle: int = dataclasses.field(metadata=hci.metadata(2)) connection_handle: int = dataclasses.field(metadata=hci.metadata(2))
tx_power_level: int = dataclasses.field(metadata=hci.metadata(-1)) tx_power_level: int = dataclasses.field(metadata=hci.metadata(-1))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('selected_tx_power_level', -1),
]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class HCI_Read_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command): class HCI_Read_Tx_Power_Level_ReturnParameters(hci.HCI_StatusReturnParameters):
handle_type: int = hci.field(metadata=hci.metadata(1))
connection_handle: int = hci.field(metadata=hci.metadata(2))
tx_power_level: int = hci.field(metadata=hci.metadata(-1))
@hci.HCI_SyncCommand.sync_command(HCI_Read_Tx_Power_Level_ReturnParameters)
@dataclasses.dataclass
class HCI_Read_Tx_Power_Level_Command(
hci.HCI_SyncCommand[HCI_Read_Tx_Power_Level_ReturnParameters],
TX_Power_Level_Command,
):
''' '''
Read TX power level. See BT_HCI_OP_VS_READ_TX_POWER_LEVEL in Read TX power level. See BT_HCI_OP_VS_READ_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
@@ -83,10 +96,3 @@ class HCI_Read_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
handle_type: int = dataclasses.field(metadata=hci.metadata(1)) handle_type: int = dataclasses.field(metadata=hci.metadata(1))
connection_handle: int = dataclasses.field(metadata=hci.metadata(2)) connection_handle: int = dataclasses.field(metadata=hci.metadata(2))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('tx_power_level', -1),
]

View File

@@ -37,15 +37,16 @@ The vendor specific HCI commands to read and write TX power are defined in
from bumble.vendor.zephyr.hci import HCI_Write_Tx_Power_Level_Command from bumble.vendor.zephyr.hci import HCI_Write_Tx_Power_Level_Command
# set advertising power to -4 dB # set advertising power to -4 dB
response = await host.send_command( response = await host.send_sync_command(
HCI_Write_Tx_Power_Level_Command( HCI_Write_Tx_Power_Level_Command(
handle_type=HCI_Write_Tx_Power_Level_Command.TX_POWER_HANDLE_TYPE_ADV, handle_type=HCI_Write_Tx_Power_Level_Command.TX_POWER_HANDLE_TYPE_ADV,
connection_handle=0, connection_handle=0,
tx_power_level=-4, tx_power_level=-4,
) ),
check_status=False
) )
if response.return_parameters.status == HCI_SUCCESS: if response.status == HCI_SUCCESS:
print(f"TX power set to {response.return_parameters.selected_tx_power_level}") print(f"TX power set to {response.selected_tx_power_level}")
``` ```

View File

@@ -42,7 +42,6 @@ from bumble.hci import (
HCI_CREATE_CONNECTION_COMMAND, HCI_CREATE_CONNECTION_COMMAND,
HCI_SUCCESS, HCI_SUCCESS,
Address, Address,
HCI_Command_Complete_Event,
HCI_Command_Status_Event, HCI_Command_Status_Event,
HCI_Connection_Complete_Event, HCI_Connection_Complete_Event,
HCI_Connection_Request_Event, HCI_Connection_Request_Event,
@@ -154,10 +153,10 @@ async def test_device_connect_parallel():
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d1.host.on_hci_packet( d1.host.on_hci_packet(
HCI_Command_Complete_Event( HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
) )
) )
@@ -188,10 +187,10 @@ async def test_device_connect_parallel():
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d2.host.on_hci_packet( d2.host.on_hci_packet(
HCI_Command_Complete_Event( HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
) )
) )

View File

@@ -20,7 +20,7 @@ import struct
import pytest import pytest
from bumble import hci from bumble import hci, utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# pylint: disable=invalid-name # pylint: disable=invalid-name
@@ -136,43 +136,25 @@ def test_HCI_LE_Channel_Selection_Algorithm_Event():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_HCI_Command_Complete_Event(): def test_HCI_Command_Complete_Event():
# With a serializable object # With a serializable object
event = hci.HCI_Command_Complete_Event( event1 = hci.HCI_Command_Complete_Event(
num_hci_command_packets=34, num_hci_command_packets=34,
command_opcode=hci.HCI_LE_READ_BUFFER_SIZE_COMMAND, command_opcode=hci.HCI_LE_READ_BUFFER_SIZE_COMMAND,
return_parameters=hci.HCI_LE_Read_Buffer_Size_Command.create_return_parameters( return_parameters=hci.HCI_LE_Read_Buffer_Size_Command.return_parameters_class(
status=0, status=0,
le_acl_data_packet_length=1234, le_acl_data_packet_length=1234,
total_num_le_acl_data_packets=56, total_num_le_acl_data_packets=56,
), ),
) )
basic_check(event) basic_check(event1)
# With an arbitrary byte array
event = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=bytes([1, 2, 3, 4]),
)
basic_check(event)
# With a simple status as a 1-byte array
event = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=bytes([7]),
)
basic_check(event)
event = hci.HCI_Packet.from_bytes(bytes(event))
assert event.return_parameters == 7
# With a simple status as an integer status # With a simple status as an integer status
event = hci.HCI_Command_Complete_Event( event3 = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND, command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=9, return_parameters=hci.HCI_StatusReturnParameters(hci.HCI_ErrorCode(9)),
) )
basic_check(event) basic_check(event3)
assert event.return_parameters == 9 assert event3.return_parameters.status == 9
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -229,6 +211,28 @@ def test_HCI_Vendor_Event():
assert isinstance(parsed, hci.HCI_Vendor_Event) assert isinstance(parsed, hci.HCI_Vendor_Event)
# -----------------------------------------------------------------------------
def test_return_parameters() -> None:
params = hci.HCI_Reset_Command.parse_return_parameters(bytes.fromhex('3C'))
assert params.status == hci.HCI_ErrorCode.ADVERTISING_TIMEOUT_ERROR
assert isinstance(params.status, utils.OpenIntEnum)
params = hci.HCI_Read_BD_ADDR_Command.parse_return_parameters(
bytes.fromhex('3C001122334455')
)
assert params.status == hci.HCI_ErrorCode.ADVERTISING_TIMEOUT_ERROR
assert isinstance(params.status, utils.OpenIntEnum)
assert isinstance(params.bd_addr, hci.Address)
params = hci.HCI_Read_Local_Name_Command.parse_return_parameters(
bytes.fromhex('0068656c6c6f') + bytes(248 - 5)
)
assert params.status == hci.HCI_ErrorCode.SUCCESS
assert isinstance(params.local_name, bytes)
assert len(params.local_name) == 248
assert hci.map_null_terminated_utf8_string(params.local_name) == 'hello'
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_HCI_Command(): def test_HCI_Command():
command = hci.HCI_Command(op_code=0x5566) command = hci.HCI_Command(op_code=0x5566)
@@ -291,7 +295,7 @@ def test_custom_le_meta_event():
for clazz in inspect.getmembers(hci) for clazz in inspect.getmembers(hci)
if isinstance(clazz[1], type) if isinstance(clazz[1], type)
and issubclass(clazz[1], hci.HCI_Command) and issubclass(clazz[1], hci.HCI_Command)
and clazz[1] is not hci.HCI_Command and clazz[1] not in (hci.HCI_Command, hci.HCI_SyncCommand, hci.HCI_AsyncCommand)
], ],
) )
def test_hci_command_subclasses_op_code(clazz: type[hci.HCI_Command]): def test_hci_command_subclasses_op_code(clazz: type[hci.HCI_Command]):
@@ -620,21 +624,19 @@ def test_HCI_Read_Local_Supported_Codecs_Command_Complete():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete(): def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete():
returned_parameters = ( returned_parameters = hci.HCI_Read_Local_Supported_Codecs_V2_Command.parse_return_parameters(
hci.HCI_Read_Local_Supported_Codecs_V2_Command.parse_return_parameters( bytes(
bytes( [
[ hci.HCI_SUCCESS,
hci.HCI_SUCCESS, 3,
3, hci.CodecID.A_LOG,
hci.CodecID.A_LOG, hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_ACL,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL, hci.CodecID.CVSD,
hci.CodecID.CVSD, hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_SCO,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO, hci.CodecID.LINEAR_PCM,
hci.CodecID.LINEAR_PCM, hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.LE_CIS,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS, 0,
0, ]
]
)
) )
) )
assert returned_parameters.standard_codec_ids == [ assert returned_parameters.standard_codec_ids == [
@@ -643,9 +645,9 @@ def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete():
hci.CodecID.LINEAR_PCM, hci.CodecID.LINEAR_PCM,
] ]
assert returned_parameters.standard_codec_transports == [ assert returned_parameters.standard_codec_transports == [
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL, hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_ACL,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO, hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_SCO,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS, hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.LE_CIS,
] ]
@@ -737,6 +739,7 @@ def run_test_commands():
if __name__ == '__main__': if __name__ == '__main__':
run_test_events() run_test_events()
run_test_commands() run_test_commands()
test_return_parameters()
test_address() test_address()
test_custom() test_custom()
test_iso_data_packet() test_iso_data_packet()

View File

@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence
import asyncio import asyncio
import itertools import itertools
from collections.abc import Sequence
import pytest import pytest

View File

@@ -15,6 +15,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio
import logging import logging
import unittest import unittest
import unittest.mock import unittest.mock
@@ -22,9 +23,17 @@ import unittest.mock
import pytest import pytest
from bumble.controller import Controller from bumble.controller import Controller
from bumble.hci import HCI_AclDataPacket from bumble.hci import (
HCI_AclDataPacket,
HCI_Command_Complete_Event,
HCI_Error,
HCI_ErrorCode,
HCI_Event,
HCI_Reset_Command,
HCI_StatusReturnParameters,
)
from bumble.host import DataPacketQueue, Host from bumble.host import DataPacketQueue, Host
from bumble.transport.common import AsyncPipeSink from bumble.transport.common import AsyncPipeSink, TransportSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -151,3 +160,58 @@ def test_data_packet_queue():
assert drain_listener.on_flow.call_count == 1 assert drain_listener.on_flow.call_count == 1
assert queue.queued == 15 assert queue.queued == 15
assert queue.completed == 15 assert queue.completed == 15
# -----------------------------------------------------------------------------
class Source:
terminated: asyncio.Future[None]
sink: TransportSink
def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink
class Sink:
response: HCI_Event
def __init__(self, source: Source, response: HCI_Event) -> None:
self.source = source
self.response = response
def on_packet(self, packet: bytes) -> None:
self.source.sink.on_packet(bytes(self.response))
@pytest.mark.asyncio
async def test_send_sync_command() -> None:
source = Source()
sink = Sink(
source,
HCI_Command_Complete_Event(
1,
HCI_Reset_Command.op_code,
HCI_StatusReturnParameters(status=HCI_ErrorCode.SUCCESS),
),
)
host = Host(source, sink)
# Sync command with success
response1 = await host.send_sync_command(HCI_Reset_Command())
assert response1.status == HCI_ErrorCode.SUCCESS
# Sync command with error status should raise
error_response = HCI_Command_Complete_Event(
1,
HCI_Reset_Command.op_code,
HCI_StatusReturnParameters(status=HCI_ErrorCode.COMMAND_DISALLOWED_ERROR),
)
sink.response = error_response
with pytest.raises(HCI_Error) as excinfo:
await host.send_sync_command(HCI_Reset_Command())
assert excinfo.value.error_code == error_response.return_parameters.status
# Sync command with error status should not raise when `check_status` is False
response2 = await host.send_sync_command(HCI_Reset_Command(), check_status=False)
assert response2.status == HCI_ErrorCode.COMMAND_DISALLOWED_ERROR

View File

@@ -89,7 +89,7 @@ class HeartRateMonitor:
async def stop(self): async def stop(self):
# TODO: replace this once a proper reset is implemented in the lib. # TODO: replace this once a proper reset is implemented in the lib.
await self.device.host.send_command(HCI_Reset_Command()) await self.device.host.send_sync_command(HCI_Reset_Command())
await self.device.power_off() await self.device.power_off()
print('### Monitor stopped') print('### Monitor stopped')

View File

@@ -60,7 +60,7 @@ class Scanner(utils.EventEmitter):
async def stop(self): async def stop(self):
# TODO: replace this once a proper reset is implemented in the lib. # TODO: replace this once a proper reset is implemented in the lib.
await self.device.host.send_command(HCI_Reset_Command()) await self.device.host.send_sync_command(HCI_Reset_Command())
await self.device.power_off() await self.device.power_off()
print('### Scanner stopped') print('### Scanner stopped')

View File

@@ -311,7 +311,7 @@ class Speaker:
async def stop(self): async def stop(self):
# TODO: replace this once a proper reset is implemented in the lib. # TODO: replace this once a proper reset is implemented in the lib.
await self.device.host.send_command(HCI_Reset_Command()) await self.device.host.send_sync_command(HCI_Reset_Command())
await self.device.power_off() await self.device.power_off()
print('Speaker stopped') print('Speaker stopped')