Compare commits

...

574 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
312fc8db36 support controller-generated rpa 2024-08-05 08:59:05 -07:00
Gilles Boccon-Gibod
615691ec81 add basic RPA support 2024-08-01 15:37:11 -07:00
zxzxwu
ae8b83f294 Merge pull request #521 from zxzxwu/bap
Add Metadata LTV serializer and adapt Unicast
2024-07-31 11:36:46 +08:00
Josh Wu
4a8e21f4db Add Metadata LTV serializer and adapt Unicast 2024-07-31 01:20:28 +08:00
zxzxwu
3462e7c437 Merge pull request #439 from zxzxwu/mcp
Media Control Service Client implementation
2024-07-24 23:45:00 +08:00
Josh Wu
0f2e5239ad MCP constants and Client implementation 2024-07-24 22:57:26 +08:00
Gilles Boccon-Gibod
ee48cdc63f Merge pull request #517 from AlanRosenthal/scanner_pyee
Update scanner.py to use pyee.EventEmitter
2024-07-18 12:53:00 -07:00
Gilles Boccon-Gibod
1c278bec93 Merge pull request #518 from google/gbg/usb-queue
USB: better packet queue logic
2024-07-18 12:51:00 -07:00
Gilles Boccon-Gibod
6a51166af7 better packet queue logic 2024-07-17 17:48:26 -07:00
Alan Rosenthal
85d79fa914 Update scanner.py to use pyee.EventEmitter 2024-07-17 16:53:50 -04:00
zxzxwu
142bdce94a Merge pull request #515 from zxzxwu/unix
Add UNIX socket transport
2024-07-17 16:04:38 +08:00
Josh Wu
881a5a64b5 Add UNIX socket transport 2024-07-17 00:41:04 +08:00
zxzxwu
5aae44b610 Merge pull request #501 from zxzxwu/exception
Reorganize exceptions
2024-07-12 15:44:58 +08:00
Gilles Boccon-Gibod
e3ea167827 Merge pull request #506 from google/gbg/a2dp-fixes
a2dp: emit delay_report
2024-07-11 18:46:06 -07:00
Gilles Boccon-Gibod
eec145e095 add type hint 2024-07-11 18:39:02 -07:00
Gilles Boccon-Gibod
87fa02d6e5 Merge pull request #507 from google/packageFile
Create `inv web.build`
2024-07-11 18:35:29 -07:00
Gilles Boccon-Gibod
ad94c1e1f3 Merge pull request #509 from AlanRosenthal/discover
device.py: Add discover_all() api
2024-07-11 18:34:29 -07:00
Gilles Boccon-Gibod
546a0bce8d Merge pull request #510 from AlanRosenthal/get_characteristics_by_uuid
device.py: Update get_characteristics_by_uuid()
2024-07-11 18:33:45 -07:00
Gilles Boccon-Gibod
cb7ca44a1c Merge pull request #512 from AlanRosenthal/favicon
Add favicon.ico to docs folder
2024-07-11 18:27:19 -07:00
Gilles Boccon-Gibod
4081b93407 Merge pull request #513 from AlanRosenthal/devcontainer
Add devcontainer.json
2024-07-11 18:24:09 -07:00
Alan Rosenthal
26203ebaad Add devcontainer.json
devcontainer.json allows github's codespaces to be created with bumble's dependencies already installed
2024-07-11 18:47:32 +00:00
Alan Rosenthal
3389e3e1ed device.py: Update get_characteristics_by_uuid()
`get_characteristics_by_uuid()` now allows a UUID to be passed to the
service param. This allows for users to easily query for a service uuid
and characteristic uuid with one API.
2024-07-11 18:05:41 +00:00
Alan Rosenthal
7e1f01c01e Add favicon.ico to docs folder
Generated via: realfavicongenerator.net

validated via:
```
$ icotool -l favicon.ico
--icon --index=1 --width=48 --height=48 --bit-depth=32 --palette-size=0
--icon --index=2 --width=32 --height=32 --bit-depth=32 --palette-size=0
--icon --index=3 --width=16 --height=16 --bit-depth=32 --palette-size=0
```
2024-07-11 09:47:19 -04:00
Gilles Boccon-Gibod
613e15548a Merge pull request #511 from AlanRosenthal/random
console.py: Use Address.generate_static_address
2024-07-10 13:45:52 -07:00
Alan Rosenthal
e09c91df8e console.py: Use Address.generate_static_address 2024-07-10 18:51:46 +00:00
Alan Rosenthal
df206667b6 device.py: Add discover_all() api 2024-07-10 13:24:08 -04:00
Gilles Boccon-Gibod
0f19dd5263 Merge pull request #508 from google/web-readme
Add tip about disabling caching to web's readme
2024-07-09 09:17:25 -07:00
Alan Rosenthal
b98e4937f3 Add tip about disabling caching to web's readme 2024-07-09 13:48:53 +00:00
Alan Rosenthal
c2c46e9ace Create inv web.build
This command will build a wheel, copy it in the web directory, and create a file `packageFile` with the name of the wheel. If the correct override param is given, bumble.js will read `packageFile` and load that package.
2024-07-09 09:32:21 -04:00
Gilles Boccon-Gibod
27791cf218 emit delay_report 2024-07-03 13:51:15 -07:00
Gilles Boccon-Gibod
32a41a815d Merge pull request #502 from google/gbg/extended-advertising-termination-reverse
support out of order advertising set termination / connection events
2024-06-18 16:42:06 -07:00
Gilles Boccon-Gibod
df5fc2ddfe add test 2024-06-12 10:13:57 -07:00
Gilles Boccon-Gibod
79122313a6 Merge pull request #489 from google/gbg/basic-auracast-app
basic auracast app
2024-06-12 10:06:30 -07:00
Gilles Boccon-Gibod
d7d03e2e92 Merge pull request #504 from google/gbg/bench-role-change
bench role change
2024-06-12 10:06:11 -07:00
Gilles Boccon-Gibod
ea493480a9 remove duplicated lines 2024-06-11 13:23:35 -07:00
Gilles Boccon-Gibod
658f641a53 add manufacturer data 2024-06-11 13:21:04 -07:00
Josh Wu
f8a2d4f0e0 Reorganize exceptions
* Add BaseBumbleException as a "real" root error
* Add several core error classes and properly replace builtin errors
  with them
* Add several error classes for specific modules (transport, device)
2024-06-11 16:13:08 +08:00
Gilles Boccon-Gibod
00edd1fbf8 post-rebase fixes 2024-06-10 10:30:59 -07:00
Gilles Boccon-Gibod
999d7b07e1 wip 2024-06-09 11:39:44 -07:00
Gilles Boccon-Gibod
2e3aeb8648 support out of order advertising set termination / connection events 2024-06-05 16:29:31 -07:00
Gilles Boccon-Gibod
f910a696ad Merge pull request #499 from google/gbg/rfcomm-bridge
rfcomm bridge app
2024-06-05 11:18:13 -07:00
Gilles Boccon-Gibod
e1d10bc482 add rfcomm disconnect test 2024-06-05 10:03:27 -07:00
Gilles Boccon-Gibod
181467f11b Merge pull request #500 from google/gbg/fix-advertising-auto-restart
fix legacy advertising auto restart
2024-06-04 06:39:54 -07:00
Gilles Boccon-Gibod
394137b6f7 fix legacy advertising auto restart 2024-06-03 19:08:46 -07:00
Gilles Boccon-Gibod
dea907be86 attempt to fix pandora test (+3 squashed commits)
Squashed commits:
[759372d] address PR comments
[2f2a275] wip
[cc86b98] wip

wip

address PR comments

attempt to fix pandora test
2024-06-03 18:22:29 -07:00
Gilles Boccon-Gibod
f5baf51132 improve DLC parameters 2024-06-03 18:11:13 -07:00
Gilles Boccon-Gibod
f2dc8bd84e wip (+2 squashed commits)
Squashed commits:
[451a295] wip
[ed7b5b6] wip (+1 squashed commit)
Squashed commits:
[9d938c8] wip

wip

wip
2024-05-30 14:59:22 -07:00
zxzxwu
090309302f Merge pull request #372 from zxzxwu/source
ASCS Source Implementation
2024-05-29 13:17:51 +08:00
Charlie Boutier
28e6229b24 Fix: Preserve transport metadata
Preserve transport metadata when wrapping with SnoopingTransport
2024-05-28 09:20:53 -07:00
Josh Wu
1b66f03dbe ASCS: Add Source ASE operations 2024-05-27 14:48:23 +08:00
Gilles Boccon-Gibod
e34f6b5fd3 Merge pull request #484 from google/gbg/quick-fix-002
fix incorrect var reference
2024-05-13 16:11:42 -07:00
Gilles Boccon-Gibod
8a0482c947 Merge pull request #485 from google/gbg/gh-action-py312
add python 3.12 to GH actions
2024-05-13 16:11:25 -07:00
zxzxwu
938a189f3f Merge pull request #478 from zxzxwu/config
Make DeviceConfiguration dataclass
2024-05-13 16:57:15 +08:00
Gilles Boccon-Gibod
2005b4a11b python 3.12 compatibility 2024-05-12 12:54:52 -07:00
Gilles Boccon-Gibod
951fdc8bdd add python 3.12 to GH actions 2024-05-12 12:07:05 -07:00
Gilles Boccon-Gibod
12af7a526c fix incorrect var reference 2024-05-12 11:59:05 -07:00
zxzxwu
8781943646 Merge pull request #483 from zxzxwu/rfc
RFCOMM: Handle packets received before DLC sink set
2024-05-10 16:34:57 +08:00
Gilles Boccon-Gibod
7fbfdb634c Merge pull request #481 from google/gbg/command-status-fix
allow checking results for HCI_Command_Status_Event
2024-05-09 19:50:10 -07:00
Josh Wu
9682077f6b RFCOMM: Avoid receive packets before DLC sink set 2024-05-09 17:57:13 +08:00
Gilles Boccon-Gibod
22eb405fde Merge pull request #482 from servusdei2018/main
bumble.js(PacketSink): Implement asynchronous packet processing
2024-05-08 20:16:04 -07:00
zxzxwu
593c61973f Merge pull request #480 from zxzxwu/hfp-ag
HFP: Add AG example and fix errors
2024-05-07 17:50:01 +08:00
Josh Wu
ccff32102f HFP: Add example and fix AG errors 2024-05-07 00:36:52 +08:00
Nate
851d62c6c9 bumble.js(PacketSink): Implement asynchronous packet processing 2024-05-05 15:03:22 -04:00
Josh Wu
a5ac5f26e2 Make DeviceConfiguration dataclass 2024-05-05 17:25:01 +08:00
Gilles Boccon-Gibod
090158820f allow checking results for HCI_Command_Status_Event 2024-05-04 12:17:05 -07:00
zxzxwu
26e6650038 Merge pull request #477 from zxzxwu/hfp-ag
Fix HFP query call status
2024-05-02 01:17:17 +08:00
Josh Wu
c48568aabe Fix HFP query call status 2024-04-30 03:13:38 +00:00
zxzxwu
1b33c9eb74 Merge pull request #475 from zxzxwu/hfp-ag
Add more HFP command suppport
2024-04-26 12:01:20 +08:00
zxzxwu
6633228975 Add more HFP command suppport
* Support all Call Hold Operation
* Support CLI Presentation
* Support Voice Recognition
* Support RING and Volume Changes
* [AG] Support Enhanced Call Status
* Minor fixes
2024-04-24 15:29:48 +00:00
Gilles Boccon-Gibod
e9cba788a4 Merge pull request #473 from google/barbibulle-patch-2
quick fix: revert to protobuf 3.12.4
2024-04-22 11:46:04 +02:00
Gilles Boccon-Gibod
98822cfc6b quick fix: revert to protobuf 3.12.4
The upgrade to 4.x wasn't really needed, and breaks some users.
2024-04-18 21:20:18 -07:00
Gilles Boccon-Gibod
97ad7e5741 Merge pull request #472 from google/gbg/update-pandora-deps
update protobuf dep and make pandora install optional
2024-04-18 11:21:29 -07:00
Charlie Boutier
71df062e07 pyusb: power_cycle if '!' is present at the start of the transport 2024-04-17 14:12:55 -07:00
Charlie Boutier
049f9021e9 pyusb: powercycle the dongle 2024-04-17 14:12:55 -07:00
Gilles Boccon-Gibod
50eae2ef54 add pandora to code-check action 2024-04-17 13:19:07 -07:00
Gilles Boccon-Gibod
c8883a7d0f update protobuf dep and make pandora install optional 2024-04-17 13:14:21 -07:00
zxzxwu
51321caf5b Merge pull request #470 from zxzxwu/examples
Type hint all examples
2024-04-16 02:56:08 +08:00
zxzxwu
51a94288e2 Type hint all examples 2024-04-15 12:48:21 +00:00
zxzxwu
8758856e8c Merge pull request #465 from zxzxwu/hfp-ag
HFP AG implementation
2024-04-12 22:15:25 +08:00
Josh Wu
deba181857 HFP AG implementation 2024-04-10 09:51:37 +00:00
zxzxwu
c65188dcbf Merge pull request #466 from zxzxwu/format
Fix format presubmit error
2024-04-09 02:59:36 +08:00
Josh Wu
21d607898d Fix format presubmit error 2024-04-09 01:44:04 +08:00
Gilles Boccon-Gibod
2698d4534e Merge pull request #435 from jeru/main
open_tcp_server_transport: allow explicit sock as input.
2024-04-04 19:17:07 -07:00
zxzxwu
bbcd64286a Merge pull request #463 from zxzxwu/hfp
Correct HFP AG indicator index
2024-04-04 12:53:19 +08:00
Gilles Boccon-Gibod
9140afbf8c Merge pull request #456 from google/gbg/update-dependencies
update some dependencies
2024-04-03 17:50:18 -06:00
Gilles Boccon-Gibod
90a682c71b bump to avatar 0.0.9 2024-04-03 16:26:07 -07:00
Gilles Boccon-Gibod
e8737a8243 update to more recent versions 2024-04-03 10:00:11 -07:00
Gilles Boccon-Gibod
72fceca72e update some dependencies 2024-04-03 10:00:09 -07:00
Gilles Boccon-Gibod
732294abbc Merge pull request #462 from google/gbg/461
fix #461
2024-04-03 10:56:05 -06:00
Josh Wu
dc1204531e Correct HFP AG indicator index 2024-04-03 17:58:04 +08:00
Gilles Boccon-Gibod
962114379c fix #461 2024-04-02 23:14:32 -07:00
Gilles Boccon-Gibod
e6913a3055 Merge pull request #457 from google/gbg/bench-ascyncio-main
delay creation of runner object
2024-04-02 21:39:37 -06:00
Gilles Boccon-Gibod
e21d122aef Merge pull request #458 from google/gbg/update-formatter
update black formatter to version 24
2024-04-02 21:39:24 -06:00
Gilles Boccon-Gibod
58d4ab913a update black formatter to version 24 2024-04-01 14:44:46 -07:00
Gilles Boccon-Gibod
76bca03fe3 format with the project's version of black 2024-04-01 14:39:34 -07:00
Gilles Boccon-Gibod
f1e5c9e59e delay creation of runner object 2024-04-01 14:25:38 -07:00
zxzxwu
ec82242462 Merge pull request #440 from zxzxwu/hfp
Rework HFP example
2024-03-27 16:54:41 +08:00
zxzxwu
a4efdd3f3e Merge pull request #442 from zxzxwu/unicast_ad
Implement Unicast Server Advertising Data
2024-03-27 16:54:06 +08:00
Gilles Boccon-Gibod
69c6643bb8 Merge pull request #452 from marshallpierce/mp/rust-0.2.0
Bumble crate 0.2.0
2024-03-21 17:15:43 -07:00
Marshall Pierce
b8214bf948 Bumble crate 0.2.0 2024-03-21 12:36:32 -06:00
Charlie Boutier
a9c62c44b3 pandora host: change AdvertisingType
change advertising type from high duty to low duty

Test: python le_host_test.py -c config.yml --test_bed android.bumbles --tests "test_scan('connectable','non_scannable','directed',0)" -v
2024-03-20 11:17:50 -07:00
Charlie Boutier
7d0b4ef4e0 pandora_server: Parse FLAGS into advertising data
Bug: 328089785
2024-03-18 09:20:55 -07:00
Charlie Boutier
313340f1c6 intel driver: check the vendorId and productId 2024-03-15 10:53:33 -07:00
Charlie Boutier
e8ed69fb09 pyusb: Collect vendorId and productId as metadata 2024-03-15 10:53:33 -07:00
David Duarte
16d5cf6770 usb: Add usb path moniker
Add a new moniker for usb and pyusb driver allowing
to select the usb device using its bus id and port
path like `usb:3-3.4.1`.
2024-03-15 09:17:39 -07:00
Gilles Boccon-Gibod
a2caf1deb2 Merge pull request #448 from BenjaminLawson/bump-avatar
Bump pandora-avatar to 0.0.8
2024-03-14 20:49:28 -07:00
Ben Lawson
01bfdd2c98 Bump pandora-avater to 0.0.8 2024-03-14 14:13:27 -07:00
Gilles Boccon-Gibod
4a60df108a Merge pull request #447 from BenjaminLawson/bump-rootcanal
Bump rootcanal to 1.9.0
2024-03-14 14:00:36 -07:00
Ben Lawson
ad48109748 Bump rootcanal to 1.9.0 2024-03-14 13:15:02 -07:00
Cheng Sheng
1ceeccbbc0 open_tcp_server_transport: allow explicit sock as input.
When a user doesn't need an exact port, but cares more about getting
SOME unused port, they can do:
* Create a socket outside with port=None or port=0.
* Use socket.getsockname()[1] to get the allocated port and pass to the
TCP client somehow.
* Use the created socket to create a TCP server transport.

Use-case: unit-testing embedded software that implements a BLE host. The
controller will be a Bumble controller, connected to the host via a TCP
channel.
* The host will have a TCP-client HCI transport for testing.
* The pytest setup code will allocate the TCP server and pass the port
number to the host.

Also add some unittests with python mock.
2024-03-13 19:34:05 +01:00
Gilles Boccon-Gibod
44c51c13ac Merge pull request #445 from google/gbg/driver-probe-fix
fix intel driver probe
2024-03-12 12:51:08 -07:00
Gilles Boccon-Gibod
7507be1eab update metadata when setting the host controller directly 2024-03-12 11:50:47 -07:00
Gilles Boccon-Gibod
cbe9446dcf fix intel driver probe 2024-03-12 09:54:20 -07:00
Charlie Boutier
174930399a intel: send vsc INTEL_DDC_CONFIG_WRITE
This VSC enable host-initiated role-switching after connection.

Implement this VSC in a driver fashion.

Test: avatar security_test with the Bluetooth Dongle Intel BE200
2024-03-11 09:15:18 -07:00
Josh Wu
35db4a4c93 Implement Unicast Server Advertising Data 2024-03-08 16:48:37 +08:00
Gilles Boccon-Gibod
1f3aee5566 Merge pull request #438 from BenjaminLawson/pandora-extended-advertising
Implement Pandora extended advertising
2024-03-07 20:36:56 -08:00
Ben Lawson
256044a789 Implement Pandora extended advertising
Support setting the PHY of Pandora scans.
2024-03-07 16:18:49 -08:00
Josh Wu
6205199d7f Rework HFP example 2024-03-05 20:53:28 +08:00
Gilles Boccon-Gibod
e554bd1033 Merge pull request #434 from google/gbg/show-timestamps
show timestamps from snoop logs
2024-02-29 11:44:23 -08:00
Gilles Boccon-Gibod
38981cefa1 pad index field 2024-02-28 11:46:35 -08:00
Gilles Boccon-Gibod
f2d601f411 show timestamps from snoop logs 2024-02-27 16:40:37 -08:00
zxzxwu
6e7c64c1de Merge pull request #431 from zxzxwu/rust
Bump Rust to 1.76.0
2024-02-23 15:14:30 +08:00
Josh Wu
565d51f4db Bump Rust to 1.76.0
```
error: failed to compile `cargo-all-features v1.10.0`, intermediate artifacts can be found at `/tmp/cargo-installshCmAG`

Caused by:
  package `clap v4.5.1` cannot be built because it requires rustc 1.74 or newer, while the currently active rustc version is 1.70.0
  Try re-running cargo install with `--locked`

```
2024-02-22 15:22:20 +08:00
Gilles Boccon-Gibod
de8f3d9c1e Merge pull request #426 from akuker/patch-1
Add clarification to short circuit list feature
2024-02-12 21:22:14 -08:00
Tony Kuker
cde6d48690 Add clarification to short circuit list feature 2024-02-12 12:22:36 -06:00
zxzxwu
02180088b3 Merge pull request #425 from zxzxwu/command
Refactor command supporting list
2024-02-07 21:45:52 +08:00
zxzxwu
90f49267d1 Merge pull request #424 from zxzxwu/adv
Fix double-disable legacy advertising set
2024-02-06 16:13:51 +08:00
Josh Wu
0e6d69cd7b Refactor command supporting list 2024-02-06 12:06:00 +08:00
Josh Wu
9eccc583d5 Fix double-disable legacy advertising set
When legacy advertising set is disabled passively(by set termination),
the legacy advertising set won't be released, and the next
stop_advertising() call will try to disable it again and cause an error.
2024-02-06 12:00:30 +08:00
Gilles Boccon-Gibod
f4aeaa6eb3 Merge pull request #422 from google/gbg/bench-rfcomm-params
add rfcomm options and fix l2cap mtu negotiation
2024-02-05 09:14:16 -08:00
Gilles Boccon-Gibod
d7489a644a update websockets version (for better typecheck) 2024-02-05 09:07:39 -08:00
Gilles Boccon-Gibod
a877283360 add rfcomm options and fix l2cap mtu negotiation 2024-02-05 08:56:59 -08:00
zxzxwu
6d91e7e79b Merge pull request #423 from zxzxwu/vcp
Fix Lint error in VCP example
2024-02-06 00:40:05 +08:00
Josh Wu
567146b143 Fix Lint error in VCP example 2024-02-04 21:23:22 +08:00
zxzxwu
1a3272d7ca Merge pull request #412 from zxzxwu/vcp
Add Volume Control Service
2024-02-04 00:42:51 +08:00
zxzxwu
1ee1ff0b62 Merge pull request #420 from zxzxwu/rfc
Add RFCOMM and SDP context manager and search helper
2024-02-04 00:42:24 +08:00
zxzxwu
729fd97748 Merge pull request #419 from zxzxwu/feat
Add local LMP feature reader
2024-02-03 13:51:19 +08:00
Josh Wu
e308051885 Add LMP feature reader 2024-02-03 13:29:25 +08:00
Josh Wu
10e53553d7 Add RFCOMM and SDP helpers 2024-02-03 13:13:35 +08:00
Gilles Boccon-Gibod
ef0b30d059 Merge pull request #382 from google/gbg/extended-advertising-v2
extended advertising v2
2024-02-02 20:43:28 -08:00
Gilles Boccon-Gibod
e7e9f9509a update rootcanal version 2024-02-02 20:33:19 -08:00
zxzxwu
c6cfd101df Merge pull request #415 from zxzxwu/hfp
HFP: State memory and event emission
2024-02-02 11:36:53 +08:00
Josh Wu
d2dcf063ee HFP: State memory and event emit 2024-02-01 12:08:43 +08:00
Michael Mogenson
d15bc7d664 Merge pull request #417 from mogenson/controller-loopback-cid-range
controller_loopback: LE support and max packet count
2024-01-31 21:13:21 -05:00
zxzxwu
e4364d18a7 Merge pull request #418 from zxzxwu/rfc
RFCOMM: Slightly refactor and correct constants
2024-02-01 01:30:53 +08:00
Josh Wu
6a34c9f224 RFCOMM: Slightly refactor and correct constants 2024-02-01 01:18:56 +08:00
Michael Mogenson
2a764fd6bb controller_loopback: LE support and max packet count
Bound the packet count CLI option. We're using the L2CAP header CID for
a paket ID, so the max packet count value has to fit into this 16-bit
field.

Add support for controllers that are LE only by checking the
le_acl_packet_queue.max_size.

Tested with 65535 max packet count. Took 138 seconds at 481 kB/s with a
USB BT dongle.
2024-01-31 10:26:51 -05:00
Josh Wu
3e8ce38eba Add Volume Control Service 2024-01-31 10:04:30 +08:00
Gilles Boccon-Gibod
8d2f37aa7a inclusive language 2024-01-28 19:09:39 -08:00
Gilles Boccon-Gibod
b7b70ebcbb address PR comments 2024-01-28 19:09:37 -08:00
Gilles Boccon-Gibod
8ba91f4986 fix assert 2024-01-28 19:02:32 -08:00
Gilles Boccon-Gibod
79a5e953bc comply with limits for certain advertising event types 2024-01-28 19:02:32 -08:00
Gilles Boccon-Gibod
20de5ea250 format 2024-01-28 19:02:32 -08:00
Gilles Boccon-Gibod
bad9ce272c add doc 2024-01-28 19:02:32 -08:00
Gilles Boccon-Gibod
d3273ffa8c format (+3 squashed commits)
Squashed commits:
[60e610f] wip
[eeab73d] wip
[3cdd5b8] basic first pass
2024-01-28 19:02:30 -08:00
zxzxwu
071fc2723a Merge pull request #376 from zxzxwu/host
Manage lifecycle of CIS and SCO links in host
2024-01-28 22:09:08 +08:00
zxzxwu
ef4ea86f58 Merge pull request #381 from zxzxwu/offload
Support non-directed address generation offload
2024-01-28 22:08:32 +08:00
Gilles Boccon-Gibod
dfdaa149d0 Merge pull request #337 from google/gbg/avrcp
Add AVRCP support
2024-01-28 01:27:52 -08:00
Gilles Boccon-Gibod
986343a807 support multiple type checkers for pandora 2024-01-28 01:21:50 -08:00
Gilles Boccon-Gibod
5211d7ba96 revert to older pytest_asyncio 2024-01-28 01:10:31 -08:00
Gilles Boccon-Gibod
a167342778 deal with SupportsBytes for python <= 3.10 2024-01-28 01:04:13 -08:00
Gilles Boccon-Gibod
1efb8cdbee use matrixed python version 2024-01-28 00:34:42 -08:00
Gilles Boccon-Gibod
80d83e6a70 upgrade to mypy 1.8.0 2024-01-28 00:26:50 -08:00
Gilles Boccon-Gibod
31ec1c41ce cleanup 2024-01-28 00:07:31 -08:00
Gilles Boccon-Gibod
aba1ac0cea use a dict instead of a series of ifs (+6 squashed commits)
Squashed commits:
[90f2024] fix import order
[0edd321] add a few docstrings
[77a0ac0] wip
[adcf159] wip
[96cbd67] wip
[d8bfbab] wip (+1 squashed commit)
Squashed commits:
[43b4d66] wip (+2 squashed commits)
Squashed commits:
[3dafaa8] wip
[5844026] wip (+1 squashed commit)
Squashed commits:
[4cbb35a] wip (+1 squashed commit)
Squashed commits:
[4d2b6d3] wip (+4 squashed commits)
Squashed commits:
[f2da510] wip
[318c119] wip
[923b4eb] wip
[9d46365] wip

use a dict instead of a series of ifs (+6 squashed commits)
Squashed commits:
[90f2024] fix import order
[0edd321] add a few docstrings
[77a0ac0] wip
[adcf159] wip
[96cbd67] wip
[d8bfbab] wip
2024-01-27 16:26:17 -08:00
Josh Wu
c40824e51c Support non-directed address generation offload 2024-01-26 16:02:40 +08:00
Gilles Boccon-Gibod
2920f05dae Merge pull request #411 from AlanRosenthal/main
Add bumble-controller-loopback console_scripts
2024-01-24 13:20:19 -08:00
Alan Rosenthal
bc911d6da0 Add bumble-controller-loopback console_scripts 2024-01-24 14:07:35 -05:00
Gilles Boccon-Gibod
4f87f587e4 Merge pull request #409 from google/gbg/root-canal-update
update to rootcanal 1.4
2024-01-22 15:07:20 -08:00
Gilles Boccon-Gibod
3e38ab3638 update to rootcanal 1.4 2024-01-22 12:19:12 -08:00
Gilles Boccon-Gibod
21bb911fea Merge pull request #408 from suneeshs/btbench-update
Update the Bumble BT Bench AndroidManifest.xml
2024-01-22 12:15:35 -08:00
Suneesh Sasikumar
744dfa33a2 Update the Bumble BT Bench AndroidManifest.xml 2024-01-22 13:46:55 -05:00
zxzxwu
ec5f8535a8 Merge pull request #405 from zxzxwu/adv
Make Advertisement dataclass
2024-01-20 11:05:41 +08:00
Gilles Boccon-Gibod
5a83734a00 Merge pull request #388 from google/gbg/scan-with-irk
allow passing IRKs as arguments to scan.py
2024-01-19 11:37:02 -08:00
Josh Wu
b4ae8af3a7 Typing Advertisement 2024-01-19 15:16:24 +08:00
Josh Wu
da60386385 Manage lifecycle of CIS and SCO links in host 2024-01-18 11:56:38 +08:00
zxzxwu
45c4c4f4c5 Merge pull request #404 from zxzxwu/cis
Fix HCI_LE_Set_Host_Feature_Command
2024-01-18 10:56:05 +08:00
zxzxwu
9187c75d68 Merge pull request #397 from zxzxwu/controller
Controller: CIS implementation
2024-01-18 10:55:37 +08:00
zxzxwu
abeec22546 Merge pull request #402 from zxzxwu/key
Save Link Key in CTKD over BR/EDR
2024-01-18 10:55:14 +08:00
Josh Wu
a6bab755cf Fix HCI_LE_Set_Host_Feature_Command 2024-01-17 22:15:15 +08:00
Josh Wu
acd9d994c3 Save link_key in CTKD over BR/EDR
Since keystore.update() overwrites all existing keys, the existing link
key will be wiped out. To avoid this, SMP also need to keep the key.
2024-01-17 19:30:02 +08:00
Gilles Boccon-Gibod
37afda3ed3 Merge pull request #399 from google/gbg/hotfix-002
fix uninitialized variable
2024-01-16 17:29:48 -08:00
Gilles Boccon-Gibod
54f2981267 fix uninitialized variable 2024-01-16 16:49:06 -08:00
Charlie Boutier
bb025514e7 PandoraHost: compute advertising_interval_max with interval_range 2024-01-12 14:36:22 -08:00
Charlie Boutier
e228597269 Pandora host: support advertising interval in advertise 2024-01-12 14:36:22 -08:00
Gilles Boccon-Gibod
95b0d6c6f2 Merge pull request #398 from google/gbg/rfcomm-no-sink
update credits even without a sink
2024-01-11 13:18:15 -08:00
Josh Wu
fa4df6e3a2 Controller: CIS implementation 2024-01-11 01:16:42 +08:00
zxzxwu
46ceea7ecd Merge pull request #391 from zxzxwu/remote_feature
LE read remote features
2024-01-10 15:51:32 +08:00
Gilles Boccon-Gibod
30f89d5739 simplify 2024-01-09 18:01:34 -08:00
Gilles Boccon-Gibod
481cf40831 update credits even without a sink 2024-01-09 17:58:52 -08:00
Josh Wu
eff05afb7a LE read remote features 2024-01-09 11:30:08 +08:00
zxzxwu
d8e6700611 Merge pull request #383 from zxzxwu/controller
Controller: SCO implementation
2024-01-09 09:39:13 +08:00
Gilles Boccon-Gibod
56eb5a933b Merge pull request #394 from google/gbg/hci-latency
add support for HCI latency probing
2024-01-08 09:21:00 -08:00
Gilles Boccon-Gibod
caacc0c133 Merge pull request #395 from google/gbg/loopback-quick-fix
compatibility with recent host ACL property changes
2024-01-08 09:20:45 -08:00
Gilles Boccon-Gibod
5f377c024b format 2024-01-05 12:26:54 -08:00
Gilles Boccon-Gibod
00cd8fbdd0 compatibility with recent host ACL property changes 2024-01-05 12:17:09 -08:00
Gilles Boccon-Gibod
aeeff18428 add support for HCI latency probing 2024-01-05 10:26:04 -08:00
Michael Mogenson
c48e3f5e9c Merge pull request #393 from mogenson/controller-loopback
apps: Add a controller loopback throughput test app
2024-01-05 13:13:30 -05:00
Michael Mogenson
d6bbc1145a apps: Add a controller loopback throughput test app
Add a command line utility to open a transport to a BT controller, put
the controller into local loopback mode, and send and receive ACL data
packets. Record the time it takes to send and receive all packets and
calculate a throughput measurement in kB/s.

This utility is usefull for characterizing the speed of a transport to a
BT controller (such as a TCP socket or serial port) without having to
deal with a connected peer or the variability of over the air
transmissions.

The transport CLI argument is required. The packet size and packet
count arguments are optional. They default to the same values as the
bumble-bench app.
2024-01-05 10:01:24 -05:00
zxzxwu
e2fec67bd9 Merge pull request #390 from zxzxwu/csip
CSIP: Encrypted SIRK implementation
2024-01-04 13:28:23 +08:00
Josh Wu
88cb3b2a4d IWYU in CSIP 2024-01-04 13:22:09 +08:00
zxzxwu
9ebb03be46 Merge pull request #389 from zxzxwu/gitignore
.gitignore: Add venv directories
2024-01-04 12:54:30 +08:00
Gilles Boccon-Gibod
80d84af76c Merge pull request #392 from google/gbg/l2cap-drain
l2cap & rfcomm drain support
2024-01-03 09:59:36 -08:00
Gilles Boccon-Gibod
8f4721758f fix typo 2024-01-03 09:53:17 -08:00
Gilles Boccon-Gibod
8864af4acd format 2024-01-02 11:35:11 -08:00
Gilles Boccon-Gibod
8980fb8cc7 add drain support and a few tool options 2024-01-02 11:07:52 -08:00
Josh Wu
2c5f3472a9 CSIP: Encrypted SIRK implementation 2023-12-30 16:06:42 +08:00
Josh Wu
f18277ac78 Ignore venv directories 2023-12-30 14:23:35 +08:00
Josh Wu
8d46bc04d2 Controller: SCO implementation 2023-12-30 14:22:58 +08:00
Gilles Boccon-Gibod
09e5ea5dec Merge pull request #387 from google/gbg/async-gatt-server
support async read/write for characteristic values
2023-12-29 11:28:22 -08:00
Gilles Boccon-Gibod
d43281c57e allow passing IRKs as arguments 2023-12-28 14:35:23 -08:00
Gilles Boccon-Gibod
6810865670 Merge pull request #385 from google/gbg/android-enable-dle
request MTU change after connection
2023-12-28 13:46:25 -08:00
Gilles Boccon-Gibod
3e9e06a02c Merge pull request #386 from AlanRosenthal/main
app/bench.py: use logging rather than print()
2023-12-28 13:42:17 -08:00
Alan Rosenthal
ccd12f6591 app/bench.py: use logging rather than print() 2023-12-28 16:06:50 -05:00
Gilles Boccon-Gibod
f9a7843f7e request MTU change after connection 2023-12-28 11:17:18 -08:00
Gilles Boccon-Gibod
210c334db7 Merge pull request #380 from google/gbg/classic-buffer-size
support per-transport ACL queues
2023-12-28 09:24:52 -08:00
Gilles Boccon-Gibod
f297cdfcce Merge pull request #384 from eukub/string-concatination-to-fstring
сhanged concatenation of strings to f-strings to improve readability
2023-12-28 09:24:25 -08:00
eukub
5b536d00ab сhanged concatenation of strings to f-strings to improve readability and unify with the rest of code 2023-12-28 16:27:36 +03:00
Gilles Boccon-Gibod
b4af46ebd5 use TCP_NODELAY on socket 2023-12-27 12:11:20 -08:00
Gilles Boccon-Gibod
c08da3193e format 2023-12-27 11:56:06 -08:00
Gilles Boccon-Gibod
f2925ca647 support async read/write for characteristic values 2023-12-27 11:52:22 -08:00
Gilles Boccon-Gibod
fd4d68e5c0 print controller flow control info 2023-12-26 13:24:24 -08:00
Gilles Boccon-Gibod
5d83deffa4 Merge pull request #345 from rdhavan/bumble_hid_device
Bumble hid device implementation - Application and hid profile
2023-12-26 11:10:34 -08:00
Gilles Boccon-Gibod
2878cca478 Merge pull request #378 from benquike/pair_linger
Improve the linger option of bumble-pair
2023-12-26 10:55:28 -08:00
Gilles Boccon-Gibod
53934716db Merge pull request #377 from benquike/irk
Add functions/tool for gen/verifying BLE IRK/RPA
2023-12-26 10:54:18 -08:00
Hui Peng
d885d45824 Add functions/tool for gen/verifying BLE IRK/RPA 2023-12-26 09:34:19 -08:00
Gilles Boccon-Gibod
b90d0f8710 fix tests 2023-12-26 09:09:20 -08:00
zxzxwu
8ccfc90fe6 Merge pull request #379 from zxzxwu/addr
Add random address generation methods
2023-12-25 17:28:49 +08:00
Josh Wu
92aa7e9e2a Add random address generation methods 2023-12-24 18:07:40 +08:00
Gilles Boccon-Gibod
afc6d19e04 address PR comments 2023-12-23 14:21:44 -08:00
Gilles Boccon-Gibod
c05f073b33 Update bumble/host.py
Co-authored-by: zxzxwu <92432172+zxzxwu@users.noreply.github.com>
2023-12-23 14:15:53 -08:00
Gilles Boccon-Gibod
2b4c2a22f4 format 2023-12-22 14:22:08 -08:00
Gilles Boccon-Gibod
47fe93a148 support per-transport ACL queues 2023-12-22 13:52:33 -08:00
zxzxwu
6139ca8045 Merge pull request #374 from zxzxwu/csip
Complete CSIP and CAP
2023-12-23 02:49:35 +08:00
Josh Wu
87c76a4a0e Complete CSIP and CAP
Also add random address generation functions.
2023-12-23 02:14:32 +08:00
Hui Peng
f7b66db873 Improve the linger option in pair tool
No matter pairing fails or not, make linger effective
2023-12-21 17:25:42 -08:00
skarnataki
0b314bd7f7 Updated absctract class and method for on_ctrl_pdu in hid.py 2023-12-18 13:36:25 +00:00
skarnataki
9da2e32ad7 Review comment Fix 3 - rename json file and usage of Optional in parameters 2023-12-15 09:42:57 +00:00
Snehal Karnataki
93c0875740 Merge branch 'google:main' into bumble_hid_device 2023-12-13 09:51:27 +00:00
Gilles Boccon-Gibod
a286700239 Merge pull request #368 from google/gbg/driver-load-before-reset
support drivers that can't use reset directly.
2023-12-11 18:06:23 -08:00
Gilles Boccon-Gibod
98ed772e8a address PR comments and add some typing 2023-12-11 17:52:04 -08:00
Gilles Boccon-Gibod
f0b55a4f97 Merge pull request #367 from google/gbg/android-bench-update
Android bench app: add support for 2M phy
2023-12-11 10:20:56 -08:00
zxzxwu
b74503d345 Merge pull request #359 from zxzxwu/ascs
Audio Stream Control Service
2023-12-12 00:47:03 +08:00
Josh Wu
f911163e49 Improve ASCS logging 2023-12-12 00:36:24 +08:00
Gilles Boccon-Gibod
b083cc99ad fix spec parsing 2023-12-08 18:57:02 -08:00
Gilles Boccon-Gibod
d35643524e allow specifying the address type 2023-12-08 18:46:25 -08:00
Gilles Boccon-Gibod
62a8ced447 support drivers that can't use reset directly. 2023-12-08 17:28:57 -08:00
Gilles Boccon-Gibod
085f163c92 add support for 2M phy 2023-12-08 10:14:38 -08:00
Josh Wu
81a6b1e097 Replace 3.9 dict merger 2023-12-08 11:10:17 +08:00
Josh Wu
dd090c9e6b Add ASCS tests 2023-12-08 11:00:44 +08:00
Josh Wu
11faa48422 Fix ASE state change 2023-12-08 09:53:14 +08:00
Josh Wu
55596176c2 ffplay routing 2023-12-08 09:53:14 +08:00
Josh Wu
4d6822d312 Remove ISO data path on release 2023-12-08 09:53:14 +08:00
Josh Wu
985c365e6d Setup data path after CIS established 2023-12-08 09:53:14 +08:00
Josh Wu
af57762227 Parse CodecSpecificConfiguration 2023-12-08 09:53:14 +08:00
Josh Wu
3575f9030e Add Audio Stream Control Service 2023-12-08 09:53:14 +08:00
zxzxwu
698d947d85 Merge pull request #366 from zxzxwu/extadv
Add advertiser classes and handle adv set terminated events
2023-12-08 09:52:42 +08:00
Josh Wu
ff6528d2bf Add Advertising unit tests 2023-12-08 01:38:01 +08:00
Josh Wu
72ac75a98d Add advertiser classes and handle adv set terminated events
* Convert hci.OwnAddressType to enum
* Add LegacyAdvertiser and ExtendedAdvertiser classes
* Rename start/stop_advertising() => start/stop_legacy_advertising()
* Handle HCI_Advertising_Set_Terminated
* Properly restart advertisement on disconnection
2023-12-07 15:51:51 +08:00
skarnataki
5e3ecb74e4 Review comment fix -2 2023-12-05 13:41:30 +00:00
Snehal Karnataki
c59be293c8 Merge branch 'google:main' into bumble_hid_device 2023-12-05 13:07:36 +00:00
zxzxwu
88b4cbdf1a Merge pull request #364 from zxzxwu/iso
Fix ISO packet issues
2023-12-05 00:41:56 +08:00
Josh Wu
d6afbc6f4e Fix ISO packet issues 2023-12-04 20:31:11 +08:00
Gilles Boccon-Gibod
fc90de3e7b Merge pull request #351 from google/dependabot/cargo/rust/openssl-0.10.60
Bump openssl from 0.10.57 to 0.10.60 in /rust
2023-12-04 00:41:27 -08:00
Gilles Boccon-Gibod
847c2ef114 Merge pull request #362 from google/gbg/more-le-features-constants
a few more HCI constants from the spec
2023-12-04 00:38:02 -08:00
Gilles Boccon-Gibod
a0bf0c1f4d Merge pull request #363 from google/gbg/android-remote-proxy-cli
android remote proxy cli
2023-12-04 00:37:49 -08:00
Gilles Boccon-Gibod
8400ff0802 shared usage printer 2023-12-04 00:37:28 -08:00
Gilles Boccon-Gibod
0ed6aa230b address PR comment 2023-12-04 00:32:04 -08:00
Snehal Karnataki
6d22ed80ec Merge branch 'google:main' into bumble_hid_device 2023-12-04 07:29:04 +00:00
Gilles Boccon-Gibod
72d5360af9 keep projects compatible with Android Studio Hedgehog 2023-12-03 18:06:54 -08:00
Gilles Boccon-Gibod
ac3961e763 add doc 2023-12-03 17:50:42 -08:00
Gilles Boccon-Gibod
843466c822 a few more constants from the spec 2023-12-03 17:16:25 -08:00
Gilles Boccon-Gibod
8385035400 add CLI support 2023-12-03 16:35:14 -08:00
zxzxwu
3adcc8be09 Merge pull request #360 from zxzxwu/hci
Remove # type: ignore[call-arg] in HCI_Command builders
2023-12-03 19:18:04 +08:00
zxzxwu
c853d56302 Merge pull request #361 from zxzxwu/hci-bug
Fix typo
2023-12-03 04:22:59 +08:00
Josh Wu
dc97be5b35 Fix typo 2023-12-02 23:42:21 +08:00
zxzxwu
73dbdfff9f Merge pull request #356 from zxzxwu/bap
Add Published Audio Capabilities Service
2023-12-02 23:34:57 +08:00
Josh Wu
dff14e1258 Add Published Audio Capabilities Service 2023-12-02 23:16:37 +08:00
Josh Wu
10a3833893 Remove # type: ignore[call-arg] in HCI_Command builders 2023-12-02 19:18:54 +08:00
zxzxwu
247cb89332 Merge pull request #358 from zxzxwu/coding2
Add variable-length bytes field
2023-12-01 03:26:38 +08:00
Josh Wu
3fc71a0266 Add variable-length bytes field 2023-12-01 03:16:52 +08:00
zxzxwu
392dcc3a05 Merge pull request #357 from zxzxwu/coding
Refactor CodingFormat
2023-12-01 03:15:33 +08:00
Josh Wu
f27015d1b7 Refactor CodingFormat
As CodingFormat is now used by HFP and LEA, and vendor specific codecs
are introduced, this object needs to provide more information.
2023-12-01 02:58:09 +08:00
zxzxwu
86a19b41aa Merge pull request #344 from zxzxwu/cis
CIS and SCO responder support
2023-11-30 21:00:55 +08:00
Gilles Boccon-Gibod
320164d476 Merge pull request #355 from google/gbg/fix-gatt-unsubscribe
fix #354 (gatt unsubscribe)
2023-11-29 22:28:57 -08:00
Josh Wu
40ae661ee5 More SCO support and warnings and typo fix 2023-11-30 12:59:43 +08:00
Snehal Karnataki
ffb3eca68b Merge branch 'google:main' into bumble_hid_device 2023-11-30 04:50:05 +00:00
Josh Wu
c5def93bb8 CIS and SCO responder support 2023-11-30 12:16:40 +08:00
zxzxwu
a9c4c5833d Merge pull request #350 from zxzxwu/csip
Add Coordinated Set Identification Service(CSIS)
2023-11-30 12:15:56 +08:00
Gilles Boccon-Gibod
58c9c4f590 fix #354 2023-11-29 19:19:40 -08:00
zxzxwu
24524d88cb Merge pull request #342 from zxzxwu/typing
Typing helper
2023-11-30 00:21:44 +08:00
zxzxwu
b8849ab311 Merge pull request #349 from zxzxwu/stack
Log track back in on_packet
2023-11-30 00:20:20 +08:00
Josh Wu
f3cd8f8ed0 Typing helper 2023-11-29 21:24:27 +08:00
zxzxwu
2b26de3f3a Merge pull request #348 from zxzxwu/gattc
Typing GATT Client and Device Peer
2023-11-29 15:09:40 +08:00
Josh Wu
0149c4c212 Log track back in on_packet
Many errors are raised in on_packet() callbacks, but currently it only
provides a very brief error message.
2023-11-29 15:01:15 +08:00
Gilles Boccon-Gibod
f2ed898784 Merge pull request #352 from google/gbg/more-gatt-uuids
add a few uuids
2023-11-28 22:44:40 -08:00
Josh Wu
464a476f9f Add CSIP 2023-11-29 14:09:31 +08:00
Gilles Boccon-Gibod
e85d067fb5 add a few uuids 2023-11-28 20:02:00 -08:00
dependabot[bot]
7eb493990f Bump openssl from 0.10.57 to 0.10.60 in /rust
Bumps [openssl](https://github.com/sfackler/rust-openssl) from 0.10.57 to 0.10.60.
- [Release notes](https://github.com/sfackler/rust-openssl/releases)
- [Commits](https://github.com/sfackler/rust-openssl/compare/openssl-v0.10.57...openssl-v0.10.60)

---
updated-dependencies:
- dependency-name: openssl
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-11-28 21:43:18 +00:00
Josh Wu
04d5bf3afc Typing GATT Client and Device Peer 2023-11-28 21:57:57 +08:00
skarnataki
403a13e4c6 Review comment fix HID device 2023-11-28 13:42:25 +00:00
Snehal Karnataki
ad0f035df5 Merge branch 'google:main' into bumble_hid_device 2023-11-28 13:06:32 +00:00
zxzxwu
a13e193d3b Merge pull request #343 from zxzxwu/lea-gatt
Add LE Audio GATT services and characteristics definitions
2023-11-28 10:34:39 +08:00
Gilles Boccon-Gibod
28a1a5ebc2 Merge pull request #347 from akuker/main
Include transport.grpc_protobuf in the setup package.
2023-11-27 15:23:17 -08:00
Tony Kuker
6310dc777f Include transport.grpc_protobuf in the setup package. 2023-11-27 16:48:37 -06:00
skarnataki
07f71fc895 Project format and lint error fix. Redefination if Device class needs to be discussed 2023-11-27 13:04:54 +00:00
Fahad Afroze
f47b9178ad Added GET_REPORT and SET_REPORT changes
Added changes to handle invalid cases
2023-11-27 11:55:35 +00:00
Josh Wu
863de18877 Add LE Audio GATT definitions 2023-11-27 17:53:00 +08:00
SneKarnataki
4f399249bd Merge branch 'google:main' into bumble_hid_device 2023-11-27 09:00:44 +00:00
zxzxwu
f0e5cdee1a Merge pull request #339 from zxzxwu/enc
Refactor crypto and fix CTKD
2023-11-27 14:05:37 +08:00
zxzxwu
7bc7d0f5af Merge pull request #334 from zxzxwu/extadv
Add support for LE Extended Advertising
2023-11-27 14:01:31 +08:00
Josh Wu
a65a215fd7 Provide IntFlag.name property fallback 2023-11-26 19:42:22 +08:00
Josh Wu
80d34a226d Slightly refactor and fix CTKD
It seems sample input data provided in the spec is big-endian (just
like other AES-CMAC-based functions), but all keys are in little-endian(
HCI standard), so they need to be reverse before and after applying
AES-CMAC.
2023-11-26 16:55:10 +08:00
Josh Wu
a9628f73e3 Add support for Extended Advertising 2023-11-26 15:03:09 +08:00
skarnataki
9324237828 send_data comment fix and lint error fix 2023-11-24 11:13:20 +00:00
Fahad Afroze
d1033c018a Modified DeviceData class 2023-11-24 05:42:31 +00:00
Fahad Afroze
0f29052ade Added mousemove changes
Also modified keyboard data on keyup
2023-11-23 17:46:55 +00:00
skarnataki
0578e84586 Menu and name change review comments fix 2023-11-23 15:43:22 +00:00
Fahad Afroze
6ab41c466f Add review comment changes 3 2023-11-23 12:27:56 +00:00
Fahad Afroze
98a1093ebf Add review comment changes 2
Also corrected sending mouseData
2023-11-23 09:53:16 +00:00
dhavan
caf04373f3 keyboard data moved to DeviceData class 2023-11-23 08:01:07 +00:00
SneKarnataki
d4e8526766 Merge branch 'google:main' into bumble_hid_device 2023-11-23 07:59:43 +00:00
dhavan
515b83a8c7 deleted: bumble/classic3.json
modified:   examples/keyboard.html
2023-11-23 06:10:52 +00:00
Lucas Abel
9bf2e03354 device: set authenticated and sc state on AES encryption change 2023-11-23 06:39:55 +01:00
dhavan
dc18595c8a MTU size check added 2023-11-23 05:17:44 +00:00
SneKarnataki
488bcfe9c6 Merge branch 'google:main' into bumble_hid_device 2023-11-23 04:03:53 +00:00
Gilles Boccon-Gibod
2900b93bb3 Merge pull request #120 from google/gbg/usb-cleanup
minor cleanup of the internals of the usb transport implementation
2023-11-22 17:18:23 -08:00
Gilles Boccon-Gibod
284cc8a321 Merge pull request #326 from google/gbg/android-benchmark-app
Android benchmarking app
2023-11-22 15:39:52 -08:00
Gilles Boccon-Gibod
3dc2e4036c rebase 2023-11-22 15:32:37 -08:00
Gilles Boccon-Gibod
268f6b0d51 remove unneeded constructor parameters 2023-11-22 15:30:18 -08:00
Gilles Boccon-Gibod
46239b321b address PR comments 2023-11-22 15:30:18 -08:00
Gilles Boccon-Gibod
8a536cd522 fix missed merge 2023-11-22 15:30:18 -08:00
Gilles Boccon-Gibod
f9f5d7ccbd first implementation (+1 squashed commit)
Squashed commits:
[ee00d67] wip
2023-11-22 15:30:16 -08:00
dhavan
d6cefdff8e Renamed the status message class 2023-11-22 17:14:24 +00:00
dhavan
dc410b14c4 SET_REPORT and GET_REPORT implemented 2023-11-22 16:05:33 +00:00
dhavan
4c49ef9403 SET_REPORT implemented 2023-11-22 12:31:34 +00:00
dhavan
ba85dcbda5 Get the changes from hid_device to bumble_hid_device
Modified the get_report_cb
2023-11-22 11:06:27 +00:00
zxzxwu
e08c84dd20 Merge pull request #333 from zxzxwu/iso
Add ISO related HCI packets
2023-11-21 15:55:00 +08:00
Josh Wu
8b46136703 Add ISO related HCI packets 2023-11-20 22:47:02 +08:00
Gilles Boccon-Gibod
9c7089c8ff terminate when unplugged 2023-11-19 11:36:38 -08:00
Gilles Boccon-Gibod
aac8d89cd0 Merge pull request #330 from benquike/main
Do not exit after pairing is finished
2023-11-18 08:57:58 -08:00
Hui Peng
24e75bfeab Do not exit after pairing is finished
Android performs additional service
discovery during pairing, otherwise
pairing fails.
2023-11-17 09:17:40 -08:00
zxzxwu
42868b08d3 Merge pull request #335 from zxzxwu/a2dp
Typing A2DP
2023-11-18 00:21:20 +08:00
zxzxwu
19b61d9ac0 Merge pull request #336 from zxzxwu/hid
Cleanup HID module
2023-11-17 23:34:03 +08:00
Josh Wu
db2a2e2bb9 Cleanup HID module
* Remove unused imports
* Replace typing exceptions by better assertions
2023-11-17 17:43:07 +08:00
Josh Wu
e1fdb12647 Typing A2DP 2023-11-17 17:29:35 +08:00
Gilles Boccon-Gibod
a8ec1b0949 minor cleanup of the internals of the usb transport implementation 2023-11-15 17:26:21 -08:00
Gilles Boccon-Gibod
2e30b2de77 Merge pull request #329 from google/gbg/le-oob
le oob
2023-11-15 16:10:20 -08:00
Gilles Boccon-Gibod
7e407ccae1 address PR comments 2023-11-15 15:48:19 -08:00
zxzxwu
0667e83919 Merge pull request #254 from zxzxwu/sco
eSCO codec/HCI definitions + Host support
2023-11-13 20:01:06 +08:00
Gilles Boccon-Gibod
1a6c9a4d04 improve help 2023-11-10 12:17:21 -08:00
Gilles Boccon-Gibod
14f5b912ad use ad_data directly 2023-11-10 11:53:54 -08:00
Gilles Boccon-Gibod
46d6242171 Merge pull request #316 from whitevegagabriel/extended
Add support for extended advertising via Rust-only API
2023-11-09 13:43:00 -08:00
Gilles Boccon-Gibod
753b966148 format 2023-11-09 12:44:02 -08:00
Gilles Boccon-Gibod
5a307c19b8 add oob data on command line 2023-11-07 20:38:35 -08:00
Lucas Abel
2cd4f84800 pandora: add annotations import 2023-11-06 14:06:56 -08:00
Gilles Boccon-Gibod
4ae612090b wip 2023-11-06 13:19:13 -08:00
Gilles Boccon-Gibod
c67ca4a09e Merge pull request #324 from google/gbg/hotfix-002
fix typo
2023-10-31 20:58:19 +01:00
Gilles Boccon-Gibod
94506220d3 fix typo 2023-10-31 12:18:28 -07:00
Gilles Boccon-Gibod
dbd865a484 Merge pull request #323 from google/gbg/device-hive
Device hive
2023-10-31 16:44:18 +01:00
Gilles Boccon-Gibod
9d2f3e932a format 2023-10-29 11:32:00 -07:00
Gilles Boccon-Gibod
49d32f5b5b add netsim.ini info 2023-10-29 10:26:34 -07:00
Gilles Boccon-Gibod
f7b74c0bcb add hive to index page 2023-10-29 10:03:31 -07:00
Gilles Boccon-Gibod
c75cb0c7b7 fix css 2023-10-29 09:58:37 -07:00
Gilles Boccon-Gibod
a63b335149 wip 2023-10-29 09:36:17 -07:00
Gilles Boccon-Gibod
d8517ce407 add links 2023-10-29 08:53:25 -07:00
Gilles Boccon-Gibod
ad13b11464 wip 2023-10-29 08:53:23 -07:00
Gilles Boccon-Gibod
99bc92d53d wip (+5 squashed commits)
Squashed commits:
[53c6c53] wip
[66f482c] wip
[b003315] wip
[f6f9d9e] wip
[4c95c7b] wip
2023-10-29 08:50:25 -07:00
Josh Wu
72199f5615 Add address resolution offload to config 2023-10-24 17:04:43 -07:00
skarnataki
78b8b50082 fixed lint errors 2023-10-19 17:19:49 -07:00
skarnataki
3ab64ce00d Fixed lint and pre-commit errors. 2023-10-19 17:19:49 -07:00
skarnataki
651e44e0b6 Submitting review comment fix: header function and extra lines.
Executed formatter on file.
2023-10-19 17:19:49 -07:00
skarnataki
963fa41a49 Submitting review comment fix: header function and extra lines. 2023-10-19 17:19:49 -07:00
skarnataki
493f4f8b95 Submitting review comment fix: header function and spacing 2023-10-19 17:19:49 -07:00
skarnataki
fc1bf36ace Review changes comment fix. Classes/Subclass/dataclass. Enum constants.
Naming conventions
2023-10-19 17:19:49 -07:00
skarnataki
5ddee17411 Commit to fix review comments for dataclass and subclass, shifting contants to Message Class
Commit for enum and dataclass
2023-10-19 17:19:49 -07:00
skarnataki
5ce353bcde Review comment Fix 2023-10-19 17:19:49 -07:00
SneKarnataki
16d33199eb Change in sdp.py file while testing hid profile,
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')) changed to
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x)
as we were facing error "UnicodeDecodeError: 'utf-8' codec can't decode byte 0xa1 in position 4: invalid start byte" while fetching sdp records.
2023-10-19 17:19:49 -07:00
SneKarnataki
e02303a448 Submitting the initial version of HID Profile files
Includes:
1. HID Host implementation - hid.py
2. HID application to test Host with 3rd party HID Device application - run_hid_host.py
3. HID supporting files for testing - hid_report_parser.py & hid_key_map.py

Commands to run the application:
Default application:
python run_hid_host.py classic1.json usb:0 <device bd-addr>

Menu options for testing (Get/Set):
python run_hid_host.py classic1.json usb:0 <device bd-addr> test-mode

CuttleFish:tcp-client:127.0.0.1:7300

Application used for testing as Device : Bluetooth Keyboard & Mouse-5.3.0.apk

Note: Change in sdp.py file while testing hid profile,
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')) changed to
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x)
as we were facing error "UnicodeDecodeError: 'utf-8' codec can't decode byte 0xa1 in position 4: invalid start byte" while fetching sdp records.
2023-10-19 17:19:49 -07:00
Fahad Afroze
36fc966ad6 Trial checkin code 2023-10-19 17:19:49 -07:00
skarnataki
644f74400d Trial to commit in dhavan repo 2023-10-19 17:19:49 -07:00
dhavan
b7cd451ddb Hid profile implemenation. Empty file 2023-10-19 17:19:49 -07:00
Gabriel White-Vega
59d7717963 Remove mutable ret pattern and test feature combinations
After adding test for feature combinations, I found a corner case where, when Transport is dropped and the process is terminated in a test, the `close` Python future is not awaited.
I don't know what other situations this issue may arise, so I have safe-guarded it via `block_on` instead of spawning a thread.
2023-10-18 15:39:37 -04:00
Gilles Boccon-Gibod
88392efca4 Merge pull request #312 from google/gbg/android-remote-hci
remote hci android app
2023-10-17 13:30:49 +02:00
zxzxwu
907f2acc7e Merge pull request #318 from zxzxwu/l2cap_refactor
Cleanup legacy L2CAP API usage
2023-10-17 14:22:45 +08:00
Gilles Boccon-Gibod
6616477bcf Merge pull request #319 from google/gbg/bt-spec-version-5-4
add constant for 5.4
2023-10-11 18:44:03 -07:00
Gilles Boccon-Gibod
5b173cb879 add constant for 5.4 2023-10-11 17:47:21 -07:00
Gilles Boccon-Gibod
dc6b466a42 add intent parameters 2023-10-11 16:52:15 -07:00
zxzxwu
8b04161da3 Merge pull request #317 from zxzxwu/pytest
Add missing @pytest.mark.asyncio decorator
2023-10-11 15:16:35 +08:00
Josh Wu
5a85765360 Cleanup legacy L2CAP API 2023-10-11 14:33:44 +08:00
Josh Wu
333940919b Add missing @pytest.mark.asyncio decorator 2023-10-11 13:52:06 +08:00
Gilles Boccon-Gibod
b9476be9ad Merge pull request #315 from google/gbg/company-ids
update to latest list of company ids
2023-10-10 22:13:16 -07:00
Gilles Boccon-Gibod
704c60491c Merge pull request #313 from benquike/pair_fix
Allow turning on BLE in classic pairing mode
2023-10-10 21:30:24 -07:00
Gilles Boccon-Gibod
4a8e612c6e update rust list 2023-10-10 21:29:39 -07:00
Gilles Boccon-Gibod
5e5c9c2580 fix byte order and packet accounting 2023-10-10 21:17:20 -07:00
Gilles Boccon-Gibod
4e71ec5738 remove stale comment 2023-10-10 20:36:48 -07:00
Gabriel White-Vega
1004f10384 Address PR comments 2023-10-10 16:45:02 -04:00
Gabriel White-Vega
1051648ffb Add support for extended advertising via Rust-only API
* Extended functionality is gated on an "unstable" feature
* Designed for very simple use and minimal interferance with existing legacy implementation
* Intended to be temporary, until bumble can integrate extended advertising into its core functionality
* Dropped `HciCommandWrapper` in favor of using bumble's `HCI_Command.from_bytes` for converting from PDL into bumble implementation
* Refactored Address and Device constructors to better match what the python constructors expect
2023-10-10 13:35:31 -04:00
uael
7255a09705 ci: add python avatar tests 2023-10-09 23:37:23 +02:00
zxzxwu
c2bf6b5f13 Merge pull request #289 from zxzxwu/l2cap_refactor
Refactor L2CAP API
2023-10-09 23:27:25 +08:00
Gilles Boccon-Gibod
d8e699b588 use the new yaml file instead of the previous CSV file 2023-10-07 23:10:49 -07:00
zxzxwu
3e4d4705f5 Merge pull request #314 from zxzxwu/sec_pandora
Pandora: Handle exception in WaitSecurity()
2023-10-08 01:42:45 +08:00
Josh Wu
c8b2804446 Pandora: Handle exception in WaitSecurity() 2023-10-07 21:17:01 +08:00
Josh Wu
e732f2589f Refactor L2CAP API 2023-10-07 20:01:15 +08:00
zxzxwu
aec5543081 Merge pull request #310 from zxzxwu/avdtp
Typing AVDTP
2023-10-07 19:50:56 +08:00
Josh Wu
e03d90ca57 Add typing for MediaCodecCapabilities members 2023-10-07 19:32:19 +08:00
Josh Wu
495ce62d9c Typing AVDTP 2023-10-07 19:32:19 +08:00
Hui Peng
fbc3959a5a Allow turning on BLE in classic pairing mode 2023-10-06 19:54:18 -07:00
Gilles Boccon-Gibod
246b11925c add remote hci android app 2023-10-06 14:10:51 -07:00
Gilles Boccon-Gibod
dfa9131192 Merge pull request #311 from zxzxwu/rust
Fix Rust lints
2023-10-06 13:37:47 -07:00
Josh Wu
88c801b4c2 Replace or_insert_with with or_default 2023-10-06 18:02:46 +08:00
Gilles Boccon-Gibod
a1b55b94e0 Merge pull request #301 from whitevegagabriel/simplify-event-loop-copy
Remove unncecesary steps for injecting Python event loop
2023-10-02 12:12:41 -07:00
Gilles Boccon-Gibod
80db9e2e2f Merge pull request #303 from whitevegagabriel/hci-command-rs
Ability to send HCI commands from Rust
2023-10-02 12:12:05 -07:00
Gabriel White-Vega
ce74690420 Update pdl to 0.2.0
- Allows removing impl PartialEq for pdl Error
2023-10-02 11:20:44 -04:00
Gilles Boccon-Gibod
50de4dfb5d Merge pull request #307 from google/gbg/hotfix-001
don't delete advertising prefs on disconnection
2023-09-30 17:46:53 -07:00
Gilles Boccon-Gibod
9bcdf860f4 don't delete advertising prefs on disconnection 2023-09-30 17:41:18 -07:00
Gabriel White-Vega
511ab4b630 Add python async wrapper, move hci non-wrapper to internal, add hci::internal tests 2023-09-29 10:23:19 -04:00
Gilles Boccon-Gibod
6f2b623e3c Merge pull request #290 from google/gbg/netsim-transport-injectable-channels
make grpc channels injectable
2023-09-27 22:16:05 -07:00
Gilles Boccon-Gibod
fa12165cd3 Merge pull request #298 from google/gbg/use-address-to-string
use Address.to_string instead of manual suffix replacement
2023-09-27 21:59:32 -07:00
Gilles Boccon-Gibod
c0c6f3329d minor cleanup 2023-09-27 21:53:54 -07:00
Gilles Boccon-Gibod
406a932467 make grpc channels injectable 2023-09-27 21:37:36 -07:00
Gilles Boccon-Gibod
cc96d4245f address PR comments 2023-09-27 21:25:13 -07:00
Sparkling Diva
c6cdca8923 device: return the psm value from register_l2cap 2023-09-27 16:41:38 -07:00
Josh Wu
45edcafb06 SCO: A loopback example 2023-09-27 23:30:26 +08:00
Josh Wu
9f0bcc131f eSCO support 2023-09-27 23:30:17 +08:00
Gabriel White-Vega
7e331c2944 Ability to send HCI commands from Rust
* Autogenerate packet code in Rust from PDL (packet file copied from rootcanal)
* Implement parsing of packets that have a type header
* Expose Python APIs for sending HCI commands
* Expose Python APIs for instantiating a local controller
2023-09-27 11:17:47 -04:00
Gilles Boccon-Gibod
10347765cb Merge pull request #302 from google/gbg/netsim-with-instance-num
support netsim instance numbers
2023-09-26 09:34:28 -07:00
Gilles Boccon-Gibod
c12dee4e76 Merge pull request #294 from mauricelam/wasm-cryptography
Make cryptography a valid dependency for emscripten targets
2023-09-25 19:29:09 -07:00
Maurice Lam
772c188674 Fix typo 2023-09-25 18:08:52 -07:00
Maurice Lam
7c1a3bb8f9 Separate version specifier for cryptography in Emscripten builds 2023-09-22 16:43:40 -07:00
Maurice Lam
8c3c0b1e13 Make cryptography a valid dependency for emscripten targets
Since only the special cryptography package bundled with pyodide can be
used, relax the version requirement to anything that's version 39.*.

Fix #284
2023-09-22 16:43:40 -07:00
Gilles Boccon-Gibod
1ad84ad51c fix linter errors 2023-09-22 15:08:10 -07:00
Gilles Boccon-Gibod
64937c3f77 support netsim instance numbers 2023-09-22 14:22:04 -07:00
Gabriel White-Vega
50fd2218fa Remove unncecesary steps for injecting Python event loop
* Context vars can be injected directly into Rust future and spawned with tokio
2023-09-22 15:23:01 -04:00
Gilles Boccon-Gibod
4c29a16271 Merge pull request #297 from google/gbg/websocket-full-url
ws-client: make implementation match the doc
2023-09-22 11:41:24 -07:00
Gilles Boccon-Gibod
762d3e92de Merge pull request #300 from google/gbg/issue-299
use correct own_address_type when restarting advertising
2023-09-22 11:41:04 -07:00
uael
2f97531d78 pandora: use public identity address for public addresses 2023-09-22 20:08:34 +02:00
Gilles Boccon-Gibod
f6c7cae661 use correct own_address_type when restarting advertising 2023-09-22 10:33:36 -07:00
Gilles Boccon-Gibod
f1777a5bd2 use .to_string instead of a manual suffix replacement 2023-09-21 19:03:54 -07:00
Gilles Boccon-Gibod
78a06ae8cf make implementation match the doc 2023-09-21 19:01:40 -07:00
zxzxwu
d290df4aa9 Merge pull request #278 from zxzxwu/gatt2
Typing GATT
2023-09-21 16:09:36 +08:00
Josh Wu
e559744f32 Typing att 2023-09-21 15:52:07 +08:00
zxzxwu
67418e649a Merge pull request #288 from zxzxwu/l2cap_states
L2CAP: Refactor states to enums
2023-09-21 15:42:21 +08:00
Gilles Boccon-Gibod
5adf9fab53 Merge pull request #275 from whitevegagabriel/file-header
Add license header check for rust files
2023-09-20 16:21:38 -07:00
Josh Wu
2491b686fa Handle SMP_Security_Request 2023-09-20 23:13:08 +02:00
Josh Wu
efd02b2f3e Adopt reviews 2023-09-20 23:03:23 +02:00
Josh Wu
3b14078646 Overload signatures 2023-09-20 23:03:23 +02:00
Josh Wu
eb9d5632bc Add utils_test type hint 2023-09-20 23:03:23 +02:00
Josh Wu
45f60edbb6 Pyee watcher context 2023-09-20 23:03:23 +02:00
David Duarte
393ea6a7bb pandora_server: Load server config
Pandora server has it's own config that we load from the 'server'
property of the current bumble config file
2023-09-18 14:28:42 -07:00
Gabriel White-Vega
6ec6f1efe5 Add license header check for rust files
Added binary that can check for and add Apache 2.0 licenses.
Run this binary during the build-rust workflow.
2023-09-14 14:29:47 -04:00
Josh Wu
5d9598ea51 L2CAP: Refactor states to enums 2023-09-14 20:52:33 +08:00
Gilles Boccon-Gibod
0d36d99a73 Merge pull request #287 from google/revert-286-gbg/package-depencencies-for-wasm
Revert "make cryptography a valid dependency for emscripten targets"
2023-09-13 23:37:42 -07:00
Gilles Boccon-Gibod
d8a9f5a724 Revert "make cryptography a valid dependency for emscripten targets" 2023-09-13 23:36:33 -07:00
Gilles Boccon-Gibod
2c66e1a042 Merge pull request #285 from google/gbg/fix-mypy-errors
mypy: ignore false positive errors
2023-09-13 23:30:50 -07:00
Gilles Boccon-Gibod
d5eccdb00f Merge pull request #286 from google/gbg/package-depencencies-for-wasm
make cryptography a valid dependency for emscripten targets
2023-09-13 23:30:28 -07:00
Gilles Boccon-Gibod
32626573a6 ignore false positive errors 2023-09-13 23:17:00 -07:00
Gilles Boccon-Gibod
caa82b8f7e make cryptography a valid dependency for emscripten targets 2023-09-13 22:38:28 -07:00
Gilles Boccon-Gibod
5af347b499 Merge pull request #282 from google/gbg/multi-python-pre-commit-check
run pre-commit tests with all supported Python versions
2023-09-13 07:47:32 -07:00
zxzxwu
4ed5bb5a9e Merge pull request #281 from zxzxwu/cleanup-transport
Replace | typing usage with Optional and Union
2023-09-13 13:31:41 +08:00
Gilles Boccon-Gibod
2478d45673 more windows compat fixes 2023-09-12 14:52:42 -07:00
Gilles Boccon-Gibod
1bc7d94111 windows NamedTemporaryFile compatibility 2023-09-12 14:33:12 -07:00
Gilles Boccon-Gibod
6432414cd5 run tests on windows and mac in addition to linux 2023-09-12 13:50:15 -07:00
Gilles Boccon-Gibod
179064ba15 run pre-commit tests with all supported Python versions 2023-09-12 13:42:33 -07:00
William Escande
783b2d70a5 Add connection parameter update from peripheral 2023-09-12 11:08:04 -07:00
zxzxwu
80824f3fc1 Merge pull request #280 from zxzxwu/device_typing
Add terminated to TransportSource protocol
2023-09-12 20:46:35 +08:00
Josh Wu
f39f5f531c Replace | typing usage with Optional and Union 2023-09-12 15:50:51 +08:00
Gilles Boccon-Gibod
56139c622f Merge pull request #258 from mogenson/vsc_tx_power
Add support for Zephyr HCI VSC set TX power command
2023-09-11 21:34:11 -07:00
Michael Mogenson
da02f6a39b Add HCI Zephyr vendor commands to read and write TX power
Create platforms/zephyr/hci.py with definitions of vendor HCI commands
to read and write TX power.

Add documentation for how to prepare an nRF52840 dongle with a Zephyr
HCI USB firmware application that includes dynamic TX power support and
how to send a write TX power vendor HCI command from Bumble.
2023-09-11 10:06:10 -04:00
Josh Wu
548d5597c0 Transport: Add termination protocol signature 2023-09-11 14:36:40 +08:00
zxzxwu
7fd65d2412 Merge pull request #279 from zxzxwu/typo
Fix typo
2023-09-11 03:02:11 +08:00
Josh Wu
05a54a4af9 Fix typo 2023-09-10 20:32:58 +08:00
Gilles Boccon-Gibod
1e00c8f456 Merge pull request #276 from google/gbg/add-zephyr-zip-to-docs
add zephyr binary to docs
2023-09-08 18:07:15 -07:00
Gilles Boccon-Gibod
90d165aa01 add zephyr binary 2023-09-08 14:17:15 -07:00
zxzxwu
01603ca9e4 Merge pull request #271 from zxzxwu/device_typing
Typing transport and relateds
2023-09-09 00:55:59 +08:00
Gilles Boccon-Gibod
a1b6eb61f2 Merge pull request #269 from google/gbg/android_vendor_hci
add support for vendor HCI commands and events
2023-09-08 08:50:49 -07:00
zxzxwu
25f300d3ec Merge pull request #270 from zxzxwu/typo
Fix typos
2023-09-08 17:32:33 +08:00
Josh Wu
41fe63df06 Fix typos 2023-09-08 16:30:06 +08:00
Josh Wu
b312170d5f Typing transport 2023-09-08 15:27:01 +08:00
David Duarte
cf7f2e8f44 Make platformdirs import lazy
platformdirs is not available in Android
2023-09-07 21:13:29 -07:00
Gilles Boccon-Gibod
d292083ed1 Merge pull request #272 from zxzxwu/gfp
Bring HfpProtocol back
2023-09-07 13:03:36 -07:00
Gilles Boccon-Gibod
9b11142b45 Merge pull request #267 from google/gbg/rfcomm-with-uuid
rfcomm with UUID
2023-09-07 13:01:56 -07:00
Hui Peng
acdbc4d7b9 Raise an exception when an L2cap connection fails 2023-09-07 19:24:38 +02:00
Josh Wu
838d10a09d Add HFP tests 2023-09-07 23:20:16 +08:00
Josh Wu
3852aa056b Bring HfpProtocol back 2023-09-07 23:20:09 +08:00
Gilles Boccon-Gibod
ae77e4528f add support for vendor HCI commands and events 2023-09-06 20:00:15 -07:00
Gilles Boccon-Gibod
9303f4fc5b Merge pull request #262 from whitevegagabriel/l2cap
Port l2cap_bridge sample to Rust
2023-09-06 17:13:12 -07:00
Gilles Boccon-Gibod
8be9f4cb0e add doc and fix types 2023-09-06 17:05:30 -07:00
Gilles Boccon-Gibod
1ea12b1bf7 rebase 2023-09-06 17:05:24 -07:00
Gilles Boccon-Gibod
65e6d68355 add tcp server 2023-09-06 16:49:21 -07:00
Gabriel White-Vega
9732eb8836 Address PR feedback 2023-09-06 09:47:08 -04:00
Gabriel White-Vega
5ae668bc70 Port l2cap_bridge sample to Rust
- Added Rust wrappers where relevant
- Edited a couple logs in python l2cap_bridge to be more symmetrical
- Created cli subcommand for running the rustified l2cap bridge
2023-09-05 16:03:02 -04:00
Gilles Boccon-Gibod
fd4d1bcca3 Merge pull request #261 from marshallpierce/mp/rust-realtek-tools
Rust tools for working with Realtek firmware
2023-09-05 10:55:29 -07:00
Gilles Boccon-Gibod
0a251c9f8e Merge pull request #265 from mogenson/grpcio-update
Update grpcio and pip package versions
2023-08-31 14:53:54 -07:00
Michael Mogenson
351d77be59 Update grpcio and pip package versions
The current grpcio version 1.51.1 fails to build on aarch64 based MacOS
computers. Update the version of the grpcio and grpcio-tools packages to
the latest 1.57.0 version. There are binary wheels available for this
version from PyPi for aarch64 MacOS.

Also update the pip version for the Conda environment. It seems a newer
version of pip is required to detect and install these wheels.

Testing:

invoke test passes and I can start the bumble-pandora-server
successfully.
2023-08-31 14:01:14 -04:00
Marshall Pierce
0e2fc80509 Rust tools for working with Realtek firmware
Further adventures in porting tools to Rust to flesh out the supported
API.

These tools didn't feel like `example`s, so I made a top level `bumble`
CLI tool that hosts them all as subcommands. I also moved the usb probe
not-really-an-`example` into it as well. I'm open to suggestions on how
best to organize the subcommands to make them intuitive to explore with
`--help`, and how to leave room for other future tools.

I also adopted the per-OS project data dir for a default firmware
location so that users can download once and then use those .bin files
from anywhere without having to sprinkle .bin files in project
directories or reaching inside the python package dir hierarchy.
2023-08-30 15:37:35 -06:00
Gilles Boccon-Gibod
8f3fdecb93 Merge pull request #263 from zxzxwu/pdu
Typing packet transmission flow
2023-08-30 11:15:12 -07:00
Josh Wu
249a205d8e Typing packet transmission flow 2023-08-30 01:47:46 +08:00
Gilles Boccon-Gibod
7485801222 Merge pull request #256 from zxzxwu/sdp-type-fix
Typing SDP and add tests
2023-08-28 08:41:02 -07:00
Gilles Boccon-Gibod
4678e59737 Merge pull request #250 from google/gbg/new-rtk-dongles
add entry to the list of supported USB devices
2023-08-28 08:40:40 -07:00
Gilles Boccon-Gibod
952d351c00 Merge pull request #247 from google/gbg/wasm-with-ws
wasm with ws
2023-08-28 08:40:18 -07:00
Josh Wu
901eb55b0e Add SDP self tests 2023-08-24 01:27:07 +08:00
Josh Wu
727586e40e Typing SDP 2023-08-23 14:52:44 +08:00
Gilles Boccon-Gibod
3aa678a58e Merge pull request #253 from zxzxwu/rfcomm_type_fix
Adding more typing in rfcomm.py
2023-08-22 09:47:38 -07:00
Gilles Boccon-Gibod
fc7c1a8113 Merge pull request #255 from zxzxwu/player
Remove accidentally added files
2023-08-22 07:34:31 -07:00
Josh Wu
f62a0bbe75 Remove accidentally added files 2023-08-22 22:12:41 +08:00
Josh Wu
7341172739 Use __future__.annotations for typing 2023-08-22 14:44:15 +08:00
Gilles Boccon-Gibod
91b9fbe450 Merge pull request #240 from zxzxwu/ssp
Handle SSP Complete events
2023-08-21 18:01:28 -07:00
Josh Wu
e6b566b848 RFCOMM: Refactor role to enum 2023-08-21 15:16:34 +08:00
Josh Wu
2527a711dc Refactor RFCOMM states to enum 2023-08-21 15:12:52 +08:00
Josh Wu
5fba6b1cae Complete typing in RFCOMM 2023-08-21 15:12:52 +08:00
Gilles Boccon-Gibod
43e632f83c Merge pull request #244 from google/gbg/hci-source-termination-mode
add sink method for lost transports
2023-08-18 10:17:11 -07:00
Gilles Boccon-Gibod
623298b0e9 emit flush event when transport lost 2023-08-18 09:59:15 -07:00
Gilles Boccon-Gibod
85a61dc39d add entry to the list of supported USB devices 2023-08-18 09:56:06 -07:00
Gilles Boccon-Gibod
6e8c44b5e6 Merge pull request #249 from zxzxwu/player
Support SBC in speaker.app
2023-08-18 09:55:23 -07:00
Josh Wu
ec4dcc174e Support SBC in speaker.app 2023-08-18 17:13:11 +08:00
Charlie Boutier
b247aca3b4 pandora_server: add support to accept bumble config file 2023-08-17 14:24:56 -07:00
Gilles Boccon-Gibod
6226bfd196 fix typo after refactor 2023-08-17 09:51:56 -07:00
Gilles Boccon-Gibod
71e11b7cf8 format 2023-08-15 15:20:48 -07:00
Gilles Boccon-Gibod
800c62fdb6 add readme for web examples 2023-08-15 15:17:38 -07:00
Gilles Boccon-Gibod
640b9cd53a refactor pyiodide support and add examples 2023-08-15 13:36:58 -07:00
Gilles Boccon-Gibod
f4add16aea Merge pull request #241 from hchataing/hfp-hf
hfp: Implement initiate SLC procedure for HFP-HF
2023-08-14 10:32:55 -07:00
Gilles Boccon-Gibod
2bfec3c4ed add sink method for lost transports 2023-08-12 10:54:20 -07:00
Henri Chataing
9963b51c04 hfp: Implement initiate SLC procedure for HFP-HF 2023-08-10 08:37:54 -07:00
Josh Wu
2af3494d8c Handle SSP Complete events 2023-08-10 10:58:41 +08:00
Gilles Boccon-Gibod
fe28473ba8 Merge pull request #234 from zxzxwu/addr
Support address resolution offload
2023-08-08 21:30:13 -07:00
Gilles Boccon-Gibod
53d66bc74a Merge pull request #237 from marshallpierce/mp/company-ids
Faster company id table
2023-08-08 21:29:45 -07:00
Marshall Pierce
e2c1ad5342 Faster company id table
Following up on the [loose end from the initial
PR](https://github.com/google/bumble/pull/207#discussion_r1278015116),
we can avoid accessing the Python company id map at runtime by doing
code gen ahead of time.

Using an example to do the code gen avoids even the small build slowdown
from invoking the code gen logic in build.rs, but more importantly,
means that it's still a totally boring normal build that won't require
any IDE setup, etc, to work for everyone. Since the company ID list
changes rarely, and there's a test to ensure it always matches, this
seems like a good trade.
2023-08-04 10:12:52 -06:00
Josh Wu
6399c5fb04 Auto add device to resolving list after pairing 2023-08-03 20:51:00 +08:00
Josh Wu
784cf4f26a Add a flag to enable LE address resolution 2023-08-03 20:50:57 +08:00
Josh Wu
0301b1a999 Pandora: Configure identity address type 2023-08-02 11:31:07 -07:00
Lucas Abel
3ab2cd5e71 pandora: decrease all info logs to debug 2023-08-02 10:56:41 -07:00
uael
6ea669531a pandora: add tcp option to transport configuration
* Add a fallback to `tcp` when `transport` is not set.
* Default the `tcp` transport to the default rootcanal HCI address.
2023-08-01 08:51:12 -07:00
Josh Wu
cbbada4748 SMP: Delegate distributed address type 2023-08-01 08:38:03 -07:00
Gilles Boccon-Gibod
152b8d1233 Merge pull request #230 from google/gbg/hci-object-array
add support for field arrays in hci packet definitions
2023-08-01 07:44:31 -07:00
Gilles Boccon-Gibod
bdad225033 add support for field arrays in hci packet definitions 2023-07-30 22:19:10 -07:00
Gilles Boccon-Gibod
8eeb58e467 Merge pull request #207 from marshallpierce/mp/rust-poc
Proof-of-concept Rust wrapper
2023-07-28 20:14:23 -07:00
Marshall Pierce
91971433d2 PR feedback 2023-07-28 14:34:02 -06:00
Gilles Boccon-Gibod
a0a4bd457f Merge pull request #227 from google/gbg/py11
compatibility with python 11
2023-07-28 12:54:30 -07:00
Gilles Boccon-Gibod
4ffc050eed restore python < 11 compat 2023-07-27 16:37:27 -07:00
Gilles Boccon-Gibod
60678419a0 compatibility with python 11 2023-07-27 14:55:28 -07:00
Gilles Boccon-Gibod
648dcc9305 use type object instead of type strings 2023-07-27 13:19:37 -07:00
Josh Wu
190529184e L2CAP: Import device.Connection for typing 2023-07-27 09:07:55 -07:00
Josh Wu
46eb81466d Add more argement hints in L2CAP 2023-07-27 09:07:55 -07:00
Josh Wu
9c70c487b9 Add type hint to L2CAP module 2023-07-27 09:07:55 -07:00
Josh Wu
43234d7c3e Use with-patch to mock SMP session 2023-07-27 08:00:36 -07:00
Josh Wu
dbf878dc3f SMP: Remove PairingMethod.__str__ 2023-07-27 08:00:36 -07:00
Josh Wu
f6c0bd88d7 SMP: Do not send phase 2 commands in CTKD 2023-07-27 08:00:36 -07:00
Josh Wu
8440b7fbf1 SMP: Refactor pairing method as enum 2023-07-27 08:00:36 -07:00
Gilles Boccon-Gibod
808ab54135 Merge pull request #221 from google/gbg/core-classes
add new device class major/minor identifiers
2023-07-25 09:49:05 -07:00
Gilles Boccon-Gibod
52b29ad680 add new device class major/minor identifiers 2023-07-24 17:41:57 -07:00
Gilles Boccon-Gibod
d41bf9c587 Merge pull request #216 from google/gbg/host-buffer-size-command
accept Host Buffer Size Command in the controller
2023-07-24 09:05:10 -07:00
Gilles Boccon-Gibod
b758825164 add flow control command 2023-07-22 13:04:39 -07:00
Gilles Boccon-Gibod
779dfe5473 accept Host Buffer Size Command in the controller 2023-07-21 19:36:26 -07:00
Marshall Pierce
afb21220e2 Proof-of-concept Rust wrapper
This contains Rust wrappers around enough of the Python API to implement Rust versions of the `battery_client` and `run_scanner` examples. The goal is to gather feedback on the approach, and of course to show that it is possible.

The module structure mirrors that of the Python. The Rust API is not optimally Rust-y, but given the constraints of everything having to delegate to Python, it's at least usable.

Notably, this does not yet solve the packaging problem: users must have an appropriate virtualenv, libpython, etc. [PyOxidizer](https://github.com/indygreg/PyOxidizer) may be a viable path there.
2023-07-20 10:50:15 -06:00
Gilles Boccon-Gibod
f9a4c7518e Merge pull request #214 from marshallpierce/mp/scanner-rssi
Add a space after RSSI
2023-07-14 10:52:54 -07:00
Marshall Pierce
bad2fdf69f Add a space after RSSI
The other data elements have a space, so I'm guessing that RSSI
is intended to as well. Perhaps there's some subtle reason why
it should have a space, though, in which case feel free to
close this.

Output now looks like this:

```
>>> 58:D3:49:E7:40:DA/P [PUBLIC]:
  RSSI: -67
  [Flags]: LE General,BR/EDR C,BR/EDR H
  [TX Power Level]: 4
  [Manufacturer Specific Data]: company=Apple, Inc., data=0f08c00af4392b00040c10020f04
```
2023-07-13 12:47:45 -06:00
Lucas Abel
a84df469cd pairing: handle user errors from all delegate calls 2023-07-12 11:03:21 -07:00
Gilles Boccon-Gibod
03e33e39bd Merge pull request #211 from google/gbg/fix-ws-transport-doc
fix doc for ws-client ws-server transports
2023-07-12 07:06:32 -07:00
Gilles Boccon-Gibod
753fb69272 fix doc for ws-client ws-server transports 2023-07-12 06:06:20 -07:00
Gilles Boccon-Gibod
81a5f3a395 Merge pull request #203 from google/gbg/realtek-driver
realtek driver
2023-07-11 07:06:07 -07:00
Gilles Boccon-Gibod
696a8d82fd look for files in linux FW dir 2023-07-11 06:41:34 -07:00
Gilles Boccon-Gibod
5f294b1fea python 3.8 compatibility 2023-07-11 06:41:34 -07:00
Gilles Boccon-Gibod
2d8f5e80fb add missing doc files 2023-07-11 06:41:34 -07:00
Gilles Boccon-Gibod
7a042db78e add more USB ids 2023-07-11 06:41:34 -07:00
Gilles Boccon-Gibod
41ce311836 allow custom driver factories 2023-07-11 06:41:34 -07:00
Gilles Boccon-Gibod
03538d0f8a add doc 2023-07-11 06:41:34 -07:00
Gilles Boccon-Gibod
86bc222dc0 add missing file 2023-07-11 06:41:34 -07:00
Gilles Boccon-Gibod
e8d285fdab add downloader tool 2023-07-11 06:41:34 -07:00
Gilles Boccon-Gibod
852c933c92 wip (+4 squashed commits)
Squashed commits:
[d29a350] wip
[7f541ed] wip
[1e2902e] basic working version
[14b497a] wip
2023-07-11 06:41:34 -07:00
Lucas Abel
7867a99a54 Merge pull request #209 from google/click-types-quick-fix
temporarily pin click to 8.1.3
2023-07-11 06:21:11 -07:00
Gilles Boccon-Gibod
6cd14bb503 temporarily pin click to 8.1.3 2023-07-11 00:11:24 -07:00
Gilles Boccon-Gibod
532b99ffea Merge pull request #206 from benquike/main
Add some commands and events in hci
2023-07-10 01:23:08 -07:00
Hui Peng
d80f40ff5d Add some commands and events in hci 2023-06-28 08:51:10 -07:00
418 changed files with 59819 additions and 5589 deletions

View File

@@ -0,0 +1,30 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/python
{
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
"image": "mcr.microsoft.com/devcontainers/universal:2",
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand":
"python -m pip install '.[build,test,development,documentation]'",
// Configure tool-specific properties.
"customizations": {
// Configure properties specific to VS Code.
"vscode": {
// Add the IDs of extensions you want installed when the container is created.
"extensions": [
"ms-python.python"
]
}
}
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}

View File

@@ -14,6 +14,10 @@ jobs:
check:
name: Check Code
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
fail-fast: false
steps:
- name: Check out from Git
@@ -25,11 +29,11 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development]"
python -m pip install ".[build,test,development,pandora]"
- name: Check
run: |
invoke project.pre-commit

43
.github/workflows/python-avatar.yml vendored Normal file
View File

@@ -0,0 +1,43 @@
name: Python Avatar
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
permissions:
contents: read
jobs:
test:
name: Avatar [${{ matrix.shard }}]
runs-on: ubuntu-latest
strategy:
matrix:
shard: [
1/24, 2/24, 3/24, 4/24,
5/24, 6/24, 7/24, 8/24,
9/24, 10/24, 11/24, 12/24,
13/24, 14/24, 15/24, 16/24,
17/24, 18/24, 19/24, 20/24,
21/24, 22/24, 23/24, 24/24,
]
steps:
- uses: actions/checkout@v3
- name: Set Up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install .[avatar,pandora]
- name: Rootcanal
run: nohup python -m rootcanal > rootcanal.log &
- name: Test
run: |
avatar --list | grep -Ev '^=' > test-names.txt
timeout 5m avatar --test-beds bumble.bumbles --tests $(split test-names.txt -n l/${{ matrix.shard }})
- name: Rootcanal Logs
run: cat rootcanal.log

View File

@@ -12,11 +12,11 @@ permissions:
jobs:
build:
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
fail-fast: false
steps:
@@ -41,3 +41,42 @@ jobs:
run: |
inv build
inv build.mkdocs
build-rust:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
rust-version: [ "1.76.0", "stable" ]
fail-fast: false
steps:
- name: Check out from Git
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development,documentation]"
- name: Install Rust toolchain
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }}
- name: Install Rust dependencies
run: cargo install cargo-all-features # allows building/testing combinations of features
- name: Check License Headers
run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build
run: cd rust && cargo build --all-targets && cargo build-all-features --all-targets
# Lints after build so what clippy needs is already built
- name: Rust Lints
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings && cargo clippy --all-features --all-targets -- --deny warnings
- name: Rust Tests
run: cd rust && cargo test-all-features
# At some point, hook up publishing the binary. For now, just make sure it builds.
# Once we're ready to publish binaries, this should be built with `--release`.
- name: Build Bumble CLI
run: cd rust && cargo build --features bumble-tools --bin bumble

8
.gitignore vendored
View File

@@ -6,6 +6,14 @@ dist/
docs/mkdocs/site
test-results.xml
__pycache__
# Vim
.*.sw*
# generated by setuptools_scm
bumble/_version.py
.vscode/launch.json
.vscode/settings.json
/.idea
venv/
.venv/
# snoop logs
out/

19
.vscode/settings.json vendored
View File

@@ -1,6 +1,7 @@
{
"cSpell.words": [
"Abortable",
"aiohttp",
"altsetting",
"ansiblue",
"ansicyan",
@@ -9,10 +10,13 @@
"ansired",
"ansiyellow",
"appendleft",
"ascs",
"ASHA",
"asyncio",
"ATRAC",
"avctp",
"avdtp",
"avrcp",
"bitpool",
"bitstruct",
"BSCP",
@@ -21,7 +25,10 @@
"cccds",
"cmac",
"CONNECTIONLESS",
"csip",
"csis",
"csrcs",
"CVSD",
"datagram",
"DATALINK",
"delayreport",
@@ -29,6 +36,8 @@
"deregistration",
"dhkey",
"diversifier",
"endianness",
"ESCO",
"Fitbit",
"GATTLINK",
"HANDSFREE",
@@ -36,15 +45,21 @@
"keyup",
"levelname",
"libc",
"liblc",
"libusb",
"MITM",
"MSBC",
"NDIS",
"netsim",
"NONBLOCK",
"NONCONN",
"OXIMETER",
"popleft",
"PRAND",
"protobuf",
"psms",
"pyee",
"Pyodide",
"pyusb",
"rfcomm",
"ROHC",
@@ -52,6 +67,7 @@
"SEID",
"seids",
"SERV",
"SIRK",
"ssrc",
"strerror",
"subband",
@@ -61,8 +77,11 @@
"substates",
"tobytes",
"tsep",
"UNMUTE",
"unmuted",
"usbmodem",
"vhci",
"wasmtime",
"websockets",
"xcursor",
"ycursor"

407
apps/auracast.py Normal file
View File

@@ -0,0 +1,407 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import logging
import os
from typing import cast, Dict, Optional, Tuple
import click
import pyee
from bumble.colors import color
import bumble.company_ids
import bumble.core
import bumble.device
import bumble.gatt
import bumble.hci
import bumble.profiles.bap
import bumble.profiles.pbp
import bumble.transport
import bumble.utils
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AURACAST_DEFAULT_DEVICE_NAME = "Bumble Auracast"
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address("F0:F1:F2:F3:F4:F5")
# -----------------------------------------------------------------------------
# Discover Broadcasts
# -----------------------------------------------------------------------------
class BroadcastDiscoverer:
@dataclasses.dataclass
class Broadcast(pyee.EventEmitter):
name: str
sync: bumble.device.PeriodicAdvertisingSync
rssi: int = 0
public_broadcast_announcement: Optional[
bumble.profiles.pbp.PublicBroadcastAnnouncement
] = None
broadcast_audio_announcement: Optional[
bumble.profiles.bap.BroadcastAudioAnnouncement
] = None
basic_audio_announcement: Optional[
bumble.profiles.bap.BasicAudioAnnouncement
] = None
appearance: Optional[bumble.core.Appearance] = None
biginfo: Optional[bumble.device.BIGInfoAdvertisement] = None
manufacturer_data: Optional[Tuple[str, bytes]] = None
def __post_init__(self) -> None:
super().__init__()
self.sync.on('establishment', self.on_sync_establishment)
self.sync.on('loss', self.on_sync_loss)
self.sync.on('periodic_advertisement', self.on_periodic_advertisement)
self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement)
self.establishment_timeout_task = asyncio.create_task(
self.wait_for_establishment()
)
async def wait_for_establishment(self) -> None:
await asyncio.sleep(5.0)
if self.sync.state == bumble.device.PeriodicAdvertisingSync.State.PENDING:
print(
color(
'!!! Periodic advertisement sync not established in time, '
'canceling',
'red',
)
)
await self.sync.terminate()
def update(self, advertisement: bumble.device.Advertisement) -> None:
self.rssi = advertisement.rssi
for service_data in advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA
):
assert isinstance(service_data, tuple)
service_uuid, data = service_data
assert isinstance(data, bytes)
if (
service_uuid
== bumble.gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE
):
self.public_broadcast_announcement = (
bumble.profiles.pbp.PublicBroadcastAnnouncement.from_bytes(data)
)
continue
if (
service_uuid
== bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
):
self.broadcast_audio_announcement = (
bumble.profiles.bap.BroadcastAudioAnnouncement.from_bytes(data)
)
continue
self.appearance = advertisement.data.get( # type: ignore[assignment]
bumble.core.AdvertisingData.APPEARANCE
)
if manufacturer_data := advertisement.data.get(
bumble.core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA
):
assert isinstance(manufacturer_data, tuple)
company_id = cast(int, manufacturer_data[0])
data = cast(bytes, manufacturer_data[1])
self.manufacturer_data = (
bumble.company_ids.COMPANY_IDENTIFIERS.get(
company_id, f'0x{company_id:04X}'
),
data,
)
def print(self) -> None:
print(
color('Broadcast:', 'yellow'),
self.sync.advertiser_address,
color(self.sync.state.name, 'green'),
)
print(f' {color("Name", "cyan")}: {self.name}')
if self.appearance:
print(f' {color("Appearance", "cyan")}: {str(self.appearance)}')
print(f' {color("RSSI", "cyan")}: {self.rssi}')
print(f' {color("SID", "cyan")}: {self.sync.sid}')
if self.manufacturer_data:
print(
f' {color("Manufacturer Data", "cyan")}: '
f'{self.manufacturer_data[0]} -> {self.manufacturer_data[1].hex()}'
)
if self.broadcast_audio_announcement:
print(
f' {color("Broadcast ID", "cyan")}: '
f'{self.broadcast_audio_announcement.broadcast_id}'
)
if self.public_broadcast_announcement:
print(
f' {color("Features", "cyan")}: '
f'{self.public_broadcast_announcement.features}'
)
print(
f' {color("Metadata", "cyan")}: '
f'{self.public_broadcast_announcement.metadata}'
)
if self.basic_audio_announcement:
print(color(' Audio:', 'cyan'))
print(
color(' Presentation Delay:', 'magenta'),
self.basic_audio_announcement.presentation_delay,
)
for subgroup in self.basic_audio_announcement.subgroups:
print(color(' Subgroup:', 'magenta'))
print(color(' Codec ID:', 'yellow'))
print(
color(' Coding Format: ', 'green'),
subgroup.codec_id.coding_format.name,
)
print(
color(' Company ID: ', 'green'),
subgroup.codec_id.company_id,
)
print(
color(' Vendor Specific Codec ID:', 'green'),
subgroup.codec_id.vendor_specific_codec_id,
)
print(
color(' Codec Config:', 'yellow'),
subgroup.codec_specific_configuration,
)
print(color(' Metadata: ', 'yellow'), subgroup.metadata)
for bis in subgroup.bis:
print(color(f' BIS [{bis.index}]:', 'yellow'))
print(
color(' Codec Config:', 'green'),
bis.codec_specific_configuration,
)
if self.biginfo:
print(color(' BIG:', 'cyan'))
print(
color(' Number of BIS:', 'magenta'),
self.biginfo.num_bis,
)
print(
color(' PHY: ', 'magenta'),
self.biginfo.phy.name,
)
print(
color(' Framed: ', 'magenta'),
self.biginfo.framed,
)
print(
color(' Encrypted: ', 'magenta'),
self.biginfo.encrypted,
)
def on_sync_establishment(self) -> None:
self.establishment_timeout_task.cancel()
self.emit('change')
def on_sync_loss(self) -> None:
self.basic_audio_announcement = None
self.biginfo = None
self.emit('change')
def on_periodic_advertisement(
self, advertisement: bumble.device.PeriodicAdvertisement
) -> None:
if advertisement.data is None:
return
for service_data in advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA
):
assert isinstance(service_data, tuple)
service_uuid, data = service_data
assert isinstance(data, bytes)
if service_uuid == bumble.gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
self.basic_audio_announcement = (
bumble.profiles.bap.BasicAudioAnnouncement.from_bytes(data)
)
break
self.emit('change')
def on_biginfo_advertisement(
self, advertisement: bumble.device.BIGInfoAdvertisement
) -> None:
self.biginfo = advertisement
self.emit('change')
def __init__(
self,
device: bumble.device.Device,
filter_duplicates: bool,
sync_timeout: float,
):
self.device = device
self.filter_duplicates = filter_duplicates
self.sync_timeout = sync_timeout
self.broadcasts: Dict[bumble.hci.Address, BroadcastDiscoverer.Broadcast] = {}
self.status_message = ''
device.on('advertisement', self.on_advertisement)
async def run(self) -> None:
self.status_message = color('Scanning...', 'green')
await self.device.start_scanning(
active=False,
filter_duplicates=False,
)
def refresh(self) -> None:
# Clear the screen from the top
print('\033[H')
print('\033[0J')
print('\033[H')
# Print the status message
print(self.status_message)
print("==========================================")
# Print all broadcasts
for broadcast in self.broadcasts.values():
broadcast.print()
print('------------------------------------------')
# Clear the screen to the bottom
print('\033[0J')
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
if (
broadcast_name := advertisement.data.get(
bumble.core.AdvertisingData.BROADCAST_NAME
)
) is None:
return
assert isinstance(broadcast_name, str)
if broadcast := self.broadcasts.get(advertisement.address):
broadcast.update(advertisement)
self.refresh()
return
bumble.utils.AsyncRunner.spawn(
self.on_new_broadcast(broadcast_name, advertisement)
)
async def on_new_broadcast(
self, name: str, advertisement: bumble.device.Advertisement
) -> None:
periodic_advertising_sync = await self.device.create_periodic_advertising_sync(
advertiser_address=advertisement.address,
sid=advertisement.sid,
sync_timeout=self.sync_timeout,
filter_duplicates=self.filter_duplicates,
)
broadcast = self.Broadcast(
name,
periodic_advertising_sync,
)
broadcast.on('change', self.refresh)
broadcast.update(advertisement)
self.broadcasts[advertisement.address] = broadcast
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
self.status_message = color(
f'+Found {len(self.broadcasts)} broadcasts', 'green'
)
self.refresh()
def on_broadcast_loss(self, broadcast: Broadcast) -> None:
del self.broadcasts[broadcast.sync.advertiser_address]
bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate())
self.status_message = color(
f'-Found {len(self.broadcasts)} broadcasts', 'green'
)
self.refresh()
async def run_discover_broadcasts(
filter_duplicates: bool, sync_timeout: float, transport: str
) -> None:
async with await bumble.transport.open_transport(transport) as (
hci_source,
hci_sink,
):
device = bumble.device.Device.with_hci(
AURACAST_DEFAULT_DEVICE_NAME,
AURACAST_DEFAULT_DEVICE_ADDRESS,
hci_source,
hci_sink,
)
await device.power_on()
discoverer = BroadcastDiscoverer(device, filter_duplicates, sync_timeout)
await discoverer.run()
await hci_source.terminated
# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
def auracast(
ctx,
):
ctx.ensure_object(dict)
@auracast.command('discover-broadcasts')
@click.option(
'--filter-duplicates', is_flag=True, default=False, help='Filter duplicates'
)
@click.option(
'--sync-timeout',
metavar='SYNC_TIMEOUT',
type=float,
default=5.0,
help='Sync timeout (in seconds)',
)
@click.argument('transport')
@click.pass_context
def discover_broadcasts(ctx, filter_duplicates, sync_timeout, transport):
"""Discover public broadcasts"""
asyncio.run(run_discover_broadcasts(filter_duplicates, sync_timeout, transport))
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
auracast()
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter

File diff suppressed because it is too large Load Diff

63
apps/ble_rpa_tool.py Normal file
View File

@@ -0,0 +1,63 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import click
from bumble.colors import color
from bumble.hci import Address
from bumble.helpers import generate_irk, verify_rpa_with_irk
@click.group()
def cli():
'''
This is a tool for generating IRK, RPA,
and verifying IRK/RPA pairs
'''
@click.command()
def gen_irk() -> None:
print(generate_irk().hex())
@click.command()
@click.argument("irk", type=str)
def gen_rpa(irk: str) -> None:
irk_bytes = bytes.fromhex(irk)
rpa = Address.generate_private_address(irk_bytes)
print(rpa.to_string(with_type_qualifier=False))
@click.command()
@click.argument("irk", type=str)
@click.argument("rpa", type=str)
def verify_rpa(irk: str, rpa: str) -> None:
address = Address(rpa)
irk_bytes = bytes.fromhex(irk)
if verify_rpa_with_irk(address, irk_bytes):
print(color("Verified", "green"))
else:
print(color("Not Verified", "red"))
def main():
cli.add_command(gen_irk)
cli.add_command(gen_rpa)
cli.add_command(verify_rpa)
cli()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()

View File

@@ -63,6 +63,7 @@ from bumble.transport import open_transport_or_link
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
from bumble.gatt_client import CharacteristicProxy
from bumble.hci import (
Address,
HCI_Constant,
HCI_LE_1M_PHY,
HCI_LE_2M_PHY,
@@ -289,11 +290,7 @@ class ConsoleApp:
device_config, hci_source, hci_sink
)
else:
random_address = (
f"{random.randint(192,255):02X}" # address is static random
)
for random_byte in random.sample(range(255), 5):
random_address += f":{random_byte:02X}"
random_address = Address.generate_static_address()
self.append_to_log(f"Setting random address: {random_address}")
self.device = Device.with_hci(
'Bumble', random_address, hci_source, hci_sink
@@ -503,21 +500,9 @@ class ConsoleApp:
self.show_error('not connected')
return
# Discover all services, characteristics and descriptors
self.append_to_output('discovering services...')
await self.connected_peer.discover_services()
self.append_to_output(
f'found {len(self.connected_peer.services)} services,'
' discovering characteristics...'
)
await self.connected_peer.discover_characteristics()
self.append_to_output('found characteristics, discovering descriptors...')
for service in self.connected_peer.services:
for characteristic in service.characteristics:
await self.connected_peer.discover_descriptors(characteristic)
self.append_to_output('discovery completed')
self.show_remote_services(self.connected_peer.services)
self.append_to_output('Service Discovery starting...')
await self.connected_peer.discover_all()
self.append_to_output('Service Discovery done!')
async def discover_attributes(self):
if not self.connected_peer:
@@ -777,7 +762,7 @@ class ConsoleApp:
if not service:
continue
values = [
attribute.read_value(connection)
await attribute.read_value(connection)
for connection in self.device.connections.values()
]
if not values:
@@ -796,11 +781,11 @@ class ConsoleApp:
if not characteristic:
continue
values = [
attribute.read_value(connection)
await attribute.read_value(connection)
for connection in self.device.connections.values()
]
if not values:
values = [attribute.read_value(None)]
values = [await attribute.read_value(None)]
# TODO: future optimization: convert CCCD value to human readable string
@@ -944,7 +929,7 @@ class ConsoleApp:
# send data to any subscribers
if isinstance(attribute, Characteristic):
attribute.write_value(None, value)
await attribute.write_value(None, value)
if attribute.has_properties(Characteristic.NOTIFY):
await self.device.gatt_server.notify_subscribers(attribute)
if attribute.has_properties(Characteristic.INDICATE):
@@ -1172,7 +1157,7 @@ class ScanResult:
name = ''
# Remove any '/P' qualifier suffix from the address string
address_str = str(self.address).replace('/P', '')
address_str = self.address.to_string(with_type_qualifier=False)
# RSSI bar
bar_string = rssi_bar(self.rssi)

View File

@@ -18,30 +18,39 @@
import asyncio
import os
import logging
import click
from bumble.company_ids import COMPANY_IDENTIFIERS
import time
import click
from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.colors import color
from bumble.core import name_or_number
from bumble.hci import (
map_null_terminated_utf8_string,
LeFeature,
HCI_SUCCESS,
HCI_LE_SUPPORTED_FEATURES_NAMES,
HCI_VERSION_NAMES,
LMP_VERSION_NAMES,
HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_Read_Buffer_Size_Command,
HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_Read_Local_Name_Command,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_Read_Local_Version_Information_Command,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
@@ -57,13 +66,14 @@ def command_succeeded(response):
# -----------------------------------------------------------------------------
async def get_classic_info(host):
async def get_classic_info(host: Host) -> None:
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response):
print()
print(
color('Classic Address:', 'yellow'), response.return_parameters.bd_addr
color('Classic Address:', 'yellow'),
response.return_parameters.bd_addr.to_string(False),
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
@@ -77,7 +87,7 @@ async def get_classic_info(host):
# -----------------------------------------------------------------------------
async def get_le_info(host):
async def get_le_info(host: Host) -> None:
print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
@@ -116,13 +126,50 @@ async def get_le_info(host):
'\n',
)
if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await host.send_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
if command_succeeded(response):
print(
color('Suggested Default Data Length:', 'yellow'),
f'{response.return_parameters.suggested_max_tx_octets}/'
f'{response.return_parameters.suggested_max_tx_time}',
'\n',
)
print(color('LE Features:', 'yellow'))
for feature in host.supported_le_features:
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
print(f' {LeFeature(feature).name}')
# -----------------------------------------------------------------------------
async def async_main(transport):
async def get_acl_flow_control_info(host: Host) -> None:
print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
print(
color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
)
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
)
# -----------------------------------------------------------------------------
async def async_main(latency_probes, transport):
print('<<< connecting to HCI...')
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected')
@@ -130,6 +177,23 @@ async def async_main(transport):
host = Host(hci_source, hci_sink)
await host.reset()
# Measure the latency if requested
latencies = []
if latency_probes:
for _ in range(latency_probes):
start = time.time()
await host.send_command(HCI_Read_Local_Version_Information_Command())
latencies.append(1000 * (time.time() - start))
print(
color('HCI Command Latency:', 'yellow'),
(
f'min={min(latencies):.2f}, '
f'max={max(latencies):.2f}, '
f'average={sum(latencies)/len(latencies):.2f}'
),
'\n',
)
# Print version
print(color('Version:', 'yellow'))
print(
@@ -153,19 +217,28 @@ async def async_main(transport):
# Get the LE info
await get_le_info(host)
# Print the ACL flow control info
await get_acl_flow_control_info(host)
# Print the list of commands supported by the controller
print()
print(color('Supported Commands:', 'yellow'))
for command in host.supported_commands:
print(' ', HCI_Command.command_name(command))
print(f' {HCI_Command.command_name(command)}')
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--latency-probes',
metavar='N',
type=int,
help='Send N commands to measure HCI transport latency statistics',
)
@click.argument('transport')
def main(transport):
def main(latency_probes, transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(async_main(transport))
asyncio.run(async_main(latency_probes, transport))
# -----------------------------------------------------------------------------

205
apps/controller_loopback.py Normal file
View File

@@ -0,0 +1,205 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import time
from typing import Optional
from bumble.colors import color
from bumble.hci import (
HCI_READ_LOOPBACK_MODE_COMMAND,
HCI_Read_Loopback_Mode_Command,
HCI_WRITE_LOOPBACK_MODE_COMMAND,
HCI_Write_Loopback_Mode_Command,
LoopbackMode,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
import click
class Loopback:
"""Send and receive ACL data packets in local loopback mode"""
def __init__(self, packet_size: int, packet_count: int, transport: str):
self.transport = transport
self.packet_size = packet_size
self.packet_count = packet_count
self.connection_handle: Optional[int] = None
self.connection_event = asyncio.Event()
self.done = asyncio.Event()
self.expected_cid = 0
self.bytes_received = 0
self.start_timestamp = 0.0
self.last_timestamp = 0.0
def on_connection(self, connection_handle: int, *args):
"""Retrieve connection handle from new connection event"""
if not self.connection_event.is_set():
# save first connection handle for ACL
# subsequent connections are SCO
self.connection_handle = connection_handle
self.connection_event.set()
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
"""Calculate packet receive speed"""
now = time.time()
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
assert connection_handle == self.connection_handle
assert cid == self.expected_cid
self.expected_cid += 1
if cid == 0:
self.start_timestamp = now
else:
elapsed_since_start = now - self.start_timestamp
elapsed_since_last = now - self.last_timestamp
self.bytes_received += len(pdu)
instant_rx_speed = len(pdu) / elapsed_since_last
average_rx_speed = self.bytes_received / elapsed_since_start
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f}',
'cyan',
)
)
self.last_timestamp = now
if self.expected_cid == self.packet_count:
print(color('@@@ Received last packet', 'green'))
self.done.set()
async def run(self):
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport_or_link(self.transport) as (
hci_source,
hci_sink,
):
print(color('>>> Connected', 'green'))
host = Host(hci_source, hci_sink)
await host.reset()
# make sure data can fit in one l2cap pdu
l2cap_header_size = 4
max_packet_size = (
host.acl_packet_queue
if host.acl_packet_queue
else host.le_acl_packet_queue
).max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size:
print(
color(
f'!!! Packet size ({self.packet_size}) larger than max supported'
f' size ({max_packet_size})',
'red',
)
)
return
if not host.supports_command(
HCI_WRITE_LOOPBACK_MODE_COMMAND
) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND):
print(color('!!! Loopback mode not supported', 'red'))
return
# set event callbacks
host.on('connection', self.on_connection)
host.on('l2cap_pdu', self.on_l2cap_pdu)
loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue'))
await host.send_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
)
print(color('### Checking loopback mode', 'blue'))
response = await host.send_command(
HCI_Read_Loopback_Mode_Command(), check_result=True
)
if response.return_parameters.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red'))
return
await self.connection_event.wait()
print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta'))
start_time = time.time()
bytes_sent = 0
for cid in range(0, self.packet_count):
# using the cid as an incremental index
host.send_l2cap_pdu(
self.connection_handle, cid, bytes(self.packet_size)
)
print(
color(
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
)
)
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
await asyncio.sleep(0) # yield to allow packet receive
await self.done.wait()
print(color('=== Done!', 'magenta'))
elapsed = time.time() - start_time
average_tx_speed = bytes_sent / elapsed
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
f' in {elapsed:.2f} seconds)',
'green',
)
)
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--packet-size',
'-s',
metavar='SIZE',
type=click.IntRange(8, 4096),
default=500,
help='Packet size',
)
@click.option(
'--packet-count',
'-c',
metavar='COUNT',
type=click.IntRange(1, 65535),
default=10,
help='Packet count',
)
@click.argument('transport')
def main(packet_size, packet_count, transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
loopback = Loopback(packet_size, packet_count, transport)
asyncio.run(loopback.run())
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()

View File

@@ -21,6 +21,7 @@ import struct
import logging
import click
from bumble import l2cap
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.core import AdvertisingData
@@ -204,7 +205,7 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
# -----------------------------------------------------------------------------
class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device):
def __init__(self, device: Device):
super().__init__()
self.device = device
self.peer = None
@@ -218,7 +219,12 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
# Listen for incoming L2CAP CoC connections
psm = 0xFB
device.register_l2cap_channel_server(0xFB, self.on_coc)
device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(
psm=0xFB,
),
handler=self.on_coc,
)
print(f'### Listening for CoC connection on PSM {psm}')
# Setup the Gattlink service

View File

@@ -20,6 +20,7 @@ import logging
import os
import click
from bumble import l2cap
from bumble.colors import color
from bumble.transport import open_transport_or_link
from bumble.device import Device
@@ -47,16 +48,17 @@ class ServerBridge:
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
# Listen for incoming L2CAP CoC connections
device.register_l2cap_channel_server(
psm=self.psm,
server=self.on_coc,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
async def start(self, device: Device) -> None:
# Listen for incoming L2CAP channel connections
device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(
psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits
),
handler=self.on_channel,
)
print(
color(f'### Listening for channel connection on PSM {self.psm}', 'yellow')
)
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection):
def on_ble_disconnection(reason):
@@ -73,7 +75,7 @@ class ServerBridge:
await device.start_advertising(auto_restart=True)
# Called when a new L2CAP connection is established
def on_coc(self, l2cap_channel):
def on_channel(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
class Pipe:
@@ -83,7 +85,7 @@ class ServerBridge:
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_coc_sdu
l2cap_channel.sink = self.on_channel_sdu
async def connect_to_tcp(self):
# Connect to the TCP server
@@ -105,7 +107,7 @@ class ServerBridge:
asyncio.create_task(self.pipe.l2cap_channel.disconnect())
def data_received(self, data):
print(f'<<< Received on TCP: {len(data)}')
print(color(f'<<< [TCP DATA]: {len(data)} bytes', 'blue'))
self.pipe.l2cap_channel.write(data)
try:
@@ -123,11 +125,12 @@ class ServerBridge:
await self.l2cap_channel.disconnect()
def on_l2cap_close(self):
print(color('*** L2CAP channel closed', 'red'))
self.l2cap_channel = None
if self.tcp_transport is not None:
self.tcp_transport.close()
def on_coc_sdu(self, sdu):
def on_channel_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
if self.tcp_transport is None:
print(color('!!! TCP socket not open, dropping', 'red'))
@@ -182,7 +185,7 @@ class ClientBridge:
peer_name = writer.get_extra_info('peer_name')
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
def on_coc_sdu(sdu):
def on_channel_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
l2cap_to_tcp_pipe.write(sdu)
@@ -194,11 +197,13 @@ class ClientBridge:
# Connect a new L2CAP channel
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
try:
l2cap_channel = await connection.open_l2cap_channel(
psm=self.psm,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
l2cap_channel = await connection.create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(
psm=self.psm,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
)
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
except Exception as error:
@@ -206,7 +211,7 @@ class ClientBridge:
writer.close()
return
l2cap_channel.sink = on_coc_sdu
l2cap_channel.sink = on_channel_sdu
l2cap_channel.on('close', on_l2cap_close)
# Start a flow control pipe from L2CAP to TCP
@@ -271,23 +276,29 @@ async def run(device_config, hci_transport, bridge):
@click.pass_context
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option('--psm', help='PSM for L2CAP', type=int, default=1234)
@click.option(
'--l2cap-coc-max-credits',
help='Maximum L2CAP CoC Credits',
'--l2cap-max-credits',
help='Maximum L2CAP Credits',
type=click.IntRange(1, 65535),
default=128,
)
@click.option(
'--l2cap-coc-mtu',
help='L2CAP CoC MTU',
type=click.IntRange(23, 65535),
default=1022,
'--l2cap-mtu',
help='L2CAP MTU',
type=click.IntRange(
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU,
),
default=1024,
)
@click.option(
'--l2cap-coc-mps',
help='L2CAP CoC MPS',
type=click.IntRange(23, 65533),
'--l2cap-mps',
help='L2CAP MPS',
type=click.IntRange(
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS,
),
default=1024,
)
def cli(
@@ -295,17 +306,17 @@ def cli(
device_config,
hci_transport,
psm,
l2cap_coc_max_credits,
l2cap_coc_mtu,
l2cap_coc_mps,
l2cap_max_credits,
l2cap_mtu,
l2cap_mps,
):
context.ensure_object(dict)
context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_coc_max_credits
context.obj['mtu'] = l2cap_coc_mtu
context.obj['mps'] = l2cap_coc_mps
context.obj['max_credits'] = l2cap_max_credits
context.obj['mtu'] = l2cap_mtu
context.obj['mps'] = l2cap_mps
# -----------------------------------------------------------------------------

577
apps/lea_unicast/app.py Normal file
View File

@@ -0,0 +1,577 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import datetime
import enum
import functools
from importlib import resources
import json
import os
import logging
import pathlib
from typing import Optional, List, cast
import weakref
import struct
import ctypes
import wasmtime
import wasmtime.loader
import liblc3 # type: ignore
import logging
import click
import aiohttp.web
import bumble
from bumble.core import AdvertisingData
from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
from bumble.transport import open_transport
from bumble.profiles import bap
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654
def _sink_pac_record() -> bap.PacRecord:
return bap.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
bap.SupportedSamplingFrequency.FREQ_8000
| bap.SupportedSamplingFrequency.FREQ_16000
| bap.SupportedSamplingFrequency.FREQ_24000
| bap.SupportedSamplingFrequency.FREQ_32000
| bap.SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1, 2],
min_octets_per_codec_frame=26,
max_octets_per_codec_frame=240,
supported_max_codec_frames_per_sdu=2,
),
)
def _source_pac_record() -> bap.PacRecord:
return bap.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
bap.SupportedSamplingFrequency.FREQ_8000
| bap.SupportedSamplingFrequency.FREQ_16000
| bap.SupportedSamplingFrequency.FREQ_24000
| bap.SupportedSamplingFrequency.FREQ_32000
| bap.SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1],
min_octets_per_codec_frame=30,
max_octets_per_codec_frame=100,
supported_max_codec_frames_per_sdu=1,
),
)
# -----------------------------------------------------------------------------
# WASM - liblc3
# -----------------------------------------------------------------------------
store = wasmtime.loader.store
_memory = cast(wasmtime.Memory, liblc3.memory)
STACK_POINTER = _memory.data_len(store)
_memory.grow(store, 1)
# Mapping wasmtime memory to linear address
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
)
class Liblc3PcmFormat(enum.IntEnum):
S16 = 0
S24 = 1
S24_3LE = 2
FLOAT = 3
MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)
DECODER_STACK_POINTER = STACK_POINTER
ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
DEFAULT_PCM_SAMPLE_RATE = 48000
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
encoders: List[int] = []
decoders: List[int] = []
def setup_encoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
encoders[:num_channels] = [
liblc3.lc3_setup_encoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Input sample rate
ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
)
for i in range(num_channels)
]
def setup_decoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
decoders[:num_channels] = [
liblc3.lc3_setup_decoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
)
for i in range(num_channels)
]
def decode(
frame_duration_us: int,
num_channels: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = DECODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
input_bytes_per_frame = input_buffer_size // num_channels
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
* num_channels
)
for i in range(num_channels):
res = liblc3.lc3_decode(
decoders[i],
input_buffer_offset + input_bytes_per_frame * i,
input_bytes_per_frame,
DEFAULT_PCM_FORMAT,
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
num_channels, # Stride
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
def encode(
sdu_length: int,
num_channels: int,
stride: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = sdu_length
output_frame_size = output_buffer_size // num_channels
for i in range(num_channels):
res = liblc3.lc3_encode(
encoders[i],
DEFAULT_PCM_FORMAT,
input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
stride,
output_frame_size,
output_buffer_offset + output_frame_size * i,
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
async def lc3_source_task(
filename: str,
sdu_length: int,
frame_duration_us: int,
device: Device,
cis_handle: int,
) -> None:
with open(filename, 'rb') as f:
header = f.read(44)
assert header[8:12] == b'WAVE'
pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = (
struct.unpack("<HIIHH", header[22:36])
)
assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
frame_bytes = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
)
packet_sequence_number = 0
while True:
next_round = datetime.datetime.now() + datetime.timedelta(
microseconds=frame_duration_us
)
pcm_data = f.read(frame_bytes)
sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data)
iso_packet = HCI_IsoDataPacket(
connection_handle=cis_handle,
data_total_length=sdu_length + 4,
packet_sequence_number=packet_sequence_number,
pb_flag=0b10,
packet_status_flag=0,
iso_sdu_length=sdu_length,
iso_sdu_fragment=sdu,
)
device.host.send_hci_packet(iso_packet)
packet_sequence_number += 1
sleep_time = next_round - datetime.datetime.now()
await asyncio.sleep(sleep_time.total_seconds())
# -----------------------------------------------------------------------------
class UiServer:
speaker: weakref.ReferenceType[Speaker]
port: int
def __init__(self, speaker: Speaker, port: int) -> None:
self.speaker = weakref.ref(speaker)
self.port = port
self.channel_socket = None
async def start_http(self) -> None:
"""Start the UI HTTP server."""
app = aiohttp.web.Application()
app.add_routes(
[
aiohttp.web.get('/', self.get_static),
aiohttp.web.get('/index.html', self.get_static),
aiohttp.web.get('/channel', self.get_channel),
]
)
runner = aiohttp.web.AppRunner(app)
await runner.setup()
site = aiohttp.web.TCPSite(runner, 'localhost', self.port)
print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green'))
await site.start()
async def get_static(self, request):
path = request.path
if path == '/':
path = '/index.html'
if path.endswith('.html'):
content_type = 'text/html'
elif path.endswith('.js'):
content_type = 'text/javascript'
elif path.endswith('.css'):
content_type = 'text/css'
elif path.endswith('.svg'):
content_type = 'image/svg+xml'
else:
content_type = 'text/plain'
text = (
resources.files("bumble.apps.lea_unicast")
.joinpath(pathlib.Path(path).relative_to('/'))
.read_text(encoding="utf-8")
)
return aiohttp.web.Response(text=text, content_type=content_type)
async def get_channel(self, request):
ws = aiohttp.web.WebSocketResponse()
await ws.prepare(request)
# Process messages until the socket is closed.
self.channel_socket = ws
async for message in ws:
if message.type == aiohttp.WSMsgType.TEXT:
logger.debug(f'<<< received message: {message.data}')
await self.on_message(message.data)
elif message.type == aiohttp.WSMsgType.ERROR:
logger.debug(
f'channel connection closed with exception {ws.exception()}'
)
self.channel_socket = None
logger.debug('--- channel connection closed')
return ws
async def on_message(self, message_str: str):
# Parse the message as JSON
message = json.loads(message_str)
# Dispatch the message
message_type = message['type']
message_params = message.get('params', {})
handler = getattr(self, f'on_{message_type}_message')
if handler:
await handler(**message_params)
async def on_hello_message(self):
await self.send_message(
'hello',
bumble_version=bumble.__version__,
codec=self.speaker().codec,
streamState=self.speaker().stream_state.name,
)
if connection := self.speaker().connection:
await self.send_message(
'connection',
peer_address=connection.peer_address.to_string(False),
peer_name=connection.peer_name,
)
async def send_message(self, message_type: str, **kwargs) -> None:
if self.channel_socket is None:
return
message = {'type': message_type, 'params': kwargs}
await self.channel_socket.send_json(message)
async def send_audio(self, data: bytes) -> None:
if self.channel_socket is None:
return
try:
await self.channel_socket.send_bytes(data)
except Exception as error:
logger.warning(f'exception while sending audio packet: {error}')
# -----------------------------------------------------------------------------
class Speaker:
def __init__(
self,
device_config_path: Optional[str],
ui_port: int,
transport: str,
lc3_input_file_path: str,
):
self.device_config_path = device_config_path
self.transport = transport
self.lc3_input_file_path = lc3_input_file_path
# Create an HTTP server for the UI
self.ui_server = UiServer(speaker=self, port=ui_port)
async def run(self) -> None:
await self.ui_server.start_http()
async with await open_transport(self.transport) as hci_transport:
# Create a device
if self.device_config_path:
device_config = DeviceConfiguration.from_file(self.device_config_path)
else:
device_config = DeviceConfiguration(
name="Bumble LE Headphone",
class_of_device=0x244418,
keystore="JsonKeyStore",
advertising_interval_min=25,
advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'),
)
device_config.le_enabled = True
device_config.cis_enabled = True
self.device = Device.from_config_with_hci(
device_config, hci_transport.source, hci_transport.sink
)
self.device.add_service(
bap.PublishedAudioCapabilitiesService(
supported_source_context=bap.ContextType(0xFFFF),
available_source_context=bap.ContextType(0xFFFF),
supported_sink_context=bap.ContextType(0xFFFF), # All context types
available_sink_context=bap.ContextType(0xFFFF), # All context types
sink_audio_locations=(
bap.AudioLocation.FRONT_LEFT | bap.AudioLocation.FRONT_RIGHT
),
sink_pac=[_sink_pac_record()],
source_audio_locations=bap.AudioLocation.FRONT_LEFT,
source_pac=[_source_pac_record()],
)
)
ascs = bap.AudioStreamControlService(
self.device, sink_ase_id=[1], source_ase_id=[2]
)
self.device.add_service(ascs)
advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(device_config.name, 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(bap.PublishedAudioCapabilitiesService.UUID),
),
]
)
) + bytes(bap.UnicastServerAdvertisingData())
def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine):
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
pcm = decode(
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
pdu.iso_sdu_fragment,
)
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
def on_ase_state_change(ase: bap.AseStateMachine) -> None:
if ase.state == bap.AseStateMachine.State.STREAMING:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
assert ase.cis_link
if ase.role == bap.AudioRole.SOURCE:
ase.cis_link.abort_on(
'disconnection',
lc3_source_task(
filename=self.lc3_input_file_path,
sdu_length=(
codec_config.codec_frames_per_sdu
* codec_config.octets_per_codec_frame
),
frame_duration_us=codec_config.frame_duration.us,
device=self.device,
cis_handle=ase.cis_link.handle,
),
)
else:
ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
if ase.role == bap.AudioRole.SOURCE:
setup_encoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
else:
setup_decoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
for ase in ascs.ase_state_machines.values():
ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
await self.device.power_on()
await self.device.create_advertising_set(
advertising_data=advertising_data,
auto_restart=True,
advertising_parameters=AdvertisingParameters(
primary_advertising_interval_min=100,
primary_advertising_interval_max=100,
),
)
await hci_transport.source.terminated
@click.command()
@click.option(
'--ui-port',
'ui_port',
metavar='HTTP_PORT',
default=DEFAULT_UI_PORT,
show_default=True,
help='HTTP port for the UI server',
)
@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
@click.argument('transport')
@click.argument('lc3_file')
def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) -> None:
"""Run the speaker."""
asyncio.run(Speaker(device_config, ui_port, transport, lc3_file).run())
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
speaker()
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter

View File

@@ -0,0 +1,68 @@
<html data-bs-theme="dark">
<head>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
<script src="https://unpkg.com/pcm-player"></script>
</head>
<body>
<nav class="navbar navbar-dark bg-primary">
<div class="container">
<span class="navbar-brand mb-0 h1">Bumble Unicast Server</span>
</div>
</nav>
<br>
<div class="container">
<button type="button" class="btn btn-danger" id="connect-audio" onclick="connectAudio()">Connect Audio</button>
<button class="btn btn-primary" type="button" disabled>
<span class="spinner-border spinner-border-sm" id="ws-status-spinner" aria-hidden="true"></span>
<span role="status" id="ws-status">WebSocket Connecting...</span>
</button>
</div>
<script>
let player = null;
const wsStatus = document.getElementById("ws-status");
const wsStatusSpinner = document.getElementById("ws-status-spinner");
const socket = new WebSocket('ws://127.0.0.1:7654/channel');
socket.binaryType = "arraybuffer";
socket.onmessage = function (message) {
if (typeof message.data === 'string' || message.data instanceof String) {
console.log(`channel MESSAGE: ${message.data}`);
} else {
console.log(typeof (message.data))
// BINARY audio data.
if (player == null) return;
player.feed(message.data);
}
};
socket.onopen = (message) => {
wsStatusSpinner.remove();
wsStatus.textContent = "WebSocket Connected";
}
socket.onclose = (message) => {
wsStatus.textContent = "WebSocket Disconnected";
}
function connectAudio() {
player = new PCMPlayer({
inputCodec: 'Int16',
channels: 2,
sampleRate: 48000,
flushTime: 10,
});
const button = document.getElementById("connect-audio")
button.disabled = true;
button.textContent = "Audio Connected";
}
</script>
</div>
</body>
</html>

BIN
apps/lea_unicast/liblc3.wasm Executable file

Binary file not shown.

View File

@@ -24,10 +24,16 @@ from prompt_toolkit.shortcuts import PromptSession
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.transport import open_transport_or_link
from bumble.pairing import PairingDelegate, PairingConfig
from bumble.pairing import OobData, PairingDelegate, PairingConfig
from bumble.smp import OobContext, OobLegacyContext
from bumble.smp import error_name as smp_error_name
from bumble.keys import JsonKeyStore
from bumble.core import ProtocolError
from bumble.core import (
AdvertisingData,
ProtocolError,
BT_LE_TRANSPORT,
BT_BR_EDR_TRANSPORT,
)
from bumble.gatt import (
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE,
@@ -46,11 +52,13 @@ from bumble.att import (
class Waiter:
instance = None
def __init__(self):
def __init__(self, linger=False):
self.done = asyncio.get_running_loop().create_future()
self.linger = linger
def terminate(self):
self.done.set_result(None)
if not self.linger:
self.done.set_result(None)
async def wait_until_terminated(self):
return await self.done
@@ -60,7 +68,7 @@ class Waiter:
class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, do_prompt):
super().__init__(
{
io_capability={
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
@@ -285,7 +293,9 @@ async def pair(
mitm,
bond,
ctkd,
linger,
io,
oob,
prompt,
request,
print_keys,
@@ -294,7 +304,7 @@ async def pair(
hci_transport,
address_or_name,
):
Waiter.instance = Waiter()
Waiter.instance = Waiter(linger=linger)
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
@@ -306,6 +316,7 @@ async def pair(
# Expose a GATT characteristic that can be used to trigger pairing by
# responding with an authentication error when read
if mode == 'le':
device.le_enabled = True
device.add_service(
Service(
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
@@ -326,7 +337,6 @@ async def pair(
# Select LE or Classic
if mode == 'classic':
device.classic_enabled = True
device.le_enabled = False
device.classic_smp_enabled = ctkd
# Get things going
@@ -343,16 +353,51 @@ async def pair(
await device.keystore.print(prefix=color('@@@ ', 'blue'))
print(color('@@@-----------------------------------', 'blue'))
# Create an OOB context if needed
if oob:
our_oob_context = OobContext()
shared_data = (
None
if oob == '-'
else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob)))
)
legacy_context = OobLegacyContext()
oob_contexts = PairingConfig.OobConfig(
our_context=our_oob_context,
peer_data=shared_data,
legacy_context=legacy_context,
)
oob_data = OobData(
address=device.random_address,
shared_data=shared_data,
legacy_context=legacy_context,
)
print(color('@@@-----------------------------------', 'yellow'))
print(color('@@@ OOB Data:', 'yellow'))
print(color(f'@@@ {our_oob_context.share()}', 'yellow'))
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
print(color(f'@@@ HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow'))
print(color('@@@-----------------------------------', 'yellow'))
else:
oob_contexts = None
# Set up a pairing config factory
device.pairing_config_factory = lambda connection: PairingConfig(
sc, mitm, bond, Delegate(mode, connection, io, prompt)
sc=sc,
mitm=mitm,
bonding=bond,
oob=oob_contexts,
delegate=Delegate(mode, connection, io, prompt),
)
# Connect to a peer or wait for a connection
device.on('connection', lambda connection: on_connection(connection, request))
if address_or_name is not None:
print(color(f'=== Connecting to {address_or_name}...', 'green'))
connection = await device.connect(address_or_name)
connection = await device.connect(
address_or_name,
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
)
if not request:
try:
@@ -360,10 +405,9 @@ async def pair(
await connection.pair()
else:
await connection.authenticate()
return
except ProtocolError as error:
print(color(f'Pairing failed: {error}', 'red'))
return
else:
if mode == 'le':
# Advertise so that peers can find us and connect
@@ -413,6 +457,7 @@ class LogHandler(logging.Handler):
help='Enable CTKD',
show_default=True,
)
@click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
@click.option(
'--io',
type=click.Choice(
@@ -421,6 +466,14 @@ class LogHandler(logging.Handler):
default='display+keyboard',
show_default=True,
)
@click.option(
'--oob',
metavar='<oob-data-hex>',
help=(
'Use OOB pairing with this data from the peer '
'(use "-" to enable OOB without peer data)'
),
)
@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
@click.option(
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
@@ -440,7 +493,9 @@ def main(
mitm,
bond,
ctkd,
linger,
io,
oob,
prompt,
request,
print_keys,
@@ -463,7 +518,9 @@ def main(
mitm,
bond,
ctkd,
linger,
io,
oob,
prompt,
request,
print_keys,

View File

@@ -1,8 +1,10 @@
import asyncio
import click
import logging
import json
from bumble.pandora import PandoraDevice, serve
from bumble.pandora import PandoraDevice, Config, serve
from typing import Dict, Any
BUMBLE_SERVER_GRPC_PORT = 7999
ROOTCANAL_PORT_CUTTLEFISH = 7300
@@ -18,12 +20,31 @@ ROOTCANAL_PORT_CUTTLEFISH = 7300
help='HCI transport',
default=f'tcp-client:127.0.0.1:<rootcanal-port>',
)
def main(grpc_port: int, rootcanal_port: int, transport: str) -> None:
@click.option(
'--config',
help='Bumble json configuration file',
)
def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> None:
if '<rootcanal-port>' in transport:
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
device = PandoraDevice({'transport': transport})
bumble_config = retrieve_config(config)
bumble_config.setdefault('transport', transport)
device = PandoraDevice(bumble_config)
server_config = Config()
server_config.load_from_dict(bumble_config.get('server', {}))
logging.basicConfig(level=logging.DEBUG)
asyncio.run(serve(device, port=grpc_port))
asyncio.run(serve(device, config=server_config, port=grpc_port))
def retrieve_config(config: str) -> Dict[str, Any]:
if not config:
return {}
with open(config, 'r') as f:
return json.load(f)
if __name__ == '__main__':

511
apps/rfcomm_bridge.py Normal file
View File

@@ -0,0 +1,511 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import time
from typing import Optional
import click
from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, Connection
from bumble import core
from bumble import hci
from bumble import rfcomm
from bumble import transport
from bumble import utils
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_RFCOMM_UUID = "E6D55659-C8B4-4B85-96BB-B1143AF6D3AE"
DEFAULT_MTU = 4096
DEFAULT_CLIENT_TCP_PORT = 9544
DEFAULT_SERVER_TCP_PORT = 9545
TRACE_MAX_SIZE = 48
# -----------------------------------------------------------------------------
class Tracer:
"""
Trace data buffers transmitted from one endpoint to another, with stats.
"""
def __init__(self, channel_name: str) -> None:
self.channel_name = channel_name
self.last_ts: float = 0.0
def trace_data(self, data: bytes) -> None:
now = time.time()
elapsed_s = now - self.last_ts if self.last_ts else 0
elapsed_ms = int(elapsed_s * 1000)
instant_throughput_kbps = ((len(data) / elapsed_s) / 1000) if elapsed_s else 0.0
hex_str = data[:TRACE_MAX_SIZE].hex() + (
"..." if len(data) > TRACE_MAX_SIZE else ""
)
print(
f"[{self.channel_name}] {len(data):4} bytes "
f"(+{elapsed_ms:4}ms, {instant_throughput_kbps: 7.2f}kB/s) "
f" {hex_str}"
)
self.last_ts = now
# -----------------------------------------------------------------------------
class ServerBridge:
"""
RFCOMM server bridge: waits for a peer to connect an RFCOMM channel.
The RFCOMM channel may be associated with a UUID published in an SDP service
description, or simply be on a system-assigned channel number.
When the connection is made, the bridge connects a TCP socket to a remote host and
bridges the data in both directions, with flow control.
When the RFCOMM channel is closed, the bridge disconnects the TCP socket
and waits for a new channel to be connected.
"""
READ_CHUNK_SIZE = 4096
def __init__(
self, channel: int, uuid: str, trace: bool, tcp_host: str, tcp_port: int
) -> None:
self.device: Optional[Device] = None
self.channel = channel
self.uuid = uuid
self.tcp_host = tcp_host
self.tcp_port = tcp_port
self.rfcomm_channel: Optional[rfcomm.DLC] = None
self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Optional[Tracer]
if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
else:
self.rfcomm_tracer = None
self.tcp_tracer = None
async def start(self, device: Device) -> None:
self.device = device
# Create and register a server
rfcomm_server = rfcomm.Server(self.device)
# Listen for incoming DLC connections
self.channel = rfcomm_server.listen(self.on_rfcomm_channel, self.channel)
# Setup the SDP to advertise this channel
service_record_handle = 0x00010001
self.device.sdp_service_records = {
service_record_handle: rfcomm.make_service_sdp_records(
service_record_handle, self.channel, core.UUID(self.uuid)
)
}
# We're ready for a connection
self.device.on("connection", self.on_connection)
await self.set_available(True)
print(
color(
(
f"### Listening for RFCOMM connection on {device.public_address}, "
f"channel {self.channel}"
),
"yellow",
)
)
async def set_available(self, available: bool):
# Become discoverable and connectable
assert self.device
await self.device.set_connectable(available)
await self.device.set_discoverable(available)
def on_connection(self, connection):
print(color(f"@@@ Bluetooth connection: {connection}", "blue"))
connection.on("disconnection", self.on_disconnection)
# Don't accept new connections until we're disconnected
utils.AsyncRunner.spawn(self.set_available(False))
def on_disconnection(self, reason: int):
print(
color("@@@ Bluetooth disconnection:", "red"),
hci.HCI_Constant.error_name(reason),
)
# We're ready for a new connection
utils.AsyncRunner.spawn(self.set_available(True))
# Called when an RFCOMM channel is established
@utils.AsyncRunner.run_in_task()
async def on_rfcomm_channel(self, rfcomm_channel):
print(color("*** RFCOMM channel:", "cyan"), rfcomm_channel)
# Connect to the TCP server
print(
color(
f"### Connecting to TCP {self.tcp_host}:{self.tcp_port}",
"yellow",
)
)
try:
reader, writer = await asyncio.open_connection(self.tcp_host, self.tcp_port)
except OSError:
print(color("!!! Connection failed", "red"))
await rfcomm_channel.disconnect()
return
# Pipe data from RFCOMM to TCP
def on_rfcomm_channel_closed():
print(color("*** RFCOMM channel closed", "cyan"))
writer.close()
def write_rfcomm_data(data):
if self.rfcomm_tracer:
self.rfcomm_tracer.trace_data(data)
writer.write(data)
rfcomm_channel.sink = write_rfcomm_data
rfcomm_channel.on("close", on_rfcomm_channel_closed)
# Pipe data from TCP to RFCOMM
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color("### TCP end of stream", "yellow"))
if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
await rfcomm_channel.disconnect()
return
if self.tcp_tracer:
self.tcp_tracer.trace_data(data)
rfcomm_channel.write(data)
await rfcomm_channel.drain()
except Exception as error:
print(f"!!! Exception: {error}")
break
writer.close()
await writer.wait_closed()
print(color("~~~ Bye bye", "magenta"))
# -----------------------------------------------------------------------------
class ClientBridge:
"""
RFCOMM client bridge: connects to a BR/EDR device, then waits for an inbound
TCP connection on a specified port number. When a TCP client connects, an
RFCOMM connection to the device is established, and the data is bridged in both
directions, with flow control.
When the TCP connection is closed by the client, the RFCOMM channel is
disconnected, but the connection to the device remains, ready for a new TCP client
to connect.
"""
READ_CHUNK_SIZE = 4096
def __init__(
self,
channel: int,
uuid: str,
trace: bool,
address: str,
tcp_host: str,
tcp_port: int,
encrypt: bool,
):
self.channel = channel
self.uuid = uuid
self.trace = trace
self.address = address
self.tcp_host = tcp_host
self.tcp_port = tcp_port
self.encrypt = encrypt
self.device: Optional[Device] = None
self.connection: Optional[Connection] = None
self.rfcomm_client: Optional[rfcomm.Client]
self.rfcomm_mux: Optional[rfcomm.Multiplexer]
self.tcp_connected: bool = False
self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Optional[Tracer]
if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
else:
self.rfcomm_tracer = None
self.tcp_tracer = None
async def connect(self) -> None:
if self.connection:
return
print(color(f"@@@ Connecting to Bluetooth {self.address}", "blue"))
assert self.device
self.connection = await self.device.connect(
self.address, transport=core.BT_BR_EDR_TRANSPORT
)
print(color(f"@@@ Bluetooth connection: {self.connection}", "blue"))
self.connection.on("disconnection", self.on_disconnection)
if self.encrypt:
print(color("@@@ Encrypting Bluetooth connection", "blue"))
await self.connection.encrypt()
print(color("@@@ Bluetooth connection encrypted", "blue"))
self.rfcomm_client = rfcomm.Client(self.connection)
try:
self.rfcomm_mux = await self.rfcomm_client.start()
except BaseException as e:
print(color("!!! Failed to setup RFCOMM connection", "red"), e)
raise
async def start(self, device: Device) -> None:
self.device = device
await device.set_connectable(False)
await device.set_discoverable(False)
# Called when a TCP connection is established
async def on_tcp_connection(reader, writer):
print(color("<<< TCP connection", "magenta"))
if self.tcp_connected:
print(
color("!!! TCP connection already active, rejecting new one", "red")
)
writer.close()
return
self.tcp_connected = True
try:
await self.pipe(reader, writer)
except BaseException as error:
print(color("!!! Exception while piping data:", "red"), error)
return
finally:
writer.close()
await writer.wait_closed()
self.tcp_connected = False
await asyncio.start_server(
on_tcp_connection,
host=self.tcp_host if self.tcp_host != "_" else None,
port=self.tcp_port,
)
print(
color(
f"### Listening for TCP connections on port {self.tcp_port}", "magenta"
)
)
async def pipe(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
# Resolve the channel number from the UUID if needed
if self.channel == 0:
await self.connect()
assert self.connection
channel = await rfcomm.find_rfcomm_channel_with_uuid(
self.connection, self.uuid
)
if channel:
print(color(f"### Found RFCOMM channel {channel}", "yellow"))
else:
print(color(f"!!! RFCOMM channel with UUID {self.uuid} not found"))
return
else:
channel = self.channel
# Connect a new RFCOMM channel
await self.connect()
assert self.rfcomm_mux
print(color(f"*** Opening RFCOMM channel {channel}", "green"))
try:
rfcomm_channel = await self.rfcomm_mux.open_dlc(channel)
print(color(f"*** RFCOMM channel open: {rfcomm_channel}", "green"))
except Exception as error:
print(color(f"!!! RFCOMM open failed: {error}", "red"))
return
# Pipe data from RFCOMM to TCP
def on_rfcomm_channel_closed():
print(color("*** RFCOMM channel closed", "green"))
def write_rfcomm_data(data):
if self.trace:
self.rfcomm_tracer.trace_data(data)
writer.write(data)
rfcomm_channel.on("close", on_rfcomm_channel_closed)
rfcomm_channel.sink = write_rfcomm_data
# Pipe data from TCP to RFCOMM
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color("### TCP end of stream", "yellow"))
if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
await rfcomm_channel.disconnect()
self.tcp_connected = False
return
if self.tcp_tracer:
self.tcp_tracer.trace_data(data)
rfcomm_channel.write(data)
await rfcomm_channel.drain()
except Exception as error:
print(f"!!! Exception: {error}")
break
print(color("~~~ Bye bye", "magenta"))
def on_disconnection(self, reason: int) -> None:
print(
color("@@@ Bluetooth disconnection:", "red"),
hci.HCI_Constant.error_name(reason),
)
self.connection = None
# -----------------------------------------------------------------------------
async def run(device_config, hci_transport, bridge):
print("<<< connecting to HCI...")
async with await transport.open_transport_or_link(hci_transport) as (
hci_source,
hci_sink,
):
print("<<< connected")
if device_config:
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
device = Device.from_config_with_hci(
DeviceConfiguration(), hci_source, hci_sink
)
device.classic_enabled = True
# Let's go
await device.power_on()
try:
await bridge.start(device)
# Wait until the transport terminates
await hci_source.wait_for_termination()
except core.ConnectionError as error:
print(color(f"!!! Bluetooth connection failed: {error}", "red"))
except Exception as error:
print(f"Exception while running bridge: {error}")
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
@click.option(
"--device-config",
metavar="CONFIG_FILE",
help="Device configuration file",
)
@click.option(
"--hci-transport", metavar="TRANSPORT_NAME", help="HCI transport", required=True
)
@click.option("--trace", is_flag=True, help="Trace bridged data to stdout")
@click.option(
"--channel",
metavar="CHANNEL_NUMER",
help="RFCOMM channel number",
type=int,
default=0,
)
@click.option(
"--uuid",
metavar="UUID",
help="UUID for the RFCOMM channel",
default=DEFAULT_RFCOMM_UUID,
)
def cli(
context,
device_config,
hci_transport,
trace,
channel,
uuid,
):
context.ensure_object(dict)
context.obj["device_config"] = device_config
context.obj["hci_transport"] = hci_transport
context.obj["trace"] = trace
context.obj["channel"] = channel
context.obj["uuid"] = uuid
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.option("--tcp-host", help="TCP host", default="localhost")
@click.option("--tcp-port", help="TCP port", default=DEFAULT_SERVER_TCP_PORT)
def server(context, tcp_host, tcp_port):
bridge = ServerBridge(
context.obj["channel"],
context.obj["uuid"],
context.obj["trace"],
tcp_host,
tcp_port,
)
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.argument("bluetooth-address")
@click.option("--tcp-host", help="TCP host", default="_")
@click.option("--tcp-port", help="TCP port", default=DEFAULT_CLIENT_TCP_PORT)
@click.option("--encrypt", is_flag=True, help="Encrypt the connection")
def client(context, bluetooth_address, tcp_host, tcp_port, encrypt):
bridge = ClientBridge(
context.obj["channel"],
context.obj["uuid"],
context.obj["trace"],
bluetooth_address,
tcp_host,
tcp_port,
encrypt,
)
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper())
if __name__ == "__main__":
cli(obj={}) # pylint: disable=no-value-for-parameter

View File

@@ -26,7 +26,7 @@ from bumble.transport import open_transport_or_link
from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver
from bumble.device import Advertisement
from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
from bumble.hci import Address, HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
# -----------------------------------------------------------------------------
@@ -66,10 +66,15 @@ class AdvertisementPrinter:
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
address.address_type
]
if address.is_public:
type_color = 'cyan'
if address.address_type in (
Address.RANDOM_IDENTITY_ADDRESS,
Address.PUBLIC_IDENTITY_ADDRESS,
):
type_color = 'yellow'
else:
if address.is_static:
if address.is_public:
type_color = 'cyan'
elif address.is_static:
type_color = 'green'
address_qualifier = '(static)'
elif address.is_resolvable:
@@ -116,6 +121,7 @@ async def scan(
phy,
filter_duplicates,
raw,
irks,
keystore_file,
device_config,
transport,
@@ -140,9 +146,21 @@ async def scan(
if device.keystore:
resolving_keys = await device.keystore.get_resolving_keys()
resolver = AddressResolver(resolving_keys)
else:
resolver = None
resolving_keys = []
for irk_and_address in irks:
if ':' not in irk_and_address:
raise ValueError('invalid IRK:ADDRESS value')
irk_hex, address_str = irk_and_address.split(':', 1)
resolving_keys.append(
(
bytes.fromhex(irk_hex),
Address(address_str, Address.RANDOM_DEVICE_ADDRESS),
)
)
resolver = AddressResolver(resolving_keys) if resolving_keys else None
printer = AdvertisementPrinter(min_rssi, resolver)
if raw:
@@ -187,8 +205,24 @@ async def scan(
default=False,
help='Listen for raw advertising reports instead of processed ones',
)
@click.option('--keystore-file', help='Keystore file to use when resolving addresses')
@click.option('--device-config', help='Device config file for the scanning device')
@click.option(
'--irk',
metavar='<IRK_HEX>:<ADDRESS>',
help=(
'Use this IRK for resolving private addresses ' '(may be used more than once)'
),
multiple=True,
)
@click.option(
'--keystore-file',
metavar='FILE_PATH',
help='Keystore file to use when resolving addresses',
)
@click.option(
'--device-config',
metavar='FILE_PATH',
help='Device config file for the scanning device',
)
@click.argument('transport')
def main(
min_rssi,
@@ -198,6 +232,7 @@ def main(
phy,
filter_duplicates,
raw,
irk,
keystore_file,
device_config,
transport,
@@ -212,6 +247,7 @@ def main(
phy,
filter_duplicates,
raw,
irk,
keystore_file,
device_config,
transport,

View File

@@ -15,7 +15,11 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import datetime
import logging
import os
import struct
import click
from bumble.colors import color
@@ -24,6 +28,14 @@ from bumble.transport.common import PacketReader
from bumble.helpers import PacketTracer
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class SnoopPacketReader:
'''
@@ -36,12 +48,18 @@ class SnoopPacketReader:
DATALINK_BSCP = 1003
DATALINK_H5 = 1004
IDENTIFICATION_PATTERN = b'btsnoop\0'
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1)
TIMESTAMP_DELTA = 0x00E03AB44A676000
ONE_MICROSECOND = datetime.timedelta(microseconds=1)
def __init__(self, source):
self.source = source
self.at_end = False
# Read the header
identification_pattern = source.read(8)
if identification_pattern.hex().lower() != '6274736e6f6f7000':
if identification_pattern != self.IDENTIFICATION_PATTERN:
raise ValueError(
'not a valid snoop file, unexpected identification pattern'
)
@@ -55,19 +73,32 @@ class SnoopPacketReader:
# Read the record header
header = self.source.read(24)
if len(header) < 24:
return (0, None)
self.at_end = True
return (None, 0, None)
# Parse the header
(
original_length,
included_length,
packet_flags,
_cumulative_drops,
_timestamp_seconds,
_timestamp_microsecond,
) = struct.unpack('>IIIIII', header)
timestamp,
) = struct.unpack('>IIIIQ', header)
# Abort on truncated packets
# Skip truncated packets
if original_length != included_length:
return (0, None)
print(
color(
f"!!! truncated packet ({included_length}/{original_length})", "red"
)
)
self.source.read(included_length)
return (None, 0, None)
# Convert the timestamp to a datetime object.
ts_dt = self.TIMESTAMP_ANCHOR + datetime.timedelta(
microseconds=timestamp - self.TIMESTAMP_DELTA
)
if self.data_link_type == self.DATALINK_H1:
# The packet is un-encapsulated, look at the flags to figure out its type
@@ -89,7 +120,17 @@ class SnoopPacketReader:
bytes([packet_type]) + self.source.read(included_length),
)
return (packet_flags & 1, self.source.read(included_length))
return (ts_dt, packet_flags & 1, self.source.read(included_length))
# -----------------------------------------------------------------------------
class Printer:
def __init__(self):
self.index = 0
def print(self, message: str) -> None:
self.index += 1
print(f"[{self.index:8}]{message}")
# -----------------------------------------------------------------------------
@@ -102,33 +143,48 @@ class SnoopPacketReader:
default='h4',
help='Format of the input file',
)
@click.option(
'--vendors',
type=click.Choice(['android', 'zephyr']),
multiple=True,
help='Support vendor-specific commands (list one or more)',
)
@click.argument('filename')
# pylint: disable=redefined-builtin
def main(format, filename):
def main(format, vendors, filename):
for vendor in vendors:
if vendor == 'android':
import bumble.vendor.android.hci
elif vendor == 'zephyr':
import bumble.vendor.zephyr.hci
input = open(filename, 'rb')
if format == 'h4':
packet_reader = PacketReader(input)
def read_next_packet():
return (0, packet_reader.next_packet())
return (None, 0, packet_reader.next_packet())
else:
packet_reader = SnoopPacketReader(input)
read_next_packet = packet_reader.next_packet
tracer = PacketTracer(emit_message=print)
printer = Printer()
tracer = PacketTracer(emit_message=printer.print)
while True:
while not packet_reader.at_end:
try:
(direction, packet) = read_next_packet()
if packet is None:
break
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction)
(timestamp, direction, packet) = read_next_packet()
if packet:
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction, timestamp)
else:
printer.print(color("[TRUNCATED]", "red"))
except Exception as error:
logger.exception()
print(color(f'!!! {error}', 'red'))
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
main() # pylint: disable=no-value-for-parameter

View File

@@ -56,7 +56,7 @@ body, h1, h2, h3, h4, h5, h6 {
border-radius: 4px;
padding: 4px;
margin: 6px;
margin-left: 0px;
margin-left: 0;
}
th, td {
@@ -65,7 +65,7 @@ th, td {
}
.properties td:nth-child(even) {
background-color: #D6EEEE;
background-color: #d6eeee;
font-family: monospace;
}

View File

@@ -2,7 +2,7 @@
<html>
<head>
<title>Bumble Speaker</title>
<script type="text/javascript" src="speaker.js"></script>
<script src="speaker.js"></script>
<link rel="stylesheet" href="speaker.css">
</head>
<body>

View File

@@ -76,6 +76,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654
# -----------------------------------------------------------------------------
class AudioExtractor:
@staticmethod
@@ -195,7 +196,7 @@ class WebSocketOutput(QueuedOutput):
except HCI_StatusError:
pass
peer_name = '' if connection.peer_name is None else connection.peer_name
peer_address = str(connection.peer_address).replace('/P', '')
peer_address = connection.peer_address.to_string(False)
await self.send_message(
'connection',
peer_address=peer_address,
@@ -228,10 +229,11 @@ class FfplayOutput(QueuedOutput):
subprocess: Optional[asyncio.subprocess.Process]
ffplay_task: Optional[asyncio.Task]
def __init__(self) -> None:
super().__init__(AacAudioExtractor())
def __init__(self, codec: str) -> None:
super().__init__(AudioExtractor.create(codec))
self.subprocess = None
self.ffplay_task = None
self.codec = codec
async def start(self):
if self.started:
@@ -240,7 +242,7 @@ class FfplayOutput(QueuedOutput):
await super().start()
self.subprocess = await asyncio.create_subprocess_shell(
'ffplay -acodec aac pipe:0',
f'ffplay -f {self.codec} pipe:0',
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
@@ -375,7 +377,7 @@ class UiServer:
if connection := self.speaker().connection:
await self.send_message(
'connection',
peer_address=str(connection.peer_address).replace('/P', ''),
peer_address=connection.peer_address.to_string(False),
peer_name=connection.peer_name,
)
@@ -419,7 +421,7 @@ class Speaker:
self.outputs = []
for output in outputs:
if output == '@ffplay':
self.outputs.append(FfplayOutput())
self.outputs.append(FfplayOutput(codec))
continue
# Default to FileOutput
@@ -640,7 +642,7 @@ class Speaker:
self.device.on('connection', self.on_bluetooth_connection)
# Create a listener to wait for AVDTP connections
self.listener = Listener(Listener.create_registrar(self.device))
self.listener = Listener.for_device(self.device)
self.listener.on('connection', self.on_avdtp_connection)
print(f'Speaker ready to play, codec={color(self.codec, "cyan")}')
@@ -708,17 +710,6 @@ def speaker(
):
"""Run the speaker."""
# ffplay only works with AAC for now
if codec != 'aac' and '@ffplay' in output:
print(
color(
f'{codec} not supported with @ffplay output, '
'@ffplay output will be skipped',
'yellow',
)
)
output = list(filter(lambda x: x != '@ffplay', output))
if '@ffplay' in output:
# Check if ffplay is installed
try:

View File

@@ -24,6 +24,7 @@ from bumble.device import Device
from bumble.keys import JsonKeyStore
from bumble.transport import open_transport
# -----------------------------------------------------------------------------
async def unbond_with_keystore(keystore, address):
if address is None:

View File

@@ -15,9 +15,13 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import struct
import logging
from collections import namedtuple
from collections.abc import AsyncGenerator
from typing import List, Callable, Awaitable
from .company_ids import COMPANY_IDENTIFIERS
from .sdp import (
@@ -180,8 +184,12 @@ def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3))
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
)
]
),
),
@@ -230,8 +238,12 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
)
]
),
),
@@ -239,24 +251,20 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
# -----------------------------------------------------------------------------
class SbcMediaCodecInformation(
namedtuple(
'SbcMediaCodecInformation',
[
'sampling_frequency',
'channel_mode',
'block_length',
'subbands',
'allocation_method',
'minimum_bitpool_value',
'maximum_bitpool_value',
],
)
):
@dataclasses.dataclass
class SbcMediaCodecInformation:
'''
A2DP spec - 4.3.2 Codec Specific Information Elements
'''
sampling_frequency: int
channel_mode: int
block_length: int
subbands: int
allocation_method: int
minimum_bitpool_value: int
maximum_bitpool_value: int
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3,
@@ -272,7 +280,7 @@ class SbcMediaCodecInformation(
}
@staticmethod
def from_bytes(data: bytes) -> 'SbcMediaCodecInformation':
def from_bytes(data: bytes) -> SbcMediaCodecInformation:
sampling_frequency = (data[0] >> 4) & 0x0F
channel_mode = (data[0] >> 0) & 0x0F
block_length = (data[1] >> 4) & 0x0F
@@ -293,14 +301,14 @@ class SbcMediaCodecInformation(
@classmethod
def from_discrete_values(
cls,
sampling_frequency,
channel_mode,
block_length,
subbands,
allocation_method,
minimum_bitpool_value,
maximum_bitpool_value,
):
sampling_frequency: int,
channel_mode: int,
block_length: int,
subbands: int,
allocation_method: int,
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
return SbcMediaCodecInformation(
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
@@ -314,14 +322,14 @@ class SbcMediaCodecInformation(
@classmethod
def from_lists(
cls,
sampling_frequencies,
channel_modes,
block_lengths,
subbands,
allocation_methods,
minimum_bitpool_value,
maximum_bitpool_value,
):
sampling_frequencies: List[int],
channel_modes: List[int],
block_lengths: List[int],
subbands: List[int],
allocation_methods: List[int],
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
return SbcMediaCodecInformation(
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
@@ -348,7 +356,7 @@ class SbcMediaCodecInformation(
]
)
def __str__(self):
def __str__(self) -> str:
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness']
return '\n'.join(
@@ -367,16 +375,19 @@ class SbcMediaCodecInformation(
# -----------------------------------------------------------------------------
class AacMediaCodecInformation(
namedtuple(
'AacMediaCodecInformation',
['object_type', 'sampling_frequency', 'channels', 'rfa', 'vbr', 'bitrate'],
)
):
@dataclasses.dataclass
class AacMediaCodecInformation:
'''
A2DP spec - 4.5.2 Codec Specific Information Elements
'''
object_type: int
sampling_frequency: int
channels: int
rfa: int
vbr: int
bitrate: int
OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
@@ -400,7 +411,7 @@ class AacMediaCodecInformation(
CHANNELS_BITS = {1: 1 << 1, 2: 1}
@staticmethod
def from_bytes(data: bytes) -> 'AacMediaCodecInformation':
def from_bytes(data: bytes) -> AacMediaCodecInformation:
object_type = data[0]
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
channels = (data[2] >> 2) & 0x03
@@ -413,8 +424,13 @@ class AacMediaCodecInformation(
@classmethod
def from_discrete_values(
cls, object_type, sampling_frequency, channels, vbr, bitrate
):
cls,
object_type: int,
sampling_frequency: int,
channels: int,
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
@@ -425,7 +441,14 @@ class AacMediaCodecInformation(
)
@classmethod
def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate):
def from_lists(
cls,
object_types: List[int],
sampling_frequencies: List[int],
channels: List[int],
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency=sum(
@@ -449,7 +472,7 @@ class AacMediaCodecInformation(
]
)
def __str__(self):
def __str__(self) -> str:
object_types = [
'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC',
@@ -474,26 +497,26 @@ class AacMediaCodecInformation(
)
@dataclasses.dataclass
# -----------------------------------------------------------------------------
class VendorSpecificMediaCodecInformation:
'''
A2DP spec - 4.7.2 Codec Specific Information Elements
'''
vendor_id: int
codec_id: int
value: bytes
@staticmethod
def from_bytes(data):
def from_bytes(data: bytes) -> VendorSpecificMediaCodecInformation:
(vendor_id, codec_id) = struct.unpack_from('<IH', data, 0)
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
def __init__(self, vendor_id, codec_id, value):
self.vendor_id = vendor_id
self.codec_id = codec_id
self.value = value
def __bytes__(self):
def __bytes__(self) -> bytes:
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
def __str__(self):
def __str__(self) -> str:
# pylint: disable=line-too-long
return '\n'.join(
[
@@ -506,29 +529,27 @@ class VendorSpecificMediaCodecInformation:
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class SbcFrame:
def __init__(
self, sampling_frequency, block_count, channel_mode, subband_count, payload
):
self.sampling_frequency = sampling_frequency
self.block_count = block_count
self.channel_mode = channel_mode
self.subband_count = subband_count
self.payload = payload
sampling_frequency: int
block_count: int
channel_mode: int
subband_count: int
payload: bytes
@property
def sample_count(self):
def sample_count(self) -> int:
return self.subband_count * self.block_count
@property
def bitrate(self):
def bitrate(self) -> int:
return 8 * ((len(self.payload) * self.sampling_frequency) // self.sample_count)
@property
def duration(self):
def duration(self) -> float:
return self.sample_count / self.sampling_frequency
def __str__(self):
def __str__(self) -> str:
return (
f'SBC(sf={self.sampling_frequency},'
f'cm={self.channel_mode},'
@@ -540,12 +561,12 @@ class SbcFrame:
# -----------------------------------------------------------------------------
class SbcParser:
def __init__(self, read):
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
self.read = read
@property
def frames(self):
async def generate_frames():
def frames(self) -> AsyncGenerator[SbcFrame, None]:
async def generate_frames() -> AsyncGenerator[SbcFrame, None]:
while True:
# Read 4 bytes of header
header = await self.read(4)
@@ -589,7 +610,9 @@ class SbcParser:
# -----------------------------------------------------------------------------
class SbcPacketSource:
def __init__(self, read, mtu, codec_capabilities):
def __init__(
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
) -> None:
self.read = read
self.mtu = mtu
self.codec_capabilities = codec_capabilities
@@ -629,7 +652,9 @@ class SbcPacketSource:
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
timestamp += sum((frame.sample_count for frame in frames))
timestamp &= 0xFFFFFFFF
frames = [frame]
frames_size = len(frame.payload)
else:

91
bumble/at.py Normal file
View File

@@ -0,0 +1,91 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from bumble import core
class AtParsingError(core.InvalidPacketError):
"""Error raised when parsing AT commands fails."""
def tokenize_parameters(buffer: bytes) -> List[bytes]:
"""Split input parameters into tokens.
Removes space characters outside of double quote blocks:
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
are ignored [..], unless they are embedded in numeric or string constants"
Raises AtParsingError in case of invalid input string."""
tokens = []
in_quotes = False
token = bytearray()
for b in buffer:
char = bytearray([b])
if in_quotes:
token.extend(char)
if char == b'\"':
in_quotes = False
tokens.append(token[1:-1])
token = bytearray()
else:
if char == b' ':
pass
elif char == b',' or char == b')':
tokens.append(token)
tokens.append(char)
token = bytearray()
elif char == b'(':
if len(token) > 0:
raise AtParsingError("open_paren following regular character")
tokens.append(char)
elif char == b'"':
if len(token) > 0:
raise AtParsingError("quote following regular character")
in_quotes = True
token.extend(char)
else:
token.extend(char)
tokens.append(token)
return [bytes(token) for token in tokens if len(token) > 0]
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
"""Parse the parameters using the comma and parenthesis separators.
Raises AtParsingError in case of invalid input string."""
tokens = tokenize_parameters(buffer)
accumulator: List[list] = [[]]
current: Union[bytes, list] = bytes()
for token in tokens:
if token == b',':
accumulator[-1].append(current)
current = bytes()
elif token == b'(':
accumulator.append([])
elif token == b')':
if len(accumulator) < 2:
raise AtParsingError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
else:
current = token
accumulator[-1].append(current)
if len(accumulator) > 1:
raise AtParsingError("missing close_paren")
return accumulator[0]

View File

@@ -23,13 +23,26 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import functools
import inspect
import struct
from pyee import EventEmitter
from typing import Dict, Type, TYPE_CHECKING
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
Union,
TYPE_CHECKING,
)
from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError
from bumble.hci import HCI_Object, key_with_value, HCI_Constant
from pyee import EventEmitter
from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value
from bumble.colors import color
if TYPE_CHECKING:
@@ -182,6 +195,7 @@ UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# pylint: enable=line-too-long
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------
# Exceptions
# -----------------------------------------------------------------------------
@@ -209,7 +223,7 @@ class ATT_PDU:
pdu_classes: Dict[int, Type[ATT_PDU]] = {}
op_code = 0
name = None
name: str
@staticmethod
def from_bytes(pdu):
@@ -641,7 +655,7 @@ class ATT_Write_Command(ATT_PDU):
@ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
('attribute_value', '*'),
# ('authentication_signature', 'TODO')
]
)
@@ -719,48 +733,94 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
'''
# -----------------------------------------------------------------------------
class AttributeValue:
'''
Attribute value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(
self,
read: Union[
Callable[[Optional[Connection]], bytes],
Callable[[Optional[Connection]], Awaitable[bytes]],
None,
] = None,
write: Union[
Callable[[Optional[Connection], bytes], None],
Callable[[Optional[Connection], bytes], Awaitable[None]],
None,
] = None,
):
self._read = read
self._write = write
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
return self._read(connection) if self._read else b''
def write(
self, connection: Optional[Connection], value: bytes
) -> Union[Awaitable[None], None]:
if self._write:
return self._write(connection, value)
return None
# -----------------------------------------------------------------------------
class Attribute(EventEmitter):
# Permission flags
READABLE = 0x01
WRITEABLE = 0x02
READ_REQUIRES_ENCRYPTION = 0x04
WRITE_REQUIRES_ENCRYPTION = 0x08
READ_REQUIRES_AUTHENTICATION = 0x10
WRITE_REQUIRES_AUTHENTICATION = 0x20
READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80
class Permissions(enum.IntFlag):
READABLE = 0x01
WRITEABLE = 0x02
READ_REQUIRES_ENCRYPTION = 0x04
WRITE_REQUIRES_ENCRYPTION = 0x08
READ_REQUIRES_AUTHENTICATION = 0x10
WRITE_REQUIRES_AUTHENTICATION = 0x20
READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80
PERMISSION_NAMES = {
READABLE: 'READABLE',
WRITEABLE: 'WRITEABLE',
READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION',
WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION',
READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION',
WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION',
READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION',
WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION',
}
@classmethod
def from_string(cls, permissions_str: str) -> Attribute.Permissions:
try:
return functools.reduce(
lambda x, y: x | Attribute.Permissions[y],
permissions_str.replace('|', ',').split(","),
Attribute.Permissions(0),
)
except TypeError as exc:
# The check for `p.name is not None` here is needed because for InFlag
# enums, the .name property can be None, when the enum value is 0,
# so the type hint for .name is Optional[str].
enum_list: List[str] = [p.name for p in cls if p.name is not None]
enum_list_str = ",".join(enum_list)
raise TypeError(
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}"
) from exc
@staticmethod
def string_to_permissions(permissions_str: str):
try:
return functools.reduce(
lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y),
permissions_str.split(","),
0,
)
except TypeError as exc:
raise TypeError(
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}"
) from exc
# Permission flags(legacy-use only)
READABLE = Permissions.READABLE
WRITEABLE = Permissions.WRITEABLE
READ_REQUIRES_ENCRYPTION = Permissions.READ_REQUIRES_ENCRYPTION
WRITE_REQUIRES_ENCRYPTION = Permissions.WRITE_REQUIRES_ENCRYPTION
READ_REQUIRES_AUTHENTICATION = Permissions.READ_REQUIRES_AUTHENTICATION
WRITE_REQUIRES_AUTHENTICATION = Permissions.WRITE_REQUIRES_AUTHENTICATION
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
def __init__(self, attribute_type, permissions, value=b''):
value: Union[bytes, AttributeValue]
def __init__(
self,
attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, AttributeValue] = b'',
) -> None:
EventEmitter.__init__(self)
self.handle = 0
self.end_group_handle = 0
if isinstance(permissions, str):
self.permissions = self.string_to_permissions(permissions)
self.permissions = Attribute.Permissions.from_string(permissions)
else:
self.permissions = permissions
@@ -778,22 +838,26 @@ class Attribute(EventEmitter):
else:
self.value = value
def encode_value(self, value):
def encode_value(self, value: Any) -> bytes:
return value
def decode_value(self, value_bytes):
def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
def read_value(self, connection: Connection):
async def read_value(self, connection: Optional[Connection]) -> bytes:
if (
self.permissions & self.READ_REQUIRES_ENCRYPTION
) and not connection.encryption:
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
and not connection.encryption
):
raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
)
if (
self.permissions & self.READ_REQUIRES_AUTHENTICATION
) and not connection.authenticated:
(self.permissions & self.READ_REQUIRES_AUTHENTICATION)
and connection is not None
and not connection.authenticated
):
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
)
@@ -803,9 +867,11 @@ class Attribute(EventEmitter):
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
if read := getattr(self.value, 'read', None):
if hasattr(self.value, 'read'):
try:
value = read(connection) # pylint: disable=not-callable
value = self.value.read(connection)
if inspect.isawaitable(value):
value = await value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
@@ -815,7 +881,7 @@ class Attribute(EventEmitter):
return self.encode_value(value)
def write_value(self, connection: Connection, value_bytes):
async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption:
@@ -836,9 +902,11 @@ class Attribute(EventEmitter):
value = self.decode_value(value_bytes)
if write := getattr(self.value, 'write', None):
if hasattr(self.value, 'write'):
try:
write(connection, value) # pylint: disable=not-callable
result = self.value.write(connection, value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle

523
bumble/avc.py Normal file
View File

@@ -0,0 +1,523 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import Dict, Type, Union, Tuple
from bumble import core
from bumble.utils import OpenIntEnum
# -----------------------------------------------------------------------------
class Frame:
class SubunitType(enum.IntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.4
MONITOR = 0x00
AUDIO = 0x01
PRINTER = 0x02
DISC = 0x03
TAPE_RECORDER_OR_PLAYER = 0x04
TUNER = 0x05
CA = 0x06
CAMERA = 0x07
PANEL = 0x09
BULLETIN_BOARD = 0x0A
VENDOR_UNIQUE = 0x1C
EXTENDED = 0x1E
UNIT = 0x1F
class OperationCode(OpenIntEnum):
# 0x00 - 0x0F: Unit and subunit commands
VENDOR_DEPENDENT = 0x00
RESERVE = 0x01
PLUG_INFO = 0x02
# 0x10 - 0x3F: Unit commands
DIGITAL_OUTPUT = 0x10
DIGITAL_INPUT = 0x11
CHANNEL_USAGE = 0x12
OUTPUT_PLUG_SIGNAL_FORMAT = 0x18
INPUT_PLUG_SIGNAL_FORMAT = 0x19
GENERAL_BUS_SETUP = 0x1F
CONNECT_AV = 0x20
DISCONNECT_AV = 0x21
CONNECTIONS = 0x22
CONNECT = 0x24
DISCONNECT = 0x25
UNIT_INFO = 0x30
SUBUNIT_INFO = 0x31
# 0x40 - 0x7F: Subunit commands
PASS_THROUGH = 0x7C
GUI_UPDATE = 0x7D
PUSH_GUI_DATA = 0x7E
USER_ACTION = 0x7F
# 0xA0 - 0xBF: Unit and subunit commands
VERSION = 0xB0
POWER = 0xB2
subunit_type: SubunitType
subunit_id: int
opcode: OperationCode
operands: bytes
@staticmethod
def subclass(subclass):
# Infer the opcode from the class name
if subclass.__name__.endswith("CommandFrame"):
short_name = subclass.__name__.replace("CommandFrame", "")
category_class = CommandFrame
elif subclass.__name__.endswith("ResponseFrame"):
short_name = subclass.__name__.replace("ResponseFrame", "")
category_class = ResponseFrame
else:
raise core.InvalidArgumentError(
f"invalid subclass name {subclass.__name__}"
)
uppercase_indexes = [
i for i in range(len(short_name)) if short_name[i].isupper()
]
uppercase_indexes.append(len(short_name))
words = [
short_name[uppercase_indexes[i] : uppercase_indexes[i + 1]].upper()
for i in range(len(uppercase_indexes) - 1)
]
opcode_name = "_".join(words)
opcode = Frame.OperationCode[opcode_name]
category_class.subclasses[opcode] = subclass
return subclass
@staticmethod
def from_bytes(data: bytes) -> Frame:
if data[0] >> 4 != 0:
raise core.InvalidPacketError("first 4 bits must be 0s")
ctype_or_response = data[0] & 0xF
subunit_type = Frame.SubunitType(data[1] >> 3)
subunit_id = data[1] & 7
if subunit_type == Frame.SubunitType.EXTENDED:
# Not supported
raise NotImplementedError("extended subunit types not supported")
if subunit_id < 5:
opcode_offset = 2
elif subunit_id == 5:
# Extended to the next byte
extension = data[2]
if extension == 0:
raise core.InvalidPacketError("extended subunit ID value reserved")
if extension == 0xFF:
subunit_id = 5 + 254 + data[3]
opcode_offset = 4
else:
subunit_id = 5 + extension
opcode_offset = 3
elif subunit_id == 6:
raise core.InvalidPacketError("reserved subunit ID")
opcode = Frame.OperationCode(data[opcode_offset])
operands = data[opcode_offset + 1 :]
# Look for a registered subclass
if ctype_or_response < 8:
# Command
ctype = CommandFrame.CommandType(ctype_or_response)
if c_subclass := CommandFrame.subclasses.get(opcode):
return c_subclass(
ctype,
subunit_type,
subunit_id,
*c_subclass.parse_operands(operands),
)
return CommandFrame(ctype, subunit_type, subunit_id, opcode, operands)
else:
# Response
response = ResponseFrame.ResponseCode(ctype_or_response)
if r_subclass := ResponseFrame.subclasses.get(opcode):
return r_subclass(
response,
subunit_type,
subunit_id,
*r_subclass.parse_operands(operands),
)
return ResponseFrame(response, subunit_type, subunit_id, opcode, operands)
def to_bytes(
self,
ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode],
) -> bytes:
# TODO: support extended subunit types and ids.
return (
bytes(
[
ctype_or_response,
self.subunit_type << 3 | self.subunit_id,
self.opcode,
]
)
+ self.operands
)
def to_string(self, extra: str) -> str:
return (
f"{self.__class__.__name__}({extra}"
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"opcode={self.opcode.name}, "
f"operands={self.operands.hex()})"
)
def __init__(
self,
subunit_type: SubunitType,
subunit_id: int,
opcode: OperationCode,
operands: bytes,
) -> None:
self.subunit_type = subunit_type
self.subunit_id = subunit_id
self.opcode = opcode
self.operands = operands
# -----------------------------------------------------------------------------
class CommandFrame(Frame):
class CommandType(OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.1
CONTROL = 0x00
STATUS = 0x01
SPECIFIC_INQUIRY = 0x02
NOTIFY = 0x03
GENERAL_INQUIRY = 0x04
subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {}
ctype: CommandType
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError
def __init__(
self,
ctype: CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
opcode: Frame.OperationCode,
operands: bytes,
) -> None:
super().__init__(subunit_type, subunit_id, opcode, operands)
self.ctype = ctype
def __bytes__(self):
return self.to_bytes(self.ctype)
def __str__(self):
return self.to_string(f"ctype={self.ctype.name}, ")
# -----------------------------------------------------------------------------
class ResponseFrame(Frame):
class ResponseCode(OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.2
NOT_IMPLEMENTED = 0x08
ACCEPTED = 0x09
REJECTED = 0x0A
IN_TRANSITION = 0x0B
IMPLEMENTED_OR_STABLE = 0x0C
CHANGED = 0x0D
INTERIM = 0x0F
subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {}
response: ResponseCode
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError
def __init__(
self,
response: ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
opcode: Frame.OperationCode,
operands: bytes,
) -> None:
super().__init__(subunit_type, subunit_id, opcode, operands)
self.response = response
def __bytes__(self):
return self.to_bytes(self.response)
def __str__(self):
return self.to_string(f"response={self.response.name}, ")
# -----------------------------------------------------------------------------
class VendorDependentFrame:
company_id: int
vendor_dependent_data: bytes
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
return (
struct.unpack(">I", b"\x00" + operands[:3])[0],
operands[3:],
)
def make_operands(self) -> bytes:
return struct.pack(">I", self.company_id)[1:] + self.vendor_dependent_data
def __init__(self, company_id: int, vendor_dependent_data: bytes):
self.company_id = company_id
self.vendor_dependent_data = vendor_dependent_data
# -----------------------------------------------------------------------------
@Frame.subclass
class VendorDependentCommandFrame(VendorDependentFrame, CommandFrame):
def __init__(
self,
ctype: CommandFrame.CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
company_id: int,
vendor_dependent_data: bytes,
) -> None:
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
CommandFrame.__init__(
self,
ctype,
subunit_type,
subunit_id,
Frame.OperationCode.VENDOR_DEPENDENT,
self.make_operands(),
)
def __str__(self):
return (
f"VendorDependentCommandFrame(ctype={self.ctype.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"company_id=0x{self.company_id:06X}, "
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
)
# -----------------------------------------------------------------------------
@Frame.subclass
class VendorDependentResponseFrame(VendorDependentFrame, ResponseFrame):
def __init__(
self,
response: ResponseFrame.ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
company_id: int,
vendor_dependent_data: bytes,
) -> None:
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
ResponseFrame.__init__(
self,
response,
subunit_type,
subunit_id,
Frame.OperationCode.VENDOR_DEPENDENT,
self.make_operands(),
)
def __str__(self):
return (
f"VendorDependentResponseFrame(response={self.response.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"company_id=0x{self.company_id:06X}, "
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
)
# -----------------------------------------------------------------------------
class PassThroughFrame:
"""
See AV/C Panel Subunit Specification 1.1 - 9.4 PASS THROUGH control command
"""
class StateFlag(enum.IntEnum):
PRESSED = 0
RELEASED = 1
class OperationId(OpenIntEnum):
SELECT = 0x00
UP = 0x01
DOWN = 0x01
LEFT = 0x03
RIGHT = 0x04
RIGHT_UP = 0x05
RIGHT_DOWN = 0x06
LEFT_UP = 0x07
LEFT_DOWN = 0x08
ROOT_MENU = 0x09
SETUP_MENU = 0x0A
CONTENTS_MENU = 0x0B
FAVORITE_MENU = 0x0C
EXIT = 0x0D
NUMBER_0 = 0x20
NUMBER_1 = 0x21
NUMBER_2 = 0x22
NUMBER_3 = 0x23
NUMBER_4 = 0x24
NUMBER_5 = 0x25
NUMBER_6 = 0x26
NUMBER_7 = 0x27
NUMBER_8 = 0x28
NUMBER_9 = 0x29
DOT = 0x2A
ENTER = 0x2B
CLEAR = 0x2C
CHANNEL_UP = 0x30
CHANNEL_DOWN = 0x31
PREVIOUS_CHANNEL = 0x32
SOUND_SELECT = 0x33
INPUT_SELECT = 0x34
DISPLAY_INFORMATION = 0x35
HELP = 0x36
PAGE_UP = 0x37
PAGE_DOWN = 0x38
POWER = 0x40
VOLUME_UP = 0x41
VOLUME_DOWN = 0x42
MUTE = 0x43
PLAY = 0x44
STOP = 0x45
PAUSE = 0x46
RECORD = 0x47
REWIND = 0x48
FAST_FORWARD = 0x49
EJECT = 0x4A
FORWARD = 0x4B
BACKWARD = 0x4C
ANGLE = 0x50
SUBPICTURE = 0x51
F1 = 0x71
F2 = 0x72
F3 = 0x73
F4 = 0x74
F5 = 0x75
VENDOR_UNIQUE = 0x7E
state_flag: StateFlag
operation_id: OperationId
operation_data: bytes
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
return (
PassThroughFrame.StateFlag(operands[0] >> 7),
PassThroughFrame.OperationId(operands[0] & 0x7F),
operands[1 : 1 + operands[1]],
)
def make_operands(self):
return (
bytes([self.state_flag << 7 | self.operation_id, len(self.operation_data)])
+ self.operation_data
)
def __init__(
self,
state_flag: StateFlag,
operation_id: OperationId,
operation_data: bytes,
) -> None:
if len(operation_data) > 255:
raise core.InvalidArgumentError("operation data must be <= 255 bytes")
self.state_flag = state_flag
self.operation_id = operation_id
self.operation_data = operation_data
# -----------------------------------------------------------------------------
@Frame.subclass
class PassThroughCommandFrame(PassThroughFrame, CommandFrame):
def __init__(
self,
ctype: CommandFrame.CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
state_flag: PassThroughFrame.StateFlag,
operation_id: PassThroughFrame.OperationId,
operation_data: bytes,
) -> None:
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
CommandFrame.__init__(
self,
ctype,
subunit_type,
subunit_id,
Frame.OperationCode.PASS_THROUGH,
self.make_operands(),
)
def __str__(self):
return (
f"PassThroughCommandFrame(ctype={self.ctype.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"state_flag={self.state_flag.name}, "
f"operation_id={self.operation_id.name}, "
f"operation_data={self.operation_data.hex()})"
)
# -----------------------------------------------------------------------------
@Frame.subclass
class PassThroughResponseFrame(PassThroughFrame, ResponseFrame):
def __init__(
self,
response: ResponseFrame.ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
state_flag: PassThroughFrame.StateFlag,
operation_id: PassThroughFrame.OperationId,
operation_data: bytes,
) -> None:
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
ResponseFrame.__init__(
self,
response,
subunit_type,
subunit_id,
Frame.OperationCode.PASS_THROUGH,
self.make_operands(),
)
def __str__(self):
return (
f"PassThroughResponseFrame(response={self.response.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"state_flag={self.state_flag.name}, "
f"operation_id={self.operation_id.name}, "
f"operation_data={self.operation_data.hex()})"
)

292
bumble/avctp.py Normal file
View File

@@ -0,0 +1,292 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from enum import IntEnum
import logging
import struct
from typing import Callable, cast, Dict, Optional
from bumble.colors import color
from bumble import avc
from bumble import core
from bumble import l2cap
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AVCTP_PSM = 0x0017
AVCTP_BROWSING_PSM = 0x001B
# -----------------------------------------------------------------------------
class MessageAssembler:
Callback = Callable[[int, bool, bool, int, bytes], None]
transaction_label: int
pid: int
c_r: int
ipid: int
payload: bytes
number_of_packets: int
packets_received: int
def __init__(self, callback: Callback) -> None:
self.callback = callback
self.reset()
def reset(self) -> None:
self.packets_received = 0
self.transaction_label = -1
self.pid = -1
self.c_r = -1
self.ipid = -1
self.payload = b''
self.number_of_packets = 0
self.packet_count = 0
def on_pdu(self, pdu: bytes) -> None:
self.packets_received += 1
transaction_label = pdu[0] >> 4
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
c_r = (pdu[0] >> 1) & 1
ipid = pdu[0] & 1
if c_r == 0 and ipid != 0:
logger.warning("invalid IPID in command frame")
self.reset()
return
pid_offset = 1
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START):
if self.transaction_label >= 0:
# We are already in a transaction
logger.warning("received START or SINGLE fragment while in transaction")
self.reset()
self.packets_received = 1
if packet_type == Protocol.PacketType.START:
self.number_of_packets = pdu[1]
pid_offset = 2
pid = struct.unpack_from(">H", pdu, pid_offset)[0]
self.payload += pdu[pid_offset + 2 :]
if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END):
if transaction_label != self.transaction_label:
logger.warning("transaction label does not match")
self.reset()
return
if pid != self.pid:
logger.warning("PID does not match")
self.reset()
return
if c_r != self.c_r:
logger.warning("C/R does not match")
self.reset()
return
if self.packets_received > self.number_of_packets:
logger.warning("too many fragments in transaction")
self.reset()
return
if packet_type == Protocol.PacketType.END:
if self.packets_received != self.number_of_packets:
logger.warning("premature END")
self.reset()
return
else:
self.transaction_label = transaction_label
self.c_r = c_r
self.ipid = ipid
self.pid = pid
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END):
self.on_message_complete()
def on_message_complete(self):
try:
self.callback(
self.transaction_label,
self.c_r == 0,
self.ipid != 0,
self.pid,
self.payload,
)
except Exception as error:
logger.exception(color(f"!!! exception in callback: {error}", "red"))
self.reset()
# -----------------------------------------------------------------------------
class Protocol:
CommandHandler = Callable[[int, avc.CommandFrame], None]
command_handlers: Dict[int, CommandHandler] # Command handlers, by PID
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID
next_transaction_label: int
message_assembler: MessageAssembler
class PacketType(IntEnum):
SINGLE = 0b00
START = 0b01
CONTINUE = 0b10
END = 0b11
def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None:
self.command_handlers = {}
self.response_handlers = {}
self.l2cap_channel = l2cap_channel
self.message_assembler = MessageAssembler(self.on_message)
# Register to receive PDUs from the channel
l2cap_channel.sink = self.on_pdu
l2cap_channel.on("open", self.on_l2cap_channel_open)
l2cap_channel.on("close", self.on_l2cap_channel_close)
def on_l2cap_channel_open(self):
logger.debug(color("<<< AVCTP channel open", "magenta"))
def on_l2cap_channel_close(self):
logger.debug(color("<<< AVCTP channel closed", "magenta"))
def on_pdu(self, pdu: bytes) -> None:
self.message_assembler.on_pdu(pdu)
def on_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
) -> None:
logger.debug(
f"<<< AVCTP Message: pid={pid}, "
f"transaction_label={transaction_label}, "
f"is_command={is_command}, "
f"ipid={ipid}, "
f"payload={payload.hex()}"
)
# Check for invalid PID responses.
if ipid:
logger.debug(f"received IPID for PID={pid}")
# Find the appropriate handler.
if is_command:
if pid not in self.command_handlers:
logger.warning(f"no command handler for PID {pid}")
self.send_ipid(transaction_label, pid)
return
command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
self.command_handlers[pid](transaction_label, command_frame)
else:
if pid not in self.response_handlers:
logger.warning(f"no response handler for PID {pid}")
return
# By convention, for an ipid, send a None payload to the response handler.
if ipid:
response_frame = None
else:
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
self.response_handlers[pid](transaction_label, response_frame)
def send_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
):
# TODO: fragment large messages
packet_type = Protocol.PacketType.SINGLE
pdu = (
struct.pack(
">BH",
transaction_label << 4
| packet_type << 2
| (0 if is_command else 1) << 1
| (1 if ipid else 0),
pid,
)
+ payload
)
self.l2cap_channel.send_pdu(pdu)
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
logger.debug(
">>> AVCTP command: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, True, False, pid, payload)
def send_response(self, transaction_label: int, pid: int, payload: bytes):
logger.debug(
">>> AVCTP response: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, False, False, pid, payload)
def send_ipid(self, transaction_label: int, pid: int) -> None:
logger.debug(
">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
)
self.send_message(transaction_label, False, True, pid, b'')
def register_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
self.command_handlers[pid] = handler
def unregister_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
raise core.InvalidArgumentError("command handler not registered")
del self.command_handlers[pid]
def register_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
self.response_handlers[pid] = handler
def unregister_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
raise core.InvalidArgumentError("response handler not registered")
del self.response_handlers[pid]

File diff suppressed because it is too large Load Diff

1919
bumble/avrcp.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,8 @@
from __future__ import annotations
from dataclasses import dataclass
from bumble import core
# -----------------------------------------------------------------------------
class BitReader:
@@ -40,7 +42,7 @@ class BitReader:
""" "Read up to 32 bits."""
if bits > 32:
raise ValueError('maximum read size is 32')
raise core.InvalidArgumentError('maximum read size is 32')
if self.bits_cached >= bits:
# We have enough bits.
@@ -53,7 +55,7 @@ class BitReader:
feed_size = len(feed_bytes)
feed_int = int.from_bytes(feed_bytes, byteorder='big')
if 8 * feed_size + self.bits_cached < bits:
raise ValueError('trying to read past the data')
raise core.InvalidArgumentError('trying to read past the data')
self.byte_position += feed_size
# Combine the new cache and the old cache
@@ -68,7 +70,7 @@ class BitReader:
def read_bytes(self, count: int):
if self.bit_position + 8 * count > 8 * len(self.data):
raise ValueError('not enough data')
raise core.InvalidArgumentError('not enough data')
if self.bit_position % 8:
# Not byte aligned
@@ -113,7 +115,7 @@ class AacAudioRtpPacket:
@staticmethod
def program_config_element(reader: BitReader):
raise ValueError('program_config_element not supported')
raise core.InvalidPacketError('program_config_element not supported')
@dataclass
class GASpecificConfig:
@@ -140,7 +142,7 @@ class AacAudioRtpPacket:
aac_spectral_data_resilience_flags = reader.read(1)
extension_flag_3 = reader.read(1)
if extension_flag_3 == 1:
raise ValueError('extensionFlag3 == 1 not supported')
raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
@staticmethod
def audio_object_type(reader: BitReader):
@@ -216,7 +218,7 @@ class AacAudioRtpPacket:
reader, self.channel_configuration, self.audio_object_type
)
else:
raise ValueError(
raise core.InvalidPacketError(
f'audioObjectType {self.audio_object_type} not supported'
)
@@ -260,7 +262,7 @@ class AacAudioRtpPacket:
else:
audio_mux_version_a = 0
if audio_mux_version_a != 0:
raise ValueError('audioMuxVersionA != 0 not supported')
raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
if audio_mux_version == 1:
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
stream_cnt = 0
@@ -268,10 +270,10 @@ class AacAudioRtpPacket:
num_sub_frames = reader.read(6)
num_program = reader.read(4)
if num_program != 0:
raise ValueError('num_program != 0 not supported')
raise core.InvalidPacketError('num_program != 0 not supported')
num_layer = reader.read(3)
if num_layer != 0:
raise ValueError('num_layer != 0 not supported')
raise core.InvalidPacketError('num_layer != 0 not supported')
if audio_mux_version == 0:
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
reader
@@ -284,7 +286,7 @@ class AacAudioRtpPacket:
)
audio_specific_config_len = reader.bit_position - marker
if asc_len < audio_specific_config_len:
raise ValueError('audio_specific_config_len > asc_len')
raise core.InvalidPacketError('audio_specific_config_len > asc_len')
asc_len -= audio_specific_config_len
reader.skip(asc_len)
frame_length_type = reader.read(3)
@@ -293,7 +295,9 @@ class AacAudioRtpPacket:
elif frame_length_type == 1:
frame_length = reader.read(9)
else:
raise ValueError(f'frame_length_type {frame_length_type} not supported')
raise core.InvalidPacketError(
f'frame_length_type {frame_length_type} not supported'
)
self.other_data_present = reader.read(1)
if self.other_data_present:
@@ -318,12 +322,12 @@ class AacAudioRtpPacket:
def __init__(self, reader: BitReader, mux_config_present: int):
if mux_config_present == 0:
raise ValueError('muxConfigPresent == 0 not supported')
raise core.InvalidPacketError('muxConfigPresent == 0 not supported')
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41
use_same_stream_mux = reader.read(1)
if use_same_stream_mux:
raise ValueError('useSameStreamMux == 1 not supported')
raise core.InvalidPacketError('useSameStreamMux == 1 not supported')
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
# We only support:

View File

@@ -16,6 +16,10 @@ from functools import partial
from typing import List, Optional, Union
class ColorError(ValueError):
"""Error raised when a color spec is invalid."""
# ANSI color names. There is also a "default"
COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
@@ -52,7 +56,7 @@ def _color_code(spec: ColorSpec, base: int) -> str:
elif isinstance(spec, int) and 0 <= spec <= 255:
return _join(base + 8, 5, spec)
else:
raise ValueError('Invalid color spec "%s"' % spec)
raise ColorError('Invalid color spec "%s"' % spec)
def color(
@@ -72,7 +76,7 @@ def color(
if style_part in STYLES:
codes.append(STYLES.index(style_part))
else:
raise ValueError('Invalid style "%s"' % style_part)
raise ColorError('Invalid style "%s"' % style_part)
if codes:
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -21,24 +21,24 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import logging
import operator
import platform
if platform.system() != 'Emscripten':
import secrets
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.asymmetric.ec import (
generate_private_key,
ECDH,
EllipticCurvePublicNumbers,
EllipticCurvePrivateNumbers,
SECP256R1,
)
from cryptography.hazmat.primitives import cmac
else:
# TODO: implement stubs
pass
import secrets
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.asymmetric.ec import (
generate_private_key,
ECDH,
EllipticCurvePrivateKey,
EllipticCurvePublicNumbers,
EllipticCurvePrivateNumbers,
SECP256R1,
)
from cryptography.hazmat.primitives import cmac
from typing import Tuple
# -----------------------------------------------------------------------------
# Logging
@@ -50,16 +50,18 @@ logger = logging.getLogger(__name__)
# Classes
# -----------------------------------------------------------------------------
class EccKey:
def __init__(self, private_key):
def __init__(self, private_key: EllipticCurvePrivateKey) -> None:
self.private_key = private_key
@classmethod
def generate(cls):
def generate(cls) -> EccKey:
private_key = generate_private_key(SECP256R1())
return cls(private_key)
@classmethod
def from_private_key_bytes(cls, d_bytes, x_bytes, y_bytes):
def from_private_key_bytes(
cls, d_bytes: bytes, x_bytes: bytes, y_bytes: bytes
) -> EccKey:
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
x = int.from_bytes(x_bytes, byteorder='big', signed=False)
y = int.from_bytes(y_bytes, byteorder='big', signed=False)
@@ -69,7 +71,7 @@ class EccKey:
return cls(private_key)
@property
def x(self):
def x(self) -> bytes:
return (
self.private_key.public_key()
.public_numbers()
@@ -77,14 +79,14 @@ class EccKey:
)
@property
def y(self):
def y(self) -> bytes:
return (
self.private_key.public_key()
.public_numbers()
.y.to_bytes(32, byteorder='big')
)
def dh(self, public_key_x, public_key_y):
def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes:
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
y = int.from_bytes(public_key_y, byteorder='big', signed=False)
public_key = EllipticCurvePublicNumbers(x, y, SECP256R1()).public_key()
@@ -97,14 +99,33 @@ class EccKey:
# Functions
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
def xor(x, y):
def generate_prand() -> bytes:
'''Generates random 3 bytes, with the 2 most significant bits of 0b01.
See Bluetooth spec, Vol 6, Part E - Table 1.2.
'''
prand_bytes = secrets.token_bytes(6)
return prand_bytes[:2] + bytes([(prand_bytes[2] & 0b01111111) | 0b01000000])
# -----------------------------------------------------------------------------
def xor(x: bytes, y: bytes) -> bytes:
assert len(x) == len(y)
return bytes(map(operator.xor, x, y))
# -----------------------------------------------------------------------------
def r():
def reverse(input: bytes) -> bytes:
'''
Returns bytes of input in reversed endianness.
'''
return input[::-1]
# -----------------------------------------------------------------------------
def r() -> bytes:
'''
Generate 16 bytes of random data
'''
@@ -112,20 +133,20 @@ def r():
# -----------------------------------------------------------------------------
def e(key, data):
def e(key: bytes, data: bytes) -> bytes:
'''
AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output.
See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e
'''
cipher = Cipher(algorithms.AES(bytes(reversed(key))), modes.ECB())
cipher = Cipher(algorithms.AES(reverse(key)), modes.ECB())
encryptor = cipher.encryptor()
return bytes(reversed(encryptor.update(bytes(reversed(data)))))
return reverse(encryptor.update(reverse(data)))
# -----------------------------------------------------------------------------
def ah(k, r): # pylint: disable=redefined-outer-name
def ah(k: bytes, r: bytes) -> bytes: # pylint: disable=redefined-outer-name
'''
See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah
'''
@@ -136,7 +157,16 @@ def ah(k, r): # pylint: disable=redefined-outer-name
# -----------------------------------------------------------------------------
def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-name
def c1(
k: bytes,
r: bytes,
preq: bytes,
pres: bytes,
iat: int,
rat: int,
ia: bytes,
ra: bytes,
) -> bytes: # pylint: disable=redefined-outer-name
'''
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for
LE Legacy Pairing
@@ -148,7 +178,7 @@ def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-n
# -----------------------------------------------------------------------------
def s1(k, r1, r2):
def s1(k: bytes, r1: bytes, r2: bytes) -> bytes:
'''
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy
Pairing
@@ -158,7 +188,7 @@ def s1(k, r1, r2):
# -----------------------------------------------------------------------------
def aes_cmac(m, k):
def aes_cmac(m: bytes, k: bytes) -> bytes:
'''
See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC
@@ -170,20 +200,16 @@ def aes_cmac(m, k):
# -----------------------------------------------------------------------------
def f4(u, v, x, z):
def f4(u: bytes, v: bytes, x: bytes, z: bytes) -> bytes:
'''
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value
Generation Function f4
'''
return bytes(
reversed(
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))
)
)
return reverse(aes_cmac(reverse(u) + reverse(v) + z, reverse(x)))
# -----------------------------------------------------------------------------
def f5(w, n1, n2, a1, a2):
def f5(w: bytes, n1: bytes, n2: bytes, a1: bytes, a2: bytes) -> Tuple[bytes, bytes]:
'''
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation
Function f5
@@ -191,87 +217,83 @@ def f5(w, n1, n2, a1, a2):
NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order
'''
salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE')
t = aes_cmac(bytes(reversed(w)), salt)
t = aes_cmac(reverse(w), salt)
key_id = bytes([0x62, 0x74, 0x6C, 0x65])
return (
bytes(
reversed(
aes_cmac(
bytes([0])
+ key_id
+ bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(a1))
+ bytes(reversed(a2))
+ bytes([1, 0]),
t,
)
reverse(
aes_cmac(
bytes([0])
+ key_id
+ reverse(n1)
+ reverse(n2)
+ reverse(a1)
+ reverse(a2)
+ bytes([1, 0]),
t,
)
),
bytes(
reversed(
aes_cmac(
bytes([1])
+ key_id
+ bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(a1))
+ bytes(reversed(a2))
+ bytes([1, 0]),
t,
)
reverse(
aes_cmac(
bytes([1])
+ key_id
+ reverse(n1)
+ reverse(n2)
+ reverse(a1)
+ reverse(a2)
+ bytes([1, 0]),
t,
)
),
)
# -----------------------------------------------------------------------------
def f6(w, n1, n2, r, io_cap, a1, a2): # pylint: disable=redefined-outer-name
def f6(
w: bytes, n1: bytes, n2: bytes, r: bytes, io_cap: bytes, a1: bytes, a2: bytes
) -> bytes: # pylint: disable=redefined-outer-name
'''
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value
Generation Function f6
'''
return bytes(
reversed(
aes_cmac(
bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(r))
+ bytes(reversed(io_cap))
+ bytes(reversed(a1))
+ bytes(reversed(a2)),
bytes(reversed(w)),
)
return reverse(
aes_cmac(
reverse(n1)
+ reverse(n2)
+ reverse(r)
+ reverse(io_cap)
+ reverse(a1)
+ reverse(a2),
reverse(w),
)
)
# -----------------------------------------------------------------------------
def g2(u, v, x, y):
def g2(u: bytes, v: bytes, x: bytes, y: bytes) -> int:
'''
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison
Value Generation Function g2
'''
return int.from_bytes(
aes_cmac(
bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)),
bytes(reversed(x)),
reverse(u) + reverse(v) + reverse(y),
reverse(x),
)[-4:],
byteorder='big',
)
# -----------------------------------------------------------------------------
def h6(w, key_id):
def h6(w: bytes, key_id: bytes) -> bytes:
'''
See Bluetooth spec, Vol 3, Part H - 2.2.10 Link key conversion function h6
'''
return aes_cmac(key_id, w)
return reverse(aes_cmac(key_id, reverse(w)))
# -----------------------------------------------------------------------------
def h7(salt, w):
def h7(salt: bytes, w: bytes) -> bytes:
'''
See Bluetooth spec, Vol 3, Part H - 2.2.11 Link key conversion function h7
'''
return aes_cmac(w, salt)
return reverse(aes_cmac(reverse(w), salt))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,87 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Drivers that can be used to customize the interaction between a host and a controller,
like loading firmware after a cold start.
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import logging
import pathlib
import platform
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
from . import rtk, intel
from .common import Driver
if TYPE_CHECKING:
from bumble.host import Host
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Functions
# -----------------------------------------------------------------------------
async def get_driver_for_host(host: Host) -> Optional[Driver]:
"""Probe diver classes until one returns a valid instance for a host, or none is
found.
If a "driver" HCI metadata entry is present, only that driver class will be probed.
"""
driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver, "intel": intel.Driver}
probe_list: Iterable[str]
if driver_name := host.hci_metadata.get("driver"):
# Only probe a single driver
probe_list = [driver_name]
else:
# Probe all drivers
probe_list = driver_classes.keys()
for driver_name in probe_list:
if driver_class := driver_classes.get(driver_name):
logger.debug(f"Probing driver class: {driver_name}")
if driver := await driver_class.for_host(host):
logger.debug(f"Instantiated {driver_name} driver")
return driver
else:
logger.debug(f"Skipping unknown driver class: {driver_name}")
return None
def project_data_dir() -> pathlib.Path:
"""
Returns:
A path to an OS-specific directory for bumble data. The directory is created if
it doesn't exist.
"""
import platformdirs
if platform.system() == 'Darwin':
# platformdirs doesn't handle macOS right: it doesn't assemble a bundle id
# out of author & project
return platformdirs.user_data_path(
appname='com.google.bumble', ensure_exists=True
)
else:
# windows and linux don't use the com qualifier
return platformdirs.user_data_path(
appname='bumble', appauthor='google', ensure_exists=True
)

45
bumble/drivers/common.py Normal file
View File

@@ -0,0 +1,45 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common types for drivers.
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import abc
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""
@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None
@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""

102
bumble/drivers/intel.py Normal file
View File

@@ -0,0 +1,102 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
from bumble.drivers import common
from bumble.hci import (
hci_vendor_command_op_code, # type: ignore
HCI_Command,
HCI_Reset_Command,
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constant
# -----------------------------------------------------------------------------
INTEL_USB_PRODUCTS = {
# Intel AX210
(0x8087, 0x0032),
# Intel BE200
(0x8087, 0x0036),
}
# -----------------------------------------------------------------------------
# HCI Commands
# -----------------------------------------------------------------------------
HCI_INTEL_DDC_CONFIG_WRITE_COMMAND = hci_vendor_command_op_code(0xFC8B) # type: ignore
HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD = [0x03, 0xE4, 0x02, 0x00]
HCI_Command.register_commands(globals())
@HCI_Command.command( # type: ignore
fields=[("params", "*")],
return_parameters_fields=[
("params", "*"),
],
)
class Hci_Intel_DDC_Config_Write_Command(HCI_Command):
pass
class Driver(common.Driver):
def __init__(self, host):
self.host = host
@staticmethod
def check(host):
driver = host.hci_metadata.get("driver")
if driver == "intel":
return True
vendor_id = host.hci_metadata.get("vendor_id")
product_id = host.hci_metadata.get("product_id")
if vendor_id is None or product_id is None:
logger.debug("USB metadata not sufficient")
return False
if (vendor_id, product_id) not in INTEL_USB_PRODUCTS:
logger.debug(
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
)
return False
return True
@classmethod
async def for_host(cls, host, force=False): # type: ignore
# Only instantiate this driver if explicitly selected
if not force and not cls.check(host):
return None
return cls(host)
async def init_controller(self):
self.host.ready = True
await self.host.send_command(HCI_Reset_Command(), check_result=True)
await self.host.send_command(
Hci_Intel_DDC_Config_Write_Command(
params=HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD
)
)

671
bumble/drivers/rtk.py Normal file
View File

@@ -0,0 +1,671 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Support for Realtek USB dongles.
Based on various online bits of information, including the Linux kernel.
(see `drivers/bluetooth/btrtl.c`)
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from dataclasses import dataclass
import asyncio
import enum
import logging
import math
import os
import pathlib
import platform
import struct
from typing import Tuple
import weakref
from bumble import core
from bumble.hci import (
hci_vendor_command_op_code,
STATUS_SPEC,
HCI_SUCCESS,
HCI_Command,
HCI_Reset_Command,
HCI_Read_Local_Version_Information_Command,
)
from bumble.drivers import common
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
class RtkFirmwareError(core.BaseBumbleError):
"""Error raised when RTK firmware initialization fails."""
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
RTK_ROM_LMP_8723A = 0x1200
RTK_ROM_LMP_8723B = 0x8723
RTK_ROM_LMP_8821A = 0x8821
RTK_ROM_LMP_8761A = 0x8761
RTK_ROM_LMP_8822B = 0x8822
RTK_ROM_LMP_8852A = 0x8852
RTK_CONFIG_MAGIC = 0x8723AB55
RTK_EPATCH_SIGNATURE = b"Realtech"
RTK_FRAGMENT_LENGTH = 252
RTK_FIRMWARE_DIR_ENV = "BUMBLE_RTK_FIRMWARE_DIR"
RTK_LINUX_FIRMWARE_DIR = "/lib/firmware/rtl_bt"
class RtlProjectId(enum.IntEnum):
PROJECT_ID_8723A = 0
PROJECT_ID_8723B = 1
PROJECT_ID_8821A = 2
PROJECT_ID_8761A = 3
PROJECT_ID_8822B = 8
PROJECT_ID_8723D = 9
PROJECT_ID_8821C = 10
PROJECT_ID_8822C = 13
PROJECT_ID_8761B = 14
PROJECT_ID_8852A = 18
PROJECT_ID_8852B = 20
PROJECT_ID_8852C = 25
RTK_PROJECT_ID_TO_ROM = {
0: RTK_ROM_LMP_8723A,
1: RTK_ROM_LMP_8723B,
2: RTK_ROM_LMP_8821A,
3: RTK_ROM_LMP_8761A,
8: RTK_ROM_LMP_8822B,
9: RTK_ROM_LMP_8723B,
10: RTK_ROM_LMP_8821A,
13: RTK_ROM_LMP_8822B,
14: RTK_ROM_LMP_8761A,
18: RTK_ROM_LMP_8852A,
20: RTK_ROM_LMP_8852A,
25: RTK_ROM_LMP_8852A,
}
# List of USB (VendorID, ProductID) for Realtek-based devices.
RTK_USB_PRODUCTS = {
# Realtek 8723AE
(0x0930, 0x021D),
(0x13D3, 0x3394),
# Realtek 8723BE
(0x0489, 0xE085),
(0x0489, 0xE08B),
(0x04F2, 0xB49F),
(0x13D3, 0x3410),
(0x13D3, 0x3416),
(0x13D3, 0x3459),
(0x13D3, 0x3494),
# Realtek 8723BU
(0x7392, 0xA611),
# Realtek 8723DE
(0x0BDA, 0xB009),
(0x2FF8, 0xB011),
# Realtek 8761BUV
(0x0B05, 0x190E),
(0x0BDA, 0x8771),
(0x2230, 0x0016),
(0x2357, 0x0604),
(0x2550, 0x8761),
(0x2B89, 0x8761),
(0x7392, 0xC611),
(0x0BDA, 0x877B),
# Realtek 8821AE
(0x0B05, 0x17DC),
(0x13D3, 0x3414),
(0x13D3, 0x3458),
(0x13D3, 0x3461),
(0x13D3, 0x3462),
# Realtek 8821CE
(0x0BDA, 0xB00C),
(0x0BDA, 0xC822),
(0x13D3, 0x3529),
# Realtek 8822BE
(0x0B05, 0x185C),
(0x13D3, 0x3526),
# Realtek 8822CE
(0x04C5, 0x161F),
(0x04CA, 0x4005),
(0x0B05, 0x18EF),
(0x0BDA, 0xB00C),
(0x0BDA, 0xC123),
(0x0BDA, 0xC822),
(0x0CB5, 0xC547),
(0x1358, 0xC123),
(0x13D3, 0x3548),
(0x13D3, 0x3549),
(0x13D3, 0x3553),
(0x13D3, 0x3555),
(0x2FF8, 0x3051),
# Realtek 8822CU
(0x13D3, 0x3549),
# Realtek 8852AE
(0x04C5, 0x165C),
(0x04CA, 0x4006),
(0x0BDA, 0x2852),
(0x0BDA, 0x385A),
(0x0BDA, 0x4852),
(0x0BDA, 0xC852),
(0x0CB8, 0xC549),
# Realtek 8852BE
(0x0BDA, 0x887B),
(0x0CB8, 0xC559),
(0x13D3, 0x3571),
# Realtek 8852CE
(0x04C5, 0x1675),
(0x04CA, 0x4007),
(0x0CB8, 0xC558),
(0x13D3, 0x3586),
(0x13D3, 0x3587),
(0x13D3, 0x3592),
}
# -----------------------------------------------------------------------------
# HCI Commands
# -----------------------------------------------------------------------------
HCI_RTK_READ_ROM_VERSION_COMMAND = hci_vendor_command_op_code(0x6D)
HCI_RTK_DOWNLOAD_COMMAND = hci_vendor_command_op_code(0x20)
HCI_RTK_DROP_FIRMWARE_COMMAND = hci_vendor_command_op_code(0x66)
HCI_Command.register_commands(globals())
@HCI_Command.command(return_parameters_fields=[("status", STATUS_SPEC), ("version", 1)])
class HCI_RTK_Read_ROM_Version_Command(HCI_Command):
pass
@HCI_Command.command(
fields=[("index", 1), ("payload", RTK_FRAGMENT_LENGTH)],
return_parameters_fields=[("status", STATUS_SPEC), ("index", 1)],
)
class HCI_RTK_Download_Command(HCI_Command):
pass
@HCI_Command.command()
class HCI_RTK_Drop_Firmware_Command(HCI_Command):
pass
# -----------------------------------------------------------------------------
class Firmware:
def __init__(self, firmware):
extension_sig = bytes([0x51, 0x04, 0xFD, 0x77])
if not firmware.startswith(RTK_EPATCH_SIGNATURE):
raise RtkFirmwareError("Firmware does not start with epatch signature")
if not firmware.endswith(extension_sig):
raise RtkFirmwareError("Firmware does not end with extension sig")
# The firmware should start with a 14 byte header.
epatch_header_size = 14
if len(firmware) < epatch_header_size:
raise RtkFirmwareError("Firmware too short")
# Look for the "project ID", starting from the end.
offset = len(firmware) - len(extension_sig)
project_id = -1
while offset >= epatch_header_size:
length, opcode = firmware[offset - 2 : offset]
offset -= 2
if opcode == 0xFF:
# End
break
if length == 0:
raise RtkFirmwareError("Invalid 0-length instruction")
if opcode == 0 and length == 1:
project_id = firmware[offset - 1]
break
offset -= length
if project_id < 0:
raise RtkFirmwareError("Project ID not found")
self.project_id = project_id
# Read the patch tables info.
self.version, num_patches = struct.unpack("<IH", firmware[8:14])
self.patches = []
# The patches tables are laid out as:
# <ChipID_1><ChipID_2>...<ChipID_N> (16 bits each)
# <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each)
# <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each)
if epatch_header_size + 8 * num_patches > len(firmware):
raise RtkFirmwareError("Firmware too short")
chip_id_table_offset = epatch_header_size
patch_length_table_offset = chip_id_table_offset + 2 * num_patches
patch_offset_table_offset = chip_id_table_offset + 4 * num_patches
for patch_index in range(num_patches):
chip_id_offset = chip_id_table_offset + 2 * patch_index
(chip_id,) = struct.unpack_from("<H", firmware, chip_id_offset)
(patch_length,) = struct.unpack_from(
"<H", firmware, patch_length_table_offset + 2 * patch_index
)
(patch_offset,) = struct.unpack_from(
"<I", firmware, patch_offset_table_offset + 4 * patch_index
)
if patch_offset + patch_length > len(firmware):
raise RtkFirmwareError("Firmware too short")
# Get the SVN version for the patch
(svn_version,) = struct.unpack_from(
"<I", firmware, patch_offset + patch_length - 8
)
# Create a payload with the patch, replacing the last 4 bytes with
# the firmware version.
self.patches.append(
(
chip_id,
firmware[patch_offset : patch_offset + patch_length - 4]
+ struct.pack("<I", self.version),
svn_version,
)
)
class Driver(common.Driver):
@dataclass
class DriverInfo:
rom: int
hci: Tuple[int, int]
config_needed: bool
has_rom_version: bool
has_msft_ext: bool = False
fw_name: str = ""
config_name: str = ""
DRIVER_INFOS = [
# 8723A
DriverInfo(
rom=RTK_ROM_LMP_8723A,
hci=(0x0B, 0x06),
config_needed=False,
has_rom_version=False,
fw_name="rtl8723a_fw.bin",
config_name="",
),
# 8723B
DriverInfo(
rom=RTK_ROM_LMP_8723B,
hci=(0x0B, 0x06),
config_needed=False,
has_rom_version=True,
fw_name="rtl8723b_fw.bin",
config_name="rtl8723b_config.bin",
),
# 8723D
DriverInfo(
rom=RTK_ROM_LMP_8723B,
hci=(0x0D, 0x08),
config_needed=True,
has_rom_version=True,
fw_name="rtl8723d_fw.bin",
config_name="rtl8723d_config.bin",
),
# 8821A
DriverInfo(
rom=RTK_ROM_LMP_8821A,
hci=(0x0A, 0x06),
config_needed=False,
has_rom_version=True,
fw_name="rtl8821a_fw.bin",
config_name="rtl8821a_config.bin",
),
# 8821C
DriverInfo(
rom=RTK_ROM_LMP_8821A,
hci=(0x0C, 0x08),
config_needed=False,
has_rom_version=True,
has_msft_ext=True,
fw_name="rtl8821c_fw.bin",
config_name="rtl8821c_config.bin",
),
# 8761A
DriverInfo(
rom=RTK_ROM_LMP_8761A,
hci=(0x0A, 0x06),
config_needed=False,
has_rom_version=True,
fw_name="rtl8761a_fw.bin",
config_name="rtl8761a_config.bin",
),
# 8761BU
DriverInfo(
rom=RTK_ROM_LMP_8761A,
hci=(0x0B, 0x0A),
config_needed=False,
has_rom_version=True,
fw_name="rtl8761bu_fw.bin",
config_name="rtl8761bu_config.bin",
),
# 8822C
DriverInfo(
rom=RTK_ROM_LMP_8822B,
hci=(0x0C, 0x0A),
config_needed=False,
has_rom_version=True,
has_msft_ext=True,
fw_name="rtl8822cu_fw.bin",
config_name="rtl8822cu_config.bin",
),
# 8822B
DriverInfo(
rom=RTK_ROM_LMP_8822B,
hci=(0x0B, 0x07),
config_needed=True,
has_rom_version=True,
has_msft_ext=True,
fw_name="rtl8822b_fw.bin",
config_name="rtl8822b_config.bin",
),
# 8852A
DriverInfo(
rom=RTK_ROM_LMP_8852A,
hci=(0x0A, 0x0B),
config_needed=False,
has_rom_version=True,
has_msft_ext=True,
fw_name="rtl8852au_fw.bin",
config_name="rtl8852au_config.bin",
),
# 8852B
DriverInfo(
rom=RTK_ROM_LMP_8852A,
hci=(0xB, 0xB),
config_needed=False,
has_rom_version=True,
has_msft_ext=True,
fw_name="rtl8852bu_fw.bin",
config_name="rtl8852bu_config.bin",
),
# 8852C
DriverInfo(
rom=RTK_ROM_LMP_8852A,
hci=(0x0C, 0x0C),
config_needed=False,
has_rom_version=True,
has_msft_ext=True,
fw_name="rtl8852cu_fw.bin",
config_name="rtl8852cu_config.bin",
),
]
POST_DROP_DELAY = 0.2
@staticmethod
def find_driver_info(hci_version, hci_subversion, lmp_subversion):
for driver_info in Driver.DRIVER_INFOS:
if driver_info.rom == lmp_subversion and driver_info.hci == (
hci_subversion,
hci_version,
):
return driver_info
return None
@staticmethod
def find_binary_path(file_name):
# First check if an environment variable is set
if RTK_FIRMWARE_DIR_ENV in os.environ:
if (
path := pathlib.Path(os.environ[RTK_FIRMWARE_DIR_ENV]) / file_name
).is_file():
logger.debug(f"{file_name} found in env dir")
return path
# When the environment variable is set, don't look elsewhere
return None
# Then, look where the firmware download tool writes by default
if (path := rtk_firmware_dir() / file_name).is_file():
logger.debug(f"{file_name} found in project data dir")
return path
# Then, look in the package's driver directory
if (path := pathlib.Path(__file__).parent / "rtk_fw" / file_name).is_file():
logger.debug(f"{file_name} found in package dir")
return path
# On Linux, check the system's FW directory
if (
platform.system() == "Linux"
and (path := pathlib.Path(RTK_LINUX_FIRMWARE_DIR) / file_name).is_file()
):
logger.debug(f"{file_name} found in Linux system FW dir")
return path
# Finally look in the current directory
if (path := pathlib.Path.cwd() / file_name).is_file():
logger.debug(f"{file_name} found in CWD")
return path
return None
@staticmethod
def check(host):
if not host.hci_metadata:
logger.debug("USB metadata not found")
return False
if host.hci_metadata.get('driver') == 'rtk':
# Forced driver
return True
vendor_id = host.hci_metadata.get("vendor_id")
product_id = host.hci_metadata.get("product_id")
if vendor_id is None or product_id is None:
logger.debug("USB metadata not sufficient")
return False
if (vendor_id, product_id) not in RTK_USB_PRODUCTS:
logger.debug(
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
)
return False
return True
@classmethod
async def driver_info_for_host(cls, host):
await host.send_command(HCI_Reset_Command(), check_result=True)
host.ready = True # Needed to let the host know the controller is ready.
response = await host.send_command(
HCI_Read_Local_Version_Information_Command(), check_result=True
)
local_version = response.return_parameters
logger.debug(
f"looking for a driver: 0x{local_version.lmp_subversion:04X} "
f"(0x{local_version.hci_version:02X}, "
f"0x{local_version.hci_subversion:04X})"
)
driver_info = cls.find_driver_info(
local_version.hci_version,
local_version.hci_subversion,
local_version.lmp_subversion,
)
if driver_info is None:
# TODO: it seems that the Linux driver will send command (0x3f, 0x66)
# in this case and then re-read the local version, then re-match.
logger.debug("firmware already loaded or no known driver for this device")
return driver_info
@classmethod
async def for_host(cls, host, force=False):
# Check that a driver is needed for this host
if not force and not cls.check(host):
return None
# Get the driver info
driver_info = await cls.driver_info_for_host(host)
if driver_info is None:
return None
# Load the firmware
firmware_path = cls.find_binary_path(driver_info.fw_name)
if not firmware_path:
logger.warning(f"Firmware file {driver_info.fw_name} not found")
logger.warning("See https://google.github.io/bumble/drivers/realtek.html")
return None
with open(firmware_path, "rb") as firmware_file:
firmware = firmware_file.read()
# Load the config
config = None
if driver_info.config_name:
config_path = cls.find_binary_path(driver_info.config_name)
if config_path:
with open(config_path, "rb") as config_file:
config = config_file.read()
if driver_info.config_needed and not config:
logger.warning("Config needed, but no config file available")
return None
return cls(host, driver_info, firmware, config)
def __init__(self, host, driver_info, firmware, config):
self.host = weakref.proxy(host)
self.driver_info = driver_info
self.firmware = firmware
self.config = config
@staticmethod
async def drop_firmware(host):
host.send_hci_packet(HCI_RTK_Drop_Firmware_Command())
# Wait for the command to be effective (no response is sent)
await asyncio.sleep(Driver.POST_DROP_DELAY)
async def download_for_rtl8723a(self):
# Check that the firmware image does not include an epatch signature.
if RTK_EPATCH_SIGNATURE in self.firmware:
logger.warning(
"epatch signature found in firmware, it is probably the wrong firmware"
)
return
# TODO: load the firmware
async def download_for_rtl8723b(self):
if self.driver_info.has_rom_version:
response = await self.host.send_command(
HCI_RTK_Read_ROM_Version_Command(), check_result=True
)
if response.return_parameters.status != HCI_SUCCESS:
logger.warning("can't get ROM version")
return
rom_version = response.return_parameters.version
logger.debug(f"ROM version before download: {rom_version:04X}")
else:
rom_version = 0
firmware = Firmware(self.firmware)
logger.debug(f"firmware: project_id=0x{firmware.project_id:04X}")
for patch in firmware.patches:
if patch[0] == rom_version + 1:
logger.debug(f"using patch {patch[0]}")
break
else:
logger.warning("no valid patch found for rom version {rom_version}")
return
# Append the config if there is one.
if self.config:
payload = patch[1] + self.config
else:
payload = patch[1]
# Download the payload, one fragment at a time.
fragment_count = math.ceil(len(payload) / RTK_FRAGMENT_LENGTH)
for fragment_index in range(fragment_count):
# NOTE: the Linux driver somehow adds 1 to the index after it wraps around.
# That's odd, but we"ll do the same here.
download_index = fragment_index & 0x7F
if download_index >= 0x80:
download_index += 1
if fragment_index == fragment_count - 1:
download_index |= 0x80 # End marker.
fragment_offset = fragment_index * RTK_FRAGMENT_LENGTH
fragment = payload[fragment_offset : fragment_offset + RTK_FRAGMENT_LENGTH]
logger.debug(f"downloading fragment {fragment_index}")
await self.host.send_command(
HCI_RTK_Download_Command(
index=download_index, payload=fragment, check_result=True
)
)
logger.debug("download complete!")
# Read the version again
response = await self.host.send_command(
HCI_RTK_Read_ROM_Version_Command(), check_result=True
)
if response.return_parameters.status != HCI_SUCCESS:
logger.warning("can't get ROM version")
else:
rom_version = response.return_parameters.version
logger.debug(f"ROM version after download: {rom_version:04X}")
async def download_firmware(self):
if self.driver_info.rom == RTK_ROM_LMP_8723A:
return await self.download_for_rtl8723a()
if self.driver_info.rom in (
RTK_ROM_LMP_8723B,
RTK_ROM_LMP_8821A,
RTK_ROM_LMP_8761A,
RTK_ROM_LMP_8822B,
RTK_ROM_LMP_8852A,
):
return await self.download_for_rtl8723b()
raise RtkFirmwareError("ROM not supported")
async def init_controller(self):
await self.download_firmware()
await self.host.send_command(HCI_Reset_Command(), check_result=True)
logger.info(f"loaded FW image {self.driver_info.fw_name}")
def rtk_firmware_dir() -> pathlib.Path:
"""
Returns:
A path to a subdir of the project data dir for Realtek firmware.
The directory is created if it doesn't exist.
"""
from bumble.drivers import project_data_dir
p = project_data_dir() / "firmware" / "realtek"
p.mkdir(parents=True, exist_ok=True)
return p

View File

@@ -36,6 +36,7 @@ logger = logging.getLogger(__name__)
# Classes
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class GenericAccessService(Service):
def __init__(self, device_name, appearance=(0, 0)):

View File

@@ -23,16 +23,28 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import enum
import functools
import logging
import struct
from typing import Optional, Sequence, List
from typing import (
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Union,
TYPE_CHECKING,
)
from .colors import color
from .core import UUID, get_dict_key_by_value
from .att import Attribute
from bumble.colors import color
from bumble.core import UUID
from bumble.att import Attribute, AttributeValue
if TYPE_CHECKING:
from bumble.gatt_client import AttributeProxy
from bumble.device import Connection
# -----------------------------------------------------------------------------
@@ -93,20 +105,35 @@ GATT_RECONNECTION_CONFIGURATION_SERVICE = UUID.from_16_bits(0x1829, 'Reconne
GATT_INSULIN_DELIVERY_SERVICE = UUID.from_16_bits(0x183A, 'Insulin Delivery')
GATT_BINARY_SENSOR_SERVICE = UUID.from_16_bits(0x183B, 'Binary Sensor')
GATT_EMERGENCY_CONFIGURATION_SERVICE = UUID.from_16_bits(0x183C, 'Emergency Configuration')
GATT_AUTHORIZATION_CONTROL_SERVICE = UUID.from_16_bits(0x183D, 'Authorization Control')
GATT_PHYSICAL_ACTIVITY_MONITOR_SERVICE = UUID.from_16_bits(0x183E, 'Physical Activity Monitor')
GATT_ELAPSED_TIME_SERVICE = UUID.from_16_bits(0x183F, 'Elapsed Time')
GATT_GENERIC_HEALTH_SENSOR_SERVICE = UUID.from_16_bits(0x1840, 'Generic Health Sensor')
GATT_AUDIO_INPUT_CONTROL_SERVICE = UUID.from_16_bits(0x1843, 'Audio Input Control')
GATT_VOLUME_CONTROL_SERVICE = UUID.from_16_bits(0x1844, 'Volume Control')
GATT_VOLUME_OFFSET_CONTROL_SERVICE = UUID.from_16_bits(0x1845, 'Volume Offset Control')
GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification Service')
GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification')
GATT_DEVICE_TIME_SERVICE = UUID.from_16_bits(0x1847, 'Device Time')
GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control Service')
GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control Service')
GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control')
GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control')
GATT_CONSTANT_TONE_EXTENSION_SERVICE = UUID.from_16_bits(0x184A, 'Constant Tone Extension')
GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer Service')
GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer Service')
GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer')
GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer')
GATT_MICROPHONE_CONTROL_SERVICE = UUID.from_16_bits(0x184D, 'Microphone Control')
GATT_AUDIO_STREAM_CONTROL_SERVICE = UUID.from_16_bits(0x184E, 'Audio Stream Control')
GATT_BROADCAST_AUDIO_SCAN_SERVICE = UUID.from_16_bits(0x184F, 'Broadcast Audio Scan')
GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE = UUID.from_16_bits(0x1850, 'Published Audio Capabilities')
GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1851, 'Basic Audio Announcement')
GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1852, 'Broadcast Audio Announcement')
GATT_COMMON_AUDIO_SERVICE = UUID.from_16_bits(0x1853, 'Common Audio')
GATT_HEARING_ACCESS_SERVICE = UUID.from_16_bits(0x1854, 'Hearing Access')
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE = UUID.from_16_bits(0x1855, 'Telephony and Media Audio')
GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1856, 'Public Broadcast Announcement')
GATT_ELECTRONIC_SHELF_LABEL_SERVICE = UUID.from_16_bits(0X1857, 'Electronic Shelf Label')
GATT_GAMING_AUDIO_SERVICE = UUID.from_16_bits(0x1858, 'Gaming Audio')
GATT_MESH_PROXY_SOLICITATION_SERVICE = UUID.from_16_bits(0x1859, 'Mesh Audio Solicitation')
# Types
# Attribute Types
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2800, 'Primary Service')
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2801, 'Secondary Service')
GATT_INCLUDE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2802, 'Include')
@@ -129,6 +156,8 @@ GATT_ENVIRONMENTAL_SENSING_MEASUREMENT_DESCRIPTOR = UUID.from_16_bits(0x290C,
GATT_ENVIRONMENTAL_SENSING_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290D, 'Environmental Sensing Trigger Setting')
GATT_TIME_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290E, 'Time Trigger Setting')
GATT_COMPLETE_BR_EDR_TRANSPORT_BLOCK_DATA_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Complete BR-EDR Transport Block Data')
GATT_OBSERVATION_SCHEDULE_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Observation Schedule')
GATT_VALID_RANGE_AND_ACCURACY_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Valid Range And Accuracy')
# Device Information Service
GATT_SYSTEM_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A23, 'System ID')
@@ -156,6 +185,96 @@ GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A39, 'Heart
# Battery Service
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level')
# Telephony And Media Audio Service (TMAS)
GATT_TMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2B51, 'TMAP Role')
# Audio Input Control Service (AICS)
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B77, 'Audio Input State')
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC = UUID.from_16_bits(0x2B78, 'Gain Settings Attribute')
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC = UUID.from_16_bits(0x2B79, 'Audio Input Type')
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC = UUID.from_16_bits(0x2B7A, 'Audio Input Status')
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7B, 'Audio Input Control Point')
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B7C, 'Audio Input Description')
# Volume Control Service (VCS)
GATT_VOLUME_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B7D, 'Volume State')
GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7E, 'Volume Control Point')
GATT_VOLUME_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2B7F, 'Volume Flags')
# Volume Offset Control Service (VOCS)
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B80, 'Volume Offset State')
GATT_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2B81, 'Audio Location')
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B82, 'Volume Offset Control Point')
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B83, 'Audio Output Description')
# Coordinated Set Identification Service (CSIS)
GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC = UUID.from_16_bits(0x2B84, 'Set Identity Resolving Key')
GATT_COORDINATED_SET_SIZE_CHARACTERISTIC = UUID.from_16_bits(0x2B85, 'Coordinated Set Size')
GATT_SET_MEMBER_LOCK_CHARACTERISTIC = UUID.from_16_bits(0x2B86, 'Set Member Lock')
GATT_SET_MEMBER_RANK_CHARACTERISTIC = UUID.from_16_bits(0x2B87, 'Set Member Rank')
# Media Control Service (MCS)
GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2B93, 'Media Player Name')
GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B94, 'Media Player Icon Object ID')
GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC = UUID.from_16_bits(0x2B95, 'Media Player Icon URL')
GATT_TRACK_CHANGED_CHARACTERISTIC = UUID.from_16_bits(0x2B96, 'Track Changed')
GATT_TRACK_TITLE_CHARACTERISTIC = UUID.from_16_bits(0x2B97, 'Track Title')
GATT_TRACK_DURATION_CHARACTERISTIC = UUID.from_16_bits(0x2B98, 'Track Duration')
GATT_TRACK_POSITION_CHARACTERISTIC = UUID.from_16_bits(0x2B99, 'Track Position')
GATT_PLAYBACK_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9A, 'Playback Speed')
GATT_SEEKING_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9B, 'Seeking Speed')
GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9C, 'Current Track Segments Object ID')
GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9D, 'Current Track Object ID')
GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9E, 'Next Track Object ID')
GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9F, 'Parent Group Object ID')
GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA0, 'Current Group Object ID')
GATT_PLAYING_ORDER_CHARACTERISTIC = UUID.from_16_bits(0x2BA1, 'Playing Order')
GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA2, 'Playing Orders Supported')
GATT_MEDIA_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BA3, 'Media State')
GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA4, 'Media Control Point')
GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA5, 'Media Control Point Opcodes Supported')
GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA6, 'Search Results Object ID')
GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA7, 'Search Control Point')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control Id')
# Telephone Bearer Service (TBS)
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB4, 'Bearer Provider Name')
GATT_BEARER_UCI_CHARACTERISTIC = UUID.from_16_bits(0x2BB5, 'Bearer UCI')
GATT_BEARER_TECHNOLOGY_CHARACTERISTIC = UUID.from_16_bits(0x2BB6, 'Bearer Technology')
GATT_BEARER_URI_SCHEMES_SUPPORTED_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2BB7, 'Bearer URI Schemes Supported List')
GATT_BEARER_SIGNAL_STRENGTH_CHARACTERISTIC = UUID.from_16_bits(0x2BB8, 'Bearer Signal Strength')
GATT_BEARER_SIGNAL_STRENGTH_REPORTING_INTERVAL_CHARACTERISTIC = UUID.from_16_bits(0x2BB9, 'Bearer Signal Strength Reporting Interval')
GATT_BEARER_LIST_CURRENT_CALLS_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Bearer List Current Calls')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBB, 'Content Control ID')
GATT_STATUS_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2BBC, 'Status Flags')
GATT_INCOMING_CALL_TARGET_BEARER_URI_CHARACTERISTIC = UUID.from_16_bits(0x2BBD, 'Incoming Call Target Bearer URI')
GATT_CALL_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BBE, 'Call State')
GATT_CALL_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BBF, 'Call Control Point')
GATT_CALL_CONTROL_POINT_OPTIONAL_OPCODES_CHARACTERISTIC = UUID.from_16_bits(0x2BC0, 'Call Control Point Optional Opcodes')
GATT_TERMINATION_REASON_CHARACTERISTIC = UUID.from_16_bits(0x2BC1, 'Termination Reason')
GATT_INCOMING_CALL_CHARACTERISTIC = UUID.from_16_bits(0x2BC2, 'Incoming Call')
GATT_CALL_FRIENDLY_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Call Friendly Name')
# Microphone Control Service (MICS)
GATT_MUTE_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Mute')
# Audio Stream Control Service (ASCS)
GATT_SINK_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC4, 'Sink ASE')
GATT_SOURCE_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC5, 'Source ASE')
GATT_ASE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC6, 'ASE Control Point')
# Broadcast Audio Scan Service (BASS)
GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC7, 'Broadcast Audio Scan Control Point')
GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BC8, 'Broadcast Receive State')
# Published Audio Capabilities Service (PACS)
GATT_SINK_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BC9, 'Sink PAC')
GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCA, 'Sink Audio Location')
GATT_SOURCE_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BCB, 'Source PAC')
GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Source Audio Location')
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
# ASHA Service
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
@@ -177,6 +296,9 @@ GATT_BOOT_KEYBOARD_INPUT_REPORT_CHARACTERISTIC = UUID.from_16_bi
GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bits(0x2A2B, 'Current Time')
GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report')
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features')
GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash')
GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features')
# fmt: on
# pylint: enable=line-too-long
@@ -187,7 +309,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
# -----------------------------------------------------------------------------
def show_services(services):
def show_services(services: Iterable[Service]) -> None:
for service in services:
print(color(str(service), 'cyan'))
@@ -210,19 +332,21 @@ class Service(Attribute):
def __init__(
self,
uuid,
uuid: Union[str, UUID],
characteristics: List[Characteristic],
primary=True,
included_services: List[Service] = [],
):
) -> None:
# Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str):
uuid = UUID(uuid)
super().__init__(
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
(
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
),
Attribute.READABLE,
uuid.to_pdu_bytes(),
)
@@ -239,7 +363,7 @@ class Service(Attribute):
"""
return None
def __str__(self):
def __str__(self) -> str:
return (
f'Service(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
@@ -255,10 +379,15 @@ class TemplateService(Service):
to expose their UUID as a class property
'''
UUID: Optional[UUID] = None
UUID: UUID
def __init__(self, characteristics, primary=True):
super().__init__(self.UUID, characteristics, primary)
def __init__(
self,
characteristics: List[Characteristic],
primary: bool = True,
included_services: List[Service] = [],
) -> None:
super().__init__(self.UUID, characteristics, primary, included_services)
# -----------------------------------------------------------------------------
@@ -269,7 +398,7 @@ class IncludedServiceDeclaration(Attribute):
service: Service
def __init__(self, service):
def __init__(self, service: Service) -> None:
declaration_bytes = struct.pack(
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
)
@@ -278,13 +407,12 @@ class IncludedServiceDeclaration(Attribute):
)
self.service = service
def __str__(self):
def __str__(self) -> str:
return (
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
f'group_starting_handle=0x{self.service.handle:04X}, '
f'group_ending_handle=0x{self.service.end_group_handle:04X}, '
f'uuid={self.service.uuid}, '
f'{self.service.properties!s})'
f'uuid={self.service.uuid})'
)
@@ -309,31 +437,33 @@ class Characteristic(Attribute):
AUTHENTICATED_SIGNED_WRITES = 0x40
EXTENDED_PROPERTIES = 0x80
@staticmethod
def from_string(properties_str: str) -> Characteristic.Properties:
property_names: List[str] = []
for property in Characteristic.Properties:
if property.name is None:
raise TypeError()
property_names.append(property.name)
def string_to_property(property_string) -> Characteristic.Properties:
for property in zip(Characteristic.Properties, property_names):
if property_string == property[1]:
return property[0]
raise TypeError(f"Unable to convert {property_string} to Property")
@classmethod
def from_string(cls, properties_str: str) -> Characteristic.Properties:
try:
return functools.reduce(
lambda x, y: x | string_to_property(y),
properties_str.split(","),
lambda x, y: x | cls[y],
properties_str.replace("|", ",").split(","),
Characteristic.Properties(0),
)
except TypeError:
except (TypeError, KeyError):
# The check for `p.name is not None` here is needed because for InFlag
# enums, the .name property can be None, when the enum value is 0,
# so the type hint for .name is Optional[str].
enum_list: List[str] = [p.name for p in cls if p.name is not None]
enum_list_str = ",".join(enum_list)
raise TypeError(
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by commas: {','.join(property_names)}\nGot: {properties_str}"
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
)
def __str__(self) -> str:
# NOTE: we override this method to offer a consistent result between python
# versions: the value returned by IntFlag.__str__() changed in version 11.
return '|'.join(
flag.name
for flag in Characteristic.Properties
if self.value & flag.value and flag.name is not None
)
# For backwards compatibility these are defined here
# For new code, please use Characteristic.Properties.X
BROADCAST = Properties.BROADCAST
@@ -347,10 +477,10 @@ class Characteristic(Attribute):
def __init__(
self,
uuid,
uuid: Union[str, bytes, UUID],
properties: Characteristic.Properties,
permissions,
value=b'',
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, CharacteristicValue] = b'',
descriptors: Sequence[Descriptor] = (),
):
super().__init__(uuid, permissions, value)
@@ -368,12 +498,12 @@ class Characteristic(Attribute):
def has_properties(self, properties: Characteristic.Properties) -> bool:
return self.properties & properties == properties
def __str__(self):
def __str__(self) -> str:
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
f'uuid={self.uuid}, '
f'{self.properties!s})'
f'{self.properties})'
)
@@ -385,7 +515,7 @@ class CharacteristicDeclaration(Attribute):
characteristic: Characteristic
def __init__(self, characteristic, value_handle):
def __init__(self, characteristic: Characteristic, value_handle: int) -> None:
declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes()
@@ -396,66 +526,53 @@ class CharacteristicDeclaration(Attribute):
self.value_handle = value_handle
self.characteristic = characteristic
def __str__(self):
def __str__(self) -> str:
return (
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
f'value_handle=0x{self.value_handle:04X}, '
f'uuid={self.characteristic.uuid}, '
f'{self.characteristic.properties!s})'
f'{self.characteristic.properties})'
)
# -----------------------------------------------------------------------------
class CharacteristicValue:
'''
Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(self, read=None, write=None):
self._read = read
self._write = write
def read(self, connection):
return self._read(connection) if self._read else b''
def write(self, connection, value):
if self._write:
self._write(connection, value)
class CharacteristicValue(AttributeValue):
"""Same as AttributeValue, for backward compatibility"""
# -----------------------------------------------------------------------------
class CharacteristicAdapter:
'''
An adapter that can adapt any object with `read_value` and `write_value`
methods (like Characteristic and CharacteristicProxy objects) by wrapping
those methods with ones that return/accept encoded/decoded values.
Objects with async methods are considered proxies, so the adaptation is one
where the return value of `read_value` is decoded and the value passed to
`write_value` is encoded. Other objects are considered local characteristics
so the adaptation is one where the return value of `read_value` is encoded
and the value passed to `write_value` is decoded.
If the characteristic has a `subscribe` method, it is wrapped with one where
the values are decoded before being passed to the subscriber.
An adapter that can adapt Characteristic and AttributeProxy objects
by wrapping their `read_value()` and `write_value()` methods with ones that
return/accept encoded/decoded values.
For proxies (i.e used by a GATT client), the adaptation is one where the return
value of `read_value()` is decoded and the value passed to `write_value()` is
encoded. The `subscribe()` method, is wrapped with one where the values are decoded
before being passed to the subscriber.
For local values (i.e hosted by a GATT server) the adaptation is one where the
return value of `read_value()` is encoded and the value passed to `write_value()`
is decoded.
'''
def __init__(self, characteristic):
self.wrapped_characteristic = characteristic
self.subscribers = {} # Map from subscriber to proxy subscriber
read_value: Callable
write_value: Callable
if asyncio.iscoroutinefunction(
characteristic.read_value
) and asyncio.iscoroutinefunction(characteristic.write_value):
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
else:
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
self.wrapped_characteristic = characteristic
self.subscribers: Dict[Callable, Callable] = (
{}
) # Map from subscriber to proxy subscriber
if isinstance(characteristic, Characteristic):
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
if hasattr(self.wrapped_characteristic, 'subscribe'):
else:
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
self.subscribe = self.wrapped_subscribe
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name):
@@ -474,11 +591,13 @@ class CharacteristicAdapter:
else:
setattr(self.wrapped_characteristic, name, value)
def read_encoded_value(self, connection):
return self.encode_value(self.wrapped_characteristic.read_value(connection))
async def read_encoded_value(self, connection):
return self.encode_value(
await self.wrapped_characteristic.read_value(connection)
)
def write_encoded_value(self, connection, value):
return self.wrapped_characteristic.write_value(
async def write_encoded_value(self, connection, value):
return await self.wrapped_characteristic.write_value(
connection, self.decode_value(value)
)
@@ -519,7 +638,7 @@ class CharacteristicAdapter:
return self.wrapped_characteristic.unsubscribe(subscriber)
def __str__(self):
def __str__(self) -> str:
wrapped = str(self.wrapped_characteristic)
return f'{self.__class__.__name__}({wrapped})'
@@ -599,10 +718,10 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
Adapter that converts strings to/from bytes using UTF-8 encoding
'''
def encode_value(self, value):
def encode_value(self, value: str) -> bytes:
return value.encode('utf-8')
def decode_value(self, value):
def decode_value(self, value: bytes) -> str:
return value.decode('utf-8')
@@ -612,14 +731,25 @@ class Descriptor(Attribute):
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
'''
def __str__(self):
def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
value = self.value.read(None)
if isinstance(value, bytes):
value_str = value.hex()
else:
value_str = '<async>'
else:
value_str = '<...>'
return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'value={self.read_value(None).hex()})'
f'value={value_str})'
)
# -----------------------------------------------------------------------------
class ClientCharacteristicConfigurationBits(enum.IntFlag):
'''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit

View File

@@ -28,7 +28,19 @@ import asyncio
import logging
import struct
from datetime import datetime
from typing import List, Optional, Dict, Tuple, Callable, Union, Any
from typing import (
List,
Optional,
Dict,
Tuple,
Callable,
Union,
Any,
Iterable,
Type,
Set,
TYPE_CHECKING,
)
from pyee import EventEmitter
@@ -66,28 +78,48 @@ from .gatt import (
GATT_INCLUDE_ATTRIBUTE_TYPE,
Characteristic,
ClientCharacteristicConfigurationBits,
TemplateService,
)
if TYPE_CHECKING:
from bumble.device import Connection
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def show_services(services: Iterable[ServiceProxy]) -> None:
for service in services:
print(color(str(service), 'cyan'))
for characteristic in service.characteristics:
print(color(' ' + str(characteristic), 'magenta'))
for descriptor in characteristic.descriptors:
print(color(' ' + str(descriptor), 'green'))
# -----------------------------------------------------------------------------
# Proxies
# -----------------------------------------------------------------------------
class AttributeProxy(EventEmitter):
client: Client
def __init__(self, client, handle, end_group_handle, attribute_type):
def __init__(
self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID
) -> None:
EventEmitter.__init__(self)
self.client = client
self.handle = handle
self.end_group_handle = end_group_handle
self.type = attribute_type
async def read_value(self, no_long_read=False):
async def read_value(self, no_long_read: bool = False) -> bytes:
return self.decode_value(
await self.client.read_value(self.handle, no_long_read)
)
@@ -97,13 +129,13 @@ class AttributeProxy(EventEmitter):
self.handle, self.encode_value(value), with_response
)
def encode_value(self, value):
def encode_value(self, value: Any) -> bytes:
return value
def decode_value(self, value_bytes):
def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
def __str__(self):
def __str__(self) -> str:
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
@@ -113,7 +145,7 @@ class ServiceProxy(AttributeProxy):
included_services: List[ServiceProxy]
@staticmethod
def from_client(service_class, client, service_uuid):
def from_client(service_class, client: Client, service_uuid: UUID):
# The service and its characteristics are considered to have already been
# discovered
services = client.get_services_by_uuid(service_uuid)
@@ -136,14 +168,14 @@ class ServiceProxy(AttributeProxy):
def get_characteristics_by_uuid(self, uuid):
return self.client.get_characteristics_by_uuid(uuid, self)
def __str__(self):
def __str__(self) -> str:
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
class CharacteristicProxy(AttributeProxy):
properties: Characteristic.Properties
descriptors: List[DescriptorProxy]
subscribers: Dict[Any, Callable]
subscribers: Dict[Any, Callable[[bytes], Any]]
def __init__(
self,
@@ -171,7 +203,9 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.discover_descriptors(self)
async def subscribe(
self, subscriber: Optional[Callable] = None, prefer_notify=True
self,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
):
if subscriber is not None:
if subscriber in self.subscribers:
@@ -189,13 +223,13 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.subscribe(self, subscriber, prefer_notify)
async def unsubscribe(self, subscriber=None):
async def unsubscribe(self, subscriber=None, force=False):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return await self.client.unsubscribe(self, subscriber)
return await self.client.unsubscribe(self, subscriber, force)
def __str__(self):
def __str__(self) -> str:
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'uuid={self.uuid}, '
@@ -207,7 +241,7 @@ class DescriptorProxy(AttributeProxy):
def __init__(self, client, handle, descriptor_type):
super().__init__(client, handle, 0, descriptor_type)
def __str__(self):
def __str__(self) -> str:
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})'
@@ -216,8 +250,10 @@ class ProfileServiceProxy:
Base class for profile-specific service proxies
'''
SERVICE_CLASS: Type[TemplateService]
@classmethod
def from_client(cls, client):
def from_client(cls, client: Client) -> ProfileServiceProxy:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -227,30 +263,36 @@ class ProfileServiceProxy:
class Client:
services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]]
notification_subscribers: Dict[
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
indication_subscribers: Dict[
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
pending_request: Optional[ATT_PDU]
def __init__(self, connection):
def __init__(self, connection: Connection) -> None:
self.connection = connection
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
self.pending_response = None
self.notification_subscribers = (
{}
) # Notification subscribers, by attribute handle
self.indication_subscribers = {} # Indication subscribers, by attribute handle
self.notification_subscribers = {} # Subscriber set, by attribute handle
self.indication_subscribers = {} # Subscriber set, by attribute handle
self.services = []
self.cached_values = {}
def send_gatt_pdu(self, pdu):
def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(ATT_CID, pdu)
async def send_command(self, command):
async def send_command(self, command: ATT_PDU) -> None:
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
self.send_gatt_pdu(command.to_bytes())
async def send_request(self, request):
async def send_request(self, request: ATT_PDU):
logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
@@ -279,19 +321,19 @@ class Client:
return response
def send_confirmation(self, confirmation):
def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None:
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
self.send_gatt_pdu(confirmation.to_bytes())
async def request_mtu(self, mtu):
async def request_mtu(self, mtu: int) -> int:
# Check the range
if mtu < ATT_DEFAULT_MTU:
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
raise core.InvalidArgumentError(f'MTU must be >= {ATT_DEFAULT_MTU}')
if mtu > 0xFFFF:
raise ValueError('MTU must be <= 0xFFFF')
raise core.InvalidArgumentError('MTU must be <= 0xFFFF')
# We can only send one request per connection
if self.mtu_exchange_done:
@@ -313,10 +355,12 @@ class Client:
return self.connection.att_mtu
def get_services_by_uuid(self, uuid):
def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid]
def get_characteristics_by_uuid(self, uuid, service=None):
def get_characteristics_by_uuid(
self, uuid: UUID, service: Optional[ServiceProxy] = None
) -> List[CharacteristicProxy]:
services = [service] if service else self.services
return [
c
@@ -324,9 +368,7 @@ class Client:
if c.uuid == uuid
]
def get_attribute_grouping(
self, attribute_handle: int
) -> Optional[
def get_attribute_grouping(self, attribute_handle: int) -> Optional[
Union[
ServiceProxy,
Tuple[ServiceProxy, CharacteristicProxy],
@@ -363,7 +405,7 @@ class Client:
if not already_known:
self.services.append(service)
async def discover_services(self, uuids=None) -> List[ServiceProxy]:
async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.1 Discover All Primary Services
'''
@@ -435,7 +477,7 @@ class Client:
return services
async def discover_service(self, uuid):
async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
'''
@@ -468,7 +510,7 @@ class Client:
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return
return []
break
for attribute_handle, end_group_handle in response.handles_information:
@@ -480,7 +522,7 @@ class Client:
logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
return
return []
# Create a service proxy for this service
service = ServiceProxy(
@@ -657,8 +699,8 @@ class Client:
async def discover_descriptors(
self,
characteristic: Optional[CharacteristicProxy] = None,
start_handle=None,
end_handle=None,
start_handle: Optional[int] = None,
end_handle: Optional[int] = None,
) -> List[DescriptorProxy]:
'''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
@@ -721,7 +763,7 @@ class Client:
return descriptors
async def discover_attributes(self):
async def discover_attributes(self) -> List[AttributeProxy]:
'''
Discover all attributes, regardless of type
'''
@@ -764,7 +806,12 @@ class Client:
return attributes
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
async def subscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
) -> None:
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
@@ -801,6 +848,7 @@ class Client:
subscriber_set = subscribers.setdefault(characteristic.handle, set())
if subscriber is not None:
subscriber_set.add(subscriber)
# Add the characteristic as a subscriber, which will result in the
# characteristic emitting an 'update' event when a notification or indication
# is received
@@ -808,7 +856,18 @@ class Client:
await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
async def unsubscribe(self, characteristic, subscriber=None):
async def unsubscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
force: bool = False,
) -> None:
'''
Unsubscribe from a characteristic.
If `force` is True, this will write zeros to the CCCD when there are no
subscribers left, even if there were already no registered subscribers.
'''
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
@@ -822,29 +881,45 @@ class Client:
logger.warning('unsubscribing from characteristic with no CCCD descriptor')
return
# Check if the characteristic has subscribers
if not (
characteristic.handle in self.notification_subscribers
or characteristic.handle in self.indication_subscribers
):
if not force:
return
# Remove the subscriber(s)
if subscriber is not None:
# Remove matching subscriber from subscriber sets
for subscriber_set in (
self.notification_subscribers,
self.indication_subscribers,
):
subscribers = subscriber_set.get(characteristic.handle, [])
if subscriber in subscribers:
if (
subscribers := subscriber_set.get(characteristic.handle)
) and subscriber in subscribers:
subscribers.remove(subscriber)
# Cleanup if we removed the last one
if not subscribers:
del subscriber_set[characteristic.handle]
else:
# Remove all subscribers for this attribute from the sets!
# Remove all subscribers for this attribute from the sets
self.notification_subscribers.pop(characteristic.handle, None)
self.indication_subscribers.pop(characteristic.handle, None)
if not self.notification_subscribers and not self.indication_subscribers:
# Update the CCCD
if not (
characteristic.handle in self.notification_subscribers
or characteristic.handle in self.indication_subscribers
):
# No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value(self, attribute, no_long_read=False):
async def read_value(
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
) -> bytes:
'''
See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -905,7 +980,9 @@ class Client:
# Return the value as bytes
return attribute_value
async def read_characteristics_by_uuid(self, uuid, service):
async def read_characteristics_by_uuid(
self, uuid: UUID, service: Optional[ServiceProxy]
) -> List[bytes]:
'''
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
'''
@@ -960,7 +1037,12 @@ class Client:
return characteristics_values
async def write_value(self, attribute, value, with_response=False):
async def write_value(
self,
attribute: Union[int, AttributeProxy],
value: bytes,
with_response: bool = False,
) -> None:
'''
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
Value
@@ -990,7 +1072,7 @@ class Client:
)
)
def on_gatt_pdu(self, att_pdu):
def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
@@ -1000,7 +1082,7 @@ class Client:
logger.warning('!!! unexpected response, there is no pending request')
return
# Sanity check: the response should match the pending request unless it is
# The response should match the pending request unless it is
# an error response
if att_pdu.op_code != ATT_ERROR_RESPONSE:
expected_response_name = self.pending_request.name.replace(
@@ -1013,6 +1095,7 @@ class Client:
return
# Return the response to the coroutine that is waiting for it
assert self.pending_response is not None
self.pending_response.set_result(att_pdu)
else:
handler_name = f'on_{att_pdu.name.lower()}'
@@ -1032,7 +1115,7 @@ class Client:
def on_att_handle_value_notification(self, notification):
# Call all subscribers
subscribers = self.notification_subscribers.get(
notification.attribute_handle, []
notification.attribute_handle, set()
)
if not subscribers:
logger.warning('!!! received notification with no subscriber')
@@ -1046,7 +1129,9 @@ class Client:
def on_att_handle_value_indication(self, indication):
# Call all subscribers
subscribers = self.indication_subscribers.get(indication.attribute_handle, [])
subscribers = self.indication_subscribers.get(
indication.attribute_handle, set()
)
if not subscribers:
logger.warning('!!! received indication with no subscriber')
@@ -1060,7 +1145,7 @@ class Client:
# Confirm that we received the indication
self.send_confirmation(ATT_Handle_Value_Confirmation())
def cache_value(self, attribute_handle: int, value: bytes):
def cache_value(self, attribute_handle: int, value: bytes) -> None:
self.cached_values[attribute_handle] = (
datetime.now(),
value,

View File

@@ -23,16 +23,17 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
from collections import defaultdict
import struct
from typing import List, Tuple, Optional, TypeVar, Type
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from pyee import EventEmitter
from .colors import color
from .core import UUID
from .att import (
from bumble.colors import color
from bumble.core import UUID
from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
@@ -42,6 +43,7 @@ from .att import (
ATT_INVALID_OFFSET_ERROR,
ATT_REQUEST_NOT_SUPPORTED_ERROR,
ATT_REQUESTS,
ATT_PDU,
ATT_UNLIKELY_ERROR_ERROR,
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
ATT_Error,
@@ -58,7 +60,7 @@ from .att import (
ATT_Write_Response,
Attribute,
)
from .gatt import (
from bumble.gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
@@ -72,7 +74,10 @@ from .gatt import (
Descriptor,
Service,
)
from bumble.utils import AsyncRunner
if TYPE_CHECKING:
from bumble.device import Device, Connection
# -----------------------------------------------------------------------------
# Logging
@@ -91,8 +96,13 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# -----------------------------------------------------------------------------
class Server(EventEmitter):
attributes: List[Attribute]
services: List[Service]
attributes_by_handle: Dict[int, Attribute]
subscribers: Dict[int, Dict[int, bytes]]
indication_semaphores: defaultdict[int, asyncio.Semaphore]
pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
def __init__(self, device):
def __init__(self, device: Device) -> None:
super().__init__()
self.device = device
self.services = []
@@ -107,16 +117,16 @@ class Server(EventEmitter):
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None)
def __str__(self):
def __str__(self) -> str:
return "\n".join(map(str, self.attributes))
def send_gatt_pdu(self, connection_handle, pdu):
def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
def next_handle(self):
def next_handle(self) -> int:
return 1 + len(self.attributes)
def get_advertising_service_data(self):
def get_advertising_service_data(self) -> Dict[Attribute, bytes]:
return {
attribute: data
for attribute in self.attributes
@@ -124,7 +134,7 @@ class Server(EventEmitter):
and (data := attribute.get_advertising_data())
}
def get_attribute(self, handle):
def get_attribute(self, handle: int) -> Optional[Attribute]:
attribute = self.attributes_by_handle.get(handle)
if attribute:
return attribute
@@ -173,12 +183,17 @@ class Server(EventEmitter):
return next(
(
(attribute, self.get_attribute(attribute.characteristic.handle))
(
attribute,
self.get_attribute(attribute.characteristic.handle),
) # type: ignore
for attribute in map(
self.get_attribute,
range(service_handle.handle, service_handle.end_group_handle + 1),
)
if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
if attribute is not None
and attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
and isinstance(attribute, CharacteristicDeclaration)
and attribute.characteristic.uuid == characteristic_uuid
),
None,
@@ -197,7 +212,7 @@ class Server(EventEmitter):
return next(
(
attribute
attribute # type: ignore
for attribute in map(
self.get_attribute,
range(
@@ -205,12 +220,12 @@ class Server(EventEmitter):
characteristic_value.end_group_handle + 1,
),
)
if attribute.type == descriptor_uuid
if attribute is not None and attribute.type == descriptor_uuid
),
None,
)
def add_attribute(self, attribute):
def add_attribute(self, attribute: Attribute) -> None:
# Assign a handle to this attribute
attribute.handle = self.next_handle()
attribute.end_group_handle = (
@@ -220,7 +235,7 @@ class Server(EventEmitter):
# Add this attribute to the list
self.attributes.append(attribute)
def add_service(self, service: Service):
def add_service(self, service: Service) -> None:
# Add the service attribute to the DB
self.add_attribute(service)
@@ -285,11 +300,13 @@ class Server(EventEmitter):
service.end_group_handle = self.attributes[-1].handle
self.services.append(service)
def add_services(self, services):
def add_services(self, services: Iterable[Service]) -> None:
for service in services:
self.add_service(service)
def read_cccd(self, connection, characteristic):
def read_cccd(
self, connection: Optional[Connection], characteristic: Characteristic
) -> bytes:
if connection is None:
return bytes([0, 0])
@@ -300,13 +317,18 @@ class Server(EventEmitter):
return cccd or bytes([0, 0])
def write_cccd(self, connection, characteristic, value):
def write_cccd(
self,
connection: Connection,
characteristic: Characteristic,
value: bytes,
) -> None:
logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
)
# Sanity check
# Check parameters
if len(value) != 2:
logger.warning('CCCD value not 2 bytes long')
return
@@ -327,13 +349,19 @@ class Server(EventEmitter):
indicate_enabled,
)
def send_response(self, connection, response):
def send_response(self, connection: Connection, response: ATT_PDU) -> None:
logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
self.send_gatt_pdu(connection.handle, response.to_bytes())
async def notify_subscriber(self, connection, attribute, value=None, force=False):
async def notify_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -352,7 +380,7 @@ class Server(EventEmitter):
# Get or encode the value
value = (
attribute.read_value(connection)
await attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
@@ -370,7 +398,13 @@ class Server(EventEmitter):
)
self.send_gatt_pdu(connection.handle, bytes(notification))
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
async def indicate_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -389,7 +423,7 @@ class Server(EventEmitter):
# Get or encode the value
value = (
attribute.read_value(connection)
await attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
@@ -411,15 +445,13 @@ class Server(EventEmitter):
assert self.pending_confirmations[connection.handle] is None
# Create a future value to hold the eventual response
self.pending_confirmations[
connection.handle
] = asyncio.get_running_loop().create_future()
pending_confirmation = self.pending_confirmations[connection.handle] = (
asyncio.get_running_loop().create_future()
)
try:
self.send_gatt_pdu(connection.handle, indication.to_bytes())
await asyncio.wait_for(
self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
)
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') from error
@@ -427,8 +459,12 @@ class Server(EventEmitter):
self.pending_confirmations[connection.handle] = None
async def notify_or_indicate_subscribers(
self, indicate, attribute, value=None, force=False
):
self,
indicate: bool,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
# Get all the connections for which there's at least one subscription
connections = [
connection
@@ -450,13 +486,23 @@ class Server(EventEmitter):
]
)
async def notify_subscribers(self, attribute, value=None, force=False):
async def notify_subscribers(
self,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
):
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
async def indicate_subscribers(self, attribute, value=None, force=False):
async def indicate_subscribers(
self,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
):
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection):
def on_disconnection(self, connection: Connection) -> None:
if connection.handle in self.subscribers:
del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores:
@@ -464,7 +510,7 @@ class Server(EventEmitter):
if connection.handle in self.pending_confirmations:
del self.pending_confirmations[connection.handle]
def on_gatt_pdu(self, connection, att_pdu):
def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None:
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
handler_name = f'on_{att_pdu.name.lower()}'
handler = getattr(self, handler_name, None)
@@ -506,7 +552,7 @@ class Server(EventEmitter):
#######################################################
# ATT handlers
#######################################################
def on_att_request(self, connection, pdu):
def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None:
'''
Handler for requests without a more specific handler
'''
@@ -605,7 +651,8 @@ class Server(EventEmitter):
self.send_response(connection, response)
def on_att_find_by_type_value_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
'''
@@ -613,13 +660,13 @@ class Server(EventEmitter):
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
async for attribute in (
attribute
for attribute in self.attributes
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type
and attribute.read_value(connection) == request.attribute_value
and (await attribute.read_value(connection)) == request.attribute_value
and pdu_space_available >= 4
):
# TODO: check permissions
@@ -657,7 +704,8 @@ class Server(EventEmitter):
self.send_response(connection, response)
def on_att_read_by_type_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_read_by_type_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
@@ -679,9 +727,8 @@ class Server(EventEmitter):
and attribute.handle <= request.ending_handle
and pdu_space_available
):
try:
attribute_value = attribute.read_value(connection)
attribute_value = await attribute.read_value(connection)
except ATT_Error as error:
# If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point
@@ -723,14 +770,15 @@ class Server(EventEmitter):
self.send_response(connection, response)
def on_att_read_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_read_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = attribute.read_value(connection)
value = await attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -748,14 +796,15 @@ class Server(EventEmitter):
)
self.send_response(connection, response)
def on_att_read_blob_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_read_blob_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = attribute.read_value(connection)
value = await attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -792,7 +841,8 @@ class Server(EventEmitter):
)
self.send_response(connection, response)
def on_att_read_by_group_type_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
'''
@@ -820,7 +870,7 @@ class Server(EventEmitter):
):
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = attribute.read_value(connection)
attribute_value = await attribute.read_value(connection)
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
@@ -859,7 +909,8 @@ class Server(EventEmitter):
self.send_response(connection, response)
def on_att_write_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_write_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
'''
@@ -892,12 +943,13 @@ class Server(EventEmitter):
return
# Accept the value
attribute.write_value(connection, request.attribute_value)
await attribute.write_value(connection, request.attribute_value)
# Done
self.send_response(connection, ATT_Write_Response())
def on_att_write_command(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_write_command(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
'''
@@ -915,9 +967,9 @@ class Server(EventEmitter):
# Accept the value
try:
attribute.write_value(connection, request.attribute_value)
await attribute.write_value(connection, request.attribute_value)
except Exception as error:
logger.warning(f'!!! ignoring exception: {error}')
logger.exception(f'!!! ignoring exception: {error}')
def on_att_handle_value_confirmation(self, connection, _confirmation):
'''

File diff suppressed because it is too large Load Diff

View File

@@ -15,30 +15,46 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import Callable, MutableMapping
import datetime
from typing import cast, Any, Optional
import logging
from .colors import color
from .att import ATT_CID, ATT_PDU
from .smp import SMP_CID, SMP_Command
from .core import name_or_number
from .l2cap import (
from bumble import avc
from bumble import avctp
from bumble import avdtp
from bumble import avrcp
from bumble import crypto
from bumble import rfcomm
from bumble import sdp
from bumble.colors import color
from bumble.att import ATT_CID, ATT_PDU
from bumble.smp import SMP_CID, SMP_Command
from bumble.core import name_or_number
from bumble.l2cap import (
L2CAP_PDU,
L2CAP_CONNECTION_REQUEST,
L2CAP_CONNECTION_RESPONSE,
L2CAP_SIGNALING_CID,
L2CAP_LE_SIGNALING_CID,
L2CAP_Control_Frame,
L2CAP_Connection_Request,
L2CAP_Connection_Response,
)
from .hci import (
from bumble.hci import (
Address,
HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_AclDataPacketAssembler,
HCI_Packet,
HCI_Event,
HCI_AclDataPacket,
HCI_Disconnection_Complete_Event,
)
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM
from .sdp import SDP_PDU, SDP_PSM
from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
# -----------------------------------------------------------------------------
# Logging
@@ -48,26 +64,36 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
PSM_NAMES = {
RFCOMM_PSM: 'RFCOMM',
SDP_PSM: 'SDP',
AVDTP_PSM: 'AVDTP'
rfcomm.RFCOMM_PSM: 'RFCOMM',
sdp.SDP_PSM: 'SDP',
avdtp.AVDTP_PSM: 'AVDTP',
avctp.AVCTP_PSM: 'AVCTP',
# TODO: add more PSM values
}
AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'}
# -----------------------------------------------------------------------------
class PacketTracer:
class AclStream:
def __init__(self, analyzer):
psms: MutableMapping[int, int]
peer: Optional[PacketTracer.AclStream]
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
avctp_assemblers: MutableMapping[int, avctp.MessageAssembler]
def __init__(self, analyzer: PacketTracer.Analyzer) -> None:
self.analyzer = analyzer
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.avctp_assemblers = {} # AVCTP assemblers, by source_cid
self.psms = {} # PSM, by source_cid
self.peer = None # ACL stream in the other direction
self.peer = None
# pylint: disable=too-many-nested-blocks
def on_acl_pdu(self, pdu):
def on_acl_pdu(self, pdu: bytes) -> None:
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.analyzer.emit(l2cap_pdu)
if l2cap_pdu.cid == ATT_CID:
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
@@ -81,46 +107,59 @@ class PacketTracer:
# Check if this signals a new channel
if control_frame.code == L2CAP_CONNECTION_REQUEST:
self.psms[control_frame.source_cid] = control_frame.psm
connection_request = cast(L2CAP_Connection_Request, control_frame)
self.psms[connection_request.source_cid] = connection_request.psm
elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
connection_response = cast(L2CAP_Connection_Response, control_frame)
if (
control_frame.result
connection_response.result
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
):
if self.peer:
if psm := self.peer.psms.get(control_frame.source_cid):
# Found a pending connection
self.psms[control_frame.destination_cid] = psm
# For AVDTP connections, create a packet assembler for
# each direction
if psm == AVDTP_PSM:
self.avdtp_assemblers[
control_frame.source_cid
] = AVDTP_MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
control_frame.destination_cid
] = AVDTP_MessageAssembler(
self.peer.on_avdtp_message
)
if self.peer and (
psm := self.peer.psms.get(connection_response.source_cid)
):
# Found a pending connection
self.psms[connection_response.destination_cid] = psm
# For AVDTP connections, create a packet assembler for
# each direction
if psm == avdtp.AVDTP_PSM:
self.avdtp_assemblers[
connection_response.source_cid
] = avdtp.MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
connection_response.destination_cid
] = avdtp.MessageAssembler(self.peer.on_avdtp_message)
elif psm == avctp.AVCTP_PSM:
self.avctp_assemblers[
connection_response.source_cid
] = avctp.MessageAssembler(self.on_avctp_message)
self.peer.avctp_assemblers[
connection_response.destination_cid
] = avctp.MessageAssembler(self.peer.on_avctp_message)
else:
# Try to find the PSM associated with this PDU
if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)):
if psm == SDP_PSM:
sdp_pdu = SDP_PDU.from_bytes(l2cap_pdu.payload)
if psm == sdp.SDP_PSM:
sdp_pdu = sdp.SDP_PDU.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(sdp_pdu)
elif psm == RFCOMM_PSM:
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
elif psm == rfcomm.RFCOMM_PSM:
rfcomm_frame = rfcomm.RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame)
elif psm == AVDTP_PSM:
elif psm == avdtp.AVDTP_PSM:
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
)
assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
if assembler:
assembler.on_pdu(l2cap_pdu.payload)
if avdtp_assembler := self.avdtp_assemblers.get(l2cap_pdu.cid):
avdtp_assembler.on_pdu(l2cap_pdu.payload)
elif psm == avctp.AVCTP_PSM:
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVCTP]: {l2cap_pdu.payload.hex()}'
)
if avctp_assembler := self.avctp_assemblers.get(l2cap_pdu.cid):
avctp_assembler.on_pdu(l2cap_pdu.payload)
else:
psm_string = name_or_number(PSM_NAMES, psm)
self.analyzer.emit(
@@ -130,22 +169,49 @@ class PacketTracer:
else:
self.analyzer.emit(l2cap_pdu)
def on_avdtp_message(self, transaction_label, message):
def on_avdtp_message(
self, transaction_label: int, message: avdtp.Message
) -> None:
self.analyzer.emit(
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
)
def feed_packet(self, packet):
def on_avctp_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
):
if pid == avrcp.AVRCP_PID:
avc_frame = avc.Frame.from_bytes(payload)
details = str(avc_frame)
else:
details = payload.hex()
c_r = 'Command' if is_command else 'Response'
self.analyzer.emit(
f'{color("AVCTP", "green")} '
f'{c_r}[{transaction_label}][{name_or_number(AVCTP_PID_NAMES, pid)}] '
f'{"#" if ipid else ""}'
f'{details}'
)
def feed_packet(self, packet: HCI_AclDataPacket) -> None:
self.packet_assembler.feed_packet(packet)
class Analyzer:
def __init__(self, label, emit_message):
acl_streams: MutableMapping[int, PacketTracer.AclStream]
peer: PacketTracer.Analyzer
def __init__(self, label: str, emit_message: Callable[..., None]) -> None:
self.label = label
self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle
self.peer = None # Analyzer in the other direction
self.packet_timestamp: Optional[datetime.datetime] = None
def start_acl_stream(self, connection_handle):
def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream:
logger.info(
f'[{self.label}] +++ Creating ACL stream for connection '
f'0x{connection_handle:04X}'
@@ -160,7 +226,7 @@ class PacketTracer:
return stream
def end_acl_stream(self, connection_handle):
def end_acl_stream(self, connection_handle: int) -> None:
if connection_handle in self.acl_streams:
logger.info(
f'[{self.label}] --- Removing ACL stream for connection '
@@ -171,34 +237,52 @@ class PacketTracer:
# Let the other forwarder know so it can cleanup its stream as well
self.peer.end_acl_stream(connection_handle)
def on_packet(self, packet):
def on_packet(
self, timestamp: Optional[datetime.datetime], packet: HCI_Packet
) -> None:
self.packet_timestamp = timestamp
self.emit(packet)
if packet.hci_packet_type == HCI_ACL_DATA_PACKET:
acl_packet = cast(HCI_AclDataPacket, packet)
# Look for an existing stream for this handle, create one if it is the
# first ACL packet for that connection handle
if (stream := self.acl_streams.get(packet.connection_handle)) is None:
stream = self.start_acl_stream(packet.connection_handle)
stream.feed_packet(packet)
if (
stream := self.acl_streams.get(acl_packet.connection_handle)
) is None:
stream = self.start_acl_stream(acl_packet.connection_handle)
stream.feed_packet(acl_packet)
elif packet.hci_packet_type == HCI_EVENT_PACKET:
if packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
self.end_acl_stream(packet.connection_handle)
event_packet = cast(HCI_Event, packet)
if event_packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
self.end_acl_stream(
cast(HCI_Disconnection_Complete_Event, packet).connection_handle
)
def emit(self, message):
self.emit_message(f'[{self.label}] {message}')
def emit(self, message: Any) -> None:
if self.packet_timestamp:
prefix = f"[{self.packet_timestamp.strftime('%Y-%m-%d %H:%M:%S.%f')}]"
else:
prefix = ""
self.emit_message(f'{prefix}[{self.label}] {message}')
def trace(self, packet, direction=0):
def trace(
self,
packet: HCI_Packet,
direction: int = 0,
timestamp: Optional[datetime.datetime] = None,
) -> None:
if direction == 0:
self.host_to_controller_analyzer.on_packet(packet)
self.host_to_controller_analyzer.on_packet(timestamp, packet)
else:
self.controller_to_host_analyzer.on_packet(packet)
self.controller_to_host_analyzer.on_packet(timestamp, packet)
def __init__(
self,
host_to_controller_label=color('HOST->CONTROLLER', 'blue'),
controller_to_host_label=color('CONTROLLER->HOST', 'cyan'),
emit_message=logger.info,
):
host_to_controller_label: str = color('HOST->CONTROLLER', 'blue'),
controller_to_host_label: str = color('CONTROLLER->HOST', 'cyan'),
emit_message: Callable[..., None] = logger.info,
) -> None:
self.host_to_controller_analyzer = PacketTracer.Analyzer(
host_to_controller_label, emit_message
)
@@ -207,3 +291,15 @@ class PacketTracer:
)
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer
def generate_irk() -> bytes:
return crypto.r()
def verify_rpa_with_irk(rpa: Address, irk: bytes) -> bool:
rpa_bytes = bytes(rpa)
prand_given = rpa_bytes[3:]
hash_given = rpa_bytes[:3]
hash_local = crypto.ah(irk, prand_given)
return hash_local[:3] == hash_given

File diff suppressed because it is too large Load Diff

555
bumble/hid.py Normal file
View File

@@ -0,0 +1,555 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from dataclasses import dataclass
import logging
import enum
import struct
from abc import ABC, abstractmethod
from pyee import EventEmitter
from typing import Optional, Callable, TYPE_CHECKING
from typing_extensions import override
from bumble import l2cap, device
from bumble.colors import color
from bumble.core import InvalidStateError, ProtocolError
from .hci import Address
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: on
HID_CONTROL_PSM = 0x0011
HID_INTERRUPT_PSM = 0x0013
class Message:
message_type: MessageType
# Report types
class ReportType(enum.IntEnum):
OTHER_REPORT = 0x00
INPUT_REPORT = 0x01
OUTPUT_REPORT = 0x02
FEATURE_REPORT = 0x03
# Handshake parameters
class Handshake(enum.IntEnum):
SUCCESSFUL = 0x00
NOT_READY = 0x01
ERR_INVALID_REPORT_ID = 0x02
ERR_UNSUPPORTED_REQUEST = 0x03
ERR_INVALID_PARAMETER = 0x04
ERR_UNKNOWN = 0x0E
ERR_FATAL = 0x0F
# Message Type
class MessageType(enum.IntEnum):
HANDSHAKE = 0x00
CONTROL = 0x01
GET_REPORT = 0x04
SET_REPORT = 0x05
GET_PROTOCOL = 0x06
SET_PROTOCOL = 0x07
DATA = 0x0A
# Protocol modes
class ProtocolMode(enum.IntEnum):
BOOT_PROTOCOL = 0x00
REPORT_PROTOCOL = 0x01
# Control Operations
class ControlCommand(enum.IntEnum):
SUSPEND = 0x03
EXIT_SUSPEND = 0x04
VIRTUAL_CABLE_UNPLUG = 0x05
# Class Method to derive header
@classmethod
def header(cls, lower_bits: int = 0x00) -> bytes:
return bytes([(cls.message_type << 4) | lower_bits])
# HIDP messages
@dataclass
class GetReportMessage(Message):
report_type: int
report_id: int
buffer_size: int
message_type = Message.MessageType.GET_REPORT
def __bytes__(self) -> bytes:
packet_bytes = bytearray()
packet_bytes.append(self.report_id)
if self.buffer_size == 0:
return self.header(self.report_type) + packet_bytes
else:
return (
self.header(0x08 | self.report_type)
+ packet_bytes
+ struct.pack("<H", self.buffer_size)
)
@dataclass
class SetReportMessage(Message):
report_type: int
data: bytes
message_type = Message.MessageType.SET_REPORT
def __bytes__(self) -> bytes:
return self.header(self.report_type) + self.data
@dataclass
class SendControlData(Message):
report_type: int
data: bytes
message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes:
return self.header(self.report_type) + self.data
@dataclass
class GetProtocolMessage(Message):
message_type = Message.MessageType.GET_PROTOCOL
def __bytes__(self) -> bytes:
return self.header()
@dataclass
class SetProtocolMessage(Message):
protocol_mode: int
message_type = Message.MessageType.SET_PROTOCOL
def __bytes__(self) -> bytes:
return self.header(self.protocol_mode)
@dataclass
class Suspend(Message):
message_type = Message.MessageType.CONTROL
def __bytes__(self) -> bytes:
return self.header(Message.ControlCommand.SUSPEND)
@dataclass
class ExitSuspend(Message):
message_type = Message.MessageType.CONTROL
def __bytes__(self) -> bytes:
return self.header(Message.ControlCommand.EXIT_SUSPEND)
@dataclass
class VirtualCableUnplug(Message):
message_type = Message.MessageType.CONTROL
def __bytes__(self) -> bytes:
return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
# Device sends input report, host sends output report.
@dataclass
class SendData(Message):
data: bytes
report_type: int
message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes:
return self.header(self.report_type) + self.data
@dataclass
class SendHandshakeMessage(Message):
result_code: int
message_type = Message.MessageType.HANDSHAKE
def __bytes__(self) -> bytes:
return self.header(self.result_code)
# -----------------------------------------------------------------------------
class HID(ABC, EventEmitter):
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
connection: Optional[device.Connection] = None
class Role(enum.IntEnum):
HOST = 0x00
DEVICE = 0x01
def __init__(self, device: device.Device, role: Role) -> None:
super().__init__()
self.remote_device_bd_address: Optional[Address] = None
self.device = device
self.role = role
# Register ourselves with the L2CAP channel manager
device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
device.on('connection', self.on_device_connection)
async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel
try:
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_CONTROL_PSM
)
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_ctrl_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
async def connect_interrupt_channel(self) -> None:
# Create a new L2CAP connection - interrupt channel
try:
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_INTERRUPT_PSM
)
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_intr_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_intr_channel.sink = self.on_intr_pdu
async def disconnect_interrupt_channel(self) -> None:
if self.l2cap_intr_channel is None:
raise InvalidStateError('invalid state')
channel = self.l2cap_intr_channel
self.l2cap_intr_channel = None
await channel.disconnect()
async def disconnect_control_channel(self) -> None:
if self.l2cap_ctrl_channel is None:
raise InvalidStateError('invalid state')
channel = self.l2cap_ctrl_channel
self.l2cap_ctrl_channel = None
await channel.disconnect()
def on_device_connection(self, connection: device.Connection) -> None:
self.connection = connection
self.remote_device_bd_address = connection.peer_address
connection.on('disconnection', self.on_device_disconnection)
def on_device_disconnection(self, reason: int) -> None:
self.connection = None
def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
if l2cap_channel.psm == HID_CONTROL_PSM:
self.l2cap_ctrl_channel = l2cap_channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
else:
self.l2cap_intr_channel = l2cap_channel
self.l2cap_intr_channel.sink = self.on_intr_pdu
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
if l2cap_channel.psm == HID_CONTROL_PSM:
self.l2cap_ctrl_channel = None
else:
self.l2cap_intr_channel = None
logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
@abstractmethod
def on_ctrl_pdu(self, pdu: bytes) -> None:
pass
def on_intr_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
self.emit("interrupt_data", pdu)
def send_pdu_on_ctrl(self, msg: bytes) -> None:
assert self.l2cap_ctrl_channel
self.l2cap_ctrl_channel.send_pdu(msg)
def send_pdu_on_intr(self, msg: bytes) -> None:
assert self.l2cap_intr_channel
self.l2cap_intr_channel.send_pdu(msg)
def send_data(self, data: bytes) -> None:
if self.role == HID.Role.HOST:
report_type = Message.ReportType.OUTPUT_REPORT
else:
report_type = Message.ReportType.INPUT_REPORT
msg = SendData(data, report_type)
hid_message = bytes(msg)
if self.l2cap_intr_channel is not None:
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
self.send_pdu_on_intr(hid_message)
def virtual_cable_unplug(self) -> None:
msg = VirtualCableUnplug()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
# -----------------------------------------------------------------------------
class Device(HID):
class GetSetReturn(enum.IntEnum):
FAILURE = 0x00
REPORT_ID_NOT_FOUND = 0x01
ERR_UNSUPPORTED_REQUEST = 0x02
ERR_UNKNOWN = 0x03
ERR_INVALID_PARAMETER = 0x04
SUCCESS = 0xFF
class GetSetStatus:
def __init__(self) -> None:
self.data = bytearray()
self.status = 0
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE)
get_report_cb: Optional[Callable[[int, int, int], None]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
get_protocol_cb: Optional[Callable[[], None]] = None
set_protocol_cb: Optional[Callable[[int], None]] = None
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
param = pdu[0] & 0x0F
message_type = pdu[0] >> 4
if message_type == Message.MessageType.GET_REPORT:
logger.debug('<<< HID GET REPORT')
self.handle_get_report(pdu)
elif message_type == Message.MessageType.SET_REPORT:
logger.debug('<<< HID SET REPORT')
self.handle_set_report(pdu)
elif message_type == Message.MessageType.GET_PROTOCOL:
logger.debug('<<< HID GET PROTOCOL')
self.handle_get_protocol(pdu)
elif message_type == Message.MessageType.SET_PROTOCOL:
logger.debug('<<< HID SET PROTOCOL')
self.handle_set_protocol(pdu)
elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu)
elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.SUSPEND:
logger.debug('<<< HID SUSPEND')
self.emit('suspend')
elif param == Message.ControlCommand.EXIT_SUSPEND:
logger.debug('<<< HID EXIT SUSPEND')
self.emit('exit_suspend')
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug')
else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def send_handshake_message(self, result_code: int) -> None:
msg = SendHandshakeMessage(result_code)
hid_message = bytes(msg)
logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def send_control_data(self, report_type: int, data: bytes):
msg = SendControlData(report_type=report_type, data=data)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def handle_get_report(self, pdu: bytes):
if self.get_report_cb is None:
logger.debug("GetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
buffer_flag = (pdu[0] & 0x08) >> 3
report_id = pdu[1]
logger.debug(f"buffer_flag: {buffer_flag}")
if buffer_flag == 1:
buffer_size = (pdu[3] << 8) | pdu[2]
else:
buffer_size = 0
ret = self.get_report_cb(report_id, report_type, buffer_size)
assert ret is not None
if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS:
data = bytearray()
data.append(report_id)
data.extend(ret.data)
if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr]
self.send_control_data(report_type=report_type, data=data)
else:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
self.get_report_cb = cb
logger.debug("GetReport callback registered successfully")
def handle_set_report(self, pdu: bytes):
if self.set_report_cb is None:
logger.debug("SetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
report_id = pdu[1]
report_data = pdu[2:]
report_size = len(report_data) + 1
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb(
self, cb: Callable[[int, int, int, bytes], None]
) -> None:
self.set_report_cb = cb
logger.debug("SetReport callback registered successfully")
def handle_get_protocol(self, pdu: bytes):
if self.get_protocol_cb is None:
logger.debug("GetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.get_protocol_cb()
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully")
def handle_set_protocol(self, pdu: bytes):
if self.set_protocol_cb is None:
logger.debug("SetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.set_protocol_cb(pdu[0] & 0x01)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
self.set_protocol_cb = cb
logger.debug("SetProtocol callback registered successfully")
# -----------------------------------------------------------------------------
class Host(HID):
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.HOST)
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
msg = GetReportMessage(
report_type=report_type, report_id=report_id, buffer_size=buffer_size
)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_report(self, report_type: int, data: bytes) -> None:
msg = SetReportMessage(report_type=report_type, data=data)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def get_protocol(self) -> None:
msg = GetProtocolMessage()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_protocol(self, protocol_mode: int) -> None:
msg = SetProtocolMessage(protocol_mode=protocol_mode)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def suspend(self) -> None:
msg = Suspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def exit_suspend(self) -> None:
msg = ExitSuspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
param = pdu[0] & 0x0F
message_type = pdu[0] >> 4
if message_type == Message.MessageType.HANDSHAKE:
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
self.emit('handshake', Message.Handshake(param))
elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu)
elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug')
else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')

File diff suppressed because it is too large Load Diff

View File

@@ -25,7 +25,8 @@ import asyncio
import logging
import os
import json
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
from typing_extensions import Self
from .colors import color
from .hci import Address
@@ -128,10 +129,10 @@ class PairingKeys:
def print(self, prefix=''):
keys_dict = self.to_dict()
for (container_property, value) in keys_dict.items():
for container_property, value in keys_dict.items():
if isinstance(value, dict):
print(f'{prefix}{color(container_property, "cyan")}:')
for (key_property, key_value) in value.items():
for key_property, key_value in value.items():
print(f'{prefix} {color(key_property, "green")}: {key_value}')
else:
print(f'{prefix}{color(container_property, "cyan")}: {value}')
@@ -158,7 +159,7 @@ class KeyStore:
async def get_resolving_keys(self):
all_keys = await self.get_all()
resolving_keys = []
for (name, keys) in all_keys:
for name, keys in all_keys:
if keys.irk is not None:
if keys.address_type is None:
address_type = Address.RANDOM_DEVICE_ADDRESS
@@ -171,7 +172,7 @@ class KeyStore:
async def print(self, prefix=''):
entries = await self.get_all()
separator = ''
for (name, keys) in entries:
for name, keys in entries:
print(separator + prefix + color(name, 'yellow'))
keys.print(prefix=prefix + ' ')
separator = '\n'
@@ -253,8 +254,10 @@ class JsonKeyStore(KeyStore):
logger.debug(f'JSON keystore: {self.filename}')
@staticmethod
def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]:
@classmethod
def from_device(
cls: Type[Self], device: Device, filename: Optional[str] = None
) -> Self:
if not filename:
# Extract the filename from the config if there is one
if device.config.keystore is not None:
@@ -270,7 +273,7 @@ class JsonKeyStore(KeyStore):
else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE
return JsonKeyStore(namespace, filename)
return cls(namespace, filename)
async def load(self):
# Try to open the file, without failing. If the file does not exist, it

File diff suppressed because it is too large Load Diff

View File

@@ -19,16 +19,25 @@ import logging
import asyncio
from functools import partial
from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT
from bumble.core import (
BT_PERIPHERAL_ROLE,
BT_BR_EDR_TRANSPORT,
BT_LE_TRANSPORT,
InvalidStateError,
)
from bumble.colors import color
from bumble.hci import (
Address,
HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
HCI_PAGE_TIMEOUT_ERROR,
HCI_Connection_Complete_Event,
)
from bumble import controller
from typing import Optional, Set
# -----------------------------------------------------------------------------
# Logging
@@ -57,6 +66,8 @@ class LocalLink:
Link bus for controllers to communicate with each other
'''
controllers: Set[controller.Controller]
def __init__(self):
self.controllers = set()
self.pending_connection = None
@@ -79,7 +90,9 @@ class LocalLink:
return controller
return None
def find_classic_controller(self, address):
def find_classic_controller(
self, address: Address
) -> Optional[controller.Controller]:
for controller in self.controllers:
if controller.public_address == address:
return controller
@@ -188,6 +201,60 @@ class LocalLink:
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
def create_cis(
self,
central_controller: controller.Controller,
peripheral_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
)
if peripheral_controller := self.find_controller(peripheral_address):
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_request,
central_controller.random_address,
cig_id,
cis_id,
)
def accept_cis(
self,
peripheral_controller: controller.Controller,
central_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
)
if central_controller := self.find_controller(central_address):
asyncio.get_running_loop().call_soon(
central_controller.on_link_cis_established, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_established, cig_id, cis_id
)
def disconnect_cis(
self,
initiator_controller: controller.Controller,
peer_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
)
if peer_controller := self.find_controller(peer_address):
asyncio.get_running_loop().call_soon(
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peer_controller.on_link_cis_disconnected, cig_id, cis_id
)
############################################################
# Classic handlers
############################################################
@@ -271,6 +338,52 @@ class LocalLink:
initiator_controller.public_address, int(not (initiator_new_role))
)
def classic_sco_connect(
self,
initiator_controller: controller.Controller,
responder_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
# Initiator controller should handle it.
assert responder_controller
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
link_type,
)
def classic_accept_sco_connection(
self,
responder_controller: controller.Controller,
initiator_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_sco_connection_complete(
responder_controller.public_address,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
link_type,
)
return
async def task():
initiator_controller.on_classic_sco_connection_complete(
responder_controller.public_address, HCI_SUCCESS, link_type
)
asyncio.create_task(task())
responder_controller.on_classic_sco_connection_complete(
initiator_controller.public_address, HCI_SUCCESS, link_type
)
# -----------------------------------------------------------------------------
class RemoteLink:
@@ -297,12 +410,12 @@ class RemoteLink:
def add_controller(self, controller):
if self.controller:
raise ValueError('controller already set')
raise InvalidStateError('controller already set')
self.controller = controller
def remove_controller(self, controller):
if self.controller != controller:
raise ValueError('controller mismatch')
raise InvalidStateError('controller mismatch')
self.controller = None
def get_pending_connection(self):

View File

@@ -15,10 +15,13 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
from dataclasses import dataclass
from typing import Optional, Tuple
from .hci import (
Address,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
HCI_DISPLAY_ONLY_IO_CAPABILITY,
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
@@ -34,7 +37,60 @@ from .smp import (
SMP_ID_KEY_DISTRIBUTION_FLAG,
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
SMP_LINK_KEY_DISTRIBUTION_FLAG,
OobContext,
OobLegacyContext,
OobSharedData,
)
from .core import AdvertisingData, LeRole
# -----------------------------------------------------------------------------
@dataclass
class OobData:
"""OOB data that can be sent from one device to another."""
address: Optional[Address] = None
role: Optional[LeRole] = None
shared_data: Optional[OobSharedData] = None
legacy_context: Optional[OobLegacyContext] = None
@classmethod
def from_ad(cls, ad: AdvertisingData) -> OobData:
instance = cls()
shared_data_c: Optional[bytes] = None
shared_data_r: Optional[bytes] = None
for ad_type, ad_data in ad.ad_structures:
if ad_type == AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS:
instance.address = Address(ad_data)
elif ad_type == AdvertisingData.LE_ROLE:
instance.role = LeRole(ad_data[0])
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE:
shared_data_c = ad_data
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE:
shared_data_r = ad_data
elif ad_type == AdvertisingData.SECURITY_MANAGER_TK_VALUE:
instance.legacy_context = OobLegacyContext(tk=ad_data)
if shared_data_c and shared_data_r:
instance.shared_data = OobSharedData(c=shared_data_c, r=shared_data_r)
return instance
def to_ad(self) -> AdvertisingData:
ad_structures = []
if self.address is not None:
ad_structures.append(
(AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address))
)
if self.role is not None:
ad_structures.append((AdvertisingData.LE_ROLE, bytes([self.role])))
if self.shared_data is not None:
ad_structures.extend(self.shared_data.to_ad().ad_structures)
if self.legacy_context is not None:
ad_structures.append(
(AdvertisingData.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk)
)
return AdvertisingData(ad_structures)
# -----------------------------------------------------------------------------
@@ -168,21 +224,39 @@ class PairingDelegate:
class PairingConfig:
"""Configuration for the Pairing protocol."""
class AddressType(enum.IntEnum):
PUBLIC = Address.PUBLIC_DEVICE_ADDRESS
RANDOM = Address.RANDOM_DEVICE_ADDRESS
@dataclass
class OobConfig:
"""Config for OOB pairing."""
our_context: Optional[OobContext]
peer_data: Optional[OobSharedData]
legacy_context: Optional[OobLegacyContext]
def __init__(
self,
sc: bool = True,
mitm: bool = True,
bonding: bool = True,
delegate: Optional[PairingDelegate] = None,
identity_address_type: Optional[AddressType] = None,
oob: Optional[OobConfig] = None,
) -> None:
self.sc = sc
self.mitm = mitm
self.bonding = bonding
self.delegate = delegate or PairingDelegate()
self.identity_address_type = identity_address_type
self.oob = oob
def __str__(self) -> str:
return (
f'PairingConfig(sc={self.sc}, '
f'mitm={self.mitm}, bonding={self.bonding}, '
f'delegate[{self.delegate.io_capability}])'
f'identity_address_type={self.identity_address_type}, '
f'delegate[{self.delegate.io_capability}]), '
f'oob[{self.oob}])'
)

View File

@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from bumble.pairing import PairingDelegate
from __future__ import annotations
from bumble.pairing import PairingConfig, PairingDelegate
from dataclasses import dataclass
from typing import Any, Dict
@@ -20,6 +21,7 @@ from typing import Any, Dict
@dataclass
class Config:
io_capability: PairingDelegate.IoCapability = PairingDelegate.NO_OUTPUT_NO_INPUT
identity_address_type: PairingConfig.AddressType = PairingConfig.AddressType.RANDOM
pairing_sc_enable: bool = True
pairing_mitm_enable: bool = True
pairing_bonding_enable: bool = True
@@ -35,6 +37,12 @@ class Config:
'io_capability', 'no_output_no_input'
).upper()
self.io_capability = getattr(PairingDelegate, io_capability_name)
identity_address_type_name: str = config.get(
'identity_address_type', 'random'
).upper()
self.identity_address_type = getattr(
PairingConfig.AddressType, identity_address_type_name
)
self.pairing_sc_enable = config.get('pairing_sc_enable', True)
self.pairing_mitm_enable = config.get('pairing_mitm_enable', True)
self.pairing_bonding_enable = config.get('pairing_bonding_enable', True)

View File

@@ -14,6 +14,7 @@
"""Generic & dependency free Bumble (reference) device."""
from __future__ import annotations
from bumble import transport
from bumble.core import (
BT_GENERIC_AUDIO_SERVICE,
@@ -34,6 +35,10 @@ from bumble.sdp import (
from typing import Any, Dict, List, Optional
# Default rootcanal HCI TCP address
ROOTCANAL_HCI_ADDRESS = "localhost:6402"
class PandoraDevice:
"""
Small wrapper around a Bumble device and it's HCI transport.
@@ -53,7 +58,9 @@ class PandoraDevice:
def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.device = _make_device(config)
self._hci_name = config.get('transport', '')
self._hci_name = config.get(
'transport', f"tcp-client:{config.get('tcp', ROOTCANAL_HCI_ADDRESS)}"
)
self._hci = None
@property

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import bumble.device
import grpc
@@ -27,14 +28,18 @@ from bumble.core import (
BT_PERIPHERAL_ROLE,
UUID,
AdvertisingData,
Appearance,
ConnectionError,
)
from bumble.device import (
DEVICE_DEFAULT_SCAN_INTERVAL,
DEVICE_DEFAULT_SCAN_WINDOW,
Advertisement,
AdvertisingParameters,
AdvertisingEventProperties,
AdvertisingType,
Device,
Phy,
)
from bumble.gatt import Service
from bumble.hci import (
@@ -46,9 +51,12 @@ from bumble.hci import (
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from pandora.host_grpc_aio import HostServicer
from pandora import host_pb2
from pandora.host_pb2 import (
NOT_CONNECTABLE,
NOT_DISCOVERABLE,
DISCOVERABLE_LIMITED,
DISCOVERABLE_GENERAL,
PRIMARY_1M,
PRIMARY_CODED,
SECONDARY_1M,
@@ -64,6 +72,7 @@ from pandora.host_pb2 import (
ConnectResponse,
DataTypes,
DisconnectRequest,
DiscoverabilityMode,
InquiryResponse,
PrimaryPhy,
ReadLocalAddressResponse,
@@ -93,6 +102,25 @@ SECONDARY_PHY_MAP: Dict[int, SecondaryPhy] = {
3: SECONDARY_CODED,
}
PRIMARY_PHY_TO_BUMBLE_PHY_MAP: Dict[PrimaryPhy, Phy] = {
PRIMARY_1M: Phy.LE_1M,
PRIMARY_CODED: Phy.LE_CODED,
}
SECONDARY_PHY_TO_BUMBLE_PHY_MAP: Dict[SecondaryPhy, Phy] = {
SECONDARY_NONE: Phy.LE_1M,
SECONDARY_1M: Phy.LE_1M,
SECONDARY_2M: Phy.LE_2M,
SECONDARY_CODED: Phy.LE_CODED,
}
OWN_ADDRESS_MAP: Dict[host_pb2.OwnAddressType, bumble.hci.OwnAddressType] = {
host_pb2.PUBLIC: bumble.hci.OwnAddressType.PUBLIC,
host_pb2.RANDOM: bumble.hci.OwnAddressType.RANDOM,
host_pb2.RESOLVABLE_OR_PUBLIC: bumble.hci.OwnAddressType.RESOLVABLE_OR_PUBLIC,
host_pb2.RESOLVABLE_OR_RANDOM: bumble.hci.OwnAddressType.RESOLVABLE_OR_RANDOM,
}
class HostService(HostServicer):
waited_connections: Set[int]
@@ -112,7 +140,7 @@ class HostService(HostServicer):
async def FactoryReset(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info('FactoryReset')
self.log.debug('FactoryReset')
# delete all bonds
if self.device.keystore is not None:
@@ -126,7 +154,7 @@ class HostService(HostServicer):
async def Reset(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info('Reset')
self.log.debug('Reset')
# clear service.
self.waited_connections.clear()
@@ -139,7 +167,7 @@ class HostService(HostServicer):
async def ReadLocalAddress(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> ReadLocalAddressResponse:
self.log.info('ReadLocalAddress')
self.log.debug('ReadLocalAddress')
return ReadLocalAddressResponse(
address=bytes(reversed(bytes(self.device.public_address)))
)
@@ -152,7 +180,7 @@ class HostService(HostServicer):
address = Address(
bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS
)
self.log.info(f"Connect to {address}")
self.log.debug(f"Connect to {address}")
try:
connection = await self.device.connect(
@@ -167,7 +195,7 @@ class HostService(HostServicer):
return ConnectResponse(connection_already_exists=empty_pb2.Empty())
raise e
self.log.info(f"Connect to {address} done (handle={connection.handle})")
self.log.debug(f"Connect to {address} done (handle={connection.handle})")
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return ConnectResponse(connection=Connection(cookie=cookie))
@@ -186,7 +214,7 @@ class HostService(HostServicer):
if address in (Address.NIL, Address.ANY):
raise ValueError('Invalid address')
self.log.info(f"WaitConnection from {address}...")
self.log.debug(f"WaitConnection from {address}...")
connection = self.device.find_connection_by_bd_addr(
address, transport=BT_BR_EDR_TRANSPORT
@@ -201,7 +229,7 @@ class HostService(HostServicer):
# save connection has waited and respond.
self.waited_connections.add(id(connection))
self.log.info(
self.log.debug(
f"WaitConnection from {address} done (handle={connection.handle})"
)
@@ -216,7 +244,7 @@ class HostService(HostServicer):
if address in (Address.NIL, Address.ANY):
raise ValueError('Invalid address')
self.log.info(f"ConnectLE to {address}...")
self.log.debug(f"ConnectLE to {address}...")
try:
connection = await self.device.connect(
@@ -233,7 +261,7 @@ class HostService(HostServicer):
return ConnectLEResponse(connection_already_exists=empty_pb2.Empty())
raise e
self.log.info(f"ConnectLE to {address} done (handle={connection.handle})")
self.log.debug(f"ConnectLE to {address} done (handle={connection.handle})")
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return ConnectLEResponse(connection=Connection(cookie=cookie))
@@ -243,12 +271,12 @@ class HostService(HostServicer):
self, request: DisconnectRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"Disconnect: {connection_handle}")
self.log.debug(f"Disconnect: {connection_handle}")
self.log.info("Disconnecting...")
self.log.debug("Disconnecting...")
if connection := self.device.lookup_connection(connection_handle):
await connection.disconnect(HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR)
self.log.info("Disconnected")
self.log.debug("Disconnected")
return empty_pb2.Empty()
@@ -257,12 +285,12 @@ class HostService(HostServicer):
self, request: WaitDisconnectionRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"WaitDisconnection: {connection_handle}")
self.log.debug(f"WaitDisconnection: {connection_handle}")
if connection := self.device.lookup_connection(connection_handle):
disconnection_future: asyncio.Future[
None
] = asyncio.get_running_loop().create_future()
disconnection_future: asyncio.Future[None] = (
asyncio.get_running_loop().create_future()
)
def on_disconnection(_: None) -> None:
disconnection_future.set_result(None)
@@ -270,7 +298,7 @@ class HostService(HostServicer):
connection.on('disconnection', on_disconnection)
try:
await disconnection_future
self.log.info("Disconnected")
self.log.debug("Disconnected")
finally:
connection.remove_listener('disconnection', on_disconnection) # type: ignore
@@ -280,14 +308,118 @@ class HostService(HostServicer):
async def Advertise(
self, request: AdvertiseRequest, context: grpc.ServicerContext
) -> AsyncGenerator[AdvertiseResponse, None]:
if not request.legacy:
raise NotImplementedError(
"TODO: add support for extended advertising in Bumble"
try:
if request.legacy:
async for rsp in self.legacy_advertise(request, context):
yield rsp
else:
async for rsp in self.extended_advertise(request, context):
yield rsp
finally:
pass
async def extended_advertise(
self, request: AdvertiseRequest, context: grpc.ServicerContext
) -> AsyncGenerator[AdvertiseResponse, None]:
advertising_data = bytes(self.unpack_data_types(request.data))
scan_response_data = bytes(self.unpack_data_types(request.scan_response_data))
scannable = len(scan_response_data) != 0
advertising_event_properties = AdvertisingEventProperties(
is_connectable=request.connectable,
is_scannable=scannable,
is_directed=request.target is not None,
is_high_duty_cycle_directed_connectable=False,
is_legacy=False,
is_anonymous=False,
include_tx_power=False,
)
peer_address = Address.ANY
if request.target:
# Need to reverse bytes order since Bumble Address is using MSB.
target_bytes = bytes(reversed(request.target))
if request.target_variant() == "public":
peer_address = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
else:
peer_address = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
advertising_parameters = AdvertisingParameters(
advertising_event_properties=advertising_event_properties,
own_address_type=OWN_ADDRESS_MAP[request.own_address_type],
peer_address=peer_address,
primary_advertising_phy=PRIMARY_PHY_TO_BUMBLE_PHY_MAP[request.primary_phy],
secondary_advertising_phy=SECONDARY_PHY_TO_BUMBLE_PHY_MAP[
request.secondary_phy
],
)
if advertising_interval := request.interval:
advertising_parameters.primary_advertising_interval_min = int(
advertising_interval
)
if request.interval:
raise NotImplementedError("TODO: add support for `request.interval`")
if request.interval_range:
raise NotImplementedError("TODO: add support for `request.interval_range`")
advertising_parameters.primary_advertising_interval_max = int(
advertising_interval
)
if interval_range := request.interval_range:
advertising_parameters.primary_advertising_interval_max += int(
interval_range
)
advertising_set = await self.device.create_advertising_set(
advertising_parameters=advertising_parameters,
advertising_data=advertising_data,
scan_response_data=scan_response_data,
)
pending_connection: asyncio.Future[bumble.device.Connection] = (
asyncio.get_running_loop().create_future()
)
if request.connectable:
def on_connection(connection: bumble.device.Connection) -> None:
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
pending_connection.set_result(connection)
self.device.on('connection', on_connection)
try:
# Advertise until RPC is canceled
while True:
if not advertising_set.enabled:
self.log.debug('Advertise (extended)')
await advertising_set.start()
if not request.connectable:
await asyncio.sleep(1)
continue
connection = await pending_connection
pending_connection = asyncio.get_running_loop().create_future()
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
yield AdvertiseResponse(connection=Connection(cookie=cookie))
await asyncio.sleep(1)
finally:
try:
self.log.debug('Stop Advertise (extended)')
await advertising_set.stop()
await advertising_set.remove()
except Exception:
pass
async def legacy_advertise(
self, request: AdvertiseRequest, context: grpc.ServicerContext
) -> AsyncGenerator[AdvertiseResponse, None]:
if advertising_interval := request.interval:
self.device.config.advertising_interval_min = int(advertising_interval)
self.device.config.advertising_interval_max = int(advertising_interval)
if interval_range := request.interval_range:
self.device.config.advertising_interval_max += int(interval_range)
if request.primary_phy:
raise NotImplementedError("TODO: add support for `request.primary_phy`")
if request.secondary_phy:
@@ -355,14 +487,10 @@ class HostService(HostServicer):
target_bytes = bytes(reversed(request.target))
if request.target_variant() == "public":
target = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
advertising_type = (
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
) # FIXME: HIGH_DUTY ?
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY
else:
target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
advertising_type = (
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
) # FIXME: HIGH_DUTY ?
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY
if request.connectable:
@@ -378,7 +506,7 @@ class HostService(HostServicer):
try:
while True:
if not self.device.is_advertising:
self.log.info('Advertise')
self.log.debug('Advertise')
await self.device.start_advertising(
target=target,
advertising_type=advertising_type,
@@ -389,14 +517,14 @@ class HostService(HostServicer):
await asyncio.sleep(1)
continue
pending_connection: asyncio.Future[
bumble.device.Connection
] = asyncio.get_running_loop().create_future()
pending_connection: asyncio.Future[bumble.device.Connection] = (
asyncio.get_running_loop().create_future()
)
self.log.info('Wait for LE connection...')
self.log.debug('Wait for LE connection...')
connection = await pending_connection
self.log.info(
self.log.debug(
f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})"
)
@@ -410,7 +538,7 @@ class HostService(HostServicer):
self.device.remove_listener('connection', on_connection) # type: ignore
try:
self.log.info('Stop advertising')
self.log.debug('Stop advertising')
await self.device.abort_on('flush', self.device.stop_advertising())
except:
pass
@@ -420,10 +548,15 @@ class HostService(HostServicer):
self, request: ScanRequest, context: grpc.ServicerContext
) -> AsyncGenerator[ScanningResponse, None]:
# TODO: modify `start_scanning` to accept floats instead of int for ms values
if request.phys:
raise NotImplementedError("TODO: add support for `request.phys`")
self.log.debug('Scan')
self.log.info('Scan')
scanning_phys = []
if PRIMARY_1M in request.phys:
scanning_phys.append(int(Phy.LE_1M))
if PRIMARY_CODED in request.phys:
scanning_phys.append(int(Phy.LE_CODED))
if not scanning_phys:
scanning_phys = [int(Phy.LE_1M), int(Phy.LE_CODED)]
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
handler = self.device.on('advertisement', scan_queue.put_nowait)
@@ -431,12 +564,15 @@ class HostService(HostServicer):
legacy=request.legacy,
active=not request.passive,
own_address_type=request.own_address_type,
scan_interval=int(request.interval)
if request.interval
else DEVICE_DEFAULT_SCAN_INTERVAL,
scan_window=int(request.window)
if request.window
else DEVICE_DEFAULT_SCAN_WINDOW,
scan_interval=(
int(request.interval)
if request.interval
else DEVICE_DEFAULT_SCAN_INTERVAL
),
scan_window=(
int(request.window) if request.window else DEVICE_DEFAULT_SCAN_WINDOW
),
scanning_phys=scanning_phys,
)
try:
@@ -470,7 +606,7 @@ class HostService(HostServicer):
finally:
self.device.remove_listener('advertisement', handler) # type: ignore
try:
self.log.info('Stop scanning')
self.log.debug('Stop scanning')
await self.device.abort_on('flush', self.device.stop_scanning())
except:
pass
@@ -479,7 +615,7 @@ class HostService(HostServicer):
async def Inquiry(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> AsyncGenerator[InquiryResponse, None]:
self.log.info('Inquiry')
self.log.debug('Inquiry')
inquiry_queue: asyncio.Queue[
Optional[Tuple[Address, int, AdvertisingData, int]]
@@ -510,7 +646,7 @@ class HostService(HostServicer):
self.device.remove_listener('inquiry_complete', complete_handler) # type: ignore
self.device.remove_listener('inquiry_result', result_handler) # type: ignore
try:
self.log.info('Stop inquiry')
self.log.debug('Stop inquiry')
await self.device.abort_on('flush', self.device.stop_discovery())
except:
pass
@@ -519,7 +655,7 @@ class HostService(HostServicer):
async def SetDiscoverabilityMode(
self, request: SetDiscoverabilityModeRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info("SetDiscoverabilityMode")
self.log.debug("SetDiscoverabilityMode")
await self.device.set_discoverable(request.mode != NOT_DISCOVERABLE)
return empty_pb2.Empty()
@@ -527,7 +663,7 @@ class HostService(HostServicer):
async def SetConnectabilityMode(
self, request: SetConnectabilityModeRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info("SetConnectabilityMode")
self.log.debug("SetConnectabilityMode")
await self.device.set_connectable(request.mode != NOT_CONNECTABLE)
return empty_pb2.Empty()
@@ -649,9 +785,11 @@ class HostService(HostServicer):
*struct.pack('<H', dt.peripheral_connection_interval_min),
*struct.pack(
'<H',
dt.peripheral_connection_interval_max
if dt.peripheral_connection_interval_max
else dt.peripheral_connection_interval_min,
(
dt.peripheral_connection_interval_max
if dt.peripheral_connection_interval_max
else dt.peripheral_connection_interval_min
),
),
]
),
@@ -733,6 +871,16 @@ class HostService(HostServicer):
)
)
flag_map = {
NOT_DISCOVERABLE: 0x00,
DISCOVERABLE_LIMITED: AdvertisingData.LE_LIMITED_DISCOVERABLE_MODE_FLAG,
DISCOVERABLE_GENERAL: AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG,
}
if dt.le_discoverability_mode:
flags = flag_map[dt.le_discoverability_mode]
ad_structures.append((AdvertisingData.FLAGS, flags.to_bytes(1, 'big')))
return AdvertisingData(ad_structures)
def pack_data_types(self, ad: AdvertisingData) -> DataTypes:
@@ -841,8 +989,8 @@ class HostService(HostServicer):
dt.random_target_addresses.extend(
[data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))]
)
if i := cast(int, ad.get(AdvertisingData.APPEARANCE)):
dt.appearance = i
if appearance := cast(Appearance, ad.get(AdvertisingData.APPEARANCE)):
dt.appearance = int(appearance)
if i := cast(int, ad.get(AdvertisingData.ADVERTISING_INTERVAL)):
dt.advertising_interval = i
if s := cast(str, ad.get(AdvertisingData.URI)):

View File

@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import contextlib
import grpc
import logging
@@ -27,8 +29,8 @@ from bumble.core import (
)
from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error
from bumble.utils import EventWatcher
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
from contextlib import suppress
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
@@ -99,7 +101,7 @@ class PairingDelegate(BasePairingDelegate):
return ev
async def confirm(self, auto: bool = False) -> bool:
self.log.info(
self.log.debug(
f"Pairing event: `just_works` (io_capability: {self.io_capability})"
)
@@ -108,13 +110,13 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # pytype: disable=name-error
answer = await anext(self.service.event_answer) # type: ignore
assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
async def compare_numbers(self, number: int, digits: int = 6) -> bool:
self.log.info(
self.log.debug(
f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})"
)
@@ -123,13 +125,13 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(numeric_comparison=number))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # pytype: disable=name-error
answer = await anext(self.service.event_answer) # type: ignore
assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
async def get_number(self) -> Optional[int]:
self.log.info(
self.log.debug(
f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
)
@@ -138,7 +140,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # pytype: disable=name-error
answer = await anext(self.service.event_answer) # type: ignore
assert answer.event == event
if answer.answer_variant() is None:
return None
@@ -146,7 +148,7 @@ class PairingDelegate(BasePairingDelegate):
return answer.passkey
async def get_string(self, max_length: int) -> Optional[str]:
self.log.info(
self.log.debug(
f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
)
@@ -155,7 +157,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # pytype: disable=name-error
answer = await anext(self.service.event_answer) # type: ignore
assert answer.event == event
if answer.answer_variant() is None:
return None
@@ -177,7 +179,7 @@ class PairingDelegate(BasePairingDelegate):
):
return
self.log.info(
self.log.debug(
f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})"
)
@@ -232,6 +234,11 @@ class SecurityService(SecurityServicer):
sc=config.pairing_sc_enable,
mitm=config.pairing_mitm_enable,
bonding=config.pairing_bonding_enable,
identity_address_type=(
PairingConfig.AddressType.PUBLIC
if connection.self_address.is_public
else config.identity_address_type
),
delegate=PairingDelegate(
connection,
self,
@@ -247,7 +254,7 @@ class SecurityService(SecurityServicer):
async def OnPairing(
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
) -> AsyncGenerator[PairingEvent, None]:
self.log.info('OnPairing')
self.log.debug('OnPairing')
if self.event_queue is not None:
raise RuntimeError('already streaming pairing events')
@@ -273,7 +280,7 @@ class SecurityService(SecurityServicer):
self, request: SecureRequest, context: grpc.ServicerContext
) -> SecureResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"Secure: {connection_handle}")
self.log.debug(f"Secure: {connection_handle}")
connection = self.device.lookup_connection(connection_handle)
assert connection
@@ -291,25 +298,37 @@ class SecurityService(SecurityServicer):
# trigger pairing if needed
if self.need_pairing(connection, level):
try:
self.log.info('Pair...')
self.log.debug('Pair...')
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
wait_for_security: asyncio.Future[
bool
] = asyncio.get_running_loop().create_future()
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
connection.on("pairing_failure", wait_for_security.set_exception)
security_result = asyncio.get_running_loop().create_future()
connection.request_pairing()
with contextlib.closing(EventWatcher()) as watcher:
await wait_for_security
else:
await connection.pair()
@watcher.on(connection, 'pairing')
def on_pairing(*_: Any) -> None:
security_result.set_result('success')
self.log.info('Paired')
@watcher.on(connection, 'pairing_failure')
def on_pairing_failure(*_: Any) -> None:
security_result.set_result('pairing_failure')
@watcher.on(connection, 'disconnection')
def on_disconnection(*_: Any) -> None:
security_result.set_result('connection_died')
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
connection.request_pairing()
else:
await connection.pair()
result = await security_result
self.log.debug(f'Pairing session complete, status={result}')
if result != 'success':
return SecureResponse(**{result: empty_pb2.Empty()})
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -320,9 +339,9 @@ class SecurityService(SecurityServicer):
# trigger authentication if needed
if self.need_authentication(connection, level):
try:
self.log.info('Authenticate...')
self.log.debug('Authenticate...')
await connection.authenticate()
self.log.info('Authenticated')
self.log.debug('Authenticated')
except asyncio.CancelledError:
self.log.warning("Connection died during authentication")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -333,9 +352,9 @@ class SecurityService(SecurityServicer):
# trigger encryption if needed
if self.need_encryption(connection, level):
try:
self.log.info('Encrypt...')
self.log.debug('Encrypt...')
await connection.encrypt()
self.log.info('Encrypted')
self.log.debug('Encrypted')
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -353,7 +372,7 @@ class SecurityService(SecurityServicer):
self, request: WaitSecurityRequest, context: grpc.ServicerContext
) -> WaitSecurityResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"WaitSecurity: {connection_handle}")
self.log.debug(f"WaitSecurity: {connection_handle}")
connection = self.device.lookup_connection(connection_handle)
assert connection
@@ -364,10 +383,11 @@ class SecurityService(SecurityServicer):
connection.transport
] == request.level_variant()
wait_for_security: asyncio.Future[
str
] = asyncio.get_running_loop().create_future()
wait_for_security: asyncio.Future[str] = (
asyncio.get_running_loop().create_future()
)
authenticate_task: Optional[asyncio.Future[None]] = None
pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None:
assert connection
@@ -390,7 +410,7 @@ class SecurityService(SecurityServicer):
def set_failure(name: str) -> Callable[..., None]:
def wrapper(*args: Any) -> None:
self.log.info(f'Wait for security: error `{name}`: {args}')
self.log.debug(f'Wait for security: error `{name}`: {args}')
wait_for_security.set_result(name)
return wrapper
@@ -398,13 +418,13 @@ class SecurityService(SecurityServicer):
def try_set_success(*_: Any) -> None:
assert connection
if self.reached_security_level(connection, level):
self.log.info('Wait for security: done')
self.log.debug('Wait for security: done')
wait_for_security.set_result('success')
def on_encryption_change(*_: Any) -> None:
assert connection
if self.reached_security_level(connection, level):
self.log.info('Wait for security: done')
self.log.debug('Wait for security: done')
wait_for_security.set_result('success')
elif (
connection.transport == BT_BR_EDR_TRANSPORT
@@ -414,6 +434,10 @@ class SecurityService(SecurityServicer):
if authenticate_task is None:
authenticate_task = asyncio.create_task(authenticate())
def pair(*_: Any) -> None:
if self.need_pairing(connection, level):
pair_task = asyncio.create_task(connection.pair())
listeners: Dict[str, Callable[..., None]] = {
'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'),
@@ -422,32 +446,41 @@ class SecurityService(SecurityServicer):
'pairing': try_set_success,
'connection_authentication': try_set_success,
'connection_encryption_change': on_encryption_change,
'classic_pairing': try_set_success,
'classic_pairing_failure': set_failure('pairing_failure'),
'security_request': pair,
}
# register event handlers
for event, listener in listeners.items():
connection.on(event, listener)
with contextlib.closing(EventWatcher()) as watcher:
# register event handlers
for event, listener in listeners.items():
watcher.on(connection, event, listener)
# security level already reached
if self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty())
# security level already reached
if self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty())
self.log.info('Wait for security...')
kwargs = {}
kwargs[await wait_for_security] = empty_pb2.Empty()
# remove event handlers
for event, listener in listeners.items():
connection.remove_listener(event, listener) # type: ignore
self.log.debug('Wait for security...')
kwargs = {}
kwargs[await wait_for_security] = empty_pb2.Empty()
# wait for `authenticate` to finish if any
if authenticate_task is not None:
self.log.info('Wait for authentication...')
self.log.debug('Wait for authentication...')
try:
await authenticate_task # type: ignore
except:
pass
self.log.info('Authenticated')
self.log.debug('Authenticated')
# wait for `pair` to finish if any
if pair_task is not None:
self.log.debug('Wait for authentication...')
try:
await pair_task # type: ignore
except:
pass
self.log.debug('paired')
return WaitSecurityResponse(**kwargs)
@@ -503,7 +536,7 @@ class SecurityStorageService(SecurityStorageServicer):
self, request: IsBondedRequest, context: grpc.ServicerContext
) -> wrappers_pb2.BoolValue:
address = utils.address_from_request(request, request.WhichOneof("address"))
self.log.info(f"IsBonded: {address}")
self.log.debug(f"IsBonded: {address}")
if self.device.keystore is not None:
is_bonded = await self.device.keystore.get(str(address)) is not None
@@ -517,10 +550,10 @@ class SecurityStorageService(SecurityStorageServicer):
self, request: DeleteBondRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
address = utils.address_from_request(request, request.WhichOneof("address"))
self.log.info(f"DeleteBond: {address}")
self.log.debug(f"DeleteBond: {address}")
if self.device.keystore is not None:
with suppress(KeyError):
with contextlib.suppress(KeyError):
await self.device.keystore.delete(str(address))
return empty_pb2.Empty()

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import functools
import grpc

View File

@@ -18,7 +18,9 @@
# -----------------------------------------------------------------------------
import struct
import logging
from typing import List
from typing import List, Optional
from bumble import l2cap
from ..core import AdvertisingData
from ..device import Device, Connection
from ..gatt import (
@@ -65,7 +67,7 @@ class AshaService(TemplateService):
self.emit('volume', connection, value[0])
# Handler for audio control commands
def on_audio_control_point_write(connection: Connection, value):
def on_audio_control_point_write(connection: Optional[Connection], value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
@@ -149,7 +151,10 @@ class AshaService(TemplateService):
channel.sink = on_data
# let the server find a free PSM
self.psm = self.device.register_l2cap_channel_server(self.psm, on_coc, 8)
self.psm = device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=self.psm, max_credits=8),
handler=on_coc,
).psm
self.le_psm_out_characteristic = Characteristic(
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.Properties.READ,

1424
bumble/profiles/bap.py Normal file

File diff suppressed because it is too large Load Diff

52
bumble/profiles/cap.py Normal file
View File

@@ -0,0 +1,52 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from bumble import gatt
from bumble import gatt_client
from bumble.profiles import csip
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CommonAudioServiceService(gatt.TemplateService):
UUID = gatt.GATT_COMMON_AUDIO_SERVICE
def __init__(
self,
coordinated_set_identification_service: csip.CoordinatedSetIdentificationService,
) -> None:
self.coordinated_set_identification_service = (
coordinated_set_identification_service
)
super().__init__(
characteristics=[],
included_services=[coordinated_set_identification_service],
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CommonAudioServiceServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CommonAudioServiceService
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy

257
bumble/profiles/csip.py Normal file
View File

@@ -0,0 +1,257 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import Optional, Tuple
from bumble import core
from bumble import crypto
from bumble import device
from bumble import gatt
from bumble import gatt_client
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
SET_IDENTITY_RESOLVING_KEY_LENGTH = 16
class SirkType(enum.IntEnum):
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
ENCRYPTED = 0x00
PLAINTEXT = 0x01
class MemberLock(enum.IntEnum):
'''Coordinated Set Identification Service - 5.3 Set Member Lock.'''
UNLOCKED = 0x01
LOCKED = 0x02
# -----------------------------------------------------------------------------
# Crypto Toolbox
# -----------------------------------------------------------------------------
def s1(m: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.3 s1 SALT generation function.
'''
return crypto.aes_cmac(m[::-1], bytes(16))[::-1]
def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.4 k1 derivation function.
'''
t = crypto.aes_cmac(n[::-1], salt[::-1])
return crypto.aes_cmac(p[::-1], t)[::-1]
def sef(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
* Plaintext in encryption
* Cipher in decryption
'''
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
def sih(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
'''
return crypto.e(k, r + bytes(13))[:3]
def generate_rsi(sirk: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
'''
prand = crypto.generate_prand()
return sih(sirk, prand) + prand
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationService(gatt.TemplateService):
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
set_identity_resolving_key: bytes
set_identity_resolving_key_characteristic: gatt.Characteristic
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
set_member_lock_characteristic: Optional[gatt.Characteristic] = None
set_member_rank_characteristic: Optional[gatt.Characteristic] = None
def __init__(
self,
set_identity_resolving_key: bytes,
set_identity_resolving_key_type: SirkType,
coordinated_set_size: Optional[int] = None,
set_member_lock: Optional[MemberLock] = None,
set_member_rank: Optional[int] = None,
) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise core.InvalidArgumentError(
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
)
characteristics = []
self.set_identity_resolving_key = set_identity_resolving_key
self.set_identity_resolving_key_type = set_identity_resolving_key_type
self.set_identity_resolving_key_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self.on_sirk_read),
)
characteristics.append(self.set_identity_resolving_key_characteristic)
if coordinated_set_size is not None:
self.coordinated_set_size_characteristic = gatt.Characteristic(
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=struct.pack('B', coordinated_set_size),
)
characteristics.append(self.coordinated_set_size_characteristic)
if set_member_lock is not None:
self.set_member_lock_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
| gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITEABLE,
value=struct.pack('B', set_member_lock),
)
characteristics.append(self.set_member_lock_characteristic)
if set_member_rank is not None:
self.set_member_rank_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=struct.pack('B', set_member_rank),
)
characteristics.append(self.set_member_rank_characteristic)
super().__init__(characteristics)
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
sirk_bytes = self.set_identity_resolving_key
else:
assert connection
if connection.transport == core.BT_LE_TRANSPORT:
key = await connection.device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await connection.device.get_link_key(connection.peer_address)
if not key:
raise core.InvalidOperationError('LTK or LinkKey is not present')
sirk_bytes = sef(key, self.set_identity_resolving_key)
return bytes([self.set_identity_resolving_key_type]) + sirk_bytes
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
generate_rsi(self.set_identity_resolving_key),
),
]
)
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CoordinatedSetIdentificationService
set_identity_resolving_key: gatt_client.CharacteristicProxy
coordinated_set_size: Optional[gatt_client.CharacteristicProxy] = None
set_member_lock: Optional[gatt_client.CharacteristicProxy] = None
set_member_rank: Optional[gatt_client.CharacteristicProxy] = None
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.set_identity_resolving_key = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC
)[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC
):
self.coordinated_set_size = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC
):
self.set_member_lock = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
):
self.set_member_rank = characteristics[0]
async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
'''Reads SIRK and decrypts if encrypted.'''
response = await self.set_identity_resolving_key.read_value()
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
raise core.InvalidPacketError('Invalid SIRK value')
sirk_type = SirkType(response[0])
if sirk_type == SirkType.PLAINTEXT:
sirk = response[1:]
else:
connection = self.service_proxy.client.connection
device = connection.device
if connection.transport == core.BT_LE_TRANSPORT:
key = await device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await device.get_link_key(connection.peer_address)
if not key:
raise core.InvalidOperationError('LTK or LinkKey is not present')
sirk = sef(key, response[1:])
return (sirk_type, sirk)

View File

@@ -19,8 +19,8 @@
import struct
from typing import Optional, Tuple
from ..gatt_client import ProfileServiceProxy
from ..gatt import (
from bumble.gatt_client import ServiceProxy, ProfileServiceProxy, CharacteristicProxy
from bumble.gatt import (
GATT_DEVICE_INFORMATION_SERVICE,
GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC,
GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC,
@@ -59,7 +59,7 @@ class DeviceInformationService(TemplateService):
firmware_revision: Optional[str] = None,
software_revision: Optional[str] = None,
system_id: Optional[Tuple[int, int]] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: Optional[bytes] = None
ieee_regulatory_certification_data_list: Optional[bytes] = None,
# TODO: pnp_id
):
characteristics = [
@@ -104,10 +104,19 @@ class DeviceInformationService(TemplateService):
class DeviceInformationServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = DeviceInformationService
def __init__(self, service_proxy):
manufacturer_name: Optional[UTF8CharacteristicAdapter]
model_number: Optional[UTF8CharacteristicAdapter]
serial_number: Optional[UTF8CharacteristicAdapter]
hardware_revision: Optional[UTF8CharacteristicAdapter]
firmware_revision: Optional[UTF8CharacteristicAdapter]
software_revision: Optional[UTF8CharacteristicAdapter]
system_id: Optional[DelegatedCharacteristicAdapter]
ieee_regulatory_certification_data_list: Optional[CharacteristicProxy]
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy
for (field, uuid) in (
for field, uuid in (
('manufacturer_name', GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
('model_number', GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),

View File

@@ -19,6 +19,7 @@
from enum import IntEnum
import struct
from bumble import core
from ..gatt_client import ProfileServiceProxy
from ..att import ATT_Error
from ..gatt import (
@@ -42,12 +43,12 @@ class HeartRateService(TemplateService):
RESET_ENERGY_EXPENDED = 0x01
class BodySensorLocation(IntEnum):
OTHER = (0,)
CHEST = (1,)
WRIST = (2,)
FINGER = (3,)
HAND = (4,)
EAR_LOBE = (5,)
OTHER = 0
CHEST = 1
WRIST = 2
FINGER = 3
HAND = 4
EAR_LOBE = 5
FOOT = 6
class HeartRateMeasurement:
@@ -59,17 +60,17 @@ class HeartRateService(TemplateService):
rr_intervals=None,
):
if heart_rate < 0 or heart_rate > 0xFFFF:
raise ValueError('heart_rate out of range')
raise core.InvalidArgumentError('heart_rate out of range')
if energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF
):
raise ValueError('energy_expended out of range')
raise core.InvalidArgumentError('energy_expended out of range')
if rr_intervals:
for rr_interval in rr_intervals:
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise ValueError('rr_intervals out of range')
raise core.InvalidArgumentError('rr_intervals out of range')
self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected

View File

@@ -0,0 +1,83 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import struct
from typing import List, Type
from typing_extensions import Self
from bumble import utils
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class Metadata:
'''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures.
As Metadata fields may extend, and Spec doesn't forbid duplication, we don't parse
Metadata into a key-value style dataclass here. Rather, we encourage users to parse
again outside the lib.
'''
class Tag(utils.OpenIntEnum):
# fmt: off
PREFERRED_AUDIO_CONTEXTS = 0x01
STREAMING_AUDIO_CONTEXTS = 0x02
PROGRAM_INFO = 0x03
LANGUAGE = 0x04
CCID_LIST = 0x05
PARENTAL_RATING = 0x06
PROGRAM_INFO_URI = 0x07
AUDIO_ACTIVE_STATE = 0x08
BROADCAST_AUDIO_IMMEDIATE_RENDERING_FLAG = 0x09
ASSISTED_LISTENING_STREAM = 0x0A
BROADCAST_NAME = 0x0B
EXTENDED_METADATA = 0xFE
VENDOR_SPECIFIC = 0xFF
@dataclasses.dataclass
class Entry:
tag: Metadata.Tag
data: bytes
@classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(tag=Metadata.Tag(data[0]), data=data[1:])
def __bytes__(self) -> bytes:
return bytes([len(self.data) + 1, self.tag]) + self.data
entries: List[Entry] = dataclasses.field(default_factory=list)
@classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self:
entries = []
offset = 0
length = len(data)
while offset < length:
entry_length = data[offset]
offset += 1
entries.append(cls.Entry.from_bytes(data[offset : offset + entry_length]))
offset += entry_length
return cls(entries)
def __bytes__(self) -> bytes:
return b''.join([bytes(entry) for entry in self.entries])

448
bumble/profiles/mcp.py Normal file
View File

@@ -0,0 +1,448 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import enum
import struct
from bumble import core
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import utils
from typing import Type, Optional, ClassVar, Dict, TYPE_CHECKING
from typing_extensions import Self
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class PlayingOrder(utils.OpenIntEnum):
'''See Media Control Service 3.15. Playing Order.'''
SINGLE_ONCE = 0x01
SINGLE_REPEAT = 0x02
IN_ORDER_ONCE = 0x03
IN_ORDER_REPEAT = 0x04
OLDEST_ONCE = 0x05
OLDEST_REPEAT = 0x06
NEWEST_ONCE = 0x07
NEWEST_REPEAT = 0x08
SHUFFLE_ONCE = 0x09
SHUFFLE_REPEAT = 0x0A
class PlayingOrderSupported(enum.IntFlag):
'''See Media Control Service 3.16. Playing Orders Supported.'''
SINGLE_ONCE = 0x0001
SINGLE_REPEAT = 0x0002
IN_ORDER_ONCE = 0x0004
IN_ORDER_REPEAT = 0x0008
OLDEST_ONCE = 0x0010
OLDEST_REPEAT = 0x0020
NEWEST_ONCE = 0x0040
NEWEST_REPEAT = 0x0080
SHUFFLE_ONCE = 0x0100
SHUFFLE_REPEAT = 0x0200
class MediaState(utils.OpenIntEnum):
'''See Media Control Service 3.17. Media State.'''
INACTIVE = 0x00
PLAYING = 0x01
PAUSED = 0x02
SEEKING = 0x03
class MediaControlPointOpcode(utils.OpenIntEnum):
'''See Media Control Service 3.18. Media Control Point.'''
PLAY = 0x01
PAUSE = 0x02
FAST_REWIND = 0x03
FAST_FORWARD = 0x04
STOP = 0x05
MOVE_RELATIVE = 0x10
PREVIOUS_SEGMENT = 0x20
NEXT_SEGMENT = 0x21
FIRST_SEGMENT = 0x22
LAST_SEGMENT = 0x23
GOTO_SEGMENT = 0x24
PREVIOUS_TRACK = 0x30
NEXT_TRACK = 0x31
FIRST_TRACK = 0x32
LAST_TRACK = 0x33
GOTO_TRACK = 0x34
PREVIOUS_GROUP = 0x40
NEXT_GROUP = 0x41
FIRST_GROUP = 0x42
LAST_GROUP = 0x43
GOTO_GROUP = 0x44
class MediaControlPointResultCode(enum.IntFlag):
'''See Media Control Service 3.18.2. Media Control Point Notification.'''
SUCCESS = 0x01
OPCODE_NOT_SUPPORTED = 0x02
MEDIA_PLAYER_INACTIVE = 0x03
COMMAND_CANNOT_BE_COMPLETED = 0x04
class MediaControlPointOpcodeSupported(enum.IntFlag):
'''See Media Control Service 3.19. Media Control Point Opcodes Supported.'''
PLAY = 0x00000001
PAUSE = 0x00000002
FAST_REWIND = 0x00000004
FAST_FORWARD = 0x00000008
STOP = 0x00000010
MOVE_RELATIVE = 0x00000020
PREVIOUS_SEGMENT = 0x00000040
NEXT_SEGMENT = 0x00000080
FIRST_SEGMENT = 0x00000100
LAST_SEGMENT = 0x00000200
GOTO_SEGMENT = 0x00000400
PREVIOUS_TRACK = 0x00000800
NEXT_TRACK = 0x00001000
FIRST_TRACK = 0x00002000
LAST_TRACK = 0x00004000
GOTO_TRACK = 0x00008000
PREVIOUS_GROUP = 0x00010000
NEXT_GROUP = 0x00020000
FIRST_GROUP = 0x00040000
LAST_GROUP = 0x00080000
GOTO_GROUP = 0x00100000
class SearchControlPointItemType(utils.OpenIntEnum):
'''See Media Control Service 3.20. Search Control Point.'''
TRACK_NAME = 0x01
ARTIST_NAME = 0x02
ALBUM_NAME = 0x03
GROUP_NAME = 0x04
EARLIEST_YEAR = 0x05
LATEST_YEAR = 0x06
GENRE = 0x07
ONLY_TRACKS = 0x08
ONLY_GROUPS = 0x09
class ObjectType(utils.OpenIntEnum):
'''See Media Control Service 4.4.1. Object Type field.'''
TASK = 0
GROUP = 1
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class ObjectId(int):
'''See Media Control Service 4.4.2. Object ID field.'''
@classmethod
def create_from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(int.from_bytes(data, byteorder='little', signed=False))
def __bytes__(self) -> bytes:
return self.to_bytes(6, 'little')
@dataclasses.dataclass
class GroupObjectType:
'''See Media Control Service 4.4. Group Object Type.'''
object_type: ObjectType
object_id: ObjectId
@classmethod
def from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls(
object_type=ObjectType(data[0]),
object_id=ObjectId.create_from_bytes(data[1:]),
)
def __bytes__(self) -> bytes:
return bytes([self.object_type]) + bytes(self.object_id)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class MediaControlService(gatt.TemplateService):
'''Media Control Service server implementation, only for testing currently.'''
UUID = gatt.GATT_MEDIA_CONTROL_SERVICE
def __init__(self, media_player_name: Optional[str] = None) -> None:
self.track_position = 0
self.media_player_name_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=media_player_name or 'Bumble Player',
)
self.track_changed_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_title_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_duration_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.track_position_characteristic = gatt.Characteristic(
uuid=gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=b'',
)
self.media_state_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.media_control_point_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(write=self.on_media_control_point),
)
self.media_control_point_opcodes_supported_characteristic = gatt.Characteristic(
uuid=gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
self.content_control_id_characteristic = gatt.Characteristic(
uuid=gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=b'',
)
super().__init__(
[
self.media_player_name_characteristic,
self.track_changed_characteristic,
self.track_title_characteristic,
self.track_duration_characteristic,
self.track_position_characteristic,
self.media_state_characteristic,
self.media_control_point_characteristic,
self.media_control_point_opcodes_supported_characteristic,
self.content_control_id_characteristic,
]
)
async def on_media_control_point(
self, connection: Optional[device.Connection], data: bytes
) -> None:
if not connection:
raise core.InvalidStateError()
opcode = MediaControlPointOpcode(data[0])
await connection.device.notify_subscriber(
connection,
self.media_control_point_characteristic,
value=bytes([opcode, MediaControlPointResultCode.SUCCESS]),
)
class GenericMediaControlService(MediaControlService):
UUID = gatt.GATT_GENERIC_MEDIA_CONTROL_SERVICE
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class MediaControlServiceProxy(
gatt_client.ProfileServiceProxy, utils.CompositeEventEmitter
):
SERVICE_CLASS = MediaControlService
_CHARACTERISTICS: ClassVar[Dict[str, core.UUID]] = {
'media_player_name': gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
'media_player_icon_object_id': gatt.GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC,
'media_player_icon_url': gatt.GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC,
'track_changed': gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
'track_title': gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
'track_duration': gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
'track_position': gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
'playback_speed': gatt.GATT_PLAYBACK_SPEED_CHARACTERISTIC,
'seeking_speed': gatt.GATT_SEEKING_SPEED_CHARACTERISTIC,
'current_track_segments_object_id': gatt.GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC,
'current_track_object_id': gatt.GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC,
'next_track_object_id': gatt.GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC,
'parent_group_object_id': gatt.GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC,
'current_group_object_id': gatt.GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC,
'playing_order': gatt.GATT_PLAYING_ORDER_CHARACTERISTIC,
'playing_orders_supported': gatt.GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC,
'media_state': gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
'media_control_point': gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
'media_control_point_opcodes_supported': gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
'search_control_point': gatt.GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC,
'search_results_object_id': gatt.GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC,
'content_control_id': gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
}
media_player_name: Optional[gatt_client.CharacteristicProxy] = None
media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy] = None
media_player_icon_url: Optional[gatt_client.CharacteristicProxy] = None
track_changed: Optional[gatt_client.CharacteristicProxy] = None
track_title: Optional[gatt_client.CharacteristicProxy] = None
track_duration: Optional[gatt_client.CharacteristicProxy] = None
track_position: Optional[gatt_client.CharacteristicProxy] = None
playback_speed: Optional[gatt_client.CharacteristicProxy] = None
seeking_speed: Optional[gatt_client.CharacteristicProxy] = None
current_track_segments_object_id: Optional[gatt_client.CharacteristicProxy] = None
current_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
next_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
parent_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
current_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
playing_order: Optional[gatt_client.CharacteristicProxy] = None
playing_orders_supported: Optional[gatt_client.CharacteristicProxy] = None
media_state: Optional[gatt_client.CharacteristicProxy] = None
media_control_point: Optional[gatt_client.CharacteristicProxy] = None
media_control_point_opcodes_supported: Optional[gatt_client.CharacteristicProxy] = (
None
)
search_control_point: Optional[gatt_client.CharacteristicProxy] = None
search_results_object_id: Optional[gatt_client.CharacteristicProxy] = None
content_control_id: Optional[gatt_client.CharacteristicProxy] = None
if TYPE_CHECKING:
media_control_point_notifications: asyncio.Queue[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
utils.CompositeEventEmitter.__init__(self)
self.service_proxy = service_proxy
self.lock = asyncio.Lock()
self.media_control_point_notifications = asyncio.Queue()
for field, uuid in self._CHARACTERISTICS.items():
if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
setattr(self, field, characteristics[0])
async def subscribe_characteristics(self) -> None:
if self.media_control_point:
await self.media_control_point.subscribe(self._on_media_control_point)
if self.media_state:
await self.media_state.subscribe(self._on_media_state)
if self.track_changed:
await self.track_changed.subscribe(self._on_track_changed)
if self.track_title:
await self.track_title.subscribe(self._on_track_title)
if self.track_duration:
await self.track_duration.subscribe(self._on_track_duration)
if self.track_position:
await self.track_position.subscribe(self._on_track_position)
async def write_control_point(
self, opcode: MediaControlPointOpcode
) -> MediaControlPointResultCode:
'''Writes a Media Control Point Opcode to peer and waits for the notification.
The write operation will be executed when there isn't other pending commands.
Args:
opcode: opcode defined in `MediaControlPointOpcode`.
Returns:
Response code provided in `MediaControlPointResultCode`
Raises:
InvalidOperationError: Server does not have Media Control Point Characteristic.
InvalidStateError: Server replies a notification with mismatched opcode.
'''
if not self.media_control_point:
raise core.InvalidOperationError("Peer does not have media control point")
async with self.lock:
await self.media_control_point.write_value(
bytes([opcode]),
with_response=False,
)
(
response_opcode,
response_code,
) = await self.media_control_point_notifications.get()
if response_opcode != opcode:
raise core.InvalidStateError(
f"Expected {opcode} notification, but get {response_opcode}"
)
return MediaControlPointResultCode(response_code)
def _on_media_control_point(self, data: bytes) -> None:
self.media_control_point_notifications.put_nowait(data)
def _on_media_state(self, data: bytes) -> None:
self.emit('media_state', MediaState(data[0]))
def _on_track_changed(self, data: bytes) -> None:
del data
self.emit('track_changed')
def _on_track_title(self, data: bytes) -> None:
self.emit('track_title', data.decode("utf-8"))
def _on_track_duration(self, data: bytes) -> None:
self.emit('track_duration', struct.unpack_from('<i', data)[0])
def _on_track_position(self, data: bytes) -> None:
self.emit('track_position', struct.unpack_from('<i', data)[0])
class GenericMediaControlServiceProxy(MediaControlServiceProxy):
SERVICE_CLASS = GenericMediaControlService

46
bumble/profiles/pbp.py Normal file
View File

@@ -0,0 +1,46 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import enum
from typing_extensions import Self
from bumble.profiles import le_audio
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class PublicBroadcastAnnouncement:
class Features(enum.IntFlag):
ENCRYPTED = 1 << 0
STANDARD_QUALITY_CONFIGURATION = 1 << 1
HIGH_QUALITY_CONFIGURATION = 1 << 2
features: Features
metadata: le_audio.Metadata
@classmethod
def from_bytes(cls, data: bytes) -> Self:
features = cls.Features(data[0])
metadata_length = data[1]
metadata_ltv = data[1 : 1 + metadata_length]
return cls(
features=features, metadata=le_audio.Metadata.from_bytes(metadata_ltv)
)

228
bumble/profiles/vcp.py Normal file
View File

@@ -0,0 +1,228 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
from bumble import att
from bumble import device
from bumble import gatt
from bumble import gatt_client
from typing import Optional
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
MIN_VOLUME = 0
MAX_VOLUME = 255
class ErrorCode(enum.IntEnum):
'''
See Volume Control Service 1.6. Application error codes.
'''
INVALID_CHANGE_COUNTER = 0x80
OPCODE_NOT_SUPPORTED = 0x81
class VolumeFlags(enum.IntFlag):
'''
See Volume Control Service 3.3. Volume Flags.
'''
VOLUME_SETTING_PERSISTED = 0x01
# RFU
class VolumeControlPointOpcode(enum.IntEnum):
'''
See Volume Control Service Table 3.3: Volume Control Point procedure requirements.
'''
# fmt: off
RELATIVE_VOLUME_DOWN = 0x00
RELATIVE_VOLUME_UP = 0x01
UNMUTE_RELATIVE_VOLUME_DOWN = 0x02
UNMUTE_RELATIVE_VOLUME_UP = 0x03
SET_ABSOLUTE_VOLUME = 0x04
UNMUTE = 0x05
MUTE = 0x06
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class VolumeControlService(gatt.TemplateService):
UUID = gatt.GATT_VOLUME_CONTROL_SERVICE
volume_state: gatt.Characteristic
volume_control_point: gatt.Characteristic
volume_flags: gatt.Characteristic
volume_setting: int
muted: int
change_counter: int
def __init__(
self,
step_size: int = 16,
volume_setting: int = 0,
muted: int = 0,
change_counter: int = 0,
volume_flags: int = 0,
) -> None:
self.step_size = step_size
self.volume_setting = volume_setting
self.muted = muted
self.change_counter = change_counter
self.volume_state = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_STATE_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self._on_read_volume_state),
)
self.volume_control_point = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(write=self._on_write_volume_control_point),
)
self.volume_flags = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes([volume_flags]),
)
super().__init__(
[
self.volume_state,
self.volume_control_point,
self.volume_flags,
]
)
@property
def volume_state_bytes(self) -> bytes:
return bytes([self.volume_setting, self.muted, self.change_counter])
@volume_state_bytes.setter
def volume_state_bytes(self, new_value: bytes) -> None:
self.volume_setting, self.muted, self.change_counter = new_value
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
return self.volume_state_bytes
def _on_write_volume_control_point(
self, connection: Optional[device.Connection], value: bytes
) -> None:
assert connection
opcode = VolumeControlPointOpcode(value[0])
change_counter = value[1]
if change_counter != self.change_counter:
raise att.ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
handler = getattr(self, '_on_' + opcode.name.lower())
if handler(*value[2:]):
self.change_counter = (self.change_counter + 1) % 256
connection.abort_on(
'disconnection',
connection.device.notify_subscribers(
attribute=self.volume_state,
value=self.volume_state_bytes,
),
)
self.emit(
'volume_state', self.volume_setting, self.muted, self.change_counter
)
def _on_relative_volume_down(self) -> bool:
old_volume = self.volume_setting
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
return self.volume_setting != old_volume
def _on_relative_volume_up(self) -> bool:
old_volume = self.volume_setting
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
return self.volume_setting != old_volume
def _on_unmute_relative_volume_down(self) -> bool:
old_volume, old_muted_state = self.volume_setting, self.muted
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
self.muted = 0
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
def _on_unmute_relative_volume_up(self) -> bool:
old_volume, old_muted_state = self.volume_setting, self.muted
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
self.muted = 0
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
def _on_set_absolute_volume(self, volume_setting: int) -> bool:
old_volume_setting = self.volume_setting
self.volume_setting = volume_setting
return old_volume_setting != self.volume_setting
def _on_unmute(self) -> bool:
old_muted_state = self.muted
self.muted = 0
return self.muted != old_muted_state
def _on_mute(self) -> bool:
old_muted_state = self.muted
self.muted = 1
return self.muted != old_muted_state
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class VolumeControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = VolumeControlService
volume_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.volume_state = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_STATE_CHARACTERISTIC
)[0],
'BBB',
)
self.volume_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC
)[0]
self.volume_flags = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC
)[0],
'B',
)

File diff suppressed because it is too large Load Diff

View File

@@ -18,13 +18,17 @@
from __future__ import annotations
import logging
import struct
from typing import Dict, List, Type
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
from typing_extensions import Self
from . import core
from . import core, l2cap
from .colors import color
from .core import InvalidStateError
from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError
from .hci import HCI_Object, name_or_number, key_with_value
if TYPE_CHECKING:
from .device import Device, Connection
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -94,6 +98,11 @@ SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID = 0X000B
SDP_ICON_URL_ATTRIBUTE_ID = 0X000C
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D
# Profile-specific Attribute Identifiers (cf. Assigned Numbers for Service Discovery)
# used by AVRCP, HFP and A2DP
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311
SDP_ATTRIBUTE_ID_NAMES = {
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID: 'SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID',
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID: 'SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID',
@@ -108,7 +117,8 @@ SDP_ATTRIBUTE_ID_NAMES = {
SDP_DOCUMENTATION_URL_ATTRIBUTE_ID: 'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID',
SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID: 'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID',
SDP_ICON_URL_ATTRIBUTE_ID: 'SDP_ICON_URL_ATTRIBUTE_ID',
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID'
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID',
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID: 'SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID',
}
SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
@@ -160,7 +170,7 @@ class DataElement:
UUID: lambda x: DataElement(
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
@@ -179,7 +189,9 @@ class DataElement:
self.bytes = None
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
if value_size is None:
raise ValueError('integer types must have a value size specified')
raise InvalidArgumentError(
'integer types must have a value size specified'
)
@staticmethod
def nil() -> DataElement:
@@ -222,7 +234,7 @@ class DataElement:
return DataElement(DataElement.UUID, value)
@staticmethod
def text_string(value: str) -> DataElement:
def text_string(value: bytes) -> DataElement:
return DataElement(DataElement.TEXT_STRING, value)
@staticmethod
@@ -255,7 +267,7 @@ class DataElement:
if len(data) == 8:
return struct.unpack('>Q', data)[0]
raise ValueError(f'invalid integer length {len(data)}')
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def signed_integer_from_bytes(data):
@@ -271,7 +283,7 @@ class DataElement:
if len(data) == 8:
return struct.unpack('>q', data)[0]
raise ValueError(f'invalid integer length {len(data)}')
raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod
def list_from_bytes(data):
@@ -344,7 +356,7 @@ class DataElement:
data = b''
elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise ValueError('UNSIGNED_INTEGER cannot be negative')
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
if self.value_size == 1:
data = struct.pack('B', self.value)
@@ -355,7 +367,7 @@ class DataElement:
elif self.value_size == 8:
data = struct.pack('>Q', self.value)
else:
raise ValueError('invalid value_size')
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.SIGNED_INTEGER:
if self.value_size == 1:
data = struct.pack('b', self.value)
@@ -366,10 +378,10 @@ class DataElement:
elif self.value_size == 8:
data = struct.pack('>q', self.value)
else:
raise ValueError('invalid value_size')
raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value)))
elif self.type in (DataElement.TEXT_STRING, DataElement.URL):
elif self.type == DataElement.URL:
data = self.value.encode('utf8')
elif self.type == DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
@@ -382,7 +394,7 @@ class DataElement:
size_bytes = b''
if self.type == DataElement.NIL:
if size != 0:
raise ValueError('NIL must be empty')
raise InvalidArgumentError('NIL must be empty')
size_index = 0
elif self.type in (
DataElement.UNSIGNED_INTEGER,
@@ -400,7 +412,7 @@ class DataElement:
elif size == 16:
size_index = 4
else:
raise ValueError('invalid data size')
raise InvalidArgumentError('invalid data size')
elif self.type in (
DataElement.TEXT_STRING,
DataElement.SEQUENCE,
@@ -417,10 +429,10 @@ class DataElement:
size_index = 7
size_bytes = struct.pack('>I', size)
else:
raise ValueError('invalid data size')
raise InvalidArgumentError('invalid data size')
elif self.type == DataElement.BOOLEAN:
if size != 1:
raise ValueError('boolean must be 1 byte')
raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0
self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
@@ -462,7 +474,7 @@ class ServiceAttribute:
self.value = value
@staticmethod
def list_from_data_elements(elements):
def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]:
attribute_list = []
for i in range(0, len(elements) // 2):
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
@@ -474,7 +486,9 @@ class ServiceAttribute:
return attribute_list
@staticmethod
def find_attribute_in_list(attribute_list, attribute_id):
def find_attribute_in_list(
attribute_list: List[ServiceAttribute], attribute_id: int
) -> Optional[DataElement]:
return next(
(
attribute.value
@@ -489,7 +503,7 @@ class ServiceAttribute:
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
@staticmethod
def is_uuid_in_value(uuid, value):
def is_uuid_in_value(uuid: core.UUID, value: DataElement) -> bool:
# Find if a uuid matches a value, either directly or recursing into sequences
if value.type == DataElement.UUID:
return value.value == uuid
@@ -543,7 +557,9 @@ class SDP_PDU:
return self
@staticmethod
def parse_service_record_handle_list_preceded_by_count(data, offset):
def parse_service_record_handle_list_preceded_by_count(
data: bytes, offset: int
) -> Tuple[int, List[int]]:
count = struct.unpack_from('>H', data, offset - 2)[0]
handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
@@ -641,6 +657,10 @@ class SDP_ServiceSearchRequest(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
'''
service_search_pattern: DataElement
maximum_service_record_count: int
continuation_state: bytes
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
@@ -659,6 +679,11 @@ class SDP_ServiceSearchResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
'''
service_record_handle_list: List[int]
total_service_record_count: int
current_service_record_count: int
continuation_state: bytes
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
@@ -674,6 +699,11 @@ class SDP_ServiceAttributeRequest(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
'''
service_record_handle: int
maximum_attribute_byte_count: int
attribute_id_list: DataElement
continuation_state: bytes
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
@@ -688,6 +718,10 @@ class SDP_ServiceAttributeResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
'''
attribute_list_byte_count: int
attribute_list: bytes
continuation_state: bytes
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
@@ -703,6 +737,11 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
'''
service_search_pattern: DataElement
maximum_attribute_byte_count: int
attribute_id_list: DataElement
continuation_state: bytes
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
@@ -717,26 +756,35 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
'''
attribute_list_byte_count: int
attribute_list: bytes
continuation_state: bytes
# -----------------------------------------------------------------------------
class Client:
def __init__(self, device):
self.device = device
channel: Optional[l2cap.ClassicChannel]
def __init__(self, connection: Connection) -> None:
self.connection = connection
self.pending_request = None
self.channel = None
async def connect(self, connection):
result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM)
self.channel = result
async def connect(self) -> None:
self.channel = await self.connection.create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(SDP_PSM)
)
async def disconnect(self):
async def disconnect(self) -> None:
if self.channel:
await self.channel.disconnect()
self.channel = None
async def search_services(self, uuids):
async def search_services(self, uuids: List[core.UUID]) -> List[int]:
if self.pending_request is not None:
raise InvalidStateError('request already pending')
if self.channel is None:
raise InvalidStateError('L2CAP not connected')
service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids]
@@ -766,20 +814,26 @@ class Client:
return service_record_handle_list
async def search_attributes(self, uuids, attribute_ids):
async def search_attributes(
self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]]
) -> List[List[ServiceAttribute]]:
if self.pending_request is not None:
raise InvalidStateError('request already pending')
if self.channel is None:
raise InvalidStateError('L2CAP not connected')
service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids]
)
attribute_id_list = DataElement.sequence(
[
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
(
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids
]
)
@@ -819,17 +873,25 @@ class Client:
if sequence.type == DataElement.SEQUENCE
]
async def get_attributes(self, service_record_handle, attribute_ids):
async def get_attributes(
self,
service_record_handle: int,
attribute_ids: List[Union[int, Tuple[int, int]]],
) -> List[ServiceAttribute]:
if self.pending_request is not None:
raise InvalidStateError('request already pending')
if self.channel is None:
raise InvalidStateError('L2CAP not connected')
attribute_id_list = DataElement.sequence(
[
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
(
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids
]
)
@@ -865,25 +927,38 @@ class Client:
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
async def __aenter__(self) -> Self:
await self.connect()
return self
async def __aexit__(self, *args) -> None:
await self.disconnect()
# -----------------------------------------------------------------------------
class Server:
CONTINUATION_STATE = bytes([0x01, 0x43])
channel: Optional[l2cap.ClassicChannel]
Service = NewType('Service', List[ServiceAttribute])
service_records: Dict[int, Service]
current_response: Union[None, bytes, Tuple[int, List[int]]]
def __init__(self, device):
def __init__(self, device: Device) -> None:
self.device = device
self.service_records = {} # Service records maps, by record handle
self.channel = None
self.current_response = None
def register(self, l2cap_channel_manager):
l2cap_channel_manager.register_server(SDP_PSM, self.on_connection)
def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
l2cap_channel_manager.create_classic_server(
spec=l2cap.ClassicChannelSpec(psm=SDP_PSM), handler=self.on_connection
)
def send_response(self, response):
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
self.channel.send_pdu(response)
def match_services(self, search_pattern):
def match_services(self, search_pattern: DataElement) -> Dict[int, Service]:
# Find the services for which the attributes in the pattern is a subset of the
# service's attribute values (NOTE: the value search recurses into sequences)
matching_services = {}
@@ -924,7 +999,7 @@ class Server:
try:
handler(sdp_pdu)
except Exception as error:
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
logger.exception(f'{color("!!! Exception in handler:", "red")} {error}')
self.send_response(
SDP_ErrorResponse(
transaction_id=sdp_pdu.transaction_id,
@@ -953,7 +1028,9 @@ class Server:
return (payload, continuation_state)
@staticmethod
def get_service_attributes(service, attribute_ids):
def get_service_attributes(
service: Service, attribute_ids: List[DataElement]
) -> DataElement:
attributes = []
for attribute_id in attribute_ids:
if attribute_id.value_size == 4:
@@ -978,10 +1055,10 @@ class Server:
return attribute_list
def on_sdp_service_search_request(self, request):
def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None:
# Check if this is a continuation
if len(request.continuation_state) > 1:
if not self.current_response:
if self.current_response is None:
self.send_response(
SDP_ErrorResponse(
transaction_id=request.transaction_id,
@@ -1010,6 +1087,7 @@ class Server:
)
# Respond, keeping any unsent handles for later
assert isinstance(self.current_response, tuple)
service_record_handles = self.current_response[1][
: request.maximum_service_record_count
]
@@ -1033,10 +1111,12 @@ class Server:
)
)
def on_sdp_service_attribute_request(self, request):
def on_sdp_service_attribute_request(
self, request: SDP_ServiceAttributeRequest
) -> None:
# Check if this is a continuation
if len(request.continuation_state) > 1:
if not self.current_response:
if self.current_response is None:
self.send_response(
SDP_ErrorResponse(
transaction_id=request.transaction_id,
@@ -1069,22 +1149,24 @@ class Server:
self.current_response = bytes(attribute_list)
# Respond, keeping any pending chunks for later
attribute_list, continuation_state = self.get_next_response_payload(
attribute_list_response, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count
)
self.send_response(
SDP_ServiceAttributeResponse(
transaction_id=request.transaction_id,
attribute_list_byte_count=len(attribute_list),
attribute_list_byte_count=len(attribute_list_response),
attribute_list=attribute_list,
continuation_state=continuation_state,
)
)
def on_sdp_service_search_attribute_request(self, request):
def on_sdp_service_search_attribute_request(
self, request: SDP_ServiceSearchAttributeRequest
) -> None:
# Check if this is a continuation
if len(request.continuation_state) > 1:
if not self.current_response:
if self.current_response is None:
self.send_response(
SDP_ErrorResponse(
transaction_id=request.transaction_id,
@@ -1114,13 +1196,13 @@ class Server:
self.current_response = bytes(attribute_lists)
# Respond, keeping any pending chunks for later
attribute_lists, continuation_state = self.get_next_response_payload(
attribute_lists_response, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count
)
self.send_response(
SDP_ServiceSearchAttributeResponse(
transaction_id=request.transaction_id,
attribute_lists_byte_count=len(attribute_lists),
attribute_lists_byte_count=len(attribute_lists_response),
attribute_lists=attribute_lists,
continuation_state=continuation_state,
)

View File

@@ -25,7 +25,9 @@
from __future__ import annotations
import logging
import asyncio
import enum
import secrets
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
@@ -36,6 +38,7 @@ from typing import (
Optional,
Tuple,
Type,
cast,
)
from pyee import EventEmitter
@@ -51,6 +54,8 @@ from .core import (
BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE,
BT_LE_TRANSPORT,
AdvertisingData,
InvalidArgumentError,
ProtocolError,
name_or_number,
)
@@ -183,8 +188,8 @@ SMP_KEYPRESS_AUTHREQ = 0b00010000
SMP_CT2_AUTHREQ = 0b00100000
# Crypto salt
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('00000000000000000000000000000000746D7031')
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032')
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('000000000000000000000000746D7031')
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
# fmt: on
# pylint: enable=line-too-long
@@ -553,20 +558,64 @@ class AddressResolver:
# -----------------------------------------------------------------------------
class Session:
# Pairing methods
class PairingMethod(enum.IntEnum):
JUST_WORKS = 0
NUMERIC_COMPARISON = 1
PASSKEY = 2
OOB = 3
CTKD_OVER_CLASSIC = 4
PAIRING_METHOD_NAMES = {
JUST_WORKS: 'JUST_WORKS',
NUMERIC_COMPARISON: 'NUMERIC_COMPARISON',
PASSKEY: 'PASSKEY',
OOB: 'OOB',
}
# -----------------------------------------------------------------------------
class OobContext:
"""Cryptographic context for LE SC OOB pairing."""
ecc_key: crypto.EccKey
r: bytes
def __init__(
self, ecc_key: Optional[crypto.EccKey] = None, r: Optional[bytes] = None
) -> None:
self.ecc_key = crypto.EccKey.generate() if ecc_key is None else ecc_key
self.r = crypto.r() if r is None else r
def share(self) -> OobSharedData:
pkx = self.ecc_key.x[::-1]
return OobSharedData(c=crypto.f4(pkx, pkx, self.r, bytes(1)), r=self.r)
# -----------------------------------------------------------------------------
class OobLegacyContext:
"""Cryptographic context for LE Legacy OOB pairing."""
tk: bytes
def __init__(self, tk: Optional[bytes] = None) -> None:
self.tk = crypto.r() if tk is None else tk
# -----------------------------------------------------------------------------
@dataclass
class OobSharedData:
"""Shareable data for LE SC OOB pairing."""
c: bytes
r: bytes
def to_ad(self) -> AdvertisingData:
return AdvertisingData(
[
(AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE, self.c),
(AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE, self.r),
]
)
def __str__(self) -> str:
return f'OOB(C={self.c.hex()}, R={self.r.hex()})'
# -----------------------------------------------------------------------------
class Session:
# I/O Capability to pairing method decision matrix
#
# See Bluetooth spec @ Vol 3, part H - Table 2.8: Mapping of IO Capabilities to Key
@@ -581,51 +630,61 @@ class Session:
# (False).
PAIRING_METHODS = {
SMP_DISPLAY_ONLY_IO_CAPABILITY: {
SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS,
SMP_DISPLAY_YES_NO_IO_CAPABILITY: JUST_WORKS,
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PASSKEY, True, False),
SMP_DISPLAY_ONLY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_DISPLAY_YES_NO_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, True, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PairingMethod.PASSKEY, True, False),
},
SMP_DISPLAY_YES_NO_IO_CAPABILITY: {
SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS,
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (JUST_WORKS, NUMERIC_COMPARISON),
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
SMP_DISPLAY_ONLY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
),
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, True, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (
(PASSKEY, True, False),
NUMERIC_COMPARISON,
(PairingMethod.PASSKEY, True, False),
PairingMethod.NUMERIC_COMPARISON,
),
},
SMP_KEYBOARD_ONLY_IO_CAPABILITY: {
SMP_DISPLAY_ONLY_IO_CAPABILITY: (PASSKEY, False, True),
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (PASSKEY, False, True),
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, False, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PASSKEY, False, True),
SMP_DISPLAY_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, False, True),
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (PairingMethod.PASSKEY, False, True),
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, False, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PairingMethod.PASSKEY, False, True),
},
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: {
SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS,
SMP_DISPLAY_YES_NO_IO_CAPABILITY: JUST_WORKS,
SMP_KEYBOARD_ONLY_IO_CAPABILITY: JUST_WORKS,
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: JUST_WORKS,
SMP_DISPLAY_ONLY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_DISPLAY_YES_NO_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_ONLY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
},
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: {
SMP_DISPLAY_ONLY_IO_CAPABILITY: (PASSKEY, False, True),
SMP_DISPLAY_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, False, True),
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (
(PASSKEY, False, True),
NUMERIC_COMPARISON,
(PairingMethod.PASSKEY, False, True),
PairingMethod.NUMERIC_COMPARISON,
),
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, True, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (
(PASSKEY, True, False),
NUMERIC_COMPARISON,
(PairingMethod.PASSKEY, True, False),
PairingMethod.NUMERIC_COMPARISON,
),
},
}
ea: bytes
eb: bytes
ltk: bytes
preq: bytes
pres: bytes
tk: bytes
def __init__(
self,
manager: Manager,
@@ -635,17 +694,10 @@ class Session:
) -> None:
self.manager = manager
self.connection = connection
self.preq: Optional[bytes] = None
self.pres: Optional[bytes] = None
self.ea = None
self.eb = None
self.tk = bytes(16)
self.r = bytes(16)
self.stk = None
self.ltk = None
self.ltk_ediv = 0
self.ltk_rand = bytes(8)
self.link_key = None
self.link_key: Optional[bytes] = None
self.initiator_key_distribution: int = 0
self.responder_key_distribution: int = 0
self.peer_random_value: Optional[bytes] = None
@@ -658,13 +710,13 @@ class Session:
self.peer_bd_addr: Optional[Address] = None
self.peer_signature_key = None
self.peer_expected_distributions: List[Type[SMP_Command]] = []
self.dh_key = None
self.dh_key = b''
self.confirm_value = None
self.passkey: Optional[int] = None
self.passkey_ready = asyncio.Event()
self.passkey_step = 0
self.passkey_display = False
self.pairing_method = 0
self.pairing_method: PairingMethod = PairingMethod.JUST_WORKS
self.pairing_config = pairing_config
self.wait_before_continuing: Optional[asyncio.Future[None]] = None
self.completed = False
@@ -686,9 +738,9 @@ class Session:
# Create a future that can be used to wait for the session to complete
if self.is_initiator:
self.pairing_result: Optional[
asyncio.Future[None]
] = asyncio.get_running_loop().create_future()
self.pairing_result: Optional[asyncio.Future[None]] = (
asyncio.get_running_loop().create_future()
)
else:
self.pairing_result = None
@@ -711,12 +763,15 @@ class Session:
self.io_capability = pairing_config.delegate.io_capability
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
# OOB (not supported yet)
self.oob = False
# OOB
self.oob_data_flag = 0 if pairing_config.oob is None else 1
# Set up addresses
self_address = connection.self_address
self_address = connection.self_resolvable_address or connection.self_address
peer_address = connection.peer_resolvable_address or connection.peer_address
logger.debug(
f"pairing with self_address={self_address}, peer_address={peer_address}"
)
if self.is_initiator:
self.ia = bytes(self_address)
self.iat = 1 if self_address.is_random else 0
@@ -728,9 +783,35 @@ class Session:
self.ia = bytes(peer_address)
self.iat = 1 if peer_address.is_random else 0
# Select the ECC key, TK and r initial value
if pairing_config.oob:
self.peer_oob_data = pairing_config.oob.peer_data
if pairing_config.sc:
if pairing_config.oob.our_context is None:
raise InvalidArgumentError(
"oob pairing config requires a context when sc is True"
)
self.r = pairing_config.oob.our_context.r
self.ecc_key = pairing_config.oob.our_context.ecc_key
if pairing_config.oob.legacy_context is not None:
self.tk = pairing_config.oob.legacy_context.tk
else:
if pairing_config.oob.legacy_context is None:
raise InvalidArgumentError(
"oob pairing config requires a legacy context when sc is False"
)
self.r = bytes(16)
self.ecc_key = manager.ecc_key
self.tk = pairing_config.oob.legacy_context.tk
else:
self.peer_oob_data = None
self.r = bytes(16)
self.ecc_key = manager.ecc_key
self.tk = bytes(16)
@property
def pkx(self) -> Tuple[bytes, bytes]:
return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x)
return (self.ecc_key.x[::-1], self.peer_public_key_x)
@property
def pka(self) -> bytes:
@@ -767,21 +848,28 @@ class Session:
return None
def decide_pairing_method(
self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
self,
auth_req: int,
initiator_io_capability: int,
responder_io_capability: int,
) -> None:
if self.connection.transport == BT_BR_EDR_TRANSPORT:
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
return
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
self.pairing_method = self.JUST_WORKS
self.pairing_method = PairingMethod.JUST_WORKS
return
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] # type: ignore[index]
if isinstance(details, tuple) and len(details) == 2:
# One entry for legacy pairing and one for secure connections
details = details[1 if self.sc else 0]
if isinstance(details, int):
if isinstance(details, PairingMethod):
# Just a method ID
self.pairing_method = details
else:
# PASSKEY method, with a method ID and display/input flags
assert isinstance(details[0], PairingMethod)
self.pairing_method = details[0]
self.passkey_display = details[1 if self.is_initiator else 2]
@@ -858,10 +946,13 @@ class Session:
self.tk = self.passkey.to_bytes(16, byteorder='little')
logger.debug(f'TK from passkey = {self.tk.hex()}')
self.connection.abort_on(
'disconnection',
self.pairing_config.delegate.display_number(self.passkey, digits=6),
)
try:
self.connection.abort_on(
'disconnection',
self.pairing_config.delegate.display_number(self.passkey, digits=6),
)
except Exception as error:
logger.warning(f'exception while displaying number: {error}')
def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None:
# Prompt the user for the passkey displayed on the peer
@@ -901,7 +992,7 @@ class Session:
command = SMP_Pairing_Request_Command(
io_capability=self.io_capability,
oob_data_flag=0,
oob_data_flag=self.oob_data_flag,
auth_req=self.auth_req,
maximum_encryption_key_size=16,
initiator_key_distribution=self.initiator_key_distribution,
@@ -913,7 +1004,7 @@ class Session:
def send_pairing_response_command(self) -> None:
response = SMP_Pairing_Response_Command(
io_capability=self.io_capability,
oob_data_flag=0,
oob_data_flag=self.oob_data_flag,
auth_req=self.auth_req,
maximum_encryption_key_size=16,
initiator_key_distribution=self.initiator_key_distribution,
@@ -929,9 +1020,12 @@ class Session:
if self.sc:
async def next_steps() -> None:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
):
z = 0
elif self.pairing_method == self.PASSKEY:
elif self.pairing_method == PairingMethod.PASSKEY:
# We need a passkey
await self.passkey_ready.wait()
assert self.passkey
@@ -971,8 +1065,8 @@ class Session:
def send_public_key_command(self) -> None:
self.send_command(
SMP_Pairing_Public_Key_Command(
public_key_x=bytes(reversed(self.manager.ecc_key.x)),
public_key_y=bytes(reversed(self.manager.ecc_key.y)),
public_key_x=self.ecc_key.x[::-1],
public_key_y=self.ecc_key.y[::-1],
)
)
@@ -983,11 +1077,24 @@ class Session:
)
)
def send_identity_address_command(self) -> None:
identity_address = {
None: self.manager.device.static_address,
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.static_address,
}[self.pairing_config.identity_address_type]
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=identity_address.address_type,
bd_addr=identity_address,
)
)
def start_encryption(self, key: bytes) -> None:
# We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection
self.manager.device.host.send_command_sync(
HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
HCI_LE_Enable_Encryption_Command(
connection_handle=self.connection.handle,
random_number=bytes(8),
encrypted_diversifier=0,
@@ -995,15 +1102,54 @@ class Session:
)
)
async def derive_ltk(self) -> None:
link_key = await self.manager.device.get_link_key(self.connection.peer_address)
assert link_key is not None
@classmethod
def derive_ltk(cls, link_key: bytes, ct2: bool) -> bytes:
'''Derives Long Term Key from Link Key.
Args:
link_key: BR/EDR Link Key bytes in little-endian.
ct2: whether ct2 is supported on both devices.
Returns:
LE Long Tern Key bytes in little-endian.
'''
ilk = (
crypto.h7(salt=SMP_CTKD_H7_BRLE_SALT, w=link_key)
if self.ct2
if ct2
else crypto.h6(link_key, b'tmp2')
)
self.ltk = crypto.h6(ilk, b'brle')
return crypto.h6(ilk, b'brle')
@classmethod
def derive_link_key(cls, ltk: bytes, ct2: bool) -> bytes:
'''Derives Link Key from Long Term Key.
Args:
ltk: LE Long Term Key bytes in little-endian.
ct2: whether ct2 is supported on both devices.
Returns:
BR/EDR Link Key bytes in little-endian.
'''
ilk = (
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=ltk)
if ct2
else crypto.h6(ltk, b'tmp1')
)
return crypto.h6(ilk, b'lebr')
async def get_link_key_and_derive_ltk(self) -> None:
'''Retrieves BR/EDR Link Key from storage and derive it to LE LTK.'''
self.link_key = await self.manager.device.get_link_key(
self.connection.peer_address
)
if self.link_key is None:
logging.warning(
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
)
self.send_pairing_failed(
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR
)
else:
self.ltk = self.derive_ltk(self.link_key, self.ct2)
def distribute_keys(self) -> None:
# Distribute the keys as required
@@ -1014,7 +1160,7 @@ class Session:
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
):
self.ctkd_task = self.connection.abort_on(
'disconnection', self.derive_ltk()
'disconnection', self.get_link_key_and_derive_ltk()
)
elif not self.sc:
# Distribute the LTK, EDIV and RAND
@@ -1035,12 +1181,7 @@ class Session:
identity_resolving_key=self.manager.device.irk
)
)
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=self.connection.self_address.address_type,
bd_addr=self.connection.self_address,
)
)
self.send_identity_address_command()
# Distribute CSRK
csrk = bytes(16) # FIXME: testing
@@ -1049,12 +1190,7 @@ class Session:
# CTKD, calculate BR/EDR link key
if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
ilk = (
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk)
if self.ct2
else crypto.h6(self.ltk, b'tmp1')
)
self.link_key = crypto.h6(ilk, b'lebr')
self.link_key = self.derive_link_key(self.ltk, self.ct2)
else:
# CTKD: Derive LTK from LinkKey
@@ -1063,7 +1199,7 @@ class Session:
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
):
self.ctkd_task = self.connection.abort_on(
'disconnection', self.derive_ltk()
'disconnection', self.get_link_key_and_derive_ltk()
)
# Distribute the LTK, EDIV and RAND
elif not self.sc:
@@ -1084,12 +1220,7 @@ class Session:
identity_resolving_key=self.manager.device.irk
)
)
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=self.connection.self_address.address_type,
bd_addr=self.connection.self_address,
)
)
self.send_identity_address_command()
# Distribute CSRK
csrk = bytes(16) # FIXME: testing
@@ -1098,12 +1229,7 @@ class Session:
# CTKD, calculate BR/EDR link key
if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
ilk = (
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk)
if self.ct2
else crypto.h6(self.ltk, b'tmp1')
)
self.link_key = crypto.h6(ilk, b'lebr')
self.link_key = self.derive_link_key(self.ltk, self.ct2)
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
# Set our expectations for what to wait for in the key distribution phase
@@ -1224,7 +1350,7 @@ class Session:
# Create an object to hold the keys
keys = PairingKeys()
keys.address_type = peer_address.address_type
authenticated = self.pairing_method != self.JUST_WORKS
authenticated = self.pairing_method != PairingMethod.JUST_WORKS
if self.sc or self.connection.transport == BT_BR_EDR_TRANSPORT:
keys.ltk = PairingKeys.Key(value=self.ltk, authenticated=authenticated)
else:
@@ -1258,7 +1384,7 @@ class Session:
keys.link_key = PairingKeys.Key(
value=self.link_key, authenticated=authenticated
)
self.manager.on_pairing(self, peer_address, keys)
await self.manager.on_pairing(self, peer_address, keys)
def on_pairing_failure(self, reason: int) -> None:
logger.warning(f'pairing failure ({error_name(reason)})')
@@ -1281,7 +1407,7 @@ class Session:
try:
handler(command)
except Exception as error:
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
logger.exception(f'{color("!!! Exception in handler:", "red")} {error}')
response = SMP_Pairing_Failed_Command(
reason=SMP_UNSPECIFIED_REASON_ERROR
)
@@ -1300,7 +1426,11 @@ class Session:
self, command: SMP_Pairing_Request_Command
) -> None:
# Check if the request should proceed
accepted = await self.pairing_config.delegate.accept()
try:
accepted = await self.pairing_config.delegate.accept()
except Exception as error:
logger.warning(f'exception while accepting: {error}')
accepted = False
if not accepted:
logger.debug('pairing rejected by delegate')
self.send_pairing_failed(SMP_PAIRING_NOT_SUPPORTED_ERROR)
@@ -1314,18 +1444,29 @@ class Session:
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0)
# Check for OOB
if command.oob_data_flag != 0:
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
# Infer the pairing method
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0)
):
# Use OOB
self.pairing_method = PairingMethod.OOB
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
self.r = bytes(16)
else:
# Decide which pairing method to use from the IO capability
self.decide_pairing_method(
command.auth_req,
command.io_capability,
self.io_capability,
)
# Decide which pairing method to use
self.decide_pairing_method(
command.auth_req, command.io_capability, self.io_capability
)
logger.debug(
f'pairing method: {self.PAIRING_METHOD_NAMES[self.pairing_method]}'
)
logger.debug(f'pairing method: {self.pairing_method.name}')
# Key distribution
(
@@ -1341,7 +1482,7 @@ class Session:
# Display a passkey if we need to
if not self.sc:
if self.pairing_method == self.PASSKEY and self.passkey_display:
if self.pairing_method == PairingMethod.PASSKEY and self.passkey_display:
self.display_passkey()
# Respond
@@ -1373,18 +1514,27 @@ class Session:
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
# Check for OOB
if self.sc and command.oob_data_flag:
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
# Infer the pairing method
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0)
):
# Use OOB
self.pairing_method = PairingMethod.OOB
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
self.r = bytes(16)
else:
# Decide which pairing method to use from the IO capability
self.decide_pairing_method(
command.auth_req, self.io_capability, command.io_capability
)
# Decide which pairing method to use
self.decide_pairing_method(
command.auth_req, self.io_capability, command.io_capability
)
logger.debug(
f'pairing method: {self.PAIRING_METHOD_NAMES[self.pairing_method]}'
)
logger.debug(f'pairing method: {self.pairing_method.name}')
# Key distribution
if (
@@ -1400,13 +1550,16 @@ class Session:
self.compute_peer_expected_distributions(self.responder_key_distribution)
# Start phase 2
if self.sc:
if self.pairing_method == self.PASSKEY:
if self.pairing_method == PairingMethod.CTKD_OVER_CLASSIC:
# Authentication is already done in SMP, so remote shall start keys distribution immediately
return
elif self.sc:
if self.pairing_method == PairingMethod.PASSKEY:
self.display_or_input_passkey()
self.send_public_key_command()
else:
if self.pairing_method == self.PASSKEY:
if self.pairing_method == PairingMethod.PASSKEY:
self.display_or_input_passkey(self.send_pairing_confirm_command)
else:
self.send_pairing_confirm_command()
@@ -1418,7 +1571,10 @@ class Session:
self.send_pairing_random_command()
else:
# If the method is PASSKEY, now is the time to input the code
if self.pairing_method == self.PASSKEY and not self.passkey_display:
if (
self.pairing_method == PairingMethod.PASSKEY
and not self.passkey_display
):
self.input_passkey(self.send_pairing_confirm_command)
else:
self.send_pairing_confirm_command()
@@ -1426,11 +1582,14 @@ class Session:
def on_smp_pairing_confirm_command_secure_connections(
self, _: SMP_Pairing_Confirm_Command
) -> None:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
):
if self.is_initiator:
self.r = crypto.r()
self.send_pairing_random_command()
elif self.pairing_method == self.PASSKEY:
elif self.pairing_method == PairingMethod.PASSKEY:
if self.is_initiator:
self.send_pairing_random_command()
else:
@@ -1486,13 +1645,16 @@ class Session:
def on_smp_pairing_random_command_secure_connections(
self, command: SMP_Pairing_Random_Command
) -> None:
if self.pairing_method == self.PASSKEY and self.passkey is None:
if self.pairing_method == PairingMethod.PASSKEY and self.passkey is None:
logger.warning('no passkey entered, ignoring command')
return
# pylint: disable=too-many-return-statements
if self.is_initiator:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
):
assert self.confirm_value
# Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4(
@@ -1502,7 +1664,7 @@ class Session:
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
):
return
elif self.pairing_method == self.PASSKEY:
elif self.pairing_method == PairingMethod.PASSKEY:
assert self.passkey and self.confirm_value
# Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4(
@@ -1522,12 +1684,16 @@ class Session:
if self.passkey_step < 20:
self.send_pairing_confirm_command()
return
else:
elif self.pairing_method != PairingMethod.OOB:
return
else:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
):
self.send_pairing_random_command()
elif self.pairing_method == self.PASSKEY:
elif self.pairing_method == PairingMethod.PASSKEY:
assert self.passkey and self.confirm_value
# Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4(
@@ -1558,15 +1724,18 @@ class Session:
(mac_key, self.ltk) = crypto.f5(self.dh_key, self.na, self.nb, a, b)
# Compute the DH Key checks
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
):
ra = bytes(16)
rb = ra
elif self.pairing_method == self.PASSKEY:
elif self.pairing_method == PairingMethod.PASSKEY:
assert self.passkey
ra = self.passkey.to_bytes(16, byteorder='little')
rb = ra
else:
# OOB not implemented yet
return
assert self.preq and self.pres
@@ -1585,13 +1754,16 @@ class Session:
self.wait_before_continuing.set_result(None)
# Prompt the user for confirmation if needed
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
):
# Compute the 6-digit code
code = crypto.g2(self.pka, self.pkb, self.na, self.nb) % 1000000
# Ask for user confirmation
self.wait_before_continuing = asyncio.get_running_loop().create_future()
if self.pairing_method == self.JUST_WORKS:
if self.pairing_method == PairingMethod.JUST_WORKS:
self.prompt_user_for_confirmation(next_steps)
else:
self.prompt_user_for_numeric_comparison(code, next_steps)
@@ -1615,26 +1787,45 @@ class Session:
self.peer_public_key_y = command.public_key_y
# Compute the DH key
self.dh_key = bytes(
reversed(
self.manager.ecc_key.dh(
bytes(reversed(command.public_key_x)),
bytes(reversed(command.public_key_y)),
)
)
)
self.dh_key = self.ecc_key.dh(
command.public_key_x[::-1],
command.public_key_y[::-1],
)[::-1]
logger.debug(f'DH key: {self.dh_key.hex()}')
if self.pairing_method == PairingMethod.OOB:
# Check against shared OOB data
if self.peer_oob_data:
confirm_verifier = crypto.f4(
self.peer_public_key_x,
self.peer_public_key_x,
self.peer_oob_data.r,
bytes(1),
)
if not self.check_expected_value(
self.peer_oob_data.c,
confirm_verifier,
SMP_CONFIRM_VALUE_FAILED_ERROR,
):
return
if self.is_initiator:
self.send_pairing_confirm_command()
if self.pairing_method == PairingMethod.OOB:
self.send_pairing_random_command()
else:
self.send_pairing_confirm_command()
else:
if self.pairing_method == self.PASSKEY:
if self.pairing_method == PairingMethod.PASSKEY:
self.display_or_input_passkey()
# Send our public key back to the initiator
self.send_public_key_command()
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
):
# We can now send the confirmation value
self.send_pairing_confirm_command()
@@ -1662,7 +1853,6 @@ class Session:
else:
self.send_pairing_dhkey_check_command()
else:
assert self.ltk
self.start_encryption(self.ltk)
def on_smp_pairing_failed_command(
@@ -1712,6 +1902,7 @@ class Manager(EventEmitter):
sessions: Dict[int, Session]
pairing_config_factory: Callable[[Connection], PairingConfig]
session_proxy: Type[Session]
_ecc_key: Optional[crypto.EccKey]
def __init__(
self,
@@ -1733,7 +1924,26 @@ class Manager(EventEmitter):
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes())
def on_smp_security_request_command(
self, connection: Connection, request: SMP_Security_Request_Command
) -> None:
connection.emit('security_request', request.auth_req)
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
# Parse the L2CAP payload into an SMP Command object
command = SMP_Command.from_bytes(pdu)
logger.debug(
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
)
# Security request is more than just pairing, so let applications handle them
if command.code == SMP_SECURITY_REQUEST_COMMAND:
self.on_smp_security_request_command(
connection, cast(SMP_Security_Request_Command, command)
)
return
# Look for a session with this connection, and create one if none exists
if not (session := self.sessions.get(connection.handle)):
if connection.role == BT_CENTRAL_ROLE:
@@ -1744,13 +1954,6 @@ class Manager(EventEmitter):
)
self.sessions[connection.handle] = session
# Parse the L2CAP payload into an SMP Command object
command = SMP_Command.from_bytes(pdu)
logger.debug(
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
)
# Delegate the handling of the command to the session
session.on_smp_command(command)
@@ -1789,21 +1992,13 @@ class Manager(EventEmitter):
def on_session_start(self, session: Session) -> None:
self.device.on_pairing_start(session.connection)
def on_pairing(
async def on_pairing(
self, session: Session, identity_address: Optional[Address], keys: PairingKeys
) -> None:
# Store the keys in the key store
if self.device.keystore and identity_address is not None:
async def store_keys():
try:
assert self.device.keystore
await self.device.keystore.update(str(identity_address), keys)
except Exception as error:
logger.warning(f'!!! error while storing keys: {error}')
self.device.abort_on('flush', store_keys())
# Make sure on_pairing emits after key update.
await self.device.update_keys(str(identity_address), keys)
# Notify the device
self.device.on_pairing(session.connection, identity_address, keys, session.sc)

View File

@@ -23,6 +23,7 @@ import datetime
from typing import BinaryIO, Generator
import os
from bumble import core
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
@@ -138,13 +139,13 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
"""
if ':' not in spec:
raise ValueError('snooper type prefix missing')
raise core.InvalidArgumentError('snooper type prefix missing')
snooper_type, snooper_args = spec.split(':', maxsplit=1)
if snooper_type == 'btsnoop':
if ':' not in snooper_args:
raise ValueError('I/O type for btsnoop snooper type missing')
raise core.InvalidArgumentError('I/O type for btsnoop snooper type missing')
io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type == 'file':
@@ -165,6 +166,6 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
_SNOOPER_INSTANCE_COUNT -= 1
return
raise ValueError(f'I/O type {io_type} not supported')
raise core.InvalidArgumentError(f'I/O type {io_type} not supported')
raise ValueError(f'snooper type {snooper_type} not found')
raise core.InvalidArgumentError(f'snooper type {snooper_type} not found')

View File

@@ -18,9 +18,9 @@
from contextlib import asynccontextmanager
import logging
import os
from typing import Optional
from .common import Transport, AsyncPipeSink, SnoopingTransport
from ..controller import Controller
from .common import Transport, AsyncPipeSink, SnoopingTransport, TransportSpecError
from ..snoop import create_snooper
# -----------------------------------------------------------------------------
@@ -53,8 +53,16 @@ def _wrap_transport(transport: Transport) -> Transport:
async def open_transport(name: str) -> Transport:
"""
Open a transport by name.
The name must be <type>:<parameters>
Where <parameters> depend on the type (and may be empty for some types).
The name must be <type>:<metadata><parameters>
Where <parameters> depend on the type (and may be empty for some types), and
<metadata> is either omitted, or a ,-separated list of <key>=<value> pairs,
enclosed in [].
If there are not metadata or parameter, the : after the <type> may be omitted.
Examples:
* usb:0
* usb:[driver=rtk]0
* android-netsim
The supported types are:
* serial
* udp
@@ -69,88 +77,116 @@ async def open_transport(name: str) -> Transport:
* usb
* pyusb
* android-emulator
* android-netsim
"""
return _wrap_transport(await _open_transport(name))
scheme, *tail = name.split(':', 1)
spec = tail[0] if tail else None
metadata = None
if spec:
# Metadata may precede the spec
if spec.startswith('['):
metadata_str, *tail = spec[1:].split(']')
spec = tail[0] if tail else None
metadata = dict([entry.split('=') for entry in metadata_str.split(',')])
transport = await _open_transport(scheme, spec)
if metadata:
transport.source.metadata = { # type: ignore[attr-defined]
**metadata,
**getattr(transport.source, 'metadata', {}),
}
# pylint: disable=line-too-long
logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined]
return _wrap_transport(transport)
# -----------------------------------------------------------------------------
async def _open_transport(name: str) -> Transport:
async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements
scheme, *spec = name.split(':', 1)
if scheme == 'serial' and spec:
from .serial import open_serial_transport
return await open_serial_transport(spec[0])
return await open_serial_transport(spec)
if scheme == 'udp' and spec:
from .udp import open_udp_transport
return await open_udp_transport(spec[0])
return await open_udp_transport(spec)
if scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec[0])
return await open_tcp_client_transport(spec)
if scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec[0])
return await open_tcp_server_transport(spec)
if scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport
return await open_ws_client_transport(spec[0])
return await open_ws_client_transport(spec)
if scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport
return await open_ws_server_transport(spec[0])
return await open_ws_server_transport(spec)
if scheme == 'pty':
from .pty import open_pty_transport
return await open_pty_transport(spec[0] if spec else None)
return await open_pty_transport(spec)
if scheme == 'file':
from .file import open_file_transport
return await open_file_transport(spec[0] if spec else None)
assert spec is not None
return await open_file_transport(spec)
if scheme == 'vhci':
from .vhci import open_vhci_transport
return await open_vhci_transport(spec[0] if spec else None)
return await open_vhci_transport(spec)
if scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec[0] if spec else None)
return await open_hci_socket_transport(spec)
if scheme == 'usb':
from .usb import open_usb_transport
return await open_usb_transport(spec[0] if spec else None)
assert spec
return await open_usb_transport(spec)
if scheme == 'pyusb':
from .pyusb import open_pyusb_transport
return await open_pyusb_transport(spec[0] if spec else None)
assert spec
return await open_pyusb_transport(spec)
if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec[0] if spec else None)
return await open_android_emulator_transport(spec)
if scheme == 'android-netsim':
from .android_netsim import open_android_netsim_transport
return await open_android_netsim_transport(spec[0] if spec else None)
return await open_android_netsim_transport(spec)
raise ValueError('unknown transport scheme')
if scheme == 'unix':
from .unix import open_unix_client_transport
assert spec
return await open_unix_client_transport(spec)
raise TransportSpecError('unknown transport scheme')
# -----------------------------------------------------------------------------
@@ -167,11 +203,13 @@ async def open_transport_or_link(name: str) -> Transport:
"""
if name.startswith('link-relay:'):
logger.warning('Link Relay has been deprecated.')
from ..controller import Controller
from ..link import RemoteLink # lazy import
link = RemoteLink(name[11:])
await link.wait_until_connected()
controller = Controller('remote', link=link)
controller = Controller('remote', link=link) # type:ignore[arg-type]
class LinkTransport(Transport):
async def close(self):

View File

@@ -18,7 +18,15 @@
import logging
import grpc.aio
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink
from typing import Optional, Union
from .common import (
PumpedTransport,
PumpedPacketSource,
PumpedPacketSink,
Transport,
TransportSpecError,
)
# pylint: disable=no-name-in-module
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
@@ -33,7 +41,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_android_emulator_transport(spec):
async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
'''
Open a transport connection to an Android emulator via its gRPC interface.
The parameter string has this syntax:
@@ -66,8 +74,8 @@ async def open_android_emulator_transport(spec):
# Parse the parameters
mode = 'host'
server_host = 'localhost'
server_port = 8554
if spec is not None:
server_port = '8554'
if spec:
params = spec.split(',')
for param in params:
if param.startswith('mode='):
@@ -75,13 +83,14 @@ async def open_android_emulator_transport(spec):
elif ':' in param:
server_host, server_port = param.split(':')
else:
raise ValueError('invalid parameter')
raise TransportSpecError('invalid parameter')
# Connect to the gRPC server
server_address = f'{server_host}:{server_port}'
logger.debug(f'connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address)
service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]
if mode == 'host':
# Connect as a host
service = EmulatedBluetoothServiceStub(channel)
@@ -91,13 +100,16 @@ async def open_android_emulator_transport(spec):
service = VhciForwardingServiceStub(channel)
hci_device = HciDevice(service.attachVhci())
else:
raise ValueError('invalid mode')
raise TransportSpecError('invalid mode')
# Create the transport object
transport = PumpedTransport(
PumpedPacketSource(hci_device.read),
PumpedPacketSink(hci_device.write),
channel.close,
class EmulatorTransport(PumpedTransport):
async def close(self):
await super().close()
await channel.close()
transport = EmulatorTransport(
PumpedPacketSource(hci_device.read), PumpedPacketSink(hci_device.write)
)
transport.start()

View File

@@ -18,11 +18,12 @@
import asyncio
import atexit
import logging
import grpc.aio
import os
import pathlib
import sys
from typing import Optional
from typing import Dict, Optional
import grpc.aio
from .common import (
ParserSource,
@@ -30,11 +31,13 @@ from .common import (
PumpedPacketSource,
PumpedPacketSink,
Transport,
TransportSpecError,
TransportInitError,
)
# pylint: disable=no-name-in-module
from .grpc_protobuf.packet_streamer_pb2_grpc import PacketStreamerStub
from .grpc_protobuf.packet_streamer_pb2_grpc import (
PacketStreamerStub,
PacketStreamerServicer,
add_PacketStreamerServicer_to_server,
)
@@ -43,6 +46,7 @@ from .grpc_protobuf.hci_packet_pb2 import HCIPacket
from .grpc_protobuf.startup_pb2 import Chip, ChipInfo
from .grpc_protobuf.common_pb2 import ChipKind
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -74,14 +78,20 @@ def get_ini_dir() -> Optional[pathlib.Path]:
# -----------------------------------------------------------------------------
def find_grpc_port() -> int:
def ini_file_name(instance_number: int) -> str:
suffix = f'_{instance_number}' if instance_number > 0 else ''
return f'netsim{suffix}.ini'
# -----------------------------------------------------------------------------
def find_grpc_port(instance_number: int) -> int:
if not (ini_dir := get_ini_dir()):
logger.debug('no known directory for .ini file')
return 0
ini_file = ini_dir / 'netsim.ini'
ini_file = ini_dir / ini_file_name(instance_number)
logger.debug(f'Looking for .ini file at {ini_file}')
if ini_file.is_file():
logger.debug(f'Found .ini file at {ini_file}')
with open(ini_file, 'r') as ini_file_data:
for line in ini_file_data.readlines():
if '=' in line:
@@ -90,12 +100,14 @@ def find_grpc_port() -> int:
logger.debug(f'gRPC port = {value}')
return int(value)
logger.debug('no grpc.port property found in .ini file')
# Not found
return 0
# -----------------------------------------------------------------------------
def publish_grpc_port(grpc_port) -> bool:
def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
if not (ini_dir := get_ini_dir()):
logger.debug('no known directory for .ini file')
return False
@@ -104,7 +116,7 @@ def publish_grpc_port(grpc_port) -> bool:
logger.debug('ini directory does not exist')
return False
ini_file = ini_dir / 'netsim.ini'
ini_file = ini_dir / ini_file_name(instance_number)
try:
ini_file.write_text(f'grpc.port={grpc_port}\n')
logger.debug(f"published gRPC port at {ini_file}")
@@ -121,13 +133,16 @@ def publish_grpc_port(grpc_port) -> bool:
# -----------------------------------------------------------------------------
async def open_android_netsim_controller_transport(server_host, server_port):
async def open_android_netsim_controller_transport(
server_host: Optional[str], server_port: int, options: Dict[str, str]
) -> Transport:
if not server_port:
raise ValueError('invalid port')
raise TransportSpecError('invalid port')
if server_host == '_' or not server_host:
server_host = 'localhost'
if not publish_grpc_port(server_port):
instance_number = int(options.get('instance', "0"))
if not publish_grpc_port(server_port, instance_number):
logger.warning("unable to publish gRPC port")
class HciDevice:
@@ -184,15 +199,12 @@ async def open_android_netsim_controller_transport(server_host, server_port):
logger.debug(f'<<< PACKET: {data.hex()}')
self.on_data_received(data)
def send_packet(self, data):
async def send():
await self.context.write(
PacketResponse(
hci_packet=HCIPacket(packet_type=data[0], packet=data[1:])
)
async def send_packet(self, data):
return await self.context.write(
PacketResponse(
hci_packet=HCIPacket(packet_type=data[0], packet=data[1:])
)
self.loop.create_task(send())
)
def terminate(self):
self.task.cancel()
@@ -226,17 +238,17 @@ async def open_android_netsim_controller_transport(server_host, server_port):
logger.debug('gRPC server cancelled')
await self.grpc_server.stop(None)
def on_packet(self, packet):
async def send_packet(self, packet):
if not self.device:
logger.debug('no device, dropping packet')
return
self.device.send_packet(packet)
return await self.device.send_packet(packet)
async def StreamPackets(self, _request_iterator, context):
logger.debug('StreamPackets request')
# Check that we won't already have a device
# Check that we don't already have a device
if self.device:
logger.debug('busy, already serving a device')
return PacketResponse(error='Busy')
@@ -259,15 +271,42 @@ async def open_android_netsim_controller_transport(server_host, server_port):
await server.start()
asyncio.get_running_loop().create_task(server.serve())
class GrpcServerTransport(Transport):
async def close(self):
await super().close()
return GrpcServerTransport(server, server)
sink = PumpedPacketSink(server.send_packet)
sink.start()
return Transport(server, sink)
# -----------------------------------------------------------------------------
async def open_android_netsim_host_transport(server_host, server_port, options):
async def open_android_netsim_host_transport_with_address(
server_host: Optional[str],
server_port: int,
options: Optional[Dict[str, str]] = None,
):
if server_host == '_' or not server_host:
server_host = 'localhost'
if not server_port:
# Look for the gRPC config in a .ini file
instance_number = 0 if options is None else int(options.get('instance', '0'))
server_port = find_grpc_port(instance_number)
if not server_port:
raise TransportInitError('gRPC server port not found')
# Connect to the gRPC server
server_address = f'{server_host}:{server_port}'
logger.debug(f'Connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address)
return await open_android_netsim_host_transport_with_channel(
channel,
options,
)
# -----------------------------------------------------------------------------
async def open_android_netsim_host_transport_with_channel(
channel, options: Optional[Dict[str, str]] = None
):
# Wrapper for I/O operations
class HciDevice:
def __init__(self, name, manufacturer, hci_device):
@@ -286,16 +325,18 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
async def read(self):
response = await self.hci_device.read()
response_type = response.WhichOneof('response_type')
if response_type == 'error':
logger.warning(f'received error: {response.error}')
raise RuntimeError(response.error)
elif response_type == 'hci_packet':
raise TransportInitError(response.error)
if response_type == 'hci_packet':
return (
bytes([response.hci_packet.packet_type])
+ response.hci_packet.packet
)
raise ValueError('unsupported response type')
raise TransportSpecError('unsupported response type')
async def write(self, packet):
await self.hci_device.write(
@@ -304,24 +345,9 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
)
)
name = options.get('name', DEFAULT_NAME)
name = DEFAULT_NAME if options is None else options.get('name', DEFAULT_NAME)
manufacturer = DEFAULT_MANUFACTURER
if server_host == '_' or not server_host:
server_host = 'localhost'
if not server_port:
# Look for the gRPC config in a .ini file
server_host = 'localhost'
server_port = find_grpc_port()
if not server_port:
raise RuntimeError('gRPC server port not found')
# Connect to the gRPC server
server_address = f'{server_host}:{server_port}'
logger.debug(f'Connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address)
# Connect as a host
service = PacketStreamerStub(channel)
hci_device = HciDevice(
@@ -332,10 +358,14 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
await hci_device.start()
# Create the transport object
transport = PumpedTransport(
class GrpcTransport(PumpedTransport):
async def close(self):
await super().close()
await channel.close()
transport = GrpcTransport(
PumpedPacketSource(hci_device.read),
PumpedPacketSink(hci_device.write),
channel.close,
)
transport.start()
@@ -343,7 +373,7 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
# -----------------------------------------------------------------------------
async def open_android_netsim_transport(spec):
async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
'''
Open a transport connection as a client or server, implementing Android's `netsim`
simulator protocol over gRPC.
@@ -357,6 +387,11 @@ async def open_android_netsim_transport(spec):
to connect *to* a netsim server (netsim is the controller), or accept
connections *as* a netsim-compatible server.
instance=<n>
Specifies an instance number, with <n> > 0. This is used to determine which
.init file to use. In `host` mode, it is ignored when the <host>:<port>
specifier is present, since in that case no .ini file is used.
In `host` mode:
The <host>:<port> part is optional. When not specified, the transport
looks for a netsim .ini file, from which it will read the `grpc.backend.port`
@@ -385,26 +420,29 @@ async def open_android_netsim_transport(spec):
params = spec.split(',') if spec else []
if params and ':' in params[0]:
# Explicit <host>:<port>
host, port = params[0].split(':')
host, port_str = params[0].split(':')
port = int(port_str)
params_offset = 1
else:
host = None
port = 0
params_offset = 0
options = {}
options: Dict[str, str] = {}
for param in params[params_offset:]:
if '=' not in param:
raise ValueError('invalid parameter, expected <name>=<value>')
raise TransportSpecError('invalid parameter, expected <name>=<value>')
option_name, option_value = param.split('=')
options[option_name] = option_value
mode = options.get('mode', 'host')
if mode == 'host':
return await open_android_netsim_host_transport(host, port, options)
return await open_android_netsim_host_transport_with_address(
host, port, options
)
if mode == 'controller':
if host is None:
raise ValueError('<host>:<port> missing')
return await open_android_netsim_controller_transport(host, port)
raise TransportSpecError('<host>:<port> missing')
return await open_android_netsim_controller_transport(host, port, options)
raise ValueError('invalid mode option')
raise TransportSpecError('invalid mode option')

View File

@@ -20,11 +20,13 @@ import contextlib
import struct
import asyncio
import logging
from typing import ContextManager
import io
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
from .. import hci
from ..colors import color
from ..snoop import Snooper
from bumble import core
from bumble import hci
from bumble.colors import color
from bumble.snoop import Snooper
# -----------------------------------------------------------------------------
@@ -36,42 +38,68 @@ logger = logging.getLogger(__name__)
# Information needed to parse HCI packets with a generic parser:
# For each packet type, the info represents:
# (length-size, length-offset, unpack-type)
HCI_PACKET_INFO = {
HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
hci.HCI_EVENT_PACKET: (1, 1, 'B'),
hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'),
}
# -----------------------------------------------------------------------------
class PacketPump:
'''
Pump HCI packets from a reader to a sink
'''
# Errors
# -----------------------------------------------------------------------------
class TransportLostError(core.BaseBumbleError, RuntimeError):
"""The Transport has been lost/disconnected."""
def __init__(self, reader, sink):
class TransportInitError(core.BaseBumbleError, RuntimeError):
"""Error raised when the transport cannot be initialized."""
class TransportSpecError(core.BaseBumbleError, ValueError):
"""Error raised when the transport spec is invalid."""
# -----------------------------------------------------------------------------
# Typing Protocols
# -----------------------------------------------------------------------------
class TransportSink(Protocol):
def on_packet(self, packet: bytes) -> None: ...
class TransportSource(Protocol):
terminated: asyncio.Future[None]
def set_packet_sink(self, sink: TransportSink) -> None: ...
# -----------------------------------------------------------------------------
class PacketPump:
"""
Pump HCI packets from a reader to a sink.
"""
def __init__(self, reader: AsyncPacketReader, sink: TransportSink) -> None:
self.reader = reader
self.sink = sink
async def run(self):
async def run(self) -> None:
while True:
try:
# Get a packet from the source
packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet())
# Deliver the packet to the sink
self.sink.on_packet(packet)
self.sink.on_packet(await self.reader.next_packet())
except Exception as error:
logger.warning(f'!!! {error}')
# -----------------------------------------------------------------------------
class PacketParser:
'''
"""
In-line parser that accepts data and emits 'on_packet' when a full packet has been
parsed
'''
parsed.
"""
# pylint: disable=attribute-defined-outside-init
@@ -79,18 +107,22 @@ class PacketParser:
NEED_LENGTH = 1
NEED_BODY = 2
def __init__(self, sink=None):
sink: Optional[TransportSink]
extended_packet_info: Dict[int, Tuple[int, int, str]]
packet_info: Optional[Tuple[int, int, str]] = None
def __init__(self, sink: Optional[TransportSink] = None) -> None:
self.sink = sink
self.extended_packet_info = {}
self.reset()
def reset(self):
def reset(self) -> None:
self.state = PacketParser.NEED_TYPE
self.bytes_needed = 1
self.packet = bytearray()
self.packet_info = None
def feed_data(self, data):
def feed_data(self, data: bytes) -> None:
data_offset = 0
data_left = len(data)
while data_left and self.bytes_needed:
@@ -107,10 +139,13 @@ class PacketParser:
packet_type
) or self.extended_packet_info.get(packet_type)
if self.packet_info is None:
raise ValueError(f'invalid packet type {packet_type}')
raise core.InvalidPacketError(
f'invalid packet type {packet_type}'
)
self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH:
assert self.packet_info is not None
body_length = struct.unpack_from(
self.packet_info[2], self.packet, 1 + self.packet_info[1]
)[0]
@@ -123,67 +158,69 @@ class PacketParser:
try:
self.sink.on_packet(bytes(self.packet))
except Exception as error:
logger.warning(
logger.exception(
color(f'!!! Exception in on_packet: {error}', 'red')
)
self.reset()
def set_packet_sink(self, sink):
def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink
# -----------------------------------------------------------------------------
class PacketReader:
'''
Reader that reads HCI packets from a sync source
'''
"""
Reader that reads HCI packets from a sync source.
"""
def __init__(self, source):
def __init__(self, source: io.BufferedReader) -> None:
self.source = source
self.at_end = False
def next_packet(self):
def next_packet(self) -> Optional[bytes]:
# Get the packet type
packet_type = self.source.read(1)
if len(packet_type) != 1:
self.at_end = True
return None
# Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None:
raise ValueError(f'invalid packet type {packet_type} found')
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1]
header = self.source.read(header_size)
if len(header) != header_size:
raise ValueError('packet too short')
raise core.InvalidPacketError('packet too short')
# Read the body
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
body = self.source.read(body_length)
if len(body) != body_length:
raise ValueError('packet too short')
raise core.InvalidPacketError('packet too short')
return packet_type + header + body
# -----------------------------------------------------------------------------
class AsyncPacketReader:
'''
Reader that reads HCI packets from an async source
'''
"""
Reader that reads HCI packets from an async source.
"""
def __init__(self, source):
def __init__(self, source: asyncio.StreamReader) -> None:
self.source = source
async def next_packet(self):
async def next_packet(self) -> bytes:
# Get the packet type
packet_type = await self.source.readexactly(1)
# Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None:
raise ValueError(f'invalid packet type {packet_type} found')
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1]
@@ -198,15 +235,15 @@ class AsyncPacketReader:
# -----------------------------------------------------------------------------
class AsyncPipeSink:
'''
Sink that forwards packets asynchronously to another sink
'''
"""
Sink that forwards packets asynchronously to another sink.
"""
def __init__(self, sink):
def __init__(self, sink: TransportSink) -> None:
self.sink = sink
self.loop = asyncio.get_running_loop()
def on_packet(self, packet):
def on_packet(self, packet: bytes) -> None:
self.loop.call_soon(self.sink.on_packet, packet)
@@ -216,35 +253,48 @@ class ParserSource:
Base class designed to be subclassed by transport-specific source classes
"""
def __init__(self):
terminated: asyncio.Future[None]
parser: PacketParser
def __init__(self) -> None:
self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future()
def set_packet_sink(self, sink):
def set_packet_sink(self, sink: TransportSink) -> None:
self.parser.set_packet_sink(sink)
async def wait_for_termination(self):
def on_transport_lost(self) -> None:
self.terminated.set_result(None)
if self.parser.sink:
if hasattr(self.parser.sink, 'on_transport_lost'):
self.parser.sink.on_transport_lost()
async def wait_for_termination(self) -> None:
"""
Convenience method for backward compatibility. Prefer using the `terminated`
attribute instead.
"""
return await self.terminated
def close(self):
def close(self) -> None:
pass
# -----------------------------------------------------------------------------
class StreamPacketSource(asyncio.Protocol, ParserSource):
def data_received(self, data):
def data_received(self, data: bytes) -> None:
self.parser.feed_data(data)
# -----------------------------------------------------------------------------
class StreamPacketSink:
def __init__(self, transport):
def __init__(self, transport: asyncio.WriteTransport) -> None:
self.transport = transport
def on_packet(self, packet):
def on_packet(self, packet: bytes) -> None:
self.transport.write(packet)
def close(self):
def close(self) -> None:
self.transport.close()
@@ -264,7 +314,7 @@ class Transport:
...
"""
def __init__(self, source, sink):
def __init__(self, source: TransportSource, sink: TransportSink) -> None:
self.source = source
self.sink = sink
@@ -278,34 +328,39 @@ class Transport:
return iter((self.source, self.sink))
async def close(self) -> None:
self.source.close()
self.sink.close()
if hasattr(self.source, 'close'):
self.source.close()
if hasattr(self.sink, 'close'):
self.sink.close()
# -----------------------------------------------------------------------------
class PumpedPacketSource(ParserSource):
def __init__(self, receive):
pump_task: Optional[asyncio.Task[None]]
def __init__(self, receive) -> None:
super().__init__()
self.receive_function = receive
self.pump_task = None
def start(self):
async def pump_packets():
def start(self) -> None:
async def pump_packets() -> None:
while True:
try:
packet = await self.receive_function()
self.parser.feed_data(packet)
except asyncio.exceptions.CancelledError:
except asyncio.CancelledError:
logger.debug('source pump task done')
self.terminated.set_result(None)
break
except Exception as error:
logger.warning(f'exception while waiting for packet: {error}')
self.terminated.set_result(error)
self.terminated.set_exception(error)
break
self.pump_task = asyncio.create_task(pump_packets())
def close(self):
def close(self) -> None:
if self.pump_task:
self.pump_task.cancel()
@@ -317,7 +372,7 @@ class PumpedPacketSink:
self.packet_queue = asyncio.Queue()
self.pump_task = None
def on_packet(self, packet):
def on_packet(self, packet: bytes) -> None:
self.packet_queue.put_nowait(packet)
def start(self):
@@ -326,7 +381,7 @@ class PumpedPacketSink:
try:
packet = await self.packet_queue.get()
await self.send_function(packet)
except asyncio.exceptions.CancelledError:
except asyncio.CancelledError:
logger.debug('sink pump task done')
break
except Exception as error:
@@ -342,18 +397,20 @@ class PumpedPacketSink:
# -----------------------------------------------------------------------------
class PumpedTransport(Transport):
def __init__(self, source, sink, close_function):
super().__init__(source, sink)
self.close_function = close_function
source: PumpedPacketSource
sink: PumpedPacketSink
def start(self):
def __init__(
self,
source: PumpedPacketSource,
sink: PumpedPacketSink,
) -> None:
super().__init__(source, sink)
def start(self) -> None:
self.source.start()
self.sink.start()
async def close(self):
await super().close()
await self.close_function()
# -----------------------------------------------------------------------------
class SnoopingTransport(Transport):
@@ -372,34 +429,45 @@ class SnoopingTransport(Transport):
return SnoopingTransport(
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
)
raise RuntimeError('unexpected code path') # Satisfy the type checker
raise core.UnreachableError() # Satisfy the type checker
class Source:
def __init__(self, source, snooper):
sink: TransportSink
@property
def metadata(self) -> dict[str, Any]:
return getattr(self.source, 'metadata', {})
def __init__(self, source: TransportSource, snooper: Snooper):
self.source = source
self.snooper = snooper
self.sink = None
self.terminated = source.terminated
def set_packet_sink(self, sink):
def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink
self.source.set_packet_sink(self)
def on_packet(self, packet):
def on_packet(self, packet: bytes) -> None:
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
if self.sink:
self.sink.on_packet(packet)
class Sink:
def __init__(self, sink, snooper):
def __init__(self, sink: TransportSink, snooper: Snooper) -> None:
self.sink = sink
self.snooper = snooper
def on_packet(self, packet):
def on_packet(self, packet: bytes) -> None:
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
if self.sink:
self.sink.on_packet(packet)
def __init__(self, transport, snooper, close_snooper=None):
def __init__(
self,
transport: Transport,
snooper: Snooper,
close_snooper=None,
) -> None:
super().__init__(
self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
)

View File

@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_file_transport(spec):
async def open_file_transport(spec: str) -> Transport:
'''
Open a File transport (typically not for a real file, but for a PTY or other unix
virtual files).

View File

@@ -23,6 +23,8 @@ import socket
import ctypes
import collections
from typing import Optional
from .common import Transport, ParserSource
@@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_hci_socket_transport(spec):
async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
'''
Open an HCI Socket (only available on some platforms).
The parameter string is either empty (to use the first/default Bluetooth adapter)
@@ -45,9 +47,9 @@ async def open_hci_socket_transport(spec):
# Create a raw HCI socket
try:
hci_socket = socket.socket(
socket.AF_BLUETOOTH,
socket.SOCK_RAW | socket.SOCK_NONBLOCK,
socket.BTPROTO_HCI,
socket.AF_BLUETOOTH, # type: ignore[attr-defined]
socket.SOCK_RAW | socket.SOCK_NONBLOCK, # type: ignore[attr-defined]
socket.BTPROTO_HCI, # type: ignore[attr-defined]
)
except AttributeError as error:
# Not supported on this platform
@@ -57,10 +59,7 @@ async def open_hci_socket_transport(spec):
) from error
# Compute the adapter index
if spec is None:
adapter_index = 0
else:
adapter_index = int(spec)
adapter_index = int(spec) if spec else 0
# Bind the socket
# NOTE: since Python doesn't support binding with the required address format (yet),
@@ -78,7 +77,7 @@ async def open_hci_socket_transport(spec):
bind_address = struct.pack(
# pylint: disable=no-member
'<HHH',
socket.AF_BLUETOOTH,
socket.AF_BLUETOOTH, # type: ignore[attr-defined]
adapter_index,
HCI_CHANNEL_USER,
)

View File

@@ -23,6 +23,8 @@ import atexit
import os
import logging
from typing import Optional
from .common import Transport, StreamPacketSource, StreamPacketSink
# -----------------------------------------------------------------------------
@@ -32,7 +34,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_pty_transport(spec):
async def open_pty_transport(spec: Optional[str]) -> Transport:
'''
Open a PTY transport.
The parameter string may be empty, or a path name where a symbolic link

View File

@@ -23,11 +23,24 @@ import time
import usb.core
import usb.util
from .common import Transport, ParserSource
from typing import Optional
from usb.core import Device as UsbDevice
from usb.core import USBError
from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER
from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB
from .common import Transport, ParserSource, TransportInitError
from .. import hci
from ..colors import color
# -----------------------------------------------------------------------------
# Constant
# -----------------------------------------------------------------------------
USB_PORT_FEATURE_POWER = 8
POWER_CYCLE_DELAY = 1
RESET_DELAY = 3
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -35,7 +48,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_pyusb_transport(spec):
async def open_pyusb_transport(spec: str) -> Transport:
'''
Open a USB transport. [Implementation based on PyUSB]
The parameter string has this syntax:
@@ -113,9 +126,10 @@ async def open_pyusb_transport(spec):
self.loop.call_soon_threadsafe(self.stop_event.set)
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, device, sco_enabled):
def __init__(self, device, metadata, sco_enabled):
super().__init__()
self.device = device
self.metadata = metadata
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.dequeue_task = None
@@ -213,9 +227,22 @@ async def open_pyusb_transport(spec):
usb_find = libusb_package.find
# Find the device according to the spec moniker
power_cycle = False
if spec.startswith('!'):
power_cycle = True
spec = spec[1:]
if ':' in spec:
vendor_id, product_id = spec.split(':')
device = usb_find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16))
elif '-' in spec:
def device_path(device):
if device.port_numbers:
return f'{device.bus}-{".".join(map(str, device.port_numbers))}'
else:
return str(device.bus)
device = usb_find(custom_match=lambda device: device_path(device) == spec)
else:
device_index = int(spec)
devices = list(
@@ -232,9 +259,20 @@ async def open_pyusb_transport(spec):
device = None
if device is None:
raise ValueError('device not found')
raise TransportInitError('device not found')
logger.debug(f'USB Device: {device}')
# Power Cycle the device
if power_cycle:
try:
device = await _power_cycle(device) # type: ignore
except Exception as e:
logging.debug(e)
logging.info(f"Unable to power cycle {hex(device.idVendor)} {hex(device.idProduct)}") # type: ignore
# Collect the metadata
device_metadata = {'vendor_id': device.idVendor, 'product_id': device.idProduct}
# Detach the kernel driver if needed
if device.is_kernel_driver_active(0):
logger.debug("detaching kernel driver")
@@ -289,9 +327,79 @@ async def open_pyusb_transport(spec):
# except usb.USBError:
# logger.warning('failed to set alternate setting')
packet_source = UsbPacketSource(device, sco_enabled)
packet_source = UsbPacketSource(device, device_metadata, sco_enabled)
packet_sink = UsbPacketSink(device)
packet_source.start()
packet_sink.start()
return UsbTransport(device, packet_source, packet_sink)
async def _power_cycle(device: UsbDevice) -> UsbDevice:
"""
For devices connected to compatible USB hubs: Performs a power cycle on a given USB device.
This involves temporarily disabling its port on the hub and then re-enabling it.
"""
device_path = f'{device.bus}-{".".join(map(str, device.port_numbers))}' # type: ignore
hub = _find_hub_by_device_path(device_path)
if hub:
try:
device_port = device.port_numbers[-1] # type: ignore
_set_port_status(hub, device_port, False)
await asyncio.sleep(POWER_CYCLE_DELAY)
_set_port_status(hub, device_port, True)
await asyncio.sleep(RESET_DELAY)
# Device needs to be find again otherwise it will appear as disconnected
return usb.core.find(idVendor=device.idVendor, idProduct=device.idProduct) # type: ignore
except USBError as e:
logger.error(f"Adjustment needed: Please revise the udev rule for device {hex(device.idVendor)}:{hex(device.idProduct)} for proper recognition.") # type: ignore
logger.error(e)
return device
def _set_port_status(device: UsbDevice, port: int, on: bool):
"""Sets the power status of a specific port on a USB hub."""
device.ctrl_transfer(
bmRequestType=CTRL_TYPE_CLASS | CTRL_RECIPIENT_OTHER,
bRequest=REQ_SET_FEATURE if on else REQ_CLEAR_FEATURE,
wIndex=port,
wValue=USB_PORT_FEATURE_POWER,
)
def _find_device_by_path(sys_path: str) -> Optional[UsbDevice]:
"""Finds a USB device based on its system path."""
bus_num, *port_parts = sys_path.split('-')
ports = [int(port) for port in port_parts[0].split('.')]
devices = usb.core.find(find_all=True, bus=int(bus_num))
if devices:
for device in devices:
if device.bus == int(bus_num) and list(device.port_numbers) == ports: # type: ignore
return device
return None
def _find_hub_by_device_path(sys_path: str) -> Optional[UsbDevice]:
"""Finds the USB hub associated with a specific device path."""
hub_sys_path = sys_path.rsplit('.', 1)[0]
hub_device = _find_device_by_path(hub_sys_path)
if hub_device is None:
return None
else:
return hub_device if _is_hub(hub_device) else None
def _is_hub(device: UsbDevice) -> bool:
"""Checks if a USB device is a hub"""
if device.bDeviceClass == CLASS_HUB: # type: ignore
return True
for config in device:
for interface in config:
if interface.bInterfaceClass == CLASS_HUB: # type: ignore
return True
return False

View File

@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_serial_transport(spec):
async def open_serial_transport(spec: str) -> Transport:
'''
Open a serial port transport.
The parameter string has this syntax:

View File

@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_tcp_client_transport(spec):
async def open_tcp_client_transport(spec: str) -> Transport:
'''
Open a TCP client transport.
The parameter string has this syntax:
@@ -39,7 +39,7 @@ async def open_tcp_client_transport(spec):
class TcpPacketSource(StreamPacketSource):
def connection_lost(self, exc):
logger.debug(f'connection lost: {exc}')
self.terminated.set_result(exc)
self.on_transport_lost()
remote_host, remote_port = spec.split(':')
tcp_transport, packet_source = await asyncio.get_running_loop().create_connection(

View File

@@ -15,8 +15,10 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import socket
from .common import Transport, StreamPacketSource
@@ -27,7 +29,14 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_tcp_server_transport(spec):
# A pass-through function to ease mock testing.
async def _create_server(*args, **kw_args):
await asyncio.get_running_loop().create_server(*args, **kw_args)
async def open_tcp_server_transport(spec: str) -> Transport:
'''
Open a TCP server transport.
The parameter string has this syntax:
@@ -37,12 +46,27 @@ async def open_tcp_server_transport(spec):
Example: _:9001
'''
local_host, local_port = spec.split(':')
return await _open_tcp_server_transport_impl(
host=local_host if local_host != '_' else None, port=int(local_port)
)
async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transport:
'''
Open a TCP server transport with an existing socket.
One reason to use this variant is to let python pick an unused port.
'''
return await _open_tcp_server_transport_impl(sock=sock)
async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
class TcpServerTransport(Transport):
async def close(self):
await super().close()
class TcpServerProtocol:
class TcpServerProtocol(asyncio.BaseProtocol):
def __init__(self, packet_source, packet_sink):
self.packet_source = packet_source
self.packet_sink = packet_sink
@@ -76,13 +100,10 @@ async def open_tcp_server_transport(spec):
else:
logger.debug('no client, dropping packet')
local_host, local_port = spec.split(':')
packet_source = StreamPacketSource()
packet_sink = TcpServerPacketSink()
await asyncio.get_running_loop().create_server(
lambda: TcpServerProtocol(packet_source, packet_sink),
host=local_host if local_host != '_' else None,
port=int(local_port),
await _create_server(
lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs
)
return TcpServerTransport(packet_source, packet_sink)

View File

@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_udp_transport(spec):
async def open_udp_transport(spec: str) -> Transport:
'''
Open a UDP transport.
The parameter string has this syntax:

56
bumble/transport/unix.py Normal file
View File

@@ -0,0 +1,56 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
from .common import Transport, StreamPacketSource, StreamPacketSink
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_unix_client_transport(spec: str) -> Transport:
'''Open a UNIX socket client transport.
The parameter is the path of unix socket. For abstract socket, the first character
needs to be '@'.
Example:
* /tmp/hci.socket
* @hci_socket
'''
class UnixPacketSource(StreamPacketSource):
def connection_lost(self, exc):
logger.debug(f'connection lost: {exc}')
self.on_transport_lost()
# For abstract socket, the first character should be null character.
if spec.startswith('@'):
spec = '\0' + spec[1:]
(
unix_transport,
packet_source,
) = await asyncio.get_running_loop().create_unix_connection(UnixPacketSource, spec)
packet_sink = StreamPacketSink(unix_transport)
return Transport(packet_source, packet_sink)

View File

@@ -15,18 +15,18 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import threading
import collections
import ctypes
import platform
import usb1
from .common import Transport, ParserSource
from .. import hci
from ..colors import color
from bumble.transport.common import Transport, ParserSource, TransportInitError
from bumble import hci
from bumble.colors import color
# -----------------------------------------------------------------------------
@@ -60,7 +60,7 @@ def load_libusb():
usb1.loadLibrary(libusb_dll)
async def open_usb_transport(spec):
async def open_usb_transport(spec: str) -> Transport:
'''
Open a USB transport.
The moniker string has this syntax:
@@ -107,20 +107,24 @@ async def open_usb_transport(spec):
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
READ_SIZE = 1024
READ_SIZE = 4096
class UsbPacketSink:
def __init__(self, device, acl_out):
self.device = device
self.acl_out = acl_out
self.transfer = device.getTransfer()
self.packets = collections.deque() # Queue of packets waiting to be sent
self.acl_out_transfer = device.getTransfer()
self.acl_out_transfer_ready = asyncio.Semaphore(1)
self.packets: asyncio.Queue[bytes] = (
asyncio.Queue()
) # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop()
self.queue_task = None
self.cancel_done = self.loop.create_future()
self.closed = False
def start(self):
pass
self.queue_task = asyncio.create_task(self.process_queue())
def on_packet(self, packet):
# Ignore packets if we're closed
@@ -132,72 +136,71 @@ async def open_usb_transport(spec):
return
# Queue the packet
self.packets.append(packet)
if len(self.packets) == 1:
# The queue was previously empty, re-prime the pump
self.process_queue()
self.packets.put_nowait(packet)
def on_packet_sent(self, transfer):
def transfer_callback(self, transfer):
self.acl_out_transfer_ready.release()
status = transfer.getStatus()
# logger.debug(f'<<< USB out transfer callback: status={status}')
# pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED:
self.loop.call_soon_threadsafe(self.on_packet_sent_)
elif status == usb1.TRANSFER_CANCELLED:
if status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
else:
return
if status != usb1.TRANSFER_COMPLETED:
logger.warning(
color(f'!!! out transfer not completed: status={status}', 'red')
color(f'!!! OUT transfer not completed: status={status}', 'red')
)
def on_packet_sent_(self):
if self.packets:
self.packets.popleft()
self.process_queue()
async def process_queue(self):
while True:
# Wait for a packet to transfer.
packet = await self.packets.get()
def process_queue(self):
if len(self.packets) == 0:
return # Nothing to do
# Wait until we can start a transfer.
await self.acl_out_transfer_ready.acquire()
packet = self.packets[0]
packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.transfer.setBulk(
self.acl_out, packet[1:], callback=self.on_packet_sent
)
logger.debug('submit ACL')
self.transfer.submit()
elif packet_type == hci.HCI_COMMAND_PACKET:
self.transfer.setControl(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:],
callback=self.on_packet_sent,
)
logger.debug('submit COMMAND')
self.transfer.submit()
else:
logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
# Transfer the packet.
packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.acl_out_transfer.setBulk(
self.acl_out, packet[1:], callback=self.transfer_callback
)
self.acl_out_transfer.submit()
elif packet_type == hci.HCI_COMMAND_PACKET:
self.acl_out_transfer.setControl(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:],
callback=self.transfer_callback,
)
self.acl_out_transfer.submit()
else:
logger.warning(
color(f'unsupported packet type {packet_type}', 'red')
)
def close(self):
self.closed = True
if self.queue_task:
self.queue_task.cancel()
async def terminate(self):
if not self.closed:
self.close()
# Empty the packet queue so that we don't send any more data
self.packets.clear()
while not self.packets.empty():
self.packets.get_nowait()
# If we have a transfer in flight, cancel it
if self.transfer.isSubmitted():
if self.acl_out_transfer.isSubmitted():
# Try to cancel the transfer, but that may fail because it may have
# already completed
try:
self.transfer.cancel()
self.acl_out_transfer.cancel()
logger.debug('waiting for OUT transfer cancellation to be done...')
await self.cancel_done
@@ -206,26 +209,22 @@ async def open_usb_transport(spec):
logger.debug('OUT transfer likely already completed')
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, context, device, acl_in, events_in):
def __init__(self, device, metadata, acl_in, events_in):
super().__init__()
self.context = context
self.device = device
self.metadata = metadata
self.acl_in = acl_in
self.acl_in_transfer = None
self.events_in = events_in
self.events_in_transfer = None
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.dequeue_task = None
self.closed = False
self.event_loop_done = self.loop.create_future()
self.cancel_done = {
hci.HCI_EVENT_PACKET: self.loop.create_future(),
hci.HCI_ACL_DATA_PACKET: self.loop.create_future(),
}
self.events_in_transfer = None
self.acl_in_transfer = None
# Create a thread to process events
self.event_thread = threading.Thread(target=self.run)
self.closed = False
def start(self):
# Set up transfer objects for input
@@ -233,7 +232,7 @@ async def open_usb_transport(spec):
self.events_in_transfer.setInterrupt(
self.events_in,
READ_SIZE,
callback=self.on_packet_received,
callback=self.transfer_callback,
user_data=hci.HCI_EVENT_PACKET,
)
self.events_in_transfer.submit()
@@ -242,22 +241,23 @@ async def open_usb_transport(spec):
self.acl_in_transfer.setBulk(
self.acl_in,
READ_SIZE,
callback=self.on_packet_received,
callback=self.transfer_callback,
user_data=hci.HCI_ACL_DATA_PACKET,
)
self.acl_in_transfer.submit()
self.dequeue_task = self.loop.create_task(self.dequeue())
self.event_thread.start()
def on_packet_received(self, transfer):
@property
def usb_transfer_submitted(self):
return (
self.events_in_transfer.isSubmitted()
or self.acl_in_transfer.isSubmitted()
)
def transfer_callback(self, transfer):
packet_type = transfer.getUserData()
status = transfer.getStatus()
# logger.debug(
# f'<<< USB IN transfer callback: status={status} '
# f'packet_type={packet_type} '
# f'length={transfer.getActualLength()}'
# )
# pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED:
@@ -266,18 +266,18 @@ async def open_usb_transport(spec):
+ transfer.getBuffer()[: transfer.getActualLength()]
)
self.loop.call_soon_threadsafe(self.queue.put_nowait, packet)
# Re-submit the transfer so we can receive more data
transfer.submit()
elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(
self.cancel_done[packet_type].set_result, None
)
return
else:
logger.warning(
color(f'!!! transfer not completed: status={status}', 'red')
color(f'!!! IN transfer not completed: status={status}', 'red')
)
# Re-submit the transfer so we can receive more data
transfer.submit()
self.loop.call_soon_threadsafe(self.on_transport_lost)
async def dequeue(self):
while not self.closed:
@@ -287,21 +287,6 @@ async def open_usb_transport(spec):
return
self.parser.feed_data(packet)
def run(self):
logger.debug('starting USB event loop')
while (
self.events_in_transfer.isSubmitted()
or self.acl_in_transfer.isSubmitted()
):
# pylint: disable=no-member
try:
self.context.handleEvents()
except usb1.USBErrorInterrupted:
pass
logger.debug('USB event loop done')
self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
def close(self):
self.closed = True
@@ -330,15 +315,14 @@ async def open_usb_transport(spec):
f'IN[{packet_type}] transfer likely already completed'
)
# Wait for the thread to terminate
await self.event_loop_done
class UsbTransport(Transport):
def __init__(self, context, device, interface, setting, source, sink):
super().__init__(source, sink)
self.context = context
self.device = device
self.interface = interface
self.loop = asyncio.get_running_loop()
self.event_loop_done = self.loop.create_future()
# Get exclusive access
device.claimInterface(interface)
@@ -351,6 +335,22 @@ async def open_usb_transport(spec):
source.start()
sink.start()
# Create a thread to process events
self.event_thread = threading.Thread(target=self.run)
self.event_thread.start()
def run(self):
logger.debug('starting USB event loop')
while self.source.usb_transfer_submitted:
# pylint: disable=no-member
try:
self.context.handleEvents()
except usb1.USBErrorInterrupted:
pass
logger.debug('USB event loop done')
self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
async def close(self):
self.source.close()
self.sink.close()
@@ -360,6 +360,9 @@ async def open_usb_transport(spec):
self.device.close()
self.context.close()
# Wait for the thread to terminate
await self.event_loop_done
# Find the device according to the spec moniker
load_libusb()
context = usb1.USBContext()
@@ -398,6 +401,16 @@ async def open_usb_transport(spec):
break
device_index -= 1
device.close()
elif '-' in spec:
def device_path(device):
return f'{device.getBusNumber()}-{".".join(map(str, device.getPortNumberList()))}'
for device in context.getDeviceIterator(skip_on_error=True):
if device_path(device) == spec:
found = device
break
device.close()
else:
# Look for a compatible device by index
def device_is_bluetooth_hci(device):
@@ -434,14 +447,14 @@ async def open_usb_transport(spec):
if found is None:
context.close()
raise ValueError('device not found')
raise TransportInitError('device not found')
logger.debug(f'USB Device: {found}')
# Look for the first interface with the right class and endpoints
def find_endpoints(device):
# pylint: disable-next=too-many-nested-blocks
for (configuration_index, configuration) in enumerate(device):
for configuration_index, configuration in enumerate(device):
interface = None
for interface in configuration:
setting = None
@@ -499,7 +512,7 @@ async def open_usb_transport(spec):
endpoints = find_endpoints(found)
if endpoints is None:
raise ValueError('no compatible interface found for device')
raise TransportInitError('no compatible interface found for device')
(configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
logger.debug(
f'selected endpoints: configuration={configuration}, '
@@ -510,6 +523,10 @@ async def open_usb_transport(spec):
f'events_in=0x{events_in:02X}, '
)
device_metadata = {
'vendor_id': found.getVendorID(),
'product_id': found.getProductID(),
}
device = found.open()
# Auto-detach the kernel driver if supported
@@ -535,7 +552,7 @@ async def open_usb_transport(spec):
except usb1.USBError:
logger.warning('failed to set configuration')
source = UsbPacketSource(context, device, acl_in, events_in)
source = UsbPacketSource(device, device_metadata, acl_in, events_in)
sink = UsbPacketSink(device, acl_out)
return UsbTransport(context, device, interface, setting, source, sink)
except usb1.USBError as error:

View File

@@ -17,6 +17,9 @@
# -----------------------------------------------------------------------------
import logging
from typing import Optional
from .common import Transport
from .file import open_file_transport
# -----------------------------------------------------------------------------
@@ -26,7 +29,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_vhci_transport(spec):
async def open_vhci_transport(spec: Optional[str]) -> Transport:
'''
Open a VHCI transport (only available on some platforms).
The parameter string is either empty (to use the default VHCI device
@@ -42,15 +45,15 @@ async def open_vhci_transport(spec):
# Override the source's `data_received` method so that we can
# filter out the vendor packet that is received just after the
# initial open
def vhci_data_received(data):
def vhci_data_received(data: bytes) -> None:
if len(data) > 0 and data[0] == HCI_VENDOR_PKT:
if len(data) == 4:
hci_index = data[2] << 8 | data[3]
logger.info(f'HCI index {hci_index}')
else:
transport.source.parser.feed_data(data)
transport.source.parser.feed_data(data) # type: ignore
transport.source.data_received = vhci_data_received
transport.source.data_received = vhci_data_received # type: ignore
# Write the initial config
transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR]))

View File

@@ -16,9 +16,9 @@
# Imports
# -----------------------------------------------------------------------------
import logging
import websockets
import websockets.client
from .common import PumpedPacketSource, PumpedPacketSink, PumpedTransport
from .common import PumpedPacketSource, PumpedPacketSink, PumpedTransport, Transport
# -----------------------------------------------------------------------------
# Logging
@@ -27,23 +27,25 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_ws_client_transport(spec):
async def open_ws_client_transport(spec: str) -> Transport:
'''
Open a WebSocket client transport.
The parameter string has this syntax:
<remote-host>:<remote-port>
<websocket-url>
Example: 127.0.0.1:9001
Example: ws://localhost:7681/v1/websocket/bt
'''
remote_host, remote_port = spec.split(':')
uri = f'ws://{remote_host}:{remote_port}'
websocket = await websockets.connect(uri)
websocket = await websockets.client.connect(spec)
transport = PumpedTransport(
class WsTransport(PumpedTransport):
async def close(self):
await super().close()
await websocket.close()
transport = WsTransport(
PumpedPacketSource(websocket.recv),
PumpedPacketSink(websocket.send),
websocket.close,
)
transport.start()
return transport

View File

@@ -15,7 +15,6 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import websockets
@@ -28,7 +27,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_ws_server_transport(spec):
async def open_ws_server_transport(spec: str) -> Transport:
'''
Open a WebSocket server transport.
The parameter string has this syntax:
@@ -43,7 +42,7 @@ async def open_ws_server_transport(spec):
def __init__(self):
source = ParserSource()
sink = PumpedPacketSink(self.send_packet)
self.connection = asyncio.get_running_loop().create_future()
self.connection = None
self.server = None
super().__init__(source, sink)
@@ -63,7 +62,7 @@ async def open_ws_server_transport(spec):
f'new connection on {connection.local_address} '
f'from {connection.remote_address}'
)
self.connection.set_result(connection)
self.connection = connection
# pylint: disable=no-member
try:
async for packet in connection:
@@ -74,12 +73,14 @@ async def open_ws_server_transport(spec):
except websockets.WebSocketException as error:
logger.debug(f'exception while receiving packet: {error}')
# Wait for a new connection
self.connection = asyncio.get_running_loop().create_future()
# We're now disconnected
self.connection = None
async def send_packet(self, packet):
connection = await self.connection
return await connection.send(packet)
if self.connection is None:
logger.debug('no connection, dropping packet')
return
return await self.connection.send(packet)
local_host, local_port = spec.split(':')
transport = WsServerTransport()

View File

@@ -15,13 +15,27 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import traceback
import collections
import enum
import functools
import logging
import sys
from typing import Awaitable, Set, TypeVar
from functools import wraps
import warnings
from typing import (
Awaitable,
Set,
TypeVar,
List,
Tuple,
Callable,
Any,
Optional,
Union,
overload,
)
from pyee import EventEmitter
from .colors import color
@@ -64,6 +78,106 @@ def composite_listener(cls):
return cls
# -----------------------------------------------------------------------------
_Handler = TypeVar('_Handler', bound=Callable)
class EventWatcher:
'''A wrapper class to control the lifecycle of event handlers better.
Usage:
```
watcher = EventWatcher()
def on_foo():
...
watcher.on(emitter, 'foo', on_foo)
@watcher.on(emitter, 'bar')
def on_bar():
...
# Close all event handlers watching through this watcher
watcher.close()
```
As context:
```
with contextlib.closing(EventWatcher()) as context:
@context.on(emitter, 'foo')
def on_foo():
...
# on_foo() has been removed here!
```
'''
handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]]
def __init__(self) -> None:
self.handlers = []
@overload
def on(
self, emitter: EventEmitter, event: str
) -> Callable[[_Handler], _Handler]: ...
@overload
def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler: ...
def on(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event until the context is closed.
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing is passed, this method
works as a decorator.
'''
def wrapper(wrapped: _Handler) -> _Handler:
self.handlers.append((emitter, event, wrapped))
emitter.on(event, wrapped)
return wrapped
return wrapper if handler is None else wrapper(handler)
@overload
def once(
self, emitter: EventEmitter, event: str
) -> Callable[[_Handler], _Handler]: ...
@overload
def once(
self, emitter: EventEmitter, event: str, handler: _Handler
) -> _Handler: ...
def once(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event for once.
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing passed, this method works
as a decorator.
'''
def wrapper(wrapped: _Handler) -> _Handler:
self.handlers.append((emitter, event, wrapped))
emitter.once(event, wrapped)
return wrapped
return wrapper if handler is None else wrapper(handler)
def close(self) -> None:
for emitter, event, handler in self.handlers:
if handler in emitter.listeners(event):
emitter.remove_listener(event, handler)
# -----------------------------------------------------------------------------
_T = TypeVar('_T')
@@ -114,13 +228,13 @@ class CompositeEventEmitter(AbortableEventEmitter):
if self._listener:
# Call the deregistration methods for each base class that has them
for cls in self._listener.__class__.mro():
if hasattr(cls, '_bumble_register_composite'):
cls._bumble_deregister_composite(listener, self)
if '_bumble_register_composite' in cls.__dict__:
cls._bumble_deregister_composite(self._listener, self)
self._listener = listener
if listener:
# Call the registration methods for each base class that has them
for cls in listener.__class__.mro():
if hasattr(cls, '_bumble_deregister_composite'):
if '_bumble_deregister_composite' in cls.__dict__:
cls._bumble_register_composite(listener, self)
@@ -167,21 +281,18 @@ class AsyncRunner:
"""
def decorator(func):
@wraps(func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
coroutine = func(*args, **kwargs)
if queue is None:
# Create a task to run the coroutine
# Spawn the coroutine as a task
async def run():
try:
await coroutine
except Exception:
logger.warning(
f'{color("!!! Exception in wrapper:", "red")} '
f'{traceback.format_exc()}'
)
logger.exception(color("!!! Exception in wrapper:", "red"))
asyncio.create_task(run())
AsyncRunner.spawn(run())
else:
# Queue the coroutine to be awaited by the work queue
queue.enqueue(coroutine)
@@ -302,3 +413,77 @@ class FlowControlAsyncPipe:
self.resume_source()
self.check_pump()
# -----------------------------------------------------------------------------
async def async_call(function, *args, **kwargs):
"""
Immediately calls the function with provided args and kwargs, wrapping it in an
async function.
Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject
a running loop.
result = await async_call(some_function, ...)
"""
return function(*args, **kwargs)
# -----------------------------------------------------------------------------
def wrap_async(function):
"""
Wraps the provided function in an async function.
"""
return functools.partial(async_call, function)
# -----------------------------------------------------------------------------
def deprecated(msg: str):
"""
Throw deprecation warning before execution.
"""
def wrapper(function):
@functools.wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return function(*args, **kwargs)
return inner
return wrapper
# -----------------------------------------------------------------------------
def experimental(msg: str):
"""
Throws a future warning before execution.
"""
def wrapper(function):
@functools.wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, FutureWarning)
return function(*args, **kwargs)
return inner
return wrapper
# -----------------------------------------------------------------------------
class OpenIntEnum(enum.IntEnum):
"""
Subclass of enum.IntEnum that can hold integer values outside the set of
predefined values. This is convenient for implementing protocols where some
integer constants may be added over time.
"""
@classmethod
def _missing_(cls, value):
if not isinstance(value, int):
return None
obj = int.__new__(cls, value)
obj._value_ = value
obj._name_ = f"{cls.__name__}[{value}]"
return obj

0
bumble/vendor/__init__.py vendored Normal file
View File

0
bumble/vendor/android/__init__.py vendored Normal file
View File

318
bumble/vendor/android/hci.py vendored Normal file
View File

@@ -0,0 +1,318 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
from bumble.hci import (
name_or_number,
hci_vendor_command_op_code,
Address,
HCI_Constant,
HCI_Object,
HCI_Command,
HCI_Vendor_Event,
STATUS_SPEC,
)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# Android Vendor Specific Commands and Events.
# Only a subset of the commands are implemented here currently.
#
# pylint: disable-next=line-too-long
# See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#chip-capabilities-and-configuration
HCI_LE_GET_VENDOR_CAPABILITIES_COMMAND = hci_vendor_command_op_code(0x153)
HCI_LE_APCF_COMMAND = hci_vendor_command_op_code(0x157)
HCI_GET_CONTROLLER_ACTIVITY_ENERGY_INFO_COMMAND = hci_vendor_command_op_code(0x159)
HCI_A2DP_HARDWARE_OFFLOAD_COMMAND = hci_vendor_command_op_code(0x15D)
HCI_BLUETOOTH_QUALITY_REPORT_COMMAND = hci_vendor_command_op_code(0x15E)
HCI_DYNAMIC_AUDIO_BUFFER_COMMAND = hci_vendor_command_op_code(0x15F)
HCI_BLUETOOTH_QUALITY_REPORT_EVENT = 0x58
HCI_Command.register_commands(globals())
HCI_Vendor_Event.register_subevents(globals())
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('max_advt_instances', 1),
('offloaded_resolution_of_private_address', 1),
('total_scan_results_storage', 2),
('max_irk_list_sz', 1),
('filtering_support', 1),
('max_filter', 1),
('activity_energy_info_support', 1),
('version_supported', 2),
('total_num_of_advt_tracked', 2),
('extended_scan_support', 1),
('debug_logging_supported', 1),
('le_address_generation_offloading_support', 1),
('a2dp_source_offload_capability_mask', 4),
('bluetooth_quality_report_support', 1),
('dynamic_audio_buffer_support', 4),
]
)
class HCI_LE_Get_Vendor_Capabilities_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#vendor-specific-capabilities
'''
@classmethod
def parse_return_parameters(cls, parameters):
# There are many versions of this data structure, so we need to parse until
# there are no more bytes to parse, and leave un-signal parameters set to
# None (older versions)
nones = {field: None for field, _ in cls.return_parameters_fields}
return_parameters = HCI_Object(cls.return_parameters_fields, **nones)
try:
offset = 0
for field in cls.return_parameters_fields:
field_name, field_type = field
field_value, field_size = HCI_Object.parse_field(
parameters, offset, field_type
)
setattr(return_parameters, field_name, field_value)
offset += field_size
except struct.error:
pass
return return_parameters
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_LE_APCF_Command.opcode_name(x),
},
),
('payload', '*'),
],
return_parameters_fields=[
('status', STATUS_SPEC),
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_LE_APCF_Command.opcode_name(x),
},
),
('payload', '*'),
],
)
class HCI_LE_APCF_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_apcf_command
NOTE: the subcommand-specific payloads are left as opaque byte arrays in this
implementation. A future enhancement may define subcommand-specific data structures.
'''
# APCF Subcommands
# TODO: use the OpenIntEnum class (when upcoming PR is merged)
APCF_ENABLE = 0x00
APCF_SET_FILTERING_PARAMETERS = 0x01
APCF_BROADCASTER_ADDRESS = 0x02
APCF_SERVICE_UUID = 0x03
APCF_SERVICE_SOLICITATION_UUID = 0x04
APCF_LOCAL_NAME = 0x05
APCF_MANUFACTURER_DATA = 0x06
APCF_SERVICE_DATA = 0x07
APCF_TRANSPORT_DISCOVERY_SERVICE = 0x08
APCF_AD_TYPE_FILTER = 0x09
APCF_READ_EXTENDED_FEATURES = 0xFF
OPCODE_NAMES = {
APCF_ENABLE: 'APCF_ENABLE',
APCF_SET_FILTERING_PARAMETERS: 'APCF_SET_FILTERING_PARAMETERS',
APCF_BROADCASTER_ADDRESS: 'APCF_BROADCASTER_ADDRESS',
APCF_SERVICE_UUID: 'APCF_SERVICE_UUID',
APCF_SERVICE_SOLICITATION_UUID: 'APCF_SERVICE_SOLICITATION_UUID',
APCF_LOCAL_NAME: 'APCF_LOCAL_NAME',
APCF_MANUFACTURER_DATA: 'APCF_MANUFACTURER_DATA',
APCF_SERVICE_DATA: 'APCF_SERVICE_DATA',
APCF_TRANSPORT_DISCOVERY_SERVICE: 'APCF_TRANSPORT_DISCOVERY_SERVICE',
APCF_AD_TYPE_FILTER: 'APCF_AD_TYPE_FILTER',
APCF_READ_EXTENDED_FEATURES: 'APCF_READ_EXTENDED_FEATURES',
}
@classmethod
def opcode_name(cls, opcode):
return name_or_number(cls.OPCODE_NAMES, opcode)
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('total_tx_time_ms', 4),
('total_rx_time_ms', 4),
('total_idle_time_ms', 4),
('total_energy_used', 4),
],
)
class HCI_Get_Controller_Activity_Energy_Info_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_get_controller_activity_energy_info
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_A2DP_Hardware_Offload_Command.opcode_name(x),
},
),
('payload', '*'),
],
return_parameters_fields=[
('status', STATUS_SPEC),
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_A2DP_Hardware_Offload_Command.opcode_name(x),
},
),
('payload', '*'),
],
)
class HCI_A2DP_Hardware_Offload_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#a2dp-hardware-offload-support
NOTE: the subcommand-specific payloads are left as opaque byte arrays in this
implementation. A future enhancement may define subcommand-specific data structures.
'''
# A2DP Hardware Offload Subcommands
# TODO: use the OpenIntEnum class (when upcoming PR is merged)
START_A2DP_OFFLOAD = 0x01
STOP_A2DP_OFFLOAD = 0x02
OPCODE_NAMES = {
START_A2DP_OFFLOAD: 'START_A2DP_OFFLOAD',
STOP_A2DP_OFFLOAD: 'STOP_A2DP_OFFLOAD',
}
@classmethod
def opcode_name(cls, opcode):
return name_or_number(cls.OPCODE_NAMES, opcode)
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_Dynamic_Audio_Buffer_Command.opcode_name(x),
},
),
('payload', '*'),
],
return_parameters_fields=[
('status', STATUS_SPEC),
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_Dynamic_Audio_Buffer_Command.opcode_name(x),
},
),
('payload', '*'),
],
)
class HCI_Dynamic_Audio_Buffer_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#dynamic-audio-buffer-command
NOTE: the subcommand-specific payloads are left as opaque byte arrays in this
implementation. A future enhancement may define subcommand-specific data structures.
'''
# Dynamic Audio Buffer Subcommands
# TODO: use the OpenIntEnum class (when upcoming PR is merged)
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
OPCODE_NAMES = {
GET_AUDIO_BUFFER_TIME_CAPABILITY: 'GET_AUDIO_BUFFER_TIME_CAPABILITY',
}
@classmethod
def opcode_name(cls, opcode):
return name_or_number(cls.OPCODE_NAMES, opcode)
# -----------------------------------------------------------------------------
@HCI_Vendor_Event.event(
fields=[
('quality_report_id', 1),
('packet_types', 1),
('connection_handle', 2),
('connection_role', {'size': 1, 'mapper': HCI_Constant.role_name}),
('tx_power_level', -1),
('rssi', -1),
('snr', 1),
('unused_afh_channel_count', 1),
('afh_select_unideal_channel_count', 1),
('lsto', 2),
('connection_piconet_clock', 4),
('retransmission_count', 4),
('no_rx_count', 4),
('nak_count', 4),
('last_tx_ack_timestamp', 4),
('flow_off_count', 4),
('last_flow_on_timestamp', 4),
('buffer_overflow_bytes', 4),
('buffer_underflow_bytes', 4),
('bdaddr', Address.parse_address),
('cal_failed_item_count', 1),
('tx_total_packets', 4),
('tx_unacked_packets', 4),
('tx_flushed_packets', 4),
('tx_last_subevent_packets', 4),
('crc_error_packets', 4),
('rx_duplicate_packets', 4),
('vendor_specific_parameters', '*'),
]
)
class HCI_Bluetooth_Quality_Report_Event(HCI_Vendor_Event):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#bluetooth-quality-report-sub-event
'''

0
bumble/vendor/zephyr/__init__.py vendored Normal file
View File

88
bumble/vendor/zephyr/hci.py vendored Normal file
View File

@@ -0,0 +1,88 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from bumble.hci import (
hci_vendor_command_op_code,
HCI_Command,
STATUS_SPEC,
)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# Zephyr RTOS Vendor Specific Commands and Events.
# Only a subset of the commands are implemented here currently.
#
# pylint: disable-next=line-too-long
# See https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
HCI_WRITE_TX_POWER_LEVEL_COMMAND = hci_vendor_command_op_code(0x000E)
HCI_READ_TX_POWER_LEVEL_COMMAND = hci_vendor_command_op_code(0x000F)
HCI_Command.register_commands(globals())
# -----------------------------------------------------------------------------
class TX_Power_Level_Command:
'''
Base class for read and write TX power level HCI commands
'''
TX_POWER_HANDLE_TYPE_ADV = 0x00
TX_POWER_HANDLE_TYPE_SCAN = 0x01
TX_POWER_HANDLE_TYPE_CONN = 0x02
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[('handle_type', 1), ('connection_handle', 2), ('tx_power_level', -1)],
return_parameters_fields=[
('status', STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('selected_tx_power_level', -1),
],
)
class HCI_Write_Tx_Power_Level_Command(HCI_Command, TX_Power_Level_Command):
'''
Write TX power level. See BT_HCI_OP_VS_WRITE_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
Power level is in dB. Connection handle for TX_POWER_HANDLE_TYPE_ADV and
TX_POWER_HANDLE_TYPE_SCAN should be zero.
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[('handle_type', 1), ('connection_handle', 2)],
return_parameters_fields=[
('status', STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('tx_power_level', -1),
],
)
class HCI_Read_Tx_Power_Level_Command(HCI_Command, TX_Power_Level_Command):
'''
Read TX power level. See BT_HCI_OP_VS_READ_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
Power level is in dB. Connection handle for TX_POWER_HANDLE_TYPE_ADV and
TX_POWER_HANDLE_TYPE_SCAN should be zero.
'''

BIN
docs/images/favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

Some files were not shown because too many files have changed in this diff Show More