Compare commits

...

447 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
c80f89d20f update cryptography dependency 2024-12-18 22:01:42 -05:00
Gilles Boccon-Gibod
62e4670a39 Merge pull request #606 from wpiet/gmap-wip
Add `Gaming Audio Profile`
2024-12-18 11:56:57 -05:00
zxzxwu
99695bb264 Merge pull request #610 from zxzxwu/cfg
Remove setup.py and setup.cfg
2024-12-19 00:53:12 +08:00
Josh Wu
eb54898106 Remove setup.py and setup.cfg 2024-12-19 00:45:13 +08:00
Gilles Boccon-Gibod
4f5ee204d2 Update code-check.yml
Hot fix because 3.13.1 somehow breaks the current version of pylint. Will revert to 3.13 without pining to 3.13.0 when pylint is fixed
2024-12-18 11:36:08 -05:00
Wojciech Pietraszewski
2552e21db1 Add characteristics initial values
Sets default values for characteristics if not specified explicitly
2024-12-04 17:00:29 +01:00
Wojciech Pietraszewski
6168f87e2f Add characteristics conditionally
Only adds a characteristic if the corresponding role has been set
2024-12-04 12:57:34 +01:00
Gilles Boccon-Gibod
ca7d2ca4df Merge pull request #607 from google/gbg/pandora-deps
move pandora deps to development
2024-12-03 09:42:44 -08:00
Gilles Boccon-Gibod
60723323e9 move pandora deps to development 2024-12-03 09:08:30 -08:00
Gilles Boccon-Gibod
3ce7b9255b Merge pull request #598 from google/gbg/gatt-class-adapter
Add a class-based GATT adapter
2024-12-03 08:46:30 -08:00
Gilles Boccon-Gibod
97fcfc2fa0 Merge pull request #604 from jmdietrich-gcx/add_encryption_key_size_to_pairing_config
Add maximum encryption key size to PairingDelegate
2024-12-03 08:30:53 -08:00
Wojciech Pietraszewski
19674e3758 Add Gaming Audio Profile
Adds initial support for `Gaming Audio Service`.
2024-12-02 11:15:10 +01:00
Jan-Marcel Dietrich
1130e1db8f Fix code formatting 2024-12-02 09:01:18 +01:00
Gilles Boccon-Gibod
37c7f3a58a Merge pull request #603 from google/gbg/fix-pair-oob
fix oob support in pair.py
2024-12-01 08:43:04 -08:00
Gilles Boccon-Gibod
0a12b2bf2e Merge pull request #585 from wpiet/vocs
Add `Volume Offset Control Service`
2024-11-29 10:41:30 -08:00
Gilles Boccon-Gibod
d014acbe63 Merge pull request #597 from google/gbg/intel-hci
intel hci
2024-11-29 10:41:10 -08:00
Jan-Marcel Dietrich
07f9997a49 Add maximum encryption key size to PairingDelegate
So far the maxmium encryption key size has been hardcoded to 16 bytes in
'send_pairing_request_command()' and 'send_pairing_response_comman()'. By
making this configurable via the PairingDelegate, one can test how devices
respond to smaller encryption key sizes. Default remains 16 bytes.
2024-11-28 14:15:51 +01:00
Gilles Boccon-Gibod
b9f91f695a fix oob support in pair.py 2024-11-27 12:58:03 -08:00
Gilles Boccon-Gibod
082d55af10 Merge pull request #599 from google/gbg/hfp-19
add super wide band constants
2024-11-25 07:47:40 -08:00
Gilles Boccon-Gibod
4c3fd5688d Merge pull request #600 from google/gbg/unify-to-bytes
only use `__bytes__` when not argument is needed.
2024-11-25 07:44:17 -08:00
Gilles Boccon-Gibod
9d3d5495ce only use __bytes__ when not argument is needed. 2024-11-23 15:56:14 -08:00
Gilles Boccon-Gibod
b3869f267c add super wide band constants 2024-11-23 09:27:03 -08:00
Gilles Boccon-Gibod
8715333706 Add a GATT adapter that uses from_bytes and __bytes__ as conversion methods. 2024-11-23 09:13:04 -08:00
Gilles Boccon-Gibod
b57096abe2 Merge pull request #595 from wpiet/aics-opcode-fix
Amend Opcode value in `Audio Input Control Service`
2024-11-23 08:56:23 -08:00
Gilles Boccon-Gibod
48685c8587 improve vendor event support 2024-11-23 08:55:50 -08:00
Wojciech Pietraszewski
100bea6b41 Fix typos
Amends the typo in the `INACTIVE` field in `Audio Input Status` characteristic.
Amends the typo in the log message of `_set_gain_settings` method.
2024-11-21 18:29:44 +01:00
Wojciech Pietraszewski
63819bf9dd Amend Opcode value in Audio Input Control Service
Corrects the Audio Input Control Point
Opcode value for `Set Gain Setting` field.
2024-11-21 16:40:49 +01:00
Wojciech Pietraszewski
6e55390930 Add Volume Offset Control Service
Adds initial support for VOCS.
2024-11-21 11:56:14 +01:00
zxzxwu
e3fdab4175 Merge pull request #593 from zxzxwu/periodic
Support Periodic Advertising
2024-11-19 17:22:37 +08:00
Josh Wu
bbcd14dbf0 Support Periodic Advertising 2024-11-19 16:27:13 +08:00
zxzxwu
01dc0d574b Merge pull request #590 from SergeantSerk/parse-scan-response-data
Correctly parse scan response from device config
2024-11-17 15:39:11 +08:00
zxzxwu
5e959d638e Merge pull request #591 from zxzxwu/auracast_scan
Improve Broadcast Scanning
2024-11-16 04:10:27 +08:00
Gilles Boccon-Gibod
8d908288c8 Merge pull request #583 from google/gbg/more-gatt-tests
regression test for GATT unsubscription
2024-11-15 10:19:20 -08:00
Josh Wu
c88b32a406 Improve Broadcast Scanning 2024-11-16 02:02:28 +08:00
zxzxwu
5a72eefb89 Merge pull request #587 from zxzxwu/device
Replace HCI member imports in device.py
2024-11-13 15:25:32 +08:00
Josh Wu
430046944b Replace HCI member import in device.py 2024-11-12 16:53:21 +08:00
zxzxwu
21d23320eb Merge pull request #584 from zxzxwu/commands6.0
Add Core Spec 6.0 new commands support mapping
2024-11-12 04:17:24 +00:00
Serkan
d0990ee04d Correctly parse scan response from device config
Parses scan response data correctly just like advertising data
2024-11-07 21:49:33 +03:00
Josh Wu
2d88e853e8 Add Core Spec 6.0 new commands support mapping 2024-11-07 14:36:54 +08:00
Gilles Boccon-Gibod
a060a70fba Merge pull request #583 from google/gbg/more-gatt-tests
regression test for GATT unsubscription
2024-11-04 13:03:57 -08:00
Gilles Boccon-Gibod
a06394ad4a Merge pull request #582 from google/gbg/580
fix #580
2024-11-04 13:03:15 -08:00
Gilles Boccon-Gibod
a1414c2b5b add unsubscribe test 2024-11-03 19:08:27 -08:00
Gilles Boccon-Gibod
b2864dac2d fix #580 2024-11-02 10:29:40 -07:00
Gilles Boccon-Gibod
b78f895143 Merge pull request #579 from jmdietrich-gcx/unsubscribe_characteristic_in_gatt_client
Remove characteristic in GATT Client unsubscribe() if it's the last subscriber
2024-10-31 04:07:02 -07:00
zxzxwu
c4e9726828 Merge pull request #581 from zxzxwu/context
[BAP] Add missing Unspecified context type
2024-10-31 11:04:25 +00:00
Gilles Boccon-Gibod
d4b8e8348a Merge pull request #574 from google/gbg/update-python-versions
remove test for deprecated Python 3.8 and add 3.13
2024-10-31 03:44:01 -07:00
Josh Wu
19debaa52e [BAP] Add missing Unspecified context type 2024-10-31 18:11:40 +08:00
Jan-Marcel Dietrich
73fe564321 Remove characteristic in GATT Client unsubscribe() if it's the last subscriber
GATT Client's subscribe() adds the characteristic itself as subscriber.
Therefore the characteristic has to be removed in unsubscribe(), if it's
the last subscriber. Otherwise the clean up does not work correctly and
the CCCD never is set back to 0 in the remote device.
2024-10-30 07:34:22 +01:00
Gilles Boccon-Gibod
a00abd65b3 fix some linter warnings 2024-10-28 12:30:37 -07:00
Gilles Boccon-Gibod
f169ceaebb update linter and type checker 2024-10-28 12:30:32 -07:00
Gilles Boccon-Gibod
528af0d338 remove test for deprecated Python 3.8 and add 3.13 2024-10-28 12:29:21 -07:00
Gilles Boccon-Gibod
4b25eed869 Merge pull request #570 from google/gbg/bench-mobly-snippets
bench mobly snippets
2024-10-28 10:25:28 -07:00
Gilles Boccon-Gibod
fcd6bd7136 address PR comments 2024-10-28 10:13:55 -07:00
Gilles Boccon-Gibod
32642c5d7c Merge pull request #576 from google/gbg/netsim-device-info
update to new netsim proto with DeviceInfo
2024-10-25 04:43:00 -07:00
Gilles Boccon-Gibod
ff8b0c375d add support for netsim device info variant 2024-10-25 04:37:30 -07:00
Gilles Boccon-Gibod
ae0228aeb8 Merge pull request #578 from jmdietrich-gcx/add_missing_parameter_to_att_execute_write
Add missing parameter 'flags' to ATT_Execute_Write_Request PDU
2024-10-25 02:57:24 -07:00
Jan-Marcel Dietrich
5d2dac18c8 Add missing parameter 'flags' to ATT_Execute_Write_Request PDU
Bluetooth spec @ Vol 3, Part F - 3.4.6.3 Table 3.36 shows that the
ATT_EXECUTE_WRITE_REQ PDU contains the parameter 'Flags' with size 1
octet, which allows to cancel all prepared writes (0x00) or to
immediately write all pending prepared values (0x01).
2024-10-24 15:08:10 +02:00
zxzxwu
d03fc14cfd Merge pull request #573 from ypomortsev/yegor
HFP: Fix reading multiple AT commands from a single data packet
2024-10-23 13:23:58 +08:00
Gilles Boccon-Gibod
ad7ce79bc4 use all caps for device kind 2024-10-22 16:30:46 -07:00
Yegor Pomortsev
c6bf27fd2c Fix test_hf_batched_response 2024-10-22 12:41:17 -07:00
Gilles Boccon-Gibod
7584daa3f9 update to new netsim proto with DeviceInfo 2024-10-22 11:48:42 -07:00
Yegor Pomortsev
654030e789 Add tests for batched HFP commands/responses; reformat 2024-10-21 16:32:20 -07:00
Gilles Boccon-Gibod
1de7d2cd6f Merge pull request #571 from google/gbg/a2dp-player
a2dp player
2024-10-19 07:40:43 -07:00
Gilles Boccon-Gibod
68db78c833 remove unnecessary import 2024-10-19 07:32:11 -07:00
Yegor Pomortsev
e1714c16cc HFP: Fix reading multiple AT commands from a single data packet
The `data` received in `_read_at` may have multiple commands.

This fixes `execute_command` timing out when waiting for an `OK`
response when it is in the same data buffer, e.g. during SLC
initialization: b'\r\n+BRSF: 3904\r\n\r\nOK\r\n'
2024-10-18 13:21:24 -07:00
Gilles Boccon-Gibod
0a20f14ea9 address PR comments 2024-10-15 15:26:19 -07:00
William Escande
23f46b36b3 HAP: wait for pairing event (#551) 2024-10-10 11:34:44 -07:00
Gilles Boccon-Gibod
009649abd1 remove unused section 2024-10-09 21:43:47 -07:00
Gilles Boccon-Gibod
855a007116 fix type checker 2024-10-09 21:34:03 -07:00
Gilles Boccon-Gibod
d064de35e0 wip 2024-10-09 21:34:03 -07:00
Gilles Boccon-Gibod
dab4d13303 wip 2024-10-09 21:34:03 -07:00
Gilles Boccon-Gibod
2bed50b353 add mobly to dev deps 2024-10-09 21:22:35 -07:00
Gilles Boccon-Gibod
1fe3778a74 adjust mypy excludes 2024-10-08 22:02:43 -07:00
Gilles Boccon-Gibod
f5443a9826 Merge pull request #564 from initializedd/fix-typo-in-comment
Fix typo in comment
2024-10-08 21:56:06 -07:00
zxzxwu
db723a5196 Merge pull request #569 from wpiet/cig-example-fix
examples/run_cig_setup: Fix the address type and CIG params
2024-10-05 17:20:32 +08:00
Gilles Boccon-Gibod
5e31bcf23d add mobly example 2024-10-04 18:17:56 -07:00
Gilles Boccon-Gibod
fe429cb2eb wip 2024-10-04 18:13:31 -07:00
Gilles Boccon-Gibod
c91695c23a wip 2024-10-04 18:13:31 -07:00
Gilles Boccon-Gibod
55f99e6887 wip 2024-10-04 18:13:31 -07:00
Gilles Boccon-Gibod
b190069f48 add snippets lib 2024-10-04 18:13:31 -07:00
Wojciech Pietraszewski
e16be1a8f4 docs/examples: Add run_cig_setup description
Adds basic information to the `examples` section of the documentation.
2024-10-02 18:51:11 +02:00
Wojciech Pietraszewski
2fa8075fb0 examples/run_cig_setup: Fix the address type and CIG params
Changes the address type used during connecting to what is actually advertised
by Device 0 by default (random address).

Amends CIG Parameters to use values allowed by the Core specification.

Updates the usage of the script and the example that show when executed incorrectly.
2024-10-02 18:50:57 +02:00
zxzxwu
566ca13d23 Merge pull request #561 from wpiet/csis-usage
run_csis_servers: Update `usage` and add docs entry
2024-10-01 17:34:22 +08:00
zxzxwu
e5666c0510 Merge pull request #565 from zxzxwu/controller
Add codecs info in controller info app
2024-10-01 15:35:32 +08:00
Slvr
46ec39ccfb avatar: update to latest version to correct flakiness (#568) 2024-10-01 00:19:41 -07:00
Slvr
eef418ae5f Collect Mobly logs (#566) 2024-09-30 15:21:19 -07:00
initializedd
9e663ad051 Clarify Bluetooth address comments 2024-09-30 18:39:02 +01:00
Wojciech Pietraszewski
f28eac4c14 docs/examples: Fix typo
Corrects the typo in the section's description.
2024-09-30 15:26:39 +02:00
Wojciech Pietraszewski
669bb3f3a8 run_csis_servers: Update usage and add docs entry
Amends the usage of the script and the example that show when executed incorrectly.
Adds basic information to the `examples` section of the documentation.
2024-09-30 15:25:40 +02:00
Josh Wu
347fe8b272 Add codecs info in controller info app 2024-09-30 00:24:06 +08:00
Gilles Boccon-Gibod
d56c4d0a11 Merge pull request #563 from initializedd/fix-whitespace
Fix whitespace
2024-09-27 18:31:59 -07:00
Gilles Boccon-Gibod
034140ccbd Merge pull request #562 from initializedd/support-netsim-ini-tmpdir-on-linux
Support netsim.ini tmpdir on linux
2024-09-27 14:08:47 -07:00
initializedd
35bef7d7b7 Fix whitespace 2024-09-27 20:49:30 +01:00
initializedd
d069708c79 Support netsim.ini tmpdir on linux 2024-09-27 19:25:49 +01:00
Slvr
bdba5c9d95 pyusb: check devices_in_use before removal (#559) 2024-09-24 13:40:58 -07:00
zxzxwu
ff659383f9 Merge pull request #556 from zxzxwu/default
Replace mutable default values
2024-09-21 16:18:13 +08:00
Josh Wu
f06a35713f Replace unsafe default values 2024-09-18 21:09:08 +08:00
Slvr
737abdc481 aics: make it a secondary service (#555)
* aics: make it a secondary service
---------

Co-authored-by: zxzxwu <92432172+zxzxwu@users.noreply.github.com>
2024-09-17 16:06:47 -07:00
Gilles Boccon-Gibod
02eb4d2e1c Merge pull request #554 from google/gbg/pair-app-fixes
add support for selecting the identity address
2024-09-15 17:21:06 -07:00
Gilles Boccon-Gibod
e7f9acb421 add support for selecting the identity address 2024-09-14 15:14:10 -07:00
zxzxwu
976e6cce57 Merge pull request #553 from zxzxwu/profiles
Remove att.CommonErrorCode
2024-09-14 18:12:27 +08:00
Josh Wu
dfdf37019c Remove att.CommonErrorCode 2024-09-14 00:50:19 +08:00
zxzxwu
56ca19600b Merge pull request #552 from zxzxwu/hci
Add some HCI commands and events
2024-09-13 13:46:19 +08:00
Slvr
cd9feeb455 Implement AICS (#547)
* aics: Implement AICS and tests
2024-09-12 08:51:20 -07:00
Josh Wu
f8e5b88be6 Add some HCI commands and events 2024-09-12 22:31:54 +08:00
Gilles Boccon-Gibod
0f71a63b42 Merge pull request #534 from hkpeprah/ford/bug/rtk-edimax-2
[Bug] Edimax BLE Dongle Fails After Teardown and Re-Instantiation
2024-09-11 09:00:02 -07:00
Ford Peprah
b7259abe3c Fix typing errors. 2024-09-10 10:59:46 -04:00
William Escande
00e660d410 Implement Hap support (#532)
* Implement Hap
2024-09-09 16:24:22 -07:00
Ford Peprah
88e3a2b87f Fix linting errors. 2024-09-09 10:54:01 -04:00
Ford Peprah
aa658418bc Bug: Edimax BLE Dongle Fails After Teardown and Re-Instantiation
This patch addresses an issue where the some RTK BLE dongles fail to perform
an HCI reset after the transport is torn down and re-instantiated. To address
that, we prevent crashing the background threads when invalid data comes in,
and time out if no response is received within a fixed amount of time. When
the timeout occurs, we retry the reset, and ultimately skip over reading the
local version information if that fails.
2024-09-09 10:54:01 -04:00
zxzxwu
ac0cff43b6 Merge pull request #549 from zxzxwu/gatt
Return ATT_Error_Response on rejected write request
2024-09-09 21:23:05 +08:00
Josh Wu
8051c23375 Return ATT_Error_Response on rejected write 2024-09-08 01:12:51 +08:00
zxzxwu
7b34bb4050 Merge pull request #548 from zxzxwu/gatt
Fix TBS Characteristics UUID
2024-09-05 22:58:50 +08:00
Josh Wu
fe38ab35cf Fix TBS Characteristics UUID 2024-09-05 17:59:28 +08:00
zxzxwu
65a9102ba1 Merge pull request #545 from google/pandora_l2cap_service
Pandora: refactor l2cap service
2024-09-05 11:14:03 +08:00
Charlie Boutier
1256170985 Pandora: refactor l2cap service
* Craft the PandoraChannel from the connection_handle and the source_cid
* Fix race on waitDisconnection
* Add ChannelContext to enable mutliple channels on the service
2024-09-03 15:52:40 +00:00
zxzxwu
4394a36332 Merge pull request #526 from Gopi-SB/oob
DH Key compute check modification for OOB Pairing
2024-08-29 16:56:45 +08:00
Gopi Sakshihally Bhuthaiah
0c9fd64434 DH Key compute check modification for OOB Pairing 2024-08-29 08:46:53 +00:00
Samad Atoro
2e99153696 Pandora: Add L2CAP service 2024-08-23 16:38:29 -07:00
zxzxwu
54a6f3cb36 Merge pull request #536 from zxzxwu/asha
Refactor ASHA service implementation and examples
2024-08-24 01:19:42 +08:00
Charlie Boutier
4a691c11d4 pyusb: allow to detect multiple usb dongle
Allow to detect multiple usb dongle by just provind the pid/vid
2024-08-23 08:22:43 -07:00
Gilles Boccon-Gibod
b114c0d63f Merge pull request #539 from google/gbg/usb-thread-hotfix
hotfix for usb transport
2024-08-22 22:36:24 -07:00
Gilles Boccon-Gibod
a311c3f723 hotfix for usb transport 2024-08-22 22:26:44 -07:00
Josh Wu
04311b4c90 Refactor ASHA service and integrate with examples 2024-08-22 12:53:19 +08:00
zxzxwu
b2bb82a432 Merge pull request #537 from zxzxwu/smp
Ignore invalid RPA
2024-08-21 13:54:02 +08:00
Josh Wu
597560ff80 Ignore invalid local resolvable address 2024-08-21 00:11:14 +08:00
Gilles Boccon-Gibod
db383bb3e6 Merge pull request #531 from AlanRosenthal/btbench-scan
BtBench: Add Scan functionality
2024-08-14 11:59:13 -07:00
Alan Rosenthal
ccc5bbdad4 BtBench: Scan 2024-08-14 11:26:31 -04:00
zxzxwu
11c8229017 Merge pull request #533 from zxzxwu/hid
Correct HID type annotations
2024-08-14 12:08:53 +08:00
Josh Wu
2248f9ae5e Correct HID type annotations 2024-08-13 23:13:33 +08:00
Gopi Sakshihally Bhuthaiah
c44c89cc6e DH Key compute check modification for OOB Pairing 2024-08-13 02:10:41 +00:00
Gilles Boccon-Gibod
03c79aacb2 Merge pull request #529 from google/gbg/broadcast-assistant
basic broadcast assistant functionality
2024-08-12 13:02:50 -07:00
zxzxwu
0c31713a8e Merge pull request #528 from zxzxwu/rpa
Fix CTKD failure introduced by Host RPA generation
2024-08-13 01:30:19 +08:00
Gilles Boccon-Gibod
9dd814f32e strict compliance check 2024-08-12 08:31:40 -07:00
Gilles Boccon-Gibod
ab6e595bcb fix typing 2024-08-12 08:31:40 -07:00
Gilles Boccon-Gibod
f08fac8c8a catch ATT errors 2024-08-12 08:31:40 -07:00
Gilles Boccon-Gibod
a699520188 fix after rebase merge 2024-08-12 08:31:40 -07:00
Gilles Boccon-Gibod
f66633459e wip 2024-08-12 08:31:40 -07:00
Gilles Boccon-Gibod
f3b776c343 wip 2024-08-12 08:31:37 -07:00
Gilles Boccon-Gibod
de7b99ce34 wip 2024-08-12 08:29:32 -07:00
Gilles Boccon-Gibod
c0b17d9aff Merge pull request #530 from google/gbg/usb-no-parser
don't user a parser for a usb source
2024-08-12 08:21:19 -07:00
zxzxwu
3c12be59c5 Merge pull request #527 from zxzxwu/scan
Support Interlaced Scan config
2024-08-12 15:15:49 +08:00
Josh Wu
c6b3deb8df Fix CTKD failure introduced by Host RPA generation 2024-08-12 15:13:40 +08:00
Gopi Sakshihally Bhuthaiah
414f2f3efb DH Key compute check modification for OOB Pairing 2024-08-12 07:00:51 +00:00
Gilles Boccon-Gibod
a0b5606047 don't user a parser for a usb source 2024-08-11 20:57:45 -07:00
Gopi Sakshihally Bhuthaiah
ed00d44ae1 DH Key compute check modification for OOB Pairing 2024-08-09 17:30:19 +00:00
Josh Wu
3824e38485 Support Interlaced Scan config 2024-08-09 22:09:26 +08:00
Gopi Sakshihally Bhuthaiah
b164524380 DH Key compute check modification for OOB Pairing 2024-08-08 10:31:26 +00:00
Gopi Sakshihally Bhuthaiah
29e4a843df DH Key compute check modification for OOB Pairing 2024-08-08 08:48:58 +00:00
Gopi Sakshihally Bhuthaiah
619b32d36e DH Key compute check modification for OOB Pairing 2024-08-08 07:53:05 +00:00
Gilles Boccon-Gibod
4433184048 Merge pull request #522 from google/gbg/rpa2
add basic RPA support
2024-08-06 10:35:39 -07:00
Gilles Boccon-Gibod
312fc8db36 support controller-generated rpa 2024-08-05 08:59:05 -07:00
Gilles Boccon-Gibod
615691ec81 add basic RPA support 2024-08-01 15:37:11 -07:00
zxzxwu
ae8b83f294 Merge pull request #521 from zxzxwu/bap
Add Metadata LTV serializer and adapt Unicast
2024-07-31 11:36:46 +08:00
Josh Wu
4a8e21f4db Add Metadata LTV serializer and adapt Unicast 2024-07-31 01:20:28 +08:00
zxzxwu
3462e7c437 Merge pull request #439 from zxzxwu/mcp
Media Control Service Client implementation
2024-07-24 23:45:00 +08:00
Josh Wu
0f2e5239ad MCP constants and Client implementation 2024-07-24 22:57:26 +08:00
Gilles Boccon-Gibod
ee48cdc63f Merge pull request #517 from AlanRosenthal/scanner_pyee
Update scanner.py to use pyee.EventEmitter
2024-07-18 12:53:00 -07:00
Gilles Boccon-Gibod
1c278bec93 Merge pull request #518 from google/gbg/usb-queue
USB: better packet queue logic
2024-07-18 12:51:00 -07:00
Gilles Boccon-Gibod
6a51166af7 better packet queue logic 2024-07-17 17:48:26 -07:00
Alan Rosenthal
85d79fa914 Update scanner.py to use pyee.EventEmitter 2024-07-17 16:53:50 -04:00
zxzxwu
142bdce94a Merge pull request #515 from zxzxwu/unix
Add UNIX socket transport
2024-07-17 16:04:38 +08:00
Josh Wu
881a5a64b5 Add UNIX socket transport 2024-07-17 00:41:04 +08:00
zxzxwu
5aae44b610 Merge pull request #501 from zxzxwu/exception
Reorganize exceptions
2024-07-12 15:44:58 +08:00
Gilles Boccon-Gibod
e3ea167827 Merge pull request #506 from google/gbg/a2dp-fixes
a2dp: emit delay_report
2024-07-11 18:46:06 -07:00
Gilles Boccon-Gibod
eec145e095 add type hint 2024-07-11 18:39:02 -07:00
Gilles Boccon-Gibod
87fa02d6e5 Merge pull request #507 from google/packageFile
Create `inv web.build`
2024-07-11 18:35:29 -07:00
Gilles Boccon-Gibod
ad94c1e1f3 Merge pull request #509 from AlanRosenthal/discover
device.py: Add discover_all() api
2024-07-11 18:34:29 -07:00
Gilles Boccon-Gibod
546a0bce8d Merge pull request #510 from AlanRosenthal/get_characteristics_by_uuid
device.py: Update get_characteristics_by_uuid()
2024-07-11 18:33:45 -07:00
Gilles Boccon-Gibod
cb7ca44a1c Merge pull request #512 from AlanRosenthal/favicon
Add favicon.ico to docs folder
2024-07-11 18:27:19 -07:00
Gilles Boccon-Gibod
4081b93407 Merge pull request #513 from AlanRosenthal/devcontainer
Add devcontainer.json
2024-07-11 18:24:09 -07:00
Alan Rosenthal
26203ebaad Add devcontainer.json
devcontainer.json allows github's codespaces to be created with bumble's dependencies already installed
2024-07-11 18:47:32 +00:00
Alan Rosenthal
3389e3e1ed device.py: Update get_characteristics_by_uuid()
`get_characteristics_by_uuid()` now allows a UUID to be passed to the
service param. This allows for users to easily query for a service uuid
and characteristic uuid with one API.
2024-07-11 18:05:41 +00:00
Alan Rosenthal
7e1f01c01e Add favicon.ico to docs folder
Generated via: realfavicongenerator.net

validated via:
```
$ icotool -l favicon.ico
--icon --index=1 --width=48 --height=48 --bit-depth=32 --palette-size=0
--icon --index=2 --width=32 --height=32 --bit-depth=32 --palette-size=0
--icon --index=3 --width=16 --height=16 --bit-depth=32 --palette-size=0
```
2024-07-11 09:47:19 -04:00
Gilles Boccon-Gibod
613e15548a Merge pull request #511 from AlanRosenthal/random
console.py: Use Address.generate_static_address
2024-07-10 13:45:52 -07:00
Alan Rosenthal
e09c91df8e console.py: Use Address.generate_static_address 2024-07-10 18:51:46 +00:00
Alan Rosenthal
df206667b6 device.py: Add discover_all() api 2024-07-10 13:24:08 -04:00
Gilles Boccon-Gibod
0f19dd5263 Merge pull request #508 from google/web-readme
Add tip about disabling caching to web's readme
2024-07-09 09:17:25 -07:00
Alan Rosenthal
b98e4937f3 Add tip about disabling caching to web's readme 2024-07-09 13:48:53 +00:00
Alan Rosenthal
c2c46e9ace Create inv web.build
This command will build a wheel, copy it in the web directory, and create a file `packageFile` with the name of the wheel. If the correct override param is given, bumble.js will read `packageFile` and load that package.
2024-07-09 09:32:21 -04:00
Gilles Boccon-Gibod
27791cf218 emit delay_report 2024-07-03 13:51:15 -07:00
Gilles Boccon-Gibod
32a41a815d Merge pull request #502 from google/gbg/extended-advertising-termination-reverse
support out of order advertising set termination / connection events
2024-06-18 16:42:06 -07:00
Gilles Boccon-Gibod
df5fc2ddfe add test 2024-06-12 10:13:57 -07:00
Gilles Boccon-Gibod
79122313a6 Merge pull request #489 from google/gbg/basic-auracast-app
basic auracast app
2024-06-12 10:06:30 -07:00
Gilles Boccon-Gibod
d7d03e2e92 Merge pull request #504 from google/gbg/bench-role-change
bench role change
2024-06-12 10:06:11 -07:00
Gilles Boccon-Gibod
ea493480a9 remove duplicated lines 2024-06-11 13:23:35 -07:00
Gilles Boccon-Gibod
658f641a53 add manufacturer data 2024-06-11 13:21:04 -07:00
Josh Wu
f8a2d4f0e0 Reorganize exceptions
* Add BaseBumbleException as a "real" root error
* Add several core error classes and properly replace builtin errors
  with them
* Add several error classes for specific modules (transport, device)
2024-06-11 16:13:08 +08:00
Gilles Boccon-Gibod
00edd1fbf8 post-rebase fixes 2024-06-10 10:30:59 -07:00
Gilles Boccon-Gibod
999d7b07e1 wip 2024-06-09 11:39:44 -07:00
Gilles Boccon-Gibod
2e3aeb8648 support out of order advertising set termination / connection events 2024-06-05 16:29:31 -07:00
Gilles Boccon-Gibod
f910a696ad Merge pull request #499 from google/gbg/rfcomm-bridge
rfcomm bridge app
2024-06-05 11:18:13 -07:00
Gilles Boccon-Gibod
e1d10bc482 add rfcomm disconnect test 2024-06-05 10:03:27 -07:00
Gilles Boccon-Gibod
181467f11b Merge pull request #500 from google/gbg/fix-advertising-auto-restart
fix legacy advertising auto restart
2024-06-04 06:39:54 -07:00
Gilles Boccon-Gibod
394137b6f7 fix legacy advertising auto restart 2024-06-03 19:08:46 -07:00
Gilles Boccon-Gibod
dea907be86 attempt to fix pandora test (+3 squashed commits)
Squashed commits:
[759372d] address PR comments
[2f2a275] wip
[cc86b98] wip

wip

address PR comments

attempt to fix pandora test
2024-06-03 18:22:29 -07:00
Gilles Boccon-Gibod
f5baf51132 improve DLC parameters 2024-06-03 18:11:13 -07:00
Gilles Boccon-Gibod
f2dc8bd84e wip (+2 squashed commits)
Squashed commits:
[451a295] wip
[ed7b5b6] wip (+1 squashed commit)
Squashed commits:
[9d938c8] wip

wip

wip
2024-05-30 14:59:22 -07:00
zxzxwu
090309302f Merge pull request #372 from zxzxwu/source
ASCS Source Implementation
2024-05-29 13:17:51 +08:00
Charlie Boutier
28e6229b24 Fix: Preserve transport metadata
Preserve transport metadata when wrapping with SnoopingTransport
2024-05-28 09:20:53 -07:00
Josh Wu
1b66f03dbe ASCS: Add Source ASE operations 2024-05-27 14:48:23 +08:00
Gilles Boccon-Gibod
e34f6b5fd3 Merge pull request #484 from google/gbg/quick-fix-002
fix incorrect var reference
2024-05-13 16:11:42 -07:00
Gilles Boccon-Gibod
8a0482c947 Merge pull request #485 from google/gbg/gh-action-py312
add python 3.12 to GH actions
2024-05-13 16:11:25 -07:00
zxzxwu
938a189f3f Merge pull request #478 from zxzxwu/config
Make DeviceConfiguration dataclass
2024-05-13 16:57:15 +08:00
Gilles Boccon-Gibod
2005b4a11b python 3.12 compatibility 2024-05-12 12:54:52 -07:00
Gilles Boccon-Gibod
951fdc8bdd add python 3.12 to GH actions 2024-05-12 12:07:05 -07:00
Gilles Boccon-Gibod
12af7a526c fix incorrect var reference 2024-05-12 11:59:05 -07:00
zxzxwu
8781943646 Merge pull request #483 from zxzxwu/rfc
RFCOMM: Handle packets received before DLC sink set
2024-05-10 16:34:57 +08:00
Gilles Boccon-Gibod
7fbfdb634c Merge pull request #481 from google/gbg/command-status-fix
allow checking results for HCI_Command_Status_Event
2024-05-09 19:50:10 -07:00
Josh Wu
9682077f6b RFCOMM: Avoid receive packets before DLC sink set 2024-05-09 17:57:13 +08:00
Gilles Boccon-Gibod
22eb405fde Merge pull request #482 from servusdei2018/main
bumble.js(PacketSink): Implement asynchronous packet processing
2024-05-08 20:16:04 -07:00
zxzxwu
593c61973f Merge pull request #480 from zxzxwu/hfp-ag
HFP: Add AG example and fix errors
2024-05-07 17:50:01 +08:00
Josh Wu
ccff32102f HFP: Add example and fix AG errors 2024-05-07 00:36:52 +08:00
Nate
851d62c6c9 bumble.js(PacketSink): Implement asynchronous packet processing 2024-05-05 15:03:22 -04:00
Josh Wu
a5ac5f26e2 Make DeviceConfiguration dataclass 2024-05-05 17:25:01 +08:00
Gilles Boccon-Gibod
090158820f allow checking results for HCI_Command_Status_Event 2024-05-04 12:17:05 -07:00
zxzxwu
26e6650038 Merge pull request #477 from zxzxwu/hfp-ag
Fix HFP query call status
2024-05-02 01:17:17 +08:00
Josh Wu
c48568aabe Fix HFP query call status 2024-04-30 03:13:38 +00:00
zxzxwu
1b33c9eb74 Merge pull request #475 from zxzxwu/hfp-ag
Add more HFP command suppport
2024-04-26 12:01:20 +08:00
zxzxwu
6633228975 Add more HFP command suppport
* Support all Call Hold Operation
* Support CLI Presentation
* Support Voice Recognition
* Support RING and Volume Changes
* [AG] Support Enhanced Call Status
* Minor fixes
2024-04-24 15:29:48 +00:00
Gilles Boccon-Gibod
e9cba788a4 Merge pull request #473 from google/barbibulle-patch-2
quick fix: revert to protobuf 3.12.4
2024-04-22 11:46:04 +02:00
Gilles Boccon-Gibod
98822cfc6b quick fix: revert to protobuf 3.12.4
The upgrade to 4.x wasn't really needed, and breaks some users.
2024-04-18 21:20:18 -07:00
Gilles Boccon-Gibod
97ad7e5741 Merge pull request #472 from google/gbg/update-pandora-deps
update protobuf dep and make pandora install optional
2024-04-18 11:21:29 -07:00
Charlie Boutier
71df062e07 pyusb: power_cycle if '!' is present at the start of the transport 2024-04-17 14:12:55 -07:00
Charlie Boutier
049f9021e9 pyusb: powercycle the dongle 2024-04-17 14:12:55 -07:00
Gilles Boccon-Gibod
50eae2ef54 add pandora to code-check action 2024-04-17 13:19:07 -07:00
Gilles Boccon-Gibod
c8883a7d0f update protobuf dep and make pandora install optional 2024-04-17 13:14:21 -07:00
zxzxwu
51321caf5b Merge pull request #470 from zxzxwu/examples
Type hint all examples
2024-04-16 02:56:08 +08:00
zxzxwu
51a94288e2 Type hint all examples 2024-04-15 12:48:21 +00:00
zxzxwu
8758856e8c Merge pull request #465 from zxzxwu/hfp-ag
HFP AG implementation
2024-04-12 22:15:25 +08:00
Josh Wu
deba181857 HFP AG implementation 2024-04-10 09:51:37 +00:00
zxzxwu
c65188dcbf Merge pull request #466 from zxzxwu/format
Fix format presubmit error
2024-04-09 02:59:36 +08:00
Josh Wu
21d607898d Fix format presubmit error 2024-04-09 01:44:04 +08:00
Gilles Boccon-Gibod
2698d4534e Merge pull request #435 from jeru/main
open_tcp_server_transport: allow explicit sock as input.
2024-04-04 19:17:07 -07:00
zxzxwu
bbcd64286a Merge pull request #463 from zxzxwu/hfp
Correct HFP AG indicator index
2024-04-04 12:53:19 +08:00
Gilles Boccon-Gibod
9140afbf8c Merge pull request #456 from google/gbg/update-dependencies
update some dependencies
2024-04-03 17:50:18 -06:00
Gilles Boccon-Gibod
90a682c71b bump to avatar 0.0.9 2024-04-03 16:26:07 -07:00
Gilles Boccon-Gibod
e8737a8243 update to more recent versions 2024-04-03 10:00:11 -07:00
Gilles Boccon-Gibod
72fceca72e update some dependencies 2024-04-03 10:00:09 -07:00
Gilles Boccon-Gibod
732294abbc Merge pull request #462 from google/gbg/461
fix #461
2024-04-03 10:56:05 -06:00
Josh Wu
dc1204531e Correct HFP AG indicator index 2024-04-03 17:58:04 +08:00
Gilles Boccon-Gibod
962114379c fix #461 2024-04-02 23:14:32 -07:00
Gilles Boccon-Gibod
e6913a3055 Merge pull request #457 from google/gbg/bench-ascyncio-main
delay creation of runner object
2024-04-02 21:39:37 -06:00
Gilles Boccon-Gibod
e21d122aef Merge pull request #458 from google/gbg/update-formatter
update black formatter to version 24
2024-04-02 21:39:24 -06:00
Gilles Boccon-Gibod
58d4ab913a update black formatter to version 24 2024-04-01 14:44:46 -07:00
Gilles Boccon-Gibod
76bca03fe3 format with the project's version of black 2024-04-01 14:39:34 -07:00
Gilles Boccon-Gibod
f1e5c9e59e delay creation of runner object 2024-04-01 14:25:38 -07:00
zxzxwu
ec82242462 Merge pull request #440 from zxzxwu/hfp
Rework HFP example
2024-03-27 16:54:41 +08:00
zxzxwu
a4efdd3f3e Merge pull request #442 from zxzxwu/unicast_ad
Implement Unicast Server Advertising Data
2024-03-27 16:54:06 +08:00
Gilles Boccon-Gibod
69c6643bb8 Merge pull request #452 from marshallpierce/mp/rust-0.2.0
Bumble crate 0.2.0
2024-03-21 17:15:43 -07:00
Marshall Pierce
b8214bf948 Bumble crate 0.2.0 2024-03-21 12:36:32 -06:00
Charlie Boutier
a9c62c44b3 pandora host: change AdvertisingType
change advertising type from high duty to low duty

Test: python le_host_test.py -c config.yml --test_bed android.bumbles --tests "test_scan('connectable','non_scannable','directed',0)" -v
2024-03-20 11:17:50 -07:00
Charlie Boutier
7d0b4ef4e0 pandora_server: Parse FLAGS into advertising data
Bug: 328089785
2024-03-18 09:20:55 -07:00
Charlie Boutier
313340f1c6 intel driver: check the vendorId and productId 2024-03-15 10:53:33 -07:00
Charlie Boutier
e8ed69fb09 pyusb: Collect vendorId and productId as metadata 2024-03-15 10:53:33 -07:00
David Duarte
16d5cf6770 usb: Add usb path moniker
Add a new moniker for usb and pyusb driver allowing
to select the usb device using its bus id and port
path like `usb:3-3.4.1`.
2024-03-15 09:17:39 -07:00
Gilles Boccon-Gibod
a2caf1deb2 Merge pull request #448 from BenjaminLawson/bump-avatar
Bump pandora-avatar to 0.0.8
2024-03-14 20:49:28 -07:00
Ben Lawson
01bfdd2c98 Bump pandora-avater to 0.0.8 2024-03-14 14:13:27 -07:00
Gilles Boccon-Gibod
4a60df108a Merge pull request #447 from BenjaminLawson/bump-rootcanal
Bump rootcanal to 1.9.0
2024-03-14 14:00:36 -07:00
Ben Lawson
ad48109748 Bump rootcanal to 1.9.0 2024-03-14 13:15:02 -07:00
Cheng Sheng
1ceeccbbc0 open_tcp_server_transport: allow explicit sock as input.
When a user doesn't need an exact port, but cares more about getting
SOME unused port, they can do:
* Create a socket outside with port=None or port=0.
* Use socket.getsockname()[1] to get the allocated port and pass to the
TCP client somehow.
* Use the created socket to create a TCP server transport.

Use-case: unit-testing embedded software that implements a BLE host. The
controller will be a Bumble controller, connected to the host via a TCP
channel.
* The host will have a TCP-client HCI transport for testing.
* The pytest setup code will allocate the TCP server and pass the port
number to the host.

Also add some unittests with python mock.
2024-03-13 19:34:05 +01:00
Gilles Boccon-Gibod
44c51c13ac Merge pull request #445 from google/gbg/driver-probe-fix
fix intel driver probe
2024-03-12 12:51:08 -07:00
Gilles Boccon-Gibod
7507be1eab update metadata when setting the host controller directly 2024-03-12 11:50:47 -07:00
Gilles Boccon-Gibod
cbe9446dcf fix intel driver probe 2024-03-12 09:54:20 -07:00
Charlie Boutier
174930399a intel: send vsc INTEL_DDC_CONFIG_WRITE
This VSC enable host-initiated role-switching after connection.

Implement this VSC in a driver fashion.

Test: avatar security_test with the Bluetooth Dongle Intel BE200
2024-03-11 09:15:18 -07:00
Josh Wu
35db4a4c93 Implement Unicast Server Advertising Data 2024-03-08 16:48:37 +08:00
Gilles Boccon-Gibod
1f3aee5566 Merge pull request #438 from BenjaminLawson/pandora-extended-advertising
Implement Pandora extended advertising
2024-03-07 20:36:56 -08:00
Ben Lawson
256044a789 Implement Pandora extended advertising
Support setting the PHY of Pandora scans.
2024-03-07 16:18:49 -08:00
Josh Wu
6205199d7f Rework HFP example 2024-03-05 20:53:28 +08:00
Gilles Boccon-Gibod
e554bd1033 Merge pull request #434 from google/gbg/show-timestamps
show timestamps from snoop logs
2024-02-29 11:44:23 -08:00
Gilles Boccon-Gibod
38981cefa1 pad index field 2024-02-28 11:46:35 -08:00
Gilles Boccon-Gibod
f2d601f411 show timestamps from snoop logs 2024-02-27 16:40:37 -08:00
zxzxwu
6e7c64c1de Merge pull request #431 from zxzxwu/rust
Bump Rust to 1.76.0
2024-02-23 15:14:30 +08:00
Josh Wu
565d51f4db Bump Rust to 1.76.0
```
error: failed to compile `cargo-all-features v1.10.0`, intermediate artifacts can be found at `/tmp/cargo-installshCmAG`

Caused by:
  package `clap v4.5.1` cannot be built because it requires rustc 1.74 or newer, while the currently active rustc version is 1.70.0
  Try re-running cargo install with `--locked`

```
2024-02-22 15:22:20 +08:00
Gilles Boccon-Gibod
de8f3d9c1e Merge pull request #426 from akuker/patch-1
Add clarification to short circuit list feature
2024-02-12 21:22:14 -08:00
Tony Kuker
cde6d48690 Add clarification to short circuit list feature 2024-02-12 12:22:36 -06:00
zxzxwu
02180088b3 Merge pull request #425 from zxzxwu/command
Refactor command supporting list
2024-02-07 21:45:52 +08:00
zxzxwu
90f49267d1 Merge pull request #424 from zxzxwu/adv
Fix double-disable legacy advertising set
2024-02-06 16:13:51 +08:00
Josh Wu
0e6d69cd7b Refactor command supporting list 2024-02-06 12:06:00 +08:00
Josh Wu
9eccc583d5 Fix double-disable legacy advertising set
When legacy advertising set is disabled passively(by set termination),
the legacy advertising set won't be released, and the next
stop_advertising() call will try to disable it again and cause an error.
2024-02-06 12:00:30 +08:00
Gilles Boccon-Gibod
f4aeaa6eb3 Merge pull request #422 from google/gbg/bench-rfcomm-params
add rfcomm options and fix l2cap mtu negotiation
2024-02-05 09:14:16 -08:00
Gilles Boccon-Gibod
d7489a644a update websockets version (for better typecheck) 2024-02-05 09:07:39 -08:00
Gilles Boccon-Gibod
a877283360 add rfcomm options and fix l2cap mtu negotiation 2024-02-05 08:56:59 -08:00
zxzxwu
6d91e7e79b Merge pull request #423 from zxzxwu/vcp
Fix Lint error in VCP example
2024-02-06 00:40:05 +08:00
Josh Wu
567146b143 Fix Lint error in VCP example 2024-02-04 21:23:22 +08:00
zxzxwu
1a3272d7ca Merge pull request #412 from zxzxwu/vcp
Add Volume Control Service
2024-02-04 00:42:51 +08:00
zxzxwu
1ee1ff0b62 Merge pull request #420 from zxzxwu/rfc
Add RFCOMM and SDP context manager and search helper
2024-02-04 00:42:24 +08:00
zxzxwu
729fd97748 Merge pull request #419 from zxzxwu/feat
Add local LMP feature reader
2024-02-03 13:51:19 +08:00
Josh Wu
e308051885 Add LMP feature reader 2024-02-03 13:29:25 +08:00
Josh Wu
10e53553d7 Add RFCOMM and SDP helpers 2024-02-03 13:13:35 +08:00
Gilles Boccon-Gibod
ef0b30d059 Merge pull request #382 from google/gbg/extended-advertising-v2
extended advertising v2
2024-02-02 20:43:28 -08:00
Gilles Boccon-Gibod
e7e9f9509a update rootcanal version 2024-02-02 20:33:19 -08:00
zxzxwu
c6cfd101df Merge pull request #415 from zxzxwu/hfp
HFP: State memory and event emission
2024-02-02 11:36:53 +08:00
Josh Wu
d2dcf063ee HFP: State memory and event emit 2024-02-01 12:08:43 +08:00
Michael Mogenson
d15bc7d664 Merge pull request #417 from mogenson/controller-loopback-cid-range
controller_loopback: LE support and max packet count
2024-01-31 21:13:21 -05:00
zxzxwu
e4364d18a7 Merge pull request #418 from zxzxwu/rfc
RFCOMM: Slightly refactor and correct constants
2024-02-01 01:30:53 +08:00
Josh Wu
6a34c9f224 RFCOMM: Slightly refactor and correct constants 2024-02-01 01:18:56 +08:00
Michael Mogenson
2a764fd6bb controller_loopback: LE support and max packet count
Bound the packet count CLI option. We're using the L2CAP header CID for
a paket ID, so the max packet count value has to fit into this 16-bit
field.

Add support for controllers that are LE only by checking the
le_acl_packet_queue.max_size.

Tested with 65535 max packet count. Took 138 seconds at 481 kB/s with a
USB BT dongle.
2024-01-31 10:26:51 -05:00
Josh Wu
3e8ce38eba Add Volume Control Service 2024-01-31 10:04:30 +08:00
Gilles Boccon-Gibod
8d2f37aa7a inclusive language 2024-01-28 19:09:39 -08:00
Gilles Boccon-Gibod
b7b70ebcbb address PR comments 2024-01-28 19:09:37 -08:00
Gilles Boccon-Gibod
8ba91f4986 fix assert 2024-01-28 19:02:32 -08:00
Gilles Boccon-Gibod
79a5e953bc comply with limits for certain advertising event types 2024-01-28 19:02:32 -08:00
Gilles Boccon-Gibod
20de5ea250 format 2024-01-28 19:02:32 -08:00
Gilles Boccon-Gibod
bad9ce272c add doc 2024-01-28 19:02:32 -08:00
Gilles Boccon-Gibod
d3273ffa8c format (+3 squashed commits)
Squashed commits:
[60e610f] wip
[eeab73d] wip
[3cdd5b8] basic first pass
2024-01-28 19:02:30 -08:00
zxzxwu
071fc2723a Merge pull request #376 from zxzxwu/host
Manage lifecycle of CIS and SCO links in host
2024-01-28 22:09:08 +08:00
zxzxwu
ef4ea86f58 Merge pull request #381 from zxzxwu/offload
Support non-directed address generation offload
2024-01-28 22:08:32 +08:00
Gilles Boccon-Gibod
dfdaa149d0 Merge pull request #337 from google/gbg/avrcp
Add AVRCP support
2024-01-28 01:27:52 -08:00
Gilles Boccon-Gibod
986343a807 support multiple type checkers for pandora 2024-01-28 01:21:50 -08:00
Gilles Boccon-Gibod
5211d7ba96 revert to older pytest_asyncio 2024-01-28 01:10:31 -08:00
Gilles Boccon-Gibod
a167342778 deal with SupportsBytes for python <= 3.10 2024-01-28 01:04:13 -08:00
Gilles Boccon-Gibod
1efb8cdbee use matrixed python version 2024-01-28 00:34:42 -08:00
Gilles Boccon-Gibod
80d83e6a70 upgrade to mypy 1.8.0 2024-01-28 00:26:50 -08:00
Gilles Boccon-Gibod
31ec1c41ce cleanup 2024-01-28 00:07:31 -08:00
Gilles Boccon-Gibod
aba1ac0cea use a dict instead of a series of ifs (+6 squashed commits)
Squashed commits:
[90f2024] fix import order
[0edd321] add a few docstrings
[77a0ac0] wip
[adcf159] wip
[96cbd67] wip
[d8bfbab] wip (+1 squashed commit)
Squashed commits:
[43b4d66] wip (+2 squashed commits)
Squashed commits:
[3dafaa8] wip
[5844026] wip (+1 squashed commit)
Squashed commits:
[4cbb35a] wip (+1 squashed commit)
Squashed commits:
[4d2b6d3] wip (+4 squashed commits)
Squashed commits:
[f2da510] wip
[318c119] wip
[923b4eb] wip
[9d46365] wip

use a dict instead of a series of ifs (+6 squashed commits)
Squashed commits:
[90f2024] fix import order
[0edd321] add a few docstrings
[77a0ac0] wip
[adcf159] wip
[96cbd67] wip
[d8bfbab] wip
2024-01-27 16:26:17 -08:00
Josh Wu
c40824e51c Support non-directed address generation offload 2024-01-26 16:02:40 +08:00
Gilles Boccon-Gibod
2920f05dae Merge pull request #411 from AlanRosenthal/main
Add bumble-controller-loopback console_scripts
2024-01-24 13:20:19 -08:00
Alan Rosenthal
bc911d6da0 Add bumble-controller-loopback console_scripts 2024-01-24 14:07:35 -05:00
Gilles Boccon-Gibod
4f87f587e4 Merge pull request #409 from google/gbg/root-canal-update
update to rootcanal 1.4
2024-01-22 15:07:20 -08:00
Gilles Boccon-Gibod
3e38ab3638 update to rootcanal 1.4 2024-01-22 12:19:12 -08:00
Gilles Boccon-Gibod
21bb911fea Merge pull request #408 from suneeshs/btbench-update
Update the Bumble BT Bench AndroidManifest.xml
2024-01-22 12:15:35 -08:00
Suneesh Sasikumar
744dfa33a2 Update the Bumble BT Bench AndroidManifest.xml 2024-01-22 13:46:55 -05:00
zxzxwu
ec5f8535a8 Merge pull request #405 from zxzxwu/adv
Make Advertisement dataclass
2024-01-20 11:05:41 +08:00
Gilles Boccon-Gibod
5a83734a00 Merge pull request #388 from google/gbg/scan-with-irk
allow passing IRKs as arguments to scan.py
2024-01-19 11:37:02 -08:00
Josh Wu
b4ae8af3a7 Typing Advertisement 2024-01-19 15:16:24 +08:00
Josh Wu
da60386385 Manage lifecycle of CIS and SCO links in host 2024-01-18 11:56:38 +08:00
zxzxwu
45c4c4f4c5 Merge pull request #404 from zxzxwu/cis
Fix HCI_LE_Set_Host_Feature_Command
2024-01-18 10:56:05 +08:00
zxzxwu
9187c75d68 Merge pull request #397 from zxzxwu/controller
Controller: CIS implementation
2024-01-18 10:55:37 +08:00
zxzxwu
abeec22546 Merge pull request #402 from zxzxwu/key
Save Link Key in CTKD over BR/EDR
2024-01-18 10:55:14 +08:00
Josh Wu
a6bab755cf Fix HCI_LE_Set_Host_Feature_Command 2024-01-17 22:15:15 +08:00
Josh Wu
acd9d994c3 Save link_key in CTKD over BR/EDR
Since keystore.update() overwrites all existing keys, the existing link
key will be wiped out. To avoid this, SMP also need to keep the key.
2024-01-17 19:30:02 +08:00
Gilles Boccon-Gibod
37afda3ed3 Merge pull request #399 from google/gbg/hotfix-002
fix uninitialized variable
2024-01-16 17:29:48 -08:00
Gilles Boccon-Gibod
54f2981267 fix uninitialized variable 2024-01-16 16:49:06 -08:00
Charlie Boutier
bb025514e7 PandoraHost: compute advertising_interval_max with interval_range 2024-01-12 14:36:22 -08:00
Charlie Boutier
e228597269 Pandora host: support advertising interval in advertise 2024-01-12 14:36:22 -08:00
Gilles Boccon-Gibod
95b0d6c6f2 Merge pull request #398 from google/gbg/rfcomm-no-sink
update credits even without a sink
2024-01-11 13:18:15 -08:00
Josh Wu
fa4df6e3a2 Controller: CIS implementation 2024-01-11 01:16:42 +08:00
zxzxwu
46ceea7ecd Merge pull request #391 from zxzxwu/remote_feature
LE read remote features
2024-01-10 15:51:32 +08:00
Gilles Boccon-Gibod
30f89d5739 simplify 2024-01-09 18:01:34 -08:00
Gilles Boccon-Gibod
481cf40831 update credits even without a sink 2024-01-09 17:58:52 -08:00
Josh Wu
eff05afb7a LE read remote features 2024-01-09 11:30:08 +08:00
zxzxwu
d8e6700611 Merge pull request #383 from zxzxwu/controller
Controller: SCO implementation
2024-01-09 09:39:13 +08:00
Gilles Boccon-Gibod
56eb5a933b Merge pull request #394 from google/gbg/hci-latency
add support for HCI latency probing
2024-01-08 09:21:00 -08:00
Gilles Boccon-Gibod
caacc0c133 Merge pull request #395 from google/gbg/loopback-quick-fix
compatibility with recent host ACL property changes
2024-01-08 09:20:45 -08:00
Gilles Boccon-Gibod
5f377c024b format 2024-01-05 12:26:54 -08:00
Gilles Boccon-Gibod
00cd8fbdd0 compatibility with recent host ACL property changes 2024-01-05 12:17:09 -08:00
Gilles Boccon-Gibod
aeeff18428 add support for HCI latency probing 2024-01-05 10:26:04 -08:00
Michael Mogenson
c48e3f5e9c Merge pull request #393 from mogenson/controller-loopback
apps: Add a controller loopback throughput test app
2024-01-05 13:13:30 -05:00
Michael Mogenson
d6bbc1145a apps: Add a controller loopback throughput test app
Add a command line utility to open a transport to a BT controller, put
the controller into local loopback mode, and send and receive ACL data
packets. Record the time it takes to send and receive all packets and
calculate a throughput measurement in kB/s.

This utility is usefull for characterizing the speed of a transport to a
BT controller (such as a TCP socket or serial port) without having to
deal with a connected peer or the variability of over the air
transmissions.

The transport CLI argument is required. The packet size and packet
count arguments are optional. They default to the same values as the
bumble-bench app.
2024-01-05 10:01:24 -05:00
zxzxwu
e2fec67bd9 Merge pull request #390 from zxzxwu/csip
CSIP: Encrypted SIRK implementation
2024-01-04 13:28:23 +08:00
Josh Wu
88cb3b2a4d IWYU in CSIP 2024-01-04 13:22:09 +08:00
zxzxwu
9ebb03be46 Merge pull request #389 from zxzxwu/gitignore
.gitignore: Add venv directories
2024-01-04 12:54:30 +08:00
Gilles Boccon-Gibod
80d84af76c Merge pull request #392 from google/gbg/l2cap-drain
l2cap & rfcomm drain support
2024-01-03 09:59:36 -08:00
Gilles Boccon-Gibod
8f4721758f fix typo 2024-01-03 09:53:17 -08:00
Gilles Boccon-Gibod
8864af4acd format 2024-01-02 11:35:11 -08:00
Gilles Boccon-Gibod
8980fb8cc7 add drain support and a few tool options 2024-01-02 11:07:52 -08:00
Josh Wu
2c5f3472a9 CSIP: Encrypted SIRK implementation 2023-12-30 16:06:42 +08:00
Josh Wu
f18277ac78 Ignore venv directories 2023-12-30 14:23:35 +08:00
Josh Wu
8d46bc04d2 Controller: SCO implementation 2023-12-30 14:22:58 +08:00
Gilles Boccon-Gibod
09e5ea5dec Merge pull request #387 from google/gbg/async-gatt-server
support async read/write for characteristic values
2023-12-29 11:28:22 -08:00
Gilles Boccon-Gibod
d43281c57e allow passing IRKs as arguments 2023-12-28 14:35:23 -08:00
Gilles Boccon-Gibod
6810865670 Merge pull request #385 from google/gbg/android-enable-dle
request MTU change after connection
2023-12-28 13:46:25 -08:00
Gilles Boccon-Gibod
3e9e06a02c Merge pull request #386 from AlanRosenthal/main
app/bench.py: use logging rather than print()
2023-12-28 13:42:17 -08:00
Alan Rosenthal
ccd12f6591 app/bench.py: use logging rather than print() 2023-12-28 16:06:50 -05:00
Gilles Boccon-Gibod
f9a7843f7e request MTU change after connection 2023-12-28 11:17:18 -08:00
Gilles Boccon-Gibod
210c334db7 Merge pull request #380 from google/gbg/classic-buffer-size
support per-transport ACL queues
2023-12-28 09:24:52 -08:00
Gilles Boccon-Gibod
f297cdfcce Merge pull request #384 from eukub/string-concatination-to-fstring
сhanged concatenation of strings to f-strings to improve readability
2023-12-28 09:24:25 -08:00
eukub
5b536d00ab сhanged concatenation of strings to f-strings to improve readability and unify with the rest of code 2023-12-28 16:27:36 +03:00
Gilles Boccon-Gibod
b4af46ebd5 use TCP_NODELAY on socket 2023-12-27 12:11:20 -08:00
Gilles Boccon-Gibod
c08da3193e format 2023-12-27 11:56:06 -08:00
Gilles Boccon-Gibod
f2925ca647 support async read/write for characteristic values 2023-12-27 11:52:22 -08:00
Gilles Boccon-Gibod
fd4d68e5c0 print controller flow control info 2023-12-26 13:24:24 -08:00
Gilles Boccon-Gibod
5d83deffa4 Merge pull request #345 from rdhavan/bumble_hid_device
Bumble hid device implementation - Application and hid profile
2023-12-26 11:10:34 -08:00
Gilles Boccon-Gibod
2878cca478 Merge pull request #378 from benquike/pair_linger
Improve the linger option of bumble-pair
2023-12-26 10:55:28 -08:00
Gilles Boccon-Gibod
53934716db Merge pull request #377 from benquike/irk
Add functions/tool for gen/verifying BLE IRK/RPA
2023-12-26 10:54:18 -08:00
Hui Peng
d885d45824 Add functions/tool for gen/verifying BLE IRK/RPA 2023-12-26 09:34:19 -08:00
Gilles Boccon-Gibod
b90d0f8710 fix tests 2023-12-26 09:09:20 -08:00
zxzxwu
8ccfc90fe6 Merge pull request #379 from zxzxwu/addr
Add random address generation methods
2023-12-25 17:28:49 +08:00
Josh Wu
92aa7e9e2a Add random address generation methods 2023-12-24 18:07:40 +08:00
Gilles Boccon-Gibod
afc6d19e04 address PR comments 2023-12-23 14:21:44 -08:00
Gilles Boccon-Gibod
c05f073b33 Update bumble/host.py
Co-authored-by: zxzxwu <92432172+zxzxwu@users.noreply.github.com>
2023-12-23 14:15:53 -08:00
Gilles Boccon-Gibod
2b4c2a22f4 format 2023-12-22 14:22:08 -08:00
Gilles Boccon-Gibod
47fe93a148 support per-transport ACL queues 2023-12-22 13:52:33 -08:00
zxzxwu
6139ca8045 Merge pull request #374 from zxzxwu/csip
Complete CSIP and CAP
2023-12-23 02:49:35 +08:00
Josh Wu
87c76a4a0e Complete CSIP and CAP
Also add random address generation functions.
2023-12-23 02:14:32 +08:00
Hui Peng
f7b66db873 Improve the linger option in pair tool
No matter pairing fails or not, make linger effective
2023-12-21 17:25:42 -08:00
skarnataki
0b314bd7f7 Updated absctract class and method for on_ctrl_pdu in hid.py 2023-12-18 13:36:25 +00:00
skarnataki
9da2e32ad7 Review comment Fix 3 - rename json file and usage of Optional in parameters 2023-12-15 09:42:57 +00:00
Snehal Karnataki
93c0875740 Merge branch 'google:main' into bumble_hid_device 2023-12-13 09:51:27 +00:00
Gilles Boccon-Gibod
a286700239 Merge pull request #368 from google/gbg/driver-load-before-reset
support drivers that can't use reset directly.
2023-12-11 18:06:23 -08:00
Gilles Boccon-Gibod
98ed772e8a address PR comments and add some typing 2023-12-11 17:52:04 -08:00
Gilles Boccon-Gibod
f0b55a4f97 Merge pull request #367 from google/gbg/android-bench-update
Android bench app: add support for 2M phy
2023-12-11 10:20:56 -08:00
zxzxwu
b74503d345 Merge pull request #359 from zxzxwu/ascs
Audio Stream Control Service
2023-12-12 00:47:03 +08:00
Josh Wu
f911163e49 Improve ASCS logging 2023-12-12 00:36:24 +08:00
Gilles Boccon-Gibod
b083cc99ad fix spec parsing 2023-12-08 18:57:02 -08:00
Gilles Boccon-Gibod
d35643524e allow specifying the address type 2023-12-08 18:46:25 -08:00
Gilles Boccon-Gibod
62a8ced447 support drivers that can't use reset directly. 2023-12-08 17:28:57 -08:00
Gilles Boccon-Gibod
085f163c92 add support for 2M phy 2023-12-08 10:14:38 -08:00
Josh Wu
81a6b1e097 Replace 3.9 dict merger 2023-12-08 11:10:17 +08:00
Josh Wu
dd090c9e6b Add ASCS tests 2023-12-08 11:00:44 +08:00
Josh Wu
11faa48422 Fix ASE state change 2023-12-08 09:53:14 +08:00
Josh Wu
55596176c2 ffplay routing 2023-12-08 09:53:14 +08:00
Josh Wu
4d6822d312 Remove ISO data path on release 2023-12-08 09:53:14 +08:00
Josh Wu
985c365e6d Setup data path after CIS established 2023-12-08 09:53:14 +08:00
Josh Wu
af57762227 Parse CodecSpecificConfiguration 2023-12-08 09:53:14 +08:00
Josh Wu
3575f9030e Add Audio Stream Control Service 2023-12-08 09:53:14 +08:00
zxzxwu
698d947d85 Merge pull request #366 from zxzxwu/extadv
Add advertiser classes and handle adv set terminated events
2023-12-08 09:52:42 +08:00
Josh Wu
ff6528d2bf Add Advertising unit tests 2023-12-08 01:38:01 +08:00
Josh Wu
72ac75a98d Add advertiser classes and handle adv set terminated events
* Convert hci.OwnAddressType to enum
* Add LegacyAdvertiser and ExtendedAdvertiser classes
* Rename start/stop_advertising() => start/stop_legacy_advertising()
* Handle HCI_Advertising_Set_Terminated
* Properly restart advertisement on disconnection
2023-12-07 15:51:51 +08:00
skarnataki
5e3ecb74e4 Review comment fix -2 2023-12-05 13:41:30 +00:00
Snehal Karnataki
c59be293c8 Merge branch 'google:main' into bumble_hid_device 2023-12-05 13:07:36 +00:00
zxzxwu
88b4cbdf1a Merge pull request #364 from zxzxwu/iso
Fix ISO packet issues
2023-12-05 00:41:56 +08:00
Josh Wu
d6afbc6f4e Fix ISO packet issues 2023-12-04 20:31:11 +08:00
Gilles Boccon-Gibod
fc90de3e7b Merge pull request #351 from google/dependabot/cargo/rust/openssl-0.10.60
Bump openssl from 0.10.57 to 0.10.60 in /rust
2023-12-04 00:41:27 -08:00
Gilles Boccon-Gibod
847c2ef114 Merge pull request #362 from google/gbg/more-le-features-constants
a few more HCI constants from the spec
2023-12-04 00:38:02 -08:00
Gilles Boccon-Gibod
a0bf0c1f4d Merge pull request #363 from google/gbg/android-remote-proxy-cli
android remote proxy cli
2023-12-04 00:37:49 -08:00
Snehal Karnataki
6d22ed80ec Merge branch 'google:main' into bumble_hid_device 2023-12-04 07:29:04 +00:00
Gilles Boccon-Gibod
843466c822 a few more constants from the spec 2023-12-03 17:16:25 -08:00
zxzxwu
3adcc8be09 Merge pull request #360 from zxzxwu/hci
Remove # type: ignore[call-arg] in HCI_Command builders
2023-12-03 19:18:04 +08:00
zxzxwu
c853d56302 Merge pull request #361 from zxzxwu/hci-bug
Fix typo
2023-12-03 04:22:59 +08:00
Josh Wu
dc97be5b35 Fix typo 2023-12-02 23:42:21 +08:00
zxzxwu
73dbdfff9f Merge pull request #356 from zxzxwu/bap
Add Published Audio Capabilities Service
2023-12-02 23:34:57 +08:00
Josh Wu
dff14e1258 Add Published Audio Capabilities Service 2023-12-02 23:16:37 +08:00
Josh Wu
10a3833893 Remove # type: ignore[call-arg] in HCI_Command builders 2023-12-02 19:18:54 +08:00
Snehal Karnataki
ffb3eca68b Merge branch 'google:main' into bumble_hid_device 2023-11-30 04:50:05 +00:00
dependabot[bot]
7eb493990f Bump openssl from 0.10.57 to 0.10.60 in /rust
Bumps [openssl](https://github.com/sfackler/rust-openssl) from 0.10.57 to 0.10.60.
- [Release notes](https://github.com/sfackler/rust-openssl/releases)
- [Commits](https://github.com/sfackler/rust-openssl/compare/openssl-v0.10.57...openssl-v0.10.60)

---
updated-dependencies:
- dependency-name: openssl
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-11-28 21:43:18 +00:00
skarnataki
403a13e4c6 Review comment fix HID device 2023-11-28 13:42:25 +00:00
Snehal Karnataki
ad0f035df5 Merge branch 'google:main' into bumble_hid_device 2023-11-28 13:06:32 +00:00
skarnataki
07f71fc895 Project format and lint error fix. Redefination if Device class needs to be discussed 2023-11-27 13:04:54 +00:00
Fahad Afroze
f47b9178ad Added GET_REPORT and SET_REPORT changes
Added changes to handle invalid cases
2023-11-27 11:55:35 +00:00
SneKarnataki
4f399249bd Merge branch 'google:main' into bumble_hid_device 2023-11-27 09:00:44 +00:00
skarnataki
9324237828 send_data comment fix and lint error fix 2023-11-24 11:13:20 +00:00
Fahad Afroze
d1033c018a Modified DeviceData class 2023-11-24 05:42:31 +00:00
Fahad Afroze
0f29052ade Added mousemove changes
Also modified keyboard data on keyup
2023-11-23 17:46:55 +00:00
skarnataki
0578e84586 Menu and name change review comments fix 2023-11-23 15:43:22 +00:00
Fahad Afroze
6ab41c466f Add review comment changes 3 2023-11-23 12:27:56 +00:00
Fahad Afroze
98a1093ebf Add review comment changes 2
Also corrected sending mouseData
2023-11-23 09:53:16 +00:00
dhavan
caf04373f3 keyboard data moved to DeviceData class 2023-11-23 08:01:07 +00:00
SneKarnataki
d4e8526766 Merge branch 'google:main' into bumble_hid_device 2023-11-23 07:59:43 +00:00
dhavan
515b83a8c7 deleted: bumble/classic3.json
modified:   examples/keyboard.html
2023-11-23 06:10:52 +00:00
dhavan
dc18595c8a MTU size check added 2023-11-23 05:17:44 +00:00
SneKarnataki
488bcfe9c6 Merge branch 'google:main' into bumble_hid_device 2023-11-23 04:03:53 +00:00
dhavan
d6cefdff8e Renamed the status message class 2023-11-22 17:14:24 +00:00
dhavan
dc410b14c4 SET_REPORT and GET_REPORT implemented 2023-11-22 16:05:33 +00:00
dhavan
4c49ef9403 SET_REPORT implemented 2023-11-22 12:31:34 +00:00
dhavan
ba85dcbda5 Get the changes from hid_device to bumble_hid_device
Modified the get_report_cb
2023-11-22 11:06:27 +00:00
270 changed files with 34132 additions and 5608 deletions

View File

@@ -0,0 +1,30 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/python
{
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
"image": "mcr.microsoft.com/devcontainers/universal:2",
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand":
"python -m pip install '.[build,test,development,documentation]'",
// Configure tool-specific properties.
"customizations": {
// Configure properties specific to VS Code.
"vscode": {
// Add the IDs of extensions you want installed when the container is created.
"extensions": [
"ms-python.python"
]
}
}
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}

View File

@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13.0"]
fail-fast: false
steps:
@@ -29,7 +29,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip

View File

@@ -40,4 +40,11 @@ jobs:
avatar --list | grep -Ev '^=' > test-names.txt
timeout 5m avatar --test-beds bumble.bumbles --tests $(split test-names.txt -n l/${{ matrix.shard }})
- name: Rootcanal Logs
if: always()
run: cat rootcanal.log
- name: Upload Mobly logs
if: always()
uses: actions/upload-artifact@v3
with:
name: mobly-logs
path: /tmp/logs/mobly/bumble.bumbles/

View File

@@ -16,7 +16,7 @@ jobs:
strategy:
matrix:
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
fail-fast: false
steps:
@@ -46,8 +46,8 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ "3.8", "3.9", "3.10", "3.11" ]
rust-version: [ "1.70.0", "stable" ]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
rust-version: [ "1.76.0", "stable" ]
fail-fast: false
steps:
- name: Check out from Git

7
.gitignore vendored
View File

@@ -6,7 +6,14 @@ dist/
docs/mkdocs/site
test-results.xml
__pycache__
# Vim
.*.sw*
# generated by setuptools_scm
bumble/_version.py
.vscode/launch.json
.vscode/settings.json
/.idea
venv/
.venv/
# snoop logs
out/

10
.vscode/settings.json vendored
View File

@@ -1,6 +1,7 @@
{
"cSpell.words": [
"Abortable",
"aiohttp",
"altsetting",
"ansiblue",
"ansicyan",
@@ -9,10 +10,13 @@
"ansired",
"ansiyellow",
"appendleft",
"ascs",
"ASHA",
"asyncio",
"ATRAC",
"avctp",
"avdtp",
"avrcp",
"bitpool",
"bitstruct",
"BSCP",
@@ -22,6 +26,7 @@
"cmac",
"CONNECTIONLESS",
"csip",
"csis",
"csrcs",
"CVSD",
"datagram",
@@ -32,6 +37,7 @@
"dhkey",
"diversifier",
"endianness",
"ESCO",
"Fitbit",
"GATTLINK",
"HANDSFREE",
@@ -39,6 +45,7 @@
"keyup",
"levelname",
"libc",
"liblc",
"libusb",
"MITM",
"MSBC",
@@ -70,8 +77,11 @@
"substates",
"tobytes",
"tsep",
"UNMUTE",
"unmuted",
"usbmodem",
"vhci",
"wasmtime",
"websockets",
"xcursor",
"ycursor"

701
apps/auracast.py Normal file
View File

@@ -0,0 +1,701 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import contextlib
import dataclasses
import logging
import os
from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple
import click
import pyee
from bumble.colors import color
import bumble.company_ids
import bumble.core
import bumble.device
import bumble.gatt
import bumble.hci
import bumble.profiles.bap
import bumble.profiles.bass
import bumble.profiles.pbp
import bumble.transport
import bumble.utils
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast'
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address('F0:F1:F2:F3:F4:F5')
AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0
AURACAST_DEFAULT_ATT_MTU = 256
# -----------------------------------------------------------------------------
# Scan For Broadcasts
# -----------------------------------------------------------------------------
class BroadcastScanner(pyee.EventEmitter):
@dataclasses.dataclass
class Broadcast(pyee.EventEmitter):
name: str | None
sync: bumble.device.PeriodicAdvertisingSync
rssi: int = 0
public_broadcast_announcement: Optional[
bumble.profiles.pbp.PublicBroadcastAnnouncement
] = None
broadcast_audio_announcement: Optional[
bumble.profiles.bap.BroadcastAudioAnnouncement
] = None
basic_audio_announcement: Optional[
bumble.profiles.bap.BasicAudioAnnouncement
] = None
appearance: Optional[bumble.core.Appearance] = None
biginfo: Optional[bumble.device.BIGInfoAdvertisement] = None
manufacturer_data: Optional[Tuple[str, bytes]] = None
def __post_init__(self) -> None:
super().__init__()
self.sync.on('establishment', self.on_sync_establishment)
self.sync.on('loss', self.on_sync_loss)
self.sync.on('periodic_advertisement', self.on_periodic_advertisement)
self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement)
def update(self, advertisement: bumble.device.Advertisement) -> None:
self.rssi = advertisement.rssi
for service_data in advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA
):
assert isinstance(service_data, tuple)
service_uuid, data = service_data
assert isinstance(data, bytes)
if (
service_uuid
== bumble.gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE
):
self.public_broadcast_announcement = (
bumble.profiles.pbp.PublicBroadcastAnnouncement.from_bytes(data)
)
continue
if (
service_uuid
== bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
):
self.broadcast_audio_announcement = (
bumble.profiles.bap.BroadcastAudioAnnouncement.from_bytes(data)
)
continue
self.appearance = advertisement.data.get( # type: ignore[assignment]
bumble.core.AdvertisingData.APPEARANCE
)
if manufacturer_data := advertisement.data.get(
bumble.core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA
):
assert isinstance(manufacturer_data, tuple)
company_id = cast(int, manufacturer_data[0])
data = cast(bytes, manufacturer_data[1])
self.manufacturer_data = (
bumble.company_ids.COMPANY_IDENTIFIERS.get(
company_id, f'0x{company_id:04X}'
),
data,
)
self.emit('update')
def print(self) -> None:
print(
color('Broadcast:', 'yellow'),
self.sync.advertiser_address,
color(self.sync.state.name, 'green'),
)
if self.name is not None:
print(f' {color("Name", "cyan")}: {self.name}')
if self.appearance:
print(f' {color("Appearance", "cyan")}: {str(self.appearance)}')
print(f' {color("RSSI", "cyan")}: {self.rssi}')
print(f' {color("SID", "cyan")}: {self.sync.sid}')
if self.manufacturer_data:
print(
f' {color("Manufacturer Data", "cyan")}: '
f'{self.manufacturer_data[0]} -> {self.manufacturer_data[1].hex()}'
)
if self.broadcast_audio_announcement:
print(
f' {color("Broadcast ID", "cyan")}: '
f'{self.broadcast_audio_announcement.broadcast_id}'
)
if self.public_broadcast_announcement:
print(
f' {color("Features", "cyan")}: '
f'{self.public_broadcast_announcement.features}'
)
print(
f' {color("Metadata", "cyan")}: '
f'{self.public_broadcast_announcement.metadata}'
)
if self.basic_audio_announcement:
print(color(' Audio:', 'cyan'))
print(
color(' Presentation Delay:', 'magenta'),
self.basic_audio_announcement.presentation_delay,
)
for subgroup in self.basic_audio_announcement.subgroups:
print(color(' Subgroup:', 'magenta'))
print(color(' Codec ID:', 'yellow'))
print(
color(' Coding Format: ', 'green'),
subgroup.codec_id.codec_id.name,
)
print(
color(' Company ID: ', 'green'),
subgroup.codec_id.company_id,
)
print(
color(' Vendor Specific Codec ID:', 'green'),
subgroup.codec_id.vendor_specific_codec_id,
)
print(
color(' Codec Config:', 'yellow'),
subgroup.codec_specific_configuration,
)
print(color(' Metadata: ', 'yellow'), subgroup.metadata)
for bis in subgroup.bis:
print(color(f' BIS [{bis.index}]:', 'yellow'))
print(
color(' Codec Config:', 'green'),
bis.codec_specific_configuration,
)
if self.biginfo:
print(color(' BIG:', 'cyan'))
print(
color(' Number of BIS:', 'magenta'),
self.biginfo.num_bis,
)
print(
color(' PHY: ', 'magenta'),
self.biginfo.phy.name,
)
print(
color(' Framed: ', 'magenta'),
self.biginfo.framed,
)
print(
color(' Encrypted: ', 'magenta'),
self.biginfo.encrypted,
)
def on_sync_establishment(self) -> None:
self.emit('sync_establishment')
def on_sync_loss(self) -> None:
self.basic_audio_announcement = None
self.biginfo = None
self.emit('sync_loss')
def on_periodic_advertisement(
self, advertisement: bumble.device.PeriodicAdvertisement
) -> None:
if advertisement.data is None:
return
for service_data in advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA
):
assert isinstance(service_data, tuple)
service_uuid, data = service_data
assert isinstance(data, bytes)
if service_uuid == bumble.gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
self.basic_audio_announcement = (
bumble.profiles.bap.BasicAudioAnnouncement.from_bytes(data)
)
break
self.emit('change')
def on_biginfo_advertisement(
self, advertisement: bumble.device.BIGInfoAdvertisement
) -> None:
self.biginfo = advertisement
self.emit('change')
def __init__(
self,
device: bumble.device.Device,
filter_duplicates: bool,
sync_timeout: float,
):
super().__init__()
self.device = device
self.filter_duplicates = filter_duplicates
self.sync_timeout = sync_timeout
self.broadcasts: Dict[bumble.hci.Address, BroadcastScanner.Broadcast] = {}
device.on('advertisement', self.on_advertisement)
async def start(self) -> None:
await self.device.start_scanning(
active=False,
filter_duplicates=False,
)
async def stop(self) -> None:
await self.device.stop_scanning()
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
if not (
ads := advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA_16_BIT_UUID
)
) or not (
any(
ad
for ad in ads
if isinstance(ad, tuple)
and ad[0] == bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
)
):
return
broadcast_name = advertisement.data.get(
bumble.core.AdvertisingData.BROADCAST_NAME
)
assert isinstance(broadcast_name, str) or broadcast_name is None
if broadcast := self.broadcasts.get(advertisement.address):
broadcast.update(advertisement)
return
bumble.utils.AsyncRunner.spawn(
self.on_new_broadcast(broadcast_name, advertisement)
)
async def on_new_broadcast(
self, name: str | None, advertisement: bumble.device.Advertisement
) -> None:
periodic_advertising_sync = await self.device.create_periodic_advertising_sync(
advertiser_address=advertisement.address,
sid=advertisement.sid,
sync_timeout=self.sync_timeout,
filter_duplicates=self.filter_duplicates,
)
broadcast = self.Broadcast(name, periodic_advertising_sync)
broadcast.update(advertisement)
self.broadcasts[advertisement.address] = broadcast
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
self.emit('new_broadcast', broadcast)
def on_broadcast_loss(self, broadcast: Broadcast) -> None:
del self.broadcasts[broadcast.sync.advertiser_address]
bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate())
self.emit('broadcast_loss', broadcast)
class PrintingBroadcastScanner:
def __init__(
self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float
) -> None:
self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout)
self.scanner.on('new_broadcast', self.on_new_broadcast)
self.scanner.on('broadcast_loss', self.on_broadcast_loss)
self.scanner.on('update', self.refresh)
self.status_message = ''
async def start(self) -> None:
self.status_message = color('Scanning...', 'green')
await self.scanner.start()
def on_new_broadcast(self, broadcast: BroadcastScanner.Broadcast) -> None:
self.status_message = color(
f'+Found {len(self.scanner.broadcasts)} broadcasts', 'green'
)
broadcast.on('change', self.refresh)
broadcast.on('update', self.refresh)
self.refresh()
def on_broadcast_loss(self, broadcast: BroadcastScanner.Broadcast) -> None:
self.status_message = color(
f'-Found {len(self.scanner.broadcasts)} broadcasts', 'green'
)
self.refresh()
def refresh(self) -> None:
# Clear the screen from the top
print('\033[H')
print('\033[0J')
print('\033[H')
# Print the status message
print(self.status_message)
print("==========================================")
# Print all broadcasts
for broadcast in self.scanner.broadcasts.values():
broadcast.print()
print('------------------------------------------')
# Clear the screen to the bottom
print('\033[0J')
@contextlib.asynccontextmanager
async def create_device(transport: str) -> AsyncGenerator[bumble.device.Device, Any]:
async with await bumble.transport.open_transport(transport) as (
hci_source,
hci_sink,
):
device_config = bumble.device.DeviceConfiguration(
name=AURACAST_DEFAULT_DEVICE_NAME,
address=AURACAST_DEFAULT_DEVICE_ADDRESS,
keystore='JsonKeyStore',
)
device = bumble.device.Device.from_config_with_hci(
device_config,
hci_source,
hci_sink,
)
await device.power_on()
yield device
async def find_broadcast_by_name(
device: bumble.device.Device, name: Optional[str]
) -> BroadcastScanner.Broadcast:
result = asyncio.get_running_loop().create_future()
def on_broadcast_change(broadcast: BroadcastScanner.Broadcast) -> None:
if broadcast.basic_audio_announcement and not result.done():
print(color('Broadcast basic audio announcement received', 'green'))
result.set_result(broadcast)
def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
if name is None or broadcast.name == name:
print(color('Broadcast found:', 'green'), broadcast.name)
broadcast.on('change', lambda: on_broadcast_change(broadcast))
return
print(color(f'Skipping broadcast {broadcast.name}'))
scanner = BroadcastScanner(device, False, AURACAST_DEFAULT_SYNC_TIMEOUT)
scanner.on('new_broadcast', on_new_broadcast)
await scanner.start()
broadcast = await result
await scanner.stop()
return broadcast
async def run_scan(
filter_duplicates: bool, sync_timeout: float, transport: str
) -> None:
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return
scanner = PrintingBroadcastScanner(device, filter_duplicates, sync_timeout)
await scanner.start()
await asyncio.get_running_loop().create_future()
async def run_assist(
broadcast_name: Optional[str],
source_id: Optional[int],
command: str,
transport: str,
address: str,
) -> None:
async with create_device(transport) as device:
if not device.supports_le_periodic_advertising:
print(color('Periodic advertising not supported', 'red'))
return
# Connect to the server
print(f'=== Connecting to {address}...')
connection = await device.connect(address)
peer = bumble.device.Peer(connection)
print(f'=== Connected to {peer}')
print("+++ Encrypting connection...")
await peer.connection.encrypt()
print("+++ Connection encrypted")
# Request a larger MTU
mtu = AURACAST_DEFAULT_ATT_MTU
print(color(f'$$$ Requesting MTU={mtu}', 'yellow'))
await peer.request_mtu(mtu)
# Get the BASS service
bass = await peer.discover_service_and_create_proxy(
bumble.profiles.bass.BroadcastAudioScanServiceProxy
)
# Check that the service was found
if not bass:
print(color('!!! Broadcast Audio Scan Service not found', 'red'))
return
# Subscribe to and read the broadcast receive state characteristics
for i, broadcast_receive_state in enumerate(bass.broadcast_receive_states):
try:
await broadcast_receive_state.subscribe(
lambda value, i=i: print(
f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}"
)
)
except bumble.core.ProtocolError as error:
print(
color(
f'!!! Failed to subscribe to Broadcast Receive State characteristic:',
'red',
),
error,
)
value = await broadcast_receive_state.read_value()
print(
f'{color(f"Initial Broadcast Receive State [{i}]:", "green")} {value}'
)
if command == 'monitor-state':
await peer.sustain()
return
if command == 'add-source':
# Find the requested broadcast
await bass.remote_scan_started()
if broadcast_name:
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
else:
print(color('Scanning for any broadcast', 'cyan'))
broadcast = await find_broadcast_by_name(device, broadcast_name)
if broadcast.broadcast_audio_announcement is None:
print(color('No broadcast audio announcement found', 'red'))
return
if (
broadcast.basic_audio_announcement is None
or not broadcast.basic_audio_announcement.subgroups
):
print(color('No subgroups found', 'red'))
return
# Add the source
print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address)
await bass.add_source(
broadcast.sync.advertiser_address,
broadcast.sync.sid,
broadcast.broadcast_audio_announcement.broadcast_id,
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
0xFFFF,
[
bumble.profiles.bass.SubgroupInfo(
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
)
],
)
# Initiate a PA Sync Transfer
await broadcast.sync.transfer(peer.connection)
# Notify the sink that we're done scanning.
await bass.remote_scan_stopped()
await peer.sustain()
return
if command == 'modify-source':
if source_id is None:
print(color('!!! modify-source requires --source-id'))
return
# Find the requested broadcast
await bass.remote_scan_started()
if broadcast_name:
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
else:
print(color('Scanning for any broadcast', 'cyan'))
broadcast = await find_broadcast_by_name(device, broadcast_name)
if broadcast.broadcast_audio_announcement is None:
print(color('No broadcast audio announcement found', 'red'))
return
if (
broadcast.basic_audio_announcement is None
or not broadcast.basic_audio_announcement.subgroups
):
print(color('No subgroups found', 'red'))
return
# Modify the source
print(
color('Modifying source:', 'blue'),
source_id,
)
await bass.modify_source(
source_id,
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
0xFFFF,
[
bumble.profiles.bass.SubgroupInfo(
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
)
],
)
await peer.sustain()
return
if command == 'remove-source':
if source_id is None:
print(color('!!! remove-source requires --source-id'))
return
# Remove the source
print(color('Removing source:', 'blue'), source_id)
await bass.remove_source(source_id)
await peer.sustain()
return
print(color(f'!!! invalid command {command}'))
async def run_pair(transport: str, address: str) -> None:
async with create_device(transport) as device:
# Connect to the server
print(f'=== Connecting to {address}...')
async with device.connect_as_gatt(address) as peer:
print(f'=== Connected to {peer}')
print("+++ Initiating pairing...")
await peer.connection.pair()
print("+++ Paired")
def run_async(async_command: Coroutine) -> None:
try:
asyncio.run(async_command)
except bumble.core.ProtocolError as error:
if error.error_namespace == 'att' and error.error_code in list(
bumble.profiles.bass.ApplicationError
):
message = bumble.profiles.bass.ApplicationError(error.error_code).name
else:
message = str(error)
print(
color('!!! An error occurred while executing the command:', 'red'), message
)
# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
def auracast(
ctx,
):
ctx.ensure_object(dict)
@auracast.command('scan')
@click.option(
'--filter-duplicates', is_flag=True, default=False, help='Filter duplicates'
)
@click.option(
'--sync-timeout',
metavar='SYNC_TIMEOUT',
type=float,
default=AURACAST_DEFAULT_SYNC_TIMEOUT,
help='Sync timeout (in seconds)',
)
@click.argument('transport')
@click.pass_context
def scan(ctx, filter_duplicates, sync_timeout, transport):
"""Scan for public broadcasts"""
run_async(run_scan(filter_duplicates, sync_timeout, transport))
@auracast.command('assist')
@click.option(
'--broadcast-name',
metavar='BROADCAST_NAME',
help='Broadcast Name to tune to',
)
@click.option(
'--source-id',
metavar='SOURCE_ID',
type=int,
help='Source ID (for remove-source command)',
)
@click.option(
'--command',
type=click.Choice(
['monitor-state', 'add-source', 'modify-source', 'remove-source']
),
required=True,
)
@click.argument('transport')
@click.argument('address')
@click.pass_context
def assist(ctx, broadcast_name, source_id, command, transport, address):
"""Scan for broadcasts on behalf of a audio server"""
run_async(run_assist(broadcast_name, source_id, command, transport, address))
@auracast.command('pair')
@click.argument('transport')
@click.argument('address')
@click.pass_context
def pair(ctx, transport, address):
"""Pair with an audio server"""
run_async(run_pair(transport, address))
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
auracast()
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter

File diff suppressed because it is too large Load Diff

63
apps/ble_rpa_tool.py Normal file
View File

@@ -0,0 +1,63 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import click
from bumble.colors import color
from bumble.hci import Address
from bumble.helpers import generate_irk, verify_rpa_with_irk
@click.group()
def cli():
'''
This is a tool for generating IRK, RPA,
and verifying IRK/RPA pairs
'''
@click.command()
def gen_irk() -> None:
print(generate_irk().hex())
@click.command()
@click.argument("irk", type=str)
def gen_rpa(irk: str) -> None:
irk_bytes = bytes.fromhex(irk)
rpa = Address.generate_private_address(irk_bytes)
print(rpa.to_string(with_type_qualifier=False))
@click.command()
@click.argument("irk", type=str)
@click.argument("rpa", type=str)
def verify_rpa(irk: str, rpa: str) -> None:
address = Address(rpa)
irk_bytes = bytes.fromhex(irk)
if verify_rpa_with_irk(address, irk_bytes):
print(color("Verified", "green"))
else:
print(color("Not Verified", "red"))
def main():
cli.add_command(gen_irk)
cli.add_command(gen_rpa)
cli.add_command(verify_rpa)
cli()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()

View File

@@ -63,6 +63,7 @@ from bumble.transport import open_transport_or_link
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
from bumble.gatt_client import CharacteristicProxy
from bumble.hci import (
Address,
HCI_Constant,
HCI_LE_1M_PHY,
HCI_LE_2M_PHY,
@@ -289,11 +290,7 @@ class ConsoleApp:
device_config, hci_source, hci_sink
)
else:
random_address = (
f"{random.randint(192,255):02X}" # address is static random
)
for random_byte in random.sample(range(255), 5):
random_address += f":{random_byte:02X}"
random_address = Address.generate_static_address()
self.append_to_log(f"Setting random address: {random_address}")
self.device = Device.with_hci(
'Bumble', random_address, hci_source, hci_sink
@@ -503,21 +500,9 @@ class ConsoleApp:
self.show_error('not connected')
return
# Discover all services, characteristics and descriptors
self.append_to_output('discovering services...')
await self.connected_peer.discover_services()
self.append_to_output(
f'found {len(self.connected_peer.services)} services,'
' discovering characteristics...'
)
await self.connected_peer.discover_characteristics()
self.append_to_output('found characteristics, discovering descriptors...')
for service in self.connected_peer.services:
for characteristic in service.characteristics:
await self.connected_peer.discover_descriptors(characteristic)
self.append_to_output('discovery completed')
self.show_remote_services(self.connected_peer.services)
self.append_to_output('Service Discovery starting...')
await self.connected_peer.discover_all()
self.append_to_output('Service Discovery done!')
async def discover_attributes(self):
if not self.connected_peer:
@@ -777,7 +762,7 @@ class ConsoleApp:
if not service:
continue
values = [
attribute.read_value(connection)
await attribute.read_value(connection)
for connection in self.device.connections.values()
]
if not values:
@@ -796,11 +781,11 @@ class ConsoleApp:
if not characteristic:
continue
values = [
attribute.read_value(connection)
await attribute.read_value(connection)
for connection in self.device.connections.values()
]
if not values:
values = [attribute.read_value(None)]
values = [await attribute.read_value(None)]
# TODO: future optimization: convert CCCD value to human readable string
@@ -944,7 +929,7 @@ class ConsoleApp:
# send data to any subscribers
if isinstance(attribute, Characteristic):
attribute.write_value(None, value)
await attribute.write_value(None, value)
if attribute.has_properties(Characteristic.NOTIFY):
await self.device.gatt_server.notify_subscribers(attribute)
if attribute.has_properties(Characteristic.INDICATE):

View File

@@ -18,24 +18,31 @@
import asyncio
import os
import logging
import click
from bumble.company_ids import COMPANY_IDENTIFIERS
import time
import click
from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.colors import color
from bumble.core import name_or_number
from bumble.hci import (
map_null_terminated_utf8_string,
CodecID,
LeFeature,
HCI_SUCCESS,
HCI_LE_SUPPORTED_FEATURES_NAMES,
HCI_VERSION_NAMES,
LMP_VERSION_NAMES,
HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_Read_Buffer_Size_Command,
HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_Read_Local_Name_Command,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
@@ -44,6 +51,9 @@ from bumble.hci import (
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_Read_Local_Supported_Codecs_Command,
HCI_Read_Local_Supported_Codecs_V2_Command,
HCI_Read_Local_Version_Information_Command,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
@@ -59,7 +69,7 @@ def command_succeeded(response):
# -----------------------------------------------------------------------------
async def get_classic_info(host):
async def get_classic_info(host: Host) -> None:
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response):
@@ -80,7 +90,7 @@ async def get_classic_info(host):
# -----------------------------------------------------------------------------
async def get_le_info(host):
async def get_le_info(host: Host) -> None:
print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
@@ -133,11 +143,90 @@ async def get_le_info(host):
print(color('LE Features:', 'yellow'))
for feature in host.supported_le_features:
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
print(f' {LeFeature(feature).name}')
# -----------------------------------------------------------------------------
async def async_main(transport):
async def get_acl_flow_control_info(host: Host) -> None:
print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
print(
color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
)
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
)
# -----------------------------------------------------------------------------
async def get_codecs_info(host: Host) -> None:
print()
if host.supports_command(HCI_Read_Local_Supported_Codecs_V2_Command.op_code):
response = await host.send_command(
HCI_Read_Local_Supported_Codecs_V2_Command(), check_result=True
)
print(color('Codecs:', 'yellow'))
for codec_id, transport in zip(
response.return_parameters.standard_codec_ids,
response.return_parameters.standard_codec_transports,
):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
transport
).name
codec_name = CodecID(codec_id).name
print(f' {codec_name} - {transport_name}')
for codec_id, transport in zip(
response.return_parameters.vendor_specific_codec_ids,
response.return_parameters.vendor_specific_codec_transports,
):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
transport
).name
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF} - {transport_name}')
if not response.return_parameters.standard_codec_ids:
print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids:
print(' No Vendor-specific codecs')
if host.supports_command(HCI_Read_Local_Supported_Codecs_Command.op_code):
response = await host.send_command(
HCI_Read_Local_Supported_Codecs_Command(), check_result=True
)
print(color('Codecs (BR/EDR):', 'yellow'))
for codec_id in response.return_parameters.standard_codec_ids:
codec_name = CodecID(codec_id).name
print(f' {codec_name}')
for codec_id in response.return_parameters.vendor_specific_codec_ids:
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF}')
if not response.return_parameters.standard_codec_ids:
print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids:
print(' No Vendor-specific codecs')
# -----------------------------------------------------------------------------
async def async_main(latency_probes, transport):
print('<<< connecting to HCI...')
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected')
@@ -145,6 +234,23 @@ async def async_main(transport):
host = Host(hci_source, hci_sink)
await host.reset()
# Measure the latency if requested
latencies = []
if latency_probes:
for _ in range(latency_probes):
start = time.time()
await host.send_command(HCI_Read_Local_Version_Information_Command())
latencies.append(1000 * (time.time() - start))
print(
color('HCI Command Latency:', 'yellow'),
(
f'min={min(latencies):.2f}, '
f'max={max(latencies):.2f}, '
f'average={sum(latencies)/len(latencies):.2f}'
),
'\n',
)
# Print version
print(color('Version:', 'yellow'))
print(
@@ -168,19 +274,31 @@ async def async_main(transport):
# Get the LE info
await get_le_info(host)
# Print the ACL flow control info
await get_acl_flow_control_info(host)
# Get codec info
await get_codecs_info(host)
# Print the list of commands supported by the controller
print()
print(color('Supported Commands:', 'yellow'))
for command in host.supported_commands:
print(' ', HCI_Command.command_name(command))
print(f' {HCI_Command.command_name(command)}')
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--latency-probes',
metavar='N',
type=int,
help='Send N commands to measure HCI transport latency statistics',
)
@click.argument('transport')
def main(transport):
def main(latency_probes, transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(async_main(transport))
asyncio.run(async_main(latency_probes, transport))
# -----------------------------------------------------------------------------

205
apps/controller_loopback.py Normal file
View File

@@ -0,0 +1,205 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import time
from typing import Optional
from bumble.colors import color
from bumble.hci import (
HCI_READ_LOOPBACK_MODE_COMMAND,
HCI_Read_Loopback_Mode_Command,
HCI_WRITE_LOOPBACK_MODE_COMMAND,
HCI_Write_Loopback_Mode_Command,
LoopbackMode,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
import click
class Loopback:
"""Send and receive ACL data packets in local loopback mode"""
def __init__(self, packet_size: int, packet_count: int, transport: str):
self.transport = transport
self.packet_size = packet_size
self.packet_count = packet_count
self.connection_handle: Optional[int] = None
self.connection_event = asyncio.Event()
self.done = asyncio.Event()
self.expected_cid = 0
self.bytes_received = 0
self.start_timestamp = 0.0
self.last_timestamp = 0.0
def on_connection(self, connection_handle: int, *args):
"""Retrieve connection handle from new connection event"""
if not self.connection_event.is_set():
# save first connection handle for ACL
# subsequent connections are SCO
self.connection_handle = connection_handle
self.connection_event.set()
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
"""Calculate packet receive speed"""
now = time.time()
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
assert connection_handle == self.connection_handle
assert cid == self.expected_cid
self.expected_cid += 1
if cid == 0:
self.start_timestamp = now
else:
elapsed_since_start = now - self.start_timestamp
elapsed_since_last = now - self.last_timestamp
self.bytes_received += len(pdu)
instant_rx_speed = len(pdu) / elapsed_since_last
average_rx_speed = self.bytes_received / elapsed_since_start
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f}',
'cyan',
)
)
self.last_timestamp = now
if self.expected_cid == self.packet_count:
print(color('@@@ Received last packet', 'green'))
self.done.set()
async def run(self):
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport_or_link(self.transport) as (
hci_source,
hci_sink,
):
print(color('>>> Connected', 'green'))
host = Host(hci_source, hci_sink)
await host.reset()
# make sure data can fit in one l2cap pdu
l2cap_header_size = 4
max_packet_size = (
host.acl_packet_queue
if host.acl_packet_queue
else host.le_acl_packet_queue
).max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size:
print(
color(
f'!!! Packet size ({self.packet_size}) larger than max supported'
f' size ({max_packet_size})',
'red',
)
)
return
if not host.supports_command(
HCI_WRITE_LOOPBACK_MODE_COMMAND
) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND):
print(color('!!! Loopback mode not supported', 'red'))
return
# set event callbacks
host.on('connection', self.on_connection)
host.on('l2cap_pdu', self.on_l2cap_pdu)
loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue'))
await host.send_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
)
print(color('### Checking loopback mode', 'blue'))
response = await host.send_command(
HCI_Read_Loopback_Mode_Command(), check_result=True
)
if response.return_parameters.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red'))
return
await self.connection_event.wait()
print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta'))
start_time = time.time()
bytes_sent = 0
for cid in range(0, self.packet_count):
# using the cid as an incremental index
host.send_l2cap_pdu(
self.connection_handle, cid, bytes(self.packet_size)
)
print(
color(
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
)
)
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
await asyncio.sleep(0) # yield to allow packet receive
await self.done.wait()
print(color('=== Done!', 'magenta'))
elapsed = time.time() - start_time
average_tx_speed = bytes_sent / elapsed
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
f' in {elapsed:.2f} seconds)',
'green',
)
)
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--packet-size',
'-s',
metavar='SIZE',
type=click.IntRange(8, 4096),
default=500,
help='Packet size',
)
@click.option(
'--packet-count',
'-c',
metavar='COUNT',
type=click.IntRange(1, 65535),
default=10,
help='Packet count',
)
@click.argument('transport')
def main(packet_size, packet_count, transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
loopback = Loopback(packet_size, packet_count, transport)
asyncio.run(loopback.run())
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()

230
apps/device_info.py Normal file
View File

@@ -0,0 +1,230 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import os
import logging
from typing import Callable, Iterable, Optional
import click
from bumble.core import ProtocolError
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.gatt import Service
from bumble.profiles.device_information_service import DeviceInformationServiceProxy
from bumble.profiles.battery_service import BatteryServiceProxy
from bumble.profiles.gap import GenericAccessServiceProxy
from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy
from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def try_show(function: Callable, *args, **kwargs) -> None:
try:
await function(*args, **kwargs)
except ProtocolError as error:
print(color('ERROR:', 'red'), error)
# -----------------------------------------------------------------------------
def show_services(services: Iterable[Service]) -> None:
for service in services:
print(color(str(service), 'cyan'))
for characteristic in service.characteristics:
print(color(' ' + str(characteristic), 'magenta'))
# -----------------------------------------------------------------------------
async def show_gap_information(
gap_service: GenericAccessServiceProxy,
):
print(color('### Generic Access Profile', 'yellow'))
if gap_service.device_name:
print(
color(' Device Name:', 'green'),
await gap_service.device_name.read_value(),
)
if gap_service.appearance:
print(
color(' Appearance: ', 'green'),
await gap_service.appearance.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_device_information(
device_information_service: DeviceInformationServiceProxy,
):
print(color('### Device Information', 'yellow'))
if device_information_service.manufacturer_name:
print(
color(' Manufacturer Name:', 'green'),
await device_information_service.manufacturer_name.read_value(),
)
if device_information_service.model_number:
print(
color(' Model Number: ', 'green'),
await device_information_service.model_number.read_value(),
)
if device_information_service.serial_number:
print(
color(' Serial Number: ', 'green'),
await device_information_service.serial_number.read_value(),
)
if device_information_service.firmware_revision:
print(
color(' Firmware Revision:', 'green'),
await device_information_service.firmware_revision.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_battery_level(
battery_service: BatteryServiceProxy,
):
print(color('### Battery Information', 'yellow'))
if battery_service.battery_level:
print(
color(' Battery Level:', 'green'),
await battery_service.battery_level.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_tmas(
tmas: TelephonyAndMediaAudioServiceProxy,
):
print(color('### Telephony And Media Audio Service', 'yellow'))
if tmas.role:
print(
color(' Role:', 'green'),
await tmas.role.read_value(),
)
print()
# -----------------------------------------------------------------------------
async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
try:
# Discover all services
print(color('### Discovering Services and Characteristics', 'magenta'))
await peer.discover_services()
for service in peer.services:
await service.discover_characteristics()
print(color('=== Services ===', 'yellow'))
show_services(peer.services)
print()
if gap_service := peer.create_service_proxy(GenericAccessServiceProxy):
await try_show(show_gap_information, gap_service)
if device_information_service := peer.create_service_proxy(
DeviceInformationServiceProxy
):
await try_show(show_device_information, device_information_service)
if battery_service := peer.create_service_proxy(BatteryServiceProxy):
await try_show(show_battery_level, battery_service)
if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy):
await try_show(show_tmas, tmas)
if done is not None:
done.set_result(None)
except asyncio.CancelledError:
print(color('!!! Operation canceled', 'red'))
# -----------------------------------------------------------------------------
async def async_main(device_config, encrypt, transport, address_or_name):
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
# Create a device
if device_config:
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
await device.power_on()
if address_or_name:
# Connect to the target peer
print(color('>>> Connecting...', 'green'))
connection = await device.connect(address_or_name)
print(color('>>> Connected', 'green'))
# Encrypt the connection if required
if encrypt:
print(color('+++ Encrypting connection...', 'blue'))
await connection.encrypt()
print(color('+++ Encryption established', 'blue'))
await show_device_info(Peer(connection), None)
else:
# Wait for a connection
done = asyncio.get_running_loop().create_future()
device.on(
'connection',
lambda connection: asyncio.create_task(
show_device_info(Peer(connection), done)
),
)
await device.start_advertising(auto_restart=True)
print(color('### Waiting for connection...', 'blue'))
await done
# -----------------------------------------------------------------------------
@click.command()
@click.option('--device-config', help='Device configuration', type=click.Path())
@click.option('--encrypt', help='Encrypt the connection', is_flag=True, default=False)
@click.argument('transport')
@click.argument('address-or-name', required=False)
def main(device_config, encrypt, transport, address_or_name):
"""
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
wait for an incoming connection.
"""
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main(device_config, encrypt, transport, address_or_name))
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()

View File

@@ -75,11 +75,15 @@ async def async_main(device_config, encrypt, transport, address_or_name):
if address_or_name:
# Connect to the target peer
print(color('>>> Connecting...', 'green'))
connection = await device.connect(address_or_name)
print(color('>>> Connected', 'green'))
# Encrypt the connection if required
if encrypt:
print(color('+++ Encrypting connection...', 'blue'))
await connection.encrypt()
print(color('+++ Encryption established', 'blue'))
await dump_gatt_db(Peer(connection), None)
else:

View File

@@ -83,7 +83,7 @@ async def async_main():
return_parameters=bytes([hci.HCI_SUCCESS]),
)
# Return a packet with 'respond to sender' set to True
return (response.to_bytes(), True)
return (bytes(response), True)
return None

View File

@@ -49,14 +49,16 @@ class ServerBridge:
self.tcp_port = tcp_port
async def start(self, device: Device) -> None:
# Listen for incoming L2CAP CoC connections
# Listen for incoming L2CAP channel connections
device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(
psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits
),
handler=self.on_coc,
handler=self.on_channel,
)
print(
color(f'### Listening for channel connection on PSM {self.psm}', 'yellow')
)
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection):
def on_ble_disconnection(reason):
@@ -73,7 +75,7 @@ class ServerBridge:
await device.start_advertising(auto_restart=True)
# Called when a new L2CAP connection is established
def on_coc(self, l2cap_channel):
def on_channel(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
class Pipe:
@@ -83,7 +85,7 @@ class ServerBridge:
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_coc_sdu
l2cap_channel.sink = self.on_channel_sdu
async def connect_to_tcp(self):
# Connect to the TCP server
@@ -128,7 +130,7 @@ class ServerBridge:
if self.tcp_transport is not None:
self.tcp_transport.close()
def on_coc_sdu(self, sdu):
def on_channel_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
if self.tcp_transport is None:
print(color('!!! TCP socket not open, dropping', 'red'))
@@ -183,7 +185,7 @@ class ClientBridge:
peer_name = writer.get_extra_info('peer_name')
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
def on_coc_sdu(sdu):
def on_channel_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
l2cap_to_tcp_pipe.write(sdu)
@@ -209,7 +211,7 @@ class ClientBridge:
writer.close()
return
l2cap_channel.sink = on_coc_sdu
l2cap_channel.sink = on_channel_sdu
l2cap_channel.on('close', on_l2cap_close)
# Start a flow control pipe from L2CAP to TCP
@@ -274,23 +276,29 @@ async def run(device_config, hci_transport, bridge):
@click.pass_context
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option('--psm', help='PSM for L2CAP', type=int, default=1234)
@click.option(
'--l2cap-coc-max-credits',
help='Maximum L2CAP CoC Credits',
'--l2cap-max-credits',
help='Maximum L2CAP Credits',
type=click.IntRange(1, 65535),
default=128,
)
@click.option(
'--l2cap-coc-mtu',
help='L2CAP CoC MTU',
type=click.IntRange(23, 65535),
default=1022,
'--l2cap-mtu',
help='L2CAP MTU',
type=click.IntRange(
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU,
),
default=1024,
)
@click.option(
'--l2cap-coc-mps',
help='L2CAP CoC MPS',
type=click.IntRange(23, 65533),
'--l2cap-mps',
help='L2CAP MPS',
type=click.IntRange(
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS,
),
default=1024,
)
def cli(
@@ -298,17 +306,17 @@ def cli(
device_config,
hci_transport,
psm,
l2cap_coc_max_credits,
l2cap_coc_mtu,
l2cap_coc_mps,
l2cap_max_credits,
l2cap_mtu,
l2cap_mps,
):
context.ensure_object(dict)
context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_coc_max_credits
context.obj['mtu'] = l2cap_coc_mtu
context.obj['mps'] = l2cap_coc_mps
context.obj['max_credits'] = l2cap_max_credits
context.obj['mtu'] = l2cap_mtu
context.obj['mps'] = l2cap_mps
# -----------------------------------------------------------------------------

594
apps/lea_unicast/app.py Normal file
View File

@@ -0,0 +1,594 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import datetime
import enum
import functools
from importlib import resources
import json
import os
import logging
import pathlib
from typing import Optional, List, cast
import weakref
import struct
import ctypes
import wasmtime
import wasmtime.loader
import liblc3 # type: ignore
import click
import aiohttp.web
import bumble
from bumble.core import AdvertisingData
from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
from bumble.transport import open_transport
from bumble.profiles import ascs, bap, pacs
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654
def _sink_pac_record() -> pacs.PacRecord:
return pacs.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
bap.SupportedSamplingFrequency.FREQ_8000
| bap.SupportedSamplingFrequency.FREQ_16000
| bap.SupportedSamplingFrequency.FREQ_24000
| bap.SupportedSamplingFrequency.FREQ_32000
| bap.SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1, 2],
min_octets_per_codec_frame=26,
max_octets_per_codec_frame=240,
supported_max_codec_frames_per_sdu=2,
),
)
def _source_pac_record() -> pacs.PacRecord:
return pacs.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
bap.SupportedSamplingFrequency.FREQ_8000
| bap.SupportedSamplingFrequency.FREQ_16000
| bap.SupportedSamplingFrequency.FREQ_24000
| bap.SupportedSamplingFrequency.FREQ_32000
| bap.SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1],
min_octets_per_codec_frame=30,
max_octets_per_codec_frame=100,
supported_max_codec_frames_per_sdu=1,
),
)
# -----------------------------------------------------------------------------
# WASM - liblc3
# -----------------------------------------------------------------------------
store = wasmtime.loader.store
_memory = cast(wasmtime.Memory, liblc3.memory)
STACK_POINTER = _memory.data_len(store)
_memory.grow(store, 1)
# Mapping wasmtime memory to linear address
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
)
class Liblc3PcmFormat(enum.IntEnum):
S16 = 0
S24 = 1
S24_3LE = 2
FLOAT = 3
MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)
DECODER_STACK_POINTER = STACK_POINTER
ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
DEFAULT_PCM_SAMPLE_RATE = 48000
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
encoders: List[int] = []
decoders: List[int] = []
def setup_encoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
encoders[:num_channels] = [
liblc3.lc3_setup_encoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Input sample rate
ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
)
for i in range(num_channels)
]
def setup_decoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
decoders[:num_channels] = [
liblc3.lc3_setup_decoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
)
for i in range(num_channels)
]
def decode(
frame_duration_us: int,
num_channels: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = DECODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
input_bytes_per_frame = input_buffer_size // num_channels
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
* num_channels
)
for i in range(num_channels):
res = liblc3.lc3_decode(
decoders[i],
input_buffer_offset + input_bytes_per_frame * i,
input_bytes_per_frame,
DEFAULT_PCM_FORMAT,
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
num_channels, # Stride
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
def encode(
sdu_length: int,
num_channels: int,
stride: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = sdu_length
output_frame_size = output_buffer_size // num_channels
for i in range(num_channels):
res = liblc3.lc3_encode(
encoders[i],
DEFAULT_PCM_FORMAT,
input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
stride,
output_frame_size,
output_buffer_offset + output_frame_size * i,
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
async def lc3_source_task(
filename: str,
sdu_length: int,
frame_duration_us: int,
device: Device,
cis_handle: int,
) -> None:
with open(filename, 'rb') as f:
header = f.read(44)
assert header[8:12] == b'WAVE'
pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = (
struct.unpack("<HIIHH", header[22:36])
)
assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
frame_bytes = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
)
packet_sequence_number = 0
while True:
next_round = datetime.datetime.now() + datetime.timedelta(
microseconds=frame_duration_us
)
pcm_data = f.read(frame_bytes)
sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data)
iso_packet = HCI_IsoDataPacket(
connection_handle=cis_handle,
data_total_length=sdu_length + 4,
packet_sequence_number=packet_sequence_number,
pb_flag=0b10,
packet_status_flag=0,
iso_sdu_length=sdu_length,
iso_sdu_fragment=sdu,
)
device.host.send_hci_packet(iso_packet)
packet_sequence_number += 1
sleep_time = next_round - datetime.datetime.now()
await asyncio.sleep(sleep_time.total_seconds())
# -----------------------------------------------------------------------------
class UiServer:
speaker: weakref.ReferenceType[Speaker]
port: int
def __init__(self, speaker: Speaker, port: int) -> None:
self.speaker = weakref.ref(speaker)
self.port = port
self.channel_socket = None
async def start_http(self) -> None:
"""Start the UI HTTP server."""
app = aiohttp.web.Application()
app.add_routes(
[
aiohttp.web.get('/', self.get_static),
aiohttp.web.get('/index.html', self.get_static),
aiohttp.web.get('/channel', self.get_channel),
]
)
runner = aiohttp.web.AppRunner(app)
await runner.setup()
site = aiohttp.web.TCPSite(runner, 'localhost', self.port)
print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green'))
await site.start()
async def get_static(self, request):
path = request.path
if path == '/':
path = '/index.html'
if path.endswith('.html'):
content_type = 'text/html'
elif path.endswith('.js'):
content_type = 'text/javascript'
elif path.endswith('.css'):
content_type = 'text/css'
elif path.endswith('.svg'):
content_type = 'image/svg+xml'
else:
content_type = 'text/plain'
text = (
resources.files("bumble.apps.lea_unicast")
.joinpath(pathlib.Path(path).relative_to('/'))
.read_text(encoding="utf-8")
)
return aiohttp.web.Response(text=text, content_type=content_type)
async def get_channel(self, request):
ws = aiohttp.web.WebSocketResponse()
await ws.prepare(request)
# Process messages until the socket is closed.
self.channel_socket = ws
async for message in ws:
if message.type == aiohttp.WSMsgType.TEXT:
logger.debug(f'<<< received message: {message.data}')
await self.on_message(message.data)
elif message.type == aiohttp.WSMsgType.ERROR:
logger.debug(
f'channel connection closed with exception {ws.exception()}'
)
self.channel_socket = None
logger.debug('--- channel connection closed')
return ws
async def on_message(self, message_str: str):
# Parse the message as JSON
message = json.loads(message_str)
# Dispatch the message
message_type = message['type']
message_params = message.get('params', {})
handler = getattr(self, f'on_{message_type}_message')
if handler:
await handler(**message_params)
async def on_hello_message(self):
await self.send_message(
'hello',
bumble_version=bumble.__version__,
codec=self.speaker().codec,
streamState=self.speaker().stream_state.name,
)
if connection := self.speaker().connection:
await self.send_message(
'connection',
peer_address=connection.peer_address.to_string(False),
peer_name=connection.peer_name,
)
async def send_message(self, message_type: str, **kwargs) -> None:
if self.channel_socket is None:
return
message = {'type': message_type, 'params': kwargs}
await self.channel_socket.send_json(message)
async def send_audio(self, data: bytes) -> None:
if self.channel_socket is None:
return
try:
await self.channel_socket.send_bytes(data)
except Exception as error:
logger.warning(f'exception while sending audio packet: {error}')
# -----------------------------------------------------------------------------
class Speaker:
def __init__(
self,
device_config_path: Optional[str],
ui_port: int,
transport: str,
lc3_input_file_path: str,
):
self.device_config_path = device_config_path
self.transport = transport
self.lc3_input_file_path = lc3_input_file_path
# Create an HTTP server for the UI
self.ui_server = UiServer(speaker=self, port=ui_port)
async def run(self) -> None:
await self.ui_server.start_http()
async with await open_transport(self.transport) as hci_transport:
# Create a device
if self.device_config_path:
device_config = DeviceConfiguration.from_file(self.device_config_path)
else:
device_config = DeviceConfiguration(
name="Bumble LE Headphone",
class_of_device=0x244418,
keystore="JsonKeyStore",
advertising_interval_min=25,
advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'),
)
device_config.le_enabled = True
device_config.cis_enabled = True
self.device = Device.from_config_with_hci(
device_config, hci_transport.source, hci_transport.sink
)
self.device.add_service(
pacs.PublishedAudioCapabilitiesService(
supported_source_context=bap.ContextType(0xFFFF),
available_source_context=bap.ContextType(0xFFFF),
supported_sink_context=bap.ContextType(0xFFFF), # All context types
available_sink_context=bap.ContextType(0xFFFF), # All context types
sink_audio_locations=(
bap.AudioLocation.FRONT_LEFT | bap.AudioLocation.FRONT_RIGHT
),
sink_pac=[_sink_pac_record()],
source_audio_locations=bap.AudioLocation.FRONT_LEFT,
source_pac=[_source_pac_record()],
)
)
ascs_service = ascs.AudioStreamControlService(
self.device, sink_ase_id=[1], source_ase_id=[2]
)
self.device.add_service(ascs_service)
advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(device_config.name, 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(pacs.PublishedAudioCapabilitiesService.UUID),
),
]
)
) + bytes(bap.UnicastServerAdvertisingData())
def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine):
codec_config = ase.codec_specific_configuration
if (
not isinstance(codec_config, bap.CodecSpecificConfiguration)
or codec_config.frame_duration is None
or codec_config.audio_channel_allocation is None
):
return
pcm = decode(
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
pdu.iso_sdu_fragment,
)
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
def on_ase_state_change(ase: ascs.AseStateMachine) -> None:
codec_config = ase.codec_specific_configuration
if ase.state == ascs.AseStateMachine.State.STREAMING:
if ase.role == ascs.AudioRole.SOURCE:
if (
not isinstance(codec_config, bap.CodecSpecificConfiguration)
or ase.cis_link is None
or codec_config.octets_per_codec_frame is None
or codec_config.frame_duration is None
or codec_config.codec_frames_per_sdu is None
):
return
ase.cis_link.abort_on(
'disconnection',
lc3_source_task(
filename=self.lc3_input_file_path,
sdu_length=(
codec_config.codec_frames_per_sdu
* codec_config.octets_per_codec_frame
),
frame_duration_us=codec_config.frame_duration.us,
device=self.device,
cis_handle=ase.cis_link.handle,
),
)
else:
if not ase.cis_link:
return
ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED:
if (
not isinstance(codec_config, bap.CodecSpecificConfiguration)
or codec_config.sampling_frequency is None
or codec_config.frame_duration is None
or codec_config.audio_channel_allocation is None
):
return
if ase.role == ascs.AudioRole.SOURCE:
setup_encoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
else:
setup_decoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
for ase in ascs_service.ase_state_machines.values():
ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
await self.device.power_on()
await self.device.create_advertising_set(
advertising_data=advertising_data,
auto_restart=True,
advertising_parameters=AdvertisingParameters(
primary_advertising_interval_min=100,
primary_advertising_interval_max=100,
),
)
await hci_transport.source.terminated
@click.command()
@click.option(
'--ui-port',
'ui_port',
metavar='HTTP_PORT',
default=DEFAULT_UI_PORT,
show_default=True,
help='HTTP port for the UI server',
)
@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
@click.argument('transport')
@click.argument('lc3_file')
def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) -> None:
"""Run the speaker."""
asyncio.run(Speaker(device_config, ui_port, transport, lc3_file).run())
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
speaker()
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter

View File

@@ -0,0 +1,68 @@
<html data-bs-theme="dark">
<head>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
<script src="https://unpkg.com/pcm-player"></script>
</head>
<body>
<nav class="navbar navbar-dark bg-primary">
<div class="container">
<span class="navbar-brand mb-0 h1">Bumble Unicast Server</span>
</div>
</nav>
<br>
<div class="container">
<button type="button" class="btn btn-danger" id="connect-audio" onclick="connectAudio()">Connect Audio</button>
<button class="btn btn-primary" type="button" disabled>
<span class="spinner-border spinner-border-sm" id="ws-status-spinner" aria-hidden="true"></span>
<span role="status" id="ws-status">WebSocket Connecting...</span>
</button>
</div>
<script>
let player = null;
const wsStatus = document.getElementById("ws-status");
const wsStatusSpinner = document.getElementById("ws-status-spinner");
const socket = new WebSocket('ws://127.0.0.1:7654/channel');
socket.binaryType = "arraybuffer";
socket.onmessage = function (message) {
if (typeof message.data === 'string' || message.data instanceof String) {
console.log(`channel MESSAGE: ${message.data}`);
} else {
console.log(typeof (message.data))
// BINARY audio data.
if (player == null) return;
player.feed(message.data);
}
};
socket.onopen = (message) => {
wsStatusSpinner.remove();
wsStatus.textContent = "WebSocket Connected";
}
socket.onclose = (message) => {
wsStatus.textContent = "WebSocket Disconnected";
}
function connectAudio() {
player = new PCMPlayer({
inputCodec: 'Int16',
channels: 2,
sampleRate: 48000,
flushTime: 10,
});
const button = document.getElementById("connect-audio")
button.disabled = true;
button.textContent = "Audio Connected";
}
</script>
</div>
</body>
</html>

BIN
apps/lea_unicast/liblc3.wasm Executable file

Binary file not shown.

View File

@@ -46,17 +46,25 @@ from bumble.att import (
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
ATT_INSUFFICIENT_ENCRYPTION_ERROR,
)
from bumble.utils import AsyncRunner
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
POST_PAIRING_DELAY = 1
# -----------------------------------------------------------------------------
class Waiter:
instance = None
def __init__(self):
def __init__(self, linger=False):
self.done = asyncio.get_running_loop().create_future()
self.linger = linger
def terminate(self):
self.done.set_result(None)
if not self.linger:
self.done.set_result(None)
async def wait_until_terminated(self):
return await self.done
@@ -233,8 +241,10 @@ def on_connection(connection, request):
# Listen for pairing events
connection.on('pairing_start', on_pairing_start)
connection.on('pairing', lambda keys: on_pairing(connection.peer_address, keys))
connection.on('pairing_failure', on_pairing_failure)
connection.on('pairing', lambda keys: on_pairing(connection, keys))
connection.on(
'pairing_failure', lambda reason: on_pairing_failure(connection, reason)
)
# Listen for encryption changes
connection.on(
@@ -268,19 +278,24 @@ def on_pairing_start():
# -----------------------------------------------------------------------------
def on_pairing(address, keys):
@AsyncRunner.run_in_task()
async def on_pairing(connection, keys):
print(color('***-----------------------------------', 'cyan'))
print(color(f'*** Paired! (peer identity={address})', 'cyan'))
print(color(f'*** Paired! (peer identity={connection.peer_address})', 'cyan'))
keys.print(prefix=color('*** ', 'cyan'))
print(color('***-----------------------------------', 'cyan'))
await asyncio.sleep(POST_PAIRING_DELAY)
await connection.disconnect()
Waiter.instance.terminate()
# -----------------------------------------------------------------------------
def on_pairing_failure(reason):
@AsyncRunner.run_in_task()
async def on_pairing_failure(connection, reason):
print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color('***-----------------------------------', 'red'))
await connection.disconnect()
Waiter.instance.terminate()
@@ -291,6 +306,7 @@ async def pair(
mitm,
bond,
ctkd,
identity_address,
linger,
io,
oob,
@@ -302,7 +318,7 @@ async def pair(
hci_transport,
address_or_name,
):
Waiter.instance = Waiter()
Waiter.instance = Waiter(linger=linger)
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
@@ -357,7 +373,9 @@ async def pair(
shared_data = (
None
if oob == '-'
else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob)))
else OobData.from_ad(
AdvertisingData.from_bytes(bytes.fromhex(oob))
).shared_data
)
legacy_context = OobLegacyContext()
oob_contexts = PairingConfig.OobConfig(
@@ -365,26 +383,36 @@ async def pair(
peer_data=shared_data,
legacy_context=legacy_context,
)
oob_data = OobData(
address=device.random_address,
shared_data=shared_data,
legacy_context=legacy_context,
)
print(color('@@@-----------------------------------', 'yellow'))
print(color('@@@ OOB Data:', 'yellow'))
print(color(f'@@@ {our_oob_context.share()}', 'yellow'))
if shared_data is None:
oob_data = OobData(
address=device.random_address, shared_data=our_oob_context.share()
)
print(
color(
f'@@@ SHARE: {bytes(oob_data.to_ad()).hex()}',
'yellow',
)
)
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
print(color(f'@@@ HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow'))
print(color('@@@-----------------------------------', 'yellow'))
else:
oob_contexts = None
# Set up a pairing config factory
if identity_address == 'public':
identity_address_type = PairingConfig.AddressType.PUBLIC
elif identity_address == 'random':
identity_address_type = PairingConfig.AddressType.RANDOM
else:
identity_address_type = None
device.pairing_config_factory = lambda connection: PairingConfig(
sc=sc,
mitm=mitm,
bonding=bond,
oob=oob_contexts,
identity_address_type=identity_address_type,
delegate=Delegate(mode, connection, io, prompt),
)
@@ -396,7 +424,6 @@ async def pair(
address_or_name,
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
)
pairing_failure = False
if not request:
try:
@@ -405,11 +432,8 @@ async def pair(
else:
await connection.authenticate()
except ProtocolError as error:
pairing_failure = True
print(color(f'Pairing failed: {error}', 'red'))
if not linger or pairing_failure:
return
else:
if mode == 'le':
# Advertise so that peers can find us and connect
@@ -459,7 +483,11 @@ class LogHandler(logging.Handler):
help='Enable CTKD',
show_default=True,
)
@click.option('--linger', default=True, is_flag=True, help='Linger after pairing')
@click.option(
'--identity-address',
type=click.Choice(['random', 'public']),
)
@click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
@click.option(
'--io',
type=click.Choice(
@@ -495,6 +523,7 @@ def main(
mitm,
bond,
ctkd,
identity_address,
linger,
io,
oob,
@@ -520,6 +549,7 @@ def main(
mitm,
bond,
ctkd,
identity_address,
linger,
io,
oob,

608
apps/player/player.py Normal file
View File

@@ -0,0 +1,608 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import asyncio.subprocess
import os
import logging
from typing import Optional, Union
import click
from bumble.a2dp import (
make_audio_source_service_sdp_records,
A2DP_SBC_CODEC_TYPE,
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
A2DP_NON_A2DP_CODEC_TYPE,
AacFrame,
AacParser,
AacPacketSource,
AacMediaCodecInformation,
SbcFrame,
SbcParser,
SbcPacketSource,
SbcMediaCodecInformation,
OpusPacket,
OpusParser,
OpusPacketSource,
OpusMediaCodecInformation,
)
from bumble.avrcp import Protocol as AvrcpProtocol
from bumble.avdtp import (
find_avdtp_service_with_connection,
AVDTP_AUDIO_MEDIA_TYPE,
AVDTP_DELAY_REPORTING_SERVICE_CATEGORY,
MediaCodecCapabilities,
MediaPacketPump,
Protocol as AvdtpProtocol,
)
from bumble.colors import color
from bumble.core import (
AdvertisingData,
ConnectionError as BumbleConnectionError,
DeviceClass,
BT_BR_EDR_TRANSPORT,
)
from bumble.device import Connection, Device, DeviceConfiguration
from bumble.hci import Address, HCI_CONNECTION_ALREADY_EXISTS_ERROR, HCI_Constant
from bumble.pairing import PairingConfig
from bumble.transport import open_transport
from bumble.utils import AsyncRunner
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def a2dp_source_sdp_records():
service_record_handle = 0x00010001
return {
service_record_handle: make_audio_source_service_sdp_records(
service_record_handle
)
}
# -----------------------------------------------------------------------------
async def sbc_codec_capabilities(read_function) -> MediaCodecCapabilities:
sbc_parser = SbcParser(read_function)
sbc_frame: SbcFrame
async for sbc_frame in sbc_parser.frames:
# We only need the first frame
print(color(f"SBC format: {sbc_frame}", "cyan"))
break
channel_mode = [
SbcMediaCodecInformation.ChannelMode.MONO,
SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL,
SbcMediaCodecInformation.ChannelMode.STEREO,
SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
][sbc_frame.channel_mode]
block_length = {
4: SbcMediaCodecInformation.BlockLength.BL_4,
8: SbcMediaCodecInformation.BlockLength.BL_8,
12: SbcMediaCodecInformation.BlockLength.BL_12,
16: SbcMediaCodecInformation.BlockLength.BL_16,
}[sbc_frame.block_count]
subbands = {
4: SbcMediaCodecInformation.Subbands.S_4,
8: SbcMediaCodecInformation.Subbands.S_8,
}[sbc_frame.subband_count]
allocation_method = [
SbcMediaCodecInformation.AllocationMethod.LOUDNESS,
SbcMediaCodecInformation.AllocationMethod.SNR,
][sbc_frame.allocation_method]
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation(
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.from_int(
sbc_frame.sampling_frequency
),
channel_mode=channel_mode,
block_length=block_length,
subbands=subbands,
allocation_method=allocation_method,
minimum_bitpool_value=2,
maximum_bitpool_value=40,
),
)
# -----------------------------------------------------------------------------
async def aac_codec_capabilities(read_function) -> MediaCodecCapabilities:
aac_parser = AacParser(read_function)
aac_frame: AacFrame
async for aac_frame in aac_parser.frames:
# We only need the first frame
print(color(f"AAC format: {aac_frame}", "cyan"))
break
sampling_frequency = AacMediaCodecInformation.SamplingFrequency.from_int(
aac_frame.sampling_frequency
)
channels = (
AacMediaCodecInformation.Channels.MONO
if aac_frame.channel_configuration == 1
else AacMediaCodecInformation.Channels.STEREO
)
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE,
media_codec_information=AacMediaCodecInformation(
object_type=AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC,
sampling_frequency=sampling_frequency,
channels=channels,
vbr=1,
bitrate=128000,
),
)
# -----------------------------------------------------------------------------
async def opus_codec_capabilities(read_function) -> MediaCodecCapabilities:
opus_parser = OpusParser(read_function)
opus_packet: OpusPacket
async for opus_packet in opus_parser.packets:
# We only need the first packet
print(color(f"Opus format: {opus_packet}", "cyan"))
break
if opus_packet.channel_mode == OpusPacket.ChannelMode.MONO:
channel_mode = OpusMediaCodecInformation.ChannelMode.MONO
elif opus_packet.channel_mode == OpusPacket.ChannelMode.STEREO:
channel_mode = OpusMediaCodecInformation.ChannelMode.STEREO
else:
channel_mode = OpusMediaCodecInformation.ChannelMode.DUAL_MONO
if opus_packet.duration == 10:
frame_size = OpusMediaCodecInformation.FrameSize.FS_10MS
else:
frame_size = OpusMediaCodecInformation.FrameSize.FS_20MS
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_NON_A2DP_CODEC_TYPE,
media_codec_information=OpusMediaCodecInformation(
channel_mode=channel_mode,
sampling_frequency=OpusMediaCodecInformation.SamplingFrequency.SF_48000,
frame_size=frame_size,
),
)
# -----------------------------------------------------------------------------
class Player:
def __init__(
self,
transport: str,
device_config: Optional[str],
authenticate: bool,
encrypt: bool,
) -> None:
self.transport = transport
self.device_config = device_config
self.authenticate = authenticate
self.encrypt = encrypt
self.avrcp_protocol: Optional[AvrcpProtocol] = None
self.done: Optional[asyncio.Event]
async def run(self, workload) -> None:
self.done = asyncio.Event()
try:
await self._run(workload)
except Exception as error:
print(color(f"!!! ERROR: {error}", "red"))
async def _run(self, workload) -> None:
async with await open_transport(self.transport) as (hci_source, hci_sink):
# Create a device
device_config = DeviceConfiguration()
if self.device_config:
device_config.load_from_file(self.device_config)
else:
device_config.name = "Bumble Player"
device_config.class_of_device = DeviceClass.pack_class_of_device(
DeviceClass.AUDIO_SERVICE_CLASS,
DeviceClass.AUDIO_VIDEO_MAJOR_DEVICE_CLASS,
DeviceClass.AUDIO_VIDEO_UNCATEGORIZED_MINOR_DEVICE_CLASS,
)
device_config.keystore = "JsonKeyStore"
device_config.classic_enabled = True
device_config.le_enabled = False
device_config.le_simultaneous_enabled = False
device_config.classic_sc_enabled = False
device_config.classic_smp_enabled = False
device = Device.from_config_with_hci(device_config, hci_source, hci_sink)
# Setup the SDP records to expose the SRC service
device.sdp_service_records = a2dp_source_sdp_records()
# Setup AVRCP
self.avrcp_protocol = AvrcpProtocol()
self.avrcp_protocol.listen(device)
# Don't require MITM when pairing.
device.pairing_config_factory = lambda connection: PairingConfig(mitm=False)
# Start the controller
await device.power_on()
# Print some of the config/properties
print(
"Player Bluetooth Address:",
color(
device.public_address.to_string(with_type_qualifier=False),
"yellow",
),
)
# Listen for connections
device.on("connection", self.on_bluetooth_connection)
# Run the workload
try:
await workload(device)
except BumbleConnectionError as error:
if error.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR:
print(color("Connection already established", "blue"))
else:
print(color(f"Failed to connect: {error}", "red"))
# Wait until it is time to exit
assert self.done is not None
await asyncio.wait(
[hci_source.terminated, asyncio.ensure_future(self.done.wait())],
return_when=asyncio.FIRST_COMPLETED,
)
def on_bluetooth_connection(self, connection: Connection) -> None:
print(color(f"--- Connected: {connection}", "cyan"))
connection.on("disconnection", self.on_bluetooth_disconnection)
def on_bluetooth_disconnection(self, reason) -> None:
print(color(f"--- Disconnected: {HCI_Constant.error_name(reason)}", "cyan"))
self.set_done()
async def connect(self, device: Device, address: str) -> Connection:
print(color(f"Connecting to {address}...", "green"))
connection = await device.connect(address, transport=BT_BR_EDR_TRANSPORT)
# Request authentication
if self.authenticate:
print(color("*** Authenticating...", "blue"))
await connection.authenticate()
print(color("*** Authenticated", "blue"))
# Enable encryption
if self.encrypt:
print(color("*** Enabling encryption...", "blue"))
await connection.encrypt()
print(color("*** Encryption on", "blue"))
return connection
async def create_avdtp_protocol(self, connection: Connection) -> AvdtpProtocol:
# Look for an A2DP service
avdtp_version = await find_avdtp_service_with_connection(connection)
if not avdtp_version:
raise RuntimeError("no A2DP service found")
print(color(f"AVDTP Version: {avdtp_version}"))
# Create a client to interact with the remote device
return await AvdtpProtocol.connect(connection, avdtp_version)
async def stream_packets(
self,
protocol: AvdtpProtocol,
codec_type: int,
vendor_id: int,
codec_id: int,
packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource],
codec_capabilities: MediaCodecCapabilities,
):
# Discover all endpoints on the remote device
endpoints = await protocol.discover_remote_endpoints()
for endpoint in endpoints:
print('@@@', endpoint)
# Select a sink
sink = protocol.find_remote_sink_by_codec(
AVDTP_AUDIO_MEDIA_TYPE, codec_type, vendor_id, codec_id
)
if sink is None:
print(color('!!! no compatible sink found', 'red'))
return
print(f'### Selected sink: {sink.seid}')
# Check if the sink supports delay reporting
delay_reporting = False
for capability in sink.capabilities:
if capability.service_category == AVDTP_DELAY_REPORTING_SERVICE_CATEGORY:
delay_reporting = True
break
def on_delay_report(delay: int):
print(color(f"*** DELAY REPORT: {delay}", "blue"))
# Adjust the codec capabilities for certain codecs
for capability in sink.capabilities:
if isinstance(capability, MediaCodecCapabilities):
if isinstance(
codec_capabilities.media_codec_information, SbcMediaCodecInformation
) and isinstance(
capability.media_codec_information, SbcMediaCodecInformation
):
codec_capabilities.media_codec_information.minimum_bitpool_value = (
capability.media_codec_information.minimum_bitpool_value
)
codec_capabilities.media_codec_information.maximum_bitpool_value = (
capability.media_codec_information.maximum_bitpool_value
)
print(color("Source media codec:", "green"), codec_capabilities)
# Stream the packets
packet_pump = MediaPacketPump(packet_source.packets)
source = protocol.add_source(codec_capabilities, packet_pump, delay_reporting)
source.on("delay_report", on_delay_report)
stream = await protocol.create_stream(source, sink)
await stream.start()
await packet_pump.wait_for_completion()
async def discover(self, device: Device) -> None:
@device.listens_to("inquiry_result")
def on_inquiry_result(
address: Address, class_of_device: int, data: AdvertisingData, rssi: int
) -> None:
(
service_classes,
major_device_class,
minor_device_class,
) = DeviceClass.split_class_of_device(class_of_device)
separator = "\n "
print(f">>> {color(address.to_string(False), 'yellow')}:")
print(f" Device Class (raw): {class_of_device:06X}")
major_class_name = DeviceClass.major_device_class_name(major_device_class)
print(" Device Major Class: " f"{major_class_name}")
minor_class_name = DeviceClass.minor_device_class_name(
major_device_class, minor_device_class
)
print(" Device Minor Class: " f"{minor_class_name}")
print(
" Device Services: "
f"{', '.join(DeviceClass.service_class_labels(service_classes))}"
)
print(f" RSSI: {rssi}")
if data.ad_structures:
print(f" {data.to_string(separator)}")
await device.start_discovery()
async def pair(self, device: Device, address: str) -> None:
print(color(f"Connecting to {address}...", "green"))
connection = await device.connect(address, transport=BT_BR_EDR_TRANSPORT)
print(color("Pairing...", "magenta"))
await connection.authenticate()
print(color("Pairing completed", "magenta"))
self.set_done()
async def inquire(self, device: Device, address: str) -> None:
connection = await self.connect(device, address)
avdtp_protocol = await self.create_avdtp_protocol(connection)
# Discover the remote endpoints
endpoints = await avdtp_protocol.discover_remote_endpoints()
print(f'@@@ Found {len(list(endpoints))} endpoints')
for endpoint in endpoints:
print('@@@', endpoint)
self.set_done()
async def play(
self,
device: Device,
address: Optional[str],
audio_format: str,
audio_file: str,
) -> None:
if audio_format == "auto":
if audio_file.endswith(".sbc"):
audio_format = "sbc"
elif audio_file.endswith(".aac") or audio_file.endswith(".adts"):
audio_format = "aac"
elif audio_file.endswith(".ogg"):
audio_format = "opus"
else:
raise ValueError("Unable to determine audio format from file extension")
device.on(
"connection",
lambda connection: AsyncRunner.spawn(on_connection(connection)),
)
async def on_connection(connection: Connection):
avdtp_protocol = await self.create_avdtp_protocol(connection)
with open(audio_file, 'rb') as input_file:
# NOTE: this should be using asyncio file reading, but blocking reads
# are good enough for this command line app.
async def read_audio_data(byte_count):
return input_file.read(byte_count)
# Obtain the codec capabilities from the stream
packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource]
vendor_id = 0
codec_id = 0
if audio_format == "sbc":
codec_type = A2DP_SBC_CODEC_TYPE
codec_capabilities = await sbc_codec_capabilities(read_audio_data)
packet_source = SbcPacketSource(
read_audio_data,
avdtp_protocol.l2cap_channel.peer_mtu,
)
elif audio_format == "aac":
codec_type = A2DP_MPEG_2_4_AAC_CODEC_TYPE
codec_capabilities = await aac_codec_capabilities(read_audio_data)
packet_source = AacPacketSource(
read_audio_data,
avdtp_protocol.l2cap_channel.peer_mtu,
)
else:
codec_type = A2DP_NON_A2DP_CODEC_TYPE
vendor_id = OpusMediaCodecInformation.VENDOR_ID
codec_id = OpusMediaCodecInformation.CODEC_ID
codec_capabilities = await opus_codec_capabilities(read_audio_data)
packet_source = OpusPacketSource(
read_audio_data,
avdtp_protocol.l2cap_channel.peer_mtu,
)
# Rewind to the start
input_file.seek(0)
try:
await self.stream_packets(
avdtp_protocol,
codec_type,
vendor_id,
codec_id,
packet_source,
codec_capabilities,
)
except Exception as error:
print(color(f"!!! Error while streaming: {error}", "red"))
self.set_done()
if address:
await self.connect(device, address)
else:
print(color("Waiting for an incoming connection...", "magenta"))
def set_done(self) -> None:
if self.done:
self.done.set()
# -----------------------------------------------------------------------------
def create_player(context) -> Player:
return Player(
transport=context.obj["hci_transport"],
device_config=context.obj["device_config"],
authenticate=context.obj["authenticate"],
encrypt=context.obj["encrypt"],
)
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
@click.option("--hci-transport", metavar="TRANSPORT", required=True)
@click.option("--device-config", metavar="FILENAME", help="Device configuration file")
@click.option(
"--authenticate",
is_flag=True,
help="Request authentication when connecting",
default=False,
)
@click.option(
"--encrypt", is_flag=True, help="Request encryption when connecting", default=True
)
def player_cli(ctx, hci_transport, device_config, authenticate, encrypt):
ctx.ensure_object(dict)
ctx.obj["hci_transport"] = hci_transport
ctx.obj["device_config"] = device_config
ctx.obj["authenticate"] = authenticate
ctx.obj["encrypt"] = encrypt
@player_cli.command("discover")
@click.pass_context
def discover(context):
"""Discover speakers or headphones"""
player = create_player(context)
asyncio.run(player.run(player.discover))
@player_cli.command("inquire")
@click.pass_context
@click.argument(
"address",
metavar="ADDRESS",
)
def inquire(context, address):
"""Connect to a speaker or headphone and inquire about their capabilities"""
player = create_player(context)
asyncio.run(player.run(lambda device: player.inquire(device, address)))
@player_cli.command("pair")
@click.pass_context
@click.argument(
"address",
metavar="ADDRESS",
)
def pair(context, address):
"""Pair with a speaker or headphone"""
player = create_player(context)
asyncio.run(player.run(lambda device: player.pair(device, address)))
@player_cli.command("play")
@click.pass_context
@click.option(
"--connect",
"address",
metavar="ADDRESS",
help="Address or name to connect to",
)
@click.option(
"-f",
"--audio-format",
type=click.Choice(["auto", "sbc", "aac", "opus"]),
help="Audio file format (use 'auto' to infer the format from the file extension)",
default="auto",
)
@click.argument("audio_file")
def play(context, address, audio_format, audio_file):
"""Play and audio file"""
player = create_player(context)
asyncio.run(
player.run(
lambda device: player.play(device, address, audio_format, audio_file)
)
)
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper())
player_cli()
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter

520
apps/rfcomm_bridge.py Normal file
View File

@@ -0,0 +1,520 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import time
from typing import Optional
import click
from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, Connection
from bumble import core
from bumble import hci
from bumble import rfcomm
from bumble import transport
from bumble import utils
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_RFCOMM_UUID = "E6D55659-C8B4-4B85-96BB-B1143AF6D3AE"
DEFAULT_MTU = 4096
DEFAULT_CLIENT_TCP_PORT = 9544
DEFAULT_SERVER_TCP_PORT = 9545
TRACE_MAX_SIZE = 48
# -----------------------------------------------------------------------------
class Tracer:
"""
Trace data buffers transmitted from one endpoint to another, with stats.
"""
def __init__(self, channel_name: str) -> None:
self.channel_name = channel_name
self.last_ts: float = 0.0
def trace_data(self, data: bytes) -> None:
now = time.time()
elapsed_s = now - self.last_ts if self.last_ts else 0
elapsed_ms = int(elapsed_s * 1000)
instant_throughput_kbps = ((len(data) / elapsed_s) / 1000) if elapsed_s else 0.0
hex_str = data[:TRACE_MAX_SIZE].hex() + (
"..." if len(data) > TRACE_MAX_SIZE else ""
)
print(
f"[{self.channel_name}] {len(data):4} bytes "
f"(+{elapsed_ms:4}ms, {instant_throughput_kbps: 7.2f}kB/s) "
f" {hex_str}"
)
self.last_ts = now
# -----------------------------------------------------------------------------
class ServerBridge:
"""
RFCOMM server bridge: waits for a peer to connect an RFCOMM channel.
The RFCOMM channel may be associated with a UUID published in an SDP service
description, or simply be on a system-assigned channel number.
When the connection is made, the bridge connects a TCP socket to a remote host and
bridges the data in both directions, with flow control.
When the RFCOMM channel is closed, the bridge disconnects the TCP socket
and waits for a new channel to be connected.
"""
READ_CHUNK_SIZE = 4096
def __init__(
self, channel: int, uuid: str, trace: bool, tcp_host: str, tcp_port: int
) -> None:
self.device: Optional[Device] = None
self.channel = channel
self.uuid = uuid
self.tcp_host = tcp_host
self.tcp_port = tcp_port
self.rfcomm_channel: Optional[rfcomm.DLC] = None
self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Optional[Tracer]
if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
else:
self.rfcomm_tracer = None
self.tcp_tracer = None
async def start(self, device: Device) -> None:
self.device = device
# Create and register a server
rfcomm_server = rfcomm.Server(self.device)
# Listen for incoming DLC connections
self.channel = rfcomm_server.listen(self.on_rfcomm_channel, self.channel)
# Setup the SDP to advertise this channel
service_record_handle = 0x00010001
self.device.sdp_service_records = {
service_record_handle: rfcomm.make_service_sdp_records(
service_record_handle, self.channel, core.UUID(self.uuid)
)
}
# We're ready for a connection
self.device.on("connection", self.on_connection)
await self.set_available(True)
print(
color(
(
f"### Listening for RFCOMM connection on {device.public_address}, "
f"channel {self.channel}"
),
"yellow",
)
)
async def set_available(self, available: bool):
# Become discoverable and connectable
assert self.device
await self.device.set_connectable(available)
await self.device.set_discoverable(available)
def on_connection(self, connection):
print(color(f"@@@ Bluetooth connection: {connection}", "blue"))
connection.on("disconnection", self.on_disconnection)
# Don't accept new connections until we're disconnected
utils.AsyncRunner.spawn(self.set_available(False))
def on_disconnection(self, reason: int):
print(
color("@@@ Bluetooth disconnection:", "red"),
hci.HCI_Constant.error_name(reason),
)
# We're ready for a new connection
utils.AsyncRunner.spawn(self.set_available(True))
# Called when an RFCOMM channel is established
@utils.AsyncRunner.run_in_task()
async def on_rfcomm_channel(self, rfcomm_channel):
print(color("*** RFCOMM channel:", "cyan"), rfcomm_channel)
# Connect to the TCP server
print(
color(
f"### Connecting to TCP {self.tcp_host}:{self.tcp_port}",
"yellow",
)
)
try:
reader, writer = await asyncio.open_connection(self.tcp_host, self.tcp_port)
except OSError:
print(color("!!! Connection failed", "red"))
await rfcomm_channel.disconnect()
return
# Pipe data from RFCOMM to TCP
def on_rfcomm_channel_closed():
print(color("*** RFCOMM channel closed", "cyan"))
writer.close()
def write_rfcomm_data(data):
if self.rfcomm_tracer:
self.rfcomm_tracer.trace_data(data)
writer.write(data)
rfcomm_channel.sink = write_rfcomm_data
rfcomm_channel.on("close", on_rfcomm_channel_closed)
# Pipe data from TCP to RFCOMM
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color("### TCP end of stream", "yellow"))
if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
await rfcomm_channel.disconnect()
return
if self.tcp_tracer:
self.tcp_tracer.trace_data(data)
rfcomm_channel.write(data)
await rfcomm_channel.drain()
except Exception as error:
print(f"!!! Exception: {error}")
break
writer.close()
await writer.wait_closed()
print(color("~~~ Bye bye", "magenta"))
# -----------------------------------------------------------------------------
class ClientBridge:
"""
RFCOMM client bridge: connects to a BR/EDR device, then waits for an inbound
TCP connection on a specified port number. When a TCP client connects, an
RFCOMM connection to the device is established, and the data is bridged in both
directions, with flow control.
When the TCP connection is closed by the client, the RFCOMM channel is
disconnected, but the connection to the device remains, ready for a new TCP client
to connect.
"""
READ_CHUNK_SIZE = 4096
def __init__(
self,
channel: int,
uuid: str,
trace: bool,
address: str,
tcp_host: str,
tcp_port: int,
authenticate: bool,
encrypt: bool,
):
self.channel = channel
self.uuid = uuid
self.trace = trace
self.address = address
self.tcp_host = tcp_host
self.tcp_port = tcp_port
self.authenticate = authenticate
self.encrypt = encrypt
self.device: Optional[Device] = None
self.connection: Optional[Connection] = None
self.rfcomm_client: Optional[rfcomm.Client]
self.rfcomm_mux: Optional[rfcomm.Multiplexer]
self.tcp_connected: bool = False
self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Optional[Tracer]
if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
else:
self.rfcomm_tracer = None
self.tcp_tracer = None
async def connect(self) -> None:
if self.connection:
return
print(color(f"@@@ Connecting to Bluetooth {self.address}", "blue"))
assert self.device
self.connection = await self.device.connect(
self.address, transport=core.BT_BR_EDR_TRANSPORT
)
print(color(f"@@@ Bluetooth connection: {self.connection}", "blue"))
self.connection.on("disconnection", self.on_disconnection)
if self.authenticate:
print(color("@@@ Authenticating Bluetooth connection", "blue"))
await self.connection.authenticate()
print(color("@@@ Bluetooth connection authenticated", "blue"))
if self.encrypt:
print(color("@@@ Encrypting Bluetooth connection", "blue"))
await self.connection.encrypt()
print(color("@@@ Bluetooth connection encrypted", "blue"))
self.rfcomm_client = rfcomm.Client(self.connection)
try:
self.rfcomm_mux = await self.rfcomm_client.start()
except BaseException as e:
print(color("!!! Failed to setup RFCOMM connection", "red"), e)
raise
async def start(self, device: Device) -> None:
self.device = device
await device.set_connectable(False)
await device.set_discoverable(False)
# Called when a TCP connection is established
async def on_tcp_connection(reader, writer):
print(color("<<< TCP connection", "magenta"))
if self.tcp_connected:
print(
color("!!! TCP connection already active, rejecting new one", "red")
)
writer.close()
return
self.tcp_connected = True
try:
await self.pipe(reader, writer)
except BaseException as error:
print(color("!!! Exception while piping data:", "red"), error)
return
finally:
writer.close()
await writer.wait_closed()
self.tcp_connected = False
await asyncio.start_server(
on_tcp_connection,
host=self.tcp_host if self.tcp_host != "_" else None,
port=self.tcp_port,
)
print(
color(
f"### Listening for TCP connections on port {self.tcp_port}", "magenta"
)
)
async def pipe(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
# Resolve the channel number from the UUID if needed
if self.channel == 0:
await self.connect()
assert self.connection
channel = await rfcomm.find_rfcomm_channel_with_uuid(
self.connection, self.uuid
)
if channel:
print(color(f"### Found RFCOMM channel {channel}", "yellow"))
else:
print(color(f"!!! RFCOMM channel with UUID {self.uuid} not found"))
return
else:
channel = self.channel
# Connect a new RFCOMM channel
await self.connect()
assert self.rfcomm_mux
print(color(f"*** Opening RFCOMM channel {channel}", "green"))
try:
rfcomm_channel = await self.rfcomm_mux.open_dlc(channel)
print(color(f"*** RFCOMM channel open: {rfcomm_channel}", "green"))
except Exception as error:
print(color(f"!!! RFCOMM open failed: {error}", "red"))
return
# Pipe data from RFCOMM to TCP
def on_rfcomm_channel_closed():
print(color("*** RFCOMM channel closed", "green"))
def write_rfcomm_data(data):
if self.trace:
self.rfcomm_tracer.trace_data(data)
writer.write(data)
rfcomm_channel.on("close", on_rfcomm_channel_closed)
rfcomm_channel.sink = write_rfcomm_data
# Pipe data from TCP to RFCOMM
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color("### TCP end of stream", "yellow"))
if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
await rfcomm_channel.disconnect()
self.tcp_connected = False
return
if self.tcp_tracer:
self.tcp_tracer.trace_data(data)
rfcomm_channel.write(data)
await rfcomm_channel.drain()
except Exception as error:
print(f"!!! Exception: {error}")
break
print(color("~~~ Bye bye", "magenta"))
def on_disconnection(self, reason: int) -> None:
print(
color("@@@ Bluetooth disconnection:", "red"),
hci.HCI_Constant.error_name(reason),
)
self.connection = None
# -----------------------------------------------------------------------------
async def run(device_config, hci_transport, bridge):
print("<<< connecting to HCI...")
async with await transport.open_transport_or_link(hci_transport) as (
hci_source,
hci_sink,
):
print("<<< connected")
if device_config:
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
device = Device.from_config_with_hci(
DeviceConfiguration(), hci_source, hci_sink
)
device.classic_enabled = True
# Let's go
await device.power_on()
try:
await bridge.start(device)
# Wait until the transport terminates
await hci_source.wait_for_termination()
except core.ConnectionError as error:
print(color(f"!!! Bluetooth connection failed: {error}", "red"))
except Exception as error:
print(f"Exception while running bridge: {error}")
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
@click.option(
"--device-config",
metavar="CONFIG_FILE",
help="Device configuration file",
)
@click.option(
"--hci-transport", metavar="TRANSPORT_NAME", help="HCI transport", required=True
)
@click.option("--trace", is_flag=True, help="Trace bridged data to stdout")
@click.option(
"--channel",
metavar="CHANNEL_NUMER",
help="RFCOMM channel number",
type=int,
default=0,
)
@click.option(
"--uuid",
metavar="UUID",
help="UUID for the RFCOMM channel",
default=DEFAULT_RFCOMM_UUID,
)
def cli(
context,
device_config,
hci_transport,
trace,
channel,
uuid,
):
context.ensure_object(dict)
context.obj["device_config"] = device_config
context.obj["hci_transport"] = hci_transport
context.obj["trace"] = trace
context.obj["channel"] = channel
context.obj["uuid"] = uuid
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.option("--tcp-host", help="TCP host", default="localhost")
@click.option("--tcp-port", help="TCP port", default=DEFAULT_SERVER_TCP_PORT)
def server(context, tcp_host, tcp_port):
bridge = ServerBridge(
context.obj["channel"],
context.obj["uuid"],
context.obj["trace"],
tcp_host,
tcp_port,
)
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.argument("bluetooth-address")
@click.option("--tcp-host", help="TCP host", default="_")
@click.option("--tcp-port", help="TCP port", default=DEFAULT_CLIENT_TCP_PORT)
@click.option("--authenticate", is_flag=True, help="Authenticate the connection")
@click.option("--encrypt", is_flag=True, help="Encrypt the connection")
def client(context, bluetooth_address, tcp_host, tcp_port, authenticate, encrypt):
bridge = ClientBridge(
context.obj["channel"],
context.obj["uuid"],
context.obj["trace"],
bluetooth_address,
tcp_host,
tcp_port,
authenticate,
encrypt,
)
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper())
if __name__ == "__main__":
cli(obj={}) # pylint: disable=no-value-for-parameter

View File

@@ -26,7 +26,7 @@ from bumble.transport import open_transport_or_link
from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver
from bumble.device import Advertisement
from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
from bumble.hci import Address, HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
# -----------------------------------------------------------------------------
@@ -66,10 +66,15 @@ class AdvertisementPrinter:
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
address.address_type
]
if address.is_public:
type_color = 'cyan'
if address.address_type in (
Address.RANDOM_IDENTITY_ADDRESS,
Address.PUBLIC_IDENTITY_ADDRESS,
):
type_color = 'yellow'
else:
if address.is_static:
if address.is_public:
type_color = 'cyan'
elif address.is_static:
type_color = 'green'
address_qualifier = '(static)'
elif address.is_resolvable:
@@ -116,6 +121,7 @@ async def scan(
phy,
filter_duplicates,
raw,
irks,
keystore_file,
device_config,
transport,
@@ -140,9 +146,21 @@ async def scan(
if device.keystore:
resolving_keys = await device.keystore.get_resolving_keys()
resolver = AddressResolver(resolving_keys)
else:
resolver = None
resolving_keys = []
for irk_and_address in irks:
if ':' not in irk_and_address:
raise ValueError('invalid IRK:ADDRESS value')
irk_hex, address_str = irk_and_address.split(':', 1)
resolving_keys.append(
(
bytes.fromhex(irk_hex),
Address(address_str, Address.RANDOM_DEVICE_ADDRESS),
)
)
resolver = AddressResolver(resolving_keys) if resolving_keys else None
printer = AdvertisementPrinter(min_rssi, resolver)
if raw:
@@ -187,8 +205,24 @@ async def scan(
default=False,
help='Listen for raw advertising reports instead of processed ones',
)
@click.option('--keystore-file', help='Keystore file to use when resolving addresses')
@click.option('--device-config', help='Device config file for the scanning device')
@click.option(
'--irk',
metavar='<IRK_HEX>:<ADDRESS>',
help=(
'Use this IRK for resolving private addresses ' '(may be used more than once)'
),
multiple=True,
)
@click.option(
'--keystore-file',
metavar='FILE_PATH',
help='Keystore file to use when resolving addresses',
)
@click.option(
'--device-config',
metavar='FILE_PATH',
help='Device config file for the scanning device',
)
@click.argument('transport')
def main(
min_rssi,
@@ -198,6 +232,7 @@ def main(
phy,
filter_duplicates,
raw,
irk,
keystore_file,
device_config,
transport,
@@ -212,6 +247,7 @@ def main(
phy,
filter_duplicates,
raw,
irk,
keystore_file,
device_config,
transport,

View File

@@ -15,7 +15,11 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import datetime
import logging
import os
import struct
import click
from bumble.colors import color
@@ -24,6 +28,14 @@ from bumble.transport.common import PacketReader
from bumble.helpers import PacketTracer
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class SnoopPacketReader:
'''
@@ -36,12 +48,18 @@ class SnoopPacketReader:
DATALINK_BSCP = 1003
DATALINK_H5 = 1004
IDENTIFICATION_PATTERN = b'btsnoop\0'
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1)
TIMESTAMP_DELTA = 0x00E03AB44A676000
ONE_MICROSECOND = datetime.timedelta(microseconds=1)
def __init__(self, source):
self.source = source
self.at_end = False
# Read the header
identification_pattern = source.read(8)
if identification_pattern.hex().lower() != '6274736e6f6f7000':
if identification_pattern != self.IDENTIFICATION_PATTERN:
raise ValueError(
'not a valid snoop file, unexpected identification pattern'
)
@@ -55,19 +73,32 @@ class SnoopPacketReader:
# Read the record header
header = self.source.read(24)
if len(header) < 24:
return (0, None)
self.at_end = True
return (None, 0, None)
# Parse the header
(
original_length,
included_length,
packet_flags,
_cumulative_drops,
_timestamp_seconds,
_timestamp_microsecond,
) = struct.unpack('>IIIIII', header)
timestamp,
) = struct.unpack('>IIIIQ', header)
# Abort on truncated packets
# Skip truncated packets
if original_length != included_length:
return (0, None)
print(
color(
f"!!! truncated packet ({included_length}/{original_length})", "red"
)
)
self.source.read(included_length)
return (None, 0, None)
# Convert the timestamp to a datetime object.
ts_dt = self.TIMESTAMP_ANCHOR + datetime.timedelta(
microseconds=timestamp - self.TIMESTAMP_DELTA
)
if self.data_link_type == self.DATALINK_H1:
# The packet is un-encapsulated, look at the flags to figure out its type
@@ -89,7 +120,17 @@ class SnoopPacketReader:
bytes([packet_type]) + self.source.read(included_length),
)
return (packet_flags & 1, self.source.read(included_length))
return (ts_dt, packet_flags & 1, self.source.read(included_length))
# -----------------------------------------------------------------------------
class Printer:
def __init__(self):
self.index = 0
def print(self, message: str) -> None:
self.index += 1
print(f"[{self.index:8}]{message}")
# -----------------------------------------------------------------------------
@@ -103,18 +144,18 @@ class SnoopPacketReader:
help='Format of the input file',
)
@click.option(
'--vendors',
'--vendor',
type=click.Choice(['android', 'zephyr']),
multiple=True,
help='Support vendor-specific commands (list one or more)',
)
@click.argument('filename')
# pylint: disable=redefined-builtin
def main(format, vendors, filename):
for vendor in vendors:
if vendor == 'android':
def main(format, vendor, filename):
for vendor_name in vendor:
if vendor_name == 'android':
import bumble.vendor.android.hci
elif vendor == 'zephyr':
elif vendor_name == 'zephyr':
import bumble.vendor.zephyr.hci
input = open(filename, 'rb')
@@ -122,24 +163,28 @@ def main(format, vendors, filename):
packet_reader = PacketReader(input)
def read_next_packet():
return (0, packet_reader.next_packet())
return (None, 0, packet_reader.next_packet())
else:
packet_reader = SnoopPacketReader(input)
read_next_packet = packet_reader.next_packet
tracer = PacketTracer(emit_message=print)
printer = Printer()
tracer = PacketTracer(emit_message=printer.print)
while True:
while not packet_reader.at_end:
try:
(direction, packet) = read_next_packet()
if packet is None:
break
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction)
(timestamp, direction, packet) = read_next_packet()
if packet:
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction, timestamp)
else:
printer.print(color("[TRUNCATED]", "red"))
except Exception as error:
logger.exception('')
print(color(f'!!! {error}', 'red'))
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
main() # pylint: disable=no-value-for-parameter

View File

@@ -44,25 +44,18 @@ from bumble.avdtp import (
AVDTP_AUDIO_MEDIA_TYPE,
Listener,
MediaCodecCapabilities,
MediaPacket,
Protocol,
)
from bumble.a2dp import (
MPEG_2_AAC_LC_OBJECT_TYPE,
make_audio_sink_service_sdp_records,
A2DP_SBC_CODEC_TYPE,
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_SNR_ALLOCATION_METHOD,
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
SbcMediaCodecInformation,
AacMediaCodecInformation,
)
from bumble.utils import AsyncRunner
from bumble.codecs import AacAudioRtpPacket
from bumble.rtp import MediaPacket
# -----------------------------------------------------------------------------
@@ -76,6 +69,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654
# -----------------------------------------------------------------------------
class AudioExtractor:
@staticmethod
@@ -92,7 +86,7 @@ class AudioExtractor:
# -----------------------------------------------------------------------------
class AacAudioExtractor:
def extract_audio(self, packet: MediaPacket) -> bytes:
return AacAudioRtpPacket(packet.payload).to_adts()
return AacAudioRtpPacket.from_bytes(packet.payload).to_adts()
# -----------------------------------------------------------------------------
@@ -450,10 +444,12 @@ class Speaker:
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE,
media_codec_information=AacMediaCodecInformation.from_lists(
object_types=[MPEG_2_AAC_LC_OBJECT_TYPE],
sampling_frequencies=[48000, 44100],
channels=[1, 2],
media_codec_information=AacMediaCodecInformation(
object_type=AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC,
sampling_frequency=AacMediaCodecInformation.SamplingFrequency.SF_48000
| AacMediaCodecInformation.SamplingFrequency.SF_44100,
channels=AacMediaCodecInformation.Channels.MONO
| AacMediaCodecInformation.Channels.STEREO,
vbr=1,
bitrate=256000,
),
@@ -463,20 +459,23 @@ class Speaker:
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation.from_lists(
sampling_frequencies=[48000, 44100, 32000, 16000],
channel_modes=[
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
],
block_lengths=[4, 8, 12, 16],
subbands=[4, 8],
allocation_methods=[
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD,
],
media_codec_information=SbcMediaCodecInformation(
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.SF_48000
| SbcMediaCodecInformation.SamplingFrequency.SF_44100
| SbcMediaCodecInformation.SamplingFrequency.SF_32000
| SbcMediaCodecInformation.SamplingFrequency.SF_16000,
channel_mode=SbcMediaCodecInformation.ChannelMode.MONO
| SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL
| SbcMediaCodecInformation.ChannelMode.STEREO
| SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
block_length=SbcMediaCodecInformation.BlockLength.BL_4
| SbcMediaCodecInformation.BlockLength.BL_8
| SbcMediaCodecInformation.BlockLength.BL_12
| SbcMediaCodecInformation.BlockLength.BL_16,
subbands=SbcMediaCodecInformation.Subbands.S_4
| SbcMediaCodecInformation.Subbands.S_8,
allocation_method=SbcMediaCodecInformation.AllocationMethod.LOUDNESS
| SbcMediaCodecInformation.AllocationMethod.SNR,
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),

View File

@@ -24,6 +24,7 @@ from bumble.device import Device
from bumble.keys import JsonKeyStore
from bumble.transport import open_transport
# -----------------------------------------------------------------------------
async def unbond_with_keystore(keystore, address):
if address is None:

View File

@@ -17,12 +17,16 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import struct
import logging
from collections.abc import AsyncGenerator
from typing import List, Callable, Awaitable
import dataclasses
import enum
import logging
import struct
from typing import Awaitable, Callable
from typing_extensions import ClassVar, Self
from .codecs import AacAudioRtpPacket
from .company_ids import COMPANY_IDENTIFIERS
from .sdp import (
DataElement,
@@ -42,6 +46,7 @@ from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
name_or_number,
)
from .rtp import MediaPacket
# -----------------------------------------------------------------------------
@@ -103,6 +108,8 @@ SBC_ALLOCATION_METHOD_NAMES = {
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
}
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
8000,
11025,
@@ -130,6 +137,9 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
}
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
# fmt: on
@@ -184,8 +194,12 @@ def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3))
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
)
]
),
),
@@ -234,8 +248,12 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
)
]
),
),
@@ -249,38 +267,61 @@ class SbcMediaCodecInformation:
A2DP spec - 4.3.2 Codec Specific Information Elements
'''
sampling_frequency: int
channel_mode: int
block_length: int
subbands: int
allocation_method: int
sampling_frequency: SamplingFrequency
channel_mode: ChannelMode
block_length: BlockLength
subbands: Subbands
allocation_method: AllocationMethod
minimum_bitpool_value: int
maximum_bitpool_value: int
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3,
SBC_DUAL_CHANNEL_MODE: 1 << 2,
SBC_STEREO_CHANNEL_MODE: 1 << 1,
SBC_JOINT_STEREO_CHANNEL_MODE: 1,
}
BLOCK_LENGTH_BITS = {4: 1 << 3, 8: 1 << 2, 12: 1 << 1, 16: 1}
SUBBANDS_BITS = {4: 1 << 1, 8: 1}
ALLOCATION_METHOD_BITS = {
SBC_SNR_ALLOCATION_METHOD: 1 << 1,
SBC_LOUDNESS_ALLOCATION_METHOD: 1,
}
class SamplingFrequency(enum.IntFlag):
SF_16000 = 1 << 3
SF_32000 = 1 << 2
SF_44100 = 1 << 1
SF_48000 = 1 << 0
@staticmethod
def from_bytes(data: bytes) -> SbcMediaCodecInformation:
sampling_frequency = (data[0] >> 4) & 0x0F
channel_mode = (data[0] >> 0) & 0x0F
block_length = (data[1] >> 4) & 0x0F
subbands = (data[1] >> 2) & 0x03
allocation_method = (data[1] >> 0) & 0x03
@classmethod
def from_int(cls, sampling_frequency: int) -> Self:
sampling_frequencies = [
16000,
32000,
44100,
48000,
]
index = sampling_frequencies.index(sampling_frequency)
return cls(1 << (len(sampling_frequencies) - index - 1))
class ChannelMode(enum.IntFlag):
MONO = 1 << 3
DUAL_CHANNEL = 1 << 2
STEREO = 1 << 1
JOINT_STEREO = 1 << 0
class BlockLength(enum.IntFlag):
BL_4 = 1 << 3
BL_8 = 1 << 2
BL_12 = 1 << 1
BL_16 = 1 << 0
class Subbands(enum.IntFlag):
S_4 = 1 << 1
S_8 = 1 << 0
class AllocationMethod(enum.IntFlag):
SNR = 1 << 1
LOUDNESS = 1 << 0
@classmethod
def from_bytes(cls, data: bytes) -> Self:
sampling_frequency = cls.SamplingFrequency((data[0] >> 4) & 0x0F)
channel_mode = cls.ChannelMode((data[0] >> 0) & 0x0F)
block_length = cls.BlockLength((data[1] >> 4) & 0x0F)
subbands = cls.Subbands((data[1] >> 2) & 0x03)
allocation_method = cls.AllocationMethod((data[1] >> 0) & 0x03)
minimum_bitpool_value = (data[2] >> 0) & 0xFF
maximum_bitpool_value = (data[3] >> 0) & 0xFF
return SbcMediaCodecInformation(
return cls(
sampling_frequency,
channel_mode,
block_length,
@@ -290,52 +331,6 @@ class SbcMediaCodecInformation:
maximum_bitpool_value,
)
@classmethod
def from_discrete_values(
cls,
sampling_frequency: int,
channel_mode: int,
block_length: int,
subbands: int,
allocation_method: int,
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
return SbcMediaCodecInformation(
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
block_length=cls.BLOCK_LENGTH_BITS[block_length],
subbands=cls.SUBBANDS_BITS[subbands],
allocation_method=cls.ALLOCATION_METHOD_BITS[allocation_method],
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value=maximum_bitpool_value,
)
@classmethod
def from_lists(
cls,
sampling_frequencies: List[int],
channel_modes: List[int],
block_lengths: List[int],
subbands: List[int],
allocation_methods: List[int],
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
return SbcMediaCodecInformation(
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channel_mode=sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes),
block_length=sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths),
subbands=sum(cls.SUBBANDS_BITS[x] for x in subbands),
allocation_method=sum(
cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods
),
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value=maximum_bitpool_value,
)
def __bytes__(self) -> bytes:
return bytes(
[
@@ -348,23 +343,6 @@ class SbcMediaCodecInformation:
]
)
def __str__(self) -> str:
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness']
return '\n'.join(
# pylint: disable=line-too-long
[
'SbcMediaCodecInformation(',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}',
f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}',
f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}',
f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}',
f' minimum_bitpool_value: {self.minimum_bitpool_value}',
f' maximum_bitpool_value: {self.maximum_bitpool_value}' ')',
]
)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
@@ -373,83 +351,66 @@ class AacMediaCodecInformation:
A2DP spec - 4.5.2 Codec Specific Information Elements
'''
object_type: int
sampling_frequency: int
channels: int
rfa: int
object_type: ObjectType
sampling_frequency: SamplingFrequency
channels: Channels
vbr: int
bitrate: int
OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5,
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4,
}
SAMPLING_FREQUENCY_BITS = {
8000: 1 << 11,
11025: 1 << 10,
12000: 1 << 9,
16000: 1 << 8,
22050: 1 << 7,
24000: 1 << 6,
32000: 1 << 5,
44100: 1 << 4,
48000: 1 << 3,
64000: 1 << 2,
88200: 1 << 1,
96000: 1,
}
CHANNELS_BITS = {1: 1 << 1, 2: 1}
class ObjectType(enum.IntFlag):
MPEG_2_AAC_LC = 1 << 7
MPEG_4_AAC_LC = 1 << 6
MPEG_4_AAC_LTP = 1 << 5
MPEG_4_AAC_SCALABLE = 1 << 4
@staticmethod
def from_bytes(data: bytes) -> AacMediaCodecInformation:
object_type = data[0]
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
channels = (data[2] >> 2) & 0x03
rfa = 0
class SamplingFrequency(enum.IntFlag):
SF_8000 = 1 << 11
SF_11025 = 1 << 10
SF_12000 = 1 << 9
SF_16000 = 1 << 8
SF_22050 = 1 << 7
SF_24000 = 1 << 6
SF_32000 = 1 << 5
SF_44100 = 1 << 4
SF_48000 = 1 << 3
SF_64000 = 1 << 2
SF_88200 = 1 << 1
SF_96000 = 1 << 0
@classmethod
def from_int(cls, sampling_frequency: int) -> Self:
sampling_frequencies = [
8000,
11025,
12000,
16000,
22050,
24000,
32000,
44100,
48000,
64000,
88200,
96000,
]
index = sampling_frequencies.index(sampling_frequency)
return cls(1 << (len(sampling_frequencies) - index - 1))
class Channels(enum.IntFlag):
MONO = 1 << 1
STEREO = 1 << 0
@classmethod
def from_bytes(cls, data: bytes) -> AacMediaCodecInformation:
object_type = cls.ObjectType(data[0])
sampling_frequency = cls.SamplingFrequency(
(data[1] << 4) | ((data[2] >> 4) & 0x0F)
)
channels = cls.Channels((data[2] >> 2) & 0x03)
vbr = (data[3] >> 7) & 0x01
bitrate = ((data[3] & 0x7F) << 16) | (data[4] << 8) | data[5]
return AacMediaCodecInformation(
object_type, sampling_frequency, channels, rfa, vbr, bitrate
)
@classmethod
def from_discrete_values(
cls,
object_type: int,
sampling_frequency: int,
channels: int,
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels=cls.CHANNELS_BITS[channels],
rfa=0,
vbr=vbr,
bitrate=bitrate,
)
@classmethod
def from_lists(
cls,
object_types: List[int],
sampling_frequencies: List[int],
channels: List[int],
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channels=sum(cls.CHANNELS_BITS[x] for x in channels),
rfa=0,
vbr=vbr,
bitrate=bitrate,
object_type, sampling_frequency, channels, vbr, bitrate
)
def __bytes__(self) -> bytes:
@@ -464,30 +425,6 @@ class AacMediaCodecInformation:
]
)
def __str__(self) -> str:
object_types = [
'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC',
'MPEG_4_AAC_LTP',
'MPEG_4_AAC_SCALABLE',
'[4]',
'[5]',
'[6]',
'[7]',
]
channels = [1, 2]
# pylint: disable=line-too-long
return '\n'.join(
[
'AacMediaCodecInformation(',
f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}',
f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}',
f' vbr: {self.vbr}',
f' bitrate: {self.bitrate}' ')',
]
)
@dataclasses.dataclass
# -----------------------------------------------------------------------------
@@ -506,7 +443,7 @@ class VendorSpecificMediaCodecInformation:
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
def __bytes__(self) -> bytes:
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
return struct.pack('<IH', self.vendor_id, self.codec_id) + self.value
def __str__(self) -> str:
# pylint: disable=line-too-long
@@ -520,13 +457,69 @@ class VendorSpecificMediaCodecInformation:
)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class OpusMediaCodecInformation(VendorSpecificMediaCodecInformation):
vendor_id: int = dataclasses.field(init=False, repr=False)
codec_id: int = dataclasses.field(init=False, repr=False)
value: bytes = dataclasses.field(init=False, repr=False)
channel_mode: ChannelMode
frame_size: FrameSize
sampling_frequency: SamplingFrequency
class ChannelMode(enum.IntFlag):
MONO = 1 << 0
STEREO = 1 << 1
DUAL_MONO = 1 << 2
class FrameSize(enum.IntFlag):
FS_10MS = 1 << 0
FS_20MS = 1 << 1
class SamplingFrequency(enum.IntFlag):
SF_48000 = 1 << 0
VENDOR_ID: ClassVar[int] = 0x000000E0
CODEC_ID: ClassVar[int] = 0x0001
def __post_init__(self) -> None:
self.vendor_id = self.VENDOR_ID
self.codec_id = self.CODEC_ID
self.value = bytes(
[
self.channel_mode
| (self.frame_size << 3)
| (self.sampling_frequency << 7)
]
)
@classmethod
def from_bytes(cls, data: bytes) -> Self:
"""Create a new instance from the `value` part of the data, not including
the vendor id and codec id"""
channel_mode = cls.ChannelMode(data[0] & 0x07)
frame_size = cls.FrameSize((data[0] >> 3) & 0x03)
sampling_frequency = cls.SamplingFrequency((data[0] >> 7) & 0x01)
return cls(
channel_mode,
frame_size,
sampling_frequency,
)
def __str__(self) -> str:
return repr(self)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class SbcFrame:
sampling_frequency: int
block_count: int
channel_mode: int
allocation_method: int
subband_count: int
bitpool: int
payload: bytes
@property
@@ -545,8 +538,10 @@ class SbcFrame:
return (
f'SBC(sf={self.sampling_frequency},'
f'cm={self.channel_mode},'
f'am={self.allocation_method},'
f'br={self.bitrate},'
f'sc={self.sample_count},'
f'bp={self.bitpool},'
f'size={len(self.payload)})'
)
@@ -575,6 +570,7 @@ class SbcParser:
blocks = 4 * (1 + ((header[1] >> 4) & 3))
channel_mode = (header[1] >> 2) & 3
channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2
allocation_method = (header[1] >> 1) & 1
subbands = 8 if ((header[1]) & 1) else 4
bitpool = header[2]
@@ -594,7 +590,13 @@ class SbcParser:
# Emit the next frame
yield SbcFrame(
sampling_frequency, blocks, channel_mode, subbands, payload
sampling_frequency,
blocks,
channel_mode,
allocation_method,
subbands,
bitpool,
payload,
)
return generate_frames()
@@ -602,21 +604,15 @@ class SbcParser:
# -----------------------------------------------------------------------------
class SbcPacketSource:
def __init__(
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
) -> None:
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
self.read = read
self.mtu = mtu
self.codec_capabilities = codec_capabilities
@property
def packets(self):
async def generate_packets():
# pylint: disable=import-outside-toplevel
from .avdtp import MediaPacket # Import here to avoid a circular reference
sequence_number = 0
timestamp = 0
sample_count = 0
frames = []
frames_size = 0
max_rtp_payload = self.mtu - 12 - 1
@@ -624,27 +620,29 @@ class SbcPacketSource:
# NOTE: this doesn't support frame fragments
sbc_parser = SbcParser(self.read)
async for frame in sbc_parser.frames:
print(frame)
if (
frames_size + len(frame.payload) > max_rtp_payload
or len(frames) == 16
or len(frames) == SBC_MAX_FRAMES_IN_RTP_PAYLOAD
):
# Need to flush what has been accumulated so far
logger.debug(f"yielding {len(frames)} frames")
# Emit a packet
sbc_payload = bytes([len(frames)]) + b''.join(
sbc_payload = bytes([len(frames) & 0x0F]) + b''.join(
[frame.payload for frame in frames]
)
timestamp_seconds = sample_count / frame.sampling_frequency
timestamp = int(1000 * timestamp_seconds)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
)
packet.timestamp_seconds = timestamp / frame.sampling_frequency
packet.timestamp_seconds = timestamp_seconds
yield packet
# Prepare for next packets
sequence_number += 1
timestamp += sum((frame.sample_count for frame in frames))
sequence_number &= 0xFFFF
sample_count += sum((frame.sample_count for frame in frames))
frames = [frame]
frames_size = len(frame.payload)
else:
@@ -653,3 +651,315 @@ class SbcPacketSource:
frames_size += len(frame.payload)
return generate_packets()
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class AacFrame:
class Profile(enum.IntEnum):
MAIN = 0
LC = 1
SSR = 2
LTP = 3
profile: Profile
sampling_frequency: int
channel_configuration: int
payload: bytes
@property
def sample_count(self) -> int:
return 1024
@property
def duration(self) -> float:
return self.sample_count / self.sampling_frequency
def __str__(self) -> str:
return (
f'AAC(sf={self.sampling_frequency},'
f'ch={self.channel_configuration},'
f'size={len(self.payload)})'
)
# -----------------------------------------------------------------------------
ADTS_AAC_SAMPLING_FREQUENCIES = [
96000,
88200,
64000,
48000,
44100,
32000,
24000,
22050,
16000,
12000,
11025,
8000,
7350,
0,
0,
0,
]
# -----------------------------------------------------------------------------
class AacParser:
"""Parser for AAC frames in an ADTS stream"""
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
self.read = read
@property
def frames(self) -> AsyncGenerator[AacFrame, None]:
async def generate_frames() -> AsyncGenerator[AacFrame, None]:
while True:
header = await self.read(7)
if not header:
return
sync_word = (header[0] << 4) | (header[1] >> 4)
if sync_word != 0b111111111111:
raise ValueError(f"invalid sync word ({sync_word:06x})")
layer = (header[1] >> 1) & 0b11
profile = AacFrame.Profile((header[2] >> 6) & 0b11)
sampling_frequency = ADTS_AAC_SAMPLING_FREQUENCIES[
(header[2] >> 2) & 0b1111
]
channel_configuration = ((header[2] & 0b1) << 2) | (header[3] >> 6)
frame_length = (
((header[3] & 0b11) << 11) | (header[4] << 3) | (header[5] >> 5)
)
if layer != 0:
raise ValueError("layer must be 0")
payload = await self.read(frame_length - 7)
if payload:
yield AacFrame(
profile, sampling_frequency, channel_configuration, payload
)
return generate_frames()
# -----------------------------------------------------------------------------
class AacPacketSource:
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
self.read = read
self.mtu = mtu
@property
def packets(self):
async def generate_packets():
sequence_number = 0
sample_count = 0
aac_parser = AacParser(self.read)
async for frame in aac_parser.frames:
logger.debug("yielding one AAC frame")
# Emit a packet
aac_payload = bytes(
AacAudioRtpPacket.for_simple_aac(
frame.sampling_frequency,
frame.channel_configuration,
frame.payload,
)
)
timestamp_seconds = sample_count / frame.sampling_frequency
timestamp = int(1000 * timestamp_seconds)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, aac_payload
)
packet.timestamp_seconds = timestamp_seconds
yield packet
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
sample_count += frame.sample_count
return generate_packets()
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class OpusPacket:
class ChannelMode(enum.IntEnum):
MONO = 0
STEREO = 1
DUAL_MONO = 2
channel_mode: ChannelMode
duration: int # Duration in ms.
sampling_frequency: int
payload: bytes
def __str__(self) -> str:
return (
f'Opus(ch={self.channel_mode.name}, '
f'd={self.duration}ms, '
f'size={len(self.payload)})'
)
# -----------------------------------------------------------------------------
class OpusParser:
"""
Parser for Opus packets in an Ogg stream
See RFC 3533
NOTE: this parser only supports bitstreams with a single logical stream.
"""
CAPTURE_PATTERN = b'OggS'
class HeaderType(enum.IntFlag):
CONTINUED = 0x01
FIRST = 0x02
LAST = 0x04
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
self.read = read
@property
def packets(self) -> AsyncGenerator[OpusPacket, None]:
async def generate_frames() -> AsyncGenerator[OpusPacket, None]:
packet = b''
packet_count = 0
expected_bitstream_serial_number = None
expected_page_sequence_number = 0
channel_mode = OpusPacket.ChannelMode.STEREO
while True:
# Parse the page header
header = await self.read(27)
if len(header) != 27:
logger.debug("end of stream")
break
capture_pattern = header[:4]
if capture_pattern != self.CAPTURE_PATTERN:
print(capture_pattern.hex())
raise ValueError("invalid capture pattern at start of page")
version = header[4]
if version != 0:
raise ValueError(f"version {version} not supported")
header_type = self.HeaderType(header[5])
(
granule_position,
bitstream_serial_number,
page_sequence_number,
crc_checksum,
page_segments,
) = struct.unpack_from("<QIIIB", header, 6)
segment_table = await self.read(page_segments)
if header_type & self.HeaderType.FIRST:
if expected_bitstream_serial_number is None:
# We will only accept pages for the first encountered stream
logger.debug("BOS")
expected_bitstream_serial_number = bitstream_serial_number
expected_page_sequence_number = page_sequence_number
if (
expected_bitstream_serial_number is None
or expected_bitstream_serial_number != bitstream_serial_number
):
logger.debug("skipping page (not the first logical bitstream)")
for lacing_value in segment_table:
if lacing_value:
await self.read(lacing_value)
continue
if expected_page_sequence_number != page_sequence_number:
raise ValueError(
f"expected page sequence number {expected_page_sequence_number}"
f" but got {page_sequence_number}"
)
expected_page_sequence_number = page_sequence_number + 1
# Assemble the page
if not header_type & self.HeaderType.CONTINUED:
packet = b''
for lacing_value in segment_table:
if lacing_value:
packet += await self.read(lacing_value)
if lacing_value < 255:
# End of packet
packet_count += 1
if packet_count == 1:
# The first packet contains the identification header
logger.debug("first packet (header)")
if packet[:8] != b"OpusHead":
raise ValueError("first packet is not OpusHead")
packet_count = (
OpusPacket.ChannelMode.MONO
if packet[9] == 1
else OpusPacket.ChannelMode.STEREO
)
elif packet_count == 2:
# The second packet contains the comment header
logger.debug("second packet (tags)")
if packet[:8] != b"OpusTags":
logger.warning("second packet is not OpusTags")
else:
yield OpusPacket(channel_mode, 20, 48000, packet)
packet = b''
if header_type & self.HeaderType.LAST:
logger.debug("EOS")
return generate_frames()
# -----------------------------------------------------------------------------
class OpusPacketSource:
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
self.read = read
self.mtu = mtu
@property
def packets(self):
async def generate_packets():
sequence_number = 0
elapsed_ms = 0
opus_parser = OpusParser(self.read)
async for opus_packet in opus_parser.packets:
# We only support sending one Opus frame per RTP packet
# TODO: check the spec for the first byte value here
opus_payload = bytes([1]) + opus_packet.payload
elapsed_s = elapsed_ms / 1000
timestamp = int(elapsed_s * opus_packet.sampling_frequency)
rtp_packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, opus_payload
)
rtp_packet.timestamp_seconds = elapsed_s
yield rtp_packet
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
elapsed_ms += opus_packet.duration
return generate_packets()
# -----------------------------------------------------------------------------
# This map should be left at the end of the file so it can refer to the classes
# above
# -----------------------------------------------------------------------------
A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES = {
OpusMediaCodecInformation.VENDOR_ID: {
OpusMediaCodecInformation.CODEC_ID: OpusMediaCodecInformation
}
}

View File

@@ -14,13 +14,19 @@
from typing import List, Union
from bumble import core
class AtParsingError(core.InvalidPacketError):
"""Error raised when parsing AT commands fails."""
def tokenize_parameters(buffer: bytes) -> List[bytes]:
"""Split input parameters into tokens.
Removes space characters outside of double quote blocks:
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
are ignored [..], unless they are embedded in numeric or string constants"
Raises ValueError in case of invalid input string."""
Raises AtParsingError in case of invalid input string."""
tokens = []
in_quotes = False
@@ -43,11 +49,11 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
token = bytearray()
elif char == b'(':
if len(token) > 0:
raise ValueError("open_paren following regular character")
raise AtParsingError("open_paren following regular character")
tokens.append(char)
elif char == b'"':
if len(token) > 0:
raise ValueError("quote following regular character")
raise AtParsingError("quote following regular character")
in_quotes = True
token.extend(char)
else:
@@ -59,7 +65,7 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
"""Parse the parameters using the comma and parenthesis separators.
Raises ValueError in case of invalid input string."""
Raises AtParsingError in case of invalid input string."""
tokens = tokenize_parameters(buffer)
accumulator: List[list] = [[]]
@@ -73,7 +79,7 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
accumulator.append([])
elif token == b')':
if len(accumulator) < 2:
raise ValueError("close_paren without matching open_paren")
raise AtParsingError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
else:
@@ -81,5 +87,5 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
accumulator[-1].append(current)
if len(accumulator) > 1:
raise ValueError("missing close_paren")
raise AtParsingError("missing close_paren")
return accumulator[0]

View File

@@ -23,12 +23,26 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import functools
import inspect
import struct
from pyee import EventEmitter
from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
Union,
TYPE_CHECKING,
)
from pyee import EventEmitter
from bumble import utils
from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value
from bumble.colors import color
@@ -43,6 +57,7 @@ if TYPE_CHECKING:
# pylint: disable=line-too-long
ATT_CID = 0x04
ATT_PSM = 0x001F
ATT_ERROR_RESPONSE = 0x01
ATT_EXCHANGE_MTU_REQUEST = 0x02
@@ -133,43 +148,57 @@ ATT_RESPONSES = [
ATT_EXECUTE_WRITE_RESPONSE
]
ATT_INVALID_HANDLE_ERROR = 0x01
ATT_READ_NOT_PERMITTED_ERROR = 0x02
ATT_WRITE_NOT_PERMITTED_ERROR = 0x03
ATT_INVALID_PDU_ERROR = 0x04
ATT_INSUFFICIENT_AUTHENTICATION_ERROR = 0x05
ATT_REQUEST_NOT_SUPPORTED_ERROR = 0x06
ATT_INVALID_OFFSET_ERROR = 0x07
ATT_INSUFFICIENT_AUTHORIZATION_ERROR = 0x08
ATT_PREPARE_QUEUE_FULL_ERROR = 0x09
ATT_ATTRIBUTE_NOT_FOUND_ERROR = 0x0A
ATT_ATTRIBUTE_NOT_LONG_ERROR = 0x0B
ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR = 0x0C
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR = 0x0D
ATT_UNLIKELY_ERROR_ERROR = 0x0E
ATT_INSUFFICIENT_ENCRYPTION_ERROR = 0x0F
ATT_UNSUPPORTED_GROUP_TYPE_ERROR = 0x10
ATT_INSUFFICIENT_RESOURCES_ERROR = 0x11
class ErrorCode(utils.OpenIntEnum):
'''
See
ATT_ERROR_NAMES = {
ATT_INVALID_HANDLE_ERROR: 'ATT_INVALID_HANDLE_ERROR',
ATT_READ_NOT_PERMITTED_ERROR: 'ATT_READ_NOT_PERMITTED_ERROR',
ATT_WRITE_NOT_PERMITTED_ERROR: 'ATT_WRITE_NOT_PERMITTED_ERROR',
ATT_INVALID_PDU_ERROR: 'ATT_INVALID_PDU_ERROR',
ATT_INSUFFICIENT_AUTHENTICATION_ERROR: 'ATT_INSUFFICIENT_AUTHENTICATION_ERROR',
ATT_REQUEST_NOT_SUPPORTED_ERROR: 'ATT_REQUEST_NOT_SUPPORTED_ERROR',
ATT_INVALID_OFFSET_ERROR: 'ATT_INVALID_OFFSET_ERROR',
ATT_INSUFFICIENT_AUTHORIZATION_ERROR: 'ATT_INSUFFICIENT_AUTHORIZATION_ERROR',
ATT_PREPARE_QUEUE_FULL_ERROR: 'ATT_PREPARE_QUEUE_FULL_ERROR',
ATT_ATTRIBUTE_NOT_FOUND_ERROR: 'ATT_ATTRIBUTE_NOT_FOUND_ERROR',
ATT_ATTRIBUTE_NOT_LONG_ERROR: 'ATT_ATTRIBUTE_NOT_LONG_ERROR',
ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR: 'ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR',
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR: 'ATT_INVALID_ATTRIBUTE_LENGTH_ERROR',
ATT_UNLIKELY_ERROR_ERROR: 'ATT_UNLIKELY_ERROR_ERROR',
ATT_INSUFFICIENT_ENCRYPTION_ERROR: 'ATT_INSUFFICIENT_ENCRYPTION_ERROR',
ATT_UNSUPPORTED_GROUP_TYPE_ERROR: 'ATT_UNSUPPORTED_GROUP_TYPE_ERROR',
ATT_INSUFFICIENT_RESOURCES_ERROR: 'ATT_INSUFFICIENT_RESOURCES_ERROR'
}
* Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response
* Core Specification Supplement: Common Profile And Service Error Codes
'''
INVALID_HANDLE = 0x01
READ_NOT_PERMITTED = 0x02
WRITE_NOT_PERMITTED = 0x03
INVALID_PDU = 0x04
INSUFFICIENT_AUTHENTICATION = 0x05
REQUEST_NOT_SUPPORTED = 0x06
INVALID_OFFSET = 0x07
INSUFFICIENT_AUTHORIZATION = 0x08
PREPARE_QUEUE_FULL = 0x09
ATTRIBUTE_NOT_FOUND = 0x0A
ATTRIBUTE_NOT_LONG = 0x0B
INSUFFICIENT_ENCRYPTION_KEY_SIZE = 0x0C
INVALID_ATTRIBUTE_LENGTH = 0x0D
UNLIKELY_ERROR = 0x0E
INSUFFICIENT_ENCRYPTION = 0x0F
UNSUPPORTED_GROUP_TYPE = 0x10
INSUFFICIENT_RESOURCES = 0x11
DATABASE_OUT_OF_SYNC = 0x12
VALUE_NOT_ALLOWED = 0x13
# 0x80 0x9F: Application Error
# 0xE0 0xFF: Common Profile and Service Error Codes
WRITE_REQUEST_REJECTED = 0xFC
CCCD_IMPROPERLY_CONFIGURED = 0xFD
PROCEDURE_ALREADY_IN_PROGRESS = 0xFE
OUT_OF_RANGE = 0xFF
# Backward Compatible Constants
ATT_INVALID_HANDLE_ERROR = ErrorCode.INVALID_HANDLE
ATT_READ_NOT_PERMITTED_ERROR = ErrorCode.READ_NOT_PERMITTED
ATT_WRITE_NOT_PERMITTED_ERROR = ErrorCode.WRITE_NOT_PERMITTED
ATT_INVALID_PDU_ERROR = ErrorCode.INVALID_PDU
ATT_INSUFFICIENT_AUTHENTICATION_ERROR = ErrorCode.INSUFFICIENT_AUTHENTICATION
ATT_REQUEST_NOT_SUPPORTED_ERROR = ErrorCode.REQUEST_NOT_SUPPORTED
ATT_INVALID_OFFSET_ERROR = ErrorCode.INVALID_OFFSET
ATT_INSUFFICIENT_AUTHORIZATION_ERROR = ErrorCode.INSUFFICIENT_AUTHORIZATION
ATT_PREPARE_QUEUE_FULL_ERROR = ErrorCode.PREPARE_QUEUE_FULL
ATT_ATTRIBUTE_NOT_FOUND_ERROR = ErrorCode.ATTRIBUTE_NOT_FOUND
ATT_ATTRIBUTE_NOT_LONG_ERROR = ErrorCode.ATTRIBUTE_NOT_LONG
ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR = ErrorCode.INSUFFICIENT_ENCRYPTION_KEY_SIZE
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR = ErrorCode.INVALID_ATTRIBUTE_LENGTH
ATT_UNLIKELY_ERROR_ERROR = ErrorCode.UNLIKELY_ERROR
ATT_INSUFFICIENT_ENCRYPTION_ERROR = ErrorCode.INSUFFICIENT_ENCRYPTION
ATT_UNSUPPORTED_GROUP_TYPE_ERROR = ErrorCode.UNSUPPORTED_GROUP_TYPE
ATT_INSUFFICIENT_RESOURCES_ERROR = ErrorCode.INSUFFICIENT_RESOURCES
ATT_DEFAULT_MTU = 23
@@ -233,9 +262,9 @@ class ATT_PDU:
def pdu_name(op_code):
return name_or_number(ATT_PDU_NAMES, op_code, 2)
@staticmethod
def error_name(error_code):
return name_or_number(ATT_ERROR_NAMES, error_code, 2)
@classmethod
def error_name(cls, error_code: int) -> str:
return ErrorCode(error_code).name
@staticmethod
def subclass(fields):
@@ -263,9 +292,6 @@ class ATT_PDU:
def init_from_bytes(self, pdu, offset):
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self):
return self.pdu
@property
def is_command(self):
return ((self.op_code >> 6) & 1) == 1
@@ -275,7 +301,7 @@ class ATT_PDU:
return ((self.op_code >> 7) & 1) == 1
def __bytes__(self):
return self.to_bytes()
return self.pdu
def __str__(self):
result = color(self.name, 'yellow')
@@ -643,7 +669,7 @@ class ATT_Write_Command(ATT_PDU):
@ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
('attribute_value', '*'),
# ('authentication_signature', 'TODO')
]
)
@@ -682,7 +708,7 @@ class ATT_Prepare_Write_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([])
@ATT_PDU.subclass([("flags", 1)])
class ATT_Execute_Write_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.6.3 Execute Write Request
@@ -722,12 +748,38 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# -----------------------------------------------------------------------------
class ConnectionValue(Protocol):
def read(self, connection) -> bytes:
...
class AttributeValue:
'''
Attribute value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def write(self, connection, value: bytes) -> None:
...
def __init__(
self,
read: Union[
Callable[[Optional[Connection]], Any],
Callable[[Optional[Connection]], Awaitable[Any]],
None,
] = None,
write: Union[
Callable[[Optional[Connection], Any], None],
Callable[[Optional[Connection], Any], Awaitable[None]],
None,
] = None,
):
self._read = read
self._write = write
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
return self._read(connection) if self._read else b''
def write(
self, connection: Optional[Connection], value: bytes
) -> Union[Awaitable[None], None]:
if self._write:
return self._write(connection, value)
return None
# -----------------------------------------------------------------------------
@@ -757,7 +809,7 @@ class Attribute(EventEmitter):
enum_list: List[str] = [p.name for p in cls if p.name is not None]
enum_list_str = ",".join(enum_list)
raise TypeError(
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}"
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str}\nGot: {permissions_str}"
) from exc
# Permission flags(legacy-use only)
@@ -770,13 +822,13 @@ class Attribute(EventEmitter):
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
value: Union[str, bytes, ConnectionValue]
value: Any
def __init__(
self,
attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, ConnectionValue] = b'',
value: Any = b'',
) -> None:
EventEmitter.__init__(self)
self.handle = 0
@@ -794,11 +846,7 @@ class Attribute(EventEmitter):
else:
self.type = attribute_type
# Convert the value to a byte array
if isinstance(value, str):
self.value = bytes(value, 'utf-8')
else:
self.value = value
self.value = value
def encode_value(self, value: Any) -> bytes:
return value
@@ -806,7 +854,7 @@ class Attribute(EventEmitter):
def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
def read_value(self, connection: Optional[Connection]) -> bytes:
async def read_value(self, connection: Optional[Connection]) -> bytes:
if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
@@ -832,6 +880,8 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'read'):
try:
value = self.value.read(connection)
if inspect.isawaitable(value):
value = await value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
@@ -839,9 +889,11 @@ class Attribute(EventEmitter):
else:
value = self.value
self.emit('read', connection, value)
return self.encode_value(value)
def write_value(self, connection: Connection, value_bytes: bytes) -> None:
async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption:
@@ -864,7 +916,9 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'write'):
try:
self.value.write(connection, value) # pylint: disable=not-callable
result = self.value.write(connection, value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle

524
bumble/avc.py Normal file
View File

@@ -0,0 +1,524 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import Dict, Type, Union, Tuple
from bumble import core
from bumble.utils import OpenIntEnum
# -----------------------------------------------------------------------------
class Frame:
class SubunitType(enum.IntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.4
MONITOR = 0x00
AUDIO = 0x01
PRINTER = 0x02
DISC = 0x03
TAPE_RECORDER_OR_PLAYER = 0x04
TUNER = 0x05
CA = 0x06
CAMERA = 0x07
PANEL = 0x09
BULLETIN_BOARD = 0x0A
VENDOR_UNIQUE = 0x1C
EXTENDED = 0x1E
UNIT = 0x1F
class OperationCode(OpenIntEnum):
# 0x00 - 0x0F: Unit and subunit commands
VENDOR_DEPENDENT = 0x00
RESERVE = 0x01
PLUG_INFO = 0x02
# 0x10 - 0x3F: Unit commands
DIGITAL_OUTPUT = 0x10
DIGITAL_INPUT = 0x11
CHANNEL_USAGE = 0x12
OUTPUT_PLUG_SIGNAL_FORMAT = 0x18
INPUT_PLUG_SIGNAL_FORMAT = 0x19
GENERAL_BUS_SETUP = 0x1F
CONNECT_AV = 0x20
DISCONNECT_AV = 0x21
CONNECTIONS = 0x22
CONNECT = 0x24
DISCONNECT = 0x25
UNIT_INFO = 0x30
SUBUNIT_INFO = 0x31
# 0x40 - 0x7F: Subunit commands
PASS_THROUGH = 0x7C
GUI_UPDATE = 0x7D
PUSH_GUI_DATA = 0x7E
USER_ACTION = 0x7F
# 0xA0 - 0xBF: Unit and subunit commands
VERSION = 0xB0
POWER = 0xB2
subunit_type: SubunitType
subunit_id: int
opcode: OperationCode
operands: bytes
@staticmethod
def subclass(subclass):
# Infer the opcode from the class name
if subclass.__name__.endswith("CommandFrame"):
short_name = subclass.__name__.replace("CommandFrame", "")
category_class = CommandFrame
elif subclass.__name__.endswith("ResponseFrame"):
short_name = subclass.__name__.replace("ResponseFrame", "")
category_class = ResponseFrame
else:
raise core.InvalidArgumentError(
f"invalid subclass name {subclass.__name__}"
)
uppercase_indexes = [
i for i in range(len(short_name)) if short_name[i].isupper()
]
uppercase_indexes.append(len(short_name))
words = [
short_name[uppercase_indexes[i] : uppercase_indexes[i + 1]].upper()
for i in range(len(uppercase_indexes) - 1)
]
opcode_name = "_".join(words)
opcode = Frame.OperationCode[opcode_name]
category_class.subclasses[opcode] = subclass
return subclass
@staticmethod
def from_bytes(data: bytes) -> Frame:
if data[0] >> 4 != 0:
raise core.InvalidPacketError("first 4 bits must be 0s")
ctype_or_response = data[0] & 0xF
subunit_type = Frame.SubunitType(data[1] >> 3)
subunit_id = data[1] & 7
if subunit_type == Frame.SubunitType.EXTENDED:
# Not supported
raise NotImplementedError("extended subunit types not supported")
if subunit_id < 5 or subunit_id == 7:
opcode_offset = 2
elif subunit_id == 5:
# Extended to the next byte
extension = data[2]
if extension == 0:
raise core.InvalidPacketError("extended subunit ID value reserved")
if extension == 0xFF:
subunit_id = 5 + 254 + data[3]
opcode_offset = 4
else:
subunit_id = 5 + extension
opcode_offset = 3
elif subunit_id == 6:
raise core.InvalidPacketError("reserved subunit ID")
else:
raise core.InvalidPacketError("invalid subunit ID")
opcode = Frame.OperationCode(data[opcode_offset])
operands = data[opcode_offset + 1 :]
# Look for a registered subclass
if ctype_or_response < 8:
# Command
ctype = CommandFrame.CommandType(ctype_or_response)
if c_subclass := CommandFrame.subclasses.get(opcode):
return c_subclass(
ctype,
subunit_type,
subunit_id,
*c_subclass.parse_operands(operands),
)
return CommandFrame(ctype, subunit_type, subunit_id, opcode, operands)
else:
# Response
response = ResponseFrame.ResponseCode(ctype_or_response)
if r_subclass := ResponseFrame.subclasses.get(opcode):
return r_subclass(
response,
subunit_type,
subunit_id,
*r_subclass.parse_operands(operands),
)
return ResponseFrame(response, subunit_type, subunit_id, opcode, operands)
def to_bytes(
self,
ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode],
) -> bytes:
# TODO: support extended subunit types and ids.
return (
bytes(
[
ctype_or_response,
self.subunit_type << 3 | self.subunit_id,
self.opcode,
]
)
+ self.operands
)
def to_string(self, extra: str) -> str:
return (
f"{self.__class__.__name__}({extra}"
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"opcode={self.opcode.name}, "
f"operands={self.operands.hex()})"
)
def __init__(
self,
subunit_type: SubunitType,
subunit_id: int,
opcode: OperationCode,
operands: bytes,
) -> None:
self.subunit_type = subunit_type
self.subunit_id = subunit_id
self.opcode = opcode
self.operands = operands
# -----------------------------------------------------------------------------
class CommandFrame(Frame):
class CommandType(OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.1
CONTROL = 0x00
STATUS = 0x01
SPECIFIC_INQUIRY = 0x02
NOTIFY = 0x03
GENERAL_INQUIRY = 0x04
subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {}
ctype: CommandType
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError
def __init__(
self,
ctype: CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
opcode: Frame.OperationCode,
operands: bytes,
) -> None:
super().__init__(subunit_type, subunit_id, opcode, operands)
self.ctype = ctype
def __bytes__(self):
return self.to_bytes(self.ctype)
def __str__(self):
return self.to_string(f"ctype={self.ctype.name}, ")
# -----------------------------------------------------------------------------
class ResponseFrame(Frame):
class ResponseCode(OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.2
NOT_IMPLEMENTED = 0x08
ACCEPTED = 0x09
REJECTED = 0x0A
IN_TRANSITION = 0x0B
IMPLEMENTED_OR_STABLE = 0x0C
CHANGED = 0x0D
INTERIM = 0x0F
subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {}
response: ResponseCode
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError
def __init__(
self,
response: ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
opcode: Frame.OperationCode,
operands: bytes,
) -> None:
super().__init__(subunit_type, subunit_id, opcode, operands)
self.response = response
def __bytes__(self):
return self.to_bytes(self.response)
def __str__(self):
return self.to_string(f"response={self.response.name}, ")
# -----------------------------------------------------------------------------
class VendorDependentFrame:
company_id: int
vendor_dependent_data: bytes
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
return (
struct.unpack(">I", b"\x00" + operands[:3])[0],
operands[3:],
)
def make_operands(self) -> bytes:
return struct.pack(">I", self.company_id)[1:] + self.vendor_dependent_data
def __init__(self, company_id: int, vendor_dependent_data: bytes):
self.company_id = company_id
self.vendor_dependent_data = vendor_dependent_data
# -----------------------------------------------------------------------------
@Frame.subclass
class VendorDependentCommandFrame(VendorDependentFrame, CommandFrame):
def __init__(
self,
ctype: CommandFrame.CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
company_id: int,
vendor_dependent_data: bytes,
) -> None:
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
CommandFrame.__init__(
self,
ctype,
subunit_type,
subunit_id,
Frame.OperationCode.VENDOR_DEPENDENT,
self.make_operands(),
)
def __str__(self):
return (
f"VendorDependentCommandFrame(ctype={self.ctype.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"company_id=0x{self.company_id:06X}, "
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
)
# -----------------------------------------------------------------------------
@Frame.subclass
class VendorDependentResponseFrame(VendorDependentFrame, ResponseFrame):
def __init__(
self,
response: ResponseFrame.ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
company_id: int,
vendor_dependent_data: bytes,
) -> None:
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
ResponseFrame.__init__(
self,
response,
subunit_type,
subunit_id,
Frame.OperationCode.VENDOR_DEPENDENT,
self.make_operands(),
)
def __str__(self):
return (
f"VendorDependentResponseFrame(response={self.response.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"company_id=0x{self.company_id:06X}, "
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
)
# -----------------------------------------------------------------------------
class PassThroughFrame:
"""
See AV/C Panel Subunit Specification 1.1 - 9.4 PASS THROUGH control command
"""
class StateFlag(enum.IntEnum):
PRESSED = 0
RELEASED = 1
class OperationId(OpenIntEnum):
SELECT = 0x00
UP = 0x01
DOWN = 0x01
LEFT = 0x03
RIGHT = 0x04
RIGHT_UP = 0x05
RIGHT_DOWN = 0x06
LEFT_UP = 0x07
LEFT_DOWN = 0x08
ROOT_MENU = 0x09
SETUP_MENU = 0x0A
CONTENTS_MENU = 0x0B
FAVORITE_MENU = 0x0C
EXIT = 0x0D
NUMBER_0 = 0x20
NUMBER_1 = 0x21
NUMBER_2 = 0x22
NUMBER_3 = 0x23
NUMBER_4 = 0x24
NUMBER_5 = 0x25
NUMBER_6 = 0x26
NUMBER_7 = 0x27
NUMBER_8 = 0x28
NUMBER_9 = 0x29
DOT = 0x2A
ENTER = 0x2B
CLEAR = 0x2C
CHANNEL_UP = 0x30
CHANNEL_DOWN = 0x31
PREVIOUS_CHANNEL = 0x32
SOUND_SELECT = 0x33
INPUT_SELECT = 0x34
DISPLAY_INFORMATION = 0x35
HELP = 0x36
PAGE_UP = 0x37
PAGE_DOWN = 0x38
POWER = 0x40
VOLUME_UP = 0x41
VOLUME_DOWN = 0x42
MUTE = 0x43
PLAY = 0x44
STOP = 0x45
PAUSE = 0x46
RECORD = 0x47
REWIND = 0x48
FAST_FORWARD = 0x49
EJECT = 0x4A
FORWARD = 0x4B
BACKWARD = 0x4C
ANGLE = 0x50
SUBPICTURE = 0x51
F1 = 0x71
F2 = 0x72
F3 = 0x73
F4 = 0x74
F5 = 0x75
VENDOR_UNIQUE = 0x7E
state_flag: StateFlag
operation_id: OperationId
operation_data: bytes
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
return (
PassThroughFrame.StateFlag(operands[0] >> 7),
PassThroughFrame.OperationId(operands[0] & 0x7F),
operands[1 : 1 + operands[1]],
)
def make_operands(self):
return (
bytes([self.state_flag << 7 | self.operation_id, len(self.operation_data)])
+ self.operation_data
)
def __init__(
self,
state_flag: StateFlag,
operation_id: OperationId,
operation_data: bytes,
) -> None:
if len(operation_data) > 255:
raise core.InvalidArgumentError("operation data must be <= 255 bytes")
self.state_flag = state_flag
self.operation_id = operation_id
self.operation_data = operation_data
# -----------------------------------------------------------------------------
@Frame.subclass
class PassThroughCommandFrame(PassThroughFrame, CommandFrame):
def __init__(
self,
ctype: CommandFrame.CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
state_flag: PassThroughFrame.StateFlag,
operation_id: PassThroughFrame.OperationId,
operation_data: bytes,
) -> None:
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
CommandFrame.__init__(
self,
ctype,
subunit_type,
subunit_id,
Frame.OperationCode.PASS_THROUGH,
self.make_operands(),
)
def __str__(self):
return (
f"PassThroughCommandFrame(ctype={self.ctype.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"state_flag={self.state_flag.name}, "
f"operation_id={self.operation_id.name}, "
f"operation_data={self.operation_data.hex()})"
)
# -----------------------------------------------------------------------------
@Frame.subclass
class PassThroughResponseFrame(PassThroughFrame, ResponseFrame):
def __init__(
self,
response: ResponseFrame.ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
state_flag: PassThroughFrame.StateFlag,
operation_id: PassThroughFrame.OperationId,
operation_data: bytes,
) -> None:
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
ResponseFrame.__init__(
self,
response,
subunit_type,
subunit_id,
Frame.OperationCode.PASS_THROUGH,
self.make_operands(),
)
def __str__(self):
return (
f"PassThroughResponseFrame(response={self.response.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"state_flag={self.state_flag.name}, "
f"operation_id={self.operation_id.name}, "
f"operation_data={self.operation_data.hex()})"
)

292
bumble/avctp.py Normal file
View File

@@ -0,0 +1,292 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from enum import IntEnum
import logging
import struct
from typing import Callable, cast, Dict, Optional
from bumble.colors import color
from bumble import avc
from bumble import core
from bumble import l2cap
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AVCTP_PSM = 0x0017
AVCTP_BROWSING_PSM = 0x001B
# -----------------------------------------------------------------------------
class MessageAssembler:
Callback = Callable[[int, bool, bool, int, bytes], None]
transaction_label: int
pid: int
c_r: int
ipid: int
payload: bytes
number_of_packets: int
packets_received: int
def __init__(self, callback: Callback) -> None:
self.callback = callback
self.reset()
def reset(self) -> None:
self.packets_received = 0
self.transaction_label = -1
self.pid = -1
self.c_r = -1
self.ipid = -1
self.payload = b''
self.number_of_packets = 0
self.packet_count = 0
def on_pdu(self, pdu: bytes) -> None:
self.packets_received += 1
transaction_label = pdu[0] >> 4
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
c_r = (pdu[0] >> 1) & 1
ipid = pdu[0] & 1
if c_r == 0 and ipid != 0:
logger.warning("invalid IPID in command frame")
self.reset()
return
pid_offset = 1
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START):
if self.transaction_label >= 0:
# We are already in a transaction
logger.warning("received START or SINGLE fragment while in transaction")
self.reset()
self.packets_received = 1
if packet_type == Protocol.PacketType.START:
self.number_of_packets = pdu[1]
pid_offset = 2
pid = struct.unpack_from(">H", pdu, pid_offset)[0]
self.payload += pdu[pid_offset + 2 :]
if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END):
if transaction_label != self.transaction_label:
logger.warning("transaction label does not match")
self.reset()
return
if pid != self.pid:
logger.warning("PID does not match")
self.reset()
return
if c_r != self.c_r:
logger.warning("C/R does not match")
self.reset()
return
if self.packets_received > self.number_of_packets:
logger.warning("too many fragments in transaction")
self.reset()
return
if packet_type == Protocol.PacketType.END:
if self.packets_received != self.number_of_packets:
logger.warning("premature END")
self.reset()
return
else:
self.transaction_label = transaction_label
self.c_r = c_r
self.ipid = ipid
self.pid = pid
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END):
self.on_message_complete()
def on_message_complete(self):
try:
self.callback(
self.transaction_label,
self.c_r == 0,
self.ipid != 0,
self.pid,
self.payload,
)
except Exception as error:
logger.exception(color(f"!!! exception in callback: {error}", "red"))
self.reset()
# -----------------------------------------------------------------------------
class Protocol:
CommandHandler = Callable[[int, avc.CommandFrame], None]
command_handlers: Dict[int, CommandHandler] # Command handlers, by PID
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID
next_transaction_label: int
message_assembler: MessageAssembler
class PacketType(IntEnum):
SINGLE = 0b00
START = 0b01
CONTINUE = 0b10
END = 0b11
def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None:
self.command_handlers = {}
self.response_handlers = {}
self.l2cap_channel = l2cap_channel
self.message_assembler = MessageAssembler(self.on_message)
# Register to receive PDUs from the channel
l2cap_channel.sink = self.on_pdu
l2cap_channel.on("open", self.on_l2cap_channel_open)
l2cap_channel.on("close", self.on_l2cap_channel_close)
def on_l2cap_channel_open(self):
logger.debug(color("<<< AVCTP channel open", "magenta"))
def on_l2cap_channel_close(self):
logger.debug(color("<<< AVCTP channel closed", "magenta"))
def on_pdu(self, pdu: bytes) -> None:
self.message_assembler.on_pdu(pdu)
def on_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
) -> None:
logger.debug(
f"<<< AVCTP Message: pid={pid}, "
f"transaction_label={transaction_label}, "
f"is_command={is_command}, "
f"ipid={ipid}, "
f"payload={payload.hex()}"
)
# Check for invalid PID responses.
if ipid:
logger.debug(f"received IPID for PID={pid}")
# Find the appropriate handler.
if is_command:
if pid not in self.command_handlers:
logger.warning(f"no command handler for PID {pid}")
self.send_ipid(transaction_label, pid)
return
command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
self.command_handlers[pid](transaction_label, command_frame)
else:
if pid not in self.response_handlers:
logger.warning(f"no response handler for PID {pid}")
return
# By convention, for an ipid, send a None payload to the response handler.
if ipid:
response_frame = None
else:
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
self.response_handlers[pid](transaction_label, response_frame)
def send_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
):
# TODO: fragment large messages
packet_type = Protocol.PacketType.SINGLE
pdu = (
struct.pack(
">BH",
transaction_label << 4
| packet_type << 2
| (0 if is_command else 1) << 1
| (1 if ipid else 0),
pid,
)
+ payload
)
self.l2cap_channel.send_pdu(pdu)
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
logger.debug(
">>> AVCTP command: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, True, False, pid, payload)
def send_response(self, transaction_label: int, pid: int, payload: bytes):
logger.debug(
">>> AVCTP response: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, False, False, pid, payload)
def send_ipid(self, transaction_label: int, pid: int) -> None:
logger.debug(
">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
)
self.send_message(transaction_label, False, True, pid, b'')
def register_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
self.command_handlers[pid] = handler
def unregister_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
raise core.InvalidArgumentError("command handler not registered")
del self.command_handlers[pid]
def register_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
self.response_handlers[pid] = handler
def unregister_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
raise core.InvalidArgumentError("response handler not registered")
del self.response_handlers[pid]

View File

@@ -17,12 +17,10 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import struct
import time
import logging
import enum
import warnings
from pyee import EventEmitter
from typing import (
Any,
Awaitable,
@@ -39,10 +37,13 @@ from typing import (
cast,
)
from pyee import EventEmitter
from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
InvalidStateError,
ProtocolError,
InvalidArgumentError,
name_or_number,
)
from .a2dp import (
@@ -50,13 +51,16 @@ from .a2dp import (
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
A2DP_NON_A2DP_CODEC_TYPE,
A2DP_SBC_CODEC_TYPE,
A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES,
AacMediaCodecInformation,
SbcMediaCodecInformation,
VendorSpecificMediaCodecInformation,
)
from .rtp import MediaPacket
from . import sdp, device, l2cap
from .colors import color
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -241,7 +245,10 @@ async def find_avdtp_service_with_sdp_client(
)
if profile_descriptor_list:
for profile_descriptor in profile_descriptor_list.value:
if len(profile_descriptor.value) >= 2:
if (
profile_descriptor.type == sdp.DataElement.SEQUENCE
and len(profile_descriptor.value) >= 2
):
avdtp_version_major = profile_descriptor.value[1].value >> 8
avdtp_version_minor = profile_descriptor.value[1].value & 0xFF
return (avdtp_version_major, avdtp_version_minor)
@@ -274,90 +281,6 @@ class RealtimeClock:
await asyncio.sleep(duration)
# -----------------------------------------------------------------------------
class MediaPacket:
@staticmethod
def from_bytes(data: bytes) -> MediaPacket:
version = (data[0] >> 6) & 0x03
padding = (data[0] >> 5) & 0x01
extension = (data[0] >> 4) & 0x01
csrc_count = data[0] & 0x0F
marker = (data[1] >> 7) & 0x01
payload_type = data[1] & 0x7F
sequence_number = struct.unpack_from('>H', data, 2)[0]
timestamp = struct.unpack_from('>I', data, 4)[0]
ssrc = struct.unpack_from('>I', data, 8)[0]
csrc_list = [
struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)
]
payload = data[12 + csrc_count * 4 :]
return MediaPacket(
version,
padding,
extension,
marker,
sequence_number,
timestamp,
ssrc,
csrc_list,
payload_type,
payload,
)
def __init__(
self,
version: int,
padding: int,
extension: int,
marker: int,
sequence_number: int,
timestamp: int,
ssrc: int,
csrc_list: List[int],
payload_type: int,
payload: bytes,
) -> None:
self.version = version
self.padding = padding
self.extension = extension
self.marker = marker
self.sequence_number = sequence_number
self.timestamp = timestamp
self.ssrc = ssrc
self.csrc_list = csrc_list
self.payload_type = payload_type
self.payload = payload
def __bytes__(self) -> bytes:
header = bytes(
[
self.version << 6
| self.padding << 5
| self.extension << 4
| len(self.csrc_list),
self.marker << 7 | self.payload_type,
]
) + struct.pack('>HII', self.sequence_number, self.timestamp, self.ssrc)
for csrc in self.csrc_list:
header += struct.pack('>I', csrc)
return header + self.payload
def __str__(self) -> str:
return (
f'RTP(v={self.version},'
f'p={self.padding},'
f'x={self.extension},'
f'm={self.marker},'
f'pt={self.payload_type},'
f'sn={self.sequence_number},'
f'ts={self.timestamp},'
f'ssrc={self.ssrc},'
f'csrcs={self.csrc_list},'
f'payload_size={len(self.payload)})'
)
# -----------------------------------------------------------------------------
class MediaPacketPump:
pump_task: Optional[asyncio.Task]
@@ -368,6 +291,7 @@ class MediaPacketPump:
self.packets = packets
self.clock = clock
self.pump_task = None
self.completed = asyncio.Event()
async def start(self, rtp_channel: l2cap.ClassicChannel) -> None:
async def pump_packets():
@@ -397,6 +321,8 @@ class MediaPacketPump:
)
except asyncio.exceptions.CancelledError:
logger.debug('pump canceled')
finally:
self.completed.set()
# Pump packets
self.pump_task = asyncio.create_task(pump_packets())
@@ -408,6 +334,9 @@ class MediaPacketPump:
await self.pump_task
self.pump_task = None
async def wait_for_completion(self) -> None:
await self.completed.wait()
# -----------------------------------------------------------------------------
class MessageAssembler:
@@ -511,7 +440,8 @@ class MessageAssembler:
try:
self.callback(self.transaction_label, message)
except Exception as error:
logger.warning(color(f'!!! exception in callback: {error}'))
logger.exception(color(f'!!! exception in callback: {error}', 'red'))
self.reset()
@@ -570,10 +500,10 @@ class ServiceCapabilities:
self.service_category = service_category
self.service_capabilities_bytes = service_capabilities_bytes
def to_string(self, details: List[str] = []) -> str:
def to_string(self, details: Optional[List[str]] = None) -> str:
attributes = ','.join(
[name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)]
+ details
+ (details or [])
)
return f'ServiceCapabilities({attributes})'
@@ -605,11 +535,25 @@ class MediaCodecCapabilities(ServiceCapabilities):
self.media_codec_information
)
elif self.media_codec_type == A2DP_NON_A2DP_CODEC_TYPE:
self.media_codec_information = (
vendor_media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(
self.media_codec_information
)
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
self.media_codec_information = media_codec_information_class.from_bytes(
vendor_media_codec_information.value
)
else:
self.media_codec_information = vendor_media_codec_information
def __init__(
self,
@@ -691,7 +635,7 @@ class Message: # pylint:disable=attribute-defined-outside-init
signal_identifier_str = name[:-7]
message_type = Message.MessageType.RESPONSE_REJECT
else:
raise ValueError('invalid class name')
raise InvalidArgumentError('invalid class name')
subclass.message_type = message_type
@@ -1306,10 +1250,20 @@ class Protocol(EventEmitter):
return None
def add_source(
self, codec_capabilities: MediaCodecCapabilities, packet_pump: MediaPacketPump
self,
codec_capabilities: MediaCodecCapabilities,
packet_pump: MediaPacketPump,
delay_reporting: bool = False,
) -> LocalSource:
seid = len(self.local_endpoints) + 1
source = LocalSource(self, seid, codec_capabilities, packet_pump)
service_capabilities = (
[ServiceCapabilities(AVDTP_DELAY_REPORTING_SERVICE_CATEGORY)]
if delay_reporting
else []
)
source = LocalSource(
self, seid, codec_capabilities, service_capabilities, packet_pump
)
self.local_endpoints.append(source)
return source
@@ -1362,7 +1316,7 @@ class Protocol(EventEmitter):
return self.remote_endpoints.values()
def find_remote_sink_by_codec(
self, media_type: int, codec_type: int
self, media_type: int, codec_type: int, vendor_id: int = 0, codec_id: int = 0
) -> Optional[DiscoveredStreamEndPoint]:
for endpoint in self.remote_endpoints.values():
if (
@@ -1387,7 +1341,19 @@ class Protocol(EventEmitter):
codec_capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE
and codec_capabilities.media_codec_type == codec_type
):
has_codec = True
if isinstance(
codec_capabilities.media_codec_information,
VendorSpecificMediaCodecInformation,
):
if (
codec_capabilities.media_codec_information.vendor_id
== vendor_id
and codec_capabilities.media_codec_information.codec_id
== codec_id
):
has_codec = True
else:
has_codec = True
if has_media_transport and has_codec:
return endpoint
@@ -1466,10 +1432,10 @@ class Protocol(EventEmitter):
f'[{transaction_label}] {message}'
)
max_fragment_size = (
self.l2cap_channel.mtu - 3
self.l2cap_channel.peer_mtu - 3
) # Enough space for a 3-byte start packet header
payload = message.payload
if len(payload) + 2 <= self.l2cap_channel.mtu:
if len(payload) + 2 <= self.l2cap_channel.peer_mtu:
# Fits in a single packet
packet_type = self.PacketType.SINGLE_PACKET
else:
@@ -1541,9 +1507,10 @@ class Protocol(EventEmitter):
assert False # Should never reach this
async def get_capabilities(
self, seid: int
) -> Union[Get_Capabilities_Response, Get_All_Capabilities_Response,]:
async def get_capabilities(self, seid: int) -> Union[
Get_Capabilities_Response,
Get_All_Capabilities_Response,
]:
if self.version > (1, 2):
return await self.send_command(Get_All_Capabilities_Command(seid))
@@ -2152,6 +2119,9 @@ class LocalStreamEndPoint(StreamEndPoint, EventEmitter):
def on_abort_command(self):
self.emit('abort')
def on_delayreport_command(self, delay: int):
self.emit('delay_report', delay)
def on_rtp_channel_open(self):
self.emit('rtp_channel_open')
@@ -2166,12 +2136,13 @@ class LocalSource(LocalStreamEndPoint):
protocol: Protocol,
seid: int,
codec_capabilities: MediaCodecCapabilities,
other_capabilitiles: Iterable[ServiceCapabilities],
packet_pump: MediaPacketPump,
) -> None:
capabilities = [
ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY),
codec_capabilities,
]
] + list(other_capabilitiles)
super().__init__(
protocol,
seid,

1938
bumble/avrcp.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -17,6 +17,9 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
from dataclasses import dataclass
from typing_extensions import Self
from bumble import core
# -----------------------------------------------------------------------------
@@ -40,7 +43,7 @@ class BitReader:
""" "Read up to 32 bits."""
if bits > 32:
raise ValueError('maximum read size is 32')
raise core.InvalidArgumentError('maximum read size is 32')
if self.bits_cached >= bits:
# We have enough bits.
@@ -53,7 +56,7 @@ class BitReader:
feed_size = len(feed_bytes)
feed_int = int.from_bytes(feed_bytes, byteorder='big')
if 8 * feed_size + self.bits_cached < bits:
raise ValueError('trying to read past the data')
raise core.InvalidArgumentError('trying to read past the data')
self.byte_position += feed_size
# Combine the new cache and the old cache
@@ -68,7 +71,7 @@ class BitReader:
def read_bytes(self, count: int):
if self.bit_position + 8 * count > 8 * len(self.data):
raise ValueError('not enough data')
raise core.InvalidArgumentError('not enough data')
if self.bit_position % 8:
# Not byte aligned
@@ -99,12 +102,40 @@ class BitReader:
break
# -----------------------------------------------------------------------------
class BitWriter:
"""Simple but not optimized bit stream writer."""
data: int
bit_count: int
def __init__(self) -> None:
self.data = 0
self.bit_count = 0
def write(self, value: int, bit_count: int) -> None:
self.data = (self.data << bit_count) | value
self.bit_count += bit_count
def write_bytes(self, data: bytes) -> None:
bit_count = 8 * len(data)
self.data = (self.data << bit_count) | int.from_bytes(data, 'big')
self.bit_count += bit_count
def __bytes__(self) -> bytes:
return (self.data << ((8 - (self.bit_count % 8)) % 8)).to_bytes(
(self.bit_count + 7) // 8, 'big'
)
# -----------------------------------------------------------------------------
class AacAudioRtpPacket:
"""AAC payload encapsulated in an RTP packet payload"""
audio_mux_element: AudioMuxElement
@staticmethod
def latm_value(reader: BitReader) -> int:
def read_latm_value(reader: BitReader) -> int:
bytes_for_value = reader.read(2)
value = 0
for _ in range(bytes_for_value + 1):
@@ -112,24 +143,33 @@ class AacAudioRtpPacket:
return value
@staticmethod
def program_config_element(reader: BitReader):
raise ValueError('program_config_element not supported')
def read_audio_object_type(reader: BitReader):
# GetAudioObjectType - ISO/EIC 14496-3 Table 1.16
audio_object_type = reader.read(5)
if audio_object_type == 31:
audio_object_type = 32 + reader.read(6)
return audio_object_type
@dataclass
class GASpecificConfig:
def __init__(
self, reader: BitReader, channel_configuration: int, audio_object_type: int
) -> None:
audio_object_type: int
# NOTE: other fields not supported
@classmethod
def from_bits(
cls, reader: BitReader, channel_configuration: int, audio_object_type: int
) -> Self:
# GASpecificConfig - ISO/EIC 14496-3 Table 4.1
frame_length_flag = reader.read(1)
depends_on_core_coder = reader.read(1)
if depends_on_core_coder:
self.core_coder_delay = reader.read(14)
core_coder_delay = reader.read(14)
extension_flag = reader.read(1)
if not channel_configuration:
AacAudioRtpPacket.program_config_element(reader)
raise core.InvalidPacketError('program_config_element not supported')
if audio_object_type in (6, 20):
self.layer_nr = reader.read(3)
layer_nr = reader.read(3)
if extension_flag:
if audio_object_type == 22:
num_of_sub_frame = reader.read(5)
@@ -140,16 +180,15 @@ class AacAudioRtpPacket:
aac_spectral_data_resilience_flags = reader.read(1)
extension_flag_3 = reader.read(1)
if extension_flag_3 == 1:
raise ValueError('extensionFlag3 == 1 not supported')
raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
@staticmethod
def audio_object_type(reader: BitReader):
# GetAudioObjectType - ISO/EIC 14496-3 Table 1.16
audio_object_type = reader.read(5)
if audio_object_type == 31:
audio_object_type = 32 + reader.read(6)
return cls(audio_object_type)
return audio_object_type
def to_bits(self, writer: BitWriter) -> None:
assert self.audio_object_type in (1, 2)
writer.write(0, 1) # frame_length_flag = 0
writer.write(0, 1) # depends_on_core_coder = 0
writer.write(0, 1) # extension_flag = 0
@dataclass
class AudioSpecificConfig:
@@ -157,6 +196,7 @@ class AacAudioRtpPacket:
sampling_frequency_index: int
sampling_frequency: int
channel_configuration: int
ga_specific_config: AacAudioRtpPacket.GASpecificConfig
sbr_present_flag: int
ps_present_flag: int
extension_audio_object_type: int
@@ -180,44 +220,73 @@ class AacAudioRtpPacket:
7350,
]
def __init__(self, reader: BitReader) -> None:
# AudioSpecificConfig - ISO/EIC 14496-3 Table 1.15
self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
self.sampling_frequency_index = reader.read(4)
if self.sampling_frequency_index == 0xF:
self.sampling_frequency = reader.read(24)
else:
self.sampling_frequency = self.SAMPLING_FREQUENCIES[
self.sampling_frequency_index
]
self.channel_configuration = reader.read(4)
self.sbr_present_flag = -1
self.ps_present_flag = -1
if self.audio_object_type in (5, 29):
self.extension_audio_object_type = 5
self.sbc_present_flag = 1
if self.audio_object_type == 29:
self.ps_present_flag = 1
self.extension_sampling_frequency_index = reader.read(4)
if self.extension_sampling_frequency_index == 0xF:
self.extension_sampling_frequency = reader.read(24)
else:
self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[
self.extension_sampling_frequency_index
]
self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
if self.audio_object_type == 22:
self.extension_channel_configuration = reader.read(4)
else:
self.extension_audio_object_type = 0
@classmethod
def for_simple_aac(
cls,
audio_object_type: int,
sampling_frequency: int,
channel_configuration: int,
) -> Self:
if sampling_frequency not in cls.SAMPLING_FREQUENCIES:
raise ValueError(f'invalid sampling frequency {sampling_frequency}')
if self.audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23):
ga_specific_config = AacAudioRtpPacket.GASpecificConfig(
reader, self.channel_configuration, self.audio_object_type
ga_specific_config = AacAudioRtpPacket.GASpecificConfig(audio_object_type)
return cls(
audio_object_type=audio_object_type,
sampling_frequency_index=cls.SAMPLING_FREQUENCIES.index(
sampling_frequency
),
sampling_frequency=sampling_frequency,
channel_configuration=channel_configuration,
ga_specific_config=ga_specific_config,
sbr_present_flag=0,
ps_present_flag=0,
extension_audio_object_type=0,
extension_sampling_frequency_index=0,
extension_sampling_frequency=0,
extension_channel_configuration=0,
)
@classmethod
def from_bits(cls, reader: BitReader) -> Self:
# AudioSpecificConfig - ISO/EIC 14496-3 Table 1.15
audio_object_type = AacAudioRtpPacket.read_audio_object_type(reader)
sampling_frequency_index = reader.read(4)
if sampling_frequency_index == 0xF:
sampling_frequency = reader.read(24)
else:
sampling_frequency = cls.SAMPLING_FREQUENCIES[sampling_frequency_index]
channel_configuration = reader.read(4)
sbr_present_flag = 0
ps_present_flag = 0
extension_sampling_frequency_index = 0
extension_sampling_frequency = 0
extension_channel_configuration = 0
extension_audio_object_type = 0
if audio_object_type in (5, 29):
extension_audio_object_type = 5
sbr_present_flag = 1
if audio_object_type == 29:
ps_present_flag = 1
extension_sampling_frequency_index = reader.read(4)
if extension_sampling_frequency_index == 0xF:
extension_sampling_frequency = reader.read(24)
else:
extension_sampling_frequency = cls.SAMPLING_FREQUENCIES[
extension_sampling_frequency_index
]
audio_object_type = AacAudioRtpPacket.read_audio_object_type(reader)
if audio_object_type == 22:
extension_channel_configuration = reader.read(4)
if audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23):
ga_specific_config = AacAudioRtpPacket.GASpecificConfig.from_bits(
reader, channel_configuration, audio_object_type
)
else:
raise ValueError(
f'audioObjectType {self.audio_object_type} not supported'
raise core.InvalidPacketError(
f'audioObjectType {audio_object_type} not supported'
)
# if self.extension_audio_object_type != 5 and bits_to_decode >= 16:
@@ -246,13 +315,44 @@ class AacAudioRtpPacket:
# self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[self.extension_sampling_frequency_index]
# self.extension_channel_configuration = reader.read(4)
return cls(
audio_object_type,
sampling_frequency_index,
sampling_frequency,
channel_configuration,
ga_specific_config,
sbr_present_flag,
ps_present_flag,
extension_audio_object_type,
extension_sampling_frequency_index,
extension_sampling_frequency,
extension_channel_configuration,
)
def to_bits(self, writer: BitWriter) -> None:
if self.sampling_frequency_index >= 15:
raise ValueError(
f"unsupported sampling frequency index {self.sampling_frequency_index}"
)
if self.audio_object_type not in (1, 2):
raise ValueError(
f"unsupported audio object type {self.audio_object_type} "
)
writer.write(self.audio_object_type, 5)
writer.write(self.sampling_frequency_index, 4)
writer.write(self.channel_configuration, 4)
self.ga_specific_config.to_bits(writer)
@dataclass
class StreamMuxConfig:
other_data_present: int
other_data_len_bits: int
audio_specific_config: AacAudioRtpPacket.AudioSpecificConfig
def __init__(self, reader: BitReader) -> None:
@classmethod
def from_bits(cls, reader: BitReader) -> Self:
# StreamMuxConfig - ISO/EIC 14496-3 Table 1.42
audio_mux_version = reader.read(1)
if audio_mux_version == 1:
@@ -260,31 +360,31 @@ class AacAudioRtpPacket:
else:
audio_mux_version_a = 0
if audio_mux_version_a != 0:
raise ValueError('audioMuxVersionA != 0 not supported')
raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
if audio_mux_version == 1:
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
tara_buffer_fullness = AacAudioRtpPacket.read_latm_value(reader)
stream_cnt = 0
all_streams_same_time_framing = reader.read(1)
num_sub_frames = reader.read(6)
num_program = reader.read(4)
if num_program != 0:
raise ValueError('num_program != 0 not supported')
raise core.InvalidPacketError('num_program != 0 not supported')
num_layer = reader.read(3)
if num_layer != 0:
raise ValueError('num_layer != 0 not supported')
raise core.InvalidPacketError('num_layer != 0 not supported')
if audio_mux_version == 0:
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig.from_bits(
reader
)
else:
asc_len = AacAudioRtpPacket.latm_value(reader)
asc_len = AacAudioRtpPacket.read_latm_value(reader)
marker = reader.bit_position
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig.from_bits(
reader
)
audio_specific_config_len = reader.bit_position - marker
if asc_len < audio_specific_config_len:
raise ValueError('audio_specific_config_len > asc_len')
raise core.InvalidPacketError('audio_specific_config_len > asc_len')
asc_len -= audio_specific_config_len
reader.skip(asc_len)
frame_length_type = reader.read(3)
@@ -293,38 +393,53 @@ class AacAudioRtpPacket:
elif frame_length_type == 1:
frame_length = reader.read(9)
else:
raise ValueError(f'frame_length_type {frame_length_type} not supported')
raise core.InvalidPacketError(
f'frame_length_type {frame_length_type} not supported'
)
self.other_data_present = reader.read(1)
if self.other_data_present:
other_data_present = reader.read(1)
other_data_len_bits = 0
if other_data_present:
if audio_mux_version == 1:
self.other_data_len_bits = AacAudioRtpPacket.latm_value(reader)
other_data_len_bits = AacAudioRtpPacket.read_latm_value(reader)
else:
self.other_data_len_bits = 0
while True:
self.other_data_len_bits *= 256
other_data_len_bits *= 256
other_data_len_esc = reader.read(1)
self.other_data_len_bits += reader.read(8)
other_data_len_bits += reader.read(8)
if other_data_len_esc == 0:
break
crc_check_present = reader.read(1)
if crc_check_present:
crc_checksum = reader.read(8)
return cls(other_data_present, other_data_len_bits, audio_specific_config)
def to_bits(self, writer: BitWriter) -> None:
writer.write(0, 1) # audioMuxVersion = 0
writer.write(1, 1) # allStreamsSameTimeFraming = 1
writer.write(0, 6) # numSubFrames = 0
writer.write(0, 4) # numProgram = 0
writer.write(0, 3) # numLayer = 0
self.audio_specific_config.to_bits(writer)
writer.write(0, 3) # frameLengthType = 0
writer.write(0, 8) # latmBufferFullness = 0
writer.write(0, 1) # otherDataPresent = 0
writer.write(0, 1) # crcCheckPresent = 0
@dataclass
class AudioMuxElement:
payload: bytes
stream_mux_config: AacAudioRtpPacket.StreamMuxConfig
payload: bytes
def __init__(self, reader: BitReader, mux_config_present: int):
if mux_config_present == 0:
raise ValueError('muxConfigPresent == 0 not supported')
@classmethod
def from_bits(cls, reader: BitReader) -> Self:
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41
# (only supports mux_config_present=1)
use_same_stream_mux = reader.read(1)
if use_same_stream_mux:
raise ValueError('useSameStreamMux == 1 not supported')
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
raise core.InvalidPacketError('useSameStreamMux == 1 not supported')
stream_mux_config = AacAudioRtpPacket.StreamMuxConfig.from_bits(reader)
# We only support:
# allStreamsSameTimeFraming == 1
@@ -340,19 +455,46 @@ class AacAudioRtpPacket:
if tmp != 255:
break
self.payload = reader.read_bytes(mux_slot_length_bytes)
payload = reader.read_bytes(mux_slot_length_bytes)
if self.stream_mux_config.other_data_present:
reader.skip(self.stream_mux_config.other_data_len_bits)
if stream_mux_config.other_data_present:
reader.skip(stream_mux_config.other_data_len_bits)
# ByteAlign
while reader.bit_position % 8:
reader.read(1)
def __init__(self, data: bytes) -> None:
return cls(stream_mux_config, payload)
def to_bits(self, writer: BitWriter) -> None:
writer.write(0, 1) # useSameStreamMux = 0
self.stream_mux_config.to_bits(writer)
mux_slot_length_bytes = len(self.payload)
while mux_slot_length_bytes > 255:
writer.write(255, 8)
mux_slot_length_bytes -= 255
writer.write(mux_slot_length_bytes, 8)
if mux_slot_length_bytes == 255:
writer.write(0, 8)
writer.write_bytes(self.payload)
@classmethod
def from_bytes(cls, data: bytes) -> Self:
# Parse the bit stream
reader = BitReader(data)
self.audio_mux_element = self.AudioMuxElement(reader, mux_config_present=1)
return cls(cls.AudioMuxElement.from_bits(reader))
@classmethod
def for_simple_aac(
cls, sampling_frequency: int, channel_configuration: int, payload: bytes
) -> Self:
audio_specific_config = cls.AudioSpecificConfig.for_simple_aac(
2, sampling_frequency, channel_configuration
)
stream_mux_config = cls.StreamMuxConfig(0, 0, audio_specific_config)
audio_mux_element = cls.AudioMuxElement(stream_mux_config, payload)
return cls(audio_mux_element)
def to_adts(self):
# pylint: disable=line-too-long
@@ -379,3 +521,11 @@ class AacAudioRtpPacket:
)
+ self.audio_mux_element.payload
)
def __init__(self, audio_mux_element: AudioMuxElement) -> None:
self.audio_mux_element = audio_mux_element
def __bytes__(self) -> bytes:
writer = BitWriter()
self.audio_mux_element.to_bits(writer)
return bytes(writer)

View File

@@ -16,6 +16,10 @@ from functools import partial
from typing import List, Optional, Union
class ColorError(ValueError):
"""Error raised when a color spec is invalid."""
# ANSI color names. There is also a "default"
COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
@@ -52,7 +56,7 @@ def _color_code(spec: ColorSpec, base: int) -> str:
elif isinstance(spec, int) and 0 <= spec <= 255:
return _join(base + 8, 5, spec)
else:
raise ValueError('Invalid color spec "%s"' % spec)
raise ColorError('Invalid color spec "%s"' % spec)
def color(
@@ -72,7 +76,7 @@ def color(
if style_part in STYLES:
codes.append(STYLES.index(style_part))
else:
raise ValueError('Invalid style "%s"' % style_part)
raise ColorError('Invalid style "%s"' % style_part)
if codes:
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)

View File

@@ -19,6 +19,7 @@ from __future__ import annotations
import logging
import asyncio
import dataclasses
import itertools
import random
import struct
@@ -42,6 +43,7 @@ from bumble.hci import (
HCI_LE_1M_PHY,
HCI_SUCCESS,
HCI_UNKNOWN_HCI_COMMAND_ERROR,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
HCI_VERSION_BLUETOOTH_CORE_5_0,
Address,
@@ -53,17 +55,21 @@ from bumble.hci import (
HCI_Connection_Request_Event,
HCI_Disconnection_Complete_Event,
HCI_Encryption_Change_Event,
HCI_Synchronous_Connection_Complete_Event,
HCI_LE_Advertising_Report_Event,
HCI_LE_CIS_Established_Event,
HCI_LE_CIS_Request_Event,
HCI_LE_Connection_Complete_Event,
HCI_LE_Read_Remote_Features_Complete_Event,
HCI_Number_Of_Completed_Packets_Event,
HCI_Packet,
HCI_Role_Change_Event,
)
from typing import Optional, Union, Dict, TYPE_CHECKING
from typing import Optional, Union, Dict, Any, TYPE_CHECKING
if TYPE_CHECKING:
from bumble.transport.common import TransportSink, TransportSource
from bumble.link import LocalLink
from bumble.transport.common import TransportSink
# -----------------------------------------------------------------------------
# Logging
@@ -79,15 +85,27 @@ class DataObject:
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class CisLink:
handle: int
cis_id: int
cig_id: int
acl_connection: Optional[Connection] = None
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class Connection:
def __init__(self, controller, handle, role, peer_address, link, transport):
self.controller = controller
self.handle = handle
self.role = role
self.peer_address = peer_address
self.link = link
controller: Controller
handle: int
role: int
peer_address: Address
link: Any
transport: int
link_type: int
def __post_init__(self):
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport
def on_hci_acl_data_packet(self, packet):
self.assembler.feed_packet(packet)
@@ -106,25 +124,27 @@ class Connection:
class Controller:
def __init__(
self,
name,
name: str,
host_source=None,
host_sink: Optional[TransportSink] = None,
link=None,
link: Optional[LocalLink] = None,
public_address: Optional[Union[bytes, str, Address]] = None,
):
self.name = name
self.hci_sink = None
self.link = link
self.central_connections: Dict[
Address, Connection
] = {} # Connections where this controller is the central
self.peripheral_connections: Dict[
Address, Connection
] = {} # Connections where this controller is the peripheral
self.classic_connections: Dict[
Address, Connection
] = {} # Connections in BR/EDR
self.central_connections: Dict[Address, Connection] = (
{}
) # Connections where this controller is the central
self.peripheral_connections: Dict[Address, Connection] = (
{}
) # Connections where this controller is the peripheral
self.classic_connections: Dict[Address, Connection] = (
{}
) # Connections in BR/EDR
self.central_cis_links: Dict[int, CisLink] = {} # CIS links by handle
self.peripheral_cis_links: Dict[int, CisLink] = {} # CIS links by handle
self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0
self.hci_revision = 0
@@ -134,12 +154,14 @@ class Controller:
'0000000060000000'
) # BR/EDR Not Supported, LE Supported (Controller)
self.manufacturer_name = 0xFFFF
self.hc_data_packet_length = 27
self.hc_total_num_data_packets = 64
self.hc_le_data_packet_length = 27
self.hc_total_num_le_data_packets = 64
self.event_mask = 0
self.event_mask_page_2 = 0
self.supported_commands = bytes.fromhex(
'2000800000c000000000e40000002822000000000000040000f7ffff7f000000'
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
'30f0f9ff01008004000000000000000000000000000000000000000000000000'
)
self.le_event_mask = 0
@@ -292,7 +314,7 @@ class Controller:
f'{color("CONTROLLER -> HOST", "green")}: {packet}'
)
if self.host:
self.host.on_packet(packet.to_bytes())
self.host.on_packet(bytes(packet))
# This method allows the controller to emulate the same API as a transport source
async def wait_for_termination(self):
@@ -301,7 +323,7 @@ class Controller:
############################################################
# Link connections
############################################################
def allocate_connection_handle(self):
def allocate_connection_handle(self) -> int:
handle = 0
max_handle = 0
for connection in itertools.chain(
@@ -313,6 +335,13 @@ class Controller:
if connection.handle == handle:
# Already used, continue searching after the current max
handle = max_handle + 1
for cis_handle in itertools.chain(
self.central_cis_links.keys(), self.peripheral_cis_links.keys()
):
max_handle = max(max_handle, cis_handle)
if cis_handle == handle:
# Already used, continue searching after the current max
handle = max_handle + 1
return handle
def find_le_connection_by_address(self, address):
@@ -357,12 +386,13 @@ class Controller:
if connection is None:
connection_handle = self.allocate_connection_handle()
connection = Connection(
self,
connection_handle,
BT_PERIPHERAL_ROLE,
peer_address,
self.link,
BT_LE_TRANSPORT,
controller=self,
handle=connection_handle,
role=BT_PERIPHERAL_ROLE,
peer_address=peer_address,
link=self.link,
transport=BT_LE_TRANSPORT,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
)
self.peripheral_connections[peer_address] = connection
logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}')
@@ -416,12 +446,13 @@ class Controller:
if connection is None:
connection_handle = self.allocate_connection_handle()
connection = Connection(
self,
connection_handle,
BT_CENTRAL_ROLE,
peer_address,
self.link,
BT_LE_TRANSPORT,
controller=self,
handle=connection_handle,
role=BT_CENTRAL_ROLE,
peer_address=peer_address,
link=self.link,
transport=BT_LE_TRANSPORT,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
)
self.central_connections[peer_address] = connection
logger.debug(
@@ -538,6 +569,104 @@ class Controller:
)
self.send_hci_packet(HCI_LE_Advertising_Report_Event([report]))
def on_link_cis_request(
self, central_address: Address, cig_id: int, cis_id: int
) -> None:
'''
Called when an incoming CIS request occurs from a central on the link
'''
connection = self.peripheral_connections.get(central_address)
assert connection
pending_cis_link = CisLink(
handle=self.allocate_connection_handle(),
cis_id=cis_id,
cig_id=cig_id,
acl_connection=connection,
)
self.peripheral_cis_links[pending_cis_link.handle] = pending_cis_link
self.send_hci_packet(
HCI_LE_CIS_Request_Event(
acl_connection_handle=connection.handle,
cis_connection_handle=pending_cis_link.handle,
cig_id=cig_id,
cis_id=cis_id,
)
)
def on_link_cis_established(self, cig_id: int, cis_id: int) -> None:
'''
Called when an incoming CIS established.
'''
cis_link = next(
cis_link
for cis_link in itertools.chain(
self.central_cis_links.values(), self.peripheral_cis_links.values()
)
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
)
self.send_hci_packet(
HCI_LE_CIS_Established_Event(
status=HCI_SUCCESS,
connection_handle=cis_link.handle,
# CIS parameters are ignored.
cig_sync_delay=0,
cis_sync_delay=0,
transport_latency_c_to_p=0,
transport_latency_p_to_c=0,
phy_c_to_p=0,
phy_p_to_c=0,
nse=0,
bn_c_to_p=0,
bn_p_to_c=0,
ft_c_to_p=0,
ft_p_to_c=0,
max_pdu_c_to_p=0,
max_pdu_p_to_c=0,
iso_interval=0,
)
)
def on_link_cis_disconnected(self, cig_id: int, cis_id: int) -> None:
'''
Called when a CIS disconnected.
'''
if cis_link := next(
(
cis_link
for cis_link in self.peripheral_cis_links.values()
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
),
None,
):
# Remove peripheral CIS on disconnection.
self.peripheral_cis_links.pop(cis_link.handle)
elif cis_link := next(
(
cis_link
for cis_link in self.central_cis_links.values()
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
),
None,
):
# Keep central CIS on disconnection. They should be removed by HCI_LE_Remove_CIG_Command.
cis_link.acl_connection = None
else:
return
self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=cis_link.handle,
reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
)
)
############################################################
# Classic link connections
############################################################
@@ -566,6 +695,7 @@ class Controller:
peer_address=peer_address,
link=self.link,
transport=BT_BR_EDR_TRANSPORT,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
)
self.classic_connections[peer_address] = connection
logger.debug(
@@ -619,6 +749,42 @@ class Controller:
)
)
def on_classic_sco_connection_complete(
self, peer_address: Address, status: int, link_type: int
):
if status == HCI_SUCCESS:
# Allocate (or reuse) a connection handle
connection_handle = self.allocate_connection_handle()
connection = Connection(
controller=self,
handle=connection_handle,
# Role doesn't matter in SCO.
role=BT_CENTRAL_ROLE,
peer_address=peer_address,
link=self.link,
transport=BT_BR_EDR_TRANSPORT,
link_type=link_type,
)
self.classic_connections[peer_address] = connection
logger.debug(f'New SCO connection handle: 0x{connection_handle:04X}')
else:
connection_handle = 0
self.send_hci_packet(
HCI_Synchronous_Connection_Complete_Event(
status=status,
connection_handle=connection_handle,
bd_addr=peer_address,
link_type=link_type,
# TODO: Provide SCO connection parameters.
transmission_interval=0,
retransmission_window=0,
rx_packet_length=0,
tx_packet_length=0,
air_mode=0,
)
)
############################################################
# Advertising support
############################################################
@@ -721,6 +887,17 @@ class Controller:
else:
# Remove the connection
del self.classic_connections[connection.peer_address]
elif cis_link := (
self.central_cis_links.get(handle) or self.peripheral_cis_links.get(handle)
):
if self.link:
self.link.disconnect_cis(
initiator_controller=self,
peer_address=cis_link.acl_connection.peer_address,
cig_id=cis_link.cig_id,
cis_id=cis_link.cis_id,
)
# Spec requires handle to be kept after disconnection.
def on_hci_accept_connection_request_command(self, command):
'''
@@ -738,6 +915,68 @@ class Controller:
)
self.link.classic_accept_connection(self, command.bd_addr, command.role)
def on_hci_enhanced_setup_synchronous_connection_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.1.45 Enhanced Setup Synchronous Connection command
'''
if self.link is None:
return
if not (
connection := self.find_classic_connection_by_handle(
command.connection_handle
)
):
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
return
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_SUCCESS,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
self.link.classic_sco_connect(
self, connection.peer_address, HCI_Connection_Complete_Event.ESCO_LINK_TYPE
)
def on_hci_enhanced_accept_synchronous_connection_request_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.1.46 Enhanced Accept Synchronous Connection Request command
'''
if self.link is None:
return
if not (connection := self.find_classic_connection_by_address(command.bd_addr)):
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
return
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_SUCCESS,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
self.link.classic_accept_sco_connection(
self, connection.peer_address, HCI_Connection_Complete_Event.ESCO_LINK_TYPE
)
def on_hci_switch_role_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.2.8 Switch Role command
@@ -912,14 +1151,48 @@ class Controller:
'''
See Bluetooth spec Vol 4, Part E - 7.4.3 Read Local Supported Features Command
'''
return bytes([HCI_SUCCESS]) + self.lmp_features
return bytes([HCI_SUCCESS]) + self.lmp_features[:8]
def on_hci_read_local_extended_features_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.4.4 Read Local Extended Features Command
'''
if command.page_number * 8 > len(self.lmp_features):
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
return (
bytes(
[
# Status
HCI_SUCCESS,
# Page number
command.page_number,
# Max page number
len(self.lmp_features) // 8 - 1,
]
)
# Features of the current page
+ self.lmp_features[command.page_number * 8 : (command.page_number + 1) * 8]
)
def on_hci_read_buffer_size_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.4.5 Read Buffer Size Command
'''
return struct.pack(
'<BHBHH',
HCI_SUCCESS,
self.hc_data_packet_length,
0,
self.hc_total_num_data_packets,
0,
)
def on_hci_read_bd_addr_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command
'''
bd_addr = (
self._public_address.to_bytes()
bytes(self._public_address)
if self._public_address is not None
else bytes(6)
)
@@ -1089,6 +1362,18 @@ class Controller:
See Bluetooth spec Vol 4, Part E - 7.8.21 LE Read Remote Features Command
'''
handle = command.connection_handle
if not self.find_connection_by_handle(handle):
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
return
# First, say that the command is pending
self.send_hci_packet(
HCI_Command_Status_Event(
@@ -1102,7 +1387,7 @@ class Controller:
self.send_hci_packet(
HCI_LE_Read_Remote_Features_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0,
connection_handle=handle,
le_features=bytes.fromhex('dd40000000000000'),
)
)
@@ -1258,8 +1543,191 @@ class Controller:
}
return bytes([HCI_SUCCESS])
def on_hci_le_set_advertising_set_random_address_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.52 LE Set Advertising Set Random Address
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_extended_advertising_parameters_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.53 LE Set Extended Advertising Parameters
Command
'''
return bytes([HCI_SUCCESS, 0])
def on_hci_le_set_extended_advertising_data_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.54 LE Set Extended Advertising Data
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_extended_scan_response_data_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.55 LE Set Extended Scan Response Data
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_extended_advertising_enable_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.56 LE Set Extended Advertising Enable
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_read_maximum_advertising_data_length_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.57 LE Read Maximum Advertising Data
Length Command
'''
return struct.pack('<BH', HCI_SUCCESS, 0x0672)
def on_hci_le_read_number_of_supported_advertising_sets_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.58 LE Read Number of Supported
Advertising Set Command
'''
return struct.pack('<BB', HCI_SUCCESS, 0xF0)
def on_hci_le_set_periodic_advertising_parameters_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.61 LE Set Periodic Advertising Parameters
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_periodic_advertising_data_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.62 LE Set Periodic Advertising Data
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_set_periodic_advertising_enable_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.63 LE Set Periodic Advertising Enable
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_read_transmit_power_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command
'''
return struct.pack('<BBB', HCI_SUCCESS, 0, 0)
def on_hci_le_set_cig_parameters_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.97 LE Set CIG Parameter Command
'''
# Remove old CIG implicitly.
for handle, cis_link in self.central_cis_links.items():
if cis_link.cig_id == command.cig_id:
self.central_cis_links.pop(handle)
handles = []
for cis_id in command.cis_id:
handle = self.allocate_connection_handle()
handles.append(handle)
self.central_cis_links[handle] = CisLink(
cis_id=cis_id,
cig_id=command.cig_id,
handle=handle,
)
return struct.pack(
'<BBB', HCI_SUCCESS, command.cig_id, len(handles)
) + b''.join([struct.pack('<H', handle) for handle in handles])
def on_hci_le_create_cis_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.99 LE Create CIS Command
'''
if not self.link:
return
for cis_handle, acl_handle in zip(
command.cis_connection_handle, command.acl_connection_handle
):
if not (connection := self.find_connection_by_handle(acl_handle)):
logger.error(f'Cannot find connection with handle={acl_handle}')
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
if not (cis_link := self.central_cis_links.get(cis_handle)):
logger.error(f'Cannot find CIS with handle={cis_handle}')
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
cis_link.acl_connection = connection
self.link.create_cis(
self,
peripheral_address=connection.peer_address,
cig_id=cis_link.cig_id,
cis_id=cis_link.cis_id,
)
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
def on_hci_le_remove_cig_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.100 LE Remove CIG Command
'''
status = HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR
for cis_handle, cis_link in self.central_cis_links.items():
if cis_link.cig_id == command.cig_id:
self.central_cis_links.pop(cis_handle)
status = HCI_SUCCESS
return struct.pack('<BH', status, command.cig_id)
def on_hci_le_accept_cis_request_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.101 LE Accept CIS Request Command
'''
if not self.link:
return
if not (
pending_cis_link := self.peripheral_cis_links.get(command.connection_handle)
):
logger.error(f'Cannot find CIS with handle={command.connection_handle}')
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
assert pending_cis_link.acl_connection
self.link.accept_cis(
peripheral_controller=self,
central_address=pending_cis_link.acl_connection.peer_address,
cig_id=pending_cis_link.cig_id,
cis_id=pending_cis_link.cis_id,
)
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
def on_hci_le_setup_iso_data_path_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.109 LE Setup ISO Data Path Command
'''
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
def on_hci_le_remove_iso_data_path_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command
'''
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)

File diff suppressed because it is too large Load Diff

View File

@@ -100,6 +100,16 @@ class EccKey:
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
def generate_prand() -> bytes:
'''Generates random 3 bytes, with the 2 most significant bits of 0b01.
See Bluetooth spec, Vol 6, Part E - Table 1.2.
'''
prand_bytes = secrets.token_bytes(6)
return prand_bytes[:2] + bytes([(prand_bytes[2] & 0b01111111) | 0b01000000])
# -----------------------------------------------------------------------------
def xor(x: bytes, y: bytes) -> bytes:
assert len(x) == len(y)

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -149,7 +151,7 @@ QMF_COEFFS = [3, -11, 12, 32, -210, 951, 3876, -805, 362, -156, 53, -11]
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class G722Decoder(object):
class G722Decoder:
"""G.722 decoder with bitrate 64kbit/s.
For the Blocks in the sub-band decoders, please refer to the G.722
@@ -157,7 +159,7 @@ class G722Decoder(object):
https://www.itu.int/rec/T-REC-G.722-201209-I
"""
def __init__(self):
def __init__(self) -> None:
self._x = [0] * 24
self._band = [Band(), Band()]
# The initial value in BLOCK 3L
@@ -165,12 +167,12 @@ class G722Decoder(object):
# The initial value in BLOCK 3H
self._band[1].det = 8
def decode_frame(self, encoded_data) -> bytearray:
def decode_frame(self, encoded_data: Union[bytes, bytearray]) -> bytearray:
result_array = bytearray(len(encoded_data) * 4)
self.g722_decode(result_array, encoded_data)
return result_array
def g722_decode(self, result_array, encoded_data) -> int:
def g722_decode(self, result_array, encoded_data: Union[bytes, bytearray]) -> int:
"""Decode the data frame using g722 decoder."""
result_length = 0
@@ -198,14 +200,16 @@ class G722Decoder(object):
return result_length
def update_decoded_result(self, xout, byte_length, byte_array) -> int:
def update_decoded_result(
self, xout: int, byte_length: int, byte_array: bytearray
) -> int:
result = (int)(xout >> 11)
bytes_result = result.to_bytes(2, 'little', signed=True)
byte_array[byte_length] = bytes_result[0]
byte_array[byte_length + 1] = bytes_result[1]
return byte_length + 2
def lower_sub_band_decoder(self, lower_bits) -> int:
def lower_sub_band_decoder(self, lower_bits: int) -> int:
"""Lower sub-band decoder for last six bits."""
# Block 5L
@@ -258,7 +262,7 @@ class G722Decoder(object):
return rlow
def higher_sub_band_decoder(self, higher_bits) -> int:
def higher_sub_band_decoder(self, higher_bits: int) -> int:
"""Higher sub-band decoder for first two bits."""
# Block 2H
@@ -306,14 +310,14 @@ class G722Decoder(object):
# -----------------------------------------------------------------------------
class Band(object):
"""Structure for G722 decode proccessing."""
class Band:
"""Structure for G722 decode processing."""
s: int = 0
nb: int = 0
det: int = 0
def __init__(self):
def __init__(self) -> None:
self._sp = 0
self._sz = 0
self._r = [0] * 3

File diff suppressed because it is too large Load Diff

View File

@@ -19,12 +19,17 @@ like loading firmware after a cold start.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import abc
from __future__ import annotations
import logging
import pathlib
import platform
from . import rtk
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
from . import rtk, intel
from .common import Driver
if TYPE_CHECKING:
from bumble.host import Host
# -----------------------------------------------------------------------------
# Logging
@@ -32,40 +37,31 @@ from . import rtk
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""
@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None
@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""
# -----------------------------------------------------------------------------
# Functions
# -----------------------------------------------------------------------------
async def get_driver_for_host(host):
"""Probe all known diver classes until one returns a valid instance for a host,
or none is found.
async def get_driver_for_host(host: Host) -> Optional[Driver]:
"""Probe diver classes until one returns a valid instance for a host, or none is
found.
If a "driver" HCI metadata entry is present, only that driver class will be probed.
"""
if driver := await rtk.Driver.for_host(host):
logger.debug("Instantiated RTK driver")
return driver
driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver, "intel": intel.Driver}
probe_list: Iterable[str]
if driver_name := host.hci_metadata.get("driver"):
# Only probe a single driver
probe_list = [driver_name]
else:
# Probe all drivers
probe_list = driver_classes.keys()
for driver_name in probe_list:
if driver_class := driver_classes.get(driver_name):
logger.debug(f"Probing driver class: {driver_name}")
if driver := await driver_class.for_host(host):
logger.debug(f"Instantiated {driver_name} driver")
return driver
else:
logger.debug(f"Skipping unknown driver class: {driver_name}")
return None

47
bumble/drivers/common.py Normal file
View File

@@ -0,0 +1,47 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common types for drivers.
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import abc
from bumble import core
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""
@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None
@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""

671
bumble/drivers/intel.py Normal file
View File

@@ -0,0 +1,671 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Support for Intel USB controllers.
Loosely based on the Fuchsia OS implementation.
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import collections
import dataclasses
import logging
import os
import pathlib
import platform
import struct
from typing import Any, Deque, Optional, TYPE_CHECKING
from bumble import core
from bumble.drivers import common
from bumble import hci
from bumble import utils
if TYPE_CHECKING:
from bumble.host import Host
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constant
# -----------------------------------------------------------------------------
INTEL_USB_PRODUCTS = {
(0x8087, 0x0032), # AX210
(0x8087, 0x0036), # BE200
}
INTEL_FW_IMAGE_NAMES = [
"ibt-0040-0041",
"ibt-0040-1020",
"ibt-0040-1050",
"ibt-0040-2120",
"ibt-0040-4150",
"ibt-0041-0041",
"ibt-0180-0041",
"ibt-0180-1050",
"ibt-0180-4150",
"ibt-0291-0291",
"ibt-1040-0041",
"ibt-1040-1020",
"ibt-1040-1050",
"ibt-1040-2120",
"ibt-1040-4150",
]
INTEL_FIRMWARE_DIR_ENV = "BUMBLE_INTEL_FIRMWARE_DIR"
INTEL_LINUX_FIRMWARE_DIR = "/lib/firmware/intel"
_MAX_FRAGMENT_SIZE = 252
_POST_RESET_DELAY = 0.2
# -----------------------------------------------------------------------------
# HCI Commands
# -----------------------------------------------------------------------------
HCI_INTEL_WRITE_DEVICE_CONFIG_COMMAND = hci.hci_vendor_command_op_code(0x008B)
HCI_INTEL_READ_VERSION_COMMAND = hci.hci_vendor_command_op_code(0x0005)
HCI_INTEL_RESET_COMMAND = hci.hci_vendor_command_op_code(0x0001)
HCI_INTEL_SECURE_SEND_COMMAND = hci.hci_vendor_command_op_code(0x0009)
HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND = hci.hci_vendor_command_op_code(0x000E)
hci.HCI_Command.register_commands(globals())
@hci.HCI_Command.command(
fields=[
("param0", 1),
],
return_parameters_fields=[
("status", hci.STATUS_SPEC),
("tlv", "*"),
],
)
class HCI_Intel_Read_Version_Command(hci.HCI_Command):
pass
@hci.HCI_Command.command(
fields=[("data_type", 1), ("data", "*")],
return_parameters_fields=[
("status", 1),
],
)
class Hci_Intel_Secure_Send_Command(hci.HCI_Command):
pass
@hci.HCI_Command.command(
fields=[
("reset_type", 1),
("patch_enable", 1),
("ddc_reload", 1),
("boot_option", 1),
("boot_address", 4),
],
return_parameters_fields=[
("data", "*"),
],
)
class HCI_Intel_Reset_Command(hci.HCI_Command):
pass
@hci.HCI_Command.command(
fields=[("data", "*")],
return_parameters_fields=[
("status", hci.STATUS_SPEC),
("params", "*"),
],
)
class Hci_Intel_Write_Device_Config_Command(hci.HCI_Command):
pass
# -----------------------------------------------------------------------------
# Functions
# -----------------------------------------------------------------------------
def intel_firmware_dir() -> pathlib.Path:
"""
Returns:
A path to a subdir of the project data dir for Intel firmware.
The directory is created if it doesn't exist.
"""
from bumble.drivers import project_data_dir
p = project_data_dir() / "firmware" / "intel"
p.mkdir(parents=True, exist_ok=True)
return p
def _find_binary_path(file_name: str) -> pathlib.Path | None:
# First check if an environment variable is set
if INTEL_FIRMWARE_DIR_ENV in os.environ:
if (
path := pathlib.Path(os.environ[INTEL_FIRMWARE_DIR_ENV]) / file_name
).is_file():
logger.debug(f"{file_name} found in env dir")
return path
# When the environment variable is set, don't look elsewhere
return None
# Then, look where the firmware download tool writes by default
if (path := intel_firmware_dir() / file_name).is_file():
logger.debug(f"{file_name} found in project data dir")
return path
# Then, look in the package's driver directory
if (path := pathlib.Path(__file__).parent / "intel_fw" / file_name).is_file():
logger.debug(f"{file_name} found in package dir")
return path
# On Linux, check the system's FW directory
if (
platform.system() == "Linux"
and (path := pathlib.Path(INTEL_LINUX_FIRMWARE_DIR) / file_name).is_file()
):
logger.debug(f"{file_name} found in Linux system FW dir")
return path
# Finally look in the current directory
if (path := pathlib.Path.cwd() / file_name).is_file():
logger.debug(f"{file_name} found in CWD")
return path
return None
def _parse_tlv(data: bytes) -> list[tuple[ValueType, Any]]:
result: list[tuple[ValueType, Any]] = []
while len(data) >= 2:
value_type = ValueType(data[0])
value_length = data[1]
value = data[2 : 2 + value_length]
typed_value: Any
if value_type == ValueType.END:
break
if value_type in (ValueType.CNVI, ValueType.CNVR):
(v,) = struct.unpack("<I", value)
typed_value = (
(((v >> 0) & 0xF) << 12)
| (((v >> 4) & 0xF) << 0)
| (((v >> 8) & 0xF) << 4)
| (((v >> 24) & 0xF) << 8)
)
elif value_type == ValueType.HARDWARE_INFO:
(v,) = struct.unpack("<I", value)
typed_value = HardwareInfo(
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
)
elif value_type in (
ValueType.USB_VENDOR_ID,
ValueType.USB_PRODUCT_ID,
ValueType.DEVICE_REVISION,
):
(typed_value,) = struct.unpack("<H", value)
elif value_type == ValueType.CURRENT_MODE_OF_OPERATION:
typed_value = ModeOfOperation(value[0])
elif value_type in (
ValueType.BUILD_TYPE,
ValueType.BUILD_NUMBER,
ValueType.SECURE_BOOT,
ValueType.OTP_LOCK,
ValueType.API_LOCK,
ValueType.DEBUG_LOCK,
ValueType.SECURE_BOOT_ENGINE_TYPE,
):
typed_value = value[0]
elif value_type == ValueType.TIMESTAMP:
typed_value = Timestamp(value[0], value[1])
elif value_type == ValueType.FIRMWARE_BUILD:
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
elif value_type == ValueType.BLUETOOTH_ADDRESS:
typed_value = hci.Address(
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
)
else:
typed_value = value
result.append((value_type, typed_value))
data = data[2 + value_length :]
return result
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class DriverError(core.BaseBumbleError):
def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message
def __str__(self) -> str:
return f"IntelDriverError({self.message})"
class ValueType(utils.OpenIntEnum):
END = 0x00
CNVI = 0x10
CNVR = 0x11
HARDWARE_INFO = 0x12
DEVICE_REVISION = 0x16
CURRENT_MODE_OF_OPERATION = 0x1C
USB_VENDOR_ID = 0x17
USB_PRODUCT_ID = 0x18
TIMESTAMP = 0x1D
BUILD_TYPE = 0x1E
BUILD_NUMBER = 0x1F
SECURE_BOOT = 0x28
OTP_LOCK = 0x2A
API_LOCK = 0x2B
DEBUG_LOCK = 0x2C
FIRMWARE_BUILD = 0x2D
SECURE_BOOT_ENGINE_TYPE = 0x2F
BLUETOOTH_ADDRESS = 0x30
class HardwarePlatform(utils.OpenIntEnum):
INTEL_37 = 0x37
class HardwareVariant(utils.OpenIntEnum):
# This is a just a partial list.
# Add other constants here as new hardware is encountered and tested.
TYPHOON_PEAK = 0x17
GALE_PEAK = 0x1C
@dataclasses.dataclass
class HardwareInfo:
platform: HardwarePlatform
variant: HardwareVariant
@dataclasses.dataclass
class Timestamp:
week: int
year: int
@dataclasses.dataclass
class FirmwareBuild:
build_number: int
timestamp: Timestamp
class ModeOfOperation(utils.OpenIntEnum):
BOOTLOADER = 0x01
INTERMEDIATE = 0x02
OPERATIONAL = 0x03
class SecureBootEngineType(utils.OpenIntEnum):
RSA = 0x00
ECDSA = 0x01
@dataclasses.dataclass
class BootParams:
css_header_offset: int
css_header_size: int
pki_offset: int
pki_size: int
sig_offset: int
sig_size: int
write_offset: int
_BOOT_PARAMS = {
SecureBootEngineType.RSA: BootParams(0, 128, 128, 256, 388, 256, 964),
SecureBootEngineType.ECDSA: BootParams(644, 128, 772, 96, 868, 96, 964),
}
class Driver(common.Driver):
def __init__(self, host: Host) -> None:
self.host = host
self.max_in_flight_firmware_load_commands = 1
self.pending_firmware_load_commands: Deque[hci.HCI_Command] = (
collections.deque()
)
self.can_send_firmware_load_command = asyncio.Event()
self.can_send_firmware_load_command.set()
self.firmware_load_complete = asyncio.Event()
self.reset_complete = asyncio.Event()
# Parse configuration options from the driver name.
self.ddc_addon: Optional[bytes] = None
self.ddc_override: Optional[bytes] = None
driver = host.hci_metadata.get("driver")
if driver is not None and driver.startswith("intel/"):
for key, value in [
key_eq_value.split(":") for key_eq_value in driver[6:].split("+")
]:
if key == "ddc_addon":
self.ddc_addon = bytes.fromhex(value)
elif key == "ddc_override":
self.ddc_override = bytes.fromhex(value)
@staticmethod
def check(host: Host) -> bool:
driver = host.hci_metadata.get("driver")
if driver == "intel" or driver is not None and driver.startswith("intel/"):
return True
vendor_id = host.hci_metadata.get("vendor_id")
product_id = host.hci_metadata.get("product_id")
if vendor_id is None or product_id is None:
logger.debug("USB metadata not sufficient")
return False
if (vendor_id, product_id) not in INTEL_USB_PRODUCTS:
logger.debug(
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
)
return False
return True
@classmethod
async def for_host(cls, host: Host, force: bool = False):
# Only instantiate this driver if explicitly selected
if not force and not cls.check(host):
return None
return cls(host)
def on_packet(self, packet: bytes) -> None:
"""Handler for event packets that are received from an ACL channel"""
event = hci.HCI_Event.from_bytes(packet)
if not isinstance(event, hci.HCI_Command_Complete_Event):
self.host.on_hci_event_packet(event)
return
if not event.return_parameters == hci.HCI_SUCCESS:
raise DriverError("HCI_Command_Complete_Event error")
if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets:
logger.debug(
"max_in_flight_firmware_load_commands update: "
f"{event.num_hci_command_packets}"
)
self.max_in_flight_firmware_load_commands = event.num_hci_command_packets
logger.debug(f"event: {event}")
self.pending_firmware_load_commands.popleft()
in_flight = len(self.pending_firmware_load_commands)
logger.debug(f"event received, {in_flight} still in flight")
if in_flight < self.max_in_flight_firmware_load_commands:
self.can_send_firmware_load_command.set()
async def send_firmware_load_command(self, command: hci.HCI_Command) -> None:
# Wait until we can send.
await self.can_send_firmware_load_command.wait()
# Send the command and adjust counters.
self.host.send_hci_packet(command)
self.pending_firmware_load_commands.append(command)
in_flight = len(self.pending_firmware_load_commands)
if in_flight >= self.max_in_flight_firmware_load_commands:
logger.debug(f"max commands in flight reached [{in_flight}]")
self.can_send_firmware_load_command.clear()
async def send_firmware_data(self, data_type: int, data: bytes) -> None:
while data:
fragment_size = min(len(data), _MAX_FRAGMENT_SIZE)
fragment = data[:fragment_size]
data = data[fragment_size:]
await self.send_firmware_load_command(
Hci_Intel_Secure_Send_Command(data_type=data_type, data=fragment)
)
async def load_firmware(self) -> None:
self.host.ready = True
device_info = await self.read_device_info()
logger.debug(
"device info: \n%s",
"\n".join(
[
f" {value_type.name}: {value}"
for value_type, value in device_info.items()
]
),
)
# Check if the firmware is already loaded.
if (
device_info.get(ValueType.CURRENT_MODE_OF_OPERATION)
== ModeOfOperation.OPERATIONAL
):
logger.debug("firmware already loaded")
return
# We only support some platforms and variants.
hardware_info = device_info.get(ValueType.HARDWARE_INFO)
if hardware_info is None:
raise DriverError("hardware info missing")
if hardware_info.platform != HardwarePlatform.INTEL_37:
raise DriverError("hardware platform not supported")
if hardware_info.variant not in (
HardwareVariant.TYPHOON_PEAK,
HardwareVariant.GALE_PEAK,
):
raise DriverError("hardware variant not supported")
# Compute the firmware name.
if ValueType.CNVI not in device_info or ValueType.CNVR not in device_info:
raise DriverError("insufficient device info, missing CNVI or CNVR")
firmware_base_name = (
"ibt-"
f"{device_info[ValueType.CNVI]:04X}-"
f"{device_info[ValueType.CNVR]:04X}"
)
logger.debug(f"FW base name: {firmware_base_name}")
firmware_name = f"{firmware_base_name}.sfi"
firmware_path = _find_binary_path(firmware_name)
if not firmware_path:
logger.warning(f"Firmware file {firmware_name} not found")
logger.warning("See https://google.github.io/bumble/drivers/intel.html")
return None
logger.debug(f"loading firmware from {firmware_path}")
firmware_image = firmware_path.read_bytes()
engine_type = device_info.get(ValueType.SECURE_BOOT_ENGINE_TYPE)
if engine_type is None:
raise DriverError("secure boot engine type missing")
if engine_type not in _BOOT_PARAMS:
raise DriverError("secure boot engine type not supported")
boot_params = _BOOT_PARAMS[engine_type]
if len(firmware_image) < boot_params.write_offset:
raise DriverError("firmware image too small")
# Register to receive vendor events.
def on_vendor_event(event: hci.HCI_Vendor_Event):
logger.debug(f"vendor event: {event}")
event_type = event.parameters[0]
if event_type == 0x02:
# Boot event
logger.debug("boot complete")
self.reset_complete.set()
elif event_type == 0x06:
# Firmware load event
logger.debug("download complete")
self.firmware_load_complete.set()
else:
logger.debug(f"ignoring vendor event type {event_type}")
self.host.on("vendor_event", on_vendor_event)
# We need to temporarily intercept packets from the controller,
# because they are formatted as HCI event packets but are received
# on the ACL channel, so the host parser would get confused.
saved_on_packet = self.host.on_packet
self.host.on_packet = self.on_packet # type: ignore
self.firmware_load_complete.clear()
# Send the CSS header
data = firmware_image[
boot_params.css_header_offset : boot_params.css_header_offset
+ boot_params.css_header_size
]
await self.send_firmware_data(0x00, data)
# Send the PKI header
data = firmware_image[
boot_params.pki_offset : boot_params.pki_offset + boot_params.pki_size
]
await self.send_firmware_data(0x03, data)
# Send the Signature header
data = firmware_image[
boot_params.sig_offset : boot_params.sig_offset + boot_params.sig_size
]
await self.send_firmware_data(0x02, data)
# Send the rest of the image.
# The payload consists of command objects, which are sent when they add up
# to a multiple of 4 bytes.
boot_address = 0
offset = boot_params.write_offset
fragment_size = 0
while offset + 3 < len(firmware_image):
(command_opcode,) = struct.unpack_from(
"<H", firmware_image, offset + fragment_size
)
command_size = firmware_image[offset + fragment_size + 2]
if command_opcode == HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND:
(boot_address,) = struct.unpack_from(
"<I", firmware_image, offset + fragment_size + 3
)
logger.debug(
"found HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND, "
f"boot_address={boot_address}"
)
fragment_size += 3 + command_size
if fragment_size % 4 == 0:
await self.send_firmware_data(
0x01, firmware_image[offset : offset + fragment_size]
)
logger.debug(f"sent {fragment_size} bytes")
offset += fragment_size
fragment_size = 0
# Wait for the firmware loading to be complete.
logger.debug("waiting for firmware to be loaded")
await self.firmware_load_complete.wait()
logger.debug("firmware loaded")
# Restore the original packet handler.
self.host.on_packet = saved_on_packet # type: ignore
# Reset
self.reset_complete.clear()
self.host.send_hci_packet(
HCI_Intel_Reset_Command(
reset_type=0x00,
patch_enable=0x01,
ddc_reload=0x00,
boot_option=0x01,
boot_address=boot_address,
)
)
logger.debug("waiting for reset completion")
await self.reset_complete.wait()
logger.debug("reset complete")
# Load the device config if there is one.
if self.ddc_override:
logger.debug("loading overridden DDC")
await self.load_device_config(self.ddc_override)
else:
ddc_name = f"{firmware_base_name}.ddc"
ddc_path = _find_binary_path(ddc_name)
if ddc_path:
logger.debug(f"loading DDC from {ddc_path}")
ddc_data = ddc_path.read_bytes()
await self.load_device_config(ddc_data)
if self.ddc_addon:
logger.debug("loading DDC addon")
await self.load_device_config(self.ddc_addon)
async def load_device_config(self, ddc_data: bytes) -> None:
while ddc_data:
ddc_len = 1 + ddc_data[0]
ddc_payload = ddc_data[:ddc_len]
await self.host.send_command(
Hci_Intel_Write_Device_Config_Command(data=ddc_payload)
)
ddc_data = ddc_data[ddc_len:]
async def reboot_bootloader(self) -> None:
self.host.send_hci_packet(
HCI_Intel_Reset_Command(
reset_type=0x01,
patch_enable=0x01,
ddc_reload=0x01,
boot_option=0x00,
boot_address=0,
)
)
await asyncio.sleep(_POST_RESET_DELAY)
async def read_device_info(self) -> dict[ValueType, Any]:
self.host.ready = True
response = await self.host.send_command(hci.HCI_Reset_Command())
if not (
isinstance(response, hci.HCI_Command_Complete_Event)
and response.return_parameters
in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS)
):
# When the controller is in operational mode, the response is a
# successful response.
# When the controller is in bootloader mode,
# HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything
# else is a failure.
logger.warning(f"unexpected response: {response}")
raise DriverError("unexpected HCI response")
# Read the firmware version.
response = await self.host.send_command(
HCI_Intel_Read_Version_Command(param0=0xFF)
)
if not isinstance(response, hci.HCI_Command_Complete_Event):
raise DriverError("unexpected HCI response")
if response.return_parameters.status != 0: # type: ignore
raise DriverError("HCI_Intel_Read_Version_Command error")
tlvs = _parse_tlv(response.return_parameters.tlv) # type: ignore
# Convert the list to a dict. That's Ok here because we only expect each type
# to appear just once.
return dict(tlvs)
async def init_controller(self):
await self.load_firmware()

View File

@@ -33,6 +33,7 @@ from typing import Tuple
import weakref
from bumble import core
from bumble.hci import (
hci_vendor_command_op_code,
STATUS_SPEC,
@@ -41,7 +42,7 @@ from bumble.hci import (
HCI_Reset_Command,
HCI_Read_Local_Version_Information_Command,
)
from bumble.drivers import common
# -----------------------------------------------------------------------------
# Logging
@@ -49,6 +50,10 @@ from bumble.hci import (
logger = logging.getLogger(__name__)
class RtkFirmwareError(core.BaseBumbleError):
"""Error raised when RTK firmware initialization fails."""
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -208,15 +213,15 @@ class Firmware:
extension_sig = bytes([0x51, 0x04, 0xFD, 0x77])
if not firmware.startswith(RTK_EPATCH_SIGNATURE):
raise ValueError("Firmware does not start with epatch signature")
raise RtkFirmwareError("Firmware does not start with epatch signature")
if not firmware.endswith(extension_sig):
raise ValueError("Firmware does not end with extension sig")
raise RtkFirmwareError("Firmware does not end with extension sig")
# The firmware should start with a 14 byte header.
epatch_header_size = 14
if len(firmware) < epatch_header_size:
raise ValueError("Firmware too short")
raise RtkFirmwareError("Firmware too short")
# Look for the "project ID", starting from the end.
offset = len(firmware) - len(extension_sig)
@@ -230,7 +235,7 @@ class Firmware:
break
if length == 0:
raise ValueError("Invalid 0-length instruction")
raise RtkFirmwareError("Invalid 0-length instruction")
if opcode == 0 and length == 1:
project_id = firmware[offset - 1]
@@ -239,7 +244,7 @@ class Firmware:
offset -= length
if project_id < 0:
raise ValueError("Project ID not found")
raise RtkFirmwareError("Project ID not found")
self.project_id = project_id
@@ -252,7 +257,7 @@ class Firmware:
# <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each)
# <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each)
if epatch_header_size + 8 * num_patches > len(firmware):
raise ValueError("Firmware too short")
raise RtkFirmwareError("Firmware too short")
chip_id_table_offset = epatch_header_size
patch_length_table_offset = chip_id_table_offset + 2 * num_patches
patch_offset_table_offset = chip_id_table_offset + 4 * num_patches
@@ -266,7 +271,7 @@ class Firmware:
"<I", firmware, patch_offset_table_offset + 4 * patch_index
)
if patch_offset + patch_length > len(firmware):
raise ValueError("Firmware too short")
raise RtkFirmwareError("Firmware too short")
# Get the SVN version for the patch
(svn_version,) = struct.unpack_from(
@@ -285,7 +290,7 @@ class Firmware:
)
class Driver:
class Driver(common.Driver):
@dataclass
class DriverInfo:
rom: int
@@ -296,6 +301,8 @@ class Driver:
fw_name: str = ""
config_name: str = ""
POST_RESET_DELAY: float = 0.2
DRIVER_INFOS = [
# 8723A
DriverInfo(
@@ -470,8 +477,12 @@ class Driver:
logger.debug("USB metadata not found")
return False
vendor_id = host.hci_metadata.get("vendor_id", None)
product_id = host.hci_metadata.get("product_id", None)
if host.hci_metadata.get('driver') == 'rtk':
# Forced driver
return True
vendor_id = host.hci_metadata.get("vendor_id")
product_id = host.hci_metadata.get("product_id")
if vendor_id is None or product_id is None:
logger.debug("USB metadata not sufficient")
return False
@@ -486,9 +497,24 @@ class Driver:
@classmethod
async def driver_info_for_host(cls, host):
response = await host.send_command(
HCI_Read_Local_Version_Information_Command(), check_result=True
)
try:
await host.send_command(
HCI_Reset_Command(),
check_result=True,
response_timeout=cls.POST_RESET_DELAY,
)
host.ready = True # Needed to let the host know the controller is ready.
except asyncio.exceptions.TimeoutError:
logger.warning("timeout waiting for hci reset, retrying")
await host.send_command(HCI_Reset_Command(), check_result=True)
host.ready = True
command = HCI_Read_Local_Version_Information_Command()
response = await host.send_command(command, check_result=True)
if response.command_opcode != command.op_code:
logger.error("failed to probe local version information")
return None
local_version = response.return_parameters
logger.debug(
@@ -638,7 +664,7 @@ class Driver:
):
return await self.download_for_rtl8723b()
raise ValueError("ROM not supported")
raise RtkFirmwareError("ROM not supported")
async def init_controller(self):
await self.download_firmware()

View File

@@ -36,6 +36,7 @@ logger = logging.getLogger(__name__)
# Classes
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class GenericAccessService(Service):
def __init__(self, device_name, appearance=(0, 0)):

View File

@@ -23,16 +23,32 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import enum
import functools
import logging
import struct
from typing import Optional, Sequence, Iterable, List, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
SupportsBytes,
Type,
Union,
TYPE_CHECKING,
)
from .colors import color
from .core import UUID, get_dict_key_by_value
from .att import Attribute
from bumble.colors import color
from bumble.core import BaseBumbleError, UUID
from bumble.att import Attribute, AttributeValue
from bumble.utils import ByteSerializable
if TYPE_CHECKING:
from bumble.gatt_client import AttributeProxy
from bumble.device import Connection
# -----------------------------------------------------------------------------
@@ -226,22 +242,22 @@ GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control Id')
# Telephone Bearer Service (TBS)
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB4, 'Bearer Provider Name')
GATT_BEARER_UCI_CHARACTERISTIC = UUID.from_16_bits(0x2BB5, 'Bearer UCI')
GATT_BEARER_TECHNOLOGY_CHARACTERISTIC = UUID.from_16_bits(0x2BB6, 'Bearer Technology')
GATT_BEARER_URI_SCHEMES_SUPPORTED_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2BB7, 'Bearer URI Schemes Supported List')
GATT_BEARER_SIGNAL_STRENGTH_CHARACTERISTIC = UUID.from_16_bits(0x2BB8, 'Bearer Signal Strength')
GATT_BEARER_SIGNAL_STRENGTH_REPORTING_INTERVAL_CHARACTERISTIC = UUID.from_16_bits(0x2BB9, 'Bearer Signal Strength Reporting Interval')
GATT_BEARER_LIST_CURRENT_CALLS_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Bearer List Current Calls')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBB, 'Content Control ID')
GATT_STATUS_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2BBC, 'Status Flags')
GATT_INCOMING_CALL_TARGET_BEARER_URI_CHARACTERISTIC = UUID.from_16_bits(0x2BBD, 'Incoming Call Target Bearer URI')
GATT_CALL_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BBE, 'Call State')
GATT_CALL_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BBF, 'Call Control Point')
GATT_CALL_CONTROL_POINT_OPTIONAL_OPCODES_CHARACTERISTIC = UUID.from_16_bits(0x2BC0, 'Call Control Point Optional Opcodes')
GATT_TERMINATION_REASON_CHARACTERISTIC = UUID.from_16_bits(0x2BC1, 'Termination Reason')
GATT_INCOMING_CALL_CHARACTERISTIC = UUID.from_16_bits(0x2BC2, 'Incoming Call')
GATT_CALL_FRIENDLY_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Call Friendly Name')
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB3, 'Bearer Provider Name')
GATT_BEARER_UCI_CHARACTERISTIC = UUID.from_16_bits(0x2BB4, 'Bearer UCI')
GATT_BEARER_TECHNOLOGY_CHARACTERISTIC = UUID.from_16_bits(0x2BB5, 'Bearer Technology')
GATT_BEARER_URI_SCHEMES_SUPPORTED_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2BB6, 'Bearer URI Schemes Supported List')
GATT_BEARER_SIGNAL_STRENGTH_CHARACTERISTIC = UUID.from_16_bits(0x2BB7, 'Bearer Signal Strength')
GATT_BEARER_SIGNAL_STRENGTH_REPORTING_INTERVAL_CHARACTERISTIC = UUID.from_16_bits(0x2BB8, 'Bearer Signal Strength Reporting Interval')
GATT_BEARER_LIST_CURRENT_CALLS_CHARACTERISTIC = UUID.from_16_bits(0x2BB9, 'Bearer List Current Calls')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control ID')
GATT_STATUS_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2BBB, 'Status Flags')
GATT_INCOMING_CALL_TARGET_BEARER_URI_CHARACTERISTIC = UUID.from_16_bits(0x2BBC, 'Incoming Call Target Bearer URI')
GATT_CALL_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BBD, 'Call State')
GATT_CALL_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BBE, 'Call Control Point')
GATT_CALL_CONTROL_POINT_OPTIONAL_OPCODES_CHARACTERISTIC = UUID.from_16_bits(0x2BBF, 'Call Control Point Optional Opcodes')
GATT_TERMINATION_REASON_CHARACTERISTIC = UUID.from_16_bits(0x2BC0, 'Termination Reason')
GATT_INCOMING_CALL_CHARACTERISTIC = UUID.from_16_bits(0x2BC1, 'Incoming Call')
GATT_CALL_FRIENDLY_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BC2, 'Call Friendly Name')
# Microphone Control Service (MICS)
GATT_MUTE_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Mute')
@@ -263,6 +279,18 @@ GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Sou
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
# Gaming Audio Service (GMAS)
GATT_GMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2C00, 'GMAP Role')
GATT_UGG_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C01, 'UGG Features')
GATT_UGT_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C02, 'UGT Features')
GATT_BGS_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C03, 'BGS Features')
GATT_BGR_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C04, 'BGR Features')
# Hearing Access Service
GATT_HEARING_AID_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2BDA, 'Hearing Aid Features')
GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BDB, 'Hearing Aid Preset Control Point')
GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC = UUID.from_16_bits(0x2BDC, 'Active Preset Index')
# ASHA Service
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
@@ -308,6 +336,11 @@ def show_services(services: Iterable[Service]) -> None:
print(color(' ' + str(descriptor), 'green'))
# -----------------------------------------------------------------------------
class InvalidServiceError(BaseBumbleError):
"""The service is not compliant with the spec/profile"""
# -----------------------------------------------------------------------------
class Service(Attribute):
'''
@@ -321,24 +354,26 @@ class Service(Attribute):
def __init__(
self,
uuid: Union[str, UUID],
characteristics: List[Characteristic],
characteristics: Iterable[Characteristic],
primary=True,
included_services: List[Service] = [],
included_services: Iterable[Service] = (),
) -> None:
# Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str):
uuid = UUID(uuid)
super().__init__(
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
(
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
),
Attribute.READABLE,
uuid.to_pdu_bytes(),
)
self.uuid = uuid
self.included_services = included_services[:]
self.characteristics = characteristics[:]
self.included_services = list(included_services)
self.characteristics = list(characteristics)
self.primary = primary
def get_advertising_data(self) -> Optional[bytes]:
@@ -368,9 +403,12 @@ class TemplateService(Service):
UUID: UUID
def __init__(
self, characteristics: List[Characteristic], primary: bool = True
self,
characteristics: Iterable[Characteristic],
primary: bool = True,
included_services: Iterable[Service] = (),
) -> None:
super().__init__(self.UUID, characteristics, primary)
super().__init__(self.UUID, characteristics, primary, included_services)
# -----------------------------------------------------------------------------
@@ -383,7 +421,7 @@ class IncludedServiceDeclaration(Attribute):
def __init__(self, service: Service) -> None:
declaration_bytes = struct.pack(
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
'<HH2s', service.handle, service.end_group_handle, bytes(service.uuid)
)
super().__init__(
GATT_INCLUDE_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
@@ -463,7 +501,7 @@ class Characteristic(Attribute):
uuid: Union[str, bytes, UUID],
properties: Characteristic.Properties,
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, CharacteristicValue] = b'',
value: Any = b'',
descriptors: Sequence[Descriptor] = (),
):
super().__init__(uuid, permissions, value)
@@ -498,7 +536,11 @@ class CharacteristicDeclaration(Attribute):
characteristic: Characteristic
def __init__(self, characteristic: Characteristic, value_handle: int) -> None:
def __init__(
self,
characteristic: Characteristic,
value_handle: int,
) -> None:
declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes()
@@ -519,56 +561,43 @@ class CharacteristicDeclaration(Attribute):
# -----------------------------------------------------------------------------
class CharacteristicValue:
'''
Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(self, read=None, write=None):
self._read = read
self._write = write
def read(self, connection):
return self._read(connection) if self._read else b''
def write(self, connection, value):
if self._write:
self._write(connection, value)
class CharacteristicValue(AttributeValue):
"""Same as AttributeValue, for backward compatibility"""
# -----------------------------------------------------------------------------
class CharacteristicAdapter:
'''
An adapter that can adapt any object with `read_value` and `write_value`
methods (like Characteristic and CharacteristicProxy objects) by wrapping
those methods with ones that return/accept encoded/decoded values.
Objects with async methods are considered proxies, so the adaptation is one
where the return value of `read_value` is decoded and the value passed to
`write_value` is encoded. Other objects are considered local characteristics
so the adaptation is one where the return value of `read_value` is encoded
and the value passed to `write_value` is decoded.
If the characteristic has a `subscribe` method, it is wrapped with one where
the values are decoded before being passed to the subscriber.
An adapter that can adapt Characteristic and AttributeProxy objects
by wrapping their `read_value()` and `write_value()` methods with ones that
return/accept encoded/decoded values.
For proxies (i.e used by a GATT client), the adaptation is one where the return
value of `read_value()` is decoded and the value passed to `write_value()` is
encoded. The `subscribe()` method, is wrapped with one where the values are decoded
before being passed to the subscriber.
For local values (i.e hosted by a GATT server) the adaptation is one where the
return value of `read_value()` is encoded and the value passed to `write_value()`
is decoded.
'''
def __init__(self, characteristic):
self.wrapped_characteristic = characteristic
self.subscribers = {} # Map from subscriber to proxy subscriber
read_value: Callable
write_value: Callable
if asyncio.iscoroutinefunction(
characteristic.read_value
) and asyncio.iscoroutinefunction(characteristic.write_value):
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
else:
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
self.wrapped_characteristic = characteristic
self.subscribers: Dict[Callable, Callable] = (
{}
) # Map from subscriber to proxy subscriber
if isinstance(characteristic, Characteristic):
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
if hasattr(self.wrapped_characteristic, 'subscribe'):
else:
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
self.subscribe = self.wrapped_subscribe
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name):
@@ -587,11 +616,13 @@ class CharacteristicAdapter:
else:
setattr(self.wrapped_characteristic, name, value)
def read_encoded_value(self, connection):
return self.encode_value(self.wrapped_characteristic.read_value(connection))
async def read_encoded_value(self, connection):
return self.encode_value(
await self.wrapped_characteristic.read_value(connection)
)
def write_encoded_value(self, connection, value):
return self.wrapped_characteristic.write_value(
async def write_encoded_value(self, connection, value):
return await self.wrapped_characteristic.write_value(
connection, self.decode_value(value)
)
@@ -689,7 +720,7 @@ class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
'''
Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format.
The adapted `read_value` and `write_value` methods return/accept aa dictionary which
The adapted `read_value` and `write_value` methods return/accept a dictionary which
is packed/unpacked according to format, with the arguments extracted from the
dictionary by key, in the same order as they occur in the `keys` parameter.
'''
@@ -719,6 +750,24 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
return value.decode('utf-8')
# -----------------------------------------------------------------------------
class SerializableCharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that converts any class to/from bytes using the class'
`to_bytes` and `__bytes__` methods, respectively.
'''
def __init__(self, characteristic, cls: Type[ByteSerializable]):
super().__init__(characteristic)
self.cls = cls
def encode_value(self, value: SupportsBytes) -> bytes:
return bytes(value)
def decode_value(self, value: bytes) -> Any:
return self.cls.from_bytes(value)
# -----------------------------------------------------------------------------
class Descriptor(Attribute):
'''
@@ -726,13 +775,24 @@ class Descriptor(Attribute):
'''
def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
value = self.value.read(None)
if isinstance(value, bytes):
value_str = value.hex()
else:
value_str = '<async>'
else:
value_str = '<...>'
return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'value={self.read_value(None).hex()})'
f'value={value_str})'
)
# -----------------------------------------------------------------------------
class ClientCharacteristicConfigurationBits(enum.IntFlag):
'''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit

View File

@@ -68,7 +68,7 @@ from .att import (
ATT_Error,
)
from . import core
from .core import UUID, InvalidStateError, ProtocolError
from .core import UUID, InvalidStateError
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
@@ -90,6 +90,22 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def show_services(services: Iterable[ServiceProxy]) -> None:
for service in services:
print(color(str(service), 'cyan'))
for characteristic in service.characteristics:
print(color(' ' + str(characteristic), 'magenta'))
for descriptor in characteristic.descriptors:
print(color(' ' + str(descriptor), 'green'))
# -----------------------------------------------------------------------------
# Proxies
# -----------------------------------------------------------------------------
@@ -237,7 +253,7 @@ class ProfileServiceProxy:
SERVICE_CLASS: Type[TemplateService]
@classmethod
def from_client(cls, client: Client) -> ProfileServiceProxy:
def from_client(cls, client: Client) -> Optional[ProfileServiceProxy]:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -267,6 +283,8 @@ class Client:
self.services = []
self.cached_values = {}
connection.on('disconnection', self.on_disconnection)
def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(ATT_CID, pdu)
@@ -274,7 +292,7 @@ class Client:
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
self.send_gatt_pdu(command.to_bytes())
self.send_gatt_pdu(bytes(command))
async def send_request(self, request: ATT_PDU):
logger.debug(
@@ -292,7 +310,7 @@ class Client:
self.pending_request = request
try:
self.send_gatt_pdu(request.to_bytes())
self.send_gatt_pdu(bytes(request))
response = await asyncio.wait_for(
self.pending_response, GATT_REQUEST_TIMEOUT
)
@@ -310,14 +328,14 @@ class Client:
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
self.send_gatt_pdu(confirmation.to_bytes())
self.send_gatt_pdu(bytes(confirmation))
async def request_mtu(self, mtu: int) -> int:
# Check the range
if mtu < ATT_DEFAULT_MTU:
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
raise core.InvalidArgumentError(f'MTU must be >= {ATT_DEFAULT_MTU}')
if mtu > 0xFFFF:
raise ValueError('MTU must be <= 0xFFFF')
raise core.InvalidArgumentError('MTU must be <= 0xFFFF')
# We can only send one request per connection
if self.mtu_exchange_done:
@@ -327,12 +345,7 @@ class Client:
self.mtu_exchange_done = True
response = await self.send_request(ATT_Exchange_MTU_Request(client_rx_mtu=mtu))
if response.op_code == ATT_ERROR_RESPONSE:
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
)
raise ATT_Error(error_code=response.error_code, message=response)
# Compute the final MTU
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
@@ -352,9 +365,7 @@ class Client:
if c.uuid == uuid
]
def get_attribute_grouping(
self, attribute_handle: int
) -> Optional[
def get_attribute_grouping(self, attribute_handle: int) -> Optional[
Union[
ServiceProxy,
Tuple[ServiceProxy, CharacteristicProxy],
@@ -391,7 +402,7 @@ class Client:
if not already_known:
self.services.append(service)
async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
async def discover_services(self, uuids: Iterable[UUID] = ()) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.1 Discover All Primary Services
'''
@@ -887,6 +898,12 @@ class Client:
) and subscriber in subscribers:
subscribers.remove(subscriber)
# The characteristic itself is added as subscriber. If it is the
# last remaining subscriber, we remove it, such that the clean up
# works correctly. Otherwise the CCCD never is set back to 0.
if len(subscribers) == 1 and characteristic in subscribers:
subscribers.remove(characteristic)
# Cleanup if we removed the last one
if not subscribers:
del subscriber_set[characteristic.handle]
@@ -920,12 +937,7 @@ class Client:
if response is None:
raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE:
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
)
raise ATT_Error(error_code=response.error_code, message=response)
# If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that
@@ -947,12 +959,7 @@ class Client:
ATT_INVALID_OFFSET_ERROR,
):
break
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
)
raise ATT_Error(error_code=response.error_code, message=response)
part = response.part_attribute_value
attribute_value += part
@@ -1045,12 +1052,7 @@ class Client:
)
)
if response.op_code == ATT_ERROR_RESPONSE:
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
)
raise ATT_Error(error_code=response.error_code, message=response)
else:
await self.send_command(
ATT_Write_Command(
@@ -1058,6 +1060,10 @@ class Client:
)
)
def on_disconnection(self, _) -> None:
if self.pending_response and not self.pending_response.done():
self.pending_response.cancel()
def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
@@ -1068,7 +1074,7 @@ class Client:
logger.warning('!!! unexpected response, there is no pending request')
return
# Sanity check: the response should match the pending request unless it is
# The response should match the pending request unless it is
# an error response
if att_pdu.op_code != ATT_ERROR_RESPONSE:
expected_response_name = self.pending_request.name.replace(

View File

@@ -28,12 +28,22 @@ import asyncio
import logging
from collections import defaultdict
import struct
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from typing import (
Dict,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Type,
Union,
TYPE_CHECKING,
)
from pyee import EventEmitter
from .colors import color
from .core import UUID
from .att import (
from bumble.colors import color
from bumble.core import UUID
from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
@@ -60,7 +70,7 @@ from .att import (
ATT_Write_Response,
Attribute,
)
from .gatt import (
from bumble.gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
@@ -68,12 +78,14 @@ from .gatt import (
GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic,
CharacteristicAdapter,
CharacteristicDeclaration,
CharacteristicValue,
IncludedServiceDeclaration,
Descriptor,
Service,
)
from bumble.utils import AsyncRunner
if TYPE_CHECKING:
from bumble.device import Device, Connection
@@ -327,7 +339,7 @@ class Server(EventEmitter):
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
)
# Sanity check
# Check parameters
if len(value) != 2:
logger.warning('CCCD value not 2 bytes long')
return
@@ -352,7 +364,7 @@ class Server(EventEmitter):
logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
self.send_gatt_pdu(connection.handle, response.to_bytes())
self.send_gatt_pdu(connection.handle, bytes(response))
async def notify_subscriber(
self,
@@ -379,7 +391,7 @@ class Server(EventEmitter):
# Get or encode the value
value = (
attribute.read_value(connection)
await attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
@@ -422,7 +434,7 @@ class Server(EventEmitter):
# Get or encode the value
value = (
attribute.read_value(connection)
await attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
@@ -444,12 +456,12 @@ class Server(EventEmitter):
assert self.pending_confirmations[connection.handle] is None
# Create a future value to hold the eventual response
pending_confirmation = self.pending_confirmations[
connection.handle
] = asyncio.get_running_loop().create_future()
pending_confirmation = self.pending_confirmations[connection.handle] = (
asyncio.get_running_loop().create_future()
)
try:
self.send_gatt_pdu(connection.handle, indication.to_bytes())
self.send_gatt_pdu(connection.handle, bytes(indication))
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red'))
@@ -650,7 +662,8 @@ class Server(EventEmitter):
self.send_response(connection, response)
def on_att_find_by_type_value_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
'''
@@ -658,13 +671,13 @@ class Server(EventEmitter):
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
async for attribute in (
attribute
for attribute in self.attributes
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type
and attribute.read_value(connection) == request.attribute_value
and (await attribute.read_value(connection)) == request.attribute_value
and pdu_space_available >= 4
):
# TODO: check permissions
@@ -702,7 +715,8 @@ class Server(EventEmitter):
self.send_response(connection, response)
def on_att_read_by_type_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_read_by_type_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
@@ -725,7 +739,7 @@ class Server(EventEmitter):
and pdu_space_available
):
try:
attribute_value = attribute.read_value(connection)
attribute_value = await attribute.read_value(connection)
except ATT_Error as error:
# If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point
@@ -767,14 +781,15 @@ class Server(EventEmitter):
self.send_response(connection, response)
def on_att_read_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_read_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = attribute.read_value(connection)
value = await attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -792,14 +807,15 @@ class Server(EventEmitter):
)
self.send_response(connection, response)
def on_att_read_blob_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_read_blob_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = attribute.read_value(connection)
value = await attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -836,7 +852,8 @@ class Server(EventEmitter):
)
self.send_response(connection, response)
def on_att_read_by_group_type_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
'''
@@ -864,7 +881,7 @@ class Server(EventEmitter):
):
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = attribute.read_value(connection)
attribute_value = await attribute.read_value(connection)
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
@@ -903,12 +920,13 @@ class Server(EventEmitter):
self.send_response(connection, response)
def on_att_write_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_write_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
'''
# Check that the attribute exists
# Check that the attribute exists
attribute = self.get_attribute(request.attribute_handle)
if attribute is None:
self.send_response(
@@ -935,13 +953,22 @@ class Server(EventEmitter):
)
return
# Accept the value
attribute.write_value(connection, request.attribute_value)
try:
# Accept the value
await attribute.write_value(connection, request.attribute_value)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=error.error_code,
)
else:
# Done
response = ATT_Write_Response()
self.send_response(connection, response)
# Done
self.send_response(connection, ATT_Write_Response())
def on_att_write_command(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_write_command(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
'''
@@ -959,9 +986,9 @@ class Server(EventEmitter):
# Accept the value
try:
attribute.write_value(connection, request.attribute_value)
await attribute.write_value(connection, request.attribute_value)
except Exception as error:
logger.warning(f'!!! ignoring exception: {error}')
logger.exception(f'!!! ignoring exception: {error}')
def on_att_handle_value_confirmation(self, connection, _confirmation):
'''

File diff suppressed because it is too large Load Diff

View File

@@ -18,10 +18,17 @@
from __future__ import annotations
from collections.abc import Callable, MutableMapping
from typing import cast, Any
import datetime
from typing import cast, Any, Optional
import logging
from bumble import avc
from bumble import avctp
from bumble import avdtp
from bumble import avrcp
from bumble import crypto
from bumble import rfcomm
from bumble import sdp
from bumble.colors import color
from bumble.att import ATT_CID, ATT_PDU
from bumble.smp import SMP_CID, SMP_Command
@@ -37,6 +44,7 @@ from bumble.l2cap import (
L2CAP_Connection_Response,
)
from bumble.hci import (
Address,
HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT,
@@ -46,8 +54,7 @@ from bumble.hci import (
HCI_AclDataPacket,
HCI_Disconnection_Complete_Event,
)
from bumble.rfcomm import RFCOMM_Frame, RFCOMM_PSM
from bumble.sdp import SDP_PDU, SDP_PSM
# -----------------------------------------------------------------------------
# Logging
@@ -57,28 +64,36 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
PSM_NAMES = {
RFCOMM_PSM: 'RFCOMM',
SDP_PSM: 'SDP',
rfcomm.RFCOMM_PSM: 'RFCOMM',
sdp.SDP_PSM: 'SDP',
avdtp.AVDTP_PSM: 'AVDTP',
avctp.AVCTP_PSM: 'AVCTP',
# TODO: add more PSM values
}
AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'}
# -----------------------------------------------------------------------------
class PacketTracer:
class AclStream:
psms: MutableMapping[int, int]
peer: PacketTracer.AclStream
peer: Optional[PacketTracer.AclStream]
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
avctp_assemblers: MutableMapping[int, avctp.MessageAssembler]
def __init__(self, analyzer: PacketTracer.Analyzer) -> None:
self.analyzer = analyzer
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.avctp_assemblers = {} # AVCTP assemblers, by source_cid
self.psms = {} # PSM, by source_cid
self.peer = None
# pylint: disable=too-many-nested-blocks
def on_acl_pdu(self, pdu: bytes) -> None:
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.analyzer.emit(l2cap_pdu)
if l2cap_pdu.cid == ATT_CID:
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
@@ -100,42 +115,51 @@ class PacketTracer:
connection_response.result
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
):
if self.peer:
if psm := self.peer.psms.get(
connection_response.source_cid
):
# Found a pending connection
self.psms[connection_response.destination_cid] = psm
# For AVDTP connections, create a packet assembler for
# each direction
if psm == avdtp.AVDTP_PSM:
self.avdtp_assemblers[
connection_response.source_cid
] = avdtp.MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
connection_response.destination_cid
] = avdtp.MessageAssembler(
self.peer.on_avdtp_message
)
if self.peer and (
psm := self.peer.psms.get(connection_response.source_cid)
):
# Found a pending connection
self.psms[connection_response.destination_cid] = psm
# For AVDTP connections, create a packet assembler for
# each direction
if psm == avdtp.AVDTP_PSM:
self.avdtp_assemblers[
connection_response.source_cid
] = avdtp.MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
connection_response.destination_cid
] = avdtp.MessageAssembler(self.peer.on_avdtp_message)
elif psm == avctp.AVCTP_PSM:
self.avctp_assemblers[
connection_response.source_cid
] = avctp.MessageAssembler(self.on_avctp_message)
self.peer.avctp_assemblers[
connection_response.destination_cid
] = avctp.MessageAssembler(self.peer.on_avctp_message)
else:
# Try to find the PSM associated with this PDU
if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)):
if psm == SDP_PSM:
sdp_pdu = SDP_PDU.from_bytes(l2cap_pdu.payload)
if psm == sdp.SDP_PSM:
sdp_pdu = sdp.SDP_PDU.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(sdp_pdu)
elif psm == RFCOMM_PSM:
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
elif psm == rfcomm.RFCOMM_PSM:
rfcomm_frame = rfcomm.RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame)
elif psm == avdtp.AVDTP_PSM:
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
)
assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
if assembler:
assembler.on_pdu(l2cap_pdu.payload)
if avdtp_assembler := self.avdtp_assemblers.get(l2cap_pdu.cid):
avdtp_assembler.on_pdu(l2cap_pdu.payload)
elif psm == avctp.AVCTP_PSM:
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVCTP]: {l2cap_pdu.payload.hex()}'
)
if avctp_assembler := self.avctp_assemblers.get(l2cap_pdu.cid):
avctp_assembler.on_pdu(l2cap_pdu.payload)
else:
psm_string = name_or_number(PSM_NAMES, psm)
self.analyzer.emit(
@@ -152,6 +176,28 @@ class PacketTracer:
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
)
def on_avctp_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
):
if pid == avrcp.AVRCP_PID:
avc_frame = avc.Frame.from_bytes(payload)
details = str(avc_frame)
else:
details = payload.hex()
c_r = 'Command' if is_command else 'Response'
self.analyzer.emit(
f'{color("AVCTP", "green")} '
f'{c_r}[{transaction_label}][{name_or_number(AVCTP_PID_NAMES, pid)}] '
f'{"#" if ipid else ""}'
f'{details}'
)
def feed_packet(self, packet: HCI_AclDataPacket) -> None:
self.packet_assembler.feed_packet(packet)
@@ -163,6 +209,7 @@ class PacketTracer:
self.label = label
self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle
self.packet_timestamp: Optional[datetime.datetime] = None
def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream:
logger.info(
@@ -190,7 +237,10 @@ class PacketTracer:
# Let the other forwarder know so it can cleanup its stream as well
self.peer.end_acl_stream(connection_handle)
def on_packet(self, packet: HCI_Packet) -> None:
def on_packet(
self, timestamp: Optional[datetime.datetime], packet: HCI_Packet
) -> None:
self.packet_timestamp = timestamp
self.emit(packet)
if packet.hci_packet_type == HCI_ACL_DATA_PACKET:
@@ -210,13 +260,22 @@ class PacketTracer:
)
def emit(self, message: Any) -> None:
self.emit_message(f'[{self.label}] {message}')
if self.packet_timestamp:
prefix = f"[{self.packet_timestamp.strftime('%Y-%m-%d %H:%M:%S.%f')}]"
else:
prefix = ""
self.emit_message(f'{prefix}[{self.label}] {message}')
def trace(self, packet: HCI_Packet, direction: int = 0) -> None:
def trace(
self,
packet: HCI_Packet,
direction: int = 0,
timestamp: Optional[datetime.datetime] = None,
) -> None:
if direction == 0:
self.host_to_controller_analyzer.on_packet(packet)
self.host_to_controller_analyzer.on_packet(timestamp, packet)
else:
self.controller_to_host_analyzer.on_packet(packet)
self.controller_to_host_analyzer.on_packet(timestamp, packet)
def __init__(
self,
@@ -232,3 +291,15 @@ class PacketTracer:
)
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer
def generate_irk() -> bytes:
return crypto.r()
def verify_rpa_with_irk(rpa: Address, irk: bytes) -> bool:
rpa_bytes = bytes(rpa)
prand_given = rpa_bytes[3:]
hash_given = rpa_bytes[:3]
hash_local = crypto.ah(irk, prand_given)
return hash_local[:3] == hash_given

File diff suppressed because it is too large Load Diff

View File

@@ -19,16 +19,16 @@ from __future__ import annotations
from dataclasses import dataclass
import logging
import enum
import struct
from abc import ABC, abstractmethod
from pyee import EventEmitter
from typing import Optional, TYPE_CHECKING
from typing import Optional, Callable
from typing_extensions import override
from bumble import l2cap
from bumble.colors import color
from bumble import l2cap, device
from bumble.core import InvalidStateError, ProtocolError
if TYPE_CHECKING:
from bumble.device import Device, Connection
from bumble.hci import Address
# -----------------------------------------------------------------------------
@@ -47,6 +47,7 @@ HID_INTERRUPT_PSM = 0x0013
class Message:
message_type: MessageType
# Report types
class ReportType(enum.IntEnum):
OTHER_REPORT = 0x00
@@ -60,6 +61,7 @@ class Message:
NOT_READY = 0x01
ERR_INVALID_REPORT_ID = 0x02
ERR_UNSUPPORTED_REQUEST = 0x03
ERR_INVALID_PARAMETER = 0x04
ERR_UNKNOWN = 0x0E
ERR_FATAL = 0x0F
@@ -101,13 +103,14 @@ class GetReportMessage(Message):
def __bytes__(self) -> bytes:
packet_bytes = bytearray()
packet_bytes.append(self.report_id)
packet_bytes.extend(
[(self.buffer_size & 0xFF), ((self.buffer_size >> 8) & 0xFF)]
)
if self.report_type == Message.ReportType.OTHER_REPORT:
if self.buffer_size == 0:
return self.header(self.report_type) + packet_bytes
else:
return self.header(0x08 | self.report_type) + packet_bytes
return (
self.header(0x08 | self.report_type)
+ packet_bytes
+ struct.pack("<H", self.buffer_size)
)
@dataclass
@@ -120,6 +123,16 @@ class SetReportMessage(Message):
return self.header(self.report_type) + self.data
@dataclass
class SendControlData(Message):
report_type: int
data: bytes
message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes:
return self.header(self.report_type) + self.data
@dataclass
class GetProtocolMessage(Message):
message_type = Message.MessageType.GET_PROTOCOL
@@ -161,60 +174,72 @@ class VirtualCableUnplug(Message):
return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
# Device sends input report, host sends output report.
@dataclass
class SendData(Message):
data: bytes
report_type: int
message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes:
return self.header(Message.ReportType.OUTPUT_REPORT) + self.data
return self.header(self.report_type) + self.data
@dataclass
class SendHandshakeMessage(Message):
result_code: int
message_type = Message.MessageType.HANDSHAKE
def __bytes__(self) -> bytes:
return self.header(self.result_code)
# -----------------------------------------------------------------------------
class Host(EventEmitter):
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel]
l2cap_intr_channel: Optional[l2cap.ClassicChannel]
class HID(ABC, EventEmitter):
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
connection: Optional[device.Connection] = None
def __init__(self, device: Device, connection: Connection) -> None:
class Role(enum.IntEnum):
HOST = 0x00
DEVICE = 0x01
def __init__(self, device: device.Device, role: Role) -> None:
super().__init__()
self.remote_device_bd_address: Optional[Address] = None
self.device = device
self.connection = connection
self.l2cap_ctrl_channel = None
self.l2cap_intr_channel = None
self.role = role
# Register ourselves with the L2CAP channel manager
device.register_l2cap_server(HID_CONTROL_PSM, self.on_connection)
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_connection)
device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
device.on('connection', self.on_device_connection)
async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel
try:
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_CONTROL_PSM
)
channel.sink = self.on_ctrl_pdu
self.l2cap_ctrl_channel = channel
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_ctrl_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
async def connect_interrupt_channel(self) -> None:
# Create a new L2CAP connection - interrupt channel
try:
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_INTERRUPT_PSM
)
channel.sink = self.on_intr_pdu
self.l2cap_intr_channel = channel
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_intr_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_intr_channel.sink = self.on_intr_pdu
async def disconnect_interrupt_channel(self) -> None:
if self.l2cap_intr_channel is None:
raise InvalidStateError('invalid state')
@@ -229,9 +254,18 @@ class Host(EventEmitter):
self.l2cap_ctrl_channel = None
await channel.disconnect()
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
def on_device_connection(self, connection: device.Connection) -> None:
self.connection = connection
self.remote_device_bd_address = connection.peer_address
connection.on('disconnection', self.on_device_disconnection)
def on_device_disconnection(self, reason: int) -> None:
self.connection = None
def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
if l2cap_channel.psm == HID_CONTROL_PSM:
@@ -242,63 +276,20 @@ class Host(EventEmitter):
self.l2cap_intr_channel.sink = self.on_intr_pdu
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
def on_ctrl_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
# Here we will receive all kinds of packets, parse and then call respective callbacks
message_type = pdu[0] >> 4
param = pdu[0] & 0x0F
if message_type == Message.MessageType.HANDSHAKE:
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
self.emit('handshake', Message.Handshake(param))
elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA')
self.emit('data', pdu)
elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.SUSPEND:
logger.debug('<<< HID SUSPEND')
self.emit('suspend', pdu)
elif param == Message.ControlCommand.EXIT_SUSPEND:
logger.debug('<<< HID EXIT SUSPEND')
self.emit('exit_suspend', pdu)
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug')
else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
if l2cap_channel.psm == HID_CONTROL_PSM:
self.l2cap_ctrl_channel = None
else:
logger.debug('<<< HID CONTROL DATA')
self.emit('data', pdu)
self.l2cap_intr_channel = None
logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
@abstractmethod
def on_ctrl_pdu(self, pdu: bytes) -> None:
pass
def on_intr_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
self.emit("data", pdu)
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
msg = GetReportMessage(
report_type=report_type, report_id=report_id, buffer_size=buffer_size
)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_report(self, report_type: int, data: bytes):
msg = SetReportMessage(report_type=report_type, data=data)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def get_protocol(self):
msg = GetProtocolMessage()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_protocol(self, protocol_mode: int):
msg = SetProtocolMessage(protocol_mode=protocol_mode)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
self.emit("interrupt_data", pdu)
def send_pdu_on_ctrl(self, msg: bytes) -> None:
assert self.l2cap_ctrl_channel
@@ -308,26 +299,253 @@ class Host(EventEmitter):
assert self.l2cap_intr_channel
self.l2cap_intr_channel.send_pdu(msg)
def send_data(self, data):
msg = SendData(data)
def send_data(self, data: bytes) -> None:
if self.role == HID.Role.HOST:
report_type = Message.ReportType.OUTPUT_REPORT
else:
report_type = Message.ReportType.INPUT_REPORT
msg = SendData(data, report_type)
hid_message = bytes(msg)
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
self.send_pdu_on_intr(hid_message)
if self.l2cap_intr_channel is not None:
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
self.send_pdu_on_intr(hid_message)
def suspend(self):
msg = Suspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(msg)
def exit_suspend(self):
msg = ExitSuspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(msg)
def virtual_cable_unplug(self):
def virtual_cable_unplug(self) -> None:
msg = VirtualCableUnplug()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(msg)
self.send_pdu_on_ctrl(hid_message)
# -----------------------------------------------------------------------------
class Device(HID):
class GetSetReturn(enum.IntEnum):
FAILURE = 0x00
REPORT_ID_NOT_FOUND = 0x01
ERR_UNSUPPORTED_REQUEST = 0x02
ERR_UNKNOWN = 0x03
ERR_INVALID_PARAMETER = 0x04
SUCCESS = 0xFF
@dataclass
class GetSetStatus:
data: bytes = b''
status: int = 0
get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE)
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
param = pdu[0] & 0x0F
message_type = pdu[0] >> 4
if message_type == Message.MessageType.GET_REPORT:
logger.debug('<<< HID GET REPORT')
self.handle_get_report(pdu)
elif message_type == Message.MessageType.SET_REPORT:
logger.debug('<<< HID SET REPORT')
self.handle_set_report(pdu)
elif message_type == Message.MessageType.GET_PROTOCOL:
logger.debug('<<< HID GET PROTOCOL')
self.handle_get_protocol(pdu)
elif message_type == Message.MessageType.SET_PROTOCOL:
logger.debug('<<< HID SET PROTOCOL')
self.handle_set_protocol(pdu)
elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu)
elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.SUSPEND:
logger.debug('<<< HID SUSPEND')
self.emit('suspend')
elif param == Message.ControlCommand.EXIT_SUSPEND:
logger.debug('<<< HID EXIT SUSPEND')
self.emit('exit_suspend')
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug')
else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def send_handshake_message(self, result_code: int) -> None:
msg = SendHandshakeMessage(result_code)
hid_message = bytes(msg)
logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def send_control_data(self, report_type: int, data: bytes):
msg = SendControlData(report_type=report_type, data=data)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def handle_get_report(self, pdu: bytes):
if self.get_report_cb is None:
logger.debug("GetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
buffer_flag = (pdu[0] & 0x08) >> 3
report_id = pdu[1]
logger.debug(f"buffer_flag: {buffer_flag}")
if buffer_flag == 1:
buffer_size = (pdu[3] << 8) | pdu[2]
else:
buffer_size = 0
ret = self.get_report_cb(report_id, report_type, buffer_size)
if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS:
data = bytearray()
data.append(report_id)
data.extend(ret.data)
if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr]
self.send_control_data(report_type=report_type, data=data)
else:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(
self, cb: Callable[[int, int, int], Device.GetSetStatus]
) -> None:
self.get_report_cb = cb
logger.debug("GetReport callback registered successfully")
def handle_set_report(self, pdu: bytes):
if self.set_report_cb is None:
logger.debug("SetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
report_id = pdu[1]
report_data = pdu[2:]
report_size = len(report_data) + 1
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb(
self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus]
) -> None:
self.set_report_cb = cb
logger.debug("SetReport callback registered successfully")
def handle_get_protocol(self, pdu: bytes):
if self.get_protocol_cb is None:
logger.debug("GetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.get_protocol_cb()
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None:
self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully")
def handle_set_protocol(self, pdu: bytes):
if self.set_protocol_cb is None:
logger.debug("SetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.set_protocol_cb(pdu[0] & 0x01)
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(
self, cb: Callable[[int], Device.GetSetStatus]
) -> None:
self.set_protocol_cb = cb
logger.debug("SetProtocol callback registered successfully")
# -----------------------------------------------------------------------------
class Host(HID):
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.HOST)
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
msg = GetReportMessage(
report_type=report_type, report_id=report_id, buffer_size=buffer_size
)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_report(self, report_type: int, data: bytes) -> None:
msg = SetReportMessage(report_type=report_type, data=data)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def get_protocol(self) -> None:
msg = GetProtocolMessage()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_protocol(self, protocol_mode: int) -> None:
msg = SetProtocolMessage(protocol_mode=protocol_mode)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def suspend(self) -> None:
msg = Suspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def exit_suspend(self) -> None:
msg = ExitSuspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
param = pdu[0] & 0x0F
message_type = pdu[0] >> 4
if message_type == Message.MessageType.HANDSHAKE:
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
self.emit('handshake', Message.Handshake(param))
elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu)
elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug')
else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')

File diff suppressed because it is too large Load Diff

View File

@@ -25,7 +25,8 @@ import asyncio
import logging
import os
import json
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
from typing_extensions import Self
from .colors import color
from .hci import Address
@@ -128,10 +129,10 @@ class PairingKeys:
def print(self, prefix=''):
keys_dict = self.to_dict()
for (container_property, value) in keys_dict.items():
for container_property, value in keys_dict.items():
if isinstance(value, dict):
print(f'{prefix}{color(container_property, "cyan")}:')
for (key_property, key_value) in value.items():
for key_property, key_value in value.items():
print(f'{prefix} {color(key_property, "green")}: {key_value}')
else:
print(f'{prefix}{color(container_property, "cyan")}: {value}')
@@ -158,7 +159,7 @@ class KeyStore:
async def get_resolving_keys(self):
all_keys = await self.get_all()
resolving_keys = []
for (name, keys) in all_keys:
for name, keys in all_keys:
if keys.irk is not None:
if keys.address_type is None:
address_type = Address.RANDOM_DEVICE_ADDRESS
@@ -171,7 +172,7 @@ class KeyStore:
async def print(self, prefix=''):
entries = await self.get_all()
separator = ''
for (name, keys) in entries:
for name, keys in entries:
print(separator + prefix + color(name, 'yellow'))
keys.print(prefix=prefix + ' ')
separator = '\n'
@@ -253,8 +254,10 @@ class JsonKeyStore(KeyStore):
logger.debug(f'JSON keystore: {self.filename}')
@staticmethod
def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]:
@classmethod
def from_device(
cls: Type[Self], device: Device, filename: Optional[str] = None
) -> Self:
if not filename:
# Extract the filename from the config if there is one
if device.config.keystore is not None:
@@ -270,7 +273,7 @@ class JsonKeyStore(KeyStore):
else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE
return JsonKeyStore(namespace, filename)
return cls(namespace, filename)
async def load(self):
# Try to open the file, without failing. If the file does not exist, it

View File

@@ -41,7 +41,14 @@ from typing import (
from .utils import deprecated
from .colors import color
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
from .core import (
BT_CENTRAL_ROLE,
InvalidStateError,
InvalidArgumentError,
InvalidPacketError,
OutOfResourcesError,
ProtocolError,
)
from .hci import (
HCI_LE_Connection_Update_Command,
HCI_Object,
@@ -70,6 +77,7 @@ L2CAP_LE_SIGNALING_CID = 0x05
L2CAP_MIN_LE_MTU = 23
L2CAP_MIN_BR_EDR_MTU = 48
L2CAP_MAX_BR_EDR_MTU = 65535
L2CAP_DEFAULT_MTU = 2048 # Default value for the MTU we are willing to accept
@@ -149,9 +157,10 @@ L2CAP_INVALID_CID_IN_REQUEST_REASON = 0x0002
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2046
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256
@@ -172,7 +181,7 @@ L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE = 0x01
@dataclasses.dataclass
class ClassicChannelSpec:
psm: Optional[int] = None
mtu: int = L2CAP_MIN_BR_EDR_MTU
mtu: int = L2CAP_DEFAULT_MTU
@dataclasses.dataclass
@@ -187,14 +196,17 @@ class LeCreditBasedChannelSpec:
self.max_credits < 1
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
):
raise ValueError('max credits out of range')
if self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
raise ValueError('MTU too small')
raise InvalidArgumentError('max credits out of range')
if (
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
):
raise InvalidArgumentError('MTU out of range')
if (
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
):
raise ValueError('MPS out of range')
raise InvalidArgumentError('MPS out of range')
class L2CAP_PDU:
@@ -204,16 +216,16 @@ class L2CAP_PDU:
@staticmethod
def from_bytes(data: bytes) -> L2CAP_PDU:
# Sanity check
# Check parameters
if len(data) < 4:
raise ValueError('not enough data for L2CAP header')
raise InvalidPacketError('not enough data for L2CAP header')
_, l2cap_pdu_cid = struct.unpack_from('<HH', data, 0)
l2cap_pdu_payload = data[4:]
return L2CAP_PDU(l2cap_pdu_cid, l2cap_pdu_payload)
def to_bytes(self) -> bytes:
def __bytes__(self) -> bytes:
header = struct.pack('<HH', len(self.payload), self.cid)
return header + self.payload
@@ -221,9 +233,6 @@ class L2CAP_PDU:
self.cid = cid
self.payload = payload
def __bytes__(self) -> bytes:
return self.to_bytes()
def __str__(self) -> str:
return f'{color("L2CAP", "green")} [CID={self.cid}]: {self.payload.hex()}'
@@ -321,11 +330,8 @@ class L2CAP_Control_Frame:
def init_from_bytes(self, pdu, offset):
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self) -> bytes:
return self.pdu
def __bytes__(self) -> bytes:
return self.to_bytes()
return self.pdu
def __str__(self) -> str:
result = f'{color(self.name, "yellow")} [ID={self.identifier}]'
@@ -745,6 +751,8 @@ class ClassicChannel(EventEmitter):
sink: Optional[Callable[[bytes], Any]]
state: State
connection: Connection
mtu: int
peer_mtu: int
def __init__(
self,
@@ -761,6 +769,7 @@ class ClassicChannel(EventEmitter):
self.signaling_cid = signaling_cid
self.state = self.State.CLOSED
self.mtu = mtu
self.peer_mtu = L2CAP_MIN_BR_EDR_MTU
self.psm = psm
self.source_cid = source_cid
self.destination_cid = 0
@@ -808,7 +817,7 @@ class ClassicChannel(EventEmitter):
# Check that we can start a new connection
if self.connection_result:
raise RuntimeError('connection already pending')
raise InvalidStateError('connection already pending')
self._change_state(self.State.WAIT_CONNECT_RSP)
self.send_control_frame(
@@ -825,7 +834,9 @@ class ClassicChannel(EventEmitter):
# Wait for the connection to succeed or fail
try:
return await self.connection_result
return await self.connection.abort_on(
'disconnection', self.connection_result
)
finally:
self.connection_result = None
@@ -857,7 +868,7 @@ class ClassicChannel(EventEmitter):
[
(
L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE,
struct.pack('<H', L2CAP_DEFAULT_MTU),
struct.pack('<H', self.mtu),
)
]
)
@@ -922,8 +933,8 @@ class ClassicChannel(EventEmitter):
options = L2CAP_Control_Frame.decode_configuration_options(request.options)
for option in options:
if option[0] == L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE:
self.mtu = struct.unpack('<H', option[1])[0]
logger.debug(f'MTU = {self.mtu}')
self.peer_mtu = struct.unpack('<H', option[1])[0]
logger.debug(f'peer MTU = {self.peer_mtu}')
self.send_control_frame(
L2CAP_Configure_Response(
@@ -1022,7 +1033,7 @@ class ClassicChannel(EventEmitter):
return (
f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, '
f'MTU={self.mtu}, '
f'MTU={self.mtu}/{self.peer_mtu}, '
f'state={self.state.name})'
)
@@ -1119,7 +1130,7 @@ class LeCreditBasedChannel(EventEmitter):
# Check that we can start a new connection
identifier = self.manager.next_identifier(self.connection)
if identifier in self.manager.le_coc_requests:
raise RuntimeError('too many concurrent connection requests')
raise InvalidStateError('too many concurrent connection requests')
self._change_state(self.State.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request(
@@ -1506,7 +1517,7 @@ class ChannelManager:
if cid not in channels:
return cid
raise RuntimeError('no free CID available')
raise OutOfResourcesError('no free CID available')
@staticmethod
def find_free_le_cid(channels: Iterable[int]) -> int:
@@ -1519,7 +1530,7 @@ class ChannelManager:
if cid not in channels:
return cid
raise RuntimeError('no free CID')
raise OutOfResourcesError('no free CID')
def next_identifier(self, connection: Connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
@@ -1566,15 +1577,15 @@ class ChannelManager:
else:
# Check that the PSM isn't already in use
if spec.psm in self.servers:
raise ValueError('PSM already in use')
raise InvalidArgumentError('PSM already in use')
# Check that the PSM is valid
if spec.psm % 2 == 0:
raise ValueError('invalid PSM (not odd)')
raise InvalidArgumentError('invalid PSM (not odd)')
check = spec.psm >> 8
while check:
if check % 2 != 0:
raise ValueError('invalid PSM')
raise InvalidArgumentError('invalid PSM')
check >>= 8
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
@@ -1616,7 +1627,7 @@ class ChannelManager:
else:
# Check that the PSM isn't already in use
if spec.psm in self.le_coc_servers:
raise ValueError('PSM already in use')
raise InvalidArgumentError('PSM already in use')
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
self,
@@ -1644,12 +1655,13 @@ class ChannelManager:
def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
pdu_bytes = bytes(pdu)
logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
f'{connection.peer_address}: {pdu_str}'
f'{connection.peer_address}: {len(pdu_bytes)} bytes, {pdu_str}'
)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
self.host.send_l2cap_pdu(connection.handle, cid, pdu_bytes)
def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
@@ -1893,6 +1905,7 @@ class ChannelManager:
data = sum(1 << cid for cid in self.fixed_channels).to_bytes(8, 'little')
else:
result = L2CAP_Information_Response.NOT_SUPPORTED
data = b''
self.send_control_frame(
connection,
@@ -1926,7 +1939,7 @@ class ChannelManager:
supervision_timeout=request.timeout,
min_ce_length=0,
max_ce_length=0,
) # type: ignore[call-arg]
)
)
else:
self.send_control_frame(
@@ -2143,10 +2156,10 @@ class ChannelManager:
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use')
raise OutOfResourcesError('all CIDs already in use')
if spec.psm is None:
raise ValueError('PSM cannot be None')
raise InvalidArgumentError('PSM cannot be None')
# Create the channel
logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}')
@@ -2195,10 +2208,10 @@ class ChannelManager:
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_br_edr_cid(connection_channels)
if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use')
raise OutOfResourcesError('all CIDs already in use')
if spec.psm is None:
raise ValueError('PSM cannot be None')
raise InvalidArgumentError('PSM cannot be None')
# Create the channel
logger.debug(
@@ -2217,7 +2230,7 @@ class ChannelManager:
# Connect
try:
await channel.connect()
except Exception as e:
except BaseException as e:
del connection_channels[source_cid]
raise e

View File

@@ -19,16 +19,25 @@ import logging
import asyncio
from functools import partial
from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT
from bumble.core import (
BT_PERIPHERAL_ROLE,
BT_BR_EDR_TRANSPORT,
BT_LE_TRANSPORT,
InvalidStateError,
)
from bumble.colors import color
from bumble.hci import (
Address,
HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
HCI_PAGE_TIMEOUT_ERROR,
HCI_Connection_Complete_Event,
)
from bumble import controller
from typing import Optional, Set
# -----------------------------------------------------------------------------
# Logging
@@ -57,6 +66,8 @@ class LocalLink:
Link bus for controllers to communicate with each other
'''
controllers: Set[controller.Controller]
def __init__(self):
self.controllers = set()
self.pending_connection = None
@@ -79,7 +90,9 @@ class LocalLink:
return controller
return None
def find_classic_controller(self, address):
def find_classic_controller(
self, address: Address
) -> Optional[controller.Controller]:
for controller in self.controllers:
if controller.public_address == address:
return controller
@@ -109,6 +122,8 @@ class LocalLink:
elif transport == BT_BR_EDR_TRANSPORT:
destination_controller = self.find_classic_controller(destination_address)
source_address = sender_controller.public_address
else:
raise ValueError("unsupported transport type")
if destination_controller is not None:
destination_controller.on_link_acl_data(source_address, transport, data)
@@ -188,6 +203,60 @@ class LocalLink:
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
def create_cis(
self,
central_controller: controller.Controller,
peripheral_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
)
if peripheral_controller := self.find_controller(peripheral_address):
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_request,
central_controller.random_address,
cig_id,
cis_id,
)
def accept_cis(
self,
peripheral_controller: controller.Controller,
central_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
)
if central_controller := self.find_controller(central_address):
asyncio.get_running_loop().call_soon(
central_controller.on_link_cis_established, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_established, cig_id, cis_id
)
def disconnect_cis(
self,
initiator_controller: controller.Controller,
peer_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
)
if peer_controller := self.find_controller(peer_address):
asyncio.get_running_loop().call_soon(
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peer_controller.on_link_cis_disconnected, cig_id, cis_id
)
############################################################
# Classic handlers
############################################################
@@ -271,6 +340,52 @@ class LocalLink:
initiator_controller.public_address, int(not (initiator_new_role))
)
def classic_sco_connect(
self,
initiator_controller: controller.Controller,
responder_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
# Initiator controller should handle it.
assert responder_controller
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
link_type,
)
def classic_accept_sco_connection(
self,
responder_controller: controller.Controller,
initiator_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_sco_connection_complete(
responder_controller.public_address,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
link_type,
)
return
async def task():
initiator_controller.on_classic_sco_connection_complete(
responder_controller.public_address, HCI_SUCCESS, link_type
)
asyncio.create_task(task())
responder_controller.on_classic_sco_connection_complete(
initiator_controller.public_address, HCI_SUCCESS, link_type
)
# -----------------------------------------------------------------------------
class RemoteLink:
@@ -297,12 +412,12 @@ class RemoteLink:
def add_controller(self, controller):
if self.controller:
raise ValueError('controller already set')
raise InvalidStateError('controller already set')
self.controller = controller
def remove_controller(self, controller):
if self.controller != controller:
raise ValueError('controller mismatch')
raise InvalidStateError('controller mismatch')
self.controller = None
def get_pending_connection(self):

View File

@@ -139,16 +139,19 @@ class PairingDelegate:
io_capability: IoCapability
local_initiator_key_distribution: KeyDistribution
local_responder_key_distribution: KeyDistribution
maximum_encryption_key_size: int
def __init__(
self,
io_capability: IoCapability = NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
maximum_encryption_key_size: int = 16,
) -> None:
self.io_capability = io_capability
self.local_initiator_key_distribution = local_initiator_key_distribution
self.local_responder_key_distribution = local_responder_key_distribution
self.maximum_encryption_key_size = maximum_encryption_key_size
@property
def classic_io_capability(self) -> int:

View File

@@ -25,8 +25,10 @@ import grpc.aio
from .config import Config
from .device import PandoraDevice
from .host import HostService
from .l2cap import L2CAPService
from .security import SecurityService, SecurityStorageService
from pandora.host_grpc_aio import add_HostServicer_to_server
from pandora.l2cap_grpc_aio import add_L2CAPServicer_to_server
from pandora.security_grpc_aio import (
add_SecurityServicer_to_server,
add_SecurityStorageServicer_to_server,
@@ -77,6 +79,7 @@ async def serve(
add_SecurityStorageServicer_to_server(
SecurityStorageService(bumble.device, config), server
)
add_L2CAPServicer_to_server(L2CAPService(bumble.device, config), server)
# call hooks if any.
for hook in _SERVICERS_HOOKS:

View File

@@ -28,12 +28,15 @@ from bumble.core import (
BT_PERIPHERAL_ROLE,
UUID,
AdvertisingData,
Appearance,
ConnectionError,
)
from bumble.device import (
DEVICE_DEFAULT_SCAN_INTERVAL,
DEVICE_DEFAULT_SCAN_WINDOW,
Advertisement,
AdvertisingParameters,
AdvertisingEventProperties,
AdvertisingType,
Device,
)
@@ -43,13 +46,17 @@ from bumble.hci import (
HCI_PAGE_TIMEOUT_ERROR,
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
Address,
Phy,
)
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from pandora.host_grpc_aio import HostServicer
from pandora import host_pb2
from pandora.host_pb2 import (
NOT_CONNECTABLE,
NOT_DISCOVERABLE,
DISCOVERABLE_LIMITED,
DISCOVERABLE_GENERAL,
PRIMARY_1M,
PRIMARY_CODED,
SECONDARY_1M,
@@ -65,6 +72,7 @@ from pandora.host_pb2 import (
ConnectResponse,
DataTypes,
DisconnectRequest,
DiscoverabilityMode,
InquiryResponse,
PrimaryPhy,
ReadLocalAddressResponse,
@@ -94,6 +102,25 @@ SECONDARY_PHY_MAP: Dict[int, SecondaryPhy] = {
3: SECONDARY_CODED,
}
PRIMARY_PHY_TO_BUMBLE_PHY_MAP: Dict[PrimaryPhy, Phy] = {
PRIMARY_1M: Phy.LE_1M,
PRIMARY_CODED: Phy.LE_CODED,
}
SECONDARY_PHY_TO_BUMBLE_PHY_MAP: Dict[SecondaryPhy, Phy] = {
SECONDARY_NONE: Phy.LE_1M,
SECONDARY_1M: Phy.LE_1M,
SECONDARY_2M: Phy.LE_2M,
SECONDARY_CODED: Phy.LE_CODED,
}
OWN_ADDRESS_MAP: Dict[host_pb2.OwnAddressType, bumble.hci.OwnAddressType] = {
host_pb2.PUBLIC: bumble.hci.OwnAddressType.PUBLIC,
host_pb2.RANDOM: bumble.hci.OwnAddressType.RANDOM,
host_pb2.RESOLVABLE_OR_PUBLIC: bumble.hci.OwnAddressType.RESOLVABLE_OR_PUBLIC,
host_pb2.RESOLVABLE_OR_RANDOM: bumble.hci.OwnAddressType.RESOLVABLE_OR_RANDOM,
}
class HostService(HostServicer):
waited_connections: Set[int]
@@ -261,9 +288,9 @@ class HostService(HostServicer):
self.log.debug(f"WaitDisconnection: {connection_handle}")
if connection := self.device.lookup_connection(connection_handle):
disconnection_future: asyncio.Future[
None
] = asyncio.get_running_loop().create_future()
disconnection_future: asyncio.Future[None] = (
asyncio.get_running_loop().create_future()
)
def on_disconnection(_: None) -> None:
disconnection_future.set_result(None)
@@ -281,14 +308,118 @@ class HostService(HostServicer):
async def Advertise(
self, request: AdvertiseRequest, context: grpc.ServicerContext
) -> AsyncGenerator[AdvertiseResponse, None]:
if not request.legacy:
raise NotImplementedError(
"TODO: add support for extended advertising in Bumble"
try:
if request.legacy:
async for rsp in self.legacy_advertise(request, context):
yield rsp
else:
async for rsp in self.extended_advertise(request, context):
yield rsp
finally:
pass
async def extended_advertise(
self, request: AdvertiseRequest, context: grpc.ServicerContext
) -> AsyncGenerator[AdvertiseResponse, None]:
advertising_data = bytes(self.unpack_data_types(request.data))
scan_response_data = bytes(self.unpack_data_types(request.scan_response_data))
scannable = len(scan_response_data) != 0
advertising_event_properties = AdvertisingEventProperties(
is_connectable=request.connectable,
is_scannable=scannable,
is_directed=request.target is not None,
is_high_duty_cycle_directed_connectable=False,
is_legacy=False,
is_anonymous=False,
include_tx_power=False,
)
peer_address = Address.ANY
if request.target:
# Need to reverse bytes order since Bumble Address is using MSB.
target_bytes = bytes(reversed(request.target))
if request.target_variant() == "public":
peer_address = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
else:
peer_address = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
advertising_parameters = AdvertisingParameters(
advertising_event_properties=advertising_event_properties,
own_address_type=OWN_ADDRESS_MAP[request.own_address_type],
peer_address=peer_address,
primary_advertising_phy=PRIMARY_PHY_TO_BUMBLE_PHY_MAP[request.primary_phy],
secondary_advertising_phy=SECONDARY_PHY_TO_BUMBLE_PHY_MAP[
request.secondary_phy
],
)
if advertising_interval := request.interval:
advertising_parameters.primary_advertising_interval_min = int(
advertising_interval
)
if request.interval:
raise NotImplementedError("TODO: add support for `request.interval`")
if request.interval_range:
raise NotImplementedError("TODO: add support for `request.interval_range`")
advertising_parameters.primary_advertising_interval_max = int(
advertising_interval
)
if interval_range := request.interval_range:
advertising_parameters.primary_advertising_interval_max += int(
interval_range
)
advertising_set = await self.device.create_advertising_set(
advertising_parameters=advertising_parameters,
advertising_data=advertising_data,
scan_response_data=scan_response_data,
)
pending_connection: asyncio.Future[bumble.device.Connection] = (
asyncio.get_running_loop().create_future()
)
if request.connectable:
def on_connection(connection: bumble.device.Connection) -> None:
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
pending_connection.set_result(connection)
self.device.on('connection', on_connection)
try:
# Advertise until RPC is canceled
while True:
if not advertising_set.enabled:
self.log.debug('Advertise (extended)')
await advertising_set.start()
if not request.connectable:
await asyncio.sleep(1)
continue
connection = await pending_connection
pending_connection = asyncio.get_running_loop().create_future()
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
yield AdvertiseResponse(connection=Connection(cookie=cookie))
await asyncio.sleep(1)
finally:
try:
self.log.debug('Stop Advertise (extended)')
await advertising_set.stop()
await advertising_set.remove()
except Exception:
pass
async def legacy_advertise(
self, request: AdvertiseRequest, context: grpc.ServicerContext
) -> AsyncGenerator[AdvertiseResponse, None]:
if advertising_interval := request.interval:
self.device.config.advertising_interval_min = int(advertising_interval)
self.device.config.advertising_interval_max = int(advertising_interval)
if interval_range := request.interval_range:
self.device.config.advertising_interval_max += int(interval_range)
if request.primary_phy:
raise NotImplementedError("TODO: add support for `request.primary_phy`")
if request.secondary_phy:
@@ -356,14 +487,10 @@ class HostService(HostServicer):
target_bytes = bytes(reversed(request.target))
if request.target_variant() == "public":
target = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
advertising_type = (
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
) # FIXME: HIGH_DUTY ?
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY
else:
target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
advertising_type = (
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
) # FIXME: HIGH_DUTY ?
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY
if request.connectable:
@@ -390,9 +517,9 @@ class HostService(HostServicer):
await asyncio.sleep(1)
continue
pending_connection: asyncio.Future[
bumble.device.Connection
] = asyncio.get_running_loop().create_future()
pending_connection: asyncio.Future[bumble.device.Connection] = (
asyncio.get_running_loop().create_future()
)
self.log.debug('Wait for LE connection...')
connection = await pending_connection
@@ -421,23 +548,31 @@ class HostService(HostServicer):
self, request: ScanRequest, context: grpc.ServicerContext
) -> AsyncGenerator[ScanningResponse, None]:
# TODO: modify `start_scanning` to accept floats instead of int for ms values
if request.phys:
raise NotImplementedError("TODO: add support for `request.phys`")
self.log.debug('Scan')
scanning_phys = []
if PRIMARY_1M in request.phys:
scanning_phys.append(int(Phy.LE_1M))
if PRIMARY_CODED in request.phys:
scanning_phys.append(int(Phy.LE_CODED))
if not scanning_phys:
scanning_phys = [int(Phy.LE_1M), int(Phy.LE_CODED)]
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
handler = self.device.on('advertisement', scan_queue.put_nowait)
await self.device.start_scanning(
legacy=request.legacy,
active=not request.passive,
own_address_type=request.own_address_type,
scan_interval=int(request.interval)
if request.interval
else DEVICE_DEFAULT_SCAN_INTERVAL,
scan_window=int(request.window)
if request.window
else DEVICE_DEFAULT_SCAN_WINDOW,
scan_interval=(
int(request.interval)
if request.interval
else DEVICE_DEFAULT_SCAN_INTERVAL
),
scan_window=(
int(request.window) if request.window else DEVICE_DEFAULT_SCAN_WINDOW
),
scanning_phys=scanning_phys,
)
try:
@@ -650,9 +785,11 @@ class HostService(HostServicer):
*struct.pack('<H', dt.peripheral_connection_interval_min),
*struct.pack(
'<H',
dt.peripheral_connection_interval_max
if dt.peripheral_connection_interval_max
else dt.peripheral_connection_interval_min,
(
dt.peripheral_connection_interval_max
if dt.peripheral_connection_interval_max
else dt.peripheral_connection_interval_min
),
),
]
),
@@ -734,6 +871,16 @@ class HostService(HostServicer):
)
)
flag_map = {
NOT_DISCOVERABLE: 0x00,
DISCOVERABLE_LIMITED: AdvertisingData.LE_LIMITED_DISCOVERABLE_MODE_FLAG,
DISCOVERABLE_GENERAL: AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG,
}
if dt.le_discoverability_mode:
flags = flag_map[dt.le_discoverability_mode]
ad_structures.append((AdvertisingData.FLAGS, flags.to_bytes(1, 'big')))
return AdvertisingData(ad_structures)
def pack_data_types(self, ad: AdvertisingData) -> DataTypes:
@@ -842,8 +989,8 @@ class HostService(HostServicer):
dt.random_target_addresses.extend(
[data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))]
)
if i := cast(int, ad.get(AdvertisingData.APPEARANCE)):
dt.appearance = i
if appearance := cast(Appearance, ad.get(AdvertisingData.APPEARANCE)):
dt.appearance = int(appearance)
if i := cast(int, ad.get(AdvertisingData.ADVERTISING_INTERVAL)):
dt.advertising_interval = i
if s := cast(str, ad.get(AdvertisingData.URI)):

310
bumble/pandora/l2cap.py Normal file
View File

@@ -0,0 +1,310 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import grpc
import json
import logging
from asyncio import Queue as AsyncQueue, Future
from . import utils
from .config import Config
from bumble.core import OutOfResourcesError, InvalidArgumentError
from bumble.device import Device
from bumble.l2cap import (
ClassicChannel,
ClassicChannelServer,
ClassicChannelSpec,
LeCreditBasedChannel,
LeCreditBasedChannelServer,
LeCreditBasedChannelSpec,
)
from google.protobuf import any_pb2, empty_pb2 # pytype: disable=pyi-error
from pandora.l2cap_grpc_aio import L2CAPServicer # pytype: disable=pyi-error
from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error
COMMAND_NOT_UNDERSTOOD,
INVALID_CID_IN_REQUEST,
Channel as PandoraChannel,
ConnectRequest,
ConnectResponse,
CreditBasedChannelRequest,
DisconnectRequest,
DisconnectResponse,
ReceiveRequest,
ReceiveResponse,
SendRequest,
SendResponse,
WaitConnectionRequest,
WaitConnectionResponse,
WaitDisconnectionRequest,
WaitDisconnectionResponse,
)
from typing import AsyncGenerator, Dict, Optional, Union
from dataclasses import dataclass
L2capChannel = Union[ClassicChannel, LeCreditBasedChannel]
@dataclass
class ChannelContext:
close_future: Future
sdu_queue: AsyncQueue
class L2CAPService(L2CAPServicer):
def __init__(self, device: Device, config: Config) -> None:
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(), {'service_name': 'L2CAP', 'device': device}
)
self.device = device
self.config = config
self.channels: Dict[bytes, ChannelContext] = {}
def register_event(self, l2cap_channel: L2capChannel) -> ChannelContext:
close_future = asyncio.get_running_loop().create_future()
sdu_queue: AsyncQueue = AsyncQueue()
def on_channel_sdu(sdu):
sdu_queue.put_nowait(sdu)
def on_close():
close_future.set_result(None)
l2cap_channel.sink = on_channel_sdu
l2cap_channel.on('close', on_close)
return ChannelContext(close_future, sdu_queue)
@utils.rpc
async def WaitConnection(
self, request: WaitConnectionRequest, context: grpc.ServicerContext
) -> WaitConnectionResponse:
self.log.debug('WaitConnection')
if not request.connection:
raise ValueError('A valid connection field must be set')
# find connection on device based on connection cookie value
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
connection = self.device.lookup_connection(connection_handle)
if not connection:
raise ValueError('The connection specified is invalid.')
oneof = request.WhichOneof('type')
self.log.debug(f'WaitConnection channel request type: {oneof}.')
channel_type = getattr(request, oneof)
spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
l2cap_server: Optional[
Union[ClassicChannelServer, LeCreditBasedChannelServer]
] = None
if isinstance(channel_type, CreditBasedChannelRequest):
spec = LeCreditBasedChannelSpec(
psm=channel_type.spsm,
max_credits=channel_type.initial_credit,
mtu=channel_type.mtu,
mps=channel_type.mps,
)
if channel_type.spsm in self.device.l2cap_channel_manager.le_coc_servers:
l2cap_server = self.device.l2cap_channel_manager.le_coc_servers[
channel_type.spsm
]
else:
spec = ClassicChannelSpec(
psm=channel_type.psm,
mtu=channel_type.mtu,
)
if channel_type.psm in self.device.l2cap_channel_manager.servers:
l2cap_server = self.device.l2cap_channel_manager.servers[
channel_type.psm
]
self.log.info(f'Listening for L2CAP connection on PSM {spec.psm}')
channel_future: Future[PandoraChannel] = (
asyncio.get_running_loop().create_future()
)
def on_l2cap_channel(l2cap_channel: L2capChannel):
try:
channel_context = self.register_event(l2cap_channel)
pandora_channel: PandoraChannel = self.craft_pandora_channel(
connection_handle, l2cap_channel
)
self.channels[pandora_channel.cookie.value] = channel_context
channel_future.set_result(pandora_channel)
except Exception as e:
self.log.error(f'Failed to set channel future: {e}')
if l2cap_server is None:
l2cap_server = self.device.create_l2cap_server(
spec=spec, handler=on_l2cap_channel
)
else:
l2cap_server.on('connection', on_l2cap_channel)
try:
self.log.debug('Waiting for a channel connection.')
pandora_channel: PandoraChannel = await channel_future
return WaitConnectionResponse(channel=pandora_channel)
except Exception as e:
self.log.warning(f'Exception: {e}')
return WaitConnectionResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc
async def WaitDisconnection(
self, request: WaitDisconnectionRequest, context: grpc.ServicerContext
) -> WaitDisconnectionResponse:
try:
self.log.debug('WaitDisconnection')
await self.lookup_context(request.channel).close_future
self.log.debug("return WaitDisconnectionResponse")
return WaitDisconnectionResponse(success=empty_pb2.Empty())
except KeyError as e:
self.log.warning(f'WaitDisconnection: Unable to find the channel: {e}')
return WaitDisconnectionResponse(error=INVALID_CID_IN_REQUEST)
except Exception as e:
self.log.exception(f'WaitDisonnection failed: {e}')
return WaitDisconnectionResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc
async def Receive(
self, request: ReceiveRequest, context: grpc.ServicerContext
) -> AsyncGenerator[ReceiveResponse, None]:
self.log.debug('Receive')
oneof = request.WhichOneof('source')
self.log.debug(f'Source: {oneof}.')
pandora_channel = getattr(request, oneof)
sdu_queue = self.lookup_context(pandora_channel).sdu_queue
while sdu := await sdu_queue.get():
self.log.debug(f'Receive: Received {len(sdu)} bytes -> {sdu.decode()}')
response = ReceiveResponse(data=sdu)
yield response
@utils.rpc
async def Connect(
self, request: ConnectRequest, context: grpc.ServicerContext
) -> ConnectResponse:
self.log.debug('Connect')
if not request.connection:
raise ValueError('A valid connection field must be set')
# find connection on device based on connection cookie value
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
connection = self.device.lookup_connection(connection_handle)
if not connection:
raise ValueError('The connection specified is invalid.')
oneof = request.WhichOneof('type')
self.log.debug(f'Channel request type: {oneof}.')
channel_type = getattr(request, oneof)
spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
if isinstance(channel_type, CreditBasedChannelRequest):
spec = LeCreditBasedChannelSpec(
psm=channel_type.spsm,
max_credits=channel_type.initial_credit,
mtu=channel_type.mtu,
mps=channel_type.mps,
)
else:
spec = ClassicChannelSpec(
psm=channel_type.psm,
mtu=channel_type.mtu,
)
try:
self.log.info(f'Opening L2CAP channel on PSM = {spec.psm}')
l2cap_channel = await connection.create_l2cap_channel(spec=spec)
channel_context = self.register_event(l2cap_channel)
pandora_channel = self.craft_pandora_channel(
connection_handle, l2cap_channel
)
self.channels[pandora_channel.cookie.value] = channel_context
return ConnectResponse(channel=pandora_channel)
except OutOfResourcesError as e:
self.log.error(e)
return ConnectResponse(error=INVALID_CID_IN_REQUEST)
except InvalidArgumentError as e:
self.log.error(e)
return ConnectResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc
async def Disconnect(
self, request: DisconnectRequest, context: grpc.ServicerContext
) -> DisconnectResponse:
try:
self.log.debug('Disconnect')
l2cap_channel = self.lookup_channel(request.channel)
if not l2cap_channel:
self.log.warning('Disconnect: Unable to find the channel')
return DisconnectResponse(error=INVALID_CID_IN_REQUEST)
await l2cap_channel.disconnect()
return DisconnectResponse(success=empty_pb2.Empty())
except Exception as e:
self.log.exception(f'Disonnect failed: {e}')
return DisconnectResponse(error=COMMAND_NOT_UNDERSTOOD)
@utils.rpc
async def Send(
self, request: SendRequest, context: grpc.ServicerContext
) -> SendResponse:
self.log.debug('Send')
try:
oneof = request.WhichOneof('sink')
self.log.debug(f'Sink: {oneof}.')
pandora_channel = getattr(request, oneof)
l2cap_channel = self.lookup_channel(pandora_channel)
if not l2cap_channel:
return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
if isinstance(l2cap_channel, ClassicChannel):
l2cap_channel.send_pdu(request.data)
else:
l2cap_channel.write(request.data)
return SendResponse(success=empty_pb2.Empty())
except Exception as e:
self.log.exception(f'Disonnect failed: {e}')
return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
def craft_pandora_channel(
self,
connection_handle: int,
l2cap_channel: L2capChannel,
) -> PandoraChannel:
parameters = {
"connection_handle": connection_handle,
"source_cid": l2cap_channel.source_cid,
}
cookie = any_pb2.Any()
cookie.value = json.dumps(parameters).encode()
return PandoraChannel(cookie=cookie)
def lookup_channel(self, pandora_channel: PandoraChannel) -> L2capChannel:
(connection_handle, source_cid) = json.loads(
pandora_channel.cookie.value
).values()
return self.device.l2cap_channel_manager.channels[connection_handle][source_cid]
def lookup_context(self, pandora_channel: PandoraChannel) -> ChannelContext:
return self.channels[pandora_channel.cookie.value]

View File

@@ -110,7 +110,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # pytype: disable=name-error
answer = await anext(self.service.event_answer) # type: ignore
assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
@@ -125,7 +125,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(numeric_comparison=number))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # pytype: disable=name-error
answer = await anext(self.service.event_answer) # type: ignore
assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
@@ -140,7 +140,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # pytype: disable=name-error
answer = await anext(self.service.event_answer) # type: ignore
assert answer.event == event
if answer.answer_variant() is None:
return None
@@ -157,7 +157,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # pytype: disable=name-error
answer = await anext(self.service.event_answer) # type: ignore
assert answer.event == event
if answer.answer_variant() is None:
return None
@@ -383,9 +383,9 @@ class SecurityService(SecurityServicer):
connection.transport
] == request.level_variant()
wait_for_security: asyncio.Future[
str
] = asyncio.get_running_loop().create_future()
wait_for_security: asyncio.Future[str] = (
asyncio.get_running_loop().create_future()
)
authenticate_task: Optional[asyncio.Future[None]] = None
pair_task: Optional[asyncio.Future[None]] = None

504
bumble/profiles/aics.py Normal file
View File

@@ -0,0 +1,504 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LE Audio - Audio Input Control Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import logging
import struct
from dataclasses import dataclass
from typing import Optional
from bumble import gatt
from bumble.device import Connection
from bumble.att import ATT_Error
from bumble.gatt import (
Characteristic,
SerializableCharacteristicAdapter,
PackedCharacteristicAdapter,
TemplateService,
CharacteristicValue,
UTF8CharacteristicAdapter,
GATT_AUDIO_INPUT_CONTROL_SERVICE,
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
from bumble.utils import OpenIntEnum
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
CHANGE_COUNTER_MAX_VALUE = 0xFF
GAIN_SETTINGS_MIN_VALUE = 0
GAIN_SETTINGS_MAX_VALUE = 255
class ErrorCode(OpenIntEnum):
'''
Cf. 1.6 Application error codes
'''
INVALID_CHANGE_COUNTER = 0x80
OPCODE_NOT_SUPPORTED = 0x81
MUTE_DISABLED = 0x82
VALUE_OUT_OF_RANGE = 0x83
GAIN_MODE_CHANGE_NOT_ALLOWED = 0x84
class Mute(OpenIntEnum):
'''
Cf. 2.2.1.2 Mute Field
'''
NOT_MUTED = 0x00
MUTED = 0x01
DISABLED = 0x02
class GainMode(OpenIntEnum):
'''
Cf. 2.2.1.3 Gain Mode
'''
MANUAL_ONLY = 0x00
AUTOMATIC_ONLY = 0x01
MANUAL = 0x02
AUTOMATIC = 0x03
class AudioInputStatus(OpenIntEnum):
'''
Cf. 3.4 Audio Input Status
'''
INACTIVE = 0x00
ACTIVE = 0x01
class AudioInputControlPointOpCode(OpenIntEnum):
'''
Cf. 3.5.1 Audio Input Control Point procedure requirements
'''
SET_GAIN_SETTING = 0x01
UNMUTE = 0x02
MUTE = 0x03
SET_MANUAL_GAIN_MODE = 0x04
SET_AUTOMATIC_GAIN_MODE = 0x05
# -----------------------------------------------------------------------------
@dataclass
class AudioInputState:
'''
Cf. 2.2.1 Audio Input State
'''
gain_settings: int = 0
mute: Mute = Mute.NOT_MUTED
gain_mode: GainMode = GainMode.MANUAL
change_counter: int = 0
attribute_value: Optional[CharacteristicValue] = None
def __bytes__(self) -> bytes:
return bytes(
[self.gain_settings, self.mute, self.gain_mode, self.change_counter]
)
@classmethod
def from_bytes(cls, data: bytes):
gain_settings, mute, gain_mode, change_counter = struct.unpack("BBBB", data)
return cls(gain_settings, mute, gain_mode, change_counter)
def update_gain_settings_unit(self, gain_settings_unit: int) -> None:
self.gain_settings_unit = gain_settings_unit
def increment_gain_settings(self, gain_settings_unit: int) -> None:
self.gain_settings += gain_settings_unit
self.increment_change_counter()
def decrement_gain_settings(self) -> None:
self.gain_settings -= self.gain_settings_unit
self.increment_change_counter()
def increment_change_counter(self):
self.change_counter = (self.change_counter + 1) % (CHANGE_COUNTER_MAX_VALUE + 1)
async def notify_subscribers_via_connection(self, connection: Connection) -> None:
assert self.attribute_value is not None
await connection.device.notify_subscribers(
attribute=self.attribute_value, value=bytes(self)
)
@dataclass
class GainSettingsProperties:
'''
Cf. 3.2 Gain Settings Properties
'''
gain_settings_unit: int = 1
gain_settings_minimum: int = GAIN_SETTINGS_MIN_VALUE
gain_settings_maximum: int = GAIN_SETTINGS_MAX_VALUE
@classmethod
def from_bytes(cls, data: bytes):
(gain_settings_unit, gain_settings_minimum, gain_settings_maximum) = (
struct.unpack('BBB', data)
)
return GainSettingsProperties(
gain_settings_unit, gain_settings_minimum, gain_settings_maximum
)
def __bytes__(self) -> bytes:
return bytes(
[
self.gain_settings_unit,
self.gain_settings_minimum,
self.gain_settings_maximum,
]
)
@dataclass
class AudioInputControlPoint:
'''
Cf. 3.5.2 Audio Input Control Point
'''
audio_input_state: AudioInputState
gain_settings_properties: GainSettingsProperties
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
opcode = AudioInputControlPointOpCode(value[0])
if opcode == AudioInputControlPointOpCode.SET_GAIN_SETTING:
gain_settings_operand = value[2]
await self._set_gain_settings(connection, gain_settings_operand)
elif opcode == AudioInputControlPointOpCode.UNMUTE:
await self._unmute(connection)
elif opcode == AudioInputControlPointOpCode.MUTE:
change_counter_operand = value[1]
await self._mute(connection, change_counter_operand)
elif opcode == AudioInputControlPointOpCode.SET_MANUAL_GAIN_MODE:
await self._set_manual_gain_mode(connection)
elif opcode == AudioInputControlPointOpCode.SET_AUTOMATIC_GAIN_MODE:
await self._set_automatic_gain_mode(connection)
else:
logger.error(f"OpCode value is incorrect: {opcode}")
raise ATT_Error(ErrorCode.OPCODE_NOT_SUPPORTED)
async def _set_gain_settings(
self, connection: Connection, gain_settings_operand: int
) -> None:
'''Cf. 3.5.2.1 Set Gain Settings Procedure'''
gain_mode = self.audio_input_state.gain_mode
logger.error(f"set_gain_setting: gain_mode: {gain_mode}")
if not (gain_mode == GainMode.MANUAL or gain_mode == GainMode.MANUAL_ONLY):
logger.warning(
"GainMode should be either MANUAL or MANUAL_ONLY Cf Spec Audio Input Control Service 3.5.2.1"
)
return
if (
gain_settings_operand < self.gain_settings_properties.gain_settings_minimum
or gain_settings_operand
> self.gain_settings_properties.gain_settings_maximum
):
logger.error("gain_settings value out of range")
raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
if self.audio_input_state.gain_settings != gain_settings_operand:
self.audio_input_state.gain_settings = gain_settings_operand
await self.audio_input_state.notify_subscribers_via_connection(connection)
async def _unmute(self, connection: Connection):
'''Cf. 3.5.2.2 Unmute procedure'''
logger.error(f'unmute: {self.audio_input_state.mute}')
mute = self.audio_input_state.mute
if mute == Mute.DISABLED:
logger.error("unmute: Cannot change Mute value, Mute state is DISABLED")
raise ATT_Error(ErrorCode.MUTE_DISABLED)
if mute == Mute.NOT_MUTED:
return
self.audio_input_state.mute = Mute.NOT_MUTED
self.audio_input_state.increment_change_counter()
await self.audio_input_state.notify_subscribers_via_connection(connection)
async def _mute(self, connection: Connection, change_counter_operand: int) -> None:
'''Cf. 3.5.5.2 Mute procedure'''
change_counter = self.audio_input_state.change_counter
mute = self.audio_input_state.mute
if mute == Mute.DISABLED:
logger.error("mute: Cannot change Mute value, Mute state is DISABLED")
raise ATT_Error(ErrorCode.MUTE_DISABLED)
if change_counter != change_counter_operand:
raise ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
if mute == Mute.MUTED:
return
self.audio_input_state.mute = Mute.MUTED
self.audio_input_state.increment_change_counter()
await self.audio_input_state.notify_subscribers_via_connection(connection)
async def _set_manual_gain_mode(self, connection: Connection) -> None:
'''Cf. 3.5.2.4 Set Manual Gain Mode procedure'''
gain_mode = self.audio_input_state.gain_mode
if gain_mode in (GainMode.AUTOMATIC_ONLY, GainMode.MANUAL_ONLY):
logger.error(f"Cannot change gain_mode, bad state: {gain_mode}")
raise ATT_Error(ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED)
if gain_mode == GainMode.MANUAL:
return
self.audio_input_state.gain_mode = GainMode.MANUAL
self.audio_input_state.increment_change_counter()
await self.audio_input_state.notify_subscribers_via_connection(connection)
async def _set_automatic_gain_mode(self, connection: Connection) -> None:
'''Cf. 3.5.2.5 Set Automatic Gain Mode'''
gain_mode = self.audio_input_state.gain_mode
if gain_mode in (GainMode.AUTOMATIC_ONLY, GainMode.MANUAL_ONLY):
logger.error(f"Cannot change gain_mode, bad state: {gain_mode}")
raise ATT_Error(ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED)
if gain_mode == GainMode.AUTOMATIC:
return
self.audio_input_state.gain_mode = GainMode.AUTOMATIC
self.audio_input_state.increment_change_counter()
await self.audio_input_state.notify_subscribers_via_connection(connection)
@dataclass
class AudioInputDescription:
'''
Cf. 3.6 Audio Input Description
'''
audio_input_description: str = "Bluetooth"
attribute_value: Optional[CharacteristicValue] = None
def on_read(self, _connection: Optional[Connection]) -> str:
return self.audio_input_description
async def on_write(self, connection: Optional[Connection], value: str) -> None:
assert connection
assert self.attribute_value
self.audio_input_description = value
await connection.device.notify_subscribers(
attribute=self.attribute_value, value=value
)
class AICSService(TemplateService):
UUID = GATT_AUDIO_INPUT_CONTROL_SERVICE
def __init__(
self,
audio_input_state: Optional[AudioInputState] = None,
gain_settings_properties: Optional[GainSettingsProperties] = None,
audio_input_type: str = "local",
audio_input_status: Optional[AudioInputStatus] = None,
audio_input_description: Optional[AudioInputDescription] = None,
):
self.audio_input_state = (
AudioInputState() if audio_input_state is None else audio_input_state
)
self.gain_settings_properties = (
GainSettingsProperties()
if gain_settings_properties is None
else gain_settings_properties
)
self.audio_input_status = (
AudioInputStatus.ACTIVE
if audio_input_status is None
else audio_input_status
)
self.audio_input_description = (
AudioInputDescription()
if audio_input_description is None
else audio_input_description
)
self.audio_input_control_point: AudioInputControlPoint = AudioInputControlPoint(
self.audio_input_state, self.gain_settings_properties
)
self.audio_input_state_characteristic = SerializableCharacteristicAdapter(
Characteristic(
uuid=GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
properties=Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=self.audio_input_state,
),
AudioInputState,
)
self.audio_input_state.attribute_value = (
self.audio_input_state_characteristic.value
)
self.gain_settings_properties_characteristic = (
SerializableCharacteristicAdapter(
Characteristic(
uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=self.gain_settings_properties,
),
GainSettingsProperties,
)
)
self.audio_input_type_characteristic = Characteristic(
uuid=GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes(audio_input_type, 'utf-8'),
)
self.audio_input_status_characteristic = Characteristic(
uuid=GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes([self.audio_input_status]),
)
self.audio_input_control_point_characteristic = Characteristic(
uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
properties=Characteristic.Properties.WRITE,
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=CharacteristicValue(write=self.audio_input_control_point.on_write),
)
self.audio_input_description_characteristic = UTF8CharacteristicAdapter(
Characteristic(
uuid=GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
properties=Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=CharacteristicValue(
write=self.audio_input_description.on_write,
read=self.audio_input_description.on_read,
),
)
)
self.audio_input_description.attribute_value = (
self.audio_input_control_point_characteristic.value
)
super().__init__(
characteristics=[
self.audio_input_state_characteristic, # type: ignore
self.gain_settings_properties_characteristic, # type: ignore
self.audio_input_type_characteristic, # type: ignore
self.audio_input_status_characteristic, # type: ignore
self.audio_input_control_point_characteristic, # type: ignore
self.audio_input_description_characteristic, # type: ignore
],
primary=False,
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class AICSServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = AICSService
def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError("Audio Input State Characteristic not found")
self.audio_input_state = SerializableCharacteristicAdapter(
characteristics[0], AudioInputState
)
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Gain Settings Attribute Characteristic not found"
)
self.gain_settings_properties = SerializableCharacteristicAdapter(
characteristics[0], GainSettingsProperties
)
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Audio Input Status Characteristic not found"
)
self.audio_input_status = PackedCharacteristicAdapter(characteristics[0], 'B')
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Audio Input Control Point Characteristic not found"
)
self.audio_input_control_point = characteristics[0]
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Audio Input Description Characteristic not found"
)
self.audio_input_description = UTF8CharacteristicAdapter(characteristics[0])

739
bumble/profiles/ascs.py Normal file
View File

@@ -0,0 +1,739 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for
"""LE Audio - Audio Stream Control Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import logging
import struct
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from bumble import colors
from bumble.profiles.bap import CodecSpecificConfiguration
from bumble.profiles import le_audio
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import hci
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# ASE Operations
# -----------------------------------------------------------------------------
class ASE_Operation:
'''
See Audio Stream Control Service - 5 ASE Control operations.
'''
classes: Dict[int, Type[ASE_Operation]] = {}
op_code: int
name: str
fields: Optional[Sequence[Any]] = None
ase_id: List[int]
class Opcode(enum.IntEnum):
# fmt: off
CONFIG_CODEC = 0x01
CONFIG_QOS = 0x02
ENABLE = 0x03
RECEIVER_START_READY = 0x04
DISABLE = 0x05
RECEIVER_STOP_READY = 0x06
UPDATE_METADATA = 0x07
RELEASE = 0x08
@staticmethod
def from_bytes(pdu: bytes) -> ASE_Operation:
op_code = pdu[0]
cls = ASE_Operation.classes.get(op_code)
if cls is None:
instance = ASE_Operation(pdu)
instance.name = ASE_Operation.Opcode(op_code).name
instance.op_code = op_code
return instance
self = cls.__new__(cls)
ASE_Operation.__init__(self, pdu)
if self.fields is not None:
self.init_from_bytes(pdu, 1)
return self
@staticmethod
def subclass(fields):
def inner(cls: Type[ASE_Operation]):
try:
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
cls.name = operation.name
cls.op_code = operation
except:
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
cls.fields = fields
# Register a factory for this class
ASE_Operation.classes[cls.op_code] = cls
return cls
return inner
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
if self.fields is not None and kwargs:
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
kwargs, self.fields
)
self.pdu = pdu
def init_from_bytes(self, pdu: bytes, offset: int):
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def __bytes__(self) -> bytes:
return self.pdu
def __str__(self) -> str:
result = f'{colors.color(self.name, "yellow")} '
if fields := getattr(self, 'fields', None):
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
else:
if len(self.pdu) > 1:
result += f': {self.pdu.hex()}'
return result
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('target_latency', 1),
('target_phy', 1),
('codec_id', hci.CodingFormat.parse_from_bytes),
('codec_specific_configuration', 'v'),
],
]
)
class ASE_Config_Codec(ASE_Operation):
'''
See Audio Stream Control Service 5.1 - Config Codec Operation
'''
target_latency: List[int]
target_phy: List[int]
codec_id: List[hci.CodingFormat]
codec_specific_configuration: List[bytes]
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('cig_id', 1),
('cis_id', 1),
('sdu_interval', 3),
('framing', 1),
('phy', 1),
('max_sdu', 2),
('retransmission_number', 1),
('max_transport_latency', 2),
('presentation_delay', 3),
],
]
)
class ASE_Config_QOS(ASE_Operation):
'''
See Audio Stream Control Service 5.2 - Config Qos Operation
'''
cig_id: List[int]
cis_id: List[int]
sdu_interval: List[int]
framing: List[int]
phy: List[int]
max_sdu: List[int]
retransmission_number: List[int]
max_transport_latency: List[int]
presentation_delay: List[int]
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Enable(ASE_Operation):
'''
See Audio Stream Control Service 5.3 - Enable Operation
'''
metadata: bytes
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Start_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Disable(ASE_Operation):
'''
See Audio Stream Control Service 5.5 - Disable Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Stop_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Update_Metadata(ASE_Operation):
'''
See Audio Stream Control Service 5.7 - Update Metadata Operation
'''
metadata: List[bytes]
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Release(ASE_Operation):
'''
See Audio Stream Control Service 5.8 - Release Operation
'''
class AseResponseCode(enum.IntEnum):
# fmt: off
SUCCESS = 0x00
UNSUPPORTED_OPCODE = 0x01
INVALID_LENGTH = 0x02
INVALID_ASE_ID = 0x03
INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04
INVALID_ASE_DIRECTION = 0x05
UNSUPPORTED_AUDIO_CAPABILITIES = 0x06
UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07
REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08
INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09
UNSUPPORTED_METADATA = 0x0A
REJECTED_METADATA = 0x0B
INVALID_METADATA = 0x0C
INSUFFICIENT_RESOURCES = 0x0D
UNSPECIFIED_ERROR = 0x0E
class AseReasonCode(enum.IntEnum):
# fmt: off
NONE = 0x00
CODEC_ID = 0x01
CODEC_SPECIFIC_CONFIGURATION = 0x02
SDU_INTERVAL = 0x03
FRAMING = 0x04
PHY = 0x05
MAXIMUM_SDU_SIZE = 0x06
RETRANSMISSION_NUMBER = 0x07
MAX_TRANSPORT_LATENCY = 0x08
PRESENTATION_DELAY = 0x09
INVALID_ASE_CIS_MAPPING = 0x0A
# -----------------------------------------------------------------------------
class AudioRole(enum.IntEnum):
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
# -----------------------------------------------------------------------------
class AseStateMachine(gatt.Characteristic):
class State(enum.IntEnum):
# fmt: off
IDLE = 0x00
CODEC_CONFIGURED = 0x01
QOS_CONFIGURED = 0x02
ENABLING = 0x03
STREAMING = 0x04
DISABLING = 0x05
RELEASING = 0x06
cis_link: Optional[device.CisLink] = None
# Additional parameters in CODEC_CONFIGURED State
preferred_framing = 0 # Unframed PDU supported
preferred_phy = 0
preferred_retransmission_number = 13
preferred_max_transport_latency = 100
supported_presentation_delay_min = 0
supported_presentation_delay_max = 0
preferred_presentation_delay_min = 0
preferred_presentation_delay_max = 0
codec_id = hci.CodingFormat(hci.CodecID.LC3)
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
# Additional parameters in QOS_CONFIGURED State
cig_id = 0
cis_id = 0
sdu_interval = 0
framing = 0
phy = 0
max_sdu = 0
retransmission_number = 0
max_transport_latency = 0
presentation_delay = 0
# Additional parameters in ENABLING, STREAMING, DISABLING State
metadata = le_audio.Metadata()
def __init__(
self,
role: AudioRole,
ase_id: int,
service: AudioStreamControlService,
) -> None:
self.service = service
self.ase_id = ase_id
self._state = AseStateMachine.State.IDLE
self.role = role
uuid = (
gatt.GATT_SINK_ASE_CHARACTERISTIC
if role == AudioRole.SINK
else gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
super().__init__(
uuid=uuid,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=gatt.CharacteristicValue(read=self.on_read),
)
self.service.device.on('cis_request', self.on_cis_request)
self.service.device.on('cis_establishment', self.on_cis_establishment)
def on_cis_request(
self,
acl_connection: device.Connection,
cis_handle: int,
cig_id: int,
cis_id: int,
) -> None:
if (
cig_id == self.cig_id
and cis_id == self.cis_id
and self.state == self.State.ENABLING
):
acl_connection.abort_on(
'flush', self.service.device.accept_cis_request(cis_handle)
)
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
if (
cis_link.cig_id == self.cig_id
and cis_link.cis_id == self.cis_id
and self.state == self.State.ENABLING
):
cis_link.on('disconnection', self.on_cis_disconnection)
async def post_cis_established():
await self.service.device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=cis_link.handle,
data_path_direction=self.role,
data_path_id=0x00, # Fixed HCI
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
)
)
if self.role == AudioRole.SINK:
self.state = self.State.STREAMING
await self.service.device.notify_subscribers(self, self.value)
cis_link.acl_connection.abort_on('flush', post_cis_established())
self.cis_link = cis_link
def on_cis_disconnection(self, _reason) -> None:
self.cis_link = None
def on_config_codec(
self,
target_latency: int,
target_phy: int,
codec_id: hci.CodingFormat,
codec_specific_configuration: bytes,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
self.State.IDLE,
self.State.CODEC_CONFIGURED,
self.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.max_transport_latency = target_latency
self.phy = target_phy
self.codec_id = codec_id
if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC:
self.codec_specific_configuration = codec_specific_configuration
else:
self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes(
codec_specific_configuration
)
self.state = self.State.CODEC_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_config_qos(
self,
cig_id: int,
cis_id: int,
sdu_interval: int,
framing: int,
phy: int,
max_sdu: int,
retransmission_number: int,
max_transport_latency: int,
presentation_delay: int,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.CODEC_CONFIGURED,
AseStateMachine.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.cig_id = cig_id
self.cis_id = cis_id
self.sdu_interval = sdu_interval
self.framing = framing
self.phy = phy
self.max_sdu = max_sdu
self.retransmission_number = retransmission_number
self.max_transport_latency = max_transport_latency
self.presentation_delay = presentation_delay
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.QOS_CONFIGURED:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = le_audio.Metadata.from_bytes(metadata)
self.state = self.State.ENABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.ENABLING:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.STREAMING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
if self.role == AudioRole.SINK:
self.state = self.State.QOS_CONFIGURED
else:
self.state = self.State.DISABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if (
self.role != AudioRole.SOURCE
or self.state != AseStateMachine.State.DISABLING
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_update_metadata(
self, metadata: bytes
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = le_audio.Metadata.from_bytes(metadata)
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state == AseStateMachine.State.IDLE:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.RELEASING
async def remove_cis_async():
await self.service.device.send_command(
hci.HCI_LE_Remove_ISO_Data_Path_Command(
connection_handle=self.cis_link.handle,
data_path_direction=self.role,
)
)
self.state = self.State.IDLE
await self.service.device.notify_subscribers(self, self.value)
self.service.device.abort_on('flush', remove_cis_async())
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
@property
def state(self) -> State:
return self._state
@state.setter
def state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
self._state = new_state
self.emit('state_change')
@property
def value(self):
'''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.'''
if self.state == self.State.CODEC_CONFIGURED:
codec_specific_configuration_bytes = bytes(
self.codec_specific_configuration
)
additional_parameters = (
struct.pack(
'<BBBH',
self.preferred_framing,
self.preferred_phy,
self.preferred_retransmission_number,
self.preferred_max_transport_latency,
)
+ self.supported_presentation_delay_min.to_bytes(3, 'little')
+ self.supported_presentation_delay_max.to_bytes(3, 'little')
+ self.preferred_presentation_delay_min.to_bytes(3, 'little')
+ self.preferred_presentation_delay_max.to_bytes(3, 'little')
+ bytes(self.codec_id)
+ bytes([len(codec_specific_configuration_bytes)])
+ codec_specific_configuration_bytes
)
elif self.state == self.State.QOS_CONFIGURED:
additional_parameters = (
bytes([self.cig_id, self.cis_id])
+ self.sdu_interval.to_bytes(3, 'little')
+ struct.pack(
'<BBHBH',
self.framing,
self.phy,
self.max_sdu,
self.retransmission_number,
self.max_transport_latency,
)
+ self.presentation_delay.to_bytes(3, 'little')
)
elif self.state in (
self.State.ENABLING,
self.State.STREAMING,
self.State.DISABLING,
):
metadata_bytes = bytes(self.metadata)
additional_parameters = (
bytes([self.cig_id, self.cis_id, len(metadata_bytes)]) + metadata_bytes
)
else:
additional_parameters = b''
return bytes([self.ase_id, self.state]) + additional_parameters
@value.setter
def value(self, _new_value):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: Optional[device.Connection]) -> bytes:
return self.value
def __str__(self) -> str:
return (
f'AseStateMachine(id={self.ase_id}, role={self.role.name} '
f'state={self._state.name})'
)
# -----------------------------------------------------------------------------
class AudioStreamControlService(gatt.TemplateService):
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
ase_state_machines: Dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic
_active_client: Optional[device.Connection] = None
def __init__(
self,
device: device.Device,
source_ase_id: Sequence[int] = (),
sink_ase_id: Sequence[int] = (),
) -> None:
self.device = device
self.ase_state_machines = {
**{
id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self)
for id in sink_ase_id
},
**{
id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self)
for id in source_ase_id
},
} # ASE state machines, by ASE ID
self.ase_control_point = gatt.Characteristic(
uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.WRITEABLE,
value=gatt.CharacteristicValue(write=self.on_write_ase_control_point),
)
super().__init__([self.ase_control_point, *self.ase_state_machines.values()])
def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args):
if ase := self.ase_state_machines.get(ase_id):
handler = getattr(ase, 'on_' + opcode.name.lower())
return (ase_id, *handler(*args))
else:
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
def _on_client_disconnected(self, _reason: int) -> None:
for ase in self.ase_state_machines.values():
ase.state = AseStateMachine.State.IDLE
self._active_client = None
def on_write_ase_control_point(self, connection, data):
if not self._active_client and connection:
self._active_client = connection
connection.once('disconnection', self._on_client_disconnected)
operation = ASE_Operation.from_bytes(data)
responses = []
logger.debug(f'*** ASCS Write {operation} ***')
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.ENABLE,
ASE_Operation.Opcode.UPDATE_METADATA,
):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.RECEIVER_START_READY,
ASE_Operation.Opcode.DISABLE,
ASE_Operation.Opcode.RECEIVER_STOP_READY,
ASE_Operation.Opcode.RELEASE,
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes(
[operation.op_code, len(responses)]
) + b''.join(map(bytes, responses))
self.device.abort_on(
'flush',
self.device.notify_subscribers(
self.ase_control_point, control_point_notification
),
)
for ase_id, *_ in responses:
if ase := self.ase_state_machines.get(ase_id):
self.device.abort_on(
'flush',
self.device.notify_subscribers(ase, ase.value),
)
# -----------------------------------------------------------------------------
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AudioStreamControlService
sink_ase: List[gatt_client.CharacteristicProxy]
source_ase: List[gatt_client.CharacteristicProxy]
ase_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.sink_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_ASE_CHARACTERISTIC
)
self.source_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
self.ase_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC
)[0]

295
bumble/profiles/asha.py Normal file
View File

@@ -0,0 +1,295 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import enum
import struct
import logging
from typing import List, Optional, Callable, Union, Any
from bumble import l2cap
from bumble import utils
from bumble import gatt
from bumble import gatt_client
from bumble.core import AdvertisingData
from bumble.device import Device, Connection
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
_logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class DeviceCapabilities(enum.IntFlag):
IS_RIGHT = 0x01
IS_DUAL = 0x02
CSIS_SUPPORTED = 0x04
class FeatureMap(enum.IntFlag):
LE_COC_AUDIO_OUTPUT_STREAMING_SUPPORTED = 0x01
class AudioType(utils.OpenIntEnum):
UNKNOWN = 0x00
RINGTONE = 0x01
PHONE_CALL = 0x02
MEDIA = 0x03
class OpCode(utils.OpenIntEnum):
START = 1
STOP = 2
STATUS = 3
class Codec(utils.OpenIntEnum):
G_722_16KHZ = 1
class SupportedCodecs(enum.IntFlag):
G_722_16KHZ = 1 << Codec.G_722_16KHZ
class PeripheralStatus(utils.OpenIntEnum):
"""Status update on the other peripheral."""
OTHER_PERIPHERAL_DISCONNECTED = 1
OTHER_PERIPHERAL_CONNECTED = 2
CONNECTION_PARAMETER_UPDATED = 3
class AudioStatus(utils.OpenIntEnum):
"""Status report field for the audio control point."""
OK = 0
UNKNOWN_COMMAND = -1
ILLEGAL_PARAMETERS = -2
# -----------------------------------------------------------------------------
class AshaService(gatt.TemplateService):
UUID = gatt.GATT_ASHA_SERVICE
audio_sink: Optional[Callable[[bytes], Any]]
active_codec: Optional[Codec] = None
audio_type: Optional[AudioType] = None
volume: Optional[int] = None
other_state: Optional[int] = None
connection: Optional[Connection] = None
def __init__(
self,
capability: int,
hisyncid: Union[List[int], bytes],
device: Device,
psm: int = 0,
audio_sink: Optional[Callable[[bytes], Any]] = None,
feature_map: int = FeatureMap.LE_COC_AUDIO_OUTPUT_STREAMING_SUPPORTED,
protocol_version: int = 0x01,
render_delay_milliseconds: int = 0,
supported_codecs: int = SupportedCodecs.G_722_16KHZ,
) -> None:
if len(hisyncid) != 8:
_logger.warning('HiSyncId should have a length of 8, got %d', len(hisyncid))
self.hisyncid = bytes(hisyncid)
self.capability = capability
self.device = device
self.audio_out_data = b''
self.psm = psm # a non-zero psm is mainly for testing purpose
self.audio_sink = audio_sink
self.protocol_version = protocol_version
self.read_only_properties_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
gatt.Characteristic.Properties.READ,
gatt.Characteristic.READABLE,
struct.pack(
"<BB8sBH2sH",
protocol_version,
capability,
self.hisyncid,
feature_map,
render_delay_milliseconds,
b'\x00\x00',
supported_codecs,
),
)
self.audio_control_point_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
gatt.Characteristic.WRITEABLE,
gatt.CharacteristicValue(write=self._on_audio_control_point_write),
)
self.audio_status_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY,
gatt.Characteristic.READABLE,
bytes([AudioStatus.OK]),
)
self.volume_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_VOLUME_CHARACTERISTIC,
gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
gatt.Characteristic.WRITEABLE,
gatt.CharacteristicValue(write=self._on_volume_write),
)
# let the server find a free PSM
self.psm = device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=self.psm, max_credits=8),
handler=self._on_connection,
).psm
self.le_psm_out_characteristic = gatt.Characteristic(
gatt.GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
gatt.Characteristic.Properties.READ,
gatt.Characteristic.READABLE,
struct.pack('<H', self.psm),
)
characteristics = [
self.read_only_properties_characteristic,
self.audio_control_point_characteristic,
self.audio_status_characteristic,
self.volume_characteristic,
self.le_psm_out_characteristic,
]
super().__init__(characteristics)
def get_advertising_data(self) -> bytes:
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData(
[
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(gatt.GATT_ASHA_SERVICE)
+ bytes([self.protocol_version, self.capability])
+ self.hisyncid[:4],
),
]
)
)
# Handler for audio control commands
async def _on_audio_control_point_write(
self, connection: Optional[Connection], value: bytes
) -> None:
_logger.debug(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == OpCode.START:
# Start
self.active_codec = Codec(value[1])
self.audio_type = AudioType(value[2])
self.volume = value[3]
self.other_state = value[4]
_logger.debug(
f'### START: codec={self.active_codec.name}, '
f'audio_type={self.audio_type.name}, '
f'volume={self.volume}, '
f'other_state={self.other_state}'
)
self.emit('started')
elif opcode == OpCode.STOP:
_logger.debug('### STOP')
self.active_codec = None
self.audio_type = None
self.volume = None
self.other_state = None
self.emit('stopped')
elif opcode == OpCode.STATUS:
_logger.debug('### STATUS: %s', PeripheralStatus(value[1]).name)
if self.connection is None and connection:
self.connection = connection
def on_disconnection(_reason) -> None:
self.connection = None
self.active_codec = None
self.audio_type = None
self.volume = None
self.other_state = None
self.emit('disconnected')
connection.once('disconnection', on_disconnection)
# OPCODE_STATUS does not need audio status point update
if opcode != OpCode.STATUS:
await self.device.notify_subscribers(
self.audio_status_characteristic, force=True
)
# Handler for volume control
def _on_volume_write(self, connection: Optional[Connection], value: bytes) -> None:
_logger.debug(f'--- VOLUME Write:{value[0]}')
self.volume = value[0]
self.emit('volume_changed')
# Register an L2CAP CoC server
def _on_connection(self, channel: l2cap.LeCreditBasedChannel) -> None:
def on_data(data: bytes) -> None:
if self.audio_sink: # pylint: disable=not-callable
self.audio_sink(data)
channel.sink = on_data
# -----------------------------------------------------------------------------
class AshaServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AshaService
read_only_properties_characteristic: gatt_client.CharacteristicProxy
audio_control_point_characteristic: gatt_client.CharacteristicProxy
audio_status_point_characteristic: gatt_client.CharacteristicProxy
volume_characteristic: gatt_client.CharacteristicProxy
psm_characteristic: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
for uuid, attribute_name in (
(
gatt.GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
'read_only_properties_characteristic',
),
(
gatt.GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
'audio_control_point_characteristic',
),
(
gatt.GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
'audio_status_point_characteristic',
),
(
gatt.GATT_ASHA_VOLUME_CHARACTERISTIC,
'volume_characteristic',
),
(
gatt.GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
'psm_characteristic',
),
):
if not (
characteristics := self.service_proxy.get_characteristics_by_uuid(uuid)
):
raise gatt.InvalidServiceError(f"Missing {uuid} Characteristic")
setattr(self, attribute_name, characteristics[0])

View File

@@ -1,193 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
import logging
from typing import List
from bumble import l2cap
from ..core import AdvertisingData
from ..device import Device, Connection
from ..gatt import (
GATT_ASHA_SERVICE,
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
GATT_ASHA_VOLUME_CHARACTERISTIC,
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
TemplateService,
Characteristic,
CharacteristicValue,
)
from ..utils import AsyncRunner
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class AshaService(TemplateService):
UUID = GATT_ASHA_SERVICE
OPCODE_START = 1
OPCODE_STOP = 2
OPCODE_STATUS = 3
PROTOCOL_VERSION = 0x01
RESERVED_FOR_FUTURE_USE = [00, 00]
FEATURE_MAP = [0x01] # [LE CoC audio output streaming supported]
SUPPORTED_CODEC_ID = [0x02, 0x01] # Codec IDs [G.722 at 16 kHz]
RENDER_DELAY = [00, 00]
def __init__(self, capability: int, hisyncid: List[int], device: Device, psm=0):
self.hisyncid = hisyncid
self.capability = capability # Device Capabilities [Left, Monaural]
self.device = device
self.audio_out_data = b''
self.psm = psm # a non-zero psm is mainly for testing purpose
# Handler for volume control
def on_volume_write(connection, value):
logger.info(f'--- VOLUME Write:{value[0]}')
self.emit('volume', connection, value[0])
# Handler for audio control commands
def on_audio_control_point_write(connection: Connection, value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
logger.info(
f'### START: codec={value[1]}, '
f'audio_type={audio_type}, '
f'volume={value[3]}, '
f'otherstate={value[4]}'
)
self.emit(
'start',
connection,
{
'codec': value[1],
'audiotype': value[2],
'volume': value[3],
'otherstate': value[4],
},
)
elif opcode == AshaService.OPCODE_STOP:
logger.info('### STOP')
self.emit('stop', connection)
elif opcode == AshaService.OPCODE_STATUS:
logger.info(f'### STATUS: connected={value[1]}')
# OPCODE_STATUS does not need audio status point update
if opcode != AshaService.OPCODE_STATUS:
AsyncRunner.spawn(
device.notify_subscribers(
self.audio_status_characteristic, force=True
)
)
self.read_only_properties_characteristic = Characteristic(
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
bytes(
[
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]
)
+ bytes(self.hisyncid)
+ bytes(AshaService.FEATURE_MAP)
+ bytes(AshaService.RENDER_DELAY)
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
+ bytes(AshaService.SUPPORTED_CODEC_ID),
)
self.audio_control_point_characteristic = Characteristic(
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write),
)
self.audio_status_characteristic = Characteristic(
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([0]),
)
self.volume_characteristic = Characteristic(
GATT_ASHA_VOLUME_CHARACTERISTIC,
Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write),
)
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
logging.debug(f'<<< data received:{data}')
self.emit('data', channel.connection, data)
self.audio_out_data += data
channel.sink = on_data
# let the server find a free PSM
self.psm = device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=self.psm, max_credits=8),
handler=on_coc,
).psm
self.le_psm_out_characteristic = Characteristic(
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', self.psm),
)
characteristics = [
self.read_only_properties_characteristic,
self.audio_control_point_characteristic,
self.audio_status_characteristic,
self.volume_characteristic,
self.le_psm_out_characteristic,
]
super().__init__(characteristics)
def get_advertising_data(self):
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData(
[
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(GATT_ASHA_SERVICE)
+ bytes(
[
AshaService.PROTOCOL_VERSION,
self.capability,
]
)
+ bytes(self.hisyncid[:4]),
),
]
)
)

623
bumble/profiles/bap.py Normal file
View File

@@ -0,0 +1,623 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import Sequence
import dataclasses
import enum
import struct
import functools
import logging
from typing import List
from typing_extensions import Self
from bumble import core
from bumble import hci
from bumble import gatt
from bumble import utils
from bumble.profiles import le_audio
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class AudioLocation(enum.IntFlag):
'''Bluetooth Assigned Numbers, Section 6.12.1 - Audio Location'''
# fmt: off
NOT_ALLOWED = 0x00000000
FRONT_LEFT = 0x00000001
FRONT_RIGHT = 0x00000002
FRONT_CENTER = 0x00000004
LOW_FREQUENCY_EFFECTS_1 = 0x00000008
BACK_LEFT = 0x00000010
BACK_RIGHT = 0x00000020
FRONT_LEFT_OF_CENTER = 0x00000040
FRONT_RIGHT_OF_CENTER = 0x00000080
BACK_CENTER = 0x00000100
LOW_FREQUENCY_EFFECTS_2 = 0x00000200
SIDE_LEFT = 0x00000400
SIDE_RIGHT = 0x00000800
TOP_FRONT_LEFT = 0x00001000
TOP_FRONT_RIGHT = 0x00002000
TOP_FRONT_CENTER = 0x00004000
TOP_CENTER = 0x00008000
TOP_BACK_LEFT = 0x00010000
TOP_BACK_RIGHT = 0x00020000
TOP_SIDE_LEFT = 0x00040000
TOP_SIDE_RIGHT = 0x00080000
TOP_BACK_CENTER = 0x00100000
BOTTOM_FRONT_CENTER = 0x00200000
BOTTOM_FRONT_LEFT = 0x00400000
BOTTOM_FRONT_RIGHT = 0x00800000
FRONT_LEFT_WIDE = 0x01000000
FRONT_RIGHT_WIDE = 0x02000000
LEFT_SURROUND = 0x04000000
RIGHT_SURROUND = 0x08000000
@property
def channel_count(self) -> int:
return bin(self.value).count('1')
class AudioInputType(enum.IntEnum):
'''Bluetooth Assigned Numbers, Section 6.12.2 - Audio Input Type'''
# fmt: off
UNSPECIFIED = 0x00
BLUETOOTH = 0x01
MICROPHONE = 0x02
ANALOG = 0x03
DIGITAL = 0x04
RADIO = 0x05
STREAMING = 0x06
AMBIENT = 0x07
class ContextType(enum.IntFlag):
'''Bluetooth Assigned Numbers, Section 6.12.3 - Context Type'''
# fmt: off
PROHIBITED = 0x0000
UNSPECIFIED = 0x0001
CONVERSATIONAL = 0x0002
MEDIA = 0x0004
GAME = 0x0008
INSTRUCTIONAL = 0x0010
VOICE_ASSISTANTS = 0x0020
LIVE = 0x0040
SOUND_EFFECTS = 0x0080
NOTIFICATIONS = 0x0100
RINGTONE = 0x0200
ALERTS = 0x0400
EMERGENCY_ALARM = 0x0800
class SamplingFrequency(utils.OpenIntEnum):
'''Bluetooth Assigned Numbers, Section 6.12.5.1 - Sampling Frequency'''
# fmt: off
FREQ_8000 = 0x01
FREQ_11025 = 0x02
FREQ_16000 = 0x03
FREQ_22050 = 0x04
FREQ_24000 = 0x05
FREQ_32000 = 0x06
FREQ_44100 = 0x07
FREQ_48000 = 0x08
FREQ_88200 = 0x09
FREQ_96000 = 0x0A
FREQ_176400 = 0x0B
FREQ_192000 = 0x0C
FREQ_384000 = 0x0D
# fmt: on
@classmethod
def from_hz(cls, frequency: int) -> SamplingFrequency:
return {
8000: SamplingFrequency.FREQ_8000,
11025: SamplingFrequency.FREQ_11025,
16000: SamplingFrequency.FREQ_16000,
22050: SamplingFrequency.FREQ_22050,
24000: SamplingFrequency.FREQ_24000,
32000: SamplingFrequency.FREQ_32000,
44100: SamplingFrequency.FREQ_44100,
48000: SamplingFrequency.FREQ_48000,
88200: SamplingFrequency.FREQ_88200,
96000: SamplingFrequency.FREQ_96000,
176400: SamplingFrequency.FREQ_176400,
192000: SamplingFrequency.FREQ_192000,
384000: SamplingFrequency.FREQ_384000,
}[frequency]
@property
def hz(self) -> int:
return {
SamplingFrequency.FREQ_8000: 8000,
SamplingFrequency.FREQ_11025: 11025,
SamplingFrequency.FREQ_16000: 16000,
SamplingFrequency.FREQ_22050: 22050,
SamplingFrequency.FREQ_24000: 24000,
SamplingFrequency.FREQ_32000: 32000,
SamplingFrequency.FREQ_44100: 44100,
SamplingFrequency.FREQ_48000: 48000,
SamplingFrequency.FREQ_88200: 88200,
SamplingFrequency.FREQ_96000: 96000,
SamplingFrequency.FREQ_176400: 176400,
SamplingFrequency.FREQ_192000: 192000,
SamplingFrequency.FREQ_384000: 384000,
}[self]
class SupportedSamplingFrequency(enum.IntFlag):
'''Bluetooth Assigned Numbers, Section 6.12.4.1 - Sample Frequency'''
# fmt: off
FREQ_8000 = 1 << (SamplingFrequency.FREQ_8000 - 1)
FREQ_11025 = 1 << (SamplingFrequency.FREQ_11025 - 1)
FREQ_16000 = 1 << (SamplingFrequency.FREQ_16000 - 1)
FREQ_22050 = 1 << (SamplingFrequency.FREQ_22050 - 1)
FREQ_24000 = 1 << (SamplingFrequency.FREQ_24000 - 1)
FREQ_32000 = 1 << (SamplingFrequency.FREQ_32000 - 1)
FREQ_44100 = 1 << (SamplingFrequency.FREQ_44100 - 1)
FREQ_48000 = 1 << (SamplingFrequency.FREQ_48000 - 1)
FREQ_88200 = 1 << (SamplingFrequency.FREQ_88200 - 1)
FREQ_96000 = 1 << (SamplingFrequency.FREQ_96000 - 1)
FREQ_176400 = 1 << (SamplingFrequency.FREQ_176400 - 1)
FREQ_192000 = 1 << (SamplingFrequency.FREQ_192000 - 1)
FREQ_384000 = 1 << (SamplingFrequency.FREQ_384000 - 1)
# fmt: on
@classmethod
def from_hz(cls, frequencies: Sequence[int]) -> SupportedSamplingFrequency:
MAPPING = {
8000: SupportedSamplingFrequency.FREQ_8000,
11025: SupportedSamplingFrequency.FREQ_11025,
16000: SupportedSamplingFrequency.FREQ_16000,
22050: SupportedSamplingFrequency.FREQ_22050,
24000: SupportedSamplingFrequency.FREQ_24000,
32000: SupportedSamplingFrequency.FREQ_32000,
44100: SupportedSamplingFrequency.FREQ_44100,
48000: SupportedSamplingFrequency.FREQ_48000,
88200: SupportedSamplingFrequency.FREQ_88200,
96000: SupportedSamplingFrequency.FREQ_96000,
176400: SupportedSamplingFrequency.FREQ_176400,
192000: SupportedSamplingFrequency.FREQ_192000,
384000: SupportedSamplingFrequency.FREQ_384000,
}
return functools.reduce(
lambda x, y: x | MAPPING[y],
frequencies,
cls(0),
)
class FrameDuration(enum.IntEnum):
'''Bluetooth Assigned Numbers, Section 6.12.5.2 - Frame Duration'''
# fmt: off
DURATION_7500_US = 0x00
DURATION_10000_US = 0x01
@property
def us(self) -> int:
return {
FrameDuration.DURATION_7500_US: 7500,
FrameDuration.DURATION_10000_US: 10000,
}[self]
class SupportedFrameDuration(enum.IntFlag):
'''Bluetooth Assigned Numbers, Section 6.12.4.2 - Frame Duration'''
# fmt: off
DURATION_7500_US_SUPPORTED = 0b0001
DURATION_10000_US_SUPPORTED = 0b0010
DURATION_7500_US_PREFERRED = 0b0001
DURATION_10000_US_PREFERRED = 0b0010
class AnnouncementType(utils.OpenIntEnum):
'''Basic Audio Profile, 3.5.3. Additional Audio Stream Control Service requirements'''
# fmt: off
GENERAL = 0x00
TARGETED = 0x01
@dataclasses.dataclass
class UnicastServerAdvertisingData:
"""Advertising Data for ASCS."""
announcement_type: AnnouncementType = AnnouncementType.TARGETED
available_audio_contexts: ContextType = ContextType.MEDIA
metadata: bytes = b''
def __bytes__(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
struct.pack(
'<2sBIB',
bytes(gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE),
self.announcement_type,
self.available_audio_contexts,
len(self.metadata),
)
+ self.metadata,
)
]
)
)
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def bits_to_channel_counts(data: int) -> List[int]:
pos = 0
counts = []
while data != 0:
# Bit 0 = count 1
# Bit 1 = count 2, and so on
pos += 1
if data & 1:
counts.append(pos)
data >>= 1
return counts
def channel_counts_to_bits(counts: Sequence[int]) -> int:
return sum(set([1 << (count - 1) for count in counts]))
# -----------------------------------------------------------------------------
# Structures
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class CodecSpecificCapabilities:
'''See:
* Bluetooth Assigned Numbers, 6.12.4 - Codec Specific Capabilities LTV Structures
* Basic Audio Profile, 4.3.1 - Codec_Specific_Capabilities LTV requirements
'''
class Type(enum.IntEnum):
# fmt: off
SAMPLING_FREQUENCY = 0x01
FRAME_DURATION = 0x02
AUDIO_CHANNEL_COUNT = 0x03
OCTETS_PER_FRAME = 0x04
CODEC_FRAMES_PER_SDU = 0x05
supported_sampling_frequencies: SupportedSamplingFrequency
supported_frame_durations: SupportedFrameDuration
supported_audio_channel_count: Sequence[int]
min_octets_per_codec_frame: int
max_octets_per_codec_frame: int
supported_max_codec_frames_per_sdu: int
@classmethod
def from_bytes(cls, data: bytes) -> CodecSpecificCapabilities:
offset = 0
# Allowed default values.
supported_audio_channel_count = [1]
supported_max_codec_frames_per_sdu = 1
while offset < len(data):
length, type = struct.unpack_from('BB', data, offset)
offset += 2
value = int.from_bytes(data[offset : offset + length - 1], 'little')
offset += length - 1
if type == CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
supported_sampling_frequencies = SupportedSamplingFrequency(value)
elif type == CodecSpecificCapabilities.Type.FRAME_DURATION:
supported_frame_durations = SupportedFrameDuration(value)
elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
supported_audio_channel_count = bits_to_channel_counts(value)
elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
min_octets_per_sample = value & 0xFFFF
max_octets_per_sample = value >> 16
elif type == CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
supported_max_codec_frames_per_sdu = value
# It is expected here that if some fields are missing, an error should be raised.
# pylint: disable=possibly-used-before-assignment,used-before-assignment
return CodecSpecificCapabilities(
supported_sampling_frequencies=supported_sampling_frequencies,
supported_frame_durations=supported_frame_durations,
supported_audio_channel_count=supported_audio_channel_count,
min_octets_per_codec_frame=min_octets_per_sample,
max_octets_per_codec_frame=max_octets_per_sample,
supported_max_codec_frames_per_sdu=supported_max_codec_frames_per_sdu,
)
def __bytes__(self) -> bytes:
return struct.pack(
'<BBHBBBBBBBBHHBBB',
3,
CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY,
self.supported_sampling_frequencies,
2,
CodecSpecificCapabilities.Type.FRAME_DURATION,
self.supported_frame_durations,
2,
CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT,
channel_counts_to_bits(self.supported_audio_channel_count),
5,
CodecSpecificCapabilities.Type.OCTETS_PER_FRAME,
self.min_octets_per_codec_frame,
self.max_octets_per_codec_frame,
2,
CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU,
self.supported_max_codec_frames_per_sdu,
)
@dataclasses.dataclass
class CodecSpecificConfiguration:
'''See:
* Bluetooth Assigned Numbers, 6.12.5 - Codec Specific Configuration LTV Structures
* Basic Audio Profile, 4.3.2 - Codec_Specific_Capabilities LTV requirements
'''
class Type(utils.OpenIntEnum):
# fmt: off
SAMPLING_FREQUENCY = 0x01
FRAME_DURATION = 0x02
AUDIO_CHANNEL_ALLOCATION = 0x03
OCTETS_PER_FRAME = 0x04
CODEC_FRAMES_PER_SDU = 0x05
sampling_frequency: SamplingFrequency | None = None
frame_duration: FrameDuration | None = None
audio_channel_allocation: AudioLocation | None = None
octets_per_codec_frame: int | None = None
codec_frames_per_sdu: int | None = None
@classmethod
def from_bytes(cls, data: bytes) -> CodecSpecificConfiguration:
offset = 0
sampling_frequency: SamplingFrequency | None = None
frame_duration: FrameDuration | None = None
audio_channel_allocation: AudioLocation | None = None
octets_per_codec_frame: int | None = None
codec_frames_per_sdu: int | None = None
while offset < len(data):
length, type = struct.unpack_from('BB', data, offset)
offset += 2
value = int.from_bytes(data[offset : offset + length - 1], 'little')
offset += length - 1
if type == CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY:
sampling_frequency = SamplingFrequency(value)
elif type == CodecSpecificConfiguration.Type.FRAME_DURATION:
frame_duration = FrameDuration(value)
elif type == CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION:
audio_channel_allocation = AudioLocation(value)
elif type == CodecSpecificConfiguration.Type.OCTETS_PER_FRAME:
octets_per_codec_frame = value
elif type == CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU:
codec_frames_per_sdu = value
return CodecSpecificConfiguration(
sampling_frequency=sampling_frequency,
frame_duration=frame_duration,
audio_channel_allocation=audio_channel_allocation,
octets_per_codec_frame=octets_per_codec_frame,
codec_frames_per_sdu=codec_frames_per_sdu,
)
def __bytes__(self) -> bytes:
return b''.join(
[
struct.pack(fmt, length, tag, value)
for fmt, length, tag, value in [
(
'<BBB',
2,
CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY,
self.sampling_frequency,
),
(
'<BBB',
2,
CodecSpecificConfiguration.Type.FRAME_DURATION,
self.frame_duration,
),
(
'<BBI',
5,
CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION,
self.audio_channel_allocation,
),
(
'<BBH',
3,
CodecSpecificConfiguration.Type.OCTETS_PER_FRAME,
self.octets_per_codec_frame,
),
(
'<BBB',
2,
CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU,
self.codec_frames_per_sdu,
),
]
if value is not None
]
)
@dataclasses.dataclass
class BroadcastAudioAnnouncement:
broadcast_id: int
@classmethod
def from_bytes(cls, data: bytes) -> Self:
return cls(int.from_bytes(data[:3], 'little'))
def __bytes__(self) -> bytes:
return self.broadcast_id.to_bytes(3, 'little')
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
(
bytes(gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE)
+ bytes(self)
),
)
]
)
)
@dataclasses.dataclass
class BasicAudioAnnouncement:
@dataclasses.dataclass
class BIS:
index: int
codec_specific_configuration: CodecSpecificConfiguration
def __bytes__(self) -> bytes:
codec_specific_configuration_bytes = bytes(
self.codec_specific_configuration
)
return (
bytes([self.index, len(codec_specific_configuration_bytes)])
+ codec_specific_configuration_bytes
)
@dataclasses.dataclass
class Subgroup:
codec_id: hci.CodingFormat
codec_specific_configuration: CodecSpecificConfiguration
metadata: le_audio.Metadata
bis: List[BasicAudioAnnouncement.BIS]
def __bytes__(self) -> bytes:
metadata_bytes = bytes(self.metadata)
codec_specific_configuration_bytes = bytes(
self.codec_specific_configuration
)
return (
bytes([len(self.bis)])
+ bytes(self.codec_id)
+ bytes([len(codec_specific_configuration_bytes)])
+ codec_specific_configuration_bytes
+ bytes([len(metadata_bytes)])
+ metadata_bytes
+ b''.join(map(bytes, self.bis))
)
presentation_delay: int
subgroups: List[BasicAudioAnnouncement.Subgroup]
@classmethod
def from_bytes(cls, data: bytes) -> Self:
presentation_delay = int.from_bytes(data[:3], 'little')
subgroups = []
offset = 4
for _ in range(data[3]):
num_bis = data[offset]
offset += 1
codec_id = hci.CodingFormat.from_bytes(data[offset : offset + 5])
offset += 5
codec_specific_configuration_length = data[offset]
offset += 1
codec_specific_configuration = data[
offset : offset + codec_specific_configuration_length
]
offset += codec_specific_configuration_length
metadata_length = data[offset]
offset += 1
metadata = le_audio.Metadata.from_bytes(
data[offset : offset + metadata_length]
)
offset += metadata_length
bis = []
for _ in range(num_bis):
bis_index = data[offset]
offset += 1
bis_codec_specific_configuration_length = data[offset]
offset += 1
bis_codec_specific_configuration = data[
offset : offset + bis_codec_specific_configuration_length
]
offset += bis_codec_specific_configuration_length
bis.append(
cls.BIS(
bis_index,
CodecSpecificConfiguration.from_bytes(
bis_codec_specific_configuration
),
)
)
subgroups.append(
cls.Subgroup(
codec_id,
CodecSpecificConfiguration.from_bytes(codec_specific_configuration),
metadata,
bis,
)
)
return cls(presentation_delay, subgroups)
def __bytes__(self) -> bytes:
return (
self.presentation_delay.to_bytes(3, 'little')
+ bytes([len(self.subgroups)])
+ b''.join(map(bytes, self.subgroups))
)
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
(
bytes(gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE)
+ bytes(self)
),
)
]
)
)

437
bumble/profiles/bass.py Normal file
View File

@@ -0,0 +1,437 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for
"""LE Audio - Broadcast Audio Scan Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import logging
import struct
from typing import ClassVar, List, Optional, Sequence
from bumble import core
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import hci
from bumble import utils
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class ApplicationError(utils.OpenIntEnum):
OPCODE_NOT_SUPPORTED = 0x80
INVALID_SOURCE_ID = 0x81
# -----------------------------------------------------------------------------
def encode_subgroups(subgroups: Sequence[SubgroupInfo]) -> bytes:
return bytes([len(subgroups)]) + b"".join(
struct.pack("<IB", subgroup.bis_sync, len(subgroup.metadata))
+ subgroup.metadata
for subgroup in subgroups
)
def decode_subgroups(data: bytes) -> List[SubgroupInfo]:
num_subgroups = data[0]
offset = 1
subgroups = []
for _ in range(num_subgroups):
bis_sync = struct.unpack("<I", data[offset : offset + 4])[0]
metadata_length = data[offset + 4]
metadata = data[offset + 5 : offset + 5 + metadata_length]
offset += 5 + metadata_length
subgroups.append(SubgroupInfo(bis_sync, metadata))
return subgroups
# -----------------------------------------------------------------------------
class PeriodicAdvertisingSyncParams(utils.OpenIntEnum):
DO_NOT_SYNCHRONIZE_TO_PA = 0x00
SYNCHRONIZE_TO_PA_PAST_AVAILABLE = 0x01
SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE = 0x02
@dataclasses.dataclass
class SubgroupInfo:
ANY_BIS: ClassVar[int] = 0xFFFFFFFF
bis_sync: int
metadata: bytes
class ControlPointOperation:
class OpCode(utils.OpenIntEnum):
REMOTE_SCAN_STOPPED = 0x00
REMOTE_SCAN_STARTED = 0x01
ADD_SOURCE = 0x02
MODIFY_SOURCE = 0x03
SET_BROADCAST_CODE = 0x04
REMOVE_SOURCE = 0x05
op_code: OpCode
parameters: bytes
@classmethod
def from_bytes(cls, data: bytes) -> ControlPointOperation:
op_code = data[0]
if op_code == cls.OpCode.REMOTE_SCAN_STOPPED:
return RemoteScanStoppedOperation()
if op_code == cls.OpCode.REMOTE_SCAN_STARTED:
return RemoteScanStartedOperation()
if op_code == cls.OpCode.ADD_SOURCE:
return AddSourceOperation.from_parameters(data[1:])
if op_code == cls.OpCode.MODIFY_SOURCE:
return ModifySourceOperation.from_parameters(data[1:])
if op_code == cls.OpCode.SET_BROADCAST_CODE:
return SetBroadcastCodeOperation.from_parameters(data[1:])
if op_code == cls.OpCode.REMOVE_SOURCE:
return RemoveSourceOperation.from_parameters(data[1:])
raise core.InvalidArgumentError("invalid op code")
def __init__(self, op_code: OpCode, parameters: bytes = b"") -> None:
self.op_code = op_code
self.parameters = parameters
def __bytes__(self) -> bytes:
return bytes([self.op_code]) + self.parameters
class RemoteScanStoppedOperation(ControlPointOperation):
def __init__(self) -> None:
super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STOPPED)
class RemoteScanStartedOperation(ControlPointOperation):
def __init__(self) -> None:
super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STARTED)
class AddSourceOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> AddSourceOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.ADD_SOURCE
instance.parameters = parameters
instance.advertiser_address = hci.Address.parse_address_preceded_by_type(
parameters, 1
)[1]
instance.advertising_sid = parameters[7]
instance.broadcast_id = int.from_bytes(parameters[8:11], "little")
instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[11])
instance.pa_interval = struct.unpack("<H", parameters[12:14])[0]
instance.subgroups = decode_subgroups(parameters[14:])
return instance
def __init__(
self,
advertiser_address: hci.Address,
advertising_sid: int,
broadcast_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
super().__init__(
ControlPointOperation.OpCode.ADD_SOURCE,
struct.pack(
"<B6sB3sBH",
advertiser_address.address_type,
bytes(advertiser_address),
advertising_sid,
broadcast_id.to_bytes(3, "little"),
pa_sync,
pa_interval,
)
+ encode_subgroups(subgroups),
)
self.advertiser_address = advertiser_address
self.advertising_sid = advertising_sid
self.broadcast_id = broadcast_id
self.pa_sync = pa_sync
self.pa_interval = pa_interval
self.subgroups = list(subgroups)
class ModifySourceOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> ModifySourceOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.MODIFY_SOURCE
instance.parameters = parameters
instance.source_id = parameters[0]
instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[1])
instance.pa_interval = struct.unpack("<H", parameters[2:4])[0]
instance.subgroups = decode_subgroups(parameters[4:])
return instance
def __init__(
self,
source_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
super().__init__(
ControlPointOperation.OpCode.MODIFY_SOURCE,
struct.pack("<BBH", source_id, pa_sync, pa_interval)
+ encode_subgroups(subgroups),
)
self.source_id = source_id
self.pa_sync = pa_sync
self.pa_interval = pa_interval
self.subgroups = list(subgroups)
class SetBroadcastCodeOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> SetBroadcastCodeOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.SET_BROADCAST_CODE
instance.parameters = parameters
instance.source_id = parameters[0]
instance.broadcast_code = parameters[1:17]
return instance
def __init__(
self,
source_id: int,
broadcast_code: bytes,
) -> None:
super().__init__(
ControlPointOperation.OpCode.SET_BROADCAST_CODE,
bytes([source_id]) + broadcast_code,
)
self.source_id = source_id
self.broadcast_code = broadcast_code
if len(self.broadcast_code) != 16:
raise core.InvalidArgumentError("broadcast_code must be 16 bytes")
class RemoveSourceOperation(ControlPointOperation):
@classmethod
def from_parameters(cls, parameters: bytes) -> RemoveSourceOperation:
instance = cls.__new__(cls)
instance.op_code = ControlPointOperation.OpCode.REMOVE_SOURCE
instance.parameters = parameters
instance.source_id = parameters[0]
return instance
def __init__(self, source_id: int) -> None:
super().__init__(ControlPointOperation.OpCode.REMOVE_SOURCE, bytes([source_id]))
self.source_id = source_id
@dataclasses.dataclass
class BroadcastReceiveState:
class PeriodicAdvertisingSyncState(utils.OpenIntEnum):
NOT_SYNCHRONIZED_TO_PA = 0x00
SYNCINFO_REQUEST = 0x01
SYNCHRONIZED_TO_PA = 0x02
FAILED_TO_SYNCHRONIZE_TO_PA = 0x03
NO_PAST = 0x04
class BigEncryption(utils.OpenIntEnum):
NOT_ENCRYPTED = 0x00
BROADCAST_CODE_REQUIRED = 0x01
DECRYPTING = 0x02
BAD_CODE = 0x03
source_id: int
source_address: hci.Address
source_adv_sid: int
broadcast_id: int
pa_sync_state: PeriodicAdvertisingSyncState
big_encryption: BigEncryption
bad_code: bytes
subgroups: List[SubgroupInfo]
@classmethod
def from_bytes(cls, data: bytes) -> BroadcastReceiveState:
source_id = data[0]
_, source_address = hci.Address.parse_address_preceded_by_type(data, 2)
source_adv_sid = data[8]
broadcast_id = int.from_bytes(data[9:12], "little")
pa_sync_state = cls.PeriodicAdvertisingSyncState(data[12])
big_encryption = cls.BigEncryption(data[13])
if big_encryption == cls.BigEncryption.BAD_CODE:
bad_code = data[14:30]
subgroups = decode_subgroups(data[30:])
else:
bad_code = b""
subgroups = decode_subgroups(data[14:])
return cls(
source_id,
source_address,
source_adv_sid,
broadcast_id,
pa_sync_state,
big_encryption,
bad_code,
subgroups,
)
def __bytes__(self) -> bytes:
return (
struct.pack(
"<BB6sB3sBB",
self.source_id,
self.source_address.address_type,
bytes(self.source_address),
self.source_adv_sid,
self.broadcast_id.to_bytes(3, "little"),
self.pa_sync_state,
self.big_encryption,
)
+ self.bad_code
+ encode_subgroups(self.subgroups)
)
# -----------------------------------------------------------------------------
class BroadcastAudioScanService(gatt.TemplateService):
UUID = gatt.GATT_BROADCAST_AUDIO_SCAN_SERVICE
def __init__(self):
self.broadcast_audio_scan_control_point_characteristic = gatt.Characteristic(
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC,
gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
gatt.Characteristic.WRITEABLE,
gatt.CharacteristicValue(
write=self.on_broadcast_audio_scan_control_point_write
),
)
self.broadcast_receive_state_characteristic = gatt.Characteristic(
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC,
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY,
gatt.Characteristic.Permissions.READABLE
| gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
b"12", # TEST
)
super().__init__([self.battery_level_characteristic])
def on_broadcast_audio_scan_control_point_write(
self, connection: device.Connection, value: bytes
) -> None:
pass
# -----------------------------------------------------------------------------
class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = BroadcastAudioScanService
broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy
broadcast_receive_states: List[gatt.SerializableCharacteristicAdapter]
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Broadcast Audio Scan Control Point characteristic not found"
)
self.broadcast_audio_scan_control_point = characteristics[0]
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC
)
):
raise gatt.InvalidServiceError(
"Broadcast Receive State characteristic not found"
)
self.broadcast_receive_states = [
gatt.SerializableCharacteristicAdapter(
characteristic, BroadcastReceiveState
)
for characteristic in characteristics
]
async def send_control_point_operation(
self, operation: ControlPointOperation
) -> None:
await self.broadcast_audio_scan_control_point.write_value(
bytes(operation), with_response=True
)
async def remote_scan_started(self) -> None:
await self.send_control_point_operation(RemoteScanStartedOperation())
async def remote_scan_stopped(self) -> None:
await self.send_control_point_operation(RemoteScanStoppedOperation())
async def add_source(
self,
advertiser_address: hci.Address,
advertising_sid: int,
broadcast_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
await self.send_control_point_operation(
AddSourceOperation(
advertiser_address,
advertising_sid,
broadcast_id,
pa_sync,
pa_interval,
subgroups,
)
)
async def modify_source(
self,
source_id: int,
pa_sync: PeriodicAdvertisingSyncParams,
pa_interval: int,
subgroups: Sequence[SubgroupInfo],
) -> None:
await self.send_control_point_operation(
ModifySourceOperation(
source_id,
pa_sync,
pa_interval,
subgroups,
)
)
async def remove_source(self, source_id: int) -> None:
await self.send_control_point_operation(RemoveSourceOperation(source_id))

52
bumble/profiles/cap.py Normal file
View File

@@ -0,0 +1,52 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from bumble import gatt
from bumble import gatt_client
from bumble.profiles import csip
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CommonAudioServiceService(gatt.TemplateService):
UUID = gatt.GATT_COMMON_AUDIO_SERVICE
def __init__(
self,
coordinated_set_identification_service: csip.CoordinatedSetIdentificationService,
) -> None:
self.coordinated_set_identification_service = (
coordinated_set_identification_service
)
super().__init__(
characteristics=[],
included_services=[coordinated_set_identification_service],
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CommonAudioServiceServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CommonAudioServiceService
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy

View File

@@ -19,8 +19,11 @@
from __future__ import annotations
import enum
import struct
from typing import Optional
from typing import Optional, Tuple
from bumble import core
from bumble import crypto
from bumble import device
from bumble import gatt
from bumble import gatt_client
@@ -28,6 +31,9 @@ from bumble import gatt_client
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
SET_IDENTITY_RESOLVING_KEY_LENGTH = 16
class SirkType(enum.IntEnum):
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
@@ -43,9 +49,47 @@ class MemberLock(enum.IntEnum):
# -----------------------------------------------------------------------------
# Utils
# Crypto Toolbox
# -----------------------------------------------------------------------------
# TODO: Implement RSI Generator
def s1(m: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.3 s1 SALT generation function.
'''
return crypto.aes_cmac(m[::-1], bytes(16))[::-1]
def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.4 k1 derivation function.
'''
t = crypto.aes_cmac(n[::-1], salt[::-1])
return crypto.aes_cmac(p[::-1], t)[::-1]
def sef(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
* Plaintext in encryption
* Cipher in decryption
'''
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
def sih(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
'''
return crypto.e(k, r + bytes(13))[:3]
def generate_rsi(sirk: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
'''
prand = crypto.generate_prand()
return sih(sirk, prand) + prand
# -----------------------------------------------------------------------------
@@ -54,6 +98,7 @@ class MemberLock(enum.IntEnum):
class CoordinatedSetIdentificationService(gatt.TemplateService):
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
set_identity_resolving_key: bytes
set_identity_resolving_key_characteristic: gatt.Characteristic
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
set_member_lock_characteristic: Optional[gatt.Characteristic] = None
@@ -62,19 +107,26 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
def __init__(
self,
set_identity_resolving_key: bytes,
set_identity_resolving_key_type: SirkType,
coordinated_set_size: Optional[int] = None,
set_member_lock: Optional[MemberLock] = None,
set_member_rank: Optional[int] = None,
) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise core.InvalidArgumentError(
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
)
characteristics = []
self.set_identity_resolving_key = set_identity_resolving_key
self.set_identity_resolving_key_type = set_identity_resolving_key_type
self.set_identity_resolving_key_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
# TODO: Implement encrypted SIRK reader.
value=struct.pack('B', SirkType.PLAINTEXT) + set_identity_resolving_key,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self.on_sirk_read),
)
characteristics.append(self.set_identity_resolving_key_characteristic)
@@ -83,7 +135,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=struct.pack('B', coordinated_set_size),
)
characteristics.append(self.coordinated_set_size_characteristic)
@@ -94,7 +146,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
| gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.READABLE
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITEABLE,
value=struct.pack('B', set_member_lock),
)
@@ -105,13 +157,45 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=struct.pack('B', set_member_rank),
)
characteristics.append(self.set_member_rank_characteristic)
super().__init__(characteristics)
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
sirk_bytes = self.set_identity_resolving_key
else:
assert connection
if connection.transport == core.BT_LE_TRANSPORT:
key = await connection.device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await connection.device.get_link_key(connection.peer_address)
if not key:
raise core.InvalidOperationError('LTK or LinkKey is not present')
sirk_bytes = sef(key, self.set_identity_resolving_key)
return bytes([self.set_identity_resolving_key_type]) + sirk_bytes
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
generate_rsi(self.set_identity_resolving_key),
),
]
)
)
# -----------------------------------------------------------------------------
# Client
@@ -145,3 +229,29 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
):
self.set_member_rank = characteristics[0]
async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
'''Reads SIRK and decrypts if encrypted.'''
response = await self.set_identity_resolving_key.read_value()
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
raise core.InvalidPacketError('Invalid SIRK value')
sirk_type = SirkType(response[0])
if sirk_type == SirkType.PLAINTEXT:
sirk = response[1:]
else:
connection = self.service_proxy.client.connection
device = connection.device
if connection.transport == core.BT_LE_TRANSPORT:
key = await device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await device.get_link_key(connection.peer_address)
if not key:
raise core.InvalidOperationError('LTK or LinkKey is not present')
sirk = sef(key, response[1:])
return (sirk_type, sirk)

View File

@@ -19,8 +19,8 @@
import struct
from typing import Optional, Tuple
from ..gatt_client import ProfileServiceProxy
from ..gatt import (
from bumble.gatt_client import ServiceProxy, ProfileServiceProxy, CharacteristicProxy
from bumble.gatt import (
GATT_DEVICE_INFORMATION_SERVICE,
GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC,
GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC,
@@ -59,12 +59,15 @@ class DeviceInformationService(TemplateService):
firmware_revision: Optional[str] = None,
software_revision: Optional[str] = None,
system_id: Optional[Tuple[int, int]] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: Optional[bytes] = None
ieee_regulatory_certification_data_list: Optional[bytes] = None,
# TODO: pnp_id
):
characteristics = [
Characteristic(
uuid, Characteristic.Properties.READ, Characteristic.READABLE, field
uuid,
Characteristic.Properties.READ,
Characteristic.READABLE,
bytes(field, 'utf-8'),
)
for (field, uuid) in (
(manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
@@ -104,10 +107,19 @@ class DeviceInformationService(TemplateService):
class DeviceInformationServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = DeviceInformationService
def __init__(self, service_proxy):
manufacturer_name: Optional[UTF8CharacteristicAdapter]
model_number: Optional[UTF8CharacteristicAdapter]
serial_number: Optional[UTF8CharacteristicAdapter]
hardware_revision: Optional[UTF8CharacteristicAdapter]
firmware_revision: Optional[UTF8CharacteristicAdapter]
software_revision: Optional[UTF8CharacteristicAdapter]
system_id: Optional[DelegatedCharacteristicAdapter]
ieee_regulatory_certification_data_list: Optional[CharacteristicProxy]
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy
for (field, uuid) in (
for field, uuid in (
('manufacturer_name', GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
('model_number', GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),

110
bumble/profiles/gap.py Normal file
View File

@@ -0,0 +1,110 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic Access Profile"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
import struct
from typing import Optional, Tuple, Union
from bumble.core import Appearance
from bumble.gatt import (
TemplateService,
Characteristic,
CharacteristicAdapter,
DelegatedCharacteristicAdapter,
UTF8CharacteristicAdapter,
GATT_GENERIC_ACCESS_SERVICE,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class GenericAccessService(TemplateService):
UUID = GATT_GENERIC_ACCESS_SERVICE
def __init__(
self, device_name: str, appearance: Union[Appearance, Tuple[int, int], int] = 0
):
if isinstance(appearance, int):
appearance_int = appearance
elif isinstance(appearance, tuple):
appearance_int = (appearance[0] << 6) | appearance[1]
elif isinstance(appearance, Appearance):
appearance_int = int(appearance)
else:
raise TypeError()
self.device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
device_name.encode('utf-8')[:248],
)
self.appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', appearance_int),
)
super().__init__(
[self.device_name_characteristic, self.appearance_characteristic]
)
# -----------------------------------------------------------------------------
class GenericAccessServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GenericAccessService
device_name: Optional[CharacteristicAdapter]
appearance: Optional[DelegatedCharacteristicAdapter]
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_DEVICE_NAME_CHARACTERISTIC
):
self.device_name = UTF8CharacteristicAdapter(characteristics[0])
else:
self.device_name = None
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_APPEARANCE_CHARACTERISTIC
):
self.appearance = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: Appearance.from_int(
struct.unpack_from('<H', value, 0)[0],
),
)
else:
self.appearance = None

198
bumble/profiles/gmap.py Normal file
View File

@@ -0,0 +1,198 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LE Audio - Gaming Audio Profile"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
from typing import Optional
from bumble.gatt import (
TemplateService,
DelegatedCharacteristicAdapter,
Characteristic,
GATT_GAMING_AUDIO_SERVICE,
GATT_GMAP_ROLE_CHARACTERISTIC,
GATT_UGG_FEATURES_CHARACTERISTIC,
GATT_UGT_FEATURES_CHARACTERISTIC,
GATT_BGS_FEATURES_CHARACTERISTIC,
GATT_BGR_FEATURES_CHARACTERISTIC,
InvalidServiceError,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
from enum import IntFlag
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class GmapRole(IntFlag):
UNICAST_GAME_GATEWAY = 1 << 0
UNICAST_GAME_TERMINAL = 1 << 1
BROADCAST_GAME_SENDER = 1 << 2
BROADCAST_GAME_RECEIVER = 1 << 3
class UggFeatures(IntFlag):
UGG_MULTIPLEX = 1 << 0
UGG_96_KBPS_SOURCE = 1 << 1
UGG_MULTISINK = 1 << 2
class UgtFeatures(IntFlag):
UGT_SOURCE = 1 << 0
UGT_80_KBPS_SOURCE = 1 << 1
UGT_SINK = 1 << 2
UGT_64_KBPS_SINK = 1 << 3
UGT_MULTIPLEX = 1 << 4
UGT_MULTISINK = 1 << 5
UGT_MULTISOURCE = 1 << 6
class BgsFeatures(IntFlag):
BGS_96_KBPS = 1 << 0
class BgrFeatures(IntFlag):
BGR_MULTISINK = 1 << 0
BGR_MULTIPLEX = 1 << 1
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class GamingAudioService(TemplateService):
UUID = GATT_GAMING_AUDIO_SERVICE
gmap_role: Characteristic
ugg_features: Optional[Characteristic] = None
ugt_features: Optional[Characteristic] = None
bgs_features: Optional[Characteristic] = None
bgr_features: Optional[Characteristic] = None
def __init__(
self,
gmap_role: GmapRole,
ugg_features: Optional[UggFeatures] = None,
ugt_features: Optional[UgtFeatures] = None,
bgs_features: Optional[BgsFeatures] = None,
bgr_features: Optional[BgrFeatures] = None,
) -> None:
characteristics = []
ugg_features = UggFeatures(0) if ugg_features is None else ugg_features
ugt_features = UgtFeatures(0) if ugt_features is None else ugt_features
bgs_features = BgsFeatures(0) if bgs_features is None else bgs_features
bgr_features = BgrFeatures(0) if bgr_features is None else bgr_features
self.gmap_role = Characteristic(
uuid=GATT_GMAP_ROLE_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', gmap_role),
)
characteristics.append(self.gmap_role)
if gmap_role & GmapRole.UNICAST_GAME_GATEWAY:
self.ugg_features = Characteristic(
uuid=GATT_UGG_FEATURES_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', ugg_features),
)
characteristics.append(self.ugg_features)
if gmap_role & GmapRole.UNICAST_GAME_TERMINAL:
self.ugt_features = Characteristic(
uuid=GATT_UGT_FEATURES_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', ugt_features),
)
characteristics.append(self.ugt_features)
if gmap_role & GmapRole.BROADCAST_GAME_SENDER:
self.bgs_features = Characteristic(
uuid=GATT_BGS_FEATURES_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', bgs_features),
)
characteristics.append(self.bgs_features)
if gmap_role & GmapRole.BROADCAST_GAME_RECEIVER:
self.bgr_features = Characteristic(
uuid=GATT_BGR_FEATURES_CHARACTERISTIC,
properties=Characteristic.Properties.READ,
permissions=Characteristic.Permissions.READABLE,
value=struct.pack('B', bgr_features),
)
characteristics.append(self.bgr_features)
super().__init__(characteristics)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class GamingAudioServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GamingAudioService
def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_GMAP_ROLE_CHARACTERISTIC
)
):
raise InvalidServiceError("GMAP Role Characteristic not found")
self.gmap_role = DelegatedCharacteristicAdapter(
characteristic=characteristics[0],
decode=lambda value: GmapRole(value[0]),
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_UGG_FEATURES_CHARACTERISTIC
):
self.ugg_features = DelegatedCharacteristicAdapter(
characteristic=characteristics[0],
decode=lambda value: UggFeatures(value[0]),
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_UGT_FEATURES_CHARACTERISTIC
):
self.ugt_features = DelegatedCharacteristicAdapter(
characteristic=characteristics[0],
decode=lambda value: UgtFeatures(value[0]),
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BGS_FEATURES_CHARACTERISTIC
):
self.bgs_features = DelegatedCharacteristicAdapter(
characteristic=characteristics[0],
decode=lambda value: BgsFeatures(value[0]),
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BGR_FEATURES_CHARACTERISTIC
):
self.bgr_features = DelegatedCharacteristicAdapter(
characteristic=characteristics[0],
decode=lambda value: BgrFeatures(value[0]),
)

674
bumble/profiles/hap.py Normal file
View File

@@ -0,0 +1,674 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import functools
from bumble import att, gatt, gatt_client
from bumble.core import InvalidArgumentError, InvalidStateError
from bumble.device import Device, Connection
from bumble.utils import AsyncRunner, OpenIntEnum
from bumble.hci import Address
from dataclasses import dataclass, field
import logging
from typing import Any, Dict, List, Optional, Set, Union
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class ErrorCode(OpenIntEnum):
'''See Hearing Access Service 2.4. Attribute Profile error codes.'''
INVALID_OPCODE = 0x80
WRITE_NAME_NOT_ALLOWED = 0x81
PRESET_SYNCHRONIZATION_NOT_SUPPORTED = 0x82
PRESET_OPERATION_NOT_POSSIBLE = 0x83
INVALID_PARAMETERS_LENGTH = 0x84
class HearingAidType(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
BINAURAL_HEARING_AID = 0b00
MONAURAL_HEARING_AID = 0b01
BANDED_HEARING_AID = 0b10
class PresetSynchronizationSupport(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED = 0b0
PRESET_SYNCHRONIZATION_IS_SUPPORTED = 0b1
class IndependentPresets(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
IDENTICAL_PRESET_RECORD = 0b0
DIFFERENT_PRESET_RECORD = 0b1
class DynamicPresets(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
PRESET_RECORDS_DOES_NOT_CHANGE = 0b0
PRESET_RECORDS_MAY_CHANGE = 0b1
class WritablePresetsSupport(OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
WRITABLE_PRESET_RECORDS_NOT_SUPPORTED = 0b0
WRITABLE_PRESET_RECORDS_SUPPORTED = 0b1
class HearingAidPresetControlPointOpcode(OpenIntEnum):
'''See Hearing Access Service 3.3.1 Hearing Aid Preset Control Point operation requirements.'''
# fmt: off
READ_PRESETS_REQUEST = 0x01
READ_PRESET_RESPONSE = 0x02
PRESET_CHANGED = 0x03
WRITE_PRESET_NAME = 0x04
SET_ACTIVE_PRESET = 0x05
SET_NEXT_PRESET = 0x06
SET_PREVIOUS_PRESET = 0x07
SET_ACTIVE_PRESET_SYNCHRONIZED_LOCALLY = 0x08
SET_NEXT_PRESET_SYNCHRONIZED_LOCALLY = 0x09
SET_PREVIOUS_PRESET_SYNCHRONIZED_LOCALLY = 0x0A
@dataclass
class HearingAidFeatures:
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
hearing_aid_type: HearingAidType
preset_synchronization_support: PresetSynchronizationSupport
independent_presets: IndependentPresets
dynamic_presets: DynamicPresets
writable_presets_support: WritablePresetsSupport
def __bytes__(self) -> bytes:
return bytes(
[
(self.hearing_aid_type << 0)
| (self.preset_synchronization_support << 2)
| (self.independent_presets << 3)
| (self.dynamic_presets << 4)
| (self.writable_presets_support << 5)
]
)
def HearingAidFeatures_from_bytes(data: int) -> HearingAidFeatures:
return HearingAidFeatures(
HearingAidType(data & 0b11),
PresetSynchronizationSupport(data >> 2 & 0b1),
IndependentPresets(data >> 3 & 0b1),
DynamicPresets(data >> 4 & 0b1),
WritablePresetsSupport(data >> 5 & 0b1),
)
@dataclass
class PresetChangedOperation:
'''See Hearing Access Service 3.2.2.2. Preset Changed operation.'''
class ChangeId(OpenIntEnum):
# fmt: off
GENERIC_UPDATE = 0x00
PRESET_RECORD_DELETED = 0x01
PRESET_RECORD_AVAILABLE = 0x02
PRESET_RECORD_UNAVAILABLE = 0x03
@dataclass
class Generic:
prev_index: int
preset_record: PresetRecord
def __bytes__(self) -> bytes:
return bytes([self.prev_index]) + bytes(self.preset_record)
change_id: ChangeId
additional_parameters: Union[Generic, int]
def to_bytes(self, is_last: bool) -> bytes:
if isinstance(self.additional_parameters, PresetChangedOperation.Generic):
additional_parameters_bytes = bytes(self.additional_parameters)
else:
additional_parameters_bytes = bytes([self.additional_parameters])
return (
bytes(
[
HearingAidPresetControlPointOpcode.PRESET_CHANGED,
self.change_id,
is_last,
]
)
+ additional_parameters_bytes
)
class PresetChangedOperationDeleted(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_DELETED
self.additional_parameters = index
class PresetChangedOperationAvailable(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_AVAILABLE
self.additional_parameters = index
class PresetChangedOperationUnavailable(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_UNAVAILABLE
self.additional_parameters = index
@dataclass
class PresetRecord:
'''See Hearing Access Service 2.8. Preset record.'''
@dataclass
class Property:
class Writable(OpenIntEnum):
CANNOT_BE_WRITTEN = 0b0
CAN_BE_WRITTEN = 0b1
class IsAvailable(OpenIntEnum):
IS_UNAVAILABLE = 0b0
IS_AVAILABLE = 0b1
writable: Writable = Writable.CAN_BE_WRITTEN
is_available: IsAvailable = IsAvailable.IS_AVAILABLE
def __bytes__(self) -> bytes:
return bytes([self.writable | (self.is_available << 1)])
index: int
name: str
properties: Property = field(default_factory=Property)
def __bytes__(self) -> bytes:
return bytes([self.index]) + bytes(self.properties) + self.name.encode('utf-8')
def is_available(self) -> bool:
return (
self.properties.is_available
== PresetRecord.Property.IsAvailable.IS_AVAILABLE
)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class HearingAccessService(gatt.TemplateService):
UUID = gatt.GATT_HEARING_ACCESS_SERVICE
hearing_aid_features_characteristic: gatt.Characteristic
hearing_aid_preset_control_point: gatt.Characteristic
active_preset_index_characteristic: gatt.Characteristic
active_preset_index: int
active_preset_index_per_device: Dict[Address, int]
device: Device
server_features: HearingAidFeatures
preset_records: Dict[int, PresetRecord] # key is the preset index
read_presets_request_in_progress: bool
preset_changed_operations_history_per_device: Dict[
Address, List[PresetChangedOperation]
]
# Keep an updated list of connected client to send notification to
currently_connected_clients: Set[Connection]
def __init__(
self, device: Device, features: HearingAidFeatures, presets: List[PresetRecord]
) -> None:
self.active_preset_index_per_device = {}
self.read_presets_request_in_progress = False
self.preset_changed_operations_history_per_device = {}
self.currently_connected_clients = set()
self.device = device
self.server_features = features
if len(presets) < 1:
raise InvalidArgumentError(f'Invalid presets: {presets}')
self.preset_records = {}
for p in presets:
if len(p.name.encode()) < 1 or len(p.name.encode()) > 40:
raise InvalidArgumentError(f'Invalid name: {p.name}')
self.preset_records[p.index] = p
# associate the lowest index as the current active preset at startup
self.active_preset_index = sorted(self.preset_records.keys())[0]
@device.on('connection') # type: ignore
def on_connection(connection: Connection) -> None:
@connection.on('disconnection') # type: ignore
def on_disconnection(_reason) -> None:
self.currently_connected_clients.remove(connection)
@connection.on('pairing') # type: ignore
def on_pairing(*_: Any) -> None:
self.on_incoming_paired_connection(connection)
if connection.peer_resolvable_address:
self.on_incoming_paired_connection(connection)
self.hearing_aid_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes(self.server_features),
)
self.hearing_aid_preset_control_point = gatt.Characteristic(
uuid=gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.INDICATE
),
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(
write=self._on_write_hearing_aid_preset_control_point
),
)
self.active_preset_index_characteristic = gatt.Characteristic(
uuid=gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self._on_read_active_preset_index),
)
super().__init__(
[
self.hearing_aid_features_characteristic,
self.hearing_aid_preset_control_point,
self.active_preset_index_characteristic,
]
)
def on_incoming_paired_connection(self, connection: Connection):
'''Setup initial operations to handle a remote bonded HAP device'''
# TODO Should we filter on HAP device only ?
self.currently_connected_clients.add(connection)
if (
connection.peer_address
not in self.preset_changed_operations_history_per_device
):
self.preset_changed_operations_history_per_device[
connection.peer_address
] = []
return
async def on_connection_async() -> None:
# Send all the PresetChangedOperation that occur when not connected
await self._preset_changed_operation(connection)
# Update the active preset index if needed
await self.notify_active_preset_for_connection(connection)
connection.abort_on('disconnection', on_connection_async())
def _on_read_active_preset_index(
self, __connection__: Optional[Connection]
) -> bytes:
return bytes([self.active_preset_index])
# TODO this need to be triggered when device is unbonded
def on_forget(self, addr: Address) -> None:
self.preset_changed_operations_history_per_device.pop(addr)
async def _on_write_hearing_aid_preset_control_point(
self, connection: Optional[Connection], value: bytes
):
assert connection
opcode = HearingAidPresetControlPointOpcode(value[0])
handler = getattr(self, '_on_' + opcode.name.lower())
await handler(connection, value)
async def _on_read_presets_request(
self, connection: Optional[Connection], value: bytes
):
assert connection
if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements
logging.warning(f'HAS require MTU >= 49: {connection}')
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
self.read_presets_request_in_progress = True
start_index = value[1]
if start_index == 0x00:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
num_presets = value[2]
if num_presets == 0x00:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
# Sending `num_presets` presets ordered by increasing index field, starting from start_index
presets = [
self.preset_records[key]
for key in sorted(self.preset_records.keys())
if self.preset_records[key].index >= start_index
]
del presets[num_presets:]
if len(presets) == 0:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
AsyncRunner.spawn(self._read_preset_response(connection, presets))
async def _read_preset_response(
self, connection: Connection, presets: List[PresetRecord]
):
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Read Presets Request operation aborted and shall not either continue or restart the operation when the client reconnects.
try:
for i, preset in enumerate(presets):
await connection.device.indicate_subscriber(
connection,
self.hearing_aid_preset_control_point,
value=bytes(
[
HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE,
i == len(presets) - 1,
]
)
+ bytes(preset),
)
finally:
# indicate_subscriber can raise a TimeoutError, we need to gracefully terminate the operation
self.read_presets_request_in_progress = False
async def generic_update(self, op: PresetChangedOperation) -> None:
'''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent'''
await self._notifyPresetOperations(op)
async def delete_preset(self, index: int) -> None:
'''Server API to delete a preset. It should not be the current active preset'''
if index == self.active_preset_index:
raise InvalidStateError('Cannot delete active preset')
del self.preset_records[index]
await self._notifyPresetOperations(PresetChangedOperationDeleted(index))
async def available_preset(self, index: int) -> None:
'''Server API to make a preset available'''
preset = self.preset_records[index]
preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE
await self._notifyPresetOperations(PresetChangedOperationAvailable(index))
async def unavailable_preset(self, index: int) -> None:
'''Server API to make a preset unavailable. It should not be the current active preset'''
if index == self.active_preset_index:
raise InvalidStateError('Cannot set active preset as unavailable')
preset = self.preset_records[index]
preset.properties.is_available = (
PresetRecord.Property.IsAvailable.IS_UNAVAILABLE
)
await self._notifyPresetOperations(PresetChangedOperationUnavailable(index))
async def _preset_changed_operation(self, connection: Connection) -> None:
'''Send all PresetChangedOperation saved for a given connection'''
op_list = self.preset_changed_operations_history_per_device.get(
connection.peer_address, []
)
# Notification will be sent in index order
def get_op_index(op: PresetChangedOperation) -> int:
if isinstance(op.additional_parameters, PresetChangedOperation.Generic):
return op.additional_parameters.prev_index
return op.additional_parameters
op_list.sort(key=get_op_index)
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Preset Changed operation aborted and shall continue the operation when the client reconnects.
while len(op_list) > 0:
try:
await connection.device.indicate_subscriber(
connection,
self.hearing_aid_preset_control_point,
value=op_list[0].to_bytes(len(op_list) == 1),
)
# Remove item once sent, and keep the non sent item in the list
op_list.pop(0)
except TimeoutError:
break
async def _notifyPresetOperations(self, op: PresetChangedOperation) -> None:
for historyList in self.preset_changed_operations_history_per_device.values():
historyList.append(op)
for connection in self.currently_connected_clients:
await self._preset_changed_operation(connection)
async def _on_write_preset_name(
self, connection: Optional[Connection], value: bytes
):
assert connection
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
index = value[1]
preset = self.preset_records.get(index, None)
if (
not preset
or preset.properties.writable
== PresetRecord.Property.Writable.CANNOT_BE_WRITTEN
):
raise att.ATT_Error(ErrorCode.WRITE_NAME_NOT_ALLOWED)
name = value[2:].decode('utf-8')
if not name or len(name) > 40:
raise att.ATT_Error(ErrorCode.INVALID_PARAMETERS_LENGTH)
preset.name = name
await self.generic_update(
PresetChangedOperation(
PresetChangedOperation.ChangeId.GENERIC_UPDATE,
PresetChangedOperation.Generic(index, preset),
)
)
async def notify_active_preset_for_connection(self, connection: Connection) -> None:
if (
self.active_preset_index_per_device.get(connection.peer_address, 0x00)
== self.active_preset_index
):
# Nothing to do, peer is already updated
return
await connection.device.notify_subscriber(
connection,
attribute=self.active_preset_index_characteristic,
value=bytes([self.active_preset_index]),
)
self.active_preset_index_per_device[connection.peer_address] = (
self.active_preset_index
)
async def notify_active_preset(self) -> None:
for connection in self.currently_connected_clients:
await self.notify_active_preset_for_connection(connection)
async def set_active_preset(
self, connection: Optional[Connection], value: bytes
) -> None:
assert connection
index = value[1]
preset = self.preset_records.get(index, None)
if (
not preset
or preset.properties.is_available
!= PresetRecord.Property.IsAvailable.IS_AVAILABLE
):
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
if index == self.active_preset_index:
# Already at correct value
return
self.active_preset_index = index
await self.notify_active_preset()
async def _on_set_active_preset(
self, connection: Optional[Connection], value: bytes
):
await self.set_active_preset(connection, value)
async def set_next_or_previous_preset(
self, connection: Optional[Connection], is_previous
):
'''Set the next or the previous preset as active'''
assert connection
if self.active_preset_index == 0x00:
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
first_preset: Optional[PresetRecord] = None # To loop to first preset
next_preset: Optional[PresetRecord] = None
for index, record in sorted(self.preset_records.items(), reverse=is_previous):
if not record.is_available():
continue
if first_preset == None:
first_preset = record
if is_previous:
if index >= self.active_preset_index:
continue
elif index <= self.active_preset_index:
continue
next_preset = record
break
if not first_preset: # If no other preset are available
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
if next_preset:
self.active_preset_index = next_preset.index
else:
self.active_preset_index = first_preset.index
await self.notify_active_preset()
async def _on_set_next_preset(
self, connection: Optional[Connection], __value__: bytes
) -> None:
await self.set_next_or_previous_preset(connection, False)
async def _on_set_previous_preset(
self, connection: Optional[Connection], __value__: bytes
) -> None:
await self.set_next_or_previous_preset(connection, True)
async def _on_set_active_preset_synchronized_locally(
self, connection: Optional[Connection], value: bytes
):
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_active_preset(connection, value)
# TODO (low priority) inform other server of the change
async def _on_set_next_preset_synchronized_locally(
self, connection: Optional[Connection], __value__: bytes
):
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_next_or_previous_preset(connection, False)
# TODO (low priority) inform other server of the change
async def _on_set_previous_preset_synchronized_locally(
self, connection: Optional[Connection], __value__: bytes
):
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_next_or_previous_preset(connection, True)
# TODO (low priority) inform other server of the change
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HearingAccessService
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
preset_control_point_indications: asyncio.Queue
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.server_features = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC
)[0],
'B',
)
self.hearing_aid_preset_control_point = (
service_proxy.get_characteristics_by_uuid(
gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC
)[0]
)
self.active_preset_index = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC
)[0],
'B',
)
async def setup_subscription(self):
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
def on_active_preset_index_notification(data: bytes):
self.active_preset_index_notification.put_nowait(data)
def on_preset_control_point_indication(data: bytes):
self.preset_control_point_indications.put_nowait(data)
await self.hearing_aid_preset_control_point.subscribe(
functools.partial(on_preset_control_point_indication), prefer_notify=False
)
await self.active_preset_index.subscribe(
functools.partial(on_active_preset_index_notification)
)

View File

@@ -19,6 +19,7 @@
from enum import IntEnum
import struct
from bumble import core
from ..gatt_client import ProfileServiceProxy
from ..att import ATT_Error
from ..gatt import (
@@ -29,6 +30,7 @@ from ..gatt import (
TemplateService,
Characteristic,
CharacteristicValue,
SerializableCharacteristicAdapter,
DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter,
)
@@ -59,17 +61,17 @@ class HeartRateService(TemplateService):
rr_intervals=None,
):
if heart_rate < 0 or heart_rate > 0xFFFF:
raise ValueError('heart_rate out of range')
raise core.InvalidArgumentError('heart_rate out of range')
if energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF
):
raise ValueError('energy_expended out of range')
raise core.InvalidArgumentError('energy_expended out of range')
if rr_intervals:
for rr_interval in rr_intervals:
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise ValueError('rr_intervals out of range')
raise core.InvalidArgumentError('rr_intervals out of range')
self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected
@@ -149,15 +151,14 @@ class HeartRateService(TemplateService):
body_sensor_location=None,
reset_energy_expended=None,
):
self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter(
self.heart_rate_measurement_characteristic = SerializableCharacteristicAdapter(
Characteristic(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Characteristic.Properties.NOTIFY,
0,
CharacteristicValue(read=read_heart_rate_measurement),
),
# pylint: disable=unnecessary-lambda
encode=lambda value: bytes(value),
HeartRateService.HeartRateMeasurement,
)
characteristics = [self.heart_rate_measurement_characteristic]
@@ -203,9 +204,8 @@ class HeartRateServiceProxy(ProfileServiceProxy):
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
):
self.heart_rate_measurement = DelegatedCharacteristicAdapter(
characteristics[0],
decode=HeartRateService.HeartRateMeasurement.from_bytes,
self.heart_rate_measurement = SerializableCharacteristicAdapter(
characteristics[0], HeartRateService.HeartRateMeasurement
)
else:
self.heart_rate_measurement = None

View File

@@ -0,0 +1,83 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import struct
from typing import List, Type
from typing_extensions import Self
from bumble import utils
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class Metadata:
'''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures.
As Metadata fields may extend, and Spec doesn't forbid duplication, we don't parse
Metadata into a key-value style dataclass here. Rather, we encourage users to parse
again outside the lib.
'''
class Tag(utils.OpenIntEnum):
# fmt: off
PREFERRED_AUDIO_CONTEXTS = 0x01
STREAMING_AUDIO_CONTEXTS = 0x02
PROGRAM_INFO = 0x03
LANGUAGE = 0x04
CCID_LIST = 0x05
PARENTAL_RATING = 0x06
PROGRAM_INFO_URI = 0x07
AUDIO_ACTIVE_STATE = 0x08
BROADCAST_AUDIO_IMMEDIATE_RENDERING_FLAG = 0x09
ASSISTED_LISTENING_STREAM = 0x0A
BROADCAST_NAME = 0x0B
EXTENDED_METADATA = 0xFE
VENDOR_SPECIFIC = 0xFF
@dataclasses.dataclass
class Entry:
tag: Metadata.Tag
data: bytes
@classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(tag=Metadata.Tag(data[0]), data=data[1:])
def __bytes__(self) -> bytes:
return bytes([len(self.data) + 1, self.tag]) + self.data
entries: List[Entry] = dataclasses.field(default_factory=list)
@classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self:
entries = []
offset = 0
length = len(data)
while offset < length:
entry_length = data[offset]
offset += 1
entries.append(cls.Entry.from_bytes(data[offset : offset + entry_length]))
offset += entry_length
return cls(entries)
def __bytes__(self) -> bytes:
return b''.join([bytes(entry) for entry in self.entries])

448
bumble/profiles/mcp.py Normal file
View File

@@ -0,0 +1,448 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import enum
import struct
from bumble import core
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import utils
from typing import Type, Optional, ClassVar, Dict, TYPE_CHECKING
from typing_extensions import Self
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class PlayingOrder(utils.OpenIntEnum):
'''See Media Control Service 3.15. Playing Order.'''
SINGLE_ONCE = 0x01
SINGLE_REPEAT = 0x02
IN_ORDER_ONCE = 0x03
IN_ORDER_REPEAT = 0x04
OLDEST_ONCE = 0x05
OLDEST_REPEAT = 0x06
NEWEST_ONCE = 0x07
NEWEST_REPEAT = 0x08
SHUFFLE_ONCE = 0x09
SHUFFLE_REPEAT = 0x0A
class PlayingOrderSupported(enum.IntFlag):
'''See Media Control Service 3.16. Playing Orders Supported.'''
SINGLE_ONCE = 0x0001
SINGLE_REPEAT = 0x0002
IN_ORDER_ONCE = 0x0004
IN_ORDER_REPEAT = 0x0008
OLDEST_ONCE = 0x0010
OLDEST_REPEAT = 0x0020
NEWEST_ONCE = 0x0040
NEWEST_REPEAT = 0x0080
SHUFFLE_ONCE = 0x0100
SHUFFLE_REPEAT = 0x0200
class MediaState(utils.OpenIntEnum):
'''See Media Control Service 3.17. Media State.'''
INACTIVE = 0x00
PLAYING = 0x01
PAUSED = 0x02
SEEKING = 0x03
class MediaControlPointOpcode(utils.OpenIntEnum):
'''See Media Control Service 3.18. Media Control Point.'''
PLAY = 0x01
PAUSE = 0x02
FAST_REWIND = 0x03
FAST_FORWARD = 0x04
STOP = 0x05
MOVE_RELATIVE = 0x10
PREVIOUS_SEGMENT = 0x20
NEXT_SEGMENT = 0x21
FIRST_SEGMENT = 0x22
LAST_SEGMENT = 0x23
GOTO_SEGMENT = 0x24
PREVIOUS_TRACK = 0x30
NEXT_TRACK = 0x31
FIRST_TRACK = 0x32
LAST_TRACK = 0x33
GOTO_TRACK = 0x34
PREVIOUS_GROUP = 0x40
NEXT_GROUP = 0x41
FIRST_GROUP = 0x42
LAST_GROUP = 0x43
GOTO_GROUP = 0x44
class MediaControlPointResultCode(enum.IntFlag):
'''See Media Control Service 3.18.2. Media Control Point Notification.'''
SUCCESS = 0x01
OPCODE_NOT_SUPPORTED = 0x02
MEDIA_PLAYER_INACTIVE = 0x03
COMMAND_CANNOT_BE_COMPLETED = 0x04
class MediaControlPointOpcodeSupported(enum.IntFlag):
'''See Media Control Service 3.19. Media Control Point Opcodes Supported.'''
PLAY = 0x00000001
PAUSE = 0x00000002
FAST_REWIND = 0x00000004
FAST_FORWARD = 0x00000008
STOP = 0x00000010
MOVE_RELATIVE = 0x00000020
PREVIOUS_SEGMENT = 0x00000040
NEXT_SEGMENT = 0x00000080
FIRST_SEGMENT = 0x00000100
LAST_SEGMENT = 0x00000200
GOTO_SEGMENT = 0x00000400
PREVIOUS_TRACK = 0x00000800
NEXT_TRACK = 0x00001000
FIRST_TRACK = 0x00002000
LAST_TRACK = 0x00004000
GOTO_TRACK = 0x00008000
PREVIOUS_GROUP = 0x00010000
NEXT_GROUP = 0x00020000
FIRST_GROUP = 0x00040000
LAST_GROUP = 0x00080000
GOTO_GROUP = 0x00100000
class SearchControlPointItemType(utils.OpenIntEnum):
'''See Media Control Service 3.20. Search Control Point.'''
TRACK_NAME = 0x01
ARTIST_NAME = 0x02
ALBUM_NAME = 0x03
GROUP_NAME = 0x04
EARLIEST_YEAR = 0x05
LATEST_YEAR = 0x06
GENRE = 0x07
ONLY_TRACKS = 0x08
ONLY_GROUPS = 0x09
class ObjectType(utils.OpenIntEnum):
'''See Media Control Service 4.4.1. Object Type field.'''
TASK = 0
GROUP = 1
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class ObjectId(int):
'''See Media Control Service 4.4.2. Object ID field.'''
@classmethod
def create_from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(int.from_bytes(data, byteorder='little', signed=False))
def __bytes__(self) -> bytes:
return self.to_bytes(6, 'little')
@dataclasses.dataclass
class GroupObjectType:
'''See Media Control Service 4.4. Group Object Type.'''
object_type: ObjectType
object_id: ObjectId
@classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(
object_type=ObjectType(data[0]),
object_id=ObjectId.create_from_bytes(data[1:]),
)
def __bytes__(self) -> bytes:
return bytes([self.object_type]) + bytes(self.object_id)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class MediaControlService(gatt.TemplateService):
'''Media Control Service server implementation, only for testing currently.'''
UUID = gatt.GATT_MEDIA_CONTROL_SERVICE
def __init__(self, media_player_name: Optional[str] = None) -> None:
self.track_position = 0
self.media_player_name_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=media_player_name or 'Bumble Player',
)
self.track_changed_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_title_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_duration_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_position_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=b'',
)
self.media_state_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.media_control_point_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(write=self.on_media_control_point),
)
self.media_control_point_opcodes_supported_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.content_control_id_characteristic = gatt.Characteristic(
uuid=gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
super().__init__(
[
self.media_player_name_characteristic,
self.track_changed_characteristic,
self.track_title_characteristic,
self.track_duration_characteristic,
self.track_position_characteristic,
self.media_state_characteristic,
self.media_control_point_characteristic,
self.media_control_point_opcodes_supported_characteristic,
self.content_control_id_characteristic,
]
)
async def on_media_control_point(
self, connection: Optional[device.Connection], data: bytes
) -> None:
if not connection:
raise core.InvalidStateError()
opcode = MediaControlPointOpcode(data[0])
await connection.device.notify_subscriber(
connection,
self.media_control_point_characteristic,
value=bytes([opcode, MediaControlPointResultCode.SUCCESS]),
)
class GenericMediaControlService(MediaControlService):
UUID = gatt.GATT_GENERIC_MEDIA_CONTROL_SERVICE
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class MediaControlServiceProxy(
gatt_client.ProfileServiceProxy, utils.CompositeEventEmitter
):
SERVICE_CLASS = MediaControlService
_CHARACTERISTICS: ClassVar[Dict[str, core.UUID]] = {
'media_player_name': gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
'media_player_icon_object_id': gatt.GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC,
'media_player_icon_url': gatt.GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC,
'track_changed': gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
'track_title': gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
'track_duration': gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
'track_position': gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
'playback_speed': gatt.GATT_PLAYBACK_SPEED_CHARACTERISTIC,
'seeking_speed': gatt.GATT_SEEKING_SPEED_CHARACTERISTIC,
'current_track_segments_object_id': gatt.GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC,
'current_track_object_id': gatt.GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC,
'next_track_object_id': gatt.GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC,
'parent_group_object_id': gatt.GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC,
'current_group_object_id': gatt.GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC,
'playing_order': gatt.GATT_PLAYING_ORDER_CHARACTERISTIC,
'playing_orders_supported': gatt.GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC,
'media_state': gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
'media_control_point': gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
'media_control_point_opcodes_supported': gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
'search_control_point': gatt.GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC,
'search_results_object_id': gatt.GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC,
'content_control_id': gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
}
media_player_name: Optional[gatt_client.CharacteristicProxy] = None
media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy] = None
media_player_icon_url: Optional[gatt_client.CharacteristicProxy] = None
track_changed: Optional[gatt_client.CharacteristicProxy] = None
track_title: Optional[gatt_client.CharacteristicProxy] = None
track_duration: Optional[gatt_client.CharacteristicProxy] = None
track_position: Optional[gatt_client.CharacteristicProxy] = None
playback_speed: Optional[gatt_client.CharacteristicProxy] = None
seeking_speed: Optional[gatt_client.CharacteristicProxy] = None
current_track_segments_object_id: Optional[gatt_client.CharacteristicProxy] = None
current_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
next_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
parent_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
current_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
playing_order: Optional[gatt_client.CharacteristicProxy] = None
playing_orders_supported: Optional[gatt_client.CharacteristicProxy] = None
media_state: Optional[gatt_client.CharacteristicProxy] = None
media_control_point: Optional[gatt_client.CharacteristicProxy] = None
media_control_point_opcodes_supported: Optional[gatt_client.CharacteristicProxy] = (
None
)
search_control_point: Optional[gatt_client.CharacteristicProxy] = None
search_results_object_id: Optional[gatt_client.CharacteristicProxy] = None
content_control_id: Optional[gatt_client.CharacteristicProxy] = None
if TYPE_CHECKING:
media_control_point_notifications: asyncio.Queue[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
utils.CompositeEventEmitter.__init__(self)
self.service_proxy = service_proxy
self.lock = asyncio.Lock()
self.media_control_point_notifications = asyncio.Queue()
for field, uuid in self._CHARACTERISTICS.items():
if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
setattr(self, field, characteristics[0])
async def subscribe_characteristics(self) -> None:
if self.media_control_point:
await self.media_control_point.subscribe(self._on_media_control_point)
if self.media_state:
await self.media_state.subscribe(self._on_media_state)
if self.track_changed:
await self.track_changed.subscribe(self._on_track_changed)
if self.track_title:
await self.track_title.subscribe(self._on_track_title)
if self.track_duration:
await self.track_duration.subscribe(self._on_track_duration)
if self.track_position:
await self.track_position.subscribe(self._on_track_position)
async def write_control_point(
self, opcode: MediaControlPointOpcode
) -> MediaControlPointResultCode:
'''Writes a Media Control Point Opcode to peer and waits for the notification.
The write operation will be executed when there isn't other pending commands.
Args:
opcode: opcode defined in `MediaControlPointOpcode`.
Returns:
Response code provided in `MediaControlPointResultCode`
Raises:
InvalidOperationError: Server does not have Media Control Point Characteristic.
InvalidStateError: Server replies a notification with mismatched opcode.
'''
if not self.media_control_point:
raise core.InvalidOperationError("Peer does not have media control point")
async with self.lock:
await self.media_control_point.write_value(
bytes([opcode]),
with_response=False,
)
(
response_opcode,
response_code,
) = await self.media_control_point_notifications.get()
if response_opcode != opcode:
raise core.InvalidStateError(
f"Expected {opcode} notification, but get {response_opcode}"
)
return MediaControlPointResultCode(response_code)
def _on_media_control_point(self, data: bytes) -> None:
self.media_control_point_notifications.put_nowait(data)
def _on_media_state(self, data: bytes) -> None:
self.emit('media_state', MediaState(data[0]))
def _on_track_changed(self, data: bytes) -> None:
del data
self.emit('track_changed')
def _on_track_title(self, data: bytes) -> None:
self.emit('track_title', data.decode("utf-8"))
def _on_track_duration(self, data: bytes) -> None:
self.emit('track_duration', struct.unpack_from('<i', data)[0])
def _on_track_position(self, data: bytes) -> None:
self.emit('track_position', struct.unpack_from('<i', data)[0])
class GenericMediaControlServiceProxy(MediaControlServiceProxy):
SERVICE_CLASS = GenericMediaControlService

210
bumble/profiles/pacs.py Normal file
View File

@@ -0,0 +1,210 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for
"""LE Audio - Published Audio Capabilities Service"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import logging
import struct
from typing import Optional, Sequence, Union
from bumble.profiles.bap import AudioLocation, CodecSpecificCapabilities, ContextType
from bumble.profiles import le_audio
from bumble import gatt
from bumble import gatt_client
from bumble import hci
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class PacRecord:
'''Published Audio Capabilities Service, Table 3.2/3.4.'''
coding_format: hci.CodingFormat
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata)
@classmethod
def from_bytes(cls, data: bytes) -> PacRecord:
offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0)
codec_specific_capabilities_size = data[offset]
offset += 1
codec_specific_capabilities_bytes = data[
offset : offset + codec_specific_capabilities_size
]
offset += codec_specific_capabilities_size
metadata_size = data[offset]
offset += 1
metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size])
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
codec_specific_capabilities = codec_specific_capabilities_bytes
else:
codec_specific_capabilities = CodecSpecificCapabilities.from_bytes(
codec_specific_capabilities_bytes
)
return PacRecord(
coding_format=coding_format,
codec_specific_capabilities=codec_specific_capabilities,
metadata=metadata,
)
def __bytes__(self) -> bytes:
capabilities_bytes = bytes(self.codec_specific_capabilities)
metadata_bytes = bytes(self.metadata)
return (
bytes(self.coding_format)
+ bytes([len(capabilities_bytes)])
+ capabilities_bytes
+ bytes([len(metadata_bytes)])
+ metadata_bytes
)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesService(gatt.TemplateService):
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
sink_pac: Optional[gatt.Characteristic]
sink_audio_locations: Optional[gatt.Characteristic]
source_pac: Optional[gatt.Characteristic]
source_audio_locations: Optional[gatt.Characteristic]
available_audio_contexts: gatt.Characteristic
supported_audio_contexts: gatt.Characteristic
def __init__(
self,
supported_source_context: ContextType,
supported_sink_context: ContextType,
available_source_context: ContextType,
available_sink_context: ContextType,
sink_pac: Sequence[PacRecord] = (),
sink_audio_locations: Optional[AudioLocation] = None,
source_pac: Sequence[PacRecord] = (),
source_audio_locations: Optional[AudioLocation] = None,
) -> None:
characteristics = []
self.supported_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', supported_sink_context, supported_source_context),
)
characteristics.append(self.supported_audio_contexts)
self.available_audio_contexts = gatt.Characteristic(
uuid=gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<HH', available_sink_context, available_source_context),
)
characteristics.append(self.available_audio_contexts)
if sink_pac:
self.sink_pac = gatt.Characteristic(
uuid=gatt.GATT_SINK_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(sink_pac)]) + b''.join(map(bytes, sink_pac)),
)
characteristics.append(self.sink_pac)
if sink_audio_locations is not None:
self.sink_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', sink_audio_locations),
)
characteristics.append(self.sink_audio_locations)
if source_pac:
self.source_pac = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_PAC_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=bytes([len(source_pac)]) + b''.join(map(bytes, source_pac)),
)
characteristics.append(self.source_pac)
if source_audio_locations is not None:
self.source_audio_locations = gatt.Characteristic(
uuid=gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('<I', source_audio_locations),
)
characteristics.append(self.source_audio_locations)
super().__init__(characteristics)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = PublishedAudioCapabilitiesService
sink_pac: Optional[gatt_client.CharacteristicProxy] = None
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
source_pac: Optional[gatt_client.CharacteristicProxy] = None
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
available_audio_contexts: gatt_client.CharacteristicProxy
supported_audio_contexts: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
)[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_PAC_CHARACTERISTIC
):
self.sink_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_PAC_CHARACTERISTIC
):
self.source_pac = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
):
self.sink_audio_locations = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
):
self.source_audio_locations = characteristics[0]

46
bumble/profiles/pbp.py Normal file
View File

@@ -0,0 +1,46 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import enum
from typing_extensions import Self
from bumble.profiles import le_audio
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class PublicBroadcastAnnouncement:
class Features(enum.IntFlag):
ENCRYPTED = 1 << 0
STANDARD_QUALITY_CONFIGURATION = 1 << 1
HIGH_QUALITY_CONFIGURATION = 1 << 2
features: Features
metadata: le_audio.Metadata
@classmethod
def from_bytes(cls, data: bytes) -> Self:
features = cls.Features(data[0])
metadata_length = data[1]
metadata_ltv = data[1 : 1 + metadata_length]
return cls(
features=features, metadata=le_audio.Metadata.from_bytes(metadata_ltv)
)

89
bumble/profiles/tmap.py Normal file
View File

@@ -0,0 +1,89 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LE Audio - Telephony and Media Audio Profile"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import enum
import logging
import struct
from bumble.gatt import (
TemplateService,
Characteristic,
DelegatedCharacteristicAdapter,
InvalidServiceError,
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE,
GATT_TMAP_ROLE_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Role(enum.IntFlag):
CALL_GATEWAY = 1 << 0
CALL_TERMINAL = 1 << 1
UNICAST_MEDIA_SENDER = 1 << 2
UNICAST_MEDIA_RECEIVER = 1 << 3
BROADCAST_MEDIA_SENDER = 1 << 4
BROADCAST_MEDIA_RECEIVER = 1 << 5
# -----------------------------------------------------------------------------
class TelephonyAndMediaAudioService(TemplateService):
UUID = GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE
def __init__(self, role: Role):
self.role_characteristic = Characteristic(
GATT_TMAP_ROLE_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', int(role)),
)
super().__init__([self.role_characteristic])
# -----------------------------------------------------------------------------
class TelephonyAndMediaAudioServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = TelephonyAndMediaAudioService
role: DelegatedCharacteristicAdapter
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_TMAP_ROLE_CHARACTERISTIC
)
):
raise InvalidServiceError('TMAP Role characteristic not found')
self.role = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: Role(
struct.unpack_from('<H', value, 0)[0],
),
)

230
bumble/profiles/vcp.py Normal file
View File

@@ -0,0 +1,230 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
from bumble import att
from bumble import device
from bumble import gatt
from bumble import gatt_client
from typing import Optional, Sequence
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
MIN_VOLUME = 0
MAX_VOLUME = 255
class ErrorCode(enum.IntEnum):
'''
See Volume Control Service 1.6. Application error codes.
'''
INVALID_CHANGE_COUNTER = 0x80
OPCODE_NOT_SUPPORTED = 0x81
class VolumeFlags(enum.IntFlag):
'''
See Volume Control Service 3.3. Volume Flags.
'''
VOLUME_SETTING_PERSISTED = 0x01
# RFU
class VolumeControlPointOpcode(enum.IntEnum):
'''
See Volume Control Service Table 3.3: Volume Control Point procedure requirements.
'''
# fmt: off
RELATIVE_VOLUME_DOWN = 0x00
RELATIVE_VOLUME_UP = 0x01
UNMUTE_RELATIVE_VOLUME_DOWN = 0x02
UNMUTE_RELATIVE_VOLUME_UP = 0x03
SET_ABSOLUTE_VOLUME = 0x04
UNMUTE = 0x05
MUTE = 0x06
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class VolumeControlService(gatt.TemplateService):
UUID = gatt.GATT_VOLUME_CONTROL_SERVICE
volume_state: gatt.Characteristic
volume_control_point: gatt.Characteristic
volume_flags: gatt.Characteristic
volume_setting: int
muted: int
change_counter: int
def __init__(
self,
step_size: int = 16,
volume_setting: int = 0,
muted: int = 0,
change_counter: int = 0,
volume_flags: int = 0,
included_services: Sequence[gatt.Service] = (),
) -> None:
self.step_size = step_size
self.volume_setting = volume_setting
self.muted = muted
self.change_counter = change_counter
self.volume_state = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_STATE_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self._on_read_volume_state),
)
self.volume_control_point = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(write=self._on_write_volume_control_point),
)
self.volume_flags = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes([volume_flags]),
)
super().__init__(
characteristics=[
self.volume_state,
self.volume_control_point,
self.volume_flags,
],
included_services=list(included_services),
)
@property
def volume_state_bytes(self) -> bytes:
return bytes([self.volume_setting, self.muted, self.change_counter])
@volume_state_bytes.setter
def volume_state_bytes(self, new_value: bytes) -> None:
self.volume_setting, self.muted, self.change_counter = new_value
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
return self.volume_state_bytes
def _on_write_volume_control_point(
self, connection: Optional[device.Connection], value: bytes
) -> None:
assert connection
opcode = VolumeControlPointOpcode(value[0])
change_counter = value[1]
if change_counter != self.change_counter:
raise att.ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
handler = getattr(self, '_on_' + opcode.name.lower())
if handler(*value[2:]):
self.change_counter = (self.change_counter + 1) % 256
connection.abort_on(
'disconnection',
connection.device.notify_subscribers(
attribute=self.volume_state,
value=self.volume_state_bytes,
),
)
self.emit(
'volume_state', self.volume_setting, self.muted, self.change_counter
)
def _on_relative_volume_down(self) -> bool:
old_volume = self.volume_setting
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
return self.volume_setting != old_volume
def _on_relative_volume_up(self) -> bool:
old_volume = self.volume_setting
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
return self.volume_setting != old_volume
def _on_unmute_relative_volume_down(self) -> bool:
old_volume, old_muted_state = self.volume_setting, self.muted
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
self.muted = 0
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
def _on_unmute_relative_volume_up(self) -> bool:
old_volume, old_muted_state = self.volume_setting, self.muted
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
self.muted = 0
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
def _on_set_absolute_volume(self, volume_setting: int) -> bool:
old_volume_setting = self.volume_setting
self.volume_setting = volume_setting
return old_volume_setting != self.volume_setting
def _on_unmute(self) -> bool:
old_muted_state = self.muted
self.muted = 0
return self.muted != old_muted_state
def _on_mute(self) -> bool:
old_muted_state = self.muted
self.muted = 1
return self.muted != old_muted_state
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class VolumeControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = VolumeControlService
volume_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.volume_state = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_STATE_CHARACTERISTIC
)[0],
'BBB',
)
self.volume_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC
)[0]
self.volume_flags = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC
)[0],
'B',
)

330
bumble/profiles/vocs.py Normal file
View File

@@ -0,0 +1,330 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
from dataclasses import dataclass
from typing import Optional
from bumble.device import Connection
from bumble.att import ATT_Error
from bumble.gatt import (
Characteristic,
DelegatedCharacteristicAdapter,
TemplateService,
CharacteristicValue,
UTF8CharacteristicAdapter,
InvalidServiceError,
GATT_VOLUME_OFFSET_CONTROL_SERVICE,
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC,
GATT_AUDIO_LOCATION_CHARACTERISTIC,
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC,
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC,
)
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
from bumble.utils import OpenIntEnum
from bumble.profiles.bap import AudioLocation
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
MIN_VOLUME_OFFSET = -255
MAX_VOLUME_OFFSET = 255
CHANGE_COUNTER_MAX_VALUE = 0xFF
class SetVolumeOffsetOpCode(OpenIntEnum):
SET_VOLUME_OFFSET = 0x01
class ErrorCode(OpenIntEnum):
"""
See Volume Offset Control Service 1.6. Application error codes.
"""
INVALID_CHANGE_COUNTER = 0x80
OPCODE_NOT_SUPPORTED = 0x81
VALUE_OUT_OF_RANGE = 0x82
# -----------------------------------------------------------------------------
@dataclass
class VolumeOffsetState:
volume_offset: int = 0
change_counter: int = 0
attribute_value: Optional[CharacteristicValue] = None
def __bytes__(self) -> bytes:
return struct.pack('<hB', self.volume_offset, self.change_counter)
@classmethod
def from_bytes(cls, data: bytes):
volume_offset, change_counter = struct.unpack('<hB', data)
return cls(volume_offset, change_counter)
def increment_change_counter(self) -> None:
self.change_counter = (self.change_counter + 1) % (CHANGE_COUNTER_MAX_VALUE + 1)
async def notify_subscribers_via_connection(self, connection: Connection) -> None:
assert self.attribute_value is not None
await connection.device.notify_subscribers(
attribute=self.attribute_value, value=bytes(self)
)
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
@dataclass
class VocsAudioLocation:
audio_location: AudioLocation = AudioLocation.NOT_ALLOWED
attribute_value: Optional[CharacteristicValue] = None
def __bytes__(self) -> bytes:
return struct.pack('<I', self.audio_location)
@classmethod
def from_bytes(cls, data: bytes):
audio_location = AudioLocation(struct.unpack('<I', data)[0])
return cls(audio_location)
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
assert self.attribute_value
self.audio_location = AudioLocation(int.from_bytes(value, 'little'))
await connection.device.notify_subscribers(
attribute=self.attribute_value, value=value
)
@dataclass
class VolumeOffsetControlPoint:
volume_offset_state: VolumeOffsetState
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
opcode = value[0]
if opcode != SetVolumeOffsetOpCode.SET_VOLUME_OFFSET:
raise ATT_Error(ErrorCode.OPCODE_NOT_SUPPORTED)
change_counter, volume_offset = struct.unpack('<Bh', value[1:])
await self._set_volume_offset(connection, change_counter, volume_offset)
async def _set_volume_offset(
self,
connection: Connection,
change_counter_operand: int,
volume_offset_operand: int,
) -> None:
change_counter = self.volume_offset_state.change_counter
if change_counter != change_counter_operand:
raise ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
if not MIN_VOLUME_OFFSET <= volume_offset_operand <= MAX_VOLUME_OFFSET:
raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
self.volume_offset_state.volume_offset = volume_offset_operand
self.volume_offset_state.increment_change_counter()
await self.volume_offset_state.notify_subscribers_via_connection(connection)
@dataclass
class AudioOutputDescription:
audio_output_description: str = ''
attribute_value: Optional[CharacteristicValue] = None
@classmethod
def from_bytes(cls, data: bytes):
return cls(audio_output_description=data.decode('utf-8'))
def __bytes__(self) -> bytes:
return self.audio_output_description.encode('utf-8')
def on_read(self, _connection: Optional[Connection]) -> bytes:
return bytes(self)
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
assert self.attribute_value
self.audio_output_description = value.decode('utf-8')
await connection.device.notify_subscribers(
attribute=self.attribute_value, value=value
)
# -----------------------------------------------------------------------------
class VolumeOffsetControlService(TemplateService):
UUID = GATT_VOLUME_OFFSET_CONTROL_SERVICE
def __init__(
self,
volume_offset_state: Optional[VolumeOffsetState] = None,
audio_location: Optional[VocsAudioLocation] = None,
audio_output_description: Optional[AudioOutputDescription] = None,
) -> None:
self.volume_offset_state = (
VolumeOffsetState() if volume_offset_state is None else volume_offset_state
)
self.audio_location = (
VocsAudioLocation() if audio_location is None else audio_location
)
self.audio_output_description = (
AudioOutputDescription()
if audio_output_description is None
else audio_output_description
)
self.volume_offset_control_point: VolumeOffsetControlPoint = (
VolumeOffsetControlPoint(self.volume_offset_state)
)
self.volume_offset_state_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
uuid=GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC,
properties=(
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY
),
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=CharacteristicValue(read=self.volume_offset_state.on_read),
),
encode=lambda value: bytes(value),
)
self.audio_location_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
uuid=GATT_AUDIO_LOCATION_CHARACTERISTIC,
properties=(
Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE
),
permissions=(
Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION
),
value=CharacteristicValue(
read=self.audio_location.on_read,
write=self.audio_location.on_write,
),
),
encode=lambda value: bytes(value),
decode=VocsAudioLocation.from_bytes,
)
self.audio_location.attribute_value = self.audio_location_characteristic.value
self.volume_offset_control_point_characteristic = Characteristic(
uuid=GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC,
properties=Characteristic.Properties.WRITE,
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=CharacteristicValue(write=self.volume_offset_control_point.on_write),
)
self.audio_output_description_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
uuid=GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC,
properties=(
Characteristic.Properties.READ
| Characteristic.Properties.NOTIFY
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE
),
permissions=(
Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION
),
value=CharacteristicValue(
read=self.audio_output_description.on_read,
write=self.audio_output_description.on_write,
),
)
)
self.audio_output_description.attribute_value = (
self.audio_output_description_characteristic.value
)
super().__init__(
characteristics=[
self.volume_offset_state_characteristic, # type: ignore
self.audio_location_characteristic, # type: ignore
self.volume_offset_control_point_characteristic, # type: ignore
self.audio_output_description_characteristic, # type: ignore
],
primary=False,
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class VolumeOffsetControlServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = VolumeOffsetControlService
def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC
)
):
raise InvalidServiceError("Volume Offset State characteristic not found")
self.volume_offset_state = DelegatedCharacteristicAdapter(
characteristics[0], decode=VolumeOffsetState.from_bytes
)
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_LOCATION_CHARACTERISTIC
)
):
raise InvalidServiceError("Audio Location characteristic not found")
self.audio_location = DelegatedCharacteristicAdapter(
characteristics[0],
encode=lambda value: bytes(value),
decode=VocsAudioLocation.from_bytes,
)
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC
)
):
raise InvalidServiceError(
"Volume Offset Control Point characteristic not found"
)
self.volume_offset_control_point = characteristics[0]
if not (
characteristics := service_proxy.get_characteristics_by_uuid(
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC
)
):
raise InvalidServiceError(
"Audio Output Description characteristic not found"
)
self.audio_output_description = UTF8CharacteristicAdapter(characteristics[0])

View File

@@ -19,30 +19,28 @@ from __future__ import annotations
import logging
import asyncio
import collections
import dataclasses
import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from typing_extensions import Self
from pyee import EventEmitter
from . import core, l2cap
from bumble import core
from bumble import l2cap
from bumble import sdp
from .colors import color
from .core import (
UUID,
BT_RFCOMM_PROTOCOL_ID,
BT_BR_EDR_TRANSPORT,
BT_L2CAP_PROTOCOL_ID,
InvalidArgumentError,
InvalidStateError,
InvalidPacketError,
ProtocolError,
)
from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_PUBLIC_BROWSE_ROOT,
DataElement,
ServiceAttribute,
)
if TYPE_CHECKING:
from bumble.device import Device, Connection
@@ -59,28 +57,20 @@ logger = logging.getLogger(__name__)
# fmt: off
RFCOMM_PSM = 0x0003
DEFAULT_RX_QUEUE_SIZE = 32
class FrameType(enum.IntEnum):
SABM = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
UA = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
DM = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
DISC = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
UIH = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
UI = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first
# Frame types
RFCOMM_SABM_FRAME = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
RFCOMM_UA_FRAME = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
RFCOMM_DM_FRAME = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
RFCOMM_DISC_FRAME = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
RFCOMM_UIH_FRAME = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
RFCOMM_UI_FRAME = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first
class MccType(enum.IntEnum):
PN = 0x20
MSC = 0x38
RFCOMM_FRAME_TYPE_NAMES = {
RFCOMM_SABM_FRAME: 'SABM',
RFCOMM_UA_FRAME: 'UA',
RFCOMM_DM_FRAME: 'DM',
RFCOMM_DISC_FRAME: 'DISC',
RFCOMM_UIH_FRAME: 'UIH',
RFCOMM_UI_FRAME: 'UI'
}
# MCC Types
RFCOMM_MCC_PN_TYPE = 0x20
RFCOMM_MCC_MSC_TYPE = 0x38
# FCS CRC
CRC_TABLE = bytes([
@@ -118,8 +108,11 @@ CRC_TABLE = bytes([
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
])
RFCOMM_DEFAULT_INITIAL_RX_CREDITS = 7
RFCOMM_DEFAULT_PREFERRED_MTU = 1280
RFCOMM_DEFAULT_L2CAP_MTU = 2048
RFCOMM_DEFAULT_INITIAL_CREDITS = 7
RFCOMM_DEFAULT_MAX_CREDITS = 32
RFCOMM_DEFAULT_CREDIT_THRESHOLD = RFCOMM_DEFAULT_MAX_CREDITS // 2
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
@@ -130,29 +123,33 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# -----------------------------------------------------------------------------
def make_service_sdp_records(
service_record_handle: int, channel: int, uuid: Optional[UUID] = None
) -> List[ServiceAttribute]:
) -> List[sdp.ServiceAttribute]:
"""
Create SDP records for an RFComm service given a channel number and an
optional UUID. A Service Class Attribute is included only if the UUID is not None.
"""
records = [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
sdp.ServiceAttribute(
sdp.SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
sdp.DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
sdp.ServiceAttribute(
sdp.SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence(
[sdp.DataElement.uuid(sdp.SDP_PUBLIC_BROWSE_ROOT)]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
sdp.ServiceAttribute(
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
sdp.DataElement.sequence(
[sdp.DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]
),
sdp.DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel),
sdp.DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
sdp.DataElement.unsigned_integer_8(channel),
]
),
]
@@ -162,15 +159,81 @@ def make_service_sdp_records(
if uuid:
records.append(
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(uuid)]),
sdp.ServiceAttribute(
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence([sdp.DataElement.uuid(uuid)]),
)
)
return records
# -----------------------------------------------------------------------------
async def find_rfcomm_channels(connection: Connection) -> Dict[int, List[UUID]]:
"""Searches all RFCOMM channels and their associated UUID from SDP service records.
Args:
connection: ACL connection to make SDP search.
Returns:
Dictionary mapping from channel number to service class UUID list.
"""
results = {}
async with sdp.Client(connection) as sdp_client:
search_result = await sdp_client.search_attributes(
uuids=[core.BT_RFCOMM_PROTOCOL_ID],
attribute_ids=[
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
],
)
for attribute_lists in search_result:
service_classes: List[UUID] = []
channel: Optional[int] = None
for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
protocol_descriptor_list = attribute.value.value
channel = protocol_descriptor_list[1].value[1].value
elif attribute.id == sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID:
service_class_id_list = attribute.value.value
service_classes = [
service_class.value for service_class in service_class_id_list
]
if not service_classes or not channel:
logger.warning(f"Bad result {attribute_lists}.")
else:
results[channel] = service_classes
return results
# -----------------------------------------------------------------------------
async def find_rfcomm_channel_with_uuid(
connection: Connection, uuid: str | UUID
) -> Optional[int]:
"""Searches an RFCOMM channel associated with given UUID from service records.
Args:
connection: ACL connection to make SDP search.
uuid: UUID of service record to search for.
Returns:
RFCOMM channel number if found, otherwise None.
"""
if isinstance(uuid, str):
uuid = UUID(uuid)
return next(
(
channel
for channel, class_id_list in (
await find_rfcomm_channels(connection)
).items()
if uuid in class_id_list
),
None,
)
# -----------------------------------------------------------------------------
def compute_fcs(buffer: bytes) -> int:
result = 0xFF
@@ -183,7 +246,7 @@ def compute_fcs(buffer: bytes) -> int:
class RFCOMM_Frame:
def __init__(
self,
frame_type: int,
frame_type: FrameType,
c_r: int,
dlci: int,
p_f: int,
@@ -206,14 +269,11 @@ class RFCOMM_Frame:
self.length = bytes([(length << 1) | 1])
self.address = (dlci << 2) | (c_r << 1) | 1
self.control = frame_type | (p_f << 4)
if frame_type == RFCOMM_UIH_FRAME:
if frame_type == FrameType.UIH:
self.fcs = compute_fcs(bytes([self.address, self.control]))
else:
self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
def type_name(self) -> str:
return RFCOMM_FRAME_TYPE_NAMES[self.type]
@staticmethod
def parse_mcc(data) -> Tuple[int, bool, bytes]:
mcc_type = data[0] >> 2
@@ -237,24 +297,24 @@ class RFCOMM_Frame:
@staticmethod
def sabm(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_SABM_FRAME, c_r, dlci, 1)
return RFCOMM_Frame(FrameType.SABM, c_r, dlci, 1)
@staticmethod
def ua(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_UA_FRAME, c_r, dlci, 1)
return RFCOMM_Frame(FrameType.UA, c_r, dlci, 1)
@staticmethod
def dm(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_DM_FRAME, c_r, dlci, 1)
return RFCOMM_Frame(FrameType.DM, c_r, dlci, 1)
@staticmethod
def disc(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1)
return RFCOMM_Frame(FrameType.DISC, c_r, dlci, 1)
@staticmethod
def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0):
return RFCOMM_Frame(
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
FrameType.UIH, c_r, dlci, p_f, information, with_credits=(p_f == 1)
)
@staticmethod
@@ -262,7 +322,7 @@ class RFCOMM_Frame:
# Extract fields
dlci = (data[0] >> 2) & 0x3F
c_r = (data[0] >> 1) & 0x01
frame_type = data[1] & 0xEF
frame_type = FrameType(data[1] & 0xEF)
p_f = (data[1] >> 4) & 0x01
length = data[2]
if length & 0x01:
@@ -277,7 +337,7 @@ class RFCOMM_Frame:
frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
if frame.fcs != fcs:
logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
raise ValueError('fcs mismatch')
raise InvalidPacketError('fcs mismatch')
return frame
@@ -291,7 +351,7 @@ class RFCOMM_Frame:
def __str__(self) -> str:
return (
f'{color(self.type_name(), "yellow")}'
f'{color(self.type.name, "yellow")}'
f'(c/r={self.c_r},'
f'dlci={self.dlci},'
f'p/f={self.p_f},'
@@ -301,6 +361,7 @@ class RFCOMM_Frame:
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class RFCOMM_MCC_PN:
dlci: int
cl: int
@@ -308,25 +369,13 @@ class RFCOMM_MCC_PN:
ack_timer: int
max_frame_size: int
max_retransmissions: int
window_size: int
initial_credits: int
def __init__(
self,
dlci: int,
cl: int,
priority: int,
ack_timer: int,
max_frame_size: int,
max_retransmissions: int,
window_size: int,
) -> None:
self.dlci = dlci
self.cl = cl
self.priority = priority
self.ack_timer = ack_timer
self.max_frame_size = max_frame_size
self.max_retransmissions = max_retransmissions
self.window_size = window_size
def __post_init__(self) -> None:
if self.initial_credits < 1 or self.initial_credits > 7:
logger.warning(
f'Initial credits {self.initial_credits} is out of range [1, 7].'
)
@staticmethod
def from_bytes(data: bytes) -> RFCOMM_MCC_PN:
@@ -337,7 +386,7 @@ class RFCOMM_MCC_PN:
ack_timer=data[3],
max_frame_size=data[4] | data[5] << 8,
max_retransmissions=data[6],
window_size=data[7],
initial_credits=data[7] & 0x07,
)
def __bytes__(self) -> bytes:
@@ -350,23 +399,14 @@ class RFCOMM_MCC_PN:
self.max_frame_size & 0xFF,
(self.max_frame_size >> 8) & 0xFF,
self.max_retransmissions & 0xFF,
self.window_size & 0xFF,
# Only 3 bits are meaningful.
self.initial_credits & 0x07,
]
)
def __str__(self) -> str:
return (
f'PN(dlci={self.dlci},'
f'cl={self.cl},'
f'priority={self.priority},'
f'ack_timer={self.ack_timer},'
f'max_frame_size={self.max_frame_size},'
f'max_retransmissions={self.max_retransmissions},'
f'window_size={self.window_size})'
)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class RFCOMM_MCC_MSC:
dlci: int
fc: int
@@ -375,16 +415,6 @@ class RFCOMM_MCC_MSC:
ic: int
dv: int
def __init__(
self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int
) -> None:
self.dlci = dlci
self.fc = fc
self.rtc = rtc
self.rtr = rtr
self.ic = ic
self.dv = dv
@staticmethod
def from_bytes(data: bytes) -> RFCOMM_MCC_MSC:
return RFCOMM_MCC_MSC(
@@ -409,16 +439,6 @@ class RFCOMM_MCC_MSC:
]
)
def __str__(self) -> str:
return (
f'MSC(dlci={self.dlci},'
f'fc={self.fc},'
f'rtc={self.rtc},'
f'rtr={self.rtr},'
f'ic={self.ic},'
f'dv={self.dv})'
)
# -----------------------------------------------------------------------------
class DLC(EventEmitter):
@@ -430,35 +450,58 @@ class DLC(EventEmitter):
DISCONNECTED = 0x04
RESET = 0x05
connection_result: Optional[asyncio.Future]
sink: Optional[Callable[[bytes], None]]
def __init__(
self,
multiplexer: Multiplexer,
dlci: int,
max_frame_size: int,
initial_tx_credits: int,
tx_max_frame_size: int,
tx_initial_credits: int,
rx_max_frame_size: int,
rx_initial_credits: int,
) -> None:
super().__init__()
self.multiplexer = multiplexer
self.dlci = dlci
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
self.rx_threshold = self.rx_credits // 2
self.tx_credits = initial_tx_credits
self.rx_max_frame_size = rx_max_frame_size
self.rx_initial_credits = rx_initial_credits
self.rx_max_credits = RFCOMM_DEFAULT_MAX_CREDITS
self.rx_credits = rx_initial_credits
self.rx_credits_threshold = RFCOMM_DEFAULT_CREDIT_THRESHOLD
self.tx_max_frame_size = tx_max_frame_size
self.tx_credits = tx_initial_credits
self.tx_buffer = b''
self.state = DLC.State.INIT
self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
self.sink = None
self.connection_result = None
self.connection_result: Optional[asyncio.Future] = None
self.disconnection_result: Optional[asyncio.Future] = None
self.drained = asyncio.Event()
self.drained.set()
# Queued packets when sink is not set.
self._enqueued_rx_packets: collections.deque[bytes] = collections.deque(
maxlen=DEFAULT_RX_QUEUE_SIZE
)
self._sink: Optional[Callable[[bytes], None]] = None
# Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs
self.mtu = min(
max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead
tx_max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead
)
@property
def sink(self) -> Optional[Callable[[bytes], None]]:
return self._sink
@sink.setter
def sink(self, sink: Optional[Callable[[bytes], None]]) -> None:
self._sink = sink
# Dump queued packets to sink
if sink:
for packet in self._enqueued_rx_packets:
sink(packet) # pylint: disable=not-callable
self._enqueued_rx_packets.clear()
def change_state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {color(new_state.name, "magenta")}')
self.state = new_state
@@ -467,7 +510,7 @@ class DLC(EventEmitter):
self.multiplexer.send_frame(frame)
def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
@@ -481,9 +524,7 @@ class DLC(EventEmitter):
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
@@ -491,22 +532,35 @@ class DLC(EventEmitter):
self.emit('open')
def on_ua_frame(self, _frame: RFCOMM_Frame) -> None:
if self.state != DLC.State.CONNECTING:
if self.state == DLC.State.CONNECTING:
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.State.CONNECTED)
if self.connection_result:
self.connection_result.set_result(None)
self.connection_result = None
self.multiplexer.on_dlc_open_complete(self)
elif self.state == DLC.State.DISCONNECTING:
self.change_state(DLC.State.DISCONNECTED)
if self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
self.multiplexer.on_dlc_disconnection(self)
self.emit('close')
else:
logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red')
color(
(
'!!! received UA frame when not in '
'CONNECTING or DISCONNECTING state'
),
'red',
)
)
return
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.State.CONNECTED)
self.multiplexer.on_dlc_open_complete(self)
def on_dm_frame(self, frame: RFCOMM_Frame) -> None:
# TODO: handle all states
@@ -534,14 +588,22 @@ class DLC(EventEmitter):
f'[{self.dlci}] {len(data)} bytes, '
f'rx_credits={self.rx_credits}: {data.hex()}'
)
if len(data) and self.sink:
self.sink(data) # pylint: disable=not-callable
if data:
if self._sink:
self._sink(data) # pylint: disable=not-callable
else:
self._enqueued_rx_packets.append(data)
if (
self._enqueued_rx_packets.maxlen
and len(self._enqueued_rx_packets) >= self._enqueued_rx_packets.maxlen
):
logger.warning(f'DLC [{self.dlci}] received packet queue is full')
# Update the credits
if self.rx_credits > 0:
self.rx_credits -= 1
else:
logger.warning(color('!!! received frame with no rx credits', 'red'))
# Update the credits
if self.rx_credits > 0:
self.rx_credits -= 1
else:
logger.warning(color('!!! received frame with no rx credits', 'red'))
# Check if there's anything to send (including credits)
self.process_tx()
@@ -554,9 +616,7 @@ class DLC(EventEmitter):
# Command
logger.debug(f'<<< MCC MSC Command: {msc}')
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=0, data=bytes(msc))
logger.debug(f'>>> MCC MSC Response: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
else:
@@ -571,6 +631,19 @@ class DLC(EventEmitter):
self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
async def disconnect(self) -> None:
if self.state != DLC.State.CONNECTED:
raise InvalidStateError('invalid state')
self.disconnection_result = asyncio.get_running_loop().create_future()
self.change_state(DLC.State.DISCONNECTING)
self.send_frame(
RFCOMM_Frame.disc(
c_r=1 if self.role == Multiplexer.Role.INITIATOR else 0, dlci=self.dlci
)
)
await self.disconnection_result
def accept(self) -> None:
if self.state != DLC.State.INIT:
raise InvalidStateError('invalid state')
@@ -580,18 +653,18 @@ class DLC(EventEmitter):
cl=0xE0,
priority=7,
ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_frame_size=self.rx_max_frame_size,
max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
initial_credits=self.rx_initial_credits,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.PN, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.State.CONNECTING)
def rx_credits_needed(self) -> int:
if self.rx_credits <= self.rx_threshold:
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
if self.rx_credits <= self.rx_credits_threshold:
return self.rx_max_credits - self.rx_credits
return 0
@@ -631,6 +704,8 @@ class DLC(EventEmitter):
)
rx_credits_needed = 0
if not self.tx_buffer:
self.drained.set()
# Stream protocol
def write(self, data: Union[bytes, str]) -> None:
@@ -640,17 +715,37 @@ class DLC(EventEmitter):
# Automatically convert strings to bytes using UTF-8
data = data.encode('utf-8')
else:
raise ValueError('write only accept bytes or strings')
raise InvalidArgumentError('write only accept bytes or strings')
self.tx_buffer += data
self.drained.clear()
self.process_tx()
def drain(self) -> None:
# TODO
pass
async def drain(self) -> None:
await self.drained.wait()
def abort(self) -> None:
logger.debug(f'aborting DLC: {self}')
if self.connection_result:
self.connection_result.cancel()
self.connection_result = None
if self.disconnection_result:
self.disconnection_result.cancel()
self.disconnection_result = None
self.change_state(DLC.State.RESET)
self.emit('close')
def __str__(self) -> str:
return f'DLC(dlci={self.dlci},state={self.state.name})'
return (
f'DLC(dlci={self.dlci}, '
f'state={self.state.name}, '
f'rx_max_frame_size={self.rx_max_frame_size}, '
f'rx_credits={self.rx_credits}, '
f'rx_max_credits={self.rx_max_credits}, '
f'tx_max_frame_size={self.tx_max_frame_size}, '
f'tx_credits={self.tx_credits}'
')'
)
# -----------------------------------------------------------------------------
@@ -671,7 +766,7 @@ class Multiplexer(EventEmitter):
connection_result: Optional[asyncio.Future]
disconnection_result: Optional[asyncio.Future]
open_result: Optional[asyncio.Future]
acceptor: Optional[Callable[[int], bool]]
acceptor: Optional[Callable[[int], Optional[Tuple[int, int]]]]
dlcs: Dict[int, DLC]
def __init__(self, l2cap_channel: l2cap.ClassicChannel, role: Role) -> None:
@@ -683,11 +778,15 @@ class Multiplexer(EventEmitter):
self.connection_result = None
self.disconnection_result = None
self.open_result = None
self.open_pn: Optional[RFCOMM_MCC_PN] = None
self.open_rx_max_credits = 0
self.acceptor = None
# Become a sink for the L2CAP channel
l2cap_channel.sink = self.on_pdu
l2cap_channel.on('close', self.on_l2cap_channel_close)
def change_state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
self.state = new_state
@@ -704,7 +803,7 @@ class Multiplexer(EventEmitter):
if frame.dlci == 0:
self.on_frame(frame)
else:
if frame.type == RFCOMM_DM_FRAME:
if frame.type == FrameType.DM:
# DM responses are for a DLCI, but since we only create the dlc when we
# receive a PN response (because we need the parameters), we handle DM
# frames at the Multiplexer level
@@ -717,7 +816,7 @@ class Multiplexer(EventEmitter):
dlc.on_frame(frame)
def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
@@ -751,6 +850,7 @@ class Multiplexer(EventEmitter):
'rfcomm',
)
)
self.open_result = None
else:
logger.warning(f'unexpected state for DM: {self}')
@@ -765,10 +865,10 @@ class Multiplexer(EventEmitter):
def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
(mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
if mcc_type == RFCOMM_MCC_PN_TYPE:
if mcc_type == MccType.PN:
pn = RFCOMM_MCC_PN.from_bytes(value)
self.on_mcc_pn(c_r, pn)
elif mcc_type == RFCOMM_MCC_MSC_TYPE:
elif mcc_type == MccType.MSC:
mcs = RFCOMM_MCC_MSC.from_bytes(value)
self.on_mcc_msc(c_r, mcs)
@@ -788,9 +888,16 @@ class Multiplexer(EventEmitter):
else:
if self.acceptor:
channel_number = pn.dlci >> 1
if self.acceptor(channel_number):
if dlc_params := self.acceptor(channel_number):
# Create a new DLC
dlc = DLC(self, pn.dlci, pn.max_frame_size, pn.window_size)
dlc = DLC(
self,
dlci=pn.dlci,
tx_max_frame_size=pn.max_frame_size,
tx_initial_credits=pn.initial_credits,
rx_max_frame_size=dlc_params[0],
rx_initial_credits=dlc_params[1],
)
self.dlcs[pn.dlci] = dlc
# Re-emit the handshake completion event
@@ -808,8 +915,17 @@ class Multiplexer(EventEmitter):
# Response
logger.debug(f'>>> PN Response: {pn}')
if self.state == Multiplexer.State.OPENING:
dlc = DLC(self, pn.dlci, pn.max_frame_size, pn.window_size)
assert self.open_pn
dlc = DLC(
self,
dlci=pn.dlci,
tx_max_frame_size=pn.max_frame_size,
tx_initial_credits=pn.initial_credits,
rx_max_frame_size=self.open_pn.max_frame_size,
rx_initial_credits=self.open_pn.initial_credits,
)
self.dlcs[pn.dlci] = dlc
self.open_pn = None
dlc.connect()
else:
logger.warning('ignoring PN response')
@@ -843,24 +959,31 @@ class Multiplexer(EventEmitter):
)
await self.disconnection_result
async def open_dlc(self, channel: int) -> DLC:
async def open_dlc(
self,
channel: int,
max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
initial_credits: int = RFCOMM_DEFAULT_INITIAL_CREDITS,
) -> DLC:
if self.state != Multiplexer.State.CONNECTED:
if self.state == Multiplexer.State.OPENING:
raise InvalidStateError('open already in progress')
raise InvalidStateError('not connected')
pn = RFCOMM_MCC_PN(
self.open_pn = RFCOMM_MCC_PN(
dlci=channel << 1,
cl=0xF0,
priority=7,
ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_frame_size=max_frame_size,
max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
initial_credits=initial_credits,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}')
mcc = RFCOMM_Frame.make_mcc(
mcc_type=MccType.PN, c_r=1, data=bytes(self.open_pn)
)
logger.debug(f'>>> Sending MCC: {self.open_pn}')
self.open_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.State.OPENING)
self.send_frame(
@@ -870,15 +993,31 @@ class Multiplexer(EventEmitter):
information=mcc,
)
)
result = await self.open_result
self.open_result = None
return result
return await self.open_result
def on_dlc_open_complete(self, dlc: DLC) -> None:
logger.debug(f'DLC [{dlc.dlci}] open complete')
self.change_state(Multiplexer.State.CONNECTED)
if self.open_result:
self.open_result.set_result(dlc)
self.open_result = None
def on_dlc_disconnection(self, dlc: DLC) -> None:
logger.debug(f'DLC [{dlc.dlci}] disconnection')
self.dlcs.pop(dlc.dlci, None)
def on_l2cap_channel_close(self) -> None:
logger.debug('L2CAP channel closed, cleaning up')
if self.open_result:
self.open_result.cancel()
self.open_result = None
if self.disconnection_result:
self.disconnection_result.cancel()
self.disconnection_result = None
for dlc in self.dlcs.values():
dlc.abort()
def __str__(self) -> str:
return f'Multiplexer(state={self.state.name})'
@@ -889,8 +1028,11 @@ class Client:
multiplexer: Optional[Multiplexer]
l2cap_channel: Optional[l2cap.ClassicChannel]
def __init__(self, connection: Connection) -> None:
def __init__(
self, connection: Connection, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
) -> None:
self.connection = connection
self.l2cap_mtu = l2cap_mtu
self.l2cap_channel = None
self.multiplexer = None
@@ -898,7 +1040,7 @@ class Client:
# Create a new L2CAP connection
try:
self.l2cap_channel = await self.connection.create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(RFCOMM_PSM)
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=self.l2cap_mtu)
)
except ProtocolError as error:
logger.warning(f'L2CAP connection failed: {error}')
@@ -921,25 +1063,40 @@ class Client:
self.multiplexer = None
# Close the L2CAP channel
# TODO
if self.l2cap_channel:
await self.l2cap_channel.disconnect()
self.l2cap_channel = None
async def __aenter__(self) -> Multiplexer:
return await self.start()
async def __aexit__(self, *args) -> None:
await self.shutdown()
# -----------------------------------------------------------------------------
class Server(EventEmitter):
acceptors: Dict[int, Callable[[DLC], None]]
def __init__(self, device: Device) -> None:
def __init__(
self, device: Device, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
) -> None:
super().__init__()
self.device = device
self.multiplexer = None
self.acceptors = {}
self.acceptors: Dict[int, Callable[[DLC], None]] = {}
self.dlc_configs: Dict[int, Tuple[int, int]] = {}
# Register ourselves with the L2CAP channel manager
device.create_l2cap_server(
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM), handler=self.on_connection
self.l2cap_server = device.create_l2cap_server(
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=l2cap_mtu),
handler=self.on_connection,
)
def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int:
def listen(
self,
acceptor: Callable[[DLC], None],
channel: int = 0,
max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
initial_credits: int = RFCOMM_DEFAULT_INITIAL_CREDITS,
) -> int:
if channel:
if channel in self.acceptors:
# Busy
@@ -959,6 +1116,8 @@ class Server(EventEmitter):
return 0
self.acceptors[channel] = acceptor
self.dlc_configs[channel] = (max_frame_size, initial_credits)
return channel
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
@@ -976,13 +1135,18 @@ class Server(EventEmitter):
# Notify
self.emit('start', multiplexer)
def accept_dlc(self, channel_number: int) -> bool:
return channel_number in self.acceptors
def accept_dlc(self, channel_number: int) -> Optional[Tuple[int, int]]:
return self.dlc_configs.get(channel_number)
def on_dlc(self, dlc: DLC) -> None:
logger.debug(f'@@@ new DLC connected: {dlc}')
# Let the acceptor know
acceptor = self.acceptors.get(dlc.dlci >> 1)
if acceptor:
if acceptor := self.acceptors.get(dlc.dlci >> 1):
acceptor(dlc)
def __enter__(self) -> Self:
return self
def __exit__(self, *args) -> None:
self.l2cap_server.close()

110
bumble/rtp.py Normal file
View File

@@ -0,0 +1,110 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from typing import List
# -----------------------------------------------------------------------------
class MediaPacket:
@staticmethod
def from_bytes(data: bytes) -> MediaPacket:
version = (data[0] >> 6) & 0x03
padding = (data[0] >> 5) & 0x01
extension = (data[0] >> 4) & 0x01
csrc_count = data[0] & 0x0F
marker = (data[1] >> 7) & 0x01
payload_type = data[1] & 0x7F
sequence_number = struct.unpack_from('>H', data, 2)[0]
timestamp = struct.unpack_from('>I', data, 4)[0]
ssrc = struct.unpack_from('>I', data, 8)[0]
csrc_list = [
struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)
]
payload = data[12 + csrc_count * 4 :]
return MediaPacket(
version,
padding,
extension,
marker,
sequence_number,
timestamp,
ssrc,
csrc_list,
payload_type,
payload,
)
def __init__(
self,
version: int,
padding: int,
extension: int,
marker: int,
sequence_number: int,
timestamp: int,
ssrc: int,
csrc_list: List[int],
payload_type: int,
payload: bytes,
) -> None:
self.version = version
self.padding = padding
self.extension = extension
self.marker = marker
self.sequence_number = sequence_number & 0xFFFF
self.timestamp = timestamp & 0xFFFFFFFF
self.timestamp_seconds = 0.0
self.ssrc = ssrc
self.csrc_list = csrc_list
self.payload_type = payload_type
self.payload = payload
def __bytes__(self) -> bytes:
header = bytes(
[
self.version << 6
| self.padding << 5
| self.extension << 4
| len(self.csrc_list),
self.marker << 7 | self.payload_type,
]
) + struct.pack(
'>HII',
self.sequence_number,
self.timestamp,
self.ssrc,
)
for csrc in self.csrc_list:
header += struct.pack('>I', csrc)
return header + self.payload
def __str__(self) -> str:
return (
f'RTP(v={self.version},'
f'p={self.padding},'
f'x={self.extension},'
f'm={self.marker},'
f'pt={self.payload_type},'
f'sn={self.sequence_number},'
f'ts={self.timestamp},'
f'ssrc={self.ssrc},'
f'csrcs={self.csrc_list},'
f'payload_size={len(self.payload)})'
)

View File

@@ -19,10 +19,11 @@ from __future__ import annotations
import logging
import struct
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
from typing_extensions import Self
from . import core, l2cap
from .colors import color
from .core import InvalidStateError
from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError
from .hci import HCI_Object, name_or_number, key_with_value
if TYPE_CHECKING:
@@ -97,7 +98,8 @@ SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID = 0X000B
SDP_ICON_URL_ATTRIBUTE_ID = 0X000C
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D
# Attribute Identifier (cf. Assigned Numbers for Service Discovery)
# Profile-specific Attribute Identifiers (cf. Assigned Numbers for Service Discovery)
# used by AVRCP, HFP and A2DP
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311
@@ -115,7 +117,8 @@ SDP_ATTRIBUTE_ID_NAMES = {
SDP_DOCUMENTATION_URL_ATTRIBUTE_ID: 'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID',
SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID: 'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID',
SDP_ICON_URL_ATTRIBUTE_ID: 'SDP_ICON_URL_ATTRIBUTE_ID',
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID'
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID',
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID: 'SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID',
}
SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
@@ -186,7 +189,9 @@ class DataElement:
self.bytes = None
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
if value_size is None:
raise ValueError('integer types must have a value size specified')
raise InvalidArgumentError(
'integer types must have a value size specified'
)
@staticmethod
def nil() -> DataElement:
@@ -262,7 +267,7 @@ class DataElement:
if len(data) == 8:
return struct.unpack('>Q', data)[0]
raise ValueError(f'invalid integer length {len(data)}')
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def signed_integer_from_bytes(data):
@@ -278,7 +283,7 @@ class DataElement:
if len(data) == 8:
return struct.unpack('>q', data)[0]
raise ValueError(f'invalid integer length {len(data)}')
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def list_from_bytes(data):
@@ -339,9 +344,6 @@ class DataElement:
] # Keep a copy so we can re-serialize to an exact replica
return result
def to_bytes(self):
return bytes(self)
def __bytes__(self):
# Return early if we have a cache
if self.bytes:
@@ -351,7 +353,7 @@ class DataElement:
data = b''
elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise ValueError('UNSIGNED_INTEGER cannot be negative')
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
if self.value_size == 1:
data = struct.pack('B', self.value)
@@ -362,7 +364,7 @@ class DataElement:
elif self.value_size == 8:
data = struct.pack('>Q', self.value)
else:
raise ValueError('invalid value_size')
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.SIGNED_INTEGER:
if self.value_size == 1:
data = struct.pack('b', self.value)
@@ -373,7 +375,7 @@ class DataElement:
elif self.value_size == 8:
data = struct.pack('>q', self.value)
else:
raise ValueError('invalid value_size')
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value)))
elif self.type == DataElement.URL:
@@ -389,7 +391,7 @@ class DataElement:
size_bytes = b''
if self.type == DataElement.NIL:
if size != 0:
raise ValueError('NIL must be empty')
raise InvalidArgumentError('NIL must be empty')
size_index = 0
elif self.type in (
DataElement.UNSIGNED_INTEGER,
@@ -407,7 +409,7 @@ class DataElement:
elif size == 16:
size_index = 4
else:
raise ValueError('invalid data size')
raise InvalidArgumentError('invalid data size')
elif self.type in (
DataElement.TEXT_STRING,
DataElement.SEQUENCE,
@@ -424,11 +426,13 @@ class DataElement:
size_index = 7
size_bytes = struct.pack('>I', size)
else:
raise ValueError('invalid data size')
raise InvalidArgumentError('invalid data size')
elif self.type == DataElement.BOOLEAN:
if size != 1:
raise ValueError('boolean must be 1 byte')
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
else:
raise RuntimeError("internal error - self.type not supported")
self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
return self.bytes
@@ -616,11 +620,8 @@ class SDP_PDU:
def init_from_bytes(self, pdu, offset):
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self):
return self.pdu
def __bytes__(self):
return self.to_bytes()
return self.pdu
def __str__(self):
result = f'{color(self.name, "blue")} [TID={self.transaction_id}]'
@@ -822,11 +823,13 @@ class Client:
)
attribute_id_list = DataElement.sequence(
[
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
(
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids
]
)
@@ -878,11 +881,13 @@ class Client:
attribute_id_list = DataElement.sequence(
[
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
(
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids
]
)
@@ -918,6 +923,13 @@ class Client:
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
async def __aenter__(self) -> Self:
await self.connect()
return self
async def __aexit__(self, *args) -> None:
await self.disconnect()
# -----------------------------------------------------------------------------
class Server:
@@ -983,7 +995,7 @@ class Server:
try:
handler(sdp_pdu)
except Exception as error:
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
logger.exception(f'{color("!!! Exception in handler:", "red")} {error}')
self.send_response(
SDP_ErrorResponse(
transaction_id=sdp_pdu.transaction_id,

View File

@@ -55,6 +55,7 @@ from .core import (
BT_CENTRAL_ROLE,
BT_LE_TRANSPORT,
AdvertisingData,
InvalidArgumentError,
ProtocolError,
name_or_number,
)
@@ -297,11 +298,8 @@ class SMP_Command:
def init_from_bytes(self, pdu: bytes, offset: int) -> None:
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self):
return self.pdu
def __bytes__(self):
return self.to_bytes()
return self.pdu
def __str__(self):
result = color(self.name, 'yellow')
@@ -697,6 +695,7 @@ class Session:
self.ltk_ediv = 0
self.ltk_rand = bytes(8)
self.link_key: Optional[bytes] = None
self.maximum_encryption_key_size: int = 0
self.initiator_key_distribution: int = 0
self.responder_key_distribution: int = 0
self.peer_random_value: Optional[bytes] = None
@@ -737,12 +736,16 @@ class Session:
# Create a future that can be used to wait for the session to complete
if self.is_initiator:
self.pairing_result: Optional[
asyncio.Future[None]
] = asyncio.get_running_loop().create_future()
self.pairing_result: Optional[asyncio.Future[None]] = (
asyncio.get_running_loop().create_future()
)
else:
self.pairing_result = None
self.maximum_encryption_key_size = (
pairing_config.delegate.maximum_encryption_key_size
)
# Key Distribution (default values before negotiation)
self.initiator_key_distribution = (
pairing_config.delegate.local_initiator_key_distribution
@@ -763,11 +766,16 @@ class Session:
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
# OOB
self.oob_data_flag = 0 if pairing_config.oob is None else 1
self.oob_data_flag = (
1 if pairing_config.oob and pairing_config.oob.peer_data else 0
)
# Set up addresses
self_address = connection.self_address
self_address = connection.self_resolvable_address or connection.self_address
peer_address = connection.peer_resolvable_address or connection.peer_address
logger.debug(
f"pairing with self_address={self_address}, peer_address={peer_address}"
)
if self.is_initiator:
self.ia = bytes(self_address)
self.iat = 1 if self_address.is_random else 0
@@ -784,7 +792,7 @@ class Session:
self.peer_oob_data = pairing_config.oob.peer_data
if pairing_config.sc:
if pairing_config.oob.our_context is None:
raise ValueError(
raise InvalidArgumentError(
"oob pairing config requires a context when sc is True"
)
self.r = pairing_config.oob.our_context.r
@@ -793,7 +801,7 @@ class Session:
self.tk = pairing_config.oob.legacy_context.tk
else:
if pairing_config.oob.legacy_context is None:
raise ValueError(
raise InvalidArgumentError(
"oob pairing config requires a legacy context when sc is False"
)
self.r = bytes(16)
@@ -990,7 +998,7 @@ class Session:
io_capability=self.io_capability,
oob_data_flag=self.oob_data_flag,
auth_req=self.auth_req,
maximum_encryption_key_size=16,
maximum_encryption_key_size=self.maximum_encryption_key_size,
initiator_key_distribution=self.initiator_key_distribution,
responder_key_distribution=self.responder_key_distribution,
)
@@ -1002,7 +1010,7 @@ class Session:
io_capability=self.io_capability,
oob_data_flag=self.oob_data_flag,
auth_req=self.auth_req,
maximum_encryption_key_size=16,
maximum_encryption_key_size=self.maximum_encryption_key_size,
initiator_key_distribution=self.initiator_key_distribution,
responder_key_distribution=self.responder_key_distribution,
)
@@ -1010,8 +1018,10 @@ class Session:
self.send_command(response)
def send_pairing_confirm_command(self) -> None:
self.r = crypto.r()
logger.debug(f'generated random: {self.r.hex()}')
if self.pairing_method != PairingMethod.OOB:
self.r = crypto.r()
logger.debug(f'generated random: {self.r.hex()}')
if self.sc:
@@ -1074,11 +1084,19 @@ class Session:
)
def send_identity_address_command(self) -> None:
identity_address = {
None: self.connection.self_address,
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.random_address,
}[self.pairing_config.identity_address_type]
if self.pairing_config.identity_address_type == Address.PUBLIC_DEVICE_ADDRESS:
identity_address = self.manager.device.public_address
elif self.pairing_config.identity_address_type == Address.RANDOM_DEVICE_ADDRESS:
identity_address = self.manager.device.static_address
else:
# No identity address type set. If the controller has a public address, it
# will be more responsible to be the identity address.
if self.manager.device.public_address != Address.ANY:
logger.debug("No identity address type set, using PUBLIC")
identity_address = self.manager.device.public_address
else:
logger.debug("No identity address type set, using RANDOM")
identity_address = self.manager.device.static_address
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=identity_address.address_type,
@@ -1090,7 +1108,7 @@ class Session:
# We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection
self.manager.device.host.send_command_sync(
HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
HCI_LE_Enable_Encryption_Command(
connection_handle=self.connection.handle,
random_number=bytes(8),
encrypted_diversifier=0,
@@ -1134,8 +1152,10 @@ class Session:
async def get_link_key_and_derive_ltk(self) -> None:
'''Retrieves BR/EDR Link Key from storage and derive it to LE LTK.'''
link_key = await self.manager.device.get_link_key(self.connection.peer_address)
if link_key is None:
self.link_key = await self.manager.device.get_link_key(
self.connection.peer_address
)
if self.link_key is None:
logging.warning(
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
)
@@ -1143,7 +1163,7 @@ class Session:
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR
)
else:
self.ltk = self.derive_ltk(link_key, self.ct2)
self.ltk = self.derive_ltk(self.link_key, self.ct2)
def distribute_keys(self) -> None:
# Distribute the keys as required
@@ -1721,7 +1741,6 @@ class Session:
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
):
ra = bytes(16)
rb = ra
@@ -1729,6 +1748,22 @@ class Session:
assert self.passkey
ra = self.passkey.to_bytes(16, byteorder='little')
rb = ra
elif self.pairing_method == PairingMethod.OOB:
if self.is_initiator:
if self.peer_oob_data:
rb = self.peer_oob_data.r
ra = self.r
else:
rb = bytes(16)
ra = self.r
else:
if self.peer_oob_data:
ra = self.peer_oob_data.r
rb = self.r
else:
ra = bytes(16)
rb = self.r
else:
return
@@ -1806,7 +1841,7 @@ class Session:
if self.is_initiator:
if self.pairing_method == PairingMethod.OOB:
self.send_pairing_random_command()
else:
elif self.pairing_method == PairingMethod.PASSKEY:
self.send_pairing_confirm_command()
else:
if self.pairing_method == PairingMethod.PASSKEY:
@@ -1916,7 +1951,7 @@ class Manager(EventEmitter):
f'{connection.peer_address}: {command}'
)
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes())
connection.send_l2cap_pdu(cid, bytes(command))
def on_smp_security_request_command(
self, connection: Connection, request: SMP_Security_Request_Command
@@ -1991,10 +2026,8 @@ class Manager(EventEmitter):
) -> None:
# Store the keys in the key store
if self.device.keystore and identity_address is not None:
self.device.abort_on(
'flush', self.device.update_keys(str(identity_address), keys)
)
# Make sure on_pairing emits after key update.
await self.device.update_keys(str(identity_address), keys)
# Notify the device
self.device.on_pairing(session.connection, identity_address, keys, session.sc)

View File

@@ -23,6 +23,7 @@ import datetime
from typing import BinaryIO, Generator
import os
from bumble import core
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
@@ -138,13 +139,13 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
"""
if ':' not in spec:
raise ValueError('snooper type prefix missing')
raise core.InvalidArgumentError('snooper type prefix missing')
snooper_type, snooper_args = spec.split(':', maxsplit=1)
if snooper_type == 'btsnoop':
if ':' not in snooper_args:
raise ValueError('I/O type for btsnoop snooper type missing')
raise core.InvalidArgumentError('I/O type for btsnoop snooper type missing')
io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type == 'file':
@@ -165,6 +166,6 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
_SNOOPER_INSTANCE_COUNT -= 1
return
raise ValueError(f'I/O type {io_type} not supported')
raise core.InvalidArgumentError(f'I/O type {io_type} not supported')
raise ValueError(f'snooper type {snooper_type} not found')
raise core.InvalidArgumentError(f'snooper type {snooper_type} not found')

View File

@@ -18,8 +18,9 @@
from contextlib import asynccontextmanager
import logging
import os
from typing import Optional
from .common import Transport, AsyncPipeSink, SnoopingTransport
from .common import Transport, AsyncPipeSink, SnoopingTransport, TransportSpecError
from ..snoop import create_snooper
# -----------------------------------------------------------------------------
@@ -52,8 +53,16 @@ def _wrap_transport(transport: Transport) -> Transport:
async def open_transport(name: str) -> Transport:
"""
Open a transport by name.
The name must be <type>:<parameters>
Where <parameters> depend on the type (and may be empty for some types).
The name must be <type>:<metadata><parameters>
Where <parameters> depend on the type (and may be empty for some types), and
<metadata> is either omitted, or a ,-separated list of <key>=<value> pairs,
enclosed in [].
If there are not metadata or parameter, the : after the <type> may be omitted.
Examples:
* usb:0
* usb:[driver=rtk]0
* android-netsim
The supported types are:
* serial
* udp
@@ -71,89 +80,113 @@ async def open_transport(name: str) -> Transport:
* android-netsim
"""
return _wrap_transport(await _open_transport(name))
scheme, *tail = name.split(':', 1)
spec = tail[0] if tail else None
metadata = None
if spec:
# Metadata may precede the spec
if spec.startswith('['):
metadata_str, *tail = spec[1:].split(']')
spec = tail[0] if tail else None
metadata = dict([entry.split('=') for entry in metadata_str.split(',')])
transport = await _open_transport(scheme, spec)
if metadata:
transport.source.metadata = { # type: ignore[attr-defined]
**metadata,
**getattr(transport.source, 'metadata', {}),
}
# pylint: disable=line-too-long
logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined]
return _wrap_transport(transport)
# -----------------------------------------------------------------------------
async def _open_transport(name: str) -> Transport:
async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements
scheme, *spec = name.split(':', 1)
if scheme == 'serial' and spec:
from .serial import open_serial_transport
return await open_serial_transport(spec[0])
return await open_serial_transport(spec)
if scheme == 'udp' and spec:
from .udp import open_udp_transport
return await open_udp_transport(spec[0])
return await open_udp_transport(spec)
if scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec[0])
return await open_tcp_client_transport(spec)
if scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec[0])
return await open_tcp_server_transport(spec)
if scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport
return await open_ws_client_transport(spec[0])
return await open_ws_client_transport(spec)
if scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport
return await open_ws_server_transport(spec[0])
return await open_ws_server_transport(spec)
if scheme == 'pty':
from .pty import open_pty_transport
return await open_pty_transport(spec[0] if spec else None)
return await open_pty_transport(spec)
if scheme == 'file':
from .file import open_file_transport
assert spec is not None
return await open_file_transport(spec[0])
return await open_file_transport(spec)
if scheme == 'vhci':
from .vhci import open_vhci_transport
return await open_vhci_transport(spec[0] if spec else None)
return await open_vhci_transport(spec)
if scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec[0] if spec else None)
return await open_hci_socket_transport(spec)
if scheme == 'usb':
from .usb import open_usb_transport
assert spec is not None
return await open_usb_transport(spec[0])
assert spec
return await open_usb_transport(spec)
if scheme == 'pyusb':
from .pyusb import open_pyusb_transport
assert spec is not None
return await open_pyusb_transport(spec[0])
assert spec
return await open_pyusb_transport(spec)
if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec[0] if spec else None)
return await open_android_emulator_transport(spec)
if scheme == 'android-netsim':
from .android_netsim import open_android_netsim_transport
return await open_android_netsim_transport(spec[0] if spec else None)
return await open_android_netsim_transport(spec)
raise ValueError('unknown transport scheme')
if scheme == 'unix':
from .unix import open_unix_client_transport
assert spec
return await open_unix_client_transport(spec)
raise TransportSpecError('unknown transport scheme')
# -----------------------------------------------------------------------------
@@ -170,12 +203,13 @@ async def open_transport_or_link(name: str) -> Transport:
"""
if name.startswith('link-relay:'):
logger.warning('Link Relay has been deprecated.')
from ..controller import Controller
from ..link import RemoteLink # lazy import
link = RemoteLink(name[11:])
await link.wait_until_connected()
controller = Controller('remote', link=link)
controller = Controller('remote', link=link) # type:ignore[arg-type]
class LinkTransport(Transport):
async def close(self):

View File

@@ -20,7 +20,13 @@ import grpc.aio
from typing import Optional, Union
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
from .common import (
PumpedTransport,
PumpedPacketSource,
PumpedPacketSink,
Transport,
TransportSpecError,
)
# pylint: disable=no-name-in-module
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
@@ -69,7 +75,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
mode = 'host'
server_host = 'localhost'
server_port = '8554'
if spec is not None:
if spec:
params = spec.split(',')
for param in params:
if param.startswith('mode='):
@@ -77,7 +83,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
elif ':' in param:
server_host, server_port = param.split(':')
else:
raise ValueError('invalid parameter')
raise TransportSpecError('invalid parameter')
# Connect to the gRPC server
server_address = f'{server_host}:{server_port}'
@@ -94,7 +100,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
service = VhciForwardingServiceStub(channel)
hci_device = HciDevice(service.attachVhci())
else:
raise ValueError('invalid mode')
raise TransportSpecError('invalid mode')
# Create the transport object
class EmulatorTransport(PumpedTransport):

View File

@@ -20,29 +20,33 @@ import atexit
import logging
import os
import pathlib
import platform
import sys
from typing import Dict, Optional
import grpc.aio
from .common import (
import bumble
from bumble.transport.common import (
ParserSource,
PumpedTransport,
PumpedPacketSource,
PumpedPacketSink,
Transport,
TransportSpecError,
TransportInitError,
)
# pylint: disable=no-name-in-module
from .grpc_protobuf.packet_streamer_pb2_grpc import (
from .grpc_protobuf.netsim.packet_streamer_pb2_grpc import (
PacketStreamerStub,
PacketStreamerServicer,
add_PacketStreamerServicer_to_server,
)
from .grpc_protobuf.packet_streamer_pb2 import PacketRequest, PacketResponse
from .grpc_protobuf.hci_packet_pb2 import HCIPacket
from .grpc_protobuf.startup_pb2 import Chip, ChipInfo
from .grpc_protobuf.common_pb2 import ChipKind
from .grpc_protobuf.netsim.packet_streamer_pb2 import PacketRequest, PacketResponse
from .grpc_protobuf.netsim.hci_packet_pb2 import HCIPacket
from .grpc_protobuf.netsim.startup_pb2 import Chip, ChipInfo, DeviceInfo
from .grpc_protobuf.netsim.common_pb2 import ChipKind
# -----------------------------------------------------------------------------
@@ -56,6 +60,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
DEFAULT_NAME = 'bumble0'
DEFAULT_MANUFACTURER = 'Bumble'
DEFAULT_VARIANT = ''
# -----------------------------------------------------------------------------
@@ -68,6 +73,9 @@ def get_ini_dir() -> Optional[pathlib.Path]:
elif sys.platform == 'linux':
if xdg_runtime_dir := os.environ.get('XDG_RUNTIME_DIR', None):
return pathlib.Path(xdg_runtime_dir)
tmpdir = os.environ.get('TMPDIR', '/tmp')
if pathlib.Path(tmpdir).is_dir():
return pathlib.Path(tmpdir)
elif sys.platform == 'win32':
if local_app_data_dir := os.environ.get('LOCALAPPDATA', None):
return pathlib.Path(local_app_data_dir) / 'Temp'
@@ -135,7 +143,7 @@ async def open_android_netsim_controller_transport(
server_host: Optional[str], server_port: int, options: Dict[str, str]
) -> Transport:
if not server_port:
raise ValueError('invalid port')
raise TransportSpecError('invalid port')
if server_host == '_' or not server_host:
server_host = 'localhost'
@@ -194,7 +202,6 @@ async def open_android_netsim_controller_transport(
data = (
bytes([request.hci_packet.packet_type]) + request.hci_packet.packet
)
logger.debug(f'<<< PACKET: {data.hex()}')
self.on_data_received(data)
async def send_packet(self, data):
@@ -248,7 +255,7 @@ async def open_android_netsim_controller_transport(
# Check that we don't already have a device
if self.device:
logger.debug('busy, already serving a device')
logger.debug('Busy, already serving a device')
return PacketResponse(error='Busy')
# Instantiate a new device
@@ -288,7 +295,7 @@ async def open_android_netsim_host_transport_with_address(
instance_number = 0 if options is None else int(options.get('instance', '0'))
server_port = find_grpc_port(instance_number)
if not server_port:
raise RuntimeError('gRPC server port not found')
raise TransportInitError('gRPC server port not found')
# Connect to the gRPC server
server_address = f'{server_host}:{server_port}'
@@ -307,16 +314,24 @@ async def open_android_netsim_host_transport_with_channel(
):
# Wrapper for I/O operations
class HciDevice:
def __init__(self, name, manufacturer, hci_device):
def __init__(self, name, variant, manufacturer, hci_device):
self.name = name
self.variant = variant
self.manufacturer = manufacturer
self.hci_device = hci_device
async def start(self): # Send the startup info
chip_info = ChipInfo(
device_info = DeviceInfo(
name=self.name,
chip=Chip(kind=ChipKind.BLUETOOTH, manufacturer=self.manufacturer),
kind='BUMBLE',
version=bumble.__version__,
sdk_version=platform.python_version(),
build_id=platform.platform(),
arch=platform.machine(),
variant=self.variant,
)
chip = Chip(kind=ChipKind.BLUETOOTH, manufacturer=self.manufacturer)
chip_info = ChipInfo(name=self.name, chip=chip, device_info=device_info)
logger.debug(f'Sending chip info to netsim: {chip_info}')
await self.hci_device.write(PacketRequest(initial_info=chip_info))
@@ -326,7 +341,7 @@ async def open_android_netsim_host_transport_with_channel(
if response_type == 'error':
logger.warning(f'received error: {response.error}')
raise RuntimeError(response.error)
raise TransportInitError(response.error)
if response_type == 'hci_packet':
return (
@@ -334,7 +349,7 @@ async def open_android_netsim_host_transport_with_channel(
+ response.hci_packet.packet
)
raise ValueError('unsupported response type')
raise TransportSpecError('unsupported response type')
async def write(self, packet):
await self.hci_device.write(
@@ -344,12 +359,16 @@ async def open_android_netsim_host_transport_with_channel(
)
name = DEFAULT_NAME if options is None else options.get('name', DEFAULT_NAME)
variant = (
DEFAULT_VARIANT if options is None else options.get('variant', DEFAULT_VARIANT)
)
manufacturer = DEFAULT_MANUFACTURER
# Connect as a host
service = PacketStreamerStub(channel)
hci_device = HciDevice(
name=name,
variant=variant,
manufacturer=manufacturer,
hci_device=service.StreamPackets(),
)
@@ -399,6 +418,9 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
The "chip" name, used to identify the "chip" instance. This
may be useful when several clients are connected, since each needs to use a
different name.
variant=<variant>
The device info variant field, which may be used to convey a device or
application type (ex: "virtual-speaker", or "keyboard")
In `controller` mode:
The <host>:<port> part is required. <host> may be the address of a local network
@@ -429,7 +451,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
options: Dict[str, str] = {}
for param in params[params_offset:]:
if '=' not in param:
raise ValueError('invalid parameter, expected <name>=<value>')
raise TransportSpecError('invalid parameter, expected <name>=<value>')
option_name, option_value = param.split('=')
options[option_name] = option_value
@@ -440,7 +462,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
)
if mode == 'controller':
if host is None:
raise ValueError('<host>:<port> missing')
raise TransportSpecError('<host>:<port> missing')
return await open_android_netsim_controller_transport(host, port, options)
raise ValueError('invalid mode option')
raise TransportSpecError('invalid mode option')

View File

@@ -21,8 +21,9 @@ import struct
import asyncio
import logging
import io
from typing import ContextManager, Tuple, Optional, Protocol, Dict
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
from bumble import core
from bumble import hci
from bumble.colors import color
from bumble.snoop import Snooper
@@ -42,31 +43,36 @@ HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
hci.HCI_EVENT_PACKET: (1, 1, 'B'),
hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'),
}
# -----------------------------------------------------------------------------
# Errors
# -----------------------------------------------------------------------------
class TransportLostError(Exception):
"""
The Transport has been lost/disconnected.
"""
class TransportLostError(core.BaseBumbleError, RuntimeError):
"""The Transport has been lost/disconnected."""
class TransportInitError(core.BaseBumbleError, RuntimeError):
"""Error raised when the transport cannot be initialized."""
class TransportSpecError(core.BaseBumbleError, ValueError):
"""Error raised when the transport spec is invalid."""
# -----------------------------------------------------------------------------
# Typing Protocols
# -----------------------------------------------------------------------------
class TransportSink(Protocol):
def on_packet(self, packet: bytes) -> None:
...
def on_packet(self, packet: bytes) -> None: ...
class TransportSource(Protocol):
terminated: asyncio.Future[None]
def set_packet_sink(self, sink: TransportSink) -> None:
...
def set_packet_sink(self, sink: TransportSink) -> None: ...
# -----------------------------------------------------------------------------
@@ -133,7 +139,9 @@ class PacketParser:
packet_type
) or self.extended_packet_info.get(packet_type)
if self.packet_info is None:
raise ValueError(f'invalid packet type {packet_type}')
raise core.InvalidPacketError(
f'invalid packet type {packet_type}'
)
self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH:
@@ -167,29 +175,31 @@ class PacketReader:
def __init__(self, source: io.BufferedReader) -> None:
self.source = source
self.at_end = False
def next_packet(self) -> Optional[bytes]:
# Get the packet type
packet_type = self.source.read(1)
if len(packet_type) != 1:
self.at_end = True
return None
# Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None:
raise ValueError(f'invalid packet type {packet_type[0]} found')
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1]
header = self.source.read(header_size)
if len(header) != header_size:
raise ValueError('packet too short')
raise core.InvalidPacketError('packet too short')
# Read the body
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
body = self.source.read(body_length)
if len(body) != body_length:
raise ValueError('packet too short')
raise core.InvalidPacketError('packet too short')
return packet_type + header + body
@@ -210,7 +220,7 @@ class AsyncPacketReader:
# Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None:
raise ValueError(f'invalid packet type {packet_type[0]} found')
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1]
@@ -238,26 +248,28 @@ class AsyncPipeSink:
# -----------------------------------------------------------------------------
class ParserSource:
class BaseSource:
"""
Base class designed to be subclassed by transport-specific source classes
"""
terminated: asyncio.Future[None]
parser: PacketParser
sink: Optional[TransportSink]
def __init__(self) -> None:
self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future()
self.sink = None
def set_packet_sink(self, sink: TransportSink) -> None:
self.parser.set_packet_sink(sink)
self.sink = sink
def on_transport_lost(self) -> None:
self.terminated.set_result(None)
if self.parser.sink:
if hasattr(self.parser.sink, 'on_transport_lost'):
self.parser.sink.on_transport_lost()
if not self.terminated.done():
self.terminated.set_result(None)
if self.sink:
if hasattr(self.sink, 'on_transport_lost'):
self.sink.on_transport_lost()
async def wait_for_termination(self) -> None:
"""
@@ -270,6 +282,23 @@ class ParserSource:
pass
# -----------------------------------------------------------------------------
class ParserSource(BaseSource):
"""
Base class for sources that use an HCI parser.
"""
parser: PacketParser
def __init__(self) -> None:
super().__init__()
self.parser = PacketParser()
def set_packet_sink(self, sink: TransportSink) -> None:
super().set_packet_sink(sink)
self.parser.set_packet_sink(sink)
# -----------------------------------------------------------------------------
class StreamPacketSource(asyncio.Protocol, ParserSource):
def data_received(self, data: bytes) -> None:
@@ -341,11 +370,13 @@ class PumpedPacketSource(ParserSource):
self.parser.feed_data(packet)
except asyncio.CancelledError:
logger.debug('source pump task done')
self.terminated.set_result(None)
if not self.terminated.done():
self.terminated.set_result(None)
break
except Exception as error:
logger.warning(f'exception while waiting for packet: {error}')
self.terminated.set_exception(error)
if not self.terminated.done():
self.terminated.set_exception(error)
break
self.pump_task = asyncio.create_task(pump_packets())
@@ -419,11 +450,15 @@ class SnoopingTransport(Transport):
return SnoopingTransport(
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
)
raise RuntimeError('unexpected code path') # Satisfy the type checker
raise core.UnreachableError() # Satisfy the type checker
class Source:
sink: TransportSink
@property
def metadata(self) -> dict[str, Any]:
return getattr(self.source, 'metadata', {})
def __init__(self, source: TransportSource, snooper: Snooper):
self.source = source
self.snooper = snooper

View File

@@ -1,28 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: hci_packet.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10hci_packet.proto\x12\rnetsim.packet\"\xb2\x01\n\tHCIPacket\x12\x38\n\x0bpacket_type\x18\x01 \x01(\x0e\x32#.netsim.packet.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"[\n\nPacketType\x12\x1a\n\x16HCI_PACKET_UNSPECIFIED\x10\x00\x12\x0b\n\x07\x43OMMAND\x10\x01\x12\x07\n\x03\x41\x43L\x10\x02\x12\x07\n\x03SCO\x10\x03\x12\t\n\x05\x45VENT\x10\x04\x12\x07\n\x03ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hci_packet_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth'
_HCIPACKET._serialized_start=36
_HCIPACKET._serialized_end=214
_HCIPACKET_PACKETTYPE._serialized_start=123
_HCIPACKET_PACKETTYPE._serialized_end=214
# @@protoc_insertion_point(module_scope)

View File

@@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: common.proto
# source: netsim/common.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@@ -13,13 +14,13 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x63ommon.proto\x12\rnetsim.common*=\n\x08\x43hipKind\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\r\n\tBLUETOOTH\x10\x01\x12\x08\n\x04WIFI\x10\x02\x12\x07\n\x03UWB\x10\x03\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13netsim/common.proto\x12\rnetsim.common*S\n\x08\x43hipKind\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\r\n\tBLUETOOTH\x10\x01\x12\x08\n\x04WIFI\x10\x02\x12\x07\n\x03UWB\x10\x03\x12\x14\n\x10\x42LUETOOTH_BEACON\x10\x04\x62\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'common_pb2', globals())
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'netsim.common_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_CHIPKIND._serialized_start=31
_CHIPKIND._serialized_end=92
_globals['_CHIPKIND']._serialized_start=38
_globals['_CHIPKIND']._serialized_end=121
# @@protoc_insertion_point(module_scope)

View File

@@ -2,11 +2,17 @@ from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from typing import ClassVar as _ClassVar
BLUETOOTH: ChipKind
DESCRIPTOR: _descriptor.FileDescriptor
UNSPECIFIED: ChipKind
UWB: ChipKind
WIFI: ChipKind
class ChipKind(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = []
__slots__ = ()
UNSPECIFIED: _ClassVar[ChipKind]
BLUETOOTH: _ClassVar[ChipKind]
WIFI: _ClassVar[ChipKind]
UWB: _ClassVar[ChipKind]
BLUETOOTH_BEACON: _ClassVar[ChipKind]
UNSPECIFIED: ChipKind
BLUETOOTH: ChipKind
WIFI: ChipKind
UWB: ChipKind
BLUETOOTH_BEACON: ChipKind

View File

@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: netsim/hci_packet.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17netsim/hci_packet.proto\x12\rnetsim.packet\"\xb2\x01\n\tHCIPacket\x12\x38\n\x0bpacket_type\x18\x01 \x01(\x0e\x32#.netsim.packet.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"[\n\nPacketType\x12\x1a\n\x16HCI_PACKET_UNSPECIFIED\x10\x00\x12\x0b\n\x07\x43OMMAND\x10\x01\x12\x07\n\x03\x41\x43L\x10\x02\x12\x07\n\x03SCO\x10\x03\x12\t\n\x05\x45VENT\x10\x04\x12\x07\n\x03ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'netsim.hci_packet_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth'
_globals['_HCIPACKET']._serialized_start=43
_globals['_HCIPACKET']._serialized_end=221
_globals['_HCIPACKET_PACKETTYPE']._serialized_start=130
_globals['_HCIPACKET_PACKETTYPE']._serialized_end=221
# @@protoc_insertion_point(module_scope)

View File

@@ -6,17 +6,23 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class HCIPacket(_message.Message):
__slots__ = ["packet", "packet_type"]
__slots__ = ("packet_type", "packet")
class PacketType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = []
ACL: HCIPacket.PacketType
COMMAND: HCIPacket.PacketType
EVENT: HCIPacket.PacketType
__slots__ = ()
HCI_PACKET_UNSPECIFIED: _ClassVar[HCIPacket.PacketType]
COMMAND: _ClassVar[HCIPacket.PacketType]
ACL: _ClassVar[HCIPacket.PacketType]
SCO: _ClassVar[HCIPacket.PacketType]
EVENT: _ClassVar[HCIPacket.PacketType]
ISO: _ClassVar[HCIPacket.PacketType]
HCI_PACKET_UNSPECIFIED: HCIPacket.PacketType
ISO: HCIPacket.PacketType
PACKET_FIELD_NUMBER: _ClassVar[int]
PACKET_TYPE_FIELD_NUMBER: _ClassVar[int]
COMMAND: HCIPacket.PacketType
ACL: HCIPacket.PacketType
SCO: HCIPacket.PacketType
packet: bytes
EVENT: HCIPacket.PacketType
ISO: HCIPacket.PacketType
PACKET_TYPE_FIELD_NUMBER: _ClassVar[int]
PACKET_FIELD_NUMBER: _ClassVar[int]
packet_type: HCIPacket.PacketType
packet: bytes
def __init__(self, packet_type: _Optional[_Union[HCIPacket.PacketType, str]] = ..., packet: _Optional[bytes] = ...) -> None: ...

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,238 @@
from bumble.transport.grpc_protobuf.netsim import common_pb2 as _common_pb2
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from bumble.transport.grpc_protobuf.rootcanal import configuration_pb2 as _configuration_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class PhyKind(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
NONE: _ClassVar[PhyKind]
BLUETOOTH_CLASSIC: _ClassVar[PhyKind]
BLUETOOTH_LOW_ENERGY: _ClassVar[PhyKind]
WIFI: _ClassVar[PhyKind]
UWB: _ClassVar[PhyKind]
WIFI_RTT: _ClassVar[PhyKind]
NONE: PhyKind
BLUETOOTH_CLASSIC: PhyKind
BLUETOOTH_LOW_ENERGY: PhyKind
WIFI: PhyKind
UWB: PhyKind
WIFI_RTT: PhyKind
class Position(_message.Message):
__slots__ = ("x", "y", "z")
X_FIELD_NUMBER: _ClassVar[int]
Y_FIELD_NUMBER: _ClassVar[int]
Z_FIELD_NUMBER: _ClassVar[int]
x: float
y: float
z: float
def __init__(self, x: _Optional[float] = ..., y: _Optional[float] = ..., z: _Optional[float] = ...) -> None: ...
class Orientation(_message.Message):
__slots__ = ("yaw", "pitch", "roll")
YAW_FIELD_NUMBER: _ClassVar[int]
PITCH_FIELD_NUMBER: _ClassVar[int]
ROLL_FIELD_NUMBER: _ClassVar[int]
yaw: float
pitch: float
roll: float
def __init__(self, yaw: _Optional[float] = ..., pitch: _Optional[float] = ..., roll: _Optional[float] = ...) -> None: ...
class Chip(_message.Message):
__slots__ = ("kind", "id", "name", "manufacturer", "product_name", "bt", "ble_beacon", "uwb", "wifi", "offset")
class Radio(_message.Message):
__slots__ = ("state", "range", "tx_count", "rx_count")
STATE_FIELD_NUMBER: _ClassVar[int]
RANGE_FIELD_NUMBER: _ClassVar[int]
TX_COUNT_FIELD_NUMBER: _ClassVar[int]
RX_COUNT_FIELD_NUMBER: _ClassVar[int]
state: bool
range: float
tx_count: int
rx_count: int
def __init__(self, state: bool = ..., range: _Optional[float] = ..., tx_count: _Optional[int] = ..., rx_count: _Optional[int] = ...) -> None: ...
class Bluetooth(_message.Message):
__slots__ = ("low_energy", "classic", "address", "bt_properties")
LOW_ENERGY_FIELD_NUMBER: _ClassVar[int]
CLASSIC_FIELD_NUMBER: _ClassVar[int]
ADDRESS_FIELD_NUMBER: _ClassVar[int]
BT_PROPERTIES_FIELD_NUMBER: _ClassVar[int]
low_energy: Chip.Radio
classic: Chip.Radio
address: str
bt_properties: _configuration_pb2.Controller
def __init__(self, low_energy: _Optional[_Union[Chip.Radio, _Mapping]] = ..., classic: _Optional[_Union[Chip.Radio, _Mapping]] = ..., address: _Optional[str] = ..., bt_properties: _Optional[_Union[_configuration_pb2.Controller, _Mapping]] = ...) -> None: ...
class BleBeacon(_message.Message):
__slots__ = ("bt", "address", "settings", "adv_data", "scan_response")
class AdvertiseSettings(_message.Message):
__slots__ = ("advertise_mode", "milliseconds", "tx_power_level", "dbm", "scannable", "timeout")
class AdvertiseMode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
LOW_POWER: _ClassVar[Chip.BleBeacon.AdvertiseSettings.AdvertiseMode]
BALANCED: _ClassVar[Chip.BleBeacon.AdvertiseSettings.AdvertiseMode]
LOW_LATENCY: _ClassVar[Chip.BleBeacon.AdvertiseSettings.AdvertiseMode]
LOW_POWER: Chip.BleBeacon.AdvertiseSettings.AdvertiseMode
BALANCED: Chip.BleBeacon.AdvertiseSettings.AdvertiseMode
LOW_LATENCY: Chip.BleBeacon.AdvertiseSettings.AdvertiseMode
class AdvertiseTxPower(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
ULTRA_LOW: _ClassVar[Chip.BleBeacon.AdvertiseSettings.AdvertiseTxPower]
LOW: _ClassVar[Chip.BleBeacon.AdvertiseSettings.AdvertiseTxPower]
MEDIUM: _ClassVar[Chip.BleBeacon.AdvertiseSettings.AdvertiseTxPower]
HIGH: _ClassVar[Chip.BleBeacon.AdvertiseSettings.AdvertiseTxPower]
ULTRA_LOW: Chip.BleBeacon.AdvertiseSettings.AdvertiseTxPower
LOW: Chip.BleBeacon.AdvertiseSettings.AdvertiseTxPower
MEDIUM: Chip.BleBeacon.AdvertiseSettings.AdvertiseTxPower
HIGH: Chip.BleBeacon.AdvertiseSettings.AdvertiseTxPower
ADVERTISE_MODE_FIELD_NUMBER: _ClassVar[int]
MILLISECONDS_FIELD_NUMBER: _ClassVar[int]
TX_POWER_LEVEL_FIELD_NUMBER: _ClassVar[int]
DBM_FIELD_NUMBER: _ClassVar[int]
SCANNABLE_FIELD_NUMBER: _ClassVar[int]
TIMEOUT_FIELD_NUMBER: _ClassVar[int]
advertise_mode: Chip.BleBeacon.AdvertiseSettings.AdvertiseMode
milliseconds: int
tx_power_level: Chip.BleBeacon.AdvertiseSettings.AdvertiseTxPower
dbm: int
scannable: bool
timeout: int
def __init__(self, advertise_mode: _Optional[_Union[Chip.BleBeacon.AdvertiseSettings.AdvertiseMode, str]] = ..., milliseconds: _Optional[int] = ..., tx_power_level: _Optional[_Union[Chip.BleBeacon.AdvertiseSettings.AdvertiseTxPower, str]] = ..., dbm: _Optional[int] = ..., scannable: bool = ..., timeout: _Optional[int] = ...) -> None: ...
class AdvertiseData(_message.Message):
__slots__ = ("include_device_name", "include_tx_power_level", "manufacturer_data", "services")
class Service(_message.Message):
__slots__ = ("uuid", "data")
UUID_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
uuid: str
data: bytes
def __init__(self, uuid: _Optional[str] = ..., data: _Optional[bytes] = ...) -> None: ...
INCLUDE_DEVICE_NAME_FIELD_NUMBER: _ClassVar[int]
INCLUDE_TX_POWER_LEVEL_FIELD_NUMBER: _ClassVar[int]
MANUFACTURER_DATA_FIELD_NUMBER: _ClassVar[int]
SERVICES_FIELD_NUMBER: _ClassVar[int]
include_device_name: bool
include_tx_power_level: bool
manufacturer_data: bytes
services: _containers.RepeatedCompositeFieldContainer[Chip.BleBeacon.AdvertiseData.Service]
def __init__(self, include_device_name: bool = ..., include_tx_power_level: bool = ..., manufacturer_data: _Optional[bytes] = ..., services: _Optional[_Iterable[_Union[Chip.BleBeacon.AdvertiseData.Service, _Mapping]]] = ...) -> None: ...
BT_FIELD_NUMBER: _ClassVar[int]
ADDRESS_FIELD_NUMBER: _ClassVar[int]
SETTINGS_FIELD_NUMBER: _ClassVar[int]
ADV_DATA_FIELD_NUMBER: _ClassVar[int]
SCAN_RESPONSE_FIELD_NUMBER: _ClassVar[int]
bt: Chip.Bluetooth
address: str
settings: Chip.BleBeacon.AdvertiseSettings
adv_data: Chip.BleBeacon.AdvertiseData
scan_response: Chip.BleBeacon.AdvertiseData
def __init__(self, bt: _Optional[_Union[Chip.Bluetooth, _Mapping]] = ..., address: _Optional[str] = ..., settings: _Optional[_Union[Chip.BleBeacon.AdvertiseSettings, _Mapping]] = ..., adv_data: _Optional[_Union[Chip.BleBeacon.AdvertiseData, _Mapping]] = ..., scan_response: _Optional[_Union[Chip.BleBeacon.AdvertiseData, _Mapping]] = ...) -> None: ...
KIND_FIELD_NUMBER: _ClassVar[int]
ID_FIELD_NUMBER: _ClassVar[int]
NAME_FIELD_NUMBER: _ClassVar[int]
MANUFACTURER_FIELD_NUMBER: _ClassVar[int]
PRODUCT_NAME_FIELD_NUMBER: _ClassVar[int]
BT_FIELD_NUMBER: _ClassVar[int]
BLE_BEACON_FIELD_NUMBER: _ClassVar[int]
UWB_FIELD_NUMBER: _ClassVar[int]
WIFI_FIELD_NUMBER: _ClassVar[int]
OFFSET_FIELD_NUMBER: _ClassVar[int]
kind: _common_pb2.ChipKind
id: int
name: str
manufacturer: str
product_name: str
bt: Chip.Bluetooth
ble_beacon: Chip.BleBeacon
uwb: Chip.Radio
wifi: Chip.Radio
offset: Position
def __init__(self, kind: _Optional[_Union[_common_pb2.ChipKind, str]] = ..., id: _Optional[int] = ..., name: _Optional[str] = ..., manufacturer: _Optional[str] = ..., product_name: _Optional[str] = ..., bt: _Optional[_Union[Chip.Bluetooth, _Mapping]] = ..., ble_beacon: _Optional[_Union[Chip.BleBeacon, _Mapping]] = ..., uwb: _Optional[_Union[Chip.Radio, _Mapping]] = ..., wifi: _Optional[_Union[Chip.Radio, _Mapping]] = ..., offset: _Optional[_Union[Position, _Mapping]] = ...) -> None: ...
class ChipCreate(_message.Message):
__slots__ = ("kind", "address", "name", "manufacturer", "product_name", "ble_beacon", "bt_properties")
class BleBeaconCreate(_message.Message):
__slots__ = ("address", "settings", "adv_data", "scan_response")
ADDRESS_FIELD_NUMBER: _ClassVar[int]
SETTINGS_FIELD_NUMBER: _ClassVar[int]
ADV_DATA_FIELD_NUMBER: _ClassVar[int]
SCAN_RESPONSE_FIELD_NUMBER: _ClassVar[int]
address: str
settings: Chip.BleBeacon.AdvertiseSettings
adv_data: Chip.BleBeacon.AdvertiseData
scan_response: Chip.BleBeacon.AdvertiseData
def __init__(self, address: _Optional[str] = ..., settings: _Optional[_Union[Chip.BleBeacon.AdvertiseSettings, _Mapping]] = ..., adv_data: _Optional[_Union[Chip.BleBeacon.AdvertiseData, _Mapping]] = ..., scan_response: _Optional[_Union[Chip.BleBeacon.AdvertiseData, _Mapping]] = ...) -> None: ...
KIND_FIELD_NUMBER: _ClassVar[int]
ADDRESS_FIELD_NUMBER: _ClassVar[int]
NAME_FIELD_NUMBER: _ClassVar[int]
MANUFACTURER_FIELD_NUMBER: _ClassVar[int]
PRODUCT_NAME_FIELD_NUMBER: _ClassVar[int]
BLE_BEACON_FIELD_NUMBER: _ClassVar[int]
BT_PROPERTIES_FIELD_NUMBER: _ClassVar[int]
kind: _common_pb2.ChipKind
address: str
name: str
manufacturer: str
product_name: str
ble_beacon: ChipCreate.BleBeaconCreate
bt_properties: _configuration_pb2.Controller
def __init__(self, kind: _Optional[_Union[_common_pb2.ChipKind, str]] = ..., address: _Optional[str] = ..., name: _Optional[str] = ..., manufacturer: _Optional[str] = ..., product_name: _Optional[str] = ..., ble_beacon: _Optional[_Union[ChipCreate.BleBeaconCreate, _Mapping]] = ..., bt_properties: _Optional[_Union[_configuration_pb2.Controller, _Mapping]] = ...) -> None: ...
class Device(_message.Message):
__slots__ = ("id", "name", "visible", "position", "orientation", "chips")
ID_FIELD_NUMBER: _ClassVar[int]
NAME_FIELD_NUMBER: _ClassVar[int]
VISIBLE_FIELD_NUMBER: _ClassVar[int]
POSITION_FIELD_NUMBER: _ClassVar[int]
ORIENTATION_FIELD_NUMBER: _ClassVar[int]
CHIPS_FIELD_NUMBER: _ClassVar[int]
id: int
name: str
visible: bool
position: Position
orientation: Orientation
chips: _containers.RepeatedCompositeFieldContainer[Chip]
def __init__(self, id: _Optional[int] = ..., name: _Optional[str] = ..., visible: bool = ..., position: _Optional[_Union[Position, _Mapping]] = ..., orientation: _Optional[_Union[Orientation, _Mapping]] = ..., chips: _Optional[_Iterable[_Union[Chip, _Mapping]]] = ...) -> None: ...
class DeviceCreate(_message.Message):
__slots__ = ("name", "position", "orientation", "chips")
NAME_FIELD_NUMBER: _ClassVar[int]
POSITION_FIELD_NUMBER: _ClassVar[int]
ORIENTATION_FIELD_NUMBER: _ClassVar[int]
CHIPS_FIELD_NUMBER: _ClassVar[int]
name: str
position: Position
orientation: Orientation
chips: _containers.RepeatedCompositeFieldContainer[ChipCreate]
def __init__(self, name: _Optional[str] = ..., position: _Optional[_Union[Position, _Mapping]] = ..., orientation: _Optional[_Union[Orientation, _Mapping]] = ..., chips: _Optional[_Iterable[_Union[ChipCreate, _Mapping]]] = ...) -> None: ...
class Scene(_message.Message):
__slots__ = ("devices",)
DEVICES_FIELD_NUMBER: _ClassVar[int]
devices: _containers.RepeatedCompositeFieldContainer[Device]
def __init__(self, devices: _Optional[_Iterable[_Union[Device, _Mapping]]] = ...) -> None: ...
class Capture(_message.Message):
__slots__ = ("id", "chip_kind", "device_name", "state", "size", "records", "timestamp", "valid")
ID_FIELD_NUMBER: _ClassVar[int]
CHIP_KIND_FIELD_NUMBER: _ClassVar[int]
DEVICE_NAME_FIELD_NUMBER: _ClassVar[int]
STATE_FIELD_NUMBER: _ClassVar[int]
SIZE_FIELD_NUMBER: _ClassVar[int]
RECORDS_FIELD_NUMBER: _ClassVar[int]
TIMESTAMP_FIELD_NUMBER: _ClassVar[int]
VALID_FIELD_NUMBER: _ClassVar[int]
id: int
chip_kind: _common_pb2.ChipKind
device_name: str
state: bool
size: int
records: int
timestamp: _timestamp_pb2.Timestamp
valid: bool
def __init__(self, id: _Optional[int] = ..., chip_kind: _Optional[_Union[_common_pb2.ChipKind, str]] = ..., device_name: _Optional[str] = ..., state: bool = ..., size: _Optional[int] = ..., records: _Optional[int] = ..., timestamp: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., valid: bool = ...) -> None: ...

Some files were not shown because too many files have changed in this diff Show More