Compare commits

...

17 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod 7523118581 typing surrport for HCI commands return parameters 2026-01-17 13:19:36 -08:00
Gilles Boccon-Gibod 2cad743f8c Merge pull request #854 from TinyServal/rtl8761cu
Add support for RTL8761CU
2026-01-08 18:37:21 -08:00
zxzxwu 255414f315 Merge pull request #857 from zxzxwu/testing
Add test for Heart Rate and Battery Service
2026-01-08 17:52:12 +08:00
Josh Wu d2df76f6f4 Add test for Heart Rate and Battery Service 2026-01-08 16:42:05 +08:00
zxzxwu 884b1c20e4 Merge pull request #856 from zxzxwu/typing
Add annotation for Heart Rate and Battery Service
2026-01-08 15:29:50 +08:00
Josh Wu 91a2b4f676 Add annotation for Heart Rate and Battery Service 2026-01-08 14:43:27 +08:00
Bowen Yan 5831f79d62 Add support for the RTL8761CU 2026-01-08 16:50:11 +11:00
zxzxwu 36f81b798c Merge pull request #853 from zxzxwu/l2cap
L2CAP: Fix segmentation and frame ack
2026-01-08 09:40:13 +08:00
Gilles Boccon-Gibod 985183001f Merge pull request #855 from encarbassotnopot/patch-1
docs: fix a small error in hci socket up/down commands
2026-01-07 14:26:15 -08:00
Josh Wu b153d0fcde L2CAP: Fix Enhanced Retransmission Segmentation 2026-01-07 23:49:57 +08:00
Eina Safor 30d912d66e docs: fix a small error in hci socket up/down commands 2026-01-07 15:59:14 +01:00
Bowen Yan 054dc70f3f Exclude macOS xattr files 2026-01-07 15:00:21 +11:00
zxzxwu 8ac8724cd8 Merge pull request #851 from zxzxwu/fix
Fix some typos and annotations
2026-01-06 14:02:40 +08:00
Josh Wu 4c3746a5b2 Fix some typos and annotations 2026-01-05 23:53:22 +08:00
zxzxwu 566ef967f4 Merge pull request #836 from zxzxwu/eatt
Add EATT Support
2026-01-05 22:26:17 +08:00
Josh Wu df697c6513 Add EATT Support 2026-01-04 21:51:50 +08:00
Gilles Boccon-Gibod e3e1b7bc5b Merge pull request #849 from google/gbg/auracast-multi-broadcast 2026-01-02 09:02:15 -08:00
38 changed files with 3136 additions and 2053 deletions
+3
View File
@@ -17,3 +17,6 @@ venv/
.venv/
# snoop logs
out/
# macOS
.DS_Store
._*
+76 -107
View File
@@ -34,11 +34,7 @@ from bumble.hci import (
HCI_READ_BD_ADDR_COMMAND,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_SUCCESS,
CodecID,
HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_Read_Buffer_Size_V2_Command,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
@@ -59,34 +55,23 @@ from bumble.host import Host
from bumble.transport import open_transport
# -----------------------------------------------------------------------------
def command_succeeded(response):
if isinstance(response, HCI_Command_Status_Event):
return response.status == HCI_SUCCESS
if isinstance(response, HCI_Command_Complete_Event):
return response.return_parameters.status == HCI_SUCCESS
return False
# -----------------------------------------------------------------------------
async def get_classic_info(host: Host) -> None:
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response):
print()
print(
color('Public Address:', 'yellow'),
response.return_parameters.bd_addr.to_string(False),
)
response1 = await host.send_sync_command(HCI_Read_BD_ADDR_Command())
print()
print(
color('Public Address:', 'yellow'),
response1.bd_addr.to_string(False),
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response = await host.send_command(HCI_Read_Local_Name_Command())
if command_succeeded(response):
print()
print(
color('Local Name:', 'yellow'),
map_null_terminated_utf8_string(response.return_parameters.local_name),
)
response2 = await host.send_sync_command(HCI_Read_Local_Name_Command())
print()
print(
color('Local Name:', 'yellow'),
map_null_terminated_utf8_string(response2.local_name),
)
# -----------------------------------------------------------------------------
@@ -94,52 +79,50 @@ async def get_le_info(host: Host) -> None:
print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
response = await host.send_command(
response1 = await host.send_sync_command(
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
if command_succeeded(response):
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response.return_parameters.num_supported_advertising_sets,
'\n',
)
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response1.num_supported_advertising_sets,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND):
response = await host.send_command(
response2 = await host.send_sync_command(
HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
if command_succeeded(response):
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response.return_parameters.max_advertising_data_length,
'\n',
)
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response2.max_advertising_data_length,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
response = await host.send_command(HCI_LE_Read_Maximum_Data_Length_Command())
if command_succeeded(response):
print(
color('Maximum Data Length:', 'yellow'),
(
f'tx:{response.return_parameters.supported_max_tx_octets}/'
f'{response.return_parameters.supported_max_tx_time}, '
f'rx:{response.return_parameters.supported_max_rx_octets}/'
f'{response.return_parameters.supported_max_rx_time}'
),
'\n',
)
response3 = await host.send_sync_command(
HCI_LE_Read_Maximum_Data_Length_Command()
)
print(
color('Maximum Data Length:', 'yellow'),
(
f'tx:{response3.supported_max_tx_octets}/'
f'{response3.supported_max_tx_time}, '
f'rx:{response3.supported_max_rx_octets}/'
f'{response3.supported_max_rx_time}'
),
'\n',
)
if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await host.send_command(
response4 = await host.send_sync_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
if command_succeeded(response):
print(
color('Suggested Default Data Length:', 'yellow'),
f'{response.return_parameters.suggested_max_tx_octets}/'
f'{response.return_parameters.suggested_max_tx_time}',
'\n',
)
print(
color('Suggested Default Data Length:', 'yellow'),
f'{response4.suggested_max_tx_octets}/'
f'{response4.suggested_max_tx_time}',
'\n',
)
print(color('LE Features:', 'yellow'))
for feature in host.supported_le_features:
@@ -151,37 +134,31 @@ async def get_flow_control_info(host: Host) -> None:
print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
response1 = await host.send_sync_command(HCI_Read_Buffer_Size_Command())
print(
color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
f'{response1.hc_total_num_acl_data_packets} '
f'packets of size {response1.hc_acl_data_packet_length}',
)
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
response2 = await host.send_sync_command(HCI_LE_Read_Buffer_Size_V2_Command())
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.le_acl_data_packet_length}',
f'{response2.total_num_le_acl_data_packets} '
f'packets of size {response2.le_acl_data_packet_length}',
)
print(
color('LE ISO Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_iso_data_packets} '
f'packets of size {response.return_parameters.iso_data_packet_length}',
f'{response2.total_num_iso_data_packets} '
f'packets of size {response2.iso_data_packet_length}',
)
elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
response3 = await host.send_sync_command(HCI_LE_Read_Buffer_Size_Command())
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.le_acl_data_packet_length}',
f'{response3.total_num_le_acl_data_packets} '
f'packets of size {response3.le_acl_data_packet_length}',
)
@@ -190,52 +167,44 @@ async def get_codecs_info(host: Host) -> None:
print()
if host.supports_command(HCI_Read_Local_Supported_Codecs_V2_Command.op_code):
response = await host.send_command(
HCI_Read_Local_Supported_Codecs_V2_Command(), check_result=True
response1 = await host.send_sync_command(
HCI_Read_Local_Supported_Codecs_V2_Command()
)
print(color('Codecs:', 'yellow'))
for codec_id, transport in zip(
response.return_parameters.standard_codec_ids,
response.return_parameters.standard_codec_transports,
response1.standard_codec_ids,
response1.standard_codec_transports,
):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
transport
).name
codec_name = CodecID(codec_id).name
print(f' {codec_name} - {transport_name}')
print(f' {codec_id.name} - {transport.name}')
for codec_id, transport in zip(
response.return_parameters.vendor_specific_codec_ids,
response.return_parameters.vendor_specific_codec_transports,
for vendor_codec_id, vendor_transport in zip(
response1.vendor_specific_codec_ids,
response1.vendor_specific_codec_transports,
):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
transport
).name
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF} - {transport_name}')
company = name_or_number(COMPANY_IDENTIFIERS, vendor_codec_id >> 16)
print(f' {company} / {vendor_codec_id & 0xFFFF} - {vendor_transport.name}')
if not response.return_parameters.standard_codec_ids:
if not response1.standard_codec_ids:
print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids:
if not response1.vendor_specific_codec_ids:
print(' No Vendor-specific codecs')
if host.supports_command(HCI_Read_Local_Supported_Codecs_Command.op_code):
response = await host.send_command(
HCI_Read_Local_Supported_Codecs_Command(), check_result=True
response2 = await host.send_sync_command(
HCI_Read_Local_Supported_Codecs_Command()
)
print(color('Codecs (BR/EDR):', 'yellow'))
for codec_id in response.return_parameters.standard_codec_ids:
codec_name = CodecID(codec_id).name
print(f' {codec_name}')
for codec_id in response2.standard_codec_ids:
print(f' {codec_id.name}')
for codec_id in response.return_parameters.vendor_specific_codec_ids:
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF}')
for vendor_codec_id in response2.vendor_specific_codec_ids:
company = name_or_number(COMPANY_IDENTIFIERS, vendor_codec_id >> 16)
print(f' {company} / {vendor_codec_id & 0xFFFF}')
if not response.return_parameters.standard_codec_ids:
if not response2.standard_codec_ids:
print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids:
if not response2.vendor_specific_codec_ids:
print(' No Vendor-specific codecs')
+11 -9
View File
@@ -85,7 +85,7 @@ class Loopback:
print(color('@@@ Received last packet', 'green'))
self.done.set()
async def run(self):
async def run(self) -> None:
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport(self.transport) as (
@@ -100,11 +100,15 @@ class Loopback:
# make sure data can fit in one l2cap pdu
l2cap_header_size = 4
max_packet_size = (
packet_queue = (
host.acl_packet_queue
if host.acl_packet_queue
else host.le_acl_packet_queue
).max_packet_size - l2cap_header_size
)
if packet_queue is None:
print(color('!!! No packet queue', 'red'))
return
max_packet_size = packet_queue.max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size:
print(
color(
@@ -128,20 +132,18 @@ class Loopback:
loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue'))
await host.send_command(
await host.send_sync_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
)
print(color('### Checking loopback mode', 'blue'))
response = await host.send_command(
HCI_Read_Loopback_Mode_Command(), check_result=True
)
if response.return_parameters.loopback_mode != loopback_mode:
response = await host.send_sync_command(HCI_Read_Loopback_Mode_Command())
if response.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red'))
return
await self.connection_event.wait()
assert self.connection_handle is not None
print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta'))
+1
View File
@@ -298,6 +298,7 @@ class Speaker:
advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'),
identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
eatt_enabled=True,
)
device_config.le_enabled = True
+42 -39
View File
@@ -15,6 +15,8 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import os
@@ -63,7 +65,7 @@ POST_PAIRING_DELAY = 1
# -----------------------------------------------------------------------------
class Waiter:
instance = None
instance: Waiter | None = None
def __init__(self, linger=False):
self.done = asyncio.get_running_loop().create_future()
@@ -327,25 +329,25 @@ async def on_pairing_failure(connection, reason):
# -----------------------------------------------------------------------------
async def pair(
mode,
sc,
mitm,
bond,
ctkd,
advertising_address,
identity_address,
linger,
io,
oob,
prompt,
request,
print_keys,
keystore_file,
advertise_service_uuids,
advertise_appearance,
device_config,
hci_transport,
address_or_name,
mode: str,
sc: bool,
mitm: bool,
bond: bool,
ctkd: bool,
advertising_address: str,
identity_address: str,
linger: bool,
io: str,
oob: str,
prompt: bool,
request: bool,
print_keys: bool,
keystore_file: str,
advertise_service_uuids: str,
advertise_appearance: str,
device_config: str,
hci_transport: str,
address_or_name: str,
):
Waiter.instance = Waiter(linger=linger)
@@ -403,6 +405,7 @@ async def pair(
# Create an OOB context if needed
if oob:
our_oob_context = OobContext()
legacy_context: OobLegacyContext | None
if oob == '-':
shared_data = None
legacy_context = OobLegacyContext()
@@ -661,25 +664,25 @@ class LogHandler(logging.Handler):
@click.argument('hci_transport')
@click.argument('address-or-name', required=False)
def main(
mode,
sc,
mitm,
bond,
ctkd,
advertising_address,
identity_address,
linger,
io,
oob,
prompt,
request,
print_keys,
keystore_file,
advertise_service_uuid,
advertise_appearance,
device_config,
hci_transport,
address_or_name,
mode: str,
sc: bool,
mitm: bool,
bond: bool,
ctkd: bool,
advertising_address: str,
identity_address: str,
linger: bool,
io: str,
oob: str,
prompt: bool,
request: bool,
print_keys: bool,
keystore_file: str,
advertise_service_uuid: str,
advertise_appearance: str,
device_config: str,
hci_transport: str,
address_or_name: str,
):
# Setup logging
log_handler = LogHandler()
+74 -3
View File
@@ -34,10 +34,13 @@ from typing import (
TYPE_CHECKING,
ClassVar,
Generic,
TypeAlias,
TypeVar,
)
from bumble import hci, utils
from typing_extensions import TypeIs
from bumble import hci, l2cap, utils
from bumble.colors import color
from bumble.core import UUID, InvalidOperationError, ProtocolError
from bumble.hci import HCI_Object
@@ -50,6 +53,14 @@ if TYPE_CHECKING:
_T = TypeVar('_T')
Bearer: TypeAlias = "Connection | l2cap.LeCreditBasedChannel"
EnhancedBearer: TypeAlias = l2cap.LeCreditBasedChannel
def is_enhanced_bearer(bearer: Bearer) -> TypeIs[EnhancedBearer]:
return isinstance(bearer, EnhancedBearer)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -58,6 +69,7 @@ _T = TypeVar('_T')
ATT_CID = 0x04
ATT_PSM = 0x001F
EATT_PSM = 0x0027
class Opcode(hci.SpecableEnum):
ATT_ERROR_RESPONSE = 0x01
@@ -780,6 +792,43 @@ class AttributeValue(Generic[_T]):
return self._write(connection, value)
# -----------------------------------------------------------------------------
class AttributeValueV2(Generic[_T]):
'''
Attribute value compatible with enhanced bearers.
The only difference between AttributeValue and AttributeValueV2 is that the actual
bearer (ACL connection for un-enhanced bearer, L2CAP channel for enhanced bearer)
will be passed into read and write callbacks in V2, while in V1 it is always
the base ACL connection.
This is only required when attributes must distinguish bearers, otherwise normal
`AttributeValue` objects are also applicable in enhanced bearers.
'''
def __init__(
self,
read: Callable[[Bearer], Awaitable[_T]] | Callable[[Bearer], _T] | None = None,
write: (
Callable[[Bearer, _T], Awaitable[None]]
| Callable[[Bearer, _T], None]
| None
) = None,
):
self._read = read
self._write = write
def read(self, bearer: Bearer) -> _T | Awaitable[_T]:
if self._read is None:
raise InvalidOperationError('AttributeValue has no read function')
return self._read(bearer)
def write(self, bearer: Bearer, value: _T) -> Awaitable[None] | None:
if self._write is None:
raise InvalidOperationError('AttributeValue has no write function')
return self._write(bearer, value)
# -----------------------------------------------------------------------------
class Attribute(utils.EventEmitter, Generic[_T]):
class Permissions(enum.IntFlag):
@@ -855,7 +904,8 @@ class Attribute(utils.EventEmitter, Generic[_T]):
def decode_value(self, value: bytes) -> _T:
return value # type: ignore
async def read_value(self, connection: Connection) -> bytes:
async def read_value(self, bearer: Bearer) -> bytes:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
@@ -890,6 +940,17 @@ class Attribute(utils.EventEmitter, Generic[_T]):
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
read_value = self.value.read(bearer)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
value = self.value
@@ -897,7 +958,8 @@ class Attribute(utils.EventEmitter, Generic[_T]):
return b'' if value is None else self.encode_value(value)
async def write_value(self, connection: Connection, value: bytes) -> None:
async def write_value(self, bearer: Bearer, value: bytes) -> None:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
if (
(self.permissions & self.WRITE_REQUIRES_ENCRYPTION)
and connection is not None
@@ -931,6 +993,15 @@ class Attribute(utils.EventEmitter, Generic[_T]):
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
result = self.value.write(bearer, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
self.value = decoded_value
+1 -1
View File
@@ -421,7 +421,7 @@ class Controller:
hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=command.op_code,
return_parameters=result,
return_parameters=hci.HCI_GenericReturnParameters(data=result),
)
)
+1 -1
View File
@@ -923,7 +923,7 @@ class DeviceClass:
# pylint: enable=line-too-long
@staticmethod
def split_class_of_device(class_of_device):
def split_class_of_device(class_of_device: int) -> tuple[int, int, int]:
# Split the bit fields of the composite class of device value into:
# (service_classes, major_device_class, minor_device_class)
return (
+340 -333
View File
File diff suppressed because it is too large Load Diff
+41 -43
View File
@@ -89,51 +89,54 @@ HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND = hci.hci_vendor_command_op_code(0x000E)
hci.HCI_Command.register_commands(globals())
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Intel_Read_Version_Command(hci.HCI_Command):
class HCI_Intel_Read_Version_ReturnParameters(hci.HCI_StatusReturnParameters):
tlv: bytes = hci.field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Read_Version_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Read_Version_Command(
hci.HCI_SyncCommand[HCI_Intel_Read_Version_ReturnParameters]
):
param0: int = dataclasses.field(metadata=hci.metadata(1))
return_parameters_fields = [
("status", hci.STATUS_SPEC),
("tlv", "*"),
]
@hci.HCI_Command.command
@hci.HCI_SyncCommand.sync_command(hci.HCI_StatusReturnParameters)
@dataclasses.dataclass
class Hci_Intel_Secure_Send_Command(hci.HCI_Command):
class Hci_Intel_Secure_Send_Command(
hci.HCI_SyncCommand[hci.HCI_StatusReturnParameters]
):
data_type: int = dataclasses.field(metadata=hci.metadata(1))
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
("status", 1),
]
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Intel_Reset_Command(hci.HCI_Command):
class HCI_Intel_Reset_ReturnParameters(hci.HCI_ReturnParameters):
data: bytes = hci.field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Reset_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Reset_Command(hci.HCI_SyncCommand[HCI_Intel_Reset_ReturnParameters]):
reset_type: int = dataclasses.field(metadata=hci.metadata(1))
patch_enable: int = dataclasses.field(metadata=hci.metadata(1))
ddc_reload: int = dataclasses.field(metadata=hci.metadata(1))
boot_option: int = dataclasses.field(metadata=hci.metadata(1))
boot_address: int = dataclasses.field(metadata=hci.metadata(4))
return_parameters_fields = [
("data", "*"),
]
@hci.HCI_Command.command
@dataclasses.dataclass
class Hci_Intel_Write_Device_Config_Command(hci.HCI_Command):
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
class HCI_Intel_Write_Device_Config_ReturnParameters(hci.HCI_StatusReturnParameters):
params: bytes = hci.field(metadata=hci.metadata('*'))
return_parameters_fields = [
("status", hci.STATUS_SPEC),
("params", "*"),
]
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Write_Device_Config_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Write_Device_Config_Command(
hci.HCI_SyncCommand[HCI_Intel_Write_Device_Config_ReturnParameters]
):
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
# -----------------------------------------------------------------------------
@@ -402,7 +405,7 @@ class Driver(common.Driver):
self.host.on_hci_event_packet(event)
return
if not event.return_parameters == hci.HCI_SUCCESS:
if not event.return_parameters.status == hci.HCI_SUCCESS:
raise DriverError("HCI_Command_Complete_Event error")
if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets:
@@ -641,8 +644,8 @@ class Driver(common.Driver):
while ddc_data:
ddc_len = 1 + ddc_data[0]
ddc_payload = ddc_data[:ddc_len]
await self.host.send_command(
Hci_Intel_Write_Device_Config_Command(data=ddc_payload)
await self.host.send_sync_command(
HCI_Intel_Write_Device_Config_Command(data=ddc_payload)
)
ddc_data = ddc_data[ddc_len:]
@@ -660,31 +663,26 @@ class Driver(common.Driver):
async def read_device_info(self) -> dict[ValueType, Any]:
self.host.ready = True
response = await self.host.send_command(hci.HCI_Reset_Command())
if not (
isinstance(response, hci.HCI_Command_Complete_Event)
and response.return_parameters
in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS)
):
response1 = await self.host.send_sync_command(
hci.HCI_Reset_Command(), check_status=False
)
if response1.status not in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS):
# When the controller is in operational mode, the response is a
# successful response.
# When the controller is in bootloader mode,
# HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything
# else is a failure.
logger.warning(f"unexpected response: {response}")
logger.warning(f"unexpected response: {response1}")
raise DriverError("unexpected HCI response")
# Read the firmware version.
response = await self.host.send_command(
HCI_Intel_Read_Version_Command(param0=0xFF)
response2 = await self.host.send_sync_command(
HCI_Intel_Read_Version_Command(param0=0xFF), check_status=False
)
if not isinstance(response, hci.HCI_Command_Complete_Event):
raise DriverError("unexpected HCI response")
if response.return_parameters.status != 0: # type: ignore
if response2.status != 0: # type: ignore
raise DriverError("HCI_Intel_Read_Version_Command error")
tlvs = _parse_tlv(response.return_parameters.tlv) # type: ignore
tlvs = _parse_tlv(response2.tlv) # type: ignore
# Convert the list to a dict. That's Ok here because we only expect each type
# to appear just once.
+79 -41
View File
@@ -16,6 +16,7 @@ Support for Realtek USB dongles.
Based on various online bits of information, including the Linux kernel.
(see `drivers/bluetooth/btrtl.c`)
"""
from __future__ import annotations
import asyncio
import enum
@@ -31,10 +32,14 @@ import weakref
# Imports
# -----------------------------------------------------------------------------
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from bumble import core, hci
from bumble.drivers import common
if TYPE_CHECKING:
from bumble.host import Host
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -77,6 +82,7 @@ class RtlProjectId(enum.IntEnum):
PROJECT_ID_8852A = 18
PROJECT_ID_8852B = 20
PROJECT_ID_8852C = 25
PROJECT_ID_8761C = 51
RTK_PROJECT_ID_TO_ROM = {
@@ -92,6 +98,7 @@ RTK_PROJECT_ID_TO_ROM = {
18: RTK_ROM_LMP_8852A,
20: RTK_ROM_LMP_8852A,
25: RTK_ROM_LMP_8852A,
51: RTK_ROM_LMP_8761A,
}
# List of USB (VendorID, ProductID) for Realtek-based devices.
@@ -123,6 +130,10 @@ RTK_USB_PRODUCTS = {
(0x2550, 0x8761),
(0x2B89, 0x8761),
(0x7392, 0xC611),
# Realtek 8761CUV
(0x0B05, 0x1BF6),
(0x0BDA, 0xC761),
(0x7392, 0xF611),
# Realtek 8821AE
(0x0B05, 0x17DC),
(0x13D3, 0x3414),
@@ -182,23 +193,36 @@ HCI_RTK_DROP_FIRMWARE_COMMAND = hci.hci_vendor_command_op_code(0x66)
hci.HCI_Command.register_commands(globals())
@hci.HCI_Command.command
@dataclass
class HCI_RTK_Read_ROM_Version_Command(hci.HCI_Command):
return_parameters_fields = [("status", hci.STATUS_SPEC), ("version", 1)]
class HCI_RTK_Read_ROM_Version_ReturnParameters(hci.HCI_StatusReturnParameters):
version: int = field(metadata=hci.metadata(1))
@hci.HCI_Command.command
@hci.HCI_SyncCommand.sync_command(HCI_RTK_Read_ROM_Version_ReturnParameters)
@dataclass
class HCI_RTK_Download_Command(hci.HCI_Command):
class HCI_RTK_Read_ROM_Version_Command(
hci.HCI_SyncCommand[HCI_RTK_Read_ROM_Version_ReturnParameters]
):
pass
@dataclass
class HCI_RTK_Download_ReturnParameters(hci.HCI_StatusReturnParameters):
index: int = field(metadata=hci.metadata(1))
@hci.HCI_SyncCommand.sync_command(HCI_RTK_Download_ReturnParameters)
@dataclass
class HCI_RTK_Download_Command(hci.HCI_SyncCommand[HCI_RTK_Download_ReturnParameters]):
index: int = field(metadata=hci.metadata(1))
payload: bytes = field(metadata=hci.metadata(RTK_FRAGMENT_LENGTH))
return_parameters_fields = [("status", hci.STATUS_SPEC), ("index", 1)]
@hci.HCI_Command.command
@hci.HCI_SyncCommand.sync_command(hci.HCI_GenericReturnParameters)
@dataclass
class HCI_RTK_Drop_Firmware_Command(hci.HCI_Command):
class HCI_RTK_Drop_Firmware_Command(
hci.HCI_SyncCommand[hci.HCI_GenericReturnParameters]
):
pass
@@ -363,6 +387,15 @@ class Driver(common.Driver):
fw_name="rtl8761bu_fw.bin",
config_name="rtl8761bu_config.bin",
),
# 8761CU
DriverInfo(
rom=RTK_ROM_LMP_8761A,
hci=(0x0E, 0x00),
config_needed=False,
has_rom_version=True,
fw_name="rtl8761cu_fw.bin",
config_name="rtl8761cu_config.bin",
),
# 8822C
DriverInfo(
rom=RTK_ROM_LMP_8822B,
@@ -420,9 +453,17 @@ class Driver(common.Driver):
@staticmethod
def find_driver_info(hci_version, hci_subversion, lmp_subversion):
for driver_info in Driver.DRIVER_INFOS:
if driver_info.rom == lmp_subversion and driver_info.hci == (
hci_subversion,
hci_version,
if driver_info.rom == lmp_subversion and (
driver_info.hci
== (
hci_subversion,
hci_version,
)
or driver_info.hci
== (
hci_subversion,
0x0,
)
):
return driver_info
@@ -467,7 +508,7 @@ class Driver(common.Driver):
return None
@staticmethod
def check(host):
def check(host: Host) -> bool:
if not host.hci_metadata:
logger.debug("USB metadata not found")
return False
@@ -491,41 +532,39 @@ class Driver(common.Driver):
return True
@staticmethod
async def get_loaded_firmware_version(host):
response = await host.send_command(HCI_RTK_Read_ROM_Version_Command())
async def get_loaded_firmware_version(host: Host) -> int | None:
response1 = await host.send_sync_command(
HCI_RTK_Read_ROM_Version_Command(), check_status=False
)
if response.return_parameters.status != hci.HCI_SUCCESS:
if response1.status != hci.HCI_SUCCESS:
return None
response = await host.send_command(
hci.HCI_Read_Local_Version_Information_Command(), check_result=True
)
return (
response.return_parameters.hci_subversion << 16
| response.return_parameters.lmp_subversion
response2 = await host.send_sync_command(
hci.HCI_Read_Local_Version_Information_Command()
)
return response2.hci_subversion << 16 | response2.lmp_subversion
@classmethod
async def driver_info_for_host(cls, host):
async def driver_info_for_host(cls, host: Host) -> DriverInfo | None:
try:
await host.send_command(
await host.send_sync_command(
hci.HCI_Reset_Command(),
check_result=True,
response_timeout=cls.POST_RESET_DELAY,
)
host.ready = True # Needed to let the host know the controller is ready.
except asyncio.exceptions.TimeoutError:
logger.warning("timeout waiting for hci reset, retrying")
await host.send_command(hci.HCI_Reset_Command(), check_result=True)
await host.send_sync_command(hci.HCI_Reset_Command())
host.ready = True
command = hci.HCI_Read_Local_Version_Information_Command()
response = await host.send_command(command, check_result=True)
if response.command_opcode != command.op_code:
response = await host.send_sync_command(command, check_status=False)
if response.status != hci.HCI_SUCCESS:
logger.error("failed to probe local version information")
return None
local_version = response.return_parameters
local_version = response
logger.debug(
f"looking for a driver: 0x{local_version.lmp_subversion:04X} "
@@ -546,7 +585,7 @@ class Driver(common.Driver):
return driver_info
@classmethod
async def for_host(cls, host, force=False):
async def for_host(cls, host: Host, force: bool = False):
# Check that a driver is needed for this host
if not force and not cls.check(host):
return None
@@ -603,13 +642,13 @@ class Driver(common.Driver):
async def download_for_rtl8723b(self):
if self.driver_info.has_rom_version:
response = await self.host.send_command(
HCI_RTK_Read_ROM_Version_Command(), check_result=True
response1 = await self.host.send_sync_command(
HCI_RTK_Read_ROM_Version_Command(), check_status=False
)
if response.return_parameters.status != hci.HCI_SUCCESS:
if response1.status != hci.HCI_SUCCESS:
logger.warning("can't get ROM version")
return None
rom_version = response.return_parameters.version
rom_version = response1.version
logger.debug(f"ROM version before download: {rom_version:04X}")
else:
rom_version = 0
@@ -644,21 +683,20 @@ class Driver(common.Driver):
fragment_offset = fragment_index * RTK_FRAGMENT_LENGTH
fragment = payload[fragment_offset : fragment_offset + RTK_FRAGMENT_LENGTH]
logger.debug(f"downloading fragment {fragment_index}")
await self.host.send_command(
HCI_RTK_Download_Command(index=download_index, payload=fragment),
check_result=True,
await self.host.send_sync_command(
HCI_RTK_Download_Command(index=download_index, payload=fragment)
)
logger.debug("download complete!")
# Read the version again
response = await self.host.send_command(
HCI_RTK_Read_ROM_Version_Command(), check_result=True
response2 = await self.host.send_sync_command(
HCI_RTK_Read_ROM_Version_Command(), check_status=False
)
if response.return_parameters.status != hci.HCI_SUCCESS:
if response2.status != hci.HCI_SUCCESS:
logger.warning("can't get ROM version")
else:
rom_version = response.return_parameters.version
rom_version = response2.version
logger.debug(f"ROM version after download: {rom_version:02X}")
return firmware.version
@@ -680,7 +718,7 @@ class Driver(common.Driver):
async def init_controller(self):
await self.download_firmware()
await self.host.send_command(hci.HCI_Reset_Command(), check_result=True)
await self.host.send_sync_command(hci.HCI_Reset_Command())
logger.info(f"loaded FW image {self.driver_info.fw_name}")
+2 -2
View File
@@ -31,7 +31,7 @@ import struct
from collections.abc import Iterable, Sequence
from typing import TypeVar
from bumble.att import Attribute, AttributeValue
from bumble.att import Attribute, AttributeValue, AttributeValueV2
from bumble.colors import color
from bumble.core import UUID, BaseBumbleError
@@ -579,7 +579,7 @@ class Descriptor(Attribute):
def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
elif isinstance(self.value, (AttributeValue, AttributeValueV2)):
value_str = '<dynamic>'
else:
value_str = '<...>'
+83 -29
View File
@@ -26,6 +26,7 @@
from __future__ import annotations
import asyncio
import functools
import logging
import struct
from collections.abc import Callable, Iterable
@@ -35,9 +36,10 @@ from typing import (
Any,
Generic,
TypeVar,
overload,
)
from bumble import att, core, utils
from bumble import att, core, l2cap, utils
from bumble.colors import color
from bumble.core import UUID, InvalidStateError
from bumble.gatt import (
@@ -54,12 +56,12 @@ from bumble.gatt import (
)
from bumble.hci import HCI_Constant
if TYPE_CHECKING:
from bumble import device as device_module
# -----------------------------------------------------------------------------
# Typing
# -----------------------------------------------------------------------------
if TYPE_CHECKING:
from bumble.device import Connection
_T = TypeVar('_T')
# -----------------------------------------------------------------------------
@@ -267,8 +269,8 @@ class Client:
pending_response: asyncio.futures.Future[att.ATT_PDU] | None
pending_request: att.ATT_PDU | None
def __init__(self, connection: Connection) -> None:
self.connection = connection
def __init__(self, bearer: att.Bearer) -> None:
self.bearer = bearer
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
@@ -278,21 +280,78 @@ class Client:
self.services = []
self.cached_values = {}
connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
if att.is_enhanced_bearer(bearer):
bearer.on(bearer.EVENT_CLOSE, self.on_disconnection)
self._bearer_id = (
f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
)
# Fill the mtu.
bearer.on_att_mtu_update(att.ATT_DEFAULT_MTU)
self.connection = bearer.connection
else:
bearer.on(bearer.EVENT_DISCONNECTION, self.on_disconnection)
self._bearer_id = f'[0x{bearer.handle:04X}]'
self.connection = bearer
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
) -> Client: ...
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client]: ...
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client] | Client:
channels = await connection.device.l2cap_channel_manager.create_enhanced_credit_based_channels(
connection,
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM),
count,
)
def on_pdu(client: Client, pdu: bytes):
client.on_gatt_pdu(att.ATT_PDU.from_bytes(pdu))
clients = [cls(channel) for channel in channels]
for channel, client in zip(channels, clients):
channel.sink = functools.partial(on_pdu, client)
channel.att_mtu = att.ATT_DEFAULT_MTU
return clients[0] if count == 1 else clients
@property
def mtu(self) -> int:
return self.bearer.att_mtu
@mtu.setter
def mtu(self, value: int) -> None:
self.bearer.on_att_mtu_update(value)
def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(att.ATT_CID, pdu)
if att.is_enhanced_bearer(self.bearer):
self.bearer.write(pdu)
else:
self.bearer.send_l2cap_pdu(att.ATT_CID, pdu)
async def send_command(self, command: att.ATT_PDU) -> None:
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
logger.debug(f'GATT Command from client: {self._bearer_id} {command}')
self.send_gatt_pdu(bytes(command))
async def send_request(self, request: att.ATT_PDU):
logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
logger.debug(f'GATT Request from client: {self._bearer_id} {request}')
# Wait until we can send (only one pending command at a time for the connection)
response = None
@@ -321,10 +380,7 @@ class Client:
def send_confirmation(
self, confirmation: att.ATT_Handle_Value_Confirmation
) -> None:
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
logger.debug(f'GATT Confirmation from client: {self._bearer_id} {confirmation}')
self.send_gatt_pdu(bytes(confirmation))
async def request_mtu(self, mtu: int) -> int:
@@ -336,7 +392,7 @@ class Client:
# We can only send one request per connection
if self.mtu_exchange_done:
return self.connection.att_mtu
return self.mtu
# Send the request
self.mtu_exchange_done = True
@@ -347,9 +403,9 @@ class Client:
raise att.ATT_Error(error_code=response.error_code, message=response)
# Compute the final MTU
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
self.mtu = min(mtu, response.server_rx_mtu)
return self.connection.att_mtu
return self.mtu
def get_services_by_uuid(self, uuid: UUID) -> list[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid]
@@ -942,7 +998,7 @@ class Client:
# If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that
attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
if not no_long_read and len(attribute_value) == self.mtu - 1:
logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value)
while True:
@@ -966,7 +1022,7 @@ class Client:
part = response.part_attribute_value
attribute_value += part
if len(part) < self.connection.att_mtu - 1:
if len(part) < self.mtu - 1:
break
offset += len(part)
@@ -1062,14 +1118,13 @@ class Client:
)
)
def on_disconnection(self, _) -> None:
def on_disconnection(self, *args) -> None:
del args # unused.
if self.pending_response and not self.pending_response.done():
self.pending_response.cancel()
def on_gatt_pdu(self, att_pdu: att.ATT_PDU) -> None:
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
logger.debug(f'GATT Response to client: {self._bearer_id} {att_pdu}')
if att_pdu.op_code in att.ATT_RESPONSES:
if self.pending_request is None:
# Not expected!
@@ -1099,8 +1154,7 @@ class Client:
else:
logger.warning(
color(
'--- Ignoring GATT Response from '
f'[0x{self.connection.handle:04X}]: ',
'--- Ignoring GATT Response from ' f'{self._bearer_id}: ',
'red',
)
+ str(att_pdu)
+198 -129
View File
@@ -32,9 +32,8 @@ from collections import defaultdict
from collections.abc import Iterable
from typing import TYPE_CHECKING, TypeVar
from bumble import att, utils
from bumble import att, core, l2cap, utils
from bumble.colors import color
from bumble.core import UUID
from bumble.gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
@@ -44,14 +43,13 @@ from bumble.gatt import (
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic,
CharacteristicDeclaration,
CharacteristicValue,
Descriptor,
IncludedServiceDeclaration,
Service,
)
if TYPE_CHECKING:
from bumble.device import Connection, Device
from bumble.device import Device
# -----------------------------------------------------------------------------
# Logging
@@ -65,6 +63,18 @@ logger = logging.getLogger(__name__)
GATT_SERVER_DEFAULT_MAX_MTU = 517
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def _bearer_id(bearer: att.Bearer) -> str:
if att.is_enhanced_bearer(bearer):
return f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
else:
return f'[0x{bearer.handle:04X}]'
# -----------------------------------------------------------------------------
# GATT Server
# -----------------------------------------------------------------------------
@@ -72,9 +82,9 @@ class Server(utils.EventEmitter):
attributes: list[att.Attribute]
services: list[Service]
attributes_by_handle: dict[int, att.Attribute]
subscribers: dict[int, dict[int, bytes]]
indication_semaphores: defaultdict[int, asyncio.Semaphore]
pending_confirmations: defaultdict[int, asyncio.futures.Future | None]
subscribers: dict[att.Bearer, dict[int, bytes]]
indication_semaphores: defaultdict[att.Bearer, asyncio.Semaphore]
pending_confirmations: defaultdict[att.Bearer, asyncio.futures.Future | None]
EVENT_CHARACTERISTIC_SUBSCRIPTION = "characteristic_subscription"
@@ -96,8 +106,29 @@ class Server(utils.EventEmitter):
def __str__(self) -> str:
return "\n".join(map(str, self.attributes))
def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(connection_handle, att.ATT_CID, pdu)
def register_eatt(
self, spec: l2cap.LeCreditBasedChannelSpec | None = None
) -> l2cap.LeCreditBasedChannelServer:
def on_channel(channel: l2cap.LeCreditBasedChannel):
logger.debug(
"New EATT Bearer Connection=0x%04X CID=0x%04X",
channel.connection.handle,
channel.source_cid,
)
channel.att_mtu = att.ATT_DEFAULT_MTU
channel.sink = lambda pdu: self.on_gatt_pdu(
channel, att.ATT_PDU.from_bytes(pdu)
)
return self.device.create_l2cap_server(
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM), handler=on_channel
)
def send_gatt_pdu(self, bearer: att.Bearer, pdu: bytes) -> None:
if att.is_enhanced_bearer(bearer):
bearer.write(pdu)
else:
self.device.send_l2cap_pdu(bearer.handle, att.ATT_CID, pdu)
def next_handle(self) -> int:
return 1 + len(self.attributes)
@@ -138,7 +169,7 @@ class Server(utils.EventEmitter):
None,
)
def get_service_attribute(self, service_uuid: UUID) -> Service | None:
def get_service_attribute(self, service_uuid: core.UUID) -> Service | None:
return next(
(
attribute
@@ -151,7 +182,7 @@ class Server(utils.EventEmitter):
)
def get_characteristic_attributes(
self, service_uuid: UUID, characteristic_uuid: UUID
self, service_uuid: core.UUID, characteristic_uuid: core.UUID
) -> tuple[CharacteristicDeclaration, Characteristic] | None:
service_handle = self.get_service_attribute(service_uuid)
if not service_handle:
@@ -176,7 +207,10 @@ class Server(utils.EventEmitter):
)
def get_descriptor_attribute(
self, service_uuid: UUID, characteristic_uuid: UUID, descriptor_uuid: UUID
self,
service_uuid: core.UUID,
characteristic_uuid: core.UUID,
descriptor_uuid: core.UUID,
) -> Descriptor | None:
characteristics = self.get_characteristic_attributes(
service_uuid, characteristic_uuid
@@ -257,14 +291,7 @@ class Server(utils.EventEmitter):
Descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
att.Attribute.READABLE | att.Attribute.WRITEABLE,
CharacteristicValue(
read=lambda connection, characteristic=characteristic: self.read_cccd(
connection, characteristic
),
write=lambda connection, value, characteristic=characteristic: self.write_cccd(
connection, characteristic, value
),
),
self.make_descriptor_value(characteristic),
)
)
@@ -280,10 +307,21 @@ class Server(utils.EventEmitter):
for service in services:
self.add_service(service)
def read_cccd(
self, connection: Connection, characteristic: Characteristic
) -> bytes:
subscribers = self.subscribers.get(connection.handle)
def make_descriptor_value(
self, characteristic: Characteristic
) -> att.AttributeValueV2:
# It is necessary to use Attribute Value V2 here to identify the bearer of CCCD.
return att.AttributeValueV2(
lambda bearer, characteristic=characteristic: self.read_cccd(
bearer, characteristic
),
write=lambda bearer, value, characteristic=characteristic: self.write_cccd(
bearer, characteristic, value
),
)
def read_cccd(self, bearer: att.Bearer, characteristic: Characteristic) -> bytes:
subscribers = self.subscribers.get(bearer)
cccd = None
if subscribers:
cccd = subscribers.get(characteristic.handle)
@@ -292,12 +330,12 @@ class Server(utils.EventEmitter):
def write_cccd(
self,
connection: Connection,
bearer: att.Bearer,
characteristic: Characteristic,
value: bytes,
) -> None:
logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, '
f'Subscription update for connection={_bearer_id(bearer)}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
)
@@ -306,41 +344,60 @@ class Server(utils.EventEmitter):
logger.warning('CCCD value not 2 bytes long')
return
cccds = self.subscribers.setdefault(connection.handle, {})
cccds = self.subscribers.setdefault(bearer, {})
cccds[characteristic.handle] = value
logger.debug(f'CCCDs: {cccds}')
notify_enabled = value[0] & 0x01 != 0
indicate_enabled = value[0] & 0x02 != 0
characteristic.emit(
characteristic.EVENT_SUBSCRIPTION,
connection,
bearer,
notify_enabled,
indicate_enabled,
)
self.emit(
self.EVENT_CHARACTERISTIC_SUBSCRIPTION,
connection,
bearer,
characteristic,
notify_enabled,
indicate_enabled,
)
def send_response(self, connection: Connection, response: att.ATT_PDU) -> None:
logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
self.send_gatt_pdu(connection.handle, bytes(response))
def send_response(self, bearer: att.Bearer, response: att.ATT_PDU) -> None:
logger.debug(f'GATT Response from server: {_bearer_id(bearer)} {response}')
self.send_gatt_pdu(bearer, bytes(response))
async def notify_subscriber(
self,
connection: Connection,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
return await self._notify_single_subscriber(bearer, attribute, value, force)
else:
# If API is called to a Connection and not forced, try to notify all subscribed bearers on it.
bearers = [
channel
for channel in self.device.l2cap_channel_manager.le_coc_channels.get(
bearer.handle, {}
).values()
if channel.psm == att.EATT_PSM
] + [bearer]
for bearer in bearers:
await self._notify_single_subscriber(bearer, attribute, value, force)
async def _notify_single_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
force: bool,
) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
subscribers = self.subscribers.get(bearer)
if not subscribers:
logger.debug('not notifying, no subscribers')
return
@@ -356,34 +413,53 @@ class Server(utils.EventEmitter):
# Get or encode the value
value = (
await attribute.read_value(connection)
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
value = value[: connection.att_mtu - 3]
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
# Notify
notification = att.ATT_Handle_Value_Notification(
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(
f'GATT Notify from server: [0x{connection.handle:04X}] {notification}'
)
self.send_gatt_pdu(connection.handle, bytes(notification))
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
self.send_gatt_pdu(bearer, bytes(notification))
async def indicate_subscriber(
self,
connection: Connection,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
return await self._notify_single_subscriber(bearer, attribute, value, force)
else:
# If API is called to a Connection and not forced, try to indicate all subscribed bearers on it.
bearers = [
channel
for channel in self.device.l2cap_channel_manager.le_coc_channels.get(
bearer.handle, {}
).values()
if channel.psm == att.EATT_PSM
] + [bearer]
for bearer in bearers:
await self._indicate_single_bearer(bearer, attribute, value, force)
async def _indicate_single_bearer(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
force: bool,
) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
subscribers = self.subscribers.get(bearer)
if not subscribers:
logger.debug('not indicating, no subscribers')
return
@@ -399,40 +475,38 @@ class Server(utils.EventEmitter):
# Get or encode the value
value = (
await attribute.read_value(connection)
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
value = value[: connection.att_mtu - 3]
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
# Indicate
indication = att.ATT_Handle_Value_Indication(
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(
f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}'
)
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
# Wait until we can send (only one pending indication at a time per connection)
async with self.indication_semaphores[connection.handle]:
assert self.pending_confirmations[connection.handle] is None
async with self.indication_semaphores[bearer]:
assert self.pending_confirmations[bearer] is None
# Create a future value to hold the eventual response
pending_confirmation = self.pending_confirmations[connection.handle] = (
pending_confirmation = self.pending_confirmations[bearer] = (
asyncio.get_running_loop().create_future()
)
try:
self.send_gatt_pdu(connection.handle, bytes(indication))
self.send_gatt_pdu(bearer, bytes(indication))
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') from error
finally:
self.pending_confirmations[connection.handle] = None
self.pending_confirmations[bearer] = None
async def _notify_or_indicate_subscribers(
self,
@@ -441,24 +515,24 @@ class Server(utils.EventEmitter):
value: bytes | None = None,
force: bool = False,
) -> None:
# Get all the connections for which there's at least one subscription
connections = [
connection
for connection in [
self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle)
]
if connection is not None
# Get all the bearers for which there's at least one subscription
bearers: list[att.Bearer] = [
bearer
for bearer, subscribers in self.subscribers.items()
if force or subscribers.get(attribute.handle)
]
# Indicate or notify for each connection
if connections:
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
if bearers:
coroutine = (
self._indicate_single_bearer
if indicate
else self._notify_single_subscriber
)
await asyncio.wait(
[
asyncio.create_task(coroutine(connection, attribute, value, force))
for connection in connections
asyncio.create_task(coroutine(bearer, attribute, value, force))
for bearer in bearers
]
)
@@ -480,21 +554,18 @@ class Server(utils.EventEmitter):
):
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection: Connection) -> None:
if connection.handle in self.subscribers:
del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores:
del self.indication_semaphores[connection.handle]
if connection.handle in self.pending_confirmations:
del self.pending_confirmations[connection.handle]
def on_disconnection(self, bearer: att.Bearer) -> None:
self.subscribers.pop(bearer, None)
self.indication_semaphores.pop(bearer, None)
self.pending_confirmations.pop(bearer, None)
def on_gatt_pdu(self, connection: Connection, att_pdu: att.ATT_PDU) -> None:
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
def on_gatt_pdu(self, bearer: att.Bearer, att_pdu: att.ATT_PDU) -> None:
logger.debug(f'GATT Request to server: {_bearer_id(bearer)} {att_pdu}')
handler_name = f'on_{att_pdu.name.lower()}'
handler = getattr(self, handler_name, None)
if handler is not None:
try:
handler(connection, att_pdu)
handler(bearer, att_pdu)
except att.ATT_Error as error:
logger.debug(f'normal exception returned by handler: {error}')
response = att.ATT_Error_Response(
@@ -502,7 +573,7 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=error.att_handle,
error_code=error.error_code,
)
self.send_response(connection, response)
self.send_response(bearer, response)
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
response = att.ATT_Error_Response(
@@ -510,18 +581,18 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=0x0000,
error_code=att.ATT_UNLIKELY_ERROR_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
raise
else:
# No specific handler registered
if att_pdu.op_code in att.ATT_REQUESTS:
# Invoke the generic handler
self.on_att_request(connection, att_pdu)
self.on_att_request(bearer, att_pdu)
else:
# Just ignore
logger.warning(
color(
f'--- Ignoring GATT Request from [0x{connection.handle:04X}]: ',
f'--- Ignoring GATT Request from {_bearer_id(bearer)}: ',
'red',
)
+ str(att_pdu)
@@ -530,13 +601,14 @@ class Server(utils.EventEmitter):
#######################################################
# ATT handlers
#######################################################
def on_att_request(self, connection: Connection, pdu: att.ATT_PDU) -> None:
def on_att_request(self, bearer: att.Bearer, pdu: att.ATT_PDU) -> None:
'''
Handler for requests without a more specific handler
'''
logger.warning(
color(
f'--- Unsupported ATT Request from [0x{connection.handle:04X}]: ', 'red'
f'--- Unsupported ATT Request from {_bearer_id(bearer)}: ',
'red',
)
+ str(pdu)
)
@@ -545,29 +617,28 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=0x0000,
error_code=att.ATT_REQUEST_NOT_SUPPORTED_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
def on_att_exchange_mtu_request(
self, connection: Connection, request: att.ATT_Exchange_MTU_Request
self, bearer: att.Bearer, request: att.ATT_Exchange_MTU_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
'''
self.send_response(
connection, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
bearer, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
)
# Compute the final MTU
if request.client_rx_mtu >= att.ATT_DEFAULT_MTU:
mtu = min(self.max_mtu, request.client_rx_mtu)
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
bearer.on_att_mtu_update(mtu)
else:
logger.warning('invalid client_rx_mtu received, MTU not changed')
def on_att_find_information_request(
self, connection: Connection, request: att.ATT_Find_Information_Request
self, bearer: att.Bearer, request: att.ATT_Find_Information_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.1 Find Information Request
@@ -580,7 +651,7 @@ class Server(utils.EventEmitter):
or request.starting_handle > request.ending_handle
):
self.send_response(
connection,
bearer,
att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
@@ -590,7 +661,7 @@ class Server(utils.EventEmitter):
return
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
pdu_space_available = bearer.att_mtu - 2
attributes: list[att.Attribute] = []
uuid_size = 0
for attribute in (
@@ -632,18 +703,18 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request(
self, connection: Connection, request: att.ATT_Find_By_Type_Value_Request
self, bearer: att.Bearer, request: att.ATT_Find_By_Type_Value_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
'''
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
pdu_space_available = bearer.att_mtu - 2
attributes = []
response: att.ATT_PDU
async for attribute in (
@@ -652,7 +723,7 @@ class Server(utils.EventEmitter):
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type
and (await attribute.read_value(connection)) == request.attribute_value
and (await attribute.read_value(bearer)) == request.attribute_value
and pdu_space_available >= 4
):
# TODO: check permissions
@@ -688,17 +759,17 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_by_type_request(
self, connection: Connection, request: att.ATT_Read_By_Type_Request
self, bearer: att.Bearer, request: att.ATT_Read_By_Type_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
pdu_space_available = connection.att_mtu - 2
pdu_space_available = bearer.att_mtu - 2
response: att.ATT_PDU = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -716,7 +787,7 @@ class Server(utils.EventEmitter):
and pdu_space_available
):
try:
attribute_value = await attribute.read_value(connection)
attribute_value = await attribute.read_value(bearer)
except att.ATT_Error as error:
# If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point
@@ -729,7 +800,7 @@ class Server(utils.EventEmitter):
break
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 4, 253)
max_attribute_size = min(bearer.att_mtu - 4, 253)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
@@ -756,11 +827,11 @@ class Server(utils.EventEmitter):
else:
logging.debug(f"not found {request}")
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_request(
self, connection: Connection, request: att.ATT_Read_Request
self, bearer: att.Bearer, request: att.ATT_Read_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
@@ -769,7 +840,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(connection)
value = await attribute.read_value(bearer)
except att.ATT_Error as error:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -777,7 +848,7 @@ class Server(utils.EventEmitter):
error_code=error.error_code,
)
else:
value_size = min(connection.att_mtu - 1, len(value))
value_size = min(bearer.att_mtu - 1, len(value))
response = att.ATT_Read_Response(attribute_value=value[:value_size])
else:
response = att.ATT_Error_Response(
@@ -785,11 +856,11 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_HANDLE_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_blob_request(
self, connection: Connection, request: att.ATT_Read_Blob_Request
self, bearer: att.Bearer, request: att.ATT_Read_Blob_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
@@ -798,7 +869,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(connection)
value = await attribute.read_value(bearer)
except att.ATT_Error as error:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -812,7 +883,7 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_OFFSET_ERROR,
)
elif len(value) <= connection.att_mtu - 1:
elif len(value) <= bearer.att_mtu - 1:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
@@ -820,7 +891,7 @@ class Server(utils.EventEmitter):
)
else:
part_size = min(
connection.att_mtu - 1, len(value) - request.value_offset
bearer.att_mtu - 1, len(value) - request.value_offset
)
response = att.ATT_Read_Blob_Response(
part_attribute_value=value[
@@ -833,11 +904,11 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_HANDLE_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request(
self, connection: Connection, request: att.ATT_Read_By_Group_Type_Request
self, bearer: att.Bearer, request: att.ATT_Read_By_Group_Type_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
@@ -852,10 +923,10 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.starting_handle,
error_code=att.ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
return
pdu_space_available = connection.att_mtu - 2
pdu_space_available = bearer.att_mtu - 2
attributes: list[tuple[int, int, bytes]] = []
for attribute in (
attribute
@@ -867,9 +938,9 @@ class Server(utils.EventEmitter):
):
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = await attribute.read_value(connection)
attribute_value = await attribute.read_value(bearer)
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251)
max_attribute_size = min(bearer.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
@@ -904,11 +975,11 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_write_request(
self, connection: Connection, request: att.ATT_Write_Request
self, bearer: att.Bearer, request: att.ATT_Write_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
@@ -918,7 +989,7 @@ class Server(utils.EventEmitter):
attribute = self.get_attribute(request.attribute_handle)
if attribute is None:
self.send_response(
connection,
bearer,
att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
@@ -932,7 +1003,7 @@ class Server(utils.EventEmitter):
# Check the request parameters
if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE:
self.send_response(
connection,
bearer,
att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
@@ -944,7 +1015,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU
try:
# Accept the value
await attribute.write_value(connection, request.attribute_value)
await attribute.write_value(bearer, request.attribute_value)
except att.ATT_Error as error:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -954,11 +1025,11 @@ class Server(utils.EventEmitter):
else:
# Done
response = att.ATT_Write_Response()
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_write_command(
self, connection: Connection, request: att.ATT_Write_Command
self, bearer: att.Bearer, request: att.ATT_Write_Command
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
@@ -977,22 +1048,20 @@ class Server(utils.EventEmitter):
# Accept the value
try:
await attribute.write_value(connection, request.attribute_value)
await attribute.write_value(bearer, request.attribute_value)
except Exception:
logger.exception('!!! ignoring exception')
def on_att_handle_value_confirmation(
self,
connection: Connection,
bearer: att.Bearer,
confirmation: att.ATT_Handle_Value_Confirmation,
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation
'''
del confirmation # Unused.
if (
pending_confirmation := self.pending_confirmations[connection.handle]
) is None:
if (pending_confirmation := self.pending_confirmations[bearer]) is None:
# Not expected!
logger.warning(
'!!! unexpected confirmation, there is no pending indication'
+1131 -778
View File
File diff suppressed because it is too large Load Diff
+179 -114
View File
@@ -23,11 +23,15 @@ import dataclasses
import logging
import struct
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
from bumble import drivers, hci, utils
from bumble.colors import color
from bumble.core import ConnectionPHY, InvalidStateError, PhysicalTransport
from bumble.core import (
ConnectionPHY,
InvalidStateError,
PhysicalTransport,
)
from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper
from bumble.transport.common import TransportLostError
@@ -35,7 +39,6 @@ from bumble.transport.common import TransportLostError
if TYPE_CHECKING:
from bumble.transport.common import TransportSink, TransportSource
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -236,6 +239,9 @@ class IsoLink:
# -----------------------------------------------------------------------------
_RP = TypeVar('_RP', bound=hci.HCI_ReturnParameters)
class Host(utils.EventEmitter):
connections: dict[int, Connection]
cis_links: dict[int, IsoLink]
@@ -264,11 +270,13 @@ class Host(utils.EventEmitter):
self.bis_links = {} # BIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.bigs = {} # BIG Handle to BIS Handles
self.pending_command = None
self.pending_command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand | None = None
self.pending_response: asyncio.Future[Any] | None = None
self.number_of_supported_advertising_sets = 0
self.maximum_advertising_data_length = 31
self.local_version = None
self.local_version: (
hci.HCI_Read_Local_Version_Information_ReturnParameters | None
) = None
self.local_supported_commands = 0
self.local_le_features = 0
self.local_lmp_features = hci.LmpFeatureMask(0) # Classic LMP features
@@ -312,7 +320,7 @@ class Host(utils.EventEmitter):
self.emit('flush')
self.command_semaphore.release()
async def reset(self, driver_factory=drivers.get_driver_for_host):
async def reset(self, driver_factory=drivers.get_driver_for_host) -> None:
if self.ready:
self.ready = False
await self.flush()
@@ -330,57 +338,53 @@ class Host(utils.EventEmitter):
# Send a reset command unless a driver has already done so.
if reset_needed:
await self.send_command(hci.HCI_Reset_Command(), check_result=True)
await self.send_sync_command(hci.HCI_Reset_Command())
self.ready = True
response = await self.send_command(
hci.HCI_Read_Local_Supported_Commands_Command(), check_result=True
response1 = await self.send_sync_command(
hci.HCI_Read_Local_Supported_Commands_Command()
)
self.local_supported_commands = int.from_bytes(
response.return_parameters.supported_commands, 'little'
response1.supported_commands, 'little'
)
if self.supports_command(hci.HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command(
hci.HCI_LE_Read_Local_Supported_Features_Command(), check_result=True
response2 = await self.send_sync_command(
hci.HCI_LE_Read_Local_Supported_Features_Command()
)
self.local_le_features = struct.unpack(
'<Q', response.return_parameters.le_features
)[0]
self.local_le_features = struct.unpack('<Q', response2.le_features)[0]
if self.supports_command(hci.HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND):
response = await self.send_command(
hci.HCI_Read_Local_Version_Information_Command(), check_result=True
self.local_version = await self.send_sync_command(
hci.HCI_Read_Local_Version_Information_Command()
)
self.local_version = response.return_parameters
if self.supports_command(hci.HCI_READ_LOCAL_EXTENDED_FEATURES_COMMAND):
max_page_number = 0
page_number = 0
lmp_features = 0
while page_number <= max_page_number:
response = await self.send_command(
response4 = await self.send_sync_command(
hci.HCI_Read_Local_Extended_Features_Command(
page_number=page_number
),
check_result=True,
)
)
lmp_features |= int.from_bytes(
response.return_parameters.extended_lmp_features, 'little'
response4.extended_lmp_features, 'little'
) << (64 * page_number)
max_page_number = response.return_parameters.maximum_page_number
max_page_number = response4.maximum_page_number
page_number += 1
self.local_lmp_features = hci.LmpFeatureMask(lmp_features)
elif self.supports_command(hci.HCI_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command(
hci.HCI_Read_Local_Supported_Features_Command(), check_result=True
response5 = await self.send_sync_command(
hci.HCI_Read_Local_Supported_Features_Command()
)
self.local_lmp_features = hci.LmpFeatureMask(
int.from_bytes(response.return_parameters.lmp_features, 'little')
int.from_bytes(response5.lmp_features, 'little')
)
await self.send_command(
await self.send_sync_command(
hci.HCI_Set_Event_Mask_Command(
event_mask=hci.HCI_Set_Event_Mask_Command.mask(
[
@@ -437,7 +441,7 @@ class Host(utils.EventEmitter):
)
)
if self.supports_command(hci.HCI_SET_EVENT_MASK_PAGE_2_COMMAND):
await self.send_command(
await self.send_sync_command(
hci.HCI_Set_Event_Mask_Page_2_Command(
event_mask_page_2=hci.HCI_Set_Event_Mask_Page_2_Command.mask(
[hci.HCI_ENCRYPTION_CHANGE_V2_EVENT]
@@ -499,20 +503,14 @@ class Host(utils.EventEmitter):
]
)
await self.send_command(
await self.send_sync_command(
hci.HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask)
)
if self.supports_command(hci.HCI_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(
hci.HCI_Read_Buffer_Size_Command(), check_result=True
)
hc_acl_data_packet_length = (
response.return_parameters.hc_acl_data_packet_length
)
hc_total_num_acl_data_packets = (
response.return_parameters.hc_total_num_acl_data_packets
)
response6 = await self.send_sync_command(hci.HCI_Read_Buffer_Size_Command())
hc_acl_data_packet_length = response6.hc_acl_data_packet_length
hc_total_num_acl_data_packets = response6.hc_total_num_acl_data_packets
logger.debug(
'HCI ACL flow control: '
@@ -531,19 +529,13 @@ class Host(utils.EventEmitter):
iso_data_packet_length = 0
total_num_iso_data_packets = 0
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await self.send_command(
hci.HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
le_acl_data_packet_length = (
response.return_parameters.le_acl_data_packet_length
)
total_num_le_acl_data_packets = (
response.return_parameters.total_num_le_acl_data_packets
)
iso_data_packet_length = response.return_parameters.iso_data_packet_length
total_num_iso_data_packets = (
response.return_parameters.total_num_iso_data_packets
response7 = await self.send_sync_command(
hci.HCI_LE_Read_Buffer_Size_V2_Command()
)
le_acl_data_packet_length = response7.le_acl_data_packet_length
total_num_le_acl_data_packets = response7.total_num_le_acl_data_packets
iso_data_packet_length = response7.iso_data_packet_length
total_num_iso_data_packets = response7.total_num_iso_data_packets
logger.debug(
'HCI LE flow control: '
@@ -553,15 +545,11 @@ class Host(utils.EventEmitter):
f'total_num_iso_data_packets={total_num_iso_data_packets}'
)
elif self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(
hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
le_acl_data_packet_length = (
response.return_parameters.le_acl_data_packet_length
)
total_num_le_acl_data_packets = (
response.return_parameters.total_num_le_acl_data_packets
response8 = await self.send_sync_command(
hci.HCI_LE_Read_Buffer_Size_Command()
)
le_acl_data_packet_length = response8.le_acl_data_packet_length
total_num_le_acl_data_packets = response8.total_num_le_acl_data_packets
logger.debug(
'HCI LE ACL flow control: '
@@ -592,16 +580,16 @@ class Host(utils.EventEmitter):
) and self.supports_command(
hci.HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
):
response = await self.send_command(
response9 = await self.send_sync_command(
hci.HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
suggested_max_tx_octets = response9.suggested_max_tx_octets
suggested_max_tx_time = response9.suggested_max_tx_time
if (
suggested_max_tx_octets != self.suggested_max_tx_octets
or suggested_max_tx_time != self.suggested_max_tx_time
):
await self.send_command(
await self.send_sync_command(
hci.HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets=self.suggested_max_tx_octets,
suggested_max_tx_time=self.suggested_max_tx_time,
@@ -611,23 +599,21 @@ class Host(utils.EventEmitter):
if self.supports_command(
hci.HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND
):
response = await self.send_command(
hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command(),
check_result=True,
response10 = await self.send_sync_command(
hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
self.number_of_supported_advertising_sets = (
response.return_parameters.num_supported_advertising_sets
response10.num_supported_advertising_sets
)
if self.supports_command(
hci.HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND
):
response = await self.send_command(
hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command(),
check_result=True,
response11 = await self.send_sync_command(
hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
self.maximum_advertising_data_length = (
response.return_parameters.max_advertising_data_length
response11.max_advertising_data_length
)
@property
@@ -654,9 +640,11 @@ class Host(utils.EventEmitter):
if self.hci_sink:
self.hci_sink.on_packet(bytes(packet))
async def send_command(
self, command, check_result=False, response_timeout: int | None = None
):
async def _send_command(
self,
command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand,
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event:
# Wait until we can send (only one pending command at a time)
async with self.command_semaphore:
assert self.pending_command is None
@@ -668,29 +656,9 @@ class Host(utils.EventEmitter):
try:
self.send_hci_packet(command)
await asyncio.wait_for(self.pending_response, timeout=response_timeout)
response = self.pending_response.result()
# Check the return parameters if required
if check_result:
if isinstance(response, hci.HCI_Command_Status_Event):
status = response.status # type: ignore[attr-defined]
elif isinstance(response.return_parameters, int):
status = response.return_parameters
elif isinstance(response.return_parameters, bytes):
# return parameters first field is a one byte status code
status = response.return_parameters[0]
else:
status = response.return_parameters.status
if status != hci.HCI_SUCCESS:
logger.warning(
f'{command.name} failed '
f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return response
return await asyncio.wait_for(
self.pending_response, timeout=response_timeout
)
except Exception:
logger.exception(color("!!! Exception while sending command:", "red"))
raise
@@ -698,12 +666,107 @@ class Host(utils.EventEmitter):
self.pending_command = None
self.pending_response = None
# Use this method to send a command from a task
def send_command_sync(self, command: hci.HCI_Command) -> None:
async def send_command(command: hci.HCI_Command) -> None:
await self.send_command(command)
@overload
async def send_command(
self,
command: hci.HCI_SyncCommand[_RP],
check_result: bool = False,
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event[_RP]: ...
asyncio.create_task(send_command(command))
@overload
async def send_command(
self,
command: hci.HCI_AsyncCommand,
check_result: bool = False,
response_timeout: float | None = None,
) -> hci.HCI_Command_Status_Event: ...
async def send_command(
self,
command: hci.HCI_SyncCommand[_RP] | hci.HCI_AsyncCommand,
check_result: bool = False,
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event[_RP] | hci.HCI_Command_Status_Event:
response = await self._send_command(command, response_timeout)
# Check the return parameters if required
if check_result:
if isinstance(response, hci.HCI_Command_Status_Event):
status = response.status # type: ignore[attr-defined]
elif isinstance(response.return_parameters, int):
status = response.return_parameters
elif isinstance(response.return_parameters, bytes):
# return parameters first field is a one byte status code
status = response.return_parameters[0]
elif isinstance(
response.return_parameters, hci.HCI_GenericReturnParameters
):
# FIXME: temporary workaround
# NO STATUS
status = hci.HCI_SUCCESS
else:
status = response.return_parameters.status
if status != hci.HCI_SUCCESS:
logger.warning(
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return response
async def send_sync_command(
self,
command: hci.HCI_SyncCommand[_RP],
check_status: bool = True,
response_timeout: float | None = None,
) -> _RP:
response = await self._send_command(command, response_timeout)
# Check that the response is of the expected type
assert isinstance(response, hci.HCI_Command_Complete_Event)
return_parameters: _RP = response.return_parameters
assert isinstance(return_parameters, command.return_parameters_class)
# Check the return parameters if required
if check_status:
if isinstance(return_parameters, hci.HCI_StatusReturnParameters):
status = return_parameters.status
if status != hci.HCI_SUCCESS:
logger.warning(
f'{command.name} failed '
f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return return_parameters
async def send_async_command(
self,
command: hci.HCI_AsyncCommand,
check_status: bool = True,
response_timeout: float | None = None,
) -> hci.HCI_ErrorCode:
response = await self._send_command(command, response_timeout)
# Check that the response is of the expected type
assert isinstance(response, hci.HCI_Command_Status_Event)
# Check the return parameters if required
status = response.status
if check_status:
if status != hci.HCI_CommandStatus.PENDING:
logger.warning(
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return hci.HCI_ErrorCode(status)
@utils.deprecated("Use utils.AsyncRunner.spawn() instead.")
def send_command_sync(self, command: hci.HCI_AsyncCommand) -> None:
utils.AsyncRunner.spawn(self.send_async_command(command))
def send_acl_sdu(self, connection_handle: int, sdu: bytes) -> None:
if not (connection := self.connections.get(connection_handle)):
@@ -1338,15 +1401,17 @@ class Host(utils.EventEmitter):
# For now, just accept everything
# TODO: delegate the decision
self.send_command_sync(
hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
connection_handle=event.connection_handle,
interval_min=event.interval_min,
interval_max=event.interval_max,
max_latency=event.max_latency,
timeout=event.timeout,
min_ce_length=0,
max_ce_length=0,
utils.AsyncRunner.spawn(
self.send_sync_command(
hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
connection_handle=event.connection_handle,
interval_min=event.interval_min,
interval_max=event.interval_max,
max_latency=event.max_latency,
timeout=event.timeout,
min_ce_length=0,
max_ce_length=0,
)
)
)
@@ -1382,9 +1447,9 @@ class Host(utils.EventEmitter):
connection_handle=event.connection_handle
)
await self.send_command(response)
await self.send_sync_command(response)
asyncio.create_task(send_long_term_key())
utils.AsyncRunner.spawn(send_long_term_key())
def on_hci_synchronous_connection_complete_event(
self, event: hci.HCI_Synchronous_Connection_Complete_Event
@@ -1583,9 +1648,9 @@ class Host(utils.EventEmitter):
bd_addr=event.bd_addr
)
await self.send_command(response)
await self.send_sync_command(response)
asyncio.create_task(send_link_key())
utils.AsyncRunner.spawn(send_link_key())
def on_hci_io_capability_request_event(
self, event: hci.HCI_IO_Capability_Request_Event
+123 -62
View File
@@ -20,6 +20,7 @@ from __future__ import annotations
import asyncio
import dataclasses
import enum
import itertools
import logging
import struct
from collections import deque
@@ -302,11 +303,9 @@ class EnhancedControlField(ControlField):
@dataclasses.dataclass
class InformationEnhancedControlField(EnhancedControlField):
tx_seq: int = 0
tx_seq: int
sar: int
req_seq: int = 0
segmentation_and_reassembly: int = (
EnhancedControlField.SegmentationAndReassembly.UNSEGMENTED
)
final: int = 1
frame_type = EnhancedControlField.FieldType.I_FRAME
@@ -316,15 +315,15 @@ class InformationEnhancedControlField(EnhancedControlField):
return cls(
tx_seq=(data[0] >> 1) & 0b0111111,
final=(data[0] >> 7) & 0b1,
req_seq=(data[1] & 0b001111111),
segmentation_and_reassembly=(data[1] >> 6) & 0b11,
req_seq=(data[1] & 0b00111111),
sar=(data[1] >> 6) & 0b11,
)
def __bytes__(self) -> bytes:
return bytes(
[
self.frame_type | (self.tx_seq << 1) | (self.final << 7),
self.req_seq | (self.segmentation_and_reassembly << 6),
self.req_seq | (self.sar << 6),
]
)
@@ -889,27 +888,38 @@ class EnhancedRetransmissionProcessor(Processor):
class _PendingPdu:
payload: bytes
tx_seq: int
sar: InformationEnhancedControlField.SegmentationAndReassembly
sdu_length: int = 0
req_seq: int = 0
def __bytes__(self) -> bytes:
return (
bytes(
InformationEnhancedControlField(
tx_seq=self.tx_seq, req_seq=self.req_seq
tx_seq=self.tx_seq,
req_seq=self.req_seq,
sar=self.sar,
)
)
+ (
struct.pack('<H', self.sdu_length)
if self.sar
== InformationEnhancedControlField.SegmentationAndReassembly.START
else b''
)
+ self.payload
)
_expected_ack_seq: int = 0
_last_acked_tx_seq: int = 0
_last_acked_rx_seq: int = 0
_next_tx_seq: int = 0
_last_tx_seq: int = 0
_req_seq_num: int = 0
_next_seq_num: int = 0
_remote_is_busy: bool = False
_in_sdu: bytes = b''
_num_receiver_ready_polls_sent: int = 0
_pending_pdus: list[_PendingPdu]
_tx_window: list[_PendingPdu]
_monitor_handle: asyncio.TimerHandle | None = None
_receiver_ready_poll_handle: asyncio.TimerHandle | None = None
@@ -917,12 +927,6 @@ class EnhancedRetransmissionProcessor(Processor):
monitor_timeout: float
retransmission_timeout: float
@classmethod
def _num_frames_between(cls, low: int, high: int) -> int:
if high < low:
high += cls.MAX_SEQ_NUM
return high - low
def __init__(
self,
channel: ClassicChannel,
@@ -935,6 +939,7 @@ class EnhancedRetransmissionProcessor(Processor):
self.peer_mps = peer_mps
self.peer_tx_window_size = peer_tx_window_size
self._pending_pdus = []
self._tx_window = []
self.monitor_timeout = spec.monitor_timeout
self.channel = channel
self.retransmission_timeout = spec.retransmission_timeout
@@ -972,12 +977,9 @@ class EnhancedRetransmissionProcessor(Processor):
def _send_receiver_ready_poll(self) -> None:
self._num_receiver_ready_polls_sent += 1
self.channel.send_pdu(
SupervisoryEnhancedControlField(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=1,
req_seq=self._next_seq_num,
)
self._send_s_frame(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=1,
)
def _get_next_tx_seq(self) -> int:
@@ -987,12 +989,35 @@ class EnhancedRetransmissionProcessor(Processor):
@override
def send_sdu(self, sdu: bytes) -> None:
if len(sdu) > self.peer_mps:
raise InvalidArgumentError(
f'SDU size({len(sdu)}) exceeds channel MPS {self.peer_mps}'
if len(sdu) <= self.peer_mps:
pdu = self._PendingPdu(
payload=sdu,
tx_seq=self._get_next_tx_seq(),
req_seq=self._req_seq_num,
sar=InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
)
pdu = self._PendingPdu(payload=sdu, tx_seq=self._get_next_tx_seq())
self._pending_pdus.append(pdu)
self._pending_pdus.append(pdu)
else:
for offset in range(0, len(sdu), self.peer_mps):
payload = sdu[offset : offset + self.peer_mps]
if offset == 0:
sar = (
InformationEnhancedControlField.SegmentationAndReassembly.START
)
elif offset + len(payload) >= len(sdu):
sar = InformationEnhancedControlField.SegmentationAndReassembly.END
else:
sar = (
InformationEnhancedControlField.SegmentationAndReassembly.CONTINUATION
)
pdu = self._PendingPdu(
payload=payload,
tx_seq=self._get_next_tx_seq(),
req_seq=self._req_seq_num,
sar=sar,
sdu_length=len(sdu),
)
self._pending_pdus.append(pdu)
self._process_output()
@override
@@ -1000,17 +1025,37 @@ class EnhancedRetransmissionProcessor(Processor):
control_field = EnhancedControlField.from_bytes(pdu)
self._update_ack_seq(control_field.req_seq, control_field.final != 0)
if isinstance(control_field, InformationEnhancedControlField):
if control_field.tx_seq != self._next_seq_num:
if control_field.tx_seq != self._req_seq_num:
logger.error(
"tx_seq != self._req_seq_num, tx_seq: %d, self._req_seq_num: %d",
control_field.tx_seq,
self._req_seq_num,
)
return
self._next_seq_num = (self._next_seq_num + 1) % self.MAX_SEQ_NUM
self._req_seq_num = self._next_seq_num
self._req_seq_num = (control_field.tx_seq + 1) % self.MAX_SEQ_NUM
ack_frame = SupervisoryEnhancedControlField(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
req_seq=self._next_seq_num,
)
self.channel.send_pdu(ack_frame)
self.channel.on_sdu(pdu[2:])
if (
control_field.sar
== InformationEnhancedControlField.SegmentationAndReassembly.START
):
# Drop Control Field(2) + SDU Length(2)
self._in_sdu += pdu[4:]
else:
# Drop Control Field(2)
self._in_sdu += pdu[2:]
if control_field.sar in (
InformationEnhancedControlField.SegmentationAndReassembly.END,
InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
):
self.channel.on_sdu(self._in_sdu)
self._in_sdu = b''
# If sink doesn't trigger any I-frame, ack this frame.
if self._req_seq_num != self._last_acked_rx_seq:
self._send_s_frame(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=0,
)
elif isinstance(control_field, SupervisoryEnhancedControlField):
self._remote_is_busy = (
control_field.supervision_function
@@ -1022,56 +1067,66 @@ class EnhancedRetransmissionProcessor(Processor):
SupervisoryEnhancedControlField.SupervisoryFunction.RNR,
):
if control_field.poll:
self.channel.send_pdu(
SupervisoryEnhancedControlField(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=1,
req_seq=self._next_seq_num,
)
self._send_s_frame(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=1,
)
else:
# TODO: Handle Retransmission.
pass
def _process_output(self) -> None:
if self._remote_is_busy or self._monitor_handle:
if self._remote_is_busy:
logger.debug("Remote is busy")
return
if self._monitor_handle:
logger.debug("Monitor handle is not None")
return
for pdu in self._pending_pdus:
if self._num_unacked_frames >= self.peer_tx_window_size:
return
self._send_pdu(pdu)
self._last_tx_seq = pdu.tx_seq
pdu_to_send = self.peer_tx_window_size - len(self._tx_window)
for pdu in itertools.islice(self._pending_pdus, pdu_to_send):
self._send_i_frame(pdu)
self._pending_pdus = self._pending_pdus[pdu_to_send:]
@property
def _num_unacked_frames(self) -> int:
if not self._pending_pdus:
return 0
return self._num_frames_between(self._expected_ack_seq, self._last_tx_seq + 1)
def _send_pdu(self, pdu: _PendingPdu) -> None:
def _send_i_frame(self, pdu: _PendingPdu) -> None:
pdu.req_seq = self._req_seq_num
self._start_receiver_ready_poll()
self._tx_window.append(pdu)
self.channel.send_pdu(bytes(pdu))
self._last_acked_rx_seq = self._req_seq_num
def _send_s_frame(
self,
supervision_function: SupervisoryEnhancedControlField.SupervisoryFunction,
final: int,
) -> None:
self.channel.send_pdu(
SupervisoryEnhancedControlField(
supervision_function=supervision_function,
final=final,
req_seq=self._req_seq_num,
)
)
self._last_acked_rx_seq = self._req_seq_num
def _update_ack_seq(self, new_seq: int, is_poll_response: bool) -> None:
num_frames_acked = self._num_frames_between(self._expected_ack_seq, new_seq)
if num_frames_acked > self._num_unacked_frames:
num_frames_acked = (new_seq - self._last_acked_tx_seq) % self.MAX_SEQ_NUM
if num_frames_acked > len(self._tx_window):
logger.error(
"Received acknowledgment for %d frames but only %d frames are pending",
num_frames_acked,
self._num_unacked_frames,
len(self._tx_window),
)
return
if is_poll_response and self._monitor_handle:
self._monitor_handle.cancel()
self._monitor_handle = None
del self._pending_pdus[:num_frames_acked]
self._expected_ack_seq = new_seq
del self._tx_window[:num_frames_acked]
self._last_acked_tx_seq = new_seq
if (
self._expected_ack_seq == self._next_tx_seq
self._last_acked_tx_seq == self._next_tx_seq
and self._receiver_ready_poll_handle
):
self._receiver_ready_poll_handle.cancel()
@@ -1552,6 +1607,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
EVENT_OPEN = "open"
EVENT_CLOSE = "close"
EVENT_ATT_MTU_UPDATE = "att_mtu_update"
def __init__(
self,
@@ -1591,6 +1647,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
self.connection_result = None
self.disconnection_result = None
self.drained = asyncio.Event()
self.att_mtu = 0 # Filled by GATT client or server later.
self.drained.set()
@@ -1821,6 +1878,10 @@ class LeCreditBasedChannel(utils.EventEmitter):
self.disconnection_result.set_result(None)
self.disconnection_result = None
def on_att_mtu_update(self, mtu: int) -> None:
self.att_mtu = mtu
self.emit(self.EVENT_ATT_MTU_UPDATE, mtu)
def flush_output(self) -> None:
self.out_queue.clear()
self.out_sdu = None
+24 -33
View File
@@ -16,35 +16,28 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from collections.abc import Callable
from bumble.gatt import (
GATT_BATTERY_LEVEL_CHARACTERISTIC,
GATT_BATTERY_SERVICE,
Characteristic,
CharacteristicValue,
TemplateService,
)
from bumble.gatt_adapters import (
PackedCharacteristicAdapter,
PackedCharacteristicProxyAdapter,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy
from bumble import device, gatt, gatt_adapters, gatt_client
# -----------------------------------------------------------------------------
class BatteryService(TemplateService):
UUID = GATT_BATTERY_SERVICE
class BatteryService(gatt.TemplateService):
UUID = gatt.GATT_BATTERY_SERVICE
BATTERY_LEVEL_FORMAT = 'B'
battery_level_characteristic: Characteristic[int]
battery_level_characteristic: gatt.Characteristic[int]
def __init__(self, read_battery_level):
self.battery_level_characteristic = PackedCharacteristicAdapter(
Characteristic(
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
CharacteristicValue(read=read_battery_level),
def __init__(self, read_battery_level: Callable[[device.Connection], int]) -> None:
self.battery_level_characteristic = gatt_adapters.PackedCharacteristicAdapter(
gatt.Characteristic(
gatt.GATT_BATTERY_LEVEL_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.READABLE,
value=gatt.CharacteristicValue(read=read_battery_level),
),
pack_format=BatteryService.BATTERY_LEVEL_FORMAT,
)
@@ -52,19 +45,17 @@ class BatteryService(TemplateService):
# -----------------------------------------------------------------------------
class BatteryServiceProxy(ProfileServiceProxy):
class BatteryServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = BatteryService
battery_level: CharacteristicProxy[int] | None
battery_level: gatt_client.CharacteristicProxy[int]
def __init__(self, service_proxy):
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BATTERY_LEVEL_CHARACTERISTIC
):
self.battery_level = PackedCharacteristicProxyAdapter(
characteristics[0], pack_format=BatteryService.BATTERY_LEVEL_FORMAT
)
else:
self.battery_level = None
self.battery_level = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_BATTERY_LEVEL_CHARACTERISTIC
),
pack_format=BatteryService.BATTERY_LEVEL_FORMAT,
)
+128 -119
View File
@@ -18,40 +18,30 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import enum
import struct
from enum import IntEnum
from collections.abc import Callable, Sequence
from typing import Any
from bumble import core
from bumble.att import ATT_Error
from bumble.gatt import (
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
GATT_HEART_RATE_SERVICE,
Characteristic,
CharacteristicValue,
TemplateService,
)
from bumble.gatt_adapters import (
DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter,
SerializableCharacteristicAdapter,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy
from typing_extensions import Self
from bumble import att, core, device, gatt, gatt_adapters, gatt_client, utils
# -----------------------------------------------------------------------------
class HeartRateService(TemplateService):
UUID = GATT_HEART_RATE_SERVICE
class HeartRateService(gatt.TemplateService):
UUID = gatt.GATT_HEART_RATE_SERVICE
HEART_RATE_CONTROL_POINT_FORMAT = 'B'
CONTROL_POINT_NOT_SUPPORTED = 0x80
RESET_ENERGY_EXPENDED = 0x01
heart_rate_measurement_characteristic: Characteristic[HeartRateMeasurement]
body_sensor_location_characteristic: Characteristic[BodySensorLocation]
heart_rate_control_point_characteristic: Characteristic[int]
heart_rate_measurement_characteristic: gatt.Characteristic[HeartRateMeasurement]
body_sensor_location_characteristic: gatt.Characteristic[BodySensorLocation]
heart_rate_control_point_characteristic: gatt.Characteristic[int]
class BodySensorLocation(IntEnum):
class BodySensorLocation(utils.OpenIntEnum):
OTHER = 0
CHEST = 1
WRIST = 2
@@ -60,82 +50,90 @@ class HeartRateService(TemplateService):
EAR_LOBE = 5
FOOT = 6
@dataclasses.dataclass
class HeartRateMeasurement:
def __init__(
self,
heart_rate,
sensor_contact_detected=None,
energy_expended=None,
rr_intervals=None,
):
if heart_rate < 0 or heart_rate > 0xFFFF:
heart_rate: int
sensor_contact_detected: bool | None = None
energy_expended: int | None = None
rr_intervals: Sequence[float] | None = None
class Flag(enum.IntFlag):
INT16_HEART_RATE = 1 << 0
SENSOR_CONTACT_DETECTED = 1 << 1
SENSOR_CONTACT_SUPPORTED = 1 << 2
ENERGY_EXPENDED_STATUS = 1 << 3
RR_INTERVAL = 1 << 4
def __post_init__(self) -> None:
if self.heart_rate < 0 or self.heart_rate > 0xFFFF:
raise core.InvalidArgumentError('heart_rate out of range')
if energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF
if self.energy_expended is not None and (
self.energy_expended < 0 or self.energy_expended > 0xFFFF
):
raise core.InvalidArgumentError('energy_expended out of range')
if rr_intervals:
for rr_interval in rr_intervals:
if self.rr_intervals:
for rr_interval in self.rr_intervals:
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise core.InvalidArgumentError('rr_intervals out of range')
self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected
self.energy_expended = energy_expended
self.rr_intervals = rr_intervals
@classmethod
def from_bytes(cls, data):
def from_bytes(cls, data: bytes) -> Self:
flags = data[0]
offset = 1
if flags & 1:
hr = struct.unpack_from('<H', data, offset)[0]
if flags & cls.Flag.INT16_HEART_RATE:
heart_rate = struct.unpack_from('<H', data, offset)[0]
offset += 2
else:
hr = struct.unpack_from('B', data, offset)[0]
heart_rate = struct.unpack_from('B', data, offset)[0]
offset += 1
if flags & (1 << 2):
sensor_contact_detected = flags & (1 << 1) != 0
if flags & cls.Flag.SENSOR_CONTACT_SUPPORTED:
sensor_contact_detected = flags & cls.Flag.SENSOR_CONTACT_DETECTED != 0
else:
sensor_contact_detected = None
if flags & (1 << 3):
if flags & cls.Flag.ENERGY_EXPENDED_STATUS:
energy_expended = struct.unpack_from('<H', data, offset)[0]
offset += 2
else:
energy_expended = None
if flags & (1 << 4):
rr_intervals: Sequence[float] | None = None
if flags & cls.Flag.RR_INTERVAL:
rr_intervals = tuple(
struct.unpack_from('<H', data, offset + i * 2)[0] / 1024
for i in range((len(data) - offset) // 2)
struct.unpack_from('<H', data, i)[0] / 1024
for i in range(offset, len(data), 2)
)
else:
rr_intervals = ()
return cls(hr, sensor_contact_detected, energy_expended, rr_intervals)
return cls(
heart_rate=heart_rate,
sensor_contact_detected=sensor_contact_detected,
energy_expended=energy_expended,
rr_intervals=rr_intervals,
)
def __bytes__(self):
def __bytes__(self) -> bytes:
flags = 0
if self.heart_rate < 256:
flags = 0
data = struct.pack('B', self.heart_rate)
else:
flags = 1
flags |= self.Flag.INT16_HEART_RATE
data = struct.pack('<H', self.heart_rate)
if self.sensor_contact_detected is not None:
flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2)
flags |= self.Flag.SENSOR_CONTACT_SUPPORTED
if self.sensor_contact_detected:
flags |= self.Flag.SENSOR_CONTACT_DETECTED
if self.energy_expended is not None:
flags |= 1 << 3
flags |= self.Flag.ENERGY_EXPENDED_STATUS
data += struct.pack('<H', self.energy_expended)
if self.rr_intervals:
flags |= 1 << 4
if self.rr_intervals is not None:
flags |= self.Flag.RR_INTERVAL
data += b''.join(
[
struct.pack('<H', int(rr_interval * 1024))
@@ -145,57 +143,67 @@ class HeartRateService(TemplateService):
return bytes([flags]) + data
def __str__(self):
return (
f'HeartRateMeasurement(heart_rate={self.heart_rate},'
f' sensor_contact_detected={self.sensor_contact_detected},'
f' energy_expended={self.energy_expended},'
f' rr_intervals={self.rr_intervals})'
)
def __init__(
self,
read_heart_rate_measurement,
body_sensor_location=None,
reset_energy_expended=None,
read_heart_rate_measurement: Callable[
[device.Connection], HeartRateMeasurement
],
body_sensor_location: HeartRateService.BodySensorLocation | None = None,
reset_energy_expended: Callable[[device.Connection], Any] | None = None,
):
self.heart_rate_measurement_characteristic = SerializableCharacteristicAdapter(
Characteristic(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Characteristic.Properties.NOTIFY,
0,
CharacteristicValue(read=read_heart_rate_measurement),
),
HeartRateService.HeartRateMeasurement,
self.heart_rate_measurement_characteristic = (
gatt_adapters.SerializableCharacteristicAdapter(
gatt.Characteristic(
uuid=gatt.GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions(0),
value=gatt.CharacteristicValue(read=read_heart_rate_measurement),
),
HeartRateService.HeartRateMeasurement,
)
)
characteristics = [self.heart_rate_measurement_characteristic]
characteristics: list[gatt.Characteristic] = [
self.heart_rate_measurement_characteristic
]
if body_sensor_location is not None:
self.body_sensor_location_characteristic = Characteristic(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
bytes([int(body_sensor_location)]),
self.body_sensor_location_characteristic = (
gatt_adapters.EnumCharacteristicAdapter(
gatt.Characteristic(
uuid=gatt.GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.READABLE,
value=body_sensor_location,
),
cls=self.BodySensorLocation,
length=1,
)
)
characteristics.append(self.body_sensor_location_characteristic)
if reset_energy_expended:
def write_heart_rate_control_point_value(connection, value):
def write_heart_rate_control_point_value(
connection: device.Connection, value: bytes
) -> None:
if value == self.RESET_ENERGY_EXPENDED:
if reset_energy_expended is not None:
reset_energy_expended(connection)
else:
raise ATT_Error(self.CONTROL_POINT_NOT_SUPPORTED)
raise att.ATT_Error(self.CONTROL_POINT_NOT_SUPPORTED)
self.heart_rate_control_point_characteristic = PackedCharacteristicAdapter(
Characteristic(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE,
Characteristic.WRITEABLE,
CharacteristicValue(write=write_heart_rate_control_point_value),
),
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
self.heart_rate_control_point_characteristic = (
gatt_adapters.PackedCharacteristicAdapter(
gatt.Characteristic(
uuid=gatt.GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.WRITEABLE,
value=gatt.CharacteristicValue(
write=write_heart_rate_control_point_value
),
),
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
)
characteristics.append(self.heart_rate_control_point_characteristic)
@@ -203,50 +211,51 @@ class HeartRateService(TemplateService):
# -----------------------------------------------------------------------------
class HeartRateServiceProxy(ProfileServiceProxy):
class HeartRateServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HeartRateService
heart_rate_measurement: (
CharacteristicProxy[HeartRateService.HeartRateMeasurement] | None
)
heart_rate_measurement: gatt_client.CharacteristicProxy[
HeartRateService.HeartRateMeasurement
]
body_sensor_location: (
CharacteristicProxy[HeartRateService.BodySensorLocation] | None
gatt_client.CharacteristicProxy[HeartRateService.BodySensorLocation] | None
)
heart_rate_control_point: CharacteristicProxy[int] | None
heart_rate_control_point: gatt_client.CharacteristicProxy[int] | None
def __init__(self, service_proxy):
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
):
self.heart_rate_measurement = SerializableCharacteristicAdapter(
characteristics[0], HeartRateService.HeartRateMeasurement
self.heart_rate_measurement = (
gatt_adapters.SerializableCharacteristicProxyAdapter(
service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
),
HeartRateService.HeartRateMeasurement,
)
else:
self.heart_rate_measurement = None
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
gatt.GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
):
self.body_sensor_location = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: HeartRateService.BodySensorLocation(value[0]),
self.body_sensor_location = gatt_adapters.EnumCharacteristicProxyAdapter(
characteristics[0], cls=HeartRateService.BodySensorLocation, length=1
)
else:
self.body_sensor_location = None
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
gatt.GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
):
self.heart_rate_control_point = PackedCharacteristicAdapter(
characteristics[0],
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
self.heart_rate_control_point = (
gatt_adapters.PackedCharacteristicProxyAdapter(
characteristics[0],
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
)
else:
self.heart_rate_control_point = None
async def reset_energy_expended(self):
async def reset_energy_expended(self) -> None:
if self.heart_rate_control_point is not None:
return await self.heart_rate_control_point.write_value(
HeartRateService.RESET_ENERGY_EXPENDED
+108 -86
View File
@@ -43,44 +43,53 @@ hci.HCI_Command.register_commands(globals())
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_LE_Get_Vendor_Capabilities_Command(hci.HCI_Command):
class HCI_LE_Get_Vendor_Capabilities_ReturnParameters(hci.HCI_StatusReturnParameters):
max_advt_instances: int = field(metadata=hci.metadata(1), default=0)
offloaded_resolution_of_private_address: int = field(
metadata=hci.metadata(1), default=0
)
total_scan_results_storage: int = field(metadata=hci.metadata(2), default=0)
max_irk_list_sz: int = field(metadata=hci.metadata(1), default=0)
filtering_support: int = field(metadata=hci.metadata(1), default=0)
max_filter: int = field(metadata=hci.metadata(1), default=0)
activity_energy_info_support: int = field(metadata=hci.metadata(1), default=0)
version_supported: int = field(metadata=hci.metadata(2), default=0)
total_num_of_advt_tracked: int = field(metadata=hci.metadata(2), default=0)
extended_scan_support: int = field(metadata=hci.metadata(1), default=0)
debug_logging_supported: int = field(metadata=hci.metadata(1), default=0)
le_address_generation_offloading_support: int = field(
metadata=hci.metadata(1), default=0
)
a2dp_source_offload_capability_mask: int = field(
metadata=hci.metadata(4), default=0
)
bluetooth_quality_report_support: int = field(metadata=hci.metadata(1), default=0)
dynamic_audio_buffer_support: int = field(metadata=hci.metadata(4), default=0)
@hci.HCI_SyncCommand.sync_command(HCI_LE_Get_Vendor_Capabilities_ReturnParameters)
@dataclasses.dataclass
class HCI_LE_Get_Vendor_Capabilities_Command(
hci.HCI_SyncCommand[HCI_LE_Get_Vendor_Capabilities_ReturnParameters]
):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#vendor-specific-capabilities
'''
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('max_advt_instances', 1),
('offloaded_resolution_of_private_address', 1),
('total_scan_results_storage', 2),
('max_irk_list_sz', 1),
('filtering_support', 1),
('max_filter', 1),
('activity_energy_info_support', 1),
('version_supported', 2),
('total_num_of_advt_tracked', 2),
('extended_scan_support', 1),
('debug_logging_supported', 1),
('le_address_generation_offloading_support', 1),
('a2dp_source_offload_capability_mask', 4),
('bluetooth_quality_report_support', 1),
('dynamic_audio_buffer_support', 4),
]
@classmethod
def parse_return_parameters(cls, parameters):
# There are many versions of this data structure, so we need to parse until
# there are no more bytes to parse, and leave un-signal parameters set to
# None (older versions)
nones = {field: None for field, _ in cls.return_parameters_fields}
return_parameters = hci.HCI_Object(cls.return_parameters_fields, **nones)
# there are no more bytes to parse, and leave un-signaled parameters set to
# 0
return_parameters = HCI_LE_Get_Vendor_Capabilities_ReturnParameters(
hci.HCI_ErrorCode.SUCCESS
)
try:
offset = 0
for field in cls.return_parameters_fields:
for field in cls.return_parameters_class.fields:
field_name, field_type = field
field_value, field_size = hci.HCI_Object.parse_field(
parameters, offset, field_type
@@ -94,9 +103,30 @@ class HCI_LE_Get_Vendor_Capabilities_Command(hci.HCI_Command):
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
# APCF Subcommands
class LeApcfOpcode(hci.SpecableEnum):
ENABLE = 0x00
SET_FILTERING_PARAMETERS = 0x01
BROADCASTER_ADDRESS = 0x02
SERVICE_UUID = 0x03
SERVICE_SOLICITATION_UUID = 0x04
LOCAL_NAME = 0x05
MANUFACTURER_DATA = 0x06
SERVICE_DATA = 0x07
TRANSPORT_DISCOVERY_SERVICE = 0x08
AD_TYPE_FILTER = 0x09
READ_EXTENDED_FEATURES = 0xFF
@dataclasses.dataclass
class HCI_LE_APCF_Command(hci.HCI_Command):
class HCI_LE_APCF_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = field(metadata=LeApcfOpcode.type_metadata(1))
payload: bytes = field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_LE_APCF_ReturnParameters)
@dataclasses.dataclass
class HCI_LE_APCF_Command(hci.HCI_SyncCommand[HCI_LE_APCF_ReturnParameters]):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_apcf_command
@@ -105,52 +135,52 @@ class HCI_LE_APCF_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures.
'''
# APCF Subcommands
class Opcode(hci.SpecableEnum):
ENABLE = 0x00
SET_FILTERING_PARAMETERS = 0x01
BROADCASTER_ADDRESS = 0x02
SERVICE_UUID = 0x03
SERVICE_SOLICITATION_UUID = 0x04
LOCAL_NAME = 0x05
MANUFACTURER_DATA = 0x06
SERVICE_DATA = 0x07
TRANSPORT_DISCOVERY_SERVICE = 0x08
AD_TYPE_FILTER = 0x09
READ_EXTENDED_FEATURES = 0xFF
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
opcode: int = dataclasses.field(metadata=LeApcfOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Get_Controller_Activity_Energy_Info_Command(hci.HCI_Command):
class HCI_Get_Controller_Activity_Energy_Info_ReturnParameters(
hci.HCI_StatusReturnParameters
):
total_tx_time_ms: int = field(metadata=hci.metadata(4))
total_rx_time_ms: int = field(metadata=hci.metadata(4))
total_idle_time_ms: int = field(metadata=hci.metadata(4))
total_energy_used: int = field(metadata=hci.metadata(4))
@hci.HCI_SyncCommand.sync_command(
HCI_Get_Controller_Activity_Energy_Info_ReturnParameters
)
@dataclasses.dataclass
class HCI_Get_Controller_Activity_Energy_Info_Command(
hci.HCI_SyncCommand[HCI_Get_Controller_Activity_Energy_Info_ReturnParameters]
):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_get_controller_activity_energy_info
'''
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('total_tx_time_ms', 4),
('total_rx_time_ms', 4),
('total_idle_time_ms', 4),
('total_energy_used', 4),
]
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
# A2DP Hardware Offload Subcommands
class A2dpHardwareOffloadOpcode(hci.SpecableEnum):
START_A2DP_OFFLOAD = 0x01
STOP_A2DP_OFFLOAD = 0x02
@dataclasses.dataclass
class HCI_A2DP_Hardware_Offload_Command(hci.HCI_Command):
class HCI_A2DP_Hardware_Offload_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = dataclasses.field(metadata=A2dpHardwareOffloadOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
@hci.HCI_SyncCommand.sync_command(HCI_A2DP_Hardware_Offload_ReturnParameters)
@dataclasses.dataclass
class HCI_A2DP_Hardware_Offload_Command(
hci.HCI_SyncCommand[HCI_A2DP_Hardware_Offload_ReturnParameters]
):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#a2dp-hardware-offload-support
@@ -159,25 +189,27 @@ class HCI_A2DP_Hardware_Offload_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures.
'''
# A2DP Hardware Offload Subcommands
class Opcode(hci.SpecableEnum):
START_A2DP_OFFLOAD = 0x01
STOP_A2DP_OFFLOAD = 0x02
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
opcode: int = dataclasses.field(metadata=A2dpHardwareOffloadOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
# Dynamic Audio Buffer Subcommands
class DynamicAudioBufferOpcode(hci.SpecableEnum):
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
@dataclasses.dataclass
class HCI_Dynamic_Audio_Buffer_Command(hci.HCI_Command):
class HCI_Dynamic_Audio_Buffer_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = dataclasses.field(metadata=DynamicAudioBufferOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
@hci.HCI_SyncCommand.sync_command(HCI_Dynamic_Audio_Buffer_ReturnParameters)
@dataclasses.dataclass
class HCI_Dynamic_Audio_Buffer_Command(
hci.HCI_SyncCommand[HCI_Dynamic_Audio_Buffer_ReturnParameters]
):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#dynamic-audio-buffer-command
@@ -186,19 +218,9 @@ class HCI_Dynamic_Audio_Buffer_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures.
'''
# Dynamic Audio Buffer Subcommands
class Opcode(hci.SpecableEnum):
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
opcode: int = dataclasses.field(metadata=DynamicAudioBufferOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# -----------------------------------------------------------------------------
class HCI_Android_Vendor_Event(hci.HCI_Extended_Event):
+24 -18
View File
@@ -46,9 +46,19 @@ class TX_Power_Level_Command:
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Write_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
class HCI_Write_Tx_Power_Level_ReturnParameters(hci.HCI_StatusReturnParameters):
handle_type: int = hci.field(metadata=hci.metadata(1))
connection_handle: int = hci.field(metadata=hci.metadata(2))
selected_tx_power_level: int = hci.field(metadata=hci.metadata(-1))
@hci.HCI_SyncCommand.sync_command(HCI_Write_Tx_Power_Level_ReturnParameters)
@dataclasses.dataclass
class HCI_Write_Tx_Power_Level_Command(
hci.HCI_SyncCommand[HCI_Write_Tx_Power_Level_ReturnParameters],
TX_Power_Level_Command,
):
'''
Write TX power level. See BT_HCI_OP_VS_WRITE_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
@@ -61,18 +71,21 @@ class HCI_Write_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
connection_handle: int = dataclasses.field(metadata=hci.metadata(2))
tx_power_level: int = dataclasses.field(metadata=hci.metadata(-1))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('selected_tx_power_level', -1),
]
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Read_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
class HCI_Read_Tx_Power_Level_ReturnParameters(hci.HCI_StatusReturnParameters):
handle_type: int = hci.field(metadata=hci.metadata(1))
connection_handle: int = hci.field(metadata=hci.metadata(2))
tx_power_level: int = hci.field(metadata=hci.metadata(-1))
@hci.HCI_SyncCommand.sync_command(HCI_Read_Tx_Power_Level_ReturnParameters)
@dataclasses.dataclass
class HCI_Read_Tx_Power_Level_Command(
hci.HCI_SyncCommand[HCI_Read_Tx_Power_Level_ReturnParameters],
TX_Power_Level_Command,
):
'''
Read TX power level. See BT_HCI_OP_VS_READ_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
@@ -83,10 +96,3 @@ class HCI_Read_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
handle_type: int = dataclasses.field(metadata=hci.metadata(1))
connection_handle: int = dataclasses.field(metadata=hci.metadata(2))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('tx_power_level', -1),
]
+1 -1
View File
@@ -63,7 +63,7 @@ HCI sockets provide a way to send/receive HCI packets to/from a Bluetooth contro
See the [HCI Socket Transport page](../transports/hci_socket.md) for details on the `hci-socket` tansport syntax.
The HCI device referenced by an `hci-socket` transport (`hci<X>`, where `<X>` is an integer, with `hci0` being the first controller device, and so on) must be in the `DOWN` state before it can be opened as a transport.
You can bring a HCI controller `UP` or `DOWN` with `hciconfig hci<X> up` and `hciconfig hci<X> up`.
You can bring a HCI controller `UP` or `DOWN` with `hciconfig hci<X> up` and `hciconfig hci<X> down`.
!!! tip "HCI Socket Permissions"
By default, when running as a regular user, you won't have the permission to use
+5 -4
View File
@@ -37,15 +37,16 @@ The vendor specific HCI commands to read and write TX power are defined in
from bumble.vendor.zephyr.hci import HCI_Write_Tx_Power_Level_Command
# set advertising power to -4 dB
response = await host.send_command(
response = await host.send_sync_command(
HCI_Write_Tx_Power_Level_Command(
handle_type=HCI_Write_Tx_Power_Level_Command.TX_POWER_HANDLE_TYPE_ADV,
connection_handle=0,
tx_power_level=-4,
)
),
check_status=False
)
if response.return_parameters.status == HCI_SUCCESS:
print(f"TX power set to {response.return_parameters.selected_tx_power_level}")
if response.status == HCI_SUCCESS:
print(f"TX power set to {response.selected_tx_power_level}")
```
+2 -2
View File
@@ -71,8 +71,8 @@ async def main() -> None:
rr_intervals=random.choice(
(
(
random.randint(900, 1100) / 1000,
random.randint(900, 1100) / 1000,
random.randint(900, 1100) // 1000,
random.randint(900, 1100) // 1000,
),
None,
)
+12 -9
View File
@@ -19,10 +19,10 @@ import asyncio
import sys
import bumble.logging
from bumble import gatt_client
from bumble.colors import color
from bumble.core import ProtocolError
from bumble.device import Device, Peer
from bumble.gatt import show_services
from bumble.device import Connection, Device
from bumble.transport import open_transport
from bumble.utils import AsyncRunner
@@ -34,24 +34,27 @@ class Listener(Device.Listener):
@AsyncRunner.run_in_task()
# pylint: disable=invalid-overridden-method
async def on_connection(self, connection):
async def on_connection(self, connection: Connection):
print(f'=== Connected to {connection}')
# Discover all services
print('=== Discovering services')
peer = Peer(connection)
await peer.discover_services()
for service in peer.services:
if connection.device.config.eatt_enabled:
client = await gatt_client.Client.connect_eatt(connection)
else:
client = connection.gatt_client
await client.discover_services()
for service in client.services:
await service.discover_characteristics()
for characteristic in service.characteristics:
await characteristic.discover_descriptors()
print('=== Services discovered')
show_services(peer.services)
gatt_client.show_services(client.services)
# Discover all attributes
print('=== Discovering attributes')
attributes = await peer.discover_attributes()
attributes = await client.discover_attributes()
for attribute in attributes:
print(attribute)
print('=== Attributes discovered')
@@ -59,7 +62,7 @@ class Listener(Device.Listener):
# Read all attributes
for attribute in attributes:
try:
value = await peer.read_value(attribute)
value = await client.read_value(attribute)
print(color(f'0x{attribute.handle:04X} = {value.hex()}', 'green'))
except ProtocolError as error:
print(color(f'cannot read {attribute.handle:04X}:', 'red'), error)
+1 -1
View File
@@ -17,6 +17,6 @@ use pyo3::PyResult;
#[pyo3_asyncio::tokio::test]
async fn realtek_driver_info_all_drivers() -> PyResult<()> {
assert_eq!(12, DriverInfo::all_drivers()?.len());
assert_eq!(13, DriverInfo::all_drivers()?.len());
Ok(())
}
+34
View File
@@ -0,0 +1,34 @@
# Copyright 2021-2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from bumble import device as device_module
from bumble.profiles import battery_service
from . import test_utils
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_battery_level():
devices = await test_utils.TwoDevices.create_with_connection()
service = battery_service.BatteryService(lambda _: 1)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(battery_service.BatteryServiceProxy)
assert client
assert await client.battery_level.read_value() == 1
+6 -7
View File
@@ -42,7 +42,6 @@ from bumble.hci import (
HCI_CREATE_CONNECTION_COMMAND,
HCI_SUCCESS,
Address,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_Connection_Complete_Event,
HCI_Connection_Request_Event,
@@ -154,10 +153,10 @@ async def test_device_connect_parallel():
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d1.host.on_hci_packet(
HCI_Command_Complete_Event(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
)
)
@@ -188,10 +187,10 @@ async def test_device_connect_parallel():
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d2.host.on_hci_packet(
HCI_Command_Complete_Event(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
)
)
@@ -292,9 +291,9 @@ async def test_legacy_advertising_disconnection(auto_restart):
await devices[0].start_advertising(
auto_restart=auto_restart, advertising_interval_min=1.0
)
connecion = await devices[1].connect(devices[0].random_address)
connection = await devices[1].connect(devices[0].random_address)
await connecion.disconnect()
await connection.disconnect()
await async_barrier()
await async_barrier()
+180 -5
View File
@@ -28,6 +28,7 @@ from unittest.mock import ANY, AsyncMock, Mock
import pytest
from typing_extensions import Self
from bumble import gatt_client, l2cap
from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_PDU,
@@ -63,7 +64,6 @@ from bumble.gatt_adapters import (
UTF8CharacteristicAdapter,
UTF8CharacteristicProxyAdapter,
)
from bumble.gatt_client import CharacteristicProxy
from .test_utils import Devices, TwoDevices, async_barrier
@@ -140,7 +140,7 @@ async def test_characteristic_encoding():
await c.write_value(Mock(), bytes([122]))
assert c.value == 122
class FooProxy(CharacteristicProxy):
class FooProxy(gatt_client.CharacteristicProxy):
def __init__(self, characteristic):
super().__init__(
characteristic.client,
@@ -456,7 +456,7 @@ async def test_CharacteristicProxyAdapter() -> None:
async def write_value(self, handle, value, with_response=False):
self.value = value
class TestAttributeProxy(CharacteristicProxy):
class TestAttributeProxy(gatt_client.CharacteristicProxy):
def __init__(self, value) -> None:
super().__init__(Client(value), 0, 0, None, 0) # type: ignore
@@ -1425,10 +1425,10 @@ async def test_get_characteristics_by_uuid():
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))
assert len(c) == 2
assert isinstance(c[0], CharacteristicProxy)
assert isinstance(c[0], gatt_client.CharacteristicProxy)
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('ABCD'))
assert len(c) == 1
assert isinstance(c[0], CharacteristicProxy)
assert isinstance(c[0], gatt_client.CharacteristicProxy)
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('AAAA'))
assert len(c) == 0
@@ -1463,6 +1463,181 @@ async def test_write_return_error():
assert e.value.error_code == ErrorCode.VALUE_NOT_ALLOWED
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_read():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
characteristic = Characteristic(
'1234',
Characteristic.Properties.READ,
Characteristic.Permissions.READABLE,
b'9999',
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
client = await gatt_client.Client.connect_eatt(devices.connections[0])
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
assert await characteristic_proxy.read_value() == b'9999'
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_write():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
write_queue = asyncio.Queue()
characteristic = Characteristic(
'1234',
Characteristic.Properties.WRITE,
Characteristic.Permissions.WRITEABLE,
CharacteristicValue(write=lambda *args: write_queue.put_nowait(args)),
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
client = await gatt_client.Client.connect_eatt(devices.connections[0])
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
await characteristic_proxy.write_value(b'9999')
assert await write_queue.get() == (devices.connections[1], b'9999')
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_notify():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
characteristic = Characteristic(
'1234',
Characteristic.Properties.NOTIFY,
Characteristic.Permissions.WRITEABLE,
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
clients = [
(
devices.connections[0].gatt_client,
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
]
for client, queue in clients:
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
for client, queue in clients[:2]:
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
await characteristic_proxy.subscribe(queue.put_nowait, prefer_notify=True)
await devices[1].gatt_server.notify_subscribers(characteristic, b'1234')
for _, queue in clients[:2]:
assert await queue.get() == b'1234'
assert queue.empty()
assert clients[2][1].empty()
await devices[1].gatt_server.notify_subscriber(
devices.connections[1], characteristic, b'5678'
)
for _, queue in clients[:2]:
assert await queue.get() == b'5678'
assert queue.empty()
assert clients[2][1].empty()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_indicate():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
characteristic = Characteristic(
'1234',
Characteristic.Properties.INDICATE,
Characteristic.Permissions.WRITEABLE,
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
clients = [
(
devices.connections[0].gatt_client,
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
]
for client, queue in clients:
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
for client, queue in clients[:2]:
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
await characteristic_proxy.subscribe(queue.put_nowait, prefer_notify=False)
await devices[1].gatt_server.indicate_subscribers(characteristic, b'1234')
for _, queue in clients[:2]:
assert await queue.get() == b'1234'
assert queue.empty()
assert clients[2][1].empty()
await devices[1].gatt_server.indicate_subscriber(
devices.connections[1], characteristic, b'5678'
)
for _, queue in clients[:2]:
assert await queue.get() == b'5678'
assert queue.empty()
assert clients[2][1].empty()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_connection_failure():
devices = await TwoDevices.create_with_connection()
with pytest.raises(l2cap.L2capError):
await gatt_client.Client.connect_eatt(devices.connections[0])
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
+48 -45
View File
@@ -20,7 +20,7 @@ import struct
import pytest
from bumble import hci
from bumble import hci, utils
# -----------------------------------------------------------------------------
# pylint: disable=invalid-name
@@ -136,43 +136,25 @@ def test_HCI_LE_Channel_Selection_Algorithm_Event():
# -----------------------------------------------------------------------------
def test_HCI_Command_Complete_Event():
# With a serializable object
event = hci.HCI_Command_Complete_Event(
event1 = hci.HCI_Command_Complete_Event(
num_hci_command_packets=34,
command_opcode=hci.HCI_LE_READ_BUFFER_SIZE_COMMAND,
return_parameters=hci.HCI_LE_Read_Buffer_Size_Command.create_return_parameters(
return_parameters=hci.HCI_LE_Read_Buffer_Size_Command.return_parameters_class(
status=0,
le_acl_data_packet_length=1234,
total_num_le_acl_data_packets=56,
),
)
basic_check(event)
# With an arbitrary byte array
event = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=bytes([1, 2, 3, 4]),
)
basic_check(event)
# With a simple status as a 1-byte array
event = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=bytes([7]),
)
basic_check(event)
event = hci.HCI_Packet.from_bytes(bytes(event))
assert event.return_parameters == 7
basic_check(event1)
# With a simple status as an integer status
event = hci.HCI_Command_Complete_Event(
event3 = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=9,
return_parameters=hci.HCI_StatusReturnParameters(hci.HCI_ErrorCode(9)),
)
basic_check(event)
assert event.return_parameters == 9
basic_check(event3)
assert event3.return_parameters.status == 9
# -----------------------------------------------------------------------------
@@ -229,6 +211,28 @@ def test_HCI_Vendor_Event():
assert isinstance(parsed, hci.HCI_Vendor_Event)
# -----------------------------------------------------------------------------
def test_return_parameters() -> None:
params = hci.HCI_Reset_Command.parse_return_parameters(bytes.fromhex('3C'))
assert params.status == hci.HCI_ErrorCode.ADVERTISING_TIMEOUT_ERROR
assert isinstance(params.status, utils.OpenIntEnum)
params = hci.HCI_Read_BD_ADDR_Command.parse_return_parameters(
bytes.fromhex('3C001122334455')
)
assert params.status == hci.HCI_ErrorCode.ADVERTISING_TIMEOUT_ERROR
assert isinstance(params.status, utils.OpenIntEnum)
assert isinstance(params.bd_addr, hci.Address)
params = hci.HCI_Read_Local_Name_Command.parse_return_parameters(
bytes.fromhex('0068656c6c6f') + bytes(248 - 5)
)
assert params.status == hci.HCI_ErrorCode.SUCCESS
assert isinstance(params.local_name, bytes)
assert len(params.local_name) == 248
assert hci.map_null_terminated_utf8_string(params.local_name) == 'hello'
# -----------------------------------------------------------------------------
def test_HCI_Command():
command = hci.HCI_Command(op_code=0x5566)
@@ -291,7 +295,7 @@ def test_custom_le_meta_event():
for clazz in inspect.getmembers(hci)
if isinstance(clazz[1], type)
and issubclass(clazz[1], hci.HCI_Command)
and clazz[1] is not hci.HCI_Command
and clazz[1] not in (hci.HCI_Command, hci.HCI_SyncCommand, hci.HCI_AsyncCommand)
],
)
def test_hci_command_subclasses_op_code(clazz: type[hci.HCI_Command]):
@@ -620,21 +624,19 @@ def test_HCI_Read_Local_Supported_Codecs_Command_Complete():
# -----------------------------------------------------------------------------
def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete():
returned_parameters = (
hci.HCI_Read_Local_Supported_Codecs_V2_Command.parse_return_parameters(
bytes(
[
hci.HCI_SUCCESS,
3,
hci.CodecID.A_LOG,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL,
hci.CodecID.CVSD,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO,
hci.CodecID.LINEAR_PCM,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS,
0,
]
)
returned_parameters = hci.HCI_Read_Local_Supported_Codecs_V2_Command.parse_return_parameters(
bytes(
[
hci.HCI_SUCCESS,
3,
hci.CodecID.A_LOG,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_ACL,
hci.CodecID.CVSD,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_SCO,
hci.CodecID.LINEAR_PCM,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.LE_CIS,
0,
]
)
)
assert returned_parameters.standard_codec_ids == [
@@ -643,9 +645,9 @@ def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete():
hci.CodecID.LINEAR_PCM,
]
assert returned_parameters.standard_codec_transports == [
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_ACL,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_SCO,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.LE_CIS,
]
@@ -737,6 +739,7 @@ def run_test_commands():
if __name__ == '__main__':
run_test_events()
run_test_commands()
test_return_parameters()
test_address()
test_custom()
test_iso_data_packet()
+89
View File
@@ -0,0 +1,89 @@
# Copyright 2021-2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import itertools
from collections.abc import Sequence
import pytest
from bumble import device as device_module
from bumble.profiles import heart_rate_service
from . import test_utils
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
"heart_rate, sensor_contact_detected, energy_expanded, rr_intervals",
itertools.product(
(1, 1000), (True, False, None), (2, None), ((3.0, 4.0, 5.0), None)
),
)
async def test_read_measurement(
heart_rate: int,
sensor_contact_detected: bool | None,
energy_expanded: int | None,
rr_intervals: Sequence[int] | None,
):
devices = await test_utils.TwoDevices.create_with_connection()
measurement = heart_rate_service.HeartRateService.HeartRateMeasurement(
heart_rate, sensor_contact_detected, energy_expanded, rr_intervals
)
service = heart_rate_service.HeartRateService(lambda _: measurement)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(heart_rate_service.HeartRateServiceProxy)
assert client
assert await client.heart_rate_measurement.read_value() == measurement
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_body_sensor_location():
devices = await test_utils.TwoDevices.create_with_connection()
measurement = heart_rate_service.HeartRateService.HeartRateMeasurement(0)
location = heart_rate_service.HeartRateService.BodySensorLocation.FINGER
service = heart_rate_service.HeartRateService(
lambda _: measurement,
body_sensor_location=location,
)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(heart_rate_service.HeartRateServiceProxy)
assert client
assert client.body_sensor_location
assert await client.body_sensor_location.read_value() == location
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_reset_energy_expended() -> None:
devices = await test_utils.TwoDevices.create_with_connection()
measurement = heart_rate_service.HeartRateService.HeartRateMeasurement(1)
reset_energy_expended = asyncio.Queue[None]()
service = heart_rate_service.HeartRateService(
lambda _: measurement,
reset_energy_expended=lambda _: reset_energy_expended.put_nowait(None),
)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(heart_rate_service.HeartRateServiceProxy)
assert client
await client.reset_energy_expended()
await reset_energy_expended.get()
+66 -2
View File
@@ -15,6 +15,7 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import unittest
import unittest.mock
@@ -22,9 +23,17 @@ import unittest.mock
import pytest
from bumble.controller import Controller
from bumble.hci import HCI_AclDataPacket
from bumble.hci import (
HCI_AclDataPacket,
HCI_Command_Complete_Event,
HCI_Error,
HCI_ErrorCode,
HCI_Event,
HCI_Reset_Command,
HCI_StatusReturnParameters,
)
from bumble.host import DataPacketQueue, Host
from bumble.transport.common import AsyncPipeSink
from bumble.transport.common import AsyncPipeSink, TransportSink
# -----------------------------------------------------------------------------
# Logging
@@ -151,3 +160,58 @@ def test_data_packet_queue():
assert drain_listener.on_flow.call_count == 1
assert queue.queued == 15
assert queue.completed == 15
# -----------------------------------------------------------------------------
class Source:
terminated: asyncio.Future[None]
sink: TransportSink
def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink
class Sink:
response: HCI_Event
def __init__(self, source: Source, response: HCI_Event) -> None:
self.source = source
self.response = response
def on_packet(self, packet: bytes) -> None:
self.source.sink.on_packet(bytes(self.response))
@pytest.mark.asyncio
async def test_send_sync_command() -> None:
source = Source()
sink = Sink(
source,
HCI_Command_Complete_Event(
1,
HCI_Reset_Command.op_code,
HCI_StatusReturnParameters(status=HCI_ErrorCode.SUCCESS),
),
)
host = Host(source, sink)
# Sync command with success
response1 = await host.send_sync_command(HCI_Reset_Command())
assert response1.status == HCI_ErrorCode.SUCCESS
# Sync command with error status should raise
error_response = HCI_Command_Complete_Event(
1,
HCI_Reset_Command.op_code,
HCI_StatusReturnParameters(status=HCI_ErrorCode.COMMAND_DISALLOWED_ERROR),
)
sink.response = error_response
with pytest.raises(HCI_Error) as excinfo:
await host.send_sync_command(HCI_Reset_Command())
assert excinfo.value.error_code == error_response.return_parameters.status
# Sync command with error status should not raise when `check_status` is False
response2 = await host.send_sync_command(HCI_Reset_Command(), check_status=False)
assert response2.status == HCI_ErrorCode.COMMAND_DISALLOWED_ERROR
+10 -17
View File
@@ -239,20 +239,7 @@ async def transfer_payload(
channels[1].sink = received.put_nowait
sdu_lengths = (21, 70, 700, 5523)
if isinstance(channels[1], l2cap.LeCreditBasedChannel):
mps = channels[1].mps
elif isinstance(
processor := channels[1].processor, l2cap.EnhancedRetransmissionProcessor
):
mps = processor.mps
else:
mps = channels[1].mtu
messages = [
bytes([i % 8 for i in range(sdu_length)])
for sdu_length in sdu_lengths
if sdu_length <= mps
]
messages = [bytes([i % 8 for i in range(sdu_length)]) for sdu_length in sdu_lengths]
for message in messages:
channels[0].write(message)
if isinstance(channels[0], l2cap.LeCreditBasedChannel):
@@ -334,20 +321,26 @@ async def test_mtu():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_enhanced_retransmission_mode():
@pytest.mark.parametrize("mtu,", (50, 255, 256, 1000))
async def test_enhanced_retransmission_mode(mtu: int):
devices = TwoDevices()
await devices.setup_connection()
server_channels = asyncio.Queue[l2cap.ClassicChannel]()
server = devices.devices[1].create_l2cap_server(
spec=l2cap.ClassicChannelSpec(
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
mtu=mtu,
mps=256,
),
handler=server_channels.put_nowait,
)
client_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(
server.psm, mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION
server.psm,
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
mtu=mtu,
mps=1024,
)
)
server_channel = await server_channels.get()
+2 -4
View File
@@ -57,15 +57,13 @@ async def test_self_disconnection():
await two_devices.setup_connection()
await two_devices.connections[0].disconnect()
await async_barrier()
assert two_devices.connections[0] is None
assert two_devices.connections[1] is None
assert not two_devices.connections
two_devices = TwoDevices()
await two_devices.setup_connection()
await two_devices.connections[1].disconnect()
await async_barrier()
assert two_devices.connections[0] is None
assert two_devices.connections[1] is None
assert not two_devices.connections
# -----------------------------------------------------------------------------
+8 -6
View File
@@ -31,10 +31,10 @@ from bumble.transport.common import AsyncPipeSink
# -----------------------------------------------------------------------------
class Devices:
connections: list[Connection | None]
connections: dict[int, Connection]
def __init__(self, num_devices: int) -> None:
self.connections = [None for _ in range(num_devices)]
self.connections = {}
self.link = LocalLink()
addresses = [":".join([f"F{i}"] * 6) for i in range(num_devices)]
@@ -60,12 +60,14 @@ class Devices:
asyncio.get_event_loop().create_future() for _ in range(num_devices)
]
def on_connection(self, which, connection):
def on_connection(self, which: int, connection: Connection) -> None:
self.connections[which] = connection
connection.on('disconnection', lambda code: self.on_disconnection(which))
connection.on(
connection.EVENT_DISCONNECTION, lambda *_: self.on_disconnection(which)
)
def on_disconnection(self, which):
self.connections[which] = None
def on_disconnection(self, which: int) -> None:
self.connections.pop(which, None)
def on_paired(self, which: int, keys: PairingKeys) -> None:
self.paired[which].set_result(keys)
+1 -1
View File
@@ -89,7 +89,7 @@ class HeartRateMonitor:
async def stop(self):
# TODO: replace this once a proper reset is implemented in the lib.
await self.device.host.send_command(HCI_Reset_Command())
await self.device.host.send_sync_command(HCI_Reset_Command())
await self.device.power_off()
print('### Monitor stopped')
+1 -1
View File
@@ -60,7 +60,7 @@ class Scanner(utils.EventEmitter):
async def stop(self):
# TODO: replace this once a proper reset is implemented in the lib.
await self.device.host.send_command(HCI_Reset_Command())
await self.device.host.send_sync_command(HCI_Reset_Command())
await self.device.power_off()
print('### Scanner stopped')
+1 -1
View File
@@ -311,7 +311,7 @@ class Speaker:
async def stop(self):
# TODO: replace this once a proper reset is implemented in the lib.
await self.device.host.send_command(HCI_Reset_Command())
await self.device.host.send_sync_command(HCI_Reset_Command())
await self.device.power_off()
print('Speaker stopped')