Compare commits

...

136 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
8e509c18c9 remove unused import 2025-02-22 13:34:57 -08:00
Gilles Boccon-Gibod
cc21ed27c7 use bis link API 2025-02-22 13:32:58 -08:00
Gilles Boccon-Gibod
b932bafe6d Merge pull request #655 from markusjellitsch/fix/acl_packet_queue
FIX - acl_packet_queue.flush
2025-02-22 11:33:37 -08:00
markus
4e35aba033 fix acl_packet_queue flush when controller does not support HCI_READ_BUFFER_SIZE_COMMAND 2025-02-22 08:37:12 +01:00
zxzxwu
0060ee8ee2 Merge pull request #646 from zxzxwu/ad
Improve AdvertisingData type annotations
2025-02-22 02:52:09 +08:00
zxzxwu
3263d71f54 Merge pull request #654 from zxzxwu/init
Fix mutable default values
2025-02-22 02:51:11 +08:00
Josh Wu
f321143837 Improve AdvertisingData type annotations
* Add overloads to provide better return type hints
* Make advertising data type enum so they can be considered constants
2025-02-21 17:12:14 +08:00
Josh Wu
bac6f5baaf Fix mutable default values 2025-02-21 16:00:18 +08:00
Gilles Boccon-Gibod
e027bcb57a Merge pull request #651 from google/gbg/update-lc3-and-rootcanal
update rootcanal and lc3 dependencies
2025-02-18 14:08:42 -08:00
Gilles Boccon-Gibod
eeb9de31ed update dependencies 2025-02-18 12:52:57 -08:00
Gilles Boccon-Gibod
2c3af5b2bb Merge pull request #649 from google/gbg/async-read-phy
make connection phy async
2025-02-18 07:25:06 -08:00
Gilles Boccon-Gibod
dfb92e8ed1 fix pandora connection waiting 2025-02-17 19:50:57 -08:00
Gilles Boccon-Gibod
73d2b54e30 make connection phy async 2025-02-17 19:24:18 -08:00
Gilles Boccon-Gibod
8315a60f24 Merge pull request #648 from pstrueb/fix/auracast_transmit
Add setup_data_path for iso queues
2025-02-17 17:22:47 -08:00
185d5fd577 reformatting 2025-02-17 20:11:17 +01:00
pstrueb
ae5f9cf690 Remove little accident 2025-02-17 18:21:16 +01:00
pstrueb
4b66a38fe6 Refractoring again 2025-02-17 09:49:09 +01:00
pstrueb
f526f549ee refractoring 2025-02-17 08:51:28 +01:00
pstrueb
8761129677 Add setup_data_path for iso queues 2025-02-16 14:38:05 +01:00
Gilles Boccon-Gibod
3f6f036270 Merge pull request #643 from google/gbg/auracast-audio-io
auracast audio io
2025-02-08 18:19:24 -05:00
Gilles Boccon-Gibod
859bb0609f fix support for float32 2025-02-08 18:12:45 -05:00
Gilles Boccon-Gibod
5f2d24570e larger queue size 2025-02-06 21:24:58 -05:00
Gilles Boccon-Gibod
dbf94c8f3e print encoding params 2025-02-06 18:23:31 -05:00
Gilles Boccon-Gibod
b6adc29365 python 3.9 compat 2025-02-06 18:17:13 -05:00
Gilles Boccon-Gibod
5caa7bfa90 fix type checker and linter errors 2025-02-06 17:05:56 -05:00
Gilles Boccon-Gibod
f39d706fa0 remove obsolete code 2025-02-06 16:45:37 -05:00
zxzxwu
c02c1f33d2 Merge pull request #642 from zxzxwu/btbench
Add missing permissions in btbench
2025-02-07 04:48:46 +08:00
Gilles Boccon-Gibod
33435c2980 better docs and GATT fixes 2025-02-06 15:48:39 -05:00
Josh Wu
c08449d9db Add missing permissions in btbench 2025-02-07 03:13:53 +08:00
Gilles Boccon-Gibod
3c8718bb5b Merge pull request #608 from google/gbg/bench-android-enhancements
add startupDelay and connectionPriority params to BtBench snippets
2025-02-06 10:37:55 -05:00
Gilles Boccon-Gibod
26e87f09fe better error message 2025-02-05 22:28:05 -05:00
Gilles Boccon-Gibod
7f5e0d190e fix import checking 2025-02-05 22:19:39 -05:00
Gilles Boccon-Gibod
efae307b3d wip 2025-02-05 16:23:47 -05:00
zxzxwu
26d38a855c Merge pull request #641 from zxzxwu/pasync
Receive Periodic Advertising Sync Transfer
2025-02-06 05:18:47 +08:00
Josh Wu
7360a887d9 Receive Periodic Advertising Sync Transfer 2025-02-06 05:12:22 +08:00
Gilles Boccon-Gibod
9756572c93 add audio module 2025-02-04 17:58:54 -05:00
Gilles Boccon-Gibod
d6100755b1 add bond listener 2025-02-04 17:47:55 -05:00
Gilles Boccon-Gibod
a66eef6630 Merge pull request #640 from whitevegagabriel/cleanup
Rust library cleanup
2025-02-04 12:35:37 -05:00
Gabriel White-Vega
ae23ef7b9b Rust library cleanup
* Fix error code extraction from Python to Rust
* Add documentation for dealing with HCI packets
2025-02-04 12:23:06 -05:00
Gilles Boccon-Gibod
f368b5e518 wip 2025-02-03 18:02:14 -05:00
Gilles Boccon-Gibod
5293d32dc6 fix linter config 2025-02-03 18:02:14 -05:00
Gilles Boccon-Gibod
6d9a0bf4e1 fix linter config 2025-02-03 18:02:14 -05:00
Gilles Boccon-Gibod
3c7b5df7c5 add startupDelay and connectionPriority params to BtBench snippets 2025-02-03 18:00:46 -05:00
Gilles Boccon-Gibod
70141c0439 improvements 2025-02-03 17:58:09 -05:00
zxzxwu
dedc0aca54 Merge pull request #639 from zxzxwu/sdp
Correct SDP_ALL_ATTRIBUTES_RANGE value
2025-02-04 00:53:27 +08:00
Gilles Boccon-Gibod
7c019b574f Merge pull request #633 from markusjellitsch/fix/legacy-adv-params
fix advertising parameter usage for legacy advertising
2025-02-03 10:29:52 -05:00
markus
9b485fd943 revert python-avatar.yml 2025-02-03 15:17:22 +01:00
Josh Wu
fdee8269ec Correct SDP_ALL_ATTRIBUTES_RANGE value 2025-02-03 21:40:39 +08:00
zxzxwu
0767f2d4ae Merge pull request #638 from zxzxwu/avatar
Update actions/upload-artifact to v4
2025-02-03 21:31:42 +08:00
Josh Wu
c4a0846727 Update actions/upload-artifact to v4 2025-02-03 16:41:09 +08:00
zxzxwu
83ac70e426 Merge pull request #619 from zxzxwu/cs
Channel Sounding
2025-02-01 03:46:59 +08:00
markus
01cce3525f update avatar to github actions v4 2025-01-30 23:55:15 +01:00
markus
b9d35aea47 revert advertising_interval to type int 2025-01-30 19:47:20 +01:00
zxzxwu
079cf6b896 Merge pull request #624 from zxzxwu/gatt
Support GATT Service
2025-01-28 20:02:43 +08:00
Markus Jellitsch
180655088c run linter 2025-01-27 22:17:31 +01:00
Gilles Boccon-Gibod
a1bade6f20 Merge pull request #632 from markusjellitsch/fix/adapt-param-types
Adapt scanning and connection parameters type
2025-01-27 10:46:08 -05:00
Gilles Boccon-Gibod
5d80e7fd80 Merge pull request #634 from jmdietrich-gcx/fix_missing_await_for_update_rpa
Add missing await for update_rpa()
2025-01-27 10:45:42 -05:00
Jan-Marcel Dietrich
2198692961 Add missing await for update_rpa() 2025-01-27 15:14:52 +01:00
Gilles Boccon-Gibod
55d3fd90f5 wip 2025-01-25 21:04:59 -05:00
Gilles Boccon-Gibod
afee659ca6 Merge pull request #630 from google/gbg/iso-packet-queue
add support for ACL and ISO HCI packet queues
2025-01-24 15:59:19 -05:00
Gilles Boccon-Gibod
6fe7931d7d rename drain event to flow 2025-01-24 11:05:02 -05:00
Markus Jellitsch
9023407ee4 fix advertising parameters for legacy advertising 2025-01-23 15:14:54 +01:00
Markus Jellitsch
54d961bbe5 adapt scanning and connection parameters type 2025-01-23 14:53:20 +01:00
Gilles Boccon-Gibod
cbd46adbcf add support for ACL and ISO HCI packet queues 2025-01-22 13:42:29 -05:00
Josh Wu
745e107849 Channel Sounding device handlers 2025-01-22 23:38:44 +08:00
Gilles Boccon-Gibod
af466c2970 Merge pull request #629 from google/gbg/sdp-enforce-mtu
SDP: enforce MTU limits
2025-01-21 12:29:18 -05:00
Gilles Boccon-Gibod
931e2de854 address PR comments 2025-01-21 12:18:06 -05:00
Gilles Boccon-Gibod
55eb7eb237 enforce MTU limits 2025-01-21 10:31:10 -05:00
zxzxwu
bade4502f9 Merge pull request #628 from zxzxwu/cs-hci
Channel Sounding HCI packet definitions
2025-01-19 16:14:08 +08:00
Josh Wu
9f952f202f Channel Sounding HCI packet definitions 2025-01-16 14:33:34 +08:00
Josh Wu
1eb9d8d055 Support GATT Service 2025-01-15 02:13:25 +08:00
Gilles Boccon-Gibod
5a477eb391 Merge pull request #626 from markusjellitsch/fix/set-ext-scan-param-cmd
Update device.py - Fix scan_interval param in hci.HCI_LE_Set_Extended_Scan_Parameters_Command
2025-01-14 11:04:15 -05:00
Markus Jellitsch
86cda8771d Update device.py 2025-01-14 10:43:49 +01:00
zxzxwu
c1ea0ddd35 Merge pull request #622 from markusjellitsch/main
Fix: _IsoLink.write() struct.exception
2025-01-13 16:21:41 +08:00
Markus Jellitsch
f567711a6c avoid struct.error exception when packet_sequence_number > 0xFFFF 2025-01-10 01:33:43 +01:00
Gilles Boccon-Gibod
509df4c676 Merge pull request #618 from google/gbg/hci-event-multi-vendor
support multiple event factories
2025-01-07 15:00:20 -05:00
Gilles Boccon-Gibod
b375ed07b4 add test 2025-01-07 14:54:59 -05:00
Gilles Boccon-Gibod
69d62d3dd1 support multiple event factories 2025-01-06 08:42:09 -05:00
zxzxwu
fe3fa3d505 Merge pull request #617 from zxzxwu/iso
Unify ISO methods
2025-01-06 14:31:47 +08:00
Josh Wu
27fcd43224 Unify ISO methods 2025-01-02 14:19:36 +08:00
zxzxwu
c3b2bb19d5 Merge pull request #589 from zxzxwu/auracast
Auracast support
2025-01-02 01:02:13 +08:00
Gilles Boccon-Gibod
34287177b9 Merge pull request #615 from google/gbg/bluetooth-6-constants
add bluetooth 6.0 constants
2024-12-23 08:46:13 -05:00
Josh Wu
d238dd4059 Use dynamic sample rate 2024-12-23 17:01:11 +08:00
Gilles Boccon-Gibod
865f3a249f add bluetooth 6.0 constants 2024-12-22 12:47:37 -05:00
Josh Wu
7324d322fe BIG 2024-12-20 13:45:12 +08:00
Gilles Boccon-Gibod
af148b476d Merge pull request #613 from google/gbg/update-cryptography-dependency
update cryptography dependency
2024-12-19 08:42:51 -05:00
zxzxwu
80d60aaf15 Merge pull request #612 from zxzxwu/lc3
Replace liblc3 wasm library
2024-12-19 15:06:22 +08:00
Gilles Boccon-Gibod
c80f89d20f update cryptography dependency 2024-12-18 22:01:42 -05:00
Josh Wu
a27f55a588 Replace liblc3 wasm library 2024-12-19 02:21:38 +08:00
Gilles Boccon-Gibod
62e4670a39 Merge pull request #606 from wpiet/gmap-wip
Add `Gaming Audio Profile`
2024-12-18 11:56:57 -05:00
zxzxwu
99695bb264 Merge pull request #610 from zxzxwu/cfg
Remove setup.py and setup.cfg
2024-12-19 00:53:12 +08:00
Josh Wu
eb54898106 Remove setup.py and setup.cfg 2024-12-19 00:45:13 +08:00
Gilles Boccon-Gibod
4f5ee204d2 Update code-check.yml
Hot fix because 3.13.1 somehow breaks the current version of pylint. Will revert to 3.13 without pining to 3.13.0 when pylint is fixed
2024-12-18 11:36:08 -05:00
Wojciech Pietraszewski
2552e21db1 Add characteristics initial values
Sets default values for characteristics if not specified explicitly
2024-12-04 17:00:29 +01:00
Wojciech Pietraszewski
6168f87e2f Add characteristics conditionally
Only adds a characteristic if the corresponding role has been set
2024-12-04 12:57:34 +01:00
Gilles Boccon-Gibod
ca7d2ca4df Merge pull request #607 from google/gbg/pandora-deps
move pandora deps to development
2024-12-03 09:42:44 -08:00
Gilles Boccon-Gibod
60723323e9 move pandora deps to development 2024-12-03 09:08:30 -08:00
Gilles Boccon-Gibod
3ce7b9255b Merge pull request #598 from google/gbg/gatt-class-adapter
Add a class-based GATT adapter
2024-12-03 08:46:30 -08:00
Gilles Boccon-Gibod
97fcfc2fa0 Merge pull request #604 from jmdietrich-gcx/add_encryption_key_size_to_pairing_config
Add maximum encryption key size to PairingDelegate
2024-12-03 08:30:53 -08:00
Wojciech Pietraszewski
19674e3758 Add Gaming Audio Profile
Adds initial support for `Gaming Audio Service`.
2024-12-02 11:15:10 +01:00
Jan-Marcel Dietrich
1130e1db8f Fix code formatting 2024-12-02 09:01:18 +01:00
Gilles Boccon-Gibod
37c7f3a58a Merge pull request #603 from google/gbg/fix-pair-oob
fix oob support in pair.py
2024-12-01 08:43:04 -08:00
Gilles Boccon-Gibod
0a12b2bf2e Merge pull request #585 from wpiet/vocs
Add `Volume Offset Control Service`
2024-11-29 10:41:30 -08:00
Gilles Boccon-Gibod
d014acbe63 Merge pull request #597 from google/gbg/intel-hci
intel hci
2024-11-29 10:41:10 -08:00
Jan-Marcel Dietrich
07f9997a49 Add maximum encryption key size to PairingDelegate
So far the maxmium encryption key size has been hardcoded to 16 bytes in
'send_pairing_request_command()' and 'send_pairing_response_comman()'. By
making this configurable via the PairingDelegate, one can test how devices
respond to smaller encryption key sizes. Default remains 16 bytes.
2024-11-28 14:15:51 +01:00
Gilles Boccon-Gibod
b9f91f695a fix oob support in pair.py 2024-11-27 12:58:03 -08:00
Gilles Boccon-Gibod
082d55af10 Merge pull request #599 from google/gbg/hfp-19
add super wide band constants
2024-11-25 07:47:40 -08:00
Gilles Boccon-Gibod
4c3fd5688d Merge pull request #600 from google/gbg/unify-to-bytes
only use `__bytes__` when not argument is needed.
2024-11-25 07:44:17 -08:00
Gilles Boccon-Gibod
9d3d5495ce only use __bytes__ when not argument is needed. 2024-11-23 15:56:14 -08:00
Gilles Boccon-Gibod
b3869f267c add super wide band constants 2024-11-23 09:27:03 -08:00
Gilles Boccon-Gibod
8715333706 Add a GATT adapter that uses from_bytes and __bytes__ as conversion methods. 2024-11-23 09:13:04 -08:00
Gilles Boccon-Gibod
b57096abe2 Merge pull request #595 from wpiet/aics-opcode-fix
Amend Opcode value in `Audio Input Control Service`
2024-11-23 08:56:23 -08:00
Gilles Boccon-Gibod
48685c8587 improve vendor event support 2024-11-23 08:55:50 -08:00
Wojciech Pietraszewski
100bea6b41 Fix typos
Amends the typo in the `INACTIVE` field in `Audio Input Status` characteristic.
Amends the typo in the log message of `_set_gain_settings` method.
2024-11-21 18:29:44 +01:00
Wojciech Pietraszewski
63819bf9dd Amend Opcode value in Audio Input Control Service
Corrects the Audio Input Control Point
Opcode value for `Set Gain Setting` field.
2024-11-21 16:40:49 +01:00
Wojciech Pietraszewski
6e55390930 Add Volume Offset Control Service
Adds initial support for VOCS.
2024-11-21 11:56:14 +01:00
zxzxwu
e3fdab4175 Merge pull request #593 from zxzxwu/periodic
Support Periodic Advertising
2024-11-19 17:22:37 +08:00
Josh Wu
bbcd14dbf0 Support Periodic Advertising 2024-11-19 16:27:13 +08:00
zxzxwu
01dc0d574b Merge pull request #590 from SergeantSerk/parse-scan-response-data
Correctly parse scan response from device config
2024-11-17 15:39:11 +08:00
zxzxwu
5e959d638e Merge pull request #591 from zxzxwu/auracast_scan
Improve Broadcast Scanning
2024-11-16 04:10:27 +08:00
Gilles Boccon-Gibod
8d908288c8 Merge pull request #583 from google/gbg/more-gatt-tests
regression test for GATT unsubscription
2024-11-15 10:19:20 -08:00
Josh Wu
c88b32a406 Improve Broadcast Scanning 2024-11-16 02:02:28 +08:00
zxzxwu
5a72eefb89 Merge pull request #587 from zxzxwu/device
Replace HCI member imports in device.py
2024-11-13 15:25:32 +08:00
Josh Wu
430046944b Replace HCI member import in device.py 2024-11-12 16:53:21 +08:00
zxzxwu
21d23320eb Merge pull request #584 from zxzxwu/commands6.0
Add Core Spec 6.0 new commands support mapping
2024-11-12 04:17:24 +00:00
Serkan
d0990ee04d Correctly parse scan response from device config
Parses scan response data correctly just like advertising data
2024-11-07 21:49:33 +03:00
Josh Wu
2d88e853e8 Add Core Spec 6.0 new commands support mapping 2024-11-07 14:36:54 +08:00
Gilles Boccon-Gibod
a060a70fba Merge pull request #583 from google/gbg/more-gatt-tests
regression test for GATT unsubscription
2024-11-04 13:03:57 -08:00
Gilles Boccon-Gibod
a06394ad4a Merge pull request #582 from google/gbg/580
fix #580
2024-11-04 13:03:15 -08:00
Gilles Boccon-Gibod
a1414c2b5b add unsubscribe test 2024-11-03 19:08:27 -08:00
Gilles Boccon-Gibod
b2864dac2d fix #580 2024-11-02 10:29:40 -07:00
Gilles Boccon-Gibod
b78f895143 Merge pull request #579 from jmdietrich-gcx/unsubscribe_characteristic_in_gatt_client
Remove characteristic in GATT Client unsubscribe() if it's the last subscriber
2024-10-31 04:07:02 -07:00
zxzxwu
c4e9726828 Merge pull request #581 from zxzxwu/context
[BAP] Add missing Unspecified context type
2024-10-31 11:04:25 +00:00
Gilles Boccon-Gibod
d4b8e8348a Merge pull request #574 from google/gbg/update-python-versions
remove test for deprecated Python 3.8 and add 3.13
2024-10-31 03:44:01 -07:00
Josh Wu
19debaa52e [BAP] Add missing Unspecified context type 2024-10-31 18:11:40 +08:00
Jan-Marcel Dietrich
73fe564321 Remove characteristic in GATT Client unsubscribe() if it's the last subscriber
GATT Client's subscribe() adds the characteristic itself as subscriber.
Therefore the characteristic has to be removed in unsubscribe(), if it's
the last subscriber. Otherwise the clean up does not work correctly and
the CCCD never is set back to 0 in the remote device.
2024-10-30 07:34:22 +01:00
106 changed files with 10015 additions and 2396 deletions

View File

@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python-version: ["3.9", "3.10", "3.11", "3.12", "3.13.0"]
fail-fast: false fail-fast: false
steps: steps:
@@ -33,7 +33,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ".[build,test,development,pandora]" python -m pip install ".[build,test,development]"
- name: Check - name: Check
run: | run: |
invoke project.pre-commit invoke project.pre-commit

View File

@@ -32,7 +32,7 @@ jobs:
- name: Install - name: Install
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install .[avatar,pandora] python -m pip install .[avatar]
- name: Rootcanal - name: Rootcanal
run: nohup python -m rootcanal > rootcanal.log & run: nohup python -m rootcanal > rootcanal.log &
- name: Test - name: Test
@@ -44,7 +44,7 @@ jobs:
run: cat rootcanal.log run: cat rootcanal.log
- name: Upload Mobly logs - name: Upload Mobly logs
if: always() if: always()
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: mobly-logs name: mobly-logs-${{ strategy.job-index }}
path: /tmp/logs/mobly/bumble.bumbles/ path: /tmp/logs/mobly/bumble.bumbles/

View File

@@ -14,9 +14,12 @@
"ASHA", "ASHA",
"asyncio", "asyncio",
"ATRAC", "ATRAC",
"auracast",
"avctp", "avctp",
"avdtp", "avdtp",
"avrcp", "avrcp",
"biginfo",
"bigs",
"bitpool", "bitpool",
"bitstruct", "bitstruct",
"BSCP", "BSCP",
@@ -36,6 +39,7 @@
"deregistration", "deregistration",
"dhkey", "dhkey",
"diversifier", "diversifier",
"ediv",
"endianness", "endianness",
"ESCO", "ESCO",
"Fitbit", "Fitbit",
@@ -47,6 +51,7 @@
"libc", "libc",
"liblc", "liblc",
"libusb", "libusb",
"maxs",
"MITM", "MITM",
"MSBC", "MSBC",
"NDIS", "NDIS",
@@ -54,8 +59,10 @@
"NONBLOCK", "NONBLOCK",
"NONCONN", "NONCONN",
"OXIMETER", "OXIMETER",
"PDUS",
"popleft", "popleft",
"PRAND", "PRAND",
"prefs",
"protobuf", "protobuf",
"psms", "psms",
"pyee", "pyee",

View File

@@ -1,4 +1,4 @@
# Copyright 2024 Google LLC # Copyright 2025 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -16,29 +16,49 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import asyncio.subprocess
import collections
import contextlib import contextlib
import dataclasses import dataclasses
import functools
import logging import logging
import os import os
from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple import struct
from typing import (
Any,
AsyncGenerator,
Coroutine,
Deque,
Optional,
Tuple,
)
import click import click
import pyee import pyee
try:
import lc3 # type: ignore # pylint: disable=E0401
except ImportError as e:
raise ImportError(
"Try `python -m pip install \"git+https://github.com/google/liblc3.git\"`."
) from e
from bumble.audio import io as audio_io
from bumble.colors import color from bumble.colors import color
import bumble.company_ids from bumble import company_ids
import bumble.core from bumble import core
from bumble import gatt
from bumble import hci
from bumble.profiles import bap
from bumble.profiles import le_audio
from bumble.profiles import pbp
from bumble.profiles import bass
import bumble.device import bumble.device
import bumble.gatt
import bumble.hci
import bumble.profiles.bap
import bumble.profiles.bass
import bumble.profiles.pbp
import bumble.transport import bumble.transport
import bumble.utils import bumble.utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -49,9 +69,34 @@ logger = logging.getLogger(__name__)
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast' AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast'
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address('F0:F1:F2:F3:F4:F5') AURACAST_DEFAULT_DEVICE_ADDRESS = hci.Address('F0:F1:F2:F3:F4:F5')
AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0 AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0
AURACAST_DEFAULT_ATT_MTU = 256 AURACAST_DEFAULT_ATT_MTU = 256
AURACAST_DEFAULT_FRAME_DURATION = 10000
AURACAST_DEFAULT_SAMPLE_RATE = 48000
AURACAST_DEFAULT_TRANSMIT_BITRATE = 80000
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def codec_config_string(
codec_config: bap.CodecSpecificConfiguration, indent: str
) -> str:
lines = []
if codec_config.sampling_frequency is not None:
lines.append(f'Sampling Frequency: {codec_config.sampling_frequency.hz} hz')
if codec_config.frame_duration is not None:
lines.append(f'Frame Duration: {codec_config.frame_duration.us} µs')
if codec_config.octets_per_codec_frame is not None:
lines.append(f'Frame Size: {codec_config.octets_per_codec_frame} bytes')
if codec_config.codec_frames_per_sdu is not None:
lines.append(f'Frames Per SDU: {codec_config.codec_frames_per_sdu}')
if codec_config.audio_channel_allocation is not None:
lines.append(
f'Audio Location: {codec_config.audio_channel_allocation.name}'
)
return '\n'.join(indent + line for line in lines)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -60,19 +105,14 @@ AURACAST_DEFAULT_ATT_MTU = 256
class BroadcastScanner(pyee.EventEmitter): class BroadcastScanner(pyee.EventEmitter):
@dataclasses.dataclass @dataclasses.dataclass
class Broadcast(pyee.EventEmitter): class Broadcast(pyee.EventEmitter):
name: str name: str | None
sync: bumble.device.PeriodicAdvertisingSync sync: bumble.device.PeriodicAdvertisingSync
broadcast_id: int
rssi: int = 0 rssi: int = 0
public_broadcast_announcement: Optional[ public_broadcast_announcement: Optional[pbp.PublicBroadcastAnnouncement] = None
bumble.profiles.pbp.PublicBroadcastAnnouncement broadcast_audio_announcement: Optional[bap.BroadcastAudioAnnouncement] = None
] = None basic_audio_announcement: Optional[bap.BasicAudioAnnouncement] = None
broadcast_audio_announcement: Optional[ appearance: Optional[core.Appearance] = None
bumble.profiles.bap.BroadcastAudioAnnouncement
] = None
basic_audio_announcement: Optional[
bumble.profiles.bap.BasicAudioAnnouncement
] = None
appearance: Optional[bumble.core.Appearance] = None
biginfo: Optional[bumble.device.BIGInfoAdvertisement] = None biginfo: Optional[bumble.device.BIGInfoAdvertisement] = None
manufacturer_data: Optional[Tuple[str, bytes]] = None manufacturer_data: Optional[Tuple[str, bytes]] = None
@@ -86,42 +126,32 @@ class BroadcastScanner(pyee.EventEmitter):
def update(self, advertisement: bumble.device.Advertisement) -> None: def update(self, advertisement: bumble.device.Advertisement) -> None:
self.rssi = advertisement.rssi self.rssi = advertisement.rssi
for service_data in advertisement.data.get_all( for service_data in advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA core.AdvertisingData.Type.SERVICE_DATA_16_BIT_UUID
): ):
assert isinstance(service_data, tuple)
service_uuid, data = service_data service_uuid, data = service_data
assert isinstance(data, bytes)
if ( if service_uuid == gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE:
service_uuid
== bumble.gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE
):
self.public_broadcast_announcement = ( self.public_broadcast_announcement = (
bumble.profiles.pbp.PublicBroadcastAnnouncement.from_bytes(data) pbp.PublicBroadcastAnnouncement.from_bytes(data)
) )
continue continue
if ( if service_uuid == gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE:
service_uuid
== bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
):
self.broadcast_audio_announcement = ( self.broadcast_audio_announcement = (
bumble.profiles.bap.BroadcastAudioAnnouncement.from_bytes(data) bap.BroadcastAudioAnnouncement.from_bytes(data)
) )
continue continue
self.appearance = advertisement.data.get( # type: ignore[assignment] self.appearance = advertisement.data.get(
bumble.core.AdvertisingData.APPEARANCE core.AdvertisingData.Type.APPEARANCE
) )
if manufacturer_data := advertisement.data.get( if manufacturer_data := advertisement.data.get(
bumble.core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA core.AdvertisingData.Type.MANUFACTURER_SPECIFIC_DATA
): ):
assert isinstance(manufacturer_data, tuple) company_id, data = manufacturer_data
company_id = cast(int, manufacturer_data[0])
data = cast(bytes, manufacturer_data[1])
self.manufacturer_data = ( self.manufacturer_data = (
bumble.company_ids.COMPANY_IDENTIFIERS.get( company_ids.COMPANY_IDENTIFIERS.get(
company_id, f'0x{company_id:04X}' company_id, f'0x{company_id:04X}'
), ),
data, data,
@@ -135,7 +165,8 @@ class BroadcastScanner(pyee.EventEmitter):
self.sync.advertiser_address, self.sync.advertiser_address,
color(self.sync.state.name, 'green'), color(self.sync.state.name, 'green'),
) )
print(f' {color("Name", "cyan")}: {self.name}') if self.name is not None:
print(f' {color("Name", "cyan")}: {self.name}')
if self.appearance: if self.appearance:
print(f' {color("Appearance", "cyan")}: {str(self.appearance)}') print(f' {color("Appearance", "cyan")}: {str(self.appearance)}')
print(f' {color("RSSI", "cyan")}: {self.rssi}') print(f' {color("RSSI", "cyan")}: {self.rssi}')
@@ -156,25 +187,24 @@ class BroadcastScanner(pyee.EventEmitter):
if self.public_broadcast_announcement: if self.public_broadcast_announcement:
print( print(
f' {color("Features", "cyan")}: ' f' {color("Features", "cyan")}: '
f'{self.public_broadcast_announcement.features}' f'{self.public_broadcast_announcement.features.name}'
)
print(
f' {color("Metadata", "cyan")}: '
f'{self.public_broadcast_announcement.metadata}'
) )
print(f' {color("Metadata", "cyan")}:')
print(self.public_broadcast_announcement.metadata.pretty_print(' '))
if self.basic_audio_announcement: if self.basic_audio_announcement:
print(color(' Audio:', 'cyan')) print(color(' Audio:', 'cyan'))
print( print(
color(' Presentation Delay:', 'magenta'), color(' Presentation Delay:', 'magenta'),
self.basic_audio_announcement.presentation_delay, self.basic_audio_announcement.presentation_delay,
"µs",
) )
for subgroup in self.basic_audio_announcement.subgroups: for subgroup in self.basic_audio_announcement.subgroups:
print(color(' Subgroup:', 'magenta')) print(color(' Subgroup:', 'magenta'))
print(color(' Codec ID:', 'yellow')) print(color(' Codec ID:', 'yellow'))
print( print(
color(' Coding Format: ', 'green'), color(' Coding Format: ', 'green'),
subgroup.codec_id.coding_format.name, subgroup.codec_id.codec_id.name,
) )
print( print(
color(' Company ID: ', 'green'), color(' Company ID: ', 'green'),
@@ -184,17 +214,22 @@ class BroadcastScanner(pyee.EventEmitter):
color(' Vendor Specific Codec ID:', 'green'), color(' Vendor Specific Codec ID:', 'green'),
subgroup.codec_id.vendor_specific_codec_id, subgroup.codec_id.vendor_specific_codec_id,
) )
print(color(' Codec Config:', 'yellow'))
print( print(
color(' Codec Config:', 'yellow'), codec_config_string(
subgroup.codec_specific_configuration, subgroup.codec_specific_configuration, ' '
),
) )
print(color(' Metadata: ', 'yellow'), subgroup.metadata) print(color(' Metadata: ', 'yellow'))
print(subgroup.metadata.pretty_print(' '))
for bis in subgroup.bis: for bis in subgroup.bis:
print(color(f' BIS [{bis.index}]:', 'yellow')) print(color(f' BIS [{bis.index}]:', 'yellow'))
print(color(' Codec Config:', 'green'))
print( print(
color(' Codec Config:', 'green'), codec_config_string(
bis.codec_specific_configuration, bis.codec_specific_configuration, ' '
),
) )
if self.biginfo: if self.biginfo:
@@ -231,15 +266,13 @@ class BroadcastScanner(pyee.EventEmitter):
return return
for service_data in advertisement.data.get_all( for service_data in advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA core.AdvertisingData.Type.SERVICE_DATA_16_BIT_UUID
): ):
assert isinstance(service_data, tuple)
service_uuid, data = service_data service_uuid, data = service_data
assert isinstance(data, bytes)
if service_uuid == bumble.gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE: if service_uuid == gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
self.basic_audio_announcement = ( self.basic_audio_announcement = (
bumble.profiles.bap.BasicAudioAnnouncement.from_bytes(data) bap.BasicAudioAnnouncement.from_bytes(data)
) )
break break
@@ -261,7 +294,7 @@ class BroadcastScanner(pyee.EventEmitter):
self.device = device self.device = device
self.filter_duplicates = filter_duplicates self.filter_duplicates = filter_duplicates
self.sync_timeout = sync_timeout self.sync_timeout = sync_timeout
self.broadcasts: Dict[bumble.hci.Address, BroadcastScanner.Broadcast] = {} self.broadcasts = dict[hci.Address, BroadcastScanner.Broadcast]()
device.on('advertisement', self.on_advertisement) device.on('advertisement', self.on_advertisement)
async def start(self) -> None: async def start(self) -> None:
@@ -274,24 +307,45 @@ class BroadcastScanner(pyee.EventEmitter):
await self.device.stop_scanning() await self.device.stop_scanning()
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None: def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
if ( if not (
broadcast_name := advertisement.data.get( ads := advertisement.data.get_all(
bumble.core.AdvertisingData.BROADCAST_NAME core.AdvertisingData.Type.SERVICE_DATA_16_BIT_UUID
) )
) is None: ) or not (
broadcast_audio_announcement := next(
(
ad
for ad in ads
if ad[0] == gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
),
None,
)
):
return return
assert isinstance(broadcast_name, str)
broadcast_name = advertisement.data.get_all(
core.AdvertisingData.Type.BROADCAST_NAME
)
if broadcast := self.broadcasts.get(advertisement.address): if broadcast := self.broadcasts.get(advertisement.address):
broadcast.update(advertisement) broadcast.update(advertisement)
return return
bumble.utils.AsyncRunner.spawn( bumble.utils.AsyncRunner.spawn(
self.on_new_broadcast(broadcast_name, advertisement) self.on_new_broadcast(
broadcast_name[0] if broadcast_name else None,
advertisement,
bap.BroadcastAudioAnnouncement.from_bytes(
broadcast_audio_announcement[1]
).broadcast_id,
)
) )
async def on_new_broadcast( async def on_new_broadcast(
self, name: str, advertisement: bumble.device.Advertisement self,
name: str | None,
advertisement: bumble.device.Advertisement,
broadcast_id: int,
) -> None: ) -> None:
periodic_advertising_sync = await self.device.create_periodic_advertising_sync( periodic_advertising_sync = await self.device.create_periodic_advertising_sync(
advertiser_address=advertisement.address, advertiser_address=advertisement.address,
@@ -299,10 +353,7 @@ class BroadcastScanner(pyee.EventEmitter):
sync_timeout=self.sync_timeout, sync_timeout=self.sync_timeout,
filter_duplicates=self.filter_duplicates, filter_duplicates=self.filter_duplicates,
) )
broadcast = self.Broadcast( broadcast = self.Broadcast(name, periodic_advertising_sync, broadcast_id)
name,
periodic_advertising_sync,
)
broadcast.update(advertisement) broadcast.update(advertisement)
self.broadcasts[advertisement.address] = broadcast self.broadcasts[advertisement.address] = broadcast
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast)) periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
@@ -314,10 +365,11 @@ class BroadcastScanner(pyee.EventEmitter):
self.emit('broadcast_loss', broadcast) self.emit('broadcast_loss', broadcast)
class PrintingBroadcastScanner: class PrintingBroadcastScanner(pyee.EventEmitter):
def __init__( def __init__(
self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float
) -> None: ) -> None:
super().__init__()
self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout) self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout)
self.scanner.on('new_broadcast', self.on_new_broadcast) self.scanner.on('new_broadcast', self.on_new_broadcast)
self.scanner.on('broadcast_loss', self.on_broadcast_loss) self.scanner.on('broadcast_loss', self.on_broadcast_loss)
@@ -452,27 +504,29 @@ async def run_assist(
await peer.request_mtu(mtu) await peer.request_mtu(mtu)
# Get the BASS service # Get the BASS service
bass = await peer.discover_service_and_create_proxy( bass_client = await peer.discover_service_and_create_proxy(
bumble.profiles.bass.BroadcastAudioScanServiceProxy bass.BroadcastAudioScanServiceProxy
) )
# Check that the service was found # Check that the service was found
if not bass: if not bass_client:
print(color('!!! Broadcast Audio Scan Service not found', 'red')) print(color('!!! Broadcast Audio Scan Service not found', 'red'))
return return
# Subscribe to and read the broadcast receive state characteristics # Subscribe to and read the broadcast receive state characteristics
for i, broadcast_receive_state in enumerate(bass.broadcast_receive_states): for i, broadcast_receive_state in enumerate(
bass_client.broadcast_receive_states
):
try: try:
await broadcast_receive_state.subscribe( await broadcast_receive_state.subscribe(
lambda value, i=i: print( lambda value, i=i: print(
f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}" f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}"
) )
) )
except bumble.core.ProtocolError as error: except core.ProtocolError as error:
print( print(
color( color(
f'!!! Failed to subscribe to Broadcast Receive State characteristic:', '!!! Failed to subscribe to Broadcast Receive State characteristic',
'red', 'red',
), ),
error, error,
@@ -488,7 +542,7 @@ async def run_assist(
if command == 'add-source': if command == 'add-source':
# Find the requested broadcast # Find the requested broadcast
await bass.remote_scan_started() await bass_client.remote_scan_started()
if broadcast_name: if broadcast_name:
print(color('Scanning for broadcast:', 'cyan'), broadcast_name) print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
else: else:
@@ -508,15 +562,15 @@ async def run_assist(
# Add the source # Add the source
print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address) print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address)
await bass.add_source( await bass_client.add_source(
broadcast.sync.advertiser_address, broadcast.sync.advertiser_address,
broadcast.sync.sid, broadcast.sync.sid,
broadcast.broadcast_audio_announcement.broadcast_id, broadcast.broadcast_audio_announcement.broadcast_id,
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE, bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
0xFFFF, 0xFFFF,
[ [
bumble.profiles.bass.SubgroupInfo( bass.SubgroupInfo(
bumble.profiles.bass.SubgroupInfo.ANY_BIS, bass.SubgroupInfo.ANY_BIS,
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata), bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
) )
], ],
@@ -526,7 +580,7 @@ async def run_assist(
await broadcast.sync.transfer(peer.connection) await broadcast.sync.transfer(peer.connection)
# Notify the sink that we're done scanning. # Notify the sink that we're done scanning.
await bass.remote_scan_stopped() await bass_client.remote_scan_stopped()
await peer.sustain() await peer.sustain()
return return
@@ -537,7 +591,7 @@ async def run_assist(
return return
# Find the requested broadcast # Find the requested broadcast
await bass.remote_scan_started() await bass_client.remote_scan_started()
if broadcast_name: if broadcast_name:
print(color('Scanning for broadcast:', 'cyan'), broadcast_name) print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
else: else:
@@ -560,13 +614,13 @@ async def run_assist(
color('Modifying source:', 'blue'), color('Modifying source:', 'blue'),
source_id, source_id,
) )
await bass.modify_source( await bass_client.modify_source(
source_id, source_id,
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE, bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
0xFFFF, 0xFFFF,
[ [
bumble.profiles.bass.SubgroupInfo( bass.SubgroupInfo(
bumble.profiles.bass.SubgroupInfo.ANY_BIS, bass.SubgroupInfo.ANY_BIS,
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata), bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
) )
], ],
@@ -581,7 +635,7 @@ async def run_assist(
# Remove the source # Remove the source
print(color('Removing source:', 'blue'), source_id) print(color('Removing source:', 'blue'), source_id)
await bass.remove_source(source_id) await bass_client.remove_source(source_id)
await peer.sustain() await peer.sustain()
return return
@@ -601,14 +655,339 @@ async def run_pair(transport: str, address: str) -> None:
print("+++ Paired") print("+++ Paired")
async def run_receive(
transport: str,
broadcast_id: Optional[int],
output: str,
broadcast_code: str | None,
sync_timeout: float,
subgroup_index: int,
) -> None:
# Run a pre-flight check for the output.
try:
if not audio_io.check_audio_output(output):
return
except ValueError as error:
print(error)
return
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return
scanner = BroadcastScanner(device, False, sync_timeout)
scan_result: asyncio.Future[BroadcastScanner.Broadcast] = (
asyncio.get_running_loop().create_future()
)
def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
if scan_result.done():
return
if broadcast_id is None or broadcast.broadcast_id == broadcast_id:
scan_result.set_result(broadcast)
scanner.on('new_broadcast', on_new_broadcast)
await scanner.start()
print('Start scanning...')
broadcast = await scan_result
print('Advertisement found:')
broadcast.print()
basic_audio_announcement_scanned = asyncio.Event()
def on_change() -> None:
if (
broadcast.basic_audio_announcement
and not basic_audio_announcement_scanned.is_set()
):
basic_audio_announcement_scanned.set()
broadcast.on('change', on_change)
if not broadcast.basic_audio_announcement:
print('Wait for Basic Audio Announcement...')
await basic_audio_announcement_scanned.wait()
print('Basic Audio Announcement found')
broadcast.print()
print('Stop scanning')
await scanner.stop()
print('Start sync to BIG')
assert broadcast.basic_audio_announcement
subgroup = broadcast.basic_audio_announcement.subgroups[subgroup_index]
configuration = subgroup.codec_specific_configuration
assert configuration
assert (sampling_frequency := configuration.sampling_frequency)
assert (frame_duration := configuration.frame_duration)
big_sync = await device.create_big_sync(
broadcast.sync,
bumble.device.BigSyncParameters(
big_sync_timeout=0x4000,
bis=[bis.index for bis in subgroup.bis],
broadcast_code=(
bytes.fromhex(broadcast_code) if broadcast_code else None
),
),
)
num_bis = len(big_sync.bis_links)
decoder = lc3.Decoder(
frame_duration_us=frame_duration.us,
sample_rate_hz=sampling_frequency.hz,
num_channels=num_bis,
)
lc3_queues: list[Deque[bytes]] = [collections.deque() for i in range(num_bis)]
packet_stats = [0, 0]
audio_output = await audio_io.create_audio_output(output)
# This try should be replaced with contextlib.aclosing() when python 3.9 is no
# longer needed.
try:
await audio_output.open(
audio_io.PcmFormat(
audio_io.PcmFormat.Endianness.LITTLE,
audio_io.PcmFormat.SampleType.FLOAT32,
sampling_frequency.hz,
num_bis,
)
)
def sink(queue: Deque[bytes], packet: hci.HCI_IsoDataPacket):
# TODO: re-assemble fragments and detect errors
queue.append(packet.iso_sdu_fragment)
while all(lc3_queues):
# This assumes SDUs contain one LC3 frame each, which may not
# be correct for all cases. TODO: revisit this assumption.
frame = b''.join([lc3_queue.popleft() for lc3_queue in lc3_queues])
if not frame:
print(color('!!! received empty frame', 'red'))
continue
packet_stats[0] += len(frame)
packet_stats[1] += 1
print(
f'\rRECEIVED: {packet_stats[0]} bytes in '
f'{packet_stats[1]} packets',
end='',
)
try:
pcm = decoder.decode(frame).tobytes()
except lc3.BaseError as error:
print(color(f'!!! LC3 decoding error: {error}'))
continue
audio_output.write(pcm)
for i, bis_link in enumerate(big_sync.bis_links):
print(f'Setup ISO for BIS {bis_link.handle}')
bis_link.sink = functools.partial(sink, lc3_queues[i])
await bis_link.setup_data_path(
direction=bis_link.Direction.CONTROLLER_TO_HOST
)
terminated = asyncio.Event()
big_sync.on(big_sync.Event.TERMINATION, lambda _: terminated.set())
await terminated.wait()
finally:
await audio_output.aclose()
async def run_transmit(
transport: str,
broadcast_id: int,
broadcast_code: str | None,
broadcast_name: str,
bitrate: int,
manufacturer_data: tuple[int, bytes] | None,
input: str,
input_format: str,
) -> None:
# Run a pre-flight check for the input.
try:
if not audio_io.check_audio_input(input):
return
except ValueError as error:
print(error)
return
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return
basic_audio_announcement = bap.BasicAudioAnnouncement(
presentation_delay=40000,
subgroups=[
bap.BasicAudioAnnouncement.Subgroup(
codec_id=hci.CodingFormat(codec_id=hci.CodecID.LC3),
codec_specific_configuration=bap.CodecSpecificConfiguration(
sampling_frequency=bap.SamplingFrequency.FREQ_48000,
frame_duration=bap.FrameDuration.DURATION_10000_US,
octets_per_codec_frame=100,
),
metadata=le_audio.Metadata(
[
le_audio.Metadata.Entry(
tag=le_audio.Metadata.Tag.LANGUAGE, data=b'eng'
),
le_audio.Metadata.Entry(
tag=le_audio.Metadata.Tag.PROGRAM_INFO, data=b'Disco'
),
]
),
bis=[
bap.BasicAudioAnnouncement.BIS(
index=1,
codec_specific_configuration=bap.CodecSpecificConfiguration(
audio_channel_allocation=bap.AudioLocation.FRONT_LEFT
),
),
bap.BasicAudioAnnouncement.BIS(
index=2,
codec_specific_configuration=bap.CodecSpecificConfiguration(
audio_channel_allocation=bap.AudioLocation.FRONT_RIGHT
),
),
],
)
],
)
broadcast_audio_announcement = bap.BroadcastAudioAnnouncement(broadcast_id)
advertising_manufacturer_data = (
b''
if manufacturer_data is None
else bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA,
struct.pack('<H', manufacturer_data[0])
+ manufacturer_data[1],
)
]
)
)
)
advertising_set = await device.create_advertising_set(
advertising_parameters=bumble.device.AdvertisingParameters(
advertising_event_properties=bumble.device.AdvertisingEventProperties(
is_connectable=False
),
primary_advertising_interval_min=100,
primary_advertising_interval_max=200,
),
advertising_data=(
broadcast_audio_announcement.get_advertising_data()
+ bytes(
core.AdvertisingData(
[(core.AdvertisingData.BROADCAST_NAME, broadcast_name.encode())]
)
)
+ advertising_manufacturer_data
),
periodic_advertising_parameters=bumble.device.PeriodicAdvertisingParameters(
periodic_advertising_interval_min=80,
periodic_advertising_interval_max=160,
),
periodic_advertising_data=basic_audio_announcement.get_advertising_data(),
auto_restart=True,
auto_start=True,
)
print('Start Periodic Advertising')
await advertising_set.start_periodic()
audio_input = await audio_io.create_audio_input(input, input_format)
pcm_format = await audio_input.open()
# This try should be replaced with contextlib.aclosing() when python 3.9 is no
# longer needed.
try:
if pcm_format.channels != 2:
print("Only 2 channels PCM configurations are supported")
return
if pcm_format.sample_type == audio_io.PcmFormat.SampleType.INT16:
pcm_bit_depth = 16
elif pcm_format.sample_type == audio_io.PcmFormat.SampleType.FLOAT32:
pcm_bit_depth = None
else:
print("Only INT16 and FLOAT32 sample types are supported")
return
encoder = lc3.Encoder(
frame_duration_us=AURACAST_DEFAULT_FRAME_DURATION,
sample_rate_hz=AURACAST_DEFAULT_SAMPLE_RATE,
num_channels=pcm_format.channels,
input_sample_rate_hz=pcm_format.sample_rate,
)
lc3_frame_samples = encoder.get_frame_samples()
lc3_frame_size = encoder.get_frame_bytes(bitrate)
print(
f'Encoding with {lc3_frame_samples} '
f'PCM samples per {lc3_frame_size} byte frame'
)
print('Setup BIG')
big = await device.create_big(
advertising_set,
parameters=bumble.device.BigParameters(
num_bis=pcm_format.channels,
sdu_interval=AURACAST_DEFAULT_FRAME_DURATION,
max_sdu=lc3_frame_size,
max_transport_latency=65,
rtn=4,
broadcast_code=(
bytes.fromhex(broadcast_code) if broadcast_code else None
),
),
)
for bis_link in big.bis_links:
print(f'Setup ISO for BIS {bis_link.handle}')
await bis_link.setup_data_path(
direction=bis_link.Direction.HOST_TO_CONTROLLER
)
iso_queues = [
bumble.device.IsoPacketStream(bis_link, 64)
for bis_link in big.bis_links
]
def on_flow():
data_packet_queue = iso_queues[0].data_packet_queue
print(
f'\rPACKETS: pending={data_packet_queue.pending}, '
f'queued={data_packet_queue.queued}, '
f'completed={data_packet_queue.completed}',
end='',
)
iso_queues[0].data_packet_queue.on('flow', on_flow)
frame_count = 0
async for pcm_frame in audio_input.frames(lc3_frame_samples):
lc3_frame = encoder.encode(
pcm_frame, num_bytes=2 * lc3_frame_size, bit_depth=pcm_bit_depth
)
mid = len(lc3_frame) // 2
await iso_queues[0].write(lc3_frame[:mid])
await iso_queues[1].write(lc3_frame[mid:])
frame_count += 1
finally:
await audio_input.aclose()
def run_async(async_command: Coroutine) -> None: def run_async(async_command: Coroutine) -> None:
try: try:
asyncio.run(async_command) asyncio.run(async_command)
except bumble.core.ProtocolError as error: except core.ProtocolError as error:
if error.error_namespace == 'att' and error.error_code in list( if error.error_namespace == 'att' and error.error_code in list(
bumble.profiles.bass.ApplicationError bass.ApplicationError
): ):
message = bumble.profiles.bass.ApplicationError(error.error_code).name message = bass.ApplicationError(error.error_code).name
else: else:
message = str(error) message = str(error)
@@ -622,9 +1001,7 @@ def run_async(async_command: Coroutine) -> None:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@click.group() @click.group()
@click.pass_context @click.pass_context
def auracast( def auracast(ctx):
ctx,
):
ctx.ensure_object(dict) ctx.ensure_object(dict)
@@ -669,7 +1046,7 @@ def scan(ctx, filter_duplicates, sync_timeout, transport):
@click.argument('address') @click.argument('address')
@click.pass_context @click.pass_context
def assist(ctx, broadcast_name, source_id, command, transport, address): def assist(ctx, broadcast_name, source_id, command, transport, address):
"""Scan for broadcasts on behalf of a audio server""" """Scan for broadcasts on behalf of an audio server"""
run_async(run_assist(broadcast_name, source_id, command, transport, address)) run_async(run_assist(broadcast_name, source_id, command, transport, address))
@@ -682,6 +1059,166 @@ def pair(ctx, transport, address):
run_async(run_pair(transport, address)) run_async(run_pair(transport, address))
@auracast.command('receive')
@click.argument('transport')
@click.argument(
'broadcast_id',
type=int,
required=False,
)
@click.option(
'--output',
default='device',
help=(
"Audio output. "
"'device' -> use the host's default sound output device, "
"'device:<DEVICE_ID>' -> use one of the host's sound output device "
"(specify 'device:?' to get a list of available sound output devices), "
"'stdout' -> send audio to stdout, "
"'file:<filename> -> write audio to a raw float32 PCM file, "
"'ffplay' -> pipe the audio to ffplay"
),
)
@click.option(
'--broadcast-code',
metavar='BROADCAST_CODE',
type=str,
help='Broadcast encryption code in hex format',
)
@click.option(
'--sync-timeout',
metavar='SYNC_TIMEOUT',
type=float,
default=AURACAST_DEFAULT_SYNC_TIMEOUT,
help='Sync timeout (in seconds)',
)
@click.option(
'--subgroup',
metavar='SUBGROUP',
type=int,
default=0,
help='Index of Subgroup',
)
@click.pass_context
def receive(
ctx,
transport,
broadcast_id,
output,
broadcast_code,
sync_timeout,
subgroup,
):
"""Receive a broadcast source"""
run_async(
run_receive(
transport,
broadcast_id,
output,
broadcast_code,
sync_timeout,
subgroup,
)
)
@auracast.command('transmit')
@click.argument('transport')
@click.option(
'--input',
required=True,
help=(
"Audio input. "
"'device' -> use the host's default sound input device, "
"'device:<DEVICE_ID>' -> use one of the host's sound input devices "
"(specify 'device:?' to get a list of available sound input devices), "
"'stdin' -> receive audio from stdin as int16 PCM, "
"'file:<filename> -> read audio from a .wav or raw int16 PCM file. "
"(The file: prefix may be omitted if the file path does not start with "
"the substring 'device:' or 'file:' and is not 'stdin')"
),
)
@click.option(
'--input-format',
metavar="FORMAT",
default='auto',
help=(
"Audio input format. "
"Use 'auto' for .wav files, or for the default setting with the devices. "
"For other inputs, the format is specified as "
"<sample-type>,<sample-rate>,<channels> (supported <sample-type>: 'int16le' "
"for 16-bit signed integers with little-endian byte order or 'float32le' for "
"32-bit floating point with little-endian byte order)"
),
)
@click.option(
'--broadcast-id',
metavar='BROADCAST_ID',
type=int,
default=123456,
help='Broadcast ID',
)
@click.option(
'--broadcast-code',
metavar='BROADCAST_CODE',
help='Broadcast encryption code in hex format',
)
@click.option(
'--broadcast-name',
metavar='BROADCAST_NAME',
default='Bumble Auracast',
help='Broadcast name',
)
@click.option(
'--bitrate',
type=int,
default=AURACAST_DEFAULT_TRANSMIT_BITRATE,
help='Bitrate, per channel, in bps',
)
@click.option(
'--manufacturer-data',
metavar='VENDOR-ID:DATA-HEX',
help='Manufacturer data (specify as <vendor-id>:<data-hex>)',
)
@click.pass_context
def transmit(
ctx,
transport,
broadcast_id,
broadcast_code,
manufacturer_data,
broadcast_name,
bitrate,
input,
input_format,
):
"""Transmit a broadcast source"""
if manufacturer_data:
vendor_id_str, data_hex = manufacturer_data.split(':')
vendor_id = int(vendor_id_str)
data = bytes.fromhex(data_hex)
manufacturer_data_tuple = (vendor_id, data)
else:
manufacturer_data_tuple = None
if (input == 'device' or input.startswith('device:')) and input_format == 'auto':
# Use a default format for device inputs
input_format = 'int16le,48000,1'
run_async(
run_transmit(
transport=transport,
broadcast_id=broadcast_id,
broadcast_code=broadcast_code,
broadcast_name=broadcast_name,
bitrate=bitrate,
manufacturer_data=manufacturer_data_tuple,
input=input,
input_format=input_format,
)
)
def main(): def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
auracast() auracast()

View File

@@ -16,6 +16,7 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import dataclasses
import enum import enum
import logging import logging
import os import os
@@ -97,49 +98,22 @@ DEFAULT_RFCOMM_MTU = 2048
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def parse_packet(packet):
if len(packet) < 1:
logging.info(
color(f'!!! Packet too short (got {len(packet)} bytes, need >= 1)', 'red')
)
raise ValueError('packet too short')
try:
packet_type = PacketType(packet[0])
except ValueError:
logging.info(color(f'!!! Invalid packet type 0x{packet[0]:02X}', 'red'))
raise
return (packet_type, packet[1:])
def parse_packet_sequence(packet_data):
if len(packet_data) < 5:
logging.info(
color(
f'!!!Packet too short (got {len(packet_data)} bytes, need >= 5)',
'red',
)
)
raise ValueError('packet too short')
return struct.unpack_from('>bI', packet_data, 0)
def le_phy_name(phy_id): def le_phy_name(phy_id):
return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get( return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
phy_id, HCI_Constant.le_phy_name(phy_id) phy_id, HCI_Constant.le_phy_name(phy_id)
) )
def print_connection_phy(phy):
logging.info(
color('@@@ PHY: ', 'yellow') + f'TX:{le_phy_name(phy.tx_phy)}/'
f'RX:{le_phy_name(phy.rx_phy)}'
)
def print_connection(connection): def print_connection(connection):
params = [] params = []
if connection.transport == BT_LE_TRANSPORT: if connection.transport == BT_LE_TRANSPORT:
params.append(
'PHY='
f'TX:{le_phy_name(connection.phy.tx_phy)}/'
f'RX:{le_phy_name(connection.phy.rx_phy)}'
)
params.append( params.append(
'DL=(' 'DL=('
f'TX:{connection.data_length[0]}/{connection.data_length[1]},' f'TX:{connection.data_length[0]}/{connection.data_length[1]},'
@@ -199,7 +173,7 @@ def log_stats(title, stats, precision=2):
stats_min = min(stats) stats_min = min(stats)
stats_max = max(stats) stats_max = max(stats)
stats_avg = statistics.mean(stats) stats_avg = statistics.mean(stats)
stats_stdev = statistics.stdev(stats) stats_stdev = statistics.stdev(stats) if len(stats) >= 2 else 0
logging.info( logging.info(
color( color(
( (
@@ -225,13 +199,135 @@ async def switch_roles(connection, role):
logging.info(f'{color("### Role switch failed:", "red")} {error}') logging.info(f'{color("### Role switch failed:", "red")} {error}')
class PacketType(enum.IntEnum): # -----------------------------------------------------------------------------
RESET = 0 # Packet
SEQUENCE = 1 # -----------------------------------------------------------------------------
ACK = 2 @dataclasses.dataclass
class Packet:
class PacketType(enum.IntEnum):
RESET = 0
SEQUENCE = 1
ACK = 2
class PacketFlags(enum.IntFlag):
LAST = 1
packet_type: PacketType
flags: PacketFlags = PacketFlags(0)
sequence: int = 0
timestamp: int = 0
payload: bytes = b""
@classmethod
def from_bytes(cls, data: bytes):
if len(data) < 1:
logging.warning(
color(f'!!! Packet too short (got {len(data)} bytes, need >= 1)', 'red')
)
raise ValueError('packet too short')
try:
packet_type = cls.PacketType(data[0])
except ValueError:
logging.warning(color(f'!!! Invalid packet type 0x{data[0]:02X}', 'red'))
raise
if packet_type == cls.PacketType.RESET:
return cls(packet_type)
flags = cls.PacketFlags(data[1])
(sequence,) = struct.unpack_from("<I", data, 2)
if packet_type == cls.PacketType.ACK:
if len(data) < 6:
logging.warning(
color(
f'!!! Packet too short (got {len(data)} bytes, need >= 6)',
'red',
)
)
return cls(packet_type, flags, sequence)
if len(data) < 10:
logging.warning(
color(
f'!!! Packet too short (got {len(data)} bytes, need >= 10)', 'red'
)
)
raise ValueError('packet too short')
(timestamp,) = struct.unpack_from("<I", data, 6)
return cls(packet_type, flags, sequence, timestamp, data[10:])
def __bytes__(self):
if self.packet_type == self.PacketType.RESET:
return bytes([self.packet_type])
if self.packet_type == self.PacketType.ACK:
return struct.pack("<BBI", self.packet_type, self.flags, self.sequence)
return (
struct.pack(
"<BBII", self.packet_type, self.flags, self.sequence, self.timestamp
)
+ self.payload
)
PACKET_FLAG_LAST = 1 # -----------------------------------------------------------------------------
# Jitter Stats
# -----------------------------------------------------------------------------
class JitterStats:
def __init__(self):
self.reset()
def reset(self):
self.packets = []
self.receive_times = []
self.jitter = []
def on_packet_received(self, packet):
now = time.time()
self.packets.append(packet)
self.receive_times.append(now)
if packet.timestamp and len(self.packets) > 1:
expected_time = (
self.receive_times[0]
+ (packet.timestamp - self.packets[0].timestamp) / 1000000
)
jitter = now - expected_time
else:
jitter = 0.0
self.jitter.append(jitter)
return jitter
def show_stats(self):
if len(self.jitter) < 3:
return
average = sum(self.jitter) / len(self.jitter)
adjusted = [jitter - average for jitter in self.jitter]
log_stats('Jitter (signed)', adjusted, 3)
log_stats('Jitter (absolute)', [abs(jitter) for jitter in adjusted], 3)
# Show a histogram
bin_count = 20
bins = [0] * bin_count
interval_min = min(adjusted)
interval_max = max(adjusted)
interval_range = interval_max - interval_min
bin_thresholds = [
interval_min + i * (interval_range / bin_count) for i in range(bin_count)
]
for jitter in adjusted:
for i in reversed(range(bin_count)):
if jitter >= bin_thresholds[i]:
bins[i] += 1
break
for i in range(bin_count):
logging.info(f'@@@ >= {bin_thresholds[i]:.4f}: {bins[i]}')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -281,19 +377,37 @@ class Sender:
await asyncio.sleep(self.tx_start_delay) await asyncio.sleep(self.tx_start_delay)
logging.info(color('=== Sending RESET', 'magenta')) logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(bytes([PacketType.RESET])) await self.packet_io.send_packet(
bytes(Packet(packet_type=Packet.PacketType.RESET))
)
self.start_time = time.time() self.start_time = time.time()
self.bytes_sent = 0 self.bytes_sent = 0
for tx_i in range(self.tx_packet_count): for tx_i in range(self.tx_packet_count):
packet_flags = ( if self.pace > 0:
PACKET_FLAG_LAST if tx_i == self.tx_packet_count - 1 else 0 # Wait until it is time to send the next packet
target_time = self.start_time + (tx_i * self.pace / 1000)
now = time.time()
if now < target_time:
await asyncio.sleep(target_time - now)
else:
await self.packet_io.drain()
packet = bytes(
Packet(
packet_type=Packet.PacketType.SEQUENCE,
flags=(
Packet.PacketFlags.LAST
if tx_i == self.tx_packet_count - 1
else 0
),
sequence=tx_i,
timestamp=int((time.time() - self.start_time) * 1000000),
payload=bytes(
self.tx_packet_size - 10 - self.packet_io.overhead_size
),
)
) )
packet = struct.pack(
'>bbI',
PacketType.SEQUENCE,
packet_flags,
tx_i,
) + bytes(self.tx_packet_size - 6 - self.packet_io.overhead_size)
logging.info( logging.info(
color( color(
f'Sending packet {tx_i}: {self.tx_packet_size} bytes', 'yellow' f'Sending packet {tx_i}: {self.tx_packet_size} bytes', 'yellow'
@@ -302,14 +416,6 @@ class Sender:
self.bytes_sent += len(packet) self.bytes_sent += len(packet)
await self.packet_io.send_packet(packet) await self.packet_io.send_packet(packet)
if self.pace is None:
continue
if self.pace > 0:
await asyncio.sleep(self.pace / 1000)
else:
await self.packet_io.drain()
await self.done.wait() await self.done.wait()
run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else '' run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
@@ -321,13 +427,13 @@ class Sender:
if self.repeat: if self.repeat:
logging.info(color('--- End of runs', 'blue')) logging.info(color('--- End of runs', 'blue'))
def on_packet_received(self, packet): def on_packet_received(self, data):
try: try:
packet_type, _ = parse_packet(packet) packet = Packet.from_bytes(data)
except ValueError: except ValueError:
return return
if packet_type == PacketType.ACK: if packet.packet_type == Packet.PacketType.ACK:
elapsed = time.time() - self.start_time elapsed = time.time() - self.start_time
average_tx_speed = self.bytes_sent / elapsed average_tx_speed = self.bytes_sent / elapsed
self.stats.append(average_tx_speed) self.stats.append(average_tx_speed)
@@ -350,52 +456,53 @@ class Receiver:
last_timestamp: float last_timestamp: float
def __init__(self, packet_io, linger): def __init__(self, packet_io, linger):
self.reset() self.jitter_stats = JitterStats()
self.packet_io = packet_io self.packet_io = packet_io
self.packet_io.packet_listener = self self.packet_io.packet_listener = self
self.linger = linger self.linger = linger
self.done = asyncio.Event() self.done = asyncio.Event()
self.reset()
def reset(self): def reset(self):
self.expected_packet_index = 0 self.expected_packet_index = 0
self.measurements = [(time.time(), 0)] self.measurements = [(time.time(), 0)]
self.total_bytes_received = 0 self.total_bytes_received = 0
self.jitter_stats.reset()
def on_packet_received(self, packet): def on_packet_received(self, data):
try: try:
packet_type, packet_data = parse_packet(packet) packet = Packet.from_bytes(data)
except ValueError: except ValueError:
logging.exception("invalid packet")
return return
if packet_type == PacketType.RESET: if packet.packet_type == Packet.PacketType.RESET:
logging.info(color('=== Received RESET', 'magenta')) logging.info(color('=== Received RESET', 'magenta'))
self.reset() self.reset()
return return
try: jitter = self.jitter_stats.on_packet_received(packet)
packet_flags, packet_index = parse_packet_sequence(packet_data)
except ValueError:
return
logging.info( logging.info(
f'<<< Received packet {packet_index}: ' f'<<< Received packet {packet.sequence}: '
f'flags=0x{packet_flags:02X}, ' f'flags={packet.flags}, '
f'{len(packet) + self.packet_io.overhead_size} bytes' f'jitter={jitter:.4f}, '
f'{len(data) + self.packet_io.overhead_size} bytes',
) )
if packet_index != self.expected_packet_index: if packet.sequence != self.expected_packet_index:
logging.info( logging.info(
color( color(
f'!!! Unexpected packet, expected {self.expected_packet_index} ' f'!!! Unexpected packet, expected {self.expected_packet_index} '
f'but received {packet_index}' f'but received {packet.sequence}'
) )
) )
now = time.time() now = time.time()
elapsed_since_start = now - self.measurements[0][0] elapsed_since_start = now - self.measurements[0][0]
elapsed_since_last = now - self.measurements[-1][0] elapsed_since_last = now - self.measurements[-1][0]
self.measurements.append((now, len(packet))) self.measurements.append((now, len(data)))
self.total_bytes_received += len(packet) self.total_bytes_received += len(data)
instant_rx_speed = len(packet) / elapsed_since_last instant_rx_speed = len(data) / elapsed_since_last
average_rx_speed = self.total_bytes_received / elapsed_since_start average_rx_speed = self.total_bytes_received / elapsed_since_start
window = self.measurements[-64:] window = self.measurements[-64:]
windowed_rx_speed = sum(measurement[1] for measurement in window[1:]) / ( windowed_rx_speed = sum(measurement[1] for measurement in window[1:]) / (
@@ -411,15 +518,17 @@ class Receiver:
) )
) )
self.expected_packet_index = packet_index + 1 self.expected_packet_index = packet.sequence + 1
if packet_flags & PACKET_FLAG_LAST: if packet.flags & Packet.PacketFlags.LAST:
AsyncRunner.spawn( AsyncRunner.spawn(
self.packet_io.send_packet( self.packet_io.send_packet(
struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index) bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence))
) )
) )
logging.info(color('@@@ Received last packet', 'green')) logging.info(color('@@@ Received last packet', 'green'))
self.jitter_stats.show_stats()
if not self.linger: if not self.linger:
self.done.set() self.done.set()
@@ -468,6 +577,7 @@ class Ping:
for run in range(self.repeat + 1): for run in range(self.repeat + 1):
self.done.clear() self.done.clear()
self.ping_times = []
if run > 0 and self.repeat and self.repeat_delay: if run > 0 and self.repeat and self.repeat_delay:
logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green')) logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green'))
@@ -478,25 +588,32 @@ class Ping:
await asyncio.sleep(self.tx_start_delay) await asyncio.sleep(self.tx_start_delay)
logging.info(color('=== Sending RESET', 'magenta')) logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(bytes([PacketType.RESET])) await self.packet_io.send_packet(bytes(Packet(Packet.PacketType.RESET)))
packet_interval = self.pace / 1000
start_time = time.time() start_time = time.time()
self.next_expected_packet_index = 0 self.next_expected_packet_index = 0
for i in range(self.tx_packet_count): for i in range(self.tx_packet_count):
target_time = start_time + (i * packet_interval) target_time = start_time + (i * self.pace / 1000)
now = time.time() now = time.time()
if now < target_time: if now < target_time:
await asyncio.sleep(target_time - now) await asyncio.sleep(target_time - now)
now = time.time()
packet = struct.pack( packet = bytes(
'>bbI', Packet(
PacketType.SEQUENCE, packet_type=Packet.PacketType.SEQUENCE,
(PACKET_FLAG_LAST if i == self.tx_packet_count - 1 else 0), flags=(
i, Packet.PacketFlags.LAST
) + bytes(self.tx_packet_size - 6) if i == self.tx_packet_count - 1
else 0
),
sequence=i,
timestamp=int((now - start_time) * 1000000),
payload=bytes(self.tx_packet_size - 10),
)
)
logging.info(color(f'Sending packet {i}', 'yellow')) logging.info(color(f'Sending packet {i}', 'yellow'))
self.ping_times.append(time.time()) self.ping_times.append(now)
await self.packet_io.send_packet(packet) await self.packet_io.send_packet(packet)
await self.done.wait() await self.done.wait()
@@ -530,40 +647,35 @@ class Ping:
if self.repeat: if self.repeat:
logging.info(color('--- End of runs', 'blue')) logging.info(color('--- End of runs', 'blue'))
def on_packet_received(self, packet): def on_packet_received(self, data):
try: try:
packet_type, packet_data = parse_packet(packet) packet = Packet.from_bytes(data)
except ValueError: except ValueError:
return return
try: if packet.packet_type == Packet.PacketType.ACK:
packet_flags, packet_index = parse_packet_sequence(packet_data) elapsed = time.time() - self.ping_times[packet.sequence]
except ValueError:
return
if packet_type == PacketType.ACK:
elapsed = time.time() - self.ping_times[packet_index]
rtt = elapsed * 1000 rtt = elapsed * 1000
self.rtts.append(rtt) self.rtts.append(rtt)
logging.info( logging.info(
color( color(
f'<<< Received ACK [{packet_index}], RTT={rtt:.2f}ms', f'<<< Received ACK [{packet.sequence}], RTT={rtt:.2f}ms',
'green', 'green',
) )
) )
if packet_index == self.next_expected_packet_index: if packet.sequence == self.next_expected_packet_index:
self.next_expected_packet_index += 1 self.next_expected_packet_index += 1
else: else:
logging.info( logging.info(
color( color(
f'!!! Unexpected packet, ' f'!!! Unexpected packet, '
f'expected {self.next_expected_packet_index} ' f'expected {self.next_expected_packet_index} '
f'but received {packet_index}' f'but received {packet.sequence}'
) )
) )
if packet_flags & PACKET_FLAG_LAST: if packet.flags & Packet.PacketFlags.LAST:
self.done.set() self.done.set()
return return
@@ -575,89 +687,56 @@ class Pong:
expected_packet_index: int expected_packet_index: int
def __init__(self, packet_io, linger): def __init__(self, packet_io, linger):
self.reset() self.jitter_stats = JitterStats()
self.packet_io = packet_io self.packet_io = packet_io
self.packet_io.packet_listener = self self.packet_io.packet_listener = self
self.linger = linger self.linger = linger
self.done = asyncio.Event() self.done = asyncio.Event()
self.reset()
def reset(self): def reset(self):
self.expected_packet_index = 0 self.expected_packet_index = 0
self.receive_times = [] self.jitter_stats.reset()
def on_packet_received(self, packet):
self.receive_times.append(time.time())
def on_packet_received(self, data):
try: try:
packet_type, packet_data = parse_packet(packet) packet = Packet.from_bytes(data)
except ValueError: except ValueError:
return return
if packet_type == PacketType.RESET: if packet.packet_type == Packet.PacketType.RESET:
logging.info(color('=== Received RESET', 'magenta')) logging.info(color('=== Received RESET', 'magenta'))
self.reset() self.reset()
return return
try: jitter = self.jitter_stats.on_packet_received(packet)
packet_flags, packet_index = parse_packet_sequence(packet_data)
except ValueError:
return
interval = (
self.receive_times[-1] - self.receive_times[-2]
if len(self.receive_times) >= 2
else 0
)
logging.info( logging.info(
color( color(
f'<<< Received packet {packet_index}: ' f'<<< Received packet {packet.sequence}: '
f'flags=0x{packet_flags:02X}, {len(packet)} bytes, ' f'flags={packet.flags}, {len(data)} bytes, '
f'interval={interval:.4f}', f'jitter={jitter:.4f}',
'green', 'green',
) )
) )
if packet_index != self.expected_packet_index: if packet.sequence != self.expected_packet_index:
logging.info( logging.info(
color( color(
f'!!! Unexpected packet, expected {self.expected_packet_index} ' f'!!! Unexpected packet, expected {self.expected_packet_index} '
f'but received {packet_index}' f'but received {packet.sequence}'
) )
) )
self.expected_packet_index = packet_index + 1 self.expected_packet_index = packet.sequence + 1
AsyncRunner.spawn( AsyncRunner.spawn(
self.packet_io.send_packet( self.packet_io.send_packet(
struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index) bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence))
) )
) )
if packet_flags & PACKET_FLAG_LAST: if packet.flags & Packet.PacketFlags.LAST:
if len(self.receive_times) >= 3: self.jitter_stats.show_stats()
# Show basic stats
intervals = [
self.receive_times[i + 1] - self.receive_times[i]
for i in range(len(self.receive_times) - 1)
]
log_stats('Packet intervals', intervals, 3)
# Show a histogram
bin_count = 20
bins = [0] * bin_count
interval_min = min(intervals)
interval_max = max(intervals)
interval_range = interval_max - interval_min
bin_thresholds = [
interval_min + i * (interval_range / bin_count)
for i in range(bin_count)
]
for interval in intervals:
for i in reversed(range(bin_count)):
if interval >= bin_thresholds[i]:
bins[i] += 1
break
for i in range(bin_count):
logging.info(f'@@@ >= {bin_thresholds[i]:.4f}: {bins[i]}')
if not self.linger: if not self.linger:
self.done.set() self.done.set()
@@ -1210,6 +1289,8 @@ class Central(Connection.Listener):
logging.info(color('### Connected', 'cyan')) logging.info(color('### Connected', 'cyan'))
self.connection.listener = self self.connection.listener = self
print_connection(self.connection) print_connection(self.connection)
phy = await self.connection.get_phy()
print_connection_phy(phy)
# Switch roles if needed. # Switch roles if needed.
if self.role_switch: if self.role_switch:
@@ -1267,8 +1348,8 @@ class Central(Connection.Listener):
def on_connection_parameters_update(self): def on_connection_parameters_update(self):
print_connection(self.connection) print_connection(self.connection)
def on_connection_phy_update(self): def on_connection_phy_update(self, phy):
print_connection(self.connection) print_connection_phy(phy)
def on_connection_att_mtu_update(self): def on_connection_att_mtu_update(self):
print_connection(self.connection) print_connection(self.connection)
@@ -1394,8 +1475,8 @@ class Peripheral(Device.Listener, Connection.Listener):
def on_connection_parameters_update(self): def on_connection_parameters_update(self):
print_connection(self.connection) print_connection(self.connection)
def on_connection_phy_update(self): def on_connection_phy_update(self, phy):
print_connection(self.connection) print_connection_phy(phy)
def on_connection_att_mtu_update(self): def on_connection_att_mtu_update(self):
print_connection(self.connection) print_connection(self.connection)
@@ -1470,7 +1551,7 @@ def create_mode_factory(ctx, default_mode):
def create_scenario_factory(ctx, default_scenario): def create_scenario_factory(ctx, default_scenario):
scenario = ctx.obj['scenario'] scenario = ctx.obj['scenario']
if scenario is None: if scenario is None:
scenarion = default_scenario scenario = default_scenario
def create_scenario(packet_io): def create_scenario(packet_io):
if scenario == 'send': if scenario == 'send':
@@ -1529,6 +1610,7 @@ def create_scenario_factory(ctx, default_scenario):
'--att-mtu', '--att-mtu',
metavar='MTU', metavar='MTU',
type=click.IntRange(23, 517), type=click.IntRange(23, 517),
default=517,
help='GATT MTU (gatt-client mode)', help='GATT MTU (gatt-client mode)',
) )
@click.option( @click.option(
@@ -1604,7 +1686,7 @@ def create_scenario_factory(ctx, default_scenario):
'--packet-size', '--packet-size',
'-s', '-s',
metavar='SIZE', metavar='SIZE',
type=click.IntRange(8, 8192), type=click.IntRange(10, 8192),
default=500, default=500,
help='Packet size (send or ping scenario)', help='Packet size (send or ping scenario)',
) )

View File

@@ -22,7 +22,6 @@
import asyncio import asyncio
import logging import logging
import os import os
import random
import re import re
import humanize import humanize
from typing import Optional, Union from typing import Optional, Union
@@ -57,7 +56,13 @@ from bumble import __version__
import bumble.core import bumble.core
from bumble import colors from bumble import colors
from bumble.core import UUID, AdvertisingData, BT_LE_TRANSPORT from bumble.core import UUID, AdvertisingData, BT_LE_TRANSPORT
from bumble.device import ConnectionParametersPreferences, Device, Connection, Peer from bumble.device import (
ConnectionParametersPreferences,
ConnectionPHY,
Device,
Connection,
Peer,
)
from bumble.utils import AsyncRunner from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
@@ -125,6 +130,7 @@ def parse_phys(phys):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ConsoleApp: class ConsoleApp:
connected_peer: Optional[Peer] connected_peer: Optional[Peer]
connection_phy: Optional[ConnectionPHY]
def __init__(self): def __init__(self):
self.known_addresses = set() self.known_addresses = set()
@@ -132,6 +138,7 @@ class ConsoleApp:
self.known_local_attributes = [] self.known_local_attributes = []
self.device = None self.device = None
self.connected_peer = None self.connected_peer = None
self.connection_phy = None
self.top_tab = 'device' self.top_tab = 'device'
self.monitor_rssi = False self.monitor_rssi = False
self.connection_rssi = None self.connection_rssi = None
@@ -332,10 +339,10 @@ class ConsoleApp:
f'{connection.parameters.peripheral_latency}/' f'{connection.parameters.peripheral_latency}/'
f'{connection.parameters.supervision_timeout}' f'{connection.parameters.supervision_timeout}'
) )
if connection.transport == BT_LE_TRANSPORT: if self.connection_phy is not None:
phy_state = ( phy_state = (
f' RX={le_phy_name(connection.phy.rx_phy)}/' f' RX={le_phy_name(self.connection_phy.rx_phy)}/'
f'TX={le_phy_name(connection.phy.tx_phy)}' f'TX={le_phy_name(self.connection_phy.tx_phy)}'
) )
else: else:
phy_state = '' phy_state = ''
@@ -654,11 +661,12 @@ class ConsoleApp:
self.append_to_output('connecting...') self.append_to_output('connecting...')
try: try:
await self.device.connect( connection = await self.device.connect(
params[0], params[0],
connection_parameters_preferences=connection_parameters_preferences, connection_parameters_preferences=connection_parameters_preferences,
timeout=DEFAULT_CONNECTION_TIMEOUT, timeout=DEFAULT_CONNECTION_TIMEOUT,
) )
self.connection_phy = await connection.get_phy()
self.top_tab = 'services' self.top_tab = 'services'
except bumble.core.TimeoutError: except bumble.core.TimeoutError:
self.show_error('connection timed out') self.show_error('connection timed out')
@@ -838,8 +846,8 @@ class ConsoleApp:
phy = await self.connected_peer.connection.get_phy() phy = await self.connected_peer.connection.get_phy()
self.append_to_output( self.append_to_output(
f'PHY: RX={HCI_Constant.le_phy_name(phy[0])}, ' f'PHY: RX={HCI_Constant.le_phy_name(phy.rx_phy)}, '
f'TX={HCI_Constant.le_phy_name(phy[1])}' f'TX={HCI_Constant.le_phy_name(phy.tx_phy)}'
) )
async def do_request_mtu(self, params): async def do_request_mtu(self, params):
@@ -1076,10 +1084,9 @@ class DeviceListener(Device.Listener, Connection.Listener):
f'{self.app.connected_peer.connection.parameters}' f'{self.app.connected_peer.connection.parameters}'
) )
def on_connection_phy_update(self): def on_connection_phy_update(self, phy):
self.app.append_to_output( self.app.connection_phy = phy
f'connection phy update: {self.app.connected_peer.connection.phy}' self.app.append_to_output(f'connection phy update: {phy}')
)
def on_connection_att_mtu_update(self): def on_connection_att_mtu_update(self):
self.app.append_to_output( self.app.append_to_output(

View File

@@ -37,6 +37,8 @@ from bumble.hci import (
HCI_Command_Status_Event, HCI_Command_Status_Event,
HCI_READ_BUFFER_SIZE_COMMAND, HCI_READ_BUFFER_SIZE_COMMAND,
HCI_Read_Buffer_Size_Command, HCI_Read_Buffer_Size_Command,
HCI_LE_READ_BUFFER_SIZE_V2_COMMAND,
HCI_LE_Read_Buffer_Size_V2_Command,
HCI_READ_BD_ADDR_COMMAND, HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command, HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND, HCI_READ_LOCAL_NAME_COMMAND,
@@ -75,7 +77,7 @@ async def get_classic_info(host: Host) -> None:
if command_succeeded(response): if command_succeeded(response):
print() print()
print( print(
color('Classic Address:', 'yellow'), color('Public Address:', 'yellow'),
response.return_parameters.bd_addr.to_string(False), response.return_parameters.bd_addr.to_string(False),
) )
@@ -147,7 +149,7 @@ async def get_le_info(host: Host) -> None:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_acl_flow_control_info(host: Host) -> None: async def get_flow_control_info(host: Host) -> None:
print() print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND): if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
@@ -160,14 +162,28 @@ async def get_acl_flow_control_info(host: Host) -> None:
f'packets of size {response.return_parameters.hc_acl_data_packet_length}', f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
) )
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND): if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.le_acl_data_packet_length}',
)
print(
color('LE ISO Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_iso_data_packets} '
f'packets of size {response.return_parameters.iso_data_packet_length}',
)
elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command( response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True HCI_LE_Read_Buffer_Size_Command(), check_result=True
) )
print( print(
color('LE ACL Flow Control:', 'yellow'), color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_le_acl_data_packets} ' f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}', f'packets of size {response.return_parameters.le_acl_data_packet_length}',
) )
@@ -274,8 +290,8 @@ async def async_main(latency_probes, transport):
# Get the LE info # Get the LE info
await get_le_info(host) await get_le_info(host)
# Print the ACL flow control info # Print the flow control info
await get_acl_flow_control_info(host) await get_flow_control_info(host)
# Get codec info # Get codec info
await get_codecs_info(host) await get_codecs_info(host)

View File

@@ -29,7 +29,9 @@ from bumble.gatt import Service
from bumble.profiles.device_information_service import DeviceInformationServiceProxy from bumble.profiles.device_information_service import DeviceInformationServiceProxy
from bumble.profiles.battery_service import BatteryServiceProxy from bumble.profiles.battery_service import BatteryServiceProxy
from bumble.profiles.gap import GenericAccessServiceProxy from bumble.profiles.gap import GenericAccessServiceProxy
from bumble.profiles.pacs import PublishedAudioCapabilitiesServiceProxy
from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy
from bumble.profiles.vcs import VolumeControlServiceProxy
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -126,14 +128,52 @@ async def show_tmas(
print(color('### Telephony And Media Audio Service', 'yellow')) print(color('### Telephony And Media Audio Service', 'yellow'))
if tmas.role: if tmas.role:
print( role = await tmas.role.read_value()
color(' Role:', 'green'), print(color(' Role:', 'green'), role)
await tmas.role.read_value(),
)
print() print()
# -----------------------------------------------------------------------------
async def show_pacs(pacs: PublishedAudioCapabilitiesServiceProxy) -> None:
print(color('### Published Audio Capabilities Service', 'yellow'))
contexts = await pacs.available_audio_contexts.read_value()
print(color(' Available Audio Contexts:', 'green'), contexts)
contexts = await pacs.supported_audio_contexts.read_value()
print(color(' Supported Audio Contexts:', 'green'), contexts)
if pacs.sink_pac:
pac = await pacs.sink_pac.read_value()
print(color(' Sink PAC: ', 'green'), pac)
if pacs.sink_audio_locations:
audio_locations = await pacs.sink_audio_locations.read_value()
print(color(' Sink Audio Locations: ', 'green'), audio_locations)
if pacs.source_pac:
pac = await pacs.source_pac.read_value()
print(color(' Source PAC: ', 'green'), pac)
if pacs.source_audio_locations:
audio_locations = await pacs.source_audio_locations.read_value()
print(color(' Source Audio Locations: ', 'green'), audio_locations)
print()
# -----------------------------------------------------------------------------
async def show_vcs(vcs: VolumeControlServiceProxy) -> None:
print(color('### Volume Control Service', 'yellow'))
volume_state = await vcs.volume_state.read_value()
print(color(' Volume State:', 'green'), volume_state)
volume_flags = await vcs.volume_flags.read_value()
print(color(' Volume Flags:', 'green'), volume_flags)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def show_device_info(peer, done: Optional[asyncio.Future]) -> None: async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
try: try:
@@ -161,6 +201,12 @@ async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy): if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy):
await try_show(show_tmas, tmas) await try_show(show_tmas, tmas)
if pacs := peer.create_service_proxy(PublishedAudioCapabilitiesServiceProxy):
await try_show(show_pacs, pacs)
if vcs := peer.create_service_proxy(VolumeControlServiceProxy):
await try_show(show_vcs, vcs)
if done is not None: if done is not None:
done.set_result(None) done.set_result(None)
except asyncio.CancelledError: except asyncio.CancelledError:

View File

@@ -83,7 +83,7 @@ async def async_main():
return_parameters=bytes([hci.HCI_SUCCESS]), return_parameters=bytes([hci.HCI_SUCCESS]),
) )
# Return a packet with 'respond to sender' set to True # Return a packet with 'respond to sender' set to True
return (response.to_bytes(), True) return (bytes(response), True)
return None return None

View File

@@ -16,23 +16,22 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import datetime import datetime
import enum
import functools import functools
from importlib import resources from importlib import resources
import json import json
import os import os
import logging import logging
import pathlib import pathlib
from typing import Optional, List, cast
import weakref import weakref
import struct import wave
import ctypes try:
import wasmtime import lc3 # type: ignore # pylint: disable=E0401
import wasmtime.loader except ImportError as e:
import liblc3 # type: ignore raise ImportError("Try `python -m pip install \".[lc3]\"`.") from e
import click import click
import aiohttp.web import aiohttp.web
@@ -40,11 +39,12 @@ import aiohttp.web
import bumble import bumble
from bumble.core import AdvertisingData from bumble.core import AdvertisingData
from bumble.colors import color from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters from bumble.device import Device, DeviceConfiguration, AdvertisingParameters, CisLink
from bumble.transport import open_transport from bumble.transport import open_transport
from bumble.profiles import ascs, bap, pacs from bumble.profiles import ascs, bap, pacs
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -54,6 +54,7 @@ logger = logging.getLogger(__name__)
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654 DEFAULT_UI_PORT = 7654
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
def _sink_pac_record() -> pacs.PacRecord: def _sink_pac_record() -> pacs.PacRecord:
@@ -100,153 +101,8 @@ def _source_pac_record() -> pacs.PacRecord:
) )
# ----------------------------------------------------------------------------- decoder: lc3.Decoder | None = None
# WASM - liblc3 encoding_config: bap.CodecSpecificConfiguration | None = None
# -----------------------------------------------------------------------------
store = wasmtime.loader.store
_memory = cast(wasmtime.Memory, liblc3.memory)
STACK_POINTER = _memory.data_len(store)
_memory.grow(store, 1)
# Mapping wasmtime memory to linear address
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
)
class Liblc3PcmFormat(enum.IntEnum):
S16 = 0
S24 = 1
S24_3LE = 2
FLOAT = 3
MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)
DECODER_STACK_POINTER = STACK_POINTER
ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
DEFAULT_PCM_SAMPLE_RATE = 48000
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
encoders: List[int] = []
decoders: List[int] = []
def setup_encoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
encoders[:num_channels] = [
liblc3.lc3_setup_encoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Input sample rate
ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
)
for i in range(num_channels)
]
def setup_decoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
decoders[:num_channels] = [
liblc3.lc3_setup_decoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
)
for i in range(num_channels)
]
def decode(
frame_duration_us: int,
num_channels: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = DECODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
input_bytes_per_frame = input_buffer_size // num_channels
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
* num_channels
)
for i in range(num_channels):
res = liblc3.lc3_decode(
decoders[i],
input_buffer_offset + input_bytes_per_frame * i,
input_bytes_per_frame,
DEFAULT_PCM_FORMAT,
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
num_channels, # Stride
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
def encode(
sdu_length: int,
num_channels: int,
stride: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = sdu_length
output_frame_size = output_buffer_size // num_channels
for i in range(num_channels):
res = liblc3.lc3_encode(
encoders[i],
DEFAULT_PCM_FORMAT,
input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
stride,
output_frame_size,
output_buffer_offset + output_frame_size * i,
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
async def lc3_source_task( async def lc3_source_task(
@@ -254,44 +110,49 @@ async def lc3_source_task(
sdu_length: int, sdu_length: int,
frame_duration_us: int, frame_duration_us: int,
device: Device, device: Device,
cis_handle: int, cis_link: CisLink,
) -> None: ) -> None:
with open(filename, 'rb') as f: logger.info(
header = f.read(44) "lc3_source_task filename=%s, sdu_length=%d, frame_duration=%.1f",
assert header[8:12] == b'WAVE' filename,
sdu_length,
frame_duration_us / 1000,
)
with wave.open(filename, 'rb') as wav:
bits_per_sample = wav.getsampwidth() * 8
pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = ( encoder: lc3.Encoder | None = None
struct.unpack("<HIIHH", header[22:36])
)
assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
frame_bytes = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
)
packet_sequence_number = 0
while True: while True:
next_round = datetime.datetime.now() + datetime.timedelta( next_round = datetime.datetime.now() + datetime.timedelta(
microseconds=frame_duration_us microseconds=frame_duration_us
) )
pcm_data = f.read(frame_bytes) if not encoder:
sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data) if (
encoding_config
and (frame_duration := encoding_config.frame_duration)
and (sampling_frequency := encoding_config.sampling_frequency)
and (
audio_channel_allocation := encoding_config.audio_channel_allocation
)
):
logger.info("Use %s", encoding_config)
encoder = lc3.Encoder(
frame_duration_us=frame_duration.us,
sample_rate_hz=sampling_frequency.hz,
num_channels=audio_channel_allocation.channel_count,
input_sample_rate_hz=wav.getframerate(),
)
else:
sdu = encoder.encode(
pcm=wav.readframes(encoder.get_frame_samples()),
num_bytes=sdu_length,
bit_depth=bits_per_sample,
)
cis_link.write(sdu)
iso_packet = HCI_IsoDataPacket(
connection_handle=cis_handle,
data_total_length=sdu_length + 4,
packet_sequence_number=packet_sequence_number,
pb_flag=0b10,
packet_status_flag=0,
iso_sdu_length=sdu_length,
iso_sdu_fragment=sdu,
)
device.host.send_hci_packet(iso_packet)
packet_sequence_number += 1
sleep_time = next_round - datetime.datetime.now() sleep_time = next_round - datetime.datetime.now()
await asyncio.sleep(sleep_time.total_seconds()) await asyncio.sleep(sleep_time.total_seconds() * 0.9)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -410,7 +271,7 @@ class Speaker:
def __init__( def __init__(
self, self,
device_config_path: Optional[str], device_config_path: str | None,
ui_port: int, ui_port: int,
transport: str, transport: str,
lc3_input_file_path: str, lc3_input_file_path: str,
@@ -437,6 +298,7 @@ class Speaker:
advertising_interval_min=25, advertising_interval_min=25,
advertising_interval_max=25, advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'), address=Address('F1:F2:F3:F4:F5:F6'),
identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
) )
device_config.le_enabled = True device_config.le_enabled = True
@@ -486,20 +348,31 @@ class Speaker:
def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine): def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine):
codec_config = ase.codec_specific_configuration codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration) if (
pcm = decode( not isinstance(codec_config, bap.CodecSpecificConfiguration)
codec_config.frame_duration.us, or codec_config.frame_duration is None
codec_config.audio_channel_allocation.channel_count, or codec_config.audio_channel_allocation is None
pdu.iso_sdu_fragment, or decoder is None
or not pdu.iso_sdu_fragment
):
return
pcm = decoder.decode(
pdu.iso_sdu_fragment, bit_depth=DEFAULT_PCM_BYTES_PER_SAMPLE * 8
) )
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm)) self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
def on_ase_state_change(ase: ascs.AseStateMachine) -> None: def on_ase_state_change(ase: ascs.AseStateMachine) -> None:
codec_config = ase.codec_specific_configuration
if ase.state == ascs.AseStateMachine.State.STREAMING: if ase.state == ascs.AseStateMachine.State.STREAMING:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
assert ase.cis_link
if ase.role == ascs.AudioRole.SOURCE: if ase.role == ascs.AudioRole.SOURCE:
if (
not isinstance(codec_config, bap.CodecSpecificConfiguration)
or ase.cis_link is None
or codec_config.octets_per_codec_frame is None
or codec_config.frame_duration is None
or codec_config.codec_frames_per_sdu is None
):
return
ase.cis_link.abort_on( ase.cis_link.abort_on(
'disconnection', 'disconnection',
lc3_source_task( lc3_source_task(
@@ -510,25 +383,30 @@ class Speaker:
), ),
frame_duration_us=codec_config.frame_duration.us, frame_duration_us=codec_config.frame_duration.us,
device=self.device, device=self.device,
cis_handle=ase.cis_link.handle, cis_link=ase.cis_link,
), ),
) )
else: else:
if not ase.cis_link:
return
ase.cis_link.sink = functools.partial(on_pdu, ase=ase) ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED: elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED:
codec_config = ase.codec_specific_configuration if (
assert isinstance(codec_config, bap.CodecSpecificConfiguration) not isinstance(codec_config, bap.CodecSpecificConfiguration)
or codec_config.sampling_frequency is None
or codec_config.frame_duration is None
or codec_config.audio_channel_allocation is None
):
return
if ase.role == ascs.AudioRole.SOURCE: if ase.role == ascs.AudioRole.SOURCE:
setup_encoders( global encoding_config
codec_config.sampling_frequency.hz, encoding_config = codec_config
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
else: else:
setup_decoders( global decoder
codec_config.sampling_frequency.hz, decoder = lc3.Decoder(
codec_config.frame_duration.us, frame_duration_us=codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count, sample_rate_hz=codec_config.sampling_frequency.hz,
num_channels=codec_config.audio_channel_allocation.channel_count,
) )
for ase in ascs_service.ase_state_machines.values(): for ase in ascs_service.ase_state_machines.values():
@@ -567,7 +445,7 @@ def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) ->
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def main(): def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
speaker() speaker()

Binary file not shown.

View File

@@ -373,7 +373,9 @@ async def pair(
shared_data = ( shared_data = (
None None
if oob == '-' if oob == '-'
else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob))) else OobData.from_ad(
AdvertisingData.from_bytes(bytes.fromhex(oob))
).shared_data
) )
legacy_context = OobLegacyContext() legacy_context = OobLegacyContext()
oob_contexts = PairingConfig.OobConfig( oob_contexts = PairingConfig.OobConfig(
@@ -381,16 +383,19 @@ async def pair(
peer_data=shared_data, peer_data=shared_data,
legacy_context=legacy_context, legacy_context=legacy_context,
) )
oob_data = OobData(
address=device.random_address,
shared_data=shared_data,
legacy_context=legacy_context,
)
print(color('@@@-----------------------------------', 'yellow')) print(color('@@@-----------------------------------', 'yellow'))
print(color('@@@ OOB Data:', 'yellow')) print(color('@@@ OOB Data:', 'yellow'))
print(color(f'@@@ {our_oob_context.share()}', 'yellow')) if shared_data is None:
oob_data = OobData(
address=device.random_address, shared_data=our_oob_context.share()
)
print(
color(
f'@@@ SHARE: {bytes(oob_data.to_ad()).hex()}',
'yellow',
)
)
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow')) print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
print(color(f'@@@ HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow'))
print(color('@@@-----------------------------------', 'yellow')) print(color('@@@-----------------------------------', 'yellow'))
else: else:
oob_contexts = None oob_contexts = None

View File

@@ -144,18 +144,18 @@ class Printer:
help='Format of the input file', help='Format of the input file',
) )
@click.option( @click.option(
'--vendors', '--vendor',
type=click.Choice(['android', 'zephyr']), type=click.Choice(['android', 'zephyr']),
multiple=True, multiple=True,
help='Support vendor-specific commands (list one or more)', help='Support vendor-specific commands (list one or more)',
) )
@click.argument('filename') @click.argument('filename')
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
def main(format, vendors, filename): def main(format, vendor, filename):
for vendor in vendors: for vendor_name in vendor:
if vendor == 'android': if vendor_name == 'android':
import bumble.vendor.android.hci import bumble.vendor.android.hci
elif vendor == 'zephyr': elif vendor_name == 'zephyr':
import bumble.vendor.zephyr.hci import bumble.vendor.zephyr.hci
input = open(filename, 'rb') input = open(filename, 'rb')
@@ -180,7 +180,7 @@ def main(format, vendors, filename):
else: else:
printer.print(color("[TRUNCATED]", "red")) printer.print(color("[TRUNCATED]", "red"))
except Exception as error: except Exception as error:
logger.exception() logger.exception('')
print(color(f'!!! {error}', 'red')) print(color(f'!!! {error}', 'red'))

View File

@@ -57,6 +57,7 @@ if TYPE_CHECKING:
# pylint: disable=line-too-long # pylint: disable=line-too-long
ATT_CID = 0x04 ATT_CID = 0x04
ATT_PSM = 0x001F
ATT_ERROR_RESPONSE = 0x01 ATT_ERROR_RESPONSE = 0x01
ATT_EXCHANGE_MTU_REQUEST = 0x02 ATT_EXCHANGE_MTU_REQUEST = 0x02
@@ -291,9 +292,6 @@ class ATT_PDU:
def init_from_bytes(self, pdu, offset): def init_from_bytes(self, pdu, offset):
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self):
return self.pdu
@property @property
def is_command(self): def is_command(self):
return ((self.op_code >> 6) & 1) == 1 return ((self.op_code >> 6) & 1) == 1
@@ -303,7 +301,7 @@ class ATT_PDU:
return ((self.op_code >> 7) & 1) == 1 return ((self.op_code >> 7) & 1) == 1
def __bytes__(self): def __bytes__(self):
return self.to_bytes() return self.pdu
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
@@ -759,13 +757,13 @@ class AttributeValue:
def __init__( def __init__(
self, self,
read: Union[ read: Union[
Callable[[Optional[Connection]], bytes], Callable[[Optional[Connection]], Any],
Callable[[Optional[Connection]], Awaitable[bytes]], Callable[[Optional[Connection]], Awaitable[Any]],
None, None,
] = None, ] = None,
write: Union[ write: Union[
Callable[[Optional[Connection], bytes], None], Callable[[Optional[Connection], Any], None],
Callable[[Optional[Connection], bytes], Awaitable[None]], Callable[[Optional[Connection], Any], Awaitable[None]],
None, None,
] = None, ] = None,
): ):
@@ -824,13 +822,13 @@ class Attribute(EventEmitter):
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
value: Union[bytes, AttributeValue] value: Any
def __init__( def __init__(
self, self,
attribute_type: Union[str, bytes, UUID], attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions], permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, AttributeValue] = b'', value: Any = b'',
) -> None: ) -> None:
EventEmitter.__init__(self) EventEmitter.__init__(self)
self.handle = 0 self.handle = 0
@@ -848,11 +846,7 @@ class Attribute(EventEmitter):
else: else:
self.type = attribute_type self.type = attribute_type
# Convert the value to a byte array self.value = value
if isinstance(value, str):
self.value = bytes(value, 'utf-8')
else:
self.value = value
def encode_value(self, value: Any) -> bytes: def encode_value(self, value: Any) -> bytes:
return value return value
@@ -895,6 +889,8 @@ class Attribute(EventEmitter):
else: else:
value = self.value value = self.value
self.emit('read', connection, value)
return self.encode_value(value) return self.encode_value(value)
async def write_value(self, connection: Connection, value_bytes: bytes) -> None: async def write_value(self, connection: Connection, value_bytes: bytes) -> None:

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC # Copyright 2025 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from setuptools import setup # -----------------------------------------------------------------------------
# Imports
setup() # -----------------------------------------------------------------------------

553
bumble/audio/io.py Normal file
View File

@@ -0,0 +1,553 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import abc
from concurrent.futures import ThreadPoolExecutor
import dataclasses
import enum
import logging
import pathlib
from typing import (
AsyncGenerator,
BinaryIO,
TYPE_CHECKING,
)
import sys
import wave
from bumble.colors import color
if TYPE_CHECKING:
import sounddevice # type: ignore[import-untyped]
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class PcmFormat:
class Endianness(enum.Enum):
LITTLE = 0
BIG = 1
class SampleType(enum.Enum):
FLOAT32 = 0
INT16 = 1
endianness: Endianness
sample_type: SampleType
sample_rate: int
channels: int
@classmethod
def from_str(cls, format_str: str) -> PcmFormat:
endianness = cls.Endianness.LITTLE # Others not yet supported.
sample_type_str, sample_rate_str, channels_str = format_str.split(',')
if sample_type_str == 'int16le':
sample_type = cls.SampleType.INT16
elif sample_type_str == 'float32le':
sample_type = cls.SampleType.FLOAT32
else:
raise ValueError(f'sample type {sample_type_str} not supported')
sample_rate = int(sample_rate_str)
channels = int(channels_str)
return cls(endianness, sample_type, sample_rate, channels)
@property
def bytes_per_sample(self) -> int:
return 2 if self.sample_type == self.SampleType.INT16 else 4
def check_audio_output(output: str) -> bool:
if output == 'device' or output.startswith('device:'):
try:
import sounddevice
except ImportError as exc:
raise ValueError(
'audio output not available (sounddevice python module not installed)'
) from exc
except OSError as exc:
raise ValueError(
'audio output not available '
'(sounddevice python module failed to load: '
f'{exc})'
) from exc
if output == 'device':
# Default device
return True
# Specific device
device = output[7:]
if device == '?':
print(color('Audio Devices:', 'yellow'))
for device_info in [
device_info
for device_info in sounddevice.query_devices()
if device_info['max_output_channels'] > 0
]:
device_index = device_info['index']
is_default = (
color(' [default]', 'green')
if sounddevice.default.device[1] == device_index
else ''
)
print(
f'{color(device_index, "cyan")}: {device_info["name"]}{is_default}'
)
return False
try:
device_info = sounddevice.query_devices(int(device))
except sounddevice.PortAudioError as exc:
raise ValueError('No such audio device') from exc
if device_info['max_output_channels'] < 1:
raise ValueError(
f'Device {device} ({device_info["name"]}) does not have an output'
)
return True
async def create_audio_output(output: str) -> AudioOutput:
if output == 'stdout':
return StreamAudioOutput(sys.stdout.buffer)
if output == 'device' or output.startswith('device:'):
device_name = '' if output == 'device' else output[7:]
return SoundDeviceAudioOutput(device_name)
if output == 'ffplay':
return SubprocessAudioOutput(
command=(
'ffplay -probesize 32 -fflags nobuffer -analyzeduration 0 '
'-ar {sample_rate} '
'-ch_layout {channel_layout} '
'-f f32le pipe:0'
)
)
if output.startswith('file:'):
return FileAudioOutput(output[5:])
raise ValueError('unsupported audio output')
class AudioOutput(abc.ABC):
"""Audio output to which PCM samples can be written."""
async def open(self, pcm_format: PcmFormat) -> None:
"""Start the output."""
@abc.abstractmethod
def write(self, pcm_samples: bytes) -> None:
"""Write PCM samples. Must not block."""
async def aclose(self) -> None:
"""Close the output."""
class ThreadedAudioOutput(AudioOutput):
"""Base class for AudioOutput classes that may need to call blocking functions.
The actual writing is performed in a thread, so as to ensure that calling write()
does not block the caller.
"""
def __init__(self) -> None:
self._thread_pool = ThreadPoolExecutor(1)
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
self._write_task = asyncio.create_task(self._write_loop())
async def _write_loop(self) -> None:
while True:
pcm_samples = await self._pcm_samples.get()
await asyncio.get_running_loop().run_in_executor(
self._thread_pool, self._write, pcm_samples
)
@abc.abstractmethod
def _write(self, pcm_samples: bytes) -> None:
"""This method does the actual writing and can block."""
def write(self, pcm_samples: bytes) -> None:
self._pcm_samples.put_nowait(pcm_samples)
def _close(self) -> None:
"""This method does the actual closing and can block."""
async def aclose(self) -> None:
await asyncio.get_running_loop().run_in_executor(self._thread_pool, self._close)
self._write_task.cancel()
self._thread_pool.shutdown()
class SoundDeviceAudioOutput(ThreadedAudioOutput):
def __init__(self, device_name: str) -> None:
super().__init__()
self._device = int(device_name) if device_name else None
self._stream: sounddevice.RawOutputStream | None = None
async def open(self, pcm_format: PcmFormat) -> None:
import sounddevice # pylint: disable=import-error
self._stream = sounddevice.RawOutputStream(
samplerate=pcm_format.sample_rate,
device=self._device,
channels=pcm_format.channels,
dtype='float32',
)
self._stream.start()
def _write(self, pcm_samples: bytes) -> None:
if self._stream is None:
return
try:
self._stream.write(pcm_samples)
except Exception as error:
print(f'Sound device error: {error}')
raise
def _close(self):
self._stream.stop()
self._stream = None
class StreamAudioOutput(ThreadedAudioOutput):
"""AudioOutput where PCM samples are written to a stream that may block."""
def __init__(self, stream: BinaryIO) -> None:
super().__init__()
self._stream = stream
def _write(self, pcm_samples: bytes) -> None:
self._stream.write(pcm_samples)
self._stream.flush()
class FileAudioOutput(StreamAudioOutput):
"""AudioOutput where PCM samples are written to a file."""
def __init__(self, filename: str) -> None:
self._file = open(filename, "wb")
super().__init__(self._file)
async def shutdown(self):
self._file.close()
return await super().shutdown()
class SubprocessAudioOutput(AudioOutput):
"""AudioOutput where audio samples are written to a subprocess via stdin."""
def __init__(self, command: str) -> None:
self._command = command
self._subprocess: asyncio.subprocess.Process | None
async def open(self, pcm_format: PcmFormat) -> None:
if pcm_format.channels == 1:
channel_layout = 'mono'
elif pcm_format.channels == 2:
channel_layout = 'stereo'
else:
raise ValueError(f'{pcm_format.channels} channels not supported')
command = self._command.format(
sample_rate=pcm_format.sample_rate, channel_layout=channel_layout
)
self._subprocess = await asyncio.create_subprocess_shell(
command,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
def write(self, pcm_samples: bytes) -> None:
if self._subprocess is None or self._subprocess.stdin is None:
return
self._subprocess.stdin.write(pcm_samples)
async def aclose(self):
if self._subprocess:
self._subprocess.terminate()
def check_audio_input(input: str) -> bool:
if input == 'device' or input.startswith('device:'):
try:
import sounddevice # pylint: disable=import-error
except ImportError as exc:
raise ValueError(
'audio input not available (sounddevice python module not installed)'
) from exc
except OSError as exc:
raise ValueError(
'audio input not available '
'(sounddevice python module failed to load: '
f'{exc})'
) from exc
if input == 'device':
# Default device
return True
# Specific device
device = input[7:]
if device == '?':
print(color('Audio Devices:', 'yellow'))
for device_info in [
device_info
for device_info in sounddevice.query_devices()
if device_info['max_input_channels'] > 0
]:
device_index = device_info["index"]
is_mono = device_info['max_input_channels'] == 1
max_channels = color(f'[{"mono" if is_mono else "stereo"}]', 'cyan')
is_default = (
color(' [default]', 'green')
if sounddevice.default.device[0] == device_index
else ''
)
print(
f'{color(device_index, "cyan")}: {device_info["name"]}'
f' {max_channels}{is_default}'
)
return False
try:
device_info = sounddevice.query_devices(int(device))
except sounddevice.PortAudioError as exc:
raise ValueError('No such audio device') from exc
if device_info['max_input_channels'] < 1:
raise ValueError(
f'Device {device} ({device_info["name"]}) does not have an input'
)
return True
async def create_audio_input(input: str, input_format: str) -> AudioInput:
pcm_format: PcmFormat | None
if input_format == 'auto':
pcm_format = None
else:
pcm_format = PcmFormat.from_str(input_format)
if input == 'stdin':
if not pcm_format:
raise ValueError('input format details required for stdin')
return StreamAudioInput(sys.stdin.buffer, pcm_format)
if input == 'device' or input.startswith('device:'):
if not pcm_format:
raise ValueError('input format details required for device')
device_name = '' if input == 'device' else input[7:]
return SoundDeviceAudioInput(device_name, pcm_format)
# If there's no file: prefix, check if we can assume it is a file.
if pathlib.Path(input).is_file():
input = 'file:' + input
if input.startswith('file:'):
filename = input[5:]
if filename.endswith('.wav'):
if input_format != 'auto':
raise ValueError(".wav file only supported with 'auto' format")
return WaveAudioInput(filename)
if pcm_format is None:
raise ValueError('input format details required for raw PCM files')
return FileAudioInput(filename, pcm_format)
raise ValueError('input not supported')
class AudioInput(abc.ABC):
"""Audio input that produces PCM samples."""
@abc.abstractmethod
async def open(self) -> PcmFormat:
"""Open the input."""
@abc.abstractmethod
def frames(self, frame_size: int) -> AsyncGenerator[bytes]:
"""Generate one frame of PCM samples. Must not block."""
async def aclose(self) -> None:
"""Close the input."""
class ThreadedAudioInput(AudioInput):
"""Base class for AudioInput implementation where reading samples may block."""
def __init__(self) -> None:
self._thread_pool = ThreadPoolExecutor(1)
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
@abc.abstractmethod
def _read(self, frame_size: int) -> bytes:
pass
@abc.abstractmethod
def _open(self) -> PcmFormat:
pass
def _close(self) -> None:
pass
async def open(self) -> PcmFormat:
return await asyncio.get_running_loop().run_in_executor(
self._thread_pool, self._open
)
async def frames(self, frame_size: int) -> AsyncGenerator[bytes]:
while pcm_sample := await asyncio.get_running_loop().run_in_executor(
self._thread_pool, self._read, frame_size
):
yield pcm_sample
async def aclose(self) -> None:
await asyncio.get_running_loop().run_in_executor(self._thread_pool, self._close)
self._thread_pool.shutdown()
class WaveAudioInput(ThreadedAudioInput):
"""Audio input that reads PCM samples from a .wav file."""
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._wav: wave.Wave_read | None = None
self._bytes_read = 0
def _open(self) -> PcmFormat:
self._wav = wave.open(self._filename, 'rb')
if self._wav.getsampwidth() != 2:
raise ValueError('sample width not supported')
return PcmFormat(
PcmFormat.Endianness.LITTLE,
PcmFormat.SampleType.INT16,
self._wav.getframerate(),
self._wav.getnchannels(),
)
def _read(self, frame_size: int) -> bytes:
if not self._wav:
return b''
pcm_samples = self._wav.readframes(frame_size)
if not pcm_samples and self._bytes_read:
# Loop around.
self._wav.rewind()
self._bytes_read = 0
pcm_samples = self._wav.readframes(frame_size)
self._bytes_read += len(pcm_samples)
return pcm_samples
def _close(self) -> None:
if self._wav:
self._wav.close()
class StreamAudioInput(ThreadedAudioInput):
"""AudioInput where samples are read from a raw PCM stream that may block."""
def __init__(self, stream: BinaryIO, pcm_format: PcmFormat) -> None:
super().__init__()
self._stream = stream
self._pcm_format = pcm_format
def _open(self) -> PcmFormat:
return self._pcm_format
def _read(self, frame_size: int) -> bytes:
return self._stream.read(
frame_size * self._pcm_format.channels * self._pcm_format.bytes_per_sample
)
class FileAudioInput(StreamAudioInput):
"""AudioInput where PCM samples are read from a raw PCM file."""
def __init__(self, filename: str, pcm_format: PcmFormat) -> None:
self._stream = open(filename, "rb")
super().__init__(self._stream, pcm_format)
def _close(self) -> None:
self._stream.close()
class SoundDeviceAudioInput(ThreadedAudioInput):
def __init__(self, device_name: str, pcm_format: PcmFormat) -> None:
super().__init__()
self._device = int(device_name) if device_name else None
self._pcm_format = pcm_format
self._stream: sounddevice.RawInputStream | None = None
def _open(self) -> PcmFormat:
import sounddevice # pylint: disable=import-error
self._stream = sounddevice.RawInputStream(
samplerate=self._pcm_format.sample_rate,
device=self._device,
channels=self._pcm_format.channels,
dtype='int16',
)
self._stream.start()
return PcmFormat(
PcmFormat.Endianness.LITTLE,
PcmFormat.SampleType.INT16,
self._pcm_format.sample_rate,
2,
)
def _read(self, frame_size: int) -> bytes:
if not self._stream:
return b''
pcm_buffer, overflowed = self._stream.read(frame_size)
if overflowed:
logger.warning("input overflow")
# Convert the buffer to stereo if needed
if self._pcm_format.channels == 1:
stereo_buffer = bytearray()
for i in range(frame_size):
sample = pcm_buffer[i * 2 : i * 2 + 2]
stereo_buffer += sample + sample
return stereo_buffer
return bytes(pcm_buffer)
def _close(self):
self._stream.stop()
self._stream = None

View File

@@ -154,15 +154,17 @@ class Controller:
'0000000060000000' '0000000060000000'
) # BR/EDR Not Supported, LE Supported (Controller) ) # BR/EDR Not Supported, LE Supported (Controller)
self.manufacturer_name = 0xFFFF self.manufacturer_name = 0xFFFF
self.hc_data_packet_length = 27 self.acl_data_packet_length = 27
self.hc_total_num_data_packets = 64 self.total_num_acl_data_packets = 64
self.hc_le_data_packet_length = 27 self.le_acl_data_packet_length = 27
self.hc_total_num_le_data_packets = 64 self.total_num_le_acl_data_packets = 64
self.iso_data_packet_length = 960
self.total_num_iso_data_packets = 64
self.event_mask = 0 self.event_mask = 0
self.event_mask_page_2 = 0 self.event_mask_page_2 = 0
self.supported_commands = bytes.fromhex( self.supported_commands = bytes.fromhex(
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000' '2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
'30f0f9ff01008004000000000000000000000000000000000000000000000000' '30f0f9ff01008004002000000000000000000000000000000000000000000000'
) )
self.le_event_mask = 0 self.le_event_mask = 0
self.advertising_parameters = None self.advertising_parameters = None
@@ -314,7 +316,7 @@ class Controller:
f'{color("CONTROLLER -> HOST", "green")}: {packet}' f'{color("CONTROLLER -> HOST", "green")}: {packet}'
) )
if self.host: if self.host:
self.host.on_packet(packet.to_bytes()) self.host.on_packet(bytes(packet))
# This method allows the controller to emulate the same API as a transport source # This method allows the controller to emulate the same API as a transport source
async def wait_for_termination(self): async def wait_for_termination(self):
@@ -1181,9 +1183,9 @@ class Controller:
return struct.pack( return struct.pack(
'<BHBHH', '<BHBHH',
HCI_SUCCESS, HCI_SUCCESS,
self.hc_data_packet_length, self.acl_data_packet_length,
0, 0,
self.hc_total_num_data_packets, self.total_num_acl_data_packets,
0, 0,
) )
@@ -1192,7 +1194,7 @@ class Controller:
See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command
''' '''
bd_addr = ( bd_addr = (
self._public_address.to_bytes() bytes(self._public_address)
if self._public_address is not None if self._public_address is not None
else bytes(6) else bytes(6)
) )
@@ -1212,8 +1214,21 @@ class Controller:
return struct.pack( return struct.pack(
'<BHB', '<BHB',
HCI_SUCCESS, HCI_SUCCESS,
self.hc_le_data_packet_length, self.le_acl_data_packet_length,
self.hc_total_num_le_data_packets, self.total_num_le_acl_data_packets,
)
def on_hci_le_read_buffer_size_v2_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.2 LE Read Buffer Size Command
'''
return struct.pack(
'<BHBHB',
HCI_SUCCESS,
self.le_acl_data_packet_length,
self.total_num_le_acl_data_packets,
self.iso_data_packet_length,
self.total_num_iso_data_packets,
) )
def on_hci_le_read_local_supported_features_command(self, _command): def on_hci_le_read_local_supported_features_command(self, _command):
@@ -1543,6 +1558,41 @@ class Controller:
} }
return bytes([HCI_SUCCESS]) return bytes([HCI_SUCCESS])
def on_hci_le_set_advertising_set_random_address_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.52 LE Set Advertising Set Random Address
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_extended_advertising_parameters_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.53 LE Set Extended Advertising Parameters
Command
'''
return bytes([HCI_SUCCESS, 0])
def on_hci_le_set_extended_advertising_data_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.54 LE Set Extended Advertising Data
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_extended_scan_response_data_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.55 LE Set Extended Scan Response Data
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_extended_advertising_enable_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.56 LE Set Extended Advertising Enable
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_read_maximum_advertising_data_length_command(self, _command): def on_hci_le_read_maximum_advertising_data_length_command(self, _command):
''' '''
See Bluetooth spec Vol 4, Part E - 7.8.57 LE Read Maximum Advertising Data See Bluetooth spec Vol 4, Part E - 7.8.57 LE Read Maximum Advertising Data
@@ -1557,6 +1607,27 @@ class Controller:
''' '''
return struct.pack('<BB', HCI_SUCCESS, 0xF0) return struct.pack('<BB', HCI_SUCCESS, 0xF0)
def on_hci_le_set_periodic_advertising_parameters_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.61 LE Set Periodic Advertising Parameters
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_periodic_advertising_data_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.62 LE Set Periodic Advertising Data
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_periodic_advertising_enable_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.63 LE Set Periodic Advertising Enable
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_read_transmit_power_command(self, _command): def on_hci_le_read_transmit_power_command(self, _command):
''' '''
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC # Copyright 2021-2025 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -16,10 +16,10 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import dataclasses
import enum import enum
import struct import struct
from typing import List, Optional, Tuple, Union, cast, Dict from typing import cast, overload, Literal, Union, Optional
from typing_extensions import Self from typing_extensions import Self
from bumble.company_ids import COMPANY_IDENTIFIERS from bumble.company_ids import COMPANY_IDENTIFIERS
@@ -57,7 +57,7 @@ def bit_flags_to_strings(bits, bit_flag_names):
return names return names
def name_or_number(dictionary: Dict[int, str], number: int, width: int = 2) -> str: def name_or_number(dictionary: dict[int, str], number: int, width: int = 2) -> str:
name = dictionary.get(number) name = dictionary.get(number)
if name is not None: if name is not None:
return name return name
@@ -200,7 +200,7 @@ class UUID:
''' '''
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian
UUIDS: List[UUID] = [] # Registry of all instances created UUIDS: list[UUID] = [] # Registry of all instances created
uuid_bytes: bytes uuid_bytes: bytes
name: Optional[str] name: Optional[str]
@@ -259,11 +259,11 @@ class UUID:
return cls.from_bytes(struct.pack('<I', uuid_32), name) return cls.from_bytes(struct.pack('<I', uuid_32), name)
@classmethod @classmethod
def parse_uuid(cls, uuid_as_bytes: bytes, offset: int) -> Tuple[int, UUID]: def parse_uuid(cls, uuid_as_bytes: bytes, offset: int) -> tuple[int, UUID]:
return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:]) return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:])
@classmethod @classmethod
def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> Tuple[int, UUID]: def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> tuple[int, UUID]:
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2]) return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
def to_bytes(self, force_128: bool = False) -> bytes: def to_bytes(self, force_128: bool = False) -> bytes:
@@ -1280,13 +1280,13 @@ class Appearance:
# Advertising Data # Advertising Data
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
AdvertisingDataObject = Union[ AdvertisingDataObject = Union[
List[UUID], list[UUID],
Tuple[UUID, bytes], tuple[UUID, bytes],
bytes, bytes,
str, str,
int, int,
Tuple[int, int], tuple[int, int],
Tuple[int, bytes], tuple[int, bytes],
Appearance, Appearance,
] ]
@@ -1295,116 +1295,116 @@ class AdvertisingData:
# fmt: off # fmt: off
# pylint: disable=line-too-long # pylint: disable=line-too-long
FLAGS = 0x01 class Type(OpenIntEnum):
INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02 FLAGS = 0x01
COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x03 INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02
INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = 0x04 COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x03
COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = 0x05 INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = 0x04
INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = 0x06 COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = 0x05
COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = 0x07 INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = 0x06
SHORTENED_LOCAL_NAME = 0x08 COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = 0x07
COMPLETE_LOCAL_NAME = 0x09 SHORTENED_LOCAL_NAME = 0x08
TX_POWER_LEVEL = 0x0A COMPLETE_LOCAL_NAME = 0x09
CLASS_OF_DEVICE = 0x0D TX_POWER_LEVEL = 0x0A
SIMPLE_PAIRING_HASH_C = 0x0E CLASS_OF_DEVICE = 0x0D
SIMPLE_PAIRING_HASH_C_192 = 0x0E SIMPLE_PAIRING_HASH_C = 0x0E
SIMPLE_PAIRING_RANDOMIZER_R = 0x0F SIMPLE_PAIRING_HASH_C_192 = 0x0E
SIMPLE_PAIRING_RANDOMIZER_R_192 = 0x0F SIMPLE_PAIRING_RANDOMIZER_R = 0x0F
DEVICE_ID = 0x10 SIMPLE_PAIRING_RANDOMIZER_R_192 = 0x0F
SECURITY_MANAGER_TK_VALUE = 0x10 DEVICE_ID = 0x10
SECURITY_MANAGER_OUT_OF_BAND_FLAGS = 0x11 SECURITY_MANAGER_TK_VALUE = 0x10
PERIPHERAL_CONNECTION_INTERVAL_RANGE = 0x12 SECURITY_MANAGER_OUT_OF_BAND_FLAGS = 0x11
LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS = 0x14 PERIPHERAL_CONNECTION_INTERVAL_RANGE = 0x12
LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS = 0x15 LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS = 0x14
SERVICE_DATA = 0x16 LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS = 0x15
SERVICE_DATA_16_BIT_UUID = 0x16 SERVICE_DATA_16_BIT_UUID = 0x16
PUBLIC_TARGET_ADDRESS = 0x17 PUBLIC_TARGET_ADDRESS = 0x17
RANDOM_TARGET_ADDRESS = 0x18 RANDOM_TARGET_ADDRESS = 0x18
APPEARANCE = 0x19 APPEARANCE = 0x19
ADVERTISING_INTERVAL = 0x1A ADVERTISING_INTERVAL = 0x1A
LE_BLUETOOTH_DEVICE_ADDRESS = 0x1B LE_BLUETOOTH_DEVICE_ADDRESS = 0x1B
LE_ROLE = 0x1C LE_ROLE = 0x1C
SIMPLE_PAIRING_HASH_C_256 = 0x1D SIMPLE_PAIRING_HASH_C_256 = 0x1D
SIMPLE_PAIRING_RANDOMIZER_R_256 = 0x1E SIMPLE_PAIRING_RANDOMIZER_R_256 = 0x1E
LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS = 0x1F LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS = 0x1F
SERVICE_DATA_32_BIT_UUID = 0x20 SERVICE_DATA_32_BIT_UUID = 0x20
SERVICE_DATA_128_BIT_UUID = 0x21 SERVICE_DATA_128_BIT_UUID = 0x21
LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE = 0x22 LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE = 0x22
LE_SECURE_CONNECTIONS_RANDOM_VALUE = 0x23 LE_SECURE_CONNECTIONS_RANDOM_VALUE = 0x23
URI = 0x24 URI = 0x24
INDOOR_POSITIONING = 0x25 INDOOR_POSITIONING = 0x25
TRANSPORT_DISCOVERY_DATA = 0x26 TRANSPORT_DISCOVERY_DATA = 0x26
LE_SUPPORTED_FEATURES = 0x27 LE_SUPPORTED_FEATURES = 0x27
CHANNEL_MAP_UPDATE_INDICATION = 0x28 CHANNEL_MAP_UPDATE_INDICATION = 0x28
PB_ADV = 0x29 PB_ADV = 0x29
MESH_MESSAGE = 0x2A MESH_MESSAGE = 0x2A
MESH_BEACON = 0x2B MESH_BEACON = 0x2B
BIGINFO = 0x2C BIGINFO = 0x2C
BROADCAST_CODE = 0x2D BROADCAST_CODE = 0x2D
RESOLVABLE_SET_IDENTIFIER = 0x2E RESOLVABLE_SET_IDENTIFIER = 0x2E
ADVERTISING_INTERVAL_LONG = 0x2F ADVERTISING_INTERVAL_LONG = 0x2F
BROADCAST_NAME = 0x30 BROADCAST_NAME = 0x30
ENCRYPTED_ADVERTISING_DATA = 0X31 ENCRYPTED_ADVERTISING_DATA = 0x31
PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION = 0X32 PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION = 0x32
ELECTRONIC_SHELF_LABEL = 0X34 ELECTRONIC_SHELF_LABEL = 0x34
THREE_D_INFORMATION_DATA = 0x3D THREE_D_INFORMATION_DATA = 0x3D
MANUFACTURER_SPECIFIC_DATA = 0xFF MANUFACTURER_SPECIFIC_DATA = 0xFF
AD_TYPE_NAMES = { # For backward-compatibility
FLAGS: 'FLAGS', FLAGS = Type.FLAGS
INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: 'INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS', INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = Type.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS
COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: 'COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS', COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = Type.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS
INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: 'INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS', INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = Type.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS
COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: 'COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS', COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = Type.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS
INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: 'INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS', INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = Type.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS
COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: 'COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS', COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = Type.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS
SHORTENED_LOCAL_NAME: 'SHORTENED_LOCAL_NAME', SHORTENED_LOCAL_NAME = Type.SHORTENED_LOCAL_NAME
COMPLETE_LOCAL_NAME: 'COMPLETE_LOCAL_NAME', COMPLETE_LOCAL_NAME = Type.COMPLETE_LOCAL_NAME
TX_POWER_LEVEL: 'TX_POWER_LEVEL', TX_POWER_LEVEL = Type.TX_POWER_LEVEL
CLASS_OF_DEVICE: 'CLASS_OF_DEVICE', CLASS_OF_DEVICE = Type.CLASS_OF_DEVICE
SIMPLE_PAIRING_HASH_C: 'SIMPLE_PAIRING_HASH_C', SIMPLE_PAIRING_HASH_C = Type.SIMPLE_PAIRING_HASH_C
SIMPLE_PAIRING_HASH_C_192: 'SIMPLE_PAIRING_HASH_C_192', SIMPLE_PAIRING_HASH_C_192 = Type.SIMPLE_PAIRING_HASH_C_192
SIMPLE_PAIRING_RANDOMIZER_R: 'SIMPLE_PAIRING_RANDOMIZER_R', SIMPLE_PAIRING_RANDOMIZER_R = Type.SIMPLE_PAIRING_RANDOMIZER_R
SIMPLE_PAIRING_RANDOMIZER_R_192: 'SIMPLE_PAIRING_RANDOMIZER_R_192', SIMPLE_PAIRING_RANDOMIZER_R_192 = Type.SIMPLE_PAIRING_RANDOMIZER_R_192
DEVICE_ID: 'DEVICE_ID', DEVICE_ID = Type.DEVICE_ID
SECURITY_MANAGER_TK_VALUE: 'SECURITY_MANAGER_TK_VALUE', SECURITY_MANAGER_TK_VALUE = Type.SECURITY_MANAGER_TK_VALUE
SECURITY_MANAGER_OUT_OF_BAND_FLAGS: 'SECURITY_MANAGER_OUT_OF_BAND_FLAGS', SECURITY_MANAGER_OUT_OF_BAND_FLAGS = Type.SECURITY_MANAGER_OUT_OF_BAND_FLAGS
PERIPHERAL_CONNECTION_INTERVAL_RANGE: 'PERIPHERAL_CONNECTION_INTERVAL_RANGE', PERIPHERAL_CONNECTION_INTERVAL_RANGE = Type.PERIPHERAL_CONNECTION_INTERVAL_RANGE
LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS: 'LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS', LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS = Type.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS
LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS: 'LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS', LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS = Type.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS
SERVICE_DATA_16_BIT_UUID: 'SERVICE_DATA_16_BIT_UUID', SERVICE_DATA = Type.SERVICE_DATA_16_BIT_UUID
PUBLIC_TARGET_ADDRESS: 'PUBLIC_TARGET_ADDRESS', SERVICE_DATA_16_BIT_UUID = Type.SERVICE_DATA_16_BIT_UUID
RANDOM_TARGET_ADDRESS: 'RANDOM_TARGET_ADDRESS', PUBLIC_TARGET_ADDRESS = Type.PUBLIC_TARGET_ADDRESS
APPEARANCE: 'APPEARANCE', RANDOM_TARGET_ADDRESS = Type.RANDOM_TARGET_ADDRESS
ADVERTISING_INTERVAL: 'ADVERTISING_INTERVAL', APPEARANCE = Type.APPEARANCE
LE_BLUETOOTH_DEVICE_ADDRESS: 'LE_BLUETOOTH_DEVICE_ADDRESS', ADVERTISING_INTERVAL = Type.ADVERTISING_INTERVAL
LE_ROLE: 'LE_ROLE', LE_BLUETOOTH_DEVICE_ADDRESS = Type.LE_BLUETOOTH_DEVICE_ADDRESS
SIMPLE_PAIRING_HASH_C_256: 'SIMPLE_PAIRING_HASH_C_256', LE_ROLE = Type.LE_ROLE
SIMPLE_PAIRING_RANDOMIZER_R_256: 'SIMPLE_PAIRING_RANDOMIZER_R_256', SIMPLE_PAIRING_HASH_C_256 = Type.SIMPLE_PAIRING_HASH_C_256
LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS: 'LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS', SIMPLE_PAIRING_RANDOMIZER_R_256 = Type.SIMPLE_PAIRING_RANDOMIZER_R_256
SERVICE_DATA_32_BIT_UUID: 'SERVICE_DATA_32_BIT_UUID', LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS = Type.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS
SERVICE_DATA_128_BIT_UUID: 'SERVICE_DATA_128_BIT_UUID', SERVICE_DATA_32_BIT_UUID = Type.SERVICE_DATA_32_BIT_UUID
LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE: 'LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE', SERVICE_DATA_128_BIT_UUID = Type.SERVICE_DATA_128_BIT_UUID
LE_SECURE_CONNECTIONS_RANDOM_VALUE: 'LE_SECURE_CONNECTIONS_RANDOM_VALUE', LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE = Type.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE
URI: 'URI', LE_SECURE_CONNECTIONS_RANDOM_VALUE = Type.LE_SECURE_CONNECTIONS_RANDOM_VALUE
INDOOR_POSITIONING: 'INDOOR_POSITIONING', URI = Type.URI
TRANSPORT_DISCOVERY_DATA: 'TRANSPORT_DISCOVERY_DATA', INDOOR_POSITIONING = Type.INDOOR_POSITIONING
LE_SUPPORTED_FEATURES: 'LE_SUPPORTED_FEATURES', TRANSPORT_DISCOVERY_DATA = Type.TRANSPORT_DISCOVERY_DATA
CHANNEL_MAP_UPDATE_INDICATION: 'CHANNEL_MAP_UPDATE_INDICATION', LE_SUPPORTED_FEATURES = Type.LE_SUPPORTED_FEATURES
PB_ADV: 'PB_ADV', CHANNEL_MAP_UPDATE_INDICATION = Type.CHANNEL_MAP_UPDATE_INDICATION
MESH_MESSAGE: 'MESH_MESSAGE', PB_ADV = Type.PB_ADV
MESH_BEACON: 'MESH_BEACON', MESH_MESSAGE = Type.MESH_MESSAGE
BIGINFO: 'BIGINFO', MESH_BEACON = Type.MESH_BEACON
BROADCAST_CODE: 'BROADCAST_CODE', BIGINFO = Type.BIGINFO
RESOLVABLE_SET_IDENTIFIER: 'RESOLVABLE_SET_IDENTIFIER', BROADCAST_CODE = Type.BROADCAST_CODE
ADVERTISING_INTERVAL_LONG: 'ADVERTISING_INTERVAL_LONG', RESOLVABLE_SET_IDENTIFIER = Type.RESOLVABLE_SET_IDENTIFIER
BROADCAST_NAME: 'BROADCAST_NAME', ADVERTISING_INTERVAL_LONG = Type.ADVERTISING_INTERVAL_LONG
ENCRYPTED_ADVERTISING_DATA: 'ENCRYPTED_ADVERTISING_DATA', BROADCAST_NAME = Type.BROADCAST_NAME
PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION: 'PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION', ENCRYPTED_ADVERTISING_DATA = Type.ENCRYPTED_ADVERTISING_DATA
ELECTRONIC_SHELF_LABEL: 'ELECTRONIC_SHELF_LABEL', PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION = Type.PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION
THREE_D_INFORMATION_DATA: 'THREE_D_INFORMATION_DATA', ELECTRONIC_SHELF_LABEL = Type.ELECTRONIC_SHELF_LABEL
MANUFACTURER_SPECIFIC_DATA: 'MANUFACTURER_SPECIFIC_DATA' THREE_D_INFORMATION_DATA = Type.THREE_D_INFORMATION_DATA
} MANUFACTURER_SPECIFIC_DATA = Type.MANUFACTURER_SPECIFIC_DATA
LE_LIMITED_DISCOVERABLE_MODE_FLAG = 0x01 LE_LIMITED_DISCOVERABLE_MODE_FLAG = 0x01
LE_GENERAL_DISCOVERABLE_MODE_FLAG = 0x02 LE_GENERAL_DISCOVERABLE_MODE_FLAG = 0x02
@@ -1412,12 +1412,12 @@ class AdvertisingData:
BR_EDR_CONTROLLER_FLAG = 0x08 BR_EDR_CONTROLLER_FLAG = 0x08
BR_EDR_HOST_FLAG = 0x10 BR_EDR_HOST_FLAG = 0x10
ad_structures: List[Tuple[int, bytes]] ad_structures: list[tuple[int, bytes]]
# fmt: on # fmt: on
# pylint: enable=line-too-long # pylint: enable=line-too-long
def __init__(self, ad_structures: Optional[List[Tuple[int, bytes]]] = None) -> None: def __init__(self, ad_structures: Optional[list[tuple[int, bytes]]] = None) -> None:
if ad_structures is None: if ad_structures is None:
ad_structures = [] ad_structures = []
self.ad_structures = ad_structures[:] self.ad_structures = ad_structures[:]
@@ -1444,7 +1444,7 @@ class AdvertisingData:
return ','.join(bit_flags_to_strings(flags, flag_names)) return ','.join(bit_flags_to_strings(flags, flag_names))
@staticmethod @staticmethod
def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> List[UUID]: def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> list[UUID]:
uuids = [] uuids = []
offset = 0 offset = 0
while (offset + uuid_size) <= len(ad_data): while (offset + uuid_size) <= len(ad_data):
@@ -1461,8 +1461,8 @@ class AdvertisingData:
] ]
) )
@staticmethod @classmethod
def ad_data_to_string(ad_type, ad_data): def ad_data_to_string(cls, ad_type: int, ad_data: bytes) -> str:
if ad_type == AdvertisingData.FLAGS: if ad_type == AdvertisingData.FLAGS:
ad_type_str = 'Flags' ad_type_str = 'Flags'
ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True) ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True)
@@ -1501,7 +1501,10 @@ class AdvertisingData:
ad_data_str = f'"{ad_data.decode("utf-8")}"' ad_data_str = f'"{ad_data.decode("utf-8")}"'
elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME: elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME:
ad_type_str = 'Complete Local Name' ad_type_str = 'Complete Local Name'
ad_data_str = f'"{ad_data.decode("utf-8")}"' try:
ad_data_str = f'"{ad_data.decode("utf-8")}"'
except UnicodeDecodeError:
ad_data_str = ad_data.hex()
elif ad_type == AdvertisingData.TX_POWER_LEVEL: elif ad_type == AdvertisingData.TX_POWER_LEVEL:
ad_type_str = 'TX Power Level' ad_type_str = 'TX Power Level'
ad_data_str = str(ad_data[0]) ad_data_str = str(ad_data[0])
@@ -1518,72 +1521,72 @@ class AdvertisingData:
ad_type_str = 'Broadcast Name' ad_type_str = 'Broadcast Name'
ad_data_str = ad_data.decode('utf-8') ad_data_str = ad_data.decode('utf-8')
else: else:
ad_type_str = AdvertisingData.AD_TYPE_NAMES.get(ad_type, f'0x{ad_type:02X}') ad_type_str = AdvertisingData.Type(ad_type).name
ad_data_str = ad_data.hex() ad_data_str = ad_data.hex()
return f'[{ad_type_str}]: {ad_data_str}' return f'[{ad_type_str}]: {ad_data_str}'
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
@staticmethod @classmethod
def ad_data_to_object(ad_type: int, ad_data: bytes) -> AdvertisingDataObject: def ad_data_to_object(cls, ad_type: int, ad_data: bytes) -> AdvertisingDataObject:
if ad_type in ( if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.Type.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.Type.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS, AdvertisingData.Type.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
): ):
return AdvertisingData.uuid_list_to_objects(ad_data, 2) return AdvertisingData.uuid_list_to_objects(ad_data, 2)
if ad_type in ( if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.Type.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.Type.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS, AdvertisingData.Type.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
): ):
return AdvertisingData.uuid_list_to_objects(ad_data, 4) return AdvertisingData.uuid_list_to_objects(ad_data, 4)
if ad_type in ( if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.Type.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.Type.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS, AdvertisingData.Type.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
): ):
return AdvertisingData.uuid_list_to_objects(ad_data, 16) return AdvertisingData.uuid_list_to_objects(ad_data, 16)
if ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID: if ad_type == AdvertisingData.Type.SERVICE_DATA_16_BIT_UUID:
return (UUID.from_bytes(ad_data[:2]), ad_data[2:]) return (UUID.from_bytes(ad_data[:2]), ad_data[2:])
if ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID: if ad_type == AdvertisingData.Type.SERVICE_DATA_32_BIT_UUID:
return (UUID.from_bytes(ad_data[:4]), ad_data[4:]) return (UUID.from_bytes(ad_data[:4]), ad_data[4:])
if ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID: if ad_type == AdvertisingData.Type.SERVICE_DATA_128_BIT_UUID:
return (UUID.from_bytes(ad_data[:16]), ad_data[16:]) return (UUID.from_bytes(ad_data[:16]), ad_data[16:])
if ad_type in ( if ad_type in (
AdvertisingData.SHORTENED_LOCAL_NAME, AdvertisingData.Type.SHORTENED_LOCAL_NAME,
AdvertisingData.COMPLETE_LOCAL_NAME, AdvertisingData.Type.COMPLETE_LOCAL_NAME,
AdvertisingData.URI, AdvertisingData.Type.URI,
AdvertisingData.BROADCAST_NAME, AdvertisingData.Type.BROADCAST_NAME,
): ):
return ad_data.decode("utf-8") return ad_data.decode("utf-8")
if ad_type in (AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS): if ad_type in (AdvertisingData.Type.TX_POWER_LEVEL, AdvertisingData.Type.FLAGS):
return cast(int, struct.unpack('B', ad_data)[0]) return cast(int, struct.unpack('B', ad_data)[0])
if ad_type in (AdvertisingData.ADVERTISING_INTERVAL,): if ad_type in (AdvertisingData.Type.ADVERTISING_INTERVAL,):
return cast(int, struct.unpack('<H', ad_data)[0]) return cast(int, struct.unpack('<H', ad_data)[0])
if ad_type == AdvertisingData.CLASS_OF_DEVICE: if ad_type == AdvertisingData.Type.CLASS_OF_DEVICE:
return cast(int, struct.unpack('<I', bytes([*ad_data, 0]))[0]) return cast(int, struct.unpack('<I', bytes([*ad_data, 0]))[0])
if ad_type == AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE: if ad_type == AdvertisingData.Type.PERIPHERAL_CONNECTION_INTERVAL_RANGE:
return cast(Tuple[int, int], struct.unpack('<HH', ad_data)) return cast(tuple[int, int], struct.unpack('<HH', ad_data))
if ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA: if ad_type == AdvertisingData.Type.APPEARANCE:
return (cast(int, struct.unpack_from('<H', ad_data, 0)[0]), ad_data[2:])
if ad_type == AdvertisingData.APPEARANCE:
return Appearance.from_int( return Appearance.from_int(
cast(int, struct.unpack_from('<H', ad_data, 0)[0]) cast(int, struct.unpack_from('<H', ad_data, 0)[0])
) )
if ad_type == AdvertisingData.Type.MANUFACTURER_SPECIFIC_DATA:
return (cast(int, struct.unpack_from('<H', ad_data, 0)[0]), ad_data[2:])
return ad_data return ad_data
def append(self, data: bytes) -> None: def append(self, data: bytes) -> None:
@@ -1597,7 +1600,80 @@ class AdvertisingData:
self.ad_structures.append((ad_type, ad_data)) self.ad_structures.append((ad_type, ad_data))
offset += length offset += length
def get_all(self, type_id: int, raw: bool = False) -> List[AdvertisingDataObject]: @overload
def get_all(
self,
type_id: Literal[
AdvertisingData.Type.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
AdvertisingData.Type.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
AdvertisingData.Type.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
],
raw: Literal[False] = False,
) -> list[list[UUID]]: ...
@overload
def get_all(
self,
type_id: Literal[
AdvertisingData.Type.SERVICE_DATA_16_BIT_UUID,
AdvertisingData.Type.SERVICE_DATA_32_BIT_UUID,
AdvertisingData.Type.SERVICE_DATA_128_BIT_UUID,
],
raw: Literal[False] = False,
) -> list[tuple[UUID, bytes]]: ...
@overload
def get_all(
self,
type_id: Literal[
AdvertisingData.Type.SHORTENED_LOCAL_NAME,
AdvertisingData.Type.COMPLETE_LOCAL_NAME,
AdvertisingData.Type.URI,
AdvertisingData.Type.BROADCAST_NAME,
],
raw: Literal[False] = False,
) -> list[str]: ...
@overload
def get_all(
self,
type_id: Literal[
AdvertisingData.Type.TX_POWER_LEVEL,
AdvertisingData.Type.FLAGS,
AdvertisingData.Type.ADVERTISING_INTERVAL,
AdvertisingData.Type.CLASS_OF_DEVICE,
],
raw: Literal[False] = False,
) -> list[int]: ...
@overload
def get_all(
self,
type_id: Literal[AdvertisingData.Type.PERIPHERAL_CONNECTION_INTERVAL_RANGE,],
raw: Literal[False] = False,
) -> list[tuple[int, int]]: ...
@overload
def get_all(
self,
type_id: Literal[AdvertisingData.Type.MANUFACTURER_SPECIFIC_DATA,],
raw: Literal[False] = False,
) -> list[tuple[int, bytes]]: ...
@overload
def get_all(
self,
type_id: Literal[AdvertisingData.Type.APPEARANCE,],
raw: Literal[False] = False,
) -> list[Appearance]: ...
@overload
def get_all(self, type_id: int, raw: Literal[True]) -> list[bytes]: ...
@overload
def get_all(
self, type_id: int, raw: bool = False
) -> list[AdvertisingDataObject]: ...
def get_all(self, type_id: int, raw: bool = False) -> list[AdvertisingDataObject]: # type: ignore[misc]
''' '''
Get Advertising Data Structure(s) with a given type Get Advertising Data Structure(s) with a given type
@@ -1609,6 +1685,79 @@ class AdvertisingData:
return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id] return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id]
@overload
def get(
self,
type_id: Literal[
AdvertisingData.Type.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
AdvertisingData.Type.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
AdvertisingData.Type.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
],
raw: Literal[False] = False,
) -> Optional[list[UUID]]: ...
@overload
def get(
self,
type_id: Literal[
AdvertisingData.Type.SERVICE_DATA_16_BIT_UUID,
AdvertisingData.Type.SERVICE_DATA_32_BIT_UUID,
AdvertisingData.Type.SERVICE_DATA_128_BIT_UUID,
],
raw: Literal[False] = False,
) -> Optional[tuple[UUID, bytes]]: ...
@overload
def get(
self,
type_id: Literal[
AdvertisingData.Type.SHORTENED_LOCAL_NAME,
AdvertisingData.Type.COMPLETE_LOCAL_NAME,
AdvertisingData.Type.URI,
AdvertisingData.Type.BROADCAST_NAME,
],
raw: Literal[False] = False,
) -> Optional[Optional[str]]: ...
@overload
def get(
self,
type_id: Literal[
AdvertisingData.Type.TX_POWER_LEVEL,
AdvertisingData.Type.FLAGS,
AdvertisingData.Type.ADVERTISING_INTERVAL,
AdvertisingData.Type.CLASS_OF_DEVICE,
],
raw: Literal[False] = False,
) -> Optional[int]: ...
@overload
def get(
self,
type_id: Literal[AdvertisingData.Type.PERIPHERAL_CONNECTION_INTERVAL_RANGE,],
raw: Literal[False] = False,
) -> Optional[tuple[int, int]]: ...
@overload
def get(
self,
type_id: Literal[AdvertisingData.Type.MANUFACTURER_SPECIFIC_DATA,],
raw: Literal[False] = False,
) -> Optional[tuple[int, bytes]]: ...
@overload
def get(
self,
type_id: Literal[AdvertisingData.Type.APPEARANCE,],
raw: Literal[False] = False,
) -> Optional[Appearance]: ...
@overload
def get(self, type_id: int, raw: Literal[True]) -> Optional[bytes]: ...
@overload
def get(
self, type_id: int, raw: bool = False
) -> Optional[AdvertisingDataObject]: ...
def get(self, type_id: int, raw: bool = False) -> Optional[AdvertisingDataObject]: def get(self, type_id: int, raw: bool = False) -> Optional[AdvertisingDataObject]:
''' '''
Get Advertising Data Structure(s) with a given type Get Advertising Data Structure(s) with a given type

File diff suppressed because it is too large Load Diff

View File

@@ -20,6 +20,8 @@ Common types for drivers.
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import abc import abc
from bumble import core
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Classes # Classes

View File

@@ -11,18 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Support for Intel USB controllers.
Loosely based on the Fuchsia OS implementation.
"""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import collections
import dataclasses
import logging import logging
import os
import pathlib
import platform
import struct
from typing import Any, Deque, Optional, TYPE_CHECKING
from bumble import core
from bumble.drivers import common from bumble.drivers import common
from bumble.hci import ( from bumble import hci
hci_vendor_command_op_code, # type: ignore from bumble import utils
HCI_Command,
HCI_Reset_Command, if TYPE_CHECKING:
) from bumble.host import Host
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -34,39 +49,328 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
INTEL_USB_PRODUCTS = { INTEL_USB_PRODUCTS = {
# Intel AX210 (0x8087, 0x0032), # AX210
(0x8087, 0x0032), (0x8087, 0x0036), # BE200
# Intel BE200
(0x8087, 0x0036),
} }
INTEL_FW_IMAGE_NAMES = [
"ibt-0040-0041",
"ibt-0040-1020",
"ibt-0040-1050",
"ibt-0040-2120",
"ibt-0040-4150",
"ibt-0041-0041",
"ibt-0180-0041",
"ibt-0180-1050",
"ibt-0180-4150",
"ibt-0291-0291",
"ibt-1040-0041",
"ibt-1040-1020",
"ibt-1040-1050",
"ibt-1040-2120",
"ibt-1040-4150",
]
INTEL_FIRMWARE_DIR_ENV = "BUMBLE_INTEL_FIRMWARE_DIR"
INTEL_LINUX_FIRMWARE_DIR = "/lib/firmware/intel"
_MAX_FRAGMENT_SIZE = 252
_POST_RESET_DELAY = 0.2
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# HCI Commands # HCI Commands
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
HCI_INTEL_DDC_CONFIG_WRITE_COMMAND = hci_vendor_command_op_code(0xFC8B) # type: ignore HCI_INTEL_WRITE_DEVICE_CONFIG_COMMAND = hci.hci_vendor_command_op_code(0x008B)
HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD = [0x03, 0xE4, 0x02, 0x00] HCI_INTEL_READ_VERSION_COMMAND = hci.hci_vendor_command_op_code(0x0005)
HCI_INTEL_RESET_COMMAND = hci.hci_vendor_command_op_code(0x0001)
HCI_INTEL_SECURE_SEND_COMMAND = hci.hci_vendor_command_op_code(0x0009)
HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND = hci.hci_vendor_command_op_code(0x000E)
HCI_Command.register_commands(globals()) hci.HCI_Command.register_commands(globals())
@HCI_Command.command( # type: ignore @hci.HCI_Command.command(
fields=[("params", "*")], fields=[
("param0", 1),
],
return_parameters_fields=[ return_parameters_fields=[
("params", "*"), ("status", hci.STATUS_SPEC),
("tlv", "*"),
], ],
) )
class Hci_Intel_DDC_Config_Write_Command(HCI_Command): class HCI_Intel_Read_Version_Command(hci.HCI_Command):
pass pass
@hci.HCI_Command.command(
fields=[("data_type", 1), ("data", "*")],
return_parameters_fields=[
("status", 1),
],
)
class Hci_Intel_Secure_Send_Command(hci.HCI_Command):
pass
@hci.HCI_Command.command(
fields=[
("reset_type", 1),
("patch_enable", 1),
("ddc_reload", 1),
("boot_option", 1),
("boot_address", 4),
],
return_parameters_fields=[
("data", "*"),
],
)
class HCI_Intel_Reset_Command(hci.HCI_Command):
pass
@hci.HCI_Command.command(
fields=[("data", "*")],
return_parameters_fields=[
("status", hci.STATUS_SPEC),
("params", "*"),
],
)
class Hci_Intel_Write_Device_Config_Command(hci.HCI_Command):
pass
# -----------------------------------------------------------------------------
# Functions
# -----------------------------------------------------------------------------
def intel_firmware_dir() -> pathlib.Path:
"""
Returns:
A path to a subdir of the project data dir for Intel firmware.
The directory is created if it doesn't exist.
"""
from bumble.drivers import project_data_dir
p = project_data_dir() / "firmware" / "intel"
p.mkdir(parents=True, exist_ok=True)
return p
def _find_binary_path(file_name: str) -> pathlib.Path | None:
# First check if an environment variable is set
if INTEL_FIRMWARE_DIR_ENV in os.environ:
if (
path := pathlib.Path(os.environ[INTEL_FIRMWARE_DIR_ENV]) / file_name
).is_file():
logger.debug(f"{file_name} found in env dir")
return path
# When the environment variable is set, don't look elsewhere
return None
# Then, look where the firmware download tool writes by default
if (path := intel_firmware_dir() / file_name).is_file():
logger.debug(f"{file_name} found in project data dir")
return path
# Then, look in the package's driver directory
if (path := pathlib.Path(__file__).parent / "intel_fw" / file_name).is_file():
logger.debug(f"{file_name} found in package dir")
return path
# On Linux, check the system's FW directory
if (
platform.system() == "Linux"
and (path := pathlib.Path(INTEL_LINUX_FIRMWARE_DIR) / file_name).is_file()
):
logger.debug(f"{file_name} found in Linux system FW dir")
return path
# Finally look in the current directory
if (path := pathlib.Path.cwd() / file_name).is_file():
logger.debug(f"{file_name} found in CWD")
return path
return None
def _parse_tlv(data: bytes) -> list[tuple[ValueType, Any]]:
result: list[tuple[ValueType, Any]] = []
while len(data) >= 2:
value_type = ValueType(data[0])
value_length = data[1]
value = data[2 : 2 + value_length]
typed_value: Any
if value_type == ValueType.END:
break
if value_type in (ValueType.CNVI, ValueType.CNVR):
(v,) = struct.unpack("<I", value)
typed_value = (
(((v >> 0) & 0xF) << 12)
| (((v >> 4) & 0xF) << 0)
| (((v >> 8) & 0xF) << 4)
| (((v >> 24) & 0xF) << 8)
)
elif value_type == ValueType.HARDWARE_INFO:
(v,) = struct.unpack("<I", value)
typed_value = HardwareInfo(
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
)
elif value_type in (
ValueType.USB_VENDOR_ID,
ValueType.USB_PRODUCT_ID,
ValueType.DEVICE_REVISION,
):
(typed_value,) = struct.unpack("<H", value)
elif value_type == ValueType.CURRENT_MODE_OF_OPERATION:
typed_value = ModeOfOperation(value[0])
elif value_type in (
ValueType.BUILD_TYPE,
ValueType.BUILD_NUMBER,
ValueType.SECURE_BOOT,
ValueType.OTP_LOCK,
ValueType.API_LOCK,
ValueType.DEBUG_LOCK,
ValueType.SECURE_BOOT_ENGINE_TYPE,
):
typed_value = value[0]
elif value_type == ValueType.TIMESTAMP:
typed_value = Timestamp(value[0], value[1])
elif value_type == ValueType.FIRMWARE_BUILD:
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
elif value_type == ValueType.BLUETOOTH_ADDRESS:
typed_value = hci.Address(
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
)
else:
typed_value = value
result.append((value_type, typed_value))
data = data[2 + value_length :]
return result
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class DriverError(core.BaseBumbleError):
def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message
def __str__(self) -> str:
return f"IntelDriverError({self.message})"
class ValueType(utils.OpenIntEnum):
END = 0x00
CNVI = 0x10
CNVR = 0x11
HARDWARE_INFO = 0x12
DEVICE_REVISION = 0x16
CURRENT_MODE_OF_OPERATION = 0x1C
USB_VENDOR_ID = 0x17
USB_PRODUCT_ID = 0x18
TIMESTAMP = 0x1D
BUILD_TYPE = 0x1E
BUILD_NUMBER = 0x1F
SECURE_BOOT = 0x28
OTP_LOCK = 0x2A
API_LOCK = 0x2B
DEBUG_LOCK = 0x2C
FIRMWARE_BUILD = 0x2D
SECURE_BOOT_ENGINE_TYPE = 0x2F
BLUETOOTH_ADDRESS = 0x30
class HardwarePlatform(utils.OpenIntEnum):
INTEL_37 = 0x37
class HardwareVariant(utils.OpenIntEnum):
# This is a just a partial list.
# Add other constants here as new hardware is encountered and tested.
TYPHOON_PEAK = 0x17
GALE_PEAK = 0x1C
@dataclasses.dataclass
class HardwareInfo:
platform: HardwarePlatform
variant: HardwareVariant
@dataclasses.dataclass
class Timestamp:
week: int
year: int
@dataclasses.dataclass
class FirmwareBuild:
build_number: int
timestamp: Timestamp
class ModeOfOperation(utils.OpenIntEnum):
BOOTLOADER = 0x01
INTERMEDIATE = 0x02
OPERATIONAL = 0x03
class SecureBootEngineType(utils.OpenIntEnum):
RSA = 0x00
ECDSA = 0x01
@dataclasses.dataclass
class BootParams:
css_header_offset: int
css_header_size: int
pki_offset: int
pki_size: int
sig_offset: int
sig_size: int
write_offset: int
_BOOT_PARAMS = {
SecureBootEngineType.RSA: BootParams(0, 128, 128, 256, 388, 256, 964),
SecureBootEngineType.ECDSA: BootParams(644, 128, 772, 96, 868, 96, 964),
}
class Driver(common.Driver): class Driver(common.Driver):
def __init__(self, host): def __init__(self, host: Host) -> None:
self.host = host self.host = host
self.max_in_flight_firmware_load_commands = 1
self.pending_firmware_load_commands: Deque[hci.HCI_Command] = (
collections.deque()
)
self.can_send_firmware_load_command = asyncio.Event()
self.can_send_firmware_load_command.set()
self.firmware_load_complete = asyncio.Event()
self.reset_complete = asyncio.Event()
# Parse configuration options from the driver name.
self.ddc_addon: Optional[bytes] = None
self.ddc_override: Optional[bytes] = None
driver = host.hci_metadata.get("driver")
if driver is not None and driver.startswith("intel/"):
for key, value in [
key_eq_value.split(":") for key_eq_value in driver[6:].split("+")
]:
if key == "ddc_addon":
self.ddc_addon = bytes.fromhex(value)
elif key == "ddc_override":
self.ddc_override = bytes.fromhex(value)
@staticmethod @staticmethod
def check(host): def check(host: Host) -> bool:
driver = host.hci_metadata.get("driver") driver = host.hci_metadata.get("driver")
if driver == "intel": if driver == "intel" or driver is not None and driver.startswith("intel/"):
return True return True
vendor_id = host.hci_metadata.get("vendor_id") vendor_id = host.hci_metadata.get("vendor_id")
@@ -85,18 +389,283 @@ class Driver(common.Driver):
return True return True
@classmethod @classmethod
async def for_host(cls, host, force=False): # type: ignore async def for_host(cls, host: Host, force: bool = False):
# Only instantiate this driver if explicitly selected # Only instantiate this driver if explicitly selected
if not force and not cls.check(host): if not force and not cls.check(host):
return None return None
return cls(host) return cls(host)
async def init_controller(self): def on_packet(self, packet: bytes) -> None:
"""Handler for event packets that are received from an ACL channel"""
event = hci.HCI_Event.from_bytes(packet)
if not isinstance(event, hci.HCI_Command_Complete_Event):
self.host.on_hci_event_packet(event)
return
if not event.return_parameters == hci.HCI_SUCCESS:
raise DriverError("HCI_Command_Complete_Event error")
if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets:
logger.debug(
"max_in_flight_firmware_load_commands update: "
f"{event.num_hci_command_packets}"
)
self.max_in_flight_firmware_load_commands = event.num_hci_command_packets
logger.debug(f"event: {event}")
self.pending_firmware_load_commands.popleft()
in_flight = len(self.pending_firmware_load_commands)
logger.debug(f"event received, {in_flight} still in flight")
if in_flight < self.max_in_flight_firmware_load_commands:
self.can_send_firmware_load_command.set()
async def send_firmware_load_command(self, command: hci.HCI_Command) -> None:
# Wait until we can send.
await self.can_send_firmware_load_command.wait()
# Send the command and adjust counters.
self.host.send_hci_packet(command)
self.pending_firmware_load_commands.append(command)
in_flight = len(self.pending_firmware_load_commands)
if in_flight >= self.max_in_flight_firmware_load_commands:
logger.debug(f"max commands in flight reached [{in_flight}]")
self.can_send_firmware_load_command.clear()
async def send_firmware_data(self, data_type: int, data: bytes) -> None:
while data:
fragment_size = min(len(data), _MAX_FRAGMENT_SIZE)
fragment = data[:fragment_size]
data = data[fragment_size:]
await self.send_firmware_load_command(
Hci_Intel_Secure_Send_Command(data_type=data_type, data=fragment)
)
async def load_firmware(self) -> None:
self.host.ready = True self.host.ready = True
await self.host.send_command(HCI_Reset_Command(), check_result=True) device_info = await self.read_device_info()
await self.host.send_command( logger.debug(
Hci_Intel_DDC_Config_Write_Command( "device info: \n%s",
params=HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD "\n".join(
[
f" {value_type.name}: {value}"
for value_type, value in device_info.items()
]
),
)
# Check if the firmware is already loaded.
if (
device_info.get(ValueType.CURRENT_MODE_OF_OPERATION)
== ModeOfOperation.OPERATIONAL
):
logger.debug("firmware already loaded")
return
# We only support some platforms and variants.
hardware_info = device_info.get(ValueType.HARDWARE_INFO)
if hardware_info is None:
raise DriverError("hardware info missing")
if hardware_info.platform != HardwarePlatform.INTEL_37:
raise DriverError("hardware platform not supported")
if hardware_info.variant not in (
HardwareVariant.TYPHOON_PEAK,
HardwareVariant.GALE_PEAK,
):
raise DriverError("hardware variant not supported")
# Compute the firmware name.
if ValueType.CNVI not in device_info or ValueType.CNVR not in device_info:
raise DriverError("insufficient device info, missing CNVI or CNVR")
firmware_base_name = (
"ibt-"
f"{device_info[ValueType.CNVI]:04X}-"
f"{device_info[ValueType.CNVR]:04X}"
)
logger.debug(f"FW base name: {firmware_base_name}")
firmware_name = f"{firmware_base_name}.sfi"
firmware_path = _find_binary_path(firmware_name)
if not firmware_path:
logger.warning(f"Firmware file {firmware_name} not found")
logger.warning("See https://google.github.io/bumble/drivers/intel.html")
return None
logger.debug(f"loading firmware from {firmware_path}")
firmware_image = firmware_path.read_bytes()
engine_type = device_info.get(ValueType.SECURE_BOOT_ENGINE_TYPE)
if engine_type is None:
raise DriverError("secure boot engine type missing")
if engine_type not in _BOOT_PARAMS:
raise DriverError("secure boot engine type not supported")
boot_params = _BOOT_PARAMS[engine_type]
if len(firmware_image) < boot_params.write_offset:
raise DriverError("firmware image too small")
# Register to receive vendor events.
def on_vendor_event(event: hci.HCI_Vendor_Event):
logger.debug(f"vendor event: {event}")
event_type = event.parameters[0]
if event_type == 0x02:
# Boot event
logger.debug("boot complete")
self.reset_complete.set()
elif event_type == 0x06:
# Firmware load event
logger.debug("download complete")
self.firmware_load_complete.set()
else:
logger.debug(f"ignoring vendor event type {event_type}")
self.host.on("vendor_event", on_vendor_event)
# We need to temporarily intercept packets from the controller,
# because they are formatted as HCI event packets but are received
# on the ACL channel, so the host parser would get confused.
saved_on_packet = self.host.on_packet
self.host.on_packet = self.on_packet # type: ignore
self.firmware_load_complete.clear()
# Send the CSS header
data = firmware_image[
boot_params.css_header_offset : boot_params.css_header_offset
+ boot_params.css_header_size
]
await self.send_firmware_data(0x00, data)
# Send the PKI header
data = firmware_image[
boot_params.pki_offset : boot_params.pki_offset + boot_params.pki_size
]
await self.send_firmware_data(0x03, data)
# Send the Signature header
data = firmware_image[
boot_params.sig_offset : boot_params.sig_offset + boot_params.sig_size
]
await self.send_firmware_data(0x02, data)
# Send the rest of the image.
# The payload consists of command objects, which are sent when they add up
# to a multiple of 4 bytes.
boot_address = 0
offset = boot_params.write_offset
fragment_size = 0
while offset + 3 < len(firmware_image):
(command_opcode,) = struct.unpack_from(
"<H", firmware_image, offset + fragment_size
)
command_size = firmware_image[offset + fragment_size + 2]
if command_opcode == HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND:
(boot_address,) = struct.unpack_from(
"<I", firmware_image, offset + fragment_size + 3
)
logger.debug(
"found HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND, "
f"boot_address={boot_address}"
)
fragment_size += 3 + command_size
if fragment_size % 4 == 0:
await self.send_firmware_data(
0x01, firmware_image[offset : offset + fragment_size]
)
logger.debug(f"sent {fragment_size} bytes")
offset += fragment_size
fragment_size = 0
# Wait for the firmware loading to be complete.
logger.debug("waiting for firmware to be loaded")
await self.firmware_load_complete.wait()
logger.debug("firmware loaded")
# Restore the original packet handler.
self.host.on_packet = saved_on_packet # type: ignore
# Reset
self.reset_complete.clear()
self.host.send_hci_packet(
HCI_Intel_Reset_Command(
reset_type=0x00,
patch_enable=0x01,
ddc_reload=0x00,
boot_option=0x01,
boot_address=boot_address,
) )
) )
logger.debug("waiting for reset completion")
await self.reset_complete.wait()
logger.debug("reset complete")
# Load the device config if there is one.
if self.ddc_override:
logger.debug("loading overridden DDC")
await self.load_device_config(self.ddc_override)
else:
ddc_name = f"{firmware_base_name}.ddc"
ddc_path = _find_binary_path(ddc_name)
if ddc_path:
logger.debug(f"loading DDC from {ddc_path}")
ddc_data = ddc_path.read_bytes()
await self.load_device_config(ddc_data)
if self.ddc_addon:
logger.debug("loading DDC addon")
await self.load_device_config(self.ddc_addon)
async def load_device_config(self, ddc_data: bytes) -> None:
while ddc_data:
ddc_len = 1 + ddc_data[0]
ddc_payload = ddc_data[:ddc_len]
await self.host.send_command(
Hci_Intel_Write_Device_Config_Command(data=ddc_payload)
)
ddc_data = ddc_data[ddc_len:]
async def reboot_bootloader(self) -> None:
self.host.send_hci_packet(
HCI_Intel_Reset_Command(
reset_type=0x01,
patch_enable=0x01,
ddc_reload=0x01,
boot_option=0x00,
boot_address=0,
)
)
await asyncio.sleep(_POST_RESET_DELAY)
async def read_device_info(self) -> dict[ValueType, Any]:
self.host.ready = True
response = await self.host.send_command(hci.HCI_Reset_Command())
if not (
isinstance(response, hci.HCI_Command_Complete_Event)
and response.return_parameters
in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS)
):
# When the controller is in operational mode, the response is a
# successful response.
# When the controller is in bootloader mode,
# HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything
# else is a failure.
logger.warning(f"unexpected response: {response}")
raise DriverError("unexpected HCI response")
# Read the firmware version.
response = await self.host.send_command(
HCI_Intel_Read_Version_Command(param0=0xFF)
)
if not isinstance(response, hci.HCI_Command_Complete_Event):
raise DriverError("unexpected HCI response")
if response.return_parameters.status != 0: # type: ignore
raise DriverError("HCI_Intel_Read_Version_Command error")
tlvs = _parse_tlv(response.return_parameters.tlv) # type: ignore
# Convert the list to a dict. That's Ok here because we only expect each type
# to appear just once.
return dict(tlvs)
async def init_controller(self):
await self.load_firmware()

View File

@@ -28,23 +28,26 @@ import functools
import logging import logging
import struct import struct
from typing import ( from typing import (
Any,
Callable, Callable,
Dict, Dict,
Iterable, Iterable,
List, List,
Optional, Optional,
Sequence, Sequence,
SupportsBytes,
Type,
Union, Union,
TYPE_CHECKING, TYPE_CHECKING,
) )
from bumble.colors import color from bumble.colors import color
from bumble.core import BaseBumbleError, UUID from bumble.core import BaseBumbleError, InvalidOperationError, UUID
from bumble.att import Attribute, AttributeValue from bumble.att import Attribute, AttributeValue
from bumble.utils import ByteSerializable
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.gatt_client import AttributeProxy from bumble.gatt_client import AttributeProxy
from bumble.device import Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -275,6 +278,13 @@ GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Sou
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts') GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts') GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
# Gaming Audio Service (GMAS)
GATT_GMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2C00, 'GMAP Role')
GATT_UGG_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C01, 'UGG Features')
GATT_UGT_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C02, 'UGT Features')
GATT_BGS_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C03, 'BGS Features')
GATT_BGR_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C04, 'BGR Features')
# Hearing Access Service # Hearing Access Service
GATT_HEARING_AID_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2BDA, 'Hearing Aid Features') GATT_HEARING_AID_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2BDA, 'Hearing Aid Features')
GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BDB, 'Hearing Aid Preset Control Point') GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BDB, 'Hearing Aid Preset Control Point')
@@ -304,6 +314,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features') GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features')
GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash') GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash')
GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features') GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features')
GATT_LE_GATT_SECURITY_LEVELS_CHARACTERISTIC = UUID.from_16_bits(0x2BF5, 'E GATT Security Levels')
# fmt: on # fmt: on
# pylint: enable=line-too-long # pylint: enable=line-too-long
@@ -312,8 +323,6 @@ GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bi
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def show_services(services: Iterable[Service]) -> None: def show_services(services: Iterable[Service]) -> None:
for service in services: for service in services:
print(color(str(service), 'cyan')) print(color(str(service), 'cyan'))
@@ -343,7 +352,7 @@ class Service(Attribute):
def __init__( def __init__(
self, self,
uuid: Union[str, UUID], uuid: Union[str, UUID],
characteristics: List[Characteristic], characteristics: Iterable[Characteristic],
primary=True, primary=True,
included_services: Iterable[Service] = (), included_services: Iterable[Service] = (),
) -> None: ) -> None:
@@ -362,7 +371,7 @@ class Service(Attribute):
) )
self.uuid = uuid self.uuid = uuid
self.included_services = list(included_services) self.included_services = list(included_services)
self.characteristics = characteristics[:] self.characteristics = list(characteristics)
self.primary = primary self.primary = primary
def get_advertising_data(self) -> Optional[bytes]: def get_advertising_data(self) -> Optional[bytes]:
@@ -393,7 +402,7 @@ class TemplateService(Service):
def __init__( def __init__(
self, self,
characteristics: List[Characteristic], characteristics: Iterable[Characteristic],
primary: bool = True, primary: bool = True,
included_services: Iterable[Service] = (), included_services: Iterable[Service] = (),
) -> None: ) -> None:
@@ -410,7 +419,7 @@ class IncludedServiceDeclaration(Attribute):
def __init__(self, service: Service) -> None: def __init__(self, service: Service) -> None:
declaration_bytes = struct.pack( declaration_bytes = struct.pack(
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes() '<HH2s', service.handle, service.end_group_handle, bytes(service.uuid)
) )
super().__init__( super().__init__(
GATT_INCLUDE_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes GATT_INCLUDE_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
@@ -490,7 +499,7 @@ class Characteristic(Attribute):
uuid: Union[str, bytes, UUID], uuid: Union[str, bytes, UUID],
properties: Characteristic.Properties, properties: Characteristic.Properties,
permissions: Union[str, Attribute.Permissions], permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, CharacteristicValue] = b'', value: Any = b'',
descriptors: Sequence[Descriptor] = (), descriptors: Sequence[Descriptor] = (),
): ):
super().__init__(uuid, permissions, value) super().__init__(uuid, permissions, value)
@@ -525,7 +534,11 @@ class CharacteristicDeclaration(Attribute):
characteristic: Characteristic characteristic: Characteristic
def __init__(self, characteristic: Characteristic, value_handle: int) -> None: def __init__(
self,
characteristic: Characteristic,
value_handle: int,
) -> None:
declaration_bytes = ( declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle) struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes() + characteristic.uuid.to_pdu_bytes()
@@ -665,10 +678,14 @@ class DelegatedCharacteristicAdapter(CharacteristicAdapter):
self.decode = decode self.decode = decode
def encode_value(self, value): def encode_value(self, value):
return self.encode(value) if self.encode else value if self.encode is None:
raise InvalidOperationError('delegated adapter does not have an encoder')
return self.encode(value)
def decode_value(self, value): def decode_value(self, value):
return self.decode(value) if self.decode else value if self.decode is None:
raise InvalidOperationError('delegate adapter does not have a decoder')
return self.decode(value)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -705,7 +722,7 @@ class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
''' '''
Adapter that packs/unpacks characteristic values according to a standard Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format. Python `struct` format.
The adapted `read_value` and `write_value` methods return/accept aa dictionary which The adapted `read_value` and `write_value` methods return/accept a dictionary which
is packed/unpacked according to format, with the arguments extracted from the is packed/unpacked according to format, with the arguments extracted from the
dictionary by key, in the same order as they occur in the `keys` parameter. dictionary by key, in the same order as they occur in the `keys` parameter.
''' '''
@@ -735,6 +752,24 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
return value.decode('utf-8') return value.decode('utf-8')
# -----------------------------------------------------------------------------
class SerializableCharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that converts any class to/from bytes using the class'
`to_bytes` and `__bytes__` methods, respectively.
'''
def __init__(self, characteristic, cls: Type[ByteSerializable]):
super().__init__(characteristic)
self.cls = cls
def encode_value(self, value: SupportsBytes) -> bytes:
return bytes(value)
def decode_value(self, value: bytes) -> Any:
return self.cls.from_bytes(value)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Descriptor(Attribute): class Descriptor(Attribute):
''' '''
@@ -769,3 +804,23 @@ class ClientCharacteristicConfigurationBits(enum.IntFlag):
DEFAULT = 0x0000 DEFAULT = 0x0000
NOTIFICATION = 0x0001 NOTIFICATION = 0x0001
INDICATION = 0x0002 INDICATION = 0x0002
# -----------------------------------------------------------------------------
class ClientSupportedFeatures(enum.IntFlag):
'''
See Vol 3, Part G - 7.2 - Table 7.6: Client Supported Features bit assignments.
'''
ROBUST_CACHING = 0x01
ENHANCED_ATT_BEARER = 0x02
MULTIPLE_HANDLE_VALUE_NOTIFICATIONS = 0x04
# -----------------------------------------------------------------------------
class ServerSupportedFeatures(enum.IntFlag):
'''
See Vol 3, Part G - 7.4 - Table 7.11: Server Supported Features bit assignments.
'''
EATT_SUPPORTED = 0x01

View File

@@ -78,6 +78,7 @@ from .gatt import (
GATT_INCLUDE_ATTRIBUTE_TYPE, GATT_INCLUDE_ATTRIBUTE_TYPE,
Characteristic, Characteristic,
ClientCharacteristicConfigurationBits, ClientCharacteristicConfigurationBits,
InvalidServiceError,
TemplateService, TemplateService,
) )
@@ -162,12 +163,23 @@ class ServiceProxy(AttributeProxy):
self.uuid = uuid self.uuid = uuid
self.characteristics = [] self.characteristics = []
async def discover_characteristics(self, uuids=()): async def discover_characteristics(self, uuids=()) -> list[CharacteristicProxy]:
return await self.client.discover_characteristics(uuids, self) return await self.client.discover_characteristics(uuids, self)
def get_characteristics_by_uuid(self, uuid): def get_characteristics_by_uuid(self, uuid: UUID) -> list[CharacteristicProxy]:
"""Get all the characteristics with a specified UUID."""
return self.client.get_characteristics_by_uuid(uuid, self) return self.client.get_characteristics_by_uuid(uuid, self)
def get_required_characteristic_by_uuid(self, uuid: UUID) -> CharacteristicProxy:
"""
Get the first characteristic with a specified UUID.
If no characteristic with that UUID is found, an InvalidServiceError is raised.
"""
if not (characteristics := self.get_characteristics_by_uuid(uuid)):
raise InvalidServiceError(f'{uuid} characteristic not found')
return characteristics[0]
def __str__(self) -> str: def __str__(self) -> str:
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})' return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
@@ -292,7 +304,7 @@ class Client:
logger.debug( logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}' f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
) )
self.send_gatt_pdu(command.to_bytes()) self.send_gatt_pdu(bytes(command))
async def send_request(self, request: ATT_PDU): async def send_request(self, request: ATT_PDU):
logger.debug( logger.debug(
@@ -310,7 +322,7 @@ class Client:
self.pending_request = request self.pending_request = request
try: try:
self.send_gatt_pdu(request.to_bytes()) self.send_gatt_pdu(bytes(request))
response = await asyncio.wait_for( response = await asyncio.wait_for(
self.pending_response, GATT_REQUEST_TIMEOUT self.pending_response, GATT_REQUEST_TIMEOUT
) )
@@ -328,7 +340,7 @@ class Client:
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] ' f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}' f'{confirmation}'
) )
self.send_gatt_pdu(confirmation.to_bytes()) self.send_gatt_pdu(bytes(confirmation))
async def request_mtu(self, mtu: int) -> int: async def request_mtu(self, mtu: int) -> int:
# Check the range # Check the range
@@ -898,6 +910,12 @@ class Client:
) and subscriber in subscribers: ) and subscriber in subscribers:
subscribers.remove(subscriber) subscribers.remove(subscriber)
# The characteristic itself is added as subscriber. If it is the
# last remaining subscriber, we remove it, such that the clean up
# works correctly. Otherwise the CCCD never is set back to 0.
if len(subscribers) == 1 and characteristic in subscribers:
subscribers.remove(characteristic)
# Cleanup if we removed the last one # Cleanup if we removed the last one
if not subscribers: if not subscribers:
del subscriber_set[characteristic.handle] del subscriber_set[characteristic.handle]

View File

@@ -28,7 +28,17 @@ import asyncio
import logging import logging
from collections import defaultdict from collections import defaultdict
import struct import struct
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING from typing import (
Dict,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Type,
Union,
TYPE_CHECKING,
)
from pyee import EventEmitter from pyee import EventEmitter
from bumble.colors import color from bumble.colors import color
@@ -68,6 +78,7 @@ from bumble.gatt import (
GATT_REQUEST_TIMEOUT, GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic, Characteristic,
CharacteristicAdapter,
CharacteristicDeclaration, CharacteristicDeclaration,
CharacteristicValue, CharacteristicValue,
IncludedServiceDeclaration, IncludedServiceDeclaration,
@@ -353,7 +364,7 @@ class Server(EventEmitter):
logger.debug( logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}' f'GATT Response from server: [0x{connection.handle:04X}] {response}'
) )
self.send_gatt_pdu(connection.handle, response.to_bytes()) self.send_gatt_pdu(connection.handle, bytes(response))
async def notify_subscriber( async def notify_subscriber(
self, self,
@@ -450,7 +461,7 @@ class Server(EventEmitter):
) )
try: try:
self.send_gatt_pdu(connection.handle, indication.to_bytes()) self.send_gatt_pdu(connection.handle, bytes(indication))
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT) await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError as error: except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red')) logger.warning(color('!!! GATT Indicate timeout', 'red'))

File diff suppressed because it is too large Load Diff

View File

@@ -141,7 +141,7 @@ class HfFeature(enum.IntFlag):
""" """
HF supported features (AT+BRSF=) (normative). HF supported features (AT+BRSF=) (normative).
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007. Hands-Free Profile v1.9, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
""" """
EC_NR = 0x001 # Echo Cancel & Noise reduction EC_NR = 0x001 # Echo Cancel & Noise reduction
@@ -155,14 +155,14 @@ class HfFeature(enum.IntFlag):
HF_INDICATORS = 0x100 HF_INDICATORS = 0x100
ESCO_S4_SETTINGS_SUPPORTED = 0x200 ESCO_S4_SETTINGS_SUPPORTED = 0x200
ENHANCED_VOICE_RECOGNITION_STATUS = 0x400 ENHANCED_VOICE_RECOGNITION_STATUS = 0x400
VOICE_RECOGNITION_TEST = 0x800 VOICE_RECOGNITION_TEXT = 0x800
class AgFeature(enum.IntFlag): class AgFeature(enum.IntFlag):
""" """
AG supported features (+BRSF:) (normative). AG supported features (+BRSF:) (normative).
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007. Hands-Free Profile v1.9, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
""" """
THREE_WAY_CALLING = 0x001 THREE_WAY_CALLING = 0x001
@@ -178,7 +178,7 @@ class AgFeature(enum.IntFlag):
HF_INDICATORS = 0x400 HF_INDICATORS = 0x400
ESCO_S4_SETTINGS_SUPPORTED = 0x800 ESCO_S4_SETTINGS_SUPPORTED = 0x800
ENHANCED_VOICE_RECOGNITION_STATUS = 0x1000 ENHANCED_VOICE_RECOGNITION_STATUS = 0x1000
VOICE_RECOGNITION_TEST = 0x2000 VOICE_RECOGNITION_TEXT = 0x2000
class AudioCodec(enum.IntEnum): class AudioCodec(enum.IntEnum):
@@ -1390,6 +1390,7 @@ class AgProtocol(pyee.EventEmitter):
def _on_bac(self, *args) -> None: def _on_bac(self, *args) -> None:
self.supported_audio_codecs = [AudioCodec(int(value)) for value in args] self.supported_audio_codecs = [AudioCodec(int(value)) for value in args]
self.emit('supported_audio_codecs', self.supported_audio_codecs)
self.send_ok() self.send_ok()
def _on_bcs(self, codec: bytes) -> None: def _on_bcs(self, codec: bytes) -> None:
@@ -1618,7 +1619,7 @@ class ProfileVersion(enum.IntEnum):
""" """
Profile version (normative). Profile version (normative).
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements. Hands-Free Profile v1.8, 6.3 SDP Interoperability Requirements.
""" """
V1_5 = 0x0105 V1_5 = 0x0105
@@ -1632,7 +1633,7 @@ class HfSdpFeature(enum.IntFlag):
""" """
HF supported features (normative). HF supported features (normative).
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements. Hands-Free Profile v1.9, 6.3 SDP Interoperability Requirements.
""" """
EC_NR = 0x01 # Echo Cancel & Noise reduction EC_NR = 0x01 # Echo Cancel & Noise reduction
@@ -1640,16 +1641,17 @@ class HfSdpFeature(enum.IntFlag):
CLI_PRESENTATION_CAPABILITY = 0x04 CLI_PRESENTATION_CAPABILITY = 0x04
VOICE_RECOGNITION_ACTIVATION = 0x08 VOICE_RECOGNITION_ACTIVATION = 0x08
REMOTE_VOLUME_CONTROL = 0x10 REMOTE_VOLUME_CONTROL = 0x10
WIDE_BAND = 0x20 # Wide band speech WIDE_BAND_SPEECH = 0x20
ENHANCED_VOICE_RECOGNITION_STATUS = 0x40 ENHANCED_VOICE_RECOGNITION_STATUS = 0x40
VOICE_RECOGNITION_TEST = 0x80 VOICE_RECOGNITION_TEXT = 0x80
SUPER_WIDE_BAND = 0x100
class AgSdpFeature(enum.IntFlag): class AgSdpFeature(enum.IntFlag):
""" """
AG supported features (normative). AG supported features (normative).
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements. Hands-Free Profile v1.9, 6.3 SDP Interoperability Requirements.
""" """
THREE_WAY_CALLING = 0x01 THREE_WAY_CALLING = 0x01
@@ -1657,9 +1659,10 @@ class AgSdpFeature(enum.IntFlag):
VOICE_RECOGNITION_FUNCTION = 0x04 VOICE_RECOGNITION_FUNCTION = 0x04
IN_BAND_RING_TONE_CAPABILITY = 0x08 IN_BAND_RING_TONE_CAPABILITY = 0x08
VOICE_TAG = 0x10 # Attach a number to voice tag VOICE_TAG = 0x10 # Attach a number to voice tag
WIDE_BAND = 0x20 # Wide band speech WIDE_BAND_SPEECH = 0x20
ENHANCED_VOICE_RECOGNITION_STATUS = 0x40 ENHANCED_VOICE_RECOGNITION_STATUS = 0x40
VOICE_RECOGNITION_TEST = 0x80 VOICE_RECOGNITION_TEXT = 0x80
SUPER_WIDE_BAND_SPEED_SPEECH = 0x100
def make_hf_sdp_records( def make_hf_sdp_records(
@@ -1692,11 +1695,11 @@ def make_hf_sdp_records(
in configuration.supported_hf_features in configuration.supported_hf_features
): ):
hf_supported_features |= HfSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS hf_supported_features |= HfSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS
if HfFeature.VOICE_RECOGNITION_TEST in configuration.supported_hf_features: if HfFeature.VOICE_RECOGNITION_TEXT in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.VOICE_RECOGNITION_TEST hf_supported_features |= HfSdpFeature.VOICE_RECOGNITION_TEXT
if AudioCodec.MSBC in configuration.supported_audio_codecs: if AudioCodec.MSBC in configuration.supported_audio_codecs:
hf_supported_features |= HfSdpFeature.WIDE_BAND hf_supported_features |= HfSdpFeature.WIDE_BAND_SPEECH
return [ return [
sdp.ServiceAttribute( sdp.ServiceAttribute(
@@ -1772,14 +1775,14 @@ def make_ag_sdp_records(
in configuration.supported_ag_features in configuration.supported_ag_features
): ):
ag_supported_features |= AgSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS ag_supported_features |= AgSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS
if AgFeature.VOICE_RECOGNITION_TEST in configuration.supported_ag_features: if AgFeature.VOICE_RECOGNITION_TEXT in configuration.supported_ag_features:
ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_TEST ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_TEXT
if AgFeature.IN_BAND_RING_TONE_CAPABILITY in configuration.supported_ag_features: if AgFeature.IN_BAND_RING_TONE_CAPABILITY in configuration.supported_ag_features:
ag_supported_features |= AgSdpFeature.IN_BAND_RING_TONE_CAPABILITY ag_supported_features |= AgSdpFeature.IN_BAND_RING_TONE_CAPABILITY
if AgFeature.VOICE_RECOGNITION_FUNCTION in configuration.supported_ag_features: if AgFeature.VOICE_RECOGNITION_FUNCTION in configuration.supported_ag_features:
ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_FUNCTION ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_FUNCTION
if AudioCodec.MSBC in configuration.supported_audio_codecs: if AudioCodec.MSBC in configuration.supported_audio_codecs:
ag_supported_features |= AgSdpFeature.WIDE_BAND ag_supported_features |= AgSdpFeature.WIDE_BAND_SPEECH
return [ return [
sdp.ServiceAttribute( sdp.ServiceAttribute(

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC # Copyright 2021-2025 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -34,6 +34,8 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
) )
import pyee
from bumble.colors import color from bumble.colors import color
from bumble.l2cap import L2CAP_PDU from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper from bumble.snoop import Snooper
@@ -59,7 +61,19 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AclPacketQueue: class DataPacketQueue(pyee.EventEmitter):
"""
Flow-control queue for host->controller data packets (ACL, ISO).
The queue holds packets associated with a connection handle. The packets
are sent to the controller, up to a maximum total number of packets in flight.
A packet is considered to be "in flight" when it has been sent to the controller
but not completed yet. Packets are no longer "in flight" when the controller
declares them as completed.
The queue emits a 'flow' event whenever one or more packets are completed.
"""
max_packet_size: int max_packet_size: int
def __init__( def __init__(
@@ -68,40 +82,105 @@ class AclPacketQueue:
max_in_flight: int, max_in_flight: int,
send: Callable[[hci.HCI_Packet], None], send: Callable[[hci.HCI_Packet], None],
) -> None: ) -> None:
super().__init__()
self.max_packet_size = max_packet_size self.max_packet_size = max_packet_size
self.max_in_flight = max_in_flight self.max_in_flight = max_in_flight
self.in_flight = 0 self._in_flight = 0 # Total number of packets in flight across all connections
self.send = send self._in_flight_per_connection: dict[int, int] = collections.defaultdict(
self.packets: Deque[hci.HCI_AclDataPacket] = collections.deque() int
) # Number of packets in flight per connection
self._send = send
self._packets: Deque[tuple[hci.HCI_Packet, int]] = collections.deque()
self._queued = 0
self._completed = 0
def enqueue(self, packet: hci.HCI_AclDataPacket) -> None: @property
self.packets.appendleft(packet) def queued(self) -> int:
self.check_queue() """Total number of packets queued since creation."""
return self._queued
if self.packets: @property
def completed(self) -> int:
"""Total number of packets completed since creation."""
return self._completed
@property
def pending(self) -> int:
"""Number of packets that have been queued but not completed."""
return self._queued - self._completed
def enqueue(self, packet: hci.HCI_Packet, connection_handle: int) -> None:
"""Enqueue a packet associated with a connection"""
self._packets.appendleft((packet, connection_handle))
self._queued += 1
self._check_queue()
if self._packets:
logger.debug( logger.debug(
f'{self.in_flight} ACL packets in flight, ' f'{self._in_flight} packets in flight, '
f'{len(self.packets)} in queue' f'{len(self._packets)} in queue'
) )
def check_queue(self) -> None: def flush(self, connection_handle: int) -> None:
while self.packets and self.in_flight < self.max_in_flight: """
packet = self.packets.pop() Remove all packets associated with a connection.
self.send(packet)
self.in_flight += 1
def on_packets_completed(self, packet_count: int) -> None: All packets associated with the connection that are in flight are implicitly
if packet_count > self.in_flight: marked as completed, but no 'flow' event is emitted.
"""
packets_to_keep = [
(packet, handle)
for (packet, handle) in self._packets
if handle != connection_handle
]
if flushed_count := len(self._packets) - len(packets_to_keep):
self._completed += flushed_count
self._packets = collections.deque(packets_to_keep)
if connection_handle in self._in_flight_per_connection:
in_flight = self._in_flight_per_connection[connection_handle]
self._completed += in_flight
self._in_flight -= in_flight
del self._in_flight_per_connection[connection_handle]
def _check_queue(self) -> None:
while self._packets and self._in_flight < self.max_in_flight:
packet, connection_handle = self._packets.pop()
self._send(packet)
self._in_flight += 1
self._in_flight_per_connection[connection_handle] += 1
def on_packets_completed(self, packet_count: int, connection_handle: int) -> None:
"""Mark one or more packets associated with a connection as completed."""
if connection_handle not in self._in_flight_per_connection:
logger.warning( logger.warning(
color( f'received completion for unknown connection {connection_handle}'
'!!! {packet_count} completed but only '
f'{self.in_flight} in flight'
)
) )
packet_count = self.in_flight return
self.in_flight -= packet_count in_flight_for_connection = self._in_flight_per_connection[connection_handle]
self.check_queue() if packet_count <= in_flight_for_connection:
self._in_flight_per_connection[connection_handle] -= packet_count
else:
logger.warning(
f'{packet_count} completed for {connection_handle} '
f'but only {in_flight_for_connection} in flight'
)
self._in_flight_per_connection[connection_handle] = 0
if packet_count <= self._in_flight:
self._in_flight -= packet_count
self._completed += packet_count
else:
logger.warning(
f'{packet_count} completed but only {self._in_flight} in flight'
)
self._in_flight = 0
self._completed = self._queued
self._check_queue()
self.emit('flow')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -114,7 +193,7 @@ class Connection:
self.peer_address = peer_address self.peer_address = peer_address
self.assembler = hci.HCI_AclDataPacketAssembler(self.on_acl_pdu) self.assembler = hci.HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport self.transport = transport
acl_packet_queue: Optional[AclPacketQueue] = ( acl_packet_queue: Optional[DataPacketQueue] = (
host.le_acl_packet_queue host.le_acl_packet_queue
if transport == BT_LE_TRANSPORT if transport == BT_LE_TRANSPORT
else host.acl_packet_queue else host.acl_packet_queue
@@ -129,28 +208,37 @@ class Connection:
l2cap_pdu = L2CAP_PDU.from_bytes(pdu) l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload) self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
def __str__(self) -> str:
return (
f'Connection(transport={self.transport}, peer_address={self.peer_address})'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass @dataclasses.dataclass
class ScoLink: class ScoLink:
peer_address: hci.Address peer_address: hci.Address
handle: int connection_handle: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass @dataclasses.dataclass
class CisLink: class IsoLink:
peer_address: hci.Address
handle: int handle: int
packet_queue: DataPacketQueue = dataclasses.field(repr=False)
packet_sequence_number: int = 0
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Host(AbortableEventEmitter): class Host(AbortableEventEmitter):
connections: Dict[int, Connection] connections: Dict[int, Connection]
cis_links: Dict[int, CisLink] cis_links: Dict[int, IsoLink]
bis_links: Dict[int, IsoLink]
sco_links: Dict[int, ScoLink] sco_links: Dict[int, ScoLink]
acl_packet_queue: Optional[AclPacketQueue] = None bigs: dict[int, set[int]]
le_acl_packet_queue: Optional[AclPacketQueue] = None acl_packet_queue: Optional[DataPacketQueue] = None
le_acl_packet_queue: Optional[DataPacketQueue] = None
iso_packet_queue: Optional[DataPacketQueue] = None
hci_sink: Optional[TransportSink] = None hci_sink: Optional[TransportSink] = None
hci_metadata: Dict[str, Any] hci_metadata: Dict[str, Any]
long_term_key_provider: Optional[ long_term_key_provider: Optional[
@@ -169,7 +257,9 @@ class Host(AbortableEventEmitter):
self.ready = False # True when we can accept incoming packets self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle self.connections = {} # Connections, by connection handle
self.cis_links = {} # CIS links, by connection handle self.cis_links = {} # CIS links, by connection handle
self.bis_links = {} # BIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle self.sco_links = {} # SCO links, by connection handle
self.bigs = {} # BIG Handle to BIS Handles
self.pending_command = None self.pending_command = None
self.pending_response: Optional[asyncio.Future[Any]] = None self.pending_response: Optional[asyncio.Future[Any]] = None
self.number_of_supported_advertising_sets = 0 self.number_of_supported_advertising_sets = 0
@@ -199,7 +289,7 @@ class Host(AbortableEventEmitter):
check_address_type: bool = False, check_address_type: bool = False,
) -> Optional[Connection]: ) -> Optional[Connection]:
for connection in self.connections.values(): for connection in self.connections.values():
if connection.peer_address.to_bytes() == bd_addr.to_bytes(): if bytes(connection.peer_address) == bytes(bd_addr):
if ( if (
check_address_type check_address_type
and connection.peer_address.address_type != bd_addr.address_type and connection.peer_address.address_type != bd_addr.address_type
@@ -387,6 +477,12 @@ class Host(AbortableEventEmitter):
hci.HCI_LE_TRANSMIT_POWER_REPORTING_EVENT, hci.HCI_LE_TRANSMIT_POWER_REPORTING_EVENT,
hci.HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT, hci.HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT,
hci.HCI_LE_SUBRATE_CHANGE_EVENT, hci.HCI_LE_SUBRATE_CHANGE_EVENT,
hci.HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMPLETE_EVENT,
hci.HCI_LE_CS_PROCEDURE_ENABLE_COMPLETE_EVENT,
hci.HCI_LE_CS_SECURITY_ENABLE_COMPLETE_EVENT,
hci.HCI_LE_CS_CONFIG_COMPLETE_EVENT,
hci.HCI_LE_CS_SUBEVENT_RESULT_EVENT,
hci.HCI_LE_CS_SUBEVENT_RESULT_CONTINUE_EVENT,
] ]
) )
@@ -411,39 +507,70 @@ class Host(AbortableEventEmitter):
f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}' f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}'
) )
self.acl_packet_queue = AclPacketQueue( self.acl_packet_queue = DataPacketQueue(
max_packet_size=hc_acl_data_packet_length, max_packet_size=hc_acl_data_packet_length,
max_in_flight=hc_total_num_acl_data_packets, max_in_flight=hc_total_num_acl_data_packets,
send=self.send_hci_packet, send=self.send_hci_packet,
) )
hc_le_acl_data_packet_length = 0 le_acl_data_packet_length = 0
hc_total_num_le_acl_data_packets = 0 total_num_le_acl_data_packets = 0
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND): iso_data_packet_length = 0
total_num_iso_data_packets = 0
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await self.send_command(
hci.HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
le_acl_data_packet_length = (
response.return_parameters.le_acl_data_packet_length
)
total_num_le_acl_data_packets = (
response.return_parameters.total_num_le_acl_data_packets
)
iso_data_packet_length = response.return_parameters.iso_data_packet_length
total_num_iso_data_packets = (
response.return_parameters.total_num_iso_data_packets
)
logger.debug(
'HCI LE flow control: '
f'le_acl_data_packet_length={le_acl_data_packet_length},'
f'total_num_le_acl_data_packets={total_num_le_acl_data_packets}'
f'iso_data_packet_length={iso_data_packet_length},'
f'total_num_iso_data_packets={total_num_iso_data_packets}'
)
elif self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command( response = await self.send_command(
hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True
) )
hc_le_acl_data_packet_length = ( le_acl_data_packet_length = (
response.return_parameters.hc_le_acl_data_packet_length response.return_parameters.le_acl_data_packet_length
) )
hc_total_num_le_acl_data_packets = ( total_num_le_acl_data_packets = (
response.return_parameters.hc_total_num_le_acl_data_packets response.return_parameters.total_num_le_acl_data_packets
) )
logger.debug( logger.debug(
'HCI LE ACL flow control: ' 'HCI LE ACL flow control: '
f'hc_le_acl_data_packet_length={hc_le_acl_data_packet_length},' f'le_acl_data_packet_length={le_acl_data_packet_length},'
f'hc_total_num_le_acl_data_packets={hc_total_num_le_acl_data_packets}' f'total_num_le_acl_data_packets={total_num_le_acl_data_packets}'
) )
if hc_le_acl_data_packet_length == 0 or hc_total_num_le_acl_data_packets == 0: if le_acl_data_packet_length == 0 or total_num_le_acl_data_packets == 0:
# LE and Classic share the same queue # LE and Classic share the same queue
self.le_acl_packet_queue = self.acl_packet_queue self.le_acl_packet_queue = self.acl_packet_queue
else: else:
# Create a separate queue for LE # Create a separate queue for LE
self.le_acl_packet_queue = AclPacketQueue( self.le_acl_packet_queue = DataPacketQueue(
max_packet_size=hc_le_acl_data_packet_length, max_packet_size=le_acl_data_packet_length,
max_in_flight=hc_total_num_le_acl_data_packets, max_in_flight=total_num_le_acl_data_packets,
send=self.send_hci_packet,
)
if iso_data_packet_length and total_num_iso_data_packets:
self.iso_packet_queue = DataPacketQueue(
max_packet_size=iso_data_packet_length,
max_in_flight=total_num_iso_data_packets,
send=self.send_hci_packet, send=self.send_hci_packet,
) )
@@ -552,7 +679,7 @@ class Host(AbortableEventEmitter):
return response return response
except Exception as error: except Exception as error:
logger.warning( logger.exception(
f'{color("!!! Exception while sending command:", "red")} {error}' f'{color("!!! Exception while sending command:", "red")} {error}'
) )
raise error raise error
@@ -595,11 +722,78 @@ class Host(AbortableEventEmitter):
data=l2cap_pdu[offset : offset + data_total_length], data=l2cap_pdu[offset : offset + data_total_length],
) )
logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}') logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}')
packet_queue.enqueue(acl_packet) packet_queue.enqueue(acl_packet, connection_handle)
pb_flag = 1 pb_flag = 1
offset += data_total_length offset += data_total_length
bytes_remaining -= data_total_length bytes_remaining -= data_total_length
def get_data_packet_queue(self, connection_handle: int) -> DataPacketQueue | None:
if connection := self.connections.get(connection_handle):
return connection.acl_packet_queue
if iso_link := self.cis_links.get(connection_handle) or self.bis_links.get(
connection_handle
):
return iso_link.packet_queue
return None
def send_iso_sdu(self, connection_handle: int, sdu: bytes) -> None:
if not (
iso_link := self.cis_links.get(connection_handle)
or self.bis_links.get(connection_handle)
):
logger.warning(f"no ISO link for connection handle {connection_handle}")
return
if iso_link.packet_queue is None:
logger.warning("ISO link has no data packet queue")
return
bytes_remaining = len(sdu)
offset = 0
while bytes_remaining:
is_first_fragment = offset == 0
header_length = 4 if is_first_fragment else 0
assert iso_link.packet_queue.max_packet_size > header_length
fragment_length = min(
bytes_remaining, iso_link.packet_queue.max_packet_size - header_length
)
is_last_fragment = bytes_remaining == fragment_length
iso_sdu_fragment = sdu[offset : offset + fragment_length]
iso_link.packet_queue.enqueue(
(
hci.HCI_IsoDataPacket(
connection_handle=connection_handle,
data_total_length=header_length + fragment_length,
packet_sequence_number=iso_link.packet_sequence_number,
pb_flag=0b10 if is_last_fragment else 0b00,
packet_status_flag=0,
iso_sdu_length=len(sdu),
iso_sdu_fragment=iso_sdu_fragment,
)
if is_first_fragment
else hci.HCI_IsoDataPacket(
connection_handle=connection_handle,
data_total_length=fragment_length,
pb_flag=0b11 if is_last_fragment else 0b01,
iso_sdu_fragment=iso_sdu_fragment,
)
),
connection_handle,
)
offset += fragment_length
bytes_remaining -= fragment_length
iso_link.packet_sequence_number = (iso_link.packet_sequence_number + 1) & 0xFFFF
def remove_big(self, big_handle: int) -> None:
if big := self.bigs.pop(big_handle, None):
for connection_handle in big:
if bis_link := self.bis_links.pop(connection_handle, None):
bis_link.packet_queue.flush(bis_link.handle)
def supports_command(self, op_code: int) -> bool: def supports_command(self, op_code: int) -> bool:
return ( return (
self.local_supported_commands self.local_supported_commands
@@ -727,16 +921,17 @@ class Host(AbortableEventEmitter):
def on_hci_command_status_event(self, event): def on_hci_command_status_event(self, event):
return self.on_command_processed(event) return self.on_command_processed(event)
def on_hci_number_of_completed_packets_event(self, event): def on_hci_number_of_completed_packets_event(
self, event: hci.HCI_Number_Of_Completed_Packets_Event
) -> None:
for connection_handle, num_completed_packets in zip( for connection_handle, num_completed_packets in zip(
event.connection_handles, event.num_completed_packets event.connection_handles, event.num_completed_packets
): ):
if connection := self.connections.get(connection_handle): if queue := self.get_data_packet_queue(connection_handle):
connection.acl_packet_queue.on_packets_completed(num_completed_packets) queue.on_packets_completed(num_completed_packets, connection_handle)
elif not ( continue
self.cis_links.get(connection_handle)
or self.sco_links.get(connection_handle) if connection_handle not in self.sco_links:
):
logger.warning( logger.warning(
'received packet completion event for unknown handle ' 'received packet completion event for unknown handle '
f'0x{connection_handle:04X}' f'0x{connection_handle:04X}'
@@ -854,11 +1049,7 @@ class Host(AbortableEventEmitter):
return return
if event.status == hci.HCI_SUCCESS: if event.status == hci.HCI_SUCCESS:
logger.debug( logger.debug(f'### DISCONNECTION: {connection}, reason={event.reason}')
f'### DISCONNECTION: [0x{handle:04X}] '
f'{connection.peer_address} '
f'reason={event.reason}'
)
# Notify the listeners # Notify the listeners
self.emit('disconnection', handle, event.reason) self.emit('disconnection', handle, event.reason)
@@ -869,6 +1060,14 @@ class Host(AbortableEventEmitter):
or self.cis_links.pop(handle, 0) or self.cis_links.pop(handle, 0)
or self.sco_links.pop(handle, 0) or self.sco_links.pop(handle, 0)
) )
# Flush the data queues
if self.acl_packet_queue:
self.acl_packet_queue.flush(handle)
if self.le_acl_packet_queue:
self.le_acl_packet_queue.flush(handle)
if self.iso_packet_queue:
self.iso_packet_queue.flush(handle)
else: else:
logger.debug(f'### DISCONNECTION FAILED: {event.status}') logger.debug(f'### DISCONNECTION FAILED: {event.status}')
@@ -902,8 +1101,11 @@ class Host(AbortableEventEmitter):
# Notify the client # Notify the client
if event.status == hci.HCI_SUCCESS: if event.status == hci.HCI_SUCCESS:
connection_phy = ConnectionPHY(event.tx_phy, event.rx_phy) self.emit(
self.emit('connection_phy_update', connection.handle, connection_phy) 'connection_phy_update',
connection.handle,
ConnectionPHY(event.tx_phy, event.rx_phy),
)
else: else:
self.emit('connection_phy_update_failure', connection.handle, event.status) self.emit('connection_phy_update_failure', connection.handle, event.status)
@@ -953,12 +1155,94 @@ class Host(AbortableEventEmitter):
event.cis_id, event.cis_id,
) )
def on_hci_le_create_big_complete_event(self, event):
self.bigs[event.big_handle] = set(event.connection_handle)
if self.iso_packet_queue is None:
logger.warning("BIS established but ISO packets not supported")
for connection_handle in event.connection_handle:
self.bis_links[connection_handle] = IsoLink(
connection_handle, self.iso_packet_queue
)
self.emit(
'big_establishment',
event.status,
event.big_handle,
event.connection_handle,
event.big_sync_delay,
event.transport_latency_big,
event.phy,
event.nse,
event.bn,
event.pto,
event.irc,
event.max_pdu,
event.iso_interval,
)
def on_hci_le_big_sync_established_event(self, event):
self.bigs[event.big_handle] = set(event.connection_handle)
for connection_handle in event.connection_handle:
self.bis_links[connection_handle] = IsoLink(
connection_handle, self.iso_packet_queue
)
self.emit(
'big_sync_establishment',
event.status,
event.big_handle,
event.transport_latency_big,
event.nse,
event.bn,
event.pto,
event.irc,
event.max_pdu,
event.iso_interval,
event.connection_handle,
)
def on_hci_le_big_sync_lost_event(self, event):
self.remove_big(event.big_handle)
self.emit('big_sync_lost', event.big_handle, event.reason)
def on_hci_le_terminate_big_complete_event(self, event):
self.remove_big(event.big_handle)
self.emit('big_termination', event.reason, event.big_handle)
def on_hci_le_periodic_advertising_sync_transfer_received_event(self, event):
self.emit(
'periodic_advertising_sync_transfer',
event.status,
event.connection_handle,
event.sync_handle,
event.advertising_sid,
event.advertiser_address,
event.advertiser_phy,
event.periodic_advertising_interval,
event.advertiser_clock_accuracy,
)
def on_hci_le_periodic_advertising_sync_transfer_received_v2_event(self, event):
self.emit(
'periodic_advertising_sync_transfer',
event.status,
event.connection_handle,
event.sync_handle,
event.advertising_sid,
event.advertiser_address,
event.advertiser_phy,
event.periodic_advertising_interval,
event.advertiser_clock_accuracy,
)
def on_hci_le_cis_established_event(self, event): def on_hci_le_cis_established_event(self, event):
# The remaining parameters are unused for now. # The remaining parameters are unused for now.
if event.status == hci.HCI_SUCCESS: if event.status == hci.HCI_SUCCESS:
self.cis_links[event.connection_handle] = CisLink( if self.iso_packet_queue is None:
handle=event.connection_handle, logger.warning("CIS established but ISO packets not supported")
peer_address=hci.Address.ANY, self.cis_links[event.connection_handle] = IsoLink(
handle=event.connection_handle, packet_queue=self.iso_packet_queue
) )
self.emit('cis_establishment', event.connection_handle) self.emit('cis_establishment', event.connection_handle)
else: else:
@@ -1028,7 +1312,7 @@ class Host(AbortableEventEmitter):
self.sco_links[event.connection_handle] = ScoLink( self.sco_links[event.connection_handle] = ScoLink(
peer_address=event.bd_addr, peer_address=event.bd_addr,
handle=event.connection_handle, connection_handle=event.connection_handle,
) )
# Notify the client # Notify the client
@@ -1248,3 +1532,24 @@ class Host(AbortableEventEmitter):
event.connection_handle, event.connection_handle,
int.from_bytes(event.le_features, 'little'), int.from_bytes(event.le_features, 'little'),
) )
def on_hci_le_cs_read_remote_supported_capabilities_complete_event(self, event):
self.emit('cs_remote_supported_capabilities', event)
def on_hci_le_cs_security_enable_complete_event(self, event):
self.emit('cs_security', event)
def on_hci_le_cs_config_complete_event(self, event):
self.emit('cs_config', event)
def on_hci_le_cs_procedure_enable_complete_event(self, event):
self.emit('cs_procedure', event)
def on_hci_le_cs_subevent_result_event(self, event):
self.emit('cs_subevent_result', event)
def on_hci_le_cs_subevent_result_continue_event(self, event):
self.emit('cs_subevent_result_continue', event)
def on_hci_vendor_event(self, event):
self.emit('vendor_event', event)

View File

@@ -225,7 +225,7 @@ class L2CAP_PDU:
return L2CAP_PDU(l2cap_pdu_cid, l2cap_pdu_payload) return L2CAP_PDU(l2cap_pdu_cid, l2cap_pdu_payload)
def to_bytes(self) -> bytes: def __bytes__(self) -> bytes:
header = struct.pack('<HH', len(self.payload), self.cid) header = struct.pack('<HH', len(self.payload), self.cid)
return header + self.payload return header + self.payload
@@ -233,9 +233,6 @@ class L2CAP_PDU:
self.cid = cid self.cid = cid
self.payload = payload self.payload = payload
def __bytes__(self) -> bytes:
return self.to_bytes()
def __str__(self) -> str: def __str__(self) -> str:
return f'{color("L2CAP", "green")} [CID={self.cid}]: {self.payload.hex()}' return f'{color("L2CAP", "green")} [CID={self.cid}]: {self.payload.hex()}'
@@ -333,11 +330,8 @@ class L2CAP_Control_Frame:
def init_from_bytes(self, pdu, offset): def init_from_bytes(self, pdu, offset):
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self) -> bytes:
return self.pdu
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return self.to_bytes() return self.pdu
def __str__(self) -> str: def __str__(self) -> str:
result = f'{color(self.name, "yellow")} [ID={self.identifier}]' result = f'{color(self.name, "yellow")} [ID={self.identifier}]'
@@ -779,7 +773,6 @@ class ClassicChannel(EventEmitter):
self.psm = psm self.psm = psm
self.source_cid = source_cid self.source_cid = source_cid
self.destination_cid = 0 self.destination_cid = 0
self.response = None
self.connection_result = None self.connection_result = None
self.disconnection_result = None self.disconnection_result = None
self.sink = None self.sink = None
@@ -789,27 +782,15 @@ class ClassicChannel(EventEmitter):
self.state = new_state self.state = new_state
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None: def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
if self.state != self.State.OPEN:
raise InvalidStateError('channel not open')
self.manager.send_pdu(self.connection, self.destination_cid, pdu) self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
self.manager.send_control_frame(self.connection, self.signaling_cid, frame) self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
async def send_request(self, request: SupportsBytes) -> bytes:
# Check that there isn't already a request pending
if self.response:
raise InvalidStateError('request already pending')
if self.state != self.State.OPEN:
raise InvalidStateError('channel not open')
self.response = asyncio.get_running_loop().create_future()
self.send_pdu(request)
return await self.response
def on_pdu(self, pdu: bytes) -> None: def on_pdu(self, pdu: bytes) -> None:
if self.response: if self.sink:
self.response.set_result(pdu)
self.response = None
elif self.sink:
# pylint: disable=not-callable # pylint: disable=not-callable
self.sink(pdu) self.sink(pdu)
else: else:

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2023 Google LLC # Copyright 2021-2025 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -76,18 +76,18 @@ class OobData:
return instance return instance
def to_ad(self) -> AdvertisingData: def to_ad(self) -> AdvertisingData:
ad_structures = [] ad_structures: list[tuple[int, bytes]] = []
if self.address is not None: if self.address is not None:
ad_structures.append( ad_structures.append(
(AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address)) (AdvertisingData.Type.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address))
) )
if self.role is not None: if self.role is not None:
ad_structures.append((AdvertisingData.LE_ROLE, bytes([self.role]))) ad_structures.append((AdvertisingData.Type.LE_ROLE, bytes([self.role])))
if self.shared_data is not None: if self.shared_data is not None:
ad_structures.extend(self.shared_data.to_ad().ad_structures) ad_structures.extend(self.shared_data.to_ad().ad_structures)
if self.legacy_context is not None: if self.legacy_context is not None:
ad_structures.append( ad_structures.append(
(AdvertisingData.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk) (AdvertisingData.Type.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk)
) )
return AdvertisingData(ad_structures) return AdvertisingData(ad_structures)
@@ -139,16 +139,19 @@ class PairingDelegate:
io_capability: IoCapability io_capability: IoCapability
local_initiator_key_distribution: KeyDistribution local_initiator_key_distribution: KeyDistribution
local_responder_key_distribution: KeyDistribution local_responder_key_distribution: KeyDistribution
maximum_encryption_key_size: int
def __init__( def __init__(
self, self,
io_capability: IoCapability = NO_OUTPUT_NO_INPUT, io_capability: IoCapability = NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION, local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION, local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
maximum_encryption_key_size: int = 16,
) -> None: ) -> None:
self.io_capability = io_capability self.io_capability = io_capability
self.local_initiator_key_distribution = local_initiator_key_distribution self.local_initiator_key_distribution = local_initiator_key_distribution
self.local_responder_key_distribution = local_responder_key_distribution self.local_responder_key_distribution = local_responder_key_distribution
self.maximum_encryption_key_size = maximum_encryption_key_size
@property @property
def classic_io_capability(self) -> int: def classic_io_capability(self) -> int:

View File

@@ -39,7 +39,6 @@ from bumble.device import (
AdvertisingEventProperties, AdvertisingEventProperties,
AdvertisingType, AdvertisingType,
Device, Device,
Phy,
) )
from bumble.gatt import Service from bumble.gatt import Service
from bumble.hci import ( from bumble.hci import (
@@ -47,6 +46,7 @@ from bumble.hci import (
HCI_PAGE_TIMEOUT_ERROR, HCI_PAGE_TIMEOUT_ERROR,
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
Address, Address,
Phy,
) )
from google.protobuf import any_pb2 # pytype: disable=pyi-error from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error from google.protobuf import empty_pb2 # pytype: disable=pyi-error
@@ -371,9 +371,7 @@ class HostService(HostServicer):
scan_response_data=scan_response_data, scan_response_data=scan_response_data,
) )
pending_connection: asyncio.Future[bumble.device.Connection] = ( connections: asyncio.Queue[bumble.device.Connection] = asyncio.Queue()
asyncio.get_running_loop().create_future()
)
if request.connectable: if request.connectable:
@@ -382,7 +380,7 @@ class HostService(HostServicer):
connection.transport == BT_LE_TRANSPORT connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE and connection.role == BT_PERIPHERAL_ROLE
): ):
pending_connection.set_result(connection) connections.put_nowait(connection)
self.device.on('connection', on_connection) self.device.on('connection', on_connection)
@@ -397,8 +395,7 @@ class HostService(HostServicer):
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
connection = await pending_connection connection = await connections.get()
pending_connection = asyncio.get_running_loop().create_future()
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big')) cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
yield AdvertiseResponse(connection=Connection(cookie=cookie)) yield AdvertiseResponse(connection=Connection(cookie=cookie))
@@ -492,6 +489,8 @@ class HostService(HostServicer):
target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS) target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY
connections: asyncio.Queue[bumble.device.Connection] = asyncio.Queue()
if request.connectable: if request.connectable:
def on_connection(connection: bumble.device.Connection) -> None: def on_connection(connection: bumble.device.Connection) -> None:
@@ -499,7 +498,7 @@ class HostService(HostServicer):
connection.transport == BT_LE_TRANSPORT connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE and connection.role == BT_PERIPHERAL_ROLE
): ):
pending_connection.set_result(connection) connections.put_nowait(connection)
self.device.on('connection', on_connection) self.device.on('connection', on_connection)
@@ -517,12 +516,8 @@ class HostService(HostServicer):
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
pending_connection: asyncio.Future[bumble.device.Connection] = (
asyncio.get_running_loop().create_future()
)
self.log.debug('Wait for LE connection...') self.log.debug('Wait for LE connection...')
connection = await pending_connection connection = await connections.get()
self.log.debug( self.log.debug(
f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})" f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})"

View File

@@ -17,6 +17,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import logging import logging
import struct import struct
@@ -28,10 +29,11 @@ from bumble.device import Connection
from bumble.att import ATT_Error from bumble.att import ATT_Error
from bumble.gatt import ( from bumble.gatt import (
Characteristic, Characteristic,
DelegatedCharacteristicAdapter, SerializableCharacteristicAdapter,
PackedCharacteristicAdapter,
TemplateService, TemplateService,
CharacteristicValue, CharacteristicValue,
PackedCharacteristicAdapter, UTF8CharacteristicAdapter,
GATT_AUDIO_INPUT_CONTROL_SERVICE, GATT_AUDIO_INPUT_CONTROL_SERVICE,
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC, GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC, GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
@@ -95,7 +97,7 @@ class AudioInputStatus(OpenIntEnum):
Cf. 3.4 Audio Input Status Cf. 3.4 Audio Input Status
''' '''
INATIVE = 0x00 INACTIVE = 0x00
ACTIVE = 0x01 ACTIVE = 0x01
@@ -104,7 +106,7 @@ class AudioInputControlPointOpCode(OpenIntEnum):
Cf. 3.5.1 Audio Input Control Point procedure requirements Cf. 3.5.1 Audio Input Control Point procedure requirements
''' '''
SET_GAIN_SETTING = 0x00 SET_GAIN_SETTING = 0x01
UNMUTE = 0x02 UNMUTE = 0x02
MUTE = 0x03 MUTE = 0x03
SET_MANUAL_GAIN_MODE = 0x04 SET_MANUAL_GAIN_MODE = 0x04
@@ -154,9 +156,6 @@ class AudioInputState:
attribute=self.attribute_value, value=bytes(self) attribute=self.attribute_value, value=bytes(self)
) )
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
@dataclass @dataclass
class GainSettingsProperties: class GainSettingsProperties:
@@ -173,7 +172,7 @@ class GainSettingsProperties:
(gain_settings_unit, gain_settings_minimum, gain_settings_maximum) = ( (gain_settings_unit, gain_settings_minimum, gain_settings_maximum) = (
struct.unpack('BBB', data) struct.unpack('BBB', data)
) )
GainSettingsProperties( return GainSettingsProperties(
gain_settings_unit, gain_settings_minimum, gain_settings_maximum gain_settings_unit, gain_settings_minimum, gain_settings_maximum
) )
@@ -186,9 +185,6 @@ class GainSettingsProperties:
] ]
) )
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
@dataclass @dataclass
class AudioInputControlPoint: class AudioInputControlPoint:
@@ -239,7 +235,7 @@ class AudioInputControlPoint:
or gain_settings_operand or gain_settings_operand
> self.gain_settings_properties.gain_settings_maximum > self.gain_settings_properties.gain_settings_maximum
): ):
logger.error("gain_seetings value out of range") logger.error("gain_settings value out of range")
raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE) raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
if self.audio_input_state.gain_settings != gain_settings_operand: if self.audio_input_state.gain_settings != gain_settings_operand:
@@ -321,21 +317,14 @@ class AudioInputDescription:
audio_input_description: str = "Bluetooth" audio_input_description: str = "Bluetooth"
attribute_value: Optional[CharacteristicValue] = None attribute_value: Optional[CharacteristicValue] = None
@classmethod def on_read(self, _connection: Optional[Connection]) -> str:
def from_bytes(cls, data: bytes): return self.audio_input_description
return cls(audio_input_description=data.decode('utf-8'))
def __bytes__(self) -> bytes: async def on_write(self, connection: Optional[Connection], value: str) -> None:
return self.audio_input_description.encode('utf-8')
def on_read(self, _connection: Optional[Connection]) -> bytes:
return self.audio_input_description.encode('utf-8')
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection assert connection
assert self.attribute_value assert self.attribute_value
self.audio_input_description = value.decode('utf-8') self.audio_input_description = value
await connection.device.notify_subscribers( await connection.device.notify_subscribers(
attribute=self.attribute_value, value=value attribute=self.attribute_value, value=value
) )
@@ -375,26 +364,29 @@ class AICSService(TemplateService):
self.audio_input_state, self.gain_settings_properties self.audio_input_state, self.gain_settings_properties
) )
self.audio_input_state_characteristic = DelegatedCharacteristicAdapter( self.audio_input_state_characteristic = SerializableCharacteristicAdapter(
Characteristic( Characteristic(
uuid=GATT_AUDIO_INPUT_STATE_CHARACTERISTIC, uuid=GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
properties=Characteristic.Properties.READ properties=Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY, | Characteristic.Properties.NOTIFY,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=CharacteristicValue(read=self.audio_input_state.on_read), value=self.audio_input_state,
), ),
encode=lambda value: bytes(value), AudioInputState,
) )
self.audio_input_state.attribute_value = ( self.audio_input_state.attribute_value = (
self.audio_input_state_characteristic.value self.audio_input_state_characteristic.value
) )
self.gain_settings_properties_characteristic = DelegatedCharacteristicAdapter( self.gain_settings_properties_characteristic = (
Characteristic( SerializableCharacteristicAdapter(
uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC, Characteristic(
properties=Characteristic.Properties.READ, uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, properties=Characteristic.Properties.READ,
value=CharacteristicValue(read=self.gain_settings_properties.on_read), permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=self.gain_settings_properties,
),
GainSettingsProperties,
) )
) )
@@ -402,7 +394,7 @@ class AICSService(TemplateService):
uuid=GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC, uuid=GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
properties=Characteristic.Properties.READ, properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=audio_input_type, value=bytes(audio_input_type, 'utf-8'),
) )
self.audio_input_status_characteristic = Characteristic( self.audio_input_status_characteristic = Characteristic(
@@ -412,18 +404,14 @@ class AICSService(TemplateService):
value=bytes([self.audio_input_status]), value=bytes([self.audio_input_status]),
) )
self.audio_input_control_point_characteristic = DelegatedCharacteristicAdapter( self.audio_input_control_point_characteristic = Characteristic(
Characteristic( uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC, properties=Characteristic.Properties.WRITE,
properties=Characteristic.Properties.WRITE, permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION, value=CharacteristicValue(write=self.audio_input_control_point.on_write),
value=CharacteristicValue(
write=self.audio_input_control_point.on_write
),
)
) )
self.audio_input_description_characteristic = DelegatedCharacteristicAdapter( self.audio_input_description_characteristic = UTF8CharacteristicAdapter(
Characteristic( Characteristic(
uuid=GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC, uuid=GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
properties=Characteristic.Properties.READ properties=Characteristic.Properties.READ
@@ -463,58 +451,35 @@ class AICSServiceProxy(ProfileServiceProxy):
def __init__(self, service_proxy: ServiceProxy) -> None: def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy self.service_proxy = service_proxy
if not ( self.audio_input_state = SerializableCharacteristicAdapter(
characteristics := service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC GATT_AUDIO_INPUT_STATE_CHARACTERISTIC
) ),
): AudioInputState,
raise gatt.InvalidServiceError("Audio Input State Characteristic not found")
self.audio_input_state = DelegatedCharacteristicAdapter(
characteristic=characteristics[0], decode=AudioInputState.from_bytes
) )
if not ( self.gain_settings_properties = SerializableCharacteristicAdapter(
characteristics := service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC
) ),
): GainSettingsProperties,
raise gatt.InvalidServiceError(
"Gain Settings Attribute Characteristic not found"
)
self.gain_settings_properties = PackedCharacteristicAdapter(
characteristics[0],
'BBB',
) )
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Audio Input Status Characteristic not found"
)
self.audio_input_status = PackedCharacteristicAdapter( self.audio_input_status = PackedCharacteristicAdapter(
characteristics[0], service_proxy.get_required_characteristic_by_uuid(
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC
),
'B', 'B',
) )
if not ( self.audio_input_control_point = (
characteristics := service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC
) )
): )
raise gatt.InvalidServiceError(
"Audio Input Control Point Characteristic not found"
)
self.audio_input_control_point = characteristics[0]
if not ( self.audio_input_description = UTF8CharacteristicAdapter(
characteristics := service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC
) )
): )
raise gatt.InvalidServiceError(
"Audio Input Description Characteristic not found"
)
self.audio_input_description = characteristics[0]

View File

@@ -17,6 +17,7 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import enum import enum
import logging import logging
import struct import struct
@@ -258,8 +259,8 @@ class AseReasonCode(enum.IntEnum):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AudioRole(enum.IntEnum): class AudioRole(enum.IntEnum):
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST SINK = device.CisLink.Direction.CONTROLLER_TO_HOST
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER SOURCE = device.CisLink.Direction.HOST_TO_CONTROLLER
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -300,7 +301,7 @@ class AseStateMachine(gatt.Characteristic):
presentation_delay = 0 presentation_delay = 0
# Additional parameters in ENABLING, STREAMING, DISABLING State # Additional parameters in ENABLING, STREAMING, DISABLING State
metadata = le_audio.Metadata() metadata: le_audio.Metadata
def __init__( def __init__(
self, self,
@@ -312,6 +313,7 @@ class AseStateMachine(gatt.Characteristic):
self.ase_id = ase_id self.ase_id = ase_id
self._state = AseStateMachine.State.IDLE self._state = AseStateMachine.State.IDLE
self.role = role self.role = role
self.metadata = le_audio.Metadata()
uuid = ( uuid = (
gatt.GATT_SINK_ASE_CHARACTERISTIC gatt.GATT_SINK_ASE_CHARACTERISTIC
@@ -354,16 +356,7 @@ class AseStateMachine(gatt.Characteristic):
cis_link.on('disconnection', self.on_cis_disconnection) cis_link.on('disconnection', self.on_cis_disconnection)
async def post_cis_established(): async def post_cis_established():
await self.service.device.send_command( await cis_link.setup_data_path(direction=self.role)
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=cis_link.handle,
data_path_direction=self.role,
data_path_id=0x00, # Fixed HCI
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
)
)
if self.role == AudioRole.SINK: if self.role == AudioRole.SINK:
self.state = self.State.STREAMING self.state = self.State.STREAMING
await self.service.device.notify_subscribers(self, self.value) await self.service.device.notify_subscribers(self, self.value)
@@ -511,12 +504,8 @@ class AseStateMachine(gatt.Characteristic):
self.state = self.State.RELEASING self.state = self.State.RELEASING
async def remove_cis_async(): async def remove_cis_async():
await self.service.device.send_command( if self.cis_link:
hci.HCI_LE_Remove_ISO_Data_Path_Command( await self.cis_link.remove_data_path(self.role)
connection_handle=self.cis_link.handle,
data_path_direction=self.role,
)
)
self.state = self.State.IDLE self.state = self.State.IDLE
await self.service.device.notify_subscribers(self, self.value) await self.service.device.notify_subscribers(self, self.value)

View File

@@ -288,8 +288,8 @@ class AshaServiceProxy(gatt_client.ProfileServiceProxy):
'psm_characteristic', 'psm_characteristic',
), ),
): ):
if not ( setattr(
characteristics := self.service_proxy.get_characteristics_by_uuid(uuid) self,
): attribute_name,
raise gatt.InvalidServiceError(f"Missing {uuid} Characteristic") self.service_proxy.get_required_characteristic_by_uuid(uuid),
setattr(self, attribute_name, characteristics[0]) )

View File

@@ -102,6 +102,7 @@ class ContextType(enum.IntFlag):
# fmt: off # fmt: off
PROHIBITED = 0x0000 PROHIBITED = 0x0000
UNSPECIFIED = 0x0001
CONVERSATIONAL = 0x0002 CONVERSATIONAL = 0x0002
MEDIA = 0x0004 MEDIA = 0x0004
GAME = 0x0008 GAME = 0x0008
@@ -264,7 +265,7 @@ class UnicastServerAdvertisingData:
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID, core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
struct.pack( struct.pack(
'<2sBIB', '<2sBIB',
gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE.to_bytes(), bytes(gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE),
self.announcement_type, self.announcement_type,
self.available_audio_contexts, self.available_audio_contexts,
len(self.metadata), len(self.metadata),
@@ -397,18 +398,21 @@ class CodecSpecificConfiguration:
OCTETS_PER_FRAME = 0x04 OCTETS_PER_FRAME = 0x04
CODEC_FRAMES_PER_SDU = 0x05 CODEC_FRAMES_PER_SDU = 0x05
sampling_frequency: SamplingFrequency sampling_frequency: SamplingFrequency | None = None
frame_duration: FrameDuration frame_duration: FrameDuration | None = None
audio_channel_allocation: AudioLocation audio_channel_allocation: AudioLocation | None = None
octets_per_codec_frame: int octets_per_codec_frame: int | None = None
codec_frames_per_sdu: int codec_frames_per_sdu: int | None = None
@classmethod @classmethod
def from_bytes(cls, data: bytes) -> CodecSpecificConfiguration: def from_bytes(cls, data: bytes) -> CodecSpecificConfiguration:
offset = 0 offset = 0
# Allowed default values. sampling_frequency: SamplingFrequency | None = None
audio_channel_allocation = AudioLocation.NOT_ALLOWED frame_duration: FrameDuration | None = None
codec_frames_per_sdu = 1 audio_channel_allocation: AudioLocation | None = None
octets_per_codec_frame: int | None = None
codec_frames_per_sdu: int | None = None
while offset < len(data): while offset < len(data):
length, type = struct.unpack_from('BB', data, offset) length, type = struct.unpack_from('BB', data, offset)
offset += 2 offset += 2
@@ -426,8 +430,6 @@ class CodecSpecificConfiguration:
elif type == CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU: elif type == CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU:
codec_frames_per_sdu = value codec_frames_per_sdu = value
# It is expected here that if some fields are missing, an error should be raised.
# pylint: disable=possibly-used-before-assignment,used-before-assignment
return CodecSpecificConfiguration( return CodecSpecificConfiguration(
sampling_frequency=sampling_frequency, sampling_frequency=sampling_frequency,
frame_duration=frame_duration, frame_duration=frame_duration,
@@ -437,23 +439,43 @@ class CodecSpecificConfiguration:
) )
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return struct.pack( return b''.join(
'<BBBBBBBBIBBHBBB', [
2, struct.pack(fmt, length, tag, value)
CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY, for fmt, length, tag, value in [
self.sampling_frequency, (
2, '<BBB',
CodecSpecificConfiguration.Type.FRAME_DURATION, 2,
self.frame_duration, CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY,
5, self.sampling_frequency,
CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION, ),
self.audio_channel_allocation, (
3, '<BBB',
CodecSpecificConfiguration.Type.OCTETS_PER_FRAME, 2,
self.octets_per_codec_frame, CodecSpecificConfiguration.Type.FRAME_DURATION,
2, self.frame_duration,
CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU, ),
self.codec_frames_per_sdu, (
'<BBI',
5,
CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION,
self.audio_channel_allocation,
),
(
'<BBH',
3,
CodecSpecificConfiguration.Type.OCTETS_PER_FRAME,
self.octets_per_codec_frame,
),
(
'<BBB',
2,
CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU,
self.codec_frames_per_sdu,
),
]
if value is not None
]
) )
@@ -465,6 +487,24 @@ class BroadcastAudioAnnouncement:
def from_bytes(cls, data: bytes) -> Self: def from_bytes(cls, data: bytes) -> Self:
return cls(int.from_bytes(data[:3], 'little')) return cls(int.from_bytes(data[:3], 'little'))
def __bytes__(self) -> bytes:
return self.broadcast_id.to_bytes(3, 'little')
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
(
bytes(gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE)
+ bytes(self)
),
)
]
)
)
@dataclasses.dataclass @dataclasses.dataclass
class BasicAudioAnnouncement: class BasicAudioAnnouncement:
@@ -473,26 +513,37 @@ class BasicAudioAnnouncement:
index: int index: int
codec_specific_configuration: CodecSpecificConfiguration codec_specific_configuration: CodecSpecificConfiguration
@dataclasses.dataclass def __bytes__(self) -> bytes:
class CodecInfo: codec_specific_configuration_bytes = bytes(
coding_format: hci.CodecID self.codec_specific_configuration
company_id: int )
vendor_specific_codec_id: int return (
bytes([self.index, len(codec_specific_configuration_bytes)])
@classmethod + codec_specific_configuration_bytes
def from_bytes(cls, data: bytes) -> Self: )
coding_format = hci.CodecID(data[0])
company_id = int.from_bytes(data[1:3], 'little')
vendor_specific_codec_id = int.from_bytes(data[3:5], 'little')
return cls(coding_format, company_id, vendor_specific_codec_id)
@dataclasses.dataclass @dataclasses.dataclass
class Subgroup: class Subgroup:
codec_id: BasicAudioAnnouncement.CodecInfo codec_id: hci.CodingFormat
codec_specific_configuration: CodecSpecificConfiguration codec_specific_configuration: CodecSpecificConfiguration
metadata: le_audio.Metadata metadata: le_audio.Metadata
bis: List[BasicAudioAnnouncement.BIS] bis: List[BasicAudioAnnouncement.BIS]
def __bytes__(self) -> bytes:
metadata_bytes = bytes(self.metadata)
codec_specific_configuration_bytes = bytes(
self.codec_specific_configuration
)
return (
bytes([len(self.bis)])
+ bytes(self.codec_id)
+ bytes([len(codec_specific_configuration_bytes)])
+ codec_specific_configuration_bytes
+ bytes([len(metadata_bytes)])
+ metadata_bytes
+ b''.join(map(bytes, self.bis))
)
presentation_delay: int presentation_delay: int
subgroups: List[BasicAudioAnnouncement.Subgroup] subgroups: List[BasicAudioAnnouncement.Subgroup]
@@ -504,7 +555,7 @@ class BasicAudioAnnouncement:
for _ in range(data[3]): for _ in range(data[3]):
num_bis = data[offset] num_bis = data[offset]
offset += 1 offset += 1
codec_id = cls.CodecInfo.from_bytes(data[offset : offset + 5]) codec_id = hci.CodingFormat.from_bytes(data[offset : offset + 5])
offset += 5 offset += 5
codec_specific_configuration_length = data[offset] codec_specific_configuration_length = data[offset]
offset += 1 offset += 1
@@ -548,3 +599,25 @@ class BasicAudioAnnouncement:
) )
return cls(presentation_delay, subgroups) return cls(presentation_delay, subgroups)
def __bytes__(self) -> bytes:
return (
self.presentation_delay.to_bytes(3, 'little')
+ bytes([len(self.subgroups)])
+ b''.join(map(bytes, self.subgroups))
)
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
(
bytes(gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE)
+ bytes(self)
),
)
]
)
)

View File

@@ -276,10 +276,7 @@ class BroadcastReceiveState:
subgroups: List[SubgroupInfo] subgroups: List[SubgroupInfo]
@classmethod @classmethod
def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]: def from_bytes(cls, data: bytes) -> BroadcastReceiveState:
if not data:
return None
source_id = data[0] source_id = data[0]
_, source_address = hci.Address.parse_address_preceded_by_type(data, 2) _, source_address = hci.Address.parse_address_preceded_by_type(data, 2)
source_adv_sid = data[8] source_adv_sid = data[8]
@@ -362,29 +359,20 @@ class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
def __init__(self, service_proxy: gatt_client.ServiceProxy): def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
if not ( self.broadcast_audio_scan_control_point = (
characteristics := service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC
) )
): )
raise gatt.InvalidServiceError(
"Broadcast Audio Scan Control Point characteristic not found"
)
self.broadcast_audio_scan_control_point = characteristics[0]
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Broadcast Receive State characteristic not found"
)
self.broadcast_receive_states = [ self.broadcast_receive_states = [
gatt.DelegatedCharacteristicAdapter( gatt.DelegatedCharacteristicAdapter(
characteristic, decode=BroadcastReceiveState.from_bytes characteristic,
decode=lambda x: BroadcastReceiveState.from_bytes(x) if x else None,
)
for characteristic in service_proxy.get_characteristics_by_uuid(
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC
) )
for characteristic in characteristics
] ]
async def send_control_point_operation( async def send_control_point_operation(

View File

@@ -64,7 +64,10 @@ class DeviceInformationService(TemplateService):
): ):
characteristics = [ characteristics = [
Characteristic( Characteristic(
uuid, Characteristic.Properties.READ, Characteristic.READABLE, field uuid,
Characteristic.Properties.READ,
Characteristic.READABLE,
bytes(field, 'utf-8'),
) )
for (field, uuid) in ( for (field, uuid) in (
(manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC), (manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),

View File

@@ -0,0 +1,166 @@
# Copyright 2021-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import struct
from typing import TYPE_CHECKING
from bumble import att
from bumble import gatt
from bumble import gatt_client
from bumble import crypto
if TYPE_CHECKING:
from bumble import device
# -----------------------------------------------------------------------------
class GenericAttributeProfileService(gatt.TemplateService):
'''See Vol 3, Part G - 7 - DEFINED GENERIC ATTRIBUTE PROFILE SERVICE.'''
UUID = gatt.GATT_GENERIC_ATTRIBUTE_SERVICE
client_supported_features_characteristic: gatt.Characteristic | None = None
server_supported_features_characteristic: gatt.Characteristic | None = None
database_hash_characteristic: gatt.Characteristic | None = None
service_changed_characteristic: gatt.Characteristic | None = None
def __init__(
self,
server_supported_features: gatt.ServerSupportedFeatures | None = None,
database_hash_enabled: bool = True,
service_change_enabled: bool = True,
) -> None:
if server_supported_features is not None:
self.server_supported_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([server_supported_features]),
)
if database_hash_enabled:
self.database_hash_characteristic = gatt.Characteristic(
uuid=gatt.GATT_DATABASE_HASH_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=gatt.CharacteristicValue(read=self.get_database_hash),
)
if service_change_enabled:
self.service_changed_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SERVICE_CHANGED_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.INDICATE,
permissions=gatt.Characteristic.Permissions(0),
value=b'',
)
if (database_hash_enabled and service_change_enabled) or (
server_supported_features
and (
server_supported_features & gatt.ServerSupportedFeatures.EATT_SUPPORTED
)
): # TODO: Support Multiple Handle Value Notifications
self.client_supported_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE
),
permissions=(
gatt.Characteristic.Permissions.READABLE
| gatt.Characteristic.Permissions.WRITEABLE
),
value=bytes(1),
)
super().__init__(
characteristics=[
c
for c in (
self.service_changed_characteristic,
self.client_supported_features_characteristic,
self.database_hash_characteristic,
self.server_supported_features_characteristic,
)
if c is not None
],
primary=True,
)
@classmethod
def get_attribute_data(cls, attribute: att.Attribute) -> bytes:
if attribute.type in (
gatt.GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
gatt.GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
gatt.GATT_INCLUDE_ATTRIBUTE_TYPE,
gatt.GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
gatt.GATT_CHARACTERISTIC_EXTENDED_PROPERTIES_DESCRIPTOR,
):
return (
struct.pack("<H", attribute.handle)
+ attribute.type.to_bytes()
+ attribute.value
)
elif attribute.type in (
gatt.GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
gatt.GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
gatt.GATT_SERVER_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
gatt.GATT_CHARACTERISTIC_PRESENTATION_FORMAT_DESCRIPTOR,
gatt.GATT_CHARACTERISTIC_AGGREGATE_FORMAT_DESCRIPTOR,
):
return struct.pack("<H", attribute.handle) + attribute.type.to_bytes()
return b''
def get_database_hash(self, connection: device.Connection | None) -> bytes:
assert connection
m = b''.join(
[
self.get_attribute_data(attribute)
for attribute in connection.device.gatt_server.attributes
]
)
return crypto.aes_cmac(m=m, k=bytes(16))
class GenericAttributeProfileServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = GenericAttributeProfileService
client_supported_features_characteristic: gatt_client.CharacteristicProxy | None = (
None
)
server_supported_features_characteristic: gatt_client.CharacteristicProxy | None = (
None
)
database_hash_characteristic: gatt_client.CharacteristicProxy | None = None
service_changed_characteristic: gatt_client.CharacteristicProxy | None = None
_CHARACTERISTICS = {
gatt.GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC: 'client_supported_features_characteristic',
gatt.GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC: 'server_supported_features_characteristic',
gatt.GATT_DATABASE_HASH_CHARACTERISTIC: 'database_hash_characteristic',
gatt.GATT_SERVICE_CHANGED_CHARACTERISTIC: 'service_changed_characteristic',
}
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
for uuid, attribute_name in self._CHARACTERISTICS.items():
if characteristics := self.service_proxy.get_characteristics_by_uuid(uuid):
setattr(self, attribute_name, characteristics[0])

193
bumble/profiles/gmap.py Normal file
View File

@@ -0,0 +1,193 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LE Audio - Gaming Audio Profile"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
from typing import Optional
from bumble.gatt import (
TemplateService,
DelegatedCharacteristicAdapter,
Characteristic,
GATT_GAMING_AUDIO_SERVICE,
GATT_GMAP_ROLE_CHARACTERISTIC,
GATT_UGG_FEATURES_CHARACTERISTIC,
GATT_UGT_FEATURES_CHARACTERISTIC,
GATT_BGS_FEATURES_CHARACTERISTIC,
GATT_BGR_FEATURES_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
from enum import IntFlag
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class GmapRole(IntFlag):
UNICAST_GAME_GATEWAY = 1 << 0
UNICAST_GAME_TERMINAL = 1 << 1
BROADCAST_GAME_SENDER = 1 << 2
BROADCAST_GAME_RECEIVER = 1 << 3
class UggFeatures(IntFlag):
UGG_MULTIPLEX = 1 << 0
UGG_96_KBPS_SOURCE = 1 << 1
UGG_MULTISINK = 1 << 2
class UgtFeatures(IntFlag):
UGT_SOURCE = 1 << 0
UGT_80_KBPS_SOURCE = 1 << 1
UGT_SINK = 1 << 2
UGT_64_KBPS_SINK = 1 << 3
UGT_MULTIPLEX = 1 << 4
UGT_MULTISINK = 1 << 5
UGT_MULTISOURCE = 1 << 6
class BgsFeatures(IntFlag):
BGS_96_KBPS = 1 << 0
class BgrFeatures(IntFlag):
BGR_MULTISINK = 1 << 0
BGR_MULTIPLEX = 1 << 1
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class GamingAudioService(TemplateService):
UUID = GATT_GAMING_AUDIO_SERVICE
gmap_role: Characteristic
ugg_features: Optional[Characteristic] = None
ugt_features: Optional[Characteristic] = None
bgs_features: Optional[Characteristic] = None
bgr_features: Optional[Characteristic] = None
def __init__(
self,
gmap_role: GmapRole,
ugg_features: Optional[UggFeatures] = None,
ugt_features: Optional[UgtFeatures] = None,
bgs_features: Optional[BgsFeatures] = None,
bgr_features: Optional[BgrFeatures] = None,
) -> None:
characteristics = []
ugg_features = UggFeatures(0) if ugg_features is None else ugg_features
ugt_features = UgtFeatures(0) if ugt_features is None else ugt_features
bgs_features = BgsFeatures(0) if bgs_features is None else bgs_features
bgr_features = BgrFeatures(0) if bgr_features is None else bgr_features
self.gmap_role = Characteristic(
uuid=GATT_GMAP_ROLE_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', gmap_role),
)
characteristics.append(self.gmap_role)
if gmap_role & GmapRole.UNICAST_GAME_GATEWAY:
self.ugg_features = Characteristic(
uuid=GATT_UGG_FEATURES_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', ugg_features),
)
characteristics.append(self.ugg_features)
if gmap_role & GmapRole.UNICAST_GAME_TERMINAL:
self.ugt_features = Characteristic(
uuid=GATT_UGT_FEATURES_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', ugt_features),
)
characteristics.append(self.ugt_features)
if gmap_role & GmapRole.BROADCAST_GAME_SENDER:
self.bgs_features = Characteristic(
uuid=GATT_BGS_FEATURES_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', bgs_features),
)
characteristics.append(self.bgs_features)
if gmap_role & GmapRole.BROADCAST_GAME_RECEIVER:
self.bgr_features = Characteristic(
uuid=GATT_BGR_FEATURES_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', bgr_features),
)
characteristics.append(self.bgr_features)
super().__init__(characteristics)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class GamingAudioServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GamingAudioService
def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy
self.gmap_role = DelegatedCharacteristicAdapter(
service_proxy.get_required_characteristic_by_uuid(
GATT_GMAP_ROLE_CHARACTERISTIC
),
decode=lambda value: GmapRole(value[0]),
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_UGG_FEATURES_CHARACTERISTIC
):
self.ugg_features = DelegatedCharacteristicAdapter(
characteristic=characteristics[0],
decode=lambda value: UggFeatures(value[0]),
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_UGT_FEATURES_CHARACTERISTIC
):
self.ugt_features = DelegatedCharacteristicAdapter(
characteristic=characteristics[0],
decode=lambda value: UgtFeatures(value[0]),
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BGS_FEATURES_CHARACTERISTIC
):
self.bgs_features = DelegatedCharacteristicAdapter(
characteristic=characteristics[0],
decode=lambda value: BgsFeatures(value[0]),
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BGR_FEATURES_CHARACTERISTIC
):
self.bgr_features = DelegatedCharacteristicAdapter(
characteristic=characteristics[0],
decode=lambda value: BgrFeatures(value[0]),
)

View File

@@ -30,6 +30,7 @@ from ..gatt import (
TemplateService, TemplateService,
Characteristic, Characteristic,
CharacteristicValue, CharacteristicValue,
SerializableCharacteristicAdapter,
DelegatedCharacteristicAdapter, DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter, PackedCharacteristicAdapter,
) )
@@ -150,15 +151,14 @@ class HeartRateService(TemplateService):
body_sensor_location=None, body_sensor_location=None,
reset_energy_expended=None, reset_energy_expended=None,
): ):
self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter( self.heart_rate_measurement_characteristic = SerializableCharacteristicAdapter(
Characteristic( Characteristic(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC, GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Characteristic.Properties.NOTIFY, Characteristic.Properties.NOTIFY,
0, 0,
CharacteristicValue(read=read_heart_rate_measurement), CharacteristicValue(read=read_heart_rate_measurement),
), ),
# pylint: disable=unnecessary-lambda HeartRateService.HeartRateMeasurement,
encode=lambda value: bytes(value),
) )
characteristics = [self.heart_rate_measurement_characteristic] characteristics = [self.heart_rate_measurement_characteristic]
@@ -204,9 +204,8 @@ class HeartRateServiceProxy(ProfileServiceProxy):
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
): ):
self.heart_rate_measurement = DelegatedCharacteristicAdapter( self.heart_rate_measurement = SerializableCharacteristicAdapter(
characteristics[0], characteristics[0], HeartRateService.HeartRateMeasurement
decode=HeartRateService.HeartRateMeasurement.from_bytes,
) )
else: else:
self.heart_rate_measurement = None self.heart_rate_measurement = None

View File

@@ -17,23 +17,35 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import enum
import struct import struct
from typing import List, Type from typing import Any, List, Type
from typing_extensions import Self from typing_extensions import Self
from bumble.profiles import bap
from bumble import utils from bumble import utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Classes # Classes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AudioActiveState(utils.OpenIntEnum):
NO_AUDIO_DATA_TRANSMITTED = 0x00
AUDIO_DATA_TRANSMITTED = 0x01
class AssistedListeningStream(utils.OpenIntEnum):
UNSPECIFIED_AUDIO_ENHANCEMENT = 0x00
@dataclasses.dataclass @dataclasses.dataclass
class Metadata: class Metadata:
'''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures. '''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures.
As Metadata fields may extend, and Spec doesn't forbid duplication, we don't parse As Metadata fields may extend, and the spec may not guarantee the uniqueness of
Metadata into a key-value style dataclass here. Rather, we encourage users to parse tags, we don't automatically parse the Metadata data into specific classes.
again outside the lib. Users of this class may decode the data by themselves, or use the Entry.decode
method.
''' '''
class Tag(utils.OpenIntEnum): class Tag(utils.OpenIntEnum):
@@ -57,6 +69,44 @@ class Metadata:
tag: Metadata.Tag tag: Metadata.Tag
data: bytes data: bytes
def decode(self) -> Any:
"""
Decode the data into an object, if possible.
If no specific object class exists to represent the data, the raw data
bytes are returned.
"""
if self.tag in (
Metadata.Tag.PREFERRED_AUDIO_CONTEXTS,
Metadata.Tag.STREAMING_AUDIO_CONTEXTS,
):
return bap.ContextType(struct.unpack("<H", self.data)[0])
if self.tag in (
Metadata.Tag.PROGRAM_INFO,
Metadata.Tag.PROGRAM_INFO_URI,
Metadata.Tag.BROADCAST_NAME,
):
return self.data.decode("utf-8")
if self.tag == Metadata.Tag.LANGUAGE:
return self.data.decode("ascii")
if self.tag == Metadata.Tag.CCID_LIST:
return list(self.data)
if self.tag == Metadata.Tag.PARENTAL_RATING:
return self.data[0]
if self.tag == Metadata.Tag.AUDIO_ACTIVE_STATE:
return AudioActiveState(self.data[0])
if self.tag == Metadata.Tag.ASSISTED_LISTENING_STREAM:
return AssistedListeningStream(self.data[0])
return self.data
@classmethod @classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self: def from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(tag=Metadata.Tag(data[0]), data=data[1:]) return cls(tag=Metadata.Tag(data[0]), data=data[1:])
@@ -66,6 +116,29 @@ class Metadata:
entries: List[Entry] = dataclasses.field(default_factory=list) entries: List[Entry] = dataclasses.field(default_factory=list)
def pretty_print(self, indent: str) -> str:
"""Convenience method to generate a string with one key-value pair per line."""
max_key_length = 0
keys = []
values = []
for entry in self.entries:
key = entry.tag.name
max_key_length = max(max_key_length, len(key))
keys.append(key)
decoded = entry.decode()
if isinstance(decoded, enum.Enum):
values.append(decoded.name)
elif isinstance(decoded, bytes):
values.append(decoded.hex())
else:
values.append(str(decoded))
return '\n'.join(
f'{indent}{key}: {" " * (max_key_length-len(key))}{value}'
for key, value in zip(keys, values)
)
@classmethod @classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self: def from_bytes(cls: Type[Self], data: bytes) -> Self:
entries = [] entries = []
@@ -81,3 +154,13 @@ class Metadata:
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return b''.join([bytes(entry) for entry in self.entries]) return b''.join([bytes(entry) for entry in self.entries])
def __str__(self) -> str:
entries_str = []
for entry in self.entries:
decoded = entry.decode()
entries_str.append(
f'{entry.tag.name}: '
f'{decoded.hex() if isinstance(decoded, bytes) else decoded!r}'
)
return f'Metadata(entries={", ".join(entry_str for entry_str in entries_str)})'

View File

@@ -72,6 +72,19 @@ class PacRecord:
metadata=metadata, metadata=metadata,
) )
@classmethod
def list_from_bytes(cls, data: bytes) -> list[PacRecord]:
"""Parse a serialized list of records preceded by a one byte list length."""
record_count = data[0]
records = []
offset = 1
for _ in range(record_count):
record = PacRecord.from_bytes(data[offset:])
offset += len(bytes(record))
records.append(record)
return records
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
capabilities_bytes = bytes(self.codec_specific_capabilities) capabilities_bytes = bytes(self.codec_specific_capabilities)
metadata_bytes = bytes(self.metadata) metadata_bytes = bytes(self.metadata)
@@ -172,39 +185,58 @@ class PublishedAudioCapabilitiesService(gatt.TemplateService):
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy): class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = PublishedAudioCapabilitiesService SERVICE_CLASS = PublishedAudioCapabilitiesService
sink_pac: Optional[gatt_client.CharacteristicProxy] = None sink_pac: Optional[gatt.DelegatedCharacteristicAdapter] = None
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None sink_audio_locations: Optional[gatt.DelegatedCharacteristicAdapter] = None
source_pac: Optional[gatt_client.CharacteristicProxy] = None source_pac: Optional[gatt.DelegatedCharacteristicAdapter] = None
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None source_audio_locations: Optional[gatt.DelegatedCharacteristicAdapter] = None
available_audio_contexts: gatt_client.CharacteristicProxy available_audio_contexts: gatt.DelegatedCharacteristicAdapter
supported_audio_contexts: gatt_client.CharacteristicProxy supported_audio_contexts: gatt.DelegatedCharacteristicAdapter
def __init__(self, service_proxy: gatt_client.ServiceProxy): def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid( self.available_audio_contexts = gatt.DelegatedCharacteristicAdapter(
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC service_proxy.get_required_characteristic_by_uuid(
)[0] gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid( ),
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC decode=lambda x: tuple(map(ContextType, struct.unpack('<HH', x))),
)[0] )
self.supported_audio_contexts = gatt.DelegatedCharacteristicAdapter(
service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
),
decode=lambda x: tuple(map(ContextType, struct.unpack('<HH', x))),
)
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_PAC_CHARACTERISTIC gatt.GATT_SINK_PAC_CHARACTERISTIC
): ):
self.sink_pac = characteristics[0] self.sink_pac = gatt.DelegatedCharacteristicAdapter(
characteristics[0],
decode=PacRecord.list_from_bytes,
)
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_PAC_CHARACTERISTIC gatt.GATT_SOURCE_PAC_CHARACTERISTIC
): ):
self.source_pac = characteristics[0] self.source_pac = gatt.DelegatedCharacteristicAdapter(
characteristics[0],
decode=PacRecord.list_from_bytes,
)
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
): ):
self.sink_audio_locations = characteristics[0] self.sink_audio_locations = gatt.DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda x: AudioLocation(struct.unpack('<I', x)[0]),
)
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
): ):
self.source_audio_locations = characteristics[0] self.source_audio_locations = gatt.DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda x: AudioLocation(struct.unpack('<I', x)[0]),
)

View File

@@ -25,7 +25,6 @@ from bumble.gatt import (
TemplateService, TemplateService,
Characteristic, Characteristic,
DelegatedCharacteristicAdapter, DelegatedCharacteristicAdapter,
InvalidServiceError,
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE, GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE,
GATT_TMAP_ROLE_CHARACTERISTIC, GATT_TMAP_ROLE_CHARACTERISTIC,
) )
@@ -74,15 +73,10 @@ class TelephonyAndMediaAudioServiceProxy(ProfileServiceProxy):
def __init__(self, service_proxy: ServiceProxy): def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_TMAP_ROLE_CHARACTERISTIC
)
):
raise InvalidServiceError('TMAP Role characteristic not found')
self.role = DelegatedCharacteristicAdapter( self.role = DelegatedCharacteristicAdapter(
characteristics[0], service_proxy.get_required_characteristic_by_uuid(
GATT_TMAP_ROLE_CHARACTERISTIC
),
decode=lambda value: Role( decode=lambda value: Role(
struct.unpack_from('<H', value, 0)[0], struct.unpack_from('<H', value, 0)[0],
), ),

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2024 Google LLC # Copyright 2021-2025 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -17,14 +17,16 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import dataclasses
import enum import enum
from typing import Optional, Sequence
from bumble import att from bumble import att
from bumble import device from bumble import device
from bumble import gatt from bumble import gatt
from bumble import gatt_client from bumble import gatt_client
from typing import Optional, Sequence
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
@@ -67,6 +69,20 @@ class VolumeControlPointOpcode(enum.IntEnum):
MUTE = 0x06 MUTE = 0x06
@dataclasses.dataclass
class VolumeState:
volume_setting: int
mute: int
change_counter: int
@classmethod
def from_bytes(cls, data: bytes) -> VolumeState:
return cls(data[0], data[1], data[2])
def __bytes__(self) -> bytes:
return bytes([self.volume_setting, self.mute, self.change_counter])
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Server # Server
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -126,16 +142,8 @@ class VolumeControlService(gatt.TemplateService):
included_services=list(included_services), included_services=list(included_services),
) )
@property
def volume_state_bytes(self) -> bytes:
return bytes([self.volume_setting, self.muted, self.change_counter])
@volume_state_bytes.setter
def volume_state_bytes(self, new_value: bytes) -> None:
self.volume_setting, self.muted, self.change_counter = new_value
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes: def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
return self.volume_state_bytes return bytes(VolumeState(self.volume_setting, self.muted, self.change_counter))
def _on_write_volume_control_point( def _on_write_volume_control_point(
self, connection: Optional[device.Connection], value: bytes self, connection: Optional[device.Connection], value: bytes
@@ -153,14 +161,9 @@ class VolumeControlService(gatt.TemplateService):
self.change_counter = (self.change_counter + 1) % 256 self.change_counter = (self.change_counter + 1) % 256
connection.abort_on( connection.abort_on(
'disconnection', 'disconnection',
connection.device.notify_subscribers( connection.device.notify_subscribers(attribute=self.volume_state),
attribute=self.volume_state,
value=self.volume_state_bytes,
),
)
self.emit(
'volume_state', self.volume_setting, self.muted, self.change_counter
) )
self.emit('volume_state_change')
def _on_relative_volume_down(self) -> bool: def _on_relative_volume_down(self) -> bool:
old_volume = self.volume_setting old_volume = self.volume_setting
@@ -207,24 +210,26 @@ class VolumeControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = VolumeControlService SERVICE_CLASS = VolumeControlService
volume_control_point: gatt_client.CharacteristicProxy volume_control_point: gatt_client.CharacteristicProxy
volume_state: gatt.SerializableCharacteristicAdapter
volume_flags: gatt.DelegatedCharacteristicAdapter
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None: def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy self.service_proxy = service_proxy
self.volume_state = gatt.PackedCharacteristicAdapter( self.volume_state = gatt.SerializableCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_VOLUME_STATE_CHARACTERISTIC gatt.GATT_VOLUME_STATE_CHARACTERISTIC
)[0], ),
'BBB', VolumeState,
) )
self.volume_control_point = service_proxy.get_characteristics_by_uuid( self.volume_control_point = service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC
)[0] )
self.volume_flags = gatt.PackedCharacteristicAdapter( self.volume_flags = gatt.DelegatedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC
)[0], ),
'B', decode=lambda data: VolumeFlags(data[0]),
) )

299
bumble/profiles/vocs.py Normal file
View File

@@ -0,0 +1,299 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
from dataclasses import dataclass
from typing import Optional
from bumble.device import Connection
from bumble.att import ATT_Error
from bumble.gatt import (
Characteristic,
DelegatedCharacteristicAdapter,
TemplateService,
CharacteristicValue,
SerializableCharacteristicAdapter,
UTF8CharacteristicAdapter,
GATT_VOLUME_OFFSET_CONTROL_SERVICE,
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC,
GATT_AUDIO_LOCATION_CHARACTERISTIC,
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC,
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
from bumble.utils import OpenIntEnum
from bumble.profiles.bap import AudioLocation
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
MIN_VOLUME_OFFSET = -255
MAX_VOLUME_OFFSET = 255
CHANGE_COUNTER_MAX_VALUE = 0xFF
class SetVolumeOffsetOpCode(OpenIntEnum):
SET_VOLUME_OFFSET = 0x01
class ErrorCode(OpenIntEnum):
"""
See Volume Offset Control Service 1.6. Application error codes.
"""
INVALID_CHANGE_COUNTER = 0x80
OPCODE_NOT_SUPPORTED = 0x81
VALUE_OUT_OF_RANGE = 0x82
# -----------------------------------------------------------------------------
@dataclass
class VolumeOffsetState:
volume_offset: int = 0
change_counter: int = 0
attribute_value: Optional[CharacteristicValue] = None
def __bytes__(self) -> bytes:
return struct.pack('<hB', self.volume_offset, self.change_counter)
@classmethod
def from_bytes(cls, data: bytes):
volume_offset, change_counter = struct.unpack('<hB', data)
return cls(volume_offset, change_counter)
def increment_change_counter(self) -> None:
self.change_counter = (self.change_counter + 1) % (CHANGE_COUNTER_MAX_VALUE + 1)
async def notify_subscribers_via_connection(self, connection: Connection) -> None:
assert self.attribute_value is not None
await connection.device.notify_subscribers(attribute=self.attribute_value)
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
@dataclass
class VocsAudioLocation:
audio_location: AudioLocation = AudioLocation.NOT_ALLOWED
attribute_value: Optional[CharacteristicValue] = None
def __bytes__(self) -> bytes:
return struct.pack('<I', self.audio_location)
@classmethod
def from_bytes(cls, data: bytes):
audio_location = AudioLocation(struct.unpack('<I', data)[0])
return cls(audio_location)
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
assert self.attribute_value
self.audio_location = AudioLocation(int.from_bytes(value, 'little'))
await connection.device.notify_subscribers(attribute=self.attribute_value)
@dataclass
class VolumeOffsetControlPoint:
volume_offset_state: VolumeOffsetState
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
opcode = value[0]
if opcode != SetVolumeOffsetOpCode.SET_VOLUME_OFFSET:
raise ATT_Error(ErrorCode.OPCODE_NOT_SUPPORTED)
change_counter, volume_offset = struct.unpack('<Bh', value[1:])
await self._set_volume_offset(connection, change_counter, volume_offset)
async def _set_volume_offset(
self,
connection: Connection,
change_counter_operand: int,
volume_offset_operand: int,
) -> None:
change_counter = self.volume_offset_state.change_counter
if change_counter != change_counter_operand:
raise ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
if not MIN_VOLUME_OFFSET <= volume_offset_operand <= MAX_VOLUME_OFFSET:
raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
self.volume_offset_state.volume_offset = volume_offset_operand
self.volume_offset_state.increment_change_counter()
await self.volume_offset_state.notify_subscribers_via_connection(connection)
@dataclass
class AudioOutputDescription:
audio_output_description: str = ''
attribute_value: Optional[CharacteristicValue] = None
@classmethod
def from_bytes(cls, data: bytes):
return cls(audio_output_description=data.decode('utf-8'))
def __bytes__(self) -> bytes:
return self.audio_output_description.encode('utf-8')
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
assert self.attribute_value
self.audio_output_description = value.decode('utf-8')
await connection.device.notify_subscribers(attribute=self.attribute_value)
# -----------------------------------------------------------------------------
class VolumeOffsetControlService(TemplateService):
UUID = GATT_VOLUME_OFFSET_CONTROL_SERVICE
def __init__(
self,
volume_offset_state: Optional[VolumeOffsetState] = None,
audio_location: Optional[VocsAudioLocation] = None,
audio_output_description: Optional[AudioOutputDescription] = None,
) -> None:
self.volume_offset_state = (
VolumeOffsetState() if volume_offset_state is None else volume_offset_state
)
self.audio_location = (
VocsAudioLocation() if audio_location is None else audio_location
)
self.audio_output_description = (
AudioOutputDescription()
if audio_output_description is None
else audio_output_description
)
self.volume_offset_control_point: VolumeOffsetControlPoint = (
VolumeOffsetControlPoint(self.volume_offset_state)
)
self.volume_offset_state_characteristic = Characteristic(
uuid=GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC,
properties=(
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY
),
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=CharacteristicValue(read=self.volume_offset_state.on_read),
)
self.audio_location_characteristic = Characteristic(
uuid=GATT_AUDIO_LOCATION_CHARACTERISTIC,
properties=(
Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE
),
permissions=(
Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION
),
value=CharacteristicValue(
read=self.audio_location.on_read,
write=self.audio_location.on_write,
),
)
self.audio_location.attribute_value = self.audio_location_characteristic.value
self.volume_offset_control_point_characteristic = Characteristic(
uuid=GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC,
properties=Characteristic.Properties.WRITE,
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=CharacteristicValue(write=self.volume_offset_control_point.on_write),
)
self.audio_output_description_characteristic = Characteristic(
uuid=GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC,
properties=(
Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE
),
permissions=(
Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION
),
value=CharacteristicValue(
read=self.audio_output_description.on_read,
write=self.audio_output_description.on_write,
),
)
self.audio_output_description.attribute_value = (
self.audio_output_description_characteristic.value
)
super().__init__(
characteristics=[
self.volume_offset_state_characteristic, # type: ignore
self.audio_location_characteristic, # type: ignore
self.volume_offset_control_point_characteristic, # type: ignore
self.audio_output_description_characteristic, # type: ignore
],
primary=False,
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class VolumeOffsetControlServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = VolumeOffsetControlService
def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy
self.volume_offset_state = SerializableCharacteristicAdapter(
service_proxy.get_required_characteristic_by_uuid(
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC
),
VolumeOffsetState,
)
self.audio_location = DelegatedCharacteristicAdapter(
service_proxy.get_required_characteristic_by_uuid(
GATT_AUDIO_LOCATION_CHARACTERISTIC
),
encode=lambda value: bytes([int(value)]),
decode=lambda data: AudioLocation(data[0]),
)
self.volume_offset_control_point = (
service_proxy.get_required_characteristic_by_uuid(
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC
)
)
self.audio_output_description = UTF8CharacteristicAdapter(
service_proxy.get_required_characteristic_by_uuid(
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC
)
)

View File

@@ -16,15 +16,21 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import struct import struct
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING from typing import Iterable, NewType, Optional, Union, Sequence, Type, TYPE_CHECKING
from typing_extensions import Self from typing_extensions import Self
from . import core, l2cap from bumble import core, l2cap
from .colors import color from bumble.colors import color
from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError from bumble.core import (
from .hci import HCI_Object, name_or_number, key_with_value InvalidStateError,
InvalidArgumentError,
InvalidPacketError,
ProtocolError,
)
from bumble.hci import HCI_Object, name_or_number, key_with_value
if TYPE_CHECKING: if TYPE_CHECKING:
from .device import Device, Connection from .device import Device, Connection
@@ -124,7 +130,7 @@ SDP_ATTRIBUTE_ID_NAMES = {
SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot') SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
# To be used in searches where an attribute ID list allows a range to be specified # To be used in searches where an attribute ID list allows a range to be specified
SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4) # Express this as tuple so we can convey the desired encoding size SDP_ALL_ATTRIBUTES_RANGE = (0x0000, 0xFFFF)
# fmt: on # fmt: on
# pylint: enable=line-too-long # pylint: enable=line-too-long
@@ -242,11 +248,11 @@ class DataElement:
return DataElement(DataElement.BOOLEAN, value) return DataElement(DataElement.BOOLEAN, value)
@staticmethod @staticmethod
def sequence(value: List[DataElement]) -> DataElement: def sequence(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.SEQUENCE, value) return DataElement(DataElement.SEQUENCE, value)
@staticmethod @staticmethod
def alternative(value: List[DataElement]) -> DataElement: def alternative(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.ALTERNATIVE, value) return DataElement(DataElement.ALTERNATIVE, value)
@staticmethod @staticmethod
@@ -344,9 +350,6 @@ class DataElement:
] # Keep a copy so we can re-serialize to an exact replica ] # Keep a copy so we can re-serialize to an exact replica
return result return result
def to_bytes(self):
return bytes(self)
def __bytes__(self): def __bytes__(self):
# Return early if we have a cache # Return early if we have a cache
if self.bytes: if self.bytes:
@@ -476,7 +479,9 @@ class ServiceAttribute:
self.value = value self.value = value
@staticmethod @staticmethod
def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]: def list_from_data_elements(
elements: Sequence[DataElement],
) -> list[ServiceAttribute]:
attribute_list = [] attribute_list = []
for i in range(0, len(elements) // 2): for i in range(0, len(elements) // 2):
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)] attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
@@ -489,7 +494,7 @@ class ServiceAttribute:
@staticmethod @staticmethod
def find_attribute_in_list( def find_attribute_in_list(
attribute_list: List[ServiceAttribute], attribute_id: int attribute_list: Iterable[ServiceAttribute], attribute_id: int
) -> Optional[DataElement]: ) -> Optional[DataElement]:
return next( return next(
( (
@@ -537,7 +542,12 @@ class SDP_PDU:
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
''' '''
sdp_pdu_classes: Dict[int, Type[SDP_PDU]] = {} RESPONSE_PDU_IDS = {
SDP_SERVICE_SEARCH_REQUEST: SDP_SERVICE_SEARCH_RESPONSE,
SDP_SERVICE_ATTRIBUTE_REQUEST: SDP_SERVICE_ATTRIBUTE_RESPONSE,
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE,
}
sdp_pdu_classes: dict[int, Type[SDP_PDU]] = {}
name = None name = None
pdu_id = 0 pdu_id = 0
@@ -561,7 +571,7 @@ class SDP_PDU:
@staticmethod @staticmethod
def parse_service_record_handle_list_preceded_by_count( def parse_service_record_handle_list_preceded_by_count(
data: bytes, offset: int data: bytes, offset: int
) -> Tuple[int, List[int]]: ) -> tuple[int, list[int]]:
count = struct.unpack_from('>H', data, offset - 2)[0] count = struct.unpack_from('>H', data, offset - 2)[0]
handle_list = [ handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count) struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
@@ -623,11 +633,8 @@ class SDP_PDU:
def init_from_bytes(self, pdu, offset): def init_from_bytes(self, pdu, offset):
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self):
return self.pdu
def __bytes__(self): def __bytes__(self):
return self.to_bytes() return self.pdu
def __str__(self): def __str__(self):
result = f'{color(self.name, "blue")} [TID={self.transaction_id}]' result = f'{color(self.name, "blue")} [TID={self.transaction_id}]'
@@ -645,6 +652,8 @@ class SDP_ErrorResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
''' '''
error_code: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass( @SDP_PDU.subclass(
@@ -681,7 +690,7 @@ class SDP_ServiceSearchResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
''' '''
service_record_handle_list: List[int] service_record_handle_list: list[int]
total_service_record_count: int total_service_record_count: int
current_service_record_count: int current_service_record_count: int
continuation_state: bytes continuation_state: bytes
@@ -758,31 +767,99 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
''' '''
attribute_list_byte_count: int attribute_lists_byte_count: int
attribute_list: bytes attribute_lists: bytes
continuation_state: bytes continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Client: class Client:
channel: Optional[l2cap.ClassicChannel] def __init__(self, connection: Connection, mtu: int = 0) -> None:
def __init__(self, connection: Connection) -> None:
self.connection = connection self.connection = connection
self.pending_request = None self.channel: Optional[l2cap.ClassicChannel] = None
self.channel = None self.mtu = mtu
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request: Optional[SDP_PDU] = None
self.pending_response: Optional[asyncio.futures.Future[SDP_PDU]] = None
self.next_transaction_id = 0
async def connect(self) -> None: async def connect(self) -> None:
self.channel = await self.connection.create_l2cap_channel( self.channel = await self.connection.create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(SDP_PSM) spec=(
l2cap.ClassicChannelSpec(SDP_PSM, self.mtu)
if self.mtu
else l2cap.ClassicChannelSpec(SDP_PSM)
)
) )
self.channel.sink = self.on_pdu
async def disconnect(self) -> None: async def disconnect(self) -> None:
if self.channel: if self.channel:
await self.channel.disconnect() await self.channel.disconnect()
self.channel = None self.channel = None
async def search_services(self, uuids: List[core.UUID]) -> List[int]: def make_transaction_id(self) -> int:
transaction_id = self.next_transaction_id
self.next_transaction_id = (self.next_transaction_id + 1) & 0xFFFF
return transaction_id
def on_pdu(self, pdu: bytes) -> None:
if not self.pending_request:
logger.warning('received response with no pending request')
return
assert self.pending_response is not None
response = SDP_PDU.from_bytes(pdu)
# Check that the transaction ID is what we expect
if self.pending_request.transaction_id != response.transaction_id:
logger.warning(
f"received response with transaction ID {response.transaction_id} "
f"but expected {self.pending_request.transaction_id}"
)
return
# Check if the response is an error
if isinstance(response, SDP_ErrorResponse):
self.pending_response.set_exception(
ProtocolError(error_code=response.error_code)
)
return
# Check that the type of the response matches the request
if response.pdu_id != SDP_PDU.RESPONSE_PDU_IDS.get(self.pending_request.pdu_id):
logger.warning("response type mismatch")
return
self.pending_response.set_result(response)
async def send_request(self, request: SDP_PDU) -> SDP_PDU:
assert self.channel is not None
async with self.request_semaphore:
assert self.pending_request is None
assert self.pending_response is None
# Create a future value to hold the eventual response
self.pending_response = asyncio.get_running_loop().create_future()
self.pending_request = request
try:
self.channel.send_pdu(bytes(request))
return await self.pending_response
finally:
self.pending_request = None
self.pending_response = None
async def search_services(self, uuids: Iterable[core.UUID]) -> list[int]:
"""
Search for services by UUID.
Args:
uuids: service the UUIDs to search for.
Returns:
A list of matching service record handles.
"""
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None: if self.channel is None:
@@ -797,16 +874,16 @@ class Client:
continuation_state = bytes([0]) continuation_state = bytes([0])
watchdog = SDP_CONTINUATION_WATCHDOG watchdog = SDP_CONTINUATION_WATCHDOG
while watchdog > 0: while watchdog > 0:
response_pdu = await self.channel.send_request( response = await self.send_request(
SDP_ServiceSearchRequest( SDP_ServiceSearchRequest(
transaction_id=0, # Transaction ID TODO: pick a real value transaction_id=self.make_transaction_id(),
service_search_pattern=service_search_pattern, service_search_pattern=service_search_pattern,
maximum_service_record_count=0xFFFF, maximum_service_record_count=0xFFFF,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu)
logger.debug(f'<<< Response: {response}') logger.debug(f'<<< Response: {response}')
assert isinstance(response, SDP_ServiceSearchResponse)
service_record_handle_list += response.service_record_handle_list service_record_handle_list += response.service_record_handle_list
continuation_state = response.continuation_state continuation_state = response.continuation_state
if len(continuation_state) == 1 and continuation_state[0] == 0: if len(continuation_state) == 1 and continuation_state[0] == 0:
@@ -817,8 +894,21 @@ class Client:
return service_record_handle_list return service_record_handle_list
async def search_attributes( async def search_attributes(
self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]] self,
) -> List[List[ServiceAttribute]]: uuids: Iterable[core.UUID],
attribute_ids: Iterable[Union[int, tuple[int, int]]],
) -> list[list[ServiceAttribute]]:
"""
Search for attributes by UUID and attribute IDs.
Args:
uuids: the service UUIDs to search for.
attribute_ids: list of attribute IDs or (start, end) attribute ID ranges.
(use (0, 0xFFFF) to include all attributes)
Returns:
A list of list of attributes, one list per matching service.
"""
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None: if self.channel is None:
@@ -830,8 +920,8 @@ class Client:
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
( (
DataElement.unsigned_integer( DataElement.unsigned_integer_32(
attribute_id[0], value_size=attribute_id[1] attribute_id[0] << 16 | attribute_id[1]
) )
if isinstance(attribute_id, tuple) if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id) else DataElement.unsigned_integer_16(attribute_id)
@@ -845,17 +935,17 @@ class Client:
continuation_state = bytes([0]) continuation_state = bytes([0])
watchdog = SDP_CONTINUATION_WATCHDOG watchdog = SDP_CONTINUATION_WATCHDOG
while watchdog > 0: while watchdog > 0:
response_pdu = await self.channel.send_request( response = await self.send_request(
SDP_ServiceSearchAttributeRequest( SDP_ServiceSearchAttributeRequest(
transaction_id=0, # Transaction ID TODO: pick a real value transaction_id=self.make_transaction_id(),
service_search_pattern=service_search_pattern, service_search_pattern=service_search_pattern,
maximum_attribute_byte_count=0xFFFF, maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list, attribute_id_list=attribute_id_list,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu)
logger.debug(f'<<< Response: {response}') logger.debug(f'<<< Response: {response}')
assert isinstance(response, SDP_ServiceSearchAttributeResponse)
accumulator += response.attribute_lists accumulator += response.attribute_lists
continuation_state = response.continuation_state continuation_state = response.continuation_state
if len(continuation_state) == 1 and continuation_state[0] == 0: if len(continuation_state) == 1 and continuation_state[0] == 0:
@@ -878,8 +968,18 @@ class Client:
async def get_attributes( async def get_attributes(
self, self,
service_record_handle: int, service_record_handle: int,
attribute_ids: List[Union[int, Tuple[int, int]]], attribute_ids: Iterable[Union[int, tuple[int, int]]],
) -> List[ServiceAttribute]: ) -> list[ServiceAttribute]:
"""
Get attributes for a service.
Args:
service_record_handle: the handle for a service
attribute_ids: list or attribute IDs or (start, end) attribute ID handles.
Returns:
A list of attributes.
"""
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None: if self.channel is None:
@@ -888,8 +988,8 @@ class Client:
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
( (
DataElement.unsigned_integer( DataElement.unsigned_integer_32(
attribute_id[0], value_size=attribute_id[1] attribute_id[0] << 16 | attribute_id[1]
) )
if isinstance(attribute_id, tuple) if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id) else DataElement.unsigned_integer_16(attribute_id)
@@ -903,17 +1003,17 @@ class Client:
continuation_state = bytes([0]) continuation_state = bytes([0])
watchdog = SDP_CONTINUATION_WATCHDOG watchdog = SDP_CONTINUATION_WATCHDOG
while watchdog > 0: while watchdog > 0:
response_pdu = await self.channel.send_request( response = await self.send_request(
SDP_ServiceAttributeRequest( SDP_ServiceAttributeRequest(
transaction_id=0, # Transaction ID TODO: pick a real value transaction_id=self.make_transaction_id(),
service_record_handle=service_record_handle, service_record_handle=service_record_handle,
maximum_attribute_byte_count=0xFFFF, maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list, attribute_id_list=attribute_id_list,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu)
logger.debug(f'<<< Response: {response}') logger.debug(f'<<< Response: {response}')
assert isinstance(response, SDP_ServiceAttributeResponse)
accumulator += response.attribute_list accumulator += response.attribute_list
continuation_state = response.continuation_state continuation_state = response.continuation_state
if len(continuation_state) == 1 and continuation_state[0] == 0: if len(continuation_state) == 1 and continuation_state[0] == 0:
@@ -939,17 +1039,17 @@ class Client:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Server: class Server:
CONTINUATION_STATE = bytes([0x01, 0x43]) CONTINUATION_STATE = bytes([0x01, 0x00])
channel: Optional[l2cap.ClassicChannel] channel: Optional[l2cap.ClassicChannel]
Service = NewType('Service', List[ServiceAttribute]) Service = NewType('Service', list[ServiceAttribute])
service_records: Dict[int, Service] service_records: dict[int, Service]
current_response: Union[None, bytes, Tuple[int, List[int]]] current_response: Union[None, bytes, tuple[int, list[int]]]
def __init__(self, device: Device) -> None: def __init__(self, device: Device) -> None:
self.device = device self.device = device
self.service_records = {} # Service records maps, by record handle self.service_records = {} # Service records maps, by record handle
self.channel = None self.channel = None
self.current_response = None self.current_response = None # Current response data, used for continuations
def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None: def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
l2cap_channel_manager.create_classic_server( l2cap_channel_manager.create_classic_server(
@@ -960,7 +1060,7 @@ class Server:
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}') logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
self.channel.send_pdu(response) self.channel.send_pdu(response)
def match_services(self, search_pattern: DataElement) -> Dict[int, Service]: def match_services(self, search_pattern: DataElement) -> dict[int, Service]:
# Find the services for which the attributes in the pattern is a subset of the # Find the services for which the attributes in the pattern is a subset of the
# service's attribute values (NOTE: the value search recurses into sequences) # service's attribute values (NOTE: the value search recurses into sequences)
matching_services = {} matching_services = {}
@@ -1017,6 +1117,31 @@ class Server:
) )
) )
def check_continuation(
self,
continuation_state: bytes,
transaction_id: int,
) -> Optional[bool]:
# Check if this is a valid continuation
if len(continuation_state) > 1:
if (
self.current_response is None
or continuation_state != self.CONTINUATION_STATE
):
self.send_response(
SDP_ErrorResponse(
transaction_id=transaction_id,
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
)
)
return None
return True
# Cleanup any partial response leftover
self.current_response = None
return False
def get_next_response_payload(self, maximum_size): def get_next_response_payload(self, maximum_size):
if len(self.current_response) > maximum_size: if len(self.current_response) > maximum_size:
payload = self.current_response[:maximum_size] payload = self.current_response[:maximum_size]
@@ -1031,7 +1156,7 @@ class Server:
@staticmethod @staticmethod
def get_service_attributes( def get_service_attributes(
service: Service, attribute_ids: List[DataElement] service: Service, attribute_ids: Iterable[DataElement]
) -> DataElement: ) -> DataElement:
attributes = [] attributes = []
for attribute_id in attribute_ids: for attribute_id in attribute_ids:
@@ -1059,30 +1184,24 @@ class Server:
def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None: def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if (
if self.current_response is None: continuation := self.check_continuation(
self.send_response( request.continuation_state, request.transaction_id
SDP_ErrorResponse( )
transaction_id=request.transaction_id, ) is None:
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, return
)
)
return
else:
# Cleanup any partial response leftover
self.current_response = None
if not continuation:
# Find the matching services # Find the matching services
matching_services = self.match_services(request.service_search_pattern) matching_services = self.match_services(request.service_search_pattern)
service_record_handles = list(matching_services.keys()) service_record_handles = list(matching_services.keys())
logger.debug(f'Service Record Handles: {service_record_handles}')
# Only return up to the maximum requested # Only return up to the maximum requested
service_record_handles_subset = service_record_handles[ service_record_handles_subset = service_record_handles[
: request.maximum_service_record_count : request.maximum_service_record_count
] ]
# Serialize to a byte array, and remember the total count
logger.debug(f'Service Record Handles: {service_record_handles}')
self.current_response = ( self.current_response = (
len(service_record_handles), len(service_record_handles),
service_record_handles_subset, service_record_handles_subset,
@@ -1090,15 +1209,21 @@ class Server:
# Respond, keeping any unsent handles for later # Respond, keeping any unsent handles for later
assert isinstance(self.current_response, tuple) assert isinstance(self.current_response, tuple)
service_record_handles = self.current_response[1][ assert self.channel is not None
: request.maximum_service_record_count total_service_record_count, service_record_handles = self.current_response
maximum_service_record_count = (self.channel.peer_mtu - 11) // 4
service_record_handles_remaining = service_record_handles[
maximum_service_record_count:
] ]
service_record_handles = service_record_handles[:maximum_service_record_count]
self.current_response = ( self.current_response = (
self.current_response[0], total_service_record_count,
self.current_response[1][request.maximum_service_record_count :], service_record_handles_remaining,
) )
continuation_state = ( continuation_state = (
Server.CONTINUATION_STATE if self.current_response[1] else bytes([0]) Server.CONTINUATION_STATE
if service_record_handles_remaining
else bytes([0])
) )
service_record_handle_list = b''.join( service_record_handle_list = b''.join(
[struct.pack('>I', handle) for handle in service_record_handles] [struct.pack('>I', handle) for handle in service_record_handles]
@@ -1106,7 +1231,7 @@ class Server:
self.send_response( self.send_response(
SDP_ServiceSearchResponse( SDP_ServiceSearchResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
total_service_record_count=self.current_response[0], total_service_record_count=total_service_record_count,
current_service_record_count=len(service_record_handles), current_service_record_count=len(service_record_handles),
service_record_handle_list=service_record_handle_list, service_record_handle_list=service_record_handle_list,
continuation_state=continuation_state, continuation_state=continuation_state,
@@ -1117,19 +1242,14 @@ class Server:
self, request: SDP_ServiceAttributeRequest self, request: SDP_ServiceAttributeRequest
) -> None: ) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if (
if self.current_response is None: continuation := self.check_continuation(
self.send_response( request.continuation_state, request.transaction_id
SDP_ErrorResponse( )
transaction_id=request.transaction_id, ) is None:
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, return
)
)
return
else:
# Cleanup any partial response leftover
self.current_response = None
if not continuation:
# Check that the service exists # Check that the service exists
service = self.service_records.get(request.service_record_handle) service = self.service_records.get(request.service_record_handle)
if service is None: if service is None:
@@ -1151,14 +1271,18 @@ class Server:
self.current_response = bytes(attribute_list) self.current_response = bytes(attribute_list)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
assert self.channel is not None
maximum_attribute_byte_count = min(
request.maximum_attribute_byte_count, self.channel.peer_mtu - 9
)
attribute_list_response, continuation_state = self.get_next_response_payload( attribute_list_response, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count maximum_attribute_byte_count
) )
self.send_response( self.send_response(
SDP_ServiceAttributeResponse( SDP_ServiceAttributeResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
attribute_list_byte_count=len(attribute_list_response), attribute_list_byte_count=len(attribute_list_response),
attribute_list=attribute_list, attribute_list=attribute_list_response,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
@@ -1167,18 +1291,14 @@ class Server:
self, request: SDP_ServiceSearchAttributeRequest self, request: SDP_ServiceSearchAttributeRequest
) -> None: ) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if (
if self.current_response is None: continuation := self.check_continuation(
self.send_response( request.continuation_state, request.transaction_id
SDP_ErrorResponse( )
transaction_id=request.transaction_id, ) is None:
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, return
)
)
else:
# Cleanup any partial response leftover
self.current_response = None
if not continuation:
# Find the matching services # Find the matching services
matching_services = self.match_services( matching_services = self.match_services(
request.service_search_pattern request.service_search_pattern
@@ -1198,14 +1318,18 @@ class Server:
self.current_response = bytes(attribute_lists) self.current_response = bytes(attribute_lists)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
assert self.channel is not None
maximum_attribute_byte_count = min(
request.maximum_attribute_byte_count, self.channel.peer_mtu - 9
)
attribute_lists_response, continuation_state = self.get_next_response_payload( attribute_lists_response, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count maximum_attribute_byte_count
) )
self.send_response( self.send_response(
SDP_ServiceSearchAttributeResponse( SDP_ServiceSearchAttributeResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
attribute_lists_byte_count=len(attribute_lists_response), attribute_lists_byte_count=len(attribute_lists_response),
attribute_lists=attribute_lists, attribute_lists=attribute_lists_response,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )

View File

@@ -298,11 +298,8 @@ class SMP_Command:
def init_from_bytes(self, pdu: bytes, offset: int) -> None: def init_from_bytes(self, pdu: bytes, offset: int) -> None:
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self):
return self.pdu
def __bytes__(self): def __bytes__(self):
return self.to_bytes() return self.pdu
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
@@ -698,6 +695,7 @@ class Session:
self.ltk_ediv = 0 self.ltk_ediv = 0
self.ltk_rand = bytes(8) self.ltk_rand = bytes(8)
self.link_key: Optional[bytes] = None self.link_key: Optional[bytes] = None
self.maximum_encryption_key_size: int = 0
self.initiator_key_distribution: int = 0 self.initiator_key_distribution: int = 0
self.responder_key_distribution: int = 0 self.responder_key_distribution: int = 0
self.peer_random_value: Optional[bytes] = None self.peer_random_value: Optional[bytes] = None
@@ -744,6 +742,10 @@ class Session:
else: else:
self.pairing_result = None self.pairing_result = None
self.maximum_encryption_key_size = (
pairing_config.delegate.maximum_encryption_key_size
)
# Key Distribution (default values before negotiation) # Key Distribution (default values before negotiation)
self.initiator_key_distribution = ( self.initiator_key_distribution = (
pairing_config.delegate.local_initiator_key_distribution pairing_config.delegate.local_initiator_key_distribution
@@ -996,7 +998,7 @@ class Session:
io_capability=self.io_capability, io_capability=self.io_capability,
oob_data_flag=self.oob_data_flag, oob_data_flag=self.oob_data_flag,
auth_req=self.auth_req, auth_req=self.auth_req,
maximum_encryption_key_size=16, maximum_encryption_key_size=self.maximum_encryption_key_size,
initiator_key_distribution=self.initiator_key_distribution, initiator_key_distribution=self.initiator_key_distribution,
responder_key_distribution=self.responder_key_distribution, responder_key_distribution=self.responder_key_distribution,
) )
@@ -1008,7 +1010,7 @@ class Session:
io_capability=self.io_capability, io_capability=self.io_capability,
oob_data_flag=self.oob_data_flag, oob_data_flag=self.oob_data_flag,
auth_req=self.auth_req, auth_req=self.auth_req,
maximum_encryption_key_size=16, maximum_encryption_key_size=self.maximum_encryption_key_size,
initiator_key_distribution=self.initiator_key_distribution, initiator_key_distribution=self.initiator_key_distribution,
responder_key_distribution=self.responder_key_distribution, responder_key_distribution=self.responder_key_distribution,
) )
@@ -1324,7 +1326,7 @@ class Session:
self.connection.abort_on('disconnection', self.on_pairing()) self.connection.abort_on('disconnection', self.on_pairing())
def on_connection_encryption_change(self) -> None: def on_connection_encryption_change(self) -> None:
if self.connection.is_encrypted: if self.connection.is_encrypted and not self.completed:
if self.is_responder: if self.is_responder:
# The responder distributes its keys first, the initiator later # The responder distributes its keys first, the initiator later
self.distribute_keys() self.distribute_keys()
@@ -1839,7 +1841,7 @@ class Session:
if self.is_initiator: if self.is_initiator:
if self.pairing_method == PairingMethod.OOB: if self.pairing_method == PairingMethod.OOB:
self.send_pairing_random_command() self.send_pairing_random_command()
else: elif self.pairing_method == PairingMethod.PASSKEY:
self.send_pairing_confirm_command() self.send_pairing_confirm_command()
else: else:
if self.pairing_method == PairingMethod.PASSKEY: if self.pairing_method == PairingMethod.PASSKEY:
@@ -1949,7 +1951,7 @@ class Manager(EventEmitter):
f'{connection.peer_address}: {command}' f'{connection.peer_address}: {command}'
) )
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes()) connection.send_l2cap_pdu(cid, bytes(command))
def on_smp_security_request_command( def on_smp_security_request_command(
self, connection: Connection, request: SMP_Security_Request_Command self, connection: Connection, request: SMP_Security_Request_Command

View File

@@ -370,11 +370,13 @@ class PumpedPacketSource(ParserSource):
self.parser.feed_data(packet) self.parser.feed_data(packet)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.debug('source pump task done') logger.debug('source pump task done')
self.terminated.set_result(None) if not self.terminated.done():
self.terminated.set_result(None)
break break
except Exception as error: except Exception as error:
logger.warning(f'exception while waiting for packet: {error}') logger.warning(f'exception while waiting for packet: {error}')
self.terminated.set_exception(error) if not self.terminated.done():
self.terminated.set_exception(error)
break break
self.pump_task = asyncio.create_task(pump_packets()) self.pump_task = asyncio.create_task(pump_packets())

View File

@@ -149,7 +149,10 @@ async def open_usb_transport(spec: str) -> Transport:
if status != usb1.TRANSFER_COMPLETED: if status != usb1.TRANSFER_COMPLETED:
logger.warning( logger.warning(
color(f'!!! OUT transfer not completed: status={status}', 'red') color(
f'!!! OUT transfer not completed: status={status}',
'red',
)
) )
async def process_queue(self): async def process_queue(self):
@@ -275,7 +278,10 @@ async def open_usb_transport(spec: str) -> Transport:
) )
else: else:
logger.warning( logger.warning(
color(f'!!! IN transfer not completed: status={status}', 'red') color(
f'!!! IN[{packet_type}] transfer not completed: status={status}',
'red',
)
) )
self.loop.call_soon_threadsafe(self.on_transport_lost) self.loop.call_soon_threadsafe(self.on_transport_lost)

View File

@@ -24,17 +24,19 @@ import logging
import sys import sys
import warnings import warnings
from typing import ( from typing import (
Awaitable,
Set,
TypeVar,
List,
Tuple,
Callable,
Any, Any,
Awaitable,
Callable,
List,
Optional, Optional,
Protocol,
Set,
Tuple,
TypeVar,
Union, Union,
overload, overload,
) )
from typing_extensions import Self
from pyee import EventEmitter from pyee import EventEmitter
@@ -445,7 +447,7 @@ def deprecated(msg: str):
def wrapper(function): def wrapper(function):
@functools.wraps(function) @functools.wraps(function)
def inner(*args, **kwargs): def inner(*args, **kwargs):
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning, stacklevel=2)
return function(*args, **kwargs) return function(*args, **kwargs)
return inner return inner
@@ -462,7 +464,7 @@ def experimental(msg: str):
def wrapper(function): def wrapper(function):
@functools.wraps(function) @functools.wraps(function)
def inner(*args, **kwargs): def inner(*args, **kwargs):
warnings.warn(msg, FutureWarning) warnings.warn(msg, FutureWarning, stacklevel=2)
return function(*args, **kwargs) return function(*args, **kwargs)
return inner return inner
@@ -487,3 +489,16 @@ class OpenIntEnum(enum.IntEnum):
obj._value_ = value obj._value_ = value
obj._name_ = f"{cls.__name__}[{value}]" obj._name_ = f"{cls.__name__}[{value}]"
return obj return obj
# -----------------------------------------------------------------------------
class ByteSerializable(Protocol):
"""
Type protocol for classes that can be instantiated from bytes and serialized
to bytes.
"""
@classmethod
def from_bytes(cls, data: bytes) -> Self: ...
def __bytes__(self) -> bytes: ...

View File

@@ -16,6 +16,7 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import struct import struct
from typing import Dict, Optional, Type
from bumble.hci import ( from bumble.hci import (
name_or_number, name_or_number,
@@ -24,7 +25,9 @@ from bumble.hci import (
HCI_Constant, HCI_Constant,
HCI_Object, HCI_Object,
HCI_Command, HCI_Command,
HCI_Vendor_Event, HCI_Event,
HCI_Extended_Event,
HCI_VENDOR_EVENT,
STATUS_SPEC, STATUS_SPEC,
) )
@@ -48,7 +51,6 @@ HCI_DYNAMIC_AUDIO_BUFFER_COMMAND = hci_vendor_command_op_code(0x15F)
HCI_BLUETOOTH_QUALITY_REPORT_EVENT = 0x58 HCI_BLUETOOTH_QUALITY_REPORT_EVENT = 0x58
HCI_Command.register_commands(globals()) HCI_Command.register_commands(globals())
HCI_Vendor_Event.register_subevents(globals())
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -279,7 +281,29 @@ class HCI_Dynamic_Audio_Buffer_Command(HCI_Command):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Vendor_Event.event( class HCI_Android_Vendor_Event(HCI_Extended_Event):
event_code: int = HCI_VENDOR_EVENT
subevent_classes: Dict[int, Type[HCI_Extended_Event]] = {}
@classmethod
def subclass_from_parameters(
cls, parameters: bytes
) -> Optional[HCI_Extended_Event]:
subevent_code = parameters[0]
if subevent_code == HCI_BLUETOOTH_QUALITY_REPORT_EVENT:
quality_report_id = parameters[1]
if quality_report_id in (0x01, 0x02, 0x03, 0x04, 0x07, 0x08, 0x09):
return HCI_Bluetooth_Quality_Report_Event.from_parameters(parameters)
return None
HCI_Android_Vendor_Event.register_subevents(globals())
HCI_Event.add_vendor_factory(HCI_Android_Vendor_Event.subclass_from_parameters)
# -----------------------------------------------------------------------------
@HCI_Extended_Event.event(
fields=[ fields=[
('quality_report_id', 1), ('quality_report_id', 1),
('packet_types', 1), ('packet_types', 1),
@@ -308,10 +332,11 @@ class HCI_Dynamic_Audio_Buffer_Command(HCI_Command):
('tx_last_subevent_packets', 4), ('tx_last_subevent_packets', 4),
('crc_error_packets', 4), ('crc_error_packets', 4),
('rx_duplicate_packets', 4), ('rx_duplicate_packets', 4),
('rx_unreceived_packets', 4),
('vendor_specific_parameters', '*'), ('vendor_specific_parameters', '*'),
] ]
) )
class HCI_Bluetooth_Quality_Report_Event(HCI_Vendor_Event): class HCI_Bluetooth_Quality_Report_Event(HCI_Android_Vendor_Event):
# pylint: disable=line-too-long # pylint: disable=line-too-long
''' '''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#bluetooth-quality-report-sub-event See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#bluetooth-quality-report-sub-event

View File

@@ -39,12 +39,14 @@ nav:
- Drivers: - Drivers:
- drivers/index.md - drivers/index.md
- Realtek: drivers/realtek.md - Realtek: drivers/realtek.md
- Intel: drivers/intel.md
- API: - API:
- Guide: api/guide.md - Guide: api/guide.md
- Examples: api/examples.md - Examples: api/examples.md
- Reference: api/reference.md - Reference: api/reference.md
- Apps & Tools: - Apps & Tools:
- apps_and_tools/index.md - apps_and_tools/index.md
- Auracast: apps_and_tools/auracast.md
- Console: apps_and_tools/console.md - Console: apps_and_tools/console.md
- Bench: apps_and_tools/bench.md - Bench: apps_and_tools/bench.md
- Speaker: apps_and_tools/speaker.md - Speaker: apps_and_tools/speaker.md
@@ -108,8 +110,8 @@ markdown_extensions:
- pymdownx.details - pymdownx.details
- pymdownx.superfences - pymdownx.superfences
- pymdownx.emoji: - pymdownx.emoji:
emoji_index: !!python/name:materialx.emoji.twemoji emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:materialx.emoji.to_svg emoji_generator: !!python/name:material.extensions.emoji.to_svg
- pymdownx.tabbed: - pymdownx.tabbed:
alternate_style: true alternate_style: true
- codehilite: - codehilite:

View File

@@ -4,12 +4,13 @@ APPS & TOOLS
Included in the project are a few apps and tools, built on top of the core libraries. Included in the project are a few apps and tools, built on top of the core libraries.
These include: These include:
* [Console](console.md) - an interactive text-based console * [Auracast](auracast.md) - Commands to broadcast, receive and/or control LE Audio.
* [Bench](bench.md) - Speed and Latency benchmarking between two devices (LE and Classic) * [Console](console.md) - An interactive text-based console.
* [Pair](pair.md) - Pair/bond two devices (LE and Classic) * [Bench](bench.md) - Speed and Latency benchmarking between two devices (LE and Classic).
* [Unbond](unbond.md) - Remove a previously established bond * [Pair](pair.md) - Pair/bond two devices (LE and Classic).
* [HCI Bridge](hci_bridge.md) - a HCI transport bridge to connect two HCI transports and filter/snoop the HCI packets * [Unbond](unbond.md) - Remove a previously established bond.
* [Golden Gate Bridge](gg_bridge.md) - a bridge between GATT and UDP to use with the Golden Gate "stack tool" * [HCI Bridge](hci_bridge.md) - An HCI transport bridge to connect two HCI transports and filter/snoop the HCI packets.
* [Show](show.md) - Parse a file with HCI packets and print the details of each packet in a human readable form * [Golden Gate Bridge](gg_bridge.md) - Bridge between GATT and UDP to use with the Golden Gate "stack tool".
* [Show](show.md) - Parse a file with HCI packets and print the details of each packet in a human readable form.
* [Speaker](speaker.md) - Virtual Bluetooth speaker, with a command line and browser-based UI. * [Speaker](speaker.md) - Virtual Bluetooth speaker, with a command line and browser-based UI.
* [Link Relay](link_relay.md) - WebSocket relay for virtual RemoteLink instances to communicate with each other. * [Link Relay](link_relay.md) - WebSocket relay for virtual RemoteLink instances to communicate with each other.

View File

@@ -16,4 +16,5 @@ USB vendor ID and product ID.
Drivers included in the module are: Drivers included in the module are:
* [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles. * [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles.
* [Intel](intel.md): Loading of Firmware and Config for Intel USB controllers.

View File

@@ -0,0 +1,73 @@
INTEL DRIVER
==============
This driver supports loading firmware images and optional config data to
Intel USB controllers.
A number of USB dongles are supported, but likely not all.
The initial implementation has been tested on BE200 and AX210 controllers.
When using a USB controller, the USB product ID and vendor ID are used
to find whether a matching set of firmware image and config data
is needed for that specific model. If a match exists, the driver will try
load the firmware image and, if needed, config data.
Alternatively, the metadata property ``driver=intel`` may be specified in a transport
name to force that driver to be used (ex: ``usb:[driver=intel]0`` instead of just
``usb:0`` for the first USB device).
The driver will look for the firmware and config files by name, in order, in:
* The directory specified by the environment variable `BUMBLE_INTEL_FIRMWARE_DIR`
if set.
* The directory `<package-dir>/drivers/intel_fw` where `<package-dir>` is the directory
where the `bumble` package is installed.
* The current directory.
It is also possible to override or extend the config data with parameters passed via the
transport name. The driver name `intel` may be suffixed with `/<param:value>[+<param:value>]...`
The supported params are:
* `ddc_addon`: configuration data to add to the data loaded from the config data file
* `ddc_override`: configuration data to use instead of the data loaded from the config data file.
With both `dcc_addon` and `dcc_override`, the param value is a hex-encoded byte array containing
the config data (same format as the config file).
Example transport name:
`usb:[driver=intel/dcc_addon:03E40200]0`
Obtaining Firmware Images and Config Data
-----------------------------------------
Firmware images and config data may be obtained from a variety of online
sources.
To facilitate finding a downloading the, the utility program `bumble-intel-fw-download`
may be used.
```
Usage: bumble-intel-fw-download [OPTIONS]
Download Intel firmware images and configs.
Options:
--output-dir TEXT Output directory where the files will be saved.
Defaults to the OS-specific app data dir, which the
driver will check when trying to find firmware
--source [linux-kernel] [default: linux-kernel]
--single TEXT Only download a single image set, by its base name
--force Overwrite files if they already exist
--help Show this message and exit.
```
Utility
-------
The `bumble-intel-util` utility may be used to interact with an Intel USB controller.
```
Usage: bumble-intel-util [OPTIONS] COMMAND [ARGS]...
Options:
--help Show this message and exit.
Commands:
bootloader Reboot in bootloader mode.
info Get the firmware info.
load Load a firmware image.
```

View File

@@ -9,9 +9,9 @@ for your platform.
Throughout the documentation, when shell commands are shown, it is assumed that you can Throughout the documentation, when shell commands are shown, it is assumed that you can
invoke Python as invoke Python as
``` ```
$ python $ python3
``` ```
If invoking python is different on your platform (it may be `python3` for example, or just `py` or `py.exe`), If invoking python is different on your platform (it may be `python` for example, or just `py` or `py.exe`),
adjust accordingly. adjust accordingly.
You may be simply using Bumble as a module for your own application or as a dependency to your own You may be simply using Bumble as a module for your own application or as a dependency to your own
@@ -30,12 +30,18 @@ manager, or from source.
python environment, or in a virtual environment, such as a `venv`, `pyenv` or `conda` environment. python environment, or in a virtual environment, such as a `venv`, `pyenv` or `conda` environment.
See the [Python Environments page](development/python_environments.md) page for details. See the [Python Environments page](development/python_environments.md) page for details.
### Install from PyPI
```
$ python3 -m pip install bumble
```
### Install From Source ### Install From Source
Install with `pip`. Run in a command shell in the directory where you downloaded the source Install with `pip`. Run in a command shell in the directory where you downloaded the source
distribution distribution
``` ```
$ python -m pip install -e . $ python3 -m pip install -e .
``` ```
### Install from GitHub ### Install from GitHub
@@ -44,21 +50,21 @@ You can install directly from GitHub without first downloading the repo.
Install the latest commit from the main branch with `pip`: Install the latest commit from the main branch with `pip`:
``` ```
$ python -m pip install git+https://github.com/google/bumble.git $ python3 -m pip install git+https://github.com/google/bumble.git
``` ```
You can specify a specific tag. You can specify a specific tag.
Install tag `v0.0.1` with `pip`: Install tag `v0.0.1` with `pip`:
``` ```
$ python -m pip install git+https://github.com/google/bumble.git@v0.0.1 $ python3 -m pip install git+https://github.com/google/bumble.git@v0.0.1
``` ```
You can also specify a specific commit. You can also specify a specific commit.
Install commit `27c0551` with `pip`: Install commit `27c0551` with `pip`:
``` ```
$ python -m pip install git+https://github.com/google/bumble.git@27c0551 $ python3 -m pip install git+https://github.com/google/bumble.git@27c0551
``` ```
# Working On The Bumble Code # Working On The Bumble Code
@@ -78,21 +84,21 @@ directory of the project.
```bash ```bash
$ export PYTHONPATH=. $ export PYTHONPATH=.
$ python apps/console.py serial:/dev/tty.usbmodem0006839912171 $ python3 apps/console.py serial:/dev/tty.usbmodem0006839912171
``` ```
or running an example, with the working directory set to the `examples` subdirectory or running an example, with the working directory set to the `examples` subdirectory
```bash ```bash
$ cd examples $ cd examples
$ export PYTHONPATH=.. $ export PYTHONPATH=..
$ python run_scanner.py usb:0 $ python3 run_scanner.py usb:0
``` ```
Or course, `export PYTHONPATH` only needs to be invoked once, not before each app/script execution. Or course, `export PYTHONPATH` only needs to be invoked once, not before each app/script execution.
Setting `PYTHONPATH` locally with each command would look something like: Setting `PYTHONPATH` locally with each command would look something like:
``` ```
$ PYTHONPATH=. python examples/run_advertiser.py examples/device1.json serial:/dev/tty.usbmodem0006839912171 $ PYTHONPATH=. python3 examples/run_advertiser.py examples/device1.json serial:/dev/tty.usbmodem0006839912171
``` ```
# Where To Go Next # Where To Go Next

View File

@@ -35,11 +35,11 @@ the command line.
visit [this Android Studio user guide page](https://developer.android.com/studio/run/emulator-commandline) visit [this Android Studio user guide page](https://developer.android.com/studio/run/emulator-commandline)
The `-packet-streamer-endpoint <endpoint>` command line option may be used to enable The `-packet-streamer-endpoint <endpoint>` command line option may be used to enable
Bluetooth emulation and tell the emulator which virtual controller to connect to. Bluetooth emulation and tell the emulator which virtual controller to connect to.
## Connecting to Netsim ## Connecting to Netsim
If the emulator doesn't have Bluetooth emulation enabled by default, use the If the emulator doesn't have Bluetooth emulation enabled by default, use the
`-packet-streamer-endpoint default` option to tell it to connect to Netsim. `-packet-streamer-endpoint default` option to tell it to connect to Netsim.
If Netsim is not running, the emulator will start it automatically. If Netsim is not running, the emulator will start it automatically.
@@ -60,17 +60,17 @@ the Bumble `android-netsim` transport in `host` mode (the default).
!!! example "Run the example GATT server connected to the emulator via Netsim" !!! example "Run the example GATT server connected to the emulator via Netsim"
``` shell ``` shell
$ python run_gatt_server.py device1.json android-netsim $ python3 run_gatt_server.py device1.json android-netsim
``` ```
By default, the Bumble `android-netsim` transport will try to automatically discover By default, the Bumble `android-netsim` transport will try to automatically discover
the port number on which the netsim process is exposing its gRPC server interface. If the port number on which the netsim process is exposing its gRPC server interface. If
that discovery process fails, or if you want to specify the interface manually, you that discovery process fails, or if you want to specify the interface manually, you
can pass a `hostname` and `port` as parameters to the transport, as: `android-netsim:<host>:<port>`. can pass a `hostname` and `port` as parameters to the transport, as: `android-netsim:<host>:<port>`.
!!! example "Run the example GATT server connected to the emulator via Netsim on a localhost, port 8877" !!! example "Run the example GATT server connected to the emulator via Netsim on a localhost, port 8877"
``` shell ``` shell
$ python run_gatt_server.py device1.json android-netsim:localhost:8877 $ python3 run_gatt_server.py device1.json android-netsim:localhost:8877
``` ```
### Multiple Instances ### Multiple Instances
@@ -84,7 +84,7 @@ For example: `android-netsim:localhost:8877,name=bumble1`
This is an advanced use case, which may not be officially supported, but should work in recent This is an advanced use case, which may not be officially supported, but should work in recent
versions of the emulator. versions of the emulator.
The first step is to run the Bumble HCI bridge, specifying netsim as the "host" end of the The first step is to run the Bumble HCI bridge, specifying netsim as the "host" end of the
bridge, and another controller (typically a USB Bluetooth dongle, but any other supported bridge, and another controller (typically a USB Bluetooth dongle, but any other supported
transport can work as well) as the "controller" end of the bridge. transport can work as well) as the "controller" end of the bridge.

View File

@@ -0,0 +1,9 @@
{
"name": "Bumble CS Initiator",
"address": "F0:F1:F2:F3:F4:F5",
"advertising_interval": 100,
"keystore": "JsonKeyStore",
"irk": "865F81FF5A8B486EAAE29A27AD9F77DC",
"identity_address_type": 1,
"channel_sounding_enabled": true
}

View File

@@ -0,0 +1,9 @@
{
"name": "Bumble CS Reflector",
"address": "F0:F1:F2:F3:F4:F6",
"advertising_interval": 100,
"keystore": "JsonKeyStore",
"irk": "0c7d74db03a1c98e7be691f76141d53d",
"identity_address_type": 1,
"channel_sounding_enabled": true
}

View File

@@ -282,7 +282,7 @@ async def keyboard_device(device, command):
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
Characteristic.Properties.READ, Characteristic.Properties.READ,
Characteristic.READABLE, Characteristic.READABLE,
'Bumble', bytes('Bumble', 'utf-8'),
) )
], ],
), ),

View File

@@ -28,7 +28,7 @@ class OneDeviceBenchTest(base_test.BaseTestClass):
def test_l2cap_client_ping(self): def test_l2cap_client_ping(self):
runner = self.dut.bench.runL2capClient( runner = self.dut.bench.runL2capClient(
"ping", "4B:2A:67:76:2B:E3", 128, True, 100, 970, 100 "ping", "4B:2A:67:76:2B:E3", 128, True, 100, 970, 100, "HIGH"
) )
print("### Initial status:", runner) print("### Initial status:", runner)
final_status = self.dut.bench.waitForRunnerCompletion(runner["id"]) final_status = self.dut.bench.waitForRunnerCompletion(runner["id"])
@@ -36,12 +36,34 @@ class OneDeviceBenchTest(base_test.BaseTestClass):
def test_l2cap_client_send(self): def test_l2cap_client_send(self):
runner = self.dut.bench.runL2capClient( runner = self.dut.bench.runL2capClient(
"send", "7E:90:D0:F2:7A:11", 131, True, 100, 970, 0 "send",
"F1:F1:F1:F1:F1:F1",
128,
True,
100,
970,
0,
"HIGH",
10000,
) )
print("### Initial status:", runner) print("### Initial status:", runner)
final_status = self.dut.bench.waitForRunnerCompletion(runner["id"]) final_status = self.dut.bench.waitForRunnerCompletion(runner["id"])
print("### Final status:", final_status) print("### Final status:", final_status)
def test_gatt_client_send(self):
runner = self.dut.bench.runGattClient(
"send", "F1:F1:F1:F1:F1:F1", 128, True, 100, 970, 100, "HIGH"
)
print("### Initial status:", runner)
final_status = self.dut.bench.waitForRunnerCompletion(runner["id"])
print("### Final status:", final_status)
def test_gatt_server_receive(self):
runner = self.dut.bench.runGattServer("receive")
print("### Initial status:", runner)
final_status = self.dut.bench.waitForRunnerCompletion(runner["id"])
print("### Final status:", final_status)
if __name__ == "__main__": if __name__ == "__main__":
test_runner.main() test_runner.main()

View File

@@ -2,8 +2,8 @@ TestBeds:
- Name: BenchTestBed - Name: BenchTestBed
Controllers: Controllers:
AndroidDevice: AndroidDevice:
- serial: 37211FDJG000DJ - serial: emulator-5554
local_bt_address: 94:45:60:5E:03:B0 local_bt_address: 94:45:60:5E:03:B0
- serial: 23071FDEE001F7 #- serial: 23071FDEE001F7
local_bt_address: DC:E5:5B:E5:51:2C # local_bt_address: DC:E5:5B:E5:51:2C

View File

@@ -0,0 +1,154 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import sys
import os
import functools
from bumble import core
from bumble import hci
from bumble.device import Connection, Device, ChannelSoundingCapabilities
from bumble.transport import open_transport_or_link
# From https://cs.android.com/android/platform/superproject/main/+/main:packages/modules/Bluetooth/system/gd/hci/distance_measurement_manager.cc.
CS_TONE_ANTENNA_CONFIG_MAPPING_TABLE = [
[0, 4, 5, 6],
[1, 7, 7, 7],
[2, 7, 7, 7],
[3, 7, 7, 7],
]
CS_PREFERRED_PEER_ANTENNA_MAPPING_TABLE = [1, 1, 1, 1, 3, 7, 15, 3]
CS_ANTENNA_PERMUTATION_ARRAY = [
[1, 2, 3, 4],
[2, 1, 3, 4],
[1, 3, 2, 4],
[3, 1, 2, 4],
[3, 2, 1, 4],
[2, 3, 1, 4],
[1, 2, 4, 3],
[2, 1, 4, 3],
[1, 4, 2, 3],
[4, 1, 2, 3],
[4, 2, 1, 3],
[2, 4, 1, 3],
[1, 4, 3, 2],
[4, 1, 3, 2],
[1, 3, 4, 2],
[3, 1, 4, 2],
[3, 4, 1, 2],
[4, 3, 1, 2],
[4, 2, 3, 1],
[2, 4, 3, 1],
[4, 3, 2, 1],
[3, 4, 2, 1],
[3, 2, 4, 1],
[2, 3, 4, 1],
]
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) < 3:
print(
'Usage: run_channel_sounding.py <config-file> <transport-spec-for-device>'
'[target_address](If missing, run as reflector)'
)
print('example: run_channel_sounding.py cs_reflector.json usb:0')
print(
'example: run_channel_sounding.py cs_initiator.json usb:0 F0:F1:F2:F3:F4:F5'
)
return
print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as hci_transport:
print('<<< connected')
device = Device.from_config_file_with_hci(
sys.argv[1], hci_transport.source, hci_transport.sink
)
await device.power_on()
assert (local_cs_capabilities := device.cs_capabilities)
if len(sys.argv) == 3:
print('<<< Start Advertising')
await device.start_advertising(
own_address_type=hci.OwnAddressType.RANDOM, auto_restart=True
)
def on_cs_capabilities(
connection: Connection, capabilities: ChannelSoundingCapabilities
):
del capabilities
print('<<< Set CS Settings')
asyncio.create_task(device.set_default_cs_settings(connection))
device.on(
'connection',
lambda connection: connection.on(
'channel_sounding_capabilities',
functools.partial(on_cs_capabilities, connection),
),
)
else:
target_address = hci.Address(sys.argv[3])
print(f'<<< Connecting to {target_address}')
connection = await device.connect(
target_address, transport=core.BT_LE_TRANSPORT
)
print('<<< ACL Connected')
if not (await device.get_long_term_key(connection.handle, b'', 0)):
print('<<< No bond, start pairing')
await connection.pair()
print('<<< Pairing complete')
print('<<< Encrypting Connection')
await connection.encrypt()
print('<<< Getting remote CS Capabilities...')
remote_capabilities = await device.get_remote_cs_capabilities(connection)
print('<<< Set CS Settings...')
await device.set_default_cs_settings(connection)
print('<<< Set CS Config...')
config = await device.create_cs_config(connection)
print('<<< Enable CS Security...')
await device.enable_cs_security(connection)
tone_antenna_config_selection = CS_TONE_ANTENNA_CONFIG_MAPPING_TABLE[
local_cs_capabilities.num_antennas_supported - 1
][remote_capabilities.num_antennas_supported - 1]
print('<<< Set CS Procedure Parameters...')
await device.set_cs_procedure_parameters(
connection=connection,
config=config,
tone_antenna_config_selection=tone_antenna_config_selection,
preferred_peer_antenna=CS_PREFERRED_PEER_ANTENNA_MAPPING_TABLE[
tone_antenna_config_selection
],
)
print('<<< Enable CS Procedure...')
await device.enable_cs_procedure(connection=connection, config=config)
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -127,7 +127,7 @@ async def main() -> None:
'486F64C6-4B5F-4B3B-8AFF-EDE134A8446A', '486F64C6-4B5F-4B3B-8AFF-EDE134A8446A',
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY, Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE, Characteristic.READABLE,
'hello', bytes('hello', 'utf-8'),
), ),
], ],
) )

View File

@@ -0,0 +1,319 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import logging
import os
import random
import struct
import sys
from typing import Any, List, Union
from bumble.device import Device, Peer
from bumble import transport
from bumble import gatt
from bumble import hci
from bumble import core
# -----------------------------------------------------------------------------
SERVICE_UUID = core.UUID("50DB505C-8AC4-4738-8448-3B1D9CC09CC5")
CHARACTERISTIC_UUID_BASE = "D901B45B-4916-412E-ACCA-0000000000"
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class CustomSerializableClass:
x: int
y: int
@classmethod
def from_bytes(cls, data: bytes) -> CustomSerializableClass:
return cls(*struct.unpack(">II", data))
def __bytes__(self) -> bytes:
return struct.pack(">II", self.x, self.y)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class CustomClass:
a: int
b: int
@classmethod
def decode(cls, data: bytes) -> CustomClass:
return cls(*struct.unpack(">II", data))
def encode(self) -> bytes:
return struct.pack(">II", self.a, self.b)
# -----------------------------------------------------------------------------
async def client(device: Device, address: hci.Address) -> None:
print(f'=== Connecting to {address}...')
connection = await device.connect(address)
print('=== Connected')
# Discover all characteristics.
peer = Peer(connection)
print("*** Discovering services and characteristics...")
await peer.discover_all()
print("*** Discovery complete")
service = peer.get_services_by_uuid(SERVICE_UUID)[0]
characteristics = []
for index in range(1, 9):
characteristics.append(
service.get_characteristics_by_uuid(
core.UUID(CHARACTERISTIC_UUID_BASE + f"{index:02X}")
)[0]
)
# Read all characteristics as raw bytes.
for characteristic in characteristics:
value = await characteristic.read_value()
print(f"### {characteristic} = {value!r} ({value.hex()})")
# Static characteristic with a bytes value.
c1 = characteristics[0]
c1_value = await c1.read_value()
print(f"@@@ C1 {c1} value = {c1_value!r} (type={type(c1_value)})")
await c1.write_value("happy π day".encode("utf-8"))
# Static characteristic with a string value.
c2 = gatt.UTF8CharacteristicAdapter(characteristics[1])
c2_value = await c2.read_value()
print(f"@@@ C2 {c2} value = {c2_value} (type={type(c2_value)})")
await c2.write_value("happy π day")
# Static characteristic with a tuple value.
c3 = gatt.PackedCharacteristicAdapter(characteristics[2], ">III")
c3_value = await c3.read_value()
print(f"@@@ C3 {c3} value = {c3_value} (type={type(c3_value)})")
await c3.write_value((2001, 2002, 2003))
# Static characteristic with a named tuple value.
c4 = gatt.MappedCharacteristicAdapter(
characteristics[3], ">III", ["f1", "f2", "f3"]
)
c4_value = await c4.read_value()
print(f"@@@ C4 {c4} value = {c4_value} (type={type(c4_value)})")
await c4.write_value({"f1": 4001, "f2": 4002, "f3": 4003})
# Static characteristic with a serializable value.
c5 = gatt.SerializableCharacteristicAdapter(
characteristics[4], CustomSerializableClass
)
c5_value = await c5.read_value()
print(f"@@@ C5 {c5} value = {c5_value} (type={type(c5_value)})")
await c5.write_value(CustomSerializableClass(56, 57))
# Static characteristic with a delegated value.
c6 = gatt.DelegatedCharacteristicAdapter(
characteristics[5], encode=CustomClass.encode, decode=CustomClass.decode
)
c6_value = await c6.read_value()
print(f"@@@ C6 {c6} value = {c6_value} (type={type(c6_value)})")
await c6.write_value(CustomClass(6, 7))
# Dynamic characteristic with a bytes value.
c7 = characteristics[6]
c7_value = await c7.read_value()
print(f"@@@ C7 {c7} value = {c7_value!r} (type={type(c7_value)})")
await c7.write_value(bytes.fromhex("01020304"))
# Dynamic characteristic with a string value.
c8 = gatt.UTF8CharacteristicAdapter(characteristics[7])
c8_value = await c8.read_value()
print(f"@@@ C8 {c8} value = {c8_value} (type={type(c8_value)})")
await c8.write_value("howdy")
# -----------------------------------------------------------------------------
def dynamic_read(selector: str) -> Union[bytes, str]:
if selector == "bytes":
print("$$$ Returning random bytes")
return random.randbytes(7)
elif selector == "string":
print("$$$ Returning random string")
return random.randbytes(7).hex()
raise ValueError("invalid selector")
# -----------------------------------------------------------------------------
def dynamic_write(selector: str, value: Any) -> None:
print(f"$$$ Received[{selector}]: {value} (type={type(value)})")
# -----------------------------------------------------------------------------
def on_characteristic_read(characteristic: gatt.Characteristic, value: Any) -> None:
"""Event listener invoked when a characteristic is read."""
print(f"<<< READ: {characteristic} -> {value} ({type(value)})")
# -----------------------------------------------------------------------------
def on_characteristic_write(characteristic: gatt.Characteristic, value: Any) -> None:
"""Event listener invoked when a characteristic is written."""
print(f"<<< WRITE: {characteristic} <- {value} ({type(value)})")
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) < 2:
print("Usage: run_gatt_with_adapters.py <transport-spec> [<bluetooth-address>]")
print("example: run_gatt_with_adapters.py usb:0 E1:CA:72:48:C4:E8")
return
async with await transport.open_transport(sys.argv[1]) as hci_transport:
# Create a device to manage the host
device = Device.with_hci(
"Bumble",
hci.Address("F0:F1:F2:F3:F4:F5"),
hci_transport.source,
hci_transport.sink,
)
# Static characteristic with a bytes value.
c1 = gatt.Characteristic(
CHARACTERISTIC_UUID_BASE + "01",
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.WRITE,
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
b'hello',
)
# Static characteristic with a string value.
c2 = gatt.UTF8CharacteristicAdapter(
gatt.Characteristic(
CHARACTERISTIC_UUID_BASE + "02",
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE,
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
'hello',
)
)
# Static characteristic with a tuple value.
c3 = gatt.PackedCharacteristicAdapter(
gatt.Characteristic(
CHARACTERISTIC_UUID_BASE + "03",
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE,
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
(1007, 1008, 1009),
),
">III",
)
# Static characteristic with a named tuple value.
c4 = gatt.MappedCharacteristicAdapter(
gatt.Characteristic(
CHARACTERISTIC_UUID_BASE + "04",
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE,
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
{"f1": 3007, "f2": 3008, "f3": 3009},
),
">III",
["f1", "f2", "f3"],
)
# Static characteristic with a serializable value.
c5 = gatt.SerializableCharacteristicAdapter(
gatt.Characteristic(
CHARACTERISTIC_UUID_BASE + "05",
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE,
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
CustomSerializableClass(11, 12),
),
CustomSerializableClass,
)
# Static characteristic with a delegated value.
c6 = gatt.DelegatedCharacteristicAdapter(
gatt.Characteristic(
CHARACTERISTIC_UUID_BASE + "06",
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE,
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
CustomClass(1, 2),
),
encode=CustomClass.encode,
decode=CustomClass.decode,
)
# Dynamic characteristic with a bytes value.
c7 = gatt.Characteristic(
CHARACTERISTIC_UUID_BASE + "07",
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.WRITE,
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
gatt.CharacteristicValue(
read=lambda connection: dynamic_read("bytes"),
write=lambda connection, value: dynamic_write("bytes", value),
),
)
# Dynamic characteristic with a string value.
c8 = gatt.UTF8CharacteristicAdapter(
gatt.Characteristic(
CHARACTERISTIC_UUID_BASE + "08",
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE,
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
gatt.CharacteristicValue(
read=lambda connection: dynamic_read("string"),
write=lambda connection, value: dynamic_write("string", value),
),
)
)
characteristics: List[
Union[gatt.Characteristic, gatt.CharacteristicAdapter]
] = [c1, c2, c3, c4, c5, c6, c7, c8]
# Listen for read and write events.
for characteristic in characteristics:
characteristic.on(
"read",
lambda _, value, c=characteristic: on_characteristic_read(c, value),
)
characteristic.on(
"write",
lambda _, value, c=characteristic: on_characteristic_write(c, value),
)
device.add_service(gatt.Service(SERVICE_UUID, characteristics)) # type: ignore
# Get things going
await device.power_on()
# Connect to a peer
if len(sys.argv) > 2:
await client(device, hci.Address(sys.argv[2]))
else:
await device.start_advertising(auto_restart=True)
await hci_transport.source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(main())

View File

@@ -21,9 +21,9 @@ import sys
import os import os
import io import io
import logging import logging
import websockets from typing import Iterable, Optional
from typing import Optional import websockets
import bumble.core import bumble.core
from bumble.device import Device, ScoLink from bumble.device import Device, ScoLink
@@ -82,6 +82,10 @@ def on_microphone_volume(level: int):
send_message(type='microphone_volume', level=level) send_message(type='microphone_volume', level=level)
def on_supported_audio_codecs(codecs: Iterable[hfp.AudioCodec]):
send_message(type='supported_audio_codecs', codecs=[codec.name for codec in codecs])
def on_sco_state_change(codec: int): def on_sco_state_change(codec: int):
if codec == hfp.AudioCodec.CVSD: if codec == hfp.AudioCodec.CVSD:
sample_rate = 8000 sample_rate = 8000
@@ -207,6 +211,7 @@ async def main() -> None:
ag_protocol = hfp.AgProtocol(dlc, configuration) ag_protocol = hfp.AgProtocol(dlc, configuration)
ag_protocol.on('speaker_volume', on_speaker_volume) ag_protocol.on('speaker_volume', on_speaker_volume)
ag_protocol.on('microphone_volume', on_microphone_volume) ag_protocol.on('microphone_volume', on_microphone_volume)
ag_protocol.on('supported_audio_codecs', on_supported_audio_codecs)
on_hfp_state_change(True) on_hfp_state_change(True)
dlc.multiplexer.l2cap_channel.on( dlc.multiplexer.l2cap_channel.on(
'close', lambda: on_hfp_state_change(False) 'close', lambda: on_hfp_state_change(False)
@@ -241,7 +246,7 @@ async def main() -> None:
# Pick the first one # Pick the first one
channel, version, hf_sdp_features = hfp_record channel, version, hf_sdp_features = hfp_record
print(f'HF version: {version}') print(f'HF version: {version}')
print(f'HF features: {hf_sdp_features}') print(f'HF features: {hf_sdp_features.name}')
# Request authentication # Request authentication
print('*** Authenticating...') print('*** Authenticating...')

View File

@@ -161,7 +161,13 @@ async def main() -> None:
else: else:
file_output = open(f'{datetime.datetime.now().isoformat()}.lc3', 'wb') file_output = open(f'{datetime.datetime.now().isoformat()}.lc3', 'wb')
codec_configuration = ase.codec_specific_configuration codec_configuration = ase.codec_specific_configuration
assert isinstance(codec_configuration, CodecSpecificConfiguration) if (
not isinstance(codec_configuration, CodecSpecificConfiguration)
or codec_configuration.sampling_frequency is None
or codec_configuration.audio_channel_allocation is None
or codec_configuration.frame_duration is None
):
return
# Write a LC3 header. # Write a LC3 header.
file_output.write( file_output.write(
bytes([0x1C, 0xCC]) # Header. bytes([0x1C, 0xCC]) # Header.

View File

@@ -42,7 +42,7 @@ from bumble.profiles.bap import (
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
from bumble.profiles.cap import CommonAudioServiceService from bumble.profiles.cap import CommonAudioServiceService
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
from bumble.profiles.vcp import VolumeControlService from bumble.profiles.vcs import VolumeControlService
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -117,13 +117,17 @@ async def main() -> None:
ws: Optional[websockets.WebSocketServerProtocol] = None ws: Optional[websockets.WebSocketServerProtocol] = None
def on_volume_state(volume_setting: int, muted: int, change_counter: int): def on_volume_state_change():
if ws: if ws:
asyncio.create_task( asyncio.create_task(
ws.send(dumps_volume_state(volume_setting, muted, change_counter)) ws.send(
dumps_volume_state(
vcs.volume_setting, vcs.muted, vcs.change_counter
)
)
) )
vcs.on('volume_state', on_volume_state) vcs.on('volume_state_change', on_volume_state_change)
advertising_data = ( advertising_data = (
bytes( bytes(
@@ -170,16 +174,10 @@ async def main() -> None:
ws = websocket ws = websocket
async for message in websocket: async for message in websocket:
volume_state = json.loads(message) volume_state = json.loads(message)
vcs.volume_state_bytes = bytes( vcs.volume_setting = volume_state['volume_setting']
[ vcs.muted = volume_state['muted']
volume_state['volume_setting'], vcs.change_counter = volume_state['change_counter']
volume_state['muted'], await device.notify_subscribers(vcs.volume_state)
volume_state['change_counter'],
]
)
await device.notify_subscribers(
vcs.volume_state, vcs.volume_state_bytes
)
ws = None ws = None
await websockets.serve(serve, 'localhost', 8989) await websockets.serve(serve, 'localhost', 8989)

View File

@@ -10,7 +10,7 @@ android {
defaultConfig { defaultConfig {
applicationId = "com.github.google.bumble.btbench" applicationId = "com.github.google.bumble.btbench"
minSdk = 30 minSdk = 33
targetSdk = 34 targetSdk = 34
versionCode = 1 versionCode = 1
versionName = "1.0" versionName = "1.0"

View File

@@ -1,7 +1,7 @@
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android" <manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.github.google.bumble.btbench"> package="com.github.google.bumble.btbench">
<uses-sdk android:minSdkVersion="30" android:targetSdkVersion="34" /> <uses-sdk android:minSdkVersion="33" android:targetSdkVersion="34" />
<!-- Request legacy Bluetooth permissions on older devices. --> <!-- Request legacy Bluetooth permissions on older devices. -->
<uses-permission android:name="android.permission.BLUETOOTH" android:maxSdkVersion="30" /> <uses-permission android:name="android.permission.BLUETOOTH" android:maxSdkVersion="30" />
<uses-permission android:name="android.permission.BLUETOOTH_ADMIN" android:maxSdkVersion="30" /> <uses-permission android:name="android.permission.BLUETOOTH_ADMIN" android:maxSdkVersion="30" />
@@ -9,6 +9,8 @@
<uses-permission android:name="android.permission.BLUETOOTH_SCAN" android:usesPermissionFlags="neverForLocation"/> <uses-permission android:name="android.permission.BLUETOOTH_SCAN" android:usesPermissionFlags="neverForLocation"/>
<uses-permission android:name="android.permission.BLUETOOTH_ADVERTISE" /> <uses-permission android:name="android.permission.BLUETOOTH_ADVERTISE" />
<uses-permission android:name="android.permission.BLUETOOTH_CONNECT" /> <uses-permission android:name="android.permission.BLUETOOTH_CONNECT" />
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
<uses-permission android:name="android.permission.INTERNET" />
<uses-feature android:name="android.hardware.bluetooth" android:required="true"/> <uses-feature android:name="android.hardware.bluetooth" android:required="true"/>
<uses-feature android:name="android.hardware.bluetooth_le" android:required="true"/> <uses-feature android:name="android.hardware.bluetooth_le" android:required="true"/>

View File

@@ -0,0 +1,39 @@
package com.github.google.bumble.btbench
import android.annotation.SuppressLint
import android.bluetooth.BluetoothAdapter
import android.bluetooth.le.AdvertiseCallback
import android.bluetooth.le.AdvertiseData
import android.bluetooth.le.AdvertiseSettings
import android.bluetooth.le.AdvertiseSettings.ADVERTISE_MODE_LOW_LATENCY
import android.os.Build
import java.util.logging.Logger
private val Log = Logger.getLogger("btbench.advertiser")
class Advertiser(private val bluetoothAdapter: BluetoothAdapter) : AdvertiseCallback() {
@SuppressLint("MissingPermission")
fun start() {
val advertiseSettingsBuilder = AdvertiseSettings.Builder()
.setAdvertiseMode(ADVERTISE_MODE_LOW_LATENCY)
.setConnectable(true)
advertiseSettingsBuilder.setDiscoverable(true)
val advertiseSettings = advertiseSettingsBuilder.build()
val advertiseData = AdvertiseData.Builder().build()
val scanData = AdvertiseData.Builder().setIncludeDeviceName(true).build()
bluetoothAdapter.bluetoothLeAdvertiser.startAdvertising(advertiseSettings, advertiseData, scanData, this)
}
@SuppressLint("MissingPermission")
fun stop() {
bluetoothAdapter.bluetoothLeAdvertiser.stopAdvertising(this)
}
override fun onStartFailure(errorCode: Int) {
Log.warning("failed to start advertising: $errorCode")
}
override fun onStartSuccess(settingsInEffect: AdvertiseSettings) {
Log.info("advertising started: $settingsInEffect")
}
}

View File

@@ -22,11 +22,13 @@ import androidx.test.core.app.ApplicationProvider;
import com.google.android.mobly.snippet.Snippet; import com.google.android.mobly.snippet.Snippet;
import com.google.android.mobly.snippet.rpc.Rpc; import com.google.android.mobly.snippet.rpc.Rpc;
import com.google.android.mobly.snippet.rpc.RpcOptional;
import org.json.JSONArray; import org.json.JSONArray;
import org.json.JSONException; import org.json.JSONException;
import org.json.JSONObject; import org.json.JSONObject;
import java.io.IOException;
import java.security.InvalidParameterException; import java.security.InvalidParameterException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.UUID; import java.util.UUID;
@@ -71,12 +73,15 @@ public class AutomationSnippet implements Snippet {
private final Context mContext; private final Context mContext;
private final ArrayList<Runner> mRunners = new ArrayList<>(); private final ArrayList<Runner> mRunners = new ArrayList<>();
public AutomationSnippet() { public AutomationSnippet() throws IOException {
mContext = ApplicationProvider.getApplicationContext(); mContext = ApplicationProvider.getApplicationContext();
BluetoothManager bluetoothManager = mContext.getSystemService(BluetoothManager.class); BluetoothManager bluetoothManager = mContext.getSystemService(BluetoothManager.class);
mBluetoothAdapter = bluetoothManager.getAdapter(); mBluetoothAdapter = bluetoothManager.getAdapter();
if (mBluetoothAdapter == null) { if (mBluetoothAdapter == null) {
throw new RuntimeException("bluetooth not supported"); throw new IOException("bluetooth not supported");
}
if (!mBluetoothAdapter.isEnabled()) {
throw new IOException("bluetooth not enabled");
} }
} }
@@ -85,32 +90,46 @@ public class AutomationSnippet implements Snippet {
switch (mode) { switch (mode) {
case "rfcomm-client": case "rfcomm-client":
runnable = new RfcommClient(model, mBluetoothAdapter, runnable = new RfcommClient(model, mBluetoothAdapter,
(PacketIO packetIO) -> createIoClient(model, scenario, (PacketIO packetIO) -> createIoClient(model, scenario,
packetIO)); packetIO));
break; break;
case "rfcomm-server": case "rfcomm-server":
runnable = new RfcommServer(model, mBluetoothAdapter, runnable = new RfcommServer(model, mBluetoothAdapter,
(PacketIO packetIO) -> createIoClient(model, scenario, (PacketIO packetIO) -> createIoClient(model, scenario,
packetIO)); packetIO));
break; break;
case "l2cap-client": case "l2cap-client":
runnable = new L2capClient(model, mBluetoothAdapter, mContext, runnable = new L2capClient(model, mBluetoothAdapter, mContext,
(PacketIO packetIO) -> createIoClient(model, scenario, (PacketIO packetIO) -> createIoClient(model, scenario,
packetIO)); packetIO));
break; break;
case "l2cap-server": case "l2cap-server":
runnable = new L2capServer(model, mBluetoothAdapter, runnable = new L2capServer(model, mBluetoothAdapter,
(PacketIO packetIO) -> createIoClient(model, scenario, (PacketIO packetIO) -> createIoClient(model, scenario,
packetIO)); packetIO));
break;
case "gatt-client":
runnable = new GattClient(model, mBluetoothAdapter, mContext,
(PacketIO packetIO) -> createIoClient(model, scenario,
packetIO));
break;
case "gatt-server":
runnable = new GattServer(model, mBluetoothAdapter, mContext,
(PacketIO packetIO) -> createIoClient(model, scenario,
packetIO));
break; break;
default: default:
return null; return null;
} }
model.setMode(mode);
model.setScenario(scenario);
runnable.run(); runnable.run();
Runner runner = new Runner(runnable, mode, scenario, model); Runner runner = new Runner(runnable, mode, scenario, model);
mRunners.add(runner); mRunners.add(runner);
@@ -140,7 +159,21 @@ public class AutomationSnippet implements Snippet {
JSONObject result = new JSONObject(); JSONObject result = new JSONObject();
result.put("status", model.getStatus()); result.put("status", model.getStatus());
result.put("running", model.getRunning()); result.put("running", model.getRunning());
result.put("peer_bluetooth_address", model.getPeerBluetoothAddress());
result.put("mode", model.getMode());
result.put("scenario", model.getScenario());
result.put("sender_packet_size", model.getSenderPacketSize());
result.put("sender_packet_count", model.getSenderPacketCount());
result.put("sender_packet_interval", model.getSenderPacketInterval());
result.put("packets_sent", model.getPacketsSent());
result.put("packets_received", model.getPacketsReceived());
result.put("l2cap_psm", model.getL2capPsm()); result.put("l2cap_psm", model.getL2capPsm());
result.put("use_2m_phy", model.getUse2mPhy());
result.put("connection_priority", model.getConnectionPriority());
result.put("mtu", model.getMtu());
result.put("rx_phy", model.getRxPhy());
result.put("tx_phy", model.getTxPhy());
result.put("startup_delay", model.getStartupDelay());
if (model.getStatus().equals("OK")) { if (model.getStatus().equals("OK")) {
JSONObject stats = new JSONObject(); JSONObject stats = new JSONObject();
result.put("stats", stats); result.put("stats", stats);
@@ -167,12 +200,12 @@ public class AutomationSnippet implements Snippet {
@Rpc(description = "Run a scenario in RFComm Client mode") @Rpc(description = "Run a scenario in RFComm Client mode")
public JSONObject runRfcommClient(String scenario, String peerBluetoothAddress, int packetCount, public JSONObject runRfcommClient(String scenario, String peerBluetoothAddress, int packetCount,
int packetSize, int packetInterval) throws JSONException { int packetSize, int packetInterval,
assert (mBluetoothAdapter != null); @RpcOptional Integer startupDelay) throws JSONException {
// We only support "send" and "ping" for this mode for now // We only support "send" and "ping" for this mode for now
if (!(scenario.equals("send") || scenario.equals("ping"))) { if (!(scenario.equals("send") || scenario.equals("ping"))) {
throw new InvalidParameterException("only 'send' and 'ping' are supported for this mode"); throw new InvalidParameterException(
"only 'send' and 'ping' are supported for this mode");
} }
AppViewModel model = new AppViewModel(); AppViewModel model = new AppViewModel();
@@ -180,6 +213,9 @@ public class AutomationSnippet implements Snippet {
model.setSenderPacketCount(packetCount); model.setSenderPacketCount(packetCount);
model.setSenderPacketSize(packetSize); model.setSenderPacketSize(packetSize);
model.setSenderPacketInterval(packetInterval); model.setSenderPacketInterval(packetInterval);
if (startupDelay != null) {
model.setStartupDelay(startupDelay);
}
Runner runner = runScenario(model, "rfcomm-client", scenario); Runner runner = runScenario(model, "rfcomm-client", scenario);
assert runner != null; assert runner != null;
@@ -187,15 +223,18 @@ public class AutomationSnippet implements Snippet {
} }
@Rpc(description = "Run a scenario in RFComm Server mode") @Rpc(description = "Run a scenario in RFComm Server mode")
public JSONObject runRfcommServer(String scenario) throws JSONException { public JSONObject runRfcommServer(String scenario,
assert (mBluetoothAdapter != null); @RpcOptional Integer startupDelay) throws JSONException {
// We only support "receive" and "pong" for this mode for now // We only support "receive" and "pong" for this mode for now
if (!(scenario.equals("receive") || scenario.equals("pong"))) { if (!(scenario.equals("receive") || scenario.equals("pong"))) {
throw new InvalidParameterException("only 'receive' and 'pong' are supported for this mode"); throw new InvalidParameterException(
"only 'receive' and 'pong' are supported for this mode");
} }
AppViewModel model = new AppViewModel(); AppViewModel model = new AppViewModel();
if (startupDelay != null) {
model.setStartupDelay(startupDelay);
}
Runner runner = runScenario(model, "rfcomm-server", scenario); Runner runner = runScenario(model, "rfcomm-server", scenario);
assert runner != null; assert runner != null;
@@ -205,12 +244,12 @@ public class AutomationSnippet implements Snippet {
@Rpc(description = "Run a scenario in L2CAP Client mode") @Rpc(description = "Run a scenario in L2CAP Client mode")
public JSONObject runL2capClient(String scenario, String peerBluetoothAddress, int psm, public JSONObject runL2capClient(String scenario, String peerBluetoothAddress, int psm,
boolean use_2m_phy, int packetCount, int packetSize, boolean use_2m_phy, int packetCount, int packetSize,
int packetInterval) throws JSONException { int packetInterval, @RpcOptional String connectionPriority,
assert (mBluetoothAdapter != null); @RpcOptional Integer startupDelay) throws JSONException {
// We only support "send" and "ping" for this mode for now // We only support "send" and "ping" for this mode for now
if (!(scenario.equals("send") || scenario.equals("ping"))) { if (!(scenario.equals("send") || scenario.equals("ping"))) {
throw new InvalidParameterException("only 'send' and 'ping' are supported for this mode"); throw new InvalidParameterException(
"only 'send' and 'ping' are supported for this mode");
} }
AppViewModel model = new AppViewModel(); AppViewModel model = new AppViewModel();
@@ -220,28 +259,83 @@ public class AutomationSnippet implements Snippet {
model.setSenderPacketCount(packetCount); model.setSenderPacketCount(packetCount);
model.setSenderPacketSize(packetSize); model.setSenderPacketSize(packetSize);
model.setSenderPacketInterval(packetInterval); model.setSenderPacketInterval(packetInterval);
if (connectionPriority != null) {
model.setConnectionPriority(connectionPriority);
}
if (startupDelay != null) {
model.setStartupDelay(startupDelay);
}
Runner runner = runScenario(model, "l2cap-client", scenario); Runner runner = runScenario(model, "l2cap-client", scenario);
assert runner != null; assert runner != null;
return runner.toJson(); return runner.toJson();
} }
@Rpc(description = "Run a scenario in L2CAP Server mode") @Rpc(description = "Run a scenario in L2CAP Server mode")
public JSONObject runL2capServer(String scenario) throws JSONException { public JSONObject runL2capServer(String scenario,
assert (mBluetoothAdapter != null); @RpcOptional Integer startupDelay) throws JSONException {
// We only support "receive" and "pong" for this mode for now // We only support "receive" and "pong" for this mode for now
if (!(scenario.equals("receive") || scenario.equals("pong"))) { if (!(scenario.equals("receive") || scenario.equals("pong"))) {
throw new InvalidParameterException("only 'receive' and 'pong' are supported for this mode"); throw new InvalidParameterException(
"only 'receive' and 'pong' are supported for this mode");
} }
AppViewModel model = new AppViewModel(); AppViewModel model = new AppViewModel();
if (startupDelay != null) {
model.setStartupDelay(startupDelay);
}
Runner runner = runScenario(model, "l2cap-server", scenario); Runner runner = runScenario(model, "l2cap-server", scenario);
assert runner != null; assert runner != null;
return runner.toJson(); return runner.toJson();
} }
@Rpc(description = "Run a scenario in GATT Client mode")
public JSONObject runGattClient(String scenario, String peerBluetoothAddress,
boolean use_2m_phy, int packetCount, int packetSize,
int packetInterval, @RpcOptional String connectionPriority,
@RpcOptional Integer startupDelay) throws JSONException {
// We only support "send" and "ping" for this mode for now
if (!(scenario.equals("send") || scenario.equals("ping"))) {
throw new InvalidParameterException(
"only 'send' and 'ping' are supported for this mode");
}
AppViewModel model = new AppViewModel();
model.setPeerBluetoothAddress(peerBluetoothAddress);
model.setUse2mPhy(use_2m_phy);
model.setSenderPacketCount(packetCount);
model.setSenderPacketSize(packetSize);
model.setSenderPacketInterval(packetInterval);
if (connectionPriority != null) {
model.setConnectionPriority(connectionPriority);
}
if (startupDelay != null) {
model.setStartupDelay(startupDelay);
}
Runner runner = runScenario(model, "gatt-client", scenario);
assert runner != null;
return runner.toJson();
}
@Rpc(description = "Run a scenario in GATT Server mode")
public JSONObject runGattServer(String scenario,
@RpcOptional Integer startupDelay) throws JSONException {
// We only support "receive" and "pong" for this mode for now
if (!(scenario.equals("receive") || scenario.equals("pong"))) {
throw new InvalidParameterException(
"only 'receive' and 'pong' are supported for this mode");
}
AppViewModel model = new AppViewModel();
if (startupDelay != null) {
model.setStartupDelay(startupDelay);
}
Runner runner = runScenario(model, "gatt-server", scenario);
assert runner != null;
return runner.toJson();
}
@Rpc(description = "Stop a Runner") @Rpc(description = "Stop a Runner")
public JSONObject stopRunner(String runnerId) throws JSONException { public JSONObject stopRunner(String runnerId) throws JSONException {
Runner runner = findRunner(runnerId); Runner runner = findRunner(runnerId);
@@ -276,7 +370,7 @@ public class AutomationSnippet implements Snippet {
JSONObject result = new JSONObject(); JSONObject result = new JSONObject();
JSONArray runners = new JSONArray(); JSONArray runners = new JSONArray();
result.put("runners", runners); result.put("runners", runners);
for (Runner runner: mRunners) { for (Runner runner : mRunners) {
runners.put(runner.toJson()); runners.put(runner.toJson());
} }

View File

@@ -0,0 +1,109 @@
package com.github.google.bumble.btbench
import android.annotation.SuppressLint
import android.bluetooth.BluetoothAdapter
import android.bluetooth.BluetoothDevice
import android.bluetooth.BluetoothGatt
import android.bluetooth.BluetoothGattCallback
import android.bluetooth.BluetoothManager
import android.bluetooth.BluetoothProfile
import android.content.Context
import android.os.Build
import androidx.core.content.ContextCompat
import java.util.logging.Logger
private val Log = Logger.getLogger("btbench.connection")
open class Connection(
private val viewModel: AppViewModel,
private val bluetoothAdapter: BluetoothAdapter,
private val context: Context
) : BluetoothGattCallback() {
var remoteDevice: BluetoothDevice? = null
var gatt: BluetoothGatt? = null
@SuppressLint("MissingPermission")
open fun connect() {
val addressIsPublic = viewModel.peerBluetoothAddress.endsWith("/P")
val address = viewModel.peerBluetoothAddress.take(17)
remoteDevice = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
bluetoothAdapter.getRemoteLeDevice(
address,
if (addressIsPublic) {
BluetoothDevice.ADDRESS_TYPE_PUBLIC
} else {
BluetoothDevice.ADDRESS_TYPE_RANDOM
}
)
} else {
bluetoothAdapter.getRemoteDevice(address)
}
gatt = remoteDevice?.connectGatt(
context,
false,
this,
BluetoothDevice.TRANSPORT_LE,
if (viewModel.use2mPhy) BluetoothDevice.PHY_LE_2M_MASK else BluetoothDevice.PHY_LE_1M_MASK
)
}
@SuppressLint("MissingPermission")
open fun disconnect() {
gatt?.disconnect()
}
override fun onMtuChanged(gatt: BluetoothGatt, mtu: Int, status: Int) {
Log.info("MTU update: mtu=$mtu status=$status")
viewModel.mtu = mtu
}
override fun onPhyUpdate(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
Log.info("PHY update: tx=$txPhy, rx=$rxPhy, status=$status")
viewModel.txPhy = txPhy
viewModel.rxPhy = rxPhy
}
override fun onPhyRead(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
Log.info("PHY: tx=$txPhy, rx=$rxPhy, status=$status")
viewModel.txPhy = txPhy
viewModel.rxPhy = rxPhy
}
@SuppressLint("MissingPermission")
override fun onConnectionStateChange(
gatt: BluetoothGatt?, status: Int, newState: Int
) {
if (status != BluetoothGatt.GATT_SUCCESS) {
Log.warning("onConnectionStateChange status=$status")
}
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
if (viewModel.use2mPhy) {
Log.info("requesting 2M PHY")
gatt.setPreferredPhy(
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_OPTION_NO_PREFERRED
)
}
gatt.readPhy()
// Request an MTU update, even though we don't use GATT, because Android
// won't request a larger link layer maximum data length otherwise.
gatt.requestMtu(517)
// Request a specific connection priority
val connectionPriority = when (viewModel.connectionPriority) {
"BALANCED" -> BluetoothGatt.CONNECTION_PRIORITY_BALANCED
"LOW_POWER" -> BluetoothGatt.CONNECTION_PRIORITY_LOW_POWER
"HIGH" -> BluetoothGatt.CONNECTION_PRIORITY_HIGH
"DCK" -> BluetoothGatt.CONNECTION_PRIORITY_DCK
else -> 0
}
if (!gatt.requestConnectionPriority(connectionPriority)) {
Log.warning("requestConnectionPriority failed")
}
}
}
}

View File

@@ -0,0 +1,23 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.github.google.bumble.btbench
import java.util.UUID
var CCCD_UUID = UUID.fromString("00002902-0000-1000-8000-00805F9B34FB")
val BENCH_SERVICE_UUID = UUID.fromString("50DB505C-8AC4-4738-8448-3B1D9CC09CC5")
val BENCH_TX_UUID = UUID.fromString("E789C754-41A1-45F4-A948-A0A1A90DBA53")
val BENCH_RX_UUID = UUID.fromString("016A2CC7-E14B-4819-935F-1F56EAE4098D")

View File

@@ -0,0 +1,224 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.github.google.bumble.btbench
import android.annotation.SuppressLint
import android.bluetooth.BluetoothAdapter
import android.bluetooth.BluetoothGatt
import android.bluetooth.BluetoothGattCharacteristic
import android.bluetooth.BluetoothGattDescriptor
import android.bluetooth.BluetoothProfile
import android.content.Context
import java.io.IOException
import java.util.UUID
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Semaphore
import java.util.logging.Logger
import kotlin.concurrent.thread
private val Log = Logger.getLogger("btbench.gatt-client")
class GattClientConnection(
viewModel: AppViewModel,
bluetoothAdapter: BluetoothAdapter,
context: Context
) : Connection(viewModel, bluetoothAdapter, context), PacketIO {
override var packetSink: PacketSink? = null
private val discoveryDone: CountDownLatch = CountDownLatch(1)
private val writeSemaphore: Semaphore = Semaphore(1)
var rxCharacteristic: BluetoothGattCharacteristic? = null
var txCharacteristic: BluetoothGattCharacteristic? = null
override fun connect() {
super.connect()
// Check if we're already connected and have discovered the services
if (gatt?.getService(BENCH_SERVICE_UUID) != null) {
Log.fine("already connected")
onServicesDiscovered(gatt, BluetoothGatt.GATT_SUCCESS)
}
}
@SuppressLint("MissingPermission")
override fun onConnectionStateChange(
gatt: BluetoothGatt?, status: Int, newState: Int
) {
super.onConnectionStateChange(gatt, status, newState)
if (status != BluetoothGatt.GATT_SUCCESS) {
Log.warning("onConnectionStateChange status=$status")
discoveryDone.countDown()
return
}
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
if (!gatt.discoverServices()) {
Log.warning("discoverServices could not start")
discoveryDone.countDown()
}
}
}
@SuppressLint("MissingPermission")
override fun onServicesDiscovered(gatt: BluetoothGatt?, status: Int) {
Log.fine("onServicesDiscovered")
if (status != BluetoothGatt.GATT_SUCCESS) {
Log.warning("failed to discover services: ${status}")
discoveryDone.countDown()
return
}
// Find the service
val service = gatt!!.getService(BENCH_SERVICE_UUID)
if (service == null) {
Log.warning("GATT Service not found")
discoveryDone.countDown()
return
}
// Find the RX and TX characteristics
rxCharacteristic = service.getCharacteristic(BENCH_RX_UUID)
if (rxCharacteristic == null) {
Log.warning("GATT RX Characteristics not found")
discoveryDone.countDown()
return
}
txCharacteristic = service.getCharacteristic(BENCH_TX_UUID)
if (txCharacteristic == null) {
Log.warning("GATT TX Characteristics not found")
discoveryDone.countDown()
return
}
// Subscribe to the RX characteristic
Log.fine("subscribing to RX")
gatt.setCharacteristicNotification(rxCharacteristic, true)
val cccdDescriptor = rxCharacteristic!!.getDescriptor(CCCD_UUID)
gatt.writeDescriptor(cccdDescriptor, BluetoothGattDescriptor.ENABLE_NOTIFICATION_VALUE);
Log.info("GATT discovery complete")
discoveryDone.countDown()
}
override fun onCharacteristicWrite(
gatt: BluetoothGatt?,
characteristic: BluetoothGattCharacteristic?,
status: Int
) {
// Now we can write again
writeSemaphore.release()
if (status != BluetoothGatt.GATT_SUCCESS) {
Log.warning("onCharacteristicWrite failed: $status")
return
}
}
override fun onCharacteristicChanged(
gatt: BluetoothGatt,
characteristic: BluetoothGattCharacteristic,
value: ByteArray
) {
if (characteristic.uuid == BENCH_RX_UUID && packetSink != null) {
val packet = Packet.from(value)
packetSink!!.onPacket(packet)
}
}
@SuppressLint("MissingPermission")
override fun sendPacket(packet: Packet) {
if (txCharacteristic == null) {
Log.warning("No TX characteristic, dropping")
return
}
// Wait until we can write
writeSemaphore.acquire()
// Write the data
val data = packet.toBytes()
val clampedData = if (data.size > 512) {
// Clamp the data to the maximum allowed characteristic data size
data.copyOf(512)
} else {
data
}
gatt?.writeCharacteristic(
txCharacteristic!!,
clampedData,
BluetoothGattCharacteristic.WRITE_TYPE_NO_RESPONSE
)
}
override
fun disconnect() {
super.disconnect()
discoveryDone.countDown()
}
fun waitForDiscoveryCompletion() {
discoveryDone.await()
}
}
class GattClient(
private val viewModel: AppViewModel,
bluetoothAdapter: BluetoothAdapter,
context: Context,
private val createIoClient: (packetIo: PacketIO) -> IoClient
) : Mode {
private var connection: GattClientConnection =
GattClientConnection(viewModel, bluetoothAdapter, context)
private var clientThread: Thread? = null
@SuppressLint("MissingPermission")
override fun run() {
viewModel.running = true
clientThread = thread(name = "GattClient") {
connection.connect()
viewModel.aborter = {
connection.disconnect()
}
// Discover the rx and tx characteristics
connection.waitForDiscoveryCompletion()
if (connection.rxCharacteristic == null || connection.txCharacteristic == null) {
connection.disconnect()
viewModel.running = false
return@thread
}
val ioClient = createIoClient(connection)
try {
ioClient.run()
viewModel.status = "OK"
} catch (error: IOException) {
Log.info("run ended abruptly")
viewModel.status = "ABORTED"
viewModel.lastError = "IO_ERROR"
} finally {
connection.disconnect()
viewModel.running = false
}
}
}
override fun waitForCompletion() {
clientThread?.join()
}
}

View File

@@ -0,0 +1,243 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.github.google.bumble.btbench
import android.annotation.SuppressLint
import android.bluetooth.BluetoothAdapter
import android.bluetooth.BluetoothDevice
import android.bluetooth.BluetoothGatt
import android.bluetooth.BluetoothGattCharacteristic
import android.bluetooth.BluetoothGattDescriptor
import android.bluetooth.BluetoothGattServer
import android.bluetooth.BluetoothGattServerCallback
import android.bluetooth.BluetoothGattService
import android.bluetooth.BluetoothManager
import android.bluetooth.BluetoothStatusCodes
import android.content.Context
import androidx.core.content.ContextCompat
import java.io.IOException
import java.util.concurrent.CountDownLatch
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.Semaphore
import java.util.logging.Logger
import kotlin.concurrent.thread
import kotlin.experimental.and
private val Log = Logger.getLogger("btbench.gatt-server")
@SuppressLint("MissingPermission")
class GattServer(
private val viewModel: AppViewModel,
private val bluetoothAdapter: BluetoothAdapter,
context: Context,
private val createIoClient: (packetIo: PacketIO) -> IoClient
) : Mode, PacketIO, BluetoothGattServerCallback() {
override var packetSink: PacketSink? = null
private val gattServer: BluetoothGattServer
private val rxCharacteristic: BluetoothGattCharacteristic?
private val txCharacteristic: BluetoothGattCharacteristic?
private val notifySemaphore: Semaphore = Semaphore(1)
private val ready: CountDownLatch = CountDownLatch(1)
private var peerDevice: BluetoothDevice? = null
private var clientThread: Thread? = null
private var sinkQueue: LinkedBlockingQueue<Packet>? = null
init {
val bluetoothManager = ContextCompat.getSystemService(context, BluetoothManager::class.java)
gattServer = bluetoothManager!!.openGattServer(context, this)
val benchService = gattServer.getService(BENCH_SERVICE_UUID)
if (benchService == null) {
rxCharacteristic = BluetoothGattCharacteristic(
BENCH_RX_UUID,
BluetoothGattCharacteristic.PROPERTY_NOTIFY,
0
)
txCharacteristic = BluetoothGattCharacteristic(
BENCH_TX_UUID,
BluetoothGattCharacteristic.PROPERTY_WRITE_NO_RESPONSE,
BluetoothGattCharacteristic.PERMISSION_WRITE
)
val rxCCCD = BluetoothGattDescriptor(
CCCD_UUID,
BluetoothGattDescriptor.PERMISSION_READ or BluetoothGattDescriptor.PERMISSION_WRITE
)
rxCharacteristic.addDescriptor(rxCCCD)
val service =
BluetoothGattService(BENCH_SERVICE_UUID, BluetoothGattService.SERVICE_TYPE_PRIMARY)
service.addCharacteristic(rxCharacteristic)
service.addCharacteristic(txCharacteristic)
gattServer.addService(service)
} else {
rxCharacteristic = benchService.getCharacteristic(BENCH_RX_UUID)
txCharacteristic = benchService.getCharacteristic(BENCH_TX_UUID)
}
}
override fun onCharacteristicWriteRequest(
device: BluetoothDevice?,
requestId: Int,
characteristic: BluetoothGattCharacteristic?,
preparedWrite: Boolean,
responseNeeded: Boolean,
offset: Int,
value: ByteArray?
) {
Log.info("onCharacteristicWriteRequest")
if (characteristic != null && characteristic.uuid == BENCH_TX_UUID) {
if (packetSink == null) {
Log.warning("no sink, dropping")
} else if (offset != 0) {
Log.warning("offset != 0")
} else if (value == null) {
Log.warning("no value")
} else {
// Deliver the packet in a separate thread so that we don't block this
// callback.
sinkQueue?.put(Packet.from(value))
}
}
if (responseNeeded) {
gattServer.sendResponse(device, requestId, BluetoothGatt.GATT_SUCCESS, offset, value)
}
}
override fun onNotificationSent(device: BluetoothDevice?, status: Int) {
if (status == BluetoothGatt.GATT_SUCCESS) {
notifySemaphore.release()
}
}
override fun onDescriptorWriteRequest(
device: BluetoothDevice?,
requestId: Int,
descriptor: BluetoothGattDescriptor?,
preparedWrite: Boolean,
responseNeeded: Boolean,
offset: Int,
value: ByteArray?
) {
if (descriptor?.uuid == CCCD_UUID && descriptor?.characteristic?.uuid == BENCH_RX_UUID) {
if (offset == 0 && value?.size == 2) {
if (value[0].and(1).toInt() != 0) {
// Subscription
Log.fine("peer subscribed to RX")
peerDevice = device
ready.countDown()
}
}
}
if (responseNeeded) {
gattServer.sendResponse(device, requestId, BluetoothGatt.GATT_SUCCESS, offset, value)
}
}
@SuppressLint("MissingPermission")
override fun sendPacket(packet: Packet) {
if (peerDevice == null) {
Log.warning("no peer device, cannot send")
return
}
if (rxCharacteristic == null) {
Log.warning("no RX characteristic, cannot send")
return
}
// Wait until we can notify
notifySemaphore.acquire()
// Send the packet via a notification
val result = gattServer.notifyCharacteristicChanged(
peerDevice!!,
rxCharacteristic,
false,
packet.toBytes()
)
if (result != BluetoothStatusCodes.SUCCESS) {
Log.warning("notifyCharacteristicChanged failed: $result")
notifySemaphore.release()
}
}
override fun run() {
viewModel.running = true
// Start advertising
Log.fine("starting advertiser")
val advertiser = Advertiser(bluetoothAdapter)
advertiser.start()
clientThread = thread(name = "GattServer") {
// Wait for a subscriber
Log.info("waiting for RX subscriber")
viewModel.aborter = {
ready.countDown()
}
ready.await()
if (peerDevice == null) {
Log.warning("server interrupted")
viewModel.running = false
gattServer.close()
return@thread
}
Log.info("RX subscriber accepted")
// Stop advertising
Log.info("stopping advertiser")
advertiser.stop()
sinkQueue = LinkedBlockingQueue()
val sinkWriterThread = thread(name = "SinkWriter") {
while (true) {
try {
val packet = sinkQueue!!.take()
if (packetSink == null) {
Log.warning("no sink, dropping packet")
continue
}
packetSink!!.onPacket(packet)
} catch (error: InterruptedException) {
Log.warning("sink writer interrupted")
break
}
}
}
val ioClient = createIoClient(this)
try {
ioClient.run()
viewModel.status = "OK"
} catch (error: IOException) {
Log.info("run ended abruptly")
viewModel.status = "ABORTED"
viewModel.lastError = "IO_ERROR"
} finally {
sinkWriterThread.interrupt()
sinkWriterThread.join()
gattServer.close()
viewModel.running = false
}
}
}
override fun waitForCompletion() {
clientThread?.join()
Log.info("server thread completed")
}
}

View File

@@ -16,89 +16,25 @@ package com.github.google.bumble.btbench
import android.annotation.SuppressLint import android.annotation.SuppressLint
import android.bluetooth.BluetoothAdapter import android.bluetooth.BluetoothAdapter
import android.bluetooth.BluetoothDevice
import android.bluetooth.BluetoothGatt
import android.bluetooth.BluetoothGattCallback
import android.bluetooth.BluetoothProfile
import android.content.Context import android.content.Context
import android.os.Build
import java.util.logging.Logger import java.util.logging.Logger
private val Log = Logger.getLogger("btbench.l2cap-client") private val Log = Logger.getLogger("btbench.l2cap-client")
class L2capClient( class L2capClient(
private val viewModel: AppViewModel, private val viewModel: AppViewModel,
private val bluetoothAdapter: BluetoothAdapter, bluetoothAdapter: BluetoothAdapter,
private val context: Context, context: Context,
private val createIoClient: (packetIo: PacketIO) -> IoClient private val createIoClient: (packetIo: PacketIO) -> IoClient
) : Mode { ) : Mode {
private var connection: Connection = Connection(viewModel, bluetoothAdapter, context)
private var socketClient: SocketClient? = null private var socketClient: SocketClient? = null
@SuppressLint("MissingPermission") @SuppressLint("MissingPermission")
override fun run() { override fun run() {
viewModel.running = true viewModel.running = true
val addressIsPublic = viewModel.peerBluetoothAddress.endsWith("/P") connection.connect()
val address = viewModel.peerBluetoothAddress.take(17) val socket = connection.remoteDevice!!.createInsecureL2capChannel(viewModel.l2capPsm)
val remoteDevice = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
bluetoothAdapter.getRemoteLeDevice(
address,
if (addressIsPublic) {
BluetoothDevice.ADDRESS_TYPE_PUBLIC
} else {
BluetoothDevice.ADDRESS_TYPE_RANDOM
}
)
} else {
bluetoothAdapter.getRemoteDevice(address)
}
val gatt = remoteDevice.connectGatt(
context,
false,
object : BluetoothGattCallback() {
override fun onMtuChanged(gatt: BluetoothGatt, mtu: Int, status: Int) {
Log.info("MTU update: mtu=$mtu status=$status")
viewModel.mtu = mtu
}
override fun onPhyUpdate(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
Log.info("PHY update: tx=$txPhy, rx=$rxPhy, status=$status")
viewModel.txPhy = txPhy
viewModel.rxPhy = rxPhy
}
override fun onPhyRead(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
Log.info("PHY: tx=$txPhy, rx=$rxPhy, status=$status")
viewModel.txPhy = txPhy
viewModel.rxPhy = rxPhy
}
override fun onConnectionStateChange(
gatt: BluetoothGatt?, status: Int, newState: Int
) {
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
if (viewModel.use2mPhy) {
Log.info("requesting 2M PHY")
gatt.setPreferredPhy(
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_OPTION_NO_PREFERRED
)
}
gatt.readPhy()
// Request an MTU update, even though we don't use GATT, because Android
// won't request a larger link layer maximum data length otherwise.
gatt.requestMtu(517)
}
}
},
BluetoothDevice.TRANSPORT_LE,
if (viewModel.use2mPhy) BluetoothDevice.PHY_LE_2M_MASK else BluetoothDevice.PHY_LE_1M_MASK
)
val socket = remoteDevice.createInsecureL2capChannel(viewModel.l2capPsm)
socketClient = SocketClient(viewModel, socket, createIoClient) socketClient = SocketClient(viewModel, socket, createIoClient)
socketClient!!.run() socketClient!!.run()
} }

View File

@@ -37,34 +37,15 @@ class L2capServer(
@SuppressLint("MissingPermission") @SuppressLint("MissingPermission")
override fun run() { override fun run() {
// Advertise so that the peer can find us and connect. // Advertise so that the peer can find us and connect.
val callback = object : AdvertiseCallback() { val advertiser = Advertiser(bluetoothAdapter)
override fun onStartFailure(errorCode: Int) {
Log.warning("failed to start advertising: $errorCode")
}
override fun onStartSuccess(settingsInEffect: AdvertiseSettings) {
Log.info("advertising started: $settingsInEffect")
}
}
val advertiseSettingsBuilder = AdvertiseSettings.Builder()
.setAdvertiseMode(ADVERTISE_MODE_LOW_LATENCY)
.setConnectable(true)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
advertiseSettingsBuilder.setDiscoverable(true)
}
val advertiseSettings = advertiseSettingsBuilder.build()
val advertiseData = AdvertiseData.Builder().build()
val scanData = AdvertiseData.Builder().setIncludeDeviceName(true).build()
val advertiser = bluetoothAdapter.bluetoothLeAdvertiser
val serverSocket = bluetoothAdapter.listenUsingInsecureL2capChannel() val serverSocket = bluetoothAdapter.listenUsingInsecureL2capChannel()
viewModel.l2capPsm = serverSocket.psm viewModel.l2capPsm = serverSocket.psm
Log.info("psm = $serverSocket.psm") Log.info("psm = $serverSocket.psm")
socketServer = SocketServer(viewModel, serverSocket, createIoClient) socketServer = SocketServer(viewModel, serverSocket, createIoClient)
socketServer!!.run( socketServer!!.run(
{ advertiser.stopAdvertising(callback) }, { advertiser.stop() },
{ advertiser.startAdvertising(advertiseSettings, advertiseData, scanData, callback) } { advertiser.start() }
) )
} }

View File

@@ -17,9 +17,12 @@ package com.github.google.bumble.btbench
import android.Manifest import android.Manifest
import android.annotation.SuppressLint import android.annotation.SuppressLint
import android.bluetooth.BluetoothAdapter import android.bluetooth.BluetoothAdapter
import android.bluetooth.BluetoothDevice
import android.bluetooth.BluetoothManager import android.bluetooth.BluetoothManager
import android.content.BroadcastReceiver
import android.content.Context import android.content.Context
import android.content.Intent import android.content.Intent
import android.content.IntentFilter
import android.content.pm.PackageManager import android.content.pm.PackageManager
import android.os.Build import android.os.Build
import android.os.Bundle import android.os.Bundle
@@ -66,6 +69,7 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp import androidx.compose.ui.unit.sp
import androidx.core.content.ContextCompat import androidx.core.content.ContextCompat
import com.github.google.bumble.btbench.ui.theme.BTBenchTheme import com.github.google.bumble.btbench.ui.theme.BTBenchTheme
import java.io.IOException
import java.util.logging.Logger import java.util.logging.Logger
private val Log = Logger.getLogger("bumble.main-activity") private val Log = Logger.getLogger("bumble.main-activity")
@@ -76,6 +80,7 @@ const val SENDER_PACKET_SIZE_PREF_KEY = "sender_packet_size"
const val SENDER_PACKET_INTERVAL_PREF_KEY = "sender_packet_interval" const val SENDER_PACKET_INTERVAL_PREF_KEY = "sender_packet_interval"
const val SCENARIO_PREF_KEY = "scenario" const val SCENARIO_PREF_KEY = "scenario"
const val MODE_PREF_KEY = "mode" const val MODE_PREF_KEY = "mode"
const val CONNECTION_PRIORITY_PREF_KEY = "connection_priority"
class MainActivity : ComponentActivity() { class MainActivity : ComponentActivity() {
private val appViewModel = AppViewModel() private val appViewModel = AppViewModel()
@@ -84,6 +89,47 @@ class MainActivity : ComponentActivity() {
super.onCreate(savedInstanceState) super.onCreate(savedInstanceState)
appViewModel.loadPreferences(getPreferences(Context.MODE_PRIVATE)) appViewModel.loadPreferences(getPreferences(Context.MODE_PRIVATE))
checkPermissions() checkPermissions()
registerReceivers()
}
private fun registerReceivers() {
val pairingRequestIntentFilter = IntentFilter(BluetoothDevice.ACTION_PAIRING_REQUEST)
registerReceiver(object: BroadcastReceiver() {
@SuppressLint("MissingPermission")
override fun onReceive(context: Context, intent: Intent) {
Log.info("ACTION_PAIRING_REQUEST")
val extras = intent.extras
if (extras != null) {
for (key in extras.keySet()) {
Log.info("$key: ${extras.get(key)}")
}
}
val device: BluetoothDevice? = intent.getParcelableExtra(BluetoothDevice.EXTRA_DEVICE)
if (device != null) {
if (checkSelfPermission(Manifest.permission.BLUETOOTH_PRIVILEGED) == PackageManager.PERMISSION_GRANTED) {
Log.info("confirming pairing")
device.setPairingConfirmation(true)
} else {
Log.info("we don't have BLUETOOTH_PRIVILEGED, not confirming")
}
}
}
}, pairingRequestIntentFilter)
val bondStateChangedIntentFilter = IntentFilter(BluetoothDevice.ACTION_BOND_STATE_CHANGED)
registerReceiver(object: BroadcastReceiver() {
@SuppressLint("MissingPermission")
override fun onReceive(context: Context, intent: Intent) {
Log.info("ACTION_BOND_STATE_CHANGED")
val extras = intent.extras
if (extras != null) {
for (key in extras.keySet()) {
Log.info("$key: ${extras.get(key)}")
}
}
}
}, bondStateChangedIntentFilter)
} }
private fun checkPermissions() { private fun checkPermissions() {
@@ -144,9 +190,7 @@ class MainActivity : ComponentActivity() {
initBluetooth() initBluetooth()
setContent { setContent {
MainView( MainView(
appViewModel, appViewModel, ::becomeDiscoverable, ::runScenario
::becomeDiscoverable,
::runScenario
) )
} }
@@ -182,6 +226,8 @@ class MainActivity : ComponentActivity() {
"rfcomm-server" -> appViewModel.mode = RFCOMM_SERVER_MODE "rfcomm-server" -> appViewModel.mode = RFCOMM_SERVER_MODE
"l2cap-client" -> appViewModel.mode = L2CAP_CLIENT_MODE "l2cap-client" -> appViewModel.mode = L2CAP_CLIENT_MODE
"l2cap-server" -> appViewModel.mode = L2CAP_SERVER_MODE "l2cap-server" -> appViewModel.mode = L2CAP_SERVER_MODE
"gatt-client" -> appViewModel.mode = GATT_CLIENT_MODE
"gatt-server" -> appViewModel.mode = GATT_SERVER_MODE
} }
} }
intent.getStringExtra("autostart")?.let { intent.getStringExtra("autostart")?.let {
@@ -195,19 +241,24 @@ class MainActivity : ComponentActivity() {
private fun runScenario() { private fun runScenario() {
if (bluetoothAdapter == null) { if (bluetoothAdapter == null) {
return throw IOException("bluetooth not enabled")
} }
val runner = when (appViewModel.mode) { val runner = when (appViewModel.mode) {
RFCOMM_CLIENT_MODE -> RfcommClient(appViewModel, bluetoothAdapter!!, ::createIoClient) RFCOMM_CLIENT_MODE -> RfcommClient(appViewModel, bluetoothAdapter!!, ::createIoClient)
RFCOMM_SERVER_MODE -> RfcommServer(appViewModel, bluetoothAdapter!!, ::createIoClient) RFCOMM_SERVER_MODE -> RfcommServer(appViewModel, bluetoothAdapter!!, ::createIoClient)
L2CAP_CLIENT_MODE -> L2capClient( L2CAP_CLIENT_MODE -> L2capClient(
appViewModel, appViewModel, bluetoothAdapter!!, baseContext, ::createIoClient
bluetoothAdapter!!,
baseContext,
::createIoClient
) )
L2CAP_SERVER_MODE -> L2capServer(appViewModel, bluetoothAdapter!!, ::createIoClient) L2CAP_SERVER_MODE -> L2capServer(appViewModel, bluetoothAdapter!!, ::createIoClient)
GATT_CLIENT_MODE -> GattClient(
appViewModel, bluetoothAdapter!!, baseContext, ::createIoClient
)
GATT_SERVER_MODE -> GattServer(
appViewModel, bluetoothAdapter!!, baseContext, ::createIoClient
)
else -> throw IllegalStateException() else -> throw IllegalStateException()
} }
runner.run() runner.run()
@@ -281,7 +332,7 @@ fun MainView(
keyboardController?.hide() keyboardController?.hide()
focusManager.clearFocus() focusManager.clearFocus()
}), }),
enabled = (appViewModel.mode == RFCOMM_CLIENT_MODE) or (appViewModel.mode == L2CAP_CLIENT_MODE) enabled = (appViewModel.mode == RFCOMM_CLIENT_MODE || appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == GATT_CLIENT_MODE)
) )
Divider() Divider()
TextField( TextField(
@@ -349,24 +400,45 @@ fun MainView(
keyboardController?.hide() keyboardController?.hide()
focusManager.clearFocus() focusManager.clearFocus()
}), }),
enabled = (appViewModel.scenario == PING_SCENARIO) enabled = (appViewModel.scenario == PING_SCENARIO || appViewModel.scenario == SEND_SCENARIO)
) )
Divider() Divider()
ActionButton(
text = "Become Discoverable", onClick = becomeDiscoverable, true
)
Row( Row(
horizontalArrangement = Arrangement.SpaceBetween, horizontalArrangement = Arrangement.SpaceBetween,
verticalAlignment = Alignment.CenterVertically verticalAlignment = Alignment.CenterVertically
) { ) {
Text(text = "2M PHY") Text(text = "2M PHY")
Spacer(modifier = Modifier.padding(start = 8.dp)) Spacer(modifier = Modifier.padding(start = 8.dp))
Switch( Switch(enabled = (appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == L2CAP_SERVER_MODE || appViewModel.mode == GATT_CLIENT_MODE || appViewModel.mode == GATT_SERVER_MODE),
enabled = (appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == L2CAP_SERVER_MODE),
checked = appViewModel.use2mPhy, checked = appViewModel.use2mPhy,
onCheckedChange = { appViewModel.use2mPhy = it } onCheckedChange = { appViewModel.use2mPhy = it })
) Column(Modifier.selectableGroup()) {
listOf(
"BALANCED", "LOW", "HIGH", "DCK"
).forEach { text ->
Row(
Modifier
.selectable(
selected = (text == appViewModel.connectionPriority),
onClick = { appViewModel.updateConnectionPriority(text) },
role = Role.RadioButton,
)
.padding(horizontal = 16.dp),
verticalAlignment = Alignment.CenterVertically
) {
RadioButton(
selected = (text == appViewModel.connectionPriority),
onClick = null,
enabled = (appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == L2CAP_SERVER_MODE || appViewModel.mode == GATT_CLIENT_MODE || appViewModel.mode == GATT_SERVER_MODE)
)
Text(
text = text,
style = MaterialTheme.typography.bodyLarge,
modifier = Modifier.padding(start = 16.dp)
)
}
}
}
} }
Row { Row {
Column(Modifier.selectableGroup()) { Column(Modifier.selectableGroup()) {
@@ -374,7 +446,9 @@ fun MainView(
RFCOMM_CLIENT_MODE, RFCOMM_CLIENT_MODE,
RFCOMM_SERVER_MODE, RFCOMM_SERVER_MODE,
L2CAP_CLIENT_MODE, L2CAP_CLIENT_MODE,
L2CAP_SERVER_MODE L2CAP_SERVER_MODE,
GATT_CLIENT_MODE,
GATT_SERVER_MODE
).forEach { text -> ).forEach { text ->
Row( Row(
Modifier Modifier
@@ -387,8 +461,7 @@ fun MainView(
verticalAlignment = Alignment.CenterVertically verticalAlignment = Alignment.CenterVertically
) { ) {
RadioButton( RadioButton(
selected = (text == appViewModel.mode), selected = (text == appViewModel.mode), onClick = null
onClick = null
) )
Text( Text(
text = text, text = text,
@@ -400,10 +473,7 @@ fun MainView(
} }
Column(Modifier.selectableGroup()) { Column(Modifier.selectableGroup()) {
listOf( listOf(
SEND_SCENARIO, SEND_SCENARIO, RECEIVE_SCENARIO, PING_SCENARIO, PONG_SCENARIO
RECEIVE_SCENARIO,
PING_SCENARIO,
PONG_SCENARIO
).forEach { text -> ).forEach { text ->
Row( Row(
Modifier Modifier
@@ -416,8 +486,7 @@ fun MainView(
verticalAlignment = Alignment.CenterVertically verticalAlignment = Alignment.CenterVertically
) { ) {
RadioButton( RadioButton(
selected = (text == appViewModel.scenario), selected = (text == appViewModel.scenario), onClick = null
onClick = null
) )
Text( Text(
text = text, text = text,
@@ -435,20 +504,29 @@ fun MainView(
ActionButton( ActionButton(
text = "Stop", onClick = appViewModel::abort, enabled = appViewModel.running text = "Stop", onClick = appViewModel::abort, enabled = appViewModel.running
) )
ActionButton(
text = "Become Discoverable", onClick = becomeDiscoverable, true
)
} }
Divider() Divider()
Text( if (appViewModel.mtu != 0) {
text = if (appViewModel.mtu != 0) "MTU: ${appViewModel.mtu}" else "" Text(
) text = "MTU: ${appViewModel.mtu}"
Text( )
text = if (appViewModel.rxPhy != 0 || appViewModel.txPhy != 0) "PHY: tx=${appViewModel.txPhy}, rx=${appViewModel.rxPhy}" else "" }
) if (appViewModel.rxPhy != 0) {
Text(
text = "PHY: tx=${appViewModel.txPhy}, rx=${appViewModel.rxPhy}"
)
}
Text( Text(
text = "Status: ${appViewModel.status}" text = "Status: ${appViewModel.status}"
) )
Text( if (appViewModel.lastError.isNotEmpty()) {
text = "Last Error: ${appViewModel.lastError}" Text(
) text = "Last Error: ${appViewModel.lastError}"
)
}
Text( Text(
text = "Packets Sent: ${appViewModel.packetsSent}" text = "Packets Sent: ${appViewModel.packetsSent}"
) )

View File

@@ -25,6 +25,7 @@ import java.util.UUID
val DEFAULT_RFCOMM_UUID: UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF6D3AE") val DEFAULT_RFCOMM_UUID: UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF6D3AE")
const val DEFAULT_PEER_BLUETOOTH_ADDRESS = "AA:BB:CC:DD:EE:FF" const val DEFAULT_PEER_BLUETOOTH_ADDRESS = "AA:BB:CC:DD:EE:FF"
const val DEFAULT_STARTUP_DELAY = 3000
const val DEFAULT_SENDER_PACKET_COUNT = 100 const val DEFAULT_SENDER_PACKET_COUNT = 100
const val DEFAULT_SENDER_PACKET_SIZE = 1024 const val DEFAULT_SENDER_PACKET_SIZE = 1024
const val DEFAULT_SENDER_PACKET_INTERVAL = 100 const val DEFAULT_SENDER_PACKET_INTERVAL = 100
@@ -34,6 +35,8 @@ const val L2CAP_CLIENT_MODE = "L2CAP Client"
const val L2CAP_SERVER_MODE = "L2CAP Server" const val L2CAP_SERVER_MODE = "L2CAP Server"
const val RFCOMM_CLIENT_MODE = "RFCOMM Client" const val RFCOMM_CLIENT_MODE = "RFCOMM Client"
const val RFCOMM_SERVER_MODE = "RFCOMM Server" const val RFCOMM_SERVER_MODE = "RFCOMM Server"
const val GATT_CLIENT_MODE = "GATT Client"
const val GATT_SERVER_MODE = "GATT Server"
const val SEND_SCENARIO = "Send" const val SEND_SCENARIO = "Send"
const val RECEIVE_SCENARIO = "Receive" const val RECEIVE_SCENARIO = "Receive"
@@ -47,8 +50,10 @@ class AppViewModel : ViewModel() {
var mode by mutableStateOf(RFCOMM_SERVER_MODE) var mode by mutableStateOf(RFCOMM_SERVER_MODE)
var scenario by mutableStateOf(RECEIVE_SCENARIO) var scenario by mutableStateOf(RECEIVE_SCENARIO)
var peerBluetoothAddress by mutableStateOf(DEFAULT_PEER_BLUETOOTH_ADDRESS) var peerBluetoothAddress by mutableStateOf(DEFAULT_PEER_BLUETOOTH_ADDRESS)
var startupDelay by mutableIntStateOf(DEFAULT_STARTUP_DELAY)
var l2capPsm by mutableIntStateOf(DEFAULT_PSM) var l2capPsm by mutableIntStateOf(DEFAULT_PSM)
var use2mPhy by mutableStateOf(true) var use2mPhy by mutableStateOf(true)
var connectionPriority by mutableStateOf("BALANCED")
var mtu by mutableIntStateOf(0) var mtu by mutableIntStateOf(0)
var rxPhy by mutableIntStateOf(0) var rxPhy by mutableIntStateOf(0)
var txPhy by mutableIntStateOf(0) var txPhy by mutableIntStateOf(0)
@@ -98,6 +103,11 @@ class AppViewModel : ViewModel() {
if (savedScenario != null) { if (savedScenario != null) {
scenario = savedScenario scenario = savedScenario
} }
val savedConnectionPriority = preferences.getString(CONNECTION_PRIORITY_PREF_KEY, null)
if (savedConnectionPriority != null) {
connectionPriority = savedConnectionPriority
}
} }
fun updatePeerBluetoothAddress(peerBluetoothAddress: String) { fun updatePeerBluetoothAddress(peerBluetoothAddress: String) {
@@ -220,6 +230,14 @@ class AppViewModel : ViewModel() {
} }
} }
fun updateConnectionPriority(connectionPriority: String) {
this.connectionPriority = connectionPriority
with(preferences!!.edit()) {
putString(CONNECTION_PRIORITY_PREF_KEY, connectionPriority)
apply()
}
}
fun clear() { fun clear() {
status = "" status = ""
lastError = "" lastError = ""

View File

@@ -17,6 +17,7 @@ package com.github.google.bumble.btbench
import android.bluetooth.BluetoothSocket import android.bluetooth.BluetoothSocket
import java.io.IOException import java.io.IOException
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.util.logging.Logger import java.util.logging.Logger
import kotlin.math.min import kotlin.math.min
@@ -37,11 +38,16 @@ abstract class Packet(val type: Int, val payload: ByteArray = ByteArray(0)) {
RESET -> ResetPacket() RESET -> ResetPacket()
SEQUENCE -> SequencePacket( SEQUENCE -> SequencePacket(
data[1].toInt(), data[1].toInt(),
ByteBuffer.wrap(data, 2, 4).getInt(), ByteBuffer.wrap(data, 2, 4).order(ByteOrder.LITTLE_ENDIAN).getInt(),
data.sliceArray(6..<data.size) ByteBuffer.wrap(data, 6, 4).order(ByteOrder.LITTLE_ENDIAN).getInt(),
data.sliceArray(10..<data.size)
)
ACK -> AckPacket(
data[1].toInt(),
ByteBuffer.wrap(data, 2, 4).order(ByteOrder.LITTLE_ENDIAN).getInt()
) )
ACK -> AckPacket(data[1].toInt(), ByteBuffer.wrap(data, 2, 4).getInt())
else -> GenericPacket(data[0].toInt(), data.sliceArray(1..<data.size)) else -> GenericPacket(data[0].toInt(), data.sliceArray(1..<data.size))
} }
} }
@@ -57,16 +63,24 @@ class ResetPacket : Packet(RESET)
class AckPacket(val flags: Int, val sequenceNumber: Int) : Packet(ACK) { class AckPacket(val flags: Int, val sequenceNumber: Int) : Packet(ACK) {
override fun toBytes(): ByteArray { override fun toBytes(): ByteArray {
return ByteBuffer.allocate(1 + 1 + 4).put(type.toByte()).put(flags.toByte()) return ByteBuffer.allocate(6).order(
ByteOrder.LITTLE_ENDIAN
).put(type.toByte()).put(flags.toByte())
.putInt(sequenceNumber).array() .putInt(sequenceNumber).array()
} }
} }
class SequencePacket(val flags: Int, val sequenceNumber: Int, payload: ByteArray) : class SequencePacket(
val flags: Int,
val sequenceNumber: Int,
val timestamp: Int,
payload: ByteArray
) :
Packet(SEQUENCE, payload) { Packet(SEQUENCE, payload) {
override fun toBytes(): ByteArray { override fun toBytes(): ByteArray {
return ByteBuffer.allocate(1 + 1 + 4 + payload.size).put(type.toByte()).put(flags.toByte()) return ByteBuffer.allocate(10 + payload.size).order(ByteOrder.LITTLE_ENDIAN)
.putInt(sequenceNumber).put(payload).array() .put(type.toByte()).put(flags.toByte())
.putInt(sequenceNumber).putInt(timestamp).put(payload).array()
} }
} }

View File

@@ -19,8 +19,6 @@ import java.util.logging.Logger
import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.TimeSource import kotlin.time.TimeSource
private const val DEFAULT_STARTUP_DELAY = 3000
private val Log = Logger.getLogger("btbench.pinger") private val Log = Logger.getLogger("btbench.pinger")
class Pinger(private val viewModel: AppViewModel, private val packetIO: PacketIO) : IoClient, class Pinger(private val viewModel: AppViewModel, private val packetIO: PacketIO) : IoClient,
@@ -36,8 +34,8 @@ class Pinger(private val viewModel: AppViewModel, private val packetIO: PacketIO
override fun run() { override fun run() {
viewModel.clear() viewModel.clear()
Log.info("startup delay: $DEFAULT_STARTUP_DELAY") Log.info("startup delay: ${viewModel.startupDelay}")
Thread.sleep(DEFAULT_STARTUP_DELAY.toLong()); Thread.sleep(viewModel.startupDelay.toLong());
Log.info("running") Log.info("running")
Log.info("sending reset") Log.info("sending reset")
@@ -48,19 +46,23 @@ class Pinger(private val viewModel: AppViewModel, private val packetIO: PacketIO
val startTime = TimeSource.Monotonic.markNow() val startTime = TimeSource.Monotonic.markNow()
for (i in 0..<packetCount) { for (i in 0..<packetCount) {
val now = TimeSource.Monotonic.markNow() var now = TimeSource.Monotonic.markNow()
val targetTime = startTime + (i * viewModel.senderPacketInterval).milliseconds if (viewModel.senderPacketInterval > 0) {
val delay = targetTime - now val targetTime = startTime + (i * viewModel.senderPacketInterval).milliseconds
if (delay.isPositive()) { val delay = targetTime - now
Log.info("sleeping ${delay.inWholeMilliseconds} ms") if (delay.isPositive()) {
Thread.sleep(delay.inWholeMilliseconds) Log.info("sleeping ${delay.inWholeMilliseconds} ms")
Thread.sleep(delay.inWholeMilliseconds)
now = TimeSource.Monotonic.markNow()
}
} }
pingTimes.add(TimeSource.Monotonic.markNow()) pingTimes.add(TimeSource.Monotonic.markNow())
packetIO.sendPacket( packetIO.sendPacket(
SequencePacket( SequencePacket(
if (i < packetCount - 1) 0 else Packet.LAST_FLAG, if (i < packetCount - 1) 0 else Packet.LAST_FLAG,
i, i,
ByteArray(packetSize - 6) (now - startTime).inWholeMicroseconds.toInt(),
ByteArray(packetSize - 10)
) )
) )
viewModel.packetsSent = i + 1 viewModel.packetsSent = i + 1

View File

@@ -14,6 +14,7 @@
package com.github.google.bumble.btbench package com.github.google.bumble.btbench
import java.util.concurrent.CountDownLatch
import java.util.logging.Logger import java.util.logging.Logger
import kotlin.time.TimeSource import kotlin.time.TimeSource
@@ -23,6 +24,7 @@ class Ponger(private val viewModel: AppViewModel, private val packetIO: PacketIO
private var startTime: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow() private var startTime: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
private var lastPacketTime: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow() private var lastPacketTime: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
private var expectedSequenceNumber: Int = 0 private var expectedSequenceNumber: Int = 0
private val done = CountDownLatch(1)
init { init {
packetIO.packetSink = this packetIO.packetSink = this
@@ -30,6 +32,7 @@ class Ponger(private val viewModel: AppViewModel, private val packetIO: PacketIO
override fun run() { override fun run() {
viewModel.clear() viewModel.clear()
done.await()
} }
override fun abort() {} override fun abort() {}
@@ -58,5 +61,10 @@ class Ponger(private val viewModel: AppViewModel, private val packetIO: PacketIO
packetIO.sendPacket(AckPacket(packet.flags, packet.sequenceNumber)) packetIO.sendPacket(AckPacket(packet.flags, packet.sequenceNumber))
viewModel.packetsSent += 1 viewModel.packetsSent += 1
if (packet.flags and Packet.LAST_FLAG != 0) {
Log.info("received last packet")
done.countDown()
}
} }
} }

View File

@@ -14,6 +14,7 @@
package com.github.google.bumble.btbench package com.github.google.bumble.btbench
import java.util.concurrent.CountDownLatch
import java.util.logging.Logger import java.util.logging.Logger
import kotlin.time.DurationUnit import kotlin.time.DurationUnit
import kotlin.time.TimeSource import kotlin.time.TimeSource
@@ -24,6 +25,7 @@ class Receiver(private val viewModel: AppViewModel, private val packetIO: Packet
private var startTime: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow() private var startTime: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
private var lastPacketTime: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow() private var lastPacketTime: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
private var bytesReceived = 0 private var bytesReceived = 0
private val done = CountDownLatch(1)
init { init {
packetIO.packetSink = this packetIO.packetSink = this
@@ -31,6 +33,7 @@ class Receiver(private val viewModel: AppViewModel, private val packetIO: Packet
override fun run() { override fun run() {
viewModel.clear() viewModel.clear()
done.await()
} }
override fun abort() {} override fun abort() {}
@@ -62,6 +65,7 @@ class Receiver(private val viewModel: AppViewModel, private val packetIO: Packet
Log.info("throughput: $throughput") Log.info("throughput: $throughput")
viewModel.throughput = throughput viewModel.throughput = throughput
packetIO.sendPacket(AckPacket(packet.flags, packet.sequenceNumber)) packetIO.sendPacket(AckPacket(packet.flags, packet.sequenceNumber))
done.countDown()
} }
} }
} }

View File

@@ -16,11 +16,10 @@ package com.github.google.bumble.btbench
import java.util.concurrent.Semaphore import java.util.concurrent.Semaphore
import java.util.logging.Logger import java.util.logging.Logger
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.DurationUnit import kotlin.time.DurationUnit
import kotlin.time.TimeSource import kotlin.time.TimeSource
private const val DEFAULT_STARTUP_DELAY = 3000
private val Log = Logger.getLogger("btbench.sender") private val Log = Logger.getLogger("btbench.sender")
class Sender(private val viewModel: AppViewModel, private val packetIO: PacketIO) : IoClient, class Sender(private val viewModel: AppViewModel, private val packetIO: PacketIO) : IoClient,
@@ -36,8 +35,8 @@ class Sender(private val viewModel: AppViewModel, private val packetIO: PacketIO
override fun run() { override fun run() {
viewModel.clear() viewModel.clear()
Log.info("startup delay: $DEFAULT_STARTUP_DELAY") Log.info("startup delay: ${viewModel.startupDelay}")
Thread.sleep(DEFAULT_STARTUP_DELAY.toLong()); Thread.sleep(viewModel.startupDelay.toLong());
Log.info("running") Log.info("running")
Log.info("sending reset") Log.info("sending reset")
@@ -47,20 +46,32 @@ class Sender(private val viewModel: AppViewModel, private val packetIO: PacketIO
val packetCount = viewModel.senderPacketCount val packetCount = viewModel.senderPacketCount
val packetSize = viewModel.senderPacketSize val packetSize = viewModel.senderPacketSize
for (i in 0..<packetCount - 1) { for (i in 0..<packetCount) {
packetIO.sendPacket(SequencePacket(0, i, ByteArray(packetSize - 6))) var now = TimeSource.Monotonic.markNow()
if (viewModel.senderPacketInterval > 0) {
val targetTime = startTime + (i * viewModel.senderPacketInterval).milliseconds
val delay = targetTime - now
if (delay.isPositive()) {
Log.info("sleeping ${delay.inWholeMilliseconds} ms")
Thread.sleep(delay.inWholeMilliseconds)
}
now = TimeSource.Monotonic.markNow()
}
val flags = when (i) {
packetCount - 1 -> Packet.LAST_FLAG
else -> 0
}
packetIO.sendPacket(
SequencePacket(
flags,
i,
(now - startTime).inWholeMicroseconds.toInt(),
ByteArray(packetSize - 10)
)
)
bytesSent += packetSize bytesSent += packetSize
viewModel.packetsSent = i + 1 viewModel.packetsSent = i + 1
} }
packetIO.sendPacket(
SequencePacket(
Packet.LAST_FLAG,
packetCount - 1,
ByteArray(packetSize - 6)
)
)
bytesSent += packetSize
viewModel.packetsSent = packetCount
// Wait for the ACK // Wait for the ACK
Log.info("waiting for ACK") Log.info("waiting for ACK")

View File

@@ -1,21 +1,135 @@
[build-system] [build-system]
requires = ["setuptools>=52", "wheel", "setuptools_scm>=6.2"] requires = ["setuptools>=61", "wheel", "setuptools_scm>=8"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project]
name = "bumble"
dynamic = ["version"]
description = "Bluetooth Stack for Apps, Emulation, Test and Experimentation"
readme = "README.md"
authors = [{ name = "Google", email = "bumble-dev@google.com" }]
requires-python = ">=3.8"
dependencies = [
"aiohttp ~= 3.8; platform_system!='Emscripten'",
"appdirs >= 1.4; platform_system!='Emscripten'",
"click >= 8.1.3; platform_system!='Emscripten'",
"cryptography >= 39; platform_system!='Emscripten'",
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the
# versions available on PyPI. Relax the version requirement since it's better than being
# completely unable to import the package in case of version mismatch.
"cryptography >= 39.0; platform_system=='Emscripten'",
"grpcio >= 1.62.1; platform_system!='Emscripten'",
"humanize >= 4.6.0; platform_system!='Emscripten'",
"libusb1 >= 2.0.1; platform_system!='Emscripten'",
"libusb-package == 1.0.26.1; platform_system!='Emscripten'",
"platformdirs >= 3.10.0; platform_system!='Emscripten'",
"prompt_toolkit >= 3.0.16; platform_system!='Emscripten'",
"prettytable >= 3.6.0; platform_system!='Emscripten'",
"protobuf >= 3.12.4; platform_system!='Emscripten'",
"pyee >= 8.2.2",
"pyserial-asyncio >= 0.5; platform_system!='Emscripten'",
"pyserial >= 3.5; platform_system!='Emscripten'",
"pyusb >= 1.2; platform_system!='Emscripten'",
"websockets == 13.1; platform_system!='Emscripten'",
]
[project.optional-dependencies]
build = ["build >= 0.7"]
test = [
"pytest >= 8.2",
"pytest-asyncio >= 0.23.5",
"pytest-html >= 3.2.0",
"coverage >= 6.4",
]
development = [
"black == 24.3",
"bt-test-interfaces >= 0.0.6",
"grpcio-tools >= 1.62.1",
"invoke >= 1.7.3",
"mobly >= 1.12.2",
"mypy == 1.12.0",
"nox >= 2022",
"pylint == 3.3.1",
"pyyaml >= 6.0",
"types-appdirs >= 1.4.3",
"types-invoke >= 1.7.3",
"types-protobuf >= 4.21.0",
]
avatar = [
"pandora-avatar == 0.0.10",
"rootcanal == 1.11.1 ; python_version>='3.10'",
]
pandora = ["bt-test-interfaces >= 0.0.6"]
documentation = [
"mkdocs >= 1.6.0",
"mkdocs-material >= 9.6",
"mkdocstrings[python] >= 0.27.0",
]
auracast = [
"lc3py >= 1.1.3; python_version>='3.10' and ((platform_system=='Linux' and platform_machine=='x86_64') or (platform_system=='Darwin' and platform_machine=='arm64'))",
"sounddevice >= 0.5.1",
]
[project.scripts]
bumble-auracast = "bumble.apps.auracast:main"
bumble-ble-rpa-tool = "bumble.apps.ble_rpa_tool:main"
bumble-console = "bumble.apps.console:main"
bumble-controller-info = "bumble.apps.controller_info:main"
bumble-controller-loopback = "bumble.apps.controller_loopback:main"
bumble-gatt-dump = "bumble.apps.gatt_dump:main"
bumble-hci-bridge = "bumble.apps.hci_bridge:main"
bumble-l2cap-bridge = "bumble.apps.l2cap_bridge:main"
bumble-rfcomm-bridge = "bumble.apps.rfcomm_bridge:main"
bumble-pair = "bumble.apps.pair:main"
bumble-scan = "bumble.apps.scan:main"
bumble-show = "bumble.apps.show:main"
bumble-unbond = "bumble.apps.unbond:main"
bumble-usb-probe = "bumble.apps.usb_probe:main"
bumble-link-relay = "bumble.apps.link_relay.link_relay:main"
bumble-bench = "bumble.apps.bench:main"
bumble-player = "bumble.apps.player.player:main"
bumble-speaker = "bumble.apps.speaker.speaker:main"
bumble-pandora-server = "bumble.apps.pandora_server:main"
bumble-rtk-util = "bumble.tools.rtk_util:main"
bumble-rtk-fw-download = "bumble.tools.rtk_fw_download:main"
bumble-intel-util = "bumble.tools.intel_util:main"
bumble-intel-fw-download = "bumble.tools.intel_fw_download:main"
[project.urls]
Homepage = "https://github.com/google/bumble"
[tool.setuptools]
packages = [
"bumble",
"bumble.transport",
"bumble.transport.grpc_protobuf",
"bumble.drivers",
"bumble.profiles",
"bumble.apps",
"bumble.apps.link_relay",
"bumble.pandora",
"bumble.tools",
]
[tool.setuptools.package-dir]
"bumble" = "bumble"
"bumble.apps" = "apps"
"bumble.tools" = "tools"
[tool.setuptools_scm] [tool.setuptools_scm]
write_to = "bumble/_version.py" write_to = "bumble/_version.py"
[tool.setuptools.package-data]
"*" = ["*.pyi", "py.typed"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
pythonpath = "." pythonpath = "."
testpaths = [ testpaths = ["tests"]
"tests"
]
[tool.pylint.master] [tool.pylint.master]
init-hook = 'import sys; sys.path.append(".")' init-hook = 'import sys; sys.path.append(".")'
ignore-paths = [ ignore-paths = ['.*_pb2(_grpc)?.py']
'.*_pb2(_grpc)?.py'
]
[tool.pylint.messages_control] [tool.pylint.messages_control]
max-line-length = "88" max-line-length = "88"
@@ -25,8 +139,8 @@ disable = [
"fixme", "fixme",
"logging-fstring-interpolation", "logging-fstring-interpolation",
"logging-not-lazy", "logging-not-lazy",
"no-member", # Temporary until pylint works better with class/method decorators "no-member", # Temporary until pylint works better with class/method decorators
"no-value-for-parameter", # Temporary until pylint works better with class/method decorators "no-value-for-parameter", # Temporary until pylint works better with class/method decorators
"missing-class-docstring", "missing-class-docstring",
"missing-function-docstring", "missing-function-docstring",
"missing-module-docstring", "missing-module-docstring",
@@ -41,11 +155,11 @@ disable = [
] ]
[tool.pylint.main] [tool.pylint.main]
ignore="pandora" # FIXME: pylint does not support stubs yet: ignore=["pandora", "mobly"] # FIXME: pylint does not support stubs yet
[tool.pylint.typecheck] [tool.pylint.typecheck]
signature-mutators="AsyncRunner.run_in_task" signature-mutators = "AsyncRunner.run_in_task"
disable="not-callable" disable = "not-callable"
[tool.black] [tool.black]
skip-string-normalization = true skip-string-normalization = true
@@ -78,6 +192,10 @@ ignore_missing_imports = true
module = "serial_asyncio.*" module = "serial_asyncio.*"
ignore_missing_imports = true ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "sounddevice.*"
ignore_missing_imports = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = "usb.*" module = "usb.*"
ignore_missing_imports = true ignore_missing_imports = true
@@ -85,4 +203,3 @@ ignore_missing_imports = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = "usb1.*" module = "usb1.*"
ignore_missing_imports = true ignore_missing_imports = true

View File

@@ -69,3 +69,68 @@ To regenerate the assigned number tables based on the Python codebase:
``` ```
PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features dev-tools PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features dev-tools
``` ```
## HCI packets
Sending a command packet from a device is composed to of two major steps.
There are more generalized ways of dealing with packets in other scenarios.
### Construct the command
Pick a command from `src/internal/hci/packets.pdl` and construct its associated "builder" struct.
```rust
// The "LE Set Scan Enable" command can be found in the Core Bluetooth Spec.
// It can also be found in `packets.pdl` as `packet LeSetScanEnable : Command`
fn main() {
let device = init_device_as_desired();
let le_set_scan_enable_command_builder = LeSetScanEnableBuilder {
filter_duplicates: Enable::Disabled,
le_scan_enable: Enable::Enabled,
};
}
```
### Send the command and interpret the event response
Send the command from an initialized device, and then receive the response.
```rust
fn main() {
// ...
// `check_result` to false to receive the event response even if the controller returns a failure code
let event = device.send_command(le_set_scan_enable_command_builder.into(), /*check_result*/ false);
// Coerce the event into the expected format. A `Command` should have an associated event response
// "<command name>Complete".
let le_set_scan_enable_complete_event: LeSetScanEnableComplete = event.try_into().unwrap();
}
```
### Generic packet handling
At the very least, you should expect to at least know _which_ kind of base packet you are dealing with. Base packets in
`packets.pdl` can be identified because they do not extend any other packet. They are easily found with the regex:
`^packet [^:]* \{`. For Bluetooth LE (BLE) HCI, one should find some kind of header preceding the packet with the purpose of
packet disambiguation. We do some of that disambiguation for H4 BLE packets using the `WithPacketHeader` trait at `internal/hci/mod.rs`.
Say you've identified a series of bytes that are certainly an `Acl` packet. They can be parsed using the `Acl` struct.
```rust
fn main() {
let bytes = bytes_that_are_certainly_acl();
let acl_packet = Acl::parse(bytes).unwrap();
}
```
Since you don't yet know what kind of `Acl` packet it is, you need to specialize it and then handle the various
potential cases.
```rust
fn main() {
// ...
match acl_packet.specialize() {
Payload(bytes) => do_something(bytes),
None => do_something_else(),
}
}
```
Some packets may yet further embed other packets, in which case you may need to further specialize until no more
specialization is needed.

View File

@@ -25,7 +25,6 @@ use clap::Parser as _;
use pyo3::PyResult; use pyo3::PyResult;
use rand::Rng; use rand::Rng;
use std::path; use std::path;
#[pyo3_asyncio::tokio::main] #[pyo3_asyncio::tokio::main]
async fn main() -> PyResult<()> { async fn main() -> PyResult<()> {
env_logger::builder() env_logger::builder()

View File

@@ -28,7 +28,7 @@ use bumble::wrapper::{
}; };
use pyo3::{ use pyo3::{
exceptions::PyException, exceptions::PyException,
{PyErr, PyResult}, FromPyObject, IntoPy, Python, {PyErr, PyResult},
}; };
#[pyo3_asyncio::tokio::test] #[pyo3_asyncio::tokio::test]
@@ -78,6 +78,28 @@ async fn test_hci_roundtrip_success_and_failure() -> PyResult<()> {
Ok(()) Ok(())
} }
#[pyo3_asyncio::tokio::test]
fn valid_error_code_extraction_succeeds() -> PyResult<()> {
let error_code = Python::with_gil(|py| {
let python_error_code_success = 0x00_u8.into_py(py);
ErrorCode::extract(python_error_code_success.as_ref(py))
})?;
assert_eq!(ErrorCode::Success, error_code);
Ok(())
}
#[pyo3_asyncio::tokio::test]
fn invalid_error_code_extraction_fails() -> PyResult<()> {
let failed_extraction = Python::with_gil(|py| {
let python_invalid_error_code = 0xFE_u8.into_py(py);
ErrorCode::extract(python_invalid_error_code.as_ref(py))
});
assert!(failed_extraction.is_err());
Ok(())
}
async fn create_local_device(address: Address) -> PyResult<Device> { async fn create_local_device(address: Address) -> PyResult<Device> {
let link = Link::new_local_link()?; let link = Link::new_local_link()?;
let controller = Controller::new("C1", None, None, Some(link), Some(address.clone())).await?; let controller = Controller::new("C1", None, None, Some(link), Some(address.clone())).await?;

View File

@@ -80,7 +80,7 @@ impl Address {
/// Creates a new [Address] object. /// Creates a new [Address] object.
pub fn new(address: &str, address_type: AddressType) -> PyResult<Self> { pub fn new(address: &str, address_type: AddressType) -> PyResult<Self> {
Python::with_gil(|py| { Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.device"))? PyModule::import(py, intern!(py, "bumble.hci"))?
.getattr(intern!(py, "Address"))? .getattr(intern!(py, "Address"))?
.call1((address, address_type)) .call1((address, address_type))
.map(|any| Self(any.into())) .map(|any| Self(any.into()))
@@ -178,7 +178,11 @@ impl IntoPy<PyObject> for AddressType {
impl<'source> FromPyObject<'source> for ErrorCode { impl<'source> FromPyObject<'source> for ErrorCode {
fn extract(ob: &'source PyAny) -> PyResult<Self> { fn extract(ob: &'source PyAny) -> PyResult<Self> {
ob.extract() // Bumble represents error codes simply as a single-byte number (in Rust, u8)
let value: u8 = ob.extract()?;
ErrorCode::try_from(value).map_err(|b| {
PyErr::new::<PyException, _>(format!("Failed to map {b} to an error code"))
})
} }
} }

111
setup.cfg
View File

@@ -1,111 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[metadata]
name = bumble
use_scm_version = True
description = Bluetooth Stack for Apps, Emulation, Test and Experimentation
long_description = file: README.md
long_description_content_type = text/markdown
author = Google
author_email = tbd@tbd.com
url = https://github.com/google/bumble
[options]
python_requires = >=3.8
packages = bumble, bumble.transport, bumble.transport.grpc_protobuf, bumble.drivers, bumble.profiles, bumble.apps, bumble.apps.link_relay, bumble.pandora, bumble.tools
package_dir =
bumble = bumble
bumble.apps = apps
bumble.tools = tools
include_package_data = True
install_requires =
aiohttp ~= 3.8; platform_system!='Emscripten'
appdirs >= 1.4; platform_system!='Emscripten'
click >= 8.1.3; platform_system!='Emscripten'
cryptography == 39; platform_system!='Emscripten'
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the
# versions available on PyPI. Relax the version requirement since it's better than being
# completely unable to import the package in case of version mismatch.
cryptography >= 39.0; platform_system=='Emscripten'
grpcio >= 1.62.1; platform_system!='Emscripten'
humanize >= 4.6.0; platform_system!='Emscripten'
libusb1 >= 2.0.1; platform_system!='Emscripten'
libusb-package == 1.0.26.1; platform_system!='Emscripten'
platformdirs >= 3.10.0; platform_system!='Emscripten'
prompt_toolkit >= 3.0.16; platform_system!='Emscripten'
prettytable >= 3.6.0; platform_system!='Emscripten'
protobuf >= 3.12.4; platform_system!='Emscripten'
pyee >= 8.2.2
pyserial-asyncio >= 0.5; platform_system!='Emscripten'
pyserial >= 3.5; platform_system!='Emscripten'
pyusb >= 1.2; platform_system!='Emscripten'
websockets >= 12.0; platform_system!='Emscripten'
[options.entry_points]
console_scripts =
bumble-ble-rpa-tool = bumble.apps.ble_rpa_tool:main
bumble-console = bumble.apps.console:main
bumble-controller-info = bumble.apps.controller_info:main
bumble-controller-loopback = bumble.apps.controller_loopback:main
bumble-gatt-dump = bumble.apps.gatt_dump:main
bumble-hci-bridge = bumble.apps.hci_bridge:main
bumble-l2cap-bridge = bumble.apps.l2cap_bridge:main
bumble-rfcomm-bridge = bumble.apps.rfcomm_bridge:main
bumble-pair = bumble.apps.pair:main
bumble-scan = bumble.apps.scan:main
bumble-show = bumble.apps.show:main
bumble-unbond = bumble.apps.unbond:main
bumble-usb-probe = bumble.apps.usb_probe:main
bumble-link-relay = bumble.apps.link_relay.link_relay:main
bumble-bench = bumble.apps.bench:main
bumble-player = bumble.apps.player.player:main
bumble-speaker = bumble.apps.speaker.speaker:main
bumble-pandora-server = bumble.apps.pandora_server:main
bumble-rtk-util = bumble.tools.rtk_util:main
bumble-rtk-fw-download = bumble.tools.rtk_fw_download:main
[options.package_data]
* = py.typed, *.pyi
[options.extras_require]
build =
build >= 0.7
test =
pytest >= 8.2
pytest-asyncio >= 0.23.5
pytest-html >= 3.2.0
coverage >= 6.4
development =
black == 24.3
grpcio-tools >= 1.62.1
invoke >= 1.7.3
mobly >= 1.12.2
mypy == 1.12.0
nox >= 2022
pylint == 3.3.1
pyyaml >= 6.0
types-appdirs >= 1.4.3
types-invoke >= 1.7.3
types-protobuf >= 4.21.0
wasmtime == 20.0.0
avatar =
pandora-avatar == 0.0.10
rootcanal == 1.10.0 ; python_version>='3.10'
pandora =
bt-test-interfaces >= 0.0.6
documentation =
mkdocs >= 1.4.0
mkdocs-material >= 8.5.6
mkdocstrings[python] >= 0.19.0

View File

@@ -28,11 +28,12 @@ from bumble.profiles.aics import (
AudioInputState, AudioInputState,
AICSServiceProxy, AICSServiceProxy,
GainMode, GainMode,
GainSettingsProperties,
AudioInputStatus, AudioInputStatus,
AudioInputControlPointOpCode, AudioInputControlPointOpCode,
ErrorCode, ErrorCode,
) )
from bumble.profiles.vcp import VolumeControlService, VolumeControlServiceProxy from bumble.profiles.vcs import VolumeControlService, VolumeControlServiceProxy
from .test_utils import TwoDevices from .test_utils import TwoDevices
@@ -82,7 +83,12 @@ async def test_init_service(aics_client: AICSServiceProxy):
gain_mode=GainMode.MANUAL, gain_mode=GainMode.MANUAL,
change_counter=0, change_counter=0,
) )
assert await aics_client.gain_settings_properties.read_value() == (1, 0, 255) assert (
await aics_client.gain_settings_properties.read_value()
== GainSettingsProperties(
gain_settings_unit=1, gain_settings_minimum=0, gain_settings_maximum=255
)
)
assert await aics_client.audio_input_status.read_value() == ( assert await aics_client.audio_input_status.read_value() == (
AudioInputStatus.ACTIVE AudioInputStatus.ACTIVE
) )
@@ -481,12 +487,12 @@ async def test_set_automatic_gain_mode_when_automatic_only(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_audio_input_description_initial_value(aics_client: AICSServiceProxy): async def test_audio_input_description_initial_value(aics_client: AICSServiceProxy):
description = await aics_client.audio_input_description.read_value() description = await aics_client.audio_input_description.read_value()
assert description.decode('utf-8') == "Bluetooth" assert description == "Bluetooth"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_audio_input_description_write_and_read(aics_client: AICSServiceProxy): async def test_audio_input_description_write_and_read(aics_client: AICSServiceProxy):
new_description = "Line Input".encode('utf-8') new_description = "Line Input"
await aics_client.audio_input_description.write_value(new_description) await aics_client.audio_input_description.write_value(new_description)

View File

@@ -39,6 +39,8 @@ from bumble.profiles.ascs import (
) )
from bumble.profiles.bap import ( from bumble.profiles.bap import (
AudioLocation, AudioLocation,
BasicAudioAnnouncement,
BroadcastAudioAnnouncement,
SupportedFrameDuration, SupportedFrameDuration,
SupportedSamplingFrequency, SupportedSamplingFrequency,
SamplingFrequency, SamplingFrequency,
@@ -200,6 +202,56 @@ def test_codec_specific_configuration() -> None:
assert CodecSpecificConfiguration.from_bytes(bytes(config)) == config assert CodecSpecificConfiguration.from_bytes(bytes(config)) == config
# -----------------------------------------------------------------------------
def test_broadcast_audio_announcement() -> None:
broadcast_audio_announcement = BroadcastAudioAnnouncement(123456)
assert (
BroadcastAudioAnnouncement.from_bytes(bytes(broadcast_audio_announcement))
== broadcast_audio_announcement
)
# -----------------------------------------------------------------------------
def test_basic_audio_announcement() -> None:
basic_audio_announcement = BasicAudioAnnouncement(
presentation_delay=40000,
subgroups=[
BasicAudioAnnouncement.Subgroup(
codec_id=CodingFormat(codec_id=CodecID.LC3),
codec_specific_configuration=CodecSpecificConfiguration(
sampling_frequency=SamplingFrequency.FREQ_48000,
frame_duration=FrameDuration.DURATION_10000_US,
octets_per_codec_frame=100,
),
metadata=Metadata(
[
Metadata.Entry(tag=Metadata.Tag.LANGUAGE, data=b'eng'),
Metadata.Entry(tag=Metadata.Tag.PROGRAM_INFO, data=b'Disco'),
]
),
bis=[
BasicAudioAnnouncement.BIS(
index=0,
codec_specific_configuration=CodecSpecificConfiguration(
audio_channel_allocation=AudioLocation.FRONT_LEFT
),
),
BasicAudioAnnouncement.BIS(
index=1,
codec_specific_configuration=CodecSpecificConfiguration(
audio_channel_allocation=AudioLocation.FRONT_RIGHT
),
),
],
)
],
)
assert (
BasicAudioAnnouncement.from_bytes(bytes(basic_audio_announcement))
== basic_audio_announcement
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pacs(): async def test_pacs():

View File

@@ -19,9 +19,7 @@ import asyncio
import functools import functools
import logging import logging
import os import os
from types import LambdaType
import pytest import pytest
from unittest import mock
from bumble.core import ( from bumble.core import (
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
@@ -29,8 +27,14 @@ from bumble.core import (
BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
ConnectionParameters, ConnectionParameters,
) )
from bumble.device import AdvertisingParameters, Connection, Device from bumble.device import (
from bumble.host import AclPacketQueue, Host AdvertisingEventProperties,
AdvertisingParameters,
Connection,
Device,
PeriodicAdvertisingParameters,
)
from bumble.host import DataPacketQueue, Host
from bumble.hci import ( from bumble.hci import (
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
HCI_COMMAND_STATUS_PENDING, HCI_COMMAND_STATUS_PENDING,
@@ -46,12 +50,7 @@ from bumble.hci import (
HCI_Error, HCI_Error,
HCI_Packet, HCI_Packet,
) )
from bumble.gatt import ( from bumble import gatt
GATT_GENERIC_ACCESS_SERVICE,
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC,
)
from .test_utils import TwoDevices, async_barrier from .test_utils import TwoDevices, async_barrier
@@ -86,9 +85,9 @@ async def test_device_connect_parallel():
def _send(packet): def _send(packet):
pass pass
d0.host.acl_packet_queue = AclPacketQueue(0, 0, _send) d0.host.acl_packet_queue = DataPacketQueue(0, 0, _send)
d1.host.acl_packet_queue = AclPacketQueue(0, 0, _send) d1.host.acl_packet_queue = DataPacketQueue(0, 0, _send)
d2.host.acl_packet_queue = AclPacketQueue(0, 0, _send) d2.host.acl_packet_queue = DataPacketQueue(0, 0, _send)
# enable classic # enable classic
d0.classic_enabled = True d0.classic_enabled = True
@@ -265,7 +264,8 @@ async def test_flush():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_legacy_advertising(): async def test_legacy_advertising():
device = Device(host=mock.AsyncMock(Host)) device = TwoDevices()[0]
await device.power_on()
# Start advertising # Start advertising
await device.start_advertising() await device.start_advertising()
@@ -283,7 +283,10 @@ async def test_legacy_advertising():
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_legacy_advertising_disconnection(auto_restart): async def test_legacy_advertising_disconnection(auto_restart):
device = Device(host=mock.AsyncMock(spec=Host)) devices = TwoDevices()
device = devices[0]
devices.controllers[0].le_features = bytes.fromhex('ffffffffffffffff')
await device.power_on()
peer_address = Address('F0:F1:F2:F3:F4:F5') peer_address = Address('F0:F1:F2:F3:F4:F5')
await device.start_advertising(auto_restart=auto_restart) await device.start_advertising(auto_restart=auto_restart)
device.on_connection( device.on_connection(
@@ -305,6 +308,11 @@ async def test_legacy_advertising_disconnection(auto_restart):
await async_barrier() await async_barrier()
if auto_restart: if auto_restart:
assert device.legacy_advertising_set
started = asyncio.Event()
if not device.is_advertising:
device.legacy_advertising_set.once('start', started.set)
await asyncio.wait_for(started.wait(), _TIMEOUT)
assert device.is_advertising assert device.is_advertising
else: else:
assert not device.is_advertising assert not device.is_advertising
@@ -313,7 +321,8 @@ async def test_legacy_advertising_disconnection(auto_restart):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extended_advertising(): async def test_extended_advertising():
device = Device(host=mock.AsyncMock(Host)) device = TwoDevices()[0]
await device.power_on()
# Start advertising # Start advertising
advertising_set = await device.create_advertising_set() advertising_set = await device.create_advertising_set()
@@ -332,7 +341,8 @@ async def test_extended_advertising():
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extended_advertising_connection(own_address_type): async def test_extended_advertising_connection(own_address_type):
device = Device(host=mock.AsyncMock(spec=Host)) device = TwoDevices()[0]
await device.power_on()
peer_address = Address('F0:F1:F2:F3:F4:F5') peer_address = Address('F0:F1:F2:F3:F4:F5')
advertising_set = await device.create_advertising_set( advertising_set = await device.create_advertising_set(
advertising_parameters=AdvertisingParameters(own_address_type=own_address_type) advertising_parameters=AdvertisingParameters(own_address_type=own_address_type)
@@ -368,8 +378,10 @@ async def test_extended_advertising_connection(own_address_type):
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extended_advertising_connection_out_of_order(own_address_type): async def test_extended_advertising_connection_out_of_order(own_address_type):
device = Device(host=mock.AsyncMock(spec=Host)) devices = TwoDevices()
peer_address = Address('F0:F1:F2:F3:F4:F5') device = devices[0]
devices.controllers[0].le_features = bytes.fromhex('ffffffffffffffff')
await device.power_on()
advertising_set = await device.create_advertising_set( advertising_set = await device.create_advertising_set(
advertising_parameters=AdvertisingParameters(own_address_type=own_address_type) advertising_parameters=AdvertisingParameters(own_address_type=own_address_type)
) )
@@ -382,7 +394,7 @@ async def test_extended_advertising_connection_out_of_order(own_address_type):
device.on_connection( device.on_connection(
0x0001, 0x0001,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
peer_address, Address('F0:F1:F2:F3:F4:F5'),
None, None,
None, None,
BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
@@ -397,6 +409,34 @@ async def test_extended_advertising_connection_out_of_order(own_address_type):
await async_barrier() await async_barrier()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_periodic_advertising():
device = TwoDevices()[0]
await device.power_on()
# Start advertising
advertising_set = await device.create_advertising_set(
advertising_parameters=AdvertisingParameters(
advertising_event_properties=AdvertisingEventProperties(
is_connectable=False
)
),
advertising_data=b'123',
periodic_advertising_parameters=PeriodicAdvertisingParameters(),
periodic_advertising_data=b'abc',
)
assert device.extended_advertising_sets
assert advertising_set.enabled
assert not advertising_set.periodic_enabled
await advertising_set.start_periodic()
assert advertising_set.periodic_enabled
await advertising_set.stop_periodic()
assert not advertising_set.periodic_enabled
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_remote_le_features(): async def test_get_remote_le_features():
@@ -547,32 +587,54 @@ async def test_power_on_default_static_address_should_not_be_any():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_gatt_services_with_gas(): def test_gatt_services_with_gas_and_gatt():
device = Device(host=Host(None, None)) device = Device(host=Host(None, None))
# there should be one service and two chars, therefore 5 attributes # there should be 2 service, 5 chars, and 1 descriptors, therefore 13 attributes
assert len(device.gatt_server.attributes) == 5 assert len(device.gatt_server.attributes) == 13
assert device.gatt_server.attributes[0].uuid == GATT_GENERIC_ACCESS_SERVICE assert device.gatt_server.attributes[0].uuid == gatt.GATT_GENERIC_ACCESS_SERVICE
assert device.gatt_server.attributes[1].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE assert (
assert device.gatt_server.attributes[2].uuid == GATT_DEVICE_NAME_CHARACTERISTIC device.gatt_server.attributes[1].type == gatt.GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
assert device.gatt_server.attributes[3].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE )
assert device.gatt_server.attributes[4].uuid == GATT_APPEARANCE_CHARACTERISTIC assert device.gatt_server.attributes[2].uuid == gatt.GATT_DEVICE_NAME_CHARACTERISTIC
assert (
device.gatt_server.attributes[3].type == gatt.GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
)
assert device.gatt_server.attributes[4].uuid == gatt.GATT_APPEARANCE_CHARACTERISTIC
assert device.gatt_server.attributes[5].uuid == gatt.GATT_GENERIC_ATTRIBUTE_SERVICE
# ----------------------------------------------------------------------------- assert (
def test_gatt_services_without_gas(): device.gatt_server.attributes[6].type == gatt.GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
device = Device(host=Host(None, None), generic_access_service=False) )
assert (
# there should be no services device.gatt_server.attributes[7].uuid
assert len(device.gatt_server.attributes) == 0 == gatt.GATT_SERVICE_CHANGED_CHARACTERISTIC
)
assert (
device.gatt_server.attributes[8].type
== gatt.GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
assert (
device.gatt_server.attributes[9].type == gatt.GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
)
assert (
device.gatt_server.attributes[10].uuid
== gatt.GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC
)
assert (
device.gatt_server.attributes[11].type
== gatt.GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
)
assert (
device.gatt_server.attributes[12].uuid == gatt.GATT_DATABASE_HASH_CHARACTERISTIC
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def run_test_device(): async def run_test_device():
await test_device_connect_parallel() await test_device_connect_parallel()
await test_flush() await test_flush()
await test_gatt_services_with_gas() await test_gatt_services_with_gas_and_gatt()
await test_gatt_services_without_gas()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

140
tests/gatt_service_test.py Normal file
View File

@@ -0,0 +1,140 @@
# Copyright 2021-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from . import test_utils
from bumble import device
from bumble import gatt
from bumble.profiles import gatt_service
# -----------------------------------------------------------------------------
async def test_database_hash():
devices = await test_utils.TwoDevices.create_with_connection()
devices[0].gatt_server.services.clear()
devices[0].gatt_server.attributes.clear()
devices[0].gatt_server.attributes_by_handle.clear()
devices[0].add_service(
gatt.Service(
gatt.GATT_GENERIC_ACCESS_SERVICE,
characteristics=[
gatt.Characteristic(
gatt.GATT_DEVICE_NAME_CHARACTERISTIC,
(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE
),
gatt.Characteristic.Permissions.READ_REQUIRES_AUTHENTICATION,
),
gatt.Characteristic(
gatt.GATT_APPEARANCE_CHARACTERISTIC,
gatt.Characteristic.Properties.READ,
gatt.Characteristic.Permissions.READ_REQUIRES_AUTHENTICATION,
),
],
)
)
devices[0].add_service(
gatt_service.GenericAttributeProfileService(
server_supported_features=None,
database_hash_enabled=True,
service_change_enabled=True,
)
)
devices[0].gatt_server.add_attribute(
gatt.Service(gatt.GATT_GLUCOSE_SERVICE, characteristics=[])
)
# There is a special attribute order in the spec, so we need to add attribute one by
# one here.
battery_service = gatt.Service(
gatt.GATT_BATTERY_SERVICE,
characteristics=[
gatt.Characteristic(
gatt.GATT_BATTERY_LEVEL_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_AUTHENTICATION,
)
],
primary=False,
)
battery_service.handle = 0x0014
battery_service.end_group_handle = 0x0016
devices[0].gatt_server.add_attribute(
gatt.IncludedServiceDeclaration(battery_service)
)
c = gatt.Characteristic(
'2A18',
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.INDICATE
| gatt.Characteristic.Properties.EXTENDED_PROPERTIES
),
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_AUTHENTICATION,
)
devices[0].gatt_server.add_attribute(
gatt.CharacteristicDeclaration(c, devices[0].gatt_server.next_handle() + 1)
)
devices[0].gatt_server.add_attribute(c)
devices[0].gatt_server.add_attribute(
gatt.Descriptor(
gatt.GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
gatt.Descriptor.Permissions.READ_REQUIRES_AUTHENTICATION,
b'\x02\x00',
),
)
devices[0].gatt_server.add_attribute(
gatt.Descriptor(
gatt.GATT_CHARACTERISTIC_EXTENDED_PROPERTIES_DESCRIPTOR,
gatt.Descriptor.Permissions.READ_REQUIRES_AUTHENTICATION,
b'\x00\x00',
),
)
devices[0].add_service(battery_service)
peer = device.Peer(devices.connections[1])
client = await peer.discover_service_and_create_proxy(
gatt_service.GenericAttributeProfileServiceProxy
)
assert client.database_hash_characteristic
assert await client.database_hash_characteristic.read_value() == bytes.fromhex(
'F1CA2D48ECF58BAC8A8830BBB9FBA990'
)
# -----------------------------------------------------------------------------
async def test_service_changed():
devices = await test_utils.TwoDevices.create_with_connection()
assert (service := devices[0].gatt_service)
peer = device.Peer(devices.connections[1])
assert (
client := await peer.discover_service_and_create_proxy(
gatt_service.GenericAttributeProfileServiceProxy
)
)
assert client.service_changed_characteristic
indications = []
await client.service_changed_characteristic.subscribe(
indications.append, prefer_notify=False
)
await devices[0].indicate_subscribers(
service.service_changed_characteristic, b'1234'
)
await test_utils.async_barrier()
assert indications[0] == b'1234'

View File

@@ -15,11 +15,13 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os import os
import struct import struct
import pytest import pytest
from typing_extensions import Self
from unittest.mock import AsyncMock, Mock, ANY from unittest.mock import AsyncMock, Mock, ANY
from bumble.controller import Controller from bumble.controller import Controller
@@ -31,6 +33,7 @@ from bumble.gatt import (
GATT_BATTERY_LEVEL_CHARACTERISTIC, GATT_BATTERY_LEVEL_CHARACTERISTIC,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
CharacteristicAdapter, CharacteristicAdapter,
SerializableCharacteristicAdapter,
DelegatedCharacteristicAdapter, DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter, PackedCharacteristicAdapter,
MappedCharacteristicAdapter, MappedCharacteristicAdapter,
@@ -57,7 +60,7 @@ from .test_utils import async_barrier
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def basic_check(x): def basic_check(x):
pdu = x.to_bytes() pdu = bytes(x)
parsed = ATT_PDU.from_bytes(pdu) parsed = ATT_PDU.from_bytes(pdu)
x_str = str(x) x_str = str(x)
parsed_str = str(parsed) parsed_str = str(parsed)
@@ -74,7 +77,7 @@ def test_UUID():
assert str(u) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6' assert str(u) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
v = UUID(str(u)) v = UUID(str(u))
assert str(v) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6' assert str(v) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
w = UUID.from_bytes(v.to_bytes()) w = UUID.from_bytes(bytes(v))
assert str(w) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6' assert str(w) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
u1 = UUID.from_16_bits(0x1234) u1 = UUID.from_16_bits(0x1234)
@@ -310,7 +313,7 @@ async def test_attribute_getters():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_CharacteristicAdapter(): async def test_CharacteristicAdapter() -> None:
# Check that the CharacteristicAdapter base class is transparent # Check that the CharacteristicAdapter base class is transparent
v = bytes([1, 2, 3]) v = bytes([1, 2, 3])
c = Characteristic( c = Characteristic(
@@ -329,67 +332,94 @@ async def test_CharacteristicAdapter():
assert c.value == v assert c.value == v
# Simple delegated adapter # Simple delegated adapter
a = DelegatedCharacteristicAdapter( delegated = DelegatedCharacteristicAdapter(
c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x)) c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x))
) )
value = await a.read_value(None) delegated_value = await delegated.read_value(None)
assert value == bytes(reversed(v)) assert delegated_value == bytes(reversed(v))
v = bytes([3, 4, 5]) delegated_value2 = bytes([3, 4, 5])
await a.write_value(None, v) await delegated.write_value(None, delegated_value2)
assert a.value == bytes(reversed(v)) assert delegated.value == bytes(reversed(delegated_value2))
# Packed adapter with single element format # Packed adapter with single element format
v = 1234 packed_value_ref = 1234
pv = struct.pack('>H', v) packed_value_bytes = struct.pack('>H', packed_value_ref)
c.value = v c.value = packed_value_ref
a = PackedCharacteristicAdapter(c, '>H') packed = PackedCharacteristicAdapter(c, '>H')
value = await a.read_value(None) packed_value_read = await packed.read_value(None)
assert value == pv assert packed_value_read == packed_value_bytes
c.value = None c.value = b''
await a.write_value(None, pv) await packed.write_value(None, packed_value_bytes)
assert a.value == v assert packed.value == packed_value_ref
# Packed adapter with multi-element format # Packed adapter with multi-element format
v1 = 1234 v1 = 1234
v2 = 5678 v2 = 5678
pv = struct.pack('>HH', v1, v2) packed_multi_value_bytes = struct.pack('>HH', v1, v2)
c.value = (v1, v2) c.value = (v1, v2)
a = PackedCharacteristicAdapter(c, '>HH') packed_multi = PackedCharacteristicAdapter(c, '>HH')
value = await a.read_value(None) packed_multi_read_value = await packed_multi.read_value(None)
assert value == pv assert packed_multi_read_value == packed_multi_value_bytes
c.value = None packed_multi.value = b''
await a.write_value(None, pv) await packed_multi.write_value(None, packed_multi_value_bytes)
assert a.value == (v1, v2) assert packed_multi.value == (v1, v2)
# Mapped adapter # Mapped adapter
v1 = 1234 v1 = 1234
v2 = 5678 v2 = 5678
pv = struct.pack('>HH', v1, v2) packed_mapped_value_bytes = struct.pack('>HH', v1, v2)
mapped = {'v1': v1, 'v2': v2} mapped = {'v1': v1, 'v2': v2}
c.value = mapped c.value = mapped
a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2')) packed_mapped = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
value = await a.read_value(None) packed_mapped_read_value = await packed_mapped.read_value(None)
assert value == pv assert packed_mapped_read_value == packed_mapped_value_bytes
c.value = None c.value = b''
await a.write_value(None, pv) await packed_mapped.write_value(None, packed_mapped_value_bytes)
assert a.value == mapped assert packed_mapped.value == mapped
# UTF-8 adapter # UTF-8 adapter
v = 'Hello π' string_value = 'Hello π'
ev = v.encode('utf-8') string_value_bytes = string_value.encode('utf-8')
c.value = v c.value = string_value
a = UTF8CharacteristicAdapter(c) string_c = UTF8CharacteristicAdapter(c)
value = await a.read_value(None) string_read_value = await string_c.read_value(None)
assert value == ev assert string_read_value == string_value_bytes
c.value = None c.value = b''
await a.write_value(None, ev) await string_c.write_value(None, string_value_bytes)
assert a.value == v assert string_c.value == string_value
# Class adapter
class BlaBla:
def __init__(self, a: int, b: int) -> None:
self.a = a
self.b = b
@classmethod
def from_bytes(cls, data: bytes) -> Self:
a, b = struct.unpack(">II", data)
return cls(a, b)
def __bytes__(self) -> bytes:
return struct.pack(">II", self.a, self.b)
class_value = BlaBla(3, 4)
class_value_bytes = struct.pack(">II", 3, 4)
c.value = class_value
class_c = SerializableCharacteristicAdapter(c, BlaBla)
class_read_value = await class_c.read_value(None)
assert class_read_value == class_value_bytes
c.value = b''
await class_c.write_value(None, class_value_bytes)
assert isinstance(c.value, BlaBla)
assert c.value.a == 3
assert c.value.b == 4
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -851,7 +881,12 @@ async def test_unsubscribe():
await async_barrier() await async_barrier()
mock1.assert_called_once_with(ANY, True, False) mock1.assert_called_once_with(ANY, True, False)
await c2.subscribe() assert len(server.gatt_server.subscribers) == 1
def callback(_):
pass
await c2.subscribe(callback)
await async_barrier() await async_barrier()
mock2.assert_called_once_with(ANY, True, False) mock2.assert_called_once_with(ANY, True, False)
@@ -861,10 +896,16 @@ async def test_unsubscribe():
mock1.assert_called_once_with(ANY, False, False) mock1.assert_called_once_with(ANY, False, False)
mock2.reset_mock() mock2.reset_mock()
await c2.unsubscribe() await c2.unsubscribe(callback)
await async_barrier() await async_barrier()
mock2.assert_called_once_with(ANY, False, False) mock2.assert_called_once_with(ANY, False, False)
# All CCCDs should be zeros now
assert list(server.gatt_server.subscribers.values())[0] == {
c1.handle: bytes([0, 0]),
c2.handle: bytes([0, 0]),
}
mock1.reset_mock() mock1.reset_mock()
await c1.unsubscribe() await c1.unsubscribe()
await async_barrier() await async_barrier()
@@ -916,11 +957,12 @@ async def test_discover_all():
peer = Peer(connection) peer = Peer(connection)
await peer.discover_all() await peer.discover_all()
assert len(peer.gatt_client.services) == 3 assert len(peer.gatt_client.services) == 4
# service 1800 gets added automatically # service 1800 and 1801 get added automatically
assert peer.gatt_client.services[0].uuid == UUID('1800') assert peer.gatt_client.services[0].uuid == UUID('1800')
assert peer.gatt_client.services[1].uuid == service1.uuid assert peer.gatt_client.services[1].uuid == UUID('1801')
assert peer.gatt_client.services[2].uuid == service2.uuid assert peer.gatt_client.services[2].uuid == service1.uuid
assert peer.gatt_client.services[3].uuid == service2.uuid
s = peer.get_services_by_uuid(service1.uuid) s = peer.get_services_by_uuid(service1.uuid)
assert len(s) == 1 assert len(s) == 1
assert len(s[0].characteristics) == 2 assert len(s[0].characteristics) == 2
@@ -1043,10 +1085,18 @@ CharacteristicDeclaration(handle=0x0002, value_handle=0x0003, uuid=UUID-16:2A00
Characteristic(handle=0x0003, end=0x0003, uuid=UUID-16:2A00 (Device Name), READ) Characteristic(handle=0x0003, end=0x0003, uuid=UUID-16:2A00 (Device Name), READ)
CharacteristicDeclaration(handle=0x0004, value_handle=0x0005, uuid=UUID-16:2A01 (Appearance), READ) CharacteristicDeclaration(handle=0x0004, value_handle=0x0005, uuid=UUID-16:2A01 (Appearance), READ)
Characteristic(handle=0x0005, end=0x0005, uuid=UUID-16:2A01 (Appearance), READ) Characteristic(handle=0x0005, end=0x0005, uuid=UUID-16:2A01 (Appearance), READ)
Service(handle=0x0006, end=0x0009, uuid=3A657F47-D34F-46B3-B1EC-698E29B6B829) Service(handle=0x0006, end=0x000D, uuid=UUID-16:1801 (Generic Attribute))
CharacteristicDeclaration(handle=0x0007, value_handle=0x0008, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, READ|WRITE|NOTIFY) CharacteristicDeclaration(handle=0x0007, value_handle=0x0008, uuid=UUID-16:2A05 (Service Changed), INDICATE)
Characteristic(handle=0x0008, end=0x0009, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, READ|WRITE|NOTIFY) Characteristic(handle=0x0008, end=0x0009, uuid=UUID-16:2A05 (Service Changed), INDICATE)
Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration), value=0000)""" Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration), value=0000)
CharacteristicDeclaration(handle=0x000A, value_handle=0x000B, uuid=UUID-16:2B29 (Client Supported Features), READ|WRITE)
Characteristic(handle=0x000B, end=0x000B, uuid=UUID-16:2B29 (Client Supported Features), READ|WRITE)
CharacteristicDeclaration(handle=0x000C, value_handle=0x000D, uuid=UUID-16:2B2A (Database Hash), READ)
Characteristic(handle=0x000D, end=0x000D, uuid=UUID-16:2B2A (Database Hash), READ)
Service(handle=0x000E, end=0x0011, uuid=3A657F47-D34F-46B3-B1EC-698E29B6B829)
CharacteristicDeclaration(handle=0x000F, value_handle=0x0010, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, READ|WRITE|NOTIFY)
Characteristic(handle=0x0010, end=0x0011, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, READ|WRITE|NOTIFY)
Descriptor(handle=0x0011, type=UUID-16:2902 (Client Characteristic Configuration), value=0000)"""
) )

84
tests/gmap_test.py Normal file
View File

@@ -0,0 +1,84 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import pytest
import pytest_asyncio
from bumble import device
from bumble.profiles.gmap import (
GamingAudioService,
GamingAudioServiceProxy,
GmapRole,
UggFeatures,
UgtFeatures,
BgrFeatures,
BgsFeatures,
)
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
gmas_service = GamingAudioService(
gmap_role=GmapRole.UNICAST_GAME_GATEWAY
| GmapRole.UNICAST_GAME_TERMINAL
| GmapRole.BROADCAST_GAME_RECEIVER
| GmapRole.BROADCAST_GAME_SENDER,
ugg_features=UggFeatures.UGG_MULTISINK,
ugt_features=UgtFeatures.UGT_SOURCE,
bgr_features=BgrFeatures.BGR_MULTISINK,
bgs_features=BgsFeatures.BGS_96_KBPS,
)
@pytest_asyncio.fixture
async def gmap_client():
devices = TwoDevices()
devices[0].add_service(gmas_service)
await devices.setup_connection()
assert devices.connections[0]
assert devices.connections[1]
devices.connections[0].encryption = 1
devices.connections[1].encryption = 1
peer = device.Peer(devices.connections[1])
gmap_client = await peer.discover_service_and_create_proxy(GamingAudioServiceProxy)
assert gmap_client
yield gmap_client
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_init_service(gmap_client: GamingAudioServiceProxy):
assert (
await gmap_client.gmap_role.read_value()
== GmapRole.UNICAST_GAME_GATEWAY
| GmapRole.UNICAST_GAME_TERMINAL
| GmapRole.BROADCAST_GAME_RECEIVER
| GmapRole.BROADCAST_GAME_SENDER
)
assert await gmap_client.ugg_features.read_value() == UggFeatures.UGG_MULTISINK
assert await gmap_client.ugt_features.read_value() == UgtFeatures.UGT_SOURCE
assert await gmap_client.bgr_features.read_value() == BgrFeatures.BGR_MULTISINK
assert await gmap_client.bgs_features.read_value() == BgsFeatures.BGS_96_KBPS

View File

@@ -15,6 +15,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import struct
from bumble.hci import ( from bumble.hci import (
HCI_DISCONNECT_COMMAND, HCI_DISCONNECT_COMMAND,
@@ -22,6 +23,7 @@ from bumble.hci import (
HCI_LE_CODED_PHY_BIT, HCI_LE_CODED_PHY_BIT,
HCI_LE_READ_BUFFER_SIZE_COMMAND, HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_RESET_COMMAND, HCI_RESET_COMMAND,
HCI_VENDOR_EVENT,
HCI_SUCCESS, HCI_SUCCESS,
HCI_LE_CONNECTION_COMPLETE_EVENT, HCI_LE_CONNECTION_COMPLETE_EVENT,
HCI_LE_ENHANCED_CONNECTION_COMPLETE_V2_EVENT, HCI_LE_ENHANCED_CONNECTION_COMPLETE_V2_EVENT,
@@ -67,6 +69,7 @@ from bumble.hci import (
HCI_Read_Local_Version_Information_Command, HCI_Read_Local_Version_Information_Command,
HCI_Reset_Command, HCI_Reset_Command,
HCI_Set_Event_Mask_Command, HCI_Set_Event_Mask_Command,
HCI_Vendor_Event,
) )
@@ -75,13 +78,13 @@ from bumble.hci import (
def basic_check(x): def basic_check(x):
packet = x.to_bytes() packet = bytes(x)
print(packet.hex()) print(packet.hex())
parsed = HCI_Packet.from_bytes(packet) parsed = HCI_Packet.from_bytes(packet)
x_str = str(x) x_str = str(x)
parsed_str = str(parsed) parsed_str = str(parsed)
print(x_str) print(x_str)
parsed_bytes = parsed.to_bytes() parsed_bytes = bytes(parsed)
assert x_str == parsed_str assert x_str == parsed_str
assert packet == parsed_bytes assert packet == parsed_bytes
@@ -167,8 +170,8 @@ def test_HCI_Command_Complete_Event():
command_opcode=HCI_LE_READ_BUFFER_SIZE_COMMAND, command_opcode=HCI_LE_READ_BUFFER_SIZE_COMMAND,
return_parameters=HCI_LE_Read_Buffer_Size_Command.create_return_parameters( return_parameters=HCI_LE_Read_Buffer_Size_Command.create_return_parameters(
status=0, status=0,
hc_le_acl_data_packet_length=1234, le_acl_data_packet_length=1234,
hc_total_num_le_acl_data_packets=56, total_num_le_acl_data_packets=56,
), ),
) )
basic_check(event) basic_check(event)
@@ -188,7 +191,7 @@ def test_HCI_Command_Complete_Event():
return_parameters=bytes([7]), return_parameters=bytes([7]),
) )
basic_check(event) basic_check(event)
event = HCI_Packet.from_bytes(event.to_bytes()) event = HCI_Packet.from_bytes(bytes(event))
assert event.return_parameters == 7 assert event.return_parameters == 7
# With a simple status as an integer status # With a simple status as an integer status
@@ -213,6 +216,41 @@ def test_HCI_Number_Of_Completed_Packets_Event():
basic_check(event) basic_check(event)
# -----------------------------------------------------------------------------
def test_HCI_Vendor_Event():
data = bytes.fromhex('01020304')
event = HCI_Vendor_Event(data=data)
event_bytes = bytes(event)
parsed = HCI_Packet.from_bytes(event_bytes)
assert isinstance(parsed, HCI_Vendor_Event)
assert parsed.data == data
class HCI_Custom_Event(HCI_Event):
def __init__(self, blabla):
super().__init__(HCI_VENDOR_EVENT, parameters=struct.pack("<I", blabla))
self.name = 'HCI_CUSTOM_EVENT'
self.blabla = blabla
def create_event(payload):
if payload[0] == 1:
return HCI_Custom_Event(blabla=struct.unpack('<I', payload)[0])
return None
HCI_Event.add_vendor_factory(create_event)
parsed = HCI_Packet.from_bytes(event_bytes)
assert isinstance(parsed, HCI_Custom_Event)
assert parsed.blabla == 0x04030201
event_bytes2 = event_bytes[:3] + bytes([7]) + event_bytes[4:]
parsed = HCI_Packet.from_bytes(event_bytes2)
assert not isinstance(parsed, HCI_Custom_Event)
assert isinstance(parsed, HCI_Vendor_Event)
HCI_Event.remove_vendor_factory(create_event)
parsed = HCI_Packet.from_bytes(event_bytes)
assert not isinstance(parsed, HCI_Custom_Event)
assert isinstance(parsed, HCI_Vendor_Event)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_HCI_Command(): def test_HCI_Command():
command = HCI_Command(0x5566) command = HCI_Command(0x5566)
@@ -562,7 +600,7 @@ def test_iso_data_packet():
'6281bc77ed6a3206d984bcdabee6be831c699cb50e2' '6281bc77ed6a3206d984bcdabee6be831c699cb50e2'
) )
assert packet.to_bytes() == data assert bytes(packet) == data
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -576,6 +614,7 @@ def run_test_events():
test_HCI_Command_Complete_Event() test_HCI_Command_Complete_Event()
test_HCI_Command_Status_Event() test_HCI_Command_Status_Event()
test_HCI_Number_Of_Completed_Packets_Event() test_HCI_Number_Of_Completed_Packets_Event()
test_HCI_Vendor_Event()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -61,7 +61,7 @@ def _default_hf_configuration() -> hfp.HfConfiguration:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def _default_hf_sdp_features() -> hfp.HfSdpFeature: def _default_hf_sdp_features() -> hfp.HfSdpFeature:
return ( return (
hfp.HfSdpFeature.WIDE_BAND hfp.HfSdpFeature.WIDE_BAND_SPEECH
| hfp.HfSdpFeature.THREE_WAY_CALLING | hfp.HfSdpFeature.THREE_WAY_CALLING
| hfp.HfSdpFeature.CLI_PRESENTATION_CAPABILITY | hfp.HfSdpFeature.CLI_PRESENTATION_CAPABILITY
) )
@@ -108,7 +108,7 @@ def _default_ag_configuration() -> hfp.AgConfiguration:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def _default_ag_sdp_features() -> hfp.AgSdpFeature: def _default_ag_sdp_features() -> hfp.AgSdpFeature:
return ( return (
hfp.AgSdpFeature.WIDE_BAND hfp.AgSdpFeature.WIDE_BAND_SPEECH
| hfp.AgSdpFeature.IN_BAND_RING_TONE_CAPABILITY | hfp.AgSdpFeature.IN_BAND_RING_TONE_CAPABILITY
| hfp.AgSdpFeature.THREE_WAY_CALLING | hfp.AgSdpFeature.THREE_WAY_CALLING
) )

View File

@@ -16,11 +16,14 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
import unittest.mock
import pytest import pytest
import unittest
from bumble.controller import Controller from bumble.controller import Controller
from bumble.host import Host from bumble.host import Host, DataPacketQueue
from bumble.transport import AsyncPipeSink from bumble.transport import AsyncPipeSink
from bumble.hci import HCI_AclDataPacket
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -60,3 +63,90 @@ async def test_reset(supported_commands: str, lmp_features: str):
assert host.local_lmp_features == int.from_bytes( assert host.local_lmp_features == int.from_bytes(
bytes.fromhex(lmp_features), 'little' bytes.fromhex(lmp_features), 'little'
) )
# -----------------------------------------------------------------------------
def test_data_packet_queue():
controller = unittest.mock.Mock()
queue = DataPacketQueue(10, 2, controller.send)
assert queue.queued == 0
assert queue.completed == 0
packet = HCI_AclDataPacket(
connection_handle=123, pb_flag=0, bc_flag=0, data_total_length=0, data=b''
)
queue.enqueue(packet, packet.connection_handle)
assert queue.queued == 1
assert queue.completed == 0
assert controller.send.call_count == 1
queue.enqueue(packet, packet.connection_handle)
assert queue.queued == 2
assert queue.completed == 0
assert controller.send.call_count == 2
queue.enqueue(packet, packet.connection_handle)
assert queue.queued == 3
assert queue.completed == 0
assert controller.send.call_count == 2
queue.on_packets_completed(1, 8000)
assert queue.queued == 3
assert queue.completed == 0
assert controller.send.call_count == 2
queue.on_packets_completed(1, 123)
assert queue.queued == 3
assert queue.completed == 1
assert controller.send.call_count == 3
queue.enqueue(packet, packet.connection_handle)
assert queue.queued == 4
assert queue.completed == 1
assert controller.send.call_count == 3
queue.on_packets_completed(2, 123)
assert queue.queued == 4
assert queue.completed == 3
assert controller.send.call_count == 4
queue.on_packets_completed(1, 123)
assert queue.queued == 4
assert queue.completed == 4
assert controller.send.call_count == 4
queue.enqueue(packet, 123)
queue.enqueue(packet, 123)
queue.enqueue(packet, 123)
queue.enqueue(packet, 124)
queue.enqueue(packet, 124)
queue.enqueue(packet, 124)
queue.on_packets_completed(1, 123)
assert queue.queued == 10
assert queue.completed == 5
queue.flush(123)
queue.flush(124)
assert queue.queued == 10
assert queue.completed == 10
queue.enqueue(packet, 123)
queue.on_packets_completed(1, 124)
assert queue.queued == 11
assert queue.completed == 10
queue.on_packets_completed(1000, 123)
assert queue.queued == 11
assert queue.completed == 11
drain_listener = unittest.mock.Mock()
queue.on('flow', drain_listener.on_flow)
queue.enqueue(packet, 123)
assert drain_listener.on_flow.call_count == 0
queue.on_packets_completed(1, 123)
assert drain_listener.on_flow.call_count == 1
queue.enqueue(packet, 123)
queue.enqueue(packet, 123)
queue.enqueue(packet, 123)
queue.flush(123)
assert drain_listener.on_flow.call_count == 1
assert queue.queued == 15
assert queue.completed == 15

View File

@@ -53,7 +53,7 @@ def test_import():
le_audio, le_audio,
pacs, pacs,
pbp, pbp,
vcp, vcs,
) )
assert att assert att
@@ -87,7 +87,7 @@ def test_import():
assert le_audio assert le_audio
assert pacs assert pacs
assert pbp assert pbp
assert vcp assert vcs
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

Some files were not shown because too many files have changed in this diff Show More