Compare commits

...

166 Commits

Author SHA1 Message Date
Josh Wu 72d821b1f6 Merge pull request #928 from zxzxwu/avdtp
AVDTP: Avoid explicit in_use management
2026-05-26 16:33:08 +08:00
Josh Wu afe064b4ea AVDTP: Make local stream endpoint in_use dyanmic property 2026-05-22 15:58:11 +08:00
Josh Wu 8d0cef70c2 AVDTP: Add keyword argument to long __init__ 2026-05-20 16:19:06 +08:00
Josh Wu 9cefde1c3e Merge pull request #923 from laozhuxinlu/avdtp_abort_issue_fix
Fix AVDTP endpoint resource leak by clearing the in_use flag on strea…
2026-05-20 16:18:35 +08:00
Clay Zhu ffb9d5f117 Fix AVDTP endpoint resource leak by clearing the in_use flag on stream close and abort commands. 2026-05-11 18:44:56 +08:00
Josh Wu 7d3be8157a Merge pull request #922 from zxzxwu/typing
Type some optional attributes
2026-05-09 16:15:42 +08:00
Josh Wu 9dc9c348e5 Merge pull request #920 from zxzxwu/avdtp
AVDTP: Make all handlers async
2026-05-07 17:39:24 +08:00
Josh Wu b18555539e Type some optional attributes 2026-05-06 17:16:40 +08:00
Josh Wu 8a853d5b2f AVDTP: Make all handlers async 2026-05-05 01:44:10 +08:00
Josh Wu 8988a85245 Merge pull request #919 from zxzxwu/sdp
SDP: Move parser functions to parser class
2026-04-29 13:21:13 +08:00
Josh Wu 0813da2278 SDP: Move parser functions to parser class 2026-04-28 13:27:50 +08:00
Gilles Boccon-Gibod a1ff183d44 Merge pull request #915 from dlech/notify-subscribers-type-hints
improve type hints for notify/indicate subscriber(s) methods
2026-04-27 21:45:38 +02:00
Gilles Boccon-Gibod 7adf44eddf Merge pull request #916 from dlech/fix-crash-in-attribute-repr
fix crash in `bumble.att.Attribute.__repr__`
2026-04-27 21:41:41 +02:00
Josh Wu 05accbf805 Merge pull request #918 from ibondarenko1/fix/avdtp-empty-pdu-guard
avdtp: bound message assembler to drop truncated PDUs (DoS prevention)
2026-04-27 10:01:51 +08:00
Josh Wu 80f54f2a09 Merge pull request #917 from dlech/fix-regex-with-backslash
Fix regex syntax warning in sdp_test.py.
2026-04-27 09:55:36 +08:00
ibondarenko1 07b5e33e09 avdtp: address review nits — use truthy checks
Per @zxzxwu review on #918:
- bumble/avdtp.py: replace `if len(pdu) < 1:` with `if not pdu:`
- tests/avdtp_test.py: replace `assert completed == []` with
  `assert not completed`

Both are idiomatic Python truthy checks; behavior identical.
2026-04-26 18:49:55 -07:00
ibondarenko1 b874e26a4f avdtp: bound message assembler to drop truncated PDUs (DoS prevention)
A remote peer can send an AVDTP frame shorter than the assembler expects.
The current MessageAssembler.on_pdu() unconditionally accesses pdu[0],
pdu[1], and (for START packets) pdu[2], so a 0-, 1-, or 2-byte frame
raises IndexError. The exception propagates up through L2CAP's read loop
and tears down the channel — same DoS class as #912 (empty ATT PDU) and
#914 (unbounded SDP recursion).

Fix: validate length before each access. Empty PDUs and packets shorter
than the type-specific minimum are logged and dropped; the assembler
stays alive so the L2CAP channel is not torn down.

- bumble/avdtp.py: length guards in MessageAssembler.on_pdu before
  accessing pdu[0], pdu[1], pdu[2].
- tests/avdtp_test.py: regression test covering empty PDU, 1-byte SINGLE,
  1-byte START, 2-byte START — all four would have raised IndexError
  pre-fix; assembler now drops without raising.
2026-04-26 18:16:15 -07:00
David Lechner baa5257780 improve type hints for notify/indicate subscriber(s) methods
Pyright expects generic type parameters to be specified for the
Attribute class, otherwise it treats the type as Unknown which can
trigger reportUnknownMemberType errors.

This can be solved by using a generic type parameter for these methods
which also has the benefit of making sure that the value parameter has
the correct type for the attribute.

In some cases, a new local `value_as_bytes` variable is needed to avoid
type errors and makes the code less confusing by not overwriting the
original `value` variable.
2026-04-26 09:43:40 -05:00
David Lechner a91ea9110c Fix regex syntax warning in sdp_test.py.
Change regex match string to raw string to avoid syntax warning:

    tests/sdp_test.py:218: SyntaxWarning: invalid escape sequence '\d'
    assert not re.search("Expect \d+ bytes, got \d+", caplog.text)

In the future, this will become an error, so we should fix it now.
2026-04-26 09:31:18 -05:00
Josh Wu 1686c5b11b Merge pull request #914 from ibondarenko1/fix/sdp-recursion-depth-limit
sdp: bound DataElement parse recursion to prevent RecursionError DoS
2026-04-26 17:22:59 +08:00
David Lechner d9481992bb fix crash in bumble.att.Attribute.__repr__
If an attribute does not contains a bytes value, it would crash with
something like:

    AttributeError: 'NoneType' object has no attribute 'hex'

Clearly, the intention here was to use `value_str` to avoid this
possibility.
2026-04-25 17:01:25 -05:00
ibondarenko1 16d0ed56cf sdp: address review nits (import at top, InvalidPacketError)
- bumble/sdp.py: replace raise ValueError with raise InvalidPacketError
  in DataElement.list_from_bytes depth guard. InvalidPacketError
  already imported at line 34 and extends ValueError so the existing
  regression test continues to match.
- tests/sdp_test.py: remove duplicate 'import pytest' inside
  test_nested_sequence_recursion_guard; pytest already imported at
  module top (line 23).

Threading.local counter left as-is per zxzxwu's 'leave it here and
refactor later' comment on the PR.
2026-04-24 11:42:49 -07:00
Ievgen Bondarenko c55eb156b8 sdp: fix lint formatting (black: blank line after import pytest) 2026-04-24 00:06:56 -07:00
ibondarenko1 8614881fb3 sdp: bound DataElement parse recursion to prevent RecursionError DoS
DataElement.from_bytes -> list_from_bytes -> (SEQUENCE/ALTERNATIVE
constructor dispatches back to list_from_bytes) had no depth limit. A
malicious SDP peer could send a PDU of a few kilobytes containing ~1000
nested SEQUENCE tags and exhaust the Python recursion stack, crashing the
host with an unhandled RecursionError propagating out of the SDP handler.

Reachable via: any remote Bluetooth device that Bumble performs SDP
service discovery against (default during Classic connection setup).

Same family as PR #912 (ATT_PDU.from_bytes empty PDU IndexError) - remote
unchecked-input parser crash in the Bluetooth stack.

Fix: thread-local depth counter, cap nesting at 32 (well above anything a
legitimate service record uses). Added two regression tests covering the
deep-nesting reject path and normal 16-level-nested SEQUENCE parsing.

Reproducer (4.5 KB payload, deterministic crash on 0.0.228):

    from bumble.sdp import DataElement
    inner = b"\x35\x00"
    for _ in range(1500):
        size = len(inner)
        if size < 65535:
            inner = bytes([0x36, (size >> 8) & 0xFF, size & 0xFF]) + inner
    DataElement.from_bytes(inner)  # RecursionError before fix

Signed-off-by: ibondarenko1 <ibondarenko1@users.noreply.github.com>
2026-04-23 00:53:06 -07:00
Josh Wu 27d02ef18d Merge pull request #913 from zxzxwu/sdp
SDP: Fix wrong parameter size
2026-04-20 16:32:37 +08:00
Josh Wu c0725e2a4a SDP: Fix wrong parameter size 2026-04-20 16:23:19 +08:00
Josh Wu bf0784dde4 Merge pull request #912 from ibondarenko1/fix/empty-pdu-crash
fix: add input validation to prevent remote crash from empty/malforme…
2026-04-20 14:36:48 +08:00
Ievgen Bondarenko 444f43f6a3 fix: address review feedback - use InvalidPacketError and abort on buffer overflow
- att.py: raise core.InvalidPacketError instead of generic ValueError
- smp.py: raise core.InvalidPacketError instead of generic ValueError
- hfp.py: add MAX_BUFFER_SIZE class constant (64KB)
- hfp.py: drop incoming data when it would overflow buffer instead of
  truncating, preserving existing partial-packet state

Per review comments on PR #912 by @zxzxwu.
2026-04-16 11:24:09 -07:00
Gilles Boccon-Gibod 2420c47cf1 Merge pull request #911 from google/gbg/issue-910
release command semaphore after timeout
2026-04-16 18:11:57 +02:00
Ievgen Bondarenko 0a78e7506b fix: add input validation to prevent remote crash from empty/malformed PDUs
Add length checks in from_bytes() for ATT and SMP protocol parsers
to prevent IndexError crashes from empty PDUs sent by remote Bluetooth
devices. Also add buffer size limit and UTF-8 error handling in HFP
protocol to prevent memory exhaustion and decode crashes.

- bumble/att.py: validate PDU is non-empty before accessing pdu[0]
- bumble/smp.py: validate PDU is non-empty before accessing pdu[0]
- bumble/hfp.py: limit buffer to 64KB, handle invalid UTF-8 gracefully

These issues can be triggered by a remote Bluetooth device sending
malformed packets, causing denial of service on the host.
2026-04-16 01:43:41 -07:00
Gilles Boccon-Gibod f7cc6f6657 release command semaphore after timeout 2026-04-15 16:54:54 +02:00
Josh Wu f2824ee6b8 Merge pull request #907 from zxzxwu/example-gatt-client-and-server
Advertise in run_gatt_client_and_server
2026-04-13 16:31:19 +08:00
Josh Wu 7188ef08de Advertise in run_gatt_client_and_server 2026-04-13 15:31:32 +08:00
Josh Wu 3ded9014d3 Merge pull request #905 from markusjellitsch/feature/debug-keys
Feature  - Add SMP Debug Mode  (Core Vol.3, Part H)
2026-04-09 15:36:42 +08:00
Josh Wu b6125bdfb1 Merge pull request #904 from zxzxwu/keys
Keys: Remove appdirs and improve typing
2026-04-09 15:30:39 +08:00
Markus Jellitsch dc17f4f1ca remove asserts 2026-04-08 20:58:47 +02:00
Markus Jellitsch 3f65380c20 remove comment 2026-04-03 23:19:43 +02:00
Markus Jellitsch 25a0056ecc remove uncommented line 2026-04-03 23:08:16 +02:00
Markus Jellitsch 85f6b10983 run formatter 2026-04-03 23:06:24 +02:00
Markus Jellitsch e85f041e9d add test for smp debug mode 2026-04-03 23:04:48 +02:00
Markus Jellitsch ee09e6f10d add smp_debug_mode config flag to enable debug keys during device init 2026-04-03 23:03:51 +02:00
Markus Jellitsch c3daf4a7e1 implement debug mode for smp manager using defined private / public key pair 2026-04-03 23:02:15 +02:00
Josh Wu 3af623be7e Keys: Remove appdirs and improve typing 2026-03-31 16:25:15 +08:00
Gilles Boccon-Gibod 4e76d3057b Merge pull request #903 from sameer/micropip-install-crypto-issue
Fix Hive demo install failure
2026-03-28 15:35:32 -04:00
Sameer Puri eda7360222 Upgrade pyodide in web fixes import error
Prior to this, these web pages fail to load with
`ImportError: cannot import name 'TypeIs' from 'typing_extensions'
(/lib/python3.11/site-packages/typing_extensions.py)`
2026-03-26 18:39:07 +00:00
Sameer Puri a4c15c00de Downgrade cryptography, fixes micropip failure
Prior to this, these web pages fail to load with

`ValueError: Can't find a pure Python 3 wheel for 'cryptography>=44.0.3;
platform_system == "Emscripten"'.`
2026-03-26 18:38:12 +00:00
Josh Wu cba4df4aef Merge pull request #900 from zxzxwu/lmp-feat
Add read classic remote features support
2026-03-24 14:03:29 +08:00
Josh Wu ceb8b448e9 Merge pull request #901 from zxzxwu/rust
Add --locked to allow installing cargo-all-features
2026-03-21 03:45:47 +08:00
Josh Wu 311b716d5c Add --locked to allow installing cargo-all-features 2026-03-20 18:44:49 +08:00
Josh Wu 0ba9e5c317 Add read classic remote features support 2026-03-20 18:32:52 +08:00
Josh Wu 3517225b62 Merge pull request #898 from zxzxwu/phy
Make ConnectionPHY dataclass
2026-03-13 12:04:45 +08:00
Josh Wu ad4bb1578b Make ConnectionPHY dataclass 2026-03-11 21:41:48 +08:00
Josh Wu 4af65b381b Merge pull request #820 from zxzxwu/sdp
SDP: Migrate to dataclasses
2026-03-04 13:45:39 +08:00
Josh Wu a5cd3365ae Merge pull request #895 from zxzxwu/uuid
Hash and cache 128 bytes of UUID
2026-03-04 00:29:43 +08:00
Josh Wu 2915cb8bb6 Add test for UUID hash 2026-03-04 00:22:50 +08:00
Josh Wu 28e485b7b3 Hash and cache 128 bytes of UUID 2026-03-03 17:54:27 +08:00
Josh Wu 1198f2c3f5 SDP: Make PDU dataclasses 2026-03-03 02:07:08 +08:00
Josh Wu 80aaf6a2b9 SDP: Make DataElement and ServiceAttribute dataclasses 2026-03-03 01:28:40 +08:00
Josh Wu eb64debb62 Merge pull request #893 from zxzxwu/le-emu
Emulation: Support LE Read features
2026-03-01 17:01:11 +08:00
Josh Wu c158f25b1e Emulation: Support LE Read features 2026-03-01 02:24:55 +08:00
Josh Wu 1330e83517 Merge pull request #892 from zxzxwu/hfp
HFP: Fix response handling
2026-02-26 13:18:03 +08:00
Josh Wu d9c9bea6cb HFP: Fix response handling 2026-02-25 00:39:45 +08:00
Gilles Boccon-Gibod 3b937631b3 Merge pull request #891 from a-detiste/main 2026-02-18 21:13:09 -08:00
Alexandre Detiste f8aa309111 fix pyproject.toml format 2026-02-18 16:39:09 +01:00
Alexandre Detiste 673281ed71 use tomllib from standard library on Python3.11+ 2026-02-18 11:11:49 +01:00
Josh Wu 3ac7af4683 Merge pull request #886 from zxzxwu/controller-status
Controller: Use new return parameter types and add _send_hci_command_status
2026-02-11 13:27:32 +08:00
Josh Wu 5ebfaae74e Controller: Use new return parameter types and add _send_hci_command_status() 2026-02-11 13:21:47 +08:00
Josh Wu e6175f85fe Merge pull request #887 from zxzxwu/gap
Remove bumble.gap
2026-02-11 13:15:39 +08:00
Josh Wu f9ba527508 Merge pull request #821 from zxzxwu/smp
Migrate most enums
2026-02-11 13:15:22 +08:00
Josh Wu a407c4cabf Merge pull request #883 from zxzxwu/avrcp
AVRCP: More delegation and bugfix
2026-02-11 13:13:16 +08:00
Josh Wu 6c2d6dddb5 Merge pull request #885 from zxzxwu/match-case
Replace long if-else with match-case
2026-02-11 13:12:38 +08:00
Josh Wu 797cd216d4 SMP: Migrate all enums 2026-02-10 20:08:01 +08:00
Josh Wu e2e8c90e47 Remove bumble.gap 2026-02-10 17:40:22 +08:00
Josh Wu 3d5648cdc3 Replace long if-else with match-case 2026-02-10 17:35:39 +08:00
Josh Wu d810d93aaf Merge pull request #884 from timrid/fix-multiple-le-connections
Connecting multiple times to a LE device is working correctly again
2026-02-06 11:25:44 +08:00
timrid 81d9adb983 delete only the required connection 2026-02-05 20:50:58 +01:00
Josh Wu 377fa896f7 Merge pull request #881 from google/dependabot/cargo/rust/cargo-f6ecf5c85a
Bump bytes from 1.5.0 to 1.11.1 in /rust in the cargo group across 1 directory
2026-02-05 23:55:37 +08:00
timrid 79e5974946 Multiple le connections are now working correctly 2026-02-05 13:15:57 +01:00
Josh Wu 657451474e AVRCP: Address type errors 2026-02-05 16:01:21 +08:00
Josh Wu 9f730dce6f AVRCP: Delegate Track Changed 2026-02-05 15:50:06 +08:00
Josh Wu 1a6be95a7e AVRCP: Delegate UID and Addressed Player 2026-02-05 15:44:11 +08:00
Josh Wu aea5320d71 AVRCP: Add Play Item delegation 2026-02-05 15:34:03 +08:00
Josh Wu 91cb1b1df3 AVRCP: Add available player changed event 2026-02-05 15:25:17 +08:00
Josh Wu 81bdc86e52 AVRCP: Delegate Player App Settings 2026-02-05 15:22:11 +08:00
Josh Wu f23cad34e3 AVRCP: Use match-case 2026-02-04 22:23:53 +08:00
Josh Wu 30fde2c00b AVRCP: Fix wrong packet field specs 2026-02-04 18:05:25 +08:00
Josh Wu 256a1a7405 Merge pull request #882 from zxzxwu/hci
Fix wrong LE event codes
2026-02-04 17:40:54 +08:00
Josh Wu 116d9b26bb Fix wrong LE event codes 2026-02-04 15:03:08 +08:00
dependabot[bot] aabe2ca063 Bump bytes in /rust in the cargo group across 1 directory
Bumps the cargo group with 1 update in the /rust directory: [bytes](https://github.com/tokio-rs/bytes).


Updates `bytes` from 1.5.0 to 1.11.1
- [Release notes](https://github.com/tokio-rs/bytes/releases)
- [Changelog](https://github.com/tokio-rs/bytes/blob/master/CHANGELOG.md)
- [Commits](https://github.com/tokio-rs/bytes/compare/v1.5.0...v1.11.1)

---
updated-dependencies:
- dependency-name: bytes
  dependency-version: 1.11.1
  dependency-type: direct:production
  dependency-group: cargo
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-02-03 20:46:35 +00:00
Gilles Boccon-Gibod 2d17a5f742 Merge pull request #880 from google/gbg/command-status-hack
add workaround for some buggy controllers
2026-02-02 23:37:52 -08:00
Gilles Boccon-Gibod 3894b14467 better handling of complete/status events 2026-02-02 23:28:40 -08:00
Gilles Boccon-Gibod e62f947430 add workaround for some buggy controllers 2026-02-02 13:19:55 -08:00
Gilles Boccon-Gibod dcb8a4b607 Merge pull request #877 from google/gbg/hci-fixes
fix a few HCI types and make the bridge more robust
2026-02-02 11:19:28 -08:00
Gilles Boccon-Gibod 81985c47a9 remove superfluous statement 2026-02-02 11:12:28 -08:00
Gilles Boccon-Gibod 7118328b07 Merge pull request #879 from google/gbg/resolve-when-bonded
resolve addresses when connecting to bonded peers
2026-01-31 11:09:55 -08:00
Gilles Boccon-Gibod 5dc01d792a address PR comments 2026-01-31 10:55:58 -08:00
Gilles Boccon-Gibod 255f357975 resolve when bonded 2026-01-30 21:53:01 -08:00
Josh Wu c86920558b Merge pull request #878 from zxzxwu/avrcp
AVRCP: SDP record classes and some delegation
2026-01-31 00:01:55 +08:00
Josh Wu 8e6efd0b2f Fix error in AVRCP example 2026-01-30 23:01:11 +08:00
Gilles Boccon-Gibod 2a59e19283 fix comment 2026-01-29 19:09:46 -08:00
Josh Wu 34f5b81c7d AVRCP: Delegate Company ID capabilities 2026-01-29 22:13:14 +08:00
Josh Wu d34d6a5c98 AVRCP: Delegate Playback Status 2026-01-29 21:33:57 +08:00
Josh Wu aedc971653 AVRCP: Add SDP record class and finder 2026-01-29 16:00:50 +08:00
Josh Wu c6815fb820 AVRCP: Delegate passthrough key event 2026-01-29 14:50:14 +08:00
Gilles Boccon-Gibod f44d013690 make bridge more robust 2026-01-27 09:47:52 -08:00
Gilles Boccon-Gibod e63dc15ede fix handling of return parameters 2026-01-27 09:39:22 -08:00
Gilles Boccon-Gibod c901e15666 fix a few HCI types and make the bridge more robust 2026-01-25 13:47:14 -08:00
Gilles Boccon-Gibod 022323b19c Merge pull request #871 from google/gbg/sci
add basic support for SCI
2026-01-24 10:39:11 -08:00
Gilles Boccon-Gibod a0d24e95e7 fix spacing_type 2026-01-24 10:15:32 -08:00
Josh Wu 7efbd303e0 Merge pull request #876 from ttdennis/await_termination_fix
Update apps and examples to await .terminated instead of wait_for_termination()
2026-01-24 11:44:19 +08:00
Dennis Heinze 49530d8d6d Update apps and examples to await .terminated 2026-01-24 00:20:55 +01:00
Gilles Boccon-Gibod 85b78b46f8 Merge pull request #870 from antipatico/feat_AV53C1 2026-01-23 13:43:12 -08:00
Josh Wu 3f9ef5aac2 Merge pull request #873 from zxzxwu/l2cap
L2CAP: Fix wrong CID on reject
2026-01-23 12:44:59 +08:00
Josh Wu e488ea9783 Merge pull request #872 from zxzxwu/avrcp
AVRCP: Fix wrong field specs
2026-01-23 12:36:14 +08:00
Josh Wu 21d937c2f1 Merge pull request #865 from willnix/pcapsnoop
Added a PcapSnooper class
2026-01-23 12:33:15 +08:00
Frieder Steinmetz a8396e6cce Formatted with black again. 2026-01-22 17:49:58 +01:00
Josh Wu 7e1b1c8f78 L2CAP: Fix wrong CID on reject 2026-01-22 23:16:25 +08:00
Josh Wu 55719bf6de AVRCP: Fix wrong field specs 2026-01-22 22:18:58 +08:00
Frieder Steinmetz 5059920696 Please mypy.\n\nTwo calls to open(), some more annotations and a rescoped global were needed. 2026-01-22 10:40:08 +01:00
Gilles Boccon-Gibod c577f17c99 add basic support for SCI 2026-01-20 15:32:55 -08:00
Gilles Boccon-Gibod 252f3e49b6 Merge pull request #870 from antipatico/feat_AV53C1 2026-01-20 10:46:52 -08:00
Jacopo Scannella f3ecf04479 Added support for STA-AV53C1-USB-BLUETOOTH StarTech(dot)com dongle - RTL8761BUE 2026-01-20 09:32:51 +01:00
Gilles Boccon-Gibod 4986f55043 Merge pull request #869 from timrid/android-fix
Make bumble work on Android using briefcase/chaquopy
2026-01-19 09:50:08 -08:00
Gilles Boccon-Gibod 7e89c8a7f8 Merge pull request #868 from google/gbg/return-parameters
typing support for HCI commands return parameters
2026-01-19 09:49:15 -08:00
timrid 085905a7bf Make bumble work on Android using briefcase that is using chaquopy under the hood. 2026-01-18 23:32:37 +01:00
Gilles Boccon-Gibod 7523118581 typing surrport for HCI commands return parameters 2026-01-17 13:19:36 -08:00
zxzxwu c619f1f21b Merge pull request #867 from zxzxwu/fix-import-error
Fix missing ClassVar import
2026-01-16 15:33:07 +08:00
Josh Wu d4b0da9265 Fix missing ClassVar import 2026-01-16 15:21:26 +08:00
zxzxwu f1058e4d4e Merge pull request #859 from istemon/att-read-by-type-request-fix
Return 'invalid handle' for malformed read by type request
2026-01-16 15:09:20 +08:00
zxzxwu 454d477d7e Merge pull request #864 from zxzxwu/hci-packets-typing
Add HCI Packets annotations and send_sco_sdu
2026-01-16 15:08:42 +08:00
zxzxwu 6966228d74 Merge pull request #863 from zxzxwu/eatt-mtu
Correct ATT_MTU in enhanced bearers
2026-01-16 15:08:12 +08:00
zxzxwu f4271a5646 Merge pull request #862 from zxzxwu/gatt-multiple
GATT: Support Multiple Requests
2026-01-16 15:08:02 +08:00
zxzxwu 534209f0af Merge pull request #861 from zxzxwu/l2cap
Replace send_pdu() with write()
2026-01-16 15:07:54 +08:00
zxzxwu 549b82999a Merge pull request #860 from zxzxwu/address
Improve Address type annotations
2026-01-16 14:04:56 +08:00
zxzxwu 551f577b2a Merge pull request #866 from zxzxwu/template-service
Fix GATT TemplateSerivce annotations
2026-01-16 09:41:48 +08:00
Frieder Steinmetz c69c1532cc Fix comments that were messed up by black 2026-01-15 19:06:03 +01:00
Frieder Steinmetz f95b2054c8 Formatted with 2026-01-15 10:50:33 +01:00
Josh Wu 84a6453dda Fix GATT TemplateSerivce annotations 2026-01-15 12:06:05 +08:00
Frieder Steinmetz 3fdd7ee45e Added the PcapSnooper class.
The class implements a bumble snooper that writes PCAP records.
It can write to either a file or a named pipe.
The latter is useful to bridge with wireshark extcap for live logging.
2026-01-14 23:40:59 +01:00
Gilles Boccon-Gibod 591ed61686 Merge pull request #858 from klow68/feat/add-usb-probe-filtering 2026-01-13 08:54:55 -08:00
Josh Wu 3d3acbb374 Add HCI Packets annotations and send_sco_sdu 2026-01-13 17:58:37 +08:00
Stryxion 671f306a27 fix: black 2026-01-13 09:42:40 +01:00
Josh Wu f7364db992 Correct ATT_MTU in enhanced bearers 2026-01-12 21:03:14 +08:00
Josh Wu 0fb2b3bd66 GATT: Support Multiple Requests 2026-01-12 20:51:38 +08:00
Stryxion 9e270d4d62 fix: mypy 2026-01-12 09:36:35 +01:00
Josh Wu cf60b5ffbb Replace send_pdu() with write() 2026-01-12 13:16:49 +08:00
Josh Wu aa4c57d105 Improve Address type annotations
* Add missing annotations
* Declare address constants as ClassVar
2026-01-12 13:07:04 +08:00
Istemon 61a601e6e2 Return 'invalid handle' for malformed read by type request 2026-01-10 01:43:30 +00:00
Stryxion 05fd4fbfc6 fix: review 2026-01-09 08:46:31 +01:00
Gilles Boccon-Gibod 2cad743f8c Merge pull request #854 from TinyServal/rtl8761cu
Add support for RTL8761CU
2026-01-08 18:37:21 -08:00
Stryxion 6aa9e0bdf7 feat: Add filtering options for usb probe 2026-01-08 14:54:58 +01:00
zxzxwu 255414f315 Merge pull request #857 from zxzxwu/testing
Add test for Heart Rate and Battery Service
2026-01-08 17:52:12 +08:00
Josh Wu d2df76f6f4 Add test for Heart Rate and Battery Service 2026-01-08 16:42:05 +08:00
zxzxwu 884b1c20e4 Merge pull request #856 from zxzxwu/typing
Add annotation for Heart Rate and Battery Service
2026-01-08 15:29:50 +08:00
Josh Wu 91a2b4f676 Add annotation for Heart Rate and Battery Service 2026-01-08 14:43:27 +08:00
Bowen Yan 5831f79d62 Add support for the RTL8761CU 2026-01-08 16:50:11 +11:00
zxzxwu 36f81b798c Merge pull request #853 from zxzxwu/l2cap
L2CAP: Fix segmentation and frame ack
2026-01-08 09:40:13 +08:00
Gilles Boccon-Gibod 985183001f Merge pull request #855 from encarbassotnopot/patch-1
docs: fix a small error in hci socket up/down commands
2026-01-07 14:26:15 -08:00
Josh Wu b153d0fcde L2CAP: Fix Enhanced Retransmission Segmentation 2026-01-07 23:49:57 +08:00
Eina Safor 30d912d66e docs: fix a small error in hci socket up/down commands 2026-01-07 15:59:14 +01:00
Bowen Yan 054dc70f3f Exclude macOS xattr files 2026-01-07 15:00:21 +11:00
zxzxwu 8ac8724cd8 Merge pull request #851 from zxzxwu/fix
Fix some typos and annotations
2026-01-06 14:02:40 +08:00
Josh Wu 4c3746a5b2 Fix some typos and annotations 2026-01-05 23:53:22 +08:00
zxzxwu 566ef967f4 Merge pull request #836 from zxzxwu/eatt
Add EATT Support
2026-01-05 22:26:17 +08:00
Josh Wu df697c6513 Add EATT Support 2026-01-04 21:51:50 +08:00
Gilles Boccon-Gibod e3e1b7bc5b Merge pull request #849 from google/gbg/auracast-multi-broadcast 2026-01-02 09:02:15 -08:00
103 changed files with 8742 additions and 4896 deletions
+1 -1
View File
@@ -69,7 +69,7 @@ jobs:
components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }}
- name: Install Rust dependencies
run: cargo install cargo-all-features --version 1.11.0 # allows building/testing combinations of features
run: cargo install cargo-all-features --version 1.11.0 --locked # allows building/testing combinations of features
- name: Check License Headers
run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build
+3
View File
@@ -17,3 +17,6 @@ venv/
.venv/
# snoop logs
out/
# macOS
.DS_Store
._*
+7 -2
View File
@@ -24,13 +24,18 @@ import dataclasses
import functools
import logging
import secrets
import sys
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
from typing import (
Any,
)
import click
import tomli
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
try:
import lc3 # type: ignore # pylint: disable=E0401
@@ -114,7 +119,7 @@ def parse_broadcast_list(filename: str) -> Sequence[Broadcast]:
broadcasts: list[Broadcast] = []
with open(filename, "rb") as config_file:
config = tomli.load(config_file)
config = tomllib.load(config_file)
for broadcast in config.get("broadcasts", []):
sources = []
for source in broadcast.get("sources", []):
+90 -114
View File
@@ -27,23 +27,17 @@ from bumble.core import name_or_number
from bumble.hci import (
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_READ_BUFFER_SIZE_V2_COMMAND,
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
HCI_LE_READ_MINIMUM_SUPPORTED_CONNECTION_INTERVAL_COMMAND,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_READ_BD_ADDR_COMMAND,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_SUCCESS,
CodecID,
HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_Read_Buffer_Size_V2_Command,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
HCI_LE_Read_Minimum_Supported_Connection_Interval_Command,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_Read_BD_ADDR_Command,
HCI_Read_Buffer_Size_Command,
@@ -59,85 +53,81 @@ from bumble.host import Host
from bumble.transport import open_transport
# -----------------------------------------------------------------------------
def command_succeeded(response):
if isinstance(response, HCI_Command_Status_Event):
return response.status == HCI_SUCCESS
if isinstance(response, HCI_Command_Complete_Event):
return response.return_parameters.status == HCI_SUCCESS
return False
# -----------------------------------------------------------------------------
async def get_classic_info(host: Host) -> None:
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response):
print()
print(
color('Public Address:', 'yellow'),
response.return_parameters.bd_addr.to_string(False),
)
response1 = await host.send_sync_command(HCI_Read_BD_ADDR_Command())
print()
print(
color('Public Address:', 'yellow'),
response1.bd_addr.to_string(False),
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response = await host.send_command(HCI_Read_Local_Name_Command())
if command_succeeded(response):
print()
print(
color('Local Name:', 'yellow'),
map_null_terminated_utf8_string(response.return_parameters.local_name),
)
response2 = await host.send_sync_command(HCI_Read_Local_Name_Command())
print()
print(
color('Local Name:', 'yellow'),
map_null_terminated_utf8_string(response2.local_name),
)
# -----------------------------------------------------------------------------
async def get_le_info(host: Host) -> None:
print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
response = await host.send_command(
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
if command_succeeded(response):
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response.return_parameters.num_supported_advertising_sets,
'\n',
)
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
host.number_of_supported_advertising_sets,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND):
response = await host.send_command(
HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
if command_succeeded(response):
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response.return_parameters.max_advertising_data_length,
'\n',
)
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
host.maximum_advertising_data_length,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
response = await host.send_command(HCI_LE_Read_Maximum_Data_Length_Command())
if command_succeeded(response):
print(
color('Maximum Data Length:', 'yellow'),
(
f'tx:{response.return_parameters.supported_max_tx_octets}/'
f'{response.return_parameters.supported_max_tx_time}, '
f'rx:{response.return_parameters.supported_max_rx_octets}/'
f'{response.return_parameters.supported_max_rx_time}'
),
'\n',
)
response1 = await host.send_sync_command(
HCI_LE_Read_Maximum_Data_Length_Command()
)
print(
color('LE Maximum Data Length:', 'yellow'),
(
f'tx:{response1.supported_max_tx_octets}/'
f'{response1.supported_max_tx_time}, '
f'rx:{response1.supported_max_rx_octets}/'
f'{response1.supported_max_rx_time}'
),
)
if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await host.send_command(
response2 = await host.send_sync_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
if command_succeeded(response):
print(
color('LE Suggested Default Data Length:', 'yellow'),
f'{response2.suggested_max_tx_octets}/'
f'{response2.suggested_max_tx_time}',
'\n',
)
if host.supports_command(HCI_LE_READ_MINIMUM_SUPPORTED_CONNECTION_INTERVAL_COMMAND):
response3 = await host.send_sync_command(
HCI_LE_Read_Minimum_Supported_Connection_Interval_Command()
)
print(
color('LE Minimum Supported Connection Interval:', 'yellow'),
f'{response3.minimum_supported_connection_interval * 125} µs',
)
for group in range(len(response3.group_min)):
print(
color('Suggested Default Data Length:', 'yellow'),
f'{response.return_parameters.suggested_max_tx_octets}/'
f'{response.return_parameters.suggested_max_tx_time}',
f' Group {group}: '
f'{response3.group_min[group] * 125} µs to '
f'{response3.group_max[group] * 125} µs '
'by increments of '
f'{response3.group_stride[group] * 125} µs',
'\n',
)
@@ -151,37 +141,31 @@ async def get_flow_control_info(host: Host) -> None:
print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
response1 = await host.send_sync_command(HCI_Read_Buffer_Size_Command())
print(
color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
f'{response1.hc_total_num_acl_data_packets} '
f'packets of size {response1.hc_acl_data_packet_length}',
)
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
response2 = await host.send_sync_command(HCI_LE_Read_Buffer_Size_V2_Command())
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.le_acl_data_packet_length}',
f'{response2.total_num_le_acl_data_packets} '
f'packets of size {response2.le_acl_data_packet_length}',
)
print(
color('LE ISO Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_iso_data_packets} '
f'packets of size {response.return_parameters.iso_data_packet_length}',
f'{response2.total_num_iso_data_packets} '
f'packets of size {response2.iso_data_packet_length}',
)
elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
response3 = await host.send_sync_command(HCI_LE_Read_Buffer_Size_Command())
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.le_acl_data_packet_length}',
f'{response3.total_num_le_acl_data_packets} '
f'packets of size {response3.le_acl_data_packet_length}',
)
@@ -190,52 +174,44 @@ async def get_codecs_info(host: Host) -> None:
print()
if host.supports_command(HCI_Read_Local_Supported_Codecs_V2_Command.op_code):
response = await host.send_command(
HCI_Read_Local_Supported_Codecs_V2_Command(), check_result=True
response1 = await host.send_sync_command(
HCI_Read_Local_Supported_Codecs_V2_Command()
)
print(color('Codecs:', 'yellow'))
for codec_id, transport in zip(
response.return_parameters.standard_codec_ids,
response.return_parameters.standard_codec_transports,
response1.standard_codec_ids,
response1.standard_codec_transports,
):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
transport
).name
codec_name = CodecID(codec_id).name
print(f' {codec_name} - {transport_name}')
print(f' {codec_id.name} - {transport.name}')
for codec_id, transport in zip(
response.return_parameters.vendor_specific_codec_ids,
response.return_parameters.vendor_specific_codec_transports,
for vendor_codec_id, vendor_transport in zip(
response1.vendor_specific_codec_ids,
response1.vendor_specific_codec_transports,
):
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
transport
).name
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF} - {transport_name}')
company = name_or_number(COMPANY_IDENTIFIERS, vendor_codec_id >> 16)
print(f' {company} / {vendor_codec_id & 0xFFFF} - {vendor_transport.name}')
if not response.return_parameters.standard_codec_ids:
if not response1.standard_codec_ids:
print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids:
if not response1.vendor_specific_codec_ids:
print(' No Vendor-specific codecs')
if host.supports_command(HCI_Read_Local_Supported_Codecs_Command.op_code):
response = await host.send_command(
HCI_Read_Local_Supported_Codecs_Command(), check_result=True
response2 = await host.send_sync_command(
HCI_Read_Local_Supported_Codecs_Command()
)
print(color('Codecs (BR/EDR):', 'yellow'))
for codec_id in response.return_parameters.standard_codec_ids:
codec_name = CodecID(codec_id).name
print(f' {codec_name}')
for codec_id in response2.standard_codec_ids:
print(f' {codec_id.name}')
for codec_id in response.return_parameters.vendor_specific_codec_ids:
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF}')
for vendor_codec_id in response2.vendor_specific_codec_ids:
company = name_or_number(COMPANY_IDENTIFIERS, vendor_codec_id >> 16)
print(f' {company} / {vendor_codec_id & 0xFFFF}')
if not response.return_parameters.standard_codec_ids:
if not response2.standard_codec_ids:
print(' No standard codecs')
if not response.return_parameters.vendor_specific_codec_ids:
if not response2.vendor_specific_codec_ids:
print(' No Vendor-specific codecs')
+11 -9
View File
@@ -85,7 +85,7 @@ class Loopback:
print(color('@@@ Received last packet', 'green'))
self.done.set()
async def run(self):
async def run(self) -> None:
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport(self.transport) as (
@@ -100,11 +100,15 @@ class Loopback:
# make sure data can fit in one l2cap pdu
l2cap_header_size = 4
max_packet_size = (
packet_queue = (
host.acl_packet_queue
if host.acl_packet_queue
else host.le_acl_packet_queue
).max_packet_size - l2cap_header_size
)
if packet_queue is None:
print(color('!!! No packet queue', 'red'))
return
max_packet_size = packet_queue.max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size:
print(
color(
@@ -128,20 +132,18 @@ class Loopback:
loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue'))
await host.send_command(
await host.send_sync_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
)
print(color('### Checking loopback mode', 'blue'))
response = await host.send_command(
HCI_Read_Loopback_Mode_Command(), check_result=True
)
if response.return_parameters.loopback_mode != loopback_mode:
response = await host.send_sync_command(HCI_Read_Loopback_Mode_Command())
if response.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red'))
return
await self.connection_event.wait()
assert self.connection_handle is not None
print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta'))
+1 -1
View File
@@ -352,7 +352,7 @@ async def run(
await bridge.start()
# Wait until the source terminates
await hci_source.wait_for_termination()
await hci_source.terminated
@click.command()
+3 -1
View File
@@ -81,7 +81,9 @@ async def async_main():
response = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci_packet.op_code,
return_parameters=bytes([hci.HCI_SUCCESS]),
return_parameters=hci.HCI_StatusReturnParameters(
status=hci.HCI_ErrorCode.SUCCESS
),
)
# Return a packet with 'respond to sender' set to True
return (bytes(response), True)
+1 -1
View File
@@ -268,7 +268,7 @@ async def run(device_config, hci_transport, bridge):
await bridge.start(device)
# Wait until the transport terminates
await hci_source.wait_for_termination()
await hci_source.terminated
# -----------------------------------------------------------------------------
+1
View File
@@ -298,6 +298,7 @@ class Speaker:
advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'),
identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
eatt_enabled=True,
)
device_config.le_enabled = True
+49 -45
View File
@@ -15,14 +15,17 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import os
from typing import ClassVar
import click
from prompt_toolkit.shortcuts import PromptSession
from bumble import data_types
from bumble import data_types, smp
from bumble.a2dp import make_audio_sink_service_sdp_records
from bumble.att import (
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
@@ -38,7 +41,7 @@ from bumble.core import (
PhysicalTransport,
ProtocolError,
)
from bumble.device import Device, Peer
from bumble.device import Connection, Device, Peer
from bumble.gatt import (
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE,
@@ -51,7 +54,6 @@ from bumble.hci import OwnAddressType
from bumble.keys import JsonKeyStore
from bumble.pairing import OobData, PairingConfig, PairingDelegate
from bumble.smp import OobContext, OobLegacyContext
from bumble.smp import error_name as smp_error_name
from bumble.transport import open_transport
from bumble.utils import AsyncRunner
@@ -63,7 +65,7 @@ POST_PAIRING_DELAY = 1
# -----------------------------------------------------------------------------
class Waiter:
instance = None
instance: ClassVar[Waiter | None] = None
def __init__(self, linger=False):
self.done = asyncio.get_running_loop().create_future()
@@ -317,35 +319,36 @@ async def on_classic_pairing(connection):
# -----------------------------------------------------------------------------
@AsyncRunner.run_in_task()
async def on_pairing_failure(connection, reason):
async def on_pairing_failure(connection: Connection, reason: smp.ErrorCode):
print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color(f'*** Pairing failed: {reason.name}', 'red'))
print(color('***-----------------------------------', 'red'))
await connection.disconnect()
Waiter.instance.terminate()
if Waiter.instance:
Waiter.instance.terminate()
# -----------------------------------------------------------------------------
async def pair(
mode,
sc,
mitm,
bond,
ctkd,
advertising_address,
identity_address,
linger,
io,
oob,
prompt,
request,
print_keys,
keystore_file,
advertise_service_uuids,
advertise_appearance,
device_config,
hci_transport,
address_or_name,
mode: str,
sc: bool,
mitm: bool,
bond: bool,
ctkd: bool,
advertising_address: str,
identity_address: str,
linger: bool,
io: str,
oob: str,
prompt: bool,
request: bool,
print_keys: bool,
keystore_file: str,
advertise_service_uuids: str,
advertise_appearance: str,
device_config: str,
hci_transport: str,
address_or_name: str,
):
Waiter.instance = Waiter(linger=linger)
@@ -403,6 +406,7 @@ async def pair(
# Create an OOB context if needed
if oob:
our_oob_context = OobContext()
legacy_context: OobLegacyContext | None
if oob == '-':
shared_data = None
legacy_context = OobLegacyContext()
@@ -661,25 +665,25 @@ class LogHandler(logging.Handler):
@click.argument('hci_transport')
@click.argument('address-or-name', required=False)
def main(
mode,
sc,
mitm,
bond,
ctkd,
advertising_address,
identity_address,
linger,
io,
oob,
prompt,
request,
print_keys,
keystore_file,
advertise_service_uuid,
advertise_appearance,
device_config,
hci_transport,
address_or_name,
mode: str,
sc: bool,
mitm: bool,
bond: bool,
ctkd: bool,
advertising_address: str,
identity_address: str,
linger: bool,
io: str,
oob: str,
prompt: bool,
request: bool,
print_keys: bool,
keystore_file: str,
advertise_service_uuid: str,
advertise_appearance: str,
device_config: str,
hci_transport: str,
address_or_name: str,
):
# Setup logging
log_handler = LogHandler()
+1 -1
View File
@@ -421,7 +421,7 @@ async def run(device_config, hci_transport, bridge):
await bridge.start(device)
# Wait until the transport terminates
await hci_source.wait_for_termination()
await hci_source.terminated
except core.ConnectionError as error:
print(color(f"!!! Bluetooth connection failed: {error}", "red"))
except Exception as error:
+10 -4
View File
@@ -22,7 +22,7 @@ import click
import bumble.logging
from bumble import data_types
from bumble.colors import color
from bumble.device import Advertisement, Device
from bumble.device import Advertisement, Device, DeviceConfiguration
from bumble.hci import HCI_LE_1M_PHY, HCI_LE_CODED_PHY, Address, HCI_Constant
from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver
@@ -144,8 +144,14 @@ async def scan(
device_config, hci_source, hci_sink
)
else:
device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
device = Device.from_config_with_hci(
DeviceConfiguration(
name='Bumble',
address=Address('F0:F1:F2:F3:F4:F5'),
keystore='JsonKeyStore',
),
hci_source,
hci_sink,
)
await device.power_on()
@@ -190,7 +196,7 @@ async def scan(
scanning_phys=scanning_phys,
)
await hci_source.wait_for_termination()
await hci_source.terminated
# -----------------------------------------------------------------------------
+1 -1
View File
@@ -726,7 +726,7 @@ class Speaker:
print("Waiting for connection...")
await self.advertise()
await hci_source.wait_for_termination()
await hci_source.terminated
for output in self.outputs:
await output.stop()
+15 -2
View File
@@ -26,6 +26,8 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from typing import Any
import click
import usb1
@@ -166,13 +168,16 @@ def is_bluetooth_hci(device):
# -----------------------------------------------------------------------------
@click.command()
@click.option('--verbose', is_flag=True, default=False, help='Print more details')
def main(verbose):
@click.option('--hci-only', is_flag=True, default=False, help='only show HCI device')
@click.option('--manufacturer', help='filter by manufacturer')
@click.option('--product', help='filter by product')
def main(verbose: bool, manufacturer: str, product: str, hci_only: bool):
bumble.logging.setup_basic_logging('WARNING')
load_libusb()
with usb1.USBContext() as context:
bluetooth_device_count = 0
devices = {}
devices: dict[tuple[Any, Any], list[str | None]] = {}
for device in context.getDeviceIterator(skip_on_error=True):
device_class = device.getDeviceClass()
@@ -234,6 +239,14 @@ def main(verbose):
f'{basic_transport_name}/{device_serial_number}'
)
# Filter
if product and device_product != product:
continue
if manufacturer and device_manufacturer != manufacturer:
continue
if not is_bluetooth_hci(device) and hci_only:
continue
# Print the results
print(
color(
+20 -38
View File
@@ -88,13 +88,6 @@ SBC_DUAL_CHANNEL_MODE = 0x01
SBC_STEREO_CHANNEL_MODE = 0x02
SBC_JOINT_STEREO_CHANNEL_MODE = 0x03
SBC_CHANNEL_MODE_NAMES = {
SBC_MONO_CHANNEL_MODE: 'SBC_MONO_CHANNEL_MODE',
SBC_DUAL_CHANNEL_MODE: 'SBC_DUAL_CHANNEL_MODE',
SBC_STEREO_CHANNEL_MODE: 'SBC_STEREO_CHANNEL_MODE',
SBC_JOINT_STEREO_CHANNEL_MODE: 'SBC_JOINT_STEREO_CHANNEL_MODE'
}
SBC_BLOCK_LENGTHS = [4, 8, 12, 16]
SBC_SUBBANDS = [4, 8]
@@ -102,11 +95,6 @@ SBC_SUBBANDS = [4, 8]
SBC_SNR_ALLOCATION_METHOD = 0x00
SBC_LOUDNESS_ALLOCATION_METHOD = 0x01
SBC_ALLOCATION_METHOD_NAMES = {
SBC_SNR_ALLOCATION_METHOD: 'SBC_SNR_ALLOCATION_METHOD',
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
}
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
@@ -129,13 +117,6 @@ MPEG_4_AAC_LC_OBJECT_TYPE = 0x01
MPEG_4_AAC_LTP_OBJECT_TYPE = 0x02
MPEG_4_AAC_SCALABLE_OBJECT_TYPE = 0x03
MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_2_AAC_LC_OBJECT_TYPE: 'MPEG_2_AAC_LC_OBJECT_TYPE',
MPEG_4_AAC_LC_OBJECT_TYPE: 'MPEG_4_AAC_LC_OBJECT_TYPE',
MPEG_4_AAC_LTP_OBJECT_TYPE: 'MPEG_4_AAC_LTP_OBJECT_TYPE',
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
}
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
@@ -267,26 +248,27 @@ class MediaCodecInformation:
def create(
cls, media_codec_type: int, data: bytes
) -> MediaCodecInformation | bytes:
if media_codec_type == CodecType.SBC:
return SbcMediaCodecInformation.from_bytes(data)
elif media_codec_type == CodecType.MPEG_2_4_AAC:
return AacMediaCodecInformation.from_bytes(data)
elif media_codec_type == CodecType.NON_A2DP:
vendor_media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(data)
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
return media_codec_information_class.from_bytes(
vendor_media_codec_information.value
match media_codec_type:
case CodecType.SBC:
return SbcMediaCodecInformation.from_bytes(data)
case CodecType.MPEG_2_4_AAC:
return AacMediaCodecInformation.from_bytes(data)
case CodecType.NON_A2DP:
vendor_media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(data)
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
return media_codec_information_class.from_bytes(
vendor_media_codec_information.value
)
return vendor_media_codec_information
@classmethod
+32 -30
View File
@@ -27,7 +27,7 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
are ignored [..], unless they are embedded in numeric or string constants"
Raises AtParsingError in case of invalid input string."""
tokens = []
tokens: list[bytearray] = []
in_quotes = False
token = bytearray()
for b in buffer:
@@ -40,23 +40,24 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
tokens.append(token[1:-1])
token = bytearray()
else:
if char == b' ':
pass
elif char == b',' or char == b')':
tokens.append(token)
tokens.append(char)
token = bytearray()
elif char == b'(':
if len(token) > 0:
raise AtParsingError("open_paren following regular character")
tokens.append(char)
elif char == b'"':
if len(token) > 0:
raise AtParsingError("quote following regular character")
in_quotes = True
token.extend(char)
else:
token.extend(char)
match char:
case b' ':
pass
case b',' | b')':
tokens.append(token)
tokens.append(char)
token = bytearray()
case b'(':
if len(token) > 0:
raise AtParsingError("open_paren following regular character")
tokens.append(char)
case b'"':
if len(token) > 0:
raise AtParsingError("quote following regular character")
in_quotes = True
token.extend(char)
case _:
token.extend(char)
tokens.append(token)
return [bytes(token) for token in tokens if len(token) > 0]
@@ -71,18 +72,19 @@ def parse_parameters(buffer: bytes) -> list[bytes | list]:
current: bytes | list = b''
for token in tokens:
if token == b',':
accumulator[-1].append(current)
current = b''
elif token == b'(':
accumulator.append([])
elif token == b')':
if len(accumulator) < 2:
raise AtParsingError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
else:
current = token
match token:
case b',':
accumulator[-1].append(current)
current = b''
case b'(':
accumulator.append([])
case b')':
if len(accumulator) < 2:
raise AtParsingError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
case _:
current = token
accumulator[-1].append(current)
if len(accumulator) > 1:
+208 -67
View File
@@ -29,17 +29,20 @@ import enum
import functools
import inspect
import struct
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Sequence
from typing import (
TYPE_CHECKING,
ClassVar,
Generic,
TypeAlias,
TypeVar,
)
from bumble import hci, utils
from typing_extensions import TypeIs
from bumble import hci, l2cap, utils
from bumble.colors import color
from bumble.core import UUID, InvalidOperationError, ProtocolError
from bumble.core import UUID, InvalidOperationError, InvalidPacketError, ProtocolError
from bumble.hci import HCI_Object
# -----------------------------------------------------------------------------
@@ -50,6 +53,14 @@ if TYPE_CHECKING:
_T = TypeVar('_T')
Bearer: TypeAlias = "Connection | l2cap.LeCreditBasedChannel"
EnhancedBearer: TypeAlias = l2cap.LeCreditBasedChannel
def is_enhanced_bearer(bearer: Bearer) -> TypeIs[EnhancedBearer]:
return isinstance(bearer, EnhancedBearer)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -58,36 +69,39 @@ _T = TypeVar('_T')
ATT_CID = 0x04
ATT_PSM = 0x001F
EATT_PSM = 0x0027
class Opcode(hci.SpecableEnum):
ATT_ERROR_RESPONSE = 0x01
ATT_EXCHANGE_MTU_REQUEST = 0x02
ATT_EXCHANGE_MTU_RESPONSE = 0x03
ATT_FIND_INFORMATION_REQUEST = 0x04
ATT_FIND_INFORMATION_RESPONSE = 0x05
ATT_FIND_BY_TYPE_VALUE_REQUEST = 0x06
ATT_FIND_BY_TYPE_VALUE_RESPONSE = 0x07
ATT_READ_BY_TYPE_REQUEST = 0x08
ATT_READ_BY_TYPE_RESPONSE = 0x09
ATT_READ_REQUEST = 0x0A
ATT_READ_RESPONSE = 0x0B
ATT_READ_BLOB_REQUEST = 0x0C
ATT_READ_BLOB_RESPONSE = 0x0D
ATT_READ_MULTIPLE_REQUEST = 0x0E
ATT_READ_MULTIPLE_RESPONSE = 0x0F
ATT_READ_BY_GROUP_TYPE_REQUEST = 0x10
ATT_READ_BY_GROUP_TYPE_RESPONSE = 0x11
ATT_WRITE_REQUEST = 0x12
ATT_WRITE_RESPONSE = 0x13
ATT_WRITE_COMMAND = 0x52
ATT_SIGNED_WRITE_COMMAND = 0xD2
ATT_PREPARE_WRITE_REQUEST = 0x16
ATT_PREPARE_WRITE_RESPONSE = 0x17
ATT_EXECUTE_WRITE_REQUEST = 0x18
ATT_EXECUTE_WRITE_RESPONSE = 0x19
ATT_HANDLE_VALUE_NOTIFICATION = 0x1B
ATT_HANDLE_VALUE_INDICATION = 0x1D
ATT_HANDLE_VALUE_CONFIRMATION = 0x1E
ATT_ERROR_RESPONSE = 0x01
ATT_EXCHANGE_MTU_REQUEST = 0x02
ATT_EXCHANGE_MTU_RESPONSE = 0x03
ATT_FIND_INFORMATION_REQUEST = 0x04
ATT_FIND_INFORMATION_RESPONSE = 0x05
ATT_FIND_BY_TYPE_VALUE_REQUEST = 0x06
ATT_FIND_BY_TYPE_VALUE_RESPONSE = 0x07
ATT_READ_BY_TYPE_REQUEST = 0x08
ATT_READ_BY_TYPE_RESPONSE = 0x09
ATT_READ_REQUEST = 0x0A
ATT_READ_RESPONSE = 0x0B
ATT_READ_BLOB_REQUEST = 0x0C
ATT_READ_BLOB_RESPONSE = 0x0D
ATT_READ_MULTIPLE_REQUEST = 0x0E
ATT_READ_MULTIPLE_RESPONSE = 0x0F
ATT_READ_BY_GROUP_TYPE_REQUEST = 0x10
ATT_READ_BY_GROUP_TYPE_RESPONSE = 0x11
ATT_READ_MULTIPLE_VARIABLE_REQUEST = 0x20
ATT_READ_MULTIPLE_VARIABLE_RESPONSE = 0x21
ATT_WRITE_REQUEST = 0x12
ATT_WRITE_RESPONSE = 0x13
ATT_WRITE_COMMAND = 0x52
ATT_SIGNED_WRITE_COMMAND = 0xD2
ATT_PREPARE_WRITE_REQUEST = 0x16
ATT_PREPARE_WRITE_RESPONSE = 0x17
ATT_EXECUTE_WRITE_REQUEST = 0x18
ATT_EXECUTE_WRITE_RESPONSE = 0x19
ATT_HANDLE_VALUE_NOTIFICATION = 0x1B
ATT_HANDLE_VALUE_INDICATION = 0x1D
ATT_HANDLE_VALUE_CONFIRMATION = 0x1E
ATT_REQUESTS = [
Opcode.ATT_EXCHANGE_MTU_REQUEST,
@@ -98,9 +112,10 @@ ATT_REQUESTS = [
Opcode.ATT_READ_BLOB_REQUEST,
Opcode.ATT_READ_MULTIPLE_REQUEST,
Opcode.ATT_READ_BY_GROUP_TYPE_REQUEST,
Opcode.ATT_READ_MULTIPLE_VARIABLE_REQUEST,
Opcode.ATT_WRITE_REQUEST,
Opcode.ATT_PREPARE_WRITE_REQUEST,
Opcode.ATT_EXECUTE_WRITE_REQUEST
Opcode.ATT_EXECUTE_WRITE_REQUEST,
]
ATT_RESPONSES = [
@@ -113,9 +128,10 @@ ATT_RESPONSES = [
Opcode.ATT_READ_BLOB_RESPONSE,
Opcode.ATT_READ_MULTIPLE_RESPONSE,
Opcode.ATT_READ_BY_GROUP_TYPE_RESPONSE,
Opcode.ATT_READ_MULTIPLE_VARIABLE_RESPONSE,
Opcode.ATT_WRITE_RESPONSE,
Opcode.ATT_PREPARE_WRITE_RESPONSE,
Opcode.ATT_EXECUTE_WRITE_RESPONSE
Opcode.ATT_EXECUTE_WRITE_RESPONSE,
]
class ErrorCode(hci.SpecableEnum):
@@ -173,6 +189,18 @@ ATT_INSUFFICIENT_RESOURCES_ERROR = ErrorCode.INSUFFICIENT_RESOURCES
ATT_DEFAULT_MTU = 23
HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'}
_SET_OF_HANDLES_METADATA = hci.metadata({
'parser': lambda data, offset: (
len(data),
[
struct.unpack_from('<H', data, i)[0]
for i in range(offset, len(data), 2)
],
),
'serializer': lambda handles: b''.join(
[struct.pack('<H', handle) for handle in handles]
),
})
# fmt: on
# pylint: enable=line-too-long
@@ -221,6 +249,8 @@ class ATT_PDU:
@classmethod
def from_bytes(cls, pdu: bytes) -> ATT_PDU:
if not pdu:
raise InvalidPacketError("Empty ATT PDU")
op_code = pdu[0]
subclass = ATT_PDU.pdu_classes.get(op_code)
@@ -542,7 +572,7 @@ class ATT_Read_Multiple_Request(ATT_PDU):
See Bluetooth spec @ Vol 3, Part F - 3.4.4.7 Read Multiple Request
'''
set_of_handles: bytes = dataclasses.field(metadata=hci.metadata("*"))
set_of_handles: Sequence[int] = dataclasses.field(metadata=_SET_OF_HANDLES_METADATA)
# -----------------------------------------------------------------------------
@@ -623,6 +653,55 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass
@dataclasses.dataclass
class ATT_Read_Multiple_Variable_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.11 Read Multiple Variable Request
'''
set_of_handles: Sequence[int] = dataclasses.field(metadata=_SET_OF_HANDLES_METADATA)
# -----------------------------------------------------------------------------
@ATT_PDU.subclass
@dataclasses.dataclass
class ATT_Read_Multiple_Variable_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.12 Read Multiple Variable Response
'''
@classmethod
def _parse_length_value_tuples(
cls, data: bytes, offset: int
) -> tuple[int, list[tuple[int, bytes]]]:
length_value_tuple_list: list[tuple[int, bytes]] = []
while offset < len(data):
length = struct.unpack_from('<H', data, offset)[0]
length_value_tuple_list.append(
(length, data[offset + 2 : offset + 2 + length])
)
offset += 2 + length
return (len(data), length_value_tuple_list)
length_value_tuple_list: Sequence[tuple[int, bytes]] = dataclasses.field(
metadata=hci.metadata(
{
'parser': lambda data, offset: ATT_Read_Multiple_Variable_Response._parse_length_value_tuples(
data, offset
),
'serializer': lambda length_value_tuple_list: b''.join(
[
struct.pack('<H', length) + value
for length, value in length_value_tuple_list
]
),
}
)
)
# -----------------------------------------------------------------------------
@ATT_PDU.subclass
@dataclasses.dataclass
@@ -780,6 +859,43 @@ class AttributeValue(Generic[_T]):
return self._write(connection, value)
# -----------------------------------------------------------------------------
class AttributeValueV2(Generic[_T]):
'''
Attribute value compatible with enhanced bearers.
The only difference between AttributeValue and AttributeValueV2 is that the actual
bearer (ACL connection for un-enhanced bearer, L2CAP channel for enhanced bearer)
will be passed into read and write callbacks in V2, while in V1 it is always
the base ACL connection.
This is only required when attributes must distinguish bearers, otherwise normal
`AttributeValue` objects are also applicable in enhanced bearers.
'''
def __init__(
self,
read: Callable[[Bearer], Awaitable[_T]] | Callable[[Bearer], _T] | None = None,
write: (
Callable[[Bearer, _T], Awaitable[None]]
| Callable[[Bearer, _T], None]
| None
) = None,
):
self._read = read
self._write = write
def read(self, bearer: Bearer) -> _T | Awaitable[_T]:
if self._read is None:
raise InvalidOperationError('AttributeValue has no read function')
return self._read(bearer)
def write(self, bearer: Bearer, value: _T) -> Awaitable[None] | None:
if self._write is None:
raise InvalidOperationError('AttributeValue has no write function')
return self._write(bearer, value)
# -----------------------------------------------------------------------------
class Attribute(utils.EventEmitter, Generic[_T]):
class Permissions(enum.IntFlag):
@@ -840,12 +956,13 @@ class Attribute(utils.EventEmitter, Generic[_T]):
self.permissions = permissions
# Convert the type to a UUID object if it isn't already
if isinstance(attribute_type, str):
self.type = UUID(attribute_type)
elif isinstance(attribute_type, bytes):
self.type = UUID.from_bytes(attribute_type)
else:
self.type = attribute_type
match attribute_type:
case str():
self.type = UUID(attribute_type)
case bytes():
self.type = UUID.from_bytes(attribute_type)
case _:
self.type = attribute_type
self.value = value
@@ -855,7 +972,8 @@ class Attribute(utils.EventEmitter, Generic[_T]):
def decode_value(self, value: bytes) -> _T:
return value # type: ignore
async def read_value(self, connection: Connection) -> bytes:
async def read_value(self, bearer: Bearer) -> bytes:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
@@ -879,25 +997,38 @@ class Attribute(utils.EventEmitter, Generic[_T]):
)
value: _T | None
if isinstance(self.value, AttributeValue):
try:
read_value = self.value.read(connection)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
value = self.value
match self.value:
case AttributeValue():
try:
read_value = self.value.read(connection)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case AttributeValueV2():
try:
read_value = self.value.read(bearer)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case _:
value = self.value
self.emit(self.EVENT_READ, connection, b'' if value is None else value)
return b'' if value is None else self.encode_value(value)
async def write_value(self, connection: Connection, value: bytes) -> None:
async def write_value(self, bearer: Bearer, value: bytes) -> None:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
if (
(self.permissions & self.WRITE_REQUIRES_ENCRYPTION)
and connection is not None
@@ -922,17 +1053,27 @@ class Attribute(utils.EventEmitter, Generic[_T]):
decoded_value = self.decode_value(value)
if isinstance(self.value, AttributeValue):
try:
result = self.value.write(connection, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
self.value = decoded_value
match self.value:
case AttributeValue():
try:
result = self.value.write(connection, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case AttributeValueV2():
try:
result = self.value.write(bearer, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case _:
self.value = decoded_value
self.emit(self.EVENT_WRITE, connection, decoded_value)
@@ -942,7 +1083,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
else:
value_str = str(self.value)
if value_str:
value_string = f', value={self.value.hex()}'
value_string = f', value={value_str}'
else:
value_string = ''
return (
+1 -1
View File
@@ -235,7 +235,7 @@ class Protocol:
)
+ payload
)
self.l2cap_channel.send_pdu(pdu)
self.l2cap_channel.write(pdu)
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
logger.debug(
+143 -80
View File
@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import abc
import asyncio
import enum
import logging
@@ -268,7 +269,7 @@ class MediaPacketPump:
await self.clock.sleep(delay)
# Emit
rtp_channel.send_pdu(bytes(packet))
rtp_channel.write(bytes(packet))
logger.debug(
f'{color(">>> sending RTP packet:", "green")} {packet}'
)
@@ -311,6 +312,13 @@ class MessageAssembler:
def on_pdu(self, pdu: bytes) -> None:
self.packet_count += 1
# Drop empty PDUs sent by remote — accessing pdu[0] below would
# raise IndexError, propagating up to the L2CAP read loop and
# tearing down the channel. Same class as #912 (ATT empty PDU).
if not pdu:
logger.warning('AVDTP message assembler: empty PDU dropped')
return
transaction_label = pdu[0] >> 4
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
message_type = Message.MessageType(pdu[0] & 3)
@@ -324,6 +332,23 @@ class MessageAssembler:
Protocol.PacketType.SINGLE_PACKET,
Protocol.PacketType.START_PACKET,
):
# Both single and start packets carry the signal identifier in
# pdu[1]; start packets additionally carry the packet count in
# pdu[2]. Guard each access so a malformed remote frame can't
# crash the message assembler.
if len(pdu) < 2:
logger.warning(
'AVDTP %s packet too short (%d bytes); dropped',
packet_type.name,
len(pdu),
)
return
if packet_type == Protocol.PacketType.START_PACKET and len(pdu) < 3:
logger.warning(
'AVDTP START packet missing signal-packet count; dropped'
)
return
if self.message is not None:
# The previous message has not been terminated
logger.warning(
@@ -1453,8 +1478,23 @@ class Protocol(utils.EventEmitter):
handler = getattr(self, handler_name, None)
if handler:
try:
response = handler(message)
self.send_message(transaction_label, response)
result = handler(message)
if asyncio.iscoroutine(result):
async def wait_and_send() -> None:
try:
response = await result
if response:
self.send_message(transaction_label, response)
except Exception:
logger.exception(
color("!!! Exception in handler:", "red")
)
utils.cancel_on_event(self, self.EVENT_CLOSE, wait_and_send())
else:
if result:
self.send_message(transaction_label, result)
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
else:
@@ -1519,7 +1559,7 @@ class Protocol(utils.EventEmitter):
header = bytes([first_header_byte])
# Send one packet
self.l2cap_channel.send_pdu(header + payload[:max_fragment_size])
self.l2cap_channel.write(header + payload[:max_fragment_size])
# Prepare for the next packet
payload = payload[max_fragment_size:]
@@ -1535,7 +1575,7 @@ class Protocol(utils.EventEmitter):
async def send_command(self, command: Message):
# TODO: support timeouts
# Send the command
(transaction_label, transaction_result) = await self.start_transaction()
transaction_label, transaction_result = await self.start_transaction()
self.send_message(transaction_label, command)
# Wait for the response
@@ -1600,14 +1640,14 @@ class Protocol(utils.EventEmitter):
async def abort(self, seid: int) -> Abort_Response:
return await self.send_command(Abort_Command(seid))
def on_discover_command(self, command: Discover_Command) -> Message | None:
async def on_discover_command(self, command: Discover_Command) -> Message | None:
endpoint_infos = [
EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep)
for endpoint in self.local_endpoints
]
return Discover_Response(endpoint_infos)
def on_get_capabilities_command(
async def on_get_capabilities_command(
self, command: Get_Capabilities_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1616,7 +1656,7 @@ class Protocol(utils.EventEmitter):
return Get_Capabilities_Response(endpoint.capabilities)
def on_get_all_capabilities_command(
async def on_get_all_capabilities_command(
self, command: Get_All_Capabilities_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1625,7 +1665,7 @@ class Protocol(utils.EventEmitter):
return Get_All_Capabilities_Response(endpoint.capabilities)
def on_set_configuration_command(
async def on_set_configuration_command(
self, command: Set_Configuration_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1640,10 +1680,10 @@ class Protocol(utils.EventEmitter):
stream = Stream(self, endpoint, StreamEndPointProxy(self, command.int_seid))
self.streams[command.acp_seid] = stream
result = stream.on_set_configuration_command(command.capabilities)
result = await stream.on_set_configuration_command(command.capabilities)
return result or Set_Configuration_Response()
def on_get_configuration_command(
async def on_get_configuration_command(
self, command: Get_Configuration_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
@@ -1652,29 +1692,31 @@ class Protocol(utils.EventEmitter):
if endpoint.stream is None:
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
return endpoint.stream.on_get_configuration_command()
return await endpoint.stream.on_get_configuration_command()
def on_reconfigure_command(self, command: Reconfigure_Command) -> Message | None:
async def on_reconfigure_command(
self, command: Reconfigure_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Reconfigure_Reject(error_code=AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = endpoint.stream.on_reconfigure_command(command.capabilities)
result = await endpoint.stream.on_reconfigure_command(command.capabilities)
return result or Reconfigure_Response()
def on_open_command(self, command: Open_Command) -> Message | None:
async def on_open_command(self, command: Open_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = endpoint.stream.on_open_command()
result = await endpoint.stream.on_open_command()
return result or Open_Response()
def on_start_command(self, command: Start_Command) -> Message | None:
async def on_start_command(self, command: Start_Command) -> Message | None:
for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None:
@@ -1688,12 +1730,12 @@ class Protocol(utils.EventEmitter):
endpoint = self.get_local_endpoint_by_seid(seid)
if not endpoint or not endpoint.stream:
raise InvalidStateError("Should already be checked!")
if (result := endpoint.stream.on_start_command()) is not None:
if (result := await endpoint.stream.on_start_command()) is not None:
return result
return Start_Response()
def on_suspend_command(self, command: Suspend_Command) -> Message | None:
async def on_suspend_command(self, command: Suspend_Command) -> Message | None:
for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None:
@@ -1707,45 +1749,47 @@ class Protocol(utils.EventEmitter):
endpoint = self.get_local_endpoint_by_seid(seid)
if not endpoint or not endpoint.stream:
raise InvalidStateError("Should already be checked!")
if (result := endpoint.stream.on_suspend_command()) is not None:
if (result := await endpoint.stream.on_suspend_command()) is not None:
return result
return Suspend_Response()
def on_close_command(self, command: Close_Command) -> Message | None:
async def on_close_command(self, command: Close_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Close_Reject(AVDTP_BAD_STATE_ERROR)
result = endpoint.stream.on_close_command()
result = await endpoint.stream.on_close_command()
return result or Close_Response()
def on_abort_command(self, command: Abort_Command) -> Message | None:
async def on_abort_command(self, command: Abort_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None or endpoint.stream is None:
return Abort_Response()
endpoint.stream.on_abort_command()
await endpoint.stream.on_abort_command()
return Abort_Response()
def on_security_control_command(
async def on_security_control_command(
self, command: Security_Control_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR)
result = endpoint.on_security_control_command(command.data)
result = await endpoint.on_security_control_command(command.data)
return result or Security_Control_Response()
def on_delayreport_command(self, command: DelayReport_Command) -> Message | None:
async def on_delayreport_command(
self, command: DelayReport_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR)
result = endpoint.on_delayreport_command(command.delay)
result = await endpoint.on_delayreport_command(command.delay)
return result or DelayReport_Response()
@@ -1829,7 +1873,7 @@ class Stream:
def send_media_packet(self, packet: MediaPacket) -> None:
assert self.rtp_channel
self.rtp_channel.send_pdu(bytes(packet))
self.rtp_channel.write(bytes(packet))
async def configure(self) -> None:
if self.state != State.IDLE:
@@ -1903,25 +1947,22 @@ class Stream:
await self.rtp_channel.disconnect()
self.rtp_channel = None
# Release the endpoint
self.local_endpoint.in_use = 0
self.change_state(State.IDLE)
def on_set_configuration_command(
async def on_set_configuration_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
if self.state != State.IDLE:
return Set_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_set_configuration_command(configuration)
result = await self.local_endpoint.on_set_configuration_command(configuration)
if result is not None:
return result
self.change_state(State.CONFIGURED)
return None
def on_get_configuration_command(self) -> Message | None:
async def on_get_configuration_command(self) -> Message | None:
if self.state not in (
State.CONFIGURED,
State.OPEN,
@@ -1929,25 +1970,25 @@ class Stream:
):
return Get_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)
return self.local_endpoint.on_get_configuration_command()
return await self.local_endpoint.on_get_configuration_command()
def on_reconfigure_command(
async def on_reconfigure_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
if self.state != State.OPEN:
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_reconfigure_command(configuration)
result = await self.local_endpoint.on_reconfigure_command(configuration)
if result is not None:
return result
return None
def on_open_command(self) -> Message | None:
async def on_open_command(self) -> Message | None:
if self.state != State.CONFIGURED:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_open_command()
result = await self.local_endpoint.on_open_command()
if result is not None:
return result
@@ -1957,7 +1998,7 @@ class Stream:
self.change_state(State.OPEN)
return None
def on_start_command(self) -> Message | None:
async def on_start_command(self) -> Message | None:
if self.state != State.OPEN:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
@@ -1966,29 +2007,29 @@ class Stream:
logger.warning('received start command before RTP channel establishment')
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_start_command()
result = await self.local_endpoint.on_start_command()
if result is not None:
return result
self.change_state(State.STREAMING)
return None
def on_suspend_command(self) -> Message | None:
async def on_suspend_command(self) -> Message | None:
if self.state != State.STREAMING:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_suspend_command()
result = await self.local_endpoint.on_suspend_command()
if result is not None:
return result
self.change_state(State.OPEN)
return None
def on_close_command(self) -> Message | None:
async def on_close_command(self) -> Message | None:
if self.state not in (State.OPEN, State.STREAMING):
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_close_command()
result = await self.local_endpoint.on_close_command()
if result is not None:
return result
@@ -2003,7 +2044,8 @@ class Stream:
return None
def on_abort_command(self) -> Message | None:
async def on_abort_command(self) -> Message | None:
await self.local_endpoint.on_abort_command()
if self.rtp_channel is None:
# No need to wait
self.change_state(State.IDLE)
@@ -2028,7 +2070,6 @@ class Stream:
def on_l2cap_channel_close(self) -> None:
logger.debug(color('<<< stream channel closed', 'magenta'))
self.local_endpoint.on_rtp_channel_close()
self.local_endpoint.in_use = 0
self.rtp_channel = None
if self.state in (State.CLOSING, State.ABORTING):
@@ -2053,7 +2094,6 @@ class Stream:
self.state = State.IDLE
local_endpoint.stream = self
local_endpoint.in_use = 1
def __str__(self) -> str:
return (
@@ -2063,14 +2103,16 @@ class Stream:
# -----------------------------------------------------------------------------
@dataclass
class StreamEndPoint:
class StreamEndPoint(abc.ABC):
seid: int
media_type: MediaType
tsep: StreamEndPointType
in_use: int
capabilities: Iterable[ServiceCapabilities]
@property
def in_use(self) -> int:
raise NotImplementedError()
# -----------------------------------------------------------------------------
class StreamEndPointProxy:
@@ -2110,14 +2152,30 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
in_use: int,
capabilities: Iterable[ServiceCapabilities],
) -> None:
StreamEndPoint.__init__(self, seid, media_type, tsep, in_use, capabilities)
StreamEndPointProxy.__init__(self, protocol, seid)
# StreamEndPoint attributes
self.seid = seid
self.media_type = media_type
self.tsep = tsep
self._in_use = in_use
self.capabilities = capabilities
StreamEndPointProxy.__init__(self, protocol=protocol, seid=seid)
@property
def in_use(self) -> int:
return self._in_use
# -----------------------------------------------------------------------------
class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
stream: Stream | None
@property
def in_use(self) -> int:
if self.stream and self.stream.state != State.IDLE:
return 1
return 0
EVENT_CONFIGURATION = "configuration"
EVENT_OPEN = "open"
EVENT_START = "start"
@@ -2140,8 +2198,13 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
capabilities: Iterable[ServiceCapabilities],
configuration: Iterable[ServiceCapabilities] | None = None,
):
StreamEndPoint.__init__(self, seid, media_type, tsep, 0, capabilities)
utils.EventEmitter.__init__(self)
# StreamEndPoint attributes
self.seid = seid
self.media_type = media_type
self.tsep = tsep
self.capabilities = capabilities
self.protocol = protocol
self.configuration = configuration if configuration is not None else []
self.stream = None
@@ -2155,13 +2218,13 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
async def close(self) -> None:
"""[Source Only] Handles when receiving close command."""
def on_reconfigure_command(
async def on_reconfigure_command(
self, command: Iterable[ServiceCapabilities]
) -> Message | None:
del command # unused.
return None
def on_set_configuration_command(
async def on_set_configuration_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
logger.debug(
@@ -2172,34 +2235,34 @@ class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
self.emit(self.EVENT_CONFIGURATION)
return None
def on_get_configuration_command(self) -> Message | None:
async def on_get_configuration_command(self) -> Message | None:
return Get_Configuration_Response(self.configuration)
def on_open_command(self) -> Message | None:
async def on_open_command(self) -> Message | None:
self.emit(self.EVENT_OPEN)
return None
def on_start_command(self) -> Message | None:
async def on_start_command(self) -> Message | None:
self.emit(self.EVENT_START)
return None
def on_suspend_command(self) -> Message | None:
async def on_suspend_command(self) -> Message | None:
self.emit(self.EVENT_SUSPEND)
return None
def on_close_command(self) -> Message | None:
async def on_close_command(self) -> Message | None:
self.emit(self.EVENT_CLOSE)
return None
def on_abort_command(self) -> Message | None:
async def on_abort_command(self) -> Message | None:
self.emit(self.EVENT_ABORT)
return None
def on_delayreport_command(self, delay: int) -> Message | None:
async def on_delayreport_command(self, delay: int) -> Message | None:
self.emit(self.EVENT_DELAY_REPORT, delay)
return None
def on_security_control_command(self, data: bytes) -> Message | None:
async def on_security_control_command(self, data: bytes) -> Message | None:
self.emit(self.EVENT_SECURITY_CONTROL, data)
return None
@@ -2227,12 +2290,12 @@ class LocalSource(LocalStreamEndPoint):
codec_capabilities,
] + list(other_capabilities)
super().__init__(
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SRC,
capabilities,
capabilities,
protocol=protocol,
seid=seid,
media_type=codec_capabilities.media_type,
tsep=AVDTP_TSEP_SRC,
capabilities=capabilities,
configuration=capabilities,
)
self.packet_pump = packet_pump
@@ -2251,13 +2314,13 @@ class LocalSource(LocalStreamEndPoint):
self.emit(self.EVENT_STOP)
@override
def on_start_command(self) -> Message | None:
asyncio.create_task(self.start())
async def on_start_command(self) -> Message | None:
await self.start()
return None
@override
def on_suspend_command(self) -> Message | None:
asyncio.create_task(self.stop())
async def on_suspend_command(self) -> Message | None:
await self.stop()
return None
@@ -2271,11 +2334,11 @@ class LocalSink(LocalStreamEndPoint):
codec_capabilities,
]
super().__init__(
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SNK,
capabilities,
protocol=protocol,
seid=seid,
media_type=codec_capabilities.media_type,
tsep=AVDTP_TSEP_SNK,
capabilities=capabilities,
)
def on_rtp_channel_open(self) -> None:
+692 -259
View File
File diff suppressed because it is too large Load Diff
+10 -2
View File
@@ -37,7 +37,12 @@ class HCI_Bridge:
def on_packet(self, packet):
# Convert the packet bytes to an object
hci_packet = HCI_Packet.from_bytes(packet)
try:
hci_packet = HCI_Packet.from_bytes(packet)
except Exception:
logger.warning('forwarding unparsed packet as-is')
self.hci_sink.on_packet(packet)
return
# Filter the packet
if self.packet_filter is not None:
@@ -50,7 +55,10 @@ class HCI_Bridge:
return
# Analyze the packet
self.trace(hci_packet)
try:
self.trace(hci_packet)
except Exception:
logger.exception('Exception while tracing packet')
# Bridge the packet
self.hci_sink.on_packet(packet)
+724 -564
View File
File diff suppressed because it is too large Load Diff
+83 -75
View File
@@ -19,6 +19,7 @@ from __future__ import annotations
import dataclasses
import enum
import functools
import struct
from collections.abc import Iterable
from typing import (
@@ -273,6 +274,18 @@ class UUID:
def parse_uuid_2(cls, uuid_as_bytes: bytes, offset: int) -> tuple[int, UUID]:
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
@functools.cached_property
def uuid_128_bytes(self) -> bytes:
match len(self.uuid_bytes):
case 2:
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
case 4:
return self.BASE_UUID + self.uuid_bytes
case 16:
return self.uuid_bytes
case _:
assert False, "unreachable"
def to_bytes(self, force_128: bool = False) -> bytes:
'''
Serialize UUID in little-endian byte-order
@@ -280,14 +293,7 @@ class UUID:
if not force_128:
return self.uuid_bytes
if len(self.uuid_bytes) == 2:
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
elif len(self.uuid_bytes) == 4:
return self.BASE_UUID + self.uuid_bytes
elif len(self.uuid_bytes) == 16:
return self.uuid_bytes
else:
assert False, "unreachable"
return self.uuid_128_bytes
def to_pdu_bytes(self) -> bytes:
'''
@@ -317,7 +323,7 @@ class UUID:
def __eq__(self, other: object) -> bool:
if isinstance(other, UUID):
return self.to_bytes(force_128=True) == other.to_bytes(force_128=True)
return self.uuid_128_bytes == other.uuid_128_bytes
if isinstance(other, str):
return UUID(other) == self
@@ -325,7 +331,7 @@ class UUID:
return False
def __hash__(self) -> int:
return hash(self.uuid_bytes)
return hash(self.uuid_128_bytes)
def __str__(self) -> str:
result = self.to_hex_str(separator='-')
@@ -923,7 +929,7 @@ class DeviceClass:
# pylint: enable=line-too-long
@staticmethod
def split_class_of_device(class_of_device):
def split_class_of_device(class_of_device: int) -> tuple[int, int, int]:
# Split the bit fields of the composite class of device value into:
# (service_classes, major_device_class, minor_device_class)
return (
@@ -1769,66 +1775,71 @@ class AdvertisingData:
@classmethod
def ad_data_to_string(cls, ad_type: int, ad_data: bytes) -> str:
if ad_type == AdvertisingData.FLAGS:
ad_type_str = 'Flags'
ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True)
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 16-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 32-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 128-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:2])
ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}'
elif ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:4])
ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}'
elif ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:16])
ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}'
elif ad_type == AdvertisingData.SHORTENED_LOCAL_NAME:
ad_type_str = 'Shortened Local Name'
ad_data_str = f'"{ad_data.decode("utf-8")}"'
elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME:
ad_type_str = 'Complete Local Name'
try:
match ad_type:
case AdvertisingData.FLAGS:
ad_type_str = 'Flags'
ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True)
case AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 16-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
case AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
case AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 32-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
case AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
case AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Complete List of 128-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
case AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs'
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
case AdvertisingData.SERVICE_DATA_16_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:2])
ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}'
case AdvertisingData.SERVICE_DATA_32_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:4])
ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}'
case AdvertisingData.SERVICE_DATA_128_BIT_UUID:
ad_type_str = 'Service Data'
uuid = UUID.from_bytes(ad_data[:16])
ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}'
case AdvertisingData.SHORTENED_LOCAL_NAME:
ad_type_str = 'Shortened Local Name'
ad_data_str = f'"{ad_data.decode("utf-8")}"'
except UnicodeDecodeError:
case AdvertisingData.COMPLETE_LOCAL_NAME:
ad_type_str = 'Complete Local Name'
try:
ad_data_str = f'"{ad_data.decode("utf-8")}"'
except UnicodeDecodeError:
ad_data_str = ad_data.hex()
case AdvertisingData.TX_POWER_LEVEL:
ad_type_str = 'TX Power Level'
ad_data_str = str(ad_data[0])
case AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
ad_type_str = 'Manufacturer Specific Data'
company_id = struct.unpack_from('<H', ad_data, 0)[0]
company_name = COMPANY_IDENTIFIERS.get(
company_id, f'0x{company_id:04X}'
)
ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}'
case AdvertisingData.APPEARANCE:
ad_type_str = 'Appearance'
appearance = Appearance.from_int(
struct.unpack_from('<H', ad_data, 0)[0]
)
ad_data_str = str(appearance)
case AdvertisingData.BROADCAST_NAME:
ad_type_str = 'Broadcast Name'
ad_data_str = ad_data.decode('utf-8')
case _:
ad_type_str = AdvertisingData.Type(ad_type).name
ad_data_str = ad_data.hex()
elif ad_type == AdvertisingData.TX_POWER_LEVEL:
ad_type_str = 'TX Power Level'
ad_data_str = str(ad_data[0])
elif ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
ad_type_str = 'Manufacturer Specific Data'
company_id = struct.unpack_from('<H', ad_data, 0)[0]
company_name = COMPANY_IDENTIFIERS.get(company_id, f'0x{company_id:04X}')
ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}'
elif ad_type == AdvertisingData.APPEARANCE:
ad_type_str = 'Appearance'
appearance = Appearance.from_int(struct.unpack_from('<H', ad_data, 0)[0])
ad_data_str = str(appearance)
elif ad_type == AdvertisingData.BROADCAST_NAME:
ad_type_str = 'Broadcast Name'
ad_data_str = ad_data.decode('utf-8')
else:
ad_type_str = AdvertisingData.Type(ad_type).name
ad_data_str = ad_data.hex()
return f'[{ad_type_str}]: {ad_data_str}'
@@ -2105,13 +2116,10 @@ class AdvertisingData:
# -----------------------------------------------------------------------------
# Connection PHY
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ConnectionPHY:
def __init__(self, tx_phy, rx_phy):
self.tx_phy = tx_phy
self.rx_phy = rx_phy
def __str__(self):
return f'ConnectionPHY(tx_phy={self.tx_phy}, rx_phy={self.rx_phy})'
tx_phy: int
rx_phy: int
# -----------------------------------------------------------------------------
+1142 -627
View File
File diff suppressed because it is too large Load Diff
+91 -84
View File
@@ -89,51 +89,54 @@ HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND = hci.hci_vendor_command_op_code(0x000E)
hci.HCI_Command.register_commands(globals())
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Intel_Read_Version_Command(hci.HCI_Command):
class HCI_Intel_Read_Version_ReturnParameters(hci.HCI_StatusReturnParameters):
tlv: bytes = hci.field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Read_Version_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Read_Version_Command(
hci.HCI_SyncCommand[HCI_Intel_Read_Version_ReturnParameters]
):
param0: int = dataclasses.field(metadata=hci.metadata(1))
return_parameters_fields = [
("status", hci.STATUS_SPEC),
("tlv", "*"),
]
@hci.HCI_Command.command
@hci.HCI_SyncCommand.sync_command(hci.HCI_StatusReturnParameters)
@dataclasses.dataclass
class Hci_Intel_Secure_Send_Command(hci.HCI_Command):
class Hci_Intel_Secure_Send_Command(
hci.HCI_SyncCommand[hci.HCI_StatusReturnParameters]
):
data_type: int = dataclasses.field(metadata=hci.metadata(1))
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
("status", 1),
]
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Intel_Reset_Command(hci.HCI_Command):
class HCI_Intel_Reset_ReturnParameters(hci.HCI_ReturnParameters):
data: bytes = hci.field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Reset_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Reset_Command(hci.HCI_SyncCommand[HCI_Intel_Reset_ReturnParameters]):
reset_type: int = dataclasses.field(metadata=hci.metadata(1))
patch_enable: int = dataclasses.field(metadata=hci.metadata(1))
ddc_reload: int = dataclasses.field(metadata=hci.metadata(1))
boot_option: int = dataclasses.field(metadata=hci.metadata(1))
boot_address: int = dataclasses.field(metadata=hci.metadata(4))
return_parameters_fields = [
("data", "*"),
]
@hci.HCI_Command.command
@dataclasses.dataclass
class Hci_Intel_Write_Device_Config_Command(hci.HCI_Command):
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
class HCI_Intel_Write_Device_Config_ReturnParameters(hci.HCI_StatusReturnParameters):
params: bytes = hci.field(metadata=hci.metadata('*'))
return_parameters_fields = [
("status", hci.STATUS_SPEC),
("params", "*"),
]
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Write_Device_Config_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Write_Device_Config_Command(
hci.HCI_SyncCommand[HCI_Intel_Write_Device_Config_ReturnParameters]
):
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
# -----------------------------------------------------------------------------
@@ -198,50 +201,51 @@ def _parse_tlv(data: bytes) -> list[tuple[ValueType, Any]]:
value = data[2 : 2 + value_length]
typed_value: Any
if value_type == ValueType.END:
break
match value_type:
case ValueType.END:
break
if value_type in (ValueType.CNVI, ValueType.CNVR):
(v,) = struct.unpack("<I", value)
typed_value = (
(((v >> 0) & 0xF) << 12)
| (((v >> 4) & 0xF) << 0)
| (((v >> 8) & 0xF) << 4)
| (((v >> 24) & 0xF) << 8)
)
elif value_type == ValueType.HARDWARE_INFO:
(v,) = struct.unpack("<I", value)
typed_value = HardwareInfo(
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
)
elif value_type in (
ValueType.USB_VENDOR_ID,
ValueType.USB_PRODUCT_ID,
ValueType.DEVICE_REVISION,
):
(typed_value,) = struct.unpack("<H", value)
elif value_type == ValueType.CURRENT_MODE_OF_OPERATION:
typed_value = ModeOfOperation(value[0])
elif value_type in (
ValueType.BUILD_TYPE,
ValueType.BUILD_NUMBER,
ValueType.SECURE_BOOT,
ValueType.OTP_LOCK,
ValueType.API_LOCK,
ValueType.DEBUG_LOCK,
ValueType.SECURE_BOOT_ENGINE_TYPE,
):
typed_value = value[0]
elif value_type == ValueType.TIMESTAMP:
typed_value = Timestamp(value[0], value[1])
elif value_type == ValueType.FIRMWARE_BUILD:
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
elif value_type == ValueType.BLUETOOTH_ADDRESS:
typed_value = hci.Address(
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
)
else:
typed_value = value
case ValueType.CNVI | ValueType.CNVR:
(v,) = struct.unpack("<I", value)
typed_value = (
(((v >> 0) & 0xF) << 12)
| (((v >> 4) & 0xF) << 0)
| (((v >> 8) & 0xF) << 4)
| (((v >> 24) & 0xF) << 8)
)
case ValueType.HARDWARE_INFO:
(v,) = struct.unpack("<I", value)
typed_value = HardwareInfo(
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
)
case (
ValueType.USB_VENDOR_ID
| ValueType.USB_PRODUCT_ID
| ValueType.DEVICE_REVISION
):
(typed_value,) = struct.unpack("<H", value)
case ValueType.CURRENT_MODE_OF_OPERATION:
typed_value = ModeOfOperation(value[0])
case (
ValueType.BUILD_TYPE
| ValueType.BUILD_NUMBER
| ValueType.SECURE_BOOT
| ValueType.OTP_LOCK
| ValueType.API_LOCK
| ValueType.DEBUG_LOCK
| ValueType.SECURE_BOOT_ENGINE_TYPE
):
typed_value = value[0]
case ValueType.TIMESTAMP:
typed_value = Timestamp(value[0], value[1])
case ValueType.FIRMWARE_BUILD:
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
case ValueType.BLUETOOTH_ADDRESS:
typed_value = hci.Address(
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
)
case _:
typed_value = value
result.append((value_type, typed_value))
data = data[2 + value_length :]
@@ -402,7 +406,7 @@ class Driver(common.Driver):
self.host.on_hci_event_packet(event)
return
if not event.return_parameters == hci.HCI_SUCCESS:
if not event.return_parameters.status == hci.HCI_SUCCESS:
raise DriverError("HCI_Command_Complete_Event error")
if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets:
@@ -641,8 +645,8 @@ class Driver(common.Driver):
while ddc_data:
ddc_len = 1 + ddc_data[0]
ddc_payload = ddc_data[:ddc_len]
await self.host.send_command(
Hci_Intel_Write_Device_Config_Command(data=ddc_payload)
await self.host.send_sync_command(
HCI_Intel_Write_Device_Config_Command(data=ddc_payload)
)
ddc_data = ddc_data[ddc_len:]
@@ -660,31 +664,34 @@ class Driver(common.Driver):
async def read_device_info(self) -> dict[ValueType, Any]:
self.host.ready = True
response = await self.host.send_command(hci.HCI_Reset_Command())
if not (
isinstance(response, hci.HCI_Command_Complete_Event)
and response.return_parameters
in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS)
response1 = await self.host.send_sync_command_raw(hci.HCI_Reset_Command())
if not isinstance(
response1.return_parameters, hci.HCI_StatusReturnParameters
) or response1.return_parameters.status not in (
hci.HCI_UNKNOWN_HCI_COMMAND_ERROR,
hci.HCI_SUCCESS,
):
# When the controller is in operational mode, the response is a
# successful response.
# When the controller is in bootloader mode,
# HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything
# else is a failure.
logger.warning(f"unexpected response: {response}")
logger.warning(f"unexpected response: {response1}")
raise DriverError("unexpected HCI response")
# Read the firmware version.
response = await self.host.send_command(
response2 = await self.host.send_sync_command_raw(
HCI_Intel_Read_Version_Command(param0=0xFF)
)
if not isinstance(response, hci.HCI_Command_Complete_Event):
raise DriverError("unexpected HCI response")
if response.return_parameters.status != 0: # type: ignore
if (
not isinstance(
response2.return_parameters, HCI_Intel_Read_Version_ReturnParameters
)
or response2.return_parameters.status != 0
):
raise DriverError("HCI_Intel_Read_Version_Command error")
tlvs = _parse_tlv(response.return_parameters.tlv) # type: ignore
tlvs = _parse_tlv(response2.return_parameters.tlv) # type: ignore
# Convert the list to a dict. That's Ok here because we only expect each type
# to appear just once.
+102 -43
View File
@@ -16,6 +16,7 @@ Support for Realtek USB dongles.
Based on various online bits of information, including the Linux kernel.
(see `drivers/bluetooth/btrtl.c`)
"""
from __future__ import annotations
import asyncio
import enum
@@ -31,10 +32,14 @@ import weakref
# Imports
# -----------------------------------------------------------------------------
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from bumble import core, hci
from bumble.drivers import common
if TYPE_CHECKING:
from bumble.host import Host
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -77,6 +82,7 @@ class RtlProjectId(enum.IntEnum):
PROJECT_ID_8852A = 18
PROJECT_ID_8852B = 20
PROJECT_ID_8852C = 25
PROJECT_ID_8761C = 51
RTK_PROJECT_ID_TO_ROM = {
@@ -92,6 +98,7 @@ RTK_PROJECT_ID_TO_ROM = {
18: RTK_ROM_LMP_8852A,
20: RTK_ROM_LMP_8852A,
25: RTK_ROM_LMP_8852A,
51: RTK_ROM_LMP_8761A,
}
# List of USB (VendorID, ProductID) for Realtek-based devices.
@@ -122,7 +129,12 @@ RTK_USB_PRODUCTS = {
(0x2357, 0x0604),
(0x2550, 0x8761),
(0x2B89, 0x8761),
(0x2C0A, 0x8761),
(0x7392, 0xC611),
# Realtek 8761CUV
(0x0B05, 0x1BF6),
(0x0BDA, 0xC761),
(0x7392, 0xF611),
# Realtek 8821AE
(0x0B05, 0x17DC),
(0x13D3, 0x3414),
@@ -182,23 +194,36 @@ HCI_RTK_DROP_FIRMWARE_COMMAND = hci.hci_vendor_command_op_code(0x66)
hci.HCI_Command.register_commands(globals())
@hci.HCI_Command.command
@dataclass
class HCI_RTK_Read_ROM_Version_Command(hci.HCI_Command):
return_parameters_fields = [("status", hci.STATUS_SPEC), ("version", 1)]
class HCI_RTK_Read_ROM_Version_ReturnParameters(hci.HCI_StatusReturnParameters):
version: int = field(metadata=hci.metadata(1))
@hci.HCI_Command.command
@hci.HCI_SyncCommand.sync_command(HCI_RTK_Read_ROM_Version_ReturnParameters)
@dataclass
class HCI_RTK_Download_Command(hci.HCI_Command):
class HCI_RTK_Read_ROM_Version_Command(
hci.HCI_SyncCommand[HCI_RTK_Read_ROM_Version_ReturnParameters]
):
pass
@dataclass
class HCI_RTK_Download_ReturnParameters(hci.HCI_StatusReturnParameters):
index: int = field(metadata=hci.metadata(1))
@hci.HCI_SyncCommand.sync_command(HCI_RTK_Download_ReturnParameters)
@dataclass
class HCI_RTK_Download_Command(hci.HCI_SyncCommand[HCI_RTK_Download_ReturnParameters]):
index: int = field(metadata=hci.metadata(1))
payload: bytes = field(metadata=hci.metadata(RTK_FRAGMENT_LENGTH))
return_parameters_fields = [("status", hci.STATUS_SPEC), ("index", 1)]
@hci.HCI_Command.command
@hci.HCI_SyncCommand.sync_command(hci.HCI_GenericReturnParameters)
@dataclass
class HCI_RTK_Drop_Firmware_Command(hci.HCI_Command):
class HCI_RTK_Drop_Firmware_Command(
hci.HCI_SyncCommand[hci.HCI_GenericReturnParameters]
):
pass
@@ -363,6 +388,15 @@ class Driver(common.Driver):
fw_name="rtl8761bu_fw.bin",
config_name="rtl8761bu_config.bin",
),
# 8761CU
DriverInfo(
rom=RTK_ROM_LMP_8761A,
hci=(0x0E, 0x00),
config_needed=False,
has_rom_version=True,
fw_name="rtl8761cu_fw.bin",
config_name="rtl8761cu_config.bin",
),
# 8822C
DriverInfo(
rom=RTK_ROM_LMP_8822B,
@@ -420,9 +454,17 @@ class Driver(common.Driver):
@staticmethod
def find_driver_info(hci_version, hci_subversion, lmp_subversion):
for driver_info in Driver.DRIVER_INFOS:
if driver_info.rom == lmp_subversion and driver_info.hci == (
hci_subversion,
hci_version,
if driver_info.rom == lmp_subversion and (
driver_info.hci
== (
hci_subversion,
hci_version,
)
or driver_info.hci
== (
hci_subversion,
0x0,
)
):
return driver_info
@@ -467,7 +509,7 @@ class Driver(common.Driver):
return None
@staticmethod
def check(host):
def check(host: Host) -> bool:
if not host.hci_metadata:
logger.debug("USB metadata not found")
return False
@@ -491,37 +533,44 @@ class Driver(common.Driver):
return True
@staticmethod
async def get_loaded_firmware_version(host):
response = await host.send_command(HCI_RTK_Read_ROM_Version_Command())
if response.return_parameters.status != hci.HCI_SUCCESS:
async def get_loaded_firmware_version(host: Host) -> int | None:
response1 = await host.send_sync_command_raw(HCI_RTK_Read_ROM_Version_Command())
if (
not isinstance(
response1.return_parameters, HCI_RTK_Read_ROM_Version_ReturnParameters
)
or response1.return_parameters.status != hci.HCI_SUCCESS
):
return None
response = await host.send_command(
hci.HCI_Read_Local_Version_Information_Command(), check_result=True
)
return (
response.return_parameters.hci_subversion << 16
| response.return_parameters.lmp_subversion
response2 = await host.send_sync_command(
hci.HCI_Read_Local_Version_Information_Command()
)
return response2.hci_subversion << 16 | response2.lmp_subversion
@classmethod
async def driver_info_for_host(cls, host):
async def driver_info_for_host(cls, host: Host) -> DriverInfo | None:
try:
await host.send_command(
await host.send_sync_command(
hci.HCI_Reset_Command(),
check_result=True,
response_timeout=cls.POST_RESET_DELAY,
)
host.ready = True # Needed to let the host know the controller is ready.
except asyncio.exceptions.TimeoutError:
logger.warning("timeout waiting for hci reset, retrying")
await host.send_command(hci.HCI_Reset_Command(), check_result=True)
await host.send_sync_command(hci.HCI_Reset_Command())
host.ready = True
command = hci.HCI_Read_Local_Version_Information_Command()
response = await host.send_command(command, check_result=True)
if response.command_opcode != command.op_code:
response = await host.send_sync_command_raw(
hci.HCI_Read_Local_Version_Information_Command()
)
if (
not isinstance(
response.return_parameters,
hci.HCI_Read_Local_Version_Information_ReturnParameters,
)
or response.return_parameters.status != hci.HCI_SUCCESS
):
logger.error("failed to probe local version information")
return None
@@ -546,7 +595,7 @@ class Driver(common.Driver):
return driver_info
@classmethod
async def for_host(cls, host, force=False):
async def for_host(cls, host: Host, force: bool = False):
# Check that a driver is needed for this host
if not force and not cls.check(host):
return None
@@ -601,15 +650,21 @@ class Driver(common.Driver):
# TODO: load the firmware
async def download_for_rtl8723b(self):
async def download_for_rtl8723b(self) -> int | None:
if self.driver_info.has_rom_version:
response = await self.host.send_command(
HCI_RTK_Read_ROM_Version_Command(), check_result=True
response1 = await self.host.send_sync_command_raw(
HCI_RTK_Read_ROM_Version_Command()
)
if response.return_parameters.status != hci.HCI_SUCCESS:
if (
not isinstance(
response1.return_parameters,
HCI_RTK_Read_ROM_Version_ReturnParameters,
)
or response1.return_parameters.status != hci.HCI_SUCCESS
):
logger.warning("can't get ROM version")
return None
rom_version = response.return_parameters.version
rom_version = response1.return_parameters.version
logger.debug(f"ROM version before download: {rom_version:04X}")
else:
rom_version = 0
@@ -644,21 +699,25 @@ class Driver(common.Driver):
fragment_offset = fragment_index * RTK_FRAGMENT_LENGTH
fragment = payload[fragment_offset : fragment_offset + RTK_FRAGMENT_LENGTH]
logger.debug(f"downloading fragment {fragment_index}")
await self.host.send_command(
HCI_RTK_Download_Command(index=download_index, payload=fragment),
check_result=True,
await self.host.send_sync_command(
HCI_RTK_Download_Command(index=download_index, payload=fragment)
)
logger.debug("download complete!")
# Read the version again
response = await self.host.send_command(
HCI_RTK_Read_ROM_Version_Command(), check_result=True
response2 = await self.host.send_sync_command_raw(
HCI_RTK_Read_ROM_Version_Command()
)
if response.return_parameters.status != hci.HCI_SUCCESS:
if (
not isinstance(
response2.return_parameters, HCI_RTK_Read_ROM_Version_ReturnParameters
)
or response2.return_parameters.status != hci.HCI_SUCCESS
):
logger.warning("can't get ROM version")
else:
rom_version = response.return_parameters.version
rom_version = response2.return_parameters.version
logger.debug(f"ROM version after download: {rom_version:02X}")
return firmware.version
@@ -680,7 +739,7 @@ class Driver(common.Driver):
async def init_controller(self):
await self.download_firmware()
await self.host.send_command(hci.HCI_Reset_Command(), check_result=True)
await self.host.send_sync_command(hci.HCI_Reset_Command())
logger.info(f"loaded FW image {self.driver_info.fw_name}")
-60
View File
@@ -1,60 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
import struct
from bumble.gatt import (
GATT_APPEARANCE_CHARACTERISTIC,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE,
Characteristic,
Service,
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class GenericAccessService(Service):
def __init__(self, device_name, appearance=(0, 0)):
device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
device_name.encode('utf-8')[:248],
)
appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', (appearance[0] << 6) | appearance[1]),
)
super().__init__(
GATT_GENERIC_ACCESS_SERVICE,
[device_name_characteristic, appearance_characteristic],
)
+4 -4
View File
@@ -29,9 +29,9 @@ import functools
import logging
import struct
from collections.abc import Iterable, Sequence
from typing import TypeVar
from typing import ClassVar, TypeVar
from bumble.att import Attribute, AttributeValue
from bumble.att import Attribute, AttributeValue, AttributeValueV2
from bumble.colors import color
from bumble.core import UUID, BaseBumbleError
@@ -403,7 +403,7 @@ class TemplateService(Service):
to expose their UUID as a class property
'''
UUID: UUID
UUID: ClassVar[UUID]
def __init__(
self,
@@ -579,7 +579,7 @@ class Descriptor(Attribute):
def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
elif isinstance(self.value, (AttributeValue, AttributeValueV2)):
value_str = '<dynamic>'
else:
value_str = '<...>'
+86 -31
View File
@@ -26,6 +26,7 @@
from __future__ import annotations
import asyncio
import functools
import logging
import struct
from collections.abc import Callable, Iterable
@@ -33,11 +34,15 @@ from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Generic,
TypeVar,
overload,
)
from bumble import att, core, utils
from typing_extensions import Self
from bumble import att, core, l2cap, utils
from bumble.colors import color
from bumble.core import UUID, InvalidStateError
from bumble.gatt import (
@@ -54,12 +59,12 @@ from bumble.gatt import (
)
from bumble.hci import HCI_Constant
if TYPE_CHECKING:
from bumble import device as device_module
# -----------------------------------------------------------------------------
# Typing
# -----------------------------------------------------------------------------
if TYPE_CHECKING:
from bumble.device import Connection
_T = TypeVar('_T')
# -----------------------------------------------------------------------------
@@ -247,10 +252,10 @@ class ProfileServiceProxy:
Base class for profile-specific service proxies
'''
SERVICE_CLASS: type[TemplateService]
SERVICE_CLASS: ClassVar[type[TemplateService]]
@classmethod
def from_client(cls, client: Client) -> ProfileServiceProxy | None:
def from_client(cls, client: Client) -> Self | None:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -267,8 +272,8 @@ class Client:
pending_response: asyncio.futures.Future[att.ATT_PDU] | None
pending_request: att.ATT_PDU | None
def __init__(self, connection: Connection) -> None:
self.connection = connection
def __init__(self, bearer: att.Bearer) -> None:
self.bearer = bearer
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
@@ -278,21 +283,76 @@ class Client:
self.services = []
self.cached_values = {}
connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
if att.is_enhanced_bearer(bearer):
bearer.on(bearer.EVENT_CLOSE, self.on_disconnection)
self._bearer_id = (
f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
)
self.connection = bearer.connection
else:
bearer.on(bearer.EVENT_DISCONNECTION, self.on_disconnection)
self._bearer_id = f'[0x{bearer.handle:04X}]'
self.connection = bearer
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
) -> Client: ...
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client]: ...
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client] | Client:
channels = await connection.device.l2cap_channel_manager.create_enhanced_credit_based_channels(
connection,
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM),
count,
)
def on_pdu(client: Client, pdu: bytes):
client.on_gatt_pdu(att.ATT_PDU.from_bytes(pdu))
clients = [cls(channel) for channel in channels]
for channel, client in zip(channels, clients):
channel.sink = functools.partial(on_pdu, client)
channel.att_mtu = att.ATT_DEFAULT_MTU
return clients[0] if count == 1 else clients
@property
def mtu(self) -> int:
return self.bearer.att_mtu
@mtu.setter
def mtu(self, value: int) -> None:
self.bearer.on_att_mtu_update(value)
def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(att.ATT_CID, pdu)
if att.is_enhanced_bearer(self.bearer):
self.bearer.write(pdu)
else:
self.bearer.send_l2cap_pdu(att.ATT_CID, pdu)
async def send_command(self, command: att.ATT_PDU) -> None:
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
logger.debug(f'GATT Command from client: {self._bearer_id} {command}')
self.send_gatt_pdu(bytes(command))
async def send_request(self, request: att.ATT_PDU):
logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
logger.debug(f'GATT Request from client: {self._bearer_id} {request}')
# Wait until we can send (only one pending command at a time for the connection)
response = None
@@ -321,10 +381,7 @@ class Client:
def send_confirmation(
self, confirmation: att.ATT_Handle_Value_Confirmation
) -> None:
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
logger.debug(f'GATT Confirmation from client: {self._bearer_id} {confirmation}')
self.send_gatt_pdu(bytes(confirmation))
async def request_mtu(self, mtu: int) -> int:
@@ -336,7 +393,7 @@ class Client:
# We can only send one request per connection
if self.mtu_exchange_done:
return self.connection.att_mtu
return self.mtu
# Send the request
self.mtu_exchange_done = True
@@ -347,9 +404,9 @@ class Client:
raise att.ATT_Error(error_code=response.error_code, message=response)
# Compute the final MTU
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
self.mtu = min(mtu, response.server_rx_mtu)
return self.connection.att_mtu
return self.mtu
def get_services_by_uuid(self, uuid: UUID) -> list[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid]
@@ -942,7 +999,7 @@ class Client:
# If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that
attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
if not no_long_read and len(attribute_value) == self.mtu - 1:
logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value)
while True:
@@ -966,7 +1023,7 @@ class Client:
part = response.part_attribute_value
attribute_value += part
if len(part) < self.connection.att_mtu - 1:
if len(part) < self.mtu - 1:
break
offset += len(part)
@@ -1062,14 +1119,13 @@ class Client:
)
)
def on_disconnection(self, _) -> None:
def on_disconnection(self, *args) -> None:
del args # unused.
if self.pending_response and not self.pending_response.done():
self.pending_response.cancel()
def on_gatt_pdu(self, att_pdu: att.ATT_PDU) -> None:
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
logger.debug(f'GATT Response to client: {self._bearer_id} {att_pdu}')
if att_pdu.op_code in att.ATT_RESPONSES:
if self.pending_request is None:
# Not expected!
@@ -1099,8 +1155,7 @@ class Client:
else:
logger.warning(
color(
'--- Ignoring GATT Response from '
f'[0x{self.connection.handle:04X}]: ',
'--- Ignoring GATT Response from ' f'{self._bearer_id}: ',
'red',
)
+ str(att_pdu)
+313 -143
View File
@@ -32,9 +32,8 @@ from collections import defaultdict
from collections.abc import Iterable
from typing import TYPE_CHECKING, TypeVar
from bumble import att, utils
from bumble import att, core, l2cap, utils
from bumble.colors import color
from bumble.core import UUID
from bumble.gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
@@ -44,14 +43,13 @@ from bumble.gatt import (
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic,
CharacteristicDeclaration,
CharacteristicValue,
Descriptor,
IncludedServiceDeclaration,
Service,
)
if TYPE_CHECKING:
from bumble.device import Connection, Device
from bumble.device import Device
# -----------------------------------------------------------------------------
# Logging
@@ -65,6 +63,20 @@ logger = logging.getLogger(__name__)
GATT_SERVER_DEFAULT_MAX_MTU = 517
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
_T = TypeVar('_T')
def _bearer_id(bearer: att.Bearer) -> str:
if att.is_enhanced_bearer(bearer):
return f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
else:
return f'[0x{bearer.handle:04X}]'
# -----------------------------------------------------------------------------
# GATT Server
# -----------------------------------------------------------------------------
@@ -72,9 +84,9 @@ class Server(utils.EventEmitter):
attributes: list[att.Attribute]
services: list[Service]
attributes_by_handle: dict[int, att.Attribute]
subscribers: dict[int, dict[int, bytes]]
indication_semaphores: defaultdict[int, asyncio.Semaphore]
pending_confirmations: defaultdict[int, asyncio.futures.Future | None]
subscribers: dict[att.Bearer, dict[int, bytes]]
indication_semaphores: defaultdict[att.Bearer, asyncio.Semaphore]
pending_confirmations: defaultdict[att.Bearer, asyncio.futures.Future | None]
EVENT_CHARACTERISTIC_SUBSCRIPTION = "characteristic_subscription"
@@ -96,8 +108,28 @@ class Server(utils.EventEmitter):
def __str__(self) -> str:
return "\n".join(map(str, self.attributes))
def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(connection_handle, att.ATT_CID, pdu)
def register_eatt(
self, spec: l2cap.LeCreditBasedChannelSpec | None = None
) -> l2cap.LeCreditBasedChannelServer:
def on_channel(channel: l2cap.LeCreditBasedChannel):
logger.debug(
"New EATT Bearer Connection=0x%04X CID=0x%04X",
channel.connection.handle,
channel.source_cid,
)
channel.sink = lambda pdu: self.on_gatt_pdu(
channel, att.ATT_PDU.from_bytes(pdu)
)
return self.device.create_l2cap_server(
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM), handler=on_channel
)
def send_gatt_pdu(self, bearer: att.Bearer, pdu: bytes) -> None:
if att.is_enhanced_bearer(bearer):
bearer.write(pdu)
else:
self.device.send_l2cap_pdu(bearer.handle, att.ATT_CID, pdu)
def next_handle(self) -> int:
return 1 + len(self.attributes)
@@ -138,7 +170,7 @@ class Server(utils.EventEmitter):
None,
)
def get_service_attribute(self, service_uuid: UUID) -> Service | None:
def get_service_attribute(self, service_uuid: core.UUID) -> Service | None:
return next(
(
attribute
@@ -151,7 +183,7 @@ class Server(utils.EventEmitter):
)
def get_characteristic_attributes(
self, service_uuid: UUID, characteristic_uuid: UUID
self, service_uuid: core.UUID, characteristic_uuid: core.UUID
) -> tuple[CharacteristicDeclaration, Characteristic] | None:
service_handle = self.get_service_attribute(service_uuid)
if not service_handle:
@@ -176,7 +208,10 @@ class Server(utils.EventEmitter):
)
def get_descriptor_attribute(
self, service_uuid: UUID, characteristic_uuid: UUID, descriptor_uuid: UUID
self,
service_uuid: core.UUID,
characteristic_uuid: core.UUID,
descriptor_uuid: core.UUID,
) -> Descriptor | None:
characteristics = self.get_characteristic_attributes(
service_uuid, characteristic_uuid
@@ -257,14 +292,7 @@ class Server(utils.EventEmitter):
Descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
att.Attribute.READABLE | att.Attribute.WRITEABLE,
CharacteristicValue(
read=lambda connection, characteristic=characteristic: self.read_cccd(
connection, characteristic
),
write=lambda connection, value, characteristic=characteristic: self.write_cccd(
connection, characteristic, value
),
),
self.make_descriptor_value(characteristic),
)
)
@@ -280,10 +308,21 @@ class Server(utils.EventEmitter):
for service in services:
self.add_service(service)
def read_cccd(
self, connection: Connection, characteristic: Characteristic
) -> bytes:
subscribers = self.subscribers.get(connection.handle)
def make_descriptor_value(
self, characteristic: Characteristic
) -> att.AttributeValueV2:
# It is necessary to use Attribute Value V2 here to identify the bearer of CCCD.
return att.AttributeValueV2(
lambda bearer, characteristic=characteristic: self.read_cccd(
bearer, characteristic
),
write=lambda bearer, value, characteristic=characteristic: self.write_cccd(
bearer, characteristic, value
),
)
def read_cccd(self, bearer: att.Bearer, characteristic: Characteristic) -> bytes:
subscribers = self.subscribers.get(bearer)
cccd = None
if subscribers:
cccd = subscribers.get(characteristic.handle)
@@ -292,12 +331,12 @@ class Server(utils.EventEmitter):
def write_cccd(
self,
connection: Connection,
bearer: att.Bearer,
characteristic: Characteristic,
value: bytes,
) -> None:
logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, '
f'Subscription update for connection={_bearer_id(bearer)}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
)
@@ -306,41 +345,60 @@ class Server(utils.EventEmitter):
logger.warning('CCCD value not 2 bytes long')
return
cccds = self.subscribers.setdefault(connection.handle, {})
cccds = self.subscribers.setdefault(bearer, {})
cccds[characteristic.handle] = value
logger.debug(f'CCCDs: {cccds}')
notify_enabled = value[0] & 0x01 != 0
indicate_enabled = value[0] & 0x02 != 0
characteristic.emit(
characteristic.EVENT_SUBSCRIPTION,
connection,
bearer,
notify_enabled,
indicate_enabled,
)
self.emit(
self.EVENT_CHARACTERISTIC_SUBSCRIPTION,
connection,
bearer,
characteristic,
notify_enabled,
indicate_enabled,
)
def send_response(self, connection: Connection, response: att.ATT_PDU) -> None:
logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
self.send_gatt_pdu(connection.handle, bytes(response))
def send_response(self, bearer: att.Bearer, response: att.ATT_PDU) -> None:
logger.debug(f'GATT Response from server: {_bearer_id(bearer)} {response}')
self.send_gatt_pdu(bearer, bytes(response))
async def notify_subscriber(
self,
connection: Connection,
attribute: att.Attribute,
value: bytes | None = None,
bearer: att.Bearer,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
return await self._notify_single_subscriber(bearer, attribute, value, force)
else:
# If API is called to a Connection and not forced, try to notify all subscribed bearers on it.
bearers = [
channel
for channel in self.device.l2cap_channel_manager.le_coc_channels.get(
bearer.handle, {}
).values()
if channel.psm == att.EATT_PSM
] + [bearer]
for bearer in bearers:
await self._notify_single_subscriber(bearer, attribute, value, force)
async def _notify_single_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute[_T],
value: _T | None,
force: bool,
) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
subscribers = self.subscribers.get(bearer)
if not subscribers:
logger.debug('not notifying, no subscribers')
return
@@ -355,35 +413,54 @@ class Server(utils.EventEmitter):
return
# Get or encode the value
value = (
await attribute.read_value(connection)
value_as_bytes = (
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
value = value[: connection.att_mtu - 3]
if len(value_as_bytes) > bearer.att_mtu - 3:
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
# Notify
notification = att.ATT_Handle_Value_Notification(
attribute_handle=attribute.handle, attribute_value=value
attribute_handle=attribute.handle, attribute_value=value_as_bytes
)
logger.debug(
f'GATT Notify from server: [0x{connection.handle:04X}] {notification}'
)
self.send_gatt_pdu(connection.handle, bytes(notification))
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
self.send_gatt_pdu(bearer, bytes(notification))
async def indicate_subscriber(
self,
connection: Connection,
attribute: att.Attribute,
value: bytes | None = None,
bearer: att.Bearer,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
return await self._notify_single_subscriber(bearer, attribute, value, force)
else:
# If API is called to a Connection and not forced, try to indicate all subscribed bearers on it.
bearers = [
channel
for channel in self.device.l2cap_channel_manager.le_coc_channels.get(
bearer.handle, {}
).values()
if channel.psm == att.EATT_PSM
] + [bearer]
for bearer in bearers:
await self._indicate_single_bearer(bearer, attribute, value, force)
async def _indicate_single_bearer(
self,
bearer: att.Bearer,
attribute: att.Attribute[_T],
value: _T | None,
force: bool,
) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
subscribers = self.subscribers.get(bearer)
if not subscribers:
logger.debug('not indicating, no subscribers')
return
@@ -398,74 +475,72 @@ class Server(utils.EventEmitter):
return
# Get or encode the value
value = (
await attribute.read_value(connection)
value_as_bytes = (
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
value = value[: connection.att_mtu - 3]
if len(value_as_bytes) > bearer.att_mtu - 3:
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
# Indicate
indication = att.ATT_Handle_Value_Indication(
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(
f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}'
attribute_handle=attribute.handle, attribute_value=value_as_bytes
)
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
# Wait until we can send (only one pending indication at a time per connection)
async with self.indication_semaphores[connection.handle]:
assert self.pending_confirmations[connection.handle] is None
async with self.indication_semaphores[bearer]:
assert self.pending_confirmations[bearer] is None
# Create a future value to hold the eventual response
pending_confirmation = self.pending_confirmations[connection.handle] = (
pending_confirmation = self.pending_confirmations[bearer] = (
asyncio.get_running_loop().create_future()
)
try:
self.send_gatt_pdu(connection.handle, bytes(indication))
self.send_gatt_pdu(bearer, bytes(indication))
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') from error
finally:
self.pending_confirmations[connection.handle] = None
self.pending_confirmations[bearer] = None
async def _notify_or_indicate_subscribers(
self,
indicate: bool,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
# Get all the connections for which there's at least one subscription
connections = [
connection
for connection in [
self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle)
]
if connection is not None
# Get all the bearers for which there's at least one subscription
bearers: list[att.Bearer] = [
bearer
for bearer, subscribers in self.subscribers.items()
if force or subscribers.get(attribute.handle)
]
# Indicate or notify for each connection
if connections:
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
if bearers:
coroutine = (
self._indicate_single_bearer
if indicate
else self._notify_single_subscriber
)
await asyncio.wait(
[
asyncio.create_task(coroutine(connection, attribute, value, force))
for connection in connections
asyncio.create_task(coroutine(bearer, attribute, value, force))
for bearer in bearers
]
)
async def notify_subscribers(
self,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(
@@ -474,27 +549,24 @@ class Server(utils.EventEmitter):
async def indicate_subscribers(
self,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection: Connection) -> None:
if connection.handle in self.subscribers:
del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores:
del self.indication_semaphores[connection.handle]
if connection.handle in self.pending_confirmations:
del self.pending_confirmations[connection.handle]
def on_disconnection(self, bearer: att.Bearer) -> None:
self.subscribers.pop(bearer, None)
self.indication_semaphores.pop(bearer, None)
self.pending_confirmations.pop(bearer, None)
def on_gatt_pdu(self, connection: Connection, att_pdu: att.ATT_PDU) -> None:
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
def on_gatt_pdu(self, bearer: att.Bearer, att_pdu: att.ATT_PDU) -> None:
logger.debug(f'GATT Request to server: {_bearer_id(bearer)} {att_pdu}')
handler_name = f'on_{att_pdu.name.lower()}'
handler = getattr(self, handler_name, None)
if handler is not None:
try:
handler(connection, att_pdu)
handler(bearer, att_pdu)
except att.ATT_Error as error:
logger.debug(f'normal exception returned by handler: {error}')
response = att.ATT_Error_Response(
@@ -502,7 +574,7 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=error.att_handle,
error_code=error.error_code,
)
self.send_response(connection, response)
self.send_response(bearer, response)
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
response = att.ATT_Error_Response(
@@ -510,18 +582,18 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=0x0000,
error_code=att.ATT_UNLIKELY_ERROR_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
raise
else:
# No specific handler registered
if att_pdu.op_code in att.ATT_REQUESTS:
# Invoke the generic handler
self.on_att_request(connection, att_pdu)
self.on_att_request(bearer, att_pdu)
else:
# Just ignore
logger.warning(
color(
f'--- Ignoring GATT Request from [0x{connection.handle:04X}]: ',
f'--- Ignoring GATT Request from {_bearer_id(bearer)}: ',
'red',
)
+ str(att_pdu)
@@ -530,13 +602,14 @@ class Server(utils.EventEmitter):
#######################################################
# ATT handlers
#######################################################
def on_att_request(self, connection: Connection, pdu: att.ATT_PDU) -> None:
def on_att_request(self, bearer: att.Bearer, pdu: att.ATT_PDU) -> None:
'''
Handler for requests without a more specific handler
'''
logger.warning(
color(
f'--- Unsupported ATT Request from [0x{connection.handle:04X}]: ', 'red'
f'--- Unsupported ATT Request from {_bearer_id(bearer)}: ',
'red',
)
+ str(pdu)
)
@@ -545,29 +618,28 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=0x0000,
error_code=att.ATT_REQUEST_NOT_SUPPORTED_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
def on_att_exchange_mtu_request(
self, connection: Connection, request: att.ATT_Exchange_MTU_Request
self, bearer: att.Bearer, request: att.ATT_Exchange_MTU_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
'''
self.send_response(
connection, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
bearer, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
)
# Compute the final MTU
if request.client_rx_mtu >= att.ATT_DEFAULT_MTU:
mtu = min(self.max_mtu, request.client_rx_mtu)
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
bearer.on_att_mtu_update(mtu)
else:
logger.warning('invalid client_rx_mtu received, MTU not changed')
def on_att_find_information_request(
self, connection: Connection, request: att.ATT_Find_Information_Request
self, bearer: att.Bearer, request: att.ATT_Find_Information_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.1 Find Information Request
@@ -580,7 +652,7 @@ class Server(utils.EventEmitter):
or request.starting_handle > request.ending_handle
):
self.send_response(
connection,
bearer,
att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
@@ -590,7 +662,7 @@ class Server(utils.EventEmitter):
return
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
pdu_space_available = bearer.att_mtu - 2
attributes: list[att.Attribute] = []
uuid_size = 0
for attribute in (
@@ -632,18 +704,18 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request(
self, connection: Connection, request: att.ATT_Find_By_Type_Value_Request
self, bearer: att.Bearer, request: att.ATT_Find_By_Type_Value_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
'''
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
pdu_space_available = bearer.att_mtu - 2
attributes = []
response: att.ATT_PDU
async for attribute in (
@@ -652,7 +724,7 @@ class Server(utils.EventEmitter):
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type
and (await attribute.read_value(connection)) == request.attribute_value
and (await attribute.read_value(bearer)) == request.attribute_value
and pdu_space_available >= 4
):
# TODO: check permissions
@@ -688,17 +760,17 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_by_type_request(
self, connection: Connection, request: att.ATT_Read_By_Type_Request
self, bearer: att.Bearer, request: att.ATT_Read_By_Type_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
pdu_space_available = connection.att_mtu - 2
pdu_space_available = bearer.att_mtu - 2
response: att.ATT_PDU = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -706,6 +778,18 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
if (
request.starting_handle == 0x0000
or request.starting_handle > request.ending_handle
):
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=att.ATT_INVALID_HANDLE_ERROR,
)
self.send_response(bearer, response)
return
attributes: list[tuple[int, bytes]] = []
for attribute in (
attribute
@@ -716,7 +800,7 @@ class Server(utils.EventEmitter):
and pdu_space_available
):
try:
attribute_value = await attribute.read_value(connection)
attribute_value = await attribute.read_value(bearer)
except att.ATT_Error as error:
# If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point
@@ -729,7 +813,7 @@ class Server(utils.EventEmitter):
break
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 4, 253)
max_attribute_size = min(bearer.att_mtu - 4, 253)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
@@ -756,11 +840,11 @@ class Server(utils.EventEmitter):
else:
logging.debug(f"not found {request}")
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_request(
self, connection: Connection, request: att.ATT_Read_Request
self, bearer: att.Bearer, request: att.ATT_Read_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
@@ -769,7 +853,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(connection)
value = await attribute.read_value(bearer)
except att.ATT_Error as error:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -777,7 +861,7 @@ class Server(utils.EventEmitter):
error_code=error.error_code,
)
else:
value_size = min(connection.att_mtu - 1, len(value))
value_size = min(bearer.att_mtu - 1, len(value))
response = att.ATT_Read_Response(attribute_value=value[:value_size])
else:
response = att.ATT_Error_Response(
@@ -785,11 +869,11 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_HANDLE_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_blob_request(
self, connection: Connection, request: att.ATT_Read_Blob_Request
self, bearer: att.Bearer, request: att.ATT_Read_Blob_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
@@ -798,7 +882,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(connection)
value = await attribute.read_value(bearer)
except att.ATT_Error as error:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -812,7 +896,7 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_OFFSET_ERROR,
)
elif len(value) <= connection.att_mtu - 1:
elif len(value) <= bearer.att_mtu - 1:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
@@ -820,7 +904,7 @@ class Server(utils.EventEmitter):
)
else:
part_size = min(
connection.att_mtu - 1, len(value) - request.value_offset
bearer.att_mtu - 1, len(value) - request.value_offset
)
response = att.ATT_Read_Blob_Response(
part_attribute_value=value[
@@ -833,11 +917,11 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_HANDLE_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request(
self, connection: Connection, request: att.ATT_Read_By_Group_Type_Request
self, bearer: att.Bearer, request: att.ATT_Read_By_Group_Type_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
@@ -852,10 +936,10 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.starting_handle,
error_code=att.ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
return
pdu_space_available = connection.att_mtu - 2
pdu_space_available = bearer.att_mtu - 2
attributes: list[tuple[int, int, bytes]] = []
for attribute in (
attribute
@@ -867,9 +951,9 @@ class Server(utils.EventEmitter):
):
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = await attribute.read_value(connection)
attribute_value = await attribute.read_value(bearer)
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251)
max_attribute_size = min(bearer.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
@@ -904,11 +988,99 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_multiple_request(
self, bearer: att.Bearer, request: att.ATT_Read_Multiple_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.7 Read Multiple Request.
'''
response: att.ATT_PDU
pdu_space_available = bearer.att_mtu - 1
values: list[bytes] = []
for handle in request.set_of_handles:
if not (attribute := self.get_attribute(handle)):
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=handle,
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(bearer, response)
return
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = await attribute.read_value(bearer)
# Check the attribute value size
max_attribute_size = min(bearer.att_mtu - 1, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
# Check if there is enough space
entry_size = len(attribute_value)
if pdu_space_available < entry_size:
break
# Add the attribute to the list
values.append(attribute_value)
pdu_space_available -= entry_size
response = att.ATT_Read_Multiple_Response(set_of_values=b''.join(values))
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_multiple_variable_request(
self, bearer: att.Bearer, request: att.ATT_Read_Multiple_Variable_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.11 Read Multiple Variable Request.
'''
response: att.ATT_PDU
pdu_space_available = bearer.att_mtu - 1
length_value_tuple_list: list[tuple[int, bytes]] = []
for handle in request.set_of_handles:
if not (attribute := self.get_attribute(handle)):
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=handle,
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(bearer, response)
return
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = await attribute.read_value(bearer)
length = len(attribute_value)
# Check the attribute value size
max_attribute_size = min(bearer.att_mtu - 3, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
# Check if there is enough space
entry_size = 2 + len(attribute_value)
# Add the attribute to the list
length_value_tuple_list.append((length, attribute_value))
pdu_space_available -= entry_size
if pdu_space_available <= 0:
break
response = att.ATT_Read_Multiple_Variable_Response(
length_value_tuple_list=length_value_tuple_list
)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_write_request(
self, connection: Connection, request: att.ATT_Write_Request
self, bearer: att.Bearer, request: att.ATT_Write_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
@@ -918,7 +1090,7 @@ class Server(utils.EventEmitter):
attribute = self.get_attribute(request.attribute_handle)
if attribute is None:
self.send_response(
connection,
bearer,
att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
@@ -932,7 +1104,7 @@ class Server(utils.EventEmitter):
# Check the request parameters
if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE:
self.send_response(
connection,
bearer,
att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
@@ -944,7 +1116,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU
try:
# Accept the value
await attribute.write_value(connection, request.attribute_value)
await attribute.write_value(bearer, request.attribute_value)
except att.ATT_Error as error:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -954,11 +1126,11 @@ class Server(utils.EventEmitter):
else:
# Done
response = att.ATT_Write_Response()
self.send_response(connection, response)
self.send_response(bearer, response)
@utils.AsyncRunner.run_in_task()
async def on_att_write_command(
self, connection: Connection, request: att.ATT_Write_Command
self, bearer: att.Bearer, request: att.ATT_Write_Command
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
@@ -977,22 +1149,20 @@ class Server(utils.EventEmitter):
# Accept the value
try:
await attribute.write_value(connection, request.attribute_value)
await attribute.write_value(bearer, request.attribute_value)
except Exception:
logger.exception('!!! ignoring exception')
def on_att_handle_value_confirmation(
self,
connection: Connection,
bearer: att.Bearer,
confirmation: att.ATT_Handle_Value_Confirmation,
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation
'''
del confirmation # Unused.
if (
pending_confirmation := self.pending_confirmations[connection.handle]
) is None:
if (pending_confirmation := self.pending_confirmations[bearer]) is None:
# Not expected!
logger.warning(
'!!! unexpected confirmation, there is no pending indication'
+1455 -953
View File
File diff suppressed because it is too large Load Diff
+72 -90
View File
@@ -26,7 +26,7 @@ import logging
import re
import traceback
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, ClassVar
from typing import Any, ClassVar, Literal, overload
from typing_extensions import Self
@@ -68,6 +68,8 @@ class HfpProtocolError(ProtocolError):
# -----------------------------------------------------------------------------
class HfpProtocol:
MAX_BUFFER_SIZE: ClassVar[int] = 65536
dlc: rfcomm.DLC
buffer: str
lines: collections.deque
@@ -84,10 +86,19 @@ class HfpProtocol:
def feed(self, data: bytes | str) -> None:
# Convert the data to a string if needed
if isinstance(data, bytes):
data = data.decode('utf-8')
data = data.decode('utf-8', errors='replace')
logger.debug(f'<<< Data received: {data}')
# Drop incoming data if it would overflow the buffer; keep existing
# partial packet state intact so a future clean packet can still parse.
if len(self.buffer) + len(data) > self.MAX_BUFFER_SIZE:
logger.warning(
'HFP buffer overflow (>%d bytes), dropping incoming data',
self.MAX_BUFFER_SIZE,
)
return
# Add to the buffer and look for lines
self.buffer += data
while (separator := self.buffer.find('\r')) >= 0:
@@ -420,61 +431,6 @@ class CmeError(enum.IntEnum):
# Hands-Free Control Interoperability Requirements
# -----------------------------------------------------------------------------
# Response codes.
RESPONSE_CODES = {
"+APLSIRI",
"+BAC",
"+BCC",
"+BCS",
"+BIA",
"+BIEV",
"+BIND",
"+BINP",
"+BLDN",
"+BRSF",
"+BTRH",
"+BVRA",
"+CCWA",
"+CHLD",
"+CHUP",
"+CIND",
"+CLCC",
"+CLIP",
"+CMEE",
"+CMER",
"+CNUM",
"+COPS",
"+IPHONEACCEV",
"+NREC",
"+VGM",
"+VGS",
"+VTS",
"+XAPL",
"A",
"D",
}
# Unsolicited responses and statuses.
UNSOLICITED_CODES = {
"+APLSIRI",
"+BCS",
"+BIND",
"+BSIR",
"+BTRH",
"+BVRA",
"+CCWA",
"+CIEV",
"+CLIP",
"+VGM",
"+VGS",
"BLACKLISTED",
"BUSY",
"DELAYED",
"NO ANSWER",
"NO CARRIER",
"RING",
}
# Status codes
STATUS_CODES = {
"+CME ERROR",
@@ -727,12 +683,9 @@ class HfProtocol(utils.EventEmitter):
dlc: rfcomm.DLC
command_lock: asyncio.Lock
if TYPE_CHECKING:
response_queue: asyncio.Queue[AtResponse]
unsolicited_queue: asyncio.Queue[AtResponse | None]
else:
response_queue: asyncio.Queue
unsolicited_queue: asyncio.Queue
pending_command: str | None = None
response_queue: asyncio.Queue[AtResponse]
unsolicited_queue: asyncio.Queue[AtResponse | None]
read_buffer: bytearray
active_codec: AudioCodec
@@ -805,16 +758,39 @@ class HfProtocol(utils.EventEmitter):
self.read_buffer = self.read_buffer[trailer + 2 :]
# Forward the received code to the correct queue.
if self.command_lock.locked() and (
response.code in STATUS_CODES or response.code in RESPONSE_CODES
if self.pending_command and (
response.code in STATUS_CODES or response.code in self.pending_command
):
self.response_queue.put_nowait(response)
elif response.code in UNSOLICITED_CODES:
self.unsolicited_queue.put_nowait(response)
else:
logger.warning(
f"dropping unexpected response with code '{response.code}'"
)
self.unsolicited_queue.put_nowait(response)
@overload
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
*,
response_type: Literal[AtResponseType.NONE] = AtResponseType.NONE,
) -> None: ...
@overload
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
*,
response_type: Literal[AtResponseType.SINGLE],
) -> AtResponse: ...
@overload
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
*,
response_type: Literal[AtResponseType.MULTIPLE],
) -> list[AtResponse]: ...
async def execute_command(
self,
@@ -835,27 +811,34 @@ class HfProtocol(utils.EventEmitter):
asyncio.TimeoutError: the status is not received after a timeout (default 1 second).
ProtocolError: the status is not OK.
"""
async with self.command_lock:
logger.debug(f">>> {cmd}")
self.dlc.write(cmd + '\r')
responses: list[AtResponse] = []
try:
async with self.command_lock:
self.pending_command = cmd
logger.debug(f">>> {cmd}")
self.dlc.write(cmd + '\r')
responses: list[AtResponse] = []
while True:
result = await asyncio.wait_for(
self.response_queue.get(), timeout=timeout
)
if result.code == 'OK':
if response_type == AtResponseType.SINGLE and len(responses) != 1:
raise HfpProtocolError("NO ANSWER")
while True:
result = await asyncio.wait_for(
self.response_queue.get(), timeout=timeout
)
if result.code == 'OK':
if (
response_type == AtResponseType.SINGLE
and len(responses) != 1
):
raise HfpProtocolError("NO ANSWER")
if response_type == AtResponseType.MULTIPLE:
return responses
if response_type == AtResponseType.SINGLE:
return responses[0]
return None
if result.code in STATUS_CODES:
raise HfpProtocolError(result.code)
responses.append(result)
if response_type == AtResponseType.MULTIPLE:
return responses
if response_type == AtResponseType.SINGLE:
return responses[0]
return None
if result.code in STATUS_CODES:
raise HfpProtocolError(result.code)
responses.append(result)
finally:
self.pending_command = None
async def initiate_slc(self):
"""4.2.1 Service Level Connection Initialization."""
@@ -1067,7 +1050,6 @@ class HfProtocol(utils.EventEmitter):
responses = await self.execute_command(
"AT+CLCC", response_type=AtResponseType.MULTIPLE
)
assert isinstance(responses, list)
calls = []
for response in responses:
+2 -2
View File
@@ -312,11 +312,11 @@ class HID(ABC, utils.EventEmitter):
def send_pdu_on_ctrl(self, msg: bytes) -> None:
assert self.l2cap_ctrl_channel
self.l2cap_ctrl_channel.send_pdu(msg)
self.l2cap_ctrl_channel.write(msg)
def send_pdu_on_intr(self, msg: bytes) -> None:
assert self.l2cap_intr_channel
self.l2cap_intr_channel.send_pdu(msg)
self.l2cap_intr_channel.write(msg)
def send_data(self, data: bytes) -> None:
if self.role == HID.Role.HOST:
+371 -165
View File
@@ -21,13 +21,16 @@ import asyncio
import collections
import dataclasses
import logging
import struct
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypeVar, overload
from bumble import drivers, hci, utils
from bumble.colors import color
from bumble.core import ConnectionPHY, InvalidStateError, PhysicalTransport
from bumble.core import (
ConnectionPHY,
InvalidStateError,
PhysicalTransport,
)
from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper
from bumble.transport.common import TransportLostError
@@ -35,7 +38,6 @@ from bumble.transport.common import TransportLostError
if TYPE_CHECKING:
from bumble.transport.common import TransportSink, TransportSource
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -236,6 +238,9 @@ class IsoLink:
# -----------------------------------------------------------------------------
_RP = TypeVar('_RP', bound=hci.HCI_ReturnParameters)
class Host(utils.EventEmitter):
connections: dict[int, Connection]
cis_links: dict[int, IsoLink]
@@ -264,13 +269,20 @@ class Host(utils.EventEmitter):
self.bis_links = {} # BIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.bigs = {} # BIG Handle to BIS Handles
self.pending_command = None
self.pending_response: asyncio.Future[Any] | None = None
self.pending_command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand | None = None
self.pending_response: (
asyncio.Future[
hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event
]
| None
) = None
self.number_of_supported_advertising_sets = 0
self.maximum_advertising_data_length = 31
self.local_version = None
self.local_version: (
hci.HCI_Read_Local_Version_Information_ReturnParameters | None
) = None
self.local_supported_commands = 0
self.local_le_features = 0
self.local_le_features = hci.LeFeatureMask(0) # LE features
self.local_lmp_features = hci.LmpFeatureMask(0) # Classic LMP features
self.suggested_max_tx_octets = 251 # Max allowed
self.suggested_max_tx_time = 2120 # Max allowed
@@ -312,7 +324,7 @@ class Host(utils.EventEmitter):
self.emit('flush')
self.command_semaphore.release()
async def reset(self, driver_factory=drivers.get_driver_for_host):
async def reset(self, driver_factory=drivers.get_driver_for_host) -> None:
if self.ready:
self.ready = False
await self.flush()
@@ -330,57 +342,61 @@ class Host(utils.EventEmitter):
# Send a reset command unless a driver has already done so.
if reset_needed:
await self.send_command(hci.HCI_Reset_Command(), check_result=True)
await self.send_sync_command(hci.HCI_Reset_Command())
self.ready = True
response = await self.send_command(
hci.HCI_Read_Local_Supported_Commands_Command(), check_result=True
response1 = await self.send_sync_command(
hci.HCI_Read_Local_Supported_Commands_Command()
)
self.local_supported_commands = int.from_bytes(
response.return_parameters.supported_commands, 'little'
response1.supported_commands, 'little'
)
if self.supports_command(hci.HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command(
hci.HCI_LE_Read_Local_Supported_Features_Command(), check_result=True
)
self.local_le_features = struct.unpack(
'<Q', response.return_parameters.le_features
)[0]
if self.supports_command(hci.HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND):
response = await self.send_command(
hci.HCI_Read_Local_Version_Information_Command(), check_result=True
self.local_version = await self.send_sync_command(
hci.HCI_Read_Local_Version_Information_Command()
)
if self.supports_command(hci.HCI_LE_READ_ALL_LOCAL_SUPPORTED_FEATURES_COMMAND):
response2 = await self.send_sync_command(
hci.HCI_LE_Read_All_Local_Supported_Features_Command()
)
self.local_le_features = hci.LeFeatureMask(
int.from_bytes(response2.le_features, 'little')
)
elif self.supports_command(hci.HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response3 = await self.send_sync_command(
hci.HCI_LE_Read_Local_Supported_Features_Command()
)
self.local_le_features = hci.LeFeatureMask(
int.from_bytes(response3.le_features, 'little')
)
self.local_version = response.return_parameters
if self.supports_command(hci.HCI_READ_LOCAL_EXTENDED_FEATURES_COMMAND):
max_page_number = 0
page_number = 0
lmp_features = 0
while page_number <= max_page_number:
response = await self.send_command(
response4 = await self.send_sync_command(
hci.HCI_Read_Local_Extended_Features_Command(
page_number=page_number
),
check_result=True,
)
)
lmp_features |= int.from_bytes(
response.return_parameters.extended_lmp_features, 'little'
response4.extended_lmp_features, 'little'
) << (64 * page_number)
max_page_number = response.return_parameters.maximum_page_number
max_page_number = response4.maximum_page_number
page_number += 1
self.local_lmp_features = hci.LmpFeatureMask(lmp_features)
elif self.supports_command(hci.HCI_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command(
hci.HCI_Read_Local_Supported_Features_Command(), check_result=True
response5 = await self.send_sync_command(
hci.HCI_Read_Local_Supported_Features_Command()
)
self.local_lmp_features = hci.LmpFeatureMask(
int.from_bytes(response.return_parameters.lmp_features, 'little')
int.from_bytes(response5.lmp_features, 'little')
)
await self.send_command(
await self.send_sync_command(
hci.HCI_Set_Event_Mask_Command(
event_mask=hci.HCI_Set_Event_Mask_Command.mask(
[
@@ -437,7 +453,7 @@ class Host(utils.EventEmitter):
)
)
if self.supports_command(hci.HCI_SET_EVENT_MASK_PAGE_2_COMMAND):
await self.send_command(
await self.send_sync_command(
hci.HCI_Set_Event_Mask_Page_2_Command(
event_mask_page_2=hci.HCI_Set_Event_Mask_Page_2_Command.mask(
[hci.HCI_ENCRYPTION_CHANGE_V2_EVENT]
@@ -490,29 +506,28 @@ class Host(utils.EventEmitter):
hci.HCI_LE_TRANSMIT_POWER_REPORTING_EVENT,
hci.HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT,
hci.HCI_LE_SUBRATE_CHANGE_EVENT,
hci.HCI_LE_READ_ALL_REMOTE_FEATURES_COMPLETE_EVENT,
hci.HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMPLETE_EVENT,
hci.HCI_LE_CS_PROCEDURE_ENABLE_COMPLETE_EVENT,
hci.HCI_LE_CS_SECURITY_ENABLE_COMPLETE_EVENT,
hci.HCI_LE_CS_CONFIG_COMPLETE_EVENT,
hci.HCI_LE_CS_SUBEVENT_RESULT_EVENT,
hci.HCI_LE_CS_SUBEVENT_RESULT_CONTINUE_EVENT,
hci.HCI_LE_MONITORED_ADVERTISERS_REPORT_EVENT,
hci.HCI_LE_FRAME_SPACE_UPDATE_COMPLETE_EVENT,
hci.HCI_LE_UTP_RECEIVE_EVENT,
hci.HCI_LE_CONNECTION_RATE_CHANGE_EVENT,
]
)
await self.send_command(
await self.send_sync_command(
hci.HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask)
)
if self.supports_command(hci.HCI_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(
hci.HCI_Read_Buffer_Size_Command(), check_result=True
)
hc_acl_data_packet_length = (
response.return_parameters.hc_acl_data_packet_length
)
hc_total_num_acl_data_packets = (
response.return_parameters.hc_total_num_acl_data_packets
)
response6 = await self.send_sync_command(hci.HCI_Read_Buffer_Size_Command())
hc_acl_data_packet_length = response6.hc_acl_data_packet_length
hc_total_num_acl_data_packets = response6.hc_total_num_acl_data_packets
logger.debug(
'HCI ACL flow control: '
@@ -531,19 +546,13 @@ class Host(utils.EventEmitter):
iso_data_packet_length = 0
total_num_iso_data_packets = 0
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response = await self.send_command(
hci.HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
le_acl_data_packet_length = (
response.return_parameters.le_acl_data_packet_length
)
total_num_le_acl_data_packets = (
response.return_parameters.total_num_le_acl_data_packets
)
iso_data_packet_length = response.return_parameters.iso_data_packet_length
total_num_iso_data_packets = (
response.return_parameters.total_num_iso_data_packets
response7 = await self.send_sync_command(
hci.HCI_LE_Read_Buffer_Size_V2_Command()
)
le_acl_data_packet_length = response7.le_acl_data_packet_length
total_num_le_acl_data_packets = response7.total_num_le_acl_data_packets
iso_data_packet_length = response7.iso_data_packet_length
total_num_iso_data_packets = response7.total_num_iso_data_packets
logger.debug(
'HCI LE flow control: '
@@ -553,15 +562,11 @@ class Host(utils.EventEmitter):
f'total_num_iso_data_packets={total_num_iso_data_packets}'
)
elif self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(
hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
le_acl_data_packet_length = (
response.return_parameters.le_acl_data_packet_length
)
total_num_le_acl_data_packets = (
response.return_parameters.total_num_le_acl_data_packets
response8 = await self.send_sync_command(
hci.HCI_LE_Read_Buffer_Size_Command()
)
le_acl_data_packet_length = response8.le_acl_data_packet_length
total_num_le_acl_data_packets = response8.total_num_le_acl_data_packets
logger.debug(
'HCI LE ACL flow control: '
@@ -592,16 +597,16 @@ class Host(utils.EventEmitter):
) and self.supports_command(
hci.HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
):
response = await self.send_command(
response9 = await self.send_sync_command(
hci.HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
suggested_max_tx_octets = response9.suggested_max_tx_octets
suggested_max_tx_time = response9.suggested_max_tx_time
if (
suggested_max_tx_octets != self.suggested_max_tx_octets
or suggested_max_tx_time != self.suggested_max_tx_time
):
await self.send_command(
await self.send_sync_command(
hci.HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets=self.suggested_max_tx_octets,
suggested_max_tx_time=self.suggested_max_tx_time,
@@ -611,24 +616,28 @@ class Host(utils.EventEmitter):
if self.supports_command(
hci.HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND
):
response = await self.send_command(
hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command(),
check_result=True,
)
self.number_of_supported_advertising_sets = (
response.return_parameters.num_supported_advertising_sets
)
try:
response10 = await self.send_sync_command(
hci.HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
self.number_of_supported_advertising_sets = (
response10.num_supported_advertising_sets
)
except hci.HCI_Error:
logger.warning('Failed to read number of supported advertising sets')
if self.supports_command(
hci.HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND
):
response = await self.send_command(
hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command(),
check_result=True,
)
self.maximum_advertising_data_length = (
response.return_parameters.max_advertising_data_length
)
try:
response11 = await self.send_sync_command(
hci.HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
self.maximum_advertising_data_length = (
response11.max_advertising_data_length
)
except hci.HCI_Error:
logger.warning('Failed to read maximum advertising data length')
@property
def controller(self) -> TransportSink | None:
@@ -654,56 +663,173 @@ class Host(utils.EventEmitter):
if self.hci_sink:
self.hci_sink.on_packet(bytes(packet))
async def send_command(
self, command, check_result=False, response_timeout: int | None = None
):
async def _send_command(
self,
command: hci.HCI_SyncCommand | hci.HCI_AsyncCommand,
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event:
# Wait until we can send (only one pending command at a time)
async with self.command_semaphore:
assert self.pending_command is None
assert self.pending_response is None
await self.command_semaphore.acquire()
# Create a future value to hold the eventual response
self.pending_response = asyncio.get_running_loop().create_future()
self.pending_command = command
# Create a future value to hold the eventual response
assert self.pending_command is None
assert self.pending_response is None
self.pending_response = asyncio.get_running_loop().create_future()
self.pending_command = command
try:
self.send_hci_packet(command)
await asyncio.wait_for(self.pending_response, timeout=response_timeout)
response = self.pending_response.result()
response: (
hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event | None
) = None
try:
self.send_hci_packet(command)
response = await asyncio.wait_for(
self.pending_response, timeout=response_timeout
)
return response
except Exception:
logger.exception(color("!!! Exception while sending command:", "red"))
raise
finally:
self.pending_command = None
self.pending_response = None
if response is None or (
response.num_hci_command_packets and self.command_semaphore.locked()
):
self.command_semaphore.release()
# Check the return parameters if required
if check_result:
if isinstance(response, hci.HCI_Command_Status_Event):
status = response.status # type: ignore[attr-defined]
elif isinstance(response.return_parameters, int):
status = response.return_parameters
elif isinstance(response.return_parameters, bytes):
# return parameters first field is a one byte status code
status = response.return_parameters[0]
else:
status = response.return_parameters.status
@overload
async def send_command(
self,
command: hci.HCI_SyncCommand[_RP],
check_result: bool = False,
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event[_RP]: ...
if status != hci.HCI_SUCCESS:
logger.warning(
f'{command.name} failed '
f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
@overload
async def send_command(
self,
command: hci.HCI_AsyncCommand,
check_result: bool = False,
response_timeout: float | None = None,
) -> hci.HCI_Command_Status_Event: ...
return response
except Exception:
logger.exception(color("!!! Exception while sending command:", "red"))
raise
finally:
self.pending_command = None
self.pending_response = None
async def send_command(
self,
command: hci.HCI_SyncCommand[_RP] | hci.HCI_AsyncCommand,
check_result: bool = False,
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event[_RP] | hci.HCI_Command_Status_Event:
response = await self._send_command(command, response_timeout)
# Use this method to send a command from a task
def send_command_sync(self, command: hci.HCI_Command) -> None:
async def send_command(command: hci.HCI_Command) -> None:
await self.send_command(command)
# Check the return parameters if required
if check_result:
if isinstance(response, hci.HCI_Command_Status_Event):
status = response.status # type: ignore[attr-defined]
elif isinstance(response.return_parameters, int):
status = response.return_parameters
elif isinstance(response.return_parameters, bytes):
# return parameters first field is a one byte status code
status = response.return_parameters[0]
elif isinstance(
response.return_parameters, hci.HCI_GenericReturnParameters
):
# FIXME: temporary workaround
# NO STATUS
status = hci.HCI_SUCCESS
else:
status = response.return_parameters.status
asyncio.create_task(send_command(command))
if status != hci.HCI_SUCCESS:
logger.warning(
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return response
async def send_sync_command(
self, command: hci.HCI_SyncCommand[_RP], response_timeout: float | None = None
) -> _RP:
response = await self.send_sync_command_raw(command, response_timeout)
return_parameters = response.return_parameters
# Check the return parameters's status
if isinstance(return_parameters, hci.HCI_StatusReturnParameters):
status = return_parameters.status
elif isinstance(return_parameters, hci.HCI_GenericReturnParameters):
# if the payload has at least one byte, assume the first byte is the status
if not return_parameters.data:
raise RuntimeError('no status byte in return parameters')
status = hci.HCI_ErrorCode(return_parameters.data[0])
else:
raise RuntimeError(
f'unexpected return parameters type ({type(return_parameters)})'
)
if status != hci.HCI_ErrorCode.SUCCESS:
logger.warning(
f'{command.name} failed ' f'({hci.HCI_Constant.error_name(status)})'
)
raise hci.HCI_Error(status)
return return_parameters
async def send_sync_command_raw(
self,
command: hci.HCI_SyncCommand[_RP],
response_timeout: float | None = None,
) -> hci.HCI_Command_Complete_Event[_RP]:
response = await self._send_command(command, response_timeout)
# For unknown HCI commands, some controllers return Command Status instead of
# Command Complete.
if (
isinstance(response, hci.HCI_Command_Status_Event)
and response.status == hci.HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
):
return hci.HCI_Command_Complete_Event(
num_hci_command_packets=response.num_hci_command_packets,
command_opcode=command.op_code,
return_parameters=hci.HCI_StatusReturnParameters(
status=hci.HCI_ErrorCode(response.status)
), # type: ignore
)
# Check that the response is of the expected type
assert isinstance(response, hci.HCI_Command_Complete_Event)
return response
async def send_async_command(
self,
command: hci.HCI_AsyncCommand,
check_status: bool = True,
response_timeout: float | None = None,
) -> hci.HCI_ErrorCode:
response = await self._send_command(command, response_timeout)
# For unknown HCI commands, some controllers return Command Complete instead of
# Command Status.
if isinstance(response, hci.HCI_Command_Complete_Event):
# Assume the first byte of the return parameters is the status
if (
status := hci.HCI_ErrorCode(response.parameters[3])
) != hci.HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR:
logger.warning(f'unexpected return paramerers status {status}')
else:
assert isinstance(response, hci.HCI_Command_Status_Event)
status = hci.HCI_ErrorCode(response.status)
# Check the status if required
if check_status:
if status != hci.HCI_CommandStatus.PENDING:
logger.warning(f'{command.name} failed ' f'({status.name})')
raise hci.HCI_Error(status)
return status
@utils.deprecated("Use utils.AsyncRunner.spawn() instead.")
def send_command_sync(self, command: hci.HCI_AsyncCommand) -> None:
utils.AsyncRunner.spawn(self.send_async_command(command))
def send_acl_sdu(self, connection_handle: int, sdu: bytes) -> None:
if not (connection := self.connections.get(connection_handle)):
@@ -728,10 +854,22 @@ class Host(utils.EventEmitter):
data=pdu,
)
logger.debug(
'>>> ACL packet enqueue: (Handle=0x%04X) %s', connection_handle, pdu
'>>> ACL packet enqueue: (handle=0x%04X) %s',
connection_handle,
pdu.hex(),
)
packet_queue.enqueue(acl_packet, connection_handle)
def send_sco_sdu(self, connection_handle: int, sdu: bytes) -> None:
self.send_hci_packet(
hci.HCI_SynchronousDataPacket(
connection_handle=connection_handle,
packet_status=0,
data_total_length=len(sdu),
data=sdu,
)
)
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.send_acl_sdu(connection_handle, bytes(L2CAP_PDU(cid, pdu)))
@@ -816,16 +954,18 @@ class Host(utils.EventEmitter):
if self.local_supported_commands & mask
)
def supports_le_features(self, feature: hci.LeFeatureMask) -> bool:
return (self.local_le_features & feature) == feature
def supports_le_features(self, features: hci.LeFeatureMask) -> bool:
return (self.local_le_features & features) == features
def supports_lmp_features(self, feature: hci.LmpFeatureMask) -> bool:
return self.local_lmp_features & (feature) == feature
def supports_lmp_features(self, features: hci.LmpFeatureMask) -> bool:
return self.local_lmp_features & (features) == features
@property
def supported_le_features(self):
def supported_le_features(self) -> list[hci.LeFeature]:
return [
feature for feature in range(64) if self.local_le_features & (1 << feature)
feature
for feature in hci.LeFeature
if self.local_le_features & (1 << feature)
]
# Packet Sink protocol (packets coming from the controller via HCI)
@@ -860,18 +1000,19 @@ class Host(utils.EventEmitter):
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
# If the packet is a command, invoke the handler for this packet
if packet.hci_packet_type == hci.HCI_COMMAND_PACKET:
self.on_hci_command_packet(cast(hci.HCI_Command, packet))
elif packet.hci_packet_type == hci.HCI_EVENT_PACKET:
self.on_hci_event_packet(cast(hci.HCI_Event, packet))
elif packet.hci_packet_type == hci.HCI_ACL_DATA_PACKET:
self.on_hci_acl_data_packet(cast(hci.HCI_AclDataPacket, packet))
elif packet.hci_packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET:
self.on_hci_sco_data_packet(cast(hci.HCI_SynchronousDataPacket, packet))
elif packet.hci_packet_type == hci.HCI_ISO_DATA_PACKET:
self.on_hci_iso_data_packet(cast(hci.HCI_IsoDataPacket, packet))
else:
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
match packet:
case hci.HCI_Command():
self.on_hci_command_packet(packet)
case hci.HCI_Event():
self.on_hci_event_packet(packet)
case hci.HCI_AclDataPacket():
self.on_hci_acl_data_packet(packet)
case hci.HCI_SynchronousDataPacket():
self.on_hci_sco_data_packet(packet)
case hci.HCI_IsoDataPacket():
self.on_hci_iso_data_packet(packet)
case _:
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
def on_hci_command_packet(self, command: hci.HCI_Command) -> None:
logger.warning(f'!!! unexpected command packet: {command}')
@@ -914,6 +1055,8 @@ class Host(utils.EventEmitter):
self.pending_response.set_result(event)
else:
logger.warning('!!! no pending response future to set')
if event.num_hci_command_packets and self.command_semaphore.locked():
self.command_semaphore.release()
############################################################
# HCI handlers
@@ -925,7 +1068,13 @@ class Host(utils.EventEmitter):
if event.command_opcode == 0:
# This is used just for the Num_HCI_Command_Packets field, not related to
# an actual command
logger.debug('no-command event')
logger.debug('no-command event for flow control')
# Release the command semaphore if needed
if event.num_hci_command_packets and self.command_semaphore.locked():
logger.debug('command complete event releasing semaphore')
self.command_semaphore.release()
return
return self.on_command_processed(event)
@@ -1106,7 +1255,7 @@ class Host(utils.EventEmitter):
self, event: hci.HCI_LE_Connection_Update_Complete_Event
):
if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! CONNECTION PARAMETERS UPDATE COMPLETE: unknown handle')
logger.warning('!!! CONNECTION UPDATE COMPLETE: unknown handle')
return
# Notify the client
@@ -1123,6 +1272,29 @@ class Host(utils.EventEmitter):
'connection_parameters_update_failure', connection.handle, event.status
)
def on_hci_le_connection_rate_change_event(
self, event: hci.HCI_LE_Connection_Rate_Change_Event
):
if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! CONNECTION RATE CHANGE: unknown handle')
return
# Notify the client
if event.status == hci.HCI_SUCCESS:
self.emit(
'le_connection_rate_change',
connection.handle,
event.connection_interval,
event.subrate_factor,
event.peripheral_latency,
event.continuation_number,
event.supervision_timeout,
)
else:
self.emit(
'le_connection_rate_change_failure', connection.handle, event.status
)
def on_hci_le_phy_update_complete_event(
self, event: hci.HCI_LE_PHY_Update_Complete_Event
):
@@ -1338,15 +1510,17 @@ class Host(utils.EventEmitter):
# For now, just accept everything
# TODO: delegate the decision
self.send_command_sync(
hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
connection_handle=event.connection_handle,
interval_min=event.interval_min,
interval_max=event.interval_max,
max_latency=event.max_latency,
timeout=event.timeout,
min_ce_length=0,
max_ce_length=0,
utils.AsyncRunner.spawn(
self.send_sync_command(
hci.HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
connection_handle=event.connection_handle,
interval_min=event.interval_min,
interval_max=event.interval_max,
max_latency=event.max_latency,
timeout=event.timeout,
min_ce_length=0,
max_ce_length=0,
)
)
)
@@ -1382,9 +1556,9 @@ class Host(utils.EventEmitter):
connection_handle=event.connection_handle
)
await self.send_command(response)
await self.send_sync_command(response)
asyncio.create_task(send_long_term_key())
utils.AsyncRunner.spawn(send_long_term_key())
def on_hci_synchronous_connection_complete_event(
self, event: hci.HCI_Synchronous_Connection_Complete_Event
@@ -1484,6 +1658,19 @@ class Host(utils.EventEmitter):
'connection_encryption_failure', event.connection_handle, event.status
)
def on_hci_read_remote_supported_features_complete_event(
self, event: hci.HCI_Read_Remote_Supported_Features_Complete_Event
) -> None:
# Notify the client
self.emit(
'classic_remote_features',
event.connection_handle,
event.status,
int.from_bytes(event.lmp_features, 'little'),
0, # page number
0, # max page number
)
def on_hci_encryption_change_v2_event(
self, event: hci.HCI_Encryption_Change_V2_Event
):
@@ -1583,9 +1770,9 @@ class Host(utils.EventEmitter):
bd_addr=event.bd_addr
)
await self.send_command(response)
await self.send_sync_command(response)
asyncio.create_task(send_link_key())
utils.AsyncRunner.spawn(send_link_key())
def on_hci_io_capability_request_event(
self, event: hci.HCI_IO_Capability_Request_Event
@@ -1640,6 +1827,18 @@ class Host(utils.EventEmitter):
rssi,
)
def on_hci_read_remote_extended_features_complete_event(
self, event: hci.HCI_Read_Remote_Extended_Features_Complete_Event
):
self.emit(
'classic_remote_features',
event.connection_handle,
event.status,
int.from_bytes(event.extended_lmp_features, 'little'),
event.page_number,
event.maximum_page_number,
)
def on_hci_extended_inquiry_result_event(
self, event: hci.HCI_Extended_Inquiry_Result_Event
):
@@ -1680,12 +1879,13 @@ class Host(utils.EventEmitter):
self.emit(
'le_remote_features_failure', event.connection_handle, event.status
)
else:
self.emit(
'le_remote_features',
event.connection_handle,
int.from_bytes(event.le_features, 'little'),
)
return
self.emit(
'le_remote_features',
event.connection_handle,
hci.LeFeatureMask(int.from_bytes(event.le_features, 'little')),
)
def on_hci_le_cs_read_remote_supported_capabilities_complete_event(
self, event: hci.HCI_LE_CS_Read_Remote_Supported_Capabilities_Complete_Event
@@ -1718,6 +1918,12 @@ class Host(utils.EventEmitter):
self.emit('cs_subevent_result_continue', event)
def on_hci_le_subrate_change_event(self, event: hci.HCI_LE_Subrate_Change_Event):
if event.status != hci.HCI_SUCCESS:
self.emit(
'le_subrate_change_failure', event.connection_handle, event.status
)
return
self.emit(
'le_subrate_change',
event.connection_handle,
+29 -29
View File
@@ -27,6 +27,7 @@ import dataclasses
import json
import logging
import os
import pathlib
from typing import TYPE_CHECKING, Any
from typing_extensions import Self
@@ -248,29 +249,26 @@ class JsonKeyStore(KeyStore):
DEFAULT_NAMESPACE = '__DEFAULT__'
DEFAULT_BASE_NAME = "keys"
def __init__(self, namespace, filename=None):
self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE
def __init__(
self, namespace: str | None = None, filename: str | None = None
) -> None:
self.namespace = namespace or self.DEFAULT_NAMESPACE
if filename is None:
# Use a default for the current user
# Import here because this may not exist on all platforms
# pylint: disable=import-outside-toplevel
import appdirs
self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
)
base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace
json_filename = (
f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p')
)
self.filename = os.path.join(self.directory_name, json_filename)
if filename:
self.filename = pathlib.Path(filename).resolve()
self.directory_name = self.filename.parent
else:
self.filename = filename
self.directory_name = os.path.dirname(os.path.abspath(self.filename))
import platformdirs # Deferred import
logger.debug(f'JSON keystore: {self.filename}')
base_dir = platformdirs.user_data_path(self.APP_NAME, self.APP_AUTHOR)
self.directory_name = base_dir / self.KEYS_DIR
base_name = self.namespace if namespace else self.DEFAULT_BASE_NAME
safe_name = base_name.lower().replace(':', '-').replace('/', '-')
self.filename = self.directory_name / f"{safe_name}.json"
logger.debug('JSON keystore: %s', self.filename)
@classmethod
def from_device(
@@ -293,7 +291,9 @@ class JsonKeyStore(KeyStore):
return cls(namespace, filename)
async def load(self):
async def load(
self,
) -> tuple[dict[str, dict[str, dict[str, Any]]], dict[str, dict[str, Any]]]:
# Try to open the file, without failing. If the file does not exist, it
# will be created upon saving.
try:
@@ -312,17 +312,17 @@ class JsonKeyStore(KeyStore):
return next(iter(db.items()))
# Finally, just create an empty key map for the namespace
key_map = {}
key_map: dict[str, dict[str, Any]] = {}
db[self.namespace] = key_map
return (db, key_map)
async def save(self, db):
async def save(self, db: dict[str, dict[str, dict[str, Any]]]) -> None:
# Create the directory if it doesn't exist
if not os.path.exists(self.directory_name):
os.makedirs(self.directory_name, exist_ok=True)
if not self.directory_name.exists():
self.directory_name.mkdir(parents=True, exist_ok=True)
# Save to a temporary file
temp_filename = self.filename + '.tmp'
temp_filename = self.filename.with_name(self.filename.name + ".tmp")
with open(temp_filename, 'w', encoding='utf-8') as output:
json.dump(db, output, sort_keys=True, indent=4)
@@ -334,16 +334,16 @@ class JsonKeyStore(KeyStore):
del key_map[name]
await self.save(db)
async def update(self, name, keys):
async def update(self, name: str, keys: PairingKeys) -> None:
db, key_map = await self.load()
key_map.setdefault(name, {}).update(keys.to_dict())
await self.save(db)
async def get_all(self):
async def get_all(self) -> list[tuple[str, PairingKeys]]:
_, key_map = await self.load()
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()]
async def delete_all(self):
async def delete_all(self) -> None:
db, key_map = await self.load()
key_map.clear()
await self.save(db)
+135 -67
View File
@@ -20,6 +20,7 @@ from __future__ import annotations
import asyncio
import dataclasses
import enum
import itertools
import logging
import struct
from collections import deque
@@ -302,11 +303,9 @@ class EnhancedControlField(ControlField):
@dataclasses.dataclass
class InformationEnhancedControlField(EnhancedControlField):
tx_seq: int = 0
tx_seq: int
sar: int
req_seq: int = 0
segmentation_and_reassembly: int = (
EnhancedControlField.SegmentationAndReassembly.UNSEGMENTED
)
final: int = 1
frame_type = EnhancedControlField.FieldType.I_FRAME
@@ -316,15 +315,15 @@ class InformationEnhancedControlField(EnhancedControlField):
return cls(
tx_seq=(data[0] >> 1) & 0b0111111,
final=(data[0] >> 7) & 0b1,
req_seq=(data[1] & 0b001111111),
segmentation_and_reassembly=(data[1] >> 6) & 0b11,
req_seq=(data[1] & 0b00111111),
sar=(data[1] >> 6) & 0b11,
)
def __bytes__(self) -> bytes:
return bytes(
[
self.frame_type | (self.tx_seq << 1) | (self.final << 7),
self.req_seq | (self.segmentation_and_reassembly << 6),
self.req_seq | (self.sar << 6),
]
)
@@ -889,27 +888,38 @@ class EnhancedRetransmissionProcessor(Processor):
class _PendingPdu:
payload: bytes
tx_seq: int
sar: InformationEnhancedControlField.SegmentationAndReassembly
sdu_length: int = 0
req_seq: int = 0
def __bytes__(self) -> bytes:
return (
bytes(
InformationEnhancedControlField(
tx_seq=self.tx_seq, req_seq=self.req_seq
tx_seq=self.tx_seq,
req_seq=self.req_seq,
sar=self.sar,
)
)
+ (
struct.pack('<H', self.sdu_length)
if self.sar
== InformationEnhancedControlField.SegmentationAndReassembly.START
else b''
)
+ self.payload
)
_expected_ack_seq: int = 0
_last_acked_tx_seq: int = 0
_last_acked_rx_seq: int = 0
_next_tx_seq: int = 0
_last_tx_seq: int = 0
_req_seq_num: int = 0
_next_seq_num: int = 0
_remote_is_busy: bool = False
_in_sdu: bytes = b''
_num_receiver_ready_polls_sent: int = 0
_pending_pdus: list[_PendingPdu]
_tx_window: list[_PendingPdu]
_monitor_handle: asyncio.TimerHandle | None = None
_receiver_ready_poll_handle: asyncio.TimerHandle | None = None
@@ -917,12 +927,6 @@ class EnhancedRetransmissionProcessor(Processor):
monitor_timeout: float
retransmission_timeout: float
@classmethod
def _num_frames_between(cls, low: int, high: int) -> int:
if high < low:
high += cls.MAX_SEQ_NUM
return high - low
def __init__(
self,
channel: ClassicChannel,
@@ -935,6 +939,7 @@ class EnhancedRetransmissionProcessor(Processor):
self.peer_mps = peer_mps
self.peer_tx_window_size = peer_tx_window_size
self._pending_pdus = []
self._tx_window = []
self.monitor_timeout = spec.monitor_timeout
self.channel = channel
self.retransmission_timeout = spec.retransmission_timeout
@@ -972,12 +977,9 @@ class EnhancedRetransmissionProcessor(Processor):
def _send_receiver_ready_poll(self) -> None:
self._num_receiver_ready_polls_sent += 1
self.channel.send_pdu(
SupervisoryEnhancedControlField(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=1,
req_seq=self._next_seq_num,
)
self._send_s_frame(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=1,
)
def _get_next_tx_seq(self) -> int:
@@ -987,12 +989,35 @@ class EnhancedRetransmissionProcessor(Processor):
@override
def send_sdu(self, sdu: bytes) -> None:
if len(sdu) > self.peer_mps:
raise InvalidArgumentError(
f'SDU size({len(sdu)}) exceeds channel MPS {self.peer_mps}'
if len(sdu) <= self.peer_mps:
pdu = self._PendingPdu(
payload=sdu,
tx_seq=self._get_next_tx_seq(),
req_seq=self._req_seq_num,
sar=InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
)
pdu = self._PendingPdu(payload=sdu, tx_seq=self._get_next_tx_seq())
self._pending_pdus.append(pdu)
self._pending_pdus.append(pdu)
else:
for offset in range(0, len(sdu), self.peer_mps):
payload = sdu[offset : offset + self.peer_mps]
if offset == 0:
sar = (
InformationEnhancedControlField.SegmentationAndReassembly.START
)
elif offset + len(payload) >= len(sdu):
sar = InformationEnhancedControlField.SegmentationAndReassembly.END
else:
sar = (
InformationEnhancedControlField.SegmentationAndReassembly.CONTINUATION
)
pdu = self._PendingPdu(
payload=payload,
tx_seq=self._get_next_tx_seq(),
req_seq=self._req_seq_num,
sar=sar,
sdu_length=len(sdu),
)
self._pending_pdus.append(pdu)
self._process_output()
@override
@@ -1000,17 +1025,37 @@ class EnhancedRetransmissionProcessor(Processor):
control_field = EnhancedControlField.from_bytes(pdu)
self._update_ack_seq(control_field.req_seq, control_field.final != 0)
if isinstance(control_field, InformationEnhancedControlField):
if control_field.tx_seq != self._next_seq_num:
if control_field.tx_seq != self._req_seq_num:
logger.error(
"tx_seq != self._req_seq_num, tx_seq: %d, self._req_seq_num: %d",
control_field.tx_seq,
self._req_seq_num,
)
return
self._next_seq_num = (self._next_seq_num + 1) % self.MAX_SEQ_NUM
self._req_seq_num = self._next_seq_num
self._req_seq_num = (control_field.tx_seq + 1) % self.MAX_SEQ_NUM
ack_frame = SupervisoryEnhancedControlField(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
req_seq=self._next_seq_num,
)
self.channel.send_pdu(ack_frame)
self.channel.on_sdu(pdu[2:])
if (
control_field.sar
== InformationEnhancedControlField.SegmentationAndReassembly.START
):
# Drop Control Field(2) + SDU Length(2)
self._in_sdu += pdu[4:]
else:
# Drop Control Field(2)
self._in_sdu += pdu[2:]
if control_field.sar in (
InformationEnhancedControlField.SegmentationAndReassembly.END,
InformationEnhancedControlField.SegmentationAndReassembly.UNSEGMENTED,
):
self.channel.on_sdu(self._in_sdu)
self._in_sdu = b''
# If sink doesn't trigger any I-frame, ack this frame.
if self._req_seq_num != self._last_acked_rx_seq:
self._send_s_frame(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=0,
)
elif isinstance(control_field, SupervisoryEnhancedControlField):
self._remote_is_busy = (
control_field.supervision_function
@@ -1022,56 +1067,66 @@ class EnhancedRetransmissionProcessor(Processor):
SupervisoryEnhancedControlField.SupervisoryFunction.RNR,
):
if control_field.poll:
self.channel.send_pdu(
SupervisoryEnhancedControlField(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=1,
req_seq=self._next_seq_num,
)
self._send_s_frame(
supervision_function=SupervisoryEnhancedControlField.SupervisoryFunction.RR,
final=1,
)
else:
# TODO: Handle Retransmission.
pass
def _process_output(self) -> None:
if self._remote_is_busy or self._monitor_handle:
if self._remote_is_busy:
logger.debug("Remote is busy")
return
if self._monitor_handle:
logger.debug("Monitor handle is not None")
return
for pdu in self._pending_pdus:
if self._num_unacked_frames >= self.peer_tx_window_size:
return
self._send_pdu(pdu)
self._last_tx_seq = pdu.tx_seq
pdu_to_send = self.peer_tx_window_size - len(self._tx_window)
for pdu in itertools.islice(self._pending_pdus, pdu_to_send):
self._send_i_frame(pdu)
self._pending_pdus = self._pending_pdus[pdu_to_send:]
@property
def _num_unacked_frames(self) -> int:
if not self._pending_pdus:
return 0
return self._num_frames_between(self._expected_ack_seq, self._last_tx_seq + 1)
def _send_pdu(self, pdu: _PendingPdu) -> None:
def _send_i_frame(self, pdu: _PendingPdu) -> None:
pdu.req_seq = self._req_seq_num
self._start_receiver_ready_poll()
self._tx_window.append(pdu)
self.channel.send_pdu(bytes(pdu))
self._last_acked_rx_seq = self._req_seq_num
def _send_s_frame(
self,
supervision_function: SupervisoryEnhancedControlField.SupervisoryFunction,
final: int,
) -> None:
self.channel.send_pdu(
SupervisoryEnhancedControlField(
supervision_function=supervision_function,
final=final,
req_seq=self._req_seq_num,
)
)
self._last_acked_rx_seq = self._req_seq_num
def _update_ack_seq(self, new_seq: int, is_poll_response: bool) -> None:
num_frames_acked = self._num_frames_between(self._expected_ack_seq, new_seq)
if num_frames_acked > self._num_unacked_frames:
num_frames_acked = (new_seq - self._last_acked_tx_seq) % self.MAX_SEQ_NUM
if num_frames_acked > len(self._tx_window):
logger.error(
"Received acknowledgment for %d frames but only %d frames are pending",
num_frames_acked,
self._num_unacked_frames,
len(self._tx_window),
)
return
if is_poll_response and self._monitor_handle:
self._monitor_handle.cancel()
self._monitor_handle = None
del self._pending_pdus[:num_frames_acked]
self._expected_ack_seq = new_seq
del self._tx_window[:num_frames_acked]
self._last_acked_tx_seq = new_seq
if (
self._expected_ack_seq == self._next_tx_seq
self._last_acked_tx_seq == self._next_tx_seq
and self._receiver_ready_poll_handle
):
self._receiver_ready_poll_handle.cancel()
@@ -1552,6 +1607,7 @@ class LeCreditBasedChannel(utils.EventEmitter):
EVENT_OPEN = "open"
EVENT_CLOSE = "close"
EVENT_ATT_MTU_UPDATE = "att_mtu_update"
def __init__(
self,
@@ -1591,6 +1647,9 @@ class LeCreditBasedChannel(utils.EventEmitter):
self.connection_result = None
self.disconnection_result = None
self.drained = asyncio.Event()
# Core Specification Vol 3, Part G, 5.3.1 ATT_MTU
# ATT_MTU shall be set to the minimum of the MTU field values of the two devices.
self.att_mtu = min(mtu, peer_mtu)
self.drained.set()
@@ -1821,6 +1880,10 @@ class LeCreditBasedChannel(utils.EventEmitter):
self.disconnection_result.set_result(None)
self.disconnection_result = None
def on_att_mtu_update(self, mtu: int) -> None:
self.att_mtu = mtu
self.emit(self.EVENT_ATT_MTU_UPDATE, mtu)
def flush_output(self) -> None:
self.out_queue.clear()
self.out_sdu = None
@@ -2279,8 +2342,8 @@ class ChannelManager:
cid,
L2CAP_Connection_Response(
identifier=request.identifier,
destination_cid=request.source_cid,
source_cid=0,
destination_cid=0,
source_cid=request.source_cid,
result=L2CAP_Connection_Response.Result.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
status=0x0000,
),
@@ -2292,7 +2355,12 @@ class ChannelManager:
f'creating server channel with cid={source_cid} for psm {request.psm}'
)
channel = ClassicChannel(
self, connection, cid, request.psm, source_cid, server.spec
manager=self,
connection=connection,
signaling_cid=cid,
psm=request.psm,
source_cid=source_cid,
spec=server.spec,
)
connection_channels[source_cid] = channel
@@ -2309,8 +2377,8 @@ class ChannelManager:
cid,
L2CAP_Connection_Response(
identifier=request.identifier,
destination_cid=request.source_cid,
source_cid=0,
destination_cid=0,
source_cid=request.source_cid,
result=L2CAP_Connection_Response.Result.CONNECTION_REFUSED_PSM_NOT_SUPPORTED,
status=0x0000,
),
+21
View File
@@ -198,3 +198,24 @@ class CisTerminateInd(ControlPdu):
cig_id: int
cis_id: int
error_code: int
@dataclasses.dataclass
class FeatureReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_FEATURE_REQ
feature_set: bytes
@dataclasses.dataclass
class FeatureRsp(ControlPdu):
opcode = ControlPdu.Opcode.LL_FEATURE_RSP
feature_set: bytes
@dataclasses.dataclass
class PeripheralFeatureReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_PERIPHERAL_FEATURE_REQ
feature_set: bytes
+35
View File
@@ -322,3 +322,38 @@ class LmpNameRes(Packet):
name_offset: int = field(metadata=hci.metadata(2))
name_length: int = field(metadata=hci.metadata(3))
name_fregment: bytes = field(metadata=hci.metadata('*'))
@Packet.subclass
@dataclass
class LmpFeaturesReq(Packet):
opcode = Opcode.LMP_FEATURES_REQ
features: bytes = field(metadata=hci.metadata(8))
@Packet.subclass
@dataclass
class LmpFeaturesRes(Packet):
opcode = Opcode.LMP_FEATURES_RES
features: bytes = field(metadata=hci.metadata(8))
@Packet.subclass
@dataclass
class LmpFeaturesReqExt(Packet):
opcode = Opcode.LMP_FEATURES_REQ_EXT
features_page: int = field(metadata=hci.metadata(1))
features: bytes = field(metadata=hci.metadata(8))
@Packet.subclass
@dataclass
class LmpFeaturesResExt(Packet):
opcode = Opcode.LMP_FEATURES_RES_EXT
features_page: int = field(metadata=hci.metadata(1))
max_features_page: int = field(metadata=hci.metadata(1))
features: bytes = field(metadata=hci.metadata(8))
+10 -19
View File
@@ -21,18 +21,9 @@ import enum
import secrets
from dataclasses import dataclass
from bumble import hci
from bumble import hci, smp
from bumble.core import AdvertisingData, LeRole
from bumble.smp import (
SMP_DISPLAY_ONLY_IO_CAPABILITY,
SMP_DISPLAY_YES_NO_IO_CAPABILITY,
SMP_ENC_KEY_DISTRIBUTION_FLAG,
SMP_ID_KEY_DISTRIBUTION_FLAG,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY,
SMP_KEYBOARD_ONLY_IO_CAPABILITY,
SMP_LINK_KEY_DISTRIBUTION_FLAG,
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
OobContext,
OobLegacyContext,
OobSharedData,
@@ -96,11 +87,11 @@ class PairingDelegate:
# These are defined abstractly, and can be mapped to specific Classic pairing
# and/or SMP constants.
class IoCapability(enum.IntEnum):
NO_OUTPUT_NO_INPUT = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
KEYBOARD_INPUT_ONLY = SMP_KEYBOARD_ONLY_IO_CAPABILITY
DISPLAY_OUTPUT_ONLY = SMP_DISPLAY_ONLY_IO_CAPABILITY
DISPLAY_OUTPUT_AND_YES_NO_INPUT = SMP_DISPLAY_YES_NO_IO_CAPABILITY
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = SMP_KEYBOARD_DISPLAY_IO_CAPABILITY
NO_OUTPUT_NO_INPUT = smp.IoCapability.NO_INPUT_NO_OUTPUT
KEYBOARD_INPUT_ONLY = smp.IoCapability.KEYBOARD_ONLY
DISPLAY_OUTPUT_ONLY = smp.IoCapability.DISPLAY_ONLY
DISPLAY_OUTPUT_AND_YES_NO_INPUT = smp.IoCapability.DISPLAY_YES_NO
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = smp.IoCapability.KEYBOARD_DISPLAY
# Direct names for backward compatibility.
NO_OUTPUT_NO_INPUT = IoCapability.NO_OUTPUT_NO_INPUT
@@ -111,10 +102,10 @@ class PairingDelegate:
# Key Distribution [LE only]
class KeyDistribution(enum.IntFlag):
DISTRIBUTE_ENCRYPTION_KEY = SMP_ENC_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_IDENTITY_KEY = SMP_ID_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_ENCRYPTION_KEY = smp.KeyDistribution.ENC_KEY
DISTRIBUTE_IDENTITY_KEY = smp.KeyDistribution.ID_KEY
DISTRIBUTE_SIGNING_KEY = smp.KeyDistribution.SIGN_KEY
DISTRIBUTE_LINK_KEY = smp.KeyDistribution.LINK_KEY
DEFAULT_KEY_DISTRIBUTION: KeyDistribution = (
KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY
+1 -1
View File
@@ -278,7 +278,7 @@ class L2CAPService(L2CAPServicer):
if not l2cap_channel:
return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
if isinstance(l2cap_channel, ClassicChannel):
l2cap_channel.send_pdu(request.data)
l2cap_channel.write(request.data)
else:
l2cap_channel.write(request.data)
return SendResponse(success=empty_pb2.Empty())
+37 -39
View File
@@ -664,46 +664,44 @@ class AudioStreamControlService(gatt.TemplateService):
responses = []
logger.debug(f'*** ASCS Write {operation} ***')
if isinstance(operation, ASE_Config_Codec):
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
match operation:
case ASE_Config_Codec():
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case ASE_Config_QOS():
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case ASE_Enable() | ASE_Update_Metadata():
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case (
ASE_Receiver_Start_Ready()
| ASE_Disable()
| ASE_Receiver_Stop_Ready()
| ASE_Release()
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif isinstance(operation, ASE_Config_QOS):
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif isinstance(operation, (ASE_Enable, ASE_Update_Metadata)):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif isinstance(
operation,
(
ASE_Receiver_Start_Ready,
ASE_Disable,
ASE_Receiver_Stop_Ready,
ASE_Release,
),
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes(
[operation.op_code, len(responses)]
+12 -11
View File
@@ -333,17 +333,18 @@ class CodecSpecificCapabilities:
value = int.from_bytes(data[offset : offset + length - 1], 'little')
offset += length - 1
if type == CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
supported_sampling_frequencies = SupportedSamplingFrequency(value)
elif type == CodecSpecificCapabilities.Type.FRAME_DURATION:
supported_frame_durations = SupportedFrameDuration(value)
elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
supported_audio_channel_count = bits_to_channel_counts(value)
elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
min_octets_per_sample = value & 0xFFFF
max_octets_per_sample = value >> 16
elif type == CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
supported_max_codec_frames_per_sdu = value
match type:
case CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
supported_sampling_frequencies = SupportedSamplingFrequency(value)
case CodecSpecificCapabilities.Type.FRAME_DURATION:
supported_frame_durations = SupportedFrameDuration(value)
case CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
supported_audio_channel_count = bits_to_channel_counts(value)
case CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
min_octets_per_sample = value & 0xFFFF
max_octets_per_sample = value >> 16
case CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
supported_max_codec_frames_per_sdu = value
# It is expected here that if some fields are missing, an error should be raised.
# pylint: disable=possibly-used-before-assignment,used-before-assignment
+24 -33
View File
@@ -16,35 +16,28 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from collections.abc import Callable
from bumble.gatt import (
GATT_BATTERY_LEVEL_CHARACTERISTIC,
GATT_BATTERY_SERVICE,
Characteristic,
CharacteristicValue,
TemplateService,
)
from bumble.gatt_adapters import (
PackedCharacteristicAdapter,
PackedCharacteristicProxyAdapter,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy
from bumble import device, gatt, gatt_adapters, gatt_client
# -----------------------------------------------------------------------------
class BatteryService(TemplateService):
UUID = GATT_BATTERY_SERVICE
class BatteryService(gatt.TemplateService):
UUID = gatt.GATT_BATTERY_SERVICE
BATTERY_LEVEL_FORMAT = 'B'
battery_level_characteristic: Characteristic[int]
battery_level_characteristic: gatt.Characteristic[int]
def __init__(self, read_battery_level):
self.battery_level_characteristic = PackedCharacteristicAdapter(
Characteristic(
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
CharacteristicValue(read=read_battery_level),
def __init__(self, read_battery_level: Callable[[device.Connection], int]) -> None:
self.battery_level_characteristic = gatt_adapters.PackedCharacteristicAdapter(
gatt.Characteristic(
gatt.GATT_BATTERY_LEVEL_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.READABLE,
value=gatt.CharacteristicValue(read=read_battery_level),
),
pack_format=BatteryService.BATTERY_LEVEL_FORMAT,
)
@@ -52,19 +45,17 @@ class BatteryService(TemplateService):
# -----------------------------------------------------------------------------
class BatteryServiceProxy(ProfileServiceProxy):
class BatteryServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = BatteryService
battery_level: CharacteristicProxy[int] | None
battery_level: gatt_client.CharacteristicProxy[int]
def __init__(self, service_proxy):
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BATTERY_LEVEL_CHARACTERISTIC
):
self.battery_level = PackedCharacteristicProxyAdapter(
characteristics[0], pack_format=BatteryService.BATTERY_LEVEL_FORMAT
)
else:
self.battery_level = None
self.battery_level = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_BATTERY_LEVEL_CHARACTERISTIC
),
pack_format=BatteryService.BATTERY_LEVEL_FORMAT,
)
+9 -8
View File
@@ -55,14 +55,15 @@ class GenericAccessService(TemplateService):
def __init__(
self, device_name: str, appearance: Appearance | tuple[int, int] | int = 0
):
if isinstance(appearance, int):
appearance_int = appearance
elif isinstance(appearance, tuple):
appearance_int = (appearance[0] << 6) | appearance[1]
elif isinstance(appearance, Appearance):
appearance_int = int(appearance)
else:
raise TypeError()
match appearance:
case int():
appearance_int = appearance
case tuple():
appearance_int = (appearance[0] << 6) | appearance[1]
case Appearance():
appearance_int = int(appearance)
case _:
raise TypeError()
self.device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC,
+128 -119
View File
@@ -18,40 +18,30 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import enum
import struct
from enum import IntEnum
from collections.abc import Callable, Sequence
from typing import Any
from bumble import core
from bumble.att import ATT_Error
from bumble.gatt import (
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
GATT_HEART_RATE_SERVICE,
Characteristic,
CharacteristicValue,
TemplateService,
)
from bumble.gatt_adapters import (
DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter,
SerializableCharacteristicAdapter,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy
from typing_extensions import Self
from bumble import att, core, device, gatt, gatt_adapters, gatt_client, utils
# -----------------------------------------------------------------------------
class HeartRateService(TemplateService):
UUID = GATT_HEART_RATE_SERVICE
class HeartRateService(gatt.TemplateService):
UUID = gatt.GATT_HEART_RATE_SERVICE
HEART_RATE_CONTROL_POINT_FORMAT = 'B'
CONTROL_POINT_NOT_SUPPORTED = 0x80
RESET_ENERGY_EXPENDED = 0x01
heart_rate_measurement_characteristic: Characteristic[HeartRateMeasurement]
body_sensor_location_characteristic: Characteristic[BodySensorLocation]
heart_rate_control_point_characteristic: Characteristic[int]
heart_rate_measurement_characteristic: gatt.Characteristic[HeartRateMeasurement]
body_sensor_location_characteristic: gatt.Characteristic[BodySensorLocation]
heart_rate_control_point_characteristic: gatt.Characteristic[int]
class BodySensorLocation(IntEnum):
class BodySensorLocation(utils.OpenIntEnum):
OTHER = 0
CHEST = 1
WRIST = 2
@@ -60,82 +50,90 @@ class HeartRateService(TemplateService):
EAR_LOBE = 5
FOOT = 6
@dataclasses.dataclass
class HeartRateMeasurement:
def __init__(
self,
heart_rate,
sensor_contact_detected=None,
energy_expended=None,
rr_intervals=None,
):
if heart_rate < 0 or heart_rate > 0xFFFF:
heart_rate: int
sensor_contact_detected: bool | None = None
energy_expended: int | None = None
rr_intervals: Sequence[float] | None = None
class Flag(enum.IntFlag):
INT16_HEART_RATE = 1 << 0
SENSOR_CONTACT_DETECTED = 1 << 1
SENSOR_CONTACT_SUPPORTED = 1 << 2
ENERGY_EXPENDED_STATUS = 1 << 3
RR_INTERVAL = 1 << 4
def __post_init__(self) -> None:
if self.heart_rate < 0 or self.heart_rate > 0xFFFF:
raise core.InvalidArgumentError('heart_rate out of range')
if energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF
if self.energy_expended is not None and (
self.energy_expended < 0 or self.energy_expended > 0xFFFF
):
raise core.InvalidArgumentError('energy_expended out of range')
if rr_intervals:
for rr_interval in rr_intervals:
if self.rr_intervals:
for rr_interval in self.rr_intervals:
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise core.InvalidArgumentError('rr_intervals out of range')
self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected
self.energy_expended = energy_expended
self.rr_intervals = rr_intervals
@classmethod
def from_bytes(cls, data):
def from_bytes(cls, data: bytes) -> Self:
flags = data[0]
offset = 1
if flags & 1:
hr = struct.unpack_from('<H', data, offset)[0]
if flags & cls.Flag.INT16_HEART_RATE:
heart_rate = struct.unpack_from('<H', data, offset)[0]
offset += 2
else:
hr = struct.unpack_from('B', data, offset)[0]
heart_rate = struct.unpack_from('B', data, offset)[0]
offset += 1
if flags & (1 << 2):
sensor_contact_detected = flags & (1 << 1) != 0
if flags & cls.Flag.SENSOR_CONTACT_SUPPORTED:
sensor_contact_detected = flags & cls.Flag.SENSOR_CONTACT_DETECTED != 0
else:
sensor_contact_detected = None
if flags & (1 << 3):
if flags & cls.Flag.ENERGY_EXPENDED_STATUS:
energy_expended = struct.unpack_from('<H', data, offset)[0]
offset += 2
else:
energy_expended = None
if flags & (1 << 4):
rr_intervals: Sequence[float] | None = None
if flags & cls.Flag.RR_INTERVAL:
rr_intervals = tuple(
struct.unpack_from('<H', data, offset + i * 2)[0] / 1024
for i in range((len(data) - offset) // 2)
struct.unpack_from('<H', data, i)[0] / 1024
for i in range(offset, len(data), 2)
)
else:
rr_intervals = ()
return cls(hr, sensor_contact_detected, energy_expended, rr_intervals)
return cls(
heart_rate=heart_rate,
sensor_contact_detected=sensor_contact_detected,
energy_expended=energy_expended,
rr_intervals=rr_intervals,
)
def __bytes__(self):
def __bytes__(self) -> bytes:
flags = 0
if self.heart_rate < 256:
flags = 0
data = struct.pack('B', self.heart_rate)
else:
flags = 1
flags |= self.Flag.INT16_HEART_RATE
data = struct.pack('<H', self.heart_rate)
if self.sensor_contact_detected is not None:
flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2)
flags |= self.Flag.SENSOR_CONTACT_SUPPORTED
if self.sensor_contact_detected:
flags |= self.Flag.SENSOR_CONTACT_DETECTED
if self.energy_expended is not None:
flags |= 1 << 3
flags |= self.Flag.ENERGY_EXPENDED_STATUS
data += struct.pack('<H', self.energy_expended)
if self.rr_intervals:
flags |= 1 << 4
if self.rr_intervals is not None:
flags |= self.Flag.RR_INTERVAL
data += b''.join(
[
struct.pack('<H', int(rr_interval * 1024))
@@ -145,57 +143,67 @@ class HeartRateService(TemplateService):
return bytes([flags]) + data
def __str__(self):
return (
f'HeartRateMeasurement(heart_rate={self.heart_rate},'
f' sensor_contact_detected={self.sensor_contact_detected},'
f' energy_expended={self.energy_expended},'
f' rr_intervals={self.rr_intervals})'
)
def __init__(
self,
read_heart_rate_measurement,
body_sensor_location=None,
reset_energy_expended=None,
read_heart_rate_measurement: Callable[
[device.Connection], HeartRateMeasurement
],
body_sensor_location: HeartRateService.BodySensorLocation | None = None,
reset_energy_expended: Callable[[device.Connection], Any] | None = None,
):
self.heart_rate_measurement_characteristic = SerializableCharacteristicAdapter(
Characteristic(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Characteristic.Properties.NOTIFY,
0,
CharacteristicValue(read=read_heart_rate_measurement),
),
HeartRateService.HeartRateMeasurement,
self.heart_rate_measurement_characteristic = (
gatt_adapters.SerializableCharacteristicAdapter(
gatt.Characteristic(
uuid=gatt.GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions(0),
value=gatt.CharacteristicValue(read=read_heart_rate_measurement),
),
HeartRateService.HeartRateMeasurement,
)
)
characteristics = [self.heart_rate_measurement_characteristic]
characteristics: list[gatt.Characteristic] = [
self.heart_rate_measurement_characteristic
]
if body_sensor_location is not None:
self.body_sensor_location_characteristic = Characteristic(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
bytes([int(body_sensor_location)]),
self.body_sensor_location_characteristic = (
gatt_adapters.EnumCharacteristicAdapter(
gatt.Characteristic(
uuid=gatt.GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.READABLE,
value=body_sensor_location,
),
cls=self.BodySensorLocation,
length=1,
)
)
characteristics.append(self.body_sensor_location_characteristic)
if reset_energy_expended:
def write_heart_rate_control_point_value(connection, value):
def write_heart_rate_control_point_value(
connection: device.Connection, value: bytes
) -> None:
if value == self.RESET_ENERGY_EXPENDED:
if reset_energy_expended is not None:
reset_energy_expended(connection)
else:
raise ATT_Error(self.CONTROL_POINT_NOT_SUPPORTED)
raise att.ATT_Error(self.CONTROL_POINT_NOT_SUPPORTED)
self.heart_rate_control_point_characteristic = PackedCharacteristicAdapter(
Characteristic(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE,
Characteristic.WRITEABLE,
CharacteristicValue(write=write_heart_rate_control_point_value),
),
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
self.heart_rate_control_point_characteristic = (
gatt_adapters.PackedCharacteristicAdapter(
gatt.Characteristic(
uuid=gatt.GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.WRITEABLE,
value=gatt.CharacteristicValue(
write=write_heart_rate_control_point_value
),
),
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
)
characteristics.append(self.heart_rate_control_point_characteristic)
@@ -203,50 +211,51 @@ class HeartRateService(TemplateService):
# -----------------------------------------------------------------------------
class HeartRateServiceProxy(ProfileServiceProxy):
class HeartRateServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HeartRateService
heart_rate_measurement: (
CharacteristicProxy[HeartRateService.HeartRateMeasurement] | None
)
heart_rate_measurement: gatt_client.CharacteristicProxy[
HeartRateService.HeartRateMeasurement
]
body_sensor_location: (
CharacteristicProxy[HeartRateService.BodySensorLocation] | None
gatt_client.CharacteristicProxy[HeartRateService.BodySensorLocation] | None
)
heart_rate_control_point: CharacteristicProxy[int] | None
heart_rate_control_point: gatt_client.CharacteristicProxy[int] | None
def __init__(self, service_proxy):
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
):
self.heart_rate_measurement = SerializableCharacteristicAdapter(
characteristics[0], HeartRateService.HeartRateMeasurement
self.heart_rate_measurement = (
gatt_adapters.SerializableCharacteristicProxyAdapter(
service_proxy.get_required_characteristic_by_uuid(
gatt.GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
),
HeartRateService.HeartRateMeasurement,
)
else:
self.heart_rate_measurement = None
)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
gatt.GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
):
self.body_sensor_location = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: HeartRateService.BodySensorLocation(value[0]),
self.body_sensor_location = gatt_adapters.EnumCharacteristicProxyAdapter(
characteristics[0], cls=HeartRateService.BodySensorLocation, length=1
)
else:
self.body_sensor_location = None
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
gatt.GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
):
self.heart_rate_control_point = PackedCharacteristicAdapter(
characteristics[0],
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
self.heart_rate_control_point = (
gatt_adapters.PackedCharacteristicProxyAdapter(
characteristics[0],
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
)
else:
self.heart_rate_control_point = None
async def reset_energy_expended(self):
async def reset_energy_expended(self) -> None:
if self.heart_rate_control_point is not None:
return await self.heart_rate_control_point.write_value(
HeartRateService.RESET_ENERGY_EXPENDED
+1 -1
View File
@@ -800,7 +800,7 @@ class Multiplexer(utils.EventEmitter):
def send_frame(self, frame: RFCOMM_Frame) -> None:
logger.debug(f'>>> Multiplexer sending {frame}')
self.l2cap_channel.send_pdu(frame)
self.l2cap_channel.write(bytes(frame))
def on_pdu(self, pdu: bytes) -> None:
frame = RFCOMM_Frame.from_bytes(pdu)
+526 -500
View File
File diff suppressed because it is too large Load Diff
+274 -268
View File
@@ -27,18 +27,18 @@ from __future__ import annotations
import asyncio
import enum
import logging
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, ClassVar, TypeVar, cast
from bumble import crypto, utils
from bumble import crypto, hci, utils
from bumble.colors import color
from bumble.core import (
AdvertisingData,
InvalidArgumentError,
InvalidPacketError,
PhysicalTransport,
ProtocolError,
name_or_number,
)
from bumble.hci import (
Address,
@@ -46,7 +46,6 @@ from bumble.hci import (
HCI_LE_Enable_Encryption_Command,
HCI_Object,
Role,
key_with_value,
metadata,
)
from bumble.keys import PairingKeys
@@ -71,115 +70,125 @@ logger = logging.getLogger(__name__)
SMP_CID = 0x06
SMP_BR_CID = 0x07
SMP_PAIRING_REQUEST_COMMAND = 0x01
SMP_PAIRING_RESPONSE_COMMAND = 0x02
SMP_PAIRING_CONFIRM_COMMAND = 0x03
SMP_PAIRING_RANDOM_COMMAND = 0x04
SMP_PAIRING_FAILED_COMMAND = 0x05
SMP_ENCRYPTION_INFORMATION_COMMAND = 0x06
SMP_MASTER_IDENTIFICATION_COMMAND = 0x07
SMP_IDENTITY_INFORMATION_COMMAND = 0x08
SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND = 0x09
SMP_SIGNING_INFORMATION_COMMAND = 0x0A
SMP_SECURITY_REQUEST_COMMAND = 0x0B
SMP_PAIRING_PUBLIC_KEY_COMMAND = 0x0C
SMP_PAIRING_DHKEY_CHECK_COMMAND = 0x0D
SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND = 0x0E
class CommandCode(hci.SpecableEnum):
PAIRING_REQUEST = 0x01
PAIRING_RESPONSE = 0x02
PAIRING_CONFIRM = 0x03
PAIRING_RANDOM = 0x04
PAIRING_FAILED = 0x05
ENCRYPTION_INFORMATION = 0x06
MASTER_IDENTIFICATION = 0x07
IDENTITY_INFORMATION = 0x08
IDENTITY_ADDRESS_INFORMATION = 0x09
SIGNING_INFORMATION = 0x0A
SECURITY_REQUEST = 0x0B
PAIRING_PUBLIC_KEY = 0x0C
PAIRING_DHKEY_CHECK = 0x0D
PAIRING_KEYPRESS_NOTIFICATION = 0x0E
SMP_COMMAND_NAMES = {
SMP_PAIRING_REQUEST_COMMAND: 'SMP_PAIRING_REQUEST_COMMAND',
SMP_PAIRING_RESPONSE_COMMAND: 'SMP_PAIRING_RESPONSE_COMMAND',
SMP_PAIRING_CONFIRM_COMMAND: 'SMP_PAIRING_CONFIRM_COMMAND',
SMP_PAIRING_RANDOM_COMMAND: 'SMP_PAIRING_RANDOM_COMMAND',
SMP_PAIRING_FAILED_COMMAND: 'SMP_PAIRING_FAILED_COMMAND',
SMP_ENCRYPTION_INFORMATION_COMMAND: 'SMP_ENCRYPTION_INFORMATION_COMMAND',
SMP_MASTER_IDENTIFICATION_COMMAND: 'SMP_MASTER_IDENTIFICATION_COMMAND',
SMP_IDENTITY_INFORMATION_COMMAND: 'SMP_IDENTITY_INFORMATION_COMMAND',
SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND: 'SMP_IDENTITY_ADDRESS_INFORMATION_COMMAND',
SMP_SIGNING_INFORMATION_COMMAND: 'SMP_SIGNING_INFORMATION_COMMAND',
SMP_SECURITY_REQUEST_COMMAND: 'SMP_SECURITY_REQUEST_COMMAND',
SMP_PAIRING_PUBLIC_KEY_COMMAND: 'SMP_PAIRING_PUBLIC_KEY_COMMAND',
SMP_PAIRING_DHKEY_CHECK_COMMAND: 'SMP_PAIRING_DHKEY_CHECK_COMMAND',
SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND: 'SMP_PAIRING_KEYPRESS_NOTIFICATION_COMMAND'
}
SMP_DISPLAY_ONLY_IO_CAPABILITY = 0x00
SMP_DISPLAY_YES_NO_IO_CAPABILITY = 0x01
SMP_KEYBOARD_ONLY_IO_CAPABILITY = 0x02
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY = 0x03
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY = 0x04
class IoCapability(hci.SpecableEnum):
DISPLAY_ONLY = 0x00
DISPLAY_YES_NO = 0x01
KEYBOARD_ONLY = 0x02
NO_INPUT_NO_OUTPUT = 0x03
KEYBOARD_DISPLAY = 0x04
SMP_IO_CAPABILITY_NAMES = {
SMP_DISPLAY_ONLY_IO_CAPABILITY: 'SMP_DISPLAY_ONLY_IO_CAPABILITY',
SMP_DISPLAY_YES_NO_IO_CAPABILITY: 'SMP_DISPLAY_YES_NO_IO_CAPABILITY',
SMP_KEYBOARD_ONLY_IO_CAPABILITY: 'SMP_KEYBOARD_ONLY_IO_CAPABILITY',
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: 'SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY',
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: 'SMP_KEYBOARD_DISPLAY_IO_CAPABILITY'
}
SMP_DISPLAY_ONLY_IO_CAPABILITY = IoCapability.DISPLAY_ONLY
SMP_DISPLAY_YES_NO_IO_CAPABILITY = IoCapability.DISPLAY_YES_NO
SMP_KEYBOARD_ONLY_IO_CAPABILITY = IoCapability.KEYBOARD_ONLY
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY = IoCapability.NO_INPUT_NO_OUTPUT
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY = IoCapability.KEYBOARD_DISPLAY
SMP_PASSKEY_ENTRY_FAILED_ERROR = 0x01
SMP_OOB_NOT_AVAILABLE_ERROR = 0x02
SMP_AUTHENTICATION_REQUIREMENTS_ERROR = 0x03
SMP_CONFIRM_VALUE_FAILED_ERROR = 0x04
SMP_PAIRING_NOT_SUPPORTED_ERROR = 0x05
SMP_ENCRYPTION_KEY_SIZE_ERROR = 0x06
SMP_COMMAND_NOT_SUPPORTED_ERROR = 0x07
SMP_UNSPECIFIED_REASON_ERROR = 0x08
SMP_REPEATED_ATTEMPTS_ERROR = 0x09
SMP_INVALID_PARAMETERS_ERROR = 0x0A
SMP_DHKEY_CHECK_FAILED_ERROR = 0x0B
SMP_NUMERIC_COMPARISON_FAILED_ERROR = 0x0C
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR = 0x0D
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR = 0x0E
class ErrorCode(hci.SpecableEnum):
PASSKEY_ENTRY_FAILED = 0x01
OOB_NOT_AVAILABLE = 0x02
AUTHENTICATION_REQUIREMENTS = 0x03
CONFIRM_VALUE_FAILED = 0x04
PAIRING_NOT_SUPPORTED = 0x05
ENCRYPTION_KEY_SIZE = 0x06
COMMAND_NOT_SUPPORTED = 0x07
UNSPECIFIED_REASON = 0x08
REPEATED_ATTEMPTS = 0x09
INVALID_PARAMETERS = 0x0A
DHKEY_CHECK_FAILED = 0x0B
NUMERIC_COMPARISON_FAILED = 0x0C
BD_EDR_PAIRING_IN_PROGRESS = 0x0D
CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED = 0x0E
SMP_ERROR_NAMES = {
SMP_PASSKEY_ENTRY_FAILED_ERROR: 'SMP_PASSKEY_ENTRY_FAILED_ERROR',
SMP_OOB_NOT_AVAILABLE_ERROR: 'SMP_OOB_NOT_AVAILABLE_ERROR',
SMP_AUTHENTICATION_REQUIREMENTS_ERROR: 'SMP_AUTHENTICATION_REQUIREMENTS_ERROR',
SMP_CONFIRM_VALUE_FAILED_ERROR: 'SMP_CONFIRM_VALUE_FAILED_ERROR',
SMP_PAIRING_NOT_SUPPORTED_ERROR: 'SMP_PAIRING_NOT_SUPPORTED_ERROR',
SMP_ENCRYPTION_KEY_SIZE_ERROR: 'SMP_ENCRYPTION_KEY_SIZE_ERROR',
SMP_COMMAND_NOT_SUPPORTED_ERROR: 'SMP_COMMAND_NOT_SUPPORTED_ERROR',
SMP_UNSPECIFIED_REASON_ERROR: 'SMP_UNSPECIFIED_REASON_ERROR',
SMP_REPEATED_ATTEMPTS_ERROR: 'SMP_REPEATED_ATTEMPTS_ERROR',
SMP_INVALID_PARAMETERS_ERROR: 'SMP_INVALID_PARAMETERS_ERROR',
SMP_DHKEY_CHECK_FAILED_ERROR: 'SMP_DHKEY_CHECK_FAILED_ERROR',
SMP_NUMERIC_COMPARISON_FAILED_ERROR: 'SMP_NUMERIC_COMPARISON_FAILED_ERROR',
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR: 'SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR',
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR: 'SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR'
}
SMP_PASSKEY_ENTRY_FAILED_ERROR = ErrorCode.PASSKEY_ENTRY_FAILED
SMP_OOB_NOT_AVAILABLE_ERROR = ErrorCode.OOB_NOT_AVAILABLE
SMP_AUTHENTICATION_REQUIREMENTS_ERROR = ErrorCode.AUTHENTICATION_REQUIREMENTS
SMP_CONFIRM_VALUE_FAILED_ERROR = ErrorCode.CONFIRM_VALUE_FAILED
SMP_PAIRING_NOT_SUPPORTED_ERROR = ErrorCode.PAIRING_NOT_SUPPORTED
SMP_ENCRYPTION_KEY_SIZE_ERROR = ErrorCode.ENCRYPTION_KEY_SIZE
SMP_COMMAND_NOT_SUPPORTED_ERROR = ErrorCode.COMMAND_NOT_SUPPORTED
SMP_UNSPECIFIED_REASON_ERROR = ErrorCode.UNSPECIFIED_REASON
SMP_REPEATED_ATTEMPTS_ERROR = ErrorCode.REPEATED_ATTEMPTS
SMP_INVALID_PARAMETERS_ERROR = ErrorCode.INVALID_PARAMETERS
SMP_DHKEY_CHECK_FAILED_ERROR = ErrorCode.DHKEY_CHECK_FAILED
SMP_NUMERIC_COMPARISON_FAILED_ERROR = ErrorCode.NUMERIC_COMPARISON_FAILED
SMP_BD_EDR_PAIRING_IN_PROGRESS_ERROR = ErrorCode.BD_EDR_PAIRING_IN_PROGRESS
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR = ErrorCode.CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED
SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE = 0
SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE = 1
SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE = 2
SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE = 3
SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE = 4
SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES = {
SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_ENTRY_STARTED_KEYPRESS_NOTIFICATION_TYPE',
SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_DIGIT_ENTERED_KEYPRESS_NOTIFICATION_TYPE',
SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_DIGIT_ERASED_KEYPRESS_NOTIFICATION_TYPE',
SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_CLEARED_KEYPRESS_NOTIFICATION_TYPE',
SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE: 'SMP_PASSKEY_ENTRY_COMPLETED_KEYPRESS_NOTIFICATION_TYPE'
}
class KeypressNotificationType(hci.SpecableEnum):
PASSKEY_ENTRY_STARTED = 0
PASSKEY_DIGIT_ENTERED = 1
PASSKEY_DIGIT_ERASED = 2
PASSKEY_CLEARED = 3
PASSKEY_ENTRY_COMPLETED = 4
# Bit flags for key distribution/generation
SMP_ENC_KEY_DISTRIBUTION_FLAG = 0b0001
SMP_ID_KEY_DISTRIBUTION_FLAG = 0b0010
SMP_SIGN_KEY_DISTRIBUTION_FLAG = 0b0100
SMP_LINK_KEY_DISTRIBUTION_FLAG = 0b1000
class KeyDistribution(hci.SpecableFlag):
ENC_KEY = 0b0001
ID_KEY = 0b0010
SIGN_KEY = 0b0100
LINK_KEY = 0b1000
# AuthReq fields
SMP_BONDING_AUTHREQ = 0b00000001
SMP_MITM_AUTHREQ = 0b00000100
SMP_SC_AUTHREQ = 0b00001000
SMP_KEYPRESS_AUTHREQ = 0b00010000
SMP_CT2_AUTHREQ = 0b00100000
class AuthReq(hci.SpecableFlag):
BONDING = 0b00000001
MITM = 0b00000100
SC = 0b00001000
KEYPRESS = 0b00010000
CT2 = 0b00100000
@classmethod
def from_booleans(
cls,
bonding: bool = False,
sc: bool = False,
mitm: bool = False,
keypress: bool = False,
ct2: bool = False,
) -> AuthReq:
auth_req = AuthReq(0)
if bonding:
auth_req |= AuthReq.BONDING
if sc:
auth_req |= AuthReq.SC
if mitm:
auth_req |= AuthReq.MITM
if keypress:
auth_req |= AuthReq.KEYPRESS
if ct2:
auth_req |= AuthReq.CT2
return auth_req
# Crypto salt
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('000000000000000000000000746D7031')
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
# Diffie-Hellman private / public key pair in Debug Mode (Core - Vol. 3, Part H)
SMP_DEBUG_KEY_PRIVATE = bytes.fromhex(
'3f49f6d4 a3c55f38 74c9b3e3 d2103f50 4aff607b eb40b799 5899b8a6 cd3c1abd'
)
SMP_DEBUG_KEY_PUBLIC_X = bytes.fromhex(
'20b003d2 f297be2c 5e2c83a7 e9f9a5b9 eff49111 acf4fddb cc030148 0e359de6'
)
SMP_DEBUG_KEY_PUBLIC_Y= bytes.fromhex(
'dc809c49 652aeb6d 63329abf 5a52155c 766345c2 8fed3024 741c8ed0 1589d28b'
)
# fmt: on
# pylint: enable=line-too-long
# pylint: disable=invalid-name
@@ -188,8 +197,6 @@ SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def error_name(error_code: int) -> str:
return name_or_number(SMP_ERROR_NAMES, error_code)
# -----------------------------------------------------------------------------
@@ -201,20 +208,22 @@ class SMP_Command:
See Bluetooth spec @ Vol 3, Part H - 3 SECURITY MANAGER PROTOCOL
'''
smp_classes: ClassVar[dict[int, type[SMP_Command]]] = {}
smp_classes: ClassVar[dict[CommandCode, type[SMP_Command]]] = {}
fields: ClassVar[Fields]
code: int = field(default=0, init=False)
code: CommandCode = field(default=CommandCode(0), init=False)
name: str = field(default='', init=False)
_payload: bytes | None = field(default=None, init=False)
@classmethod
def from_bytes(cls, pdu: bytes) -> SMP_Command:
code = pdu[0]
if not pdu:
raise InvalidPacketError("Empty SMP PDU")
code = CommandCode(pdu[0])
subclass = SMP_Command.smp_classes.get(code)
if subclass is None:
instance = SMP_Command()
instance.name = SMP_Command.command_name(code)
instance.name = code.name
instance.code = code
instance.payload = pdu
return instance
@@ -222,59 +231,14 @@ class SMP_Command:
instance.payload = pdu[1:]
return instance
@staticmethod
def command_name(code: int) -> str:
return name_or_number(SMP_COMMAND_NAMES, code)
@staticmethod
def auth_req_str(value: int) -> str:
bonding_flags = value & 3
mitm = (value >> 2) & 1
sc = (value >> 3) & 1
keypress = (value >> 4) & 1
ct2 = (value >> 5) & 1
return (
f'bonding_flags={bonding_flags}, '
f'MITM={mitm}, sc={sc}, keypress={keypress}, ct2={ct2}'
)
@staticmethod
def io_capability_name(io_capability: int) -> str:
return name_or_number(SMP_IO_CAPABILITY_NAMES, io_capability)
@staticmethod
def key_distribution_str(value: int) -> str:
key_types: list[str] = []
if value & SMP_ENC_KEY_DISTRIBUTION_FLAG:
key_types.append('ENC')
if value & SMP_ID_KEY_DISTRIBUTION_FLAG:
key_types.append('ID')
if value & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
key_types.append('SIGN')
if value & SMP_LINK_KEY_DISTRIBUTION_FLAG:
key_types.append('LINK')
return ','.join(key_types)
@staticmethod
def keypress_notification_type_name(notification_type: int) -> str:
return name_or_number(SMP_KEYPRESS_NOTIFICATION_TYPE_NAMES, notification_type)
_Command = TypeVar("_Command", bound="SMP_Command")
@classmethod
def subclass(cls, subclass: type[_Command]) -> type[_Command]:
subclass.name = subclass.__name__.upper()
subclass.code = key_with_value(SMP_COMMAND_NAMES, subclass.name)
if subclass.code is None:
raise KeyError(
f'Command name {subclass.name} not found in SMP_COMMAND_NAMES'
)
subclass.fields = HCI_Object.fields_from_dataclass(subclass)
subclass.name = subclass.__name__.upper()
# Register a factory for this class
SMP_Command.smp_classes[subclass.code] = subclass
return subclass
@property
@@ -308,19 +272,17 @@ class SMP_Pairing_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.1 Pairing Request
'''
io_capability: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.io_capability_name})
)
code = CommandCode.PAIRING_REQUEST
io_capability: IoCapability = field(metadata=IoCapability.type_metadata(1))
oob_data_flag: int = field(metadata=metadata(1))
auth_req: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
)
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
maximum_encryption_key_size: int = field(metadata=metadata(1))
initiator_key_distribution: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
initiator_key_distribution: KeyDistribution = field(
metadata=KeyDistribution.type_metadata(1)
)
responder_key_distribution: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
responder_key_distribution: KeyDistribution = field(
metadata=KeyDistribution.type_metadata(1)
)
@@ -332,19 +294,17 @@ class SMP_Pairing_Response_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.2 Pairing Response
'''
io_capability: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.io_capability_name})
)
code = CommandCode.PAIRING_RESPONSE
io_capability: IoCapability = field(metadata=IoCapability.type_metadata(1))
oob_data_flag: int = field(metadata=metadata(1))
auth_req: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
)
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
maximum_encryption_key_size: int = field(metadata=metadata(1))
initiator_key_distribution: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
initiator_key_distribution: KeyDistribution = field(
metadata=KeyDistribution.type_metadata(1)
)
responder_key_distribution: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.key_distribution_str})
responder_key_distribution: KeyDistribution = field(
metadata=KeyDistribution.type_metadata(1)
)
@@ -356,6 +316,8 @@ class SMP_Pairing_Confirm_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.3 Pairing Confirm
'''
code = CommandCode.PAIRING_CONFIRM
confirm_value: bytes = field(metadata=metadata(16))
@@ -367,6 +329,8 @@ class SMP_Pairing_Random_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.4 Pairing Random
'''
code = CommandCode.PAIRING_RANDOM
random_value: bytes = field(metadata=metadata(16))
@@ -378,7 +342,9 @@ class SMP_Pairing_Failed_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.5 Pairing Failed
'''
reason: int = field(metadata=metadata({'size': 1, 'mapper': error_name}))
code = CommandCode.PAIRING_FAILED
reason: ErrorCode = field(metadata=ErrorCode.type_metadata(1))
# -----------------------------------------------------------------------------
@@ -389,6 +355,8 @@ class SMP_Pairing_Public_Key_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.6 Pairing Public Key
'''
code = CommandCode.PAIRING_PUBLIC_KEY
public_key_x: bytes = field(metadata=metadata(32))
public_key_y: bytes = field(metadata=metadata(32))
@@ -401,6 +369,8 @@ class SMP_Pairing_DHKey_Check_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.7 Pairing DHKey Check
'''
code = CommandCode.PAIRING_DHKEY_CHECK
dhkey_check: bytes = field(metadata=metadata(16))
@@ -412,10 +382,10 @@ class SMP_Pairing_Keypress_Notification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.5.8 Keypress Notification
'''
notification_type: int = field(
metadata=metadata(
{'size': 1, 'mapper': SMP_Command.keypress_notification_type_name}
)
code = CommandCode.PAIRING_KEYPRESS_NOTIFICATION
notification_type: KeypressNotificationType = field(
metadata=KeypressNotificationType.type_metadata(1)
)
@@ -427,6 +397,8 @@ class SMP_Encryption_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.2 Encryption Information
'''
code = CommandCode.ENCRYPTION_INFORMATION
long_term_key: bytes = field(metadata=metadata(16))
@@ -438,6 +410,8 @@ class SMP_Master_Identification_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.3 Master Identification
'''
code = CommandCode.MASTER_IDENTIFICATION
ediv: int = field(metadata=metadata(2))
rand: bytes = field(metadata=metadata(8))
@@ -450,6 +424,8 @@ class SMP_Identity_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.4 Identity Information
'''
code = CommandCode.IDENTITY_INFORMATION
identity_resolving_key: bytes = field(metadata=metadata(16))
@@ -461,6 +437,8 @@ class SMP_Identity_Address_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.5 Identity Address Information
'''
code = CommandCode.IDENTITY_ADDRESS_INFORMATION
addr_type: int = field(metadata=metadata(Address.ADDRESS_TYPE_SPEC))
bd_addr: Address = field(metadata=metadata(Address.parse_address_preceded_by_type))
@@ -473,6 +451,8 @@ class SMP_Signing_Information_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.6 Signing Information
'''
code = CommandCode.SIGNING_INFORMATION
signature_key: bytes = field(metadata=metadata(16))
@@ -484,33 +464,22 @@ class SMP_Security_Request_Command(SMP_Command):
See Bluetooth spec @ Vol 3, Part H - 3.6.7 Security Request
'''
auth_req: int = field(
metadata=metadata({'size': 1, 'mapper': SMP_Command.auth_req_str})
)
code = CommandCode.SECURITY_REQUEST
# -----------------------------------------------------------------------------
def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) -> int:
value = 0
if bonding:
value |= SMP_BONDING_AUTHREQ
if mitm:
value |= SMP_MITM_AUTHREQ
if sc:
value |= SMP_SC_AUTHREQ
if keypress:
value |= SMP_KEYPRESS_AUTHREQ
if ct2:
value |= SMP_CT2_AUTHREQ
return value
auth_req: AuthReq = field(metadata=AuthReq.type_metadata(1))
# -----------------------------------------------------------------------------
class AddressResolver:
def __init__(self, resolving_keys):
def __init__(self, resolving_keys: Sequence[tuple[bytes, Address]]) -> None:
self.resolving_keys = resolving_keys
def resolve(self, address):
def can_resolve_to(self, address: Address) -> bool:
return any(
resolved_address == address for _, resolved_address in self.resolving_keys
)
def resolve(self, address: Address) -> Address | None:
address_bytes = bytes(address)
hash_part = address_bytes[0:3]
prand = address_bytes[3:6]
@@ -671,8 +640,8 @@ class Session:
self.ltk_rand = bytes(8)
self.link_key: bytes | None = None
self.maximum_encryption_key_size: int = 0
self.initiator_key_distribution: int = 0
self.responder_key_distribution: int = 0
self.initiator_key_distribution: KeyDistribution = KeyDistribution(0)
self.responder_key_distribution: KeyDistribution = KeyDistribution(0)
self.peer_random_value: bytes | None = None
self.peer_public_key_x: bytes = bytes(32)
self.peer_public_key_y = bytes(32)
@@ -723,10 +692,10 @@ class Session:
)
# Key Distribution (default values before negotiation)
self.initiator_key_distribution = (
self.initiator_key_distribution = KeyDistribution(
pairing_config.delegate.local_initiator_key_distribution
)
self.responder_key_distribution = (
self.responder_key_distribution = KeyDistribution(
pairing_config.delegate.local_responder_key_distribution
)
@@ -738,7 +707,7 @@ class Session:
self.ct2: bool = False
# I/O Capabilities
self.io_capability = pairing_config.delegate.io_capability
self.io_capability = IoCapability(pairing_config.delegate.io_capability)
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
# OOB
@@ -817,8 +786,14 @@ class Session:
return self.nx[0 if self.is_responder else 1]
@property
def auth_req(self) -> int:
return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2)
def auth_req(self) -> AuthReq:
return AuthReq.from_booleans(
bonding=self.bonding,
sc=self.sc,
mitm=self.mitm,
keypress=self.keypress,
ct2=self.ct2,
)
def get_long_term_key(self, rand: bytes, ediv: int) -> bytes | None:
if not self.sc and not self.completed:
@@ -838,7 +813,7 @@ class Session:
if self.connection.transport == PhysicalTransport.BR_EDR:
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
return
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
if (not self.mitm) and (auth_req & AuthReq.MITM == 0):
self.pairing_method = PairingMethod.JUST_WORKS
return
@@ -856,7 +831,7 @@ class Session:
self.passkey_display = details[1 if self.is_initiator else 2]
def check_expected_value(
self, expected: bytes, received: bytes, error: int
self, expected: bytes, received: bytes, error: ErrorCode
) -> bool:
logger.debug(f'expected={expected.hex()} got={received.hex()}')
if expected != received:
@@ -876,7 +851,7 @@ class Session:
except Exception:
logger.exception('exception while confirm')
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
self.send_pairing_failed(ErrorCode.CONFIRM_VALUE_FAILED)
self.connection.cancel_on_disconnection(prompt())
@@ -895,7 +870,7 @@ class Session:
except Exception:
logger.exception('exception while prompting')
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
self.send_pairing_failed(ErrorCode.CONFIRM_VALUE_FAILED)
self.connection.cancel_on_disconnection(prompt())
@@ -906,13 +881,13 @@ class Session:
passkey = await self.pairing_config.delegate.get_number()
if passkey is None:
logger.debug('Passkey request rejected')
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
self.send_pairing_failed(ErrorCode.PASSKEY_ENTRY_FAILED)
return
logger.debug(f'user input: {passkey}')
next_steps(passkey)
except Exception:
logger.exception('exception while prompting')
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
self.send_pairing_failed(ErrorCode.PASSKEY_ENTRY_FAILED)
self.connection.cancel_on_disconnection(prompt())
@@ -967,7 +942,7 @@ class Session:
def send_command(self, command: SMP_Command) -> None:
self.manager.send_command(self.connection, command)
def send_pairing_failed(self, error: int) -> None:
def send_pairing_failed(self, error: ErrorCode) -> None:
self.send_command(SMP_Pairing_Failed_Command(reason=error))
self.on_pairing_failure(error)
@@ -1139,7 +1114,7 @@ class Session:
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
)
self.send_pairing_failed(
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR
ErrorCode.CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED
)
else:
self.ltk = self.derive_ltk(self.link_key, self.ct2)
@@ -1150,14 +1125,14 @@ class Session:
# CTKD: Derive LTK from LinkKey
if (
self.connection.transport == PhysicalTransport.BR_EDR
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
and self.initiator_key_distribution & KeyDistribution.ENC_KEY
):
self.ctkd_task = self.connection.cancel_on_disconnection(
self.get_link_key_and_derive_ltk()
)
elif not self.sc:
# Distribute the LTK, EDIV and RAND
if self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
if self.initiator_key_distribution & KeyDistribution.ENC_KEY:
self.send_command(
SMP_Encryption_Information_Command(long_term_key=self.ltk)
)
@@ -1168,7 +1143,7 @@ class Session:
)
# Distribute IRK & BD ADDR
if self.initiator_key_distribution & SMP_ID_KEY_DISTRIBUTION_FLAG:
if self.initiator_key_distribution & KeyDistribution.ID_KEY:
self.send_command(
SMP_Identity_Information_Command(
identity_resolving_key=self.manager.device.irk
@@ -1178,25 +1153,25 @@ class Session:
# Distribute CSRK
csrk = bytes(16) # FIXME: testing
if self.initiator_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
if self.initiator_key_distribution & KeyDistribution.SIGN_KEY:
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
# CTKD, calculate BR/EDR link key
if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
if self.initiator_key_distribution & KeyDistribution.LINK_KEY:
self.link_key = self.derive_link_key(self.ltk, self.ct2)
else:
# CTKD: Derive LTK from LinkKey
if (
self.connection.transport == PhysicalTransport.BR_EDR
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
and self.responder_key_distribution & KeyDistribution.ENC_KEY
):
self.ctkd_task = self.connection.cancel_on_disconnection(
self.get_link_key_and_derive_ltk()
)
# Distribute the LTK, EDIV and RAND
elif not self.sc:
if self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
if self.responder_key_distribution & KeyDistribution.ENC_KEY:
self.send_command(
SMP_Encryption_Information_Command(long_term_key=self.ltk)
)
@@ -1207,7 +1182,7 @@ class Session:
)
# Distribute IRK & BD ADDR
if self.responder_key_distribution & SMP_ID_KEY_DISTRIBUTION_FLAG:
if self.responder_key_distribution & KeyDistribution.ID_KEY:
self.send_command(
SMP_Identity_Information_Command(
identity_resolving_key=self.manager.device.irk
@@ -1217,30 +1192,30 @@ class Session:
# Distribute CSRK
csrk = bytes(16) # FIXME: testing
if self.responder_key_distribution & SMP_SIGN_KEY_DISTRIBUTION_FLAG:
if self.responder_key_distribution & KeyDistribution.SIGN_KEY:
self.send_command(SMP_Signing_Information_Command(signature_key=csrk))
# CTKD, calculate BR/EDR link key
if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
if self.responder_key_distribution & KeyDistribution.LINK_KEY:
self.link_key = self.derive_link_key(self.ltk, self.ct2)
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
# Set our expectations for what to wait for in the key distribution phase
self.peer_expected_distributions = []
if not self.sc and self.connection.transport == PhysicalTransport.LE:
if key_distribution_flags & SMP_ENC_KEY_DISTRIBUTION_FLAG != 0:
if key_distribution_flags & KeyDistribution.ENC_KEY != 0:
self.peer_expected_distributions.append(
SMP_Encryption_Information_Command
)
self.peer_expected_distributions.append(
SMP_Master_Identification_Command
)
if key_distribution_flags & SMP_ID_KEY_DISTRIBUTION_FLAG != 0:
if key_distribution_flags & KeyDistribution.ID_KEY != 0:
self.peer_expected_distributions.append(SMP_Identity_Information_Command)
self.peer_expected_distributions.append(
SMP_Identity_Address_Information_Command
)
if key_distribution_flags & SMP_SIGN_KEY_DISTRIBUTION_FLAG != 0:
if key_distribution_flags & KeyDistribution.SIGN_KEY != 0:
self.peer_expected_distributions.append(SMP_Signing_Information_Command)
logger.debug(
'expecting distributions: '
@@ -1253,7 +1228,7 @@ class Session:
logger.warning(
color('received key distribution on a non-encrypted connection', 'red')
)
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
self.send_pairing_failed(ErrorCode.UNSPECIFIED_REASON)
return
# Check that this command class is expected
@@ -1273,7 +1248,7 @@ class Session:
'red',
)
)
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
self.send_pairing_failed(ErrorCode.UNSPECIFIED_REASON)
async def pair(self) -> None:
# Start pairing as an initiator
@@ -1384,34 +1359,56 @@ class Session:
)
await self.manager.on_pairing(self, peer_address, keys)
def on_pairing_failure(self, reason: int) -> None:
logger.warning(f'pairing failure ({error_name(reason)})')
def on_pairing_failure(self, reason: ErrorCode) -> None:
logger.warning('pairing failure (%s)', reason.name)
if self.completed:
return
self.completed = True
error = ProtocolError(reason, 'smp', error_name(reason))
error = ProtocolError(reason, 'smp', reason.name)
if self.pairing_result is not None and not self.pairing_result.done():
self.pairing_result.set_exception(error)
self.manager.on_pairing_failure(self, reason)
def on_smp_command(self, command: SMP_Command) -> None:
# Find the handler method
handler_name = f'on_{command.name.lower()}'
handler = getattr(self, handler_name, None)
if handler is not None:
try:
handler(command)
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
response = SMP_Pairing_Failed_Command(
reason=SMP_UNSPECIFIED_REASON_ERROR
)
self.send_command(response)
else:
logger.error(color('SMP command not handled???', 'red'))
try:
match command:
case SMP_Pairing_Request_Command():
self.on_smp_pairing_request_command(command)
case SMP_Pairing_Response_Command():
self.on_smp_pairing_response_command(command)
case SMP_Pairing_Confirm_Command():
self.on_smp_pairing_confirm_command(command)
case SMP_Pairing_Random_Command():
self.on_smp_pairing_random_command(command)
case SMP_Pairing_Failed_Command():
self.on_smp_pairing_failed_command(command)
case SMP_Encryption_Information_Command():
self.on_smp_encryption_information_command(command)
case SMP_Master_Identification_Command():
self.on_smp_master_identification_command(command)
case SMP_Identity_Information_Command():
self.on_smp_identity_information_command(command)
case SMP_Identity_Address_Information_Command():
self.on_smp_identity_address_information_command(command)
case SMP_Signing_Information_Command():
self.on_smp_signing_information_command(command)
case SMP_Pairing_Public_Key_Command():
self.on_smp_pairing_public_key_command(command)
case SMP_Pairing_DHKey_Check_Command():
self.on_smp_pairing_dhkey_check_command(command)
# case SMP_Security_Request_Command():
# self.on_smp_security_request_command(command)
# case SMP_Pairing_Keypress_Notification_Command():
# self.on_smp_pairing_keypress_notification_command(command)
case _:
logger.error(color('SMP command not handled', 'red'))
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
response = SMP_Pairing_Failed_Command(reason=ErrorCode.UNSPECIFIED_REASON)
self.send_command(response)
def on_smp_pairing_request_command(
self, command: SMP_Pairing_Request_Command
@@ -1431,16 +1428,16 @@ class Session:
accepted = False
if not accepted:
logger.debug('pairing rejected by delegate')
self.send_pairing_failed(SMP_PAIRING_NOT_SUPPORTED_ERROR)
self.send_pairing_failed(ErrorCode.PAIRING_NOT_SUPPORTED)
return
# Save the request
self.preq = bytes(command)
# Bonding and SC require both sides to request/support it
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0)
self.bonding = self.bonding and (command.auth_req & AuthReq.BONDING != 0)
self.sc = self.sc and (command.auth_req & AuthReq.SC != 0)
self.ct2 = self.ct2 and (command.auth_req & AuthReq.CT2 != 0)
# Infer the pairing method
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
@@ -1451,7 +1448,7 @@ class Session:
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
self.send_pairing_failed(ErrorCode.OOB_NOT_AVAILABLE)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
@@ -1470,8 +1467,11 @@ class Session:
(
self.initiator_key_distribution,
self.responder_key_distribution,
) = await self.pairing_config.delegate.key_distribution_response(
command.initiator_key_distribution, command.responder_key_distribution
) = map(
KeyDistribution,
await self.pairing_config.delegate.key_distribution_response(
command.initiator_key_distribution, command.responder_key_distribution
),
)
self.compute_peer_expected_distributions(self.initiator_key_distribution)
@@ -1509,8 +1509,8 @@ class Session:
self.peer_io_capability = command.io_capability
# Bonding and SC require both sides to request/support it
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
self.bonding = self.bonding and (command.auth_req & AuthReq.BONDING != 0)
self.sc = self.sc and (command.auth_req & AuthReq.SC != 0)
# Infer the pairing method
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
@@ -1521,7 +1521,7 @@ class Session:
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
self.send_pairing_failed(ErrorCode.OOB_NOT_AVAILABLE)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
@@ -1541,7 +1541,7 @@ class Session:
command.responder_key_distribution & ~self.responder_key_distribution != 0
):
# The response isn't a subset of the request
self.send_pairing_failed(SMP_INVALID_PARAMETERS_ERROR)
self.send_pairing_failed(ErrorCode.INVALID_PARAMETERS)
return
self.initiator_key_distribution = command.initiator_key_distribution
self.responder_key_distribution = command.responder_key_distribution
@@ -1619,7 +1619,7 @@ class Session:
)
assert self.confirm_value
if not self.check_expected_value(
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
):
return
@@ -1660,7 +1660,7 @@ class Session:
self.pkb, self.pka, command.random_value, bytes([0])
)
if not self.check_expected_value(
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
):
return
elif self.pairing_method == PairingMethod.PASSKEY:
@@ -1673,7 +1673,7 @@ class Session:
bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]),
)
if not self.check_expected_value(
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
):
return
@@ -1702,7 +1702,7 @@ class Session:
bytes([0x80 + ((self.passkey >> self.passkey_step) & 1)]),
)
if not self.check_expected_value(
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
self.confirm_value, confirm_verifier, ErrorCode.CONFIRM_VALUE_FAILED
):
return
@@ -1819,7 +1819,7 @@ class Session:
if not self.check_expected_value(
self.peer_oob_data.c,
confirm_verifier,
SMP_CONFIRM_VALUE_FAILED_ERROR,
ErrorCode.CONFIRM_VALUE_FAILED,
):
return
@@ -1853,7 +1853,7 @@ class Session:
expected = self.eb if self.is_initiator else self.ea
assert expected
if not self.check_expected_value(
expected, command.dhkey_check, SMP_DHKEY_CHECK_FAILED_ERROR
expected, command.dhkey_check, ErrorCode.DHKEY_CHECK_FAILED
):
return
@@ -1932,6 +1932,7 @@ class Manager(utils.EventEmitter):
self._ecc_key = None
self.pairing_config_factory = pairing_config_factory
self.session_proxy = Session
self.debug_mode = False
def send_command(self, connection: Connection, command: SMP_Command) -> None:
logger.debug(
@@ -1957,7 +1958,7 @@ class Manager(utils.EventEmitter):
)
# Security request is more than just pairing, so let applications handle them
if command.code == SMP_SECURITY_REQUEST_COMMAND:
if command.code == CommandCode.SECURITY_REQUEST:
self.on_smp_security_request_command(
connection, cast(SMP_Security_Request_Command, command)
)
@@ -1978,6 +1979,13 @@ class Manager(utils.EventEmitter):
@property
def ecc_key(self) -> crypto.EccKey:
if self.debug_mode:
# Core - Vol 3, Part H:
# When the Security Manager is placed in a Debug mode it shall use the
# following Diffie-Hellman private / public key pair:
debug_key = crypto.EccKey.from_private_key_bytes(SMP_DEBUG_KEY_PRIVATE)
return debug_key
if self._ecc_key is None:
self._ecc_key = crypto.EccKey.generate()
assert self._ecc_key
@@ -1997,15 +2005,13 @@ class Manager(utils.EventEmitter):
def request_pairing(self, connection: Connection) -> None:
pairing_config = self.pairing_config_factory(connection)
if pairing_config:
auth_req = smp_auth_req(
pairing_config.bonding,
pairing_config.mitm,
pairing_config.sc,
False,
False,
auth_req = AuthReq.from_booleans(
bonding=pairing_config.bonding,
sc=pairing_config.sc,
mitm=pairing_config.mitm,
)
else:
auth_req = 0
auth_req = AuthReq(0)
self.send_command(connection, SMP_Security_Request_Command(auth_req=auth_req))
def on_session_start(self, session: Session) -> None:
@@ -2021,7 +2027,7 @@ class Manager(utils.EventEmitter):
# Notify the device
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
def on_pairing_failure(self, session: Session, reason: int) -> None:
def on_pairing_failure(self, session: Session, reason: ErrorCode) -> None:
self.device.on_pairing_failure(session.connection, reason)
def on_session_end(self, session: Session) -> None:
+111 -1
View File
@@ -110,6 +110,53 @@ class BtSnooper(Snooper):
)
# -----------------------------------------------------------------------------
class PcapSnooper(Snooper):
"""
Snooper that saves or streames HCI packets using the PCAP format.
"""
PCAP_MAGIC = 0xA1B2C3D4
DLT_BLUETOOTH_HCI_H4_WITH_PHDR = 201
def __init__(self, output: BinaryIO):
self.output = output
# Write the header
self.output.write(
struct.pack(
"<IHHIIII",
self.PCAP_MAGIC,
2, # Major PCAP Version
4, # Minor PCAP Version
0, # Reserved 1
0, # Reserved 2
65535, # SnapLen
# FCS and f are set to 0 implicitly by the next line
self.DLT_BLUETOOTH_HCI_H4_WITH_PHDR, # The DLT in this PCAP
)
)
def snoop(self, hci_packet: bytes, direction: Snooper.Direction):
now = datetime.datetime.now(datetime.timezone.utc)
sec = int(now.timestamp())
usec = now.microsecond
# Emit the record
self.output.write(
struct.pack(
"<IIII",
sec, # Timestamp (Seconds)
usec, # Timestamp (Microseconds)
len(hci_packet) + 4,
len(hci_packet) + 4, # +4 because of the addtional direction info...
)
+ struct.pack(">I", int(direction)) # ...thats being added here
+ hci_packet
)
self.output.flush() # flush after every packet for live logging
# -----------------------------------------------------------------------------
_SNOOPER_INSTANCE_COUNT = 0
@@ -140,9 +187,38 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
pid: the current process ID.
instance: the instance ID in the current process.
pcapsnoop
The syntax for the type-specific arguments for this type is:
<io-type>:<io-type-specific-arguments>
Supported I/O types are:
file
The type-specific arguments for this I/O type is a string that is converted
to a file path using the python `str.format()` string formatting. The log
records will be written to that file if it can be opened/created.
The keyword args that may be referenced by the string pattern are:
now: the value of `datetime.now()`
utcnow: the value of `datetime.now(tz=datetime.timezone.utc)`
pid: the current process ID.
instance: the instance ID in the current process.
pipe
The type-specific arguments for this I/O type is a string that is converted
to a path using the python `str.format()` string formatting. The log
records will be written to the named pipe referenced by this path
if it can be opened. The keyword args that may be referenced by the
string pattern are:
now: the value of `datetime.now()`
utcnow: the value of `datetime.now(tz=datetime.timezone.utc)`
pid: the current process ID.
instance: the instance ID in the current process.
Examples:
btsnoop:file:my_btsnoop.log
btsnoop:file:/tmp/bumble_{now:%Y-%m-%d-%H:%M:%S}_{pid}.log
pcapsnoop:pipe:/tmp/bumble-extcap
"""
if ':' not in spec:
@@ -150,6 +226,8 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
snooper_type, snooper_args = spec.split(':', maxsplit=1)
global _SNOOPER_INSTANCE_COUNT
if snooper_type == 'btsnoop':
if ':' not in snooper_args:
raise core.InvalidArgumentError('I/O type for btsnoop snooper type missing')
@@ -157,7 +235,6 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type == 'file':
# Process the file name string pattern.
global _SNOOPER_INSTANCE_COUNT
file_path = io_name.format(
now=datetime.datetime.now(),
utcnow=datetime.datetime.now(tz=datetime.timezone.utc),
@@ -173,6 +250,39 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
_SNOOPER_INSTANCE_COUNT -= 1
return
elif snooper_type == 'pcapsnoop':
if ':' not in snooper_args:
raise core.InvalidArgumentError(
'I/O type for pcapsnoop snooper type missing'
)
io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type in {'pipe', 'file'}:
# Process the file name string pattern.
file_path = io_name.format(
now=datetime.datetime.now(),
utcnow=datetime.datetime.now(tz=datetime.timezone.utc),
pid=os.getpid(),
instance=_SNOOPER_INSTANCE_COUNT,
)
# Open a file or pipe
logger.debug(f'PCAP file: {file_path}')
# Pipes we have to open with unbuffered binary I/O
# so we pass ``buffering`` for pipes but not for files
pcap_file: BinaryIO
if io_type == 'pipe':
pcap_file = open(file_path, 'wb', buffering=0)
else:
pcap_file = open(file_path, 'wb')
with pcap_file:
_SNOOPER_INSTANCE_COUNT += 1
yield PcapSnooper(pcap_file)
_SNOOPER_INSTANCE_COUNT -= 1
return
raise core.InvalidArgumentError(f'I/O type {io_type} not supported')
raise core.InvalidArgumentError(f'snooper type {snooper_type} not found')
+1 -1
View File
@@ -194,7 +194,7 @@ async def open_android_netsim_controller_transport(
# We only accept BLUETOOTH
if request.initial_info.chip.kind != ChipKind.BLUETOOTH:
logger.warning('Unsupported chip type')
logger.debug('Request for unsupported chip type')
error = PacketResponse(error='Unsupported chip type')
await self.context.write(error)
# return
+108 -86
View File
@@ -43,44 +43,53 @@ hci.HCI_Command.register_commands(globals())
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_LE_Get_Vendor_Capabilities_Command(hci.HCI_Command):
class HCI_LE_Get_Vendor_Capabilities_ReturnParameters(hci.HCI_StatusReturnParameters):
max_advt_instances: int = field(metadata=hci.metadata(1), default=0)
offloaded_resolution_of_private_address: int = field(
metadata=hci.metadata(1), default=0
)
total_scan_results_storage: int = field(metadata=hci.metadata(2), default=0)
max_irk_list_sz: int = field(metadata=hci.metadata(1), default=0)
filtering_support: int = field(metadata=hci.metadata(1), default=0)
max_filter: int = field(metadata=hci.metadata(1), default=0)
activity_energy_info_support: int = field(metadata=hci.metadata(1), default=0)
version_supported: int = field(metadata=hci.metadata(2), default=0)
total_num_of_advt_tracked: int = field(metadata=hci.metadata(2), default=0)
extended_scan_support: int = field(metadata=hci.metadata(1), default=0)
debug_logging_supported: int = field(metadata=hci.metadata(1), default=0)
le_address_generation_offloading_support: int = field(
metadata=hci.metadata(1), default=0
)
a2dp_source_offload_capability_mask: int = field(
metadata=hci.metadata(4), default=0
)
bluetooth_quality_report_support: int = field(metadata=hci.metadata(1), default=0)
dynamic_audio_buffer_support: int = field(metadata=hci.metadata(4), default=0)
@hci.HCI_SyncCommand.sync_command(HCI_LE_Get_Vendor_Capabilities_ReturnParameters)
@dataclasses.dataclass
class HCI_LE_Get_Vendor_Capabilities_Command(
hci.HCI_SyncCommand[HCI_LE_Get_Vendor_Capabilities_ReturnParameters]
):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#vendor-specific-capabilities
'''
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('max_advt_instances', 1),
('offloaded_resolution_of_private_address', 1),
('total_scan_results_storage', 2),
('max_irk_list_sz', 1),
('filtering_support', 1),
('max_filter', 1),
('activity_energy_info_support', 1),
('version_supported', 2),
('total_num_of_advt_tracked', 2),
('extended_scan_support', 1),
('debug_logging_supported', 1),
('le_address_generation_offloading_support', 1),
('a2dp_source_offload_capability_mask', 4),
('bluetooth_quality_report_support', 1),
('dynamic_audio_buffer_support', 4),
]
@classmethod
def parse_return_parameters(cls, parameters):
# There are many versions of this data structure, so we need to parse until
# there are no more bytes to parse, and leave un-signal parameters set to
# None (older versions)
nones = {field: None for field, _ in cls.return_parameters_fields}
return_parameters = hci.HCI_Object(cls.return_parameters_fields, **nones)
# there are no more bytes to parse, and leave un-signaled parameters set to
# 0
return_parameters = HCI_LE_Get_Vendor_Capabilities_ReturnParameters(
hci.HCI_ErrorCode.SUCCESS
)
try:
offset = 0
for field in cls.return_parameters_fields:
for field in cls.return_parameters_class.fields:
field_name, field_type = field
field_value, field_size = hci.HCI_Object.parse_field(
parameters, offset, field_type
@@ -94,9 +103,30 @@ class HCI_LE_Get_Vendor_Capabilities_Command(hci.HCI_Command):
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
# APCF Subcommands
class LeApcfOpcode(hci.SpecableEnum):
ENABLE = 0x00
SET_FILTERING_PARAMETERS = 0x01
BROADCASTER_ADDRESS = 0x02
SERVICE_UUID = 0x03
SERVICE_SOLICITATION_UUID = 0x04
LOCAL_NAME = 0x05
MANUFACTURER_DATA = 0x06
SERVICE_DATA = 0x07
TRANSPORT_DISCOVERY_SERVICE = 0x08
AD_TYPE_FILTER = 0x09
READ_EXTENDED_FEATURES = 0xFF
@dataclasses.dataclass
class HCI_LE_APCF_Command(hci.HCI_Command):
class HCI_LE_APCF_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = field(metadata=LeApcfOpcode.type_metadata(1))
payload: bytes = field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_LE_APCF_ReturnParameters)
@dataclasses.dataclass
class HCI_LE_APCF_Command(hci.HCI_SyncCommand[HCI_LE_APCF_ReturnParameters]):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_apcf_command
@@ -105,52 +135,52 @@ class HCI_LE_APCF_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures.
'''
# APCF Subcommands
class Opcode(hci.SpecableEnum):
ENABLE = 0x00
SET_FILTERING_PARAMETERS = 0x01
BROADCASTER_ADDRESS = 0x02
SERVICE_UUID = 0x03
SERVICE_SOLICITATION_UUID = 0x04
LOCAL_NAME = 0x05
MANUFACTURER_DATA = 0x06
SERVICE_DATA = 0x07
TRANSPORT_DISCOVERY_SERVICE = 0x08
AD_TYPE_FILTER = 0x09
READ_EXTENDED_FEATURES = 0xFF
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
opcode: int = dataclasses.field(metadata=LeApcfOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Get_Controller_Activity_Energy_Info_Command(hci.HCI_Command):
class HCI_Get_Controller_Activity_Energy_Info_ReturnParameters(
hci.HCI_StatusReturnParameters
):
total_tx_time_ms: int = field(metadata=hci.metadata(4))
total_rx_time_ms: int = field(metadata=hci.metadata(4))
total_idle_time_ms: int = field(metadata=hci.metadata(4))
total_energy_used: int = field(metadata=hci.metadata(4))
@hci.HCI_SyncCommand.sync_command(
HCI_Get_Controller_Activity_Energy_Info_ReturnParameters
)
@dataclasses.dataclass
class HCI_Get_Controller_Activity_Energy_Info_Command(
hci.HCI_SyncCommand[HCI_Get_Controller_Activity_Energy_Info_ReturnParameters]
):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_get_controller_activity_energy_info
'''
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('total_tx_time_ms', 4),
('total_rx_time_ms', 4),
('total_idle_time_ms', 4),
('total_energy_used', 4),
]
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
# A2DP Hardware Offload Subcommands
class A2dpHardwareOffloadOpcode(hci.SpecableEnum):
START_A2DP_OFFLOAD = 0x01
STOP_A2DP_OFFLOAD = 0x02
@dataclasses.dataclass
class HCI_A2DP_Hardware_Offload_Command(hci.HCI_Command):
class HCI_A2DP_Hardware_Offload_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = dataclasses.field(metadata=A2dpHardwareOffloadOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
@hci.HCI_SyncCommand.sync_command(HCI_A2DP_Hardware_Offload_ReturnParameters)
@dataclasses.dataclass
class HCI_A2DP_Hardware_Offload_Command(
hci.HCI_SyncCommand[HCI_A2DP_Hardware_Offload_ReturnParameters]
):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#a2dp-hardware-offload-support
@@ -159,25 +189,27 @@ class HCI_A2DP_Hardware_Offload_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures.
'''
# A2DP Hardware Offload Subcommands
class Opcode(hci.SpecableEnum):
START_A2DP_OFFLOAD = 0x01
STOP_A2DP_OFFLOAD = 0x02
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
opcode: int = dataclasses.field(metadata=A2dpHardwareOffloadOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
# Dynamic Audio Buffer Subcommands
class DynamicAudioBufferOpcode(hci.SpecableEnum):
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
@dataclasses.dataclass
class HCI_Dynamic_Audio_Buffer_Command(hci.HCI_Command):
class HCI_Dynamic_Audio_Buffer_ReturnParameters(hci.HCI_StatusReturnParameters):
opcode: int = dataclasses.field(metadata=DynamicAudioBufferOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
@hci.HCI_SyncCommand.sync_command(HCI_Dynamic_Audio_Buffer_ReturnParameters)
@dataclasses.dataclass
class HCI_Dynamic_Audio_Buffer_Command(
hci.HCI_SyncCommand[HCI_Dynamic_Audio_Buffer_ReturnParameters]
):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#dynamic-audio-buffer-command
@@ -186,19 +218,9 @@ class HCI_Dynamic_Audio_Buffer_Command(hci.HCI_Command):
implementation. A future enhancement may define subcommand-specific data structures.
'''
# Dynamic Audio Buffer Subcommands
class Opcode(hci.SpecableEnum):
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
opcode: int = dataclasses.field(metadata=Opcode.type_metadata(1))
opcode: int = dataclasses.field(metadata=DynamicAudioBufferOpcode.type_metadata(1))
payload: bytes = dataclasses.field(metadata=hci.metadata("*"))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('opcode', Opcode.type_spec(1)),
('payload', '*'),
]
# -----------------------------------------------------------------------------
class HCI_Android_Vendor_Event(hci.HCI_Extended_Event):
+24 -18
View File
@@ -46,9 +46,19 @@ class TX_Power_Level_Command:
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Write_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
class HCI_Write_Tx_Power_Level_ReturnParameters(hci.HCI_StatusReturnParameters):
handle_type: int = hci.field(metadata=hci.metadata(1))
connection_handle: int = hci.field(metadata=hci.metadata(2))
selected_tx_power_level: int = hci.field(metadata=hci.metadata(-1))
@hci.HCI_SyncCommand.sync_command(HCI_Write_Tx_Power_Level_ReturnParameters)
@dataclasses.dataclass
class HCI_Write_Tx_Power_Level_Command(
hci.HCI_SyncCommand[HCI_Write_Tx_Power_Level_ReturnParameters],
TX_Power_Level_Command,
):
'''
Write TX power level. See BT_HCI_OP_VS_WRITE_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
@@ -61,18 +71,21 @@ class HCI_Write_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
connection_handle: int = dataclasses.field(metadata=hci.metadata(2))
tx_power_level: int = dataclasses.field(metadata=hci.metadata(-1))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('selected_tx_power_level', -1),
]
# -----------------------------------------------------------------------------
@hci.HCI_Command.command
@dataclasses.dataclass
class HCI_Read_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
class HCI_Read_Tx_Power_Level_ReturnParameters(hci.HCI_StatusReturnParameters):
handle_type: int = hci.field(metadata=hci.metadata(1))
connection_handle: int = hci.field(metadata=hci.metadata(2))
tx_power_level: int = hci.field(metadata=hci.metadata(-1))
@hci.HCI_SyncCommand.sync_command(HCI_Read_Tx_Power_Level_ReturnParameters)
@dataclasses.dataclass
class HCI_Read_Tx_Power_Level_Command(
hci.HCI_SyncCommand[HCI_Read_Tx_Power_Level_ReturnParameters],
TX_Power_Level_Command,
):
'''
Read TX power level. See BT_HCI_OP_VS_READ_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
@@ -83,10 +96,3 @@ class HCI_Read_Tx_Power_Level_Command(hci.HCI_Command, TX_Power_Level_Command):
handle_type: int = dataclasses.field(metadata=hci.metadata(1))
connection_handle: int = dataclasses.field(metadata=hci.metadata(2))
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('tx_power_level', -1),
]
+1 -1
View File
@@ -63,7 +63,7 @@ HCI sockets provide a way to send/receive HCI packets to/from a Bluetooth contro
See the [HCI Socket Transport page](../transports/hci_socket.md) for details on the `hci-socket` tansport syntax.
The HCI device referenced by an `hci-socket` transport (`hci<X>`, where `<X>` is an integer, with `hci0` being the first controller device, and so on) must be in the `DOWN` state before it can be opened as a transport.
You can bring a HCI controller `UP` or `DOWN` with `hciconfig hci<X> up` and `hciconfig hci<X> up`.
You can bring a HCI controller `UP` or `DOWN` with `hciconfig hci<X> up` and `hciconfig hci<X> down`.
!!! tip "HCI Socket Permissions"
By default, when running as a regular user, you won't have the permission to use
+3 -3
View File
@@ -37,7 +37,7 @@ The vendor specific HCI commands to read and write TX power are defined in
from bumble.vendor.zephyr.hci import HCI_Write_Tx_Power_Level_Command
# set advertising power to -4 dB
response = await host.send_command(
response = await host.send_sync_command(
HCI_Write_Tx_Power_Level_Command(
handle_type=HCI_Write_Tx_Power_Level_Command.TX_POWER_HANDLE_TYPE_ADV,
connection_handle=0,
@@ -45,7 +45,7 @@ response = await host.send_command(
)
)
if response.return_parameters.status == HCI_SUCCESS:
print(f"TX power set to {response.return_parameters.selected_tx_power_level}")
if response.status == HCI_SUCCESS:
print(f"TX power set to {response.selected_tx_power_level}")
```
+1 -1
View File
@@ -65,7 +65,7 @@ async def main() -> None:
# Go!
await device.power_on()
await device.start_advertising(auto_restart=True)
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+2 -2
View File
@@ -71,8 +71,8 @@ async def main() -> None:
rr_intervals=random.choice(
(
(
random.randint(900, 1100) / 1000,
random.randint(900, 1100) / 1000,
random.randint(900, 1100) // 1000,
random.randint(900, 1100) // 1000,
),
None,
)
+1 -1
View File
@@ -161,7 +161,7 @@ async def main() -> None:
await device.set_discoverable(True)
await device.set_connectable(True)
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+1 -1
View File
@@ -181,7 +181,7 @@ async def main() -> None:
await device.set_discoverable(True)
await device.set_connectable(True)
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+1 -1
View File
@@ -70,7 +70,7 @@ async def main() -> None:
await device.power_on()
await device.start_advertising(advertising_type=advertising_type, target=target)
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+27 -25
View File
@@ -25,7 +25,7 @@ import sys
import websockets.asyncio.server
import bumble.logging
from bumble import a2dp, avc, avdtp, avrcp, utils
from bumble import a2dp, avc, avdtp, avrcp, sdp, utils
from bumble.core import PhysicalTransport
from bumble.device import Device
from bumble.transport import open_transport
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def sdp_records():
def sdp_records() -> dict[int, list[sdp.ServiceAttribute]]:
a2dp_sink_service_record_handle = 0x00010001
avrcp_controller_service_record_handle = 0x00010002
avrcp_target_service_record_handle = 0x00010003
@@ -43,17 +43,17 @@ def sdp_records():
a2dp_sink_service_record_handle: a2dp.make_audio_sink_service_sdp_records(
a2dp_sink_service_record_handle
),
avrcp_controller_service_record_handle: avrcp.make_controller_service_sdp_records(
avrcp_controller_service_record_handle: avrcp.ControllerServiceSdpRecord(
avrcp_controller_service_record_handle
),
avrcp_target_service_record_handle: avrcp.make_target_service_sdp_records(
avrcp_controller_service_record_handle
),
).to_service_attributes(),
avrcp_target_service_record_handle: avrcp.TargetServiceSdpRecord(
avrcp_target_service_record_handle
).to_service_attributes(),
}
# -----------------------------------------------------------------------------
def codec_capabilities():
def codec_capabilities() -> avdtp.MediaCodecCapabilities:
return avdtp.MediaCodecCapabilities(
media_type=avdtp.AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=a2dp.A2DP_SBC_CODEC_TYPE,
@@ -81,20 +81,22 @@ def codec_capabilities():
# -----------------------------------------------------------------------------
def on_avdtp_connection(server):
def on_avdtp_connection(server: avdtp.Protocol) -> None:
# Add a sink endpoint to the server
sink = server.add_sink(codec_capabilities())
sink.on('rtp_packet', on_rtp_packet)
sink.on(sink.EVENT_RTP_PACKET, on_rtp_packet)
# -----------------------------------------------------------------------------
def on_rtp_packet(packet):
def on_rtp_packet(packet: avdtp.MediaPacket) -> None:
print(f'RTP: {packet}')
# -----------------------------------------------------------------------------
def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketServer):
async def get_supported_events():
def on_avrcp_start(
avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketServer
) -> None:
async def get_supported_events() -> None:
events = await avrcp_protocol.get_supported_events()
print("SUPPORTED EVENTS:", events)
websocket_server.send_message(
@@ -130,14 +132,14 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
utils.AsyncRunner.spawn(get_supported_events())
async def monitor_track_changed():
async for identifier in avrcp_protocol.monitor_track_changed():
print("TRACK CHANGED:", identifier.hex())
async def monitor_track_changed() -> None:
async for uid in avrcp_protocol.monitor_track_changed():
print("TRACK CHANGED:", hex(uid))
websocket_server.send_message(
{"type": "track-changed", "params": {"identifier": identifier.hex()}}
{"type": "track-changed", "params": {"identifier": hex(uid)}}
)
async def monitor_playback_status():
async def monitor_playback_status() -> None:
async for playback_status in avrcp_protocol.monitor_playback_status():
print("PLAYBACK STATUS CHANGED:", playback_status.name)
websocket_server.send_message(
@@ -147,7 +149,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
}
)
async def monitor_playback_position():
async def monitor_playback_position() -> None:
async for playback_position in avrcp_protocol.monitor_playback_position(
playback_interval=1
):
@@ -159,7 +161,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
}
)
async def monitor_player_application_settings():
async def monitor_player_application_settings() -> None:
async for settings in avrcp_protocol.monitor_player_application_settings():
print("PLAYER APPLICATION SETTINGS:", settings)
settings_as_dict = [
@@ -173,14 +175,14 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
}
)
async def monitor_available_players():
async def monitor_available_players() -> None:
async for _ in avrcp_protocol.monitor_available_players():
print("AVAILABLE PLAYERS CHANGED")
websocket_server.send_message(
{"type": "available-players-changed", "params": {}}
)
async def monitor_addressed_player():
async def monitor_addressed_player() -> None:
async for player in avrcp_protocol.monitor_addressed_player():
print("ADDRESSED PLAYER CHANGED")
websocket_server.send_message(
@@ -195,7 +197,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
}
)
async def monitor_uids():
async def monitor_uids() -> None:
async for uid_counter in avrcp_protocol.monitor_uids():
print("UIDS CHANGED")
websocket_server.send_message(
@@ -207,7 +209,7 @@ def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketSe
}
)
async def monitor_volume():
async def monitor_volume() -> None:
async for volume in avrcp_protocol.monitor_volume():
print("VOLUME CHANGED:", volume)
websocket_server.send_message(
@@ -360,7 +362,7 @@ async def main() -> None:
# Create a listener to wait for AVDTP connections
listener = avdtp.Listener(avdtp.Listener.create_registrar(device))
listener.on('connection', on_avdtp_connection)
listener.on(listener.EVENT_CONNECTION, on_avdtp_connection)
avrcp_delegate = Delegate()
avrcp_protocol = avrcp.Protocol(avrcp_delegate)
+1 -1
View File
@@ -112,7 +112,7 @@ async def main() -> None:
await device.set_discoverable(True)
await device.set_connectable(True)
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+1 -1
View File
@@ -73,7 +73,7 @@ async def main() -> None:
await device.power_on()
await device.start_discovery()
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+1 -1
View File
@@ -57,7 +57,7 @@ async def main() -> None:
print(f'!!! Encryption failed: {error}')
return
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+201
View File
@@ -0,0 +1,201 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import sys
from collections.abc import Callable
import bumble.logging
from bumble.core import BaseError
from bumble.device import Connection, Device
from bumble.hci import Address, LeFeatureMask
from bumble.transport import open_transport
# -----------------------------------------------------------------------------
DEFAULT_CENTRAL_ADDRESS = Address("F0:F0:F0:F0:F0:F0")
DEFAULT_PERIPHERAL_ADDRESS = Address("F1:F1:F1:F1:F1:F1")
# -----------------------------------------------------------------------------
async def run_as_central(
device: Device,
scenario: Callable | None,
) -> None:
# Connect to the peripheral
print(f'=== Connecting to {DEFAULT_PERIPHERAL_ADDRESS}...')
connection = await device.connect(DEFAULT_PERIPHERAL_ADDRESS)
print("=== Connected")
if scenario is not None:
await asyncio.sleep(1)
await scenario(connection)
await asyncio.get_running_loop().create_future()
async def run_as_peripheral(device: Device, scenario: Callable | None) -> None:
# Wait for a connection from the central
print(f'=== Advertising as {DEFAULT_PERIPHERAL_ADDRESS}...')
await device.start_advertising(auto_restart=True)
async def on_connection(connection: Connection) -> None:
assert scenario is not None
await asyncio.sleep(1)
await scenario(connection)
if scenario is not None:
device.on(Device.EVENT_CONNECTION, on_connection)
await asyncio.get_running_loop().create_future()
async def change_parameters(
connection: Connection,
parameter_request_procedure_supported: bool,
subrating_supported: bool,
shorter_connection_intervals_supported: bool,
) -> None:
if parameter_request_procedure_supported:
try:
print(">>> update_parameters(7.5, 200, 0, 4000)")
await connection.update_parameters(7.5, 200, 0, 4000)
await asyncio.sleep(3)
except BaseError as error:
print(f"Error: {error}")
if subrating_supported:
try:
print(">>> update_subrate(1, 2, 2, 1, 4000)")
await connection.update_subrate(1, 2, 2, 1, 4000)
await asyncio.sleep(3)
except BaseError as error:
print(f"Error: {error}")
if shorter_connection_intervals_supported:
try:
print(
">>> update_parameters_with_subrate(7.5, 200, 1, 1, 0, 0, 4000, 5, 1000)"
)
await connection.update_parameters_with_subrate(
7.5, 200, 1, 1, 0, 0, 4000, 5, 1000
)
await asyncio.sleep(3)
except BaseError as error:
print(f"Error: {error}")
try:
print(
">>> update_parameters_with_subrate(0.750, 5, 1, 1, 0, 0, 4000, 0.125, 1000)"
)
await connection.update_parameters_with_subrate(
0.750, 5, 1, 1, 0, 0, 4000, 0.125, 1000
)
await asyncio.sleep(3)
except BaseError as error:
print(f"Error: {error}")
print(">>> done")
def on_connection(connection: Connection) -> None:
print(f"+++ Connection established: {connection}")
def on_le_remote_features_change() -> None:
print(f'... LE Remote Features change: {connection.peer_le_features.name}')
connection.on(
connection.EVENT_LE_REMOTE_FEATURES_CHANGE, on_le_remote_features_change
)
def on_connection_parameters_change() -> None:
print(f'... LE Connection Parameters change: {connection.parameters}')
connection.on(
connection.EVENT_CONNECTION_PARAMETERS_UPDATE, on_connection_parameters_change
)
async def main() -> None:
if len(sys.argv) < 3:
print(
'Usage: run_connection_updates.py <transport-spec> '
'central|peripheral initiator|responder'
)
return
print('<<< connecting to HCI...')
async with await open_transport(sys.argv[1]) as hci_transport:
print('<<< connected')
role = sys.argv[2]
direction = sys.argv[3]
device = Device.with_hci(
role,
(
DEFAULT_CENTRAL_ADDRESS
if role == "central"
else DEFAULT_PERIPHERAL_ADDRESS
),
hci_transport.source,
hci_transport.sink,
)
device.le_subrate_enabled = True
device.le_shorter_connection_intervals_enabled = True
await device.power_on()
parameter_request_procedure_supported = device.supports_le_features(
LeFeatureMask.CONNECTION_PARAMETERS_REQUEST_PROCEDURE
)
print(
"Parameters Request Procedure supported: "
f"{parameter_request_procedure_supported}"
)
subrating_supported = device.supports_le_features(
LeFeatureMask.CONNECTION_SUBRATING
)
print(f"Subrating supported: {subrating_supported}")
shorter_connection_intervals_supported = device.supports_le_features(
LeFeatureMask.SHORTER_CONNECTION_INTERVALS
)
print(
"Shorter Connection Intervals supported: "
f"{shorter_connection_intervals_supported}"
)
device.on(Device.EVENT_CONNECTION, on_connection)
async def run(connection: Connection) -> None:
await change_parameters(
connection,
parameter_request_procedure_supported,
subrating_supported,
shorter_connection_intervals_supported,
)
scenario = run if direction == "initiator" else None
if role == "central":
await run_as_central(device, scenario)
else:
await run_as_peripheral(device, scenario)
# -----------------------------------------------------------------------------
bumble.logging.setup_basic_logging('DEBUG')
asyncio.run(main())
+1 -1
View File
@@ -101,7 +101,7 @@ async def main() -> None:
await device.start_advertising()
await device.start_scanning()
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+1 -1
View File
@@ -48,7 +48,7 @@ async def main() -> None:
await device.power_on()
await device.start_scanning()
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+12 -9
View File
@@ -19,10 +19,10 @@ import asyncio
import sys
import bumble.logging
from bumble import gatt_client
from bumble.colors import color
from bumble.core import ProtocolError
from bumble.device import Device, Peer
from bumble.gatt import show_services
from bumble.device import Connection, Device
from bumble.transport import open_transport
from bumble.utils import AsyncRunner
@@ -34,24 +34,27 @@ class Listener(Device.Listener):
@AsyncRunner.run_in_task()
# pylint: disable=invalid-overridden-method
async def on_connection(self, connection):
async def on_connection(self, connection: Connection):
print(f'=== Connected to {connection}')
# Discover all services
print('=== Discovering services')
peer = Peer(connection)
await peer.discover_services()
for service in peer.services:
if connection.device.config.eatt_enabled:
client = await gatt_client.Client.connect_eatt(connection)
else:
client = connection.gatt_client
await client.discover_services()
for service in client.services:
await service.discover_characteristics()
for characteristic in service.characteristics:
await characteristic.discover_descriptors()
print('=== Services discovered')
show_services(peer.services)
gatt_client.show_services(client.services)
# Discover all attributes
print('=== Discovering attributes')
attributes = await peer.discover_attributes()
attributes = await client.discover_attributes()
for attribute in attributes:
print(attribute)
print('=== Attributes discovered')
@@ -59,7 +62,7 @@ class Listener(Device.Listener):
# Read all attributes
for attribute in attributes:
try:
value = await peer.read_value(attribute)
value = await client.read_value(attribute)
print(color(f'0x{attribute.handle:04X} = {value.hex()}', 'green'))
except ProtocolError as error:
print(color(f'cannot read {attribute.handle:04X}:', 'red'), error)
+1
View File
@@ -83,6 +83,7 @@ async def main() -> None:
GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic]
)
server_device.add_service(device_info_service)
await server_device.start_advertising()
# Connect the client to the server
connection = await client_device.connect(server_device.random_address)
+1 -1
View File
@@ -147,7 +147,7 @@ async def main() -> None:
else:
await device.start_advertising(auto_restart=True)
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
@@ -99,7 +99,7 @@ async def main() -> None:
else:
await device.start_advertising(auto_restart=True)
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+1 -1
View File
@@ -422,7 +422,7 @@ async def main() -> None:
# Setup a server
await server(device)
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+3 -7
View File
@@ -100,13 +100,9 @@ def on_sco_packet(packet: hci.HCI_SynchronousDataPacket):
if source_file and (pcm_data := source_file.read(packet.data_total_length)):
assert ag_protocol
host = ag_protocol.dlc.multiplexer.l2cap_channel.connection.device.host
host.send_hci_packet(
hci.HCI_SynchronousDataPacket(
connection_handle=packet.connection_handle,
packet_status=0,
data_total_length=len(pcm_data),
data=pcm_data,
)
host.send_sco_sdu(
connection_handle=packet.connection_handle,
sdu=pcm_data,
)
+1 -1
View File
@@ -167,7 +167,7 @@ async def main() -> None:
await websockets.asyncio.server.serve(serve, 'localhost', 8989)
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+1 -1
View File
@@ -735,7 +735,7 @@ async def main() -> None:
print("Executing in Web mode")
await keyboard_device(hid_device)
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+1 -1
View File
@@ -556,7 +556,7 @@ async def main() -> None:
# Interrupt Channel
await hid_host.connect_interrupt_channel()
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+1 -1
View File
@@ -227,7 +227,7 @@ async def main() -> None:
tcp_port = int(sys.argv[5])
asyncio.create_task(tcp_server(tcp_port, session))
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+1 -1
View File
@@ -153,7 +153,7 @@ async def main() -> None:
await device.set_discoverable(True)
await device.set_connectable(True)
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+1 -1
View File
@@ -75,7 +75,7 @@ async def main() -> None:
await device.power_on()
await device.start_scanning(filter_duplicates=filter_duplicates)
await hci_transport.source.wait_for_termination()
await hci_transport.source.terminated
# -----------------------------------------------------------------------------
+9 -5
View File
@@ -13,17 +13,21 @@ authors = [{ name = "Google", email = "bumble-dev@google.com" }]
requires-python = ">=3.10"
dependencies = [
"aiohttp ~= 3.8; platform_system!='Emscripten'",
"appdirs >= 1.4; platform_system!='Emscripten'",
"click >= 8.1.3; platform_system!='Emscripten'",
"cryptography >= 44.0.3; platform_system!='Emscripten'",
"cryptography >= 44.0.3; platform_system!='Emscripten' and platform_system!='Android'",
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the
# versions available on PyPI. Relax the version requirement since it's better than being
# completely unable to import the package in case of version mismatch.
"cryptography >= 44.0.3; platform_system=='Emscripten'",
"cryptography >= 39.0.0; platform_system=='Emscripten'",
# Android wheels for cryptography are not yet available on PyPI, so chaquopy uses
# the builds from https://chaquo.com/pypi-13.1/cryptography/. But these are not regually
# updated. Relax the version requirement since it's better than being completely unable
# to import the package in case of version mismatch.
"cryptography >= 42.0.8; platform_system=='Android'",
"grpcio >= 1.62.1; platform_system!='Emscripten'",
"humanize >= 4.6.0; platform_system!='Emscripten'",
"libusb1 >= 2.0.1; platform_system!='Emscripten'",
"libusb-package == 1.0.26.1; platform_system!='Emscripten'",
"libusb-package == 1.0.26.1; platform_system!='Emscripten' and platform_system!='Android'",
"platformdirs >= 3.10.0; platform_system!='Emscripten'",
"prompt_toolkit >= 3.0.16; platform_system!='Emscripten'",
"prettytable >= 3.6.0; platform_system!='Emscripten'",
@@ -32,7 +36,7 @@ dependencies = [
"pyserial-asyncio >= 0.5; platform_system!='Emscripten'",
"pyserial >= 3.5; platform_system!='Emscripten'",
"pyusb >= 1.2; platform_system!='Emscripten'",
"tomli ~= 2.2.1; platform_system!='Emscripten'",
"tomli ~= 2.2.1; platform_system!='Emscripten' and python_version<'3.11'",
"websockets >= 15.0.1; platform_system!='Emscripten'",
]
+2 -2
View File
@@ -221,9 +221,9 @@ checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1"
[[package]]
name = "bytes"
version = "1.5.0"
version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33"
[[package]]
name = "cc"
+1 -1
View File
@@ -30,7 +30,7 @@ hex = "0.4.3"
itertools = "0.11.0"
lazy_static = "1.4.0"
thiserror = "1.0.41"
bytes = "1.5.0"
bytes = "1.11.1"
pdl-derive = "0.2.0"
pdl-runtime = "0.2.0"
futures = "0.3.28"
+1 -1
View File
@@ -17,6 +17,6 @@ use pyo3::PyResult;
#[pyo3_asyncio::tokio::test]
async fn realtek_driver_info_all_drivers() -> PyResult<()> {
assert_eq!(12, DriverInfo::all_drivers()?.len());
assert_eq!(13, DriverInfo::all_drivers()?.len());
Ok(())
}
+25
View File
@@ -120,6 +120,31 @@ def test_messages(message: avdtp.Message):
assert message.payload == parsed.payload
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'pdu',
(
b'', # empty PDU — would IndexError on pdu[0]
b'\x00', # 1-byte SINGLE_PACKET — would IndexError on pdu[1]
b'\x04', # 1-byte START_PACKET — would IndexError on pdu[1]
b'\x44\x10', # 2-byte START_PACKET — would IndexError on pdu[2]
),
)
def test_message_assembler_truncated_pdu(pdu: bytes):
"""Truncated AVDTP PDUs from a remote peer must NOT raise IndexError —
same DoS class as #912 (ATT empty PDU). The assembler is required to
log + drop and stay alive so the L2CAP channel survives."""
completed = []
def callback(transaction_label, message):
completed.append((transaction_label, message))
assembler = avdtp.MessageAssembler(callback)
# Must not raise; nothing should be delivered to callback either.
assembler.on_pdu(pdu)
assert not completed
# -----------------------------------------------------------------------------
def test_rtp():
packet = bytes.fromhex(
+393 -4
View File
@@ -17,8 +17,10 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import struct
from collections.abc import Sequence
from unittest import mock
import pytest
@@ -117,8 +119,6 @@ class TwoDevices(test_utils.TwoDevices):
scope=avrcp.Scope.NOW_PLAYING,
uid=0,
uid_counter=1,
start_item=0,
end_item=0,
attributes=[avrcp.MediaAttributeId.DEFAULT_COVER_ART],
),
avrcp.GetTotalNumberOfItemsCommand(scope=avrcp.Scope.NOW_PLAYING),
@@ -135,7 +135,7 @@ def test_command(command: avrcp.Command):
"event,",
[
avrcp.UidsChangedEvent(uid_counter=7),
avrcp.TrackChangedEvent(identifier=b'12356'),
avrcp.TrackChangedEvent(uid=12356),
avrcp.VolumeChangedEvent(volume=9),
avrcp.PlaybackStatusChangedEvent(play_status=avrcp.PlayStatus.PLAYING),
avrcp.AddressedPlayerChangedEvent(
@@ -233,7 +233,21 @@ def test_event(event: avrcp.Event):
feature_bitmask=avrcp.MediaPlayerItem.Features.ADD_TO_NOW_PLAYING,
character_set_id=avrcp.CharacterSetId.UTF_8,
displayable_name="Woo",
)
),
avrcp.FolderItem(
folder_uid=1,
folder_type=avrcp.FolderItem.FolderType.ALBUMS,
is_playable=avrcp.FolderItem.Playable.PLAYABLE,
character_set_id=avrcp.CharacterSetId.UTF_8,
displayable_name="Album",
),
avrcp.MediaElementItem(
media_element_uid=1,
media_type=avrcp.MediaElementItem.MediaType.AUDIO,
character_set_id=avrcp.CharacterSetId.UTF_8,
displayable_name="Song",
attribute_value_entry_list=[],
),
],
),
avrcp.ChangePathResponse(
@@ -408,6 +422,47 @@ def test_passthrough_commands():
assert bytes(parsed) == play_pressed_bytes
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_find_sdp_records():
two_devices = await TwoDevices.create_with_avdtp()
# Add SDP records to device 1
controller_record = avrcp.ControllerServiceSdpRecord(
service_record_handle=0x10001,
avctp_version=(1, 4),
avrcp_version=(1, 6),
supported_features=(
avrcp.ControllerFeatures.CATEGORY_1
| avrcp.ControllerFeatures.SUPPORTS_BROWSING
),
)
target_record = avrcp.TargetServiceSdpRecord(
service_record_handle=0x10002,
avctp_version=(1, 4),
avrcp_version=(1, 6),
supported_features=(
avrcp.TargetFeatures.CATEGORY_1 | avrcp.TargetFeatures.SUPPORTS_BROWSING
),
)
two_devices.devices[1].sdp_service_records = {
0x10001: controller_record.to_service_attributes(),
0x10002: target_record.to_service_attributes(),
}
# Find records from device 0
controller_records = await avrcp.ControllerServiceSdpRecord.find(
two_devices.connections[0]
)
assert len(controller_records) == 1
assert controller_records[0] == controller_record
target_records = await avrcp.TargetServiceSdpRecord.find(two_devices.connections[0])
assert len(target_records) == 1
assert target_records[0] == target_record
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_supported_events():
@@ -422,6 +477,340 @@ async def test_get_supported_events():
assert supported_events == [avrcp.EventId.VOLUME_CHANGED]
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_passthrough_key_event():
two_devices = await TwoDevices.create_with_avdtp()
q = asyncio.Queue[tuple[avc.PassThroughFrame.OperationId, bool, bytes]]()
class Delegate(avrcp.Delegate):
async def on_key_event(
self, key: avc.PassThroughFrame.OperationId, pressed: bool, data: bytes
) -> None:
q.put_nowait((key, pressed, data))
two_devices.protocols[1].delegate = Delegate()
for key, pressed in [
(avc.PassThroughFrame.OperationId.PLAY, True),
(avc.PassThroughFrame.OperationId.PLAY, False),
(avc.PassThroughFrame.OperationId.PAUSE, True),
(avc.PassThroughFrame.OperationId.PAUSE, False),
]:
await two_devices.protocols[0].send_key_event(key, pressed)
assert (await q.get()) == (key, pressed, b'')
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_passthrough_key_event_rejected():
two_devices = await TwoDevices.create_with_avdtp()
class Delegate(avrcp.Delegate):
async def on_key_event(
self, key: avc.PassThroughFrame.OperationId, pressed: bool, data: bytes
) -> None:
raise avrcp.Delegate.AvcError(avc.ResponseFrame.ResponseCode.REJECTED)
two_devices.protocols[1].delegate = Delegate()
response = await two_devices.protocols[0].send_key_event(
avc.PassThroughFrame.OperationId.PLAY, True
)
assert response.response == avc.ResponseFrame.ResponseCode.REJECTED
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_passthrough_key_event_exception():
two_devices = await TwoDevices.create_with_avdtp()
class Delegate(avrcp.Delegate):
async def on_key_event(
self, key: avc.PassThroughFrame.OperationId, pressed: bool, data: bytes
) -> None:
raise Exception()
two_devices.protocols[1].delegate = Delegate()
response = await two_devices.protocols[0].send_key_event(
avc.PassThroughFrame.OperationId.PLAY, True
)
assert response.response == avc.ResponseFrame.ResponseCode.REJECTED
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_volume():
two_devices = await TwoDevices.create_with_avdtp()
for volume in range(avrcp.SetAbsoluteVolumeCommand.MAXIMUM_VOLUME + 1):
response = await two_devices.protocols[1].send_avrcp_command(
avc.CommandFrame.CommandType.CONTROL, avrcp.SetAbsoluteVolumeCommand(volume)
)
assert isinstance(response.response, avrcp.SetAbsoluteVolumeResponse)
assert response.response.volume == volume
assert two_devices.protocols[0].delegate.volume == volume
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_playback_status():
two_devices = await TwoDevices.create_with_avdtp()
for status in avrcp.PlayStatus:
two_devices.protocols[0].delegate.playback_status = status
response = await two_devices.protocols[1].get_play_status()
assert response.play_status == status
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_supported_company_ids():
two_devices = await TwoDevices.create_with_avdtp()
for status in avrcp.PlayStatus:
two_devices.protocols[0].delegate = avrcp.Delegate(
supported_company_ids=[avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID]
)
supported_company_ids = await two_devices.protocols[
1
].get_supported_company_ids()
assert supported_company_ids == [avrcp.AVRCP_BLUETOOTH_SIG_COMPANY_ID]
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_list_player_application_settings():
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
expected_settings = {
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: [
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
avrcp.ApplicationSetting.RepeatModeStatus.SINGLE_TRACK_REPEAT,
avrcp.ApplicationSetting.RepeatModeStatus.OFF,
],
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: [
avrcp.ApplicationSetting.ShuffleOnOffStatus.OFF,
avrcp.ApplicationSetting.ShuffleOnOffStatus.ALL_TRACKS_SHUFFLE,
avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
],
}
two_devices.protocols[1].delegate = avrcp.Delegate(
supported_player_app_settings=expected_settings
)
actual_settings = await two_devices.protocols[
0
].list_supported_player_app_settings()
assert actual_settings == expected_settings
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_set_player_app_settings():
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate
await two_devices.protocols[0].send_avrcp_command(
avc.CommandFrame.CommandType.CONTROL,
avrcp.SetPlayerApplicationSettingValueCommand(
attribute=[
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
],
value=[
avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
],
),
)
expected_settings = {
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: avrcp.ApplicationSetting.ShuffleOnOffStatus.GROUP_SHUFFLE,
}
assert delegate.player_app_settings == expected_settings
actual_settings = await two_devices.protocols[0].get_player_app_settings(
[
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.AttributeId.SHUFFLE_ON_OFF,
]
)
assert actual_settings == expected_settings
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_play_item():
two_devices: TwoDevices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate
with mock.patch.object(delegate, delegate.play_item.__name__) as play_item_mock:
await two_devices.protocols[0].send_avrcp_command(
avc.CommandFrame.CommandType.CONTROL,
avrcp.PlayItemCommand(
scope=avrcp.Scope.MEDIA_PLAYER_LIST, uid=0, uid_counter=1
),
)
play_item_mock.assert_called_once_with(
scope=avrcp.Scope.MEDIA_PLAYER_LIST, uid=0, uid_counter=1
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_volume():
two_devices = await TwoDevices.create_with_avdtp()
two_devices.protocols[1].delegate = avrcp.Delegate([avrcp.EventId.VOLUME_CHANGED])
volume_iter = two_devices.protocols[0].monitor_volume()
for volume in range(avrcp.SetAbsoluteVolumeCommand.MAXIMUM_VOLUME + 1):
# Interim
two_devices.protocols[1].delegate.volume = 0
assert (await anext(volume_iter)) == 0
# Changed
two_devices.protocols[1].notify_volume_changed(volume)
assert (await anext(volume_iter)) == volume
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_playback_status():
two_devices = await TwoDevices.create_with_avdtp()
two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.PLAYBACK_STATUS_CHANGED]
)
playback_status_iter = two_devices.protocols[0].monitor_playback_status()
for playback_status in avrcp.PlayStatus:
# Interim
two_devices.protocols[1].delegate.playback_status = avrcp.PlayStatus.STOPPED
assert (await anext(playback_status_iter)) == avrcp.PlayStatus.STOPPED
# Changed
two_devices.protocols[1].notify_playback_status_changed(playback_status)
assert (await anext(playback_status_iter)) == playback_status
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_now_playing_content():
two_devices = await TwoDevices.create_with_avdtp()
two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.NOW_PLAYING_CONTENT_CHANGED]
)
now_playing_iter = two_devices.protocols[0].monitor_now_playing_content()
for _ in range(2):
# Interim
await anext(now_playing_iter)
# Changed
two_devices.protocols[1].notify_now_playing_content_changed()
await anext(now_playing_iter)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_track_changed():
two_devices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.TRACK_CHANGED]
)
delegate.current_track_uid = avrcp.TrackChangedEvent.NO_TRACK
track_iter = two_devices.protocols[0].monitor_track_changed()
# Interim
assert (await anext(track_iter)) == avrcp.TrackChangedEvent.NO_TRACK
# Changed
two_devices.protocols[1].notify_track_changed(1)
assert (await anext(track_iter)) == 1
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_uid_changed():
two_devices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.UIDS_CHANGED]
)
delegate.uid_counter = 0
uid_iter = two_devices.protocols[0].monitor_uids()
# Interim
assert (await anext(uid_iter)) == 0
# Changed
two_devices.protocols[1].notify_uids_changed(1)
assert (await anext(uid_iter)) == 1
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_addressed_player():
two_devices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
[avrcp.EventId.ADDRESSED_PLAYER_CHANGED]
)
delegate.uid_counter = 0
delegate.addressed_player_id = 0
addressed_player_iter = two_devices.protocols[0].monitor_addressed_player()
# Interim
assert (
await anext(addressed_player_iter)
) == avrcp.AddressedPlayerChangedEvent.Player(player_id=0, uid_counter=0)
# Changed
two_devices.protocols[1].notify_addressed_player_changed(
avrcp.AddressedPlayerChangedEvent.Player(player_id=1, uid_counter=1)
)
assert (
await anext(addressed_player_iter)
) == avrcp.AddressedPlayerChangedEvent.Player(player_id=1, uid_counter=1)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_monitor_player_app_settings():
two_devices = await TwoDevices.create_with_avdtp()
delegate = two_devices.protocols[1].delegate = avrcp.Delegate(
supported_events=[avrcp.EventId.PLAYER_APPLICATION_SETTING_CHANGED]
)
delegate.player_app_settings = {
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE: avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT
}
settings_iter = two_devices.protocols[0].monitor_player_application_settings()
# Interim
interim = await anext(settings_iter)
assert interim[0].attribute_id == avrcp.ApplicationSetting.AttributeId.REPEAT_MODE
assert (
interim[0].value_id
== avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT
)
# Changed
two_devices.protocols[1].notify_player_application_settings_changed(
[
avrcp.PlayerApplicationSettingChangedEvent.Setting(
avrcp.ApplicationSetting.AttributeId.REPEAT_MODE,
avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT,
)
]
)
changed = await anext(settings_iter)
assert changed[0].attribute_id == avrcp.ApplicationSetting.AttributeId.REPEAT_MODE
assert changed[0].value_id == avrcp.ApplicationSetting.RepeatModeStatus.GROUP_REPEAT
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_frame_parser()
+34
View File
@@ -0,0 +1,34 @@
# Copyright 2021-2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from bumble import device as device_module
from bumble.profiles import battery_service
from . import test_utils
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_battery_level():
devices = await test_utils.TwoDevices.create_with_connection()
service = battery_service.BatteryService(lambda _: 1)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(battery_service.BatteryServiceProxy)
assert client
assert await client.battery_level.read_value() == 1
+8
View File
@@ -73,6 +73,14 @@ def test_uuid_to_hex_str() -> None:
)
# -----------------------------------------------------------------------------
def test_uuid_hash() -> None:
uuid = UUID("1234")
uuid_128_bytes = UUID.from_bytes(uuid.to_bytes(force_128=True))
assert uuid in {uuid_128_bytes}
assert uuid_128_bytes in {uuid}
# -----------------------------------------------------------------------------
def test_appearance() -> None:
a = Appearance(Appearance.Category.COMPUTER, Appearance.ComputerSubcategory.LAPTOP)
+49 -9
View File
@@ -42,7 +42,6 @@ from bumble.hci import (
HCI_CREATE_CONNECTION_COMMAND,
HCI_SUCCESS,
Address,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_Connection_Complete_Event,
HCI_Connection_Request_Event,
@@ -154,10 +153,10 @@ async def test_device_connect_parallel():
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d1.host.on_hci_packet(
HCI_Command_Complete_Event(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
)
)
@@ -188,10 +187,10 @@ async def test_device_connect_parallel():
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d2.host.on_hci_packet(
HCI_Command_Complete_Event(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
)
)
@@ -292,9 +291,9 @@ async def test_legacy_advertising_disconnection(auto_restart):
await devices[0].start_advertising(
auto_restart=auto_restart, advertising_interval_min=1.0
)
connecion = await devices[1].connect(devices[0].random_address)
connection = await devices[1].connect(devices[0].random_address)
await connecion.disconnect()
await connection.disconnect()
await async_barrier()
await async_barrier()
@@ -310,6 +309,27 @@ async def test_legacy_advertising_disconnection(auto_restart):
assert not devices[0].is_advertising
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_le_multiple_connects():
devices = TwoDevices()
for controller in devices.controllers:
controller.le_features |= hci.LeFeatureMask.LE_EXTENDED_ADVERTISING
for dev in devices:
await dev.power_on()
await devices[0].start_advertising(auto_restart=True, advertising_interval_min=1.0)
connection = await devices[1].connect(devices[0].random_address)
await connection.disconnect()
await async_barrier()
await async_barrier()
# a second connection attempt is working
connection = await devices[1].connect(devices[0].random_address)
await connection.disconnect()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_advertising_and_scanning():
@@ -446,7 +466,9 @@ async def test_get_remote_le_features():
devices = TwoDevices()
await devices.setup_connection()
assert (await devices.connections[0].get_remote_le_features()) is not None
assert (
await devices.connections[0].get_remote_le_features()
) == devices.controllers[1].le_features
# -----------------------------------------------------------------------------
@@ -620,7 +642,9 @@ async def test_le_request_subrate():
def on_le_subrate_change():
q.put_nowait(lambda: None)
devices.connections[0].on(Connection.EVENT_LE_SUBRATE_CHANGE, on_le_subrate_change)
devices.connections[0].on(
Connection.EVENT_CONNECTION_PARAMETERS_UPDATE, on_le_subrate_change
)
await devices[0].send_command(
hci.HCI_LE_Subrate_Request_Command(
@@ -802,6 +826,22 @@ async def test_remote_name_request():
assert actual_name == expected_name
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_remote_classic_features():
devices = TwoDevices()
devices[0].classic_enabled = True
devices[1].classic_enabled = True
await devices[0].power_on()
await devices[1].power_on()
connection = await devices[0].connect_classic(devices[1].public_address)
assert (
await asyncio.wait_for(connection.get_remote_classic_features(), _TIMEOUT)
== devices.controllers[1].lmp_features
)
# -----------------------------------------------------------------------------
async def run_test_device():
await test_device_connect_parallel()
+278 -5
View File
@@ -28,6 +28,7 @@ from unittest.mock import ANY, AsyncMock, Mock
import pytest
from typing_extensions import Self
from bumble import att, gatt_client, l2cap
from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_PDU,
@@ -63,7 +64,6 @@ from bumble.gatt_adapters import (
UTF8CharacteristicAdapter,
UTF8CharacteristicProxyAdapter,
)
from bumble.gatt_client import CharacteristicProxy
from .test_utils import Devices, TwoDevices, async_barrier
@@ -140,7 +140,7 @@ async def test_characteristic_encoding():
await c.write_value(Mock(), bytes([122]))
assert c.value == 122
class FooProxy(CharacteristicProxy):
class FooProxy(gatt_client.CharacteristicProxy):
def __init__(self, characteristic):
super().__init__(
characteristic.client,
@@ -456,7 +456,7 @@ async def test_CharacteristicProxyAdapter() -> None:
async def write_value(self, handle, value, with_response=False):
self.value = value
class TestAttributeProxy(CharacteristicProxy):
class TestAttributeProxy(gatt_client.CharacteristicProxy):
def __init__(self, value) -> None:
super().__init__(Client(value), 0, 0, None, 0) # type: ignore
@@ -1425,10 +1425,10 @@ async def test_get_characteristics_by_uuid():
await peer.discover_characteristics()
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))
assert len(c) == 2
assert isinstance(c[0], CharacteristicProxy)
assert isinstance(c[0], gatt_client.CharacteristicProxy)
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('ABCD'))
assert len(c) == 1
assert isinstance(c[0], CharacteristicProxy)
assert isinstance(c[0], gatt_client.CharacteristicProxy)
c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('AAAA'))
assert len(c) == 0
@@ -1463,6 +1463,279 @@ async def test_write_return_error():
assert e.value.error_code == ErrorCode.VALUE_NOT_ALLOWED
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_read():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
characteristic = Characteristic(
'1234',
Characteristic.Properties.READ,
Characteristic.Permissions.READABLE,
b'9999',
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
client = await gatt_client.Client.connect_eatt(devices.connections[0])
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
assert await characteristic_proxy.read_value() == b'9999'
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_write():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
write_queue = asyncio.Queue()
characteristic = Characteristic(
'1234',
Characteristic.Properties.WRITE,
Characteristic.Permissions.WRITEABLE,
CharacteristicValue(write=lambda *args: write_queue.put_nowait(args)),
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
client = await gatt_client.Client.connect_eatt(devices.connections[0])
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
await characteristic_proxy.write_value(b'9999')
assert await write_queue.get() == (devices.connections[1], b'9999')
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_notify():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
characteristic = Characteristic(
'1234',
Characteristic.Properties.NOTIFY,
Characteristic.Permissions.WRITEABLE,
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
clients = [
(
devices.connections[0].gatt_client,
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
]
for client, queue in clients:
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
for client, queue in clients[:2]:
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
await characteristic_proxy.subscribe(queue.put_nowait, prefer_notify=True)
await devices[1].gatt_server.notify_subscribers(characteristic, b'1234')
for _, queue in clients[:2]:
assert await queue.get() == b'1234'
assert queue.empty()
assert clients[2][1].empty()
await devices[1].gatt_server.notify_subscriber(
devices.connections[1], characteristic, b'5678'
)
for _, queue in clients[:2]:
assert await queue.get() == b'5678'
assert queue.empty()
assert clients[2][1].empty()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_indicate():
devices = await TwoDevices.create_with_connection()
devices[1].gatt_server.register_eatt()
characteristic = Characteristic(
'1234',
Characteristic.Properties.INDICATE,
Characteristic.Permissions.WRITEABLE,
)
service = Service('ABCD', [characteristic])
devices[1].add_service(service)
clients = [
(
devices.connections[0].gatt_client,
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
(
await gatt_client.Client.connect_eatt(devices.connections[0]),
asyncio.Queue[bytes](),
),
]
for client, queue in clients:
await client.discover_services()
service_proxy = client.get_services_by_uuid(service.uuid)[0]
await service_proxy.discover_characteristics()
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
for client, queue in clients[:2]:
characteristic_proxy = service_proxy.get_characteristics_by_uuid(
characteristic.uuid
)[0]
await characteristic_proxy.subscribe(queue.put_nowait, prefer_notify=False)
await devices[1].gatt_server.indicate_subscribers(characteristic, b'1234')
for _, queue in clients[:2]:
assert await queue.get() == b'1234'
assert queue.empty()
assert clients[2][1].empty()
await devices[1].gatt_server.indicate_subscriber(
devices.connections[1], characteristic, b'5678'
)
for _, queue in clients[:2]:
assert await queue.get() == b'5678'
assert queue.empty()
assert clients[2][1].empty()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_eatt_connection_failure():
devices = await TwoDevices.create_with_connection()
with pytest.raises(l2cap.L2capError):
await gatt_client.Client.connect_eatt(devices.connections[0])
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_multiple() -> None:
devices = await TwoDevices.create_with_connection()
characteristic1 = Characteristic(
'0001', Characteristic.Properties.READ, Characteristic.READABLE, b'1234'
)
characteristic2 = Characteristic(
'0002',
Characteristic.Properties.READ,
Characteristic.READABLE,
b'5678',
)
service = Service('0000', [characteristic1, characteristic2])
devices[1].add_service(service)
client = devices.connections[0].gatt_client
server = devices[1].gatt_server
await client.discover_services()
characteristics = await client.discover_characteristics(
[characteristic1.uuid, characteristic2.uuid], None
)
response = await client.send_request(
att.ATT_Read_Multiple_Request(
set_of_handles=[c.handle for c in characteristics]
)
)
assert isinstance(response, att.ATT_Read_Multiple_Response)
assert response.set_of_values == b'12345678'
response = await client.send_request(
att.ATT_Read_Multiple_Request(
set_of_handles=[
next(
handle
for handle in range(0x0001, 0xFFFF)
if not server.get_attribute(handle)
)
]
)
)
assert isinstance(response, att.ATT_Error_Response)
assert response.error_code == att.ATT_ATTRIBUTE_NOT_FOUND_ERROR
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_multiple_variable() -> None:
devices = await TwoDevices.create_with_connection()
characteristic1 = Characteristic(
'0001', Characteristic.Properties.READ, Characteristic.READABLE, b'1234'
)
characteristic2 = Characteristic(
'0002',
Characteristic.Properties.READ,
Characteristic.READABLE,
b'99',
)
service = Service('0000', [characteristic1, characteristic2])
devices[1].add_service(service)
client = devices.connections[0].gatt_client
server = devices[1].gatt_server
await client.discover_services()
characteristics = await client.discover_characteristics(
[characteristic1.uuid, characteristic2.uuid], None
)
response = await client.send_request(
att.ATT_Read_Multiple_Variable_Request(
set_of_handles=[c.handle for c in characteristics]
)
)
assert isinstance(response, att.ATT_Read_Multiple_Variable_Response)
assert response.length_value_tuple_list == [(4, b'1234'), (2, b'99')]
response = await client.send_request(
att.ATT_Read_Multiple_Variable_Request(
set_of_handles=[
next(
handle
for handle in range(0x0001, 0xFFFF)
if not server.get_attribute(handle)
)
]
)
)
assert isinstance(response, att.ATT_Error_Response)
assert response.error_code == att.ATT_ATTRIBUTE_NOT_FOUND_ERROR
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
+56 -45
View File
@@ -20,7 +20,7 @@ import struct
import pytest
from bumble import hci
from bumble import hci, utils
# -----------------------------------------------------------------------------
# pylint: disable=invalid-name
@@ -136,43 +136,25 @@ def test_HCI_LE_Channel_Selection_Algorithm_Event():
# -----------------------------------------------------------------------------
def test_HCI_Command_Complete_Event():
# With a serializable object
event = hci.HCI_Command_Complete_Event(
event1 = hci.HCI_Command_Complete_Event(
num_hci_command_packets=34,
command_opcode=hci.HCI_LE_READ_BUFFER_SIZE_COMMAND,
return_parameters=hci.HCI_LE_Read_Buffer_Size_Command.create_return_parameters(
return_parameters=hci.HCI_LE_Read_Buffer_Size_Command.return_parameters_class(
status=0,
le_acl_data_packet_length=1234,
total_num_le_acl_data_packets=56,
),
)
basic_check(event)
# With an arbitrary byte array
event = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=bytes([1, 2, 3, 4]),
)
basic_check(event)
# With a simple status as a 1-byte array
event = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=bytes([7]),
)
basic_check(event)
event = hci.HCI_Packet.from_bytes(bytes(event))
assert event.return_parameters == 7
basic_check(event1)
# With a simple status as an integer status
event = hci.HCI_Command_Complete_Event(
event3 = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci.HCI_RESET_COMMAND,
return_parameters=9,
return_parameters=hci.HCI_StatusReturnParameters(hci.HCI_ErrorCode(9)),
)
basic_check(event)
assert event.return_parameters == 9
basic_check(event3)
assert event3.return_parameters.status == 9
# -----------------------------------------------------------------------------
@@ -229,6 +211,36 @@ def test_HCI_Vendor_Event():
assert isinstance(parsed, hci.HCI_Vendor_Event)
# -----------------------------------------------------------------------------
def test_return_parameters() -> None:
params = hci.HCI_Reset_Command.parse_return_parameters(bytes.fromhex('3C'))
assert params.status == hci.HCI_ErrorCode.ADVERTISING_TIMEOUT_ERROR
assert isinstance(params.status, utils.OpenIntEnum)
params = hci.HCI_Read_BD_ADDR_Command.parse_return_parameters(
bytes.fromhex('00001122334455')
)
assert params.status == hci.HCI_ErrorCode.SUCCESS
assert isinstance(params.status, utils.OpenIntEnum)
assert isinstance(params.bd_addr, hci.Address)
params = hci.HCI_Read_Local_Name_Command.parse_return_parameters(
bytes.fromhex('0068656c6c6f') + bytes(248 - 5)
)
assert params.status == hci.HCI_ErrorCode.SUCCESS
assert isinstance(params.local_name, bytes)
assert len(params.local_name) == 248
assert hci.map_null_terminated_utf8_string(params.local_name) == 'hello'
# Some return parameters may be shorter than the full length
# (for Command Complete events with errors)
params = hci.HCI_Read_BD_ADDR_Command.parse_return_parameters(
bytes.fromhex('010011223344')
)
assert isinstance(params, hci.HCI_StatusReturnParameters)
assert params.status == hci.HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
# -----------------------------------------------------------------------------
def test_HCI_Command():
command = hci.HCI_Command(op_code=0x5566)
@@ -291,7 +303,7 @@ def test_custom_le_meta_event():
for clazz in inspect.getmembers(hci)
if isinstance(clazz[1], type)
and issubclass(clazz[1], hci.HCI_Command)
and clazz[1] is not hci.HCI_Command
and clazz[1] not in (hci.HCI_Command, hci.HCI_SyncCommand, hci.HCI_AsyncCommand)
],
)
def test_hci_command_subclasses_op_code(clazz: type[hci.HCI_Command]):
@@ -620,21 +632,19 @@ def test_HCI_Read_Local_Supported_Codecs_Command_Complete():
# -----------------------------------------------------------------------------
def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete():
returned_parameters = (
hci.HCI_Read_Local_Supported_Codecs_V2_Command.parse_return_parameters(
bytes(
[
hci.HCI_SUCCESS,
3,
hci.CodecID.A_LOG,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL,
hci.CodecID.CVSD,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO,
hci.CodecID.LINEAR_PCM,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS,
0,
]
)
returned_parameters = hci.HCI_Read_Local_Supported_Codecs_V2_Command.parse_return_parameters(
bytes(
[
hci.HCI_SUCCESS,
3,
hci.CodecID.A_LOG,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_ACL,
hci.CodecID.CVSD,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_SCO,
hci.CodecID.LINEAR_PCM,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.LE_CIS,
0,
]
)
)
assert returned_parameters.standard_codec_ids == [
@@ -643,9 +653,9 @@ def test_HCI_Read_Local_Supported_Codecs_V2_Command_Complete():
hci.CodecID.LINEAR_PCM,
]
assert returned_parameters.standard_codec_transports == [
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_ACL,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.BR_EDR_SCO,
hci.HCI_Read_Local_Supported_Codecs_V2_Command.Transport.LE_CIS,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_ACL,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.BR_EDR_SCO,
hci.HCI_Read_Local_Supported_Codecs_V2_ReturnParameters.Transport.LE_CIS,
]
@@ -737,6 +747,7 @@ def run_test_commands():
if __name__ == '__main__':
run_test_events()
run_test_commands()
test_return_parameters()
test_address()
test_custom()
test_iso_data_packet()
+89
View File
@@ -0,0 +1,89 @@
# Copyright 2021-2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import itertools
from collections.abc import Sequence
import pytest
from bumble import device as device_module
from bumble.profiles import heart_rate_service
from . import test_utils
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
"heart_rate, sensor_contact_detected, energy_expanded, rr_intervals",
itertools.product(
(1, 1000), (True, False, None), (2, None), ((3.0, 4.0, 5.0), None)
),
)
async def test_read_measurement(
heart_rate: int,
sensor_contact_detected: bool | None,
energy_expanded: int | None,
rr_intervals: Sequence[int] | None,
):
devices = await test_utils.TwoDevices.create_with_connection()
measurement = heart_rate_service.HeartRateService.HeartRateMeasurement(
heart_rate, sensor_contact_detected, energy_expanded, rr_intervals
)
service = heart_rate_service.HeartRateService(lambda _: measurement)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(heart_rate_service.HeartRateServiceProxy)
assert client
assert await client.heart_rate_measurement.read_value() == measurement
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_body_sensor_location():
devices = await test_utils.TwoDevices.create_with_connection()
measurement = heart_rate_service.HeartRateService.HeartRateMeasurement(0)
location = heart_rate_service.HeartRateService.BodySensorLocation.FINGER
service = heart_rate_service.HeartRateService(
lambda _: measurement,
body_sensor_location=location,
)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(heart_rate_service.HeartRateServiceProxy)
assert client
assert client.body_sensor_location
assert await client.body_sensor_location.read_value() == location
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_reset_energy_expended() -> None:
devices = await test_utils.TwoDevices.create_with_connection()
measurement = heart_rate_service.HeartRateService.HeartRateMeasurement(1)
reset_energy_expended = asyncio.Queue[None]()
service = heart_rate_service.HeartRateService(
lambda _: measurement,
reset_energy_expended=lambda _: reset_energy_expended.put_nowait(None),
)
devices[0].add_service(service)
async with device_module.Peer(devices.connections[1]) as peer:
client = peer.create_service_proxy(heart_rate_service.HeartRateServiceProxy)
assert client
await client.reset_energy_expended()
await reset_energy_expended.get()
+155 -18
View File
@@ -15,16 +15,31 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import unittest
import unittest.mock
import pytest
from bumble import controller, hci
from bumble.controller import Controller
from bumble.hci import HCI_AclDataPacket
from bumble.hci import (
HCI_AclDataPacket,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_CommandStatus,
HCI_Disconnect_Command,
HCI_Error,
HCI_ErrorCode,
HCI_Event,
HCI_GenericReturnParameters,
HCI_LE_Terminate_BIG_Command,
HCI_Reset_Command,
HCI_StatusReturnParameters,
)
from bumble.host import DataPacketQueue, Host
from bumble.transport.common import AsyncPipeSink
from bumble.transport.common import AsyncPipeSink, TransportSink
# -----------------------------------------------------------------------------
# Logging
@@ -35,34 +50,27 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
'supported_commands, lmp_features',
'supported_commands, max_lmp_features_page_number',
[
(
# Default commands
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
'30f0f9ff01008004000000000000000000000000000000000000000000000000',
# Only LE LMP feature
'0000000060000000',
),
(controller.Controller.supported_commands, 0),
(
# All commands
'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff'
'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff',
set(hci.HCI_Command.command_names.keys()),
# 3 pages of LMP features
'000102030405060708090A0B0C0D0E0F011112131415161718191A1B1C1D1E1F',
2,
),
],
)
async def test_reset(supported_commands: str, lmp_features: str):
async def test_reset(supported_commands: set[int], max_lmp_features_page_number: int):
controller = Controller('C')
controller.supported_commands = bytes.fromhex(supported_commands)
controller.lmp_features = bytes.fromhex(lmp_features)
controller.supported_commands = supported_commands
controller.lmp_features_max_page_number = max_lmp_features_page_number
host = Host(controller, AsyncPipeSink(controller))
await host.reset()
assert host.local_lmp_features == int.from_bytes(
bytes.fromhex(lmp_features), 'little'
assert host.local_lmp_features == (
controller.lmp_features & ~(1 << (64 * max_lmp_features_page_number + 1))
)
@@ -151,3 +159,132 @@ def test_data_packet_queue():
assert drain_listener.on_flow.call_count == 1
assert queue.queued == 15
assert queue.completed == 15
# -----------------------------------------------------------------------------
class Source:
terminated: asyncio.Future[None]
sink: TransportSink
def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink
class Sink:
response: HCI_Event | None
def __init__(self, source: Source, response: HCI_Event | None) -> None:
self.source = source
self.response = response
def on_packet(self, packet: bytes) -> None:
if self.response is not None:
self.source.sink.on_packet(bytes(self.response))
@pytest.mark.asyncio
async def test_send_sync_command() -> None:
source = Source()
sink = Sink(
source,
HCI_Command_Complete_Event(
1,
HCI_Reset_Command.op_code,
HCI_StatusReturnParameters(status=HCI_ErrorCode.SUCCESS),
),
)
host = Host(source, sink)
host.ready = True
# Sync command with success
response1 = await host.send_sync_command(HCI_Reset_Command())
assert response1.status == HCI_ErrorCode.SUCCESS
# Sync command with error status should raise
error_response = HCI_Command_Complete_Event(
1,
HCI_Reset_Command.op_code,
HCI_StatusReturnParameters(status=HCI_ErrorCode.COMMAND_DISALLOWED_ERROR),
)
sink.response = error_response
with pytest.raises(HCI_Error) as excinfo:
await host.send_sync_command(HCI_Reset_Command())
assert excinfo.value.error_code == error_response.return_parameters.status
# Sync command with raw result
response2 = await host.send_sync_command_raw(HCI_Reset_Command())
assert response2.return_parameters.status == HCI_ErrorCode.COMMAND_DISALLOWED_ERROR
# Sync command with a command that's not an HCI_SyncCommand
# (here, for convenience, we use an HCI_AsyncCommand instance)
command = HCI_Disconnect_Command(connection_handle=0x1234, reason=0x13)
sink.response = HCI_Command_Complete_Event(
1,
command.op_code,
HCI_GenericReturnParameters(data=bytes.fromhex("00112233")),
)
response3 = await host.send_sync_command_raw(command) # type: ignore
assert isinstance(response3.return_parameters, HCI_GenericReturnParameters)
@pytest.mark.asyncio
async def test_send_sync_command_timeout() -> None:
source = Source()
sink = Sink(source, None)
host = Host(source, sink)
host.ready = True
with pytest.raises(asyncio.TimeoutError):
await host.send_sync_command(HCI_Reset_Command(), response_timeout=0.01)
# The sending semaphore should have been released, so this should not block
# indefinitely
with pytest.raises(asyncio.TimeoutError):
await host.send_sync_command(hci.HCI_Reset_Command(), response_timeout=0.01)
@pytest.mark.asyncio
async def test_send_async_command() -> None:
source = Source()
sink = Sink(
source,
HCI_Command_Status_Event(
HCI_CommandStatus.PENDING,
1,
HCI_Reset_Command.op_code,
),
)
host = Host(source, sink)
host.ready = True
# Normal pending status
response = await host.send_async_command(
HCI_LE_Terminate_BIG_Command(big_handle=0, reason=0)
)
assert response == HCI_CommandStatus.PENDING
# Unknown HCI command result returned as a Command Status
sink.response = HCI_Command_Status_Event(
HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR,
1,
HCI_LE_Terminate_BIG_Command.op_code,
)
response = await host.send_async_command(
HCI_LE_Terminate_BIG_Command(big_handle=0, reason=0), check_status=False
)
assert response == HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
# Unknown HCI command result returned as a Command Complete
sink.response = HCI_Command_Complete_Event(
1,
HCI_LE_Terminate_BIG_Command.op_code,
HCI_StatusReturnParameters(HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR),
)
response = await host.send_async_command(
HCI_LE_Terminate_BIG_Command(big_handle=0, reason=0), check_status=False
)
assert response == HCI_ErrorCode.UNKNOWN_HCI_COMMAND_ERROR
+45
View File
@@ -21,6 +21,7 @@ import logging
import os
import pathlib
import tempfile
from unittest import mock
import pytest
@@ -179,11 +180,55 @@ async def test_default_namespace(temporary_file):
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_no_filename(tmp_path):
import platformdirs
with mock.patch.object(platformdirs, 'user_data_path', return_value=tmp_path):
# Case 1: no namespace, no filename
keystore = JsonKeyStore(None, None)
expected_directory = tmp_path / 'Pairing'
expected_filename = expected_directory / 'keys.json'
assert keystore.directory_name == expected_directory
assert keystore.filename == expected_filename
# Save some data
keys = PairingKeys()
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
assert expected_filename.exists()
# Load back
keystore2 = JsonKeyStore(None, None)
foo = await keystore2.get('foo')
assert foo is not None
assert foo.ltk.value == ltk
# Case 2: namespace, no filename
keystore3 = JsonKeyStore('my:namespace', None)
# safe_name = 'my-namespace' (lower is already 'my:namespace', then replace ':' with '-')
expected_filename3 = expected_directory / 'my-namespace.json'
assert keystore3.filename == expected_filename3
# Save some data
await keystore3.update('bar', keys)
assert expected_filename3.exists()
# Load back
keystore4 = JsonKeyStore('my:namespace', None)
bar = await keystore4.get('bar')
assert bar is not None
assert bar.ltk.value == ltk
# -----------------------------------------------------------------------------
async def run_tests():
await test_basic()
await test_parsing()
await test_default_namespace()
await test_no_filename()
# -----------------------------------------------------------------------------
+10 -17
View File
@@ -239,20 +239,7 @@ async def transfer_payload(
channels[1].sink = received.put_nowait
sdu_lengths = (21, 70, 700, 5523)
if isinstance(channels[1], l2cap.LeCreditBasedChannel):
mps = channels[1].mps
elif isinstance(
processor := channels[1].processor, l2cap.EnhancedRetransmissionProcessor
):
mps = processor.mps
else:
mps = channels[1].mtu
messages = [
bytes([i % 8 for i in range(sdu_length)])
for sdu_length in sdu_lengths
if sdu_length <= mps
]
messages = [bytes([i % 8 for i in range(sdu_length)]) for sdu_length in sdu_lengths]
for message in messages:
channels[0].write(message)
if isinstance(channels[0], l2cap.LeCreditBasedChannel):
@@ -334,20 +321,26 @@ async def test_mtu():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_enhanced_retransmission_mode():
@pytest.mark.parametrize("mtu,", (50, 255, 256, 1000))
async def test_enhanced_retransmission_mode(mtu: int):
devices = TwoDevices()
await devices.setup_connection()
server_channels = asyncio.Queue[l2cap.ClassicChannel]()
server = devices.devices[1].create_l2cap_server(
spec=l2cap.ClassicChannelSpec(
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
mtu=mtu,
mps=256,
),
handler=server_channels.put_nowait,
)
client_channel = await devices.connections[0].create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(
server.psm, mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION
server.psm,
mode=l2cap.TransmissionMode.ENHANCED_RETRANSMISSION,
mtu=mtu,
mps=1024,
)
)
server_channel = await server_channels.get()
+52
View File
@@ -18,9 +18,11 @@
import asyncio
import logging
import os
import re
import pytest
from bumble import sdp
from bumble.core import BT_L2CAP_PROTOCOL_ID, UUID
from bumble.sdp import (
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
@@ -206,6 +208,16 @@ def sdp_records(record_count=1):
}
# -----------------------------------------------------------------------------
def test_pdu_parameter_length(caplog) -> None:
caplog.set_level(logging.WARNING)
pdu = sdp.SDP_ErrorResponse(
transaction_id=0, error_code=sdp.ErrorCode.INVALID_SDP_VERSION
)
assert sdp.SDP_PDU.from_bytes(bytes(pdu)) == pdu
assert not re.search(r"Expect \d+ bytes, got \d+", caplog.text)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_search():
@@ -428,3 +440,43 @@ async def run():
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())
# -----------------------------------------------------------------------------
def test_nested_sequence_recursion_guard():
"""Regression test: deeply-nested SDP SEQUENCE/ALTERNATIVE must not crash
the parser with RecursionError. Instead a ValueError is raised once the
configured nesting limit is exceeded.
Root cause: DataElement.from_bytes -> list_from_bytes -> (constructor
dispatching back to list_from_bytes for SEQUENCE/ALTERNATIVE) recursed
without a depth limit. A malicious SDP peer could craft a PDU exceeding
Pythons default recursion limit (~1000 frames) and crash the host.
"""
# Build nested SEQUENCE payload with tag 0x36 (SEQUENCE, 2-byte length).
inner = b"\x35\x00" # empty SEQUENCE terminator
for _ in range(1500):
size = len(inner)
if size >= 65535:
break
inner = bytes([0x36, (size >> 8) & 0xFF, size & 0xFF]) + inner
with pytest.raises(ValueError, match="nesting exceeds max depth"):
DataElement.from_bytes(inner)
def test_nested_sequence_within_limit_still_works():
"""Nested-but-reasonable SDP SEQUENCEs must still parse correctly."""
leaf = DataElement.unsigned_integer(1, value_size=2)
payload = leaf
for _ in range(16): # under the 32-depth limit
payload = DataElement.sequence([payload])
raw = bytes(payload)
parsed = DataElement.from_bytes(raw)
# Walk back down to confirm structural integrity preserved
cur = parsed
for _ in range(16):
assert cur.type == DataElement.SEQUENCE
cur = cur.value[0]
assert cur.type == DataElement.UNSIGNED_INTEGER
assert cur.value == 1
+7 -10
View File
@@ -29,8 +29,7 @@ from bumble.gatt import Characteristic, Service
from bumble.hci import Role
from bumble.pairing import PairingConfig, PairingDelegate
from bumble.smp import (
SMP_CONFIRM_VALUE_FAILED_ERROR,
SMP_PAIRING_NOT_SUPPORTED_ERROR,
ErrorCode,
OobContext,
OobLegacyContext,
)
@@ -57,15 +56,13 @@ async def test_self_disconnection():
await two_devices.setup_connection()
await two_devices.connections[0].disconnect()
await async_barrier()
assert two_devices.connections[0] is None
assert two_devices.connections[1] is None
assert not two_devices.connections
two_devices = TwoDevices()
await two_devices.setup_connection()
await two_devices.connections[1].disconnect()
await async_barrier()
assert two_devices.connections[0] is None
assert two_devices.connections[1] is None
assert not two_devices.connections
# -----------------------------------------------------------------------------
@@ -380,7 +377,7 @@ async def test_self_smp_reject():
await _test_self_smp_with_configs(None, rejecting_pairing_config)
paired = True
except ProtocolError as error:
assert error.error_code == SMP_PAIRING_NOT_SUPPORTED_ERROR
assert error.error_code == ErrorCode.PAIRING_NOT_SUPPORTED
assert not paired
@@ -405,7 +402,7 @@ async def test_self_smp_wrong_pin():
)
paired = True
except ProtocolError as error:
assert error.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
assert error.error_code == ErrorCode.CONFIRM_VALUE_FAILED
assert not paired
@@ -536,11 +533,11 @@ async def test_self_smp_oob_sc():
with pytest.raises(ProtocolError) as error:
await _test_self_smp_with_configs(pairing_config_1, pairing_config_4)
assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
assert error.value.error_code == ErrorCode.CONFIRM_VALUE_FAILED
with pytest.raises(ProtocolError):
await _test_self_smp_with_configs(pairing_config_4, pairing_config_1)
assert error.value.error_code == SMP_CONFIRM_VALUE_FAILED_ERROR
assert error.value.error_code == ErrorCode.CONFIRM_VALUE_FAILED
# -----------------------------------------------------------------------------
+15 -1
View File
@@ -24,7 +24,7 @@ import pytest
from bumble import crypto, pairing, smp
from bumble.core import AdvertisingData
from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
from bumble.device import Device
from bumble.device import Device, DeviceConfiguration
from bumble.hci import Address
from bumble.pairing import LeRole, OobData, OobSharedData
@@ -312,3 +312,17 @@ async def test_send_identity_address_command(
actual_command = mock_method.call_args.args[0]
assert actual_command.addr_type == expected_identity_address.address_type
assert actual_command.bd_addr == expected_identity_address
@pytest.mark.asyncio
async def test_smp_debug_mode():
config = DeviceConfiguration(smp_debug_mode=True)
device = Device(config=config)
assert device.smp_manager.ecc_key.x == smp.SMP_DEBUG_KEY_PUBLIC_X
assert device.smp_manager.ecc_key.y == smp.SMP_DEBUG_KEY_PUBLIC_Y
device.smp_manager.debug_mode = False
assert not device.smp_manager.ecc_key.x == smp.SMP_DEBUG_KEY_PUBLIC_X
assert not device.smp_manager.ecc_key.y == smp.SMP_DEBUG_KEY_PUBLIC_Y
+8 -6
View File
@@ -31,10 +31,10 @@ from bumble.transport.common import AsyncPipeSink
# -----------------------------------------------------------------------------
class Devices:
connections: list[Connection | None]
connections: dict[int, Connection]
def __init__(self, num_devices: int) -> None:
self.connections = [None for _ in range(num_devices)]
self.connections = {}
self.link = LocalLink()
addresses = [":".join([f"F{i}"] * 6) for i in range(num_devices)]
@@ -60,12 +60,14 @@ class Devices:
asyncio.get_event_loop().create_future() for _ in range(num_devices)
]
def on_connection(self, which, connection):
def on_connection(self, which: int, connection: Connection) -> None:
self.connections[which] = connection
connection.on('disconnection', lambda code: self.on_disconnection(which))
connection.on(
connection.EVENT_DISCONNECTION, lambda *_: self.on_disconnection(which)
)
def on_disconnection(self, which):
self.connections[which] = None
def on_disconnection(self, which: int) -> None:
self.connections.pop(which, None)
def on_paired(self, which: int, keys: PairingKeys) -> None:
self.paired[which].set_result(keys)
@@ -3,7 +3,7 @@
<head>
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Material+Symbols+Outlined:opsz,wght,FILL,GRAD@24,400,0,0" />
<script src="https://cdn.jsdelivr.net/pyodide/v0.24.1/full/pyodide.js"></script>
<script src="https://cdn.jsdelivr.net/pyodide/v0.26.2/full/pyodide.js"></script>
<script type="module" src="../ui.js"></script>
<script type="module" src="heart_rate_monitor.js"></script>
<style>
+1 -1
View File
@@ -89,7 +89,7 @@ class HeartRateMonitor:
async def stop(self):
# TODO: replace this once a proper reset is implemented in the lib.
await self.device.host.send_command(HCI_Reset_Command())
await self.device.host.send_sync_command(HCI_Reset_Command())
await self.device.power_off()
print('### Monitor stopped')
+1 -1
View File
@@ -3,7 +3,7 @@
<head>
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
<link rel="stylesheet" href="scanner.css">
<script src="https://cdn.jsdelivr.net/pyodide/v0.24.1/full/pyodide.js"></script>
<script src="https://cdn.jsdelivr.net/pyodide/v0.26.2/full/pyodide.js"></script>
<script type="module" src="../ui.js"></script>
<script type="module" src="scanner.js"></script>
</style>

Some files were not shown because too many files have changed in this diff Show More