Compare commits

...

71 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod 3894b14467 better handling of complete/status events 2026-02-02 23:28:40 -08:00
Gilles Boccon-Gibod e62f947430 add workaround for some buggy controllers 2026-02-02 13:19:55 -08:00
Gilles Boccon-Gibod dcb8a4b607 Merge pull request #877 from google/gbg/hci-fixes
fix a few HCI types and make the bridge more robust
2026-02-02 11:19:28 -08:00
Gilles Boccon-Gibod 81985c47a9 remove superfluous statement 2026-02-02 11:12:28 -08:00
Gilles Boccon-Gibod 7118328b07 Merge pull request #879 from google/gbg/resolve-when-bonded
resolve addresses when connecting to bonded peers
2026-01-31 11:09:55 -08:00
Gilles Boccon-Gibod 5dc01d792a address PR comments 2026-01-31 10:55:58 -08:00
Gilles Boccon-Gibod 255f357975 resolve when bonded 2026-01-30 21:53:01 -08:00
Josh Wu c86920558b Merge pull request #878 from zxzxwu/avrcp
AVRCP: SDP record classes and some delegation
2026-01-31 00:01:55 +08:00
Josh Wu 8e6efd0b2f Fix error in AVRCP example 2026-01-30 23:01:11 +08:00
Gilles Boccon-Gibod 2a59e19283 fix comment 2026-01-29 19:09:46 -08:00
Josh Wu 34f5b81c7d AVRCP: Delegate Company ID capabilities 2026-01-29 22:13:14 +08:00
Josh Wu d34d6a5c98 AVRCP: Delegate Playback Status 2026-01-29 21:33:57 +08:00
Josh Wu aedc971653 AVRCP: Add SDP record class and finder 2026-01-29 16:00:50 +08:00
Josh Wu c6815fb820 AVRCP: Delegate passthrough key event 2026-01-29 14:50:14 +08:00
Gilles Boccon-Gibod f44d013690 make bridge more robust 2026-01-27 09:47:52 -08:00
Gilles Boccon-Gibod e63dc15ede fix handling of return parameters 2026-01-27 09:39:22 -08:00
Gilles Boccon-Gibod c901e15666 fix a few HCI types and make the bridge more robust 2026-01-25 13:47:14 -08:00
Gilles Boccon-Gibod 022323b19c Merge pull request #871 from google/gbg/sci
add basic support for SCI
2026-01-24 10:39:11 -08:00
Gilles Boccon-Gibod a0d24e95e7 fix spacing_type 2026-01-24 10:15:32 -08:00
Josh Wu 7efbd303e0 Merge pull request #876 from ttdennis/await_termination_fix
Update apps and examples to await .terminated instead of wait_for_termination()
2026-01-24 11:44:19 +08:00
Dennis Heinze 49530d8d6d Update apps and examples to await .terminated 2026-01-24 00:20:55 +01:00
Gilles Boccon-Gibod 85b78b46f8 Merge pull request #870 from antipatico/feat_AV53C1 2026-01-23 13:43:12 -08:00
Josh Wu 3f9ef5aac2 Merge pull request #873 from zxzxwu/l2cap
L2CAP: Fix wrong CID on reject
2026-01-23 12:44:59 +08:00
Josh Wu e488ea9783 Merge pull request #872 from zxzxwu/avrcp
AVRCP: Fix wrong field specs
2026-01-23 12:36:14 +08:00
Josh Wu 21d937c2f1 Merge pull request #865 from willnix/pcapsnoop
Added a PcapSnooper class
2026-01-23 12:33:15 +08:00
Frieder Steinmetz a8396e6cce Formatted with black again. 2026-01-22 17:49:58 +01:00
Josh Wu 7e1b1c8f78 L2CAP: Fix wrong CID on reject 2026-01-22 23:16:25 +08:00
Josh Wu 55719bf6de AVRCP: Fix wrong field specs 2026-01-22 22:18:58 +08:00
Frieder Steinmetz 5059920696 Please mypy.\n\nTwo calls to open(), some more annotations and a rescoped global were needed. 2026-01-22 10:40:08 +01:00
Gilles Boccon-Gibod c577f17c99 add basic support for SCI 2026-01-20 15:32:55 -08:00
Gilles Boccon-Gibod 252f3e49b6 Merge pull request #870 from antipatico/feat_AV53C1 2026-01-20 10:46:52 -08:00
Jacopo Scannella f3ecf04479 Added support for STA-AV53C1-USB-BLUETOOTH StarTech(dot)com dongle - RTL8761BUE 2026-01-20 09:32:51 +01:00
Gilles Boccon-Gibod 4986f55043 Merge pull request #869 from timrid/android-fix
Make bumble work on Android using briefcase/chaquopy
2026-01-19 09:50:08 -08:00
Gilles Boccon-Gibod 7e89c8a7f8 Merge pull request #868 from google/gbg/return-parameters
typing support for HCI commands return parameters
2026-01-19 09:49:15 -08:00
timrid 085905a7bf Make bumble work on Android using briefcase that is using chaquopy under the hood. 2026-01-18 23:32:37 +01:00
Gilles Boccon-Gibod 7523118581 typing surrport for HCI commands return parameters 2026-01-17 13:19:36 -08:00
zxzxwu c619f1f21b Merge pull request #867 from zxzxwu/fix-import-error
Fix missing ClassVar import
2026-01-16 15:33:07 +08:00
Josh Wu d4b0da9265 Fix missing ClassVar import 2026-01-16 15:21:26 +08:00
zxzxwu f1058e4d4e Merge pull request #859 from istemon/att-read-by-type-request-fix
Return 'invalid handle' for malformed read by type request
2026-01-16 15:09:20 +08:00
zxzxwu 454d477d7e Merge pull request #864 from zxzxwu/hci-packets-typing
Add HCI Packets annotations and send_sco_sdu
2026-01-16 15:08:42 +08:00
zxzxwu 6966228d74 Merge pull request #863 from zxzxwu/eatt-mtu
Correct ATT_MTU in enhanced bearers
2026-01-16 15:08:12 +08:00
zxzxwu f4271a5646 Merge pull request #862 from zxzxwu/gatt-multiple
GATT: Support Multiple Requests
2026-01-16 15:08:02 +08:00
zxzxwu 534209f0af Merge pull request #861 from zxzxwu/l2cap
Replace send_pdu() with write()
2026-01-16 15:07:54 +08:00
zxzxwu 549b82999a Merge pull request #860 from zxzxwu/address
Improve Address type annotations
2026-01-16 14:04:56 +08:00
zxzxwu 551f577b2a Merge pull request #866 from zxzxwu/template-service
Fix GATT TemplateSerivce annotations
2026-01-16 09:41:48 +08:00
Frieder Steinmetz c69c1532cc Fix comments that were messed up by black 2026-01-15 19:06:03 +01:00
Frieder Steinmetz f95b2054c8 Formatted with 2026-01-15 10:50:33 +01:00
Josh Wu 84a6453dda Fix GATT TemplateSerivce annotations 2026-01-15 12:06:05 +08:00
Frieder Steinmetz 3fdd7ee45e Added the PcapSnooper class.
The class implements a bumble snooper that writes PCAP records.
It can write to either a file or a named pipe.
The latter is useful to bridge with wireshark extcap for live logging.
2026-01-14 23:40:59 +01:00
Gilles Boccon-Gibod 591ed61686 Merge pull request #858 from klow68/feat/add-usb-probe-filtering 2026-01-13 08:54:55 -08:00
Josh Wu 3d3acbb374 Add HCI Packets annotations and send_sco_sdu 2026-01-13 17:58:37 +08:00
Stryxion 671f306a27 fix: black 2026-01-13 09:42:40 +01:00
Josh Wu f7364db992 Correct ATT_MTU in enhanced bearers 2026-01-12 21:03:14 +08:00
Josh Wu 0fb2b3bd66 GATT: Support Multiple Requests 2026-01-12 20:51:38 +08:00
Stryxion 9e270d4d62 fix: mypy 2026-01-12 09:36:35 +01:00
Josh Wu cf60b5ffbb Replace send_pdu() with write() 2026-01-12 13:16:49 +08:00
Josh Wu aa4c57d105 Improve Address type annotations
* Add missing annotations
* Declare address constants as ClassVar
2026-01-12 13:07:04 +08:00
Istemon 61a601e6e2 Return 'invalid handle' for malformed read by type request 2026-01-10 01:43:30 +00:00
Stryxion 05fd4fbfc6 fix: review 2026-01-09 08:46:31 +01:00
Gilles Boccon-Gibod 2cad743f8c Merge pull request #854 from TinyServal/rtl8761cu
Add support for RTL8761CU
2026-01-08 18:37:21 -08:00
Stryxion 6aa9e0bdf7 feat: Add filtering options for usb probe 2026-01-08 14:54:58 +01:00
zxzxwu 255414f315 Merge pull request #857 from zxzxwu/testing
Add test for Heart Rate and Battery Service
2026-01-08 17:52:12 +08:00
Josh Wu d2df76f6f4 Add test for Heart Rate and Battery Service 2026-01-08 16:42:05 +08:00
zxzxwu 884b1c20e4 Merge pull request #856 from zxzxwu/typing
Add annotation for Heart Rate and Battery Service
2026-01-08 15:29:50 +08:00
Josh Wu 91a2b4f676 Add annotation for Heart Rate and Battery Service 2026-01-08 14:43:27 +08:00
Bowen Yan 5831f79d62 Add support for the RTL8761CU 2026-01-08 16:50:11 +11:00
zxzxwu 36f81b798c Merge pull request #853 from zxzxwu/l2cap
L2CAP: Fix segmentation and frame ack
2026-01-08 09:40:13 +08:00
Gilles Boccon-Gibod 985183001f Merge pull request #855 from encarbassotnopot/patch-1
docs: fix a small error in hci socket up/down commands
2026-01-07 14:26:15 -08:00
Josh Wu b153d0fcde L2CAP: Fix Enhanced Retransmission Segmentation 2026-01-07 23:49:57 +08:00
Eina Safor 30d912d66e docs: fix a small error in hci socket up/down commands 2026-01-07 15:59:14 +01:00
Bowen Yan 054dc70f3f Exclude macOS xattr files 2026-01-07 15:00:21 +11:00
74 changed files with 5055 additions and 2526 deletions
+3
View File
@@ -17,3 +17,6 @@ venv/
.venv/ .venv/
# snoop logs # snoop logs
out/ out/
# macOS
.DS_Store
._*
+90 -114
View File
@@ -27,23 +27,17 @@ from bumble.core import name_or_number
from bumble.hci import ( from bumble.hci import (
HCI_LE_READ_BUFFER_SIZE_COMMAND, HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_READ_BUFFER_SIZE_V2_COMMAND, HCI_LE_READ_BUFFER_SIZE_V2_COMMAND,
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND, HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND, HCI_LE_READ_MINIMUM_SUPPORTED_CONNECTION_INTERVAL_COMMAND,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND, HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_READ_BD_ADDR_COMMAND, HCI_READ_BD_ADDR_COMMAND,
HCI_READ_BUFFER_SIZE_COMMAND, HCI_READ_BUFFER_SIZE_COMMAND,
HCI_READ_LOCAL_NAME_COMMAND, HCI_READ_LOCAL_NAME_COMMAND,
HCI_SUCCESS,
CodecID,
HCI_Command, HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_LE_Read_Buffer_Size_Command, HCI_LE_Read_Buffer_Size_Command,
HCI_LE_Read_Buffer_Size_V2_Command, HCI_LE_Read_Buffer_Size_V2_Command,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
HCI_LE_Read_Maximum_Data_Length_Command, HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command, HCI_LE_Read_Minimum_Supported_Connection_Interval_Command,
HCI_LE_Read_Suggested_Default_Data_Length_Command, HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_Read_BD_ADDR_Command, HCI_Read_BD_ADDR_Command,
HCI_Read_Buffer_Size_Command, HCI_Read_Buffer_Size_Command,
@@ -59,85 +53,81 @@ from bumble.host import Host
from bumble.transport import open_transport from bumble.transport import open_transport
# -----------------------------------------------------------------------------
def command_succeeded(response):
if isinstance(response, HCI_Command_Status_Event):
return response.status == HCI_SUCCESS
if isinstance(response, HCI_Command_Complete_Event):
return response.return_parameters.status == HCI_SUCCESS
return False
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_classic_info(host: Host) -> None: async def get_classic_info(host: Host) -> None:
if host.supports_command(HCI_READ_BD_ADDR_COMMAND): if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command()) response1 = await host.send_sync_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response): print()
print() print(
print( color('Public Address:', 'yellow'),
color('Public Address:', 'yellow'), response1.bd_addr.to_string(False),
response.return_parameters.bd_addr.to_string(False), )
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND): if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response = await host.send_command(HCI_Read_Local_Name_Command()) response2 = await host.send_sync_command(HCI_Read_Local_Name_Command())
if command_succeeded(response): print()
print() print(
print( color('Local Name:', 'yellow'),
color('Local Name:', 'yellow'), map_null_terminated_utf8_string(response2.local_name),
map_null_terminated_utf8_string(response.return_parameters.local_name), )
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_le_info(host: Host) -> None: async def get_le_info(host: Host) -> None:
print() print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND): print(
response = await host.send_command( color('LE Number Of Supported Advertising Sets:', 'yellow'),
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command() host.number_of_supported_advertising_sets,
) '\n',
if command_succeeded(response): )
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response.return_parameters.num_supported_advertising_sets,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND): print(
response = await host.send_command( color('LE Maximum Advertising Data Length:', 'yellow'),
HCI_LE_Read_Maximum_Advertising_Data_Length_Command() host.maximum_advertising_data_length,
) '\n',
if command_succeeded(response): )
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response.return_parameters.max_advertising_data_length,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND): if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
response = await host.send_command(HCI_LE_Read_Maximum_Data_Length_Command()) response1 = await host.send_sync_command(
if command_succeeded(response): HCI_LE_Read_Maximum_Data_Length_Command()
print( )
color('Maximum Data Length:', 'yellow'), print(
( color('LE Maximum Data Length:', 'yellow'),
f'tx:{response.return_parameters.supported_max_tx_octets}/' (
f'{response.return_parameters.supported_max_tx_time}, ' f'tx:{response1.supported_max_tx_octets}/'
f'rx:{response.return_parameters.supported_max_rx_octets}/' f'{response1.supported_max_tx_time}, '
f'{response.return_parameters.supported_max_rx_time}' f'rx:{response1.supported_max_rx_octets}/'
), f'{response1.supported_max_rx_time}'
'\n', ),
) )
if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND): if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await host.send_command( response2 = await host.send_sync_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command() HCI_LE_Read_Suggested_Default_Data_Length_Command()
) )
if command_succeeded(response): print(
color('LE Suggested Default Data Length:', 'yellow'),
f'{response2.suggested_max_tx_octets}/'
f'{response2.suggested_max_tx_time}',
'\n',
)
if host.supports_command(HCI_LE_READ_MINIMUM_SUPPORTED_CONNECTION_INTERVAL_COMMAND):
response3 = await host.send_sync_command(
HCI_LE_Read_Minimum_Supported_Connection_Interval_Command()
)
print(
color('LE Minimum Supported Connection Interval:', 'yellow'),
f'{response3.minimum_supported_connection_interval * 125} µs',
)
for group in range(len(response3.group_min)):
print( print(
color('Suggested Default Data Length:', 'yellow'), f' Group {group}: '
f'{response.return_parameters.suggested_max_tx_octets}/' f'{response3.group_min[group] * 125} µs to '
f'{response.return_parameters.suggested_max_tx_time}', f'{response3.group_max[group] * 125} µs '
'by increments of '
f'{response3.group_stride[group] * 125} µs',
'\n', '\n',
) )
@@ -151,37 +141,31 @@ async def get_flow_control_info(host: Host) -> None:
print() print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND): if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command( response1 = await host.send_sync_command(HCI_Read_Buffer_Size_Command())
HCI_Read_Buffer_Size_Command(), check_result=True
)
print( print(
color('ACL Flow Control:', 'yellow'), color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} ' f'{response1.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}', f'packets of size {response1.hc_acl_data_packet_length}',
) )
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND): if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await host.send_command( response2 = await host.send_sync_command(HCI_LE_Read_Buffer_Size_V2_Command())
HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
print( print(
color('LE ACL Flow Control:', 'yellow'), color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_le_acl_data_packets} ' f'{response2.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.le_acl_data_packet_length}', f'packets of size {response2.le_acl_data_packet_length}',
) )
print( print(
color('LE ISO Flow Control:', 'yellow'), color('LE ISO Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_iso_data_packets} ' f'{response2.total_num_iso_data_packets} '
f'packets of size {response.return_parameters.iso_data_packet_length}', f'packets of size {response2.iso_data_packet_length}',
) )
elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND): elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command( response3 = await host.send_sync_command(HCI_LE_Read_Buffer_Size_Command())
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
print( print(
color('LE ACL Flow Control:', 'yellow'), color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_le_acl_data_packets} ' f'{response3.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.le_acl_data_packet_length}', f'packets of size {response3.le_acl_data_packet_length}',
) )
@@ -190,52 +174,44 @@ async def get_codecs_info(host: Host) -> None:
print() print()
if host.supports_command(HCI_Read_Local_Supported_Codecs_V2_Command.op_code): if host.supports_command(HCI_Read_Local_Supported_Codecs_V2_Command.op_code):
response = await host.send_command( response1 = await host.send_sync_command(
HCI_Read_Local_Supported_Codecs_V2_Command(), check_result=True HCI_Read_Local_Supported_Codecs_V2_Command()
) )
print(color('Codecs:', 'yellow')) print(color('Codecs:', 'yellow'))
for codec_id, transport in zip( for codec_id, transport in zip(
response.return_parameters.standard_codec_ids, response1.standard_codec_ids,
response.return_parameters.standard_codec_transports, response1.standard_codec_transports,
): ):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport( print(f' {codec_id.name} - {transport.name}')
transport
).name
codec_name = CodecID(codec_id).name
print(f' {codec_name} - {transport_name}')
for codec_id, transport in zip( for vendor_codec_id, vendor_transport in zip(
response.return_parameters.vendor_specific_codec_ids, response1.vendor_specific_codec_ids,
response.return_parameters.vendor_specific_codec_transports, response1.vendor_specific_codec_transports,
): ):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport( company = name_or_number(COMPANY_IDENTIFIERS, vendor_codec_id >> 16)
transport print(f' {company} / {vendor_codec_id & 0xFFFF} - {vendor_transport.name}')
).name
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF} - {transport_name}')
if not response.return_parameters.standard_codec_ids: if not response1.standard_codec_ids:
print(' No standard codecs') print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids: if not response1.vendor_specific_codec_ids:
print(' No Vendor-specific codecs') print(' No Vendor-specific codecs')
if host.supports_command(HCI_Read_Local_Supported_Codecs_Command.op_code): if host.supports_command(HCI_Read_Local_Supported_Codecs_Command.op_code):
response = await host.send_command( response2 = await host.send_sync_command(
HCI_Read_Local_Supported_Codecs_Command(), check_result=True HCI_Read_Local_Supported_Codecs_Command()
) )
print(color('Codecs (BR/EDR):', 'yellow')) print(color('Codecs (BR/EDR):', 'yellow'))
for codec_id in response.return_parameters.standard_codec_ids: for codec_id in response2.standard_codec_ids:
codec_name = CodecID(codec_id).name print(f' {codec_id.name}')
print(f' {codec_name}')
for codec_id in response.return_parameters.vendor_specific_codec_ids: for vendor_codec_id in response2.vendor_specific_codec_ids:
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16) company = name_or_number(COMPANY_IDENTIFIERS, vendor_codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF}') print(f' {company} / {vendor_codec_id & 0xFFFF}')
if not response.return_parameters.standard_codec_ids: if not response2.standard_codec_ids:
print(' No standard codecs') print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids: if not response2.vendor_specific_codec_ids:
print(' No Vendor-specific codecs') print(' No Vendor-specific codecs')
+11 -9
View File
@@ -85,7 +85,7 @@ class Loopback:
print(color('@@@ Received last packet', 'green')) print(color('@@@ Received last packet', 'green'))
self.done.set() self.done.set()
async def run(self): async def run(self) -> None:
"""Run a loopback throughput test""" """Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green')) print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport(self.transport) as ( async with await open_transport(self.transport) as (
@@ -100,11 +100,15 @@ class Loopback:
# make sure data can fit in one l2cap pdu # make sure data can fit in one l2cap pdu
l2cap_header_size = 4 l2cap_header_size = 4
max_packet_size = ( packet_queue = (
host.acl_packet_queue host.acl_packet_queue
if host.acl_packet_queue if host.acl_packet_queue
else host.le_acl_packet_queue else host.le_acl_packet_queue
).max_packet_size - l2cap_header_size )
if packet_queue is None:
print(color('!!! No packet queue', 'red'))
return
max_packet_size = packet_queue.max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size: if self.packet_size > max_packet_size:
print( print(
color( color(
@@ -128,20 +132,18 @@ class Loopback:
loopback_mode = LoopbackMode.LOCAL loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue')) print(color('### Setting loopback mode', 'blue'))
await host.send_command( await host.send_sync_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL), HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
) )
print(color('### Checking loopback mode', 'blue')) print(color('### Checking loopback mode', 'blue'))
response = await host.send_command( response = await host.send_sync_command(HCI_Read_Loopback_Mode_Command())
HCI_Read_Loopback_Mode_Command(), check_result=True if response.loopback_mode != loopback_mode:
)
if response.return_parameters.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red')) print(color('!!! Loopback mode mismatch', 'red'))
return return
await self.connection_event.wait() await self.connection_event.wait()
assert self.connection_handle is not None
print(color('### Connected', 'cyan')) print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta')) print(color('=== Start sending', 'magenta'))
+1 -1
View File
@@ -352,7 +352,7 @@ async def run(
await bridge.start() await bridge.start()
# Wait until the source terminates # Wait until the source terminates
await hci_source.wait_for_termination() await hci_source.terminated
@click.command() @click.command()
+3 -1
View File
@@ -81,7 +81,9 @@ async def async_main():
response = hci.HCI_Command_Complete_Event( response = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode=hci_packet.op_code, command_opcode=hci_packet.op_code,
return_parameters=bytes([hci.HCI_SUCCESS]), return_parameters=hci.HCI_StatusReturnParameters(
status=hci.HCI_ErrorCode.SUCCESS
),
) )
# Return a packet with 'respond to sender' set to True # Return a packet with 'respond to sender' set to True
return (bytes(response), True) return (bytes(response), True)
+1 -1
View File
@@ -268,7 +268,7 @@ async def run(device_config, hci_transport, bridge):
await bridge.start(device) await bridge.start(device)
# Wait until the transport terminates # Wait until the transport terminates
await hci_source.wait_for_termination() await hci_source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -421,7 +421,7 @@ async def run(device_config, hci_transport, bridge):
await bridge.start(device) await bridge.start(device)
# Wait until the transport terminates # Wait until the transport terminates
await hci_source.wait_for_termination() await hci_source.terminated
except core.ConnectionError as error: except core.ConnectionError as error:
print(color(f"!!! Bluetooth connection failed: {error}", "red")) print(color(f"!!! Bluetooth connection failed: {error}", "red"))
except Exception as error: except Exception as error:
+10 -4
View File
@@ -22,7 +22,7 @@ import click
import bumble.logging import bumble.logging
from bumble import data_types from bumble import data_types
from bumble.colors import color from bumble.colors import color
from bumble.device import Advertisement, Device from bumble.device import Advertisement, Device, DeviceConfiguration
from bumble.hci import HCI_LE_1M_PHY, HCI_LE_CODED_PHY, Address, HCI_Constant from bumble.hci import HCI_LE_1M_PHY, HCI_LE_CODED_PHY, Address, HCI_Constant
from bumble.keys import JsonKeyStore from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver from bumble.smp import AddressResolver
@@ -144,8 +144,14 @@ async def scan(
device_config, hci_source, hci_sink device_config, hci_source, hci_sink
) )
else: else:
device = Device.with_hci( device = Device.from_config_with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink DeviceConfiguration(
name='Bumble',
address=Address('F0:F1:F2:F3:F4:F5'),
keystore='JsonKeyStore',
),
hci_source,
hci_sink,
) )
await device.power_on() await device.power_on()
@@ -190,7 +196,7 @@ async def scan(
scanning_phys=scanning_phys, scanning_phys=scanning_phys,
) )
await hci_source.wait_for_termination() await hci_source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -726,7 +726,7 @@ class Speaker:
print("Waiting for connection...") print("Waiting for connection...")
await self.advertise() await self.advertise()
await hci_source.wait_for_termination() await hci_source.terminated
for output in self.outputs: for output in self.outputs:
await output.stop() await output.stop()
+15 -2
View File
@@ -26,6 +26,8 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from typing import Any
import click import click
import usb1 import usb1
@@ -166,13 +168,16 @@ def is_bluetooth_hci(device):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@click.command() @click.command()
@click.option('--verbose', is_flag=True, default=False, help='Print more details') @click.option('--verbose', is_flag=True, default=False, help='Print more details')
def main(verbose): @click.option('--hci-only', is_flag=True, default=False, help='only show HCI device')
@click.option('--manufacturer', help='filter by manufacturer')
@click.option('--product', help='filter by product')
def main(verbose: bool, manufacturer: str, product: str, hci_only: bool):
bumble.logging.setup_basic_logging('WARNING') bumble.logging.setup_basic_logging('WARNING')
load_libusb() load_libusb()
with usb1.USBContext() as context: with usb1.USBContext() as context:
bluetooth_device_count = 0 bluetooth_device_count = 0
devices = {} devices: dict[tuple[Any, Any], list[str | None]] = {}
for device in context.getDeviceIterator(skip_on_error=True): for device in context.getDeviceIterator(skip_on_error=True):
device_class = device.getDeviceClass() device_class = device.getDeviceClass()
@@ -234,6 +239,14 @@ def main(verbose):
f'{basic_transport_name}/{device_serial_number}' f'{basic_transport_name}/{device_serial_number}'
) )
# Filter
if product and device_product != product:
continue
if manufacturer and device_manufacturer != manufacturer:
continue
if not is_bluetooth_hci(device) and hci_only:
continue
# Print the results # Print the results
print( print(
color( color(
+97 -32
View File
@@ -29,7 +29,7 @@ import enum
import functools import functools
import inspect import inspect
import struct import struct
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable, Sequence
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
ClassVar, ClassVar,
@@ -72,34 +72,36 @@ ATT_PSM = 0x001F
EATT_PSM = 0x0027 EATT_PSM = 0x0027
class Opcode(hci.SpecableEnum): class Opcode(hci.SpecableEnum):
ATT_ERROR_RESPONSE = 0x01 ATT_ERROR_RESPONSE = 0x01
ATT_EXCHANGE_MTU_REQUEST = 0x02 ATT_EXCHANGE_MTU_REQUEST = 0x02
ATT_EXCHANGE_MTU_RESPONSE = 0x03 ATT_EXCHANGE_MTU_RESPONSE = 0x03
ATT_FIND_INFORMATION_REQUEST = 0x04 ATT_FIND_INFORMATION_REQUEST = 0x04
ATT_FIND_INFORMATION_RESPONSE = 0x05 ATT_FIND_INFORMATION_RESPONSE = 0x05
ATT_FIND_BY_TYPE_VALUE_REQUEST = 0x06 ATT_FIND_BY_TYPE_VALUE_REQUEST = 0x06
ATT_FIND_BY_TYPE_VALUE_RESPONSE = 0x07 ATT_FIND_BY_TYPE_VALUE_RESPONSE = 0x07
ATT_READ_BY_TYPE_REQUEST = 0x08 ATT_READ_BY_TYPE_REQUEST = 0x08
ATT_READ_BY_TYPE_RESPONSE = 0x09 ATT_READ_BY_TYPE_RESPONSE = 0x09
ATT_READ_REQUEST = 0x0A ATT_READ_REQUEST = 0x0A
ATT_READ_RESPONSE = 0x0B ATT_READ_RESPONSE = 0x0B
ATT_READ_BLOB_REQUEST = 0x0C ATT_READ_BLOB_REQUEST = 0x0C
ATT_READ_BLOB_RESPONSE = 0x0D ATT_READ_BLOB_RESPONSE = 0x0D
ATT_READ_MULTIPLE_REQUEST = 0x0E ATT_READ_MULTIPLE_REQUEST = 0x0E
ATT_READ_MULTIPLE_RESPONSE = 0x0F ATT_READ_MULTIPLE_RESPONSE = 0x0F
ATT_READ_BY_GROUP_TYPE_REQUEST = 0x10 ATT_READ_BY_GROUP_TYPE_REQUEST = 0x10
ATT_READ_BY_GROUP_TYPE_RESPONSE = 0x11 ATT_READ_BY_GROUP_TYPE_RESPONSE = 0x11
ATT_WRITE_REQUEST = 0x12 ATT_READ_MULTIPLE_VARIABLE_REQUEST = 0x20
ATT_WRITE_RESPONSE = 0x13 ATT_READ_MULTIPLE_VARIABLE_RESPONSE = 0x21
ATT_WRITE_COMMAND = 0x52 ATT_WRITE_REQUEST = 0x12
ATT_SIGNED_WRITE_COMMAND = 0xD2 ATT_WRITE_RESPONSE = 0x13
ATT_PREPARE_WRITE_REQUEST = 0x16 ATT_WRITE_COMMAND = 0x52
ATT_PREPARE_WRITE_RESPONSE = 0x17 ATT_SIGNED_WRITE_COMMAND = 0xD2
ATT_EXECUTE_WRITE_REQUEST = 0x18 ATT_PREPARE_WRITE_REQUEST = 0x16
ATT_EXECUTE_WRITE_RESPONSE = 0x19 ATT_PREPARE_WRITE_RESPONSE = 0x17
ATT_HANDLE_VALUE_NOTIFICATION = 0x1B ATT_EXECUTE_WRITE_REQUEST = 0x18
ATT_HANDLE_VALUE_INDICATION = 0x1D ATT_EXECUTE_WRITE_RESPONSE = 0x19
ATT_HANDLE_VALUE_CONFIRMATION = 0x1E ATT_HANDLE_VALUE_NOTIFICATION = 0x1B
ATT_HANDLE_VALUE_INDICATION = 0x1D
ATT_HANDLE_VALUE_CONFIRMATION = 0x1E
ATT_REQUESTS = [ ATT_REQUESTS = [
Opcode.ATT_EXCHANGE_MTU_REQUEST, Opcode.ATT_EXCHANGE_MTU_REQUEST,
@@ -110,9 +112,10 @@ ATT_REQUESTS = [
Opcode.ATT_READ_BLOB_REQUEST, Opcode.ATT_READ_BLOB_REQUEST,
Opcode.ATT_READ_MULTIPLE_REQUEST, Opcode.ATT_READ_MULTIPLE_REQUEST,
Opcode.ATT_READ_BY_GROUP_TYPE_REQUEST, Opcode.ATT_READ_BY_GROUP_TYPE_REQUEST,
Opcode.ATT_READ_MULTIPLE_VARIABLE_REQUEST,
Opcode.ATT_WRITE_REQUEST, Opcode.ATT_WRITE_REQUEST,
Opcode.ATT_PREPARE_WRITE_REQUEST, Opcode.ATT_PREPARE_WRITE_REQUEST,
Opcode.ATT_EXECUTE_WRITE_REQUEST Opcode.ATT_EXECUTE_WRITE_REQUEST,
] ]
ATT_RESPONSES = [ ATT_RESPONSES = [
@@ -125,9 +128,10 @@ ATT_RESPONSES = [
Opcode.ATT_READ_BLOB_RESPONSE, Opcode.ATT_READ_BLOB_RESPONSE,
Opcode.ATT_READ_MULTIPLE_RESPONSE, Opcode.ATT_READ_MULTIPLE_RESPONSE,
Opcode.ATT_READ_BY_GROUP_TYPE_RESPONSE, Opcode.ATT_READ_BY_GROUP_TYPE_RESPONSE,
Opcode.ATT_READ_MULTIPLE_VARIABLE_RESPONSE,
Opcode.ATT_WRITE_RESPONSE, Opcode.ATT_WRITE_RESPONSE,
Opcode.ATT_PREPARE_WRITE_RESPONSE, Opcode.ATT_PREPARE_WRITE_RESPONSE,
Opcode.ATT_EXECUTE_WRITE_RESPONSE Opcode.ATT_EXECUTE_WRITE_RESPONSE,
] ]
class ErrorCode(hci.SpecableEnum): class ErrorCode(hci.SpecableEnum):
@@ -185,6 +189,18 @@ ATT_INSUFFICIENT_RESOURCES_ERROR = ErrorCode.INSUFFICIENT_RESOURCES
ATT_DEFAULT_MTU = 23 ATT_DEFAULT_MTU = 23
HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'} HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'}
_SET_OF_HANDLES_METADATA = hci.metadata({
'parser': lambda data, offset: (
len(data),
[
struct.unpack_from('<H', data, i)[0]
for i in range(offset, len(data), 2)
],
),
'serializer': lambda handles: b''.join(
[struct.pack('<H', handle) for handle in handles]
),
})
# fmt: on # fmt: on
# pylint: enable=line-too-long # pylint: enable=line-too-long
@@ -554,7 +570,7 @@ class ATT_Read_Multiple_Request(ATT_PDU):
See Bluetooth spec @ Vol 3, Part F - 3.4.4.7 Read Multiple Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.7 Read Multiple Request
''' '''
set_of_handles: bytes = dataclasses.field(metadata=hci.metadata("*")) set_of_handles: Sequence[int] = dataclasses.field(metadata=_SET_OF_HANDLES_METADATA)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -635,6 +651,55 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
return result return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass
@dataclasses.dataclass
class ATT_Read_Multiple_Variable_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.11 Read Multiple Variable Request
'''
set_of_handles: Sequence[int] = dataclasses.field(metadata=_SET_OF_HANDLES_METADATA)
# -----------------------------------------------------------------------------
@ATT_PDU.subclass
@dataclasses.dataclass
class ATT_Read_Multiple_Variable_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.12 Read Multiple Variable Response
'''
@classmethod
def _parse_length_value_tuples(
cls, data: bytes, offset: int
) -> tuple[int, list[tuple[int, bytes]]]:
length_value_tuple_list: list[tuple[int, bytes]] = []
while offset < len(data):
length = struct.unpack_from('<H', data, offset)[0]
length_value_tuple_list.append(
(length, data[offset + 2 : offset + 2 + length])
)
offset += 2 + length
return (len(data), length_value_tuple_list)
length_value_tuple_list: Sequence[tuple[int, bytes]] = dataclasses.field(
metadata=hci.metadata(
{
'parser': lambda data, offset: ATT_Read_Multiple_Variable_Response._parse_length_value_tuples(
data, offset
),
'serializer': lambda length_value_tuple_list: b''.join(
[
struct.pack('<H', length) + value
for length, value in length_value_tuple_list
]
),
}
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass @ATT_PDU.subclass
@dataclasses.dataclass @dataclasses.dataclass
+1 -1
View File
@@ -235,7 +235,7 @@ class Protocol:
) )
+ payload + payload
) )
self.l2cap_channel.send_pdu(pdu) self.l2cap_channel.write(pdu)
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None: def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
logger.debug( logger.debug(
+3 -3
View File
@@ -268,7 +268,7 @@ class MediaPacketPump:
await self.clock.sleep(delay) await self.clock.sleep(delay)
# Emit # Emit
rtp_channel.send_pdu(bytes(packet)) rtp_channel.write(bytes(packet))
logger.debug( logger.debug(
f'{color(">>> sending RTP packet:", "green")} {packet}' f'{color(">>> sending RTP packet:", "green")} {packet}'
) )
@@ -1519,7 +1519,7 @@ class Protocol(utils.EventEmitter):
header = bytes([first_header_byte]) header = bytes([first_header_byte])
# Send one packet # Send one packet
self.l2cap_channel.send_pdu(header + payload[:max_fragment_size]) self.l2cap_channel.write(header + payload[:max_fragment_size])
# Prepare for the next packet # Prepare for the next packet
payload = payload[max_fragment_size:] payload = payload[max_fragment_size:]
@@ -1829,7 +1829,7 @@ class Stream:
def send_media_packet(self, packet: MediaPacket) -> None: def send_media_packet(self, packet: MediaPacket) -> None:
assert self.rtp_channel assert self.rtp_channel
self.rtp_channel.send_pdu(bytes(packet)) self.rtp_channel.write(bytes(packet))
async def configure(self) -> None: async def configure(self) -> None:
if self.state != State.IDLE: if self.state != State.IDLE:
+374 -198
View File
@@ -26,7 +26,7 @@ from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequen
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import ClassVar, SupportsBytes, TypeVar from typing import ClassVar, SupportsBytes, TypeVar
from bumble import avc, avctp, core, hci, l2cap, utils from bumble import avc, avctp, core, hci, l2cap, sdp, utils
from bumble.colors import color from bumble.colors import color
from bumble.device import Connection, Device from bumble.device import Connection, Device
from bumble.sdp import ( from bumble.sdp import (
@@ -55,13 +55,15 @@ AVRCP_PID = 0x110E
AVRCP_BLUETOOTH_SIG_COMPANY_ID = 0x001958 AVRCP_BLUETOOTH_SIG_COMPANY_ID = 0x001958
_UINT64_BE_METADATA = { _UINT64_BE_METADATA = hci.metadata(
'parser': lambda data, offset: ( {
offset + 8, 'parser': lambda data, offset: (
int.from_bytes(data[offset : offset + 8], byteorder='big'), offset + 8,
), int.from_bytes(data[offset : offset + 8], byteorder='big'),
'serializer': lambda x: x.to_bytes(8, byteorder='big'), ),
} 'serializer': lambda x: x.to_bytes(8, byteorder='big'),
}
)
class PduId(utils.OpenIntEnum): class PduId(utils.OpenIntEnum):
@@ -92,7 +94,7 @@ class PduId(utils.OpenIntEnum):
class CharacterSetId(hci.SpecableEnum): class CharacterSetId(hci.SpecableEnum):
UTF_8 = 0x06 UTF_8 = 0x6A
class MediaAttributeId(hci.SpecableEnum): class MediaAttributeId(hci.SpecableEnum):
@@ -192,82 +194,43 @@ class TargetFeatures(enum.IntFlag):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_controller_service_sdp_records( @dataclass
service_record_handle: int, class ControllerServiceSdpRecord:
avctp_version: tuple[int, int] = (1, 4), service_record_handle: int
avrcp_version: tuple[int, int] = (1, 6), avctp_version: tuple[int, int] = (1, 4)
supported_features: int | ControllerFeatures = 1, avrcp_version: tuple[int, int] = (1, 6)
) -> list[ServiceAttribute]: supported_features: int | ControllerFeatures = ControllerFeatures(1)
avctp_version_int = avctp_version[0] << 8 | avctp_version[1]
avrcp_version_int = avrcp_version[0] << 8 | avrcp_version[1]
attributes = [ def to_service_attributes(self) -> list[ServiceAttribute]:
ServiceAttribute( avctp_version_int = self.avctp_version[0] << 8 | self.avctp_version[1]
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, avrcp_version_int = self.avrcp_version[0] << 8 | self.avrcp_version[1]
DataElement.unsigned_integer_32(service_record_handle),
), attributes = [
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_SERVICE),
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_CONTROLLER_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(avctp.AVCTP_PSM),
]
),
DataElement.sequence(
[
DataElement.uuid(core.BT_AVCTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(avctp_version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_SERVICE),
DataElement.unsigned_integer_16(avrcp_version_int),
]
),
]
),
),
ServiceAttribute(
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(supported_features),
),
]
if supported_features & ControllerFeatures.SUPPORTS_BROWSING:
attributes.append(
ServiceAttribute( ServiceAttribute(
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(self.service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_SERVICE),
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_CONTROLLER_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence( DataElement.sequence(
[ [
DataElement.sequence( DataElement.sequence(
[ [
DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID), DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16( DataElement.unsigned_integer_16(avctp.AVCTP_PSM),
avctp.AVCTP_BROWSING_PSM
),
] ]
), ),
DataElement.sequence( DataElement.sequence(
@@ -279,87 +242,130 @@ def make_controller_service_sdp_records(
] ]
), ),
), ),
) ServiceAttribute(
return attributes SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_SERVICE),
DataElement.unsigned_integer_16(avrcp_version_int),
]
),
]
),
),
ServiceAttribute(
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(self.supported_features),
),
]
if self.supported_features & ControllerFeatures.SUPPORTS_BROWSING:
attributes.append(
ServiceAttribute(
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(
avctp.AVCTP_BROWSING_PSM
),
]
),
DataElement.sequence(
[
DataElement.uuid(core.BT_AVCTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(avctp_version_int),
]
),
]
),
),
)
return attributes
@classmethod
async def find(cls, connection: Connection) -> list[ControllerServiceSdpRecord]:
async with sdp.Client(connection) as sdp_client:
search_result = await sdp_client.search_attributes(
uuids=[core.BT_AV_REMOTE_CONTROL_CONTROLLER_SERVICE],
attribute_ids=[
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
],
)
records: list[ControllerServiceSdpRecord] = []
for attribute_lists in search_result:
record = cls(0)
for attribute in attribute_lists:
if attribute.id == SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID:
record.service_record_handle = attribute.value.value
elif attribute.id == SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
# [[L2CAP, PSM], [AVCTP, version]]
record.avctp_version = (
attribute.value.value[1].value[1].value >> 8,
attribute.value.value[1].value[1].value & 0xFF,
)
elif (
attribute.id
== SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
):
# [[AV_REMOTE_CONTROL, version]]
record.avrcp_version = (
attribute.value.value[0].value[1].value >> 8,
attribute.value.value[0].value[1].value & 0xFF,
)
elif attribute.id == SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID:
record.supported_features = ControllerFeatures(
attribute.value.value
)
records.append(record)
return records
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_target_service_sdp_records( @dataclass
service_record_handle: int, class TargetServiceSdpRecord:
avctp_version: tuple[int, int] = (1, 4), service_record_handle: int
avrcp_version: tuple[int, int] = (1, 6), avctp_version: tuple[int, int] = (1, 4)
supported_features: int | TargetFeatures = 0x23, avrcp_version: tuple[int, int] = (1, 6)
) -> list[ServiceAttribute]: supported_features: int | TargetFeatures = TargetFeatures(0x23)
# TODO: support a way to compute the supported features from a feature list
avctp_version_int = avctp_version[0] << 8 | avctp_version[1]
avrcp_version_int = avrcp_version[0] << 8 | avrcp_version[1]
attributes = [ def to_service_attributes(self) -> list[ServiceAttribute]:
ServiceAttribute( # TODO: support a way to compute the supported features from a feature list
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, avctp_version_int = self.avctp_version[0] << 8 | self.avctp_version[1]
DataElement.unsigned_integer_32(service_record_handle), avrcp_version_int = self.avrcp_version[0] << 8 | self.avrcp_version[1]
),
ServiceAttribute( attributes = [
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_TARGET_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(avctp.AVCTP_PSM),
]
),
DataElement.sequence(
[
DataElement.uuid(core.BT_AVCTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(avctp_version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_SERVICE),
DataElement.unsigned_integer_16(avrcp_version_int),
]
),
]
),
),
ServiceAttribute(
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(supported_features),
),
]
if supported_features & TargetFeatures.SUPPORTS_BROWSING:
attributes.append(
ServiceAttribute( ServiceAttribute(
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(self.service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_TARGET_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence( DataElement.sequence(
[ [
DataElement.sequence( DataElement.sequence(
[ [
DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID), DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16( DataElement.unsigned_integer_16(avctp.AVCTP_PSM),
avctp.AVCTP_BROWSING_PSM
),
] ]
), ),
DataElement.sequence( DataElement.sequence(
@@ -371,8 +377,90 @@ def make_target_service_sdp_records(
] ]
), ),
), ),
) ServiceAttribute(
return attributes SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_AV_REMOTE_CONTROL_SERVICE),
DataElement.unsigned_integer_16(avrcp_version_int),
]
),
]
),
),
ServiceAttribute(
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(self.supported_features),
),
]
if self.supported_features & TargetFeatures.SUPPORTS_BROWSING:
attributes.append(
ServiceAttribute(
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(core.BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(
avctp.AVCTP_BROWSING_PSM
),
]
),
DataElement.sequence(
[
DataElement.uuid(core.BT_AVCTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(avctp_version_int),
]
),
]
),
),
)
return attributes
@classmethod
async def find(cls, connection: Connection) -> list[TargetServiceSdpRecord]:
async with sdp.Client(connection) as sdp_client:
search_result = await sdp_client.search_attributes(
uuids=[core.BT_AV_REMOTE_CONTROL_TARGET_SERVICE],
attribute_ids=[
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
],
)
records: list[TargetServiceSdpRecord] = []
for attribute_lists in search_result:
record = cls(0)
for attribute in attribute_lists:
if attribute.id == SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID:
record.service_record_handle = attribute.value.value
elif attribute.id == SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
# [[L2CAP, PSM], [AVCTP, version]]
record.avctp_version = (
attribute.value.value[1].value[1].value >> 8,
attribute.value.value[1].value[1].value & 0xFF,
)
elif (
attribute.id
== SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
):
# [[AV_REMOTE_CONTROL, version]]
record.avrcp_version = (
attribute.value.value[0].value[1].value >> 8,
attribute.value.value[0].value[1].value & 0xFF,
)
elif attribute.id == SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID:
record.supported_features = TargetFeatures(
attribute.value.value
)
records.append(record)
return records
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -491,14 +579,12 @@ class BrowseableItem:
**hci.HCI_Object.dict_from_bytes(data, offset + 3, subclass.fields) **hci.HCI_Object.dict_from_bytes(data, offset + 3, subclass.fields)
) )
instance._payload = data[3:] instance._payload = data[3:]
return offset + length, instance return offset + length + 3, instance
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
if self._payload is None: if self._payload is None:
self._payload = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields) self._payload = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
return ( return struct.pack('>BH', self.item_type, len(self._payload)) + self._payload
struct.pack('>BH', self.item_type, len(self._payload) + 3) + self._payload
)
_Item = TypeVar('_Item', bound='BrowseableItem') _Item = TypeVar('_Item', bound='BrowseableItem')
@@ -601,11 +687,11 @@ class MediaPlayerItem(BrowseableItem):
metadata=MajorPlayerType.type_metadata(1) metadata=MajorPlayerType.type_metadata(1)
) )
player_sub_type: PlayerSubType = field( player_sub_type: PlayerSubType = field(
metadata=PlayerSubType.type_metadata(4, byteorder='big') metadata=PlayerSubType.type_metadata(4, byteorder='little')
) )
play_status: PlayStatus = field(metadata=PlayStatus.type_metadata(1)) play_status: PlayStatus = field(metadata=PlayStatus.type_metadata(1))
feature_bitmask: Features = field( feature_bitmask: Features = field(
metadata=Features.type_metadata(16, byteorder='big') metadata=Features.type_metadata(16, byteorder='little')
) )
character_set_id: CharacterSetId = field( character_set_id: CharacterSetId = field(
metadata=CharacterSetId.type_metadata(2, byteorder='big') metadata=CharacterSetId.type_metadata(2, byteorder='big')
@@ -634,7 +720,7 @@ class FolderItem(BrowseableItem):
folder_uid: int = field(metadata=_UINT64_BE_METADATA) folder_uid: int = field(metadata=_UINT64_BE_METADATA)
folder_type: FolderType = field(metadata=FolderType.type_metadata(1)) folder_type: FolderType = field(metadata=FolderType.type_metadata(1))
is_playable: FolderType = field(metadata=Playable.type_metadata(1)) is_playable: Playable = field(metadata=Playable.type_metadata(1))
character_set_id: CharacterSetId = field( character_set_id: CharacterSetId = field(
metadata=CharacterSetId.type_metadata(2, byteorder='big') metadata=CharacterSetId.type_metadata(2, byteorder='big')
) )
@@ -876,7 +962,7 @@ class GetPlayStatusCommand(Command):
class GetElementAttributesCommand(Command): class GetElementAttributesCommand(Command):
pdu_id = PduId.GET_ELEMENT_ATTRIBUTES pdu_id = PduId.GET_ELEMENT_ATTRIBUTES
identifier: int = field(metadata=hci.metadata(_UINT64_BE_METADATA)) identifier: int = field(metadata=_UINT64_BE_METADATA)
attribute_ids: Sequence[MediaAttributeId] = field( attribute_ids: Sequence[MediaAttributeId] = field(
metadata=MediaAttributeId.type_metadata( metadata=MediaAttributeId.type_metadata(
4, list_begin=True, list_end=True, byteorder='big' 4, list_begin=True, list_end=True, byteorder='big'
@@ -951,7 +1037,7 @@ class ChangePathCommand(Command):
uid_counter: int = field(metadata=hci.metadata('>2')) uid_counter: int = field(metadata=hci.metadata('>2'))
direction: Direction = field(metadata=Direction.type_metadata(1)) direction: Direction = field(metadata=Direction.type_metadata(1))
folder_uid: int = field(metadata=hci.metadata(_UINT64_BE_METADATA)) folder_uid: int = field(metadata=_UINT64_BE_METADATA)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -961,7 +1047,7 @@ class GetItemAttributesCommand(Command):
pdu_id = PduId.GET_ITEM_ATTRIBUTES pdu_id = PduId.GET_ITEM_ATTRIBUTES
scope: Scope = field(metadata=Scope.type_metadata(1)) scope: Scope = field(metadata=Scope.type_metadata(1))
uid: int = field(metadata=hci.metadata(_UINT64_BE_METADATA)) uid: int = field(metadata=_UINT64_BE_METADATA)
uid_counter: int = field(metadata=hci.metadata('>2')) uid_counter: int = field(metadata=hci.metadata('>2'))
start_item: int = field(metadata=hci.metadata('>4')) start_item: int = field(metadata=hci.metadata('>4'))
end_item: int = field(metadata=hci.metadata('>4')) end_item: int = field(metadata=hci.metadata('>4'))
@@ -999,7 +1085,7 @@ class PlayItemCommand(Command):
pdu_id = PduId.PLAY_ITEM pdu_id = PduId.PLAY_ITEM
scope: Scope = field(metadata=Scope.type_metadata(1)) scope: Scope = field(metadata=Scope.type_metadata(1))
uid: int = field(metadata=hci.metadata(_UINT64_BE_METADATA)) uid: int = field(metadata=_UINT64_BE_METADATA)
uid_counter: int = field(metadata=hci.metadata('>2')) uid_counter: int = field(metadata=hci.metadata('>2'))
@@ -1010,7 +1096,7 @@ class AddToNowPlayingCommand(Command):
pdu_id = PduId.ADD_TO_NOW_PLAYING pdu_id = PduId.ADD_TO_NOW_PLAYING
scope: Scope = field(metadata=Scope.type_metadata(1)) scope: Scope = field(metadata=Scope.type_metadata(1))
uid: int = field(metadata=hci.metadata(_UINT64_BE_METADATA)) uid: int = field(metadata=_UINT64_BE_METADATA)
uid_counter: int = field(metadata=hci.metadata('>2')) uid_counter: int = field(metadata=hci.metadata('>2'))
@@ -1204,6 +1290,10 @@ class InformBatteryStatusOfCtResponse(Response):
@dataclass @dataclass
class GetPlayStatusResponse(Response): class GetPlayStatusResponse(Response):
pdu_id = PduId.GET_PLAY_STATUS pdu_id = PduId.GET_PLAY_STATUS
# TG doesn't support Song Length or Position.
UNAVAILABLE = 0xFFFFFFFF
song_length: int = field(metadata=hci.metadata(">4")) song_length: int = field(metadata=hci.metadata(">4"))
song_position: int = field(metadata=hci.metadata(">4")) song_position: int = field(metadata=hci.metadata(">4"))
play_status: PlayStatus = field(metadata=PlayStatus.type_metadata(1)) play_status: PlayStatus = field(metadata=PlayStatus.type_metadata(1))
@@ -1521,16 +1611,33 @@ class Delegate:
def __init__(self, status_code: StatusCode) -> None: def __init__(self, status_code: StatusCode) -> None:
self.status_code = status_code self.status_code = status_code
supported_events: list[EventId] class AvcError(Exception):
volume: int """The delegate AVC method failed, with a specified status code."""
def __init__(self, supported_events: Iterable[EventId] = ()) -> None: def __init__(self, status_code: avc.ResponseFrame.ResponseCode) -> None:
self.status_code = status_code
supported_events: list[EventId]
supported_company_ids: list[int]
volume: int
playback_status: PlayStatus
def __init__(
self,
supported_events: Iterable[EventId] = (),
supported_company_ids: Iterable[int] = (AVRCP_BLUETOOTH_SIG_COMPANY_ID,),
) -> None:
self.supported_company_ids = list(supported_company_ids)
self.supported_events = list(supported_events) self.supported_events = list(supported_events)
self.volume = 0 self.volume = 0
self.playback_status = PlayStatus.STOPPED
async def get_supported_events(self) -> list[EventId]: async def get_supported_events(self) -> list[EventId]:
return self.supported_events return self.supported_events
async def get_supported_company_ids(self) -> list[int]:
return self.supported_company_ids
async def set_absolute_volume(self, volume: int) -> None: async def set_absolute_volume(self, volume: int) -> None:
""" """
Set the absolute volume. Set the absolute volume.
@@ -1543,6 +1650,19 @@ class Delegate:
async def get_absolute_volume(self) -> int: async def get_absolute_volume(self) -> int:
return self.volume return self.volume
async def on_key_event(
self,
key: avc.PassThroughFrame.OperationId,
pressed: bool,
data: bytes,
) -> None:
logger.debug(
"@@@ on_key_event: key=%s, pressed=%s, data=%s", key, pressed, data.hex()
)
async def get_playback_status(self) -> PlayStatus:
return self.playback_status
# TODO add other delegate methods # TODO add other delegate methods
@@ -1756,6 +1876,19 @@ class Protocol(utils.EventEmitter):
if isinstance(capability, EventId) if isinstance(capability, EventId)
) )
async def get_supported_company_ids(self) -> list[int]:
"""Get the list of events supported by the connected peer."""
response_context = await self.send_avrcp_command(
avc.CommandFrame.CommandType.STATUS,
GetCapabilitiesCommand(GetCapabilitiesCommand.CapabilityId.COMPANY_ID),
)
response = self._check_response(response_context, GetCapabilitiesResponse)
return list(
int.from_bytes(capability, 'big')
for capability in response.capabilities
if isinstance(capability, bytes)
)
async def get_play_status(self) -> SongAndPlayStatus: async def get_play_status(self) -> SongAndPlayStatus:
"""Get the play status of the connected peer.""" """Get the play status of the connected peer."""
response_context = await self.send_avrcp_command( response_context = await self.send_avrcp_command(
@@ -2052,16 +2185,28 @@ class Protocol(utils.EventEmitter):
return return
if isinstance(command, avc.PassThroughCommandFrame): if isinstance(command, avc.PassThroughCommandFrame):
# TODO: delegate
response = avc.PassThroughResponseFrame( async def dispatch_key_event() -> None:
avc.ResponseFrame.ResponseCode.ACCEPTED, try:
command.subunit_type, await self.delegate.on_key_event(
command.subunit_id, command.operation_id,
command.state_flag, command.state_flag == avc.PassThroughFrame.StateFlag.PRESSED,
command.operation_id, command.operation_data,
command.operation_data, )
) response_code = avc.ResponseFrame.ResponseCode.ACCEPTED
self.send_response(transaction_label, response) except Delegate.AvcError as error:
logger.exception("delegate method raised exception")
response_code = error.status_code
except Exception:
logger.exception("delegate method raised exception")
response_code = avc.ResponseFrame.ResponseCode.REJECTED
self.send_passthrough_response(
transaction_label=transaction_label,
command=command,
response_code=response_code,
)
utils.AsyncRunner.spawn(dispatch_key_event())
return return
# TODO handle other types # TODO handle other types
@@ -2141,6 +2286,8 @@ class Protocol(utils.EventEmitter):
self._on_set_absolute_volume_command(transaction_label, command) self._on_set_absolute_volume_command(transaction_label, command)
elif isinstance(command, RegisterNotificationCommand): elif isinstance(command, RegisterNotificationCommand):
self._on_register_notification_command(transaction_label, command) self._on_register_notification_command(transaction_label, command)
elif isinstance(command, GetPlayStatusCommand):
self._on_get_play_status_command(transaction_label, command)
else: else:
# Not supported. # Not supported.
# TODO: check that this is the right way to respond in this case. # TODO: check that this is the right way to respond in this case.
@@ -2364,17 +2511,27 @@ class Protocol(utils.EventEmitter):
logger.debug(f"<<< AVRCP command PDU: {command}") logger.debug(f"<<< AVRCP command PDU: {command}")
async def get_supported_events() -> None: async def get_supported_events() -> None:
capabilities: Sequence[bytes | SupportsBytes]
if ( if (
command.capability_id command.capability_id
!= GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED == GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED
): ):
raise core.InvalidArgumentError() capabilities = await self.delegate.get_supported_events()
elif (
supported_events = await self.delegate.get_supported_events() command.capability_id == GetCapabilitiesCommand.CapabilityId.COMPANY_ID
):
company_ids = await self.delegate.get_supported_company_ids()
capabilities = [
company_id.to_bytes(3, 'big') for company_id in company_ids
]
else:
raise core.InvalidArgumentError(
f"Unsupported capability: {command.capability_id}"
)
self.send_avrcp_response( self.send_avrcp_response(
transaction_label, transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE, avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
GetCapabilitiesResponse(command.capability_id, supported_events), GetCapabilitiesResponse(command.capability_id, capabilities),
) )
self._delegate_command(transaction_label, command, get_supported_events()) self._delegate_command(transaction_label, command, get_supported_events())
@@ -2395,6 +2552,26 @@ class Protocol(utils.EventEmitter):
self._delegate_command(transaction_label, command, set_absolute_volume()) self._delegate_command(transaction_label, command, set_absolute_volume())
def _on_get_play_status_command(
self, transaction_label: int, command: GetPlayStatusCommand
) -> None:
logger.debug("<<< AVRCP command PDU: %s", command)
async def get_playback_status() -> None:
play_status: PlayStatus = await self.delegate.get_playback_status()
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
GetPlayStatusResponse(
# TODO: Delegate this.
song_length=GetPlayStatusResponse.UNAVAILABLE,
song_position=GetPlayStatusResponse.UNAVAILABLE,
play_status=play_status,
),
)
self._delegate_command(transaction_label, command, get_playback_status())
def _on_register_notification_command( def _on_register_notification_command(
self, transaction_label: int, command: RegisterNotificationCommand self, transaction_label: int, command: RegisterNotificationCommand
) -> None: ) -> None:
@@ -2410,28 +2587,27 @@ class Protocol(utils.EventEmitter):
) )
return return
response: Response
if command.event_id == EventId.VOLUME_CHANGED: if command.event_id == EventId.VOLUME_CHANGED:
volume = await self.delegate.get_absolute_volume() volume = await self.delegate.get_absolute_volume()
response = RegisterNotificationResponse(VolumeChangedEvent(volume)) response = RegisterNotificationResponse(VolumeChangedEvent(volume))
self.send_avrcp_response( elif command.event_id == EventId.PLAYBACK_STATUS_CHANGED:
transaction_label, playback_status = await self.delegate.get_playback_status()
avc.ResponseFrame.ResponseCode.INTERIM, response = RegisterNotificationResponse(
response, PlaybackStatusChangedEvent(play_status=playback_status)
) )
self._register_notification_listener(transaction_label, command) elif command.event_id == EventId.NOW_PLAYING_CONTENT_CHANGED:
playback_status = await self.delegate.get_playback_status()
response = RegisterNotificationResponse(NowPlayingContentChangedEvent())
else:
logger.warning("Event supported but not handled %s", command.event_id)
return return
if command.event_id == EventId.PLAYBACK_STATUS_CHANGED: self.send_avrcp_response(
# TODO: testing only, use delegate transaction_label,
response = RegisterNotificationResponse( avc.ResponseFrame.ResponseCode.INTERIM,
PlaybackStatusChangedEvent(play_status=PlayStatus.PLAYING) response,
) )
self.send_avrcp_response( self._register_notification_listener(transaction_label, command)
transaction_label,
avc.ResponseFrame.ResponseCode.INTERIM,
response,
)
self._register_notification_listener(transaction_label, command)
return
self._delegate_command(transaction_label, command, register_notification()) self._delegate_command(transaction_label, command, register_notification())
+10 -2
View File
@@ -37,7 +37,12 @@ class HCI_Bridge:
def on_packet(self, packet): def on_packet(self, packet):
# Convert the packet bytes to an object # Convert the packet bytes to an object
hci_packet = HCI_Packet.from_bytes(packet) try:
hci_packet = HCI_Packet.from_bytes(packet)
except Exception:
logger.warning('forwarding unparsed packet as-is')
self.hci_sink.on_packet(packet)
return
# Filter the packet # Filter the packet
if self.packet_filter is not None: if self.packet_filter is not None:
@@ -50,7 +55,10 @@ class HCI_Bridge:
return return
# Analyze the packet # Analyze the packet
self.trace(hci_packet) try:
self.trace(hci_packet)
except Exception:
logger.exception('Exception while tracing packet')
# Bridge the packet # Bridge the packet
self.hci_sink.on_packet(packet) self.hci_sink.on_packet(packet)
+14 -1
View File
@@ -421,7 +421,7 @@ class Controller:
hci.HCI_Command_Complete_Event( hci.HCI_Command_Complete_Event(
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode=command.op_code, command_opcode=command.op_code,
return_parameters=result, return_parameters=hci.HCI_GenericReturnParameters(data=result),
) )
) )
@@ -1898,6 +1898,19 @@ class Controller:
''' '''
return bytes([hci.HCI_SUCCESS]) + self.le_features.value.to_bytes(8, 'little') return bytes([hci.HCI_SUCCESS]) + self.le_features.value.to_bytes(8, 'little')
def on_hci_le_read_all_local_supported_features_command(
self, _command: hci.HCI_LE_Read_All_Local_Supported_Features_Command
) -> bytes | None:
'''
See Bluetooth spec Vol 4, Part E - 7.8.128 LE Read All Local Supported Features
Command
'''
return (
bytes([hci.HCI_SUCCESS])
+ bytes([0])
+ self.le_features.value.to_bytes(248, 'little')
)
def on_hci_le_set_random_address_command( def on_hci_le_set_random_address_command(
self, command: hci.HCI_LE_Set_Random_Address_Command self, command: hci.HCI_LE_Set_Random_Address_Command
) -> bytes | None: ) -> bytes | None:
+1 -1
View File
@@ -923,7 +923,7 @@ class DeviceClass:
# pylint: enable=line-too-long # pylint: enable=line-too-long
@staticmethod @staticmethod
def split_class_of_device(class_of_device): def split_class_of_device(class_of_device: int) -> tuple[int, int, int]:
# Split the bit fields of the composite class of device value into: # Split the bit fields of the composite class of device value into:
# (service_classes, major_device_class, minor_device_class) # (service_classes, major_device_class, minor_device_class)
return ( return (
+1021 -599
View File
File diff suppressed because it is too large Load Diff
+47 -41
View File
@@ -89,51 +89,54 @@ HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND = hci.hci_vendor_command_op_code(0x000E)
hci.HCI_Command.register_commands(globals()) hci.HCI_Command.register_commands(globals())
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class HCI_Intel_Read_Version_Command(hci.HCI_Command): class HCI_Intel_Read_Version_ReturnParameters(hci.HCI_StatusReturnParameters):
tlv: bytes = hci.field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Read_Version_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Read_Version_Command(
hci.HCI_SyncCommand[HCI_Intel_Read_Version_ReturnParameters]
):
param0: int = dataclasses.field(metadata=hci.metadata(1)) param0: int = dataclasses.field(metadata=hci.metadata(1))
return_parameters_fields = [
("status", hci.STATUS_SPEC),
("tlv", "*"),
]
@hci.HCI_SyncCommand.sync_command(hci.HCI_StatusReturnParameters)
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class Hci_Intel_Secure_Send_Command(hci.HCI_Command): class Hci_Intel_Secure_Send_Command(
hci.HCI_SyncCommand[hci.HCI_StatusReturnParameters]
):
data_type: int = dataclasses.field(metadata=hci.metadata(1)) data_type: int = dataclasses.field(metadata=hci.metadata(1))
data: bytes = dataclasses.field(metadata=hci.metadata("*")) data: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
("status", 1),
]
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class HCI_Intel_Reset_Command(hci.HCI_Command): class HCI_Intel_Reset_ReturnParameters(hci.HCI_ReturnParameters):
data: bytes = hci.field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Reset_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Reset_Command(hci.HCI_SyncCommand[HCI_Intel_Reset_ReturnParameters]):
reset_type: int = dataclasses.field(metadata=hci.metadata(1)) reset_type: int = dataclasses.field(metadata=hci.metadata(1))
patch_enable: int = dataclasses.field(metadata=hci.metadata(1)) patch_enable: int = dataclasses.field(metadata=hci.metadata(1))
ddc_reload: int = dataclasses.field(metadata=hci.metadata(1)) ddc_reload: int = dataclasses.field(metadata=hci.metadata(1))
boot_option: int = dataclasses.field(metadata=hci.metadata(1)) boot_option: int = dataclasses.field(metadata=hci.metadata(1))
boot_address: int = dataclasses.field(metadata=hci.metadata(4)) boot_address: int = dataclasses.field(metadata=hci.metadata(4))
return_parameters_fields = [
("data", "*"),
]
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class Hci_Intel_Write_Device_Config_Command(hci.HCI_Command): class HCI_Intel_Write_Device_Config_ReturnParameters(hci.HCI_StatusReturnParameters):
data: bytes = dataclasses.field(metadata=hci.metadata("*")) params: bytes = hci.field(metadata=hci.metadata('*'))
return_parameters_fields = [
("status", hci.STATUS_SPEC), @hci.HCI_SyncCommand.sync_command(HCI_Intel_Write_Device_Config_ReturnParameters)
("params", "*"), @dataclasses.dataclass
] class HCI_Intel_Write_Device_Config_Command(
hci.HCI_SyncCommand[HCI_Intel_Write_Device_Config_ReturnParameters]
):
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -402,7 +405,7 @@ class Driver(common.Driver):
self.host.on_hci_event_packet(event) self.host.on_hci_event_packet(event)
return return
if not event.return_parameters == hci.HCI_SUCCESS: if not event.return_parameters.status == hci.HCI_SUCCESS:
raise DriverError("HCI_Command_Complete_Event error") raise DriverError("HCI_Command_Complete_Event error")
if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets: if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets:
@@ -641,8 +644,8 @@ class Driver(common.Driver):
while ddc_data: while ddc_data:
ddc_len = 1 + ddc_data[0] ddc_len = 1 + ddc_data[0]
ddc_payload = ddc_data[:ddc_len] ddc_payload = ddc_data[:ddc_len]
await self.host.send_command( await self.host.send_sync_command(
Hci_Intel_Write_Device_Config_Command(data=ddc_payload) HCI_Intel_Write_Device_Config_Command(data=ddc_payload)
) )
ddc_data = ddc_data[ddc_len:] ddc_data = ddc_data[ddc_len:]
@@ -660,31 +663,34 @@ class Driver(common.Driver):
async def read_device_info(self) -> dict[ValueType, Any]: async def read_device_info(self) -> dict[ValueType, Any]:
self.host.ready = True self.host.ready = True
response = await self.host.send_command(hci.HCI_Reset_Command()) response1 = await self.host.send_sync_command_raw(hci.HCI_Reset_Command())
if not ( if not isinstance(
isinstance(response, hci.HCI_Command_Complete_Event) response1.return_parameters, hci.HCI_StatusReturnParameters
and response.return_parameters ) or response1.return_parameters.status not in (
in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS) hci.HCI_UNKNOWN_HCI_COMMAND_ERROR,
hci.HCI_SUCCESS,
): ):
# When the controller is in operational mode, the response is a # When the controller is in operational mode, the response is a
# successful response. # successful response.
# When the controller is in bootloader mode, # When the controller is in bootloader mode,
# HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything # HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything
# else is a failure. # else is a failure.
logger.warning(f"unexpected response: {response}") logger.warning(f"unexpected response: {response1}")
raise DriverError("unexpected HCI response") raise DriverError("unexpected HCI response")
# Read the firmware version. # Read the firmware version.
response = await self.host.send_command( response2 = await self.host.send_sync_command_raw(
HCI_Intel_Read_Version_Command(param0=0xFF) HCI_Intel_Read_Version_Command(param0=0xFF)
) )
if not isinstance(response, hci.HCI_Command_Complete_Event): if (
raise DriverError("unexpected HCI response") not isinstance(
response2.return_parameters, HCI_Intel_Read_Version_ReturnParameters
if response.return_parameters.status != 0: # type: ignore )
or response2.return_parameters.status != 0
):
raise DriverError("HCI_Intel_Read_Version_Command error") raise DriverError("HCI_Intel_Read_Version_Command error")
tlvs = _parse_tlv(response.return_parameters.tlv) # type: ignore tlvs = _parse_tlv(response2.return_parameters.tlv) # type: ignore
# Convert the list to a dict. That's Ok here because we only expect each type # Convert the list to a dict. That's Ok here because we only expect each type
# to appear just once. # to appear just once.
+102 -43
View File
@@ -16,6 +16,7 @@ Support for Realtek USB dongles.
Based on various online bits of information, including the Linux kernel. Based on various online bits of information, including the Linux kernel.
(see `drivers/bluetooth/btrtl.c`) (see `drivers/bluetooth/btrtl.c`)
""" """
from __future__ import annotations
import asyncio import asyncio
import enum import enum
@@ -31,10 +32,14 @@ import weakref
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from bumble import core, hci from bumble import core, hci
from bumble.drivers import common from bumble.drivers import common
if TYPE_CHECKING:
from bumble.host import Host
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -77,6 +82,7 @@ class RtlProjectId(enum.IntEnum):
PROJECT_ID_8852A = 18 PROJECT_ID_8852A = 18
PROJECT_ID_8852B = 20 PROJECT_ID_8852B = 20
PROJECT_ID_8852C = 25 PROJECT_ID_8852C = 25
PROJECT_ID_8761C = 51
RTK_PROJECT_ID_TO_ROM = { RTK_PROJECT_ID_TO_ROM = {
@@ -92,6 +98,7 @@ RTK_PROJECT_ID_TO_ROM = {
18: RTK_ROM_LMP_8852A, 18: RTK_ROM_LMP_8852A,
20: RTK_ROM_LMP_8852A, 20: RTK_ROM_LMP_8852A,
25: RTK_ROM_LMP_8852A, 25: RTK_ROM_LMP_8852A,
51: RTK_ROM_LMP_8761A,
} }
# List of USB (VendorID, ProductID) for Realtek-based devices. # List of USB (VendorID, ProductID) for Realtek-based devices.
@@ -122,7 +129,12 @@ RTK_USB_PRODUCTS = {
(0x2357, 0x0604), (0x2357, 0x0604),
(0x2550, 0x8761), (0x2550, 0x8761),
(0x2B89, 0x8761), (0x2B89, 0x8761),
(0x2C0A, 0x8761),
(0x7392, 0xC611), (0x7392, 0xC611),
# Realtek 8761CUV
(0x0B05, 0x1BF6),
(0x0BDA, 0xC761),
(0x7392, 0xF611),
# Realtek 8821AE # Realtek 8821AE
(0x0B05, 0x17DC), (0x0B05, 0x17DC),
(0x13D3, 0x3414), (0x13D3, 0x3414),
@@ -182,23 +194,36 @@ HCI_RTK_DROP_FIRMWARE_COMMAND = hci.hci_vendor_command_op_code(0x66)
hci.HCI_Command.register_commands(globals()) hci.HCI_Command.register_commands(globals())
@hci.HCI_Command.command
@dataclass @dataclass
class HCI_RTK_Read_ROM_Version_Command(hci.HCI_Command): class HCI_RTK_Read_ROM_Version_ReturnParameters(hci.HCI_StatusReturnParameters):
return_parameters_fields = [("status", hci.STATUS_SPEC), ("version", 1)] version: int = field(metadata=hci.metadata(1))
@hci.HCI_Command.command @hci.HCI_SyncCommand.sync_command(HCI_RTK_Read_ROM_Version_ReturnParameters)
@dataclass @dataclass
class HCI_RTK_Download_Command(hci.HCI_Command): class HCI_RTK_Read_ROM_Version_Command(
hci.HCI_SyncCommand[HCI_RTK_Read_ROM_Version_ReturnParameters]
):
pass
@dataclass
class HCI_RTK_Download_ReturnParameters(hci.HCI_StatusReturnParameters):
index: int = field(metadata=hci.metadata(1))
@hci.HCI_SyncCommand.sync_command(HCI_RTK_Download_ReturnParameters)
@dataclass
class HCI_RTK_Download_Command(hci.HCI_SyncCommand[HCI_RTK_Download_ReturnParameters]):
index: int = field(metadata=hci.metadata(1)) index: int = field(metadata=hci.metadata(1))
payload: bytes = field(metadata=hci.metadata(RTK_FRAGMENT_LENGTH)) payload: bytes = field(metadata=hci.metadata(RTK_FRAGMENT_LENGTH))
return_parameters_fields = [("status", hci.STATUS_SPEC), ("index", 1)]
@hci.HCI_Command.command @hci.HCI_SyncCommand.sync_command(hci.HCI_GenericReturnParameters)
@dataclass @dataclass
class HCI_RTK_Drop_Firmware_Command(hci.HCI_Command): class HCI_RTK_Drop_Firmware_Command(
hci.HCI_SyncCommand[hci.HCI_GenericReturnParameters]
):
pass pass
@@ -363,6 +388,15 @@ class Driver(common.Driver):
fw_name="rtl8761bu_fw.bin", fw_name="rtl8761bu_fw.bin",
config_name="rtl8761bu_config.bin", config_name="rtl8761bu_config.bin",
), ),
# 8761CU
DriverInfo(
rom=RTK_ROM_LMP_8761A,
hci=(0x0E, 0x00),
config_needed=False,
has_rom_version=True,
fw_name="rtl8761cu_fw.bin",
config_name="rtl8761cu_config.bin",
),
# 8822C # 8822C
DriverInfo( DriverInfo(
rom=RTK_ROM_LMP_8822B, rom=RTK_ROM_LMP_8822B,
@@ -420,9 +454,17 @@ class Driver(common.Driver):
@staticmethod @staticmethod
def find_driver_info(hci_version, hci_subversion, lmp_subversion): def find_driver_info(hci_version, hci_subversion, lmp_subversion):
for driver_info in Driver.DRIVER_INFOS: for driver_info in Driver.DRIVER_INFOS:
if driver_info.rom == lmp_subversion and driver_info.hci == ( if driver_info.rom == lmp_subversion and (
hci_subversion, driver_info.hci
hci_version, == (
hci_subversion,
hci_version,
)
or driver_info.hci
== (
hci_subversion,
0x0,
)
): ):
return driver_info return driver_info
@@ -467,7 +509,7 @@ class Driver(common.Driver):
return None return None
@staticmethod @staticmethod
def check(host): def check(host: Host) -> bool:
if not host.hci_metadata: if not host.hci_metadata:
logger.debug("USB metadata not found") logger.debug("USB metadata not found")
return False return False
@@ -491,37 +533,44 @@ class Driver(common.Driver):
return True return True
@staticmethod @staticmethod
async def get_loaded_firmware_version(host): async def get_loaded_firmware_version(host: Host) -> int | None:
response = await host.send_command(HCI_RTK_Read_ROM_Version_Command()) response1 = await host.send_sync_command_raw(HCI_RTK_Read_ROM_Version_Command())
if (
if response.return_parameters.status != hci.HCI_SUCCESS: not isinstance(
response1.return_parameters, HCI_RTK_Read_ROM_Version_ReturnParameters
)
or response1.return_parameters.status != hci.HCI_SUCCESS
):
return None return None
response = await host.send_command( response2 = await host.send_sync_command(
hci.HCI_Read_Local_Version_Information_Command(), check_result=True hci.HCI_Read_Local_Version_Information_Command()
)
return (
response.return_parameters.hci_subversion << 16
| response.return_parameters.lmp_subversion
) )
return response2.hci_subversion << 16 | response2.lmp_subversion
@classmethod @classmethod
async def driver_info_for_host(cls, host): async def driver_info_for_host(cls, host: Host) -> DriverInfo | None:
try: try:
await host.send_command( await host.send_sync_command(
hci.HCI_Reset_Command(), hci.HCI_Reset_Command(),
check_result=True,
response_timeout=cls.POST_RESET_DELAY, response_timeout=cls.POST_RESET_DELAY,
) )
host.ready = True # Needed to let the host know the controller is ready. host.ready = True # Needed to let the host know the controller is ready.
except asyncio.exceptions.TimeoutError: except asyncio.exceptions.TimeoutError:
logger.warning("timeout waiting for hci reset, retrying") logger.warning("timeout waiting for hci reset, retrying")
await host.send_command(hci.HCI_Reset_Command(), check_result=True) await host.send_sync_command(hci.HCI_Reset_Command())
host.ready = True host.ready = True
command = hci.HCI_Read_Local_Version_Information_Command() response = await host.send_sync_command_raw(
response = await host.send_command(command, check_result=True) hci.HCI_Read_Local_Version_Information_Command()
if response.command_opcode != command.op_code: )
if (
not isinstance(
response.return_parameters,
hci.HCI_Read_Local_Version_Information_ReturnParameters,
)
or response.return_parameters.status != hci.HCI_SUCCESS
):
logger.error("failed to probe local version information") logger.error("failed to probe local version information")
return None return None
@@ -546,7 +595,7 @@ class Driver(common.Driver):
return driver_info return driver_info
@classmethod @classmethod
async def for_host(cls, host, force=False): async def for_host(cls, host: Host, force: bool = False):
# Check that a driver is needed for this host # Check that a driver is needed for this host
if not force and not cls.check(host): if not force and not cls.check(host):
return None return None
@@ -601,15 +650,21 @@ class Driver(common.Driver):
# TODO: load the firmware # TODO: load the firmware
async def download_for_rtl8723b(self): async def download_for_rtl8723b(self) -> int | None:
if self.driver_info.has_rom_version: if self.driver_info.has_rom_version:
response = await self.host.send_command( response1 = await self.host.send_sync_command_raw(
HCI_RTK_Read_ROM_Version_Command(), check_result=True HCI_RTK_Read_ROM_Version_Command()
) )
if response.return_parameters.status != hci.HCI_SUCCESS: if (
not isinstance(
response1.return_parameters,
HCI_RTK_Read_ROM_Version_ReturnParameters,
)
or response1.return_parameters.status != hci.HCI_SUCCESS
):
logger.warning("can't get ROM version") logger.warning("can't get ROM version")
return None return None
rom_version = response.return_parameters.version rom_version = response1.return_parameters.version
logger.debug(f"ROM version before download: {rom_version:04X}") logger.debug(f"ROM version before download: {rom_version:04X}")
else: else:
rom_version = 0 rom_version = 0
@@ -644,21 +699,25 @@ class Driver(common.Driver):
fragment_offset = fragment_index * RTK_FRAGMENT_LENGTH fragment_offset = fragment_index * RTK_FRAGMENT_LENGTH
fragment = payload[fragment_offset : fragment_offset + RTK_FRAGMENT_LENGTH] fragment = payload[fragment_offset : fragment_offset + RTK_FRAGMENT_LENGTH]
logger.debug(f"downloading fragment {fragment_index}") logger.debug(f"downloading fragment {fragment_index}")
await self.host.send_command( await self.host.send_sync_command(
HCI_RTK_Download_Command(index=download_index, payload=fragment), HCI_RTK_Download_Command(index=download_index, payload=fragment)
check_result=True,
) )
logger.debug("download complete!") logger.debug("download complete!")
# Read the version again # Read the version again
response = await self.host.send_command( response2 = await self.host.send_sync_command_raw(
HCI_RTK_Read_ROM_Version_Command(), check_result=True HCI_RTK_Read_ROM_Version_Command()
) )
if response.return_parameters.status != hci.HCI_SUCCESS: if (
not isinstance(
response2.return_parameters, HCI_RTK_Read_ROM_Version_ReturnParameters
)
or response2.return_parameters.status != hci.HCI_SUCCESS
):
logger.warning("can't get ROM version") logger.warning("can't get ROM version")
else: else:
rom_version = response.return_parameters.version rom_version = response2.return_parameters.version
logger.debug(f"ROM version after download: {rom_version:02X}") logger.debug(f"ROM version after download: {rom_version:02X}")
return firmware.version return firmware.version
@@ -680,7 +739,7 @@ class Driver(common.Driver):
async def init_controller(self): async def init_controller(self):
await self.download_firmware() await self.download_firmware()
await self.host.send_command(hci.HCI_Reset_Command(), check_result=True) await self.host.send_sync_command(hci.HCI_Reset_Command())
logger.info(f"loaded FW image {self.driver_info.fw_name}") logger.info(f"loaded FW image {self.driver_info.fw_name}")
+2 -2
View File
@@ -29,7 +29,7 @@ import functools
import logging import logging
import struct import struct
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from typing import TypeVar from typing import ClassVar, TypeVar
from bumble.att import Attribute, AttributeValue, AttributeValueV2 from bumble.att import Attribute, AttributeValue, AttributeValueV2
from bumble.colors import color from bumble.colors import color
@@ -403,7 +403,7 @@ class TemplateService(Service):
to expose their UUID as a class property to expose their UUID as a class property
''' '''
UUID: UUID UUID: ClassVar[UUID]
def __init__( def __init__(
self, self,
+5 -4
View File
@@ -34,11 +34,14 @@ from datetime import datetime
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
ClassVar,
Generic, Generic,
TypeVar, TypeVar,
overload, overload,
) )
from typing_extensions import Self
from bumble import att, core, l2cap, utils from bumble import att, core, l2cap, utils
from bumble.colors import color from bumble.colors import color
from bumble.core import UUID, InvalidStateError from bumble.core import UUID, InvalidStateError
@@ -249,10 +252,10 @@ class ProfileServiceProxy:
Base class for profile-specific service proxies Base class for profile-specific service proxies
''' '''
SERVICE_CLASS: type[TemplateService] SERVICE_CLASS: ClassVar[type[TemplateService]]
@classmethod @classmethod
def from_client(cls, client: Client) -> ProfileServiceProxy | None: def from_client(cls, client: Client) -> Self | None:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID) return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -285,8 +288,6 @@ class Client:
self._bearer_id = ( self._bearer_id = (
f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]' f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
) )
# Fill the mtu.
bearer.on_att_mtu_update(att.ATT_DEFAULT_MTU)
self.connection = bearer.connection self.connection = bearer.connection
else: else:
bearer.on(bearer.EVENT_DISCONNECTION, self.on_disconnection) bearer.on(bearer.EVENT_DISCONNECTION, self.on_disconnection)
+100 -1
View File
@@ -115,7 +115,6 @@ class Server(utils.EventEmitter):
channel.connection.handle, channel.connection.handle,
channel.source_cid, channel.source_cid,
) )
channel.att_mtu = att.ATT_DEFAULT_MTU
channel.sink = lambda pdu: self.on_gatt_pdu( channel.sink = lambda pdu: self.on_gatt_pdu(
channel, att.ATT_PDU.from_bytes(pdu) channel, att.ATT_PDU.from_bytes(pdu)
) )
@@ -777,6 +776,18 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR, error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
) )
if (
request.starting_handle == 0x0000
or request.starting_handle > request.ending_handle
):
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=att.ATT_INVALID_HANDLE_ERROR,
)
self.send_response(bearer, response)
return
attributes: list[tuple[int, bytes]] = [] attributes: list[tuple[int, bytes]] = []
for attribute in ( for attribute in (
attribute attribute
@@ -977,6 +988,94 @@ class Server(utils.EventEmitter):
self.send_response(bearer, response) self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_multiple_request(
self, bearer: att.Bearer, request: att.ATT_Read_Multiple_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.7 Read Multiple Request.
'''
response: att.ATT_PDU
pdu_space_available = bearer.att_mtu - 1
values: list[bytes] = []
for handle in request.set_of_handles:
if not (attribute := self.get_attribute(handle)):
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=handle,
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(bearer, response)
return
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = await attribute.read_value(bearer)
# Check the attribute value size
max_attribute_size = min(bearer.att_mtu - 1, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
# Check if there is enough space
entry_size = len(attribute_value)
if pdu_space_available < entry_size:
break
# Add the attribute to the list
values.append(attribute_value)
pdu_space_available -= entry_size
response = att.ATT_Read_Multiple_Response(set_of_values=b''.join(values))
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_multiple_variable_request(
self, bearer: att.Bearer, request: att.ATT_Read_Multiple_Variable_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.11 Read Multiple Variable Request.
'''
response: att.ATT_PDU
pdu_space_available = bearer.att_mtu - 1
length_value_tuple_list: list[tuple[int, bytes]] = []
for handle in request.set_of_handles:
if not (attribute := self.get_attribute(handle)):
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=handle,
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(bearer, response)
return
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = await attribute.read_value(bearer)
length = len(attribute_value)
# Check the attribute value size
max_attribute_size = min(bearer.att_mtu - 3, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
# Check if there is enough space
entry_size = 2 + len(attribute_value)
# Add the attribute to the list
length_value_tuple_list.append((length, attribute_value))
pdu_space_available -= entry_size
if pdu_space_available <= 0:
break
response = att.ATT_Read_Multiple_Variable_Response(
length_value_tuple_list=length_value_tuple_list
)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_write_request( async def on_att_write_request(
self, bearer: att.Bearer, request: att.ATT_Write_Request self, bearer: att.Bearer, request: att.ATT_Write_Request
+1368 -845
View File
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -312,11 +312,11 @@ class HID(ABC, utils.EventEmitter):
def send_pdu_on_ctrl(self, msg: bytes) -> None: def send_pdu_on_ctrl(self, msg: bytes) -> None:
assert self.l2cap_ctrl_channel assert self.l2cap_ctrl_channel
self.l2cap_ctrl_channel.send_pdu(msg) self.l2cap_ctrl_channel.write(msg)
def send_pdu_on_intr(self, msg: bytes) -> None: def send_pdu_on_intr(self, msg: bytes) -> None:
assert self.l2cap_intr_channel assert self.l2cap_intr_channel
self.l2cap_intr_channel.send_pdu(msg) self.l2cap_intr_channel.write(msg)
def send_data(self, data: bytes) -> None: def send_data(self, data: bytes) -> None:
if self.role == HID.Role.HOST: if self.role == HID.Role.HOST:
+335 -153
View File
@@ -21,13 +21,16 @@ import asyncio
import collections import collections
import dataclasses import dataclasses
import logging import logging
import struct
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
from bumble import drivers, hci, utils from bumble import drivers, hci, utils
from bumble.colors import color from bumble.colors import color
from bumble.core import ConnectionPHY, InvalidStateError, PhysicalTransport from bumble.core import (
ConnectionPHY,
InvalidStateError,
PhysicalTransport,
)
from bumble.l2cap import L2CAP_PDU from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper from bumble.snoop import Snooper
from bumble.transport.common import TransportLostError from bumble.transport.common import TransportLostError
@@ -35,7 +38,6 @@ from bumble.transport.common import TransportLostError
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.transport.common import TransportSink, TransportSource from bumble.transport.common import TransportSink, TransportSource
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -236,6 +238,9 @@ class IsoLink:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_RP = TypeVar('_RP', bound=hci.HCI_ReturnParameters)
class Host(utils.EventEmitter): class Host(utils.EventEmitter):
connections: dict[int, Connection] connections: dict[int, Connection]
cis_links: dict[int, IsoLink] cis_links: dict[int, IsoLink]
@@ -264,13 +269,20 @@ class Host(utils.EventEmitter):
self.bis_links = {} # BIS links, by connection handle self.bis_links = {} # BIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle self.sco_links = {} # SCO links, by connection handle
self.bigs = {} # BIG Handle to BIS Handles self.bigs = {} # BIG Handle to BIS Handles
self.pending_command = None self.pending_command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand | None = None
self.pending_response: asyncio.Future[Any] | None = None self.pending_response: (
asyncio.Future[
hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event
]
| None
) = None
self.number_of_supported_advertising_sets = 0 self.number_of_supported_advertising_sets = 0
self.maximum_advertising_data_length = 31 self.maximum_advertising_data_length = 31
self.local_version = None self.local_version: (
hci.HCI_Read_Local_Version_Information_ReturnParameters | None
) = None
self.local_supported_commands = 0 self.local_supported_commands = 0
self.local_le_features = 0 self.local_le_features = hci.LeFeatureMask(0) # LE features
self.local_lmp_features = hci.LmpFeatureMask(0) # Classic LMP features self.local_lmp_features = hci.LmpFeatureMask(0) # Classic LMP features
self.suggested_max_tx_octets = 251 # Max allowed self.suggested_max_tx_octets = 251 # Max allowed
self.suggested_max_tx_time = 2120 # Max allowed self.suggested_max_tx_time = 2120 # Max allowed
@@ -312,7 +324,7 @@ class Host(utils.EventEmitter):
self.emit('flush') self.emit('flush')
self.command_semaphore.release() self.command_semaphore.release()
async def reset(self, driver_factory=drivers.get_driver_for_host): async def reset(self, driver_factory=drivers.get_driver_for_host) -> None:
if self.ready: if self.ready:
self.ready = False self.ready = False
await self.flush() await self.flush()
@@ -330,57 +342,61 @@ class Host(utils.EventEmitter):
# Send a reset command unless a driver has already done so. # Send a reset command unless a driver has already done so.
if reset_needed: if reset_needed:
await self.send_command(hci.HCI_Reset_Command(), check_result=True) await self.send_sync_command(hci.HCI_Reset_Command())
self.ready = True self.ready = True
response = await self.send_command( response1 = await self.send_sync_command(
hci.HCI_Read_Local_Supported_Commands_Command(), check_result=True hci.HCI_Read_Local_Supported_Commands_Command()
) )
self.local_supported_commands = int.from_bytes( self.local_supported_commands = int.from_bytes(
response.return_parameters.supported_commands, 'little' response1.supported_commands, 'little'
) )
if self.supports_command(hci.HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command(
hci.HCI_LE_Read_Local_Supported_Features_Command(), check_result=True
)
self.local_le_features = struct.unpack(
'<Q', response.return_parameters.le_features
)[0]
if self.supports_command(hci.HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND): if self.supports_command(hci.HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND):
response = await self.send_command( self.local_version = await self.send_sync_command(
hci.HCI_Read_Local_Version_Information_Command(), check_result=True hci.HCI_Read_Local_Version_Information_Command()
)
if self.supports_command(hci.HCI_LE_READ_ALL_LOCAL_SUPPORTED_FEATURES_COMMAND):
response2 = await self.send_sync_command(
hci.HCI_LE_Read_All_Local_Supported_Features_Command()
)
self.local_le_features = hci.LeFeatureMask(
int.from_bytes(response2.le_features, 'little')
)
elif self.supports_command(hci.HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response3 = await self.send_sync_command(
hci.HCI_LE_Read_Local_Supported_Features_Command()
)
self.local_le_features = hci.LeFeatureMask(
int.from_bytes(response3.le_features, 'little')
) )
self.local_version = response.return_parameters
if self.supports_command(hci.HCI_READ_LOCAL_EXTENDED_FEATURES_COMMAND): if self.supports_command(hci.HCI_READ_LOCAL_EXTENDED_FEATURES_COMMAND):
max_page_number = 0 max_page_number = 0
page_number = 0 page_number = 0
lmp_features = 0 lmp_features = 0
while page_number <= max_page_number: while page_number <= max_page_number:
response = await self.send_command( response4 = await self.send_sync_command(
hci.HCI_Read_Local_Extended_Features_Command( hci.HCI_Read_Local_Extended_Features_Command(
page_number=page_number page_number=page_number
), )
check_result=True,
) )
lmp_features |= int.from_bytes( lmp_features |= int.from_bytes(
response.return_parameters.extended_lmp_features, 'little' response4.extended_lmp_features, 'little'
) << (64 * page_number) ) << (64 * page_number)
max_page_number = response.return_parameters.maximum_page_number max_page_number = response4.maximum_page_number
page_number += 1 page_number += 1
self.local_lmp_features = hci.LmpFeatureMask(lmp_features) self.local_lmp_features = hci.LmpFeatureMask(lmp_features)
elif self.supports_command(hci.HCI_READ_LOCAL_SUPPORTED_FEATURES_COMMAND): elif self.supports_command(hci.HCI_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command( response5 = await self.send_sync_command(
hci.HCI_Read_Local_Supported_Features_Command(), check_result=True hci.HCI_Read_Local_Supported_Features_Command()
) )
self.local_lmp_features = hci.LmpFeatureMask( self.local_lmp_features = hci.LmpFeatureMask(
int.from_bytes(response.return_parameters.lmp_features, 'little') int.from_bytes(response5.lmp_features, 'little')
) )
await self.send_command( await self.send_sync_command(
hci.HCI_Set_Event_Mask_Command( hci.HCI_Set_Event_Mask_Command(
event_mask=hci.HCI_Set_Event_Mask_Command.mask( event_mask=hci.HCI_Set_Event_Mask_Command.mask(
[ [
@@ -437,7 +453,7 @@ class Host(utils.EventEmitter):
) )
) )
if self.supports_command(hci.HCI_SET_EVENT_MASK_PAGE_2_COMMAND): if self.supports_command(hci.HCI_SET_EVENT_MASK_PAGE_2_COMMAND):
await self.send_command( await self.send_sync_command(
hci.HCI_Set_Event_Mask_Page_2_Command( hci.HCI_Set_Event_Mask_Page_2_Command(
event_mask_page_2=hci.HCI_Set_Event_Mask_Page_2_Command.mask( event_mask_page_2=hci.HCI_Set_Event_Mask_Page_2_Command.mask(
[hci.HCI_ENCRYPTION_CHANGE_V2_EVENT] [hci.HCI_ENCRYPTION_CHANGE_V2_EVENT]
@@ -490,29 +506,28 @@ class Host(utils.EventEmitter):
hci.HCI_LE_TRANSMIT_POWER_REPORTING_EVENT, hci.HCI_LE_TRANSMIT_POWER_REPORTING_EVENT,
hci.HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT, hci.HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT,
hci.HCI_LE_SUBRATE_CHANGE_EVENT, hci.HCI_LE_SUBRATE_CHANGE_EVENT,
hci.HCI_LE_READ_ALL_REMOTE_FEATURES_COMPLETE_EVENT,
hci.HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMPLETE_EVENT, hci.HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMPLETE_EVENT,
hci.HCI_LE_CS_PROCEDURE_ENABLE_COMPLETE_EVENT, hci.HCI_LE_CS_PROCEDURE_ENABLE_COMPLETE_EVENT,
hci.HCI_LE_CS_SECURITY_ENABLE_COMPLETE_EVENT, hci.HCI_LE_CS_SECURITY_ENABLE_COMPLETE_EVENT,
hci.HCI_LE_CS_CONFIG_COMPLETE_EVENT, hci.HCI_LE_CS_CONFIG_COMPLETE_EVENT,
hci.HCI_LE_CS_SUBEVENT_RESULT_EVENT, hci.HCI_LE_CS_SUBEVENT_RESULT_EVENT,
hci.HCI_LE_CS_SUBEVENT_RESULT_CONTINUE_EVENT, hci.HCI_LE_CS_SUBEVENT_RESULT_CONTINUE_EVENT,
hci.HCI_LE_MONITORED_ADVERTISERS_REPORT_EVENT,
hci.HCI_LE_FRAME_SPACE_UPDATE_COMPLETE_EVENT,
hci.HCI_LE_UTP_RECEIVE_EVENT,
hci.HCI_LE_CONNECTION_RATE_CHANGE_EVENT,
] ]
) )
await self.send_command( await self.send_sync_command(
hci.HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask) hci.HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask)
) )
if self.supports_command(hci.HCI_READ_BUFFER_SIZE_COMMAND): if self.supports_command(hci.HCI_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command( response6 = await self.send_sync_command(hci.HCI_Read_Buffer_Size_Command())
hci.HCI_Read_Buffer_Size_Command(), check_result=True hc_acl_data_packet_length = response6.hc_acl_data_packet_length
) hc_total_num_acl_data_packets = response6.hc_total_num_acl_data_packets
hc_acl_data_packet_length = (
response.return_parameters.hc_acl_data_packet_length
)
hc_total_num_acl_data_packets = (
response.return_parameters.hc_total_num_acl_data_packets
)
logger.debug( logger.debug(
'HCI ACL flow control: ' 'HCI ACL flow control: '
@@ -531,19 +546,13 @@ class Host(utils.EventEmitter):
iso_data_packet_length = 0 iso_data_packet_length = 0
total_num_iso_data_packets = 0 total_num_iso_data_packets = 0
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_V2_COMMAND): if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await self.send_command( response7 = await self.send_sync_command(
hci.HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True hci.HCI_LE_Read_Buffer_Size_V2_Command()
)
le_acl_data_packet_length = (
response.return_parameters.le_acl_data_packet_length
)
total_num_le_acl_data_packets = (
response.return_parameters.total_num_le_acl_data_packets
)
iso_data_packet_length = response.return_parameters.iso_data_packet_length
total_num_iso_data_packets = (
response.return_parameters.total_num_iso_data_packets
) )
le_acl_data_packet_length = response7.le_acl_data_packet_length
total_num_le_acl_data_packets = response7.total_num_le_acl_data_packets
iso_data_packet_length = response7.iso_data_packet_length
total_num_iso_data_packets = response7.total_num_iso_data_packets
logger.debug( logger.debug(
'HCI LE flow control: ' 'HCI LE flow control: '
@@ -553,15 +562,11 @@ class Host(utils.EventEmitter):
f'total_num_iso_data_packets={total_num_iso_data_packets}' f'total_num_iso_data_packets={total_num_iso_data_packets}'
) )
elif self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND): elif self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command( response8 = await self.send_sync_command(
hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True hci.HCI_LE_Read_Buffer_Size_Command()
)
le_acl_data_packet_length = (
response.return_parameters.le_acl_data_packet_length
)
total_num_le_acl_data_packets = (
response.return_parameters.total_num_le_acl_data_packets
) )
le_acl_data_packet_length = response8.le_acl_data_packet_length
total_num_le_acl_data_packets = response8.total_num_le_acl_data_packets
logger.debug( logger.debug(
'HCI LE ACL flow control: ' 'HCI LE ACL flow control: '
@@ -592,16 +597,16 @@ class Host(utils.EventEmitter):
) and self.supports_command( ) and self.supports_command(
hci.HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND hci.HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
): ):
response = await self.send_command( response9 = await self.send_sync_command(
hci.HCI_LE_Read_Suggested_Default_Data_Length_Command() hci.HCI_LE_Read_Suggested_Default_Data_Length_Command()
) )
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets suggested_max_tx_octets = response9.suggested_max_tx_octets
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time suggested_max_tx_time = response9.suggested_max_tx_time
if ( if (
suggested_max_tx_octets != self.suggested_max_tx_octets suggested_max_tx_octets != self.suggested_max_tx_octets
or suggested_max_tx_time != self.suggested_max_tx_time or suggested_max_tx_time != self.suggested_max_tx_time
): ):
await self.send_command( await self.send_sync_command(
hci.HCI_LE_Write_Suggested_Default_Data_Length_Command( hci.HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets=self.suggested_max_tx_octets, suggested_max_tx_octets=self.suggested_max_tx_octets,
suggested_max_tx_time=self.suggested_max_tx_time, suggested_max_tx_time=self.suggested_max_tx_time,
@@ -611,24 +616,28 @@ class Host(utils.EventEmitter):
if self.supports_command( if self.supports_command(
hci.HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND hci.HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND
): ):
response = await self.send_command( try:
hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command(), response10 = await self.send_sync_command(
check_result=True, hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
) )
self.number_of_supported_advertising_sets = ( self.number_of_supported_advertising_sets = (
response.return_parameters.num_supported_advertising_sets response10.num_supported_advertising_sets
) )
except hci.HCI_Error:
logger.warning('Failed to read number of supported advertising sets')
if self.supports_command( if self.supports_command(
hci.HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND hci.HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND
): ):
response = await self.send_command( try:
hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command(), response11 = await self.send_sync_command(
check_result=True, hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
) )
self.maximum_advertising_data_length = ( self.maximum_advertising_data_length = (
response.return_parameters.max_advertising_data_length response11.max_advertising_data_length
) )
except hci.HCI_Error:
logger.warning('Failed to read maximum advertising data length')
@property @property
def controller(self) -> TransportSink | None: def controller(self) -> TransportSink | None:
@@ -654,56 +663,175 @@ class Host(utils.EventEmitter):
if self.hci_sink: if self.hci_sink:
self.hci_sink.on_packet(bytes(packet)) self.hci_sink.on_packet(bytes(packet))
async def send_command( async def _send_command(
self, command, check_result=False, response_timeout: int | None = None self,
): command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand,
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event:
# Wait until we can send (only one pending command at a time) # Wait until we can send (only one pending command at a time)
async with self.command_semaphore: await self.command_semaphore.acquire()
assert self.pending_command is None
assert self.pending_response is None
# Create a future value to hold the eventual response # Create a future value to hold the eventual response
self.pending_response = asyncio.get_running_loop().create_future() assert self.pending_command is None
self.pending_command = command assert self.pending_response is None
self.pending_response = asyncio.get_running_loop().create_future()
self.pending_command = command
try: response: (
self.send_hci_packet(command) hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event | None
await asyncio.wait_for(self.pending_response, timeout=response_timeout) ) = None
response = self.pending_response.result() try:
self.send_hci_packet(command)
response = await asyncio.wait_for(
self.pending_response, timeout=response_timeout
)
return response
except Exception:
logger.exception(color("!!! Exception while sending command:", "red"))
raise
finally:
self.pending_command = None
self.pending_response = None
if (
response is not None
and response.num_hci_command_packets
and self.command_semaphore.locked()
):
self.command_semaphore.release()
# Check the return parameters if required @overload
if check_result: async def send_command(
if isinstance(response, hci.HCI_Command_Status_Event): self,
status = response.status # type: ignore[attr-defined] command: hci.HCI_SyncCommand[_RP],
elif isinstance(response.return_parameters, int): check_result: bool = False,
status = response.return_parameters response_timeout: float | None = None,
elif isinstance(response.return_parameters, bytes): ) -> hci.HCI_Command_Complete_Event[_RP]: ...
# return parameters first field is a one byte status code
status = response.return_parameters[0]
else:
status = response.return_parameters.status
if status != hci.HCI_SUCCESS: @overload
logger.warning( async def send_command(
f'{command.name} failed ' self,
f'({hci.HCI_Constant.error_name(status)})' command: hci.HCI_AsyncCommand,
) check_result: bool = False,
raise hci.HCI_Error(status) response_timeout: float | None = None,
) -> hci.HCI_Command_Status_Event: ...
return response async def send_command(
except Exception: self,
logger.exception(color("!!! Exception while sending command:", "red")) command: hci.HCI_SyncCommand[_RP] | hci.HCI_AsyncCommand,
raise check_result: bool = False,
finally: response_timeout: float | None = None,
self.pending_command = None ) -> hci.HCI_Command_Complete_Event[_RP] | hci.HCI_Command_Status_Event:
self.pending_response = None response = await self._send_command(command, response_timeout)
# Use this method to send a command from a task # Check the return parameters if required
def send_command_sync(self, command: hci.HCI_Command) -> None: if check_result:
async def send_command(command: hci.HCI_Command) -> None: if isinstance(response, hci.HCI_Command_Status_Event):
await self.send_command(command) status = response.status # type: ignore[attr-defined]
elif isinstance(response.return_parameters, int):
status = response.return_parameters
elif isinstance(response.return_parameters, bytes):
# return parameters first field is a one byte status code
status = response.return_parameters[0]
elif isinstance(
response.return_parameters, hci.HCI_GenericReturnParameters
):
# FIXME: temporary workaround
# NO STATUS
status = hci.HCI_SUCCESS
else:
status = response.return_parameters.status
asyncio.create_task(send_command(command)) if status != hci.HCI_SUCCESS:
logger.warning(
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return response
async def send_sync_command(
self, command: hci.HCI_SyncCommand[_RP], response_timeout: float | None = None
) -> _RP:
response = await self.send_sync_command_raw(command, response_timeout)
return_parameters = response.return_parameters
# Check the return parameters's status
if isinstance(return_parameters, hci.HCI_StatusReturnParameters):
status = return_parameters.status
elif isinstance(return_parameters, hci.HCI_GenericReturnParameters):
# if the payload has at least one byte, assume the first byte is the status
if not return_parameters.data:
raise RuntimeError('no status byte in return parameters')
status = hci.HCI_ErrorCode(return_parameters.data[0])
else:
raise RuntimeError(
f'unexpected return parameters type ({type(return_parameters)})'
)
if status != hci.HCI_ErrorCode.SUCCESS:
logger.warning(
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return return_parameters
async def send_sync_command_raw(
self,
command: hci.HCI_SyncCommand[_RP],
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event[_RP]:
response = await self._send_command(command, response_timeout)
# For unknown HCI commands, some controllers return Command Status instead of
# Command Complete.
if (
isinstance(response, hci.HCI_Command_Status_Event)
and response.status == hci.HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
):
return hci.HCI_Command_Complete_Event(
num_hci_command_packets=response.num_hci_command_packets,
command_opcode=command.op_code,
return_parameters=hci.HCI_StatusReturnParameters(
status=hci.HCI_ErrorCode(response.status)
), # type: ignore
)
# Check that the response is of the expected type
assert isinstance(response, hci.HCI_Command_Complete_Event)
return response
async def send_async_command(
self,
command: hci.HCI_AsyncCommand,
check_status: bool = True,
response_timeout: float | None = None,
) -> hci.HCI_ErrorCode:
response = await self._send_command(command, response_timeout)
# For unknown HCI commands, some controllers return Command Complete instead of
# Command Status.
if isinstance(response, hci.HCI_Command_Complete_Event):
# Assume the first byte of the return parameters is the status
if (
status := hci.HCI_ErrorCode(response.parameters[3])
) != hci.HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR:
logger.warning(f'unexpected return paramerers status {status}')
else:
assert isinstance(response, hci.HCI_Command_Status_Event)
status = hci.HCI_ErrorCode(response.status)
# Check the status if required
if check_status:
if status != hci.HCI_CommandStatus.PENDING:
logger.warning(f'{command.name} failed ' f'({status.name})')
raise hci.HCI_Error(status)
return status
@utils.deprecated("Use utils.AsyncRunner.spawn() instead.")
def send_command_sync(self, command: hci.HCI_AsyncCommand) -> None:
utils.AsyncRunner.spawn(self.send_async_command(command))
def send_acl_sdu(self, connection_handle: int, sdu: bytes) -> None: def send_acl_sdu(self, connection_handle: int, sdu: bytes) -> None:
if not (connection := self.connections.get(connection_handle)): if not (connection := self.connections.get(connection_handle)):
@@ -728,10 +856,22 @@ class Host(utils.EventEmitter):
data=pdu, data=pdu,
) )
logger.debug( logger.debug(
'>>> ACL packet enqueue: (Handle=0x%04X) %s', connection_handle, pdu '>>> ACL packet enqueue: (handle=0x%04X) %s',
connection_handle,
pdu.hex(),
) )
packet_queue.enqueue(acl_packet, connection_handle) packet_queue.enqueue(acl_packet, connection_handle)
def send_sco_sdu(self, connection_handle: int, sdu: bytes) -> None:
self.send_hci_packet(
hci.HCI_SynchronousDataPacket(
connection_handle=connection_handle,
packet_status=0,
data_total_length=len(sdu),
data=sdu,
)
)
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.send_acl_sdu(connection_handle, bytes(L2CAP_PDU(cid, pdu))) self.send_acl_sdu(connection_handle, bytes(L2CAP_PDU(cid, pdu)))
@@ -816,16 +956,18 @@ class Host(utils.EventEmitter):
if self.local_supported_commands & mask if self.local_supported_commands & mask
) )
def supports_le_features(self, feature: hci.LeFeatureMask) -> bool: def supports_le_features(self, features: hci.LeFeatureMask) -> bool:
return (self.local_le_features & feature) == feature return (self.local_le_features & features) == features
def supports_lmp_features(self, feature: hci.LmpFeatureMask) -> bool: def supports_lmp_features(self, features: hci.LmpFeatureMask) -> bool:
return self.local_lmp_features & (feature) == feature return self.local_lmp_features & (features) == features
@property @property
def supported_le_features(self): def supported_le_features(self) -> list[hci.LeFeature]:
return [ return [
feature for feature in range(64) if self.local_le_features & (1 << feature) feature
for feature in hci.LeFeature
if self.local_le_features & (1 << feature)
] ]
# Packet Sink protocol (packets coming from the controller via HCI) # Packet Sink protocol (packets coming from the controller via HCI)
@@ -914,6 +1056,8 @@ class Host(utils.EventEmitter):
self.pending_response.set_result(event) self.pending_response.set_result(event)
else: else:
logger.warning('!!! no pending response future to set') logger.warning('!!! no pending response future to set')
if event.num_hci_command_packets and self.command_semaphore.locked():
self.command_semaphore.release()
############################################################ ############################################################
# HCI handlers # HCI handlers
@@ -925,7 +1069,13 @@ class Host(utils.EventEmitter):
if event.command_opcode == 0: if event.command_opcode == 0:
# This is used just for the Num_HCI_Command_Packets field, not related to # This is used just for the Num_HCI_Command_Packets field, not related to
# an actual command # an actual command
logger.debug('no-command event') logger.debug('no-command event for flow control')
# Release the command semaphore if needed
if event.num_hci_command_packets and self.command_semaphore.locked():
logger.debug('command complete event releasing semaphore')
self.command_semaphore.release()
return return
return self.on_command_processed(event) return self.on_command_processed(event)
@@ -1106,7 +1256,7 @@ class Host(utils.EventEmitter):
self, event: hci.HCI_LE_Connection_Update_Complete_Event self, event: hci.HCI_LE_Connection_Update_Complete_Event
): ):
if (connection := self.connections.get(event.connection_handle)) is None: if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! CONNECTION PARAMETERS UPDATE COMPLETE: unknown handle') logger.warning('!!! CONNECTION UPDATE COMPLETE: unknown handle')
return return
# Notify the client # Notify the client
@@ -1123,6 +1273,29 @@ class Host(utils.EventEmitter):
'connection_parameters_update_failure', connection.handle, event.status 'connection_parameters_update_failure', connection.handle, event.status
) )
def on_hci_le_connection_rate_change_event(
self, event: hci.HCI_LE_Connection_Rate_Change_Event
):
if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! CONNECTION RATE CHANGE: unknown handle')
return
# Notify the client
if event.status == hci.HCI_SUCCESS:
self.emit(
'le_connection_rate_change',
connection.handle,
event.connection_interval,
event.subrate_factor,
event.peripheral_latency,
event.continuation_number,
event.supervision_timeout,
)
else:
self.emit(
'le_connection_rate_change_failure', connection.handle, event.status
)
def on_hci_le_phy_update_complete_event( def on_hci_le_phy_update_complete_event(
self, event: hci.HCI_LE_PHY_Update_Complete_Event self, event: hci.HCI_LE_PHY_Update_Complete_Event
): ):
@@ -1338,15 +1511,17 @@ class Host(utils.EventEmitter):
# For now, just accept everything # For now, just accept everything
# TODO: delegate the decision # TODO: delegate the decision
self.send_command_sync( utils.AsyncRunner.spawn(
hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command( self.send_sync_command(
connection_handle=event.connection_handle, hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
interval_min=event.interval_min, connection_handle=event.connection_handle,
interval_max=event.interval_max, interval_min=event.interval_min,
max_latency=event.max_latency, interval_max=event.interval_max,
timeout=event.timeout, max_latency=event.max_latency,
min_ce_length=0, timeout=event.timeout,
max_ce_length=0, min_ce_length=0,
max_ce_length=0,
)
) )
) )
@@ -1382,9 +1557,9 @@ class Host(utils.EventEmitter):
connection_handle=event.connection_handle connection_handle=event.connection_handle
) )
await self.send_command(response) await self.send_sync_command(response)
asyncio.create_task(send_long_term_key()) utils.AsyncRunner.spawn(send_long_term_key())
def on_hci_synchronous_connection_complete_event( def on_hci_synchronous_connection_complete_event(
self, event: hci.HCI_Synchronous_Connection_Complete_Event self, event: hci.HCI_Synchronous_Connection_Complete_Event
@@ -1583,9 +1758,9 @@ class Host(utils.EventEmitter):
bd_addr=event.bd_addr bd_addr=event.bd_addr
) )
await self.send_command(response) await self.send_sync_command(response)
asyncio.create_task(send_link_key()) utils.AsyncRunner.spawn(send_link_key())
def on_hci_io_capability_request_event( def on_hci_io_capability_request_event(
self, event: hci.HCI_IO_Capability_Request_Event self, event: hci.HCI_IO_Capability_Request_Event
@@ -1680,12 +1855,13 @@ class Host(utils.EventEmitter):
self.emit( self.emit(
'le_remote_features_failure', event.connection_handle, event.status 'le_remote_features_failure', event.connection_handle, event.status
) )
else: return
self.emit(
'le_remote_features', self.emit(
event.connection_handle, 'le_remote_features',
int.from_bytes(event.le_features, 'little'), event.connection_handle,
) hci.LeFeatureMask(int.from_bytes(event.le_features, 'little')),
)
def on_hci_le_cs_read_remote_supported_capabilities_complete_event( def on_hci_le_cs_read_remote_supported_capabilities_complete_event(
self, event: hci.HCI_LE_CS_Read_Remote_Supported_Capabilities_Complete_Event self, event: hci.HCI_LE_CS_Read_Remote_Supported_Capabilities_Complete_Event
@@ -1718,6 +1894,12 @@ class Host(utils.EventEmitter):
self.emit('cs_subevent_result_continue', event) self.emit('cs_subevent_result_continue', event)
def on_hci_le_subrate_change_event(self, event: hci.HCI_LE_Subrate_Change_Event): def on_hci_le_subrate_change_event(self, event: hci.HCI_LE_Subrate_Change_Event):
if event.status != hci.HCI_SUCCESS:
self.emit(
'le_subrate_change_failure', event.connection_handle, event.status
)
return
self.emit( self.emit(
'le_subrate_change', 'le_subrate_change',
event.connection_handle, event.connection_handle,
+130 -68
View File
@@ -20,6 +20,7 @@ from __future__ import annotations
import asyncio import asyncio
import dataclasses import dataclasses
import enum import enum
import itertools
import logging import logging
import struct import struct
from collections import deque from collections import deque
@@ -302,11 +303,9 @@ class EnhancedControlField(ControlField):
@dataclasses.dataclass @dataclasses.dataclass
class InformationEnhancedControlField(EnhancedControlField): class InformationEnhancedControlField(EnhancedControlField):
tx_seq: int = 0 tx_seq: int
sar: int
req_seq: int = 0 req_seq: int = 0
segmentation_and_reassembly: int = (
EnhancedControlField.SegmentationAndReassembly.UNSEGMENTED
)
final: int = 1 final: int = 1
frame_type = EnhancedControlField.FieldType.I_FRAME frame_type = EnhancedControlField.FieldType.I_FRAME
@@ -316,15 +315,15 @@ class InformationEnhancedControlField(EnhancedControlField):
return cls( return cls(
tx_seq=(data[0] >> 1) & 0b0111111, tx_seq=(data[0] >> 1) & 0b0111111,
final=(data[0] >> 7) & 0b1, final=(data[0] >> 7) & 0b1,
req_seq=(data[1] & 0b001111111), req_seq=(data[1] & 0b00111111),
segmentation_and_reassembly=(data[1] >> 6) & 0b11, sar=(data[1] >> 6) & 0b11,
) )
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return bytes( return bytes(
[ [
self.frame_type | (self.tx_seq << 1) | (self.final << 7), self.frame_type | (self.tx_seq << 1) | (self.final << 7),
self.req_seq | (self.segmentation_and_reassembly << 6), self.req_seq | (self.sar << 6),
] ]
) )
@@ -889,27 +888,38 @@ class EnhancedRetransmissionProcessor(Processor):
class _PendingPdu: class _PendingPdu:
payload: bytes payload: bytes
tx_seq: int tx_seq: int
sar: InformationEnhancedControlField.SegmentationAndReassembly
sdu_length: int = 0
req_seq: int = 0 req_seq: int = 0
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return ( return (
bytes( bytes(
InformationEnhancedControlField( InformationEnhancedControlField(
tx_seq=self.tx_seq, req_seq=self.req_seq tx_seq=self.tx_seq,
req_seq=self.req_seq,
sar=self.sar,
) )
) )
+ (
struct.pack('<H', self.sdu_length)
if self.sar
== InformationEnhancedControlField.SegmentationAndReassembly.START
else b''
)
+ self.payload + self.payload
) )
_expected_ack_seq: int = 0 _last_acked_tx_seq: int = 0
_last_acked_rx_seq: int = 0
_next_tx_seq: int = 0 _next_tx_seq: int = 0
_last_tx_seq: int = 0
_req_seq_num: int = 0 _req_seq_num: int = 0
_next_seq_num: int = 0
_remote_is_busy: bool = False _remote_is_busy: bool = False
_in_sdu: bytes = b''
_num_receiver_ready_polls_sent: int = 0 _num_receiver_ready_polls_sent: int = 0
_pending_pdus: list[_PendingPdu] _pending_pdus: list[_PendingPdu]
_tx_window: list[_PendingPdu]
_monitor_handle: asyncio.TimerHandle | None = None _monitor_handle: asyncio.TimerHandle | None = None
_receiver_ready_poll_handle: asyncio.TimerHandle | None = None _receiver_ready_poll_handle: asyncio.TimerHandle | None = None
@@ -917,12 +927,6 @@ class EnhancedRetransmissionProcessor(Processor):
monitor_timeout: float monitor_timeout: float
retransmission_timeout: float retransmission_timeout: float
@classmethod
def _num_frames_between(cls, low: int, high: int) -> int:
if high < low:
high += cls.MAX_SEQ_NUM
return high - low
def __init__( def __init__(
self, self,
channel: ClassicChannel, channel: ClassicChannel,
@@ -935,6 +939,7 @@ class EnhancedRetransmissionProcessor(Processor):
self.peer_mps = peer_mps self.peer_mps = peer_mps
self.peer_tx_window_size = peer_tx_window_size self.peer_tx_window_size = peer_tx_window_size
self._pending_pdus = [] self._pending_pdus = []
self._tx_window = []
self.monitor_timeout = spec.monitor_timeout self.monitor_timeout = spec.monitor_timeout
self.channel = channel self.channel = channel
self.retransmission_timeout = spec.retransmission_timeout self.retransmission_timeout = spec.retransmission_timeout
@@ -972,12 +977,9 @@ class EnhancedRetransmissionProcessor(Processor):
def _send_receiver_ready_poll(self) -> None: def _send_receiver_ready_poll(self) -> None:
self._num_receiver_ready_polls_sent += 1 self._num_receiver_ready_polls_sent += 1
self.channel.send_pdu( self._send_s_frame(
SupervisoryEnhancedControlField( supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, final=1,
final=1,
req_seq=self._next_seq_num,
)
) )
def _get_next_tx_seq(self) -> int: def _get_next_tx_seq(self) -> int:
@@ -987,12 +989,35 @@ class EnhancedRetransmissionProcessor(Processor):
@override @override
def send_sdu(self, sdu: bytes) -> None: def send_sdu(self, sdu: bytes) -> None:
if len(sdu) > self.peer_mps: if len(sdu) <= self.peer_mps:
raise InvalidArgumentError( pdu = self._PendingPdu(
f'SDU size({len(sdu)}) exceeds channel MPS {self.peer_mps}' payload=sdu,
tx_seq=self._get_next_tx_seq(),
req_seq=self._req_seq_num,
sar=InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
) )
pdu = self._PendingPdu(payload=sdu, tx_seq=self._get_next_tx_seq()) self._pending_pdus.append(pdu)
self._pending_pdus.append(pdu) else:
for offset in range(0, len(sdu), self.peer_mps):
payload = sdu[offset : offset + self.peer_mps]
if offset == 0:
sar = (
InformationEnhancedControlField.SegmentationAndReassembly.START
)
elif offset + len(payload) >= len(sdu):
sar = InformationEnhancedControlField.SegmentationAndReassembly.END
else:
sar = (
InformationEnhancedControlField.SegmentationAndReassembly.CONTINUATION
)
pdu = self._PendingPdu(
payload=payload,
tx_seq=self._get_next_tx_seq(),
req_seq=self._req_seq_num,
sar=sar,
sdu_length=len(sdu),
)
self._pending_pdus.append(pdu)
self._process_output() self._process_output()
@override @override
@@ -1000,17 +1025,37 @@ class EnhancedRetransmissionProcessor(Processor):
control_field = EnhancedControlField.from_bytes(pdu) control_field = EnhancedControlField.from_bytes(pdu)
self._update_ack_seq(control_field.req_seq, control_field.final != 0) self._update_ack_seq(control_field.req_seq, control_field.final != 0)
if isinstance(control_field, InformationEnhancedControlField): if isinstance(control_field, InformationEnhancedControlField):
if control_field.tx_seq != self._next_seq_num: if control_field.tx_seq != self._req_seq_num:
logger.error(
"tx_seq != self._req_seq_num, tx_seq: %d, self._req_seq_num: %d",
control_field.tx_seq,
self._req_seq_num,
)
return return
self._next_seq_num = (self._next_seq_num + 1) % self.MAX_SEQ_NUM self._req_seq_num = (control_field.tx_seq + 1) % self.MAX_SEQ_NUM
self._req_seq_num = self._next_seq_num
ack_frame = SupervisoryEnhancedControlField( if (
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, control_field.sar
req_seq=self._next_seq_num, == InformationEnhancedControlField.SegmentationAndReassembly.START
) ):
self.channel.send_pdu(ack_frame) # Drop Control Field(2) + SDU Length(2)
self.channel.on_sdu(pdu[2:]) self._in_sdu += pdu[4:]
else:
# Drop Control Field(2)
self._in_sdu += pdu[2:]
if control_field.sar in (
InformationEnhancedControlField.SegmentationAndReassembly.END,
InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
):
self.channel.on_sdu(self._in_sdu)
self._in_sdu = b''
# If sink doesn't trigger any I-frame, ack this frame.
if self._req_seq_num != self._last_acked_rx_seq:
self._send_s_frame(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=0,
)
elif isinstance(control_field, SupervisoryEnhancedControlField): elif isinstance(control_field, SupervisoryEnhancedControlField):
self._remote_is_busy = ( self._remote_is_busy = (
control_field.supervision_function control_field.supervision_function
@@ -1022,56 +1067,66 @@ class EnhancedRetransmissionProcessor(Processor):
SupervisoryEnhancedControlField.SupervisoryFunction.RNR, SupervisoryEnhancedControlField.SupervisoryFunction.RNR,
): ):
if control_field.poll: if control_field.poll:
self.channel.send_pdu( self._send_s_frame(
SupervisoryEnhancedControlField( supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR, final=1,
final=1,
req_seq=self._next_seq_num,
)
) )
else: else:
# TODO: Handle Retransmission. # TODO: Handle Retransmission.
pass pass
def _process_output(self) -> None: def _process_output(self) -> None:
if self._remote_is_busy or self._monitor_handle: if self._remote_is_busy:
logger.debug("Remote is busy")
return
if self._monitor_handle:
logger.debug("Monitor handle is not None")
return return
for pdu in self._pending_pdus: pdu_to_send = self.peer_tx_window_size - len(self._tx_window)
if self._num_unacked_frames >= self.peer_tx_window_size: for pdu in itertools.islice(self._pending_pdus, pdu_to_send):
return self._send_i_frame(pdu)
self._send_pdu(pdu) self._pending_pdus = self._pending_pdus[pdu_to_send:]
self._last_tx_seq = pdu.tx_seq
@property def _send_i_frame(self, pdu: _PendingPdu) -> None:
def _num_unacked_frames(self) -> int:
if not self._pending_pdus:
return 0
return self._num_frames_between(self._expected_ack_seq, self._last_tx_seq + 1)
def _send_pdu(self, pdu: _PendingPdu) -> None:
pdu.req_seq = self._req_seq_num pdu.req_seq = self._req_seq_num
self._start_receiver_ready_poll() self._start_receiver_ready_poll()
self._tx_window.append(pdu)
self.channel.send_pdu(bytes(pdu)) self.channel.send_pdu(bytes(pdu))
self._last_acked_rx_seq = self._req_seq_num
def _send_s_frame(
self,
supervision_function: SupervisoryEnhancedControlField.SupervisoryFunction,
final: int,
) -> None:
self.channel.send_pdu(
SupervisoryEnhancedControlField(
supervision_function=supervision_function,
final=final,
req_seq=self._req_seq_num,
)
)
self._last_acked_rx_seq = self._req_seq_num
def _update_ack_seq(self, new_seq: int, is_poll_response: bool) -> None: def _update_ack_seq(self, new_seq: int, is_poll_response: bool) -> None:
num_frames_acked = self._num_frames_between(self._expected_ack_seq, new_seq) num_frames_acked = (new_seq - self._last_acked_tx_seq) % self.MAX_SEQ_NUM
if num_frames_acked > self._num_unacked_frames: if num_frames_acked > len(self._tx_window):
logger.error( logger.error(
"Received acknowledgment for %d frames but only %d frames are pending", "Received acknowledgment for %d frames but only %d frames are pending",
num_frames_acked, num_frames_acked,
self._num_unacked_frames, len(self._tx_window),
) )
return return
if is_poll_response and self._monitor_handle: if is_poll_response and self._monitor_handle:
self._monitor_handle.cancel() self._monitor_handle.cancel()
self._monitor_handle = None self._monitor_handle = None
del self._pending_pdus[:num_frames_acked] del self._tx_window[:num_frames_acked]
self._expected_ack_seq = new_seq self._last_acked_tx_seq = new_seq
if ( if (
self._expected_ack_seq == self._next_tx_seq self._last_acked_tx_seq == self._next_tx_seq
and self._receiver_ready_poll_handle and self._receiver_ready_poll_handle
): ):
self._receiver_ready_poll_handle.cancel() self._receiver_ready_poll_handle.cancel()
@@ -1592,7 +1647,9 @@ class LeCreditBasedChannel(utils.EventEmitter):
self.connection_result = None self.connection_result = None
self.disconnection_result = None self.disconnection_result = None
self.drained = asyncio.Event() self.drained = asyncio.Event()
self.att_mtu = 0 # Filled by GATT client or server later. # Core Specification Vol 3, Part G, 5.3.1 ATT_MTU
# ATT_MTU shall be set to the minimum of the MTU field values of the two devices.
self.att_mtu = min(mtu, peer_mtu)
self.drained.set() self.drained.set()
@@ -2285,8 +2342,8 @@ class ChannelManager:
cid, cid,
L2CAP_Connection_Response( L2CAP_Connection_Response(
identifier=request.identifier, identifier=request.identifier,
destination_cid=request.source_cid, destination_cid=0,
source_cid=0, source_cid=request.source_cid,
result=L2CAP_Connection_Response.Result.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, result=L2CAP_Connection_Response.Result.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
status=0x0000, status=0x0000,
), ),
@@ -2298,7 +2355,12 @@ class ChannelManager:
f'creating server channel with cid={source_cid} for psm {request.psm}' f'creating server channel with cid={source_cid} for psm {request.psm}'
) )
channel = ClassicChannel( channel = ClassicChannel(
self, connection, cid, request.psm, source_cid, server.spec manager=self,
connection=connection,
signaling_cid=cid,
psm=request.psm,
source_cid=source_cid,
spec=server.spec,
) )
connection_channels[source_cid] = channel connection_channels[source_cid] = channel
@@ -2315,8 +2377,8 @@ class ChannelManager:
cid, cid,
L2CAP_Connection_Response( L2CAP_Connection_Response(
identifier=request.identifier, identifier=request.identifier,
destination_cid=request.source_cid, destination_cid=0,
source_cid=0, source_cid=request.source_cid,
result=L2CAP_Connection_Response.Result.CONNECTION_REFUSED_PSM_NOT_SUPPORTED, result=L2CAP_Connection_Response.Result.CONNECTION_REFUSED_PSM_NOT_SUPPORTED,
status=0x0000, status=0x0000,
), ),
+1 -1
View File
@@ -278,7 +278,7 @@ class L2CAPService(L2CAPServicer):
if not l2cap_channel: if not l2cap_channel:
return SendResponse(error=COMMAND_NOT_UNDERSTOOD) return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
if isinstance(l2cap_channel, ClassicChannel): if isinstance(l2cap_channel, ClassicChannel):
l2cap_channel.send_pdu(request.data) l2cap_channel.write(request.data)
else: else:
l2cap_channel.write(request.data) l2cap_channel.write(request.data)
return SendResponse(success=empty_pb2.Empty()) return SendResponse(success=empty_pb2.Empty())
+24 -33
View File
@@ -16,35 +16,28 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from collections.abc import Callable
from bumble.gatt import ( from bumble import device, gatt, gatt_adapters, gatt_client
GATT_BATTERY_LEVEL_CHARACTERISTIC,
GATT_BATTERY_SERVICE,
Characteristic,
CharacteristicValue,
TemplateService,
)
from bumble.gatt_adapters import (
PackedCharacteristicAdapter,
PackedCharacteristicProxyAdapter,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class BatteryService(TemplateService): class BatteryService(gatt.TemplateService):
UUID = GATT_BATTERY_SERVICE UUID = gatt.GATT_BATTERY_SERVICE
BATTERY_LEVEL_FORMAT = 'B' BATTERY_LEVEL_FORMAT = 'B'
battery_level_characteristic: Characteristic[int] battery_level_characteristic: gatt.Characteristic[int]
def __init__(self, read_battery_level): def __init__(self, read_battery_level: Callable[[device.Connection], int]) -> None:
self.battery_level_characteristic = PackedCharacteristicAdapter( self.battery_level_characteristic = gatt_adapters.PackedCharacteristicAdapter(
Characteristic( gatt.Characteristic(
GATT_BATTERY_LEVEL_CHARACTERISTIC, gatt.GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY, properties=(
Characteristic.READABLE, gatt.Characteristic.Properties.READ
CharacteristicValue(read=read_battery_level), | gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.READABLE,
value=gatt.CharacteristicValue(read=read_battery_level),
), ),
pack_format=BatteryService.BATTERY_LEVEL_FORMAT, pack_format=BatteryService.BATTERY_LEVEL_FORMAT,
) )
@@ -52,19 +45,17 @@ class BatteryService(TemplateService):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class BatteryServiceProxy(ProfileServiceProxy): class BatteryServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = BatteryService SERVICE_CLASS = BatteryService
battery_level: CharacteristicProxy[int] | None battery_level: gatt_client.CharacteristicProxy[int]
def __init__(self, service_proxy): def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid( self.battery_level = gatt_adapters.PackedCharacteristicProxyAdapter(
GATT_BATTERY_LEVEL_CHARACTERISTIC service_proxy.get_required_characteristic_by_uuid(
): gatt.GATT_BATTERY_LEVEL_CHARACTERISTIC
self.battery_level = PackedCharacteristicProxyAdapter( ),
characteristics[0], pack_format=BatteryService.BATTERY_LEVEL_FORMAT pack_format=BatteryService.BATTERY_LEVEL_FORMAT,
) )
else:
self.battery_level = None
+128 -119
View File
@@ -18,40 +18,30 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import dataclasses
import enum
import struct import struct
from enum import IntEnum from collections.abc import Callable, Sequence
from typing import Any
from bumble import core from typing_extensions import Self
from bumble.att import ATT_Error
from bumble.gatt import ( from bumble import att, core, device, gatt, gatt_adapters, gatt_client, utils
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
GATT_HEART_RATE_SERVICE,
Characteristic,
CharacteristicValue,
TemplateService,
)
from bumble.gatt_adapters import (
DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter,
SerializableCharacteristicAdapter,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HeartRateService(TemplateService): class HeartRateService(gatt.TemplateService):
UUID = GATT_HEART_RATE_SERVICE UUID = gatt.GATT_HEART_RATE_SERVICE
HEART_RATE_CONTROL_POINT_FORMAT = 'B' HEART_RATE_CONTROL_POINT_FORMAT = 'B'
CONTROL_POINT_NOT_SUPPORTED = 0x80 CONTROL_POINT_NOT_SUPPORTED = 0x80
RESET_ENERGY_EXPENDED = 0x01 RESET_ENERGY_EXPENDED = 0x01
heart_rate_measurement_characteristic: Characteristic[HeartRateMeasurement] heart_rate_measurement_characteristic: gatt.Characteristic[HeartRateMeasurement]
body_sensor_location_characteristic: Characteristic[BodySensorLocation] body_sensor_location_characteristic: gatt.Characteristic[BodySensorLocation]
heart_rate_control_point_characteristic: Characteristic[int] heart_rate_control_point_characteristic: gatt.Characteristic[int]
class BodySensorLocation(IntEnum): class BodySensorLocation(utils.OpenIntEnum):
OTHER = 0 OTHER = 0
CHEST = 1 CHEST = 1
WRIST = 2 WRIST = 2
@@ -60,82 +50,90 @@ class HeartRateService(TemplateService):
EAR_LOBE = 5 EAR_LOBE = 5
FOOT = 6 FOOT = 6
@dataclasses.dataclass
class HeartRateMeasurement: class HeartRateMeasurement:
def __init__( heart_rate: int
self, sensor_contact_detected: bool | None = None
heart_rate, energy_expended: int | None = None
sensor_contact_detected=None, rr_intervals: Sequence[float] | None = None
energy_expended=None,
rr_intervals=None, class Flag(enum.IntFlag):
): INT16_HEART_RATE = 1 << 0
if heart_rate < 0 or heart_rate > 0xFFFF: SENSOR_CONTACT_DETECTED = 1 << 1
SENSOR_CONTACT_SUPPORTED = 1 << 2
ENERGY_EXPENDED_STATUS = 1 << 3
RR_INTERVAL = 1 << 4
def __post_init__(self) -> None:
if self.heart_rate < 0 or self.heart_rate > 0xFFFF:
raise core.InvalidArgumentError('heart_rate out of range') raise core.InvalidArgumentError('heart_rate out of range')
if energy_expended is not None and ( if self.energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF self.energy_expended < 0 or self.energy_expended > 0xFFFF
): ):
raise core.InvalidArgumentError('energy_expended out of range') raise core.InvalidArgumentError('energy_expended out of range')
if rr_intervals: if self.rr_intervals:
for rr_interval in rr_intervals: for rr_interval in self.rr_intervals:
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF: if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise core.InvalidArgumentError('rr_intervals out of range') raise core.InvalidArgumentError('rr_intervals out of range')
self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected
self.energy_expended = energy_expended
self.rr_intervals = rr_intervals
@classmethod @classmethod
def from_bytes(cls, data): def from_bytes(cls, data: bytes) -> Self:
flags = data[0] flags = data[0]
offset = 1 offset = 1
if flags & 1: if flags & cls.Flag.INT16_HEART_RATE:
hr = struct.unpack_from('<H', data, offset)[0] heart_rate = struct.unpack_from('<H', data, offset)[0]
offset += 2 offset += 2
else: else:
hr = struct.unpack_from('B', data, offset)[0] heart_rate = struct.unpack_from('B', data, offset)[0]
offset += 1 offset += 1
if flags & (1 << 2): if flags & cls.Flag.SENSOR_CONTACT_SUPPORTED:
sensor_contact_detected = flags & (1 << 1) != 0 sensor_contact_detected = flags & cls.Flag.SENSOR_CONTACT_DETECTED != 0
else: else:
sensor_contact_detected = None sensor_contact_detected = None
if flags & (1 << 3): if flags & cls.Flag.ENERGY_EXPENDED_STATUS:
energy_expended = struct.unpack_from('<H', data, offset)[0] energy_expended = struct.unpack_from('<H', data, offset)[0]
offset += 2 offset += 2
else: else:
energy_expended = None energy_expended = None
if flags & (1 << 4): rr_intervals: Sequence[float] | None = None
if flags & cls.Flag.RR_INTERVAL:
rr_intervals = tuple( rr_intervals = tuple(
struct.unpack_from('<H', data, offset + i * 2)[0] / 1024 struct.unpack_from('<H', data, i)[0] / 1024
for i in range((len(data) - offset) // 2) for i in range(offset, len(data), 2)
) )
else:
rr_intervals = ()
return cls(hr, sensor_contact_detected, energy_expended, rr_intervals) return cls(
heart_rate=heart_rate,
sensor_contact_detected=sensor_contact_detected,
energy_expended=energy_expended,
rr_intervals=rr_intervals,
)
def __bytes__(self): def __bytes__(self) -> bytes:
flags = 0
if self.heart_rate < 256: if self.heart_rate < 256:
flags = 0
data = struct.pack('B', self.heart_rate) data = struct.pack('B', self.heart_rate)
else: else:
flags = 1 flags |= self.Flag.INT16_HEART_RATE
data = struct.pack('<H', self.heart_rate) data = struct.pack('<H', self.heart_rate)
if self.sensor_contact_detected is not None: if self.sensor_contact_detected is not None:
flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2) flags |= self.Flag.SENSOR_CONTACT_SUPPORTED
if self.sensor_contact_detected:
flags |= self.Flag.SENSOR_CONTACT_DETECTED
if self.energy_expended is not None: if self.energy_expended is not None:
flags |= 1 << 3 flags |= self.Flag.ENERGY_EXPENDED_STATUS
data += struct.pack('<H', self.energy_expended) data += struct.pack('<H', self.energy_expended)
if self.rr_intervals: if self.rr_intervals is not None:
flags |= 1 << 4 flags |= self.Flag.RR_INTERVAL
data += b''.join( data += b''.join(
[ [
struct.pack('<H', int(rr_interval * 1024)) struct.pack('<H', int(rr_interval * 1024))
@@ -145,57 +143,67 @@ class HeartRateService(TemplateService):
return bytes([flags]) + data return bytes([flags]) + data
def __str__(self):
return (
f'HeartRateMeasurement(heart_rate={self.heart_rate},'
f' sensor_contact_detected={self.sensor_contact_detected},'
f' energy_expended={self.energy_expended},'
f' rr_intervals={self.rr_intervals})'
)
def __init__( def __init__(
self, self,
read_heart_rate_measurement, read_heart_rate_measurement: Callable[
body_sensor_location=None, [device.Connection], HeartRateMeasurement
reset_energy_expended=None, ],
body_sensor_location: HeartRateService.BodySensorLocation | None = None,
reset_energy_expended: Callable[[device.Connection], Any] | None = None,
): ):
self.heart_rate_measurement_characteristic = SerializableCharacteristicAdapter( self.heart_rate_measurement_characteristic = (
Characteristic( gatt_adapters.SerializableCharacteristicAdapter(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC, gatt.Characteristic(
Characteristic.Properties.NOTIFY, uuid=gatt.GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
0, properties=gatt.Characteristic.Properties.NOTIFY,
CharacteristicValue(read=read_heart_rate_measurement), permissions=gatt.Characteristic.Permissions(0),
), value=gatt.CharacteristicValue(read=read_heart_rate_measurement),
HeartRateService.HeartRateMeasurement, ),
HeartRateService.HeartRateMeasurement,
)
) )
characteristics = [self.heart_rate_measurement_characteristic] characteristics: list[gatt.Characteristic] = [
self.heart_rate_measurement_characteristic
]
if body_sensor_location is not None: if body_sensor_location is not None:
self.body_sensor_location_characteristic = Characteristic( self.body_sensor_location_characteristic = (
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC, gatt_adapters.EnumCharacteristicAdapter(
Characteristic.Properties.READ, gatt.Characteristic(
Characteristic.READABLE, uuid=gatt.GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
bytes([int(body_sensor_location)]), properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.READABLE,
value=body_sensor_location,
),
cls=self.BodySensorLocation,
length=1,
)
) )
characteristics.append(self.body_sensor_location_characteristic) characteristics.append(self.body_sensor_location_characteristic)
if reset_energy_expended: if reset_energy_expended:
def write_heart_rate_control_point_value(connection, value): def write_heart_rate_control_point_value(
connection: device.Connection, value: bytes
) -> None:
if value == self.RESET_ENERGY_EXPENDED: if value == self.RESET_ENERGY_EXPENDED:
if reset_energy_expended is not None: if reset_energy_expended is not None:
reset_energy_expended(connection) reset_energy_expended(connection)
else: else:
raise ATT_Error(self.CONTROL_POINT_NOT_SUPPORTED) raise att.ATT_Error(self.CONTROL_POINT_NOT_SUPPORTED)
self.heart_rate_control_point_characteristic = PackedCharacteristicAdapter( self.heart_rate_control_point_characteristic = (
Characteristic( gatt_adapters.PackedCharacteristicAdapter(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC, gatt.Characteristic(
Characteristic.Properties.WRITE, uuid=gatt.GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITEABLE, properties=gatt.Characteristic.Properties.WRITE,
CharacteristicValue(write=write_heart_rate_control_point_value), permissions=gatt.Characteristic.WRITEABLE,
), value=gatt.CharacteristicValue(
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT, write=write_heart_rate_control_point_value
),
),
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
) )
characteristics.append(self.heart_rate_control_point_characteristic) characteristics.append(self.heart_rate_control_point_characteristic)
@@ -203,50 +211,51 @@ class HeartRateService(TemplateService):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HeartRateServiceProxy(ProfileServiceProxy): class HeartRateServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HeartRateService SERVICE_CLASS = HeartRateService
heart_rate_measurement: ( heart_rate_measurement: gatt_client.CharacteristicProxy[
CharacteristicProxy[HeartRateService.HeartRateMeasurement] | None HeartRateService.HeartRateMeasurement
) ]
body_sensor_location: ( body_sensor_location: (
CharacteristicProxy[HeartRateService.BodySensorLocation] | None gatt_client.CharacteristicProxy[HeartRateService.BodySensorLocation] | None
) )
heart_rate_control_point: CharacteristicProxy[int] | None heart_rate_control_point: gatt_client.CharacteristicProxy[int] | None
def __init__(self, service_proxy): def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid( self.heart_rate_measurement = (
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC gatt_adapters.SerializableCharacteristicProxyAdapter(
): service_proxy.get_required_characteristic_by_uuid(
self.heart_rate_measurement = SerializableCharacteristicAdapter( gatt.GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
characteristics[0], HeartRateService.HeartRateMeasurement ),
HeartRateService.HeartRateMeasurement,
) )
else: )
self.heart_rate_measurement = None
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC gatt.GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
): ):
self.body_sensor_location = DelegatedCharacteristicAdapter( self.body_sensor_location = gatt_adapters.EnumCharacteristicProxyAdapter(
characteristics[0], characteristics[0], cls=HeartRateService.BodySensorLocation, length=1
decode=lambda value: HeartRateService.BodySensorLocation(value[0]),
) )
else: else:
self.body_sensor_location = None self.body_sensor_location = None
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC gatt.GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
): ):
self.heart_rate_control_point = PackedCharacteristicAdapter( self.heart_rate_control_point = (
characteristics[0], gatt_adapters.PackedCharacteristicProxyAdapter(
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT, characteristics[0],
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
) )
else: else:
self.heart_rate_control_point = None self.heart_rate_control_point = None
async def reset_energy_expended(self): async def reset_energy_expended(self) -> None:
if self.heart_rate_control_point is not None: if self.heart_rate_control_point is not None:
return await self.heart_rate_control_point.write_value( return await self.heart_rate_control_point.write_value(
HeartRateService.RESET_ENERGY_EXPENDED HeartRateService.RESET_ENERGY_EXPENDED
+1 -1
View File
@@ -800,7 +800,7 @@ class Multiplexer(utils.EventEmitter):
def send_frame(self, frame: RFCOMM_Frame) -> None: def send_frame(self, frame: RFCOMM_Frame) -> None:
logger.debug(f'>>> Multiplexer sending {frame}') logger.debug(f'>>> Multiplexer sending {frame}')
self.l2cap_channel.send_pdu(frame) self.l2cap_channel.write(bytes(frame))
def on_pdu(self, pdu: bytes) -> None: def on_pdu(self, pdu: bytes) -> None:
frame = RFCOMM_Frame.from_bytes(pdu) frame = RFCOMM_Frame.from_bytes(pdu)
+2 -2
View File
@@ -847,7 +847,7 @@ class Client:
self.pending_request = request self.pending_request = request
try: try:
self.channel.send_pdu(bytes(request)) self.channel.write(bytes(request))
return await self.pending_response return await self.pending_response
finally: finally:
self.pending_request = None self.pending_request = None
@@ -1061,7 +1061,7 @@ class Server:
def send_response(self, response): def send_response(self, response):
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}') logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
self.channel.send_pdu(response) self.channel.write(response)
def match_services(self, search_pattern: DataElement) -> dict[int, Service]: def match_services(self, search_pattern: DataElement) -> dict[int, Service]:
# Find the services for which the attributes in the pattern is a subset of the # Find the services for which the attributes in the pattern is a subset of the
+8 -3
View File
@@ -27,7 +27,7 @@ from __future__ import annotations
import asyncio import asyncio
import enum import enum
import logging import logging
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, ClassVar, TypeVar, cast from typing import TYPE_CHECKING, ClassVar, TypeVar, cast
@@ -507,10 +507,15 @@ def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AddressResolver: class AddressResolver:
def __init__(self, resolving_keys): def __init__(self, resolving_keys: Sequence[tuple[bytes, Address]]) -> None:
self.resolving_keys = resolving_keys self.resolving_keys = resolving_keys
def resolve(self, address): def can_resolve_to(self, address: Address) -> bool:
return any(
resolved_address == address for _, resolved_address in self.resolving_keys
)
def resolve(self, address: Address) -> Address | None:
address_bytes = bytes(address) address_bytes = bytes(address)
hash_part = address_bytes[0:3] hash_part = address_bytes[0:3]
prand = address_bytes[3:6] prand = address_bytes[3:6]
+111 -1
View File
@@ -110,6 +110,53 @@ class BtSnooper(Snooper):
) )
# -----------------------------------------------------------------------------
class PcapSnooper(Snooper):
"""
Snooper that saves or streames HCI packets using the PCAP format.
"""
PCAP_MAGIC = 0xA1B2C3D4
DLT_BLUETOOTH_HCI_H4_WITH_PHDR = 201
def __init__(self, output: BinaryIO):
self.output = output
# Write the header
self.output.write(
struct.pack(
"<IHHIIII",
self.PCAP_MAGIC,
2, # Major PCAP Version
4, # Minor PCAP Version
0, # Reserved 1
0, # Reserved 2
65535, # SnapLen
# FCS and f are set to 0 implicitly by the next line
self.DLT_BLUETOOTH_HCI_H4_WITH_PHDR, # The DLT in this PCAP
)
)
def snoop(self, hci_packet: bytes, direction: Snooper.Direction):
now = datetime.datetime.now(datetime.timezone.utc)
sec = int(now.timestamp())
usec = now.microsecond
# Emit the record
self.output.write(
struct.pack(
"<IIII",
sec, # Timestamp (Seconds)
usec, # Timestamp (Microseconds)
len(hci_packet) + 4,
len(hci_packet) + 4, # +4 because of the addtional direction info...
)
+ struct.pack(">I", int(direction)) # ...thats being added here
+ hci_packet
)
self.output.flush() # flush after every packet for live logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_SNOOPER_INSTANCE_COUNT = 0 _SNOOPER_INSTANCE_COUNT = 0
@@ -140,9 +187,38 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
pid: the current process ID. pid: the current process ID.
instance: the instance ID in the current process. instance: the instance ID in the current process.
pcapsnoop
The syntax for the type-specific arguments for this type is:
<io-type>:<io-type-specific-arguments>
Supported I/O types are:
file
The type-specific arguments for this I/O type is a string that is converted
to a file path using the python `str.format()` string formatting. The log
records will be written to that file if it can be opened/created.
The keyword args that may be referenced by the string pattern are:
now: the value of `datetime.now()`
utcnow: the value of `datetime.now(tz=datetime.timezone.utc)`
pid: the current process ID.
instance: the instance ID in the current process.
pipe
The type-specific arguments for this I/O type is a string that is converted
to a path using the python `str.format()` string formatting. The log
records will be written to the named pipe referenced by this path
if it can be opened. The keyword args that may be referenced by the
string pattern are:
now: the value of `datetime.now()`
utcnow: the value of `datetime.now(tz=datetime.timezone.utc)`
pid: the current process ID.
instance: the instance ID in the current process.
Examples: Examples:
btsnoop:file:my_btsnoop.log btsnoop:file:my_btsnoop.log
btsnoop:file:/tmp/bumble_{now:%Y-%m-%d-%H:%M:%S}_{pid}.log btsnoop:file:/tmp/bumble_{now:%Y-%m-%d-%H:%M:%S}_{pid}.log
pcapsnoop:pipe:/tmp/bumble-extcap
""" """
if ':' not in spec: if ':' not in spec:
@@ -150,6 +226,8 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
snooper_type, snooper_args = spec.split(':', maxsplit=1) snooper_type, snooper_args = spec.split(':', maxsplit=1)
global _SNOOPER_INSTANCE_COUNT
if snooper_type == 'btsnoop': if snooper_type == 'btsnoop':
if ':' not in snooper_args: if ':' not in snooper_args:
raise core.InvalidArgumentError('I/O type for btsnoop snooper type missing') raise core.InvalidArgumentError('I/O type for btsnoop snooper type missing')
@@ -157,7 +235,6 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
io_type, io_name = snooper_args.split(':', maxsplit=1) io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type == 'file': if io_type == 'file':
# Process the file name string pattern. # Process the file name string pattern.
global _SNOOPER_INSTANCE_COUNT
file_path = io_name.format( file_path = io_name.format(
now=datetime.datetime.now(), now=datetime.datetime.now(),
utcnow=datetime.datetime.now(tz=datetime.timezone.utc), utcnow=datetime.datetime.now(tz=datetime.timezone.utc),
@@ -173,6 +250,39 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
_SNOOPER_INSTANCE_COUNT -= 1 _SNOOPER_INSTANCE_COUNT -= 1
return return
elif snooper_type == 'pcapsnoop':
if ':' not in snooper_args:
raise core.InvalidArgumentError(
'I/O type for pcapsnoop snooper type missing'
)
io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type in {'pipe', 'file'}:
# Process the file name string pattern.
file_path = io_name.format(
now=datetime.datetime.now(),
utcnow=datetime.datetime.now(tz=datetime.timezone.utc),
pid=os.getpid(),
instance=_SNOOPER_INSTANCE_COUNT,
)
# Open a file or pipe
logger.debug(f'PCAP file: {file_path}')
# Pipes we have to open with unbuffered binary I/O
# so we pass ``buffering`` for pipes but not for files
pcap_file: BinaryIO
if io_type == 'pipe':
pcap_file = open(file_path, 'wb', buffering=0)
else:
pcap_file = open(file_path, 'wb')
with pcap_file:
_SNOOPER_INSTANCE_COUNT += 1
yield PcapSnooper(pcap_file)
_SNOOPER_INSTANCE_COUNT -= 1
return
raise core.InvalidArgumentError(f'I/O type {io_type} not supported') raise core.InvalidArgumentError(f'I/O type {io_type} not supported')
raise core.InvalidArgumentError(f'snooper type {snooper_type} not found') raise core.InvalidArgumentError(f'snooper type {snooper_type} not found')
+1 -1
View File
@@ -194,7 +194,7 @@ async def open_android_netsim_controller_transport(
# We only accept BLUETOOTH # We only accept BLUETOOTH
if request.initial_info.chip.kind != ChipKind.BLUETOOTH: if request.initial_info.chip.kind != ChipKind.BLUETOOTH:
logger.warning('Unsupported chip type') logger.debug('Request for unsupported chip type')
error = PacketResponse(error='Unsupported chip type') error = PacketResponse(error='Unsupported chip type')
await self.context.write(error) await self.context.write(error)
# return # return
+108 -86
View File
@@ -43,44 +43,53 @@ hci.HCI_Command.register_commands(globals())
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class HCI_LE_Get_Vendor_Capabilities_Command(hci.HCI_Command): class HCI_LE_Get_Vendor_Capabilities_ReturnParameters(hci.HCI_StatusReturnParameters):
max_advt_instances: int = field(metadata=hci.metadata(1), default=0)
offloaded_resolution_of_private_address: int = field(
metadata=hci.metadata(1), default=0
)
total_scan_results_storage: int = field(metadata=hci.metadata(2), default=0)
max_irk_list_sz: int = field(metadata=hci.metadata(1), default=0)
filtering_support: int = field(metadata=hci.metadata(1), default=0)
max_filter: int = field(metadata=hci.metadata(1), default=0)
activity_energy_info_support: int = field(metadata=hci.metadata(1), default=0)
version_supported: int = field(metadata=hci.metadata(2), default=0)
total_num_of_advt_tracked: int = field(metadata=hci.metadata(2), default=0)
extended_scan_support: int = field(metadata=hci.metadata(1), default=0)
debug_logging_supported: int = field(metadata=hci.metadata(1), default=0)
le_address_generation_offloading_support: int = field(
metadata=hci.metadata(1), default=0
)
a2dp_source_offload_capability_mask: int = field(
metadata=hci.metadata(4), default=0
)
bluetooth_quality_report_support: int = field(metadata=hci.metadata(1), default=0)
dynamic_audio_buffer_support: int = field(metadata=hci.metadata(4), default=0)
@hci.HCI_SyncCommand.sync_command(HCI_LE_Get_Vendor_Capabilities_ReturnParameters)
@dataclasses.dataclass
class HCI_LE_Get_Vendor_Capabilities_Command(
hci.HCI_SyncCommand[HCI_LE_Get_Vendor_Capabilities_ReturnParameters]
):
# pylint: disable=line-too-long # pylint: disable=line-too-long
''' '''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#vendor-specific-capabilities See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#vendor-specific-capabilities
''' '''
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('max_advt_instances', 1),
('offloaded_resolution_of_private_address', 1),
('total_scan_results_storage', 2),
('max_irk_list_sz', 1),
('filtering_support', 1),
('max_filter', 1),
('activity_energy_info_support', 1),
('version_supported', 2),
('total_num_of_advt_tracked', 2),
('extended_scan_support', 1),
('debug_logging_supported', 1),
('le_address_generation_offloading_support', 1),
('a2dp_source_offload_capability_mask', 4),
('bluetooth_quality_report_support', 1),
('dynamic_audio_buffer_support', 4),
]
@classmethod @classmethod
def parse_return_parameters(cls, parameters): def parse_return_parameters(cls, parameters):
# There are many versions of this data structure, so we need to parse until # There are many versions of this data structure, so we need to parse until
# there are no more bytes to parse, and leave un-signal parameters set to # there are no more bytes to parse, and leave un-signaled parameters set to
# None (older versions) # 0
nones = {field: None for field, _ in cls.return_parameters_fields} return_parameters = HCI_LE_Get_Vendor_Capabilities_ReturnParameters(
return_parameters = hci.HCI_Object(cls.return_parameters_fields, **nones) hci.HCI_ErrorCode.SUCCESS
)
try: try:
offset = 0 offset = 0
for field in cls.return_parameters_fields: for field in cls.return_parameters_class.fields:
field_name, field_type = field field_name, field_type = field
field_value, field_size = hci.HCI_Object.parse_field( field_value, field_size = hci.HCI_Object.parse_field(
parameters, offset, field_type parameters, offset, field_type
@@ -94,9 +103,30 @@ class HCI_LE_Get_Vendor_Capabilities_Command(hci.HCI_Command):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command # APCF Subcommands
class LeApcfOpcode(hci.SpecableEnum):
ENABLE = 0x00
SET_FILTERING_PARAMETERS = 0x01
BROADCASTER_ADDRESS = 0x02
SERVICE_UUID = 0x03
SERVICE_SOLICITATION_UUID = 0x04
LOCAL_NAME = 0x05
MANUFACTURER_DATA = 0x06
SERVICE_DATA = 0x07
TRANSPORT_DISCOVERY_SERVICE = 0x08
AD_TYPE_FILTER = 0x09
READ_EXTENDED_FEATURES = 0xFF
@dataclasses.dataclass @dataclasses.dataclass
class HCI_LE_APCF_Command(hci.HCI_Command): class HCI_LE_APCF_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = field(metadata=LeApcfOpcode.type_metadata(1))
payload: bytes = field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_LE_APCF_ReturnParameters)
@dataclasses.dataclass
class HCI_LE_APCF_Command(hci.HCI_SyncCommand[HCI_LE_APCF_ReturnParameters]):
# pylint: disable=line-too-long # pylint: disable=line-too-long
''' '''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_apcf_command See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_apcf_command
@@ -105,52 +135,52 @@ class HCI_LE_APCF_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures. implementation. A future enhancement may define subcommand-specific data structures.
''' '''
# APCF Subcommands opcode: int = dataclasses.field(metadata=LeApcfOpcode.type_metadata(1))
class Opcode(hci.SpecableEnum):
ENABLE = 0x00
SET_FILTERING_PARAMETERS = 0x01
BROADCASTER_ADDRESS = 0x02
SERVICE_UUID = 0x03
SERVICE_SOLICITATION_UUID = 0x04
LOCAL_NAME = 0x05
MANUFACTURER_DATA = 0x06
SERVICE_DATA = 0x07
TRANSPORT_DISCOVERY_SERVICE = 0x08
AD_TYPE_FILTER = 0x09
READ_EXTENDED_FEATURES = 0xFF
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*")) payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class HCI_Get_Controller_Activity_Energy_Info_Command(hci.HCI_Command): class HCI_Get_Controller_Activity_Energy_Info_ReturnParameters(
hci.HCI_StatusReturnParameters
):
total_tx_time_ms: int = field(metadata=hci.metadata(4))
total_rx_time_ms: int = field(metadata=hci.metadata(4))
total_idle_time_ms: int = field(metadata=hci.metadata(4))
total_energy_used: int = field(metadata=hci.metadata(4))
@hci.HCI_SyncCommand.sync_command(
HCI_Get_Controller_Activity_Energy_Info_ReturnParameters
)
@dataclasses.dataclass
class HCI_Get_Controller_Activity_Energy_Info_Command(
hci.HCI_SyncCommand[HCI_Get_Controller_Activity_Energy_Info_ReturnParameters]
):
# pylint: disable=line-too-long # pylint: disable=line-too-long
''' '''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_get_controller_activity_energy_info See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_get_controller_activity_energy_info
''' '''
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('total_tx_time_ms', 4),
('total_rx_time_ms', 4),
('total_idle_time_ms', 4),
('total_energy_used', 4),
]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command # A2DP Hardware Offload Subcommands
class A2dpHardwareOffloadOpcode(hci.SpecableEnum):
START_A2DP_OFFLOAD = 0x01
STOP_A2DP_OFFLOAD = 0x02
@dataclasses.dataclass @dataclasses.dataclass
class HCI_A2DP_Hardware_Offload_Command(hci.HCI_Command): class HCI_A2DP_Hardware_Offload_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = dataclasses.field(metadata=A2dpHardwareOffloadOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
@hci.HCI_SyncCommand.sync_command(HCI_A2DP_Hardware_Offload_ReturnParameters)
@dataclasses.dataclass
class HCI_A2DP_Hardware_Offload_Command(
hci.HCI_SyncCommand[HCI_A2DP_Hardware_Offload_ReturnParameters]
):
# pylint: disable=line-too-long # pylint: disable=line-too-long
''' '''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#a2dp-hardware-offload-support See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#a2dp-hardware-offload-support
@@ -159,25 +189,27 @@ class HCI_A2DP_Hardware_Offload_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures. implementation. A future enhancement may define subcommand-specific data structures.
''' '''
# A2DP Hardware Offload Subcommands opcode: int = dataclasses.field(metadata=A2dpHardwareOffloadOpcode.type_metadata(1))
class Opcode(hci.SpecableEnum):
START_A2DP_OFFLOAD = 0x01
STOP_A2DP_OFFLOAD = 0x02
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*")) payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command # Dynamic Audio Buffer Subcommands
class DynamicAudioBufferOpcode(hci.SpecableEnum):
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
@dataclasses.dataclass @dataclasses.dataclass
class HCI_Dynamic_Audio_Buffer_Command(hci.HCI_Command): class HCI_Dynamic_Audio_Buffer_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = dataclasses.field(metadata=DynamicAudioBufferOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
@hci.HCI_SyncCommand.sync_command(HCI_Dynamic_Audio_Buffer_ReturnParameters)
@dataclasses.dataclass
class HCI_Dynamic_Audio_Buffer_Command(
hci.HCI_SyncCommand[HCI_Dynamic_Audio_Buffer_ReturnParameters]
):
# pylint: disable=line-too-long # pylint: disable=line-too-long
''' '''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#dynamic-audio-buffer-command See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#dynamic-audio-buffer-command
@@ -186,19 +218,9 @@ class HCI_Dynamic_Audio_Buffer_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures. implementation. A future enhancement may define subcommand-specific data structures.
''' '''
# Dynamic Audio Buffer Subcommands opcode: int = dataclasses.field(metadata=DynamicAudioBufferOpcode.type_metadata(1))
class Opcode(hci.SpecableEnum):
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*")) payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HCI_Android_Vendor_Event(hci.HCI_Extended_Event): class HCI_Android_Vendor_Event(hci.HCI_Extended_Event):
+24 -18
View File
@@ -46,9 +46,19 @@ class TX_Power_Level_Command:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class HCI_Write_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command): class HCI_Write_Tx_Power_Level_ReturnParameters(hci.HCI_StatusReturnParameters):
handle_type: int = hci.field(metadata=hci.metadata(1))
connection_handle: int = hci.field(metadata=hci.metadata(2))
selected_tx_power_level: int = hci.field(metadata=hci.metadata(-1))
@hci.HCI_SyncCommand.sync_command(HCI_Write_Tx_Power_Level_ReturnParameters)
@dataclasses.dataclass
class HCI_Write_Tx_Power_Level_Command(
hci.HCI_SyncCommand[HCI_Write_Tx_Power_Level_ReturnParameters],
TX_Power_Level_Command,
):
''' '''
Write TX power level. See BT_HCI_OP_VS_WRITE_TX_POWER_LEVEL in Write TX power level. See BT_HCI_OP_VS_WRITE_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
@@ -61,18 +71,21 @@ class HCI_Write_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
connection_handle: int = dataclasses.field(metadata=hci.metadata(2)) connection_handle: int = dataclasses.field(metadata=hci.metadata(2))
tx_power_level: int = dataclasses.field(metadata=hci.metadata(-1)) tx_power_level: int = dataclasses.field(metadata=hci.metadata(-1))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('selected_tx_power_level', -1),
]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass @dataclasses.dataclass
class HCI_Read_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command): class HCI_Read_Tx_Power_Level_ReturnParameters(hci.HCI_StatusReturnParameters):
handle_type: int = hci.field(metadata=hci.metadata(1))
connection_handle: int = hci.field(metadata=hci.metadata(2))
tx_power_level: int = hci.field(metadata=hci.metadata(-1))
@hci.HCI_SyncCommand.sync_command(HCI_Read_Tx_Power_Level_ReturnParameters)
@dataclasses.dataclass
class HCI_Read_Tx_Power_Level_Command(
hci.HCI_SyncCommand[HCI_Read_Tx_Power_Level_ReturnParameters],
TX_Power_Level_Command,
):
''' '''
Read TX power level. See BT_HCI_OP_VS_READ_TX_POWER_LEVEL in Read TX power level. See BT_HCI_OP_VS_READ_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
@@ -83,10 +96,3 @@ class HCI_Read_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
handle_type: int = dataclasses.field(metadata=hci.metadata(1)) handle_type: int = dataclasses.field(metadata=hci.metadata(1))
connection_handle: int = dataclasses.field(metadata=hci.metadata(2)) connection_handle: int = dataclasses.field(metadata=hci.metadata(2))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('tx_power_level', -1),
]
+1 -1
View File
@@ -63,7 +63,7 @@ HCI sockets provide a way to send/receive HCI packets to/from a Bluetooth contro
See the [HCI Socket Transport page](../transports/hci_socket.md) for details on the `hci-socket` tansport syntax. See the [HCI Socket Transport page](../transports/hci_socket.md) for details on the `hci-socket` tansport syntax.
The HCI device referenced by an `hci-socket` transport (`hci<X>`, where `<X>` is an integer, with `hci0` being the first controller device, and so on) must be in the `DOWN` state before it can be opened as a transport. The HCI device referenced by an `hci-socket` transport (`hci<X>`, where `<X>` is an integer, with `hci0` being the first controller device, and so on) must be in the `DOWN` state before it can be opened as a transport.
You can bring a HCI controller `UP` or `DOWN` with `hciconfig hci<X> up` and `hciconfig hci<X> up`. You can bring a HCI controller `UP` or `DOWN` with `hciconfig hci<X> up` and `hciconfig hci<X> down`.
!!! tip "HCI Socket Permissions" !!! tip "HCI Socket Permissions"
By default, when running as a regular user, you won't have the permission to use By default, when running as a regular user, you won't have the permission to use
+3 -3
View File
@@ -37,7 +37,7 @@ The vendor specific HCI commands to read and write TX power are defined in
from bumble.vendor.zephyr.hci import HCI_Write_Tx_Power_Level_Command from bumble.vendor.zephyr.hci import HCI_Write_Tx_Power_Level_Command
# set advertising power to -4 dB # set advertising power to -4 dB
response = await host.send_command( response = await host.send_sync_command(
HCI_Write_Tx_Power_Level_Command( HCI_Write_Tx_Power_Level_Command(
handle_type=HCI_Write_Tx_Power_Level_Command.TX_POWER_HANDLE_TYPE_ADV, handle_type=HCI_Write_Tx_Power_Level_Command.TX_POWER_HANDLE_TYPE_ADV,
connection_handle=0, connection_handle=0,
@@ -45,7 +45,7 @@ response = await host.send_command(
) )
) )
if response.return_parameters.status == HCI_SUCCESS: if response.status == HCI_SUCCESS:
print(f"TX power set to {response.return_parameters.selected_tx_power_level}") print(f"TX power set to {response.selected_tx_power_level}")
``` ```
+1 -1
View File
@@ -65,7 +65,7 @@ async def main() -> None:
# Go! # Go!
await device.power_on() await device.power_on()
await device.start_advertising(auto_restart=True) await device.start_advertising(auto_restart=True)
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+2 -2
View File
@@ -71,8 +71,8 @@ async def main() -> None:
rr_intervals=random.choice( rr_intervals=random.choice(
( (
( (
random.randint(900, 1100) / 1000, random.randint(900, 1100) // 1000,
random.randint(900, 1100) / 1000, random.randint(900, 1100) // 1000,
), ),
None, None,
) )
+1 -1
View File
@@ -161,7 +161,7 @@ async def main() -> None:
await device.set_discoverable(True) await device.set_discoverable(True)
await device.set_connectable(True) await device.set_connectable(True)
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -181,7 +181,7 @@ async def main() -> None:
await device.set_discoverable(True) await device.set_discoverable(True)
await device.set_connectable(True) await device.set_connectable(True)
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -70,7 +70,7 @@ async def main() -> None:
await device.power_on() await device.power_on()
await device.start_advertising(advertising_type=advertising_type, target=target) await device.start_advertising(advertising_type=advertising_type, target=target)
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+24 -22
View File
@@ -25,7 +25,7 @@ import sys
import websockets.asyncio.server import websockets.asyncio.server
import bumble.logging import bumble.logging
from bumble import a2dp, avc, avdtp, avrcp, utils from bumble import a2dp, avc, avdtp, avrcp, sdp, utils
from bumble.core import PhysicalTransport from bumble.core import PhysicalTransport
from bumble.device import Device from bumble.device import Device
from bumble.transport import open_transport from bumble.transport import open_transport
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def sdp_records(): def sdp_records() -> dict[int, list[sdp.ServiceAttribute]]:
a2dp_sink_service_record_handle = 0x00010001 a2dp_sink_service_record_handle = 0x00010001
avrcp_controller_service_record_handle = 0x00010002 avrcp_controller_service_record_handle = 0x00010002
avrcp_target_service_record_handle = 0x00010003 avrcp_target_service_record_handle = 0x00010003
@@ -43,17 +43,17 @@ def sdp_records():
a2dp_sink_service_record_handle: a2dp.make_audio_sink_service_sdp_records( a2dp_sink_service_record_handle: a2dp.make_audio_sink_service_sdp_records(
a2dp_sink_service_record_handle a2dp_sink_service_record_handle
), ),
avrcp_controller_service_record_handle: avrcp.make_controller_service_sdp_records( avrcp_controller_service_record_handle: avrcp.ControllerServiceSdpRecord(
avrcp_controller_service_record_handle avrcp_controller_service_record_handle
), ).to_service_attributes(),
avrcp_target_service_record_handle: avrcp.make_target_service_sdp_records( avrcp_target_service_record_handle: avrcp.TargetServiceSdpRecord(
avrcp_controller_service_record_handle avrcp_target_service_record_handle
), ).to_service_attributes(),
} }
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def codec_capabilities(): def codec_capabilities() -> avdtp.MediaCodecCapabilities:
return avdtp.MediaCodecCapabilities( return avdtp.MediaCodecCapabilities(
media_type=avdtp.AVDTP_AUDIO_MEDIA_TYPE, media_type=avdtp.AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=a2dp.A2DP_SBC_CODEC_TYPE, media_codec_type=a2dp.A2DP_SBC_CODEC_TYPE,
@@ -81,20 +81,22 @@ def codec_capabilities():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_avdtp_connection(server): def on_avdtp_connection(server: avdtp.Protocol) -> None:
# Add a sink endpoint to the server # Add a sink endpoint to the server
sink = server.add_sink(codec_capabilities()) sink = server.add_sink(codec_capabilities())
sink.on('rtp_packet', on_rtp_packet) sink.on(sink.EVENT_RTP_PACKET, on_rtp_packet)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_rtp_packet(packet): def on_rtp_packet(packet: avdtp.MediaPacket) -> None:
print(f'RTP: {packet}') print(f'RTP: {packet}')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketServer): def on_avrcp_start(
async def get_supported_events(): avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketServer
) -> None:
async def get_supported_events() -> None:
events = await avrcp_protocol.get_supported_events() events = await avrcp_protocol.get_supported_events()
print("SUPPORTED EVENTS:", events) print("SUPPORTED EVENTS:", events)
websocket_server.send_message( websocket_server.send_message(
@@ -130,14 +132,14 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
utils.AsyncRunner.spawn(get_supported_events()) utils.AsyncRunner.spawn(get_supported_events())
async def monitor_track_changed(): async def monitor_track_changed() -> None:
async for identifier in avrcp_protocol.monitor_track_changed(): async for identifier in avrcp_protocol.monitor_track_changed():
print("TRACK CHANGED:", identifier.hex()) print("TRACK CHANGED:", identifier.hex())
websocket_server.send_message( websocket_server.send_message(
{"type": "track-changed", "params": {"identifier": identifier.hex()}} {"type": "track-changed", "params": {"identifier": identifier.hex()}}
) )
async def monitor_playback_status(): async def monitor_playback_status() -> None:
async for playback_status in avrcp_protocol.monitor_playback_status(): async for playback_status in avrcp_protocol.monitor_playback_status():
print("PLAYBACK STATUS CHANGED:", playback_status.name) print("PLAYBACK STATUS CHANGED:", playback_status.name)
websocket_server.send_message( websocket_server.send_message(
@@ -147,7 +149,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
} }
) )
async def monitor_playback_position(): async def monitor_playback_position() -> None:
async for playback_position in avrcp_protocol.monitor_playback_position( async for playback_position in avrcp_protocol.monitor_playback_position(
playback_interval=1 playback_interval=1
): ):
@@ -159,7 +161,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
} }
) )
async def monitor_player_application_settings(): async def monitor_player_application_settings() -> None:
async for settings in avrcp_protocol.monitor_player_application_settings(): async for settings in avrcp_protocol.monitor_player_application_settings():
print("PLAYER APPLICATION SETTINGS:", settings) print("PLAYER APPLICATION SETTINGS:", settings)
settings_as_dict = [ settings_as_dict = [
@@ -173,14 +175,14 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
} }
) )
async def monitor_available_players(): async def monitor_available_players() -> None:
async for _ in avrcp_protocol.monitor_available_players(): async for _ in avrcp_protocol.monitor_available_players():
print("AVAILABLE PLAYERS CHANGED") print("AVAILABLE PLAYERS CHANGED")
websocket_server.send_message( websocket_server.send_message(
{"type": "available-players-changed", "params": {}} {"type": "available-players-changed", "params": {}}
) )
async def monitor_addressed_player(): async def monitor_addressed_player() -> None:
async for player in avrcp_protocol.monitor_addressed_player(): async for player in avrcp_protocol.monitor_addressed_player():
print("ADDRESSED PLAYER CHANGED") print("ADDRESSED PLAYER CHANGED")
websocket_server.send_message( websocket_server.send_message(
@@ -195,7 +197,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
} }
) )
async def monitor_uids(): async def monitor_uids() -> None:
async for uid_counter in avrcp_protocol.monitor_uids(): async for uid_counter in avrcp_protocol.monitor_uids():
print("UIDS CHANGED") print("UIDS CHANGED")
websocket_server.send_message( websocket_server.send_message(
@@ -207,7 +209,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
} }
) )
async def monitor_volume(): async def monitor_volume() -> None:
async for volume in avrcp_protocol.monitor_volume(): async for volume in avrcp_protocol.monitor_volume():
print("VOLUME CHANGED:", volume) print("VOLUME CHANGED:", volume)
websocket_server.send_message( websocket_server.send_message(
@@ -360,7 +362,7 @@ async def main() -> None:
# Create a listener to wait for AVDTP connections # Create a listener to wait for AVDTP connections
listener = avdtp.Listener(avdtp.Listener.create_registrar(device)) listener = avdtp.Listener(avdtp.Listener.create_registrar(device))
listener.on('connection', on_avdtp_connection) listener.on(listener.EVENT_CONNECTION, on_avdtp_connection)
avrcp_delegate = Delegate() avrcp_delegate = Delegate()
avrcp_protocol = avrcp.Protocol(avrcp_delegate) avrcp_protocol = avrcp.Protocol(avrcp_delegate)
+1 -1
View File
@@ -112,7 +112,7 @@ async def main() -> None:
await device.set_discoverable(True) await device.set_discoverable(True)
await device.set_connectable(True) await device.set_connectable(True)
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -73,7 +73,7 @@ async def main() -> None:
await device.power_on() await device.power_on()
await device.start_discovery() await device.start_discovery()
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -57,7 +57,7 @@ async def main() -> None:
print(f'!!! Encryption failed: {error}') print(f'!!! Encryption failed: {error}')
return return
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+201
View File
@@ -0,0 +1,201 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import sys
from collections.abc import Callable
import bumble.logging
from bumble.core import BaseError
from bumble.device import Connection, Device
from bumble.hci import Address, LeFeatureMask
from bumble.transport import open_transport
# -----------------------------------------------------------------------------
DEFAULT_CENTRAL_ADDRESS = Address("F0:F0:F0:F0:F0:F0")
DEFAULT_PERIPHERAL_ADDRESS = Address("F1:F1:F1:F1:F1:F1")
# -----------------------------------------------------------------------------
async def run_as_central(
device: Device,
scenario: Callable | None,
) -> None:
# Connect to the peripheral
print(f'=== Connecting to {DEFAULT_PERIPHERAL_ADDRESS}...')
connection = await device.connect(DEFAULT_PERIPHERAL_ADDRESS)
print("=== Connected")
if scenario is not None:
await asyncio.sleep(1)
await scenario(connection)
await asyncio.get_running_loop().create_future()
async def run_as_peripheral(device: Device, scenario: Callable | None) -> None:
# Wait for a connection from the central
print(f'=== Advertising as {DEFAULT_PERIPHERAL_ADDRESS}...')
await device.start_advertising(auto_restart=True)
async def on_connection(connection: Connection) -> None:
assert scenario is not None
await asyncio.sleep(1)
await scenario(connection)
if scenario is not None:
device.on(Device.EVENT_CONNECTION, on_connection)
await asyncio.get_running_loop().create_future()
async def change_parameters(
connection: Connection,
parameter_request_procedure_supported: bool,
subrating_supported: bool,
shorter_connection_intervals_supported: bool,
) -> None:
if parameter_request_procedure_supported:
try:
print(">>> update_parameters(7.5, 200, 0, 4000)")
await connection.update_parameters(7.5, 200, 0, 4000)
await asyncio.sleep(3)
except BaseError as error:
print(f"Error: {error}")
if subrating_supported:
try:
print(">>> update_subrate(1, 2, 2, 1, 4000)")
await connection.update_subrate(1, 2, 2, 1, 4000)
await asyncio.sleep(3)
except BaseError as error:
print(f"Error: {error}")
if shorter_connection_intervals_supported:
try:
print(
">>> update_parameters_with_subrate(7.5, 200, 1, 1, 0, 0, 4000, 5, 1000)"
)
await connection.update_parameters_with_subrate(
7.5, 200, 1, 1, 0, 0, 4000, 5, 1000
)
await asyncio.sleep(3)
except BaseError as error:
print(f"Error: {error}")
try:
print(
">>> update_parameters_with_subrate(0.750, 5, 1, 1, 0, 0, 4000, 0.125, 1000)"
)
await connection.update_parameters_with_subrate(
0.750, 5, 1, 1, 0, 0, 4000, 0.125, 1000
)
await asyncio.sleep(3)
except BaseError as error:
print(f"Error: {error}")
print(">>> done")
def on_connection(connection: Connection) -> None:
print(f"+++ Connection established: {connection}")
def on_le_remote_features_change() -> None:
print(f'... LE Remote Features change: {connection.peer_le_features.name}')
connection.on(
connection.EVENT_LE_REMOTE_FEATURES_CHANGE, on_le_remote_features_change
)
def on_connection_parameters_change() -> None:
print(f'... LE Connection Parameters change: {connection.parameters}')
connection.on(
connection.EVENT_CONNECTION_PARAMETERS_UPDATE, on_connection_parameters_change
)
async def main() -> None:
if len(sys.argv) < 3:
print(
'Usage: run_connection_updates.py <transport-spec> '
'central|peripheral initiator|responder'
)
return
print('<<< connecting to HCI...')
async with await open_transport(sys.argv[1]) as hci_transport:
print('<<< connected')
role = sys.argv[2]
direction = sys.argv[3]
device = Device.with_hci(
role,
(
DEFAULT_CENTRAL_ADDRESS
if role == "central"
else DEFAULT_PERIPHERAL_ADDRESS
),
hci_transport.source,
hci_transport.sink,
)
device.le_subrate_enabled = True
device.le_shorter_connection_intervals_enabled = True
await device.power_on()
parameter_request_procedure_supported = device.supports_le_features(
LeFeatureMask.CONNECTION_PARAMETERS_REQUEST_PROCEDURE
)
print(
"Parameters Request Procedure supported: "
f"{parameter_request_procedure_supported}"
)
subrating_supported = device.supports_le_features(
LeFeatureMask.CONNECTION_SUBRATING
)
print(f"Subrating supported: {subrating_supported}")
shorter_connection_intervals_supported = device.supports_le_features(
LeFeatureMask.SHORTER_CONNECTION_INTERVALS
)
print(
"Shorter Connection Intervals supported: "
f"{shorter_connection_intervals_supported}"
)
device.on(Device.EVENT_CONNECTION, on_connection)
async def run(connection: Connection) -> None:
await change_parameters(
connection,
parameter_request_procedure_supported,
subrating_supported,
shorter_connection_intervals_supported,
)
scenario = run if direction == "initiator" else None
if role == "central":
await run_as_central(device, scenario)
else:
await run_as_peripheral(device, scenario)
# -----------------------------------------------------------------------------
bumble.logging.setup_basic_logging('DEBUG')
asyncio.run(main())
+1 -1
View File
@@ -101,7 +101,7 @@ async def main() -> None:
await device.start_advertising() await device.start_advertising()
await device.start_scanning() await device.start_scanning()
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -48,7 +48,7 @@ async def main() -> None:
await device.power_on() await device.power_on()
await device.start_scanning() await device.start_scanning()
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -147,7 +147,7 @@ async def main() -> None:
else: else:
await device.start_advertising(auto_restart=True) await device.start_advertising(auto_restart=True)
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -99,7 +99,7 @@ async def main() -> None:
else: else:
await device.start_advertising(auto_restart=True) await device.start_advertising(auto_restart=True)
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -422,7 +422,7 @@ async def main() -> None:
# Setup a server # Setup a server
await server(device) await server(device)
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+3 -7
View File
@@ -100,13 +100,9 @@ def on_sco_packet(packet: hci.HCI_SynchronousDataPacket):
if source_file and (pcm_data := source_file.read(packet.data_total_length)): if source_file and (pcm_data := source_file.read(packet.data_total_length)):
assert ag_protocol assert ag_protocol
host = ag_protocol.dlc.multiplexer.l2cap_channel.connection.device.host host = ag_protocol.dlc.multiplexer.l2cap_channel.connection.device.host
host.send_hci_packet( host.send_sco_sdu(
hci.HCI_SynchronousDataPacket( connection_handle=packet.connection_handle,
connection_handle=packet.connection_handle, sdu=pcm_data,
packet_status=0,
data_total_length=len(pcm_data),
data=pcm_data,
)
) )
+1 -1
View File
@@ -167,7 +167,7 @@ async def main() -> None:
await websockets.asyncio.server.serve(serve, 'localhost', 8989) await websockets.asyncio.server.serve(serve, 'localhost', 8989)
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -735,7 +735,7 @@ async def main() -> None:
print("Executing in Web mode") print("Executing in Web mode")
await keyboard_device(hid_device) await keyboard_device(hid_device)
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -556,7 +556,7 @@ async def main() -> None:
# Interrupt Channel # Interrupt Channel
await hid_host.connect_interrupt_channel() await hid_host.connect_interrupt_channel()
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -227,7 +227,7 @@ async def main() -> None:
tcp_port = int(sys.argv[5]) tcp_port = int(sys.argv[5])
asyncio.create_task(tcp_server(tcp_port, session)) asyncio.create_task(tcp_server(tcp_port, session))
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -153,7 +153,7 @@ async def main() -> None:
await device.set_discoverable(True) await device.set_discoverable(True)
await device.set_connectable(True) await device.set_connectable(True)
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+1 -1
View File
@@ -75,7 +75,7 @@ async def main() -> None:
await device.power_on() await device.power_on()
await device.start_scanning(filter_duplicates=filter_duplicates) await device.start_scanning(filter_duplicates=filter_duplicates)
await hci_transport.source.wait_for_termination() await hci_transport.source.terminated
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+7 -2
View File
@@ -15,15 +15,20 @@ dependencies = [
"aiohttp ~= 3.8; platform_system!='Emscripten'", "aiohttp ~= 3.8; platform_system!='Emscripten'",
"appdirs >= 1.4; platform_system!='Emscripten'", "appdirs >= 1.4; platform_system!='Emscripten'",
"click >= 8.1.3; platform_system!='Emscripten'", "click >= 8.1.3; platform_system!='Emscripten'",
"cryptography >= 44.0.3; platform_system!='Emscripten'", "cryptography >= 44.0.3; platform_system!='Emscripten' and platform_system!='Android'",
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the # Pyodide bundles a version of cryptography that is built for wasm, which may not match the
# versions available on PyPI. Relax the version requirement since it's better than being # versions available on PyPI. Relax the version requirement since it's better than being
# completely unable to import the package in case of version mismatch. # completely unable to import the package in case of version mismatch.
"cryptography >= 44.0.3; platform_system=='Emscripten'", "cryptography >= 44.0.3; platform_system=='Emscripten'",
# Android wheels for cryptography are not yet available on PyPI, so chaquopy uses
# the builds from https://chaquo.com/pypi-13.1/cryptography/. But these are not regually
# updated. Relax the version requirement since it's better than being completely unable
# to import the package in case of version mismatch.
"cryptography >= 42.0.8; platform_system=='Android'",
"grpcio >= 1.62.1; platform_system!='Emscripten'", "grpcio >= 1.62.1; platform_system!='Emscripten'",
"humanize >= 4.6.0; platform_system!='Emscripten'", "humanize >= 4.6.0; platform_system!='Emscripten'",
"libusb1 >= 2.0.1; platform_system!='Emscripten'", "libusb1 >= 2.0.1; platform_system!='Emscripten'",
"libusb-package == 1.0.26.1; platform_system!='Emscripten'", "libusb-package == 1.0.26.1; platform_system!='Emscripten' and platform_system!='Android'",
"platformdirs >= 3.10.0; platform_system!='Emscripten'", "platformdirs >= 3.10.0; platform_system!='Emscripten'",
"prompt_toolkit >= 3.0.16; platform_system!='Emscripten'", "prompt_toolkit >= 3.0.16; platform_system!='Emscripten'",
"prettytable >= 3.6.0; platform_system!='Emscripten'", "prettytable >= 3.6.0; platform_system!='Emscripten'",
+1 -1
View File
@@ -17,6 +17,6 @@ use pyo3::PyResult;
#[pyo3_asyncio::tokio::test] #[pyo3_asyncio::tokio::test]
async fn realtek_driver_info_all_drivers() -> PyResult<()> { async fn realtek_driver_info_all_drivers() -> PyResult<()> {
assert_eq!(12, DriverInfo::all_drivers()?.len()); assert_eq!(13, DriverInfo::all_drivers()?.len());
Ok(()) Ok(())
} }
+214 -1
View File
@@ -17,6 +17,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio
import struct import struct
from collections.abc import Sequence from collections.abc import Sequence
@@ -233,7 +234,21 @@ def test_event(event: avrcp.Event):
feature_bitmask=avrcp.MediaPlayerItem.Features.ADD_TO_NOW_PLAYING, feature_bitmask=avrcp.MediaPlayerItem.Features.ADD_TO_NOW_PLAYING,
character_set_id=avrcp.CharacterSetId.UTF_8, character_set_id=avrcp.CharacterSetId.UTF_8,
displayable_name="Woo", displayable_name="Woo",
) ),
avrcp.FolderItem(
folder_uid=1,
folder_type=avrcp.FolderItem.FolderType.ALBUMS,
is_playable=avrcp.FolderItem.Playable.PLAYABLE,
character_set_id=avrcp.CharacterSetId.UTF_8,
displayable_name="Album",
),
avrcp.MediaElementItem(
media_element_uid=1,
media_type=avrcp.MediaElementItem.MediaType.AUDIO,
character_set_id=avrcp.CharacterSetId.UTF_8,
displayable_name="Song",
attribute_value_entry_list=[],
),
], ],
), ),
avrcp.ChangePathResponse( avrcp.ChangePathResponse(
@@ -408,6 +423,47 @@ def test_passthrough_commands():
assert bytes(parsed) == play_pressed_bytes assert bytes(parsed) == play_pressed_bytes
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_find_sdp_records():
two_devices = await TwoDevices.create_with_avdtp()
# Add SDP records to device 1
controller_record = avrcp.ControllerServiceSdpRecord(
service_record_handle=0x10001,
avctp_version=(1, 4),
avrcp_version=(1, 6),
supported_features=(
avrcp.ControllerFeatures.CATEGORY_1
| avrcp.ControllerFeatures.SUPPORTS_BROWSING
),
)
target_record = avrcp.TargetServiceSdpRecord(
service_record_handle=0x10002,
avctp_version=(1, 4),
avrcp_version=(1, 6),
supported_features=(
avrcp.TargetFeatures.CATEGORY_1 | avrcp.TargetFeatures.SUPPORTS_BROWSING
),
)
two_devices.devices[1].sdp_service_records = {
0x10001: controller_record.to_service_attributes(),
0x10002: target_record.to_service_attributes(),
}
# Find records from device 0
controller_records = await avrcp.ControllerServiceSdpRecord.find(
two_devices.connections[0]
)
assert len(controller_records) == 1
assert controller_records[0] == controller_record
target_records = await avrcp.TargetServiceSdpRecord.find(two_devices.connections[0])
assert len(target_records) == 1
assert target_records[0] == target_record
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_supported_events(): async def test_get_supported_events():
@@ -422,6 +478,163 @@ async def test_get_supported_events():
assert supported_events == [avrcp.EventId.VOLUME_CHANGED] assert supported_events == [avrcp.EventId.VOLUME_CHANGED]
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_passthrough_key_event():
two_devices = await TwoDevices.create_with_avdtp()
q = asyncio.Queue[tuple[avc.PassThroughFrame.OperationId, bool, bytes]]()
class Delegate(avrcp.Delegate):
async def on_key_event(
self, key: avc.PassThroughFrame.OperationId, pressed: bool, data: bytes
) -> None:
q.put_nowait((key, pressed, data))
two_devices.protocols[1].delegate = Delegate()
for key, pressed in [
(avc.PassThroughFrame.OperationId.PLAY, True),
(avc.PassThroughFrame.OperationId.PLAY, False),
(avc.PassThroughFrame.OperationId.PAUSE, True),
(avc.PassThroughFrame.OperationId.PAUSE, False),
]:
await two_devices.protocols[0].send_key_event(key, pressed)
assert (await q.get()) == (key, pressed, b'')
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_passthrough_key_event_rejected():
two_devices = await TwoDevices.create_with_avdtp()
class Delegate(avrcp.Delegate):
async def on_key_event(
self, key: avc.PassThroughFrame.OperationId, pressed: bool, data: bytes
) -> None:
raise avrcp.Delegate.AvcError(avc.ResponseFrame.ResponseCode.REJECTED)
two_devices.protocols[1].delegate = Delegate()
response = await two_devices.protocols[0].send_key_event(
avc.PassThroughFrame.OperationId.PLAY, True
)
assert response.response == avc.ResponseFrame.ResponseCode.REJECTED
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_passthrough_key_event_exception():
two_devices = await TwoDevices.create_with_avdtp()
class Delegate(avrcp.Delegate):
async def on_key_event(
self, key: avc.PassThroughFrame.OperationId, pressed: bool, data: bytes
) -> None:
raise Exception()
two_devices.protocols[1].delegate = Delegate()
response = await two_devices.protocols[0].send_key_event(
avc.PassThroughFrame.OperationId.PLAY, True
)
assert response.response == avc.ResponseFrame.ResponseCode.REJECTED
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_volume():
two_devices = await TwoDevices.create_with_avdtp()
for volume in range(avrcp.SetAbsoluteVolumeCommand.MAXIMUM_VOLUME + 1):
response = await two_devices.protocols[1].send_avrcp_command(
avc.CommandFrame.CommandType.CONTROL, avrcp.SetAbsoluteVolumeCommand(volume)
)
assert isinstance(response.response, avrcp.SetAbsoluteVolumeResponse)
assert response.response.volume == volume
assert two_devices.protocols[0].delegate.volume == volume
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_playback_status():
two_devices = await TwoDevices.create_with_avdtp()
for status in avrcp.PlayStatus:
two_devices.protocols[0].delegate.playback_status = status
response = await two_devices.protocols[1].get_play_status()
assert response.play_status == status
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_supported_company_ids():
two_devices = await TwoDevices.create_with_avdtp()
for status in avrcp.PlayStatus:
two_devices.protocols[0].delegate = avrcp.Delegate(
supported_company_ids=[avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID]
)
supported_company_ids = await two_devices.protocols[
1
].get_supported_company_ids()
assert supported_company_ids == [avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID]
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_volume():
two_devices = await TwoDevices.create_with_avdtp()
two_devices.protocols[1].delegate = avrcp.Delegate([avrcp.EventId.VOLUME_CHANGED])
volume_iter = two_devices.protocols[0].monitor_volume()
for volume in range(avrcp.SetAbsoluteVolumeCommand.MAXIMUM_VOLUME + 1):
# Interim
two_devices.protocols[1].delegate.volume = 0
assert (await anext(volume_iter)) == 0
# Changed
two_devices.protocols[1].notify_volume_changed(volume)
assert (await anext(volume_iter)) == volume
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_playback_status():
two_devices = await TwoDevices.create_with_avdtp()
two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.PLAYBACK_STATUS_CHANGED]
)
playback_status_iter = two_devices.protocols[0].monitor_playback_status()
for playback_status in avrcp.PlayStatus:
# Interim
two_devices.protocols[1].delegate.playback_status = avrcp.PlayStatus.STOPPED
assert (await anext(playback_status_iter)) == avrcp.PlayStatus.STOPPED
# Changed
two_devices.protocols[1].notify_playback_status_changed(playback_status)
assert (await anext(playback_status_iter)) == playback_status
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_now_playing_content():
two_devices = await TwoDevices.create_with_avdtp()
two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.NOW_PLAYING_CONTENT_CHANGED]
)
now_playing_iter = two_devices.protocols[0].monitor_now_playing_content()
for _ in range(2):
# Interim
await anext(now_playing_iter)
# Changed
two_devices.protocols[1].notify_now_playing_content_changed()
await anext(now_playing_iter)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
test_frame_parser() test_frame_parser()
+34
View File
@@ -0,0 +1,34 @@
# Copyright 2021-2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from bumble import device as device_module
from bumble.profiles import battery_service
from . import test_utils
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_battery_level():
devices = await test_utils.TwoDevices.create_with_connection()
service = battery_service.BatteryService(lambda _: 1)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(battery_service.BatteryServiceProxy)
assert client
assert await client.battery_level.read_value() == 1
+7 -6
View File
@@ -42,7 +42,6 @@ from bumble.hci import (
HCI_CREATE_CONNECTION_COMMAND, HCI_CREATE_CONNECTION_COMMAND,
HCI_SUCCESS, HCI_SUCCESS,
Address, Address,
HCI_Command_Complete_Event,
HCI_Command_Status_Event, HCI_Command_Status_Event,
HCI_Connection_Complete_Event, HCI_Connection_Complete_Event,
HCI_Connection_Request_Event, HCI_Connection_Request_Event,
@@ -154,10 +153,10 @@ async def test_device_connect_parallel():
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d1.host.on_hci_packet( d1.host.on_hci_packet(
HCI_Command_Complete_Event( HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
) )
) )
@@ -188,10 +187,10 @@ async def test_device_connect_parallel():
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d2.host.on_hci_packet( d2.host.on_hci_packet(
HCI_Command_Complete_Event( HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
) )
) )
@@ -620,7 +619,9 @@ async def test_le_request_subrate():
def on_le_subrate_change(): def on_le_subrate_change():
q.put_nowait(lambda: None) q.put_nowait(lambda: None)
devices.connections[0].on(Connection.EVENT_LE_SUBRATE_CHANGE, on_le_subrate_change) devices.connections[0].on(
Connection.EVENT_CONNECTION_PARAMETERS_UPDATE, on_le_subrate_change
)
await devices[0].send_command( await devices[0].send_command(
hci.HCI_LE_Subrate_Request_Command( hci.HCI_LE_Subrate_Request_Command(
+99 -1
View File
@@ -28,7 +28,7 @@ from unittest.mock import ANY, AsyncMock, Mock
import pytest import pytest
from typing_extensions import Self from typing_extensions import Self
from bumble import gatt_client, l2cap from bumble import att, gatt_client, l2cap
from bumble.att import ( from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR, ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_PDU, ATT_PDU,
@@ -1638,6 +1638,104 @@ async def test_eatt_connection_failure():
await gatt_client.Client.connect_eatt(devices.connections[0]) await gatt_client.Client.connect_eatt(devices.connections[0])
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_multiple() -> None:
devices = await TwoDevices.create_with_connection()
characteristic1 = Characteristic(
'0001', Characteristic.Properties.READ, Characteristic.READABLE, b'1234'
)
characteristic2 = Characteristic(
'0002',
Characteristic.Properties.READ,
Characteristic.READABLE,
b'5678',
)
service = Service('0000', [characteristic1, characteristic2])
devices[1].add_service(service)
client = devices.connections[0].gatt_client
server = devices[1].gatt_server
await client.discover_services()
characteristics = await client.discover_characteristics(
[characteristic1.uuid, characteristic2.uuid], None
)
response = await client.send_request(
att.ATT_Read_Multiple_Request(
set_of_handles=[c.handle for c in characteristics]
)
)
assert isinstance(response, att.ATT_Read_Multiple_Response)
assert response.set_of_values == b'12345678'
response = await client.send_request(
att.ATT_Read_Multiple_Request(
set_of_handles=[
next(
handle
for handle in range(0x0001, 0xFFFF)
if not server.get_attribute(handle)
)
]
)
)
assert isinstance(response, att.ATT_Error_Response)
assert response.error_code == att.ATT_ATTRIBUTE_NOT_FOUND_ERROR
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_multiple_variable() -> None:
devices = await TwoDevices.create_with_connection()
characteristic1 = Characteristic(
'0001', Characteristic.Properties.READ, Characteristic.READABLE, b'1234'
)
characteristic2 = Characteristic(
'0002',
Characteristic.Properties.READ,
Characteristic.READABLE,
b'99',
)
service = Service('0000', [characteristic1, characteristic2])
devices[1].add_service(service)
client = devices.connections[0].gatt_client
server = devices[1].gatt_server
await client.discover_services()
characteristics = await client.discover_characteristics(
[characteristic1.uuid, characteristic2.uuid], None
)
response = await client.send_request(
att.ATT_Read_Multiple_Variable_Request(
set_of_handles=[c.handle for c in characteristics]
)
)
assert isinstance(response, att.ATT_Read_Multiple_Variable_Response)
assert response.length_value_tuple_list == [(4, b'1234'), (2, b'99')]
response = await client.send_request(
att.ATT_Read_Multiple_Variable_Request(
set_of_handles=[
next(
handle
for handle in range(0x0001, 0xFFFF)
if not server.get_attribute(handle)
)
]
)
)
assert isinstance(response, att.ATT_Error_Response)
assert response.error_code == att.ATT_ATTRIBUTE_NOT_FOUND_ERROR
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
+56 -45
View File
@@ -20,7 +20,7 @@ import struct
import pytest import pytest
from bumble import hci from bumble import hci, utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# pylint: disable=invalid-name # pylint: disable=invalid-name
@@ -136,43 +136,25 @@ def test_HCI_LE_Channel_Selection_Algorithm_Event():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_HCI_Command_Complete_Event(): def test_HCI_Command_Complete_Event():
# With a serializable object # With a serializable object
event = hci.HCI_Command_Complete_Event( event1 = hci.HCI_Command_Complete_Event(
num_hci_command_packets=34, num_hci_command_packets=34,
command_opcode=hci.HCI_LE_READ_BUFFER_SIZE_COMMAND, command_opcode=hci.HCI_LE_READ_BUFFER_SIZE_COMMAND,
return_parameters=hci.HCI_LE_Read_Buffer_Size_Command.create_return_parameters( return_parameters=hci.HCI_LE_Read_Buffer_Size_Command.return_parameters_class(
status=0, status=0,
le_acl_data_packet_length=1234, le_acl_data_packet_length=1234,
total_num_le_acl_data_packets=56, total_num_le_acl_data_packets=56,
), ),
) )
basic_check(event) basic_check(event1)
# With an arbitrary byte array
event = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=bytes([1, 2, 3, 4]),
)
basic_check(event)
# With a simple status as a 1-byte array
event = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=bytes([7]),
)
basic_check(event)
event = hci.HCI_Packet.from_bytes(bytes(event))
assert event.return_parameters == 7
# With a simple status as an integer status # With a simple status as an integer status
event = hci.HCI_Command_Complete_Event( event3 = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND, command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=9, return_parameters=hci.HCI_StatusReturnParameters(hci.HCI_ErrorCode(9)),
) )
basic_check(event) basic_check(event3)
assert event.return_parameters == 9 assert event3.return_parameters.status == 9
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -229,6 +211,36 @@ def test_HCI_Vendor_Event():
assert isinstance(parsed, hci.HCI_Vendor_Event) assert isinstance(parsed, hci.HCI_Vendor_Event)
# -----------------------------------------------------------------------------
def test_return_parameters() -> None:
params = hci.HCI_Reset_Command.parse_return_parameters(bytes.fromhex('3C'))
assert params.status == hci.HCI_ErrorCode.ADVERTISING_TIMEOUT_ERROR
assert isinstance(params.status, utils.OpenIntEnum)
params = hci.HCI_Read_BD_ADDR_Command.parse_return_parameters(
bytes.fromhex('00001122334455')
)
assert params.status == hci.HCI_ErrorCode.SUCCESS
assert isinstance(params.status, utils.OpenIntEnum)
assert isinstance(params.bd_addr, hci.Address)
params = hci.HCI_Read_Local_Name_Command.parse_return_parameters(
bytes.fromhex('0068656c6c6f') + bytes(248 - 5)
)
assert params.status == hci.HCI_ErrorCode.SUCCESS
assert isinstance(params.local_name, bytes)
assert len(params.local_name) == 248
assert hci.map_null_terminated_utf8_string(params.local_name) == 'hello'
# Some return parameters may be shorter than the full length
# (for Command Complete events with errors)
params = hci.HCI_Read_BD_ADDR_Command.parse_return_parameters(
bytes.fromhex('010011223344')
)
assert isinstance(params, hci.HCI_StatusReturnParameters)
assert params.status == hci.HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_HCI_Command(): def test_HCI_Command():
command = hci.HCI_Command(op_code=0x5566) command = hci.HCI_Command(op_code=0x5566)
@@ -291,7 +303,7 @@ def test_custom_le_meta_event():
for clazz in inspect.getmembers(hci) for clazz in inspect.getmembers(hci)
if isinstance(clazz[1], type) if isinstance(clazz[1], type)
and issubclass(clazz[1], hci.HCI_Command) and issubclass(clazz[1], hci.HCI_Command)
and clazz[1] is not hci.HCI_Command and clazz[1] not in (hci.HCI_Command, hci.HCI_SyncCommand, hci.HCI_AsyncCommand)
], ],
) )
def test_hci_command_subclasses_op_code(clazz: type[hci.HCI_Command]): def test_hci_command_subclasses_op_code(clazz: type[hci.HCI_Command]):
@@ -620,21 +632,19 @@ def test_HCI_Read_Local_Supported_Codecs_Command_Complete():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete(): def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete():
returned_parameters = ( returned_parameters = hci.HCI_Read_Local_Supported_Codecs_V2_Command.parse_return_parameters(
hci.HCI_Read_Local_Supported_Codecs_V2_Command.parse_return_parameters( bytes(
bytes( [
[ hci.HCI_SUCCESS,
hci.HCI_SUCCESS, 3,
3, hci.CodecID.A_LOG,
hci.CodecID.A_LOG, hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_ACL,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL, hci.CodecID.CVSD,
hci.CodecID.CVSD, hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_SCO,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO, hci.CodecID.LINEAR_PCM,
hci.CodecID.LINEAR_PCM, hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.LE_CIS,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS, 0,
0, ]
]
)
) )
) )
assert returned_parameters.standard_codec_ids == [ assert returned_parameters.standard_codec_ids == [
@@ -643,9 +653,9 @@ def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete():
hci.CodecID.LINEAR_PCM, hci.CodecID.LINEAR_PCM,
] ]
assert returned_parameters.standard_codec_transports == [ assert returned_parameters.standard_codec_transports == [
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL, hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_ACL,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO, hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_SCO,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS, hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.LE_CIS,
] ]
@@ -737,6 +747,7 @@ def run_test_commands():
if __name__ == '__main__': if __name__ == '__main__':
run_test_events() run_test_events()
run_test_commands() run_test_commands()
test_return_parameters()
test_address() test_address()
test_custom() test_custom()
test_iso_data_packet() test_iso_data_packet()
+89
View File
@@ -0,0 +1,89 @@
# Copyright 2021-2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import itertools
from collections.abc import Sequence
import pytest
from bumble import device as device_module
from bumble.profiles import heart_rate_service
from . import test_utils
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
"heart_rate, sensor_contact_detected, energy_expanded, rr_intervals",
itertools.product(
(1, 1000), (True, False, None), (2, None), ((3.0, 4.0, 5.0), None)
),
)
async def test_read_measurement(
heart_rate: int,
sensor_contact_detected: bool | None,
energy_expanded: int | None,
rr_intervals: Sequence[int] | None,
):
devices = await test_utils.TwoDevices.create_with_connection()
measurement = heart_rate_service.HeartRateService.HeartRateMeasurement(
heart_rate, sensor_contact_detected, energy_expanded, rr_intervals
)
service = heart_rate_service.HeartRateService(lambda _: measurement)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(heart_rate_service.HeartRateServiceProxy)
assert client
assert await client.heart_rate_measurement.read_value() == measurement
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_body_sensor_location():
devices = await test_utils.TwoDevices.create_with_connection()
measurement = heart_rate_service.HeartRateService.HeartRateMeasurement(0)
location = heart_rate_service.HeartRateService.BodySensorLocation.FINGER
service = heart_rate_service.HeartRateService(
lambda _: measurement,
body_sensor_location=location,
)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(heart_rate_service.HeartRateServiceProxy)
assert client
assert client.body_sensor_location
assert await client.body_sensor_location.read_value() == location
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_reset_energy_expended() -> None:
devices = await test_utils.TwoDevices.create_with_connection()
measurement = heart_rate_service.HeartRateService.HeartRateMeasurement(1)
reset_energy_expended = asyncio.Queue[None]()
service = heart_rate_service.HeartRateService(
lambda _: measurement,
reset_energy_expended=lambda _: reset_energy_expended.put_nowait(None),
)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(heart_rate_service.HeartRateServiceProxy)
assert client
await client.reset_energy_expended()
await reset_energy_expended.get()
+127 -2
View File
@@ -15,6 +15,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio
import logging import logging
import unittest import unittest
import unittest.mock import unittest.mock
@@ -22,9 +23,22 @@ import unittest.mock
import pytest import pytest
from bumble.controller import Controller from bumble.controller import Controller
from bumble.hci import HCI_AclDataPacket from bumble.hci import (
HCI_AclDataPacket,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_CommandStatus,
HCI_Disconnect_Command,
HCI_Error,
HCI_ErrorCode,
HCI_Event,
HCI_GenericReturnParameters,
HCI_LE_Terminate_BIG_Command,
HCI_Reset_Command,
HCI_StatusReturnParameters,
)
from bumble.host import DataPacketQueue, Host from bumble.host import DataPacketQueue, Host
from bumble.transport.common import AsyncPipeSink from bumble.transport.common import AsyncPipeSink, TransportSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -151,3 +165,114 @@ def test_data_packet_queue():
assert drain_listener.on_flow.call_count == 1 assert drain_listener.on_flow.call_count == 1
assert queue.queued == 15 assert queue.queued == 15
assert queue.completed == 15 assert queue.completed == 15
# -----------------------------------------------------------------------------
class Source:
terminated: asyncio.Future[None]
sink: TransportSink
def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink
class Sink:
response: HCI_Event
def __init__(self, source: Source, response: HCI_Event) -> None:
self.source = source
self.response = response
def on_packet(self, packet: bytes) -> None:
self.source.sink.on_packet(bytes(self.response))
@pytest.mark.asyncio
async def test_send_sync_command() -> None:
source = Source()
sink = Sink(
source,
HCI_Command_Complete_Event(
1,
HCI_Reset_Command.op_code,
HCI_StatusReturnParameters(status=HCI_ErrorCode.SUCCESS),
),
)
host = Host(source, sink)
host.ready = True
# Sync command with success
response1 = await host.send_sync_command(HCI_Reset_Command())
assert response1.status == HCI_ErrorCode.SUCCESS
# Sync command with error status should raise
error_response = HCI_Command_Complete_Event(
1,
HCI_Reset_Command.op_code,
HCI_StatusReturnParameters(status=HCI_ErrorCode.COMMAND_DISALLOWED_ERROR),
)
sink.response = error_response
with pytest.raises(HCI_Error) as excinfo:
await host.send_sync_command(HCI_Reset_Command())
assert excinfo.value.error_code == error_response.return_parameters.status
# Sync command with raw result
response2 = await host.send_sync_command_raw(HCI_Reset_Command())
assert response2.return_parameters.status == HCI_ErrorCode.COMMAND_DISALLOWED_ERROR
# Sync command with a command that's not an HCI_SyncCommand
# (here, for convenience, we use an HCI_AsyncCommand instance)
command = HCI_Disconnect_Command(connection_handle=0x1234, reason=0x13)
sink.response = HCI_Command_Complete_Event(
1,
command.op_code,
HCI_GenericReturnParameters(data=bytes.fromhex("00112233")),
)
response3 = await host.send_sync_command_raw(command) # type: ignore
assert isinstance(response3.return_parameters, HCI_GenericReturnParameters)
@pytest.mark.asyncio
async def test_send_async_command() -> None:
source = Source()
sink = Sink(
source,
HCI_Command_Status_Event(
HCI_CommandStatus.PENDING,
1,
HCI_Reset_Command.op_code,
),
)
host = Host(source, sink)
host.ready = True
# Normal pending status
response = await host.send_async_command(
HCI_LE_Terminate_BIG_Command(big_handle=0, reason=0)
)
assert response == HCI_CommandStatus.PENDING
# Unknown HCI command result returned as a Command Status
sink.response = HCI_Command_Status_Event(
HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR,
1,
HCI_LE_Terminate_BIG_Command.op_code,
)
response = await host.send_async_command(
HCI_LE_Terminate_BIG_Command(big_handle=0, reason=0), check_status=False
)
assert response == HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
# Unknown HCI command result returned as a Command Complete
sink.response = HCI_Command_Complete_Event(
1,
HCI_LE_Terminate_BIG_Command.op_code,
HCI_StatusReturnParameters(HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR),
)
response = await host.send_async_command(
HCI_LE_Terminate_BIG_Command(big_handle=0, reason=0), check_status=False
)
assert response == HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
+10 -17
View File
@@ -239,20 +239,7 @@ async def transfer_payload(
channels[1].sink = received.put_nowait channels[1].sink = received.put_nowait
sdu_lengths = (21, 70, 700, 5523) sdu_lengths = (21, 70, 700, 5523)
if isinstance(channels[1], l2cap.LeCreditBasedChannel): messages = [bytes([i % 8 for i in range(sdu_length)]) for sdu_length in sdu_lengths]
mps = channels[1].mps
elif isinstance(
processor := channels[1].processor, l2cap.EnhancedRetransmissionProcessor
):
mps = processor.mps
else:
mps = channels[1].mtu
messages = [
bytes([i % 8 for i in range(sdu_length)])
for sdu_length in sdu_lengths
if sdu_length <= mps
]
for message in messages: for message in messages:
channels[0].write(message) channels[0].write(message)
if isinstance(channels[0], l2cap.LeCreditBasedChannel): if isinstance(channels[0], l2cap.LeCreditBasedChannel):
@@ -334,20 +321,26 @@ async def test_mtu():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_enhanced_retransmission_mode(): @pytest.mark.parametrize("mtu,", (50, 255, 256, 1000))
async def test_enhanced_retransmission_mode(mtu: int):
devices = TwoDevices() devices = TwoDevices()
await devices.setup_connection() await devices.setup_connection()
server_channels = asyncio.Queue[l2cap.ClassicChannel]() server_channels = asyncio.Queue[l2cap.ClassicChannel]()
server = devices.devices[1].create_l2cap_server( server = devices.devices[1].create_l2cap_server(
spec=l2cap.ClassicChannelSpec( spec=l2cap.ClassicChannelSpec(
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
mtu=mtu,
mps=256,
), ),
handler=server_channels.put_nowait, handler=server_channels.put_nowait,
) )
client_channel = await devices.connections[0].create_l2cap_channel( client_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.ClassicChannelSpec( spec=l2cap.ClassicChannelSpec(
server.psm, mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION server.psm,
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
mtu=mtu,
mps=1024,
) )
) )
server_channel = await server_channels.get() server_channel = await server_channels.get()
+1 -1
View File
@@ -89,7 +89,7 @@ class HeartRateMonitor:
async def stop(self): async def stop(self):
# TODO: replace this once a proper reset is implemented in the lib. # TODO: replace this once a proper reset is implemented in the lib.
await self.device.host.send_command(HCI_Reset_Command()) await self.device.host.send_sync_command(HCI_Reset_Command())
await self.device.power_off() await self.device.power_off()
print('### Monitor stopped') print('### Monitor stopped')
+1 -1
View File
@@ -60,7 +60,7 @@ class Scanner(utils.EventEmitter):
async def stop(self): async def stop(self):
# TODO: replace this once a proper reset is implemented in the lib. # TODO: replace this once a proper reset is implemented in the lib.
await self.device.host.send_command(HCI_Reset_Command()) await self.device.host.send_sync_command(HCI_Reset_Command())
await self.device.power_off() await self.device.power_off()
print('### Scanner stopped') print('### Scanner stopped')
+1 -1
View File
@@ -311,7 +311,7 @@ class Speaker:
async def stop(self): async def stop(self):
# TODO: replace this once a proper reset is implemented in the lib. # TODO: replace this once a proper reset is implemented in the lib.
await self.device.host.send_command(HCI_Reset_Command()) await self.device.host.send_sync_command(HCI_Reset_Command())
await self.device.power_off() await self.device.power_off()
print('Speaker stopped') print('Speaker stopped')