forked from auracaster/bumble_mirror
Compare commits
394 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0fa517a4f6 | ||
|
|
a11962a487 | ||
|
|
32d448edf3 | ||
|
|
3d615b13ce | ||
|
|
1ad92dc759 | ||
|
|
aacfd4328c | ||
|
|
6aa1f5211c | ||
|
|
df8e454ee5 | ||
|
|
aec50ac616 | ||
|
|
6a3eaa457f | ||
|
|
6e6b4cd4b2 | ||
|
|
aa1d7933da | ||
|
|
34e0f293c2 | ||
|
|
85215df2c3 | ||
|
|
f8223ca81f | ||
|
|
2b0b1ad726 | ||
|
|
58debcd8bb | ||
|
|
6eba81e3dd | ||
|
|
768bbd95cc | ||
|
|
502b80af0d | ||
|
|
a25427305c | ||
|
|
3c47739029 | ||
|
|
8fc1330948 | ||
|
|
8a5f6a61d5 | ||
|
|
83c5061700 | ||
|
|
b80b790dc1 | ||
|
|
21bf69592c | ||
|
|
7d8addb849 | ||
|
|
d86d69d816 | ||
|
|
bb08a1c70b | ||
|
|
dc93f32a9a | ||
|
|
9838908a26 | ||
|
|
613519f0b3 | ||
|
|
a943ea57ef | ||
|
|
14401910bb | ||
|
|
5d35ed471c | ||
|
|
c720ad5fdc | ||
|
|
f02183f95d | ||
|
|
d903937a51 | ||
|
|
6381ee0ab1 | ||
|
|
59d99780e1 | ||
|
|
4bf0bc03af | ||
|
|
91ba2f61f1 | ||
|
|
116dc9b319 | ||
|
|
9f3d8c9b49 | ||
|
|
31961febe5 | ||
|
|
dab0993cba | ||
|
|
6f73b736d7 | ||
|
|
6091e6365d | ||
|
|
3333ba472b | ||
|
|
8bda7d2212 | ||
|
|
7aba36302a | ||
|
|
ceefe8b2a5 | ||
|
|
cd37027795 | ||
|
|
bb2aa8229d | ||
|
|
4aed53c48d | ||
|
|
4a88e9a0cf | ||
|
|
3b8dd6f3cf | ||
|
|
f41b7746d2 | ||
|
|
1b727741bf | ||
|
|
d2bc8175fb | ||
|
|
84dfff290a | ||
|
|
17563e423a | ||
|
|
19d3616032 | ||
|
|
4a48309643 | ||
|
|
870217acb3 | ||
|
|
f8077d7996 | ||
|
|
739907fa31 | ||
|
|
a275c399a3 | ||
|
|
c98275f385 | ||
|
|
0b19347bef | ||
|
|
f61fd64c0b | ||
|
|
ec12771be6 | ||
|
|
5b33e715da | ||
|
|
b885f29318 | ||
|
|
7ca13188d5 | ||
|
|
89586d5d18 | ||
|
|
381032ceb9 | ||
|
|
12ca1c01f0 | ||
|
|
a7111d0107 | ||
|
|
c034297bc0 | ||
|
|
a1eff958e6 | ||
|
|
d6282a7247 | ||
|
|
efdc770fde | ||
|
|
357d7f9c22 | ||
|
|
3bc08b4e0d | ||
|
|
982aaeabc3 | ||
|
|
1dc0950177 | ||
|
|
df0fd74533 | ||
|
|
822f97fa84 | ||
|
|
4a6b0ef840 | ||
|
|
a6ead0147e | ||
|
|
0665e9ca5c | ||
|
|
b8b78ca1ee | ||
|
|
d611d25802 | ||
|
|
bf8a2cdcb5 | ||
|
|
cce2e4d4e3 | ||
|
|
4bf7448a01 | ||
|
|
1b44e73f90 | ||
|
|
1a81c5d05c | ||
|
|
d8a43f0151 | ||
|
|
858788f05e | ||
|
|
41f8797a4c | ||
|
|
fc3fd7f25b | ||
|
|
48bbf9f1e0 | ||
|
|
3d6c595c6e | ||
|
|
d9d971b8b3 | ||
|
|
a5effb433b | ||
|
|
8802c95d31 | ||
|
|
a184cae560 | ||
|
|
fa6fe2aaca | ||
|
|
43a8cc37f8 | ||
|
|
e45143e33d | ||
|
|
1c1b947455 | ||
|
|
d7ddffd275 | ||
|
|
3cb97d2373 | ||
|
|
bad037b010 | ||
|
|
88777710a4 | ||
|
|
0ab5b6c49a | ||
|
|
22ff0d5e32 | ||
|
|
2f5de37d76 | ||
|
|
799d730f88 | ||
|
|
1a05eebfdb | ||
|
|
ebaa720e74 | ||
|
|
a505badffc | ||
|
|
45d938c901 | ||
|
|
a0498af626 | ||
|
|
bf027cf38f | ||
|
|
f2d7faa9af | ||
|
|
a0248a1cdf | ||
|
|
1e95e19f16 | ||
|
|
8137caf37b | ||
|
|
630243e243 | ||
|
|
39518c89f5 | ||
|
|
d631156f6c | ||
|
|
60e31884c8 | ||
|
|
8614e075b6 | ||
|
|
8a0cd5d0d1 | ||
|
|
3a64772cc5 | ||
|
|
1ecfb78d94 | ||
|
|
9ad276a757 | ||
|
|
4c4f8c8225 | ||
|
|
a00b2bd707 | ||
|
|
b8a055de45 | ||
|
|
4d07726acf | ||
|
|
2e523b6f49 | ||
|
|
8f9f12f1ee | ||
|
|
a875aa4055 | ||
|
|
775b2d5d7f | ||
|
|
3b399ea1a2 | ||
|
|
84f7cad678 | ||
|
|
778f439e1c | ||
|
|
1b95d4e1df | ||
|
|
512f6d4ee1 | ||
|
|
c52b614abb | ||
|
|
7b7afc7179 | ||
|
|
b1c6044533 | ||
|
|
38499dfe3c | ||
|
|
b58c29202a | ||
|
|
ca759ca967 | ||
|
|
3858bf80c1 | ||
|
|
a88a034ce2 | ||
|
|
6b2cd1147d | ||
|
|
bb8dcaf63e | ||
|
|
8e84b528ce | ||
|
|
8b59b4f515 | ||
|
|
dcc72e49a2 | ||
|
|
ce04c163db | ||
|
|
9f1e95d87f | ||
|
|
088bcbed0b | ||
|
|
57fbad6fa4 | ||
|
|
6926d5cb70 | ||
|
|
00c7df6a11 | ||
|
|
fbd03ed4a5 | ||
|
|
d3bd5a759f | ||
|
|
dedef79bef | ||
|
|
8db974877e | ||
|
|
e7d1531eae | ||
|
|
4785fe6002 | ||
|
|
22d6a7bf05 | ||
|
|
97757c0c3d | ||
|
|
ab60b42b85 | ||
|
|
febed8179b | ||
|
|
1bd83273e8 | ||
|
|
5e9fc89f80 | ||
|
|
2686663eb2 | ||
|
|
55801bc2ca | ||
|
|
6cecc16519 | ||
|
|
a57cf13e2e | ||
|
|
58f153afc4 | ||
|
|
7569da37e4 | ||
|
|
a8019a70da | ||
|
|
685f1dc43e | ||
|
|
220b3b0236 | ||
|
|
3495eb52ba | ||
|
|
1f7a1401eb | ||
|
|
ce2b02b62a | ||
|
|
5e55c0e358 | ||
|
|
ebeb0dc9f1 | ||
|
|
776bdae519 | ||
|
|
b2d9541f8f | ||
|
|
637224d5bc | ||
|
|
92ab171013 | ||
|
|
592475e2ed | ||
|
|
12bcdb7770 | ||
|
|
7a58f36020 | ||
|
|
ed0eb912c5 | ||
|
|
752ce6c830 | ||
|
|
8e509c18c9 | ||
|
|
cc21ed27c7 | ||
|
|
82d825071c | ||
|
|
b932bafe6d | ||
|
|
4e35aba033 | ||
|
|
0060ee8ee2 | ||
|
|
3263d71f54 | ||
|
|
f321143837 | ||
|
|
bac6f5baaf | ||
|
|
e027bcb57a | ||
|
|
eeb9de31ed | ||
|
|
4befc5bbae | ||
|
|
2c3af5b2bb | ||
|
|
dfb92e8ed1 | ||
|
|
73d2b54e30 | ||
|
|
8315a60f24 | ||
| 185d5fd577 | |||
|
|
ae5f9cf690 | ||
|
|
4b66a38fe6 | ||
|
|
f526f549ee | ||
|
|
da029a1749 | ||
|
|
8761129677 | ||
|
|
3f6f036270 | ||
|
|
859bb0609f | ||
|
|
5f2d24570e | ||
|
|
dbf94c8f3e | ||
|
|
b6adc29365 | ||
|
|
5caa7bfa90 | ||
|
|
f39d706fa0 | ||
|
|
c02c1f33d2 | ||
|
|
33435c2980 | ||
|
|
c08449d9db | ||
|
|
3c8718bb5b | ||
|
|
26e87f09fe | ||
|
|
7f5e0d190e | ||
|
|
efae307b3d | ||
|
|
26d38a855c | ||
|
|
7360a887d9 | ||
|
|
9756572c93 | ||
|
|
d6100755b1 | ||
|
|
a66eef6630 | ||
|
|
ae23ef7b9b | ||
|
|
f368b5e518 | ||
|
|
5293d32dc6 | ||
|
|
6d9a0bf4e1 | ||
|
|
3c7b5df7c5 | ||
|
|
70141c0439 | ||
|
|
dedc0aca54 | ||
|
|
7c019b574f | ||
|
|
9b485fd943 | ||
|
|
fdee8269ec | ||
|
|
0767f2d4ae | ||
|
|
c4a0846727 | ||
|
|
83ac70e426 | ||
|
|
01cce3525f | ||
|
|
b9d35aea47 | ||
|
|
079cf6b896 | ||
|
|
180655088c | ||
|
|
a1bade6f20 | ||
|
|
5d80e7fd80 | ||
|
|
2198692961 | ||
|
|
55d3fd90f5 | ||
|
|
afee659ca6 | ||
|
|
6fe7931d7d | ||
|
|
9023407ee4 | ||
|
|
54d961bbe5 | ||
|
|
cbd46adbcf | ||
|
|
745e107849 | ||
|
|
af466c2970 | ||
|
|
931e2de854 | ||
|
|
55eb7eb237 | ||
|
|
bade4502f9 | ||
|
|
9f952f202f | ||
|
|
1eb9d8d055 | ||
|
|
5a477eb391 | ||
|
|
86cda8771d | ||
|
|
c1ea0ddd35 | ||
|
|
f567711a6c | ||
|
|
509df4c676 | ||
|
|
b375ed07b4 | ||
|
|
69d62d3dd1 | ||
|
|
fe3fa3d505 | ||
|
|
27fcd43224 | ||
|
|
c3b2bb19d5 | ||
|
|
34287177b9 | ||
|
|
d238dd4059 | ||
|
|
865f3a249f | ||
|
|
7324d322fe | ||
|
|
af148b476d | ||
|
|
80d60aaf15 | ||
|
|
c80f89d20f | ||
|
|
a27f55a588 | ||
|
|
62e4670a39 | ||
|
|
99695bb264 | ||
|
|
eb54898106 | ||
|
|
4f5ee204d2 | ||
|
|
2552e21db1 | ||
|
|
6168f87e2f | ||
|
|
ca7d2ca4df | ||
|
|
60723323e9 | ||
|
|
3ce7b9255b | ||
|
|
97fcfc2fa0 | ||
|
|
19674e3758 | ||
|
|
1130e1db8f | ||
|
|
37c7f3a58a | ||
|
|
0a12b2bf2e | ||
|
|
d014acbe63 | ||
|
|
07f9997a49 | ||
|
|
b9f91f695a | ||
|
|
082d55af10 | ||
|
|
4c3fd5688d | ||
|
|
9d3d5495ce | ||
|
|
b3869f267c | ||
|
|
8715333706 | ||
|
|
b57096abe2 | ||
|
|
48685c8587 | ||
|
|
100bea6b41 | ||
|
|
63819bf9dd | ||
|
|
6e55390930 | ||
|
|
e3fdab4175 | ||
|
|
bbcd14dbf0 | ||
|
|
01dc0d574b | ||
|
|
5e959d638e | ||
|
|
8d908288c8 | ||
|
|
c88b32a406 | ||
|
|
5a72eefb89 | ||
|
|
430046944b | ||
|
|
21d23320eb | ||
|
|
d0990ee04d | ||
|
|
2d88e853e8 | ||
|
|
a060a70fba | ||
|
|
a06394ad4a | ||
|
|
a1414c2b5b | ||
|
|
b2864dac2d | ||
|
|
b78f895143 | ||
|
|
c4e9726828 | ||
|
|
d4b8e8348a | ||
|
|
19debaa52e | ||
|
|
73fe564321 | ||
|
|
a00abd65b3 | ||
|
|
f169ceaebb | ||
|
|
528af0d338 | ||
|
|
4b25eed869 | ||
|
|
fcd6bd7136 | ||
|
|
32642c5d7c | ||
|
|
ff8b0c375d | ||
|
|
ae0228aeb8 | ||
|
|
5d2dac18c8 | ||
|
|
d03fc14cfd | ||
|
|
ad7ce79bc4 | ||
|
|
c6bf27fd2c | ||
|
|
7584daa3f9 | ||
|
|
654030e789 | ||
|
|
1de7d2cd6f | ||
|
|
68db78c833 | ||
|
|
e1714c16cc | ||
|
|
0a20f14ea9 | ||
|
|
23f46b36b3 | ||
|
|
009649abd1 | ||
|
|
855a007116 | ||
|
|
d064de35e0 | ||
|
|
dab4d13303 | ||
|
|
2bed50b353 | ||
|
|
1fe3778a74 | ||
|
|
f5443a9826 | ||
|
|
db723a5196 | ||
|
|
5e31bcf23d | ||
|
|
fe429cb2eb | ||
|
|
c91695c23a | ||
|
|
55f99e6887 | ||
|
|
b190069f48 | ||
|
|
e16be1a8f4 | ||
|
|
2fa8075fb0 | ||
|
|
566ca13d23 | ||
|
|
e5666c0510 | ||
|
|
46ec39ccfb | ||
|
|
eef418ae5f | ||
|
|
9e663ad051 | ||
|
|
f28eac4c14 | ||
|
|
669bb3f3a8 | ||
|
|
347fe8b272 | ||
|
|
d56c4d0a11 | ||
|
|
034140ccbd | ||
|
|
35bef7d7b7 | ||
|
|
d069708c79 | ||
|
|
bdba5c9d95 |
26
.github/ci-gradle.properties
vendored
Normal file
26
.github/ci-gradle.properties
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
#
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
org.gradle.configureondemand=true
|
||||
org.gradle.caching=true
|
||||
org.gradle.parallel=true
|
||||
|
||||
# Declare we support AndroidX
|
||||
android.useAndroidX=true
|
||||
|
||||
org.gradle.jvmargs=-Xmx4608m -XX:MaxMetaspaceSize=1536m -XX:+HeapDumpOnOutOfMemoryError
|
||||
|
||||
kotlin.compiler.execution.strategy=in-process
|
||||
6
.github/workflows/code-check.yml
vendored
6
.github/workflows/code-check.yml
vendored
@@ -6,6 +6,8 @@ on:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
branches: [main]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -16,7 +18,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13.0"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
@@ -33,7 +35,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[build,test,development,pandora]"
|
||||
python -m pip install ".[build,test,development]"
|
||||
- name: Check
|
||||
run: |
|
||||
invoke project.pre-commit
|
||||
|
||||
2
.github/workflows/codeql-analysis.yml
vendored
2
.github/workflows/codeql-analysis.yml
vendored
@@ -17,6 +17,8 @@ on:
|
||||
pull_request:
|
||||
# The branches below must be a subset of the branches above
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
branches: [main]
|
||||
schedule:
|
||||
- cron: '39 21 * * 4'
|
||||
|
||||
|
||||
37
.github/workflows/gradle-btbench.yml
vendored
Normal file
37
.github/workflows/gradle-btbench.yml
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
name: Gradle Android Build & test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'extras/android/BtBench/**'
|
||||
workflow_dispatch:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'extras/android/BtBench/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 40
|
||||
|
||||
steps:
|
||||
- name: Check out from Git
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up JDK
|
||||
uses: actions/setup-java@v4
|
||||
with:
|
||||
distribution: 'zulu'
|
||||
java-version: 17
|
||||
|
||||
- name: Setup Gradle
|
||||
uses: gradle/actions/setup-gradle@v3
|
||||
|
||||
- name: Build with Gradle
|
||||
run: cd extras/android/BtBench && ./gradlew build
|
||||
11
.github/workflows/python-avatar.yml
vendored
11
.github/workflows/python-avatar.yml
vendored
@@ -5,6 +5,8 @@ on:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
branches: [main]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -32,7 +34,7 @@ jobs:
|
||||
- name: Install
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install .[avatar,pandora]
|
||||
python -m pip install .[avatar]
|
||||
- name: Rootcanal
|
||||
run: nohup python -m rootcanal > rootcanal.log &
|
||||
- name: Test
|
||||
@@ -40,4 +42,11 @@ jobs:
|
||||
avatar --list | grep -Ev '^=' > test-names.txt
|
||||
timeout 5m avatar --test-beds bumble.bumbles --tests $(split test-names.txt -n l/${{ matrix.shard }})
|
||||
- name: Rootcanal Logs
|
||||
if: always()
|
||||
run: cat rootcanal.log
|
||||
- name: Upload Mobly logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: mobly-logs-${{ strategy.job-index }}
|
||||
path: /tmp/logs/mobly/bumble.bumbles/
|
||||
|
||||
10
.github/workflows/python-build-test.yml
vendored
10
.github/workflows/python-build-test.yml
vendored
@@ -6,6 +6,8 @@ on:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
branches: [main]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -16,7 +18,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
@@ -46,8 +48,8 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
|
||||
rust-version: [ "1.76.0", "stable" ]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
rust-version: [ "1.80.0", "stable" ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out from Git
|
||||
@@ -70,7 +72,7 @@ jobs:
|
||||
- name: Check License Headers
|
||||
run: cd rust && cargo run --features dev-tools --bin file-header check-all
|
||||
- name: Rust Build
|
||||
run: cd rust && cargo build --all-targets && cargo build-all-features --all-targets
|
||||
run: cd rust && cargo build --all-targets && cargo build-all-features
|
||||
# Lints after build so what clippy needs is already built
|
||||
- name: Rust Lints
|
||||
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings && cargo clippy --all-features --all-targets -- --deny warnings
|
||||
|
||||
14
.vscode/settings.json
vendored
14
.vscode/settings.json
vendored
@@ -14,9 +14,12 @@
|
||||
"ASHA",
|
||||
"asyncio",
|
||||
"ATRAC",
|
||||
"auracast",
|
||||
"avctp",
|
||||
"avdtp",
|
||||
"avrcp",
|
||||
"biginfo",
|
||||
"bigs",
|
||||
"bitpool",
|
||||
"bitstruct",
|
||||
"BSCP",
|
||||
@@ -36,6 +39,7 @@
|
||||
"deregistration",
|
||||
"dhkey",
|
||||
"diversifier",
|
||||
"ediv",
|
||||
"endianness",
|
||||
"ESCO",
|
||||
"Fitbit",
|
||||
@@ -47,6 +51,7 @@
|
||||
"libc",
|
||||
"liblc",
|
||||
"libusb",
|
||||
"maxs",
|
||||
"MITM",
|
||||
"MSBC",
|
||||
"NDIS",
|
||||
@@ -54,8 +59,10 @@
|
||||
"NONBLOCK",
|
||||
"NONCONN",
|
||||
"OXIMETER",
|
||||
"PDUS",
|
||||
"popleft",
|
||||
"PRAND",
|
||||
"prefs",
|
||||
"protobuf",
|
||||
"psms",
|
||||
"pyee",
|
||||
@@ -95,5 +102,10 @@
|
||||
"."
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
"python.testing.pytestEnabled": true,
|
||||
"python-envs.defaultEnvManager": "ms-python.python:system",
|
||||
"python-envs.pythonProjects": [],
|
||||
"nrf-connect.applications": [
|
||||
"${workspaceFolder}/extras/zephyr/hci_usb"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ Bumble is easiest to use with a dedicated USB dongle.
|
||||
This is because internal Bluetooth interfaces tend to be locked down by the operating system.
|
||||
You can use the [usb_probe](/docs/mkdocs/src/apps_and_tools/usb_probe.md) tool (all platforms) or `lsusb` (Linux or macOS) to list the available USB devices on your system.
|
||||
|
||||
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices. Also, if your are on a mac, see [these instructions](docs/mkdocs/src/platforms/macos.md).
|
||||
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices. Also, if you are on a mac, see [these instructions](docs/mkdocs/src/platforms/macos.md).
|
||||
|
||||
## License
|
||||
|
||||
|
||||
@@ -12,9 +12,6 @@ Apps
|
||||
## `show.py`
|
||||
Parse a file with HCI packets and print the details of each packet in a human readable form
|
||||
|
||||
## `link_relay.py`
|
||||
Simple WebSocket relay for virtual RemoteLink instances to communicate with each other through.
|
||||
|
||||
## `hci_bridge.py`
|
||||
This app acts as a simple bridge between two HCI transports, with a host on one side and
|
||||
a controller on the other. All the HCI packets bridged between the two are printed on the console
|
||||
|
||||
774
apps/auracast.py
774
apps/auracast.py
File diff suppressed because it is too large
Load Diff
1055
apps/bench.py
1055
apps/bench.py
File diff suppressed because it is too large
Load Diff
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import click
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble.hci import Address
|
||||
from bumble.helpers import generate_irk, verify_rpa_with_irk
|
||||
|
||||
@@ -22,54 +22,56 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import humanize
|
||||
from typing import Optional, Union
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Union
|
||||
|
||||
import click
|
||||
import humanize
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from prompt_toolkit import Application
|
||||
from prompt_toolkit.history import FileHistory
|
||||
from prompt_toolkit.completion import Completer, Completion, NestedCompleter
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
from prompt_toolkit.formatted_text import ANSI
|
||||
from prompt_toolkit.styles import Style
|
||||
from prompt_toolkit.filters import Condition
|
||||
from prompt_toolkit.widgets import TextArea, Frame
|
||||
from prompt_toolkit.widgets.toolbars import FormattedTextToolbar
|
||||
from prompt_toolkit.data_structures import Point
|
||||
from prompt_toolkit.filters import Condition
|
||||
from prompt_toolkit.formatted_text import ANSI
|
||||
from prompt_toolkit.history import FileHistory
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
from prompt_toolkit.layout import (
|
||||
Layout,
|
||||
HSplit,
|
||||
Window,
|
||||
CompletionsMenu,
|
||||
Float,
|
||||
FormattedTextControl,
|
||||
FloatContainer,
|
||||
ConditionalContainer,
|
||||
Dimension,
|
||||
Float,
|
||||
FloatContainer,
|
||||
FormattedTextControl,
|
||||
HSplit,
|
||||
Layout,
|
||||
Window,
|
||||
)
|
||||
from prompt_toolkit.styles import Style
|
||||
from prompt_toolkit.widgets import Frame, TextArea
|
||||
from prompt_toolkit.widgets.toolbars import FormattedTextToolbar
|
||||
|
||||
from bumble import __version__
|
||||
import bumble.core
|
||||
from bumble import colors
|
||||
from bumble.core import UUID, AdvertisingData, BT_LE_TRANSPORT
|
||||
from bumble.device import ConnectionParametersPreferences, Device, Connection, Peer
|
||||
from bumble.utils import AsyncRunner
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
|
||||
from bumble import __version__, colors
|
||||
from bumble.core import UUID, AdvertisingData
|
||||
from bumble.device import (
|
||||
Connection,
|
||||
ConnectionParametersPreferences,
|
||||
ConnectionPHY,
|
||||
Device,
|
||||
Peer,
|
||||
)
|
||||
from bumble.gatt import Characteristic, CharacteristicDeclaration, Descriptor, Service
|
||||
from bumble.gatt_client import CharacteristicProxy
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
HCI_Constant,
|
||||
HCI_LE_1M_PHY,
|
||||
HCI_LE_2M_PHY,
|
||||
HCI_LE_CODED_PHY,
|
||||
Address,
|
||||
HCI_Constant,
|
||||
)
|
||||
|
||||
from bumble.transport import open_transport
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -125,6 +127,7 @@ def parse_phys(phys):
|
||||
# -----------------------------------------------------------------------------
|
||||
class ConsoleApp:
|
||||
connected_peer: Optional[Peer]
|
||||
connection_phy: Optional[ConnectionPHY]
|
||||
|
||||
def __init__(self):
|
||||
self.known_addresses = set()
|
||||
@@ -132,6 +135,7 @@ class ConsoleApp:
|
||||
self.known_local_attributes = []
|
||||
self.device = None
|
||||
self.connected_peer = None
|
||||
self.connection_phy = None
|
||||
self.top_tab = 'device'
|
||||
self.monitor_rssi = False
|
||||
self.connection_rssi = None
|
||||
@@ -284,7 +288,7 @@ class ConsoleApp:
|
||||
async def run_async(self, device_config, transport):
|
||||
rssi_monitoring_task = asyncio.create_task(self.rssi_monitor_loop())
|
||||
|
||||
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
|
||||
async with await open_transport(transport) as (hci_source, hci_sink):
|
||||
if device_config:
|
||||
self.device = Device.from_config_file_with_hci(
|
||||
device_config, hci_source, hci_sink
|
||||
@@ -328,14 +332,14 @@ class ConsoleApp:
|
||||
elif self.connected_peer:
|
||||
connection = self.connected_peer.connection
|
||||
connection_parameters = (
|
||||
f'{connection.parameters.connection_interval}/'
|
||||
f'{connection.parameters.connection_interval:.2f}/'
|
||||
f'{connection.parameters.peripheral_latency}/'
|
||||
f'{connection.parameters.supervision_timeout}'
|
||||
f'{connection.parameters.supervision_timeout:.2f}'
|
||||
)
|
||||
if connection.transport == BT_LE_TRANSPORT:
|
||||
if self.connection_phy is not None:
|
||||
phy_state = (
|
||||
f' RX={le_phy_name(connection.phy.rx_phy)}/'
|
||||
f'TX={le_phy_name(connection.phy.tx_phy)}'
|
||||
f' RX={le_phy_name(self.connection_phy.rx_phy)}/'
|
||||
f'TX={le_phy_name(self.connection_phy.tx_phy)}'
|
||||
)
|
||||
else:
|
||||
phy_state = ''
|
||||
@@ -654,11 +658,12 @@ class ConsoleApp:
|
||||
self.append_to_output('connecting...')
|
||||
|
||||
try:
|
||||
await self.device.connect(
|
||||
connection = await self.device.connect(
|
||||
params[0],
|
||||
connection_parameters_preferences=connection_parameters_preferences,
|
||||
timeout=DEFAULT_CONNECTION_TIMEOUT,
|
||||
)
|
||||
self.connection_phy = await connection.get_phy()
|
||||
self.top_tab = 'services'
|
||||
except bumble.core.TimeoutError:
|
||||
self.show_error('connection timed out')
|
||||
@@ -838,8 +843,8 @@ class ConsoleApp:
|
||||
|
||||
phy = await self.connected_peer.connection.get_phy()
|
||||
self.append_to_output(
|
||||
f'PHY: RX={HCI_Constant.le_phy_name(phy[0])}, '
|
||||
f'TX={HCI_Constant.le_phy_name(phy[1])}'
|
||||
f'PHY: RX={HCI_Constant.le_phy_name(phy.rx_phy)}, '
|
||||
f'TX={HCI_Constant.le_phy_name(phy.tx_phy)}'
|
||||
)
|
||||
|
||||
async def do_request_mtu(self, params):
|
||||
@@ -1076,10 +1081,9 @@ class DeviceListener(Device.Listener, Connection.Listener):
|
||||
f'{self.app.connected_peer.connection.parameters}'
|
||||
)
|
||||
|
||||
def on_connection_phy_update(self):
|
||||
self.app.append_to_output(
|
||||
f'connection phy update: {self.app.connected_peer.connection.phy}'
|
||||
)
|
||||
def on_connection_phy_update(self, phy):
|
||||
self.app.connection_phy = phy
|
||||
self.app.append_to_output(f'connection phy update: {phy}')
|
||||
|
||||
def on_connection_att_mtu_update(self):
|
||||
self.app.append_to_output(
|
||||
|
||||
@@ -16,44 +16,48 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
|
||||
from bumble.company_ids import COMPANY_IDENTIFIERS
|
||||
import bumble.logging
|
||||
from bumble.colors import color
|
||||
from bumble.company_ids import COMPANY_IDENTIFIERS
|
||||
from bumble.core import name_or_number
|
||||
from bumble.hci import (
|
||||
map_null_terminated_utf8_string,
|
||||
LeFeature,
|
||||
HCI_LE_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_LE_READ_BUFFER_SIZE_V2_COMMAND,
|
||||
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
|
||||
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
|
||||
HCI_READ_BD_ADDR_COMMAND,
|
||||
HCI_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_READ_LOCAL_NAME_COMMAND,
|
||||
HCI_SUCCESS,
|
||||
HCI_VERSION_NAMES,
|
||||
LMP_VERSION_NAMES,
|
||||
CodecID,
|
||||
HCI_Command,
|
||||
HCI_Command_Complete_Event,
|
||||
HCI_Command_Status_Event,
|
||||
HCI_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_Read_Buffer_Size_Command,
|
||||
HCI_READ_BD_ADDR_COMMAND,
|
||||
HCI_Read_BD_ADDR_Command,
|
||||
HCI_READ_LOCAL_NAME_COMMAND,
|
||||
HCI_Read_Local_Name_Command,
|
||||
HCI_LE_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_LE_Read_Buffer_Size_Command,
|
||||
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Maximum_Data_Length_Command,
|
||||
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
|
||||
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
|
||||
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Buffer_Size_V2_Command,
|
||||
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
|
||||
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Maximum_Data_Length_Command,
|
||||
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
|
||||
HCI_LE_Read_Suggested_Default_Data_Length_Command,
|
||||
HCI_Read_BD_ADDR_Command,
|
||||
HCI_Read_Buffer_Size_Command,
|
||||
HCI_Read_Local_Name_Command,
|
||||
HCI_Read_Local_Supported_Codecs_Command,
|
||||
HCI_Read_Local_Supported_Codecs_V2_Command,
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
LeFeature,
|
||||
map_null_terminated_utf8_string,
|
||||
)
|
||||
from bumble.host import Host
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.transport import open_transport
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -72,7 +76,7 @@ async def get_classic_info(host: Host) -> None:
|
||||
if command_succeeded(response):
|
||||
print()
|
||||
print(
|
||||
color('Classic Address:', 'yellow'),
|
||||
color('Public Address:', 'yellow'),
|
||||
response.return_parameters.bd_addr.to_string(False),
|
||||
)
|
||||
|
||||
@@ -144,7 +148,7 @@ async def get_le_info(host: Host) -> None:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_acl_flow_control_info(host: Host) -> None:
|
||||
async def get_flow_control_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
|
||||
@@ -157,40 +161,123 @@ async def get_acl_flow_control_info(host: Host) -> None:
|
||||
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('LE ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.total_num_le_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.le_acl_data_packet_length}',
|
||||
)
|
||||
print(
|
||||
color('LE ISO Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.total_num_iso_data_packets} '
|
||||
f'packets of size {response.return_parameters.iso_data_packet_length}',
|
||||
)
|
||||
elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_LE_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('LE ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
|
||||
f'{response.return_parameters.total_num_le_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.le_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(latency_probes, transport):
|
||||
async def get_codecs_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_Read_Local_Supported_Codecs_V2_Command.op_code):
|
||||
response = await host.send_command(
|
||||
HCI_Read_Local_Supported_Codecs_V2_Command(), check_result=True
|
||||
)
|
||||
print(color('Codecs:', 'yellow'))
|
||||
|
||||
for codec_id, transport in zip(
|
||||
response.return_parameters.standard_codec_ids,
|
||||
response.return_parameters.standard_codec_transports,
|
||||
):
|
||||
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
|
||||
transport
|
||||
).name
|
||||
codec_name = CodecID(codec_id).name
|
||||
print(f' {codec_name} - {transport_name}')
|
||||
|
||||
for codec_id, transport in zip(
|
||||
response.return_parameters.vendor_specific_codec_ids,
|
||||
response.return_parameters.vendor_specific_codec_transports,
|
||||
):
|
||||
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
|
||||
transport
|
||||
).name
|
||||
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
|
||||
print(f' {company} / {codec_id & 0xFFFF} - {transport_name}')
|
||||
|
||||
if not response.return_parameters.standard_codec_ids:
|
||||
print(' No standard codecs')
|
||||
if not response.return_parameters.vendor_specific_codec_ids:
|
||||
print(' No Vendor-specific codecs')
|
||||
|
||||
if host.supports_command(HCI_Read_Local_Supported_Codecs_Command.op_code):
|
||||
response = await host.send_command(
|
||||
HCI_Read_Local_Supported_Codecs_Command(), check_result=True
|
||||
)
|
||||
print(color('Codecs (BR/EDR):', 'yellow'))
|
||||
for codec_id in response.return_parameters.standard_codec_ids:
|
||||
codec_name = CodecID(codec_id).name
|
||||
print(f' {codec_name}')
|
||||
|
||||
for codec_id in response.return_parameters.vendor_specific_codec_ids:
|
||||
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
|
||||
print(f' {company} / {codec_id & 0xFFFF}')
|
||||
|
||||
if not response.return_parameters.standard_codec_ids:
|
||||
print(' No standard codecs')
|
||||
if not response.return_parameters.vendor_specific_codec_ids:
|
||||
print(' No Vendor-specific codecs')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(
|
||||
latency_probes, latency_probe_interval, latency_probe_command, transport
|
||||
):
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
|
||||
async with await open_transport(transport) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
|
||||
host = Host(hci_source, hci_sink)
|
||||
await host.reset()
|
||||
|
||||
# Measure the latency if requested
|
||||
# (we add an extra probe at the start, that we ignore, just to ensure that
|
||||
# the transport is primed)
|
||||
latencies = []
|
||||
if latency_probes:
|
||||
for _ in range(latency_probes):
|
||||
if latency_probe_command:
|
||||
probe_hci_command = HCI_Command.from_bytes(
|
||||
bytes.fromhex(latency_probe_command)
|
||||
)
|
||||
else:
|
||||
probe_hci_command = HCI_Read_Local_Version_Information_Command()
|
||||
|
||||
for iteration in range(1 + latency_probes):
|
||||
if latency_probe_interval:
|
||||
await asyncio.sleep(latency_probe_interval / 1000)
|
||||
start = time.time()
|
||||
await host.send_command(HCI_Read_Local_Version_Information_Command())
|
||||
latencies.append(1000 * (time.time() - start))
|
||||
await host.send_command(probe_hci_command)
|
||||
if iteration:
|
||||
latencies.append(1000 * (time.time() - start))
|
||||
print(
|
||||
color('HCI Command Latency:', 'yellow'),
|
||||
(
|
||||
f'min={min(latencies):.2f}, '
|
||||
f'max={max(latencies):.2f}, '
|
||||
f'average={sum(latencies)/len(latencies):.2f}'
|
||||
f'average={sum(latencies)/len(latencies):.2f},'
|
||||
),
|
||||
[f'{latency:.4}' for latency in latencies],
|
||||
'\n',
|
||||
)
|
||||
|
||||
@@ -217,8 +304,11 @@ async def async_main(latency_probes, transport):
|
||||
# Get the LE info
|
||||
await get_le_info(host)
|
||||
|
||||
# Print the ACL flow control info
|
||||
await get_acl_flow_control_info(host)
|
||||
# Print the flow control info
|
||||
await get_flow_control_info(host)
|
||||
|
||||
# Get codec info
|
||||
await get_codecs_info(host)
|
||||
|
||||
# Print the list of commands supported by the controller
|
||||
print()
|
||||
@@ -235,10 +325,28 @@ async def async_main(latency_probes, transport):
|
||||
type=int,
|
||||
help='Send N commands to measure HCI transport latency statistics',
|
||||
)
|
||||
@click.option(
|
||||
'--latency-probe-interval',
|
||||
metavar='INTERVAL',
|
||||
type=int,
|
||||
help='Interval between latency probes (milliseconds)',
|
||||
)
|
||||
@click.option(
|
||||
'--latency-probe-command',
|
||||
metavar='COMMAND_HEX',
|
||||
help=(
|
||||
'Probe command (HCI Command packet bytes, in hex. Use 0177FC00 for'
|
||||
' a loopback test with the HCI remote proxy app)'
|
||||
),
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(latency_probes, transport):
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
asyncio.run(async_main(latency_probes, transport))
|
||||
def main(latency_probes, latency_probe_interval, latency_probe_command, transport):
|
||||
bumble.logging.setup_basic_logging()
|
||||
asyncio.run(
|
||||
async_main(
|
||||
latency_probes, latency_probe_interval, latency_probe_command, transport
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -16,21 +16,22 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
import bumble.logging
|
||||
from bumble.colors import color
|
||||
from bumble.hci import (
|
||||
HCI_READ_LOOPBACK_MODE_COMMAND,
|
||||
HCI_Read_Loopback_Mode_Command,
|
||||
HCI_WRITE_LOOPBACK_MODE_COMMAND,
|
||||
HCI_Read_Loopback_Mode_Command,
|
||||
HCI_Write_Loopback_Mode_Command,
|
||||
LoopbackMode,
|
||||
)
|
||||
from bumble.host import Host
|
||||
from bumble.transport import open_transport_or_link
|
||||
import click
|
||||
from bumble.transport import open_transport
|
||||
|
||||
|
||||
class Loopback:
|
||||
@@ -88,7 +89,7 @@ class Loopback:
|
||||
async def run(self):
|
||||
"""Run a loopback throughput test"""
|
||||
print(color('>>> Connecting to HCI...', 'green'))
|
||||
async with await open_transport_or_link(self.transport) as (
|
||||
async with await open_transport(self.transport) as (
|
||||
hci_source,
|
||||
hci_sink,
|
||||
):
|
||||
@@ -194,8 +195,7 @@ class Loopback:
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(packet_size, packet_count, transport):
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
|
||||
bumble.logging.setup_basic_logging()
|
||||
loopback = Loopback(packet_size, packet_count, transport)
|
||||
asyncio.run(loopback.run())
|
||||
|
||||
|
||||
@@ -15,14 +15,13 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import logging
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
import bumble.logging
|
||||
from bumble.controller import Controller
|
||||
from bumble.link import LocalLink
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.transport import open_transport
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -42,7 +41,7 @@ async def async_main():
|
||||
transports = []
|
||||
controllers = []
|
||||
for index, transport_name in enumerate(sys.argv[1:]):
|
||||
transport = await open_transport_or_link(transport_name)
|
||||
transport = await open_transport(transport_name)
|
||||
transports.append(transport)
|
||||
controller = Controller(
|
||||
f'C{index}',
|
||||
@@ -62,7 +61,7 @@ async def async_main():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
bumble.logging.setup_basic_logging()
|
||||
asyncio.run(async_main())
|
||||
|
||||
|
||||
|
||||
@@ -16,21 +16,22 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
import click
|
||||
|
||||
from bumble.core import ProtocolError
|
||||
import bumble.logging
|
||||
from bumble.colors import color
|
||||
from bumble.core import ProtocolError
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.gatt import Service
|
||||
from bumble.profiles.device_information_service import DeviceInformationServiceProxy
|
||||
from bumble.profiles.battery_service import BatteryServiceProxy
|
||||
from bumble.profiles.device_information_service import DeviceInformationServiceProxy
|
||||
from bumble.profiles.gap import GenericAccessServiceProxy
|
||||
from bumble.profiles.pacs import PublishedAudioCapabilitiesServiceProxy
|
||||
from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.profiles.vcs import VolumeControlServiceProxy
|
||||
from bumble.transport import open_transport
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -126,14 +127,52 @@ async def show_tmas(
|
||||
print(color('### Telephony And Media Audio Service', 'yellow'))
|
||||
|
||||
if tmas.role:
|
||||
print(
|
||||
color(' Role:', 'green'),
|
||||
await tmas.role.read_value(),
|
||||
)
|
||||
role = await tmas.role.read_value()
|
||||
print(color(' Role:', 'green'), role)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def show_pacs(pacs: PublishedAudioCapabilitiesServiceProxy) -> None:
|
||||
print(color('### Published Audio Capabilities Service', 'yellow'))
|
||||
|
||||
contexts = await pacs.available_audio_contexts.read_value()
|
||||
print(color(' Available Audio Contexts:', 'green'), contexts)
|
||||
|
||||
contexts = await pacs.supported_audio_contexts.read_value()
|
||||
print(color(' Supported Audio Contexts:', 'green'), contexts)
|
||||
|
||||
if pacs.sink_pac:
|
||||
pac = await pacs.sink_pac.read_value()
|
||||
print(color(' Sink PAC: ', 'green'), pac)
|
||||
|
||||
if pacs.sink_audio_locations:
|
||||
audio_locations = await pacs.sink_audio_locations.read_value()
|
||||
print(color(' Sink Audio Locations: ', 'green'), audio_locations)
|
||||
|
||||
if pacs.source_pac:
|
||||
pac = await pacs.source_pac.read_value()
|
||||
print(color(' Source PAC: ', 'green'), pac)
|
||||
|
||||
if pacs.source_audio_locations:
|
||||
audio_locations = await pacs.source_audio_locations.read_value()
|
||||
print(color(' Source Audio Locations: ', 'green'), audio_locations)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def show_vcs(vcs: VolumeControlServiceProxy) -> None:
|
||||
print(color('### Volume Control Service', 'yellow'))
|
||||
|
||||
volume_state = await vcs.volume_state.read_value()
|
||||
print(color(' Volume State:', 'green'), volume_state)
|
||||
|
||||
volume_flags = await vcs.volume_flags.read_value()
|
||||
print(color(' Volume Flags:', 'green'), volume_flags)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
|
||||
try:
|
||||
@@ -161,6 +200,12 @@ async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
|
||||
if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy):
|
||||
await try_show(show_tmas, tmas)
|
||||
|
||||
if pacs := peer.create_service_proxy(PublishedAudioCapabilitiesServiceProxy):
|
||||
await try_show(show_pacs, pacs)
|
||||
|
||||
if vcs := peer.create_service_proxy(VolumeControlServiceProxy):
|
||||
await try_show(show_vcs, vcs)
|
||||
|
||||
if done is not None:
|
||||
done.set_result(None)
|
||||
except asyncio.CancelledError:
|
||||
@@ -169,7 +214,7 @@ async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(device_config, encrypt, transport, address_or_name):
|
||||
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
|
||||
async with await open_transport(transport) as (hci_source, hci_sink):
|
||||
|
||||
# Create a device
|
||||
if device_config:
|
||||
@@ -221,7 +266,7 @@ def main(device_config, encrypt, transport, address_or_name):
|
||||
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
|
||||
wait for an incoming connection.
|
||||
"""
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
bumble.logging.setup_basic_logging()
|
||||
asyncio.run(async_main(device_config, encrypt, transport, address_or_name))
|
||||
|
||||
|
||||
|
||||
@@ -16,15 +16,15 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
|
||||
import click
|
||||
|
||||
import bumble.core
|
||||
import bumble.logging
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.gatt import show_services
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.transport import open_transport
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -60,7 +60,7 @@ async def dump_gatt_db(peer, done):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(device_config, encrypt, transport, address_or_name):
|
||||
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
|
||||
async with await open_transport(transport) as (hci_source, hci_sink):
|
||||
|
||||
# Create a device
|
||||
if device_config:
|
||||
@@ -112,7 +112,7 @@ def main(device_config, encrypt, transport, address_or_name):
|
||||
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
|
||||
wait for an incoming connection.
|
||||
"""
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
bumble.logging.setup_basic_logging()
|
||||
asyncio.run(async_main(device_config, encrypt, transport, address_or_name))
|
||||
|
||||
|
||||
|
||||
@@ -16,20 +16,19 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import os
|
||||
import struct
|
||||
import logging
|
||||
|
||||
import click
|
||||
|
||||
import bumble.logging
|
||||
from bumble import l2cap
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.gatt import Service, Characteristic, CharacteristicValue
|
||||
from bumble.utils import AsyncRunner
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.gatt import Characteristic, CharacteristicValue, Service
|
||||
from bumble.hci import HCI_Constant
|
||||
|
||||
from bumble.transport import open_transport
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -234,7 +233,7 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
|
||||
Characteristic.WRITEABLE,
|
||||
CharacteristicValue(write=self.on_rx_write),
|
||||
)
|
||||
self.tx_characteristic = Characteristic(
|
||||
self.tx_characteristic: Characteristic[bytes] = Characteristic(
|
||||
GG_GATTLINK_TX_CHARACTERISTIC_UUID,
|
||||
Characteristic.Properties.NOTIFY,
|
||||
Characteristic.READABLE,
|
||||
@@ -325,7 +324,7 @@ async def run(
|
||||
receive_port,
|
||||
):
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
|
||||
async with await open_transport(hci_transport) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
|
||||
# Instantiate a bridge object
|
||||
@@ -383,6 +382,7 @@ def main(
|
||||
receive_host,
|
||||
receive_port,
|
||||
):
|
||||
bumble.logging.setup_basic_logging('WARNING')
|
||||
asyncio.run(
|
||||
run(
|
||||
hci_transport,
|
||||
@@ -397,6 +397,5 @@ def main(
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@@ -12,14 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import bumble.logging
|
||||
from bumble import hci, transport
|
||||
from bumble.bridge import HCI_Bridge
|
||||
|
||||
@@ -46,14 +47,14 @@ async def async_main():
|
||||
return
|
||||
|
||||
print('>>> connecting to HCI...')
|
||||
async with await transport.open_transport_or_link(sys.argv[1]) as (
|
||||
async with await transport.open_transport(sys.argv[1]) as (
|
||||
hci_host_source,
|
||||
hci_host_sink,
|
||||
):
|
||||
print('>>> connected')
|
||||
|
||||
print('>>> connecting to HCI...')
|
||||
async with await transport.open_transport_or_link(sys.argv[2]) as (
|
||||
async with await transport.open_transport(sys.argv[2]) as (
|
||||
hci_controller_source,
|
||||
hci_controller_sink,
|
||||
):
|
||||
@@ -83,7 +84,7 @@ async def async_main():
|
||||
return_parameters=bytes([hci.HCI_SUCCESS]),
|
||||
)
|
||||
# Return a packet with 'respond to sender' set to True
|
||||
return (response.to_bytes(), True)
|
||||
return (bytes(response), True)
|
||||
|
||||
return None
|
||||
|
||||
@@ -100,7 +101,7 @@ async def async_main():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
bumble.logging.setup_basic_logging()
|
||||
asyncio.run(async_main())
|
||||
|
||||
|
||||
|
||||
@@ -16,16 +16,16 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
import bumble.logging
|
||||
from bumble import l2cap
|
||||
from bumble.colors import color
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.device import Device
|
||||
from bumble.utils import FlowControlAsyncPipe
|
||||
from bumble.hci import HCI_Constant
|
||||
from bumble.transport import open_transport
|
||||
from bumble.utils import FlowControlAsyncPipe
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -258,7 +258,7 @@ class ClientBridge:
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run(device_config, hci_transport, bridge):
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
|
||||
async with await open_transport(hci_transport) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
|
||||
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
|
||||
@@ -356,6 +356,6 @@ def client(context, bluetooth_address, tcp_host, tcp_port):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
if __name__ == '__main__':
|
||||
bumble.logging.setup_basic_logging('WARNING')
|
||||
cli(obj={}) # pylint: disable=no-value-for-parameter
|
||||
|
||||
@@ -16,34 +16,34 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import enum
|
||||
import functools
|
||||
from importlib import resources
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Optional, List, cast
|
||||
import wave
|
||||
import weakref
|
||||
import struct
|
||||
from importlib import resources
|
||||
|
||||
import ctypes
|
||||
import wasmtime
|
||||
import wasmtime.loader
|
||||
import liblc3 # type: ignore
|
||||
try:
|
||||
import lc3 # type: ignore # pylint: disable=E0401
|
||||
except ImportError as e:
|
||||
raise ImportError("Try `python -m pip install \".[lc3]\"`.") from e
|
||||
|
||||
import click
|
||||
import aiohttp.web
|
||||
import click
|
||||
|
||||
import bumble
|
||||
from bumble.core import AdvertisingData
|
||||
import bumble.logging
|
||||
from bumble import data_types, utils
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
|
||||
from bumble.transport import open_transport
|
||||
from bumble.profiles import ascs, bap, pacs
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.device import AdvertisingParameters, CisLink, Device, DeviceConfiguration
|
||||
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
|
||||
from bumble.profiles import ascs, bap, pacs
|
||||
from bumble.transport import open_transport
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -54,6 +54,7 @@ logger = logging.getLogger(__name__)
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
DEFAULT_UI_PORT = 7654
|
||||
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
|
||||
|
||||
|
||||
def _sink_pac_record() -> pacs.PacRecord:
|
||||
@@ -100,153 +101,8 @@ def _source_pac_record() -> pacs.PacRecord:
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# WASM - liblc3
|
||||
# -----------------------------------------------------------------------------
|
||||
store = wasmtime.loader.store
|
||||
_memory = cast(wasmtime.Memory, liblc3.memory)
|
||||
STACK_POINTER = _memory.data_len(store)
|
||||
_memory.grow(store, 1)
|
||||
# Mapping wasmtime memory to linear address
|
||||
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
|
||||
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class Liblc3PcmFormat(enum.IntEnum):
|
||||
S16 = 0
|
||||
S24 = 1
|
||||
S24_3LE = 2
|
||||
FLOAT = 3
|
||||
|
||||
|
||||
MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
|
||||
MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)
|
||||
|
||||
DECODER_STACK_POINTER = STACK_POINTER
|
||||
ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
|
||||
DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
|
||||
ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
|
||||
DEFAULT_PCM_SAMPLE_RATE = 48000
|
||||
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
|
||||
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
|
||||
|
||||
|
||||
encoders: List[int] = []
|
||||
decoders: List[int] = []
|
||||
|
||||
|
||||
def setup_encoders(
|
||||
sample_rate_hz: int, frame_duration_us: int, num_channels: int
|
||||
) -> None:
|
||||
logger.info(
|
||||
f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
|
||||
)
|
||||
encoders[:num_channels] = [
|
||||
liblc3.lc3_setup_encoder(
|
||||
frame_duration_us,
|
||||
sample_rate_hz,
|
||||
DEFAULT_PCM_SAMPLE_RATE, # Input sample rate
|
||||
ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
|
||||
)
|
||||
for i in range(num_channels)
|
||||
]
|
||||
|
||||
|
||||
def setup_decoders(
|
||||
sample_rate_hz: int, frame_duration_us: int, num_channels: int
|
||||
) -> None:
|
||||
logger.info(
|
||||
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
|
||||
)
|
||||
decoders[:num_channels] = [
|
||||
liblc3.lc3_setup_decoder(
|
||||
frame_duration_us,
|
||||
sample_rate_hz,
|
||||
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
|
||||
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
|
||||
)
|
||||
for i in range(num_channels)
|
||||
]
|
||||
|
||||
|
||||
def decode(
|
||||
frame_duration_us: int,
|
||||
num_channels: int,
|
||||
input_bytes: bytes,
|
||||
) -> bytes:
|
||||
if not input_bytes:
|
||||
return b''
|
||||
|
||||
input_buffer_offset = DECODE_BUFFER_STACK_POINTER
|
||||
input_buffer_size = len(input_bytes)
|
||||
input_bytes_per_frame = input_buffer_size // num_channels
|
||||
|
||||
# Copy into wasm
|
||||
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
|
||||
|
||||
output_buffer_offset = input_buffer_offset + input_buffer_size
|
||||
output_buffer_size = (
|
||||
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
|
||||
* DEFAULT_PCM_BYTES_PER_SAMPLE
|
||||
* num_channels
|
||||
)
|
||||
|
||||
for i in range(num_channels):
|
||||
res = liblc3.lc3_decode(
|
||||
decoders[i],
|
||||
input_buffer_offset + input_bytes_per_frame * i,
|
||||
input_bytes_per_frame,
|
||||
DEFAULT_PCM_FORMAT,
|
||||
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
|
||||
num_channels, # Stride
|
||||
)
|
||||
|
||||
if res != 0:
|
||||
logging.error(f"Parsing failed, res={res}")
|
||||
|
||||
# Extract decoded data from the output buffer
|
||||
return bytes(
|
||||
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
|
||||
)
|
||||
|
||||
|
||||
def encode(
|
||||
sdu_length: int,
|
||||
num_channels: int,
|
||||
stride: int,
|
||||
input_bytes: bytes,
|
||||
) -> bytes:
|
||||
if not input_bytes:
|
||||
return b''
|
||||
|
||||
input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
|
||||
input_buffer_size = len(input_bytes)
|
||||
|
||||
# Copy into wasm
|
||||
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
|
||||
|
||||
output_buffer_offset = input_buffer_offset + input_buffer_size
|
||||
output_buffer_size = sdu_length
|
||||
output_frame_size = output_buffer_size // num_channels
|
||||
|
||||
for i in range(num_channels):
|
||||
res = liblc3.lc3_encode(
|
||||
encoders[i],
|
||||
DEFAULT_PCM_FORMAT,
|
||||
input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
|
||||
stride,
|
||||
output_frame_size,
|
||||
output_buffer_offset + output_frame_size * i,
|
||||
)
|
||||
|
||||
if res != 0:
|
||||
logging.error(f"Parsing failed, res={res}")
|
||||
|
||||
# Extract decoded data from the output buffer
|
||||
return bytes(
|
||||
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
|
||||
)
|
||||
decoder: lc3.Decoder | None = None
|
||||
encoding_config: bap.CodecSpecificConfiguration | None = None
|
||||
|
||||
|
||||
async def lc3_source_task(
|
||||
@@ -254,44 +110,49 @@ async def lc3_source_task(
|
||||
sdu_length: int,
|
||||
frame_duration_us: int,
|
||||
device: Device,
|
||||
cis_handle: int,
|
||||
cis_link: CisLink,
|
||||
) -> None:
|
||||
with open(filename, 'rb') as f:
|
||||
header = f.read(44)
|
||||
assert header[8:12] == b'WAVE'
|
||||
logger.info(
|
||||
"lc3_source_task filename=%s, sdu_length=%d, frame_duration=%.1f",
|
||||
filename,
|
||||
sdu_length,
|
||||
frame_duration_us / 1000,
|
||||
)
|
||||
with wave.open(filename, 'rb') as wav:
|
||||
bits_per_sample = wav.getsampwidth() * 8
|
||||
|
||||
pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = (
|
||||
struct.unpack("<HIIHH", header[22:36])
|
||||
)
|
||||
assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
|
||||
assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
|
||||
|
||||
frame_bytes = (
|
||||
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
|
||||
* DEFAULT_PCM_BYTES_PER_SAMPLE
|
||||
)
|
||||
packet_sequence_number = 0
|
||||
encoder: lc3.Encoder | None = None
|
||||
|
||||
while True:
|
||||
next_round = datetime.datetime.now() + datetime.timedelta(
|
||||
microseconds=frame_duration_us
|
||||
)
|
||||
pcm_data = f.read(frame_bytes)
|
||||
sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data)
|
||||
if not encoder:
|
||||
if (
|
||||
encoding_config
|
||||
and (frame_duration := encoding_config.frame_duration)
|
||||
and (sampling_frequency := encoding_config.sampling_frequency)
|
||||
and (
|
||||
audio_channel_allocation := encoding_config.audio_channel_allocation
|
||||
)
|
||||
):
|
||||
logger.info("Use %s", encoding_config)
|
||||
encoder = lc3.Encoder(
|
||||
frame_duration_us=frame_duration.us,
|
||||
sample_rate_hz=sampling_frequency.hz,
|
||||
num_channels=audio_channel_allocation.channel_count,
|
||||
input_sample_rate_hz=wav.getframerate(),
|
||||
)
|
||||
else:
|
||||
sdu = encoder.encode(
|
||||
pcm=wav.readframes(encoder.get_frame_samples()),
|
||||
num_bytes=sdu_length,
|
||||
bit_depth=bits_per_sample,
|
||||
)
|
||||
cis_link.write(sdu)
|
||||
|
||||
iso_packet = HCI_IsoDataPacket(
|
||||
connection_handle=cis_handle,
|
||||
data_total_length=sdu_length + 4,
|
||||
packet_sequence_number=packet_sequence_number,
|
||||
pb_flag=0b10,
|
||||
packet_status_flag=0,
|
||||
iso_sdu_length=sdu_length,
|
||||
iso_sdu_fragment=sdu,
|
||||
)
|
||||
device.host.send_hci_packet(iso_packet)
|
||||
packet_sequence_number += 1
|
||||
sleep_time = next_round - datetime.datetime.now()
|
||||
await asyncio.sleep(sleep_time.total_seconds())
|
||||
await asyncio.sleep(sleep_time.total_seconds() * 0.9)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -410,7 +271,7 @@ class Speaker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_config_path: Optional[str],
|
||||
device_config_path: str | None,
|
||||
ui_port: int,
|
||||
transport: str,
|
||||
lc3_input_file_path: str,
|
||||
@@ -437,6 +298,7 @@ class Speaker:
|
||||
advertising_interval_min=25,
|
||||
advertising_interval_max=25,
|
||||
address=Address('F1:F2:F3:F4:F5:F6'),
|
||||
identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
|
||||
)
|
||||
|
||||
device_config.le_enabled = True
|
||||
@@ -468,17 +330,13 @@ class Speaker:
|
||||
advertising_data = bytes(
|
||||
AdvertisingData(
|
||||
[
|
||||
(
|
||||
AdvertisingData.COMPLETE_LOCAL_NAME,
|
||||
bytes(device_config.name, 'utf-8'),
|
||||
data_types.CompleteLocalName(device_config.name),
|
||||
data_types.Flags(
|
||||
AdvertisingData.Flags.LE_GENERAL_DISCOVERABLE_MODE
|
||||
| AdvertisingData.Flags.BR_EDR_NOT_SUPPORTED
|
||||
),
|
||||
(
|
||||
AdvertisingData.FLAGS,
|
||||
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
|
||||
),
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
bytes(pacs.PublishedAudioCapabilitiesService.UUID),
|
||||
data_types.IncompleteListOf16BitServiceUUIDs(
|
||||
[pacs.PublishedAudioCapabilitiesService.UUID]
|
||||
),
|
||||
]
|
||||
)
|
||||
@@ -486,21 +344,35 @@ class Speaker:
|
||||
|
||||
def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine):
|
||||
codec_config = ase.codec_specific_configuration
|
||||
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
pcm = decode(
|
||||
codec_config.frame_duration.us,
|
||||
codec_config.audio_channel_allocation.channel_count,
|
||||
pdu.iso_sdu_fragment,
|
||||
if (
|
||||
not isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
or codec_config.frame_duration is None
|
||||
or codec_config.audio_channel_allocation is None
|
||||
or decoder is None
|
||||
or not pdu.iso_sdu_fragment
|
||||
):
|
||||
return
|
||||
pcm = decoder.decode(
|
||||
pdu.iso_sdu_fragment, bit_depth=DEFAULT_PCM_BYTES_PER_SAMPLE * 8
|
||||
)
|
||||
utils.cancel_on_event(
|
||||
self.device, 'disconnection', self.ui_server.send_audio(pcm)
|
||||
)
|
||||
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
|
||||
|
||||
def on_ase_state_change(ase: ascs.AseStateMachine) -> None:
|
||||
codec_config = ase.codec_specific_configuration
|
||||
if ase.state == ascs.AseStateMachine.State.STREAMING:
|
||||
codec_config = ase.codec_specific_configuration
|
||||
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
assert ase.cis_link
|
||||
if ase.role == ascs.AudioRole.SOURCE:
|
||||
ase.cis_link.abort_on(
|
||||
if (
|
||||
not isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
or ase.cis_link is None
|
||||
or codec_config.octets_per_codec_frame is None
|
||||
or codec_config.frame_duration is None
|
||||
or codec_config.codec_frames_per_sdu is None
|
||||
):
|
||||
return
|
||||
utils.cancel_on_event(
|
||||
ase.cis_link,
|
||||
'disconnection',
|
||||
lc3_source_task(
|
||||
filename=self.lc3_input_file_path,
|
||||
@@ -510,25 +382,30 @@ class Speaker:
|
||||
),
|
||||
frame_duration_us=codec_config.frame_duration.us,
|
||||
device=self.device,
|
||||
cis_handle=ase.cis_link.handle,
|
||||
cis_link=ase.cis_link,
|
||||
),
|
||||
)
|
||||
else:
|
||||
if not ase.cis_link:
|
||||
return
|
||||
ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
|
||||
elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED:
|
||||
codec_config = ase.codec_specific_configuration
|
||||
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
if (
|
||||
not isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
or codec_config.sampling_frequency is None
|
||||
or codec_config.frame_duration is None
|
||||
or codec_config.audio_channel_allocation is None
|
||||
):
|
||||
return
|
||||
if ase.role == ascs.AudioRole.SOURCE:
|
||||
setup_encoders(
|
||||
codec_config.sampling_frequency.hz,
|
||||
codec_config.frame_duration.us,
|
||||
codec_config.audio_channel_allocation.channel_count,
|
||||
)
|
||||
global encoding_config
|
||||
encoding_config = codec_config
|
||||
else:
|
||||
setup_decoders(
|
||||
codec_config.sampling_frequency.hz,
|
||||
codec_config.frame_duration.us,
|
||||
codec_config.audio_channel_allocation.channel_count,
|
||||
global decoder
|
||||
decoder = lc3.Decoder(
|
||||
frame_duration_us=codec_config.frame_duration.us,
|
||||
sample_rate_hz=codec_config.sampling_frequency.hz,
|
||||
num_channels=codec_config.audio_channel_allocation.channel_count,
|
||||
)
|
||||
|
||||
for ase in ascs_service.ase_state_machines.values():
|
||||
@@ -567,7 +444,7 @@ def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) ->
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
bumble.logging.setup_basic_logging()
|
||||
speaker()
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
@@ -1,289 +0,0 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# ----------------------------------------------------------------------------
|
||||
import sys
|
||||
import logging
|
||||
import json
|
||||
import asyncio
|
||||
import argparse
|
||||
import uuid
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
import websockets
|
||||
|
||||
from bumble.colors import color
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ----------------------------------------------------------------------------
|
||||
DEFAULT_RELAY_PORT = 10723
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# ----------------------------------------------------------------------------
|
||||
def error_to_json(error):
|
||||
return json.dumps({'error': error})
|
||||
|
||||
|
||||
def error_to_result(error):
|
||||
return f'result:{error_to_json(error)}'
|
||||
|
||||
|
||||
async def broadcast_message(message, connections):
|
||||
# Send to all the connections
|
||||
tasks = [connection.send_message(message) for connection in connections]
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Connection class
|
||||
# ----------------------------------------------------------------------------
|
||||
class Connection:
|
||||
"""
|
||||
A Connection represents a client connected to the relay over a websocket
|
||||
"""
|
||||
|
||||
def __init__(self, room, websocket):
|
||||
self.room = room
|
||||
self.websocket = websocket
|
||||
self.address = str(uuid.uuid4())
|
||||
|
||||
async def send_message(self, message):
|
||||
try:
|
||||
logger.debug(color(f'->{self.address}: {message}', 'yellow'))
|
||||
return await self.websocket.send(message)
|
||||
except websockets.exceptions.WebSocketException as error:
|
||||
logger.info(f'! client "{self}" disconnected: {error}')
|
||||
await self.cleanup()
|
||||
|
||||
async def send_error(self, error):
|
||||
return await self.send_message(f'result:{error_to_json(error)}')
|
||||
|
||||
async def receive_message(self):
|
||||
try:
|
||||
message = await self.websocket.recv()
|
||||
logger.debug(color(f'<-{self.address}: {message}', 'blue'))
|
||||
return message
|
||||
except websockets.exceptions.WebSocketException as error:
|
||||
logger.info(color(f'! client "{self}" disconnected: {error}', 'red'))
|
||||
await self.cleanup()
|
||||
|
||||
async def cleanup(self):
|
||||
if self.room:
|
||||
await self.room.remove_connection(self)
|
||||
|
||||
def set_address(self, address):
|
||||
logger.info(f'Connection address changed: {self.address} -> {address}')
|
||||
self.address = address
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f'Connection(address="{self.address}", '
|
||||
f'client={self.websocket.remote_address[0]}:'
|
||||
f'{self.websocket.remote_address[1]})'
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Room class
|
||||
# ----------------------------------------------------------------------------
|
||||
class Room:
|
||||
"""
|
||||
A Room is a collection of bridged connections
|
||||
"""
|
||||
|
||||
def __init__(self, relay, name):
|
||||
self.relay = relay
|
||||
self.name = name
|
||||
self.observers = []
|
||||
self.connections = []
|
||||
|
||||
async def add_connection(self, connection):
|
||||
logger.info(f'New participant in {self.name}: {connection}')
|
||||
self.connections.append(connection)
|
||||
await self.broadcast_message(connection, f'joined:{connection.address}')
|
||||
|
||||
async def remove_connection(self, connection):
|
||||
if connection in self.connections:
|
||||
self.connections.remove(connection)
|
||||
await self.broadcast_message(connection, f'left:{connection.address}')
|
||||
|
||||
def find_connections_by_address(self, address):
|
||||
return [c for c in self.connections if c.address == address]
|
||||
|
||||
async def bridge_connection(self, connection):
|
||||
while True:
|
||||
# Wait for a message
|
||||
message = await connection.receive_message()
|
||||
|
||||
# Skip empty messages
|
||||
if message is None:
|
||||
return
|
||||
|
||||
# Parse the message to decide how to handle it
|
||||
if message.startswith('@'):
|
||||
# This is a targeted message
|
||||
await self.on_targeted_message(connection, message)
|
||||
elif message.startswith('/'):
|
||||
# This is an RPC request
|
||||
await self.on_rpc_request(connection, message)
|
||||
else:
|
||||
await connection.send_message(
|
||||
f'result:{error_to_json("error: invalid message")}'
|
||||
)
|
||||
|
||||
async def broadcast_message(self, sender, message):
|
||||
'''
|
||||
Send to all connections in the room except back to the sender
|
||||
'''
|
||||
await broadcast_message(message, [c for c in self.connections if c != sender])
|
||||
|
||||
async def on_rpc_request(self, connection, message):
|
||||
command, *params = message.split(' ', 1)
|
||||
if handler := getattr(
|
||||
self, f'on_{command[1:].lower().replace("-","_")}_command', None
|
||||
):
|
||||
try:
|
||||
result = await handler(connection, params)
|
||||
except Exception as error:
|
||||
result = error_to_result(error)
|
||||
else:
|
||||
result = error_to_result('unknown command')
|
||||
|
||||
await connection.send_message(result or 'result:{}')
|
||||
|
||||
async def on_targeted_message(self, connection, message):
|
||||
target, *payload = message.split(' ', 1)
|
||||
if not payload:
|
||||
return error_to_json('missing arguments')
|
||||
payload = payload[0]
|
||||
target = target[1:]
|
||||
|
||||
# Determine what targets to send to
|
||||
if target == '*':
|
||||
# Send to all connections in the room except the connection from which the
|
||||
# message was received
|
||||
connections = [c for c in self.connections if c != connection]
|
||||
else:
|
||||
connections = self.find_connections_by_address(target)
|
||||
if not connections:
|
||||
# Unicast with no recipient, let the sender know
|
||||
await connection.send_message(f'unreachable:{target}')
|
||||
|
||||
# Send to targets
|
||||
await broadcast_message(f'message:{connection.address}/{payload}', connections)
|
||||
|
||||
async def on_set_address_command(self, connection, params):
|
||||
if not params:
|
||||
return error_to_result('missing address')
|
||||
|
||||
current_address = connection.address
|
||||
new_address = params[0]
|
||||
connection.set_address(new_address)
|
||||
await self.broadcast_message(
|
||||
connection, f'address-changed:from={current_address},to={new_address}'
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
class Relay:
|
||||
"""
|
||||
A relay accepts connections with the following url: ws://<hostname>/<room>.
|
||||
Participants in a room can communicate with each other
|
||||
"""
|
||||
|
||||
def __init__(self, port):
|
||||
self.port = port
|
||||
self.rooms = {}
|
||||
self.observers = []
|
||||
|
||||
def start(self):
|
||||
logger.info(f'Starting Relay on port {self.port}')
|
||||
|
||||
# pylint: disable-next=no-member
|
||||
return websockets.serve(self.serve, '0.0.0.0', self.port, ping_interval=None)
|
||||
|
||||
async def serve_as_controller(self, connection):
|
||||
pass
|
||||
|
||||
async def serve(self, websocket, path):
|
||||
logger.debug(f'New connection with path {path}')
|
||||
|
||||
# Parse the path
|
||||
parsed = urlparse(path)
|
||||
|
||||
# Check if this is a controller client
|
||||
if parsed.path == '/':
|
||||
return await self.serve_as_controller(Connection('', websocket))
|
||||
|
||||
# Find or create a room for this connection
|
||||
room_name = parsed.path[1:].split('/')[0]
|
||||
if room_name not in self.rooms:
|
||||
self.rooms[room_name] = Room(self, room_name)
|
||||
room = self.rooms[room_name]
|
||||
|
||||
# Add the connection to the room
|
||||
connection = Connection(room, websocket)
|
||||
await room.add_connection(connection)
|
||||
|
||||
# Bridge until the connection is closed
|
||||
await room.bridge_connection(connection)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
def main():
|
||||
# Check the Python version
|
||||
if sys.version_info < (3, 6, 1):
|
||||
print('ERROR: Python 3.6.1 or higher is required')
|
||||
sys.exit(1)
|
||||
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
|
||||
# Parse arguments
|
||||
arg_parser = argparse.ArgumentParser(description='Bumble Link Relay')
|
||||
arg_parser.add_argument('--log-level', default='INFO', help='logger level')
|
||||
arg_parser.add_argument('--log-config', help='logger config file (YAML)')
|
||||
arg_parser.add_argument(
|
||||
'--port', type=int, default=DEFAULT_RELAY_PORT, help='Port to listen on'
|
||||
)
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
# Setup logger
|
||||
if args.log_config:
|
||||
from logging import config # pylint: disable=import-outside-toplevel
|
||||
|
||||
config.fileConfig(args.log_config)
|
||||
else:
|
||||
logging.basicConfig(level=getattr(logging, args.log_level.upper()))
|
||||
|
||||
# Start a relay
|
||||
relay = Relay(args.port)
|
||||
asyncio.get_event_loop().run_until_complete(relay.start())
|
||||
asyncio.get_event_loop().run_forever()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,21 +0,0 @@
|
||||
[loggers]
|
||||
keys=root
|
||||
|
||||
[handlers]
|
||||
keys=stream_handler
|
||||
|
||||
[formatters]
|
||||
keys=formatter
|
||||
|
||||
[logger_root]
|
||||
level=DEBUG
|
||||
handlers=stream_handler
|
||||
|
||||
[handler_stream_handler]
|
||||
class=StreamHandler
|
||||
level=DEBUG
|
||||
formatter=formatter
|
||||
args=(sys.stderr,)
|
||||
|
||||
[formatter_formatter]
|
||||
format=%(asctime)s %(name)-12s %(levelname)-8s %(message)s
|
||||
288
apps/pair.py
288
apps/pair.py
@@ -16,36 +16,44 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
|
||||
import click
|
||||
from prompt_toolkit.shortcuts import PromptSession
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.pairing import OobData, PairingDelegate, PairingConfig
|
||||
from bumble.smp import OobContext, OobLegacyContext
|
||||
from bumble.smp import error_name as smp_error_name
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.core import (
|
||||
AdvertisingData,
|
||||
ProtocolError,
|
||||
BT_LE_TRANSPORT,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
from bumble import data_types
|
||||
from bumble.a2dp import make_audio_sink_service_sdp_records
|
||||
from bumble.att import (
|
||||
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
|
||||
ATT_INSUFFICIENT_ENCRYPTION_ERROR,
|
||||
ATT_Error,
|
||||
)
|
||||
from bumble.colors import color
|
||||
from bumble.core import (
|
||||
UUID,
|
||||
AdvertisingData,
|
||||
Appearance,
|
||||
DataType,
|
||||
PhysicalTransport,
|
||||
ProtocolError,
|
||||
)
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.gatt import (
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
GATT_GENERIC_ACCESS_SERVICE,
|
||||
Service,
|
||||
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
|
||||
GATT_HEART_RATE_SERVICE,
|
||||
Characteristic,
|
||||
CharacteristicValue,
|
||||
)
|
||||
from bumble.att import (
|
||||
ATT_Error,
|
||||
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
|
||||
ATT_INSUFFICIENT_ENCRYPTION_ERROR,
|
||||
Service,
|
||||
)
|
||||
from bumble.hci import OwnAddressType
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.pairing import OobData, PairingConfig, PairingDelegate
|
||||
from bumble.smp import OobContext, OobLegacyContext
|
||||
from bumble.smp import error_name as smp_error_name
|
||||
from bumble.transport import open_transport
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -63,7 +71,7 @@ class Waiter:
|
||||
self.linger = linger
|
||||
|
||||
def terminate(self):
|
||||
if not self.linger:
|
||||
if not self.linger and not self.done.done:
|
||||
self.done.set_result(None)
|
||||
|
||||
async def wait_until_terminated(self):
|
||||
@@ -194,7 +202,7 @@ class Delegate(PairingDelegate):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_peer_name(peer, mode):
|
||||
if mode == 'classic':
|
||||
if peer.connection.transport == PhysicalTransport.BR_EDR:
|
||||
return await peer.request_name()
|
||||
|
||||
# Try to get the peer name from GATT
|
||||
@@ -226,13 +234,14 @@ def read_with_error(connection):
|
||||
raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
|
||||
|
||||
|
||||
def write_with_error(connection, _value):
|
||||
if not connection.is_encrypted:
|
||||
raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR)
|
||||
|
||||
if not AUTHENTICATION_ERROR_RETURNED[1]:
|
||||
AUTHENTICATION_ERROR_RETURNED[1] = True
|
||||
raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
|
||||
# -----------------------------------------------------------------------------
|
||||
def sdp_records():
|
||||
service_record_handle = 0x00010001
|
||||
return {
|
||||
service_record_handle: make_audio_sink_service_sdp_records(
|
||||
service_record_handle
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -240,15 +249,19 @@ def on_connection(connection, request):
|
||||
print(color(f'<<< Connection: {connection}', 'green'))
|
||||
|
||||
# Listen for pairing events
|
||||
connection.on('pairing_start', on_pairing_start)
|
||||
connection.on('pairing', lambda keys: on_pairing(connection, keys))
|
||||
connection.on(connection.EVENT_PAIRING_START, on_pairing_start)
|
||||
connection.on(connection.EVENT_PAIRING, lambda keys: on_pairing(connection, keys))
|
||||
connection.on(
|
||||
'pairing_failure', lambda reason: on_pairing_failure(connection, reason)
|
||||
connection.EVENT_CLASSIC_PAIRING, lambda: on_classic_pairing(connection)
|
||||
)
|
||||
connection.on(
|
||||
connection.EVENT_PAIRING_FAILURE,
|
||||
lambda reason: on_pairing_failure(connection, reason),
|
||||
)
|
||||
|
||||
# Listen for encryption changes
|
||||
connection.on(
|
||||
'connection_encryption_change',
|
||||
connection.EVENT_CONNECTION_ENCRYPTION_CHANGE,
|
||||
lambda: on_connection_encryption_change(connection),
|
||||
)
|
||||
|
||||
@@ -289,6 +302,20 @@ async def on_pairing(connection, keys):
|
||||
Waiter.instance.terminate()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_classic_pairing(connection):
|
||||
print(color('***-----------------------------------', 'cyan'))
|
||||
print(
|
||||
color(
|
||||
f'*** Paired [Classic]! (peer identity={connection.peer_address})', 'cyan'
|
||||
)
|
||||
)
|
||||
print(color('***-----------------------------------', 'cyan'))
|
||||
await asyncio.sleep(POST_PAIRING_DELAY)
|
||||
Waiter.instance.terminate()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_pairing_failure(connection, reason):
|
||||
@@ -306,6 +333,7 @@ async def pair(
|
||||
mitm,
|
||||
bond,
|
||||
ctkd,
|
||||
advertising_address,
|
||||
identity_address,
|
||||
linger,
|
||||
io,
|
||||
@@ -314,6 +342,8 @@ async def pair(
|
||||
request,
|
||||
print_keys,
|
||||
keystore_file,
|
||||
advertise_service_uuids,
|
||||
advertise_appearance,
|
||||
device_config,
|
||||
hci_transport,
|
||||
address_or_name,
|
||||
@@ -321,7 +351,7 @@ async def pair(
|
||||
Waiter.instance = Waiter(linger=linger)
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
|
||||
async with await open_transport(hci_transport) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
|
||||
# Create a device to manage the host
|
||||
@@ -329,29 +359,33 @@ async def pair(
|
||||
|
||||
# Expose a GATT characteristic that can be used to trigger pairing by
|
||||
# responding with an authentication error when read
|
||||
if mode == 'le':
|
||||
device.le_enabled = True
|
||||
if mode in ('le', 'dual'):
|
||||
device.add_service(
|
||||
Service(
|
||||
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
|
||||
GATT_HEART_RATE_SERVICE,
|
||||
[
|
||||
Characteristic(
|
||||
'552957FB-CF1F-4A31-9535-E78847E1A714',
|
||||
Characteristic.Properties.READ
|
||||
| Characteristic.Properties.WRITE,
|
||||
Characteristic.READABLE | Characteristic.WRITEABLE,
|
||||
CharacteristicValue(
|
||||
read=read_with_error, write=write_with_error
|
||||
),
|
||||
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.READ_REQUIRES_AUTHENTICATION,
|
||||
bytes(1),
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# Select LE or Classic
|
||||
if mode == 'classic':
|
||||
# LE and Classic support
|
||||
if mode in ('classic', 'dual'):
|
||||
device.classic_enabled = True
|
||||
device.classic_smp_enabled = ctkd
|
||||
if mode in ('le', 'dual'):
|
||||
device.le_enabled = True
|
||||
if mode == 'dual':
|
||||
device.le_simultaneous_enabled = True
|
||||
|
||||
# Setup SDP
|
||||
if mode in ('classic', 'dual'):
|
||||
device.sdp_service_records = sdp_records()
|
||||
|
||||
# Get things going
|
||||
await device.power_on()
|
||||
@@ -370,27 +404,40 @@ async def pair(
|
||||
# Create an OOB context if needed
|
||||
if oob:
|
||||
our_oob_context = OobContext()
|
||||
shared_data = (
|
||||
None
|
||||
if oob == '-'
|
||||
else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob)))
|
||||
)
|
||||
legacy_context = OobLegacyContext()
|
||||
if oob == '-':
|
||||
shared_data = None
|
||||
legacy_context = OobLegacyContext()
|
||||
else:
|
||||
oob_data = OobData.from_ad(
|
||||
AdvertisingData.from_bytes(bytes.fromhex(oob))
|
||||
)
|
||||
shared_data = oob_data.shared_data
|
||||
legacy_context = oob_data.legacy_context
|
||||
if legacy_context is None and not sc:
|
||||
print(color('OOB pairing in legacy mode requires TK', 'red'))
|
||||
return
|
||||
|
||||
oob_contexts = PairingConfig.OobConfig(
|
||||
our_context=our_oob_context,
|
||||
peer_data=shared_data,
|
||||
legacy_context=legacy_context,
|
||||
)
|
||||
oob_data = OobData(
|
||||
address=device.random_address,
|
||||
shared_data=shared_data,
|
||||
legacy_context=legacy_context,
|
||||
)
|
||||
print(color('@@@-----------------------------------', 'yellow'))
|
||||
print(color('@@@ OOB Data:', 'yellow'))
|
||||
print(color(f'@@@ {our_oob_context.share()}', 'yellow'))
|
||||
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
|
||||
print(color(f'@@@ HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow'))
|
||||
if shared_data is None:
|
||||
oob_data = OobData(
|
||||
address=device.random_address,
|
||||
shared_data=our_oob_context.share(),
|
||||
legacy_context=(None if sc else legacy_context),
|
||||
)
|
||||
print(
|
||||
color(
|
||||
f'@@@ SHARE: {bytes(oob_data.to_ad()).hex()}',
|
||||
'yellow',
|
||||
)
|
||||
)
|
||||
if legacy_context:
|
||||
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
|
||||
print(color('@@@-----------------------------------', 'yellow'))
|
||||
else:
|
||||
oob_contexts = None
|
||||
@@ -417,7 +464,9 @@ async def pair(
|
||||
print(color(f'=== Connecting to {address_or_name}...', 'green'))
|
||||
connection = await device.connect(
|
||||
address_or_name,
|
||||
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
|
||||
transport=(
|
||||
PhysicalTransport.LE if mode == 'le' else PhysicalTransport.BR_EDR
|
||||
),
|
||||
)
|
||||
|
||||
if not request:
|
||||
@@ -430,13 +479,94 @@ async def pair(
|
||||
print(color(f'Pairing failed: {error}', 'red'))
|
||||
|
||||
else:
|
||||
if mode == 'le':
|
||||
# Advertise so that peers can find us and connect
|
||||
await device.start_advertising(auto_restart=True)
|
||||
else:
|
||||
if mode in ('le', 'dual'):
|
||||
# Advertise so that peers can find us and connect.
|
||||
# Include the heart rate service UUID in the advertisement data
|
||||
# so that devices like iPhones can show this device in their
|
||||
# Bluetooth selector.
|
||||
service_uuids_16 = []
|
||||
service_uuids_32 = []
|
||||
service_uuids_128 = []
|
||||
if advertise_service_uuids:
|
||||
for uuid in advertise_service_uuids:
|
||||
uuid = uuid.replace("-", "")
|
||||
if len(uuid) == 4:
|
||||
service_uuids_16.append(UUID(uuid))
|
||||
elif len(uuid) == 8:
|
||||
service_uuids_32.append(UUID(uuid))
|
||||
elif len(uuid) == 32:
|
||||
service_uuids_128.append(UUID(uuid))
|
||||
else:
|
||||
print(color('Invalid UUID format', 'red'))
|
||||
return
|
||||
else:
|
||||
service_uuids_16.append(GATT_HEART_RATE_SERVICE)
|
||||
|
||||
flags = AdvertisingData.Flags.LE_LIMITED_DISCOVERABLE_MODE
|
||||
if mode == 'le':
|
||||
flags |= AdvertisingData.Flags.BR_EDR_NOT_SUPPORTED
|
||||
if mode == 'dual':
|
||||
flags |= AdvertisingData.Flags.SIMULTANEOUS_LE_BR_EDR_CAPABLE
|
||||
|
||||
advertising_data_types: list[DataType] = [
|
||||
data_types.Flags(flags),
|
||||
data_types.CompleteLocalName('Bumble'),
|
||||
]
|
||||
if service_uuids_16:
|
||||
advertising_data_types.append(
|
||||
data_types.IncompleteListOf16BitServiceUUIDs(service_uuids_16)
|
||||
)
|
||||
if service_uuids_32:
|
||||
advertising_data_types.append(
|
||||
data_types.IncompleteListOf32BitServiceUUIDs(service_uuids_32)
|
||||
)
|
||||
if service_uuids_128:
|
||||
advertising_data_types.append(
|
||||
data_types.IncompleteListOf128BitServiceUUIDs(service_uuids_128)
|
||||
)
|
||||
|
||||
if advertise_appearance:
|
||||
advertise_appearance = advertise_appearance.upper()
|
||||
try:
|
||||
advertise_appearance_int = int(advertise_appearance)
|
||||
except ValueError:
|
||||
category, subcategory = advertise_appearance.split('/')
|
||||
try:
|
||||
category_enum = Appearance.Category[category]
|
||||
except ValueError:
|
||||
print(
|
||||
color(f'Invalid appearance category {category}', 'red')
|
||||
)
|
||||
return
|
||||
subcategory_class = Appearance.SUBCATEGORY_CLASSES[
|
||||
category_enum
|
||||
]
|
||||
try:
|
||||
subcategory_enum = subcategory_class[subcategory]
|
||||
except ValueError:
|
||||
print(color(f'Invalid subcategory {subcategory}', 'red'))
|
||||
return
|
||||
advertise_appearance_int = int(
|
||||
Appearance(category_enum, subcategory_enum)
|
||||
)
|
||||
advertising_data_types.append(
|
||||
data_types.Appearance(category_enum, subcategory_enum)
|
||||
)
|
||||
device.advertising_data = bytes(AdvertisingData(advertising_data_types))
|
||||
await device.start_advertising(
|
||||
auto_restart=True,
|
||||
own_address_type=(
|
||||
OwnAddressType.PUBLIC
|
||||
if advertising_address == 'public'
|
||||
else OwnAddressType.RANDOM
|
||||
),
|
||||
)
|
||||
|
||||
if mode in ('classic', 'dual'):
|
||||
# Become discoverable and connectable
|
||||
await device.set_discoverable(True)
|
||||
await device.set_connectable(True)
|
||||
print(color('Ready for connections on', 'blue'), device.public_address)
|
||||
|
||||
# Run until the user asks to exit
|
||||
await Waiter.instance.wait_until_terminated()
|
||||
@@ -456,7 +586,10 @@ class LogHandler(logging.Handler):
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True
|
||||
'--mode',
|
||||
type=click.Choice(['le', 'classic', 'dual']),
|
||||
default='le',
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
'--sc',
|
||||
@@ -478,6 +611,10 @@ class LogHandler(logging.Handler):
|
||||
help='Enable CTKD',
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
'--advertising-address',
|
||||
type=click.Choice(['random', 'public']),
|
||||
)
|
||||
@click.option(
|
||||
'--identity-address',
|
||||
type=click.Choice(['random', 'public']),
|
||||
@@ -506,9 +643,20 @@ class LogHandler(logging.Handler):
|
||||
@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing')
|
||||
@click.option(
|
||||
'--keystore-file',
|
||||
metavar='<filename>',
|
||||
metavar='FILENAME',
|
||||
help='File in which to store the pairing keys',
|
||||
)
|
||||
@click.option(
|
||||
'--advertise-service-uuid',
|
||||
metavar="UUID",
|
||||
multiple=True,
|
||||
help="Advertise a GATT service UUID (may be specified more than once)",
|
||||
)
|
||||
@click.option(
|
||||
'--advertise-appearance',
|
||||
metavar='APPEARANCE',
|
||||
help='Advertise an Appearance ID (int value or string)',
|
||||
)
|
||||
@click.argument('device-config')
|
||||
@click.argument('hci_transport')
|
||||
@click.argument('address-or-name', required=False)
|
||||
@@ -518,6 +666,7 @@ def main(
|
||||
mitm,
|
||||
bond,
|
||||
ctkd,
|
||||
advertising_address,
|
||||
identity_address,
|
||||
linger,
|
||||
io,
|
||||
@@ -526,6 +675,8 @@ def main(
|
||||
request,
|
||||
print_keys,
|
||||
keystore_file,
|
||||
advertise_service_uuid,
|
||||
advertise_appearance,
|
||||
device_config,
|
||||
hci_transport,
|
||||
address_or_name,
|
||||
@@ -544,6 +695,7 @@ def main(
|
||||
mitm,
|
||||
bond,
|
||||
ctkd,
|
||||
advertising_address,
|
||||
identity_address,
|
||||
linger,
|
||||
io,
|
||||
@@ -552,6 +704,8 @@ def main(
|
||||
request,
|
||||
print_keys,
|
||||
keystore_file,
|
||||
advertise_service_uuid,
|
||||
advertise_appearance,
|
||||
device_config,
|
||||
hci_transport,
|
||||
address_or_name,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import asyncio
|
||||
import click
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from bumble.pandora import PandoraDevice, Config, serve
|
||||
from typing import Dict, Any
|
||||
import click
|
||||
|
||||
from bumble.pandora import Config, PandoraDevice, serve
|
||||
|
||||
BUMBLE_SERVER_GRPC_PORT = 7999
|
||||
ROOTCANAL_PORT_CUTTLEFISH = 7300
|
||||
@@ -39,7 +40,7 @@ def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> No
|
||||
asyncio.run(serve(device, config=server_config, port=grpc_port))
|
||||
|
||||
|
||||
def retrieve_config(config: str) -> Dict[str, Any]:
|
||||
def retrieve_config(config: str) -> dict[str, Any]:
|
||||
if not config:
|
||||
return {}
|
||||
|
||||
|
||||
604
apps/player/player.py
Normal file
604
apps/player/player.py
Normal file
@@ -0,0 +1,604 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import click
|
||||
|
||||
import bumble.logging
|
||||
from bumble.a2dp import (
|
||||
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
|
||||
A2DP_NON_A2DP_CODEC_TYPE,
|
||||
A2DP_SBC_CODEC_TYPE,
|
||||
AacFrame,
|
||||
AacMediaCodecInformation,
|
||||
AacPacketSource,
|
||||
AacParser,
|
||||
OpusMediaCodecInformation,
|
||||
OpusPacket,
|
||||
OpusPacketSource,
|
||||
OpusParser,
|
||||
SbcFrame,
|
||||
SbcMediaCodecInformation,
|
||||
SbcPacketSource,
|
||||
SbcParser,
|
||||
make_audio_source_service_sdp_records,
|
||||
)
|
||||
from bumble.avdtp import (
|
||||
AVDTP_AUDIO_MEDIA_TYPE,
|
||||
AVDTP_DELAY_REPORTING_SERVICE_CATEGORY,
|
||||
MediaCodecCapabilities,
|
||||
MediaPacketPump,
|
||||
)
|
||||
from bumble.avdtp import Protocol as AvdtpProtocol
|
||||
from bumble.avdtp import find_avdtp_service_with_connection
|
||||
from bumble.avrcp import Protocol as AvrcpProtocol
|
||||
from bumble.colors import color
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.core import ConnectionError as BumbleConnectionError
|
||||
from bumble.core import DeviceClass, PhysicalTransport
|
||||
from bumble.device import Connection, Device, DeviceConfiguration
|
||||
from bumble.hci import HCI_CONNECTION_ALREADY_EXISTS_ERROR, Address, HCI_Constant
|
||||
from bumble.pairing import PairingConfig
|
||||
from bumble.transport import open_transport
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def a2dp_source_sdp_records():
|
||||
service_record_handle = 0x00010001
|
||||
return {
|
||||
service_record_handle: make_audio_source_service_sdp_records(
|
||||
service_record_handle
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def sbc_codec_capabilities(read_function) -> MediaCodecCapabilities:
|
||||
sbc_parser = SbcParser(read_function)
|
||||
sbc_frame: SbcFrame
|
||||
async for sbc_frame in sbc_parser.frames:
|
||||
# We only need the first frame
|
||||
print(color(f"SBC format: {sbc_frame}", "cyan"))
|
||||
break
|
||||
|
||||
channel_mode = [
|
||||
SbcMediaCodecInformation.ChannelMode.MONO,
|
||||
SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL,
|
||||
SbcMediaCodecInformation.ChannelMode.STEREO,
|
||||
SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
|
||||
][sbc_frame.channel_mode]
|
||||
block_length = {
|
||||
4: SbcMediaCodecInformation.BlockLength.BL_4,
|
||||
8: SbcMediaCodecInformation.BlockLength.BL_8,
|
||||
12: SbcMediaCodecInformation.BlockLength.BL_12,
|
||||
16: SbcMediaCodecInformation.BlockLength.BL_16,
|
||||
}[sbc_frame.block_count]
|
||||
subbands = {
|
||||
4: SbcMediaCodecInformation.Subbands.S_4,
|
||||
8: SbcMediaCodecInformation.Subbands.S_8,
|
||||
}[sbc_frame.subband_count]
|
||||
allocation_method = [
|
||||
SbcMediaCodecInformation.AllocationMethod.LOUDNESS,
|
||||
SbcMediaCodecInformation.AllocationMethod.SNR,
|
||||
][sbc_frame.allocation_method]
|
||||
return MediaCodecCapabilities(
|
||||
media_type=AVDTP_AUDIO_MEDIA_TYPE,
|
||||
media_codec_type=A2DP_SBC_CODEC_TYPE,
|
||||
media_codec_information=SbcMediaCodecInformation(
|
||||
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.from_int(
|
||||
sbc_frame.sampling_frequency
|
||||
),
|
||||
channel_mode=channel_mode,
|
||||
block_length=block_length,
|
||||
subbands=subbands,
|
||||
allocation_method=allocation_method,
|
||||
minimum_bitpool_value=2,
|
||||
maximum_bitpool_value=40,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def aac_codec_capabilities(read_function) -> MediaCodecCapabilities:
|
||||
aac_parser = AacParser(read_function)
|
||||
aac_frame: AacFrame
|
||||
async for aac_frame in aac_parser.frames:
|
||||
# We only need the first frame
|
||||
print(color(f"AAC format: {aac_frame}", "cyan"))
|
||||
break
|
||||
|
||||
sampling_frequency = AacMediaCodecInformation.SamplingFrequency.from_int(
|
||||
aac_frame.sampling_frequency
|
||||
)
|
||||
channels = (
|
||||
AacMediaCodecInformation.Channels.MONO
|
||||
if aac_frame.channel_configuration == 1
|
||||
else AacMediaCodecInformation.Channels.STEREO
|
||||
)
|
||||
|
||||
return MediaCodecCapabilities(
|
||||
media_type=AVDTP_AUDIO_MEDIA_TYPE,
|
||||
media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE,
|
||||
media_codec_information=AacMediaCodecInformation(
|
||||
object_type=AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC,
|
||||
sampling_frequency=sampling_frequency,
|
||||
channels=channels,
|
||||
vbr=1,
|
||||
bitrate=128000,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def opus_codec_capabilities(read_function) -> MediaCodecCapabilities:
|
||||
opus_parser = OpusParser(read_function)
|
||||
opus_packet: OpusPacket
|
||||
async for opus_packet in opus_parser.packets:
|
||||
# We only need the first packet
|
||||
print(color(f"Opus format: {opus_packet}", "cyan"))
|
||||
break
|
||||
|
||||
if opus_packet.channel_mode == OpusPacket.ChannelMode.MONO:
|
||||
channel_mode = OpusMediaCodecInformation.ChannelMode.MONO
|
||||
elif opus_packet.channel_mode == OpusPacket.ChannelMode.STEREO:
|
||||
channel_mode = OpusMediaCodecInformation.ChannelMode.STEREO
|
||||
else:
|
||||
channel_mode = OpusMediaCodecInformation.ChannelMode.DUAL_MONO
|
||||
|
||||
if opus_packet.duration == 10:
|
||||
frame_size = OpusMediaCodecInformation.FrameSize.FS_10MS
|
||||
else:
|
||||
frame_size = OpusMediaCodecInformation.FrameSize.FS_20MS
|
||||
|
||||
return MediaCodecCapabilities(
|
||||
media_type=AVDTP_AUDIO_MEDIA_TYPE,
|
||||
media_codec_type=A2DP_NON_A2DP_CODEC_TYPE,
|
||||
media_codec_information=OpusMediaCodecInformation(
|
||||
channel_mode=channel_mode,
|
||||
sampling_frequency=OpusMediaCodecInformation.SamplingFrequency.SF_48000,
|
||||
frame_size=frame_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Player:
|
||||
def __init__(
|
||||
self,
|
||||
transport: str,
|
||||
device_config: Optional[str],
|
||||
authenticate: bool,
|
||||
encrypt: bool,
|
||||
) -> None:
|
||||
self.transport = transport
|
||||
self.device_config = device_config
|
||||
self.authenticate = authenticate
|
||||
self.encrypt = encrypt
|
||||
self.avrcp_protocol: Optional[AvrcpProtocol] = None
|
||||
self.done: Optional[asyncio.Event]
|
||||
|
||||
async def run(self, workload) -> None:
|
||||
self.done = asyncio.Event()
|
||||
try:
|
||||
await self._run(workload)
|
||||
except Exception as error:
|
||||
print(color(f"!!! ERROR: {error}", "red"))
|
||||
|
||||
async def _run(self, workload) -> None:
|
||||
async with await open_transport(self.transport) as (hci_source, hci_sink):
|
||||
# Create a device
|
||||
device_config = DeviceConfiguration()
|
||||
if self.device_config:
|
||||
device_config.load_from_file(self.device_config)
|
||||
else:
|
||||
device_config.name = "Bumble Player"
|
||||
device_config.class_of_device = DeviceClass.pack_class_of_device(
|
||||
DeviceClass.AUDIO_SERVICE_CLASS,
|
||||
DeviceClass.AUDIO_VIDEO_MAJOR_DEVICE_CLASS,
|
||||
DeviceClass.AUDIO_VIDEO_UNCATEGORIZED_MINOR_DEVICE_CLASS,
|
||||
)
|
||||
device_config.keystore = "JsonKeyStore"
|
||||
|
||||
device_config.classic_enabled = True
|
||||
device_config.le_enabled = False
|
||||
device_config.le_simultaneous_enabled = False
|
||||
device_config.classic_sc_enabled = False
|
||||
device_config.classic_smp_enabled = False
|
||||
device = Device.from_config_with_hci(device_config, hci_source, hci_sink)
|
||||
|
||||
# Setup the SDP records to expose the SRC service
|
||||
device.sdp_service_records = a2dp_source_sdp_records()
|
||||
|
||||
# Setup AVRCP
|
||||
self.avrcp_protocol = AvrcpProtocol()
|
||||
self.avrcp_protocol.listen(device)
|
||||
|
||||
# Don't require MITM when pairing.
|
||||
device.pairing_config_factory = lambda connection: PairingConfig(mitm=False)
|
||||
|
||||
# Start the controller
|
||||
await device.power_on()
|
||||
|
||||
# Print some of the config/properties
|
||||
print(
|
||||
"Player Bluetooth Address:",
|
||||
color(
|
||||
device.public_address.to_string(with_type_qualifier=False),
|
||||
"yellow",
|
||||
),
|
||||
)
|
||||
|
||||
# Listen for connections
|
||||
device.on("connection", self.on_bluetooth_connection)
|
||||
|
||||
# Run the workload
|
||||
try:
|
||||
await workload(device)
|
||||
except BumbleConnectionError as error:
|
||||
if error.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR:
|
||||
print(color("Connection already established", "blue"))
|
||||
else:
|
||||
print(color(f"Failed to connect: {error}", "red"))
|
||||
|
||||
# Wait until it is time to exit
|
||||
assert self.done is not None
|
||||
await asyncio.wait(
|
||||
[hci_source.terminated, asyncio.ensure_future(self.done.wait())],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
def on_bluetooth_connection(self, connection: Connection) -> None:
|
||||
print(color(f"--- Connected: {connection}", "cyan"))
|
||||
connection.on("disconnection", self.on_bluetooth_disconnection)
|
||||
|
||||
def on_bluetooth_disconnection(self, reason) -> None:
|
||||
print(color(f"--- Disconnected: {HCI_Constant.error_name(reason)}", "cyan"))
|
||||
self.set_done()
|
||||
|
||||
async def connect(self, device: Device, address: str) -> Connection:
|
||||
print(color(f"Connecting to {address}...", "green"))
|
||||
connection = await device.connect(address, transport=PhysicalTransport.BR_EDR)
|
||||
|
||||
# Request authentication
|
||||
if self.authenticate:
|
||||
print(color("*** Authenticating...", "blue"))
|
||||
await connection.authenticate()
|
||||
print(color("*** Authenticated", "blue"))
|
||||
|
||||
# Enable encryption
|
||||
if self.encrypt:
|
||||
print(color("*** Enabling encryption...", "blue"))
|
||||
await connection.encrypt()
|
||||
print(color("*** Encryption on", "blue"))
|
||||
|
||||
return connection
|
||||
|
||||
async def create_avdtp_protocol(self, connection: Connection) -> AvdtpProtocol:
|
||||
# Look for an A2DP service
|
||||
avdtp_version = await find_avdtp_service_with_connection(connection)
|
||||
if not avdtp_version:
|
||||
raise RuntimeError("no A2DP service found")
|
||||
|
||||
print(color(f"AVDTP Version: {avdtp_version}"))
|
||||
|
||||
# Create a client to interact with the remote device
|
||||
return await AvdtpProtocol.connect(connection, avdtp_version)
|
||||
|
||||
async def stream_packets(
|
||||
self,
|
||||
protocol: AvdtpProtocol,
|
||||
codec_type: int,
|
||||
vendor_id: int,
|
||||
codec_id: int,
|
||||
packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource],
|
||||
codec_capabilities: MediaCodecCapabilities,
|
||||
):
|
||||
# Discover all endpoints on the remote device
|
||||
endpoints = await protocol.discover_remote_endpoints()
|
||||
for endpoint in endpoints:
|
||||
print('@@@', endpoint)
|
||||
|
||||
# Select a sink
|
||||
sink = protocol.find_remote_sink_by_codec(
|
||||
AVDTP_AUDIO_MEDIA_TYPE, codec_type, vendor_id, codec_id
|
||||
)
|
||||
if sink is None:
|
||||
print(color('!!! no compatible sink found', 'red'))
|
||||
return
|
||||
print(f'### Selected sink: {sink.seid}')
|
||||
|
||||
# Check if the sink supports delay reporting
|
||||
delay_reporting = False
|
||||
for capability in sink.capabilities:
|
||||
if capability.service_category == AVDTP_DELAY_REPORTING_SERVICE_CATEGORY:
|
||||
delay_reporting = True
|
||||
break
|
||||
|
||||
def on_delay_report(delay: int):
|
||||
print(color(f"*** DELAY REPORT: {delay}", "blue"))
|
||||
|
||||
# Adjust the codec capabilities for certain codecs
|
||||
for capability in sink.capabilities:
|
||||
if isinstance(capability, MediaCodecCapabilities):
|
||||
if isinstance(
|
||||
codec_capabilities.media_codec_information, SbcMediaCodecInformation
|
||||
) and isinstance(
|
||||
capability.media_codec_information, SbcMediaCodecInformation
|
||||
):
|
||||
codec_capabilities.media_codec_information.minimum_bitpool_value = (
|
||||
capability.media_codec_information.minimum_bitpool_value
|
||||
)
|
||||
codec_capabilities.media_codec_information.maximum_bitpool_value = (
|
||||
capability.media_codec_information.maximum_bitpool_value
|
||||
)
|
||||
print(color("Source media codec:", "green"), codec_capabilities)
|
||||
|
||||
# Stream the packets
|
||||
packet_pump = MediaPacketPump(packet_source.packets)
|
||||
source = protocol.add_source(codec_capabilities, packet_pump, delay_reporting)
|
||||
source.on("delay_report", on_delay_report)
|
||||
stream = await protocol.create_stream(source, sink)
|
||||
await stream.start()
|
||||
|
||||
await packet_pump.wait_for_completion()
|
||||
|
||||
async def discover(self, device: Device) -> None:
|
||||
@device.listens_to("inquiry_result")
|
||||
def on_inquiry_result(
|
||||
address: Address, class_of_device: int, data: AdvertisingData, rssi: int
|
||||
) -> None:
|
||||
(
|
||||
service_classes,
|
||||
major_device_class,
|
||||
minor_device_class,
|
||||
) = DeviceClass.split_class_of_device(class_of_device)
|
||||
separator = "\n "
|
||||
print(f">>> {color(address.to_string(False), 'yellow')}:")
|
||||
print(f" Device Class (raw): {class_of_device:06X}")
|
||||
major_class_name = DeviceClass.major_device_class_name(major_device_class)
|
||||
print(" Device Major Class: " f"{major_class_name}")
|
||||
minor_class_name = DeviceClass.minor_device_class_name(
|
||||
major_device_class, minor_device_class
|
||||
)
|
||||
print(" Device Minor Class: " f"{minor_class_name}")
|
||||
print(
|
||||
" Device Services: "
|
||||
f"{', '.join(DeviceClass.service_class_labels(service_classes))}"
|
||||
)
|
||||
print(f" RSSI: {rssi}")
|
||||
if data.ad_structures:
|
||||
print(f" {data.to_string(separator)}")
|
||||
|
||||
await device.start_discovery()
|
||||
|
||||
async def pair(self, device: Device, address: str) -> None:
|
||||
print(color(f"Connecting to {address}...", "green"))
|
||||
connection = await device.connect(address, transport=PhysicalTransport.BR_EDR)
|
||||
|
||||
print(color("Pairing...", "magenta"))
|
||||
await connection.authenticate()
|
||||
print(color("Pairing completed", "magenta"))
|
||||
self.set_done()
|
||||
|
||||
async def inquire(self, device: Device, address: str) -> None:
|
||||
connection = await self.connect(device, address)
|
||||
avdtp_protocol = await self.create_avdtp_protocol(connection)
|
||||
|
||||
# Discover the remote endpoints
|
||||
endpoints = await avdtp_protocol.discover_remote_endpoints()
|
||||
print(f'@@@ Found {len(list(endpoints))} endpoints')
|
||||
for endpoint in endpoints:
|
||||
print('@@@', endpoint)
|
||||
|
||||
self.set_done()
|
||||
|
||||
async def play(
|
||||
self,
|
||||
device: Device,
|
||||
address: Optional[str],
|
||||
audio_format: str,
|
||||
audio_file: str,
|
||||
) -> None:
|
||||
if audio_format == "auto":
|
||||
if audio_file.endswith(".sbc"):
|
||||
audio_format = "sbc"
|
||||
elif audio_file.endswith(".aac") or audio_file.endswith(".adts"):
|
||||
audio_format = "aac"
|
||||
elif audio_file.endswith(".ogg"):
|
||||
audio_format = "opus"
|
||||
else:
|
||||
raise ValueError("Unable to determine audio format from file extension")
|
||||
|
||||
device.on(
|
||||
"connection",
|
||||
lambda connection: AsyncRunner.spawn(on_connection(connection)),
|
||||
)
|
||||
|
||||
async def on_connection(connection: Connection):
|
||||
avdtp_protocol = await self.create_avdtp_protocol(connection)
|
||||
|
||||
with open(audio_file, 'rb') as input_file:
|
||||
# NOTE: this should be using asyncio file reading, but blocking reads
|
||||
# are good enough for this command line app.
|
||||
async def read_audio_data(byte_count):
|
||||
return input_file.read(byte_count)
|
||||
|
||||
# Obtain the codec capabilities from the stream
|
||||
packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource]
|
||||
vendor_id = 0
|
||||
codec_id = 0
|
||||
if audio_format == "sbc":
|
||||
codec_type = A2DP_SBC_CODEC_TYPE
|
||||
codec_capabilities = await sbc_codec_capabilities(read_audio_data)
|
||||
packet_source = SbcPacketSource(
|
||||
read_audio_data,
|
||||
avdtp_protocol.l2cap_channel.peer_mtu,
|
||||
)
|
||||
elif audio_format == "aac":
|
||||
codec_type = A2DP_MPEG_2_4_AAC_CODEC_TYPE
|
||||
codec_capabilities = await aac_codec_capabilities(read_audio_data)
|
||||
packet_source = AacPacketSource(
|
||||
read_audio_data,
|
||||
avdtp_protocol.l2cap_channel.peer_mtu,
|
||||
)
|
||||
else:
|
||||
codec_type = A2DP_NON_A2DP_CODEC_TYPE
|
||||
vendor_id = OpusMediaCodecInformation.VENDOR_ID
|
||||
codec_id = OpusMediaCodecInformation.CODEC_ID
|
||||
codec_capabilities = await opus_codec_capabilities(read_audio_data)
|
||||
packet_source = OpusPacketSource(
|
||||
read_audio_data,
|
||||
avdtp_protocol.l2cap_channel.peer_mtu,
|
||||
)
|
||||
|
||||
# Rewind to the start
|
||||
input_file.seek(0)
|
||||
|
||||
try:
|
||||
await self.stream_packets(
|
||||
avdtp_protocol,
|
||||
codec_type,
|
||||
vendor_id,
|
||||
codec_id,
|
||||
packet_source,
|
||||
codec_capabilities,
|
||||
)
|
||||
except Exception as error:
|
||||
print(color(f"!!! Error while streaming: {error}", "red"))
|
||||
|
||||
self.set_done()
|
||||
|
||||
if address:
|
||||
await self.connect(device, address)
|
||||
else:
|
||||
print(color("Waiting for an incoming connection...", "magenta"))
|
||||
|
||||
def set_done(self) -> None:
|
||||
if self.done:
|
||||
self.done.set()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def create_player(context) -> Player:
|
||||
return Player(
|
||||
transport=context.obj["hci_transport"],
|
||||
device_config=context.obj["device_config"],
|
||||
authenticate=context.obj["authenticate"],
|
||||
encrypt=context.obj["encrypt"],
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.group()
|
||||
@click.pass_context
|
||||
@click.option("--hci-transport", metavar="TRANSPORT", required=True)
|
||||
@click.option("--device-config", metavar="FILENAME", help="Device configuration file")
|
||||
@click.option(
|
||||
"--authenticate",
|
||||
is_flag=True,
|
||||
help="Request authentication when connecting",
|
||||
default=False,
|
||||
)
|
||||
@click.option(
|
||||
"--encrypt", is_flag=True, help="Request encryption when connecting", default=True
|
||||
)
|
||||
def player_cli(ctx, hci_transport, device_config, authenticate, encrypt):
|
||||
ctx.ensure_object(dict)
|
||||
ctx.obj["hci_transport"] = hci_transport
|
||||
ctx.obj["device_config"] = device_config
|
||||
ctx.obj["authenticate"] = authenticate
|
||||
ctx.obj["encrypt"] = encrypt
|
||||
|
||||
|
||||
@player_cli.command("discover")
|
||||
@click.pass_context
|
||||
def discover(context):
|
||||
"""Discover speakers or headphones"""
|
||||
player = create_player(context)
|
||||
asyncio.run(player.run(player.discover))
|
||||
|
||||
|
||||
@player_cli.command("inquire")
|
||||
@click.pass_context
|
||||
@click.argument(
|
||||
"address",
|
||||
metavar="ADDRESS",
|
||||
)
|
||||
def inquire(context, address):
|
||||
"""Connect to a speaker or headphone and inquire about their capabilities"""
|
||||
player = create_player(context)
|
||||
asyncio.run(player.run(lambda device: player.inquire(device, address)))
|
||||
|
||||
|
||||
@player_cli.command("pair")
|
||||
@click.pass_context
|
||||
@click.argument(
|
||||
"address",
|
||||
metavar="ADDRESS",
|
||||
)
|
||||
def pair(context, address):
|
||||
"""Pair with a speaker or headphone"""
|
||||
player = create_player(context)
|
||||
asyncio.run(player.run(lambda device: player.pair(device, address)))
|
||||
|
||||
|
||||
@player_cli.command("play")
|
||||
@click.pass_context
|
||||
@click.option(
|
||||
"--connect",
|
||||
"address",
|
||||
metavar="ADDRESS",
|
||||
help="Address or name to connect to",
|
||||
)
|
||||
@click.option(
|
||||
"-f",
|
||||
"--audio-format",
|
||||
type=click.Choice(["auto", "sbc", "aac", "opus"]),
|
||||
help="Audio file format (use 'auto' to infer the format from the file extension)",
|
||||
default="auto",
|
||||
)
|
||||
@click.argument("audio_file")
|
||||
def play(context, address, audio_format, audio_file):
|
||||
"""Play and audio file"""
|
||||
player = create_player(context)
|
||||
asyncio.run(
|
||||
player.run(
|
||||
lambda device: player.play(device, address, audio_format, audio_file)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
bumble.logging.setup_basic_logging("WARNING")
|
||||
player_cli()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
@@ -16,21 +16,15 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
import bumble.logging
|
||||
from bumble import core, hci, rfcomm, transport, utils
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, DeviceConfiguration, Connection
|
||||
from bumble import core
|
||||
from bumble import hci
|
||||
from bumble import rfcomm
|
||||
from bumble import transport
|
||||
from bumble import utils
|
||||
|
||||
from bumble.device import Connection, Device, DeviceConfiguration
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -237,6 +231,7 @@ class ClientBridge:
|
||||
address: str,
|
||||
tcp_host: str,
|
||||
tcp_port: int,
|
||||
authenticate: bool,
|
||||
encrypt: bool,
|
||||
):
|
||||
self.channel = channel
|
||||
@@ -245,6 +240,7 @@ class ClientBridge:
|
||||
self.address = address
|
||||
self.tcp_host = tcp_host
|
||||
self.tcp_port = tcp_port
|
||||
self.authenticate = authenticate
|
||||
self.encrypt = encrypt
|
||||
self.device: Optional[Device] = None
|
||||
self.connection: Optional[Connection] = None
|
||||
@@ -269,11 +265,16 @@ class ClientBridge:
|
||||
print(color(f"@@@ Connecting to Bluetooth {self.address}", "blue"))
|
||||
assert self.device
|
||||
self.connection = await self.device.connect(
|
||||
self.address, transport=core.BT_BR_EDR_TRANSPORT
|
||||
self.address, transport=core.PhysicalTransport.BR_EDR
|
||||
)
|
||||
print(color(f"@@@ Bluetooth connection: {self.connection}", "blue"))
|
||||
self.connection.on("disconnection", self.on_disconnection)
|
||||
|
||||
if self.authenticate:
|
||||
print(color("@@@ Authenticating Bluetooth connection", "blue"))
|
||||
await self.connection.authenticate()
|
||||
print(color("@@@ Bluetooth connection authenticated", "blue"))
|
||||
|
||||
if self.encrypt:
|
||||
print(color("@@@ Encrypting Bluetooth connection", "blue"))
|
||||
await self.connection.encrypt()
|
||||
@@ -399,7 +400,7 @@ class ClientBridge:
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run(device_config, hci_transport, bridge):
|
||||
print("<<< connecting to HCI...")
|
||||
async with await transport.open_transport_or_link(hci_transport) as (
|
||||
async with await transport.open_transport(hci_transport) as (
|
||||
hci_source,
|
||||
hci_sink,
|
||||
):
|
||||
@@ -491,8 +492,9 @@ def server(context, tcp_host, tcp_port):
|
||||
@click.argument("bluetooth-address")
|
||||
@click.option("--tcp-host", help="TCP host", default="_")
|
||||
@click.option("--tcp-port", help="TCP port", default=DEFAULT_CLIENT_TCP_PORT)
|
||||
@click.option("--authenticate", is_flag=True, help="Authenticate the connection")
|
||||
@click.option("--encrypt", is_flag=True, help="Encrypt the connection")
|
||||
def client(context, bluetooth_address, tcp_host, tcp_port, encrypt):
|
||||
def client(context, bluetooth_address, tcp_host, tcp_port, authenticate, encrypt):
|
||||
bridge = ClientBridge(
|
||||
context.obj["channel"],
|
||||
context.obj["uuid"],
|
||||
@@ -500,12 +502,13 @@ def client(context, bluetooth_address, tcp_host, tcp_port, encrypt):
|
||||
bluetooth_address,
|
||||
tcp_host,
|
||||
tcp_port,
|
||||
authenticate,
|
||||
encrypt,
|
||||
)
|
||||
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper())
|
||||
if __name__ == "__main__":
|
||||
bumble.logging.setup_basic_logging("WARNING")
|
||||
cli(obj={}) # pylint: disable=no-value-for-parameter
|
||||
|
||||
27
apps/scan.py
27
apps/scan.py
@@ -16,17 +16,17 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
|
||||
import click
|
||||
|
||||
import bumble.logging
|
||||
from bumble import data_types
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.device import Advertisement, Device
|
||||
from bumble.hci import HCI_LE_1M_PHY, HCI_LE_CODED_PHY, Address, HCI_Constant
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.smp import AddressResolver
|
||||
from bumble.device import Advertisement
|
||||
from bumble.hci import Address, HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
|
||||
from bumble.transport import open_transport
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -95,13 +95,22 @@ class AdvertisementPrinter:
|
||||
else:
|
||||
phy_info = ''
|
||||
|
||||
details = separator.join(
|
||||
[
|
||||
data_type.to_string(use_label=True)
|
||||
for data_type in data_types.data_types_from_advertising_data(
|
||||
advertisement.data
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
print(
|
||||
f'>>> {color(address, address_color)} '
|
||||
f'[{color(address_type_string, type_color)}]{address_qualifier}'
|
||||
f'{resolution_qualifier}:{separator}'
|
||||
f'{phy_info}'
|
||||
f'RSSI:{advertisement.rssi:4} {rssi_bar}{separator}'
|
||||
f'{advertisement.data.to_string(separator)}\n'
|
||||
f'{details}\n'
|
||||
)
|
||||
|
||||
def on_advertisement(self, advertisement):
|
||||
@@ -127,7 +136,7 @@ async def scan(
|
||||
transport,
|
||||
):
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
|
||||
async with await open_transport(transport) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
|
||||
if device_config:
|
||||
@@ -237,7 +246,7 @@ def main(
|
||||
device_config,
|
||||
transport,
|
||||
):
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
bumble.logging.setup_basic_logging('WARNING')
|
||||
asyncio.run(
|
||||
scan(
|
||||
min_rssi,
|
||||
|
||||
27
apps/show.py
27
apps/show.py
@@ -16,17 +16,17 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import datetime
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
|
||||
import click
|
||||
|
||||
from bumble.colors import color
|
||||
import bumble.logging
|
||||
from bumble import hci
|
||||
from bumble.transport.common import PacketReader
|
||||
from bumble.colors import color
|
||||
from bumble.helpers import PacketTracer
|
||||
|
||||
from bumble.transport.common import PacketReader
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -144,19 +144,20 @@ class Printer:
|
||||
help='Format of the input file',
|
||||
)
|
||||
@click.option(
|
||||
'--vendors',
|
||||
'--vendor',
|
||||
type=click.Choice(['android', 'zephyr']),
|
||||
multiple=True,
|
||||
help='Support vendor-specific commands (list one or more)',
|
||||
)
|
||||
@click.argument('filename')
|
||||
# pylint: disable=redefined-builtin
|
||||
def main(format, vendors, filename):
|
||||
for vendor in vendors:
|
||||
if vendor == 'android':
|
||||
import bumble.vendor.android.hci
|
||||
elif vendor == 'zephyr':
|
||||
import bumble.vendor.zephyr.hci
|
||||
def main(format, vendor, filename):
|
||||
for vendor_name in vendor:
|
||||
if vendor_name == 'android':
|
||||
# Prevent being deleted by linter.
|
||||
importlib.import_module('bumble.vendor.android.hci')
|
||||
elif vendor_name == 'zephyr':
|
||||
importlib.import_module('bumble.vendor.zephyr.hci')
|
||||
|
||||
input = open(filename, 'rb')
|
||||
if format == 'h4':
|
||||
@@ -180,11 +181,11 @@ def main(format, vendors, filename):
|
||||
else:
|
||||
printer.print(color("[TRUNCATED]", "red"))
|
||||
except Exception as error:
|
||||
logger.exception()
|
||||
logger.exception('')
|
||||
print(color(f'!!! {error}', 'red'))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
bumble.logging.setup_basic_logging('WARNING')
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
<tr><td>Codec</td><td><span id="codecText"></span></td></tr>
|
||||
<tr><td>Packets</td><td><span id="packetsReceivedText"></span></td></tr>
|
||||
<tr><td>Bytes</td><td><span id="bytesReceivedText"></span></td></tr>
|
||||
<tr><td>Bitrate</td><td><span id="bitrate"></span></td></tr>
|
||||
</table>
|
||||
</td>
|
||||
<td>
|
||||
|
||||
@@ -7,17 +7,19 @@ let connectionText;
|
||||
let codecText;
|
||||
let packetsReceivedText;
|
||||
let bytesReceivedText;
|
||||
let bitrateText;
|
||||
let streamStateText;
|
||||
let connectionStateText;
|
||||
let controlsDiv;
|
||||
let audioOnButton;
|
||||
let mediaSource;
|
||||
let sourceBuffer;
|
||||
let audioElement;
|
||||
let audioDecoder;
|
||||
let audioCodec;
|
||||
let audioContext;
|
||||
let audioAnalyzer;
|
||||
let audioFrequencyBinCount;
|
||||
let audioFrequencyData;
|
||||
let nextAudioStartPosition = 0;
|
||||
let audioStartTime = 0;
|
||||
let packetsReceived = 0;
|
||||
let bytesReceived = 0;
|
||||
let audioState = "stopped";
|
||||
@@ -29,20 +31,17 @@ let bandwidthCanvas;
|
||||
let bandwidthCanvasContext;
|
||||
let bandwidthBinCount;
|
||||
let bandwidthBins = [];
|
||||
let bitrateSamples = [];
|
||||
|
||||
const FFT_WIDTH = 800;
|
||||
const FFT_HEIGHT = 256;
|
||||
const BANDWIDTH_WIDTH = 500;
|
||||
const BANDWIDTH_HEIGHT = 100;
|
||||
|
||||
function hexToBytes(hex) {
|
||||
return Uint8Array.from(hex.match(/.{1,2}/g).map((byte) => parseInt(byte, 16)));
|
||||
}
|
||||
const BITRATE_WINDOW = 30;
|
||||
|
||||
function init() {
|
||||
initUI();
|
||||
initMediaSource();
|
||||
initAudioElement();
|
||||
initAudioContext();
|
||||
initAnalyzer();
|
||||
|
||||
connect();
|
||||
@@ -56,6 +55,7 @@ function initUI() {
|
||||
codecText = document.getElementById("codecText");
|
||||
packetsReceivedText = document.getElementById("packetsReceivedText");
|
||||
bytesReceivedText = document.getElementById("bytesReceivedText");
|
||||
bitrateText = document.getElementById("bitrate");
|
||||
streamStateText = document.getElementById("streamStateText");
|
||||
connectionStateText = document.getElementById("connectionStateText");
|
||||
audioSupportMessageText = document.getElementById("audioSupportMessageText");
|
||||
@@ -67,17 +67,9 @@ function initUI() {
|
||||
requestAnimationFrame(onAnimationFrame);
|
||||
}
|
||||
|
||||
function initMediaSource() {
|
||||
mediaSource = new MediaSource();
|
||||
mediaSource.onsourceopen = onMediaSourceOpen;
|
||||
mediaSource.onsourceclose = onMediaSourceClose;
|
||||
mediaSource.onsourceended = onMediaSourceEnd;
|
||||
}
|
||||
|
||||
function initAudioElement() {
|
||||
audioElement = document.getElementById("audio");
|
||||
audioElement.src = URL.createObjectURL(mediaSource);
|
||||
// audioElement.controls = true;
|
||||
function initAudioContext() {
|
||||
audioContext = new AudioContext();
|
||||
audioContext.onstatechange = () => console.log("AudioContext state:", audioContext.state);
|
||||
}
|
||||
|
||||
function initAnalyzer() {
|
||||
@@ -94,24 +86,16 @@ function initAnalyzer() {
|
||||
bandwidthCanvasContext = bandwidthCanvas.getContext('2d');
|
||||
bandwidthCanvasContext.fillStyle = "rgb(255, 255, 255)";
|
||||
bandwidthCanvasContext.fillRect(0, 0, BANDWIDTH_WIDTH, BANDWIDTH_HEIGHT);
|
||||
}
|
||||
|
||||
function startAnalyzer() {
|
||||
// FFT
|
||||
if (audioElement.captureStream !== undefined) {
|
||||
audioContext = new AudioContext();
|
||||
audioAnalyzer = audioContext.createAnalyser();
|
||||
audioAnalyzer.fftSize = 128;
|
||||
audioFrequencyBinCount = audioAnalyzer.frequencyBinCount;
|
||||
audioFrequencyData = new Uint8Array(audioFrequencyBinCount);
|
||||
const stream = audioElement.captureStream();
|
||||
const source = audioContext.createMediaStreamSource(stream);
|
||||
source.connect(audioAnalyzer);
|
||||
}
|
||||
|
||||
// Bandwidth
|
||||
bandwidthBinCount = BANDWIDTH_WIDTH / 2;
|
||||
bandwidthBins = [];
|
||||
bitrateSamples = [];
|
||||
|
||||
audioAnalyzer = audioContext.createAnalyser();
|
||||
audioAnalyzer.fftSize = 128;
|
||||
audioFrequencyBinCount = audioAnalyzer.frequencyBinCount;
|
||||
audioFrequencyData = new Uint8Array(audioFrequencyBinCount);
|
||||
|
||||
audioAnalyzer.connect(audioContext.destination)
|
||||
}
|
||||
|
||||
function setConnectionText(message) {
|
||||
@@ -148,7 +132,8 @@ function onAnimationFrame() {
|
||||
bandwidthCanvasContext.fillRect(0, 0, BANDWIDTH_WIDTH, BANDWIDTH_HEIGHT);
|
||||
bandwidthCanvasContext.fillStyle = `rgb(100, 100, 100)`;
|
||||
for (let t = 0; t < bandwidthBins.length; t++) {
|
||||
const lineHeight = (bandwidthBins[t] / 1000) * BANDWIDTH_HEIGHT;
|
||||
const bytesReceived = bandwidthBins[t]
|
||||
const lineHeight = (bytesReceived / 1000) * BANDWIDTH_HEIGHT;
|
||||
bandwidthCanvasContext.fillRect(t * 2, BANDWIDTH_HEIGHT - lineHeight, 2, lineHeight);
|
||||
}
|
||||
|
||||
@@ -156,28 +141,14 @@ function onAnimationFrame() {
|
||||
requestAnimationFrame(onAnimationFrame);
|
||||
}
|
||||
|
||||
function onMediaSourceOpen() {
|
||||
console.log(this.readyState);
|
||||
sourceBuffer = mediaSource.addSourceBuffer("audio/aac");
|
||||
}
|
||||
|
||||
function onMediaSourceClose() {
|
||||
console.log(this.readyState);
|
||||
}
|
||||
|
||||
function onMediaSourceEnd() {
|
||||
console.log(this.readyState);
|
||||
}
|
||||
|
||||
async function startAudio() {
|
||||
try {
|
||||
console.log("starting audio...");
|
||||
audioOnButton.disabled = true;
|
||||
audioState = "starting";
|
||||
await audioElement.play();
|
||||
audioContext.resume();
|
||||
console.log("audio started");
|
||||
audioState = "playing";
|
||||
startAnalyzer();
|
||||
} catch(error) {
|
||||
console.error(`play failed: ${error}`);
|
||||
audioState = "stopped";
|
||||
@@ -185,12 +156,47 @@ async function startAudio() {
|
||||
}
|
||||
}
|
||||
|
||||
function onAudioPacket(packet) {
|
||||
if (audioState != "stopped") {
|
||||
// Queue the audio packet.
|
||||
sourceBuffer.appendBuffer(packet);
|
||||
function onDecodedAudio(audioData) {
|
||||
const bufferSource = audioContext.createBufferSource()
|
||||
|
||||
const now = audioContext.currentTime;
|
||||
let nextAudioStartTime = audioStartTime + (nextAudioStartPosition / audioData.sampleRate);
|
||||
if (nextAudioStartTime < now) {
|
||||
console.log("starting new audio time base")
|
||||
audioStartTime = now;
|
||||
nextAudioStartTime = now;
|
||||
nextAudioStartPosition = 0;
|
||||
} else {
|
||||
console.log(`audio buffer scheduled in ${nextAudioStartTime - now}`)
|
||||
}
|
||||
|
||||
const audioBuffer = audioContext.createBuffer(
|
||||
audioData.numberOfChannels,
|
||||
audioData.numberOfFrames,
|
||||
audioData.sampleRate
|
||||
);
|
||||
|
||||
for (let channel = 0; channel < audioData.numberOfChannels; channel++) {
|
||||
audioData.copyTo(
|
||||
audioBuffer.getChannelData(channel),
|
||||
{
|
||||
planeIndex: channel,
|
||||
format: "f32-planar"
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
bufferSource.buffer = audioBuffer;
|
||||
bufferSource.connect(audioAnalyzer)
|
||||
bufferSource.start(nextAudioStartTime);
|
||||
nextAudioStartPosition += audioData.numberOfFrames;
|
||||
}
|
||||
|
||||
function onCodecError(error) {
|
||||
console.log("Codec error:", error)
|
||||
}
|
||||
|
||||
async function onAudioPacket(packet) {
|
||||
packetsReceived += 1;
|
||||
packetsReceivedText.innerText = packetsReceived;
|
||||
bytesReceived += packet.byteLength;
|
||||
@@ -200,6 +206,48 @@ function onAudioPacket(packet) {
|
||||
if (bandwidthBins.length > bandwidthBinCount) {
|
||||
bandwidthBins.shift();
|
||||
}
|
||||
bitrateSamples[bitrateSamples.length] = {ts: Date.now(), bytes: packet.byteLength}
|
||||
if (bitrateSamples.length > BITRATE_WINDOW) {
|
||||
bitrateSamples.shift();
|
||||
}
|
||||
if (bitrateSamples.length >= 2) {
|
||||
const windowBytes = bitrateSamples.reduce((accumulator, x) => accumulator + x.bytes, 0) - bitrateSamples[0].bytes;
|
||||
const elapsed = bitrateSamples[bitrateSamples.length-1].ts - bitrateSamples[0].ts;
|
||||
const bitrate = Math.floor(8 * windowBytes / elapsed)
|
||||
bitrateText.innerText = `${bitrate} kb/s`
|
||||
}
|
||||
|
||||
if (audioState == "stopped") {
|
||||
return;
|
||||
}
|
||||
|
||||
if (audioDecoder === undefined) {
|
||||
let audioConfig;
|
||||
if (audioCodec == 'aac') {
|
||||
audioConfig = {
|
||||
codec: 'mp4a.40.2',
|
||||
sampleRate: 44100, // ignored
|
||||
numberOfChannels: 2, // ignored
|
||||
}
|
||||
} else if (audioCodec == 'opus') {
|
||||
audioConfig = {
|
||||
codec: 'opus',
|
||||
sampleRate: 48000, // ignored
|
||||
numberOfChannels: 2, // ignored
|
||||
}
|
||||
}
|
||||
audioDecoder = new AudioDecoder({ output: onDecodedAudio, error: onCodecError });
|
||||
audioDecoder.configure(audioConfig)
|
||||
}
|
||||
|
||||
const encodedAudio = new EncodedAudioChunk({
|
||||
type: "key",
|
||||
data: packet,
|
||||
timestamp: 0,
|
||||
transfer: [packet],
|
||||
});
|
||||
|
||||
audioDecoder.decode(encodedAudio);
|
||||
}
|
||||
|
||||
function onChannelOpen() {
|
||||
@@ -249,16 +297,19 @@ function onChannelMessage(message) {
|
||||
}
|
||||
}
|
||||
|
||||
function onHelloMessage(params) {
|
||||
async function onHelloMessage(params) {
|
||||
codecText.innerText = params.codec;
|
||||
if (params.codec != "aac") {
|
||||
audioOnButton.disabled = true;
|
||||
audioSupportMessageText.innerText = "Only AAC can be played, audio will be disabled";
|
||||
audioSupportMessageText.style.display = "inline-block";
|
||||
} else {
|
||||
|
||||
if (params.codec == "aac" || params.codec == "opus") {
|
||||
audioCodec = params.codec
|
||||
audioSupportMessageText.innerText = "";
|
||||
audioSupportMessageText.style.display = "none";
|
||||
} else {
|
||||
audioOnButton.disabled = true;
|
||||
audioSupportMessageText.innerText = "Only AAC and Opus can be played, audio will be disabled";
|
||||
audioSupportMessageText.style.display = "inline-block";
|
||||
}
|
||||
|
||||
if (params.streamState) {
|
||||
setStreamState(params.streamState);
|
||||
}
|
||||
|
||||
@@ -16,54 +16,49 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import asyncio.subprocess
|
||||
from importlib import resources
|
||||
import enum
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import pathlib
|
||||
import subprocess
|
||||
from typing import Dict, List, Optional
|
||||
import weakref
|
||||
from importlib import resources
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import aiohttp
|
||||
import click
|
||||
from aiohttp import web
|
||||
|
||||
import bumble
|
||||
from bumble.colors import color
|
||||
from bumble.core import BT_BR_EDR_TRANSPORT, CommandTimeoutError
|
||||
from bumble.device import Connection, Device, DeviceConfiguration
|
||||
from bumble.hci import HCI_StatusError
|
||||
from bumble.pairing import PairingConfig
|
||||
from bumble.sdp import ServiceAttribute
|
||||
from bumble.transport import open_transport
|
||||
import bumble.logging
|
||||
from bumble.a2dp import (
|
||||
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
|
||||
A2DP_NON_A2DP_CODEC_TYPE,
|
||||
A2DP_SBC_CODEC_TYPE,
|
||||
AacMediaCodecInformation,
|
||||
OpusMediaCodecInformation,
|
||||
SbcMediaCodecInformation,
|
||||
make_audio_sink_service_sdp_records,
|
||||
)
|
||||
from bumble.avdtp import (
|
||||
AVDTP_AUDIO_MEDIA_TYPE,
|
||||
Listener,
|
||||
MediaCodecCapabilities,
|
||||
MediaPacket,
|
||||
Protocol,
|
||||
)
|
||||
from bumble.a2dp import (
|
||||
MPEG_2_AAC_LC_OBJECT_TYPE,
|
||||
make_audio_sink_service_sdp_records,
|
||||
A2DP_SBC_CODEC_TYPE,
|
||||
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
|
||||
SBC_MONO_CHANNEL_MODE,
|
||||
SBC_DUAL_CHANNEL_MODE,
|
||||
SBC_SNR_ALLOCATION_METHOD,
|
||||
SBC_LOUDNESS_ALLOCATION_METHOD,
|
||||
SBC_STEREO_CHANNEL_MODE,
|
||||
SBC_JOINT_STEREO_CHANNEL_MODE,
|
||||
SbcMediaCodecInformation,
|
||||
AacMediaCodecInformation,
|
||||
)
|
||||
from bumble.utils import AsyncRunner
|
||||
from bumble.codecs import AacAudioRtpPacket
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble.core import CommandTimeoutError, PhysicalTransport
|
||||
from bumble.device import Connection, Device, DeviceConfiguration
|
||||
from bumble.hci import HCI_StatusError
|
||||
from bumble.pairing import PairingConfig
|
||||
from bumble.rtp import MediaPacket
|
||||
from bumble.sdp import ServiceAttribute
|
||||
from bumble.transport import open_transport
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -85,6 +80,8 @@ class AudioExtractor:
|
||||
return AacAudioExtractor()
|
||||
if codec == 'sbc':
|
||||
return SbcAudioExtractor()
|
||||
if codec == 'opus':
|
||||
return OpusAudioExtractor()
|
||||
|
||||
def extract_audio(self, packet: MediaPacket) -> bytes:
|
||||
raise NotImplementedError()
|
||||
@@ -93,7 +90,7 @@ class AudioExtractor:
|
||||
# -----------------------------------------------------------------------------
|
||||
class AacAudioExtractor:
|
||||
def extract_audio(self, packet: MediaPacket) -> bytes:
|
||||
return AacAudioRtpPacket(packet.payload).to_adts()
|
||||
return AacAudioRtpPacket.from_bytes(packet.payload).to_adts()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -109,6 +106,13 @@ class SbcAudioExtractor:
|
||||
return packet.payload[1:]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OpusAudioExtractor:
|
||||
def extract_audio(self, packet: MediaPacket) -> bytes:
|
||||
# TODO: parse fields
|
||||
return packet.payload[1:]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Output:
|
||||
async def start(self) -> None:
|
||||
@@ -242,7 +246,7 @@ class FfplayOutput(QueuedOutput):
|
||||
await super().start()
|
||||
|
||||
self.subprocess = await asyncio.create_subprocess_shell(
|
||||
f'ffplay -f {self.codec} pipe:0',
|
||||
f'ffplay -probesize 32 -f {self.codec} pipe:0',
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
@@ -406,10 +410,24 @@ class Speaker:
|
||||
STARTED = 2
|
||||
SUSPENDED = 3
|
||||
|
||||
def __init__(self, device_config, transport, codec, discover, outputs, ui_port):
|
||||
def __init__(
|
||||
self,
|
||||
device_config,
|
||||
transport,
|
||||
codec,
|
||||
sampling_frequencies,
|
||||
bitrate,
|
||||
vbr,
|
||||
discover,
|
||||
outputs,
|
||||
ui_port,
|
||||
):
|
||||
self.device_config = device_config
|
||||
self.transport = transport
|
||||
self.codec = codec
|
||||
self.sampling_frequencies = sampling_frequencies
|
||||
self.bitrate = bitrate
|
||||
self.vbr = vbr
|
||||
self.discover = discover
|
||||
self.ui_port = ui_port
|
||||
self.device = None
|
||||
@@ -430,7 +448,7 @@ class Speaker:
|
||||
# Create an HTTP server for the UI
|
||||
self.ui_server = UiServer(speaker=self, port=ui_port)
|
||||
|
||||
def sdp_records(self) -> Dict[int, List[ServiceAttribute]]:
|
||||
def sdp_records(self) -> dict[int, list[ServiceAttribute]]:
|
||||
service_record_handle = 0x00010001
|
||||
return {
|
||||
service_record_handle: make_audio_sink_service_sdp_records(
|
||||
@@ -445,44 +463,92 @@ class Speaker:
|
||||
if self.codec == 'sbc':
|
||||
return self.sbc_codec_capabilities()
|
||||
|
||||
if self.codec == 'opus':
|
||||
return self.opus_codec_capabilities()
|
||||
|
||||
raise RuntimeError('unsupported codec')
|
||||
|
||||
def aac_codec_capabilities(self) -> MediaCodecCapabilities:
|
||||
supported_sampling_frequencies = AacMediaCodecInformation.SamplingFrequency(0)
|
||||
for sampling_frequency in self.sampling_frequencies or [
|
||||
8000,
|
||||
11025,
|
||||
12000,
|
||||
16000,
|
||||
22050,
|
||||
24000,
|
||||
32000,
|
||||
44100,
|
||||
48000,
|
||||
]:
|
||||
supported_sampling_frequencies |= (
|
||||
AacMediaCodecInformation.SamplingFrequency.from_int(sampling_frequency)
|
||||
)
|
||||
return MediaCodecCapabilities(
|
||||
media_type=AVDTP_AUDIO_MEDIA_TYPE,
|
||||
media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE,
|
||||
media_codec_information=AacMediaCodecInformation.from_lists(
|
||||
object_types=[MPEG_2_AAC_LC_OBJECT_TYPE],
|
||||
sampling_frequencies=[48000, 44100],
|
||||
channels=[1, 2],
|
||||
vbr=1,
|
||||
bitrate=256000,
|
||||
media_codec_information=AacMediaCodecInformation(
|
||||
object_type=AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC,
|
||||
sampling_frequency=supported_sampling_frequencies,
|
||||
channels=AacMediaCodecInformation.Channels.MONO
|
||||
| AacMediaCodecInformation.Channels.STEREO,
|
||||
vbr=1 if self.vbr else 0,
|
||||
bitrate=self.bitrate or 256000,
|
||||
),
|
||||
)
|
||||
|
||||
def sbc_codec_capabilities(self) -> MediaCodecCapabilities:
|
||||
supported_sampling_frequencies = SbcMediaCodecInformation.SamplingFrequency(0)
|
||||
for sampling_frequency in self.sampling_frequencies or [
|
||||
16000,
|
||||
32000,
|
||||
44100,
|
||||
48000,
|
||||
]:
|
||||
supported_sampling_frequencies |= (
|
||||
SbcMediaCodecInformation.SamplingFrequency.from_int(sampling_frequency)
|
||||
)
|
||||
return MediaCodecCapabilities(
|
||||
media_type=AVDTP_AUDIO_MEDIA_TYPE,
|
||||
media_codec_type=A2DP_SBC_CODEC_TYPE,
|
||||
media_codec_information=SbcMediaCodecInformation.from_lists(
|
||||
sampling_frequencies=[48000, 44100, 32000, 16000],
|
||||
channel_modes=[
|
||||
SBC_MONO_CHANNEL_MODE,
|
||||
SBC_DUAL_CHANNEL_MODE,
|
||||
SBC_STEREO_CHANNEL_MODE,
|
||||
SBC_JOINT_STEREO_CHANNEL_MODE,
|
||||
],
|
||||
block_lengths=[4, 8, 12, 16],
|
||||
subbands=[4, 8],
|
||||
allocation_methods=[
|
||||
SBC_LOUDNESS_ALLOCATION_METHOD,
|
||||
SBC_SNR_ALLOCATION_METHOD,
|
||||
],
|
||||
media_codec_information=SbcMediaCodecInformation(
|
||||
sampling_frequency=supported_sampling_frequencies,
|
||||
channel_mode=SbcMediaCodecInformation.ChannelMode.MONO
|
||||
| SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL
|
||||
| SbcMediaCodecInformation.ChannelMode.STEREO
|
||||
| SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
|
||||
block_length=SbcMediaCodecInformation.BlockLength.BL_4
|
||||
| SbcMediaCodecInformation.BlockLength.BL_8
|
||||
| SbcMediaCodecInformation.BlockLength.BL_12
|
||||
| SbcMediaCodecInformation.BlockLength.BL_16,
|
||||
subbands=SbcMediaCodecInformation.Subbands.S_4
|
||||
| SbcMediaCodecInformation.Subbands.S_8,
|
||||
allocation_method=SbcMediaCodecInformation.AllocationMethod.LOUDNESS
|
||||
| SbcMediaCodecInformation.AllocationMethod.SNR,
|
||||
minimum_bitpool_value=2,
|
||||
maximum_bitpool_value=53,
|
||||
),
|
||||
)
|
||||
|
||||
def opus_codec_capabilities(self) -> MediaCodecCapabilities:
|
||||
supported_sampling_frequencies = OpusMediaCodecInformation.SamplingFrequency(0)
|
||||
for sampling_frequency in self.sampling_frequencies or [48000]:
|
||||
supported_sampling_frequencies |= (
|
||||
OpusMediaCodecInformation.SamplingFrequency.from_int(sampling_frequency)
|
||||
)
|
||||
return MediaCodecCapabilities(
|
||||
media_type=AVDTP_AUDIO_MEDIA_TYPE,
|
||||
media_codec_type=A2DP_NON_A2DP_CODEC_TYPE,
|
||||
media_codec_information=OpusMediaCodecInformation(
|
||||
frame_size=OpusMediaCodecInformation.FrameSize.FS_10MS
|
||||
| OpusMediaCodecInformation.FrameSize.FS_20MS,
|
||||
channel_mode=OpusMediaCodecInformation.ChannelMode.MONO
|
||||
| OpusMediaCodecInformation.ChannelMode.STEREO
|
||||
| OpusMediaCodecInformation.ChannelMode.DUAL_MONO,
|
||||
sampling_frequency=supported_sampling_frequencies,
|
||||
),
|
||||
)
|
||||
|
||||
async def dispatch_to_outputs(self, function):
|
||||
for output in self.outputs:
|
||||
await function(output)
|
||||
@@ -570,7 +636,9 @@ class Speaker:
|
||||
async def connect(self, address):
|
||||
# Connect to the source
|
||||
print(f'=== Connecting to {address}...')
|
||||
connection = await self.device.connect(address, transport=BT_BR_EDR_TRANSPORT)
|
||||
connection = await self.device.connect(
|
||||
address, transport=PhysicalTransport.BR_EDR
|
||||
)
|
||||
print(f'=== Connected to {connection.peer_address}')
|
||||
|
||||
# Request authentication
|
||||
@@ -675,7 +743,26 @@ def speaker_cli(ctx, device_config):
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--codec', type=click.Choice(['sbc', 'aac']), default='aac', show_default=True
|
||||
'--codec',
|
||||
type=click.Choice(['sbc', 'aac', 'opus']),
|
||||
default='aac',
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
'--sampling-frequency',
|
||||
metavar='SAMPLING-FREQUENCY',
|
||||
type=int,
|
||||
multiple=True,
|
||||
help='Enable a sampling frequency (may be specified more than once)',
|
||||
)
|
||||
@click.option(
|
||||
'--bitrate',
|
||||
metavar='BITRATE',
|
||||
type=int,
|
||||
help='Supported bitrate (AAC only)',
|
||||
)
|
||||
@click.option(
|
||||
'--vbr/--no-vbr', is_flag=True, default=True, help='Enable VBR (AAC only)'
|
||||
)
|
||||
@click.option(
|
||||
'--discover', is_flag=True, help='Discover remote endpoints once connected'
|
||||
@@ -706,7 +793,16 @@ def speaker_cli(ctx, device_config):
|
||||
@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
|
||||
@click.argument('transport')
|
||||
def speaker(
|
||||
transport, codec, connect_address, discover, output, ui_port, device_config
|
||||
transport,
|
||||
codec,
|
||||
sampling_frequency,
|
||||
bitrate,
|
||||
vbr,
|
||||
connect_address,
|
||||
discover,
|
||||
output,
|
||||
ui_port,
|
||||
device_config,
|
||||
):
|
||||
"""Run the speaker."""
|
||||
|
||||
@@ -721,15 +817,23 @@ def speaker(
|
||||
output = list(filter(lambda x: x != '@ffplay', output))
|
||||
|
||||
asyncio.run(
|
||||
Speaker(device_config, transport, codec, discover, output, ui_port).run(
|
||||
connect_address
|
||||
)
|
||||
Speaker(
|
||||
device_config,
|
||||
transport,
|
||||
codec,
|
||||
sampling_frequency,
|
||||
bitrate,
|
||||
vbr,
|
||||
discover,
|
||||
output,
|
||||
ui_port,
|
||||
).run(connect_address)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
bumble.logging.setup_basic_logging('WARNING')
|
||||
speaker()
|
||||
|
||||
|
||||
|
||||
@@ -16,10 +16,10 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
|
||||
import click
|
||||
|
||||
import bumble.logging
|
||||
from bumble.device import Device
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.transport import open_transport
|
||||
@@ -68,7 +68,7 @@ def main(keystore_file, hci_transport, device_config, address):
|
||||
instantiated.
|
||||
If no address is passed, the existing pairing keys for all addresses are printed.
|
||||
"""
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
bumble.logging.setup_basic_logging()
|
||||
|
||||
if not keystore_file and not hci_transport:
|
||||
print('either --keystore-file or --hci-transport must be specified.')
|
||||
|
||||
@@ -26,15 +26,13 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import os
|
||||
import logging
|
||||
import click
|
||||
import usb1
|
||||
|
||||
import bumble.logging
|
||||
from bumble.colors import color
|
||||
from bumble.transport.usb import load_libusb
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -169,7 +167,7 @@ def is_bluetooth_hci(device):
|
||||
@click.command()
|
||||
@click.option('--verbose', is_flag=True, default=False, help='Print more details')
|
||||
def main(verbose):
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
bumble.logging.setup_basic_logging('WARNING')
|
||||
|
||||
load_libusb()
|
||||
with usb1.USBContext() as context:
|
||||
|
||||
747
bumble/a2dp.py
747
bumble/a2dp.py
@@ -18,31 +18,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import struct
|
||||
import enum
|
||||
import logging
|
||||
import struct
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import List, Callable, Awaitable
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from .company_ids import COMPANY_IDENTIFIERS
|
||||
from .sdp import (
|
||||
DataElement,
|
||||
ServiceAttribute,
|
||||
SDP_PUBLIC_BROWSE_ROOT,
|
||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
)
|
||||
from .core import (
|
||||
BT_L2CAP_PROTOCOL_ID,
|
||||
BT_AUDIO_SOURCE_SERVICE,
|
||||
BT_AUDIO_SINK_SERVICE,
|
||||
BT_AVDTP_PROTOCOL_ID,
|
||||
from typing_extensions import ClassVar, Self
|
||||
|
||||
from bumble.codecs import AacAudioRtpPacket
|
||||
from bumble.company_ids import COMPANY_IDENTIFIERS
|
||||
from bumble.core import (
|
||||
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
|
||||
BT_AUDIO_SINK_SERVICE,
|
||||
BT_AUDIO_SOURCE_SERVICE,
|
||||
BT_AVDTP_PROTOCOL_ID,
|
||||
BT_L2CAP_PROTOCOL_ID,
|
||||
name_or_number,
|
||||
)
|
||||
|
||||
from bumble.rtp import MediaPacket
|
||||
from bumble.sdp import (
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_PUBLIC_BROWSE_ROOT,
|
||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
DataElement,
|
||||
ServiceAttribute,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -103,6 +107,8 @@ SBC_ALLOCATION_METHOD_NAMES = {
|
||||
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
|
||||
}
|
||||
|
||||
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
||||
|
||||
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
|
||||
8000,
|
||||
11025,
|
||||
@@ -130,6 +136,9 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
|
||||
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
|
||||
}
|
||||
|
||||
|
||||
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
@@ -145,7 +154,7 @@ def flags_to_list(flags, values):
|
||||
# -----------------------------------------------------------------------------
|
||||
def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)):
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from .avdtp import AVDTP_PSM
|
||||
from bumble.avdtp import AVDTP_PSM
|
||||
|
||||
version_int = version[0] << 8 | version[1]
|
||||
return [
|
||||
@@ -199,7 +208,7 @@ def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3))
|
||||
# -----------------------------------------------------------------------------
|
||||
def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from .avdtp import AVDTP_PSM
|
||||
from bumble.avdtp import AVDTP_PSM
|
||||
|
||||
version_int = version[0] << 8 | version[1]
|
||||
return [
|
||||
@@ -257,38 +266,61 @@ class SbcMediaCodecInformation:
|
||||
A2DP spec - 4.3.2 Codec Specific Information Elements
|
||||
'''
|
||||
|
||||
sampling_frequency: int
|
||||
channel_mode: int
|
||||
block_length: int
|
||||
subbands: int
|
||||
allocation_method: int
|
||||
sampling_frequency: SamplingFrequency
|
||||
channel_mode: ChannelMode
|
||||
block_length: BlockLength
|
||||
subbands: Subbands
|
||||
allocation_method: AllocationMethod
|
||||
minimum_bitpool_value: int
|
||||
maximum_bitpool_value: int
|
||||
|
||||
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
|
||||
CHANNEL_MODE_BITS = {
|
||||
SBC_MONO_CHANNEL_MODE: 1 << 3,
|
||||
SBC_DUAL_CHANNEL_MODE: 1 << 2,
|
||||
SBC_STEREO_CHANNEL_MODE: 1 << 1,
|
||||
SBC_JOINT_STEREO_CHANNEL_MODE: 1,
|
||||
}
|
||||
BLOCK_LENGTH_BITS = {4: 1 << 3, 8: 1 << 2, 12: 1 << 1, 16: 1}
|
||||
SUBBANDS_BITS = {4: 1 << 1, 8: 1}
|
||||
ALLOCATION_METHOD_BITS = {
|
||||
SBC_SNR_ALLOCATION_METHOD: 1 << 1,
|
||||
SBC_LOUDNESS_ALLOCATION_METHOD: 1,
|
||||
}
|
||||
class SamplingFrequency(enum.IntFlag):
|
||||
SF_16000 = 1 << 3
|
||||
SF_32000 = 1 << 2
|
||||
SF_44100 = 1 << 1
|
||||
SF_48000 = 1 << 0
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> SbcMediaCodecInformation:
|
||||
sampling_frequency = (data[0] >> 4) & 0x0F
|
||||
channel_mode = (data[0] >> 0) & 0x0F
|
||||
block_length = (data[1] >> 4) & 0x0F
|
||||
subbands = (data[1] >> 2) & 0x03
|
||||
allocation_method = (data[1] >> 0) & 0x03
|
||||
@classmethod
|
||||
def from_int(cls, sampling_frequency: int) -> Self:
|
||||
sampling_frequencies = [
|
||||
16000,
|
||||
32000,
|
||||
44100,
|
||||
48000,
|
||||
]
|
||||
index = sampling_frequencies.index(sampling_frequency)
|
||||
return cls(1 << (len(sampling_frequencies) - index - 1))
|
||||
|
||||
class ChannelMode(enum.IntFlag):
|
||||
MONO = 1 << 3
|
||||
DUAL_CHANNEL = 1 << 2
|
||||
STEREO = 1 << 1
|
||||
JOINT_STEREO = 1 << 0
|
||||
|
||||
class BlockLength(enum.IntFlag):
|
||||
BL_4 = 1 << 3
|
||||
BL_8 = 1 << 2
|
||||
BL_12 = 1 << 1
|
||||
BL_16 = 1 << 0
|
||||
|
||||
class Subbands(enum.IntFlag):
|
||||
S_4 = 1 << 1
|
||||
S_8 = 1 << 0
|
||||
|
||||
class AllocationMethod(enum.IntFlag):
|
||||
SNR = 1 << 1
|
||||
LOUDNESS = 1 << 0
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
sampling_frequency = cls.SamplingFrequency((data[0] >> 4) & 0x0F)
|
||||
channel_mode = cls.ChannelMode((data[0] >> 0) & 0x0F)
|
||||
block_length = cls.BlockLength((data[1] >> 4) & 0x0F)
|
||||
subbands = cls.Subbands((data[1] >> 2) & 0x03)
|
||||
allocation_method = cls.AllocationMethod((data[1] >> 0) & 0x03)
|
||||
minimum_bitpool_value = (data[2] >> 0) & 0xFF
|
||||
maximum_bitpool_value = (data[3] >> 0) & 0xFF
|
||||
return SbcMediaCodecInformation(
|
||||
return cls(
|
||||
sampling_frequency,
|
||||
channel_mode,
|
||||
block_length,
|
||||
@@ -298,52 +330,6 @@ class SbcMediaCodecInformation:
|
||||
maximum_bitpool_value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_discrete_values(
|
||||
cls,
|
||||
sampling_frequency: int,
|
||||
channel_mode: int,
|
||||
block_length: int,
|
||||
subbands: int,
|
||||
allocation_method: int,
|
||||
minimum_bitpool_value: int,
|
||||
maximum_bitpool_value: int,
|
||||
) -> SbcMediaCodecInformation:
|
||||
return SbcMediaCodecInformation(
|
||||
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
|
||||
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
|
||||
block_length=cls.BLOCK_LENGTH_BITS[block_length],
|
||||
subbands=cls.SUBBANDS_BITS[subbands],
|
||||
allocation_method=cls.ALLOCATION_METHOD_BITS[allocation_method],
|
||||
minimum_bitpool_value=minimum_bitpool_value,
|
||||
maximum_bitpool_value=maximum_bitpool_value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lists(
|
||||
cls,
|
||||
sampling_frequencies: List[int],
|
||||
channel_modes: List[int],
|
||||
block_lengths: List[int],
|
||||
subbands: List[int],
|
||||
allocation_methods: List[int],
|
||||
minimum_bitpool_value: int,
|
||||
maximum_bitpool_value: int,
|
||||
) -> SbcMediaCodecInformation:
|
||||
return SbcMediaCodecInformation(
|
||||
sampling_frequency=sum(
|
||||
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
|
||||
),
|
||||
channel_mode=sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes),
|
||||
block_length=sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths),
|
||||
subbands=sum(cls.SUBBANDS_BITS[x] for x in subbands),
|
||||
allocation_method=sum(
|
||||
cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods
|
||||
),
|
||||
minimum_bitpool_value=minimum_bitpool_value,
|
||||
maximum_bitpool_value=maximum_bitpool_value,
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes(
|
||||
[
|
||||
@@ -356,23 +342,6 @@ class SbcMediaCodecInformation:
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
|
||||
allocation_methods = ['SNR', 'Loudness']
|
||||
return '\n'.join(
|
||||
# pylint: disable=line-too-long
|
||||
[
|
||||
'SbcMediaCodecInformation(',
|
||||
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
|
||||
f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}',
|
||||
f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}',
|
||||
f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}',
|
||||
f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}',
|
||||
f' minimum_bitpool_value: {self.minimum_bitpool_value}',
|
||||
f' maximum_bitpool_value: {self.maximum_bitpool_value}' ')',
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
@@ -381,83 +350,66 @@ class AacMediaCodecInformation:
|
||||
A2DP spec - 4.5.2 Codec Specific Information Elements
|
||||
'''
|
||||
|
||||
object_type: int
|
||||
sampling_frequency: int
|
||||
channels: int
|
||||
rfa: int
|
||||
object_type: ObjectType
|
||||
sampling_frequency: SamplingFrequency
|
||||
channels: Channels
|
||||
vbr: int
|
||||
bitrate: int
|
||||
|
||||
OBJECT_TYPE_BITS = {
|
||||
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
|
||||
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
|
||||
MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5,
|
||||
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4,
|
||||
}
|
||||
SAMPLING_FREQUENCY_BITS = {
|
||||
8000: 1 << 11,
|
||||
11025: 1 << 10,
|
||||
12000: 1 << 9,
|
||||
16000: 1 << 8,
|
||||
22050: 1 << 7,
|
||||
24000: 1 << 6,
|
||||
32000: 1 << 5,
|
||||
44100: 1 << 4,
|
||||
48000: 1 << 3,
|
||||
64000: 1 << 2,
|
||||
88200: 1 << 1,
|
||||
96000: 1,
|
||||
}
|
||||
CHANNELS_BITS = {1: 1 << 1, 2: 1}
|
||||
class ObjectType(enum.IntFlag):
|
||||
MPEG_2_AAC_LC = 1 << 7
|
||||
MPEG_4_AAC_LC = 1 << 6
|
||||
MPEG_4_AAC_LTP = 1 << 5
|
||||
MPEG_4_AAC_SCALABLE = 1 << 4
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> AacMediaCodecInformation:
|
||||
object_type = data[0]
|
||||
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
|
||||
channels = (data[2] >> 2) & 0x03
|
||||
rfa = 0
|
||||
class SamplingFrequency(enum.IntFlag):
|
||||
SF_8000 = 1 << 11
|
||||
SF_11025 = 1 << 10
|
||||
SF_12000 = 1 << 9
|
||||
SF_16000 = 1 << 8
|
||||
SF_22050 = 1 << 7
|
||||
SF_24000 = 1 << 6
|
||||
SF_32000 = 1 << 5
|
||||
SF_44100 = 1 << 4
|
||||
SF_48000 = 1 << 3
|
||||
SF_64000 = 1 << 2
|
||||
SF_88200 = 1 << 1
|
||||
SF_96000 = 1 << 0
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, sampling_frequency: int) -> Self:
|
||||
sampling_frequencies = [
|
||||
8000,
|
||||
11025,
|
||||
12000,
|
||||
16000,
|
||||
22050,
|
||||
24000,
|
||||
32000,
|
||||
44100,
|
||||
48000,
|
||||
64000,
|
||||
88200,
|
||||
96000,
|
||||
]
|
||||
index = sampling_frequencies.index(sampling_frequency)
|
||||
return cls(1 << (len(sampling_frequencies) - index - 1))
|
||||
|
||||
class Channels(enum.IntFlag):
|
||||
MONO = 1 << 1
|
||||
STEREO = 1 << 0
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> AacMediaCodecInformation:
|
||||
object_type = cls.ObjectType(data[0])
|
||||
sampling_frequency = cls.SamplingFrequency(
|
||||
(data[1] << 4) | ((data[2] >> 4) & 0x0F)
|
||||
)
|
||||
channels = cls.Channels((data[2] >> 2) & 0x03)
|
||||
vbr = (data[3] >> 7) & 0x01
|
||||
bitrate = ((data[3] & 0x7F) << 16) | (data[4] << 8) | data[5]
|
||||
return AacMediaCodecInformation(
|
||||
object_type, sampling_frequency, channels, rfa, vbr, bitrate
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_discrete_values(
|
||||
cls,
|
||||
object_type: int,
|
||||
sampling_frequency: int,
|
||||
channels: int,
|
||||
vbr: int,
|
||||
bitrate: int,
|
||||
) -> AacMediaCodecInformation:
|
||||
return AacMediaCodecInformation(
|
||||
object_type=cls.OBJECT_TYPE_BITS[object_type],
|
||||
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
|
||||
channels=cls.CHANNELS_BITS[channels],
|
||||
rfa=0,
|
||||
vbr=vbr,
|
||||
bitrate=bitrate,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lists(
|
||||
cls,
|
||||
object_types: List[int],
|
||||
sampling_frequencies: List[int],
|
||||
channels: List[int],
|
||||
vbr: int,
|
||||
bitrate: int,
|
||||
) -> AacMediaCodecInformation:
|
||||
return AacMediaCodecInformation(
|
||||
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
|
||||
sampling_frequency=sum(
|
||||
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
|
||||
),
|
||||
channels=sum(cls.CHANNELS_BITS[x] for x in channels),
|
||||
rfa=0,
|
||||
vbr=vbr,
|
||||
bitrate=bitrate,
|
||||
object_type, sampling_frequency, channels, vbr, bitrate
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
@@ -472,30 +424,6 @@ class AacMediaCodecInformation:
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
object_types = [
|
||||
'MPEG_2_AAC_LC',
|
||||
'MPEG_4_AAC_LC',
|
||||
'MPEG_4_AAC_LTP',
|
||||
'MPEG_4_AAC_SCALABLE',
|
||||
'[4]',
|
||||
'[5]',
|
||||
'[6]',
|
||||
'[7]',
|
||||
]
|
||||
channels = [1, 2]
|
||||
# pylint: disable=line-too-long
|
||||
return '\n'.join(
|
||||
[
|
||||
'AacMediaCodecInformation(',
|
||||
f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}',
|
||||
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}',
|
||||
f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}',
|
||||
f' vbr: {self.vbr}',
|
||||
f' bitrate: {self.bitrate}' ')',
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -514,7 +442,7 @@ class VendorSpecificMediaCodecInformation:
|
||||
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
|
||||
return struct.pack('<IH', self.vendor_id, self.codec_id) + self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
# pylint: disable=line-too-long
|
||||
@@ -528,13 +456,75 @@ class VendorSpecificMediaCodecInformation:
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class OpusMediaCodecInformation(VendorSpecificMediaCodecInformation):
|
||||
vendor_id: int = dataclasses.field(init=False, repr=False)
|
||||
codec_id: int = dataclasses.field(init=False, repr=False)
|
||||
value: bytes = dataclasses.field(init=False, repr=False)
|
||||
channel_mode: ChannelMode
|
||||
frame_size: FrameSize
|
||||
sampling_frequency: SamplingFrequency
|
||||
|
||||
class ChannelMode(enum.IntFlag):
|
||||
MONO = 1 << 0
|
||||
STEREO = 1 << 1
|
||||
DUAL_MONO = 1 << 2
|
||||
|
||||
class FrameSize(enum.IntFlag):
|
||||
FS_10MS = 1 << 0
|
||||
FS_20MS = 1 << 1
|
||||
|
||||
class SamplingFrequency(enum.IntFlag):
|
||||
SF_48000 = 1 << 0
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, sampling_frequency: int) -> Self:
|
||||
if sampling_frequency != 48000:
|
||||
raise ValueError("no such sampling frequency")
|
||||
return cls(1)
|
||||
|
||||
VENDOR_ID: ClassVar[int] = 0x000000E0
|
||||
CODEC_ID: ClassVar[int] = 0x0001
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.vendor_id = self.VENDOR_ID
|
||||
self.codec_id = self.CODEC_ID
|
||||
self.value = bytes(
|
||||
[
|
||||
self.channel_mode
|
||||
| (self.frame_size << 3)
|
||||
| (self.sampling_frequency << 7)
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
"""Create a new instance from the `value` part of the data, not including
|
||||
the vendor id and codec id"""
|
||||
channel_mode = cls.ChannelMode(data[0] & 0x07)
|
||||
frame_size = cls.FrameSize((data[0] >> 3) & 0x03)
|
||||
sampling_frequency = cls.SamplingFrequency((data[0] >> 7) & 0x01)
|
||||
|
||||
return cls(
|
||||
channel_mode,
|
||||
frame_size,
|
||||
sampling_frequency,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return repr(self)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class SbcFrame:
|
||||
sampling_frequency: int
|
||||
block_count: int
|
||||
channel_mode: int
|
||||
allocation_method: int
|
||||
subband_count: int
|
||||
bitpool: int
|
||||
payload: bytes
|
||||
|
||||
@property
|
||||
@@ -553,8 +543,10 @@ class SbcFrame:
|
||||
return (
|
||||
f'SBC(sf={self.sampling_frequency},'
|
||||
f'cm={self.channel_mode},'
|
||||
f'am={self.allocation_method},'
|
||||
f'br={self.bitrate},'
|
||||
f'sc={self.sample_count},'
|
||||
f'bp={self.bitpool},'
|
||||
f'size={len(self.payload)})'
|
||||
)
|
||||
|
||||
@@ -583,6 +575,7 @@ class SbcParser:
|
||||
blocks = 4 * (1 + ((header[1] >> 4) & 3))
|
||||
channel_mode = (header[1] >> 2) & 3
|
||||
channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2
|
||||
allocation_method = (header[1] >> 1) & 1
|
||||
subbands = 8 if ((header[1]) & 1) else 4
|
||||
bitpool = header[2]
|
||||
|
||||
@@ -602,7 +595,13 @@ class SbcParser:
|
||||
|
||||
# Emit the next frame
|
||||
yield SbcFrame(
|
||||
sampling_frequency, blocks, channel_mode, subbands, payload
|
||||
sampling_frequency,
|
||||
blocks,
|
||||
channel_mode,
|
||||
allocation_method,
|
||||
subbands,
|
||||
bitpool,
|
||||
payload,
|
||||
)
|
||||
|
||||
return generate_frames()
|
||||
@@ -610,21 +609,15 @@ class SbcParser:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SbcPacketSource:
|
||||
def __init__(
|
||||
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
|
||||
) -> None:
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
|
||||
self.read = read
|
||||
self.mtu = mtu
|
||||
self.codec_capabilities = codec_capabilities
|
||||
|
||||
@property
|
||||
def packets(self):
|
||||
async def generate_packets():
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from .avdtp import MediaPacket # Import here to avoid a circular reference
|
||||
|
||||
sequence_number = 0
|
||||
timestamp = 0
|
||||
sample_count = 0
|
||||
frames = []
|
||||
frames_size = 0
|
||||
max_rtp_payload = self.mtu - 12 - 1
|
||||
@@ -632,29 +625,29 @@ class SbcPacketSource:
|
||||
# NOTE: this doesn't support frame fragments
|
||||
sbc_parser = SbcParser(self.read)
|
||||
async for frame in sbc_parser.frames:
|
||||
print(frame)
|
||||
|
||||
if (
|
||||
frames_size + len(frame.payload) > max_rtp_payload
|
||||
or len(frames) == 16
|
||||
or len(frames) == SBC_MAX_FRAMES_IN_RTP_PAYLOAD
|
||||
):
|
||||
# Need to flush what has been accumulated so far
|
||||
logger.debug(f"yielding {len(frames)} frames")
|
||||
|
||||
# Emit a packet
|
||||
sbc_payload = bytes([len(frames)]) + b''.join(
|
||||
sbc_payload = bytes([len(frames) & 0x0F]) + b''.join(
|
||||
[frame.payload for frame in frames]
|
||||
)
|
||||
timestamp_seconds = sample_count / frame.sampling_frequency
|
||||
timestamp = int(1000 * timestamp_seconds)
|
||||
packet = MediaPacket(
|
||||
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
|
||||
)
|
||||
packet.timestamp_seconds = timestamp / frame.sampling_frequency
|
||||
packet.timestamp_seconds = timestamp_seconds
|
||||
yield packet
|
||||
|
||||
# Prepare for next packets
|
||||
sequence_number += 1
|
||||
sequence_number &= 0xFFFF
|
||||
timestamp += sum((frame.sample_count for frame in frames))
|
||||
timestamp &= 0xFFFFFFFF
|
||||
sample_count += sum((frame.sample_count for frame in frames))
|
||||
frames = [frame]
|
||||
frames_size = len(frame.payload)
|
||||
else:
|
||||
@@ -663,3 +656,315 @@ class SbcPacketSource:
|
||||
frames_size += len(frame.payload)
|
||||
|
||||
return generate_packets()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class AacFrame:
|
||||
class Profile(enum.IntEnum):
|
||||
MAIN = 0
|
||||
LC = 1
|
||||
SSR = 2
|
||||
LTP = 3
|
||||
|
||||
profile: Profile
|
||||
sampling_frequency: int
|
||||
channel_configuration: int
|
||||
payload: bytes
|
||||
|
||||
@property
|
||||
def sample_count(self) -> int:
|
||||
return 1024
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
return self.sample_count / self.sampling_frequency
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'AAC(sf={self.sampling_frequency},'
|
||||
f'ch={self.channel_configuration},'
|
||||
f'size={len(self.payload)})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
ADTS_AAC_SAMPLING_FREQUENCIES = [
|
||||
96000,
|
||||
88200,
|
||||
64000,
|
||||
48000,
|
||||
44100,
|
||||
32000,
|
||||
24000,
|
||||
22050,
|
||||
16000,
|
||||
12000,
|
||||
11025,
|
||||
8000,
|
||||
7350,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AacParser:
|
||||
"""Parser for AAC frames in an ADTS stream"""
|
||||
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
|
||||
self.read = read
|
||||
|
||||
@property
|
||||
def frames(self) -> AsyncGenerator[AacFrame, None]:
|
||||
async def generate_frames() -> AsyncGenerator[AacFrame, None]:
|
||||
while True:
|
||||
header = await self.read(7)
|
||||
if not header:
|
||||
return
|
||||
|
||||
sync_word = (header[0] << 4) | (header[1] >> 4)
|
||||
if sync_word != 0b111111111111:
|
||||
raise ValueError(f"invalid sync word ({sync_word:06x})")
|
||||
layer = (header[1] >> 1) & 0b11
|
||||
profile = AacFrame.Profile((header[2] >> 6) & 0b11)
|
||||
sampling_frequency = ADTS_AAC_SAMPLING_FREQUENCIES[
|
||||
(header[2] >> 2) & 0b1111
|
||||
]
|
||||
channel_configuration = ((header[2] & 0b1) << 2) | (header[3] >> 6)
|
||||
frame_length = (
|
||||
((header[3] & 0b11) << 11) | (header[4] << 3) | (header[5] >> 5)
|
||||
)
|
||||
|
||||
if layer != 0:
|
||||
raise ValueError("layer must be 0")
|
||||
|
||||
payload = await self.read(frame_length - 7)
|
||||
if payload:
|
||||
yield AacFrame(
|
||||
profile, sampling_frequency, channel_configuration, payload
|
||||
)
|
||||
|
||||
return generate_frames()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AacPacketSource:
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
|
||||
self.read = read
|
||||
self.mtu = mtu
|
||||
|
||||
@property
|
||||
def packets(self):
|
||||
async def generate_packets():
|
||||
sequence_number = 0
|
||||
sample_count = 0
|
||||
|
||||
aac_parser = AacParser(self.read)
|
||||
async for frame in aac_parser.frames:
|
||||
logger.debug("yielding one AAC frame")
|
||||
|
||||
# Emit a packet
|
||||
aac_payload = bytes(
|
||||
AacAudioRtpPacket.for_simple_aac(
|
||||
frame.sampling_frequency,
|
||||
frame.channel_configuration,
|
||||
frame.payload,
|
||||
)
|
||||
)
|
||||
timestamp_seconds = sample_count / frame.sampling_frequency
|
||||
timestamp = int(1000 * timestamp_seconds)
|
||||
packet = MediaPacket(
|
||||
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, aac_payload
|
||||
)
|
||||
packet.timestamp_seconds = timestamp_seconds
|
||||
yield packet
|
||||
|
||||
# Prepare for next packets
|
||||
sequence_number += 1
|
||||
sequence_number &= 0xFFFF
|
||||
sample_count += frame.sample_count
|
||||
|
||||
return generate_packets()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class OpusPacket:
|
||||
class ChannelMode(enum.IntEnum):
|
||||
MONO = 0
|
||||
STEREO = 1
|
||||
DUAL_MONO = 2
|
||||
|
||||
channel_mode: ChannelMode
|
||||
duration: int # Duration in ms.
|
||||
sampling_frequency: int
|
||||
payload: bytes
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Opus(ch={self.channel_mode.name}, '
|
||||
f'd={self.duration}ms, '
|
||||
f'size={len(self.payload)})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OpusParser:
|
||||
"""
|
||||
Parser for Opus packets in an Ogg stream
|
||||
|
||||
See RFC 3533
|
||||
|
||||
NOTE: this parser only supports bitstreams with a single logical stream.
|
||||
"""
|
||||
|
||||
CAPTURE_PATTERN = b'OggS'
|
||||
|
||||
class HeaderType(enum.IntFlag):
|
||||
CONTINUED = 0x01
|
||||
FIRST = 0x02
|
||||
LAST = 0x04
|
||||
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
|
||||
self.read = read
|
||||
|
||||
@property
|
||||
def packets(self) -> AsyncGenerator[OpusPacket, None]:
|
||||
async def generate_frames() -> AsyncGenerator[OpusPacket, None]:
|
||||
packet = b''
|
||||
packet_count = 0
|
||||
expected_bitstream_serial_number = None
|
||||
expected_page_sequence_number = 0
|
||||
channel_mode = OpusPacket.ChannelMode.STEREO
|
||||
|
||||
while True:
|
||||
# Parse the page header
|
||||
header = await self.read(27)
|
||||
if len(header) != 27:
|
||||
logger.debug("end of stream")
|
||||
break
|
||||
|
||||
capture_pattern = header[:4]
|
||||
if capture_pattern != self.CAPTURE_PATTERN:
|
||||
print(capture_pattern.hex())
|
||||
raise ValueError("invalid capture pattern at start of page")
|
||||
|
||||
version = header[4]
|
||||
if version != 0:
|
||||
raise ValueError(f"version {version} not supported")
|
||||
|
||||
header_type = self.HeaderType(header[5])
|
||||
(
|
||||
granule_position,
|
||||
bitstream_serial_number,
|
||||
page_sequence_number,
|
||||
crc_checksum,
|
||||
page_segments,
|
||||
) = struct.unpack_from("<QIIIB", header, 6)
|
||||
segment_table = await self.read(page_segments)
|
||||
|
||||
if header_type & self.HeaderType.FIRST:
|
||||
if expected_bitstream_serial_number is None:
|
||||
# We will only accept pages for the first encountered stream
|
||||
logger.debug("BOS")
|
||||
expected_bitstream_serial_number = bitstream_serial_number
|
||||
expected_page_sequence_number = page_sequence_number
|
||||
|
||||
if (
|
||||
expected_bitstream_serial_number is None
|
||||
or expected_bitstream_serial_number != bitstream_serial_number
|
||||
):
|
||||
logger.debug("skipping page (not the first logical bitstream)")
|
||||
for lacing_value in segment_table:
|
||||
if lacing_value:
|
||||
await self.read(lacing_value)
|
||||
continue
|
||||
|
||||
if expected_page_sequence_number != page_sequence_number:
|
||||
raise ValueError(
|
||||
f"expected page sequence number {expected_page_sequence_number}"
|
||||
f" but got {page_sequence_number}"
|
||||
)
|
||||
expected_page_sequence_number = page_sequence_number + 1
|
||||
|
||||
# Assemble the page
|
||||
if not header_type & self.HeaderType.CONTINUED:
|
||||
packet = b''
|
||||
for lacing_value in segment_table:
|
||||
if lacing_value:
|
||||
packet += await self.read(lacing_value)
|
||||
if lacing_value < 255:
|
||||
# End of packet
|
||||
packet_count += 1
|
||||
|
||||
if packet_count == 1:
|
||||
# The first packet contains the identification header
|
||||
logger.debug("first packet (header)")
|
||||
if packet[:8] != b"OpusHead":
|
||||
raise ValueError("first packet is not OpusHead")
|
||||
packet_count = (
|
||||
OpusPacket.ChannelMode.MONO
|
||||
if packet[9] == 1
|
||||
else OpusPacket.ChannelMode.STEREO
|
||||
)
|
||||
|
||||
elif packet_count == 2:
|
||||
# The second packet contains the comment header
|
||||
logger.debug("second packet (tags)")
|
||||
if packet[:8] != b"OpusTags":
|
||||
logger.warning("second packet is not OpusTags")
|
||||
else:
|
||||
yield OpusPacket(channel_mode, 20, 48000, packet)
|
||||
|
||||
packet = b''
|
||||
|
||||
if header_type & self.HeaderType.LAST:
|
||||
logger.debug("EOS")
|
||||
|
||||
return generate_frames()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OpusPacketSource:
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
|
||||
self.read = read
|
||||
self.mtu = mtu
|
||||
|
||||
@property
|
||||
def packets(self):
|
||||
async def generate_packets():
|
||||
sequence_number = 0
|
||||
elapsed_ms = 0
|
||||
|
||||
opus_parser = OpusParser(self.read)
|
||||
async for opus_packet in opus_parser.packets:
|
||||
# We only support sending one Opus frame per RTP packet
|
||||
# TODO: check the spec for the first byte value here
|
||||
opus_payload = bytes([1]) + opus_packet.payload
|
||||
elapsed_s = elapsed_ms / 1000
|
||||
timestamp = int(elapsed_s * opus_packet.sampling_frequency)
|
||||
rtp_packet = MediaPacket(
|
||||
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, opus_payload
|
||||
)
|
||||
rtp_packet.timestamp_seconds = elapsed_s
|
||||
yield rtp_packet
|
||||
|
||||
# Prepare for next packets
|
||||
sequence_number += 1
|
||||
sequence_number &= 0xFFFF
|
||||
elapsed_ms += opus_packet.duration
|
||||
|
||||
return generate_packets()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# This map should be left at the end of the file so it can refer to the classes
|
||||
# above
|
||||
# -----------------------------------------------------------------------------
|
||||
A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES = {
|
||||
OpusMediaCodecInformation.VENDOR_ID: {
|
||||
OpusMediaCodecInformation.CODEC_ID: OpusMediaCodecInformation
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
|
||||
from bumble import core
|
||||
|
||||
@@ -21,7 +21,7 @@ class AtParsingError(core.InvalidPacketError):
|
||||
"""Error raised when parsing AT commands fails."""
|
||||
|
||||
|
||||
def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
||||
def tokenize_parameters(buffer: bytes) -> list[bytes]:
|
||||
"""Split input parameters into tokens.
|
||||
Removes space characters outside of double quote blocks:
|
||||
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
|
||||
@@ -63,12 +63,12 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
||||
return [bytes(token) for token in tokens if len(token) > 0]
|
||||
|
||||
|
||||
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
||||
def parse_parameters(buffer: bytes) -> list[Union[bytes, list]]:
|
||||
"""Parse the parameters using the comma and parenthesis separators.
|
||||
Raises AtParsingError in case of invalid input string."""
|
||||
|
||||
tokens = tokenize_parameters(buffer)
|
||||
accumulator: List[list] = [[]]
|
||||
accumulator: list[list] = [[]]
|
||||
current: Union[bytes, list] = bytes()
|
||||
|
||||
for token in tokens:
|
||||
|
||||
617
bumble/att.py
617
bumble/att.py
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,6 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
setup()
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
549
bumble/audio/io.py
Normal file
549
bumble/audio/io.py
Normal file
@@ -0,0 +1,549 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import pathlib
|
||||
import sys
|
||||
import wave
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO
|
||||
|
||||
from bumble.colors import color
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sounddevice # type: ignore[import-untyped]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class PcmFormat:
|
||||
class Endianness(enum.Enum):
|
||||
LITTLE = 0
|
||||
BIG = 1
|
||||
|
||||
class SampleType(enum.Enum):
|
||||
FLOAT32 = 0
|
||||
INT16 = 1
|
||||
|
||||
endianness: Endianness
|
||||
sample_type: SampleType
|
||||
sample_rate: int
|
||||
channels: int
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, format_str: str) -> PcmFormat:
|
||||
endianness = cls.Endianness.LITTLE # Others not yet supported.
|
||||
sample_type_str, sample_rate_str, channels_str = format_str.split(',')
|
||||
if sample_type_str == 'int16le':
|
||||
sample_type = cls.SampleType.INT16
|
||||
elif sample_type_str == 'float32le':
|
||||
sample_type = cls.SampleType.FLOAT32
|
||||
else:
|
||||
raise ValueError(f'sample type {sample_type_str} not supported')
|
||||
sample_rate = int(sample_rate_str)
|
||||
channels = int(channels_str)
|
||||
|
||||
return cls(endianness, sample_type, sample_rate, channels)
|
||||
|
||||
@property
|
||||
def bytes_per_sample(self) -> int:
|
||||
return 2 if self.sample_type == self.SampleType.INT16 else 4
|
||||
|
||||
|
||||
def check_audio_output(output: str) -> bool:
|
||||
if output == 'device' or output.startswith('device:'):
|
||||
try:
|
||||
import sounddevice
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
'audio output not available (sounddevice python module not installed)'
|
||||
) from exc
|
||||
except OSError as exc:
|
||||
raise ValueError(
|
||||
'audio output not available '
|
||||
'(sounddevice python module failed to load: '
|
||||
f'{exc})'
|
||||
) from exc
|
||||
|
||||
if output == 'device':
|
||||
# Default device
|
||||
return True
|
||||
|
||||
# Specific device
|
||||
device = output[7:]
|
||||
if device == '?':
|
||||
print(color('Audio Devices:', 'yellow'))
|
||||
for device_info in [
|
||||
device_info
|
||||
for device_info in sounddevice.query_devices()
|
||||
if device_info['max_output_channels'] > 0
|
||||
]:
|
||||
device_index = device_info['index']
|
||||
is_default = (
|
||||
color(' [default]', 'green')
|
||||
if sounddevice.default.device[1] == device_index
|
||||
else ''
|
||||
)
|
||||
print(
|
||||
f'{color(device_index, "cyan")}: {device_info["name"]}{is_default}'
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
device_info = sounddevice.query_devices(int(device))
|
||||
except sounddevice.PortAudioError as exc:
|
||||
raise ValueError('No such audio device') from exc
|
||||
|
||||
if device_info['max_output_channels'] < 1:
|
||||
raise ValueError(
|
||||
f'Device {device} ({device_info["name"]}) does not have an output'
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def create_audio_output(output: str) -> AudioOutput:
|
||||
if output == 'stdout':
|
||||
return StreamAudioOutput(sys.stdout.buffer)
|
||||
|
||||
if output == 'device' or output.startswith('device:'):
|
||||
device_name = '' if output == 'device' else output[7:]
|
||||
return SoundDeviceAudioOutput(device_name)
|
||||
|
||||
if output == 'ffplay':
|
||||
return SubprocessAudioOutput(
|
||||
command=(
|
||||
'ffplay -probesize 32 -fflags nobuffer -analyzeduration 0 '
|
||||
'-ar {sample_rate} '
|
||||
'-ch_layout {channel_layout} '
|
||||
'-f f32le pipe:0'
|
||||
)
|
||||
)
|
||||
|
||||
if output.startswith('file:'):
|
||||
return FileAudioOutput(output[5:])
|
||||
|
||||
raise ValueError('unsupported audio output')
|
||||
|
||||
|
||||
class AudioOutput(abc.ABC):
|
||||
"""Audio output to which PCM samples can be written."""
|
||||
|
||||
async def open(self, pcm_format: PcmFormat) -> None:
|
||||
"""Start the output."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def write(self, pcm_samples: bytes) -> None:
|
||||
"""Write PCM samples. Must not block."""
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the output."""
|
||||
|
||||
|
||||
class ThreadedAudioOutput(AudioOutput):
|
||||
"""Base class for AudioOutput classes that may need to call blocking functions.
|
||||
|
||||
The actual writing is performed in a thread, so as to ensure that calling write()
|
||||
does not block the caller.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._thread_pool = ThreadPoolExecutor(1)
|
||||
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._write_task = asyncio.create_task(self._write_loop())
|
||||
|
||||
async def _write_loop(self) -> None:
|
||||
while True:
|
||||
pcm_samples = await self._pcm_samples.get()
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
self._thread_pool, self._write, pcm_samples
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _write(self, pcm_samples: bytes) -> None:
|
||||
"""This method does the actual writing and can block."""
|
||||
|
||||
def write(self, pcm_samples: bytes) -> None:
|
||||
self._pcm_samples.put_nowait(pcm_samples)
|
||||
|
||||
def _close(self) -> None:
|
||||
"""This method does the actual closing and can block."""
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await asyncio.get_running_loop().run_in_executor(self._thread_pool, self._close)
|
||||
self._write_task.cancel()
|
||||
self._thread_pool.shutdown()
|
||||
|
||||
|
||||
class SoundDeviceAudioOutput(ThreadedAudioOutput):
|
||||
def __init__(self, device_name: str) -> None:
|
||||
super().__init__()
|
||||
self._device = int(device_name) if device_name else None
|
||||
self._stream: sounddevice.RawOutputStream | None = None
|
||||
|
||||
async def open(self, pcm_format: PcmFormat) -> None:
|
||||
import sounddevice # pylint: disable=import-error
|
||||
|
||||
self._stream = sounddevice.RawOutputStream(
|
||||
samplerate=pcm_format.sample_rate,
|
||||
device=self._device,
|
||||
channels=pcm_format.channels,
|
||||
dtype='float32',
|
||||
)
|
||||
self._stream.start()
|
||||
|
||||
def _write(self, pcm_samples: bytes) -> None:
|
||||
if self._stream is None:
|
||||
return
|
||||
|
||||
try:
|
||||
self._stream.write(pcm_samples)
|
||||
except Exception:
|
||||
logger.exception('Sound device error')
|
||||
raise
|
||||
|
||||
def _close(self):
|
||||
self._stream.stop()
|
||||
self._stream = None
|
||||
|
||||
|
||||
class StreamAudioOutput(ThreadedAudioOutput):
|
||||
"""AudioOutput where PCM samples are written to a stream that may block."""
|
||||
|
||||
def __init__(self, stream: BinaryIO) -> None:
|
||||
super().__init__()
|
||||
self._stream = stream
|
||||
|
||||
def _write(self, pcm_samples: bytes) -> None:
|
||||
self._stream.write(pcm_samples)
|
||||
self._stream.flush()
|
||||
|
||||
|
||||
class FileAudioOutput(StreamAudioOutput):
|
||||
"""AudioOutput where PCM samples are written to a file."""
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
self._file = open(filename, "wb")
|
||||
super().__init__(self._file)
|
||||
|
||||
async def shutdown(self):
|
||||
self._file.close()
|
||||
return await super().shutdown()
|
||||
|
||||
|
||||
class SubprocessAudioOutput(AudioOutput):
|
||||
"""AudioOutput where audio samples are written to a subprocess via stdin."""
|
||||
|
||||
def __init__(self, command: str) -> None:
|
||||
self._command = command
|
||||
self._subprocess: asyncio.subprocess.Process | None
|
||||
|
||||
async def open(self, pcm_format: PcmFormat) -> None:
|
||||
if pcm_format.channels == 1:
|
||||
channel_layout = 'mono'
|
||||
elif pcm_format.channels == 2:
|
||||
channel_layout = 'stereo'
|
||||
else:
|
||||
raise ValueError(f'{pcm_format.channels} channels not supported')
|
||||
|
||||
command = self._command.format(
|
||||
sample_rate=pcm_format.sample_rate, channel_layout=channel_layout
|
||||
)
|
||||
self._subprocess = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
def write(self, pcm_samples: bytes) -> None:
|
||||
if self._subprocess is None or self._subprocess.stdin is None:
|
||||
return
|
||||
|
||||
self._subprocess.stdin.write(pcm_samples)
|
||||
|
||||
async def aclose(self):
|
||||
if self._subprocess:
|
||||
self._subprocess.terminate()
|
||||
|
||||
|
||||
def check_audio_input(input: str) -> bool:
|
||||
if input == 'device' or input.startswith('device:'):
|
||||
try:
|
||||
import sounddevice # pylint: disable=import-error
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
'audio input not available (sounddevice python module not installed)'
|
||||
) from exc
|
||||
except OSError as exc:
|
||||
raise ValueError(
|
||||
'audio input not available '
|
||||
'(sounddevice python module failed to load: '
|
||||
f'{exc})'
|
||||
) from exc
|
||||
|
||||
if input == 'device':
|
||||
# Default device
|
||||
return True
|
||||
|
||||
# Specific device
|
||||
device = input[7:]
|
||||
if device == '?':
|
||||
print(color('Audio Devices:', 'yellow'))
|
||||
for device_info in [
|
||||
device_info
|
||||
for device_info in sounddevice.query_devices()
|
||||
if device_info['max_input_channels'] > 0
|
||||
]:
|
||||
device_index = device_info["index"]
|
||||
is_mono = device_info['max_input_channels'] == 1
|
||||
max_channels = color(f'[{"mono" if is_mono else "stereo"}]', 'cyan')
|
||||
is_default = (
|
||||
color(' [default]', 'green')
|
||||
if sounddevice.default.device[0] == device_index
|
||||
else ''
|
||||
)
|
||||
print(
|
||||
f'{color(device_index, "cyan")}: {device_info["name"]}'
|
||||
f' {max_channels}{is_default}'
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
device_info = sounddevice.query_devices(int(device))
|
||||
except sounddevice.PortAudioError as exc:
|
||||
raise ValueError('No such audio device') from exc
|
||||
|
||||
if device_info['max_input_channels'] < 1:
|
||||
raise ValueError(
|
||||
f'Device {device} ({device_info["name"]}) does not have an input'
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def create_audio_input(input: str, input_format: str) -> AudioInput:
|
||||
pcm_format: PcmFormat | None
|
||||
if input_format == 'auto':
|
||||
pcm_format = None
|
||||
else:
|
||||
pcm_format = PcmFormat.from_str(input_format)
|
||||
|
||||
if input == 'stdin':
|
||||
if not pcm_format:
|
||||
raise ValueError('input format details required for stdin')
|
||||
return StreamAudioInput(sys.stdin.buffer, pcm_format)
|
||||
|
||||
if input == 'device' or input.startswith('device:'):
|
||||
if not pcm_format:
|
||||
raise ValueError('input format details required for device')
|
||||
device_name = '' if input == 'device' else input[7:]
|
||||
return SoundDeviceAudioInput(device_name, pcm_format)
|
||||
|
||||
# If there's no file: prefix, check if we can assume it is a file.
|
||||
if pathlib.Path(input).is_file():
|
||||
input = 'file:' + input
|
||||
|
||||
if input.startswith('file:'):
|
||||
filename = input[5:]
|
||||
if filename.endswith('.wav'):
|
||||
if input_format != 'auto':
|
||||
raise ValueError(".wav file only supported with 'auto' format")
|
||||
return WaveAudioInput(filename)
|
||||
|
||||
if pcm_format is None:
|
||||
raise ValueError('input format details required for raw PCM files')
|
||||
return FileAudioInput(filename, pcm_format)
|
||||
|
||||
raise ValueError('input not supported')
|
||||
|
||||
|
||||
class AudioInput(abc.ABC):
|
||||
"""Audio input that produces PCM samples."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def open(self) -> PcmFormat:
|
||||
"""Open the input."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def frames(self, frame_size: int) -> AsyncGenerator[bytes]:
|
||||
"""Generate one frame of PCM samples. Must not block."""
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the input."""
|
||||
|
||||
|
||||
class ThreadedAudioInput(AudioInput):
|
||||
"""Base class for AudioInput implementation where reading samples may block."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._thread_pool = ThreadPoolExecutor(1)
|
||||
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _read(self, frame_size: int) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _open(self) -> PcmFormat:
|
||||
pass
|
||||
|
||||
def _close(self) -> None:
|
||||
pass
|
||||
|
||||
async def open(self) -> PcmFormat:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
self._thread_pool, self._open
|
||||
)
|
||||
|
||||
async def frames(self, frame_size: int) -> AsyncGenerator[bytes]:
|
||||
while pcm_sample := await asyncio.get_running_loop().run_in_executor(
|
||||
self._thread_pool, self._read, frame_size
|
||||
):
|
||||
yield pcm_sample
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await asyncio.get_running_loop().run_in_executor(self._thread_pool, self._close)
|
||||
self._thread_pool.shutdown()
|
||||
|
||||
|
||||
class WaveAudioInput(ThreadedAudioInput):
|
||||
"""Audio input that reads PCM samples from a .wav file."""
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._wav: wave.Wave_read | None = None
|
||||
self._bytes_read = 0
|
||||
|
||||
def _open(self) -> PcmFormat:
|
||||
self._wav = wave.open(self._filename, 'rb')
|
||||
if self._wav.getsampwidth() != 2:
|
||||
raise ValueError('sample width not supported')
|
||||
return PcmFormat(
|
||||
PcmFormat.Endianness.LITTLE,
|
||||
PcmFormat.SampleType.INT16,
|
||||
self._wav.getframerate(),
|
||||
self._wav.getnchannels(),
|
||||
)
|
||||
|
||||
def _read(self, frame_size: int) -> bytes:
|
||||
if not self._wav:
|
||||
return b''
|
||||
|
||||
pcm_samples = self._wav.readframes(frame_size)
|
||||
if not pcm_samples and self._bytes_read:
|
||||
# Loop around.
|
||||
self._wav.rewind()
|
||||
self._bytes_read = 0
|
||||
pcm_samples = self._wav.readframes(frame_size)
|
||||
|
||||
self._bytes_read += len(pcm_samples)
|
||||
return pcm_samples
|
||||
|
||||
def _close(self) -> None:
|
||||
if self._wav:
|
||||
self._wav.close()
|
||||
|
||||
|
||||
class StreamAudioInput(ThreadedAudioInput):
|
||||
"""AudioInput where samples are read from a raw PCM stream that may block."""
|
||||
|
||||
def __init__(self, stream: BinaryIO, pcm_format: PcmFormat) -> None:
|
||||
super().__init__()
|
||||
self._stream = stream
|
||||
self._pcm_format = pcm_format
|
||||
|
||||
def _open(self) -> PcmFormat:
|
||||
return self._pcm_format
|
||||
|
||||
def _read(self, frame_size: int) -> bytes:
|
||||
return self._stream.read(
|
||||
frame_size * self._pcm_format.channels * self._pcm_format.bytes_per_sample
|
||||
)
|
||||
|
||||
|
||||
class FileAudioInput(StreamAudioInput):
|
||||
"""AudioInput where PCM samples are read from a raw PCM file."""
|
||||
|
||||
def __init__(self, filename: str, pcm_format: PcmFormat) -> None:
|
||||
self._stream = open(filename, "rb")
|
||||
super().__init__(self._stream, pcm_format)
|
||||
|
||||
def _close(self) -> None:
|
||||
self._stream.close()
|
||||
|
||||
|
||||
class SoundDeviceAudioInput(ThreadedAudioInput):
|
||||
def __init__(self, device_name: str, pcm_format: PcmFormat) -> None:
|
||||
super().__init__()
|
||||
self._device = int(device_name) if device_name else None
|
||||
self._pcm_format = pcm_format
|
||||
self._stream: sounddevice.RawInputStream | None = None
|
||||
|
||||
def _open(self) -> PcmFormat:
|
||||
import sounddevice # pylint: disable=import-error
|
||||
|
||||
self._stream = sounddevice.RawInputStream(
|
||||
samplerate=self._pcm_format.sample_rate,
|
||||
device=self._device,
|
||||
channels=self._pcm_format.channels,
|
||||
dtype='int16',
|
||||
)
|
||||
self._stream.start()
|
||||
|
||||
return PcmFormat(
|
||||
PcmFormat.Endianness.LITTLE,
|
||||
PcmFormat.SampleType.INT16,
|
||||
self._pcm_format.sample_rate,
|
||||
2,
|
||||
)
|
||||
|
||||
def _read(self, frame_size: int) -> bytes:
|
||||
if not self._stream:
|
||||
return b''
|
||||
pcm_buffer, overflowed = self._stream.read(frame_size)
|
||||
if overflowed:
|
||||
logger.warning("input overflow")
|
||||
|
||||
# Convert the buffer to stereo if needed
|
||||
if self._pcm_format.channels == 1:
|
||||
stereo_buffer = bytearray()
|
||||
for i in range(frame_size):
|
||||
sample = pcm_buffer[i * 2 : i * 2 + 2]
|
||||
stereo_buffer += sample + sample
|
||||
return stereo_buffer
|
||||
|
||||
return bytes(pcm_buffer)
|
||||
|
||||
def _close(self):
|
||||
self._stream.stop()
|
||||
self._stream = None
|
||||
@@ -16,12 +16,12 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import struct
|
||||
from typing import Dict, Type, Union, Tuple
|
||||
from typing import Union
|
||||
|
||||
from bumble import core
|
||||
from bumble.utils import OpenIntEnum
|
||||
from bumble import core, utils
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -43,7 +43,7 @@ class Frame:
|
||||
EXTENDED = 0x1E
|
||||
UNIT = 0x1F
|
||||
|
||||
class OperationCode(OpenIntEnum):
|
||||
class OperationCode(utils.OpenIntEnum):
|
||||
# 0x00 - 0x0F: Unit and subunit commands
|
||||
VENDOR_DEPENDENT = 0x00
|
||||
RESERVE = 0x01
|
||||
@@ -119,7 +119,7 @@ class Frame:
|
||||
# Not supported
|
||||
raise NotImplementedError("extended subunit types not supported")
|
||||
|
||||
if subunit_id < 5:
|
||||
if subunit_id < 5 or subunit_id == 7:
|
||||
opcode_offset = 2
|
||||
elif subunit_id == 5:
|
||||
# Extended to the next byte
|
||||
@@ -132,9 +132,10 @@ class Frame:
|
||||
else:
|
||||
subunit_id = 5 + extension
|
||||
opcode_offset = 3
|
||||
|
||||
elif subunit_id == 6:
|
||||
raise core.InvalidPacketError("reserved subunit ID")
|
||||
else:
|
||||
raise core.InvalidPacketError("invalid subunit ID")
|
||||
|
||||
opcode = Frame.OperationCode(data[opcode_offset])
|
||||
operands = data[opcode_offset + 1 :]
|
||||
@@ -203,7 +204,7 @@ class Frame:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommandFrame(Frame):
|
||||
class CommandType(OpenIntEnum):
|
||||
class CommandType(utils.OpenIntEnum):
|
||||
# AV/C Digital Interface Command Set General Specification Version 4.1
|
||||
# Table 7.1
|
||||
CONTROL = 0x00
|
||||
@@ -212,11 +213,11 @@ class CommandFrame(Frame):
|
||||
NOTIFY = 0x03
|
||||
GENERAL_INQUIRY = 0x04
|
||||
|
||||
subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {}
|
||||
subclasses: dict[Frame.OperationCode, type[CommandFrame]] = {}
|
||||
ctype: CommandType
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
def parse_operands(operands: bytes) -> tuple:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
@@ -239,7 +240,7 @@ class CommandFrame(Frame):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ResponseFrame(Frame):
|
||||
class ResponseCode(OpenIntEnum):
|
||||
class ResponseCode(utils.OpenIntEnum):
|
||||
# AV/C Digital Interface Command Set General Specification Version 4.1
|
||||
# Table 7.2
|
||||
NOT_IMPLEMENTED = 0x08
|
||||
@@ -250,11 +251,11 @@ class ResponseFrame(Frame):
|
||||
CHANGED = 0x0D
|
||||
INTERIM = 0x0F
|
||||
|
||||
subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {}
|
||||
subclasses: dict[Frame.OperationCode, type[ResponseFrame]] = {}
|
||||
response: ResponseCode
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
def parse_operands(operands: bytes) -> tuple:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
@@ -281,7 +282,7 @@ class VendorDependentFrame:
|
||||
vendor_dependent_data: bytes
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
def parse_operands(operands: bytes) -> tuple:
|
||||
return (
|
||||
struct.unpack(">I", b"\x00" + operands[:3])[0],
|
||||
operands[3:],
|
||||
@@ -367,7 +368,7 @@ class PassThroughFrame:
|
||||
PRESSED = 0
|
||||
RELEASED = 1
|
||||
|
||||
class OperationId(OpenIntEnum):
|
||||
class OperationId(utils.OpenIntEnum):
|
||||
SELECT = 0x00
|
||||
UP = 0x01
|
||||
DOWN = 0x01
|
||||
@@ -431,7 +432,7 @@ class PassThroughFrame:
|
||||
operation_data: bytes
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
def parse_operands(operands: bytes) -> tuple:
|
||||
return (
|
||||
PassThroughFrame.StateFlag(operands[0] >> 7),
|
||||
PassThroughFrame.OperationId(operands[0] & 0x7F),
|
||||
|
||||
@@ -16,15 +16,14 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from typing import Callable, cast, Dict, Optional
|
||||
from enum import IntEnum
|
||||
from typing import Callable, Optional, cast
|
||||
|
||||
from bumble import avc, core, l2cap
|
||||
from bumble.colors import color
|
||||
from bumble import avc
|
||||
from bumble import core
|
||||
from bumble import l2cap
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -137,8 +136,8 @@ class MessageAssembler:
|
||||
self.pid,
|
||||
self.payload,
|
||||
)
|
||||
except Exception as error:
|
||||
logger.exception(color(f"!!! exception in callback: {error}", "red"))
|
||||
except Exception:
|
||||
logger.exception(color("!!! exception in callback", "red"))
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -146,9 +145,9 @@ class MessageAssembler:
|
||||
# -----------------------------------------------------------------------------
|
||||
class Protocol:
|
||||
CommandHandler = Callable[[int, avc.CommandFrame], None]
|
||||
command_handlers: Dict[int, CommandHandler] # Command handlers, by PID
|
||||
command_handlers: dict[int, CommandHandler] # Command handlers, by PID
|
||||
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
|
||||
response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID
|
||||
response_handlers: dict[int, ResponseHandler] # Response handlers, by PID
|
||||
next_transaction_label: int
|
||||
message_assembler: MessageAssembler
|
||||
|
||||
@@ -166,8 +165,8 @@ class Protocol:
|
||||
|
||||
# Register to receive PDUs from the channel
|
||||
l2cap_channel.sink = self.on_pdu
|
||||
l2cap_channel.on("open", self.on_l2cap_channel_open)
|
||||
l2cap_channel.on("close", self.on_l2cap_channel_close)
|
||||
l2cap_channel.on(l2cap_channel.EVENT_OPEN, self.on_l2cap_channel_open)
|
||||
l2cap_channel.on(l2cap_channel.EVENT_CLOSE, self.on_l2cap_channel_close)
|
||||
|
||||
def on_l2cap_channel_open(self):
|
||||
logger.debug(color("<<< AVCTP channel open", "magenta"))
|
||||
|
||||
425
bumble/avdtp.py
425
bumble/avdtp.py
@@ -16,47 +16,44 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
import time
|
||||
import logging
|
||||
import enum
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from pyee import EventEmitter
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Dict,
|
||||
Type,
|
||||
Tuple,
|
||||
Optional,
|
||||
Callable,
|
||||
List,
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Iterable,
|
||||
Union,
|
||||
Optional,
|
||||
SupportsBytes,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from .core import (
|
||||
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
|
||||
InvalidStateError,
|
||||
ProtocolError,
|
||||
InvalidArgumentError,
|
||||
name_or_number,
|
||||
)
|
||||
from .a2dp import (
|
||||
from bumble import device, l2cap, sdp, utils
|
||||
from bumble.a2dp import (
|
||||
A2DP_CODEC_TYPE_NAMES,
|
||||
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
|
||||
A2DP_NON_A2DP_CODEC_TYPE,
|
||||
A2DP_SBC_CODEC_TYPE,
|
||||
A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES,
|
||||
AacMediaCodecInformation,
|
||||
SbcMediaCodecInformation,
|
||||
VendorSpecificMediaCodecInformation,
|
||||
)
|
||||
from . import sdp, device, l2cap
|
||||
from .colors import color
|
||||
from bumble.colors import color
|
||||
from bumble.core import (
|
||||
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
|
||||
InvalidArgumentError,
|
||||
InvalidStateError,
|
||||
ProtocolError,
|
||||
name_or_number,
|
||||
)
|
||||
from bumble.rtp import MediaPacket
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -225,7 +222,7 @@ AVDTP_STATE_NAMES = {
|
||||
# -----------------------------------------------------------------------------
|
||||
async def find_avdtp_service_with_sdp_client(
|
||||
sdp_client: sdp.Client,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
) -> Optional[tuple[int, int]]:
|
||||
'''
|
||||
Find an AVDTP service, using a connected SDP client, and return its version,
|
||||
or None if none is found
|
||||
@@ -255,7 +252,7 @@ async def find_avdtp_service_with_sdp_client(
|
||||
# -----------------------------------------------------------------------------
|
||||
async def find_avdtp_service_with_connection(
|
||||
connection: device.Connection,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
) -> Optional[tuple[int, int]]:
|
||||
'''
|
||||
Find an AVDTP service, for a connection, and return its version,
|
||||
or None if none is found
|
||||
@@ -278,95 +275,6 @@ class RealtimeClock:
|
||||
await asyncio.sleep(duration)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MediaPacket:
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> MediaPacket:
|
||||
version = (data[0] >> 6) & 0x03
|
||||
padding = (data[0] >> 5) & 0x01
|
||||
extension = (data[0] >> 4) & 0x01
|
||||
csrc_count = data[0] & 0x0F
|
||||
marker = (data[1] >> 7) & 0x01
|
||||
payload_type = data[1] & 0x7F
|
||||
sequence_number = struct.unpack_from('>H', data, 2)[0]
|
||||
timestamp = struct.unpack_from('>I', data, 4)[0]
|
||||
ssrc = struct.unpack_from('>I', data, 8)[0]
|
||||
csrc_list = [
|
||||
struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)
|
||||
]
|
||||
payload = data[12 + csrc_count * 4 :]
|
||||
|
||||
return MediaPacket(
|
||||
version,
|
||||
padding,
|
||||
extension,
|
||||
marker,
|
||||
sequence_number,
|
||||
timestamp,
|
||||
ssrc,
|
||||
csrc_list,
|
||||
payload_type,
|
||||
payload,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version: int,
|
||||
padding: int,
|
||||
extension: int,
|
||||
marker: int,
|
||||
sequence_number: int,
|
||||
timestamp: int,
|
||||
ssrc: int,
|
||||
csrc_list: List[int],
|
||||
payload_type: int,
|
||||
payload: bytes,
|
||||
) -> None:
|
||||
self.version = version
|
||||
self.padding = padding
|
||||
self.extension = extension
|
||||
self.marker = marker
|
||||
self.sequence_number = sequence_number & 0xFFFF
|
||||
self.timestamp = timestamp & 0xFFFFFFFF
|
||||
self.ssrc = ssrc
|
||||
self.csrc_list = csrc_list
|
||||
self.payload_type = payload_type
|
||||
self.payload = payload
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
header = bytes(
|
||||
[
|
||||
self.version << 6
|
||||
| self.padding << 5
|
||||
| self.extension << 4
|
||||
| len(self.csrc_list),
|
||||
self.marker << 7 | self.payload_type,
|
||||
]
|
||||
) + struct.pack(
|
||||
'>HII',
|
||||
self.sequence_number,
|
||||
self.timestamp,
|
||||
self.ssrc,
|
||||
)
|
||||
for csrc in self.csrc_list:
|
||||
header += struct.pack('>I', csrc)
|
||||
return header + self.payload
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'RTP(v={self.version},'
|
||||
f'p={self.padding},'
|
||||
f'x={self.extension},'
|
||||
f'm={self.marker},'
|
||||
f'pt={self.payload_type},'
|
||||
f'sn={self.sequence_number},'
|
||||
f'ts={self.timestamp},'
|
||||
f'ssrc={self.ssrc},'
|
||||
f'csrcs={self.csrc_list},'
|
||||
f'payload_size={len(self.payload)})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MediaPacketPump:
|
||||
pump_task: Optional[asyncio.Task]
|
||||
@@ -377,6 +285,7 @@ class MediaPacketPump:
|
||||
self.packets = packets
|
||||
self.clock = clock
|
||||
self.pump_task = None
|
||||
self.completed = asyncio.Event()
|
||||
|
||||
async def start(self, rtp_channel: l2cap.ClassicChannel) -> None:
|
||||
async def pump_packets():
|
||||
@@ -406,6 +315,8 @@ class MediaPacketPump:
|
||||
)
|
||||
except asyncio.exceptions.CancelledError:
|
||||
logger.debug('pump canceled')
|
||||
finally:
|
||||
self.completed.set()
|
||||
|
||||
# Pump packets
|
||||
self.pump_task = asyncio.create_task(pump_packets())
|
||||
@@ -417,6 +328,9 @@ class MediaPacketPump:
|
||||
await self.pump_task
|
||||
self.pump_task = None
|
||||
|
||||
async def wait_for_completion(self) -> None:
|
||||
await self.completed.wait()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MessageAssembler:
|
||||
@@ -519,8 +433,8 @@ class MessageAssembler:
|
||||
)
|
||||
try:
|
||||
self.callback(self.transaction_label, message)
|
||||
except Exception as error:
|
||||
logger.exception(color(f'!!! exception in callback: {error}', 'red'))
|
||||
except Exception:
|
||||
logger.exception(color('!!! exception in callback', 'red'))
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -532,7 +446,7 @@ class ServiceCapabilities:
|
||||
service_category: int, service_capabilities_bytes: bytes
|
||||
) -> ServiceCapabilities:
|
||||
# Select the appropriate subclass
|
||||
cls: Type[ServiceCapabilities]
|
||||
cls: type[ServiceCapabilities]
|
||||
if service_category == AVDTP_MEDIA_CODEC_SERVICE_CATEGORY:
|
||||
cls = MediaCodecCapabilities
|
||||
else:
|
||||
@@ -547,7 +461,7 @@ class ServiceCapabilities:
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def parse_capabilities(payload: bytes) -> List[ServiceCapabilities]:
|
||||
def parse_capabilities(payload: bytes) -> list[ServiceCapabilities]:
|
||||
capabilities = []
|
||||
while payload:
|
||||
service_category = payload[0]
|
||||
@@ -580,7 +494,7 @@ class ServiceCapabilities:
|
||||
self.service_category = service_category
|
||||
self.service_capabilities_bytes = service_capabilities_bytes
|
||||
|
||||
def to_string(self, details: Optional[List[str]] = None) -> str:
|
||||
def to_string(self, details: Optional[list[str]] = None) -> str:
|
||||
attributes = ','.join(
|
||||
[name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)]
|
||||
+ (details or [])
|
||||
@@ -615,11 +529,25 @@ class MediaCodecCapabilities(ServiceCapabilities):
|
||||
self.media_codec_information
|
||||
)
|
||||
elif self.media_codec_type == A2DP_NON_A2DP_CODEC_TYPE:
|
||||
self.media_codec_information = (
|
||||
vendor_media_codec_information = (
|
||||
VendorSpecificMediaCodecInformation.from_bytes(
|
||||
self.media_codec_information
|
||||
)
|
||||
)
|
||||
if (
|
||||
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
|
||||
vendor_media_codec_information.vendor_id
|
||||
)
|
||||
) and (
|
||||
media_codec_information_class := vendor_class_map.get(
|
||||
vendor_media_codec_information.codec_id
|
||||
)
|
||||
):
|
||||
self.media_codec_information = media_codec_information_class.from_bytes(
|
||||
vendor_media_codec_information.value
|
||||
)
|
||||
else:
|
||||
self.media_codec_information = vendor_media_codec_information
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -679,7 +607,7 @@ class Message: # pylint:disable=attribute-defined-outside-init
|
||||
RESPONSE_REJECT = 3
|
||||
|
||||
# Subclasses, by signal identifier and message type
|
||||
subclasses: Dict[int, Dict[int, Type[Message]]] = {}
|
||||
subclasses: dict[int, dict[int, type[Message]]] = {}
|
||||
message_type: MessageType
|
||||
signal_identifier: int
|
||||
|
||||
@@ -824,7 +752,7 @@ class Discover_Response(Message):
|
||||
See Bluetooth AVDTP spec - 8.6.2 Stream End Point Discovery Response
|
||||
'''
|
||||
|
||||
endpoints: List[EndPointInfo]
|
||||
endpoints: list[EndPointInfo]
|
||||
|
||||
def init_from_payload(self):
|
||||
self.endpoints = []
|
||||
@@ -963,7 +891,7 @@ class Set_Configuration_Reject(Message):
|
||||
self.service_category = self.payload[0]
|
||||
self.error_code = self.payload[1]
|
||||
|
||||
def __init__(self, service_category, error_code):
|
||||
def __init__(self, error_code: int, service_category: int = 0) -> None:
|
||||
super().__init__(payload=bytes([service_category, error_code]))
|
||||
self.service_category = service_category
|
||||
self.error_code = error_code
|
||||
@@ -1199,6 +1127,14 @@ class Security_Control_Command(Message):
|
||||
See Bluetooth AVDTP spec - 8.17.1 Security Control Command
|
||||
'''
|
||||
|
||||
def init_from_payload(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.acp_seid = self.payload[0] >> 2
|
||||
self.data = self.payload[1:]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_string([f'ACP_SEID: {self.acp_seid}', f'data: {self.data}'])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@Message.subclass
|
||||
@@ -1260,13 +1196,16 @@ class DelayReport_Reject(Simple_Reject):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Protocol(EventEmitter):
|
||||
local_endpoints: List[LocalStreamEndPoint]
|
||||
remote_endpoints: Dict[int, DiscoveredStreamEndPoint]
|
||||
streams: Dict[int, Stream]
|
||||
transaction_results: List[Optional[asyncio.Future[Message]]]
|
||||
class Protocol(utils.EventEmitter):
|
||||
local_endpoints: list[LocalStreamEndPoint]
|
||||
remote_endpoints: dict[int, DiscoveredStreamEndPoint]
|
||||
streams: dict[int, Stream]
|
||||
transaction_results: list[Optional[asyncio.Future[Message]]]
|
||||
channel_connector: Callable[[], Awaitable[l2cap.ClassicChannel]]
|
||||
|
||||
EVENT_OPEN = "open"
|
||||
EVENT_CLOSE = "close"
|
||||
|
||||
class PacketType(enum.IntEnum):
|
||||
SINGLE_PACKET = 0
|
||||
START_PACKET = 1
|
||||
@@ -1279,7 +1218,7 @@ class Protocol(EventEmitter):
|
||||
|
||||
@staticmethod
|
||||
async def connect(
|
||||
connection: device.Connection, version: Tuple[int, int] = (1, 3)
|
||||
connection: device.Connection, version: tuple[int, int] = (1, 3)
|
||||
) -> Protocol:
|
||||
channel = await connection.create_l2cap_channel(
|
||||
spec=l2cap.ClassicChannelSpec(psm=AVDTP_PSM)
|
||||
@@ -1289,7 +1228,7 @@ class Protocol(EventEmitter):
|
||||
return protocol
|
||||
|
||||
def __init__(
|
||||
self, l2cap_channel: l2cap.ClassicChannel, version: Tuple[int, int] = (1, 3)
|
||||
self, l2cap_channel: l2cap.ClassicChannel, version: tuple[int, int] = (1, 3)
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.l2cap_channel = l2cap_channel
|
||||
@@ -1306,8 +1245,8 @@ class Protocol(EventEmitter):
|
||||
|
||||
# Register to receive PDUs from the channel
|
||||
l2cap_channel.sink = self.on_pdu
|
||||
l2cap_channel.on('open', self.on_l2cap_channel_open)
|
||||
l2cap_channel.on('close', self.on_l2cap_channel_close)
|
||||
l2cap_channel.on(l2cap_channel.EVENT_OPEN, self.on_l2cap_channel_open)
|
||||
l2cap_channel.on(l2cap_channel.EVENT_CLOSE, self.on_l2cap_channel_close)
|
||||
|
||||
def get_local_endpoint_by_seid(self, seid: int) -> Optional[LocalStreamEndPoint]:
|
||||
if 0 < seid <= len(self.local_endpoints):
|
||||
@@ -1316,10 +1255,20 @@ class Protocol(EventEmitter):
|
||||
return None
|
||||
|
||||
def add_source(
|
||||
self, codec_capabilities: MediaCodecCapabilities, packet_pump: MediaPacketPump
|
||||
self,
|
||||
codec_capabilities: MediaCodecCapabilities,
|
||||
packet_pump: MediaPacketPump,
|
||||
delay_reporting: bool = False,
|
||||
) -> LocalSource:
|
||||
seid = len(self.local_endpoints) + 1
|
||||
source = LocalSource(self, seid, codec_capabilities, packet_pump)
|
||||
service_capabilities = (
|
||||
[ServiceCapabilities(AVDTP_DELAY_REPORTING_SERVICE_CATEGORY)]
|
||||
if delay_reporting
|
||||
else []
|
||||
)
|
||||
source = LocalSource(
|
||||
self, seid, codec_capabilities, service_capabilities, packet_pump
|
||||
)
|
||||
self.local_endpoints.append(source)
|
||||
|
||||
return source
|
||||
@@ -1372,7 +1321,7 @@ class Protocol(EventEmitter):
|
||||
return self.remote_endpoints.values()
|
||||
|
||||
def find_remote_sink_by_codec(
|
||||
self, media_type: int, codec_type: int
|
||||
self, media_type: int, codec_type: int, vendor_id: int = 0, codec_id: int = 0
|
||||
) -> Optional[DiscoveredStreamEndPoint]:
|
||||
for endpoint in self.remote_endpoints.values():
|
||||
if (
|
||||
@@ -1397,7 +1346,19 @@ class Protocol(EventEmitter):
|
||||
codec_capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE
|
||||
and codec_capabilities.media_codec_type == codec_type
|
||||
):
|
||||
has_codec = True
|
||||
if isinstance(
|
||||
codec_capabilities.media_codec_information,
|
||||
VendorSpecificMediaCodecInformation,
|
||||
):
|
||||
if (
|
||||
codec_capabilities.media_codec_information.vendor_id
|
||||
== vendor_id
|
||||
and codec_capabilities.media_codec_information.codec_id
|
||||
== codec_id
|
||||
):
|
||||
has_codec = True
|
||||
else:
|
||||
has_codec = True
|
||||
if has_media_transport and has_codec:
|
||||
return endpoint
|
||||
|
||||
@@ -1438,10 +1399,8 @@ class Protocol(EventEmitter):
|
||||
try:
|
||||
response = handler(message)
|
||||
self.send_message(transaction_label, response)
|
||||
except Exception as error:
|
||||
logger.warning(
|
||||
f'{color("!!! Exception in handler:", "red")} {error}'
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(color("!!! Exception in handler:", "red"))
|
||||
else:
|
||||
logger.warning('unhandled command')
|
||||
else:
|
||||
@@ -1455,20 +1414,20 @@ class Protocol(EventEmitter):
|
||||
self.transaction_results[transaction_label] = None
|
||||
self.transaction_semaphore.release()
|
||||
|
||||
def on_l2cap_connection(self, channel):
|
||||
def on_l2cap_connection(self, channel: l2cap.ClassicChannel) -> None:
|
||||
# Forward the channel to the endpoint that's expecting it
|
||||
if self.channel_acceptor is None:
|
||||
logger.warning(color('!!! l2cap connection with no acceptor', 'red'))
|
||||
return
|
||||
self.channel_acceptor.on_l2cap_connection(channel)
|
||||
|
||||
def on_l2cap_channel_open(self):
|
||||
def on_l2cap_channel_open(self) -> None:
|
||||
logger.debug(color('<<< L2CAP channel open', 'magenta'))
|
||||
self.emit('open')
|
||||
self.emit(self.EVENT_OPEN)
|
||||
|
||||
def on_l2cap_channel_close(self):
|
||||
def on_l2cap_channel_close(self) -> None:
|
||||
logger.debug(color('<<< L2CAP channel close', 'magenta'))
|
||||
self.emit('close')
|
||||
self.emit(self.EVENT_CLOSE)
|
||||
|
||||
def send_message(self, transaction_label: int, message: Message) -> None:
|
||||
logger.debug(
|
||||
@@ -1536,7 +1495,7 @@ class Protocol(EventEmitter):
|
||||
|
||||
return response
|
||||
|
||||
async def start_transaction(self) -> Tuple[int, asyncio.Future[Message]]:
|
||||
async def start_transaction(self) -> tuple[int, asyncio.Future[Message]]:
|
||||
# Wait until we can start a new transaction
|
||||
await self.transaction_semaphore.acquire()
|
||||
|
||||
@@ -1586,28 +1545,34 @@ class Protocol(EventEmitter):
|
||||
async def abort(self, seid: int) -> Abort_Response:
|
||||
return await self.send_command(Abort_Command(seid))
|
||||
|
||||
def on_discover_command(self, _command):
|
||||
def on_discover_command(self, command: Discover_Command) -> Optional[Message]:
|
||||
endpoint_infos = [
|
||||
EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep)
|
||||
for endpoint in self.local_endpoints
|
||||
]
|
||||
return Discover_Response(endpoint_infos)
|
||||
|
||||
def on_get_capabilities_command(self, command):
|
||||
def on_get_capabilities_command(
|
||||
self, command: Get_Capabilities_Command
|
||||
) -> Optional[Message]:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return Get_Capabilities_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||
|
||||
return Get_Capabilities_Response(endpoint.capabilities)
|
||||
|
||||
def on_get_all_capabilities_command(self, command):
|
||||
def on_get_all_capabilities_command(
|
||||
self, command: Get_All_Capabilities_Command
|
||||
) -> Optional[Message]:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return Get_All_Capabilities_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||
|
||||
return Get_All_Capabilities_Response(endpoint.capabilities)
|
||||
|
||||
def on_set_configuration_command(self, command):
|
||||
def on_set_configuration_command(
|
||||
self, command: Set_Configuration_Command
|
||||
) -> Optional[Message]:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return Set_Configuration_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||
@@ -1623,7 +1588,9 @@ class Protocol(EventEmitter):
|
||||
result = stream.on_set_configuration_command(command.capabilities)
|
||||
return result or Set_Configuration_Response()
|
||||
|
||||
def on_get_configuration_command(self, command):
|
||||
def on_get_configuration_command(
|
||||
self, command: Get_Configuration_Command
|
||||
) -> Optional[Message]:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return Get_Configuration_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||
@@ -1632,7 +1599,7 @@ class Protocol(EventEmitter):
|
||||
|
||||
return endpoint.stream.on_get_configuration_command()
|
||||
|
||||
def on_reconfigure_command(self, command):
|
||||
def on_reconfigure_command(self, command: Reconfigure_Command) -> Optional[Message]:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return Reconfigure_Reject(0, AVDTP_BAD_ACP_SEID_ERROR)
|
||||
@@ -1642,7 +1609,7 @@ class Protocol(EventEmitter):
|
||||
result = endpoint.stream.on_reconfigure_command(command.capabilities)
|
||||
return result or Reconfigure_Response()
|
||||
|
||||
def on_open_command(self, command):
|
||||
def on_open_command(self, command: Open_Command) -> Optional[Message]:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||
@@ -1652,25 +1619,26 @@ class Protocol(EventEmitter):
|
||||
result = endpoint.stream.on_open_command()
|
||||
return result or Open_Response()
|
||||
|
||||
def on_start_command(self, command):
|
||||
def on_start_command(self, command: Start_Command) -> Optional[Message]:
|
||||
for seid in command.acp_seids:
|
||||
endpoint = self.get_local_endpoint_by_seid(seid)
|
||||
if endpoint is None:
|
||||
return Start_Reject(seid, AVDTP_BAD_ACP_SEID_ERROR)
|
||||
if endpoint.stream is None:
|
||||
return Start_Reject(AVDTP_BAD_STATE_ERROR)
|
||||
return Start_Reject(seid, AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
# Start all streams
|
||||
# TODO: deal with partial failures
|
||||
for seid in command.acp_seids:
|
||||
endpoint = self.get_local_endpoint_by_seid(seid)
|
||||
result = endpoint.stream.on_start_command()
|
||||
if result is not None:
|
||||
if not endpoint or not endpoint.stream:
|
||||
raise InvalidStateError("Should already be checked!")
|
||||
if (result := endpoint.stream.on_start_command()) is not None:
|
||||
return result
|
||||
|
||||
return Start_Response()
|
||||
|
||||
def on_suspend_command(self, command):
|
||||
def on_suspend_command(self, command: Suspend_Command) -> Optional[Message]:
|
||||
for seid in command.acp_seids:
|
||||
endpoint = self.get_local_endpoint_by_seid(seid)
|
||||
if endpoint is None:
|
||||
@@ -1682,13 +1650,14 @@ class Protocol(EventEmitter):
|
||||
# TODO: deal with partial failures
|
||||
for seid in command.acp_seids:
|
||||
endpoint = self.get_local_endpoint_by_seid(seid)
|
||||
result = endpoint.stream.on_suspend_command()
|
||||
if result is not None:
|
||||
if not endpoint or not endpoint.stream:
|
||||
raise InvalidStateError("Should already be checked!")
|
||||
if (result := endpoint.stream.on_suspend_command()) is not None:
|
||||
return result
|
||||
|
||||
return Suspend_Response()
|
||||
|
||||
def on_close_command(self, command):
|
||||
def on_close_command(self, command: Close_Command) -> Optional[Message]:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||
@@ -1698,7 +1667,7 @@ class Protocol(EventEmitter):
|
||||
result = endpoint.stream.on_close_command()
|
||||
return result or Close_Response()
|
||||
|
||||
def on_abort_command(self, command):
|
||||
def on_abort_command(self, command: Abort_Command) -> Optional[Message]:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None or endpoint.stream is None:
|
||||
return Abort_Response()
|
||||
@@ -1706,15 +1675,17 @@ class Protocol(EventEmitter):
|
||||
endpoint.stream.on_abort_command()
|
||||
return Abort_Response()
|
||||
|
||||
def on_security_control_command(self, command):
|
||||
def on_security_control_command(
|
||||
self, command: Security_Control_Command
|
||||
) -> Optional[Message]:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||
|
||||
result = endpoint.on_security_control_command(command.payload)
|
||||
result = endpoint.on_security_control_command(command.data)
|
||||
return result or Security_Control_Response()
|
||||
|
||||
def on_delayreport_command(self, command):
|
||||
def on_delayreport_command(self, command: DelayReport_Command) -> Optional[Message]:
|
||||
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
|
||||
if endpoint is None:
|
||||
return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR)
|
||||
@@ -1724,8 +1695,10 @@ class Protocol(EventEmitter):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Listener(EventEmitter):
|
||||
servers: Dict[int, Protocol]
|
||||
class Listener(utils.EventEmitter):
|
||||
servers: dict[int, Protocol]
|
||||
|
||||
EVENT_CONNECTION = "connection"
|
||||
|
||||
@staticmethod
|
||||
def create_registrar(device: device.Device):
|
||||
@@ -1755,13 +1728,13 @@ class Listener(EventEmitter):
|
||||
|
||||
@classmethod
|
||||
def for_device(
|
||||
cls, device: device.Device, version: Tuple[int, int] = (1, 3)
|
||||
cls, device: device.Device, version: tuple[int, int] = (1, 3)
|
||||
) -> Listener:
|
||||
listener = Listener(registrar=None, version=version)
|
||||
l2cap_server = device.create_l2cap_server(
|
||||
spec=l2cap.ClassicChannelSpec(psm=AVDTP_PSM)
|
||||
)
|
||||
l2cap_server.on('connection', listener.on_l2cap_connection)
|
||||
l2cap_server.on(l2cap_server.EVENT_CONNECTION, listener.on_l2cap_connection)
|
||||
return listener
|
||||
|
||||
def on_l2cap_connection(self, channel: l2cap.ClassicChannel) -> None:
|
||||
@@ -1777,14 +1750,14 @@ class Listener(EventEmitter):
|
||||
logger.debug('setting up new Protocol for the connection')
|
||||
server = Protocol(channel, self.version)
|
||||
self.set_server(channel.connection, server)
|
||||
self.emit('connection', server)
|
||||
self.emit(self.EVENT_CONNECTION, server)
|
||||
|
||||
def on_channel_close():
|
||||
logger.debug('removing Protocol for the connection')
|
||||
self.remove_server(channel.connection)
|
||||
|
||||
channel.on('open', on_channel_open)
|
||||
channel.on('close', on_channel_close)
|
||||
channel.on(channel.EVENT_OPEN, on_channel_open)
|
||||
channel.on(channel.EVENT_CLOSE, on_channel_close)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -1833,6 +1806,7 @@ class Stream:
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""[Source] Start streaming."""
|
||||
# Auto-open if needed
|
||||
if self.state == AVDTP_CONFIGURED_STATE:
|
||||
await self.open()
|
||||
@@ -1849,6 +1823,7 @@ class Stream:
|
||||
self.change_state(AVDTP_STREAMING_STATE)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""[Source] Stop streaming and transit to OPEN state."""
|
||||
if self.state != AVDTP_STREAMING_STATE:
|
||||
raise InvalidStateError('current state is not STREAMING')
|
||||
|
||||
@@ -1861,6 +1836,7 @@ class Stream:
|
||||
self.change_state(AVDTP_OPEN_STATE)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""[Source] Close channel and transit to IDLE state."""
|
||||
if self.state not in (AVDTP_OPEN_STATE, AVDTP_STREAMING_STATE):
|
||||
raise InvalidStateError('current state is not OPEN or STREAMING')
|
||||
|
||||
@@ -1892,7 +1868,7 @@ class Stream:
|
||||
self.change_state(AVDTP_CONFIGURED_STATE)
|
||||
return None
|
||||
|
||||
def on_get_configuration_command(self, configuration):
|
||||
def on_get_configuration_command(self):
|
||||
if self.state not in (
|
||||
AVDTP_CONFIGURED_STATE,
|
||||
AVDTP_OPEN_STATE,
|
||||
@@ -1900,7 +1876,7 @@ class Stream:
|
||||
):
|
||||
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
|
||||
|
||||
return self.local_endpoint.on_get_configuration_command(configuration)
|
||||
return self.local_endpoint.on_get_configuration_command()
|
||||
|
||||
def on_reconfigure_command(self, configuration):
|
||||
if self.state != AVDTP_OPEN_STATE:
|
||||
@@ -1980,20 +1956,20 @@ class Stream:
|
||||
# Wait for the RTP channel to be closed
|
||||
self.change_state(AVDTP_ABORTING_STATE)
|
||||
|
||||
def on_l2cap_connection(self, channel):
|
||||
def on_l2cap_connection(self, channel: l2cap.ClassicChannel) -> None:
|
||||
logger.debug(color('<<< stream channel connected', 'magenta'))
|
||||
self.rtp_channel = channel
|
||||
channel.on('open', self.on_l2cap_channel_open)
|
||||
channel.on('close', self.on_l2cap_channel_close)
|
||||
channel.on(channel.EVENT_OPEN, self.on_l2cap_channel_open)
|
||||
channel.on(channel.EVENT_CLOSE, self.on_l2cap_channel_close)
|
||||
|
||||
# We don't need more channels
|
||||
self.protocol.channel_acceptor = None
|
||||
|
||||
def on_l2cap_channel_open(self):
|
||||
def on_l2cap_channel_open(self) -> None:
|
||||
logger.debug(color('<<< stream channel open', 'magenta'))
|
||||
self.local_endpoint.on_rtp_channel_open()
|
||||
|
||||
def on_l2cap_channel_close(self):
|
||||
def on_l2cap_channel_close(self) -> None:
|
||||
logger.debug(color('<<< stream channel closed', 'magenta'))
|
||||
self.local_endpoint.on_rtp_channel_close()
|
||||
self.local_endpoint.in_use = 0
|
||||
@@ -2107,9 +2083,22 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class LocalStreamEndPoint(StreamEndPoint, EventEmitter):
|
||||
class LocalStreamEndPoint(StreamEndPoint, utils.EventEmitter):
|
||||
stream: Optional[Stream]
|
||||
|
||||
EVENT_CONFIGURATION = "configuration"
|
||||
EVENT_OPEN = "open"
|
||||
EVENT_START = "start"
|
||||
EVENT_STOP = "stop"
|
||||
EVENT_RTP_PACKET = "rtp_packet"
|
||||
EVENT_SUSPEND = "suspend"
|
||||
EVENT_CLOSE = "close"
|
||||
EVENT_ABORT = "abort"
|
||||
EVENT_DELAY_REPORT = "delay_report"
|
||||
EVENT_SECURITY_CONTROL = "security_control"
|
||||
EVENT_RTP_CHANNEL_OPEN = "rtp_channel_open"
|
||||
EVENT_RTP_CHANNEL_CLOSE = "rtp_channel_close"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
protocol: Protocol,
|
||||
@@ -2120,57 +2109,70 @@ class LocalStreamEndPoint(StreamEndPoint, EventEmitter):
|
||||
configuration: Optional[Iterable[ServiceCapabilities]] = None,
|
||||
):
|
||||
StreamEndPoint.__init__(self, seid, media_type, tsep, 0, capabilities)
|
||||
EventEmitter.__init__(self)
|
||||
utils.EventEmitter.__init__(self)
|
||||
self.protocol = protocol
|
||||
self.configuration = configuration if configuration is not None else []
|
||||
self.stream = None
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
async def start(self) -> None:
|
||||
"""[Source Only] Handles when receiving start command."""
|
||||
|
||||
async def stop(self):
|
||||
pass
|
||||
async def stop(self) -> None:
|
||||
"""[Source Only] Handles when receiving stop command."""
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
async def close(self) -> None:
|
||||
"""[Source Only] Handles when receiving close command."""
|
||||
|
||||
def on_reconfigure_command(self, command):
|
||||
pass
|
||||
def on_reconfigure_command(self, command) -> Optional[Message]:
|
||||
return None
|
||||
|
||||
def on_set_configuration_command(self, configuration):
|
||||
def on_set_configuration_command(self, configuration) -> Optional[Message]:
|
||||
logger.debug(
|
||||
'<<< received configuration: '
|
||||
f'{",".join([str(capability) for capability in configuration])}'
|
||||
)
|
||||
self.configuration = configuration
|
||||
self.emit('configuration')
|
||||
self.emit(self.EVENT_CONFIGURATION)
|
||||
return None
|
||||
|
||||
def on_get_configuration_command(self):
|
||||
def on_get_configuration_command(self) -> Optional[Message]:
|
||||
return Get_Configuration_Response(self.configuration)
|
||||
|
||||
def on_open_command(self):
|
||||
self.emit('open')
|
||||
def on_open_command(self) -> Optional[Message]:
|
||||
self.emit(self.EVENT_OPEN)
|
||||
return None
|
||||
|
||||
def on_start_command(self):
|
||||
self.emit('start')
|
||||
def on_start_command(self) -> Optional[Message]:
|
||||
self.emit(self.EVENT_START)
|
||||
return None
|
||||
|
||||
def on_suspend_command(self):
|
||||
self.emit('suspend')
|
||||
def on_suspend_command(self) -> Optional[Message]:
|
||||
self.emit(self.EVENT_SUSPEND)
|
||||
return None
|
||||
|
||||
def on_close_command(self):
|
||||
self.emit('close')
|
||||
def on_close_command(self) -> Optional[Message]:
|
||||
self.emit(self.EVENT_CLOSE)
|
||||
return None
|
||||
|
||||
def on_abort_command(self):
|
||||
self.emit('abort')
|
||||
def on_abort_command(self) -> Optional[Message]:
|
||||
self.emit(self.EVENT_ABORT)
|
||||
return None
|
||||
|
||||
def on_delayreport_command(self, delay: int):
|
||||
self.emit('delay_report', delay)
|
||||
def on_delayreport_command(self, delay: int) -> Optional[Message]:
|
||||
self.emit(self.EVENT_DELAY_REPORT, delay)
|
||||
return None
|
||||
|
||||
def on_rtp_channel_open(self):
|
||||
self.emit('rtp_channel_open')
|
||||
def on_security_control_command(self, data: bytes) -> Optional[Message]:
|
||||
self.emit(self.EVENT_SECURITY_CONTROL, data)
|
||||
return None
|
||||
|
||||
def on_rtp_channel_close(self):
|
||||
self.emit('rtp_channel_close')
|
||||
def on_rtp_channel_open(self) -> None:
|
||||
self.emit(self.EVENT_RTP_CHANNEL_OPEN)
|
||||
return None
|
||||
|
||||
def on_rtp_channel_close(self) -> None:
|
||||
self.emit(self.EVENT_RTP_CHANNEL_CLOSE)
|
||||
return None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -2180,12 +2182,13 @@ class LocalSource(LocalStreamEndPoint):
|
||||
protocol: Protocol,
|
||||
seid: int,
|
||||
codec_capabilities: MediaCodecCapabilities,
|
||||
other_capabilitiles: Iterable[ServiceCapabilities],
|
||||
packet_pump: MediaPacketPump,
|
||||
) -> None:
|
||||
capabilities = [
|
||||
ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY),
|
||||
codec_capabilities,
|
||||
]
|
||||
] + list(other_capabilitiles)
|
||||
super().__init__(
|
||||
protocol,
|
||||
seid,
|
||||
@@ -2200,13 +2203,13 @@ class LocalSource(LocalStreamEndPoint):
|
||||
if self.packet_pump and self.stream and self.stream.rtp_channel:
|
||||
return await self.packet_pump.start(self.stream.rtp_channel)
|
||||
|
||||
self.emit('start')
|
||||
self.emit(self.EVENT_START)
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self.packet_pump:
|
||||
return await self.packet_pump.stop()
|
||||
|
||||
self.emit('stop')
|
||||
self.emit(self.EVENT_STOP)
|
||||
|
||||
def on_start_command(self):
|
||||
asyncio.create_task(self.start())
|
||||
@@ -2247,4 +2250,4 @@ class LocalSink(LocalStreamEndPoint):
|
||||
f'{color("<<< RTP Packet:", "green")} '
|
||||
f'{rtp_packet} {rtp_packet.payload[:16].hex()}'
|
||||
)
|
||||
self.emit('rtp_packet', rtp_packet)
|
||||
self.emit(self.EVENT_RTP_PACKET, rtp_packet)
|
||||
|
||||
1934
bumble/avrcp.py
1934
bumble/avrcp.py
File diff suppressed because it is too large
Load Diff
@@ -17,8 +17,8 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
import logging
|
||||
|
||||
from .hci import HCI_Packet
|
||||
from .helpers import PacketTracer
|
||||
from bumble.hci import HCI_Packet
|
||||
from bumble.helpers import PacketTracer
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
|
||||
292
bumble/codecs.py
292
bumble/codecs.py
@@ -16,8 +16,11 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble import core
|
||||
|
||||
|
||||
@@ -101,12 +104,40 @@ class BitReader:
|
||||
break
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class BitWriter:
|
||||
"""Simple but not optimized bit stream writer."""
|
||||
|
||||
data: int
|
||||
bit_count: int
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = 0
|
||||
self.bit_count = 0
|
||||
|
||||
def write(self, value: int, bit_count: int) -> None:
|
||||
self.data = (self.data << bit_count) | value
|
||||
self.bit_count += bit_count
|
||||
|
||||
def write_bytes(self, data: bytes) -> None:
|
||||
bit_count = 8 * len(data)
|
||||
self.data = (self.data << bit_count) | int.from_bytes(data, 'big')
|
||||
self.bit_count += bit_count
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return (self.data << ((8 - (self.bit_count % 8)) % 8)).to_bytes(
|
||||
(self.bit_count + 7) // 8, 'big'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AacAudioRtpPacket:
|
||||
"""AAC payload encapsulated in an RTP packet payload"""
|
||||
|
||||
audio_mux_element: AudioMuxElement
|
||||
|
||||
@staticmethod
|
||||
def latm_value(reader: BitReader) -> int:
|
||||
def read_latm_value(reader: BitReader) -> int:
|
||||
bytes_for_value = reader.read(2)
|
||||
value = 0
|
||||
for _ in range(bytes_for_value + 1):
|
||||
@@ -114,24 +145,33 @@ class AacAudioRtpPacket:
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def program_config_element(reader: BitReader):
|
||||
raise core.InvalidPacketError('program_config_element not supported')
|
||||
def read_audio_object_type(reader: BitReader):
|
||||
# GetAudioObjectType - ISO/EIC 14496-3 Table 1.16
|
||||
audio_object_type = reader.read(5)
|
||||
if audio_object_type == 31:
|
||||
audio_object_type = 32 + reader.read(6)
|
||||
|
||||
return audio_object_type
|
||||
|
||||
@dataclass
|
||||
class GASpecificConfig:
|
||||
def __init__(
|
||||
self, reader: BitReader, channel_configuration: int, audio_object_type: int
|
||||
) -> None:
|
||||
audio_object_type: int
|
||||
# NOTE: other fields not supported
|
||||
|
||||
@classmethod
|
||||
def from_bits(
|
||||
cls, reader: BitReader, channel_configuration: int, audio_object_type: int
|
||||
) -> Self:
|
||||
# GASpecificConfig - ISO/EIC 14496-3 Table 4.1
|
||||
frame_length_flag = reader.read(1)
|
||||
depends_on_core_coder = reader.read(1)
|
||||
if depends_on_core_coder:
|
||||
self.core_coder_delay = reader.read(14)
|
||||
core_coder_delay = reader.read(14)
|
||||
extension_flag = reader.read(1)
|
||||
if not channel_configuration:
|
||||
AacAudioRtpPacket.program_config_element(reader)
|
||||
raise core.InvalidPacketError('program_config_element not supported')
|
||||
if audio_object_type in (6, 20):
|
||||
self.layer_nr = reader.read(3)
|
||||
layer_nr = reader.read(3)
|
||||
if extension_flag:
|
||||
if audio_object_type == 22:
|
||||
num_of_sub_frame = reader.read(5)
|
||||
@@ -144,14 +184,13 @@ class AacAudioRtpPacket:
|
||||
if extension_flag_3 == 1:
|
||||
raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
|
||||
|
||||
@staticmethod
|
||||
def audio_object_type(reader: BitReader):
|
||||
# GetAudioObjectType - ISO/EIC 14496-3 Table 1.16
|
||||
audio_object_type = reader.read(5)
|
||||
if audio_object_type == 31:
|
||||
audio_object_type = 32 + reader.read(6)
|
||||
return cls(audio_object_type)
|
||||
|
||||
return audio_object_type
|
||||
def to_bits(self, writer: BitWriter) -> None:
|
||||
assert self.audio_object_type in (1, 2)
|
||||
writer.write(0, 1) # frame_length_flag = 0
|
||||
writer.write(0, 1) # depends_on_core_coder = 0
|
||||
writer.write(0, 1) # extension_flag = 0
|
||||
|
||||
@dataclass
|
||||
class AudioSpecificConfig:
|
||||
@@ -159,6 +198,7 @@ class AacAudioRtpPacket:
|
||||
sampling_frequency_index: int
|
||||
sampling_frequency: int
|
||||
channel_configuration: int
|
||||
ga_specific_config: AacAudioRtpPacket.GASpecificConfig
|
||||
sbr_present_flag: int
|
||||
ps_present_flag: int
|
||||
extension_audio_object_type: int
|
||||
@@ -182,44 +222,73 @@ class AacAudioRtpPacket:
|
||||
7350,
|
||||
]
|
||||
|
||||
def __init__(self, reader: BitReader) -> None:
|
||||
# AudioSpecificConfig - ISO/EIC 14496-3 Table 1.15
|
||||
self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
|
||||
self.sampling_frequency_index = reader.read(4)
|
||||
if self.sampling_frequency_index == 0xF:
|
||||
self.sampling_frequency = reader.read(24)
|
||||
else:
|
||||
self.sampling_frequency = self.SAMPLING_FREQUENCIES[
|
||||
self.sampling_frequency_index
|
||||
]
|
||||
self.channel_configuration = reader.read(4)
|
||||
self.sbr_present_flag = -1
|
||||
self.ps_present_flag = -1
|
||||
if self.audio_object_type in (5, 29):
|
||||
self.extension_audio_object_type = 5
|
||||
self.sbc_present_flag = 1
|
||||
if self.audio_object_type == 29:
|
||||
self.ps_present_flag = 1
|
||||
self.extension_sampling_frequency_index = reader.read(4)
|
||||
if self.extension_sampling_frequency_index == 0xF:
|
||||
self.extension_sampling_frequency = reader.read(24)
|
||||
else:
|
||||
self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[
|
||||
self.extension_sampling_frequency_index
|
||||
]
|
||||
self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
|
||||
if self.audio_object_type == 22:
|
||||
self.extension_channel_configuration = reader.read(4)
|
||||
else:
|
||||
self.extension_audio_object_type = 0
|
||||
@classmethod
|
||||
def for_simple_aac(
|
||||
cls,
|
||||
audio_object_type: int,
|
||||
sampling_frequency: int,
|
||||
channel_configuration: int,
|
||||
) -> Self:
|
||||
if sampling_frequency not in cls.SAMPLING_FREQUENCIES:
|
||||
raise ValueError(f'invalid sampling frequency {sampling_frequency}')
|
||||
|
||||
if self.audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23):
|
||||
ga_specific_config = AacAudioRtpPacket.GASpecificConfig(
|
||||
reader, self.channel_configuration, self.audio_object_type
|
||||
ga_specific_config = AacAudioRtpPacket.GASpecificConfig(audio_object_type)
|
||||
|
||||
return cls(
|
||||
audio_object_type=audio_object_type,
|
||||
sampling_frequency_index=cls.SAMPLING_FREQUENCIES.index(
|
||||
sampling_frequency
|
||||
),
|
||||
sampling_frequency=sampling_frequency,
|
||||
channel_configuration=channel_configuration,
|
||||
ga_specific_config=ga_specific_config,
|
||||
sbr_present_flag=0,
|
||||
ps_present_flag=0,
|
||||
extension_audio_object_type=0,
|
||||
extension_sampling_frequency_index=0,
|
||||
extension_sampling_frequency=0,
|
||||
extension_channel_configuration=0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_bits(cls, reader: BitReader) -> Self:
|
||||
# AudioSpecificConfig - ISO/EIC 14496-3 Table 1.15
|
||||
audio_object_type = AacAudioRtpPacket.read_audio_object_type(reader)
|
||||
sampling_frequency_index = reader.read(4)
|
||||
if sampling_frequency_index == 0xF:
|
||||
sampling_frequency = reader.read(24)
|
||||
else:
|
||||
sampling_frequency = cls.SAMPLING_FREQUENCIES[sampling_frequency_index]
|
||||
channel_configuration = reader.read(4)
|
||||
sbr_present_flag = 0
|
||||
ps_present_flag = 0
|
||||
extension_sampling_frequency_index = 0
|
||||
extension_sampling_frequency = 0
|
||||
extension_channel_configuration = 0
|
||||
extension_audio_object_type = 0
|
||||
if audio_object_type in (5, 29):
|
||||
extension_audio_object_type = 5
|
||||
sbr_present_flag = 1
|
||||
if audio_object_type == 29:
|
||||
ps_present_flag = 1
|
||||
extension_sampling_frequency_index = reader.read(4)
|
||||
if extension_sampling_frequency_index == 0xF:
|
||||
extension_sampling_frequency = reader.read(24)
|
||||
else:
|
||||
extension_sampling_frequency = cls.SAMPLING_FREQUENCIES[
|
||||
extension_sampling_frequency_index
|
||||
]
|
||||
audio_object_type = AacAudioRtpPacket.read_audio_object_type(reader)
|
||||
if audio_object_type == 22:
|
||||
extension_channel_configuration = reader.read(4)
|
||||
|
||||
if audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23):
|
||||
ga_specific_config = AacAudioRtpPacket.GASpecificConfig.from_bits(
|
||||
reader, channel_configuration, audio_object_type
|
||||
)
|
||||
else:
|
||||
raise core.InvalidPacketError(
|
||||
f'audioObjectType {self.audio_object_type} not supported'
|
||||
f'audioObjectType {audio_object_type} not supported'
|
||||
)
|
||||
|
||||
# if self.extension_audio_object_type != 5 and bits_to_decode >= 16:
|
||||
@@ -248,13 +317,44 @@ class AacAudioRtpPacket:
|
||||
# self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[self.extension_sampling_frequency_index]
|
||||
# self.extension_channel_configuration = reader.read(4)
|
||||
|
||||
return cls(
|
||||
audio_object_type,
|
||||
sampling_frequency_index,
|
||||
sampling_frequency,
|
||||
channel_configuration,
|
||||
ga_specific_config,
|
||||
sbr_present_flag,
|
||||
ps_present_flag,
|
||||
extension_audio_object_type,
|
||||
extension_sampling_frequency_index,
|
||||
extension_sampling_frequency,
|
||||
extension_channel_configuration,
|
||||
)
|
||||
|
||||
def to_bits(self, writer: BitWriter) -> None:
|
||||
if self.sampling_frequency_index >= 15:
|
||||
raise ValueError(
|
||||
f"unsupported sampling frequency index {self.sampling_frequency_index}"
|
||||
)
|
||||
|
||||
if self.audio_object_type not in (1, 2):
|
||||
raise ValueError(
|
||||
f"unsupported audio object type {self.audio_object_type} "
|
||||
)
|
||||
|
||||
writer.write(self.audio_object_type, 5)
|
||||
writer.write(self.sampling_frequency_index, 4)
|
||||
writer.write(self.channel_configuration, 4)
|
||||
self.ga_specific_config.to_bits(writer)
|
||||
|
||||
@dataclass
|
||||
class StreamMuxConfig:
|
||||
other_data_present: int
|
||||
other_data_len_bits: int
|
||||
audio_specific_config: AacAudioRtpPacket.AudioSpecificConfig
|
||||
|
||||
def __init__(self, reader: BitReader) -> None:
|
||||
@classmethod
|
||||
def from_bits(cls, reader: BitReader) -> Self:
|
||||
# StreamMuxConfig - ISO/EIC 14496-3 Table 1.42
|
||||
audio_mux_version = reader.read(1)
|
||||
if audio_mux_version == 1:
|
||||
@@ -264,7 +364,7 @@ class AacAudioRtpPacket:
|
||||
if audio_mux_version_a != 0:
|
||||
raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
|
||||
if audio_mux_version == 1:
|
||||
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
|
||||
tara_buffer_fullness = AacAudioRtpPacket.read_latm_value(reader)
|
||||
stream_cnt = 0
|
||||
all_streams_same_time_framing = reader.read(1)
|
||||
num_sub_frames = reader.read(6)
|
||||
@@ -275,13 +375,13 @@ class AacAudioRtpPacket:
|
||||
if num_layer != 0:
|
||||
raise core.InvalidPacketError('num_layer != 0 not supported')
|
||||
if audio_mux_version == 0:
|
||||
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
|
||||
audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig.from_bits(
|
||||
reader
|
||||
)
|
||||
else:
|
||||
asc_len = AacAudioRtpPacket.latm_value(reader)
|
||||
asc_len = AacAudioRtpPacket.read_latm_value(reader)
|
||||
marker = reader.bit_position
|
||||
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
|
||||
audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig.from_bits(
|
||||
reader
|
||||
)
|
||||
audio_specific_config_len = reader.bit_position - marker
|
||||
@@ -299,36 +399,49 @@ class AacAudioRtpPacket:
|
||||
f'frame_length_type {frame_length_type} not supported'
|
||||
)
|
||||
|
||||
self.other_data_present = reader.read(1)
|
||||
if self.other_data_present:
|
||||
other_data_present = reader.read(1)
|
||||
other_data_len_bits = 0
|
||||
if other_data_present:
|
||||
if audio_mux_version == 1:
|
||||
self.other_data_len_bits = AacAudioRtpPacket.latm_value(reader)
|
||||
other_data_len_bits = AacAudioRtpPacket.read_latm_value(reader)
|
||||
else:
|
||||
self.other_data_len_bits = 0
|
||||
while True:
|
||||
self.other_data_len_bits *= 256
|
||||
other_data_len_bits *= 256
|
||||
other_data_len_esc = reader.read(1)
|
||||
self.other_data_len_bits += reader.read(8)
|
||||
other_data_len_bits += reader.read(8)
|
||||
if other_data_len_esc == 0:
|
||||
break
|
||||
crc_check_present = reader.read(1)
|
||||
if crc_check_present:
|
||||
crc_checksum = reader.read(8)
|
||||
|
||||
return cls(other_data_present, other_data_len_bits, audio_specific_config)
|
||||
|
||||
def to_bits(self, writer: BitWriter) -> None:
|
||||
writer.write(0, 1) # audioMuxVersion = 0
|
||||
writer.write(1, 1) # allStreamsSameTimeFraming = 1
|
||||
writer.write(0, 6) # numSubFrames = 0
|
||||
writer.write(0, 4) # numProgram = 0
|
||||
writer.write(0, 3) # numLayer = 0
|
||||
self.audio_specific_config.to_bits(writer)
|
||||
writer.write(0, 3) # frameLengthType = 0
|
||||
writer.write(0, 8) # latmBufferFullness = 0
|
||||
writer.write(0, 1) # otherDataPresent = 0
|
||||
writer.write(0, 1) # crcCheckPresent = 0
|
||||
|
||||
@dataclass
|
||||
class AudioMuxElement:
|
||||
payload: bytes
|
||||
stream_mux_config: AacAudioRtpPacket.StreamMuxConfig
|
||||
payload: bytes
|
||||
|
||||
def __init__(self, reader: BitReader, mux_config_present: int):
|
||||
if mux_config_present == 0:
|
||||
raise core.InvalidPacketError('muxConfigPresent == 0 not supported')
|
||||
|
||||
@classmethod
|
||||
def from_bits(cls, reader: BitReader) -> Self:
|
||||
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41
|
||||
# (only supports mux_config_present=1)
|
||||
use_same_stream_mux = reader.read(1)
|
||||
if use_same_stream_mux:
|
||||
raise core.InvalidPacketError('useSameStreamMux == 1 not supported')
|
||||
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
|
||||
stream_mux_config = AacAudioRtpPacket.StreamMuxConfig.from_bits(reader)
|
||||
|
||||
# We only support:
|
||||
# allStreamsSameTimeFraming == 1
|
||||
@@ -344,19 +457,46 @@ class AacAudioRtpPacket:
|
||||
if tmp != 255:
|
||||
break
|
||||
|
||||
self.payload = reader.read_bytes(mux_slot_length_bytes)
|
||||
payload = reader.read_bytes(mux_slot_length_bytes)
|
||||
|
||||
if self.stream_mux_config.other_data_present:
|
||||
reader.skip(self.stream_mux_config.other_data_len_bits)
|
||||
if stream_mux_config.other_data_present:
|
||||
reader.skip(stream_mux_config.other_data_len_bits)
|
||||
|
||||
# ByteAlign
|
||||
while reader.bit_position % 8:
|
||||
reader.read(1)
|
||||
|
||||
def __init__(self, data: bytes) -> None:
|
||||
return cls(stream_mux_config, payload)
|
||||
|
||||
def to_bits(self, writer: BitWriter) -> None:
|
||||
writer.write(0, 1) # useSameStreamMux = 0
|
||||
self.stream_mux_config.to_bits(writer)
|
||||
mux_slot_length_bytes = len(self.payload)
|
||||
while mux_slot_length_bytes > 255:
|
||||
writer.write(255, 8)
|
||||
mux_slot_length_bytes -= 255
|
||||
writer.write(mux_slot_length_bytes, 8)
|
||||
if mux_slot_length_bytes == 255:
|
||||
writer.write(0, 8)
|
||||
writer.write_bytes(self.payload)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
# Parse the bit stream
|
||||
reader = BitReader(data)
|
||||
self.audio_mux_element = self.AudioMuxElement(reader, mux_config_present=1)
|
||||
return cls(cls.AudioMuxElement.from_bits(reader))
|
||||
|
||||
@classmethod
|
||||
def for_simple_aac(
|
||||
cls, sampling_frequency: int, channel_configuration: int, payload: bytes
|
||||
) -> Self:
|
||||
audio_specific_config = cls.AudioSpecificConfig.for_simple_aac(
|
||||
2, sampling_frequency, channel_configuration
|
||||
)
|
||||
stream_mux_config = cls.StreamMuxConfig(0, 0, audio_specific_config)
|
||||
audio_mux_element = cls.AudioMuxElement(stream_mux_config, payload)
|
||||
|
||||
return cls(audio_mux_element)
|
||||
|
||||
def to_adts(self):
|
||||
# pylint: disable=line-too-long
|
||||
@@ -383,3 +523,11 @@ class AacAudioRtpPacket:
|
||||
)
|
||||
+ self.audio_mux_element.payload
|
||||
)
|
||||
|
||||
def __init__(self, audio_mux_element: AudioMuxElement) -> None:
|
||||
self.audio_mux_element = audio_mux_element
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
writer = BitWriter()
|
||||
self.audio_mux_element.to_bits(writer)
|
||||
return bytes(writer)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class ColorError(ValueError):
|
||||
@@ -65,7 +65,7 @@ def color(
|
||||
bg: Optional[ColorSpec] = None,
|
||||
style: Optional[str] = None,
|
||||
) -> str:
|
||||
codes: List[ColorSpec] = []
|
||||
codes: list[ColorSpec] = []
|
||||
|
||||
if fg:
|
||||
codes.append(_color_code(fg, 30))
|
||||
|
||||
@@ -17,34 +17,30 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import itertools
|
||||
import logging
|
||||
import random
|
||||
import struct
|
||||
from bumble.colors import color
|
||||
from bumble.core import (
|
||||
BT_CENTRAL_ROLE,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
BT_LE_TRANSPORT,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from bumble import hci
|
||||
from bumble.colors import color
|
||||
from bumble.core import PhysicalTransport
|
||||
from bumble.hci import (
|
||||
HCI_ACL_DATA_PACKET,
|
||||
HCI_COMMAND_DISALLOWED_ERROR,
|
||||
HCI_COMMAND_PACKET,
|
||||
HCI_COMMAND_STATUS_PENDING,
|
||||
HCI_CONNECTION_TIMEOUT_ERROR,
|
||||
HCI_CONTROLLER_BUSY_ERROR,
|
||||
HCI_EVENT_PACKET,
|
||||
HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR,
|
||||
HCI_LE_1M_PHY,
|
||||
HCI_SUCCESS,
|
||||
HCI_UNKNOWN_HCI_COMMAND_ERROR,
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
HCI_SUCCESS,
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
HCI_UNKNOWN_HCI_COMMAND_ERROR,
|
||||
HCI_VERSION_BLUETOOTH_CORE_5_0,
|
||||
Address,
|
||||
HCI_AclDataPacket,
|
||||
@@ -55,7 +51,6 @@ from bumble.hci import (
|
||||
HCI_Connection_Request_Event,
|
||||
HCI_Disconnection_Complete_Event,
|
||||
HCI_Encryption_Change_Event,
|
||||
HCI_Synchronous_Connection_Complete_Event,
|
||||
HCI_LE_Advertising_Report_Event,
|
||||
HCI_LE_CIS_Established_Event,
|
||||
HCI_LE_CIS_Request_Event,
|
||||
@@ -64,8 +59,9 @@ from bumble.hci import (
|
||||
HCI_Number_Of_Completed_Packets_Event,
|
||||
HCI_Packet,
|
||||
HCI_Role_Change_Event,
|
||||
HCI_Synchronous_Connection_Complete_Event,
|
||||
Role,
|
||||
)
|
||||
from typing import Optional, Union, Dict, Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.link import LocalLink
|
||||
@@ -91,6 +87,7 @@ class CisLink:
|
||||
cis_id: int
|
||||
cig_id: int
|
||||
acl_connection: Optional[Connection] = None
|
||||
data_paths: set[int] = dataclasses.field(default_factory=set)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -98,7 +95,7 @@ class CisLink:
|
||||
class Connection:
|
||||
controller: Controller
|
||||
handle: int
|
||||
role: int
|
||||
role: Role
|
||||
peer_address: Address
|
||||
link: Any
|
||||
transport: int
|
||||
@@ -110,7 +107,9 @@ class Connection:
|
||||
def on_hci_acl_data_packet(self, packet):
|
||||
self.assembler.feed_packet(packet)
|
||||
self.controller.send_hci_packet(
|
||||
HCI_Number_Of_Completed_Packets_Event([(self.handle, 1)])
|
||||
HCI_Number_Of_Completed_Packets_Event(
|
||||
connection_handles=[self.handle], num_completed_packets=[1]
|
||||
)
|
||||
)
|
||||
|
||||
def on_acl_pdu(self, data):
|
||||
@@ -134,17 +133,17 @@ class Controller:
|
||||
self.hci_sink = None
|
||||
self.link = link
|
||||
|
||||
self.central_connections: Dict[Address, Connection] = (
|
||||
self.central_connections: dict[Address, Connection] = (
|
||||
{}
|
||||
) # Connections where this controller is the central
|
||||
self.peripheral_connections: Dict[Address, Connection] = (
|
||||
self.peripheral_connections: dict[Address, Connection] = (
|
||||
{}
|
||||
) # Connections where this controller is the peripheral
|
||||
self.classic_connections: Dict[Address, Connection] = (
|
||||
self.classic_connections: dict[Address, Connection] = (
|
||||
{}
|
||||
) # Connections in BR/EDR
|
||||
self.central_cis_links: Dict[int, CisLink] = {} # CIS links by handle
|
||||
self.peripheral_cis_links: Dict[int, CisLink] = {} # CIS links by handle
|
||||
self.central_cis_links: dict[int, CisLink] = {} # CIS links by handle
|
||||
self.peripheral_cis_links: dict[int, CisLink] = {} # CIS links by handle
|
||||
|
||||
self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0
|
||||
self.hci_revision = 0
|
||||
@@ -154,15 +153,17 @@ class Controller:
|
||||
'0000000060000000'
|
||||
) # BR/EDR Not Supported, LE Supported (Controller)
|
||||
self.manufacturer_name = 0xFFFF
|
||||
self.hc_data_packet_length = 27
|
||||
self.hc_total_num_data_packets = 64
|
||||
self.hc_le_data_packet_length = 27
|
||||
self.hc_total_num_le_data_packets = 64
|
||||
self.acl_data_packet_length = 27
|
||||
self.total_num_acl_data_packets = 64
|
||||
self.le_acl_data_packet_length = 27
|
||||
self.total_num_le_acl_data_packets = 64
|
||||
self.iso_data_packet_length = 960
|
||||
self.total_num_iso_data_packets = 64
|
||||
self.event_mask = 0
|
||||
self.event_mask_page_2 = 0
|
||||
self.supported_commands = bytes.fromhex(
|
||||
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
|
||||
'30f0f9ff01008004000000000000000000000000000000000000000000000000'
|
||||
'30f0f9ff01008004002000000000000000000000000000000000000000000000'
|
||||
)
|
||||
self.le_event_mask = 0
|
||||
self.advertising_parameters = None
|
||||
@@ -314,7 +315,7 @@ class Controller:
|
||||
f'{color("CONTROLLER -> HOST", "green")}: {packet}'
|
||||
)
|
||||
if self.host:
|
||||
self.host.on_packet(packet.to_bytes())
|
||||
self.host.on_packet(bytes(packet))
|
||||
|
||||
# This method allows the controller to emulate the same API as a transport source
|
||||
async def wait_for_termination(self):
|
||||
@@ -368,12 +369,23 @@ class Controller:
|
||||
return connection
|
||||
return None
|
||||
|
||||
def find_peripheral_connection_by_handle(self, handle):
|
||||
for connection in self.peripheral_connections.values():
|
||||
if connection.handle == handle:
|
||||
return connection
|
||||
return None
|
||||
|
||||
def find_classic_connection_by_handle(self, handle):
|
||||
for connection in self.classic_connections.values():
|
||||
if connection.handle == handle:
|
||||
return connection
|
||||
return None
|
||||
|
||||
def find_iso_link_by_handle(self, handle: int) -> Optional[CisLink]:
|
||||
return self.central_cis_links.get(handle) or self.peripheral_cis_links.get(
|
||||
handle
|
||||
)
|
||||
|
||||
def on_link_central_connected(self, central_address):
|
||||
'''
|
||||
Called when an incoming connection occurs from a central on the link
|
||||
@@ -388,11 +400,11 @@ class Controller:
|
||||
connection = Connection(
|
||||
controller=self,
|
||||
handle=connection_handle,
|
||||
role=BT_PERIPHERAL_ROLE,
|
||||
role=Role.PERIPHERAL,
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=BT_LE_TRANSPORT,
|
||||
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
transport=PhysicalTransport.LE,
|
||||
link_type=HCI_Connection_Complete_Event.LinkType.ACL,
|
||||
)
|
||||
self.peripheral_connections[peer_address] = connection
|
||||
logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}')
|
||||
@@ -412,7 +424,7 @@ class Controller:
|
||||
)
|
||||
)
|
||||
|
||||
def on_link_central_disconnected(self, peer_address, reason):
|
||||
def on_link_disconnected(self, peer_address, reason):
|
||||
'''
|
||||
Called when an active disconnection occurs from a peer
|
||||
'''
|
||||
@@ -429,6 +441,17 @@ class Controller:
|
||||
|
||||
# Remove the connection
|
||||
del self.peripheral_connections[peer_address]
|
||||
elif connection := self.central_connections.get(peer_address):
|
||||
self.send_hci_packet(
|
||||
HCI_Disconnection_Complete_Event(
|
||||
status=HCI_SUCCESS,
|
||||
connection_handle=connection.handle,
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
|
||||
# Remove the connection
|
||||
del self.central_connections[peer_address]
|
||||
else:
|
||||
logger.warning(f'!!! No peripheral connection found for {peer_address}')
|
||||
|
||||
@@ -448,11 +471,11 @@ class Controller:
|
||||
connection = Connection(
|
||||
controller=self,
|
||||
handle=connection_handle,
|
||||
role=BT_CENTRAL_ROLE,
|
||||
role=Role.CENTRAL,
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=BT_LE_TRANSPORT,
|
||||
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
transport=PhysicalTransport.LE,
|
||||
link_type=HCI_Connection_Complete_Event.LinkType.ACL,
|
||||
)
|
||||
self.central_connections[peer_address] = connection
|
||||
logger.debug(
|
||||
@@ -467,7 +490,7 @@ class Controller:
|
||||
HCI_LE_Connection_Complete_Event(
|
||||
status=status,
|
||||
connection_handle=connection.handle if connection else 0,
|
||||
role=BT_CENTRAL_ROLE,
|
||||
role=Role.CENTRAL,
|
||||
peer_address_type=le_create_connection_command.peer_address_type,
|
||||
peer_address=le_create_connection_command.peer_address,
|
||||
connection_interval=le_create_connection_command.connection_interval_min,
|
||||
@@ -477,7 +500,7 @@ class Controller:
|
||||
)
|
||||
)
|
||||
|
||||
def on_link_peripheral_disconnection_complete(self, disconnection_command, status):
|
||||
def on_link_disconnection_complete(self, disconnection_command, status):
|
||||
'''
|
||||
Called when a disconnection has been completed
|
||||
'''
|
||||
@@ -497,26 +520,11 @@ class Controller:
|
||||
):
|
||||
logger.debug(f'CENTRAL Connection removed: {connection}')
|
||||
del self.central_connections[connection.peer_address]
|
||||
|
||||
def on_link_peripheral_disconnected(self, peer_address):
|
||||
'''
|
||||
Called when a connection to a peripheral is broken
|
||||
'''
|
||||
|
||||
# Send a disconnection complete event
|
||||
if connection := self.central_connections.get(peer_address):
|
||||
self.send_hci_packet(
|
||||
HCI_Disconnection_Complete_Event(
|
||||
status=HCI_SUCCESS,
|
||||
connection_handle=connection.handle,
|
||||
reason=HCI_CONNECTION_TIMEOUT_ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
# Remove the connection
|
||||
del self.central_connections[peer_address]
|
||||
else:
|
||||
logger.warning(f'!!! No central connection found for {peer_address}')
|
||||
elif connection := self.find_peripheral_connection_by_handle(
|
||||
disconnection_command.connection_handle
|
||||
):
|
||||
logger.debug(f'PERIPHERAL Connection removed: {connection}')
|
||||
del self.peripheral_connections[connection.peer_address]
|
||||
|
||||
def on_link_encrypted(self, peer_address, _rand, _ediv, _ltk):
|
||||
# For now, just setup the encryption without asking the host
|
||||
@@ -529,7 +537,7 @@ class Controller:
|
||||
|
||||
def on_link_acl_data(self, sender_address, transport, data):
|
||||
# Look for the connection to which this data belongs
|
||||
if transport == BT_LE_TRANSPORT:
|
||||
if transport == PhysicalTransport.LE:
|
||||
connection = self.find_le_connection_by_address(sender_address)
|
||||
else:
|
||||
connection = self.find_classic_connection_by_address(sender_address)
|
||||
@@ -542,15 +550,14 @@ class Controller:
|
||||
acl_packet = HCI_AclDataPacket(connection.handle, 2, 0, len(data), data)
|
||||
self.send_hci_packet(acl_packet)
|
||||
|
||||
def on_link_advertising_data(self, sender_address, data):
|
||||
def on_link_advertising_data(self, sender_address: Address, data: bytes):
|
||||
# Ignore if we're not scanning
|
||||
if self.le_scan_enable == 0:
|
||||
return
|
||||
|
||||
# Send a scan report
|
||||
report = HCI_LE_Advertising_Report_Event.Report(
|
||||
HCI_LE_Advertising_Report_Event.Report.FIELDS,
|
||||
event_type=HCI_LE_Advertising_Report_Event.ADV_IND,
|
||||
event_type=HCI_LE_Advertising_Report_Event.EventType.ADV_IND,
|
||||
address_type=sender_address.address_type,
|
||||
address=sender_address,
|
||||
data=data,
|
||||
@@ -560,8 +567,7 @@ class Controller:
|
||||
|
||||
# Simulate a scan response
|
||||
report = HCI_LE_Advertising_Report_Event.Report(
|
||||
HCI_LE_Advertising_Report_Event.Report.FIELDS,
|
||||
event_type=HCI_LE_Advertising_Report_Event.SCAN_RSP,
|
||||
event_type=HCI_LE_Advertising_Report_Event.EventType.SCAN_RSP,
|
||||
address_type=sender_address.address_type,
|
||||
address=sender_address,
|
||||
data=data,
|
||||
@@ -618,8 +624,8 @@ class Controller:
|
||||
cis_sync_delay=0,
|
||||
transport_latency_c_to_p=0,
|
||||
transport_latency_p_to_c=0,
|
||||
phy_c_to_p=0,
|
||||
phy_p_to_c=0,
|
||||
phy_c_to_p=1,
|
||||
phy_p_to_c=1,
|
||||
nse=0,
|
||||
bn_c_to_p=0,
|
||||
bn_p_to_c=0,
|
||||
@@ -691,11 +697,11 @@ class Controller:
|
||||
controller=self,
|
||||
handle=connection_handle,
|
||||
# Role doesn't matter in Classic because they are managed by HCI_Role_Change and HCI_Role_Discovery
|
||||
role=BT_CENTRAL_ROLE,
|
||||
role=Role.CENTRAL,
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=BT_BR_EDR_TRANSPORT,
|
||||
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
transport=PhysicalTransport.BR_EDR,
|
||||
link_type=HCI_Connection_Complete_Event.LinkType.ACL,
|
||||
)
|
||||
self.classic_connections[peer_address] = connection
|
||||
logger.debug(
|
||||
@@ -709,7 +715,7 @@ class Controller:
|
||||
connection_handle=connection_handle,
|
||||
bd_addr=peer_address,
|
||||
encryption_enabled=False,
|
||||
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
link_type=HCI_Connection_Complete_Event.LinkType.ACL,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -720,7 +726,7 @@ class Controller:
|
||||
connection_handle=0,
|
||||
bd_addr=peer_address,
|
||||
encryption_enabled=False,
|
||||
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
link_type=HCI_Connection_Complete_Event.LinkType.ACL,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -759,10 +765,10 @@ class Controller:
|
||||
controller=self,
|
||||
handle=connection_handle,
|
||||
# Role doesn't matter in SCO.
|
||||
role=BT_CENTRAL_ROLE,
|
||||
role=Role.CENTRAL,
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=BT_BR_EDR_TRANSPORT,
|
||||
transport=PhysicalTransport.BR_EDR,
|
||||
link_type=link_type,
|
||||
)
|
||||
self.classic_connections[peer_address] = connection
|
||||
@@ -877,6 +883,14 @@ class Controller:
|
||||
else:
|
||||
# Remove the connection
|
||||
del self.central_connections[connection.peer_address]
|
||||
elif connection := self.find_peripheral_connection_by_handle(handle):
|
||||
if self.link:
|
||||
self.link.disconnect(
|
||||
self.random_address, connection.peer_address, command
|
||||
)
|
||||
else:
|
||||
# Remove the connection
|
||||
del self.peripheral_connections[connection.peer_address]
|
||||
elif connection := self.find_classic_connection_by_handle(handle):
|
||||
if self.link:
|
||||
self.link.classic_disconnect(
|
||||
@@ -945,7 +959,7 @@ class Controller:
|
||||
)
|
||||
)
|
||||
self.link.classic_sco_connect(
|
||||
self, connection.peer_address, HCI_Connection_Complete_Event.ESCO_LINK_TYPE
|
||||
self, connection.peer_address, HCI_Connection_Complete_Event.LinkType.ESCO
|
||||
)
|
||||
|
||||
def on_hci_enhanced_accept_synchronous_connection_request_command(self, command):
|
||||
@@ -974,10 +988,71 @@ class Controller:
|
||||
)
|
||||
)
|
||||
self.link.classic_accept_sco_connection(
|
||||
self, connection.peer_address, HCI_Connection_Complete_Event.ESCO_LINK_TYPE
|
||||
self, connection.peer_address, HCI_Connection_Complete_Event.LinkType.ESCO
|
||||
)
|
||||
|
||||
def on_hci_switch_role_command(self, command):
|
||||
def on_hci_sniff_mode_command(self, command: hci.HCI_Sniff_Mode_Command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.2.2 Sniff Mode command
|
||||
'''
|
||||
if self.link is None:
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Command_Status_Event(
|
||||
status=hci.HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Command_Status_Event(
|
||||
status=HCI_SUCCESS,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Mode_Change_Event(
|
||||
status=HCI_SUCCESS,
|
||||
connection_handle=command.connection_handle,
|
||||
current_mode=hci.HCI_Mode_Change_Event.Mode.SNIFF,
|
||||
interval=2,
|
||||
)
|
||||
)
|
||||
|
||||
def on_hci_exit_sniff_mode_command(self, command: hci.HCI_Exit_Sniff_Mode_Command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.2.3 Exit Sniff Mode command
|
||||
'''
|
||||
|
||||
if self.link is None:
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Command_Status_Event(
|
||||
status=hci.HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Command_Status_Event(
|
||||
status=HCI_SUCCESS,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Mode_Change_Event(
|
||||
status=HCI_SUCCESS,
|
||||
connection_handle=command.connection_handle,
|
||||
current_mode=hci.HCI_Mode_Change_Event.Mode.ACTIVE,
|
||||
interval=2,
|
||||
)
|
||||
)
|
||||
|
||||
def on_hci_switch_role_command(self, command: hci.HCI_Switch_Role_Command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.2.8 Switch Role command
|
||||
'''
|
||||
@@ -1181,9 +1256,9 @@ class Controller:
|
||||
return struct.pack(
|
||||
'<BHBHH',
|
||||
HCI_SUCCESS,
|
||||
self.hc_data_packet_length,
|
||||
self.acl_data_packet_length,
|
||||
0,
|
||||
self.hc_total_num_data_packets,
|
||||
self.total_num_acl_data_packets,
|
||||
0,
|
||||
)
|
||||
|
||||
@@ -1192,12 +1267,62 @@ class Controller:
|
||||
See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command
|
||||
'''
|
||||
bd_addr = (
|
||||
self._public_address.to_bytes()
|
||||
bytes(self._public_address)
|
||||
if self._public_address is not None
|
||||
else bytes(6)
|
||||
)
|
||||
return bytes([HCI_SUCCESS]) + bd_addr
|
||||
|
||||
def on_hci_le_set_default_subrate_command(
|
||||
self, command: hci.HCI_LE_Set_Default_Subrate_Command
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 6, Part E - 7.8.123 LE Set Event Mask Command
|
||||
'''
|
||||
|
||||
if (
|
||||
command.subrate_max * (command.max_latency) > 500
|
||||
or command.subrate_max < command.subrate_min
|
||||
or command.continuation_number >= command.subrate_max
|
||||
):
|
||||
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
|
||||
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_subrate_request_command(
|
||||
self, command: hci.HCI_LE_Subrate_Request_Command
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 6, Part E - 7.8.124 LE Subrate Request command
|
||||
'''
|
||||
if (
|
||||
command.subrate_max * (command.max_latency) > 500
|
||||
or command.continuation_number < command.continuation_number
|
||||
or command.subrate_max < command.subrate_min
|
||||
or command.continuation_number >= command.subrate_max
|
||||
):
|
||||
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
|
||||
|
||||
self.send_hci_packet(
|
||||
hci.HCI_Command_Status_Event(
|
||||
status=hci.HCI_SUCCESS,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
|
||||
self.send_hci_packet(
|
||||
hci.HCI_LE_Subrate_Change_Event(
|
||||
status=hci.HCI_SUCCESS,
|
||||
connection_handle=command.connection_handle,
|
||||
subrate_factor=2,
|
||||
peripheral_latency=2,
|
||||
continuation_number=command.continuation_number,
|
||||
supervision_timeout=command.supervision_timeout,
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
def on_hci_le_set_event_mask_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.1 LE Set Event Mask Command
|
||||
@@ -1212,8 +1337,21 @@ class Controller:
|
||||
return struct.pack(
|
||||
'<BHB',
|
||||
HCI_SUCCESS,
|
||||
self.hc_le_data_packet_length,
|
||||
self.hc_total_num_le_data_packets,
|
||||
self.le_acl_data_packet_length,
|
||||
self.total_num_le_acl_data_packets,
|
||||
)
|
||||
|
||||
def on_hci_le_read_buffer_size_v2_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.2 LE Read Buffer Size Command
|
||||
'''
|
||||
return struct.pack(
|
||||
'<BHBHB',
|
||||
HCI_SUCCESS,
|
||||
self.le_acl_data_packet_length,
|
||||
self.total_num_le_acl_data_packets,
|
||||
self.iso_data_packet_length,
|
||||
self.total_num_iso_data_packets,
|
||||
)
|
||||
|
||||
def on_hci_le_read_local_supported_features_command(self, _command):
|
||||
@@ -1543,6 +1681,41 @@ class Controller:
|
||||
}
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_advertising_set_random_address_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.52 LE Set Advertising Set Random Address
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_extended_advertising_parameters_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.53 LE Set Extended Advertising Parameters
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS, 0])
|
||||
|
||||
def on_hci_le_set_extended_advertising_data_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.54 LE Set Extended Advertising Data
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_extended_scan_response_data_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.55 LE Set Extended Scan Response Data
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_extended_advertising_enable_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.56 LE Set Extended Advertising Enable
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_read_maximum_advertising_data_length_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.57 LE Read Maximum Advertising Data
|
||||
@@ -1557,6 +1730,27 @@ class Controller:
|
||||
'''
|
||||
return struct.pack('<BB', HCI_SUCCESS, 0xF0)
|
||||
|
||||
def on_hci_le_set_periodic_advertising_parameters_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.61 LE Set Periodic Advertising Parameters
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_periodic_advertising_data_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.62 LE Set Periodic Advertising Data
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_periodic_advertising_enable_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.63 LE Set Periodic Advertising Enable
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_read_transmit_power_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command
|
||||
@@ -1664,14 +1858,57 @@ class Controller:
|
||||
)
|
||||
)
|
||||
|
||||
def on_hci_le_setup_iso_data_path_command(self, command):
|
||||
def on_hci_le_setup_iso_data_path_command(
|
||||
self, command: hci.HCI_LE_Setup_ISO_Data_Path_Command
|
||||
) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.109 LE Setup ISO Data Path Command
|
||||
'''
|
||||
if not (iso_link := self.find_iso_link_by_handle(command.connection_handle)):
|
||||
return struct.pack(
|
||||
'<BH',
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
command.connection_handle,
|
||||
)
|
||||
if command.data_path_direction in iso_link.data_paths:
|
||||
return struct.pack(
|
||||
'<BH',
|
||||
HCI_COMMAND_DISALLOWED_ERROR,
|
||||
command.connection_handle,
|
||||
)
|
||||
iso_link.data_paths.add(command.data_path_direction)
|
||||
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
|
||||
|
||||
def on_hci_le_remove_iso_data_path_command(self, command):
|
||||
def on_hci_le_remove_iso_data_path_command(
|
||||
self, command: hci.HCI_LE_Remove_ISO_Data_Path_Command
|
||||
) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command
|
||||
'''
|
||||
if not (iso_link := self.find_iso_link_by_handle(command.connection_handle)):
|
||||
return struct.pack(
|
||||
'<BH',
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
command.connection_handle,
|
||||
)
|
||||
data_paths: set[int] = set(
|
||||
direction
|
||||
for direction in hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction
|
||||
if (1 << direction) & command.data_path_direction
|
||||
)
|
||||
if not data_paths.issubset(iso_link.data_paths):
|
||||
return struct.pack(
|
||||
'<BH',
|
||||
HCI_COMMAND_DISALLOWED_ERROR,
|
||||
command.connection_handle,
|
||||
)
|
||||
iso_link.data_paths.difference_update(data_paths)
|
||||
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
|
||||
|
||||
def on_hci_le_set_host_feature_command(
|
||||
self, _command: hci.HCI_LE_Set_Host_Feature_Command
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.115 LE Set Host Feature command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
1417
bumble/core.py
1417
bumble/core.py
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
# Copyright 2021-2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
@@ -12,12 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Crypto support
|
||||
#
|
||||
# See Bluetooth spec Vol 3, Part H - 2.2 CRYPTOGRAPHIC TOOLBOX
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -25,19 +19,15 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import operator
|
||||
|
||||
import secrets
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||
generate_private_key,
|
||||
ECDH,
|
||||
EllipticCurvePrivateKey,
|
||||
EllipticCurvePublicNumbers,
|
||||
EllipticCurvePrivateNumbers,
|
||||
SECP256R1,
|
||||
)
|
||||
from cryptography.hazmat.primitives import cmac
|
||||
from typing import Tuple
|
||||
|
||||
try:
|
||||
from bumble.crypto.cryptography import EccKey, aes_cmac, e
|
||||
except ImportError:
|
||||
logging.getLogger(__name__).debug(
|
||||
"Unable to import cryptography, use built-in primitives."
|
||||
)
|
||||
from bumble.crypto.builtin import EccKey, aes_cmac, e # type: ignore[assignment]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -46,55 +36,6 @@ from typing import Tuple
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class EccKey:
|
||||
def __init__(self, private_key: EllipticCurvePrivateKey) -> None:
|
||||
self.private_key = private_key
|
||||
|
||||
@classmethod
|
||||
def generate(cls) -> EccKey:
|
||||
private_key = generate_private_key(SECP256R1())
|
||||
return cls(private_key)
|
||||
|
||||
@classmethod
|
||||
def from_private_key_bytes(
|
||||
cls, d_bytes: bytes, x_bytes: bytes, y_bytes: bytes
|
||||
) -> EccKey:
|
||||
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
|
||||
x = int.from_bytes(x_bytes, byteorder='big', signed=False)
|
||||
y = int.from_bytes(y_bytes, byteorder='big', signed=False)
|
||||
private_key = EllipticCurvePrivateNumbers(
|
||||
d, EllipticCurvePublicNumbers(x, y, SECP256R1())
|
||||
).private_key()
|
||||
return cls(private_key)
|
||||
|
||||
@property
|
||||
def x(self) -> bytes:
|
||||
return (
|
||||
self.private_key.public_key()
|
||||
.public_numbers()
|
||||
.x.to_bytes(32, byteorder='big')
|
||||
)
|
||||
|
||||
@property
|
||||
def y(self) -> bytes:
|
||||
return (
|
||||
self.private_key.public_key()
|
||||
.public_numbers()
|
||||
.y.to_bytes(32, byteorder='big')
|
||||
)
|
||||
|
||||
def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes:
|
||||
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
|
||||
y = int.from_bytes(public_key_y, byteorder='big', signed=False)
|
||||
public_key = EllipticCurvePublicNumbers(x, y, SECP256R1()).public_key()
|
||||
shared_key = self.private_key.exchange(ECDH(), public_key)
|
||||
|
||||
return shared_key
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -132,19 +73,6 @@ def r() -> bytes:
|
||||
return secrets.token_bytes(16)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def e(key: bytes, data: bytes) -> bytes:
|
||||
'''
|
||||
AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output.
|
||||
|
||||
See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e
|
||||
'''
|
||||
|
||||
cipher = Cipher(algorithms.AES(reverse(key)), modes.ECB())
|
||||
encryptor = cipher.encryptor()
|
||||
return reverse(encryptor.update(reverse(data)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def ah(k: bytes, r: bytes) -> bytes: # pylint: disable=redefined-outer-name
|
||||
'''
|
||||
@@ -187,18 +115,6 @@ def s1(k: bytes, r1: bytes, r2: bytes) -> bytes:
|
||||
return e(k, r2[0:8] + r1[0:8])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def aes_cmac(m: bytes, k: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC
|
||||
|
||||
NOTE: the input and output of this internal function are in big-endian byte order
|
||||
'''
|
||||
mac = cmac.CMAC(algorithms.AES(k))
|
||||
mac.update(m)
|
||||
return mac.finalize()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def f4(u: bytes, v: bytes, x: bytes, z: bytes) -> bytes:
|
||||
'''
|
||||
@@ -209,7 +125,7 @@ def f4(u: bytes, v: bytes, x: bytes, z: bytes) -> bytes:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def f5(w: bytes, n1: bytes, n2: bytes, a1: bytes, a2: bytes) -> Tuple[bytes, bytes]:
|
||||
def f5(w: bytes, n1: bytes, n2: bytes, a1: bytes, a2: bytes) -> tuple[bytes, bytes]:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation
|
||||
Function f5
|
||||
652
bumble/crypto/builtin.py
Normal file
652
bumble/crypto/builtin.py
Normal file
@@ -0,0 +1,652 @@
|
||||
# Copyright 2021-2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The implementation is modified from:
|
||||
# * AES - https://github.com/ricmoo/pyaes by Richard Moore under MIT License
|
||||
# * CMAC - https://github.com/pycrypto/pycrypto by contributors under pycrypto License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Built-in implementation of cryptography primitives.
|
||||
#
|
||||
# Note: It's very dangerous to use this library in the real world.
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import secrets
|
||||
import struct
|
||||
from typing import Optional
|
||||
|
||||
from bumble import core
|
||||
|
||||
|
||||
def _compact_word(word: bytes) -> int:
|
||||
return int.from_bytes(word, "big")
|
||||
|
||||
|
||||
def _shift_bytes(bs: bytes, xor_lsb: int = 0) -> bytes:
|
||||
return ((int.from_bytes(bs, "big") << 1) ^ xor_lsb).to_bytes(len(bs) + 1, "big")[1:]
|
||||
|
||||
|
||||
def _xor(a: bytes, b: bytes) -> bytes:
|
||||
return bytes(x ^ y for x, y in zip(a, b))
|
||||
|
||||
|
||||
# Based *largely* on the Rijndael implementation
|
||||
# See: http://csrc.nist.gov/publications/FIPS/FIPS197/FIPS-197.pdf
|
||||
class _AES:
|
||||
'''Encapsulates the AES block cipher.
|
||||
|
||||
You generally should not need this. Use the AESModeOfOperation classes
|
||||
below instead.'''
|
||||
|
||||
# fmt: off
|
||||
# Number of rounds by key size
|
||||
_NUMBER_OF_ROUNDS = {16: 10, 24: 12, 32: 14}
|
||||
|
||||
# Round constant words
|
||||
_RCON = [ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91 ]
|
||||
|
||||
# S-box and Inverse S-box (S is for Substitution)
|
||||
_S = [ 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16 ]
|
||||
_S_INV =[ 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb, 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb, 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e, 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25, 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92, 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84, 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06, 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b, 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73, 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e, 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b, 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4, 0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f, 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef, 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61, 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d ]
|
||||
|
||||
# Transformations for encryption
|
||||
_T1 = [ 0xc66363a5, 0xf87c7c84, 0xee777799, 0xf67b7b8d, 0xfff2f20d, 0xd66b6bbd, 0xde6f6fb1, 0x91c5c554, 0x60303050, 0x02010103, 0xce6767a9, 0x562b2b7d, 0xe7fefe19, 0xb5d7d762, 0x4dababe6, 0xec76769a, 0x8fcaca45, 0x1f82829d, 0x89c9c940, 0xfa7d7d87, 0xeffafa15, 0xb25959eb, 0x8e4747c9, 0xfbf0f00b, 0x41adadec, 0xb3d4d467, 0x5fa2a2fd, 0x45afafea, 0x239c9cbf, 0x53a4a4f7, 0xe4727296, 0x9bc0c05b, 0x75b7b7c2, 0xe1fdfd1c, 0x3d9393ae, 0x4c26266a, 0x6c36365a, 0x7e3f3f41, 0xf5f7f702, 0x83cccc4f, 0x6834345c, 0x51a5a5f4, 0xd1e5e534, 0xf9f1f108, 0xe2717193, 0xabd8d873, 0x62313153, 0x2a15153f, 0x0804040c, 0x95c7c752, 0x46232365, 0x9dc3c35e, 0x30181828, 0x379696a1, 0x0a05050f, 0x2f9a9ab5, 0x0e070709, 0x24121236, 0x1b80809b, 0xdfe2e23d, 0xcdebeb26, 0x4e272769, 0x7fb2b2cd, 0xea75759f, 0x1209091b, 0x1d83839e, 0x582c2c74, 0x341a1a2e, 0x361b1b2d, 0xdc6e6eb2, 0xb45a5aee, 0x5ba0a0fb, 0xa45252f6, 0x763b3b4d, 0xb7d6d661, 0x7db3b3ce, 0x5229297b, 0xdde3e33e, 0x5e2f2f71, 0x13848497, 0xa65353f5, 0xb9d1d168, 0x00000000, 0xc1eded2c, 0x40202060, 0xe3fcfc1f, 0x79b1b1c8, 0xb65b5bed, 0xd46a6abe, 0x8dcbcb46, 0x67bebed9, 0x7239394b, 0x944a4ade, 0x984c4cd4, 0xb05858e8, 0x85cfcf4a, 0xbbd0d06b, 0xc5efef2a, 0x4faaaae5, 0xedfbfb16, 0x864343c5, 0x9a4d4dd7, 0x66333355, 0x11858594, 0x8a4545cf, 0xe9f9f910, 0x04020206, 0xfe7f7f81, 0xa05050f0, 0x783c3c44, 0x259f9fba, 0x4ba8a8e3, 0xa25151f3, 0x5da3a3fe, 0x804040c0, 0x058f8f8a, 0x3f9292ad, 0x219d9dbc, 0x70383848, 0xf1f5f504, 0x63bcbcdf, 0x77b6b6c1, 0xafdada75, 0x42212163, 0x20101030, 0xe5ffff1a, 0xfdf3f30e, 0xbfd2d26d, 0x81cdcd4c, 0x180c0c14, 0x26131335, 0xc3ecec2f, 0xbe5f5fe1, 0x359797a2, 0x884444cc, 0x2e171739, 0x93c4c457, 0x55a7a7f2, 0xfc7e7e82, 0x7a3d3d47, 0xc86464ac, 0xba5d5de7, 0x3219192b, 0xe6737395, 0xc06060a0, 0x19818198, 0x9e4f4fd1, 0xa3dcdc7f, 0x44222266, 0x542a2a7e, 0x3b9090ab, 0x0b888883, 0x8c4646ca, 0xc7eeee29, 0x6bb8b8d3, 0x2814143c, 0xa7dede79, 0xbc5e5ee2, 0x160b0b1d, 0xaddbdb76, 0xdbe0e03b, 0x64323256, 0x743a3a4e, 0x140a0a1e, 0x924949db, 0x0c06060a, 0x4824246c, 0xb85c5ce4, 0x9fc2c25d, 0xbdd3d36e, 0x43acacef, 0xc46262a6, 0x399191a8, 0x319595a4, 0xd3e4e437, 0xf279798b, 0xd5e7e732, 0x8bc8c843, 0x6e373759, 0xda6d6db7, 0x018d8d8c, 0xb1d5d564, 0x9c4e4ed2, 0x49a9a9e0, 0xd86c6cb4, 0xac5656fa, 0xf3f4f407, 0xcfeaea25, 0xca6565af, 0xf47a7a8e, 0x47aeaee9, 0x10080818, 0x6fbabad5, 0xf0787888, 0x4a25256f, 0x5c2e2e72, 0x381c1c24, 0x57a6a6f1, 0x73b4b4c7, 0x97c6c651, 0xcbe8e823, 0xa1dddd7c, 0xe874749c, 0x3e1f1f21, 0x964b4bdd, 0x61bdbddc, 0x0d8b8b86, 0x0f8a8a85, 0xe0707090, 0x7c3e3e42, 0x71b5b5c4, 0xcc6666aa, 0x904848d8, 0x06030305, 0xf7f6f601, 0x1c0e0e12, 0xc26161a3, 0x6a35355f, 0xae5757f9, 0x69b9b9d0, 0x17868691, 0x99c1c158, 0x3a1d1d27, 0x279e9eb9, 0xd9e1e138, 0xebf8f813, 0x2b9898b3, 0x22111133, 0xd26969bb, 0xa9d9d970, 0x078e8e89, 0x339494a7, 0x2d9b9bb6, 0x3c1e1e22, 0x15878792, 0xc9e9e920, 0x87cece49, 0xaa5555ff, 0x50282878, 0xa5dfdf7a, 0x038c8c8f, 0x59a1a1f8, 0x09898980, 0x1a0d0d17, 0x65bfbfda, 0xd7e6e631, 0x844242c6, 0xd06868b8, 0x824141c3, 0x299999b0, 0x5a2d2d77, 0x1e0f0f11, 0x7bb0b0cb, 0xa85454fc, 0x6dbbbbd6, 0x2c16163a ]
|
||||
_T2 = [ 0xa5c66363, 0x84f87c7c, 0x99ee7777, 0x8df67b7b, 0x0dfff2f2, 0xbdd66b6b, 0xb1de6f6f, 0x5491c5c5, 0x50603030, 0x03020101, 0xa9ce6767, 0x7d562b2b, 0x19e7fefe, 0x62b5d7d7, 0xe64dabab, 0x9aec7676, 0x458fcaca, 0x9d1f8282, 0x4089c9c9, 0x87fa7d7d, 0x15effafa, 0xebb25959, 0xc98e4747, 0x0bfbf0f0, 0xec41adad, 0x67b3d4d4, 0xfd5fa2a2, 0xea45afaf, 0xbf239c9c, 0xf753a4a4, 0x96e47272, 0x5b9bc0c0, 0xc275b7b7, 0x1ce1fdfd, 0xae3d9393, 0x6a4c2626, 0x5a6c3636, 0x417e3f3f, 0x02f5f7f7, 0x4f83cccc, 0x5c683434, 0xf451a5a5, 0x34d1e5e5, 0x08f9f1f1, 0x93e27171, 0x73abd8d8, 0x53623131, 0x3f2a1515, 0x0c080404, 0x5295c7c7, 0x65462323, 0x5e9dc3c3, 0x28301818, 0xa1379696, 0x0f0a0505, 0xb52f9a9a, 0x090e0707, 0x36241212, 0x9b1b8080, 0x3ddfe2e2, 0x26cdebeb, 0x694e2727, 0xcd7fb2b2, 0x9fea7575, 0x1b120909, 0x9e1d8383, 0x74582c2c, 0x2e341a1a, 0x2d361b1b, 0xb2dc6e6e, 0xeeb45a5a, 0xfb5ba0a0, 0xf6a45252, 0x4d763b3b, 0x61b7d6d6, 0xce7db3b3, 0x7b522929, 0x3edde3e3, 0x715e2f2f, 0x97138484, 0xf5a65353, 0x68b9d1d1, 0x00000000, 0x2cc1eded, 0x60402020, 0x1fe3fcfc, 0xc879b1b1, 0xedb65b5b, 0xbed46a6a, 0x468dcbcb, 0xd967bebe, 0x4b723939, 0xde944a4a, 0xd4984c4c, 0xe8b05858, 0x4a85cfcf, 0x6bbbd0d0, 0x2ac5efef, 0xe54faaaa, 0x16edfbfb, 0xc5864343, 0xd79a4d4d, 0x55663333, 0x94118585, 0xcf8a4545, 0x10e9f9f9, 0x06040202, 0x81fe7f7f, 0xf0a05050, 0x44783c3c, 0xba259f9f, 0xe34ba8a8, 0xf3a25151, 0xfe5da3a3, 0xc0804040, 0x8a058f8f, 0xad3f9292, 0xbc219d9d, 0x48703838, 0x04f1f5f5, 0xdf63bcbc, 0xc177b6b6, 0x75afdada, 0x63422121, 0x30201010, 0x1ae5ffff, 0x0efdf3f3, 0x6dbfd2d2, 0x4c81cdcd, 0x14180c0c, 0x35261313, 0x2fc3ecec, 0xe1be5f5f, 0xa2359797, 0xcc884444, 0x392e1717, 0x5793c4c4, 0xf255a7a7, 0x82fc7e7e, 0x477a3d3d, 0xacc86464, 0xe7ba5d5d, 0x2b321919, 0x95e67373, 0xa0c06060, 0x98198181, 0xd19e4f4f, 0x7fa3dcdc, 0x66442222, 0x7e542a2a, 0xab3b9090, 0x830b8888, 0xca8c4646, 0x29c7eeee, 0xd36bb8b8, 0x3c281414, 0x79a7dede, 0xe2bc5e5e, 0x1d160b0b, 0x76addbdb, 0x3bdbe0e0, 0x56643232, 0x4e743a3a, 0x1e140a0a, 0xdb924949, 0x0a0c0606, 0x6c482424, 0xe4b85c5c, 0x5d9fc2c2, 0x6ebdd3d3, 0xef43acac, 0xa6c46262, 0xa8399191, 0xa4319595, 0x37d3e4e4, 0x8bf27979, 0x32d5e7e7, 0x438bc8c8, 0x596e3737, 0xb7da6d6d, 0x8c018d8d, 0x64b1d5d5, 0xd29c4e4e, 0xe049a9a9, 0xb4d86c6c, 0xfaac5656, 0x07f3f4f4, 0x25cfeaea, 0xafca6565, 0x8ef47a7a, 0xe947aeae, 0x18100808, 0xd56fbaba, 0x88f07878, 0x6f4a2525, 0x725c2e2e, 0x24381c1c, 0xf157a6a6, 0xc773b4b4, 0x5197c6c6, 0x23cbe8e8, 0x7ca1dddd, 0x9ce87474, 0x213e1f1f, 0xdd964b4b, 0xdc61bdbd, 0x860d8b8b, 0x850f8a8a, 0x90e07070, 0x427c3e3e, 0xc471b5b5, 0xaacc6666, 0xd8904848, 0x05060303, 0x01f7f6f6, 0x121c0e0e, 0xa3c26161, 0x5f6a3535, 0xf9ae5757, 0xd069b9b9, 0x91178686, 0x5899c1c1, 0x273a1d1d, 0xb9279e9e, 0x38d9e1e1, 0x13ebf8f8, 0xb32b9898, 0x33221111, 0xbbd26969, 0x70a9d9d9, 0x89078e8e, 0xa7339494, 0xb62d9b9b, 0x223c1e1e, 0x92158787, 0x20c9e9e9, 0x4987cece, 0xffaa5555, 0x78502828, 0x7aa5dfdf, 0x8f038c8c, 0xf859a1a1, 0x80098989, 0x171a0d0d, 0xda65bfbf, 0x31d7e6e6, 0xc6844242, 0xb8d06868, 0xc3824141, 0xb0299999, 0x775a2d2d, 0x111e0f0f, 0xcb7bb0b0, 0xfca85454, 0xd66dbbbb, 0x3a2c1616 ]
|
||||
_T3 = [ 0x63a5c663, 0x7c84f87c, 0x7799ee77, 0x7b8df67b, 0xf20dfff2, 0x6bbdd66b, 0x6fb1de6f, 0xc55491c5, 0x30506030, 0x01030201, 0x67a9ce67, 0x2b7d562b, 0xfe19e7fe, 0xd762b5d7, 0xabe64dab, 0x769aec76, 0xca458fca, 0x829d1f82, 0xc94089c9, 0x7d87fa7d, 0xfa15effa, 0x59ebb259, 0x47c98e47, 0xf00bfbf0, 0xadec41ad, 0xd467b3d4, 0xa2fd5fa2, 0xafea45af, 0x9cbf239c, 0xa4f753a4, 0x7296e472, 0xc05b9bc0, 0xb7c275b7, 0xfd1ce1fd, 0x93ae3d93, 0x266a4c26, 0x365a6c36, 0x3f417e3f, 0xf702f5f7, 0xcc4f83cc, 0x345c6834, 0xa5f451a5, 0xe534d1e5, 0xf108f9f1, 0x7193e271, 0xd873abd8, 0x31536231, 0x153f2a15, 0x040c0804, 0xc75295c7, 0x23654623, 0xc35e9dc3, 0x18283018, 0x96a13796, 0x050f0a05, 0x9ab52f9a, 0x07090e07, 0x12362412, 0x809b1b80, 0xe23ddfe2, 0xeb26cdeb, 0x27694e27, 0xb2cd7fb2, 0x759fea75, 0x091b1209, 0x839e1d83, 0x2c74582c, 0x1a2e341a, 0x1b2d361b, 0x6eb2dc6e, 0x5aeeb45a, 0xa0fb5ba0, 0x52f6a452, 0x3b4d763b, 0xd661b7d6, 0xb3ce7db3, 0x297b5229, 0xe33edde3, 0x2f715e2f, 0x84971384, 0x53f5a653, 0xd168b9d1, 0x00000000, 0xed2cc1ed, 0x20604020, 0xfc1fe3fc, 0xb1c879b1, 0x5bedb65b, 0x6abed46a, 0xcb468dcb, 0xbed967be, 0x394b7239, 0x4ade944a, 0x4cd4984c, 0x58e8b058, 0xcf4a85cf, 0xd06bbbd0, 0xef2ac5ef, 0xaae54faa, 0xfb16edfb, 0x43c58643, 0x4dd79a4d, 0x33556633, 0x85941185, 0x45cf8a45, 0xf910e9f9, 0x02060402, 0x7f81fe7f, 0x50f0a050, 0x3c44783c, 0x9fba259f, 0xa8e34ba8, 0x51f3a251, 0xa3fe5da3, 0x40c08040, 0x8f8a058f, 0x92ad3f92, 0x9dbc219d, 0x38487038, 0xf504f1f5, 0xbcdf63bc, 0xb6c177b6, 0xda75afda, 0x21634221, 0x10302010, 0xff1ae5ff, 0xf30efdf3, 0xd26dbfd2, 0xcd4c81cd, 0x0c14180c, 0x13352613, 0xec2fc3ec, 0x5fe1be5f, 0x97a23597, 0x44cc8844, 0x17392e17, 0xc45793c4, 0xa7f255a7, 0x7e82fc7e, 0x3d477a3d, 0x64acc864, 0x5de7ba5d, 0x192b3219, 0x7395e673, 0x60a0c060, 0x81981981, 0x4fd19e4f, 0xdc7fa3dc, 0x22664422, 0x2a7e542a, 0x90ab3b90, 0x88830b88, 0x46ca8c46, 0xee29c7ee, 0xb8d36bb8, 0x143c2814, 0xde79a7de, 0x5ee2bc5e, 0x0b1d160b, 0xdb76addb, 0xe03bdbe0, 0x32566432, 0x3a4e743a, 0x0a1e140a, 0x49db9249, 0x060a0c06, 0x246c4824, 0x5ce4b85c, 0xc25d9fc2, 0xd36ebdd3, 0xacef43ac, 0x62a6c462, 0x91a83991, 0x95a43195, 0xe437d3e4, 0x798bf279, 0xe732d5e7, 0xc8438bc8, 0x37596e37, 0x6db7da6d, 0x8d8c018d, 0xd564b1d5, 0x4ed29c4e, 0xa9e049a9, 0x6cb4d86c, 0x56faac56, 0xf407f3f4, 0xea25cfea, 0x65afca65, 0x7a8ef47a, 0xaee947ae, 0x08181008, 0xbad56fba, 0x7888f078, 0x256f4a25, 0x2e725c2e, 0x1c24381c, 0xa6f157a6, 0xb4c773b4, 0xc65197c6, 0xe823cbe8, 0xdd7ca1dd, 0x749ce874, 0x1f213e1f, 0x4bdd964b, 0xbddc61bd, 0x8b860d8b, 0x8a850f8a, 0x7090e070, 0x3e427c3e, 0xb5c471b5, 0x66aacc66, 0x48d89048, 0x03050603, 0xf601f7f6, 0x0e121c0e, 0x61a3c261, 0x355f6a35, 0x57f9ae57, 0xb9d069b9, 0x86911786, 0xc15899c1, 0x1d273a1d, 0x9eb9279e, 0xe138d9e1, 0xf813ebf8, 0x98b32b98, 0x11332211, 0x69bbd269, 0xd970a9d9, 0x8e89078e, 0x94a73394, 0x9bb62d9b, 0x1e223c1e, 0x87921587, 0xe920c9e9, 0xce4987ce, 0x55ffaa55, 0x28785028, 0xdf7aa5df, 0x8c8f038c, 0xa1f859a1, 0x89800989, 0x0d171a0d, 0xbfda65bf, 0xe631d7e6, 0x42c68442, 0x68b8d068, 0x41c38241, 0x99b02999, 0x2d775a2d, 0x0f111e0f, 0xb0cb7bb0, 0x54fca854, 0xbbd66dbb, 0x163a2c16 ]
|
||||
_T4 = [ 0x6363a5c6, 0x7c7c84f8, 0x777799ee, 0x7b7b8df6, 0xf2f20dff, 0x6b6bbdd6, 0x6f6fb1de, 0xc5c55491, 0x30305060, 0x01010302, 0x6767a9ce, 0x2b2b7d56, 0xfefe19e7, 0xd7d762b5, 0xababe64d, 0x76769aec, 0xcaca458f, 0x82829d1f, 0xc9c94089, 0x7d7d87fa, 0xfafa15ef, 0x5959ebb2, 0x4747c98e, 0xf0f00bfb, 0xadadec41, 0xd4d467b3, 0xa2a2fd5f, 0xafafea45, 0x9c9cbf23, 0xa4a4f753, 0x727296e4, 0xc0c05b9b, 0xb7b7c275, 0xfdfd1ce1, 0x9393ae3d, 0x26266a4c, 0x36365a6c, 0x3f3f417e, 0xf7f702f5, 0xcccc4f83, 0x34345c68, 0xa5a5f451, 0xe5e534d1, 0xf1f108f9, 0x717193e2, 0xd8d873ab, 0x31315362, 0x15153f2a, 0x04040c08, 0xc7c75295, 0x23236546, 0xc3c35e9d, 0x18182830, 0x9696a137, 0x05050f0a, 0x9a9ab52f, 0x0707090e, 0x12123624, 0x80809b1b, 0xe2e23ddf, 0xebeb26cd, 0x2727694e, 0xb2b2cd7f, 0x75759fea, 0x09091b12, 0x83839e1d, 0x2c2c7458, 0x1a1a2e34, 0x1b1b2d36, 0x6e6eb2dc, 0x5a5aeeb4, 0xa0a0fb5b, 0x5252f6a4, 0x3b3b4d76, 0xd6d661b7, 0xb3b3ce7d, 0x29297b52, 0xe3e33edd, 0x2f2f715e, 0x84849713, 0x5353f5a6, 0xd1d168b9, 0x00000000, 0xeded2cc1, 0x20206040, 0xfcfc1fe3, 0xb1b1c879, 0x5b5bedb6, 0x6a6abed4, 0xcbcb468d, 0xbebed967, 0x39394b72, 0x4a4ade94, 0x4c4cd498, 0x5858e8b0, 0xcfcf4a85, 0xd0d06bbb, 0xefef2ac5, 0xaaaae54f, 0xfbfb16ed, 0x4343c586, 0x4d4dd79a, 0x33335566, 0x85859411, 0x4545cf8a, 0xf9f910e9, 0x02020604, 0x7f7f81fe, 0x5050f0a0, 0x3c3c4478, 0x9f9fba25, 0xa8a8e34b, 0x5151f3a2, 0xa3a3fe5d, 0x4040c080, 0x8f8f8a05, 0x9292ad3f, 0x9d9dbc21, 0x38384870, 0xf5f504f1, 0xbcbcdf63, 0xb6b6c177, 0xdada75af, 0x21216342, 0x10103020, 0xffff1ae5, 0xf3f30efd, 0xd2d26dbf, 0xcdcd4c81, 0x0c0c1418, 0x13133526, 0xecec2fc3, 0x5f5fe1be, 0x9797a235, 0x4444cc88, 0x1717392e, 0xc4c45793, 0xa7a7f255, 0x7e7e82fc, 0x3d3d477a, 0x6464acc8, 0x5d5de7ba, 0x19192b32, 0x737395e6, 0x6060a0c0, 0x81819819, 0x4f4fd19e, 0xdcdc7fa3, 0x22226644, 0x2a2a7e54, 0x9090ab3b, 0x8888830b, 0x4646ca8c, 0xeeee29c7, 0xb8b8d36b, 0x14143c28, 0xdede79a7, 0x5e5ee2bc, 0x0b0b1d16, 0xdbdb76ad, 0xe0e03bdb, 0x32325664, 0x3a3a4e74, 0x0a0a1e14, 0x4949db92, 0x06060a0c, 0x24246c48, 0x5c5ce4b8, 0xc2c25d9f, 0xd3d36ebd, 0xacacef43, 0x6262a6c4, 0x9191a839, 0x9595a431, 0xe4e437d3, 0x79798bf2, 0xe7e732d5, 0xc8c8438b, 0x3737596e, 0x6d6db7da, 0x8d8d8c01, 0xd5d564b1, 0x4e4ed29c, 0xa9a9e049, 0x6c6cb4d8, 0x5656faac, 0xf4f407f3, 0xeaea25cf, 0x6565afca, 0x7a7a8ef4, 0xaeaee947, 0x08081810, 0xbabad56f, 0x787888f0, 0x25256f4a, 0x2e2e725c, 0x1c1c2438, 0xa6a6f157, 0xb4b4c773, 0xc6c65197, 0xe8e823cb, 0xdddd7ca1, 0x74749ce8, 0x1f1f213e, 0x4b4bdd96, 0xbdbddc61, 0x8b8b860d, 0x8a8a850f, 0x707090e0, 0x3e3e427c, 0xb5b5c471, 0x6666aacc, 0x4848d890, 0x03030506, 0xf6f601f7, 0x0e0e121c, 0x6161a3c2, 0x35355f6a, 0x5757f9ae, 0xb9b9d069, 0x86869117, 0xc1c15899, 0x1d1d273a, 0x9e9eb927, 0xe1e138d9, 0xf8f813eb, 0x9898b32b, 0x11113322, 0x6969bbd2, 0xd9d970a9, 0x8e8e8907, 0x9494a733, 0x9b9bb62d, 0x1e1e223c, 0x87879215, 0xe9e920c9, 0xcece4987, 0x5555ffaa, 0x28287850, 0xdfdf7aa5, 0x8c8c8f03, 0xa1a1f859, 0x89898009, 0x0d0d171a, 0xbfbfda65, 0xe6e631d7, 0x4242c684, 0x6868b8d0, 0x4141c382, 0x9999b029, 0x2d2d775a, 0x0f0f111e, 0xb0b0cb7b, 0x5454fca8, 0xbbbbd66d, 0x16163a2c ]
|
||||
|
||||
# Transformations for decryption
|
||||
_T5 = [ 0x51f4a750, 0x7e416553, 0x1a17a4c3, 0x3a275e96, 0x3bab6bcb, 0x1f9d45f1, 0xacfa58ab, 0x4be30393, 0x2030fa55, 0xad766df6, 0x88cc7691, 0xf5024c25, 0x4fe5d7fc, 0xc52acbd7, 0x26354480, 0xb562a38f, 0xdeb15a49, 0x25ba1b67, 0x45ea0e98, 0x5dfec0e1, 0xc32f7502, 0x814cf012, 0x8d4697a3, 0x6bd3f9c6, 0x038f5fe7, 0x15929c95, 0xbf6d7aeb, 0x955259da, 0xd4be832d, 0x587421d3, 0x49e06929, 0x8ec9c844, 0x75c2896a, 0xf48e7978, 0x99583e6b, 0x27b971dd, 0xbee14fb6, 0xf088ad17, 0xc920ac66, 0x7dce3ab4, 0x63df4a18, 0xe51a3182, 0x97513360, 0x62537f45, 0xb16477e0, 0xbb6bae84, 0xfe81a01c, 0xf9082b94, 0x70486858, 0x8f45fd19, 0x94de6c87, 0x527bf8b7, 0xab73d323, 0x724b02e2, 0xe31f8f57, 0x6655ab2a, 0xb2eb2807, 0x2fb5c203, 0x86c57b9a, 0xd33708a5, 0x302887f2, 0x23bfa5b2, 0x02036aba, 0xed16825c, 0x8acf1c2b, 0xa779b492, 0xf307f2f0, 0x4e69e2a1, 0x65daf4cd, 0x0605bed5, 0xd134621f, 0xc4a6fe8a, 0x342e539d, 0xa2f355a0, 0x058ae132, 0xa4f6eb75, 0x0b83ec39, 0x4060efaa, 0x5e719f06, 0xbd6e1051, 0x3e218af9, 0x96dd063d, 0xdd3e05ae, 0x4de6bd46, 0x91548db5, 0x71c45d05, 0x0406d46f, 0x605015ff, 0x1998fb24, 0xd6bde997, 0x894043cc, 0x67d99e77, 0xb0e842bd, 0x07898b88, 0xe7195b38, 0x79c8eedb, 0xa17c0a47, 0x7c420fe9, 0xf8841ec9, 0x00000000, 0x09808683, 0x322bed48, 0x1e1170ac, 0x6c5a724e, 0xfd0efffb, 0x0f853856, 0x3daed51e, 0x362d3927, 0x0a0fd964, 0x685ca621, 0x9b5b54d1, 0x24362e3a, 0x0c0a67b1, 0x9357e70f, 0xb4ee96d2, 0x1b9b919e, 0x80c0c54f, 0x61dc20a2, 0x5a774b69, 0x1c121a16, 0xe293ba0a, 0xc0a02ae5, 0x3c22e043, 0x121b171d, 0x0e090d0b, 0xf28bc7ad, 0x2db6a8b9, 0x141ea9c8, 0x57f11985, 0xaf75074c, 0xee99ddbb, 0xa37f60fd, 0xf701269f, 0x5c72f5bc, 0x44663bc5, 0x5bfb7e34, 0x8b432976, 0xcb23c6dc, 0xb6edfc68, 0xb8e4f163, 0xd731dcca, 0x42638510, 0x13972240, 0x84c61120, 0x854a247d, 0xd2bb3df8, 0xaef93211, 0xc729a16d, 0x1d9e2f4b, 0xdcb230f3, 0x0d8652ec, 0x77c1e3d0, 0x2bb3166c, 0xa970b999, 0x119448fa, 0x47e96422, 0xa8fc8cc4, 0xa0f03f1a, 0x567d2cd8, 0x223390ef, 0x87494ec7, 0xd938d1c1, 0x8ccaa2fe, 0x98d40b36, 0xa6f581cf, 0xa57ade28, 0xdab78e26, 0x3fadbfa4, 0x2c3a9de4, 0x5078920d, 0x6a5fcc9b, 0x547e4662, 0xf68d13c2, 0x90d8b8e8, 0x2e39f75e, 0x82c3aff5, 0x9f5d80be, 0x69d0937c, 0x6fd52da9, 0xcf2512b3, 0xc8ac993b, 0x10187da7, 0xe89c636e, 0xdb3bbb7b, 0xcd267809, 0x6e5918f4, 0xec9ab701, 0x834f9aa8, 0xe6956e65, 0xaaffe67e, 0x21bccf08, 0xef15e8e6, 0xbae79bd9, 0x4a6f36ce, 0xea9f09d4, 0x29b07cd6, 0x31a4b2af, 0x2a3f2331, 0xc6a59430, 0x35a266c0, 0x744ebc37, 0xfc82caa6, 0xe090d0b0, 0x33a7d815, 0xf104984a, 0x41ecdaf7, 0x7fcd500e, 0x1791f62f, 0x764dd68d, 0x43efb04d, 0xccaa4d54, 0xe49604df, 0x9ed1b5e3, 0x4c6a881b, 0xc12c1fb8, 0x4665517f, 0x9d5eea04, 0x018c355d, 0xfa877473, 0xfb0b412e, 0xb3671d5a, 0x92dbd252, 0xe9105633, 0x6dd64713, 0x9ad7618c, 0x37a10c7a, 0x59f8148e, 0xeb133c89, 0xcea927ee, 0xb761c935, 0xe11ce5ed, 0x7a47b13c, 0x9cd2df59, 0x55f2733f, 0x1814ce79, 0x73c737bf, 0x53f7cdea, 0x5ffdaa5b, 0xdf3d6f14, 0x7844db86, 0xcaaff381, 0xb968c43e, 0x3824342c, 0xc2a3405f, 0x161dc372, 0xbce2250c, 0x283c498b, 0xff0d9541, 0x39a80171, 0x080cb3de, 0xd8b4e49c, 0x6456c190, 0x7bcb8461, 0xd532b670, 0x486c5c74, 0xd0b85742 ]
|
||||
_T6 = [ 0x5051f4a7, 0x537e4165, 0xc31a17a4, 0x963a275e, 0xcb3bab6b, 0xf11f9d45, 0xabacfa58, 0x934be303, 0x552030fa, 0xf6ad766d, 0x9188cc76, 0x25f5024c, 0xfc4fe5d7, 0xd7c52acb, 0x80263544, 0x8fb562a3, 0x49deb15a, 0x6725ba1b, 0x9845ea0e, 0xe15dfec0, 0x02c32f75, 0x12814cf0, 0xa38d4697, 0xc66bd3f9, 0xe7038f5f, 0x9515929c, 0xebbf6d7a, 0xda955259, 0x2dd4be83, 0xd3587421, 0x2949e069, 0x448ec9c8, 0x6a75c289, 0x78f48e79, 0x6b99583e, 0xdd27b971, 0xb6bee14f, 0x17f088ad, 0x66c920ac, 0xb47dce3a, 0x1863df4a, 0x82e51a31, 0x60975133, 0x4562537f, 0xe0b16477, 0x84bb6bae, 0x1cfe81a0, 0x94f9082b, 0x58704868, 0x198f45fd, 0x8794de6c, 0xb7527bf8, 0x23ab73d3, 0xe2724b02, 0x57e31f8f, 0x2a6655ab, 0x07b2eb28, 0x032fb5c2, 0x9a86c57b, 0xa5d33708, 0xf2302887, 0xb223bfa5, 0xba02036a, 0x5ced1682, 0x2b8acf1c, 0x92a779b4, 0xf0f307f2, 0xa14e69e2, 0xcd65daf4, 0xd50605be, 0x1fd13462, 0x8ac4a6fe, 0x9d342e53, 0xa0a2f355, 0x32058ae1, 0x75a4f6eb, 0x390b83ec, 0xaa4060ef, 0x065e719f, 0x51bd6e10, 0xf93e218a, 0x3d96dd06, 0xaedd3e05, 0x464de6bd, 0xb591548d, 0x0571c45d, 0x6f0406d4, 0xff605015, 0x241998fb, 0x97d6bde9, 0xcc894043, 0x7767d99e, 0xbdb0e842, 0x8807898b, 0x38e7195b, 0xdb79c8ee, 0x47a17c0a, 0xe97c420f, 0xc9f8841e, 0x00000000, 0x83098086, 0x48322bed, 0xac1e1170, 0x4e6c5a72, 0xfbfd0eff, 0x560f8538, 0x1e3daed5, 0x27362d39, 0x640a0fd9, 0x21685ca6, 0xd19b5b54, 0x3a24362e, 0xb10c0a67, 0x0f9357e7, 0xd2b4ee96, 0x9e1b9b91, 0x4f80c0c5, 0xa261dc20, 0x695a774b, 0x161c121a, 0x0ae293ba, 0xe5c0a02a, 0x433c22e0, 0x1d121b17, 0x0b0e090d, 0xadf28bc7, 0xb92db6a8, 0xc8141ea9, 0x8557f119, 0x4caf7507, 0xbbee99dd, 0xfda37f60, 0x9ff70126, 0xbc5c72f5, 0xc544663b, 0x345bfb7e, 0x768b4329, 0xdccb23c6, 0x68b6edfc, 0x63b8e4f1, 0xcad731dc, 0x10426385, 0x40139722, 0x2084c611, 0x7d854a24, 0xf8d2bb3d, 0x11aef932, 0x6dc729a1, 0x4b1d9e2f, 0xf3dcb230, 0xec0d8652, 0xd077c1e3, 0x6c2bb316, 0x99a970b9, 0xfa119448, 0x2247e964, 0xc4a8fc8c, 0x1aa0f03f, 0xd8567d2c, 0xef223390, 0xc787494e, 0xc1d938d1, 0xfe8ccaa2, 0x3698d40b, 0xcfa6f581, 0x28a57ade, 0x26dab78e, 0xa43fadbf, 0xe42c3a9d, 0x0d507892, 0x9b6a5fcc, 0x62547e46, 0xc2f68d13, 0xe890d8b8, 0x5e2e39f7, 0xf582c3af, 0xbe9f5d80, 0x7c69d093, 0xa96fd52d, 0xb3cf2512, 0x3bc8ac99, 0xa710187d, 0x6ee89c63, 0x7bdb3bbb, 0x09cd2678, 0xf46e5918, 0x01ec9ab7, 0xa8834f9a, 0x65e6956e, 0x7eaaffe6, 0x0821bccf, 0xe6ef15e8, 0xd9bae79b, 0xce4a6f36, 0xd4ea9f09, 0xd629b07c, 0xaf31a4b2, 0x312a3f23, 0x30c6a594, 0xc035a266, 0x37744ebc, 0xa6fc82ca, 0xb0e090d0, 0x1533a7d8, 0x4af10498, 0xf741ecda, 0x0e7fcd50, 0x2f1791f6, 0x8d764dd6, 0x4d43efb0, 0x54ccaa4d, 0xdfe49604, 0xe39ed1b5, 0x1b4c6a88, 0xb8c12c1f, 0x7f466551, 0x049d5eea, 0x5d018c35, 0x73fa8774, 0x2efb0b41, 0x5ab3671d, 0x5292dbd2, 0x33e91056, 0x136dd647, 0x8c9ad761, 0x7a37a10c, 0x8e59f814, 0x89eb133c, 0xeecea927, 0x35b761c9, 0xede11ce5, 0x3c7a47b1, 0x599cd2df, 0x3f55f273, 0x791814ce, 0xbf73c737, 0xea53f7cd, 0x5b5ffdaa, 0x14df3d6f, 0x867844db, 0x81caaff3, 0x3eb968c4, 0x2c382434, 0x5fc2a340, 0x72161dc3, 0x0cbce225, 0x8b283c49, 0x41ff0d95, 0x7139a801, 0xde080cb3, 0x9cd8b4e4, 0x906456c1, 0x617bcb84, 0x70d532b6, 0x74486c5c, 0x42d0b857 ]
|
||||
_T7 = [ 0xa75051f4, 0x65537e41, 0xa4c31a17, 0x5e963a27, 0x6bcb3bab, 0x45f11f9d, 0x58abacfa, 0x03934be3, 0xfa552030, 0x6df6ad76, 0x769188cc, 0x4c25f502, 0xd7fc4fe5, 0xcbd7c52a, 0x44802635, 0xa38fb562, 0x5a49deb1, 0x1b6725ba, 0x0e9845ea, 0xc0e15dfe, 0x7502c32f, 0xf012814c, 0x97a38d46, 0xf9c66bd3, 0x5fe7038f, 0x9c951592, 0x7aebbf6d, 0x59da9552, 0x832dd4be, 0x21d35874, 0x692949e0, 0xc8448ec9, 0x896a75c2, 0x7978f48e, 0x3e6b9958, 0x71dd27b9, 0x4fb6bee1, 0xad17f088, 0xac66c920, 0x3ab47dce, 0x4a1863df, 0x3182e51a, 0x33609751, 0x7f456253, 0x77e0b164, 0xae84bb6b, 0xa01cfe81, 0x2b94f908, 0x68587048, 0xfd198f45, 0x6c8794de, 0xf8b7527b, 0xd323ab73, 0x02e2724b, 0x8f57e31f, 0xab2a6655, 0x2807b2eb, 0xc2032fb5, 0x7b9a86c5, 0x08a5d337, 0x87f23028, 0xa5b223bf, 0x6aba0203, 0x825ced16, 0x1c2b8acf, 0xb492a779, 0xf2f0f307, 0xe2a14e69, 0xf4cd65da, 0xbed50605, 0x621fd134, 0xfe8ac4a6, 0x539d342e, 0x55a0a2f3, 0xe132058a, 0xeb75a4f6, 0xec390b83, 0xefaa4060, 0x9f065e71, 0x1051bd6e, 0x8af93e21, 0x063d96dd, 0x05aedd3e, 0xbd464de6, 0x8db59154, 0x5d0571c4, 0xd46f0406, 0x15ff6050, 0xfb241998, 0xe997d6bd, 0x43cc8940, 0x9e7767d9, 0x42bdb0e8, 0x8b880789, 0x5b38e719, 0xeedb79c8, 0x0a47a17c, 0x0fe97c42, 0x1ec9f884, 0x00000000, 0x86830980, 0xed48322b, 0x70ac1e11, 0x724e6c5a, 0xfffbfd0e, 0x38560f85, 0xd51e3dae, 0x3927362d, 0xd9640a0f, 0xa621685c, 0x54d19b5b, 0x2e3a2436, 0x67b10c0a, 0xe70f9357, 0x96d2b4ee, 0x919e1b9b, 0xc54f80c0, 0x20a261dc, 0x4b695a77, 0x1a161c12, 0xba0ae293, 0x2ae5c0a0, 0xe0433c22, 0x171d121b, 0x0d0b0e09, 0xc7adf28b, 0xa8b92db6, 0xa9c8141e, 0x198557f1, 0x074caf75, 0xddbbee99, 0x60fda37f, 0x269ff701, 0xf5bc5c72, 0x3bc54466, 0x7e345bfb, 0x29768b43, 0xc6dccb23, 0xfc68b6ed, 0xf163b8e4, 0xdccad731, 0x85104263, 0x22401397, 0x112084c6, 0x247d854a, 0x3df8d2bb, 0x3211aef9, 0xa16dc729, 0x2f4b1d9e, 0x30f3dcb2, 0x52ec0d86, 0xe3d077c1, 0x166c2bb3, 0xb999a970, 0x48fa1194, 0x642247e9, 0x8cc4a8fc, 0x3f1aa0f0, 0x2cd8567d, 0x90ef2233, 0x4ec78749, 0xd1c1d938, 0xa2fe8cca, 0x0b3698d4, 0x81cfa6f5, 0xde28a57a, 0x8e26dab7, 0xbfa43fad, 0x9de42c3a, 0x920d5078, 0xcc9b6a5f, 0x4662547e, 0x13c2f68d, 0xb8e890d8, 0xf75e2e39, 0xaff582c3, 0x80be9f5d, 0x937c69d0, 0x2da96fd5, 0x12b3cf25, 0x993bc8ac, 0x7da71018, 0x636ee89c, 0xbb7bdb3b, 0x7809cd26, 0x18f46e59, 0xb701ec9a, 0x9aa8834f, 0x6e65e695, 0xe67eaaff, 0xcf0821bc, 0xe8e6ef15, 0x9bd9bae7, 0x36ce4a6f, 0x09d4ea9f, 0x7cd629b0, 0xb2af31a4, 0x23312a3f, 0x9430c6a5, 0x66c035a2, 0xbc37744e, 0xcaa6fc82, 0xd0b0e090, 0xd81533a7, 0x984af104, 0xdaf741ec, 0x500e7fcd, 0xf62f1791, 0xd68d764d, 0xb04d43ef, 0x4d54ccaa, 0x04dfe496, 0xb5e39ed1, 0x881b4c6a, 0x1fb8c12c, 0x517f4665, 0xea049d5e, 0x355d018c, 0x7473fa87, 0x412efb0b, 0x1d5ab367, 0xd25292db, 0x5633e910, 0x47136dd6, 0x618c9ad7, 0x0c7a37a1, 0x148e59f8, 0x3c89eb13, 0x27eecea9, 0xc935b761, 0xe5ede11c, 0xb13c7a47, 0xdf599cd2, 0x733f55f2, 0xce791814, 0x37bf73c7, 0xcdea53f7, 0xaa5b5ffd, 0x6f14df3d, 0xdb867844, 0xf381caaf, 0xc43eb968, 0x342c3824, 0x405fc2a3, 0xc372161d, 0x250cbce2, 0x498b283c, 0x9541ff0d, 0x017139a8, 0xb3de080c, 0xe49cd8b4, 0xc1906456, 0x84617bcb, 0xb670d532, 0x5c74486c, 0x5742d0b8 ]
|
||||
_T8 = [ 0xf4a75051, 0x4165537e, 0x17a4c31a, 0x275e963a, 0xab6bcb3b, 0x9d45f11f, 0xfa58abac, 0xe303934b, 0x30fa5520, 0x766df6ad, 0xcc769188, 0x024c25f5, 0xe5d7fc4f, 0x2acbd7c5, 0x35448026, 0x62a38fb5, 0xb15a49de, 0xba1b6725, 0xea0e9845, 0xfec0e15d, 0x2f7502c3, 0x4cf01281, 0x4697a38d, 0xd3f9c66b, 0x8f5fe703, 0x929c9515, 0x6d7aebbf, 0x5259da95, 0xbe832dd4, 0x7421d358, 0xe0692949, 0xc9c8448e, 0xc2896a75, 0x8e7978f4, 0x583e6b99, 0xb971dd27, 0xe14fb6be, 0x88ad17f0, 0x20ac66c9, 0xce3ab47d, 0xdf4a1863, 0x1a3182e5, 0x51336097, 0x537f4562, 0x6477e0b1, 0x6bae84bb, 0x81a01cfe, 0x082b94f9, 0x48685870, 0x45fd198f, 0xde6c8794, 0x7bf8b752, 0x73d323ab, 0x4b02e272, 0x1f8f57e3, 0x55ab2a66, 0xeb2807b2, 0xb5c2032f, 0xc57b9a86, 0x3708a5d3, 0x2887f230, 0xbfa5b223, 0x036aba02, 0x16825ced, 0xcf1c2b8a, 0x79b492a7, 0x07f2f0f3, 0x69e2a14e, 0xdaf4cd65, 0x05bed506, 0x34621fd1, 0xa6fe8ac4, 0x2e539d34, 0xf355a0a2, 0x8ae13205, 0xf6eb75a4, 0x83ec390b, 0x60efaa40, 0x719f065e, 0x6e1051bd, 0x218af93e, 0xdd063d96, 0x3e05aedd, 0xe6bd464d, 0x548db591, 0xc45d0571, 0x06d46f04, 0x5015ff60, 0x98fb2419, 0xbde997d6, 0x4043cc89, 0xd99e7767, 0xe842bdb0, 0x898b8807, 0x195b38e7, 0xc8eedb79, 0x7c0a47a1, 0x420fe97c, 0x841ec9f8, 0x00000000, 0x80868309, 0x2bed4832, 0x1170ac1e, 0x5a724e6c, 0x0efffbfd, 0x8538560f, 0xaed51e3d, 0x2d392736, 0x0fd9640a, 0x5ca62168, 0x5b54d19b, 0x362e3a24, 0x0a67b10c, 0x57e70f93, 0xee96d2b4, 0x9b919e1b, 0xc0c54f80, 0xdc20a261, 0x774b695a, 0x121a161c, 0x93ba0ae2, 0xa02ae5c0, 0x22e0433c, 0x1b171d12, 0x090d0b0e, 0x8bc7adf2, 0xb6a8b92d, 0x1ea9c814, 0xf1198557, 0x75074caf, 0x99ddbbee, 0x7f60fda3, 0x01269ff7, 0x72f5bc5c, 0x663bc544, 0xfb7e345b, 0x4329768b, 0x23c6dccb, 0xedfc68b6, 0xe4f163b8, 0x31dccad7, 0x63851042, 0x97224013, 0xc6112084, 0x4a247d85, 0xbb3df8d2, 0xf93211ae, 0x29a16dc7, 0x9e2f4b1d, 0xb230f3dc, 0x8652ec0d, 0xc1e3d077, 0xb3166c2b, 0x70b999a9, 0x9448fa11, 0xe9642247, 0xfc8cc4a8, 0xf03f1aa0, 0x7d2cd856, 0x3390ef22, 0x494ec787, 0x38d1c1d9, 0xcaa2fe8c, 0xd40b3698, 0xf581cfa6, 0x7ade28a5, 0xb78e26da, 0xadbfa43f, 0x3a9de42c, 0x78920d50, 0x5fcc9b6a, 0x7e466254, 0x8d13c2f6, 0xd8b8e890, 0x39f75e2e, 0xc3aff582, 0x5d80be9f, 0xd0937c69, 0xd52da96f, 0x2512b3cf, 0xac993bc8, 0x187da710, 0x9c636ee8, 0x3bbb7bdb, 0x267809cd, 0x5918f46e, 0x9ab701ec, 0x4f9aa883, 0x956e65e6, 0xffe67eaa, 0xbccf0821, 0x15e8e6ef, 0xe79bd9ba, 0x6f36ce4a, 0x9f09d4ea, 0xb07cd629, 0xa4b2af31, 0x3f23312a, 0xa59430c6, 0xa266c035, 0x4ebc3774, 0x82caa6fc, 0x90d0b0e0, 0xa7d81533, 0x04984af1, 0xecdaf741, 0xcd500e7f, 0x91f62f17, 0x4dd68d76, 0xefb04d43, 0xaa4d54cc, 0x9604dfe4, 0xd1b5e39e, 0x6a881b4c, 0x2c1fb8c1, 0x65517f46, 0x5eea049d, 0x8c355d01, 0x877473fa, 0x0b412efb, 0x671d5ab3, 0xdbd25292, 0x105633e9, 0xd647136d, 0xd7618c9a, 0xa10c7a37, 0xf8148e59, 0x133c89eb, 0xa927eece, 0x61c935b7, 0x1ce5ede1, 0x47b13c7a, 0xd2df599c, 0xf2733f55, 0x14ce7918, 0xc737bf73, 0xf7cdea53, 0xfdaa5b5f, 0x3d6f14df, 0x44db8678, 0xaff381ca, 0x68c43eb9, 0x24342c38, 0xa3405fc2, 0x1dc37216, 0xe2250cbc, 0x3c498b28, 0x0d9541ff, 0xa8017139, 0x0cb3de08, 0xb4e49cd8, 0x56c19064, 0xcb84617b, 0x32b670d5, 0x6c5c7448, 0xb85742d0 ]
|
||||
|
||||
# Transformations for decryption key expansion
|
||||
_U1 = [ 0x00000000, 0x0e090d0b, 0x1c121a16, 0x121b171d, 0x3824342c, 0x362d3927, 0x24362e3a, 0x2a3f2331, 0x70486858, 0x7e416553, 0x6c5a724e, 0x62537f45, 0x486c5c74, 0x4665517f, 0x547e4662, 0x5a774b69, 0xe090d0b0, 0xee99ddbb, 0xfc82caa6, 0xf28bc7ad, 0xd8b4e49c, 0xd6bde997, 0xc4a6fe8a, 0xcaaff381, 0x90d8b8e8, 0x9ed1b5e3, 0x8ccaa2fe, 0x82c3aff5, 0xa8fc8cc4, 0xa6f581cf, 0xb4ee96d2, 0xbae79bd9, 0xdb3bbb7b, 0xd532b670, 0xc729a16d, 0xc920ac66, 0xe31f8f57, 0xed16825c, 0xff0d9541, 0xf104984a, 0xab73d323, 0xa57ade28, 0xb761c935, 0xb968c43e, 0x9357e70f, 0x9d5eea04, 0x8f45fd19, 0x814cf012, 0x3bab6bcb, 0x35a266c0, 0x27b971dd, 0x29b07cd6, 0x038f5fe7, 0x0d8652ec, 0x1f9d45f1, 0x119448fa, 0x4be30393, 0x45ea0e98, 0x57f11985, 0x59f8148e, 0x73c737bf, 0x7dce3ab4, 0x6fd52da9, 0x61dc20a2, 0xad766df6, 0xa37f60fd, 0xb16477e0, 0xbf6d7aeb, 0x955259da, 0x9b5b54d1, 0x894043cc, 0x87494ec7, 0xdd3e05ae, 0xd33708a5, 0xc12c1fb8, 0xcf2512b3, 0xe51a3182, 0xeb133c89, 0xf9082b94, 0xf701269f, 0x4de6bd46, 0x43efb04d, 0x51f4a750, 0x5ffdaa5b, 0x75c2896a, 0x7bcb8461, 0x69d0937c, 0x67d99e77, 0x3daed51e, 0x33a7d815, 0x21bccf08, 0x2fb5c203, 0x058ae132, 0x0b83ec39, 0x1998fb24, 0x1791f62f, 0x764dd68d, 0x7844db86, 0x6a5fcc9b, 0x6456c190, 0x4e69e2a1, 0x4060efaa, 0x527bf8b7, 0x5c72f5bc, 0x0605bed5, 0x080cb3de, 0x1a17a4c3, 0x141ea9c8, 0x3e218af9, 0x302887f2, 0x223390ef, 0x2c3a9de4, 0x96dd063d, 0x98d40b36, 0x8acf1c2b, 0x84c61120, 0xaef93211, 0xa0f03f1a, 0xb2eb2807, 0xbce2250c, 0xe6956e65, 0xe89c636e, 0xfa877473, 0xf48e7978, 0xdeb15a49, 0xd0b85742, 0xc2a3405f, 0xccaa4d54, 0x41ecdaf7, 0x4fe5d7fc, 0x5dfec0e1, 0x53f7cdea, 0x79c8eedb, 0x77c1e3d0, 0x65daf4cd, 0x6bd3f9c6, 0x31a4b2af, 0x3fadbfa4, 0x2db6a8b9, 0x23bfa5b2, 0x09808683, 0x07898b88, 0x15929c95, 0x1b9b919e, 0xa17c0a47, 0xaf75074c, 0xbd6e1051, 0xb3671d5a, 0x99583e6b, 0x97513360, 0x854a247d, 0x8b432976, 0xd134621f, 0xdf3d6f14, 0xcd267809, 0xc32f7502, 0xe9105633, 0xe7195b38, 0xf5024c25, 0xfb0b412e, 0x9ad7618c, 0x94de6c87, 0x86c57b9a, 0x88cc7691, 0xa2f355a0, 0xacfa58ab, 0xbee14fb6, 0xb0e842bd, 0xea9f09d4, 0xe49604df, 0xf68d13c2, 0xf8841ec9, 0xd2bb3df8, 0xdcb230f3, 0xcea927ee, 0xc0a02ae5, 0x7a47b13c, 0x744ebc37, 0x6655ab2a, 0x685ca621, 0x42638510, 0x4c6a881b, 0x5e719f06, 0x5078920d, 0x0a0fd964, 0x0406d46f, 0x161dc372, 0x1814ce79, 0x322bed48, 0x3c22e043, 0x2e39f75e, 0x2030fa55, 0xec9ab701, 0xe293ba0a, 0xf088ad17, 0xfe81a01c, 0xd4be832d, 0xdab78e26, 0xc8ac993b, 0xc6a59430, 0x9cd2df59, 0x92dbd252, 0x80c0c54f, 0x8ec9c844, 0xa4f6eb75, 0xaaffe67e, 0xb8e4f163, 0xb6edfc68, 0x0c0a67b1, 0x02036aba, 0x10187da7, 0x1e1170ac, 0x342e539d, 0x3a275e96, 0x283c498b, 0x26354480, 0x7c420fe9, 0x724b02e2, 0x605015ff, 0x6e5918f4, 0x44663bc5, 0x4a6f36ce, 0x587421d3, 0x567d2cd8, 0x37a10c7a, 0x39a80171, 0x2bb3166c, 0x25ba1b67, 0x0f853856, 0x018c355d, 0x13972240, 0x1d9e2f4b, 0x47e96422, 0x49e06929, 0x5bfb7e34, 0x55f2733f, 0x7fcd500e, 0x71c45d05, 0x63df4a18, 0x6dd64713, 0xd731dcca, 0xd938d1c1, 0xcb23c6dc, 0xc52acbd7, 0xef15e8e6, 0xe11ce5ed, 0xf307f2f0, 0xfd0efffb, 0xa779b492, 0xa970b999, 0xbb6bae84, 0xb562a38f, 0x9f5d80be, 0x91548db5, 0x834f9aa8, 0x8d4697a3 ]
|
||||
_U2 = [ 0x00000000, 0x0b0e090d, 0x161c121a, 0x1d121b17, 0x2c382434, 0x27362d39, 0x3a24362e, 0x312a3f23, 0x58704868, 0x537e4165, 0x4e6c5a72, 0x4562537f, 0x74486c5c, 0x7f466551, 0x62547e46, 0x695a774b, 0xb0e090d0, 0xbbee99dd, 0xa6fc82ca, 0xadf28bc7, 0x9cd8b4e4, 0x97d6bde9, 0x8ac4a6fe, 0x81caaff3, 0xe890d8b8, 0xe39ed1b5, 0xfe8ccaa2, 0xf582c3af, 0xc4a8fc8c, 0xcfa6f581, 0xd2b4ee96, 0xd9bae79b, 0x7bdb3bbb, 0x70d532b6, 0x6dc729a1, 0x66c920ac, 0x57e31f8f, 0x5ced1682, 0x41ff0d95, 0x4af10498, 0x23ab73d3, 0x28a57ade, 0x35b761c9, 0x3eb968c4, 0x0f9357e7, 0x049d5eea, 0x198f45fd, 0x12814cf0, 0xcb3bab6b, 0xc035a266, 0xdd27b971, 0xd629b07c, 0xe7038f5f, 0xec0d8652, 0xf11f9d45, 0xfa119448, 0x934be303, 0x9845ea0e, 0x8557f119, 0x8e59f814, 0xbf73c737, 0xb47dce3a, 0xa96fd52d, 0xa261dc20, 0xf6ad766d, 0xfda37f60, 0xe0b16477, 0xebbf6d7a, 0xda955259, 0xd19b5b54, 0xcc894043, 0xc787494e, 0xaedd3e05, 0xa5d33708, 0xb8c12c1f, 0xb3cf2512, 0x82e51a31, 0x89eb133c, 0x94f9082b, 0x9ff70126, 0x464de6bd, 0x4d43efb0, 0x5051f4a7, 0x5b5ffdaa, 0x6a75c289, 0x617bcb84, 0x7c69d093, 0x7767d99e, 0x1e3daed5, 0x1533a7d8, 0x0821bccf, 0x032fb5c2, 0x32058ae1, 0x390b83ec, 0x241998fb, 0x2f1791f6, 0x8d764dd6, 0x867844db, 0x9b6a5fcc, 0x906456c1, 0xa14e69e2, 0xaa4060ef, 0xb7527bf8, 0xbc5c72f5, 0xd50605be, 0xde080cb3, 0xc31a17a4, 0xc8141ea9, 0xf93e218a, 0xf2302887, 0xef223390, 0xe42c3a9d, 0x3d96dd06, 0x3698d40b, 0x2b8acf1c, 0x2084c611, 0x11aef932, 0x1aa0f03f, 0x07b2eb28, 0x0cbce225, 0x65e6956e, 0x6ee89c63, 0x73fa8774, 0x78f48e79, 0x49deb15a, 0x42d0b857, 0x5fc2a340, 0x54ccaa4d, 0xf741ecda, 0xfc4fe5d7, 0xe15dfec0, 0xea53f7cd, 0xdb79c8ee, 0xd077c1e3, 0xcd65daf4, 0xc66bd3f9, 0xaf31a4b2, 0xa43fadbf, 0xb92db6a8, 0xb223bfa5, 0x83098086, 0x8807898b, 0x9515929c, 0x9e1b9b91, 0x47a17c0a, 0x4caf7507, 0x51bd6e10, 0x5ab3671d, 0x6b99583e, 0x60975133, 0x7d854a24, 0x768b4329, 0x1fd13462, 0x14df3d6f, 0x09cd2678, 0x02c32f75, 0x33e91056, 0x38e7195b, 0x25f5024c, 0x2efb0b41, 0x8c9ad761, 0x8794de6c, 0x9a86c57b, 0x9188cc76, 0xa0a2f355, 0xabacfa58, 0xb6bee14f, 0xbdb0e842, 0xd4ea9f09, 0xdfe49604, 0xc2f68d13, 0xc9f8841e, 0xf8d2bb3d, 0xf3dcb230, 0xeecea927, 0xe5c0a02a, 0x3c7a47b1, 0x37744ebc, 0x2a6655ab, 0x21685ca6, 0x10426385, 0x1b4c6a88, 0x065e719f, 0x0d507892, 0x640a0fd9, 0x6f0406d4, 0x72161dc3, 0x791814ce, 0x48322bed, 0x433c22e0, 0x5e2e39f7, 0x552030fa, 0x01ec9ab7, 0x0ae293ba, 0x17f088ad, 0x1cfe81a0, 0x2dd4be83, 0x26dab78e, 0x3bc8ac99, 0x30c6a594, 0x599cd2df, 0x5292dbd2, 0x4f80c0c5, 0x448ec9c8, 0x75a4f6eb, 0x7eaaffe6, 0x63b8e4f1, 0x68b6edfc, 0xb10c0a67, 0xba02036a, 0xa710187d, 0xac1e1170, 0x9d342e53, 0x963a275e, 0x8b283c49, 0x80263544, 0xe97c420f, 0xe2724b02, 0xff605015, 0xf46e5918, 0xc544663b, 0xce4a6f36, 0xd3587421, 0xd8567d2c, 0x7a37a10c, 0x7139a801, 0x6c2bb316, 0x6725ba1b, 0x560f8538, 0x5d018c35, 0x40139722, 0x4b1d9e2f, 0x2247e964, 0x2949e069, 0x345bfb7e, 0x3f55f273, 0x0e7fcd50, 0x0571c45d, 0x1863df4a, 0x136dd647, 0xcad731dc, 0xc1d938d1, 0xdccb23c6, 0xd7c52acb, 0xe6ef15e8, 0xede11ce5, 0xf0f307f2, 0xfbfd0eff, 0x92a779b4, 0x99a970b9, 0x84bb6bae, 0x8fb562a3, 0xbe9f5d80, 0xb591548d, 0xa8834f9a, 0xa38d4697 ]
|
||||
_U3 = [ 0x00000000, 0x0d0b0e09, 0x1a161c12, 0x171d121b, 0x342c3824, 0x3927362d, 0x2e3a2436, 0x23312a3f, 0x68587048, 0x65537e41, 0x724e6c5a, 0x7f456253, 0x5c74486c, 0x517f4665, 0x4662547e, 0x4b695a77, 0xd0b0e090, 0xddbbee99, 0xcaa6fc82, 0xc7adf28b, 0xe49cd8b4, 0xe997d6bd, 0xfe8ac4a6, 0xf381caaf, 0xb8e890d8, 0xb5e39ed1, 0xa2fe8cca, 0xaff582c3, 0x8cc4a8fc, 0x81cfa6f5, 0x96d2b4ee, 0x9bd9bae7, 0xbb7bdb3b, 0xb670d532, 0xa16dc729, 0xac66c920, 0x8f57e31f, 0x825ced16, 0x9541ff0d, 0x984af104, 0xd323ab73, 0xde28a57a, 0xc935b761, 0xc43eb968, 0xe70f9357, 0xea049d5e, 0xfd198f45, 0xf012814c, 0x6bcb3bab, 0x66c035a2, 0x71dd27b9, 0x7cd629b0, 0x5fe7038f, 0x52ec0d86, 0x45f11f9d, 0x48fa1194, 0x03934be3, 0x0e9845ea, 0x198557f1, 0x148e59f8, 0x37bf73c7, 0x3ab47dce, 0x2da96fd5, 0x20a261dc, 0x6df6ad76, 0x60fda37f, 0x77e0b164, 0x7aebbf6d, 0x59da9552, 0x54d19b5b, 0x43cc8940, 0x4ec78749, 0x05aedd3e, 0x08a5d337, 0x1fb8c12c, 0x12b3cf25, 0x3182e51a, 0x3c89eb13, 0x2b94f908, 0x269ff701, 0xbd464de6, 0xb04d43ef, 0xa75051f4, 0xaa5b5ffd, 0x896a75c2, 0x84617bcb, 0x937c69d0, 0x9e7767d9, 0xd51e3dae, 0xd81533a7, 0xcf0821bc, 0xc2032fb5, 0xe132058a, 0xec390b83, 0xfb241998, 0xf62f1791, 0xd68d764d, 0xdb867844, 0xcc9b6a5f, 0xc1906456, 0xe2a14e69, 0xefaa4060, 0xf8b7527b, 0xf5bc5c72, 0xbed50605, 0xb3de080c, 0xa4c31a17, 0xa9c8141e, 0x8af93e21, 0x87f23028, 0x90ef2233, 0x9de42c3a, 0x063d96dd, 0x0b3698d4, 0x1c2b8acf, 0x112084c6, 0x3211aef9, 0x3f1aa0f0, 0x2807b2eb, 0x250cbce2, 0x6e65e695, 0x636ee89c, 0x7473fa87, 0x7978f48e, 0x5a49deb1, 0x5742d0b8, 0x405fc2a3, 0x4d54ccaa, 0xdaf741ec, 0xd7fc4fe5, 0xc0e15dfe, 0xcdea53f7, 0xeedb79c8, 0xe3d077c1, 0xf4cd65da, 0xf9c66bd3, 0xb2af31a4, 0xbfa43fad, 0xa8b92db6, 0xa5b223bf, 0x86830980, 0x8b880789, 0x9c951592, 0x919e1b9b, 0x0a47a17c, 0x074caf75, 0x1051bd6e, 0x1d5ab367, 0x3e6b9958, 0x33609751, 0x247d854a, 0x29768b43, 0x621fd134, 0x6f14df3d, 0x7809cd26, 0x7502c32f, 0x5633e910, 0x5b38e719, 0x4c25f502, 0x412efb0b, 0x618c9ad7, 0x6c8794de, 0x7b9a86c5, 0x769188cc, 0x55a0a2f3, 0x58abacfa, 0x4fb6bee1, 0x42bdb0e8, 0x09d4ea9f, 0x04dfe496, 0x13c2f68d, 0x1ec9f884, 0x3df8d2bb, 0x30f3dcb2, 0x27eecea9, 0x2ae5c0a0, 0xb13c7a47, 0xbc37744e, 0xab2a6655, 0xa621685c, 0x85104263, 0x881b4c6a, 0x9f065e71, 0x920d5078, 0xd9640a0f, 0xd46f0406, 0xc372161d, 0xce791814, 0xed48322b, 0xe0433c22, 0xf75e2e39, 0xfa552030, 0xb701ec9a, 0xba0ae293, 0xad17f088, 0xa01cfe81, 0x832dd4be, 0x8e26dab7, 0x993bc8ac, 0x9430c6a5, 0xdf599cd2, 0xd25292db, 0xc54f80c0, 0xc8448ec9, 0xeb75a4f6, 0xe67eaaff, 0xf163b8e4, 0xfc68b6ed, 0x67b10c0a, 0x6aba0203, 0x7da71018, 0x70ac1e11, 0x539d342e, 0x5e963a27, 0x498b283c, 0x44802635, 0x0fe97c42, 0x02e2724b, 0x15ff6050, 0x18f46e59, 0x3bc54466, 0x36ce4a6f, 0x21d35874, 0x2cd8567d, 0x0c7a37a1, 0x017139a8, 0x166c2bb3, 0x1b6725ba, 0x38560f85, 0x355d018c, 0x22401397, 0x2f4b1d9e, 0x642247e9, 0x692949e0, 0x7e345bfb, 0x733f55f2, 0x500e7fcd, 0x5d0571c4, 0x4a1863df, 0x47136dd6, 0xdccad731, 0xd1c1d938, 0xc6dccb23, 0xcbd7c52a, 0xe8e6ef15, 0xe5ede11c, 0xf2f0f307, 0xfffbfd0e, 0xb492a779, 0xb999a970, 0xae84bb6b, 0xa38fb562, 0x80be9f5d, 0x8db59154, 0x9aa8834f, 0x97a38d46 ]
|
||||
_U4 = [ 0x00000000, 0x090d0b0e, 0x121a161c, 0x1b171d12, 0x24342c38, 0x2d392736, 0x362e3a24, 0x3f23312a, 0x48685870, 0x4165537e, 0x5a724e6c, 0x537f4562, 0x6c5c7448, 0x65517f46, 0x7e466254, 0x774b695a, 0x90d0b0e0, 0x99ddbbee, 0x82caa6fc, 0x8bc7adf2, 0xb4e49cd8, 0xbde997d6, 0xa6fe8ac4, 0xaff381ca, 0xd8b8e890, 0xd1b5e39e, 0xcaa2fe8c, 0xc3aff582, 0xfc8cc4a8, 0xf581cfa6, 0xee96d2b4, 0xe79bd9ba, 0x3bbb7bdb, 0x32b670d5, 0x29a16dc7, 0x20ac66c9, 0x1f8f57e3, 0x16825ced, 0x0d9541ff, 0x04984af1, 0x73d323ab, 0x7ade28a5, 0x61c935b7, 0x68c43eb9, 0x57e70f93, 0x5eea049d, 0x45fd198f, 0x4cf01281, 0xab6bcb3b, 0xa266c035, 0xb971dd27, 0xb07cd629, 0x8f5fe703, 0x8652ec0d, 0x9d45f11f, 0x9448fa11, 0xe303934b, 0xea0e9845, 0xf1198557, 0xf8148e59, 0xc737bf73, 0xce3ab47d, 0xd52da96f, 0xdc20a261, 0x766df6ad, 0x7f60fda3, 0x6477e0b1, 0x6d7aebbf, 0x5259da95, 0x5b54d19b, 0x4043cc89, 0x494ec787, 0x3e05aedd, 0x3708a5d3, 0x2c1fb8c1, 0x2512b3cf, 0x1a3182e5, 0x133c89eb, 0x082b94f9, 0x01269ff7, 0xe6bd464d, 0xefb04d43, 0xf4a75051, 0xfdaa5b5f, 0xc2896a75, 0xcb84617b, 0xd0937c69, 0xd99e7767, 0xaed51e3d, 0xa7d81533, 0xbccf0821, 0xb5c2032f, 0x8ae13205, 0x83ec390b, 0x98fb2419, 0x91f62f17, 0x4dd68d76, 0x44db8678, 0x5fcc9b6a, 0x56c19064, 0x69e2a14e, 0x60efaa40, 0x7bf8b752, 0x72f5bc5c, 0x05bed506, 0x0cb3de08, 0x17a4c31a, 0x1ea9c814, 0x218af93e, 0x2887f230, 0x3390ef22, 0x3a9de42c, 0xdd063d96, 0xd40b3698, 0xcf1c2b8a, 0xc6112084, 0xf93211ae, 0xf03f1aa0, 0xeb2807b2, 0xe2250cbc, 0x956e65e6, 0x9c636ee8, 0x877473fa, 0x8e7978f4, 0xb15a49de, 0xb85742d0, 0xa3405fc2, 0xaa4d54cc, 0xecdaf741, 0xe5d7fc4f, 0xfec0e15d, 0xf7cdea53, 0xc8eedb79, 0xc1e3d077, 0xdaf4cd65, 0xd3f9c66b, 0xa4b2af31, 0xadbfa43f, 0xb6a8b92d, 0xbfa5b223, 0x80868309, 0x898b8807, 0x929c9515, 0x9b919e1b, 0x7c0a47a1, 0x75074caf, 0x6e1051bd, 0x671d5ab3, 0x583e6b99, 0x51336097, 0x4a247d85, 0x4329768b, 0x34621fd1, 0x3d6f14df, 0x267809cd, 0x2f7502c3, 0x105633e9, 0x195b38e7, 0x024c25f5, 0x0b412efb, 0xd7618c9a, 0xde6c8794, 0xc57b9a86, 0xcc769188, 0xf355a0a2, 0xfa58abac, 0xe14fb6be, 0xe842bdb0, 0x9f09d4ea, 0x9604dfe4, 0x8d13c2f6, 0x841ec9f8, 0xbb3df8d2, 0xb230f3dc, 0xa927eece, 0xa02ae5c0, 0x47b13c7a, 0x4ebc3774, 0x55ab2a66, 0x5ca62168, 0x63851042, 0x6a881b4c, 0x719f065e, 0x78920d50, 0x0fd9640a, 0x06d46f04, 0x1dc37216, 0x14ce7918, 0x2bed4832, 0x22e0433c, 0x39f75e2e, 0x30fa5520, 0x9ab701ec, 0x93ba0ae2, 0x88ad17f0, 0x81a01cfe, 0xbe832dd4, 0xb78e26da, 0xac993bc8, 0xa59430c6, 0xd2df599c, 0xdbd25292, 0xc0c54f80, 0xc9c8448e, 0xf6eb75a4, 0xffe67eaa, 0xe4f163b8, 0xedfc68b6, 0x0a67b10c, 0x036aba02, 0x187da710, 0x1170ac1e, 0x2e539d34, 0x275e963a, 0x3c498b28, 0x35448026, 0x420fe97c, 0x4b02e272, 0x5015ff60, 0x5918f46e, 0x663bc544, 0x6f36ce4a, 0x7421d358, 0x7d2cd856, 0xa10c7a37, 0xa8017139, 0xb3166c2b, 0xba1b6725, 0x8538560f, 0x8c355d01, 0x97224013, 0x9e2f4b1d, 0xe9642247, 0xe0692949, 0xfb7e345b, 0xf2733f55, 0xcd500e7f, 0xc45d0571, 0xdf4a1863, 0xd647136d, 0x31dccad7, 0x38d1c1d9, 0x23c6dccb, 0x2acbd7c5, 0x15e8e6ef, 0x1ce5ede1, 0x07f2f0f3, 0x0efffbfd, 0x79b492a7, 0x70b999a9, 0x6bae84bb, 0x62a38fb5, 0x5d80be9f, 0x548db591, 0x4f9aa883, 0x4697a38d ]
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, key: bytes) -> None:
|
||||
|
||||
if len(key) not in (16, 24, 32):
|
||||
raise core.InvalidArgumentError(f'Invalid key size {len(key)}')
|
||||
|
||||
rounds = self._NUMBER_OF_ROUNDS[len(key)]
|
||||
|
||||
# Encryption round keys
|
||||
self._ke = [[0] * 4 for i in range(rounds + 1)]
|
||||
|
||||
# Decryption round keys
|
||||
self._kd = [[0] * 4 for i in range(rounds + 1)]
|
||||
|
||||
round_key_count = (rounds + 1) * 4
|
||||
kc = len(key) // 4
|
||||
|
||||
# Convert the key into ints
|
||||
tk = [struct.unpack('>i', key[i : i + 4])[0] for i in range(0, len(key), 4)]
|
||||
|
||||
# Copy values into round key arrays
|
||||
for i in range(0, kc):
|
||||
self._ke[i // 4][i % 4] = tk[i]
|
||||
self._kd[rounds - (i // 4)][i % 4] = tk[i]
|
||||
|
||||
# Key expansion (FIPS-197 section 5.2)
|
||||
r_con_pointer = 0
|
||||
t = kc
|
||||
while t < round_key_count:
|
||||
|
||||
tt = tk[kc - 1]
|
||||
tk[0] ^= (
|
||||
(self._S[(tt >> 16) & 0xFF] << 24)
|
||||
^ (self._S[(tt >> 8) & 0xFF] << 16)
|
||||
^ (self._S[tt & 0xFF] << 8)
|
||||
^ self._S[(tt >> 24) & 0xFF]
|
||||
^ (self._RCON[r_con_pointer] << 24)
|
||||
)
|
||||
r_con_pointer += 1
|
||||
|
||||
if kc != 8:
|
||||
for i in range(1, kc):
|
||||
tk[i] ^= tk[i - 1]
|
||||
|
||||
# Key expansion for 256-bit keys is "slightly different" (FIPS-197)
|
||||
else:
|
||||
for i in range(1, kc // 2):
|
||||
tk[i] ^= tk[i - 1]
|
||||
tt = tk[kc // 2 - 1]
|
||||
|
||||
tk[kc // 2] ^= (
|
||||
self._S[tt & 0xFF]
|
||||
^ (self._S[(tt >> 8) & 0xFF] << 8)
|
||||
^ (self._S[(tt >> 16) & 0xFF] << 16)
|
||||
^ (self._S[(tt >> 24) & 0xFF] << 24)
|
||||
)
|
||||
|
||||
for i in range(kc // 2 + 1, kc):
|
||||
tk[i] ^= tk[i - 1]
|
||||
|
||||
# Copy values into round key arrays
|
||||
j = 0
|
||||
while j < kc and t < round_key_count:
|
||||
self._ke[t // 4][t % 4] = tk[j]
|
||||
self._kd[rounds - (t // 4)][t % 4] = tk[j]
|
||||
j += 1
|
||||
t += 1
|
||||
|
||||
# Inverse-Cipher-ify the decryption round key (FIPS-197 section 5.3)
|
||||
for r in range(1, rounds):
|
||||
for j in range(0, 4):
|
||||
tt = self._kd[r][j]
|
||||
self._kd[r][j] = (
|
||||
self._U1[(tt >> 24) & 0xFF]
|
||||
^ self._U2[(tt >> 16) & 0xFF]
|
||||
^ self._U3[(tt >> 8) & 0xFF]
|
||||
^ self._U4[tt & 0xFF]
|
||||
)
|
||||
|
||||
def encrypt(self, plaintext: bytes) -> bytes:
|
||||
"""Encrypt a block of plain text using the AES block cipher."""
|
||||
|
||||
if len(plaintext) != 16:
|
||||
raise core.InvalidArgumentError(f'wrong block length {len(plaintext)}')
|
||||
|
||||
rounds = len(self._ke) - 1
|
||||
(s1, s2, s3) = [1, 2, 3]
|
||||
a = [0, 0, 0, 0]
|
||||
|
||||
# Convert plaintext to (ints ^ key)
|
||||
t = [
|
||||
(_compact_word(plaintext[4 * i : 4 * i + 4]) ^ self._ke[0][i])
|
||||
for i in range(0, 4)
|
||||
]
|
||||
|
||||
# Apply round transforms
|
||||
for r in range(1, rounds):
|
||||
for i in range(0, 4):
|
||||
a[i] = (
|
||||
self._T1[(t[i] >> 24) & 0xFF]
|
||||
^ self._T2[(t[(i + s1) % 4] >> 16) & 0xFF]
|
||||
^ self._T3[(t[(i + s2) % 4] >> 8) & 0xFF]
|
||||
^ self._T4[t[(i + s3) % 4] & 0xFF]
|
||||
^ self._ke[r][i]
|
||||
)
|
||||
t = copy.copy(a)
|
||||
|
||||
# The last round is special
|
||||
result = []
|
||||
for i in range(0, 4):
|
||||
tt = self._ke[rounds][i]
|
||||
result.append((self._S[(t[i] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
|
||||
result.append((self._S[(t[(i + s1) % 4] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
|
||||
result.append((self._S[(t[(i + s2) % 4] >> 8) & 0xFF] ^ (tt >> 8)) & 0xFF)
|
||||
result.append((self._S[t[(i + s3) % 4] & 0xFF] ^ tt) & 0xFF)
|
||||
|
||||
return bytes(result)
|
||||
|
||||
def decrypt(self, cipher_text: bytes) -> bytes:
|
||||
"""Decrypt a block of cipher text using the AES block cipher."""
|
||||
|
||||
if len(cipher_text) != 16:
|
||||
raise core.InvalidArgumentError(f'wrong block length {len(cipher_text)}')
|
||||
|
||||
rounds = len(self._kd) - 1
|
||||
(s1, s2, s3) = [3, 2, 1]
|
||||
a = [0, 0, 0, 0]
|
||||
|
||||
# Convert ciphertext to (ints ^ key)
|
||||
t = [
|
||||
(_compact_word(cipher_text[4 * i : 4 * i + 4]) ^ self._kd[0][i])
|
||||
for i in range(0, 4)
|
||||
]
|
||||
|
||||
# Apply round transforms
|
||||
for r in range(1, rounds):
|
||||
for i in range(0, 4):
|
||||
a[i] = (
|
||||
self._T5[(t[i] >> 24) & 0xFF]
|
||||
^ self._T6[(t[(i + s1) % 4] >> 16) & 0xFF]
|
||||
^ self._T7[(t[(i + s2) % 4] >> 8) & 0xFF]
|
||||
^ self._T8[t[(i + s3) % 4] & 0xFF]
|
||||
^ self._kd[r][i]
|
||||
)
|
||||
t = copy.copy(a)
|
||||
|
||||
# The last round is special
|
||||
result = []
|
||||
for i in range(0, 4):
|
||||
tt = self._kd[rounds][i]
|
||||
result.append((self._S_INV[(t[i] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
|
||||
result.append(
|
||||
(self._S_INV[(t[(i + s1) % 4] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF
|
||||
)
|
||||
result.append(
|
||||
(self._S_INV[(t[(i + s2) % 4] >> 8) & 0xFF] ^ (tt >> 8)) & 0xFF
|
||||
)
|
||||
result.append((self._S_INV[t[(i + s3) % 4] & 0xFF] ^ tt) & 0xFF)
|
||||
|
||||
return bytes(result)
|
||||
|
||||
|
||||
class _ECB:
|
||||
def __init__(self, key: bytes):
|
||||
self._aes = _AES(key)
|
||||
|
||||
def encrypt(self, plaintext: bytes) -> bytes:
|
||||
return b"".join(
|
||||
[
|
||||
self._aes.encrypt(
|
||||
plaintext[offset : offset + 16].ljust(16, b"\x00") # Pad 0.
|
||||
)
|
||||
for offset in range(0, len(plaintext), 16)
|
||||
]
|
||||
)
|
||||
|
||||
def decrypt(self, cipher_text: bytes) -> bytes:
|
||||
return b"".join(
|
||||
[
|
||||
self._aes.encrypt(cipher_text[offset : offset + 16])
|
||||
for offset in range(0, len(cipher_text), 16)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class _CBC:
|
||||
|
||||
def __init__(self, key: bytes, iv: bytes = bytes(16)) -> None:
|
||||
if len(iv) != 16:
|
||||
raise core.InvalidArgumentError(
|
||||
f'initialization vector must be 16 bytes, get {len(iv)}'
|
||||
)
|
||||
else:
|
||||
self._last_cipher_block = iv
|
||||
self._aes = _AES(key)
|
||||
|
||||
def encrypt(self, plaintext: bytes) -> bytes:
|
||||
cipher_text = b""
|
||||
for offset in range(0, len(plaintext), 16):
|
||||
pre_cipher_block = _xor(
|
||||
plaintext[offset : offset + 16], self._last_cipher_block
|
||||
)
|
||||
self._last_cipher_block = self._aes.encrypt(pre_cipher_block)
|
||||
cipher_text += self._last_cipher_block
|
||||
return cipher_text
|
||||
|
||||
def decrypt(self, cipher_text: bytes) -> bytes:
|
||||
plaintext = b""
|
||||
for offset in range(0, len(cipher_text), 16):
|
||||
plaintext += _xor(
|
||||
self._aes.decrypt(cipher_text[offset : offset + 16]),
|
||||
self._last_cipher_block,
|
||||
)
|
||||
self._last_cipher_block = cipher_text[offset : offset + 16]
|
||||
|
||||
return plaintext
|
||||
|
||||
|
||||
class _CMAC:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: bytes,
|
||||
msg: bytes = bytes(16),
|
||||
mac_len: int = 16,
|
||||
update_after_digest: bool = False,
|
||||
) -> None:
|
||||
self.digest_size = mac_len
|
||||
self._key = key
|
||||
self._block_size = bs = 16
|
||||
self._mac_tag: Optional[bytes] = None
|
||||
self._update_after_digest = update_after_digest
|
||||
|
||||
# Section 5.3 of NIST SP 800 38B and Appendix B
|
||||
if bs == 8:
|
||||
const_Rb = 0x1B
|
||||
self._max_size = 8 * (2**21)
|
||||
elif bs == 16:
|
||||
const_Rb = 0x87
|
||||
self._max_size = 16 * (2**48)
|
||||
else:
|
||||
raise core.InvalidArgumentError(
|
||||
f"CMAC requires a cipher with a block size of 8 or 16 bytes, not {bs}"
|
||||
)
|
||||
|
||||
# Compute sub-keys
|
||||
zero_block = bytes(bs)
|
||||
self._ecb = _ECB(key)
|
||||
L = self._ecb.encrypt(zero_block)
|
||||
if L[0] & 0x80:
|
||||
self._k1 = _shift_bytes(L, const_Rb)
|
||||
else:
|
||||
self._k1 = _shift_bytes(L)
|
||||
if self._k1[0] & 0x80:
|
||||
self._k2 = _shift_bytes(self._k1, const_Rb)
|
||||
else:
|
||||
self._k2 = _shift_bytes(self._k1)
|
||||
|
||||
# Initialize CBC cipher with zero IV
|
||||
self._cbc = _CBC(key, zero_block)
|
||||
|
||||
# Cache for outstanding data to authenticate
|
||||
self._cache = bytearray(bs)
|
||||
self._cache_n = 0
|
||||
|
||||
# Last piece of cipher text produced
|
||||
self._last_ct = zero_block
|
||||
|
||||
# Last block that was encrypted with AES
|
||||
self._last_pt: Optional[bytes] = None
|
||||
|
||||
# Counter for total message size
|
||||
self._data_size = 0
|
||||
|
||||
if msg:
|
||||
self.update(msg)
|
||||
|
||||
def update(self, msg: bytes) -> _CMAC:
|
||||
"""Authenticate the next chunk of message.
|
||||
|
||||
Args:
|
||||
data (byte string/byte array/memoryview): The next chunk of data
|
||||
"""
|
||||
|
||||
if self._mac_tag is not None and not self._update_after_digest:
|
||||
raise core.InvalidStateError(
|
||||
"update() cannot be called after digest() or verify()"
|
||||
)
|
||||
|
||||
self._data_size += len(msg)
|
||||
bs = self._block_size
|
||||
|
||||
if self._cache_n > 0:
|
||||
filler = min(bs - self._cache_n, len(msg))
|
||||
self._cache[self._cache_n : self._cache_n + filler] = msg[:filler]
|
||||
self._cache_n += filler
|
||||
|
||||
if self._cache_n < bs:
|
||||
return self
|
||||
|
||||
msg = msg[filler:]
|
||||
self._update(self._cache)
|
||||
self._cache_n = 0
|
||||
|
||||
remain = len(msg) % bs
|
||||
if remain > 0:
|
||||
self._update(msg[:-remain])
|
||||
self._cache[:remain] = msg[-remain:]
|
||||
else:
|
||||
self._update(msg)
|
||||
self._cache_n = remain
|
||||
return self
|
||||
|
||||
def _update(self, data_block: bytes) -> None:
|
||||
"""Update a block aligned to the block boundary"""
|
||||
|
||||
bs = self._block_size
|
||||
assert len(data_block) % bs == 0
|
||||
|
||||
if len(data_block) == 0:
|
||||
return
|
||||
|
||||
ct = self._cbc.encrypt(data_block)
|
||||
if len(data_block) == bs:
|
||||
second_last = self._last_ct
|
||||
else:
|
||||
second_last = ct[-bs * 2 : -bs]
|
||||
self._last_ct = ct[-bs:]
|
||||
self._last_pt = _xor(second_last, data_block[-bs:])
|
||||
|
||||
def digest(self) -> bytes:
|
||||
|
||||
bs = self._block_size
|
||||
|
||||
if self._mac_tag is not None and not self._update_after_digest:
|
||||
return self._mac_tag
|
||||
|
||||
if self._data_size > self._max_size:
|
||||
raise core.InvalidArgumentError("MAC is unsafe for this message")
|
||||
|
||||
if self._cache_n == 0 and self._data_size > 0 and self._last_pt:
|
||||
# Last block was full
|
||||
pt = _xor(self._last_pt, self._k1)
|
||||
else:
|
||||
# Last block is partial (or message length is zero)
|
||||
partial = self._cache[:]
|
||||
partial[self._cache_n :] = b'\x80' + b'\x00' * (bs - self._cache_n - 1)
|
||||
pt = _xor(_xor(self._last_ct, partial), self._k2)
|
||||
|
||||
self._mac_tag = self._ecb.encrypt(pt)[: self.digest_size]
|
||||
|
||||
return self._mac_tag
|
||||
|
||||
|
||||
# Define the original Point class for clarity and conversion purposes
|
||||
@dataclasses.dataclass
|
||||
class _Point:
|
||||
"""Represents a point on the elliptic curve in affine coordinates."""
|
||||
|
||||
curve: _EllipticCurve
|
||||
x: int = 0
|
||||
y: int = 0
|
||||
infinite: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _JacobianPoint:
|
||||
"""Represents a point on the elliptic curve in Jacobian coordinates."""
|
||||
|
||||
curve: _EllipticCurve
|
||||
x: int = 1 # For point at infinity (1:1:0)
|
||||
y: int = 1
|
||||
z: int = 0 # z = 0 indicates point at infinity
|
||||
|
||||
@classmethod
|
||||
def point_at_infinity(cls, curve: _EllipticCurve) -> _JacobianPoint:
|
||||
return _JacobianPoint(curve=curve, x=1, y=1, z=0)
|
||||
|
||||
@classmethod
|
||||
def from_affine(cls, affine_point: _Point) -> _JacobianPoint:
|
||||
if affine_point.infinite:
|
||||
return _JacobianPoint.point_at_infinity(affine_point.curve)
|
||||
# A simple conversion is (x, y, 1)
|
||||
return _JacobianPoint(
|
||||
curve=affine_point.curve, x=affine_point.x, y=affine_point.y, z=1
|
||||
)
|
||||
|
||||
def to_affine(self) -> _Point:
|
||||
if self.z == 0:
|
||||
return _Point(infinite=True, curve=self.curve)
|
||||
|
||||
p = self.curve.p
|
||||
inv_z = pow(self.z, -1, p)
|
||||
affine_x = (self.x * inv_z**2) % p
|
||||
affine_y = (self.y * inv_z**3) % p
|
||||
|
||||
return _Point(curve=self.curve, x=affine_x, y=affine_y, infinite=False)
|
||||
|
||||
def double(self) -> _JacobianPoint:
|
||||
if self.z == 0 or self.y == 0:
|
||||
return _JacobianPoint.point_at_infinity(self.curve)
|
||||
|
||||
s = 4 * self.x * self.y**2
|
||||
m = 3 * self.x**2 + self.curve.a * self.z**4
|
||||
x2 = m**2 - 2 * s
|
||||
y2 = m * (s - x2) - 8 * self.y**4
|
||||
z2 = 2 * self.y * self.z
|
||||
p = self.curve.p
|
||||
|
||||
return _JacobianPoint(curve=self.curve, x=x2 % p, y=y2 % p, z=z2 % p)
|
||||
|
||||
def __add__(self, other: _JacobianPoint) -> _JacobianPoint:
|
||||
if self.z == 0 and other.z == 0:
|
||||
return _JacobianPoint.point_at_infinity(self.curve)
|
||||
elif self.z == 0:
|
||||
return other
|
||||
elif other.z == 0:
|
||||
return self
|
||||
|
||||
x1 = self.x
|
||||
y1 = self.y
|
||||
z1 = self.z
|
||||
x2 = other.x
|
||||
y2 = other.y
|
||||
z2 = other.z
|
||||
p = self.curve.p
|
||||
u1 = (x1 * z2**2) % p
|
||||
u2 = (x2 * z1**2) % p
|
||||
s1 = (y1 * z2**3) % p
|
||||
s2 = (y2 * z1**3) % p
|
||||
|
||||
if u1 == u2:
|
||||
if s1 != s2:
|
||||
return _JacobianPoint.point_at_infinity(self.curve)
|
||||
else:
|
||||
return self.double()
|
||||
else:
|
||||
h = u2 - u1
|
||||
r = s2 - s1
|
||||
|
||||
h3 = h**3 % p
|
||||
u1h2 = (u1 * h**2) % p
|
||||
x3 = r**2 - h3 - 2 * u1h2
|
||||
y3 = r * (u1h2 - x3) - s1 * h3
|
||||
z3 = h * z1 * z2
|
||||
|
||||
return _JacobianPoint(self.curve, x3 % p, y3 % p, z3 % p)
|
||||
|
||||
def __mul__(self, k: int) -> _JacobianPoint:
|
||||
addend = self
|
||||
result = _JacobianPoint.point_at_infinity(self.curve)
|
||||
|
||||
while k > 0:
|
||||
if k % 2 != 0:
|
||||
result = result + addend
|
||||
addend = addend.double()
|
||||
k = k >> 1
|
||||
return result
|
||||
|
||||
def __rmul__(self, k: int) -> _JacobianPoint:
|
||||
return self * k
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _EllipticCurve:
|
||||
p: int
|
||||
a: int
|
||||
b: int
|
||||
n: int
|
||||
g_x: int
|
||||
g_y: int
|
||||
|
||||
_generator_jacobian: _JacobianPoint = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self._generator_jacobian = _JacobianPoint(
|
||||
curve=self, x=self.g_x, y=self.g_y, z=1
|
||||
)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PrivateKey:
|
||||
key: int
|
||||
curve: _EllipticCurve
|
||||
|
||||
def generate_private_key(self) -> PrivateKey:
|
||||
"""Generates a random private key."""
|
||||
return self.PrivateKey(key=secrets.randbelow(self.n), curve=self)
|
||||
|
||||
def generate_public_key(self, private_key: int) -> _Point:
|
||||
"""Generates a public key from a private key using Jacobian coordinates for scalar multiplication."""
|
||||
public_key_jacobian = self._generator_jacobian * private_key
|
||||
return public_key_jacobian.to_affine()
|
||||
|
||||
def ecdh_shared_secret(self, private_key: int, other_public_key: _Point) -> bytes:
|
||||
"""Computes the shared secret using ECDH."""
|
||||
other_public_key_jacobian = _JacobianPoint.from_affine(other_public_key)
|
||||
shared_point_jacobian = other_public_key_jacobian * private_key
|
||||
shared_point_affine = shared_point_jacobian.to_affine()
|
||||
if shared_point_affine.infinite:
|
||||
raise core.InvalidPacketError(
|
||||
"Shared secret calculation resulted in the point at infinite"
|
||||
)
|
||||
return shared_point_affine.x.to_bytes(32, 'big')
|
||||
|
||||
@classmethod
|
||||
def SECP256R1(cls) -> _EllipticCurve:
|
||||
p = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF
|
||||
a = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC
|
||||
b = 0x5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B
|
||||
n = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551 # Curve order
|
||||
g_x = 0x6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296
|
||||
g_y = 0x4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5
|
||||
|
||||
return _EllipticCurve(p=p, a=a, b=b, n=n, g_x=g_x, g_y=g_y)
|
||||
|
||||
|
||||
class EccKey:
|
||||
def __init__(self, private_key: _EllipticCurve.PrivateKey) -> None:
|
||||
self.private_key = private_key
|
||||
|
||||
@functools.cached_property
|
||||
def x(self) -> bytes:
|
||||
return self.private_key.curve.generate_public_key(
|
||||
self.private_key.key
|
||||
).x.to_bytes(32, byteorder='big')
|
||||
|
||||
@functools.cached_property
|
||||
def y(self) -> bytes:
|
||||
return self.private_key.curve.generate_public_key(
|
||||
self.private_key.key
|
||||
).y.to_bytes(32, byteorder='big')
|
||||
|
||||
def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes:
|
||||
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
|
||||
y = int.from_bytes(public_key_y, byteorder='big', signed=False)
|
||||
return self.private_key.curve.ecdh_shared_secret(
|
||||
self.private_key.key,
|
||||
_Point(x=x, y=y, curve=self.private_key.curve),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate(cls) -> EccKey:
|
||||
return EccKey(_EllipticCurve.SECP256R1().generate_private_key())
|
||||
|
||||
@classmethod
|
||||
def from_private_key_bytes(cls, d_bytes: bytes) -> EccKey:
|
||||
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
|
||||
return EccKey(_EllipticCurve.PrivateKey(d, _EllipticCurve.SECP256R1()))
|
||||
|
||||
|
||||
def e(key: bytes, data: bytes) -> bytes:
|
||||
'''
|
||||
AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output.
|
||||
|
||||
See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e
|
||||
'''
|
||||
|
||||
return _ECB(key[::-1]).encrypt(data[::-1])[::-1]
|
||||
|
||||
|
||||
def aes_cmac(m: bytes, k: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC
|
||||
|
||||
NOTE: the input and output of this internal function are in big-endian byte order
|
||||
'''
|
||||
return _CMAC(key=k, msg=m).digest()
|
||||
82
bumble/crypto/cryptography.py
Normal file
82
bumble/crypto/cryptography.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright 2021-2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
||||
from cryptography.hazmat.primitives import ciphers, cmac
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.hazmat.primitives.ciphers import algorithms, modes
|
||||
|
||||
|
||||
def e(key: bytes, data: bytes) -> bytes:
|
||||
'''
|
||||
AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output.
|
||||
|
||||
See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e
|
||||
'''
|
||||
|
||||
cipher = ciphers.Cipher(algorithms.AES(key[::-1]), modes.ECB())
|
||||
encryptor = cipher.encryptor()
|
||||
return encryptor.update(data[::-1])[::-1]
|
||||
|
||||
|
||||
class EccKey:
|
||||
def __init__(self, private_key: ec.EllipticCurvePrivateKey) -> None:
|
||||
self.private_key = private_key
|
||||
|
||||
@classmethod
|
||||
def generate(cls) -> EccKey:
|
||||
return EccKey(ec.generate_private_key(ec.SECP256R1()))
|
||||
|
||||
@classmethod
|
||||
def from_private_key_bytes(cls, d_bytes: bytes) -> EccKey:
|
||||
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
|
||||
return EccKey(ec.derive_private_key(d, ec.SECP256R1()))
|
||||
|
||||
@functools.cached_property
|
||||
def x(self) -> bytes:
|
||||
return (
|
||||
self.private_key.public_key()
|
||||
.public_numbers()
|
||||
.x.to_bytes(32, byteorder='big')
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def y(self) -> bytes:
|
||||
return (
|
||||
self.private_key.public_key()
|
||||
.public_numbers()
|
||||
.y.to_bytes(32, byteorder='big')
|
||||
)
|
||||
|
||||
def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes:
|
||||
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
|
||||
y = int.from_bytes(public_key_y, byteorder='big', signed=False)
|
||||
return self.private_key.exchange(
|
||||
ec.ECDH(),
|
||||
ec.EllipticCurvePublicNumbers(x, y, ec.SECP256R1()).public_key(),
|
||||
)
|
||||
|
||||
|
||||
def aes_cmac(m: bytes, k: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC
|
||||
|
||||
NOTE: the input and output of this internal function are in big-endian byte order
|
||||
'''
|
||||
mac = cmac.CMAC(algorithms.AES(k))
|
||||
mac.update(m)
|
||||
return mac.finalize()
|
||||
1025
bumble/data_types.py
Normal file
1025
bumble/data_types.py
Normal file
File diff suppressed because it is too large
Load Diff
3702
bumble/device.py
3702
bumble/device.py
File diff suppressed because it is too large
Load Diff
@@ -20,13 +20,14 @@ like loading firmware after a cold start.
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import pathlib
|
||||
import platform
|
||||
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Iterable, Optional
|
||||
|
||||
from . import rtk, intel
|
||||
from .common import Driver
|
||||
from bumble.drivers import intel, rtk
|
||||
from bumble.drivers.common import Driver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.host import Host
|
||||
@@ -45,7 +46,7 @@ async def get_driver_for_host(host: Host) -> Optional[Driver]:
|
||||
found.
|
||||
If a "driver" HCI metadata entry is present, only that driver class will be probed.
|
||||
"""
|
||||
driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver, "intel": intel.Driver}
|
||||
driver_classes: dict[str, type[Driver]] = {"rtk": rtk.Driver, "intel": intel.Driver}
|
||||
probe_list: Iterable[str]
|
||||
if driver_name := host.hci_metadata.get("driver"):
|
||||
# Only probe a single driver
|
||||
|
||||
@@ -11,18 +11,32 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Support for Intel USB controllers.
|
||||
Loosely based on the Fuchsia OS implementation.
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import logging
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import struct
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from bumble import core, hci, utils
|
||||
from bumble.drivers import common
|
||||
from bumble.hci import (
|
||||
hci_vendor_command_op_code, # type: ignore
|
||||
HCI_Command,
|
||||
HCI_Reset_Command,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.host import Host
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -34,39 +48,327 @@ logger = logging.getLogger(__name__)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
INTEL_USB_PRODUCTS = {
|
||||
# Intel AX210
|
||||
(0x8087, 0x0032),
|
||||
# Intel BE200
|
||||
(0x8087, 0x0036),
|
||||
(0x8087, 0x0032), # AX210
|
||||
(0x8087, 0x0033), # AX211
|
||||
(0x8087, 0x0036), # BE200
|
||||
}
|
||||
|
||||
INTEL_FW_IMAGE_NAMES = [
|
||||
"ibt-0040-0041",
|
||||
"ibt-0040-1020",
|
||||
"ibt-0040-1050",
|
||||
"ibt-0040-2120",
|
||||
"ibt-0040-4150",
|
||||
"ibt-0041-0041",
|
||||
"ibt-0180-0041",
|
||||
"ibt-0180-1050",
|
||||
"ibt-0180-4150",
|
||||
"ibt-0291-0291",
|
||||
"ibt-1040-0041",
|
||||
"ibt-1040-1020",
|
||||
"ibt-1040-1050",
|
||||
"ibt-1040-2120",
|
||||
"ibt-1040-4150",
|
||||
]
|
||||
|
||||
INTEL_FIRMWARE_DIR_ENV = "BUMBLE_INTEL_FIRMWARE_DIR"
|
||||
INTEL_LINUX_FIRMWARE_DIR = "/lib/firmware/intel"
|
||||
|
||||
_MAX_FRAGMENT_SIZE = 252
|
||||
_POST_RESET_DELAY = 0.2
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HCI Commands
|
||||
# -----------------------------------------------------------------------------
|
||||
HCI_INTEL_DDC_CONFIG_WRITE_COMMAND = hci_vendor_command_op_code(0xFC8B) # type: ignore
|
||||
HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD = [0x03, 0xE4, 0x02, 0x00]
|
||||
HCI_INTEL_WRITE_DEVICE_CONFIG_COMMAND = hci.hci_vendor_command_op_code(0x008B)
|
||||
HCI_INTEL_READ_VERSION_COMMAND = hci.hci_vendor_command_op_code(0x0005)
|
||||
HCI_INTEL_RESET_COMMAND = hci.hci_vendor_command_op_code(0x0001)
|
||||
HCI_INTEL_SECURE_SEND_COMMAND = hci.hci_vendor_command_op_code(0x0009)
|
||||
HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND = hci.hci_vendor_command_op_code(0x000E)
|
||||
|
||||
HCI_Command.register_commands(globals())
|
||||
hci.HCI_Command.register_commands(globals())
|
||||
|
||||
|
||||
@HCI_Command.command( # type: ignore
|
||||
fields=[("params", "*")],
|
||||
return_parameters_fields=[
|
||||
@hci.HCI_Command.command
|
||||
@dataclasses.dataclass
|
||||
class HCI_Intel_Read_Version_Command(hci.HCI_Command):
|
||||
param0: int = dataclasses.field(metadata=hci.metadata(1))
|
||||
|
||||
return_parameters_fields = [
|
||||
("status", hci.STATUS_SPEC),
|
||||
("tlv", "*"),
|
||||
]
|
||||
|
||||
|
||||
@hci.HCI_Command.command
|
||||
@dataclasses.dataclass
|
||||
class Hci_Intel_Secure_Send_Command(hci.HCI_Command):
|
||||
data_type: int = dataclasses.field(metadata=hci.metadata(1))
|
||||
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
|
||||
|
||||
return_parameters_fields = [
|
||||
("status", 1),
|
||||
]
|
||||
|
||||
|
||||
@hci.HCI_Command.command
|
||||
@dataclasses.dataclass
|
||||
class HCI_Intel_Reset_Command(hci.HCI_Command):
|
||||
reset_type: int = dataclasses.field(metadata=hci.metadata(1))
|
||||
patch_enable: int = dataclasses.field(metadata=hci.metadata(1))
|
||||
ddc_reload: int = dataclasses.field(metadata=hci.metadata(1))
|
||||
boot_option: int = dataclasses.field(metadata=hci.metadata(1))
|
||||
boot_address: int = dataclasses.field(metadata=hci.metadata(4))
|
||||
|
||||
return_parameters_fields = [
|
||||
("data", "*"),
|
||||
]
|
||||
|
||||
|
||||
@hci.HCI_Command.command
|
||||
@dataclasses.dataclass
|
||||
class Hci_Intel_Write_Device_Config_Command(hci.HCI_Command):
|
||||
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
|
||||
|
||||
return_parameters_fields = [
|
||||
("status", hci.STATUS_SPEC),
|
||||
("params", "*"),
|
||||
],
|
||||
)
|
||||
class Hci_Intel_DDC_Config_Write_Command(HCI_Command):
|
||||
pass
|
||||
]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
def intel_firmware_dir() -> pathlib.Path:
|
||||
"""
|
||||
Returns:
|
||||
A path to a subdir of the project data dir for Intel firmware.
|
||||
The directory is created if it doesn't exist.
|
||||
"""
|
||||
from bumble.drivers import project_data_dir
|
||||
|
||||
p = project_data_dir() / "firmware" / "intel"
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
return p
|
||||
|
||||
|
||||
def _find_binary_path(file_name: str) -> pathlib.Path | None:
|
||||
# First check if an environment variable is set
|
||||
if INTEL_FIRMWARE_DIR_ENV in os.environ:
|
||||
if (
|
||||
path := pathlib.Path(os.environ[INTEL_FIRMWARE_DIR_ENV]) / file_name
|
||||
).is_file():
|
||||
logger.debug(f"{file_name} found in env dir")
|
||||
return path
|
||||
|
||||
# When the environment variable is set, don't look elsewhere
|
||||
return None
|
||||
|
||||
# Then, look where the firmware download tool writes by default
|
||||
if (path := intel_firmware_dir() / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in project data dir")
|
||||
return path
|
||||
|
||||
# Then, look in the package's driver directory
|
||||
if (path := pathlib.Path(__file__).parent / "intel_fw" / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in package dir")
|
||||
return path
|
||||
|
||||
# On Linux, check the system's FW directory
|
||||
if (
|
||||
platform.system() == "Linux"
|
||||
and (path := pathlib.Path(INTEL_LINUX_FIRMWARE_DIR) / file_name).is_file()
|
||||
):
|
||||
logger.debug(f"{file_name} found in Linux system FW dir")
|
||||
return path
|
||||
|
||||
# Finally look in the current directory
|
||||
if (path := pathlib.Path.cwd() / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in CWD")
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _parse_tlv(data: bytes) -> list[tuple[ValueType, Any]]:
|
||||
result: list[tuple[ValueType, Any]] = []
|
||||
while len(data) >= 2:
|
||||
value_type = ValueType(data[0])
|
||||
value_length = data[1]
|
||||
value = data[2 : 2 + value_length]
|
||||
typed_value: Any
|
||||
|
||||
if value_type == ValueType.END:
|
||||
break
|
||||
|
||||
if value_type in (ValueType.CNVI, ValueType.CNVR):
|
||||
(v,) = struct.unpack("<I", value)
|
||||
typed_value = (
|
||||
(((v >> 0) & 0xF) << 12)
|
||||
| (((v >> 4) & 0xF) << 0)
|
||||
| (((v >> 8) & 0xF) << 4)
|
||||
| (((v >> 24) & 0xF) << 8)
|
||||
)
|
||||
elif value_type == ValueType.HARDWARE_INFO:
|
||||
(v,) = struct.unpack("<I", value)
|
||||
typed_value = HardwareInfo(
|
||||
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
|
||||
)
|
||||
elif value_type in (
|
||||
ValueType.USB_VENDOR_ID,
|
||||
ValueType.USB_PRODUCT_ID,
|
||||
ValueType.DEVICE_REVISION,
|
||||
):
|
||||
(typed_value,) = struct.unpack("<H", value)
|
||||
elif value_type == ValueType.CURRENT_MODE_OF_OPERATION:
|
||||
typed_value = ModeOfOperation(value[0])
|
||||
elif value_type in (
|
||||
ValueType.BUILD_TYPE,
|
||||
ValueType.BUILD_NUMBER,
|
||||
ValueType.SECURE_BOOT,
|
||||
ValueType.OTP_LOCK,
|
||||
ValueType.API_LOCK,
|
||||
ValueType.DEBUG_LOCK,
|
||||
ValueType.SECURE_BOOT_ENGINE_TYPE,
|
||||
):
|
||||
typed_value = value[0]
|
||||
elif value_type == ValueType.TIMESTAMP:
|
||||
typed_value = Timestamp(value[0], value[1])
|
||||
elif value_type == ValueType.FIRMWARE_BUILD:
|
||||
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
|
||||
elif value_type == ValueType.BLUETOOTH_ADDRESS:
|
||||
typed_value = hci.Address(
|
||||
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
|
||||
)
|
||||
else:
|
||||
typed_value = value
|
||||
|
||||
result.append((value_type, typed_value))
|
||||
data = data[2 + value_length :]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class DriverError(core.BaseBumbleError):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"IntelDriverError({self.message})"
|
||||
|
||||
|
||||
class ValueType(utils.OpenIntEnum):
|
||||
END = 0x00
|
||||
CNVI = 0x10
|
||||
CNVR = 0x11
|
||||
HARDWARE_INFO = 0x12
|
||||
DEVICE_REVISION = 0x16
|
||||
CURRENT_MODE_OF_OPERATION = 0x1C
|
||||
USB_VENDOR_ID = 0x17
|
||||
USB_PRODUCT_ID = 0x18
|
||||
TIMESTAMP = 0x1D
|
||||
BUILD_TYPE = 0x1E
|
||||
BUILD_NUMBER = 0x1F
|
||||
SECURE_BOOT = 0x28
|
||||
OTP_LOCK = 0x2A
|
||||
API_LOCK = 0x2B
|
||||
DEBUG_LOCK = 0x2C
|
||||
FIRMWARE_BUILD = 0x2D
|
||||
SECURE_BOOT_ENGINE_TYPE = 0x2F
|
||||
BLUETOOTH_ADDRESS = 0x30
|
||||
|
||||
|
||||
class HardwarePlatform(utils.OpenIntEnum):
|
||||
INTEL_37 = 0x37
|
||||
|
||||
|
||||
class HardwareVariant(utils.OpenIntEnum):
|
||||
# This is a just a partial list.
|
||||
# Add other constants here as new hardware is encountered and tested.
|
||||
TYPHOON_PEAK = 0x17
|
||||
GARFIELD_PEAK = 0x19
|
||||
GALE_PEAK = 0x1C
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HardwareInfo:
|
||||
platform: HardwarePlatform
|
||||
variant: HardwareVariant
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Timestamp:
|
||||
week: int
|
||||
year: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FirmwareBuild:
|
||||
build_number: int
|
||||
timestamp: Timestamp
|
||||
|
||||
|
||||
class ModeOfOperation(utils.OpenIntEnum):
|
||||
BOOTLOADER = 0x01
|
||||
INTERMEDIATE = 0x02
|
||||
OPERATIONAL = 0x03
|
||||
|
||||
|
||||
class SecureBootEngineType(utils.OpenIntEnum):
|
||||
RSA = 0x00
|
||||
ECDSA = 0x01
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BootParams:
|
||||
css_header_offset: int
|
||||
css_header_size: int
|
||||
pki_offset: int
|
||||
pki_size: int
|
||||
sig_offset: int
|
||||
sig_size: int
|
||||
write_offset: int
|
||||
|
||||
|
||||
_BOOT_PARAMS = {
|
||||
SecureBootEngineType.RSA: BootParams(0, 128, 128, 256, 388, 256, 964),
|
||||
SecureBootEngineType.ECDSA: BootParams(644, 128, 772, 96, 868, 96, 964),
|
||||
}
|
||||
|
||||
|
||||
class Driver(common.Driver):
|
||||
def __init__(self, host):
|
||||
def __init__(self, host: Host) -> None:
|
||||
self.host = host
|
||||
self.max_in_flight_firmware_load_commands = 1
|
||||
self.pending_firmware_load_commands: collections.deque[hci.HCI_Command] = (
|
||||
collections.deque()
|
||||
)
|
||||
self.can_send_firmware_load_command = asyncio.Event()
|
||||
self.can_send_firmware_load_command.set()
|
||||
self.firmware_load_complete = asyncio.Event()
|
||||
self.reset_complete = asyncio.Event()
|
||||
|
||||
# Parse configuration options from the driver name.
|
||||
self.ddc_addon: Optional[bytes] = None
|
||||
self.ddc_override: Optional[bytes] = None
|
||||
driver = host.hci_metadata.get("driver")
|
||||
if driver is not None and driver.startswith("intel/"):
|
||||
for key, value in [
|
||||
key_eq_value.split(":") for key_eq_value in driver[6:].split("+")
|
||||
]:
|
||||
if key == "ddc_addon":
|
||||
self.ddc_addon = bytes.fromhex(value)
|
||||
elif key == "ddc_override":
|
||||
self.ddc_override = bytes.fromhex(value)
|
||||
|
||||
@staticmethod
|
||||
def check(host):
|
||||
def check(host: Host) -> bool:
|
||||
driver = host.hci_metadata.get("driver")
|
||||
if driver == "intel":
|
||||
if driver == "intel" or driver is not None and driver.startswith("intel/"):
|
||||
return True
|
||||
|
||||
vendor_id = host.hci_metadata.get("vendor_id")
|
||||
@@ -85,18 +387,284 @@ class Driver(common.Driver):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def for_host(cls, host, force=False): # type: ignore
|
||||
async def for_host(cls, host: Host, force: bool = False):
|
||||
# Only instantiate this driver if explicitly selected
|
||||
if not force and not cls.check(host):
|
||||
return None
|
||||
|
||||
return cls(host)
|
||||
|
||||
async def init_controller(self):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
"""Handler for event packets that are received from an ACL channel"""
|
||||
event = hci.HCI_Event.from_bytes(packet)
|
||||
|
||||
if not isinstance(event, hci.HCI_Command_Complete_Event):
|
||||
self.host.on_hci_event_packet(event)
|
||||
return
|
||||
|
||||
if not event.return_parameters == hci.HCI_SUCCESS:
|
||||
raise DriverError("HCI_Command_Complete_Event error")
|
||||
|
||||
if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets:
|
||||
logger.debug(
|
||||
"max_in_flight_firmware_load_commands update: "
|
||||
f"{event.num_hci_command_packets}"
|
||||
)
|
||||
self.max_in_flight_firmware_load_commands = event.num_hci_command_packets
|
||||
logger.debug(f"event: {event}")
|
||||
self.pending_firmware_load_commands.popleft()
|
||||
in_flight = len(self.pending_firmware_load_commands)
|
||||
logger.debug(f"event received, {in_flight} still in flight")
|
||||
if in_flight < self.max_in_flight_firmware_load_commands:
|
||||
self.can_send_firmware_load_command.set()
|
||||
|
||||
async def send_firmware_load_command(self, command: hci.HCI_Command) -> None:
|
||||
# Wait until we can send.
|
||||
await self.can_send_firmware_load_command.wait()
|
||||
|
||||
# Send the command and adjust counters.
|
||||
self.host.send_hci_packet(command)
|
||||
self.pending_firmware_load_commands.append(command)
|
||||
in_flight = len(self.pending_firmware_load_commands)
|
||||
if in_flight >= self.max_in_flight_firmware_load_commands:
|
||||
logger.debug(f"max commands in flight reached [{in_flight}]")
|
||||
self.can_send_firmware_load_command.clear()
|
||||
|
||||
async def send_firmware_data(self, data_type: int, data: bytes) -> None:
|
||||
while data:
|
||||
fragment_size = min(len(data), _MAX_FRAGMENT_SIZE)
|
||||
fragment = data[:fragment_size]
|
||||
data = data[fragment_size:]
|
||||
|
||||
await self.send_firmware_load_command(
|
||||
Hci_Intel_Secure_Send_Command(data_type=data_type, data=fragment)
|
||||
)
|
||||
|
||||
async def load_firmware(self) -> None:
|
||||
self.host.ready = True
|
||||
await self.host.send_command(HCI_Reset_Command(), check_result=True)
|
||||
await self.host.send_command(
|
||||
Hci_Intel_DDC_Config_Write_Command(
|
||||
params=HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD
|
||||
device_info = await self.read_device_info()
|
||||
logger.debug(
|
||||
"device info: \n%s",
|
||||
"\n".join(
|
||||
[
|
||||
f" {value_type.name}: {value}"
|
||||
for value_type, value in device_info.items()
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# Check if the firmware is already loaded.
|
||||
if (
|
||||
device_info.get(ValueType.CURRENT_MODE_OF_OPERATION)
|
||||
== ModeOfOperation.OPERATIONAL
|
||||
):
|
||||
logger.debug("firmware already loaded")
|
||||
return
|
||||
|
||||
# We only support some platforms and variants.
|
||||
hardware_info = device_info.get(ValueType.HARDWARE_INFO)
|
||||
if hardware_info is None:
|
||||
raise DriverError("hardware info missing")
|
||||
if hardware_info.platform != HardwarePlatform.INTEL_37:
|
||||
raise DriverError("hardware platform not supported")
|
||||
if hardware_info.variant not in (
|
||||
HardwareVariant.TYPHOON_PEAK,
|
||||
HardwareVariant.GARFIELD_PEAK,
|
||||
HardwareVariant.GALE_PEAK,
|
||||
):
|
||||
raise DriverError("hardware variant not supported")
|
||||
|
||||
# Compute the firmware name.
|
||||
if ValueType.CNVI not in device_info or ValueType.CNVR not in device_info:
|
||||
raise DriverError("insufficient device info, missing CNVI or CNVR")
|
||||
|
||||
firmware_base_name = (
|
||||
"ibt-"
|
||||
f"{device_info[ValueType.CNVI]:04X}-"
|
||||
f"{device_info[ValueType.CNVR]:04X}"
|
||||
)
|
||||
logger.debug(f"FW base name: {firmware_base_name}")
|
||||
|
||||
firmware_name = f"{firmware_base_name}.sfi"
|
||||
firmware_path = _find_binary_path(firmware_name)
|
||||
if not firmware_path:
|
||||
logger.warning(f"Firmware file {firmware_name} not found")
|
||||
logger.warning("See https://google.github.io/bumble/drivers/intel.html")
|
||||
return None
|
||||
logger.debug(f"loading firmware from {firmware_path}")
|
||||
firmware_image = firmware_path.read_bytes()
|
||||
|
||||
engine_type = device_info.get(ValueType.SECURE_BOOT_ENGINE_TYPE)
|
||||
if engine_type is None:
|
||||
raise DriverError("secure boot engine type missing")
|
||||
if engine_type not in _BOOT_PARAMS:
|
||||
raise DriverError("secure boot engine type not supported")
|
||||
|
||||
boot_params = _BOOT_PARAMS[engine_type]
|
||||
if len(firmware_image) < boot_params.write_offset:
|
||||
raise DriverError("firmware image too small")
|
||||
|
||||
# Register to receive vendor events.
|
||||
def on_vendor_event(event: hci.HCI_Vendor_Event):
|
||||
logger.debug(f"vendor event: {event}")
|
||||
event_type = event.parameters[0]
|
||||
if event_type == 0x02:
|
||||
# Boot event
|
||||
logger.debug("boot complete")
|
||||
self.reset_complete.set()
|
||||
elif event_type == 0x06:
|
||||
# Firmware load event
|
||||
logger.debug("download complete")
|
||||
self.firmware_load_complete.set()
|
||||
else:
|
||||
logger.debug(f"ignoring vendor event type {event_type}")
|
||||
|
||||
self.host.on("vendor_event", on_vendor_event)
|
||||
|
||||
# We need to temporarily intercept packets from the controller,
|
||||
# because they are formatted as HCI event packets but are received
|
||||
# on the ACL channel, so the host parser would get confused.
|
||||
saved_on_packet = self.host.on_packet
|
||||
self.host.on_packet = self.on_packet # type: ignore
|
||||
self.firmware_load_complete.clear()
|
||||
|
||||
# Send the CSS header
|
||||
data = firmware_image[
|
||||
boot_params.css_header_offset : boot_params.css_header_offset
|
||||
+ boot_params.css_header_size
|
||||
]
|
||||
await self.send_firmware_data(0x00, data)
|
||||
|
||||
# Send the PKI header
|
||||
data = firmware_image[
|
||||
boot_params.pki_offset : boot_params.pki_offset + boot_params.pki_size
|
||||
]
|
||||
await self.send_firmware_data(0x03, data)
|
||||
|
||||
# Send the Signature header
|
||||
data = firmware_image[
|
||||
boot_params.sig_offset : boot_params.sig_offset + boot_params.sig_size
|
||||
]
|
||||
await self.send_firmware_data(0x02, data)
|
||||
|
||||
# Send the rest of the image.
|
||||
# The payload consists of command objects, which are sent when they add up
|
||||
# to a multiple of 4 bytes.
|
||||
boot_address = 0
|
||||
offset = boot_params.write_offset
|
||||
fragment_size = 0
|
||||
while offset + 3 < len(firmware_image):
|
||||
(command_opcode,) = struct.unpack_from(
|
||||
"<H", firmware_image, offset + fragment_size
|
||||
)
|
||||
command_size = firmware_image[offset + fragment_size + 2]
|
||||
if command_opcode == HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND:
|
||||
(boot_address,) = struct.unpack_from(
|
||||
"<I", firmware_image, offset + fragment_size + 3
|
||||
)
|
||||
logger.debug(
|
||||
"found HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND, "
|
||||
f"boot_address={boot_address}"
|
||||
)
|
||||
fragment_size += 3 + command_size
|
||||
if fragment_size % 4 == 0:
|
||||
await self.send_firmware_data(
|
||||
0x01, firmware_image[offset : offset + fragment_size]
|
||||
)
|
||||
logger.debug(f"sent {fragment_size} bytes")
|
||||
offset += fragment_size
|
||||
fragment_size = 0
|
||||
|
||||
# Wait for the firmware loading to be complete.
|
||||
logger.debug("waiting for firmware to be loaded")
|
||||
await self.firmware_load_complete.wait()
|
||||
logger.debug("firmware loaded")
|
||||
|
||||
# Restore the original packet handler.
|
||||
self.host.on_packet = saved_on_packet # type: ignore
|
||||
|
||||
# Reset
|
||||
self.reset_complete.clear()
|
||||
self.host.send_hci_packet(
|
||||
HCI_Intel_Reset_Command(
|
||||
reset_type=0x00,
|
||||
patch_enable=0x01,
|
||||
ddc_reload=0x00,
|
||||
boot_option=0x01,
|
||||
boot_address=boot_address,
|
||||
)
|
||||
)
|
||||
logger.debug("waiting for reset completion")
|
||||
await self.reset_complete.wait()
|
||||
logger.debug("reset complete")
|
||||
|
||||
# Load the device config if there is one.
|
||||
if self.ddc_override:
|
||||
logger.debug("loading overridden DDC")
|
||||
await self.load_device_config(self.ddc_override)
|
||||
else:
|
||||
ddc_name = f"{firmware_base_name}.ddc"
|
||||
ddc_path = _find_binary_path(ddc_name)
|
||||
if ddc_path:
|
||||
logger.debug(f"loading DDC from {ddc_path}")
|
||||
ddc_data = ddc_path.read_bytes()
|
||||
await self.load_device_config(ddc_data)
|
||||
if self.ddc_addon:
|
||||
logger.debug("loading DDC addon")
|
||||
await self.load_device_config(self.ddc_addon)
|
||||
|
||||
async def load_device_config(self, ddc_data: bytes) -> None:
|
||||
while ddc_data:
|
||||
ddc_len = 1 + ddc_data[0]
|
||||
ddc_payload = ddc_data[:ddc_len]
|
||||
await self.host.send_command(
|
||||
Hci_Intel_Write_Device_Config_Command(data=ddc_payload)
|
||||
)
|
||||
ddc_data = ddc_data[ddc_len:]
|
||||
|
||||
async def reboot_bootloader(self) -> None:
|
||||
self.host.send_hci_packet(
|
||||
HCI_Intel_Reset_Command(
|
||||
reset_type=0x01,
|
||||
patch_enable=0x01,
|
||||
ddc_reload=0x01,
|
||||
boot_option=0x00,
|
||||
boot_address=0,
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(_POST_RESET_DELAY)
|
||||
|
||||
async def read_device_info(self) -> dict[ValueType, Any]:
|
||||
self.host.ready = True
|
||||
response = await self.host.send_command(hci.HCI_Reset_Command())
|
||||
if not (
|
||||
isinstance(response, hci.HCI_Command_Complete_Event)
|
||||
and response.return_parameters
|
||||
in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS)
|
||||
):
|
||||
# When the controller is in operational mode, the response is a
|
||||
# successful response.
|
||||
# When the controller is in bootloader mode,
|
||||
# HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything
|
||||
# else is a failure.
|
||||
logger.warning(f"unexpected response: {response}")
|
||||
raise DriverError("unexpected HCI response")
|
||||
|
||||
# Read the firmware version.
|
||||
response = await self.host.send_command(
|
||||
HCI_Intel_Read_Version_Command(param0=0xFF)
|
||||
)
|
||||
if not isinstance(response, hci.HCI_Command_Complete_Event):
|
||||
raise DriverError("unexpected HCI response")
|
||||
|
||||
if response.return_parameters.status != 0: # type: ignore
|
||||
raise DriverError("HCI_Intel_Read_Version_Command error")
|
||||
|
||||
tlvs = _parse_tlv(response.return_parameters.tlv) # type: ignore
|
||||
|
||||
# Convert the list to a dict. That's Ok here because we only expect each type
|
||||
# to appear just once.
|
||||
return dict(tlvs)
|
||||
|
||||
async def init_controller(self):
|
||||
await self.load_firmware()
|
||||
|
||||
@@ -17,10 +17,6 @@ Based on various online bits of information, including the Linux kernel.
|
||||
(see `drivers/bluetooth/btrtl.c`)
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from dataclasses import dataclass
|
||||
import asyncio
|
||||
import enum
|
||||
import logging
|
||||
@@ -29,19 +25,14 @@ import os
|
||||
import pathlib
|
||||
import platform
|
||||
import struct
|
||||
from typing import Tuple
|
||||
import weakref
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from bumble import core
|
||||
from bumble.hci import (
|
||||
hci_vendor_command_op_code,
|
||||
STATUS_SPEC,
|
||||
HCI_SUCCESS,
|
||||
HCI_Command,
|
||||
HCI_Reset_Command,
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
)
|
||||
from bumble import core, hci
|
||||
from bumble.drivers import common
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -183,27 +174,29 @@ RTK_USB_PRODUCTS = {
|
||||
# -----------------------------------------------------------------------------
|
||||
# HCI Commands
|
||||
# -----------------------------------------------------------------------------
|
||||
HCI_RTK_READ_ROM_VERSION_COMMAND = hci_vendor_command_op_code(0x6D)
|
||||
HCI_RTK_DOWNLOAD_COMMAND = hci_vendor_command_op_code(0x20)
|
||||
HCI_RTK_DROP_FIRMWARE_COMMAND = hci_vendor_command_op_code(0x66)
|
||||
HCI_Command.register_commands(globals())
|
||||
HCI_RTK_READ_ROM_VERSION_COMMAND = hci.hci_vendor_command_op_code(0x6D)
|
||||
HCI_RTK_DOWNLOAD_COMMAND = hci.hci_vendor_command_op_code(0x20)
|
||||
HCI_RTK_DROP_FIRMWARE_COMMAND = hci.hci_vendor_command_op_code(0x66)
|
||||
hci.HCI_Command.register_commands(globals())
|
||||
|
||||
|
||||
@HCI_Command.command(return_parameters_fields=[("status", STATUS_SPEC), ("version", 1)])
|
||||
class HCI_RTK_Read_ROM_Version_Command(HCI_Command):
|
||||
pass
|
||||
@hci.HCI_Command.command
|
||||
@dataclass
|
||||
class HCI_RTK_Read_ROM_Version_Command(hci.HCI_Command):
|
||||
return_parameters_fields = [("status", hci.STATUS_SPEC), ("version", 1)]
|
||||
|
||||
|
||||
@HCI_Command.command(
|
||||
fields=[("index", 1), ("payload", RTK_FRAGMENT_LENGTH)],
|
||||
return_parameters_fields=[("status", STATUS_SPEC), ("index", 1)],
|
||||
)
|
||||
class HCI_RTK_Download_Command(HCI_Command):
|
||||
pass
|
||||
@hci.HCI_Command.command
|
||||
@dataclass
|
||||
class HCI_RTK_Download_Command(hci.HCI_Command):
|
||||
index: int = field(metadata=hci.metadata(1))
|
||||
payload: bytes = field(metadata=hci.metadata(RTK_FRAGMENT_LENGTH))
|
||||
return_parameters_fields = [("status", hci.STATUS_SPEC), ("index", 1)]
|
||||
|
||||
|
||||
@HCI_Command.command()
|
||||
class HCI_RTK_Drop_Firmware_Command(HCI_Command):
|
||||
@hci.HCI_Command.command
|
||||
@dataclass
|
||||
class HCI_RTK_Drop_Firmware_Command(hci.HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
@@ -294,7 +287,7 @@ class Driver(common.Driver):
|
||||
@dataclass
|
||||
class DriverInfo:
|
||||
rom: int
|
||||
hci: Tuple[int, int]
|
||||
hci: tuple[int, int]
|
||||
config_needed: bool
|
||||
has_rom_version: bool
|
||||
has_msft_ext: bool = False
|
||||
@@ -495,21 +488,36 @@ class Driver(common.Driver):
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def get_loaded_firmware_version(host):
|
||||
response = await host.send_command(HCI_RTK_Read_ROM_Version_Command())
|
||||
|
||||
if response.return_parameters.status != hci.HCI_SUCCESS:
|
||||
return None
|
||||
|
||||
response = await host.send_command(
|
||||
hci.HCI_Read_Local_Version_Information_Command(), check_result=True
|
||||
)
|
||||
return (
|
||||
response.return_parameters.hci_subversion << 16
|
||||
| response.return_parameters.lmp_subversion
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def driver_info_for_host(cls, host):
|
||||
try:
|
||||
await host.send_command(
|
||||
HCI_Reset_Command(),
|
||||
hci.HCI_Reset_Command(),
|
||||
check_result=True,
|
||||
response_timeout=cls.POST_RESET_DELAY,
|
||||
)
|
||||
host.ready = True # Needed to let the host know the controller is ready.
|
||||
except asyncio.exceptions.TimeoutError:
|
||||
logger.warning("timeout waiting for hci reset, retrying")
|
||||
await host.send_command(HCI_Reset_Command(), check_result=True)
|
||||
await host.send_command(hci.HCI_Reset_Command(), check_result=True)
|
||||
host.ready = True
|
||||
|
||||
command = HCI_Read_Local_Version_Information_Command()
|
||||
command = hci.HCI_Read_Local_Version_Information_Command()
|
||||
response = await host.send_command(command, check_result=True)
|
||||
if response.command_opcode != command.op_code:
|
||||
logger.error("failed to probe local version information")
|
||||
@@ -596,9 +604,9 @@ class Driver(common.Driver):
|
||||
response = await self.host.send_command(
|
||||
HCI_RTK_Read_ROM_Version_Command(), check_result=True
|
||||
)
|
||||
if response.return_parameters.status != HCI_SUCCESS:
|
||||
if response.return_parameters.status != hci.HCI_SUCCESS:
|
||||
logger.warning("can't get ROM version")
|
||||
return
|
||||
return None
|
||||
rom_version = response.return_parameters.version
|
||||
logger.debug(f"ROM version before download: {rom_version:04X}")
|
||||
else:
|
||||
@@ -606,13 +614,14 @@ class Driver(common.Driver):
|
||||
|
||||
firmware = Firmware(self.firmware)
|
||||
logger.debug(f"firmware: project_id=0x{firmware.project_id:04X}")
|
||||
logger.debug(f"firmware: version=0x{firmware.version:04X}")
|
||||
for patch in firmware.patches:
|
||||
if patch[0] == rom_version + 1:
|
||||
logger.debug(f"using patch {patch[0]}")
|
||||
break
|
||||
else:
|
||||
logger.warning("no valid patch found for rom version {rom_version}")
|
||||
return
|
||||
return None
|
||||
|
||||
# Append the config if there is one.
|
||||
if self.config:
|
||||
@@ -634,9 +643,8 @@ class Driver(common.Driver):
|
||||
fragment = payload[fragment_offset : fragment_offset + RTK_FRAGMENT_LENGTH]
|
||||
logger.debug(f"downloading fragment {fragment_index}")
|
||||
await self.host.send_command(
|
||||
HCI_RTK_Download_Command(
|
||||
index=download_index, payload=fragment, check_result=True
|
||||
)
|
||||
HCI_RTK_Download_Command(index=download_index, payload=fragment),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
logger.debug("download complete!")
|
||||
@@ -645,11 +653,13 @@ class Driver(common.Driver):
|
||||
response = await self.host.send_command(
|
||||
HCI_RTK_Read_ROM_Version_Command(), check_result=True
|
||||
)
|
||||
if response.return_parameters.status != HCI_SUCCESS:
|
||||
if response.return_parameters.status != hci.HCI_SUCCESS:
|
||||
logger.warning("can't get ROM version")
|
||||
else:
|
||||
rom_version = response.return_parameters.version
|
||||
logger.debug(f"ROM version after download: {rom_version:04X}")
|
||||
logger.debug(f"ROM version after download: {rom_version:02X}")
|
||||
|
||||
return firmware.version
|
||||
|
||||
async def download_firmware(self):
|
||||
if self.driver_info.rom == RTK_ROM_LMP_8723A:
|
||||
@@ -668,7 +678,7 @@ class Driver(common.Driver):
|
||||
|
||||
async def init_controller(self):
|
||||
await self.download_firmware()
|
||||
await self.host.send_command(HCI_Reset_Command(), check_result=True)
|
||||
await self.host.send_command(hci.HCI_Reset_Command(), check_result=True)
|
||||
logger.info(f"loaded FW image {self.driver_info.fw_name}")
|
||||
|
||||
|
||||
|
||||
@@ -18,12 +18,12 @@
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from .gatt import (
|
||||
Service,
|
||||
Characteristic,
|
||||
GATT_GENERIC_ACCESS_SERVICE,
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
from bumble.gatt import (
|
||||
GATT_APPEARANCE_CHARACTERISTIC,
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
GATT_GENERIC_ACCESS_SERVICE,
|
||||
Characteristic,
|
||||
Service,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
289
bumble/gatt.py
289
bumble/gatt.py
@@ -23,29 +23,21 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import struct
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from typing import Iterable, Optional, Sequence, TypeVar, Union
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble.core import BaseBumbleError, UUID
|
||||
from bumble.att import Attribute, AttributeValue
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID, BaseBumbleError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.gatt_client import AttributeProxy
|
||||
from bumble.device import Connection
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Typing
|
||||
# -----------------------------------------------------------------------------
|
||||
_T = TypeVar('_T')
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -275,6 +267,13 @@ GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Sou
|
||||
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
|
||||
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
|
||||
|
||||
# Gaming Audio Service (GMAS)
|
||||
GATT_GMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2C00, 'GMAP Role')
|
||||
GATT_UGG_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C01, 'UGG Features')
|
||||
GATT_UGT_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C02, 'UGT Features')
|
||||
GATT_BGS_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C03, 'BGS Features')
|
||||
GATT_BGR_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C04, 'BGR Features')
|
||||
|
||||
# Hearing Access Service
|
||||
GATT_HEARING_AID_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2BDA, 'Hearing Aid Features')
|
||||
GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BDB, 'Hearing Aid Preset Control Point')
|
||||
@@ -288,6 +287,22 @@ GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-32
|
||||
GATT_ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
|
||||
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT')
|
||||
|
||||
# Apple Notification Center Service
|
||||
GATT_ANCS_SERVICE = UUID('7905F431-B5CE-4E99-A40F-4B1E122D00D0', 'Apple Notification Center')
|
||||
GATT_ANCS_NOTIFICATION_SOURCE_CHARACTERISTIC = UUID('9FBF120D-6301-42D9-8C58-25E699A21DBD', 'Notification Source')
|
||||
GATT_ANCS_CONTROL_POINT_CHARACTERISTIC = UUID('69D1D8F3-45E1-49A8-9821-9BBDFDAAD9D9', 'Control Point')
|
||||
GATT_ANCS_DATA_SOURCE_CHARACTERISTIC = UUID('22EAC6E9-24D6-4BB5-BE44-B36ACE7C7BFB', 'Data Source')
|
||||
|
||||
# Apple Media Service
|
||||
GATT_AMS_SERVICE = UUID('89D3502B-0F36-433A-8EF4-C502AD55F8DC', 'Apple Media')
|
||||
GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC = UUID('9B3C81D8-57B1-4A8A-B8DF-0E56F7CA51C2', 'Remote Command')
|
||||
GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC = UUID('2F7CABCE-808D-411F-9A0C-BB92BA96C102', 'Entity Update')
|
||||
GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC = UUID('C6B2F38C-23AB-46D8-A6AB-A3A870BBD5D7', 'Entity Attribute')
|
||||
|
||||
# Misc Apple Services
|
||||
GATT_APPLE_CONTINUITY_SERVICE = UUID('D0611E78-BBB4-4591-A5F8-487910AE4366', 'Apple Continuity')
|
||||
GATT_APPLE_NEARBY_SERVICE = UUID('9FA480E0-4967-4542-9390-D343DC5D04AE', 'Apple Nearby')
|
||||
|
||||
# Misc
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2A00, 'Device Name')
|
||||
GATT_APPEARANCE_CHARACTERISTIC = UUID.from_16_bits(0x2A01, 'Appearance')
|
||||
@@ -304,6 +319,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
|
||||
GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features')
|
||||
GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash')
|
||||
GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features')
|
||||
GATT_LE_GATT_SECURITY_LEVELS_CHARACTERISTIC = UUID.from_16_bits(0x2BF5, 'E GATT Security Levels')
|
||||
|
||||
# fmt: on
|
||||
# pylint: enable=line-too-long
|
||||
@@ -312,8 +328,6 @@ GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bi
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def show_services(services: Iterable[Service]) -> None:
|
||||
for service in services:
|
||||
print(color(str(service), 'cyan'))
|
||||
@@ -337,13 +351,13 @@ class Service(Attribute):
|
||||
'''
|
||||
|
||||
uuid: UUID
|
||||
characteristics: List[Characteristic]
|
||||
included_services: List[Service]
|
||||
characteristics: list[Characteristic]
|
||||
included_services: list[Service]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uuid: Union[str, UUID],
|
||||
characteristics: List[Characteristic],
|
||||
characteristics: Iterable[Characteristic],
|
||||
primary=True,
|
||||
included_services: Iterable[Service] = (),
|
||||
) -> None:
|
||||
@@ -362,7 +376,7 @@ class Service(Attribute):
|
||||
)
|
||||
self.uuid = uuid
|
||||
self.included_services = list(included_services)
|
||||
self.characteristics = characteristics[:]
|
||||
self.characteristics = list(characteristics)
|
||||
self.primary = primary
|
||||
|
||||
def get_advertising_data(self) -> Optional[bytes]:
|
||||
@@ -393,7 +407,7 @@ class TemplateService(Service):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
characteristics: List[Characteristic],
|
||||
characteristics: Iterable[Characteristic],
|
||||
primary: bool = True,
|
||||
included_services: Iterable[Service] = (),
|
||||
) -> None:
|
||||
@@ -410,7 +424,7 @@ class IncludedServiceDeclaration(Attribute):
|
||||
|
||||
def __init__(self, service: Service) -> None:
|
||||
declaration_bytes = struct.pack(
|
||||
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
|
||||
'<HH2s', service.handle, service.end_group_handle, bytes(service.uuid)
|
||||
)
|
||||
super().__init__(
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
|
||||
@@ -427,7 +441,7 @@ class IncludedServiceDeclaration(Attribute):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Characteristic(Attribute):
|
||||
class Characteristic(Attribute[_T]):
|
||||
'''
|
||||
See Vol 3, Part G - 3.3 CHARACTERISTIC DEFINITION
|
||||
'''
|
||||
@@ -435,6 +449,8 @@ class Characteristic(Attribute):
|
||||
uuid: UUID
|
||||
properties: Characteristic.Properties
|
||||
|
||||
EVENT_SUBSCRIPTION = "subscription"
|
||||
|
||||
class Properties(enum.IntFlag):
|
||||
"""Property flags"""
|
||||
|
||||
@@ -459,7 +475,7 @@ class Characteristic(Attribute):
|
||||
# The check for `p.name is not None` here is needed because for InFlag
|
||||
# enums, the .name property can be None, when the enum value is 0,
|
||||
# so the type hint for .name is Optional[str].
|
||||
enum_list: List[str] = [p.name for p in cls if p.name is not None]
|
||||
enum_list: list[str] = [p.name for p in cls if p.name is not None]
|
||||
enum_list_str = ",".join(enum_list)
|
||||
raise TypeError(
|
||||
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
|
||||
@@ -490,7 +506,7 @@ class Characteristic(Attribute):
|
||||
uuid: Union[str, bytes, UUID],
|
||||
properties: Characteristic.Properties,
|
||||
permissions: Union[str, Attribute.Permissions],
|
||||
value: Union[str, bytes, CharacteristicValue] = b'',
|
||||
value: Union[AttributeValue[_T], _T, None] = None,
|
||||
descriptors: Sequence[Descriptor] = (),
|
||||
):
|
||||
super().__init__(uuid, permissions, value)
|
||||
@@ -525,7 +541,11 @@ class CharacteristicDeclaration(Attribute):
|
||||
|
||||
characteristic: Characteristic
|
||||
|
||||
def __init__(self, characteristic: Characteristic, value_handle: int) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
characteristic: Characteristic,
|
||||
value_handle: int,
|
||||
) -> None:
|
||||
declaration_bytes = (
|
||||
struct.pack('<BH', characteristic.properties, value_handle)
|
||||
+ characteristic.uuid.to_pdu_bytes()
|
||||
@@ -546,195 +566,10 @@ class CharacteristicDeclaration(Attribute):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicValue(AttributeValue):
|
||||
class CharacteristicValue(AttributeValue[_T]):
|
||||
"""Same as AttributeValue, for backward compatibility"""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicAdapter:
|
||||
'''
|
||||
An adapter that can adapt Characteristic and AttributeProxy objects
|
||||
by wrapping their `read_value()` and `write_value()` methods with ones that
|
||||
return/accept encoded/decoded values.
|
||||
|
||||
For proxies (i.e used by a GATT client), the adaptation is one where the return
|
||||
value of `read_value()` is decoded and the value passed to `write_value()` is
|
||||
encoded. The `subscribe()` method, is wrapped with one where the values are decoded
|
||||
before being passed to the subscriber.
|
||||
|
||||
For local values (i.e hosted by a GATT server) the adaptation is one where the
|
||||
return value of `read_value()` is encoded and the value passed to `write_value()`
|
||||
is decoded.
|
||||
'''
|
||||
|
||||
read_value: Callable
|
||||
write_value: Callable
|
||||
|
||||
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
|
||||
self.wrapped_characteristic = characteristic
|
||||
self.subscribers: Dict[Callable, Callable] = (
|
||||
{}
|
||||
) # Map from subscriber to proxy subscriber
|
||||
|
||||
if isinstance(characteristic, Characteristic):
|
||||
self.read_value = self.read_encoded_value
|
||||
self.write_value = self.write_encoded_value
|
||||
else:
|
||||
self.read_value = self.read_decoded_value
|
||||
self.write_value = self.write_decoded_value
|
||||
self.subscribe = self.wrapped_subscribe
|
||||
self.unsubscribe = self.wrapped_unsubscribe
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.wrapped_characteristic, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in (
|
||||
'wrapped_characteristic',
|
||||
'subscribers',
|
||||
'read_value',
|
||||
'write_value',
|
||||
'subscribe',
|
||||
'unsubscribe',
|
||||
):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
setattr(self.wrapped_characteristic, name, value)
|
||||
|
||||
async def read_encoded_value(self, connection):
|
||||
return self.encode_value(
|
||||
await self.wrapped_characteristic.read_value(connection)
|
||||
)
|
||||
|
||||
async def write_encoded_value(self, connection, value):
|
||||
return await self.wrapped_characteristic.write_value(
|
||||
connection, self.decode_value(value)
|
||||
)
|
||||
|
||||
async def read_decoded_value(self):
|
||||
return self.decode_value(await self.wrapped_characteristic.read_value())
|
||||
|
||||
async def write_decoded_value(self, value, with_response=False):
|
||||
return await self.wrapped_characteristic.write_value(
|
||||
self.encode_value(value), with_response
|
||||
)
|
||||
|
||||
def encode_value(self, value):
|
||||
return value
|
||||
|
||||
def decode_value(self, value):
|
||||
return value
|
||||
|
||||
def wrapped_subscribe(self, subscriber=None):
|
||||
if subscriber is not None:
|
||||
if subscriber in self.subscribers:
|
||||
# We already have a proxy subscriber
|
||||
subscriber = self.subscribers[subscriber]
|
||||
else:
|
||||
# Create and register a proxy that will decode the value
|
||||
original_subscriber = subscriber
|
||||
|
||||
def on_change(value):
|
||||
original_subscriber(self.decode_value(value))
|
||||
|
||||
self.subscribers[subscriber] = on_change
|
||||
subscriber = on_change
|
||||
|
||||
return self.wrapped_characteristic.subscribe(subscriber)
|
||||
|
||||
def wrapped_unsubscribe(self, subscriber=None):
|
||||
if subscriber in self.subscribers:
|
||||
subscriber = self.subscribers.pop(subscriber)
|
||||
|
||||
return self.wrapped_characteristic.unsubscribe(subscriber)
|
||||
|
||||
def __str__(self) -> str:
|
||||
wrapped = str(self.wrapped_characteristic)
|
||||
return f'{self.__class__.__name__}({wrapped})'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class DelegatedCharacteristicAdapter(CharacteristicAdapter):
|
||||
'''
|
||||
Adapter that converts bytes values using an encode and a decode function.
|
||||
'''
|
||||
|
||||
def __init__(self, characteristic, encode=None, decode=None):
|
||||
super().__init__(characteristic)
|
||||
self.encode = encode
|
||||
self.decode = decode
|
||||
|
||||
def encode_value(self, value):
|
||||
return self.encode(value) if self.encode else value
|
||||
|
||||
def decode_value(self, value):
|
||||
return self.decode(value) if self.decode else value
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PackedCharacteristicAdapter(CharacteristicAdapter):
|
||||
'''
|
||||
Adapter that packs/unpacks characteristic values according to a standard
|
||||
Python `struct` format.
|
||||
For formats with a single value, the adapted `read_value` and `write_value`
|
||||
methods return/accept single values. For formats with multiple values,
|
||||
they return/accept a tuple with the same number of elements as is required for
|
||||
the format.
|
||||
'''
|
||||
|
||||
def __init__(self, characteristic, pack_format):
|
||||
super().__init__(characteristic)
|
||||
self.struct = struct.Struct(pack_format)
|
||||
|
||||
def pack(self, *values):
|
||||
return self.struct.pack(*values)
|
||||
|
||||
def unpack(self, buffer):
|
||||
return self.struct.unpack(buffer)
|
||||
|
||||
def encode_value(self, value):
|
||||
return self.pack(*value if isinstance(value, tuple) else (value,))
|
||||
|
||||
def decode_value(self, value):
|
||||
unpacked = self.unpack(value)
|
||||
return unpacked[0] if len(unpacked) == 1 else unpacked
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
|
||||
'''
|
||||
Adapter that packs/unpacks characteristic values according to a standard
|
||||
Python `struct` format.
|
||||
The adapted `read_value` and `write_value` methods return/accept aa dictionary which
|
||||
is packed/unpacked according to format, with the arguments extracted from the
|
||||
dictionary by key, in the same order as they occur in the `keys` parameter.
|
||||
'''
|
||||
|
||||
def __init__(self, characteristic, pack_format, keys):
|
||||
super().__init__(characteristic, pack_format)
|
||||
self.keys = keys
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
def pack(self, values):
|
||||
return super().pack(*(values[key] for key in self.keys))
|
||||
|
||||
def unpack(self, buffer):
|
||||
return dict(zip(self.keys, super().unpack(buffer)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class UTF8CharacteristicAdapter(CharacteristicAdapter):
|
||||
'''
|
||||
Adapter that converts strings to/from bytes using UTF-8 encoding
|
||||
'''
|
||||
|
||||
def encode_value(self, value: str) -> bytes:
|
||||
return value.encode('utf-8')
|
||||
|
||||
def decode_value(self, value: bytes) -> str:
|
||||
return value.decode('utf-8')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Descriptor(Attribute):
|
||||
'''
|
||||
@@ -745,11 +580,7 @@ class Descriptor(Attribute):
|
||||
if isinstance(self.value, bytes):
|
||||
value_str = self.value.hex()
|
||||
elif isinstance(self.value, CharacteristicValue):
|
||||
value = self.value.read(None)
|
||||
if isinstance(value, bytes):
|
||||
value_str = value.hex()
|
||||
else:
|
||||
value_str = '<async>'
|
||||
value_str = '<dynamic>'
|
||||
else:
|
||||
value_str = '<...>'
|
||||
return (
|
||||
@@ -769,3 +600,23 @@ class ClientCharacteristicConfigurationBits(enum.IntFlag):
|
||||
DEFAULT = 0x0000
|
||||
NOTIFICATION = 0x0001
|
||||
INDICATION = 0x0002
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ClientSupportedFeatures(enum.IntFlag):
|
||||
'''
|
||||
See Vol 3, Part G - 7.2 - Table 7.6: Client Supported Features bit assignments.
|
||||
'''
|
||||
|
||||
ROBUST_CACHING = 0x01
|
||||
ENHANCED_ATT_BEARER = 0x02
|
||||
MULTIPLE_HANDLE_VALUE_NOTIFICATIONS = 0x04
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ServerSupportedFeatures(enum.IntFlag):
|
||||
'''
|
||||
See Vol 3, Part G - 7.4 - Table 7.11: Server Supported Features bit assignments.
|
||||
'''
|
||||
|
||||
EATT_SUPPORTED = 0x01
|
||||
|
||||
365
bumble/gatt_adapters.py
Normal file
365
bumble/gatt_adapters.py
Normal file
@@ -0,0 +1,365 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GATT - Type Adapters
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from typing import Any, Callable, Generic, Iterable, Literal, Optional, TypeVar
|
||||
|
||||
from bumble import utils
|
||||
from bumble.core import InvalidOperationError
|
||||
from bumble.gatt import Characteristic
|
||||
from bumble.gatt_client import CharacteristicProxy
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Typing
|
||||
# -----------------------------------------------------------------------------
|
||||
_T = TypeVar('_T')
|
||||
_T2 = TypeVar('_T2', bound=utils.ByteSerializable)
|
||||
_T3 = TypeVar('_T3', bound=utils.IntConvertible)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicAdapter(Characteristic, Generic[_T]):
|
||||
'''Base class for GATT Characteristic adapters.'''
|
||||
|
||||
def __init__(self, characteristic: Characteristic) -> None:
|
||||
super().__init__(
|
||||
characteristic.uuid,
|
||||
characteristic.properties,
|
||||
characteristic.permissions,
|
||||
characteristic.value,
|
||||
characteristic.descriptors,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicProxyAdapter(CharacteristicProxy[_T]):
|
||||
'''Base class for GATT CharacteristicProxy adapters.'''
|
||||
|
||||
def __init__(self, characteristic_proxy: CharacteristicProxy):
|
||||
super().__init__(
|
||||
characteristic_proxy.client,
|
||||
characteristic_proxy.handle,
|
||||
characteristic_proxy.end_group_handle,
|
||||
characteristic_proxy.uuid,
|
||||
characteristic_proxy.properties,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class DelegatedCharacteristicAdapter(CharacteristicAdapter[_T]):
|
||||
'''
|
||||
Adapter that converts bytes values using an encode and/or a decode function.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
characteristic: Characteristic,
|
||||
encode: Optional[Callable[[_T], bytes]] = None,
|
||||
decode: Optional[Callable[[bytes], _T]] = None,
|
||||
):
|
||||
super().__init__(characteristic)
|
||||
self.encode = encode
|
||||
self.decode = decode
|
||||
|
||||
def encode_value(self, value: _T) -> bytes:
|
||||
if self.encode is None:
|
||||
raise InvalidOperationError('delegated adapter does not have an encoder')
|
||||
return self.encode(value)
|
||||
|
||||
def decode_value(self, value: bytes) -> _T:
|
||||
if self.decode is None:
|
||||
raise InvalidOperationError('delegate adapter does not have a decoder')
|
||||
return self.decode(value)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class DelegatedCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T]):
|
||||
'''
|
||||
Adapter that converts bytes values using an encode and a decode function.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
characteristic_proxy: CharacteristicProxy,
|
||||
encode: Optional[Callable[[_T], bytes]] = None,
|
||||
decode: Optional[Callable[[bytes], _T]] = None,
|
||||
):
|
||||
super().__init__(characteristic_proxy)
|
||||
self.encode = encode
|
||||
self.decode = decode
|
||||
|
||||
def encode_value(self, value: _T) -> bytes:
|
||||
if self.encode is None:
|
||||
raise InvalidOperationError('delegated adapter does not have an encoder')
|
||||
return self.encode(value)
|
||||
|
||||
def decode_value(self, value: bytes) -> _T:
|
||||
if self.decode is None:
|
||||
raise InvalidOperationError('delegate adapter does not have a decoder')
|
||||
return self.decode(value)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PackedCharacteristicAdapter(CharacteristicAdapter):
|
||||
'''
|
||||
Adapter that packs/unpacks characteristic values according to a standard
|
||||
Python `struct` format.
|
||||
For formats with a single value, the adapted `read_value` and `write_value`
|
||||
methods return/accept single values. For formats with multiple values,
|
||||
they return/accept a tuple with the same number of elements as is required for
|
||||
the format.
|
||||
'''
|
||||
|
||||
def __init__(self, characteristic: Characteristic, pack_format: str) -> None:
|
||||
super().__init__(characteristic)
|
||||
self.struct = struct.Struct(pack_format)
|
||||
|
||||
def pack(self, *values) -> bytes:
|
||||
return self.struct.pack(*values)
|
||||
|
||||
def unpack(self, buffer: bytes) -> tuple:
|
||||
return self.struct.unpack(buffer)
|
||||
|
||||
def encode_value(self, value: Any) -> bytes:
|
||||
return self.pack(*value if isinstance(value, tuple) else (value,))
|
||||
|
||||
def decode_value(self, value: bytes) -> Any:
|
||||
unpacked = self.unpack(value)
|
||||
return unpacked[0] if len(unpacked) == 1 else unpacked
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PackedCharacteristicProxyAdapter(CharacteristicProxyAdapter):
|
||||
'''
|
||||
Adapter that packs/unpacks characteristic values according to a standard
|
||||
Python `struct` format.
|
||||
For formats with a single value, the adapted `read_value` and `write_value`
|
||||
methods return/accept single values. For formats with multiple values,
|
||||
they return/accept a tuple with the same number of elements as is required for
|
||||
the format.
|
||||
'''
|
||||
|
||||
def __init__(self, characteristic_proxy, pack_format):
|
||||
super().__init__(characteristic_proxy)
|
||||
self.struct = struct.Struct(pack_format)
|
||||
|
||||
def pack(self, *values) -> bytes:
|
||||
return self.struct.pack(*values)
|
||||
|
||||
def unpack(self, buffer: bytes) -> tuple:
|
||||
return self.struct.unpack(buffer)
|
||||
|
||||
def encode_value(self, value: Any) -> bytes:
|
||||
return self.pack(*value if isinstance(value, tuple) else (value,))
|
||||
|
||||
def decode_value(self, value: bytes) -> Any:
|
||||
unpacked = self.unpack(value)
|
||||
return unpacked[0] if len(unpacked) == 1 else unpacked
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
|
||||
'''
|
||||
Adapter that packs/unpacks characteristic values according to a standard
|
||||
Python `struct` format.
|
||||
The adapted `read_value` and `write_value` methods return/accept a dictionary which
|
||||
is packed/unpacked according to format, with the arguments extracted from the
|
||||
dictionary by key, in the same order as they occur in the `keys` parameter.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self, characteristic: Characteristic, pack_format: str, keys: Iterable[str]
|
||||
) -> None:
|
||||
super().__init__(characteristic, pack_format)
|
||||
self.keys = keys
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
def pack(self, values) -> bytes:
|
||||
return super().pack(*(values[key] for key in self.keys))
|
||||
|
||||
def unpack(self, buffer: bytes) -> Any:
|
||||
return dict(zip(self.keys, super().unpack(buffer)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MappedCharacteristicProxyAdapter(PackedCharacteristicProxyAdapter):
|
||||
'''
|
||||
Adapter that packs/unpacks characteristic values according to a standard
|
||||
Python `struct` format.
|
||||
The adapted `read_value` and `write_value` methods return/accept a dictionary which
|
||||
is packed/unpacked according to format, with the arguments extracted from the
|
||||
dictionary by key, in the same order as they occur in the `keys` parameter.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
characteristic_proxy: CharacteristicProxy,
|
||||
pack_format: str,
|
||||
keys: Iterable[str],
|
||||
) -> None:
|
||||
super().__init__(characteristic_proxy, pack_format)
|
||||
self.keys = keys
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
def pack(self, values) -> bytes:
|
||||
return super().pack(*(values[key] for key in self.keys))
|
||||
|
||||
def unpack(self, buffer: bytes) -> Any:
|
||||
return dict(zip(self.keys, super().unpack(buffer)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class UTF8CharacteristicAdapter(CharacteristicAdapter[str]):
|
||||
'''
|
||||
Adapter that converts strings to/from bytes using UTF-8 encoding
|
||||
'''
|
||||
|
||||
def encode_value(self, value: str) -> bytes:
|
||||
return value.encode('utf-8')
|
||||
|
||||
def decode_value(self, value: bytes) -> str:
|
||||
return value.decode('utf-8')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class UTF8CharacteristicProxyAdapter(CharacteristicProxyAdapter[str]):
|
||||
'''
|
||||
Adapter that converts strings to/from bytes using UTF-8 encoding
|
||||
'''
|
||||
|
||||
def encode_value(self, value: str) -> bytes:
|
||||
return value.encode('utf-8')
|
||||
|
||||
def decode_value(self, value: bytes) -> str:
|
||||
return value.decode('utf-8')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SerializableCharacteristicAdapter(CharacteristicAdapter[_T2]):
|
||||
'''
|
||||
Adapter that converts any class to/from bytes using the class'
|
||||
`to_bytes` and `__bytes__` methods, respectively.
|
||||
'''
|
||||
|
||||
def __init__(self, characteristic: Characteristic, cls: type[_T2]) -> None:
|
||||
super().__init__(characteristic)
|
||||
self.cls = cls
|
||||
|
||||
def encode_value(self, value: _T2) -> bytes:
|
||||
return bytes(value)
|
||||
|
||||
def decode_value(self, value: bytes) -> _T2:
|
||||
return self.cls.from_bytes(value)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SerializableCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T2]):
|
||||
'''
|
||||
Adapter that converts any class to/from bytes using the class'
|
||||
`to_bytes` and `__bytes__` methods, respectively.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self, characteristic_proxy: CharacteristicProxy, cls: type[_T2]
|
||||
) -> None:
|
||||
super().__init__(characteristic_proxy)
|
||||
self.cls = cls
|
||||
|
||||
def encode_value(self, value: _T2) -> bytes:
|
||||
return bytes(value)
|
||||
|
||||
def decode_value(self, value: bytes) -> _T2:
|
||||
return self.cls.from_bytes(value)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class EnumCharacteristicAdapter(CharacteristicAdapter[_T3]):
|
||||
'''
|
||||
Adapter that converts int-enum-like classes to/from bytes using the class'
|
||||
`int().to_bytes()` and `from_bytes()` methods, respectively.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
characteristic: Characteristic,
|
||||
cls: type[_T3],
|
||||
length: int,
|
||||
byteorder: Literal['little', 'big'] = 'little',
|
||||
):
|
||||
"""
|
||||
Initialize an instance.
|
||||
|
||||
Params:
|
||||
characteristic: the Characteristic to adapt to/from
|
||||
cls: the class to/from which to convert integer values
|
||||
length: number of bytes used to represent integer values
|
||||
byteorder: byte order of the byte representation of integers.
|
||||
"""
|
||||
super().__init__(characteristic)
|
||||
self.cls = cls
|
||||
self.length = length
|
||||
self.byteorder = byteorder
|
||||
|
||||
def encode_value(self, value: _T3) -> bytes:
|
||||
return int(value).to_bytes(self.length, self.byteorder)
|
||||
|
||||
def decode_value(self, value: bytes) -> _T3:
|
||||
int_value = int.from_bytes(value, self.byteorder)
|
||||
return self.cls(int_value)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class EnumCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T3]):
|
||||
'''
|
||||
Adapter that converts int-enum-like classes to/from bytes using the class'
|
||||
`int().to_bytes()` and `from_bytes()` methods, respectively.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
characteristic_proxy: CharacteristicProxy,
|
||||
cls: type[_T3],
|
||||
length: int,
|
||||
byteorder: Literal['little', 'big'] = 'little',
|
||||
):
|
||||
"""
|
||||
Initialize an instance.
|
||||
|
||||
Params:
|
||||
characteristic_proxy: the CharacteristicProxy to adapt to/from
|
||||
cls: the class to/from which to convert integer values
|
||||
length: number of bytes used to represent integer values
|
||||
byteorder: byte order of the byte representation of integers.
|
||||
"""
|
||||
super().__init__(characteristic_proxy)
|
||||
self.cls = cls
|
||||
self.length = length
|
||||
self.byteorder = byteorder
|
||||
|
||||
def encode_value(self, value: _T3) -> bytes:
|
||||
return int(value).to_bytes(self.length, self.byteorder)
|
||||
|
||||
def decode_value(self, value: bytes) -> _T3:
|
||||
int_value = int.from_bytes(value, self.byteorder)
|
||||
a = self.cls(int_value)
|
||||
return self.cls(int_value)
|
||||
@@ -24,66 +24,47 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import struct
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Dict,
|
||||
Tuple,
|
||||
Callable,
|
||||
Union,
|
||||
Any,
|
||||
Iterable,
|
||||
Type,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pyee import EventEmitter
|
||||
|
||||
from .colors import color
|
||||
from .hci import HCI_Constant
|
||||
from .att import (
|
||||
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
||||
ATT_CID,
|
||||
ATT_DEFAULT_MTU,
|
||||
ATT_ERROR_RESPONSE,
|
||||
ATT_INVALID_OFFSET_ERROR,
|
||||
ATT_PDU,
|
||||
ATT_RESPONSES,
|
||||
ATT_Exchange_MTU_Request,
|
||||
ATT_Find_By_Type_Value_Request,
|
||||
ATT_Find_Information_Request,
|
||||
ATT_Handle_Value_Confirmation,
|
||||
ATT_Read_Blob_Request,
|
||||
ATT_Read_By_Group_Type_Request,
|
||||
ATT_Read_By_Type_Request,
|
||||
ATT_Read_Request,
|
||||
ATT_Write_Command,
|
||||
ATT_Write_Request,
|
||||
ATT_Error,
|
||||
)
|
||||
from . import core
|
||||
from .core import UUID, InvalidStateError
|
||||
from .gatt import (
|
||||
from bumble import att, core, utils
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID, InvalidStateError
|
||||
from bumble.gatt import (
|
||||
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
GATT_REQUEST_TIMEOUT,
|
||||
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||
Characteristic,
|
||||
ClientCharacteristicConfigurationBits,
|
||||
InvalidServiceError,
|
||||
TemplateService,
|
||||
)
|
||||
from bumble.hci import HCI_Constant
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Typing
|
||||
# -----------------------------------------------------------------------------
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Connection
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -109,31 +90,31 @@ def show_services(services: Iterable[ServiceProxy]) -> None:
|
||||
# -----------------------------------------------------------------------------
|
||||
# Proxies
|
||||
# -----------------------------------------------------------------------------
|
||||
class AttributeProxy(EventEmitter):
|
||||
class AttributeProxy(utils.EventEmitter, Generic[_T]):
|
||||
def __init__(
|
||||
self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID
|
||||
) -> None:
|
||||
EventEmitter.__init__(self)
|
||||
utils.EventEmitter.__init__(self)
|
||||
self.client = client
|
||||
self.handle = handle
|
||||
self.end_group_handle = end_group_handle
|
||||
self.type = attribute_type
|
||||
|
||||
async def read_value(self, no_long_read: bool = False) -> bytes:
|
||||
async def read_value(self, no_long_read: bool = False) -> _T:
|
||||
return self.decode_value(
|
||||
await self.client.read_value(self.handle, no_long_read)
|
||||
)
|
||||
|
||||
async def write_value(self, value, with_response=False):
|
||||
async def write_value(self, value: _T, with_response=False):
|
||||
return await self.client.write_value(
|
||||
self.handle, self.encode_value(value), with_response
|
||||
)
|
||||
|
||||
def encode_value(self, value: Any) -> bytes:
|
||||
return value
|
||||
def encode_value(self, value: _T) -> bytes:
|
||||
return value # type: ignore
|
||||
|
||||
def decode_value(self, value_bytes: bytes) -> Any:
|
||||
return value_bytes
|
||||
def decode_value(self, value: bytes) -> _T:
|
||||
return value # type: ignore
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
|
||||
@@ -141,8 +122,8 @@ class AttributeProxy(EventEmitter):
|
||||
|
||||
class ServiceProxy(AttributeProxy):
|
||||
uuid: UUID
|
||||
characteristics: List[CharacteristicProxy]
|
||||
included_services: List[ServiceProxy]
|
||||
characteristics: list[CharacteristicProxy[bytes]]
|
||||
included_services: list[ServiceProxy]
|
||||
|
||||
@staticmethod
|
||||
def from_client(service_class, client: Client, service_uuid: UUID):
|
||||
@@ -162,29 +143,48 @@ class ServiceProxy(AttributeProxy):
|
||||
self.uuid = uuid
|
||||
self.characteristics = []
|
||||
|
||||
async def discover_characteristics(self, uuids=()):
|
||||
async def discover_characteristics(
|
||||
self, uuids=()
|
||||
) -> list[CharacteristicProxy[bytes]]:
|
||||
return await self.client.discover_characteristics(uuids, self)
|
||||
|
||||
def get_characteristics_by_uuid(self, uuid):
|
||||
def get_characteristics_by_uuid(
|
||||
self, uuid: UUID
|
||||
) -> list[CharacteristicProxy[bytes]]:
|
||||
"""Get all the characteristics with a specified UUID."""
|
||||
return self.client.get_characteristics_by_uuid(uuid, self)
|
||||
|
||||
def get_required_characteristic_by_uuid(
|
||||
self, uuid: UUID
|
||||
) -> CharacteristicProxy[bytes]:
|
||||
"""
|
||||
Get the first characteristic with a specified UUID.
|
||||
|
||||
If no characteristic with that UUID is found, an InvalidServiceError is raised.
|
||||
"""
|
||||
if not (characteristics := self.get_characteristics_by_uuid(uuid)):
|
||||
raise InvalidServiceError(f'{uuid} characteristic not found')
|
||||
return characteristics[0]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
|
||||
|
||||
|
||||
class CharacteristicProxy(AttributeProxy):
|
||||
class CharacteristicProxy(AttributeProxy[_T]):
|
||||
properties: Characteristic.Properties
|
||||
descriptors: List[DescriptorProxy]
|
||||
subscribers: Dict[Any, Callable[[bytes], Any]]
|
||||
descriptors: list[DescriptorProxy]
|
||||
subscribers: dict[Any, Callable[[_T], Any]]
|
||||
|
||||
EVENT_UPDATE = "update"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client,
|
||||
handle,
|
||||
end_group_handle,
|
||||
uuid,
|
||||
client: Client,
|
||||
handle: int,
|
||||
end_group_handle: int,
|
||||
uuid: UUID,
|
||||
properties: int,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(client, handle, end_group_handle, uuid)
|
||||
self.uuid = uuid
|
||||
self.properties = Characteristic.Properties(properties)
|
||||
@@ -192,21 +192,21 @@ class CharacteristicProxy(AttributeProxy):
|
||||
self.descriptors_discovered = False
|
||||
self.subscribers = {} # Map from subscriber to proxy subscriber
|
||||
|
||||
def get_descriptor(self, descriptor_type):
|
||||
def get_descriptor(self, descriptor_type: UUID) -> Optional[DescriptorProxy]:
|
||||
for descriptor in self.descriptors:
|
||||
if descriptor.type == descriptor_type:
|
||||
return descriptor
|
||||
|
||||
return None
|
||||
|
||||
async def discover_descriptors(self):
|
||||
async def discover_descriptors(self) -> list[DescriptorProxy]:
|
||||
return await self.client.discover_descriptors(self)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
subscriber: Optional[Callable[[bytes], Any]] = None,
|
||||
subscriber: Optional[Callable[[_T], Any]] = None,
|
||||
prefer_notify: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
if subscriber is not None:
|
||||
if subscriber in self.subscribers:
|
||||
# We already have a proxy subscriber
|
||||
@@ -221,13 +221,13 @@ class CharacteristicProxy(AttributeProxy):
|
||||
self.subscribers[subscriber] = on_change
|
||||
subscriber = on_change
|
||||
|
||||
return await self.client.subscribe(self, subscriber, prefer_notify)
|
||||
await self.client.subscribe(self, subscriber, prefer_notify)
|
||||
|
||||
async def unsubscribe(self, subscriber=None, force=False):
|
||||
async def unsubscribe(self, subscriber=None, force=False) -> None:
|
||||
if subscriber in self.subscribers:
|
||||
subscriber = self.subscribers.pop(subscriber)
|
||||
|
||||
return await self.client.unsubscribe(self, subscriber, force)
|
||||
await self.client.unsubscribe(self, subscriber, force)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
@@ -237,8 +237,8 @@ class CharacteristicProxy(AttributeProxy):
|
||||
)
|
||||
|
||||
|
||||
class DescriptorProxy(AttributeProxy):
|
||||
def __init__(self, client, handle, descriptor_type):
|
||||
class DescriptorProxy(AttributeProxy[bytes]):
|
||||
def __init__(self, client: Client, handle: int, descriptor_type: UUID) -> None:
|
||||
super().__init__(client, handle, 0, descriptor_type)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -250,7 +250,7 @@ class ProfileServiceProxy:
|
||||
Base class for profile-specific service proxies
|
||||
'''
|
||||
|
||||
SERVICE_CLASS: Type[TemplateService]
|
||||
SERVICE_CLASS: type[TemplateService]
|
||||
|
||||
@classmethod
|
||||
def from_client(cls, client: Client) -> Optional[ProfileServiceProxy]:
|
||||
@@ -261,16 +261,16 @@ class ProfileServiceProxy:
|
||||
# GATT Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class Client:
|
||||
services: List[ServiceProxy]
|
||||
cached_values: Dict[int, Tuple[datetime, bytes]]
|
||||
notification_subscribers: Dict[
|
||||
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
|
||||
services: list[ServiceProxy]
|
||||
cached_values: dict[int, tuple[datetime, bytes]]
|
||||
notification_subscribers: dict[
|
||||
int, set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
|
||||
]
|
||||
indication_subscribers: Dict[
|
||||
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
|
||||
indication_subscribers: dict[
|
||||
int, set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
|
||||
]
|
||||
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
|
||||
pending_request: Optional[ATT_PDU]
|
||||
pending_response: Optional[asyncio.futures.Future[att.ATT_PDU]]
|
||||
pending_request: Optional[att.ATT_PDU]
|
||||
|
||||
def __init__(self, connection: Connection) -> None:
|
||||
self.connection = connection
|
||||
@@ -283,18 +283,18 @@ class Client:
|
||||
self.services = []
|
||||
self.cached_values = {}
|
||||
|
||||
connection.on('disconnection', self.on_disconnection)
|
||||
connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
|
||||
|
||||
def send_gatt_pdu(self, pdu: bytes) -> None:
|
||||
self.connection.send_l2cap_pdu(ATT_CID, pdu)
|
||||
self.connection.send_l2cap_pdu(att.ATT_CID, pdu)
|
||||
|
||||
async def send_command(self, command: ATT_PDU) -> None:
|
||||
async def send_command(self, command: att.ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
|
||||
)
|
||||
self.send_gatt_pdu(command.to_bytes())
|
||||
self.send_gatt_pdu(bytes(command))
|
||||
|
||||
async def send_request(self, request: ATT_PDU):
|
||||
async def send_request(self, request: att.ATT_PDU):
|
||||
logger.debug(
|
||||
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
|
||||
)
|
||||
@@ -310,7 +310,7 @@ class Client:
|
||||
self.pending_request = request
|
||||
|
||||
try:
|
||||
self.send_gatt_pdu(request.to_bytes())
|
||||
self.send_gatt_pdu(bytes(request))
|
||||
response = await asyncio.wait_for(
|
||||
self.pending_response, GATT_REQUEST_TIMEOUT
|
||||
)
|
||||
@@ -323,17 +323,19 @@ class Client:
|
||||
|
||||
return response
|
||||
|
||||
def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None:
|
||||
def send_confirmation(
|
||||
self, confirmation: att.ATT_Handle_Value_Confirmation
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
|
||||
f'{confirmation}'
|
||||
)
|
||||
self.send_gatt_pdu(confirmation.to_bytes())
|
||||
self.send_gatt_pdu(bytes(confirmation))
|
||||
|
||||
async def request_mtu(self, mtu: int) -> int:
|
||||
# Check the range
|
||||
if mtu < ATT_DEFAULT_MTU:
|
||||
raise core.InvalidArgumentError(f'MTU must be >= {ATT_DEFAULT_MTU}')
|
||||
if mtu < att.ATT_DEFAULT_MTU:
|
||||
raise core.InvalidArgumentError(f'MTU must be >= {att.ATT_DEFAULT_MTU}')
|
||||
if mtu > 0xFFFF:
|
||||
raise core.InvalidArgumentError('MTU must be <= 0xFFFF')
|
||||
|
||||
@@ -343,21 +345,23 @@ class Client:
|
||||
|
||||
# Send the request
|
||||
self.mtu_exchange_done = True
|
||||
response = await self.send_request(ATT_Exchange_MTU_Request(client_rx_mtu=mtu))
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
raise ATT_Error(error_code=response.error_code, message=response)
|
||||
response = await self.send_request(
|
||||
att.ATT_Exchange_MTU_Request(client_rx_mtu=mtu)
|
||||
)
|
||||
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE:
|
||||
raise att.ATT_Error(error_code=response.error_code, message=response)
|
||||
|
||||
# Compute the final MTU
|
||||
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
|
||||
|
||||
return self.connection.att_mtu
|
||||
|
||||
def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]:
|
||||
def get_services_by_uuid(self, uuid: UUID) -> list[ServiceProxy]:
|
||||
return [service for service in self.services if service.uuid == uuid]
|
||||
|
||||
def get_characteristics_by_uuid(
|
||||
self, uuid: UUID, service: Optional[ServiceProxy] = None
|
||||
) -> List[CharacteristicProxy]:
|
||||
) -> list[CharacteristicProxy[bytes]]:
|
||||
services = [service] if service else self.services
|
||||
return [
|
||||
c
|
||||
@@ -368,8 +372,8 @@ class Client:
|
||||
def get_attribute_grouping(self, attribute_handle: int) -> Optional[
|
||||
Union[
|
||||
ServiceProxy,
|
||||
Tuple[ServiceProxy, CharacteristicProxy],
|
||||
Tuple[ServiceProxy, CharacteristicProxy, DescriptorProxy],
|
||||
tuple[ServiceProxy, CharacteristicProxy],
|
||||
tuple[ServiceProxy, CharacteristicProxy, DescriptorProxy],
|
||||
]
|
||||
]:
|
||||
"""
|
||||
@@ -402,7 +406,7 @@ class Client:
|
||||
if not already_known:
|
||||
self.services.append(service)
|
||||
|
||||
async def discover_services(self, uuids: Iterable[UUID] = ()) -> List[ServiceProxy]:
|
||||
async def discover_services(self, uuids: Iterable[UUID] = ()) -> list[ServiceProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.4.1 Discover All Primary Services
|
||||
'''
|
||||
@@ -410,7 +414,7 @@ class Client:
|
||||
services = []
|
||||
while starting_handle < 0xFFFF:
|
||||
response = await self.send_request(
|
||||
ATT_Read_By_Group_Type_Request(
|
||||
att.ATT_Read_By_Group_Type_Request(
|
||||
starting_handle=starting_handle,
|
||||
ending_handle=0xFFFF,
|
||||
attribute_group_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
@@ -421,14 +425,14 @@ class Client:
|
||||
return []
|
||||
|
||||
# Check if we reached the end of the iteration
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE:
|
||||
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
# Unexpected end
|
||||
logger.warning(
|
||||
'!!! unexpected error while discovering services: '
|
||||
f'{HCI_Constant.error_name(response.error_code)}'
|
||||
)
|
||||
raise ATT_Error(
|
||||
raise att.ATT_Error(
|
||||
error_code=response.error_code,
|
||||
message='Unexpected error while discovering services',
|
||||
)
|
||||
@@ -474,7 +478,7 @@ class Client:
|
||||
|
||||
return services
|
||||
|
||||
async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]:
|
||||
async def discover_service(self, uuid: Union[str, UUID]) -> list[ServiceProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
|
||||
'''
|
||||
@@ -487,7 +491,7 @@ class Client:
|
||||
services = []
|
||||
while starting_handle < 0xFFFF:
|
||||
response = await self.send_request(
|
||||
ATT_Find_By_Type_Value_Request(
|
||||
att.ATT_Find_By_Type_Value_Request(
|
||||
starting_handle=starting_handle,
|
||||
ending_handle=0xFFFF,
|
||||
attribute_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
@@ -499,8 +503,8 @@ class Client:
|
||||
return []
|
||||
|
||||
# Check if we reached the end of the iteration
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE:
|
||||
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
# Unexpected end
|
||||
logger.warning(
|
||||
'!!! unexpected error while discovering services: '
|
||||
@@ -545,7 +549,7 @@ class Client:
|
||||
|
||||
async def discover_included_services(
|
||||
self, service: ServiceProxy
|
||||
) -> List[ServiceProxy]:
|
||||
) -> list[ServiceProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.5.1 Find Included Services
|
||||
'''
|
||||
@@ -553,10 +557,10 @@ class Client:
|
||||
starting_handle = service.handle
|
||||
ending_handle = service.end_group_handle
|
||||
|
||||
included_services: List[ServiceProxy] = []
|
||||
included_services: list[ServiceProxy] = []
|
||||
while starting_handle <= ending_handle:
|
||||
response = await self.send_request(
|
||||
ATT_Read_By_Type_Request(
|
||||
att.ATT_Read_By_Type_Request(
|
||||
starting_handle=starting_handle,
|
||||
ending_handle=ending_handle,
|
||||
attribute_type=GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||
@@ -567,14 +571,14 @@ class Client:
|
||||
return []
|
||||
|
||||
# Check if we reached the end of the iteration
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE:
|
||||
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
# Unexpected end
|
||||
logger.warning(
|
||||
'!!! unexpected error while discovering included services: '
|
||||
f'{HCI_Constant.error_name(response.error_code)}'
|
||||
)
|
||||
raise ATT_Error(
|
||||
raise att.ATT_Error(
|
||||
error_code=response.error_code,
|
||||
message='Unexpected error while discovering included services',
|
||||
)
|
||||
@@ -609,7 +613,7 @@ class Client:
|
||||
|
||||
async def discover_characteristics(
|
||||
self, uuids, service: Optional[ServiceProxy]
|
||||
) -> List[CharacteristicProxy]:
|
||||
) -> list[CharacteristicProxy[bytes]]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2
|
||||
Discover Characteristics by UUID
|
||||
@@ -622,15 +626,15 @@ class Client:
|
||||
services = [service] if service else self.services
|
||||
|
||||
# Perform characteristic discovery for each service
|
||||
discovered_characteristics: List[CharacteristicProxy] = []
|
||||
discovered_characteristics: list[CharacteristicProxy[bytes]] = []
|
||||
for service in services:
|
||||
starting_handle = service.handle
|
||||
ending_handle = service.end_group_handle
|
||||
|
||||
characteristics: List[CharacteristicProxy] = []
|
||||
characteristics: list[CharacteristicProxy[bytes]] = []
|
||||
while starting_handle <= ending_handle:
|
||||
response = await self.send_request(
|
||||
ATT_Read_By_Type_Request(
|
||||
att.ATT_Read_By_Type_Request(
|
||||
starting_handle=starting_handle,
|
||||
ending_handle=ending_handle,
|
||||
attribute_type=GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
@@ -641,14 +645,14 @@ class Client:
|
||||
return []
|
||||
|
||||
# Check if we reached the end of the iteration
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE:
|
||||
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
# Unexpected end
|
||||
logger.warning(
|
||||
'!!! unexpected error while discovering characteristics: '
|
||||
f'{HCI_Constant.error_name(response.error_code)}'
|
||||
)
|
||||
raise ATT_Error(
|
||||
raise att.ATT_Error(
|
||||
error_code=response.error_code,
|
||||
message='Unexpected error while discovering characteristics',
|
||||
)
|
||||
@@ -667,7 +671,7 @@ class Client:
|
||||
|
||||
properties, handle = struct.unpack_from('<BH', attribute_value)
|
||||
characteristic_uuid = UUID.from_bytes(attribute_value[3:])
|
||||
characteristic = CharacteristicProxy(
|
||||
characteristic = CharacteristicProxy[bytes](
|
||||
self, handle, 0, characteristic_uuid, properties
|
||||
)
|
||||
|
||||
@@ -698,7 +702,7 @@ class Client:
|
||||
characteristic: Optional[CharacteristicProxy] = None,
|
||||
start_handle: Optional[int] = None,
|
||||
end_handle: Optional[int] = None,
|
||||
) -> List[DescriptorProxy]:
|
||||
) -> list[DescriptorProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
|
||||
'''
|
||||
@@ -711,10 +715,10 @@ class Client:
|
||||
else:
|
||||
return []
|
||||
|
||||
descriptors: List[DescriptorProxy] = []
|
||||
descriptors: list[DescriptorProxy] = []
|
||||
while starting_handle <= ending_handle:
|
||||
response = await self.send_request(
|
||||
ATT_Find_Information_Request(
|
||||
att.ATT_Find_Information_Request(
|
||||
starting_handle=starting_handle, ending_handle=ending_handle
|
||||
)
|
||||
)
|
||||
@@ -723,8 +727,8 @@ class Client:
|
||||
return []
|
||||
|
||||
# Check if we reached the end of the iteration
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE:
|
||||
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
# Unexpected end
|
||||
logger.warning(
|
||||
'!!! unexpected error while discovering descriptors: '
|
||||
@@ -760,7 +764,7 @@ class Client:
|
||||
|
||||
return descriptors
|
||||
|
||||
async def discover_attributes(self) -> List[AttributeProxy]:
|
||||
async def discover_attributes(self) -> list[AttributeProxy[bytes]]:
|
||||
'''
|
||||
Discover all attributes, regardless of type
|
||||
'''
|
||||
@@ -769,7 +773,7 @@ class Client:
|
||||
attributes = []
|
||||
while True:
|
||||
response = await self.send_request(
|
||||
ATT_Find_Information_Request(
|
||||
att.ATT_Find_Information_Request(
|
||||
starting_handle=starting_handle, ending_handle=ending_handle
|
||||
)
|
||||
)
|
||||
@@ -777,8 +781,8 @@ class Client:
|
||||
return []
|
||||
|
||||
# Check if we reached the end of the iteration
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE:
|
||||
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
# Unexpected end
|
||||
logger.warning(
|
||||
'!!! unexpected error while discovering attributes: '
|
||||
@@ -793,7 +797,7 @@ class Client:
|
||||
logger.warning(f'bogus handle value: {attribute_handle}')
|
||||
return []
|
||||
|
||||
attribute = AttributeProxy(
|
||||
attribute = AttributeProxy[bytes](
|
||||
self, attribute_handle, 0, UUID.from_bytes(attribute_uuid)
|
||||
)
|
||||
attributes.append(attribute)
|
||||
@@ -806,7 +810,7 @@ class Client:
|
||||
async def subscribe(
|
||||
self,
|
||||
characteristic: CharacteristicProxy,
|
||||
subscriber: Optional[Callable[[bytes], Any]] = None,
|
||||
subscriber: Optional[Callable[[Any], Any]] = None,
|
||||
prefer_notify: bool = True,
|
||||
) -> None:
|
||||
# If we haven't already discovered the descriptors for this characteristic,
|
||||
@@ -856,7 +860,7 @@ class Client:
|
||||
async def unsubscribe(
|
||||
self,
|
||||
characteristic: CharacteristicProxy,
|
||||
subscriber: Optional[Callable[[bytes], Any]] = None,
|
||||
subscriber: Optional[Callable[[Any], Any]] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
'''
|
||||
@@ -898,6 +902,12 @@ class Client:
|
||||
) and subscriber in subscribers:
|
||||
subscribers.remove(subscriber)
|
||||
|
||||
# The characteristic itself is added as subscriber. If it is the
|
||||
# last remaining subscriber, we remove it, such that the clean up
|
||||
# works correctly. Otherwise the CCCD never is set back to 0.
|
||||
if len(subscribers) == 1 and characteristic in subscribers:
|
||||
subscribers.remove(characteristic)
|
||||
|
||||
# Cleanup if we removed the last one
|
||||
if not subscribers:
|
||||
del subscriber_set[characteristic.handle]
|
||||
@@ -926,12 +936,12 @@ class Client:
|
||||
# Send a request to read
|
||||
attribute_handle = attribute if isinstance(attribute, int) else attribute.handle
|
||||
response = await self.send_request(
|
||||
ATT_Read_Request(attribute_handle=attribute_handle)
|
||||
att.ATT_Read_Request(attribute_handle=attribute_handle)
|
||||
)
|
||||
if response is None:
|
||||
raise TimeoutError('read timeout')
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
raise ATT_Error(error_code=response.error_code, message=response)
|
||||
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE:
|
||||
raise att.ATT_Error(error_code=response.error_code, message=response)
|
||||
|
||||
# If the value is the max size for the MTU, try to read more unless the caller
|
||||
# specifically asked not to do that
|
||||
@@ -941,19 +951,21 @@ class Client:
|
||||
offset = len(attribute_value)
|
||||
while True:
|
||||
response = await self.send_request(
|
||||
ATT_Read_Blob_Request(
|
||||
att.ATT_Read_Blob_Request(
|
||||
attribute_handle=attribute_handle, value_offset=offset
|
||||
)
|
||||
)
|
||||
if response is None:
|
||||
raise TimeoutError('read timeout')
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE:
|
||||
if response.error_code in (
|
||||
ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
||||
ATT_INVALID_OFFSET_ERROR,
|
||||
att.ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
||||
att.ATT_INVALID_OFFSET_ERROR,
|
||||
):
|
||||
break
|
||||
raise ATT_Error(error_code=response.error_code, message=response)
|
||||
raise att.ATT_Error(
|
||||
error_code=response.error_code, message=response
|
||||
)
|
||||
|
||||
part = response.part_attribute_value
|
||||
attribute_value += part
|
||||
@@ -969,7 +981,7 @@ class Client:
|
||||
|
||||
async def read_characteristics_by_uuid(
|
||||
self, uuid: UUID, service: Optional[ServiceProxy]
|
||||
) -> List[bytes]:
|
||||
) -> list[bytes]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
|
||||
'''
|
||||
@@ -984,7 +996,7 @@ class Client:
|
||||
characteristics_values = []
|
||||
while starting_handle <= ending_handle:
|
||||
response = await self.send_request(
|
||||
ATT_Read_By_Type_Request(
|
||||
att.ATT_Read_By_Type_Request(
|
||||
starting_handle=starting_handle,
|
||||
ending_handle=ending_handle,
|
||||
attribute_type=uuid,
|
||||
@@ -995,8 +1007,8 @@ class Client:
|
||||
return []
|
||||
|
||||
# Check if we reached the end of the iteration
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE:
|
||||
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
# Unexpected end
|
||||
logger.warning(
|
||||
'!!! unexpected error while reading characteristics: '
|
||||
@@ -1041,15 +1053,15 @@ class Client:
|
||||
attribute_handle = attribute if isinstance(attribute, int) else attribute.handle
|
||||
if with_response:
|
||||
response = await self.send_request(
|
||||
ATT_Write_Request(
|
||||
att.ATT_Write_Request(
|
||||
attribute_handle=attribute_handle, attribute_value=value
|
||||
)
|
||||
)
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
raise ATT_Error(error_code=response.error_code, message=response)
|
||||
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE:
|
||||
raise att.ATT_Error(error_code=response.error_code, message=response)
|
||||
else:
|
||||
await self.send_command(
|
||||
ATT_Write_Command(
|
||||
att.ATT_Write_Command(
|
||||
attribute_handle=attribute_handle, attribute_value=value
|
||||
)
|
||||
)
|
||||
@@ -1058,11 +1070,11 @@ class Client:
|
||||
if self.pending_response and not self.pending_response.done():
|
||||
self.pending_response.cancel()
|
||||
|
||||
def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
|
||||
def on_gatt_pdu(self, att_pdu: att.ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
|
||||
)
|
||||
if att_pdu.op_code in ATT_RESPONSES:
|
||||
if att_pdu.op_code in att.ATT_RESPONSES:
|
||||
if self.pending_request is None:
|
||||
# Not expected!
|
||||
logger.warning('!!! unexpected response, there is no pending request')
|
||||
@@ -1070,7 +1082,7 @@ class Client:
|
||||
|
||||
# The response should match the pending request unless it is
|
||||
# an error response
|
||||
if att_pdu.op_code != ATT_ERROR_RESPONSE:
|
||||
if att_pdu.op_code != att.Opcode.ATT_ERROR_RESPONSE:
|
||||
expected_response_name = self.pending_request.name.replace(
|
||||
'_REQUEST', '_RESPONSE'
|
||||
)
|
||||
@@ -1098,7 +1110,9 @@ class Client:
|
||||
+ str(att_pdu)
|
||||
)
|
||||
|
||||
def on_att_handle_value_notification(self, notification):
|
||||
def on_att_handle_value_notification(
|
||||
self, notification: att.ATT_Handle_Value_Notification
|
||||
):
|
||||
# Call all subscribers
|
||||
subscribers = self.notification_subscribers.get(
|
||||
notification.attribute_handle, set()
|
||||
@@ -1111,9 +1125,11 @@ class Client:
|
||||
if callable(subscriber):
|
||||
subscriber(notification.attribute_value)
|
||||
else:
|
||||
subscriber.emit('update', notification.attribute_value)
|
||||
subscriber.emit(subscriber.EVENT_UPDATE, notification.attribute_value)
|
||||
|
||||
def on_att_handle_value_indication(self, indication):
|
||||
def on_att_handle_value_indication(
|
||||
self, indication: att.ATT_Handle_Value_Indication
|
||||
):
|
||||
# Call all subscribers
|
||||
subscribers = self.indication_subscribers.get(
|
||||
indication.attribute_handle, set()
|
||||
@@ -1126,10 +1142,10 @@ class Client:
|
||||
if callable(subscriber):
|
||||
subscriber(indication.attribute_value)
|
||||
else:
|
||||
subscriber.emit('update', indication.attribute_value)
|
||||
subscriber.emit(subscriber.EVENT_UPDATE, indication.attribute_value)
|
||||
|
||||
# Confirm that we received the indication
|
||||
self.send_confirmation(ATT_Handle_Value_Confirmation())
|
||||
self.send_confirmation(att.ATT_Handle_Value_Confirmation())
|
||||
|
||||
def cache_value(self, attribute_handle: int, value: bytes) -> None:
|
||||
self.cached_values[attribute_handle] = (
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GATT - Generic Attribute Profile
|
||||
# GATT - Generic att.Attribute Profile
|
||||
# Server
|
||||
#
|
||||
# See Bluetooth spec @ Vol 3, Part G
|
||||
@@ -24,42 +24,16 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import struct
|
||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
|
||||
from pyee import EventEmitter
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Iterable, Optional, TypeVar
|
||||
|
||||
from bumble import att, utils
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID
|
||||
from bumble.att import (
|
||||
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
||||
ATT_CID,
|
||||
ATT_DEFAULT_MTU,
|
||||
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR,
|
||||
ATT_INVALID_HANDLE_ERROR,
|
||||
ATT_INVALID_OFFSET_ERROR,
|
||||
ATT_REQUEST_NOT_SUPPORTED_ERROR,
|
||||
ATT_REQUESTS,
|
||||
ATT_PDU,
|
||||
ATT_UNLIKELY_ERROR_ERROR,
|
||||
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
|
||||
ATT_Error,
|
||||
ATT_Error_Response,
|
||||
ATT_Exchange_MTU_Response,
|
||||
ATT_Find_By_Type_Value_Response,
|
||||
ATT_Find_Information_Response,
|
||||
ATT_Handle_Value_Indication,
|
||||
ATT_Handle_Value_Notification,
|
||||
ATT_Read_Blob_Response,
|
||||
ATT_Read_By_Group_Type_Response,
|
||||
ATT_Read_By_Type_Response,
|
||||
ATT_Read_Response,
|
||||
ATT_Write_Response,
|
||||
Attribute,
|
||||
)
|
||||
from bumble.gatt import (
|
||||
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
@@ -70,14 +44,13 @@ from bumble.gatt import (
|
||||
Characteristic,
|
||||
CharacteristicDeclaration,
|
||||
CharacteristicValue,
|
||||
IncludedServiceDeclaration,
|
||||
Descriptor,
|
||||
IncludedServiceDeclaration,
|
||||
Service,
|
||||
)
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Device, Connection
|
||||
from bumble.device import Connection, Device
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -94,19 +67,21 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
|
||||
# -----------------------------------------------------------------------------
|
||||
# GATT Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class Server(EventEmitter):
|
||||
attributes: List[Attribute]
|
||||
services: List[Service]
|
||||
attributes_by_handle: Dict[int, Attribute]
|
||||
subscribers: Dict[int, Dict[int, bytes]]
|
||||
class Server(utils.EventEmitter):
|
||||
attributes: list[att.Attribute]
|
||||
services: list[Service]
|
||||
attributes_by_handle: dict[int, att.Attribute]
|
||||
subscribers: dict[int, dict[int, bytes]]
|
||||
indication_semaphores: defaultdict[int, asyncio.Semaphore]
|
||||
pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
|
||||
|
||||
EVENT_CHARACTERISTIC_SUBSCRIPTION = "characteristic_subscription"
|
||||
|
||||
def __init__(self, device: Device) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.services = []
|
||||
self.attributes = [] # Attributes, ordered by increasing handle values
|
||||
self.attributes = [] # att.Attributes, ordered by increasing handle values
|
||||
self.attributes_by_handle = {} # Map for fast attribute access by handle
|
||||
self.max_mtu = (
|
||||
GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
|
||||
@@ -121,12 +96,12 @@ class Server(EventEmitter):
|
||||
return "\n".join(map(str, self.attributes))
|
||||
|
||||
def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
|
||||
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
|
||||
self.device.send_l2cap_pdu(connection_handle, att.ATT_CID, pdu)
|
||||
|
||||
def next_handle(self) -> int:
|
||||
return 1 + len(self.attributes)
|
||||
|
||||
def get_advertising_service_data(self) -> Dict[Attribute, bytes]:
|
||||
def get_advertising_service_data(self) -> dict[att.Attribute, bytes]:
|
||||
return {
|
||||
attribute: data
|
||||
for attribute in self.attributes
|
||||
@@ -134,7 +109,7 @@ class Server(EventEmitter):
|
||||
and (data := attribute.get_advertising_data())
|
||||
}
|
||||
|
||||
def get_attribute(self, handle: int) -> Optional[Attribute]:
|
||||
def get_attribute(self, handle: int) -> Optional[att.Attribute]:
|
||||
attribute = self.attributes_by_handle.get(handle)
|
||||
if attribute:
|
||||
return attribute
|
||||
@@ -150,7 +125,7 @@ class Server(EventEmitter):
|
||||
AttributeGroupType = TypeVar('AttributeGroupType', Service, Characteristic)
|
||||
|
||||
def get_attribute_group(
|
||||
self, handle: int, group_type: Type[AttributeGroupType]
|
||||
self, handle: int, group_type: type[AttributeGroupType]
|
||||
) -> Optional[AttributeGroupType]:
|
||||
return next(
|
||||
(
|
||||
@@ -176,7 +151,7 @@ class Server(EventEmitter):
|
||||
|
||||
def get_characteristic_attributes(
|
||||
self, service_uuid: UUID, characteristic_uuid: UUID
|
||||
) -> Optional[Tuple[CharacteristicDeclaration, Characteristic]]:
|
||||
) -> Optional[tuple[CharacteristicDeclaration, Characteristic]]:
|
||||
service_handle = self.get_service_attribute(service_uuid)
|
||||
if not service_handle:
|
||||
return None
|
||||
@@ -225,7 +200,7 @@ class Server(EventEmitter):
|
||||
None,
|
||||
)
|
||||
|
||||
def add_attribute(self, attribute: Attribute) -> None:
|
||||
def add_attribute(self, attribute: att.Attribute) -> None:
|
||||
# Assign a handle to this attribute
|
||||
attribute.handle = self.next_handle()
|
||||
attribute.end_group_handle = (
|
||||
@@ -280,7 +255,7 @@ class Server(EventEmitter):
|
||||
# pylint: disable=line-too-long
|
||||
Descriptor(
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
Attribute.READABLE | Attribute.WRITEABLE,
|
||||
att.Attribute.READABLE | att.Attribute.WRITEABLE,
|
||||
CharacteristicValue(
|
||||
read=lambda connection, characteristic=characteristic: self.read_cccd(
|
||||
connection, characteristic
|
||||
@@ -305,11 +280,8 @@ class Server(EventEmitter):
|
||||
self.add_service(service)
|
||||
|
||||
def read_cccd(
|
||||
self, connection: Optional[Connection], characteristic: Characteristic
|
||||
self, connection: Connection, characteristic: Characteristic
|
||||
) -> bytes:
|
||||
if connection is None:
|
||||
return bytes([0, 0])
|
||||
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
cccd = None
|
||||
if subscribers:
|
||||
@@ -339,26 +311,29 @@ class Server(EventEmitter):
|
||||
notify_enabled = value[0] & 0x01 != 0
|
||||
indicate_enabled = value[0] & 0x02 != 0
|
||||
characteristic.emit(
|
||||
'subscription', connection, notify_enabled, indicate_enabled
|
||||
characteristic.EVENT_SUBSCRIPTION,
|
||||
connection,
|
||||
notify_enabled,
|
||||
indicate_enabled,
|
||||
)
|
||||
self.emit(
|
||||
'characteristic_subscription',
|
||||
self.EVENT_CHARACTERISTIC_SUBSCRIPTION,
|
||||
connection,
|
||||
characteristic,
|
||||
notify_enabled,
|
||||
indicate_enabled,
|
||||
)
|
||||
|
||||
def send_response(self, connection: Connection, response: ATT_PDU) -> None:
|
||||
def send_response(self, connection: Connection, response: att.ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
|
||||
)
|
||||
self.send_gatt_pdu(connection.handle, response.to_bytes())
|
||||
self.send_gatt_pdu(connection.handle, bytes(response))
|
||||
|
||||
async def notify_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
attribute: Attribute,
|
||||
attribute: att.Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
@@ -390,7 +365,7 @@ class Server(EventEmitter):
|
||||
value = value[: connection.att_mtu - 3]
|
||||
|
||||
# Notify
|
||||
notification = ATT_Handle_Value_Notification(
|
||||
notification = att.ATT_Handle_Value_Notification(
|
||||
attribute_handle=attribute.handle, attribute_value=value
|
||||
)
|
||||
logger.debug(
|
||||
@@ -401,7 +376,7 @@ class Server(EventEmitter):
|
||||
async def indicate_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
attribute: Attribute,
|
||||
attribute: att.Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
@@ -433,7 +408,7 @@ class Server(EventEmitter):
|
||||
value = value[: connection.att_mtu - 3]
|
||||
|
||||
# Indicate
|
||||
indication = ATT_Handle_Value_Indication(
|
||||
indication = att.ATT_Handle_Value_Indication(
|
||||
attribute_handle=attribute.handle, attribute_value=value
|
||||
)
|
||||
logger.debug(
|
||||
@@ -450,7 +425,7 @@ class Server(EventEmitter):
|
||||
)
|
||||
|
||||
try:
|
||||
self.send_gatt_pdu(connection.handle, indication.to_bytes())
|
||||
self.send_gatt_pdu(connection.handle, bytes(indication))
|
||||
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
|
||||
except asyncio.TimeoutError as error:
|
||||
logger.warning(color('!!! GATT Indicate timeout', 'red'))
|
||||
@@ -458,10 +433,10 @@ class Server(EventEmitter):
|
||||
finally:
|
||||
self.pending_confirmations[connection.handle] = None
|
||||
|
||||
async def notify_or_indicate_subscribers(
|
||||
async def _notify_or_indicate_subscribers(
|
||||
self,
|
||||
indicate: bool,
|
||||
attribute: Attribute,
|
||||
attribute: att.Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
@@ -488,19 +463,21 @@ class Server(EventEmitter):
|
||||
|
||||
async def notify_subscribers(
|
||||
self,
|
||||
attribute: Attribute,
|
||||
attribute: att.Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
):
|
||||
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
|
||||
return await self._notify_or_indicate_subscribers(
|
||||
False, attribute, value, force
|
||||
)
|
||||
|
||||
async def indicate_subscribers(
|
||||
self,
|
||||
attribute: Attribute,
|
||||
attribute: att.Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
):
|
||||
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
|
||||
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
|
||||
|
||||
def on_disconnection(self, connection: Connection) -> None:
|
||||
if connection.handle in self.subscribers:
|
||||
@@ -510,33 +487,33 @@ class Server(EventEmitter):
|
||||
if connection.handle in self.pending_confirmations:
|
||||
del self.pending_confirmations[connection.handle]
|
||||
|
||||
def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None:
|
||||
def on_gatt_pdu(self, connection: Connection, att_pdu: att.ATT_PDU) -> None:
|
||||
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
|
||||
handler_name = f'on_{att_pdu.name.lower()}'
|
||||
handler = getattr(self, handler_name, None)
|
||||
if handler is not None:
|
||||
try:
|
||||
handler(connection, att_pdu)
|
||||
except ATT_Error as error:
|
||||
except att.ATT_Error as error:
|
||||
logger.debug(f'normal exception returned by handler: {error}')
|
||||
response = ATT_Error_Response(
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=att_pdu.op_code,
|
||||
attribute_handle_in_error=error.att_handle,
|
||||
error_code=error.error_code,
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
except Exception as error:
|
||||
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
|
||||
response = ATT_Error_Response(
|
||||
except Exception:
|
||||
logger.exception(color("!!! Exception in handler:", "red"))
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=att_pdu.op_code,
|
||||
attribute_handle_in_error=0x0000,
|
||||
error_code=ATT_UNLIKELY_ERROR_ERROR,
|
||||
error_code=att.ATT_UNLIKELY_ERROR_ERROR,
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
raise error
|
||||
raise
|
||||
else:
|
||||
# No specific handler registered
|
||||
if att_pdu.op_code in ATT_REQUESTS:
|
||||
if att_pdu.op_code in att.ATT_REQUESTS:
|
||||
# Invoke the generic handler
|
||||
self.on_att_request(connection, att_pdu)
|
||||
else:
|
||||
@@ -552,7 +529,7 @@ class Server(EventEmitter):
|
||||
#######################################################
|
||||
# ATT handlers
|
||||
#######################################################
|
||||
def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None:
|
||||
def on_att_request(self, connection: Connection, pdu: att.ATT_PDU) -> None:
|
||||
'''
|
||||
Handler for requests without a more specific handler
|
||||
'''
|
||||
@@ -562,23 +539,25 @@ class Server(EventEmitter):
|
||||
)
|
||||
+ str(pdu)
|
||||
)
|
||||
response = ATT_Error_Response(
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=pdu.op_code,
|
||||
attribute_handle_in_error=0x0000,
|
||||
error_code=ATT_REQUEST_NOT_SUPPORTED_ERROR,
|
||||
error_code=att.ATT_REQUEST_NOT_SUPPORTED_ERROR,
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_exchange_mtu_request(self, connection, request):
|
||||
def on_att_exchange_mtu_request(
|
||||
self, connection: Connection, request: att.ATT_Exchange_MTU_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
|
||||
'''
|
||||
self.send_response(
|
||||
connection, ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
|
||||
connection, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
|
||||
)
|
||||
|
||||
# Compute the final MTU
|
||||
if request.client_rx_mtu >= ATT_DEFAULT_MTU:
|
||||
if request.client_rx_mtu >= att.ATT_DEFAULT_MTU:
|
||||
mtu = min(self.max_mtu, request.client_rx_mtu)
|
||||
|
||||
# Notify the device
|
||||
@@ -586,11 +565,14 @@ class Server(EventEmitter):
|
||||
else:
|
||||
logger.warning('invalid client_rx_mtu received, MTU not changed')
|
||||
|
||||
def on_att_find_information_request(self, connection, request):
|
||||
def on_att_find_information_request(
|
||||
self, connection: Connection, request: att.ATT_Find_Information_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.3.1 Find Information Request
|
||||
'''
|
||||
|
||||
response: att.ATT_PDU
|
||||
# Check the request parameters
|
||||
if (
|
||||
request.starting_handle == 0
|
||||
@@ -598,17 +580,17 @@ class Server(EventEmitter):
|
||||
):
|
||||
self.send_response(
|
||||
connection,
|
||||
ATT_Error_Response(
|
||||
att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.starting_handle,
|
||||
error_code=ATT_INVALID_HANDLE_ERROR,
|
||||
error_code=att.ATT_INVALID_HANDLE_ERROR,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
# Build list of returned attributes
|
||||
pdu_space_available = connection.att_mtu - 2
|
||||
attributes = []
|
||||
attributes: list[att.Attribute] = []
|
||||
uuid_size = 0
|
||||
for attribute in (
|
||||
attribute
|
||||
@@ -638,21 +620,23 @@ class Server(EventEmitter):
|
||||
struct.pack('<H', attribute.handle) + attribute.type.to_pdu_bytes()
|
||||
for attribute in attributes
|
||||
]
|
||||
response = ATT_Find_Information_Response(
|
||||
response = att.ATT_Find_Information_Response(
|
||||
format=1 if len(attributes[0].type.to_pdu_bytes()) == 2 else 2,
|
||||
information_data=b''.join(information_data_list),
|
||||
)
|
||||
else:
|
||||
response = ATT_Error_Response(
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.starting_handle,
|
||||
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
)
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_find_by_type_value_request(self, connection, request):
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_find_by_type_value_request(
|
||||
self, connection: Connection, request: att.ATT_Find_By_Type_Value_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
|
||||
'''
|
||||
@@ -660,6 +644,7 @@ class Server(EventEmitter):
|
||||
# Build list of returned attributes
|
||||
pdu_space_available = connection.att_mtu - 2
|
||||
attributes = []
|
||||
response: att.ATT_PDU
|
||||
async for attribute in (
|
||||
attribute
|
||||
for attribute in self.attributes
|
||||
@@ -692,33 +677,35 @@ class Server(EventEmitter):
|
||||
handles_information_list.append(
|
||||
struct.pack('<HH', attribute.handle, group_end_handle)
|
||||
)
|
||||
response = ATT_Find_By_Type_Value_Response(
|
||||
response = att.ATT_Find_By_Type_Value_Response(
|
||||
handles_information_list=b''.join(handles_information_list)
|
||||
)
|
||||
else:
|
||||
response = ATT_Error_Response(
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.starting_handle,
|
||||
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
)
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_type_request(self, connection, request):
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_type_request(
|
||||
self, connection: Connection, request: att.ATT_Read_By_Type_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
|
||||
'''
|
||||
|
||||
pdu_space_available = connection.att_mtu - 2
|
||||
|
||||
response = ATT_Error_Response(
|
||||
response: att.ATT_PDU = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.starting_handle,
|
||||
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
)
|
||||
|
||||
attributes = []
|
||||
attributes: list[tuple[int, bytes]] = []
|
||||
for attribute in (
|
||||
attribute
|
||||
for attribute in self.attributes
|
||||
@@ -729,11 +716,11 @@ class Server(EventEmitter):
|
||||
):
|
||||
try:
|
||||
attribute_value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
except att.ATT_Error as error:
|
||||
# If the first attribute is unreadable, return an error
|
||||
# Otherwise return attributes up to this point
|
||||
if not attributes:
|
||||
response = ATT_Error_Response(
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=attribute.handle,
|
||||
error_code=error.error_code,
|
||||
@@ -762,7 +749,7 @@ class Server(EventEmitter):
|
||||
attribute_data_list = [
|
||||
struct.pack('<H', handle) + value for handle, value in attributes
|
||||
]
|
||||
response = ATT_Read_By_Type_Response(
|
||||
response = att.ATT_Read_By_Type_Response(
|
||||
length=entry_size, attribute_data_list=b''.join(attribute_data_list)
|
||||
)
|
||||
else:
|
||||
@@ -770,96 +757,105 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_request(self, connection, request):
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_read_request(
|
||||
self, connection: Connection, request: att.ATT_Read_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
|
||||
'''
|
||||
|
||||
response: att.ATT_PDU
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
try:
|
||||
value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
except att.ATT_Error as error:
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=error.error_code,
|
||||
)
|
||||
else:
|
||||
value_size = min(connection.att_mtu - 1, len(value))
|
||||
response = ATT_Read_Response(attribute_value=value[:value_size])
|
||||
response = att.ATT_Read_Response(attribute_value=value[:value_size])
|
||||
else:
|
||||
response = ATT_Error_Response(
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=ATT_INVALID_HANDLE_ERROR,
|
||||
error_code=att.ATT_INVALID_HANDLE_ERROR,
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_blob_request(self, connection, request):
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_read_blob_request(
|
||||
self, connection: Connection, request: att.ATT_Read_Blob_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
|
||||
'''
|
||||
|
||||
response: att.ATT_PDU
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
try:
|
||||
value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
except att.ATT_Error as error:
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=error.error_code,
|
||||
)
|
||||
else:
|
||||
if request.value_offset > len(value):
|
||||
response = ATT_Error_Response(
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=ATT_INVALID_OFFSET_ERROR,
|
||||
error_code=att.ATT_INVALID_OFFSET_ERROR,
|
||||
)
|
||||
elif len(value) <= connection.att_mtu - 1:
|
||||
response = ATT_Error_Response(
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
||||
error_code=att.ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
||||
)
|
||||
else:
|
||||
part_size = min(
|
||||
connection.att_mtu - 1, len(value) - request.value_offset
|
||||
)
|
||||
response = ATT_Read_Blob_Response(
|
||||
response = att.ATT_Read_Blob_Response(
|
||||
part_attribute_value=value[
|
||||
request.value_offset : request.value_offset + part_size
|
||||
]
|
||||
)
|
||||
else:
|
||||
response = ATT_Error_Response(
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=ATT_INVALID_HANDLE_ERROR,
|
||||
error_code=att.ATT_INVALID_HANDLE_ERROR,
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_group_type_request(self, connection, request):
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_group_type_request(
|
||||
self, connection: Connection, request: att.ATT_Read_By_Group_Type_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
|
||||
'''
|
||||
response: att.ATT_PDU
|
||||
if request.attribute_group_type not in (
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
):
|
||||
response = ATT_Error_Response(
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.starting_handle,
|
||||
error_code=ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
|
||||
error_code=att.ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
return
|
||||
|
||||
pdu_space_available = connection.att_mtu - 2
|
||||
attributes = []
|
||||
attributes: list[tuple[int, int, bytes]] = []
|
||||
for attribute in (
|
||||
attribute
|
||||
for attribute in self.attributes
|
||||
@@ -896,21 +892,23 @@ class Server(EventEmitter):
|
||||
struct.pack('<HH', handle, end_group_handle) + value
|
||||
for handle, end_group_handle, value in attributes
|
||||
]
|
||||
response = ATT_Read_By_Group_Type_Response(
|
||||
response = att.ATT_Read_By_Group_Type_Response(
|
||||
length=len(attribute_data_list[0]),
|
||||
attribute_data_list=b''.join(attribute_data_list),
|
||||
)
|
||||
else:
|
||||
response = ATT_Error_Response(
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.starting_handle,
|
||||
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
)
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_write_request(self, connection, request):
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_write_request(
|
||||
self, connection: Connection, request: att.ATT_Write_Request
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
|
||||
'''
|
||||
@@ -920,10 +918,10 @@ class Server(EventEmitter):
|
||||
if attribute is None:
|
||||
self.send_response(
|
||||
connection,
|
||||
ATT_Error_Response(
|
||||
att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=ATT_INVALID_HANDLE_ERROR,
|
||||
error_code=att.ATT_INVALID_HANDLE_ERROR,
|
||||
),
|
||||
)
|
||||
return
|
||||
@@ -934,30 +932,33 @@ class Server(EventEmitter):
|
||||
if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE:
|
||||
self.send_response(
|
||||
connection,
|
||||
ATT_Error_Response(
|
||||
att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=ATT_INVALID_ATTRIBUTE_LENGTH_ERROR,
|
||||
error_code=att.ATT_INVALID_ATTRIBUTE_LENGTH_ERROR,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
response: att.ATT_PDU
|
||||
try:
|
||||
# Accept the value
|
||||
await attribute.write_value(connection, request.attribute_value)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
except att.ATT_Error as error:
|
||||
response = att.ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=error.error_code,
|
||||
)
|
||||
else:
|
||||
# Done
|
||||
response = ATT_Write_Response()
|
||||
response = att.ATT_Write_Response()
|
||||
self.send_response(connection, response)
|
||||
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_write_command(self, connection, request):
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_att_write_command(
|
||||
self, connection: Connection, request: att.ATT_Write_Command
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
|
||||
'''
|
||||
@@ -976,18 +977,25 @@ class Server(EventEmitter):
|
||||
# Accept the value
|
||||
try:
|
||||
await attribute.write_value(connection, request.attribute_value)
|
||||
except Exception as error:
|
||||
logger.exception(f'!!! ignoring exception: {error}')
|
||||
except Exception:
|
||||
logger.exception('!!! ignoring exception')
|
||||
|
||||
def on_att_handle_value_confirmation(self, connection, _confirmation):
|
||||
def on_att_handle_value_confirmation(
|
||||
self,
|
||||
connection: Connection,
|
||||
confirmation: att.ATT_Handle_Value_Confirmation,
|
||||
):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation
|
||||
'''
|
||||
if self.pending_confirmations[connection.handle] is None:
|
||||
del confirmation # Unused.
|
||||
if (
|
||||
pending_confirmation := self.pending_confirmations[connection.handle]
|
||||
) is None:
|
||||
# Not expected!
|
||||
logger.warning(
|
||||
'!!! unexpected confirmation, there is no pending indication'
|
||||
)
|
||||
return
|
||||
|
||||
self.pending_confirmations[connection.handle].set_result(None)
|
||||
pending_confirmation.set_result(None)
|
||||
|
||||
5649
bumble/hci.py
5649
bumble/hci.py
File diff suppressed because it is too large
Load Diff
@@ -17,44 +17,36 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, MutableMapping
|
||||
import datetime
|
||||
from typing import cast, Any, Optional
|
||||
import logging
|
||||
from collections.abc import Callable, MutableMapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from bumble import avc
|
||||
from bumble import avctp
|
||||
from bumble import avdtp
|
||||
from bumble import avrcp
|
||||
from bumble import crypto
|
||||
from bumble import rfcomm
|
||||
from bumble import sdp
|
||||
from bumble.colors import color
|
||||
from bumble import avc, avctp, avdtp, avrcp, crypto, rfcomm, sdp
|
||||
from bumble.att import ATT_CID, ATT_PDU
|
||||
from bumble.smp import SMP_CID, SMP_Command
|
||||
from bumble.colors import color
|
||||
from bumble.core import name_or_number
|
||||
from bumble.l2cap import (
|
||||
L2CAP_PDU,
|
||||
L2CAP_CONNECTION_REQUEST,
|
||||
L2CAP_CONNECTION_RESPONSE,
|
||||
L2CAP_SIGNALING_CID,
|
||||
L2CAP_LE_SIGNALING_CID,
|
||||
L2CAP_Control_Frame,
|
||||
L2CAP_Connection_Request,
|
||||
L2CAP_Connection_Response,
|
||||
)
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
HCI_EVENT_PACKET,
|
||||
HCI_ACL_DATA_PACKET,
|
||||
HCI_DISCONNECTION_COMPLETE_EVENT,
|
||||
HCI_AclDataPacketAssembler,
|
||||
HCI_Packet,
|
||||
HCI_Event,
|
||||
HCI_EVENT_PACKET,
|
||||
Address,
|
||||
HCI_AclDataPacket,
|
||||
HCI_AclDataPacketAssembler,
|
||||
HCI_Disconnection_Complete_Event,
|
||||
HCI_Event,
|
||||
HCI_Packet,
|
||||
)
|
||||
|
||||
from bumble.l2cap import (
|
||||
L2CAP_LE_SIGNALING_CID,
|
||||
L2CAP_PDU,
|
||||
L2CAP_SIGNALING_CID,
|
||||
CommandCode,
|
||||
L2CAP_Connection_Request,
|
||||
L2CAP_Connection_Response,
|
||||
L2CAP_Control_Frame,
|
||||
)
|
||||
from bumble.smp import SMP_CID, SMP_Command
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -106,14 +98,14 @@ class PacketTracer:
|
||||
self.analyzer.emit(control_frame)
|
||||
|
||||
# Check if this signals a new channel
|
||||
if control_frame.code == L2CAP_CONNECTION_REQUEST:
|
||||
if control_frame.code == CommandCode.L2CAP_CONNECTION_REQUEST:
|
||||
connection_request = cast(L2CAP_Connection_Request, control_frame)
|
||||
self.psms[connection_request.source_cid] = connection_request.psm
|
||||
elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
|
||||
elif control_frame.code == CommandCode.L2CAP_CONNECTION_RESPONSE:
|
||||
connection_response = cast(L2CAP_Connection_Response, control_frame)
|
||||
if (
|
||||
connection_response.result
|
||||
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
|
||||
== L2CAP_Connection_Response.Result.CONNECTION_SUCCESSFUL
|
||||
):
|
||||
if self.peer and (
|
||||
psm := self.peer.psms.get(connection_response.source_cid)
|
||||
|
||||
289
bumble/hfp.py
289
bumble/hfp.py
@@ -17,50 +17,34 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import collections.abc
|
||||
import logging
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import traceback
|
||||
import pyee
|
||||
import logging
|
||||
import re
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Union,
|
||||
Set,
|
||||
Any,
|
||||
Optional,
|
||||
Type,
|
||||
Tuple,
|
||||
ClassVar,
|
||||
Iterable,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional, Union
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble import at
|
||||
from bumble import device
|
||||
from bumble import rfcomm
|
||||
from bumble import sdp
|
||||
from bumble import at, device, rfcomm, sdp, utils
|
||||
from bumble.colors import color
|
||||
from bumble.core import (
|
||||
ProtocolError,
|
||||
BT_GENERIC_AUDIO_SERVICE,
|
||||
BT_HANDSFREE_SERVICE,
|
||||
BT_HANDSFREE_AUDIO_GATEWAY_SERVICE,
|
||||
BT_HANDSFREE_SERVICE,
|
||||
BT_L2CAP_PROTOCOL_ID,
|
||||
BT_RFCOMM_PROTOCOL_ID,
|
||||
ProtocolError,
|
||||
)
|
||||
from bumble.hci import (
|
||||
HCI_Enhanced_Setup_Synchronous_Connection_Command,
|
||||
CodingFormat,
|
||||
CodecID,
|
||||
CodingFormat,
|
||||
HCI_Enhanced_Setup_Synchronous_Connection_Command,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -141,7 +125,7 @@ class HfFeature(enum.IntFlag):
|
||||
"""
|
||||
HF supported features (AT+BRSF=) (normative).
|
||||
|
||||
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
|
||||
Hands-Free Profile v1.9, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
|
||||
"""
|
||||
|
||||
EC_NR = 0x001 # Echo Cancel & Noise reduction
|
||||
@@ -155,14 +139,14 @@ class HfFeature(enum.IntFlag):
|
||||
HF_INDICATORS = 0x100
|
||||
ESCO_S4_SETTINGS_SUPPORTED = 0x200
|
||||
ENHANCED_VOICE_RECOGNITION_STATUS = 0x400
|
||||
VOICE_RECOGNITION_TEST = 0x800
|
||||
VOICE_RECOGNITION_TEXT = 0x800
|
||||
|
||||
|
||||
class AgFeature(enum.IntFlag):
|
||||
"""
|
||||
AG supported features (+BRSF:) (normative).
|
||||
|
||||
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
|
||||
Hands-Free Profile v1.9, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
|
||||
"""
|
||||
|
||||
THREE_WAY_CALLING = 0x001
|
||||
@@ -178,7 +162,7 @@ class AgFeature(enum.IntFlag):
|
||||
HF_INDICATORS = 0x400
|
||||
ESCO_S4_SETTINGS_SUPPORTED = 0x800
|
||||
ENHANCED_VOICE_RECOGNITION_STATUS = 0x1000
|
||||
VOICE_RECOGNITION_TEST = 0x2000
|
||||
VOICE_RECOGNITION_TEXT = 0x2000
|
||||
|
||||
|
||||
class AudioCodec(enum.IntEnum):
|
||||
@@ -375,7 +359,7 @@ class CallLineIdentification:
|
||||
cli_validity: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def parse_from(cls: Type[Self], parameters: List[bytes]) -> Self:
|
||||
def parse_from(cls, parameters: list[bytes]) -> Self:
|
||||
return cls(
|
||||
number=parameters[0].decode(),
|
||||
type=int(parameters[1]),
|
||||
@@ -505,9 +489,9 @@ STATUS_CODES = {
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HfConfiguration:
|
||||
supported_hf_features: List[HfFeature]
|
||||
supported_hf_indicators: List[HfIndicator]
|
||||
supported_audio_codecs: List[AudioCodec]
|
||||
supported_hf_features: list[HfFeature]
|
||||
supported_hf_indicators: list[HfIndicator]
|
||||
supported_audio_codecs: list[AudioCodec]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -535,7 +519,7 @@ class AtResponse:
|
||||
parameters: list
|
||||
|
||||
@classmethod
|
||||
def parse_from(cls: Type[Self], buffer: bytearray) -> Self:
|
||||
def parse_from(cls: type[Self], buffer: bytearray) -> Self:
|
||||
code_and_parameters = buffer.split(b':')
|
||||
parameters = (
|
||||
code_and_parameters[1] if len(code_and_parameters) > 1 else bytearray()
|
||||
@@ -563,7 +547,7 @@ class AtCommand:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse_from(cls: Type[Self], buffer: bytearray) -> Self:
|
||||
def parse_from(cls: type[Self], buffer: bytearray) -> Self:
|
||||
if not (match := cls._PARSE_PATTERN.fullmatch(buffer.decode())):
|
||||
if buffer.startswith(b'ATA'):
|
||||
return cls(code='A', sub_code=AtCommand.SubCode.NONE, parameters=[])
|
||||
@@ -598,7 +582,7 @@ class AgIndicatorState:
|
||||
"""
|
||||
|
||||
indicator: AgIndicator
|
||||
supported_values: Set[int]
|
||||
supported_values: set[int]
|
||||
current_status: int
|
||||
index: Optional[int] = None
|
||||
enabled: bool = True
|
||||
@@ -616,14 +600,14 @@ class AgIndicatorState:
|
||||
return f'(\"{self.indicator.value}\",{supported_values_text})'
|
||||
|
||||
@classmethod
|
||||
def call(cls: Type[Self]) -> Self:
|
||||
def call(cls: type[Self]) -> Self:
|
||||
"""Default call indicator state."""
|
||||
return cls(
|
||||
indicator=AgIndicator.CALL, supported_values={0, 1}, current_status=0
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def callsetup(cls: Type[Self]) -> Self:
|
||||
def callsetup(cls: type[Self]) -> Self:
|
||||
"""Default callsetup indicator state."""
|
||||
return cls(
|
||||
indicator=AgIndicator.CALL_SETUP,
|
||||
@@ -632,7 +616,7 @@ class AgIndicatorState:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def callheld(cls: Type[Self]) -> Self:
|
||||
def callheld(cls: type[Self]) -> Self:
|
||||
"""Default call indicator state."""
|
||||
return cls(
|
||||
indicator=AgIndicator.CALL_HELD,
|
||||
@@ -641,14 +625,14 @@ class AgIndicatorState:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def service(cls: Type[Self]) -> Self:
|
||||
def service(cls: type[Self]) -> Self:
|
||||
"""Default service indicator state."""
|
||||
return cls(
|
||||
indicator=AgIndicator.SERVICE, supported_values={0, 1}, current_status=0
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def signal(cls: Type[Self]) -> Self:
|
||||
def signal(cls: type[Self]) -> Self:
|
||||
"""Default signal indicator state."""
|
||||
return cls(
|
||||
indicator=AgIndicator.SIGNAL,
|
||||
@@ -657,14 +641,14 @@ class AgIndicatorState:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def roam(cls: Type[Self]) -> Self:
|
||||
def roam(cls: type[Self]) -> Self:
|
||||
"""Default roam indicator state."""
|
||||
return cls(
|
||||
indicator=AgIndicator.CALL, supported_values={0, 1}, current_status=0
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def battchg(cls: Type[Self]) -> Self:
|
||||
def battchg(cls: type[Self]) -> Self:
|
||||
"""Default battery charge indicator state."""
|
||||
return cls(
|
||||
indicator=AgIndicator.BATTERY_CHARGE,
|
||||
@@ -690,7 +674,7 @@ class HfIndicatorState:
|
||||
current_status: int = 0
|
||||
|
||||
|
||||
class HfProtocol(pyee.EventEmitter):
|
||||
class HfProtocol(utils.EventEmitter):
|
||||
"""
|
||||
Implementation for the Hands-Free side of the Hands-Free profile.
|
||||
|
||||
@@ -720,17 +704,25 @@ class HfProtocol(pyee.EventEmitter):
|
||||
vrec: VoiceRecognitionState
|
||||
"""
|
||||
|
||||
EVENT_CODEC_NEGOTIATION = "codec_negotiation"
|
||||
EVENT_AG_INDICATOR = "ag_indicator"
|
||||
EVENT_SPEAKER_VOLUME = "speaker_volume"
|
||||
EVENT_MICROPHONE_VOLUME = "microphone_volume"
|
||||
EVENT_RING = "ring"
|
||||
EVENT_CLI_NOTIFICATION = "cli_notification"
|
||||
EVENT_VOICE_RECOGNITION = "voice_recognition"
|
||||
|
||||
class HfLoopTermination(HfpProtocolError):
|
||||
"""Termination signal for run() loop."""
|
||||
|
||||
supported_hf_features: int
|
||||
supported_audio_codecs: List[AudioCodec]
|
||||
supported_audio_codecs: list[AudioCodec]
|
||||
|
||||
supported_ag_features: int
|
||||
supported_ag_call_hold_operations: List[CallHoldOperation]
|
||||
supported_ag_call_hold_operations: list[CallHoldOperation]
|
||||
|
||||
ag_indicators: List[AgIndicatorState]
|
||||
hf_indicators: Dict[HfIndicator, HfIndicatorState]
|
||||
ag_indicators: list[AgIndicatorState]
|
||||
hf_indicators: dict[HfIndicator, HfIndicatorState]
|
||||
|
||||
dlc: rfcomm.DLC
|
||||
command_lock: asyncio.Lock
|
||||
@@ -777,7 +769,8 @@ class HfProtocol(pyee.EventEmitter):
|
||||
self.dlc.sink = self._read_at
|
||||
# Stop the run() loop when L2CAP is closed.
|
||||
self.dlc.multiplexer.l2cap_channel.on(
|
||||
'close', lambda: self.unsolicited_queue.put_nowait(None)
|
||||
self.dlc.multiplexer.l2cap_channel.EVENT_CLOSE,
|
||||
lambda: self.unsolicited_queue.put_nowait(None),
|
||||
)
|
||||
|
||||
def supports_hf_feature(self, feature: HfFeature) -> bool:
|
||||
@@ -795,36 +788,39 @@ class HfProtocol(pyee.EventEmitter):
|
||||
# Append to the read buffer.
|
||||
self.read_buffer.extend(data)
|
||||
|
||||
# Locate header and trailer.
|
||||
header = self.read_buffer.find(b'\r\n')
|
||||
trailer = self.read_buffer.find(b'\r\n', header + 2)
|
||||
if header == -1 or trailer == -1:
|
||||
return
|
||||
while self.read_buffer:
|
||||
# Locate header and trailer.
|
||||
header = self.read_buffer.find(b'\r\n')
|
||||
trailer = self.read_buffer.find(b'\r\n', header + 2)
|
||||
if header == -1 or trailer == -1:
|
||||
return
|
||||
|
||||
# Isolate the AT response code and parameters.
|
||||
raw_response = self.read_buffer[header + 2 : trailer]
|
||||
response = AtResponse.parse_from(raw_response)
|
||||
logger.debug(f"<<< {raw_response.decode()}")
|
||||
# Isolate the AT response code and parameters.
|
||||
raw_response = self.read_buffer[header + 2 : trailer]
|
||||
response = AtResponse.parse_from(raw_response)
|
||||
logger.debug(f"<<< {raw_response.decode()}")
|
||||
|
||||
# Consume the response bytes.
|
||||
self.read_buffer = self.read_buffer[trailer + 2 :]
|
||||
# Consume the response bytes.
|
||||
self.read_buffer = self.read_buffer[trailer + 2 :]
|
||||
|
||||
# Forward the received code to the correct queue.
|
||||
if self.command_lock.locked() and (
|
||||
response.code in STATUS_CODES or response.code in RESPONSE_CODES
|
||||
):
|
||||
self.response_queue.put_nowait(response)
|
||||
elif response.code in UNSOLICITED_CODES:
|
||||
self.unsolicited_queue.put_nowait(response)
|
||||
else:
|
||||
logger.warning(f"dropping unexpected response with code '{response.code}'")
|
||||
# Forward the received code to the correct queue.
|
||||
if self.command_lock.locked() and (
|
||||
response.code in STATUS_CODES or response.code in RESPONSE_CODES
|
||||
):
|
||||
self.response_queue.put_nowait(response)
|
||||
elif response.code in UNSOLICITED_CODES:
|
||||
self.unsolicited_queue.put_nowait(response)
|
||||
else:
|
||||
logger.warning(
|
||||
f"dropping unexpected response with code '{response.code}'"
|
||||
)
|
||||
|
||||
async def execute_command(
|
||||
self,
|
||||
cmd: str,
|
||||
timeout: float = 1.0,
|
||||
response_type: AtResponseType = AtResponseType.NONE,
|
||||
) -> Union[None, AtResponse, List[AtResponse]]:
|
||||
) -> Union[None, AtResponse, list[AtResponse]]:
|
||||
"""
|
||||
Sends an AT command and wait for the peer response.
|
||||
Wait for the AT responses sent by the peer, to the status code.
|
||||
@@ -841,7 +837,7 @@ class HfProtocol(pyee.EventEmitter):
|
||||
async with self.command_lock:
|
||||
logger.debug(f">>> {cmd}")
|
||||
self.dlc.write(cmd + '\r')
|
||||
responses: List[AtResponse] = []
|
||||
responses: list[AtResponse] = []
|
||||
|
||||
while True:
|
||||
result = await asyncio.wait_for(
|
||||
@@ -1031,7 +1027,7 @@ class HfProtocol(pyee.EventEmitter):
|
||||
# ID. The HF shall be ready to accept the synchronous connection
|
||||
# establishment as soon as it has sent the AT commands AT+BCS=<Codec ID>.
|
||||
self.active_codec = AudioCodec(codec_id)
|
||||
self.emit('codec_negotiation', self.active_codec)
|
||||
self.emit(self.EVENT_CODEC_NEGOTIATION, self.active_codec)
|
||||
|
||||
logger.info("codec connection setup completed")
|
||||
|
||||
@@ -1061,7 +1057,7 @@ class HfProtocol(pyee.EventEmitter):
|
||||
# code, with the value indicating (call=0).
|
||||
await self.execute_command("AT+CHUP")
|
||||
|
||||
async def query_current_calls(self) -> List[CallInfo]:
|
||||
async def query_current_calls(self) -> list[CallInfo]:
|
||||
"""4.32.1 Query List of Current Calls in AG.
|
||||
|
||||
Return:
|
||||
@@ -1092,7 +1088,7 @@ class HfProtocol(pyee.EventEmitter):
|
||||
# CIEV is in 1-index, while ag_indicators is in 0-index.
|
||||
ag_indicator = self.ag_indicators[index - 1]
|
||||
ag_indicator.current_status = value
|
||||
self.emit('ag_indicator', ag_indicator)
|
||||
self.emit(self.EVENT_AG_INDICATOR, ag_indicator)
|
||||
logger.info(f"AG indicator updated: {ag_indicator.indicator}, {value}")
|
||||
|
||||
async def handle_unsolicited(self):
|
||||
@@ -1107,19 +1103,21 @@ class HfProtocol(pyee.EventEmitter):
|
||||
int(result.parameters[0]), int(result.parameters[1])
|
||||
)
|
||||
elif result.code == "+VGS":
|
||||
self.emit('speaker_volume', int(result.parameters[0]))
|
||||
self.emit(self.EVENT_SPEAKER_VOLUME, int(result.parameters[0]))
|
||||
elif result.code == "+VGM":
|
||||
self.emit('microphone_volume', int(result.parameters[0]))
|
||||
self.emit(self.EVENT_MICROPHONE_VOLUME, int(result.parameters[0]))
|
||||
elif result.code == "RING":
|
||||
self.emit('ring')
|
||||
self.emit(self.EVENT_RING)
|
||||
elif result.code == "+CLIP":
|
||||
self.emit(
|
||||
'cli_notification', CallLineIdentification.parse_from(result.parameters)
|
||||
self.EVENT_CLI_NOTIFICATION,
|
||||
CallLineIdentification.parse_from(result.parameters),
|
||||
)
|
||||
elif result.code == "+BVRA":
|
||||
# TODO: Support Enhanced Voice Recognition.
|
||||
self.emit(
|
||||
'voice_recognition', VoiceRecognitionState(int(result.parameters[0]))
|
||||
self.EVENT_VOICE_RECOGNITION,
|
||||
VoiceRecognitionState(int(result.parameters[0])),
|
||||
)
|
||||
else:
|
||||
logging.info(f"unhandled unsolicited response {result.code}")
|
||||
@@ -1143,7 +1141,7 @@ class HfProtocol(pyee.EventEmitter):
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
class AgProtocol(pyee.EventEmitter):
|
||||
class AgProtocol(utils.EventEmitter):
|
||||
"""
|
||||
Implementation for the Audio-Gateway side of the Hands-Free profile.
|
||||
|
||||
@@ -1176,28 +1174,41 @@ class AgProtocol(pyee.EventEmitter):
|
||||
volume: Int
|
||||
"""
|
||||
|
||||
EVENT_SLC_COMPLETE = "slc_complete"
|
||||
EVENT_SUPPORTED_AUDIO_CODECS = "supported_audio_codecs"
|
||||
EVENT_CODEC_NEGOTIATION = "codec_negotiation"
|
||||
EVENT_VOICE_RECOGNITION = "voice_recognition"
|
||||
EVENT_CALL_HOLD = "call_hold"
|
||||
EVENT_HF_INDICATOR = "hf_indicator"
|
||||
EVENT_CODEC_CONNECTION_REQUEST = "codec_connection_request"
|
||||
EVENT_ANSWER = "answer"
|
||||
EVENT_DIAL = "dial"
|
||||
EVENT_HANG_UP = "hang_up"
|
||||
EVENT_SPEAKER_VOLUME = "speaker_volume"
|
||||
EVENT_MICROPHONE_VOLUME = "microphone_volume"
|
||||
|
||||
supported_hf_features: int
|
||||
supported_hf_indicators: Set[HfIndicator]
|
||||
supported_audio_codecs: List[AudioCodec]
|
||||
supported_hf_indicators: set[HfIndicator]
|
||||
supported_audio_codecs: list[AudioCodec]
|
||||
|
||||
supported_ag_features: int
|
||||
supported_ag_call_hold_operations: List[CallHoldOperation]
|
||||
supported_ag_call_hold_operations: list[CallHoldOperation]
|
||||
|
||||
ag_indicators: List[AgIndicatorState]
|
||||
ag_indicators: list[AgIndicatorState]
|
||||
hf_indicators: collections.OrderedDict[HfIndicator, HfIndicatorState]
|
||||
|
||||
dlc: rfcomm.DLC
|
||||
|
||||
read_buffer: bytearray
|
||||
active_codec: AudioCodec
|
||||
calls: List[CallInfo]
|
||||
calls: list[CallInfo]
|
||||
|
||||
indicator_report_enabled: bool
|
||||
inband_ringtone_enabled: bool
|
||||
cme_error_enabled: bool
|
||||
cli_notification_enabled: bool
|
||||
call_waiting_enabled: bool
|
||||
_remained_slc_setup_features: Set[HfFeature]
|
||||
_remained_slc_setup_features: set[HfFeature]
|
||||
|
||||
def __init__(self, dlc: rfcomm.DLC, configuration: AgConfiguration) -> None:
|
||||
super().__init__()
|
||||
@@ -1244,31 +1255,32 @@ class AgProtocol(pyee.EventEmitter):
|
||||
# Append to the read buffer.
|
||||
self.read_buffer.extend(data)
|
||||
|
||||
# Locate the trailer.
|
||||
trailer = self.read_buffer.find(b'\r')
|
||||
if trailer == -1:
|
||||
return
|
||||
while self.read_buffer:
|
||||
# Locate the trailer.
|
||||
trailer = self.read_buffer.find(b'\r')
|
||||
if trailer == -1:
|
||||
return
|
||||
|
||||
# Isolate the AT response code and parameters.
|
||||
raw_command = self.read_buffer[:trailer]
|
||||
command = AtCommand.parse_from(raw_command)
|
||||
logger.debug(f"<<< {raw_command.decode()}")
|
||||
# Isolate the AT response code and parameters.
|
||||
raw_command = self.read_buffer[:trailer]
|
||||
command = AtCommand.parse_from(raw_command)
|
||||
logger.debug(f"<<< {raw_command.decode()}")
|
||||
|
||||
# Consume the response bytes.
|
||||
self.read_buffer = self.read_buffer[trailer + 1 :]
|
||||
# Consume the response bytes.
|
||||
self.read_buffer = self.read_buffer[trailer + 1 :]
|
||||
|
||||
if command.sub_code == AtCommand.SubCode.TEST:
|
||||
handler_name = f'_on_{command.code.lower()}_test'
|
||||
elif command.sub_code == AtCommand.SubCode.READ:
|
||||
handler_name = f'_on_{command.code.lower()}_read'
|
||||
else:
|
||||
handler_name = f'_on_{command.code.lower()}'
|
||||
if command.sub_code == AtCommand.SubCode.TEST:
|
||||
handler_name = f'_on_{command.code.lower()}_test'
|
||||
elif command.sub_code == AtCommand.SubCode.READ:
|
||||
handler_name = f'_on_{command.code.lower()}_read'
|
||||
else:
|
||||
handler_name = f'_on_{command.code.lower()}'
|
||||
|
||||
if handler := getattr(self, handler_name, None):
|
||||
handler(*command.parameters)
|
||||
else:
|
||||
logger.warning('Handler %s not found', handler_name)
|
||||
self.send_response('ERROR')
|
||||
if handler := getattr(self, handler_name, None):
|
||||
handler(*command.parameters)
|
||||
else:
|
||||
logger.warning('Handler %s not found', handler_name)
|
||||
self.send_response('ERROR')
|
||||
|
||||
def send_response(self, response: str) -> None:
|
||||
"""Sends an AT response."""
|
||||
@@ -1367,7 +1379,7 @@ class AgProtocol(pyee.EventEmitter):
|
||||
|
||||
def _check_remained_slc_commands(self) -> None:
|
||||
if not self._remained_slc_setup_features:
|
||||
self.emit('slc_complete')
|
||||
self.emit(self.EVENT_SLC_COMPLETE)
|
||||
|
||||
def _on_brsf(self, hf_features: bytes) -> None:
|
||||
self.supported_hf_features = int(hf_features)
|
||||
@@ -1386,16 +1398,17 @@ class AgProtocol(pyee.EventEmitter):
|
||||
|
||||
def _on_bac(self, *args) -> None:
|
||||
self.supported_audio_codecs = [AudioCodec(int(value)) for value in args]
|
||||
self.emit(self.EVENT_SUPPORTED_AUDIO_CODECS, self.supported_audio_codecs)
|
||||
self.send_ok()
|
||||
|
||||
def _on_bcs(self, codec: bytes) -> None:
|
||||
self.active_codec = AudioCodec(int(codec))
|
||||
self.send_ok()
|
||||
self.emit('codec_negotiation', self.active_codec)
|
||||
self.emit(self.EVENT_CODEC_NEGOTIATION, self.active_codec)
|
||||
|
||||
def _on_bvra(self, vrec: bytes) -> None:
|
||||
self.send_ok()
|
||||
self.emit('voice_recognition', VoiceRecognitionState(int(vrec)))
|
||||
self.emit(self.EVENT_VOICE_RECOGNITION, VoiceRecognitionState(int(vrec)))
|
||||
|
||||
def _on_chld(self, operation_code: bytes) -> None:
|
||||
call_index: Optional[int] = None
|
||||
@@ -1422,7 +1435,7 @@ class AgProtocol(pyee.EventEmitter):
|
||||
# Real three-way calls have more complicated situations, but this is not a popular issue - let users to handle the remaining :)
|
||||
|
||||
self.send_ok()
|
||||
self.emit('call_hold', operation, call_index)
|
||||
self.emit(self.EVENT_CALL_HOLD, operation, call_index)
|
||||
|
||||
def _on_chld_test(self) -> None:
|
||||
if not self.supports_ag_feature(AgFeature.THREE_WAY_CALLING):
|
||||
@@ -1548,7 +1561,7 @@ class AgProtocol(pyee.EventEmitter):
|
||||
return
|
||||
|
||||
self.hf_indicators[index].current_status = int(value_bytes)
|
||||
self.emit('hf_indicator', self.hf_indicators[index])
|
||||
self.emit(self.EVENT_HF_INDICATOR, self.hf_indicators[index])
|
||||
self.send_ok()
|
||||
|
||||
def _on_bia(self, *args) -> None:
|
||||
@@ -1557,21 +1570,21 @@ class AgProtocol(pyee.EventEmitter):
|
||||
self.send_ok()
|
||||
|
||||
def _on_bcc(self) -> None:
|
||||
self.emit('codec_connection_request')
|
||||
self.emit(self.EVENT_CODEC_CONNECTION_REQUEST)
|
||||
self.send_ok()
|
||||
|
||||
def _on_a(self) -> None:
|
||||
"""ATA handler."""
|
||||
self.emit('answer')
|
||||
self.emit(self.EVENT_ANSWER)
|
||||
self.send_ok()
|
||||
|
||||
def _on_d(self, number: bytes) -> None:
|
||||
"""ATD handler."""
|
||||
self.emit('dial', number.decode())
|
||||
self.emit(self.EVENT_DIAL, number.decode())
|
||||
self.send_ok()
|
||||
|
||||
def _on_chup(self) -> None:
|
||||
self.emit('hang_up')
|
||||
self.emit(self.EVENT_HANG_UP)
|
||||
self.send_ok()
|
||||
|
||||
def _on_clcc(self) -> None:
|
||||
@@ -1597,11 +1610,11 @@ class AgProtocol(pyee.EventEmitter):
|
||||
self.send_ok()
|
||||
|
||||
def _on_vgs(self, level: bytes) -> None:
|
||||
self.emit('speaker_volume', int(level))
|
||||
self.emit(self.EVENT_SPEAKER_VOLUME, int(level))
|
||||
self.send_ok()
|
||||
|
||||
def _on_vgm(self, level: bytes) -> None:
|
||||
self.emit('microphone_volume', int(level))
|
||||
self.emit(self.EVENT_MICROPHONE_VOLUME, int(level))
|
||||
self.send_ok()
|
||||
|
||||
|
||||
@@ -1614,7 +1627,7 @@ class ProfileVersion(enum.IntEnum):
|
||||
"""
|
||||
Profile version (normative).
|
||||
|
||||
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements.
|
||||
Hands-Free Profile v1.8, 6.3 SDP Interoperability Requirements.
|
||||
"""
|
||||
|
||||
V1_5 = 0x0105
|
||||
@@ -1628,7 +1641,7 @@ class HfSdpFeature(enum.IntFlag):
|
||||
"""
|
||||
HF supported features (normative).
|
||||
|
||||
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements.
|
||||
Hands-Free Profile v1.9, 6.3 SDP Interoperability Requirements.
|
||||
"""
|
||||
|
||||
EC_NR = 0x01 # Echo Cancel & Noise reduction
|
||||
@@ -1636,16 +1649,17 @@ class HfSdpFeature(enum.IntFlag):
|
||||
CLI_PRESENTATION_CAPABILITY = 0x04
|
||||
VOICE_RECOGNITION_ACTIVATION = 0x08
|
||||
REMOTE_VOLUME_CONTROL = 0x10
|
||||
WIDE_BAND = 0x20 # Wide band speech
|
||||
WIDE_BAND_SPEECH = 0x20
|
||||
ENHANCED_VOICE_RECOGNITION_STATUS = 0x40
|
||||
VOICE_RECOGNITION_TEST = 0x80
|
||||
VOICE_RECOGNITION_TEXT = 0x80
|
||||
SUPER_WIDE_BAND = 0x100
|
||||
|
||||
|
||||
class AgSdpFeature(enum.IntFlag):
|
||||
"""
|
||||
AG supported features (normative).
|
||||
|
||||
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements.
|
||||
Hands-Free Profile v1.9, 6.3 SDP Interoperability Requirements.
|
||||
"""
|
||||
|
||||
THREE_WAY_CALLING = 0x01
|
||||
@@ -1653,9 +1667,10 @@ class AgSdpFeature(enum.IntFlag):
|
||||
VOICE_RECOGNITION_FUNCTION = 0x04
|
||||
IN_BAND_RING_TONE_CAPABILITY = 0x08
|
||||
VOICE_TAG = 0x10 # Attach a number to voice tag
|
||||
WIDE_BAND = 0x20 # Wide band speech
|
||||
WIDE_BAND_SPEECH = 0x20
|
||||
ENHANCED_VOICE_RECOGNITION_STATUS = 0x40
|
||||
VOICE_RECOGNITION_TEST = 0x80
|
||||
VOICE_RECOGNITION_TEXT = 0x80
|
||||
SUPER_WIDE_BAND_SPEED_SPEECH = 0x100
|
||||
|
||||
|
||||
def make_hf_sdp_records(
|
||||
@@ -1663,7 +1678,7 @@ def make_hf_sdp_records(
|
||||
rfcomm_channel: int,
|
||||
configuration: HfConfiguration,
|
||||
version: ProfileVersion = ProfileVersion.V1_8,
|
||||
) -> List[sdp.ServiceAttribute]:
|
||||
) -> list[sdp.ServiceAttribute]:
|
||||
"""
|
||||
Generates the SDP record for HFP Hands-Free support.
|
||||
|
||||
@@ -1688,11 +1703,11 @@ def make_hf_sdp_records(
|
||||
in configuration.supported_hf_features
|
||||
):
|
||||
hf_supported_features |= HfSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS
|
||||
if HfFeature.VOICE_RECOGNITION_TEST in configuration.supported_hf_features:
|
||||
hf_supported_features |= HfSdpFeature.VOICE_RECOGNITION_TEST
|
||||
if HfFeature.VOICE_RECOGNITION_TEXT in configuration.supported_hf_features:
|
||||
hf_supported_features |= HfSdpFeature.VOICE_RECOGNITION_TEXT
|
||||
|
||||
if AudioCodec.MSBC in configuration.supported_audio_codecs:
|
||||
hf_supported_features |= HfSdpFeature.WIDE_BAND
|
||||
hf_supported_features |= HfSdpFeature.WIDE_BAND_SPEECH
|
||||
|
||||
return [
|
||||
sdp.ServiceAttribute(
|
||||
@@ -1749,7 +1764,7 @@ def make_ag_sdp_records(
|
||||
rfcomm_channel: int,
|
||||
configuration: AgConfiguration,
|
||||
version: ProfileVersion = ProfileVersion.V1_8,
|
||||
) -> List[sdp.ServiceAttribute]:
|
||||
) -> list[sdp.ServiceAttribute]:
|
||||
"""
|
||||
Generates the SDP record for HFP Audio-Gateway support.
|
||||
|
||||
@@ -1768,14 +1783,14 @@ def make_ag_sdp_records(
|
||||
in configuration.supported_ag_features
|
||||
):
|
||||
ag_supported_features |= AgSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS
|
||||
if AgFeature.VOICE_RECOGNITION_TEST in configuration.supported_ag_features:
|
||||
ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_TEST
|
||||
if AgFeature.VOICE_RECOGNITION_TEXT in configuration.supported_ag_features:
|
||||
ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_TEXT
|
||||
if AgFeature.IN_BAND_RING_TONE_CAPABILITY in configuration.supported_ag_features:
|
||||
ag_supported_features |= AgSdpFeature.IN_BAND_RING_TONE_CAPABILITY
|
||||
if AgFeature.VOICE_RECOGNITION_FUNCTION in configuration.supported_ag_features:
|
||||
ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_FUNCTION
|
||||
if AudioCodec.MSBC in configuration.supported_audio_codecs:
|
||||
ag_supported_features |= AgSdpFeature.WIDE_BAND
|
||||
ag_supported_features |= AgSdpFeature.WIDE_BAND_SPEECH
|
||||
|
||||
return [
|
||||
sdp.ServiceAttribute(
|
||||
@@ -1829,7 +1844,7 @@ def make_ag_sdp_records(
|
||||
|
||||
async def find_hf_sdp_record(
|
||||
connection: device.Connection,
|
||||
) -> Optional[Tuple[int, ProfileVersion, HfSdpFeature]]:
|
||||
) -> Optional[tuple[int, ProfileVersion, HfSdpFeature]]:
|
||||
"""Searches a Hands-Free SDP record from remote device.
|
||||
|
||||
Args:
|
||||
@@ -1881,7 +1896,7 @@ async def find_hf_sdp_record(
|
||||
|
||||
async def find_ag_sdp_record(
|
||||
connection: device.Connection,
|
||||
) -> Optional[Tuple[int, ProfileVersion, AgSdpFeature]]:
|
||||
) -> Optional[tuple[int, ProfileVersion, AgSdpFeature]]:
|
||||
"""Searches an Audio-Gateway SDP record from remote device.
|
||||
|
||||
Args:
|
||||
@@ -1979,7 +1994,7 @@ class EscoParameters:
|
||||
transmit_codec_frame_size: int = 60
|
||||
receive_codec_frame_size: int = 60
|
||||
|
||||
def asdict(self) -> Dict[str, Any]:
|
||||
def asdict(self) -> dict[str, Any]:
|
||||
# dataclasses.asdict() will recursively deep-copy the entire object,
|
||||
# which is expensive and breaks CodingFormat object, so let it simply copy here.
|
||||
return self.__dict__
|
||||
|
||||
@@ -16,21 +16,20 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import enum
|
||||
import struct
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import struct
|
||||
from abc import ABC, abstractmethod
|
||||
from pyee import EventEmitter
|
||||
from typing import Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from bumble import l2cap, device
|
||||
from bumble import device, l2cap, utils
|
||||
from bumble.core import InvalidStateError, ProtocolError
|
||||
from bumble.hci import Address
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -195,11 +194,18 @@ class SendHandshakeMessage(Message):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class HID(ABC, EventEmitter):
|
||||
class HID(ABC, utils.EventEmitter):
|
||||
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
|
||||
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
|
||||
connection: Optional[device.Connection] = None
|
||||
|
||||
EVENT_INTERRUPT_DATA = "interrupt_data"
|
||||
EVENT_CONTROL_DATA = "control_data"
|
||||
EVENT_SUSPEND = "suspend"
|
||||
EVENT_EXIT_SUSPEND = "exit_suspend"
|
||||
EVENT_VIRTUAL_CABLE_UNPLUG = "virtual_cable_unplug"
|
||||
EVENT_HANDSHAKE = "handshake"
|
||||
|
||||
class Role(enum.IntEnum):
|
||||
HOST = 0x00
|
||||
DEVICE = 0x01
|
||||
@@ -211,33 +217,41 @@ class HID(ABC, EventEmitter):
|
||||
self.role = role
|
||||
|
||||
# Register ourselves with the L2CAP channel manager
|
||||
device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
|
||||
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
|
||||
device.create_l2cap_server(
|
||||
l2cap.ClassicChannelSpec(HID_CONTROL_PSM), self.on_l2cap_connection
|
||||
)
|
||||
device.create_l2cap_server(
|
||||
l2cap.ClassicChannelSpec(HID_INTERRUPT_PSM), self.on_l2cap_connection
|
||||
)
|
||||
|
||||
device.on('connection', self.on_device_connection)
|
||||
device.on(device.EVENT_CONNECTION, self.on_device_connection)
|
||||
|
||||
async def connect_control_channel(self) -> None:
|
||||
if not self.connection:
|
||||
raise InvalidStateError("Connection is not established!")
|
||||
# Create a new L2CAP connection - control channel
|
||||
try:
|
||||
channel = await self.device.l2cap_channel_manager.connect(
|
||||
self.connection, HID_CONTROL_PSM
|
||||
channel = await self.connection.create_l2cap_channel(
|
||||
l2cap.ClassicChannelSpec(HID_CONTROL_PSM)
|
||||
)
|
||||
channel.sink = self.on_ctrl_pdu
|
||||
self.l2cap_ctrl_channel = channel
|
||||
except ProtocolError:
|
||||
logging.exception(f'L2CAP connection failed.')
|
||||
logging.exception('L2CAP connection failed.')
|
||||
raise
|
||||
|
||||
async def connect_interrupt_channel(self) -> None:
|
||||
if not self.connection:
|
||||
raise InvalidStateError("Connection is not established!")
|
||||
# Create a new L2CAP connection - interrupt channel
|
||||
try:
|
||||
channel = await self.device.l2cap_channel_manager.connect(
|
||||
self.connection, HID_INTERRUPT_PSM
|
||||
channel = await self.connection.create_l2cap_channel(
|
||||
l2cap.ClassicChannelSpec(HID_CONTROL_PSM)
|
||||
)
|
||||
channel.sink = self.on_intr_pdu
|
||||
self.l2cap_intr_channel = channel
|
||||
except ProtocolError:
|
||||
logging.exception(f'L2CAP connection failed.')
|
||||
logging.exception('L2CAP connection failed.')
|
||||
raise
|
||||
|
||||
async def disconnect_interrupt_channel(self) -> None:
|
||||
@@ -257,15 +271,20 @@ class HID(ABC, EventEmitter):
|
||||
def on_device_connection(self, connection: device.Connection) -> None:
|
||||
self.connection = connection
|
||||
self.remote_device_bd_address = connection.peer_address
|
||||
connection.on('disconnection', self.on_device_disconnection)
|
||||
connection.on(connection.EVENT_DISCONNECTION, self.on_device_disconnection)
|
||||
|
||||
def on_device_disconnection(self, reason: int) -> None:
|
||||
self.connection = None
|
||||
|
||||
def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
|
||||
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
|
||||
l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
|
||||
l2cap_channel.on(
|
||||
l2cap_channel.EVENT_OPEN, lambda: self.on_l2cap_channel_open(l2cap_channel)
|
||||
)
|
||||
l2cap_channel.on(
|
||||
l2cap_channel.EVENT_CLOSE,
|
||||
lambda: self.on_l2cap_channel_close(l2cap_channel),
|
||||
)
|
||||
|
||||
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
if l2cap_channel.psm == HID_CONTROL_PSM:
|
||||
@@ -289,7 +308,7 @@ class HID(ABC, EventEmitter):
|
||||
|
||||
def on_intr_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
|
||||
self.emit("interrupt_data", pdu)
|
||||
self.emit(self.EVENT_INTERRUPT_DATA, pdu)
|
||||
|
||||
def send_pdu_on_ctrl(self, msg: bytes) -> None:
|
||||
assert self.l2cap_ctrl_channel
|
||||
@@ -362,17 +381,17 @@ class Device(HID):
|
||||
self.handle_set_protocol(pdu)
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('control_data', pdu)
|
||||
self.emit(self.EVENT_CONTROL_DATA, pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.SUSPEND:
|
||||
logger.debug('<<< HID SUSPEND')
|
||||
self.emit('suspend')
|
||||
self.emit(self.EVENT_SUSPEND)
|
||||
elif param == Message.ControlCommand.EXIT_SUSPEND:
|
||||
logger.debug('<<< HID EXIT SUSPEND')
|
||||
self.emit('exit_suspend')
|
||||
self.emit(self.EVENT_EXIT_SUSPEND)
|
||||
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
self.emit(self.EVENT_VIRTUAL_CABLE_UNPLUG)
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
else:
|
||||
@@ -537,14 +556,14 @@ class Host(HID):
|
||||
message_type = pdu[0] >> 4
|
||||
if message_type == Message.MessageType.HANDSHAKE:
|
||||
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
|
||||
self.emit('handshake', Message.Handshake(param))
|
||||
self.emit(self.EVENT_HANDSHAKE, Message.Handshake(param))
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('control_data', pdu)
|
||||
self.emit(self.EVENT_CONTROL_DATA, pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
self.emit(self.EVENT_VIRTUAL_CABLE_UNPLUG)
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
else:
|
||||
|
||||
863
bumble/host.py
863
bumble/host.py
File diff suppressed because it is too large
Load Diff
124
bumble/keys.py
124
bumble/keys.py
@@ -21,18 +21,21 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from .colors import color
|
||||
from .hci import Address
|
||||
from bumble import hci
|
||||
from bumble.colors import color
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .device import Device
|
||||
from bumble.device import Device
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -42,16 +45,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class PairingKeys:
|
||||
@dataclasses.dataclass
|
||||
class Key:
|
||||
def __init__(self, value, authenticated=False, ediv=None, rand=None):
|
||||
self.value = value
|
||||
self.authenticated = authenticated
|
||||
self.ediv = ediv
|
||||
self.rand = rand
|
||||
value: bytes
|
||||
authenticated: bool = False
|
||||
ediv: Optional[int] = None
|
||||
rand: Optional[bytes] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, key_dict):
|
||||
def from_dict(cls, key_dict: dict[str, Any]) -> PairingKeys.Key:
|
||||
value = bytes.fromhex(key_dict['value'])
|
||||
authenticated = key_dict.get('authenticated', False)
|
||||
ediv = key_dict.get('ediv')
|
||||
@@ -61,7 +65,7 @@ class PairingKeys:
|
||||
|
||||
return cls(value, authenticated, ediv, rand)
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated}
|
||||
if self.ediv is not None:
|
||||
key_dict['ediv'] = self.ediv
|
||||
@@ -70,39 +74,42 @@ class PairingKeys:
|
||||
|
||||
return key_dict
|
||||
|
||||
def __init__(self):
|
||||
self.address_type = None
|
||||
self.ltk = None
|
||||
self.ltk_central = None
|
||||
self.ltk_peripheral = None
|
||||
self.irk = None
|
||||
self.csrk = None
|
||||
self.link_key = None # Classic
|
||||
address_type: Optional[hci.AddressType] = None
|
||||
ltk: Optional[Key] = None
|
||||
ltk_central: Optional[Key] = None
|
||||
ltk_peripheral: Optional[Key] = None
|
||||
irk: Optional[Key] = None
|
||||
csrk: Optional[Key] = None
|
||||
link_key: Optional[Key] = None # Classic
|
||||
link_key_type: Optional[int] = None # Classic
|
||||
|
||||
@staticmethod
|
||||
def key_from_dict(keys_dict, key_name):
|
||||
@classmethod
|
||||
def key_from_dict(cls, keys_dict: dict[str, Any], key_name: str) -> Optional[Key]:
|
||||
key_dict = keys_dict.get(key_name)
|
||||
if key_dict is None:
|
||||
return None
|
||||
|
||||
return PairingKeys.Key.from_dict(key_dict)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(keys_dict):
|
||||
keys = PairingKeys()
|
||||
@classmethod
|
||||
def from_dict(cls, keys_dict: dict[str, Any]) -> PairingKeys:
|
||||
return PairingKeys(
|
||||
address_type=(
|
||||
hci.AddressType(t)
|
||||
if (t := keys_dict.get('address_type')) is not None
|
||||
else None
|
||||
),
|
||||
ltk=PairingKeys.key_from_dict(keys_dict, 'ltk'),
|
||||
ltk_central=PairingKeys.key_from_dict(keys_dict, 'ltk_central'),
|
||||
ltk_peripheral=PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral'),
|
||||
irk=PairingKeys.key_from_dict(keys_dict, 'irk'),
|
||||
csrk=PairingKeys.key_from_dict(keys_dict, 'csrk'),
|
||||
link_key=PairingKeys.key_from_dict(keys_dict, 'link_key'),
|
||||
link_key_type=keys_dict.get('link_key_type'),
|
||||
)
|
||||
|
||||
keys.address_type = keys_dict.get('address_type')
|
||||
keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk')
|
||||
keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central')
|
||||
keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral')
|
||||
keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk')
|
||||
keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk')
|
||||
keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key')
|
||||
|
||||
return keys
|
||||
|
||||
def to_dict(self):
|
||||
keys = {}
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
keys: dict[str, Any] = {}
|
||||
|
||||
if self.address_type is not None:
|
||||
keys['address_type'] = self.address_type
|
||||
@@ -125,9 +132,12 @@ class PairingKeys:
|
||||
if self.link_key is not None:
|
||||
keys['link_key'] = self.link_key.to_dict()
|
||||
|
||||
if self.link_key_type is not None:
|
||||
keys['link_key_type'] = self.link_key_type
|
||||
|
||||
return keys
|
||||
|
||||
def print(self, prefix=''):
|
||||
def print(self, prefix: str = '') -> None:
|
||||
keys_dict = self.to_dict()
|
||||
for container_property, value in keys_dict.items():
|
||||
if isinstance(value, dict):
|
||||
@@ -149,27 +159,35 @@ class KeyStore:
|
||||
async def get(self, _name: str) -> Optional[PairingKeys]:
|
||||
return None
|
||||
|
||||
async def get_all(self) -> List[Tuple[str, PairingKeys]]:
|
||||
async def get_all(self) -> list[tuple[str, PairingKeys]]:
|
||||
return []
|
||||
|
||||
async def delete_all(self) -> None:
|
||||
all_keys = await self.get_all()
|
||||
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
|
||||
|
||||
async def get_resolving_keys(self):
|
||||
async def get_resolving_keys(self) -> list[tuple[bytes, hci.Address]]:
|
||||
all_keys = await self.get_all()
|
||||
resolving_keys = []
|
||||
for name, keys in all_keys:
|
||||
if keys.irk is not None:
|
||||
if keys.address_type is None:
|
||||
address_type = Address.RANDOM_DEVICE_ADDRESS
|
||||
else:
|
||||
address_type = keys.address_type
|
||||
resolving_keys.append((keys.irk.value, Address(name, address_type)))
|
||||
resolving_keys.append(
|
||||
(
|
||||
keys.irk.value,
|
||||
hci.Address(
|
||||
name,
|
||||
(
|
||||
keys.address_type
|
||||
if keys.address_type is not None
|
||||
else hci.Address.RANDOM_DEVICE_ADDRESS
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return resolving_keys
|
||||
|
||||
async def print(self, prefix=''):
|
||||
async def print(self, prefix: str = '') -> None:
|
||||
entries = await self.get_all()
|
||||
separator = ''
|
||||
for name, keys in entries:
|
||||
@@ -177,8 +195,8 @@ class KeyStore:
|
||||
keys.print(prefix=prefix + ' ')
|
||||
separator = '\n'
|
||||
|
||||
@staticmethod
|
||||
def create_for_device(device: Device) -> KeyStore:
|
||||
@classmethod
|
||||
def create_for_device(cls, device: Device) -> KeyStore:
|
||||
if device.config.keystore is None:
|
||||
return MemoryKeyStore()
|
||||
|
||||
@@ -256,7 +274,7 @@ class JsonKeyStore(KeyStore):
|
||||
|
||||
@classmethod
|
||||
def from_device(
|
||||
cls: Type[Self], device: Device, filename: Optional[str] = None
|
||||
cls: type[Self], device: Device, filename: Optional[str] = None
|
||||
) -> Self:
|
||||
if not filename:
|
||||
# Extract the filename from the config if there is one
|
||||
@@ -266,9 +284,9 @@ class JsonKeyStore(KeyStore):
|
||||
filename = params[0]
|
||||
|
||||
# Use a namespace based on the device address
|
||||
if device.public_address not in (Address.ANY, Address.ANY_RANDOM):
|
||||
if device.public_address not in (hci.Address.ANY, hci.Address.ANY_RANDOM):
|
||||
namespace = str(device.public_address)
|
||||
elif device.random_address != Address.ANY_RANDOM:
|
||||
elif device.random_address != hci.Address.ANY_RANDOM:
|
||||
namespace = str(device.random_address)
|
||||
else:
|
||||
namespace = JsonKeyStore.DEFAULT_NAMESPACE
|
||||
@@ -340,7 +358,7 @@ class JsonKeyStore(KeyStore):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MemoryKeyStore(KeyStore):
|
||||
all_keys: Dict[str, PairingKeys]
|
||||
all_keys: dict[str, PairingKeys]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.all_keys = {}
|
||||
@@ -355,5 +373,5 @@ class MemoryKeyStore(KeyStore):
|
||||
async def get(self, name: str) -> Optional[PairingKeys]:
|
||||
return self.all_keys.get(name)
|
||||
|
||||
async def get_all(self) -> List[Tuple[str, PairingKeys]]:
|
||||
async def get_all(self) -> list[tuple[str, PairingKeys]]:
|
||||
return list(self.all_keys.items())
|
||||
|
||||
976
bumble/l2cap.py
976
bumble/l2cap.py
File diff suppressed because it is too large
Load Diff
314
bumble/link.py
314
bumble/link.py
@@ -12,32 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import logging
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
from bumble.core import (
|
||||
BT_PERIPHERAL_ROLE,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_LE_TRANSPORT,
|
||||
InvalidStateError,
|
||||
)
|
||||
from bumble.colors import color
|
||||
from bumble import controller, core
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
HCI_SUCCESS,
|
||||
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
|
||||
HCI_CONNECTION_TIMEOUT_ERROR,
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
HCI_PAGE_TIMEOUT_ERROR,
|
||||
HCI_SUCCESS,
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
Address,
|
||||
HCI_Connection_Complete_Event,
|
||||
Role,
|
||||
)
|
||||
from bumble import controller
|
||||
|
||||
from typing import Optional, Set
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -66,7 +58,7 @@ class LocalLink:
|
||||
Link bus for controllers to communicate with each other
|
||||
'''
|
||||
|
||||
controllers: Set[controller.Controller]
|
||||
controllers: set[controller.Controller]
|
||||
|
||||
def __init__(self):
|
||||
self.controllers = set()
|
||||
@@ -116,12 +108,14 @@ class LocalLink:
|
||||
|
||||
def send_acl_data(self, sender_controller, destination_address, transport, data):
|
||||
# Send the data to the first controller with a matching address
|
||||
if transport == BT_LE_TRANSPORT:
|
||||
if transport == core.PhysicalTransport.LE:
|
||||
destination_controller = self.find_controller(destination_address)
|
||||
source_address = sender_controller.random_address
|
||||
elif transport == BT_BR_EDR_TRANSPORT:
|
||||
elif transport == core.PhysicalTransport.BR_EDR:
|
||||
destination_controller = self.find_classic_controller(destination_address)
|
||||
source_address = sender_controller.public_address
|
||||
else:
|
||||
raise ValueError("unsupported transport type")
|
||||
|
||||
if destination_controller is not None:
|
||||
destination_controller.on_link_acl_data(source_address, transport, data)
|
||||
@@ -164,29 +158,29 @@ class LocalLink:
|
||||
asyncio.get_running_loop().call_soon(self.on_connection_complete)
|
||||
|
||||
def on_disconnection_complete(
|
||||
self, central_address, peripheral_address, disconnect_command
|
||||
self, initiating_address, target_address, disconnect_command
|
||||
):
|
||||
# Find the controller that initiated the disconnection
|
||||
if not (central_controller := self.find_controller(central_address)):
|
||||
if not (initiating_controller := self.find_controller(initiating_address)):
|
||||
logger.warning('!!! Initiating controller not found')
|
||||
return
|
||||
|
||||
# Disconnect from the first controller with a matching address
|
||||
if peripheral_controller := self.find_controller(peripheral_address):
|
||||
peripheral_controller.on_link_central_disconnected(
|
||||
central_address, disconnect_command.reason
|
||||
if target_controller := self.find_controller(target_address):
|
||||
target_controller.on_link_disconnected(
|
||||
initiating_address, disconnect_command.reason
|
||||
)
|
||||
|
||||
central_controller.on_link_peripheral_disconnection_complete(
|
||||
initiating_controller.on_link_disconnection_complete(
|
||||
disconnect_command, HCI_SUCCESS
|
||||
)
|
||||
|
||||
def disconnect(self, central_address, peripheral_address, disconnect_command):
|
||||
def disconnect(self, initiating_address, target_address, disconnect_command):
|
||||
logger.debug(
|
||||
f'$$$ DISCONNECTION {central_address} -> '
|
||||
f'{peripheral_address}: reason = {disconnect_command.reason}'
|
||||
f'$$$ DISCONNECTION {initiating_address} -> '
|
||||
f'{target_address}: reason = {disconnect_command.reason}'
|
||||
)
|
||||
args = [central_address, peripheral_address, disconnect_command]
|
||||
args = [initiating_address, target_address, disconnect_command]
|
||||
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
@@ -273,7 +267,7 @@ class LocalLink:
|
||||
|
||||
responder_controller.on_classic_connection_request(
|
||||
initiator_controller.public_address,
|
||||
HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
HCI_Connection_Complete_Event.LinkType.ACL,
|
||||
)
|
||||
|
||||
def classic_accept_connection(
|
||||
@@ -290,7 +284,7 @@ class LocalLink:
|
||||
return
|
||||
|
||||
async def task():
|
||||
if responder_role != BT_PERIPHERAL_ROLE:
|
||||
if responder_role != Role.PERIPHERAL:
|
||||
initiator_controller.on_classic_role_change(
|
||||
responder_controller.public_address, int(not (responder_role))
|
||||
)
|
||||
@@ -383,261 +377,3 @@ class LocalLink:
|
||||
responder_controller.on_classic_sco_connection_complete(
|
||||
initiator_controller.public_address, HCI_SUCCESS, link_type
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class RemoteLink:
|
||||
'''
|
||||
A Link implementation that communicates with other virtual controllers via a
|
||||
WebSocket relay
|
||||
'''
|
||||
|
||||
def __init__(self, uri):
|
||||
self.controller = None
|
||||
self.uri = uri
|
||||
self.execution_queue = asyncio.Queue()
|
||||
self.websocket = asyncio.get_running_loop().create_future()
|
||||
self.rpc_result = None
|
||||
self.pending_connection = None
|
||||
self.central_connections = set() # List of addresses that we have connected to
|
||||
self.peripheral_connections = (
|
||||
set()
|
||||
) # List of addresses that have connected to us
|
||||
|
||||
# Connect and run asynchronously
|
||||
asyncio.create_task(self.run_connection())
|
||||
asyncio.create_task(self.run_executor_loop())
|
||||
|
||||
def add_controller(self, controller):
|
||||
if self.controller:
|
||||
raise InvalidStateError('controller already set')
|
||||
self.controller = controller
|
||||
|
||||
def remove_controller(self, controller):
|
||||
if self.controller != controller:
|
||||
raise InvalidStateError('controller mismatch')
|
||||
self.controller = None
|
||||
|
||||
def get_pending_connection(self):
|
||||
return self.pending_connection
|
||||
|
||||
def get_pending_classic_connection(self):
|
||||
return self.pending_classic_connection
|
||||
|
||||
async def wait_until_connected(self):
|
||||
await self.websocket
|
||||
|
||||
def execute(self, async_function):
|
||||
self.execution_queue.put_nowait(async_function())
|
||||
|
||||
async def run_executor_loop(self):
|
||||
logger.debug('executor loop starting')
|
||||
while True:
|
||||
item = await self.execution_queue.get()
|
||||
try:
|
||||
await item
|
||||
except Exception as error:
|
||||
logger.warning(
|
||||
f'{color("!!! Exception in async handler:", "red")} {error}'
|
||||
)
|
||||
|
||||
async def run_connection(self):
|
||||
import websockets # lazy import
|
||||
|
||||
# Connect to the relay
|
||||
logger.debug(f'connecting to {self.uri}')
|
||||
# pylint: disable-next=no-member
|
||||
websocket = await websockets.connect(self.uri)
|
||||
self.websocket.set_result(websocket)
|
||||
logger.debug(f'connected to {self.uri}')
|
||||
|
||||
while True:
|
||||
message = await websocket.recv()
|
||||
logger.debug(f'received message: {message}')
|
||||
keyword, *payload = message.split(':', 1)
|
||||
|
||||
handler_name = f'on_{keyword}_received'
|
||||
handler = getattr(self, handler_name, None)
|
||||
if handler:
|
||||
await handler(payload[0] if payload else None)
|
||||
|
||||
def close(self):
|
||||
if self.websocket.done():
|
||||
logger.debug('closing websocket')
|
||||
websocket = self.websocket.result()
|
||||
asyncio.create_task(websocket.close())
|
||||
|
||||
async def on_result_received(self, result):
|
||||
if self.rpc_result:
|
||||
self.rpc_result.set_result(result)
|
||||
|
||||
async def on_left_received(self, address):
|
||||
if address in self.central_connections:
|
||||
self.controller.on_link_peripheral_disconnected(Address(address))
|
||||
self.central_connections.remove(address)
|
||||
|
||||
if address in self.peripheral_connections:
|
||||
self.controller.on_link_central_disconnected(
|
||||
address, HCI_CONNECTION_TIMEOUT_ERROR
|
||||
)
|
||||
self.peripheral_connections.remove(address)
|
||||
|
||||
async def on_unreachable_received(self, target):
|
||||
await self.on_left_received(target)
|
||||
|
||||
async def on_message_received(self, message):
|
||||
sender, *payload = message.split('/', 1)
|
||||
if payload:
|
||||
keyword, *payload = payload[0].split(':', 1)
|
||||
handler_name = f'on_{keyword}_message_received'
|
||||
handler = getattr(self, handler_name, None)
|
||||
if handler:
|
||||
await handler(sender, payload[0] if payload else None)
|
||||
|
||||
async def on_advertisement_message_received(self, sender, advertisement):
|
||||
try:
|
||||
self.controller.on_link_advertising_data(
|
||||
Address(sender), bytes.fromhex(advertisement)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('exception')
|
||||
|
||||
async def on_acl_message_received(self, sender, acl_data):
|
||||
try:
|
||||
self.controller.on_link_acl_data(Address(sender), bytes.fromhex(acl_data))
|
||||
except Exception:
|
||||
logger.exception('exception')
|
||||
|
||||
async def on_connect_message_received(self, sender, _):
|
||||
# Remember the connection
|
||||
self.peripheral_connections.add(sender)
|
||||
|
||||
# Notify the controller
|
||||
logger.debug(f'connection from central {sender}')
|
||||
self.controller.on_link_central_connected(Address(sender))
|
||||
|
||||
# Accept the connection by responding to it
|
||||
await self.send_targeted_message(sender, 'connected')
|
||||
|
||||
async def on_connected_message_received(self, sender, _):
|
||||
if not self.pending_connection:
|
||||
logger.warning('received a connection ack, but no connection is pending')
|
||||
return
|
||||
|
||||
# Remember the connection
|
||||
self.central_connections.add(sender)
|
||||
|
||||
# Notify the controller
|
||||
logger.debug(f'connected to peripheral {self.pending_connection.peer_address}')
|
||||
self.controller.on_link_peripheral_connection_complete(
|
||||
self.pending_connection, HCI_SUCCESS
|
||||
)
|
||||
|
||||
async def on_disconnect_message_received(self, sender, message):
|
||||
# Notify the controller
|
||||
params = parse_parameters(message)
|
||||
reason = int(params.get('reason', str(HCI_CONNECTION_TIMEOUT_ERROR)))
|
||||
self.controller.on_link_central_disconnected(Address(sender), reason)
|
||||
|
||||
# Forget the connection
|
||||
if sender in self.peripheral_connections:
|
||||
self.peripheral_connections.remove(sender)
|
||||
|
||||
async def on_encrypted_message_received(self, sender, _):
|
||||
# TODO parse params to get real args
|
||||
self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16))
|
||||
|
||||
async def send_rpc_command(self, command):
|
||||
# Ensure we have a connection
|
||||
websocket = await self.websocket
|
||||
|
||||
# Create a future value to hold the eventual result
|
||||
assert self.rpc_result is None
|
||||
self.rpc_result = asyncio.get_running_loop().create_future()
|
||||
|
||||
# Send the command
|
||||
await websocket.send(command)
|
||||
|
||||
# Wait for the result
|
||||
rpc_result = await self.rpc_result
|
||||
self.rpc_result = None
|
||||
logger.debug(f'rpc_result: {rpc_result}')
|
||||
|
||||
# TODO: parse the result
|
||||
|
||||
async def send_targeted_message(self, target, message):
|
||||
# Ensure we have a connection
|
||||
websocket = await self.websocket
|
||||
|
||||
# Send the message
|
||||
await websocket.send(f'@{target} {message}')
|
||||
|
||||
async def notify_address_changed(self):
|
||||
await self.send_rpc_command(f'/set-address {self.controller.random_address}')
|
||||
|
||||
def on_address_changed(self, controller):
|
||||
logger.info(f'address changed for {controller}: {controller.random_address}')
|
||||
|
||||
# Notify the relay of the change
|
||||
self.execute(self.notify_address_changed)
|
||||
|
||||
async def send_advertising_data_to_relay(self, data):
|
||||
await self.send_targeted_message('*', f'advertisement:{data.hex()}')
|
||||
|
||||
def send_advertising_data(self, _, data):
|
||||
self.execute(partial(self.send_advertising_data_to_relay, data))
|
||||
|
||||
async def send_acl_data_to_relay(self, peer_address, data):
|
||||
await self.send_targeted_message(peer_address, f'acl:{data.hex()}')
|
||||
|
||||
def send_acl_data(self, _, peer_address, _transport, data):
|
||||
# TODO: handle different transport
|
||||
self.execute(partial(self.send_acl_data_to_relay, peer_address, data))
|
||||
|
||||
async def send_connection_request_to_relay(self, peer_address):
|
||||
await self.send_targeted_message(peer_address, 'connect')
|
||||
|
||||
def connect(self, _, le_create_connection_command):
|
||||
if self.pending_connection:
|
||||
logger.warning('connection already pending')
|
||||
return
|
||||
self.pending_connection = le_create_connection_command
|
||||
self.execute(
|
||||
partial(
|
||||
self.send_connection_request_to_relay,
|
||||
str(le_create_connection_command.peer_address),
|
||||
)
|
||||
)
|
||||
|
||||
def on_disconnection_complete(self, disconnect_command):
|
||||
self.controller.on_link_peripheral_disconnection_complete(
|
||||
disconnect_command, HCI_SUCCESS
|
||||
)
|
||||
|
||||
def disconnect(self, central_address, peripheral_address, disconnect_command):
|
||||
logger.debug(
|
||||
f'disconnect {central_address} -> '
|
||||
f'{peripheral_address}: reason = {disconnect_command.reason}'
|
||||
)
|
||||
self.execute(
|
||||
partial(
|
||||
self.send_targeted_message,
|
||||
peripheral_address,
|
||||
f'disconnect:reason={disconnect_command.reason}',
|
||||
)
|
||||
)
|
||||
asyncio.get_running_loop().call_soon(
|
||||
self.on_disconnection_complete, disconnect_command
|
||||
)
|
||||
|
||||
def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk):
|
||||
asyncio.get_running_loop().call_soon(
|
||||
self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk
|
||||
)
|
||||
self.execute(
|
||||
partial(
|
||||
self.send_targeted_message,
|
||||
peripheral_address,
|
||||
f'encrypted:ltk={ltk.hex()}',
|
||||
)
|
||||
)
|
||||
|
||||
65
bumble/logging.py
Normal file
65
bumble/logging.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
|
||||
from bumble import colors
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ColorFormatter(logging.Formatter):
|
||||
_colorizers = {
|
||||
logging.DEBUG: functools.partial(colors.color, fg="white"),
|
||||
logging.INFO: functools.partial(colors.color, fg="green"),
|
||||
logging.WARNING: functools.partial(colors.color, fg="yellow"),
|
||||
logging.ERROR: functools.partial(colors.color, fg="red"),
|
||||
logging.CRITICAL: functools.partial(colors.color, fg="black", bg="red"),
|
||||
}
|
||||
|
||||
_formatters = {
|
||||
level: logging.Formatter(
|
||||
fmt=colorizer("{asctime}.{msecs:03.0f} {levelname:.1} {name}: ")
|
||||
+ "{message}",
|
||||
datefmt="%H:%M:%S",
|
||||
style="{",
|
||||
)
|
||||
for level, colorizer in _colorizers.items()
|
||||
}
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
return self._formatters[record.levelno].format(record)
|
||||
|
||||
|
||||
def setup_basic_logging(default_level: str = "INFO") -> None:
|
||||
"""
|
||||
Set up basic logging with logging.basicConfig, configured with a simple formatter
|
||||
that prints out the date and log level in color.
|
||||
If the BUMBLE_LOGLEVEL environment variable is set to the name of a log level, it
|
||||
is used. Otherwise the default_level argument is used.
|
||||
|
||||
Args:
|
||||
default_level: default logging level
|
||||
|
||||
"""
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(ColorFormatter())
|
||||
logging.basicConfig(
|
||||
level=os.environ.get("BUMBLE_LOGLEVEL", default_level).upper(),
|
||||
handlers=[handler],
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
# Copyright 2021-2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -16,32 +16,28 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from .hci import (
|
||||
Address,
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY,
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
|
||||
HCI_KEYBOARD_ONLY_IO_CAPABILITY,
|
||||
)
|
||||
from .smp import (
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY,
|
||||
import enum
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from bumble import hci
|
||||
from bumble.core import AdvertisingData, LeRole
|
||||
from bumble.smp import (
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY,
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY,
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY,
|
||||
SMP_ENC_KEY_DISTRIBUTION_FLAG,
|
||||
SMP_ID_KEY_DISTRIBUTION_FLAG,
|
||||
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY,
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY,
|
||||
SMP_LINK_KEY_DISTRIBUTION_FLAG,
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
|
||||
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
|
||||
OobContext,
|
||||
OobLegacyContext,
|
||||
OobSharedData,
|
||||
)
|
||||
from .core import AdvertisingData, LeRole
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -49,7 +45,7 @@ from .core import AdvertisingData, LeRole
|
||||
class OobData:
|
||||
"""OOB data that can be sent from one device to another."""
|
||||
|
||||
address: Optional[Address] = None
|
||||
address: Optional[hci.Address] = None
|
||||
role: Optional[LeRole] = None
|
||||
shared_data: Optional[OobSharedData] = None
|
||||
legacy_context: Optional[OobLegacyContext] = None
|
||||
@@ -61,7 +57,7 @@ class OobData:
|
||||
shared_data_r: Optional[bytes] = None
|
||||
for ad_type, ad_data in ad.ad_structures:
|
||||
if ad_type == AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS:
|
||||
instance.address = Address(ad_data)
|
||||
instance.address = hci.Address(ad_data)
|
||||
elif ad_type == AdvertisingData.LE_ROLE:
|
||||
instance.role = LeRole(ad_data[0])
|
||||
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE:
|
||||
@@ -76,18 +72,18 @@ class OobData:
|
||||
return instance
|
||||
|
||||
def to_ad(self) -> AdvertisingData:
|
||||
ad_structures = []
|
||||
ad_structures: list[tuple[int, bytes]] = []
|
||||
if self.address is not None:
|
||||
ad_structures.append(
|
||||
(AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address))
|
||||
(AdvertisingData.Type.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address))
|
||||
)
|
||||
if self.role is not None:
|
||||
ad_structures.append((AdvertisingData.LE_ROLE, bytes([self.role])))
|
||||
ad_structures.append((AdvertisingData.Type.LE_ROLE, bytes([self.role])))
|
||||
if self.shared_data is not None:
|
||||
ad_structures.extend(self.shared_data.to_ad().ad_structures)
|
||||
if self.legacy_context is not None:
|
||||
ad_structures.append(
|
||||
(AdvertisingData.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk)
|
||||
(AdvertisingData.Type.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk)
|
||||
)
|
||||
|
||||
return AdvertisingData(ad_structures)
|
||||
@@ -129,26 +125,29 @@ class PairingDelegate:
|
||||
# Default mapping from abstract to Classic I/O capabilities.
|
||||
# Subclasses may override this if they prefer a different mapping.
|
||||
CLASSIC_IO_CAPABILITIES_MAP = {
|
||||
NO_OUTPUT_NO_INPUT: HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
|
||||
KEYBOARD_INPUT_ONLY: HCI_KEYBOARD_ONLY_IO_CAPABILITY,
|
||||
DISPLAY_OUTPUT_ONLY: HCI_DISPLAY_ONLY_IO_CAPABILITY,
|
||||
DISPLAY_OUTPUT_AND_YES_NO_INPUT: HCI_DISPLAY_YES_NO_IO_CAPABILITY,
|
||||
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT: HCI_DISPLAY_YES_NO_IO_CAPABILITY,
|
||||
NO_OUTPUT_NO_INPUT: hci.IoCapability.NO_INPUT_NO_OUTPUT,
|
||||
KEYBOARD_INPUT_ONLY: hci.IoCapability.KEYBOARD_ONLY,
|
||||
DISPLAY_OUTPUT_ONLY: hci.IoCapability.DISPLAY_ONLY,
|
||||
DISPLAY_OUTPUT_AND_YES_NO_INPUT: hci.IoCapability.DISPLAY_YES_NO,
|
||||
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT: hci.IoCapability.DISPLAY_YES_NO,
|
||||
}
|
||||
|
||||
io_capability: IoCapability
|
||||
local_initiator_key_distribution: KeyDistribution
|
||||
local_responder_key_distribution: KeyDistribution
|
||||
maximum_encryption_key_size: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
io_capability: IoCapability = NO_OUTPUT_NO_INPUT,
|
||||
local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
|
||||
local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
|
||||
maximum_encryption_key_size: int = 16,
|
||||
) -> None:
|
||||
self.io_capability = io_capability
|
||||
self.local_initiator_key_distribution = local_initiator_key_distribution
|
||||
self.local_responder_key_distribution = local_responder_key_distribution
|
||||
self.maximum_encryption_key_size = maximum_encryption_key_size
|
||||
|
||||
@property
|
||||
def classic_io_capability(self) -> int:
|
||||
@@ -156,7 +155,7 @@ class PairingDelegate:
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
return self.CLASSIC_IO_CAPABILITIES_MAP.get(
|
||||
self.io_capability, HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
|
||||
self.io_capability, hci.IoCapability.NO_INPUT_NO_OUTPUT
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -202,7 +201,7 @@ class PairingDelegate:
|
||||
# [LE only]
|
||||
async def key_distribution_response(
|
||||
self, peer_initiator_key_distribution: int, peer_responder_key_distribution: int
|
||||
) -> Tuple[int, int]:
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Return the key distribution response in an SMP protocol context.
|
||||
|
||||
@@ -219,14 +218,22 @@ class PairingDelegate:
|
||||
),
|
||||
)
|
||||
|
||||
async def generate_passkey(self) -> int:
|
||||
"""
|
||||
Return a passkey value between 0 and 999999 (inclusive).
|
||||
"""
|
||||
|
||||
# By default, generate a random passkey.
|
||||
return secrets.randbelow(1000000)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PairingConfig:
|
||||
"""Configuration for the Pairing protocol."""
|
||||
|
||||
class AddressType(enum.IntEnum):
|
||||
PUBLIC = Address.PUBLIC_DEVICE_ADDRESS
|
||||
RANDOM = Address.RANDOM_DEVICE_ADDRESS
|
||||
PUBLIC = hci.Address.PUBLIC_DEVICE_ADDRESS
|
||||
RANDOM = hci.Address.RANDOM_DEVICE_ADDRESS
|
||||
|
||||
@dataclass
|
||||
class OobConfig:
|
||||
|
||||
@@ -19,21 +19,22 @@ This module implement the Pandora Bluetooth test APIs for the Bumble stack.
|
||||
|
||||
__version__ = "0.0.1"
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import grpc
|
||||
import grpc.aio
|
||||
|
||||
from .config import Config
|
||||
from .device import PandoraDevice
|
||||
from .host import HostService
|
||||
from .l2cap import L2CAPService
|
||||
from .security import SecurityService, SecurityStorageService
|
||||
from pandora.host_grpc_aio import add_HostServicer_to_server
|
||||
from pandora.l2cap_grpc_aio import add_L2CAPServicer_to_server
|
||||
from pandora.security_grpc_aio import (
|
||||
add_SecurityServicer_to_server,
|
||||
add_SecurityStorageServicer_to_server,
|
||||
)
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from bumble.pandora.config import Config
|
||||
from bumble.pandora.device import PandoraDevice
|
||||
from bumble.pandora.host import HostService
|
||||
from bumble.pandora.l2cap import L2CAPService
|
||||
from bumble.pandora.security import SecurityService, SecurityStorageService
|
||||
|
||||
# public symbols
|
||||
__all__ = [
|
||||
@@ -45,11 +46,11 @@ __all__ = [
|
||||
|
||||
|
||||
# Add servicers hooks.
|
||||
_SERVICERS_HOOKS: List[Callable[[PandoraDevice, Config, grpc.aio.Server], None]] = []
|
||||
_SERVICERS_HOOKS: list[Callable[[PandoraDevice, Config, grpc.aio.Server], None]] = []
|
||||
|
||||
|
||||
def register_servicer_hook(
|
||||
hook: Callable[[PandoraDevice, Config, grpc.aio.Server], None]
|
||||
hook: Callable[[PandoraDevice, Config, grpc.aio.Server], None],
|
||||
) -> None:
|
||||
_SERVICERS_HOOKS.append(hook)
|
||||
|
||||
|
||||
@@ -13,9 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
from bumble.pairing import PairingConfig, PairingDelegate
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from bumble.pairing import PairingConfig, PairingDelegate
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -32,7 +34,7 @@ class Config:
|
||||
PairingDelegate.DEFAULT_KEY_DISTRIBUTION
|
||||
)
|
||||
|
||||
def load_from_dict(self, config: Dict[str, Any]) -> None:
|
||||
def load_from_dict(self, config: dict[str, Any]) -> None:
|
||||
io_capability_name: str = config.get(
|
||||
'io_capability', 'no_output_no_input'
|
||||
).upper()
|
||||
|
||||
@@ -15,6 +15,9 @@
|
||||
"""Generic & dependency free Bumble (reference) device."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from bumble import transport
|
||||
from bumble.core import (
|
||||
BT_GENERIC_AUDIO_SERVICE,
|
||||
@@ -32,8 +35,6 @@ from bumble.sdp import (
|
||||
DataElement,
|
||||
ServiceAttribute,
|
||||
)
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
# Default rootcanal HCI TCP address
|
||||
ROOTCANAL_HCI_ADDRESS = "localhost:6402"
|
||||
@@ -49,13 +50,13 @@ class PandoraDevice:
|
||||
|
||||
# Bumble device instance & configuration.
|
||||
device: Device
|
||||
config: Dict[str, Any]
|
||||
config: dict[str, Any]
|
||||
|
||||
# HCI transport name & instance.
|
||||
_hci_name: str
|
||||
_hci: Optional[transport.Transport] # type: ignore[name-defined]
|
||||
|
||||
def __init__(self, config: Dict[str, Any]) -> None:
|
||||
def __init__(self, config: dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.device = _make_device(config)
|
||||
self._hci_name = config.get(
|
||||
@@ -95,14 +96,14 @@ class PandoraDevice:
|
||||
await self.close()
|
||||
await self.open()
|
||||
|
||||
def info(self) -> Optional[Dict[str, str]]:
|
||||
def info(self) -> Optional[dict[str, str]]:
|
||||
return {
|
||||
'public_bd_address': str(self.device.public_address),
|
||||
'random_address': str(self.device.random_address),
|
||||
}
|
||||
|
||||
|
||||
def _make_device(config: Dict[str, Any]) -> Device:
|
||||
def _make_device(config: dict[str, Any]) -> Device:
|
||||
"""Initialize an idle Bumble device instance."""
|
||||
|
||||
# initialize bumble device.
|
||||
@@ -117,7 +118,7 @@ def _make_device(config: Dict[str, Any]) -> Device:
|
||||
|
||||
|
||||
# TODO(b/267540823): remove when Pandora A2dp is supported
|
||||
def _make_sdp_records(rfcomm_channel: int) -> Dict[int, List[ServiceAttribute]]:
|
||||
def _make_sdp_records(rfcomm_channel: int) -> dict[int, list[ServiceAttribute]]:
|
||||
return {
|
||||
0x00010001: [
|
||||
ServiceAttribute(
|
||||
|
||||
@@ -13,50 +13,23 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import bumble.device
|
||||
import grpc
|
||||
import grpc.aio
|
||||
import logging
|
||||
import struct
|
||||
from typing import AsyncGenerator, Optional, cast
|
||||
|
||||
from . import utils
|
||||
from .config import Config
|
||||
from bumble.core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_LE_TRANSPORT,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
UUID,
|
||||
AdvertisingData,
|
||||
Appearance,
|
||||
ConnectionError,
|
||||
)
|
||||
from bumble.device import (
|
||||
DEVICE_DEFAULT_SCAN_INTERVAL,
|
||||
DEVICE_DEFAULT_SCAN_WINDOW,
|
||||
Advertisement,
|
||||
AdvertisingParameters,
|
||||
AdvertisingEventProperties,
|
||||
AdvertisingType,
|
||||
Device,
|
||||
Phy,
|
||||
)
|
||||
from bumble.gatt import Service
|
||||
from bumble.hci import (
|
||||
HCI_CONNECTION_ALREADY_EXISTS_ERROR,
|
||||
HCI_PAGE_TIMEOUT_ERROR,
|
||||
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
Address,
|
||||
)
|
||||
import grpc
|
||||
import grpc.aio
|
||||
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
||||
from pandora.host_grpc_aio import HostServicer
|
||||
from pandora import host_pb2
|
||||
from pandora.host_grpc_aio import HostServicer
|
||||
from pandora.host_pb2 import (
|
||||
DISCOVERABLE_GENERAL,
|
||||
DISCOVERABLE_LIMITED,
|
||||
NOT_CONNECTABLE,
|
||||
NOT_DISCOVERABLE,
|
||||
DISCOVERABLE_LIMITED,
|
||||
DISCOVERABLE_GENERAL,
|
||||
PRIMARY_1M,
|
||||
PRIMARY_CODED,
|
||||
SECONDARY_1M,
|
||||
@@ -72,7 +45,6 @@ from pandora.host_pb2 import (
|
||||
ConnectResponse,
|
||||
DataTypes,
|
||||
DisconnectRequest,
|
||||
DiscoverabilityMode,
|
||||
InquiryResponse,
|
||||
PrimaryPhy,
|
||||
ReadLocalAddressResponse,
|
||||
@@ -85,9 +57,39 @@ from pandora.host_pb2 import (
|
||||
WaitConnectionResponse,
|
||||
WaitDisconnectionRequest,
|
||||
)
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Set, Tuple, cast
|
||||
|
||||
PRIMARY_PHY_MAP: Dict[int, PrimaryPhy] = {
|
||||
import bumble.device
|
||||
import bumble.utils
|
||||
from bumble.core import (
|
||||
UUID,
|
||||
AdvertisingData,
|
||||
Appearance,
|
||||
ConnectionError,
|
||||
PhysicalTransport,
|
||||
)
|
||||
from bumble.device import (
|
||||
DEVICE_DEFAULT_SCAN_INTERVAL,
|
||||
DEVICE_DEFAULT_SCAN_WINDOW,
|
||||
Advertisement,
|
||||
AdvertisingEventProperties,
|
||||
AdvertisingParameters,
|
||||
AdvertisingType,
|
||||
Device,
|
||||
)
|
||||
from bumble.gatt import Service
|
||||
from bumble.hci import (
|
||||
HCI_CONNECTION_ALREADY_EXISTS_ERROR,
|
||||
HCI_PAGE_TIMEOUT_ERROR,
|
||||
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
Address,
|
||||
OwnAddressType,
|
||||
Phy,
|
||||
Role,
|
||||
)
|
||||
from bumble.pandora import utils
|
||||
from bumble.pandora.config import Config
|
||||
|
||||
PRIMARY_PHY_MAP: dict[int, PrimaryPhy] = {
|
||||
# Default value reported by Bumble for legacy Advertising reports.
|
||||
# FIXME(uael): `None` might be a better value, but Bumble need to change accordingly.
|
||||
0: PRIMARY_1M,
|
||||
@@ -95,35 +97,35 @@ PRIMARY_PHY_MAP: Dict[int, PrimaryPhy] = {
|
||||
3: PRIMARY_CODED,
|
||||
}
|
||||
|
||||
SECONDARY_PHY_MAP: Dict[int, SecondaryPhy] = {
|
||||
SECONDARY_PHY_MAP: dict[int, SecondaryPhy] = {
|
||||
0: SECONDARY_NONE,
|
||||
1: SECONDARY_1M,
|
||||
2: SECONDARY_2M,
|
||||
3: SECONDARY_CODED,
|
||||
}
|
||||
|
||||
PRIMARY_PHY_TO_BUMBLE_PHY_MAP: Dict[PrimaryPhy, Phy] = {
|
||||
PRIMARY_PHY_TO_BUMBLE_PHY_MAP: dict[PrimaryPhy, Phy] = {
|
||||
PRIMARY_1M: Phy.LE_1M,
|
||||
PRIMARY_CODED: Phy.LE_CODED,
|
||||
}
|
||||
|
||||
SECONDARY_PHY_TO_BUMBLE_PHY_MAP: Dict[SecondaryPhy, Phy] = {
|
||||
SECONDARY_PHY_TO_BUMBLE_PHY_MAP: dict[SecondaryPhy, Phy] = {
|
||||
SECONDARY_NONE: Phy.LE_1M,
|
||||
SECONDARY_1M: Phy.LE_1M,
|
||||
SECONDARY_2M: Phy.LE_2M,
|
||||
SECONDARY_CODED: Phy.LE_CODED,
|
||||
}
|
||||
|
||||
OWN_ADDRESS_MAP: Dict[host_pb2.OwnAddressType, bumble.hci.OwnAddressType] = {
|
||||
host_pb2.PUBLIC: bumble.hci.OwnAddressType.PUBLIC,
|
||||
host_pb2.RANDOM: bumble.hci.OwnAddressType.RANDOM,
|
||||
host_pb2.RESOLVABLE_OR_PUBLIC: bumble.hci.OwnAddressType.RESOLVABLE_OR_PUBLIC,
|
||||
host_pb2.RESOLVABLE_OR_RANDOM: bumble.hci.OwnAddressType.RESOLVABLE_OR_RANDOM,
|
||||
OWN_ADDRESS_MAP: dict[host_pb2.OwnAddressType, OwnAddressType] = {
|
||||
host_pb2.PUBLIC: OwnAddressType.PUBLIC,
|
||||
host_pb2.RANDOM: OwnAddressType.RANDOM,
|
||||
host_pb2.RESOLVABLE_OR_PUBLIC: OwnAddressType.RESOLVABLE_OR_PUBLIC,
|
||||
host_pb2.RESOLVABLE_OR_RANDOM: OwnAddressType.RESOLVABLE_OR_RANDOM,
|
||||
}
|
||||
|
||||
|
||||
class HostService(HostServicer):
|
||||
waited_connections: Set[int]
|
||||
waited_connections: set[int]
|
||||
|
||||
def __init__(
|
||||
self, grpc_server: grpc.aio.Server, device: Device, config: Config
|
||||
@@ -184,7 +186,7 @@ class HostService(HostServicer):
|
||||
|
||||
try:
|
||||
connection = await self.device.connect(
|
||||
address, transport=BT_BR_EDR_TRANSPORT
|
||||
address, transport=PhysicalTransport.BR_EDR
|
||||
)
|
||||
except ConnectionError as e:
|
||||
if e.error_code == HCI_PAGE_TIMEOUT_ERROR:
|
||||
@@ -217,7 +219,7 @@ class HostService(HostServicer):
|
||||
self.log.debug(f"WaitConnection from {address}...")
|
||||
|
||||
connection = self.device.find_connection_by_bd_addr(
|
||||
address, transport=BT_BR_EDR_TRANSPORT
|
||||
address, transport=PhysicalTransport.BR_EDR
|
||||
)
|
||||
if connection and id(connection) in self.waited_connections:
|
||||
# this connection was already returned: wait for a new one.
|
||||
@@ -249,8 +251,8 @@ class HostService(HostServicer):
|
||||
try:
|
||||
connection = await self.device.connect(
|
||||
address,
|
||||
transport=BT_LE_TRANSPORT,
|
||||
own_address_type=request.own_address_type,
|
||||
transport=PhysicalTransport.LE,
|
||||
own_address_type=OwnAddressType(request.own_address_type),
|
||||
)
|
||||
except ConnectionError as e:
|
||||
if e.error_code == HCI_PAGE_TIMEOUT_ERROR:
|
||||
@@ -295,12 +297,12 @@ class HostService(HostServicer):
|
||||
def on_disconnection(_: None) -> None:
|
||||
disconnection_future.set_result(None)
|
||||
|
||||
connection.on('disconnection', on_disconnection)
|
||||
connection.on(connection.EVENT_DISCONNECTION, on_disconnection)
|
||||
try:
|
||||
await disconnection_future
|
||||
self.log.debug("Disconnected")
|
||||
finally:
|
||||
connection.remove_listener('disconnection', on_disconnection) # type: ignore
|
||||
connection.remove_listener(connection.EVENT_DISCONNECTION, on_disconnection) # type: ignore
|
||||
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@@ -371,20 +373,18 @@ class HostService(HostServicer):
|
||||
scan_response_data=scan_response_data,
|
||||
)
|
||||
|
||||
pending_connection: asyncio.Future[bumble.device.Connection] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
connections: asyncio.Queue[bumble.device.Connection] = asyncio.Queue()
|
||||
|
||||
if request.connectable:
|
||||
|
||||
def on_connection(connection: bumble.device.Connection) -> None:
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
connection.transport == PhysicalTransport.LE
|
||||
and connection.role == Role.PERIPHERAL
|
||||
):
|
||||
pending_connection.set_result(connection)
|
||||
connections.put_nowait(connection)
|
||||
|
||||
self.device.on('connection', on_connection)
|
||||
self.device.on(self.device.EVENT_CONNECTION, on_connection)
|
||||
|
||||
try:
|
||||
# Advertise until RPC is canceled
|
||||
@@ -397,8 +397,7 @@ class HostService(HostServicer):
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
connection = await pending_connection
|
||||
pending_connection = asyncio.get_running_loop().create_future()
|
||||
connection = await connections.get()
|
||||
|
||||
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
|
||||
yield AdvertiseResponse(connection=Connection(cookie=cookie))
|
||||
@@ -492,16 +491,18 @@ class HostService(HostServicer):
|
||||
target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
|
||||
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY
|
||||
|
||||
connections: asyncio.Queue[bumble.device.Connection] = asyncio.Queue()
|
||||
|
||||
if request.connectable:
|
||||
|
||||
def on_connection(connection: bumble.device.Connection) -> None:
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
connection.transport == PhysicalTransport.LE
|
||||
and connection.role == Role.PERIPHERAL
|
||||
):
|
||||
pending_connection.set_result(connection)
|
||||
connections.put_nowait(connection)
|
||||
|
||||
self.device.on('connection', on_connection)
|
||||
self.device.on(self.device.EVENT_CONNECTION, on_connection)
|
||||
|
||||
try:
|
||||
while True:
|
||||
@@ -510,19 +511,15 @@ class HostService(HostServicer):
|
||||
await self.device.start_advertising(
|
||||
target=target,
|
||||
advertising_type=advertising_type,
|
||||
own_address_type=request.own_address_type,
|
||||
own_address_type=OwnAddressType(request.own_address_type),
|
||||
)
|
||||
|
||||
if not request.connectable:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
pending_connection: asyncio.Future[bumble.device.Connection] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
self.log.debug('Wait for LE connection...')
|
||||
connection = await pending_connection
|
||||
connection = await connections.get()
|
||||
|
||||
self.log.debug(
|
||||
f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})"
|
||||
@@ -535,11 +532,13 @@ class HostService(HostServicer):
|
||||
await asyncio.sleep(1)
|
||||
finally:
|
||||
if request.connectable:
|
||||
self.device.remove_listener('connection', on_connection) # type: ignore
|
||||
self.device.remove_listener(self.device.EVENT_CONNECTION, on_connection) # type: ignore
|
||||
|
||||
try:
|
||||
self.log.debug('Stop advertising')
|
||||
await self.device.abort_on('flush', self.device.stop_advertising())
|
||||
await bumble.utils.cancel_on_event(
|
||||
self.device, 'flush', self.device.stop_advertising()
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -559,11 +558,11 @@ class HostService(HostServicer):
|
||||
scanning_phys = [int(Phy.LE_1M), int(Phy.LE_CODED)]
|
||||
|
||||
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
|
||||
handler = self.device.on('advertisement', scan_queue.put_nowait)
|
||||
handler = self.device.on(self.device.EVENT_ADVERTISEMENT, scan_queue.put_nowait)
|
||||
await self.device.start_scanning(
|
||||
legacy=request.legacy,
|
||||
active=not request.passive,
|
||||
own_address_type=request.own_address_type,
|
||||
own_address_type=OwnAddressType(request.own_address_type),
|
||||
scan_interval=(
|
||||
int(request.interval)
|
||||
if request.interval
|
||||
@@ -604,10 +603,12 @@ class HostService(HostServicer):
|
||||
yield sr
|
||||
|
||||
finally:
|
||||
self.device.remove_listener('advertisement', handler) # type: ignore
|
||||
self.device.remove_listener(self.device.EVENT_ADVERTISEMENT, handler) # type: ignore
|
||||
try:
|
||||
self.log.debug('Stop scanning')
|
||||
await self.device.abort_on('flush', self.device.stop_scanning())
|
||||
await bumble.utils.cancel_on_event(
|
||||
self.device, 'flush', self.device.stop_scanning()
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -618,13 +619,13 @@ class HostService(HostServicer):
|
||||
self.log.debug('Inquiry')
|
||||
|
||||
inquiry_queue: asyncio.Queue[
|
||||
Optional[Tuple[Address, int, AdvertisingData, int]]
|
||||
Optional[tuple[Address, int, AdvertisingData, int]]
|
||||
] = asyncio.Queue()
|
||||
complete_handler = self.device.on(
|
||||
'inquiry_complete', lambda: inquiry_queue.put_nowait(None)
|
||||
self.device.EVENT_INQUIRY_COMPLETE, lambda: inquiry_queue.put_nowait(None)
|
||||
)
|
||||
result_handler = self.device.on( # type: ignore
|
||||
'inquiry_result',
|
||||
self.device.EVENT_INQUIRY_RESULT,
|
||||
lambda address, class_of_device, eir_data, rssi: inquiry_queue.put_nowait( # type: ignore
|
||||
(address, class_of_device, eir_data, rssi) # type: ignore
|
||||
),
|
||||
@@ -643,11 +644,13 @@ class HostService(HostServicer):
|
||||
)
|
||||
|
||||
finally:
|
||||
self.device.remove_listener('inquiry_complete', complete_handler) # type: ignore
|
||||
self.device.remove_listener('inquiry_result', result_handler) # type: ignore
|
||||
self.device.remove_listener(self.device.EVENT_INQUIRY_COMPLETE, complete_handler) # type: ignore
|
||||
self.device.remove_listener(self.device.EVENT_INQUIRY_RESULT, result_handler) # type: ignore
|
||||
try:
|
||||
self.log.debug('Stop inquiry')
|
||||
await self.device.abort_on('flush', self.device.stop_discovery())
|
||||
await bumble.utils.cancel_on_event(
|
||||
self.device, 'flush', self.device.stop_discovery()
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -668,10 +671,10 @@ class HostService(HostServicer):
|
||||
return empty_pb2.Empty()
|
||||
|
||||
def unpack_data_types(self, dt: DataTypes) -> AdvertisingData:
|
||||
ad_structures: List[Tuple[int, bytes]] = []
|
||||
ad_structures: list[tuple[int, bytes]] = []
|
||||
|
||||
uuids: List[str]
|
||||
datas: Dict[str, bytes]
|
||||
uuids: list[str]
|
||||
datas: dict[str, bytes]
|
||||
|
||||
def uuid128_from_str(uuid: str) -> bytes:
|
||||
"""Decode a 128-bit uuid encoded as XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
|
||||
@@ -885,50 +888,50 @@ class HostService(HostServicer):
|
||||
|
||||
def pack_data_types(self, ad: AdvertisingData) -> DataTypes:
|
||||
dt = DataTypes()
|
||||
uuids: List[UUID]
|
||||
uuids: list[UUID]
|
||||
s: str
|
||||
i: int
|
||||
ij: Tuple[int, int]
|
||||
uuid_data: Tuple[UUID, bytes]
|
||||
ij: tuple[int, int]
|
||||
uuid_data: tuple[UUID, bytes]
|
||||
data: bytes
|
||||
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
list[UUID],
|
||||
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.incomplete_service_class_uuids16.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
list[UUID],
|
||||
ad.get(AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.complete_service_class_uuids16.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
list[UUID],
|
||||
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.incomplete_service_class_uuids32.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
list[UUID],
|
||||
ad.get(AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.complete_service_class_uuids32.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
list[UUID],
|
||||
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.incomplete_service_class_uuids128.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
list[UUID],
|
||||
ad.get(AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS),
|
||||
):
|
||||
dt.complete_service_class_uuids128.extend(
|
||||
@@ -943,42 +946,42 @@ class HostService(HostServicer):
|
||||
if i := cast(int, ad.get(AdvertisingData.CLASS_OF_DEVICE)):
|
||||
dt.class_of_device = i
|
||||
if ij := cast(
|
||||
Tuple[int, int],
|
||||
tuple[int, int],
|
||||
ad.get(AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE),
|
||||
):
|
||||
dt.peripheral_connection_interval_min = ij[0]
|
||||
dt.peripheral_connection_interval_max = ij[1]
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
list[UUID],
|
||||
ad.get(AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS),
|
||||
):
|
||||
dt.service_solicitation_uuids16.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
list[UUID],
|
||||
ad.get(AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS),
|
||||
):
|
||||
dt.service_solicitation_uuids32.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuids := cast(
|
||||
List[UUID],
|
||||
list[UUID],
|
||||
ad.get(AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS),
|
||||
):
|
||||
dt.service_solicitation_uuids128.extend(
|
||||
list(map(lambda x: x.to_hex_str('-'), uuids))
|
||||
)
|
||||
if uuid_data := cast(
|
||||
Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_16_BIT_UUID)
|
||||
tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_16_BIT_UUID)
|
||||
):
|
||||
dt.service_data_uuid16[uuid_data[0].to_hex_str('-')] = uuid_data[1]
|
||||
if uuid_data := cast(
|
||||
Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_32_BIT_UUID)
|
||||
tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_32_BIT_UUID)
|
||||
):
|
||||
dt.service_data_uuid32[uuid_data[0].to_hex_str('-')] = uuid_data[1]
|
||||
if uuid_data := cast(
|
||||
Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_128_BIT_UUID)
|
||||
tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_128_BIT_UUID)
|
||||
):
|
||||
dt.service_data_uuid128[uuid_data[0].to_hex_str('-')] = uuid_data[1]
|
||||
if data := cast(bytes, ad.get(AdvertisingData.PUBLIC_TARGET_ADDRESS, raw=True)):
|
||||
|
||||
@@ -12,31 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import grpc
|
||||
import json
|
||||
import logging
|
||||
from asyncio import Future
|
||||
from asyncio import Queue as AsyncQueue
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, Optional, Union
|
||||
|
||||
from asyncio import Queue as AsyncQueue, Future
|
||||
|
||||
from . import utils
|
||||
from .config import Config
|
||||
from bumble.core import OutOfResourcesError, InvalidArgumentError
|
||||
from bumble.device import Device
|
||||
from bumble.l2cap import (
|
||||
ClassicChannel,
|
||||
ClassicChannelServer,
|
||||
ClassicChannelSpec,
|
||||
LeCreditBasedChannel,
|
||||
LeCreditBasedChannelServer,
|
||||
LeCreditBasedChannelSpec,
|
||||
)
|
||||
import grpc
|
||||
from google.protobuf import any_pb2, empty_pb2 # pytype: disable=pyi-error
|
||||
from pandora.l2cap_grpc_aio import L2CAPServicer # pytype: disable=pyi-error
|
||||
from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error
|
||||
COMMAND_NOT_UNDERSTOOD,
|
||||
INVALID_CID_IN_REQUEST,
|
||||
Channel as PandoraChannel,
|
||||
from pandora.l2cap_pb2 import COMMAND_NOT_UNDERSTOOD, INVALID_CID_IN_REQUEST
|
||||
from pandora.l2cap_pb2 import Channel as PandoraChannel # pytype: disable=pyi-error
|
||||
from pandora.l2cap_pb2 import (
|
||||
ConnectRequest,
|
||||
ConnectResponse,
|
||||
CreditBasedChannelRequest,
|
||||
@@ -51,8 +41,19 @@ from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error
|
||||
WaitDisconnectionRequest,
|
||||
WaitDisconnectionResponse,
|
||||
)
|
||||
from typing import AsyncGenerator, Dict, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from bumble.core import InvalidArgumentError, OutOfResourcesError
|
||||
from bumble.device import Device
|
||||
from bumble.l2cap import (
|
||||
ClassicChannel,
|
||||
ClassicChannelServer,
|
||||
ClassicChannelSpec,
|
||||
LeCreditBasedChannel,
|
||||
LeCreditBasedChannelServer,
|
||||
LeCreditBasedChannelSpec,
|
||||
)
|
||||
from bumble.pandora import utils
|
||||
from bumble.pandora.config import Config
|
||||
|
||||
L2capChannel = Union[ClassicChannel, LeCreditBasedChannel]
|
||||
|
||||
@@ -70,7 +71,7 @@ class L2CAPService(L2CAPServicer):
|
||||
)
|
||||
self.device = device
|
||||
self.config = config
|
||||
self.channels: Dict[bytes, ChannelContext] = {}
|
||||
self.channels: dict[bytes, ChannelContext] = {}
|
||||
|
||||
def register_event(self, l2cap_channel: L2capChannel) -> ChannelContext:
|
||||
close_future = asyncio.get_running_loop().create_future()
|
||||
@@ -83,7 +84,7 @@ class L2CAPService(L2CAPServicer):
|
||||
close_future.set_result(None)
|
||||
|
||||
l2cap_channel.sink = on_channel_sdu
|
||||
l2cap_channel.on('close', on_close)
|
||||
l2cap_channel.on(l2cap_channel.EVENT_CLOSE, on_close)
|
||||
|
||||
return ChannelContext(close_future, sdu_queue)
|
||||
|
||||
@@ -151,7 +152,7 @@ class L2CAPService(L2CAPServicer):
|
||||
spec=spec, handler=on_l2cap_channel
|
||||
)
|
||||
else:
|
||||
l2cap_server.on('connection', on_l2cap_channel)
|
||||
l2cap_server.on(l2cap_server.EVENT_CONNECTION, on_l2cap_channel)
|
||||
|
||||
try:
|
||||
self.log.debug('Waiting for a channel connection.')
|
||||
|
||||
@@ -13,24 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import grpc
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Optional, Union
|
||||
|
||||
from . import utils
|
||||
from .config import Config
|
||||
from bumble import hci
|
||||
from bumble.core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_LE_TRANSPORT,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ProtocolError,
|
||||
)
|
||||
from bumble.device import Connection as BumbleConnection, Device
|
||||
from bumble.hci import HCI_Error
|
||||
from bumble.utils import EventWatcher
|
||||
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
|
||||
import grpc
|
||||
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
|
||||
@@ -57,7 +47,17 @@ from pandora.security_pb2 import (
|
||||
WaitSecurityRequest,
|
||||
WaitSecurityResponse,
|
||||
)
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Optional, Union
|
||||
|
||||
import bumble.utils
|
||||
from bumble import hci
|
||||
from bumble.core import InvalidArgumentError, PhysicalTransport, ProtocolError
|
||||
from bumble.device import Connection as BumbleConnection
|
||||
from bumble.device import Device
|
||||
from bumble.hci import HCI_Error, Role
|
||||
from bumble.pairing import PairingConfig
|
||||
from bumble.pairing import PairingDelegate as BasePairingDelegate
|
||||
from bumble.pandora import utils
|
||||
from bumble.pandora.config import Config
|
||||
|
||||
|
||||
class PairingDelegate(BasePairingDelegate):
|
||||
@@ -95,7 +95,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
else:
|
||||
# In BR/EDR, connection may not be complete,
|
||||
# use address instead
|
||||
assert self.connection.transport == BT_BR_EDR_TRANSPORT
|
||||
assert self.connection.transport == PhysicalTransport.BR_EDR
|
||||
ev.address = bytes(reversed(bytes(self.connection.peer_address)))
|
||||
|
||||
return ev
|
||||
@@ -174,7 +174,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
|
||||
async def display_number(self, number: int, digits: int = 6) -> None:
|
||||
if (
|
||||
self.connection.transport == BT_BR_EDR_TRANSPORT
|
||||
self.connection.transport == PhysicalTransport.BR_EDR
|
||||
and self.io_capability == BasePairingDelegate.DISPLAY_OUTPUT_ONLY
|
||||
):
|
||||
return
|
||||
@@ -190,35 +190,6 @@ class PairingDelegate(BasePairingDelegate):
|
||||
self.service.event_queue.put_nowait(event)
|
||||
|
||||
|
||||
BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = {
|
||||
LEVEL0: lambda connection: True,
|
||||
LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated,
|
||||
LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated,
|
||||
LEVEL3: lambda connection: connection.encryption != 0
|
||||
and connection.authenticated
|
||||
and connection.link_key_type
|
||||
in (
|
||||
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
|
||||
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
|
||||
),
|
||||
LEVEL4: lambda connection: connection.encryption
|
||||
== hci.HCI_Encryption_Change_Event.AES_CCM
|
||||
and connection.authenticated
|
||||
and connection.link_key_type
|
||||
== hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
|
||||
}
|
||||
|
||||
LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = {
|
||||
LE_LEVEL1: lambda connection: True,
|
||||
LE_LEVEL2: lambda connection: connection.encryption != 0,
|
||||
LE_LEVEL3: lambda connection: connection.encryption != 0
|
||||
and connection.authenticated,
|
||||
LE_LEVEL4: lambda connection: connection.encryption != 0
|
||||
and connection.authenticated
|
||||
and connection.sc,
|
||||
}
|
||||
|
||||
|
||||
class SecurityService(SecurityServicer):
|
||||
def __init__(self, device: Device, config: Config) -> None:
|
||||
self.log = utils.BumbleServerLoggerAdapter(
|
||||
@@ -250,6 +221,59 @@ class SecurityService(SecurityServicer):
|
||||
|
||||
self.device.pairing_config_factory = pairing_config_factory
|
||||
|
||||
async def _classic_level_reached(
|
||||
self, level: SecurityLevel, connection: BumbleConnection
|
||||
) -> bool:
|
||||
if level == LEVEL0:
|
||||
return True
|
||||
if level == LEVEL1:
|
||||
return connection.encryption == 0 or connection.authenticated
|
||||
if level == LEVEL2:
|
||||
return connection.encryption != 0 and connection.authenticated
|
||||
|
||||
link_key_type: Optional[int] = None
|
||||
if (keystore := connection.device.keystore) and (
|
||||
keys := await keystore.get(str(connection.peer_address))
|
||||
):
|
||||
link_key_type = keys.link_key_type
|
||||
self.log.debug("link_key_type: %d", link_key_type)
|
||||
|
||||
if level == LEVEL3:
|
||||
return (
|
||||
connection.encryption != 0
|
||||
and connection.authenticated
|
||||
and link_key_type
|
||||
in (
|
||||
hci.LinkKeyType.AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192,
|
||||
hci.LinkKeyType.AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256,
|
||||
)
|
||||
)
|
||||
if level == LEVEL4:
|
||||
return (
|
||||
connection.encryption == hci.HCI_Encryption_Change_Event.Enabled.AES_CCM
|
||||
and connection.authenticated
|
||||
and link_key_type
|
||||
== hci.LinkKeyType.AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256
|
||||
)
|
||||
raise InvalidArgumentError(f"Unexpected level {level}")
|
||||
|
||||
def _le_level_reached(
|
||||
self, level: LESecurityLevel, connection: BumbleConnection
|
||||
) -> bool:
|
||||
if level == LE_LEVEL1:
|
||||
return True
|
||||
if level == LE_LEVEL2:
|
||||
return connection.encryption != 0
|
||||
if level == LE_LEVEL3:
|
||||
return connection.encryption != 0 and connection.authenticated
|
||||
if level == LE_LEVEL4:
|
||||
return (
|
||||
connection.encryption != 0
|
||||
and connection.authenticated
|
||||
and connection.sc
|
||||
)
|
||||
raise InvalidArgumentError(f"Unexpected level {level}")
|
||||
|
||||
@utils.rpc
|
||||
async def OnPairing(
|
||||
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
|
||||
@@ -287,12 +311,12 @@ class SecurityService(SecurityServicer):
|
||||
|
||||
oneof = request.WhichOneof('level')
|
||||
level = getattr(request, oneof)
|
||||
assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[
|
||||
assert {PhysicalTransport.BR_EDR: 'classic', PhysicalTransport.LE: 'le'}[
|
||||
connection.transport
|
||||
] == oneof
|
||||
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
if await self.reached_security_level(connection, level):
|
||||
return SecureResponse(success=empty_pb2.Empty())
|
||||
|
||||
# trigger pairing if needed
|
||||
@@ -302,23 +326,23 @@ class SecurityService(SecurityServicer):
|
||||
|
||||
security_result = asyncio.get_running_loop().create_future()
|
||||
|
||||
with contextlib.closing(EventWatcher()) as watcher:
|
||||
with contextlib.closing(bumble.utils.EventWatcher()) as watcher:
|
||||
|
||||
@watcher.on(connection, 'pairing')
|
||||
@watcher.on(connection, connection.EVENT_PAIRING)
|
||||
def on_pairing(*_: Any) -> None:
|
||||
security_result.set_result('success')
|
||||
|
||||
@watcher.on(connection, 'pairing_failure')
|
||||
@watcher.on(connection, connection.EVENT_PAIRING_FAILURE)
|
||||
def on_pairing_failure(*_: Any) -> None:
|
||||
security_result.set_result('pairing_failure')
|
||||
|
||||
@watcher.on(connection, 'disconnection')
|
||||
@watcher.on(connection, connection.EVENT_DISCONNECTION)
|
||||
def on_disconnection(*_: Any) -> None:
|
||||
security_result.set_result('connection_died')
|
||||
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
connection.transport == PhysicalTransport.LE
|
||||
and connection.role == Role.PERIPHERAL
|
||||
):
|
||||
connection.request_pairing()
|
||||
else:
|
||||
@@ -363,7 +387,7 @@ class SecurityService(SecurityServicer):
|
||||
return SecureResponse(encryption_failure=empty_pb2.Empty())
|
||||
|
||||
# security level has been reached ?
|
||||
if self.reached_security_level(connection, level):
|
||||
if await self.reached_security_level(connection, level):
|
||||
return SecureResponse(success=empty_pb2.Empty())
|
||||
return SecureResponse(not_reached=empty_pb2.Empty())
|
||||
|
||||
@@ -379,7 +403,7 @@ class SecurityService(SecurityServicer):
|
||||
|
||||
assert request.level
|
||||
level = request.level
|
||||
assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[
|
||||
assert {PhysicalTransport.BR_EDR: 'classic', PhysicalTransport.LE: 'le'}[
|
||||
connection.transport
|
||||
] == request.level_variant()
|
||||
|
||||
@@ -390,13 +414,10 @@ class SecurityService(SecurityServicer):
|
||||
pair_task: Optional[asyncio.Future[None]] = None
|
||||
|
||||
async def authenticate() -> None:
|
||||
assert connection
|
||||
if (encryption := connection.encryption) != 0:
|
||||
self.log.debug('Disable encryption...')
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
await connection.encrypt(enable=False)
|
||||
except:
|
||||
pass
|
||||
self.log.debug('Disable encryption: done')
|
||||
|
||||
self.log.debug('Authenticate...')
|
||||
@@ -415,19 +436,17 @@ class SecurityService(SecurityServicer):
|
||||
|
||||
return wrapper
|
||||
|
||||
def try_set_success(*_: Any) -> None:
|
||||
assert connection
|
||||
if self.reached_security_level(connection, level):
|
||||
async def try_set_success(*_: Any) -> None:
|
||||
if await self.reached_security_level(connection, level):
|
||||
self.log.debug('Wait for security: done')
|
||||
wait_for_security.set_result('success')
|
||||
|
||||
def on_encryption_change(*_: Any) -> None:
|
||||
assert connection
|
||||
if self.reached_security_level(connection, level):
|
||||
async def on_encryption_change(*_: Any) -> None:
|
||||
if await self.reached_security_level(connection, level):
|
||||
self.log.debug('Wait for security: done')
|
||||
wait_for_security.set_result('success')
|
||||
elif (
|
||||
connection.transport == BT_BR_EDR_TRANSPORT
|
||||
connection.transport == PhysicalTransport.BR_EDR
|
||||
and self.need_authentication(connection, level)
|
||||
):
|
||||
nonlocal authenticate_task
|
||||
@@ -438,7 +457,7 @@ class SecurityService(SecurityServicer):
|
||||
if self.need_pairing(connection, level):
|
||||
pair_task = asyncio.create_task(connection.pair())
|
||||
|
||||
listeners: Dict[str, Callable[..., None]] = {
|
||||
listeners: dict[str, Callable[..., Union[None, Awaitable[None]]]] = {
|
||||
'disconnection': set_failure('connection_died'),
|
||||
'pairing_failure': set_failure('pairing_failure'),
|
||||
'connection_authentication_failure': set_failure('authentication_failure'),
|
||||
@@ -451,13 +470,13 @@ class SecurityService(SecurityServicer):
|
||||
'security_request': pair,
|
||||
}
|
||||
|
||||
with contextlib.closing(EventWatcher()) as watcher:
|
||||
with contextlib.closing(bumble.utils.EventWatcher()) as watcher:
|
||||
# register event handlers
|
||||
for event, listener in listeners.items():
|
||||
watcher.on(connection, event, listener)
|
||||
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
if await self.reached_security_level(connection, level):
|
||||
return WaitSecurityResponse(success=empty_pb2.Empty())
|
||||
|
||||
self.log.debug('Wait for security...')
|
||||
@@ -467,24 +486,20 @@ class SecurityService(SecurityServicer):
|
||||
# wait for `authenticate` to finish if any
|
||||
if authenticate_task is not None:
|
||||
self.log.debug('Wait for authentication...')
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
await authenticate_task # type: ignore
|
||||
except:
|
||||
pass
|
||||
self.log.debug('Authenticated')
|
||||
|
||||
# wait for `pair` to finish if any
|
||||
if pair_task is not None:
|
||||
self.log.debug('Wait for authentication...')
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
await pair_task # type: ignore
|
||||
except:
|
||||
pass
|
||||
self.log.debug('paired')
|
||||
|
||||
return WaitSecurityResponse(**kwargs)
|
||||
|
||||
def reached_security_level(
|
||||
async def reached_security_level(
|
||||
self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
|
||||
) -> bool:
|
||||
self.log.debug(
|
||||
@@ -494,23 +509,22 @@ class SecurityService(SecurityServicer):
|
||||
'encryption': connection.encryption,
|
||||
'authenticated': connection.authenticated,
|
||||
'sc': connection.sc,
|
||||
'link_key_type': connection.link_key_type,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(level, LESecurityLevel):
|
||||
return LE_LEVEL_REACHED[level](connection)
|
||||
return self._le_level_reached(level, connection)
|
||||
|
||||
return BR_LEVEL_REACHED[level](connection)
|
||||
return await self._classic_level_reached(level, connection)
|
||||
|
||||
def need_pairing(self, connection: BumbleConnection, level: int) -> bool:
|
||||
if connection.transport == BT_LE_TRANSPORT:
|
||||
if connection.transport == PhysicalTransport.LE:
|
||||
return level >= LE_LEVEL3 and not connection.authenticated
|
||||
return False
|
||||
|
||||
def need_authentication(self, connection: BumbleConnection, level: int) -> bool:
|
||||
if connection.transport == BT_LE_TRANSPORT:
|
||||
if connection.transport == PhysicalTransport.LE:
|
||||
return False
|
||||
if level == LEVEL2 and connection.encryption != 0:
|
||||
return not connection.authenticated
|
||||
@@ -518,7 +532,7 @@ class SecurityService(SecurityServicer):
|
||||
|
||||
def need_encryption(self, connection: BumbleConnection, level: int) -> bool:
|
||||
# TODO(abel): need to support MITM
|
||||
if connection.transport == BT_LE_TRANSPORT:
|
||||
if connection.transport == PhysicalTransport.LE:
|
||||
return level == LE_LEVEL2 and not connection.encryption
|
||||
return level >= LEVEL2 and not connection.encryption
|
||||
|
||||
|
||||
@@ -13,18 +13,20 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import grpc
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Generator, MutableMapping, Optional
|
||||
|
||||
import grpc
|
||||
from google.protobuf.message import Message # pytype: disable=pyi-error
|
||||
|
||||
from bumble.device import Device
|
||||
from bumble.hci import Address
|
||||
from google.protobuf.message import Message # pytype: disable=pyi-error
|
||||
from typing import Any, Dict, Generator, MutableMapping, Optional, Tuple
|
||||
from bumble.hci import Address, AddressType
|
||||
|
||||
ADDRESS_TYPES: Dict[str, int] = {
|
||||
ADDRESS_TYPES: dict[str, AddressType] = {
|
||||
"public": Address.PUBLIC_DEVICE_ADDRESS,
|
||||
"random": Address.RANDOM_DEVICE_ADDRESS,
|
||||
"public_identity": Address.PUBLIC_IDENTITY_ADDRESS,
|
||||
@@ -43,7 +45,7 @@ class BumbleServerLoggerAdapter(logging.LoggerAdapter): # type: ignore
|
||||
|
||||
def process(
|
||||
self, msg: str, kwargs: MutableMapping[str, Any]
|
||||
) -> Tuple[str, MutableMapping[str, Any]]:
|
||||
) -> tuple[str, MutableMapping[str, Any]]:
|
||||
assert self.extra
|
||||
service_name = self.extra['service_name']
|
||||
assert isinstance(service_name, str)
|
||||
|
||||
@@ -17,31 +17,38 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from bumble import gatt
|
||||
from bumble.device import Connection
|
||||
from bumble import utils
|
||||
from bumble.att import ATT_Error
|
||||
from bumble.device import Connection
|
||||
from bumble.gatt import (
|
||||
Characteristic,
|
||||
DelegatedCharacteristicAdapter,
|
||||
TemplateService,
|
||||
CharacteristicValue,
|
||||
PackedCharacteristicAdapter,
|
||||
GATT_AUDIO_INPUT_CONTROL_SERVICE,
|
||||
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
|
||||
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
|
||||
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
|
||||
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
|
||||
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
|
||||
GATT_AUDIO_INPUT_CONTROL_SERVICE,
|
||||
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
|
||||
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
|
||||
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
|
||||
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
|
||||
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
|
||||
Attribute,
|
||||
Characteristic,
|
||||
CharacteristicValue,
|
||||
TemplateService,
|
||||
)
|
||||
from bumble.gatt_adapters import (
|
||||
CharacteristicProxy,
|
||||
PackedCharacteristicProxyAdapter,
|
||||
SerializableCharacteristicAdapter,
|
||||
SerializableCharacteristicProxyAdapter,
|
||||
UTF8CharacteristicAdapter,
|
||||
UTF8CharacteristicProxyAdapter,
|
||||
)
|
||||
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
|
||||
from bumble.utils import OpenIntEnum
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -57,7 +64,7 @@ GAIN_SETTINGS_MIN_VALUE = 0
|
||||
GAIN_SETTINGS_MAX_VALUE = 255
|
||||
|
||||
|
||||
class ErrorCode(OpenIntEnum):
|
||||
class ErrorCode(utils.OpenIntEnum):
|
||||
'''
|
||||
Cf. 1.6 Application error codes
|
||||
'''
|
||||
@@ -69,7 +76,7 @@ class ErrorCode(OpenIntEnum):
|
||||
GAIN_MODE_CHANGE_NOT_ALLOWED = 0x84
|
||||
|
||||
|
||||
class Mute(OpenIntEnum):
|
||||
class Mute(utils.OpenIntEnum):
|
||||
'''
|
||||
Cf. 2.2.1.2 Mute Field
|
||||
'''
|
||||
@@ -79,7 +86,7 @@ class Mute(OpenIntEnum):
|
||||
DISABLED = 0x02
|
||||
|
||||
|
||||
class GainMode(OpenIntEnum):
|
||||
class GainMode(utils.OpenIntEnum):
|
||||
'''
|
||||
Cf. 2.2.1.3 Gain Mode
|
||||
'''
|
||||
@@ -90,21 +97,21 @@ class GainMode(OpenIntEnum):
|
||||
AUTOMATIC = 0x03
|
||||
|
||||
|
||||
class AudioInputStatus(OpenIntEnum):
|
||||
class AudioInputStatus(utils.OpenIntEnum):
|
||||
'''
|
||||
Cf. 3.4 Audio Input Status
|
||||
'''
|
||||
|
||||
INATIVE = 0x00
|
||||
INACTIVE = 0x00
|
||||
ACTIVE = 0x01
|
||||
|
||||
|
||||
class AudioInputControlPointOpCode(OpenIntEnum):
|
||||
class AudioInputControlPointOpCode(utils.OpenIntEnum):
|
||||
'''
|
||||
Cf. 3.5.1 Audio Input Control Point procedure requirements
|
||||
'''
|
||||
|
||||
SET_GAIN_SETTING = 0x00
|
||||
SET_GAIN_SETTING = 0x01
|
||||
UNMUTE = 0x02
|
||||
MUTE = 0x03
|
||||
SET_MANUAL_GAIN_MODE = 0x04
|
||||
@@ -122,7 +129,7 @@ class AudioInputState:
|
||||
mute: Mute = Mute.NOT_MUTED
|
||||
gain_mode: GainMode = GainMode.MANUAL
|
||||
change_counter: int = 0
|
||||
attribute_value: Optional[CharacteristicValue] = None
|
||||
attribute: Optional[Attribute] = None
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes(
|
||||
@@ -149,13 +156,8 @@ class AudioInputState:
|
||||
self.change_counter = (self.change_counter + 1) % (CHANGE_COUNTER_MAX_VALUE + 1)
|
||||
|
||||
async def notify_subscribers_via_connection(self, connection: Connection) -> None:
|
||||
assert self.attribute_value is not None
|
||||
await connection.device.notify_subscribers(
|
||||
attribute=self.attribute_value, value=bytes(self)
|
||||
)
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
return bytes(self)
|
||||
assert self.attribute is not None
|
||||
await connection.device.notify_subscribers(attribute=self.attribute)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -173,7 +175,7 @@ class GainSettingsProperties:
|
||||
(gain_settings_unit, gain_settings_minimum, gain_settings_maximum) = (
|
||||
struct.unpack('BBB', data)
|
||||
)
|
||||
GainSettingsProperties(
|
||||
return GainSettingsProperties(
|
||||
gain_settings_unit, gain_settings_minimum, gain_settings_maximum
|
||||
)
|
||||
|
||||
@@ -186,9 +188,6 @@ class GainSettingsProperties:
|
||||
]
|
||||
)
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioInputControlPoint:
|
||||
@@ -199,8 +198,7 @@ class AudioInputControlPoint:
|
||||
audio_input_state: AudioInputState
|
||||
gain_settings_properties: GainSettingsProperties
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
assert connection
|
||||
async def on_write(self, connection: Connection, value: bytes) -> None:
|
||||
|
||||
opcode = AudioInputControlPointOpCode(value[0])
|
||||
|
||||
@@ -239,7 +237,7 @@ class AudioInputControlPoint:
|
||||
or gain_settings_operand
|
||||
> self.gain_settings_properties.gain_settings_maximum
|
||||
):
|
||||
logger.error("gain_seetings value out of range")
|
||||
logger.error("gain_settings value out of range")
|
||||
raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
|
||||
|
||||
if self.audio_input_state.gain_settings != gain_settings_operand:
|
||||
@@ -319,31 +317,27 @@ class AudioInputDescription:
|
||||
'''
|
||||
|
||||
audio_input_description: str = "Bluetooth"
|
||||
attribute_value: Optional[CharacteristicValue] = None
|
||||
attribute: Optional[Attribute] = None
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
return cls(audio_input_description=data.decode('utf-8'))
|
||||
def on_read(self, _connection: Connection) -> str:
|
||||
return self.audio_input_description
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.audio_input_description.encode('utf-8')
|
||||
async def on_write(self, connection: Connection, value: str) -> None:
|
||||
assert self.attribute
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
return self.audio_input_description.encode('utf-8')
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
assert connection
|
||||
assert self.attribute_value
|
||||
|
||||
self.audio_input_description = value.decode('utf-8')
|
||||
await connection.device.notify_subscribers(
|
||||
attribute=self.attribute_value, value=value
|
||||
)
|
||||
self.audio_input_description = value
|
||||
await connection.device.notify_subscribers(attribute=self.attribute)
|
||||
|
||||
|
||||
class AICSService(TemplateService):
|
||||
UUID = GATT_AUDIO_INPUT_CONTROL_SERVICE
|
||||
|
||||
audio_input_state_characteristic: Characteristic[AudioInputState]
|
||||
audio_input_type_characteristic: Characteristic[bytes]
|
||||
audio_input_status_characteristic: Characteristic[bytes]
|
||||
audio_input_control_point_characteristic: Characteristic[bytes]
|
||||
gain_settings_properties_characteristic: Characteristic[GainSettingsProperties]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_input_state: Optional[AudioInputState] = None,
|
||||
@@ -375,26 +369,27 @@ class AICSService(TemplateService):
|
||||
self.audio_input_state, self.gain_settings_properties
|
||||
)
|
||||
|
||||
self.audio_input_state_characteristic = DelegatedCharacteristicAdapter(
|
||||
self.audio_input_state_characteristic = SerializableCharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ
|
||||
| Characteristic.Properties.NOTIFY,
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(read=self.audio_input_state.on_read),
|
||||
value=self.audio_input_state,
|
||||
),
|
||||
encode=lambda value: bytes(value),
|
||||
)
|
||||
self.audio_input_state.attribute_value = (
|
||||
self.audio_input_state_characteristic.value
|
||||
AudioInputState,
|
||||
)
|
||||
self.audio_input_state.attribute = self.audio_input_state_characteristic
|
||||
|
||||
self.gain_settings_properties_characteristic = DelegatedCharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(read=self.gain_settings_properties.on_read),
|
||||
self.gain_settings_properties_characteristic = (
|
||||
SerializableCharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=self.gain_settings_properties,
|
||||
),
|
||||
GainSettingsProperties,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -402,7 +397,7 @@ class AICSService(TemplateService):
|
||||
uuid=GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=audio_input_type,
|
||||
value=bytes(audio_input_type, 'utf-8'),
|
||||
)
|
||||
|
||||
self.audio_input_status_characteristic = Characteristic(
|
||||
@@ -412,18 +407,14 @@ class AICSService(TemplateService):
|
||||
value=bytes([self.audio_input_status]),
|
||||
)
|
||||
|
||||
self.audio_input_control_point_characteristic = DelegatedCharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.WRITE,
|
||||
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(
|
||||
write=self.audio_input_control_point.on_write
|
||||
),
|
||||
)
|
||||
self.audio_input_control_point_characteristic = Characteristic(
|
||||
uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.WRITE,
|
||||
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(write=self.audio_input_control_point.on_write),
|
||||
)
|
||||
|
||||
self.audio_input_description_characteristic = DelegatedCharacteristicAdapter(
|
||||
self.audio_input_description_characteristic = UTF8CharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ
|
||||
@@ -437,8 +428,8 @@ class AICSService(TemplateService):
|
||||
),
|
||||
)
|
||||
)
|
||||
self.audio_input_description.attribute_value = (
|
||||
self.audio_input_control_point_characteristic.value
|
||||
self.audio_input_description.attribute = (
|
||||
self.audio_input_control_point_characteristic
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
@@ -460,61 +451,43 @@ class AICSService(TemplateService):
|
||||
class AICSServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = AICSService
|
||||
|
||||
audio_input_state: CharacteristicProxy[AudioInputState]
|
||||
gain_settings_properties: CharacteristicProxy[GainSettingsProperties]
|
||||
audio_input_status: CharacteristicProxy[int]
|
||||
audio_input_control_point: CharacteristicProxy[bytes]
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
self.audio_input_state = SerializableCharacteristicProxyAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError("Audio Input State Characteristic not found")
|
||||
self.audio_input_state = DelegatedCharacteristicAdapter(
|
||||
characteristic=characteristics[0], decode=AudioInputState.from_bytes
|
||||
),
|
||||
AudioInputState,
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
self.gain_settings_properties = SerializableCharacteristicProxyAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Gain Settings Attribute Characteristic not found"
|
||||
)
|
||||
self.gain_settings_properties = PackedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
'BBB',
|
||||
),
|
||||
GainSettingsProperties,
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
self.audio_input_status = PackedCharacteristicProxyAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Audio Input Status Characteristic not found"
|
||||
)
|
||||
self.audio_input_status = PackedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
),
|
||||
'B',
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
self.audio_input_control_point = (
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Audio Input Control Point Characteristic not found"
|
||||
)
|
||||
self.audio_input_control_point = characteristics[0]
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
self.audio_input_description = UTF8CharacteristicProxyAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Audio Input Description Characteristic not found"
|
||||
)
|
||||
self.audio_input_description = characteristics[0]
|
||||
)
|
||||
|
||||
403
bumble/profiles/ams.py
Normal file
403
bumble/profiles/ams.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Apple Media Service (AMS).
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
from bumble import utils
|
||||
from bumble.device import Peer
|
||||
from bumble.gatt import (
|
||||
GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC,
|
||||
GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC,
|
||||
GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC,
|
||||
GATT_AMS_SERVICE,
|
||||
Characteristic,
|
||||
TemplateService,
|
||||
)
|
||||
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Protocol
|
||||
# -----------------------------------------------------------------------------
|
||||
class RemoteCommandId(utils.OpenIntEnum):
|
||||
PLAY = 0
|
||||
PAUSE = 1
|
||||
TOGGLE_PLAY_PAUSE = 2
|
||||
NEXT_TRACK = 3
|
||||
PREVIOUS_TRACK = 4
|
||||
VOLUME_UP = 5
|
||||
VOLUME_DOWN = 6
|
||||
ADVANCE_REPEAT_MODE = 7
|
||||
ADVANCE_SHUFFLE_MODE = 8
|
||||
SKIP_FORWARD = 9
|
||||
SKIP_BACKWARD = 10
|
||||
LIKE_TRACK = 11
|
||||
DISLIKE_TRACK = 12
|
||||
BOOKMARK_TRACK = 13
|
||||
|
||||
|
||||
class EntityId(utils.OpenIntEnum):
|
||||
PLAYER = 0
|
||||
QUEUE = 1
|
||||
TRACK = 2
|
||||
|
||||
|
||||
class ActionId(utils.OpenIntEnum):
|
||||
POSITIVE = 0
|
||||
NEGATIVE = 1
|
||||
|
||||
|
||||
class EntityUpdateFlags(enum.IntFlag):
|
||||
TRUNCATED = 1
|
||||
|
||||
|
||||
class PlayerAttributeId(utils.OpenIntEnum):
|
||||
NAME = 0
|
||||
PLAYBACK_INFO = 1
|
||||
VOLUME = 2
|
||||
|
||||
|
||||
class QueueAttributeId(utils.OpenIntEnum):
|
||||
INDEX = 0
|
||||
COUNT = 1
|
||||
SHUFFLE_MODE = 2
|
||||
REPEAT_MODE = 3
|
||||
|
||||
|
||||
class ShuffleMode(utils.OpenIntEnum):
|
||||
OFF = 0
|
||||
ONE = 1
|
||||
ALL = 2
|
||||
|
||||
|
||||
class RepeatMode(utils.OpenIntEnum):
|
||||
OFF = 0
|
||||
ONE = 1
|
||||
ALL = 2
|
||||
|
||||
|
||||
class TrackAttributeId(utils.OpenIntEnum):
|
||||
ARTIST = 0
|
||||
ALBUM = 1
|
||||
TITLE = 2
|
||||
DURATION = 3
|
||||
|
||||
|
||||
class PlaybackState(utils.OpenIntEnum):
|
||||
PAUSED = 0
|
||||
PLAYING = 1
|
||||
REWINDING = 2
|
||||
FAST_FORWARDING = 3
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PlaybackInfo:
|
||||
playback_state: PlaybackState = PlaybackState.PAUSED
|
||||
playback_rate: float = 1.0
|
||||
elapsed_time: float = 0.0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GATT Server-side
|
||||
# -----------------------------------------------------------------------------
|
||||
class Ams(TemplateService):
|
||||
UUID = GATT_AMS_SERVICE
|
||||
|
||||
remote_command_characteristic: Characteristic
|
||||
entity_update_characteristic: Characteristic
|
||||
entity_attribute_characteristic: Characteristic
|
||||
|
||||
def __init__(self) -> None:
|
||||
# TODO not the final implementation
|
||||
self.remote_command_characteristic = Characteristic(
|
||||
GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC,
|
||||
Characteristic.Properties.NOTIFY
|
||||
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
|
||||
Characteristic.Permissions.WRITEABLE,
|
||||
)
|
||||
|
||||
# TODO not the final implementation
|
||||
self.entity_update_characteristic = Characteristic(
|
||||
GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC,
|
||||
Characteristic.Properties.NOTIFY | Characteristic.Properties.WRITE,
|
||||
Characteristic.Permissions.WRITEABLE,
|
||||
)
|
||||
|
||||
# TODO not the final implementation
|
||||
self.entity_attribute_characteristic = Characteristic(
|
||||
GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC,
|
||||
Characteristic.Properties.READ
|
||||
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
|
||||
Characteristic.Permissions.WRITEABLE | Characteristic.Permissions.READABLE,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
[
|
||||
self.remote_command_characteristic,
|
||||
self.entity_update_characteristic,
|
||||
self.entity_attribute_characteristic,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GATT Client-side
|
||||
# -----------------------------------------------------------------------------
|
||||
class AmsProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = Ams
|
||||
|
||||
# NOTE: these don't use adapters, because the format for write and notifications
|
||||
# are different.
|
||||
remote_command: CharacteristicProxy[bytes]
|
||||
entity_update: CharacteristicProxy[bytes]
|
||||
entity_attribute: CharacteristicProxy[bytes]
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy):
|
||||
self.remote_command = service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC
|
||||
)
|
||||
|
||||
self.entity_update = service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC
|
||||
)
|
||||
|
||||
self.entity_attribute = service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC
|
||||
)
|
||||
|
||||
|
||||
class AmsClient(utils.EventEmitter):
|
||||
EVENT_SUPPORTED_COMMANDS = "supported_commands"
|
||||
EVENT_PLAYER_NAME = "player_name"
|
||||
EVENT_PLAYER_PLAYBACK_INFO = "player_playback_info"
|
||||
EVENT_PLAYER_VOLUME = "player_volume"
|
||||
EVENT_QUEUE_COUNT = "queue_count"
|
||||
EVENT_QUEUE_INDEX = "queue_index"
|
||||
EVENT_QUEUE_SHUFFLE_MODE = "queue_shuffle_mode"
|
||||
EVENT_QUEUE_REPEAT_MODE = "queue_repeat_mode"
|
||||
EVENT_TRACK_ARTIST = "track_artist"
|
||||
EVENT_TRACK_ALBUM = "track_album"
|
||||
EVENT_TRACK_TITLE = "track_title"
|
||||
EVENT_TRACK_DURATION = "track_duration"
|
||||
|
||||
supported_commands: set[RemoteCommandId]
|
||||
player_name: str = ""
|
||||
player_playback_info: PlaybackInfo = PlaybackInfo(PlaybackState.PAUSED, 0.0, 0.0)
|
||||
player_volume: float = 1.0
|
||||
queue_count: int = 0
|
||||
queue_index: int = 0
|
||||
queue_shuffle_mode: ShuffleMode = ShuffleMode.OFF
|
||||
queue_repeat_mode: RepeatMode = RepeatMode.OFF
|
||||
track_artist: str = ""
|
||||
track_album: str = ""
|
||||
track_title: str = ""
|
||||
track_duration: float = 0.0
|
||||
|
||||
def __init__(self, ams_proxy: AmsProxy) -> None:
|
||||
super().__init__()
|
||||
self._ams_proxy = ams_proxy
|
||||
self._started = False
|
||||
self._read_attribute_semaphore = asyncio.Semaphore()
|
||||
self.supported_commands = set()
|
||||
|
||||
@classmethod
|
||||
async def for_peer(cls, peer: Peer) -> Optional[AmsClient]:
|
||||
ams_proxy = await peer.discover_service_and_create_proxy(AmsProxy)
|
||||
if ams_proxy is None:
|
||||
return None
|
||||
return cls(ams_proxy)
|
||||
|
||||
async def start(self) -> None:
|
||||
logger.debug("subscribing to remote command characteristic")
|
||||
await self._ams_proxy.remote_command.subscribe(
|
||||
self._on_remote_command_notification
|
||||
)
|
||||
|
||||
logger.debug("subscribing to entity update characteristic")
|
||||
await self._ams_proxy.entity_update.subscribe(
|
||||
lambda data: utils.AsyncRunner.spawn(
|
||||
self._on_entity_update_notification(data)
|
||||
)
|
||||
)
|
||||
|
||||
self._started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
await self._ams_proxy.remote_command.unsubscribe(
|
||||
self._on_remote_command_notification
|
||||
)
|
||||
await self._ams_proxy.entity_update.unsubscribe(
|
||||
self._on_entity_update_notification
|
||||
)
|
||||
self._started = False
|
||||
|
||||
async def observe(
|
||||
self,
|
||||
entity: EntityId,
|
||||
attributes: Iterable[
|
||||
Union[PlayerAttributeId, QueueAttributeId, TrackAttributeId]
|
||||
],
|
||||
) -> None:
|
||||
await self._ams_proxy.entity_update.write_value(
|
||||
bytes([entity] + list(attributes)), with_response=True
|
||||
)
|
||||
|
||||
async def command(self, command: RemoteCommandId) -> None:
|
||||
await self._ams_proxy.remote_command.write_value(
|
||||
bytes([command]), with_response=True
|
||||
)
|
||||
|
||||
async def play(self) -> None:
|
||||
await self.command(RemoteCommandId.PLAY)
|
||||
|
||||
async def pause(self) -> None:
|
||||
await self.command(RemoteCommandId.PAUSE)
|
||||
|
||||
async def toggle_play_pause(self) -> None:
|
||||
await self.command(RemoteCommandId.TOGGLE_PLAY_PAUSE)
|
||||
|
||||
async def next_track(self) -> None:
|
||||
await self.command(RemoteCommandId.NEXT_TRACK)
|
||||
|
||||
async def previous_track(self) -> None:
|
||||
await self.command(RemoteCommandId.PREVIOUS_TRACK)
|
||||
|
||||
async def volume_up(self) -> None:
|
||||
await self.command(RemoteCommandId.VOLUME_UP)
|
||||
|
||||
async def volume_down(self) -> None:
|
||||
await self.command(RemoteCommandId.VOLUME_DOWN)
|
||||
|
||||
async def advance_repeat_mode(self) -> None:
|
||||
await self.command(RemoteCommandId.ADVANCE_REPEAT_MODE)
|
||||
|
||||
async def advance_shuffle_mode(self) -> None:
|
||||
await self.command(RemoteCommandId.ADVANCE_SHUFFLE_MODE)
|
||||
|
||||
async def skip_forward(self) -> None:
|
||||
await self.command(RemoteCommandId.SKIP_FORWARD)
|
||||
|
||||
async def skip_backward(self) -> None:
|
||||
await self.command(RemoteCommandId.SKIP_BACKWARD)
|
||||
|
||||
async def like_track(self) -> None:
|
||||
await self.command(RemoteCommandId.LIKE_TRACK)
|
||||
|
||||
async def dislike_track(self) -> None:
|
||||
await self.command(RemoteCommandId.DISLIKE_TRACK)
|
||||
|
||||
async def bookmark_track(self) -> None:
|
||||
await self.command(RemoteCommandId.BOOKMARK_TRACK)
|
||||
|
||||
def _on_remote_command_notification(self, data: bytes) -> None:
|
||||
supported_commands = [RemoteCommandId(command) for command in data]
|
||||
logger.debug(
|
||||
f"supported commands: {[command.name for command in supported_commands]}"
|
||||
)
|
||||
for command in supported_commands:
|
||||
self.supported_commands.add(command)
|
||||
|
||||
self.emit(self.EVENT_SUPPORTED_COMMANDS)
|
||||
|
||||
async def _on_entity_update_notification(self, data: bytes) -> None:
|
||||
entity = EntityId(data[0])
|
||||
flags = EntityUpdateFlags(data[2])
|
||||
value = data[3:]
|
||||
|
||||
if flags & EntityUpdateFlags.TRUNCATED:
|
||||
logger.debug("truncated attribute, fetching full value")
|
||||
|
||||
# Write the entity and attribute we're interested in
|
||||
# (protected by a semaphore, so that we only read one attribute at a time)
|
||||
async with self._read_attribute_semaphore:
|
||||
await self._ams_proxy.entity_attribute.write_value(
|
||||
data[:2], with_response=True
|
||||
)
|
||||
value = await self._ams_proxy.entity_attribute.read_value()
|
||||
|
||||
if entity == EntityId.PLAYER:
|
||||
player_attribute = PlayerAttributeId(data[1])
|
||||
if player_attribute == PlayerAttributeId.NAME:
|
||||
self.player_name = value.decode()
|
||||
self.emit(self.EVENT_PLAYER_NAME)
|
||||
elif player_attribute == PlayerAttributeId.PLAYBACK_INFO:
|
||||
playback_state_str, playback_rate_str, elapsed_time_str = (
|
||||
value.decode().split(",")
|
||||
)
|
||||
self.player_playback_info = PlaybackInfo(
|
||||
PlaybackState(int(playback_state_str)),
|
||||
float(playback_rate_str),
|
||||
float(elapsed_time_str),
|
||||
)
|
||||
self.emit(self.EVENT_PLAYER_PLAYBACK_INFO)
|
||||
elif player_attribute == PlayerAttributeId.VOLUME:
|
||||
self.player_volume = float(value.decode())
|
||||
self.emit(self.EVENT_PLAYER_VOLUME)
|
||||
else:
|
||||
logger.warning(f"received unknown player attribute {player_attribute}")
|
||||
|
||||
elif entity == EntityId.QUEUE:
|
||||
queue_attribute = QueueAttributeId(data[1])
|
||||
if queue_attribute == QueueAttributeId.COUNT:
|
||||
self.queue_count = int(value)
|
||||
self.emit(self.EVENT_QUEUE_COUNT)
|
||||
elif queue_attribute == QueueAttributeId.INDEX:
|
||||
self.queue_index = int(value)
|
||||
self.emit(self.EVENT_QUEUE_INDEX)
|
||||
elif queue_attribute == QueueAttributeId.REPEAT_MODE:
|
||||
self.queue_repeat_mode = RepeatMode(int(value))
|
||||
self.emit(self.EVENT_QUEUE_REPEAT_MODE)
|
||||
elif queue_attribute == QueueAttributeId.SHUFFLE_MODE:
|
||||
self.queue_shuffle_mode = ShuffleMode(int(value))
|
||||
self.emit(self.EVENT_QUEUE_SHUFFLE_MODE)
|
||||
else:
|
||||
logger.warning(f"received unknown queue attribute {queue_attribute}")
|
||||
|
||||
elif entity == EntityId.TRACK:
|
||||
track_attribute = TrackAttributeId(data[1])
|
||||
if track_attribute == TrackAttributeId.ARTIST:
|
||||
self.track_artist = value.decode()
|
||||
self.emit(self.EVENT_TRACK_ARTIST)
|
||||
elif track_attribute == TrackAttributeId.ALBUM:
|
||||
self.track_album = value.decode()
|
||||
self.emit(self.EVENT_TRACK_ALBUM)
|
||||
elif track_attribute == TrackAttributeId.TITLE:
|
||||
self.track_title = value.decode()
|
||||
self.emit(self.EVENT_TRACK_TITLE)
|
||||
elif track_attribute == TrackAttributeId.DURATION:
|
||||
self.track_duration = float(value.decode())
|
||||
self.emit(self.EVENT_TRACK_DURATION)
|
||||
else:
|
||||
logger.warning(f"received unknown track attribute {track_attribute}")
|
||||
|
||||
else:
|
||||
logger.warning(f"received unknown attribute ID {data[1]}")
|
||||
514
bumble/profiles/ancs.py
Normal file
514
bumble/profiles/ancs.py
Normal file
@@ -0,0 +1,514 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Apple Notification Center Service (ANCS).
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import datetime
|
||||
import enum
|
||||
import logging
|
||||
import struct
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from bumble import utils
|
||||
from bumble.att import ATT_Error
|
||||
from bumble.device import Peer
|
||||
from bumble.gatt import (
|
||||
GATT_ANCS_CONTROL_POINT_CHARACTERISTIC,
|
||||
GATT_ANCS_DATA_SOURCE_CHARACTERISTIC,
|
||||
GATT_ANCS_NOTIFICATION_SOURCE_CHARACTERISTIC,
|
||||
GATT_ANCS_SERVICE,
|
||||
Characteristic,
|
||||
TemplateService,
|
||||
)
|
||||
from bumble.gatt_adapters import SerializableCharacteristicProxyAdapter
|
||||
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
_DEFAULT_ATTRIBUTE_MAX_LENGTH = 65535
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Protocol
|
||||
# -----------------------------------------------------------------------------
|
||||
class ActionId(utils.OpenIntEnum):
|
||||
POSITIVE = 0
|
||||
NEGATIVE = 1
|
||||
|
||||
|
||||
class AppAttributeId(utils.OpenIntEnum):
|
||||
DISPLAY_NAME = 0
|
||||
|
||||
|
||||
class CategoryId(utils.OpenIntEnum):
|
||||
OTHER = 0
|
||||
INCOMING_CALL = 1
|
||||
MISSED_CALL = 2
|
||||
VOICEMAIL = 3
|
||||
SOCIAL = 4
|
||||
SCHEDULE = 5
|
||||
EMAIL = 6
|
||||
NEWS = 7
|
||||
HEALTH_AND_FITNESS = 8
|
||||
BUSINESS_AND_FINANCE = 9
|
||||
LOCATION = 10
|
||||
ENTERTAINMENT = 11
|
||||
|
||||
|
||||
class CommandId(utils.OpenIntEnum):
|
||||
GET_NOTIFICATION_ATTRIBUTES = 0
|
||||
GET_APP_ATTRIBUTES = 1
|
||||
PERFORM_NOTIFICATION_ACTION = 2
|
||||
|
||||
|
||||
class EventId(utils.OpenIntEnum):
|
||||
NOTIFICATION_ADDED = 0
|
||||
NOTIFICATION_MODIFIED = 1
|
||||
NOTIFICATION_REMOVED = 2
|
||||
|
||||
|
||||
class EventFlags(enum.IntFlag):
|
||||
SILENT = 1 << 0
|
||||
IMPORTANT = 1 << 1
|
||||
PRE_EXISTING = 1 << 2
|
||||
POSITIVE_ACTION = 1 << 3
|
||||
NEGATIVE_ACTION = 1 << 4
|
||||
|
||||
|
||||
class NotificationAttributeId(utils.OpenIntEnum):
|
||||
APP_IDENTIFIER = 0
|
||||
TITLE = 1
|
||||
SUBTITLE = 2
|
||||
MESSAGE = 3
|
||||
MESSAGE_SIZE = 4
|
||||
DATE = 5
|
||||
POSITIVE_ACTION_LABEL = 6
|
||||
NEGATIVE_ACTION_LABEL = 7
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NotificationAttribute:
|
||||
attribute_id: NotificationAttributeId
|
||||
value: Union[str, int, datetime.datetime]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AppAttribute:
|
||||
attribute_id: AppAttributeId
|
||||
value: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Notification:
|
||||
event_id: EventId
|
||||
event_flags: EventFlags
|
||||
category_id: CategoryId
|
||||
category_count: int
|
||||
notification_uid: int
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Notification:
|
||||
return cls(
|
||||
event_id=EventId(data[0]),
|
||||
event_flags=EventFlags(data[1]),
|
||||
category_id=CategoryId(data[2]),
|
||||
category_count=data[3],
|
||||
notification_uid=int.from_bytes(data[4:8], 'little'),
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack(
|
||||
"<BBBBI",
|
||||
self.event_id,
|
||||
self.event_flags,
|
||||
self.category_id,
|
||||
self.category_count,
|
||||
self.notification_uid,
|
||||
)
|
||||
|
||||
|
||||
class ErrorCode(utils.OpenIntEnum):
|
||||
UNKNOWN_COMMAND = 0xA0
|
||||
INVALID_COMMAND = 0xA1
|
||||
INVALID_PARAMETER = 0xA2
|
||||
ACTION_FAILED = 0xA3
|
||||
|
||||
|
||||
class ProtocolError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
def __init__(self, error_code: ErrorCode) -> None:
|
||||
self.error_code = error_code
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"CommandError(error_code={self.error_code.name})"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GATT Server-side
|
||||
# -----------------------------------------------------------------------------
|
||||
class Ancs(TemplateService):
|
||||
UUID = GATT_ANCS_SERVICE
|
||||
|
||||
notification_source_characteristic: Characteristic
|
||||
data_source_characteristic: Characteristic
|
||||
control_point_characteristic: Characteristic
|
||||
|
||||
def __init__(self) -> None:
|
||||
# TODO not the final implementation
|
||||
self.notification_source_characteristic = Characteristic(
|
||||
GATT_ANCS_NOTIFICATION_SOURCE_CHARACTERISTIC,
|
||||
Characteristic.Properties.NOTIFY,
|
||||
Characteristic.Permissions.READABLE,
|
||||
)
|
||||
|
||||
# TODO not the final implementation
|
||||
self.data_source_characteristic = Characteristic(
|
||||
GATT_ANCS_DATA_SOURCE_CHARACTERISTIC,
|
||||
Characteristic.Properties.NOTIFY,
|
||||
Characteristic.Permissions.READABLE,
|
||||
)
|
||||
|
||||
# TODO not the final implementation
|
||||
self.control_point_characteristic = Characteristic(
|
||||
GATT_ANCS_CONTROL_POINT_CHARACTERISTIC,
|
||||
Characteristic.Properties.WRITE,
|
||||
Characteristic.Permissions.WRITEABLE,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
[
|
||||
self.notification_source_characteristic,
|
||||
self.data_source_characteristic,
|
||||
self.control_point_characteristic,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GATT Client-side
|
||||
# -----------------------------------------------------------------------------
|
||||
class AncsProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = Ancs
|
||||
|
||||
notification_source: CharacteristicProxy[Notification]
|
||||
data_source: CharacteristicProxy
|
||||
control_point: CharacteristicProxy[bytes]
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy):
|
||||
self.notification_source = SerializableCharacteristicProxyAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_ANCS_NOTIFICATION_SOURCE_CHARACTERISTIC
|
||||
),
|
||||
Notification,
|
||||
)
|
||||
|
||||
self.data_source = service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_ANCS_DATA_SOURCE_CHARACTERISTIC
|
||||
)
|
||||
|
||||
self.control_point = service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_ANCS_CONTROL_POINT_CHARACTERISTIC
|
||||
)
|
||||
|
||||
|
||||
class AncsClient(utils.EventEmitter):
|
||||
_expected_response_command_id: Optional[CommandId]
|
||||
_expected_response_notification_uid: Optional[int]
|
||||
_expected_response_app_identifier: Optional[str]
|
||||
_expected_app_identifier: Optional[str]
|
||||
_expected_response_tuples: int
|
||||
_response_accumulator: bytes
|
||||
|
||||
EVENT_NOTIFICATION = "notification"
|
||||
|
||||
def __init__(self, ancs_proxy: AncsProxy) -> None:
|
||||
super().__init__()
|
||||
self._ancs_proxy = ancs_proxy
|
||||
self._command_semaphore = asyncio.Semaphore()
|
||||
self._response: Optional[asyncio.Future] = None
|
||||
self._reset_response()
|
||||
self._started = False
|
||||
|
||||
@classmethod
|
||||
async def for_peer(cls, peer: Peer) -> Optional[AncsClient]:
|
||||
ancs_proxy = await peer.discover_service_and_create_proxy(AncsProxy)
|
||||
if ancs_proxy is None:
|
||||
return None
|
||||
return cls(ancs_proxy)
|
||||
|
||||
async def start(self) -> None:
|
||||
await self._ancs_proxy.notification_source.subscribe(self._on_notification)
|
||||
await self._ancs_proxy.data_source.subscribe(self._on_data)
|
||||
self._started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
await self._ancs_proxy.notification_source.unsubscribe(self._on_notification)
|
||||
await self._ancs_proxy.data_source.unsubscribe(self._on_data)
|
||||
self._started = False
|
||||
|
||||
def _reset_response(self) -> None:
|
||||
self._expected_response_command_id = None
|
||||
self._expected_response_notification_uid = None
|
||||
self._expected_app_identifier = None
|
||||
self._expected_response_tuples = 0
|
||||
self._response_accumulator = b""
|
||||
|
||||
def _on_notification(self, notification: Notification) -> None:
|
||||
logger.debug(f"ANCS NOTIFICATION: {notification}")
|
||||
self.emit(self.EVENT_NOTIFICATION, notification)
|
||||
|
||||
def _on_data(self, data: bytes) -> None:
|
||||
logger.debug(f"ANCS DATA: {data.hex()}")
|
||||
|
||||
if not self._response:
|
||||
logger.warning("received unexpected data, discarding")
|
||||
return
|
||||
|
||||
self._response_accumulator += data
|
||||
|
||||
# Try to parse the accumulated data until we have all we need.
|
||||
if not self._response_accumulator:
|
||||
logger.warning("empty data from data source")
|
||||
return
|
||||
|
||||
command_id = self._response_accumulator[0]
|
||||
if command_id != self._expected_response_command_id:
|
||||
logger.warning(
|
||||
"unexpected response command id: "
|
||||
f"expected {self._expected_response_command_id} "
|
||||
f"but got {command_id}"
|
||||
)
|
||||
self._reset_response()
|
||||
if not self._response.done():
|
||||
self._response.set_exception(ProtocolError())
|
||||
|
||||
if len(self._response_accumulator) < 5:
|
||||
# Not enough data yet.
|
||||
return
|
||||
|
||||
attributes: list[Union[NotificationAttribute, AppAttribute]] = []
|
||||
|
||||
if command_id == CommandId.GET_NOTIFICATION_ATTRIBUTES:
|
||||
(notification_uid,) = struct.unpack_from(
|
||||
"<I", self._response_accumulator, 1
|
||||
)
|
||||
if notification_uid != self._expected_response_notification_uid:
|
||||
logger.warning(
|
||||
"unexpected response notification uid: "
|
||||
f"expected {self._expected_response_notification_uid} "
|
||||
f"but got {notification_uid}"
|
||||
)
|
||||
self._reset_response()
|
||||
if not self._response.done():
|
||||
self._response.set_exception(ProtocolError())
|
||||
|
||||
attribute_data = self._response_accumulator[5:]
|
||||
while len(attribute_data) >= 3:
|
||||
attribute_id, attribute_data_length = struct.unpack_from(
|
||||
"<BH", attribute_data, 0
|
||||
)
|
||||
if len(attribute_data) < 3 + attribute_data_length:
|
||||
return
|
||||
str_value = attribute_data[3 : 3 + attribute_data_length].decode(
|
||||
"utf-8"
|
||||
)
|
||||
value: Union[str, int, datetime.datetime]
|
||||
if attribute_id == NotificationAttributeId.MESSAGE_SIZE:
|
||||
value = int(str_value)
|
||||
elif attribute_id == NotificationAttributeId.DATE:
|
||||
year = int(str_value[:4])
|
||||
month = int(str_value[4:6])
|
||||
day = int(str_value[6:8])
|
||||
hour = int(str_value[9:11])
|
||||
minute = int(str_value[11:13])
|
||||
second = int(str_value[13:15])
|
||||
value = datetime.datetime(year, month, day, hour, minute, second)
|
||||
else:
|
||||
value = str_value
|
||||
attributes.append(
|
||||
NotificationAttribute(NotificationAttributeId(attribute_id), value)
|
||||
)
|
||||
attribute_data = attribute_data[3 + attribute_data_length :]
|
||||
elif command_id == CommandId.GET_APP_ATTRIBUTES:
|
||||
if 0 not in self._response_accumulator[1:]:
|
||||
# No null-terminated string yet.
|
||||
return
|
||||
|
||||
app_identifier_length = self._response_accumulator.find(0, 1) - 1
|
||||
app_identifier = self._response_accumulator[
|
||||
1 : 1 + app_identifier_length
|
||||
].decode("utf-8")
|
||||
if app_identifier != self._expected_response_app_identifier:
|
||||
logger.warning(
|
||||
"unexpected response app identifier: "
|
||||
f"expected {self._expected_response_app_identifier} "
|
||||
f"but got {app_identifier}"
|
||||
)
|
||||
self._reset_response()
|
||||
if not self._response.done():
|
||||
self._response.set_exception(ProtocolError())
|
||||
|
||||
attribute_data = self._response_accumulator[1 + app_identifier_length + 1 :]
|
||||
while len(attribute_data) >= 3:
|
||||
attribute_id, attribute_data_length = struct.unpack_from(
|
||||
"<BH", attribute_data, 0
|
||||
)
|
||||
if len(attribute_data) < 3 + attribute_data_length:
|
||||
return
|
||||
attributes.append(
|
||||
AppAttribute(
|
||||
AppAttributeId(attribute_id),
|
||||
attribute_data[3 : 3 + attribute_data_length].decode("utf-8"),
|
||||
)
|
||||
)
|
||||
attribute_data = attribute_data[3 + attribute_data_length :]
|
||||
else:
|
||||
logger.warning(f"unexpected response command id {command_id}")
|
||||
return
|
||||
|
||||
if len(attributes) < self._expected_response_tuples:
|
||||
# We have not received all the tuples yet.
|
||||
return
|
||||
|
||||
if not self._response.done():
|
||||
self._response.set_result(attributes)
|
||||
|
||||
async def _send_command(self, command: bytes) -> None:
|
||||
try:
|
||||
await self._ancs_proxy.control_point.write_value(
|
||||
command, with_response=True
|
||||
)
|
||||
except ATT_Error as error:
|
||||
raise CommandError(error_code=ErrorCode(error.error_code)) from error
|
||||
|
||||
async def get_notification_attributes(
|
||||
self,
|
||||
notification_uid: int,
|
||||
attributes: Sequence[
|
||||
Union[NotificationAttributeId, tuple[NotificationAttributeId, int]]
|
||||
],
|
||||
) -> list[NotificationAttribute]:
|
||||
if not self._started:
|
||||
raise RuntimeError("client not started")
|
||||
|
||||
command = struct.pack(
|
||||
"<BI", CommandId.GET_NOTIFICATION_ATTRIBUTES, notification_uid
|
||||
)
|
||||
for attribute in attributes:
|
||||
attribute_max_length = 0
|
||||
if isinstance(attribute, tuple):
|
||||
attribute_id, attribute_max_length = attribute
|
||||
if attribute_id not in (
|
||||
NotificationAttributeId.TITLE,
|
||||
NotificationAttributeId.SUBTITLE,
|
||||
NotificationAttributeId.MESSAGE,
|
||||
):
|
||||
raise ValueError(
|
||||
"this attribute does not allow specifying a max length"
|
||||
)
|
||||
else:
|
||||
attribute_id = attribute
|
||||
if attribute_id in (
|
||||
NotificationAttributeId.TITLE,
|
||||
NotificationAttributeId.SUBTITLE,
|
||||
NotificationAttributeId.MESSAGE,
|
||||
):
|
||||
attribute_max_length = _DEFAULT_ATTRIBUTE_MAX_LENGTH
|
||||
|
||||
if attribute_max_length:
|
||||
command += struct.pack("<BH", attribute_id, attribute_max_length)
|
||||
else:
|
||||
command += struct.pack("B", attribute_id)
|
||||
|
||||
try:
|
||||
async with self._command_semaphore:
|
||||
self._expected_response_notification_uid = notification_uid
|
||||
self._expected_response_tuples = len(attributes)
|
||||
self._expected_response_command_id = (
|
||||
CommandId.GET_NOTIFICATION_ATTRIBUTES
|
||||
)
|
||||
self._response = asyncio.Future()
|
||||
|
||||
# Send the command.
|
||||
await self._send_command(command)
|
||||
|
||||
# Wait for the response.
|
||||
return await self._response
|
||||
finally:
|
||||
self._reset_response()
|
||||
|
||||
async def get_app_attributes(
|
||||
self, app_identifier: str, attributes: Sequence[AppAttributeId]
|
||||
) -> list[AppAttribute]:
|
||||
if not self._started:
|
||||
raise RuntimeError("client not started")
|
||||
|
||||
command = (
|
||||
bytes([CommandId.GET_APP_ATTRIBUTES])
|
||||
+ app_identifier.encode("utf-8")
|
||||
+ b"\0"
|
||||
)
|
||||
for attribute_id in attributes:
|
||||
command += struct.pack("B", attribute_id)
|
||||
|
||||
try:
|
||||
async with self._command_semaphore:
|
||||
self._expected_response_app_identifier = app_identifier
|
||||
self._expected_response_tuples = len(attributes)
|
||||
self._expected_response_command_id = CommandId.GET_APP_ATTRIBUTES
|
||||
self._response = asyncio.Future()
|
||||
|
||||
# Send the command.
|
||||
await self._send_command(command)
|
||||
|
||||
# Wait for the response.
|
||||
return await self._response
|
||||
finally:
|
||||
self._reset_response()
|
||||
|
||||
async def perform_action(self, notification_uid: int, action: ActionId) -> None:
|
||||
if not self._started:
|
||||
raise RuntimeError("client not started")
|
||||
|
||||
command = struct.pack(
|
||||
"<BIB", CommandId.PERFORM_NOTIFICATION_ACTION, notification_uid, action
|
||||
)
|
||||
|
||||
async with self._command_semaphore:
|
||||
await self._send_command(command)
|
||||
|
||||
async def perform_positive_action(self, notification_uid: int) -> None:
|
||||
return await self.perform_action(notification_uid, ActionId.POSITIVE)
|
||||
|
||||
async def perform_negative_action(self, notification_uid: int) -> None:
|
||||
return await self.perform_action(notification_uid, ActionId.NEGATIVE)
|
||||
@@ -17,18 +17,18 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import struct
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
|
||||
from bumble import colors
|
||||
from bumble.profiles.bap import CodecSpecificConfiguration
|
||||
from bumble import colors, device, gatt, gatt_client, hci, utils
|
||||
from bumble.profiles import le_audio
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble import hci
|
||||
from bumble.profiles.bap import CodecSpecificConfiguration
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -46,11 +46,11 @@ class ASE_Operation:
|
||||
See Audio Stream Control Service - 5 ASE Control operations.
|
||||
'''
|
||||
|
||||
classes: Dict[int, Type[ASE_Operation]] = {}
|
||||
op_code: int
|
||||
classes: dict[int, type[ASE_Operation]] = {}
|
||||
op_code: Opcode
|
||||
name: str
|
||||
fields: Optional[Sequence[Any]] = None
|
||||
ase_id: List[int]
|
||||
ase_id: Sequence[int]
|
||||
|
||||
class Opcode(enum.IntEnum):
|
||||
# fmt: off
|
||||
@@ -63,51 +63,30 @@ class ASE_Operation:
|
||||
UPDATE_METADATA = 0x07
|
||||
RELEASE = 0x08
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(pdu: bytes) -> ASE_Operation:
|
||||
@classmethod
|
||||
def from_bytes(cls, pdu: bytes) -> ASE_Operation:
|
||||
op_code = pdu[0]
|
||||
|
||||
cls = ASE_Operation.classes.get(op_code)
|
||||
if cls is None:
|
||||
instance = ASE_Operation(pdu)
|
||||
instance.name = ASE_Operation.Opcode(op_code).name
|
||||
instance.op_code = op_code
|
||||
return instance
|
||||
self = cls.__new__(cls)
|
||||
ASE_Operation.__init__(self, pdu)
|
||||
if self.fields is not None:
|
||||
self.init_from_bytes(pdu, 1)
|
||||
return self
|
||||
clazz = ASE_Operation.classes[op_code]
|
||||
return clazz(
|
||||
**hci.HCI_Object.dict_from_bytes(pdu, offset=1, fields=clazz.fields)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def subclass(fields):
|
||||
def inner(cls: Type[ASE_Operation]):
|
||||
try:
|
||||
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
|
||||
cls.name = operation.name
|
||||
cls.op_code = operation
|
||||
except:
|
||||
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
|
||||
cls.fields = fields
|
||||
_OP = TypeVar("_OP", bound="ASE_Operation")
|
||||
|
||||
# Register a factory for this class
|
||||
ASE_Operation.classes[cls.op_code] = cls
|
||||
@classmethod
|
||||
def subclass(cls, clazz: type[_OP]) -> type[_OP]:
|
||||
clazz.name = f"ASE_{clazz.op_code.name.upper()}"
|
||||
clazz.fields = hci.HCI_Object.fields_from_dataclass(clazz)
|
||||
# Register a factory for this class
|
||||
ASE_Operation.classes[clazz.op_code] = clazz
|
||||
return clazz
|
||||
|
||||
return cls
|
||||
|
||||
return inner
|
||||
|
||||
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
|
||||
if self.fields is not None and kwargs:
|
||||
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
|
||||
if pdu is None:
|
||||
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
|
||||
kwargs, self.fields
|
||||
)
|
||||
self.pdu = pdu
|
||||
|
||||
def init_from_bytes(self, pdu: bytes, offset: int):
|
||||
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||
@functools.cached_property
|
||||
def pdu(self) -> bytes:
|
||||
return bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
|
||||
self.__dict__, self.fields
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.pdu
|
||||
@@ -122,105 +101,128 @@ class ASE_Operation:
|
||||
return result
|
||||
|
||||
|
||||
@ASE_Operation.subclass(
|
||||
[
|
||||
[
|
||||
('ase_id', 1),
|
||||
('target_latency', 1),
|
||||
('target_phy', 1),
|
||||
('codec_id', hci.CodingFormat.parse_from_bytes),
|
||||
('codec_specific_configuration', 'v'),
|
||||
],
|
||||
]
|
||||
)
|
||||
@ASE_Operation.subclass
|
||||
@dataclass
|
||||
class ASE_Config_Codec(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.1 - Config Codec Operation
|
||||
'''
|
||||
|
||||
target_latency: List[int]
|
||||
target_phy: List[int]
|
||||
codec_id: List[hci.CodingFormat]
|
||||
codec_specific_configuration: List[bytes]
|
||||
op_code = ASE_Operation.Opcode.CONFIG_CODEC
|
||||
|
||||
ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True))
|
||||
target_latency: Sequence[int] = field(metadata=hci.metadata(1))
|
||||
target_phy: Sequence[int] = field(metadata=hci.metadata(1))
|
||||
codec_id: Sequence[hci.CodingFormat] = field(
|
||||
metadata=hci.metadata(hci.CodingFormat.parse_from_bytes)
|
||||
)
|
||||
codec_specific_configuration: Sequence[bytes] = field(
|
||||
metadata=hci.metadata('v', list_end=True)
|
||||
)
|
||||
|
||||
|
||||
@ASE_Operation.subclass(
|
||||
[
|
||||
[
|
||||
('ase_id', 1),
|
||||
('cig_id', 1),
|
||||
('cis_id', 1),
|
||||
('sdu_interval', 3),
|
||||
('framing', 1),
|
||||
('phy', 1),
|
||||
('max_sdu', 2),
|
||||
('retransmission_number', 1),
|
||||
('max_transport_latency', 2),
|
||||
('presentation_delay', 3),
|
||||
],
|
||||
]
|
||||
)
|
||||
@ASE_Operation.subclass
|
||||
@dataclass
|
||||
class ASE_Config_QOS(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.2 - Config Qos Operation
|
||||
'''
|
||||
|
||||
cig_id: List[int]
|
||||
cis_id: List[int]
|
||||
sdu_interval: List[int]
|
||||
framing: List[int]
|
||||
phy: List[int]
|
||||
max_sdu: List[int]
|
||||
retransmission_number: List[int]
|
||||
max_transport_latency: List[int]
|
||||
presentation_delay: List[int]
|
||||
op_code = ASE_Operation.Opcode.CONFIG_QOS
|
||||
|
||||
ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True))
|
||||
cig_id: Sequence[int] = field(metadata=hci.metadata(1))
|
||||
cis_id: Sequence[int] = field(metadata=hci.metadata(1))
|
||||
sdu_interval: Sequence[int] = field(metadata=hci.metadata(3))
|
||||
framing: Sequence[int] = field(metadata=hci.metadata(1))
|
||||
phy: Sequence[int] = field(metadata=hci.metadata(1))
|
||||
max_sdu: Sequence[int] = field(metadata=hci.metadata(2))
|
||||
retransmission_number: Sequence[int] = field(metadata=hci.metadata(1))
|
||||
max_transport_latency: Sequence[int] = field(metadata=hci.metadata(2))
|
||||
presentation_delay: Sequence[int] = field(metadata=hci.metadata(3, list_end=True))
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
|
||||
@ASE_Operation.subclass
|
||||
@dataclass
|
||||
class ASE_Enable(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.3 - Enable Operation
|
||||
'''
|
||||
|
||||
metadata: bytes
|
||||
op_code = ASE_Operation.Opcode.ENABLE
|
||||
|
||||
ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True))
|
||||
metadata: Sequence[bytes] = field(metadata=hci.metadata('v', list_end=True))
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||
@ASE_Operation.subclass
|
||||
@dataclass
|
||||
class ASE_Receiver_Start_Ready(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
|
||||
'''
|
||||
|
||||
op_code = ASE_Operation.Opcode.RECEIVER_START_READY
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||
ase_id: Sequence[int] = field(
|
||||
metadata=hci.metadata(1, list_begin=True, list_end=True)
|
||||
)
|
||||
|
||||
|
||||
@ASE_Operation.subclass
|
||||
@dataclass
|
||||
class ASE_Disable(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.5 - Disable Operation
|
||||
'''
|
||||
|
||||
op_code = ASE_Operation.Opcode.DISABLE
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||
ase_id: Sequence[int] = field(
|
||||
metadata=hci.metadata(1, list_begin=True, list_end=True)
|
||||
)
|
||||
|
||||
|
||||
@ASE_Operation.subclass
|
||||
@dataclass
|
||||
class ASE_Receiver_Stop_Ready(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
|
||||
'''
|
||||
|
||||
op_code = ASE_Operation.Opcode.RECEIVER_STOP_READY
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
|
||||
ase_id: Sequence[int] = field(
|
||||
metadata=hci.metadata(1, list_begin=True, list_end=True)
|
||||
)
|
||||
|
||||
|
||||
@ASE_Operation.subclass
|
||||
@dataclass
|
||||
class ASE_Update_Metadata(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.7 - Update Metadata Operation
|
||||
'''
|
||||
|
||||
metadata: List[bytes]
|
||||
op_code = ASE_Operation.Opcode.UPDATE_METADATA
|
||||
|
||||
ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True))
|
||||
metadata: Sequence[bytes] = field(metadata=hci.metadata('v', list_end=True))
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||
@ASE_Operation.subclass
|
||||
@dataclass
|
||||
class ASE_Release(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.8 - Release Operation
|
||||
'''
|
||||
|
||||
op_code = ASE_Operation.Opcode.RELEASE
|
||||
|
||||
ase_id: Sequence[int] = field(
|
||||
metadata=hci.metadata(1, list_begin=True, list_end=True)
|
||||
)
|
||||
|
||||
|
||||
class AseResponseCode(enum.IntEnum):
|
||||
# fmt: off
|
||||
@@ -258,8 +260,8 @@ class AseReasonCode(enum.IntEnum):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AudioRole(enum.IntEnum):
|
||||
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
|
||||
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
|
||||
SINK = device.CisLink.Direction.CONTROLLER_TO_HOST
|
||||
SOURCE = device.CisLink.Direction.HOST_TO_CONTROLLER
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -274,6 +276,8 @@ class AseStateMachine(gatt.Characteristic):
|
||||
DISABLING = 0x05
|
||||
RELEASING = 0x06
|
||||
|
||||
EVENT_STATE_CHANGE = "state_change"
|
||||
|
||||
cis_link: Optional[device.CisLink] = None
|
||||
|
||||
# Additional parameters in CODEC_CONFIGURED State
|
||||
@@ -300,7 +304,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
presentation_delay = 0
|
||||
|
||||
# Additional parameters in ENABLING, STREAMING, DISABLING State
|
||||
metadata = le_audio.Metadata()
|
||||
metadata: le_audio.Metadata
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -312,6 +316,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
self.ase_id = ase_id
|
||||
self._state = AseStateMachine.State.IDLE
|
||||
self.role = role
|
||||
self.metadata = le_audio.Metadata()
|
||||
|
||||
uuid = (
|
||||
gatt.GATT_SINK_ASE_CHARACTERISTIC
|
||||
@@ -326,23 +331,23 @@ class AseStateMachine(gatt.Characteristic):
|
||||
value=gatt.CharacteristicValue(read=self.on_read),
|
||||
)
|
||||
|
||||
self.service.device.on('cis_request', self.on_cis_request)
|
||||
self.service.device.on('cis_establishment', self.on_cis_establishment)
|
||||
self.service.device.on(
|
||||
self.service.device.EVENT_CIS_REQUEST, self.on_cis_request
|
||||
)
|
||||
self.service.device.on(
|
||||
self.service.device.EVENT_CIS_ESTABLISHMENT, self.on_cis_establishment
|
||||
)
|
||||
|
||||
def on_cis_request(
|
||||
self,
|
||||
acl_connection: device.Connection,
|
||||
cis_handle: int,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
) -> None:
|
||||
def on_cis_request(self, cis_link: device.CisLink) -> None:
|
||||
if (
|
||||
cig_id == self.cig_id
|
||||
and cis_id == self.cis_id
|
||||
cis_link.cig_id == self.cig_id
|
||||
and cis_link.cis_id == self.cis_id
|
||||
and self.state == self.State.ENABLING
|
||||
):
|
||||
acl_connection.abort_on(
|
||||
'flush', self.service.device.accept_cis_request(cis_handle)
|
||||
utils.cancel_on_event(
|
||||
cis_link.acl_connection,
|
||||
'flush',
|
||||
self.service.device.accept_cis_request(cis_link),
|
||||
)
|
||||
|
||||
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
|
||||
@@ -351,24 +356,17 @@ class AseStateMachine(gatt.Characteristic):
|
||||
and cis_link.cis_id == self.cis_id
|
||||
and self.state == self.State.ENABLING
|
||||
):
|
||||
cis_link.on('disconnection', self.on_cis_disconnection)
|
||||
cis_link.on(cis_link.EVENT_DISCONNECTION, self.on_cis_disconnection)
|
||||
|
||||
async def post_cis_established():
|
||||
await self.service.device.send_command(
|
||||
hci.HCI_LE_Setup_ISO_Data_Path_Command(
|
||||
connection_handle=cis_link.handle,
|
||||
data_path_direction=self.role,
|
||||
data_path_id=0x00, # Fixed HCI
|
||||
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
|
||||
controller_delay=0,
|
||||
codec_configuration=b'',
|
||||
)
|
||||
)
|
||||
await cis_link.setup_data_path(direction=self.role)
|
||||
if self.role == AudioRole.SINK:
|
||||
self.state = self.State.STREAMING
|
||||
await self.service.device.notify_subscribers(self, self.value)
|
||||
|
||||
cis_link.acl_connection.abort_on('flush', post_cis_established())
|
||||
utils.cancel_on_event(
|
||||
cis_link.acl_connection, 'flush', post_cis_established()
|
||||
)
|
||||
self.cis_link = cis_link
|
||||
|
||||
def on_cis_disconnection(self, _reason) -> None:
|
||||
@@ -380,7 +378,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
target_phy: int,
|
||||
codec_id: hci.CodingFormat,
|
||||
codec_specific_configuration: bytes,
|
||||
) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
) -> tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state not in (
|
||||
self.State.IDLE,
|
||||
self.State.CODEC_CONFIGURED,
|
||||
@@ -416,7 +414,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
retransmission_number: int,
|
||||
max_transport_latency: int,
|
||||
presentation_delay: int,
|
||||
) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
) -> tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state not in (
|
||||
AseStateMachine.State.CODEC_CONFIGURED,
|
||||
AseStateMachine.State.QOS_CONFIGURED,
|
||||
@@ -440,7 +438,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
def on_enable(self, metadata: bytes) -> tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state != AseStateMachine.State.QOS_CONFIGURED:
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
@@ -449,10 +447,20 @@ class AseStateMachine(gatt.Characteristic):
|
||||
|
||||
self.metadata = le_audio.Metadata.from_bytes(metadata)
|
||||
self.state = self.State.ENABLING
|
||||
# CIS could be established before enable.
|
||||
if cis_link := next(
|
||||
(
|
||||
cis_link
|
||||
for cis_link in self.service.device.cis_links.values()
|
||||
if cis_link.cig_id == self.cig_id and cis_link.cis_id == self.cis_id
|
||||
),
|
||||
None,
|
||||
):
|
||||
self.on_cis_establishment(cis_link)
|
||||
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
def on_receiver_start_ready(self) -> tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state != AseStateMachine.State.ENABLING:
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
@@ -461,7 +469,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
self.state = self.State.STREAMING
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
def on_disable(self) -> tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state not in (
|
||||
AseStateMachine.State.ENABLING,
|
||||
AseStateMachine.State.STREAMING,
|
||||
@@ -476,7 +484,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
self.state = self.State.DISABLING
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
def on_receiver_stop_ready(self) -> tuple[AseResponseCode, AseReasonCode]:
|
||||
if (
|
||||
self.role != AudioRole.SOURCE
|
||||
or self.state != AseStateMachine.State.DISABLING
|
||||
@@ -490,7 +498,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
|
||||
def on_update_metadata(
|
||||
self, metadata: bytes
|
||||
) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
) -> tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state not in (
|
||||
AseStateMachine.State.ENABLING,
|
||||
AseStateMachine.State.STREAMING,
|
||||
@@ -502,7 +510,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
self.metadata = le_audio.Metadata.from_bytes(metadata)
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
def on_release(self) -> tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state == AseStateMachine.State.IDLE:
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
@@ -511,16 +519,12 @@ class AseStateMachine(gatt.Characteristic):
|
||||
self.state = self.State.RELEASING
|
||||
|
||||
async def remove_cis_async():
|
||||
await self.service.device.send_command(
|
||||
hci.HCI_LE_Remove_ISO_Data_Path_Command(
|
||||
connection_handle=self.cis_link.handle,
|
||||
data_path_direction=self.role,
|
||||
)
|
||||
)
|
||||
if self.cis_link:
|
||||
await self.cis_link.remove_data_path([self.role])
|
||||
self.state = self.State.IDLE
|
||||
await self.service.device.notify_subscribers(self, self.value)
|
||||
|
||||
self.service.device.abort_on('flush', remove_cis_async())
|
||||
utils.cancel_on_event(self.service.device, 'flush', remove_cis_async())
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
@property
|
||||
@@ -531,7 +535,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
def state(self, new_state: State) -> None:
|
||||
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
|
||||
self._state = new_state
|
||||
self.emit('state_change')
|
||||
self.emit(self.EVENT_STATE_CHANGE)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
@@ -590,7 +594,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
# Readonly. Do nothing in the setter.
|
||||
pass
|
||||
|
||||
def on_read(self, _: Optional[device.Connection]) -> bytes:
|
||||
def on_read(self, _: device.Connection) -> bytes:
|
||||
return self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -604,8 +608,8 @@ class AseStateMachine(gatt.Characteristic):
|
||||
class AudioStreamControlService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
|
||||
|
||||
ase_state_machines: Dict[int, AseStateMachine]
|
||||
ase_control_point: gatt.Characteristic
|
||||
ase_state_machines: dict[int, AseStateMachine]
|
||||
ase_control_point: gatt.Characteristic[bytes]
|
||||
_active_client: Optional[device.Connection] = None
|
||||
|
||||
def __init__(
|
||||
@@ -649,7 +653,9 @@ class AudioStreamControlService(gatt.TemplateService):
|
||||
ase.state = AseStateMachine.State.IDLE
|
||||
self._active_client = None
|
||||
|
||||
def on_write_ase_control_point(self, connection, data):
|
||||
def on_write_ase_control_point(
|
||||
self, connection: device.Connection, data: bytes
|
||||
) -> None:
|
||||
if not self._active_client and connection:
|
||||
self._active_client = connection
|
||||
connection.once('disconnection', self._on_client_disconnected)
|
||||
@@ -658,7 +664,7 @@ class AudioStreamControlService(gatt.TemplateService):
|
||||
responses = []
|
||||
logger.debug(f'*** ASCS Write {operation} ***')
|
||||
|
||||
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
|
||||
if isinstance(operation, ASE_Config_Codec):
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.target_latency,
|
||||
@@ -667,7 +673,7 @@ class AudioStreamControlService(gatt.TemplateService):
|
||||
operation.codec_specific_configuration,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
|
||||
elif isinstance(operation, ASE_Config_QOS):
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.cig_id,
|
||||
@@ -681,20 +687,20 @@ class AudioStreamControlService(gatt.TemplateService):
|
||||
operation.presentation_delay,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
elif operation.op_code in (
|
||||
ASE_Operation.Opcode.ENABLE,
|
||||
ASE_Operation.Opcode.UPDATE_METADATA,
|
||||
):
|
||||
elif isinstance(operation, (ASE_Enable, ASE_Update_Metadata)):
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.metadata,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
elif operation.op_code in (
|
||||
ASE_Operation.Opcode.RECEIVER_START_READY,
|
||||
ASE_Operation.Opcode.DISABLE,
|
||||
ASE_Operation.Opcode.RECEIVER_STOP_READY,
|
||||
ASE_Operation.Opcode.RELEASE,
|
||||
elif isinstance(
|
||||
operation,
|
||||
(
|
||||
ASE_Receiver_Start_Ready,
|
||||
ASE_Disable,
|
||||
ASE_Receiver_Stop_Ready,
|
||||
ASE_Release,
|
||||
),
|
||||
):
|
||||
for ase_id in operation.ase_id:
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, []))
|
||||
@@ -702,7 +708,8 @@ class AudioStreamControlService(gatt.TemplateService):
|
||||
control_point_notification = bytes(
|
||||
[operation.op_code, len(responses)]
|
||||
) + b''.join(map(bytes, responses))
|
||||
self.device.abort_on(
|
||||
utils.cancel_on_event(
|
||||
self.device,
|
||||
'flush',
|
||||
self.device.notify_subscribers(
|
||||
self.ase_control_point, control_point_notification
|
||||
@@ -711,7 +718,8 @@ class AudioStreamControlService(gatt.TemplateService):
|
||||
|
||||
for ase_id, *_ in responses:
|
||||
if ase := self.ase_state_machines.get(ase_id):
|
||||
self.device.abort_on(
|
||||
utils.cancel_on_event(
|
||||
self.device,
|
||||
'flush',
|
||||
self.device.notify_subscribers(ase, ase.value),
|
||||
)
|
||||
@@ -721,9 +729,9 @@ class AudioStreamControlService(gatt.TemplateService):
|
||||
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = AudioStreamControlService
|
||||
|
||||
sink_ase: List[gatt_client.CharacteristicProxy]
|
||||
source_ase: List[gatt_client.CharacteristicProxy]
|
||||
ase_control_point: gatt_client.CharacteristicProxy
|
||||
sink_ase: list[gatt_client.CharacteristicProxy[bytes]]
|
||||
source_ase: list[gatt_client.CharacteristicProxy[bytes]]
|
||||
ase_control_point: gatt_client.CharacteristicProxy[bytes]
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
@@ -17,16 +17,13 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import enum
|
||||
import struct
|
||||
import logging
|
||||
from typing import List, Optional, Callable, Union, Any
|
||||
import struct
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from bumble import l2cap
|
||||
from bumble import utils
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble import data_types, gatt, gatt_client, l2cap, utils
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.device import Device, Connection
|
||||
from bumble.device import Connection, Device
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -88,6 +85,11 @@ class AudioStatus(utils.OpenIntEnum):
|
||||
class AshaService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_ASHA_SERVICE
|
||||
|
||||
EVENT_STARTED = "started"
|
||||
EVENT_STOPPED = "stopped"
|
||||
EVENT_DISCONNECTED = "disconnected"
|
||||
EVENT_VOLUME_CHANGED = "volume_changed"
|
||||
|
||||
audio_sink: Optional[Callable[[bytes], Any]]
|
||||
active_codec: Optional[Codec] = None
|
||||
audio_type: Optional[AudioType] = None
|
||||
@@ -98,7 +100,7 @@ class AshaService(gatt.TemplateService):
|
||||
def __init__(
|
||||
self,
|
||||
capability: int,
|
||||
hisyncid: Union[List[int], bytes],
|
||||
hisyncid: Union[list[int], bytes],
|
||||
device: Device,
|
||||
psm: int = 0,
|
||||
audio_sink: Optional[Callable[[bytes], Any]] = None,
|
||||
@@ -134,12 +136,14 @@ class AshaService(gatt.TemplateService):
|
||||
),
|
||||
)
|
||||
|
||||
self.audio_control_point_characteristic = gatt.Characteristic(
|
||||
gatt.GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
|
||||
gatt.Characteristic.Properties.WRITE
|
||||
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
|
||||
gatt.Characteristic.WRITEABLE,
|
||||
gatt.CharacteristicValue(write=self._on_audio_control_point_write),
|
||||
self.audio_control_point_characteristic: gatt.Characteristic[bytes] = (
|
||||
gatt.Characteristic(
|
||||
gatt.GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
|
||||
gatt.Characteristic.Properties.WRITE
|
||||
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
|
||||
gatt.Characteristic.WRITEABLE,
|
||||
gatt.CharacteristicValue(write=self._on_audio_control_point_write),
|
||||
)
|
||||
)
|
||||
self.audio_status_characteristic = gatt.Characteristic(
|
||||
gatt.GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
|
||||
@@ -147,7 +151,7 @@ class AshaService(gatt.TemplateService):
|
||||
gatt.Characteristic.READABLE,
|
||||
bytes([AudioStatus.OK]),
|
||||
)
|
||||
self.volume_characteristic = gatt.Characteristic(
|
||||
self.volume_characteristic: gatt.Characteristic[bytes] = gatt.Characteristic(
|
||||
gatt.GATT_ASHA_VOLUME_CHARACTERISTIC,
|
||||
gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
|
||||
gatt.Characteristic.WRITEABLE,
|
||||
@@ -166,13 +170,13 @@ class AshaService(gatt.TemplateService):
|
||||
struct.pack('<H', self.psm),
|
||||
)
|
||||
|
||||
characteristics = [
|
||||
characteristics = (
|
||||
self.read_only_properties_characteristic,
|
||||
self.audio_control_point_characteristic,
|
||||
self.audio_status_characteristic,
|
||||
self.volume_characteristic,
|
||||
self.le_psm_out_characteristic,
|
||||
]
|
||||
)
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
@@ -181,19 +185,18 @@ class AshaService(gatt.TemplateService):
|
||||
return bytes(
|
||||
AdvertisingData(
|
||||
[
|
||||
(
|
||||
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
|
||||
bytes(gatt.GATT_ASHA_SERVICE)
|
||||
+ bytes([self.protocol_version, self.capability])
|
||||
data_types.ServiceData16BitUUID(
|
||||
gatt.GATT_ASHA_SERVICE,
|
||||
bytes([self.protocol_version, self.capability])
|
||||
+ self.hisyncid[:4],
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Handler for audio control commands
|
||||
async def _on_audio_control_point_write(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
self, connection: Connection, value: bytes
|
||||
) -> None:
|
||||
_logger.debug(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
|
||||
opcode = value[0]
|
||||
@@ -209,14 +212,14 @@ class AshaService(gatt.TemplateService):
|
||||
f'volume={self.volume}, '
|
||||
f'other_state={self.other_state}'
|
||||
)
|
||||
self.emit('started')
|
||||
self.emit(self.EVENT_STARTED)
|
||||
elif opcode == OpCode.STOP:
|
||||
_logger.debug('### STOP')
|
||||
self.active_codec = None
|
||||
self.audio_type = None
|
||||
self.volume = None
|
||||
self.other_state = None
|
||||
self.emit('stopped')
|
||||
self.emit(self.EVENT_STOPPED)
|
||||
elif opcode == OpCode.STATUS:
|
||||
_logger.debug('### STATUS: %s', PeripheralStatus(value[1]).name)
|
||||
|
||||
@@ -229,7 +232,7 @@ class AshaService(gatt.TemplateService):
|
||||
self.audio_type = None
|
||||
self.volume = None
|
||||
self.other_state = None
|
||||
self.emit('disconnected')
|
||||
self.emit(self.EVENT_DISCONNECTED)
|
||||
|
||||
connection.once('disconnection', on_disconnection)
|
||||
|
||||
@@ -240,10 +243,10 @@ class AshaService(gatt.TemplateService):
|
||||
)
|
||||
|
||||
# Handler for volume control
|
||||
def _on_volume_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
def _on_volume_write(self, connection: Connection, value: bytes) -> None:
|
||||
_logger.debug(f'--- VOLUME Write:{value[0]}')
|
||||
self.volume = value[0]
|
||||
self.emit('volume_changed')
|
||||
self.emit(self.EVENT_VOLUME_CHANGED)
|
||||
|
||||
# Register an L2CAP CoC server
|
||||
def _on_connection(self, channel: l2cap.LeCreditBasedChannel) -> None:
|
||||
@@ -257,11 +260,11 @@ class AshaService(gatt.TemplateService):
|
||||
# -----------------------------------------------------------------------------
|
||||
class AshaServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = AshaService
|
||||
read_only_properties_characteristic: gatt_client.CharacteristicProxy
|
||||
audio_control_point_characteristic: gatt_client.CharacteristicProxy
|
||||
audio_status_point_characteristic: gatt_client.CharacteristicProxy
|
||||
volume_characteristic: gatt_client.CharacteristicProxy
|
||||
psm_characteristic: gatt_client.CharacteristicProxy
|
||||
read_only_properties_characteristic: gatt_client.CharacteristicProxy[bytes]
|
||||
audio_control_point_characteristic: gatt_client.CharacteristicProxy[bytes]
|
||||
audio_status_point_characteristic: gatt_client.CharacteristicProxy[bytes]
|
||||
volume_characteristic: gatt_client.CharacteristicProxy[bytes]
|
||||
psm_characteristic: gatt_client.CharacteristicProxy[bytes]
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
@@ -288,8 +291,8 @@ class AshaServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
'psm_characteristic',
|
||||
),
|
||||
):
|
||||
if not (
|
||||
characteristics := self.service_proxy.get_characteristics_by_uuid(uuid)
|
||||
):
|
||||
raise gatt.InvalidServiceError(f"Missing {uuid} Characteristic")
|
||||
setattr(self, attribute_name, characteristics[0])
|
||||
setattr(
|
||||
self,
|
||||
attribute_name,
|
||||
self.service_proxy.get_required_characteristic_by_uuid(uuid),
|
||||
)
|
||||
|
||||
@@ -18,22 +18,18 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import struct
|
||||
import functools
|
||||
import logging
|
||||
from typing import List
|
||||
import struct
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble import core
|
||||
from bumble import hci
|
||||
from bumble import gatt
|
||||
from bumble import utils
|
||||
from bumble import core, data_types, gatt, hci, utils
|
||||
from bumble.profiles import le_audio
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -102,6 +98,7 @@ class ContextType(enum.IntFlag):
|
||||
|
||||
# fmt: off
|
||||
PROHIBITED = 0x0000
|
||||
UNSPECIFIED = 0x0001
|
||||
CONVERSATIONAL = 0x0002
|
||||
MEDIA = 0x0004
|
||||
GAME = 0x0008
|
||||
@@ -260,11 +257,10 @@ class UnicastServerAdvertisingData:
|
||||
return bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
(
|
||||
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
|
||||
data_types.ServiceData16BitUUID(
|
||||
gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE,
|
||||
struct.pack(
|
||||
'<2sBIB',
|
||||
gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE.to_bytes(),
|
||||
'<BIB',
|
||||
self.announcement_type,
|
||||
self.available_audio_contexts,
|
||||
len(self.metadata),
|
||||
@@ -281,7 +277,7 @@ class UnicastServerAdvertisingData:
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def bits_to_channel_counts(data: int) -> List[int]:
|
||||
def bits_to_channel_counts(data: int) -> list[int]:
|
||||
pos = 0
|
||||
counts = []
|
||||
while data != 0:
|
||||
@@ -350,6 +346,7 @@ class CodecSpecificCapabilities:
|
||||
supported_max_codec_frames_per_sdu = value
|
||||
|
||||
# It is expected here that if some fields are missing, an error should be raised.
|
||||
# pylint: disable=possibly-used-before-assignment,used-before-assignment
|
||||
return CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=supported_sampling_frequencies,
|
||||
supported_frame_durations=supported_frame_durations,
|
||||
@@ -396,18 +393,21 @@ class CodecSpecificConfiguration:
|
||||
OCTETS_PER_FRAME = 0x04
|
||||
CODEC_FRAMES_PER_SDU = 0x05
|
||||
|
||||
sampling_frequency: SamplingFrequency
|
||||
frame_duration: FrameDuration
|
||||
audio_channel_allocation: AudioLocation
|
||||
octets_per_codec_frame: int
|
||||
codec_frames_per_sdu: int
|
||||
sampling_frequency: SamplingFrequency | None = None
|
||||
frame_duration: FrameDuration | None = None
|
||||
audio_channel_allocation: AudioLocation | None = None
|
||||
octets_per_codec_frame: int | None = None
|
||||
codec_frames_per_sdu: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> CodecSpecificConfiguration:
|
||||
offset = 0
|
||||
# Allowed default values.
|
||||
audio_channel_allocation = AudioLocation.NOT_ALLOWED
|
||||
codec_frames_per_sdu = 1
|
||||
sampling_frequency: SamplingFrequency | None = None
|
||||
frame_duration: FrameDuration | None = None
|
||||
audio_channel_allocation: AudioLocation | None = None
|
||||
octets_per_codec_frame: int | None = None
|
||||
codec_frames_per_sdu: int | None = None
|
||||
|
||||
while offset < len(data):
|
||||
length, type = struct.unpack_from('BB', data, offset)
|
||||
offset += 2
|
||||
@@ -425,7 +425,6 @@ class CodecSpecificConfiguration:
|
||||
elif type == CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU:
|
||||
codec_frames_per_sdu = value
|
||||
|
||||
# It is expected here that if some fields are missing, an error should be raised.
|
||||
return CodecSpecificConfiguration(
|
||||
sampling_frequency=sampling_frequency,
|
||||
frame_duration=frame_duration,
|
||||
@@ -435,23 +434,43 @@ class CodecSpecificConfiguration:
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack(
|
||||
'<BBBBBBBBIBBHBBB',
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY,
|
||||
self.sampling_frequency,
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.FRAME_DURATION,
|
||||
self.frame_duration,
|
||||
5,
|
||||
CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION,
|
||||
self.audio_channel_allocation,
|
||||
3,
|
||||
CodecSpecificConfiguration.Type.OCTETS_PER_FRAME,
|
||||
self.octets_per_codec_frame,
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU,
|
||||
self.codec_frames_per_sdu,
|
||||
return b''.join(
|
||||
[
|
||||
struct.pack(fmt, length, tag, value)
|
||||
for fmt, length, tag, value in [
|
||||
(
|
||||
'<BBB',
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY,
|
||||
self.sampling_frequency,
|
||||
),
|
||||
(
|
||||
'<BBB',
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.FRAME_DURATION,
|
||||
self.frame_duration,
|
||||
),
|
||||
(
|
||||
'<BBI',
|
||||
5,
|
||||
CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION,
|
||||
self.audio_channel_allocation,
|
||||
),
|
||||
(
|
||||
'<BBH',
|
||||
3,
|
||||
CodecSpecificConfiguration.Type.OCTETS_PER_FRAME,
|
||||
self.octets_per_codec_frame,
|
||||
),
|
||||
(
|
||||
'<BBB',
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU,
|
||||
self.codec_frames_per_sdu,
|
||||
),
|
||||
]
|
||||
if value is not None
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -463,6 +482,20 @@ class BroadcastAudioAnnouncement:
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
return cls(int.from_bytes(data[:3], 'little'))
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.broadcast_id.to_bytes(3, 'little')
|
||||
|
||||
def get_advertising_data(self) -> bytes:
|
||||
return bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
data_types.ServiceData16BitUUID(
|
||||
gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE, bytes(self)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BasicAudioAnnouncement:
|
||||
@@ -471,28 +504,39 @@ class BasicAudioAnnouncement:
|
||||
index: int
|
||||
codec_specific_configuration: CodecSpecificConfiguration
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CodecInfo:
|
||||
coding_format: hci.CodecID
|
||||
company_id: int
|
||||
vendor_specific_codec_id: int
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
coding_format = hci.CodecID(data[0])
|
||||
company_id = int.from_bytes(data[1:3], 'little')
|
||||
vendor_specific_codec_id = int.from_bytes(data[3:5], 'little')
|
||||
return cls(coding_format, company_id, vendor_specific_codec_id)
|
||||
def __bytes__(self) -> bytes:
|
||||
codec_specific_configuration_bytes = bytes(
|
||||
self.codec_specific_configuration
|
||||
)
|
||||
return (
|
||||
bytes([self.index, len(codec_specific_configuration_bytes)])
|
||||
+ codec_specific_configuration_bytes
|
||||
)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Subgroup:
|
||||
codec_id: BasicAudioAnnouncement.CodecInfo
|
||||
codec_id: hci.CodingFormat
|
||||
codec_specific_configuration: CodecSpecificConfiguration
|
||||
metadata: le_audio.Metadata
|
||||
bis: List[BasicAudioAnnouncement.BIS]
|
||||
bis: list[BasicAudioAnnouncement.BIS]
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
metadata_bytes = bytes(self.metadata)
|
||||
codec_specific_configuration_bytes = bytes(
|
||||
self.codec_specific_configuration
|
||||
)
|
||||
return (
|
||||
bytes([len(self.bis)])
|
||||
+ bytes(self.codec_id)
|
||||
+ bytes([len(codec_specific_configuration_bytes)])
|
||||
+ codec_specific_configuration_bytes
|
||||
+ bytes([len(metadata_bytes)])
|
||||
+ metadata_bytes
|
||||
+ b''.join(map(bytes, self.bis))
|
||||
)
|
||||
|
||||
presentation_delay: int
|
||||
subgroups: List[BasicAudioAnnouncement.Subgroup]
|
||||
subgroups: list[BasicAudioAnnouncement.Subgroup]
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
@@ -502,7 +546,7 @@ class BasicAudioAnnouncement:
|
||||
for _ in range(data[3]):
|
||||
num_bis = data[offset]
|
||||
offset += 1
|
||||
codec_id = cls.CodecInfo.from_bytes(data[offset : offset + 5])
|
||||
codec_id = hci.CodingFormat.from_bytes(data[offset : offset + 5])
|
||||
offset += 5
|
||||
codec_specific_configuration_length = data[offset]
|
||||
offset += 1
|
||||
@@ -546,3 +590,21 @@ class BasicAudioAnnouncement:
|
||||
)
|
||||
|
||||
return cls(presentation_delay, subgroups)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return (
|
||||
self.presentation_delay.to_bytes(3, 'little')
|
||||
+ bytes([len(self.subgroups)])
|
||||
+ b''.join(map(bytes, self.subgroups))
|
||||
)
|
||||
|
||||
def get_advertising_data(self) -> bytes:
|
||||
return bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
data_types.ServiceData16BitUUID(
|
||||
gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE, bytes(self)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -17,17 +17,13 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import struct
|
||||
from typing import ClassVar, List, Optional, Sequence
|
||||
from typing import ClassVar, Optional, Sequence
|
||||
|
||||
from bumble import core
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble import hci
|
||||
from bumble import utils
|
||||
from bumble import core, device, gatt, gatt_adapters, gatt_client, hci, utils
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -52,7 +48,7 @@ def encode_subgroups(subgroups: Sequence[SubgroupInfo]) -> bytes:
|
||||
)
|
||||
|
||||
|
||||
def decode_subgroups(data: bytes) -> List[SubgroupInfo]:
|
||||
def decode_subgroups(data: bytes) -> list[SubgroupInfo]:
|
||||
num_subgroups = data[0]
|
||||
offset = 1
|
||||
subgroups = []
|
||||
@@ -273,13 +269,10 @@ class BroadcastReceiveState:
|
||||
pa_sync_state: PeriodicAdvertisingSyncState
|
||||
big_encryption: BigEncryption
|
||||
bad_code: bytes
|
||||
subgroups: List[SubgroupInfo]
|
||||
subgroups: list[SubgroupInfo]
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]:
|
||||
if not data:
|
||||
return None
|
||||
|
||||
def from_bytes(cls, data: bytes) -> BroadcastReceiveState:
|
||||
source_id = data[0]
|
||||
_, source_address = hci.Address.parse_address_preceded_by_type(data, 2)
|
||||
source_adv_sid = data[8]
|
||||
@@ -356,35 +349,28 @@ class BroadcastAudioScanService(gatt.TemplateService):
|
||||
class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = BroadcastAudioScanService
|
||||
|
||||
broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy
|
||||
broadcast_receive_states: List[gatt.DelegatedCharacteristicAdapter]
|
||||
broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy[bytes]
|
||||
broadcast_receive_states: list[
|
||||
gatt_client.CharacteristicProxy[Optional[BroadcastReceiveState]]
|
||||
]
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
self.broadcast_audio_scan_control_point = (
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Broadcast Audio Scan Control Point characteristic not found"
|
||||
)
|
||||
self.broadcast_audio_scan_control_point = characteristics[0]
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
self.broadcast_receive_states = [
|
||||
gatt_adapters.DelegatedCharacteristicProxyAdapter(
|
||||
characteristic,
|
||||
decode=lambda x: BroadcastReceiveState.from_bytes(x) if x else None,
|
||||
)
|
||||
for characteristic in service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Broadcast Receive State characteristic not found"
|
||||
)
|
||||
self.broadcast_receive_states = [
|
||||
gatt.DelegatedCharacteristicAdapter(
|
||||
characteristic, decode=BroadcastReceiveState.from_bytes
|
||||
)
|
||||
for characteristic in characteristics
|
||||
]
|
||||
|
||||
async def send_control_point_operation(
|
||||
|
||||
@@ -16,15 +16,20 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from ..gatt_client import ProfileServiceProxy
|
||||
from ..gatt import (
|
||||
GATT_BATTERY_SERVICE,
|
||||
from typing import Optional
|
||||
|
||||
from bumble.gatt import (
|
||||
GATT_BATTERY_LEVEL_CHARACTERISTIC,
|
||||
TemplateService,
|
||||
GATT_BATTERY_SERVICE,
|
||||
Characteristic,
|
||||
CharacteristicValue,
|
||||
PackedCharacteristicAdapter,
|
||||
TemplateService,
|
||||
)
|
||||
from bumble.gatt_adapters import (
|
||||
PackedCharacteristicAdapter,
|
||||
PackedCharacteristicProxyAdapter,
|
||||
)
|
||||
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -32,6 +37,8 @@ class BatteryService(TemplateService):
|
||||
UUID = GATT_BATTERY_SERVICE
|
||||
BATTERY_LEVEL_FORMAT = 'B'
|
||||
|
||||
battery_level_characteristic: Characteristic[int]
|
||||
|
||||
def __init__(self, read_battery_level):
|
||||
self.battery_level_characteristic = PackedCharacteristicAdapter(
|
||||
Characteristic(
|
||||
@@ -49,13 +56,15 @@ class BatteryService(TemplateService):
|
||||
class BatteryServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = BatteryService
|
||||
|
||||
battery_level: Optional[CharacteristicProxy[int]]
|
||||
|
||||
def __init__(self, service_proxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_BATTERY_LEVEL_CHARACTERISTIC
|
||||
):
|
||||
self.battery_level = PackedCharacteristicAdapter(
|
||||
self.battery_level = PackedCharacteristicProxyAdapter(
|
||||
characteristics[0], pack_format=BatteryService.BATTERY_LEVEL_FORMAT
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -18,8 +18,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble import gatt, gatt_client
|
||||
from bumble.profiles import csip
|
||||
|
||||
|
||||
|
||||
@@ -17,16 +17,12 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import struct
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from bumble import core
|
||||
from bumble import crypto
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from typing import Optional
|
||||
|
||||
from bumble import core, crypto, device, gatt, gatt_client
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -99,10 +95,10 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
|
||||
|
||||
set_identity_resolving_key: bytes
|
||||
set_identity_resolving_key_characteristic: gatt.Characteristic
|
||||
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
|
||||
set_member_lock_characteristic: Optional[gatt.Characteristic] = None
|
||||
set_member_rank_characteristic: Optional[gatt.Characteristic] = None
|
||||
set_identity_resolving_key_characteristic: gatt.Characteristic[bytes]
|
||||
coordinated_set_size_characteristic: Optional[gatt.Characteristic[bytes]] = None
|
||||
set_member_lock_characteristic: Optional[gatt.Characteristic[bytes]] = None
|
||||
set_member_rank_characteristic: Optional[gatt.Characteristic[bytes]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -164,13 +160,11 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
|
||||
async def on_sirk_read(self, connection: device.Connection) -> bytes:
|
||||
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
|
||||
sirk_bytes = self.set_identity_resolving_key
|
||||
else:
|
||||
assert connection
|
||||
|
||||
if connection.transport == core.BT_LE_TRANSPORT:
|
||||
if connection.transport == core.PhysicalTransport.LE:
|
||||
key = await connection.device.get_long_term_key(
|
||||
connection_handle=connection.handle, rand=b'', ediv=0
|
||||
)
|
||||
@@ -203,10 +197,10 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = CoordinatedSetIdentificationService
|
||||
|
||||
set_identity_resolving_key: gatt_client.CharacteristicProxy
|
||||
coordinated_set_size: Optional[gatt_client.CharacteristicProxy] = None
|
||||
set_member_lock: Optional[gatt_client.CharacteristicProxy] = None
|
||||
set_member_rank: Optional[gatt_client.CharacteristicProxy] = None
|
||||
set_identity_resolving_key: gatt_client.CharacteristicProxy[bytes]
|
||||
coordinated_set_size: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
set_member_lock: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
set_member_rank: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
@@ -230,7 +224,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
|
||||
):
|
||||
self.set_member_rank = characteristics[0]
|
||||
|
||||
async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
|
||||
async def read_set_identity_resolving_key(self) -> tuple[SirkType, bytes]:
|
||||
'''Reads SIRK and decrypts if encrypted.'''
|
||||
response = await self.set_identity_resolving_key.read_value()
|
||||
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
|
||||
@@ -242,7 +236,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
|
||||
else:
|
||||
connection = self.service_proxy.client.connection
|
||||
device = connection.device
|
||||
if connection.transport == core.BT_LE_TRANSPORT:
|
||||
if connection.transport == core.PhysicalTransport.LE:
|
||||
key = await device.get_long_term_key(
|
||||
connection_handle=connection.handle, rand=b'', ediv=0
|
||||
)
|
||||
|
||||
@@ -17,24 +17,26 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from bumble.gatt_client import ServiceProxy, ProfileServiceProxy, CharacteristicProxy
|
||||
from bumble.gatt import (
|
||||
GATT_DEVICE_INFORMATION_SERVICE,
|
||||
GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC,
|
||||
GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC,
|
||||
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
|
||||
GATT_MODEL_NUMBER_STRING_CHARACTERISTIC,
|
||||
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
|
||||
GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC,
|
||||
GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC,
|
||||
GATT_SYSTEM_ID_CHARACTERISTIC,
|
||||
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
|
||||
TemplateService,
|
||||
Characteristic,
|
||||
DelegatedCharacteristicAdapter,
|
||||
UTF8CharacteristicAdapter,
|
||||
TemplateService,
|
||||
)
|
||||
from bumble.gatt_adapters import (
|
||||
DelegatedCharacteristicProxyAdapter,
|
||||
UTF8CharacteristicProxyAdapter,
|
||||
)
|
||||
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -58,13 +60,16 @@ class DeviceInformationService(TemplateService):
|
||||
hardware_revision: Optional[str] = None,
|
||||
firmware_revision: Optional[str] = None,
|
||||
software_revision: Optional[str] = None,
|
||||
system_id: Optional[Tuple[int, int]] = None, # (OUI, Manufacturer ID)
|
||||
system_id: Optional[tuple[int, int]] = None, # (OUI, Manufacturer ID)
|
||||
ieee_regulatory_certification_data_list: Optional[bytes] = None,
|
||||
# TODO: pnp_id
|
||||
):
|
||||
characteristics = [
|
||||
characteristics: list[Characteristic[bytes]] = [
|
||||
Characteristic(
|
||||
uuid, Characteristic.Properties.READ, Characteristic.READABLE, field
|
||||
uuid,
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.READABLE,
|
||||
bytes(field, 'utf-8'),
|
||||
)
|
||||
for (field, uuid) in (
|
||||
(manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
|
||||
@@ -104,14 +109,14 @@ class DeviceInformationService(TemplateService):
|
||||
class DeviceInformationServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = DeviceInformationService
|
||||
|
||||
manufacturer_name: Optional[UTF8CharacteristicAdapter]
|
||||
model_number: Optional[UTF8CharacteristicAdapter]
|
||||
serial_number: Optional[UTF8CharacteristicAdapter]
|
||||
hardware_revision: Optional[UTF8CharacteristicAdapter]
|
||||
firmware_revision: Optional[UTF8CharacteristicAdapter]
|
||||
software_revision: Optional[UTF8CharacteristicAdapter]
|
||||
system_id: Optional[DelegatedCharacteristicAdapter]
|
||||
ieee_regulatory_certification_data_list: Optional[CharacteristicProxy]
|
||||
manufacturer_name: Optional[CharacteristicProxy[str]]
|
||||
model_number: Optional[CharacteristicProxy[str]]
|
||||
serial_number: Optional[CharacteristicProxy[str]]
|
||||
hardware_revision: Optional[CharacteristicProxy[str]]
|
||||
firmware_revision: Optional[CharacteristicProxy[str]]
|
||||
software_revision: Optional[CharacteristicProxy[str]]
|
||||
system_id: Optional[CharacteristicProxy[tuple[int, int]]]
|
||||
ieee_regulatory_certification_data_list: Optional[CharacteristicProxy[bytes]]
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
@@ -125,7 +130,7 @@ class DeviceInformationServiceProxy(ProfileServiceProxy):
|
||||
('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC),
|
||||
):
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
|
||||
characteristic = UTF8CharacteristicAdapter(characteristics[0])
|
||||
characteristic = UTF8CharacteristicProxyAdapter(characteristics[0])
|
||||
else:
|
||||
characteristic = None
|
||||
self.__setattr__(field, characteristic)
|
||||
@@ -133,7 +138,7 @@ class DeviceInformationServiceProxy(ProfileServiceProxy):
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_SYSTEM_ID_CHARACTERISTIC
|
||||
):
|
||||
self.system_id = DelegatedCharacteristicAdapter(
|
||||
self.system_id = DelegatedCharacteristicProxyAdapter(
|
||||
characteristics[0],
|
||||
encode=lambda v: DeviceInformationService.pack_system_id(*v),
|
||||
decode=DeviceInformationService.unpack_system_id,
|
||||
|
||||
@@ -19,20 +19,21 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
import logging
|
||||
import struct
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from bumble.core import Appearance
|
||||
from bumble.gatt import (
|
||||
TemplateService,
|
||||
Characteristic,
|
||||
CharacteristicAdapter,
|
||||
DelegatedCharacteristicAdapter,
|
||||
UTF8CharacteristicAdapter,
|
||||
GATT_GENERIC_ACCESS_SERVICE,
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
GATT_APPEARANCE_CHARACTERISTIC,
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
GATT_GENERIC_ACCESS_SERVICE,
|
||||
Characteristic,
|
||||
TemplateService,
|
||||
)
|
||||
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
|
||||
from bumble.gatt_adapters import (
|
||||
DelegatedCharacteristicProxyAdapter,
|
||||
UTF8CharacteristicProxyAdapter,
|
||||
)
|
||||
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -49,8 +50,11 @@ logger = logging.getLogger(__name__)
|
||||
class GenericAccessService(TemplateService):
|
||||
UUID = GATT_GENERIC_ACCESS_SERVICE
|
||||
|
||||
device_name_characteristic: Characteristic[bytes]
|
||||
appearance_characteristic: Characteristic[bytes]
|
||||
|
||||
def __init__(
|
||||
self, device_name: str, appearance: Union[Appearance, Tuple[int, int], int] = 0
|
||||
self, device_name: str, appearance: Union[Appearance, tuple[int, int], int] = 0
|
||||
):
|
||||
if isinstance(appearance, int):
|
||||
appearance_int = appearance
|
||||
@@ -84,8 +88,8 @@ class GenericAccessService(TemplateService):
|
||||
class GenericAccessServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = GenericAccessService
|
||||
|
||||
device_name: Optional[CharacteristicAdapter]
|
||||
appearance: Optional[DelegatedCharacteristicAdapter]
|
||||
device_name: Optional[CharacteristicProxy[str]]
|
||||
appearance: Optional[CharacteristicProxy[Appearance]]
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
@@ -93,14 +97,14 @@ class GenericAccessServiceProxy(ProfileServiceProxy):
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC
|
||||
):
|
||||
self.device_name = UTF8CharacteristicAdapter(characteristics[0])
|
||||
self.device_name = UTF8CharacteristicProxyAdapter(characteristics[0])
|
||||
else:
|
||||
self.device_name = None
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_APPEARANCE_CHARACTERISTIC
|
||||
):
|
||||
self.appearance = DelegatedCharacteristicAdapter(
|
||||
self.appearance = DelegatedCharacteristicProxyAdapter(
|
||||
characteristics[0],
|
||||
decode=lambda value: Appearance.from_int(
|
||||
struct.unpack_from('<H', value, 0)[0],
|
||||
|
||||
162
bumble/profiles/gatt_service.py
Normal file
162
bumble/profiles/gatt_service.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright 2021-2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from bumble import att, crypto, gatt, gatt_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble import device
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class GenericAttributeProfileService(gatt.TemplateService):
|
||||
'''See Vol 3, Part G - 7 - DEFINED GENERIC ATTRIBUTE PROFILE SERVICE.'''
|
||||
|
||||
UUID = gatt.GATT_GENERIC_ATTRIBUTE_SERVICE
|
||||
|
||||
client_supported_features_characteristic: gatt.Characteristic[bytes] | None = None
|
||||
server_supported_features_characteristic: gatt.Characteristic[bytes] | None = None
|
||||
database_hash_characteristic: gatt.Characteristic[bytes] | None = None
|
||||
service_changed_characteristic: gatt.Characteristic[bytes] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_supported_features: gatt.ServerSupportedFeatures | None = None,
|
||||
database_hash_enabled: bool = True,
|
||||
service_change_enabled: bool = True,
|
||||
) -> None:
|
||||
|
||||
if server_supported_features is not None:
|
||||
self.server_supported_features_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
value=bytes([server_supported_features]),
|
||||
)
|
||||
|
||||
if database_hash_enabled:
|
||||
self.database_hash_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_DATABASE_HASH_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
value=gatt.CharacteristicValue(read=self.get_database_hash),
|
||||
)
|
||||
|
||||
if service_change_enabled:
|
||||
self.service_changed_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SERVICE_CHANGED_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.INDICATE,
|
||||
permissions=gatt.Characteristic.Permissions(0),
|
||||
value=b'',
|
||||
)
|
||||
|
||||
if (database_hash_enabled and service_change_enabled) or (
|
||||
server_supported_features
|
||||
and (
|
||||
server_supported_features & gatt.ServerSupportedFeatures.EATT_SUPPORTED
|
||||
)
|
||||
): # TODO: Support Multiple Handle Value Notifications
|
||||
self.client_supported_features_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC,
|
||||
properties=(
|
||||
gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.WRITE
|
||||
),
|
||||
permissions=(
|
||||
gatt.Characteristic.Permissions.READABLE
|
||||
| gatt.Characteristic.Permissions.WRITEABLE
|
||||
),
|
||||
value=bytes(1),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
characteristics=[
|
||||
c
|
||||
for c in (
|
||||
self.service_changed_characteristic,
|
||||
self.client_supported_features_characteristic,
|
||||
self.database_hash_characteristic,
|
||||
self.server_supported_features_characteristic,
|
||||
)
|
||||
if c is not None
|
||||
],
|
||||
primary=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_attribute_data(cls, attribute: att.Attribute) -> bytes:
|
||||
if attribute.type in (
|
||||
gatt.GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
gatt.GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
gatt.GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||
gatt.GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
gatt.GATT_CHARACTERISTIC_EXTENDED_PROPERTIES_DESCRIPTOR,
|
||||
):
|
||||
assert isinstance(attribute.value, bytes)
|
||||
return (
|
||||
struct.pack("<H", attribute.handle)
|
||||
+ attribute.type.to_bytes()
|
||||
+ attribute.value
|
||||
)
|
||||
elif attribute.type in (
|
||||
gatt.GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
|
||||
gatt.GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
gatt.GATT_SERVER_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
gatt.GATT_CHARACTERISTIC_PRESENTATION_FORMAT_DESCRIPTOR,
|
||||
gatt.GATT_CHARACTERISTIC_AGGREGATE_FORMAT_DESCRIPTOR,
|
||||
):
|
||||
return struct.pack("<H", attribute.handle) + attribute.type.to_bytes()
|
||||
|
||||
return b''
|
||||
|
||||
def get_database_hash(self, connection: device.Connection) -> bytes:
|
||||
m = b''.join(
|
||||
[
|
||||
self.get_attribute_data(attribute)
|
||||
for attribute in connection.device.gatt_server.attributes
|
||||
]
|
||||
)
|
||||
|
||||
return crypto.aes_cmac(m=m, k=bytes(16))
|
||||
|
||||
|
||||
class GenericAttributeProfileServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = GenericAttributeProfileService
|
||||
|
||||
client_supported_features_characteristic: (
|
||||
gatt_client.CharacteristicProxy[bytes] | None
|
||||
) = None
|
||||
server_supported_features_characteristic: (
|
||||
gatt_client.CharacteristicProxy[bytes] | None
|
||||
) = None
|
||||
database_hash_characteristic: gatt_client.CharacteristicProxy[bytes] | None = None
|
||||
service_changed_characteristic: gatt_client.CharacteristicProxy[bytes] | None = None
|
||||
|
||||
_CHARACTERISTICS = {
|
||||
gatt.GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC: 'client_supported_features_characteristic',
|
||||
gatt.GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC: 'server_supported_features_characteristic',
|
||||
gatt.GATT_DATABASE_HASH_CHARACTERISTIC: 'database_hash_characteristic',
|
||||
gatt.GATT_SERVICE_CHANGED_CHARACTERISTIC: 'service_changed_characteristic',
|
||||
}
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
for uuid, attribute_name in self._CHARACTERISTICS.items():
|
||||
if characteristics := self.service_proxy.get_characteristics_by_uuid(uuid):
|
||||
setattr(self, attribute_name, characteristics[0])
|
||||
198
bumble/profiles/gmap.py
Normal file
198
bumble/profiles/gmap.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""LE Audio - Gaming Audio Profile"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
from enum import IntFlag
|
||||
from typing import Optional
|
||||
|
||||
from bumble.gatt import (
|
||||
GATT_BGR_FEATURES_CHARACTERISTIC,
|
||||
GATT_BGS_FEATURES_CHARACTERISTIC,
|
||||
GATT_GAMING_AUDIO_SERVICE,
|
||||
GATT_GMAP_ROLE_CHARACTERISTIC,
|
||||
GATT_UGG_FEATURES_CHARACTERISTIC,
|
||||
GATT_UGT_FEATURES_CHARACTERISTIC,
|
||||
Characteristic,
|
||||
TemplateService,
|
||||
)
|
||||
from bumble.gatt_adapters import DelegatedCharacteristicProxyAdapter
|
||||
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class GmapRole(IntFlag):
|
||||
UNICAST_GAME_GATEWAY = 1 << 0
|
||||
UNICAST_GAME_TERMINAL = 1 << 1
|
||||
BROADCAST_GAME_SENDER = 1 << 2
|
||||
BROADCAST_GAME_RECEIVER = 1 << 3
|
||||
|
||||
|
||||
class UggFeatures(IntFlag):
|
||||
UGG_MULTIPLEX = 1 << 0
|
||||
UGG_96_KBPS_SOURCE = 1 << 1
|
||||
UGG_MULTISINK = 1 << 2
|
||||
|
||||
|
||||
class UgtFeatures(IntFlag):
|
||||
UGT_SOURCE = 1 << 0
|
||||
UGT_80_KBPS_SOURCE = 1 << 1
|
||||
UGT_SINK = 1 << 2
|
||||
UGT_64_KBPS_SINK = 1 << 3
|
||||
UGT_MULTIPLEX = 1 << 4
|
||||
UGT_MULTISINK = 1 << 5
|
||||
UGT_MULTISOURCE = 1 << 6
|
||||
|
||||
|
||||
class BgsFeatures(IntFlag):
|
||||
BGS_96_KBPS = 1 << 0
|
||||
|
||||
|
||||
class BgrFeatures(IntFlag):
|
||||
BGR_MULTISINK = 1 << 0
|
||||
BGR_MULTIPLEX = 1 << 1
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class GamingAudioService(TemplateService):
|
||||
UUID = GATT_GAMING_AUDIO_SERVICE
|
||||
|
||||
gmap_role: Characteristic
|
||||
ugg_features: Optional[Characteristic] = None
|
||||
ugt_features: Optional[Characteristic] = None
|
||||
bgs_features: Optional[Characteristic] = None
|
||||
bgr_features: Optional[Characteristic] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gmap_role: GmapRole,
|
||||
ugg_features: Optional[UggFeatures] = None,
|
||||
ugt_features: Optional[UgtFeatures] = None,
|
||||
bgs_features: Optional[BgsFeatures] = None,
|
||||
bgr_features: Optional[BgrFeatures] = None,
|
||||
) -> None:
|
||||
characteristics = []
|
||||
|
||||
ugg_features = UggFeatures(0) if ugg_features is None else ugg_features
|
||||
ugt_features = UgtFeatures(0) if ugt_features is None else ugt_features
|
||||
bgs_features = BgsFeatures(0) if bgs_features is None else bgs_features
|
||||
bgr_features = BgrFeatures(0) if bgr_features is None else bgr_features
|
||||
|
||||
self.gmap_role = Characteristic(
|
||||
uuid=GATT_GMAP_ROLE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', gmap_role),
|
||||
)
|
||||
characteristics.append(self.gmap_role)
|
||||
|
||||
if gmap_role & GmapRole.UNICAST_GAME_GATEWAY:
|
||||
self.ugg_features = Characteristic(
|
||||
uuid=GATT_UGG_FEATURES_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', ugg_features),
|
||||
)
|
||||
characteristics.append(self.ugg_features)
|
||||
|
||||
if gmap_role & GmapRole.UNICAST_GAME_TERMINAL:
|
||||
self.ugt_features = Characteristic(
|
||||
uuid=GATT_UGT_FEATURES_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', ugt_features),
|
||||
)
|
||||
characteristics.append(self.ugt_features)
|
||||
|
||||
if gmap_role & GmapRole.BROADCAST_GAME_SENDER:
|
||||
self.bgs_features = Characteristic(
|
||||
uuid=GATT_BGS_FEATURES_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', bgs_features),
|
||||
)
|
||||
characteristics.append(self.bgs_features)
|
||||
|
||||
if gmap_role & GmapRole.BROADCAST_GAME_RECEIVER:
|
||||
self.bgr_features = Characteristic(
|
||||
uuid=GATT_BGR_FEATURES_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', bgr_features),
|
||||
)
|
||||
characteristics.append(self.bgr_features)
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class GamingAudioServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = GamingAudioService
|
||||
|
||||
ugg_features: Optional[CharacteristicProxy[UggFeatures]] = None
|
||||
ugt_features: Optional[CharacteristicProxy[UgtFeatures]] = None
|
||||
bgs_features: Optional[CharacteristicProxy[BgsFeatures]] = None
|
||||
bgr_features: Optional[CharacteristicProxy[BgrFeatures]] = None
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.gmap_role = DelegatedCharacteristicProxyAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_GMAP_ROLE_CHARACTERISTIC
|
||||
),
|
||||
decode=lambda value: GmapRole(value[0]),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_UGG_FEATURES_CHARACTERISTIC
|
||||
):
|
||||
self.ugg_features = DelegatedCharacteristicProxyAdapter(
|
||||
characteristics[0],
|
||||
decode=lambda value: UggFeatures(value[0]),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_UGT_FEATURES_CHARACTERISTIC
|
||||
):
|
||||
self.ugt_features = DelegatedCharacteristicProxyAdapter(
|
||||
characteristics[0],
|
||||
decode=lambda value: UgtFeatures(value[0]),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_BGS_FEATURES_CHARACTERISTIC
|
||||
):
|
||||
self.bgs_features = DelegatedCharacteristicProxyAdapter(
|
||||
characteristics[0],
|
||||
decode=lambda value: BgsFeatures(value[0]),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_BGR_FEATURES_CHARACTERISTIC
|
||||
):
|
||||
self.bgr_features = DelegatedCharacteristicProxyAdapter(
|
||||
characteristics[0],
|
||||
decode=lambda value: BgrFeatures(value[0]),
|
||||
)
|
||||
@@ -16,22 +16,22 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from bumble import att, gatt, gatt_client
|
||||
from bumble.core import InvalidArgumentError, InvalidStateError
|
||||
from bumble.device import Device, Connection
|
||||
from bumble.utils import AsyncRunner, OpenIntEnum
|
||||
from bumble.hci import Address
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from bumble import att, gatt, gatt_adapters, gatt_client, utils
|
||||
from bumble.core import InvalidArgumentError, InvalidStateError
|
||||
from bumble.device import Connection, Device
|
||||
from bumble.hci import Address
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
class ErrorCode(OpenIntEnum):
|
||||
class ErrorCode(utils.OpenIntEnum):
|
||||
'''See Hearing Access Service 2.4. Attribute Profile error codes.'''
|
||||
|
||||
INVALID_OPCODE = 0x80
|
||||
@@ -41,7 +41,7 @@ class ErrorCode(OpenIntEnum):
|
||||
INVALID_PARAMETERS_LENGTH = 0x84
|
||||
|
||||
|
||||
class HearingAidType(OpenIntEnum):
|
||||
class HearingAidType(utils.OpenIntEnum):
|
||||
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
|
||||
|
||||
BINAURAL_HEARING_AID = 0b00
|
||||
@@ -49,35 +49,35 @@ class HearingAidType(OpenIntEnum):
|
||||
BANDED_HEARING_AID = 0b10
|
||||
|
||||
|
||||
class PresetSynchronizationSupport(OpenIntEnum):
|
||||
class PresetSynchronizationSupport(utils.OpenIntEnum):
|
||||
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
|
||||
|
||||
PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED = 0b0
|
||||
PRESET_SYNCHRONIZATION_IS_SUPPORTED = 0b1
|
||||
|
||||
|
||||
class IndependentPresets(OpenIntEnum):
|
||||
class IndependentPresets(utils.OpenIntEnum):
|
||||
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
|
||||
|
||||
IDENTICAL_PRESET_RECORD = 0b0
|
||||
DIFFERENT_PRESET_RECORD = 0b1
|
||||
|
||||
|
||||
class DynamicPresets(OpenIntEnum):
|
||||
class DynamicPresets(utils.OpenIntEnum):
|
||||
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
|
||||
|
||||
PRESET_RECORDS_DOES_NOT_CHANGE = 0b0
|
||||
PRESET_RECORDS_MAY_CHANGE = 0b1
|
||||
|
||||
|
||||
class WritablePresetsSupport(OpenIntEnum):
|
||||
class WritablePresetsSupport(utils.OpenIntEnum):
|
||||
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
|
||||
|
||||
WRITABLE_PRESET_RECORDS_NOT_SUPPORTED = 0b0
|
||||
WRITABLE_PRESET_RECORDS_SUPPORTED = 0b1
|
||||
|
||||
|
||||
class HearingAidPresetControlPointOpcode(OpenIntEnum):
|
||||
class HearingAidPresetControlPointOpcode(utils.OpenIntEnum):
|
||||
'''See Hearing Access Service 3.3.1 Hearing Aid Preset Control Point operation requirements.'''
|
||||
|
||||
# fmt: off
|
||||
@@ -129,7 +129,7 @@ def HearingAidFeatures_from_bytes(data: int) -> HearingAidFeatures:
|
||||
class PresetChangedOperation:
|
||||
'''See Hearing Access Service 3.2.2.2. Preset Changed operation.'''
|
||||
|
||||
class ChangeId(OpenIntEnum):
|
||||
class ChangeId(utils.OpenIntEnum):
|
||||
# fmt: off
|
||||
GENERIC_UPDATE = 0x00
|
||||
PRESET_RECORD_DELETED = 0x01
|
||||
@@ -189,11 +189,11 @@ class PresetRecord:
|
||||
|
||||
@dataclass
|
||||
class Property:
|
||||
class Writable(OpenIntEnum):
|
||||
class Writable(utils.OpenIntEnum):
|
||||
CANNOT_BE_WRITTEN = 0b0
|
||||
CAN_BE_WRITTEN = 0b1
|
||||
|
||||
class IsAvailable(OpenIntEnum):
|
||||
class IsAvailable(utils.OpenIntEnum):
|
||||
IS_UNAVAILABLE = 0b0
|
||||
IS_AVAILABLE = 0b1
|
||||
|
||||
@@ -223,27 +223,29 @@ class PresetRecord:
|
||||
class HearingAccessService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_HEARING_ACCESS_SERVICE
|
||||
|
||||
hearing_aid_features_characteristic: gatt.Characteristic
|
||||
hearing_aid_preset_control_point: gatt.Characteristic
|
||||
active_preset_index_characteristic: gatt.Characteristic
|
||||
hearing_aid_features_characteristic: gatt.Characteristic[bytes]
|
||||
hearing_aid_preset_control_point: gatt.Characteristic[bytes]
|
||||
active_preset_index_characteristic: gatt.Characteristic[bytes]
|
||||
active_preset_index: int
|
||||
active_preset_index_per_device: Dict[Address, int]
|
||||
active_preset_index_per_device: dict[Address, int]
|
||||
|
||||
device: Device
|
||||
|
||||
server_features: HearingAidFeatures
|
||||
preset_records: Dict[int, PresetRecord] # key is the preset index
|
||||
preset_records: dict[int, PresetRecord] # key is the preset index
|
||||
read_presets_request_in_progress: bool
|
||||
|
||||
preset_changed_operations_history_per_device: Dict[
|
||||
Address, List[PresetChangedOperation]
|
||||
other_server_in_binaural_set: Optional[HearingAccessService] = None
|
||||
|
||||
preset_changed_operations_history_per_device: dict[
|
||||
Address, list[PresetChangedOperation]
|
||||
]
|
||||
|
||||
# Keep an updated list of connected client to send notification to
|
||||
currently_connected_clients: Set[Connection]
|
||||
currently_connected_clients: set[Connection]
|
||||
|
||||
def __init__(
|
||||
self, device: Device, features: HearingAidFeatures, presets: List[PresetRecord]
|
||||
self, device: Device, features: HearingAidFeatures, presets: list[PresetRecord]
|
||||
) -> None:
|
||||
self.active_preset_index_per_device = {}
|
||||
self.read_presets_request_in_progress = False
|
||||
@@ -265,30 +267,25 @@ class HearingAccessService(gatt.TemplateService):
|
||||
# associate the lowest index as the current active preset at startup
|
||||
self.active_preset_index = sorted(self.preset_records.keys())[0]
|
||||
|
||||
@device.on('connection') # type: ignore
|
||||
@device.on(device.EVENT_CONNECTION)
|
||||
def on_connection(connection: Connection) -> None:
|
||||
@connection.on('disconnection') # type: ignore
|
||||
@connection.on(connection.EVENT_DISCONNECTION)
|
||||
def on_disconnection(_reason) -> None:
|
||||
self.currently_connected_clients.remove(connection)
|
||||
self.currently_connected_clients.discard(connection)
|
||||
|
||||
# TODO Should we filter on device bonded && device is HAP ?
|
||||
self.currently_connected_clients.add(connection)
|
||||
if (
|
||||
connection.peer_address
|
||||
not in self.preset_changed_operations_history_per_device
|
||||
):
|
||||
self.preset_changed_operations_history_per_device[
|
||||
connection.peer_address
|
||||
] = []
|
||||
return
|
||||
@connection.on(connection.EVENT_CONNECTION_ATT_MTU_UPDATE)
|
||||
def on_mtu_update(*_: Any) -> None:
|
||||
self.on_incoming_connection(connection)
|
||||
|
||||
async def on_connection_async() -> None:
|
||||
# Send all the PresetChangedOperation that occur when not connected
|
||||
await self._preset_changed_operation(connection)
|
||||
# Update the active preset index if needed
|
||||
await self.notify_active_preset_for_connection(connection)
|
||||
@connection.on(connection.EVENT_CONNECTION_ENCRYPTION_CHANGE)
|
||||
def on_encryption_change(*_: Any) -> None:
|
||||
self.on_incoming_connection(connection)
|
||||
|
||||
connection.abort_on('disconnection', on_connection_async())
|
||||
@connection.on(connection.EVENT_PAIRING)
|
||||
def on_pairing(*_: Any) -> None:
|
||||
self.on_incoming_connection(connection)
|
||||
|
||||
self.on_incoming_connection(connection)
|
||||
|
||||
self.hearing_aid_features_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC,
|
||||
@@ -325,9 +322,50 @@ class HearingAccessService(gatt.TemplateService):
|
||||
]
|
||||
)
|
||||
|
||||
def _on_read_active_preset_index(
|
||||
self, __connection__: Optional[Connection]
|
||||
) -> bytes:
|
||||
def on_incoming_connection(self, connection: Connection):
|
||||
'''Setup initial operations to handle a remote bonded HAP device'''
|
||||
# TODO Should we filter on HAP device only ?
|
||||
|
||||
if not connection.is_encrypted:
|
||||
logging.debug(f'HAS: {connection.peer_address} is not encrypted')
|
||||
return
|
||||
|
||||
if not connection.peer_resolvable_address:
|
||||
logging.debug(f'HAS: {connection.peer_address} is not paired')
|
||||
return
|
||||
|
||||
if connection.att_mtu < 49:
|
||||
logging.debug(
|
||||
f'HAS: {connection.peer_address} invalid MTU={connection.att_mtu}'
|
||||
)
|
||||
return
|
||||
|
||||
if connection.peer_address in self.currently_connected_clients:
|
||||
logging.debug(
|
||||
f'HAS: Already connected to {connection.peer_address} nothing to do'
|
||||
)
|
||||
return
|
||||
|
||||
self.currently_connected_clients.add(connection)
|
||||
if (
|
||||
connection.peer_address
|
||||
not in self.preset_changed_operations_history_per_device
|
||||
):
|
||||
self.preset_changed_operations_history_per_device[
|
||||
connection.peer_address
|
||||
] = []
|
||||
return
|
||||
|
||||
async def on_connection_async() -> None:
|
||||
# Send all the PresetChangedOperation that occur when not connected
|
||||
await self._preset_changed_operation(connection)
|
||||
# Update the active preset index if needed
|
||||
await self.notify_active_preset_for_connection(connection)
|
||||
|
||||
connection.cancel_on_disconnection(on_connection_async())
|
||||
|
||||
def _on_read_active_preset_index(self, connection: Connection) -> bytes:
|
||||
del connection # Unused
|
||||
return bytes([self.active_preset_index])
|
||||
|
||||
# TODO this need to be triggered when device is unbonded
|
||||
@@ -335,18 +373,13 @@ class HearingAccessService(gatt.TemplateService):
|
||||
self.preset_changed_operations_history_per_device.pop(addr)
|
||||
|
||||
async def _on_write_hearing_aid_preset_control_point(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
self, connection: Connection, value: bytes
|
||||
):
|
||||
assert connection
|
||||
|
||||
opcode = HearingAidPresetControlPointOpcode(value[0])
|
||||
handler = getattr(self, '_on_' + opcode.name.lower())
|
||||
await handler(connection, value)
|
||||
|
||||
async def _on_read_presets_request(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
):
|
||||
assert connection
|
||||
async def _on_read_presets_request(self, connection: Connection, value: bytes):
|
||||
if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements
|
||||
logging.warning(f'HAS require MTU >= 49: {connection}')
|
||||
|
||||
@@ -367,17 +400,19 @@ class HearingAccessService(gatt.TemplateService):
|
||||
self.preset_records[key]
|
||||
for key in sorted(self.preset_records.keys())
|
||||
if self.preset_records[key].index >= start_index
|
||||
]
|
||||
del presets[num_presets:]
|
||||
][:num_presets]
|
||||
if len(presets) == 0:
|
||||
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
|
||||
|
||||
AsyncRunner.spawn(self._read_preset_response(connection, presets))
|
||||
utils.AsyncRunner.spawn(self._read_preset_response(connection, presets))
|
||||
|
||||
async def _read_preset_response(
|
||||
self, connection: Connection, presets: List[PresetRecord]
|
||||
self, connection: Connection, presets: list[PresetRecord]
|
||||
):
|
||||
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Read Presets Request operation aborted and shall not either continue or restart the operation when the client reconnects.
|
||||
# If the ATT bearer is terminated before all notifications or indications are
|
||||
# sent, then the server shall consider the Read Presets Request operation
|
||||
# aborted and shall not either continue or restart the operation when the client
|
||||
# reconnects.
|
||||
try:
|
||||
for i, preset in enumerate(presets):
|
||||
await connection.device.indicate_subscriber(
|
||||
@@ -398,7 +433,7 @@ class HearingAccessService(gatt.TemplateService):
|
||||
|
||||
async def generic_update(self, op: PresetChangedOperation) -> None:
|
||||
'''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent'''
|
||||
await self._notifyPresetOperations(op)
|
||||
await self._notify_preset_operations(op)
|
||||
|
||||
async def delete_preset(self, index: int) -> None:
|
||||
'''Server API to delete a preset. It should not be the current active preset'''
|
||||
@@ -407,14 +442,14 @@ class HearingAccessService(gatt.TemplateService):
|
||||
raise InvalidStateError('Cannot delete active preset')
|
||||
|
||||
del self.preset_records[index]
|
||||
await self._notifyPresetOperations(PresetChangedOperationDeleted(index))
|
||||
await self._notify_preset_operations(PresetChangedOperationDeleted(index))
|
||||
|
||||
async def available_preset(self, index: int) -> None:
|
||||
'''Server API to make a preset available'''
|
||||
|
||||
preset = self.preset_records[index]
|
||||
preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE
|
||||
await self._notifyPresetOperations(PresetChangedOperationAvailable(index))
|
||||
await self._notify_preset_operations(PresetChangedOperationAvailable(index))
|
||||
|
||||
async def unavailable_preset(self, index: int) -> None:
|
||||
'''Server API to make a preset unavailable. It should not be the current active preset'''
|
||||
@@ -426,7 +461,7 @@ class HearingAccessService(gatt.TemplateService):
|
||||
preset.properties.is_available = (
|
||||
PresetRecord.Property.IsAvailable.IS_UNAVAILABLE
|
||||
)
|
||||
await self._notifyPresetOperations(PresetChangedOperationUnavailable(index))
|
||||
await self._notify_preset_operations(PresetChangedOperationUnavailable(index))
|
||||
|
||||
async def _preset_changed_operation(self, connection: Connection) -> None:
|
||||
'''Send all PresetChangedOperation saved for a given connection'''
|
||||
@@ -441,30 +476,31 @@ class HearingAccessService(gatt.TemplateService):
|
||||
return op.additional_parameters
|
||||
|
||||
op_list.sort(key=get_op_index)
|
||||
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Preset Changed operation aborted and shall continue the operation when the client reconnects.
|
||||
while len(op_list) > 0:
|
||||
# If the ATT bearer is terminated before all notifications or indications are
|
||||
# sent, then the server shall consider the Preset Changed operation aborted and
|
||||
# shall continue the operation when the client reconnects.
|
||||
while op_list:
|
||||
try:
|
||||
await connection.device.indicate_subscriber(
|
||||
connection,
|
||||
self.hearing_aid_preset_control_point,
|
||||
value=op_list[0].to_bytes(len(op_list) == 1),
|
||||
force=True, # TODO GATT notification subscription should be persistent
|
||||
)
|
||||
# Remove item once sent, and keep the non sent item in the list
|
||||
op_list.pop(0)
|
||||
except TimeoutError:
|
||||
break
|
||||
|
||||
async def _notifyPresetOperations(self, op: PresetChangedOperation) -> None:
|
||||
for historyList in self.preset_changed_operations_history_per_device.values():
|
||||
historyList.append(op)
|
||||
async def _notify_preset_operations(self, op: PresetChangedOperation) -> None:
|
||||
for history_list in self.preset_changed_operations_history_per_device.values():
|
||||
history_list.append(op)
|
||||
|
||||
for connection in self.currently_connected_clients:
|
||||
await self._preset_changed_operation(connection)
|
||||
|
||||
async def _on_write_preset_name(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
):
|
||||
assert connection
|
||||
async def _on_write_preset_name(self, connection: Connection, value: bytes):
|
||||
del connection # Unused
|
||||
|
||||
if self.read_presets_request_in_progress:
|
||||
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
|
||||
@@ -512,10 +548,7 @@ class HearingAccessService(gatt.TemplateService):
|
||||
for connection in self.currently_connected_clients:
|
||||
await self.notify_active_preset_for_connection(connection)
|
||||
|
||||
async def set_active_preset(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
) -> None:
|
||||
assert connection
|
||||
async def set_active_preset(self, value: bytes) -> None:
|
||||
index = value[1]
|
||||
preset = self.preset_records.get(index, None)
|
||||
if (
|
||||
@@ -532,86 +565,85 @@ class HearingAccessService(gatt.TemplateService):
|
||||
self.active_preset_index = index
|
||||
await self.notify_active_preset()
|
||||
|
||||
async def _on_set_active_preset(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
):
|
||||
await self.set_active_preset(connection, value)
|
||||
async def _on_set_active_preset(self, connection: Connection, value: bytes):
|
||||
del connection # Unused
|
||||
await self.set_active_preset(value)
|
||||
|
||||
async def set_next_or_previous_preset(
|
||||
self, connection: Optional[Connection], is_previous
|
||||
):
|
||||
async def set_next_or_previous_preset(self, is_previous: bool) -> None:
|
||||
'''Set the next or the previous preset as active'''
|
||||
assert connection
|
||||
|
||||
if self.active_preset_index == 0x00:
|
||||
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
|
||||
|
||||
first_preset: Optional[PresetRecord] = None # To loop to first preset
|
||||
next_preset: Optional[PresetRecord] = None
|
||||
for index, record in sorted(self.preset_records.items(), reverse=is_previous):
|
||||
if not record.is_available():
|
||||
continue
|
||||
if first_preset == None:
|
||||
first_preset = record
|
||||
if is_previous:
|
||||
if index >= self.active_preset_index:
|
||||
continue
|
||||
elif index <= self.active_preset_index:
|
||||
continue
|
||||
next_preset = record
|
||||
break
|
||||
presets = sorted(
|
||||
[
|
||||
record
|
||||
for record in self.preset_records.values()
|
||||
if record.is_available()
|
||||
],
|
||||
key=lambda record: record.index,
|
||||
)
|
||||
current_preset = self.preset_records[self.active_preset_index]
|
||||
current_preset_pos = presets.index(current_preset)
|
||||
if is_previous:
|
||||
new_preset = presets[(current_preset_pos - 1) % len(presets)]
|
||||
else:
|
||||
new_preset = presets[(current_preset_pos + 1) % len(presets)]
|
||||
|
||||
if not first_preset: # If no other preset are available
|
||||
if current_preset == new_preset: # If no other preset are available
|
||||
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
|
||||
|
||||
if next_preset:
|
||||
self.active_preset_index = next_preset.index
|
||||
else:
|
||||
self.active_preset_index = first_preset.index
|
||||
self.active_preset_index = new_preset.index
|
||||
await self.notify_active_preset()
|
||||
|
||||
async def _on_set_next_preset(
|
||||
self, connection: Optional[Connection], __value__: bytes
|
||||
) -> None:
|
||||
await self.set_next_or_previous_preset(connection, False)
|
||||
async def _on_set_next_preset(self, connection: Connection, value: bytes) -> None:
|
||||
del connection, value # Unused.
|
||||
await self.set_next_or_previous_preset(False)
|
||||
|
||||
async def _on_set_previous_preset(
|
||||
self, connection: Optional[Connection], __value__: bytes
|
||||
self, connection: Connection, value: bytes
|
||||
) -> None:
|
||||
await self.set_next_or_previous_preset(connection, True)
|
||||
del connection, value # Unused.
|
||||
await self.set_next_or_previous_preset(True)
|
||||
|
||||
async def _on_set_active_preset_synchronized_locally(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
self, connection: Connection, value: bytes
|
||||
):
|
||||
del connection # Unused.
|
||||
if (
|
||||
self.server_features.preset_synchronization_support
|
||||
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
|
||||
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
|
||||
):
|
||||
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
|
||||
await self.set_active_preset(connection, value)
|
||||
# TODO (low priority) inform other server of the change
|
||||
await self.set_active_preset(value)
|
||||
if self.other_server_in_binaural_set:
|
||||
await self.other_server_in_binaural_set.set_active_preset(value)
|
||||
|
||||
async def _on_set_next_preset_synchronized_locally(
|
||||
self, connection: Optional[Connection], __value__: bytes
|
||||
self, connection: Connection, value: bytes
|
||||
):
|
||||
del connection, value # Unused.
|
||||
if (
|
||||
self.server_features.preset_synchronization_support
|
||||
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
|
||||
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
|
||||
):
|
||||
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
|
||||
await self.set_next_or_previous_preset(connection, False)
|
||||
# TODO (low priority) inform other server of the change
|
||||
await self.set_next_or_previous_preset(False)
|
||||
if self.other_server_in_binaural_set:
|
||||
await self.other_server_in_binaural_set.set_next_or_previous_preset(False)
|
||||
|
||||
async def _on_set_previous_preset_synchronized_locally(
|
||||
self, connection: Optional[Connection], __value__: bytes
|
||||
self, connection: Connection, value: bytes
|
||||
):
|
||||
del connection, value # Unused.
|
||||
if (
|
||||
self.server_features.preset_synchronization_support
|
||||
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
|
||||
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
|
||||
):
|
||||
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
|
||||
await self.set_next_or_previous_preset(connection, True)
|
||||
# TODO (low priority) inform other server of the change
|
||||
await self.set_next_or_previous_preset(True)
|
||||
if self.other_server_in_binaural_set:
|
||||
await self.other_server_in_binaural_set.set_next_or_previous_preset(True)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -621,12 +653,15 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = HearingAccessService
|
||||
|
||||
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
|
||||
preset_control_point_indications: asyncio.Queue
|
||||
preset_control_point_indications: asyncio.Queue[bytes]
|
||||
active_preset_index_notification: asyncio.Queue[bytes]
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
self.preset_control_point_indications = asyncio.Queue()
|
||||
self.active_preset_index_notification = asyncio.Queue()
|
||||
|
||||
self.server_features = gatt.PackedCharacteristicAdapter(
|
||||
self.server_features = gatt_adapters.PackedCharacteristicProxyAdapter(
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC
|
||||
)[0],
|
||||
@@ -639,27 +674,19 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
)[0]
|
||||
)
|
||||
|
||||
self.active_preset_index = gatt.PackedCharacteristicAdapter(
|
||||
self.active_preset_index = gatt_adapters.PackedCharacteristicProxyAdapter(
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC
|
||||
)[0],
|
||||
'B',
|
||||
)
|
||||
|
||||
async def setup_subscription(self):
|
||||
self.preset_control_point_indications = asyncio.Queue()
|
||||
self.active_preset_index_notification = asyncio.Queue()
|
||||
|
||||
def on_active_preset_index_notification(data: bytes):
|
||||
self.active_preset_index_notification.put_nowait(data)
|
||||
|
||||
def on_preset_control_point_indication(data: bytes):
|
||||
self.preset_control_point_indications.put_nowait(data)
|
||||
|
||||
async def setup_subscription(self) -> None:
|
||||
await self.hearing_aid_preset_control_point.subscribe(
|
||||
functools.partial(on_preset_control_point_indication), prefer_notify=False
|
||||
self.preset_control_point_indications.put_nowait,
|
||||
prefer_notify=False,
|
||||
)
|
||||
|
||||
await self.active_preset_index.subscribe(
|
||||
functools.partial(on_active_preset_index_notification)
|
||||
self.active_preset_index_notification.put_nowait
|
||||
)
|
||||
|
||||
@@ -16,23 +16,29 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from enum import IntEnum
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
|
||||
from bumble import core
|
||||
from ..gatt_client import ProfileServiceProxy
|
||||
from ..att import ATT_Error
|
||||
from ..gatt import (
|
||||
GATT_HEART_RATE_SERVICE,
|
||||
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
|
||||
from bumble.att import ATT_Error
|
||||
from bumble.gatt import (
|
||||
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
|
||||
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
|
||||
TemplateService,
|
||||
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
|
||||
GATT_HEART_RATE_SERVICE,
|
||||
Characteristic,
|
||||
CharacteristicValue,
|
||||
TemplateService,
|
||||
)
|
||||
from bumble.gatt_adapters import (
|
||||
DelegatedCharacteristicAdapter,
|
||||
PackedCharacteristicAdapter,
|
||||
SerializableCharacteristicAdapter,
|
||||
)
|
||||
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -42,6 +48,10 @@ class HeartRateService(TemplateService):
|
||||
CONTROL_POINT_NOT_SUPPORTED = 0x80
|
||||
RESET_ENERGY_EXPENDED = 0x01
|
||||
|
||||
heart_rate_measurement_characteristic: Characteristic[HeartRateMeasurement]
|
||||
body_sensor_location_characteristic: Characteristic[BodySensorLocation]
|
||||
heart_rate_control_point_characteristic: Characteristic[int]
|
||||
|
||||
class BodySensorLocation(IntEnum):
|
||||
OTHER = 0
|
||||
CHEST = 1
|
||||
@@ -150,15 +160,14 @@ class HeartRateService(TemplateService):
|
||||
body_sensor_location=None,
|
||||
reset_energy_expended=None,
|
||||
):
|
||||
self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter(
|
||||
self.heart_rate_measurement_characteristic = SerializableCharacteristicAdapter(
|
||||
Characteristic(
|
||||
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
|
||||
Characteristic.Properties.NOTIFY,
|
||||
0,
|
||||
CharacteristicValue(read=read_heart_rate_measurement),
|
||||
),
|
||||
# pylint: disable=unnecessary-lambda
|
||||
encode=lambda value: bytes(value),
|
||||
HeartRateService.HeartRateMeasurement,
|
||||
)
|
||||
characteristics = [self.heart_rate_measurement_characteristic]
|
||||
|
||||
@@ -198,15 +207,22 @@ class HeartRateService(TemplateService):
|
||||
class HeartRateServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = HeartRateService
|
||||
|
||||
heart_rate_measurement: Optional[
|
||||
CharacteristicProxy[HeartRateService.HeartRateMeasurement]
|
||||
]
|
||||
body_sensor_location: Optional[
|
||||
CharacteristicProxy[HeartRateService.BodySensorLocation]
|
||||
]
|
||||
heart_rate_control_point: Optional[CharacteristicProxy[int]]
|
||||
|
||||
def __init__(self, service_proxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
|
||||
):
|
||||
self.heart_rate_measurement = DelegatedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
decode=HeartRateService.HeartRateMeasurement.from_bytes,
|
||||
self.heart_rate_measurement = SerializableCharacteristicAdapter(
|
||||
characteristics[0], HeartRateService.HeartRateMeasurement
|
||||
)
|
||||
else:
|
||||
self.heart_rate_measurement = None
|
||||
|
||||
@@ -16,24 +16,38 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import struct
|
||||
from typing import List, Type
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble import utils
|
||||
from bumble.profiles import bap
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class AudioActiveState(utils.OpenIntEnum):
|
||||
NO_AUDIO_DATA_TRANSMITTED = 0x00
|
||||
AUDIO_DATA_TRANSMITTED = 0x01
|
||||
|
||||
|
||||
class AssistedListeningStream(utils.OpenIntEnum):
|
||||
UNSPECIFIED_AUDIO_ENHANCEMENT = 0x00
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Metadata:
|
||||
'''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures.
|
||||
|
||||
As Metadata fields may extend, and Spec doesn't forbid duplication, we don't parse
|
||||
Metadata into a key-value style dataclass here. Rather, we encourage users to parse
|
||||
again outside the lib.
|
||||
As Metadata fields may extend, and the spec may not guarantee the uniqueness of
|
||||
tags, we don't automatically parse the Metadata data into specific classes.
|
||||
Users of this class may decode the data by themselves, or use the Entry.decode
|
||||
method.
|
||||
'''
|
||||
|
||||
class Tag(utils.OpenIntEnum):
|
||||
@@ -57,17 +71,78 @@ class Metadata:
|
||||
tag: Metadata.Tag
|
||||
data: bytes
|
||||
|
||||
def decode(self) -> Any:
|
||||
"""
|
||||
Decode the data into an object, if possible.
|
||||
|
||||
If no specific object class exists to represent the data, the raw data
|
||||
bytes are returned.
|
||||
"""
|
||||
|
||||
if self.tag in (
|
||||
Metadata.Tag.PREFERRED_AUDIO_CONTEXTS,
|
||||
Metadata.Tag.STREAMING_AUDIO_CONTEXTS,
|
||||
):
|
||||
return bap.ContextType(struct.unpack("<H", self.data)[0])
|
||||
|
||||
if self.tag in (
|
||||
Metadata.Tag.PROGRAM_INFO,
|
||||
Metadata.Tag.PROGRAM_INFO_URI,
|
||||
Metadata.Tag.BROADCAST_NAME,
|
||||
):
|
||||
return self.data.decode("utf-8")
|
||||
|
||||
if self.tag == Metadata.Tag.LANGUAGE:
|
||||
return self.data.decode("ascii")
|
||||
|
||||
if self.tag == Metadata.Tag.CCID_LIST:
|
||||
return list(self.data)
|
||||
|
||||
if self.tag == Metadata.Tag.PARENTAL_RATING:
|
||||
return self.data[0]
|
||||
|
||||
if self.tag == Metadata.Tag.AUDIO_ACTIVE_STATE:
|
||||
return AudioActiveState(self.data[0])
|
||||
|
||||
if self.tag == Metadata.Tag.ASSISTED_LISTENING_STREAM:
|
||||
return AssistedListeningStream(self.data[0])
|
||||
|
||||
return self.data
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
def from_bytes(cls: type[Self], data: bytes) -> Self:
|
||||
return cls(tag=Metadata.Tag(data[0]), data=data[1:])
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes([len(self.data) + 1, self.tag]) + self.data
|
||||
|
||||
entries: List[Entry] = dataclasses.field(default_factory=list)
|
||||
entries: list[Entry] = dataclasses.field(default_factory=list)
|
||||
|
||||
def pretty_print(self, indent: str) -> str:
|
||||
"""Convenience method to generate a string with one key-value pair per line."""
|
||||
|
||||
max_key_length = 0
|
||||
keys = []
|
||||
values = []
|
||||
for entry in self.entries:
|
||||
key = entry.tag.name
|
||||
max_key_length = max(max_key_length, len(key))
|
||||
keys.append(key)
|
||||
decoded = entry.decode()
|
||||
if isinstance(decoded, enum.Enum):
|
||||
values.append(decoded.name)
|
||||
elif isinstance(decoded, bytes):
|
||||
values.append(decoded.hex())
|
||||
else:
|
||||
values.append(str(decoded))
|
||||
|
||||
return '\n'.join(
|
||||
f'{indent}{key}: {" " * (max_key_length-len(key))}{value}'
|
||||
for key, value in zip(keys, values)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
def from_bytes(cls: type[Self], data: bytes) -> Self:
|
||||
entries = []
|
||||
offset = 0
|
||||
length = len(data)
|
||||
@@ -81,3 +156,13 @@ class Metadata:
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return b''.join([bytes(entry) for entry in self.entries])
|
||||
|
||||
def __str__(self) -> str:
|
||||
entries_str = []
|
||||
for entry in self.entries:
|
||||
decoded = entry.decode()
|
||||
entries_str.append(
|
||||
f'{entry.tag.name}: '
|
||||
f'{decoded.hex() if isinstance(decoded, bytes) else decoded!r}'
|
||||
)
|
||||
return f'Metadata(entries={", ".join(entry_str for entry_str in entries_str)})'
|
||||
|
||||
@@ -22,16 +22,12 @@ import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import struct
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
from bumble import core
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble import utils
|
||||
|
||||
from typing import Type, Optional, ClassVar, Dict, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble import core, device, gatt, gatt_client, utils
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -167,7 +163,7 @@ class ObjectId(int):
|
||||
'''See Media Control Service 4.4.2. Object ID field.'''
|
||||
|
||||
@classmethod
|
||||
def create_from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
def create_from_bytes(cls: type[Self], data: bytes) -> Self:
|
||||
return cls(int.from_bytes(data, byteorder='little', signed=False))
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
@@ -182,7 +178,7 @@ class GroupObjectType:
|
||||
object_id: ObjectId
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
def from_bytes(cls: type[Self], data: bytes) -> Self:
|
||||
return cls(
|
||||
object_type=ObjectType(data[0]),
|
||||
object_id=ObjectId.create_from_bytes(data[1:]),
|
||||
@@ -208,7 +204,7 @@ class MediaControlService(gatt.TemplateService):
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=media_player_name or 'Bumble Player',
|
||||
value=(media_player_name or 'Bumble Player').encode(),
|
||||
)
|
||||
self.track_changed_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
|
||||
@@ -247,14 +243,16 @@ class MediaControlService(gatt.TemplateService):
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
self.media_control_point_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.WRITE
|
||||
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(write=self.on_media_control_point),
|
||||
self.media_control_point_characteristic: gatt.Characteristic[bytes] = (
|
||||
gatt.Characteristic(
|
||||
uuid=gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.WRITE
|
||||
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(write=self.on_media_control_point),
|
||||
)
|
||||
)
|
||||
self.media_control_point_opcodes_supported_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
|
||||
@@ -285,11 +283,8 @@ class MediaControlService(gatt.TemplateService):
|
||||
)
|
||||
|
||||
async def on_media_control_point(
|
||||
self, connection: Optional[device.Connection], data: bytes
|
||||
self, connection: device.Connection, data: bytes
|
||||
) -> None:
|
||||
if not connection:
|
||||
raise core.InvalidStateError()
|
||||
|
||||
opcode = MediaControlPointOpcode(data[0])
|
||||
|
||||
await connection.device.notify_subscriber(
|
||||
@@ -311,7 +306,7 @@ class MediaControlServiceProxy(
|
||||
):
|
||||
SERVICE_CLASS = MediaControlService
|
||||
|
||||
_CHARACTERISTICS: ClassVar[Dict[str, core.UUID]] = {
|
||||
_CHARACTERISTICS: ClassVar[dict[str, core.UUID]] = {
|
||||
'media_player_name': gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
|
||||
'media_player_icon_object_id': gatt.GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC,
|
||||
'media_player_icon_url': gatt.GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC,
|
||||
@@ -336,30 +331,38 @@ class MediaControlServiceProxy(
|
||||
'content_control_id': gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
|
||||
}
|
||||
|
||||
media_player_name: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_player_icon_url: Optional[gatt_client.CharacteristicProxy] = None
|
||||
track_changed: Optional[gatt_client.CharacteristicProxy] = None
|
||||
track_title: Optional[gatt_client.CharacteristicProxy] = None
|
||||
track_duration: Optional[gatt_client.CharacteristicProxy] = None
|
||||
track_position: Optional[gatt_client.CharacteristicProxy] = None
|
||||
playback_speed: Optional[gatt_client.CharacteristicProxy] = None
|
||||
seeking_speed: Optional[gatt_client.CharacteristicProxy] = None
|
||||
current_track_segments_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
current_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
next_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
parent_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
current_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
playing_order: Optional[gatt_client.CharacteristicProxy] = None
|
||||
playing_orders_supported: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_state: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_control_point: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_control_point_opcodes_supported: Optional[gatt_client.CharacteristicProxy] = (
|
||||
None
|
||||
)
|
||||
search_control_point: Optional[gatt_client.CharacteristicProxy] = None
|
||||
search_results_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
content_control_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
EVENT_MEDIA_STATE = "media_state"
|
||||
EVENT_TRACK_CHANGED = "track_changed"
|
||||
EVENT_TRACK_TITLE = "track_title"
|
||||
EVENT_TRACK_DURATION = "track_duration"
|
||||
EVENT_TRACK_POSITION = "track_position"
|
||||
|
||||
media_player_name: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
media_player_icon_url: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
track_changed: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
track_title: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
track_duration: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
track_position: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
playback_speed: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
seeking_speed: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
current_track_segments_object_id: Optional[
|
||||
gatt_client.CharacteristicProxy[bytes]
|
||||
] = None
|
||||
current_track_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
next_track_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
parent_group_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
current_group_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
playing_order: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
playing_orders_supported: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
media_state: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
media_control_point: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
media_control_point_opcodes_supported: Optional[
|
||||
gatt_client.CharacteristicProxy[bytes]
|
||||
] = None
|
||||
search_control_point: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
search_results_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
content_control_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
media_control_point_notifications: asyncio.Queue[bytes]
|
||||
@@ -428,20 +431,20 @@ class MediaControlServiceProxy(
|
||||
self.media_control_point_notifications.put_nowait(data)
|
||||
|
||||
def _on_media_state(self, data: bytes) -> None:
|
||||
self.emit('media_state', MediaState(data[0]))
|
||||
self.emit(self.EVENT_MEDIA_STATE, MediaState(data[0]))
|
||||
|
||||
def _on_track_changed(self, data: bytes) -> None:
|
||||
del data
|
||||
self.emit('track_changed')
|
||||
self.emit(self.EVENT_TRACK_CHANGED)
|
||||
|
||||
def _on_track_title(self, data: bytes) -> None:
|
||||
self.emit('track_title', data.decode("utf-8"))
|
||||
self.emit(self.EVENT_TRACK_TITLE, data.decode("utf-8"))
|
||||
|
||||
def _on_track_duration(self, data: bytes) -> None:
|
||||
self.emit('track_duration', struct.unpack_from('<i', data)[0])
|
||||
self.emit(self.EVENT_TRACK_DURATION, struct.unpack_from('<i', data)[0])
|
||||
|
||||
def _on_track_position(self, data: bytes) -> None:
|
||||
self.emit('track_position', struct.unpack_from('<i', data)[0])
|
||||
self.emit(self.EVENT_TRACK_POSITION, struct.unpack_from('<i', data)[0])
|
||||
|
||||
|
||||
class GenericMediaControlServiceProxy(MediaControlServiceProxy):
|
||||
|
||||
@@ -17,17 +17,15 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import struct
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from bumble.profiles.bap import AudioLocation, CodecSpecificCapabilities, ContextType
|
||||
from bumble import gatt, gatt_adapters, gatt_client, hci
|
||||
from bumble.profiles import le_audio
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble import hci
|
||||
|
||||
from bumble.profiles.bap import AudioLocation, CodecSpecificCapabilities, ContextType
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -72,6 +70,19 @@ class PacRecord:
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_from_bytes(cls, data: bytes) -> list[PacRecord]:
|
||||
"""Parse a serialized list of records preceded by a one byte list length."""
|
||||
record_count = data[0]
|
||||
records = []
|
||||
offset = 1
|
||||
for _ in range(record_count):
|
||||
record = PacRecord.from_bytes(data[offset:])
|
||||
offset += len(bytes(record))
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
capabilities_bytes = bytes(self.codec_specific_capabilities)
|
||||
metadata_bytes = bytes(self.metadata)
|
||||
@@ -90,12 +101,12 @@ class PacRecord:
|
||||
class PublishedAudioCapabilitiesService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
|
||||
|
||||
sink_pac: Optional[gatt.Characteristic]
|
||||
sink_audio_locations: Optional[gatt.Characteristic]
|
||||
source_pac: Optional[gatt.Characteristic]
|
||||
source_audio_locations: Optional[gatt.Characteristic]
|
||||
available_audio_contexts: gatt.Characteristic
|
||||
supported_audio_contexts: gatt.Characteristic
|
||||
sink_pac: Optional[gatt.Characteristic[bytes]]
|
||||
sink_audio_locations: Optional[gatt.Characteristic[bytes]]
|
||||
source_pac: Optional[gatt.Characteristic[bytes]]
|
||||
source_audio_locations: Optional[gatt.Characteristic[bytes]]
|
||||
available_audio_contexts: gatt.Characteristic[bytes]
|
||||
supported_audio_contexts: gatt.Characteristic[bytes]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -172,39 +183,70 @@ class PublishedAudioCapabilitiesService(gatt.TemplateService):
|
||||
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = PublishedAudioCapabilitiesService
|
||||
|
||||
sink_pac: Optional[gatt_client.CharacteristicProxy] = None
|
||||
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
|
||||
source_pac: Optional[gatt_client.CharacteristicProxy] = None
|
||||
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
|
||||
available_audio_contexts: gatt_client.CharacteristicProxy
|
||||
supported_audio_contexts: gatt_client.CharacteristicProxy
|
||||
sink_pac: Optional[gatt_client.CharacteristicProxy[list[PacRecord]]] = None
|
||||
sink_audio_locations: Optional[gatt_client.CharacteristicProxy[AudioLocation]] = (
|
||||
None
|
||||
)
|
||||
source_pac: Optional[gatt_client.CharacteristicProxy[list[PacRecord]]] = None
|
||||
source_audio_locations: Optional[gatt_client.CharacteristicProxy[AudioLocation]] = (
|
||||
None
|
||||
)
|
||||
available_audio_contexts: gatt_client.CharacteristicProxy[tuple[ContextType, ...]]
|
||||
supported_audio_contexts: gatt_client.CharacteristicProxy[tuple[ContextType, ...]]
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
|
||||
)[0]
|
||||
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
|
||||
)[0]
|
||||
self.available_audio_contexts = (
|
||||
gatt_adapters.DelegatedCharacteristicProxyAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
|
||||
),
|
||||
decode=lambda x: tuple(map(ContextType, struct.unpack('<HH', x))),
|
||||
)
|
||||
)
|
||||
|
||||
self.supported_audio_contexts = (
|
||||
gatt_adapters.DelegatedCharacteristicProxyAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
|
||||
),
|
||||
decode=lambda x: tuple(map(ContextType, struct.unpack('<HH', x))),
|
||||
)
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SINK_PAC_CHARACTERISTIC
|
||||
):
|
||||
self.sink_pac = characteristics[0]
|
||||
self.sink_pac = gatt_adapters.DelegatedCharacteristicProxyAdapter(
|
||||
characteristics[0],
|
||||
decode=PacRecord.list_from_bytes,
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SOURCE_PAC_CHARACTERISTIC
|
||||
):
|
||||
self.source_pac = characteristics[0]
|
||||
self.source_pac = gatt_adapters.DelegatedCharacteristicProxyAdapter(
|
||||
characteristics[0],
|
||||
decode=PacRecord.list_from_bytes,
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
|
||||
):
|
||||
self.sink_audio_locations = characteristics[0]
|
||||
self.sink_audio_locations = (
|
||||
gatt_adapters.DelegatedCharacteristicProxyAdapter(
|
||||
characteristics[0],
|
||||
decode=lambda x: AudioLocation(struct.unpack('<I', x)[0]),
|
||||
)
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
|
||||
):
|
||||
self.source_audio_locations = characteristics[0]
|
||||
self.source_audio_locations = (
|
||||
gatt_adapters.DelegatedCharacteristicProxyAdapter(
|
||||
characteristics[0],
|
||||
decode=lambda x: AudioLocation(struct.unpack('<I', x)[0]),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -16,8 +16,10 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble.profiles import le_audio
|
||||
@@ -40,7 +42,7 @@ class PublicBroadcastAnnouncement:
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
features = cls.Features(data[0])
|
||||
metadata_length = data[1]
|
||||
metadata_ltv = data[1 : 1 + metadata_length]
|
||||
metadata_ltv = data[2 : 2 + metadata_length]
|
||||
return cls(
|
||||
features=features, metadata=le_audio.Metadata.from_bytes(metadata_ltv)
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user