Compare commits

...

187 Commits

Author SHA1 Message Date
Charlie Boutier
7237619d3b A2DP example: Codec selection based on file type
Currently support SBC and AAC
2025-05-08 14:24:42 -07:00
Slvr
a88a034ce2 cryptography: bump version to 44.0.3 to fix python parsing (#684)
Bug: 404336381
2025-05-08 08:28:33 -07:00
zxzxwu
6b2cd1147d Merge pull request #682 from zxzxwu/linkkey
Move connection.link_key_type to keystore
2025-05-08 11:23:28 +08:00
Josh Wu
bb8dcaf63e Move connection.link_key_type to keystore 2025-05-06 02:11:25 +08:00
Gilles Boccon-Gibod
8e84b528ce Merge pull request #679 from google/gbg/pairing-ios 2025-05-05 09:50:49 -07:00
Gilles Boccon-Gibod
8b59b4f515 address PR comments 2025-05-04 17:50:00 -07:00
Gilles Boccon-Gibod
dcc72e49a2 forward legacy constants 2025-05-04 11:34:11 -07:00
Gilles Boccon-Gibod
ce04c163db fix merge conflict 2025-05-04 11:32:25 -07:00
Gilles Boccon-Gibod
9f1e95d87f more merge fixes 2025-05-04 11:31:15 -07:00
Gilles Boccon-Gibod
088bcbed0b resolve merge conflicts 2025-05-04 11:31:15 -07:00
Gilles Boccon-Gibod
57fbad6fa4 add LE advertisement and HR service 2025-05-04 11:31:15 -07:00
Gilles Boccon-Gibod
6926d5cb70 Merge pull request #678 from google/gbg/fix-timescales
fix a few timescale adjustments
2025-05-04 11:19:05 -07:00
Gilles Boccon-Gibod
00c7df6a11 update pyee version 2025-05-03 12:24:59 -07:00
Gilles Boccon-Gibod
fbd03ed4a5 fix a few timescale adjustments 2025-05-03 12:07:53 -07:00
Gilles Boccon-Gibod
d3bd5a759f Revert "fix a few timescale adjustments"
This reverts commit dedef79bef.
2025-05-03 12:05:31 -07:00
Gilles Boccon-Gibod
dedef79bef fix a few timescale adjustments 2025-05-03 12:00:34 -07:00
zxzxwu
8db974877e Merge pull request #677 from zxzxwu/java-workflow
Add a workflow to build btbench
2025-04-26 09:44:50 -07:00
Josh Wu
e7d1531eae Add a workflow to build btbench 2025-04-26 18:51:19 +08:00
zxzxwu
4785fe6002 Merge pull request #674 from zxzxwu/event
Declare emitted events as constants
2025-04-26 02:45:50 -07:00
Josh Wu
22d6a7bf05 Declare emitted events as constants 2025-04-26 03:55:31 +08:00
Gilles Boccon-Gibod
97757c0c3d Merge pull request #676 from google/gbg/bt-bench-fixes
fix numeric entries and phy request
2025-04-24 17:27:55 -07:00
Gilles Boccon-Gibod
ab60b42b85 minor fix 2025-04-24 17:22:25 -07:00
Gilles Boccon-Gibod
febed8179b fix numeric entries and phy request 2025-04-22 17:14:39 -07:00
zxzxwu
1bd83273e8 Merge pull request #671 from zxzxwu/gatt_typing
Add missing characteristic type parameters
2025-04-16 10:06:51 -07:00
Josh Wu
5e9fc89f80 Add missing characteristic type parameters 2025-04-16 20:34:12 +08:00
zxzxwu
2686663eb2 Merge pull request #670 from zxzxwu/ee
Make all event emitters abortable and async
2025-04-15 22:33:51 -07:00
Josh Wu
55801bc2ca Make all event emitters async
* Also remove AbortableEventEmitter
2025-04-16 12:40:57 +08:00
zxzxwu
6cecc16519 Merge pull request #669 from zxzxwu/import
Cleanup relative imports
2025-04-14 10:07:13 -07:00
Josh Wu
a57cf13e2e Cleanup relative imports 2025-04-12 23:06:52 +08:00
zxzxwu
58f153afc4 Merge pull request #667 from zxzxwu/transport
Replace legacy transport and role constants
2025-04-10 12:02:27 +08:00
Josh Wu
7569da37e4 Replace legacy transport and role constants 2025-04-09 19:04:02 +08:00
Gilles Boccon-Gibod
a8019a70da Merge pull request #666 from canatella/fix-l2cap-signaling-packet-identifiers
Fix L2CAP signaling packet identifiers
2025-04-08 14:49:43 -04:00
Damien Merenne
685f1dc43e Fix L2CAP signaling packet identifiers
According to the Bluetooth Core Spec, Volume 3, Part A, Section 4, 0x00 is an invalid identifier:

 4. Signaling packet formats
...
    Identifier (1 octet)

    ... Signaling identifier 0x00 is an invalid identifier and shall never be used in any command.
2025-04-08 14:37:02 +00:00
Gilles Boccon-Gibod
220b3b0236 Merge pull request #664 from google/gbg/auracast-broadcast-code
add broadcast code encoding
2025-03-20 14:33:05 -04:00
Gilles Boccon-Gibod
3495eb52ba reset parser before raising exception 2025-03-19 11:32:51 -04:00
zxzxwu
1f7a1401eb Merge pull request #644 from zxzxwu/pasync
Advertising Set Info Transfer
2025-03-18 22:12:23 +08:00
Josh Wu
ce2b02b62a Advertising Set Info Transfer 2025-03-18 21:59:35 +08:00
Gilles Boccon-Gibod
5e55c0e358 add broadcast code encoding 2025-03-17 19:56:02 -04:00
Gilles Boccon-Gibod
ebeb0dc9f1 Merge pull request #663 from google/gbg/ancs
Initial support for ANCS client functionality
2025-03-14 14:07:14 -04:00
Gilles Boccon-Gibod
776bdae519 Initial support for ANCS client functionality 2025-03-12 15:44:13 -04:00
zxzxwu
b2d9541f8f Merge pull request #332 from zxzxwu/role
Enumify: PhysicalTransport, Role, AddressType
2025-03-10 00:04:18 +08:00
Josh Wu
637224d5bc Enum: PhysicalTransport, Role, AddressType 2025-03-09 23:34:01 +08:00
Gilles Boccon-Gibod
92ab171013 Merge pull request #659 from pcondoleon/le_scan_interval_fix
Fixed le_scan_interval incorrectly being set with scan_window
2025-02-28 15:04:03 -05:00
Peter Condoleon
592475e2ed Fixed le_scan_interval incorrectly being set with scan_window 2025-02-27 13:54:20 +10:00
Gilles Boccon-Gibod
12bcdb7770 Merge pull request #658 from google/gbg/auracast-doc
add auracast doc
2025-02-25 07:18:38 -08:00
Gilles Boccon-Gibod
7a58f36020 add auracast doc 2025-02-24 09:10:03 -08:00
Gilles Boccon-Gibod
ed0eb912c5 Merge pull request #650 from google/gbg/gatt-adapter-typing
new GATT adapter classes with proper typing support
2025-02-23 18:06:16 -08:00
Gilles Boccon-Gibod
752ce6c830 Merge pull request #657 from google/gbg/auracast-iso-data-path-refactor
use bis link API
2025-02-23 07:42:13 -08:00
Gilles Boccon-Gibod
8e509c18c9 remove unused import 2025-02-22 13:34:57 -08:00
Gilles Boccon-Gibod
cc21ed27c7 use bis link API 2025-02-22 13:32:58 -08:00
Gilles Boccon-Gibod
82d825071c address PR comments 2025-02-22 12:43:38 -08:00
Gilles Boccon-Gibod
b932bafe6d Merge pull request #655 from markusjellitsch/fix/acl_packet_queue
FIX - acl_packet_queue.flush
2025-02-22 11:33:37 -08:00
markus
4e35aba033 fix acl_packet_queue flush when controller does not support HCI_READ_BUFFER_SIZE_COMMAND 2025-02-22 08:37:12 +01:00
zxzxwu
0060ee8ee2 Merge pull request #646 from zxzxwu/ad
Improve AdvertisingData type annotations
2025-02-22 02:52:09 +08:00
zxzxwu
3263d71f54 Merge pull request #654 from zxzxwu/init
Fix mutable default values
2025-02-22 02:51:11 +08:00
Josh Wu
f321143837 Improve AdvertisingData type annotations
* Add overloads to provide better return type hints
* Make advertising data type enum so they can be considered constants
2025-02-21 17:12:14 +08:00
Josh Wu
bac6f5baaf Fix mutable default values 2025-02-21 16:00:18 +08:00
Gilles Boccon-Gibod
e027bcb57a Merge pull request #651 from google/gbg/update-lc3-and-rootcanal
update rootcanal and lc3 dependencies
2025-02-18 14:08:42 -08:00
Gilles Boccon-Gibod
eeb9de31ed update dependencies 2025-02-18 12:52:57 -08:00
Gilles Boccon-Gibod
4befc5bbae fix doc strings 2025-02-18 09:50:15 -08:00
Gilles Boccon-Gibod
2c3af5b2bb Merge pull request #649 from google/gbg/async-read-phy
make connection phy async
2025-02-18 07:25:06 -08:00
Gilles Boccon-Gibod
dfb92e8ed1 fix pandora connection waiting 2025-02-17 19:50:57 -08:00
Gilles Boccon-Gibod
73d2b54e30 make connection phy async 2025-02-17 19:24:18 -08:00
Gilles Boccon-Gibod
8315a60f24 Merge pull request #648 from pstrueb/fix/auracast_transmit
Add setup_data_path for iso queues
2025-02-17 17:22:47 -08:00
185d5fd577 reformatting 2025-02-17 20:11:17 +01:00
pstrueb
ae5f9cf690 Remove little accident 2025-02-17 18:21:16 +01:00
pstrueb
4b66a38fe6 Refractoring again 2025-02-17 09:49:09 +01:00
pstrueb
f526f549ee refractoring 2025-02-17 08:51:28 +01:00
Gilles Boccon-Gibod
da029a1749 new adapter classes 2025-02-16 16:26:13 -08:00
pstrueb
8761129677 Add setup_data_path for iso queues 2025-02-16 14:38:05 +01:00
Gilles Boccon-Gibod
3f6f036270 Merge pull request #643 from google/gbg/auracast-audio-io
auracast audio io
2025-02-08 18:19:24 -05:00
Gilles Boccon-Gibod
859bb0609f fix support for float32 2025-02-08 18:12:45 -05:00
Gilles Boccon-Gibod
5f2d24570e larger queue size 2025-02-06 21:24:58 -05:00
Gilles Boccon-Gibod
dbf94c8f3e print encoding params 2025-02-06 18:23:31 -05:00
Gilles Boccon-Gibod
b6adc29365 python 3.9 compat 2025-02-06 18:17:13 -05:00
Gilles Boccon-Gibod
5caa7bfa90 fix type checker and linter errors 2025-02-06 17:05:56 -05:00
Gilles Boccon-Gibod
f39d706fa0 remove obsolete code 2025-02-06 16:45:37 -05:00
zxzxwu
c02c1f33d2 Merge pull request #642 from zxzxwu/btbench
Add missing permissions in btbench
2025-02-07 04:48:46 +08:00
Gilles Boccon-Gibod
33435c2980 better docs and GATT fixes 2025-02-06 15:48:39 -05:00
Josh Wu
c08449d9db Add missing permissions in btbench 2025-02-07 03:13:53 +08:00
Gilles Boccon-Gibod
3c8718bb5b Merge pull request #608 from google/gbg/bench-android-enhancements
add startupDelay and connectionPriority params to BtBench snippets
2025-02-06 10:37:55 -05:00
Gilles Boccon-Gibod
26e87f09fe better error message 2025-02-05 22:28:05 -05:00
Gilles Boccon-Gibod
7f5e0d190e fix import checking 2025-02-05 22:19:39 -05:00
Gilles Boccon-Gibod
efae307b3d wip 2025-02-05 16:23:47 -05:00
zxzxwu
26d38a855c Merge pull request #641 from zxzxwu/pasync
Receive Periodic Advertising Sync Transfer
2025-02-06 05:18:47 +08:00
Josh Wu
7360a887d9 Receive Periodic Advertising Sync Transfer 2025-02-06 05:12:22 +08:00
Gilles Boccon-Gibod
9756572c93 add audio module 2025-02-04 17:58:54 -05:00
Gilles Boccon-Gibod
d6100755b1 add bond listener 2025-02-04 17:47:55 -05:00
Gilles Boccon-Gibod
a66eef6630 Merge pull request #640 from whitevegagabriel/cleanup
Rust library cleanup
2025-02-04 12:35:37 -05:00
Gabriel White-Vega
ae23ef7b9b Rust library cleanup
* Fix error code extraction from Python to Rust
* Add documentation for dealing with HCI packets
2025-02-04 12:23:06 -05:00
Gilles Boccon-Gibod
f368b5e518 wip 2025-02-03 18:02:14 -05:00
Gilles Boccon-Gibod
5293d32dc6 fix linter config 2025-02-03 18:02:14 -05:00
Gilles Boccon-Gibod
6d9a0bf4e1 fix linter config 2025-02-03 18:02:14 -05:00
Gilles Boccon-Gibod
3c7b5df7c5 add startupDelay and connectionPriority params to BtBench snippets 2025-02-03 18:00:46 -05:00
Gilles Boccon-Gibod
70141c0439 improvements 2025-02-03 17:58:09 -05:00
zxzxwu
dedc0aca54 Merge pull request #639 from zxzxwu/sdp
Correct SDP_ALL_ATTRIBUTES_RANGE value
2025-02-04 00:53:27 +08:00
Gilles Boccon-Gibod
7c019b574f Merge pull request #633 from markusjellitsch/fix/legacy-adv-params
fix advertising parameter usage for legacy advertising
2025-02-03 10:29:52 -05:00
markus
9b485fd943 revert python-avatar.yml 2025-02-03 15:17:22 +01:00
Josh Wu
fdee8269ec Correct SDP_ALL_ATTRIBUTES_RANGE value 2025-02-03 21:40:39 +08:00
zxzxwu
0767f2d4ae Merge pull request #638 from zxzxwu/avatar
Update actions/upload-artifact to v4
2025-02-03 21:31:42 +08:00
Josh Wu
c4a0846727 Update actions/upload-artifact to v4 2025-02-03 16:41:09 +08:00
zxzxwu
83ac70e426 Merge pull request #619 from zxzxwu/cs
Channel Sounding
2025-02-01 03:46:59 +08:00
markus
01cce3525f update avatar to github actions v4 2025-01-30 23:55:15 +01:00
markus
b9d35aea47 revert advertising_interval to type int 2025-01-30 19:47:20 +01:00
zxzxwu
079cf6b896 Merge pull request #624 from zxzxwu/gatt
Support GATT Service
2025-01-28 20:02:43 +08:00
Markus Jellitsch
180655088c run linter 2025-01-27 22:17:31 +01:00
Gilles Boccon-Gibod
a1bade6f20 Merge pull request #632 from markusjellitsch/fix/adapt-param-types
Adapt scanning and connection parameters type
2025-01-27 10:46:08 -05:00
Gilles Boccon-Gibod
5d80e7fd80 Merge pull request #634 from jmdietrich-gcx/fix_missing_await_for_update_rpa
Add missing await for update_rpa()
2025-01-27 10:45:42 -05:00
Jan-Marcel Dietrich
2198692961 Add missing await for update_rpa() 2025-01-27 15:14:52 +01:00
Gilles Boccon-Gibod
55d3fd90f5 wip 2025-01-25 21:04:59 -05:00
Gilles Boccon-Gibod
afee659ca6 Merge pull request #630 from google/gbg/iso-packet-queue
add support for ACL and ISO HCI packet queues
2025-01-24 15:59:19 -05:00
Gilles Boccon-Gibod
6fe7931d7d rename drain event to flow 2025-01-24 11:05:02 -05:00
Markus Jellitsch
9023407ee4 fix advertising parameters for legacy advertising 2025-01-23 15:14:54 +01:00
Markus Jellitsch
54d961bbe5 adapt scanning and connection parameters type 2025-01-23 14:53:20 +01:00
Gilles Boccon-Gibod
cbd46adbcf add support for ACL and ISO HCI packet queues 2025-01-22 13:42:29 -05:00
Josh Wu
745e107849 Channel Sounding device handlers 2025-01-22 23:38:44 +08:00
Gilles Boccon-Gibod
af466c2970 Merge pull request #629 from google/gbg/sdp-enforce-mtu
SDP: enforce MTU limits
2025-01-21 12:29:18 -05:00
Gilles Boccon-Gibod
931e2de854 address PR comments 2025-01-21 12:18:06 -05:00
Gilles Boccon-Gibod
55eb7eb237 enforce MTU limits 2025-01-21 10:31:10 -05:00
zxzxwu
bade4502f9 Merge pull request #628 from zxzxwu/cs-hci
Channel Sounding HCI packet definitions
2025-01-19 16:14:08 +08:00
Josh Wu
9f952f202f Channel Sounding HCI packet definitions 2025-01-16 14:33:34 +08:00
Josh Wu
1eb9d8d055 Support GATT Service 2025-01-15 02:13:25 +08:00
Gilles Boccon-Gibod
5a477eb391 Merge pull request #626 from markusjellitsch/fix/set-ext-scan-param-cmd
Update device.py - Fix scan_interval param in hci.HCI_LE_Set_Extended_Scan_Parameters_Command
2025-01-14 11:04:15 -05:00
Markus Jellitsch
86cda8771d Update device.py 2025-01-14 10:43:49 +01:00
zxzxwu
c1ea0ddd35 Merge pull request #622 from markusjellitsch/main
Fix: _IsoLink.write() struct.exception
2025-01-13 16:21:41 +08:00
Markus Jellitsch
f567711a6c avoid struct.error exception when packet_sequence_number > 0xFFFF 2025-01-10 01:33:43 +01:00
Gilles Boccon-Gibod
509df4c676 Merge pull request #618 from google/gbg/hci-event-multi-vendor
support multiple event factories
2025-01-07 15:00:20 -05:00
Gilles Boccon-Gibod
b375ed07b4 add test 2025-01-07 14:54:59 -05:00
Gilles Boccon-Gibod
69d62d3dd1 support multiple event factories 2025-01-06 08:42:09 -05:00
zxzxwu
fe3fa3d505 Merge pull request #617 from zxzxwu/iso
Unify ISO methods
2025-01-06 14:31:47 +08:00
Josh Wu
27fcd43224 Unify ISO methods 2025-01-02 14:19:36 +08:00
zxzxwu
c3b2bb19d5 Merge pull request #589 from zxzxwu/auracast
Auracast support
2025-01-02 01:02:13 +08:00
Gilles Boccon-Gibod
34287177b9 Merge pull request #615 from google/gbg/bluetooth-6-constants
add bluetooth 6.0 constants
2024-12-23 08:46:13 -05:00
Josh Wu
d238dd4059 Use dynamic sample rate 2024-12-23 17:01:11 +08:00
Gilles Boccon-Gibod
865f3a249f add bluetooth 6.0 constants 2024-12-22 12:47:37 -05:00
Josh Wu
7324d322fe BIG 2024-12-20 13:45:12 +08:00
Gilles Boccon-Gibod
af148b476d Merge pull request #613 from google/gbg/update-cryptography-dependency
update cryptography dependency
2024-12-19 08:42:51 -05:00
zxzxwu
80d60aaf15 Merge pull request #612 from zxzxwu/lc3
Replace liblc3 wasm library
2024-12-19 15:06:22 +08:00
Gilles Boccon-Gibod
c80f89d20f update cryptography dependency 2024-12-18 22:01:42 -05:00
Josh Wu
a27f55a588 Replace liblc3 wasm library 2024-12-19 02:21:38 +08:00
Gilles Boccon-Gibod
62e4670a39 Merge pull request #606 from wpiet/gmap-wip
Add `Gaming Audio Profile`
2024-12-18 11:56:57 -05:00
zxzxwu
99695bb264 Merge pull request #610 from zxzxwu/cfg
Remove setup.py and setup.cfg
2024-12-19 00:53:12 +08:00
Josh Wu
eb54898106 Remove setup.py and setup.cfg 2024-12-19 00:45:13 +08:00
Gilles Boccon-Gibod
4f5ee204d2 Update code-check.yml
Hot fix because 3.13.1 somehow breaks the current version of pylint. Will revert to 3.13 without pining to 3.13.0 when pylint is fixed
2024-12-18 11:36:08 -05:00
Wojciech Pietraszewski
2552e21db1 Add characteristics initial values
Sets default values for characteristics if not specified explicitly
2024-12-04 17:00:29 +01:00
Wojciech Pietraszewski
6168f87e2f Add characteristics conditionally
Only adds a characteristic if the corresponding role has been set
2024-12-04 12:57:34 +01:00
Gilles Boccon-Gibod
ca7d2ca4df Merge pull request #607 from google/gbg/pandora-deps
move pandora deps to development
2024-12-03 09:42:44 -08:00
Gilles Boccon-Gibod
60723323e9 move pandora deps to development 2024-12-03 09:08:30 -08:00
Gilles Boccon-Gibod
3ce7b9255b Merge pull request #598 from google/gbg/gatt-class-adapter
Add a class-based GATT adapter
2024-12-03 08:46:30 -08:00
Gilles Boccon-Gibod
97fcfc2fa0 Merge pull request #604 from jmdietrich-gcx/add_encryption_key_size_to_pairing_config
Add maximum encryption key size to PairingDelegate
2024-12-03 08:30:53 -08:00
Wojciech Pietraszewski
19674e3758 Add Gaming Audio Profile
Adds initial support for `Gaming Audio Service`.
2024-12-02 11:15:10 +01:00
Jan-Marcel Dietrich
1130e1db8f Fix code formatting 2024-12-02 09:01:18 +01:00
Gilles Boccon-Gibod
37c7f3a58a Merge pull request #603 from google/gbg/fix-pair-oob
fix oob support in pair.py
2024-12-01 08:43:04 -08:00
Gilles Boccon-Gibod
0a12b2bf2e Merge pull request #585 from wpiet/vocs
Add `Volume Offset Control Service`
2024-11-29 10:41:30 -08:00
Gilles Boccon-Gibod
d014acbe63 Merge pull request #597 from google/gbg/intel-hci
intel hci
2024-11-29 10:41:10 -08:00
Jan-Marcel Dietrich
07f9997a49 Add maximum encryption key size to PairingDelegate
So far the maxmium encryption key size has been hardcoded to 16 bytes in
'send_pairing_request_command()' and 'send_pairing_response_comman()'. By
making this configurable via the PairingDelegate, one can test how devices
respond to smaller encryption key sizes. Default remains 16 bytes.
2024-11-28 14:15:51 +01:00
Gilles Boccon-Gibod
b9f91f695a fix oob support in pair.py 2024-11-27 12:58:03 -08:00
Gilles Boccon-Gibod
082d55af10 Merge pull request #599 from google/gbg/hfp-19
add super wide band constants
2024-11-25 07:47:40 -08:00
Gilles Boccon-Gibod
4c3fd5688d Merge pull request #600 from google/gbg/unify-to-bytes
only use `__bytes__` when not argument is needed.
2024-11-25 07:44:17 -08:00
Gilles Boccon-Gibod
9d3d5495ce only use __bytes__ when not argument is needed. 2024-11-23 15:56:14 -08:00
Gilles Boccon-Gibod
b3869f267c add super wide band constants 2024-11-23 09:27:03 -08:00
Gilles Boccon-Gibod
8715333706 Add a GATT adapter that uses from_bytes and __bytes__ as conversion methods. 2024-11-23 09:13:04 -08:00
Gilles Boccon-Gibod
b57096abe2 Merge pull request #595 from wpiet/aics-opcode-fix
Amend Opcode value in `Audio Input Control Service`
2024-11-23 08:56:23 -08:00
Gilles Boccon-Gibod
48685c8587 improve vendor event support 2024-11-23 08:55:50 -08:00
Wojciech Pietraszewski
100bea6b41 Fix typos
Amends the typo in the `INACTIVE` field in `Audio Input Status` characteristic.
Amends the typo in the log message of `_set_gain_settings` method.
2024-11-21 18:29:44 +01:00
Wojciech Pietraszewski
63819bf9dd Amend Opcode value in Audio Input Control Service
Corrects the Audio Input Control Point
Opcode value for `Set Gain Setting` field.
2024-11-21 16:40:49 +01:00
Wojciech Pietraszewski
6e55390930 Add Volume Offset Control Service
Adds initial support for VOCS.
2024-11-21 11:56:14 +01:00
zxzxwu
e3fdab4175 Merge pull request #593 from zxzxwu/periodic
Support Periodic Advertising
2024-11-19 17:22:37 +08:00
Josh Wu
bbcd14dbf0 Support Periodic Advertising 2024-11-19 16:27:13 +08:00
zxzxwu
01dc0d574b Merge pull request #590 from SergeantSerk/parse-scan-response-data
Correctly parse scan response from device config
2024-11-17 15:39:11 +08:00
zxzxwu
5e959d638e Merge pull request #591 from zxzxwu/auracast_scan
Improve Broadcast Scanning
2024-11-16 04:10:27 +08:00
Gilles Boccon-Gibod
8d908288c8 Merge pull request #583 from google/gbg/more-gatt-tests
regression test for GATT unsubscription
2024-11-15 10:19:20 -08:00
Josh Wu
c88b32a406 Improve Broadcast Scanning 2024-11-16 02:02:28 +08:00
zxzxwu
5a72eefb89 Merge pull request #587 from zxzxwu/device
Replace HCI member imports in device.py
2024-11-13 15:25:32 +08:00
Josh Wu
430046944b Replace HCI member import in device.py 2024-11-12 16:53:21 +08:00
zxzxwu
21d23320eb Merge pull request #584 from zxzxwu/commands6.0
Add Core Spec 6.0 new commands support mapping
2024-11-12 04:17:24 +00:00
Serkan
d0990ee04d Correctly parse scan response from device config
Parses scan response data correctly just like advertising data
2024-11-07 21:49:33 +03:00
Josh Wu
2d88e853e8 Add Core Spec 6.0 new commands support mapping 2024-11-07 14:36:54 +08:00
Gilles Boccon-Gibod
a060a70fba Merge pull request #583 from google/gbg/more-gatt-tests
regression test for GATT unsubscription
2024-11-04 13:03:57 -08:00
Gilles Boccon-Gibod
a06394ad4a Merge pull request #582 from google/gbg/580
fix #580
2024-11-04 13:03:15 -08:00
Gilles Boccon-Gibod
a1414c2b5b add unsubscribe test 2024-11-03 19:08:27 -08:00
Gilles Boccon-Gibod
b2864dac2d fix #580 2024-11-02 10:29:40 -07:00
Gilles Boccon-Gibod
b78f895143 Merge pull request #579 from jmdietrich-gcx/unsubscribe_characteristic_in_gatt_client
Remove characteristic in GATT Client unsubscribe() if it's the last subscriber
2024-10-31 04:07:02 -07:00
zxzxwu
c4e9726828 Merge pull request #581 from zxzxwu/context
[BAP] Add missing Unspecified context type
2024-10-31 11:04:25 +00:00
Gilles Boccon-Gibod
d4b8e8348a Merge pull request #574 from google/gbg/update-python-versions
remove test for deprecated Python 3.8 and add 3.13
2024-10-31 03:44:01 -07:00
Josh Wu
19debaa52e [BAP] Add missing Unspecified context type 2024-10-31 18:11:40 +08:00
Jan-Marcel Dietrich
73fe564321 Remove characteristic in GATT Client unsubscribe() if it's the last subscriber
GATT Client's subscribe() adds the characteristic itself as subscriber.
Therefore the characteristic has to be removed in unsubscribe(), if it's
the last subscriber. Otherwise the clean up does not work correctly and
the CCCD never is set back to 0 in the remote device.
2024-10-30 07:34:22 +01:00
172 changed files with 14244 additions and 4106 deletions

26
.github/ci-gradle.properties vendored Normal file
View File

@@ -0,0 +1,26 @@
#
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
org.gradle.configureondemand=true
org.gradle.caching=true
org.gradle.parallel=true
# Declare we support AndroidX
android.useAndroidX=true
org.gradle.jvmargs=-Xmx4608m -XX:MaxMetaspaceSize=1536m -XX:+HeapDumpOnOutOfMemoryError
kotlin.compiler.execution.strategy=in-process

View File

@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python-version: ["3.9", "3.10", "3.11", "3.12", "3.13.0"]
fail-fast: false fail-fast: false
steps: steps:
@@ -33,7 +33,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ".[build,test,development,pandora]" python -m pip install ".[build,examples,test,development]"
- name: Check - name: Check
run: | run: |
invoke project.pre-commit invoke project.pre-commit

33
.github/workflows/gradle-btbench.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: Gradle Android Build & test
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
paths:
- 'extras/android/BtBench/**'
permissions:
contents: read
jobs:
build:
runs-on: ubuntu-latest
timeout-minutes: 40
steps:
- name: Check out from Git
uses: actions/checkout@v3
- name: Set up JDK
uses: actions/setup-java@v4
with:
distribution: 'zulu'
java-version: 17
- name: Setup Gradle
uses: gradle/actions/setup-gradle@v3
- name: Build with Gradle
run: cd extras/android/BtBench && ./gradlew build

View File

@@ -32,7 +32,7 @@ jobs:
- name: Install - name: Install
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install .[avatar,pandora] python -m pip install .[avatar]
- name: Rootcanal - name: Rootcanal
run: nohup python -m rootcanal > rootcanal.log & run: nohup python -m rootcanal > rootcanal.log &
- name: Test - name: Test
@@ -44,7 +44,7 @@ jobs:
run: cat rootcanal.log run: cat rootcanal.log
- name: Upload Mobly logs - name: Upload Mobly logs
if: always() if: always()
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: mobly-logs name: mobly-logs-${{ strategy.job-index }}
path: /tmp/logs/mobly/bumble.bumbles/ path: /tmp/logs/mobly/bumble.bumbles/

View File

@@ -14,9 +14,12 @@
"ASHA", "ASHA",
"asyncio", "asyncio",
"ATRAC", "ATRAC",
"auracast",
"avctp", "avctp",
"avdtp", "avdtp",
"avrcp", "avrcp",
"biginfo",
"bigs",
"bitpool", "bitpool",
"bitstruct", "bitstruct",
"BSCP", "BSCP",
@@ -36,6 +39,7 @@
"deregistration", "deregistration",
"dhkey", "dhkey",
"diversifier", "diversifier",
"ediv",
"endianness", "endianness",
"ESCO", "ESCO",
"Fitbit", "Fitbit",
@@ -47,6 +51,7 @@
"libc", "libc",
"liblc", "liblc",
"libusb", "libusb",
"maxs",
"MITM", "MITM",
"MSBC", "MSBC",
"NDIS", "NDIS",
@@ -54,8 +59,10 @@
"NONBLOCK", "NONBLOCK",
"NONCONN", "NONCONN",
"OXIMETER", "OXIMETER",
"PDUS",
"popleft", "popleft",
"PRAND", "PRAND",
"prefs",
"protobuf", "protobuf",
"psms", "psms",
"pyee", "pyee",

File diff suppressed because it is too large Load Diff

View File

@@ -16,6 +16,7 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import dataclasses
import enum import enum
import logging import logging
import os import os
@@ -27,8 +28,7 @@ import click
from bumble import l2cap from bumble import l2cap
from bumble.core import ( from bumble.core import (
BT_BR_EDR_TRANSPORT, PhysicalTransport,
BT_LE_TRANSPORT,
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID,
UUID, UUID,
@@ -41,8 +41,7 @@ from bumble.hci import (
HCI_LE_1M_PHY, HCI_LE_1M_PHY,
HCI_LE_2M_PHY, HCI_LE_2M_PHY,
HCI_LE_CODED_PHY, HCI_LE_CODED_PHY,
HCI_CENTRAL_ROLE, Role,
HCI_PERIPHERAL_ROLE,
HCI_Constant, HCI_Constant,
HCI_Error, HCI_Error,
HCI_StatusError, HCI_StatusError,
@@ -97,49 +96,22 @@ DEFAULT_RFCOMM_MTU = 2048
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def parse_packet(packet):
if len(packet) < 1:
logging.info(
color(f'!!! Packet too short (got {len(packet)} bytes, need >= 1)', 'red')
)
raise ValueError('packet too short')
try:
packet_type = PacketType(packet[0])
except ValueError:
logging.info(color(f'!!! Invalid packet type 0x{packet[0]:02X}', 'red'))
raise
return (packet_type, packet[1:])
def parse_packet_sequence(packet_data):
if len(packet_data) < 5:
logging.info(
color(
f'!!!Packet too short (got {len(packet_data)} bytes, need >= 5)',
'red',
)
)
raise ValueError('packet too short')
return struct.unpack_from('>bI', packet_data, 0)
def le_phy_name(phy_id): def le_phy_name(phy_id):
return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get( return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
phy_id, HCI_Constant.le_phy_name(phy_id) phy_id, HCI_Constant.le_phy_name(phy_id)
) )
def print_connection_phy(phy):
logging.info(
color('@@@ PHY: ', 'yellow') + f'TX:{le_phy_name(phy.tx_phy)}/'
f'RX:{le_phy_name(phy.rx_phy)}'
)
def print_connection(connection): def print_connection(connection):
params = [] params = []
if connection.transport == BT_LE_TRANSPORT: if connection.transport == PhysicalTransport.LE:
params.append(
'PHY='
f'TX:{le_phy_name(connection.phy.tx_phy)}/'
f'RX:{le_phy_name(connection.phy.rx_phy)}'
)
params.append( params.append(
'DL=(' 'DL=('
f'TX:{connection.data_length[0]}/{connection.data_length[1]},' f'TX:{connection.data_length[0]}/{connection.data_length[1]},'
@@ -149,9 +121,9 @@ def print_connection(connection):
params.append( params.append(
'Parameters=' 'Parameters='
f'{connection.parameters.connection_interval * 1.25:.2f}/' f'{connection.parameters.connection_interval:.2f}/'
f'{connection.parameters.peripheral_latency}/' f'{connection.parameters.peripheral_latency}/'
f'{connection.parameters.supervision_timeout * 10} ' f'{connection.parameters.supervision_timeout:.2f} '
) )
params.append(f'MTU={connection.att_mtu}') params.append(f'MTU={connection.att_mtu}')
@@ -199,7 +171,7 @@ def log_stats(title, stats, precision=2):
stats_min = min(stats) stats_min = min(stats)
stats_max = max(stats) stats_max = max(stats)
stats_avg = statistics.mean(stats) stats_avg = statistics.mean(stats)
stats_stdev = statistics.stdev(stats) stats_stdev = statistics.stdev(stats) if len(stats) >= 2 else 0
logging.info( logging.info(
color( color(
( (
@@ -215,7 +187,7 @@ def log_stats(title, stats, precision=2):
async def switch_roles(connection, role): async def switch_roles(connection, role):
target_role = HCI_CENTRAL_ROLE if role == "central" else HCI_PERIPHERAL_ROLE target_role = Role.CENTRAL if role == "central" else Role.PERIPHERAL
if connection.role != target_role: if connection.role != target_role:
logging.info(f'{color("### Switching roles to:", "cyan")} {role}') logging.info(f'{color("### Switching roles to:", "cyan")} {role}')
try: try:
@@ -225,13 +197,135 @@ async def switch_roles(connection, role):
logging.info(f'{color("### Role switch failed:", "red")} {error}') logging.info(f'{color("### Role switch failed:", "red")} {error}')
class PacketType(enum.IntEnum): # -----------------------------------------------------------------------------
RESET = 0 # Packet
SEQUENCE = 1 # -----------------------------------------------------------------------------
ACK = 2 @dataclasses.dataclass
class Packet:
class PacketType(enum.IntEnum):
RESET = 0
SEQUENCE = 1
ACK = 2
class PacketFlags(enum.IntFlag):
LAST = 1
packet_type: PacketType
flags: PacketFlags = PacketFlags(0)
sequence: int = 0
timestamp: int = 0
payload: bytes = b""
@classmethod
def from_bytes(cls, data: bytes):
if len(data) < 1:
logging.warning(
color(f'!!! Packet too short (got {len(data)} bytes, need >= 1)', 'red')
)
raise ValueError('packet too short')
try:
packet_type = cls.PacketType(data[0])
except ValueError:
logging.warning(color(f'!!! Invalid packet type 0x{data[0]:02X}', 'red'))
raise
if packet_type == cls.PacketType.RESET:
return cls(packet_type)
flags = cls.PacketFlags(data[1])
(sequence,) = struct.unpack_from("<I", data, 2)
if packet_type == cls.PacketType.ACK:
if len(data) < 6:
logging.warning(
color(
f'!!! Packet too short (got {len(data)} bytes, need >= 6)',
'red',
)
)
return cls(packet_type, flags, sequence)
if len(data) < 10:
logging.warning(
color(
f'!!! Packet too short (got {len(data)} bytes, need >= 10)', 'red'
)
)
raise ValueError('packet too short')
(timestamp,) = struct.unpack_from("<I", data, 6)
return cls(packet_type, flags, sequence, timestamp, data[10:])
def __bytes__(self):
if self.packet_type == self.PacketType.RESET:
return bytes([self.packet_type])
if self.packet_type == self.PacketType.ACK:
return struct.pack("<BBI", self.packet_type, self.flags, self.sequence)
return (
struct.pack(
"<BBII", self.packet_type, self.flags, self.sequence, self.timestamp
)
+ self.payload
)
PACKET_FLAG_LAST = 1 # -----------------------------------------------------------------------------
# Jitter Stats
# -----------------------------------------------------------------------------
class JitterStats:
def __init__(self):
self.reset()
def reset(self):
self.packets = []
self.receive_times = []
self.jitter = []
def on_packet_received(self, packet):
now = time.time()
self.packets.append(packet)
self.receive_times.append(now)
if packet.timestamp and len(self.packets) > 1:
expected_time = (
self.receive_times[0]
+ (packet.timestamp - self.packets[0].timestamp) / 1000000
)
jitter = now - expected_time
else:
jitter = 0.0
self.jitter.append(jitter)
return jitter
def show_stats(self):
if len(self.jitter) < 3:
return
average = sum(self.jitter) / len(self.jitter)
adjusted = [jitter - average for jitter in self.jitter]
log_stats('Jitter (signed)', adjusted, 3)
log_stats('Jitter (absolute)', [abs(jitter) for jitter in adjusted], 3)
# Show a histogram
bin_count = 20
bins = [0] * bin_count
interval_min = min(adjusted)
interval_max = max(adjusted)
interval_range = interval_max - interval_min
bin_thresholds = [
interval_min + i * (interval_range / bin_count) for i in range(bin_count)
]
for jitter in adjusted:
for i in reversed(range(bin_count)):
if jitter >= bin_thresholds[i]:
bins[i] += 1
break
for i in range(bin_count):
logging.info(f'@@@ >= {bin_thresholds[i]:.4f}: {bins[i]}')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -281,19 +375,37 @@ class Sender:
await asyncio.sleep(self.tx_start_delay) await asyncio.sleep(self.tx_start_delay)
logging.info(color('=== Sending RESET', 'magenta')) logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(bytes([PacketType.RESET])) await self.packet_io.send_packet(
bytes(Packet(packet_type=Packet.PacketType.RESET))
)
self.start_time = time.time() self.start_time = time.time()
self.bytes_sent = 0 self.bytes_sent = 0
for tx_i in range(self.tx_packet_count): for tx_i in range(self.tx_packet_count):
packet_flags = ( if self.pace > 0:
PACKET_FLAG_LAST if tx_i == self.tx_packet_count - 1 else 0 # Wait until it is time to send the next packet
target_time = self.start_time + (tx_i * self.pace / 1000)
now = time.time()
if now < target_time:
await asyncio.sleep(target_time - now)
else:
await self.packet_io.drain()
packet = bytes(
Packet(
packet_type=Packet.PacketType.SEQUENCE,
flags=(
Packet.PacketFlags.LAST
if tx_i == self.tx_packet_count - 1
else 0
),
sequence=tx_i,
timestamp=int((time.time() - self.start_time) * 1000000),
payload=bytes(
self.tx_packet_size - 10 - self.packet_io.overhead_size
),
)
) )
packet = struct.pack(
'>bbI',
PacketType.SEQUENCE,
packet_flags,
tx_i,
) + bytes(self.tx_packet_size - 6 - self.packet_io.overhead_size)
logging.info( logging.info(
color( color(
f'Sending packet {tx_i}: {self.tx_packet_size} bytes', 'yellow' f'Sending packet {tx_i}: {self.tx_packet_size} bytes', 'yellow'
@@ -302,14 +414,6 @@ class Sender:
self.bytes_sent += len(packet) self.bytes_sent += len(packet)
await self.packet_io.send_packet(packet) await self.packet_io.send_packet(packet)
if self.pace is None:
continue
if self.pace > 0:
await asyncio.sleep(self.pace / 1000)
else:
await self.packet_io.drain()
await self.done.wait() await self.done.wait()
run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else '' run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
@@ -321,13 +425,13 @@ class Sender:
if self.repeat: if self.repeat:
logging.info(color('--- End of runs', 'blue')) logging.info(color('--- End of runs', 'blue'))
def on_packet_received(self, packet): def on_packet_received(self, data):
try: try:
packet_type, _ = parse_packet(packet) packet = Packet.from_bytes(data)
except ValueError: except ValueError:
return return
if packet_type == PacketType.ACK: if packet.packet_type == Packet.PacketType.ACK:
elapsed = time.time() - self.start_time elapsed = time.time() - self.start_time
average_tx_speed = self.bytes_sent / elapsed average_tx_speed = self.bytes_sent / elapsed
self.stats.append(average_tx_speed) self.stats.append(average_tx_speed)
@@ -350,52 +454,53 @@ class Receiver:
last_timestamp: float last_timestamp: float
def __init__(self, packet_io, linger): def __init__(self, packet_io, linger):
self.reset() self.jitter_stats = JitterStats()
self.packet_io = packet_io self.packet_io = packet_io
self.packet_io.packet_listener = self self.packet_io.packet_listener = self
self.linger = linger self.linger = linger
self.done = asyncio.Event() self.done = asyncio.Event()
self.reset()
def reset(self): def reset(self):
self.expected_packet_index = 0 self.expected_packet_index = 0
self.measurements = [(time.time(), 0)] self.measurements = [(time.time(), 0)]
self.total_bytes_received = 0 self.total_bytes_received = 0
self.jitter_stats.reset()
def on_packet_received(self, packet): def on_packet_received(self, data):
try: try:
packet_type, packet_data = parse_packet(packet) packet = Packet.from_bytes(data)
except ValueError: except ValueError:
logging.exception("invalid packet")
return return
if packet_type == PacketType.RESET: if packet.packet_type == Packet.PacketType.RESET:
logging.info(color('=== Received RESET', 'magenta')) logging.info(color('=== Received RESET', 'magenta'))
self.reset() self.reset()
return return
try: jitter = self.jitter_stats.on_packet_received(packet)
packet_flags, packet_index = parse_packet_sequence(packet_data)
except ValueError:
return
logging.info( logging.info(
f'<<< Received packet {packet_index}: ' f'<<< Received packet {packet.sequence}: '
f'flags=0x{packet_flags:02X}, ' f'flags={packet.flags}, '
f'{len(packet) + self.packet_io.overhead_size} bytes' f'jitter={jitter:.4f}, '
f'{len(data) + self.packet_io.overhead_size} bytes',
) )
if packet_index != self.expected_packet_index: if packet.sequence != self.expected_packet_index:
logging.info( logging.info(
color( color(
f'!!! Unexpected packet, expected {self.expected_packet_index} ' f'!!! Unexpected packet, expected {self.expected_packet_index} '
f'but received {packet_index}' f'but received {packet.sequence}'
) )
) )
now = time.time() now = time.time()
elapsed_since_start = now - self.measurements[0][0] elapsed_since_start = now - self.measurements[0][0]
elapsed_since_last = now - self.measurements[-1][0] elapsed_since_last = now - self.measurements[-1][0]
self.measurements.append((now, len(packet))) self.measurements.append((now, len(data)))
self.total_bytes_received += len(packet) self.total_bytes_received += len(data)
instant_rx_speed = len(packet) / elapsed_since_last instant_rx_speed = len(data) / elapsed_since_last
average_rx_speed = self.total_bytes_received / elapsed_since_start average_rx_speed = self.total_bytes_received / elapsed_since_start
window = self.measurements[-64:] window = self.measurements[-64:]
windowed_rx_speed = sum(measurement[1] for measurement in window[1:]) / ( windowed_rx_speed = sum(measurement[1] for measurement in window[1:]) / (
@@ -411,15 +516,17 @@ class Receiver:
) )
) )
self.expected_packet_index = packet_index + 1 self.expected_packet_index = packet.sequence + 1
if packet_flags & PACKET_FLAG_LAST: if packet.flags & Packet.PacketFlags.LAST:
AsyncRunner.spawn( AsyncRunner.spawn(
self.packet_io.send_packet( self.packet_io.send_packet(
struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index) bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence))
) )
) )
logging.info(color('@@@ Received last packet', 'green')) logging.info(color('@@@ Received last packet', 'green'))
self.jitter_stats.show_stats()
if not self.linger: if not self.linger:
self.done.set() self.done.set()
@@ -468,6 +575,7 @@ class Ping:
for run in range(self.repeat + 1): for run in range(self.repeat + 1):
self.done.clear() self.done.clear()
self.ping_times = []
if run > 0 and self.repeat and self.repeat_delay: if run > 0 and self.repeat and self.repeat_delay:
logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green')) logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green'))
@@ -478,25 +586,32 @@ class Ping:
await asyncio.sleep(self.tx_start_delay) await asyncio.sleep(self.tx_start_delay)
logging.info(color('=== Sending RESET', 'magenta')) logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(bytes([PacketType.RESET])) await self.packet_io.send_packet(bytes(Packet(Packet.PacketType.RESET)))
packet_interval = self.pace / 1000
start_time = time.time() start_time = time.time()
self.next_expected_packet_index = 0 self.next_expected_packet_index = 0
for i in range(self.tx_packet_count): for i in range(self.tx_packet_count):
target_time = start_time + (i * packet_interval) target_time = start_time + (i * self.pace / 1000)
now = time.time() now = time.time()
if now < target_time: if now < target_time:
await asyncio.sleep(target_time - now) await asyncio.sleep(target_time - now)
now = time.time()
packet = struct.pack( packet = bytes(
'>bbI', Packet(
PacketType.SEQUENCE, packet_type=Packet.PacketType.SEQUENCE,
(PACKET_FLAG_LAST if i == self.tx_packet_count - 1 else 0), flags=(
i, Packet.PacketFlags.LAST
) + bytes(self.tx_packet_size - 6) if i == self.tx_packet_count - 1
else 0
),
sequence=i,
timestamp=int((now - start_time) * 1000000),
payload=bytes(self.tx_packet_size - 10),
)
)
logging.info(color(f'Sending packet {i}', 'yellow')) logging.info(color(f'Sending packet {i}', 'yellow'))
self.ping_times.append(time.time()) self.ping_times.append(now)
await self.packet_io.send_packet(packet) await self.packet_io.send_packet(packet)
await self.done.wait() await self.done.wait()
@@ -530,40 +645,35 @@ class Ping:
if self.repeat: if self.repeat:
logging.info(color('--- End of runs', 'blue')) logging.info(color('--- End of runs', 'blue'))
def on_packet_received(self, packet): def on_packet_received(self, data):
try: try:
packet_type, packet_data = parse_packet(packet) packet = Packet.from_bytes(data)
except ValueError: except ValueError:
return return
try: if packet.packet_type == Packet.PacketType.ACK:
packet_flags, packet_index = parse_packet_sequence(packet_data) elapsed = time.time() - self.ping_times[packet.sequence]
except ValueError:
return
if packet_type == PacketType.ACK:
elapsed = time.time() - self.ping_times[packet_index]
rtt = elapsed * 1000 rtt = elapsed * 1000
self.rtts.append(rtt) self.rtts.append(rtt)
logging.info( logging.info(
color( color(
f'<<< Received ACK [{packet_index}], RTT={rtt:.2f}ms', f'<<< Received ACK [{packet.sequence}], RTT={rtt:.2f}ms',
'green', 'green',
) )
) )
if packet_index == self.next_expected_packet_index: if packet.sequence == self.next_expected_packet_index:
self.next_expected_packet_index += 1 self.next_expected_packet_index += 1
else: else:
logging.info( logging.info(
color( color(
f'!!! Unexpected packet, ' f'!!! Unexpected packet, '
f'expected {self.next_expected_packet_index} ' f'expected {self.next_expected_packet_index} '
f'but received {packet_index}' f'but received {packet.sequence}'
) )
) )
if packet_flags & PACKET_FLAG_LAST: if packet.flags & Packet.PacketFlags.LAST:
self.done.set() self.done.set()
return return
@@ -575,89 +685,56 @@ class Pong:
expected_packet_index: int expected_packet_index: int
def __init__(self, packet_io, linger): def __init__(self, packet_io, linger):
self.reset() self.jitter_stats = JitterStats()
self.packet_io = packet_io self.packet_io = packet_io
self.packet_io.packet_listener = self self.packet_io.packet_listener = self
self.linger = linger self.linger = linger
self.done = asyncio.Event() self.done = asyncio.Event()
self.reset()
def reset(self): def reset(self):
self.expected_packet_index = 0 self.expected_packet_index = 0
self.receive_times = [] self.jitter_stats.reset()
def on_packet_received(self, packet):
self.receive_times.append(time.time())
def on_packet_received(self, data):
try: try:
packet_type, packet_data = parse_packet(packet) packet = Packet.from_bytes(data)
except ValueError: except ValueError:
return return
if packet_type == PacketType.RESET: if packet.packet_type == Packet.PacketType.RESET:
logging.info(color('=== Received RESET', 'magenta')) logging.info(color('=== Received RESET', 'magenta'))
self.reset() self.reset()
return return
try: jitter = self.jitter_stats.on_packet_received(packet)
packet_flags, packet_index = parse_packet_sequence(packet_data)
except ValueError:
return
interval = (
self.receive_times[-1] - self.receive_times[-2]
if len(self.receive_times) >= 2
else 0
)
logging.info( logging.info(
color( color(
f'<<< Received packet {packet_index}: ' f'<<< Received packet {packet.sequence}: '
f'flags=0x{packet_flags:02X}, {len(packet)} bytes, ' f'flags={packet.flags}, {len(data)} bytes, '
f'interval={interval:.4f}', f'jitter={jitter:.4f}',
'green', 'green',
) )
) )
if packet_index != self.expected_packet_index: if packet.sequence != self.expected_packet_index:
logging.info( logging.info(
color( color(
f'!!! Unexpected packet, expected {self.expected_packet_index} ' f'!!! Unexpected packet, expected {self.expected_packet_index} '
f'but received {packet_index}' f'but received {packet.sequence}'
) )
) )
self.expected_packet_index = packet_index + 1 self.expected_packet_index = packet.sequence + 1
AsyncRunner.spawn( AsyncRunner.spawn(
self.packet_io.send_packet( self.packet_io.send_packet(
struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index) bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence))
) )
) )
if packet_flags & PACKET_FLAG_LAST: if packet.flags & Packet.PacketFlags.LAST:
if len(self.receive_times) >= 3: self.jitter_stats.show_stats()
# Show basic stats
intervals = [
self.receive_times[i + 1] - self.receive_times[i]
for i in range(len(self.receive_times) - 1)
]
log_stats('Packet intervals', intervals, 3)
# Show a histogram
bin_count = 20
bins = [0] * bin_count
interval_min = min(intervals)
interval_max = max(intervals)
interval_range = interval_max - interval_min
bin_thresholds = [
interval_min + i * (interval_range / bin_count)
for i in range(bin_count)
]
for interval in intervals:
for i in reversed(range(bin_count)):
if interval >= bin_thresholds[i]:
bins[i] += 1
break
for i in range(bin_count):
logging.info(f'@@@ >= {bin_thresholds[i]:.4f}: {bins[i]}')
if not self.linger: if not self.linger:
self.done.set() self.done.set()
@@ -1179,6 +1256,7 @@ class Central(Connection.Listener):
self.device.classic_enabled = self.classic self.device.classic_enabled = self.classic
# Set up a pairing config factory with minimal requirements. # Set up a pairing config factory with minimal requirements.
self.device.config.keystore = "JsonKeyStore"
self.device.pairing_config_factory = lambda _: PairingConfig( self.device.pairing_config_factory = lambda _: PairingConfig(
sc=False, mitm=False, bonding=False sc=False, mitm=False, bonding=False
) )
@@ -1196,7 +1274,11 @@ class Central(Connection.Listener):
self.connection = await self.device.connect( self.connection = await self.device.connect(
self.peripheral_address, self.peripheral_address,
connection_parameters_preferences=self.connection_parameter_preferences, connection_parameters_preferences=self.connection_parameter_preferences,
transport=BT_BR_EDR_TRANSPORT if self.classic else BT_LE_TRANSPORT, transport=(
PhysicalTransport.BR_EDR
if self.classic
else PhysicalTransport.LE
),
) )
except CommandTimeoutError: except CommandTimeoutError:
logging.info(color('!!! Connection timed out', 'red')) logging.info(color('!!! Connection timed out', 'red'))
@@ -1211,6 +1293,10 @@ class Central(Connection.Listener):
self.connection.listener = self self.connection.listener = self
print_connection(self.connection) print_connection(self.connection)
if not self.classic:
phy = await self.connection.get_phy()
print_connection_phy(phy)
# Switch roles if needed. # Switch roles if needed.
if self.role_switch: if self.role_switch:
await switch_roles(self.connection, self.role_switch) await switch_roles(self.connection, self.role_switch)
@@ -1267,8 +1353,8 @@ class Central(Connection.Listener):
def on_connection_parameters_update(self): def on_connection_parameters_update(self):
print_connection(self.connection) print_connection(self.connection)
def on_connection_phy_update(self): def on_connection_phy_update(self, phy):
print_connection(self.connection) print_connection_phy(phy)
def on_connection_att_mtu_update(self): def on_connection_att_mtu_update(self):
print_connection(self.connection) print_connection(self.connection)
@@ -1323,6 +1409,7 @@ class Peripheral(Device.Listener, Connection.Listener):
self.device.classic_enabled = self.classic self.device.classic_enabled = self.classic
# Set up a pairing config factory with minimal requirements. # Set up a pairing config factory with minimal requirements.
self.device.config.keystore = "JsonKeyStore"
self.device.pairing_config_factory = lambda _: PairingConfig( self.device.pairing_config_factory = lambda _: PairingConfig(
sc=False, mitm=False, bonding=False sc=False, mitm=False, bonding=False
) )
@@ -1394,8 +1481,8 @@ class Peripheral(Device.Listener, Connection.Listener):
def on_connection_parameters_update(self): def on_connection_parameters_update(self):
print_connection(self.connection) print_connection(self.connection)
def on_connection_phy_update(self): def on_connection_phy_update(self, phy):
print_connection(self.connection) print_connection_phy(phy)
def on_connection_att_mtu_update(self): def on_connection_att_mtu_update(self):
print_connection(self.connection) print_connection(self.connection)
@@ -1470,7 +1557,7 @@ def create_mode_factory(ctx, default_mode):
def create_scenario_factory(ctx, default_scenario): def create_scenario_factory(ctx, default_scenario):
scenario = ctx.obj['scenario'] scenario = ctx.obj['scenario']
if scenario is None: if scenario is None:
scenarion = default_scenario scenario = default_scenario
def create_scenario(packet_io): def create_scenario(packet_io):
if scenario == 'send': if scenario == 'send':
@@ -1529,6 +1616,7 @@ def create_scenario_factory(ctx, default_scenario):
'--att-mtu', '--att-mtu',
metavar='MTU', metavar='MTU',
type=click.IntRange(23, 517), type=click.IntRange(23, 517),
default=517,
help='GATT MTU (gatt-client mode)', help='GATT MTU (gatt-client mode)',
) )
@click.option( @click.option(
@@ -1604,7 +1692,7 @@ def create_scenario_factory(ctx, default_scenario):
'--packet-size', '--packet-size',
'-s', '-s',
metavar='SIZE', metavar='SIZE',
type=click.IntRange(8, 8192), type=click.IntRange(10, 8192),
default=500, default=500,
help='Packet size (send or ping scenario)', help='Packet size (send or ping scenario)',
) )

View File

@@ -22,7 +22,6 @@
import asyncio import asyncio
import logging import logging
import os import os
import random
import re import re
import humanize import humanize
from typing import Optional, Union from typing import Optional, Union
@@ -56,8 +55,14 @@ from prompt_toolkit.layout import (
from bumble import __version__ from bumble import __version__
import bumble.core import bumble.core
from bumble import colors from bumble import colors
from bumble.core import UUID, AdvertisingData, BT_LE_TRANSPORT from bumble.core import UUID, AdvertisingData, PhysicalTransport
from bumble.device import ConnectionParametersPreferences, Device, Connection, Peer from bumble.device import (
ConnectionParametersPreferences,
ConnectionPHY,
Device,
Connection,
Peer,
)
from bumble.utils import AsyncRunner from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
@@ -125,6 +130,7 @@ def parse_phys(phys):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ConsoleApp: class ConsoleApp:
connected_peer: Optional[Peer] connected_peer: Optional[Peer]
connection_phy: Optional[ConnectionPHY]
def __init__(self): def __init__(self):
self.known_addresses = set() self.known_addresses = set()
@@ -132,6 +138,7 @@ class ConsoleApp:
self.known_local_attributes = [] self.known_local_attributes = []
self.device = None self.device = None
self.connected_peer = None self.connected_peer = None
self.connection_phy = None
self.top_tab = 'device' self.top_tab = 'device'
self.monitor_rssi = False self.monitor_rssi = False
self.connection_rssi = None self.connection_rssi = None
@@ -328,14 +335,14 @@ class ConsoleApp:
elif self.connected_peer: elif self.connected_peer:
connection = self.connected_peer.connection connection = self.connected_peer.connection
connection_parameters = ( connection_parameters = (
f'{connection.parameters.connection_interval}/' f'{connection.parameters.connection_interval:.2f}/'
f'{connection.parameters.peripheral_latency}/' f'{connection.parameters.peripheral_latency}/'
f'{connection.parameters.supervision_timeout}' f'{connection.parameters.supervision_timeout:.2f}'
) )
if connection.transport == BT_LE_TRANSPORT: if self.connection_phy is not None:
phy_state = ( phy_state = (
f' RX={le_phy_name(connection.phy.rx_phy)}/' f' RX={le_phy_name(self.connection_phy.rx_phy)}/'
f'TX={le_phy_name(connection.phy.tx_phy)}' f'TX={le_phy_name(self.connection_phy.tx_phy)}'
) )
else: else:
phy_state = '' phy_state = ''
@@ -654,11 +661,12 @@ class ConsoleApp:
self.append_to_output('connecting...') self.append_to_output('connecting...')
try: try:
await self.device.connect( connection = await self.device.connect(
params[0], params[0],
connection_parameters_preferences=connection_parameters_preferences, connection_parameters_preferences=connection_parameters_preferences,
timeout=DEFAULT_CONNECTION_TIMEOUT, timeout=DEFAULT_CONNECTION_TIMEOUT,
) )
self.connection_phy = await connection.get_phy()
self.top_tab = 'services' self.top_tab = 'services'
except bumble.core.TimeoutError: except bumble.core.TimeoutError:
self.show_error('connection timed out') self.show_error('connection timed out')
@@ -838,8 +846,8 @@ class ConsoleApp:
phy = await self.connected_peer.connection.get_phy() phy = await self.connected_peer.connection.get_phy()
self.append_to_output( self.append_to_output(
f'PHY: RX={HCI_Constant.le_phy_name(phy[0])}, ' f'PHY: RX={HCI_Constant.le_phy_name(phy.rx_phy)}, '
f'TX={HCI_Constant.le_phy_name(phy[1])}' f'TX={HCI_Constant.le_phy_name(phy.tx_phy)}'
) )
async def do_request_mtu(self, params): async def do_request_mtu(self, params):
@@ -1076,10 +1084,9 @@ class DeviceListener(Device.Listener, Connection.Listener):
f'{self.app.connected_peer.connection.parameters}' f'{self.app.connected_peer.connection.parameters}'
) )
def on_connection_phy_update(self): def on_connection_phy_update(self, phy):
self.app.append_to_output( self.app.connection_phy = phy
f'connection phy update: {self.app.connected_peer.connection.phy}' self.app.append_to_output(f'connection phy update: {phy}')
)
def on_connection_att_mtu_update(self): def on_connection_att_mtu_update(self):
self.app.append_to_output( self.app.append_to_output(

View File

@@ -37,6 +37,8 @@ from bumble.hci import (
HCI_Command_Status_Event, HCI_Command_Status_Event,
HCI_READ_BUFFER_SIZE_COMMAND, HCI_READ_BUFFER_SIZE_COMMAND,
HCI_Read_Buffer_Size_Command, HCI_Read_Buffer_Size_Command,
HCI_LE_READ_BUFFER_SIZE_V2_COMMAND,
HCI_LE_Read_Buffer_Size_V2_Command,
HCI_READ_BD_ADDR_COMMAND, HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command, HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND, HCI_READ_LOCAL_NAME_COMMAND,
@@ -75,7 +77,7 @@ async def get_classic_info(host: Host) -> None:
if command_succeeded(response): if command_succeeded(response):
print() print()
print( print(
color('Classic Address:', 'yellow'), color('Public Address:', 'yellow'),
response.return_parameters.bd_addr.to_string(False), response.return_parameters.bd_addr.to_string(False),
) )
@@ -147,7 +149,7 @@ async def get_le_info(host: Host) -> None:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_acl_flow_control_info(host: Host) -> None: async def get_flow_control_info(host: Host) -> None:
print() print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND): if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
@@ -160,14 +162,28 @@ async def get_acl_flow_control_info(host: Host) -> None:
f'packets of size {response.return_parameters.hc_acl_data_packet_length}', f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
) )
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND): if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.le_acl_data_packet_length}',
)
print(
color('LE ISO Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_iso_data_packets} '
f'packets of size {response.return_parameters.iso_data_packet_length}',
)
elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command( response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True HCI_LE_Read_Buffer_Size_Command(), check_result=True
) )
print( print(
color('LE ACL Flow Control:', 'yellow'), color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_le_acl_data_packets} ' f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}', f'packets of size {response.return_parameters.le_acl_data_packet_length}',
) )
@@ -274,8 +290,8 @@ async def async_main(latency_probes, transport):
# Get the LE info # Get the LE info
await get_le_info(host) await get_le_info(host)
# Print the ACL flow control info # Print the flow control info
await get_acl_flow_control_info(host) await get_flow_control_info(host)
# Get codec info # Get codec info
await get_codecs_info(host) await get_codecs_info(host)

View File

@@ -29,7 +29,9 @@ from bumble.gatt import Service
from bumble.profiles.device_information_service import DeviceInformationServiceProxy from bumble.profiles.device_information_service import DeviceInformationServiceProxy
from bumble.profiles.battery_service import BatteryServiceProxy from bumble.profiles.battery_service import BatteryServiceProxy
from bumble.profiles.gap import GenericAccessServiceProxy from bumble.profiles.gap import GenericAccessServiceProxy
from bumble.profiles.pacs import PublishedAudioCapabilitiesServiceProxy
from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy
from bumble.profiles.vcs import VolumeControlServiceProxy
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -126,14 +128,52 @@ async def show_tmas(
print(color('### Telephony And Media Audio Service', 'yellow')) print(color('### Telephony And Media Audio Service', 'yellow'))
if tmas.role: if tmas.role:
print( role = await tmas.role.read_value()
color(' Role:', 'green'), print(color(' Role:', 'green'), role)
await tmas.role.read_value(),
)
print() print()
# -----------------------------------------------------------------------------
async def show_pacs(pacs: PublishedAudioCapabilitiesServiceProxy) -> None:
print(color('### Published Audio Capabilities Service', 'yellow'))
contexts = await pacs.available_audio_contexts.read_value()
print(color(' Available Audio Contexts:', 'green'), contexts)
contexts = await pacs.supported_audio_contexts.read_value()
print(color(' Supported Audio Contexts:', 'green'), contexts)
if pacs.sink_pac:
pac = await pacs.sink_pac.read_value()
print(color(' Sink PAC: ', 'green'), pac)
if pacs.sink_audio_locations:
audio_locations = await pacs.sink_audio_locations.read_value()
print(color(' Sink Audio Locations: ', 'green'), audio_locations)
if pacs.source_pac:
pac = await pacs.source_pac.read_value()
print(color(' Source PAC: ', 'green'), pac)
if pacs.source_audio_locations:
audio_locations = await pacs.source_audio_locations.read_value()
print(color(' Source Audio Locations: ', 'green'), audio_locations)
print()
# -----------------------------------------------------------------------------
async def show_vcs(vcs: VolumeControlServiceProxy) -> None:
print(color('### Volume Control Service', 'yellow'))
volume_state = await vcs.volume_state.read_value()
print(color(' Volume State:', 'green'), volume_state)
volume_flags = await vcs.volume_flags.read_value()
print(color(' Volume Flags:', 'green'), volume_flags)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def show_device_info(peer, done: Optional[asyncio.Future]) -> None: async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
try: try:
@@ -161,6 +201,12 @@ async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy): if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy):
await try_show(show_tmas, tmas) await try_show(show_tmas, tmas)
if pacs := peer.create_service_proxy(PublishedAudioCapabilitiesServiceProxy):
await try_show(show_pacs, pacs)
if vcs := peer.create_service_proxy(VolumeControlServiceProxy):
await try_show(show_vcs, vcs)
if done is not None: if done is not None:
done.set_result(None) done.set_result(None)
except asyncio.CancelledError: except asyncio.CancelledError:

View File

@@ -234,7 +234,7 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
Characteristic.WRITEABLE, Characteristic.WRITEABLE,
CharacteristicValue(write=self.on_rx_write), CharacteristicValue(write=self.on_rx_write),
) )
self.tx_characteristic = Characteristic( self.tx_characteristic: Characteristic[bytes] = Characteristic(
GG_GATTLINK_TX_CHARACTERISTIC_UUID, GG_GATTLINK_TX_CHARACTERISTIC_UUID,
Characteristic.Properties.NOTIFY, Characteristic.Properties.NOTIFY,
Characteristic.READABLE, Characteristic.READABLE,

View File

@@ -83,7 +83,7 @@ async def async_main():
return_parameters=bytes([hci.HCI_SUCCESS]), return_parameters=bytes([hci.HCI_SUCCESS]),
) )
# Return a packet with 'respond to sender' set to True # Return a packet with 'respond to sender' set to True
return (response.to_bytes(), True) return (bytes(response), True)
return None return None

View File

@@ -16,35 +16,36 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import datetime import datetime
import enum
import functools import functools
from importlib import resources from importlib import resources
import json import json
import os import os
import logging import logging
import pathlib import pathlib
from typing import Optional, List, cast
import weakref import weakref
import struct import wave
import ctypes try:
import wasmtime import lc3 # type: ignore # pylint: disable=E0401
import wasmtime.loader except ImportError as e:
import liblc3 # type: ignore raise ImportError("Try `python -m pip install \".[lc3]\"`.") from e
import click import click
import aiohttp.web import aiohttp.web
import bumble import bumble
from bumble import utils
from bumble.core import AdvertisingData from bumble.core import AdvertisingData
from bumble.colors import color from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters from bumble.device import Device, DeviceConfiguration, AdvertisingParameters, CisLink
from bumble.transport import open_transport from bumble.transport import open_transport
from bumble.profiles import ascs, bap, pacs from bumble.profiles import ascs, bap, pacs
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -54,6 +55,7 @@ logger = logging.getLogger(__name__)
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654 DEFAULT_UI_PORT = 7654
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
def _sink_pac_record() -> pacs.PacRecord: def _sink_pac_record() -> pacs.PacRecord:
@@ -100,153 +102,8 @@ def _source_pac_record() -> pacs.PacRecord:
) )
# ----------------------------------------------------------------------------- decoder: lc3.Decoder | None = None
# WASM - liblc3 encoding_config: bap.CodecSpecificConfiguration | None = None
# -----------------------------------------------------------------------------
store = wasmtime.loader.store
_memory = cast(wasmtime.Memory, liblc3.memory)
STACK_POINTER = _memory.data_len(store)
_memory.grow(store, 1)
# Mapping wasmtime memory to linear address
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
)
class Liblc3PcmFormat(enum.IntEnum):
S16 = 0
S24 = 1
S24_3LE = 2
FLOAT = 3
MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)
DECODER_STACK_POINTER = STACK_POINTER
ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
DEFAULT_PCM_SAMPLE_RATE = 48000
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
encoders: List[int] = []
decoders: List[int] = []
def setup_encoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
encoders[:num_channels] = [
liblc3.lc3_setup_encoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Input sample rate
ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
)
for i in range(num_channels)
]
def setup_decoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
decoders[:num_channels] = [
liblc3.lc3_setup_decoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
)
for i in range(num_channels)
]
def decode(
frame_duration_us: int,
num_channels: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = DECODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
input_bytes_per_frame = input_buffer_size // num_channels
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
* num_channels
)
for i in range(num_channels):
res = liblc3.lc3_decode(
decoders[i],
input_buffer_offset + input_bytes_per_frame * i,
input_bytes_per_frame,
DEFAULT_PCM_FORMAT,
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
num_channels, # Stride
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
def encode(
sdu_length: int,
num_channels: int,
stride: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = sdu_length
output_frame_size = output_buffer_size // num_channels
for i in range(num_channels):
res = liblc3.lc3_encode(
encoders[i],
DEFAULT_PCM_FORMAT,
input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
stride,
output_frame_size,
output_buffer_offset + output_frame_size * i,
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
async def lc3_source_task( async def lc3_source_task(
@@ -254,44 +111,49 @@ async def lc3_source_task(
sdu_length: int, sdu_length: int,
frame_duration_us: int, frame_duration_us: int,
device: Device, device: Device,
cis_handle: int, cis_link: CisLink,
) -> None: ) -> None:
with open(filename, 'rb') as f: logger.info(
header = f.read(44) "lc3_source_task filename=%s, sdu_length=%d, frame_duration=%.1f",
assert header[8:12] == b'WAVE' filename,
sdu_length,
frame_duration_us / 1000,
)
with wave.open(filename, 'rb') as wav:
bits_per_sample = wav.getsampwidth() * 8
pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = ( encoder: lc3.Encoder | None = None
struct.unpack("<HIIHH", header[22:36])
)
assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
frame_bytes = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
)
packet_sequence_number = 0
while True: while True:
next_round = datetime.datetime.now() + datetime.timedelta( next_round = datetime.datetime.now() + datetime.timedelta(
microseconds=frame_duration_us microseconds=frame_duration_us
) )
pcm_data = f.read(frame_bytes) if not encoder:
sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data) if (
encoding_config
and (frame_duration := encoding_config.frame_duration)
and (sampling_frequency := encoding_config.sampling_frequency)
and (
audio_channel_allocation := encoding_config.audio_channel_allocation
)
):
logger.info("Use %s", encoding_config)
encoder = lc3.Encoder(
frame_duration_us=frame_duration.us,
sample_rate_hz=sampling_frequency.hz,
num_channels=audio_channel_allocation.channel_count,
input_sample_rate_hz=wav.getframerate(),
)
else:
sdu = encoder.encode(
pcm=wav.readframes(encoder.get_frame_samples()),
num_bytes=sdu_length,
bit_depth=bits_per_sample,
)
cis_link.write(sdu)
iso_packet = HCI_IsoDataPacket(
connection_handle=cis_handle,
data_total_length=sdu_length + 4,
packet_sequence_number=packet_sequence_number,
pb_flag=0b10,
packet_status_flag=0,
iso_sdu_length=sdu_length,
iso_sdu_fragment=sdu,
)
device.host.send_hci_packet(iso_packet)
packet_sequence_number += 1
sleep_time = next_round - datetime.datetime.now() sleep_time = next_round - datetime.datetime.now()
await asyncio.sleep(sleep_time.total_seconds()) await asyncio.sleep(sleep_time.total_seconds() * 0.9)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -410,7 +272,7 @@ class Speaker:
def __init__( def __init__(
self, self,
device_config_path: Optional[str], device_config_path: str | None,
ui_port: int, ui_port: int,
transport: str, transport: str,
lc3_input_file_path: str, lc3_input_file_path: str,
@@ -437,6 +299,7 @@ class Speaker:
advertising_interval_min=25, advertising_interval_min=25,
advertising_interval_max=25, advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'), address=Address('F1:F2:F3:F4:F5:F6'),
identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
) )
device_config.le_enabled = True device_config.le_enabled = True
@@ -486,21 +349,35 @@ class Speaker:
def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine): def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine):
codec_config = ase.codec_specific_configuration codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration) if (
pcm = decode( not isinstance(codec_config, bap.CodecSpecificConfiguration)
codec_config.frame_duration.us, or codec_config.frame_duration is None
codec_config.audio_channel_allocation.channel_count, or codec_config.audio_channel_allocation is None
pdu.iso_sdu_fragment, or decoder is None
or not pdu.iso_sdu_fragment
):
return
pcm = decoder.decode(
pdu.iso_sdu_fragment, bit_depth=DEFAULT_PCM_BYTES_PER_SAMPLE * 8
)
utils.cancel_on_event(
self.device, 'disconnection', self.ui_server.send_audio(pcm)
) )
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
def on_ase_state_change(ase: ascs.AseStateMachine) -> None: def on_ase_state_change(ase: ascs.AseStateMachine) -> None:
codec_config = ase.codec_specific_configuration
if ase.state == ascs.AseStateMachine.State.STREAMING: if ase.state == ascs.AseStateMachine.State.STREAMING:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
assert ase.cis_link
if ase.role == ascs.AudioRole.SOURCE: if ase.role == ascs.AudioRole.SOURCE:
ase.cis_link.abort_on( if (
not isinstance(codec_config, bap.CodecSpecificConfiguration)
or ase.cis_link is None
or codec_config.octets_per_codec_frame is None
or codec_config.frame_duration is None
or codec_config.codec_frames_per_sdu is None
):
return
utils.cancel_on_event(
ase.cis_link,
'disconnection', 'disconnection',
lc3_source_task( lc3_source_task(
filename=self.lc3_input_file_path, filename=self.lc3_input_file_path,
@@ -510,25 +387,30 @@ class Speaker:
), ),
frame_duration_us=codec_config.frame_duration.us, frame_duration_us=codec_config.frame_duration.us,
device=self.device, device=self.device,
cis_handle=ase.cis_link.handle, cis_link=ase.cis_link,
), ),
) )
else: else:
if not ase.cis_link:
return
ase.cis_link.sink = functools.partial(on_pdu, ase=ase) ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED: elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED:
codec_config = ase.codec_specific_configuration if (
assert isinstance(codec_config, bap.CodecSpecificConfiguration) not isinstance(codec_config, bap.CodecSpecificConfiguration)
or codec_config.sampling_frequency is None
or codec_config.frame_duration is None
or codec_config.audio_channel_allocation is None
):
return
if ase.role == ascs.AudioRole.SOURCE: if ase.role == ascs.AudioRole.SOURCE:
setup_encoders( global encoding_config
codec_config.sampling_frequency.hz, encoding_config = codec_config
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
else: else:
setup_decoders( global decoder
codec_config.sampling_frequency.hz, decoder = lc3.Decoder(
codec_config.frame_duration.us, frame_duration_us=codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count, sample_rate_hz=codec_config.sampling_frequency.hz,
num_channels=codec_config.audio_channel_allocation.channel_count,
) )
for ase in ascs_service.ase_state_machines.values(): for ase in ascs_service.ase_state_machines.values():
@@ -567,7 +449,7 @@ def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) ->
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def main(): def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
speaker() speaker()

Binary file not shown.

View File

@@ -18,9 +18,12 @@
import asyncio import asyncio
import os import os
import logging import logging
import struct
import click import click
from prompt_toolkit.shortcuts import PromptSession from prompt_toolkit.shortcuts import PromptSession
from bumble.a2dp import make_audio_sink_service_sdp_records
from bumble.colors import color from bumble.colors import color
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -30,17 +33,20 @@ from bumble.smp import error_name as smp_error_name
from bumble.keys import JsonKeyStore from bumble.keys import JsonKeyStore
from bumble.core import ( from bumble.core import (
AdvertisingData, AdvertisingData,
Appearance,
ProtocolError, ProtocolError,
BT_LE_TRANSPORT, PhysicalTransport,
BT_BR_EDR_TRANSPORT, UUID,
) )
from bumble.gatt import ( from bumble.gatt import (
GATT_DEVICE_NAME_CHARACTERISTIC, GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE, GATT_GENERIC_ACCESS_SERVICE,
GATT_HEART_RATE_SERVICE,
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Service, Service,
Characteristic, Characteristic,
CharacteristicValue,
) )
from bumble.hci import OwnAddressType
from bumble.att import ( from bumble.att import (
ATT_Error, ATT_Error,
ATT_INSUFFICIENT_AUTHENTICATION_ERROR, ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
@@ -63,7 +69,7 @@ class Waiter:
self.linger = linger self.linger = linger
def terminate(self): def terminate(self):
if not self.linger: if not self.linger and not self.done.done:
self.done.set_result(None) self.done.set_result(None)
async def wait_until_terminated(self): async def wait_until_terminated(self):
@@ -194,7 +200,7 @@ class Delegate(PairingDelegate):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_peer_name(peer, mode): async def get_peer_name(peer, mode):
if mode == 'classic': if peer.connection.transport == PhysicalTransport.BR_EDR:
return await peer.request_name() return await peer.request_name()
# Try to get the peer name from GATT # Try to get the peer name from GATT
@@ -226,13 +232,14 @@ def read_with_error(connection):
raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR) raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
def write_with_error(connection, _value): # -----------------------------------------------------------------------------
if not connection.is_encrypted: def sdp_records():
raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR) service_record_handle = 0x00010001
return {
if not AUTHENTICATION_ERROR_RETURNED[1]: service_record_handle: make_audio_sink_service_sdp_records(
AUTHENTICATION_ERROR_RETURNED[1] = True service_record_handle
raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR) )
}
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -240,15 +247,19 @@ def on_connection(connection, request):
print(color(f'<<< Connection: {connection}', 'green')) print(color(f'<<< Connection: {connection}', 'green'))
# Listen for pairing events # Listen for pairing events
connection.on('pairing_start', on_pairing_start) connection.on(connection.EVENT_PAIRING_START, on_pairing_start)
connection.on('pairing', lambda keys: on_pairing(connection, keys)) connection.on(connection.EVENT_PAIRING, lambda keys: on_pairing(connection, keys))
connection.on( connection.on(
'pairing_failure', lambda reason: on_pairing_failure(connection, reason) connection.EVENT_CLASSIC_PAIRING, lambda: on_classic_pairing(connection)
)
connection.on(
connection.EVENT_PAIRING_FAILURE,
lambda reason: on_pairing_failure(connection, reason),
) )
# Listen for encryption changes # Listen for encryption changes
connection.on( connection.on(
'connection_encryption_change', connection.EVENT_CONNECTION_ENCRYPTION_CHANGE,
lambda: on_connection_encryption_change(connection), lambda: on_connection_encryption_change(connection),
) )
@@ -289,6 +300,20 @@ async def on_pairing(connection, keys):
Waiter.instance.terminate() Waiter.instance.terminate()
# -----------------------------------------------------------------------------
@AsyncRunner.run_in_task()
async def on_classic_pairing(connection):
print(color('***-----------------------------------', 'cyan'))
print(
color(
f'*** Paired [Classic]! (peer identity={connection.peer_address})', 'cyan'
)
)
print(color('***-----------------------------------', 'cyan'))
await asyncio.sleep(POST_PAIRING_DELAY)
Waiter.instance.terminate()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@AsyncRunner.run_in_task() @AsyncRunner.run_in_task()
async def on_pairing_failure(connection, reason): async def on_pairing_failure(connection, reason):
@@ -306,6 +331,7 @@ async def pair(
mitm, mitm,
bond, bond,
ctkd, ctkd,
advertising_address,
identity_address, identity_address,
linger, linger,
io, io,
@@ -314,6 +340,8 @@ async def pair(
request, request,
print_keys, print_keys,
keystore_file, keystore_file,
advertise_service_uuids,
advertise_appearance,
device_config, device_config,
hci_transport, hci_transport,
address_or_name, address_or_name,
@@ -329,29 +357,33 @@ async def pair(
# Expose a GATT characteristic that can be used to trigger pairing by # Expose a GATT characteristic that can be used to trigger pairing by
# responding with an authentication error when read # responding with an authentication error when read
if mode == 'le': if mode in ('le', 'dual'):
device.le_enabled = True
device.add_service( device.add_service(
Service( Service(
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5', GATT_HEART_RATE_SERVICE,
[ [
Characteristic( Characteristic(
'552957FB-CF1F-4A31-9535-E78847E1A714', GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Characteristic.Properties.READ Characteristic.Properties.READ,
| Characteristic.Properties.WRITE, Characteristic.READ_REQUIRES_AUTHENTICATION,
Characteristic.READABLE | Characteristic.WRITEABLE, bytes(1),
CharacteristicValue(
read=read_with_error, write=write_with_error
),
) )
], ],
) )
) )
# Select LE or Classic # LE and Classic support
if mode == 'classic': if mode in ('classic', 'dual'):
device.classic_enabled = True device.classic_enabled = True
device.classic_smp_enabled = ctkd device.classic_smp_enabled = ctkd
if mode in ('le', 'dual'):
device.le_enabled = True
if mode == 'dual':
device.le_simultaneous_enabled = True
# Setup SDP
if mode in ('classic', 'dual'):
device.sdp_service_records = sdp_records()
# Get things going # Get things going
await device.power_on() await device.power_on()
@@ -373,7 +405,9 @@ async def pair(
shared_data = ( shared_data = (
None None
if oob == '-' if oob == '-'
else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob))) else OobData.from_ad(
AdvertisingData.from_bytes(bytes.fromhex(oob))
).shared_data
) )
legacy_context = OobLegacyContext() legacy_context = OobLegacyContext()
oob_contexts = PairingConfig.OobConfig( oob_contexts = PairingConfig.OobConfig(
@@ -381,16 +415,19 @@ async def pair(
peer_data=shared_data, peer_data=shared_data,
legacy_context=legacy_context, legacy_context=legacy_context,
) )
oob_data = OobData(
address=device.random_address,
shared_data=shared_data,
legacy_context=legacy_context,
)
print(color('@@@-----------------------------------', 'yellow')) print(color('@@@-----------------------------------', 'yellow'))
print(color('@@@ OOB Data:', 'yellow')) print(color('@@@ OOB Data:', 'yellow'))
print(color(f'@@@ {our_oob_context.share()}', 'yellow')) if shared_data is None:
oob_data = OobData(
address=device.random_address, shared_data=our_oob_context.share()
)
print(
color(
f'@@@ SHARE: {bytes(oob_data.to_ad()).hex()}',
'yellow',
)
)
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow')) print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
print(color(f'@@@ HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow'))
print(color('@@@-----------------------------------', 'yellow')) print(color('@@@-----------------------------------', 'yellow'))
else: else:
oob_contexts = None oob_contexts = None
@@ -417,7 +454,9 @@ async def pair(
print(color(f'=== Connecting to {address_or_name}...', 'green')) print(color(f'=== Connecting to {address_or_name}...', 'green'))
connection = await device.connect( connection = await device.connect(
address_or_name, address_or_name,
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT, transport=(
PhysicalTransport.LE if mode == 'le' else PhysicalTransport.BR_EDR
),
) )
if not request: if not request:
@@ -430,13 +469,109 @@ async def pair(
print(color(f'Pairing failed: {error}', 'red')) print(color(f'Pairing failed: {error}', 'red'))
else: else:
if mode == 'le': if mode in ('le', 'dual'):
# Advertise so that peers can find us and connect # Advertise so that peers can find us and connect.
await device.start_advertising(auto_restart=True) # Include the heart rate service UUID in the advertisement data
else: # so that devices like iPhones can show this device in their
# Bluetooth selector.
service_uuids_16 = []
service_uuids_32 = []
service_uuids_128 = []
if advertise_service_uuids:
for uuid in advertise_service_uuids:
uuid = uuid.replace("-", "")
if len(uuid) == 4:
service_uuids_16.append(UUID(uuid))
elif len(uuid) == 8:
service_uuids_32.append(UUID(uuid))
elif len(uuid) == 32:
service_uuids_128.append(UUID(uuid))
else:
print(color('Invalid UUID format', 'red'))
return
else:
service_uuids_16.append(GATT_HEART_RATE_SERVICE)
flags = AdvertisingData.Flags.LE_LIMITED_DISCOVERABLE_MODE
if mode == 'le':
flags |= AdvertisingData.Flags.BR_EDR_NOT_SUPPORTED
if mode == 'dual':
flags |= AdvertisingData.Flags.SIMULTANEOUS_LE_BR_EDR_CAPABLE
ad_structs = [
(
AdvertisingData.FLAGS,
bytes([flags]),
),
(AdvertisingData.COMPLETE_LOCAL_NAME, 'Bumble'.encode()),
]
if service_uuids_16:
ad_structs.append(
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
b"".join(bytes(uuid) for uuid in service_uuids_16),
)
)
if service_uuids_32:
ad_structs.append(
(
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
b"".join(bytes(uuid) for uuid in service_uuids_32),
)
)
if service_uuids_128:
ad_structs.append(
(
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
b"".join(bytes(uuid) for uuid in service_uuids_128),
)
)
if advertise_appearance:
advertise_appearance = advertise_appearance.upper()
try:
advertise_appearance_int = int(advertise_appearance)
except ValueError:
category, subcategory = advertise_appearance.split('/')
try:
category_enum = Appearance.Category[category]
except ValueError:
print(
color(f'Invalid appearance category {category}', 'red')
)
return
subcategory_class = Appearance.SUBCATEGORY_CLASSES[
category_enum
]
try:
subcategory_enum = subcategory_class[subcategory]
except ValueError:
print(color(f'Invalid subcategory {subcategory}', 'red'))
return
advertise_appearance_int = int(
Appearance(category_enum, subcategory_enum)
)
ad_structs.append(
(
AdvertisingData.APPEARANCE,
struct.pack('<H', advertise_appearance_int),
)
)
device.advertising_data = bytes(AdvertisingData(ad_structs))
await device.start_advertising(
auto_restart=True,
own_address_type=(
OwnAddressType.PUBLIC
if advertising_address == 'public'
else OwnAddressType.RANDOM
),
)
if mode in ('classic', 'dual'):
# Become discoverable and connectable # Become discoverable and connectable
await device.set_discoverable(True) await device.set_discoverable(True)
await device.set_connectable(True) await device.set_connectable(True)
print(color('Ready for connections on', 'blue'), device.public_address)
# Run until the user asks to exit # Run until the user asks to exit
await Waiter.instance.wait_until_terminated() await Waiter.instance.wait_until_terminated()
@@ -456,7 +591,10 @@ class LogHandler(logging.Handler):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@click.command() @click.command()
@click.option( @click.option(
'--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True '--mode',
type=click.Choice(['le', 'classic', 'dual']),
default='le',
show_default=True,
) )
@click.option( @click.option(
'--sc', '--sc',
@@ -478,6 +616,10 @@ class LogHandler(logging.Handler):
help='Enable CTKD', help='Enable CTKD',
show_default=True, show_default=True,
) )
@click.option(
'--advertising-address',
type=click.Choice(['random', 'public']),
)
@click.option( @click.option(
'--identity-address', '--identity-address',
type=click.Choice(['random', 'public']), type=click.Choice(['random', 'public']),
@@ -506,9 +648,20 @@ class LogHandler(logging.Handler):
@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing') @click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing')
@click.option( @click.option(
'--keystore-file', '--keystore-file',
metavar='<filename>', metavar='FILENAME',
help='File in which to store the pairing keys', help='File in which to store the pairing keys',
) )
@click.option(
'--advertise-service-uuid',
metavar="UUID",
multiple=True,
help="Advertise a GATT service UUID (may be specified more than once)",
)
@click.option(
'--advertise-appearance',
metavar='APPEARANCE',
help='Advertise an Appearance ID (int value or string)',
)
@click.argument('device-config') @click.argument('device-config')
@click.argument('hci_transport') @click.argument('hci_transport')
@click.argument('address-or-name', required=False) @click.argument('address-or-name', required=False)
@@ -518,6 +671,7 @@ def main(
mitm, mitm,
bond, bond,
ctkd, ctkd,
advertising_address,
identity_address, identity_address,
linger, linger,
io, io,
@@ -526,6 +680,8 @@ def main(
request, request,
print_keys, print_keys,
keystore_file, keystore_file,
advertise_service_uuid,
advertise_appearance,
device_config, device_config,
hci_transport, hci_transport,
address_or_name, address_or_name,
@@ -544,6 +700,7 @@ def main(
mitm, mitm,
bond, bond,
ctkd, ctkd,
advertising_address,
identity_address, identity_address,
linger, linger,
io, io,
@@ -552,6 +709,8 @@ def main(
request, request,
print_keys, print_keys,
keystore_file, keystore_file,
advertise_service_uuid,
advertise_appearance,
device_config, device_config,
hci_transport, hci_transport,
address_or_name, address_or_name,

View File

@@ -56,7 +56,7 @@ from bumble.core import (
AdvertisingData, AdvertisingData,
ConnectionError as BumbleConnectionError, ConnectionError as BumbleConnectionError,
DeviceClass, DeviceClass,
BT_BR_EDR_TRANSPORT, PhysicalTransport,
) )
from bumble.device import Connection, Device, DeviceConfiguration from bumble.device import Connection, Device, DeviceConfiguration
from bumble.hci import Address, HCI_CONNECTION_ALREADY_EXISTS_ERROR, HCI_Constant from bumble.hci import Address, HCI_CONNECTION_ALREADY_EXISTS_ERROR, HCI_Constant
@@ -286,7 +286,7 @@ class Player:
async def connect(self, device: Device, address: str) -> Connection: async def connect(self, device: Device, address: str) -> Connection:
print(color(f"Connecting to {address}...", "green")) print(color(f"Connecting to {address}...", "green"))
connection = await device.connect(address, transport=BT_BR_EDR_TRANSPORT) connection = await device.connect(address, transport=PhysicalTransport.BR_EDR)
# Request authentication # Request authentication
if self.authenticate: if self.authenticate:
@@ -402,7 +402,7 @@ class Player:
async def pair(self, device: Device, address: str) -> None: async def pair(self, device: Device, address: str) -> None:
print(color(f"Connecting to {address}...", "green")) print(color(f"Connecting to {address}...", "green"))
connection = await device.connect(address, transport=BT_BR_EDR_TRANSPORT) connection = await device.connect(address, transport=PhysicalTransport.BR_EDR)
print(color("Pairing...", "magenta")) print(color("Pairing...", "magenta"))
await connection.authenticate() await connection.authenticate()

View File

@@ -271,7 +271,7 @@ class ClientBridge:
print(color(f"@@@ Connecting to Bluetooth {self.address}", "blue")) print(color(f"@@@ Connecting to Bluetooth {self.address}", "blue"))
assert self.device assert self.device
self.connection = await self.device.connect( self.connection = await self.device.connect(
self.address, transport=core.BT_BR_EDR_TRANSPORT self.address, transport=core.PhysicalTransport.BR_EDR
) )
print(color(f"@@@ Bluetooth connection: {self.connection}", "blue")) print(color(f"@@@ Bluetooth connection: {self.connection}", "blue"))
self.connection.on("disconnection", self.on_disconnection) self.connection.on("disconnection", self.on_disconnection)

View File

@@ -144,18 +144,18 @@ class Printer:
help='Format of the input file', help='Format of the input file',
) )
@click.option( @click.option(
'--vendors', '--vendor',
type=click.Choice(['android', 'zephyr']), type=click.Choice(['android', 'zephyr']),
multiple=True, multiple=True,
help='Support vendor-specific commands (list one or more)', help='Support vendor-specific commands (list one or more)',
) )
@click.argument('filename') @click.argument('filename')
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
def main(format, vendors, filename): def main(format, vendor, filename):
for vendor in vendors: for vendor_name in vendor:
if vendor == 'android': if vendor_name == 'android':
import bumble.vendor.android.hci import bumble.vendor.android.hci
elif vendor == 'zephyr': elif vendor_name == 'zephyr':
import bumble.vendor.zephyr.hci import bumble.vendor.zephyr.hci
input = open(filename, 'rb') input = open(filename, 'rb')
@@ -180,7 +180,7 @@ def main(format, vendors, filename):
else: else:
printer.print(color("[TRUNCATED]", "red")) printer.print(color("[TRUNCATED]", "red"))
except Exception as error: except Exception as error:
logger.exception() logger.exception('')
print(color(f'!!! {error}', 'red')) print(color(f'!!! {error}', 'red'))

View File

@@ -34,7 +34,7 @@ from aiohttp import web
import bumble import bumble
from bumble.colors import color from bumble.colors import color
from bumble.core import BT_BR_EDR_TRANSPORT, CommandTimeoutError from bumble.core import PhysicalTransport, CommandTimeoutError
from bumble.device import Connection, Device, DeviceConfiguration from bumble.device import Connection, Device, DeviceConfiguration
from bumble.hci import HCI_StatusError from bumble.hci import HCI_StatusError
from bumble.pairing import PairingConfig from bumble.pairing import PairingConfig
@@ -568,7 +568,9 @@ class Speaker:
async def connect(self, address): async def connect(self, address):
# Connect to the source # Connect to the source
print(f'=== Connecting to {address}...') print(f'=== Connecting to {address}...')
connection = await self.device.connect(address, transport=BT_BR_EDR_TRANSPORT) connection = await self.device.connect(
address, transport=PhysicalTransport.BR_EDR
)
print(f'=== Connected to {connection.peer_address}') print(f'=== Connected to {connection.peer_address}')
# Request authentication # Request authentication

View File

@@ -26,9 +26,9 @@ from typing import Awaitable, Callable
from typing_extensions import ClassVar, Self from typing_extensions import ClassVar, Self
from .codecs import AacAudioRtpPacket from bumble.codecs import AacAudioRtpPacket
from .company_ids import COMPANY_IDENTIFIERS from bumble.company_ids import COMPANY_IDENTIFIERS
from .sdp import ( from bumble.sdp import (
DataElement, DataElement,
ServiceAttribute, ServiceAttribute,
SDP_PUBLIC_BROWSE_ROOT, SDP_PUBLIC_BROWSE_ROOT,
@@ -38,7 +38,7 @@ from .sdp import (
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
) )
from .core import ( from bumble.core import (
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
BT_AUDIO_SOURCE_SERVICE, BT_AUDIO_SOURCE_SERVICE,
BT_AUDIO_SINK_SERVICE, BT_AUDIO_SINK_SERVICE,
@@ -46,7 +46,7 @@ from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE, BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
name_or_number, name_or_number,
) )
from .rtp import MediaPacket from bumble.rtp import MediaPacket
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -155,7 +155,7 @@ def flags_to_list(flags, values):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)): def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)):
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from .avdtp import AVDTP_PSM from bumble.avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1] version_int = version[0] << 8 | version[1]
return [ return [
@@ -209,7 +209,7 @@ def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)): def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from .avdtp import AVDTP_PSM from bumble.avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1] version_int = version[0] << 8 | version[1]
return [ return [

View File

@@ -29,27 +29,32 @@ import functools
import inspect import inspect
import struct import struct
from typing import ( from typing import (
Any,
Awaitable, Awaitable,
Callable, Callable,
Generic,
Dict, Dict,
List, List,
Optional, Optional,
Type, Type,
TypeVar,
Union, Union,
TYPE_CHECKING, TYPE_CHECKING,
) )
from pyee import EventEmitter
from bumble import utils from bumble import utils
from bumble.core import UUID, name_or_number, ProtocolError from bumble.core import UUID, name_or_number, InvalidOperationError, ProtocolError
from bumble.hci import HCI_Object, key_with_value from bumble.hci import HCI_Object, key_with_value
from bumble.colors import color from bumble.colors import color
# -----------------------------------------------------------------------------
# Typing
# -----------------------------------------------------------------------------
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Connection from bumble.device import Connection
_T = TypeVar('_T')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -57,6 +62,7 @@ if TYPE_CHECKING:
# pylint: disable=line-too-long # pylint: disable=line-too-long
ATT_CID = 0x04 ATT_CID = 0x04
ATT_PSM = 0x001F
ATT_ERROR_RESPONSE = 0x01 ATT_ERROR_RESPONSE = 0x01
ATT_EXCHANGE_MTU_REQUEST = 0x02 ATT_EXCHANGE_MTU_REQUEST = 0x02
@@ -216,7 +222,12 @@ UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# Exceptions # Exceptions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ATT_Error(ProtocolError): class ATT_Error(ProtocolError):
def __init__(self, error_code, att_handle=0x0000, message=''): error_code: int
att_handle: int
def __init__(
self, error_code: int, att_handle: int = 0x0000, message: str = ''
) -> None:
super().__init__( super().__init__(
error_code, error_code,
error_namespace='att', error_namespace='att',
@@ -226,7 +237,10 @@ class ATT_Error(ProtocolError):
self.message = message self.message = message
def __str__(self): def __str__(self):
return f'ATT_Error(error={self.error_name}, handle={self.att_handle:04X}): {self.message}' return (
f'ATT_Error(error={self.error_name}, '
f'handle={self.att_handle:04X}): {self.message}'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -291,9 +305,6 @@ class ATT_PDU:
def init_from_bytes(self, pdu, offset): def init_from_bytes(self, pdu, offset):
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self):
return self.pdu
@property @property
def is_command(self): def is_command(self):
return ((self.op_code >> 6) & 1) == 1 return ((self.op_code >> 6) & 1) == 1
@@ -303,7 +314,7 @@ class ATT_PDU:
return ((self.op_code >> 7) & 1) == 1 return ((self.op_code >> 7) & 1) == 1
def __bytes__(self): def __bytes__(self):
return self.to_bytes() return self.pdu
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
@@ -750,7 +761,7 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AttributeValue: class AttributeValue(Generic[_T]):
''' '''
Attribute value where reading and/or writing is delegated to functions Attribute value where reading and/or writing is delegated to functions
passed as arguments to the constructor. passed as arguments to the constructor.
@@ -759,33 +770,34 @@ class AttributeValue:
def __init__( def __init__(
self, self,
read: Union[ read: Union[
Callable[[Optional[Connection]], bytes], Callable[[Optional[Connection]], _T],
Callable[[Optional[Connection]], Awaitable[bytes]], Callable[[Optional[Connection]], Awaitable[_T]],
None, None,
] = None, ] = None,
write: Union[ write: Union[
Callable[[Optional[Connection], bytes], None], Callable[[Optional[Connection], _T], None],
Callable[[Optional[Connection], bytes], Awaitable[None]], Callable[[Optional[Connection], _T], Awaitable[None]],
None, None,
] = None, ] = None,
): ):
self._read = read self._read = read
self._write = write self._write = write
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]: def read(self, connection: Optional[Connection]) -> Union[_T, Awaitable[_T]]:
return self._read(connection) if self._read else b'' if self._read is None:
raise InvalidOperationError('AttributeValue has no read function')
return self._read(connection)
def write( def write(
self, connection: Optional[Connection], value: bytes self, connection: Optional[Connection], value: _T
) -> Union[Awaitable[None], None]: ) -> Union[Awaitable[None], None]:
if self._write: if self._write is None:
return self._write(connection, value) raise InvalidOperationError('AttributeValue has no write function')
return self._write(connection, value)
return None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Attribute(EventEmitter): class Attribute(utils.EventEmitter, Generic[_T]):
class Permissions(enum.IntFlag): class Permissions(enum.IntFlag):
READABLE = 0x01 READABLE = 0x01
WRITEABLE = 0x02 WRITEABLE = 0x02
@@ -824,15 +836,18 @@ class Attribute(EventEmitter):
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
value: Union[bytes, AttributeValue] EVENT_READ = "read"
EVENT_WRITE = "write"
value: Union[AttributeValue[_T], _T, None]
def __init__( def __init__(
self, self,
attribute_type: Union[str, bytes, UUID], attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions], permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, AttributeValue] = b'', value: Union[AttributeValue[_T], _T, None] = None,
) -> None: ) -> None:
EventEmitter.__init__(self) utils.EventEmitter.__init__(self)
self.handle = 0 self.handle = 0
self.end_group_handle = 0 self.end_group_handle = 0
if isinstance(permissions, str): if isinstance(permissions, str):
@@ -848,17 +863,13 @@ class Attribute(EventEmitter):
else: else:
self.type = attribute_type self.type = attribute_type
# Convert the value to a byte array self.value = value
if isinstance(value, str):
self.value = bytes(value, 'utf-8')
else:
self.value = value
def encode_value(self, value: Any) -> bytes: def encode_value(self, value: _T) -> bytes:
return value return value # type: ignore
def decode_value(self, value_bytes: bytes) -> Any: def decode_value(self, value: bytes) -> _T:
return value_bytes return value # type: ignore
async def read_value(self, connection: Optional[Connection]) -> bytes: async def read_value(self, connection: Optional[Connection]) -> bytes:
if ( if (
@@ -883,11 +894,14 @@ class Attribute(EventEmitter):
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
) )
if hasattr(self.value, 'read'): value: Union[_T, None]
if isinstance(self.value, AttributeValue):
try: try:
value = self.value.read(connection) read_value = self.value.read(connection)
if inspect.isawaitable(value): if inspect.isawaitable(read_value):
value = await value value = await read_value
else:
value = read_value
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error( raise ATT_Error(
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle
@@ -895,18 +909,24 @@ class Attribute(EventEmitter):
else: else:
value = self.value value = self.value
return self.encode_value(value) self.emit(self.EVENT_READ, connection, b'' if value is None else value)
async def write_value(self, connection: Connection, value_bytes: bytes) -> None: return b'' if value is None else self.encode_value(value)
async def write_value(self, connection: Optional[Connection], value: bytes) -> None:
if ( if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION (self.permissions & self.WRITE_REQUIRES_ENCRYPTION)
) and not connection.encryption: and connection is not None
and not connection.encryption
):
raise ATT_Error( raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
) )
if ( if (
self.permissions & self.WRITE_REQUIRES_AUTHENTICATION (self.permissions & self.WRITE_REQUIRES_AUTHENTICATION)
) and not connection.authenticated: and connection is not None
and not connection.authenticated
):
raise ATT_Error( raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
) )
@@ -916,11 +936,11 @@ class Attribute(EventEmitter):
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
) )
value = self.decode_value(value_bytes) decoded_value = self.decode_value(value)
if hasattr(self.value, 'write'): if isinstance(self.value, AttributeValue):
try: try:
result = self.value.write(connection, value) result = self.value.write(connection, decoded_value)
if inspect.isawaitable(result): if inspect.isawaitable(result):
await result await result
except ATT_Error as error: except ATT_Error as error:
@@ -928,9 +948,9 @@ class Attribute(EventEmitter):
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle
) from error ) from error
else: else:
self.value = value self.value = decoded_value
self.emit('write', connection, value) self.emit(self.EVENT_WRITE, connection, decoded_value)
def __repr__(self): def __repr__(self):
if isinstance(self.value, bytes): if isinstance(self.value, bytes):

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC # Copyright 2025 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from setuptools import setup # -----------------------------------------------------------------------------
# Imports
setup() # -----------------------------------------------------------------------------

553
bumble/audio/io.py Normal file
View File

@@ -0,0 +1,553 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import abc
from concurrent.futures import ThreadPoolExecutor
import dataclasses
import enum
import logging
import pathlib
from typing import (
AsyncGenerator,
BinaryIO,
TYPE_CHECKING,
)
import sys
import wave
from bumble.colors import color
if TYPE_CHECKING:
import sounddevice # type: ignore[import-untyped]
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class PcmFormat:
class Endianness(enum.Enum):
LITTLE = 0
BIG = 1
class SampleType(enum.Enum):
FLOAT32 = 0
INT16 = 1
endianness: Endianness
sample_type: SampleType
sample_rate: int
channels: int
@classmethod
def from_str(cls, format_str: str) -> PcmFormat:
endianness = cls.Endianness.LITTLE # Others not yet supported.
sample_type_str, sample_rate_str, channels_str = format_str.split(',')
if sample_type_str == 'int16le':
sample_type = cls.SampleType.INT16
elif sample_type_str == 'float32le':
sample_type = cls.SampleType.FLOAT32
else:
raise ValueError(f'sample type {sample_type_str} not supported')
sample_rate = int(sample_rate_str)
channels = int(channels_str)
return cls(endianness, sample_type, sample_rate, channels)
@property
def bytes_per_sample(self) -> int:
return 2 if self.sample_type == self.SampleType.INT16 else 4
def check_audio_output(output: str) -> bool:
if output == 'device' or output.startswith('device:'):
try:
import sounddevice
except ImportError as exc:
raise ValueError(
'audio output not available (sounddevice python module not installed)'
) from exc
except OSError as exc:
raise ValueError(
'audio output not available '
'(sounddevice python module failed to load: '
f'{exc})'
) from exc
if output == 'device':
# Default device
return True
# Specific device
device = output[7:]
if device == '?':
print(color('Audio Devices:', 'yellow'))
for device_info in [
device_info
for device_info in sounddevice.query_devices()
if device_info['max_output_channels'] > 0
]:
device_index = device_info['index']
is_default = (
color(' [default]', 'green')
if sounddevice.default.device[1] == device_index
else ''
)
print(
f'{color(device_index, "cyan")}: {device_info["name"]}{is_default}'
)
return False
try:
device_info = sounddevice.query_devices(int(device))
except sounddevice.PortAudioError as exc:
raise ValueError('No such audio device') from exc
if device_info['max_output_channels'] < 1:
raise ValueError(
f'Device {device} ({device_info["name"]}) does not have an output'
)
return True
async def create_audio_output(output: str) -> AudioOutput:
if output == 'stdout':
return StreamAudioOutput(sys.stdout.buffer)
if output == 'device' or output.startswith('device:'):
device_name = '' if output == 'device' else output[7:]
return SoundDeviceAudioOutput(device_name)
if output == 'ffplay':
return SubprocessAudioOutput(
command=(
'ffplay -probesize 32 -fflags nobuffer -analyzeduration 0 '
'-ar {sample_rate} '
'-ch_layout {channel_layout} '
'-f f32le pipe:0'
)
)
if output.startswith('file:'):
return FileAudioOutput(output[5:])
raise ValueError('unsupported audio output')
class AudioOutput(abc.ABC):
"""Audio output to which PCM samples can be written."""
async def open(self, pcm_format: PcmFormat) -> None:
"""Start the output."""
@abc.abstractmethod
def write(self, pcm_samples: bytes) -> None:
"""Write PCM samples. Must not block."""
async def aclose(self) -> None:
"""Close the output."""
class ThreadedAudioOutput(AudioOutput):
"""Base class for AudioOutput classes that may need to call blocking functions.
The actual writing is performed in a thread, so as to ensure that calling write()
does not block the caller.
"""
def __init__(self) -> None:
self._thread_pool = ThreadPoolExecutor(1)
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
self._write_task = asyncio.create_task(self._write_loop())
async def _write_loop(self) -> None:
while True:
pcm_samples = await self._pcm_samples.get()
await asyncio.get_running_loop().run_in_executor(
self._thread_pool, self._write, pcm_samples
)
@abc.abstractmethod
def _write(self, pcm_samples: bytes) -> None:
"""This method does the actual writing and can block."""
def write(self, pcm_samples: bytes) -> None:
self._pcm_samples.put_nowait(pcm_samples)
def _close(self) -> None:
"""This method does the actual closing and can block."""
async def aclose(self) -> None:
await asyncio.get_running_loop().run_in_executor(self._thread_pool, self._close)
self._write_task.cancel()
self._thread_pool.shutdown()
class SoundDeviceAudioOutput(ThreadedAudioOutput):
def __init__(self, device_name: str) -> None:
super().__init__()
self._device = int(device_name) if device_name else None
self._stream: sounddevice.RawOutputStream | None = None
async def open(self, pcm_format: PcmFormat) -> None:
import sounddevice # pylint: disable=import-error
self._stream = sounddevice.RawOutputStream(
samplerate=pcm_format.sample_rate,
device=self._device,
channels=pcm_format.channels,
dtype='float32',
)
self._stream.start()
def _write(self, pcm_samples: bytes) -> None:
if self._stream is None:
return
try:
self._stream.write(pcm_samples)
except Exception as error:
print(f'Sound device error: {error}')
raise
def _close(self):
self._stream.stop()
self._stream = None
class StreamAudioOutput(ThreadedAudioOutput):
"""AudioOutput where PCM samples are written to a stream that may block."""
def __init__(self, stream: BinaryIO) -> None:
super().__init__()
self._stream = stream
def _write(self, pcm_samples: bytes) -> None:
self._stream.write(pcm_samples)
self._stream.flush()
class FileAudioOutput(StreamAudioOutput):
"""AudioOutput where PCM samples are written to a file."""
def __init__(self, filename: str) -> None:
self._file = open(filename, "wb")
super().__init__(self._file)
async def shutdown(self):
self._file.close()
return await super().shutdown()
class SubprocessAudioOutput(AudioOutput):
"""AudioOutput where audio samples are written to a subprocess via stdin."""
def __init__(self, command: str) -> None:
self._command = command
self._subprocess: asyncio.subprocess.Process | None
async def open(self, pcm_format: PcmFormat) -> None:
if pcm_format.channels == 1:
channel_layout = 'mono'
elif pcm_format.channels == 2:
channel_layout = 'stereo'
else:
raise ValueError(f'{pcm_format.channels} channels not supported')
command = self._command.format(
sample_rate=pcm_format.sample_rate, channel_layout=channel_layout
)
self._subprocess = await asyncio.create_subprocess_shell(
command,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
def write(self, pcm_samples: bytes) -> None:
if self._subprocess is None or self._subprocess.stdin is None:
return
self._subprocess.stdin.write(pcm_samples)
async def aclose(self):
if self._subprocess:
self._subprocess.terminate()
def check_audio_input(input: str) -> bool:
if input == 'device' or input.startswith('device:'):
try:
import sounddevice # pylint: disable=import-error
except ImportError as exc:
raise ValueError(
'audio input not available (sounddevice python module not installed)'
) from exc
except OSError as exc:
raise ValueError(
'audio input not available '
'(sounddevice python module failed to load: '
f'{exc})'
) from exc
if input == 'device':
# Default device
return True
# Specific device
device = input[7:]
if device == '?':
print(color('Audio Devices:', 'yellow'))
for device_info in [
device_info
for device_info in sounddevice.query_devices()
if device_info['max_input_channels'] > 0
]:
device_index = device_info["index"]
is_mono = device_info['max_input_channels'] == 1
max_channels = color(f'[{"mono" if is_mono else "stereo"}]', 'cyan')
is_default = (
color(' [default]', 'green')
if sounddevice.default.device[0] == device_index
else ''
)
print(
f'{color(device_index, "cyan")}: {device_info["name"]}'
f' {max_channels}{is_default}'
)
return False
try:
device_info = sounddevice.query_devices(int(device))
except sounddevice.PortAudioError as exc:
raise ValueError('No such audio device') from exc
if device_info['max_input_channels'] < 1:
raise ValueError(
f'Device {device} ({device_info["name"]}) does not have an input'
)
return True
async def create_audio_input(input: str, input_format: str) -> AudioInput:
pcm_format: PcmFormat | None
if input_format == 'auto':
pcm_format = None
else:
pcm_format = PcmFormat.from_str(input_format)
if input == 'stdin':
if not pcm_format:
raise ValueError('input format details required for stdin')
return StreamAudioInput(sys.stdin.buffer, pcm_format)
if input == 'device' or input.startswith('device:'):
if not pcm_format:
raise ValueError('input format details required for device')
device_name = '' if input == 'device' else input[7:]
return SoundDeviceAudioInput(device_name, pcm_format)
# If there's no file: prefix, check if we can assume it is a file.
if pathlib.Path(input).is_file():
input = 'file:' + input
if input.startswith('file:'):
filename = input[5:]
if filename.endswith('.wav'):
if input_format != 'auto':
raise ValueError(".wav file only supported with 'auto' format")
return WaveAudioInput(filename)
if pcm_format is None:
raise ValueError('input format details required for raw PCM files')
return FileAudioInput(filename, pcm_format)
raise ValueError('input not supported')
class AudioInput(abc.ABC):
"""Audio input that produces PCM samples."""
@abc.abstractmethod
async def open(self) -> PcmFormat:
"""Open the input."""
@abc.abstractmethod
def frames(self, frame_size: int) -> AsyncGenerator[bytes]:
"""Generate one frame of PCM samples. Must not block."""
async def aclose(self) -> None:
"""Close the input."""
class ThreadedAudioInput(AudioInput):
"""Base class for AudioInput implementation where reading samples may block."""
def __init__(self) -> None:
self._thread_pool = ThreadPoolExecutor(1)
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
@abc.abstractmethod
def _read(self, frame_size: int) -> bytes:
pass
@abc.abstractmethod
def _open(self) -> PcmFormat:
pass
def _close(self) -> None:
pass
async def open(self) -> PcmFormat:
return await asyncio.get_running_loop().run_in_executor(
self._thread_pool, self._open
)
async def frames(self, frame_size: int) -> AsyncGenerator[bytes]:
while pcm_sample := await asyncio.get_running_loop().run_in_executor(
self._thread_pool, self._read, frame_size
):
yield pcm_sample
async def aclose(self) -> None:
await asyncio.get_running_loop().run_in_executor(self._thread_pool, self._close)
self._thread_pool.shutdown()
class WaveAudioInput(ThreadedAudioInput):
"""Audio input that reads PCM samples from a .wav file."""
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._wav: wave.Wave_read | None = None
self._bytes_read = 0
def _open(self) -> PcmFormat:
self._wav = wave.open(self._filename, 'rb')
if self._wav.getsampwidth() != 2:
raise ValueError('sample width not supported')
return PcmFormat(
PcmFormat.Endianness.LITTLE,
PcmFormat.SampleType.INT16,
self._wav.getframerate(),
self._wav.getnchannels(),
)
def _read(self, frame_size: int) -> bytes:
if not self._wav:
return b''
pcm_samples = self._wav.readframes(frame_size)
if not pcm_samples and self._bytes_read:
# Loop around.
self._wav.rewind()
self._bytes_read = 0
pcm_samples = self._wav.readframes(frame_size)
self._bytes_read += len(pcm_samples)
return pcm_samples
def _close(self) -> None:
if self._wav:
self._wav.close()
class StreamAudioInput(ThreadedAudioInput):
"""AudioInput where samples are read from a raw PCM stream that may block."""
def __init__(self, stream: BinaryIO, pcm_format: PcmFormat) -> None:
super().__init__()
self._stream = stream
self._pcm_format = pcm_format
def _open(self) -> PcmFormat:
return self._pcm_format
def _read(self, frame_size: int) -> bytes:
return self._stream.read(
frame_size * self._pcm_format.channels * self._pcm_format.bytes_per_sample
)
class FileAudioInput(StreamAudioInput):
"""AudioInput where PCM samples are read from a raw PCM file."""
def __init__(self, filename: str, pcm_format: PcmFormat) -> None:
self._stream = open(filename, "rb")
super().__init__(self._stream, pcm_format)
def _close(self) -> None:
self._stream.close()
class SoundDeviceAudioInput(ThreadedAudioInput):
def __init__(self, device_name: str, pcm_format: PcmFormat) -> None:
super().__init__()
self._device = int(device_name) if device_name else None
self._pcm_format = pcm_format
self._stream: sounddevice.RawInputStream | None = None
def _open(self) -> PcmFormat:
import sounddevice # pylint: disable=import-error
self._stream = sounddevice.RawInputStream(
samplerate=self._pcm_format.sample_rate,
device=self._device,
channels=self._pcm_format.channels,
dtype='int16',
)
self._stream.start()
return PcmFormat(
PcmFormat.Endianness.LITTLE,
PcmFormat.SampleType.INT16,
self._pcm_format.sample_rate,
2,
)
def _read(self, frame_size: int) -> bytes:
if not self._stream:
return b''
pcm_buffer, overflowed = self._stream.read(frame_size)
if overflowed:
logger.warning("input overflow")
# Convert the buffer to stereo if needed
if self._pcm_format.channels == 1:
stereo_buffer = bytearray()
for i in range(frame_size):
sample = pcm_buffer[i * 2 : i * 2 + 2]
stereo_buffer += sample + sample
return stereo_buffer
return bytes(pcm_buffer)
def _close(self):
self._stream.stop()
self._stream = None

View File

@@ -21,7 +21,7 @@ import struct
from typing import Dict, Type, Union, Tuple from typing import Dict, Type, Union, Tuple
from bumble import core from bumble import core
from bumble.utils import OpenIntEnum from bumble import utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -43,7 +43,7 @@ class Frame:
EXTENDED = 0x1E EXTENDED = 0x1E
UNIT = 0x1F UNIT = 0x1F
class OperationCode(OpenIntEnum): class OperationCode(utils.OpenIntEnum):
# 0x00 - 0x0F: Unit and subunit commands # 0x00 - 0x0F: Unit and subunit commands
VENDOR_DEPENDENT = 0x00 VENDOR_DEPENDENT = 0x00
RESERVE = 0x01 RESERVE = 0x01
@@ -204,7 +204,7 @@ class Frame:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CommandFrame(Frame): class CommandFrame(Frame):
class CommandType(OpenIntEnum): class CommandType(utils.OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1 # AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.1 # Table 7.1
CONTROL = 0x00 CONTROL = 0x00
@@ -240,7 +240,7 @@ class CommandFrame(Frame):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ResponseFrame(Frame): class ResponseFrame(Frame):
class ResponseCode(OpenIntEnum): class ResponseCode(utils.OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1 # AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.2 # Table 7.2
NOT_IMPLEMENTED = 0x08 NOT_IMPLEMENTED = 0x08
@@ -368,7 +368,7 @@ class PassThroughFrame:
PRESSED = 0 PRESSED = 0
RELEASED = 1 RELEASED = 1
class OperationId(OpenIntEnum): class OperationId(utils.OpenIntEnum):
SELECT = 0x00 SELECT = 0x00
UP = 0x01 UP = 0x01
DOWN = 0x01 DOWN = 0x01

View File

@@ -166,8 +166,8 @@ class Protocol:
# Register to receive PDUs from the channel # Register to receive PDUs from the channel
l2cap_channel.sink = self.on_pdu l2cap_channel.sink = self.on_pdu
l2cap_channel.on("open", self.on_l2cap_channel_open) l2cap_channel.on(l2cap_channel.EVENT_OPEN, self.on_l2cap_channel_open)
l2cap_channel.on("close", self.on_l2cap_channel_close) l2cap_channel.on(l2cap_channel.EVENT_CLOSE, self.on_l2cap_channel_close)
def on_l2cap_channel_open(self): def on_l2cap_channel_open(self):
logger.debug(color("<<< AVCTP channel open", "magenta")) logger.debug(color("<<< AVCTP channel open", "magenta"))

View File

@@ -37,16 +37,15 @@ from typing import (
cast, cast,
) )
from pyee import EventEmitter
from .core import ( from bumble.core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE, BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
InvalidStateError, InvalidStateError,
ProtocolError, ProtocolError,
InvalidArgumentError, InvalidArgumentError,
name_or_number, name_or_number,
) )
from .a2dp import ( from bumble.a2dp import (
A2DP_CODEC_TYPE_NAMES, A2DP_CODEC_TYPE_NAMES,
A2DP_MPEG_2_4_AAC_CODEC_TYPE, A2DP_MPEG_2_4_AAC_CODEC_TYPE,
A2DP_NON_A2DP_CODEC_TYPE, A2DP_NON_A2DP_CODEC_TYPE,
@@ -56,9 +55,9 @@ from .a2dp import (
SbcMediaCodecInformation, SbcMediaCodecInformation,
VendorSpecificMediaCodecInformation, VendorSpecificMediaCodecInformation,
) )
from .rtp import MediaPacket from bumble.rtp import MediaPacket
from . import sdp, device, l2cap from bumble import sdp, device, l2cap, utils
from .colors import color from bumble.colors import color
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -897,7 +896,7 @@ class Set_Configuration_Reject(Message):
self.service_category = self.payload[0] self.service_category = self.payload[0]
self.error_code = self.payload[1] self.error_code = self.payload[1]
def __init__(self, service_category, error_code): def __init__(self, error_code: int, service_category: int = 0) -> None:
super().__init__(payload=bytes([service_category, error_code])) super().__init__(payload=bytes([service_category, error_code]))
self.service_category = service_category self.service_category = service_category
self.error_code = error_code self.error_code = error_code
@@ -1133,6 +1132,14 @@ class Security_Control_Command(Message):
See Bluetooth AVDTP spec - 8.17.1 Security Control Command See Bluetooth AVDTP spec - 8.17.1 Security Control Command
''' '''
def init_from_payload(self):
# pylint: disable=attribute-defined-outside-init
self.acp_seid = self.payload[0] >> 2
self.data = self.payload[1:]
def __str__(self) -> str:
return self.to_string([f'ACP_SEID: {self.acp_seid}', f'data: {self.data}'])
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@Message.subclass @Message.subclass
@@ -1194,13 +1201,16 @@ class DelayReport_Reject(Simple_Reject):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Protocol(EventEmitter): class Protocol(utils.EventEmitter):
local_endpoints: List[LocalStreamEndPoint] local_endpoints: List[LocalStreamEndPoint]
remote_endpoints: Dict[int, DiscoveredStreamEndPoint] remote_endpoints: Dict[int, DiscoveredStreamEndPoint]
streams: Dict[int, Stream] streams: Dict[int, Stream]
transaction_results: List[Optional[asyncio.Future[Message]]] transaction_results: List[Optional[asyncio.Future[Message]]]
channel_connector: Callable[[], Awaitable[l2cap.ClassicChannel]] channel_connector: Callable[[], Awaitable[l2cap.ClassicChannel]]
EVENT_OPEN = "open"
EVENT_CLOSE = "close"
class PacketType(enum.IntEnum): class PacketType(enum.IntEnum):
SINGLE_PACKET = 0 SINGLE_PACKET = 0
START_PACKET = 1 START_PACKET = 1
@@ -1240,8 +1250,8 @@ class Protocol(EventEmitter):
# Register to receive PDUs from the channel # Register to receive PDUs from the channel
l2cap_channel.sink = self.on_pdu l2cap_channel.sink = self.on_pdu
l2cap_channel.on('open', self.on_l2cap_channel_open) l2cap_channel.on(l2cap_channel.EVENT_OPEN, self.on_l2cap_channel_open)
l2cap_channel.on('close', self.on_l2cap_channel_close) l2cap_channel.on(l2cap_channel.EVENT_CLOSE, self.on_l2cap_channel_close)
def get_local_endpoint_by_seid(self, seid: int) -> Optional[LocalStreamEndPoint]: def get_local_endpoint_by_seid(self, seid: int) -> Optional[LocalStreamEndPoint]:
if 0 < seid <= len(self.local_endpoints): if 0 < seid <= len(self.local_endpoints):
@@ -1411,20 +1421,20 @@ class Protocol(EventEmitter):
self.transaction_results[transaction_label] = None self.transaction_results[transaction_label] = None
self.transaction_semaphore.release() self.transaction_semaphore.release()
def on_l2cap_connection(self, channel): def on_l2cap_connection(self, channel: l2cap.ClassicChannel) -> None:
# Forward the channel to the endpoint that's expecting it # Forward the channel to the endpoint that's expecting it
if self.channel_acceptor is None: if self.channel_acceptor is None:
logger.warning(color('!!! l2cap connection with no acceptor', 'red')) logger.warning(color('!!! l2cap connection with no acceptor', 'red'))
return return
self.channel_acceptor.on_l2cap_connection(channel) self.channel_acceptor.on_l2cap_connection(channel)
def on_l2cap_channel_open(self): def on_l2cap_channel_open(self) -> None:
logger.debug(color('<<< L2CAP channel open', 'magenta')) logger.debug(color('<<< L2CAP channel open', 'magenta'))
self.emit('open') self.emit(self.EVENT_OPEN)
def on_l2cap_channel_close(self): def on_l2cap_channel_close(self) -> None:
logger.debug(color('<<< L2CAP channel close', 'magenta')) logger.debug(color('<<< L2CAP channel close', 'magenta'))
self.emit('close') self.emit(self.EVENT_CLOSE)
def send_message(self, transaction_label: int, message: Message) -> None: def send_message(self, transaction_label: int, message: Message) -> None:
logger.debug( logger.debug(
@@ -1542,28 +1552,34 @@ class Protocol(EventEmitter):
async def abort(self, seid: int) -> Abort_Response: async def abort(self, seid: int) -> Abort_Response:
return await self.send_command(Abort_Command(seid)) return await self.send_command(Abort_Command(seid))
def on_discover_command(self, _command): def on_discover_command(self, command: Discover_Command) -> Optional[Message]:
endpoint_infos = [ endpoint_infos = [
EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep) EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep)
for endpoint in self.local_endpoints for endpoint in self.local_endpoints
] ]
return Discover_Response(endpoint_infos) return Discover_Response(endpoint_infos)
def on_get_capabilities_command(self, command): def on_get_capabilities_command(
self, command: Get_Capabilities_Command
) -> Optional[Message]:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return Get_Capabilities_Reject(AVDTP_BAD_ACP_SEID_ERROR) return Get_Capabilities_Reject(AVDTP_BAD_ACP_SEID_ERROR)
return Get_Capabilities_Response(endpoint.capabilities) return Get_Capabilities_Response(endpoint.capabilities)
def on_get_all_capabilities_command(self, command): def on_get_all_capabilities_command(
self, command: Get_All_Capabilities_Command
) -> Optional[Message]:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return Get_All_Capabilities_Reject(AVDTP_BAD_ACP_SEID_ERROR) return Get_All_Capabilities_Reject(AVDTP_BAD_ACP_SEID_ERROR)
return Get_All_Capabilities_Response(endpoint.capabilities) return Get_All_Capabilities_Response(endpoint.capabilities)
def on_set_configuration_command(self, command): def on_set_configuration_command(
self, command: Set_Configuration_Command
) -> Optional[Message]:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return Set_Configuration_Reject(AVDTP_BAD_ACP_SEID_ERROR) return Set_Configuration_Reject(AVDTP_BAD_ACP_SEID_ERROR)
@@ -1579,7 +1595,9 @@ class Protocol(EventEmitter):
result = stream.on_set_configuration_command(command.capabilities) result = stream.on_set_configuration_command(command.capabilities)
return result or Set_Configuration_Response() return result or Set_Configuration_Response()
def on_get_configuration_command(self, command): def on_get_configuration_command(
self, command: Get_Configuration_Command
) -> Optional[Message]:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return Get_Configuration_Reject(AVDTP_BAD_ACP_SEID_ERROR) return Get_Configuration_Reject(AVDTP_BAD_ACP_SEID_ERROR)
@@ -1588,7 +1606,7 @@ class Protocol(EventEmitter):
return endpoint.stream.on_get_configuration_command() return endpoint.stream.on_get_configuration_command()
def on_reconfigure_command(self, command): def on_reconfigure_command(self, command: Reconfigure_Command) -> Optional[Message]:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return Reconfigure_Reject(0, AVDTP_BAD_ACP_SEID_ERROR) return Reconfigure_Reject(0, AVDTP_BAD_ACP_SEID_ERROR)
@@ -1598,7 +1616,7 @@ class Protocol(EventEmitter):
result = endpoint.stream.on_reconfigure_command(command.capabilities) result = endpoint.stream.on_reconfigure_command(command.capabilities)
return result or Reconfigure_Response() return result or Reconfigure_Response()
def on_open_command(self, command): def on_open_command(self, command: Open_Command) -> Optional[Message]:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR) return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR)
@@ -1608,25 +1626,26 @@ class Protocol(EventEmitter):
result = endpoint.stream.on_open_command() result = endpoint.stream.on_open_command()
return result or Open_Response() return result or Open_Response()
def on_start_command(self, command): def on_start_command(self, command: Start_Command) -> Optional[Message]:
for seid in command.acp_seids: for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid) endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None: if endpoint is None:
return Start_Reject(seid, AVDTP_BAD_ACP_SEID_ERROR) return Start_Reject(seid, AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None: if endpoint.stream is None:
return Start_Reject(AVDTP_BAD_STATE_ERROR) return Start_Reject(seid, AVDTP_BAD_STATE_ERROR)
# Start all streams # Start all streams
# TODO: deal with partial failures # TODO: deal with partial failures
for seid in command.acp_seids: for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid) endpoint = self.get_local_endpoint_by_seid(seid)
result = endpoint.stream.on_start_command() if not endpoint or not endpoint.stream:
if result is not None: raise InvalidStateError("Should already be checked!")
if (result := endpoint.stream.on_start_command()) is not None:
return result return result
return Start_Response() return Start_Response()
def on_suspend_command(self, command): def on_suspend_command(self, command: Suspend_Command) -> Optional[Message]:
for seid in command.acp_seids: for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid) endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None: if endpoint is None:
@@ -1638,13 +1657,14 @@ class Protocol(EventEmitter):
# TODO: deal with partial failures # TODO: deal with partial failures
for seid in command.acp_seids: for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid) endpoint = self.get_local_endpoint_by_seid(seid)
result = endpoint.stream.on_suspend_command() if not endpoint or not endpoint.stream:
if result is not None: raise InvalidStateError("Should already be checked!")
if (result := endpoint.stream.on_suspend_command()) is not None:
return result return result
return Suspend_Response() return Suspend_Response()
def on_close_command(self, command): def on_close_command(self, command: Close_Command) -> Optional[Message]:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR) return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR)
@@ -1654,7 +1674,7 @@ class Protocol(EventEmitter):
result = endpoint.stream.on_close_command() result = endpoint.stream.on_close_command()
return result or Close_Response() return result or Close_Response()
def on_abort_command(self, command): def on_abort_command(self, command: Abort_Command) -> Optional[Message]:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None or endpoint.stream is None: if endpoint is None or endpoint.stream is None:
return Abort_Response() return Abort_Response()
@@ -1662,15 +1682,17 @@ class Protocol(EventEmitter):
endpoint.stream.on_abort_command() endpoint.stream.on_abort_command()
return Abort_Response() return Abort_Response()
def on_security_control_command(self, command): def on_security_control_command(
self, command: Security_Control_Command
) -> Optional[Message]:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR) return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR)
result = endpoint.on_security_control_command(command.payload) result = endpoint.on_security_control_command(command.data)
return result or Security_Control_Response() return result or Security_Control_Response()
def on_delayreport_command(self, command): def on_delayreport_command(self, command: DelayReport_Command) -> Optional[Message]:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid) endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None: if endpoint is None:
return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR) return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR)
@@ -1680,9 +1702,11 @@ class Protocol(EventEmitter):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Listener(EventEmitter): class Listener(utils.EventEmitter):
servers: Dict[int, Protocol] servers: Dict[int, Protocol]
EVENT_CONNECTION = "connection"
@staticmethod @staticmethod
def create_registrar(device: device.Device): def create_registrar(device: device.Device):
warnings.warn("Please use Listener.for_device()", DeprecationWarning) warnings.warn("Please use Listener.for_device()", DeprecationWarning)
@@ -1717,7 +1741,7 @@ class Listener(EventEmitter):
l2cap_server = device.create_l2cap_server( l2cap_server = device.create_l2cap_server(
spec=l2cap.ClassicChannelSpec(psm=AVDTP_PSM) spec=l2cap.ClassicChannelSpec(psm=AVDTP_PSM)
) )
l2cap_server.on('connection', listener.on_l2cap_connection) l2cap_server.on(l2cap_server.EVENT_CONNECTION, listener.on_l2cap_connection)
return listener return listener
def on_l2cap_connection(self, channel: l2cap.ClassicChannel) -> None: def on_l2cap_connection(self, channel: l2cap.ClassicChannel) -> None:
@@ -1733,14 +1757,14 @@ class Listener(EventEmitter):
logger.debug('setting up new Protocol for the connection') logger.debug('setting up new Protocol for the connection')
server = Protocol(channel, self.version) server = Protocol(channel, self.version)
self.set_server(channel.connection, server) self.set_server(channel.connection, server)
self.emit('connection', server) self.emit(self.EVENT_CONNECTION, server)
def on_channel_close(): def on_channel_close():
logger.debug('removing Protocol for the connection') logger.debug('removing Protocol for the connection')
self.remove_server(channel.connection) self.remove_server(channel.connection)
channel.on('open', on_channel_open) channel.on(channel.EVENT_OPEN, on_channel_open)
channel.on('close', on_channel_close) channel.on(channel.EVENT_CLOSE, on_channel_close)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1789,6 +1813,7 @@ class Stream:
) )
async def start(self) -> None: async def start(self) -> None:
"""[Source] Start streaming."""
# Auto-open if needed # Auto-open if needed
if self.state == AVDTP_CONFIGURED_STATE: if self.state == AVDTP_CONFIGURED_STATE:
await self.open() await self.open()
@@ -1805,6 +1830,7 @@ class Stream:
self.change_state(AVDTP_STREAMING_STATE) self.change_state(AVDTP_STREAMING_STATE)
async def stop(self) -> None: async def stop(self) -> None:
"""[Source] Stop streaming and transit to OPEN state."""
if self.state != AVDTP_STREAMING_STATE: if self.state != AVDTP_STREAMING_STATE:
raise InvalidStateError('current state is not STREAMING') raise InvalidStateError('current state is not STREAMING')
@@ -1817,6 +1843,7 @@ class Stream:
self.change_state(AVDTP_OPEN_STATE) self.change_state(AVDTP_OPEN_STATE)
async def close(self) -> None: async def close(self) -> None:
"""[Source] Close channel and transit to IDLE state."""
if self.state not in (AVDTP_OPEN_STATE, AVDTP_STREAMING_STATE): if self.state not in (AVDTP_OPEN_STATE, AVDTP_STREAMING_STATE):
raise InvalidStateError('current state is not OPEN or STREAMING') raise InvalidStateError('current state is not OPEN or STREAMING')
@@ -1848,7 +1875,7 @@ class Stream:
self.change_state(AVDTP_CONFIGURED_STATE) self.change_state(AVDTP_CONFIGURED_STATE)
return None return None
def on_get_configuration_command(self, configuration): def on_get_configuration_command(self):
if self.state not in ( if self.state not in (
AVDTP_CONFIGURED_STATE, AVDTP_CONFIGURED_STATE,
AVDTP_OPEN_STATE, AVDTP_OPEN_STATE,
@@ -1856,7 +1883,7 @@ class Stream:
): ):
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR) return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
return self.local_endpoint.on_get_configuration_command(configuration) return self.local_endpoint.on_get_configuration_command()
def on_reconfigure_command(self, configuration): def on_reconfigure_command(self, configuration):
if self.state != AVDTP_OPEN_STATE: if self.state != AVDTP_OPEN_STATE:
@@ -1936,20 +1963,20 @@ class Stream:
# Wait for the RTP channel to be closed # Wait for the RTP channel to be closed
self.change_state(AVDTP_ABORTING_STATE) self.change_state(AVDTP_ABORTING_STATE)
def on_l2cap_connection(self, channel): def on_l2cap_connection(self, channel: l2cap.ClassicChannel) -> None:
logger.debug(color('<<< stream channel connected', 'magenta')) logger.debug(color('<<< stream channel connected', 'magenta'))
self.rtp_channel = channel self.rtp_channel = channel
channel.on('open', self.on_l2cap_channel_open) channel.on(channel.EVENT_OPEN, self.on_l2cap_channel_open)
channel.on('close', self.on_l2cap_channel_close) channel.on(channel.EVENT_CLOSE, self.on_l2cap_channel_close)
# We don't need more channels # We don't need more channels
self.protocol.channel_acceptor = None self.protocol.channel_acceptor = None
def on_l2cap_channel_open(self): def on_l2cap_channel_open(self) -> None:
logger.debug(color('<<< stream channel open', 'magenta')) logger.debug(color('<<< stream channel open', 'magenta'))
self.local_endpoint.on_rtp_channel_open() self.local_endpoint.on_rtp_channel_open()
def on_l2cap_channel_close(self): def on_l2cap_channel_close(self) -> None:
logger.debug(color('<<< stream channel closed', 'magenta')) logger.debug(color('<<< stream channel closed', 'magenta'))
self.local_endpoint.on_rtp_channel_close() self.local_endpoint.on_rtp_channel_close()
self.local_endpoint.in_use = 0 self.local_endpoint.in_use = 0
@@ -2063,9 +2090,22 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class LocalStreamEndPoint(StreamEndPoint, EventEmitter): class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
stream: Optional[Stream] stream: Optional[Stream]
EVENT_CONFIGURATION = "configuration"
EVENT_OPEN = "open"
EVENT_START = "start"
EVENT_STOP = "stop"
EVENT_RTP_PACKET = "rtp_packet"
EVENT_SUSPEND = "suspend"
EVENT_CLOSE = "close"
EVENT_ABORT = "abort"
EVENT_DELAY_REPORT = "delay_report"
EVENT_SECURITY_CONTROL = "security_control"
EVENT_RTP_CHANNEL_OPEN = "rtp_channel_open"
EVENT_RTP_CHANNEL_CLOSE = "rtp_channel_close"
def __init__( def __init__(
self, self,
protocol: Protocol, protocol: Protocol,
@@ -2076,57 +2116,70 @@ class LocalStreamEndPoint(StreamEndPoint, EventEmitter):
configuration: Optional[Iterable[ServiceCapabilities]] = None, configuration: Optional[Iterable[ServiceCapabilities]] = None,
): ):
StreamEndPoint.__init__(self, seid, media_type, tsep, 0, capabilities) StreamEndPoint.__init__(self, seid, media_type, tsep, 0, capabilities)
EventEmitter.__init__(self) utils.EventEmitter.__init__(self)
self.protocol = protocol self.protocol = protocol
self.configuration = configuration if configuration is not None else [] self.configuration = configuration if configuration is not None else []
self.stream = None self.stream = None
async def start(self): async def start(self) -> None:
pass """[Source Only] Handles when receiving start command."""
async def stop(self): async def stop(self) -> None:
pass """[Source Only] Handles when receiving stop command."""
async def close(self): async def close(self) -> None:
pass """[Source Only] Handles when receiving close command."""
def on_reconfigure_command(self, command): def on_reconfigure_command(self, command) -> Optional[Message]:
pass return None
def on_set_configuration_command(self, configuration): def on_set_configuration_command(self, configuration) -> Optional[Message]:
logger.debug( logger.debug(
'<<< received configuration: ' '<<< received configuration: '
f'{",".join([str(capability) for capability in configuration])}' f'{",".join([str(capability) for capability in configuration])}'
) )
self.configuration = configuration self.configuration = configuration
self.emit('configuration') self.emit(self.EVENT_CONFIGURATION)
return None
def on_get_configuration_command(self): def on_get_configuration_command(self) -> Optional[Message]:
return Get_Configuration_Response(self.configuration) return Get_Configuration_Response(self.configuration)
def on_open_command(self): def on_open_command(self) -> Optional[Message]:
self.emit('open') self.emit(self.EVENT_OPEN)
return None
def on_start_command(self): def on_start_command(self) -> Optional[Message]:
self.emit('start') self.emit(self.EVENT_START)
return None
def on_suspend_command(self): def on_suspend_command(self) -> Optional[Message]:
self.emit('suspend') self.emit(self.EVENT_SUSPEND)
return None
def on_close_command(self): def on_close_command(self) -> Optional[Message]:
self.emit('close') self.emit(self.EVENT_CLOSE)
return None
def on_abort_command(self): def on_abort_command(self) -> Optional[Message]:
self.emit('abort') self.emit(self.EVENT_ABORT)
return None
def on_delayreport_command(self, delay: int): def on_delayreport_command(self, delay: int) -> Optional[Message]:
self.emit('delay_report', delay) self.emit(self.EVENT_DELAY_REPORT, delay)
return None
def on_rtp_channel_open(self): def on_security_control_command(self, data: bytes) -> Optional[Message]:
self.emit('rtp_channel_open') self.emit(self.EVENT_SECURITY_CONTROL, data)
return None
def on_rtp_channel_close(self): def on_rtp_channel_open(self) -> None:
self.emit('rtp_channel_close') self.emit(self.EVENT_RTP_CHANNEL_OPEN)
return None
def on_rtp_channel_close(self) -> None:
self.emit(self.EVENT_RTP_CHANNEL_CLOSE)
return None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -2157,13 +2210,13 @@ class LocalSource(LocalStreamEndPoint):
if self.packet_pump and self.stream and self.stream.rtp_channel: if self.packet_pump and self.stream and self.stream.rtp_channel:
return await self.packet_pump.start(self.stream.rtp_channel) return await self.packet_pump.start(self.stream.rtp_channel)
self.emit('start') self.emit(self.EVENT_START)
async def stop(self) -> None: async def stop(self) -> None:
if self.packet_pump: if self.packet_pump:
return await self.packet_pump.stop() return await self.packet_pump.stop()
self.emit('stop') self.emit(self.EVENT_STOP)
def on_start_command(self): def on_start_command(self):
asyncio.create_task(self.start()) asyncio.create_task(self.start())
@@ -2204,4 +2257,4 @@ class LocalSink(LocalStreamEndPoint):
f'{color("<<< RTP Packet:", "green")} ' f'{color("<<< RTP Packet:", "green")} '
f'{rtp_packet} {rtp_packet.payload[:16].hex()}' f'{rtp_packet} {rtp_packet.payload[:16].hex()}'
) )
self.emit('rtp_packet', rtp_packet) self.emit(self.EVENT_RTP_PACKET, rtp_packet)

View File

@@ -38,7 +38,6 @@ from typing import (
Union, Union,
) )
import pyee
from bumble.colors import color from bumble.colors import color
from bumble.device import Device, Connection from bumble.device import Device, Connection
@@ -53,7 +52,7 @@ from bumble.sdp import (
DataElement, DataElement,
ServiceAttribute, ServiceAttribute,
) )
from bumble.utils import AsyncRunner, OpenIntEnum from bumble import utils
from bumble.core import ( from bumble.core import (
InvalidArgumentError, InvalidArgumentError,
ProtocolError, ProtocolError,
@@ -307,7 +306,7 @@ class Command:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class GetCapabilitiesCommand(Command): class GetCapabilitiesCommand(Command):
class CapabilityId(OpenIntEnum): class CapabilityId(utils.OpenIntEnum):
COMPANY_ID = 0x02 COMPANY_ID = 0x02
EVENTS_SUPPORTED = 0x03 EVENTS_SUPPORTED = 0x03
@@ -637,7 +636,7 @@ class RegisterNotificationResponse(Response):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class EventId(OpenIntEnum): class EventId(utils.OpenIntEnum):
PLAYBACK_STATUS_CHANGED = 0x01 PLAYBACK_STATUS_CHANGED = 0x01
TRACK_CHANGED = 0x02 TRACK_CHANGED = 0x02
TRACK_REACHED_END = 0x03 TRACK_REACHED_END = 0x03
@@ -657,12 +656,12 @@ class EventId(OpenIntEnum):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CharacterSetId(OpenIntEnum): class CharacterSetId(utils.OpenIntEnum):
UTF_8 = 0x06 UTF_8 = 0x06
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class MediaAttributeId(OpenIntEnum): class MediaAttributeId(utils.OpenIntEnum):
TITLE = 0x01 TITLE = 0x01
ARTIST_NAME = 0x02 ARTIST_NAME = 0x02
ALBUM_NAME = 0x03 ALBUM_NAME = 0x03
@@ -682,7 +681,7 @@ class MediaAttribute:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PlayStatus(OpenIntEnum): class PlayStatus(utils.OpenIntEnum):
STOPPED = 0x00 STOPPED = 0x00
PLAYING = 0x01 PLAYING = 0x01
PAUSED = 0x02 PAUSED = 0x02
@@ -701,33 +700,33 @@ class SongAndPlayStatus:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ApplicationSetting: class ApplicationSetting:
class AttributeId(OpenIntEnum): class AttributeId(utils.OpenIntEnum):
EQUALIZER_ON_OFF = 0x01 EQUALIZER_ON_OFF = 0x01
REPEAT_MODE = 0x02 REPEAT_MODE = 0x02
SHUFFLE_ON_OFF = 0x03 SHUFFLE_ON_OFF = 0x03
SCAN_ON_OFF = 0x04 SCAN_ON_OFF = 0x04
class EqualizerOnOffStatus(OpenIntEnum): class EqualizerOnOffStatus(utils.OpenIntEnum):
OFF = 0x01 OFF = 0x01
ON = 0x02 ON = 0x02
class RepeatModeStatus(OpenIntEnum): class RepeatModeStatus(utils.OpenIntEnum):
OFF = 0x01 OFF = 0x01
SINGLE_TRACK_REPEAT = 0x02 SINGLE_TRACK_REPEAT = 0x02
ALL_TRACK_REPEAT = 0x03 ALL_TRACK_REPEAT = 0x03
GROUP_REPEAT = 0x04 GROUP_REPEAT = 0x04
class ShuffleOnOffStatus(OpenIntEnum): class ShuffleOnOffStatus(utils.OpenIntEnum):
OFF = 0x01 OFF = 0x01
ALL_TRACKS_SHUFFLE = 0x02 ALL_TRACKS_SHUFFLE = 0x02
GROUP_SHUFFLE = 0x03 GROUP_SHUFFLE = 0x03
class ScanOnOffStatus(OpenIntEnum): class ScanOnOffStatus(utils.OpenIntEnum):
OFF = 0x01 OFF = 0x01
ALL_TRACKS_SCAN = 0x02 ALL_TRACKS_SCAN = 0x02
GROUP_SCAN = 0x03 GROUP_SCAN = 0x03
class GenericValue(OpenIntEnum): class GenericValue(utils.OpenIntEnum):
pass pass
@@ -816,7 +815,7 @@ class PlayerApplicationSettingChangedEvent(Event):
@dataclass @dataclass
class Setting: class Setting:
attribute_id: ApplicationSetting.AttributeId attribute_id: ApplicationSetting.AttributeId
value_id: OpenIntEnum value_id: utils.OpenIntEnum
player_application_settings: List[Setting] player_application_settings: List[Setting]
@@ -824,7 +823,7 @@ class PlayerApplicationSettingChangedEvent(Event):
def from_bytes(cls, pdu: bytes) -> PlayerApplicationSettingChangedEvent: def from_bytes(cls, pdu: bytes) -> PlayerApplicationSettingChangedEvent:
def setting(attribute_id_int: int, value_id_int: int): def setting(attribute_id_int: int, value_id_int: int):
attribute_id = ApplicationSetting.AttributeId(attribute_id_int) attribute_id = ApplicationSetting.AttributeId(attribute_id_int)
value_id: OpenIntEnum value_id: utils.OpenIntEnum
if attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF: if attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF:
value_id = ApplicationSetting.EqualizerOnOffStatus(value_id_int) value_id = ApplicationSetting.EqualizerOnOffStatus(value_id_int)
elif attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE: elif attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE:
@@ -994,16 +993,20 @@ class Delegate:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Protocol(pyee.EventEmitter): class Protocol(utils.EventEmitter):
"""AVRCP Controller and Target protocol.""" """AVRCP Controller and Target protocol."""
EVENT_CONNECTION = "connection"
EVENT_START = "start"
EVENT_STOP = "stop"
class PacketType(enum.IntEnum): class PacketType(enum.IntEnum):
SINGLE = 0b00 SINGLE = 0b00
START = 0b01 START = 0b01
CONTINUE = 0b10 CONTINUE = 0b10
END = 0b11 END = 0b11
class PduId(OpenIntEnum): class PduId(utils.OpenIntEnum):
GET_CAPABILITIES = 0x10 GET_CAPABILITIES = 0x10
LIST_PLAYER_APPLICATION_SETTING_ATTRIBUTES = 0x11 LIST_PLAYER_APPLICATION_SETTING_ATTRIBUTES = 0x11
LIST_PLAYER_APPLICATION_SETTING_VALUES = 0x12 LIST_PLAYER_APPLICATION_SETTING_VALUES = 0x12
@@ -1024,7 +1027,7 @@ class Protocol(pyee.EventEmitter):
GET_FOLDER_ITEMS = 0x71 GET_FOLDER_ITEMS = 0x71
GET_TOTAL_NUMBER_OF_ITEMS = 0x75 GET_TOTAL_NUMBER_OF_ITEMS = 0x75
class StatusCode(OpenIntEnum): class StatusCode(utils.OpenIntEnum):
INVALID_COMMAND = 0x00 INVALID_COMMAND = 0x00
INVALID_PARAMETER = 0x01 INVALID_PARAMETER = 0x01
PARAMETER_CONTENT_ERROR = 0x02 PARAMETER_CONTENT_ERROR = 0x02
@@ -1457,16 +1460,18 @@ class Protocol(pyee.EventEmitter):
def _on_avctp_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: def _on_avctp_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug("AVCTP connection established") logger.debug("AVCTP connection established")
l2cap_channel.on("open", lambda: self._on_avctp_channel_open(l2cap_channel)) l2cap_channel.on(
l2cap_channel.EVENT_OPEN, lambda: self._on_avctp_channel_open(l2cap_channel)
)
self.emit("connection") self.emit(self.EVENT_CONNECTION)
def _on_avctp_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None: def _on_avctp_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug("AVCTP channel open") logger.debug("AVCTP channel open")
if self.avctp_protocol is not None: if self.avctp_protocol is not None:
# TODO: find a better strategy instead of just closing # TODO: find a better strategy instead of just closing
logger.warning("AVCTP protocol already active, closing connection") logger.warning("AVCTP protocol already active, closing connection")
AsyncRunner.spawn(l2cap_channel.disconnect()) utils.AsyncRunner.spawn(l2cap_channel.disconnect())
return return
self.avctp_protocol = avctp.Protocol(l2cap_channel) self.avctp_protocol = avctp.Protocol(l2cap_channel)
@@ -1474,15 +1479,15 @@ class Protocol(pyee.EventEmitter):
self.avctp_protocol.register_response_handler( self.avctp_protocol.register_response_handler(
AVRCP_PID, self._on_avctp_response AVRCP_PID, self._on_avctp_response
) )
l2cap_channel.on("close", self._on_avctp_channel_close) l2cap_channel.on(l2cap_channel.EVENT_CLOSE, self._on_avctp_channel_close)
self.emit("start") self.emit(self.EVENT_START)
def _on_avctp_channel_close(self) -> None: def _on_avctp_channel_close(self) -> None:
logger.debug("AVCTP channel closed") logger.debug("AVCTP channel closed")
self.avctp_protocol = None self.avctp_protocol = None
self.emit("stop") self.emit(self.EVENT_STOP)
def _on_avctp_command( def _on_avctp_command(
self, transaction_label: int, command: avc.CommandFrame self, transaction_label: int, command: avc.CommandFrame

View File

@@ -17,8 +17,8 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
from .hci import HCI_Packet from bumble.hci import HCI_Packet
from .helpers import PacketTracer from bumble.helpers import PacketTracer
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging

View File

@@ -25,10 +25,7 @@ import random
import struct import struct
from bumble.colors import color from bumble.colors import color
from bumble.core import ( from bumble.core import (
BT_CENTRAL_ROLE, PhysicalTransport,
BT_PERIPHERAL_ROLE,
BT_LE_TRANSPORT,
BT_BR_EDR_TRANSPORT,
) )
from bumble.hci import ( from bumble.hci import (
@@ -47,6 +44,7 @@ from bumble.hci import (
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
HCI_VERSION_BLUETOOTH_CORE_5_0, HCI_VERSION_BLUETOOTH_CORE_5_0,
Address, Address,
Role,
HCI_AclDataPacket, HCI_AclDataPacket,
HCI_AclDataPacketAssembler, HCI_AclDataPacketAssembler,
HCI_Command_Complete_Event, HCI_Command_Complete_Event,
@@ -98,7 +96,7 @@ class CisLink:
class Connection: class Connection:
controller: Controller controller: Controller
handle: int handle: int
role: int role: Role
peer_address: Address peer_address: Address
link: Any link: Any
transport: int transport: int
@@ -154,15 +152,17 @@ class Controller:
'0000000060000000' '0000000060000000'
) # BR/EDR Not Supported, LE Supported (Controller) ) # BR/EDR Not Supported, LE Supported (Controller)
self.manufacturer_name = 0xFFFF self.manufacturer_name = 0xFFFF
self.hc_data_packet_length = 27 self.acl_data_packet_length = 27
self.hc_total_num_data_packets = 64 self.total_num_acl_data_packets = 64
self.hc_le_data_packet_length = 27 self.le_acl_data_packet_length = 27
self.hc_total_num_le_data_packets = 64 self.total_num_le_acl_data_packets = 64
self.iso_data_packet_length = 960
self.total_num_iso_data_packets = 64
self.event_mask = 0 self.event_mask = 0
self.event_mask_page_2 = 0 self.event_mask_page_2 = 0
self.supported_commands = bytes.fromhex( self.supported_commands = bytes.fromhex(
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000' '2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
'30f0f9ff01008004000000000000000000000000000000000000000000000000' '30f0f9ff01008004002000000000000000000000000000000000000000000000'
) )
self.le_event_mask = 0 self.le_event_mask = 0
self.advertising_parameters = None self.advertising_parameters = None
@@ -314,7 +314,7 @@ class Controller:
f'{color("CONTROLLER -> HOST", "green")}: {packet}' f'{color("CONTROLLER -> HOST", "green")}: {packet}'
) )
if self.host: if self.host:
self.host.on_packet(packet.to_bytes()) self.host.on_packet(bytes(packet))
# This method allows the controller to emulate the same API as a transport source # This method allows the controller to emulate the same API as a transport source
async def wait_for_termination(self): async def wait_for_termination(self):
@@ -388,10 +388,10 @@ class Controller:
connection = Connection( connection = Connection(
controller=self, controller=self,
handle=connection_handle, handle=connection_handle,
role=BT_PERIPHERAL_ROLE, role=Role.PERIPHERAL,
peer_address=peer_address, peer_address=peer_address,
link=self.link, link=self.link,
transport=BT_LE_TRANSPORT, transport=PhysicalTransport.LE,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
) )
self.peripheral_connections[peer_address] = connection self.peripheral_connections[peer_address] = connection
@@ -448,10 +448,10 @@ class Controller:
connection = Connection( connection = Connection(
controller=self, controller=self,
handle=connection_handle, handle=connection_handle,
role=BT_CENTRAL_ROLE, role=Role.CENTRAL,
peer_address=peer_address, peer_address=peer_address,
link=self.link, link=self.link,
transport=BT_LE_TRANSPORT, transport=PhysicalTransport.LE,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
) )
self.central_connections[peer_address] = connection self.central_connections[peer_address] = connection
@@ -467,7 +467,7 @@ class Controller:
HCI_LE_Connection_Complete_Event( HCI_LE_Connection_Complete_Event(
status=status, status=status,
connection_handle=connection.handle if connection else 0, connection_handle=connection.handle if connection else 0,
role=BT_CENTRAL_ROLE, role=Role.CENTRAL,
peer_address_type=le_create_connection_command.peer_address_type, peer_address_type=le_create_connection_command.peer_address_type,
peer_address=le_create_connection_command.peer_address, peer_address=le_create_connection_command.peer_address,
connection_interval=le_create_connection_command.connection_interval_min, connection_interval=le_create_connection_command.connection_interval_min,
@@ -529,7 +529,7 @@ class Controller:
def on_link_acl_data(self, sender_address, transport, data): def on_link_acl_data(self, sender_address, transport, data):
# Look for the connection to which this data belongs # Look for the connection to which this data belongs
if transport == BT_LE_TRANSPORT: if transport == PhysicalTransport.LE:
connection = self.find_le_connection_by_address(sender_address) connection = self.find_le_connection_by_address(sender_address)
else: else:
connection = self.find_classic_connection_by_address(sender_address) connection = self.find_classic_connection_by_address(sender_address)
@@ -691,10 +691,10 @@ class Controller:
controller=self, controller=self,
handle=connection_handle, handle=connection_handle,
# Role doesn't matter in Classic because they are managed by HCI_Role_Change and HCI_Role_Discovery # Role doesn't matter in Classic because they are managed by HCI_Role_Change and HCI_Role_Discovery
role=BT_CENTRAL_ROLE, role=Role.CENTRAL,
peer_address=peer_address, peer_address=peer_address,
link=self.link, link=self.link,
transport=BT_BR_EDR_TRANSPORT, transport=PhysicalTransport.BR_EDR,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
) )
self.classic_connections[peer_address] = connection self.classic_connections[peer_address] = connection
@@ -759,10 +759,10 @@ class Controller:
controller=self, controller=self,
handle=connection_handle, handle=connection_handle,
# Role doesn't matter in SCO. # Role doesn't matter in SCO.
role=BT_CENTRAL_ROLE, role=Role.CENTRAL,
peer_address=peer_address, peer_address=peer_address,
link=self.link, link=self.link,
transport=BT_BR_EDR_TRANSPORT, transport=PhysicalTransport.BR_EDR,
link_type=link_type, link_type=link_type,
) )
self.classic_connections[peer_address] = connection self.classic_connections[peer_address] = connection
@@ -1181,9 +1181,9 @@ class Controller:
return struct.pack( return struct.pack(
'<BHBHH', '<BHBHH',
HCI_SUCCESS, HCI_SUCCESS,
self.hc_data_packet_length, self.acl_data_packet_length,
0, 0,
self.hc_total_num_data_packets, self.total_num_acl_data_packets,
0, 0,
) )
@@ -1192,7 +1192,7 @@ class Controller:
See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command
''' '''
bd_addr = ( bd_addr = (
self._public_address.to_bytes() bytes(self._public_address)
if self._public_address is not None if self._public_address is not None
else bytes(6) else bytes(6)
) )
@@ -1212,8 +1212,21 @@ class Controller:
return struct.pack( return struct.pack(
'<BHB', '<BHB',
HCI_SUCCESS, HCI_SUCCESS,
self.hc_le_data_packet_length, self.le_acl_data_packet_length,
self.hc_total_num_le_data_packets, self.total_num_le_acl_data_packets,
)
def on_hci_le_read_buffer_size_v2_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.2 LE Read Buffer Size Command
'''
return struct.pack(
'<BHBHB',
HCI_SUCCESS,
self.le_acl_data_packet_length,
self.total_num_le_acl_data_packets,
self.iso_data_packet_length,
self.total_num_iso_data_packets,
) )
def on_hci_le_read_local_supported_features_command(self, _command): def on_hci_le_read_local_supported_features_command(self, _command):
@@ -1543,6 +1556,41 @@ class Controller:
} }
return bytes([HCI_SUCCESS]) return bytes([HCI_SUCCESS])
def on_hci_le_set_advertising_set_random_address_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.52 LE Set Advertising Set Random Address
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_extended_advertising_parameters_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.53 LE Set Extended Advertising Parameters
Command
'''
return bytes([HCI_SUCCESS, 0])
def on_hci_le_set_extended_advertising_data_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.54 LE Set Extended Advertising Data
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_extended_scan_response_data_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.55 LE Set Extended Scan Response Data
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_extended_advertising_enable_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.56 LE Set Extended Advertising Enable
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_read_maximum_advertising_data_length_command(self, _command): def on_hci_le_read_maximum_advertising_data_length_command(self, _command):
''' '''
See Bluetooth spec Vol 4, Part E - 7.8.57 LE Read Maximum Advertising Data See Bluetooth spec Vol 4, Part E - 7.8.57 LE Read Maximum Advertising Data
@@ -1557,6 +1605,27 @@ class Controller:
''' '''
return struct.pack('<BB', HCI_SUCCESS, 0xF0) return struct.pack('<BB', HCI_SUCCESS, 0xF0)
def on_hci_le_set_periodic_advertising_parameters_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.61 LE Set Periodic Advertising Parameters
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_periodic_advertising_data_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.62 LE Set Periodic Advertising Data
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_periodic_advertising_enable_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.63 LE Set Periodic Advertising Enable
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_read_transmit_power_command(self, _command): def on_hci_le_read_transmit_power_command(self, _command):
''' '''
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC # Copyright 2021-2025 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -16,14 +16,14 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import dataclasses
import enum import enum
import struct import struct
from typing import List, Optional, Tuple, Union, cast, Dict from typing import cast, overload, Literal, Union, Optional
from typing_extensions import Self from typing_extensions import Self
from bumble.company_ids import COMPANY_IDENTIFIERS from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.utils import OpenIntEnum from bumble import utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -31,11 +31,12 @@ from bumble.utils import OpenIntEnum
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off # fmt: off
BT_CENTRAL_ROLE = 0 class PhysicalTransport(enum.IntEnum):
BT_PERIPHERAL_ROLE = 1 BR_EDR = 0
LE = 1
BT_BR_EDR_TRANSPORT = 0 BT_BR_EDR_TRANSPORT = PhysicalTransport.BR_EDR
BT_LE_TRANSPORT = 1 BT_LE_TRANSPORT = PhysicalTransport.LE
# fmt: on # fmt: on
@@ -57,7 +58,7 @@ def bit_flags_to_strings(bits, bit_flag_names):
return names return names
def name_or_number(dictionary: Dict[int, str], number: int, width: int = 2) -> str: def name_or_number(dictionary: dict[int, str], number: int, width: int = 2) -> str:
name = dictionary.get(number) name = dictionary.get(number)
if name is not None: if name is not None:
return name return name
@@ -200,7 +201,7 @@ class UUID:
''' '''
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian
UUIDS: List[UUID] = [] # Registry of all instances created UUIDS: list[UUID] = [] # Registry of all instances created
uuid_bytes: bytes uuid_bytes: bytes
name: Optional[str] name: Optional[str]
@@ -259,11 +260,11 @@ class UUID:
return cls.from_bytes(struct.pack('<I', uuid_32), name) return cls.from_bytes(struct.pack('<I', uuid_32), name)
@classmethod @classmethod
def parse_uuid(cls, uuid_as_bytes: bytes, offset: int) -> Tuple[int, UUID]: def parse_uuid(cls, uuid_as_bytes: bytes, offset: int) -> tuple[int, UUID]:
return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:]) return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:])
@classmethod @classmethod
def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> Tuple[int, UUID]: def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> tuple[int, UUID]:
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2]) return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
def to_bytes(self, force_128: bool = False) -> bytes: def to_bytes(self, force_128: bool = False) -> bytes:
@@ -729,7 +730,7 @@ class DeviceClass:
# Appearance # Appearance
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Appearance: class Appearance:
class Category(OpenIntEnum): class Category(utils.OpenIntEnum):
UNKNOWN = 0x0000 UNKNOWN = 0x0000
PHONE = 0x0001 PHONE = 0x0001
COMPUTER = 0x0002 COMPUTER = 0x0002
@@ -783,13 +784,13 @@ class Appearance:
SPIROMETER = 0x0037 SPIROMETER = 0x0037
OUTDOOR_SPORTS_ACTIVITY = 0x0051 OUTDOOR_SPORTS_ACTIVITY = 0x0051
class UnknownSubcategory(OpenIntEnum): class UnknownSubcategory(utils.OpenIntEnum):
GENERIC_UNKNOWN = 0x00 GENERIC_UNKNOWN = 0x00
class PhoneSubcategory(OpenIntEnum): class PhoneSubcategory(utils.OpenIntEnum):
GENERIC_PHONE = 0x00 GENERIC_PHONE = 0x00
class ComputerSubcategory(OpenIntEnum): class ComputerSubcategory(utils.OpenIntEnum):
GENERIC_COMPUTER = 0x00 GENERIC_COMPUTER = 0x00
DESKTOP_WORKSTATION = 0x01 DESKTOP_WORKSTATION = 0x01
SERVER_CLASS_COMPUTER = 0x02 SERVER_CLASS_COMPUTER = 0x02
@@ -807,49 +808,49 @@ class Appearance:
MINI_PC = 0x0E MINI_PC = 0x0E
STICK_PC = 0x0F STICK_PC = 0x0F
class WatchSubcategory(OpenIntEnum): class WatchSubcategory(utils.OpenIntEnum):
GENENERIC_WATCH = 0x00 GENERIC_WATCH = 0x00
SPORTS_WATCH = 0x01 SPORTS_WATCH = 0x01
SMARTWATCH = 0x02 SMARTWATCH = 0x02
class ClockSubcategory(OpenIntEnum): class ClockSubcategory(utils.OpenIntEnum):
GENERIC_CLOCK = 0x00 GENERIC_CLOCK = 0x00
class DisplaySubcategory(OpenIntEnum): class DisplaySubcategory(utils.OpenIntEnum):
GENERIC_DISPLAY = 0x00 GENERIC_DISPLAY = 0x00
class RemoteControlSubcategory(OpenIntEnum): class RemoteControlSubcategory(utils.OpenIntEnum):
GENERIC_REMOTE_CONTROL = 0x00 GENERIC_REMOTE_CONTROL = 0x00
class EyeglassesSubcategory(OpenIntEnum): class EyeglassesSubcategory(utils.OpenIntEnum):
GENERIC_EYEGLASSES = 0x00 GENERIC_EYEGLASSES = 0x00
class TagSubcategory(OpenIntEnum): class TagSubcategory(utils.OpenIntEnum):
GENERIC_TAG = 0x00 GENERIC_TAG = 0x00
class KeyringSubcategory(OpenIntEnum): class KeyringSubcategory(utils.OpenIntEnum):
GENERIC_KEYRING = 0x00 GENERIC_KEYRING = 0x00
class MediaPlayerSubcategory(OpenIntEnum): class MediaPlayerSubcategory(utils.OpenIntEnum):
GENERIC_MEDIA_PLAYER = 0x00 GENERIC_MEDIA_PLAYER = 0x00
class BarcodeScannerSubcategory(OpenIntEnum): class BarcodeScannerSubcategory(utils.OpenIntEnum):
GENERIC_BARCODE_SCANNER = 0x00 GENERIC_BARCODE_SCANNER = 0x00
class ThermometerSubcategory(OpenIntEnum): class ThermometerSubcategory(utils.OpenIntEnum):
GENERIC_THERMOMETER = 0x00 GENERIC_THERMOMETER = 0x00
EAR_THERMOMETER = 0x01 EAR_THERMOMETER = 0x01
class HeartRateSensorSubcategory(OpenIntEnum): class HeartRateSensorSubcategory(utils.OpenIntEnum):
GENERIC_HEART_RATE_SENSOR = 0x00 GENERIC_HEART_RATE_SENSOR = 0x00
HEART_RATE_BELT = 0x01 HEART_RATE_BELT = 0x01
class BloodPressureSubcategory(OpenIntEnum): class BloodPressureSubcategory(utils.OpenIntEnum):
GENERIC_BLOOD_PRESSURE = 0x00 GENERIC_BLOOD_PRESSURE = 0x00
ARM_BLOOD_PRESSURE = 0x01 ARM_BLOOD_PRESSURE = 0x01
WRIST_BLOOD_PRESSURE = 0x02 WRIST_BLOOD_PRESSURE = 0x02
class HumanInterfaceDeviceSubcategory(OpenIntEnum): class HumanInterfaceDeviceSubcategory(utils.OpenIntEnum):
GENERIC_HUMAN_INTERFACE_DEVICE = 0x00 GENERIC_HUMAN_INTERFACE_DEVICE = 0x00
KEYBOARD = 0x01 KEYBOARD = 0x01
MOUSE = 0x02 MOUSE = 0x02
@@ -862,16 +863,16 @@ class Appearance:
TOUCHPAD = 0x09 TOUCHPAD = 0x09
PRESENTATION_REMOTE = 0x0A PRESENTATION_REMOTE = 0x0A
class GlucoseMeterSubcategory(OpenIntEnum): class GlucoseMeterSubcategory(utils.OpenIntEnum):
GENERIC_GLUCOSE_METER = 0x00 GENERIC_GLUCOSE_METER = 0x00
class RunningWalkingSensorSubcategory(OpenIntEnum): class RunningWalkingSensorSubcategory(utils.OpenIntEnum):
GENERIC_RUNNING_WALKING_SENSOR = 0x00 GENERIC_RUNNING_WALKING_SENSOR = 0x00
IN_SHOE_RUNNING_WALKING_SENSOR = 0x01 IN_SHOE_RUNNING_WALKING_SENSOR = 0x01
ON_SHOW_RUNNING_WALKING_SENSOR = 0x02 ON_SHOW_RUNNING_WALKING_SENSOR = 0x02
ON_HIP_RUNNING_WALKING_SENSOR = 0x03 ON_HIP_RUNNING_WALKING_SENSOR = 0x03
class CyclingSubcategory(OpenIntEnum): class CyclingSubcategory(utils.OpenIntEnum):
GENERIC_CYCLING = 0x00 GENERIC_CYCLING = 0x00
CYCLING_COMPUTER = 0x01 CYCLING_COMPUTER = 0x01
SPEED_SENSOR = 0x02 SPEED_SENSOR = 0x02
@@ -879,7 +880,7 @@ class Appearance:
POWER_SENSOR = 0x04 POWER_SENSOR = 0x04
SPEED_AND_CADENCE_SENSOR = 0x05 SPEED_AND_CADENCE_SENSOR = 0x05
class ControlDeviceSubcategory(OpenIntEnum): class ControlDeviceSubcategory(utils.OpenIntEnum):
GENERIC_CONTROL_DEVICE = 0x00 GENERIC_CONTROL_DEVICE = 0x00
SWITCH = 0x01 SWITCH = 0x01
MULTI_SWITCH = 0x02 MULTI_SWITCH = 0x02
@@ -894,13 +895,13 @@ class Appearance:
ENERGY_HARVESTING_SWITCH = 0x0B ENERGY_HARVESTING_SWITCH = 0x0B
PUSH_BUTTON = 0x0C PUSH_BUTTON = 0x0C
class NetworkDeviceSubcategory(OpenIntEnum): class NetworkDeviceSubcategory(utils.OpenIntEnum):
GENERIC_NETWORK_DEVICE = 0x00 GENERIC_NETWORK_DEVICE = 0x00
ACCESS_POINT = 0x01 ACCESS_POINT = 0x01
MESH_DEVICE = 0x02 MESH_DEVICE = 0x02
MESH_NETWORK_PROXY = 0x03 MESH_NETWORK_PROXY = 0x03
class SensorSubcategory(OpenIntEnum): class SensorSubcategory(utils.OpenIntEnum):
GENERIC_SENSOR = 0x00 GENERIC_SENSOR = 0x00
MOTION_SENSOR = 0x01 MOTION_SENSOR = 0x01
AIR_QUALITY_SENSOR = 0x02 AIR_QUALITY_SENSOR = 0x02
@@ -928,7 +929,7 @@ class Appearance:
FLAME_DETECTOR = 0x18 FLAME_DETECTOR = 0x18
VEHICLE_TIRE_PRESSURE_SENSOR = 0x19 VEHICLE_TIRE_PRESSURE_SENSOR = 0x19
class LightFixturesSubcategory(OpenIntEnum): class LightFixturesSubcategory(utils.OpenIntEnum):
GENERIC_LIGHT_FIXTURES = 0x00 GENERIC_LIGHT_FIXTURES = 0x00
WALL_LIGHT = 0x01 WALL_LIGHT = 0x01
CEILING_LIGHT = 0x02 CEILING_LIGHT = 0x02
@@ -956,7 +957,7 @@ class Appearance:
LOW_BAY_LIGHT = 0x18 LOW_BAY_LIGHT = 0x18
HIGH_BAY_LIGHT = 0x19 HIGH_BAY_LIGHT = 0x19
class FanSubcategory(OpenIntEnum): class FanSubcategory(utils.OpenIntEnum):
GENERIC_FAN = 0x00 GENERIC_FAN = 0x00
CEILING_FAN = 0x01 CEILING_FAN = 0x01
AXIAL_FAN = 0x02 AXIAL_FAN = 0x02
@@ -965,7 +966,7 @@ class Appearance:
DESK_FAN = 0x05 DESK_FAN = 0x05
WALL_FAN = 0x06 WALL_FAN = 0x06
class HvacSubcategory(OpenIntEnum): class HvacSubcategory(utils.OpenIntEnum):
GENERIC_HVAC = 0x00 GENERIC_HVAC = 0x00
THERMOSTAT = 0x01 THERMOSTAT = 0x01
HUMIDIFIER = 0x02 HUMIDIFIER = 0x02
@@ -979,13 +980,13 @@ class Appearance:
FAN_HEATER = 0x0A FAN_HEATER = 0x0A
AIR_CURTAIN = 0x0B AIR_CURTAIN = 0x0B
class AirConditioningSubcategory(OpenIntEnum): class AirConditioningSubcategory(utils.OpenIntEnum):
GENERIC_AIR_CONDITIONING = 0x00 GENERIC_AIR_CONDITIONING = 0x00
class HumidifierSubcategory(OpenIntEnum): class HumidifierSubcategory(utils.OpenIntEnum):
GENERIC_HUMIDIFIER = 0x00 GENERIC_HUMIDIFIER = 0x00
class HeatingSubcategory(OpenIntEnum): class HeatingSubcategory(utils.OpenIntEnum):
GENERIC_HEATING = 0x00 GENERIC_HEATING = 0x00
RADIATOR = 0x01 RADIATOR = 0x01
BOILER = 0x02 BOILER = 0x02
@@ -995,7 +996,7 @@ class Appearance:
FAN_HEATER = 0x06 FAN_HEATER = 0x06
AIR_CURTAIN = 0x07 AIR_CURTAIN = 0x07
class AccessControlSubcategory(OpenIntEnum): class AccessControlSubcategory(utils.OpenIntEnum):
GENERIC_ACCESS_CONTROL = 0x00 GENERIC_ACCESS_CONTROL = 0x00
ACCESS_DOOR = 0x01 ACCESS_DOOR = 0x01
GARAGE_DOOR = 0x02 GARAGE_DOOR = 0x02
@@ -1007,7 +1008,7 @@ class Appearance:
DOOR_LOCK = 0x08 DOOR_LOCK = 0x08
LOCKER = 0x09 LOCKER = 0x09
class MotorizedDeviceSubcategory(OpenIntEnum): class MotorizedDeviceSubcategory(utils.OpenIntEnum):
GENERIC_MOTORIZED_DEVICE = 0x00 GENERIC_MOTORIZED_DEVICE = 0x00
MOTORIZED_GATE = 0x01 MOTORIZED_GATE = 0x01
AWNING = 0x02 AWNING = 0x02
@@ -1015,7 +1016,7 @@ class Appearance:
CURTAINS = 0x04 CURTAINS = 0x04
SCREEN = 0x05 SCREEN = 0x05
class PowerDeviceSubcategory(OpenIntEnum): class PowerDeviceSubcategory(utils.OpenIntEnum):
GENERIC_POWER_DEVICE = 0x00 GENERIC_POWER_DEVICE = 0x00
POWER_OUTLET = 0x01 POWER_OUTLET = 0x01
POWER_STRIP = 0x02 POWER_STRIP = 0x02
@@ -1027,7 +1028,7 @@ class Appearance:
CHARGE_CASE = 0x08 CHARGE_CASE = 0x08
POWER_BANK = 0x09 POWER_BANK = 0x09
class LightSourceSubcategory(OpenIntEnum): class LightSourceSubcategory(utils.OpenIntEnum):
GENERIC_LIGHT_SOURCE = 0x00 GENERIC_LIGHT_SOURCE = 0x00
INCANDESCENT_LIGHT_BULB = 0x01 INCANDESCENT_LIGHT_BULB = 0x01
LED_LAMP = 0x02 LED_LAMP = 0x02
@@ -1038,7 +1039,7 @@ class Appearance:
LOW_VOLTAGE_HALOGEN = 0x07 LOW_VOLTAGE_HALOGEN = 0x07
ORGANIC_LIGHT_EMITTING_DIODE = 0x08 ORGANIC_LIGHT_EMITTING_DIODE = 0x08
class WindowCoveringSubcategory(OpenIntEnum): class WindowCoveringSubcategory(utils.OpenIntEnum):
GENERIC_WINDOW_COVERING = 0x00 GENERIC_WINDOW_COVERING = 0x00
WINDOW_SHADES = 0x01 WINDOW_SHADES = 0x01
WINDOW_BLINDS = 0x02 WINDOW_BLINDS = 0x02
@@ -1047,7 +1048,7 @@ class Appearance:
EXTERIOR_SHUTTER = 0x05 EXTERIOR_SHUTTER = 0x05
EXTERIOR_SCREEN = 0x06 EXTERIOR_SCREEN = 0x06
class AudioSinkSubcategory(OpenIntEnum): class AudioSinkSubcategory(utils.OpenIntEnum):
GENERIC_AUDIO_SINK = 0x00 GENERIC_AUDIO_SINK = 0x00
STANDALONE_SPEAKER = 0x01 STANDALONE_SPEAKER = 0x01
SOUNDBAR = 0x02 SOUNDBAR = 0x02
@@ -1055,7 +1056,7 @@ class Appearance:
STANDMOUNTED_SPEAKER = 0x04 STANDMOUNTED_SPEAKER = 0x04
SPEAKERPHONE = 0x05 SPEAKERPHONE = 0x05
class AudioSourceSubcategory(OpenIntEnum): class AudioSourceSubcategory(utils.OpenIntEnum):
GENERIC_AUDIO_SOURCE = 0x00 GENERIC_AUDIO_SOURCE = 0x00
MICROPHONE = 0x01 MICROPHONE = 0x01
ALARM = 0x02 ALARM = 0x02
@@ -1067,7 +1068,7 @@ class Appearance:
BROADCASTING_ROOM = 0x08 BROADCASTING_ROOM = 0x08
AUDITORIUM = 0x09 AUDITORIUM = 0x09
class MotorizedVehicleSubcategory(OpenIntEnum): class MotorizedVehicleSubcategory(utils.OpenIntEnum):
GENERIC_MOTORIZED_VEHICLE = 0x00 GENERIC_MOTORIZED_VEHICLE = 0x00
CAR = 0x01 CAR = 0x01
LARGE_GOODS_VEHICLE = 0x02 LARGE_GOODS_VEHICLE = 0x02
@@ -1085,7 +1086,7 @@ class Appearance:
CAMPER_CARAVAN = 0x0E CAMPER_CARAVAN = 0x0E
RECREATIONAL_VEHICLE_MOTOR_HOME = 0x0F RECREATIONAL_VEHICLE_MOTOR_HOME = 0x0F
class DomesticApplianceSubcategory(OpenIntEnum): class DomesticApplianceSubcategory(utils.OpenIntEnum):
GENERIC_DOMESTIC_APPLIANCE = 0x00 GENERIC_DOMESTIC_APPLIANCE = 0x00
REFRIGERATOR = 0x01 REFRIGERATOR = 0x01
FREEZER = 0x02 FREEZER = 0x02
@@ -1103,21 +1104,21 @@ class Appearance:
RICE_COOKER = 0x0E RICE_COOKER = 0x0E
CLOTHES_STEAMER = 0x0F CLOTHES_STEAMER = 0x0F
class WearableAudioDeviceSubcategory(OpenIntEnum): class WearableAudioDeviceSubcategory(utils.OpenIntEnum):
GENERIC_WEARABLE_AUDIO_DEVICE = 0x00 GENERIC_WEARABLE_AUDIO_DEVICE = 0x00
EARBUD = 0x01 EARBUD = 0x01
HEADSET = 0x02 HEADSET = 0x02
HEADPHONES = 0x03 HEADPHONES = 0x03
NECK_BAND = 0x04 NECK_BAND = 0x04
class AircraftSubcategory(OpenIntEnum): class AircraftSubcategory(utils.OpenIntEnum):
GENERIC_AIRCRAFT = 0x00 GENERIC_AIRCRAFT = 0x00
LIGHT_AIRCRAFT = 0x01 LIGHT_AIRCRAFT = 0x01
MICROLIGHT = 0x02 MICROLIGHT = 0x02
PARAGLIDER = 0x03 PARAGLIDER = 0x03
LARGE_PASSENGER_AIRCRAFT = 0x04 LARGE_PASSENGER_AIRCRAFT = 0x04
class AvEquipmentSubcategory(OpenIntEnum): class AvEquipmentSubcategory(utils.OpenIntEnum):
GENERIC_AV_EQUIPMENT = 0x00 GENERIC_AV_EQUIPMENT = 0x00
AMPLIFIER = 0x01 AMPLIFIER = 0x01
RECEIVER = 0x02 RECEIVER = 0x02
@@ -1126,69 +1127,69 @@ class Appearance:
TURNTABLE = 0x05 TURNTABLE = 0x05
CD_PLAYER = 0x06 CD_PLAYER = 0x06
DVD_PLAYER = 0x07 DVD_PLAYER = 0x07
BLUERAY_PLAYER = 0x08 BLURAY_PLAYER = 0x08
OPTICAL_DISC_PLAYER = 0x09 OPTICAL_DISC_PLAYER = 0x09
SET_TOP_BOX = 0x0A SET_TOP_BOX = 0x0A
class DisplayEquipmentSubcategory(OpenIntEnum): class DisplayEquipmentSubcategory(utils.OpenIntEnum):
GENERIC_DISPLAY_EQUIPMENT = 0x00 GENERIC_DISPLAY_EQUIPMENT = 0x00
TELEVISION = 0x01 TELEVISION = 0x01
MONITOR = 0x02 MONITOR = 0x02
PROJECTOR = 0x03 PROJECTOR = 0x03
class HearingAidSubcategory(OpenIntEnum): class HearingAidSubcategory(utils.OpenIntEnum):
GENERIC_HEARING_AID = 0x00 GENERIC_HEARING_AID = 0x00
IN_EAR_HEARING_AID = 0x01 IN_EAR_HEARING_AID = 0x01
BEHIND_EAR_HEARING_AID = 0x02 BEHIND_EAR_HEARING_AID = 0x02
COCHLEAR_IMPLANT = 0x03 COCHLEAR_IMPLANT = 0x03
class GamingSubcategory(OpenIntEnum): class GamingSubcategory(utils.OpenIntEnum):
GENERIC_GAMING = 0x00 GENERIC_GAMING = 0x00
HOME_VIDEO_GAME_CONSOLE = 0x01 HOME_VIDEO_GAME_CONSOLE = 0x01
PORTABLE_HANDHELD_CONSOLE = 0x02 PORTABLE_HANDHELD_CONSOLE = 0x02
class SignageSubcategory(OpenIntEnum): class SignageSubcategory(utils.OpenIntEnum):
GENERIC_SIGNAGE = 0x00 GENERIC_SIGNAGE = 0x00
DIGITAL_SIGNAGE = 0x01 DIGITAL_SIGNAGE = 0x01
ELECTRONIC_LABEL = 0x02 ELECTRONIC_LABEL = 0x02
class PulseOximeterSubcategory(OpenIntEnum): class PulseOximeterSubcategory(utils.OpenIntEnum):
GENERIC_PULSE_OXIMETER = 0x00 GENERIC_PULSE_OXIMETER = 0x00
FINGERTIP_PULSE_OXIMETER = 0x01 FINGERTIP_PULSE_OXIMETER = 0x01
WRIST_WORN_PULSE_OXIMETER = 0x02 WRIST_WORN_PULSE_OXIMETER = 0x02
class WeightScaleSubcategory(OpenIntEnum): class WeightScaleSubcategory(utils.OpenIntEnum):
GENERIC_WEIGHT_SCALE = 0x00 GENERIC_WEIGHT_SCALE = 0x00
class PersonalMobilityDeviceSubcategory(OpenIntEnum): class PersonalMobilityDeviceSubcategory(utils.OpenIntEnum):
GENERIC_PERSONAL_MOBILITY_DEVICE = 0x00 GENERIC_PERSONAL_MOBILITY_DEVICE = 0x00
POWERED_WHEELCHAIR = 0x01 POWERED_WHEELCHAIR = 0x01
MOBILITY_SCOOTER = 0x02 MOBILITY_SCOOTER = 0x02
class ContinuousGlucoseMonitorSubcategory(OpenIntEnum): class ContinuousGlucoseMonitorSubcategory(utils.OpenIntEnum):
GENERIC_CONTINUOUS_GLUCOSE_MONITOR = 0x00 GENERIC_CONTINUOUS_GLUCOSE_MONITOR = 0x00
class InsulinPumpSubcategory(OpenIntEnum): class InsulinPumpSubcategory(utils.OpenIntEnum):
GENERIC_INSULIN_PUMP = 0x00 GENERIC_INSULIN_PUMP = 0x00
INSULIN_PUMP_DURABLE_PUMP = 0x01 INSULIN_PUMP_DURABLE_PUMP = 0x01
INSULIN_PUMP_PATCH_PUMP = 0x02 INSULIN_PUMP_PATCH_PUMP = 0x02
INSULIN_PEN = 0x03 INSULIN_PEN = 0x03
class MedicationDeliverySubcategory(OpenIntEnum): class MedicationDeliverySubcategory(utils.OpenIntEnum):
GENERIC_MEDICATION_DELIVERY = 0x00 GENERIC_MEDICATION_DELIVERY = 0x00
class SpirometerSubcategory(OpenIntEnum): class SpirometerSubcategory(utils.OpenIntEnum):
GENERIC_SPIROMETER = 0x00 GENERIC_SPIROMETER = 0x00
HANDHELD_SPIROMETER = 0x01 HANDHELD_SPIROMETER = 0x01
class OutdoorSportsActivitySubcategory(OpenIntEnum): class OutdoorSportsActivitySubcategory(utils.OpenIntEnum):
GENERIC_OUTDOOR_SPORTS_ACTIVITY = 0x00 GENERIC_OUTDOOR_SPORTS_ACTIVITY = 0x00
LOCATION_DISPLAY = 0x01 LOCATION_DISPLAY = 0x01
LOCATION_AND_NAVIGATION_DISPLAY = 0x02 LOCATION_AND_NAVIGATION_DISPLAY = 0x02
LOCATION_POD = 0x03 LOCATION_POD = 0x03
LOCATION_AND_NAVIGATION_POD = 0x04 LOCATION_AND_NAVIGATION_POD = 0x04
class _OpenSubcategory(OpenIntEnum): class _OpenSubcategory(utils.OpenIntEnum):
GENERIC = 0x00 GENERIC = 0x00
SUBCATEGORY_CLASSES = { SUBCATEGORY_CLASSES = {
@@ -1280,13 +1281,13 @@ class Appearance:
# Advertising Data # Advertising Data
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
AdvertisingDataObject = Union[ AdvertisingDataObject = Union[
List[UUID], list[UUID],
Tuple[UUID, bytes], tuple[UUID, bytes],
bytes, bytes,
str, str,
int, int,
Tuple[int, int], tuple[int, int],
Tuple[int, bytes], tuple[int, bytes],
Appearance, Appearance,
] ]
@@ -1295,129 +1296,135 @@ class AdvertisingData:
# fmt: off # fmt: off
# pylint: disable=line-too-long # pylint: disable=line-too-long
FLAGS = 0x01 class Type(utils.OpenIntEnum):
INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02 FLAGS = 0x01
COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x03 INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02
INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = 0x04 COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x03
COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = 0x05 INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = 0x04
INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = 0x06 COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = 0x05
COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = 0x07 INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = 0x06
SHORTENED_LOCAL_NAME = 0x08 COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = 0x07
COMPLETE_LOCAL_NAME = 0x09 SHORTENED_LOCAL_NAME = 0x08
TX_POWER_LEVEL = 0x0A COMPLETE_LOCAL_NAME = 0x09
CLASS_OF_DEVICE = 0x0D TX_POWER_LEVEL = 0x0A
SIMPLE_PAIRING_HASH_C = 0x0E CLASS_OF_DEVICE = 0x0D
SIMPLE_PAIRING_HASH_C_192 = 0x0E SIMPLE_PAIRING_HASH_C = 0x0E
SIMPLE_PAIRING_RANDOMIZER_R = 0x0F SIMPLE_PAIRING_HASH_C_192 = 0x0E
SIMPLE_PAIRING_RANDOMIZER_R_192 = 0x0F SIMPLE_PAIRING_RANDOMIZER_R = 0x0F
DEVICE_ID = 0x10 SIMPLE_PAIRING_RANDOMIZER_R_192 = 0x0F
SECURITY_MANAGER_TK_VALUE = 0x10 DEVICE_ID = 0x10
SECURITY_MANAGER_OUT_OF_BAND_FLAGS = 0x11 SECURITY_MANAGER_TK_VALUE = 0x10
PERIPHERAL_CONNECTION_INTERVAL_RANGE = 0x12 SECURITY_MANAGER_OUT_OF_BAND_FLAGS = 0x11
LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS = 0x14 PERIPHERAL_CONNECTION_INTERVAL_RANGE = 0x12
LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS = 0x15 LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS = 0x14
SERVICE_DATA = 0x16 LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS = 0x15
SERVICE_DATA_16_BIT_UUID = 0x16 SERVICE_DATA_16_BIT_UUID = 0x16
PUBLIC_TARGET_ADDRESS = 0x17 PUBLIC_TARGET_ADDRESS = 0x17
RANDOM_TARGET_ADDRESS = 0x18 RANDOM_TARGET_ADDRESS = 0x18
APPEARANCE = 0x19 APPEARANCE = 0x19
ADVERTISING_INTERVAL = 0x1A ADVERTISING_INTERVAL = 0x1A
LE_BLUETOOTH_DEVICE_ADDRESS = 0x1B LE_BLUETOOTH_DEVICE_ADDRESS = 0x1B
LE_ROLE = 0x1C LE_ROLE = 0x1C
SIMPLE_PAIRING_HASH_C_256 = 0x1D SIMPLE_PAIRING_HASH_C_256 = 0x1D
SIMPLE_PAIRING_RANDOMIZER_R_256 = 0x1E SIMPLE_PAIRING_RANDOMIZER_R_256 = 0x1E
LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS = 0x1F LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS = 0x1F
SERVICE_DATA_32_BIT_UUID = 0x20 SERVICE_DATA_32_BIT_UUID = 0x20
SERVICE_DATA_128_BIT_UUID = 0x21 SERVICE_DATA_128_BIT_UUID = 0x21
LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE = 0x22 LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE = 0x22
LE_SECURE_CONNECTIONS_RANDOM_VALUE = 0x23 LE_SECURE_CONNECTIONS_RANDOM_VALUE = 0x23
URI = 0x24 URI = 0x24
INDOOR_POSITIONING = 0x25 INDOOR_POSITIONING = 0x25
TRANSPORT_DISCOVERY_DATA = 0x26 TRANSPORT_DISCOVERY_DATA = 0x26
LE_SUPPORTED_FEATURES = 0x27 LE_SUPPORTED_FEATURES = 0x27
CHANNEL_MAP_UPDATE_INDICATION = 0x28 CHANNEL_MAP_UPDATE_INDICATION = 0x28
PB_ADV = 0x29 PB_ADV = 0x29
MESH_MESSAGE = 0x2A MESH_MESSAGE = 0x2A
MESH_BEACON = 0x2B MESH_BEACON = 0x2B
BIGINFO = 0x2C BIGINFO = 0x2C
BROADCAST_CODE = 0x2D BROADCAST_CODE = 0x2D
RESOLVABLE_SET_IDENTIFIER = 0x2E RESOLVABLE_SET_IDENTIFIER = 0x2E
ADVERTISING_INTERVAL_LONG = 0x2F ADVERTISING_INTERVAL_LONG = 0x2F
BROADCAST_NAME = 0x30 BROADCAST_NAME = 0x30
ENCRYPTED_ADVERTISING_DATA = 0X31 ENCRYPTED_ADVERTISING_DATA = 0x31
PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION = 0X32 PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION = 0x32
ELECTRONIC_SHELF_LABEL = 0X34 ELECTRONIC_SHELF_LABEL = 0x34
THREE_D_INFORMATION_DATA = 0x3D THREE_D_INFORMATION_DATA = 0x3D
MANUFACTURER_SPECIFIC_DATA = 0xFF MANUFACTURER_SPECIFIC_DATA = 0xFF
AD_TYPE_NAMES = { class Flags(enum.IntFlag):
FLAGS: 'FLAGS', LE_LIMITED_DISCOVERABLE_MODE = 1 << 0
INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: 'INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS', LE_GENERAL_DISCOVERABLE_MODE = 1 << 1
COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: 'COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS', BR_EDR_NOT_SUPPORTED = 1 << 2
INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: 'INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS', SIMULTANEOUS_LE_BR_EDR_CAPABLE = 1 << 3
COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: 'COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS',
INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: 'INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS',
COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: 'COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS',
SHORTENED_LOCAL_NAME: 'SHORTENED_LOCAL_NAME',
COMPLETE_LOCAL_NAME: 'COMPLETE_LOCAL_NAME',
TX_POWER_LEVEL: 'TX_POWER_LEVEL',
CLASS_OF_DEVICE: 'CLASS_OF_DEVICE',
SIMPLE_PAIRING_HASH_C: 'SIMPLE_PAIRING_HASH_C',
SIMPLE_PAIRING_HASH_C_192: 'SIMPLE_PAIRING_HASH_C_192',
SIMPLE_PAIRING_RANDOMIZER_R: 'SIMPLE_PAIRING_RANDOMIZER_R',
SIMPLE_PAIRING_RANDOMIZER_R_192: 'SIMPLE_PAIRING_RANDOMIZER_R_192',
DEVICE_ID: 'DEVICE_ID',
SECURITY_MANAGER_TK_VALUE: 'SECURITY_MANAGER_TK_VALUE',
SECURITY_MANAGER_OUT_OF_BAND_FLAGS: 'SECURITY_MANAGER_OUT_OF_BAND_FLAGS',
PERIPHERAL_CONNECTION_INTERVAL_RANGE: 'PERIPHERAL_CONNECTION_INTERVAL_RANGE',
LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS: 'LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS',
LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS: 'LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS',
SERVICE_DATA_16_BIT_UUID: 'SERVICE_DATA_16_BIT_UUID',
PUBLIC_TARGET_ADDRESS: 'PUBLIC_TARGET_ADDRESS',
RANDOM_TARGET_ADDRESS: 'RANDOM_TARGET_ADDRESS',
APPEARANCE: 'APPEARANCE',
ADVERTISING_INTERVAL: 'ADVERTISING_INTERVAL',
LE_BLUETOOTH_DEVICE_ADDRESS: 'LE_BLUETOOTH_DEVICE_ADDRESS',
LE_ROLE: 'LE_ROLE',
SIMPLE_PAIRING_HASH_C_256: 'SIMPLE_PAIRING_HASH_C_256',
SIMPLE_PAIRING_RANDOMIZER_R_256: 'SIMPLE_PAIRING_RANDOMIZER_R_256',
LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS: 'LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS',
SERVICE_DATA_32_BIT_UUID: 'SERVICE_DATA_32_BIT_UUID',
SERVICE_DATA_128_BIT_UUID: 'SERVICE_DATA_128_BIT_UUID',
LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE: 'LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE',
LE_SECURE_CONNECTIONS_RANDOM_VALUE: 'LE_SECURE_CONNECTIONS_RANDOM_VALUE',
URI: 'URI',
INDOOR_POSITIONING: 'INDOOR_POSITIONING',
TRANSPORT_DISCOVERY_DATA: 'TRANSPORT_DISCOVERY_DATA',
LE_SUPPORTED_FEATURES: 'LE_SUPPORTED_FEATURES',
CHANNEL_MAP_UPDATE_INDICATION: 'CHANNEL_MAP_UPDATE_INDICATION',
PB_ADV: 'PB_ADV',
MESH_MESSAGE: 'MESH_MESSAGE',
MESH_BEACON: 'MESH_BEACON',
BIGINFO: 'BIGINFO',
BROADCAST_CODE: 'BROADCAST_CODE',
RESOLVABLE_SET_IDENTIFIER: 'RESOLVABLE_SET_IDENTIFIER',
ADVERTISING_INTERVAL_LONG: 'ADVERTISING_INTERVAL_LONG',
BROADCAST_NAME: 'BROADCAST_NAME',
ENCRYPTED_ADVERTISING_DATA: 'ENCRYPTED_ADVERTISING_DATA',
PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION: 'PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION',
ELECTRONIC_SHELF_LABEL: 'ELECTRONIC_SHELF_LABEL',
THREE_D_INFORMATION_DATA: 'THREE_D_INFORMATION_DATA',
MANUFACTURER_SPECIFIC_DATA: 'MANUFACTURER_SPECIFIC_DATA'
}
LE_LIMITED_DISCOVERABLE_MODE_FLAG = 0x01 # For backward-compatibility
LE_GENERAL_DISCOVERABLE_MODE_FLAG = 0x02 FLAGS = Type.FLAGS
BR_EDR_NOT_SUPPORTED_FLAG = 0x04 INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = Type.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS
BR_EDR_CONTROLLER_FLAG = 0x08 COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = Type.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS
BR_EDR_HOST_FLAG = 0x10 INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = Type.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS
COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = Type.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS
INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = Type.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS
COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = Type.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS
SHORTENED_LOCAL_NAME = Type.SHORTENED_LOCAL_NAME
COMPLETE_LOCAL_NAME = Type.COMPLETE_LOCAL_NAME
TX_POWER_LEVEL = Type.TX_POWER_LEVEL
CLASS_OF_DEVICE = Type.CLASS_OF_DEVICE
SIMPLE_PAIRING_HASH_C = Type.SIMPLE_PAIRING_HASH_C
SIMPLE_PAIRING_HASH_C_192 = Type.SIMPLE_PAIRING_HASH_C_192
SIMPLE_PAIRING_RANDOMIZER_R = Type.SIMPLE_PAIRING_RANDOMIZER_R
SIMPLE_PAIRING_RANDOMIZER_R_192 = Type.SIMPLE_PAIRING_RANDOMIZER_R_192
DEVICE_ID = Type.DEVICE_ID
SECURITY_MANAGER_TK_VALUE = Type.SECURITY_MANAGER_TK_VALUE
SECURITY_MANAGER_OUT_OF_BAND_FLAGS = Type.SECURITY_MANAGER_OUT_OF_BAND_FLAGS
PERIPHERAL_CONNECTION_INTERVAL_RANGE = Type.PERIPHERAL_CONNECTION_INTERVAL_RANGE
LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS = Type.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS
LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS = Type.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS
SERVICE_DATA = Type.SERVICE_DATA_16_BIT_UUID
SERVICE_DATA_16_BIT_UUID = Type.SERVICE_DATA_16_BIT_UUID
PUBLIC_TARGET_ADDRESS = Type.PUBLIC_TARGET_ADDRESS
RANDOM_TARGET_ADDRESS = Type.RANDOM_TARGET_ADDRESS
APPEARANCE = Type.APPEARANCE
ADVERTISING_INTERVAL = Type.ADVERTISING_INTERVAL
LE_BLUETOOTH_DEVICE_ADDRESS = Type.LE_BLUETOOTH_DEVICE_ADDRESS
LE_ROLE = Type.LE_ROLE
SIMPLE_PAIRING_HASH_C_256 = Type.SIMPLE_PAIRING_HASH_C_256
SIMPLE_PAIRING_RANDOMIZER_R_256 = Type.SIMPLE_PAIRING_RANDOMIZER_R_256
LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS = Type.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS
SERVICE_DATA_32_BIT_UUID = Type.SERVICE_DATA_32_BIT_UUID
SERVICE_DATA_128_BIT_UUID = Type.SERVICE_DATA_128_BIT_UUID
LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE = Type.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE
LE_SECURE_CONNECTIONS_RANDOM_VALUE = Type.LE_SECURE_CONNECTIONS_RANDOM_VALUE
URI = Type.URI
INDOOR_POSITIONING = Type.INDOOR_POSITIONING
TRANSPORT_DISCOVERY_DATA = Type.TRANSPORT_DISCOVERY_DATA
LE_SUPPORTED_FEATURES = Type.LE_SUPPORTED_FEATURES
CHANNEL_MAP_UPDATE_INDICATION = Type.CHANNEL_MAP_UPDATE_INDICATION
PB_ADV = Type.PB_ADV
MESH_MESSAGE = Type.MESH_MESSAGE
MESH_BEACON = Type.MESH_BEACON
BIGINFO = Type.BIGINFO
BROADCAST_CODE = Type.BROADCAST_CODE
RESOLVABLE_SET_IDENTIFIER = Type.RESOLVABLE_SET_IDENTIFIER
ADVERTISING_INTERVAL_LONG = Type.ADVERTISING_INTERVAL_LONG
BROADCAST_NAME = Type.BROADCAST_NAME
ENCRYPTED_ADVERTISING_DATA = Type.ENCRYPTED_ADVERTISING_DATA
PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION = Type.PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION
ELECTRONIC_SHELF_LABEL = Type.ELECTRONIC_SHELF_LABEL
THREE_D_INFORMATION_DATA = Type.THREE_D_INFORMATION_DATA
MANUFACTURER_SPECIFIC_DATA = Type.MANUFACTURER_SPECIFIC_DATA
ad_structures: List[Tuple[int, bytes]] LE_LIMITED_DISCOVERABLE_MODE_FLAG = Flags.LE_LIMITED_DISCOVERABLE_MODE
LE_GENERAL_DISCOVERABLE_MODE_FLAG = Flags.LE_GENERAL_DISCOVERABLE_MODE
BR_EDR_NOT_SUPPORTED_FLAG = Flags.BR_EDR_NOT_SUPPORTED
BR_EDR_CONTROLLER_FLAG = Flags.SIMULTANEOUS_LE_BR_EDR_CAPABLE
BR_EDR_HOST_FLAG = 0x10 # Deprecated
ad_structures: list[tuple[int, bytes]]
# fmt: on # fmt: on
# pylint: enable=line-too-long # pylint: enable=line-too-long
def __init__(self, ad_structures: Optional[List[Tuple[int, bytes]]] = None) -> None: def __init__(self, ad_structures: Optional[list[tuple[int, bytes]]] = None) -> None:
if ad_structures is None: if ad_structures is None:
ad_structures = [] ad_structures = []
self.ad_structures = ad_structures[:] self.ad_structures = ad_structures[:]
@@ -1444,7 +1451,7 @@ class AdvertisingData:
return ','.join(bit_flags_to_strings(flags, flag_names)) return ','.join(bit_flags_to_strings(flags, flag_names))
@staticmethod @staticmethod
def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> List[UUID]: def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> list[UUID]:
uuids = [] uuids = []
offset = 0 offset = 0
while (offset + uuid_size) <= len(ad_data): while (offset + uuid_size) <= len(ad_data):
@@ -1461,8 +1468,8 @@ class AdvertisingData:
] ]
) )
@staticmethod @classmethod
def ad_data_to_string(ad_type, ad_data): def ad_data_to_string(cls, ad_type: int, ad_data: bytes) -> str:
if ad_type == AdvertisingData.FLAGS: if ad_type == AdvertisingData.FLAGS:
ad_type_str = 'Flags' ad_type_str = 'Flags'
ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True) ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True)
@@ -1501,7 +1508,10 @@ class AdvertisingData:
ad_data_str = f'"{ad_data.decode("utf-8")}"' ad_data_str = f'"{ad_data.decode("utf-8")}"'
elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME: elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME:
ad_type_str = 'Complete Local Name' ad_type_str = 'Complete Local Name'
ad_data_str = f'"{ad_data.decode("utf-8")}"' try:
ad_data_str = f'"{ad_data.decode("utf-8")}"'
except UnicodeDecodeError:
ad_data_str = ad_data.hex()
elif ad_type == AdvertisingData.TX_POWER_LEVEL: elif ad_type == AdvertisingData.TX_POWER_LEVEL:
ad_type_str = 'TX Power Level' ad_type_str = 'TX Power Level'
ad_data_str = str(ad_data[0]) ad_data_str = str(ad_data[0])
@@ -1518,72 +1528,72 @@ class AdvertisingData:
ad_type_str = 'Broadcast Name' ad_type_str = 'Broadcast Name'
ad_data_str = ad_data.decode('utf-8') ad_data_str = ad_data.decode('utf-8')
else: else:
ad_type_str = AdvertisingData.AD_TYPE_NAMES.get(ad_type, f'0x{ad_type:02X}') ad_type_str = AdvertisingData.Type(ad_type).name
ad_data_str = ad_data.hex() ad_data_str = ad_data.hex()
return f'[{ad_type_str}]: {ad_data_str}' return f'[{ad_type_str}]: {ad_data_str}'
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
@staticmethod @classmethod
def ad_data_to_object(ad_type: int, ad_data: bytes) -> AdvertisingDataObject: def ad_data_to_object(cls, ad_type: int, ad_data: bytes) -> AdvertisingDataObject:
if ad_type in ( if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.Type.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.Type.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS, AdvertisingData.Type.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
): ):
return AdvertisingData.uuid_list_to_objects(ad_data, 2) return AdvertisingData.uuid_list_to_objects(ad_data, 2)
if ad_type in ( if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.Type.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.Type.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS, AdvertisingData.Type.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
): ):
return AdvertisingData.uuid_list_to_objects(ad_data, 4) return AdvertisingData.uuid_list_to_objects(ad_data, 4)
if ad_type in ( if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.Type.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.Type.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS, AdvertisingData.Type.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
): ):
return AdvertisingData.uuid_list_to_objects(ad_data, 16) return AdvertisingData.uuid_list_to_objects(ad_data, 16)
if ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID: if ad_type == AdvertisingData.Type.SERVICE_DATA_16_BIT_UUID:
return (UUID.from_bytes(ad_data[:2]), ad_data[2:]) return (UUID.from_bytes(ad_data[:2]), ad_data[2:])
if ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID: if ad_type == AdvertisingData.Type.SERVICE_DATA_32_BIT_UUID:
return (UUID.from_bytes(ad_data[:4]), ad_data[4:]) return (UUID.from_bytes(ad_data[:4]), ad_data[4:])
if ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID: if ad_type == AdvertisingData.Type.SERVICE_DATA_128_BIT_UUID:
return (UUID.from_bytes(ad_data[:16]), ad_data[16:]) return (UUID.from_bytes(ad_data[:16]), ad_data[16:])
if ad_type in ( if ad_type in (
AdvertisingData.SHORTENED_LOCAL_NAME, AdvertisingData.Type.SHORTENED_LOCAL_NAME,
AdvertisingData.COMPLETE_LOCAL_NAME, AdvertisingData.Type.COMPLETE_LOCAL_NAME,
AdvertisingData.URI, AdvertisingData.Type.URI,
AdvertisingData.BROADCAST_NAME, AdvertisingData.Type.BROADCAST_NAME,
): ):
return ad_data.decode("utf-8") return ad_data.decode("utf-8")
if ad_type in (AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS): if ad_type in (AdvertisingData.Type.TX_POWER_LEVEL, AdvertisingData.Type.FLAGS):
return cast(int, struct.unpack('B', ad_data)[0]) return cast(int, struct.unpack('B', ad_data)[0])
if ad_type in (AdvertisingData.ADVERTISING_INTERVAL,): if ad_type in (AdvertisingData.Type.ADVERTISING_INTERVAL,):
return cast(int, struct.unpack('<H', ad_data)[0]) return cast(int, struct.unpack('<H', ad_data)[0])
if ad_type == AdvertisingData.CLASS_OF_DEVICE: if ad_type == AdvertisingData.Type.CLASS_OF_DEVICE:
return cast(int, struct.unpack('<I', bytes([*ad_data, 0]))[0]) return cast(int, struct.unpack('<I', bytes([*ad_data, 0]))[0])
if ad_type == AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE: if ad_type == AdvertisingData.Type.PERIPHERAL_CONNECTION_INTERVAL_RANGE:
return cast(Tuple[int, int], struct.unpack('<HH', ad_data)) return cast(tuple[int, int], struct.unpack('<HH', ad_data))
if ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA: if ad_type == AdvertisingData.Type.APPEARANCE:
return (cast(int, struct.unpack_from('<H', ad_data, 0)[0]), ad_data[2:])
if ad_type == AdvertisingData.APPEARANCE:
return Appearance.from_int( return Appearance.from_int(
cast(int, struct.unpack_from('<H', ad_data, 0)[0]) cast(int, struct.unpack_from('<H', ad_data, 0)[0])
) )
if ad_type == AdvertisingData.Type.MANUFACTURER_SPECIFIC_DATA:
return (cast(int, struct.unpack_from('<H', ad_data, 0)[0]), ad_data[2:])
return ad_data return ad_data
def append(self, data: bytes) -> None: def append(self, data: bytes) -> None:
@@ -1597,7 +1607,80 @@ class AdvertisingData:
self.ad_structures.append((ad_type, ad_data)) self.ad_structures.append((ad_type, ad_data))
offset += length offset += length
def get_all(self, type_id: int, raw: bool = False) -> List[AdvertisingDataObject]: @overload
def get_all(
self,
type_id: Literal[
AdvertisingData.Type.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
AdvertisingData.Type.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
AdvertisingData.Type.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
],
raw: Literal[False] = False,
) -> list[list[UUID]]: ...
@overload
def get_all(
self,
type_id: Literal[
AdvertisingData.Type.SERVICE_DATA_16_BIT_UUID,
AdvertisingData.Type.SERVICE_DATA_32_BIT_UUID,
AdvertisingData.Type.SERVICE_DATA_128_BIT_UUID,
],
raw: Literal[False] = False,
) -> list[tuple[UUID, bytes]]: ...
@overload
def get_all(
self,
type_id: Literal[
AdvertisingData.Type.SHORTENED_LOCAL_NAME,
AdvertisingData.Type.COMPLETE_LOCAL_NAME,
AdvertisingData.Type.URI,
AdvertisingData.Type.BROADCAST_NAME,
],
raw: Literal[False] = False,
) -> list[str]: ...
@overload
def get_all(
self,
type_id: Literal[
AdvertisingData.Type.TX_POWER_LEVEL,
AdvertisingData.Type.FLAGS,
AdvertisingData.Type.ADVERTISING_INTERVAL,
AdvertisingData.Type.CLASS_OF_DEVICE,
],
raw: Literal[False] = False,
) -> list[int]: ...
@overload
def get_all(
self,
type_id: Literal[AdvertisingData.Type.PERIPHERAL_CONNECTION_INTERVAL_RANGE,],
raw: Literal[False] = False,
) -> list[tuple[int, int]]: ...
@overload
def get_all(
self,
type_id: Literal[AdvertisingData.Type.MANUFACTURER_SPECIFIC_DATA,],
raw: Literal[False] = False,
) -> list[tuple[int, bytes]]: ...
@overload
def get_all(
self,
type_id: Literal[AdvertisingData.Type.APPEARANCE,],
raw: Literal[False] = False,
) -> list[Appearance]: ...
@overload
def get_all(self, type_id: int, raw: Literal[True]) -> list[bytes]: ...
@overload
def get_all(
self, type_id: int, raw: bool = False
) -> list[AdvertisingDataObject]: ...
def get_all(self, type_id: int, raw: bool = False) -> list[AdvertisingDataObject]: # type: ignore[misc]
''' '''
Get Advertising Data Structure(s) with a given type Get Advertising Data Structure(s) with a given type
@@ -1609,6 +1692,79 @@ class AdvertisingData:
return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id] return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id]
@overload
def get(
self,
type_id: Literal[
AdvertisingData.Type.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
AdvertisingData.Type.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
AdvertisingData.Type.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.Type.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
],
raw: Literal[False] = False,
) -> Optional[list[UUID]]: ...
@overload
def get(
self,
type_id: Literal[
AdvertisingData.Type.SERVICE_DATA_16_BIT_UUID,
AdvertisingData.Type.SERVICE_DATA_32_BIT_UUID,
AdvertisingData.Type.SERVICE_DATA_128_BIT_UUID,
],
raw: Literal[False] = False,
) -> Optional[tuple[UUID, bytes]]: ...
@overload
def get(
self,
type_id: Literal[
AdvertisingData.Type.SHORTENED_LOCAL_NAME,
AdvertisingData.Type.COMPLETE_LOCAL_NAME,
AdvertisingData.Type.URI,
AdvertisingData.Type.BROADCAST_NAME,
],
raw: Literal[False] = False,
) -> Optional[Optional[str]]: ...
@overload
def get(
self,
type_id: Literal[
AdvertisingData.Type.TX_POWER_LEVEL,
AdvertisingData.Type.FLAGS,
AdvertisingData.Type.ADVERTISING_INTERVAL,
AdvertisingData.Type.CLASS_OF_DEVICE,
],
raw: Literal[False] = False,
) -> Optional[int]: ...
@overload
def get(
self,
type_id: Literal[AdvertisingData.Type.PERIPHERAL_CONNECTION_INTERVAL_RANGE,],
raw: Literal[False] = False,
) -> Optional[tuple[int, int]]: ...
@overload
def get(
self,
type_id: Literal[AdvertisingData.Type.MANUFACTURER_SPECIFIC_DATA,],
raw: Literal[False] = False,
) -> Optional[tuple[int, bytes]]: ...
@overload
def get(
self,
type_id: Literal[AdvertisingData.Type.APPEARANCE,],
raw: Literal[False] = False,
) -> Optional[Appearance]: ...
@overload
def get(self, type_id: int, raw: Literal[True]) -> Optional[bytes]: ...
@overload
def get(
self, type_id: int, raw: bool = False
) -> Optional[AdvertisingDataObject]: ...
def get(self, type_id: int, raw: bool = False) -> Optional[AdvertisingDataObject]: def get(self, type_id: int, raw: bool = False) -> Optional[AdvertisingDataObject]:
''' '''
Get Advertising Data Structure(s) with a given type Get Advertising Data Structure(s) with a given type

File diff suppressed because it is too large Load Diff

View File

@@ -25,8 +25,8 @@ import pathlib
import platform import platform
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
from . import rtk, intel from bumble.drivers import rtk, intel
from .common import Driver from bumble.drivers.common import Driver
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.host import Host from bumble.host import Host

View File

@@ -20,6 +20,8 @@ Common types for drivers.
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import abc import abc
from bumble import core
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Classes # Classes

View File

@@ -11,18 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Support for Intel USB controllers.
Loosely based on the Fuchsia OS implementation.
"""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import collections
import dataclasses
import logging import logging
import os
import pathlib
import platform
import struct
from typing import Any, Deque, Optional, TYPE_CHECKING
from bumble import core
from bumble.drivers import common from bumble.drivers import common
from bumble.hci import ( from bumble import hci
hci_vendor_command_op_code, # type: ignore from bumble import utils
HCI_Command,
HCI_Reset_Command, if TYPE_CHECKING:
) from bumble.host import Host
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -34,39 +49,328 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
INTEL_USB_PRODUCTS = { INTEL_USB_PRODUCTS = {
# Intel AX210 (0x8087, 0x0032), # AX210
(0x8087, 0x0032), (0x8087, 0x0036), # BE200
# Intel BE200
(0x8087, 0x0036),
} }
INTEL_FW_IMAGE_NAMES = [
"ibt-0040-0041",
"ibt-0040-1020",
"ibt-0040-1050",
"ibt-0040-2120",
"ibt-0040-4150",
"ibt-0041-0041",
"ibt-0180-0041",
"ibt-0180-1050",
"ibt-0180-4150",
"ibt-0291-0291",
"ibt-1040-0041",
"ibt-1040-1020",
"ibt-1040-1050",
"ibt-1040-2120",
"ibt-1040-4150",
]
INTEL_FIRMWARE_DIR_ENV = "BUMBLE_INTEL_FIRMWARE_DIR"
INTEL_LINUX_FIRMWARE_DIR = "/lib/firmware/intel"
_MAX_FRAGMENT_SIZE = 252
_POST_RESET_DELAY = 0.2
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# HCI Commands # HCI Commands
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
HCI_INTEL_DDC_CONFIG_WRITE_COMMAND = hci_vendor_command_op_code(0xFC8B) # type: ignore HCI_INTEL_WRITE_DEVICE_CONFIG_COMMAND = hci.hci_vendor_command_op_code(0x008B)
HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD = [0x03, 0xE4, 0x02, 0x00] HCI_INTEL_READ_VERSION_COMMAND = hci.hci_vendor_command_op_code(0x0005)
HCI_INTEL_RESET_COMMAND = hci.hci_vendor_command_op_code(0x0001)
HCI_INTEL_SECURE_SEND_COMMAND = hci.hci_vendor_command_op_code(0x0009)
HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND = hci.hci_vendor_command_op_code(0x000E)
HCI_Command.register_commands(globals()) hci.HCI_Command.register_commands(globals())
@HCI_Command.command( # type: ignore @hci.HCI_Command.command(
fields=[("params", "*")], fields=[
("param0", 1),
],
return_parameters_fields=[ return_parameters_fields=[
("params", "*"), ("status", hci.STATUS_SPEC),
("tlv", "*"),
], ],
) )
class Hci_Intel_DDC_Config_Write_Command(HCI_Command): class HCI_Intel_Read_Version_Command(hci.HCI_Command):
pass pass
@hci.HCI_Command.command(
fields=[("data_type", 1), ("data", "*")],
return_parameters_fields=[
("status", 1),
],
)
class Hci_Intel_Secure_Send_Command(hci.HCI_Command):
pass
@hci.HCI_Command.command(
fields=[
("reset_type", 1),
("patch_enable", 1),
("ddc_reload", 1),
("boot_option", 1),
("boot_address", 4),
],
return_parameters_fields=[
("data", "*"),
],
)
class HCI_Intel_Reset_Command(hci.HCI_Command):
pass
@hci.HCI_Command.command(
fields=[("data", "*")],
return_parameters_fields=[
("status", hci.STATUS_SPEC),
("params", "*"),
],
)
class Hci_Intel_Write_Device_Config_Command(hci.HCI_Command):
pass
# -----------------------------------------------------------------------------
# Functions
# -----------------------------------------------------------------------------
def intel_firmware_dir() -> pathlib.Path:
"""
Returns:
A path to a subdir of the project data dir for Intel firmware.
The directory is created if it doesn't exist.
"""
from bumble.drivers import project_data_dir
p = project_data_dir() / "firmware" / "intel"
p.mkdir(parents=True, exist_ok=True)
return p
def _find_binary_path(file_name: str) -> pathlib.Path | None:
# First check if an environment variable is set
if INTEL_FIRMWARE_DIR_ENV in os.environ:
if (
path := pathlib.Path(os.environ[INTEL_FIRMWARE_DIR_ENV]) / file_name
).is_file():
logger.debug(f"{file_name} found in env dir")
return path
# When the environment variable is set, don't look elsewhere
return None
# Then, look where the firmware download tool writes by default
if (path := intel_firmware_dir() / file_name).is_file():
logger.debug(f"{file_name} found in project data dir")
return path
# Then, look in the package's driver directory
if (path := pathlib.Path(__file__).parent / "intel_fw" / file_name).is_file():
logger.debug(f"{file_name} found in package dir")
return path
# On Linux, check the system's FW directory
if (
platform.system() == "Linux"
and (path := pathlib.Path(INTEL_LINUX_FIRMWARE_DIR) / file_name).is_file()
):
logger.debug(f"{file_name} found in Linux system FW dir")
return path
# Finally look in the current directory
if (path := pathlib.Path.cwd() / file_name).is_file():
logger.debug(f"{file_name} found in CWD")
return path
return None
def _parse_tlv(data: bytes) -> list[tuple[ValueType, Any]]:
result: list[tuple[ValueType, Any]] = []
while len(data) >= 2:
value_type = ValueType(data[0])
value_length = data[1]
value = data[2 : 2 + value_length]
typed_value: Any
if value_type == ValueType.END:
break
if value_type in (ValueType.CNVI, ValueType.CNVR):
(v,) = struct.unpack("<I", value)
typed_value = (
(((v >> 0) & 0xF) << 12)
| (((v >> 4) & 0xF) << 0)
| (((v >> 8) & 0xF) << 4)
| (((v >> 24) & 0xF) << 8)
)
elif value_type == ValueType.HARDWARE_INFO:
(v,) = struct.unpack("<I", value)
typed_value = HardwareInfo(
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
)
elif value_type in (
ValueType.USB_VENDOR_ID,
ValueType.USB_PRODUCT_ID,
ValueType.DEVICE_REVISION,
):
(typed_value,) = struct.unpack("<H", value)
elif value_type == ValueType.CURRENT_MODE_OF_OPERATION:
typed_value = ModeOfOperation(value[0])
elif value_type in (
ValueType.BUILD_TYPE,
ValueType.BUILD_NUMBER,
ValueType.SECURE_BOOT,
ValueType.OTP_LOCK,
ValueType.API_LOCK,
ValueType.DEBUG_LOCK,
ValueType.SECURE_BOOT_ENGINE_TYPE,
):
typed_value = value[0]
elif value_type == ValueType.TIMESTAMP:
typed_value = Timestamp(value[0], value[1])
elif value_type == ValueType.FIRMWARE_BUILD:
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
elif value_type == ValueType.BLUETOOTH_ADDRESS:
typed_value = hci.Address(
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
)
else:
typed_value = value
result.append((value_type, typed_value))
data = data[2 + value_length :]
return result
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class DriverError(core.BaseBumbleError):
def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message
def __str__(self) -> str:
return f"IntelDriverError({self.message})"
class ValueType(utils.OpenIntEnum):
END = 0x00
CNVI = 0x10
CNVR = 0x11
HARDWARE_INFO = 0x12
DEVICE_REVISION = 0x16
CURRENT_MODE_OF_OPERATION = 0x1C
USB_VENDOR_ID = 0x17
USB_PRODUCT_ID = 0x18
TIMESTAMP = 0x1D
BUILD_TYPE = 0x1E
BUILD_NUMBER = 0x1F
SECURE_BOOT = 0x28
OTP_LOCK = 0x2A
API_LOCK = 0x2B
DEBUG_LOCK = 0x2C
FIRMWARE_BUILD = 0x2D
SECURE_BOOT_ENGINE_TYPE = 0x2F
BLUETOOTH_ADDRESS = 0x30
class HardwarePlatform(utils.OpenIntEnum):
INTEL_37 = 0x37
class HardwareVariant(utils.OpenIntEnum):
# This is a just a partial list.
# Add other constants here as new hardware is encountered and tested.
TYPHOON_PEAK = 0x17
GALE_PEAK = 0x1C
@dataclasses.dataclass
class HardwareInfo:
platform: HardwarePlatform
variant: HardwareVariant
@dataclasses.dataclass
class Timestamp:
week: int
year: int
@dataclasses.dataclass
class FirmwareBuild:
build_number: int
timestamp: Timestamp
class ModeOfOperation(utils.OpenIntEnum):
BOOTLOADER = 0x01
INTERMEDIATE = 0x02
OPERATIONAL = 0x03
class SecureBootEngineType(utils.OpenIntEnum):
RSA = 0x00
ECDSA = 0x01
@dataclasses.dataclass
class BootParams:
css_header_offset: int
css_header_size: int
pki_offset: int
pki_size: int
sig_offset: int
sig_size: int
write_offset: int
_BOOT_PARAMS = {
SecureBootEngineType.RSA: BootParams(0, 128, 128, 256, 388, 256, 964),
SecureBootEngineType.ECDSA: BootParams(644, 128, 772, 96, 868, 96, 964),
}
class Driver(common.Driver): class Driver(common.Driver):
def __init__(self, host): def __init__(self, host: Host) -> None:
self.host = host self.host = host
self.max_in_flight_firmware_load_commands = 1
self.pending_firmware_load_commands: Deque[hci.HCI_Command] = (
collections.deque()
)
self.can_send_firmware_load_command = asyncio.Event()
self.can_send_firmware_load_command.set()
self.firmware_load_complete = asyncio.Event()
self.reset_complete = asyncio.Event()
# Parse configuration options from the driver name.
self.ddc_addon: Optional[bytes] = None
self.ddc_override: Optional[bytes] = None
driver = host.hci_metadata.get("driver")
if driver is not None and driver.startswith("intel/"):
for key, value in [
key_eq_value.split(":") for key_eq_value in driver[6:].split("+")
]:
if key == "ddc_addon":
self.ddc_addon = bytes.fromhex(value)
elif key == "ddc_override":
self.ddc_override = bytes.fromhex(value)
@staticmethod @staticmethod
def check(host): def check(host: Host) -> bool:
driver = host.hci_metadata.get("driver") driver = host.hci_metadata.get("driver")
if driver == "intel": if driver == "intel" or driver is not None and driver.startswith("intel/"):
return True return True
vendor_id = host.hci_metadata.get("vendor_id") vendor_id = host.hci_metadata.get("vendor_id")
@@ -85,18 +389,283 @@ class Driver(common.Driver):
return True return True
@classmethod @classmethod
async def for_host(cls, host, force=False): # type: ignore async def for_host(cls, host: Host, force: bool = False):
# Only instantiate this driver if explicitly selected # Only instantiate this driver if explicitly selected
if not force and not cls.check(host): if not force and not cls.check(host):
return None return None
return cls(host) return cls(host)
async def init_controller(self): def on_packet(self, packet: bytes) -> None:
"""Handler for event packets that are received from an ACL channel"""
event = hci.HCI_Event.from_bytes(packet)
if not isinstance(event, hci.HCI_Command_Complete_Event):
self.host.on_hci_event_packet(event)
return
if not event.return_parameters == hci.HCI_SUCCESS:
raise DriverError("HCI_Command_Complete_Event error")
if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets:
logger.debug(
"max_in_flight_firmware_load_commands update: "
f"{event.num_hci_command_packets}"
)
self.max_in_flight_firmware_load_commands = event.num_hci_command_packets
logger.debug(f"event: {event}")
self.pending_firmware_load_commands.popleft()
in_flight = len(self.pending_firmware_load_commands)
logger.debug(f"event received, {in_flight} still in flight")
if in_flight < self.max_in_flight_firmware_load_commands:
self.can_send_firmware_load_command.set()
async def send_firmware_load_command(self, command: hci.HCI_Command) -> None:
# Wait until we can send.
await self.can_send_firmware_load_command.wait()
# Send the command and adjust counters.
self.host.send_hci_packet(command)
self.pending_firmware_load_commands.append(command)
in_flight = len(self.pending_firmware_load_commands)
if in_flight >= self.max_in_flight_firmware_load_commands:
logger.debug(f"max commands in flight reached [{in_flight}]")
self.can_send_firmware_load_command.clear()
async def send_firmware_data(self, data_type: int, data: bytes) -> None:
while data:
fragment_size = min(len(data), _MAX_FRAGMENT_SIZE)
fragment = data[:fragment_size]
data = data[fragment_size:]
await self.send_firmware_load_command(
Hci_Intel_Secure_Send_Command(data_type=data_type, data=fragment)
)
async def load_firmware(self) -> None:
self.host.ready = True self.host.ready = True
await self.host.send_command(HCI_Reset_Command(), check_result=True) device_info = await self.read_device_info()
await self.host.send_command( logger.debug(
Hci_Intel_DDC_Config_Write_Command( "device info: \n%s",
params=HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD "\n".join(
[
f" {value_type.name}: {value}"
for value_type, value in device_info.items()
]
),
)
# Check if the firmware is already loaded.
if (
device_info.get(ValueType.CURRENT_MODE_OF_OPERATION)
== ModeOfOperation.OPERATIONAL
):
logger.debug("firmware already loaded")
return
# We only support some platforms and variants.
hardware_info = device_info.get(ValueType.HARDWARE_INFO)
if hardware_info is None:
raise DriverError("hardware info missing")
if hardware_info.platform != HardwarePlatform.INTEL_37:
raise DriverError("hardware platform not supported")
if hardware_info.variant not in (
HardwareVariant.TYPHOON_PEAK,
HardwareVariant.GALE_PEAK,
):
raise DriverError("hardware variant not supported")
# Compute the firmware name.
if ValueType.CNVI not in device_info or ValueType.CNVR not in device_info:
raise DriverError("insufficient device info, missing CNVI or CNVR")
firmware_base_name = (
"ibt-"
f"{device_info[ValueType.CNVI]:04X}-"
f"{device_info[ValueType.CNVR]:04X}"
)
logger.debug(f"FW base name: {firmware_base_name}")
firmware_name = f"{firmware_base_name}.sfi"
firmware_path = _find_binary_path(firmware_name)
if not firmware_path:
logger.warning(f"Firmware file {firmware_name} not found")
logger.warning("See https://google.github.io/bumble/drivers/intel.html")
return None
logger.debug(f"loading firmware from {firmware_path}")
firmware_image = firmware_path.read_bytes()
engine_type = device_info.get(ValueType.SECURE_BOOT_ENGINE_TYPE)
if engine_type is None:
raise DriverError("secure boot engine type missing")
if engine_type not in _BOOT_PARAMS:
raise DriverError("secure boot engine type not supported")
boot_params = _BOOT_PARAMS[engine_type]
if len(firmware_image) < boot_params.write_offset:
raise DriverError("firmware image too small")
# Register to receive vendor events.
def on_vendor_event(event: hci.HCI_Vendor_Event):
logger.debug(f"vendor event: {event}")
event_type = event.parameters[0]
if event_type == 0x02:
# Boot event
logger.debug("boot complete")
self.reset_complete.set()
elif event_type == 0x06:
# Firmware load event
logger.debug("download complete")
self.firmware_load_complete.set()
else:
logger.debug(f"ignoring vendor event type {event_type}")
self.host.on("vendor_event", on_vendor_event)
# We need to temporarily intercept packets from the controller,
# because they are formatted as HCI event packets but are received
# on the ACL channel, so the host parser would get confused.
saved_on_packet = self.host.on_packet
self.host.on_packet = self.on_packet # type: ignore
self.firmware_load_complete.clear()
# Send the CSS header
data = firmware_image[
boot_params.css_header_offset : boot_params.css_header_offset
+ boot_params.css_header_size
]
await self.send_firmware_data(0x00, data)
# Send the PKI header
data = firmware_image[
boot_params.pki_offset : boot_params.pki_offset + boot_params.pki_size
]
await self.send_firmware_data(0x03, data)
# Send the Signature header
data = firmware_image[
boot_params.sig_offset : boot_params.sig_offset + boot_params.sig_size
]
await self.send_firmware_data(0x02, data)
# Send the rest of the image.
# The payload consists of command objects, which are sent when they add up
# to a multiple of 4 bytes.
boot_address = 0
offset = boot_params.write_offset
fragment_size = 0
while offset + 3 < len(firmware_image):
(command_opcode,) = struct.unpack_from(
"<H", firmware_image, offset + fragment_size
)
command_size = firmware_image[offset + fragment_size + 2]
if command_opcode == HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND:
(boot_address,) = struct.unpack_from(
"<I", firmware_image, offset + fragment_size + 3
)
logger.debug(
"found HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND, "
f"boot_address={boot_address}"
)
fragment_size += 3 + command_size
if fragment_size % 4 == 0:
await self.send_firmware_data(
0x01, firmware_image[offset : offset + fragment_size]
)
logger.debug(f"sent {fragment_size} bytes")
offset += fragment_size
fragment_size = 0
# Wait for the firmware loading to be complete.
logger.debug("waiting for firmware to be loaded")
await self.firmware_load_complete.wait()
logger.debug("firmware loaded")
# Restore the original packet handler.
self.host.on_packet = saved_on_packet # type: ignore
# Reset
self.reset_complete.clear()
self.host.send_hci_packet(
HCI_Intel_Reset_Command(
reset_type=0x00,
patch_enable=0x01,
ddc_reload=0x00,
boot_option=0x01,
boot_address=boot_address,
) )
) )
logger.debug("waiting for reset completion")
await self.reset_complete.wait()
logger.debug("reset complete")
# Load the device config if there is one.
if self.ddc_override:
logger.debug("loading overridden DDC")
await self.load_device_config(self.ddc_override)
else:
ddc_name = f"{firmware_base_name}.ddc"
ddc_path = _find_binary_path(ddc_name)
if ddc_path:
logger.debug(f"loading DDC from {ddc_path}")
ddc_data = ddc_path.read_bytes()
await self.load_device_config(ddc_data)
if self.ddc_addon:
logger.debug("loading DDC addon")
await self.load_device_config(self.ddc_addon)
async def load_device_config(self, ddc_data: bytes) -> None:
while ddc_data:
ddc_len = 1 + ddc_data[0]
ddc_payload = ddc_data[:ddc_len]
await self.host.send_command(
Hci_Intel_Write_Device_Config_Command(data=ddc_payload)
)
ddc_data = ddc_data[ddc_len:]
async def reboot_bootloader(self) -> None:
self.host.send_hci_packet(
HCI_Intel_Reset_Command(
reset_type=0x01,
patch_enable=0x01,
ddc_reload=0x01,
boot_option=0x00,
boot_address=0,
)
)
await asyncio.sleep(_POST_RESET_DELAY)
async def read_device_info(self) -> dict[ValueType, Any]:
self.host.ready = True
response = await self.host.send_command(hci.HCI_Reset_Command())
if not (
isinstance(response, hci.HCI_Command_Complete_Event)
and response.return_parameters
in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS)
):
# When the controller is in operational mode, the response is a
# successful response.
# When the controller is in bootloader mode,
# HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything
# else is a failure.
logger.warning(f"unexpected response: {response}")
raise DriverError("unexpected HCI response")
# Read the firmware version.
response = await self.host.send_command(
HCI_Intel_Read_Version_Command(param0=0xFF)
)
if not isinstance(response, hci.HCI_Command_Complete_Event):
raise DriverError("unexpected HCI response")
if response.return_parameters.status != 0: # type: ignore
raise DriverError("HCI_Intel_Read_Version_Command error")
tlvs = _parse_tlv(response.return_parameters.tlv) # type: ignore
# Convert the list to a dict. That's Ok here because we only expect each type
# to appear just once.
return dict(tlvs)
async def init_controller(self):
await self.load_firmware()

View File

@@ -18,7 +18,7 @@
import logging import logging
import struct import struct
from .gatt import ( from bumble.gatt import (
Service, Service,
Characteristic, Characteristic,
GATT_GENERIC_ACCESS_SERVICE, GATT_GENERIC_ACCESS_SERVICE,

View File

@@ -27,25 +27,16 @@ import enum
import functools import functools
import logging import logging
import struct import struct
from typing import ( from typing import Iterable, List, Optional, Sequence, TypeVar, Union
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Union,
TYPE_CHECKING,
)
from bumble.colors import color from bumble.colors import color
from bumble.core import BaseBumbleError, UUID from bumble.core import BaseBumbleError, UUID
from bumble.att import Attribute, AttributeValue from bumble.att import Attribute, AttributeValue
if TYPE_CHECKING: # -----------------------------------------------------------------------------
from bumble.gatt_client import AttributeProxy # Typing
from bumble.device import Connection # -----------------------------------------------------------------------------
_T = TypeVar('_T')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -275,6 +266,13 @@ GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Sou
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts') GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts') GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
# Gaming Audio Service (GMAS)
GATT_GMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2C00, 'GMAP Role')
GATT_UGG_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C01, 'UGG Features')
GATT_UGT_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C02, 'UGT Features')
GATT_BGS_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C03, 'BGS Features')
GATT_BGR_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C04, 'BGR Features')
# Hearing Access Service # Hearing Access Service
GATT_HEARING_AID_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2BDA, 'Hearing Aid Features') GATT_HEARING_AID_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2BDA, 'Hearing Aid Features')
GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BDB, 'Hearing Aid Preset Control Point') GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BDB, 'Hearing Aid Preset Control Point')
@@ -288,6 +286,22 @@ GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-32
GATT_ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume') GATT_ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT') GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT')
# Apple Notification Center Service
GATT_ANCS_SERVICE = UUID('7905F431-B5CE-4E99-A40F-4B1E122D00D0', 'Apple Notification Center')
GATT_ANCS_NOTIFICATION_SOURCE_CHARACTERISTIC = UUID('9FBF120D-6301-42D9-8C58-25E699A21DBD', 'Notification Source')
GATT_ANCS_CONTROL_POINT_CHARACTERISTIC = UUID('69D1D8F3-45E1-49A8-9821-9BBDFDAAD9D9', 'Control Point')
GATT_ANCS_DATA_SOURCE_CHARACTERISTIC = UUID('22EAC6E9-24D6-4BB5-BE44-B36ACE7C7BFB', 'Data Source')
# Apple Media Service
GATT_AMS_SERVICE = UUID('89D3502B-0F36-433A-8EF4-C502AD55F8DC', 'Apple Media')
GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC = UUID('9B3C81D8-57B1-4A8A-B8DF-0E56F7CA51C2', 'Remote Command')
GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC = UUID('2F7CABCE-808D-411F-9A0C-BB92BA96C102', 'Entity Update')
GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC = UUID('C6B2F38C-23AB-46D8-A6AB-A3A870BBD5D7', 'Entity Attribute')
# Misc Apple Services
GATT_APPLE_CONTINUITY_SERVICE = UUID('D0611E78-BBB4-4591-A5F8-487910AE4366', 'Apple Continuity')
GATT_APPLE_NEARBY_SERVICE = UUID('9FA480E0-4967-4542-9390-D343DC5D04AE', 'Apple Nearby')
# Misc # Misc
GATT_DEVICE_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2A00, 'Device Name') GATT_DEVICE_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2A00, 'Device Name')
GATT_APPEARANCE_CHARACTERISTIC = UUID.from_16_bits(0x2A01, 'Appearance') GATT_APPEARANCE_CHARACTERISTIC = UUID.from_16_bits(0x2A01, 'Appearance')
@@ -304,6 +318,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features') GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features')
GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash') GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash')
GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features') GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features')
GATT_LE_GATT_SECURITY_LEVELS_CHARACTERISTIC = UUID.from_16_bits(0x2BF5, 'E GATT Security Levels')
# fmt: on # fmt: on
# pylint: enable=line-too-long # pylint: enable=line-too-long
@@ -312,8 +327,6 @@ GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bi
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def show_services(services: Iterable[Service]) -> None: def show_services(services: Iterable[Service]) -> None:
for service in services: for service in services:
print(color(str(service), 'cyan')) print(color(str(service), 'cyan'))
@@ -343,7 +356,7 @@ class Service(Attribute):
def __init__( def __init__(
self, self,
uuid: Union[str, UUID], uuid: Union[str, UUID],
characteristics: List[Characteristic], characteristics: Iterable[Characteristic],
primary=True, primary=True,
included_services: Iterable[Service] = (), included_services: Iterable[Service] = (),
) -> None: ) -> None:
@@ -362,7 +375,7 @@ class Service(Attribute):
) )
self.uuid = uuid self.uuid = uuid
self.included_services = list(included_services) self.included_services = list(included_services)
self.characteristics = characteristics[:] self.characteristics = list(characteristics)
self.primary = primary self.primary = primary
def get_advertising_data(self) -> Optional[bytes]: def get_advertising_data(self) -> Optional[bytes]:
@@ -393,7 +406,7 @@ class TemplateService(Service):
def __init__( def __init__(
self, self,
characteristics: List[Characteristic], characteristics: Iterable[Characteristic],
primary: bool = True, primary: bool = True,
included_services: Iterable[Service] = (), included_services: Iterable[Service] = (),
) -> None: ) -> None:
@@ -410,7 +423,7 @@ class IncludedServiceDeclaration(Attribute):
def __init__(self, service: Service) -> None: def __init__(self, service: Service) -> None:
declaration_bytes = struct.pack( declaration_bytes = struct.pack(
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes() '<HH2s', service.handle, service.end_group_handle, bytes(service.uuid)
) )
super().__init__( super().__init__(
GATT_INCLUDE_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes GATT_INCLUDE_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
@@ -427,7 +440,7 @@ class IncludedServiceDeclaration(Attribute):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Characteristic(Attribute): class Characteristic(Attribute[_T]):
''' '''
See Vol 3, Part G - 3.3 CHARACTERISTIC DEFINITION See Vol 3, Part G - 3.3 CHARACTERISTIC DEFINITION
''' '''
@@ -435,6 +448,8 @@ class Characteristic(Attribute):
uuid: UUID uuid: UUID
properties: Characteristic.Properties properties: Characteristic.Properties
EVENT_SUBSCRIPTION = "subscription"
class Properties(enum.IntFlag): class Properties(enum.IntFlag):
"""Property flags""" """Property flags"""
@@ -490,7 +505,7 @@ class Characteristic(Attribute):
uuid: Union[str, bytes, UUID], uuid: Union[str, bytes, UUID],
properties: Characteristic.Properties, properties: Characteristic.Properties,
permissions: Union[str, Attribute.Permissions], permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, CharacteristicValue] = b'', value: Union[AttributeValue[_T], _T, None] = None,
descriptors: Sequence[Descriptor] = (), descriptors: Sequence[Descriptor] = (),
): ):
super().__init__(uuid, permissions, value) super().__init__(uuid, permissions, value)
@@ -525,7 +540,11 @@ class CharacteristicDeclaration(Attribute):
characteristic: Characteristic characteristic: Characteristic
def __init__(self, characteristic: Characteristic, value_handle: int) -> None: def __init__(
self,
characteristic: Characteristic,
value_handle: int,
) -> None:
declaration_bytes = ( declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle) struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes() + characteristic.uuid.to_pdu_bytes()
@@ -546,195 +565,10 @@ class CharacteristicDeclaration(Attribute):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CharacteristicValue(AttributeValue): class CharacteristicValue(AttributeValue[_T]):
"""Same as AttributeValue, for backward compatibility""" """Same as AttributeValue, for backward compatibility"""
# -----------------------------------------------------------------------------
class CharacteristicAdapter:
'''
An adapter that can adapt Characteristic and AttributeProxy objects
by wrapping their `read_value()` and `write_value()` methods with ones that
return/accept encoded/decoded values.
For proxies (i.e used by a GATT client), the adaptation is one where the return
value of `read_value()` is decoded and the value passed to `write_value()` is
encoded. The `subscribe()` method, is wrapped with one where the values are decoded
before being passed to the subscriber.
For local values (i.e hosted by a GATT server) the adaptation is one where the
return value of `read_value()` is encoded and the value passed to `write_value()`
is decoded.
'''
read_value: Callable
write_value: Callable
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
self.wrapped_characteristic = characteristic
self.subscribers: Dict[Callable, Callable] = (
{}
) # Map from subscriber to proxy subscriber
if isinstance(characteristic, Characteristic):
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
else:
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
self.subscribe = self.wrapped_subscribe
self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name):
return getattr(self.wrapped_characteristic, name)
def __setattr__(self, name, value):
if name in (
'wrapped_characteristic',
'subscribers',
'read_value',
'write_value',
'subscribe',
'unsubscribe',
):
super().__setattr__(name, value)
else:
setattr(self.wrapped_characteristic, name, value)
async def read_encoded_value(self, connection):
return self.encode_value(
await self.wrapped_characteristic.read_value(connection)
)
async def write_encoded_value(self, connection, value):
return await self.wrapped_characteristic.write_value(
connection, self.decode_value(value)
)
async def read_decoded_value(self):
return self.decode_value(await self.wrapped_characteristic.read_value())
async def write_decoded_value(self, value, with_response=False):
return await self.wrapped_characteristic.write_value(
self.encode_value(value), with_response
)
def encode_value(self, value):
return value
def decode_value(self, value):
return value
def wrapped_subscribe(self, subscriber=None):
if subscriber is not None:
if subscriber in self.subscribers:
# We already have a proxy subscriber
subscriber = self.subscribers[subscriber]
else:
# Create and register a proxy that will decode the value
original_subscriber = subscriber
def on_change(value):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
return self.wrapped_characteristic.subscribe(subscriber)
def wrapped_unsubscribe(self, subscriber=None):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return self.wrapped_characteristic.unsubscribe(subscriber)
def __str__(self) -> str:
wrapped = str(self.wrapped_characteristic)
return f'{self.__class__.__name__}({wrapped})'
# -----------------------------------------------------------------------------
class DelegatedCharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that converts bytes values using an encode and a decode function.
'''
def __init__(self, characteristic, encode=None, decode=None):
super().__init__(characteristic)
self.encode = encode
self.decode = decode
def encode_value(self, value):
return self.encode(value) if self.encode else value
def decode_value(self, value):
return self.decode(value) if self.decode else value
# -----------------------------------------------------------------------------
class PackedCharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format.
For formats with a single value, the adapted `read_value` and `write_value`
methods return/accept single values. For formats with multiple values,
they return/accept a tuple with the same number of elements as is required for
the format.
'''
def __init__(self, characteristic, pack_format):
super().__init__(characteristic)
self.struct = struct.Struct(pack_format)
def pack(self, *values):
return self.struct.pack(*values)
def unpack(self, buffer):
return self.struct.unpack(buffer)
def encode_value(self, value):
return self.pack(*value if isinstance(value, tuple) else (value,))
def decode_value(self, value):
unpacked = self.unpack(value)
return unpacked[0] if len(unpacked) == 1 else unpacked
# -----------------------------------------------------------------------------
class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
'''
Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format.
The adapted `read_value` and `write_value` methods return/accept aa dictionary which
is packed/unpacked according to format, with the arguments extracted from the
dictionary by key, in the same order as they occur in the `keys` parameter.
'''
def __init__(self, characteristic, pack_format, keys):
super().__init__(characteristic, pack_format)
self.keys = keys
# pylint: disable=arguments-differ
def pack(self, values):
return super().pack(*(values[key] for key in self.keys))
def unpack(self, buffer):
return dict(zip(self.keys, super().unpack(buffer)))
# -----------------------------------------------------------------------------
class UTF8CharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that converts strings to/from bytes using UTF-8 encoding
'''
def encode_value(self, value: str) -> bytes:
return value.encode('utf-8')
def decode_value(self, value: bytes) -> str:
return value.decode('utf-8')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Descriptor(Attribute): class Descriptor(Attribute):
''' '''
@@ -769,3 +603,23 @@ class ClientCharacteristicConfigurationBits(enum.IntFlag):
DEFAULT = 0x0000 DEFAULT = 0x0000
NOTIFICATION = 0x0001 NOTIFICATION = 0x0001
INDICATION = 0x0002 INDICATION = 0x0002
# -----------------------------------------------------------------------------
class ClientSupportedFeatures(enum.IntFlag):
'''
See Vol 3, Part G - 7.2 - Table 7.6: Client Supported Features bit assignments.
'''
ROBUST_CACHING = 0x01
ENHANCED_ATT_BEARER = 0x02
MULTIPLE_HANDLE_VALUE_NOTIFICATIONS = 0x04
# -----------------------------------------------------------------------------
class ServerSupportedFeatures(enum.IntFlag):
'''
See Vol 3, Part G - 7.4 - Table 7.11: Server Supported Features bit assignments.
'''
EATT_SUPPORTED = 0x01

374
bumble/gatt_adapters.py Normal file
View File

@@ -0,0 +1,374 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# GATT - Type Adapters
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from typing import (
Any,
Callable,
Generic,
Iterable,
Literal,
Optional,
Type,
TypeVar,
)
from bumble.core import InvalidOperationError
from bumble.gatt import Characteristic
from bumble.gatt_client import CharacteristicProxy
from bumble import utils
# -----------------------------------------------------------------------------
# Typing
# -----------------------------------------------------------------------------
_T = TypeVar('_T')
_T2 = TypeVar('_T2', bound=utils.ByteSerializable)
_T3 = TypeVar('_T3', bound=utils.IntConvertible)
# -----------------------------------------------------------------------------
class CharacteristicAdapter(Characteristic, Generic[_T]):
'''Base class for GATT Characteristic adapters.'''
def __init__(self, characteristic: Characteristic) -> None:
super().__init__(
characteristic.uuid,
characteristic.properties,
characteristic.permissions,
characteristic.value,
characteristic.descriptors,
)
# -----------------------------------------------------------------------------
class CharacteristicProxyAdapter(CharacteristicProxy[_T]):
'''Base class for GATT CharacteristicProxy adapters.'''
def __init__(self, characteristic_proxy: CharacteristicProxy):
super().__init__(
characteristic_proxy.client,
characteristic_proxy.handle,
characteristic_proxy.end_group_handle,
characteristic_proxy.uuid,
characteristic_proxy.properties,
)
# -----------------------------------------------------------------------------
class DelegatedCharacteristicAdapter(CharacteristicAdapter[_T]):
'''
Adapter that converts bytes values using an encode and/or a decode function.
'''
def __init__(
self,
characteristic: Characteristic,
encode: Optional[Callable[[_T], bytes]] = None,
decode: Optional[Callable[[bytes], _T]] = None,
):
super().__init__(characteristic)
self.encode = encode
self.decode = decode
def encode_value(self, value: _T) -> bytes:
if self.encode is None:
raise InvalidOperationError('delegated adapter does not have an encoder')
return self.encode(value)
def decode_value(self, value: bytes) -> _T:
if self.decode is None:
raise InvalidOperationError('delegate adapter does not have a decoder')
return self.decode(value)
# -----------------------------------------------------------------------------
class DelegatedCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T]):
'''
Adapter that converts bytes values using an encode and a decode function.
'''
def __init__(
self,
characteristic_proxy: CharacteristicProxy,
encode: Optional[Callable[[_T], bytes]] = None,
decode: Optional[Callable[[bytes], _T]] = None,
):
super().__init__(characteristic_proxy)
self.encode = encode
self.decode = decode
def encode_value(self, value: _T) -> bytes:
if self.encode is None:
raise InvalidOperationError('delegated adapter does not have an encoder')
return self.encode(value)
def decode_value(self, value: bytes) -> _T:
if self.decode is None:
raise InvalidOperationError('delegate adapter does not have a decoder')
return self.decode(value)
# -----------------------------------------------------------------------------
class PackedCharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format.
For formats with a single value, the adapted `read_value` and `write_value`
methods return/accept single values. For formats with multiple values,
they return/accept a tuple with the same number of elements as is required for
the format.
'''
def __init__(self, characteristic: Characteristic, pack_format: str) -> None:
super().__init__(characteristic)
self.struct = struct.Struct(pack_format)
def pack(self, *values) -> bytes:
return self.struct.pack(*values)
def unpack(self, buffer: bytes) -> tuple:
return self.struct.unpack(buffer)
def encode_value(self, value: Any) -> bytes:
return self.pack(*value if isinstance(value, tuple) else (value,))
def decode_value(self, value: bytes) -> Any:
unpacked = self.unpack(value)
return unpacked[0] if len(unpacked) == 1 else unpacked
# -----------------------------------------------------------------------------
class PackedCharacteristicProxyAdapter(CharacteristicProxyAdapter):
'''
Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format.
For formats with a single value, the adapted `read_value` and `write_value`
methods return/accept single values. For formats with multiple values,
they return/accept a tuple with the same number of elements as is required for
the format.
'''
def __init__(self, characteristic_proxy, pack_format):
super().__init__(characteristic_proxy)
self.struct = struct.Struct(pack_format)
def pack(self, *values) -> bytes:
return self.struct.pack(*values)
def unpack(self, buffer: bytes) -> tuple:
return self.struct.unpack(buffer)
def encode_value(self, value: Any) -> bytes:
return self.pack(*value if isinstance(value, tuple) else (value,))
def decode_value(self, value: bytes) -> Any:
unpacked = self.unpack(value)
return unpacked[0] if len(unpacked) == 1 else unpacked
# -----------------------------------------------------------------------------
class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
'''
Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format.
The adapted `read_value` and `write_value` methods return/accept a dictionary which
is packed/unpacked according to format, with the arguments extracted from the
dictionary by key, in the same order as they occur in the `keys` parameter.
'''
def __init__(
self, characteristic: Characteristic, pack_format: str, keys: Iterable[str]
) -> None:
super().__init__(characteristic, pack_format)
self.keys = keys
# pylint: disable=arguments-differ
def pack(self, values) -> bytes:
return super().pack(*(values[key] for key in self.keys))
def unpack(self, buffer: bytes) -> Any:
return dict(zip(self.keys, super().unpack(buffer)))
# -----------------------------------------------------------------------------
class MappedCharacteristicProxyAdapter(PackedCharacteristicProxyAdapter):
'''
Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format.
The adapted `read_value` and `write_value` methods return/accept a dictionary which
is packed/unpacked according to format, with the arguments extracted from the
dictionary by key, in the same order as they occur in the `keys` parameter.
'''
def __init__(
self,
characteristic_proxy: CharacteristicProxy,
pack_format: str,
keys: Iterable[str],
) -> None:
super().__init__(characteristic_proxy, pack_format)
self.keys = keys
# pylint: disable=arguments-differ
def pack(self, values) -> bytes:
return super().pack(*(values[key] for key in self.keys))
def unpack(self, buffer: bytes) -> Any:
return dict(zip(self.keys, super().unpack(buffer)))
# -----------------------------------------------------------------------------
class UTF8CharacteristicAdapter(CharacteristicAdapter[str]):
'''
Adapter that converts strings to/from bytes using UTF-8 encoding
'''
def encode_value(self, value: str) -> bytes:
return value.encode('utf-8')
def decode_value(self, value: bytes) -> str:
return value.decode('utf-8')
# -----------------------------------------------------------------------------
class UTF8CharacteristicProxyAdapter(CharacteristicProxyAdapter[str]):
'''
Adapter that converts strings to/from bytes using UTF-8 encoding
'''
def encode_value(self, value: str) -> bytes:
return value.encode('utf-8')
def decode_value(self, value: bytes) -> str:
return value.decode('utf-8')
# -----------------------------------------------------------------------------
class SerializableCharacteristicAdapter(CharacteristicAdapter[_T2]):
'''
Adapter that converts any class to/from bytes using the class'
`to_bytes` and `__bytes__` methods, respectively.
'''
def __init__(self, characteristic: Characteristic, cls: Type[_T2]) -> None:
super().__init__(characteristic)
self.cls = cls
def encode_value(self, value: _T2) -> bytes:
return bytes(value)
def decode_value(self, value: bytes) -> _T2:
return self.cls.from_bytes(value)
# -----------------------------------------------------------------------------
class SerializableCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T2]):
'''
Adapter that converts any class to/from bytes using the class'
`to_bytes` and `__bytes__` methods, respectively.
'''
def __init__(
self, characteristic_proxy: CharacteristicProxy, cls: Type[_T2]
) -> None:
super().__init__(characteristic_proxy)
self.cls = cls
def encode_value(self, value: _T2) -> bytes:
return bytes(value)
def decode_value(self, value: bytes) -> _T2:
return self.cls.from_bytes(value)
# -----------------------------------------------------------------------------
class EnumCharacteristicAdapter(CharacteristicAdapter[_T3]):
'''
Adapter that converts int-enum-like classes to/from bytes using the class'
`int().to_bytes()` and `from_bytes()` methods, respectively.
'''
def __init__(
self,
characteristic: Characteristic,
cls: Type[_T3],
length: int,
byteorder: Literal['little', 'big'] = 'little',
):
"""
Initialize an instance.
Params:
characteristic: the Characteristic to adapt to/from
cls: the class to/from which to convert integer values
length: number of bytes used to represent integer values
byteorder: byte order of the byte representation of integers.
"""
super().__init__(characteristic)
self.cls = cls
self.length = length
self.byteorder = byteorder
def encode_value(self, value: _T3) -> bytes:
return int(value).to_bytes(self.length, self.byteorder)
def decode_value(self, value: bytes) -> _T3:
int_value = int.from_bytes(value, self.byteorder)
return self.cls(int_value)
# -----------------------------------------------------------------------------
class EnumCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T3]):
'''
Adapter that converts int-enum-like classes to/from bytes using the class'
`int().to_bytes()` and `from_bytes()` methods, respectively.
'''
def __init__(
self,
characteristic_proxy: CharacteristicProxy,
cls: Type[_T3],
length: int,
byteorder: Literal['little', 'big'] = 'little',
):
"""
Initialize an instance.
Params:
characteristic_proxy: the CharacteristicProxy to adapt to/from
cls: the class to/from which to convert integer values
length: number of bytes used to represent integer values
byteorder: byte order of the byte representation of integers.
"""
super().__init__(characteristic_proxy)
self.cls = cls
self.length = length
self.byteorder = byteorder
def encode_value(self, value: _T3) -> bytes:
return int(value).to_bytes(self.length, self.byteorder)
def decode_value(self, value: bytes) -> _T3:
int_value = int.from_bytes(value, self.byteorder)
a = self.cls(int_value)
return self.cls(int_value)

View File

@@ -29,24 +29,25 @@ import logging
import struct import struct
from datetime import datetime from datetime import datetime
from typing import ( from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
List, List,
Optional, Optional,
Dict,
Tuple,
Callable,
Union,
Any,
Iterable,
Type,
Set, Set,
Tuple,
Union,
Type,
TypeVar,
TYPE_CHECKING, TYPE_CHECKING,
) )
from pyee import EventEmitter
from .colors import color from bumble.colors import color
from .hci import HCI_Constant from bumble.hci import HCI_Constant
from .att import ( from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR, ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR, ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID, ATT_CID,
@@ -67,9 +68,10 @@ from .att import (
ATT_Write_Request, ATT_Write_Request,
ATT_Error, ATT_Error,
) )
from . import core from bumble import utils
from .core import UUID, InvalidStateError from bumble import core
from .gatt import ( from bumble.core import UUID, InvalidStateError
from bumble.gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
@@ -78,12 +80,18 @@ from .gatt import (
GATT_INCLUDE_ATTRIBUTE_TYPE, GATT_INCLUDE_ATTRIBUTE_TYPE,
Characteristic, Characteristic,
ClientCharacteristicConfigurationBits, ClientCharacteristicConfigurationBits,
InvalidServiceError,
TemplateService, TemplateService,
) )
# -----------------------------------------------------------------------------
# Typing
# -----------------------------------------------------------------------------
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Connection from bumble.device import Connection
_T = TypeVar('_T')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -109,31 +117,31 @@ def show_services(services: Iterable[ServiceProxy]) -> None:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Proxies # Proxies
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AttributeProxy(EventEmitter): class AttributeProxy(utils.EventEmitter, Generic[_T]):
def __init__( def __init__(
self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID
) -> None: ) -> None:
EventEmitter.__init__(self) utils.EventEmitter.__init__(self)
self.client = client self.client = client
self.handle = handle self.handle = handle
self.end_group_handle = end_group_handle self.end_group_handle = end_group_handle
self.type = attribute_type self.type = attribute_type
async def read_value(self, no_long_read: bool = False) -> bytes: async def read_value(self, no_long_read: bool = False) -> _T:
return self.decode_value( return self.decode_value(
await self.client.read_value(self.handle, no_long_read) await self.client.read_value(self.handle, no_long_read)
) )
async def write_value(self, value, with_response=False): async def write_value(self, value: _T, with_response=False):
return await self.client.write_value( return await self.client.write_value(
self.handle, self.encode_value(value), with_response self.handle, self.encode_value(value), with_response
) )
def encode_value(self, value: Any) -> bytes: def encode_value(self, value: _T) -> bytes:
return value return value # type: ignore
def decode_value(self, value_bytes: bytes) -> Any: def decode_value(self, value: bytes) -> _T:
return value_bytes return value # type: ignore
def __str__(self) -> str: def __str__(self) -> str:
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})' return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
@@ -141,7 +149,7 @@ class AttributeProxy(EventEmitter):
class ServiceProxy(AttributeProxy): class ServiceProxy(AttributeProxy):
uuid: UUID uuid: UUID
characteristics: List[CharacteristicProxy] characteristics: List[CharacteristicProxy[bytes]]
included_services: List[ServiceProxy] included_services: List[ServiceProxy]
@staticmethod @staticmethod
@@ -162,29 +170,48 @@ class ServiceProxy(AttributeProxy):
self.uuid = uuid self.uuid = uuid
self.characteristics = [] self.characteristics = []
async def discover_characteristics(self, uuids=()): async def discover_characteristics(
self, uuids=()
) -> list[CharacteristicProxy[bytes]]:
return await self.client.discover_characteristics(uuids, self) return await self.client.discover_characteristics(uuids, self)
def get_characteristics_by_uuid(self, uuid): def get_characteristics_by_uuid(
self, uuid: UUID
) -> list[CharacteristicProxy[bytes]]:
"""Get all the characteristics with a specified UUID."""
return self.client.get_characteristics_by_uuid(uuid, self) return self.client.get_characteristics_by_uuid(uuid, self)
def get_required_characteristic_by_uuid(
self, uuid: UUID
) -> CharacteristicProxy[bytes]:
"""
Get the first characteristic with a specified UUID.
If no characteristic with that UUID is found, an InvalidServiceError is raised.
"""
if not (characteristics := self.get_characteristics_by_uuid(uuid)):
raise InvalidServiceError(f'{uuid} characteristic not found')
return characteristics[0]
def __str__(self) -> str: def __str__(self) -> str:
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})' return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
class CharacteristicProxy(AttributeProxy): class CharacteristicProxy(AttributeProxy[_T]):
properties: Characteristic.Properties properties: Characteristic.Properties
descriptors: List[DescriptorProxy] descriptors: List[DescriptorProxy]
subscribers: Dict[Any, Callable[[bytes], Any]] subscribers: Dict[Any, Callable[[_T], Any]]
EVENT_UPDATE = "update"
def __init__( def __init__(
self, self,
client, client: Client,
handle, handle: int,
end_group_handle, end_group_handle: int,
uuid, uuid: UUID,
properties: int, properties: int,
): ) -> None:
super().__init__(client, handle, end_group_handle, uuid) super().__init__(client, handle, end_group_handle, uuid)
self.uuid = uuid self.uuid = uuid
self.properties = Characteristic.Properties(properties) self.properties = Characteristic.Properties(properties)
@@ -192,21 +219,21 @@ class CharacteristicProxy(AttributeProxy):
self.descriptors_discovered = False self.descriptors_discovered = False
self.subscribers = {} # Map from subscriber to proxy subscriber self.subscribers = {} # Map from subscriber to proxy subscriber
def get_descriptor(self, descriptor_type): def get_descriptor(self, descriptor_type: UUID) -> Optional[DescriptorProxy]:
for descriptor in self.descriptors: for descriptor in self.descriptors:
if descriptor.type == descriptor_type: if descriptor.type == descriptor_type:
return descriptor return descriptor
return None return None
async def discover_descriptors(self): async def discover_descriptors(self) -> list[DescriptorProxy]:
return await self.client.discover_descriptors(self) return await self.client.discover_descriptors(self)
async def subscribe( async def subscribe(
self, self,
subscriber: Optional[Callable[[bytes], Any]] = None, subscriber: Optional[Callable[[_T], Any]] = None,
prefer_notify: bool = True, prefer_notify: bool = True,
): ) -> None:
if subscriber is not None: if subscriber is not None:
if subscriber in self.subscribers: if subscriber in self.subscribers:
# We already have a proxy subscriber # We already have a proxy subscriber
@@ -221,13 +248,13 @@ class CharacteristicProxy(AttributeProxy):
self.subscribers[subscriber] = on_change self.subscribers[subscriber] = on_change
subscriber = on_change subscriber = on_change
return await self.client.subscribe(self, subscriber, prefer_notify) await self.client.subscribe(self, subscriber, prefer_notify)
async def unsubscribe(self, subscriber=None, force=False): async def unsubscribe(self, subscriber=None, force=False) -> None:
if subscriber in self.subscribers: if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber) subscriber = self.subscribers.pop(subscriber)
return await self.client.unsubscribe(self, subscriber, force) await self.client.unsubscribe(self, subscriber, force)
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
@@ -237,8 +264,8 @@ class CharacteristicProxy(AttributeProxy):
) )
class DescriptorProxy(AttributeProxy): class DescriptorProxy(AttributeProxy[bytes]):
def __init__(self, client, handle, descriptor_type): def __init__(self, client: Client, handle: int, descriptor_type: UUID) -> None:
super().__init__(client, handle, 0, descriptor_type) super().__init__(client, handle, 0, descriptor_type)
def __str__(self) -> str: def __str__(self) -> str:
@@ -283,7 +310,7 @@ class Client:
self.services = [] self.services = []
self.cached_values = {} self.cached_values = {}
connection.on('disconnection', self.on_disconnection) connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
def send_gatt_pdu(self, pdu: bytes) -> None: def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(ATT_CID, pdu) self.connection.send_l2cap_pdu(ATT_CID, pdu)
@@ -292,7 +319,7 @@ class Client:
logger.debug( logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}' f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
) )
self.send_gatt_pdu(command.to_bytes()) self.send_gatt_pdu(bytes(command))
async def send_request(self, request: ATT_PDU): async def send_request(self, request: ATT_PDU):
logger.debug( logger.debug(
@@ -310,7 +337,7 @@ class Client:
self.pending_request = request self.pending_request = request
try: try:
self.send_gatt_pdu(request.to_bytes()) self.send_gatt_pdu(bytes(request))
response = await asyncio.wait_for( response = await asyncio.wait_for(
self.pending_response, GATT_REQUEST_TIMEOUT self.pending_response, GATT_REQUEST_TIMEOUT
) )
@@ -328,7 +355,7 @@ class Client:
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] ' f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}' f'{confirmation}'
) )
self.send_gatt_pdu(confirmation.to_bytes()) self.send_gatt_pdu(bytes(confirmation))
async def request_mtu(self, mtu: int) -> int: async def request_mtu(self, mtu: int) -> int:
# Check the range # Check the range
@@ -357,7 +384,7 @@ class Client:
def get_characteristics_by_uuid( def get_characteristics_by_uuid(
self, uuid: UUID, service: Optional[ServiceProxy] = None self, uuid: UUID, service: Optional[ServiceProxy] = None
) -> List[CharacteristicProxy]: ) -> List[CharacteristicProxy[bytes]]:
services = [service] if service else self.services services = [service] if service else self.services
return [ return [
c c
@@ -609,7 +636,7 @@ class Client:
async def discover_characteristics( async def discover_characteristics(
self, uuids, service: Optional[ServiceProxy] self, uuids, service: Optional[ServiceProxy]
) -> List[CharacteristicProxy]: ) -> List[CharacteristicProxy[bytes]]:
''' '''
See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2 See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2
Discover Characteristics by UUID Discover Characteristics by UUID
@@ -622,12 +649,12 @@ class Client:
services = [service] if service else self.services services = [service] if service else self.services
# Perform characteristic discovery for each service # Perform characteristic discovery for each service
discovered_characteristics: List[CharacteristicProxy] = [] discovered_characteristics: List[CharacteristicProxy[bytes]] = []
for service in services: for service in services:
starting_handle = service.handle starting_handle = service.handle
ending_handle = service.end_group_handle ending_handle = service.end_group_handle
characteristics: List[CharacteristicProxy] = [] characteristics: List[CharacteristicProxy[bytes]] = []
while starting_handle <= ending_handle: while starting_handle <= ending_handle:
response = await self.send_request( response = await self.send_request(
ATT_Read_By_Type_Request( ATT_Read_By_Type_Request(
@@ -667,7 +694,7 @@ class Client:
properties, handle = struct.unpack_from('<BH', attribute_value) properties, handle = struct.unpack_from('<BH', attribute_value)
characteristic_uuid = UUID.from_bytes(attribute_value[3:]) characteristic_uuid = UUID.from_bytes(attribute_value[3:])
characteristic = CharacteristicProxy( characteristic = CharacteristicProxy[bytes](
self, handle, 0, characteristic_uuid, properties self, handle, 0, characteristic_uuid, properties
) )
@@ -760,7 +787,7 @@ class Client:
return descriptors return descriptors
async def discover_attributes(self) -> List[AttributeProxy]: async def discover_attributes(self) -> List[AttributeProxy[bytes]]:
''' '''
Discover all attributes, regardless of type Discover all attributes, regardless of type
''' '''
@@ -793,7 +820,7 @@ class Client:
logger.warning(f'bogus handle value: {attribute_handle}') logger.warning(f'bogus handle value: {attribute_handle}')
return [] return []
attribute = AttributeProxy( attribute = AttributeProxy[bytes](
self, attribute_handle, 0, UUID.from_bytes(attribute_uuid) self, attribute_handle, 0, UUID.from_bytes(attribute_uuid)
) )
attributes.append(attribute) attributes.append(attribute)
@@ -806,7 +833,7 @@ class Client:
async def subscribe( async def subscribe(
self, self,
characteristic: CharacteristicProxy, characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None, subscriber: Optional[Callable[[Any], Any]] = None,
prefer_notify: bool = True, prefer_notify: bool = True,
) -> None: ) -> None:
# If we haven't already discovered the descriptors for this characteristic, # If we haven't already discovered the descriptors for this characteristic,
@@ -856,7 +883,7 @@ class Client:
async def unsubscribe( async def unsubscribe(
self, self,
characteristic: CharacteristicProxy, characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None, subscriber: Optional[Callable[[Any], Any]] = None,
force: bool = False, force: bool = False,
) -> None: ) -> None:
''' '''
@@ -898,6 +925,12 @@ class Client:
) and subscriber in subscribers: ) and subscriber in subscribers:
subscribers.remove(subscriber) subscribers.remove(subscriber)
# The characteristic itself is added as subscriber. If it is the
# last remaining subscriber, we remove it, such that the clean up
# works correctly. Otherwise the CCCD never is set back to 0.
if len(subscribers) == 1 and characteristic in subscribers:
subscribers.remove(characteristic)
# Cleanup if we removed the last one # Cleanup if we removed the last one
if not subscribers: if not subscribers:
del subscriber_set[characteristic.handle] del subscriber_set[characteristic.handle]
@@ -1111,7 +1144,7 @@ class Client:
if callable(subscriber): if callable(subscriber):
subscriber(notification.attribute_value) subscriber(notification.attribute_value)
else: else:
subscriber.emit('update', notification.attribute_value) subscriber.emit(subscriber.EVENT_UPDATE, notification.attribute_value)
def on_att_handle_value_indication(self, indication): def on_att_handle_value_indication(self, indication):
# Call all subscribers # Call all subscribers
@@ -1126,7 +1159,7 @@ class Client:
if callable(subscriber): if callable(subscriber):
subscriber(indication.attribute_value) subscriber(indication.attribute_value)
else: else:
subscriber.emit('update', indication.attribute_value) subscriber.emit(subscriber.EVENT_UPDATE, indication.attribute_value)
# Confirm that we received the indication # Confirm that we received the indication
self.send_confirmation(ATT_Handle_Value_Confirmation()) self.send_confirmation(ATT_Handle_Value_Confirmation())

View File

@@ -28,8 +28,16 @@ import asyncio
import logging import logging
from collections import defaultdict from collections import defaultdict
import struct import struct
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING from typing import (
from pyee import EventEmitter Dict,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Type,
TYPE_CHECKING,
)
from bumble.colors import color from bumble.colors import color
from bumble.core import UUID from bumble.core import UUID
@@ -74,7 +82,7 @@ from bumble.gatt import (
Descriptor, Descriptor,
Service, Service,
) )
from bumble.utils import AsyncRunner from bumble import utils
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Device, Connection from bumble.device import Device, Connection
@@ -94,7 +102,7 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# GATT Server # GATT Server
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Server(EventEmitter): class Server(utils.EventEmitter):
attributes: List[Attribute] attributes: List[Attribute]
services: List[Service] services: List[Service]
attributes_by_handle: Dict[int, Attribute] attributes_by_handle: Dict[int, Attribute]
@@ -102,6 +110,8 @@ class Server(EventEmitter):
indication_semaphores: defaultdict[int, asyncio.Semaphore] indication_semaphores: defaultdict[int, asyncio.Semaphore]
pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]] pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
EVENT_CHARACTERISTIC_SUBSCRIPTION = "characteristic_subscription"
def __init__(self, device: Device) -> None: def __init__(self, device: Device) -> None:
super().__init__() super().__init__()
self.device = device self.device = device
@@ -339,10 +349,13 @@ class Server(EventEmitter):
notify_enabled = value[0] & 0x01 != 0 notify_enabled = value[0] & 0x01 != 0
indicate_enabled = value[0] & 0x02 != 0 indicate_enabled = value[0] & 0x02 != 0
characteristic.emit( characteristic.emit(
'subscription', connection, notify_enabled, indicate_enabled characteristic.EVENT_SUBSCRIPTION,
connection,
notify_enabled,
indicate_enabled,
) )
self.emit( self.emit(
'characteristic_subscription', self.EVENT_CHARACTERISTIC_SUBSCRIPTION,
connection, connection,
characteristic, characteristic,
notify_enabled, notify_enabled,
@@ -353,7 +366,7 @@ class Server(EventEmitter):
logger.debug( logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}' f'GATT Response from server: [0x{connection.handle:04X}] {response}'
) )
self.send_gatt_pdu(connection.handle, response.to_bytes()) self.send_gatt_pdu(connection.handle, bytes(response))
async def notify_subscriber( async def notify_subscriber(
self, self,
@@ -450,7 +463,7 @@ class Server(EventEmitter):
) )
try: try:
self.send_gatt_pdu(connection.handle, indication.to_bytes()) self.send_gatt_pdu(connection.handle, bytes(indication))
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT) await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError as error: except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red')) logger.warning(color('!!! GATT Indicate timeout', 'red'))
@@ -458,7 +471,7 @@ class Server(EventEmitter):
finally: finally:
self.pending_confirmations[connection.handle] = None self.pending_confirmations[connection.handle] = None
async def notify_or_indicate_subscribers( async def _notify_or_indicate_subscribers(
self, self,
indicate: bool, indicate: bool,
attribute: Attribute, attribute: Attribute,
@@ -492,7 +505,9 @@ class Server(EventEmitter):
value: Optional[bytes] = None, value: Optional[bytes] = None,
force: bool = False, force: bool = False,
): ):
return await self.notify_or_indicate_subscribers(False, attribute, value, force) return await self._notify_or_indicate_subscribers(
False, attribute, value, force
)
async def indicate_subscribers( async def indicate_subscribers(
self, self,
@@ -500,7 +515,7 @@ class Server(EventEmitter):
value: Optional[bytes] = None, value: Optional[bytes] = None,
force: bool = False, force: bool = False,
): ):
return await self.notify_or_indicate_subscribers(True, attribute, value, force) return await self._notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection: Connection) -> None: def on_disconnection(self, connection: Connection) -> None:
if connection.handle in self.subscribers: if connection.handle in self.subscribers:
@@ -651,7 +666,7 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
@AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request(self, connection, request): async def on_att_find_by_type_value_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
@@ -704,7 +719,7 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
@AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_read_by_type_request(self, connection, request): async def on_att_read_by_type_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
@@ -770,7 +785,7 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
@AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_read_request(self, connection, request): async def on_att_read_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
@@ -796,7 +811,7 @@ class Server(EventEmitter):
) )
self.send_response(connection, response) self.send_response(connection, response)
@AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_read_blob_request(self, connection, request): async def on_att_read_blob_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
@@ -841,7 +856,7 @@ class Server(EventEmitter):
) )
self.send_response(connection, response) self.send_response(connection, response)
@AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request(self, connection, request): async def on_att_read_by_group_type_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
@@ -909,7 +924,7 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
@AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_write_request(self, connection, request): async def on_att_write_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
@@ -956,7 +971,7 @@ class Server(EventEmitter):
response = ATT_Write_Response() response = ATT_Write_Response()
self.send_response(connection, response) self.send_response(connection, response)
@AsyncRunner.run_in_task() @utils.AsyncRunner.run_in_task()
async def on_att_write_command(self, connection, request): async def on_att_write_command(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,6 @@ import asyncio
import dataclasses import dataclasses
import enum import enum
import traceback import traceback
import pyee
import re import re
from typing import ( from typing import (
Dict, Dict,
@@ -45,6 +44,7 @@ from bumble import at
from bumble import device from bumble import device
from bumble import rfcomm from bumble import rfcomm
from bumble import sdp from bumble import sdp
from bumble import utils
from bumble.colors import color from bumble.colors import color
from bumble.core import ( from bumble.core import (
ProtocolError, ProtocolError,
@@ -141,7 +141,7 @@ class HfFeature(enum.IntFlag):
""" """
HF supported features (AT+BRSF=) (normative). HF supported features (AT+BRSF=) (normative).
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007. Hands-Free Profile v1.9, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
""" """
EC_NR = 0x001 # Echo Cancel & Noise reduction EC_NR = 0x001 # Echo Cancel & Noise reduction
@@ -155,14 +155,14 @@ class HfFeature(enum.IntFlag):
HF_INDICATORS = 0x100 HF_INDICATORS = 0x100
ESCO_S4_SETTINGS_SUPPORTED = 0x200 ESCO_S4_SETTINGS_SUPPORTED = 0x200
ENHANCED_VOICE_RECOGNITION_STATUS = 0x400 ENHANCED_VOICE_RECOGNITION_STATUS = 0x400
VOICE_RECOGNITION_TEST = 0x800 VOICE_RECOGNITION_TEXT = 0x800
class AgFeature(enum.IntFlag): class AgFeature(enum.IntFlag):
""" """
AG supported features (+BRSF:) (normative). AG supported features (+BRSF:) (normative).
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007. Hands-Free Profile v1.9, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
""" """
THREE_WAY_CALLING = 0x001 THREE_WAY_CALLING = 0x001
@@ -178,7 +178,7 @@ class AgFeature(enum.IntFlag):
HF_INDICATORS = 0x400 HF_INDICATORS = 0x400
ESCO_S4_SETTINGS_SUPPORTED = 0x800 ESCO_S4_SETTINGS_SUPPORTED = 0x800
ENHANCED_VOICE_RECOGNITION_STATUS = 0x1000 ENHANCED_VOICE_RECOGNITION_STATUS = 0x1000
VOICE_RECOGNITION_TEST = 0x2000 VOICE_RECOGNITION_TEXT = 0x2000
class AudioCodec(enum.IntEnum): class AudioCodec(enum.IntEnum):
@@ -690,7 +690,7 @@ class HfIndicatorState:
current_status: int = 0 current_status: int = 0
class HfProtocol(pyee.EventEmitter): class HfProtocol(utils.EventEmitter):
""" """
Implementation for the Hands-Free side of the Hands-Free profile. Implementation for the Hands-Free side of the Hands-Free profile.
@@ -720,6 +720,14 @@ class HfProtocol(pyee.EventEmitter):
vrec: VoiceRecognitionState vrec: VoiceRecognitionState
""" """
EVENT_CODEC_NEGOTIATION = "codec_negotiation"
EVENT_AG_INDICATOR = "ag_indicator"
EVENT_SPEAKER_VOLUME = "speaker_volume"
EVENT_MICROPHONE_VOLUME = "microphone_volume"
EVENT_RING = "ring"
EVENT_CLI_NOTIFICATION = "cli_notification"
EVENT_VOICE_RECOGNITION = "voice_recognition"
class HfLoopTermination(HfpProtocolError): class HfLoopTermination(HfpProtocolError):
"""Termination signal for run() loop.""" """Termination signal for run() loop."""
@@ -777,7 +785,8 @@ class HfProtocol(pyee.EventEmitter):
self.dlc.sink = self._read_at self.dlc.sink = self._read_at
# Stop the run() loop when L2CAP is closed. # Stop the run() loop when L2CAP is closed.
self.dlc.multiplexer.l2cap_channel.on( self.dlc.multiplexer.l2cap_channel.on(
'close', lambda: self.unsolicited_queue.put_nowait(None) self.dlc.multiplexer.l2cap_channel.EVENT_CLOSE,
lambda: self.unsolicited_queue.put_nowait(None),
) )
def supports_hf_feature(self, feature: HfFeature) -> bool: def supports_hf_feature(self, feature: HfFeature) -> bool:
@@ -1034,7 +1043,7 @@ class HfProtocol(pyee.EventEmitter):
# ID. The HF shall be ready to accept the synchronous connection # ID. The HF shall be ready to accept the synchronous connection
# establishment as soon as it has sent the AT commands AT+BCS=<Codec ID>. # establishment as soon as it has sent the AT commands AT+BCS=<Codec ID>.
self.active_codec = AudioCodec(codec_id) self.active_codec = AudioCodec(codec_id)
self.emit('codec_negotiation', self.active_codec) self.emit(self.EVENT_CODEC_NEGOTIATION, self.active_codec)
logger.info("codec connection setup completed") logger.info("codec connection setup completed")
@@ -1095,7 +1104,7 @@ class HfProtocol(pyee.EventEmitter):
# CIEV is in 1-index, while ag_indicators is in 0-index. # CIEV is in 1-index, while ag_indicators is in 0-index.
ag_indicator = self.ag_indicators[index - 1] ag_indicator = self.ag_indicators[index - 1]
ag_indicator.current_status = value ag_indicator.current_status = value
self.emit('ag_indicator', ag_indicator) self.emit(self.EVENT_AG_INDICATOR, ag_indicator)
logger.info(f"AG indicator updated: {ag_indicator.indicator}, {value}") logger.info(f"AG indicator updated: {ag_indicator.indicator}, {value}")
async def handle_unsolicited(self): async def handle_unsolicited(self):
@@ -1110,19 +1119,21 @@ class HfProtocol(pyee.EventEmitter):
int(result.parameters[0]), int(result.parameters[1]) int(result.parameters[0]), int(result.parameters[1])
) )
elif result.code == "+VGS": elif result.code == "+VGS":
self.emit('speaker_volume', int(result.parameters[0])) self.emit(self.EVENT_SPEAKER_VOLUME, int(result.parameters[0]))
elif result.code == "+VGM": elif result.code == "+VGM":
self.emit('microphone_volume', int(result.parameters[0])) self.emit(self.EVENT_MICROPHONE_VOLUME, int(result.parameters[0]))
elif result.code == "RING": elif result.code == "RING":
self.emit('ring') self.emit(self.EVENT_RING)
elif result.code == "+CLIP": elif result.code == "+CLIP":
self.emit( self.emit(
'cli_notification', CallLineIdentification.parse_from(result.parameters) self.EVENT_CLI_NOTIFICATION,
CallLineIdentification.parse_from(result.parameters),
) )
elif result.code == "+BVRA": elif result.code == "+BVRA":
# TODO: Support Enhanced Voice Recognition. # TODO: Support Enhanced Voice Recognition.
self.emit( self.emit(
'voice_recognition', VoiceRecognitionState(int(result.parameters[0])) self.EVENT_VOICE_RECOGNITION,
VoiceRecognitionState(int(result.parameters[0])),
) )
else: else:
logging.info(f"unhandled unsolicited response {result.code}") logging.info(f"unhandled unsolicited response {result.code}")
@@ -1146,7 +1157,7 @@ class HfProtocol(pyee.EventEmitter):
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
class AgProtocol(pyee.EventEmitter): class AgProtocol(utils.EventEmitter):
""" """
Implementation for the Audio-Gateway side of the Hands-Free profile. Implementation for the Audio-Gateway side of the Hands-Free profile.
@@ -1179,6 +1190,19 @@ class AgProtocol(pyee.EventEmitter):
volume: Int volume: Int
""" """
EVENT_SLC_COMPLETE = "slc_complete"
EVENT_SUPPORTED_AUDIO_CODECS = "supported_audio_codecs"
EVENT_CODEC_NEGOTIATION = "codec_negotiation"
EVENT_VOICE_RECOGNITION = "voice_recognition"
EVENT_CALL_HOLD = "call_hold"
EVENT_HF_INDICATOR = "hf_indicator"
EVENT_CODEC_CONNECTION_REQUEST = "codec_connection_request"
EVENT_ANSWER = "answer"
EVENT_DIAL = "dial"
EVENT_HANG_UP = "hang_up"
EVENT_SPEAKER_VOLUME = "speaker_volume"
EVENT_MICROPHONE_VOLUME = "microphone_volume"
supported_hf_features: int supported_hf_features: int
supported_hf_indicators: Set[HfIndicator] supported_hf_indicators: Set[HfIndicator]
supported_audio_codecs: List[AudioCodec] supported_audio_codecs: List[AudioCodec]
@@ -1371,7 +1395,7 @@ class AgProtocol(pyee.EventEmitter):
def _check_remained_slc_commands(self) -> None: def _check_remained_slc_commands(self) -> None:
if not self._remained_slc_setup_features: if not self._remained_slc_setup_features:
self.emit('slc_complete') self.emit(self.EVENT_SLC_COMPLETE)
def _on_brsf(self, hf_features: bytes) -> None: def _on_brsf(self, hf_features: bytes) -> None:
self.supported_hf_features = int(hf_features) self.supported_hf_features = int(hf_features)
@@ -1390,16 +1414,17 @@ class AgProtocol(pyee.EventEmitter):
def _on_bac(self, *args) -> None: def _on_bac(self, *args) -> None:
self.supported_audio_codecs = [AudioCodec(int(value)) for value in args] self.supported_audio_codecs = [AudioCodec(int(value)) for value in args]
self.emit(self.EVENT_SUPPORTED_AUDIO_CODECS, self.supported_audio_codecs)
self.send_ok() self.send_ok()
def _on_bcs(self, codec: bytes) -> None: def _on_bcs(self, codec: bytes) -> None:
self.active_codec = AudioCodec(int(codec)) self.active_codec = AudioCodec(int(codec))
self.send_ok() self.send_ok()
self.emit('codec_negotiation', self.active_codec) self.emit(self.EVENT_CODEC_NEGOTIATION, self.active_codec)
def _on_bvra(self, vrec: bytes) -> None: def _on_bvra(self, vrec: bytes) -> None:
self.send_ok() self.send_ok()
self.emit('voice_recognition', VoiceRecognitionState(int(vrec))) self.emit(self.EVENT_VOICE_RECOGNITION, VoiceRecognitionState(int(vrec)))
def _on_chld(self, operation_code: bytes) -> None: def _on_chld(self, operation_code: bytes) -> None:
call_index: Optional[int] = None call_index: Optional[int] = None
@@ -1426,7 +1451,7 @@ class AgProtocol(pyee.EventEmitter):
# Real three-way calls have more complicated situations, but this is not a popular issue - let users to handle the remaining :) # Real three-way calls have more complicated situations, but this is not a popular issue - let users to handle the remaining :)
self.send_ok() self.send_ok()
self.emit('call_hold', operation, call_index) self.emit(self.EVENT_CALL_HOLD, operation, call_index)
def _on_chld_test(self) -> None: def _on_chld_test(self) -> None:
if not self.supports_ag_feature(AgFeature.THREE_WAY_CALLING): if not self.supports_ag_feature(AgFeature.THREE_WAY_CALLING):
@@ -1552,7 +1577,7 @@ class AgProtocol(pyee.EventEmitter):
return return
self.hf_indicators[index].current_status = int(value_bytes) self.hf_indicators[index].current_status = int(value_bytes)
self.emit('hf_indicator', self.hf_indicators[index]) self.emit(self.EVENT_HF_INDICATOR, self.hf_indicators[index])
self.send_ok() self.send_ok()
def _on_bia(self, *args) -> None: def _on_bia(self, *args) -> None:
@@ -1561,21 +1586,21 @@ class AgProtocol(pyee.EventEmitter):
self.send_ok() self.send_ok()
def _on_bcc(self) -> None: def _on_bcc(self) -> None:
self.emit('codec_connection_request') self.emit(self.EVENT_CODEC_CONNECTION_REQUEST)
self.send_ok() self.send_ok()
def _on_a(self) -> None: def _on_a(self) -> None:
"""ATA handler.""" """ATA handler."""
self.emit('answer') self.emit(self.EVENT_ANSWER)
self.send_ok() self.send_ok()
def _on_d(self, number: bytes) -> None: def _on_d(self, number: bytes) -> None:
"""ATD handler.""" """ATD handler."""
self.emit('dial', number.decode()) self.emit(self.EVENT_DIAL, number.decode())
self.send_ok() self.send_ok()
def _on_chup(self) -> None: def _on_chup(self) -> None:
self.emit('hang_up') self.emit(self.EVENT_HANG_UP)
self.send_ok() self.send_ok()
def _on_clcc(self) -> None: def _on_clcc(self) -> None:
@@ -1601,11 +1626,11 @@ class AgProtocol(pyee.EventEmitter):
self.send_ok() self.send_ok()
def _on_vgs(self, level: bytes) -> None: def _on_vgs(self, level: bytes) -> None:
self.emit('speaker_volume', int(level)) self.emit(self.EVENT_SPEAKER_VOLUME, int(level))
self.send_ok() self.send_ok()
def _on_vgm(self, level: bytes) -> None: def _on_vgm(self, level: bytes) -> None:
self.emit('microphone_volume', int(level)) self.emit(self.EVENT_MICROPHONE_VOLUME, int(level))
self.send_ok() self.send_ok()
@@ -1618,7 +1643,7 @@ class ProfileVersion(enum.IntEnum):
""" """
Profile version (normative). Profile version (normative).
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements. Hands-Free Profile v1.8, 6.3 SDP Interoperability Requirements.
""" """
V1_5 = 0x0105 V1_5 = 0x0105
@@ -1632,7 +1657,7 @@ class HfSdpFeature(enum.IntFlag):
""" """
HF supported features (normative). HF supported features (normative).
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements. Hands-Free Profile v1.9, 6.3 SDP Interoperability Requirements.
""" """
EC_NR = 0x01 # Echo Cancel & Noise reduction EC_NR = 0x01 # Echo Cancel & Noise reduction
@@ -1640,16 +1665,17 @@ class HfSdpFeature(enum.IntFlag):
CLI_PRESENTATION_CAPABILITY = 0x04 CLI_PRESENTATION_CAPABILITY = 0x04
VOICE_RECOGNITION_ACTIVATION = 0x08 VOICE_RECOGNITION_ACTIVATION = 0x08
REMOTE_VOLUME_CONTROL = 0x10 REMOTE_VOLUME_CONTROL = 0x10
WIDE_BAND = 0x20 # Wide band speech WIDE_BAND_SPEECH = 0x20
ENHANCED_VOICE_RECOGNITION_STATUS = 0x40 ENHANCED_VOICE_RECOGNITION_STATUS = 0x40
VOICE_RECOGNITION_TEST = 0x80 VOICE_RECOGNITION_TEXT = 0x80
SUPER_WIDE_BAND = 0x100
class AgSdpFeature(enum.IntFlag): class AgSdpFeature(enum.IntFlag):
""" """
AG supported features (normative). AG supported features (normative).
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements. Hands-Free Profile v1.9, 6.3 SDP Interoperability Requirements.
""" """
THREE_WAY_CALLING = 0x01 THREE_WAY_CALLING = 0x01
@@ -1657,9 +1683,10 @@ class AgSdpFeature(enum.IntFlag):
VOICE_RECOGNITION_FUNCTION = 0x04 VOICE_RECOGNITION_FUNCTION = 0x04
IN_BAND_RING_TONE_CAPABILITY = 0x08 IN_BAND_RING_TONE_CAPABILITY = 0x08
VOICE_TAG = 0x10 # Attach a number to voice tag VOICE_TAG = 0x10 # Attach a number to voice tag
WIDE_BAND = 0x20 # Wide band speech WIDE_BAND_SPEECH = 0x20
ENHANCED_VOICE_RECOGNITION_STATUS = 0x40 ENHANCED_VOICE_RECOGNITION_STATUS = 0x40
VOICE_RECOGNITION_TEST = 0x80 VOICE_RECOGNITION_TEXT = 0x80
SUPER_WIDE_BAND_SPEED_SPEECH = 0x100
def make_hf_sdp_records( def make_hf_sdp_records(
@@ -1692,11 +1719,11 @@ def make_hf_sdp_records(
in configuration.supported_hf_features in configuration.supported_hf_features
): ):
hf_supported_features |= HfSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS hf_supported_features |= HfSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS
if HfFeature.VOICE_RECOGNITION_TEST in configuration.supported_hf_features: if HfFeature.VOICE_RECOGNITION_TEXT in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.VOICE_RECOGNITION_TEST hf_supported_features |= HfSdpFeature.VOICE_RECOGNITION_TEXT
if AudioCodec.MSBC in configuration.supported_audio_codecs: if AudioCodec.MSBC in configuration.supported_audio_codecs:
hf_supported_features |= HfSdpFeature.WIDE_BAND hf_supported_features |= HfSdpFeature.WIDE_BAND_SPEECH
return [ return [
sdp.ServiceAttribute( sdp.ServiceAttribute(
@@ -1772,14 +1799,14 @@ def make_ag_sdp_records(
in configuration.supported_ag_features in configuration.supported_ag_features
): ):
ag_supported_features |= AgSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS ag_supported_features |= AgSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS
if AgFeature.VOICE_RECOGNITION_TEST in configuration.supported_ag_features: if AgFeature.VOICE_RECOGNITION_TEXT in configuration.supported_ag_features:
ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_TEST ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_TEXT
if AgFeature.IN_BAND_RING_TONE_CAPABILITY in configuration.supported_ag_features: if AgFeature.IN_BAND_RING_TONE_CAPABILITY in configuration.supported_ag_features:
ag_supported_features |= AgSdpFeature.IN_BAND_RING_TONE_CAPABILITY ag_supported_features |= AgSdpFeature.IN_BAND_RING_TONE_CAPABILITY
if AgFeature.VOICE_RECOGNITION_FUNCTION in configuration.supported_ag_features: if AgFeature.VOICE_RECOGNITION_FUNCTION in configuration.supported_ag_features:
ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_FUNCTION ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_FUNCTION
if AudioCodec.MSBC in configuration.supported_audio_codecs: if AudioCodec.MSBC in configuration.supported_audio_codecs:
ag_supported_features |= AgSdpFeature.WIDE_BAND ag_supported_features |= AgSdpFeature.WIDE_BAND_SPEECH
return [ return [
sdp.ServiceAttribute( sdp.ServiceAttribute(

View File

@@ -22,11 +22,12 @@ import enum
import struct import struct
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pyee import EventEmitter
from typing import Optional, Callable from typing import Optional, Callable
from typing_extensions import override from typing_extensions import override
from bumble import l2cap, device from bumble import l2cap
from bumble import device
from bumble import utils
from bumble.core import InvalidStateError, ProtocolError from bumble.core import InvalidStateError, ProtocolError
from bumble.hci import Address from bumble.hci import Address
@@ -195,11 +196,18 @@ class SendHandshakeMessage(Message):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HID(ABC, EventEmitter): class HID(ABC, utils.EventEmitter):
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
connection: Optional[device.Connection] = None connection: Optional[device.Connection] = None
EVENT_INTERRUPT_DATA = "interrupt_data"
EVENT_CONTROL_DATA = "control_data"
EVENT_SUSPEND = "suspend"
EVENT_EXIT_SUSPEND = "exit_suspend"
EVENT_VIRTUAL_CABLE_UNPLUG = "virtual_cable_unplug"
EVENT_HANDSHAKE = "handshake"
class Role(enum.IntEnum): class Role(enum.IntEnum):
HOST = 0x00 HOST = 0x00
DEVICE = 0x01 DEVICE = 0x01
@@ -214,7 +222,7 @@ class HID(ABC, EventEmitter):
device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection) device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection) device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
device.on('connection', self.on_device_connection) device.on(device.EVENT_CONNECTION, self.on_device_connection)
async def connect_control_channel(self) -> None: async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel # Create a new L2CAP connection - control channel
@@ -257,15 +265,20 @@ class HID(ABC, EventEmitter):
def on_device_connection(self, connection: device.Connection) -> None: def on_device_connection(self, connection: device.Connection) -> None:
self.connection = connection self.connection = connection
self.remote_device_bd_address = connection.peer_address self.remote_device_bd_address = connection.peer_address
connection.on('disconnection', self.on_device_disconnection) connection.on(connection.EVENT_DISCONNECTION, self.on_device_disconnection)
def on_device_disconnection(self, reason: int) -> None: def on_device_disconnection(self, reason: int) -> None:
self.connection = None self.connection = None
def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}') logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) l2cap_channel.on(
l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel)) l2cap_channel.EVENT_OPEN, lambda: self.on_l2cap_channel_open(l2cap_channel)
)
l2cap_channel.on(
l2cap_channel.EVENT_CLOSE,
lambda: self.on_l2cap_channel_close(l2cap_channel),
)
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None: def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
if l2cap_channel.psm == HID_CONTROL_PSM: if l2cap_channel.psm == HID_CONTROL_PSM:
@@ -289,7 +302,7 @@ class HID(ABC, EventEmitter):
def on_intr_pdu(self, pdu: bytes) -> None: def on_intr_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}') logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
self.emit("interrupt_data", pdu) self.emit(self.EVENT_INTERRUPT_DATA, pdu)
def send_pdu_on_ctrl(self, msg: bytes) -> None: def send_pdu_on_ctrl(self, msg: bytes) -> None:
assert self.l2cap_ctrl_channel assert self.l2cap_ctrl_channel
@@ -362,17 +375,17 @@ class Device(HID):
self.handle_set_protocol(pdu) self.handle_set_protocol(pdu)
elif message_type == Message.MessageType.DATA: elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA') logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu) self.emit(self.EVENT_CONTROL_DATA, pdu)
elif message_type == Message.MessageType.CONTROL: elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.SUSPEND: if param == Message.ControlCommand.SUSPEND:
logger.debug('<<< HID SUSPEND') logger.debug('<<< HID SUSPEND')
self.emit('suspend') self.emit(self.EVENT_SUSPEND)
elif param == Message.ControlCommand.EXIT_SUSPEND: elif param == Message.ControlCommand.EXIT_SUSPEND:
logger.debug('<<< HID EXIT SUSPEND') logger.debug('<<< HID EXIT SUSPEND')
self.emit('exit_suspend') self.emit(self.EVENT_EXIT_SUSPEND)
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG: elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG') logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug') self.emit(self.EVENT_VIRTUAL_CABLE_UNPLUG)
else: else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED') logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else: else:
@@ -537,14 +550,14 @@ class Host(HID):
message_type = pdu[0] >> 4 message_type = pdu[0] >> 4
if message_type == Message.MessageType.HANDSHAKE: if message_type == Message.MessageType.HANDSHAKE:
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}') logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
self.emit('handshake', Message.Handshake(param)) self.emit(self.EVENT_HANDSHAKE, Message.Handshake(param))
elif message_type == Message.MessageType.DATA: elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA') logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu) self.emit(self.EVENT_CONTROL_DATA, pdu)
elif message_type == Message.MessageType.CONTROL: elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG: if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG') logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug') self.emit(self.EVENT_VIRTUAL_CABLE_UNPLUG)
else: else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED') logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else: else:

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC # Copyright 2021-2025 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -34,22 +34,23 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
) )
from bumble.colors import color from bumble.colors import color
from bumble.l2cap import L2CAP_PDU from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper from bumble.snoop import Snooper
from bumble import drivers from bumble import drivers
from bumble import hci from bumble import hci
from bumble.core import ( from bumble.core import (
BT_BR_EDR_TRANSPORT, PhysicalTransport,
BT_LE_TRANSPORT, PhysicalTransport,
ConnectionPHY, ConnectionPHY,
ConnectionParameters, ConnectionParameters,
) )
from bumble.utils import AbortableEventEmitter from bumble import utils
from bumble.transport.common import TransportLostError from bumble.transport.common import TransportLostError
if TYPE_CHECKING: if TYPE_CHECKING:
from .transport.common import TransportSink, TransportSource from bumble.transport.common import TransportSink, TransportSource
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -59,7 +60,19 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AclPacketQueue: class DataPacketQueue(utils.EventEmitter):
"""
Flow-control queue for host->controller data packets (ACL, ISO).
The queue holds packets associated with a connection handle. The packets
are sent to the controller, up to a maximum total number of packets in flight.
A packet is considered to be "in flight" when it has been sent to the controller
but not completed yet. Packets are no longer "in flight" when the controller
declares them as completed.
The queue emits a 'flow' event whenever one or more packets are completed.
"""
max_packet_size: int max_packet_size: int
def __init__( def __init__(
@@ -68,55 +81,124 @@ class AclPacketQueue:
max_in_flight: int, max_in_flight: int,
send: Callable[[hci.HCI_Packet], None], send: Callable[[hci.HCI_Packet], None],
) -> None: ) -> None:
super().__init__()
self.max_packet_size = max_packet_size self.max_packet_size = max_packet_size
self.max_in_flight = max_in_flight self.max_in_flight = max_in_flight
self.in_flight = 0 self._in_flight = 0 # Total number of packets in flight across all connections
self.send = send self._in_flight_per_connection: dict[int, int] = collections.defaultdict(
self.packets: Deque[hci.HCI_AclDataPacket] = collections.deque() int
) # Number of packets in flight per connection
self._send = send
self._packets: Deque[tuple[hci.HCI_Packet, int]] = collections.deque()
self._queued = 0
self._completed = 0
def enqueue(self, packet: hci.HCI_AclDataPacket) -> None: @property
self.packets.appendleft(packet) def queued(self) -> int:
self.check_queue() """Total number of packets queued since creation."""
return self._queued
if self.packets: @property
def completed(self) -> int:
"""Total number of packets completed since creation."""
return self._completed
@property
def pending(self) -> int:
"""Number of packets that have been queued but not completed."""
return self._queued - self._completed
def enqueue(self, packet: hci.HCI_Packet, connection_handle: int) -> None:
"""Enqueue a packet associated with a connection"""
self._packets.appendleft((packet, connection_handle))
self._queued += 1
self._check_queue()
if self._packets:
logger.debug( logger.debug(
f'{self.in_flight} ACL packets in flight, ' f'{self._in_flight} packets in flight, '
f'{len(self.packets)} in queue' f'{len(self._packets)} in queue'
) )
def check_queue(self) -> None: def flush(self, connection_handle: int) -> None:
while self.packets and self.in_flight < self.max_in_flight: """
packet = self.packets.pop() Remove all packets associated with a connection.
self.send(packet)
self.in_flight += 1
def on_packets_completed(self, packet_count: int) -> None: All packets associated with the connection that are in flight are implicitly
if packet_count > self.in_flight: marked as completed, but no 'flow' event is emitted.
"""
packets_to_keep = [
(packet, handle)
for (packet, handle) in self._packets
if handle != connection_handle
]
if flushed_count := len(self._packets) - len(packets_to_keep):
self._completed += flushed_count
self._packets = collections.deque(packets_to_keep)
if connection_handle in self._in_flight_per_connection:
in_flight = self._in_flight_per_connection[connection_handle]
self._completed += in_flight
self._in_flight -= in_flight
del self._in_flight_per_connection[connection_handle]
def _check_queue(self) -> None:
while self._packets and self._in_flight < self.max_in_flight:
packet, connection_handle = self._packets.pop()
self._send(packet)
self._in_flight += 1
self._in_flight_per_connection[connection_handle] += 1
def on_packets_completed(self, packet_count: int, connection_handle: int) -> None:
"""Mark one or more packets associated with a connection as completed."""
if connection_handle not in self._in_flight_per_connection:
logger.warning( logger.warning(
color( f'received completion for unknown connection {connection_handle}'
'!!! {packet_count} completed but only '
f'{self.in_flight} in flight'
)
) )
packet_count = self.in_flight return
self.in_flight -= packet_count in_flight_for_connection = self._in_flight_per_connection[connection_handle]
self.check_queue() if packet_count <= in_flight_for_connection:
self._in_flight_per_connection[connection_handle] -= packet_count
else:
logger.warning(
f'{packet_count} completed for {connection_handle} '
f'but only {in_flight_for_connection} in flight'
)
self._in_flight_per_connection[connection_handle] = 0
if packet_count <= self._in_flight:
self._in_flight -= packet_count
self._completed += packet_count
else:
logger.warning(
f'{packet_count} completed but only {self._in_flight} in flight'
)
self._in_flight = 0
self._completed = self._queued
self._check_queue()
self.emit('flow')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Connection: class Connection:
def __init__( def __init__(
self, host: Host, handle: int, peer_address: hci.Address, transport: int self,
host: Host,
handle: int,
peer_address: hci.Address,
transport: PhysicalTransport,
): ):
self.host = host self.host = host
self.handle = handle self.handle = handle
self.peer_address = peer_address self.peer_address = peer_address
self.assembler = hci.HCI_AclDataPacketAssembler(self.on_acl_pdu) self.assembler = hci.HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport self.transport = transport
acl_packet_queue: Optional[AclPacketQueue] = ( acl_packet_queue: Optional[DataPacketQueue] = (
host.le_acl_packet_queue host.le_acl_packet_queue
if transport == BT_LE_TRANSPORT if transport == PhysicalTransport.LE
else host.acl_packet_queue else host.acl_packet_queue
) )
assert acl_packet_queue assert acl_packet_queue
@@ -129,28 +211,37 @@ class Connection:
l2cap_pdu = L2CAP_PDU.from_bytes(pdu) l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload) self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
def __str__(self) -> str:
return (
f'Connection(transport={self.transport}, peer_address={self.peer_address})'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass @dataclasses.dataclass
class ScoLink: class ScoLink:
peer_address: hci.Address peer_address: hci.Address
handle: int connection_handle: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass @dataclasses.dataclass
class CisLink: class IsoLink:
peer_address: hci.Address
handle: int handle: int
packet_queue: DataPacketQueue = dataclasses.field(repr=False)
packet_sequence_number: int = 0
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Host(AbortableEventEmitter): class Host(utils.EventEmitter):
connections: Dict[int, Connection] connections: Dict[int, Connection]
cis_links: Dict[int, CisLink] cis_links: Dict[int, IsoLink]
bis_links: Dict[int, IsoLink]
sco_links: Dict[int, ScoLink] sco_links: Dict[int, ScoLink]
acl_packet_queue: Optional[AclPacketQueue] = None bigs: dict[int, set[int]]
le_acl_packet_queue: Optional[AclPacketQueue] = None acl_packet_queue: Optional[DataPacketQueue] = None
le_acl_packet_queue: Optional[DataPacketQueue] = None
iso_packet_queue: Optional[DataPacketQueue] = None
hci_sink: Optional[TransportSink] = None hci_sink: Optional[TransportSink] = None
hci_metadata: Dict[str, Any] hci_metadata: Dict[str, Any]
long_term_key_provider: Optional[ long_term_key_provider: Optional[
@@ -169,7 +260,9 @@ class Host(AbortableEventEmitter):
self.ready = False # True when we can accept incoming packets self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle self.connections = {} # Connections, by connection handle
self.cis_links = {} # CIS links, by connection handle self.cis_links = {} # CIS links, by connection handle
self.bis_links = {} # BIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle self.sco_links = {} # SCO links, by connection handle
self.bigs = {} # BIG Handle to BIS Handles
self.pending_command = None self.pending_command = None
self.pending_response: Optional[asyncio.Future[Any]] = None self.pending_response: Optional[asyncio.Future[Any]] = None
self.number_of_supported_advertising_sets = 0 self.number_of_supported_advertising_sets = 0
@@ -199,7 +292,7 @@ class Host(AbortableEventEmitter):
check_address_type: bool = False, check_address_type: bool = False,
) -> Optional[Connection]: ) -> Optional[Connection]:
for connection in self.connections.values(): for connection in self.connections.values():
if connection.peer_address.to_bytes() == bd_addr.to_bytes(): if bytes(connection.peer_address) == bytes(bd_addr):
if ( if (
check_address_type check_address_type
and connection.peer_address.address_type != bd_addr.address_type and connection.peer_address.address_type != bd_addr.address_type
@@ -342,6 +435,14 @@ class Host(AbortableEventEmitter):
) )
) )
) )
if self.supports_command(hci.HCI_SET_EVENT_MASK_PAGE_2_COMMAND):
await self.send_command(
hci.HCI_Set_Event_Mask_Page_2_Command(
event_mask_page_2=hci.HCI_Set_Event_Mask_Page_2_Command.mask(
[hci.HCI_ENCRYPTION_CHANGE_V2_EVENT]
)
)
)
if ( if (
self.local_version is not None self.local_version is not None
@@ -363,6 +464,7 @@ class Host(AbortableEventEmitter):
hci.HCI_LE_READ_LOCAL_P_256_PUBLIC_KEY_COMPLETE_EVENT, hci.HCI_LE_READ_LOCAL_P_256_PUBLIC_KEY_COMPLETE_EVENT,
hci.HCI_LE_GENERATE_DHKEY_COMPLETE_EVENT, hci.HCI_LE_GENERATE_DHKEY_COMPLETE_EVENT,
hci.HCI_LE_ENHANCED_CONNECTION_COMPLETE_EVENT, hci.HCI_LE_ENHANCED_CONNECTION_COMPLETE_EVENT,
hci.HCI_LE_ENHANCED_CONNECTION_COMPLETE_V2_EVENT,
hci.HCI_LE_DIRECTED_ADVERTISING_REPORT_EVENT, hci.HCI_LE_DIRECTED_ADVERTISING_REPORT_EVENT,
hci.HCI_LE_PHY_UPDATE_COMPLETE_EVENT, hci.HCI_LE_PHY_UPDATE_COMPLETE_EVENT,
hci.HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT, hci.HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT,
@@ -387,6 +489,12 @@ class Host(AbortableEventEmitter):
hci.HCI_LE_TRANSMIT_POWER_REPORTING_EVENT, hci.HCI_LE_TRANSMIT_POWER_REPORTING_EVENT,
hci.HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT, hci.HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT,
hci.HCI_LE_SUBRATE_CHANGE_EVENT, hci.HCI_LE_SUBRATE_CHANGE_EVENT,
hci.HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMPLETE_EVENT,
hci.HCI_LE_CS_PROCEDURE_ENABLE_COMPLETE_EVENT,
hci.HCI_LE_CS_SECURITY_ENABLE_COMPLETE_EVENT,
hci.HCI_LE_CS_CONFIG_COMPLETE_EVENT,
hci.HCI_LE_CS_SUBEVENT_RESULT_EVENT,
hci.HCI_LE_CS_SUBEVENT_RESULT_CONTINUE_EVENT,
] ]
) )
@@ -411,39 +519,70 @@ class Host(AbortableEventEmitter):
f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}' f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}'
) )
self.acl_packet_queue = AclPacketQueue( self.acl_packet_queue = DataPacketQueue(
max_packet_size=hc_acl_data_packet_length, max_packet_size=hc_acl_data_packet_length,
max_in_flight=hc_total_num_acl_data_packets, max_in_flight=hc_total_num_acl_data_packets,
send=self.send_hci_packet, send=self.send_hci_packet,
) )
hc_le_acl_data_packet_length = 0 le_acl_data_packet_length = 0
hc_total_num_le_acl_data_packets = 0 total_num_le_acl_data_packets = 0
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND): iso_data_packet_length = 0
total_num_iso_data_packets = 0
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await self.send_command(
hci.HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
le_acl_data_packet_length = (
response.return_parameters.le_acl_data_packet_length
)
total_num_le_acl_data_packets = (
response.return_parameters.total_num_le_acl_data_packets
)
iso_data_packet_length = response.return_parameters.iso_data_packet_length
total_num_iso_data_packets = (
response.return_parameters.total_num_iso_data_packets
)
logger.debug(
'HCI LE flow control: '
f'le_acl_data_packet_length={le_acl_data_packet_length},'
f'total_num_le_acl_data_packets={total_num_le_acl_data_packets}'
f'iso_data_packet_length={iso_data_packet_length},'
f'total_num_iso_data_packets={total_num_iso_data_packets}'
)
elif self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command( response = await self.send_command(
hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True
) )
hc_le_acl_data_packet_length = ( le_acl_data_packet_length = (
response.return_parameters.hc_le_acl_data_packet_length response.return_parameters.le_acl_data_packet_length
) )
hc_total_num_le_acl_data_packets = ( total_num_le_acl_data_packets = (
response.return_parameters.hc_total_num_le_acl_data_packets response.return_parameters.total_num_le_acl_data_packets
) )
logger.debug( logger.debug(
'HCI LE ACL flow control: ' 'HCI LE ACL flow control: '
f'hc_le_acl_data_packet_length={hc_le_acl_data_packet_length},' f'le_acl_data_packet_length={le_acl_data_packet_length},'
f'hc_total_num_le_acl_data_packets={hc_total_num_le_acl_data_packets}' f'total_num_le_acl_data_packets={total_num_le_acl_data_packets}'
) )
if hc_le_acl_data_packet_length == 0 or hc_total_num_le_acl_data_packets == 0: if le_acl_data_packet_length == 0 or total_num_le_acl_data_packets == 0:
# LE and Classic share the same queue # LE and Classic share the same queue
self.le_acl_packet_queue = self.acl_packet_queue self.le_acl_packet_queue = self.acl_packet_queue
else: else:
# Create a separate queue for LE # Create a separate queue for LE
self.le_acl_packet_queue = AclPacketQueue( self.le_acl_packet_queue = DataPacketQueue(
max_packet_size=hc_le_acl_data_packet_length, max_packet_size=le_acl_data_packet_length,
max_in_flight=hc_total_num_le_acl_data_packets, max_in_flight=total_num_le_acl_data_packets,
send=self.send_hci_packet,
)
if iso_data_packet_length and total_num_iso_data_packets:
self.iso_packet_queue = DataPacketQueue(
max_packet_size=iso_data_packet_length,
max_in_flight=total_num_iso_data_packets,
send=self.send_hci_packet, send=self.send_hci_packet,
) )
@@ -552,7 +691,7 @@ class Host(AbortableEventEmitter):
return response return response
except Exception as error: except Exception as error:
logger.warning( logger.exception(
f'{color("!!! Exception while sending command:", "red")} {error}' f'{color("!!! Exception while sending command:", "red")} {error}'
) )
raise error raise error
@@ -595,11 +734,78 @@ class Host(AbortableEventEmitter):
data=l2cap_pdu[offset : offset + data_total_length], data=l2cap_pdu[offset : offset + data_total_length],
) )
logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}') logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}')
packet_queue.enqueue(acl_packet) packet_queue.enqueue(acl_packet, connection_handle)
pb_flag = 1 pb_flag = 1
offset += data_total_length offset += data_total_length
bytes_remaining -= data_total_length bytes_remaining -= data_total_length
def get_data_packet_queue(self, connection_handle: int) -> DataPacketQueue | None:
if connection := self.connections.get(connection_handle):
return connection.acl_packet_queue
if iso_link := self.cis_links.get(connection_handle) or self.bis_links.get(
connection_handle
):
return iso_link.packet_queue
return None
def send_iso_sdu(self, connection_handle: int, sdu: bytes) -> None:
if not (
iso_link := self.cis_links.get(connection_handle)
or self.bis_links.get(connection_handle)
):
logger.warning(f"no ISO link for connection handle {connection_handle}")
return
if iso_link.packet_queue is None:
logger.warning("ISO link has no data packet queue")
return
bytes_remaining = len(sdu)
offset = 0
while bytes_remaining:
is_first_fragment = offset == 0
header_length = 4 if is_first_fragment else 0
assert iso_link.packet_queue.max_packet_size > header_length
fragment_length = min(
bytes_remaining, iso_link.packet_queue.max_packet_size - header_length
)
is_last_fragment = bytes_remaining == fragment_length
iso_sdu_fragment = sdu[offset : offset + fragment_length]
iso_link.packet_queue.enqueue(
(
hci.HCI_IsoDataPacket(
connection_handle=connection_handle,
data_total_length=header_length + fragment_length,
packet_sequence_number=iso_link.packet_sequence_number,
pb_flag=0b10 if is_last_fragment else 0b00,
packet_status_flag=0,
iso_sdu_length=len(sdu),
iso_sdu_fragment=iso_sdu_fragment,
)
if is_first_fragment
else hci.HCI_IsoDataPacket(
connection_handle=connection_handle,
data_total_length=fragment_length,
pb_flag=0b11 if is_last_fragment else 0b01,
iso_sdu_fragment=iso_sdu_fragment,
)
),
connection_handle,
)
offset += fragment_length
bytes_remaining -= fragment_length
iso_link.packet_sequence_number = (iso_link.packet_sequence_number + 1) & 0xFFFF
def remove_big(self, big_handle: int) -> None:
if big := self.bigs.pop(big_handle, None):
for connection_handle in big:
if bis_link := self.bis_links.pop(connection_handle, None):
bis_link.packet_queue.flush(bis_link.handle)
def supports_command(self, op_code: int) -> bool: def supports_command(self, op_code: int) -> bool:
return ( return (
self.local_supported_commands self.local_supported_commands
@@ -727,16 +933,17 @@ class Host(AbortableEventEmitter):
def on_hci_command_status_event(self, event): def on_hci_command_status_event(self, event):
return self.on_command_processed(event) return self.on_command_processed(event)
def on_hci_number_of_completed_packets_event(self, event): def on_hci_number_of_completed_packets_event(
self, event: hci.HCI_Number_Of_Completed_Packets_Event
) -> None:
for connection_handle, num_completed_packets in zip( for connection_handle, num_completed_packets in zip(
event.connection_handles, event.num_completed_packets event.connection_handles, event.num_completed_packets
): ):
if connection := self.connections.get(connection_handle): if queue := self.get_data_packet_queue(connection_handle):
connection.acl_packet_queue.on_packets_completed(num_completed_packets) queue.on_packets_completed(num_completed_packets, connection_handle)
elif not ( continue
self.cis_links.get(connection_handle)
or self.sco_links.get(connection_handle) if connection_handle not in self.sco_links:
):
logger.warning( logger.warning(
'received packet completion event for unknown handle ' 'received packet completion event for unknown handle '
f'0x{connection_handle:04X}' f'0x{connection_handle:04X}'
@@ -767,7 +974,7 @@ class Host(AbortableEventEmitter):
self, self,
event.connection_handle, event.connection_handle,
event.peer_address, event.peer_address,
BT_LE_TRANSPORT, PhysicalTransport.LE,
) )
self.connections[event.connection_handle] = connection self.connections[event.connection_handle] = connection
@@ -780,11 +987,11 @@ class Host(AbortableEventEmitter):
self.emit( self.emit(
'connection', 'connection',
event.connection_handle, event.connection_handle,
BT_LE_TRANSPORT, PhysicalTransport.LE,
event.peer_address, event.peer_address,
getattr(event, 'local_resolvable_private_address', None), getattr(event, 'local_resolvable_private_address', None),
getattr(event, 'peer_resolvable_private_address', None), getattr(event, 'peer_resolvable_private_address', None),
event.role, hci.Role(event.role),
connection_parameters, connection_parameters,
) )
else: else:
@@ -792,7 +999,10 @@ class Host(AbortableEventEmitter):
# Notify the listeners # Notify the listeners
self.emit( self.emit(
'connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status 'connection_failure',
PhysicalTransport.LE,
event.peer_address,
event.status,
) )
def on_hci_le_enhanced_connection_complete_event(self, event): def on_hci_le_enhanced_connection_complete_event(self, event):
@@ -817,7 +1027,7 @@ class Host(AbortableEventEmitter):
self, self,
event.connection_handle, event.connection_handle,
event.bd_addr, event.bd_addr,
BT_BR_EDR_TRANSPORT, PhysicalTransport.BR_EDR,
) )
self.connections[event.connection_handle] = connection self.connections[event.connection_handle] = connection
@@ -825,7 +1035,7 @@ class Host(AbortableEventEmitter):
self.emit( self.emit(
'connection', 'connection',
event.connection_handle, event.connection_handle,
BT_BR_EDR_TRANSPORT, PhysicalTransport.BR_EDR,
event.bd_addr, event.bd_addr,
None, None,
None, None,
@@ -837,7 +1047,10 @@ class Host(AbortableEventEmitter):
# Notify the client # Notify the client
self.emit( self.emit(
'connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status 'connection_failure',
PhysicalTransport.BR_EDR,
event.bd_addr,
event.status,
) )
def on_hci_disconnection_complete_event(self, event): def on_hci_disconnection_complete_event(self, event):
@@ -854,11 +1067,7 @@ class Host(AbortableEventEmitter):
return return
if event.status == hci.HCI_SUCCESS: if event.status == hci.HCI_SUCCESS:
logger.debug( logger.debug(f'### DISCONNECTION: {connection}, reason={event.reason}')
f'### DISCONNECTION: [0x{handle:04X}] '
f'{connection.peer_address} '
f'reason={event.reason}'
)
# Notify the listeners # Notify the listeners
self.emit('disconnection', handle, event.reason) self.emit('disconnection', handle, event.reason)
@@ -869,6 +1078,14 @@ class Host(AbortableEventEmitter):
or self.cis_links.pop(handle, 0) or self.cis_links.pop(handle, 0)
or self.sco_links.pop(handle, 0) or self.sco_links.pop(handle, 0)
) )
# Flush the data queues
if self.acl_packet_queue:
self.acl_packet_queue.flush(handle)
if self.le_acl_packet_queue:
self.le_acl_packet_queue.flush(handle)
if self.iso_packet_queue:
self.iso_packet_queue.flush(handle)
else: else:
logger.debug(f'### DISCONNECTION FAILED: {event.status}') logger.debug(f'### DISCONNECTION FAILED: {event.status}')
@@ -902,8 +1119,11 @@ class Host(AbortableEventEmitter):
# Notify the client # Notify the client
if event.status == hci.HCI_SUCCESS: if event.status == hci.HCI_SUCCESS:
connection_phy = ConnectionPHY(event.tx_phy, event.rx_phy) self.emit(
self.emit('connection_phy_update', connection.handle, connection_phy) 'connection_phy_update',
connection.handle,
ConnectionPHY(event.tx_phy, event.rx_phy),
)
else: else:
self.emit('connection_phy_update_failure', connection.handle, event.status) self.emit('connection_phy_update_failure', connection.handle, event.status)
@@ -953,12 +1173,94 @@ class Host(AbortableEventEmitter):
event.cis_id, event.cis_id,
) )
def on_hci_le_create_big_complete_event(self, event):
self.bigs[event.big_handle] = set(event.connection_handle)
if self.iso_packet_queue is None:
logger.warning("BIS established but ISO packets not supported")
for connection_handle in event.connection_handle:
self.bis_links[connection_handle] = IsoLink(
connection_handle, self.iso_packet_queue
)
self.emit(
'big_establishment',
event.status,
event.big_handle,
event.connection_handle,
event.big_sync_delay,
event.transport_latency_big,
event.phy,
event.nse,
event.bn,
event.pto,
event.irc,
event.max_pdu,
event.iso_interval,
)
def on_hci_le_big_sync_established_event(self, event):
self.bigs[event.big_handle] = set(event.connection_handle)
for connection_handle in event.connection_handle:
self.bis_links[connection_handle] = IsoLink(
connection_handle, self.iso_packet_queue
)
self.emit(
'big_sync_establishment',
event.status,
event.big_handle,
event.transport_latency_big,
event.nse,
event.bn,
event.pto,
event.irc,
event.max_pdu,
event.iso_interval,
event.connection_handle,
)
def on_hci_le_big_sync_lost_event(self, event):
self.remove_big(event.big_handle)
self.emit('big_sync_lost', event.big_handle, event.reason)
def on_hci_le_terminate_big_complete_event(self, event):
self.remove_big(event.big_handle)
self.emit('big_termination', event.reason, event.big_handle)
def on_hci_le_periodic_advertising_sync_transfer_received_event(self, event):
self.emit(
'periodic_advertising_sync_transfer',
event.status,
event.connection_handle,
event.sync_handle,
event.advertising_sid,
event.advertiser_address,
event.advertiser_phy,
event.periodic_advertising_interval,
event.advertiser_clock_accuracy,
)
def on_hci_le_periodic_advertising_sync_transfer_received_v2_event(self, event):
self.emit(
'periodic_advertising_sync_transfer',
event.status,
event.connection_handle,
event.sync_handle,
event.advertising_sid,
event.advertiser_address,
event.advertiser_phy,
event.periodic_advertising_interval,
event.advertiser_clock_accuracy,
)
def on_hci_le_cis_established_event(self, event): def on_hci_le_cis_established_event(self, event):
# The remaining parameters are unused for now. # The remaining parameters are unused for now.
if event.status == hci.HCI_SUCCESS: if event.status == hci.HCI_SUCCESS:
self.cis_links[event.connection_handle] = CisLink( if self.iso_packet_queue is None:
handle=event.connection_handle, logger.warning("CIS established but ISO packets not supported")
peer_address=hci.Address.ANY, self.cis_links[event.connection_handle] = IsoLink(
handle=event.connection_handle, packet_queue=self.iso_packet_queue
) )
self.emit('cis_establishment', event.connection_handle) self.emit('cis_establishment', event.connection_handle)
else: else:
@@ -995,7 +1297,8 @@ class Host(AbortableEventEmitter):
logger.debug('no long term key provider') logger.debug('no long term key provider')
long_term_key = None long_term_key = None
else: else:
long_term_key = await self.abort_on( long_term_key = await utils.cancel_on_event(
self,
'flush', 'flush',
# pylint: disable-next=not-callable # pylint: disable-next=not-callable
self.long_term_key_provider( self.long_term_key_provider(
@@ -1028,7 +1331,7 @@ class Host(AbortableEventEmitter):
self.sco_links[event.connection_handle] = ScoLink( self.sco_links[event.connection_handle] = ScoLink(
peer_address=event.bd_addr, peer_address=event.bd_addr,
handle=event.connection_handle, connection_handle=event.connection_handle,
) )
# Notify the client # Notify the client
@@ -1053,7 +1356,7 @@ class Host(AbortableEventEmitter):
f'role change for {event.bd_addr}: ' f'role change for {event.bd_addr}: '
f'{hci.HCI_Constant.role_name(event.new_role)}' f'{hci.HCI_Constant.role_name(event.new_role)}'
) )
self.emit('role_change', event.bd_addr, event.new_role) self.emit('role_change', event.bd_addr, hci.Role(event.new_role))
else: else:
logger.debug( logger.debug(
f'role change for {event.bd_addr} failed: ' f'role change for {event.bd_addr} failed: '
@@ -1089,6 +1392,21 @@ class Host(AbortableEventEmitter):
'connection_encryption_change', 'connection_encryption_change',
event.connection_handle, event.connection_handle,
event.encryption_enabled, event.encryption_enabled,
0,
)
else:
self.emit(
'connection_encryption_failure', event.connection_handle, event.status
)
def on_hci_encryption_change_v2_event(self, event):
# Notify the client
if event.status == hci.HCI_SUCCESS:
self.emit(
'connection_encryption_change',
event.connection_handle,
event.encryption_enabled,
event.encryption_key_size,
) )
else: else:
self.emit( self.emit(
@@ -1153,7 +1471,8 @@ class Host(AbortableEventEmitter):
logger.debug('no link key provider') logger.debug('no link key provider')
link_key = None link_key = None
else: else:
link_key = await self.abort_on( link_key = await utils.cancel_on_event(
self,
'flush', 'flush',
# pylint: disable-next=not-callable # pylint: disable-next=not-callable
self.link_key_provider(event.bd_addr), self.link_key_provider(event.bd_addr),
@@ -1248,3 +1567,24 @@ class Host(AbortableEventEmitter):
event.connection_handle, event.connection_handle,
int.from_bytes(event.le_features, 'little'), int.from_bytes(event.le_features, 'little'),
) )
def on_hci_le_cs_read_remote_supported_capabilities_complete_event(self, event):
self.emit('cs_remote_supported_capabilities', event)
def on_hci_le_cs_security_enable_complete_event(self, event):
self.emit('cs_security', event)
def on_hci_le_cs_config_complete_event(self, event):
self.emit('cs_config', event)
def on_hci_le_cs_procedure_enable_complete_event(self, event):
self.emit('cs_procedure', event)
def on_hci_le_cs_subevent_result_event(self, event):
self.emit('cs_subevent_result', event)
def on_hci_le_cs_subevent_result_continue_event(self, event):
self.emit('cs_subevent_result_continue', event)
def on_hci_vendor_event(self, event):
self.emit('vendor_event', event)

View File

@@ -22,17 +22,18 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import dataclasses
import logging import logging
import os import os
import json import json
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Any
from typing_extensions import Self from typing_extensions import Self
from .colors import color from bumble.colors import color
from .hci import Address from bumble import hci
if TYPE_CHECKING: if TYPE_CHECKING:
from .device import Device from bumble.device import Device
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -42,16 +43,17 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass
class PairingKeys: class PairingKeys:
@dataclasses.dataclass
class Key: class Key:
def __init__(self, value, authenticated=False, ediv=None, rand=None): value: bytes
self.value = value authenticated: bool = False
self.authenticated = authenticated ediv: Optional[int] = None
self.ediv = ediv rand: Optional[bytes] = None
self.rand = rand
@classmethod @classmethod
def from_dict(cls, key_dict): def from_dict(cls, key_dict: dict[str, Any]) -> PairingKeys.Key:
value = bytes.fromhex(key_dict['value']) value = bytes.fromhex(key_dict['value'])
authenticated = key_dict.get('authenticated', False) authenticated = key_dict.get('authenticated', False)
ediv = key_dict.get('ediv') ediv = key_dict.get('ediv')
@@ -61,7 +63,7 @@ class PairingKeys:
return cls(value, authenticated, ediv, rand) return cls(value, authenticated, ediv, rand)
def to_dict(self): def to_dict(self) -> dict[str, Any]:
key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated} key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated}
if self.ediv is not None: if self.ediv is not None:
key_dict['ediv'] = self.ediv key_dict['ediv'] = self.ediv
@@ -70,39 +72,42 @@ class PairingKeys:
return key_dict return key_dict
def __init__(self): address_type: Optional[hci.AddressType] = None
self.address_type = None ltk: Optional[Key] = None
self.ltk = None ltk_central: Optional[Key] = None
self.ltk_central = None ltk_peripheral: Optional[Key] = None
self.ltk_peripheral = None irk: Optional[Key] = None
self.irk = None csrk: Optional[Key] = None
self.csrk = None link_key: Optional[Key] = None # Classic
self.link_key = None # Classic link_key_type: Optional[int] = None # Classic
@staticmethod @classmethod
def key_from_dict(keys_dict, key_name): def key_from_dict(cls, keys_dict: dict[str, Any], key_name: str) -> Optional[Key]:
key_dict = keys_dict.get(key_name) key_dict = keys_dict.get(key_name)
if key_dict is None: if key_dict is None:
return None return None
return PairingKeys.Key.from_dict(key_dict) return PairingKeys.Key.from_dict(key_dict)
@staticmethod @classmethod
def from_dict(keys_dict): def from_dict(cls, keys_dict: dict[str, Any]) -> PairingKeys:
keys = PairingKeys() return PairingKeys(
address_type=(
hci.AddressType(t)
if (t := keys_dict.get('address_type')) is not None
else None
),
ltk=PairingKeys.key_from_dict(keys_dict, 'ltk'),
ltk_central=PairingKeys.key_from_dict(keys_dict, 'ltk_central'),
ltk_peripheral=PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral'),
irk=PairingKeys.key_from_dict(keys_dict, 'irk'),
csrk=PairingKeys.key_from_dict(keys_dict, 'csrk'),
link_key=PairingKeys.key_from_dict(keys_dict, 'link_key'),
link_key_type=keys_dict.get('link_key_type'),
)
keys.address_type = keys_dict.get('address_type') def to_dict(self) -> dict[str, Any]:
keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk') keys: dict[str, Any] = {}
keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central')
keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral')
keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk')
keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk')
keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key')
return keys
def to_dict(self):
keys = {}
if self.address_type is not None: if self.address_type is not None:
keys['address_type'] = self.address_type keys['address_type'] = self.address_type
@@ -125,9 +130,12 @@ class PairingKeys:
if self.link_key is not None: if self.link_key is not None:
keys['link_key'] = self.link_key.to_dict() keys['link_key'] = self.link_key.to_dict()
if self.link_key_type is not None:
keys['link_key_type'] = self.link_key_type
return keys return keys
def print(self, prefix=''): def print(self, prefix: str = '') -> None:
keys_dict = self.to_dict() keys_dict = self.to_dict()
for container_property, value in keys_dict.items(): for container_property, value in keys_dict.items():
if isinstance(value, dict): if isinstance(value, dict):
@@ -156,20 +164,28 @@ class KeyStore:
all_keys = await self.get_all() all_keys = await self.get_all()
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys)) await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
async def get_resolving_keys(self): async def get_resolving_keys(self) -> list[tuple[bytes, hci.Address]]:
all_keys = await self.get_all() all_keys = await self.get_all()
resolving_keys = [] resolving_keys = []
for name, keys in all_keys: for name, keys in all_keys:
if keys.irk is not None: if keys.irk is not None:
if keys.address_type is None: resolving_keys.append(
address_type = Address.RANDOM_DEVICE_ADDRESS (
else: keys.irk.value,
address_type = keys.address_type hci.Address(
resolving_keys.append((keys.irk.value, Address(name, address_type))) name,
(
keys.address_type
if keys.address_type is not None
else hci.Address.RANDOM_DEVICE_ADDRESS
),
),
)
)
return resolving_keys return resolving_keys
async def print(self, prefix=''): async def print(self, prefix: str = '') -> None:
entries = await self.get_all() entries = await self.get_all()
separator = '' separator = ''
for name, keys in entries: for name, keys in entries:
@@ -177,8 +193,8 @@ class KeyStore:
keys.print(prefix=prefix + ' ') keys.print(prefix=prefix + ' ')
separator = '\n' separator = '\n'
@staticmethod @classmethod
def create_for_device(device: Device) -> KeyStore: def create_for_device(cls, device: Device) -> KeyStore:
if device.config.keystore is None: if device.config.keystore is None:
return MemoryKeyStore() return MemoryKeyStore()
@@ -266,9 +282,9 @@ class JsonKeyStore(KeyStore):
filename = params[0] filename = params[0]
# Use a namespace based on the device address # Use a namespace based on the device address
if device.public_address not in (Address.ANY, Address.ANY_RANDOM): if device.public_address not in (hci.Address.ANY, hci.Address.ANY_RANDOM):
namespace = str(device.public_address) namespace = str(device.public_address)
elif device.random_address != Address.ANY_RANDOM: elif device.random_address != hci.Address.ANY_RANDOM:
namespace = str(device.random_address) namespace = str(device.random_address)
else: else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE namespace = JsonKeyStore.DEFAULT_NAMESPACE

View File

@@ -23,7 +23,6 @@ import logging
import struct import struct
from collections import deque from collections import deque
from pyee import EventEmitter
from typing import ( from typing import (
Dict, Dict,
Type, Type,
@@ -39,19 +38,19 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
) )
from .utils import deprecated from bumble import utils
from .colors import color from bumble.colors import color
from .core import ( from bumble.core import (
BT_CENTRAL_ROLE,
InvalidStateError, InvalidStateError,
InvalidArgumentError, InvalidArgumentError,
InvalidPacketError, InvalidPacketError,
OutOfResourcesError, OutOfResourcesError,
ProtocolError, ProtocolError,
) )
from .hci import ( from bumble.hci import (
HCI_LE_Connection_Update_Command, HCI_LE_Connection_Update_Command,
HCI_Object, HCI_Object,
Role,
key_with_value, key_with_value,
name_or_number, name_or_number,
) )
@@ -225,7 +224,7 @@ class L2CAP_PDU:
return L2CAP_PDU(l2cap_pdu_cid, l2cap_pdu_payload) return L2CAP_PDU(l2cap_pdu_cid, l2cap_pdu_payload)
def to_bytes(self) -> bytes: def __bytes__(self) -> bytes:
header = struct.pack('<HH', len(self.payload), self.cid) header = struct.pack('<HH', len(self.payload), self.cid)
return header + self.payload return header + self.payload
@@ -233,9 +232,6 @@ class L2CAP_PDU:
self.cid = cid self.cid = cid
self.payload = payload self.payload = payload
def __bytes__(self) -> bytes:
return self.to_bytes()
def __str__(self) -> str: def __str__(self) -> str:
return f'{color("L2CAP", "green")} [CID={self.cid}]: {self.payload.hex()}' return f'{color("L2CAP", "green")} [CID={self.cid}]: {self.payload.hex()}'
@@ -333,11 +329,8 @@ class L2CAP_Control_Frame:
def init_from_bytes(self, pdu, offset): def init_from_bytes(self, pdu, offset):
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self) -> bytes:
return self.pdu
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return self.to_bytes() return self.pdu
def __str__(self) -> str: def __str__(self) -> str:
result = f'{color(self.name, "yellow")} [ID={self.identifier}]' result = f'{color(self.name, "yellow")} [ID={self.identifier}]'
@@ -726,7 +719,7 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ClassicChannel(EventEmitter): class ClassicChannel(utils.EventEmitter):
class State(enum.IntEnum): class State(enum.IntEnum):
# States # States
CLOSED = 0x00 CLOSED = 0x00
@@ -751,6 +744,9 @@ class ClassicChannel(EventEmitter):
WAIT_FINAL_RSP = 0x16 WAIT_FINAL_RSP = 0x16
WAIT_CONTROL_IND = 0x17 WAIT_CONTROL_IND = 0x17
EVENT_OPEN = "open"
EVENT_CLOSE = "close"
connection_result: Optional[asyncio.Future[None]] connection_result: Optional[asyncio.Future[None]]
disconnection_result: Optional[asyncio.Future[None]] disconnection_result: Optional[asyncio.Future[None]]
response: Optional[asyncio.Future[bytes]] response: Optional[asyncio.Future[bytes]]
@@ -779,7 +775,6 @@ class ClassicChannel(EventEmitter):
self.psm = psm self.psm = psm
self.source_cid = source_cid self.source_cid = source_cid
self.destination_cid = 0 self.destination_cid = 0
self.response = None
self.connection_result = None self.connection_result = None
self.disconnection_result = None self.disconnection_result = None
self.sink = None self.sink = None
@@ -789,27 +784,15 @@ class ClassicChannel(EventEmitter):
self.state = new_state self.state = new_state
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None: def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
if self.state != self.State.OPEN:
raise InvalidStateError('channel not open')
self.manager.send_pdu(self.connection, self.destination_cid, pdu) self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
self.manager.send_control_frame(self.connection, self.signaling_cid, frame) self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
async def send_request(self, request: SupportsBytes) -> bytes:
# Check that there isn't already a request pending
if self.response:
raise InvalidStateError('request already pending')
if self.state != self.State.OPEN:
raise InvalidStateError('channel not open')
self.response = asyncio.get_running_loop().create_future()
self.send_pdu(request)
return await self.response
def on_pdu(self, pdu: bytes) -> None: def on_pdu(self, pdu: bytes) -> None:
if self.response: if self.sink:
self.response.set_result(pdu)
self.response = None
elif self.sink:
# pylint: disable=not-callable # pylint: disable=not-callable
self.sink(pdu) self.sink(pdu)
else: else:
@@ -840,8 +823,8 @@ class ClassicChannel(EventEmitter):
# Wait for the connection to succeed or fail # Wait for the connection to succeed or fail
try: try:
return await self.connection.abort_on( return await utils.cancel_on_event(
'disconnection', self.connection_result self.connection, 'disconnection', self.connection_result
) )
finally: finally:
self.connection_result = None self.connection_result = None
@@ -867,7 +850,7 @@ class ClassicChannel(EventEmitter):
def abort(self) -> None: def abort(self) -> None:
if self.state == self.State.OPEN: if self.state == self.State.OPEN:
self._change_state(self.State.CLOSED) self._change_state(self.State.CLOSED)
self.emit('close') self.emit(self.EVENT_CLOSE)
def send_configure_request(self) -> None: def send_configure_request(self) -> None:
options = L2CAP_Control_Frame.encode_configuration_options( options = L2CAP_Control_Frame.encode_configuration_options(
@@ -960,7 +943,7 @@ class ClassicChannel(EventEmitter):
if self.connection_result: if self.connection_result:
self.connection_result.set_result(None) self.connection_result.set_result(None)
self.connection_result = None self.connection_result = None
self.emit('open') self.emit(self.EVENT_OPEN)
elif self.state == self.State.WAIT_CONFIG_REQ_RSP: elif self.state == self.State.WAIT_CONFIG_REQ_RSP:
self._change_state(self.State.WAIT_CONFIG_RSP) self._change_state(self.State.WAIT_CONFIG_RSP)
@@ -976,7 +959,7 @@ class ClassicChannel(EventEmitter):
if self.connection_result: if self.connection_result:
self.connection_result.set_result(None) self.connection_result.set_result(None)
self.connection_result = None self.connection_result = None
self.emit('open') self.emit(self.EVENT_OPEN)
else: else:
logger.warning(color('invalid state', 'red')) logger.warning(color('invalid state', 'red'))
elif ( elif (
@@ -1011,7 +994,7 @@ class ClassicChannel(EventEmitter):
) )
) )
self._change_state(self.State.CLOSED) self._change_state(self.State.CLOSED)
self.emit('close') self.emit(self.EVENT_CLOSE)
self.manager.on_channel_closed(self) self.manager.on_channel_closed(self)
else: else:
logger.warning(color('invalid state', 'red')) logger.warning(color('invalid state', 'red'))
@@ -1032,7 +1015,7 @@ class ClassicChannel(EventEmitter):
if self.disconnection_result: if self.disconnection_result:
self.disconnection_result.set_result(None) self.disconnection_result.set_result(None)
self.disconnection_result = None self.disconnection_result = None
self.emit('close') self.emit(self.EVENT_CLOSE)
self.manager.on_channel_closed(self) self.manager.on_channel_closed(self)
def __str__(self) -> str: def __str__(self) -> str:
@@ -1045,7 +1028,7 @@ class ClassicChannel(EventEmitter):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class LeCreditBasedChannel(EventEmitter): class LeCreditBasedChannel(utils.EventEmitter):
""" """
LE Credit-based Connection Oriented Channel LE Credit-based Connection Oriented Channel
""" """
@@ -1067,6 +1050,9 @@ class LeCreditBasedChannel(EventEmitter):
connection: Connection connection: Connection
sink: Optional[Callable[[bytes], Any]] sink: Optional[Callable[[bytes], Any]]
EVENT_OPEN = "open"
EVENT_CLOSE = "close"
def __init__( def __init__(
self, self,
manager: ChannelManager, manager: ChannelManager,
@@ -1118,9 +1104,9 @@ class LeCreditBasedChannel(EventEmitter):
self.state = new_state self.state = new_state
if new_state == self.State.CONNECTED: if new_state == self.State.CONNECTED:
self.emit('open') self.emit(self.EVENT_OPEN)
elif new_state == self.State.DISCONNECTED: elif new_state == self.State.DISCONNECTED:
self.emit('close') self.emit(self.EVENT_CLOSE)
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None: def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu) self.manager.send_pdu(self.connection, self.destination_cid, pdu)
@@ -1400,7 +1386,9 @@ class LeCreditBasedChannel(EventEmitter):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ClassicChannelServer(EventEmitter): class ClassicChannelServer(utils.EventEmitter):
EVENT_CONNECTION = "connection"
def __init__( def __init__(
self, self,
manager: ChannelManager, manager: ChannelManager,
@@ -1415,7 +1403,7 @@ class ClassicChannelServer(EventEmitter):
self.mtu = mtu self.mtu = mtu
def on_connection(self, channel: ClassicChannel) -> None: def on_connection(self, channel: ClassicChannel) -> None:
self.emit('connection', channel) self.emit(self.EVENT_CONNECTION, channel)
if self.handler: if self.handler:
self.handler(channel) self.handler(channel)
@@ -1425,7 +1413,9 @@ class ClassicChannelServer(EventEmitter):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class LeCreditBasedChannelServer(EventEmitter): class LeCreditBasedChannelServer(utils.EventEmitter):
EVENT_CONNECTION = "connection"
def __init__( def __init__(
self, self,
manager: ChannelManager, manager: ChannelManager,
@@ -1444,7 +1434,7 @@ class LeCreditBasedChannelServer(EventEmitter):
self.mps = mps self.mps = mps
def on_connection(self, channel: LeCreditBasedChannel) -> None: def on_connection(self, channel: LeCreditBasedChannel) -> None:
self.emit('connection', channel) self.emit(self.EVENT_CONNECTION, channel)
if self.handler: if self.handler:
self.handler(channel) self.handler(channel)
@@ -1540,6 +1530,9 @@ class ChannelManager:
def next_identifier(self, connection: Connection) -> int: def next_identifier(self, connection: Connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256 identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
# 0x00 is an invalid ID (BT Core Spec, Vol 3, Part A, Sect 4
if identifier == 0:
identifier = 1
self.identifiers[connection.handle] = identifier self.identifiers[connection.handle] = identifier
return identifier return identifier
@@ -1552,7 +1545,7 @@ class ChannelManager:
if cid in self.fixed_channels: if cid in self.fixed_channels:
del self.fixed_channels[cid] del self.fixed_channels[cid]
@deprecated("Please use create_classic_server") @utils.deprecated("Please use create_classic_server")
def register_server( def register_server(
self, self,
psm: int, psm: int,
@@ -1598,7 +1591,7 @@ class ChannelManager:
return self.servers[spec.psm] return self.servers[spec.psm]
@deprecated("Please use create_le_credit_based_server()") @utils.deprecated("Please use create_le_credit_based_server()")
def register_le_coc_server( def register_le_coc_server(
self, self,
psm: int, psm: int,
@@ -1927,7 +1920,7 @@ class ChannelManager:
def on_l2cap_connection_parameter_update_request( def on_l2cap_connection_parameter_update_request(
self, connection: Connection, cid: int, request self, connection: Connection, cid: int, request
): ):
if connection.role == BT_CENTRAL_ROLE: if connection.role == Role.CENTRAL:
self.send_control_frame( self.send_control_frame(
connection, connection,
cid, cid,
@@ -2142,7 +2135,7 @@ class ChannelManager:
if channel.source_cid in connection_channels: if channel.source_cid in connection_channels:
del connection_channels[channel.source_cid] del connection_channels[channel.source_cid]
@deprecated("Please use create_le_credit_based_channel()") @utils.deprecated("Please use create_le_credit_based_channel()")
async def open_le_coc( async def open_le_coc(
self, connection: Connection, psm: int, max_credits: int, mtu: int, mps: int self, connection: Connection, psm: int, max_credits: int, mtu: int, mps: int
) -> LeCreditBasedChannel: ) -> LeCreditBasedChannel:
@@ -2199,7 +2192,7 @@ class ChannelManager:
return channel return channel
@deprecated("Please use create_classic_channel()") @utils.deprecated("Please use create_classic_channel()")
async def connect(self, connection: Connection, psm: int) -> ClassicChannel: async def connect(self, connection: Connection, psm: int) -> ClassicChannel:
return await self.create_classic_channel( return await self.create_classic_channel(
connection=connection, spec=ClassicChannelSpec(psm=psm) connection=connection, spec=ClassicChannelSpec(psm=psm)
@@ -2249,12 +2242,12 @@ class ChannelManager:
class Channel(ClassicChannel): class Channel(ClassicChannel):
@deprecated("Please use ClassicChannel") @utils.deprecated("Please use ClassicChannel")
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class LeConnectionOrientedChannel(LeCreditBasedChannel): class LeConnectionOrientedChannel(LeCreditBasedChannel):
@deprecated("Please use LeCreditBasedChannel") @utils.deprecated("Please use LeCreditBasedChannel")
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@@ -20,14 +20,13 @@ import asyncio
from functools import partial from functools import partial
from bumble.core import ( from bumble.core import (
BT_PERIPHERAL_ROLE, PhysicalTransport,
BT_BR_EDR_TRANSPORT,
BT_LE_TRANSPORT,
InvalidStateError, InvalidStateError,
) )
from bumble.colors import color from bumble.colors import color
from bumble.hci import ( from bumble.hci import (
Address, Address,
Role,
HCI_SUCCESS, HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR, HCI_CONNECTION_TIMEOUT_ERROR,
@@ -116,10 +115,10 @@ class LocalLink:
def send_acl_data(self, sender_controller, destination_address, transport, data): def send_acl_data(self, sender_controller, destination_address, transport, data):
# Send the data to the first controller with a matching address # Send the data to the first controller with a matching address
if transport == BT_LE_TRANSPORT: if transport == PhysicalTransport.LE:
destination_controller = self.find_controller(destination_address) destination_controller = self.find_controller(destination_address)
source_address = sender_controller.random_address source_address = sender_controller.random_address
elif transport == BT_BR_EDR_TRANSPORT: elif transport == PhysicalTransport.BR_EDR:
destination_controller = self.find_classic_controller(destination_address) destination_controller = self.find_classic_controller(destination_address)
source_address = sender_controller.public_address source_address = sender_controller.public_address
else: else:
@@ -292,7 +291,7 @@ class LocalLink:
return return
async def task(): async def task():
if responder_role != BT_PERIPHERAL_ROLE: if responder_role != Role.PERIPHERAL:
initiator_controller.on_classic_role_change( initiator_controller.on_classic_role_change(
responder_controller.public_address, int(not (responder_role)) responder_controller.public_address, int(not (responder_role))
) )

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2023 Google LLC # Copyright 2021-2025 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -20,14 +20,14 @@ import enum
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
from .hci import ( from bumble.hci import (
Address, Address,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY, HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
HCI_DISPLAY_ONLY_IO_CAPABILITY, HCI_DISPLAY_ONLY_IO_CAPABILITY,
HCI_DISPLAY_YES_NO_IO_CAPABILITY, HCI_DISPLAY_YES_NO_IO_CAPABILITY,
HCI_KEYBOARD_ONLY_IO_CAPABILITY, HCI_KEYBOARD_ONLY_IO_CAPABILITY,
) )
from .smp import ( from bumble.smp import (
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY, SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
SMP_KEYBOARD_ONLY_IO_CAPABILITY, SMP_KEYBOARD_ONLY_IO_CAPABILITY,
SMP_DISPLAY_ONLY_IO_CAPABILITY, SMP_DISPLAY_ONLY_IO_CAPABILITY,
@@ -41,7 +41,7 @@ from .smp import (
OobLegacyContext, OobLegacyContext,
OobSharedData, OobSharedData,
) )
from .core import AdvertisingData, LeRole from bumble.core import AdvertisingData, LeRole
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -76,18 +76,18 @@ class OobData:
return instance return instance
def to_ad(self) -> AdvertisingData: def to_ad(self) -> AdvertisingData:
ad_structures = [] ad_structures: list[tuple[int, bytes]] = []
if self.address is not None: if self.address is not None:
ad_structures.append( ad_structures.append(
(AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address)) (AdvertisingData.Type.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address))
) )
if self.role is not None: if self.role is not None:
ad_structures.append((AdvertisingData.LE_ROLE, bytes([self.role]))) ad_structures.append((AdvertisingData.Type.LE_ROLE, bytes([self.role])))
if self.shared_data is not None: if self.shared_data is not None:
ad_structures.extend(self.shared_data.to_ad().ad_structures) ad_structures.extend(self.shared_data.to_ad().ad_structures)
if self.legacy_context is not None: if self.legacy_context is not None:
ad_structures.append( ad_structures.append(
(AdvertisingData.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk) (AdvertisingData.Type.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk)
) )
return AdvertisingData(ad_structures) return AdvertisingData(ad_structures)
@@ -139,16 +139,19 @@ class PairingDelegate:
io_capability: IoCapability io_capability: IoCapability
local_initiator_key_distribution: KeyDistribution local_initiator_key_distribution: KeyDistribution
local_responder_key_distribution: KeyDistribution local_responder_key_distribution: KeyDistribution
maximum_encryption_key_size: int
def __init__( def __init__(
self, self,
io_capability: IoCapability = NO_OUTPUT_NO_INPUT, io_capability: IoCapability = NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION, local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION, local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
maximum_encryption_key_size: int = 16,
) -> None: ) -> None:
self.io_capability = io_capability self.io_capability = io_capability
self.local_initiator_key_distribution = local_initiator_key_distribution self.local_initiator_key_distribution = local_initiator_key_distribution
self.local_responder_key_distribution = local_responder_key_distribution self.local_responder_key_distribution = local_responder_key_distribution
self.maximum_encryption_key_size = maximum_encryption_key_size
@property @property
def classic_io_capability(self) -> int: def classic_io_capability(self) -> int:

View File

@@ -22,11 +22,11 @@ __version__ = "0.0.1"
import grpc import grpc
import grpc.aio import grpc.aio
from .config import Config from bumble.pandora.config import Config
from .device import PandoraDevice from bumble.pandora.device import PandoraDevice
from .host import HostService from bumble.pandora.host import HostService
from .l2cap import L2CAPService from bumble.pandora.l2cap import L2CAPService
from .security import SecurityService, SecurityStorageService from bumble.pandora.security import SecurityService, SecurityStorageService
from pandora.host_grpc_aio import add_HostServicer_to_server from pandora.host_grpc_aio import add_HostServicer_to_server
from pandora.l2cap_grpc_aio import add_L2CAPServicer_to_server from pandora.l2cap_grpc_aio import add_L2CAPServicer_to_server
from pandora.security_grpc_aio import ( from pandora.security_grpc_aio import (

View File

@@ -20,12 +20,11 @@ import grpc.aio
import logging import logging
import struct import struct
from . import utils import bumble.utils
from .config import Config from bumble.pandora import utils
from bumble.pandora.config import Config
from bumble.core import ( from bumble.core import (
BT_BR_EDR_TRANSPORT, PhysicalTransport,
BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE,
UUID, UUID,
AdvertisingData, AdvertisingData,
Appearance, Appearance,
@@ -39,7 +38,6 @@ from bumble.device import (
AdvertisingEventProperties, AdvertisingEventProperties,
AdvertisingType, AdvertisingType,
Device, Device,
Phy,
) )
from bumble.gatt import Service from bumble.gatt import Service
from bumble.hci import ( from bumble.hci import (
@@ -47,6 +45,9 @@ from bumble.hci import (
HCI_PAGE_TIMEOUT_ERROR, HCI_PAGE_TIMEOUT_ERROR,
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
Address, Address,
Phy,
Role,
OwnAddressType,
) )
from google.protobuf import any_pb2 # pytype: disable=pyi-error from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error from google.protobuf import empty_pb2 # pytype: disable=pyi-error
@@ -114,11 +115,11 @@ SECONDARY_PHY_TO_BUMBLE_PHY_MAP: Dict[SecondaryPhy, Phy] = {
SECONDARY_CODED: Phy.LE_CODED, SECONDARY_CODED: Phy.LE_CODED,
} }
OWN_ADDRESS_MAP: Dict[host_pb2.OwnAddressType, bumble.hci.OwnAddressType] = { OWN_ADDRESS_MAP: Dict[host_pb2.OwnAddressType, OwnAddressType] = {
host_pb2.PUBLIC: bumble.hci.OwnAddressType.PUBLIC, host_pb2.PUBLIC: OwnAddressType.PUBLIC,
host_pb2.RANDOM: bumble.hci.OwnAddressType.RANDOM, host_pb2.RANDOM: OwnAddressType.RANDOM,
host_pb2.RESOLVABLE_OR_PUBLIC: bumble.hci.OwnAddressType.RESOLVABLE_OR_PUBLIC, host_pb2.RESOLVABLE_OR_PUBLIC: OwnAddressType.RESOLVABLE_OR_PUBLIC,
host_pb2.RESOLVABLE_OR_RANDOM: bumble.hci.OwnAddressType.RESOLVABLE_OR_RANDOM, host_pb2.RESOLVABLE_OR_RANDOM: OwnAddressType.RESOLVABLE_OR_RANDOM,
} }
@@ -184,7 +185,7 @@ class HostService(HostServicer):
try: try:
connection = await self.device.connect( connection = await self.device.connect(
address, transport=BT_BR_EDR_TRANSPORT address, transport=PhysicalTransport.BR_EDR
) )
except ConnectionError as e: except ConnectionError as e:
if e.error_code == HCI_PAGE_TIMEOUT_ERROR: if e.error_code == HCI_PAGE_TIMEOUT_ERROR:
@@ -217,7 +218,7 @@ class HostService(HostServicer):
self.log.debug(f"WaitConnection from {address}...") self.log.debug(f"WaitConnection from {address}...")
connection = self.device.find_connection_by_bd_addr( connection = self.device.find_connection_by_bd_addr(
address, transport=BT_BR_EDR_TRANSPORT address, transport=PhysicalTransport.BR_EDR
) )
if connection and id(connection) in self.waited_connections: if connection and id(connection) in self.waited_connections:
# this connection was already returned: wait for a new one. # this connection was already returned: wait for a new one.
@@ -249,8 +250,8 @@ class HostService(HostServicer):
try: try:
connection = await self.device.connect( connection = await self.device.connect(
address, address,
transport=BT_LE_TRANSPORT, transport=PhysicalTransport.LE,
own_address_type=request.own_address_type, own_address_type=OwnAddressType(request.own_address_type),
) )
except ConnectionError as e: except ConnectionError as e:
if e.error_code == HCI_PAGE_TIMEOUT_ERROR: if e.error_code == HCI_PAGE_TIMEOUT_ERROR:
@@ -295,12 +296,12 @@ class HostService(HostServicer):
def on_disconnection(_: None) -> None: def on_disconnection(_: None) -> None:
disconnection_future.set_result(None) disconnection_future.set_result(None)
connection.on('disconnection', on_disconnection) connection.on(connection.EVENT_DISCONNECTION, on_disconnection)
try: try:
await disconnection_future await disconnection_future
self.log.debug("Disconnected") self.log.debug("Disconnected")
finally: finally:
connection.remove_listener('disconnection', on_disconnection) # type: ignore connection.remove_listener(connection.EVENT_DISCONNECTION, on_disconnection) # type: ignore
return empty_pb2.Empty() return empty_pb2.Empty()
@@ -371,20 +372,18 @@ class HostService(HostServicer):
scan_response_data=scan_response_data, scan_response_data=scan_response_data,
) )
pending_connection: asyncio.Future[bumble.device.Connection] = ( connections: asyncio.Queue[bumble.device.Connection] = asyncio.Queue()
asyncio.get_running_loop().create_future()
)
if request.connectable: if request.connectable:
def on_connection(connection: bumble.device.Connection) -> None: def on_connection(connection: bumble.device.Connection) -> None:
if ( if (
connection.transport == BT_LE_TRANSPORT connection.transport == PhysicalTransport.LE
and connection.role == BT_PERIPHERAL_ROLE and connection.role == Role.PERIPHERAL
): ):
pending_connection.set_result(connection) connections.put_nowait(connection)
self.device.on('connection', on_connection) self.device.on(self.device.EVENT_CONNECTION, on_connection)
try: try:
# Advertise until RPC is canceled # Advertise until RPC is canceled
@@ -397,8 +396,7 @@ class HostService(HostServicer):
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
connection = await pending_connection connection = await connections.get()
pending_connection = asyncio.get_running_loop().create_future()
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big')) cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
yield AdvertiseResponse(connection=Connection(cookie=cookie)) yield AdvertiseResponse(connection=Connection(cookie=cookie))
@@ -492,16 +490,18 @@ class HostService(HostServicer):
target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS) target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY
connections: asyncio.Queue[bumble.device.Connection] = asyncio.Queue()
if request.connectable: if request.connectable:
def on_connection(connection: bumble.device.Connection) -> None: def on_connection(connection: bumble.device.Connection) -> None:
if ( if (
connection.transport == BT_LE_TRANSPORT connection.transport == PhysicalTransport.LE
and connection.role == BT_PERIPHERAL_ROLE and connection.role == Role.PERIPHERAL
): ):
pending_connection.set_result(connection) connections.put_nowait(connection)
self.device.on('connection', on_connection) self.device.on(self.device.EVENT_CONNECTION, on_connection)
try: try:
while True: while True:
@@ -510,19 +510,15 @@ class HostService(HostServicer):
await self.device.start_advertising( await self.device.start_advertising(
target=target, target=target,
advertising_type=advertising_type, advertising_type=advertising_type,
own_address_type=request.own_address_type, own_address_type=OwnAddressType(request.own_address_type),
) )
if not request.connectable: if not request.connectable:
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
pending_connection: asyncio.Future[bumble.device.Connection] = (
asyncio.get_running_loop().create_future()
)
self.log.debug('Wait for LE connection...') self.log.debug('Wait for LE connection...')
connection = await pending_connection connection = await connections.get()
self.log.debug( self.log.debug(
f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})" f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})"
@@ -535,11 +531,13 @@ class HostService(HostServicer):
await asyncio.sleep(1) await asyncio.sleep(1)
finally: finally:
if request.connectable: if request.connectable:
self.device.remove_listener('connection', on_connection) # type: ignore self.device.remove_listener(self.device.EVENT_CONNECTION, on_connection) # type: ignore
try: try:
self.log.debug('Stop advertising') self.log.debug('Stop advertising')
await self.device.abort_on('flush', self.device.stop_advertising()) await bumble.utils.cancel_on_event(
self.device, 'flush', self.device.stop_advertising()
)
except: except:
pass pass
@@ -559,11 +557,11 @@ class HostService(HostServicer):
scanning_phys = [int(Phy.LE_1M), int(Phy.LE_CODED)] scanning_phys = [int(Phy.LE_1M), int(Phy.LE_CODED)]
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue() scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
handler = self.device.on('advertisement', scan_queue.put_nowait) handler = self.device.on(self.device.EVENT_ADVERTISEMENT, scan_queue.put_nowait)
await self.device.start_scanning( await self.device.start_scanning(
legacy=request.legacy, legacy=request.legacy,
active=not request.passive, active=not request.passive,
own_address_type=request.own_address_type, own_address_type=OwnAddressType(request.own_address_type),
scan_interval=( scan_interval=(
int(request.interval) int(request.interval)
if request.interval if request.interval
@@ -604,10 +602,12 @@ class HostService(HostServicer):
yield sr yield sr
finally: finally:
self.device.remove_listener('advertisement', handler) # type: ignore self.device.remove_listener(self.device.EVENT_ADVERTISEMENT, handler) # type: ignore
try: try:
self.log.debug('Stop scanning') self.log.debug('Stop scanning')
await self.device.abort_on('flush', self.device.stop_scanning()) await bumble.utils.cancel_on_event(
self.device, 'flush', self.device.stop_scanning()
)
except: except:
pass pass
@@ -621,10 +621,10 @@ class HostService(HostServicer):
Optional[Tuple[Address, int, AdvertisingData, int]] Optional[Tuple[Address, int, AdvertisingData, int]]
] = asyncio.Queue() ] = asyncio.Queue()
complete_handler = self.device.on( complete_handler = self.device.on(
'inquiry_complete', lambda: inquiry_queue.put_nowait(None) self.device.EVENT_INQUIRY_COMPLETE, lambda: inquiry_queue.put_nowait(None)
) )
result_handler = self.device.on( # type: ignore result_handler = self.device.on( # type: ignore
'inquiry_result', self.device.EVENT_INQUIRY_RESULT,
lambda address, class_of_device, eir_data, rssi: inquiry_queue.put_nowait( # type: ignore lambda address, class_of_device, eir_data, rssi: inquiry_queue.put_nowait( # type: ignore
(address, class_of_device, eir_data, rssi) # type: ignore (address, class_of_device, eir_data, rssi) # type: ignore
), ),
@@ -643,11 +643,13 @@ class HostService(HostServicer):
) )
finally: finally:
self.device.remove_listener('inquiry_complete', complete_handler) # type: ignore self.device.remove_listener(self.device.EVENT_INQUIRY_COMPLETE, complete_handler) # type: ignore
self.device.remove_listener('inquiry_result', result_handler) # type: ignore self.device.remove_listener(self.device.EVENT_INQUIRY_RESULT, result_handler) # type: ignore
try: try:
self.log.debug('Stop inquiry') self.log.debug('Stop inquiry')
await self.device.abort_on('flush', self.device.stop_discovery()) await bumble.utils.cancel_on_event(
self.device, 'flush', self.device.stop_discovery()
)
except: except:
pass pass

View File

@@ -19,8 +19,8 @@ import logging
from asyncio import Queue as AsyncQueue, Future from asyncio import Queue as AsyncQueue, Future
from . import utils from bumble.pandora import utils
from .config import Config from bumble.pandora.config import Config
from bumble.core import OutOfResourcesError, InvalidArgumentError from bumble.core import OutOfResourcesError, InvalidArgumentError
from bumble.device import Device from bumble.device import Device
from bumble.l2cap import ( from bumble.l2cap import (
@@ -83,7 +83,7 @@ class L2CAPService(L2CAPServicer):
close_future.set_result(None) close_future.set_result(None)
l2cap_channel.sink = on_channel_sdu l2cap_channel.sink = on_channel_sdu
l2cap_channel.on('close', on_close) l2cap_channel.on(l2cap_channel.EVENT_CLOSE, on_close)
return ChannelContext(close_future, sdu_queue) return ChannelContext(close_future, sdu_queue)
@@ -151,7 +151,7 @@ class L2CAPService(L2CAPServicer):
spec=spec, handler=on_l2cap_channel spec=spec, handler=on_l2cap_channel
) )
else: else:
l2cap_server.on('connection', on_l2cap_channel) l2cap_server.on(l2cap_server.EVENT_CONNECTION, on_l2cap_channel)
try: try:
self.log.debug('Waiting for a channel connection.') self.log.debug('Waiting for a channel connection.')

View File

@@ -15,21 +15,21 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
from collections.abc import Awaitable
import grpc import grpc
import logging import logging
from . import utils from bumble.pandora import utils
from .config import Config from bumble.pandora.config import Config
from bumble import hci from bumble import hci
from bumble.core import ( from bumble.core import (
BT_BR_EDR_TRANSPORT, PhysicalTransport,
BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE,
ProtocolError, ProtocolError,
InvalidArgumentError,
) )
import bumble.utils
from bumble.device import Connection as BumbleConnection, Device from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error from bumble.hci import HCI_Error, Role
from bumble.utils import EventWatcher
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
from google.protobuf import any_pb2 # pytype: disable=pyi-error from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error from google.protobuf import empty_pb2 # pytype: disable=pyi-error
@@ -95,7 +95,7 @@ class PairingDelegate(BasePairingDelegate):
else: else:
# In BR/EDR, connection may not be complete, # In BR/EDR, connection may not be complete,
# use address instead # use address instead
assert self.connection.transport == BT_BR_EDR_TRANSPORT assert self.connection.transport == PhysicalTransport.BR_EDR
ev.address = bytes(reversed(bytes(self.connection.peer_address))) ev.address = bytes(reversed(bytes(self.connection.peer_address)))
return ev return ev
@@ -174,7 +174,7 @@ class PairingDelegate(BasePairingDelegate):
async def display_number(self, number: int, digits: int = 6) -> None: async def display_number(self, number: int, digits: int = 6) -> None:
if ( if (
self.connection.transport == BT_BR_EDR_TRANSPORT self.connection.transport == PhysicalTransport.BR_EDR
and self.io_capability == BasePairingDelegate.DISPLAY_OUTPUT_ONLY and self.io_capability == BasePairingDelegate.DISPLAY_OUTPUT_ONLY
): ):
return return
@@ -190,35 +190,6 @@ class PairingDelegate(BasePairingDelegate):
self.service.event_queue.put_nowait(event) self.service.event_queue.put_nowait(event)
BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = {
LEVEL0: lambda connection: True,
LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated,
LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated,
LEVEL3: lambda connection: connection.encryption != 0
and connection.authenticated
and connection.link_key_type
in (
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
),
LEVEL4: lambda connection: connection.encryption
== hci.HCI_Encryption_Change_Event.AES_CCM
and connection.authenticated
and connection.link_key_type
== hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
}
LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = {
LE_LEVEL1: lambda connection: True,
LE_LEVEL2: lambda connection: connection.encryption != 0,
LE_LEVEL3: lambda connection: connection.encryption != 0
and connection.authenticated,
LE_LEVEL4: lambda connection: connection.encryption != 0
and connection.authenticated
and connection.sc,
}
class SecurityService(SecurityServicer): class SecurityService(SecurityServicer):
def __init__(self, device: Device, config: Config) -> None: def __init__(self, device: Device, config: Config) -> None:
self.log = utils.BumbleServerLoggerAdapter( self.log = utils.BumbleServerLoggerAdapter(
@@ -250,6 +221,59 @@ class SecurityService(SecurityServicer):
self.device.pairing_config_factory = pairing_config_factory self.device.pairing_config_factory = pairing_config_factory
async def _classic_level_reached(
self, level: SecurityLevel, connection: BumbleConnection
) -> bool:
if level == LEVEL0:
return True
if level == LEVEL1:
return connection.encryption == 0 or connection.authenticated
if level == LEVEL2:
return connection.encryption != 0 and connection.authenticated
link_key_type: Optional[int] = None
if (keystore := connection.device.keystore) and (
keys := await keystore.get(str(connection.peer_address))
):
link_key_type = keys.link_key_type
self.log.debug("link_key_type: %d", link_key_type)
if level == LEVEL3:
return (
connection.encryption != 0
and connection.authenticated
and link_key_type
in (
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
)
)
if level == LEVEL4:
return (
connection.encryption == hci.HCI_Encryption_Change_Event.AES_CCM
and connection.authenticated
and link_key_type
== hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE
)
raise InvalidArgumentError(f"Unexpected level {level}")
def _le_level_reached(
self, level: LESecurityLevel, connection: BumbleConnection
) -> bool:
if level == LE_LEVEL1:
return True
if level == LE_LEVEL2:
return connection.encryption != 0
if level == LE_LEVEL3:
return connection.encryption != 0 and connection.authenticated
if level == LE_LEVEL4:
return (
connection.encryption != 0
and connection.authenticated
and connection.sc
)
raise InvalidArgumentError(f"Unexpected level {level}")
@utils.rpc @utils.rpc
async def OnPairing( async def OnPairing(
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
@@ -287,12 +311,12 @@ class SecurityService(SecurityServicer):
oneof = request.WhichOneof('level') oneof = request.WhichOneof('level')
level = getattr(request, oneof) level = getattr(request, oneof)
assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[ assert {PhysicalTransport.BR_EDR: 'classic', PhysicalTransport.LE: 'le'}[
connection.transport connection.transport
] == oneof ] == oneof
# security level already reached # security level already reached
if self.reached_security_level(connection, level): if await self.reached_security_level(connection, level):
return SecureResponse(success=empty_pb2.Empty()) return SecureResponse(success=empty_pb2.Empty())
# trigger pairing if needed # trigger pairing if needed
@@ -302,23 +326,23 @@ class SecurityService(SecurityServicer):
security_result = asyncio.get_running_loop().create_future() security_result = asyncio.get_running_loop().create_future()
with contextlib.closing(EventWatcher()) as watcher: with contextlib.closing(bumble.utils.EventWatcher()) as watcher:
@watcher.on(connection, 'pairing') @watcher.on(connection, connection.EVENT_PAIRING)
def on_pairing(*_: Any) -> None: def on_pairing(*_: Any) -> None:
security_result.set_result('success') security_result.set_result('success')
@watcher.on(connection, 'pairing_failure') @watcher.on(connection, connection.EVENT_PAIRING_FAILURE)
def on_pairing_failure(*_: Any) -> None: def on_pairing_failure(*_: Any) -> None:
security_result.set_result('pairing_failure') security_result.set_result('pairing_failure')
@watcher.on(connection, 'disconnection') @watcher.on(connection, connection.EVENT_DISCONNECTION)
def on_disconnection(*_: Any) -> None: def on_disconnection(*_: Any) -> None:
security_result.set_result('connection_died') security_result.set_result('connection_died')
if ( if (
connection.transport == BT_LE_TRANSPORT connection.transport == PhysicalTransport.LE
and connection.role == BT_PERIPHERAL_ROLE and connection.role == Role.PERIPHERAL
): ):
connection.request_pairing() connection.request_pairing()
else: else:
@@ -363,7 +387,7 @@ class SecurityService(SecurityServicer):
return SecureResponse(encryption_failure=empty_pb2.Empty()) return SecureResponse(encryption_failure=empty_pb2.Empty())
# security level has been reached ? # security level has been reached ?
if self.reached_security_level(connection, level): if await self.reached_security_level(connection, level):
return SecureResponse(success=empty_pb2.Empty()) return SecureResponse(success=empty_pb2.Empty())
return SecureResponse(not_reached=empty_pb2.Empty()) return SecureResponse(not_reached=empty_pb2.Empty())
@@ -379,7 +403,7 @@ class SecurityService(SecurityServicer):
assert request.level assert request.level
level = request.level level = request.level
assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[ assert {PhysicalTransport.BR_EDR: 'classic', PhysicalTransport.LE: 'le'}[
connection.transport connection.transport
] == request.level_variant() ] == request.level_variant()
@@ -390,13 +414,10 @@ class SecurityService(SecurityServicer):
pair_task: Optional[asyncio.Future[None]] = None pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None: async def authenticate() -> None:
assert connection
if (encryption := connection.encryption) != 0: if (encryption := connection.encryption) != 0:
self.log.debug('Disable encryption...') self.log.debug('Disable encryption...')
try: with contextlib.suppress(Exception):
await connection.encrypt(enable=False) await connection.encrypt(enable=False)
except:
pass
self.log.debug('Disable encryption: done') self.log.debug('Disable encryption: done')
self.log.debug('Authenticate...') self.log.debug('Authenticate...')
@@ -415,19 +436,17 @@ class SecurityService(SecurityServicer):
return wrapper return wrapper
def try_set_success(*_: Any) -> None: async def try_set_success(*_: Any) -> None:
assert connection if await self.reached_security_level(connection, level):
if self.reached_security_level(connection, level):
self.log.debug('Wait for security: done') self.log.debug('Wait for security: done')
wait_for_security.set_result('success') wait_for_security.set_result('success')
def on_encryption_change(*_: Any) -> None: async def on_encryption_change(*_: Any) -> None:
assert connection if await self.reached_security_level(connection, level):
if self.reached_security_level(connection, level):
self.log.debug('Wait for security: done') self.log.debug('Wait for security: done')
wait_for_security.set_result('success') wait_for_security.set_result('success')
elif ( elif (
connection.transport == BT_BR_EDR_TRANSPORT connection.transport == PhysicalTransport.BR_EDR
and self.need_authentication(connection, level) and self.need_authentication(connection, level)
): ):
nonlocal authenticate_task nonlocal authenticate_task
@@ -438,7 +457,7 @@ class SecurityService(SecurityServicer):
if self.need_pairing(connection, level): if self.need_pairing(connection, level):
pair_task = asyncio.create_task(connection.pair()) pair_task = asyncio.create_task(connection.pair())
listeners: Dict[str, Callable[..., None]] = { listeners: Dict[str, Callable[..., Union[None, Awaitable[None]]]] = {
'disconnection': set_failure('connection_died'), 'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'), 'pairing_failure': set_failure('pairing_failure'),
'connection_authentication_failure': set_failure('authentication_failure'), 'connection_authentication_failure': set_failure('authentication_failure'),
@@ -451,13 +470,13 @@ class SecurityService(SecurityServicer):
'security_request': pair, 'security_request': pair,
} }
with contextlib.closing(EventWatcher()) as watcher: with contextlib.closing(bumble.utils.EventWatcher()) as watcher:
# register event handlers # register event handlers
for event, listener in listeners.items(): for event, listener in listeners.items():
watcher.on(connection, event, listener) watcher.on(connection, event, listener)
# security level already reached # security level already reached
if self.reached_security_level(connection, level): if await self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty()) return WaitSecurityResponse(success=empty_pb2.Empty())
self.log.debug('Wait for security...') self.log.debug('Wait for security...')
@@ -467,24 +486,20 @@ class SecurityService(SecurityServicer):
# wait for `authenticate` to finish if any # wait for `authenticate` to finish if any
if authenticate_task is not None: if authenticate_task is not None:
self.log.debug('Wait for authentication...') self.log.debug('Wait for authentication...')
try: with contextlib.suppress(Exception):
await authenticate_task # type: ignore await authenticate_task # type: ignore
except:
pass
self.log.debug('Authenticated') self.log.debug('Authenticated')
# wait for `pair` to finish if any # wait for `pair` to finish if any
if pair_task is not None: if pair_task is not None:
self.log.debug('Wait for authentication...') self.log.debug('Wait for authentication...')
try: with contextlib.suppress(Exception):
await pair_task # type: ignore await pair_task # type: ignore
except:
pass
self.log.debug('paired') self.log.debug('paired')
return WaitSecurityResponse(**kwargs) return WaitSecurityResponse(**kwargs)
def reached_security_level( async def reached_security_level(
self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel] self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
) -> bool: ) -> bool:
self.log.debug( self.log.debug(
@@ -494,23 +509,22 @@ class SecurityService(SecurityServicer):
'encryption': connection.encryption, 'encryption': connection.encryption,
'authenticated': connection.authenticated, 'authenticated': connection.authenticated,
'sc': connection.sc, 'sc': connection.sc,
'link_key_type': connection.link_key_type,
} }
) )
) )
if isinstance(level, LESecurityLevel): if isinstance(level, LESecurityLevel):
return LE_LEVEL_REACHED[level](connection) return self._le_level_reached(level, connection)
return BR_LEVEL_REACHED[level](connection) return await self._classic_level_reached(level, connection)
def need_pairing(self, connection: BumbleConnection, level: int) -> bool: def need_pairing(self, connection: BumbleConnection, level: int) -> bool:
if connection.transport == BT_LE_TRANSPORT: if connection.transport == PhysicalTransport.LE:
return level >= LE_LEVEL3 and not connection.authenticated return level >= LE_LEVEL3 and not connection.authenticated
return False return False
def need_authentication(self, connection: BumbleConnection, level: int) -> bool: def need_authentication(self, connection: BumbleConnection, level: int) -> bool:
if connection.transport == BT_LE_TRANSPORT: if connection.transport == PhysicalTransport.LE:
return False return False
if level == LEVEL2 and connection.encryption != 0: if level == LEVEL2 and connection.encryption != 0:
return not connection.authenticated return not connection.authenticated
@@ -518,7 +532,7 @@ class SecurityService(SecurityServicer):
def need_encryption(self, connection: BumbleConnection, level: int) -> bool: def need_encryption(self, connection: BumbleConnection, level: int) -> bool:
# TODO(abel): need to support MITM # TODO(abel): need to support MITM
if connection.transport == BT_LE_TRANSPORT: if connection.transport == PhysicalTransport.LE:
return level == LE_LEVEL2 and not connection.encryption return level == LE_LEVEL2 and not connection.encryption
return level >= LEVEL2 and not connection.encryption return level >= LEVEL2 and not connection.encryption

View File

@@ -20,11 +20,11 @@ import inspect
import logging import logging
from bumble.device import Device from bumble.device import Device
from bumble.hci import Address from bumble.hci import Address, AddressType
from google.protobuf.message import Message # pytype: disable=pyi-error from google.protobuf.message import Message # pytype: disable=pyi-error
from typing import Any, Dict, Generator, MutableMapping, Optional, Tuple from typing import Any, Dict, Generator, MutableMapping, Optional, Tuple
ADDRESS_TYPES: Dict[str, int] = { ADDRESS_TYPES: Dict[str, AddressType] = {
"public": Address.PUBLIC_DEVICE_ADDRESS, "public": Address.PUBLIC_DEVICE_ADDRESS,
"random": Address.RANDOM_DEVICE_ADDRESS, "random": Address.RANDOM_DEVICE_ADDRESS,
"public_identity": Address.PUBLIC_IDENTITY_ADDRESS, "public_identity": Address.PUBLIC_IDENTITY_ADDRESS,

View File

@@ -17,21 +17,20 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import logging import logging
import struct import struct
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from bumble import gatt
from bumble.device import Connection from bumble.device import Connection
from bumble.att import ATT_Error from bumble.att import ATT_Error
from bumble.gatt import ( from bumble.gatt import (
Attribute,
Characteristic, Characteristic,
DelegatedCharacteristicAdapter,
TemplateService, TemplateService,
CharacteristicValue, CharacteristicValue,
PackedCharacteristicAdapter,
GATT_AUDIO_INPUT_CONTROL_SERVICE, GATT_AUDIO_INPUT_CONTROL_SERVICE,
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC, GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC, GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
@@ -40,8 +39,16 @@ from bumble.gatt import (
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC, GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC, GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
) )
from bumble.gatt_adapters import (
CharacteristicProxy,
PackedCharacteristicProxyAdapter,
SerializableCharacteristicAdapter,
SerializableCharacteristicProxyAdapter,
UTF8CharacteristicAdapter,
UTF8CharacteristicProxyAdapter,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
from bumble.utils import OpenIntEnum from bumble import utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -57,7 +64,7 @@ GAIN_SETTINGS_MIN_VALUE = 0
GAIN_SETTINGS_MAX_VALUE = 255 GAIN_SETTINGS_MAX_VALUE = 255
class ErrorCode(OpenIntEnum): class ErrorCode(utils.OpenIntEnum):
''' '''
Cf. 1.6 Application error codes Cf. 1.6 Application error codes
''' '''
@@ -69,7 +76,7 @@ class ErrorCode(OpenIntEnum):
GAIN_MODE_CHANGE_NOT_ALLOWED = 0x84 GAIN_MODE_CHANGE_NOT_ALLOWED = 0x84
class Mute(OpenIntEnum): class Mute(utils.OpenIntEnum):
''' '''
Cf. 2.2.1.2 Mute Field Cf. 2.2.1.2 Mute Field
''' '''
@@ -79,7 +86,7 @@ class Mute(OpenIntEnum):
DISABLED = 0x02 DISABLED = 0x02
class GainMode(OpenIntEnum): class GainMode(utils.OpenIntEnum):
''' '''
Cf. 2.2.1.3 Gain Mode Cf. 2.2.1.3 Gain Mode
''' '''
@@ -90,21 +97,21 @@ class GainMode(OpenIntEnum):
AUTOMATIC = 0x03 AUTOMATIC = 0x03
class AudioInputStatus(OpenIntEnum): class AudioInputStatus(utils.OpenIntEnum):
''' '''
Cf. 3.4 Audio Input Status Cf. 3.4 Audio Input Status
''' '''
INATIVE = 0x00 INACTIVE = 0x00
ACTIVE = 0x01 ACTIVE = 0x01
class AudioInputControlPointOpCode(OpenIntEnum): class AudioInputControlPointOpCode(utils.OpenIntEnum):
''' '''
Cf. 3.5.1 Audio Input Control Point procedure requirements Cf. 3.5.1 Audio Input Control Point procedure requirements
''' '''
SET_GAIN_SETTING = 0x00 SET_GAIN_SETTING = 0x01
UNMUTE = 0x02 UNMUTE = 0x02
MUTE = 0x03 MUTE = 0x03
SET_MANUAL_GAIN_MODE = 0x04 SET_MANUAL_GAIN_MODE = 0x04
@@ -122,7 +129,7 @@ class AudioInputState:
mute: Mute = Mute.NOT_MUTED mute: Mute = Mute.NOT_MUTED
gain_mode: GainMode = GainMode.MANUAL gain_mode: GainMode = GainMode.MANUAL
change_counter: int = 0 change_counter: int = 0
attribute_value: Optional[CharacteristicValue] = None attribute: Optional[Attribute] = None
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return bytes( return bytes(
@@ -149,13 +156,8 @@ class AudioInputState:
self.change_counter = (self.change_counter + 1) % (CHANGE_COUNTER_MAX_VALUE + 1) self.change_counter = (self.change_counter + 1) % (CHANGE_COUNTER_MAX_VALUE + 1)
async def notify_subscribers_via_connection(self, connection: Connection) -> None: async def notify_subscribers_via_connection(self, connection: Connection) -> None:
assert self.attribute_value is not None assert self.attribute is not None
await connection.device.notify_subscribers( await connection.device.notify_subscribers(attribute=self.attribute)
attribute=self.attribute_value, value=bytes(self)
)
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
@dataclass @dataclass
@@ -173,7 +175,7 @@ class GainSettingsProperties:
(gain_settings_unit, gain_settings_minimum, gain_settings_maximum) = ( (gain_settings_unit, gain_settings_minimum, gain_settings_maximum) = (
struct.unpack('BBB', data) struct.unpack('BBB', data)
) )
GainSettingsProperties( return GainSettingsProperties(
gain_settings_unit, gain_settings_minimum, gain_settings_maximum gain_settings_unit, gain_settings_minimum, gain_settings_maximum
) )
@@ -186,9 +188,6 @@ class GainSettingsProperties:
] ]
) )
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
@dataclass @dataclass
class AudioInputControlPoint: class AudioInputControlPoint:
@@ -239,7 +238,7 @@ class AudioInputControlPoint:
or gain_settings_operand or gain_settings_operand
> self.gain_settings_properties.gain_settings_maximum > self.gain_settings_properties.gain_settings_maximum
): ):
logger.error("gain_seetings value out of range") logger.error("gain_settings value out of range")
raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE) raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
if self.audio_input_state.gain_settings != gain_settings_operand: if self.audio_input_state.gain_settings != gain_settings_operand:
@@ -319,31 +318,28 @@ class AudioInputDescription:
''' '''
audio_input_description: str = "Bluetooth" audio_input_description: str = "Bluetooth"
attribute_value: Optional[CharacteristicValue] = None attribute: Optional[Attribute] = None
@classmethod def on_read(self, _connection: Optional[Connection]) -> str:
def from_bytes(cls, data: bytes): return self.audio_input_description
return cls(audio_input_description=data.decode('utf-8'))
def __bytes__(self) -> bytes: async def on_write(self, connection: Optional[Connection], value: str) -> None:
return self.audio_input_description.encode('utf-8')
def on_read(self, _connection: Optional[Connection]) -> bytes:
return self.audio_input_description.encode('utf-8')
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection assert connection
assert self.attribute_value assert self.attribute
self.audio_input_description = value.decode('utf-8') self.audio_input_description = value
await connection.device.notify_subscribers( await connection.device.notify_subscribers(attribute=self.attribute)
attribute=self.attribute_value, value=value
)
class AICSService(TemplateService): class AICSService(TemplateService):
UUID = GATT_AUDIO_INPUT_CONTROL_SERVICE UUID = GATT_AUDIO_INPUT_CONTROL_SERVICE
audio_input_state_characteristic: Characteristic[AudioInputState]
audio_input_type_characteristic: Characteristic[bytes]
audio_input_status_characteristic: Characteristic[bytes]
audio_input_control_point_characteristic: Characteristic[bytes]
gain_settings_properties_characteristic: Characteristic[GainSettingsProperties]
def __init__( def __init__(
self, self,
audio_input_state: Optional[AudioInputState] = None, audio_input_state: Optional[AudioInputState] = None,
@@ -375,26 +371,27 @@ class AICSService(TemplateService):
self.audio_input_state, self.gain_settings_properties self.audio_input_state, self.gain_settings_properties
) )
self.audio_input_state_characteristic = DelegatedCharacteristicAdapter( self.audio_input_state_characteristic = SerializableCharacteristicAdapter(
Characteristic( Characteristic(
uuid=GATT_AUDIO_INPUT_STATE_CHARACTERISTIC, uuid=GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
properties=Characteristic.Properties.READ properties=Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY, | Characteristic.Properties.NOTIFY,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=CharacteristicValue(read=self.audio_input_state.on_read), value=self.audio_input_state,
), ),
encode=lambda value: bytes(value), AudioInputState,
)
self.audio_input_state.attribute_value = (
self.audio_input_state_characteristic.value
) )
self.audio_input_state.attribute = self.audio_input_state_characteristic
self.gain_settings_properties_characteristic = DelegatedCharacteristicAdapter( self.gain_settings_properties_characteristic = (
Characteristic( SerializableCharacteristicAdapter(
uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC, Characteristic(
properties=Characteristic.Properties.READ, uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, properties=Characteristic.Properties.READ,
value=CharacteristicValue(read=self.gain_settings_properties.on_read), permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=self.gain_settings_properties,
),
GainSettingsProperties,
) )
) )
@@ -402,7 +399,7 @@ class AICSService(TemplateService):
uuid=GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC, uuid=GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
properties=Characteristic.Properties.READ, properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=audio_input_type, value=bytes(audio_input_type, 'utf-8'),
) )
self.audio_input_status_characteristic = Characteristic( self.audio_input_status_characteristic = Characteristic(
@@ -412,18 +409,14 @@ class AICSService(TemplateService):
value=bytes([self.audio_input_status]), value=bytes([self.audio_input_status]),
) )
self.audio_input_control_point_characteristic = DelegatedCharacteristicAdapter( self.audio_input_control_point_characteristic = Characteristic(
Characteristic( uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC, properties=Characteristic.Properties.WRITE,
properties=Characteristic.Properties.WRITE, permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION, value=CharacteristicValue(write=self.audio_input_control_point.on_write),
value=CharacteristicValue(
write=self.audio_input_control_point.on_write
),
)
) )
self.audio_input_description_characteristic = DelegatedCharacteristicAdapter( self.audio_input_description_characteristic = UTF8CharacteristicAdapter(
Characteristic( Characteristic(
uuid=GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC, uuid=GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
properties=Characteristic.Properties.READ properties=Characteristic.Properties.READ
@@ -437,8 +430,8 @@ class AICSService(TemplateService):
), ),
) )
) )
self.audio_input_description.attribute_value = ( self.audio_input_description.attribute = (
self.audio_input_control_point_characteristic.value self.audio_input_control_point_characteristic
) )
super().__init__( super().__init__(
@@ -460,61 +453,43 @@ class AICSService(TemplateService):
class AICSServiceProxy(ProfileServiceProxy): class AICSServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = AICSService SERVICE_CLASS = AICSService
audio_input_state: CharacteristicProxy[AudioInputState]
gain_settings_properties: CharacteristicProxy[GainSettingsProperties]
audio_input_status: CharacteristicProxy[int]
audio_input_control_point: CharacteristicProxy[bytes]
def __init__(self, service_proxy: ServiceProxy) -> None: def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy self.service_proxy = service_proxy
if not ( self.audio_input_state = SerializableCharacteristicProxyAdapter(
characteristics := service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC GATT_AUDIO_INPUT_STATE_CHARACTERISTIC
) ),
): AudioInputState,
raise gatt.InvalidServiceError("Audio Input State Characteristic not found")
self.audio_input_state = DelegatedCharacteristicAdapter(
characteristic=characteristics[0], decode=AudioInputState.from_bytes
) )
if not ( self.gain_settings_properties = SerializableCharacteristicProxyAdapter(
characteristics := service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC
) ),
): GainSettingsProperties,
raise gatt.InvalidServiceError(
"Gain Settings Attribute Characteristic not found"
)
self.gain_settings_properties = PackedCharacteristicAdapter(
characteristics[0],
'BBB',
) )
if not ( self.audio_input_status = PackedCharacteristicProxyAdapter(
characteristics := service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC
) ),
):
raise gatt.InvalidServiceError(
"Audio Input Status Characteristic not found"
)
self.audio_input_status = PackedCharacteristicAdapter(
characteristics[0],
'B', 'B',
) )
if not ( self.audio_input_control_point = (
characteristics := service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC
) )
): )
raise gatt.InvalidServiceError(
"Audio Input Control Point Characteristic not found"
)
self.audio_input_control_point = characteristics[0]
if not ( self.audio_input_description = UTF8CharacteristicProxyAdapter(
characteristics := service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC
) )
): )
raise gatt.InvalidServiceError(
"Audio Input Description Characteristic not found"
)
self.audio_input_description = characteristics[0]

515
bumble/profiles/ancs.py Normal file
View File

@@ -0,0 +1,515 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Apple Notification Center Service (ANCS).
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import datetime
import enum
import logging
import struct
from typing import Optional, Sequence, Union
from bumble.att import ATT_Error
from bumble.device import Peer
from bumble.gatt import (
Characteristic,
GATT_ANCS_SERVICE,
GATT_ANCS_NOTIFICATION_SOURCE_CHARACTERISTIC,
GATT_ANCS_CONTROL_POINT_CHARACTERISTIC,
GATT_ANCS_DATA_SOURCE_CHARACTERISTIC,
TemplateService,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
from bumble.gatt_adapters import SerializableCharacteristicProxyAdapter
from bumble import utils
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
_DEFAULT_ATTRIBUTE_MAX_LENGTH = 65535
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Protocol
# -----------------------------------------------------------------------------
class ActionId(utils.OpenIntEnum):
POSITIVE = 0
NEGATIVE = 1
class AppAttributeId(utils.OpenIntEnum):
DISPLAY_NAME = 0
class CategoryId(utils.OpenIntEnum):
OTHER = 0
INCOMING_CALL = 1
MISSED_CALL = 2
VOICEMAIL = 3
SOCIAL = 4
SCHEDULE = 5
EMAIL = 6
NEWS = 7
HEALTH_AND_FITNESS = 8
BUSINESS_AND_FINANCE = 9
LOCATION = 10
ENTERTAINMENT = 11
class CommandId(utils.OpenIntEnum):
GET_NOTIFICATION_ATTRIBUTES = 0
GET_APP_ATTRIBUTES = 1
PERFORM_NOTIFICATION_ACTION = 2
class EventId(utils.OpenIntEnum):
NOTIFICATION_ADDED = 0
NOTIFICATION_MODIFIED = 1
NOTIFICATION_REMOVED = 2
class EventFlags(enum.IntFlag):
SILENT = 1 << 0
IMPORTANT = 1 << 1
PRE_EXISTING = 1 << 2
POSITIVE_ACTION = 1 << 3
NEGATIVE_ACTION = 1 << 4
class NotificationAttributeId(utils.OpenIntEnum):
APP_IDENTIFIER = 0
TITLE = 1
SUBTITLE = 2
MESSAGE = 3
MESSAGE_SIZE = 4
DATE = 5
POSITIVE_ACTION_LABEL = 6
NEGATIVE_ACTION_LABEL = 7
@dataclasses.dataclass
class NotificationAttribute:
attribute_id: NotificationAttributeId
value: Union[str, int, datetime.datetime]
@dataclasses.dataclass
class AppAttribute:
attribute_id: AppAttributeId
value: str
@dataclasses.dataclass
class Notification:
event_id: EventId
event_flags: EventFlags
category_id: CategoryId
category_count: int
notification_uid: int
@classmethod
def from_bytes(cls, data: bytes) -> Notification:
return cls(
event_id=EventId(data[0]),
event_flags=EventFlags(data[1]),
category_id=CategoryId(data[2]),
category_count=data[3],
notification_uid=int.from_bytes(data[4:8], 'little'),
)
def __bytes__(self) -> bytes:
return struct.pack(
"<BBBBI",
self.event_id,
self.event_flags,
self.category_id,
self.category_count,
self.notification_uid,
)
class ErrorCode(utils.OpenIntEnum):
UNKNOWN_COMMAND = 0xA0
INVALID_COMMAND = 0xA1
INVALID_PARAMETER = 0xA2
ACTION_FAILED = 0xA3
class ProtocolError(Exception):
pass
class CommandError(Exception):
def __init__(self, error_code: ErrorCode) -> None:
self.error_code = error_code
def __str__(self) -> str:
return f"CommandError(error_code={self.error_code.name})"
# -----------------------------------------------------------------------------
# GATT Server-side
# -----------------------------------------------------------------------------
class Ancs(TemplateService):
UUID = GATT_ANCS_SERVICE
notification_source_characteristic: Characteristic
data_source_characteristic: Characteristic
control_point_characteristic: Characteristic
def __init__(self) -> None:
# TODO not the final implementation
self.notification_source_characteristic = Characteristic(
GATT_ANCS_NOTIFICATION_SOURCE_CHARACTERISTIC,
Characteristic.Properties.NOTIFY,
Characteristic.Permissions.READABLE,
)
# TODO not the final implementation
self.data_source_characteristic = Characteristic(
GATT_ANCS_DATA_SOURCE_CHARACTERISTIC,
Characteristic.Properties.NOTIFY,
Characteristic.Permissions.READABLE,
)
# TODO not the final implementation
self.control_point_characteristic = Characteristic(
GATT_ANCS_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE,
Characteristic.Permissions.WRITEABLE,
)
super().__init__(
[
self.notification_source_characteristic,
self.data_source_characteristic,
self.control_point_characteristic,
]
)
# -----------------------------------------------------------------------------
# GATT Client-side
# -----------------------------------------------------------------------------
class AncsProxy(ProfileServiceProxy):
SERVICE_CLASS = Ancs
notification_source: CharacteristicProxy[Notification]
data_source: CharacteristicProxy
control_point: CharacteristicProxy[bytes]
def __init__(self, service_proxy: ServiceProxy):
self.notification_source = SerializableCharacteristicProxyAdapter(
service_proxy.get_required_characteristic_by_uuid(
GATT_ANCS_NOTIFICATION_SOURCE_CHARACTERISTIC
),
Notification,
)
self.data_source = service_proxy.get_required_characteristic_by_uuid(
GATT_ANCS_DATA_SOURCE_CHARACTERISTIC
)
self.control_point = service_proxy.get_required_characteristic_by_uuid(
GATT_ANCS_CONTROL_POINT_CHARACTERISTIC
)
class AncsClient(utils.EventEmitter):
_expected_response_command_id: Optional[CommandId]
_expected_response_notification_uid: Optional[int]
_expected_response_app_identifier: Optional[str]
_expected_app_identifier: Optional[str]
_expected_response_tuples: int
_response_accumulator: bytes
EVENT_NOTIFICATION = "notification"
def __init__(self, ancs_proxy: AncsProxy) -> None:
super().__init__()
self._ancs_proxy = ancs_proxy
self._command_semaphore = asyncio.Semaphore()
self._response: Optional[asyncio.Future] = None
self._reset_response()
self._started = False
@classmethod
async def for_peer(cls, peer: Peer) -> Optional[AncsClient]:
ancs_proxy = await peer.discover_service_and_create_proxy(AncsProxy)
if ancs_proxy is None:
return None
return cls(ancs_proxy)
async def start(self) -> None:
await self._ancs_proxy.notification_source.subscribe(self._on_notification)
await self._ancs_proxy.data_source.subscribe(self._on_data)
self._started = True
async def stop(self) -> None:
await self._ancs_proxy.notification_source.unsubscribe(self._on_notification)
await self._ancs_proxy.data_source.unsubscribe(self._on_data)
self._started = False
def _reset_response(self) -> None:
self._expected_response_command_id = None
self._expected_response_notification_uid = None
self._expected_app_identifier = None
self._expected_response_tuples = 0
self._response_accumulator = b""
def _on_notification(self, notification: Notification) -> None:
logger.debug(f"ANCS NOTIFICATION: {notification}")
self.emit(self.EVENT_NOTIFICATION, notification)
def _on_data(self, data: bytes) -> None:
logger.debug(f"ANCS DATA: {data.hex()}")
if not self._response:
logger.warning("received unexpected data, discarding")
return
self._response_accumulator += data
# Try to parse the accumulated data until we have all we need.
if not self._response_accumulator:
logger.warning("empty data from data source")
return
command_id = self._response_accumulator[0]
if command_id != self._expected_response_command_id:
logger.warning(
"unexpected response command id: "
f"expected {self._expected_response_command_id} "
f"but got {command_id}"
)
self._reset_response()
if not self._response.done():
self._response.set_exception(ProtocolError())
if len(self._response_accumulator) < 5:
# Not enough data yet.
return
attributes: list[Union[NotificationAttribute, AppAttribute]] = []
if command_id == CommandId.GET_NOTIFICATION_ATTRIBUTES:
(notification_uid,) = struct.unpack_from(
"<I", self._response_accumulator, 1
)
if notification_uid != self._expected_response_notification_uid:
logger.warning(
"unexpected response notification uid: "
f"expected {self._expected_response_notification_uid} "
f"but got {notification_uid}"
)
self._reset_response()
if not self._response.done():
self._response.set_exception(ProtocolError())
attribute_data = self._response_accumulator[5:]
while len(attribute_data) >= 3:
attribute_id, attribute_data_length = struct.unpack_from(
"<BH", attribute_data, 0
)
if len(attribute_data) < 3 + attribute_data_length:
return
str_value = attribute_data[3 : 3 + attribute_data_length].decode(
"utf-8"
)
value: Union[str, int, datetime.datetime]
if attribute_id == NotificationAttributeId.MESSAGE_SIZE:
value = int(str_value)
elif attribute_id == NotificationAttributeId.DATE:
year = int(str_value[:4])
month = int(str_value[4:6])
day = int(str_value[6:8])
hour = int(str_value[9:11])
minute = int(str_value[11:13])
second = int(str_value[13:15])
value = datetime.datetime(year, month, day, hour, minute, second)
else:
value = str_value
attributes.append(
NotificationAttribute(NotificationAttributeId(attribute_id), value)
)
attribute_data = attribute_data[3 + attribute_data_length :]
elif command_id == CommandId.GET_APP_ATTRIBUTES:
if 0 not in self._response_accumulator[1:]:
# No null-terminated string yet.
return
app_identifier_length = self._response_accumulator.find(0, 1) - 1
app_identifier = self._response_accumulator[
1 : 1 + app_identifier_length
].decode("utf-8")
if app_identifier != self._expected_response_app_identifier:
logger.warning(
"unexpected response app identifier: "
f"expected {self._expected_response_app_identifier} "
f"but got {app_identifier}"
)
self._reset_response()
if not self._response.done():
self._response.set_exception(ProtocolError())
attribute_data = self._response_accumulator[1 + app_identifier_length + 1 :]
while len(attribute_data) >= 3:
attribute_id, attribute_data_length = struct.unpack_from(
"<BH", attribute_data, 0
)
if len(attribute_data) < 3 + attribute_data_length:
return
attributes.append(
AppAttribute(
AppAttributeId(attribute_id),
attribute_data[3 : 3 + attribute_data_length].decode("utf-8"),
)
)
attribute_data = attribute_data[3 + attribute_data_length :]
else:
logger.warning(f"unexpected response command id {command_id}")
return
if len(attributes) < self._expected_response_tuples:
# We have not received all the tuples yet.
return
if not self._response.done():
self._response.set_result(attributes)
async def _send_command(self, command: bytes) -> None:
try:
await self._ancs_proxy.control_point.write_value(
command, with_response=True
)
except ATT_Error as error:
raise CommandError(error_code=ErrorCode(error.error_code)) from error
async def get_notification_attributes(
self,
notification_uid: int,
attributes: Sequence[
Union[NotificationAttributeId, tuple[NotificationAttributeId, int]]
],
) -> list[NotificationAttribute]:
if not self._started:
raise RuntimeError("client not started")
command = struct.pack(
"<BI", CommandId.GET_NOTIFICATION_ATTRIBUTES, notification_uid
)
for attribute in attributes:
attribute_max_length = 0
if isinstance(attribute, tuple):
attribute_id, attribute_max_length = attribute
if attribute_id not in (
NotificationAttributeId.TITLE,
NotificationAttributeId.SUBTITLE,
NotificationAttributeId.MESSAGE,
):
raise ValueError(
"this attribute does not allow specifying a max length"
)
else:
attribute_id = attribute
if attribute_id in (
NotificationAttributeId.TITLE,
NotificationAttributeId.SUBTITLE,
NotificationAttributeId.MESSAGE,
):
attribute_max_length = _DEFAULT_ATTRIBUTE_MAX_LENGTH
if attribute_max_length:
command += struct.pack("<BH", attribute_id, attribute_max_length)
else:
command += struct.pack("B", attribute_id)
try:
async with self._command_semaphore:
self._expected_response_notification_uid = notification_uid
self._expected_response_tuples = len(attributes)
self._expected_response_command_id = (
CommandId.GET_NOTIFICATION_ATTRIBUTES
)
self._response = asyncio.Future()
# Send the command.
await self._send_command(command)
# Wait for the response.
return await self._response
finally:
self._reset_response()
async def get_app_attributes(
self, app_identifier: str, attributes: Sequence[AppAttributeId]
) -> list[AppAttribute]:
if not self._started:
raise RuntimeError("client not started")
command = (
bytes([CommandId.GET_APP_ATTRIBUTES])
+ app_identifier.encode("utf-8")
+ b"\0"
)
for attribute_id in attributes:
command += struct.pack("B", attribute_id)
try:
async with self._command_semaphore:
self._expected_response_app_identifier = app_identifier
self._expected_response_tuples = len(attributes)
self._expected_response_command_id = CommandId.GET_APP_ATTRIBUTES
self._response = asyncio.Future()
# Send the command.
await self._send_command(command)
# Wait for the response.
return await self._response
finally:
self._reset_response()
async def perform_action(self, notification_uid: int, action: ActionId) -> None:
if not self._started:
raise RuntimeError("client not started")
command = struct.pack(
"<BIB", CommandId.PERFORM_NOTIFICATION_ACTION, notification_uid, action
)
async with self._command_semaphore:
await self._send_command(command)
async def perform_positive_action(self, notification_uid: int) -> None:
return await self.perform_action(notification_uid, ActionId.POSITIVE)
async def perform_negative_action(self, notification_uid: int) -> None:
return await self.perform_action(notification_uid, ActionId.NEGATIVE)

View File

@@ -17,11 +17,13 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import enum import enum
import logging import logging
import struct import struct
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from bumble import utils
from bumble import colors from bumble import colors
from bumble.profiles.bap import CodecSpecificConfiguration from bumble.profiles.bap import CodecSpecificConfiguration
from bumble.profiles import le_audio from bumble.profiles import le_audio
@@ -258,8 +260,8 @@ class AseReasonCode(enum.IntEnum):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AudioRole(enum.IntEnum): class AudioRole(enum.IntEnum):
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST SINK = device.CisLink.Direction.CONTROLLER_TO_HOST
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER SOURCE = device.CisLink.Direction.HOST_TO_CONTROLLER
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -274,6 +276,8 @@ class AseStateMachine(gatt.Characteristic):
DISABLING = 0x05 DISABLING = 0x05
RELEASING = 0x06 RELEASING = 0x06
EVENT_STATE_CHANGE = "state_change"
cis_link: Optional[device.CisLink] = None cis_link: Optional[device.CisLink] = None
# Additional parameters in CODEC_CONFIGURED State # Additional parameters in CODEC_CONFIGURED State
@@ -300,7 +304,7 @@ class AseStateMachine(gatt.Characteristic):
presentation_delay = 0 presentation_delay = 0
# Additional parameters in ENABLING, STREAMING, DISABLING State # Additional parameters in ENABLING, STREAMING, DISABLING State
metadata = le_audio.Metadata() metadata: le_audio.Metadata
def __init__( def __init__(
self, self,
@@ -312,6 +316,7 @@ class AseStateMachine(gatt.Characteristic):
self.ase_id = ase_id self.ase_id = ase_id
self._state = AseStateMachine.State.IDLE self._state = AseStateMachine.State.IDLE
self.role = role self.role = role
self.metadata = le_audio.Metadata()
uuid = ( uuid = (
gatt.GATT_SINK_ASE_CHARACTERISTIC gatt.GATT_SINK_ASE_CHARACTERISTIC
@@ -326,8 +331,12 @@ class AseStateMachine(gatt.Characteristic):
value=gatt.CharacteristicValue(read=self.on_read), value=gatt.CharacteristicValue(read=self.on_read),
) )
self.service.device.on('cis_request', self.on_cis_request) self.service.device.on(
self.service.device.on('cis_establishment', self.on_cis_establishment) self.service.device.EVENT_CIS_REQUEST, self.on_cis_request
)
self.service.device.on(
self.service.device.EVENT_CIS_ESTABLISHMENT, self.on_cis_establishment
)
def on_cis_request( def on_cis_request(
self, self,
@@ -341,8 +350,10 @@ class AseStateMachine(gatt.Characteristic):
and cis_id == self.cis_id and cis_id == self.cis_id
and self.state == self.State.ENABLING and self.state == self.State.ENABLING
): ):
acl_connection.abort_on( utils.cancel_on_event(
'flush', self.service.device.accept_cis_request(cis_handle) acl_connection,
'flush',
self.service.device.accept_cis_request(cis_handle),
) )
def on_cis_establishment(self, cis_link: device.CisLink) -> None: def on_cis_establishment(self, cis_link: device.CisLink) -> None:
@@ -351,24 +362,17 @@ class AseStateMachine(gatt.Characteristic):
and cis_link.cis_id == self.cis_id and cis_link.cis_id == self.cis_id
and self.state == self.State.ENABLING and self.state == self.State.ENABLING
): ):
cis_link.on('disconnection', self.on_cis_disconnection) cis_link.on(cis_link.EVENT_DISCONNECTION, self.on_cis_disconnection)
async def post_cis_established(): async def post_cis_established():
await self.service.device.send_command( await cis_link.setup_data_path(direction=self.role)
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=cis_link.handle,
data_path_direction=self.role,
data_path_id=0x00, # Fixed HCI
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
)
)
if self.role == AudioRole.SINK: if self.role == AudioRole.SINK:
self.state = self.State.STREAMING self.state = self.State.STREAMING
await self.service.device.notify_subscribers(self, self.value) await self.service.device.notify_subscribers(self, self.value)
cis_link.acl_connection.abort_on('flush', post_cis_established()) utils.cancel_on_event(
cis_link.acl_connection, 'flush', post_cis_established()
)
self.cis_link = cis_link self.cis_link = cis_link
def on_cis_disconnection(self, _reason) -> None: def on_cis_disconnection(self, _reason) -> None:
@@ -511,16 +515,12 @@ class AseStateMachine(gatt.Characteristic):
self.state = self.State.RELEASING self.state = self.State.RELEASING
async def remove_cis_async(): async def remove_cis_async():
await self.service.device.send_command( if self.cis_link:
hci.HCI_LE_Remove_ISO_Data_Path_Command( await self.cis_link.remove_data_path(self.role)
connection_handle=self.cis_link.handle,
data_path_direction=self.role,
)
)
self.state = self.State.IDLE self.state = self.State.IDLE
await self.service.device.notify_subscribers(self, self.value) await self.service.device.notify_subscribers(self, self.value)
self.service.device.abort_on('flush', remove_cis_async()) utils.cancel_on_event(self.service.device, 'flush', remove_cis_async())
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
@property @property
@@ -531,7 +531,7 @@ class AseStateMachine(gatt.Characteristic):
def state(self, new_state: State) -> None: def state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}') logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
self._state = new_state self._state = new_state
self.emit('state_change') self.emit(self.EVENT_STATE_CHANGE)
@property @property
def value(self): def value(self):
@@ -605,7 +605,7 @@ class AudioStreamControlService(gatt.TemplateService):
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
ase_state_machines: Dict[int, AseStateMachine] ase_state_machines: Dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic ase_control_point: gatt.Characteristic[bytes]
_active_client: Optional[device.Connection] = None _active_client: Optional[device.Connection] = None
def __init__( def __init__(
@@ -702,7 +702,8 @@ class AudioStreamControlService(gatt.TemplateService):
control_point_notification = bytes( control_point_notification = bytes(
[operation.op_code, len(responses)] [operation.op_code, len(responses)]
) + b''.join(map(bytes, responses)) ) + b''.join(map(bytes, responses))
self.device.abort_on( utils.cancel_on_event(
self.device,
'flush', 'flush',
self.device.notify_subscribers( self.device.notify_subscribers(
self.ase_control_point, control_point_notification self.ase_control_point, control_point_notification
@@ -711,7 +712,8 @@ class AudioStreamControlService(gatt.TemplateService):
for ase_id, *_ in responses: for ase_id, *_ in responses:
if ase := self.ase_state_machines.get(ase_id): if ase := self.ase_state_machines.get(ase_id):
self.device.abort_on( utils.cancel_on_event(
self.device,
'flush', 'flush',
self.device.notify_subscribers(ase, ase.value), self.device.notify_subscribers(ase, ase.value),
) )
@@ -721,9 +723,9 @@ class AudioStreamControlService(gatt.TemplateService):
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy): class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AudioStreamControlService SERVICE_CLASS = AudioStreamControlService
sink_ase: List[gatt_client.CharacteristicProxy] sink_ase: List[gatt_client.CharacteristicProxy[bytes]]
source_ase: List[gatt_client.CharacteristicProxy] source_ase: List[gatt_client.CharacteristicProxy[bytes]]
ase_control_point: gatt_client.CharacteristicProxy ase_control_point: gatt_client.CharacteristicProxy[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy): def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy

View File

@@ -88,6 +88,11 @@ class AudioStatus(utils.OpenIntEnum):
class AshaService(gatt.TemplateService): class AshaService(gatt.TemplateService):
UUID = gatt.GATT_ASHA_SERVICE UUID = gatt.GATT_ASHA_SERVICE
EVENT_STARTED = "started"
EVENT_STOPPED = "stopped"
EVENT_DISCONNECTED = "disconnected"
EVENT_VOLUME_CHANGED = "volume_changed"
audio_sink: Optional[Callable[[bytes], Any]] audio_sink: Optional[Callable[[bytes], Any]]
active_codec: Optional[Codec] = None active_codec: Optional[Codec] = None
audio_type: Optional[AudioType] = None audio_type: Optional[AudioType] = None
@@ -134,12 +139,14 @@ class AshaService(gatt.TemplateService):
), ),
) )
self.audio_control_point_characteristic = gatt.Characteristic( self.audio_control_point_characteristic: gatt.Characteristic[bytes] = (
gatt.GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC, gatt.Characteristic(
gatt.Characteristic.Properties.WRITE gatt.GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE, gatt.Characteristic.Properties.WRITE
gatt.Characteristic.WRITEABLE, | gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
gatt.CharacteristicValue(write=self._on_audio_control_point_write), gatt.Characteristic.WRITEABLE,
gatt.CharacteristicValue(write=self._on_audio_control_point_write),
)
) )
self.audio_status_characteristic = gatt.Characteristic( self.audio_status_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC, gatt.GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
@@ -147,7 +154,7 @@ class AshaService(gatt.TemplateService):
gatt.Characteristic.READABLE, gatt.Characteristic.READABLE,
bytes([AudioStatus.OK]), bytes([AudioStatus.OK]),
) )
self.volume_characteristic = gatt.Characteristic( self.volume_characteristic: gatt.Characteristic[bytes] = gatt.Characteristic(
gatt.GATT_ASHA_VOLUME_CHARACTERISTIC, gatt.GATT_ASHA_VOLUME_CHARACTERISTIC,
gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE, gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
gatt.Characteristic.WRITEABLE, gatt.Characteristic.WRITEABLE,
@@ -166,13 +173,13 @@ class AshaService(gatt.TemplateService):
struct.pack('<H', self.psm), struct.pack('<H', self.psm),
) )
characteristics = [ characteristics = (
self.read_only_properties_characteristic, self.read_only_properties_characteristic,
self.audio_control_point_characteristic, self.audio_control_point_characteristic,
self.audio_status_characteristic, self.audio_status_characteristic,
self.volume_characteristic, self.volume_characteristic,
self.le_psm_out_characteristic, self.le_psm_out_characteristic,
] )
super().__init__(characteristics) super().__init__(characteristics)
@@ -209,14 +216,14 @@ class AshaService(gatt.TemplateService):
f'volume={self.volume}, ' f'volume={self.volume}, '
f'other_state={self.other_state}' f'other_state={self.other_state}'
) )
self.emit('started') self.emit(self.EVENT_STARTED)
elif opcode == OpCode.STOP: elif opcode == OpCode.STOP:
_logger.debug('### STOP') _logger.debug('### STOP')
self.active_codec = None self.active_codec = None
self.audio_type = None self.audio_type = None
self.volume = None self.volume = None
self.other_state = None self.other_state = None
self.emit('stopped') self.emit(self.EVENT_STOPPED)
elif opcode == OpCode.STATUS: elif opcode == OpCode.STATUS:
_logger.debug('### STATUS: %s', PeripheralStatus(value[1]).name) _logger.debug('### STATUS: %s', PeripheralStatus(value[1]).name)
@@ -229,7 +236,7 @@ class AshaService(gatt.TemplateService):
self.audio_type = None self.audio_type = None
self.volume = None self.volume = None
self.other_state = None self.other_state = None
self.emit('disconnected') self.emit(self.EVENT_DISCONNECTED)
connection.once('disconnection', on_disconnection) connection.once('disconnection', on_disconnection)
@@ -243,7 +250,7 @@ class AshaService(gatt.TemplateService):
def _on_volume_write(self, connection: Optional[Connection], value: bytes) -> None: def _on_volume_write(self, connection: Optional[Connection], value: bytes) -> None:
_logger.debug(f'--- VOLUME Write:{value[0]}') _logger.debug(f'--- VOLUME Write:{value[0]}')
self.volume = value[0] self.volume = value[0]
self.emit('volume_changed') self.emit(self.EVENT_VOLUME_CHANGED)
# Register an L2CAP CoC server # Register an L2CAP CoC server
def _on_connection(self, channel: l2cap.LeCreditBasedChannel) -> None: def _on_connection(self, channel: l2cap.LeCreditBasedChannel) -> None:
@@ -257,11 +264,11 @@ class AshaService(gatt.TemplateService):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AshaServiceProxy(gatt_client.ProfileServiceProxy): class AshaServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AshaService SERVICE_CLASS = AshaService
read_only_properties_characteristic: gatt_client.CharacteristicProxy read_only_properties_characteristic: gatt_client.CharacteristicProxy[bytes]
audio_control_point_characteristic: gatt_client.CharacteristicProxy audio_control_point_characteristic: gatt_client.CharacteristicProxy[bytes]
audio_status_point_characteristic: gatt_client.CharacteristicProxy audio_status_point_characteristic: gatt_client.CharacteristicProxy[bytes]
volume_characteristic: gatt_client.CharacteristicProxy volume_characteristic: gatt_client.CharacteristicProxy[bytes]
psm_characteristic: gatt_client.CharacteristicProxy psm_characteristic: gatt_client.CharacteristicProxy[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None: def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy self.service_proxy = service_proxy
@@ -288,8 +295,8 @@ class AshaServiceProxy(gatt_client.ProfileServiceProxy):
'psm_characteristic', 'psm_characteristic',
), ),
): ):
if not ( setattr(
characteristics := self.service_proxy.get_characteristics_by_uuid(uuid) self,
): attribute_name,
raise gatt.InvalidServiceError(f"Missing {uuid} Characteristic") self.service_proxy.get_required_characteristic_by_uuid(uuid),
setattr(self, attribute_name, characteristics[0]) )

View File

@@ -102,6 +102,7 @@ class ContextType(enum.IntFlag):
# fmt: off # fmt: off
PROHIBITED = 0x0000 PROHIBITED = 0x0000
UNSPECIFIED = 0x0001
CONVERSATIONAL = 0x0002 CONVERSATIONAL = 0x0002
MEDIA = 0x0004 MEDIA = 0x0004
GAME = 0x0008 GAME = 0x0008
@@ -264,7 +265,7 @@ class UnicastServerAdvertisingData:
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID, core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
struct.pack( struct.pack(
'<2sBIB', '<2sBIB',
gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE.to_bytes(), bytes(gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE),
self.announcement_type, self.announcement_type,
self.available_audio_contexts, self.available_audio_contexts,
len(self.metadata), len(self.metadata),
@@ -397,18 +398,21 @@ class CodecSpecificConfiguration:
OCTETS_PER_FRAME = 0x04 OCTETS_PER_FRAME = 0x04
CODEC_FRAMES_PER_SDU = 0x05 CODEC_FRAMES_PER_SDU = 0x05
sampling_frequency: SamplingFrequency sampling_frequency: SamplingFrequency | None = None
frame_duration: FrameDuration frame_duration: FrameDuration | None = None
audio_channel_allocation: AudioLocation audio_channel_allocation: AudioLocation | None = None
octets_per_codec_frame: int octets_per_codec_frame: int | None = None
codec_frames_per_sdu: int codec_frames_per_sdu: int | None = None
@classmethod @classmethod
def from_bytes(cls, data: bytes) -> CodecSpecificConfiguration: def from_bytes(cls, data: bytes) -> CodecSpecificConfiguration:
offset = 0 offset = 0
# Allowed default values. sampling_frequency: SamplingFrequency | None = None
audio_channel_allocation = AudioLocation.NOT_ALLOWED frame_duration: FrameDuration | None = None
codec_frames_per_sdu = 1 audio_channel_allocation: AudioLocation | None = None
octets_per_codec_frame: int | None = None
codec_frames_per_sdu: int | None = None
while offset < len(data): while offset < len(data):
length, type = struct.unpack_from('BB', data, offset) length, type = struct.unpack_from('BB', data, offset)
offset += 2 offset += 2
@@ -426,8 +430,6 @@ class CodecSpecificConfiguration:
elif type == CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU: elif type == CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU:
codec_frames_per_sdu = value codec_frames_per_sdu = value
# It is expected here that if some fields are missing, an error should be raised.
# pylint: disable=possibly-used-before-assignment,used-before-assignment
return CodecSpecificConfiguration( return CodecSpecificConfiguration(
sampling_frequency=sampling_frequency, sampling_frequency=sampling_frequency,
frame_duration=frame_duration, frame_duration=frame_duration,
@@ -437,23 +439,43 @@ class CodecSpecificConfiguration:
) )
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return struct.pack( return b''.join(
'<BBBBBBBBIBBHBBB', [
2, struct.pack(fmt, length, tag, value)
CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY, for fmt, length, tag, value in [
self.sampling_frequency, (
2, '<BBB',
CodecSpecificConfiguration.Type.FRAME_DURATION, 2,
self.frame_duration, CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY,
5, self.sampling_frequency,
CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION, ),
self.audio_channel_allocation, (
3, '<BBB',
CodecSpecificConfiguration.Type.OCTETS_PER_FRAME, 2,
self.octets_per_codec_frame, CodecSpecificConfiguration.Type.FRAME_DURATION,
2, self.frame_duration,
CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU, ),
self.codec_frames_per_sdu, (
'<BBI',
5,
CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION,
self.audio_channel_allocation,
),
(
'<BBH',
3,
CodecSpecificConfiguration.Type.OCTETS_PER_FRAME,
self.octets_per_codec_frame,
),
(
'<BBB',
2,
CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU,
self.codec_frames_per_sdu,
),
]
if value is not None
]
) )
@@ -465,6 +487,24 @@ class BroadcastAudioAnnouncement:
def from_bytes(cls, data: bytes) -> Self: def from_bytes(cls, data: bytes) -> Self:
return cls(int.from_bytes(data[:3], 'little')) return cls(int.from_bytes(data[:3], 'little'))
def __bytes__(self) -> bytes:
return self.broadcast_id.to_bytes(3, 'little')
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
(
bytes(gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE)
+ bytes(self)
),
)
]
)
)
@dataclasses.dataclass @dataclasses.dataclass
class BasicAudioAnnouncement: class BasicAudioAnnouncement:
@@ -473,26 +513,37 @@ class BasicAudioAnnouncement:
index: int index: int
codec_specific_configuration: CodecSpecificConfiguration codec_specific_configuration: CodecSpecificConfiguration
@dataclasses.dataclass def __bytes__(self) -> bytes:
class CodecInfo: codec_specific_configuration_bytes = bytes(
coding_format: hci.CodecID self.codec_specific_configuration
company_id: int )
vendor_specific_codec_id: int return (
bytes([self.index, len(codec_specific_configuration_bytes)])
@classmethod + codec_specific_configuration_bytes
def from_bytes(cls, data: bytes) -> Self: )
coding_format = hci.CodecID(data[0])
company_id = int.from_bytes(data[1:3], 'little')
vendor_specific_codec_id = int.from_bytes(data[3:5], 'little')
return cls(coding_format, company_id, vendor_specific_codec_id)
@dataclasses.dataclass @dataclasses.dataclass
class Subgroup: class Subgroup:
codec_id: BasicAudioAnnouncement.CodecInfo codec_id: hci.CodingFormat
codec_specific_configuration: CodecSpecificConfiguration codec_specific_configuration: CodecSpecificConfiguration
metadata: le_audio.Metadata metadata: le_audio.Metadata
bis: List[BasicAudioAnnouncement.BIS] bis: List[BasicAudioAnnouncement.BIS]
def __bytes__(self) -> bytes:
metadata_bytes = bytes(self.metadata)
codec_specific_configuration_bytes = bytes(
self.codec_specific_configuration
)
return (
bytes([len(self.bis)])
+ bytes(self.codec_id)
+ bytes([len(codec_specific_configuration_bytes)])
+ codec_specific_configuration_bytes
+ bytes([len(metadata_bytes)])
+ metadata_bytes
+ b''.join(map(bytes, self.bis))
)
presentation_delay: int presentation_delay: int
subgroups: List[BasicAudioAnnouncement.Subgroup] subgroups: List[BasicAudioAnnouncement.Subgroup]
@@ -504,7 +555,7 @@ class BasicAudioAnnouncement:
for _ in range(data[3]): for _ in range(data[3]):
num_bis = data[offset] num_bis = data[offset]
offset += 1 offset += 1
codec_id = cls.CodecInfo.from_bytes(data[offset : offset + 5]) codec_id = hci.CodingFormat.from_bytes(data[offset : offset + 5])
offset += 5 offset += 5
codec_specific_configuration_length = data[offset] codec_specific_configuration_length = data[offset]
offset += 1 offset += 1
@@ -548,3 +599,25 @@ class BasicAudioAnnouncement:
) )
return cls(presentation_delay, subgroups) return cls(presentation_delay, subgroups)
def __bytes__(self) -> bytes:
return (
self.presentation_delay.to_bytes(3, 'little')
+ bytes([len(self.subgroups)])
+ b''.join(map(bytes, self.subgroups))
)
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
(
bytes(gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE)
+ bytes(self)
),
)
]
)
)

View File

@@ -20,11 +20,12 @@ from __future__ import annotations
import dataclasses import dataclasses
import logging import logging
import struct import struct
from typing import ClassVar, List, Optional, Sequence from typing import ClassVar, Optional, Sequence
from bumble import core from bumble import core
from bumble import device from bumble import device
from bumble import gatt from bumble import gatt
from bumble import gatt_adapters
from bumble import gatt_client from bumble import gatt_client
from bumble import hci from bumble import hci
from bumble import utils from bumble import utils
@@ -52,7 +53,7 @@ def encode_subgroups(subgroups: Sequence[SubgroupInfo]) -> bytes:
) )
def decode_subgroups(data: bytes) -> List[SubgroupInfo]: def decode_subgroups(data: bytes) -> list[SubgroupInfo]:
num_subgroups = data[0] num_subgroups = data[0]
offset = 1 offset = 1
subgroups = [] subgroups = []
@@ -273,13 +274,10 @@ class BroadcastReceiveState:
pa_sync_state: PeriodicAdvertisingSyncState pa_sync_state: PeriodicAdvertisingSyncState
big_encryption: BigEncryption big_encryption: BigEncryption
bad_code: bytes bad_code: bytes
subgroups: List[SubgroupInfo] subgroups: list[SubgroupInfo]
@classmethod @classmethod
def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]: def from_bytes(cls, data: bytes) -> BroadcastReceiveState:
if not data:
return None
source_id = data[0] source_id = data[0]
_, source_address = hci.Address.parse_address_preceded_by_type(data, 2) _, source_address = hci.Address.parse_address_preceded_by_type(data, 2)
source_adv_sid = data[8] source_adv_sid = data[8]
@@ -356,35 +354,28 @@ class BroadcastAudioScanService(gatt.TemplateService):
class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy): class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = BroadcastAudioScanService SERVICE_CLASS = BroadcastAudioScanService
broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy[bytes]
broadcast_receive_states: List[gatt.DelegatedCharacteristicAdapter] broadcast_receive_states: list[
gatt_client.CharacteristicProxy[Optional[BroadcastReceiveState]]
]
def __init__(self, service_proxy: gatt_client.ServiceProxy): def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
if not ( self.broadcast_audio_scan_control_point = (
characteristics := service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC
) )
): )
raise gatt.InvalidServiceError(
"Broadcast Audio Scan Control Point characteristic not found"
)
self.broadcast_audio_scan_control_point = characteristics[0]
if not ( self.broadcast_receive_states = [
characteristics := service_proxy.get_characteristics_by_uuid( gatt_adapters.DelegatedCharacteristicProxyAdapter(
characteristic,
decode=lambda x: BroadcastReceiveState.from_bytes(x) if x else None,
)
for characteristic in service_proxy.get_characteristics_by_uuid(
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC
) )
):
raise gatt.InvalidServiceError(
"Broadcast Receive State characteristic not found"
)
self.broadcast_receive_states = [
gatt.DelegatedCharacteristicAdapter(
characteristic, decode=BroadcastReceiveState.from_bytes
)
for characteristic in characteristics
] ]
async def send_control_point_operation( async def send_control_point_operation(

View File

@@ -16,14 +16,20 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from ..gatt_client import ProfileServiceProxy from typing import Optional
from ..gatt import (
from bumble.gatt_client import ProfileServiceProxy
from bumble.gatt import (
GATT_BATTERY_SERVICE, GATT_BATTERY_SERVICE,
GATT_BATTERY_LEVEL_CHARACTERISTIC, GATT_BATTERY_LEVEL_CHARACTERISTIC,
TemplateService, TemplateService,
Characteristic, Characteristic,
CharacteristicValue, CharacteristicValue,
)
from bumble.gatt_client import CharacteristicProxy
from bumble.gatt_adapters import (
PackedCharacteristicAdapter, PackedCharacteristicAdapter,
PackedCharacteristicProxyAdapter,
) )
@@ -32,6 +38,8 @@ class BatteryService(TemplateService):
UUID = GATT_BATTERY_SERVICE UUID = GATT_BATTERY_SERVICE
BATTERY_LEVEL_FORMAT = 'B' BATTERY_LEVEL_FORMAT = 'B'
battery_level_characteristic: Characteristic[int]
def __init__(self, read_battery_level): def __init__(self, read_battery_level):
self.battery_level_characteristic = PackedCharacteristicAdapter( self.battery_level_characteristic = PackedCharacteristicAdapter(
Characteristic( Characteristic(
@@ -49,13 +57,15 @@ class BatteryService(TemplateService):
class BatteryServiceProxy(ProfileServiceProxy): class BatteryServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = BatteryService SERVICE_CLASS = BatteryService
battery_level: Optional[CharacteristicProxy[int]]
def __init__(self, service_proxy): def __init__(self, service_proxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BATTERY_LEVEL_CHARACTERISTIC GATT_BATTERY_LEVEL_CHARACTERISTIC
): ):
self.battery_level = PackedCharacteristicAdapter( self.battery_level = PackedCharacteristicProxyAdapter(
characteristics[0], pack_format=BatteryService.BATTERY_LEVEL_FORMAT characteristics[0], pack_format=BatteryService.BATTERY_LEVEL_FORMAT
) )
else: else:

View File

@@ -99,10 +99,10 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
set_identity_resolving_key: bytes set_identity_resolving_key: bytes
set_identity_resolving_key_characteristic: gatt.Characteristic set_identity_resolving_key_characteristic: gatt.Characteristic[bytes]
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None coordinated_set_size_characteristic: Optional[gatt.Characteristic[bytes]] = None
set_member_lock_characteristic: Optional[gatt.Characteristic] = None set_member_lock_characteristic: Optional[gatt.Characteristic[bytes]] = None
set_member_rank_characteristic: Optional[gatt.Characteristic] = None set_member_rank_characteristic: Optional[gatt.Characteristic[bytes]] = None
def __init__( def __init__(
self, self,
@@ -170,7 +170,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
else: else:
assert connection assert connection
if connection.transport == core.BT_LE_TRANSPORT: if connection.transport == core.PhysicalTransport.LE:
key = await connection.device.get_long_term_key( key = await connection.device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0 connection_handle=connection.handle, rand=b'', ediv=0
) )
@@ -203,10 +203,10 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy): class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CoordinatedSetIdentificationService SERVICE_CLASS = CoordinatedSetIdentificationService
set_identity_resolving_key: gatt_client.CharacteristicProxy set_identity_resolving_key: gatt_client.CharacteristicProxy[bytes]
coordinated_set_size: Optional[gatt_client.CharacteristicProxy] = None coordinated_set_size: Optional[gatt_client.CharacteristicProxy[bytes]] = None
set_member_lock: Optional[gatt_client.CharacteristicProxy] = None set_member_lock: Optional[gatt_client.CharacteristicProxy[bytes]] = None
set_member_rank: Optional[gatt_client.CharacteristicProxy] = None set_member_rank: Optional[gatt_client.CharacteristicProxy[bytes]] = None
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None: def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy self.service_proxy = service_proxy
@@ -242,7 +242,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
else: else:
connection = self.service_proxy.client.connection connection = self.service_proxy.client.connection
device = connection.device device = connection.device
if connection.transport == core.BT_LE_TRANSPORT: if connection.transport == core.PhysicalTransport.LE:
key = await device.get_long_term_key( key = await device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0 connection_handle=connection.handle, rand=b'', ediv=0
) )

View File

@@ -19,7 +19,6 @@
import struct import struct
from typing import Optional, Tuple from typing import Optional, Tuple
from bumble.gatt_client import ServiceProxy, ProfileServiceProxy, CharacteristicProxy
from bumble.gatt import ( from bumble.gatt import (
GATT_DEVICE_INFORMATION_SERVICE, GATT_DEVICE_INFORMATION_SERVICE,
GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC, GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC,
@@ -32,9 +31,12 @@ from bumble.gatt import (
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC, GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
TemplateService, TemplateService,
Characteristic, Characteristic,
DelegatedCharacteristicAdapter,
UTF8CharacteristicAdapter,
) )
from bumble.gatt_adapters import (
DelegatedCharacteristicProxyAdapter,
UTF8CharacteristicProxyAdapter,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -62,9 +64,12 @@ class DeviceInformationService(TemplateService):
ieee_regulatory_certification_data_list: Optional[bytes] = None, ieee_regulatory_certification_data_list: Optional[bytes] = None,
# TODO: pnp_id # TODO: pnp_id
): ):
characteristics = [ characteristics: list[Characteristic[bytes]] = [
Characteristic( Characteristic(
uuid, Characteristic.Properties.READ, Characteristic.READABLE, field uuid,
Characteristic.Properties.READ,
Characteristic.READABLE,
bytes(field, 'utf-8'),
) )
for (field, uuid) in ( for (field, uuid) in (
(manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC), (manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
@@ -104,14 +109,14 @@ class DeviceInformationService(TemplateService):
class DeviceInformationServiceProxy(ProfileServiceProxy): class DeviceInformationServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = DeviceInformationService SERVICE_CLASS = DeviceInformationService
manufacturer_name: Optional[UTF8CharacteristicAdapter] manufacturer_name: Optional[CharacteristicProxy[str]]
model_number: Optional[UTF8CharacteristicAdapter] model_number: Optional[CharacteristicProxy[str]]
serial_number: Optional[UTF8CharacteristicAdapter] serial_number: Optional[CharacteristicProxy[str]]
hardware_revision: Optional[UTF8CharacteristicAdapter] hardware_revision: Optional[CharacteristicProxy[str]]
firmware_revision: Optional[UTF8CharacteristicAdapter] firmware_revision: Optional[CharacteristicProxy[str]]
software_revision: Optional[UTF8CharacteristicAdapter] software_revision: Optional[CharacteristicProxy[str]]
system_id: Optional[DelegatedCharacteristicAdapter] system_id: Optional[CharacteristicProxy[tuple[int, int]]]
ieee_regulatory_certification_data_list: Optional[CharacteristicProxy] ieee_regulatory_certification_data_list: Optional[CharacteristicProxy[bytes]]
def __init__(self, service_proxy: ServiceProxy): def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
@@ -125,7 +130,7 @@ class DeviceInformationServiceProxy(ProfileServiceProxy):
('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC), ('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC),
): ):
if characteristics := service_proxy.get_characteristics_by_uuid(uuid): if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
characteristic = UTF8CharacteristicAdapter(characteristics[0]) characteristic = UTF8CharacteristicProxyAdapter(characteristics[0])
else: else:
characteristic = None characteristic = None
self.__setattr__(field, characteristic) self.__setattr__(field, characteristic)
@@ -133,7 +138,7 @@ class DeviceInformationServiceProxy(ProfileServiceProxy):
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_SYSTEM_ID_CHARACTERISTIC GATT_SYSTEM_ID_CHARACTERISTIC
): ):
self.system_id = DelegatedCharacteristicAdapter( self.system_id = DelegatedCharacteristicProxyAdapter(
characteristics[0], characteristics[0],
encode=lambda v: DeviceInformationService.pack_system_id(*v), encode=lambda v: DeviceInformationService.pack_system_id(*v),
decode=DeviceInformationService.unpack_system_id, decode=DeviceInformationService.unpack_system_id,

View File

@@ -25,14 +25,15 @@ from bumble.core import Appearance
from bumble.gatt import ( from bumble.gatt import (
TemplateService, TemplateService,
Characteristic, Characteristic,
CharacteristicAdapter,
DelegatedCharacteristicAdapter,
UTF8CharacteristicAdapter,
GATT_GENERIC_ACCESS_SERVICE, GATT_GENERIC_ACCESS_SERVICE,
GATT_DEVICE_NAME_CHARACTERISTIC, GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC, GATT_APPEARANCE_CHARACTERISTIC,
) )
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy from bumble.gatt_adapters import (
DelegatedCharacteristicProxyAdapter,
UTF8CharacteristicProxyAdapter,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -49,6 +50,9 @@ logger = logging.getLogger(__name__)
class GenericAccessService(TemplateService): class GenericAccessService(TemplateService):
UUID = GATT_GENERIC_ACCESS_SERVICE UUID = GATT_GENERIC_ACCESS_SERVICE
device_name_characteristic: Characteristic[bytes]
appearance_characteristic: Characteristic[bytes]
def __init__( def __init__(
self, device_name: str, appearance: Union[Appearance, Tuple[int, int], int] = 0 self, device_name: str, appearance: Union[Appearance, Tuple[int, int], int] = 0
): ):
@@ -84,8 +88,8 @@ class GenericAccessService(TemplateService):
class GenericAccessServiceProxy(ProfileServiceProxy): class GenericAccessServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GenericAccessService SERVICE_CLASS = GenericAccessService
device_name: Optional[CharacteristicAdapter] device_name: Optional[CharacteristicProxy[str]]
appearance: Optional[DelegatedCharacteristicAdapter] appearance: Optional[CharacteristicProxy[Appearance]]
def __init__(self, service_proxy: ServiceProxy): def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
@@ -93,14 +97,14 @@ class GenericAccessServiceProxy(ProfileServiceProxy):
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_DEVICE_NAME_CHARACTERISTIC GATT_DEVICE_NAME_CHARACTERISTIC
): ):
self.device_name = UTF8CharacteristicAdapter(characteristics[0]) self.device_name = UTF8CharacteristicProxyAdapter(characteristics[0])
else: else:
self.device_name = None self.device_name = None
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_APPEARANCE_CHARACTERISTIC GATT_APPEARANCE_CHARACTERISTIC
): ):
self.appearance = DelegatedCharacteristicAdapter( self.appearance = DelegatedCharacteristicProxyAdapter(
characteristics[0], characteristics[0],
decode=lambda value: Appearance.from_int( decode=lambda value: Appearance.from_int(
struct.unpack_from('<H', value, 0)[0], struct.unpack_from('<H', value, 0)[0],

View File

@@ -0,0 +1,167 @@
# Copyright 2021-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import struct
from typing import TYPE_CHECKING
from bumble import att
from bumble import gatt
from bumble import gatt_client
from bumble import crypto
if TYPE_CHECKING:
from bumble import device
# -----------------------------------------------------------------------------
class GenericAttributeProfileService(gatt.TemplateService):
'''See Vol 3, Part G - 7 - DEFINED GENERIC ATTRIBUTE PROFILE SERVICE.'''
UUID = gatt.GATT_GENERIC_ATTRIBUTE_SERVICE
client_supported_features_characteristic: gatt.Characteristic[bytes] | None = None
server_supported_features_characteristic: gatt.Characteristic[bytes] | None = None
database_hash_characteristic: gatt.Characteristic[bytes] | None = None
service_changed_characteristic: gatt.Characteristic[bytes] | None = None
def __init__(
self,
server_supported_features: gatt.ServerSupportedFeatures | None = None,
database_hash_enabled: bool = True,
service_change_enabled: bool = True,
) -> None:
if server_supported_features is not None:
self.server_supported_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([server_supported_features]),
)
if database_hash_enabled:
self.database_hash_characteristic = gatt.Characteristic(
uuid=gatt.GATT_DATABASE_HASH_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=gatt.CharacteristicValue(read=self.get_database_hash),
)
if service_change_enabled:
self.service_changed_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SERVICE_CHANGED_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.INDICATE,
permissions=gatt.Characteristic.Permissions(0),
value=b'',
)
if (database_hash_enabled and service_change_enabled) or (
server_supported_features
and (
server_supported_features & gatt.ServerSupportedFeatures.EATT_SUPPORTED
)
): # TODO: Support Multiple Handle Value Notifications
self.client_supported_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE
),
permissions=(
gatt.Characteristic.Permissions.READABLE
| gatt.Characteristic.Permissions.WRITEABLE
),
value=bytes(1),
)
super().__init__(
characteristics=[
c
for c in (
self.service_changed_characteristic,
self.client_supported_features_characteristic,
self.database_hash_characteristic,
self.server_supported_features_characteristic,
)
if c is not None
],
primary=True,
)
@classmethod
def get_attribute_data(cls, attribute: att.Attribute) -> bytes:
if attribute.type in (
gatt.GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
gatt.GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
gatt.GATT_INCLUDE_ATTRIBUTE_TYPE,
gatt.GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
gatt.GATT_CHARACTERISTIC_EXTENDED_PROPERTIES_DESCRIPTOR,
):
assert isinstance(attribute.value, bytes)
return (
struct.pack("<H", attribute.handle)
+ attribute.type.to_bytes()
+ attribute.value
)
elif attribute.type in (
gatt.GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
gatt.GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
gatt.GATT_SERVER_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
gatt.GATT_CHARACTERISTIC_PRESENTATION_FORMAT_DESCRIPTOR,
gatt.GATT_CHARACTERISTIC_AGGREGATE_FORMAT_DESCRIPTOR,
):
return struct.pack("<H", attribute.handle) + attribute.type.to_bytes()
return b''
def get_database_hash(self, connection: device.Connection | None) -> bytes:
assert connection
m = b''.join(
[
self.get_attribute_data(attribute)
for attribute in connection.device.gatt_server.attributes
]
)
return crypto.aes_cmac(m=m, k=bytes(16))
class GenericAttributeProfileServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = GenericAttributeProfileService
client_supported_features_characteristic: (
gatt_client.CharacteristicProxy[bytes] | None
) = None
server_supported_features_characteristic: (
gatt_client.CharacteristicProxy[bytes] | None
) = None
database_hash_characteristic: gatt_client.CharacteristicProxy[bytes] | None = None
service_changed_characteristic: gatt_client.CharacteristicProxy[bytes] | None = None
_CHARACTERISTICS = {
gatt.GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC: 'client_supported_features_characteristic',
gatt.GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC: 'server_supported_features_characteristic',
gatt.GATT_DATABASE_HASH_CHARACTERISTIC: 'database_hash_characteristic',
gatt.GATT_SERVICE_CHANGED_CHARACTERISTIC: 'service_changed_characteristic',
}
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
for uuid, attribute_name in self._CHARACTERISTICS.items():
if characteristics := self.service_proxy.get_characteristics_by_uuid(uuid):
setattr(self, attribute_name, characteristics[0])

198
bumble/profiles/gmap.py Normal file
View File

@@ -0,0 +1,198 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LE Audio - Gaming Audio Profile"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
from typing import Optional
from bumble.gatt import (
TemplateService,
Characteristic,
GATT_GAMING_AUDIO_SERVICE,
GATT_GMAP_ROLE_CHARACTERISTIC,
GATT_UGG_FEATURES_CHARACTERISTIC,
GATT_UGT_FEATURES_CHARACTERISTIC,
GATT_BGS_FEATURES_CHARACTERISTIC,
GATT_BGR_FEATURES_CHARACTERISTIC,
)
from bumble.gatt_adapters import DelegatedCharacteristicProxyAdapter
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
from enum import IntFlag
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class GmapRole(IntFlag):
UNICAST_GAME_GATEWAY = 1 << 0
UNICAST_GAME_TERMINAL = 1 << 1
BROADCAST_GAME_SENDER = 1 << 2
BROADCAST_GAME_RECEIVER = 1 << 3
class UggFeatures(IntFlag):
UGG_MULTIPLEX = 1 << 0
UGG_96_KBPS_SOURCE = 1 << 1
UGG_MULTISINK = 1 << 2
class UgtFeatures(IntFlag):
UGT_SOURCE = 1 << 0
UGT_80_KBPS_SOURCE = 1 << 1
UGT_SINK = 1 << 2
UGT_64_KBPS_SINK = 1 << 3
UGT_MULTIPLEX = 1 << 4
UGT_MULTISINK = 1 << 5
UGT_MULTISOURCE = 1 << 6
class BgsFeatures(IntFlag):
BGS_96_KBPS = 1 << 0
class BgrFeatures(IntFlag):
BGR_MULTISINK = 1 << 0
BGR_MULTIPLEX = 1 << 1
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class GamingAudioService(TemplateService):
UUID = GATT_GAMING_AUDIO_SERVICE
gmap_role: Characteristic
ugg_features: Optional[Characteristic] = None
ugt_features: Optional[Characteristic] = None
bgs_features: Optional[Characteristic] = None
bgr_features: Optional[Characteristic] = None
def __init__(
self,
gmap_role: GmapRole,
ugg_features: Optional[UggFeatures] = None,
ugt_features: Optional[UgtFeatures] = None,
bgs_features: Optional[BgsFeatures] = None,
bgr_features: Optional[BgrFeatures] = None,
) -> None:
characteristics = []
ugg_features = UggFeatures(0) if ugg_features is None else ugg_features
ugt_features = UgtFeatures(0) if ugt_features is None else ugt_features
bgs_features = BgsFeatures(0) if bgs_features is None else bgs_features
bgr_features = BgrFeatures(0) if bgr_features is None else bgr_features
self.gmap_role = Characteristic(
uuid=GATT_GMAP_ROLE_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', gmap_role),
)
characteristics.append(self.gmap_role)
if gmap_role & GmapRole.UNICAST_GAME_GATEWAY:
self.ugg_features = Characteristic(
uuid=GATT_UGG_FEATURES_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', ugg_features),
)
characteristics.append(self.ugg_features)
if gmap_role & GmapRole.UNICAST_GAME_TERMINAL:
self.ugt_features = Characteristic(
uuid=GATT_UGT_FEATURES_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', ugt_features),
)
characteristics.append(self.ugt_features)
if gmap_role & GmapRole.BROADCAST_GAME_SENDER:
self.bgs_features = Characteristic(
uuid=GATT_BGS_FEATURES_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', bgs_features),
)
characteristics.append(self.bgs_features)
if gmap_role & GmapRole.BROADCAST_GAME_RECEIVER:
self.bgr_features = Characteristic(
uuid=GATT_BGR_FEATURES_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', bgr_features),
)
characteristics.append(self.bgr_features)
super().__init__(characteristics)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class GamingAudioServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GamingAudioService
ugg_features: Optional[CharacteristicProxy[UggFeatures]] = None
ugt_features: Optional[CharacteristicProxy[UgtFeatures]] = None
bgs_features: Optional[CharacteristicProxy[BgsFeatures]] = None
bgr_features: Optional[CharacteristicProxy[BgrFeatures]] = None
def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy
self.gmap_role = DelegatedCharacteristicProxyAdapter(
service_proxy.get_required_characteristic_by_uuid(
GATT_GMAP_ROLE_CHARACTERISTIC
),
decode=lambda value: GmapRole(value[0]),
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_UGG_FEATURES_CHARACTERISTIC
):
self.ugg_features = DelegatedCharacteristicProxyAdapter(
characteristics[0],
decode=lambda value: UggFeatures(value[0]),
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_UGT_FEATURES_CHARACTERISTIC
):
self.ugt_features = DelegatedCharacteristicProxyAdapter(
characteristics[0],
decode=lambda value: UgtFeatures(value[0]),
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BGS_FEATURES_CHARACTERISTIC
):
self.bgs_features = DelegatedCharacteristicProxyAdapter(
characteristics[0],
decode=lambda value: BgsFeatures(value[0]),
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BGR_FEATURES_CHARACTERISTIC
):
self.bgr_features = DelegatedCharacteristicProxyAdapter(
characteristics[0],
decode=lambda value: BgrFeatures(value[0]),
)

View File

@@ -18,20 +18,21 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import functools import functools
from bumble import att, gatt, gatt_client
from bumble.core import InvalidArgumentError, InvalidStateError
from bumble.device import Device, Connection
from bumble.utils import AsyncRunner, OpenIntEnum
from bumble.hci import Address
from dataclasses import dataclass, field from dataclasses import dataclass, field
import logging import logging
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
from bumble import att, gatt, gatt_adapters, gatt_client
from bumble.core import InvalidArgumentError, InvalidStateError
from bumble.device import Device, Connection
from bumble import utils
from bumble.hci import Address
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ErrorCode(OpenIntEnum): class ErrorCode(utils.OpenIntEnum):
'''See Hearing Access Service 2.4. Attribute Profile error codes.''' '''See Hearing Access Service 2.4. Attribute Profile error codes.'''
INVALID_OPCODE = 0x80 INVALID_OPCODE = 0x80
@@ -41,7 +42,7 @@ class ErrorCode(OpenIntEnum):
INVALID_PARAMETERS_LENGTH = 0x84 INVALID_PARAMETERS_LENGTH = 0x84
class HearingAidType(OpenIntEnum): class HearingAidType(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.''' '''See Hearing Access Service 3.1. Hearing Aid Features.'''
BINAURAL_HEARING_AID = 0b00 BINAURAL_HEARING_AID = 0b00
@@ -49,35 +50,35 @@ class HearingAidType(OpenIntEnum):
BANDED_HEARING_AID = 0b10 BANDED_HEARING_AID = 0b10
class PresetSynchronizationSupport(OpenIntEnum): class PresetSynchronizationSupport(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.''' '''See Hearing Access Service 3.1. Hearing Aid Features.'''
PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED = 0b0 PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED = 0b0
PRESET_SYNCHRONIZATION_IS_SUPPORTED = 0b1 PRESET_SYNCHRONIZATION_IS_SUPPORTED = 0b1
class IndependentPresets(OpenIntEnum): class IndependentPresets(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.''' '''See Hearing Access Service 3.1. Hearing Aid Features.'''
IDENTICAL_PRESET_RECORD = 0b0 IDENTICAL_PRESET_RECORD = 0b0
DIFFERENT_PRESET_RECORD = 0b1 DIFFERENT_PRESET_RECORD = 0b1
class DynamicPresets(OpenIntEnum): class DynamicPresets(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.''' '''See Hearing Access Service 3.1. Hearing Aid Features.'''
PRESET_RECORDS_DOES_NOT_CHANGE = 0b0 PRESET_RECORDS_DOES_NOT_CHANGE = 0b0
PRESET_RECORDS_MAY_CHANGE = 0b1 PRESET_RECORDS_MAY_CHANGE = 0b1
class WritablePresetsSupport(OpenIntEnum): class WritablePresetsSupport(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.''' '''See Hearing Access Service 3.1. Hearing Aid Features.'''
WRITABLE_PRESET_RECORDS_NOT_SUPPORTED = 0b0 WRITABLE_PRESET_RECORDS_NOT_SUPPORTED = 0b0
WRITABLE_PRESET_RECORDS_SUPPORTED = 0b1 WRITABLE_PRESET_RECORDS_SUPPORTED = 0b1
class HearingAidPresetControlPointOpcode(OpenIntEnum): class HearingAidPresetControlPointOpcode(utils.OpenIntEnum):
'''See Hearing Access Service 3.3.1 Hearing Aid Preset Control Point operation requirements.''' '''See Hearing Access Service 3.3.1 Hearing Aid Preset Control Point operation requirements.'''
# fmt: off # fmt: off
@@ -129,7 +130,7 @@ def HearingAidFeatures_from_bytes(data: int) -> HearingAidFeatures:
class PresetChangedOperation: class PresetChangedOperation:
'''See Hearing Access Service 3.2.2.2. Preset Changed operation.''' '''See Hearing Access Service 3.2.2.2. Preset Changed operation.'''
class ChangeId(OpenIntEnum): class ChangeId(utils.OpenIntEnum):
# fmt: off # fmt: off
GENERIC_UPDATE = 0x00 GENERIC_UPDATE = 0x00
PRESET_RECORD_DELETED = 0x01 PRESET_RECORD_DELETED = 0x01
@@ -189,11 +190,11 @@ class PresetRecord:
@dataclass @dataclass
class Property: class Property:
class Writable(OpenIntEnum): class Writable(utils.OpenIntEnum):
CANNOT_BE_WRITTEN = 0b0 CANNOT_BE_WRITTEN = 0b0
CAN_BE_WRITTEN = 0b1 CAN_BE_WRITTEN = 0b1
class IsAvailable(OpenIntEnum): class IsAvailable(utils.OpenIntEnum):
IS_UNAVAILABLE = 0b0 IS_UNAVAILABLE = 0b0
IS_AVAILABLE = 0b1 IS_AVAILABLE = 0b1
@@ -223,9 +224,9 @@ class PresetRecord:
class HearingAccessService(gatt.TemplateService): class HearingAccessService(gatt.TemplateService):
UUID = gatt.GATT_HEARING_ACCESS_SERVICE UUID = gatt.GATT_HEARING_ACCESS_SERVICE
hearing_aid_features_characteristic: gatt.Characteristic hearing_aid_features_characteristic: gatt.Characteristic[bytes]
hearing_aid_preset_control_point: gatt.Characteristic hearing_aid_preset_control_point: gatt.Characteristic[bytes]
active_preset_index_characteristic: gatt.Characteristic active_preset_index_characteristic: gatt.Characteristic[bytes]
active_preset_index: int active_preset_index: int
active_preset_index_per_device: Dict[Address, int] active_preset_index_per_device: Dict[Address, int]
@@ -265,13 +266,13 @@ class HearingAccessService(gatt.TemplateService):
# associate the lowest index as the current active preset at startup # associate the lowest index as the current active preset at startup
self.active_preset_index = sorted(self.preset_records.keys())[0] self.active_preset_index = sorted(self.preset_records.keys())[0]
@device.on('connection') # type: ignore @device.on(device.EVENT_CONNECTION)
def on_connection(connection: Connection) -> None: def on_connection(connection: Connection) -> None:
@connection.on('disconnection') # type: ignore @connection.on(connection.EVENT_DISCONNECTION)
def on_disconnection(_reason) -> None: def on_disconnection(_reason) -> None:
self.currently_connected_clients.remove(connection) self.currently_connected_clients.remove(connection)
@connection.on('pairing') # type: ignore @connection.on(connection.EVENT_PAIRING)
def on_pairing(*_: Any) -> None: def on_pairing(*_: Any) -> None:
self.on_incoming_paired_connection(connection) self.on_incoming_paired_connection(connection)
@@ -332,7 +333,7 @@ class HearingAccessService(gatt.TemplateService):
# Update the active preset index if needed # Update the active preset index if needed
await self.notify_active_preset_for_connection(connection) await self.notify_active_preset_for_connection(connection)
connection.abort_on('disconnection', on_connection_async()) utils.cancel_on_event(connection, 'disconnection', on_connection_async())
def _on_read_active_preset_index( def _on_read_active_preset_index(
self, __connection__: Optional[Connection] self, __connection__: Optional[Connection]
@@ -381,7 +382,7 @@ class HearingAccessService(gatt.TemplateService):
if len(presets) == 0: if len(presets) == 0:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE) raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
AsyncRunner.spawn(self._read_preset_response(connection, presets)) utils.AsyncRunner.spawn(self._read_preset_response(connection, presets))
async def _read_preset_response( async def _read_preset_response(
self, connection: Connection, presets: List[PresetRecord] self, connection: Connection, presets: List[PresetRecord]
@@ -631,11 +632,12 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
preset_control_point_indications: asyncio.Queue preset_control_point_indications: asyncio.Queue
active_preset_index_notification: asyncio.Queue
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None: def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy self.service_proxy = service_proxy
self.server_features = gatt.PackedCharacteristicAdapter( self.server_features = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_characteristics_by_uuid( service_proxy.get_characteristics_by_uuid(
gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC
)[0], )[0],
@@ -648,7 +650,7 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
)[0] )[0]
) )
self.active_preset_index = gatt.PackedCharacteristicAdapter( self.active_preset_index = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_characteristics_by_uuid( service_proxy.get_characteristics_by_uuid(
gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC
)[0], )[0],

View File

@@ -16,13 +16,14 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
from enum import IntEnum from enum import IntEnum
import struct import struct
from typing import Optional
from bumble import core from bumble import core
from ..gatt_client import ProfileServiceProxy from bumble.att import ATT_Error
from ..att import ATT_Error from bumble.gatt import (
from ..gatt import (
GATT_HEART_RATE_SERVICE, GATT_HEART_RATE_SERVICE,
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC, GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC, GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
@@ -30,9 +31,13 @@ from ..gatt import (
TemplateService, TemplateService,
Characteristic, Characteristic,
CharacteristicValue, CharacteristicValue,
)
from bumble.gatt_adapters import (
DelegatedCharacteristicAdapter, DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter, PackedCharacteristicAdapter,
SerializableCharacteristicAdapter,
) )
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -42,6 +47,10 @@ class HeartRateService(TemplateService):
CONTROL_POINT_NOT_SUPPORTED = 0x80 CONTROL_POINT_NOT_SUPPORTED = 0x80
RESET_ENERGY_EXPENDED = 0x01 RESET_ENERGY_EXPENDED = 0x01
heart_rate_measurement_characteristic: Characteristic[HeartRateMeasurement]
body_sensor_location_characteristic: Characteristic[BodySensorLocation]
heart_rate_control_point_characteristic: Characteristic[int]
class BodySensorLocation(IntEnum): class BodySensorLocation(IntEnum):
OTHER = 0 OTHER = 0
CHEST = 1 CHEST = 1
@@ -150,15 +159,14 @@ class HeartRateService(TemplateService):
body_sensor_location=None, body_sensor_location=None,
reset_energy_expended=None, reset_energy_expended=None,
): ):
self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter( self.heart_rate_measurement_characteristic = SerializableCharacteristicAdapter(
Characteristic( Characteristic(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC, GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Characteristic.Properties.NOTIFY, Characteristic.Properties.NOTIFY,
0, 0,
CharacteristicValue(read=read_heart_rate_measurement), CharacteristicValue(read=read_heart_rate_measurement),
), ),
# pylint: disable=unnecessary-lambda HeartRateService.HeartRateMeasurement,
encode=lambda value: bytes(value),
) )
characteristics = [self.heart_rate_measurement_characteristic] characteristics = [self.heart_rate_measurement_characteristic]
@@ -198,15 +206,22 @@ class HeartRateService(TemplateService):
class HeartRateServiceProxy(ProfileServiceProxy): class HeartRateServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = HeartRateService SERVICE_CLASS = HeartRateService
heart_rate_measurement: Optional[
CharacteristicProxy[HeartRateService.HeartRateMeasurement]
]
body_sensor_location: Optional[
CharacteristicProxy[HeartRateService.BodySensorLocation]
]
heart_rate_control_point: Optional[CharacteristicProxy[int]]
def __init__(self, service_proxy): def __init__(self, service_proxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
): ):
self.heart_rate_measurement = DelegatedCharacteristicAdapter( self.heart_rate_measurement = SerializableCharacteristicAdapter(
characteristics[0], characteristics[0], HeartRateService.HeartRateMeasurement
decode=HeartRateService.HeartRateMeasurement.from_bytes,
) )
else: else:
self.heart_rate_measurement = None self.heart_rate_measurement = None

View File

@@ -17,23 +17,35 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import enum
import struct import struct
from typing import List, Type from typing import Any, List, Type
from typing_extensions import Self from typing_extensions import Self
from bumble.profiles import bap
from bumble import utils from bumble import utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Classes # Classes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AudioActiveState(utils.OpenIntEnum):
NO_AUDIO_DATA_TRANSMITTED = 0x00
AUDIO_DATA_TRANSMITTED = 0x01
class AssistedListeningStream(utils.OpenIntEnum):
UNSPECIFIED_AUDIO_ENHANCEMENT = 0x00
@dataclasses.dataclass @dataclasses.dataclass
class Metadata: class Metadata:
'''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures. '''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures.
As Metadata fields may extend, and Spec doesn't forbid duplication, we don't parse As Metadata fields may extend, and the spec may not guarantee the uniqueness of
Metadata into a key-value style dataclass here. Rather, we encourage users to parse tags, we don't automatically parse the Metadata data into specific classes.
again outside the lib. Users of this class may decode the data by themselves, or use the Entry.decode
method.
''' '''
class Tag(utils.OpenIntEnum): class Tag(utils.OpenIntEnum):
@@ -57,6 +69,44 @@ class Metadata:
tag: Metadata.Tag tag: Metadata.Tag
data: bytes data: bytes
def decode(self) -> Any:
"""
Decode the data into an object, if possible.
If no specific object class exists to represent the data, the raw data
bytes are returned.
"""
if self.tag in (
Metadata.Tag.PREFERRED_AUDIO_CONTEXTS,
Metadata.Tag.STREAMING_AUDIO_CONTEXTS,
):
return bap.ContextType(struct.unpack("<H", self.data)[0])
if self.tag in (
Metadata.Tag.PROGRAM_INFO,
Metadata.Tag.PROGRAM_INFO_URI,
Metadata.Tag.BROADCAST_NAME,
):
return self.data.decode("utf-8")
if self.tag == Metadata.Tag.LANGUAGE:
return self.data.decode("ascii")
if self.tag == Metadata.Tag.CCID_LIST:
return list(self.data)
if self.tag == Metadata.Tag.PARENTAL_RATING:
return self.data[0]
if self.tag == Metadata.Tag.AUDIO_ACTIVE_STATE:
return AudioActiveState(self.data[0])
if self.tag == Metadata.Tag.ASSISTED_LISTENING_STREAM:
return AssistedListeningStream(self.data[0])
return self.data
@classmethod @classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self: def from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(tag=Metadata.Tag(data[0]), data=data[1:]) return cls(tag=Metadata.Tag(data[0]), data=data[1:])
@@ -66,6 +116,29 @@ class Metadata:
entries: List[Entry] = dataclasses.field(default_factory=list) entries: List[Entry] = dataclasses.field(default_factory=list)
def pretty_print(self, indent: str) -> str:
"""Convenience method to generate a string with one key-value pair per line."""
max_key_length = 0
keys = []
values = []
for entry in self.entries:
key = entry.tag.name
max_key_length = max(max_key_length, len(key))
keys.append(key)
decoded = entry.decode()
if isinstance(decoded, enum.Enum):
values.append(decoded.name)
elif isinstance(decoded, bytes):
values.append(decoded.hex())
else:
values.append(str(decoded))
return '\n'.join(
f'{indent}{key}: {" " * (max_key_length-len(key))}{value}'
for key, value in zip(keys, values)
)
@classmethod @classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self: def from_bytes(cls: Type[Self], data: bytes) -> Self:
entries = [] entries = []
@@ -81,3 +154,13 @@ class Metadata:
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return b''.join([bytes(entry) for entry in self.entries]) return b''.join([bytes(entry) for entry in self.entries])
def __str__(self) -> str:
entries_str = []
for entry in self.entries:
decoded = entry.decode()
entries_str.append(
f'{entry.tag.name}: '
f'{decoded.hex() if isinstance(decoded, bytes) else decoded!r}'
)
return f'Metadata(entries={", ".join(entry_str for entry_str in entries_str)})'

View File

@@ -208,7 +208,7 @@ class MediaControlService(gatt.TemplateService):
properties=gatt.Characteristic.Properties.READ properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY, | gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=media_player_name or 'Bumble Player', value=(media_player_name or 'Bumble Player').encode(),
) )
self.track_changed_characteristic = gatt.Characteristic( self.track_changed_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_CHANGED_CHARACTERISTIC, uuid=gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
@@ -247,14 +247,16 @@ class MediaControlService(gatt.TemplateService):
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'', value=b'',
) )
self.media_control_point_characteristic = gatt.Characteristic( self.media_control_point_characteristic: gatt.Characteristic[bytes] = (
uuid=gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC, gatt.Characteristic(
properties=gatt.Characteristic.Properties.WRITE uuid=gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.NOTIFY, | gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION | gatt.Characteristic.Properties.NOTIFY,
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION, permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
value=gatt.CharacteristicValue(write=self.on_media_control_point), | gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(write=self.on_media_control_point),
)
) )
self.media_control_point_opcodes_supported_characteristic = gatt.Characteristic( self.media_control_point_opcodes_supported_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC, uuid=gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
@@ -336,30 +338,38 @@ class MediaControlServiceProxy(
'content_control_id': gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC, 'content_control_id': gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
} }
media_player_name: Optional[gatt_client.CharacteristicProxy] = None EVENT_MEDIA_STATE = "media_state"
media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy] = None EVENT_TRACK_CHANGED = "track_changed"
media_player_icon_url: Optional[gatt_client.CharacteristicProxy] = None EVENT_TRACK_TITLE = "track_title"
track_changed: Optional[gatt_client.CharacteristicProxy] = None EVENT_TRACK_DURATION = "track_duration"
track_title: Optional[gatt_client.CharacteristicProxy] = None EVENT_TRACK_POSITION = "track_position"
track_duration: Optional[gatt_client.CharacteristicProxy] = None
track_position: Optional[gatt_client.CharacteristicProxy] = None media_player_name: Optional[gatt_client.CharacteristicProxy[bytes]] = None
playback_speed: Optional[gatt_client.CharacteristicProxy] = None media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
seeking_speed: Optional[gatt_client.CharacteristicProxy] = None media_player_icon_url: Optional[gatt_client.CharacteristicProxy[bytes]] = None
current_track_segments_object_id: Optional[gatt_client.CharacteristicProxy] = None track_changed: Optional[gatt_client.CharacteristicProxy[bytes]] = None
current_track_object_id: Optional[gatt_client.CharacteristicProxy] = None track_title: Optional[gatt_client.CharacteristicProxy[bytes]] = None
next_track_object_id: Optional[gatt_client.CharacteristicProxy] = None track_duration: Optional[gatt_client.CharacteristicProxy[bytes]] = None
parent_group_object_id: Optional[gatt_client.CharacteristicProxy] = None track_position: Optional[gatt_client.CharacteristicProxy[bytes]] = None
current_group_object_id: Optional[gatt_client.CharacteristicProxy] = None playback_speed: Optional[gatt_client.CharacteristicProxy[bytes]] = None
playing_order: Optional[gatt_client.CharacteristicProxy] = None seeking_speed: Optional[gatt_client.CharacteristicProxy[bytes]] = None
playing_orders_supported: Optional[gatt_client.CharacteristicProxy] = None current_track_segments_object_id: Optional[
media_state: Optional[gatt_client.CharacteristicProxy] = None gatt_client.CharacteristicProxy[bytes]
media_control_point: Optional[gatt_client.CharacteristicProxy] = None ] = None
media_control_point_opcodes_supported: Optional[gatt_client.CharacteristicProxy] = ( current_track_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
None next_track_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
) parent_group_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
search_control_point: Optional[gatt_client.CharacteristicProxy] = None current_group_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
search_results_object_id: Optional[gatt_client.CharacteristicProxy] = None playing_order: Optional[gatt_client.CharacteristicProxy[bytes]] = None
content_control_id: Optional[gatt_client.CharacteristicProxy] = None playing_orders_supported: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_state: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_control_point: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_control_point_opcodes_supported: Optional[
gatt_client.CharacteristicProxy[bytes]
] = None
search_control_point: Optional[gatt_client.CharacteristicProxy[bytes]] = None
search_results_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
content_control_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
if TYPE_CHECKING: if TYPE_CHECKING:
media_control_point_notifications: asyncio.Queue[bytes] media_control_point_notifications: asyncio.Queue[bytes]
@@ -428,20 +438,20 @@ class MediaControlServiceProxy(
self.media_control_point_notifications.put_nowait(data) self.media_control_point_notifications.put_nowait(data)
def _on_media_state(self, data: bytes) -> None: def _on_media_state(self, data: bytes) -> None:
self.emit('media_state', MediaState(data[0])) self.emit(self.EVENT_MEDIA_STATE, MediaState(data[0]))
def _on_track_changed(self, data: bytes) -> None: def _on_track_changed(self, data: bytes) -> None:
del data del data
self.emit('track_changed') self.emit(self.EVENT_TRACK_CHANGED)
def _on_track_title(self, data: bytes) -> None: def _on_track_title(self, data: bytes) -> None:
self.emit('track_title', data.decode("utf-8")) self.emit(self.EVENT_TRACK_TITLE, data.decode("utf-8"))
def _on_track_duration(self, data: bytes) -> None: def _on_track_duration(self, data: bytes) -> None:
self.emit('track_duration', struct.unpack_from('<i', data)[0]) self.emit(self.EVENT_TRACK_DURATION, struct.unpack_from('<i', data)[0])
def _on_track_position(self, data: bytes) -> None: def _on_track_position(self, data: bytes) -> None:
self.emit('track_position', struct.unpack_from('<i', data)[0]) self.emit(self.EVENT_TRACK_POSITION, struct.unpack_from('<i', data)[0])
class GenericMediaControlServiceProxy(MediaControlServiceProxy): class GenericMediaControlServiceProxy(MediaControlServiceProxy):

View File

@@ -25,6 +25,7 @@ from typing import Optional, Sequence, Union
from bumble.profiles.bap import AudioLocation, CodecSpecificCapabilities, ContextType from bumble.profiles.bap import AudioLocation, CodecSpecificCapabilities, ContextType
from bumble.profiles import le_audio from bumble.profiles import le_audio
from bumble import gatt from bumble import gatt
from bumble import gatt_adapters
from bumble import gatt_client from bumble import gatt_client
from bumble import hci from bumble import hci
@@ -72,6 +73,19 @@ class PacRecord:
metadata=metadata, metadata=metadata,
) )
@classmethod
def list_from_bytes(cls, data: bytes) -> list[PacRecord]:
"""Parse a serialized list of records preceded by a one byte list length."""
record_count = data[0]
records = []
offset = 1
for _ in range(record_count):
record = PacRecord.from_bytes(data[offset:])
offset += len(bytes(record))
records.append(record)
return records
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
capabilities_bytes = bytes(self.codec_specific_capabilities) capabilities_bytes = bytes(self.codec_specific_capabilities)
metadata_bytes = bytes(self.metadata) metadata_bytes = bytes(self.metadata)
@@ -90,12 +104,12 @@ class PacRecord:
class PublishedAudioCapabilitiesService(gatt.TemplateService): class PublishedAudioCapabilitiesService(gatt.TemplateService):
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
sink_pac: Optional[gatt.Characteristic] sink_pac: Optional[gatt.Characteristic[bytes]]
sink_audio_locations: Optional[gatt.Characteristic] sink_audio_locations: Optional[gatt.Characteristic[bytes]]
source_pac: Optional[gatt.Characteristic] source_pac: Optional[gatt.Characteristic[bytes]]
source_audio_locations: Optional[gatt.Characteristic] source_audio_locations: Optional[gatt.Characteristic[bytes]]
available_audio_contexts: gatt.Characteristic available_audio_contexts: gatt.Characteristic[bytes]
supported_audio_contexts: gatt.Characteristic supported_audio_contexts: gatt.Characteristic[bytes]
def __init__( def __init__(
self, self,
@@ -172,39 +186,70 @@ class PublishedAudioCapabilitiesService(gatt.TemplateService):
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy): class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = PublishedAudioCapabilitiesService SERVICE_CLASS = PublishedAudioCapabilitiesService
sink_pac: Optional[gatt_client.CharacteristicProxy] = None sink_pac: Optional[gatt_client.CharacteristicProxy[list[PacRecord]]] = None
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None sink_audio_locations: Optional[gatt_client.CharacteristicProxy[AudioLocation]] = (
source_pac: Optional[gatt_client.CharacteristicProxy] = None None
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None )
available_audio_contexts: gatt_client.CharacteristicProxy source_pac: Optional[gatt_client.CharacteristicProxy[list[PacRecord]]] = None
supported_audio_contexts: gatt_client.CharacteristicProxy source_audio_locations: Optional[gatt_client.CharacteristicProxy[AudioLocation]] = (
None
)
available_audio_contexts: gatt_client.CharacteristicProxy[tuple[ContextType, ...]]
supported_audio_contexts: gatt_client.CharacteristicProxy[tuple[ContextType, ...]]
def __init__(self, service_proxy: gatt_client.ServiceProxy): def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid( self.available_audio_contexts = (
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC gatt_adapters.DelegatedCharacteristicProxyAdapter(
)[0] service_proxy.get_required_characteristic_by_uuid(
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid( gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC ),
)[0] decode=lambda x: tuple(map(ContextType, struct.unpack('<HH', x))),
)
)
self.supported_audio_contexts = (
gatt_adapters.DelegatedCharacteristicProxyAdapter(
service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
),
decode=lambda x: tuple(map(ContextType, struct.unpack('<HH', x))),
)
)
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_PAC_CHARACTERISTIC gatt.GATT_SINK_PAC_CHARACTERISTIC
): ):
self.sink_pac = characteristics[0] self.sink_pac = gatt_adapters.DelegatedCharacteristicProxyAdapter(
characteristics[0],
decode=PacRecord.list_from_bytes,
)
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_PAC_CHARACTERISTIC gatt.GATT_SOURCE_PAC_CHARACTERISTIC
): ):
self.source_pac = characteristics[0] self.source_pac = gatt_adapters.DelegatedCharacteristicProxyAdapter(
characteristics[0],
decode=PacRecord.list_from_bytes,
)
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
): ):
self.sink_audio_locations = characteristics[0] self.sink_audio_locations = (
gatt_adapters.DelegatedCharacteristicProxyAdapter(
characteristics[0],
decode=lambda x: AudioLocation(struct.unpack('<I', x)[0]),
)
)
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
): ):
self.source_audio_locations = characteristics[0] self.source_audio_locations = (
gatt_adapters.DelegatedCharacteristicProxyAdapter(
characteristics[0],
decode=lambda x: AudioLocation(struct.unpack('<I', x)[0]),
)
)

View File

@@ -40,7 +40,7 @@ class PublicBroadcastAnnouncement:
def from_bytes(cls, data: bytes) -> Self: def from_bytes(cls, data: bytes) -> Self:
features = cls.Features(data[0]) features = cls.Features(data[0])
metadata_length = data[1] metadata_length = data[1]
metadata_ltv = data[1 : 1 + metadata_length] metadata_ltv = data[2 : 2 + metadata_length]
return cls( return cls(
features=features, metadata=le_audio.Metadata.from_bytes(metadata_ltv) features=features, metadata=le_audio.Metadata.from_bytes(metadata_ltv)
) )

View File

@@ -24,12 +24,11 @@ import struct
from bumble.gatt import ( from bumble.gatt import (
TemplateService, TemplateService,
Characteristic, Characteristic,
DelegatedCharacteristicAdapter,
InvalidServiceError,
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE, GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE,
GATT_TMAP_ROLE_CHARACTERISTIC, GATT_TMAP_ROLE_CHARACTERISTIC,
) )
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy from bumble.gatt_adapters import DelegatedCharacteristicProxyAdapter
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -54,6 +53,8 @@ class Role(enum.IntFlag):
class TelephonyAndMediaAudioService(TemplateService): class TelephonyAndMediaAudioService(TemplateService):
UUID = GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE UUID = GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE
role_characteristic: Characteristic[bytes]
def __init__(self, role: Role): def __init__(self, role: Role):
self.role_characteristic = Characteristic( self.role_characteristic = Characteristic(
GATT_TMAP_ROLE_CHARACTERISTIC, GATT_TMAP_ROLE_CHARACTERISTIC,
@@ -69,20 +70,15 @@ class TelephonyAndMediaAudioService(TemplateService):
class TelephonyAndMediaAudioServiceProxy(ProfileServiceProxy): class TelephonyAndMediaAudioServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = TelephonyAndMediaAudioService SERVICE_CLASS = TelephonyAndMediaAudioService
role: DelegatedCharacteristicAdapter role: CharacteristicProxy[Role]
def __init__(self, service_proxy: ServiceProxy): def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
if not ( self.role = DelegatedCharacteristicProxyAdapter(
characteristics := service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
GATT_TMAP_ROLE_CHARACTERISTIC GATT_TMAP_ROLE_CHARACTERISTIC
) ),
):
raise InvalidServiceError('TMAP Role characteristic not found')
self.role = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: Role( decode=lambda value: Role(
struct.unpack_from('<H', value, 0)[0], struct.unpack_from('<H', value, 0)[0],
), ),

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2024 Google LLC # Copyright 2021-2025 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -17,14 +17,18 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import dataclasses
import enum import enum
from typing import Optional, Sequence
from bumble import att from bumble import att
from bumble import utils
from bumble import device from bumble import device
from bumble import gatt from bumble import gatt
from bumble import gatt_adapters
from bumble import gatt_client from bumble import gatt_client
from typing import Optional, Sequence
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
@@ -67,15 +71,31 @@ class VolumeControlPointOpcode(enum.IntEnum):
MUTE = 0x06 MUTE = 0x06
@dataclasses.dataclass
class VolumeState:
volume_setting: int
mute: int
change_counter: int
@classmethod
def from_bytes(cls, data: bytes) -> VolumeState:
return cls(data[0], data[1], data[2])
def __bytes__(self) -> bytes:
return bytes([self.volume_setting, self.mute, self.change_counter])
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Server # Server
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class VolumeControlService(gatt.TemplateService): class VolumeControlService(gatt.TemplateService):
UUID = gatt.GATT_VOLUME_CONTROL_SERVICE UUID = gatt.GATT_VOLUME_CONTROL_SERVICE
volume_state: gatt.Characteristic EVENT_VOLUME_STATE_CHANGE = "volume_state_change"
volume_control_point: gatt.Characteristic
volume_flags: gatt.Characteristic volume_state: gatt.Characteristic[bytes]
volume_control_point: gatt.Characteristic[bytes]
volume_flags: gatt.Characteristic[bytes]
volume_setting: int volume_setting: int
muted: int muted: int
@@ -126,16 +146,8 @@ class VolumeControlService(gatt.TemplateService):
included_services=list(included_services), included_services=list(included_services),
) )
@property
def volume_state_bytes(self) -> bytes:
return bytes([self.volume_setting, self.muted, self.change_counter])
@volume_state_bytes.setter
def volume_state_bytes(self, new_value: bytes) -> None:
self.volume_setting, self.muted, self.change_counter = new_value
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes: def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
return self.volume_state_bytes return bytes(VolumeState(self.volume_setting, self.muted, self.change_counter))
def _on_write_volume_control_point( def _on_write_volume_control_point(
self, connection: Optional[device.Connection], value: bytes self, connection: Optional[device.Connection], value: bytes
@@ -151,16 +163,12 @@ class VolumeControlService(gatt.TemplateService):
handler = getattr(self, '_on_' + opcode.name.lower()) handler = getattr(self, '_on_' + opcode.name.lower())
if handler(*value[2:]): if handler(*value[2:]):
self.change_counter = (self.change_counter + 1) % 256 self.change_counter = (self.change_counter + 1) % 256
connection.abort_on( utils.cancel_on_event(
connection,
'disconnection', 'disconnection',
connection.device.notify_subscribers( connection.device.notify_subscribers(attribute=self.volume_state),
attribute=self.volume_state,
value=self.volume_state_bytes,
),
)
self.emit(
'volume_state', self.volume_setting, self.muted, self.change_counter
) )
self.emit(self.EVENT_VOLUME_STATE_CHANGE)
def _on_relative_volume_down(self) -> bool: def _on_relative_volume_down(self) -> bool:
old_volume = self.volume_setting old_volume = self.volume_setting
@@ -206,25 +214,27 @@ class VolumeControlService(gatt.TemplateService):
class VolumeControlServiceProxy(gatt_client.ProfileServiceProxy): class VolumeControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = VolumeControlService SERVICE_CLASS = VolumeControlService
volume_control_point: gatt_client.CharacteristicProxy volume_control_point: gatt_client.CharacteristicProxy[bytes]
volume_state: gatt_client.CharacteristicProxy[VolumeState]
volume_flags: gatt_client.CharacteristicProxy[VolumeFlags]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None: def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy self.service_proxy = service_proxy
self.volume_state = gatt.PackedCharacteristicAdapter( self.volume_state = gatt_adapters.SerializableCharacteristicProxyAdapter(
service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_VOLUME_STATE_CHARACTERISTIC gatt.GATT_VOLUME_STATE_CHARACTERISTIC
)[0], ),
'BBB', VolumeState,
) )
self.volume_control_point = service_proxy.get_characteristics_by_uuid( self.volume_control_point = service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC
)[0] )
self.volume_flags = gatt.PackedCharacteristicAdapter( self.volume_flags = gatt_adapters.DelegatedCharacteristicProxyAdapter(
service_proxy.get_characteristics_by_uuid( service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC
)[0], ),
'B', decode=lambda data: VolumeFlags(data[0]),
) )

307
bumble/profiles/vocs.py Normal file
View File

@@ -0,0 +1,307 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
from dataclasses import dataclass
from typing import Optional
from bumble.device import Connection
from bumble.att import ATT_Error
from bumble.gatt import (
Characteristic,
TemplateService,
CharacteristicValue,
GATT_VOLUME_OFFSET_CONTROL_SERVICE,
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC,
GATT_AUDIO_LOCATION_CHARACTERISTIC,
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC,
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC,
)
from bumble.gatt_adapters import (
DelegatedCharacteristicProxyAdapter,
SerializableCharacteristicProxyAdapter,
UTF8CharacteristicProxyAdapter,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
from bumble import utils
from bumble.profiles.bap import AudioLocation
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
MIN_VOLUME_OFFSET = -255
MAX_VOLUME_OFFSET = 255
CHANGE_COUNTER_MAX_VALUE = 0xFF
class SetVolumeOffsetOpCode(utils.OpenIntEnum):
SET_VOLUME_OFFSET = 0x01
class ErrorCode(utils.OpenIntEnum):
"""
See Volume Offset Control Service 1.6. Application error codes.
"""
INVALID_CHANGE_COUNTER = 0x80
OPCODE_NOT_SUPPORTED = 0x81
VALUE_OUT_OF_RANGE = 0x82
# -----------------------------------------------------------------------------
@dataclass
class VolumeOffsetState:
volume_offset: int = 0
change_counter: int = 0
attribute: Optional[Characteristic] = None
def __bytes__(self) -> bytes:
return struct.pack('<hB', self.volume_offset, self.change_counter)
@classmethod
def from_bytes(cls, data: bytes):
volume_offset, change_counter = struct.unpack('<hB', data)
return cls(volume_offset, change_counter)
def increment_change_counter(self) -> None:
self.change_counter = (self.change_counter + 1) % (CHANGE_COUNTER_MAX_VALUE + 1)
async def notify_subscribers_via_connection(self, connection: Connection) -> None:
assert self.attribute is not None
await connection.device.notify_subscribers(attribute=self.attribute)
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
@dataclass
class VocsAudioLocation:
audio_location: AudioLocation = AudioLocation.NOT_ALLOWED
attribute: Optional[Characteristic] = None
def __bytes__(self) -> bytes:
return struct.pack('<I', self.audio_location)
@classmethod
def from_bytes(cls, data: bytes):
audio_location = AudioLocation(struct.unpack('<I', data)[0])
return cls(audio_location)
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
assert self.attribute
self.audio_location = AudioLocation(int.from_bytes(value, 'little'))
await connection.device.notify_subscribers(attribute=self.attribute)
@dataclass
class VolumeOffsetControlPoint:
volume_offset_state: VolumeOffsetState
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
opcode = value[0]
if opcode != SetVolumeOffsetOpCode.SET_VOLUME_OFFSET:
raise ATT_Error(ErrorCode.OPCODE_NOT_SUPPORTED)
change_counter, volume_offset = struct.unpack('<Bh', value[1:])
await self._set_volume_offset(connection, change_counter, volume_offset)
async def _set_volume_offset(
self,
connection: Connection,
change_counter_operand: int,
volume_offset_operand: int,
) -> None:
change_counter = self.volume_offset_state.change_counter
if change_counter != change_counter_operand:
raise ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
if not MIN_VOLUME_OFFSET <= volume_offset_operand <= MAX_VOLUME_OFFSET:
raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
self.volume_offset_state.volume_offset = volume_offset_operand
self.volume_offset_state.increment_change_counter()
await self.volume_offset_state.notify_subscribers_via_connection(connection)
@dataclass
class AudioOutputDescription:
audio_output_description: str = ''
attribute: Optional[Characteristic] = None
@classmethod
def from_bytes(cls, data: bytes):
return cls(audio_output_description=data.decode('utf-8'))
def __bytes__(self) -> bytes:
return self.audio_output_description.encode('utf-8')
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
assert self.attribute
self.audio_output_description = value.decode('utf-8')
await connection.device.notify_subscribers(attribute=self.attribute)
# -----------------------------------------------------------------------------
class VolumeOffsetControlService(TemplateService):
UUID = GATT_VOLUME_OFFSET_CONTROL_SERVICE
def __init__(
self,
volume_offset_state: Optional[VolumeOffsetState] = None,
audio_location: Optional[VocsAudioLocation] = None,
audio_output_description: Optional[AudioOutputDescription] = None,
) -> None:
self.volume_offset_state = (
VolumeOffsetState() if volume_offset_state is None else volume_offset_state
)
self.audio_location = (
VocsAudioLocation() if audio_location is None else audio_location
)
self.audio_output_description = (
AudioOutputDescription()
if audio_output_description is None
else audio_output_description
)
self.volume_offset_control_point: VolumeOffsetControlPoint = (
VolumeOffsetControlPoint(self.volume_offset_state)
)
self.volume_offset_state_characteristic: Characteristic[bytes] = Characteristic(
uuid=GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC,
properties=(
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY
),
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=CharacteristicValue(read=self.volume_offset_state.on_read),
)
self.audio_location_characteristic: Characteristic[bytes] = Characteristic(
uuid=GATT_AUDIO_LOCATION_CHARACTERISTIC,
properties=(
Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE
),
permissions=(
Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION
),
value=CharacteristicValue(
read=self.audio_location.on_read,
write=self.audio_location.on_write,
),
)
self.audio_location.attribute = self.audio_location_characteristic
self.volume_offset_control_point_characteristic: Characteristic[bytes] = (
Characteristic(
uuid=GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC,
properties=Characteristic.Properties.WRITE,
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=CharacteristicValue(
write=self.volume_offset_control_point.on_write
),
)
)
self.audio_output_description_characteristic: Characteristic[bytes] = (
Characteristic(
uuid=GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC,
properties=(
Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE
),
permissions=(
Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION
),
value=CharacteristicValue(
read=self.audio_output_description.on_read,
write=self.audio_output_description.on_write,
),
)
)
self.audio_output_description.attribute = (
self.audio_output_description_characteristic
)
super().__init__(
characteristics=[
self.volume_offset_state_characteristic, # type: ignore
self.audio_location_characteristic, # type: ignore
self.volume_offset_control_point_characteristic, # type: ignore
self.audio_output_description_characteristic, # type: ignore
],
primary=False,
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class VolumeOffsetControlServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = VolumeOffsetControlService
def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy
self.volume_offset_state = SerializableCharacteristicProxyAdapter(
service_proxy.get_required_characteristic_by_uuid(
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC
),
VolumeOffsetState,
)
self.audio_location = DelegatedCharacteristicProxyAdapter(
service_proxy.get_required_characteristic_by_uuid(
GATT_AUDIO_LOCATION_CHARACTERISTIC
),
encode=lambda value: bytes([int(value)]),
decode=lambda data: AudioLocation(data[0]),
)
self.volume_offset_control_point = (
service_proxy.get_required_characteristic_by_uuid(
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC
)
)
self.audio_output_description = UTF8CharacteristicProxyAdapter(
service_proxy.get_required_characteristic_by_uuid(
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC
)
)

View File

@@ -25,16 +25,16 @@ import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from typing_extensions import Self from typing_extensions import Self
from pyee import EventEmitter
from bumble import core from bumble import core
from bumble import l2cap from bumble import l2cap
from bumble import sdp from bumble import sdp
from .colors import color from bumble import utils
from .core import ( from bumble.colors import color
from bumble.core import (
UUID, UUID,
BT_RFCOMM_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID,
BT_BR_EDR_TRANSPORT, PhysicalTransport,
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
InvalidArgumentError, InvalidArgumentError,
InvalidStateError, InvalidStateError,
@@ -441,7 +441,10 @@ class RFCOMM_MCC_MSC:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class DLC(EventEmitter): class DLC(utils.EventEmitter):
EVENT_OPEN = "open"
EVENT_CLOSE = "close"
class State(enum.IntEnum): class State(enum.IntEnum):
INIT = 0x00 INIT = 0x00
CONNECTING = 0x01 CONNECTING = 0x01
@@ -529,7 +532,7 @@ class DLC(EventEmitter):
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc)) self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.State.CONNECTED) self.change_state(DLC.State.CONNECTED)
self.emit('open') self.emit(self.EVENT_OPEN)
def on_ua_frame(self, _frame: RFCOMM_Frame) -> None: def on_ua_frame(self, _frame: RFCOMM_Frame) -> None:
if self.state == DLC.State.CONNECTING: if self.state == DLC.State.CONNECTING:
@@ -550,7 +553,7 @@ class DLC(EventEmitter):
self.disconnection_result.set_result(None) self.disconnection_result.set_result(None)
self.disconnection_result = None self.disconnection_result = None
self.multiplexer.on_dlc_disconnection(self) self.multiplexer.on_dlc_disconnection(self)
self.emit('close') self.emit(self.EVENT_CLOSE)
else: else:
logger.warning( logger.warning(
color( color(
@@ -733,7 +736,7 @@ class DLC(EventEmitter):
self.disconnection_result.cancel() self.disconnection_result.cancel()
self.disconnection_result = None self.disconnection_result = None
self.change_state(DLC.State.RESET) self.change_state(DLC.State.RESET)
self.emit('close') self.emit(self.EVENT_CLOSE)
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
@@ -749,7 +752,7 @@ class DLC(EventEmitter):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Multiplexer(EventEmitter): class Multiplexer(utils.EventEmitter):
class Role(enum.IntEnum): class Role(enum.IntEnum):
INITIATOR = 0x00 INITIATOR = 0x00
RESPONDER = 0x01 RESPONDER = 0x01
@@ -763,6 +766,8 @@ class Multiplexer(EventEmitter):
DISCONNECTED = 0x05 DISCONNECTED = 0x05
RESET = 0x06 RESET = 0x06
EVENT_DLC = "dlc"
connection_result: Optional[asyncio.Future] connection_result: Optional[asyncio.Future]
disconnection_result: Optional[asyncio.Future] disconnection_result: Optional[asyncio.Future]
open_result: Optional[asyncio.Future] open_result: Optional[asyncio.Future]
@@ -785,7 +790,7 @@ class Multiplexer(EventEmitter):
# Become a sink for the L2CAP channel # Become a sink for the L2CAP channel
l2cap_channel.sink = self.on_pdu l2cap_channel.sink = self.on_pdu
l2cap_channel.on('close', self.on_l2cap_channel_close) l2cap_channel.on(l2cap_channel.EVENT_CLOSE, self.on_l2cap_channel_close)
def change_state(self, new_state: State) -> None: def change_state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}') logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
@@ -845,7 +850,7 @@ class Multiplexer(EventEmitter):
self.open_result.set_exception( self.open_result.set_exception(
core.ConnectionError( core.ConnectionError(
core.ConnectionError.CONNECTION_REFUSED, core.ConnectionError.CONNECTION_REFUSED,
BT_BR_EDR_TRANSPORT, PhysicalTransport.BR_EDR,
self.l2cap_channel.connection.peer_address, self.l2cap_channel.connection.peer_address,
'rfcomm', 'rfcomm',
) )
@@ -901,7 +906,7 @@ class Multiplexer(EventEmitter):
self.dlcs[pn.dlci] = dlc self.dlcs[pn.dlci] = dlc
# Re-emit the handshake completion event # Re-emit the handshake completion event
dlc.on('open', lambda: self.emit('dlc', dlc)) dlc.on(dlc.EVENT_OPEN, lambda: self.emit(self.EVENT_DLC, dlc))
# Respond to complete the handshake # Respond to complete the handshake
dlc.accept() dlc.accept()
@@ -1075,7 +1080,9 @@ class Client:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Server(EventEmitter): class Server(utils.EventEmitter):
EVENT_START = "start"
def __init__( def __init__(
self, device: Device, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU self, device: Device, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
) -> None: ) -> None:
@@ -1122,7 +1129,9 @@ class Server(EventEmitter):
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug(f'+++ new L2CAP connection: {l2cap_channel}') logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) l2cap_channel.on(
l2cap_channel.EVENT_OPEN, lambda: self.on_l2cap_channel_open(l2cap_channel)
)
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None: def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}') logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
@@ -1130,10 +1139,10 @@ class Server(EventEmitter):
# Create a new multiplexer for the channel # Create a new multiplexer for the channel
multiplexer = Multiplexer(l2cap_channel, Multiplexer.Role.RESPONDER) multiplexer = Multiplexer(l2cap_channel, Multiplexer.Role.RESPONDER)
multiplexer.acceptor = self.accept_dlc multiplexer.acceptor = self.accept_dlc
multiplexer.on('dlc', self.on_dlc) multiplexer.on(multiplexer.EVENT_DLC, self.on_dlc)
# Notify # Notify
self.emit('start', multiplexer) self.emit(self.EVENT_START, multiplexer)
def accept_dlc(self, channel_number: int) -> Optional[Tuple[int, int]]: def accept_dlc(self, channel_number: int) -> Optional[Tuple[int, int]]:
return self.dlc_configs.get(channel_number) return self.dlc_configs.get(channel_number)

View File

@@ -16,18 +16,24 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import struct import struct
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING from typing import Iterable, NewType, Optional, Union, Sequence, Type, TYPE_CHECKING
from typing_extensions import Self from typing_extensions import Self
from . import core, l2cap from bumble import core, l2cap
from .colors import color from bumble.colors import color
from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError from bumble.core import (
from .hci import HCI_Object, name_or_number, key_with_value InvalidStateError,
InvalidArgumentError,
InvalidPacketError,
ProtocolError,
)
from bumble.hci import HCI_Object, name_or_number, key_with_value
if TYPE_CHECKING: if TYPE_CHECKING:
from .device import Device, Connection from bumble.device import Device, Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -124,7 +130,7 @@ SDP_ATTRIBUTE_ID_NAMES = {
SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot') SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
# To be used in searches where an attribute ID list allows a range to be specified # To be used in searches where an attribute ID list allows a range to be specified
SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4) # Express this as tuple so we can convey the desired encoding size SDP_ALL_ATTRIBUTES_RANGE = (0x0000, 0xFFFF)
# fmt: on # fmt: on
# pylint: enable=line-too-long # pylint: enable=line-too-long
@@ -242,11 +248,11 @@ class DataElement:
return DataElement(DataElement.BOOLEAN, value) return DataElement(DataElement.BOOLEAN, value)
@staticmethod @staticmethod
def sequence(value: List[DataElement]) -> DataElement: def sequence(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.SEQUENCE, value) return DataElement(DataElement.SEQUENCE, value)
@staticmethod @staticmethod
def alternative(value: List[DataElement]) -> DataElement: def alternative(value: Iterable[DataElement]) -> DataElement:
return DataElement(DataElement.ALTERNATIVE, value) return DataElement(DataElement.ALTERNATIVE, value)
@staticmethod @staticmethod
@@ -344,9 +350,6 @@ class DataElement:
] # Keep a copy so we can re-serialize to an exact replica ] # Keep a copy so we can re-serialize to an exact replica
return result return result
def to_bytes(self):
return bytes(self)
def __bytes__(self): def __bytes__(self):
# Return early if we have a cache # Return early if we have a cache
if self.bytes: if self.bytes:
@@ -476,7 +479,9 @@ class ServiceAttribute:
self.value = value self.value = value
@staticmethod @staticmethod
def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]: def list_from_data_elements(
elements: Sequence[DataElement],
) -> list[ServiceAttribute]:
attribute_list = [] attribute_list = []
for i in range(0, len(elements) // 2): for i in range(0, len(elements) // 2):
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)] attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
@@ -489,7 +494,7 @@ class ServiceAttribute:
@staticmethod @staticmethod
def find_attribute_in_list( def find_attribute_in_list(
attribute_list: List[ServiceAttribute], attribute_id: int attribute_list: Iterable[ServiceAttribute], attribute_id: int
) -> Optional[DataElement]: ) -> Optional[DataElement]:
return next( return next(
( (
@@ -537,7 +542,12 @@ class SDP_PDU:
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
''' '''
sdp_pdu_classes: Dict[int, Type[SDP_PDU]] = {} RESPONSE_PDU_IDS = {
SDP_SERVICE_SEARCH_REQUEST: SDP_SERVICE_SEARCH_RESPONSE,
SDP_SERVICE_ATTRIBUTE_REQUEST: SDP_SERVICE_ATTRIBUTE_RESPONSE,
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE,
}
sdp_pdu_classes: dict[int, Type[SDP_PDU]] = {}
name = None name = None
pdu_id = 0 pdu_id = 0
@@ -561,7 +571,7 @@ class SDP_PDU:
@staticmethod @staticmethod
def parse_service_record_handle_list_preceded_by_count( def parse_service_record_handle_list_preceded_by_count(
data: bytes, offset: int data: bytes, offset: int
) -> Tuple[int, List[int]]: ) -> tuple[int, list[int]]:
count = struct.unpack_from('>H', data, offset - 2)[0] count = struct.unpack_from('>H', data, offset - 2)[0]
handle_list = [ handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count) struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
@@ -623,11 +633,8 @@ class SDP_PDU:
def init_from_bytes(self, pdu, offset): def init_from_bytes(self, pdu, offset):
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self):
return self.pdu
def __bytes__(self): def __bytes__(self):
return self.to_bytes() return self.pdu
def __str__(self): def __str__(self):
result = f'{color(self.name, "blue")} [TID={self.transaction_id}]' result = f'{color(self.name, "blue")} [TID={self.transaction_id}]'
@@ -645,6 +652,8 @@ class SDP_ErrorResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
''' '''
error_code: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass( @SDP_PDU.subclass(
@@ -681,7 +690,7 @@ class SDP_ServiceSearchResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
''' '''
service_record_handle_list: List[int] service_record_handle_list: list[int]
total_service_record_count: int total_service_record_count: int
current_service_record_count: int current_service_record_count: int
continuation_state: bytes continuation_state: bytes
@@ -758,31 +767,99 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
''' '''
attribute_list_byte_count: int attribute_lists_byte_count: int
attribute_list: bytes attribute_lists: bytes
continuation_state: bytes continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Client: class Client:
channel: Optional[l2cap.ClassicChannel] def __init__(self, connection: Connection, mtu: int = 0) -> None:
def __init__(self, connection: Connection) -> None:
self.connection = connection self.connection = connection
self.pending_request = None self.channel: Optional[l2cap.ClassicChannel] = None
self.channel = None self.mtu = mtu
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request: Optional[SDP_PDU] = None
self.pending_response: Optional[asyncio.futures.Future[SDP_PDU]] = None
self.next_transaction_id = 0
async def connect(self) -> None: async def connect(self) -> None:
self.channel = await self.connection.create_l2cap_channel( self.channel = await self.connection.create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(SDP_PSM) spec=(
l2cap.ClassicChannelSpec(SDP_PSM, self.mtu)
if self.mtu
else l2cap.ClassicChannelSpec(SDP_PSM)
)
) )
self.channel.sink = self.on_pdu
async def disconnect(self) -> None: async def disconnect(self) -> None:
if self.channel: if self.channel:
await self.channel.disconnect() await self.channel.disconnect()
self.channel = None self.channel = None
async def search_services(self, uuids: List[core.UUID]) -> List[int]: def make_transaction_id(self) -> int:
transaction_id = self.next_transaction_id
self.next_transaction_id = (self.next_transaction_id + 1) & 0xFFFF
return transaction_id
def on_pdu(self, pdu: bytes) -> None:
if not self.pending_request:
logger.warning('received response with no pending request')
return
assert self.pending_response is not None
response = SDP_PDU.from_bytes(pdu)
# Check that the transaction ID is what we expect
if self.pending_request.transaction_id != response.transaction_id:
logger.warning(
f"received response with transaction ID {response.transaction_id} "
f"but expected {self.pending_request.transaction_id}"
)
return
# Check if the response is an error
if isinstance(response, SDP_ErrorResponse):
self.pending_response.set_exception(
ProtocolError(error_code=response.error_code)
)
return
# Check that the type of the response matches the request
if response.pdu_id != SDP_PDU.RESPONSE_PDU_IDS.get(self.pending_request.pdu_id):
logger.warning("response type mismatch")
return
self.pending_response.set_result(response)
async def send_request(self, request: SDP_PDU) -> SDP_PDU:
assert self.channel is not None
async with self.request_semaphore:
assert self.pending_request is None
assert self.pending_response is None
# Create a future value to hold the eventual response
self.pending_response = asyncio.get_running_loop().create_future()
self.pending_request = request
try:
self.channel.send_pdu(bytes(request))
return await self.pending_response
finally:
self.pending_request = None
self.pending_response = None
async def search_services(self, uuids: Iterable[core.UUID]) -> list[int]:
"""
Search for services by UUID.
Args:
uuids: service the UUIDs to search for.
Returns:
A list of matching service record handles.
"""
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None: if self.channel is None:
@@ -797,16 +874,16 @@ class Client:
continuation_state = bytes([0]) continuation_state = bytes([0])
watchdog = SDP_CONTINUATION_WATCHDOG watchdog = SDP_CONTINUATION_WATCHDOG
while watchdog > 0: while watchdog > 0:
response_pdu = await self.channel.send_request( response = await self.send_request(
SDP_ServiceSearchRequest( SDP_ServiceSearchRequest(
transaction_id=0, # Transaction ID TODO: pick a real value transaction_id=self.make_transaction_id(),
service_search_pattern=service_search_pattern, service_search_pattern=service_search_pattern,
maximum_service_record_count=0xFFFF, maximum_service_record_count=0xFFFF,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu)
logger.debug(f'<<< Response: {response}') logger.debug(f'<<< Response: {response}')
assert isinstance(response, SDP_ServiceSearchResponse)
service_record_handle_list += response.service_record_handle_list service_record_handle_list += response.service_record_handle_list
continuation_state = response.continuation_state continuation_state = response.continuation_state
if len(continuation_state) == 1 and continuation_state[0] == 0: if len(continuation_state) == 1 and continuation_state[0] == 0:
@@ -817,8 +894,21 @@ class Client:
return service_record_handle_list return service_record_handle_list
async def search_attributes( async def search_attributes(
self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]] self,
) -> List[List[ServiceAttribute]]: uuids: Iterable[core.UUID],
attribute_ids: Iterable[Union[int, tuple[int, int]]],
) -> list[list[ServiceAttribute]]:
"""
Search for attributes by UUID and attribute IDs.
Args:
uuids: the service UUIDs to search for.
attribute_ids: list of attribute IDs or (start, end) attribute ID ranges.
(use (0, 0xFFFF) to include all attributes)
Returns:
A list of list of attributes, one list per matching service.
"""
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None: if self.channel is None:
@@ -830,8 +920,8 @@ class Client:
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
( (
DataElement.unsigned_integer( DataElement.unsigned_integer_32(
attribute_id[0], value_size=attribute_id[1] attribute_id[0] << 16 | attribute_id[1]
) )
if isinstance(attribute_id, tuple) if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id) else DataElement.unsigned_integer_16(attribute_id)
@@ -845,17 +935,17 @@ class Client:
continuation_state = bytes([0]) continuation_state = bytes([0])
watchdog = SDP_CONTINUATION_WATCHDOG watchdog = SDP_CONTINUATION_WATCHDOG
while watchdog > 0: while watchdog > 0:
response_pdu = await self.channel.send_request( response = await self.send_request(
SDP_ServiceSearchAttributeRequest( SDP_ServiceSearchAttributeRequest(
transaction_id=0, # Transaction ID TODO: pick a real value transaction_id=self.make_transaction_id(),
service_search_pattern=service_search_pattern, service_search_pattern=service_search_pattern,
maximum_attribute_byte_count=0xFFFF, maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list, attribute_id_list=attribute_id_list,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu)
logger.debug(f'<<< Response: {response}') logger.debug(f'<<< Response: {response}')
assert isinstance(response, SDP_ServiceSearchAttributeResponse)
accumulator += response.attribute_lists accumulator += response.attribute_lists
continuation_state = response.continuation_state continuation_state = response.continuation_state
if len(continuation_state) == 1 and continuation_state[0] == 0: if len(continuation_state) == 1 and continuation_state[0] == 0:
@@ -878,8 +968,18 @@ class Client:
async def get_attributes( async def get_attributes(
self, self,
service_record_handle: int, service_record_handle: int,
attribute_ids: List[Union[int, Tuple[int, int]]], attribute_ids: Iterable[Union[int, tuple[int, int]]],
) -> List[ServiceAttribute]: ) -> list[ServiceAttribute]:
"""
Get attributes for a service.
Args:
service_record_handle: the handle for a service
attribute_ids: list or attribute IDs or (start, end) attribute ID handles.
Returns:
A list of attributes.
"""
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None: if self.channel is None:
@@ -888,8 +988,8 @@ class Client:
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
( (
DataElement.unsigned_integer( DataElement.unsigned_integer_32(
attribute_id[0], value_size=attribute_id[1] attribute_id[0] << 16 | attribute_id[1]
) )
if isinstance(attribute_id, tuple) if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id) else DataElement.unsigned_integer_16(attribute_id)
@@ -903,17 +1003,17 @@ class Client:
continuation_state = bytes([0]) continuation_state = bytes([0])
watchdog = SDP_CONTINUATION_WATCHDOG watchdog = SDP_CONTINUATION_WATCHDOG
while watchdog > 0: while watchdog > 0:
response_pdu = await self.channel.send_request( response = await self.send_request(
SDP_ServiceAttributeRequest( SDP_ServiceAttributeRequest(
transaction_id=0, # Transaction ID TODO: pick a real value transaction_id=self.make_transaction_id(),
service_record_handle=service_record_handle, service_record_handle=service_record_handle,
maximum_attribute_byte_count=0xFFFF, maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list, attribute_id_list=attribute_id_list,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu)
logger.debug(f'<<< Response: {response}') logger.debug(f'<<< Response: {response}')
assert isinstance(response, SDP_ServiceAttributeResponse)
accumulator += response.attribute_list accumulator += response.attribute_list
continuation_state = response.continuation_state continuation_state = response.continuation_state
if len(continuation_state) == 1 and continuation_state[0] == 0: if len(continuation_state) == 1 and continuation_state[0] == 0:
@@ -939,17 +1039,17 @@ class Client:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Server: class Server:
CONTINUATION_STATE = bytes([0x01, 0x43]) CONTINUATION_STATE = bytes([0x01, 0x00])
channel: Optional[l2cap.ClassicChannel] channel: Optional[l2cap.ClassicChannel]
Service = NewType('Service', List[ServiceAttribute]) Service = NewType('Service', list[ServiceAttribute])
service_records: Dict[int, Service] service_records: dict[int, Service]
current_response: Union[None, bytes, Tuple[int, List[int]]] current_response: Union[None, bytes, tuple[int, list[int]]]
def __init__(self, device: Device) -> None: def __init__(self, device: Device) -> None:
self.device = device self.device = device
self.service_records = {} # Service records maps, by record handle self.service_records = {} # Service records maps, by record handle
self.channel = None self.channel = None
self.current_response = None self.current_response = None # Current response data, used for continuations
def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None: def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
l2cap_channel_manager.create_classic_server( l2cap_channel_manager.create_classic_server(
@@ -960,7 +1060,7 @@ class Server:
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}') logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
self.channel.send_pdu(response) self.channel.send_pdu(response)
def match_services(self, search_pattern: DataElement) -> Dict[int, Service]: def match_services(self, search_pattern: DataElement) -> dict[int, Service]:
# Find the services for which the attributes in the pattern is a subset of the # Find the services for which the attributes in the pattern is a subset of the
# service's attribute values (NOTE: the value search recurses into sequences) # service's attribute values (NOTE: the value search recurses into sequences)
matching_services = {} matching_services = {}
@@ -1017,6 +1117,31 @@ class Server:
) )
) )
def check_continuation(
self,
continuation_state: bytes,
transaction_id: int,
) -> Optional[bool]:
# Check if this is a valid continuation
if len(continuation_state) > 1:
if (
self.current_response is None
or continuation_state != self.CONTINUATION_STATE
):
self.send_response(
SDP_ErrorResponse(
transaction_id=transaction_id,
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
)
)
return None
return True
# Cleanup any partial response leftover
self.current_response = None
return False
def get_next_response_payload(self, maximum_size): def get_next_response_payload(self, maximum_size):
if len(self.current_response) > maximum_size: if len(self.current_response) > maximum_size:
payload = self.current_response[:maximum_size] payload = self.current_response[:maximum_size]
@@ -1031,7 +1156,7 @@ class Server:
@staticmethod @staticmethod
def get_service_attributes( def get_service_attributes(
service: Service, attribute_ids: List[DataElement] service: Service, attribute_ids: Iterable[DataElement]
) -> DataElement: ) -> DataElement:
attributes = [] attributes = []
for attribute_id in attribute_ids: for attribute_id in attribute_ids:
@@ -1059,30 +1184,24 @@ class Server:
def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None: def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if (
if self.current_response is None: continuation := self.check_continuation(
self.send_response( request.continuation_state, request.transaction_id
SDP_ErrorResponse( )
transaction_id=request.transaction_id, ) is None:
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, return
)
)
return
else:
# Cleanup any partial response leftover
self.current_response = None
if not continuation:
# Find the matching services # Find the matching services
matching_services = self.match_services(request.service_search_pattern) matching_services = self.match_services(request.service_search_pattern)
service_record_handles = list(matching_services.keys()) service_record_handles = list(matching_services.keys())
logger.debug(f'Service Record Handles: {service_record_handles}')
# Only return up to the maximum requested # Only return up to the maximum requested
service_record_handles_subset = service_record_handles[ service_record_handles_subset = service_record_handles[
: request.maximum_service_record_count : request.maximum_service_record_count
] ]
# Serialize to a byte array, and remember the total count
logger.debug(f'Service Record Handles: {service_record_handles}')
self.current_response = ( self.current_response = (
len(service_record_handles), len(service_record_handles),
service_record_handles_subset, service_record_handles_subset,
@@ -1090,15 +1209,21 @@ class Server:
# Respond, keeping any unsent handles for later # Respond, keeping any unsent handles for later
assert isinstance(self.current_response, tuple) assert isinstance(self.current_response, tuple)
service_record_handles = self.current_response[1][ assert self.channel is not None
: request.maximum_service_record_count total_service_record_count, service_record_handles = self.current_response
maximum_service_record_count = (self.channel.peer_mtu - 11) // 4
service_record_handles_remaining = service_record_handles[
maximum_service_record_count:
] ]
service_record_handles = service_record_handles[:maximum_service_record_count]
self.current_response = ( self.current_response = (
self.current_response[0], total_service_record_count,
self.current_response[1][request.maximum_service_record_count :], service_record_handles_remaining,
) )
continuation_state = ( continuation_state = (
Server.CONTINUATION_STATE if self.current_response[1] else bytes([0]) Server.CONTINUATION_STATE
if service_record_handles_remaining
else bytes([0])
) )
service_record_handle_list = b''.join( service_record_handle_list = b''.join(
[struct.pack('>I', handle) for handle in service_record_handles] [struct.pack('>I', handle) for handle in service_record_handles]
@@ -1106,7 +1231,7 @@ class Server:
self.send_response( self.send_response(
SDP_ServiceSearchResponse( SDP_ServiceSearchResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
total_service_record_count=self.current_response[0], total_service_record_count=total_service_record_count,
current_service_record_count=len(service_record_handles), current_service_record_count=len(service_record_handles),
service_record_handle_list=service_record_handle_list, service_record_handle_list=service_record_handle_list,
continuation_state=continuation_state, continuation_state=continuation_state,
@@ -1117,19 +1242,14 @@ class Server:
self, request: SDP_ServiceAttributeRequest self, request: SDP_ServiceAttributeRequest
) -> None: ) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if (
if self.current_response is None: continuation := self.check_continuation(
self.send_response( request.continuation_state, request.transaction_id
SDP_ErrorResponse( )
transaction_id=request.transaction_id, ) is None:
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, return
)
)
return
else:
# Cleanup any partial response leftover
self.current_response = None
if not continuation:
# Check that the service exists # Check that the service exists
service = self.service_records.get(request.service_record_handle) service = self.service_records.get(request.service_record_handle)
if service is None: if service is None:
@@ -1151,14 +1271,18 @@ class Server:
self.current_response = bytes(attribute_list) self.current_response = bytes(attribute_list)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
assert self.channel is not None
maximum_attribute_byte_count = min(
request.maximum_attribute_byte_count, self.channel.peer_mtu - 9
)
attribute_list_response, continuation_state = self.get_next_response_payload( attribute_list_response, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count maximum_attribute_byte_count
) )
self.send_response( self.send_response(
SDP_ServiceAttributeResponse( SDP_ServiceAttributeResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
attribute_list_byte_count=len(attribute_list_response), attribute_list_byte_count=len(attribute_list_response),
attribute_list=attribute_list, attribute_list=attribute_list_response,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
@@ -1167,18 +1291,14 @@ class Server:
self, request: SDP_ServiceSearchAttributeRequest self, request: SDP_ServiceSearchAttributeRequest
) -> None: ) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if (
if self.current_response is None: continuation := self.check_continuation(
self.send_response( request.continuation_state, request.transaction_id
SDP_ErrorResponse( )
transaction_id=request.transaction_id, ) is None:
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR, return
)
)
else:
# Cleanup any partial response leftover
self.current_response = None
if not continuation:
# Find the matching services # Find the matching services
matching_services = self.match_services( matching_services = self.match_services(
request.service_search_pattern request.service_search_pattern
@@ -1198,14 +1318,18 @@ class Server:
self.current_response = bytes(attribute_lists) self.current_response = bytes(attribute_lists)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
assert self.channel is not None
maximum_attribute_byte_count = min(
request.maximum_attribute_byte_count, self.channel.peer_mtu - 9
)
attribute_lists_response, continuation_state = self.get_next_response_payload( attribute_lists_response, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count maximum_attribute_byte_count
) )
self.send_response( self.send_response(
SDP_ServiceSearchAttributeResponse( SDP_ServiceSearchAttributeResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
attribute_lists_byte_count=len(attribute_lists_response), attribute_lists_byte_count=len(attribute_lists_response),
attribute_lists=attribute_lists, attribute_lists=attribute_lists_response,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )

View File

@@ -41,26 +41,25 @@ from typing import (
cast, cast,
) )
from pyee import EventEmitter
from .colors import color from bumble.colors import color
from .hci import ( from bumble.hci import (
Address, Address,
Role,
HCI_LE_Enable_Encryption_Command, HCI_LE_Enable_Encryption_Command,
HCI_Object, HCI_Object,
key_with_value, key_with_value,
) )
from .core import ( from bumble.core import (
BT_BR_EDR_TRANSPORT, PhysicalTransport,
BT_CENTRAL_ROLE,
BT_LE_TRANSPORT,
AdvertisingData, AdvertisingData,
InvalidArgumentError, InvalidArgumentError,
ProtocolError, ProtocolError,
name_or_number, name_or_number,
) )
from .keys import PairingKeys from bumble.keys import PairingKeys
from . import crypto from bumble import crypto
from bumble import utils
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Connection, Device from bumble.device import Connection, Device
@@ -298,11 +297,8 @@ class SMP_Command:
def init_from_bytes(self, pdu: bytes, offset: int) -> None: def init_from_bytes(self, pdu: bytes, offset: int) -> None:
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields) return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self):
return self.pdu
def __bytes__(self): def __bytes__(self):
return self.to_bytes() return self.pdu
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
@@ -698,6 +694,7 @@ class Session:
self.ltk_ediv = 0 self.ltk_ediv = 0
self.ltk_rand = bytes(8) self.ltk_rand = bytes(8)
self.link_key: Optional[bytes] = None self.link_key: Optional[bytes] = None
self.maximum_encryption_key_size: int = 0
self.initiator_key_distribution: int = 0 self.initiator_key_distribution: int = 0
self.responder_key_distribution: int = 0 self.responder_key_distribution: int = 0
self.peer_random_value: Optional[bytes] = None self.peer_random_value: Optional[bytes] = None
@@ -727,12 +724,13 @@ class Session:
self.is_responder = not self.is_initiator self.is_responder = not self.is_initiator
# Listen for connection events # Listen for connection events
connection.on('disconnection', self.on_disconnection) connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
connection.on( connection.on(
'connection_encryption_change', self.on_connection_encryption_change connection.EVENT_CONNECTION_ENCRYPTION_CHANGE,
self.on_connection_encryption_change,
) )
connection.on( connection.on(
'connection_encryption_key_refresh', connection.EVENT_CONNECTION_ENCRYPTION_KEY_REFRESH,
self.on_connection_encryption_key_refresh, self.on_connection_encryption_key_refresh,
) )
@@ -744,6 +742,10 @@ class Session:
else: else:
self.pairing_result = None self.pairing_result = None
self.maximum_encryption_key_size = (
pairing_config.delegate.maximum_encryption_key_size
)
# Key Distribution (default values before negotiation) # Key Distribution (default values before negotiation)
self.initiator_key_distribution = ( self.initiator_key_distribution = (
pairing_config.delegate.local_initiator_key_distribution pairing_config.delegate.local_initiator_key_distribution
@@ -855,7 +857,7 @@ class Session:
initiator_io_capability: int, initiator_io_capability: int,
responder_io_capability: int, responder_io_capability: int,
) -> None: ) -> None:
if self.connection.transport == BT_BR_EDR_TRANSPORT: if self.connection.transport == PhysicalTransport.BR_EDR:
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
return return
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0): if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
@@ -898,7 +900,7 @@ class Session:
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR) self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
self.connection.abort_on('disconnection', prompt()) utils.cancel_on_event(self.connection, 'disconnection', prompt())
def prompt_user_for_numeric_comparison( def prompt_user_for_numeric_comparison(
self, code: int, next_steps: Callable[[], None] self, code: int, next_steps: Callable[[], None]
@@ -917,7 +919,7 @@ class Session:
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR) self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
self.connection.abort_on('disconnection', prompt()) utils.cancel_on_event(self.connection, 'disconnection', prompt())
def prompt_user_for_number(self, next_steps: Callable[[int], None]) -> None: def prompt_user_for_number(self, next_steps: Callable[[int], None]) -> None:
async def prompt() -> None: async def prompt() -> None:
@@ -934,7 +936,7 @@ class Session:
logger.warning(f'exception while prompting: {error}') logger.warning(f'exception while prompting: {error}')
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR) self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
self.connection.abort_on('disconnection', prompt()) utils.cancel_on_event(self.connection, 'disconnection', prompt())
def display_passkey(self) -> None: def display_passkey(self) -> None:
# Generate random Passkey/PIN code # Generate random Passkey/PIN code
@@ -949,7 +951,8 @@ class Session:
logger.debug(f'TK from passkey = {self.tk.hex()}') logger.debug(f'TK from passkey = {self.tk.hex()}')
try: try:
self.connection.abort_on( utils.cancel_on_event(
self.connection,
'disconnection', 'disconnection',
self.pairing_config.delegate.display_number(self.passkey, digits=6), self.pairing_config.delegate.display_number(self.passkey, digits=6),
) )
@@ -996,7 +999,7 @@ class Session:
io_capability=self.io_capability, io_capability=self.io_capability,
oob_data_flag=self.oob_data_flag, oob_data_flag=self.oob_data_flag,
auth_req=self.auth_req, auth_req=self.auth_req,
maximum_encryption_key_size=16, maximum_encryption_key_size=self.maximum_encryption_key_size,
initiator_key_distribution=self.initiator_key_distribution, initiator_key_distribution=self.initiator_key_distribution,
responder_key_distribution=self.responder_key_distribution, responder_key_distribution=self.responder_key_distribution,
) )
@@ -1008,7 +1011,7 @@ class Session:
io_capability=self.io_capability, io_capability=self.io_capability,
oob_data_flag=self.oob_data_flag, oob_data_flag=self.oob_data_flag,
auth_req=self.auth_req, auth_req=self.auth_req,
maximum_encryption_key_size=16, maximum_encryption_key_size=self.maximum_encryption_key_size,
initiator_key_distribution=self.initiator_key_distribution, initiator_key_distribution=self.initiator_key_distribution,
responder_key_distribution=self.responder_key_distribution, responder_key_distribution=self.responder_key_distribution,
) )
@@ -1048,7 +1051,7 @@ class Session:
) )
# Perform the next steps asynchronously in case we need to wait for input # Perform the next steps asynchronously in case we need to wait for input
self.connection.abort_on('disconnection', next_steps()) utils.cancel_on_event(self.connection, 'disconnection', next_steps())
else: else:
confirm_value = crypto.c1( confirm_value = crypto.c1(
self.tk, self.tk,
@@ -1168,11 +1171,11 @@ class Session:
if self.is_initiator: if self.is_initiator:
# CTKD: Derive LTK from LinkKey # CTKD: Derive LTK from LinkKey
if ( if (
self.connection.transport == BT_BR_EDR_TRANSPORT self.connection.transport == PhysicalTransport.BR_EDR
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
): ):
self.ctkd_task = self.connection.abort_on( self.ctkd_task = utils.cancel_on_event(
'disconnection', self.get_link_key_and_derive_ltk() self.connection, 'disconnection', self.get_link_key_and_derive_ltk()
) )
elif not self.sc: elif not self.sc:
# Distribute the LTK, EDIV and RAND # Distribute the LTK, EDIV and RAND
@@ -1207,11 +1210,11 @@ class Session:
else: else:
# CTKD: Derive LTK from LinkKey # CTKD: Derive LTK from LinkKey
if ( if (
self.connection.transport == BT_BR_EDR_TRANSPORT self.connection.transport == PhysicalTransport.BR_EDR
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
): ):
self.ctkd_task = self.connection.abort_on( self.ctkd_task = utils.cancel_on_event(
'disconnection', self.get_link_key_and_derive_ltk() self.connection, 'disconnection', self.get_link_key_and_derive_ltk()
) )
# Distribute the LTK, EDIV and RAND # Distribute the LTK, EDIV and RAND
elif not self.sc: elif not self.sc:
@@ -1246,7 +1249,7 @@ class Session:
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None: def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
# Set our expectations for what to wait for in the key distribution phase # Set our expectations for what to wait for in the key distribution phase
self.peer_expected_distributions = [] self.peer_expected_distributions = []
if not self.sc and self.connection.transport == BT_LE_TRANSPORT: if not self.sc and self.connection.transport == PhysicalTransport.LE:
if key_distribution_flags & SMP_ENC_KEY_DISTRIBUTION_FLAG != 0: if key_distribution_flags & SMP_ENC_KEY_DISTRIBUTION_FLAG != 0:
self.peer_expected_distributions.append( self.peer_expected_distributions.append(
SMP_Encryption_Information_Command SMP_Encryption_Information_Command
@@ -1303,15 +1306,20 @@ class Session:
# Wait for the pairing process to finish # Wait for the pairing process to finish
assert self.pairing_result assert self.pairing_result
await self.connection.abort_on('disconnection', self.pairing_result) await utils.cancel_on_event(
self.connection, 'disconnection', self.pairing_result
)
def on_disconnection(self, _: int) -> None: def on_disconnection(self, _: int) -> None:
self.connection.remove_listener('disconnection', self.on_disconnection)
self.connection.remove_listener( self.connection.remove_listener(
'connection_encryption_change', self.on_connection_encryption_change self.connection.EVENT_DISCONNECTION, self.on_disconnection
) )
self.connection.remove_listener( self.connection.remove_listener(
'connection_encryption_key_refresh', self.connection.EVENT_CONNECTION_ENCRYPTION_CHANGE,
self.on_connection_encryption_change,
)
self.connection.remove_listener(
self.connection.EVENT_CONNECTION_ENCRYPTION_KEY_REFRESH,
self.on_connection_encryption_key_refresh, self.on_connection_encryption_key_refresh,
) )
self.manager.on_session_end(self) self.manager.on_session_end(self)
@@ -1321,10 +1329,10 @@ class Session:
if self.is_initiator: if self.is_initiator:
self.distribute_keys() self.distribute_keys()
self.connection.abort_on('disconnection', self.on_pairing()) utils.cancel_on_event(self.connection, 'disconnection', self.on_pairing())
def on_connection_encryption_change(self) -> None: def on_connection_encryption_change(self) -> None:
if self.connection.is_encrypted: if self.connection.is_encrypted and not self.completed:
if self.is_responder: if self.is_responder:
# The responder distributes its keys first, the initiator later # The responder distributes its keys first, the initiator later
self.distribute_keys() self.distribute_keys()
@@ -1363,7 +1371,7 @@ class Session:
keys = PairingKeys() keys = PairingKeys()
keys.address_type = peer_address.address_type keys.address_type = peer_address.address_type
authenticated = self.pairing_method != PairingMethod.JUST_WORKS authenticated = self.pairing_method != PairingMethod.JUST_WORKS
if self.sc or self.connection.transport == BT_BR_EDR_TRANSPORT: if self.sc or self.connection.transport == PhysicalTransport.BR_EDR:
keys.ltk = PairingKeys.Key(value=self.ltk, authenticated=authenticated) keys.ltk = PairingKeys.Key(value=self.ltk, authenticated=authenticated)
else: else:
our_ltk_key = PairingKeys.Key( our_ltk_key = PairingKeys.Key(
@@ -1372,8 +1380,10 @@ class Session:
ediv=self.ltk_ediv, ediv=self.ltk_ediv,
rand=self.ltk_rand, rand=self.ltk_rand,
) )
if not self.peer_ltk:
logger.error("peer_ltk is None")
peer_ltk_key = PairingKeys.Key( peer_ltk_key = PairingKeys.Key(
value=self.peer_ltk, value=self.peer_ltk or b'',
authenticated=authenticated, authenticated=authenticated,
ediv=self.peer_ediv, ediv=self.peer_ediv,
rand=self.peer_rand, rand=self.peer_rand,
@@ -1430,8 +1440,10 @@ class Session:
def on_smp_pairing_request_command( def on_smp_pairing_request_command(
self, command: SMP_Pairing_Request_Command self, command: SMP_Pairing_Request_Command
) -> None: ) -> None:
self.connection.abort_on( utils.cancel_on_event(
'disconnection', self.on_smp_pairing_request_command_async(command) self.connection,
'disconnection',
self.on_smp_pairing_request_command_async(command),
) )
async def on_smp_pairing_request_command_async( async def on_smp_pairing_request_command_async(
@@ -1504,7 +1516,7 @@ class Session:
# CTKD over BR/EDR should happen after the connection has been encrypted, # CTKD over BR/EDR should happen after the connection has been encrypted,
# so when receiving pairing requests, responder should start distributing keys # so when receiving pairing requests, responder should start distributing keys
if ( if (
self.connection.transport == BT_BR_EDR_TRANSPORT self.connection.transport == PhysicalTransport.BR_EDR
and self.connection.is_encrypted and self.connection.is_encrypted
and self.is_responder and self.is_responder
and accepted and accepted
@@ -1839,7 +1851,7 @@ class Session:
if self.is_initiator: if self.is_initiator:
if self.pairing_method == PairingMethod.OOB: if self.pairing_method == PairingMethod.OOB:
self.send_pairing_random_command() self.send_pairing_random_command()
else: elif self.pairing_method == PairingMethod.PASSKEY:
self.send_pairing_confirm_command() self.send_pairing_confirm_command()
else: else:
if self.pairing_method == PairingMethod.PASSKEY: if self.pairing_method == PairingMethod.PASSKEY:
@@ -1876,7 +1888,7 @@ class Session:
self.wait_before_continuing = None self.wait_before_continuing = None
self.send_pairing_dhkey_check_command() self.send_pairing_dhkey_check_command()
self.connection.abort_on('disconnection', next_steps()) utils.cancel_on_event(self.connection, 'disconnection', next_steps())
else: else:
self.send_pairing_dhkey_check_command() self.send_pairing_dhkey_check_command()
else: else:
@@ -1920,7 +1932,7 @@ class Session:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Manager(EventEmitter): class Manager(utils.EventEmitter):
''' '''
Implements the Initiator and Responder roles of the Security Manager Protocol Implements the Initiator and Responder roles of the Security Manager Protocol
''' '''
@@ -1948,13 +1960,15 @@ class Manager(EventEmitter):
f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] ' f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}' f'{connection.peer_address}: {command}'
) )
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID cid = (
connection.send_l2cap_pdu(cid, command.to_bytes()) SMP_BR_CID if connection.transport == PhysicalTransport.BR_EDR else SMP_CID
)
connection.send_l2cap_pdu(cid, bytes(command))
def on_smp_security_request_command( def on_smp_security_request_command(
self, connection: Connection, request: SMP_Security_Request_Command self, connection: Connection, request: SMP_Security_Request_Command
) -> None: ) -> None:
connection.emit('security_request', request.auth_req) connection.emit(connection.EVENT_SECURITY_REQUEST, request.auth_req)
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None: def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
# Parse the L2CAP payload into an SMP Command object # Parse the L2CAP payload into an SMP Command object
@@ -1973,7 +1987,7 @@ class Manager(EventEmitter):
# Look for a session with this connection, and create one if none exists # Look for a session with this connection, and create one if none exists
if not (session := self.sessions.get(connection.handle)): if not (session := self.sessions.get(connection.handle)):
if connection.role == BT_CENTRAL_ROLE: if connection.role == Role.CENTRAL:
logger.warning('Remote starts pairing as Peripheral!') logger.warning('Remote starts pairing as Peripheral!')
pairing_config = self.pairing_config_factory(connection) pairing_config = self.pairing_config_factory(connection)
session = self.session_proxy( session = self.session_proxy(
@@ -1993,7 +2007,7 @@ class Manager(EventEmitter):
async def pair(self, connection: Connection) -> None: async def pair(self, connection: Connection) -> None:
# TODO: check if there's already a session for this connection # TODO: check if there's already a session for this connection
if connection.role != BT_CENTRAL_ROLE: if connection.role != Role.CENTRAL:
logger.warning('Start pairing as Peripheral!') logger.warning('Start pairing as Peripheral!')
pairing_config = self.pairing_config_factory(connection) pairing_config = self.pairing_config_factory(connection)
session = self.session_proxy( session = self.session_proxy(

View File

@@ -20,8 +20,13 @@ import logging
import os import os
from typing import Optional from typing import Optional
from .common import Transport, AsyncPipeSink, SnoopingTransport, TransportSpecError from bumble.transport.common import (
from ..snoop import create_snooper Transport,
AsyncPipeSink,
SnoopingTransport,
TransportSpecError,
)
from bumble.snoop import create_snooper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -108,80 +113,80 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
if scheme == 'serial' and spec: if scheme == 'serial' and spec:
from .serial import open_serial_transport from bumble.transport.serial import open_serial_transport
return await open_serial_transport(spec) return await open_serial_transport(spec)
if scheme == 'udp' and spec: if scheme == 'udp' and spec:
from .udp import open_udp_transport from bumble.transport.udp import open_udp_transport
return await open_udp_transport(spec) return await open_udp_transport(spec)
if scheme == 'tcp-client' and spec: if scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport from bumble.transport.tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec) return await open_tcp_client_transport(spec)
if scheme == 'tcp-server' and spec: if scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport from bumble.transport.tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec) return await open_tcp_server_transport(spec)
if scheme == 'ws-client' and spec: if scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport from bumble.transport.ws_client import open_ws_client_transport
return await open_ws_client_transport(spec) return await open_ws_client_transport(spec)
if scheme == 'ws-server' and spec: if scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport from bumble.transport.ws_server import open_ws_server_transport
return await open_ws_server_transport(spec) return await open_ws_server_transport(spec)
if scheme == 'pty': if scheme == 'pty':
from .pty import open_pty_transport from bumble.transport.pty import open_pty_transport
return await open_pty_transport(spec) return await open_pty_transport(spec)
if scheme == 'file': if scheme == 'file':
from .file import open_file_transport from bumble.transport.file import open_file_transport
assert spec is not None assert spec is not None
return await open_file_transport(spec) return await open_file_transport(spec)
if scheme == 'vhci': if scheme == 'vhci':
from .vhci import open_vhci_transport from bumble.transport.vhci import open_vhci_transport
return await open_vhci_transport(spec) return await open_vhci_transport(spec)
if scheme == 'hci-socket': if scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport from bumble.transport.hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec) return await open_hci_socket_transport(spec)
if scheme == 'usb': if scheme == 'usb':
from .usb import open_usb_transport from bumble.transport.usb import open_usb_transport
assert spec assert spec
return await open_usb_transport(spec) return await open_usb_transport(spec)
if scheme == 'pyusb': if scheme == 'pyusb':
from .pyusb import open_pyusb_transport from bumble.transport.pyusb import open_pyusb_transport
assert spec assert spec
return await open_pyusb_transport(spec) return await open_pyusb_transport(spec)
if scheme == 'android-emulator': if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport from bumble.transport.android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec) return await open_android_emulator_transport(spec)
if scheme == 'android-netsim': if scheme == 'android-netsim':
from .android_netsim import open_android_netsim_transport from bumble.transport.android_netsim import open_android_netsim_transport
return await open_android_netsim_transport(spec) return await open_android_netsim_transport(spec)
if scheme == 'unix': if scheme == 'unix':
from .unix import open_unix_client_transport from bumble.transport.unix import open_unix_client_transport
assert spec assert spec
return await open_unix_client_transport(spec) return await open_unix_client_transport(spec)
@@ -204,8 +209,8 @@ async def open_transport_or_link(name: str) -> Transport:
""" """
if name.startswith('link-relay:'): if name.startswith('link-relay:'):
logger.warning('Link Relay has been deprecated.') logger.warning('Link Relay has been deprecated.')
from ..controller import Controller from bumble.controller import Controller
from ..link import RemoteLink # lazy import from bumble.link import RemoteLink # lazy import
link = RemoteLink(name[11:]) link = RemoteLink(name[11:])
await link.wait_until_connected() await link.wait_until_connected()

View File

@@ -20,7 +20,7 @@ import grpc.aio
from typing import Optional, Union from typing import Optional, Union
from .common import ( from bumble.transport.common import (
PumpedTransport, PumpedTransport,
PumpedPacketSource, PumpedPacketSource,
PumpedPacketSink, PumpedPacketSink,
@@ -29,9 +29,13 @@ from .common import (
) )
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub from bumble.transport.grpc_protobuf.emulated_bluetooth_pb2_grpc import (
from .grpc_protobuf.emulated_bluetooth_packets_pb2 import HCIPacket EmulatedBluetoothServiceStub,
from .grpc_protobuf.emulated_bluetooth_vhci_pb2_grpc import VhciForwardingServiceStub )
from bumble.transport.grpc_protobuf.emulated_bluetooth_packets_pb2 import HCIPacket
from bumble.transport.grpc_protobuf.emulated_bluetooth_vhci_pb2_grpc import (
VhciForwardingServiceStub,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -38,15 +38,18 @@ from bumble.transport.common import (
) )
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
from .grpc_protobuf.netsim.packet_streamer_pb2_grpc import ( from bumble.transport.grpc_protobuf.netsim.packet_streamer_pb2_grpc import (
PacketStreamerStub, PacketStreamerStub,
PacketStreamerServicer, PacketStreamerServicer,
add_PacketStreamerServicer_to_server, add_PacketStreamerServicer_to_server,
) )
from .grpc_protobuf.netsim.packet_streamer_pb2 import PacketRequest, PacketResponse from bumble.transport.grpc_protobuf.netsim.packet_streamer_pb2 import (
from .grpc_protobuf.netsim.hci_packet_pb2 import HCIPacket PacketRequest,
from .grpc_protobuf.netsim.startup_pb2 import Chip, ChipInfo, DeviceInfo PacketResponse,
from .grpc_protobuf.netsim.common_pb2 import ChipKind )
from bumble.transport.grpc_protobuf.netsim.hci_packet_pb2 import HCIPacket
from bumble.transport.grpc_protobuf.netsim.startup_pb2 import Chip, ChipInfo, DeviceInfo
from bumble.transport.grpc_protobuf.netsim.common_pb2 import ChipKind
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -139,6 +139,7 @@ class PacketParser:
packet_type packet_type
) or self.extended_packet_info.get(packet_type) ) or self.extended_packet_info.get(packet_type)
if self.packet_info is None: if self.packet_info is None:
self.reset()
raise core.InvalidPacketError( raise core.InvalidPacketError(
f'invalid packet type {packet_type}' f'invalid packet type {packet_type}'
) )
@@ -302,7 +303,10 @@ class ParserSource(BaseSource):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class StreamPacketSource(asyncio.Protocol, ParserSource): class StreamPacketSource(asyncio.Protocol, ParserSource):
def data_received(self, data: bytes) -> None: def data_received(self, data: bytes) -> None:
self.parser.feed_data(data) try:
self.parser.feed_data(data)
except core.InvalidPacketError:
logger.warning("invalid packet, ignoring data")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -370,11 +374,13 @@ class PumpedPacketSource(ParserSource):
self.parser.feed_data(packet) self.parser.feed_data(packet)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.debug('source pump task done') logger.debug('source pump task done')
self.terminated.set_result(None) if not self.terminated.done():
self.terminated.set_result(None)
break break
except Exception as error: except Exception as error:
logger.warning(f'exception while waiting for packet: {error}') logger.warning(f'exception while waiting for packet: {error}')
self.terminated.set_exception(error) if not self.terminated.done():
self.terminated.set_exception(error)
break break
self.pump_task = asyncio.create_task(pump_packets()) self.pump_task = asyncio.create_task(pump_packets())

View File

@@ -19,7 +19,7 @@ import asyncio
import io import io
import logging import logging
from .common import Transport, StreamPacketSource, StreamPacketSink from bumble.transport.common import Transport, StreamPacketSource, StreamPacketSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging

View File

@@ -25,7 +25,7 @@ import collections
from typing import Optional from typing import Optional
from .common import Transport, ParserSource from bumble.transport.common import Transport, ParserSource
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -25,7 +25,7 @@ import logging
from typing import Optional from typing import Optional
from .common import Transport, StreamPacketSource, StreamPacketSink from bumble.transport.common import Transport, StreamPacketSource, StreamPacketSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging

View File

@@ -29,9 +29,9 @@ from usb.core import USBError
from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER
from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB
from .common import Transport, ParserSource, TransportInitError from bumble.transport.common import Transport, ParserSource, TransportInitError
from .. import hci from bumble import hci
from ..colors import color from bumble.colors import color
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -19,7 +19,7 @@ import asyncio
import logging import logging
import serial_asyncio import serial_asyncio
from .common import Transport, StreamPacketSource, StreamPacketSink from bumble.transport.common import Transport, StreamPacketSource, StreamPacketSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging

View File

@@ -18,7 +18,7 @@
import asyncio import asyncio
import logging import logging
from .common import Transport, StreamPacketSource, StreamPacketSink from bumble.transport.common import Transport, StreamPacketSource, StreamPacketSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging

View File

@@ -20,7 +20,7 @@ import asyncio
import logging import logging
import socket import socket
from .common import Transport, StreamPacketSource from bumble.transport.common import Transport, StreamPacketSource
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging

View File

@@ -18,7 +18,7 @@
import asyncio import asyncio
import logging import logging
from .common import Transport, ParserSource from bumble.transport.common import Transport, ParserSource
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging

View File

@@ -18,7 +18,7 @@
import asyncio import asyncio
import logging import logging
from .common import Transport, StreamPacketSource, StreamPacketSink from bumble.transport.common import Transport, StreamPacketSource, StreamPacketSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging

View File

@@ -115,9 +115,7 @@ async def open_usb_transport(spec: str) -> Transport:
self.acl_out = acl_out self.acl_out = acl_out
self.acl_out_transfer = device.getTransfer() self.acl_out_transfer = device.getTransfer()
self.acl_out_transfer_ready = asyncio.Semaphore(1) self.acl_out_transfer_ready = asyncio.Semaphore(1)
self.packets: asyncio.Queue[bytes] = ( self.packets = asyncio.Queue[bytes]() # Queue of packets waiting to be sent
asyncio.Queue()
) # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.queue_task = None self.queue_task = None
self.cancel_done = self.loop.create_future() self.cancel_done = self.loop.create_future()
@@ -149,7 +147,10 @@ async def open_usb_transport(spec: str) -> Transport:
if status != usb1.TRANSFER_COMPLETED: if status != usb1.TRANSFER_COMPLETED:
logger.warning( logger.warning(
color(f'!!! OUT transfer not completed: status={status}', 'red') color(
f'!!! OUT transfer not completed: status={status}',
'red',
)
) )
async def process_queue(self): async def process_queue(self):
@@ -275,7 +276,10 @@ async def open_usb_transport(spec: str) -> Transport:
) )
else: else:
logger.warning( logger.warning(
color(f'!!! IN transfer not completed: status={status}', 'red') color(
f'!!! IN[{packet_type}] transfer not completed: status={status}',
'red',
)
) )
self.loop.call_soon_threadsafe(self.on_transport_lost) self.loop.call_soon_threadsafe(self.on_transport_lost)

View File

@@ -19,8 +19,8 @@ import logging
from typing import Optional from typing import Optional
from .common import Transport from bumble.transport.common import Transport
from .file import open_file_transport from bumble.transport.file import open_file_transport
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging

View File

@@ -18,7 +18,12 @@
import logging import logging
import websockets.client import websockets.client
from .common import PumpedPacketSource, PumpedPacketSink, PumpedTransport, Transport from bumble.transport.common import (
PumpedPacketSource,
PumpedPacketSink,
PumpedTransport,
Transport,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging

View File

@@ -18,7 +18,7 @@
import logging import logging
import websockets import websockets
from .common import Transport, ParserSource, PumpedPacketSink from bumble.transport.common import Transport, ParserSource, PumpedPacketSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging

View File

@@ -24,21 +24,24 @@ import logging
import sys import sys
import warnings import warnings
from typing import ( from typing import (
Awaitable,
Set,
TypeVar,
List,
Tuple,
Callable,
Any, Any,
Awaitable,
Callable,
List,
Optional, Optional,
Protocol,
Set,
Tuple,
TypeVar,
Union, Union,
overload, overload,
) )
from typing_extensions import Self
from pyee import EventEmitter import pyee
import pyee.asyncio
from .colors import color from bumble.colors import color
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -54,6 +57,48 @@ def setup_event_forwarding(emitter, forwarder, event_name):
emitter.on(event_name, emit) emitter.on(event_name, emit)
# -----------------------------------------------------------------------------
def wrap_async(function):
"""
Wraps the provided function in an async function.
"""
return functools.partial(async_call, function)
# -----------------------------------------------------------------------------
def deprecated(msg: str):
"""
Throw deprecation warning before execution.
"""
def wrapper(function):
@functools.wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return function(*args, **kwargs)
return inner
return wrapper
# -----------------------------------------------------------------------------
def experimental(msg: str):
"""
Throws a future warning before execution.
"""
def wrapper(function):
@functools.wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, FutureWarning, stacklevel=2)
return function(*args, **kwargs)
return inner
return wrapper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def composite_listener(cls): def composite_listener(cls):
""" """
@@ -111,21 +156,23 @@ class EventWatcher:
``` ```
''' '''
handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]] handlers: List[Tuple[pyee.EventEmitter, str, Callable[..., Any]]]
def __init__(self) -> None: def __init__(self) -> None:
self.handlers = [] self.handlers = []
@overload @overload
def on( def on(
self, emitter: EventEmitter, event: str self, emitter: pyee.EventEmitter, event: str
) -> Callable[[_Handler], _Handler]: ... ) -> Callable[[_Handler], _Handler]: ...
@overload @overload
def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler: ... def on(
self, emitter: pyee.EventEmitter, event: str, handler: _Handler
) -> _Handler: ...
def on( def on(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None self, emitter: pyee.EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]: ) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event until the context is closed. '''Watch an event until the context is closed.
@@ -145,16 +192,16 @@ class EventWatcher:
@overload @overload
def once( def once(
self, emitter: EventEmitter, event: str self, emitter: pyee.EventEmitter, event: str
) -> Callable[[_Handler], _Handler]: ... ) -> Callable[[_Handler], _Handler]: ...
@overload @overload
def once( def once(
self, emitter: EventEmitter, event: str, handler: _Handler self, emitter: pyee.EventEmitter, event: str, handler: _Handler
) -> _Handler: ... ) -> _Handler: ...
def once( def once(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None self, emitter: pyee.EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]: ) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event for once. '''Watch an event for once.
@@ -182,38 +229,48 @@ class EventWatcher:
_T = TypeVar('_T') _T = TypeVar('_T')
class AbortableEventEmitter(EventEmitter): def cancel_on_event(
def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]: emitter: pyee.EventEmitter, event: str, awaitable: Awaitable[_T]
""" ) -> Awaitable[_T]:
Set a coroutine or future to abort when an event occur. """Set a coroutine or future to cancel when an event occur."""
""" future = asyncio.ensure_future(awaitable)
future = asyncio.ensure_future(awaitable) if future.done():
if future.done():
return future
def on_event(*_):
if future.done():
return
msg = f'abort: {event} event occurred.'
if isinstance(future, asyncio.Task):
# python < 3.9 does not support passing a message on `Task.cancel`
if sys.version_info < (3, 9, 0):
future.cancel()
else:
future.cancel(msg)
else:
future.set_exception(asyncio.CancelledError(msg))
def on_done(_):
self.remove_listener(event, on_event)
self.on(event, on_event)
future.add_done_callback(on_done)
return future return future
def on_event(*args, **kwargs) -> None:
del args, kwargs
if future.done():
return
msg = f'abort: {event} event occurred.'
if isinstance(future, asyncio.Task):
# python < 3.9 does not support passing a message on `Task.cancel`
if sys.version_info < (3, 9, 0):
future.cancel()
else:
future.cancel(msg)
else:
future.set_exception(asyncio.CancelledError(msg))
def on_done(_):
emitter.remove_listener(event, on_event)
emitter.on(event, on_event)
future.add_done_callback(on_done)
return future
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CompositeEventEmitter(AbortableEventEmitter): class EventEmitter(pyee.asyncio.AsyncIOEventEmitter):
"""A Base EventEmitter for Bumble."""
@deprecated("Use `cancel_on_event` instead.")
def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]:
"""Set a coroutine or future to abort when an event occur."""
return cancel_on_event(self, event, awaitable)
# -----------------------------------------------------------------------------
class CompositeEventEmitter(EventEmitter):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._listener = None self._listener = None
@@ -428,48 +485,6 @@ async def async_call(function, *args, **kwargs):
return function(*args, **kwargs) return function(*args, **kwargs)
# -----------------------------------------------------------------------------
def wrap_async(function):
"""
Wraps the provided function in an async function.
"""
return functools.partial(async_call, function)
# -----------------------------------------------------------------------------
def deprecated(msg: str):
"""
Throw deprecation warning before execution.
"""
def wrapper(function):
@functools.wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return function(*args, **kwargs)
return inner
return wrapper
# -----------------------------------------------------------------------------
def experimental(msg: str):
"""
Throws a future warning before execution.
"""
def wrapper(function):
@functools.wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, FutureWarning)
return function(*args, **kwargs)
return inner
return wrapper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class OpenIntEnum(enum.IntEnum): class OpenIntEnum(enum.IntEnum):
""" """
@@ -487,3 +502,26 @@ class OpenIntEnum(enum.IntEnum):
obj._value_ = value obj._value_ = value
obj._name_ = f"{cls.__name__}[{value}]" obj._name_ = f"{cls.__name__}[{value}]"
return obj return obj
# -----------------------------------------------------------------------------
class ByteSerializable(Protocol):
"""
Type protocol for classes that can be instantiated from bytes and serialized
to bytes.
"""
@classmethod
def from_bytes(cls, data: bytes) -> Self: ...
def __bytes__(self) -> bytes: ...
# -----------------------------------------------------------------------------
class IntConvertible(Protocol):
"""
Type protocol for classes that can be instantiated from int and converted to int.
"""
def __init__(self, value: int) -> None: ...
def __int__(self) -> int: ...

View File

@@ -16,6 +16,7 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import struct import struct
from typing import Dict, Optional, Type
from bumble.hci import ( from bumble.hci import (
name_or_number, name_or_number,
@@ -24,7 +25,9 @@ from bumble.hci import (
HCI_Constant, HCI_Constant,
HCI_Object, HCI_Object,
HCI_Command, HCI_Command,
HCI_Vendor_Event, HCI_Event,
HCI_Extended_Event,
HCI_VENDOR_EVENT,
STATUS_SPEC, STATUS_SPEC,
) )
@@ -48,7 +51,6 @@ HCI_DYNAMIC_AUDIO_BUFFER_COMMAND = hci_vendor_command_op_code(0x15F)
HCI_BLUETOOTH_QUALITY_REPORT_EVENT = 0x58 HCI_BLUETOOTH_QUALITY_REPORT_EVENT = 0x58
HCI_Command.register_commands(globals()) HCI_Command.register_commands(globals())
HCI_Vendor_Event.register_subevents(globals())
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -279,7 +281,29 @@ class HCI_Dynamic_Audio_Buffer_Command(HCI_Command):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Vendor_Event.event( class HCI_Android_Vendor_Event(HCI_Extended_Event):
event_code: int = HCI_VENDOR_EVENT
subevent_classes: Dict[int, Type[HCI_Extended_Event]] = {}
@classmethod
def subclass_from_parameters(
cls, parameters: bytes
) -> Optional[HCI_Extended_Event]:
subevent_code = parameters[0]
if subevent_code == HCI_BLUETOOTH_QUALITY_REPORT_EVENT:
quality_report_id = parameters[1]
if quality_report_id in (0x01, 0x02, 0x03, 0x04, 0x07, 0x08, 0x09):
return HCI_Bluetooth_Quality_Report_Event.from_parameters(parameters)
return None
HCI_Android_Vendor_Event.register_subevents(globals())
HCI_Event.add_vendor_factory(HCI_Android_Vendor_Event.subclass_from_parameters)
# -----------------------------------------------------------------------------
@HCI_Extended_Event.event(
fields=[ fields=[
('quality_report_id', 1), ('quality_report_id', 1),
('packet_types', 1), ('packet_types', 1),
@@ -308,10 +332,11 @@ class HCI_Dynamic_Audio_Buffer_Command(HCI_Command):
('tx_last_subevent_packets', 4), ('tx_last_subevent_packets', 4),
('crc_error_packets', 4), ('crc_error_packets', 4),
('rx_duplicate_packets', 4), ('rx_duplicate_packets', 4),
('rx_unreceived_packets', 4),
('vendor_specific_parameters', '*'), ('vendor_specific_parameters', '*'),
] ]
) )
class HCI_Bluetooth_Quality_Report_Event(HCI_Vendor_Event): class HCI_Bluetooth_Quality_Report_Event(HCI_Android_Vendor_Event):
# pylint: disable=line-too-long # pylint: disable=line-too-long
''' '''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#bluetooth-quality-report-sub-event See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#bluetooth-quality-report-sub-event

View File

@@ -39,12 +39,14 @@ nav:
- Drivers: - Drivers:
- drivers/index.md - drivers/index.md
- Realtek: drivers/realtek.md - Realtek: drivers/realtek.md
- Intel: drivers/intel.md
- API: - API:
- Guide: api/guide.md - Guide: api/guide.md
- Examples: api/examples.md - Examples: api/examples.md
- Reference: api/reference.md - Reference: api/reference.md
- Apps & Tools: - Apps & Tools:
- apps_and_tools/index.md - apps_and_tools/index.md
- Auracast: apps_and_tools/auracast.md
- Console: apps_and_tools/console.md - Console: apps_and_tools/console.md
- Bench: apps_and_tools/bench.md - Bench: apps_and_tools/bench.md
- Speaker: apps_and_tools/speaker.md - Speaker: apps_and_tools/speaker.md
@@ -108,8 +110,8 @@ markdown_extensions:
- pymdownx.details - pymdownx.details
- pymdownx.superfences - pymdownx.superfences
- pymdownx.emoji: - pymdownx.emoji:
emoji_index: !!python/name:materialx.emoji.twemoji emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:materialx.emoji.to_svg emoji_generator: !!python/name:material.extensions.emoji.to_svg
- pymdownx.tabbed: - pymdownx.tabbed:
alternate_style: true alternate_style: true
- codehilite: - codehilite:

View File

@@ -0,0 +1,202 @@
AURACAST TOOL
=============
The "auracast" tool implements commands that implement broadcasting, receiving
and controlling LE Audio broadcasts.
=== "Running as an installed package"
```
$ bumble-auracast
```
=== "Running from source"
```
$ python3 apps/auracast.py <args>
```
# Python Dependencies
Try installing the optional `[auracast]` dependencies:
=== "From source"
```bash
$ python3 -m pip install ".[auracast]"
```
=== "From PyPI"
```bash
$ python3 -m pip install "bumble[auracast]"
```
## LC3
The `auracast` app depends on the `lc3` python module, which is available
either as PyPI module (currently only available for Linux x86_64).
When installing Bumble with the optional `auracast` dependency, the `lc3`
module will be installed from the `lc3py` PyPI package if available.
If not, you will need to install it separately. This can be done with:
```bash
$ python3 -m pip install "git+https://github.com/google/liblc3.git"
```
## SoundDevice
The `sounddevice` module is required for audio output to the host's sound
output device(s) and/or input from the host's input device(s).
If not installed, the `auracast` app is still functional, but will be limited
to non-device inputs and output (files, external processes, ...)
On macOS and Windows, the `sounddevice` module gets installed with the
native PortAudio libraries included.
For Linux, however, PortAudio must be installed separately.
This is typically done with a command like:
```bash
$ sudo apt install libportaudio2
```
Visit the [sounddevice documentation](https://python-sounddevice.readthedocs.io/)
for details.
# General Usage
```
Usage: bumble-auracast [OPTIONS] COMMAND [ARGS]...
Options:
--help Show this message and exit.
Commands:
assist Scan for broadcasts on behalf of an audio server
pair Pair with an audio server
receive Receive a broadcast source
scan Scan for public broadcasts
transmit Transmit a broadcast source
```
Use `bumble-auracast <command> --help` to get more detailed usage information
for a specific `<command>`.
## `assist`
Act as a broadcast assistant.
Use `bumble-auracast assist --help` for details on the commands and options.
The assistant commands are:
### `monitor-state`
Subscribe to the state characteristic and monitor changes.
### `add-source`
Add a broadcast source. This will instruct the device to start
receiving a broadcast.
### `modify-source`
Modify a broadcast source.
### `remove-source`
Remote a broadcast source.
## `pair`
Pair with a device.
## `receive`
Receive a broadcast source.
The `--output` option specifies where to send the decoded audio samples.
The following outputs are supported:
### Sound Device
The `--output` argument is either `device`, to send the audio to the hosts's default sound device, or `device:<DEVICE_ID>` where `<DEVICE_ID>`
is the integer ID of one of the available sound devices.
When invoked with `--output "device:?"`, a list of available devices and
their IDs is printed out.
### Standard Output
With `--output stdout`, the decoded audio samples are written to the
standard output (currently always as float32 PCM samples)
### FFPlay
With `--output ffplay`, the decoded audio samples are piped to `ffplay`
in a child process. This option is only available if `ffplay` is a command that is available on the host.
### File
With `--output <filename>` or `--output file:<filename>`, the decoded audio
samples are written to a file (currently always as float32 PCM)
## `transmit`
Broadcast an audio source as a transmitter.
The `--input` and `--input-format` options specify what audio input
source to transmit.
The following inputs are supported:
### Sound Device
The `--input` argument is either `device`, to use the host's default sound
device (typically a builtin microphone), or `device:<DEVICE_ID>` where
`<DEVICE_ID>` is the integer ID of one of the available sound devices.
When invoked with `--input "device:?"`, a list of available devices and their
IDs is printed out.
### Standard Input
With `--input stdout`, the audio samples are read from the standard input.
(currently always as int16 PCM).
### File
With `--input <filename>` or `--input file:<filename>`, the audio samples
are read from a .wav or raw PCM file.
Use the `--input-format <FORMAT>` option to specify the format of the audio
samples in raw PCM files. `<FORMAT>` is expressed as:
`<sample-type>,<sample-rate>,<channels>`
(the only supported <sample-type> currently is 'int16le' for 16 bit signed integers with little-endian byte order)
## `scan`
Scan for public broadcasts.
A live display of the available broadcasts is displayed continuously.
# Compatibility With Some Products
The `auracast` app has been tested for compatibility with a few products.
The list is still very limited. Please let us know if there are products
that are not working well, or if there are specific instructions that should
be shared to allow better compatibiity with certain products.
## Transmitters
The `receive` command has been tested to successfully receive broadcasts from
the following transmitters:
* JBL GO 4
* Flairmesh FlooGoo FMA120
* Eppfun AK3040Pro Max
* HIGHGAZE BA-25T
* Nexum Audio VOCE and USB dongle
## Receivers
### Pixel Buds Pro 2
The Pixel Buds Pro 2 can be used as a broadcast receiver, controlled by the
`auracast assist` command, instructing the buds to receive a broadcast.
Use the `assist --command add-source` command to tell the buds to receive a
broadcast.
Use the `assist --command monitor-state` command to monitor the current sync/receive
state of the buds.
### JBL
The JBL GO 4 and other JBL products that support the Auracast feature can be used
as transmitters or receivers.
When running in receiver mode (pressing the Auracast button while not already playing),
the JBL speaker will scan for broadcast advertisements with a specific manufacturer data.
Use the `--manufacturer-data` option of the `transmit` command in order to include data
that will let the speaker recognize the broadcast as a compatible source.
The manufacturer ID for JBL is 87.
Using an option like `--manufacturer-data 87:00000000000000000000000000000000dffd` should work (tested on the
JBL GO 4. The `dffd` value at the end of the payload may be different on other models?).
### Others
* Nexum Audio VOCE and USB dongle

View File

@@ -4,12 +4,13 @@ APPS & TOOLS
Included in the project are a few apps and tools, built on top of the core libraries. Included in the project are a few apps and tools, built on top of the core libraries.
These include: These include:
* [Console](console.md) - an interactive text-based console * [Auracast](auracast.md) - Commands to broadcast, receive and/or control LE Audio.
* [Bench](bench.md) - Speed and Latency benchmarking between two devices (LE and Classic) * [Console](console.md) - An interactive text-based console.
* [Pair](pair.md) - Pair/bond two devices (LE and Classic) * [Bench](bench.md) - Speed and Latency benchmarking between two devices (LE and Classic).
* [Unbond](unbond.md) - Remove a previously established bond * [Pair](pair.md) - Pair/bond two devices (LE and Classic).
* [HCI Bridge](hci_bridge.md) - a HCI transport bridge to connect two HCI transports and filter/snoop the HCI packets * [Unbond](unbond.md) - Remove a previously established bond.
* [Golden Gate Bridge](gg_bridge.md) - a bridge between GATT and UDP to use with the Golden Gate "stack tool" * [HCI Bridge](hci_bridge.md) - An HCI transport bridge to connect two HCI transports and filter/snoop the HCI packets.
* [Show](show.md) - Parse a file with HCI packets and print the details of each packet in a human readable form * [Golden Gate Bridge](gg_bridge.md) - Bridge between GATT and UDP to use with the Golden Gate "stack tool".
* [Show](show.md) - Parse a file with HCI packets and print the details of each packet in a human readable form.
* [Speaker](speaker.md) - Virtual Bluetooth speaker, with a command line and browser-based UI. * [Speaker](speaker.md) - Virtual Bluetooth speaker, with a command line and browser-based UI.
* [Link Relay](link_relay.md) - WebSocket relay for virtual RemoteLink instances to communicate with each other. * [Link Relay](link_relay.md) - WebSocket relay for virtual RemoteLink instances to communicate with each other.

View File

@@ -17,3 +17,4 @@ USB vendor ID and product ID.
Drivers included in the module are: Drivers included in the module are:
* [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles. * [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles.
* [Intel](intel.md): Loading of Firmware and Config for Intel USB controllers.

View File

@@ -0,0 +1,73 @@
INTEL DRIVER
==============
This driver supports loading firmware images and optional config data to
Intel USB controllers.
A number of USB dongles are supported, but likely not all.
The initial implementation has been tested on BE200 and AX210 controllers.
When using a USB controller, the USB product ID and vendor ID are used
to find whether a matching set of firmware image and config data
is needed for that specific model. If a match exists, the driver will try
load the firmware image and, if needed, config data.
Alternatively, the metadata property ``driver=intel`` may be specified in a transport
name to force that driver to be used (ex: ``usb:[driver=intel]0`` instead of just
``usb:0`` for the first USB device).
The driver will look for the firmware and config files by name, in order, in:
* The directory specified by the environment variable `BUMBLE_INTEL_FIRMWARE_DIR`
if set.
* The directory `<package-dir>/drivers/intel_fw` where `<package-dir>` is the directory
where the `bumble` package is installed.
* The current directory.
It is also possible to override or extend the config data with parameters passed via the
transport name. The driver name `intel` may be suffixed with `/<param:value>[+<param:value>]...`
The supported params are:
* `ddc_addon`: configuration data to add to the data loaded from the config data file
* `ddc_override`: configuration data to use instead of the data loaded from the config data file.
With both `dcc_addon` and `dcc_override`, the param value is a hex-encoded byte array containing
the config data (same format as the config file).
Example transport name:
`usb:[driver=intel/dcc_addon:03E40200]0`
Obtaining Firmware Images and Config Data
-----------------------------------------
Firmware images and config data may be obtained from a variety of online
sources.
To facilitate finding a downloading the, the utility program `bumble-intel-fw-download`
may be used.
```
Usage: bumble-intel-fw-download [OPTIONS]
Download Intel firmware images and configs.
Options:
--output-dir TEXT Output directory where the files will be saved.
Defaults to the OS-specific app data dir, which the
driver will check when trying to find firmware
--source [linux-kernel] [default: linux-kernel]
--single TEXT Only download a single image set, by its base name
--force Overwrite files if they already exist
--help Show this message and exit.
```
Utility
-------
The `bumble-intel-util` utility may be used to interact with an Intel USB controller.
```
Usage: bumble-intel-util [OPTIONS] COMMAND [ARGS]...
Options:
--help Show this message and exit.
Commands:
bootloader Reboot in bootloader mode.
info Get the firmware info.
load Load a firmware image.
```

Some files were not shown because too many files have changed in this diff Show More