mirror of
https://github.com/google/bumble.git
synced 2026-04-16 00:25:31 +00:00
Compare commits
445 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f4aeaa6eb3 | ||
|
|
d7489a644a | ||
|
|
a877283360 | ||
|
|
6d91e7e79b | ||
|
|
567146b143 | ||
|
|
1a3272d7ca | ||
|
|
1ee1ff0b62 | ||
|
|
729fd97748 | ||
|
|
e308051885 | ||
|
|
10e53553d7 | ||
|
|
ef0b30d059 | ||
|
|
e7e9f9509a | ||
|
|
c6cfd101df | ||
|
|
d2dcf063ee | ||
|
|
d15bc7d664 | ||
|
|
e4364d18a7 | ||
|
|
6a34c9f224 | ||
|
|
2a764fd6bb | ||
|
|
3e8ce38eba | ||
|
|
8d2f37aa7a | ||
|
|
b7b70ebcbb | ||
|
|
8ba91f4986 | ||
|
|
79a5e953bc | ||
|
|
20de5ea250 | ||
|
|
bad9ce272c | ||
|
|
d3273ffa8c | ||
|
|
071fc2723a | ||
|
|
ef4ea86f58 | ||
|
|
dfdaa149d0 | ||
|
|
986343a807 | ||
|
|
5211d7ba96 | ||
|
|
a167342778 | ||
|
|
1efb8cdbee | ||
|
|
80d83e6a70 | ||
|
|
31ec1c41ce | ||
|
|
aba1ac0cea | ||
|
|
c40824e51c | ||
|
|
2920f05dae | ||
|
|
bc911d6da0 | ||
|
|
4f87f587e4 | ||
|
|
3e38ab3638 | ||
|
|
21bb911fea | ||
|
|
744dfa33a2 | ||
|
|
ec5f8535a8 | ||
|
|
5a83734a00 | ||
|
|
b4ae8af3a7 | ||
|
|
da60386385 | ||
|
|
45c4c4f4c5 | ||
|
|
9187c75d68 | ||
|
|
abeec22546 | ||
|
|
a6bab755cf | ||
|
|
acd9d994c3 | ||
|
|
37afda3ed3 | ||
|
|
54f2981267 | ||
|
|
bb025514e7 | ||
|
|
e228597269 | ||
|
|
95b0d6c6f2 | ||
|
|
fa4df6e3a2 | ||
|
|
46ceea7ecd | ||
|
|
30f89d5739 | ||
|
|
481cf40831 | ||
|
|
eff05afb7a | ||
|
|
d8e6700611 | ||
|
|
56eb5a933b | ||
|
|
caacc0c133 | ||
|
|
5f377c024b | ||
|
|
00cd8fbdd0 | ||
|
|
aeeff18428 | ||
|
|
c48e3f5e9c | ||
|
|
d6bbc1145a | ||
|
|
e2fec67bd9 | ||
|
|
88cb3b2a4d | ||
|
|
9ebb03be46 | ||
|
|
80d84af76c | ||
|
|
8f4721758f | ||
|
|
8864af4acd | ||
|
|
8980fb8cc7 | ||
|
|
2c5f3472a9 | ||
|
|
f18277ac78 | ||
|
|
8d46bc04d2 | ||
|
|
09e5ea5dec | ||
|
|
d43281c57e | ||
|
|
6810865670 | ||
|
|
3e9e06a02c | ||
|
|
ccd12f6591 | ||
|
|
f9a7843f7e | ||
|
|
210c334db7 | ||
|
|
f297cdfcce | ||
|
|
5b536d00ab | ||
|
|
b4af46ebd5 | ||
|
|
c08da3193e | ||
|
|
f2925ca647 | ||
|
|
fd4d68e5c0 | ||
|
|
5d83deffa4 | ||
|
|
2878cca478 | ||
|
|
53934716db | ||
|
|
d885d45824 | ||
|
|
b90d0f8710 | ||
|
|
8ccfc90fe6 | ||
|
|
92aa7e9e2a | ||
|
|
afc6d19e04 | ||
|
|
c05f073b33 | ||
|
|
2b4c2a22f4 | ||
|
|
47fe93a148 | ||
|
|
6139ca8045 | ||
|
|
87c76a4a0e | ||
|
|
f7b66db873 | ||
|
|
0b314bd7f7 | ||
|
|
9da2e32ad7 | ||
|
|
93c0875740 | ||
|
|
a286700239 | ||
|
|
98ed772e8a | ||
|
|
f0b55a4f97 | ||
|
|
b74503d345 | ||
|
|
f911163e49 | ||
|
|
b083cc99ad | ||
|
|
d35643524e | ||
|
|
62a8ced447 | ||
|
|
085f163c92 | ||
|
|
81a6b1e097 | ||
|
|
dd090c9e6b | ||
|
|
11faa48422 | ||
|
|
55596176c2 | ||
|
|
4d6822d312 | ||
|
|
985c365e6d | ||
|
|
af57762227 | ||
|
|
3575f9030e | ||
|
|
698d947d85 | ||
|
|
ff6528d2bf | ||
|
|
72ac75a98d | ||
|
|
5e3ecb74e4 | ||
|
|
c59be293c8 | ||
|
|
88b4cbdf1a | ||
|
|
d6afbc6f4e | ||
|
|
fc90de3e7b | ||
|
|
847c2ef114 | ||
|
|
a0bf0c1f4d | ||
|
|
8400ff0802 | ||
|
|
0ed6aa230b | ||
|
|
6d22ed80ec | ||
|
|
72d5360af9 | ||
|
|
ac3961e763 | ||
|
|
843466c822 | ||
|
|
8385035400 | ||
|
|
3adcc8be09 | ||
|
|
c853d56302 | ||
|
|
dc97be5b35 | ||
|
|
73dbdfff9f | ||
|
|
dff14e1258 | ||
|
|
10a3833893 | ||
|
|
247cb89332 | ||
|
|
3fc71a0266 | ||
|
|
392dcc3a05 | ||
|
|
f27015d1b7 | ||
|
|
86a19b41aa | ||
|
|
320164d476 | ||
|
|
40ae661ee5 | ||
|
|
ffb3eca68b | ||
|
|
c5def93bb8 | ||
|
|
a9c4c5833d | ||
|
|
58c9c4f590 | ||
|
|
24524d88cb | ||
|
|
b8849ab311 | ||
|
|
f3cd8f8ed0 | ||
|
|
2b26de3f3a | ||
|
|
0149c4c212 | ||
|
|
f2ed898784 | ||
|
|
464a476f9f | ||
|
|
e85d067fb5 | ||
|
|
7eb493990f | ||
|
|
04d5bf3afc | ||
|
|
403a13e4c6 | ||
|
|
ad0f035df5 | ||
|
|
a13e193d3b | ||
|
|
28a1a5ebc2 | ||
|
|
6310dc777f | ||
|
|
07f71fc895 | ||
|
|
f47b9178ad | ||
|
|
863de18877 | ||
|
|
4f399249bd | ||
|
|
f0e5cdee1a | ||
|
|
7bc7d0f5af | ||
|
|
a65a215fd7 | ||
|
|
80d34a226d | ||
|
|
a9628f73e3 | ||
|
|
9324237828 | ||
|
|
d1033c018a | ||
|
|
0f29052ade | ||
|
|
0578e84586 | ||
|
|
6ab41c466f | ||
|
|
98a1093ebf | ||
|
|
caf04373f3 | ||
|
|
d4e8526766 | ||
|
|
515b83a8c7 | ||
|
|
9bf2e03354 | ||
|
|
dc18595c8a | ||
|
|
488bcfe9c6 | ||
|
|
2900b93bb3 | ||
|
|
284cc8a321 | ||
|
|
3dc2e4036c | ||
|
|
268f6b0d51 | ||
|
|
46239b321b | ||
|
|
8a536cd522 | ||
|
|
f9f5d7ccbd | ||
|
|
d6cefdff8e | ||
|
|
dc410b14c4 | ||
|
|
4c49ef9403 | ||
|
|
ba85dcbda5 | ||
|
|
e08c84dd20 | ||
|
|
8b46136703 | ||
|
|
9c7089c8ff | ||
|
|
aac8d89cd0 | ||
|
|
24e75bfeab | ||
|
|
42868b08d3 | ||
|
|
19b61d9ac0 | ||
|
|
db2a2e2bb9 | ||
|
|
e1fdb12647 | ||
|
|
a8ec1b0949 | ||
|
|
2e30b2de77 | ||
|
|
7e407ccae1 | ||
|
|
0667e83919 | ||
|
|
1a6c9a4d04 | ||
|
|
14f5b912ad | ||
|
|
46d6242171 | ||
|
|
753b966148 | ||
|
|
5a307c19b8 | ||
|
|
2cd4f84800 | ||
|
|
4ae612090b | ||
|
|
c67ca4a09e | ||
|
|
94506220d3 | ||
|
|
dbd865a484 | ||
|
|
9d2f3e932a | ||
|
|
49d32f5b5b | ||
|
|
f7b74c0bcb | ||
|
|
c75cb0c7b7 | ||
|
|
a63b335149 | ||
|
|
d8517ce407 | ||
|
|
ad13b11464 | ||
|
|
99bc92d53d | ||
|
|
72199f5615 | ||
|
|
78b8b50082 | ||
|
|
3ab64ce00d | ||
|
|
651e44e0b6 | ||
|
|
963fa41a49 | ||
|
|
493f4f8b95 | ||
|
|
fc1bf36ace | ||
|
|
5ddee17411 | ||
|
|
5ce353bcde | ||
|
|
16d33199eb | ||
|
|
e02303a448 | ||
|
|
36fc966ad6 | ||
|
|
644f74400d | ||
|
|
b7cd451ddb | ||
|
|
59d7717963 | ||
|
|
88392efca4 | ||
|
|
907f2acc7e | ||
|
|
6616477bcf | ||
|
|
5b173cb879 | ||
|
|
dc6b466a42 | ||
|
|
8b04161da3 | ||
|
|
5a85765360 | ||
|
|
333940919b | ||
|
|
b9476be9ad | ||
|
|
704c60491c | ||
|
|
4a8e612c6e | ||
|
|
5e5c9c2580 | ||
|
|
4e71ec5738 | ||
|
|
1004f10384 | ||
|
|
1051648ffb | ||
|
|
7255a09705 | ||
|
|
c2bf6b5f13 | ||
|
|
d8e699b588 | ||
|
|
3e4d4705f5 | ||
|
|
c8b2804446 | ||
|
|
e732f2589f | ||
|
|
aec5543081 | ||
|
|
e03d90ca57 | ||
|
|
495ce62d9c | ||
|
|
fbc3959a5a | ||
|
|
246b11925c | ||
|
|
dfa9131192 | ||
|
|
88c801b4c2 | ||
|
|
a1b55b94e0 | ||
|
|
80db9e2e2f | ||
|
|
ce74690420 | ||
|
|
50de4dfb5d | ||
|
|
9bcdf860f4 | ||
|
|
511ab4b630 | ||
|
|
6f2b623e3c | ||
|
|
fa12165cd3 | ||
|
|
c0c6f3329d | ||
|
|
406a932467 | ||
|
|
cc96d4245f | ||
|
|
c6cdca8923 | ||
|
|
45edcafb06 | ||
|
|
9f0bcc131f | ||
|
|
7e331c2944 | ||
|
|
10347765cb | ||
|
|
c12dee4e76 | ||
|
|
772c188674 | ||
|
|
7c1a3bb8f9 | ||
|
|
8c3c0b1e13 | ||
|
|
1ad84ad51c | ||
|
|
64937c3f77 | ||
|
|
50fd2218fa | ||
|
|
4c29a16271 | ||
|
|
762d3e92de | ||
|
|
2f97531d78 | ||
|
|
f6c7cae661 | ||
|
|
f1777a5bd2 | ||
|
|
78a06ae8cf | ||
|
|
d290df4aa9 | ||
|
|
e559744f32 | ||
|
|
67418e649a | ||
|
|
5adf9fab53 | ||
|
|
2491b686fa | ||
|
|
efd02b2f3e | ||
|
|
3b14078646 | ||
|
|
eb9d5632bc | ||
|
|
45f60edbb6 | ||
|
|
393ea6a7bb | ||
|
|
6ec6f1efe5 | ||
|
|
5d9598ea51 | ||
|
|
0d36d99a73 | ||
|
|
d8a9f5a724 | ||
|
|
2c66e1a042 | ||
|
|
d5eccdb00f | ||
|
|
32626573a6 | ||
|
|
caa82b8f7e | ||
|
|
5af347b499 | ||
|
|
4ed5bb5a9e | ||
|
|
2478d45673 | ||
|
|
1bc7d94111 | ||
|
|
6432414cd5 | ||
|
|
179064ba15 | ||
|
|
783b2d70a5 | ||
|
|
80824f3fc1 | ||
|
|
f39f5f531c | ||
|
|
56139c622f | ||
|
|
da02f6a39b | ||
|
|
548d5597c0 | ||
|
|
7fd65d2412 | ||
|
|
05a54a4af9 | ||
|
|
1e00c8f456 | ||
|
|
90d165aa01 | ||
|
|
01603ca9e4 | ||
|
|
a1b6eb61f2 | ||
|
|
25f300d3ec | ||
|
|
41fe63df06 | ||
|
|
b312170d5f | ||
|
|
cf7f2e8f44 | ||
|
|
d292083ed1 | ||
|
|
9b11142b45 | ||
|
|
acdbc4d7b9 | ||
|
|
838d10a09d | ||
|
|
3852aa056b | ||
|
|
ae77e4528f | ||
|
|
9303f4fc5b | ||
|
|
8be9f4cb0e | ||
|
|
1ea12b1bf7 | ||
|
|
65e6d68355 | ||
|
|
9732eb8836 | ||
|
|
5ae668bc70 | ||
|
|
fd4d1bcca3 | ||
|
|
0a251c9f8e | ||
|
|
351d77be59 | ||
|
|
0e2fc80509 | ||
|
|
8f3fdecb93 | ||
|
|
249a205d8e | ||
|
|
7485801222 | ||
|
|
4678e59737 | ||
|
|
952d351c00 | ||
|
|
901eb55b0e | ||
|
|
727586e40e | ||
|
|
3aa678a58e | ||
|
|
fc7c1a8113 | ||
|
|
f62a0bbe75 | ||
|
|
7341172739 | ||
|
|
91b9fbe450 | ||
|
|
e6b566b848 | ||
|
|
2527a711dc | ||
|
|
5fba6b1cae | ||
|
|
43e632f83c | ||
|
|
623298b0e9 | ||
|
|
85a61dc39d | ||
|
|
6e8c44b5e6 | ||
|
|
ec4dcc174e | ||
|
|
b247aca3b4 | ||
|
|
6226bfd196 | ||
|
|
71e11b7cf8 | ||
|
|
800c62fdb6 | ||
|
|
640b9cd53a | ||
|
|
f4add16aea | ||
|
|
2bfec3c4ed | ||
|
|
9963b51c04 | ||
|
|
2af3494d8c | ||
|
|
fe28473ba8 | ||
|
|
53d66bc74a | ||
|
|
e2c1ad5342 | ||
|
|
6399c5fb04 | ||
|
|
784cf4f26a | ||
|
|
0301b1a999 | ||
|
|
3ab2cd5e71 | ||
|
|
6ea669531a | ||
|
|
cbbada4748 | ||
|
|
152b8d1233 | ||
|
|
bdad225033 | ||
|
|
8eeb58e467 | ||
|
|
91971433d2 | ||
|
|
a0a4bd457f | ||
|
|
4ffc050eed | ||
|
|
60678419a0 | ||
|
|
648dcc9305 | ||
|
|
190529184e | ||
|
|
46eb81466d | ||
|
|
9c70c487b9 | ||
|
|
43234d7c3e | ||
|
|
dbf878dc3f | ||
|
|
f6c0bd88d7 | ||
|
|
8440b7fbf1 | ||
|
|
808ab54135 | ||
|
|
52b29ad680 | ||
|
|
d41bf9c587 | ||
|
|
b758825164 | ||
|
|
779dfe5473 | ||
|
|
afb21220e2 | ||
|
|
f9a4c7518e | ||
|
|
bad2fdf69f | ||
|
|
a84df469cd | ||
|
|
03e33e39bd | ||
|
|
753fb69272 | ||
|
|
81a5f3a395 | ||
|
|
696a8d82fd | ||
|
|
5f294b1fea | ||
|
|
2d8f5e80fb | ||
|
|
7a042db78e | ||
|
|
41ce311836 | ||
|
|
03538d0f8a | ||
|
|
86bc222dc0 | ||
|
|
e8d285fdab | ||
|
|
852c933c92 | ||
|
|
7867a99a54 | ||
|
|
6cd14bb503 | ||
|
|
532b99ffea | ||
|
|
d80f40ff5d |
6
.github/workflows/code-check.yml
vendored
6
.github/workflows/code-check.yml
vendored
@@ -14,6 +14,10 @@ jobs:
|
||||
check:
|
||||
name: Check Code
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Check out from Git
|
||||
@@ -25,7 +29,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
43
.github/workflows/python-avatar.yml
vendored
Normal file
43
.github/workflows/python-avatar.yml
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
name: Python Avatar
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Avatar [${{ matrix.shard }}]
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
shard: [
|
||||
1/24, 2/24, 3/24, 4/24,
|
||||
5/24, 6/24, 7/24, 8/24,
|
||||
9/24, 10/24, 11/24, 12/24,
|
||||
13/24, 14/24, 15/24, 16/24,
|
||||
17/24, 18/24, 19/24, 20/24,
|
||||
21/24, 22/24, 23/24, 24/24,
|
||||
]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set Up Python 3.11
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.11
|
||||
- name: Install
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install .[avatar]
|
||||
- name: Rootcanal
|
||||
run: nohup python -m rootcanal > rootcanal.log &
|
||||
- name: Test
|
||||
run: |
|
||||
avatar --list | grep -Ev '^=' > test-names.txt
|
||||
timeout 5m avatar --test-beds bumble.bumbles --tests $(split test-names.txt -n l/${{ matrix.shard }})
|
||||
- name: Rootcanal Logs
|
||||
run: cat rootcanal.log
|
||||
45
.github/workflows/python-build-test.yml
vendored
45
.github/workflows/python-build-test.yml
vendored
@@ -12,11 +12,11 @@ permissions:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10"]
|
||||
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
@@ -41,3 +41,42 @@ jobs:
|
||||
run: |
|
||||
inv build
|
||||
inv build.mkdocs
|
||||
|
||||
build-rust:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [ "3.8", "3.9", "3.10", "3.11" ]
|
||||
rust-version: [ "1.70.0", "stable" ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out from Git
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[build,test,development,documentation]"
|
||||
- name: Install Rust toolchain
|
||||
uses: actions-rust-lang/setup-rust-toolchain@v1
|
||||
with:
|
||||
components: clippy,rustfmt
|
||||
toolchain: ${{ matrix.rust-version }}
|
||||
- name: Install Rust dependencies
|
||||
run: cargo install cargo-all-features # allows building/testing combinations of features
|
||||
- name: Check License Headers
|
||||
run: cd rust && cargo run --features dev-tools --bin file-header check-all
|
||||
- name: Rust Build
|
||||
run: cd rust && cargo build --all-targets && cargo build-all-features --all-targets
|
||||
# Lints after build so what clippy needs is already built
|
||||
- name: Rust Lints
|
||||
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings && cargo clippy --all-features --all-targets -- --deny warnings
|
||||
- name: Rust Tests
|
||||
run: cd rust && cargo test-all-features
|
||||
# At some point, hook up publishing the binary. For now, just make sure it builds.
|
||||
# Once we're ready to publish binaries, this should be built with `--release`.
|
||||
- name: Build Bumble CLI
|
||||
run: cd rust && cargo build --features bumble-tools --bin bumble
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -9,3 +9,6 @@ __pycache__
|
||||
# generated by setuptools_scm
|
||||
bumble/_version.py
|
||||
.vscode/launch.json
|
||||
/.idea
|
||||
venv/
|
||||
.venv/
|
||||
|
||||
15
.vscode/settings.json
vendored
15
.vscode/settings.json
vendored
@@ -12,7 +12,9 @@
|
||||
"ASHA",
|
||||
"asyncio",
|
||||
"ATRAC",
|
||||
"avctp",
|
||||
"avdtp",
|
||||
"avrcp",
|
||||
"bitpool",
|
||||
"bitstruct",
|
||||
"BSCP",
|
||||
@@ -21,7 +23,10 @@
|
||||
"cccds",
|
||||
"cmac",
|
||||
"CONNECTIONLESS",
|
||||
"csip",
|
||||
"csis",
|
||||
"csrcs",
|
||||
"CVSD",
|
||||
"datagram",
|
||||
"DATALINK",
|
||||
"delayreport",
|
||||
@@ -29,6 +34,8 @@
|
||||
"deregistration",
|
||||
"dhkey",
|
||||
"diversifier",
|
||||
"endianness",
|
||||
"ESCO",
|
||||
"Fitbit",
|
||||
"GATTLINK",
|
||||
"HANDSFREE",
|
||||
@@ -38,13 +45,18 @@
|
||||
"libc",
|
||||
"libusb",
|
||||
"MITM",
|
||||
"MSBC",
|
||||
"NDIS",
|
||||
"netsim",
|
||||
"NONBLOCK",
|
||||
"NONCONN",
|
||||
"OXIMETER",
|
||||
"popleft",
|
||||
"PRAND",
|
||||
"protobuf",
|
||||
"psms",
|
||||
"pyee",
|
||||
"Pyodide",
|
||||
"pyusb",
|
||||
"rfcomm",
|
||||
"ROHC",
|
||||
@@ -52,6 +64,7 @@
|
||||
"SEID",
|
||||
"seids",
|
||||
"SERV",
|
||||
"SIRK",
|
||||
"ssrc",
|
||||
"strerror",
|
||||
"subband",
|
||||
@@ -61,6 +74,8 @@
|
||||
"substates",
|
||||
"tobytes",
|
||||
"tsep",
|
||||
"UNMUTE",
|
||||
"unmuted",
|
||||
"usbmodem",
|
||||
"vhci",
|
||||
"websockets",
|
||||
|
||||
741
apps/bench.py
741
apps/bench.py
File diff suppressed because it is too large
Load Diff
63
apps/ble_rpa_tool.py
Normal file
63
apps/ble_rpa_tool.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import click
|
||||
from bumble.colors import color
|
||||
from bumble.hci import Address
|
||||
from bumble.helpers import generate_irk, verify_rpa_with_irk
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
'''
|
||||
This is a tool for generating IRK, RPA,
|
||||
and verifying IRK/RPA pairs
|
||||
'''
|
||||
|
||||
|
||||
@click.command()
|
||||
def gen_irk() -> None:
|
||||
print(generate_irk().hex())
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("irk", type=str)
|
||||
def gen_rpa(irk: str) -> None:
|
||||
irk_bytes = bytes.fromhex(irk)
|
||||
rpa = Address.generate_private_address(irk_bytes)
|
||||
print(rpa.to_string(with_type_qualifier=False))
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("irk", type=str)
|
||||
@click.argument("rpa", type=str)
|
||||
def verify_rpa(irk: str, rpa: str) -> None:
|
||||
address = Address(rpa)
|
||||
irk_bytes = bytes.fromhex(irk)
|
||||
if verify_rpa_with_irk(address, irk_bytes):
|
||||
print(color("Verified", "green"))
|
||||
else:
|
||||
print(color("Not Verified", "red"))
|
||||
|
||||
|
||||
def main():
|
||||
cli.add_command(gen_irk)
|
||||
cli.add_command(gen_rpa)
|
||||
cli.add_command(verify_rpa)
|
||||
cli()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -777,7 +777,7 @@ class ConsoleApp:
|
||||
if not service:
|
||||
continue
|
||||
values = [
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
for connection in self.device.connections.values()
|
||||
]
|
||||
if not values:
|
||||
@@ -796,11 +796,11 @@ class ConsoleApp:
|
||||
if not characteristic:
|
||||
continue
|
||||
values = [
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
for connection in self.device.connections.values()
|
||||
]
|
||||
if not values:
|
||||
values = [attribute.read_value(None)]
|
||||
values = [await attribute.read_value(None)]
|
||||
|
||||
# TODO: future optimization: convert CCCD value to human readable string
|
||||
|
||||
@@ -944,7 +944,7 @@ class ConsoleApp:
|
||||
|
||||
# send data to any subscribers
|
||||
if isinstance(attribute, Characteristic):
|
||||
attribute.write_value(None, value)
|
||||
await attribute.write_value(None, value)
|
||||
if attribute.has_properties(Characteristic.NOTIFY):
|
||||
await self.device.gatt_server.notify_subscribers(attribute)
|
||||
if attribute.has_properties(Characteristic.INDICATE):
|
||||
@@ -1172,7 +1172,7 @@ class ScanResult:
|
||||
name = ''
|
||||
|
||||
# Remove any '/P' qualifier suffix from the address string
|
||||
address_str = str(self.address).replace('/P', '')
|
||||
address_str = self.address.to_string(with_type_qualifier=False)
|
||||
|
||||
# RSSI bar
|
||||
bar_string = rssi_bar(self.rssi)
|
||||
|
||||
@@ -18,30 +18,39 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import click
|
||||
from bumble.company_ids import COMPANY_IDENTIFIERS
|
||||
import time
|
||||
|
||||
import click
|
||||
|
||||
from bumble.company_ids import COMPANY_IDENTIFIERS
|
||||
from bumble.colors import color
|
||||
from bumble.core import name_or_number
|
||||
from bumble.hci import (
|
||||
map_null_terminated_utf8_string,
|
||||
LeFeatureMask,
|
||||
HCI_SUCCESS,
|
||||
HCI_LE_SUPPORTED_FEATURES_NAMES,
|
||||
HCI_VERSION_NAMES,
|
||||
LMP_VERSION_NAMES,
|
||||
HCI_Command,
|
||||
HCI_Command_Complete_Event,
|
||||
HCI_Command_Status_Event,
|
||||
HCI_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_Read_Buffer_Size_Command,
|
||||
HCI_READ_BD_ADDR_COMMAND,
|
||||
HCI_Read_BD_ADDR_Command,
|
||||
HCI_READ_LOCAL_NAME_COMMAND,
|
||||
HCI_Read_Local_Name_Command,
|
||||
HCI_LE_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_LE_Read_Buffer_Size_Command,
|
||||
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Maximum_Data_Length_Command,
|
||||
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
|
||||
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
|
||||
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
|
||||
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Suggested_Default_Data_Length_Command,
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
)
|
||||
from bumble.host import Host
|
||||
from bumble.transport import open_transport_or_link
|
||||
@@ -57,13 +66,14 @@ def command_succeeded(response):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_classic_info(host):
|
||||
async def get_classic_info(host: Host) -> None:
|
||||
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
|
||||
response = await host.send_command(HCI_Read_BD_ADDR_Command())
|
||||
if command_succeeded(response):
|
||||
print()
|
||||
print(
|
||||
color('Classic Address:', 'yellow'), response.return_parameters.bd_addr
|
||||
color('Classic Address:', 'yellow'),
|
||||
response.return_parameters.bd_addr.to_string(False),
|
||||
)
|
||||
|
||||
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
|
||||
@@ -77,7 +87,7 @@ async def get_classic_info(host):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_le_info(host):
|
||||
async def get_le_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
|
||||
@@ -116,13 +126,50 @@ async def get_le_info(host):
|
||||
'\n',
|
||||
)
|
||||
|
||||
if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_LE_Read_Suggested_Default_Data_Length_Command()
|
||||
)
|
||||
if command_succeeded(response):
|
||||
print(
|
||||
color('Suggested Default Data Length:', 'yellow'),
|
||||
f'{response.return_parameters.suggested_max_tx_octets}/'
|
||||
f'{response.return_parameters.suggested_max_tx_time}',
|
||||
'\n',
|
||||
)
|
||||
|
||||
print(color('LE Features:', 'yellow'))
|
||||
for feature in host.supported_le_features:
|
||||
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
|
||||
print(LeFeatureMask(feature).name)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(transport):
|
||||
async def get_acl_flow_control_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.hc_total_num_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_LE_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('LE ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(latency_probes, transport):
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
@@ -130,6 +177,23 @@ async def async_main(transport):
|
||||
host = Host(hci_source, hci_sink)
|
||||
await host.reset()
|
||||
|
||||
# Measure the latency if requested
|
||||
latencies = []
|
||||
if latency_probes:
|
||||
for _ in range(latency_probes):
|
||||
start = time.time()
|
||||
await host.send_command(HCI_Read_Local_Version_Information_Command())
|
||||
latencies.append(1000 * (time.time() - start))
|
||||
print(
|
||||
color('HCI Command Latency:', 'yellow'),
|
||||
(
|
||||
f'min={min(latencies):.2f}, '
|
||||
f'max={max(latencies):.2f}, '
|
||||
f'average={sum(latencies)/len(latencies):.2f}'
|
||||
),
|
||||
'\n',
|
||||
)
|
||||
|
||||
# Print version
|
||||
print(color('Version:', 'yellow'))
|
||||
print(
|
||||
@@ -153,6 +217,9 @@ async def async_main(transport):
|
||||
# Get the LE info
|
||||
await get_le_info(host)
|
||||
|
||||
# Print the ACL flow control info
|
||||
await get_acl_flow_control_info(host)
|
||||
|
||||
# Print the list of commands supported by the controller
|
||||
print()
|
||||
print(color('Supported Commands:', 'yellow'))
|
||||
@@ -162,10 +229,16 @@ async def async_main(transport):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--latency-probes',
|
||||
metavar='N',
|
||||
type=int,
|
||||
help='Send N commands to measure HCI transport latency statistics',
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(transport):
|
||||
def main(latency_probes, transport):
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
asyncio.run(async_main(transport))
|
||||
asyncio.run(async_main(latency_probes, transport))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
205
apps/controller_loopback.py
Normal file
205
apps/controller_loopback.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from bumble.colors import color
|
||||
from bumble.hci import (
|
||||
HCI_READ_LOOPBACK_MODE_COMMAND,
|
||||
HCI_Read_Loopback_Mode_Command,
|
||||
HCI_WRITE_LOOPBACK_MODE_COMMAND,
|
||||
HCI_Write_Loopback_Mode_Command,
|
||||
LoopbackMode,
|
||||
)
|
||||
from bumble.host import Host
|
||||
from bumble.transport import open_transport_or_link
|
||||
import click
|
||||
|
||||
|
||||
class Loopback:
|
||||
"""Send and receive ACL data packets in local loopback mode"""
|
||||
|
||||
def __init__(self, packet_size: int, packet_count: int, transport: str):
|
||||
self.transport = transport
|
||||
self.packet_size = packet_size
|
||||
self.packet_count = packet_count
|
||||
self.connection_handle: Optional[int] = None
|
||||
self.connection_event = asyncio.Event()
|
||||
self.done = asyncio.Event()
|
||||
self.expected_cid = 0
|
||||
self.bytes_received = 0
|
||||
self.start_timestamp = 0.0
|
||||
self.last_timestamp = 0.0
|
||||
|
||||
def on_connection(self, connection_handle: int, *args):
|
||||
"""Retrieve connection handle from new connection event"""
|
||||
if not self.connection_event.is_set():
|
||||
# save first connection handle for ACL
|
||||
# subsequent connections are SCO
|
||||
self.connection_handle = connection_handle
|
||||
self.connection_event.set()
|
||||
|
||||
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
|
||||
"""Calculate packet receive speed"""
|
||||
now = time.time()
|
||||
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
|
||||
assert connection_handle == self.connection_handle
|
||||
assert cid == self.expected_cid
|
||||
self.expected_cid += 1
|
||||
if cid == 0:
|
||||
self.start_timestamp = now
|
||||
else:
|
||||
elapsed_since_start = now - self.start_timestamp
|
||||
elapsed_since_last = now - self.last_timestamp
|
||||
self.bytes_received += len(pdu)
|
||||
instant_rx_speed = len(pdu) / elapsed_since_last
|
||||
average_rx_speed = self.bytes_received / elapsed_since_start
|
||||
print(
|
||||
color(
|
||||
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
|
||||
f' average={average_rx_speed:.4f}',
|
||||
'cyan',
|
||||
)
|
||||
)
|
||||
|
||||
self.last_timestamp = now
|
||||
|
||||
if self.expected_cid == self.packet_count:
|
||||
print(color('@@@ Received last packet', 'green'))
|
||||
self.done.set()
|
||||
|
||||
async def run(self):
|
||||
"""Run a loopback throughput test"""
|
||||
print(color('>>> Connecting to HCI...', 'green'))
|
||||
async with await open_transport_or_link(self.transport) as (
|
||||
hci_source,
|
||||
hci_sink,
|
||||
):
|
||||
print(color('>>> Connected', 'green'))
|
||||
|
||||
host = Host(hci_source, hci_sink)
|
||||
await host.reset()
|
||||
|
||||
# make sure data can fit in one l2cap pdu
|
||||
l2cap_header_size = 4
|
||||
|
||||
max_packet_size = (
|
||||
host.acl_packet_queue
|
||||
if host.acl_packet_queue
|
||||
else host.le_acl_packet_queue
|
||||
).max_packet_size - l2cap_header_size
|
||||
if self.packet_size > max_packet_size:
|
||||
print(
|
||||
color(
|
||||
f'!!! Packet size ({self.packet_size}) larger than max supported'
|
||||
f' size ({max_packet_size})',
|
||||
'red',
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if not host.supports_command(
|
||||
HCI_WRITE_LOOPBACK_MODE_COMMAND
|
||||
) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND):
|
||||
print(color('!!! Loopback mode not supported', 'red'))
|
||||
return
|
||||
|
||||
# set event callbacks
|
||||
host.on('connection', self.on_connection)
|
||||
host.on('l2cap_pdu', self.on_l2cap_pdu)
|
||||
|
||||
loopback_mode = LoopbackMode.LOCAL
|
||||
|
||||
print(color('### Setting loopback mode', 'blue'))
|
||||
await host.send_command(
|
||||
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
print(color('### Checking loopback mode', 'blue'))
|
||||
response = await host.send_command(
|
||||
HCI_Read_Loopback_Mode_Command(), check_result=True
|
||||
)
|
||||
if response.return_parameters.loopback_mode != loopback_mode:
|
||||
print(color('!!! Loopback mode mismatch', 'red'))
|
||||
return
|
||||
|
||||
await self.connection_event.wait()
|
||||
print(color('### Connected', 'cyan'))
|
||||
|
||||
print(color('=== Start sending', 'magenta'))
|
||||
start_time = time.time()
|
||||
bytes_sent = 0
|
||||
for cid in range(0, self.packet_count):
|
||||
# using the cid as an incremental index
|
||||
host.send_l2cap_pdu(
|
||||
self.connection_handle, cid, bytes(self.packet_size)
|
||||
)
|
||||
print(
|
||||
color(
|
||||
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
|
||||
)
|
||||
)
|
||||
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
|
||||
await asyncio.sleep(0) # yield to allow packet receive
|
||||
|
||||
await self.done.wait()
|
||||
print(color('=== Done!', 'magenta'))
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
average_tx_speed = bytes_sent / elapsed
|
||||
print(
|
||||
color(
|
||||
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
|
||||
f' in {elapsed:.2f} seconds)',
|
||||
'green',
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--packet-size',
|
||||
'-s',
|
||||
metavar='SIZE',
|
||||
type=click.IntRange(8, 4096),
|
||||
default=500,
|
||||
help='Packet size',
|
||||
)
|
||||
@click.option(
|
||||
'--packet-count',
|
||||
'-c',
|
||||
metavar='COUNT',
|
||||
type=click.IntRange(1, 65535),
|
||||
default=10,
|
||||
help='Packet count',
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(packet_size, packet_count, transport):
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
|
||||
loopback = Loopback(packet_size, packet_count, transport)
|
||||
asyncio.run(loopback.run())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -21,6 +21,7 @@ import struct
|
||||
import logging
|
||||
import click
|
||||
|
||||
from bumble import l2cap
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.core import AdvertisingData
|
||||
@@ -204,7 +205,7 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
|
||||
def __init__(self, device):
|
||||
def __init__(self, device: Device):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.peer = None
|
||||
@@ -218,7 +219,12 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
|
||||
|
||||
# Listen for incoming L2CAP CoC connections
|
||||
psm = 0xFB
|
||||
device.register_l2cap_channel_server(0xFB, self.on_coc)
|
||||
device.create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(
|
||||
psm=0xFB,
|
||||
),
|
||||
handler=self.on_coc,
|
||||
)
|
||||
print(f'### Listening for CoC connection on PSM {psm}')
|
||||
|
||||
# Setup the Gattlink service
|
||||
|
||||
@@ -20,6 +20,7 @@ import logging
|
||||
import os
|
||||
import click
|
||||
|
||||
from bumble import l2cap
|
||||
from bumble.colors import color
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.device import Device
|
||||
@@ -47,16 +48,17 @@ class ServerBridge:
|
||||
self.tcp_host = tcp_host
|
||||
self.tcp_port = tcp_port
|
||||
|
||||
async def start(self, device):
|
||||
# Listen for incoming L2CAP CoC connections
|
||||
device.register_l2cap_channel_server(
|
||||
psm=self.psm,
|
||||
server=self.on_coc,
|
||||
max_credits=self.max_credits,
|
||||
mtu=self.mtu,
|
||||
mps=self.mps,
|
||||
async def start(self, device: Device) -> None:
|
||||
# Listen for incoming L2CAP channel connections
|
||||
device.create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(
|
||||
psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits
|
||||
),
|
||||
handler=self.on_channel,
|
||||
)
|
||||
print(
|
||||
color(f'### Listening for channel connection on PSM {self.psm}', 'yellow')
|
||||
)
|
||||
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
|
||||
|
||||
def on_ble_connection(connection):
|
||||
def on_ble_disconnection(reason):
|
||||
@@ -73,7 +75,7 @@ class ServerBridge:
|
||||
await device.start_advertising(auto_restart=True)
|
||||
|
||||
# Called when a new L2CAP connection is established
|
||||
def on_coc(self, l2cap_channel):
|
||||
def on_channel(self, l2cap_channel):
|
||||
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
|
||||
|
||||
class Pipe:
|
||||
@@ -83,7 +85,7 @@ class ServerBridge:
|
||||
self.l2cap_channel = l2cap_channel
|
||||
|
||||
l2cap_channel.on('close', self.on_l2cap_close)
|
||||
l2cap_channel.sink = self.on_coc_sdu
|
||||
l2cap_channel.sink = self.on_channel_sdu
|
||||
|
||||
async def connect_to_tcp(self):
|
||||
# Connect to the TCP server
|
||||
@@ -105,7 +107,7 @@ class ServerBridge:
|
||||
asyncio.create_task(self.pipe.l2cap_channel.disconnect())
|
||||
|
||||
def data_received(self, data):
|
||||
print(f'<<< Received on TCP: {len(data)}')
|
||||
print(color(f'<<< [TCP DATA]: {len(data)} bytes', 'blue'))
|
||||
self.pipe.l2cap_channel.write(data)
|
||||
|
||||
try:
|
||||
@@ -123,11 +125,12 @@ class ServerBridge:
|
||||
await self.l2cap_channel.disconnect()
|
||||
|
||||
def on_l2cap_close(self):
|
||||
print(color('*** L2CAP channel closed', 'red'))
|
||||
self.l2cap_channel = None
|
||||
if self.tcp_transport is not None:
|
||||
self.tcp_transport.close()
|
||||
|
||||
def on_coc_sdu(self, sdu):
|
||||
def on_channel_sdu(self, sdu):
|
||||
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
|
||||
if self.tcp_transport is None:
|
||||
print(color('!!! TCP socket not open, dropping', 'red'))
|
||||
@@ -182,7 +185,7 @@ class ClientBridge:
|
||||
peer_name = writer.get_extra_info('peer_name')
|
||||
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
|
||||
|
||||
def on_coc_sdu(sdu):
|
||||
def on_channel_sdu(sdu):
|
||||
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
|
||||
l2cap_to_tcp_pipe.write(sdu)
|
||||
|
||||
@@ -194,11 +197,13 @@ class ClientBridge:
|
||||
# Connect a new L2CAP channel
|
||||
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
|
||||
try:
|
||||
l2cap_channel = await connection.open_l2cap_channel(
|
||||
psm=self.psm,
|
||||
max_credits=self.max_credits,
|
||||
mtu=self.mtu,
|
||||
mps=self.mps,
|
||||
l2cap_channel = await connection.create_l2cap_channel(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(
|
||||
psm=self.psm,
|
||||
max_credits=self.max_credits,
|
||||
mtu=self.mtu,
|
||||
mps=self.mps,
|
||||
)
|
||||
)
|
||||
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
|
||||
except Exception as error:
|
||||
@@ -206,7 +211,7 @@ class ClientBridge:
|
||||
writer.close()
|
||||
return
|
||||
|
||||
l2cap_channel.sink = on_coc_sdu
|
||||
l2cap_channel.sink = on_channel_sdu
|
||||
l2cap_channel.on('close', on_l2cap_close)
|
||||
|
||||
# Start a flow control pipe from L2CAP to TCP
|
||||
@@ -271,23 +276,29 @@ async def run(device_config, hci_transport, bridge):
|
||||
@click.pass_context
|
||||
@click.option('--device-config', help='Device configuration file', required=True)
|
||||
@click.option('--hci-transport', help='HCI transport', required=True)
|
||||
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
|
||||
@click.option('--psm', help='PSM for L2CAP', type=int, default=1234)
|
||||
@click.option(
|
||||
'--l2cap-coc-max-credits',
|
||||
help='Maximum L2CAP CoC Credits',
|
||||
'--l2cap-max-credits',
|
||||
help='Maximum L2CAP Credits',
|
||||
type=click.IntRange(1, 65535),
|
||||
default=128,
|
||||
)
|
||||
@click.option(
|
||||
'--l2cap-coc-mtu',
|
||||
help='L2CAP CoC MTU',
|
||||
type=click.IntRange(23, 65535),
|
||||
default=1022,
|
||||
'--l2cap-mtu',
|
||||
help='L2CAP MTU',
|
||||
type=click.IntRange(
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU,
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU,
|
||||
),
|
||||
default=1024,
|
||||
)
|
||||
@click.option(
|
||||
'--l2cap-coc-mps',
|
||||
help='L2CAP CoC MPS',
|
||||
type=click.IntRange(23, 65533),
|
||||
'--l2cap-mps',
|
||||
help='L2CAP MPS',
|
||||
type=click.IntRange(
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS,
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS,
|
||||
),
|
||||
default=1024,
|
||||
)
|
||||
def cli(
|
||||
@@ -295,17 +306,17 @@ def cli(
|
||||
device_config,
|
||||
hci_transport,
|
||||
psm,
|
||||
l2cap_coc_max_credits,
|
||||
l2cap_coc_mtu,
|
||||
l2cap_coc_mps,
|
||||
l2cap_max_credits,
|
||||
l2cap_mtu,
|
||||
l2cap_mps,
|
||||
):
|
||||
context.ensure_object(dict)
|
||||
context.obj['device_config'] = device_config
|
||||
context.obj['hci_transport'] = hci_transport
|
||||
context.obj['psm'] = psm
|
||||
context.obj['max_credits'] = l2cap_coc_max_credits
|
||||
context.obj['mtu'] = l2cap_coc_mtu
|
||||
context.obj['mps'] = l2cap_coc_mps
|
||||
context.obj['max_credits'] = l2cap_max_credits
|
||||
context.obj['mtu'] = l2cap_mtu
|
||||
context.obj['mps'] = l2cap_mps
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
79
apps/pair.py
79
apps/pair.py
@@ -24,10 +24,16 @@ from prompt_toolkit.shortcuts import PromptSession
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.pairing import PairingDelegate, PairingConfig
|
||||
from bumble.pairing import OobData, PairingDelegate, PairingConfig
|
||||
from bumble.smp import OobContext, OobLegacyContext
|
||||
from bumble.smp import error_name as smp_error_name
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.core import ProtocolError
|
||||
from bumble.core import (
|
||||
AdvertisingData,
|
||||
ProtocolError,
|
||||
BT_LE_TRANSPORT,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
)
|
||||
from bumble.gatt import (
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
GATT_GENERIC_ACCESS_SERVICE,
|
||||
@@ -46,11 +52,13 @@ from bumble.att import (
|
||||
class Waiter:
|
||||
instance = None
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, linger=False):
|
||||
self.done = asyncio.get_running_loop().create_future()
|
||||
self.linger = linger
|
||||
|
||||
def terminate(self):
|
||||
self.done.set_result(None)
|
||||
if not self.linger:
|
||||
self.done.set_result(None)
|
||||
|
||||
async def wait_until_terminated(self):
|
||||
return await self.done
|
||||
@@ -60,7 +68,7 @@ class Waiter:
|
||||
class Delegate(PairingDelegate):
|
||||
def __init__(self, mode, connection, capability_string, do_prompt):
|
||||
super().__init__(
|
||||
{
|
||||
io_capability={
|
||||
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
|
||||
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
|
||||
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
|
||||
@@ -285,7 +293,9 @@ async def pair(
|
||||
mitm,
|
||||
bond,
|
||||
ctkd,
|
||||
linger,
|
||||
io,
|
||||
oob,
|
||||
prompt,
|
||||
request,
|
||||
print_keys,
|
||||
@@ -294,7 +304,7 @@ async def pair(
|
||||
hci_transport,
|
||||
address_or_name,
|
||||
):
|
||||
Waiter.instance = Waiter()
|
||||
Waiter.instance = Waiter(linger=linger)
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
|
||||
@@ -306,6 +316,7 @@ async def pair(
|
||||
# Expose a GATT characteristic that can be used to trigger pairing by
|
||||
# responding with an authentication error when read
|
||||
if mode == 'le':
|
||||
device.le_enabled = True
|
||||
device.add_service(
|
||||
Service(
|
||||
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
|
||||
@@ -326,7 +337,6 @@ async def pair(
|
||||
# Select LE or Classic
|
||||
if mode == 'classic':
|
||||
device.classic_enabled = True
|
||||
device.le_enabled = False
|
||||
device.classic_smp_enabled = ctkd
|
||||
|
||||
# Get things going
|
||||
@@ -343,16 +353,51 @@ async def pair(
|
||||
await device.keystore.print(prefix=color('@@@ ', 'blue'))
|
||||
print(color('@@@-----------------------------------', 'blue'))
|
||||
|
||||
# Create an OOB context if needed
|
||||
if oob:
|
||||
our_oob_context = OobContext()
|
||||
shared_data = (
|
||||
None
|
||||
if oob == '-'
|
||||
else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob)))
|
||||
)
|
||||
legacy_context = OobLegacyContext()
|
||||
oob_contexts = PairingConfig.OobConfig(
|
||||
our_context=our_oob_context,
|
||||
peer_data=shared_data,
|
||||
legacy_context=legacy_context,
|
||||
)
|
||||
oob_data = OobData(
|
||||
address=device.random_address,
|
||||
shared_data=shared_data,
|
||||
legacy_context=legacy_context,
|
||||
)
|
||||
print(color('@@@-----------------------------------', 'yellow'))
|
||||
print(color('@@@ OOB Data:', 'yellow'))
|
||||
print(color(f'@@@ {our_oob_context.share()}', 'yellow'))
|
||||
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
|
||||
print(color(f'@@@ HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow'))
|
||||
print(color('@@@-----------------------------------', 'yellow'))
|
||||
else:
|
||||
oob_contexts = None
|
||||
|
||||
# Set up a pairing config factory
|
||||
device.pairing_config_factory = lambda connection: PairingConfig(
|
||||
sc, mitm, bond, Delegate(mode, connection, io, prompt)
|
||||
sc=sc,
|
||||
mitm=mitm,
|
||||
bonding=bond,
|
||||
oob=oob_contexts,
|
||||
delegate=Delegate(mode, connection, io, prompt),
|
||||
)
|
||||
|
||||
# Connect to a peer or wait for a connection
|
||||
device.on('connection', lambda connection: on_connection(connection, request))
|
||||
if address_or_name is not None:
|
||||
print(color(f'=== Connecting to {address_or_name}...', 'green'))
|
||||
connection = await device.connect(address_or_name)
|
||||
connection = await device.connect(
|
||||
address_or_name,
|
||||
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
|
||||
)
|
||||
|
||||
if not request:
|
||||
try:
|
||||
@@ -360,10 +405,9 @@ async def pair(
|
||||
await connection.pair()
|
||||
else:
|
||||
await connection.authenticate()
|
||||
return
|
||||
except ProtocolError as error:
|
||||
print(color(f'Pairing failed: {error}', 'red'))
|
||||
return
|
||||
|
||||
else:
|
||||
if mode == 'le':
|
||||
# Advertise so that peers can find us and connect
|
||||
@@ -413,6 +457,7 @@ class LogHandler(logging.Handler):
|
||||
help='Enable CTKD',
|
||||
show_default=True,
|
||||
)
|
||||
@click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
|
||||
@click.option(
|
||||
'--io',
|
||||
type=click.Choice(
|
||||
@@ -421,6 +466,14 @@ class LogHandler(logging.Handler):
|
||||
default='display+keyboard',
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
'--oob',
|
||||
metavar='<oob-data-hex>',
|
||||
help=(
|
||||
'Use OOB pairing with this data from the peer '
|
||||
'(use "-" to enable OOB without peer data)'
|
||||
),
|
||||
)
|
||||
@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
|
||||
@click.option(
|
||||
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
|
||||
@@ -440,7 +493,9 @@ def main(
|
||||
mitm,
|
||||
bond,
|
||||
ctkd,
|
||||
linger,
|
||||
io,
|
||||
oob,
|
||||
prompt,
|
||||
request,
|
||||
print_keys,
|
||||
@@ -463,7 +518,9 @@ def main(
|
||||
mitm,
|
||||
bond,
|
||||
ctkd,
|
||||
linger,
|
||||
io,
|
||||
oob,
|
||||
prompt,
|
||||
request,
|
||||
print_keys,
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import click
|
||||
import logging
|
||||
import json
|
||||
|
||||
from bumble.pandora import PandoraDevice, serve
|
||||
from bumble.pandora import PandoraDevice, Config, serve
|
||||
from typing import Dict, Any
|
||||
|
||||
BUMBLE_SERVER_GRPC_PORT = 7999
|
||||
ROOTCANAL_PORT_CUTTLEFISH = 7300
|
||||
@@ -18,12 +20,31 @@ ROOTCANAL_PORT_CUTTLEFISH = 7300
|
||||
help='HCI transport',
|
||||
default=f'tcp-client:127.0.0.1:<rootcanal-port>',
|
||||
)
|
||||
def main(grpc_port: int, rootcanal_port: int, transport: str) -> None:
|
||||
@click.option(
|
||||
'--config',
|
||||
help='Bumble json configuration file',
|
||||
)
|
||||
def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> None:
|
||||
if '<rootcanal-port>' in transport:
|
||||
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
|
||||
device = PandoraDevice({'transport': transport})
|
||||
|
||||
bumble_config = retrieve_config(config)
|
||||
bumble_config.setdefault('transport', transport)
|
||||
device = PandoraDevice(bumble_config)
|
||||
|
||||
server_config = Config()
|
||||
server_config.load_from_dict(bumble_config.get('server', {}))
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
asyncio.run(serve(device, port=grpc_port))
|
||||
asyncio.run(serve(device, config=server_config, port=grpc_port))
|
||||
|
||||
|
||||
def retrieve_config(config: str) -> Dict[str, Any]:
|
||||
if not config:
|
||||
return {}
|
||||
|
||||
with open(config, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
52
apps/scan.py
52
apps/scan.py
@@ -26,7 +26,7 @@ from bumble.transport import open_transport_or_link
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.smp import AddressResolver
|
||||
from bumble.device import Advertisement
|
||||
from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
|
||||
from bumble.hci import Address, HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -66,10 +66,15 @@ class AdvertisementPrinter:
|
||||
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
|
||||
address.address_type
|
||||
]
|
||||
if address.is_public:
|
||||
type_color = 'cyan'
|
||||
if address.address_type in (
|
||||
Address.RANDOM_IDENTITY_ADDRESS,
|
||||
Address.PUBLIC_IDENTITY_ADDRESS,
|
||||
):
|
||||
type_color = 'yellow'
|
||||
else:
|
||||
if address.is_static:
|
||||
if address.is_public:
|
||||
type_color = 'cyan'
|
||||
elif address.is_static:
|
||||
type_color = 'green'
|
||||
address_qualifier = '(static)'
|
||||
elif address.is_resolvable:
|
||||
@@ -116,6 +121,7 @@ async def scan(
|
||||
phy,
|
||||
filter_duplicates,
|
||||
raw,
|
||||
irks,
|
||||
keystore_file,
|
||||
device_config,
|
||||
transport,
|
||||
@@ -140,9 +146,21 @@ async def scan(
|
||||
|
||||
if device.keystore:
|
||||
resolving_keys = await device.keystore.get_resolving_keys()
|
||||
resolver = AddressResolver(resolving_keys)
|
||||
else:
|
||||
resolver = None
|
||||
resolving_keys = []
|
||||
|
||||
for irk_and_address in irks:
|
||||
if ':' not in irk_and_address:
|
||||
raise ValueError('invalid IRK:ADDRESS value')
|
||||
irk_hex, address_str = irk_and_address.split(':', 1)
|
||||
resolving_keys.append(
|
||||
(
|
||||
bytes.fromhex(irk_hex),
|
||||
Address(address_str, Address.RANDOM_DEVICE_ADDRESS),
|
||||
)
|
||||
)
|
||||
|
||||
resolver = AddressResolver(resolving_keys) if resolving_keys else None
|
||||
|
||||
printer = AdvertisementPrinter(min_rssi, resolver)
|
||||
if raw:
|
||||
@@ -187,8 +205,24 @@ async def scan(
|
||||
default=False,
|
||||
help='Listen for raw advertising reports instead of processed ones',
|
||||
)
|
||||
@click.option('--keystore-file', help='Keystore file to use when resolving addresses')
|
||||
@click.option('--device-config', help='Device config file for the scanning device')
|
||||
@click.option(
|
||||
'--irk',
|
||||
metavar='<IRK_HEX>:<ADDRESS>',
|
||||
help=(
|
||||
'Use this IRK for resolving private addresses ' '(may be used more than once)'
|
||||
),
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
'--keystore-file',
|
||||
metavar='FILE_PATH',
|
||||
help='Keystore file to use when resolving addresses',
|
||||
)
|
||||
@click.option(
|
||||
'--device-config',
|
||||
metavar='FILE_PATH',
|
||||
help='Device config file for the scanning device',
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(
|
||||
min_rssi,
|
||||
@@ -198,6 +232,7 @@ def main(
|
||||
phy,
|
||||
filter_duplicates,
|
||||
raw,
|
||||
irk,
|
||||
keystore_file,
|
||||
device_config,
|
||||
transport,
|
||||
@@ -212,6 +247,7 @@ def main(
|
||||
phy,
|
||||
filter_duplicates,
|
||||
raw,
|
||||
irk,
|
||||
keystore_file,
|
||||
device_config,
|
||||
transport,
|
||||
|
||||
15
apps/show.py
15
apps/show.py
@@ -102,9 +102,21 @@ class SnoopPacketReader:
|
||||
default='h4',
|
||||
help='Format of the input file',
|
||||
)
|
||||
@click.option(
|
||||
'--vendors',
|
||||
type=click.Choice(['android', 'zephyr']),
|
||||
multiple=True,
|
||||
help='Support vendor-specific commands (list one or more)',
|
||||
)
|
||||
@click.argument('filename')
|
||||
# pylint: disable=redefined-builtin
|
||||
def main(format, filename):
|
||||
def main(format, vendors, filename):
|
||||
for vendor in vendors:
|
||||
if vendor == 'android':
|
||||
import bumble.vendor.android.hci
|
||||
elif vendor == 'zephyr':
|
||||
import bumble.vendor.zephyr.hci
|
||||
|
||||
input = open(filename, 'rb')
|
||||
if format == 'h4':
|
||||
packet_reader = PacketReader(input)
|
||||
@@ -124,7 +136,6 @@ def main(format, filename):
|
||||
if packet is None:
|
||||
break
|
||||
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction)
|
||||
|
||||
except Exception as error:
|
||||
print(color(f'!!! {error}', 'red'))
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ body, h1, h2, h3, h4, h5, h6 {
|
||||
border-radius: 4px;
|
||||
padding: 4px;
|
||||
margin: 6px;
|
||||
margin-left: 0px;
|
||||
margin-left: 0;
|
||||
}
|
||||
|
||||
th, td {
|
||||
@@ -65,7 +65,7 @@ th, td {
|
||||
}
|
||||
|
||||
.properties td:nth-child(even) {
|
||||
background-color: #D6EEEE;
|
||||
background-color: #d6eeee;
|
||||
font-family: monospace;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
<html>
|
||||
<head>
|
||||
<title>Bumble Speaker</title>
|
||||
<script type="text/javascript" src="speaker.js"></script>
|
||||
<script src="speaker.js"></script>
|
||||
<link rel="stylesheet" href="speaker.css">
|
||||
</head>
|
||||
<body>
|
||||
|
||||
@@ -195,7 +195,7 @@ class WebSocketOutput(QueuedOutput):
|
||||
except HCI_StatusError:
|
||||
pass
|
||||
peer_name = '' if connection.peer_name is None else connection.peer_name
|
||||
peer_address = str(connection.peer_address).replace('/P', '')
|
||||
peer_address = connection.peer_address.to_string(False)
|
||||
await self.send_message(
|
||||
'connection',
|
||||
peer_address=peer_address,
|
||||
@@ -228,10 +228,11 @@ class FfplayOutput(QueuedOutput):
|
||||
subprocess: Optional[asyncio.subprocess.Process]
|
||||
ffplay_task: Optional[asyncio.Task]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(AacAudioExtractor())
|
||||
def __init__(self, codec: str) -> None:
|
||||
super().__init__(AudioExtractor.create(codec))
|
||||
self.subprocess = None
|
||||
self.ffplay_task = None
|
||||
self.codec = codec
|
||||
|
||||
async def start(self):
|
||||
if self.started:
|
||||
@@ -240,7 +241,7 @@ class FfplayOutput(QueuedOutput):
|
||||
await super().start()
|
||||
|
||||
self.subprocess = await asyncio.create_subprocess_shell(
|
||||
'ffplay -acodec aac pipe:0',
|
||||
f'ffplay -f {self.codec} pipe:0',
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
@@ -375,7 +376,7 @@ class UiServer:
|
||||
if connection := self.speaker().connection:
|
||||
await self.send_message(
|
||||
'connection',
|
||||
peer_address=str(connection.peer_address).replace('/P', ''),
|
||||
peer_address=connection.peer_address.to_string(False),
|
||||
peer_name=connection.peer_name,
|
||||
)
|
||||
|
||||
@@ -419,7 +420,7 @@ class Speaker:
|
||||
self.outputs = []
|
||||
for output in outputs:
|
||||
if output == '@ffplay':
|
||||
self.outputs.append(FfplayOutput())
|
||||
self.outputs.append(FfplayOutput(codec))
|
||||
continue
|
||||
|
||||
# Default to FileOutput
|
||||
@@ -640,7 +641,7 @@ class Speaker:
|
||||
self.device.on('connection', self.on_bluetooth_connection)
|
||||
|
||||
# Create a listener to wait for AVDTP connections
|
||||
self.listener = Listener(Listener.create_registrar(self.device))
|
||||
self.listener = Listener.for_device(self.device)
|
||||
self.listener.on('connection', self.on_avdtp_connection)
|
||||
|
||||
print(f'Speaker ready to play, codec={color(self.codec, "cyan")}')
|
||||
@@ -708,17 +709,6 @@ def speaker(
|
||||
):
|
||||
"""Run the speaker."""
|
||||
|
||||
# ffplay only works with AAC for now
|
||||
if codec != 'aac' and '@ffplay' in output:
|
||||
print(
|
||||
color(
|
||||
f'{codec} not supported with @ffplay output, '
|
||||
'@ffplay output will be skipped',
|
||||
'yellow',
|
||||
)
|
||||
)
|
||||
output = list(filter(lambda x: x != '@ffplay', output))
|
||||
|
||||
if '@ffplay' in output:
|
||||
# Check if ffplay is installed
|
||||
try:
|
||||
|
||||
167
bumble/a2dp.py
167
bumble/a2dp.py
@@ -15,9 +15,13 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import struct
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import List, Callable, Awaitable
|
||||
|
||||
from .company_ids import COMPANY_IDENTIFIERS
|
||||
from .sdp import (
|
||||
@@ -180,8 +184,12 @@ def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3))
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
|
||||
DataElement.unsigned_integer_16(version_int),
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
|
||||
DataElement.unsigned_integer_16(version_int),
|
||||
]
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
@@ -230,8 +238,12 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
|
||||
DataElement.unsigned_integer_16(version_int),
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
|
||||
DataElement.unsigned_integer_16(version_int),
|
||||
]
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
@@ -239,24 +251,20 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SbcMediaCodecInformation(
|
||||
namedtuple(
|
||||
'SbcMediaCodecInformation',
|
||||
[
|
||||
'sampling_frequency',
|
||||
'channel_mode',
|
||||
'block_length',
|
||||
'subbands',
|
||||
'allocation_method',
|
||||
'minimum_bitpool_value',
|
||||
'maximum_bitpool_value',
|
||||
],
|
||||
)
|
||||
):
|
||||
@dataclasses.dataclass
|
||||
class SbcMediaCodecInformation:
|
||||
'''
|
||||
A2DP spec - 4.3.2 Codec Specific Information Elements
|
||||
'''
|
||||
|
||||
sampling_frequency: int
|
||||
channel_mode: int
|
||||
block_length: int
|
||||
subbands: int
|
||||
allocation_method: int
|
||||
minimum_bitpool_value: int
|
||||
maximum_bitpool_value: int
|
||||
|
||||
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
|
||||
CHANNEL_MODE_BITS = {
|
||||
SBC_MONO_CHANNEL_MODE: 1 << 3,
|
||||
@@ -272,7 +280,7 @@ class SbcMediaCodecInformation(
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> 'SbcMediaCodecInformation':
|
||||
def from_bytes(data: bytes) -> SbcMediaCodecInformation:
|
||||
sampling_frequency = (data[0] >> 4) & 0x0F
|
||||
channel_mode = (data[0] >> 0) & 0x0F
|
||||
block_length = (data[1] >> 4) & 0x0F
|
||||
@@ -293,14 +301,14 @@ class SbcMediaCodecInformation(
|
||||
@classmethod
|
||||
def from_discrete_values(
|
||||
cls,
|
||||
sampling_frequency,
|
||||
channel_mode,
|
||||
block_length,
|
||||
subbands,
|
||||
allocation_method,
|
||||
minimum_bitpool_value,
|
||||
maximum_bitpool_value,
|
||||
):
|
||||
sampling_frequency: int,
|
||||
channel_mode: int,
|
||||
block_length: int,
|
||||
subbands: int,
|
||||
allocation_method: int,
|
||||
minimum_bitpool_value: int,
|
||||
maximum_bitpool_value: int,
|
||||
) -> SbcMediaCodecInformation:
|
||||
return SbcMediaCodecInformation(
|
||||
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
|
||||
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
|
||||
@@ -314,14 +322,14 @@ class SbcMediaCodecInformation(
|
||||
@classmethod
|
||||
def from_lists(
|
||||
cls,
|
||||
sampling_frequencies,
|
||||
channel_modes,
|
||||
block_lengths,
|
||||
subbands,
|
||||
allocation_methods,
|
||||
minimum_bitpool_value,
|
||||
maximum_bitpool_value,
|
||||
):
|
||||
sampling_frequencies: List[int],
|
||||
channel_modes: List[int],
|
||||
block_lengths: List[int],
|
||||
subbands: List[int],
|
||||
allocation_methods: List[int],
|
||||
minimum_bitpool_value: int,
|
||||
maximum_bitpool_value: int,
|
||||
) -> SbcMediaCodecInformation:
|
||||
return SbcMediaCodecInformation(
|
||||
sampling_frequency=sum(
|
||||
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
|
||||
@@ -348,7 +356,7 @@ class SbcMediaCodecInformation(
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
|
||||
allocation_methods = ['SNR', 'Loudness']
|
||||
return '\n'.join(
|
||||
@@ -367,16 +375,19 @@ class SbcMediaCodecInformation(
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AacMediaCodecInformation(
|
||||
namedtuple(
|
||||
'AacMediaCodecInformation',
|
||||
['object_type', 'sampling_frequency', 'channels', 'rfa', 'vbr', 'bitrate'],
|
||||
)
|
||||
):
|
||||
@dataclasses.dataclass
|
||||
class AacMediaCodecInformation:
|
||||
'''
|
||||
A2DP spec - 4.5.2 Codec Specific Information Elements
|
||||
'''
|
||||
|
||||
object_type: int
|
||||
sampling_frequency: int
|
||||
channels: int
|
||||
rfa: int
|
||||
vbr: int
|
||||
bitrate: int
|
||||
|
||||
OBJECT_TYPE_BITS = {
|
||||
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
|
||||
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
|
||||
@@ -400,7 +411,7 @@ class AacMediaCodecInformation(
|
||||
CHANNELS_BITS = {1: 1 << 1, 2: 1}
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> 'AacMediaCodecInformation':
|
||||
def from_bytes(data: bytes) -> AacMediaCodecInformation:
|
||||
object_type = data[0]
|
||||
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
|
||||
channels = (data[2] >> 2) & 0x03
|
||||
@@ -413,8 +424,13 @@ class AacMediaCodecInformation(
|
||||
|
||||
@classmethod
|
||||
def from_discrete_values(
|
||||
cls, object_type, sampling_frequency, channels, vbr, bitrate
|
||||
):
|
||||
cls,
|
||||
object_type: int,
|
||||
sampling_frequency: int,
|
||||
channels: int,
|
||||
vbr: int,
|
||||
bitrate: int,
|
||||
) -> AacMediaCodecInformation:
|
||||
return AacMediaCodecInformation(
|
||||
object_type=cls.OBJECT_TYPE_BITS[object_type],
|
||||
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
|
||||
@@ -425,7 +441,14 @@ class AacMediaCodecInformation(
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate):
|
||||
def from_lists(
|
||||
cls,
|
||||
object_types: List[int],
|
||||
sampling_frequencies: List[int],
|
||||
channels: List[int],
|
||||
vbr: int,
|
||||
bitrate: int,
|
||||
) -> AacMediaCodecInformation:
|
||||
return AacMediaCodecInformation(
|
||||
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
|
||||
sampling_frequency=sum(
|
||||
@@ -449,7 +472,7 @@ class AacMediaCodecInformation(
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
object_types = [
|
||||
'MPEG_2_AAC_LC',
|
||||
'MPEG_4_AAC_LC',
|
||||
@@ -474,26 +497,26 @@ class AacMediaCodecInformation(
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
# -----------------------------------------------------------------------------
|
||||
class VendorSpecificMediaCodecInformation:
|
||||
'''
|
||||
A2DP spec - 4.7.2 Codec Specific Information Elements
|
||||
'''
|
||||
|
||||
vendor_id: int
|
||||
codec_id: int
|
||||
value: bytes
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data):
|
||||
def from_bytes(data: bytes) -> VendorSpecificMediaCodecInformation:
|
||||
(vendor_id, codec_id) = struct.unpack_from('<IH', data, 0)
|
||||
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
|
||||
|
||||
def __init__(self, vendor_id, codec_id, value):
|
||||
self.vendor_id = vendor_id
|
||||
self.codec_id = codec_id
|
||||
self.value = value
|
||||
|
||||
def __bytes__(self):
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
# pylint: disable=line-too-long
|
||||
return '\n'.join(
|
||||
[
|
||||
@@ -506,29 +529,27 @@ class VendorSpecificMediaCodecInformation:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class SbcFrame:
|
||||
def __init__(
|
||||
self, sampling_frequency, block_count, channel_mode, subband_count, payload
|
||||
):
|
||||
self.sampling_frequency = sampling_frequency
|
||||
self.block_count = block_count
|
||||
self.channel_mode = channel_mode
|
||||
self.subband_count = subband_count
|
||||
self.payload = payload
|
||||
sampling_frequency: int
|
||||
block_count: int
|
||||
channel_mode: int
|
||||
subband_count: int
|
||||
payload: bytes
|
||||
|
||||
@property
|
||||
def sample_count(self):
|
||||
def sample_count(self) -> int:
|
||||
return self.subband_count * self.block_count
|
||||
|
||||
@property
|
||||
def bitrate(self):
|
||||
def bitrate(self) -> int:
|
||||
return 8 * ((len(self.payload) * self.sampling_frequency) // self.sample_count)
|
||||
|
||||
@property
|
||||
def duration(self):
|
||||
def duration(self) -> float:
|
||||
return self.sample_count / self.sampling_frequency
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'SBC(sf={self.sampling_frequency},'
|
||||
f'cm={self.channel_mode},'
|
||||
@@ -540,12 +561,12 @@ class SbcFrame:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SbcParser:
|
||||
def __init__(self, read):
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
|
||||
self.read = read
|
||||
|
||||
@property
|
||||
def frames(self):
|
||||
async def generate_frames():
|
||||
def frames(self) -> AsyncGenerator[SbcFrame, None]:
|
||||
async def generate_frames() -> AsyncGenerator[SbcFrame, None]:
|
||||
while True:
|
||||
# Read 4 bytes of header
|
||||
header = await self.read(4)
|
||||
@@ -589,7 +610,9 @@ class SbcParser:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SbcPacketSource:
|
||||
def __init__(self, read, mtu, codec_capabilities):
|
||||
def __init__(
|
||||
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
|
||||
) -> None:
|
||||
self.read = read
|
||||
self.mtu = mtu
|
||||
self.codec_capabilities = codec_capabilities
|
||||
|
||||
85
bumble/at.py
Normal file
85
bumble/at.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
||||
"""Split input parameters into tokens.
|
||||
Removes space characters outside of double quote blocks:
|
||||
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
|
||||
are ignored [..], unless they are embedded in numeric or string constants"
|
||||
Raises ValueError in case of invalid input string."""
|
||||
|
||||
tokens = []
|
||||
in_quotes = False
|
||||
token = bytearray()
|
||||
for b in buffer:
|
||||
char = bytearray([b])
|
||||
|
||||
if in_quotes:
|
||||
token.extend(char)
|
||||
if char == b'\"':
|
||||
in_quotes = False
|
||||
tokens.append(token[1:-1])
|
||||
token = bytearray()
|
||||
else:
|
||||
if char == b' ':
|
||||
pass
|
||||
elif char == b',' or char == b')':
|
||||
tokens.append(token)
|
||||
tokens.append(char)
|
||||
token = bytearray()
|
||||
elif char == b'(':
|
||||
if len(token) > 0:
|
||||
raise ValueError("open_paren following regular character")
|
||||
tokens.append(char)
|
||||
elif char == b'"':
|
||||
if len(token) > 0:
|
||||
raise ValueError("quote following regular character")
|
||||
in_quotes = True
|
||||
token.extend(char)
|
||||
else:
|
||||
token.extend(char)
|
||||
|
||||
tokens.append(token)
|
||||
return [bytes(token) for token in tokens if len(token) > 0]
|
||||
|
||||
|
||||
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
||||
"""Parse the parameters using the comma and parenthesis separators.
|
||||
Raises ValueError in case of invalid input string."""
|
||||
|
||||
tokens = tokenize_parameters(buffer)
|
||||
accumulator: List[list] = [[]]
|
||||
current: Union[bytes, list] = bytes()
|
||||
|
||||
for token in tokens:
|
||||
if token == b',':
|
||||
accumulator[-1].append(current)
|
||||
current = bytes()
|
||||
elif token == b'(':
|
||||
accumulator.append([])
|
||||
elif token == b')':
|
||||
if len(accumulator) < 2:
|
||||
raise ValueError("close_paren without matching open_paren")
|
||||
accumulator[-1].append(current)
|
||||
current = accumulator.pop()
|
||||
else:
|
||||
current = token
|
||||
|
||||
accumulator[-1].append(current)
|
||||
if len(accumulator) > 1:
|
||||
raise ValueError("missing close_paren")
|
||||
return accumulator[0]
|
||||
168
bumble/att.py
168
bumble/att.py
@@ -23,13 +23,26 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import struct
|
||||
from pyee import EventEmitter
|
||||
from typing import Dict, Type, TYPE_CHECKING
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError
|
||||
from bumble.hci import HCI_Object, key_with_value, HCI_Constant
|
||||
from pyee import EventEmitter
|
||||
|
||||
from bumble.core import UUID, name_or_number, ProtocolError
|
||||
from bumble.hci import HCI_Object, key_with_value
|
||||
from bumble.colors import color
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -182,6 +195,7 @@ UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
|
||||
# pylint: enable=line-too-long
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Exceptions
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -209,7 +223,7 @@ class ATT_PDU:
|
||||
|
||||
pdu_classes: Dict[int, Type[ATT_PDU]] = {}
|
||||
op_code = 0
|
||||
name = None
|
||||
name: str
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(pdu):
|
||||
@@ -719,48 +733,94 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AttributeValue:
|
||||
'''
|
||||
Attribute value where reading and/or writing is delegated to functions
|
||||
passed as arguments to the constructor.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read: Union[
|
||||
Callable[[Optional[Connection]], bytes],
|
||||
Callable[[Optional[Connection]], Awaitable[bytes]],
|
||||
None,
|
||||
] = None,
|
||||
write: Union[
|
||||
Callable[[Optional[Connection], bytes], None],
|
||||
Callable[[Optional[Connection], bytes], Awaitable[None]],
|
||||
None,
|
||||
] = None,
|
||||
):
|
||||
self._read = read
|
||||
self._write = write
|
||||
|
||||
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
|
||||
return self._read(connection) if self._read else b''
|
||||
|
||||
def write(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
) -> Union[Awaitable[None], None]:
|
||||
if self._write:
|
||||
return self._write(connection, value)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Attribute(EventEmitter):
|
||||
# Permission flags
|
||||
READABLE = 0x01
|
||||
WRITEABLE = 0x02
|
||||
READ_REQUIRES_ENCRYPTION = 0x04
|
||||
WRITE_REQUIRES_ENCRYPTION = 0x08
|
||||
READ_REQUIRES_AUTHENTICATION = 0x10
|
||||
WRITE_REQUIRES_AUTHENTICATION = 0x20
|
||||
READ_REQUIRES_AUTHORIZATION = 0x40
|
||||
WRITE_REQUIRES_AUTHORIZATION = 0x80
|
||||
class Permissions(enum.IntFlag):
|
||||
READABLE = 0x01
|
||||
WRITEABLE = 0x02
|
||||
READ_REQUIRES_ENCRYPTION = 0x04
|
||||
WRITE_REQUIRES_ENCRYPTION = 0x08
|
||||
READ_REQUIRES_AUTHENTICATION = 0x10
|
||||
WRITE_REQUIRES_AUTHENTICATION = 0x20
|
||||
READ_REQUIRES_AUTHORIZATION = 0x40
|
||||
WRITE_REQUIRES_AUTHORIZATION = 0x80
|
||||
|
||||
PERMISSION_NAMES = {
|
||||
READABLE: 'READABLE',
|
||||
WRITEABLE: 'WRITEABLE',
|
||||
READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION',
|
||||
WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION',
|
||||
READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION',
|
||||
WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION',
|
||||
READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION',
|
||||
WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION',
|
||||
}
|
||||
@classmethod
|
||||
def from_string(cls, permissions_str: str) -> Attribute.Permissions:
|
||||
try:
|
||||
return functools.reduce(
|
||||
lambda x, y: x | Attribute.Permissions[y],
|
||||
permissions_str.replace('|', ',').split(","),
|
||||
Attribute.Permissions(0),
|
||||
)
|
||||
except TypeError as exc:
|
||||
# The check for `p.name is not None` here is needed because for InFlag
|
||||
# enums, the .name property can be None, when the enum value is 0,
|
||||
# so the type hint for .name is Optional[str].
|
||||
enum_list: List[str] = [p.name for p in cls if p.name is not None]
|
||||
enum_list_str = ",".join(enum_list)
|
||||
raise TypeError(
|
||||
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}"
|
||||
) from exc
|
||||
|
||||
@staticmethod
|
||||
def string_to_permissions(permissions_str: str):
|
||||
try:
|
||||
return functools.reduce(
|
||||
lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y),
|
||||
permissions_str.split(","),
|
||||
0,
|
||||
)
|
||||
except TypeError as exc:
|
||||
raise TypeError(
|
||||
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}"
|
||||
) from exc
|
||||
# Permission flags(legacy-use only)
|
||||
READABLE = Permissions.READABLE
|
||||
WRITEABLE = Permissions.WRITEABLE
|
||||
READ_REQUIRES_ENCRYPTION = Permissions.READ_REQUIRES_ENCRYPTION
|
||||
WRITE_REQUIRES_ENCRYPTION = Permissions.WRITE_REQUIRES_ENCRYPTION
|
||||
READ_REQUIRES_AUTHENTICATION = Permissions.READ_REQUIRES_AUTHENTICATION
|
||||
WRITE_REQUIRES_AUTHENTICATION = Permissions.WRITE_REQUIRES_AUTHENTICATION
|
||||
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
|
||||
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
|
||||
|
||||
def __init__(self, attribute_type, permissions, value=b''):
|
||||
value: Union[bytes, AttributeValue]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attribute_type: Union[str, bytes, UUID],
|
||||
permissions: Union[str, Attribute.Permissions],
|
||||
value: Union[str, bytes, AttributeValue] = b'',
|
||||
) -> None:
|
||||
EventEmitter.__init__(self)
|
||||
self.handle = 0
|
||||
self.end_group_handle = 0
|
||||
if isinstance(permissions, str):
|
||||
self.permissions = self.string_to_permissions(permissions)
|
||||
self.permissions = Attribute.Permissions.from_string(permissions)
|
||||
else:
|
||||
self.permissions = permissions
|
||||
|
||||
@@ -778,22 +838,26 @@ class Attribute(EventEmitter):
|
||||
else:
|
||||
self.value = value
|
||||
|
||||
def encode_value(self, value):
|
||||
def encode_value(self, value: Any) -> bytes:
|
||||
return value
|
||||
|
||||
def decode_value(self, value_bytes):
|
||||
def decode_value(self, value_bytes: bytes) -> Any:
|
||||
return value_bytes
|
||||
|
||||
def read_value(self, connection: Connection):
|
||||
async def read_value(self, connection: Optional[Connection]) -> bytes:
|
||||
if (
|
||||
self.permissions & self.READ_REQUIRES_ENCRYPTION
|
||||
) and not connection.encryption:
|
||||
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
|
||||
and connection is not None
|
||||
and not connection.encryption
|
||||
):
|
||||
raise ATT_Error(
|
||||
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
|
||||
)
|
||||
if (
|
||||
self.permissions & self.READ_REQUIRES_AUTHENTICATION
|
||||
) and not connection.authenticated:
|
||||
(self.permissions & self.READ_REQUIRES_AUTHENTICATION)
|
||||
and connection is not None
|
||||
and not connection.authenticated
|
||||
):
|
||||
raise ATT_Error(
|
||||
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
|
||||
)
|
||||
@@ -803,9 +867,11 @@ class Attribute(EventEmitter):
|
||||
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
|
||||
)
|
||||
|
||||
if read := getattr(self.value, 'read', None):
|
||||
if hasattr(self.value, 'read'):
|
||||
try:
|
||||
value = read(connection) # pylint: disable=not-callable
|
||||
value = self.value.read(connection)
|
||||
if inspect.isawaitable(value):
|
||||
value = await value
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
@@ -815,7 +881,7 @@ class Attribute(EventEmitter):
|
||||
|
||||
return self.encode_value(value)
|
||||
|
||||
def write_value(self, connection: Connection, value_bytes):
|
||||
async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
|
||||
if (
|
||||
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
|
||||
) and not connection.encryption:
|
||||
@@ -836,9 +902,11 @@ class Attribute(EventEmitter):
|
||||
|
||||
value = self.decode_value(value_bytes)
|
||||
|
||||
if write := getattr(self.value, 'write', None):
|
||||
if hasattr(self.value, 'write'):
|
||||
try:
|
||||
write(connection, value) # pylint: disable=not-callable
|
||||
result = self.value.write(connection, value)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
|
||||
520
bumble/avc.py
Normal file
520
bumble/avc.py
Normal file
@@ -0,0 +1,520 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import struct
|
||||
from typing import Dict, Type, Union, Tuple
|
||||
|
||||
from bumble.utils import OpenIntEnum
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Frame:
|
||||
class SubunitType(enum.IntEnum):
|
||||
# AV/C Digital Interface Command Set General Specification Version 4.1
|
||||
# Table 7.4
|
||||
MONITOR = 0x00
|
||||
AUDIO = 0x01
|
||||
PRINTER = 0x02
|
||||
DISC = 0x03
|
||||
TAPE_RECORDER_OR_PLAYER = 0x04
|
||||
TUNER = 0x05
|
||||
CA = 0x06
|
||||
CAMERA = 0x07
|
||||
PANEL = 0x09
|
||||
BULLETIN_BOARD = 0x0A
|
||||
VENDOR_UNIQUE = 0x1C
|
||||
EXTENDED = 0x1E
|
||||
UNIT = 0x1F
|
||||
|
||||
class OperationCode(OpenIntEnum):
|
||||
# 0x00 - 0x0F: Unit and subunit commands
|
||||
VENDOR_DEPENDENT = 0x00
|
||||
RESERVE = 0x01
|
||||
PLUG_INFO = 0x02
|
||||
|
||||
# 0x10 - 0x3F: Unit commands
|
||||
DIGITAL_OUTPUT = 0x10
|
||||
DIGITAL_INPUT = 0x11
|
||||
CHANNEL_USAGE = 0x12
|
||||
OUTPUT_PLUG_SIGNAL_FORMAT = 0x18
|
||||
INPUT_PLUG_SIGNAL_FORMAT = 0x19
|
||||
GENERAL_BUS_SETUP = 0x1F
|
||||
CONNECT_AV = 0x20
|
||||
DISCONNECT_AV = 0x21
|
||||
CONNECTIONS = 0x22
|
||||
CONNECT = 0x24
|
||||
DISCONNECT = 0x25
|
||||
UNIT_INFO = 0x30
|
||||
SUBUNIT_INFO = 0x31
|
||||
|
||||
# 0x40 - 0x7F: Subunit commands
|
||||
PASS_THROUGH = 0x7C
|
||||
GUI_UPDATE = 0x7D
|
||||
PUSH_GUI_DATA = 0x7E
|
||||
USER_ACTION = 0x7F
|
||||
|
||||
# 0xA0 - 0xBF: Unit and subunit commands
|
||||
VERSION = 0xB0
|
||||
POWER = 0xB2
|
||||
|
||||
subunit_type: SubunitType
|
||||
subunit_id: int
|
||||
opcode: OperationCode
|
||||
operands: bytes
|
||||
|
||||
@staticmethod
|
||||
def subclass(subclass):
|
||||
# Infer the opcode from the class name
|
||||
if subclass.__name__.endswith("CommandFrame"):
|
||||
short_name = subclass.__name__.replace("CommandFrame", "")
|
||||
category_class = CommandFrame
|
||||
elif subclass.__name__.endswith("ResponseFrame"):
|
||||
short_name = subclass.__name__.replace("ResponseFrame", "")
|
||||
category_class = ResponseFrame
|
||||
else:
|
||||
raise ValueError(f"invalid subclass name {subclass.__name__}")
|
||||
|
||||
uppercase_indexes = [
|
||||
i for i in range(len(short_name)) if short_name[i].isupper()
|
||||
]
|
||||
uppercase_indexes.append(len(short_name))
|
||||
words = [
|
||||
short_name[uppercase_indexes[i] : uppercase_indexes[i + 1]].upper()
|
||||
for i in range(len(uppercase_indexes) - 1)
|
||||
]
|
||||
opcode_name = "_".join(words)
|
||||
opcode = Frame.OperationCode[opcode_name]
|
||||
category_class.subclasses[opcode] = subclass
|
||||
return subclass
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> Frame:
|
||||
if data[0] >> 4 != 0:
|
||||
raise ValueError("first 4 bits must be 0s")
|
||||
|
||||
ctype_or_response = data[0] & 0xF
|
||||
subunit_type = Frame.SubunitType(data[1] >> 3)
|
||||
subunit_id = data[1] & 7
|
||||
|
||||
if subunit_type == Frame.SubunitType.EXTENDED:
|
||||
# Not supported
|
||||
raise NotImplementedError("extended subunit types not supported")
|
||||
|
||||
if subunit_id < 5:
|
||||
opcode_offset = 2
|
||||
elif subunit_id == 5:
|
||||
# Extended to the next byte
|
||||
extension = data[2]
|
||||
if extension == 0:
|
||||
raise ValueError("extended subunit ID value reserved")
|
||||
if extension == 0xFF:
|
||||
subunit_id = 5 + 254 + data[3]
|
||||
opcode_offset = 4
|
||||
else:
|
||||
subunit_id = 5 + extension
|
||||
opcode_offset = 3
|
||||
|
||||
elif subunit_id == 6:
|
||||
raise ValueError("reserved subunit ID")
|
||||
|
||||
opcode = Frame.OperationCode(data[opcode_offset])
|
||||
operands = data[opcode_offset + 1 :]
|
||||
|
||||
# Look for a registered subclass
|
||||
if ctype_or_response < 8:
|
||||
# Command
|
||||
ctype = CommandFrame.CommandType(ctype_or_response)
|
||||
if c_subclass := CommandFrame.subclasses.get(opcode):
|
||||
return c_subclass(
|
||||
ctype,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
*c_subclass.parse_operands(operands),
|
||||
)
|
||||
return CommandFrame(ctype, subunit_type, subunit_id, opcode, operands)
|
||||
else:
|
||||
# Response
|
||||
response = ResponseFrame.ResponseCode(ctype_or_response)
|
||||
if r_subclass := ResponseFrame.subclasses.get(opcode):
|
||||
return r_subclass(
|
||||
response,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
*r_subclass.parse_operands(operands),
|
||||
)
|
||||
return ResponseFrame(response, subunit_type, subunit_id, opcode, operands)
|
||||
|
||||
def to_bytes(
|
||||
self,
|
||||
ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode],
|
||||
) -> bytes:
|
||||
# TODO: support extended subunit types and ids.
|
||||
return (
|
||||
bytes(
|
||||
[
|
||||
ctype_or_response,
|
||||
self.subunit_type << 3 | self.subunit_id,
|
||||
self.opcode,
|
||||
]
|
||||
)
|
||||
+ self.operands
|
||||
)
|
||||
|
||||
def to_string(self, extra: str) -> str:
|
||||
return (
|
||||
f"{self.__class__.__name__}({extra}"
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"opcode={self.opcode.name}, "
|
||||
f"operands={self.operands.hex()})"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subunit_type: SubunitType,
|
||||
subunit_id: int,
|
||||
opcode: OperationCode,
|
||||
operands: bytes,
|
||||
) -> None:
|
||||
self.subunit_type = subunit_type
|
||||
self.subunit_id = subunit_id
|
||||
self.opcode = opcode
|
||||
self.operands = operands
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommandFrame(Frame):
|
||||
class CommandType(OpenIntEnum):
|
||||
# AV/C Digital Interface Command Set General Specification Version 4.1
|
||||
# Table 7.1
|
||||
CONTROL = 0x00
|
||||
STATUS = 0x01
|
||||
SPECIFIC_INQUIRY = 0x02
|
||||
NOTIFY = 0x03
|
||||
GENERAL_INQUIRY = 0x04
|
||||
|
||||
subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {}
|
||||
ctype: CommandType
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctype: CommandType,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
opcode: Frame.OperationCode,
|
||||
operands: bytes,
|
||||
) -> None:
|
||||
super().__init__(subunit_type, subunit_id, opcode, operands)
|
||||
self.ctype = ctype
|
||||
|
||||
def __bytes__(self):
|
||||
return self.to_bytes(self.ctype)
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string(f"ctype={self.ctype.name}, ")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ResponseFrame(Frame):
|
||||
class ResponseCode(OpenIntEnum):
|
||||
# AV/C Digital Interface Command Set General Specification Version 4.1
|
||||
# Table 7.2
|
||||
NOT_IMPLEMENTED = 0x08
|
||||
ACCEPTED = 0x09
|
||||
REJECTED = 0x0A
|
||||
IN_TRANSITION = 0x0B
|
||||
IMPLEMENTED_OR_STABLE = 0x0C
|
||||
CHANGED = 0x0D
|
||||
INTERIM = 0x0F
|
||||
|
||||
subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {}
|
||||
response: ResponseCode
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: ResponseCode,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
opcode: Frame.OperationCode,
|
||||
operands: bytes,
|
||||
) -> None:
|
||||
super().__init__(subunit_type, subunit_id, opcode, operands)
|
||||
self.response = response
|
||||
|
||||
def __bytes__(self):
|
||||
return self.to_bytes(self.response)
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string(f"response={self.response.name}, ")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class VendorDependentFrame:
|
||||
company_id: int
|
||||
vendor_dependent_data: bytes
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
return (
|
||||
struct.unpack(">I", b"\x00" + operands[:3])[0],
|
||||
operands[3:],
|
||||
)
|
||||
|
||||
def make_operands(self) -> bytes:
|
||||
return struct.pack(">I", self.company_id)[1:] + self.vendor_dependent_data
|
||||
|
||||
def __init__(self, company_id: int, vendor_dependent_data: bytes):
|
||||
self.company_id = company_id
|
||||
self.vendor_dependent_data = vendor_dependent_data
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@Frame.subclass
|
||||
class VendorDependentCommandFrame(VendorDependentFrame, CommandFrame):
|
||||
def __init__(
|
||||
self,
|
||||
ctype: CommandFrame.CommandType,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
company_id: int,
|
||||
vendor_dependent_data: bytes,
|
||||
) -> None:
|
||||
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
|
||||
CommandFrame.__init__(
|
||||
self,
|
||||
ctype,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
Frame.OperationCode.VENDOR_DEPENDENT,
|
||||
self.make_operands(),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"VendorDependentCommandFrame(ctype={self.ctype.name}, "
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"company_id=0x{self.company_id:06X}, "
|
||||
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@Frame.subclass
|
||||
class VendorDependentResponseFrame(VendorDependentFrame, ResponseFrame):
|
||||
def __init__(
|
||||
self,
|
||||
response: ResponseFrame.ResponseCode,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
company_id: int,
|
||||
vendor_dependent_data: bytes,
|
||||
) -> None:
|
||||
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
|
||||
ResponseFrame.__init__(
|
||||
self,
|
||||
response,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
Frame.OperationCode.VENDOR_DEPENDENT,
|
||||
self.make_operands(),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"VendorDependentResponseFrame(response={self.response.name}, "
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"company_id=0x{self.company_id:06X}, "
|
||||
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PassThroughFrame:
|
||||
"""
|
||||
See AV/C Panel Subunit Specification 1.1 - 9.4 PASS THROUGH control command
|
||||
"""
|
||||
|
||||
class StateFlag(enum.IntEnum):
|
||||
PRESSED = 0
|
||||
RELEASED = 1
|
||||
|
||||
class OperationId(OpenIntEnum):
|
||||
SELECT = 0x00
|
||||
UP = 0x01
|
||||
DOWN = 0x01
|
||||
LEFT = 0x03
|
||||
RIGHT = 0x04
|
||||
RIGHT_UP = 0x05
|
||||
RIGHT_DOWN = 0x06
|
||||
LEFT_UP = 0x07
|
||||
LEFT_DOWN = 0x08
|
||||
ROOT_MENU = 0x09
|
||||
SETUP_MENU = 0x0A
|
||||
CONTENTS_MENU = 0x0B
|
||||
FAVORITE_MENU = 0x0C
|
||||
EXIT = 0x0D
|
||||
NUMBER_0 = 0x20
|
||||
NUMBER_1 = 0x21
|
||||
NUMBER_2 = 0x22
|
||||
NUMBER_3 = 0x23
|
||||
NUMBER_4 = 0x24
|
||||
NUMBER_5 = 0x25
|
||||
NUMBER_6 = 0x26
|
||||
NUMBER_7 = 0x27
|
||||
NUMBER_8 = 0x28
|
||||
NUMBER_9 = 0x29
|
||||
DOT = 0x2A
|
||||
ENTER = 0x2B
|
||||
CLEAR = 0x2C
|
||||
CHANNEL_UP = 0x30
|
||||
CHANNEL_DOWN = 0x31
|
||||
PREVIOUS_CHANNEL = 0x32
|
||||
SOUND_SELECT = 0x33
|
||||
INPUT_SELECT = 0x34
|
||||
DISPLAY_INFORMATION = 0x35
|
||||
HELP = 0x36
|
||||
PAGE_UP = 0x37
|
||||
PAGE_DOWN = 0x38
|
||||
POWER = 0x40
|
||||
VOLUME_UP = 0x41
|
||||
VOLUME_DOWN = 0x42
|
||||
MUTE = 0x43
|
||||
PLAY = 0x44
|
||||
STOP = 0x45
|
||||
PAUSE = 0x46
|
||||
RECORD = 0x47
|
||||
REWIND = 0x48
|
||||
FAST_FORWARD = 0x49
|
||||
EJECT = 0x4A
|
||||
FORWARD = 0x4B
|
||||
BACKWARD = 0x4C
|
||||
ANGLE = 0x50
|
||||
SUBPICTURE = 0x51
|
||||
F1 = 0x71
|
||||
F2 = 0x72
|
||||
F3 = 0x73
|
||||
F4 = 0x74
|
||||
F5 = 0x75
|
||||
VENDOR_UNIQUE = 0x7E
|
||||
|
||||
state_flag: StateFlag
|
||||
operation_id: OperationId
|
||||
operation_data: bytes
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
return (
|
||||
PassThroughFrame.StateFlag(operands[0] >> 7),
|
||||
PassThroughFrame.OperationId(operands[0] & 0x7F),
|
||||
operands[1 : 1 + operands[1]],
|
||||
)
|
||||
|
||||
def make_operands(self):
|
||||
return (
|
||||
bytes([self.state_flag << 7 | self.operation_id, len(self.operation_data)])
|
||||
+ self.operation_data
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_flag: StateFlag,
|
||||
operation_id: OperationId,
|
||||
operation_data: bytes,
|
||||
) -> None:
|
||||
if len(operation_data) > 255:
|
||||
raise ValueError("operation data must be <= 255 bytes")
|
||||
self.state_flag = state_flag
|
||||
self.operation_id = operation_id
|
||||
self.operation_data = operation_data
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@Frame.subclass
|
||||
class PassThroughCommandFrame(PassThroughFrame, CommandFrame):
|
||||
def __init__(
|
||||
self,
|
||||
ctype: CommandFrame.CommandType,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
state_flag: PassThroughFrame.StateFlag,
|
||||
operation_id: PassThroughFrame.OperationId,
|
||||
operation_data: bytes,
|
||||
) -> None:
|
||||
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
|
||||
CommandFrame.__init__(
|
||||
self,
|
||||
ctype,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
Frame.OperationCode.PASS_THROUGH,
|
||||
self.make_operands(),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"PassThroughCommandFrame(ctype={self.ctype.name}, "
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"state_flag={self.state_flag.name}, "
|
||||
f"operation_id={self.operation_id.name}, "
|
||||
f"operation_data={self.operation_data.hex()})"
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@Frame.subclass
|
||||
class PassThroughResponseFrame(PassThroughFrame, ResponseFrame):
|
||||
def __init__(
|
||||
self,
|
||||
response: ResponseFrame.ResponseCode,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
state_flag: PassThroughFrame.StateFlag,
|
||||
operation_id: PassThroughFrame.OperationId,
|
||||
operation_data: bytes,
|
||||
) -> None:
|
||||
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
|
||||
ResponseFrame.__init__(
|
||||
self,
|
||||
response,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
Frame.OperationCode.PASS_THROUGH,
|
||||
self.make_operands(),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"PassThroughResponseFrame(response={self.response.name}, "
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"state_flag={self.state_flag.name}, "
|
||||
f"operation_id={self.operation_id.name}, "
|
||||
f"operation_data={self.operation_data.hex()})"
|
||||
)
|
||||
291
bumble/avctp.py
Normal file
291
bumble/avctp.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
import logging
|
||||
import struct
|
||||
from typing import Callable, cast, Dict, Optional
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble import avc
|
||||
from bumble import l2cap
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
AVCTP_PSM = 0x0017
|
||||
AVCTP_BROWSING_PSM = 0x001B
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MessageAssembler:
|
||||
Callback = Callable[[int, bool, bool, int, bytes], None]
|
||||
|
||||
transaction_label: int
|
||||
pid: int
|
||||
c_r: int
|
||||
ipid: int
|
||||
payload: bytes
|
||||
number_of_packets: int
|
||||
packets_received: int
|
||||
|
||||
def __init__(self, callback: Callback) -> None:
|
||||
self.callback = callback
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.packets_received = 0
|
||||
self.transaction_label = -1
|
||||
self.pid = -1
|
||||
self.c_r = -1
|
||||
self.ipid = -1
|
||||
self.payload = b''
|
||||
self.number_of_packets = 0
|
||||
self.packet_count = 0
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
self.packets_received += 1
|
||||
|
||||
transaction_label = pdu[0] >> 4
|
||||
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
|
||||
c_r = (pdu[0] >> 1) & 1
|
||||
ipid = pdu[0] & 1
|
||||
|
||||
if c_r == 0 and ipid != 0:
|
||||
logger.warning("invalid IPID in command frame")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
pid_offset = 1
|
||||
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START):
|
||||
if self.transaction_label >= 0:
|
||||
# We are already in a transaction
|
||||
logger.warning("received START or SINGLE fragment while in transaction")
|
||||
self.reset()
|
||||
self.packets_received = 1
|
||||
|
||||
if packet_type == Protocol.PacketType.START:
|
||||
self.number_of_packets = pdu[1]
|
||||
pid_offset = 2
|
||||
|
||||
pid = struct.unpack_from(">H", pdu, pid_offset)[0]
|
||||
self.payload += pdu[pid_offset + 2 :]
|
||||
|
||||
if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END):
|
||||
if transaction_label != self.transaction_label:
|
||||
logger.warning("transaction label does not match")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if pid != self.pid:
|
||||
logger.warning("PID does not match")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if c_r != self.c_r:
|
||||
logger.warning("C/R does not match")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if self.packets_received > self.number_of_packets:
|
||||
logger.warning("too many fragments in transaction")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if packet_type == Protocol.PacketType.END:
|
||||
if self.packets_received != self.number_of_packets:
|
||||
logger.warning("premature END")
|
||||
self.reset()
|
||||
return
|
||||
else:
|
||||
self.transaction_label = transaction_label
|
||||
self.c_r = c_r
|
||||
self.ipid = ipid
|
||||
self.pid = pid
|
||||
|
||||
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END):
|
||||
self.on_message_complete()
|
||||
|
||||
def on_message_complete(self):
|
||||
try:
|
||||
self.callback(
|
||||
self.transaction_label,
|
||||
self.c_r == 0,
|
||||
self.ipid != 0,
|
||||
self.pid,
|
||||
self.payload,
|
||||
)
|
||||
except Exception as error:
|
||||
logger.exception(color(f"!!! exception in callback: {error}", "red"))
|
||||
|
||||
self.reset()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Protocol:
|
||||
CommandHandler = Callable[[int, avc.CommandFrame], None]
|
||||
command_handlers: Dict[int, CommandHandler] # Command handlers, by PID
|
||||
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
|
||||
response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID
|
||||
next_transaction_label: int
|
||||
message_assembler: MessageAssembler
|
||||
|
||||
class PacketType(IntEnum):
|
||||
SINGLE = 0b00
|
||||
START = 0b01
|
||||
CONTINUE = 0b10
|
||||
END = 0b11
|
||||
|
||||
def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
self.command_handlers = {}
|
||||
self.response_handlers = {}
|
||||
self.l2cap_channel = l2cap_channel
|
||||
self.message_assembler = MessageAssembler(self.on_message)
|
||||
|
||||
# Register to receive PDUs from the channel
|
||||
l2cap_channel.sink = self.on_pdu
|
||||
l2cap_channel.on("open", self.on_l2cap_channel_open)
|
||||
l2cap_channel.on("close", self.on_l2cap_channel_close)
|
||||
|
||||
def on_l2cap_channel_open(self):
|
||||
logger.debug(color("<<< AVCTP channel open", "magenta"))
|
||||
|
||||
def on_l2cap_channel_close(self):
|
||||
logger.debug(color("<<< AVCTP channel closed", "magenta"))
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
self.message_assembler.on_pdu(pdu)
|
||||
|
||||
def on_message(
|
||||
self,
|
||||
transaction_label: int,
|
||||
is_command: bool,
|
||||
ipid: bool,
|
||||
pid: int,
|
||||
payload: bytes,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"<<< AVCTP Message: pid={pid}, "
|
||||
f"transaction_label={transaction_label}, "
|
||||
f"is_command={is_command}, "
|
||||
f"ipid={ipid}, "
|
||||
f"payload={payload.hex()}"
|
||||
)
|
||||
|
||||
# Check for invalid PID responses.
|
||||
if ipid:
|
||||
logger.debug(f"received IPID for PID={pid}")
|
||||
|
||||
# Find the appropriate handler.
|
||||
if is_command:
|
||||
if pid not in self.command_handlers:
|
||||
logger.warning(f"no command handler for PID {pid}")
|
||||
self.send_ipid(transaction_label, pid)
|
||||
return
|
||||
|
||||
command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
|
||||
self.command_handlers[pid](transaction_label, command_frame)
|
||||
else:
|
||||
if pid not in self.response_handlers:
|
||||
logger.warning(f"no response handler for PID {pid}")
|
||||
return
|
||||
|
||||
# By convention, for an ipid, send a None payload to the response handler.
|
||||
if ipid:
|
||||
response_frame = None
|
||||
else:
|
||||
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
|
||||
|
||||
self.response_handlers[pid](transaction_label, response_frame)
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
transaction_label: int,
|
||||
is_command: bool,
|
||||
ipid: bool,
|
||||
pid: int,
|
||||
payload: bytes,
|
||||
):
|
||||
# TODO: fragment large messages
|
||||
packet_type = Protocol.PacketType.SINGLE
|
||||
pdu = (
|
||||
struct.pack(
|
||||
">BH",
|
||||
transaction_label << 4
|
||||
| packet_type << 2
|
||||
| (0 if is_command else 1) << 1
|
||||
| (1 if ipid else 0),
|
||||
pid,
|
||||
)
|
||||
+ payload
|
||||
)
|
||||
self.l2cap_channel.send_pdu(pdu)
|
||||
|
||||
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
|
||||
logger.debug(
|
||||
">>> AVCTP command: "
|
||||
f"transaction_label={transaction_label}, "
|
||||
f"pid={pid}, "
|
||||
f"payload={payload.hex()}"
|
||||
)
|
||||
self.send_message(transaction_label, True, False, pid, payload)
|
||||
|
||||
def send_response(self, transaction_label: int, pid: int, payload: bytes):
|
||||
logger.debug(
|
||||
">>> AVCTP response: "
|
||||
f"transaction_label={transaction_label}, "
|
||||
f"pid={pid}, "
|
||||
f"payload={payload.hex()}"
|
||||
)
|
||||
self.send_message(transaction_label, False, False, pid, payload)
|
||||
|
||||
def send_ipid(self, transaction_label: int, pid: int) -> None:
|
||||
logger.debug(
|
||||
">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
|
||||
)
|
||||
self.send_message(transaction_label, False, True, pid, b'')
|
||||
|
||||
def register_command_handler(
|
||||
self, pid: int, handler: Protocol.CommandHandler
|
||||
) -> None:
|
||||
self.command_handlers[pid] = handler
|
||||
|
||||
def unregister_command_handler(
|
||||
self, pid: int, handler: Protocol.CommandHandler
|
||||
) -> None:
|
||||
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
|
||||
raise ValueError("command handler not registered")
|
||||
del self.command_handlers[pid]
|
||||
|
||||
def register_response_handler(
|
||||
self, pid: int, handler: Protocol.ResponseHandler
|
||||
) -> None:
|
||||
self.response_handlers[pid] = handler
|
||||
|
||||
def unregister_response_handler(
|
||||
self, pid: int, handler: Protocol.ResponseHandler
|
||||
) -> None:
|
||||
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
|
||||
raise ValueError("response handler not registered")
|
||||
del self.response_handlers[pid]
|
||||
494
bumble/avdtp.py
494
bumble/avdtp.py
File diff suppressed because it is too large
Load Diff
1916
bumble/avrcp.py
Normal file
1916
bumble/avrcp.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
122
bumble/core.py
122
bumble/core.py
@@ -16,8 +16,9 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import struct
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
from typing import List, Optional, Tuple, Union, cast, Dict
|
||||
|
||||
from .company_ids import COMPANY_IDENTIFIERS
|
||||
|
||||
@@ -53,7 +54,7 @@ def bit_flags_to_strings(bits, bit_flag_names):
|
||||
return names
|
||||
|
||||
|
||||
def name_or_number(dictionary, number, width=2):
|
||||
def name_or_number(dictionary: Dict[int, str], number: int, width: int = 2) -> str:
|
||||
name = dictionary.get(number)
|
||||
if name is not None:
|
||||
return name
|
||||
@@ -78,7 +79,13 @@ def get_dict_key_by_value(dictionary, value):
|
||||
class BaseError(Exception):
|
||||
"""Base class for errors with an error code, error name and namespace"""
|
||||
|
||||
def __init__(self, error_code, error_namespace='', error_name='', details=''):
|
||||
def __init__(
|
||||
self,
|
||||
error_code: Optional[int],
|
||||
error_namespace: str = '',
|
||||
error_name: str = '',
|
||||
details: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
self.error_code = error_code
|
||||
self.error_namespace = error_namespace
|
||||
@@ -90,12 +97,18 @@ class BaseError(Exception):
|
||||
namespace = f'{self.error_namespace}/'
|
||||
else:
|
||||
namespace = ''
|
||||
if self.error_name:
|
||||
name = f'{self.error_name} [0x{self.error_code:X}]'
|
||||
have_name = self.error_name != ''
|
||||
have_code = self.error_code is not None
|
||||
if have_name and have_code:
|
||||
error_text = f'{self.error_name} [0x{self.error_code:X}]'
|
||||
elif have_name and not have_code:
|
||||
error_text = self.error_name
|
||||
elif not have_name and have_code:
|
||||
error_text = f'0x{self.error_code:X}'
|
||||
else:
|
||||
name = f'0x{self.error_code:X}'
|
||||
error_text = '<unspecified>'
|
||||
|
||||
return f'{type(self).__name__}({namespace}{name})'
|
||||
return f'{type(self).__name__}({namespace}{error_text})'
|
||||
|
||||
|
||||
class ProtocolError(BaseError):
|
||||
@@ -134,6 +147,10 @@ class ConnectionError(BaseError): # pylint: disable=redefined-builtin
|
||||
self.peer_address = peer_address
|
||||
|
||||
|
||||
class ConnectionParameterUpdateError(BaseError):
|
||||
"""Connection Parameter Update Error"""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# UUID
|
||||
#
|
||||
@@ -306,7 +323,7 @@ BT_HIDP_PROTOCOL_ID = UUID.from_16_bits(0x0011, 'HIDP')
|
||||
BT_HARDCOPY_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0012, 'HardcopyControlChannel')
|
||||
BT_HARDCOPY_DATA_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0014, 'HardcopyDataChannel')
|
||||
BT_HARDCOPY_NOTIFICATION_PROTOCOL_ID = UUID.from_16_bits(0x0016, 'HardcopyNotification')
|
||||
BT_AVTCP_PROTOCOL_ID = UUID.from_16_bits(0x0017, 'AVCTP')
|
||||
BT_AVCTP_PROTOCOL_ID = UUID.from_16_bits(0x0017, 'AVCTP')
|
||||
BT_AVDTP_PROTOCOL_ID = UUID.from_16_bits(0x0019, 'AVDTP')
|
||||
BT_CMTP_PROTOCOL_ID = UUID.from_16_bits(0x001B, 'CMTP')
|
||||
BT_MCAP_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x001E, 'MCAPControlChannel')
|
||||
@@ -562,11 +579,82 @@ class DeviceClass:
|
||||
PERIPHERAL_HANDHELD_GESTURAL_INPUT_DEVICE_MINOR_DEVICE_CLASS: 'Handheld gestural input device'
|
||||
}
|
||||
|
||||
WEARABLE_UNCATEGORIZED_MINOR_DEVICE_CLASS = 0x00
|
||||
WEARABLE_WRISTWATCH_MINOR_DEVICE_CLASS = 0x01
|
||||
WEARABLE_PAGER_MINOR_DEVICE_CLASS = 0x02
|
||||
WEARABLE_JACKET_MINOR_DEVICE_CLASS = 0x03
|
||||
WEARABLE_HELMET_MINOR_DEVICE_CLASS = 0x04
|
||||
WEARABLE_GLASSES_MINOR_DEVICE_CLASS = 0x05
|
||||
|
||||
WEARABLE_MINOR_DEVICE_CLASS_NAMES = {
|
||||
WEARABLE_UNCATEGORIZED_MINOR_DEVICE_CLASS: 'Uncategorized',
|
||||
WEARABLE_WRISTWATCH_MINOR_DEVICE_CLASS: 'Wristwatch',
|
||||
WEARABLE_PAGER_MINOR_DEVICE_CLASS: 'Pager',
|
||||
WEARABLE_JACKET_MINOR_DEVICE_CLASS: 'Jacket',
|
||||
WEARABLE_HELMET_MINOR_DEVICE_CLASS: 'Helmet',
|
||||
WEARABLE_GLASSES_MINOR_DEVICE_CLASS: 'Glasses',
|
||||
}
|
||||
|
||||
TOY_UNCATEGORIZED_MINOR_DEVICE_CLASS = 0x00
|
||||
TOY_ROBOT_MINOR_DEVICE_CLASS = 0x01
|
||||
TOY_VEHICLE_MINOR_DEVICE_CLASS = 0x02
|
||||
TOY_DOLL_ACTION_FIGURE_MINOR_DEVICE_CLASS = 0x03
|
||||
TOY_CONTROLLER_MINOR_DEVICE_CLASS = 0x04
|
||||
TOY_GAME_MINOR_DEVICE_CLASS = 0x05
|
||||
|
||||
TOY_MINOR_DEVICE_CLASS_NAMES = {
|
||||
TOY_UNCATEGORIZED_MINOR_DEVICE_CLASS: 'Uncategorized',
|
||||
TOY_ROBOT_MINOR_DEVICE_CLASS: 'Robot',
|
||||
TOY_VEHICLE_MINOR_DEVICE_CLASS: 'Vehicle',
|
||||
TOY_DOLL_ACTION_FIGURE_MINOR_DEVICE_CLASS: 'Doll/Action figure',
|
||||
TOY_CONTROLLER_MINOR_DEVICE_CLASS: 'Controller',
|
||||
TOY_GAME_MINOR_DEVICE_CLASS: 'Game',
|
||||
}
|
||||
|
||||
HEALTH_UNDEFINED_MINOR_DEVICE_CLASS = 0x00
|
||||
HEALTH_BLOOD_PRESSURE_MONITOR_MINOR_DEVICE_CLASS = 0x01
|
||||
HEALTH_THERMOMETER_MINOR_DEVICE_CLASS = 0x02
|
||||
HEALTH_WEIGHING_SCALE_MINOR_DEVICE_CLASS = 0x03
|
||||
HEALTH_GLUCOSE_METER_MINOR_DEVICE_CLASS = 0x04
|
||||
HEALTH_PULSE_OXIMETER_MINOR_DEVICE_CLASS = 0x05
|
||||
HEALTH_HEART_PULSE_RATE_MONITOR_MINOR_DEVICE_CLASS = 0x06
|
||||
HEALTH_HEALTH_DATA_DISPLAY_MINOR_DEVICE_CLASS = 0x07
|
||||
HEALTH_STEP_COUNTER_MINOR_DEVICE_CLASS = 0x08
|
||||
HEALTH_BODY_COMPOSITION_ANALYZER_MINOR_DEVICE_CLASS = 0x09
|
||||
HEALTH_PEAK_FLOW_MONITOR_MINOR_DEVICE_CLASS = 0x0A
|
||||
HEALTH_MEDICATION_MONITOR_MINOR_DEVICE_CLASS = 0x0B
|
||||
HEALTH_KNEE_PROSTHESIS_MINOR_DEVICE_CLASS = 0x0C
|
||||
HEALTH_ANKLE_PROSTHESIS_MINOR_DEVICE_CLASS = 0x0D
|
||||
HEALTH_GENERIC_HEALTH_MANAGER_MINOR_DEVICE_CLASS = 0x0E
|
||||
HEALTH_PERSONAL_MOBILITY_DEVICE_MINOR_DEVICE_CLASS = 0x0F
|
||||
|
||||
HEALTH_MINOR_DEVICE_CLASS_NAMES = {
|
||||
HEALTH_UNDEFINED_MINOR_DEVICE_CLASS: 'Undefined',
|
||||
HEALTH_BLOOD_PRESSURE_MONITOR_MINOR_DEVICE_CLASS: 'Blood Pressure Monitor',
|
||||
HEALTH_THERMOMETER_MINOR_DEVICE_CLASS: 'Thermometer',
|
||||
HEALTH_WEIGHING_SCALE_MINOR_DEVICE_CLASS: 'Weighing Scale',
|
||||
HEALTH_GLUCOSE_METER_MINOR_DEVICE_CLASS: 'Glucose Meter',
|
||||
HEALTH_PULSE_OXIMETER_MINOR_DEVICE_CLASS: 'Pulse Oximeter',
|
||||
HEALTH_HEART_PULSE_RATE_MONITOR_MINOR_DEVICE_CLASS: 'Heart/Pulse Rate Monitor',
|
||||
HEALTH_HEALTH_DATA_DISPLAY_MINOR_DEVICE_CLASS: 'Health Data Display',
|
||||
HEALTH_STEP_COUNTER_MINOR_DEVICE_CLASS: 'Step Counter',
|
||||
HEALTH_BODY_COMPOSITION_ANALYZER_MINOR_DEVICE_CLASS: 'Body Composition Analyzer',
|
||||
HEALTH_PEAK_FLOW_MONITOR_MINOR_DEVICE_CLASS: 'Peak Flow Monitor',
|
||||
HEALTH_MEDICATION_MONITOR_MINOR_DEVICE_CLASS: 'Medication Monitor',
|
||||
HEALTH_KNEE_PROSTHESIS_MINOR_DEVICE_CLASS: 'Knee Prosthesis',
|
||||
HEALTH_ANKLE_PROSTHESIS_MINOR_DEVICE_CLASS: 'Ankle Prosthesis',
|
||||
HEALTH_GENERIC_HEALTH_MANAGER_MINOR_DEVICE_CLASS: 'Generic Health Manager',
|
||||
HEALTH_PERSONAL_MOBILITY_DEVICE_MINOR_DEVICE_CLASS: 'Personal Mobility Device',
|
||||
}
|
||||
|
||||
MINOR_DEVICE_CLASS_NAMES = {
|
||||
COMPUTER_MAJOR_DEVICE_CLASS: COMPUTER_MINOR_DEVICE_CLASS_NAMES,
|
||||
PHONE_MAJOR_DEVICE_CLASS: PHONE_MINOR_DEVICE_CLASS_NAMES,
|
||||
AUDIO_VIDEO_MAJOR_DEVICE_CLASS: AUDIO_VIDEO_MINOR_DEVICE_CLASS_NAMES,
|
||||
PERIPHERAL_MAJOR_DEVICE_CLASS: PERIPHERAL_MINOR_DEVICE_CLASS_NAMES
|
||||
PERIPHERAL_MAJOR_DEVICE_CLASS: PERIPHERAL_MINOR_DEVICE_CLASS_NAMES,
|
||||
WEARABLE_MAJOR_DEVICE_CLASS: WEARABLE_MINOR_DEVICE_CLASS_NAMES,
|
||||
TOY_MAJOR_DEVICE_CLASS: TOY_MINOR_DEVICE_CLASS_NAMES,
|
||||
HEALTH_MAJOR_DEVICE_CLASS: HEALTH_MINOR_DEVICE_CLASS_NAMES,
|
||||
}
|
||||
|
||||
# fmt: on
|
||||
@@ -737,8 +825,8 @@ class AdvertisingData:
|
||||
ad_structures = []
|
||||
self.ad_structures = ad_structures[:]
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data):
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> AdvertisingData:
|
||||
instance = AdvertisingData()
|
||||
instance.append(data)
|
||||
return instance
|
||||
@@ -894,7 +982,7 @@ class AdvertisingData:
|
||||
|
||||
return ad_data
|
||||
|
||||
def append(self, data):
|
||||
def append(self, data: bytes) -> None:
|
||||
offset = 0
|
||||
while offset + 1 < len(data):
|
||||
length = data[offset]
|
||||
@@ -968,3 +1056,13 @@ class ConnectionPHY:
|
||||
|
||||
def __str__(self):
|
||||
return f'ConnectionPHY(tx_phy={self.tx_phy}, rx_phy={self.rx_phy})'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# LE Role
|
||||
# -----------------------------------------------------------------------------
|
||||
class LeRole(enum.IntEnum):
|
||||
PERIPHERAL_ONLY = 0x00
|
||||
CENTRAL_ONLY = 0x01
|
||||
BOTH_PERIPHERAL_PREFERRED = 0x02
|
||||
BOTH_CENTRAL_PREFERRED = 0x03
|
||||
|
||||
184
bumble/crypto.py
184
bumble/crypto.py
@@ -21,24 +21,24 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import operator
|
||||
import platform
|
||||
|
||||
if platform.system() != 'Emscripten':
|
||||
import secrets
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||
generate_private_key,
|
||||
ECDH,
|
||||
EllipticCurvePublicNumbers,
|
||||
EllipticCurvePrivateNumbers,
|
||||
SECP256R1,
|
||||
)
|
||||
from cryptography.hazmat.primitives import cmac
|
||||
else:
|
||||
# TODO: implement stubs
|
||||
pass
|
||||
import secrets
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||
generate_private_key,
|
||||
ECDH,
|
||||
EllipticCurvePrivateKey,
|
||||
EllipticCurvePublicNumbers,
|
||||
EllipticCurvePrivateNumbers,
|
||||
SECP256R1,
|
||||
)
|
||||
from cryptography.hazmat.primitives import cmac
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -50,16 +50,18 @@ logger = logging.getLogger(__name__)
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class EccKey:
|
||||
def __init__(self, private_key):
|
||||
def __init__(self, private_key: EllipticCurvePrivateKey) -> None:
|
||||
self.private_key = private_key
|
||||
|
||||
@classmethod
|
||||
def generate(cls):
|
||||
def generate(cls) -> EccKey:
|
||||
private_key = generate_private_key(SECP256R1())
|
||||
return cls(private_key)
|
||||
|
||||
@classmethod
|
||||
def from_private_key_bytes(cls, d_bytes, x_bytes, y_bytes):
|
||||
def from_private_key_bytes(
|
||||
cls, d_bytes: bytes, x_bytes: bytes, y_bytes: bytes
|
||||
) -> EccKey:
|
||||
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
|
||||
x = int.from_bytes(x_bytes, byteorder='big', signed=False)
|
||||
y = int.from_bytes(y_bytes, byteorder='big', signed=False)
|
||||
@@ -69,7 +71,7 @@ class EccKey:
|
||||
return cls(private_key)
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
def x(self) -> bytes:
|
||||
return (
|
||||
self.private_key.public_key()
|
||||
.public_numbers()
|
||||
@@ -77,14 +79,14 @@ class EccKey:
|
||||
)
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
def y(self) -> bytes:
|
||||
return (
|
||||
self.private_key.public_key()
|
||||
.public_numbers()
|
||||
.y.to_bytes(32, byteorder='big')
|
||||
)
|
||||
|
||||
def dh(self, public_key_x, public_key_y):
|
||||
def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes:
|
||||
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
|
||||
y = int.from_bytes(public_key_y, byteorder='big', signed=False)
|
||||
public_key = EllipticCurvePublicNumbers(x, y, SECP256R1()).public_key()
|
||||
@@ -97,14 +99,33 @@ class EccKey:
|
||||
# Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def xor(x, y):
|
||||
def generate_prand() -> bytes:
|
||||
'''Generates random 3 bytes, with the 2 most significant bits of 0b01.
|
||||
|
||||
See Bluetooth spec, Vol 6, Part E - Table 1.2.
|
||||
'''
|
||||
prand_bytes = secrets.token_bytes(6)
|
||||
return prand_bytes[:2] + bytes([(prand_bytes[2] & 0b01111111) | 0b01000000])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def xor(x: bytes, y: bytes) -> bytes:
|
||||
assert len(x) == len(y)
|
||||
return bytes(map(operator.xor, x, y))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def r():
|
||||
def reverse(input: bytes) -> bytes:
|
||||
'''
|
||||
Returns bytes of input in reversed endianness.
|
||||
'''
|
||||
return input[::-1]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def r() -> bytes:
|
||||
'''
|
||||
Generate 16 bytes of random data
|
||||
'''
|
||||
@@ -112,20 +133,20 @@ def r():
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def e(key, data):
|
||||
def e(key: bytes, data: bytes) -> bytes:
|
||||
'''
|
||||
AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output.
|
||||
|
||||
See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e
|
||||
'''
|
||||
|
||||
cipher = Cipher(algorithms.AES(bytes(reversed(key))), modes.ECB())
|
||||
cipher = Cipher(algorithms.AES(reverse(key)), modes.ECB())
|
||||
encryptor = cipher.encryptor()
|
||||
return bytes(reversed(encryptor.update(bytes(reversed(data)))))
|
||||
return reverse(encryptor.update(reverse(data)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def ah(k, r): # pylint: disable=redefined-outer-name
|
||||
def ah(k: bytes, r: bytes) -> bytes: # pylint: disable=redefined-outer-name
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah
|
||||
'''
|
||||
@@ -136,7 +157,16 @@ def ah(k, r): # pylint: disable=redefined-outer-name
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-name
|
||||
def c1(
|
||||
k: bytes,
|
||||
r: bytes,
|
||||
preq: bytes,
|
||||
pres: bytes,
|
||||
iat: int,
|
||||
rat: int,
|
||||
ia: bytes,
|
||||
ra: bytes,
|
||||
) -> bytes: # pylint: disable=redefined-outer-name
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for
|
||||
LE Legacy Pairing
|
||||
@@ -148,7 +178,7 @@ def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-n
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def s1(k, r1, r2):
|
||||
def s1(k: bytes, r1: bytes, r2: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy
|
||||
Pairing
|
||||
@@ -158,7 +188,7 @@ def s1(k, r1, r2):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def aes_cmac(m, k):
|
||||
def aes_cmac(m: bytes, k: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC
|
||||
|
||||
@@ -170,20 +200,16 @@ def aes_cmac(m, k):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def f4(u, v, x, z):
|
||||
def f4(u: bytes, v: bytes, x: bytes, z: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value
|
||||
Generation Function f4
|
||||
'''
|
||||
return bytes(
|
||||
reversed(
|
||||
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))
|
||||
)
|
||||
)
|
||||
return reverse(aes_cmac(reverse(u) + reverse(v) + z, reverse(x)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def f5(w, n1, n2, a1, a2):
|
||||
def f5(w: bytes, n1: bytes, n2: bytes, a1: bytes, a2: bytes) -> Tuple[bytes, bytes]:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation
|
||||
Function f5
|
||||
@@ -191,87 +217,83 @@ def f5(w, n1, n2, a1, a2):
|
||||
NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order
|
||||
'''
|
||||
salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE')
|
||||
t = aes_cmac(bytes(reversed(w)), salt)
|
||||
t = aes_cmac(reverse(w), salt)
|
||||
key_id = bytes([0x62, 0x74, 0x6C, 0x65])
|
||||
return (
|
||||
bytes(
|
||||
reversed(
|
||||
aes_cmac(
|
||||
bytes([0])
|
||||
+ key_id
|
||||
+ bytes(reversed(n1))
|
||||
+ bytes(reversed(n2))
|
||||
+ bytes(reversed(a1))
|
||||
+ bytes(reversed(a2))
|
||||
+ bytes([1, 0]),
|
||||
t,
|
||||
)
|
||||
reverse(
|
||||
aes_cmac(
|
||||
bytes([0])
|
||||
+ key_id
|
||||
+ reverse(n1)
|
||||
+ reverse(n2)
|
||||
+ reverse(a1)
|
||||
+ reverse(a2)
|
||||
+ bytes([1, 0]),
|
||||
t,
|
||||
)
|
||||
),
|
||||
bytes(
|
||||
reversed(
|
||||
aes_cmac(
|
||||
bytes([1])
|
||||
+ key_id
|
||||
+ bytes(reversed(n1))
|
||||
+ bytes(reversed(n2))
|
||||
+ bytes(reversed(a1))
|
||||
+ bytes(reversed(a2))
|
||||
+ bytes([1, 0]),
|
||||
t,
|
||||
)
|
||||
reverse(
|
||||
aes_cmac(
|
||||
bytes([1])
|
||||
+ key_id
|
||||
+ reverse(n1)
|
||||
+ reverse(n2)
|
||||
+ reverse(a1)
|
||||
+ reverse(a2)
|
||||
+ bytes([1, 0]),
|
||||
t,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def f6(w, n1, n2, r, io_cap, a1, a2): # pylint: disable=redefined-outer-name
|
||||
def f6(
|
||||
w: bytes, n1: bytes, n2: bytes, r: bytes, io_cap: bytes, a1: bytes, a2: bytes
|
||||
) -> bytes: # pylint: disable=redefined-outer-name
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value
|
||||
Generation Function f6
|
||||
'''
|
||||
return bytes(
|
||||
reversed(
|
||||
aes_cmac(
|
||||
bytes(reversed(n1))
|
||||
+ bytes(reversed(n2))
|
||||
+ bytes(reversed(r))
|
||||
+ bytes(reversed(io_cap))
|
||||
+ bytes(reversed(a1))
|
||||
+ bytes(reversed(a2)),
|
||||
bytes(reversed(w)),
|
||||
)
|
||||
return reverse(
|
||||
aes_cmac(
|
||||
reverse(n1)
|
||||
+ reverse(n2)
|
||||
+ reverse(r)
|
||||
+ reverse(io_cap)
|
||||
+ reverse(a1)
|
||||
+ reverse(a2),
|
||||
reverse(w),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def g2(u, v, x, y):
|
||||
def g2(u: bytes, v: bytes, x: bytes, y: bytes) -> int:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison
|
||||
Value Generation Function g2
|
||||
'''
|
||||
return int.from_bytes(
|
||||
aes_cmac(
|
||||
bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)),
|
||||
bytes(reversed(x)),
|
||||
reverse(u) + reverse(v) + reverse(y),
|
||||
reverse(x),
|
||||
)[-4:],
|
||||
byteorder='big',
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def h6(w, key_id):
|
||||
def h6(w: bytes, key_id: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.10 Link key conversion function h6
|
||||
'''
|
||||
return aes_cmac(key_id, w)
|
||||
return reverse(aes_cmac(key_id, reverse(w)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def h7(salt, w):
|
||||
def h7(salt: bytes, w: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.11 Link key conversion function h7
|
||||
'''
|
||||
return aes_cmac(w, salt)
|
||||
return reverse(aes_cmac(reverse(w), salt))
|
||||
|
||||
1959
bumble/device.py
1959
bumble/device.py
File diff suppressed because it is too large
Load Diff
87
bumble/drivers/__init__.py
Normal file
87
bumble/drivers/__init__.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Drivers that can be used to customize the interaction between a host and a controller,
|
||||
like loading firmware after a cold start.
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import pathlib
|
||||
import platform
|
||||
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
|
||||
|
||||
from . import rtk
|
||||
from .common import Driver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.host import Host
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_driver_for_host(host: Host) -> Optional[Driver]:
|
||||
"""Probe diver classes until one returns a valid instance for a host, or none is
|
||||
found.
|
||||
If a "driver" HCI metadata entry is present, only that driver class will be probed.
|
||||
"""
|
||||
driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver}
|
||||
probe_list: Iterable[str]
|
||||
if driver_name := host.hci_metadata.get("driver"):
|
||||
# Only probe a single driver
|
||||
probe_list = [driver_name]
|
||||
else:
|
||||
# Probe all drivers
|
||||
probe_list = driver_classes.keys()
|
||||
|
||||
for driver_name in probe_list:
|
||||
if driver_class := driver_classes.get(driver_name):
|
||||
logger.debug(f"Probing driver class: {driver_name}")
|
||||
if driver := await driver_class.for_host(host):
|
||||
logger.debug(f"Instantiated {driver_name} driver")
|
||||
return driver
|
||||
else:
|
||||
logger.debug(f"Skipping unknown driver class: {driver_name}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def project_data_dir() -> pathlib.Path:
|
||||
"""
|
||||
Returns:
|
||||
A path to an OS-specific directory for bumble data. The directory is created if
|
||||
it doesn't exist.
|
||||
"""
|
||||
import platformdirs
|
||||
|
||||
if platform.system() == 'Darwin':
|
||||
# platformdirs doesn't handle macOS right: it doesn't assemble a bundle id
|
||||
# out of author & project
|
||||
return platformdirs.user_data_path(
|
||||
appname='com.google.bumble', ensure_exists=True
|
||||
)
|
||||
else:
|
||||
# windows and linux don't use the com qualifier
|
||||
return platformdirs.user_data_path(
|
||||
appname='bumble', appauthor='google', ensure_exists=True
|
||||
)
|
||||
45
bumble/drivers/common.py
Normal file
45
bumble/drivers/common.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Common types for drivers.
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import abc
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class Driver(abc.ABC):
|
||||
"""Base class for drivers."""
|
||||
|
||||
@staticmethod
|
||||
async def for_host(_host):
|
||||
"""Return a driver instance for a host.
|
||||
|
||||
Args:
|
||||
host: Host object for which a driver should be created.
|
||||
|
||||
Returns:
|
||||
A Driver instance if a driver should be instantiated for this host, or
|
||||
None if no driver instance of this class is needed.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def init_controller(self):
|
||||
"""Initialize the controller."""
|
||||
666
bumble/drivers/rtk.py
Normal file
666
bumble/drivers/rtk.py
Normal file
@@ -0,0 +1,666 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Support for Realtek USB dongles.
|
||||
Based on various online bits of information, including the Linux kernel.
|
||||
(see `drivers/bluetooth/btrtl.c`)
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from dataclasses import dataclass
|
||||
import asyncio
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import struct
|
||||
from typing import Tuple
|
||||
import weakref
|
||||
|
||||
|
||||
from bumble.hci import (
|
||||
hci_vendor_command_op_code,
|
||||
STATUS_SPEC,
|
||||
HCI_SUCCESS,
|
||||
HCI_Command,
|
||||
HCI_Reset_Command,
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
)
|
||||
from bumble.drivers import common
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
RTK_ROM_LMP_8723A = 0x1200
|
||||
RTK_ROM_LMP_8723B = 0x8723
|
||||
RTK_ROM_LMP_8821A = 0x8821
|
||||
RTK_ROM_LMP_8761A = 0x8761
|
||||
RTK_ROM_LMP_8822B = 0x8822
|
||||
RTK_ROM_LMP_8852A = 0x8852
|
||||
RTK_CONFIG_MAGIC = 0x8723AB55
|
||||
|
||||
RTK_EPATCH_SIGNATURE = b"Realtech"
|
||||
|
||||
RTK_FRAGMENT_LENGTH = 252
|
||||
|
||||
RTK_FIRMWARE_DIR_ENV = "BUMBLE_RTK_FIRMWARE_DIR"
|
||||
RTK_LINUX_FIRMWARE_DIR = "/lib/firmware/rtl_bt"
|
||||
|
||||
|
||||
class RtlProjectId(enum.IntEnum):
|
||||
PROJECT_ID_8723A = 0
|
||||
PROJECT_ID_8723B = 1
|
||||
PROJECT_ID_8821A = 2
|
||||
PROJECT_ID_8761A = 3
|
||||
PROJECT_ID_8822B = 8
|
||||
PROJECT_ID_8723D = 9
|
||||
PROJECT_ID_8821C = 10
|
||||
PROJECT_ID_8822C = 13
|
||||
PROJECT_ID_8761B = 14
|
||||
PROJECT_ID_8852A = 18
|
||||
PROJECT_ID_8852B = 20
|
||||
PROJECT_ID_8852C = 25
|
||||
|
||||
|
||||
RTK_PROJECT_ID_TO_ROM = {
|
||||
0: RTK_ROM_LMP_8723A,
|
||||
1: RTK_ROM_LMP_8723B,
|
||||
2: RTK_ROM_LMP_8821A,
|
||||
3: RTK_ROM_LMP_8761A,
|
||||
8: RTK_ROM_LMP_8822B,
|
||||
9: RTK_ROM_LMP_8723B,
|
||||
10: RTK_ROM_LMP_8821A,
|
||||
13: RTK_ROM_LMP_8822B,
|
||||
14: RTK_ROM_LMP_8761A,
|
||||
18: RTK_ROM_LMP_8852A,
|
||||
20: RTK_ROM_LMP_8852A,
|
||||
25: RTK_ROM_LMP_8852A,
|
||||
}
|
||||
|
||||
# List of USB (VendorID, ProductID) for Realtek-based devices.
|
||||
RTK_USB_PRODUCTS = {
|
||||
# Realtek 8723AE
|
||||
(0x0930, 0x021D),
|
||||
(0x13D3, 0x3394),
|
||||
# Realtek 8723BE
|
||||
(0x0489, 0xE085),
|
||||
(0x0489, 0xE08B),
|
||||
(0x04F2, 0xB49F),
|
||||
(0x13D3, 0x3410),
|
||||
(0x13D3, 0x3416),
|
||||
(0x13D3, 0x3459),
|
||||
(0x13D3, 0x3494),
|
||||
# Realtek 8723BU
|
||||
(0x7392, 0xA611),
|
||||
# Realtek 8723DE
|
||||
(0x0BDA, 0xB009),
|
||||
(0x2FF8, 0xB011),
|
||||
# Realtek 8761BUV
|
||||
(0x0B05, 0x190E),
|
||||
(0x0BDA, 0x8771),
|
||||
(0x2230, 0x0016),
|
||||
(0x2357, 0x0604),
|
||||
(0x2550, 0x8761),
|
||||
(0x2B89, 0x8761),
|
||||
(0x7392, 0xC611),
|
||||
(0x0BDA, 0x877B),
|
||||
# Realtek 8821AE
|
||||
(0x0B05, 0x17DC),
|
||||
(0x13D3, 0x3414),
|
||||
(0x13D3, 0x3458),
|
||||
(0x13D3, 0x3461),
|
||||
(0x13D3, 0x3462),
|
||||
# Realtek 8821CE
|
||||
(0x0BDA, 0xB00C),
|
||||
(0x0BDA, 0xC822),
|
||||
(0x13D3, 0x3529),
|
||||
# Realtek 8822BE
|
||||
(0x0B05, 0x185C),
|
||||
(0x13D3, 0x3526),
|
||||
# Realtek 8822CE
|
||||
(0x04C5, 0x161F),
|
||||
(0x04CA, 0x4005),
|
||||
(0x0B05, 0x18EF),
|
||||
(0x0BDA, 0xB00C),
|
||||
(0x0BDA, 0xC123),
|
||||
(0x0BDA, 0xC822),
|
||||
(0x0CB5, 0xC547),
|
||||
(0x1358, 0xC123),
|
||||
(0x13D3, 0x3548),
|
||||
(0x13D3, 0x3549),
|
||||
(0x13D3, 0x3553),
|
||||
(0x13D3, 0x3555),
|
||||
(0x2FF8, 0x3051),
|
||||
# Realtek 8822CU
|
||||
(0x13D3, 0x3549),
|
||||
# Realtek 8852AE
|
||||
(0x04C5, 0x165C),
|
||||
(0x04CA, 0x4006),
|
||||
(0x0BDA, 0x2852),
|
||||
(0x0BDA, 0x385A),
|
||||
(0x0BDA, 0x4852),
|
||||
(0x0BDA, 0xC852),
|
||||
(0x0CB8, 0xC549),
|
||||
# Realtek 8852BE
|
||||
(0x0BDA, 0x887B),
|
||||
(0x0CB8, 0xC559),
|
||||
(0x13D3, 0x3571),
|
||||
# Realtek 8852CE
|
||||
(0x04C5, 0x1675),
|
||||
(0x04CA, 0x4007),
|
||||
(0x0CB8, 0xC558),
|
||||
(0x13D3, 0x3586),
|
||||
(0x13D3, 0x3587),
|
||||
(0x13D3, 0x3592),
|
||||
}
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HCI Commands
|
||||
# -----------------------------------------------------------------------------
|
||||
HCI_RTK_READ_ROM_VERSION_COMMAND = hci_vendor_command_op_code(0x6D)
|
||||
HCI_RTK_DOWNLOAD_COMMAND = hci_vendor_command_op_code(0x20)
|
||||
HCI_RTK_DROP_FIRMWARE_COMMAND = hci_vendor_command_op_code(0x66)
|
||||
HCI_Command.register_commands(globals())
|
||||
|
||||
|
||||
@HCI_Command.command(return_parameters_fields=[("status", STATUS_SPEC), ("version", 1)])
|
||||
class HCI_RTK_Read_ROM_Version_Command(HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
@HCI_Command.command(
|
||||
fields=[("index", 1), ("payload", RTK_FRAGMENT_LENGTH)],
|
||||
return_parameters_fields=[("status", STATUS_SPEC), ("index", 1)],
|
||||
)
|
||||
class HCI_RTK_Download_Command(HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
@HCI_Command.command()
|
||||
class HCI_RTK_Drop_Firmware_Command(HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Firmware:
|
||||
def __init__(self, firmware):
|
||||
extension_sig = bytes([0x51, 0x04, 0xFD, 0x77])
|
||||
|
||||
if not firmware.startswith(RTK_EPATCH_SIGNATURE):
|
||||
raise ValueError("Firmware does not start with epatch signature")
|
||||
|
||||
if not firmware.endswith(extension_sig):
|
||||
raise ValueError("Firmware does not end with extension sig")
|
||||
|
||||
# The firmware should start with a 14 byte header.
|
||||
epatch_header_size = 14
|
||||
if len(firmware) < epatch_header_size:
|
||||
raise ValueError("Firmware too short")
|
||||
|
||||
# Look for the "project ID", starting from the end.
|
||||
offset = len(firmware) - len(extension_sig)
|
||||
project_id = -1
|
||||
while offset >= epatch_header_size:
|
||||
length, opcode = firmware[offset - 2 : offset]
|
||||
offset -= 2
|
||||
|
||||
if opcode == 0xFF:
|
||||
# End
|
||||
break
|
||||
|
||||
if length == 0:
|
||||
raise ValueError("Invalid 0-length instruction")
|
||||
|
||||
if opcode == 0 and length == 1:
|
||||
project_id = firmware[offset - 1]
|
||||
break
|
||||
|
||||
offset -= length
|
||||
|
||||
if project_id < 0:
|
||||
raise ValueError("Project ID not found")
|
||||
|
||||
self.project_id = project_id
|
||||
|
||||
# Read the patch tables info.
|
||||
self.version, num_patches = struct.unpack("<IH", firmware[8:14])
|
||||
self.patches = []
|
||||
|
||||
# The patches tables are laid out as:
|
||||
# <ChipID_1><ChipID_2>...<ChipID_N> (16 bits each)
|
||||
# <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each)
|
||||
# <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each)
|
||||
if epatch_header_size + 8 * num_patches > len(firmware):
|
||||
raise ValueError("Firmware too short")
|
||||
chip_id_table_offset = epatch_header_size
|
||||
patch_length_table_offset = chip_id_table_offset + 2 * num_patches
|
||||
patch_offset_table_offset = chip_id_table_offset + 4 * num_patches
|
||||
for patch_index in range(num_patches):
|
||||
chip_id_offset = chip_id_table_offset + 2 * patch_index
|
||||
(chip_id,) = struct.unpack_from("<H", firmware, chip_id_offset)
|
||||
(patch_length,) = struct.unpack_from(
|
||||
"<H", firmware, patch_length_table_offset + 2 * patch_index
|
||||
)
|
||||
(patch_offset,) = struct.unpack_from(
|
||||
"<I", firmware, patch_offset_table_offset + 4 * patch_index
|
||||
)
|
||||
if patch_offset + patch_length > len(firmware):
|
||||
raise ValueError("Firmware too short")
|
||||
|
||||
# Get the SVN version for the patch
|
||||
(svn_version,) = struct.unpack_from(
|
||||
"<I", firmware, patch_offset + patch_length - 8
|
||||
)
|
||||
|
||||
# Create a payload with the patch, replacing the last 4 bytes with
|
||||
# the firmware version.
|
||||
self.patches.append(
|
||||
(
|
||||
chip_id,
|
||||
firmware[patch_offset : patch_offset + patch_length - 4]
|
||||
+ struct.pack("<I", self.version),
|
||||
svn_version,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class Driver(common.Driver):
|
||||
@dataclass
|
||||
class DriverInfo:
|
||||
rom: int
|
||||
hci: Tuple[int, int]
|
||||
config_needed: bool
|
||||
has_rom_version: bool
|
||||
has_msft_ext: bool = False
|
||||
fw_name: str = ""
|
||||
config_name: str = ""
|
||||
|
||||
DRIVER_INFOS = [
|
||||
# 8723A
|
||||
DriverInfo(
|
||||
rom=RTK_ROM_LMP_8723A,
|
||||
hci=(0x0B, 0x06),
|
||||
config_needed=False,
|
||||
has_rom_version=False,
|
||||
fw_name="rtl8723a_fw.bin",
|
||||
config_name="",
|
||||
),
|
||||
# 8723B
|
||||
DriverInfo(
|
||||
rom=RTK_ROM_LMP_8723B,
|
||||
hci=(0x0B, 0x06),
|
||||
config_needed=False,
|
||||
has_rom_version=True,
|
||||
fw_name="rtl8723b_fw.bin",
|
||||
config_name="rtl8723b_config.bin",
|
||||
),
|
||||
# 8723D
|
||||
DriverInfo(
|
||||
rom=RTK_ROM_LMP_8723B,
|
||||
hci=(0x0D, 0x08),
|
||||
config_needed=True,
|
||||
has_rom_version=True,
|
||||
fw_name="rtl8723d_fw.bin",
|
||||
config_name="rtl8723d_config.bin",
|
||||
),
|
||||
# 8821A
|
||||
DriverInfo(
|
||||
rom=RTK_ROM_LMP_8821A,
|
||||
hci=(0x0A, 0x06),
|
||||
config_needed=False,
|
||||
has_rom_version=True,
|
||||
fw_name="rtl8821a_fw.bin",
|
||||
config_name="rtl8821a_config.bin",
|
||||
),
|
||||
# 8821C
|
||||
DriverInfo(
|
||||
rom=RTK_ROM_LMP_8821A,
|
||||
hci=(0x0C, 0x08),
|
||||
config_needed=False,
|
||||
has_rom_version=True,
|
||||
has_msft_ext=True,
|
||||
fw_name="rtl8821c_fw.bin",
|
||||
config_name="rtl8821c_config.bin",
|
||||
),
|
||||
# 8761A
|
||||
DriverInfo(
|
||||
rom=RTK_ROM_LMP_8761A,
|
||||
hci=(0x0A, 0x06),
|
||||
config_needed=False,
|
||||
has_rom_version=True,
|
||||
fw_name="rtl8761a_fw.bin",
|
||||
config_name="rtl8761a_config.bin",
|
||||
),
|
||||
# 8761BU
|
||||
DriverInfo(
|
||||
rom=RTK_ROM_LMP_8761A,
|
||||
hci=(0x0B, 0x0A),
|
||||
config_needed=False,
|
||||
has_rom_version=True,
|
||||
fw_name="rtl8761bu_fw.bin",
|
||||
config_name="rtl8761bu_config.bin",
|
||||
),
|
||||
# 8822C
|
||||
DriverInfo(
|
||||
rom=RTK_ROM_LMP_8822B,
|
||||
hci=(0x0C, 0x0A),
|
||||
config_needed=False,
|
||||
has_rom_version=True,
|
||||
has_msft_ext=True,
|
||||
fw_name="rtl8822cu_fw.bin",
|
||||
config_name="rtl8822cu_config.bin",
|
||||
),
|
||||
# 8822B
|
||||
DriverInfo(
|
||||
rom=RTK_ROM_LMP_8822B,
|
||||
hci=(0x0B, 0x07),
|
||||
config_needed=True,
|
||||
has_rom_version=True,
|
||||
has_msft_ext=True,
|
||||
fw_name="rtl8822b_fw.bin",
|
||||
config_name="rtl8822b_config.bin",
|
||||
),
|
||||
# 8852A
|
||||
DriverInfo(
|
||||
rom=RTK_ROM_LMP_8852A,
|
||||
hci=(0x0A, 0x0B),
|
||||
config_needed=False,
|
||||
has_rom_version=True,
|
||||
has_msft_ext=True,
|
||||
fw_name="rtl8852au_fw.bin",
|
||||
config_name="rtl8852au_config.bin",
|
||||
),
|
||||
# 8852B
|
||||
DriverInfo(
|
||||
rom=RTK_ROM_LMP_8852A,
|
||||
hci=(0xB, 0xB),
|
||||
config_needed=False,
|
||||
has_rom_version=True,
|
||||
has_msft_ext=True,
|
||||
fw_name="rtl8852bu_fw.bin",
|
||||
config_name="rtl8852bu_config.bin",
|
||||
),
|
||||
# 8852C
|
||||
DriverInfo(
|
||||
rom=RTK_ROM_LMP_8852A,
|
||||
hci=(0x0C, 0x0C),
|
||||
config_needed=False,
|
||||
has_rom_version=True,
|
||||
has_msft_ext=True,
|
||||
fw_name="rtl8852cu_fw.bin",
|
||||
config_name="rtl8852cu_config.bin",
|
||||
),
|
||||
]
|
||||
|
||||
POST_DROP_DELAY = 0.2
|
||||
|
||||
@staticmethod
|
||||
def find_driver_info(hci_version, hci_subversion, lmp_subversion):
|
||||
for driver_info in Driver.DRIVER_INFOS:
|
||||
if driver_info.rom == lmp_subversion and driver_info.hci == (
|
||||
hci_subversion,
|
||||
hci_version,
|
||||
):
|
||||
return driver_info
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def find_binary_path(file_name):
|
||||
# First check if an environment variable is set
|
||||
if RTK_FIRMWARE_DIR_ENV in os.environ:
|
||||
if (
|
||||
path := pathlib.Path(os.environ[RTK_FIRMWARE_DIR_ENV]) / file_name
|
||||
).is_file():
|
||||
logger.debug(f"{file_name} found in env dir")
|
||||
return path
|
||||
|
||||
# When the environment variable is set, don't look elsewhere
|
||||
return None
|
||||
|
||||
# Then, look where the firmware download tool writes by default
|
||||
if (path := rtk_firmware_dir() / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in project data dir")
|
||||
return path
|
||||
|
||||
# Then, look in the package's driver directory
|
||||
if (path := pathlib.Path(__file__).parent / "rtk_fw" / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in package dir")
|
||||
return path
|
||||
|
||||
# On Linux, check the system's FW directory
|
||||
if (
|
||||
platform.system() == "Linux"
|
||||
and (path := pathlib.Path(RTK_LINUX_FIRMWARE_DIR) / file_name).is_file()
|
||||
):
|
||||
logger.debug(f"{file_name} found in Linux system FW dir")
|
||||
return path
|
||||
|
||||
# Finally look in the current directory
|
||||
if (path := pathlib.Path.cwd() / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in CWD")
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def check(host):
|
||||
if not host.hci_metadata:
|
||||
logger.debug("USB metadata not found")
|
||||
return False
|
||||
|
||||
if host.hci_metadata.get('driver') == 'rtk':
|
||||
# Forced driver
|
||||
return True
|
||||
|
||||
vendor_id = host.hci_metadata.get("vendor_id")
|
||||
product_id = host.hci_metadata.get("product_id")
|
||||
if vendor_id is None or product_id is None:
|
||||
logger.debug("USB metadata not sufficient")
|
||||
return False
|
||||
|
||||
if (vendor_id, product_id) not in RTK_USB_PRODUCTS:
|
||||
logger.debug(
|
||||
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def driver_info_for_host(cls, host):
|
||||
await host.send_command(HCI_Reset_Command(), check_result=True)
|
||||
host.ready = True # Needed to let the host know the controller is ready.
|
||||
|
||||
response = await host.send_command(
|
||||
HCI_Read_Local_Version_Information_Command(), check_result=True
|
||||
)
|
||||
local_version = response.return_parameters
|
||||
|
||||
logger.debug(
|
||||
f"looking for a driver: 0x{local_version.lmp_subversion:04X} "
|
||||
f"(0x{local_version.hci_version:02X}, "
|
||||
f"0x{local_version.hci_subversion:04X})"
|
||||
)
|
||||
|
||||
driver_info = cls.find_driver_info(
|
||||
local_version.hci_version,
|
||||
local_version.hci_subversion,
|
||||
local_version.lmp_subversion,
|
||||
)
|
||||
if driver_info is None:
|
||||
# TODO: it seems that the Linux driver will send command (0x3f, 0x66)
|
||||
# in this case and then re-read the local version, then re-match.
|
||||
logger.debug("firmware already loaded or no known driver for this device")
|
||||
|
||||
return driver_info
|
||||
|
||||
@classmethod
|
||||
async def for_host(cls, host, force=False):
|
||||
# Check that a driver is needed for this host
|
||||
if not force and not cls.check(host):
|
||||
return None
|
||||
|
||||
# Get the driver info
|
||||
driver_info = await cls.driver_info_for_host(host)
|
||||
if driver_info is None:
|
||||
return None
|
||||
|
||||
# Load the firmware
|
||||
firmware_path = cls.find_binary_path(driver_info.fw_name)
|
||||
if not firmware_path:
|
||||
logger.warning(f"Firmware file {driver_info.fw_name} not found")
|
||||
logger.warning("See https://google.github.io/bumble/drivers/realtek.html")
|
||||
return None
|
||||
with open(firmware_path, "rb") as firmware_file:
|
||||
firmware = firmware_file.read()
|
||||
|
||||
# Load the config
|
||||
config = None
|
||||
if driver_info.config_name:
|
||||
config_path = cls.find_binary_path(driver_info.config_name)
|
||||
if config_path:
|
||||
with open(config_path, "rb") as config_file:
|
||||
config = config_file.read()
|
||||
if driver_info.config_needed and not config:
|
||||
logger.warning("Config needed, but no config file available")
|
||||
return None
|
||||
|
||||
return cls(host, driver_info, firmware, config)
|
||||
|
||||
def __init__(self, host, driver_info, firmware, config):
|
||||
self.host = weakref.proxy(host)
|
||||
self.driver_info = driver_info
|
||||
self.firmware = firmware
|
||||
self.config = config
|
||||
|
||||
@staticmethod
|
||||
async def drop_firmware(host):
|
||||
host.send_hci_packet(HCI_RTK_Drop_Firmware_Command())
|
||||
|
||||
# Wait for the command to be effective (no response is sent)
|
||||
await asyncio.sleep(Driver.POST_DROP_DELAY)
|
||||
|
||||
async def download_for_rtl8723a(self):
|
||||
# Check that the firmware image does not include an epatch signature.
|
||||
if RTK_EPATCH_SIGNATURE in self.firmware:
|
||||
logger.warning(
|
||||
"epatch signature found in firmware, it is probably the wrong firmware"
|
||||
)
|
||||
return
|
||||
|
||||
# TODO: load the firmware
|
||||
|
||||
async def download_for_rtl8723b(self):
|
||||
if self.driver_info.has_rom_version:
|
||||
response = await self.host.send_command(
|
||||
HCI_RTK_Read_ROM_Version_Command(), check_result=True
|
||||
)
|
||||
if response.return_parameters.status != HCI_SUCCESS:
|
||||
logger.warning("can't get ROM version")
|
||||
return
|
||||
rom_version = response.return_parameters.version
|
||||
logger.debug(f"ROM version before download: {rom_version:04X}")
|
||||
else:
|
||||
rom_version = 0
|
||||
|
||||
firmware = Firmware(self.firmware)
|
||||
logger.debug(f"firmware: project_id=0x{firmware.project_id:04X}")
|
||||
for patch in firmware.patches:
|
||||
if patch[0] == rom_version + 1:
|
||||
logger.debug(f"using patch {patch[0]}")
|
||||
break
|
||||
else:
|
||||
logger.warning("no valid patch found for rom version {rom_version}")
|
||||
return
|
||||
|
||||
# Append the config if there is one.
|
||||
if self.config:
|
||||
payload = patch[1] + self.config
|
||||
else:
|
||||
payload = patch[1]
|
||||
|
||||
# Download the payload, one fragment at a time.
|
||||
fragment_count = math.ceil(len(payload) / RTK_FRAGMENT_LENGTH)
|
||||
for fragment_index in range(fragment_count):
|
||||
# NOTE: the Linux driver somehow adds 1 to the index after it wraps around.
|
||||
# That's odd, but we"ll do the same here.
|
||||
download_index = fragment_index & 0x7F
|
||||
if download_index >= 0x80:
|
||||
download_index += 1
|
||||
if fragment_index == fragment_count - 1:
|
||||
download_index |= 0x80 # End marker.
|
||||
fragment_offset = fragment_index * RTK_FRAGMENT_LENGTH
|
||||
fragment = payload[fragment_offset : fragment_offset + RTK_FRAGMENT_LENGTH]
|
||||
logger.debug(f"downloading fragment {fragment_index}")
|
||||
await self.host.send_command(
|
||||
HCI_RTK_Download_Command(
|
||||
index=download_index, payload=fragment, check_result=True
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("download complete!")
|
||||
|
||||
# Read the version again
|
||||
response = await self.host.send_command(
|
||||
HCI_RTK_Read_ROM_Version_Command(), check_result=True
|
||||
)
|
||||
if response.return_parameters.status != HCI_SUCCESS:
|
||||
logger.warning("can't get ROM version")
|
||||
else:
|
||||
rom_version = response.return_parameters.version
|
||||
logger.debug(f"ROM version after download: {rom_version:04X}")
|
||||
|
||||
async def download_firmware(self):
|
||||
if self.driver_info.rom == RTK_ROM_LMP_8723A:
|
||||
return await self.download_for_rtl8723a()
|
||||
|
||||
if self.driver_info.rom in (
|
||||
RTK_ROM_LMP_8723B,
|
||||
RTK_ROM_LMP_8821A,
|
||||
RTK_ROM_LMP_8761A,
|
||||
RTK_ROM_LMP_8822B,
|
||||
RTK_ROM_LMP_8852A,
|
||||
):
|
||||
return await self.download_for_rtl8723b()
|
||||
|
||||
raise ValueError("ROM not supported")
|
||||
|
||||
async def init_controller(self):
|
||||
await self.download_firmware()
|
||||
await self.host.send_command(HCI_Reset_Command(), check_result=True)
|
||||
logger.info(f"loaded FW image {self.driver_info.fw_name}")
|
||||
|
||||
|
||||
def rtk_firmware_dir() -> pathlib.Path:
|
||||
"""
|
||||
Returns:
|
||||
A path to a subdir of the project data dir for Realtek firmware.
|
||||
The directory is created if it doesn't exist.
|
||||
"""
|
||||
from bumble.drivers import project_data_dir
|
||||
|
||||
p = project_data_dir() / "firmware" / "realtek"
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
return p
|
||||
320
bumble/gatt.py
320
bumble/gatt.py
@@ -23,16 +23,28 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import struct
|
||||
from typing import Optional, Sequence, List
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from .colors import color
|
||||
from .core import UUID, get_dict_key_by_value
|
||||
from .att import Attribute
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID
|
||||
from bumble.att import Attribute, AttributeValue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.gatt_client import AttributeProxy
|
||||
from bumble.device import Connection
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -93,20 +105,35 @@ GATT_RECONNECTION_CONFIGURATION_SERVICE = UUID.from_16_bits(0x1829, 'Reconne
|
||||
GATT_INSULIN_DELIVERY_SERVICE = UUID.from_16_bits(0x183A, 'Insulin Delivery')
|
||||
GATT_BINARY_SENSOR_SERVICE = UUID.from_16_bits(0x183B, 'Binary Sensor')
|
||||
GATT_EMERGENCY_CONFIGURATION_SERVICE = UUID.from_16_bits(0x183C, 'Emergency Configuration')
|
||||
GATT_AUTHORIZATION_CONTROL_SERVICE = UUID.from_16_bits(0x183D, 'Authorization Control')
|
||||
GATT_PHYSICAL_ACTIVITY_MONITOR_SERVICE = UUID.from_16_bits(0x183E, 'Physical Activity Monitor')
|
||||
GATT_ELAPSED_TIME_SERVICE = UUID.from_16_bits(0x183F, 'Elapsed Time')
|
||||
GATT_GENERIC_HEALTH_SENSOR_SERVICE = UUID.from_16_bits(0x1840, 'Generic Health Sensor')
|
||||
GATT_AUDIO_INPUT_CONTROL_SERVICE = UUID.from_16_bits(0x1843, 'Audio Input Control')
|
||||
GATT_VOLUME_CONTROL_SERVICE = UUID.from_16_bits(0x1844, 'Volume Control')
|
||||
GATT_VOLUME_OFFSET_CONTROL_SERVICE = UUID.from_16_bits(0x1845, 'Volume Offset Control')
|
||||
GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification Service')
|
||||
GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification')
|
||||
GATT_DEVICE_TIME_SERVICE = UUID.from_16_bits(0x1847, 'Device Time')
|
||||
GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control Service')
|
||||
GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control Service')
|
||||
GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control')
|
||||
GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control')
|
||||
GATT_CONSTANT_TONE_EXTENSION_SERVICE = UUID.from_16_bits(0x184A, 'Constant Tone Extension')
|
||||
GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer Service')
|
||||
GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer Service')
|
||||
GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer')
|
||||
GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer')
|
||||
GATT_MICROPHONE_CONTROL_SERVICE = UUID.from_16_bits(0x184D, 'Microphone Control')
|
||||
GATT_AUDIO_STREAM_CONTROL_SERVICE = UUID.from_16_bits(0x184E, 'Audio Stream Control')
|
||||
GATT_BROADCAST_AUDIO_SCAN_SERVICE = UUID.from_16_bits(0x184F, 'Broadcast Audio Scan')
|
||||
GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE = UUID.from_16_bits(0x1850, 'Published Audio Capabilities')
|
||||
GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1851, 'Basic Audio Announcement')
|
||||
GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1852, 'Broadcast Audio Announcement')
|
||||
GATT_COMMON_AUDIO_SERVICE = UUID.from_16_bits(0x1853, 'Common Audio')
|
||||
GATT_HEARING_ACCESS_SERVICE = UUID.from_16_bits(0x1854, 'Hearing Access')
|
||||
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE = UUID.from_16_bits(0x1855, 'Telephony and Media Audio')
|
||||
GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1856, 'Public Broadcast Announcement')
|
||||
GATT_ELECTRONIC_SHELF_LABEL_SERVICE = UUID.from_16_bits(0X1857, 'Electronic Shelf Label')
|
||||
GATT_GAMING_AUDIO_SERVICE = UUID.from_16_bits(0x1858, 'Gaming Audio')
|
||||
GATT_MESH_PROXY_SOLICITATION_SERVICE = UUID.from_16_bits(0x1859, 'Mesh Audio Solicitation')
|
||||
|
||||
# Types
|
||||
# Attribute Types
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2800, 'Primary Service')
|
||||
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2801, 'Secondary Service')
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2802, 'Include')
|
||||
@@ -129,6 +156,8 @@ GATT_ENVIRONMENTAL_SENSING_MEASUREMENT_DESCRIPTOR = UUID.from_16_bits(0x290C,
|
||||
GATT_ENVIRONMENTAL_SENSING_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290D, 'Environmental Sensing Trigger Setting')
|
||||
GATT_TIME_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290E, 'Time Trigger Setting')
|
||||
GATT_COMPLETE_BR_EDR_TRANSPORT_BLOCK_DATA_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Complete BR-EDR Transport Block Data')
|
||||
GATT_OBSERVATION_SCHEDULE_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Observation Schedule')
|
||||
GATT_VALID_RANGE_AND_ACCURACY_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Valid Range And Accuracy')
|
||||
|
||||
# Device Information Service
|
||||
GATT_SYSTEM_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A23, 'System ID')
|
||||
@@ -156,6 +185,96 @@ GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A39, 'Heart
|
||||
# Battery Service
|
||||
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level')
|
||||
|
||||
# Telephony And Media Audio Service (TMAS)
|
||||
GATT_TMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2B51, 'TMAP Role')
|
||||
|
||||
# Audio Input Control Service (AICS)
|
||||
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B77, 'Audio Input State')
|
||||
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC = UUID.from_16_bits(0x2B78, 'Gain Settings Attribute')
|
||||
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC = UUID.from_16_bits(0x2B79, 'Audio Input Type')
|
||||
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC = UUID.from_16_bits(0x2B7A, 'Audio Input Status')
|
||||
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7B, 'Audio Input Control Point')
|
||||
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B7C, 'Audio Input Description')
|
||||
|
||||
# Volume Control Service (VCS)
|
||||
GATT_VOLUME_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B7D, 'Volume State')
|
||||
GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7E, 'Volume Control Point')
|
||||
GATT_VOLUME_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2B7F, 'Volume Flags')
|
||||
|
||||
# Volume Offset Control Service (VOCS)
|
||||
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B80, 'Volume Offset State')
|
||||
GATT_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2B81, 'Audio Location')
|
||||
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B82, 'Volume Offset Control Point')
|
||||
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B83, 'Audio Output Description')
|
||||
|
||||
# Coordinated Set Identification Service (CSIS)
|
||||
GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC = UUID.from_16_bits(0x2B84, 'Set Identity Resolving Key')
|
||||
GATT_COORDINATED_SET_SIZE_CHARACTERISTIC = UUID.from_16_bits(0x2B85, 'Coordinated Set Size')
|
||||
GATT_SET_MEMBER_LOCK_CHARACTERISTIC = UUID.from_16_bits(0x2B86, 'Set Member Lock')
|
||||
GATT_SET_MEMBER_RANK_CHARACTERISTIC = UUID.from_16_bits(0x2B87, 'Set Member Rank')
|
||||
|
||||
# Media Control Service (MCS)
|
||||
GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2B93, 'Media Player Name')
|
||||
GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B94, 'Media Player Icon Object ID')
|
||||
GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC = UUID.from_16_bits(0x2B95, 'Media Player Icon URL')
|
||||
GATT_TRACK_CHANGED_CHARACTERISTIC = UUID.from_16_bits(0x2B96, 'Track Changed')
|
||||
GATT_TRACK_TITLE_CHARACTERISTIC = UUID.from_16_bits(0x2B97, 'Track Title')
|
||||
GATT_TRACK_DURATION_CHARACTERISTIC = UUID.from_16_bits(0x2B98, 'Track Duration')
|
||||
GATT_TRACK_POSITION_CHARACTERISTIC = UUID.from_16_bits(0x2B99, 'Track Position')
|
||||
GATT_PLAYBACK_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9A, 'Playback Speed')
|
||||
GATT_SEEKING_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9B, 'Seeking Speed')
|
||||
GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9C, 'Current Track Segments Object ID')
|
||||
GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9D, 'Current Track Object ID')
|
||||
GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9E, 'Next Track Object ID')
|
||||
GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9F, 'Parent Group Object ID')
|
||||
GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA0, 'Current Group Object ID')
|
||||
GATT_PLAYING_ORDER_CHARACTERISTIC = UUID.from_16_bits(0x2BA1, 'Playing Order')
|
||||
GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA2, 'Playing Orders Supported')
|
||||
GATT_MEDIA_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BA3, 'Media State')
|
||||
GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA4, 'Media Control Point')
|
||||
GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA5, 'Media Control Point Opcodes Supported')
|
||||
GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA6, 'Search Results Object ID')
|
||||
GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA7, 'Search Control Point')
|
||||
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control Id')
|
||||
|
||||
# Telephone Bearer Service (TBS)
|
||||
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB4, 'Bearer Provider Name')
|
||||
GATT_BEARER_UCI_CHARACTERISTIC = UUID.from_16_bits(0x2BB5, 'Bearer UCI')
|
||||
GATT_BEARER_TECHNOLOGY_CHARACTERISTIC = UUID.from_16_bits(0x2BB6, 'Bearer Technology')
|
||||
GATT_BEARER_URI_SCHEMES_SUPPORTED_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2BB7, 'Bearer URI Schemes Supported List')
|
||||
GATT_BEARER_SIGNAL_STRENGTH_CHARACTERISTIC = UUID.from_16_bits(0x2BB8, 'Bearer Signal Strength')
|
||||
GATT_BEARER_SIGNAL_STRENGTH_REPORTING_INTERVAL_CHARACTERISTIC = UUID.from_16_bits(0x2BB9, 'Bearer Signal Strength Reporting Interval')
|
||||
GATT_BEARER_LIST_CURRENT_CALLS_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Bearer List Current Calls')
|
||||
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBB, 'Content Control ID')
|
||||
GATT_STATUS_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2BBC, 'Status Flags')
|
||||
GATT_INCOMING_CALL_TARGET_BEARER_URI_CHARACTERISTIC = UUID.from_16_bits(0x2BBD, 'Incoming Call Target Bearer URI')
|
||||
GATT_CALL_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BBE, 'Call State')
|
||||
GATT_CALL_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BBF, 'Call Control Point')
|
||||
GATT_CALL_CONTROL_POINT_OPTIONAL_OPCODES_CHARACTERISTIC = UUID.from_16_bits(0x2BC0, 'Call Control Point Optional Opcodes')
|
||||
GATT_TERMINATION_REASON_CHARACTERISTIC = UUID.from_16_bits(0x2BC1, 'Termination Reason')
|
||||
GATT_INCOMING_CALL_CHARACTERISTIC = UUID.from_16_bits(0x2BC2, 'Incoming Call')
|
||||
GATT_CALL_FRIENDLY_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Call Friendly Name')
|
||||
|
||||
# Microphone Control Service (MICS)
|
||||
GATT_MUTE_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Mute')
|
||||
|
||||
# Audio Stream Control Service (ASCS)
|
||||
GATT_SINK_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC4, 'Sink ASE')
|
||||
GATT_SOURCE_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC5, 'Source ASE')
|
||||
GATT_ASE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC6, 'ASE Control Point')
|
||||
|
||||
# Broadcast Audio Scan Service (BASS)
|
||||
GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC7, 'Broadcast Audio Scan Control Point')
|
||||
GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BC8, 'Broadcast Receive State')
|
||||
|
||||
# Published Audio Capabilities Service (PACS)
|
||||
GATT_SINK_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BC9, 'Sink PAC')
|
||||
GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCA, 'Sink Audio Location')
|
||||
GATT_SOURCE_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BCB, 'Source PAC')
|
||||
GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Source Audio Location')
|
||||
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
|
||||
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
|
||||
|
||||
# ASHA Service
|
||||
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
|
||||
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
|
||||
@@ -177,6 +296,9 @@ GATT_BOOT_KEYBOARD_INPUT_REPORT_CHARACTERISTIC = UUID.from_16_bi
|
||||
GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bits(0x2A2B, 'Current Time')
|
||||
GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report')
|
||||
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
|
||||
GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features')
|
||||
GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash')
|
||||
GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features')
|
||||
|
||||
# fmt: on
|
||||
# pylint: enable=line-too-long
|
||||
@@ -187,7 +309,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def show_services(services):
|
||||
def show_services(services: Iterable[Service]) -> None:
|
||||
for service in services:
|
||||
print(color(str(service), 'cyan'))
|
||||
|
||||
@@ -210,11 +332,11 @@ class Service(Attribute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uuid,
|
||||
uuid: Union[str, UUID],
|
||||
characteristics: List[Characteristic],
|
||||
primary=True,
|
||||
included_services: List[Service] = [],
|
||||
):
|
||||
) -> None:
|
||||
# Convert the uuid to a UUID object if it isn't already
|
||||
if isinstance(uuid, str):
|
||||
uuid = UUID(uuid)
|
||||
@@ -239,7 +361,7 @@ class Service(Attribute):
|
||||
"""
|
||||
return None
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Service(handle=0x{self.handle:04X}, '
|
||||
f'end=0x{self.end_group_handle:04X}, '
|
||||
@@ -255,10 +377,15 @@ class TemplateService(Service):
|
||||
to expose their UUID as a class property
|
||||
'''
|
||||
|
||||
UUID: Optional[UUID] = None
|
||||
UUID: UUID
|
||||
|
||||
def __init__(self, characteristics, primary=True):
|
||||
super().__init__(self.UUID, characteristics, primary)
|
||||
def __init__(
|
||||
self,
|
||||
characteristics: List[Characteristic],
|
||||
primary: bool = True,
|
||||
included_services: List[Service] = [],
|
||||
) -> None:
|
||||
super().__init__(self.UUID, characteristics, primary, included_services)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -269,7 +396,7 @@ class IncludedServiceDeclaration(Attribute):
|
||||
|
||||
service: Service
|
||||
|
||||
def __init__(self, service):
|
||||
def __init__(self, service: Service) -> None:
|
||||
declaration_bytes = struct.pack(
|
||||
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
|
||||
)
|
||||
@@ -278,13 +405,12 @@ class IncludedServiceDeclaration(Attribute):
|
||||
)
|
||||
self.service = service
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
|
||||
f'group_starting_handle=0x{self.service.handle:04X}, '
|
||||
f'group_ending_handle=0x{self.service.end_group_handle:04X}, '
|
||||
f'uuid={self.service.uuid}, '
|
||||
f'{self.service.properties!s})'
|
||||
f'uuid={self.service.uuid})'
|
||||
)
|
||||
|
||||
|
||||
@@ -309,31 +435,33 @@ class Characteristic(Attribute):
|
||||
AUTHENTICATED_SIGNED_WRITES = 0x40
|
||||
EXTENDED_PROPERTIES = 0x80
|
||||
|
||||
@staticmethod
|
||||
def from_string(properties_str: str) -> Characteristic.Properties:
|
||||
property_names: List[str] = []
|
||||
for property in Characteristic.Properties:
|
||||
if property.name is None:
|
||||
raise TypeError()
|
||||
property_names.append(property.name)
|
||||
|
||||
def string_to_property(property_string) -> Characteristic.Properties:
|
||||
for property in zip(Characteristic.Properties, property_names):
|
||||
if property_string == property[1]:
|
||||
return property[0]
|
||||
raise TypeError(f"Unable to convert {property_string} to Property")
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, properties_str: str) -> Characteristic.Properties:
|
||||
try:
|
||||
return functools.reduce(
|
||||
lambda x, y: x | string_to_property(y),
|
||||
properties_str.split(","),
|
||||
lambda x, y: x | cls[y],
|
||||
properties_str.replace("|", ",").split(","),
|
||||
Characteristic.Properties(0),
|
||||
)
|
||||
except TypeError:
|
||||
except (TypeError, KeyError):
|
||||
# The check for `p.name is not None` here is needed because for InFlag
|
||||
# enums, the .name property can be None, when the enum value is 0,
|
||||
# so the type hint for .name is Optional[str].
|
||||
enum_list: List[str] = [p.name for p in cls if p.name is not None]
|
||||
enum_list_str = ",".join(enum_list)
|
||||
raise TypeError(
|
||||
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by commas: {','.join(property_names)}\nGot: {properties_str}"
|
||||
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
# NOTE: we override this method to offer a consistent result between python
|
||||
# versions: the value returned by IntFlag.__str__() changed in version 11.
|
||||
return '|'.join(
|
||||
flag.name
|
||||
for flag in Characteristic.Properties
|
||||
if self.value & flag.value and flag.name is not None
|
||||
)
|
||||
|
||||
# For backwards compatibility these are defined here
|
||||
# For new code, please use Characteristic.Properties.X
|
||||
BROADCAST = Properties.BROADCAST
|
||||
@@ -347,10 +475,10 @@ class Characteristic(Attribute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uuid,
|
||||
uuid: Union[str, bytes, UUID],
|
||||
properties: Characteristic.Properties,
|
||||
permissions,
|
||||
value=b'',
|
||||
permissions: Union[str, Attribute.Permissions],
|
||||
value: Union[str, bytes, CharacteristicValue] = b'',
|
||||
descriptors: Sequence[Descriptor] = (),
|
||||
):
|
||||
super().__init__(uuid, permissions, value)
|
||||
@@ -368,12 +496,12 @@ class Characteristic(Attribute):
|
||||
def has_properties(self, properties: Characteristic.Properties) -> bool:
|
||||
return self.properties & properties == properties
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Characteristic(handle=0x{self.handle:04X}, '
|
||||
f'end=0x{self.end_group_handle:04X}, '
|
||||
f'uuid={self.uuid}, '
|
||||
f'{self.properties!s})'
|
||||
f'{self.properties})'
|
||||
)
|
||||
|
||||
|
||||
@@ -385,7 +513,7 @@ class CharacteristicDeclaration(Attribute):
|
||||
|
||||
characteristic: Characteristic
|
||||
|
||||
def __init__(self, characteristic, value_handle):
|
||||
def __init__(self, characteristic: Characteristic, value_handle: int) -> None:
|
||||
declaration_bytes = (
|
||||
struct.pack('<BH', characteristic.properties, value_handle)
|
||||
+ characteristic.uuid.to_pdu_bytes()
|
||||
@@ -396,66 +524,53 @@ class CharacteristicDeclaration(Attribute):
|
||||
self.value_handle = value_handle
|
||||
self.characteristic = characteristic
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
|
||||
f'value_handle=0x{self.value_handle:04X}, '
|
||||
f'uuid={self.characteristic.uuid}, '
|
||||
f'{self.characteristic.properties!s})'
|
||||
f'{self.characteristic.properties})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicValue:
|
||||
'''
|
||||
Characteristic value where reading and/or writing is delegated to functions
|
||||
passed as arguments to the constructor.
|
||||
'''
|
||||
|
||||
def __init__(self, read=None, write=None):
|
||||
self._read = read
|
||||
self._write = write
|
||||
|
||||
def read(self, connection):
|
||||
return self._read(connection) if self._read else b''
|
||||
|
||||
def write(self, connection, value):
|
||||
if self._write:
|
||||
self._write(connection, value)
|
||||
class CharacteristicValue(AttributeValue):
|
||||
"""Same as AttributeValue, for backward compatibility"""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicAdapter:
|
||||
'''
|
||||
An adapter that can adapt any object with `read_value` and `write_value`
|
||||
methods (like Characteristic and CharacteristicProxy objects) by wrapping
|
||||
those methods with ones that return/accept encoded/decoded values.
|
||||
Objects with async methods are considered proxies, so the adaptation is one
|
||||
where the return value of `read_value` is decoded and the value passed to
|
||||
`write_value` is encoded. Other objects are considered local characteristics
|
||||
so the adaptation is one where the return value of `read_value` is encoded
|
||||
and the value passed to `write_value` is decoded.
|
||||
If the characteristic has a `subscribe` method, it is wrapped with one where
|
||||
the values are decoded before being passed to the subscriber.
|
||||
An adapter that can adapt Characteristic and AttributeProxy objects
|
||||
by wrapping their `read_value()` and `write_value()` methods with ones that
|
||||
return/accept encoded/decoded values.
|
||||
|
||||
For proxies (i.e used by a GATT client), the adaptation is one where the return
|
||||
value of `read_value()` is decoded and the value passed to `write_value()` is
|
||||
encoded. The `subscribe()` method, is wrapped with one where the values are decoded
|
||||
before being passed to the subscriber.
|
||||
|
||||
For local values (i.e hosted by a GATT server) the adaptation is one where the
|
||||
return value of `read_value()` is encoded and the value passed to `write_value()`
|
||||
is decoded.
|
||||
'''
|
||||
|
||||
def __init__(self, characteristic):
|
||||
self.wrapped_characteristic = characteristic
|
||||
self.subscribers = {} # Map from subscriber to proxy subscriber
|
||||
read_value: Callable
|
||||
write_value: Callable
|
||||
|
||||
if asyncio.iscoroutinefunction(
|
||||
characteristic.read_value
|
||||
) and asyncio.iscoroutinefunction(characteristic.write_value):
|
||||
self.read_value = self.read_decoded_value
|
||||
self.write_value = self.write_decoded_value
|
||||
else:
|
||||
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
|
||||
self.wrapped_characteristic = characteristic
|
||||
self.subscribers: Dict[
|
||||
Callable, Callable
|
||||
] = {} # Map from subscriber to proxy subscriber
|
||||
|
||||
if isinstance(characteristic, Characteristic):
|
||||
self.read_value = self.read_encoded_value
|
||||
self.write_value = self.write_encoded_value
|
||||
|
||||
if hasattr(self.wrapped_characteristic, 'subscribe'):
|
||||
else:
|
||||
self.read_value = self.read_decoded_value
|
||||
self.write_value = self.write_decoded_value
|
||||
self.subscribe = self.wrapped_subscribe
|
||||
|
||||
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
|
||||
self.unsubscribe = self.wrapped_unsubscribe
|
||||
|
||||
def __getattr__(self, name):
|
||||
@@ -474,11 +589,13 @@ class CharacteristicAdapter:
|
||||
else:
|
||||
setattr(self.wrapped_characteristic, name, value)
|
||||
|
||||
def read_encoded_value(self, connection):
|
||||
return self.encode_value(self.wrapped_characteristic.read_value(connection))
|
||||
async def read_encoded_value(self, connection):
|
||||
return self.encode_value(
|
||||
await self.wrapped_characteristic.read_value(connection)
|
||||
)
|
||||
|
||||
def write_encoded_value(self, connection, value):
|
||||
return self.wrapped_characteristic.write_value(
|
||||
async def write_encoded_value(self, connection, value):
|
||||
return await self.wrapped_characteristic.write_value(
|
||||
connection, self.decode_value(value)
|
||||
)
|
||||
|
||||
@@ -519,7 +636,7 @@ class CharacteristicAdapter:
|
||||
|
||||
return self.wrapped_characteristic.unsubscribe(subscriber)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
wrapped = str(self.wrapped_characteristic)
|
||||
return f'{self.__class__.__name__}({wrapped})'
|
||||
|
||||
@@ -599,10 +716,10 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
|
||||
Adapter that converts strings to/from bytes using UTF-8 encoding
|
||||
'''
|
||||
|
||||
def encode_value(self, value):
|
||||
def encode_value(self, value: str) -> bytes:
|
||||
return value.encode('utf-8')
|
||||
|
||||
def decode_value(self, value):
|
||||
def decode_value(self, value: bytes) -> str:
|
||||
return value.decode('utf-8')
|
||||
|
||||
|
||||
@@ -612,14 +729,25 @@ class Descriptor(Attribute):
|
||||
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
|
||||
'''
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if isinstance(self.value, bytes):
|
||||
value_str = self.value.hex()
|
||||
elif isinstance(self.value, CharacteristicValue):
|
||||
value = self.value.read(None)
|
||||
if isinstance(value, bytes):
|
||||
value_str = value.hex()
|
||||
else:
|
||||
value_str = '<async>'
|
||||
else:
|
||||
value_str = '<...>'
|
||||
return (
|
||||
f'Descriptor(handle=0x{self.handle:04X}, '
|
||||
f'type={self.type}, '
|
||||
f'value={self.read_value(None).hex()})'
|
||||
f'value={value_str})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ClientCharacteristicConfigurationBits(enum.IntFlag):
|
||||
'''
|
||||
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
|
||||
|
||||
@@ -28,7 +28,19 @@ import asyncio
|
||||
import logging
|
||||
import struct
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Tuple, Callable, Union, Any
|
||||
from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Dict,
|
||||
Tuple,
|
||||
Callable,
|
||||
Union,
|
||||
Any,
|
||||
Iterable,
|
||||
Type,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from pyee import EventEmitter
|
||||
|
||||
@@ -66,8 +78,12 @@ from .gatt import (
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||
Characteristic,
|
||||
ClientCharacteristicConfigurationBits,
|
||||
TemplateService,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Connection
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -78,16 +94,16 @@ logger = logging.getLogger(__name__)
|
||||
# Proxies
|
||||
# -----------------------------------------------------------------------------
|
||||
class AttributeProxy(EventEmitter):
|
||||
client: Client
|
||||
|
||||
def __init__(self, client, handle, end_group_handle, attribute_type):
|
||||
def __init__(
|
||||
self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID
|
||||
) -> None:
|
||||
EventEmitter.__init__(self)
|
||||
self.client = client
|
||||
self.handle = handle
|
||||
self.end_group_handle = end_group_handle
|
||||
self.type = attribute_type
|
||||
|
||||
async def read_value(self, no_long_read=False):
|
||||
async def read_value(self, no_long_read: bool = False) -> bytes:
|
||||
return self.decode_value(
|
||||
await self.client.read_value(self.handle, no_long_read)
|
||||
)
|
||||
@@ -97,13 +113,13 @@ class AttributeProxy(EventEmitter):
|
||||
self.handle, self.encode_value(value), with_response
|
||||
)
|
||||
|
||||
def encode_value(self, value):
|
||||
def encode_value(self, value: Any) -> bytes:
|
||||
return value
|
||||
|
||||
def decode_value(self, value_bytes):
|
||||
def decode_value(self, value_bytes: bytes) -> Any:
|
||||
return value_bytes
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
|
||||
|
||||
|
||||
@@ -113,7 +129,7 @@ class ServiceProxy(AttributeProxy):
|
||||
included_services: List[ServiceProxy]
|
||||
|
||||
@staticmethod
|
||||
def from_client(service_class, client, service_uuid):
|
||||
def from_client(service_class, client: Client, service_uuid: UUID):
|
||||
# The service and its characteristics are considered to have already been
|
||||
# discovered
|
||||
services = client.get_services_by_uuid(service_uuid)
|
||||
@@ -136,14 +152,14 @@ class ServiceProxy(AttributeProxy):
|
||||
def get_characteristics_by_uuid(self, uuid):
|
||||
return self.client.get_characteristics_by_uuid(uuid, self)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
|
||||
|
||||
|
||||
class CharacteristicProxy(AttributeProxy):
|
||||
properties: Characteristic.Properties
|
||||
descriptors: List[DescriptorProxy]
|
||||
subscribers: Dict[Any, Callable]
|
||||
subscribers: Dict[Any, Callable[[bytes], Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -171,7 +187,9 @@ class CharacteristicProxy(AttributeProxy):
|
||||
return await self.client.discover_descriptors(self)
|
||||
|
||||
async def subscribe(
|
||||
self, subscriber: Optional[Callable] = None, prefer_notify=True
|
||||
self,
|
||||
subscriber: Optional[Callable[[bytes], Any]] = None,
|
||||
prefer_notify: bool = True,
|
||||
):
|
||||
if subscriber is not None:
|
||||
if subscriber in self.subscribers:
|
||||
@@ -189,13 +207,13 @@ class CharacteristicProxy(AttributeProxy):
|
||||
|
||||
return await self.client.subscribe(self, subscriber, prefer_notify)
|
||||
|
||||
async def unsubscribe(self, subscriber=None):
|
||||
async def unsubscribe(self, subscriber=None, force=False):
|
||||
if subscriber in self.subscribers:
|
||||
subscriber = self.subscribers.pop(subscriber)
|
||||
|
||||
return await self.client.unsubscribe(self, subscriber)
|
||||
return await self.client.unsubscribe(self, subscriber, force)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Characteristic(handle=0x{self.handle:04X}, '
|
||||
f'uuid={self.uuid}, '
|
||||
@@ -207,7 +225,7 @@ class DescriptorProxy(AttributeProxy):
|
||||
def __init__(self, client, handle, descriptor_type):
|
||||
super().__init__(client, handle, 0, descriptor_type)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})'
|
||||
|
||||
|
||||
@@ -216,8 +234,10 @@ class ProfileServiceProxy:
|
||||
Base class for profile-specific service proxies
|
||||
'''
|
||||
|
||||
SERVICE_CLASS: Type[TemplateService]
|
||||
|
||||
@classmethod
|
||||
def from_client(cls, client):
|
||||
def from_client(cls, client: Client) -> ProfileServiceProxy:
|
||||
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
|
||||
|
||||
|
||||
@@ -227,30 +247,36 @@ class ProfileServiceProxy:
|
||||
class Client:
|
||||
services: List[ServiceProxy]
|
||||
cached_values: Dict[int, Tuple[datetime, bytes]]
|
||||
notification_subscribers: Dict[
|
||||
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
|
||||
]
|
||||
indication_subscribers: Dict[
|
||||
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
|
||||
]
|
||||
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
|
||||
pending_request: Optional[ATT_PDU]
|
||||
|
||||
def __init__(self, connection):
|
||||
def __init__(self, connection: Connection) -> None:
|
||||
self.connection = connection
|
||||
self.mtu_exchange_done = False
|
||||
self.request_semaphore = asyncio.Semaphore(1)
|
||||
self.pending_request = None
|
||||
self.pending_response = None
|
||||
self.notification_subscribers = (
|
||||
{}
|
||||
) # Notification subscribers, by attribute handle
|
||||
self.indication_subscribers = {} # Indication subscribers, by attribute handle
|
||||
self.notification_subscribers = {} # Subscriber set, by attribute handle
|
||||
self.indication_subscribers = {} # Subscriber set, by attribute handle
|
||||
self.services = []
|
||||
self.cached_values = {}
|
||||
|
||||
def send_gatt_pdu(self, pdu):
|
||||
def send_gatt_pdu(self, pdu: bytes) -> None:
|
||||
self.connection.send_l2cap_pdu(ATT_CID, pdu)
|
||||
|
||||
async def send_command(self, command):
|
||||
async def send_command(self, command: ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
|
||||
)
|
||||
self.send_gatt_pdu(command.to_bytes())
|
||||
|
||||
async def send_request(self, request):
|
||||
async def send_request(self, request: ATT_PDU):
|
||||
logger.debug(
|
||||
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
|
||||
)
|
||||
@@ -279,14 +305,14 @@ class Client:
|
||||
|
||||
return response
|
||||
|
||||
def send_confirmation(self, confirmation):
|
||||
def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None:
|
||||
logger.debug(
|
||||
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
|
||||
f'{confirmation}'
|
||||
)
|
||||
self.send_gatt_pdu(confirmation.to_bytes())
|
||||
|
||||
async def request_mtu(self, mtu):
|
||||
async def request_mtu(self, mtu: int) -> int:
|
||||
# Check the range
|
||||
if mtu < ATT_DEFAULT_MTU:
|
||||
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
|
||||
@@ -313,10 +339,12 @@ class Client:
|
||||
|
||||
return self.connection.att_mtu
|
||||
|
||||
def get_services_by_uuid(self, uuid):
|
||||
def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]:
|
||||
return [service for service in self.services if service.uuid == uuid]
|
||||
|
||||
def get_characteristics_by_uuid(self, uuid, service=None):
|
||||
def get_characteristics_by_uuid(
|
||||
self, uuid: UUID, service: Optional[ServiceProxy] = None
|
||||
) -> List[CharacteristicProxy]:
|
||||
services = [service] if service else self.services
|
||||
return [
|
||||
c
|
||||
@@ -363,7 +391,7 @@ class Client:
|
||||
if not already_known:
|
||||
self.services.append(service)
|
||||
|
||||
async def discover_services(self, uuids=None) -> List[ServiceProxy]:
|
||||
async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.4.1 Discover All Primary Services
|
||||
'''
|
||||
@@ -435,7 +463,7 @@ class Client:
|
||||
|
||||
return services
|
||||
|
||||
async def discover_service(self, uuid):
|
||||
async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
|
||||
'''
|
||||
@@ -468,7 +496,7 @@ class Client:
|
||||
f'{HCI_Constant.error_name(response.error_code)}'
|
||||
)
|
||||
# TODO raise appropriate exception
|
||||
return
|
||||
return []
|
||||
break
|
||||
|
||||
for attribute_handle, end_group_handle in response.handles_information:
|
||||
@@ -480,7 +508,7 @@ class Client:
|
||||
logger.warning(
|
||||
f'bogus handle values: {attribute_handle} {end_group_handle}'
|
||||
)
|
||||
return
|
||||
return []
|
||||
|
||||
# Create a service proxy for this service
|
||||
service = ServiceProxy(
|
||||
@@ -657,8 +685,8 @@ class Client:
|
||||
async def discover_descriptors(
|
||||
self,
|
||||
characteristic: Optional[CharacteristicProxy] = None,
|
||||
start_handle=None,
|
||||
end_handle=None,
|
||||
start_handle: Optional[int] = None,
|
||||
end_handle: Optional[int] = None,
|
||||
) -> List[DescriptorProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
|
||||
@@ -721,7 +749,7 @@ class Client:
|
||||
|
||||
return descriptors
|
||||
|
||||
async def discover_attributes(self):
|
||||
async def discover_attributes(self) -> List[AttributeProxy]:
|
||||
'''
|
||||
Discover all attributes, regardless of type
|
||||
'''
|
||||
@@ -764,7 +792,12 @@ class Client:
|
||||
|
||||
return attributes
|
||||
|
||||
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
|
||||
async def subscribe(
|
||||
self,
|
||||
characteristic: CharacteristicProxy,
|
||||
subscriber: Optional[Callable[[bytes], Any]] = None,
|
||||
prefer_notify: bool = True,
|
||||
) -> None:
|
||||
# If we haven't already discovered the descriptors for this characteristic,
|
||||
# do it now
|
||||
if not characteristic.descriptors_discovered:
|
||||
@@ -801,6 +834,7 @@ class Client:
|
||||
subscriber_set = subscribers.setdefault(characteristic.handle, set())
|
||||
if subscriber is not None:
|
||||
subscriber_set.add(subscriber)
|
||||
|
||||
# Add the characteristic as a subscriber, which will result in the
|
||||
# characteristic emitting an 'update' event when a notification or indication
|
||||
# is received
|
||||
@@ -808,7 +842,18 @@ class Client:
|
||||
|
||||
await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
|
||||
|
||||
async def unsubscribe(self, characteristic, subscriber=None):
|
||||
async def unsubscribe(
|
||||
self,
|
||||
characteristic: CharacteristicProxy,
|
||||
subscriber: Optional[Callable[[bytes], Any]] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
'''
|
||||
Unsubscribe from a characteristic.
|
||||
|
||||
If `force` is True, this will write zeros to the CCCD when there are no
|
||||
subscribers left, even if there were already no registered subscribers.
|
||||
'''
|
||||
# If we haven't already discovered the descriptors for this characteristic,
|
||||
# do it now
|
||||
if not characteristic.descriptors_discovered:
|
||||
@@ -822,29 +867,45 @@ class Client:
|
||||
logger.warning('unsubscribing from characteristic with no CCCD descriptor')
|
||||
return
|
||||
|
||||
# Check if the characteristic has subscribers
|
||||
if not (
|
||||
characteristic.handle in self.notification_subscribers
|
||||
or characteristic.handle in self.indication_subscribers
|
||||
):
|
||||
if not force:
|
||||
return
|
||||
|
||||
# Remove the subscriber(s)
|
||||
if subscriber is not None:
|
||||
# Remove matching subscriber from subscriber sets
|
||||
for subscriber_set in (
|
||||
self.notification_subscribers,
|
||||
self.indication_subscribers,
|
||||
):
|
||||
subscribers = subscriber_set.get(characteristic.handle, [])
|
||||
if subscriber in subscribers:
|
||||
if (
|
||||
subscribers := subscriber_set.get(characteristic.handle)
|
||||
) and subscriber in subscribers:
|
||||
subscribers.remove(subscriber)
|
||||
|
||||
# Cleanup if we removed the last one
|
||||
if not subscribers:
|
||||
del subscriber_set[characteristic.handle]
|
||||
else:
|
||||
# Remove all subscribers for this attribute from the sets!
|
||||
# Remove all subscribers for this attribute from the sets
|
||||
self.notification_subscribers.pop(characteristic.handle, None)
|
||||
self.indication_subscribers.pop(characteristic.handle, None)
|
||||
|
||||
if not self.notification_subscribers and not self.indication_subscribers:
|
||||
# Update the CCCD
|
||||
if not (
|
||||
characteristic.handle in self.notification_subscribers
|
||||
or characteristic.handle in self.indication_subscribers
|
||||
):
|
||||
# No more subscribers left
|
||||
await self.write_value(cccd, b'\x00\x00', with_response=True)
|
||||
|
||||
async def read_value(self, attribute, no_long_read=False):
|
||||
async def read_value(
|
||||
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
|
||||
) -> bytes:
|
||||
'''
|
||||
See Vol 3, Part G - 4.8.1 Read Characteristic Value
|
||||
|
||||
@@ -905,7 +966,9 @@ class Client:
|
||||
# Return the value as bytes
|
||||
return attribute_value
|
||||
|
||||
async def read_characteristics_by_uuid(self, uuid, service):
|
||||
async def read_characteristics_by_uuid(
|
||||
self, uuid: UUID, service: Optional[ServiceProxy]
|
||||
) -> List[bytes]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
|
||||
'''
|
||||
@@ -960,7 +1023,12 @@ class Client:
|
||||
|
||||
return characteristics_values
|
||||
|
||||
async def write_value(self, attribute, value, with_response=False):
|
||||
async def write_value(
|
||||
self,
|
||||
attribute: Union[int, AttributeProxy],
|
||||
value: bytes,
|
||||
with_response: bool = False,
|
||||
) -> None:
|
||||
'''
|
||||
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
|
||||
Value
|
||||
@@ -990,7 +1058,7 @@ class Client:
|
||||
)
|
||||
)
|
||||
|
||||
def on_gatt_pdu(self, att_pdu):
|
||||
def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
|
||||
)
|
||||
@@ -1000,7 +1068,7 @@ class Client:
|
||||
logger.warning('!!! unexpected response, there is no pending request')
|
||||
return
|
||||
|
||||
# Sanity check: the response should match the pending request unless it is
|
||||
# The response should match the pending request unless it is
|
||||
# an error response
|
||||
if att_pdu.op_code != ATT_ERROR_RESPONSE:
|
||||
expected_response_name = self.pending_request.name.replace(
|
||||
@@ -1013,6 +1081,7 @@ class Client:
|
||||
return
|
||||
|
||||
# Return the response to the coroutine that is waiting for it
|
||||
assert self.pending_response is not None
|
||||
self.pending_response.set_result(att_pdu)
|
||||
else:
|
||||
handler_name = f'on_{att_pdu.name.lower()}'
|
||||
@@ -1032,7 +1101,7 @@ class Client:
|
||||
def on_att_handle_value_notification(self, notification):
|
||||
# Call all subscribers
|
||||
subscribers = self.notification_subscribers.get(
|
||||
notification.attribute_handle, []
|
||||
notification.attribute_handle, set()
|
||||
)
|
||||
if not subscribers:
|
||||
logger.warning('!!! received notification with no subscriber')
|
||||
@@ -1046,7 +1115,9 @@ class Client:
|
||||
|
||||
def on_att_handle_value_indication(self, indication):
|
||||
# Call all subscribers
|
||||
subscribers = self.indication_subscribers.get(indication.attribute_handle, [])
|
||||
subscribers = self.indication_subscribers.get(
|
||||
indication.attribute_handle, set()
|
||||
)
|
||||
if not subscribers:
|
||||
logger.warning('!!! received indication with no subscriber')
|
||||
|
||||
@@ -1060,7 +1131,7 @@ class Client:
|
||||
# Confirm that we received the indication
|
||||
self.send_confirmation(ATT_Handle_Value_Confirmation())
|
||||
|
||||
def cache_value(self, attribute_handle: int, value: bytes):
|
||||
def cache_value(self, attribute_handle: int, value: bytes) -> None:
|
||||
self.cached_values[attribute_handle] = (
|
||||
datetime.now(),
|
||||
value,
|
||||
|
||||
@@ -23,16 +23,17 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import struct
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
|
||||
from pyee import EventEmitter
|
||||
|
||||
from .colors import color
|
||||
from .core import UUID
|
||||
from .att import (
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID
|
||||
from bumble.att import (
|
||||
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
||||
ATT_CID,
|
||||
@@ -42,6 +43,7 @@ from .att import (
|
||||
ATT_INVALID_OFFSET_ERROR,
|
||||
ATT_REQUEST_NOT_SUPPORTED_ERROR,
|
||||
ATT_REQUESTS,
|
||||
ATT_PDU,
|
||||
ATT_UNLIKELY_ERROR_ERROR,
|
||||
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
|
||||
ATT_Error,
|
||||
@@ -58,7 +60,7 @@ from .att import (
|
||||
ATT_Write_Response,
|
||||
Attribute,
|
||||
)
|
||||
from .gatt import (
|
||||
from bumble.gatt import (
|
||||
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
|
||||
@@ -72,7 +74,10 @@ from .gatt import (
|
||||
Descriptor,
|
||||
Service,
|
||||
)
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Device, Connection
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -91,8 +96,13 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
|
||||
# -----------------------------------------------------------------------------
|
||||
class Server(EventEmitter):
|
||||
attributes: List[Attribute]
|
||||
services: List[Service]
|
||||
attributes_by_handle: Dict[int, Attribute]
|
||||
subscribers: Dict[int, Dict[int, bytes]]
|
||||
indication_semaphores: defaultdict[int, asyncio.Semaphore]
|
||||
pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
|
||||
|
||||
def __init__(self, device):
|
||||
def __init__(self, device: Device) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.services = []
|
||||
@@ -107,16 +117,16 @@ class Server(EventEmitter):
|
||||
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
|
||||
self.pending_confirmations = defaultdict(lambda: None)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return "\n".join(map(str, self.attributes))
|
||||
|
||||
def send_gatt_pdu(self, connection_handle, pdu):
|
||||
def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
|
||||
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
|
||||
|
||||
def next_handle(self):
|
||||
def next_handle(self) -> int:
|
||||
return 1 + len(self.attributes)
|
||||
|
||||
def get_advertising_service_data(self):
|
||||
def get_advertising_service_data(self) -> Dict[Attribute, bytes]:
|
||||
return {
|
||||
attribute: data
|
||||
for attribute in self.attributes
|
||||
@@ -124,7 +134,7 @@ class Server(EventEmitter):
|
||||
and (data := attribute.get_advertising_data())
|
||||
}
|
||||
|
||||
def get_attribute(self, handle):
|
||||
def get_attribute(self, handle: int) -> Optional[Attribute]:
|
||||
attribute = self.attributes_by_handle.get(handle)
|
||||
if attribute:
|
||||
return attribute
|
||||
@@ -173,12 +183,17 @@ class Server(EventEmitter):
|
||||
|
||||
return next(
|
||||
(
|
||||
(attribute, self.get_attribute(attribute.characteristic.handle))
|
||||
(
|
||||
attribute,
|
||||
self.get_attribute(attribute.characteristic.handle),
|
||||
) # type: ignore
|
||||
for attribute in map(
|
||||
self.get_attribute,
|
||||
range(service_handle.handle, service_handle.end_group_handle + 1),
|
||||
)
|
||||
if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
|
||||
if attribute is not None
|
||||
and attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
|
||||
and isinstance(attribute, CharacteristicDeclaration)
|
||||
and attribute.characteristic.uuid == characteristic_uuid
|
||||
),
|
||||
None,
|
||||
@@ -197,7 +212,7 @@ class Server(EventEmitter):
|
||||
|
||||
return next(
|
||||
(
|
||||
attribute
|
||||
attribute # type: ignore
|
||||
for attribute in map(
|
||||
self.get_attribute,
|
||||
range(
|
||||
@@ -205,12 +220,12 @@ class Server(EventEmitter):
|
||||
characteristic_value.end_group_handle + 1,
|
||||
),
|
||||
)
|
||||
if attribute.type == descriptor_uuid
|
||||
if attribute is not None and attribute.type == descriptor_uuid
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def add_attribute(self, attribute):
|
||||
def add_attribute(self, attribute: Attribute) -> None:
|
||||
# Assign a handle to this attribute
|
||||
attribute.handle = self.next_handle()
|
||||
attribute.end_group_handle = (
|
||||
@@ -220,7 +235,7 @@ class Server(EventEmitter):
|
||||
# Add this attribute to the list
|
||||
self.attributes.append(attribute)
|
||||
|
||||
def add_service(self, service: Service):
|
||||
def add_service(self, service: Service) -> None:
|
||||
# Add the service attribute to the DB
|
||||
self.add_attribute(service)
|
||||
|
||||
@@ -285,11 +300,13 @@ class Server(EventEmitter):
|
||||
service.end_group_handle = self.attributes[-1].handle
|
||||
self.services.append(service)
|
||||
|
||||
def add_services(self, services):
|
||||
def add_services(self, services: Iterable[Service]) -> None:
|
||||
for service in services:
|
||||
self.add_service(service)
|
||||
|
||||
def read_cccd(self, connection, characteristic):
|
||||
def read_cccd(
|
||||
self, connection: Optional[Connection], characteristic: Characteristic
|
||||
) -> bytes:
|
||||
if connection is None:
|
||||
return bytes([0, 0])
|
||||
|
||||
@@ -300,13 +317,18 @@ class Server(EventEmitter):
|
||||
|
||||
return cccd or bytes([0, 0])
|
||||
|
||||
def write_cccd(self, connection, characteristic, value):
|
||||
def write_cccd(
|
||||
self,
|
||||
connection: Connection,
|
||||
characteristic: Characteristic,
|
||||
value: bytes,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'Subscription update for connection=0x{connection.handle:04X}, '
|
||||
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
|
||||
)
|
||||
|
||||
# Sanity check
|
||||
# Check parameters
|
||||
if len(value) != 2:
|
||||
logger.warning('CCCD value not 2 bytes long')
|
||||
return
|
||||
@@ -327,13 +349,19 @@ class Server(EventEmitter):
|
||||
indicate_enabled,
|
||||
)
|
||||
|
||||
def send_response(self, connection, response):
|
||||
def send_response(self, connection: Connection, response: ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
|
||||
)
|
||||
self.send_gatt_pdu(connection.handle, response.to_bytes())
|
||||
|
||||
async def notify_subscriber(self, connection, attribute, value=None, force=False):
|
||||
async def notify_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
# Check if there's a subscriber
|
||||
if not force:
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
@@ -352,7 +380,7 @@ class Server(EventEmitter):
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
@@ -370,7 +398,13 @@ class Server(EventEmitter):
|
||||
)
|
||||
self.send_gatt_pdu(connection.handle, bytes(notification))
|
||||
|
||||
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
|
||||
async def indicate_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
# Check if there's a subscriber
|
||||
if not force:
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
@@ -389,7 +423,7 @@ class Server(EventEmitter):
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
@@ -411,15 +445,13 @@ class Server(EventEmitter):
|
||||
assert self.pending_confirmations[connection.handle] is None
|
||||
|
||||
# Create a future value to hold the eventual response
|
||||
self.pending_confirmations[
|
||||
pending_confirmation = self.pending_confirmations[
|
||||
connection.handle
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
|
||||
try:
|
||||
self.send_gatt_pdu(connection.handle, indication.to_bytes())
|
||||
await asyncio.wait_for(
|
||||
self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
|
||||
)
|
||||
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
|
||||
except asyncio.TimeoutError as error:
|
||||
logger.warning(color('!!! GATT Indicate timeout', 'red'))
|
||||
raise TimeoutError(f'GATT timeout for {indication.name}') from error
|
||||
@@ -427,8 +459,12 @@ class Server(EventEmitter):
|
||||
self.pending_confirmations[connection.handle] = None
|
||||
|
||||
async def notify_or_indicate_subscribers(
|
||||
self, indicate, attribute, value=None, force=False
|
||||
):
|
||||
self,
|
||||
indicate: bool,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
# Get all the connections for which there's at least one subscription
|
||||
connections = [
|
||||
connection
|
||||
@@ -450,13 +486,23 @@ class Server(EventEmitter):
|
||||
]
|
||||
)
|
||||
|
||||
async def notify_subscribers(self, attribute, value=None, force=False):
|
||||
async def notify_subscribers(
|
||||
self,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
):
|
||||
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
|
||||
|
||||
async def indicate_subscribers(self, attribute, value=None, force=False):
|
||||
async def indicate_subscribers(
|
||||
self,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
):
|
||||
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
|
||||
|
||||
def on_disconnection(self, connection):
|
||||
def on_disconnection(self, connection: Connection) -> None:
|
||||
if connection.handle in self.subscribers:
|
||||
del self.subscribers[connection.handle]
|
||||
if connection.handle in self.indication_semaphores:
|
||||
@@ -464,7 +510,7 @@ class Server(EventEmitter):
|
||||
if connection.handle in self.pending_confirmations:
|
||||
del self.pending_confirmations[connection.handle]
|
||||
|
||||
def on_gatt_pdu(self, connection, att_pdu):
|
||||
def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None:
|
||||
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
|
||||
handler_name = f'on_{att_pdu.name.lower()}'
|
||||
handler = getattr(self, handler_name, None)
|
||||
@@ -506,7 +552,7 @@ class Server(EventEmitter):
|
||||
#######################################################
|
||||
# ATT handlers
|
||||
#######################################################
|
||||
def on_att_request(self, connection, pdu):
|
||||
def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None:
|
||||
'''
|
||||
Handler for requests without a more specific handler
|
||||
'''
|
||||
@@ -605,7 +651,8 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_find_by_type_value_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_find_by_type_value_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
|
||||
'''
|
||||
@@ -613,13 +660,13 @@ class Server(EventEmitter):
|
||||
# Build list of returned attributes
|
||||
pdu_space_available = connection.att_mtu - 2
|
||||
attributes = []
|
||||
for attribute in (
|
||||
async for attribute in (
|
||||
attribute
|
||||
for attribute in self.attributes
|
||||
if attribute.handle >= request.starting_handle
|
||||
and attribute.handle <= request.ending_handle
|
||||
and attribute.type == request.attribute_type
|
||||
and attribute.read_value(connection) == request.attribute_value
|
||||
and (await attribute.read_value(connection)) == request.attribute_value
|
||||
and pdu_space_available >= 4
|
||||
):
|
||||
# TODO: check permissions
|
||||
@@ -657,7 +704,8 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_by_type_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_type_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
|
||||
'''
|
||||
@@ -679,9 +727,8 @@ class Server(EventEmitter):
|
||||
and attribute.handle <= request.ending_handle
|
||||
and pdu_space_available
|
||||
):
|
||||
|
||||
try:
|
||||
attribute_value = attribute.read_value(connection)
|
||||
attribute_value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
# If the first attribute is unreadable, return an error
|
||||
# Otherwise return attributes up to this point
|
||||
@@ -723,14 +770,15 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
|
||||
'''
|
||||
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
try:
|
||||
value = attribute.read_value(connection)
|
||||
value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -748,14 +796,15 @@ class Server(EventEmitter):
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_blob_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_blob_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
|
||||
'''
|
||||
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
try:
|
||||
value = attribute.read_value(connection)
|
||||
value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -792,7 +841,8 @@ class Server(EventEmitter):
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_by_group_type_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_group_type_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
|
||||
'''
|
||||
@@ -820,7 +870,7 @@ class Server(EventEmitter):
|
||||
):
|
||||
# No need to catch permission errors here, since these attributes
|
||||
# must all be world-readable
|
||||
attribute_value = attribute.read_value(connection)
|
||||
attribute_value = await attribute.read_value(connection)
|
||||
# Check the attribute value size
|
||||
max_attribute_size = min(connection.att_mtu - 6, 251)
|
||||
if len(attribute_value) > max_attribute_size:
|
||||
@@ -859,7 +909,8 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_write_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_write_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
|
||||
'''
|
||||
@@ -892,12 +943,13 @@ class Server(EventEmitter):
|
||||
return
|
||||
|
||||
# Accept the value
|
||||
attribute.write_value(connection, request.attribute_value)
|
||||
await attribute.write_value(connection, request.attribute_value)
|
||||
|
||||
# Done
|
||||
self.send_response(connection, ATT_Write_Response())
|
||||
|
||||
def on_att_write_command(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_write_command(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
|
||||
'''
|
||||
@@ -915,9 +967,9 @@ class Server(EventEmitter):
|
||||
|
||||
# Accept the value
|
||||
try:
|
||||
attribute.write_value(connection, request.attribute_value)
|
||||
await attribute.write_value(connection, request.attribute_value)
|
||||
except Exception as error:
|
||||
logger.warning(f'!!! ignoring exception: {error}')
|
||||
logger.exception(f'!!! ignoring exception: {error}')
|
||||
|
||||
def on_att_handle_value_confirmation(self, connection, _confirmation):
|
||||
'''
|
||||
|
||||
2154
bumble/hci.py
2154
bumble/hci.py
File diff suppressed because it is too large
Load Diff
@@ -15,30 +15,45 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, MutableMapping
|
||||
from typing import cast, Any, Optional
|
||||
import logging
|
||||
|
||||
from .colors import color
|
||||
from .att import ATT_CID, ATT_PDU
|
||||
from .smp import SMP_CID, SMP_Command
|
||||
from .core import name_or_number
|
||||
from .l2cap import (
|
||||
from bumble import avc
|
||||
from bumble import avctp
|
||||
from bumble import avdtp
|
||||
from bumble import avrcp
|
||||
from bumble import crypto
|
||||
from bumble import rfcomm
|
||||
from bumble import sdp
|
||||
from bumble.colors import color
|
||||
from bumble.att import ATT_CID, ATT_PDU
|
||||
from bumble.smp import SMP_CID, SMP_Command
|
||||
from bumble.core import name_or_number
|
||||
from bumble.l2cap import (
|
||||
L2CAP_PDU,
|
||||
L2CAP_CONNECTION_REQUEST,
|
||||
L2CAP_CONNECTION_RESPONSE,
|
||||
L2CAP_SIGNALING_CID,
|
||||
L2CAP_LE_SIGNALING_CID,
|
||||
L2CAP_Control_Frame,
|
||||
L2CAP_Connection_Request,
|
||||
L2CAP_Connection_Response,
|
||||
)
|
||||
from .hci import (
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
HCI_EVENT_PACKET,
|
||||
HCI_ACL_DATA_PACKET,
|
||||
HCI_DISCONNECTION_COMPLETE_EVENT,
|
||||
HCI_AclDataPacketAssembler,
|
||||
HCI_Packet,
|
||||
HCI_Event,
|
||||
HCI_AclDataPacket,
|
||||
HCI_Disconnection_Complete_Event,
|
||||
)
|
||||
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM
|
||||
from .sdp import SDP_PDU, SDP_PSM
|
||||
from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -48,26 +63,35 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
PSM_NAMES = {
|
||||
RFCOMM_PSM: 'RFCOMM',
|
||||
SDP_PSM: 'SDP',
|
||||
AVDTP_PSM: 'AVDTP'
|
||||
rfcomm.RFCOMM_PSM: 'RFCOMM',
|
||||
sdp.SDP_PSM: 'SDP',
|
||||
avdtp.AVDTP_PSM: 'AVDTP',
|
||||
avctp.AVCTP_PSM: 'AVCTP'
|
||||
# TODO: add more PSM values
|
||||
}
|
||||
|
||||
AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'}
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PacketTracer:
|
||||
class AclStream:
|
||||
def __init__(self, analyzer):
|
||||
psms: MutableMapping[int, int]
|
||||
peer: Optional[PacketTracer.AclStream]
|
||||
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
|
||||
avctp_assemblers: MutableMapping[int, avctp.MessageAssembler]
|
||||
|
||||
def __init__(self, analyzer: PacketTracer.Analyzer) -> None:
|
||||
self.analyzer = analyzer
|
||||
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
|
||||
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
|
||||
self.avctp_assemblers = {} # AVCTP assemblers, by source_cid
|
||||
self.psms = {} # PSM, by source_cid
|
||||
self.peer = None # ACL stream in the other direction
|
||||
self.peer = None
|
||||
|
||||
# pylint: disable=too-many-nested-blocks
|
||||
def on_acl_pdu(self, pdu):
|
||||
def on_acl_pdu(self, pdu: bytes) -> None:
|
||||
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
|
||||
self.analyzer.emit(l2cap_pdu)
|
||||
|
||||
if l2cap_pdu.cid == ATT_CID:
|
||||
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
|
||||
@@ -81,46 +105,59 @@ class PacketTracer:
|
||||
|
||||
# Check if this signals a new channel
|
||||
if control_frame.code == L2CAP_CONNECTION_REQUEST:
|
||||
self.psms[control_frame.source_cid] = control_frame.psm
|
||||
connection_request = cast(L2CAP_Connection_Request, control_frame)
|
||||
self.psms[connection_request.source_cid] = connection_request.psm
|
||||
elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
|
||||
connection_response = cast(L2CAP_Connection_Response, control_frame)
|
||||
if (
|
||||
control_frame.result
|
||||
connection_response.result
|
||||
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
|
||||
):
|
||||
if self.peer:
|
||||
if psm := self.peer.psms.get(control_frame.source_cid):
|
||||
# Found a pending connection
|
||||
self.psms[control_frame.destination_cid] = psm
|
||||
|
||||
# For AVDTP connections, create a packet assembler for
|
||||
# each direction
|
||||
if psm == AVDTP_PSM:
|
||||
self.avdtp_assemblers[
|
||||
control_frame.source_cid
|
||||
] = AVDTP_MessageAssembler(self.on_avdtp_message)
|
||||
self.peer.avdtp_assemblers[
|
||||
control_frame.destination_cid
|
||||
] = AVDTP_MessageAssembler(
|
||||
self.peer.on_avdtp_message
|
||||
)
|
||||
if self.peer and (
|
||||
psm := self.peer.psms.get(connection_response.source_cid)
|
||||
):
|
||||
# Found a pending connection
|
||||
self.psms[connection_response.destination_cid] = psm
|
||||
|
||||
# For AVDTP connections, create a packet assembler for
|
||||
# each direction
|
||||
if psm == avdtp.AVDTP_PSM:
|
||||
self.avdtp_assemblers[
|
||||
connection_response.source_cid
|
||||
] = avdtp.MessageAssembler(self.on_avdtp_message)
|
||||
self.peer.avdtp_assemblers[
|
||||
connection_response.destination_cid
|
||||
] = avdtp.MessageAssembler(self.peer.on_avdtp_message)
|
||||
elif psm == avctp.AVCTP_PSM:
|
||||
self.avctp_assemblers[
|
||||
connection_response.source_cid
|
||||
] = avctp.MessageAssembler(self.on_avctp_message)
|
||||
self.peer.avctp_assemblers[
|
||||
connection_response.destination_cid
|
||||
] = avctp.MessageAssembler(self.peer.on_avctp_message)
|
||||
else:
|
||||
# Try to find the PSM associated with this PDU
|
||||
if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)):
|
||||
if psm == SDP_PSM:
|
||||
sdp_pdu = SDP_PDU.from_bytes(l2cap_pdu.payload)
|
||||
if psm == sdp.SDP_PSM:
|
||||
sdp_pdu = sdp.SDP_PDU.from_bytes(l2cap_pdu.payload)
|
||||
self.analyzer.emit(sdp_pdu)
|
||||
elif psm == RFCOMM_PSM:
|
||||
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
|
||||
elif psm == rfcomm.RFCOMM_PSM:
|
||||
rfcomm_frame = rfcomm.RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
|
||||
self.analyzer.emit(rfcomm_frame)
|
||||
elif psm == AVDTP_PSM:
|
||||
elif psm == avdtp.AVDTP_PSM:
|
||||
self.analyzer.emit(
|
||||
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
|
||||
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
|
||||
)
|
||||
assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
|
||||
if assembler:
|
||||
assembler.on_pdu(l2cap_pdu.payload)
|
||||
if avdtp_assembler := self.avdtp_assemblers.get(l2cap_pdu.cid):
|
||||
avdtp_assembler.on_pdu(l2cap_pdu.payload)
|
||||
elif psm == avctp.AVCTP_PSM:
|
||||
self.analyzer.emit(
|
||||
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
|
||||
f'PSM=AVCTP]: {l2cap_pdu.payload.hex()}'
|
||||
)
|
||||
if avctp_assembler := self.avctp_assemblers.get(l2cap_pdu.cid):
|
||||
avctp_assembler.on_pdu(l2cap_pdu.payload)
|
||||
else:
|
||||
psm_string = name_or_number(PSM_NAMES, psm)
|
||||
self.analyzer.emit(
|
||||
@@ -130,22 +167,48 @@ class PacketTracer:
|
||||
else:
|
||||
self.analyzer.emit(l2cap_pdu)
|
||||
|
||||
def on_avdtp_message(self, transaction_label, message):
|
||||
def on_avdtp_message(
|
||||
self, transaction_label: int, message: avdtp.Message
|
||||
) -> None:
|
||||
self.analyzer.emit(
|
||||
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
|
||||
)
|
||||
|
||||
def feed_packet(self, packet):
|
||||
def on_avctp_message(
|
||||
self,
|
||||
transaction_label: int,
|
||||
is_command: bool,
|
||||
ipid: bool,
|
||||
pid: int,
|
||||
payload: bytes,
|
||||
):
|
||||
if pid == avrcp.AVRCP_PID:
|
||||
avc_frame = avc.Frame.from_bytes(payload)
|
||||
details = str(avc_frame)
|
||||
else:
|
||||
details = payload.hex()
|
||||
|
||||
c_r = 'Command' if is_command else 'Response'
|
||||
self.analyzer.emit(
|
||||
f'{color("AVCTP", "green")} '
|
||||
f'{c_r}[{transaction_label}][{name_or_number(AVCTP_PID_NAMES, pid)}] '
|
||||
f'{"#" if ipid else ""}'
|
||||
f'{details}'
|
||||
)
|
||||
|
||||
def feed_packet(self, packet: HCI_AclDataPacket) -> None:
|
||||
self.packet_assembler.feed_packet(packet)
|
||||
|
||||
class Analyzer:
|
||||
def __init__(self, label, emit_message):
|
||||
acl_streams: MutableMapping[int, PacketTracer.AclStream]
|
||||
peer: PacketTracer.Analyzer
|
||||
|
||||
def __init__(self, label: str, emit_message: Callable[..., None]) -> None:
|
||||
self.label = label
|
||||
self.emit_message = emit_message
|
||||
self.acl_streams = {} # ACL streams, by connection handle
|
||||
self.peer = None # Analyzer in the other direction
|
||||
|
||||
def start_acl_stream(self, connection_handle):
|
||||
def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream:
|
||||
logger.info(
|
||||
f'[{self.label}] +++ Creating ACL stream for connection '
|
||||
f'0x{connection_handle:04X}'
|
||||
@@ -160,7 +223,7 @@ class PacketTracer:
|
||||
|
||||
return stream
|
||||
|
||||
def end_acl_stream(self, connection_handle):
|
||||
def end_acl_stream(self, connection_handle: int) -> None:
|
||||
if connection_handle in self.acl_streams:
|
||||
logger.info(
|
||||
f'[{self.label}] --- Removing ACL stream for connection '
|
||||
@@ -171,23 +234,29 @@ class PacketTracer:
|
||||
# Let the other forwarder know so it can cleanup its stream as well
|
||||
self.peer.end_acl_stream(connection_handle)
|
||||
|
||||
def on_packet(self, packet):
|
||||
def on_packet(self, packet: HCI_Packet) -> None:
|
||||
self.emit(packet)
|
||||
|
||||
if packet.hci_packet_type == HCI_ACL_DATA_PACKET:
|
||||
acl_packet = cast(HCI_AclDataPacket, packet)
|
||||
# Look for an existing stream for this handle, create one if it is the
|
||||
# first ACL packet for that connection handle
|
||||
if (stream := self.acl_streams.get(packet.connection_handle)) is None:
|
||||
stream = self.start_acl_stream(packet.connection_handle)
|
||||
stream.feed_packet(packet)
|
||||
if (
|
||||
stream := self.acl_streams.get(acl_packet.connection_handle)
|
||||
) is None:
|
||||
stream = self.start_acl_stream(acl_packet.connection_handle)
|
||||
stream.feed_packet(acl_packet)
|
||||
elif packet.hci_packet_type == HCI_EVENT_PACKET:
|
||||
if packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
|
||||
self.end_acl_stream(packet.connection_handle)
|
||||
event_packet = cast(HCI_Event, packet)
|
||||
if event_packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
|
||||
self.end_acl_stream(
|
||||
cast(HCI_Disconnection_Complete_Event, packet).connection_handle
|
||||
)
|
||||
|
||||
def emit(self, message):
|
||||
def emit(self, message: Any) -> None:
|
||||
self.emit_message(f'[{self.label}] {message}')
|
||||
|
||||
def trace(self, packet, direction=0):
|
||||
def trace(self, packet: HCI_Packet, direction: int = 0) -> None:
|
||||
if direction == 0:
|
||||
self.host_to_controller_analyzer.on_packet(packet)
|
||||
else:
|
||||
@@ -195,10 +264,10 @@ class PacketTracer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host_to_controller_label=color('HOST->CONTROLLER', 'blue'),
|
||||
controller_to_host_label=color('CONTROLLER->HOST', 'cyan'),
|
||||
emit_message=logger.info,
|
||||
):
|
||||
host_to_controller_label: str = color('HOST->CONTROLLER', 'blue'),
|
||||
controller_to_host_label: str = color('CONTROLLER->HOST', 'cyan'),
|
||||
emit_message: Callable[..., None] = logger.info,
|
||||
) -> None:
|
||||
self.host_to_controller_analyzer = PacketTracer.Analyzer(
|
||||
host_to_controller_label, emit_message
|
||||
)
|
||||
@@ -207,3 +276,15 @@ class PacketTracer:
|
||||
)
|
||||
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
|
||||
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer
|
||||
|
||||
|
||||
def generate_irk() -> bytes:
|
||||
return crypto.r()
|
||||
|
||||
|
||||
def verify_rpa_with_irk(rpa: Address, irk: bytes) -> bool:
|
||||
rpa_bytes = bytes(rpa)
|
||||
prand_given = rpa_bytes[3:]
|
||||
hash_given = rpa_bytes[:3]
|
||||
hash_local = crypto.ah(irk, prand_given)
|
||||
return hash_local[:3] == hash_given
|
||||
|
||||
1090
bumble/hfp.py
1090
bumble/hfp.py
File diff suppressed because it is too large
Load Diff
554
bumble/hid.py
Normal file
554
bumble/hid.py
Normal file
@@ -0,0 +1,554 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import enum
|
||||
import struct
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pyee import EventEmitter
|
||||
from typing import Optional, Callable, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
from bumble import l2cap, device
|
||||
from bumble.colors import color
|
||||
from bumble.core import InvalidStateError, ProtocolError
|
||||
from .hci import Address
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
# fmt: on
|
||||
HID_CONTROL_PSM = 0x0011
|
||||
HID_INTERRUPT_PSM = 0x0013
|
||||
|
||||
|
||||
class Message:
|
||||
message_type: MessageType
|
||||
# Report types
|
||||
class ReportType(enum.IntEnum):
|
||||
OTHER_REPORT = 0x00
|
||||
INPUT_REPORT = 0x01
|
||||
OUTPUT_REPORT = 0x02
|
||||
FEATURE_REPORT = 0x03
|
||||
|
||||
# Handshake parameters
|
||||
class Handshake(enum.IntEnum):
|
||||
SUCCESSFUL = 0x00
|
||||
NOT_READY = 0x01
|
||||
ERR_INVALID_REPORT_ID = 0x02
|
||||
ERR_UNSUPPORTED_REQUEST = 0x03
|
||||
ERR_INVALID_PARAMETER = 0x04
|
||||
ERR_UNKNOWN = 0x0E
|
||||
ERR_FATAL = 0x0F
|
||||
|
||||
# Message Type
|
||||
class MessageType(enum.IntEnum):
|
||||
HANDSHAKE = 0x00
|
||||
CONTROL = 0x01
|
||||
GET_REPORT = 0x04
|
||||
SET_REPORT = 0x05
|
||||
GET_PROTOCOL = 0x06
|
||||
SET_PROTOCOL = 0x07
|
||||
DATA = 0x0A
|
||||
|
||||
# Protocol modes
|
||||
class ProtocolMode(enum.IntEnum):
|
||||
BOOT_PROTOCOL = 0x00
|
||||
REPORT_PROTOCOL = 0x01
|
||||
|
||||
# Control Operations
|
||||
class ControlCommand(enum.IntEnum):
|
||||
SUSPEND = 0x03
|
||||
EXIT_SUSPEND = 0x04
|
||||
VIRTUAL_CABLE_UNPLUG = 0x05
|
||||
|
||||
# Class Method to derive header
|
||||
@classmethod
|
||||
def header(cls, lower_bits: int = 0x00) -> bytes:
|
||||
return bytes([(cls.message_type << 4) | lower_bits])
|
||||
|
||||
|
||||
# HIDP messages
|
||||
@dataclass
|
||||
class GetReportMessage(Message):
|
||||
report_type: int
|
||||
report_id: int
|
||||
buffer_size: int
|
||||
message_type = Message.MessageType.GET_REPORT
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
packet_bytes = bytearray()
|
||||
packet_bytes.append(self.report_id)
|
||||
if self.buffer_size == 0:
|
||||
return self.header(self.report_type) + packet_bytes
|
||||
else:
|
||||
return (
|
||||
self.header(0x08 | self.report_type)
|
||||
+ packet_bytes
|
||||
+ struct.pack("<H", self.buffer_size)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetReportMessage(Message):
|
||||
report_type: int
|
||||
data: bytes
|
||||
message_type = Message.MessageType.SET_REPORT
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendControlData(Message):
|
||||
report_type: int
|
||||
data: bytes
|
||||
message_type = Message.MessageType.DATA
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetProtocolMessage(Message):
|
||||
message_type = Message.MessageType.GET_PROTOCOL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetProtocolMessage(Message):
|
||||
protocol_mode: int
|
||||
message_type = Message.MessageType.SET_PROTOCOL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.protocol_mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Suspend(Message):
|
||||
message_type = Message.MessageType.CONTROL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(Message.ControlCommand.SUSPEND)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExitSuspend(Message):
|
||||
message_type = Message.MessageType.CONTROL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(Message.ControlCommand.EXIT_SUSPEND)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VirtualCableUnplug(Message):
|
||||
message_type = Message.MessageType.CONTROL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
|
||||
|
||||
|
||||
# Device sends input report, host sends output report.
|
||||
@dataclass
|
||||
class SendData(Message):
|
||||
data: bytes
|
||||
report_type: int
|
||||
message_type = Message.MessageType.DATA
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendHandshakeMessage(Message):
|
||||
result_code: int
|
||||
message_type = Message.MessageType.HANDSHAKE
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.result_code)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class HID(ABC, EventEmitter):
|
||||
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
|
||||
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
|
||||
connection: Optional[device.Connection] = None
|
||||
|
||||
class Role(enum.IntEnum):
|
||||
HOST = 0x00
|
||||
DEVICE = 0x01
|
||||
|
||||
def __init__(self, device: device.Device, role: Role) -> None:
|
||||
super().__init__()
|
||||
self.remote_device_bd_address: Optional[Address] = None
|
||||
self.device = device
|
||||
self.role = role
|
||||
|
||||
# Register ourselves with the L2CAP channel manager
|
||||
device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
|
||||
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
|
||||
|
||||
device.on('connection', self.on_device_connection)
|
||||
|
||||
async def connect_control_channel(self) -> None:
|
||||
# Create a new L2CAP connection - control channel
|
||||
try:
|
||||
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
|
||||
self.connection, HID_CONTROL_PSM
|
||||
)
|
||||
except ProtocolError:
|
||||
logging.exception(f'L2CAP connection failed.')
|
||||
raise
|
||||
|
||||
assert self.l2cap_ctrl_channel is not None
|
||||
# Become a sink for the L2CAP channel
|
||||
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
|
||||
|
||||
async def connect_interrupt_channel(self) -> None:
|
||||
# Create a new L2CAP connection - interrupt channel
|
||||
try:
|
||||
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
|
||||
self.connection, HID_INTERRUPT_PSM
|
||||
)
|
||||
except ProtocolError:
|
||||
logging.exception(f'L2CAP connection failed.')
|
||||
raise
|
||||
|
||||
assert self.l2cap_intr_channel is not None
|
||||
# Become a sink for the L2CAP channel
|
||||
self.l2cap_intr_channel.sink = self.on_intr_pdu
|
||||
|
||||
async def disconnect_interrupt_channel(self) -> None:
|
||||
if self.l2cap_intr_channel is None:
|
||||
raise InvalidStateError('invalid state')
|
||||
channel = self.l2cap_intr_channel
|
||||
self.l2cap_intr_channel = None
|
||||
await channel.disconnect()
|
||||
|
||||
async def disconnect_control_channel(self) -> None:
|
||||
if self.l2cap_ctrl_channel is None:
|
||||
raise InvalidStateError('invalid state')
|
||||
channel = self.l2cap_ctrl_channel
|
||||
self.l2cap_ctrl_channel = None
|
||||
await channel.disconnect()
|
||||
|
||||
def on_device_connection(self, connection: device.Connection) -> None:
|
||||
self.connection = connection
|
||||
self.remote_device_bd_address = connection.peer_address
|
||||
connection.on('disconnection', self.on_device_disconnection)
|
||||
|
||||
def on_device_disconnection(self, reason: int) -> None:
|
||||
self.connection = None
|
||||
|
||||
def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
|
||||
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
|
||||
l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
|
||||
|
||||
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
if l2cap_channel.psm == HID_CONTROL_PSM:
|
||||
self.l2cap_ctrl_channel = l2cap_channel
|
||||
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
|
||||
else:
|
||||
self.l2cap_intr_channel = l2cap_channel
|
||||
self.l2cap_intr_channel.sink = self.on_intr_pdu
|
||||
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
|
||||
|
||||
def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
if l2cap_channel.psm == HID_CONTROL_PSM:
|
||||
self.l2cap_ctrl_channel = None
|
||||
else:
|
||||
self.l2cap_intr_channel = None
|
||||
logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
|
||||
|
||||
@abstractmethod
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
pass
|
||||
|
||||
def on_intr_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
|
||||
self.emit("interrupt_data", pdu)
|
||||
|
||||
def send_pdu_on_ctrl(self, msg: bytes) -> None:
|
||||
assert self.l2cap_ctrl_channel
|
||||
self.l2cap_ctrl_channel.send_pdu(msg)
|
||||
|
||||
def send_pdu_on_intr(self, msg: bytes) -> None:
|
||||
assert self.l2cap_intr_channel
|
||||
self.l2cap_intr_channel.send_pdu(msg)
|
||||
|
||||
def send_data(self, data: bytes) -> None:
|
||||
if self.role == HID.Role.HOST:
|
||||
report_type = Message.ReportType.OUTPUT_REPORT
|
||||
else:
|
||||
report_type = Message.ReportType.INPUT_REPORT
|
||||
msg = SendData(data, report_type)
|
||||
hid_message = bytes(msg)
|
||||
if self.l2cap_intr_channel is not None:
|
||||
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_intr(hid_message)
|
||||
|
||||
def virtual_cable_unplug(self) -> None:
|
||||
msg = VirtualCableUnplug()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Device(HID):
|
||||
class GetSetReturn(enum.IntEnum):
|
||||
FAILURE = 0x00
|
||||
REPORT_ID_NOT_FOUND = 0x01
|
||||
ERR_UNSUPPORTED_REQUEST = 0x02
|
||||
ERR_UNKNOWN = 0x03
|
||||
ERR_INVALID_PARAMETER = 0x04
|
||||
SUCCESS = 0xFF
|
||||
|
||||
class GetSetStatus:
|
||||
def __init__(self) -> None:
|
||||
self.data = bytearray()
|
||||
self.status = 0
|
||||
|
||||
def __init__(self, device: device.Device) -> None:
|
||||
super().__init__(device, HID.Role.DEVICE)
|
||||
get_report_cb: Optional[Callable[[int, int, int], None]] = None
|
||||
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
|
||||
get_protocol_cb: Optional[Callable[[], None]] = None
|
||||
set_protocol_cb: Optional[Callable[[int], None]] = None
|
||||
|
||||
@override
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
|
||||
param = pdu[0] & 0x0F
|
||||
message_type = pdu[0] >> 4
|
||||
|
||||
if message_type == Message.MessageType.GET_REPORT:
|
||||
logger.debug('<<< HID GET REPORT')
|
||||
self.handle_get_report(pdu)
|
||||
elif message_type == Message.MessageType.SET_REPORT:
|
||||
logger.debug('<<< HID SET REPORT')
|
||||
self.handle_set_report(pdu)
|
||||
elif message_type == Message.MessageType.GET_PROTOCOL:
|
||||
logger.debug('<<< HID GET PROTOCOL')
|
||||
self.handle_get_protocol(pdu)
|
||||
elif message_type == Message.MessageType.SET_PROTOCOL:
|
||||
logger.debug('<<< HID SET PROTOCOL')
|
||||
self.handle_set_protocol(pdu)
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('control_data', pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.SUSPEND:
|
||||
logger.debug('<<< HID SUSPEND')
|
||||
self.emit('suspend')
|
||||
elif param == Message.ControlCommand.EXIT_SUSPEND:
|
||||
logger.debug('<<< HID EXIT SUSPEND')
|
||||
self.emit('exit_suspend')
|
||||
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
else:
|
||||
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def send_handshake_message(self, result_code: int) -> None:
|
||||
msg = SendHandshakeMessage(result_code)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def send_control_data(self, report_type: int, data: bytes):
|
||||
msg = SendControlData(report_type=report_type, data=data)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def handle_get_report(self, pdu: bytes):
|
||||
if self.get_report_cb is None:
|
||||
logger.debug("GetReport callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
report_type = pdu[0] & 0x03
|
||||
buffer_flag = (pdu[0] & 0x08) >> 3
|
||||
report_id = pdu[1]
|
||||
logger.debug(f"buffer_flag: {buffer_flag}")
|
||||
if buffer_flag == 1:
|
||||
buffer_size = (pdu[3] << 8) | pdu[2]
|
||||
else:
|
||||
buffer_size = 0
|
||||
|
||||
ret = self.get_report_cb(report_id, report_type, buffer_size)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.FAILURE:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
|
||||
elif ret.status == self.GetSetReturn.SUCCESS:
|
||||
data = bytearray()
|
||||
data.append(report_id)
|
||||
data.extend(ret.data)
|
||||
if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr]
|
||||
self.send_control_data(report_type=report_type, data=data)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
|
||||
self.get_report_cb = cb
|
||||
logger.debug("GetReport callback registered successfully")
|
||||
|
||||
def handle_set_report(self, pdu: bytes):
|
||||
if self.set_report_cb is None:
|
||||
logger.debug("SetReport callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
report_type = pdu[0] & 0x03
|
||||
report_id = pdu[1]
|
||||
report_data = pdu[2:]
|
||||
report_size = len(report_data) + 1
|
||||
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_set_report_cb(
|
||||
self, cb: Callable[[int, int, int, bytes], None]
|
||||
) -> None:
|
||||
self.set_report_cb = cb
|
||||
logger.debug("SetReport callback registered successfully")
|
||||
|
||||
def handle_get_protocol(self, pdu: bytes):
|
||||
if self.get_protocol_cb is None:
|
||||
logger.debug("GetProtocol callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
ret = self.get_protocol_cb()
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
|
||||
self.get_protocol_cb = cb
|
||||
logger.debug("GetProtocol callback registered successfully")
|
||||
|
||||
def handle_set_protocol(self, pdu: bytes):
|
||||
if self.set_protocol_cb is None:
|
||||
logger.debug("SetProtocol callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
ret = self.set_protocol_cb(pdu[0] & 0x01)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
|
||||
self.set_protocol_cb = cb
|
||||
logger.debug("SetProtocol callback registered successfully")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Host(HID):
|
||||
def __init__(self, device: device.Device) -> None:
|
||||
super().__init__(device, HID.Role.HOST)
|
||||
|
||||
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
|
||||
msg = GetReportMessage(
|
||||
report_type=report_type, report_id=report_id, buffer_size=buffer_size
|
||||
)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_report(self, report_type: int, data: bytes) -> None:
|
||||
msg = SetReportMessage(report_type=report_type, data=data)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def get_protocol(self) -> None:
|
||||
msg = GetProtocolMessage()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_protocol(self, protocol_mode: int) -> None:
|
||||
msg = SetProtocolMessage(protocol_mode=protocol_mode)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def suspend(self) -> None:
|
||||
msg = Suspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def exit_suspend(self) -> None:
|
||||
msg = ExitSuspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
@override
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
|
||||
param = pdu[0] & 0x0F
|
||||
message_type = pdu[0] >> 4
|
||||
if message_type == Message.MessageType.HANDSHAKE:
|
||||
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
|
||||
self.emit('handshake', Message.Handshake(param))
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('control_data', pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
else:
|
||||
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
|
||||
710
bumble/host.py
710
bumble/host.py
File diff suppressed because it is too large
Load Diff
881
bumble/l2cap.py
881
bumble/l2cap.py
File diff suppressed because it is too large
Load Diff
110
bumble/link.py
110
bumble/link.py
@@ -26,9 +26,13 @@ from bumble.hci import (
|
||||
HCI_SUCCESS,
|
||||
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
|
||||
HCI_CONNECTION_TIMEOUT_ERROR,
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
HCI_PAGE_TIMEOUT_ERROR,
|
||||
HCI_Connection_Complete_Event,
|
||||
)
|
||||
from bumble import controller
|
||||
|
||||
from typing import Optional, Set
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -57,6 +61,8 @@ class LocalLink:
|
||||
Link bus for controllers to communicate with each other
|
||||
'''
|
||||
|
||||
controllers: Set[controller.Controller]
|
||||
|
||||
def __init__(self):
|
||||
self.controllers = set()
|
||||
self.pending_connection = None
|
||||
@@ -79,7 +85,9 @@ class LocalLink:
|
||||
return controller
|
||||
return None
|
||||
|
||||
def find_classic_controller(self, address):
|
||||
def find_classic_controller(
|
||||
self, address: Address
|
||||
) -> Optional[controller.Controller]:
|
||||
for controller in self.controllers:
|
||||
if controller.public_address == address:
|
||||
return controller
|
||||
@@ -188,6 +196,60 @@ class LocalLink:
|
||||
if peripheral_controller := self.find_controller(peripheral_address):
|
||||
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
|
||||
|
||||
def create_cis(
|
||||
self,
|
||||
central_controller: controller.Controller,
|
||||
peripheral_address: Address,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
|
||||
)
|
||||
if peripheral_controller := self.find_controller(peripheral_address):
|
||||
asyncio.get_running_loop().call_soon(
|
||||
peripheral_controller.on_link_cis_request,
|
||||
central_controller.random_address,
|
||||
cig_id,
|
||||
cis_id,
|
||||
)
|
||||
|
||||
def accept_cis(
|
||||
self,
|
||||
peripheral_controller: controller.Controller,
|
||||
central_address: Address,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
|
||||
)
|
||||
if central_controller := self.find_controller(central_address):
|
||||
asyncio.get_running_loop().call_soon(
|
||||
central_controller.on_link_cis_established, cig_id, cis_id
|
||||
)
|
||||
asyncio.get_running_loop().call_soon(
|
||||
peripheral_controller.on_link_cis_established, cig_id, cis_id
|
||||
)
|
||||
|
||||
def disconnect_cis(
|
||||
self,
|
||||
initiator_controller: controller.Controller,
|
||||
peer_address: Address,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
|
||||
)
|
||||
if peer_controller := self.find_controller(peer_address):
|
||||
asyncio.get_running_loop().call_soon(
|
||||
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
|
||||
)
|
||||
asyncio.get_running_loop().call_soon(
|
||||
peer_controller.on_link_cis_disconnected, cig_id, cis_id
|
||||
)
|
||||
|
||||
############################################################
|
||||
# Classic handlers
|
||||
############################################################
|
||||
@@ -271,6 +333,52 @@ class LocalLink:
|
||||
initiator_controller.public_address, int(not (initiator_new_role))
|
||||
)
|
||||
|
||||
def classic_sco_connect(
|
||||
self,
|
||||
initiator_controller: controller.Controller,
|
||||
responder_address: Address,
|
||||
link_type: int,
|
||||
):
|
||||
logger.debug(
|
||||
f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
|
||||
)
|
||||
responder_controller = self.find_classic_controller(responder_address)
|
||||
# Initiator controller should handle it.
|
||||
assert responder_controller
|
||||
|
||||
responder_controller.on_classic_connection_request(
|
||||
initiator_controller.public_address,
|
||||
link_type,
|
||||
)
|
||||
|
||||
def classic_accept_sco_connection(
|
||||
self,
|
||||
responder_controller: controller.Controller,
|
||||
initiator_address: Address,
|
||||
link_type: int,
|
||||
):
|
||||
logger.debug(
|
||||
f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
|
||||
)
|
||||
initiator_controller = self.find_classic_controller(initiator_address)
|
||||
if initiator_controller is None:
|
||||
responder_controller.on_classic_sco_connection_complete(
|
||||
responder_controller.public_address,
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
link_type,
|
||||
)
|
||||
return
|
||||
|
||||
async def task():
|
||||
initiator_controller.on_classic_sco_connection_complete(
|
||||
responder_controller.public_address, HCI_SUCCESS, link_type
|
||||
)
|
||||
|
||||
asyncio.create_task(task())
|
||||
responder_controller.on_classic_sco_connection_complete(
|
||||
initiator_controller.public_address, HCI_SUCCESS, link_type
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class RemoteLink:
|
||||
|
||||
@@ -15,10 +15,13 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from .hci import (
|
||||
Address,
|
||||
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
|
||||
HCI_DISPLAY_ONLY_IO_CAPABILITY,
|
||||
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
|
||||
@@ -34,7 +37,60 @@ from .smp import (
|
||||
SMP_ID_KEY_DISTRIBUTION_FLAG,
|
||||
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
|
||||
SMP_LINK_KEY_DISTRIBUTION_FLAG,
|
||||
OobContext,
|
||||
OobLegacyContext,
|
||||
OobSharedData,
|
||||
)
|
||||
from .core import AdvertisingData, LeRole
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class OobData:
|
||||
"""OOB data that can be sent from one device to another."""
|
||||
|
||||
address: Optional[Address] = None
|
||||
role: Optional[LeRole] = None
|
||||
shared_data: Optional[OobSharedData] = None
|
||||
legacy_context: Optional[OobLegacyContext] = None
|
||||
|
||||
@classmethod
|
||||
def from_ad(cls, ad: AdvertisingData) -> OobData:
|
||||
instance = cls()
|
||||
shared_data_c: Optional[bytes] = None
|
||||
shared_data_r: Optional[bytes] = None
|
||||
for ad_type, ad_data in ad.ad_structures:
|
||||
if ad_type == AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS:
|
||||
instance.address = Address(ad_data)
|
||||
elif ad_type == AdvertisingData.LE_ROLE:
|
||||
instance.role = LeRole(ad_data[0])
|
||||
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE:
|
||||
shared_data_c = ad_data
|
||||
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE:
|
||||
shared_data_r = ad_data
|
||||
elif ad_type == AdvertisingData.SECURITY_MANAGER_TK_VALUE:
|
||||
instance.legacy_context = OobLegacyContext(tk=ad_data)
|
||||
if shared_data_c and shared_data_r:
|
||||
instance.shared_data = OobSharedData(c=shared_data_c, r=shared_data_r)
|
||||
|
||||
return instance
|
||||
|
||||
def to_ad(self) -> AdvertisingData:
|
||||
ad_structures = []
|
||||
if self.address is not None:
|
||||
ad_structures.append(
|
||||
(AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address))
|
||||
)
|
||||
if self.role is not None:
|
||||
ad_structures.append((AdvertisingData.LE_ROLE, bytes([self.role])))
|
||||
if self.shared_data is not None:
|
||||
ad_structures.extend(self.shared_data.to_ad().ad_structures)
|
||||
if self.legacy_context is not None:
|
||||
ad_structures.append(
|
||||
(AdvertisingData.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk)
|
||||
)
|
||||
|
||||
return AdvertisingData(ad_structures)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -168,21 +224,39 @@ class PairingDelegate:
|
||||
class PairingConfig:
|
||||
"""Configuration for the Pairing protocol."""
|
||||
|
||||
class AddressType(enum.IntEnum):
|
||||
PUBLIC = Address.PUBLIC_DEVICE_ADDRESS
|
||||
RANDOM = Address.RANDOM_DEVICE_ADDRESS
|
||||
|
||||
@dataclass
|
||||
class OobConfig:
|
||||
"""Config for OOB pairing."""
|
||||
|
||||
our_context: Optional[OobContext]
|
||||
peer_data: Optional[OobSharedData]
|
||||
legacy_context: Optional[OobLegacyContext]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sc: bool = True,
|
||||
mitm: bool = True,
|
||||
bonding: bool = True,
|
||||
delegate: Optional[PairingDelegate] = None,
|
||||
identity_address_type: Optional[AddressType] = None,
|
||||
oob: Optional[OobConfig] = None,
|
||||
) -> None:
|
||||
self.sc = sc
|
||||
self.mitm = mitm
|
||||
self.bonding = bonding
|
||||
self.delegate = delegate or PairingDelegate()
|
||||
self.identity_address_type = identity_address_type
|
||||
self.oob = oob
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'PairingConfig(sc={self.sc}, '
|
||||
f'mitm={self.mitm}, bonding={self.bonding}, '
|
||||
f'delegate[{self.delegate.io_capability}])'
|
||||
f'identity_address_type={self.identity_address_type}, '
|
||||
f'delegate[{self.delegate.io_capability}]), '
|
||||
f'oob[{self.oob}])'
|
||||
)
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from bumble.pairing import PairingDelegate
|
||||
from __future__ import annotations
|
||||
from bumble.pairing import PairingConfig, PairingDelegate
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
@@ -20,6 +21,7 @@ from typing import Any, Dict
|
||||
@dataclass
|
||||
class Config:
|
||||
io_capability: PairingDelegate.IoCapability = PairingDelegate.NO_OUTPUT_NO_INPUT
|
||||
identity_address_type: PairingConfig.AddressType = PairingConfig.AddressType.RANDOM
|
||||
pairing_sc_enable: bool = True
|
||||
pairing_mitm_enable: bool = True
|
||||
pairing_bonding_enable: bool = True
|
||||
@@ -35,6 +37,12 @@ class Config:
|
||||
'io_capability', 'no_output_no_input'
|
||||
).upper()
|
||||
self.io_capability = getattr(PairingDelegate, io_capability_name)
|
||||
identity_address_type_name: str = config.get(
|
||||
'identity_address_type', 'random'
|
||||
).upper()
|
||||
self.identity_address_type = getattr(
|
||||
PairingConfig.AddressType, identity_address_type_name
|
||||
)
|
||||
self.pairing_sc_enable = config.get('pairing_sc_enable', True)
|
||||
self.pairing_mitm_enable = config.get('pairing_mitm_enable', True)
|
||||
self.pairing_bonding_enable = config.get('pairing_bonding_enable', True)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
"""Generic & dependency free Bumble (reference) device."""
|
||||
|
||||
from __future__ import annotations
|
||||
from bumble import transport
|
||||
from bumble.core import (
|
||||
BT_GENERIC_AUDIO_SERVICE,
|
||||
@@ -34,6 +35,10 @@ from bumble.sdp import (
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
# Default rootcanal HCI TCP address
|
||||
ROOTCANAL_HCI_ADDRESS = "localhost:6402"
|
||||
|
||||
|
||||
class PandoraDevice:
|
||||
"""
|
||||
Small wrapper around a Bumble device and it's HCI transport.
|
||||
@@ -53,7 +58,9 @@ class PandoraDevice:
|
||||
def __init__(self, config: Dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.device = _make_device(config)
|
||||
self._hci_name = config.get('transport', '')
|
||||
self._hci_name = config.get(
|
||||
'transport', f"tcp-client:{config.get('tcp', ROOTCANAL_HCI_ADDRESS)}"
|
||||
)
|
||||
self._hci = None
|
||||
|
||||
@property
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import bumble.device
|
||||
import grpc
|
||||
@@ -112,7 +113,7 @@ class HostService(HostServicer):
|
||||
async def FactoryReset(
|
||||
self, request: empty_pb2.Empty, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
self.log.info('FactoryReset')
|
||||
self.log.debug('FactoryReset')
|
||||
|
||||
# delete all bonds
|
||||
if self.device.keystore is not None:
|
||||
@@ -126,7 +127,7 @@ class HostService(HostServicer):
|
||||
async def Reset(
|
||||
self, request: empty_pb2.Empty, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
self.log.info('Reset')
|
||||
self.log.debug('Reset')
|
||||
|
||||
# clear service.
|
||||
self.waited_connections.clear()
|
||||
@@ -139,7 +140,7 @@ class HostService(HostServicer):
|
||||
async def ReadLocalAddress(
|
||||
self, request: empty_pb2.Empty, context: grpc.ServicerContext
|
||||
) -> ReadLocalAddressResponse:
|
||||
self.log.info('ReadLocalAddress')
|
||||
self.log.debug('ReadLocalAddress')
|
||||
return ReadLocalAddressResponse(
|
||||
address=bytes(reversed(bytes(self.device.public_address)))
|
||||
)
|
||||
@@ -152,7 +153,7 @@ class HostService(HostServicer):
|
||||
address = Address(
|
||||
bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS
|
||||
)
|
||||
self.log.info(f"Connect to {address}")
|
||||
self.log.debug(f"Connect to {address}")
|
||||
|
||||
try:
|
||||
connection = await self.device.connect(
|
||||
@@ -167,7 +168,7 @@ class HostService(HostServicer):
|
||||
return ConnectResponse(connection_already_exists=empty_pb2.Empty())
|
||||
raise e
|
||||
|
||||
self.log.info(f"Connect to {address} done (handle={connection.handle})")
|
||||
self.log.debug(f"Connect to {address} done (handle={connection.handle})")
|
||||
|
||||
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
|
||||
return ConnectResponse(connection=Connection(cookie=cookie))
|
||||
@@ -186,7 +187,7 @@ class HostService(HostServicer):
|
||||
if address in (Address.NIL, Address.ANY):
|
||||
raise ValueError('Invalid address')
|
||||
|
||||
self.log.info(f"WaitConnection from {address}...")
|
||||
self.log.debug(f"WaitConnection from {address}...")
|
||||
|
||||
connection = self.device.find_connection_by_bd_addr(
|
||||
address, transport=BT_BR_EDR_TRANSPORT
|
||||
@@ -201,7 +202,7 @@ class HostService(HostServicer):
|
||||
# save connection has waited and respond.
|
||||
self.waited_connections.add(id(connection))
|
||||
|
||||
self.log.info(
|
||||
self.log.debug(
|
||||
f"WaitConnection from {address} done (handle={connection.handle})"
|
||||
)
|
||||
|
||||
@@ -216,7 +217,7 @@ class HostService(HostServicer):
|
||||
if address in (Address.NIL, Address.ANY):
|
||||
raise ValueError('Invalid address')
|
||||
|
||||
self.log.info(f"ConnectLE to {address}...")
|
||||
self.log.debug(f"ConnectLE to {address}...")
|
||||
|
||||
try:
|
||||
connection = await self.device.connect(
|
||||
@@ -233,7 +234,7 @@ class HostService(HostServicer):
|
||||
return ConnectLEResponse(connection_already_exists=empty_pb2.Empty())
|
||||
raise e
|
||||
|
||||
self.log.info(f"ConnectLE to {address} done (handle={connection.handle})")
|
||||
self.log.debug(f"ConnectLE to {address} done (handle={connection.handle})")
|
||||
|
||||
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
|
||||
return ConnectLEResponse(connection=Connection(cookie=cookie))
|
||||
@@ -243,12 +244,12 @@ class HostService(HostServicer):
|
||||
self, request: DisconnectRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
self.log.info(f"Disconnect: {connection_handle}")
|
||||
self.log.debug(f"Disconnect: {connection_handle}")
|
||||
|
||||
self.log.info("Disconnecting...")
|
||||
self.log.debug("Disconnecting...")
|
||||
if connection := self.device.lookup_connection(connection_handle):
|
||||
await connection.disconnect(HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR)
|
||||
self.log.info("Disconnected")
|
||||
self.log.debug("Disconnected")
|
||||
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@@ -257,7 +258,7 @@ class HostService(HostServicer):
|
||||
self, request: WaitDisconnectionRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
self.log.info(f"WaitDisconnection: {connection_handle}")
|
||||
self.log.debug(f"WaitDisconnection: {connection_handle}")
|
||||
|
||||
if connection := self.device.lookup_connection(connection_handle):
|
||||
disconnection_future: asyncio.Future[
|
||||
@@ -270,7 +271,7 @@ class HostService(HostServicer):
|
||||
connection.on('disconnection', on_disconnection)
|
||||
try:
|
||||
await disconnection_future
|
||||
self.log.info("Disconnected")
|
||||
self.log.debug("Disconnected")
|
||||
finally:
|
||||
connection.remove_listener('disconnection', on_disconnection) # type: ignore
|
||||
|
||||
@@ -284,10 +285,11 @@ class HostService(HostServicer):
|
||||
raise NotImplementedError(
|
||||
"TODO: add support for extended advertising in Bumble"
|
||||
)
|
||||
if request.interval:
|
||||
raise NotImplementedError("TODO: add support for `request.interval`")
|
||||
if request.interval_range:
|
||||
raise NotImplementedError("TODO: add support for `request.interval_range`")
|
||||
if advertising_interval := request.interval:
|
||||
self.device.config.advertising_interval_min = int(advertising_interval)
|
||||
self.device.config.advertising_interval_max = int(advertising_interval)
|
||||
if interval_range := request.interval_range:
|
||||
self.device.config.advertising_interval_max += int(interval_range)
|
||||
if request.primary_phy:
|
||||
raise NotImplementedError("TODO: add support for `request.primary_phy`")
|
||||
if request.secondary_phy:
|
||||
@@ -378,7 +380,7 @@ class HostService(HostServicer):
|
||||
try:
|
||||
while True:
|
||||
if not self.device.is_advertising:
|
||||
self.log.info('Advertise')
|
||||
self.log.debug('Advertise')
|
||||
await self.device.start_advertising(
|
||||
target=target,
|
||||
advertising_type=advertising_type,
|
||||
@@ -393,10 +395,10 @@ class HostService(HostServicer):
|
||||
bumble.device.Connection
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
|
||||
self.log.info('Wait for LE connection...')
|
||||
self.log.debug('Wait for LE connection...')
|
||||
connection = await pending_connection
|
||||
|
||||
self.log.info(
|
||||
self.log.debug(
|
||||
f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})"
|
||||
)
|
||||
|
||||
@@ -410,7 +412,7 @@ class HostService(HostServicer):
|
||||
self.device.remove_listener('connection', on_connection) # type: ignore
|
||||
|
||||
try:
|
||||
self.log.info('Stop advertising')
|
||||
self.log.debug('Stop advertising')
|
||||
await self.device.abort_on('flush', self.device.stop_advertising())
|
||||
except:
|
||||
pass
|
||||
@@ -423,7 +425,7 @@ class HostService(HostServicer):
|
||||
if request.phys:
|
||||
raise NotImplementedError("TODO: add support for `request.phys`")
|
||||
|
||||
self.log.info('Scan')
|
||||
self.log.debug('Scan')
|
||||
|
||||
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
|
||||
handler = self.device.on('advertisement', scan_queue.put_nowait)
|
||||
@@ -470,7 +472,7 @@ class HostService(HostServicer):
|
||||
finally:
|
||||
self.device.remove_listener('advertisement', handler) # type: ignore
|
||||
try:
|
||||
self.log.info('Stop scanning')
|
||||
self.log.debug('Stop scanning')
|
||||
await self.device.abort_on('flush', self.device.stop_scanning())
|
||||
except:
|
||||
pass
|
||||
@@ -479,7 +481,7 @@ class HostService(HostServicer):
|
||||
async def Inquiry(
|
||||
self, request: empty_pb2.Empty, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[InquiryResponse, None]:
|
||||
self.log.info('Inquiry')
|
||||
self.log.debug('Inquiry')
|
||||
|
||||
inquiry_queue: asyncio.Queue[
|
||||
Optional[Tuple[Address, int, AdvertisingData, int]]
|
||||
@@ -510,7 +512,7 @@ class HostService(HostServicer):
|
||||
self.device.remove_listener('inquiry_complete', complete_handler) # type: ignore
|
||||
self.device.remove_listener('inquiry_result', result_handler) # type: ignore
|
||||
try:
|
||||
self.log.info('Stop inquiry')
|
||||
self.log.debug('Stop inquiry')
|
||||
await self.device.abort_on('flush', self.device.stop_discovery())
|
||||
except:
|
||||
pass
|
||||
@@ -519,7 +521,7 @@ class HostService(HostServicer):
|
||||
async def SetDiscoverabilityMode(
|
||||
self, request: SetDiscoverabilityModeRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
self.log.info("SetDiscoverabilityMode")
|
||||
self.log.debug("SetDiscoverabilityMode")
|
||||
await self.device.set_discoverable(request.mode != NOT_DISCOVERABLE)
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@@ -527,7 +529,7 @@ class HostService(HostServicer):
|
||||
async def SetConnectabilityMode(
|
||||
self, request: SetConnectabilityModeRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
self.log.info("SetConnectabilityMode")
|
||||
self.log.debug("SetConnectabilityMode")
|
||||
await self.device.set_connectable(request.mode != NOT_CONNECTABLE)
|
||||
return empty_pb2.Empty()
|
||||
|
||||
|
||||
@@ -12,7 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import grpc
|
||||
import logging
|
||||
|
||||
@@ -27,8 +29,8 @@ from bumble.core import (
|
||||
)
|
||||
from bumble.device import Connection as BumbleConnection, Device
|
||||
from bumble.hci import HCI_Error
|
||||
from bumble.utils import EventWatcher
|
||||
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
|
||||
from contextlib import suppress
|
||||
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
|
||||
@@ -99,7 +101,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
return ev
|
||||
|
||||
async def confirm(self, auto: bool = False) -> bool:
|
||||
self.log.info(
|
||||
self.log.debug(
|
||||
f"Pairing event: `just_works` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
@@ -108,13 +110,13 @@ class PairingDelegate(BasePairingDelegate):
|
||||
|
||||
event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
answer = await anext(self.service.event_answer) # type: ignore
|
||||
assert answer.event == event
|
||||
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
|
||||
return answer.confirm
|
||||
|
||||
async def compare_numbers(self, number: int, digits: int = 6) -> bool:
|
||||
self.log.info(
|
||||
self.log.debug(
|
||||
f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
@@ -123,13 +125,13 @@ class PairingDelegate(BasePairingDelegate):
|
||||
|
||||
event = self.add_origin(PairingEvent(numeric_comparison=number))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
answer = await anext(self.service.event_answer) # type: ignore
|
||||
assert answer.event == event
|
||||
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
|
||||
return answer.confirm
|
||||
|
||||
async def get_number(self) -> Optional[int]:
|
||||
self.log.info(
|
||||
self.log.debug(
|
||||
f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
@@ -138,7 +140,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
|
||||
event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
answer = await anext(self.service.event_answer) # type: ignore
|
||||
assert answer.event == event
|
||||
if answer.answer_variant() is None:
|
||||
return None
|
||||
@@ -146,7 +148,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
return answer.passkey
|
||||
|
||||
async def get_string(self, max_length: int) -> Optional[str]:
|
||||
self.log.info(
|
||||
self.log.debug(
|
||||
f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
@@ -155,7 +157,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
|
||||
event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
answer = await anext(self.service.event_answer) # type: ignore
|
||||
assert answer.event == event
|
||||
if answer.answer_variant() is None:
|
||||
return None
|
||||
@@ -177,7 +179,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
):
|
||||
return
|
||||
|
||||
self.log.info(
|
||||
self.log.debug(
|
||||
f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})"
|
||||
)
|
||||
|
||||
@@ -232,6 +234,11 @@ class SecurityService(SecurityServicer):
|
||||
sc=config.pairing_sc_enable,
|
||||
mitm=config.pairing_mitm_enable,
|
||||
bonding=config.pairing_bonding_enable,
|
||||
identity_address_type=(
|
||||
PairingConfig.AddressType.PUBLIC
|
||||
if connection.self_address.is_public
|
||||
else config.identity_address_type
|
||||
),
|
||||
delegate=PairingDelegate(
|
||||
connection,
|
||||
self,
|
||||
@@ -247,7 +254,7 @@ class SecurityService(SecurityServicer):
|
||||
async def OnPairing(
|
||||
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[PairingEvent, None]:
|
||||
self.log.info('OnPairing')
|
||||
self.log.debug('OnPairing')
|
||||
|
||||
if self.event_queue is not None:
|
||||
raise RuntimeError('already streaming pairing events')
|
||||
@@ -273,7 +280,7 @@ class SecurityService(SecurityServicer):
|
||||
self, request: SecureRequest, context: grpc.ServicerContext
|
||||
) -> SecureResponse:
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
self.log.info(f"Secure: {connection_handle}")
|
||||
self.log.debug(f"Secure: {connection_handle}")
|
||||
|
||||
connection = self.device.lookup_connection(connection_handle)
|
||||
assert connection
|
||||
@@ -291,25 +298,37 @@ class SecurityService(SecurityServicer):
|
||||
# trigger pairing if needed
|
||||
if self.need_pairing(connection, level):
|
||||
try:
|
||||
self.log.info('Pair...')
|
||||
self.log.debug('Pair...')
|
||||
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
):
|
||||
wait_for_security: asyncio.Future[
|
||||
bool
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
|
||||
connection.on("pairing_failure", wait_for_security.set_exception)
|
||||
security_result = asyncio.get_running_loop().create_future()
|
||||
|
||||
connection.request_pairing()
|
||||
with contextlib.closing(EventWatcher()) as watcher:
|
||||
|
||||
await wait_for_security
|
||||
else:
|
||||
await connection.pair()
|
||||
@watcher.on(connection, 'pairing')
|
||||
def on_pairing(*_: Any) -> None:
|
||||
security_result.set_result('success')
|
||||
|
||||
self.log.info('Paired')
|
||||
@watcher.on(connection, 'pairing_failure')
|
||||
def on_pairing_failure(*_: Any) -> None:
|
||||
security_result.set_result('pairing_failure')
|
||||
|
||||
@watcher.on(connection, 'disconnection')
|
||||
def on_disconnection(*_: Any) -> None:
|
||||
security_result.set_result('connection_died')
|
||||
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
):
|
||||
connection.request_pairing()
|
||||
else:
|
||||
await connection.pair()
|
||||
|
||||
result = await security_result
|
||||
|
||||
self.log.debug(f'Pairing session complete, status={result}')
|
||||
if result != 'success':
|
||||
return SecureResponse(**{result: empty_pb2.Empty()})
|
||||
except asyncio.CancelledError:
|
||||
self.log.warning("Connection died during encryption")
|
||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||
@@ -320,9 +339,9 @@ class SecurityService(SecurityServicer):
|
||||
# trigger authentication if needed
|
||||
if self.need_authentication(connection, level):
|
||||
try:
|
||||
self.log.info('Authenticate...')
|
||||
self.log.debug('Authenticate...')
|
||||
await connection.authenticate()
|
||||
self.log.info('Authenticated')
|
||||
self.log.debug('Authenticated')
|
||||
except asyncio.CancelledError:
|
||||
self.log.warning("Connection died during authentication")
|
||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||
@@ -333,9 +352,9 @@ class SecurityService(SecurityServicer):
|
||||
# trigger encryption if needed
|
||||
if self.need_encryption(connection, level):
|
||||
try:
|
||||
self.log.info('Encrypt...')
|
||||
self.log.debug('Encrypt...')
|
||||
await connection.encrypt()
|
||||
self.log.info('Encrypted')
|
||||
self.log.debug('Encrypted')
|
||||
except asyncio.CancelledError:
|
||||
self.log.warning("Connection died during encryption")
|
||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||
@@ -353,7 +372,7 @@ class SecurityService(SecurityServicer):
|
||||
self, request: WaitSecurityRequest, context: grpc.ServicerContext
|
||||
) -> WaitSecurityResponse:
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
self.log.info(f"WaitSecurity: {connection_handle}")
|
||||
self.log.debug(f"WaitSecurity: {connection_handle}")
|
||||
|
||||
connection = self.device.lookup_connection(connection_handle)
|
||||
assert connection
|
||||
@@ -368,6 +387,7 @@ class SecurityService(SecurityServicer):
|
||||
str
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
authenticate_task: Optional[asyncio.Future[None]] = None
|
||||
pair_task: Optional[asyncio.Future[None]] = None
|
||||
|
||||
async def authenticate() -> None:
|
||||
assert connection
|
||||
@@ -390,7 +410,7 @@ class SecurityService(SecurityServicer):
|
||||
|
||||
def set_failure(name: str) -> Callable[..., None]:
|
||||
def wrapper(*args: Any) -> None:
|
||||
self.log.info(f'Wait for security: error `{name}`: {args}')
|
||||
self.log.debug(f'Wait for security: error `{name}`: {args}')
|
||||
wait_for_security.set_result(name)
|
||||
|
||||
return wrapper
|
||||
@@ -398,13 +418,13 @@ class SecurityService(SecurityServicer):
|
||||
def try_set_success(*_: Any) -> None:
|
||||
assert connection
|
||||
if self.reached_security_level(connection, level):
|
||||
self.log.info('Wait for security: done')
|
||||
self.log.debug('Wait for security: done')
|
||||
wait_for_security.set_result('success')
|
||||
|
||||
def on_encryption_change(*_: Any) -> None:
|
||||
assert connection
|
||||
if self.reached_security_level(connection, level):
|
||||
self.log.info('Wait for security: done')
|
||||
self.log.debug('Wait for security: done')
|
||||
wait_for_security.set_result('success')
|
||||
elif (
|
||||
connection.transport == BT_BR_EDR_TRANSPORT
|
||||
@@ -414,6 +434,10 @@ class SecurityService(SecurityServicer):
|
||||
if authenticate_task is None:
|
||||
authenticate_task = asyncio.create_task(authenticate())
|
||||
|
||||
def pair(*_: Any) -> None:
|
||||
if self.need_pairing(connection, level):
|
||||
pair_task = asyncio.create_task(connection.pair())
|
||||
|
||||
listeners: Dict[str, Callable[..., None]] = {
|
||||
'disconnection': set_failure('connection_died'),
|
||||
'pairing_failure': set_failure('pairing_failure'),
|
||||
@@ -422,32 +446,41 @@ class SecurityService(SecurityServicer):
|
||||
'pairing': try_set_success,
|
||||
'connection_authentication': try_set_success,
|
||||
'connection_encryption_change': on_encryption_change,
|
||||
'classic_pairing': try_set_success,
|
||||
'classic_pairing_failure': set_failure('pairing_failure'),
|
||||
'security_request': pair,
|
||||
}
|
||||
|
||||
# register event handlers
|
||||
for event, listener in listeners.items():
|
||||
connection.on(event, listener)
|
||||
with contextlib.closing(EventWatcher()) as watcher:
|
||||
# register event handlers
|
||||
for event, listener in listeners.items():
|
||||
watcher.on(connection, event, listener)
|
||||
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
return WaitSecurityResponse(success=empty_pb2.Empty())
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
return WaitSecurityResponse(success=empty_pb2.Empty())
|
||||
|
||||
self.log.info('Wait for security...')
|
||||
kwargs = {}
|
||||
kwargs[await wait_for_security] = empty_pb2.Empty()
|
||||
|
||||
# remove event handlers
|
||||
for event, listener in listeners.items():
|
||||
connection.remove_listener(event, listener) # type: ignore
|
||||
self.log.debug('Wait for security...')
|
||||
kwargs = {}
|
||||
kwargs[await wait_for_security] = empty_pb2.Empty()
|
||||
|
||||
# wait for `authenticate` to finish if any
|
||||
if authenticate_task is not None:
|
||||
self.log.info('Wait for authentication...')
|
||||
self.log.debug('Wait for authentication...')
|
||||
try:
|
||||
await authenticate_task # type: ignore
|
||||
except:
|
||||
pass
|
||||
self.log.info('Authenticated')
|
||||
self.log.debug('Authenticated')
|
||||
|
||||
# wait for `pair` to finish if any
|
||||
if pair_task is not None:
|
||||
self.log.debug('Wait for authentication...')
|
||||
try:
|
||||
await pair_task # type: ignore
|
||||
except:
|
||||
pass
|
||||
self.log.debug('paired')
|
||||
|
||||
return WaitSecurityResponse(**kwargs)
|
||||
|
||||
@@ -503,7 +536,7 @@ class SecurityStorageService(SecurityStorageServicer):
|
||||
self, request: IsBondedRequest, context: grpc.ServicerContext
|
||||
) -> wrappers_pb2.BoolValue:
|
||||
address = utils.address_from_request(request, request.WhichOneof("address"))
|
||||
self.log.info(f"IsBonded: {address}")
|
||||
self.log.debug(f"IsBonded: {address}")
|
||||
|
||||
if self.device.keystore is not None:
|
||||
is_bonded = await self.device.keystore.get(str(address)) is not None
|
||||
@@ -517,10 +550,10 @@ class SecurityStorageService(SecurityStorageServicer):
|
||||
self, request: DeleteBondRequest, context: grpc.ServicerContext
|
||||
) -> empty_pb2.Empty:
|
||||
address = utils.address_from_request(request, request.WhichOneof("address"))
|
||||
self.log.info(f"DeleteBond: {address}")
|
||||
self.log.debug(f"DeleteBond: {address}")
|
||||
|
||||
if self.device.keystore is not None:
|
||||
with suppress(KeyError):
|
||||
with contextlib.suppress(KeyError):
|
||||
await self.device.keystore.delete(str(address))
|
||||
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import functools
|
||||
import grpc
|
||||
|
||||
@@ -18,7 +18,9 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from bumble import l2cap
|
||||
from ..core import AdvertisingData
|
||||
from ..device import Device, Connection
|
||||
from ..gatt import (
|
||||
@@ -65,7 +67,7 @@ class AshaService(TemplateService):
|
||||
self.emit('volume', connection, value[0])
|
||||
|
||||
# Handler for audio control commands
|
||||
def on_audio_control_point_write(connection: Connection, value):
|
||||
def on_audio_control_point_write(connection: Optional[Connection], value):
|
||||
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
|
||||
opcode = value[0]
|
||||
if opcode == AshaService.OPCODE_START:
|
||||
@@ -149,7 +151,10 @@ class AshaService(TemplateService):
|
||||
channel.sink = on_data
|
||||
|
||||
# let the server find a free PSM
|
||||
self.psm = self.device.register_l2cap_channel_server(self.psm, on_coc, 8)
|
||||
self.psm = device.create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(psm=self.psm, max_credits=8),
|
||||
handler=on_coc,
|
||||
).psm
|
||||
self.le_psm_out_characteristic = Characteristic(
|
||||
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
|
||||
Characteristic.Properties.READ,
|
||||
|
||||
1247
bumble/profiles/bap.py
Normal file
1247
bumble/profiles/bap.py
Normal file
File diff suppressed because it is too large
Load Diff
52
bumble/profiles/cap.py
Normal file
52
bumble/profiles/cap.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble.profiles import csip
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommonAudioServiceService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_COMMON_AUDIO_SERVICE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
coordinated_set_identification_service: csip.CoordinatedSetIdentificationService,
|
||||
) -> None:
|
||||
self.coordinated_set_identification_service = (
|
||||
coordinated_set_identification_service
|
||||
)
|
||||
super().__init__(
|
||||
characteristics=[],
|
||||
included_services=[coordinated_set_identification_service],
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommonAudioServiceServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = CommonAudioServiceService
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
257
bumble/profiles/csip.py
Normal file
257
bumble/profiles/csip.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import struct
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from bumble import core
|
||||
from bumble import crypto
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
SET_IDENTITY_RESOLVING_KEY_LENGTH = 16
|
||||
|
||||
|
||||
class SirkType(enum.IntEnum):
|
||||
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
|
||||
|
||||
ENCRYPTED = 0x00
|
||||
PLAINTEXT = 0x01
|
||||
|
||||
|
||||
class MemberLock(enum.IntEnum):
|
||||
'''Coordinated Set Identification Service - 5.3 Set Member Lock.'''
|
||||
|
||||
UNLOCKED = 0x01
|
||||
LOCKED = 0x02
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Crypto Toolbox
|
||||
# -----------------------------------------------------------------------------
|
||||
def s1(m: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.3 s1 SALT generation function.
|
||||
'''
|
||||
return crypto.aes_cmac(m[::-1], bytes(16))[::-1]
|
||||
|
||||
|
||||
def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.4 k1 derivation function.
|
||||
'''
|
||||
t = crypto.aes_cmac(n[::-1], salt[::-1])
|
||||
return crypto.aes_cmac(p[::-1], t)[::-1]
|
||||
|
||||
|
||||
def sef(k: bytes, r: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
|
||||
|
||||
SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
|
||||
* Plaintext in encryption
|
||||
* Cipher in decryption
|
||||
'''
|
||||
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
|
||||
|
||||
|
||||
def sih(k: bytes, r: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
|
||||
'''
|
||||
return crypto.e(k, r + bytes(13))[:3]
|
||||
|
||||
|
||||
def generate_rsi(sirk: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
|
||||
'''
|
||||
prand = crypto.generate_prand()
|
||||
return sih(sirk, prand) + prand
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
|
||||
|
||||
set_identity_resolving_key: bytes
|
||||
set_identity_resolving_key_characteristic: gatt.Characteristic
|
||||
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
|
||||
set_member_lock_characteristic: Optional[gatt.Characteristic] = None
|
||||
set_member_rank_characteristic: Optional[gatt.Characteristic] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
set_identity_resolving_key: bytes,
|
||||
set_identity_resolving_key_type: SirkType,
|
||||
coordinated_set_size: Optional[int] = None,
|
||||
set_member_lock: Optional[MemberLock] = None,
|
||||
set_member_rank: Optional[int] = None,
|
||||
) -> None:
|
||||
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
|
||||
raise ValueError(
|
||||
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
|
||||
)
|
||||
|
||||
characteristics = []
|
||||
|
||||
self.set_identity_resolving_key = set_identity_resolving_key
|
||||
self.set_identity_resolving_key_type = set_identity_resolving_key_type
|
||||
self.set_identity_resolving_key_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(read=self.on_sirk_read),
|
||||
)
|
||||
characteristics.append(self.set_identity_resolving_key_characteristic)
|
||||
|
||||
if coordinated_set_size is not None:
|
||||
self.coordinated_set_size_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=struct.pack('B', coordinated_set_size),
|
||||
)
|
||||
characteristics.append(self.coordinated_set_size_characteristic)
|
||||
|
||||
if set_member_lock is not None:
|
||||
self.set_member_lock_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY
|
||||
| gatt.Characteristic.Properties.WRITE,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| gatt.Characteristic.Permissions.WRITEABLE,
|
||||
value=struct.pack('B', set_member_lock),
|
||||
)
|
||||
characteristics.append(self.set_member_lock_characteristic)
|
||||
|
||||
if set_member_rank is not None:
|
||||
self.set_member_rank_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=struct.pack('B', set_member_rank),
|
||||
)
|
||||
characteristics.append(self.set_member_rank_characteristic)
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
|
||||
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
|
||||
sirk_bytes = self.set_identity_resolving_key
|
||||
else:
|
||||
assert connection
|
||||
|
||||
if connection.transport == core.BT_LE_TRANSPORT:
|
||||
key = await connection.device.get_long_term_key(
|
||||
connection_handle=connection.handle, rand=b'', ediv=0
|
||||
)
|
||||
else:
|
||||
key = await connection.device.get_link_key(connection.peer_address)
|
||||
|
||||
if not key:
|
||||
raise RuntimeError('LTK or LinkKey is not present')
|
||||
|
||||
sirk_bytes = sef(key, self.set_identity_resolving_key)
|
||||
|
||||
return bytes([self.set_identity_resolving_key_type]) + sirk_bytes
|
||||
|
||||
def get_advertising_data(self) -> bytes:
|
||||
return bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
(
|
||||
core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
|
||||
generate_rsi(self.set_identity_resolving_key),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = CoordinatedSetIdentificationService
|
||||
|
||||
set_identity_resolving_key: gatt_client.CharacteristicProxy
|
||||
coordinated_set_size: Optional[gatt_client.CharacteristicProxy] = None
|
||||
set_member_lock: Optional[gatt_client.CharacteristicProxy] = None
|
||||
set_member_rank: Optional[gatt_client.CharacteristicProxy] = None
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.set_identity_resolving_key = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC
|
||||
)[0]
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC
|
||||
):
|
||||
self.coordinated_set_size = characteristics[0]
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC
|
||||
):
|
||||
self.set_member_lock = characteristics[0]
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
|
||||
):
|
||||
self.set_member_rank = characteristics[0]
|
||||
|
||||
async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
|
||||
'''Reads SIRK and decrypts if encrypted.'''
|
||||
response = await self.set_identity_resolving_key.read_value()
|
||||
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
|
||||
raise RuntimeError('Invalid SIRK value')
|
||||
|
||||
sirk_type = SirkType(response[0])
|
||||
if sirk_type == SirkType.PLAINTEXT:
|
||||
sirk = response[1:]
|
||||
else:
|
||||
connection = self.service_proxy.client.connection
|
||||
device = connection.device
|
||||
if connection.transport == core.BT_LE_TRANSPORT:
|
||||
key = await device.get_long_term_key(
|
||||
connection_handle=connection.handle, rand=b'', ediv=0
|
||||
)
|
||||
else:
|
||||
key = await device.get_link_key(connection.peer_address)
|
||||
|
||||
if not key:
|
||||
raise RuntimeError('LTK or LinkKey is not present')
|
||||
|
||||
sirk = sef(key, response[1:])
|
||||
|
||||
return (sirk_type, sirk)
|
||||
@@ -42,12 +42,12 @@ class HeartRateService(TemplateService):
|
||||
RESET_ENERGY_EXPENDED = 0x01
|
||||
|
||||
class BodySensorLocation(IntEnum):
|
||||
OTHER = (0,)
|
||||
CHEST = (1,)
|
||||
WRIST = (2,)
|
||||
FINGER = (3,)
|
||||
HAND = (4,)
|
||||
EAR_LOBE = (5,)
|
||||
OTHER = 0
|
||||
CHEST = 1
|
||||
WRIST = 2
|
||||
FINGER = 3
|
||||
HAND = 4
|
||||
EAR_LOBE = 5
|
||||
FOOT = 6
|
||||
|
||||
class HeartRateMeasurement:
|
||||
|
||||
228
bumble/profiles/vcp.py
Normal file
228
bumble/profiles/vcp.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# Copyright 2021-2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
|
||||
from bumble import att
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
|
||||
from typing import Optional
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
MIN_VOLUME = 0
|
||||
MAX_VOLUME = 255
|
||||
|
||||
|
||||
class ErrorCode(enum.IntEnum):
|
||||
'''
|
||||
See Volume Control Service 1.6. Application error codes.
|
||||
'''
|
||||
|
||||
INVALID_CHANGE_COUNTER = 0x80
|
||||
OPCODE_NOT_SUPPORTED = 0x81
|
||||
|
||||
|
||||
class VolumeFlags(enum.IntFlag):
|
||||
'''
|
||||
See Volume Control Service 3.3. Volume Flags.
|
||||
'''
|
||||
|
||||
VOLUME_SETTING_PERSISTED = 0x01
|
||||
# RFU
|
||||
|
||||
|
||||
class VolumeControlPointOpcode(enum.IntEnum):
|
||||
'''
|
||||
See Volume Control Service Table 3.3: Volume Control Point procedure requirements.
|
||||
'''
|
||||
|
||||
# fmt: off
|
||||
RELATIVE_VOLUME_DOWN = 0x00
|
||||
RELATIVE_VOLUME_UP = 0x01
|
||||
UNMUTE_RELATIVE_VOLUME_DOWN = 0x02
|
||||
UNMUTE_RELATIVE_VOLUME_UP = 0x03
|
||||
SET_ABSOLUTE_VOLUME = 0x04
|
||||
UNMUTE = 0x05
|
||||
MUTE = 0x06
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class VolumeControlService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_VOLUME_CONTROL_SERVICE
|
||||
|
||||
volume_state: gatt.Characteristic
|
||||
volume_control_point: gatt.Characteristic
|
||||
volume_flags: gatt.Characteristic
|
||||
|
||||
volume_setting: int
|
||||
muted: int
|
||||
change_counter: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_size: int = 16,
|
||||
volume_setting: int = 0,
|
||||
muted: int = 0,
|
||||
change_counter: int = 0,
|
||||
volume_flags: int = 0,
|
||||
) -> None:
|
||||
self.step_size = step_size
|
||||
self.volume_setting = volume_setting
|
||||
self.muted = muted
|
||||
self.change_counter = change_counter
|
||||
|
||||
self.volume_state = gatt.Characteristic(
|
||||
uuid=gatt.GATT_VOLUME_STATE_CHARACTERISTIC,
|
||||
properties=(
|
||||
gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY
|
||||
),
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(read=self._on_read_volume_state),
|
||||
)
|
||||
self.volume_control_point = gatt.Characteristic(
|
||||
uuid=gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.WRITE,
|
||||
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(write=self._on_write_volume_control_point),
|
||||
)
|
||||
self.volume_flags = gatt.Characteristic(
|
||||
uuid=gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=bytes([volume_flags]),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
[
|
||||
self.volume_state,
|
||||
self.volume_control_point,
|
||||
self.volume_flags,
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def volume_state_bytes(self) -> bytes:
|
||||
return bytes([self.volume_setting, self.muted, self.change_counter])
|
||||
|
||||
@volume_state_bytes.setter
|
||||
def volume_state_bytes(self, new_value: bytes) -> None:
|
||||
self.volume_setting, self.muted, self.change_counter = new_value
|
||||
|
||||
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
|
||||
return self.volume_state_bytes
|
||||
|
||||
def _on_write_volume_control_point(
|
||||
self, connection: Optional[device.Connection], value: bytes
|
||||
) -> None:
|
||||
assert connection
|
||||
|
||||
opcode = VolumeControlPointOpcode(value[0])
|
||||
change_counter = value[1]
|
||||
|
||||
if change_counter != self.change_counter:
|
||||
raise att.ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
|
||||
|
||||
handler = getattr(self, '_on_' + opcode.name.lower())
|
||||
if handler(*value[2:]):
|
||||
self.change_counter = (self.change_counter + 1) % 256
|
||||
connection.abort_on(
|
||||
'disconnection',
|
||||
connection.device.notify_subscribers(
|
||||
attribute=self.volume_state,
|
||||
value=self.volume_state_bytes,
|
||||
),
|
||||
)
|
||||
self.emit(
|
||||
'volume_state', self.volume_setting, self.muted, self.change_counter
|
||||
)
|
||||
|
||||
def _on_relative_volume_down(self) -> bool:
|
||||
old_volume = self.volume_setting
|
||||
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
|
||||
return self.volume_setting != old_volume
|
||||
|
||||
def _on_relative_volume_up(self) -> bool:
|
||||
old_volume = self.volume_setting
|
||||
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
|
||||
return self.volume_setting != old_volume
|
||||
|
||||
def _on_unmute_relative_volume_down(self) -> bool:
|
||||
old_volume, old_muted_state = self.volume_setting, self.muted
|
||||
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
|
||||
self.muted = 0
|
||||
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
|
||||
|
||||
def _on_unmute_relative_volume_up(self) -> bool:
|
||||
old_volume, old_muted_state = self.volume_setting, self.muted
|
||||
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
|
||||
self.muted = 0
|
||||
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
|
||||
|
||||
def _on_set_absolute_volume(self, volume_setting: int) -> bool:
|
||||
old_volume_setting = self.volume_setting
|
||||
self.volume_setting = volume_setting
|
||||
return old_volume_setting != self.volume_setting
|
||||
|
||||
def _on_unmute(self) -> bool:
|
||||
old_muted_state = self.muted
|
||||
self.muted = 0
|
||||
return self.muted != old_muted_state
|
||||
|
||||
def _on_mute(self) -> bool:
|
||||
old_muted_state = self.muted
|
||||
self.muted = 1
|
||||
return self.muted != old_muted_state
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class VolumeControlServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = VolumeControlService
|
||||
|
||||
volume_control_point: gatt_client.CharacteristicProxy
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.volume_state = gatt.PackedCharacteristicAdapter(
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_VOLUME_STATE_CHARACTERISTIC
|
||||
)[0],
|
||||
'BBB',
|
||||
)
|
||||
|
||||
self.volume_control_point = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC
|
||||
)[0]
|
||||
|
||||
self.volume_flags = gatt.PackedCharacteristicAdapter(
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC
|
||||
)[0],
|
||||
'B',
|
||||
)
|
||||
582
bumble/rfcomm.py
582
bumble/rfcomm.py
File diff suppressed because it is too large
Load Diff
144
bumble/sdp.py
144
bumble/sdp.py
@@ -18,13 +18,17 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import struct
|
||||
from typing import Dict, List, Type
|
||||
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
|
||||
from . import core
|
||||
from . import core, l2cap
|
||||
from .colors import color
|
||||
from .core import InvalidStateError
|
||||
from .hci import HCI_Object, name_or_number, key_with_value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .device import Device, Connection
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -94,6 +98,11 @@ SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID = 0X000B
|
||||
SDP_ICON_URL_ATTRIBUTE_ID = 0X000C
|
||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D
|
||||
|
||||
|
||||
# Profile-specific Attribute Identifiers (cf. Assigned Numbers for Service Discovery)
|
||||
# used by AVRCP, HFP and A2DP
|
||||
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311
|
||||
|
||||
SDP_ATTRIBUTE_ID_NAMES = {
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID: 'SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID',
|
||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID: 'SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID',
|
||||
@@ -108,7 +117,8 @@ SDP_ATTRIBUTE_ID_NAMES = {
|
||||
SDP_DOCUMENTATION_URL_ATTRIBUTE_ID: 'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID',
|
||||
SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID: 'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID',
|
||||
SDP_ICON_URL_ATTRIBUTE_ID: 'SDP_ICON_URL_ATTRIBUTE_ID',
|
||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID'
|
||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID',
|
||||
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID: 'SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID',
|
||||
}
|
||||
|
||||
SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
|
||||
@@ -160,7 +170,7 @@ class DataElement:
|
||||
UUID: lambda x: DataElement(
|
||||
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
|
||||
),
|
||||
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')),
|
||||
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
|
||||
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
|
||||
SEQUENCE: lambda x: DataElement(
|
||||
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
|
||||
@@ -222,7 +232,7 @@ class DataElement:
|
||||
return DataElement(DataElement.UUID, value)
|
||||
|
||||
@staticmethod
|
||||
def text_string(value: str) -> DataElement:
|
||||
def text_string(value: bytes) -> DataElement:
|
||||
return DataElement(DataElement.TEXT_STRING, value)
|
||||
|
||||
@staticmethod
|
||||
@@ -369,7 +379,7 @@ class DataElement:
|
||||
raise ValueError('invalid value_size')
|
||||
elif self.type == DataElement.UUID:
|
||||
data = bytes(reversed(bytes(self.value)))
|
||||
elif self.type in (DataElement.TEXT_STRING, DataElement.URL):
|
||||
elif self.type == DataElement.URL:
|
||||
data = self.value.encode('utf8')
|
||||
elif self.type == DataElement.BOOLEAN:
|
||||
data = bytes([1 if self.value else 0])
|
||||
@@ -462,7 +472,7 @@ class ServiceAttribute:
|
||||
self.value = value
|
||||
|
||||
@staticmethod
|
||||
def list_from_data_elements(elements):
|
||||
def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]:
|
||||
attribute_list = []
|
||||
for i in range(0, len(elements) // 2):
|
||||
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
|
||||
@@ -474,7 +484,9 @@ class ServiceAttribute:
|
||||
return attribute_list
|
||||
|
||||
@staticmethod
|
||||
def find_attribute_in_list(attribute_list, attribute_id):
|
||||
def find_attribute_in_list(
|
||||
attribute_list: List[ServiceAttribute], attribute_id: int
|
||||
) -> Optional[DataElement]:
|
||||
return next(
|
||||
(
|
||||
attribute.value
|
||||
@@ -489,7 +501,7 @@ class ServiceAttribute:
|
||||
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
|
||||
|
||||
@staticmethod
|
||||
def is_uuid_in_value(uuid, value):
|
||||
def is_uuid_in_value(uuid: core.UUID, value: DataElement) -> bool:
|
||||
# Find if a uuid matches a value, either directly or recursing into sequences
|
||||
if value.type == DataElement.UUID:
|
||||
return value.value == uuid
|
||||
@@ -543,7 +555,9 @@ class SDP_PDU:
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def parse_service_record_handle_list_preceded_by_count(data, offset):
|
||||
def parse_service_record_handle_list_preceded_by_count(
|
||||
data: bytes, offset: int
|
||||
) -> Tuple[int, List[int]]:
|
||||
count = struct.unpack_from('>H', data, offset - 2)[0]
|
||||
handle_list = [
|
||||
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
|
||||
@@ -641,6 +655,10 @@ class SDP_ServiceSearchRequest(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
|
||||
'''
|
||||
|
||||
service_search_pattern: DataElement
|
||||
maximum_service_record_count: int
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SDP_PDU.subclass(
|
||||
@@ -659,6 +677,11 @@ class SDP_ServiceSearchResponse(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
|
||||
'''
|
||||
|
||||
service_record_handle_list: List[int]
|
||||
total_service_record_count: int
|
||||
current_service_record_count: int
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SDP_PDU.subclass(
|
||||
@@ -674,6 +697,11 @@ class SDP_ServiceAttributeRequest(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
|
||||
'''
|
||||
|
||||
service_record_handle: int
|
||||
maximum_attribute_byte_count: int
|
||||
attribute_id_list: DataElement
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SDP_PDU.subclass(
|
||||
@@ -688,6 +716,10 @@ class SDP_ServiceAttributeResponse(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
|
||||
'''
|
||||
|
||||
attribute_list_byte_count: int
|
||||
attribute_list: bytes
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SDP_PDU.subclass(
|
||||
@@ -703,6 +735,11 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
|
||||
'''
|
||||
|
||||
service_search_pattern: DataElement
|
||||
maximum_attribute_byte_count: int
|
||||
attribute_id_list: DataElement
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SDP_PDU.subclass(
|
||||
@@ -717,26 +754,35 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
|
||||
'''
|
||||
|
||||
attribute_list_byte_count: int
|
||||
attribute_list: bytes
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Client:
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
channel: Optional[l2cap.ClassicChannel]
|
||||
|
||||
def __init__(self, connection: Connection) -> None:
|
||||
self.connection = connection
|
||||
self.pending_request = None
|
||||
self.channel = None
|
||||
|
||||
async def connect(self, connection):
|
||||
result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM)
|
||||
self.channel = result
|
||||
async def connect(self) -> None:
|
||||
self.channel = await self.connection.create_l2cap_channel(
|
||||
spec=l2cap.ClassicChannelSpec(SDP_PSM)
|
||||
)
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
if self.channel:
|
||||
await self.channel.disconnect()
|
||||
self.channel = None
|
||||
|
||||
async def search_services(self, uuids):
|
||||
async def search_services(self, uuids: List[core.UUID]) -> List[int]:
|
||||
if self.pending_request is not None:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.channel is None:
|
||||
raise InvalidStateError('L2CAP not connected')
|
||||
|
||||
service_search_pattern = DataElement.sequence(
|
||||
[DataElement.uuid(uuid) for uuid in uuids]
|
||||
@@ -766,9 +812,13 @@ class Client:
|
||||
|
||||
return service_record_handle_list
|
||||
|
||||
async def search_attributes(self, uuids, attribute_ids):
|
||||
async def search_attributes(
|
||||
self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]]
|
||||
) -> List[List[ServiceAttribute]]:
|
||||
if self.pending_request is not None:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.channel is None:
|
||||
raise InvalidStateError('L2CAP not connected')
|
||||
|
||||
service_search_pattern = DataElement.sequence(
|
||||
[DataElement.uuid(uuid) for uuid in uuids]
|
||||
@@ -819,9 +869,15 @@ class Client:
|
||||
if sequence.type == DataElement.SEQUENCE
|
||||
]
|
||||
|
||||
async def get_attributes(self, service_record_handle, attribute_ids):
|
||||
async def get_attributes(
|
||||
self,
|
||||
service_record_handle: int,
|
||||
attribute_ids: List[Union[int, Tuple[int, int]]],
|
||||
) -> List[ServiceAttribute]:
|
||||
if self.pending_request is not None:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.channel is None:
|
||||
raise InvalidStateError('L2CAP not connected')
|
||||
|
||||
attribute_id_list = DataElement.sequence(
|
||||
[
|
||||
@@ -865,25 +921,38 @@ class Client:
|
||||
|
||||
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args) -> None:
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Server:
|
||||
CONTINUATION_STATE = bytes([0x01, 0x43])
|
||||
channel: Optional[l2cap.ClassicChannel]
|
||||
Service = NewType('Service', List[ServiceAttribute])
|
||||
service_records: Dict[int, Service]
|
||||
current_response: Union[None, bytes, Tuple[int, List[int]]]
|
||||
|
||||
def __init__(self, device):
|
||||
def __init__(self, device: Device) -> None:
|
||||
self.device = device
|
||||
self.service_records = {} # Service records maps, by record handle
|
||||
self.channel = None
|
||||
self.current_response = None
|
||||
|
||||
def register(self, l2cap_channel_manager):
|
||||
l2cap_channel_manager.register_server(SDP_PSM, self.on_connection)
|
||||
def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
|
||||
l2cap_channel_manager.create_classic_server(
|
||||
spec=l2cap.ClassicChannelSpec(psm=SDP_PSM), handler=self.on_connection
|
||||
)
|
||||
|
||||
def send_response(self, response):
|
||||
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
|
||||
self.channel.send_pdu(response)
|
||||
|
||||
def match_services(self, search_pattern):
|
||||
def match_services(self, search_pattern: DataElement) -> Dict[int, Service]:
|
||||
# Find the services for which the attributes in the pattern is a subset of the
|
||||
# service's attribute values (NOTE: the value search recurses into sequences)
|
||||
matching_services = {}
|
||||
@@ -953,7 +1022,9 @@ class Server:
|
||||
return (payload, continuation_state)
|
||||
|
||||
@staticmethod
|
||||
def get_service_attributes(service, attribute_ids):
|
||||
def get_service_attributes(
|
||||
service: Service, attribute_ids: List[DataElement]
|
||||
) -> DataElement:
|
||||
attributes = []
|
||||
for attribute_id in attribute_ids:
|
||||
if attribute_id.value_size == 4:
|
||||
@@ -978,10 +1049,10 @@ class Server:
|
||||
|
||||
return attribute_list
|
||||
|
||||
def on_sdp_service_search_request(self, request):
|
||||
def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None:
|
||||
# Check if this is a continuation
|
||||
if len(request.continuation_state) > 1:
|
||||
if not self.current_response:
|
||||
if self.current_response is None:
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
@@ -1010,6 +1081,7 @@ class Server:
|
||||
)
|
||||
|
||||
# Respond, keeping any unsent handles for later
|
||||
assert isinstance(self.current_response, tuple)
|
||||
service_record_handles = self.current_response[1][
|
||||
: request.maximum_service_record_count
|
||||
]
|
||||
@@ -1033,10 +1105,12 @@ class Server:
|
||||
)
|
||||
)
|
||||
|
||||
def on_sdp_service_attribute_request(self, request):
|
||||
def on_sdp_service_attribute_request(
|
||||
self, request: SDP_ServiceAttributeRequest
|
||||
) -> None:
|
||||
# Check if this is a continuation
|
||||
if len(request.continuation_state) > 1:
|
||||
if not self.current_response:
|
||||
if self.current_response is None:
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
@@ -1069,22 +1143,24 @@ class Server:
|
||||
self.current_response = bytes(attribute_list)
|
||||
|
||||
# Respond, keeping any pending chunks for later
|
||||
attribute_list, continuation_state = self.get_next_response_payload(
|
||||
attribute_list_response, continuation_state = self.get_next_response_payload(
|
||||
request.maximum_attribute_byte_count
|
||||
)
|
||||
self.send_response(
|
||||
SDP_ServiceAttributeResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
attribute_list_byte_count=len(attribute_list),
|
||||
attribute_list_byte_count=len(attribute_list_response),
|
||||
attribute_list=attribute_list,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
|
||||
def on_sdp_service_search_attribute_request(self, request):
|
||||
def on_sdp_service_search_attribute_request(
|
||||
self, request: SDP_ServiceSearchAttributeRequest
|
||||
) -> None:
|
||||
# Check if this is a continuation
|
||||
if len(request.continuation_state) > 1:
|
||||
if not self.current_response:
|
||||
if self.current_response is None:
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
@@ -1114,13 +1190,13 @@ class Server:
|
||||
self.current_response = bytes(attribute_lists)
|
||||
|
||||
# Respond, keeping any pending chunks for later
|
||||
attribute_lists, continuation_state = self.get_next_response_payload(
|
||||
attribute_lists_response, continuation_state = self.get_next_response_payload(
|
||||
request.maximum_attribute_byte_count
|
||||
)
|
||||
self.send_response(
|
||||
SDP_ServiceSearchAttributeResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
attribute_lists_byte_count=len(attribute_lists),
|
||||
attribute_lists_byte_count=len(attribute_lists_response),
|
||||
attribute_lists=attribute_lists,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
|
||||
531
bumble/smp.py
531
bumble/smp.py
@@ -25,7 +25,9 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import asyncio
|
||||
import enum
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -36,6 +38,7 @@ from typing import (
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pyee import EventEmitter
|
||||
@@ -51,6 +54,7 @@ from .core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_CENTRAL_ROLE,
|
||||
BT_LE_TRANSPORT,
|
||||
AdvertisingData,
|
||||
ProtocolError,
|
||||
name_or_number,
|
||||
)
|
||||
@@ -183,8 +187,8 @@ SMP_KEYPRESS_AUTHREQ = 0b00010000
|
||||
SMP_CT2_AUTHREQ = 0b00100000
|
||||
|
||||
# Crypto salt
|
||||
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('00000000000000000000000000000000746D7031')
|
||||
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032')
|
||||
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('000000000000000000000000746D7031')
|
||||
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
|
||||
|
||||
# fmt: on
|
||||
# pylint: enable=line-too-long
|
||||
@@ -553,20 +557,64 @@ class AddressResolver:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Session:
|
||||
# Pairing methods
|
||||
class PairingMethod(enum.IntEnum):
|
||||
JUST_WORKS = 0
|
||||
NUMERIC_COMPARISON = 1
|
||||
PASSKEY = 2
|
||||
OOB = 3
|
||||
CTKD_OVER_CLASSIC = 4
|
||||
|
||||
PAIRING_METHOD_NAMES = {
|
||||
JUST_WORKS: 'JUST_WORKS',
|
||||
NUMERIC_COMPARISON: 'NUMERIC_COMPARISON',
|
||||
PASSKEY: 'PASSKEY',
|
||||
OOB: 'OOB',
|
||||
}
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OobContext:
|
||||
"""Cryptographic context for LE SC OOB pairing."""
|
||||
|
||||
ecc_key: crypto.EccKey
|
||||
r: bytes
|
||||
|
||||
def __init__(
|
||||
self, ecc_key: Optional[crypto.EccKey] = None, r: Optional[bytes] = None
|
||||
) -> None:
|
||||
self.ecc_key = crypto.EccKey.generate() if ecc_key is None else ecc_key
|
||||
self.r = crypto.r() if r is None else r
|
||||
|
||||
def share(self) -> OobSharedData:
|
||||
pkx = self.ecc_key.x[::-1]
|
||||
return OobSharedData(c=crypto.f4(pkx, pkx, self.r, bytes(1)), r=self.r)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OobLegacyContext:
|
||||
"""Cryptographic context for LE Legacy OOB pairing."""
|
||||
|
||||
tk: bytes
|
||||
|
||||
def __init__(self, tk: Optional[bytes] = None) -> None:
|
||||
self.tk = crypto.r() if tk is None else tk
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class OobSharedData:
|
||||
"""Shareable data for LE SC OOB pairing."""
|
||||
|
||||
c: bytes
|
||||
r: bytes
|
||||
|
||||
def to_ad(self) -> AdvertisingData:
|
||||
return AdvertisingData(
|
||||
[
|
||||
(AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE, self.c),
|
||||
(AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE, self.r),
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'OOB(C={self.c.hex()}, R={self.r.hex()})'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Session:
|
||||
# I/O Capability to pairing method decision matrix
|
||||
#
|
||||
# See Bluetooth spec @ Vol 3, part H - Table 2.8: Mapping of IO Capabilities to Key
|
||||
@@ -581,51 +629,61 @@ class Session:
|
||||
# (False).
|
||||
PAIRING_METHODS = {
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY: {
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS,
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY: JUST_WORKS,
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False),
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PASSKEY, True, False),
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY: PairingMethod.JUST_WORKS,
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, True, False),
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PairingMethod.PASSKEY, True, False),
|
||||
},
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY: {
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS,
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (JUST_WORKS, NUMERIC_COMPARISON),
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False),
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (
|
||||
PairingMethod.JUST_WORKS,
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
),
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, True, False),
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (
|
||||
(PASSKEY, True, False),
|
||||
NUMERIC_COMPARISON,
|
||||
(PairingMethod.PASSKEY, True, False),
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
),
|
||||
},
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY: {
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY: (PASSKEY, False, True),
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (PASSKEY, False, True),
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, False, False),
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PASSKEY, False, True),
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, False, True),
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (PairingMethod.PASSKEY, False, True),
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, False, False),
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PairingMethod.PASSKEY, False, True),
|
||||
},
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: {
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS,
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY: JUST_WORKS,
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY: JUST_WORKS,
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: JUST_WORKS,
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY: PairingMethod.JUST_WORKS,
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
|
||||
},
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: {
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY: (PASSKEY, False, True),
|
||||
SMP_DISPLAY_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, False, True),
|
||||
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (
|
||||
(PASSKEY, False, True),
|
||||
NUMERIC_COMPARISON,
|
||||
(PairingMethod.PASSKEY, False, True),
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
),
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False),
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
|
||||
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, True, False),
|
||||
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
|
||||
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (
|
||||
(PASSKEY, True, False),
|
||||
NUMERIC_COMPARISON,
|
||||
(PairingMethod.PASSKEY, True, False),
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
ea: bytes
|
||||
eb: bytes
|
||||
ltk: bytes
|
||||
preq: bytes
|
||||
pres: bytes
|
||||
tk: bytes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: Manager,
|
||||
@@ -635,17 +693,10 @@ class Session:
|
||||
) -> None:
|
||||
self.manager = manager
|
||||
self.connection = connection
|
||||
self.preq: Optional[bytes] = None
|
||||
self.pres: Optional[bytes] = None
|
||||
self.ea = None
|
||||
self.eb = None
|
||||
self.tk = bytes(16)
|
||||
self.r = bytes(16)
|
||||
self.stk = None
|
||||
self.ltk = None
|
||||
self.ltk_ediv = 0
|
||||
self.ltk_rand = bytes(8)
|
||||
self.link_key = None
|
||||
self.link_key: Optional[bytes] = None
|
||||
self.initiator_key_distribution: int = 0
|
||||
self.responder_key_distribution: int = 0
|
||||
self.peer_random_value: Optional[bytes] = None
|
||||
@@ -658,13 +709,13 @@ class Session:
|
||||
self.peer_bd_addr: Optional[Address] = None
|
||||
self.peer_signature_key = None
|
||||
self.peer_expected_distributions: List[Type[SMP_Command]] = []
|
||||
self.dh_key = None
|
||||
self.dh_key = b''
|
||||
self.confirm_value = None
|
||||
self.passkey: Optional[int] = None
|
||||
self.passkey_ready = asyncio.Event()
|
||||
self.passkey_step = 0
|
||||
self.passkey_display = False
|
||||
self.pairing_method = 0
|
||||
self.pairing_method: PairingMethod = PairingMethod.JUST_WORKS
|
||||
self.pairing_config = pairing_config
|
||||
self.wait_before_continuing: Optional[asyncio.Future[None]] = None
|
||||
self.completed = False
|
||||
@@ -711,8 +762,8 @@ class Session:
|
||||
self.io_capability = pairing_config.delegate.io_capability
|
||||
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
|
||||
|
||||
# OOB (not supported yet)
|
||||
self.oob = False
|
||||
# OOB
|
||||
self.oob_data_flag = 0 if pairing_config.oob is None else 1
|
||||
|
||||
# Set up addresses
|
||||
self_address = connection.self_address
|
||||
@@ -728,9 +779,35 @@ class Session:
|
||||
self.ia = bytes(peer_address)
|
||||
self.iat = 1 if peer_address.is_random else 0
|
||||
|
||||
# Select the ECC key, TK and r initial value
|
||||
if pairing_config.oob:
|
||||
self.peer_oob_data = pairing_config.oob.peer_data
|
||||
if pairing_config.sc:
|
||||
if pairing_config.oob.our_context is None:
|
||||
raise ValueError(
|
||||
"oob pairing config requires a context when sc is True"
|
||||
)
|
||||
self.r = pairing_config.oob.our_context.r
|
||||
self.ecc_key = pairing_config.oob.our_context.ecc_key
|
||||
if pairing_config.oob.legacy_context is not None:
|
||||
self.tk = pairing_config.oob.legacy_context.tk
|
||||
else:
|
||||
if pairing_config.oob.legacy_context is None:
|
||||
raise ValueError(
|
||||
"oob pairing config requires a legacy context when sc is False"
|
||||
)
|
||||
self.r = bytes(16)
|
||||
self.ecc_key = manager.ecc_key
|
||||
self.tk = pairing_config.oob.legacy_context.tk
|
||||
else:
|
||||
self.peer_oob_data = None
|
||||
self.r = bytes(16)
|
||||
self.ecc_key = manager.ecc_key
|
||||
self.tk = bytes(16)
|
||||
|
||||
@property
|
||||
def pkx(self) -> Tuple[bytes, bytes]:
|
||||
return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x)
|
||||
return (self.ecc_key.x[::-1], self.peer_public_key_x)
|
||||
|
||||
@property
|
||||
def pka(self) -> bytes:
|
||||
@@ -767,21 +844,28 @@ class Session:
|
||||
return None
|
||||
|
||||
def decide_pairing_method(
|
||||
self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
|
||||
self,
|
||||
auth_req: int,
|
||||
initiator_io_capability: int,
|
||||
responder_io_capability: int,
|
||||
) -> None:
|
||||
if self.connection.transport == BT_BR_EDR_TRANSPORT:
|
||||
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
|
||||
return
|
||||
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
|
||||
self.pairing_method = self.JUST_WORKS
|
||||
self.pairing_method = PairingMethod.JUST_WORKS
|
||||
return
|
||||
|
||||
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] # type: ignore[index]
|
||||
if isinstance(details, tuple) and len(details) == 2:
|
||||
# One entry for legacy pairing and one for secure connections
|
||||
details = details[1 if self.sc else 0]
|
||||
if isinstance(details, int):
|
||||
if isinstance(details, PairingMethod):
|
||||
# Just a method ID
|
||||
self.pairing_method = details
|
||||
else:
|
||||
# PASSKEY method, with a method ID and display/input flags
|
||||
assert isinstance(details[0], PairingMethod)
|
||||
self.pairing_method = details[0]
|
||||
self.passkey_display = details[1 if self.is_initiator else 2]
|
||||
|
||||
@@ -858,10 +942,13 @@ class Session:
|
||||
self.tk = self.passkey.to_bytes(16, byteorder='little')
|
||||
logger.debug(f'TK from passkey = {self.tk.hex()}')
|
||||
|
||||
self.connection.abort_on(
|
||||
'disconnection',
|
||||
self.pairing_config.delegate.display_number(self.passkey, digits=6),
|
||||
)
|
||||
try:
|
||||
self.connection.abort_on(
|
||||
'disconnection',
|
||||
self.pairing_config.delegate.display_number(self.passkey, digits=6),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning(f'exception while displaying number: {error}')
|
||||
|
||||
def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None:
|
||||
# Prompt the user for the passkey displayed on the peer
|
||||
@@ -901,7 +988,7 @@ class Session:
|
||||
|
||||
command = SMP_Pairing_Request_Command(
|
||||
io_capability=self.io_capability,
|
||||
oob_data_flag=0,
|
||||
oob_data_flag=self.oob_data_flag,
|
||||
auth_req=self.auth_req,
|
||||
maximum_encryption_key_size=16,
|
||||
initiator_key_distribution=self.initiator_key_distribution,
|
||||
@@ -913,7 +1000,7 @@ class Session:
|
||||
def send_pairing_response_command(self) -> None:
|
||||
response = SMP_Pairing_Response_Command(
|
||||
io_capability=self.io_capability,
|
||||
oob_data_flag=0,
|
||||
oob_data_flag=self.oob_data_flag,
|
||||
auth_req=self.auth_req,
|
||||
maximum_encryption_key_size=16,
|
||||
initiator_key_distribution=self.initiator_key_distribution,
|
||||
@@ -929,9 +1016,12 @@ class Session:
|
||||
if self.sc:
|
||||
|
||||
async def next_steps() -> None:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
if self.pairing_method in (
|
||||
PairingMethod.JUST_WORKS,
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
):
|
||||
z = 0
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
elif self.pairing_method == PairingMethod.PASSKEY:
|
||||
# We need a passkey
|
||||
await self.passkey_ready.wait()
|
||||
assert self.passkey
|
||||
@@ -971,8 +1061,8 @@ class Session:
|
||||
def send_public_key_command(self) -> None:
|
||||
self.send_command(
|
||||
SMP_Pairing_Public_Key_Command(
|
||||
public_key_x=bytes(reversed(self.manager.ecc_key.x)),
|
||||
public_key_y=bytes(reversed(self.manager.ecc_key.y)),
|
||||
public_key_x=self.ecc_key.x[::-1],
|
||||
public_key_y=self.ecc_key.y[::-1],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -983,11 +1073,24 @@ class Session:
|
||||
)
|
||||
)
|
||||
|
||||
def send_identity_address_command(self) -> None:
|
||||
identity_address = {
|
||||
None: self.connection.self_address,
|
||||
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
|
||||
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.random_address,
|
||||
}[self.pairing_config.identity_address_type]
|
||||
self.send_command(
|
||||
SMP_Identity_Address_Information_Command(
|
||||
addr_type=identity_address.address_type,
|
||||
bd_addr=identity_address,
|
||||
)
|
||||
)
|
||||
|
||||
def start_encryption(self, key: bytes) -> None:
|
||||
# We can now encrypt the connection with the short term key, so that we can
|
||||
# distribute the long term and/or other keys over an encrypted connection
|
||||
self.manager.device.host.send_command_sync(
|
||||
HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
|
||||
HCI_LE_Enable_Encryption_Command(
|
||||
connection_handle=self.connection.handle,
|
||||
random_number=bytes(8),
|
||||
encrypted_diversifier=0,
|
||||
@@ -995,15 +1098,54 @@ class Session:
|
||||
)
|
||||
)
|
||||
|
||||
async def derive_ltk(self) -> None:
|
||||
link_key = await self.manager.device.get_link_key(self.connection.peer_address)
|
||||
assert link_key is not None
|
||||
@classmethod
|
||||
def derive_ltk(cls, link_key: bytes, ct2: bool) -> bytes:
|
||||
'''Derives Long Term Key from Link Key.
|
||||
|
||||
Args:
|
||||
link_key: BR/EDR Link Key bytes in little-endian.
|
||||
ct2: whether ct2 is supported on both devices.
|
||||
Returns:
|
||||
LE Long Tern Key bytes in little-endian.
|
||||
'''
|
||||
ilk = (
|
||||
crypto.h7(salt=SMP_CTKD_H7_BRLE_SALT, w=link_key)
|
||||
if self.ct2
|
||||
if ct2
|
||||
else crypto.h6(link_key, b'tmp2')
|
||||
)
|
||||
self.ltk = crypto.h6(ilk, b'brle')
|
||||
return crypto.h6(ilk, b'brle')
|
||||
|
||||
@classmethod
|
||||
def derive_link_key(cls, ltk: bytes, ct2: bool) -> bytes:
|
||||
'''Derives Link Key from Long Term Key.
|
||||
|
||||
Args:
|
||||
ltk: LE Long Term Key bytes in little-endian.
|
||||
ct2: whether ct2 is supported on both devices.
|
||||
Returns:
|
||||
BR/EDR Link Key bytes in little-endian.
|
||||
'''
|
||||
ilk = (
|
||||
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=ltk)
|
||||
if ct2
|
||||
else crypto.h6(ltk, b'tmp1')
|
||||
)
|
||||
return crypto.h6(ilk, b'lebr')
|
||||
|
||||
async def get_link_key_and_derive_ltk(self) -> None:
|
||||
'''Retrieves BR/EDR Link Key from storage and derive it to LE LTK.'''
|
||||
self.link_key = await self.manager.device.get_link_key(
|
||||
self.connection.peer_address
|
||||
)
|
||||
if self.link_key is None:
|
||||
logging.warning(
|
||||
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
|
||||
)
|
||||
self.send_pairing_failed(
|
||||
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR
|
||||
)
|
||||
else:
|
||||
self.ltk = self.derive_ltk(self.link_key, self.ct2)
|
||||
|
||||
def distribute_keys(self) -> None:
|
||||
# Distribute the keys as required
|
||||
@@ -1014,7 +1156,7 @@ class Session:
|
||||
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
||||
):
|
||||
self.ctkd_task = self.connection.abort_on(
|
||||
'disconnection', self.derive_ltk()
|
||||
'disconnection', self.get_link_key_and_derive_ltk()
|
||||
)
|
||||
elif not self.sc:
|
||||
# Distribute the LTK, EDIV and RAND
|
||||
@@ -1035,12 +1177,7 @@ class Session:
|
||||
identity_resolving_key=self.manager.device.irk
|
||||
)
|
||||
)
|
||||
self.send_command(
|
||||
SMP_Identity_Address_Information_Command(
|
||||
addr_type=self.connection.self_address.address_type,
|
||||
bd_addr=self.connection.self_address,
|
||||
)
|
||||
)
|
||||
self.send_identity_address_command()
|
||||
|
||||
# Distribute CSRK
|
||||
csrk = bytes(16) # FIXME: testing
|
||||
@@ -1049,12 +1186,7 @@ class Session:
|
||||
|
||||
# CTKD, calculate BR/EDR link key
|
||||
if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
||||
ilk = (
|
||||
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk)
|
||||
if self.ct2
|
||||
else crypto.h6(self.ltk, b'tmp1')
|
||||
)
|
||||
self.link_key = crypto.h6(ilk, b'lebr')
|
||||
self.link_key = self.derive_link_key(self.ltk, self.ct2)
|
||||
|
||||
else:
|
||||
# CTKD: Derive LTK from LinkKey
|
||||
@@ -1063,7 +1195,7 @@ class Session:
|
||||
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
||||
):
|
||||
self.ctkd_task = self.connection.abort_on(
|
||||
'disconnection', self.derive_ltk()
|
||||
'disconnection', self.get_link_key_and_derive_ltk()
|
||||
)
|
||||
# Distribute the LTK, EDIV and RAND
|
||||
elif not self.sc:
|
||||
@@ -1084,12 +1216,7 @@ class Session:
|
||||
identity_resolving_key=self.manager.device.irk
|
||||
)
|
||||
)
|
||||
self.send_command(
|
||||
SMP_Identity_Address_Information_Command(
|
||||
addr_type=self.connection.self_address.address_type,
|
||||
bd_addr=self.connection.self_address,
|
||||
)
|
||||
)
|
||||
self.send_identity_address_command()
|
||||
|
||||
# Distribute CSRK
|
||||
csrk = bytes(16) # FIXME: testing
|
||||
@@ -1098,12 +1225,7 @@ class Session:
|
||||
|
||||
# CTKD, calculate BR/EDR link key
|
||||
if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
||||
ilk = (
|
||||
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk)
|
||||
if self.ct2
|
||||
else crypto.h6(self.ltk, b'tmp1')
|
||||
)
|
||||
self.link_key = crypto.h6(ilk, b'lebr')
|
||||
self.link_key = self.derive_link_key(self.ltk, self.ct2)
|
||||
|
||||
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
|
||||
# Set our expectations for what to wait for in the key distribution phase
|
||||
@@ -1224,7 +1346,7 @@ class Session:
|
||||
# Create an object to hold the keys
|
||||
keys = PairingKeys()
|
||||
keys.address_type = peer_address.address_type
|
||||
authenticated = self.pairing_method != self.JUST_WORKS
|
||||
authenticated = self.pairing_method != PairingMethod.JUST_WORKS
|
||||
if self.sc or self.connection.transport == BT_BR_EDR_TRANSPORT:
|
||||
keys.ltk = PairingKeys.Key(value=self.ltk, authenticated=authenticated)
|
||||
else:
|
||||
@@ -1258,7 +1380,7 @@ class Session:
|
||||
keys.link_key = PairingKeys.Key(
|
||||
value=self.link_key, authenticated=authenticated
|
||||
)
|
||||
self.manager.on_pairing(self, peer_address, keys)
|
||||
await self.manager.on_pairing(self, peer_address, keys)
|
||||
|
||||
def on_pairing_failure(self, reason: int) -> None:
|
||||
logger.warning(f'pairing failure ({error_name(reason)})')
|
||||
@@ -1281,7 +1403,7 @@ class Session:
|
||||
try:
|
||||
handler(command)
|
||||
except Exception as error:
|
||||
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
|
||||
logger.exception(f'{color("!!! Exception in handler:", "red")} {error}')
|
||||
response = SMP_Pairing_Failed_Command(
|
||||
reason=SMP_UNSPECIFIED_REASON_ERROR
|
||||
)
|
||||
@@ -1300,7 +1422,11 @@ class Session:
|
||||
self, command: SMP_Pairing_Request_Command
|
||||
) -> None:
|
||||
# Check if the request should proceed
|
||||
accepted = await self.pairing_config.delegate.accept()
|
||||
try:
|
||||
accepted = await self.pairing_config.delegate.accept()
|
||||
except Exception as error:
|
||||
logger.warning(f'exception while accepting: {error}')
|
||||
accepted = False
|
||||
if not accepted:
|
||||
logger.debug('pairing rejected by delegate')
|
||||
self.send_pairing_failed(SMP_PAIRING_NOT_SUPPORTED_ERROR)
|
||||
@@ -1314,18 +1440,29 @@ class Session:
|
||||
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
|
||||
self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0)
|
||||
|
||||
# Check for OOB
|
||||
if command.oob_data_flag != 0:
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
return
|
||||
# Infer the pairing method
|
||||
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
|
||||
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0)
|
||||
):
|
||||
# Use OOB
|
||||
self.pairing_method = PairingMethod.OOB
|
||||
if not self.sc and self.tk is None:
|
||||
# For legacy OOB, TK is required.
|
||||
logger.warning("legacy OOB without TK")
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
return
|
||||
if command.oob_data_flag == 0:
|
||||
# The peer doesn't have OOB data, use r=0
|
||||
self.r = bytes(16)
|
||||
else:
|
||||
# Decide which pairing method to use from the IO capability
|
||||
self.decide_pairing_method(
|
||||
command.auth_req,
|
||||
command.io_capability,
|
||||
self.io_capability,
|
||||
)
|
||||
|
||||
# Decide which pairing method to use
|
||||
self.decide_pairing_method(
|
||||
command.auth_req, command.io_capability, self.io_capability
|
||||
)
|
||||
logger.debug(
|
||||
f'pairing method: {self.PAIRING_METHOD_NAMES[self.pairing_method]}'
|
||||
)
|
||||
logger.debug(f'pairing method: {self.pairing_method.name}')
|
||||
|
||||
# Key distribution
|
||||
(
|
||||
@@ -1341,7 +1478,7 @@ class Session:
|
||||
|
||||
# Display a passkey if we need to
|
||||
if not self.sc:
|
||||
if self.pairing_method == self.PASSKEY and self.passkey_display:
|
||||
if self.pairing_method == PairingMethod.PASSKEY and self.passkey_display:
|
||||
self.display_passkey()
|
||||
|
||||
# Respond
|
||||
@@ -1373,18 +1510,27 @@ class Session:
|
||||
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
|
||||
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
|
||||
|
||||
# Check for OOB
|
||||
if self.sc and command.oob_data_flag:
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
return
|
||||
# Infer the pairing method
|
||||
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
|
||||
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0)
|
||||
):
|
||||
# Use OOB
|
||||
self.pairing_method = PairingMethod.OOB
|
||||
if not self.sc and self.tk is None:
|
||||
# For legacy OOB, TK is required.
|
||||
logger.warning("legacy OOB without TK")
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
return
|
||||
if command.oob_data_flag == 0:
|
||||
# The peer doesn't have OOB data, use r=0
|
||||
self.r = bytes(16)
|
||||
else:
|
||||
# Decide which pairing method to use from the IO capability
|
||||
self.decide_pairing_method(
|
||||
command.auth_req, self.io_capability, command.io_capability
|
||||
)
|
||||
|
||||
# Decide which pairing method to use
|
||||
self.decide_pairing_method(
|
||||
command.auth_req, self.io_capability, command.io_capability
|
||||
)
|
||||
logger.debug(
|
||||
f'pairing method: {self.PAIRING_METHOD_NAMES[self.pairing_method]}'
|
||||
)
|
||||
logger.debug(f'pairing method: {self.pairing_method.name}')
|
||||
|
||||
# Key distribution
|
||||
if (
|
||||
@@ -1400,13 +1546,16 @@ class Session:
|
||||
self.compute_peer_expected_distributions(self.responder_key_distribution)
|
||||
|
||||
# Start phase 2
|
||||
if self.sc:
|
||||
if self.pairing_method == self.PASSKEY:
|
||||
if self.pairing_method == PairingMethod.CTKD_OVER_CLASSIC:
|
||||
# Authentication is already done in SMP, so remote shall start keys distribution immediately
|
||||
return
|
||||
elif self.sc:
|
||||
if self.pairing_method == PairingMethod.PASSKEY:
|
||||
self.display_or_input_passkey()
|
||||
|
||||
self.send_public_key_command()
|
||||
else:
|
||||
if self.pairing_method == self.PASSKEY:
|
||||
if self.pairing_method == PairingMethod.PASSKEY:
|
||||
self.display_or_input_passkey(self.send_pairing_confirm_command)
|
||||
else:
|
||||
self.send_pairing_confirm_command()
|
||||
@@ -1418,7 +1567,10 @@ class Session:
|
||||
self.send_pairing_random_command()
|
||||
else:
|
||||
# If the method is PASSKEY, now is the time to input the code
|
||||
if self.pairing_method == self.PASSKEY and not self.passkey_display:
|
||||
if (
|
||||
self.pairing_method == PairingMethod.PASSKEY
|
||||
and not self.passkey_display
|
||||
):
|
||||
self.input_passkey(self.send_pairing_confirm_command)
|
||||
else:
|
||||
self.send_pairing_confirm_command()
|
||||
@@ -1426,11 +1578,14 @@ class Session:
|
||||
def on_smp_pairing_confirm_command_secure_connections(
|
||||
self, _: SMP_Pairing_Confirm_Command
|
||||
) -> None:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
if self.pairing_method in (
|
||||
PairingMethod.JUST_WORKS,
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
):
|
||||
if self.is_initiator:
|
||||
self.r = crypto.r()
|
||||
self.send_pairing_random_command()
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
elif self.pairing_method == PairingMethod.PASSKEY:
|
||||
if self.is_initiator:
|
||||
self.send_pairing_random_command()
|
||||
else:
|
||||
@@ -1486,13 +1641,16 @@ class Session:
|
||||
def on_smp_pairing_random_command_secure_connections(
|
||||
self, command: SMP_Pairing_Random_Command
|
||||
) -> None:
|
||||
if self.pairing_method == self.PASSKEY and self.passkey is None:
|
||||
if self.pairing_method == PairingMethod.PASSKEY and self.passkey is None:
|
||||
logger.warning('no passkey entered, ignoring command')
|
||||
return
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
if self.is_initiator:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
if self.pairing_method in (
|
||||
PairingMethod.JUST_WORKS,
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
):
|
||||
assert self.confirm_value
|
||||
# Check that the random value matches what was committed to earlier
|
||||
confirm_verifier = crypto.f4(
|
||||
@@ -1502,7 +1660,7 @@ class Session:
|
||||
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
|
||||
):
|
||||
return
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
elif self.pairing_method == PairingMethod.PASSKEY:
|
||||
assert self.passkey and self.confirm_value
|
||||
# Check that the random value matches what was committed to earlier
|
||||
confirm_verifier = crypto.f4(
|
||||
@@ -1522,12 +1680,16 @@ class Session:
|
||||
if self.passkey_step < 20:
|
||||
self.send_pairing_confirm_command()
|
||||
return
|
||||
else:
|
||||
elif self.pairing_method != PairingMethod.OOB:
|
||||
return
|
||||
else:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
if self.pairing_method in (
|
||||
PairingMethod.JUST_WORKS,
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
PairingMethod.OOB,
|
||||
):
|
||||
self.send_pairing_random_command()
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
elif self.pairing_method == PairingMethod.PASSKEY:
|
||||
assert self.passkey and self.confirm_value
|
||||
# Check that the random value matches what was committed to earlier
|
||||
confirm_verifier = crypto.f4(
|
||||
@@ -1558,15 +1720,18 @@ class Session:
|
||||
(mac_key, self.ltk) = crypto.f5(self.dh_key, self.na, self.nb, a, b)
|
||||
|
||||
# Compute the DH Key checks
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
if self.pairing_method in (
|
||||
PairingMethod.JUST_WORKS,
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
PairingMethod.OOB,
|
||||
):
|
||||
ra = bytes(16)
|
||||
rb = ra
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
elif self.pairing_method == PairingMethod.PASSKEY:
|
||||
assert self.passkey
|
||||
ra = self.passkey.to_bytes(16, byteorder='little')
|
||||
rb = ra
|
||||
else:
|
||||
# OOB not implemented yet
|
||||
return
|
||||
|
||||
assert self.preq and self.pres
|
||||
@@ -1585,13 +1750,16 @@ class Session:
|
||||
self.wait_before_continuing.set_result(None)
|
||||
|
||||
# Prompt the user for confirmation if needed
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
if self.pairing_method in (
|
||||
PairingMethod.JUST_WORKS,
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
):
|
||||
# Compute the 6-digit code
|
||||
code = crypto.g2(self.pka, self.pkb, self.na, self.nb) % 1000000
|
||||
|
||||
# Ask for user confirmation
|
||||
self.wait_before_continuing = asyncio.get_running_loop().create_future()
|
||||
if self.pairing_method == self.JUST_WORKS:
|
||||
if self.pairing_method == PairingMethod.JUST_WORKS:
|
||||
self.prompt_user_for_confirmation(next_steps)
|
||||
else:
|
||||
self.prompt_user_for_numeric_comparison(code, next_steps)
|
||||
@@ -1615,26 +1783,45 @@ class Session:
|
||||
self.peer_public_key_y = command.public_key_y
|
||||
|
||||
# Compute the DH key
|
||||
self.dh_key = bytes(
|
||||
reversed(
|
||||
self.manager.ecc_key.dh(
|
||||
bytes(reversed(command.public_key_x)),
|
||||
bytes(reversed(command.public_key_y)),
|
||||
)
|
||||
)
|
||||
)
|
||||
self.dh_key = self.ecc_key.dh(
|
||||
command.public_key_x[::-1],
|
||||
command.public_key_y[::-1],
|
||||
)[::-1]
|
||||
logger.debug(f'DH key: {self.dh_key.hex()}')
|
||||
|
||||
if self.pairing_method == PairingMethod.OOB:
|
||||
# Check against shared OOB data
|
||||
if self.peer_oob_data:
|
||||
confirm_verifier = crypto.f4(
|
||||
self.peer_public_key_x,
|
||||
self.peer_public_key_x,
|
||||
self.peer_oob_data.r,
|
||||
bytes(1),
|
||||
)
|
||||
if not self.check_expected_value(
|
||||
self.peer_oob_data.c,
|
||||
confirm_verifier,
|
||||
SMP_CONFIRM_VALUE_FAILED_ERROR,
|
||||
):
|
||||
return
|
||||
|
||||
if self.is_initiator:
|
||||
self.send_pairing_confirm_command()
|
||||
if self.pairing_method == PairingMethod.OOB:
|
||||
self.send_pairing_random_command()
|
||||
else:
|
||||
self.send_pairing_confirm_command()
|
||||
else:
|
||||
if self.pairing_method == self.PASSKEY:
|
||||
if self.pairing_method == PairingMethod.PASSKEY:
|
||||
self.display_or_input_passkey()
|
||||
|
||||
# Send our public key back to the initiator
|
||||
self.send_public_key_command()
|
||||
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
if self.pairing_method in (
|
||||
PairingMethod.JUST_WORKS,
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
PairingMethod.OOB,
|
||||
):
|
||||
# We can now send the confirmation value
|
||||
self.send_pairing_confirm_command()
|
||||
|
||||
@@ -1662,7 +1849,6 @@ class Session:
|
||||
else:
|
||||
self.send_pairing_dhkey_check_command()
|
||||
else:
|
||||
assert self.ltk
|
||||
self.start_encryption(self.ltk)
|
||||
|
||||
def on_smp_pairing_failed_command(
|
||||
@@ -1712,6 +1898,7 @@ class Manager(EventEmitter):
|
||||
sessions: Dict[int, Session]
|
||||
pairing_config_factory: Callable[[Connection], PairingConfig]
|
||||
session_proxy: Type[Session]
|
||||
_ecc_key: Optional[crypto.EccKey]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1733,7 +1920,26 @@ class Manager(EventEmitter):
|
||||
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
|
||||
connection.send_l2cap_pdu(cid, command.to_bytes())
|
||||
|
||||
def on_smp_security_request_command(
|
||||
self, connection: Connection, request: SMP_Security_Request_Command
|
||||
) -> None:
|
||||
connection.emit('security_request', request.auth_req)
|
||||
|
||||
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
|
||||
# Parse the L2CAP payload into an SMP Command object
|
||||
command = SMP_Command.from_bytes(pdu)
|
||||
logger.debug(
|
||||
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address}: {command}'
|
||||
)
|
||||
|
||||
# Security request is more than just pairing, so let applications handle them
|
||||
if command.code == SMP_SECURITY_REQUEST_COMMAND:
|
||||
self.on_smp_security_request_command(
|
||||
connection, cast(SMP_Security_Request_Command, command)
|
||||
)
|
||||
return
|
||||
|
||||
# Look for a session with this connection, and create one if none exists
|
||||
if not (session := self.sessions.get(connection.handle)):
|
||||
if connection.role == BT_CENTRAL_ROLE:
|
||||
@@ -1744,13 +1950,6 @@ class Manager(EventEmitter):
|
||||
)
|
||||
self.sessions[connection.handle] = session
|
||||
|
||||
# Parse the L2CAP payload into an SMP Command object
|
||||
command = SMP_Command.from_bytes(pdu)
|
||||
logger.debug(
|
||||
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address}: {command}'
|
||||
)
|
||||
|
||||
# Delegate the handling of the command to the session
|
||||
session.on_smp_command(command)
|
||||
|
||||
@@ -1789,21 +1988,13 @@ class Manager(EventEmitter):
|
||||
def on_session_start(self, session: Session) -> None:
|
||||
self.device.on_pairing_start(session.connection)
|
||||
|
||||
def on_pairing(
|
||||
async def on_pairing(
|
||||
self, session: Session, identity_address: Optional[Address], keys: PairingKeys
|
||||
) -> None:
|
||||
# Store the keys in the key store
|
||||
if self.device.keystore and identity_address is not None:
|
||||
|
||||
async def store_keys():
|
||||
try:
|
||||
assert self.device.keystore
|
||||
await self.device.keystore.update(str(identity_address), keys)
|
||||
except Exception as error:
|
||||
logger.warning(f'!!! error while storing keys: {error}')
|
||||
|
||||
self.device.abort_on('flush', store_keys())
|
||||
|
||||
# Make sure on_pairing emits after key update.
|
||||
await self.device.update_keys(str(identity_address), keys)
|
||||
# Notify the device
|
||||
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
|
||||
|
||||
|
||||
@@ -18,9 +18,9 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .common import Transport, AsyncPipeSink, SnoopingTransport
|
||||
from ..controller import Controller
|
||||
from ..snoop import create_snooper
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -53,8 +53,16 @@ def _wrap_transport(transport: Transport) -> Transport:
|
||||
async def open_transport(name: str) -> Transport:
|
||||
"""
|
||||
Open a transport by name.
|
||||
The name must be <type>:<parameters>
|
||||
Where <parameters> depend on the type (and may be empty for some types).
|
||||
The name must be <type>:<metadata><parameters>
|
||||
Where <parameters> depend on the type (and may be empty for some types), and
|
||||
<metadata> is either omitted, or a ,-separated list of <key>=<value> pairs,
|
||||
enclosed in [].
|
||||
If there are not metadata or parameter, the : after the <type> may be omitted.
|
||||
Examples:
|
||||
* usb:0
|
||||
* usb:[driver=rtk]0
|
||||
* android-netsim
|
||||
|
||||
The supported types are:
|
||||
* serial
|
||||
* udp
|
||||
@@ -69,86 +77,108 @@ async def open_transport(name: str) -> Transport:
|
||||
* usb
|
||||
* pyusb
|
||||
* android-emulator
|
||||
* android-netsim
|
||||
"""
|
||||
|
||||
return _wrap_transport(await _open_transport(name))
|
||||
scheme, *tail = name.split(':', 1)
|
||||
spec = tail[0] if tail else None
|
||||
metadata = None
|
||||
if spec:
|
||||
# Metadata may precede the spec
|
||||
if spec.startswith('['):
|
||||
metadata_str, *tail = spec[1:].split(']')
|
||||
spec = tail[0] if tail else None
|
||||
metadata = dict([entry.split('=') for entry in metadata_str.split(',')])
|
||||
|
||||
transport = await _open_transport(scheme, spec)
|
||||
if metadata:
|
||||
transport.source.metadata = { # type: ignore[attr-defined]
|
||||
**metadata,
|
||||
**getattr(transport.source, 'metadata', {}),
|
||||
}
|
||||
# pylint: disable=line-too-long
|
||||
logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined]
|
||||
|
||||
return _wrap_transport(transport)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def _open_transport(name: str) -> Transport:
|
||||
async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pylint: disable=too-many-return-statements
|
||||
|
||||
scheme, *spec = name.split(':', 1)
|
||||
if scheme == 'serial' and spec:
|
||||
from .serial import open_serial_transport
|
||||
|
||||
return await open_serial_transport(spec[0])
|
||||
return await open_serial_transport(spec)
|
||||
|
||||
if scheme == 'udp' and spec:
|
||||
from .udp import open_udp_transport
|
||||
|
||||
return await open_udp_transport(spec[0])
|
||||
return await open_udp_transport(spec)
|
||||
|
||||
if scheme == 'tcp-client' and spec:
|
||||
from .tcp_client import open_tcp_client_transport
|
||||
|
||||
return await open_tcp_client_transport(spec[0])
|
||||
return await open_tcp_client_transport(spec)
|
||||
|
||||
if scheme == 'tcp-server' and spec:
|
||||
from .tcp_server import open_tcp_server_transport
|
||||
|
||||
return await open_tcp_server_transport(spec[0])
|
||||
return await open_tcp_server_transport(spec)
|
||||
|
||||
if scheme == 'ws-client' and spec:
|
||||
from .ws_client import open_ws_client_transport
|
||||
|
||||
return await open_ws_client_transport(spec[0])
|
||||
return await open_ws_client_transport(spec)
|
||||
|
||||
if scheme == 'ws-server' and spec:
|
||||
from .ws_server import open_ws_server_transport
|
||||
|
||||
return await open_ws_server_transport(spec[0])
|
||||
return await open_ws_server_transport(spec)
|
||||
|
||||
if scheme == 'pty':
|
||||
from .pty import open_pty_transport
|
||||
|
||||
return await open_pty_transport(spec[0] if spec else None)
|
||||
return await open_pty_transport(spec)
|
||||
|
||||
if scheme == 'file':
|
||||
from .file import open_file_transport
|
||||
|
||||
return await open_file_transport(spec[0] if spec else None)
|
||||
assert spec is not None
|
||||
return await open_file_transport(spec)
|
||||
|
||||
if scheme == 'vhci':
|
||||
from .vhci import open_vhci_transport
|
||||
|
||||
return await open_vhci_transport(spec[0] if spec else None)
|
||||
return await open_vhci_transport(spec)
|
||||
|
||||
if scheme == 'hci-socket':
|
||||
from .hci_socket import open_hci_socket_transport
|
||||
|
||||
return await open_hci_socket_transport(spec[0] if spec else None)
|
||||
return await open_hci_socket_transport(spec)
|
||||
|
||||
if scheme == 'usb':
|
||||
from .usb import open_usb_transport
|
||||
|
||||
return await open_usb_transport(spec[0] if spec else None)
|
||||
assert spec
|
||||
return await open_usb_transport(spec)
|
||||
|
||||
if scheme == 'pyusb':
|
||||
from .pyusb import open_pyusb_transport
|
||||
|
||||
return await open_pyusb_transport(spec[0] if spec else None)
|
||||
assert spec
|
||||
return await open_pyusb_transport(spec)
|
||||
|
||||
if scheme == 'android-emulator':
|
||||
from .android_emulator import open_android_emulator_transport
|
||||
|
||||
return await open_android_emulator_transport(spec[0] if spec else None)
|
||||
return await open_android_emulator_transport(spec)
|
||||
|
||||
if scheme == 'android-netsim':
|
||||
from .android_netsim import open_android_netsim_transport
|
||||
|
||||
return await open_android_netsim_transport(spec[0] if spec else None)
|
||||
return await open_android_netsim_transport(spec)
|
||||
|
||||
raise ValueError('unknown transport scheme')
|
||||
|
||||
@@ -167,11 +197,13 @@ async def open_transport_or_link(name: str) -> Transport:
|
||||
|
||||
"""
|
||||
if name.startswith('link-relay:'):
|
||||
logger.warning('Link Relay has been deprecated.')
|
||||
from ..controller import Controller
|
||||
from ..link import RemoteLink # lazy import
|
||||
|
||||
link = RemoteLink(name[11:])
|
||||
await link.wait_until_connected()
|
||||
controller = Controller('remote', link=link)
|
||||
controller = Controller('remote', link=link) # type:ignore[arg-type]
|
||||
|
||||
class LinkTransport(Transport):
|
||||
async def close(self):
|
||||
|
||||
@@ -18,7 +18,9 @@
|
||||
import logging
|
||||
import grpc.aio
|
||||
|
||||
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink
|
||||
from typing import Optional, Union
|
||||
|
||||
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
|
||||
@@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_emulator_transport(spec):
|
||||
async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||
'''
|
||||
Open a transport connection to an Android emulator via its gRPC interface.
|
||||
The parameter string has this syntax:
|
||||
@@ -66,8 +68,8 @@ async def open_android_emulator_transport(spec):
|
||||
# Parse the parameters
|
||||
mode = 'host'
|
||||
server_host = 'localhost'
|
||||
server_port = 8554
|
||||
if spec is not None:
|
||||
server_port = '8554'
|
||||
if spec:
|
||||
params = spec.split(',')
|
||||
for param in params:
|
||||
if param.startswith('mode='):
|
||||
@@ -82,6 +84,7 @@ async def open_android_emulator_transport(spec):
|
||||
logger.debug(f'connecting to gRPC server at {server_address}')
|
||||
channel = grpc.aio.insecure_channel(server_address)
|
||||
|
||||
service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]
|
||||
if mode == 'host':
|
||||
# Connect as a host
|
||||
service = EmulatedBluetoothServiceStub(channel)
|
||||
@@ -94,10 +97,13 @@ async def open_android_emulator_transport(spec):
|
||||
raise ValueError('invalid mode')
|
||||
|
||||
# Create the transport object
|
||||
transport = PumpedTransport(
|
||||
PumpedPacketSource(hci_device.read),
|
||||
PumpedPacketSink(hci_device.write),
|
||||
channel.close,
|
||||
class EmulatorTransport(PumpedTransport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
await channel.close()
|
||||
|
||||
transport = EmulatorTransport(
|
||||
PumpedPacketSource(hci_device.read), PumpedPacketSink(hci_device.write)
|
||||
)
|
||||
transport.start()
|
||||
|
||||
|
||||
@@ -18,11 +18,12 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import grpc.aio
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import grpc.aio
|
||||
|
||||
from .common import (
|
||||
ParserSource,
|
||||
@@ -33,8 +34,8 @@ from .common import (
|
||||
)
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
from .grpc_protobuf.packet_streamer_pb2_grpc import PacketStreamerStub
|
||||
from .grpc_protobuf.packet_streamer_pb2_grpc import (
|
||||
PacketStreamerStub,
|
||||
PacketStreamerServicer,
|
||||
add_PacketStreamerServicer_to_server,
|
||||
)
|
||||
@@ -43,6 +44,7 @@ from .grpc_protobuf.hci_packet_pb2 import HCIPacket
|
||||
from .grpc_protobuf.startup_pb2 import Chip, ChipInfo
|
||||
from .grpc_protobuf.common_pb2 import ChipKind
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -74,14 +76,20 @@ def get_ini_dir() -> Optional[pathlib.Path]:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def find_grpc_port() -> int:
|
||||
def ini_file_name(instance_number: int) -> str:
|
||||
suffix = f'_{instance_number}' if instance_number > 0 else ''
|
||||
return f'netsim{suffix}.ini'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def find_grpc_port(instance_number: int) -> int:
|
||||
if not (ini_dir := get_ini_dir()):
|
||||
logger.debug('no known directory for .ini file')
|
||||
return 0
|
||||
|
||||
ini_file = ini_dir / 'netsim.ini'
|
||||
ini_file = ini_dir / ini_file_name(instance_number)
|
||||
logger.debug(f'Looking for .ini file at {ini_file}')
|
||||
if ini_file.is_file():
|
||||
logger.debug(f'Found .ini file at {ini_file}')
|
||||
with open(ini_file, 'r') as ini_file_data:
|
||||
for line in ini_file_data.readlines():
|
||||
if '=' in line:
|
||||
@@ -90,12 +98,14 @@ def find_grpc_port() -> int:
|
||||
logger.debug(f'gRPC port = {value}')
|
||||
return int(value)
|
||||
|
||||
logger.debug('no grpc.port property found in .ini file')
|
||||
|
||||
# Not found
|
||||
return 0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def publish_grpc_port(grpc_port) -> bool:
|
||||
def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
|
||||
if not (ini_dir := get_ini_dir()):
|
||||
logger.debug('no known directory for .ini file')
|
||||
return False
|
||||
@@ -104,7 +114,7 @@ def publish_grpc_port(grpc_port) -> bool:
|
||||
logger.debug('ini directory does not exist')
|
||||
return False
|
||||
|
||||
ini_file = ini_dir / 'netsim.ini'
|
||||
ini_file = ini_dir / ini_file_name(instance_number)
|
||||
try:
|
||||
ini_file.write_text(f'grpc.port={grpc_port}\n')
|
||||
logger.debug(f"published gRPC port at {ini_file}")
|
||||
@@ -121,13 +131,16 @@ def publish_grpc_port(grpc_port) -> bool:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_netsim_controller_transport(server_host, server_port):
|
||||
async def open_android_netsim_controller_transport(
|
||||
server_host: Optional[str], server_port: int, options: Dict[str, str]
|
||||
) -> Transport:
|
||||
if not server_port:
|
||||
raise ValueError('invalid port')
|
||||
if server_host == '_' or not server_host:
|
||||
server_host = 'localhost'
|
||||
|
||||
if not publish_grpc_port(server_port):
|
||||
instance_number = int(options.get('instance', "0"))
|
||||
if not publish_grpc_port(server_port, instance_number):
|
||||
logger.warning("unable to publish gRPC port")
|
||||
|
||||
class HciDevice:
|
||||
@@ -184,15 +197,12 @@ async def open_android_netsim_controller_transport(server_host, server_port):
|
||||
logger.debug(f'<<< PACKET: {data.hex()}')
|
||||
self.on_data_received(data)
|
||||
|
||||
def send_packet(self, data):
|
||||
async def send():
|
||||
await self.context.write(
|
||||
PacketResponse(
|
||||
hci_packet=HCIPacket(packet_type=data[0], packet=data[1:])
|
||||
)
|
||||
async def send_packet(self, data):
|
||||
return await self.context.write(
|
||||
PacketResponse(
|
||||
hci_packet=HCIPacket(packet_type=data[0], packet=data[1:])
|
||||
)
|
||||
|
||||
self.loop.create_task(send())
|
||||
)
|
||||
|
||||
def terminate(self):
|
||||
self.task.cancel()
|
||||
@@ -226,17 +236,17 @@ async def open_android_netsim_controller_transport(server_host, server_port):
|
||||
logger.debug('gRPC server cancelled')
|
||||
await self.grpc_server.stop(None)
|
||||
|
||||
def on_packet(self, packet):
|
||||
async def send_packet(self, packet):
|
||||
if not self.device:
|
||||
logger.debug('no device, dropping packet')
|
||||
return
|
||||
|
||||
self.device.send_packet(packet)
|
||||
return await self.device.send_packet(packet)
|
||||
|
||||
async def StreamPackets(self, _request_iterator, context):
|
||||
logger.debug('StreamPackets request')
|
||||
|
||||
# Check that we won't already have a device
|
||||
# Check that we don't already have a device
|
||||
if self.device:
|
||||
logger.debug('busy, already serving a device')
|
||||
return PacketResponse(error='Busy')
|
||||
@@ -259,15 +269,42 @@ async def open_android_netsim_controller_transport(server_host, server_port):
|
||||
await server.start()
|
||||
asyncio.get_running_loop().create_task(server.serve())
|
||||
|
||||
class GrpcServerTransport(Transport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
|
||||
return GrpcServerTransport(server, server)
|
||||
sink = PumpedPacketSink(server.send_packet)
|
||||
sink.start()
|
||||
return Transport(server, sink)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_netsim_host_transport(server_host, server_port, options):
|
||||
async def open_android_netsim_host_transport_with_address(
|
||||
server_host: Optional[str],
|
||||
server_port: int,
|
||||
options: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
if server_host == '_' or not server_host:
|
||||
server_host = 'localhost'
|
||||
|
||||
if not server_port:
|
||||
# Look for the gRPC config in a .ini file
|
||||
instance_number = 0 if options is None else int(options.get('instance', '0'))
|
||||
server_port = find_grpc_port(instance_number)
|
||||
if not server_port:
|
||||
raise RuntimeError('gRPC server port not found')
|
||||
|
||||
# Connect to the gRPC server
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
logger.debug(f'Connecting to gRPC server at {server_address}')
|
||||
channel = grpc.aio.insecure_channel(server_address)
|
||||
|
||||
return await open_android_netsim_host_transport_with_channel(
|
||||
channel,
|
||||
options,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_netsim_host_transport_with_channel(
|
||||
channel, options: Optional[Dict[str, str]] = None
|
||||
):
|
||||
# Wrapper for I/O operations
|
||||
class HciDevice:
|
||||
def __init__(self, name, manufacturer, hci_device):
|
||||
@@ -286,10 +323,12 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
|
||||
async def read(self):
|
||||
response = await self.hci_device.read()
|
||||
response_type = response.WhichOneof('response_type')
|
||||
|
||||
if response_type == 'error':
|
||||
logger.warning(f'received error: {response.error}')
|
||||
raise RuntimeError(response.error)
|
||||
elif response_type == 'hci_packet':
|
||||
|
||||
if response_type == 'hci_packet':
|
||||
return (
|
||||
bytes([response.hci_packet.packet_type])
|
||||
+ response.hci_packet.packet
|
||||
@@ -304,24 +343,9 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
|
||||
)
|
||||
)
|
||||
|
||||
name = options.get('name', DEFAULT_NAME)
|
||||
name = DEFAULT_NAME if options is None else options.get('name', DEFAULT_NAME)
|
||||
manufacturer = DEFAULT_MANUFACTURER
|
||||
|
||||
if server_host == '_' or not server_host:
|
||||
server_host = 'localhost'
|
||||
|
||||
if not server_port:
|
||||
# Look for the gRPC config in a .ini file
|
||||
server_host = 'localhost'
|
||||
server_port = find_grpc_port()
|
||||
if not server_port:
|
||||
raise RuntimeError('gRPC server port not found')
|
||||
|
||||
# Connect to the gRPC server
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
logger.debug(f'Connecting to gRPC server at {server_address}')
|
||||
channel = grpc.aio.insecure_channel(server_address)
|
||||
|
||||
# Connect as a host
|
||||
service = PacketStreamerStub(channel)
|
||||
hci_device = HciDevice(
|
||||
@@ -332,10 +356,14 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
|
||||
await hci_device.start()
|
||||
|
||||
# Create the transport object
|
||||
transport = PumpedTransport(
|
||||
class GrpcTransport(PumpedTransport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
await channel.close()
|
||||
|
||||
transport = GrpcTransport(
|
||||
PumpedPacketSource(hci_device.read),
|
||||
PumpedPacketSink(hci_device.write),
|
||||
channel.close,
|
||||
)
|
||||
transport.start()
|
||||
|
||||
@@ -343,7 +371,7 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_netsim_transport(spec):
|
||||
async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
|
||||
'''
|
||||
Open a transport connection as a client or server, implementing Android's `netsim`
|
||||
simulator protocol over gRPC.
|
||||
@@ -357,6 +385,11 @@ async def open_android_netsim_transport(spec):
|
||||
to connect *to* a netsim server (netsim is the controller), or accept
|
||||
connections *as* a netsim-compatible server.
|
||||
|
||||
instance=<n>
|
||||
Specifies an instance number, with <n> > 0. This is used to determine which
|
||||
.init file to use. In `host` mode, it is ignored when the <host>:<port>
|
||||
specifier is present, since in that case no .ini file is used.
|
||||
|
||||
In `host` mode:
|
||||
The <host>:<port> part is optional. When not specified, the transport
|
||||
looks for a netsim .ini file, from which it will read the `grpc.backend.port`
|
||||
@@ -385,14 +418,15 @@ async def open_android_netsim_transport(spec):
|
||||
params = spec.split(',') if spec else []
|
||||
if params and ':' in params[0]:
|
||||
# Explicit <host>:<port>
|
||||
host, port = params[0].split(':')
|
||||
host, port_str = params[0].split(':')
|
||||
port = int(port_str)
|
||||
params_offset = 1
|
||||
else:
|
||||
host = None
|
||||
port = 0
|
||||
params_offset = 0
|
||||
|
||||
options = {}
|
||||
options: Dict[str, str] = {}
|
||||
for param in params[params_offset:]:
|
||||
if '=' not in param:
|
||||
raise ValueError('invalid parameter, expected <name>=<value>')
|
||||
@@ -401,10 +435,12 @@ async def open_android_netsim_transport(spec):
|
||||
|
||||
mode = options.get('mode', 'host')
|
||||
if mode == 'host':
|
||||
return await open_android_netsim_host_transport(host, port, options)
|
||||
return await open_android_netsim_host_transport_with_address(
|
||||
host, port, options
|
||||
)
|
||||
if mode == 'controller':
|
||||
if host is None:
|
||||
raise ValueError('<host>:<port> missing')
|
||||
return await open_android_netsim_controller_transport(host, port)
|
||||
return await open_android_netsim_controller_transport(host, port, options)
|
||||
|
||||
raise ValueError('invalid mode option')
|
||||
|
||||
@@ -20,11 +20,12 @@ import contextlib
|
||||
import struct
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import ContextManager
|
||||
import io
|
||||
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
|
||||
|
||||
from .. import hci
|
||||
from ..colors import color
|
||||
from ..snoop import Snooper
|
||||
from bumble import hci
|
||||
from bumble.colors import color
|
||||
from bumble.snoop import Snooper
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -36,42 +37,64 @@ logger = logging.getLogger(__name__)
|
||||
# Information needed to parse HCI packets with a generic parser:
|
||||
# For each packet type, the info represents:
|
||||
# (length-size, length-offset, unpack-type)
|
||||
HCI_PACKET_INFO = {
|
||||
HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
|
||||
hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
|
||||
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
|
||||
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
|
||||
hci.HCI_EVENT_PACKET: (1, 1, 'B'),
|
||||
hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'),
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PacketPump:
|
||||
'''
|
||||
Pump HCI packets from a reader to a sink
|
||||
'''
|
||||
# Errors
|
||||
# -----------------------------------------------------------------------------
|
||||
class TransportLostError(Exception):
|
||||
"""
|
||||
The Transport has been lost/disconnected.
|
||||
"""
|
||||
|
||||
def __init__(self, reader, sink):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Typing Protocols
|
||||
# -----------------------------------------------------------------------------
|
||||
class TransportSink(Protocol):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
...
|
||||
|
||||
|
||||
class TransportSource(Protocol):
|
||||
terminated: asyncio.Future[None]
|
||||
|
||||
def set_packet_sink(self, sink: TransportSink) -> None:
|
||||
...
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PacketPump:
|
||||
"""
|
||||
Pump HCI packets from a reader to a sink.
|
||||
"""
|
||||
|
||||
def __init__(self, reader: AsyncPacketReader, sink: TransportSink) -> None:
|
||||
self.reader = reader
|
||||
self.sink = sink
|
||||
|
||||
async def run(self):
|
||||
async def run(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
# Get a packet from the source
|
||||
packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet())
|
||||
|
||||
# Deliver the packet to the sink
|
||||
self.sink.on_packet(packet)
|
||||
self.sink.on_packet(await self.reader.next_packet())
|
||||
except Exception as error:
|
||||
logger.warning(f'!!! {error}')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PacketParser:
|
||||
'''
|
||||
"""
|
||||
In-line parser that accepts data and emits 'on_packet' when a full packet has been
|
||||
parsed
|
||||
'''
|
||||
parsed.
|
||||
"""
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
|
||||
@@ -79,18 +102,22 @@ class PacketParser:
|
||||
NEED_LENGTH = 1
|
||||
NEED_BODY = 2
|
||||
|
||||
def __init__(self, sink=None):
|
||||
sink: Optional[TransportSink]
|
||||
extended_packet_info: Dict[int, Tuple[int, int, str]]
|
||||
packet_info: Optional[Tuple[int, int, str]] = None
|
||||
|
||||
def __init__(self, sink: Optional[TransportSink] = None) -> None:
|
||||
self.sink = sink
|
||||
self.extended_packet_info = {}
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self.state = PacketParser.NEED_TYPE
|
||||
self.bytes_needed = 1
|
||||
self.packet = bytearray()
|
||||
self.packet_info = None
|
||||
|
||||
def feed_data(self, data):
|
||||
def feed_data(self, data: bytes) -> None:
|
||||
data_offset = 0
|
||||
data_left = len(data)
|
||||
while data_left and self.bytes_needed:
|
||||
@@ -111,6 +138,7 @@ class PacketParser:
|
||||
self.state = PacketParser.NEED_LENGTH
|
||||
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
|
||||
elif self.state == PacketParser.NEED_LENGTH:
|
||||
assert self.packet_info is not None
|
||||
body_length = struct.unpack_from(
|
||||
self.packet_info[2], self.packet, 1 + self.packet_info[1]
|
||||
)[0]
|
||||
@@ -123,25 +151,25 @@ class PacketParser:
|
||||
try:
|
||||
self.sink.on_packet(bytes(self.packet))
|
||||
except Exception as error:
|
||||
logger.warning(
|
||||
logger.exception(
|
||||
color(f'!!! Exception in on_packet: {error}', 'red')
|
||||
)
|
||||
self.reset()
|
||||
|
||||
def set_packet_sink(self, sink):
|
||||
def set_packet_sink(self, sink: TransportSink) -> None:
|
||||
self.sink = sink
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PacketReader:
|
||||
'''
|
||||
Reader that reads HCI packets from a sync source
|
||||
'''
|
||||
"""
|
||||
Reader that reads HCI packets from a sync source.
|
||||
"""
|
||||
|
||||
def __init__(self, source):
|
||||
def __init__(self, source: io.BufferedReader) -> None:
|
||||
self.source = source
|
||||
|
||||
def next_packet(self):
|
||||
def next_packet(self) -> Optional[bytes]:
|
||||
# Get the packet type
|
||||
packet_type = self.source.read(1)
|
||||
if len(packet_type) != 1:
|
||||
@@ -150,7 +178,7 @@ class PacketReader:
|
||||
# Get the packet info based on its type
|
||||
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
||||
if packet_info is None:
|
||||
raise ValueError(f'invalid packet type {packet_type} found')
|
||||
raise ValueError(f'invalid packet type {packet_type[0]} found')
|
||||
|
||||
# Read the header (that includes the length)
|
||||
header_size = packet_info[0] + packet_info[1]
|
||||
@@ -169,21 +197,21 @@ class PacketReader:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AsyncPacketReader:
|
||||
'''
|
||||
Reader that reads HCI packets from an async source
|
||||
'''
|
||||
"""
|
||||
Reader that reads HCI packets from an async source.
|
||||
"""
|
||||
|
||||
def __init__(self, source):
|
||||
def __init__(self, source: asyncio.StreamReader) -> None:
|
||||
self.source = source
|
||||
|
||||
async def next_packet(self):
|
||||
async def next_packet(self) -> bytes:
|
||||
# Get the packet type
|
||||
packet_type = await self.source.readexactly(1)
|
||||
|
||||
# Get the packet info based on its type
|
||||
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
||||
if packet_info is None:
|
||||
raise ValueError(f'invalid packet type {packet_type} found')
|
||||
raise ValueError(f'invalid packet type {packet_type[0]} found')
|
||||
|
||||
# Read the header (that includes the length)
|
||||
header_size = packet_info[0] + packet_info[1]
|
||||
@@ -198,15 +226,15 @@ class AsyncPacketReader:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AsyncPipeSink:
|
||||
'''
|
||||
Sink that forwards packets asynchronously to another sink
|
||||
'''
|
||||
"""
|
||||
Sink that forwards packets asynchronously to another sink.
|
||||
"""
|
||||
|
||||
def __init__(self, sink):
|
||||
def __init__(self, sink: TransportSink) -> None:
|
||||
self.sink = sink
|
||||
self.loop = asyncio.get_running_loop()
|
||||
|
||||
def on_packet(self, packet):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
self.loop.call_soon(self.sink.on_packet, packet)
|
||||
|
||||
|
||||
@@ -216,35 +244,48 @@ class ParserSource:
|
||||
Base class designed to be subclassed by transport-specific source classes
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
terminated: asyncio.Future[None]
|
||||
parser: PacketParser
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.parser = PacketParser()
|
||||
self.terminated = asyncio.get_running_loop().create_future()
|
||||
|
||||
def set_packet_sink(self, sink):
|
||||
def set_packet_sink(self, sink: TransportSink) -> None:
|
||||
self.parser.set_packet_sink(sink)
|
||||
|
||||
async def wait_for_termination(self):
|
||||
def on_transport_lost(self) -> None:
|
||||
self.terminated.set_result(None)
|
||||
if self.parser.sink:
|
||||
if hasattr(self.parser.sink, 'on_transport_lost'):
|
||||
self.parser.sink.on_transport_lost()
|
||||
|
||||
async def wait_for_termination(self) -> None:
|
||||
"""
|
||||
Convenience method for backward compatibility. Prefer using the `terminated`
|
||||
attribute instead.
|
||||
"""
|
||||
return await self.terminated
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class StreamPacketSource(asyncio.Protocol, ParserSource):
|
||||
def data_received(self, data):
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self.parser.feed_data(data)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class StreamPacketSink:
|
||||
def __init__(self, transport):
|
||||
def __init__(self, transport: asyncio.WriteTransport) -> None:
|
||||
self.transport = transport
|
||||
|
||||
def on_packet(self, packet):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
self.transport.write(packet)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self.transport.close()
|
||||
|
||||
|
||||
@@ -264,7 +305,7 @@ class Transport:
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, source, sink):
|
||||
def __init__(self, source: TransportSource, sink: TransportSink) -> None:
|
||||
self.source = source
|
||||
self.sink = sink
|
||||
|
||||
@@ -278,34 +319,39 @@ class Transport:
|
||||
return iter((self.source, self.sink))
|
||||
|
||||
async def close(self) -> None:
|
||||
self.source.close()
|
||||
self.sink.close()
|
||||
if hasattr(self.source, 'close'):
|
||||
self.source.close()
|
||||
if hasattr(self.sink, 'close'):
|
||||
self.sink.close()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PumpedPacketSource(ParserSource):
|
||||
def __init__(self, receive):
|
||||
pump_task: Optional[asyncio.Task[None]]
|
||||
|
||||
def __init__(self, receive) -> None:
|
||||
super().__init__()
|
||||
self.receive_function = receive
|
||||
self.pump_task = None
|
||||
|
||||
def start(self):
|
||||
async def pump_packets():
|
||||
def start(self) -> None:
|
||||
async def pump_packets() -> None:
|
||||
while True:
|
||||
try:
|
||||
packet = await self.receive_function()
|
||||
self.parser.feed_data(packet)
|
||||
except asyncio.exceptions.CancelledError:
|
||||
except asyncio.CancelledError:
|
||||
logger.debug('source pump task done')
|
||||
self.terminated.set_result(None)
|
||||
break
|
||||
except Exception as error:
|
||||
logger.warning(f'exception while waiting for packet: {error}')
|
||||
self.terminated.set_result(error)
|
||||
self.terminated.set_exception(error)
|
||||
break
|
||||
|
||||
self.pump_task = asyncio.create_task(pump_packets())
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
if self.pump_task:
|
||||
self.pump_task.cancel()
|
||||
|
||||
@@ -317,7 +363,7 @@ class PumpedPacketSink:
|
||||
self.packet_queue = asyncio.Queue()
|
||||
self.pump_task = None
|
||||
|
||||
def on_packet(self, packet):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
self.packet_queue.put_nowait(packet)
|
||||
|
||||
def start(self):
|
||||
@@ -326,7 +372,7 @@ class PumpedPacketSink:
|
||||
try:
|
||||
packet = await self.packet_queue.get()
|
||||
await self.send_function(packet)
|
||||
except asyncio.exceptions.CancelledError:
|
||||
except asyncio.CancelledError:
|
||||
logger.debug('sink pump task done')
|
||||
break
|
||||
except Exception as error:
|
||||
@@ -342,18 +388,20 @@ class PumpedPacketSink:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PumpedTransport(Transport):
|
||||
def __init__(self, source, sink, close_function):
|
||||
super().__init__(source, sink)
|
||||
self.close_function = close_function
|
||||
source: PumpedPacketSource
|
||||
sink: PumpedPacketSink
|
||||
|
||||
def start(self):
|
||||
def __init__(
|
||||
self,
|
||||
source: PumpedPacketSource,
|
||||
sink: PumpedPacketSink,
|
||||
) -> None:
|
||||
super().__init__(source, sink)
|
||||
|
||||
def start(self) -> None:
|
||||
self.source.start()
|
||||
self.sink.start()
|
||||
|
||||
async def close(self):
|
||||
await super().close()
|
||||
await self.close_function()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SnoopingTransport(Transport):
|
||||
@@ -375,31 +423,38 @@ class SnoopingTransport(Transport):
|
||||
raise RuntimeError('unexpected code path') # Satisfy the type checker
|
||||
|
||||
class Source:
|
||||
def __init__(self, source, snooper):
|
||||
sink: TransportSink
|
||||
|
||||
def __init__(self, source: TransportSource, snooper: Snooper):
|
||||
self.source = source
|
||||
self.snooper = snooper
|
||||
self.sink = None
|
||||
self.terminated = source.terminated
|
||||
|
||||
def set_packet_sink(self, sink):
|
||||
def set_packet_sink(self, sink: TransportSink) -> None:
|
||||
self.sink = sink
|
||||
self.source.set_packet_sink(self)
|
||||
|
||||
def on_packet(self, packet):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
|
||||
if self.sink:
|
||||
self.sink.on_packet(packet)
|
||||
|
||||
class Sink:
|
||||
def __init__(self, sink, snooper):
|
||||
def __init__(self, sink: TransportSink, snooper: Snooper) -> None:
|
||||
self.sink = sink
|
||||
self.snooper = snooper
|
||||
|
||||
def on_packet(self, packet):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
|
||||
if self.sink:
|
||||
self.sink.on_packet(packet)
|
||||
|
||||
def __init__(self, transport, snooper, close_snooper=None):
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
snooper: Snooper,
|
||||
close_snooper=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_file_transport(spec):
|
||||
async def open_file_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a File transport (typically not for a real file, but for a PTY or other unix
|
||||
virtual files).
|
||||
|
||||
@@ -23,6 +23,8 @@ import socket
|
||||
import ctypes
|
||||
import collections
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .common import Transport, ParserSource
|
||||
|
||||
|
||||
@@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_hci_socket_transport(spec):
|
||||
async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
|
||||
'''
|
||||
Open an HCI Socket (only available on some platforms).
|
||||
The parameter string is either empty (to use the first/default Bluetooth adapter)
|
||||
@@ -45,9 +47,9 @@ async def open_hci_socket_transport(spec):
|
||||
# Create a raw HCI socket
|
||||
try:
|
||||
hci_socket = socket.socket(
|
||||
socket.AF_BLUETOOTH,
|
||||
socket.SOCK_RAW | socket.SOCK_NONBLOCK,
|
||||
socket.BTPROTO_HCI,
|
||||
socket.AF_BLUETOOTH, # type: ignore[attr-defined]
|
||||
socket.SOCK_RAW | socket.SOCK_NONBLOCK, # type: ignore[attr-defined]
|
||||
socket.BTPROTO_HCI, # type: ignore[attr-defined]
|
||||
)
|
||||
except AttributeError as error:
|
||||
# Not supported on this platform
|
||||
@@ -57,10 +59,7 @@ async def open_hci_socket_transport(spec):
|
||||
) from error
|
||||
|
||||
# Compute the adapter index
|
||||
if spec is None:
|
||||
adapter_index = 0
|
||||
else:
|
||||
adapter_index = int(spec)
|
||||
adapter_index = int(spec) if spec else 0
|
||||
|
||||
# Bind the socket
|
||||
# NOTE: since Python doesn't support binding with the required address format (yet),
|
||||
@@ -78,7 +77,7 @@ async def open_hci_socket_transport(spec):
|
||||
bind_address = struct.pack(
|
||||
# pylint: disable=no-member
|
||||
'<HHH',
|
||||
socket.AF_BLUETOOTH,
|
||||
socket.AF_BLUETOOTH, # type: ignore[attr-defined]
|
||||
adapter_index,
|
||||
HCI_CHANNEL_USER,
|
||||
)
|
||||
|
||||
@@ -23,6 +23,8 @@ import atexit
|
||||
import os
|
||||
import logging
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .common import Transport, StreamPacketSource, StreamPacketSink
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -32,7 +34,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_pty_transport(spec):
|
||||
async def open_pty_transport(spec: Optional[str]) -> Transport:
|
||||
'''
|
||||
Open a PTY transport.
|
||||
The parameter string may be empty, or a path name where a symbolic link
|
||||
|
||||
@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_pyusb_transport(spec):
|
||||
async def open_pyusb_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a USB transport. [Implementation based on PyUSB]
|
||||
The parameter string has this syntax:
|
||||
|
||||
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_serial_transport(spec):
|
||||
async def open_serial_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a serial port transport.
|
||||
The parameter string has this syntax:
|
||||
|
||||
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_tcp_client_transport(spec):
|
||||
async def open_tcp_client_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a TCP client transport.
|
||||
The parameter string has this syntax:
|
||||
@@ -39,7 +39,7 @@ async def open_tcp_client_transport(spec):
|
||||
class TcpPacketSource(StreamPacketSource):
|
||||
def connection_lost(self, exc):
|
||||
logger.debug(f'connection lost: {exc}')
|
||||
self.terminated.set_result(exc)
|
||||
self.on_transport_lost()
|
||||
|
||||
remote_host, remote_port = spec.split(':')
|
||||
tcp_transport, packet_source = await asyncio.get_running_loop().create_connection(
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
@@ -27,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_tcp_server_transport(spec):
|
||||
async def open_tcp_server_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a TCP server transport.
|
||||
The parameter string has this syntax:
|
||||
@@ -42,7 +43,7 @@ async def open_tcp_server_transport(spec):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
|
||||
class TcpServerProtocol:
|
||||
class TcpServerProtocol(asyncio.BaseProtocol):
|
||||
def __init__(self, packet_source, packet_sink):
|
||||
self.packet_source = packet_source
|
||||
self.packet_sink = packet_sink
|
||||
|
||||
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_udp_transport(spec):
|
||||
async def open_udp_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a UDP transport.
|
||||
The parameter string has this syntax:
|
||||
|
||||
@@ -24,9 +24,10 @@ import platform
|
||||
|
||||
import usb1
|
||||
|
||||
from .common import Transport, ParserSource
|
||||
from .. import hci
|
||||
from ..colors import color
|
||||
from bumble.transport.common import Transport, ParserSource
|
||||
from bumble import hci
|
||||
from bumble.colors import color
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -60,7 +61,7 @@ def load_libusb():
|
||||
usb1.loadLibrary(libusb_dll)
|
||||
|
||||
|
||||
async def open_usb_transport(spec):
|
||||
async def open_usb_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a USB transport.
|
||||
The moniker string has this syntax:
|
||||
@@ -107,13 +108,13 @@ async def open_usb_transport(spec):
|
||||
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
|
||||
)
|
||||
|
||||
READ_SIZE = 1024
|
||||
READ_SIZE = 4096
|
||||
|
||||
class UsbPacketSink:
|
||||
def __init__(self, device, acl_out):
|
||||
self.device = device
|
||||
self.acl_out = acl_out
|
||||
self.transfer = device.getTransfer()
|
||||
self.acl_out_transfer = device.getTransfer()
|
||||
self.packets = collections.deque() # Queue of packets waiting to be sent
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.cancel_done = self.loop.create_future()
|
||||
@@ -137,21 +138,20 @@ async def open_usb_transport(spec):
|
||||
# The queue was previously empty, re-prime the pump
|
||||
self.process_queue()
|
||||
|
||||
def on_packet_sent(self, transfer):
|
||||
def transfer_callback(self, transfer):
|
||||
status = transfer.getStatus()
|
||||
# logger.debug(f'<<< USB out transfer callback: status={status}')
|
||||
|
||||
# pylint: disable=no-member
|
||||
if status == usb1.TRANSFER_COMPLETED:
|
||||
self.loop.call_soon_threadsafe(self.on_packet_sent_)
|
||||
self.loop.call_soon_threadsafe(self.on_packet_sent)
|
||||
elif status == usb1.TRANSFER_CANCELLED:
|
||||
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
|
||||
else:
|
||||
logger.warning(
|
||||
color(f'!!! out transfer not completed: status={status}', 'red')
|
||||
color(f'!!! OUT transfer not completed: status={status}', 'red')
|
||||
)
|
||||
|
||||
def on_packet_sent_(self):
|
||||
def on_packet_sent(self):
|
||||
if self.packets:
|
||||
self.packets.popleft()
|
||||
self.process_queue()
|
||||
@@ -163,22 +163,20 @@ async def open_usb_transport(spec):
|
||||
packet = self.packets[0]
|
||||
packet_type = packet[0]
|
||||
if packet_type == hci.HCI_ACL_DATA_PACKET:
|
||||
self.transfer.setBulk(
|
||||
self.acl_out, packet[1:], callback=self.on_packet_sent
|
||||
self.acl_out_transfer.setBulk(
|
||||
self.acl_out, packet[1:], callback=self.transfer_callback
|
||||
)
|
||||
logger.debug('submit ACL')
|
||||
self.transfer.submit()
|
||||
self.acl_out_transfer.submit()
|
||||
elif packet_type == hci.HCI_COMMAND_PACKET:
|
||||
self.transfer.setControl(
|
||||
self.acl_out_transfer.setControl(
|
||||
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
packet[1:],
|
||||
callback=self.on_packet_sent,
|
||||
callback=self.transfer_callback,
|
||||
)
|
||||
logger.debug('submit COMMAND')
|
||||
self.transfer.submit()
|
||||
self.acl_out_transfer.submit()
|
||||
else:
|
||||
logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
|
||||
|
||||
@@ -193,11 +191,11 @@ async def open_usb_transport(spec):
|
||||
self.packets.clear()
|
||||
|
||||
# If we have a transfer in flight, cancel it
|
||||
if self.transfer.isSubmitted():
|
||||
if self.acl_out_transfer.isSubmitted():
|
||||
# Try to cancel the transfer, but that may fail because it may have
|
||||
# already completed
|
||||
try:
|
||||
self.transfer.cancel()
|
||||
self.acl_out_transfer.cancel()
|
||||
|
||||
logger.debug('waiting for OUT transfer cancellation to be done...')
|
||||
await self.cancel_done
|
||||
@@ -206,26 +204,22 @@ async def open_usb_transport(spec):
|
||||
logger.debug('OUT transfer likely already completed')
|
||||
|
||||
class UsbPacketSource(asyncio.Protocol, ParserSource):
|
||||
def __init__(self, context, device, acl_in, events_in):
|
||||
def __init__(self, device, metadata, acl_in, events_in):
|
||||
super().__init__()
|
||||
self.context = context
|
||||
self.device = device
|
||||
self.metadata = metadata
|
||||
self.acl_in = acl_in
|
||||
self.acl_in_transfer = None
|
||||
self.events_in = events_in
|
||||
self.events_in_transfer = None
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.queue = asyncio.Queue()
|
||||
self.dequeue_task = None
|
||||
self.closed = False
|
||||
self.event_loop_done = self.loop.create_future()
|
||||
self.cancel_done = {
|
||||
hci.HCI_EVENT_PACKET: self.loop.create_future(),
|
||||
hci.HCI_ACL_DATA_PACKET: self.loop.create_future(),
|
||||
}
|
||||
self.events_in_transfer = None
|
||||
self.acl_in_transfer = None
|
||||
|
||||
# Create a thread to process events
|
||||
self.event_thread = threading.Thread(target=self.run)
|
||||
self.closed = False
|
||||
|
||||
def start(self):
|
||||
# Set up transfer objects for input
|
||||
@@ -233,7 +227,7 @@ async def open_usb_transport(spec):
|
||||
self.events_in_transfer.setInterrupt(
|
||||
self.events_in,
|
||||
READ_SIZE,
|
||||
callback=self.on_packet_received,
|
||||
callback=self.transfer_callback,
|
||||
user_data=hci.HCI_EVENT_PACKET,
|
||||
)
|
||||
self.events_in_transfer.submit()
|
||||
@@ -242,22 +236,23 @@ async def open_usb_transport(spec):
|
||||
self.acl_in_transfer.setBulk(
|
||||
self.acl_in,
|
||||
READ_SIZE,
|
||||
callback=self.on_packet_received,
|
||||
callback=self.transfer_callback,
|
||||
user_data=hci.HCI_ACL_DATA_PACKET,
|
||||
)
|
||||
self.acl_in_transfer.submit()
|
||||
|
||||
self.dequeue_task = self.loop.create_task(self.dequeue())
|
||||
self.event_thread.start()
|
||||
|
||||
def on_packet_received(self, transfer):
|
||||
@property
|
||||
def usb_transfer_submitted(self):
|
||||
return (
|
||||
self.events_in_transfer.isSubmitted()
|
||||
or self.acl_in_transfer.isSubmitted()
|
||||
)
|
||||
|
||||
def transfer_callback(self, transfer):
|
||||
packet_type = transfer.getUserData()
|
||||
status = transfer.getStatus()
|
||||
# logger.debug(
|
||||
# f'<<< USB IN transfer callback: status={status} '
|
||||
# f'packet_type={packet_type} '
|
||||
# f'length={transfer.getActualLength()}'
|
||||
# )
|
||||
|
||||
# pylint: disable=no-member
|
||||
if status == usb1.TRANSFER_COMPLETED:
|
||||
@@ -266,18 +261,18 @@ async def open_usb_transport(spec):
|
||||
+ transfer.getBuffer()[: transfer.getActualLength()]
|
||||
)
|
||||
self.loop.call_soon_threadsafe(self.queue.put_nowait, packet)
|
||||
|
||||
# Re-submit the transfer so we can receive more data
|
||||
transfer.submit()
|
||||
elif status == usb1.TRANSFER_CANCELLED:
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.cancel_done[packet_type].set_result, None
|
||||
)
|
||||
return
|
||||
else:
|
||||
logger.warning(
|
||||
color(f'!!! transfer not completed: status={status}', 'red')
|
||||
color(f'!!! IN transfer not completed: status={status}', 'red')
|
||||
)
|
||||
|
||||
# Re-submit the transfer so we can receive more data
|
||||
transfer.submit()
|
||||
self.loop.call_soon_threadsafe(self.on_transport_lost)
|
||||
|
||||
async def dequeue(self):
|
||||
while not self.closed:
|
||||
@@ -287,21 +282,6 @@ async def open_usb_transport(spec):
|
||||
return
|
||||
self.parser.feed_data(packet)
|
||||
|
||||
def run(self):
|
||||
logger.debug('starting USB event loop')
|
||||
while (
|
||||
self.events_in_transfer.isSubmitted()
|
||||
or self.acl_in_transfer.isSubmitted()
|
||||
):
|
||||
# pylint: disable=no-member
|
||||
try:
|
||||
self.context.handleEvents()
|
||||
except usb1.USBErrorInterrupted:
|
||||
pass
|
||||
|
||||
logger.debug('USB event loop done')
|
||||
self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
@@ -330,15 +310,14 @@ async def open_usb_transport(spec):
|
||||
f'IN[{packet_type}] transfer likely already completed'
|
||||
)
|
||||
|
||||
# Wait for the thread to terminate
|
||||
await self.event_loop_done
|
||||
|
||||
class UsbTransport(Transport):
|
||||
def __init__(self, context, device, interface, setting, source, sink):
|
||||
super().__init__(source, sink)
|
||||
self.context = context
|
||||
self.device = device
|
||||
self.interface = interface
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.event_loop_done = self.loop.create_future()
|
||||
|
||||
# Get exclusive access
|
||||
device.claimInterface(interface)
|
||||
@@ -351,6 +330,22 @@ async def open_usb_transport(spec):
|
||||
source.start()
|
||||
sink.start()
|
||||
|
||||
# Create a thread to process events
|
||||
self.event_thread = threading.Thread(target=self.run)
|
||||
self.event_thread.start()
|
||||
|
||||
def run(self):
|
||||
logger.debug('starting USB event loop')
|
||||
while self.source.usb_transfer_submitted:
|
||||
# pylint: disable=no-member
|
||||
try:
|
||||
self.context.handleEvents()
|
||||
except usb1.USBErrorInterrupted:
|
||||
pass
|
||||
|
||||
logger.debug('USB event loop done')
|
||||
self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
|
||||
|
||||
async def close(self):
|
||||
self.source.close()
|
||||
self.sink.close()
|
||||
@@ -360,6 +355,9 @@ async def open_usb_transport(spec):
|
||||
self.device.close()
|
||||
self.context.close()
|
||||
|
||||
# Wait for the thread to terminate
|
||||
await self.event_loop_done
|
||||
|
||||
# Find the device according to the spec moniker
|
||||
load_libusb()
|
||||
context = usb1.USBContext()
|
||||
@@ -510,6 +508,10 @@ async def open_usb_transport(spec):
|
||||
f'events_in=0x{events_in:02X}, '
|
||||
)
|
||||
|
||||
device_metadata = {
|
||||
'vendor_id': found.getVendorID(),
|
||||
'product_id': found.getProductID(),
|
||||
}
|
||||
device = found.open()
|
||||
|
||||
# Auto-detach the kernel driver if supported
|
||||
@@ -535,7 +537,7 @@ async def open_usb_transport(spec):
|
||||
except usb1.USBError:
|
||||
logger.warning('failed to set configuration')
|
||||
|
||||
source = UsbPacketSource(context, device, acl_in, events_in)
|
||||
source = UsbPacketSource(device, device_metadata, acl_in, events_in)
|
||||
sink = UsbPacketSink(device, acl_out)
|
||||
return UsbTransport(context, device, interface, setting, source, sink)
|
||||
except usb1.USBError as error:
|
||||
|
||||
@@ -17,6 +17,9 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
import logging
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .common import Transport
|
||||
from .file import open_file_transport
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -26,7 +29,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_vhci_transport(spec):
|
||||
async def open_vhci_transport(spec: Optional[str]) -> Transport:
|
||||
'''
|
||||
Open a VHCI transport (only available on some platforms).
|
||||
The parameter string is either empty (to use the default VHCI device
|
||||
@@ -42,15 +45,15 @@ async def open_vhci_transport(spec):
|
||||
# Override the source's `data_received` method so that we can
|
||||
# filter out the vendor packet that is received just after the
|
||||
# initial open
|
||||
def vhci_data_received(data):
|
||||
def vhci_data_received(data: bytes) -> None:
|
||||
if len(data) > 0 and data[0] == HCI_VENDOR_PKT:
|
||||
if len(data) == 4:
|
||||
hci_index = data[2] << 8 | data[3]
|
||||
logger.info(f'HCI index {hci_index}')
|
||||
else:
|
||||
transport.source.parser.feed_data(data)
|
||||
transport.source.parser.feed_data(data) # type: ignore
|
||||
|
||||
transport.source.data_received = vhci_data_received
|
||||
transport.source.data_received = vhci_data_received # type: ignore
|
||||
|
||||
# Write the initial config
|
||||
transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR]))
|
||||
|
||||
@@ -16,9 +16,9 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import logging
|
||||
import websockets
|
||||
import websockets.client
|
||||
|
||||
from .common import PumpedPacketSource, PumpedPacketSink, PumpedTransport
|
||||
from .common import PumpedPacketSource, PumpedPacketSink, PumpedTransport, Transport
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -27,23 +27,25 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_ws_client_transport(spec):
|
||||
async def open_ws_client_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a WebSocket client transport.
|
||||
The parameter string has this syntax:
|
||||
<remote-host>:<remote-port>
|
||||
<websocket-url>
|
||||
|
||||
Example: 127.0.0.1:9001
|
||||
Example: ws://localhost:7681/v1/websocket/bt
|
||||
'''
|
||||
|
||||
remote_host, remote_port = spec.split(':')
|
||||
uri = f'ws://{remote_host}:{remote_port}'
|
||||
websocket = await websockets.connect(uri)
|
||||
websocket = await websockets.client.connect(spec)
|
||||
|
||||
transport = PumpedTransport(
|
||||
class WsTransport(PumpedTransport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
await websocket.close()
|
||||
|
||||
transport = WsTransport(
|
||||
PumpedPacketSource(websocket.recv),
|
||||
PumpedPacketSink(websocket.send),
|
||||
websocket.close,
|
||||
)
|
||||
transport.start()
|
||||
return transport
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import websockets
|
||||
|
||||
@@ -28,7 +27,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_ws_server_transport(spec):
|
||||
async def open_ws_server_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a WebSocket server transport.
|
||||
The parameter string has this syntax:
|
||||
@@ -43,7 +42,7 @@ async def open_ws_server_transport(spec):
|
||||
def __init__(self):
|
||||
source = ParserSource()
|
||||
sink = PumpedPacketSink(self.send_packet)
|
||||
self.connection = asyncio.get_running_loop().create_future()
|
||||
self.connection = None
|
||||
self.server = None
|
||||
|
||||
super().__init__(source, sink)
|
||||
@@ -63,7 +62,7 @@ async def open_ws_server_transport(spec):
|
||||
f'new connection on {connection.local_address} '
|
||||
f'from {connection.remote_address}'
|
||||
)
|
||||
self.connection.set_result(connection)
|
||||
self.connection = connection
|
||||
# pylint: disable=no-member
|
||||
try:
|
||||
async for packet in connection:
|
||||
@@ -74,12 +73,14 @@ async def open_ws_server_transport(spec):
|
||||
except websockets.WebSocketException as error:
|
||||
logger.debug(f'exception while receiving packet: {error}')
|
||||
|
||||
# Wait for a new connection
|
||||
self.connection = asyncio.get_running_loop().create_future()
|
||||
# We're now disconnected
|
||||
self.connection = None
|
||||
|
||||
async def send_packet(self, packet):
|
||||
connection = await self.connection
|
||||
return await connection.send(packet)
|
||||
if self.connection is None:
|
||||
logger.debug('no connection, dropping packet')
|
||||
return
|
||||
return await self.connection.send(packet)
|
||||
|
||||
local_host, local_port = spec.split(':')
|
||||
transport = WsServerTransport()
|
||||
|
||||
211
bumble/utils.py
211
bumble/utils.py
@@ -15,13 +15,27 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
import collections
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import sys
|
||||
from typing import Awaitable, Set, TypeVar
|
||||
from functools import wraps
|
||||
import warnings
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Set,
|
||||
TypeVar,
|
||||
List,
|
||||
Tuple,
|
||||
Callable,
|
||||
Any,
|
||||
Optional,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pyee import EventEmitter
|
||||
|
||||
from .colors import color
|
||||
@@ -64,6 +78,104 @@ def composite_listener(cls):
|
||||
return cls
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
_Handler = TypeVar('_Handler', bound=Callable)
|
||||
|
||||
|
||||
class EventWatcher:
|
||||
'''A wrapper class to control the lifecycle of event handlers better.
|
||||
|
||||
Usage:
|
||||
```
|
||||
watcher = EventWatcher()
|
||||
|
||||
def on_foo():
|
||||
...
|
||||
watcher.on(emitter, 'foo', on_foo)
|
||||
|
||||
@watcher.on(emitter, 'bar')
|
||||
def on_bar():
|
||||
...
|
||||
|
||||
# Close all event handlers watching through this watcher
|
||||
watcher.close()
|
||||
```
|
||||
|
||||
As context:
|
||||
```
|
||||
with contextlib.closing(EventWatcher()) as context:
|
||||
@context.on(emitter, 'foo')
|
||||
def on_foo():
|
||||
...
|
||||
# on_foo() has been removed here!
|
||||
```
|
||||
'''
|
||||
|
||||
handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.handlers = []
|
||||
|
||||
@overload
|
||||
def on(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
|
||||
...
|
||||
|
||||
def on(
|
||||
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
|
||||
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
|
||||
'''Watch an event until the context is closed.
|
||||
|
||||
Args:
|
||||
emitter: EventEmitter to watch
|
||||
event: Event name
|
||||
handler: (Optional) Event handler. When nothing is passed, this method
|
||||
works as a decorator.
|
||||
'''
|
||||
|
||||
def wrapper(wrapped: _Handler) -> _Handler:
|
||||
self.handlers.append((emitter, event, wrapped))
|
||||
emitter.on(event, wrapped)
|
||||
return wrapped
|
||||
|
||||
return wrapper if handler is None else wrapper(handler)
|
||||
|
||||
@overload
|
||||
def once(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def once(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
|
||||
...
|
||||
|
||||
def once(
|
||||
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
|
||||
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
|
||||
'''Watch an event for once.
|
||||
|
||||
Args:
|
||||
emitter: EventEmitter to watch
|
||||
event: Event name
|
||||
handler: (Optional) Event handler. When nothing passed, this method works
|
||||
as a decorator.
|
||||
'''
|
||||
|
||||
def wrapper(wrapped: _Handler) -> _Handler:
|
||||
self.handlers.append((emitter, event, wrapped))
|
||||
emitter.once(event, wrapped)
|
||||
return wrapped
|
||||
|
||||
return wrapper if handler is None else wrapper(handler)
|
||||
|
||||
def close(self) -> None:
|
||||
for emitter, event, handler in self.handlers:
|
||||
if handler in emitter.listeners(event):
|
||||
emitter.remove_listener(event, handler)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
_T = TypeVar('_T')
|
||||
|
||||
@@ -114,13 +226,13 @@ class CompositeEventEmitter(AbortableEventEmitter):
|
||||
if self._listener:
|
||||
# Call the deregistration methods for each base class that has them
|
||||
for cls in self._listener.__class__.mro():
|
||||
if hasattr(cls, '_bumble_register_composite'):
|
||||
cls._bumble_deregister_composite(listener, self)
|
||||
if '_bumble_register_composite' in cls.__dict__:
|
||||
cls._bumble_deregister_composite(self._listener, self)
|
||||
self._listener = listener
|
||||
if listener:
|
||||
# Call the registration methods for each base class that has them
|
||||
for cls in listener.__class__.mro():
|
||||
if hasattr(cls, '_bumble_deregister_composite'):
|
||||
if '_bumble_deregister_composite' in cls.__dict__:
|
||||
cls._bumble_register_composite(listener, self)
|
||||
|
||||
|
||||
@@ -167,21 +279,18 @@ class AsyncRunner:
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
coroutine = func(*args, **kwargs)
|
||||
if queue is None:
|
||||
# Create a task to run the coroutine
|
||||
# Spawn the coroutine as a task
|
||||
async def run():
|
||||
try:
|
||||
await coroutine
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f'{color("!!! Exception in wrapper:", "red")} '
|
||||
f'{traceback.format_exc()}'
|
||||
)
|
||||
logger.exception(color("!!! Exception in wrapper:", "red"))
|
||||
|
||||
asyncio.create_task(run())
|
||||
AsyncRunner.spawn(run())
|
||||
else:
|
||||
# Queue the coroutine to be awaited by the work queue
|
||||
queue.enqueue(coroutine)
|
||||
@@ -302,3 +411,77 @@ class FlowControlAsyncPipe:
|
||||
self.resume_source()
|
||||
|
||||
self.check_pump()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_call(function, *args, **kwargs):
|
||||
"""
|
||||
Immediately calls the function with provided args and kwargs, wrapping it in an
|
||||
async function.
|
||||
Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject
|
||||
a running loop.
|
||||
|
||||
result = await async_call(some_function, ...)
|
||||
"""
|
||||
return function(*args, **kwargs)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def wrap_async(function):
|
||||
"""
|
||||
Wraps the provided function in an async function.
|
||||
"""
|
||||
return functools.partial(async_call, function)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def deprecated(msg: str):
|
||||
"""
|
||||
Throw deprecation warning before execution.
|
||||
"""
|
||||
|
||||
def wrapper(function):
|
||||
@functools.wraps(function)
|
||||
def inner(*args, **kwargs):
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
return function(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def experimental(msg: str):
|
||||
"""
|
||||
Throws a future warning before execution.
|
||||
"""
|
||||
|
||||
def wrapper(function):
|
||||
@functools.wraps(function)
|
||||
def inner(*args, **kwargs):
|
||||
warnings.warn(msg, FutureWarning)
|
||||
return function(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OpenIntEnum(enum.IntEnum):
|
||||
"""
|
||||
Subclass of enum.IntEnum that can hold integer values outside the set of
|
||||
predefined values. This is convenient for implementing protocols where some
|
||||
integer constants may be added over time.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if not isinstance(value, int):
|
||||
return None
|
||||
|
||||
obj = int.__new__(cls, value)
|
||||
obj._value_ = value
|
||||
obj._name_ = f"{cls.__name__}[{value}]"
|
||||
return obj
|
||||
|
||||
0
bumble/vendor/__init__.py
vendored
Normal file
0
bumble/vendor/__init__.py
vendored
Normal file
0
bumble/vendor/android/__init__.py
vendored
Normal file
0
bumble/vendor/android/__init__.py
vendored
Normal file
318
bumble/vendor/android/hci.py
vendored
Normal file
318
bumble/vendor/android/hci.py
vendored
Normal file
@@ -0,0 +1,318 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
|
||||
from bumble.hci import (
|
||||
name_or_number,
|
||||
hci_vendor_command_op_code,
|
||||
Address,
|
||||
HCI_Constant,
|
||||
HCI_Object,
|
||||
HCI_Command,
|
||||
HCI_Vendor_Event,
|
||||
STATUS_SPEC,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Android Vendor Specific Commands and Events.
|
||||
# Only a subset of the commands are implemented here currently.
|
||||
#
|
||||
# pylint: disable-next=line-too-long
|
||||
# See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#chip-capabilities-and-configuration
|
||||
HCI_LE_GET_VENDOR_CAPABILITIES_COMMAND = hci_vendor_command_op_code(0x153)
|
||||
HCI_LE_APCF_COMMAND = hci_vendor_command_op_code(0x157)
|
||||
HCI_GET_CONTROLLER_ACTIVITY_ENERGY_INFO_COMMAND = hci_vendor_command_op_code(0x159)
|
||||
HCI_A2DP_HARDWARE_OFFLOAD_COMMAND = hci_vendor_command_op_code(0x15D)
|
||||
HCI_BLUETOOTH_QUALITY_REPORT_COMMAND = hci_vendor_command_op_code(0x15E)
|
||||
HCI_DYNAMIC_AUDIO_BUFFER_COMMAND = hci_vendor_command_op_code(0x15F)
|
||||
|
||||
HCI_BLUETOOTH_QUALITY_REPORT_EVENT = 0x58
|
||||
|
||||
HCI_Command.register_commands(globals())
|
||||
HCI_Vendor_Event.register_subevents(globals())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Command.command(
|
||||
return_parameters_fields=[
|
||||
('status', STATUS_SPEC),
|
||||
('max_advt_instances', 1),
|
||||
('offloaded_resolution_of_private_address', 1),
|
||||
('total_scan_results_storage', 2),
|
||||
('max_irk_list_sz', 1),
|
||||
('filtering_support', 1),
|
||||
('max_filter', 1),
|
||||
('activity_energy_info_support', 1),
|
||||
('version_supported', 2),
|
||||
('total_num_of_advt_tracked', 2),
|
||||
('extended_scan_support', 1),
|
||||
('debug_logging_supported', 1),
|
||||
('le_address_generation_offloading_support', 1),
|
||||
('a2dp_source_offload_capability_mask', 4),
|
||||
('bluetooth_quality_report_support', 1),
|
||||
('dynamic_audio_buffer_support', 4),
|
||||
]
|
||||
)
|
||||
class HCI_LE_Get_Vendor_Capabilities_Command(HCI_Command):
|
||||
# pylint: disable=line-too-long
|
||||
'''
|
||||
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#vendor-specific-capabilities
|
||||
'''
|
||||
|
||||
@classmethod
|
||||
def parse_return_parameters(cls, parameters):
|
||||
# There are many versions of this data structure, so we need to parse until
|
||||
# there are no more bytes to parse, and leave un-signal parameters set to
|
||||
# None (older versions)
|
||||
nones = {field: None for field, _ in cls.return_parameters_fields}
|
||||
return_parameters = HCI_Object(cls.return_parameters_fields, **nones)
|
||||
|
||||
try:
|
||||
offset = 0
|
||||
for field in cls.return_parameters_fields:
|
||||
field_name, field_type = field
|
||||
field_value, field_size = HCI_Object.parse_field(
|
||||
parameters, offset, field_type
|
||||
)
|
||||
setattr(return_parameters, field_name, field_value)
|
||||
offset += field_size
|
||||
except struct.error:
|
||||
pass
|
||||
|
||||
return return_parameters
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Command.command(
|
||||
fields=[
|
||||
(
|
||||
'opcode',
|
||||
{
|
||||
'size': 1,
|
||||
'mapper': lambda x: HCI_LE_APCF_Command.opcode_name(x),
|
||||
},
|
||||
),
|
||||
('payload', '*'),
|
||||
],
|
||||
return_parameters_fields=[
|
||||
('status', STATUS_SPEC),
|
||||
(
|
||||
'opcode',
|
||||
{
|
||||
'size': 1,
|
||||
'mapper': lambda x: HCI_LE_APCF_Command.opcode_name(x),
|
||||
},
|
||||
),
|
||||
('payload', '*'),
|
||||
],
|
||||
)
|
||||
class HCI_LE_APCF_Command(HCI_Command):
|
||||
# pylint: disable=line-too-long
|
||||
'''
|
||||
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_apcf_command
|
||||
|
||||
NOTE: the subcommand-specific payloads are left as opaque byte arrays in this
|
||||
implementation. A future enhancement may define subcommand-specific data structures.
|
||||
'''
|
||||
|
||||
# APCF Subcommands
|
||||
# TODO: use the OpenIntEnum class (when upcoming PR is merged)
|
||||
APCF_ENABLE = 0x00
|
||||
APCF_SET_FILTERING_PARAMETERS = 0x01
|
||||
APCF_BROADCASTER_ADDRESS = 0x02
|
||||
APCF_SERVICE_UUID = 0x03
|
||||
APCF_SERVICE_SOLICITATION_UUID = 0x04
|
||||
APCF_LOCAL_NAME = 0x05
|
||||
APCF_MANUFACTURER_DATA = 0x06
|
||||
APCF_SERVICE_DATA = 0x07
|
||||
APCF_TRANSPORT_DISCOVERY_SERVICE = 0x08
|
||||
APCF_AD_TYPE_FILTER = 0x09
|
||||
APCF_READ_EXTENDED_FEATURES = 0xFF
|
||||
|
||||
OPCODE_NAMES = {
|
||||
APCF_ENABLE: 'APCF_ENABLE',
|
||||
APCF_SET_FILTERING_PARAMETERS: 'APCF_SET_FILTERING_PARAMETERS',
|
||||
APCF_BROADCASTER_ADDRESS: 'APCF_BROADCASTER_ADDRESS',
|
||||
APCF_SERVICE_UUID: 'APCF_SERVICE_UUID',
|
||||
APCF_SERVICE_SOLICITATION_UUID: 'APCF_SERVICE_SOLICITATION_UUID',
|
||||
APCF_LOCAL_NAME: 'APCF_LOCAL_NAME',
|
||||
APCF_MANUFACTURER_DATA: 'APCF_MANUFACTURER_DATA',
|
||||
APCF_SERVICE_DATA: 'APCF_SERVICE_DATA',
|
||||
APCF_TRANSPORT_DISCOVERY_SERVICE: 'APCF_TRANSPORT_DISCOVERY_SERVICE',
|
||||
APCF_AD_TYPE_FILTER: 'APCF_AD_TYPE_FILTER',
|
||||
APCF_READ_EXTENDED_FEATURES: 'APCF_READ_EXTENDED_FEATURES',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def opcode_name(cls, opcode):
|
||||
return name_or_number(cls.OPCODE_NAMES, opcode)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Command.command(
|
||||
return_parameters_fields=[
|
||||
('status', STATUS_SPEC),
|
||||
('total_tx_time_ms', 4),
|
||||
('total_rx_time_ms', 4),
|
||||
('total_idle_time_ms', 4),
|
||||
('total_energy_used', 4),
|
||||
],
|
||||
)
|
||||
class HCI_Get_Controller_Activity_Energy_Info_Command(HCI_Command):
|
||||
# pylint: disable=line-too-long
|
||||
'''
|
||||
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_get_controller_activity_energy_info
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Command.command(
|
||||
fields=[
|
||||
(
|
||||
'opcode',
|
||||
{
|
||||
'size': 1,
|
||||
'mapper': lambda x: HCI_A2DP_Hardware_Offload_Command.opcode_name(x),
|
||||
},
|
||||
),
|
||||
('payload', '*'),
|
||||
],
|
||||
return_parameters_fields=[
|
||||
('status', STATUS_SPEC),
|
||||
(
|
||||
'opcode',
|
||||
{
|
||||
'size': 1,
|
||||
'mapper': lambda x: HCI_A2DP_Hardware_Offload_Command.opcode_name(x),
|
||||
},
|
||||
),
|
||||
('payload', '*'),
|
||||
],
|
||||
)
|
||||
class HCI_A2DP_Hardware_Offload_Command(HCI_Command):
|
||||
# pylint: disable=line-too-long
|
||||
'''
|
||||
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#a2dp-hardware-offload-support
|
||||
|
||||
NOTE: the subcommand-specific payloads are left as opaque byte arrays in this
|
||||
implementation. A future enhancement may define subcommand-specific data structures.
|
||||
'''
|
||||
|
||||
# A2DP Hardware Offload Subcommands
|
||||
# TODO: use the OpenIntEnum class (when upcoming PR is merged)
|
||||
START_A2DP_OFFLOAD = 0x01
|
||||
STOP_A2DP_OFFLOAD = 0x02
|
||||
|
||||
OPCODE_NAMES = {
|
||||
START_A2DP_OFFLOAD: 'START_A2DP_OFFLOAD',
|
||||
STOP_A2DP_OFFLOAD: 'STOP_A2DP_OFFLOAD',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def opcode_name(cls, opcode):
|
||||
return name_or_number(cls.OPCODE_NAMES, opcode)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Command.command(
|
||||
fields=[
|
||||
(
|
||||
'opcode',
|
||||
{
|
||||
'size': 1,
|
||||
'mapper': lambda x: HCI_Dynamic_Audio_Buffer_Command.opcode_name(x),
|
||||
},
|
||||
),
|
||||
('payload', '*'),
|
||||
],
|
||||
return_parameters_fields=[
|
||||
('status', STATUS_SPEC),
|
||||
(
|
||||
'opcode',
|
||||
{
|
||||
'size': 1,
|
||||
'mapper': lambda x: HCI_Dynamic_Audio_Buffer_Command.opcode_name(x),
|
||||
},
|
||||
),
|
||||
('payload', '*'),
|
||||
],
|
||||
)
|
||||
class HCI_Dynamic_Audio_Buffer_Command(HCI_Command):
|
||||
# pylint: disable=line-too-long
|
||||
'''
|
||||
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#dynamic-audio-buffer-command
|
||||
|
||||
NOTE: the subcommand-specific payloads are left as opaque byte arrays in this
|
||||
implementation. A future enhancement may define subcommand-specific data structures.
|
||||
'''
|
||||
|
||||
# Dynamic Audio Buffer Subcommands
|
||||
# TODO: use the OpenIntEnum class (when upcoming PR is merged)
|
||||
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
|
||||
|
||||
OPCODE_NAMES = {
|
||||
GET_AUDIO_BUFFER_TIME_CAPABILITY: 'GET_AUDIO_BUFFER_TIME_CAPABILITY',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def opcode_name(cls, opcode):
|
||||
return name_or_number(cls.OPCODE_NAMES, opcode)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Vendor_Event.event(
|
||||
fields=[
|
||||
('quality_report_id', 1),
|
||||
('packet_types', 1),
|
||||
('connection_handle', 2),
|
||||
('connection_role', {'size': 1, 'mapper': HCI_Constant.role_name}),
|
||||
('tx_power_level', -1),
|
||||
('rssi', -1),
|
||||
('snr', 1),
|
||||
('unused_afh_channel_count', 1),
|
||||
('afh_select_unideal_channel_count', 1),
|
||||
('lsto', 2),
|
||||
('connection_piconet_clock', 4),
|
||||
('retransmission_count', 4),
|
||||
('no_rx_count', 4),
|
||||
('nak_count', 4),
|
||||
('last_tx_ack_timestamp', 4),
|
||||
('flow_off_count', 4),
|
||||
('last_flow_on_timestamp', 4),
|
||||
('buffer_overflow_bytes', 4),
|
||||
('buffer_underflow_bytes', 4),
|
||||
('bdaddr', Address.parse_address),
|
||||
('cal_failed_item_count', 1),
|
||||
('tx_total_packets', 4),
|
||||
('tx_unacked_packets', 4),
|
||||
('tx_flushed_packets', 4),
|
||||
('tx_last_subevent_packets', 4),
|
||||
('crc_error_packets', 4),
|
||||
('rx_duplicate_packets', 4),
|
||||
('vendor_specific_parameters', '*'),
|
||||
]
|
||||
)
|
||||
class HCI_Bluetooth_Quality_Report_Event(HCI_Vendor_Event):
|
||||
# pylint: disable=line-too-long
|
||||
'''
|
||||
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#bluetooth-quality-report-sub-event
|
||||
'''
|
||||
0
bumble/vendor/zephyr/__init__.py
vendored
Normal file
0
bumble/vendor/zephyr/__init__.py
vendored
Normal file
88
bumble/vendor/zephyr/hci.py
vendored
Normal file
88
bumble/vendor/zephyr/hci.py
vendored
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from bumble.hci import (
|
||||
hci_vendor_command_op_code,
|
||||
HCI_Command,
|
||||
STATUS_SPEC,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Zephyr RTOS Vendor Specific Commands and Events.
|
||||
# Only a subset of the commands are implemented here currently.
|
||||
#
|
||||
# pylint: disable-next=line-too-long
|
||||
# See https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
|
||||
HCI_WRITE_TX_POWER_LEVEL_COMMAND = hci_vendor_command_op_code(0x000E)
|
||||
HCI_READ_TX_POWER_LEVEL_COMMAND = hci_vendor_command_op_code(0x000F)
|
||||
|
||||
HCI_Command.register_commands(globals())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class TX_Power_Level_Command:
|
||||
'''
|
||||
Base class for read and write TX power level HCI commands
|
||||
'''
|
||||
|
||||
TX_POWER_HANDLE_TYPE_ADV = 0x00
|
||||
TX_POWER_HANDLE_TYPE_SCAN = 0x01
|
||||
TX_POWER_HANDLE_TYPE_CONN = 0x02
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Command.command(
|
||||
fields=[('handle_type', 1), ('connection_handle', 2), ('tx_power_level', -1)],
|
||||
return_parameters_fields=[
|
||||
('status', STATUS_SPEC),
|
||||
('handle_type', 1),
|
||||
('connection_handle', 2),
|
||||
('selected_tx_power_level', -1),
|
||||
],
|
||||
)
|
||||
class HCI_Write_Tx_Power_Level_Command(HCI_Command, TX_Power_Level_Command):
|
||||
'''
|
||||
Write TX power level. See BT_HCI_OP_VS_WRITE_TX_POWER_LEVEL in
|
||||
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
|
||||
|
||||
Power level is in dB. Connection handle for TX_POWER_HANDLE_TYPE_ADV and
|
||||
TX_POWER_HANDLE_TYPE_SCAN should be zero.
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Command.command(
|
||||
fields=[('handle_type', 1), ('connection_handle', 2)],
|
||||
return_parameters_fields=[
|
||||
('status', STATUS_SPEC),
|
||||
('handle_type', 1),
|
||||
('connection_handle', 2),
|
||||
('tx_power_level', -1),
|
||||
],
|
||||
)
|
||||
class HCI_Read_Tx_Power_Level_Command(HCI_Command, TX_Power_Level_Command):
|
||||
'''
|
||||
Read TX power level. See BT_HCI_OP_VS_READ_TX_POWER_LEVEL in
|
||||
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
|
||||
|
||||
Power level is in dB. Connection handle for TX_POWER_HANDLE_TYPE_ADV and
|
||||
TX_POWER_HANDLE_TYPE_SCAN should be zero.
|
||||
'''
|
||||
@@ -10,7 +10,7 @@ nav:
|
||||
- Contributing: development/contributing.md
|
||||
- Code Style: development/code_style.md
|
||||
- Use Cases:
|
||||
- Overview: use_cases/index.md
|
||||
- use_cases/index.md
|
||||
- Use Case 1: use_cases/use_case_1.md
|
||||
- Use Case 2: use_cases/use_case_2.md
|
||||
- Use Case 3: use_cases/use_case_3.md
|
||||
@@ -23,7 +23,7 @@ nav:
|
||||
- GATT: components/gatt.md
|
||||
- Security Manager: components/security_manager.md
|
||||
- Transports:
|
||||
- Overview: transports/index.md
|
||||
- transports/index.md
|
||||
- Serial: transports/serial.md
|
||||
- USB: transports/usb.md
|
||||
- PTY: transports/pty.md
|
||||
@@ -36,12 +36,15 @@ nav:
|
||||
- HCI Socket: transports/hci_socket.md
|
||||
- Android Emulator: transports/android_emulator.md
|
||||
- File: transports/file.md
|
||||
- Drivers:
|
||||
- drivers/index.md
|
||||
- Realtek: drivers/realtek.md
|
||||
- API:
|
||||
- Guide: api/guide.md
|
||||
- Examples: api/examples.md
|
||||
- Reference: api/reference.md
|
||||
- Apps & Tools:
|
||||
- Overview: apps_and_tools/index.md
|
||||
- apps_and_tools/index.md
|
||||
- Console: apps_and_tools/console.md
|
||||
- Bench: apps_and_tools/bench.md
|
||||
- Speaker: apps_and_tools/speaker.md
|
||||
@@ -54,15 +57,25 @@ nav:
|
||||
- USB Probe: apps_and_tools/usb_probe.md
|
||||
- Link Relay: apps_and_tools/link_relay.md
|
||||
- Hardware:
|
||||
- Overview: hardware/index.md
|
||||
- hardware/index.md
|
||||
- Platforms:
|
||||
- Overview: platforms/index.md
|
||||
- platforms/index.md
|
||||
- macOS: platforms/macos.md
|
||||
- Linux: platforms/linux.md
|
||||
- Windows: platforms/windows.md
|
||||
- Android: platforms/android.md
|
||||
- Zephyr: platforms/zephyr.md
|
||||
- Examples:
|
||||
- Overview: examples/index.md
|
||||
- examples/index.md
|
||||
- Extras:
|
||||
- extras/index.md
|
||||
- Android Remote HCI: extras/android_remote_hci.md
|
||||
- Android BT Bench: extras/android_bt_bench.md
|
||||
- Hive:
|
||||
- hive/index.md
|
||||
- Speaker: hive/web/speaker/speaker.html
|
||||
- Scanner: hive/web/scanner/scanner.html
|
||||
- Heart Rate Monitor: hive/web/heart_rate_monitor/heart_rate_monitor.html
|
||||
|
||||
copyright: Copyright 2021-2023 Google LLC
|
||||
|
||||
@@ -71,6 +84,8 @@ theme:
|
||||
logo: 'images/logo.png'
|
||||
favicon: 'images/favicon.ico'
|
||||
custom_dir: 'theme'
|
||||
features:
|
||||
- navigation.indexes
|
||||
|
||||
plugins:
|
||||
- mkdocstrings:
|
||||
@@ -95,6 +110,8 @@ markdown_extensions:
|
||||
- pymdownx.emoji:
|
||||
emoji_index: !!python/name:materialx.emoji.twemoji
|
||||
emoji_generator: !!python/name:materialx.emoji.to_svg
|
||||
- pymdownx.tabbed:
|
||||
alternate_style: true
|
||||
- codehilite:
|
||||
guess_lang: false
|
||||
- toc:
|
||||
|
||||
@@ -7,16 +7,36 @@ throughput and/or latency between two devices.
|
||||
# General Usage
|
||||
|
||||
```
|
||||
Usage: bench.py [OPTIONS] COMMAND [ARGS]...
|
||||
Usage: bumble-bench [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Options:
|
||||
--device-config FILENAME Device configuration file
|
||||
--role [sender|receiver|ping|pong]
|
||||
--mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server]
|
||||
--att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517]
|
||||
-s, --packet-size SIZE Packet size (server role) [8<=x<=4096]
|
||||
-c, --packet-count COUNT Packet count (server role)
|
||||
-sd, --start-delay SECONDS Start delay (server role)
|
||||
--extended-data-length TEXT Request a data length upon connection,
|
||||
specified as tx_octets/tx_time
|
||||
--rfcomm-channel INTEGER RFComm channel to use
|
||||
--rfcomm-uuid TEXT RFComm service UUID to use (ignored if
|
||||
--rfcomm-channel is not 0)
|
||||
--l2cap-psm INTEGER L2CAP PSM to use
|
||||
--l2cap-mtu INTEGER L2CAP MTU to use
|
||||
--l2cap-mps INTEGER L2CAP MPS to use
|
||||
--l2cap-max-credits INTEGER L2CAP maximum number of credits allowed for
|
||||
the peer
|
||||
-s, --packet-size SIZE Packet size (client or ping role)
|
||||
[8<=x<=4096]
|
||||
-c, --packet-count COUNT Packet count (client or ping role)
|
||||
-sd, --start-delay SECONDS Start delay (client or ping role)
|
||||
--repeat N Repeat the run N times (client and ping
|
||||
roles)(0, which is the fault, to run just
|
||||
once)
|
||||
--repeat-delay SECONDS Delay, in seconds, between repeats
|
||||
--pace MILLISECONDS Wait N milliseconds between packets (0,
|
||||
which is the fault, to send as fast as
|
||||
possible)
|
||||
--linger Don't exit at the end of a run (server and
|
||||
pong roles)
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
@@ -35,17 +55,18 @@ Options:
|
||||
--connection-interval, --ci CONNECTION_INTERVAL
|
||||
Connection interval (in ms)
|
||||
--phy [1m|2m|coded] PHY to use
|
||||
--authenticate Authenticate (RFComm only)
|
||||
--encrypt Encrypt the connection (RFComm only)
|
||||
--help Show this message and exit.
|
||||
```
|
||||
|
||||
|
||||
To test once device against another, one of the two devices must be running
|
||||
To test once device against another, one of the two devices must be running
|
||||
the ``peripheral`` command and the other the ``central`` command. The device
|
||||
running the ``peripheral`` command will accept connections from the device
|
||||
running the ``central`` command.
|
||||
When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils),
|
||||
the default addresses configured in the tool should be sufficient. But when using
|
||||
Bluetooth Classic, the address of the Peripheral must be specified on the Central
|
||||
the default addresses configured in the tool should be sufficient. But when using
|
||||
Bluetooth Classic, the address of the Peripheral must be specified on the Central
|
||||
using the ``--peripheral`` option. The address will be printed by the Peripheral when
|
||||
it starts.
|
||||
|
||||
@@ -83,7 +104,7 @@ the other on `usb:1`, and two consoles/terminals. We will run a command in each.
|
||||
$ bumble-bench central usb:1
|
||||
```
|
||||
|
||||
In this default configuration, the Central runs a Sender, as a GATT client,
|
||||
In this default configuration, the Central runs a Sender, as a GATT client,
|
||||
connecting to the Peripheral running a Receiver, as a GATT server.
|
||||
|
||||
!!! example "L2CAP Throughput"
|
||||
|
||||
BIN
docs/mkdocs/src/downloads/zephyr/hci_usb.zip
Normal file
BIN
docs/mkdocs/src/downloads/zephyr/hci_usb.zip
Normal file
Binary file not shown.
19
docs/mkdocs/src/drivers/index.md
Normal file
19
docs/mkdocs/src/drivers/index.md
Normal file
@@ -0,0 +1,19 @@
|
||||
DRIVERS
|
||||
=======
|
||||
|
||||
Some Bluetooth controllers require a driver to function properly.
|
||||
This may include, for instance, loading a Firmware image or patch,
|
||||
loading a configuration.
|
||||
|
||||
By default, drivers will be automatically probed to determine if they should be
|
||||
used with particular HCI controller.
|
||||
When the transport for an HCI controller is instantiated from a transport name,
|
||||
a driver may also be forced by specifying ``driver=<driver-name>`` in the optional
|
||||
metadata portion of the transport name. For example,
|
||||
``usb:[driver=-rtk]0`` indicates that the ``rtk`` driver should be used with the
|
||||
first USB device, even if a normal probe would not have selected it based on the
|
||||
USB vendor ID and product ID.
|
||||
|
||||
Drivers included in the module are:
|
||||
|
||||
* [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles.
|
||||
65
docs/mkdocs/src/drivers/realtek.md
Normal file
65
docs/mkdocs/src/drivers/realtek.md
Normal file
@@ -0,0 +1,65 @@
|
||||
REALTEK DRIVER
|
||||
==============
|
||||
|
||||
This driver supports loading firmware images and optional config data to
|
||||
USB dongles with a Realtek chipset.
|
||||
A number of USB dongles are supported, but likely not all.
|
||||
When using a USB dongle, the USB product ID and vendor ID are used
|
||||
to find whether a matching set of firmware image and config data
|
||||
is needed for that specific model. If a match exists, the driver will try
|
||||
load the firmware image and, if needed, config data.
|
||||
Alternatively, the metadata property ``driver=rtk`` may be specified in a transport
|
||||
name to force that driver to be used (ex: ``usb:[driver=rtk]0`` instead of just
|
||||
``usb:0`` for the first USB device).
|
||||
The driver will look for those files by name, in order, in:
|
||||
|
||||
* The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR`
|
||||
if set.
|
||||
* The directory `<package-dir>/drivers/rtk_fw` where `<package-dir>` is the directory
|
||||
where the `bumble` package is installed.
|
||||
* The current directory.
|
||||
|
||||
|
||||
Obtaining Firmware Images and Config Data
|
||||
-----------------------------------------
|
||||
|
||||
Firmware images and config data may be obtained from a variety of online
|
||||
sources.
|
||||
To facilitate finding a downloading the, the utility program `bumble-rtk-fw-download`
|
||||
may be used.
|
||||
|
||||
```
|
||||
Usage: bumble-rtk-fw-download [OPTIONS]
|
||||
|
||||
Download RTK firmware images and configs.
|
||||
|
||||
Options:
|
||||
--output-dir TEXT Output directory where the files will be
|
||||
saved [default: .]
|
||||
--source [linux-kernel|realtek-opensource|linux-from-scratch]
|
||||
[default: linux-kernel]
|
||||
--single TEXT Only download a single image set, by its
|
||||
base name
|
||||
--force Overwrite files if they already exist
|
||||
--parse Parse the FW image after saving
|
||||
--help Show this message and exit.
|
||||
```
|
||||
|
||||
Utility
|
||||
-------
|
||||
|
||||
The `bumble-rtk-util` utility may be used to interact with a Realtek USB dongle
|
||||
and/or firmware images.
|
||||
|
||||
```
|
||||
Usage: bumble-rtk-util [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
drop Drop a firmware image from the USB dongle.
|
||||
info Get the firmware info from a USB dongle.
|
||||
load Load a firmware image into the USB dongle.
|
||||
parse Parse a firmware image.
|
||||
```
|
||||
64
docs/mkdocs/src/extras/android_bt_bench.md
Normal file
64
docs/mkdocs/src/extras/android_bt_bench.md
Normal file
@@ -0,0 +1,64 @@
|
||||
ANDROID BENCH APP
|
||||
=================
|
||||
|
||||
This Android app that is compatible with the Bumble `bench` command line app.
|
||||
This app can be used to test the throughput and latency between two Android
|
||||
devices, or between an Android device and another device running the Bumble
|
||||
`bench` app.
|
||||
Only the RFComm Client, RFComm Server, L2CAP Client and L2CAP Server modes are
|
||||
supported.
|
||||
|
||||
Building
|
||||
--------
|
||||
|
||||
You can build the app by running `./gradlew build` (use `gradlew.bat` on Windows) from the `BtBench` top level directory.
|
||||
You can also build with Android Studio: open the `BtBench` project. You can build and/or debug from there.
|
||||
|
||||
If the build succeeds, you can find the app APKs (debug and release) at:
|
||||
|
||||
* [Release] ``app/build/outputs/apk/release/app-release-unsigned.apk``
|
||||
* [Debug] ``app/build/outputs/apk/debug/app-debug.apk``
|
||||
|
||||
|
||||
Running
|
||||
-------
|
||||
|
||||
### Starting the app
|
||||
You can start the app from the Android launcher, from Android Studio, or with `adb`
|
||||
|
||||
#### Launching from the launcher
|
||||
Just tap the app icon on the launcher, check the parameters, and tap
|
||||
one of the benchmark action buttons.
|
||||
|
||||
#### Launching with `adb`
|
||||
Using the `am` command, you can start the activity, and pass it arguments so that you can
|
||||
automatically start the benchmark test, and/or set the parameters.
|
||||
|
||||
| Parameter Name | Parameter Type | Description
|
||||
|------------------------|----------------|------------
|
||||
| autostart | String | Benchmark to start. (rfcomm-client, rfcomm-server, l2cap-client or l2cap-server)
|
||||
| packet-count | Integer | Number of packets to send (rfcomm-client and l2cap-client only)
|
||||
| packet-size | Integer | Number of bytes per packet (rfcomm-client and l2cap-client only)
|
||||
| peer-bluetooth-address | Integer | Peer Bluetooth address to connect to (rfcomm-client and l2cap-client | only)
|
||||
|
||||
|
||||
!!! tip "Launching from adb with auto-start"
|
||||
In this example, we auto-start the Rfcomm Server bench action.
|
||||
```bash
|
||||
$ adb shell am start -n com.github.google.bumble.btbench/.MainActivity --es autostart rfcomm-server
|
||||
```
|
||||
|
||||
!!! tip "Launching from adb with auto-start and some parameters"
|
||||
In this example, we auto-start the Rfcomm Client bench action, set the packet count to 100,
|
||||
and the packet size to 1024, and connect to DA:4C:10:DE:17:02
|
||||
```bash
|
||||
$ adb shell am start -n com.github.google.bumble.btbench/.MainActivity --es autostart rfcomm-client --ei packet-count 100 --ei packet-size 1024 --es peer-bluetooth-address DA:4C:10:DE:17:02
|
||||
```
|
||||
|
||||
#### Selecting a Peer Bluetooth Address
|
||||
The app's main activity has a "Peer Bluetooth Address" setting where you can change the address.
|
||||
|
||||
!!! note "Bluetooth Address for L2CAP vs RFComm"
|
||||
For BLE (L2CAP mode), the address of a device typically changes regularly (it is randomized for privacy), whereas the Bluetooth Classic addresses will remain the same (RFComm mode).
|
||||
If two devices are paired and bonded, then they will each "see" a non-changing address for each other even with BLE (Resolvable Private Address)
|
||||
|
||||
181
docs/mkdocs/src/extras/android_remote_hci.md
Normal file
181
docs/mkdocs/src/extras/android_remote_hci.md
Normal file
@@ -0,0 +1,181 @@
|
||||
ANDROID REMOTE HCI APP
|
||||
======================
|
||||
|
||||
This application allows using an android phone's built-in Bluetooth controller with
|
||||
a Bumble host stack running outside the phone (typically a development laptop or desktop).
|
||||
The app runs an HCI proxy between a TCP socket on the "outside" and the Bluetooth HCI HAL
|
||||
on the "inside". (See [this page](https://source.android.com/docs/core/connect/bluetooth) for a high level
|
||||
description of the Android Bluetooth HCI HAL).
|
||||
The HCI packets received on the TCP socket are forwarded to the phone's controller, and the
|
||||
packets coming from the controller are forwarded to the TCP socket.
|
||||
|
||||
|
||||
Building
|
||||
--------
|
||||
|
||||
You can build the app by running `./gradlew build` (use `gradlew.bat` on Windows) from the `extras/android/RemoteHCI` top level directory.
|
||||
You can also build with Android Studio: open the `RemoteHCI` project. You can build and/or debug from there.
|
||||
|
||||
If the build succeeds, you can find the app APKs (debug and release) at:
|
||||
|
||||
* [Release] ``app/build/outputs/apk/release/app-release-unsigned.apk``
|
||||
* [Debug] ``app/build/outputs/apk/debug/app-debug.apk``
|
||||
|
||||
|
||||
Running
|
||||
-------
|
||||
|
||||
!!! note
|
||||
In the following examples, it is assumed that shell commands are executed while in the
|
||||
app's root directory, `extras/android/RemoteHCI`. If you are in a different directory,
|
||||
adjust the relative paths accordingly.
|
||||
|
||||
### Preconditions
|
||||
When the proxy starts (tapping the "Start" button in the app's main activity, or running the proxy
|
||||
from an `adb shell` command line), it will try to bind to the Bluetooth HAL.
|
||||
This requires that there is no other HAL client, and requires certain privileges.
|
||||
For running as a regular app, this requires disabling SELinux temporarily.
|
||||
For running as a command-line executable, this just requires a root shell.
|
||||
|
||||
#### Root Shell
|
||||
!!! tip "Restart `adb` as root"
|
||||
```bash
|
||||
$ adb root
|
||||
```
|
||||
|
||||
#### Disabling SELinux
|
||||
Binding to the Bluetooth HCI HAL requires certain SELinux permissions that can't simply be changed
|
||||
on a device without rebuilding its system image. To bypass these restrictions, you will need
|
||||
to disable SELinux on your phone (please be aware that this is global, not just for the proxy app,
|
||||
so proceed with caution).
|
||||
In order to disable SELinux, you need to root the phone (it may be advisable to do this on a
|
||||
development phone).
|
||||
|
||||
!!! tip "Disabling SELinux Temporarily"
|
||||
Restart `adb` as root:
|
||||
```bash
|
||||
$ adb root
|
||||
```
|
||||
|
||||
Then disable SELinux
|
||||
```bash
|
||||
$ adb shell setenforce 0
|
||||
```
|
||||
|
||||
Once you're done using the proxy, you can restore SELinux, if you need to, with
|
||||
```bash
|
||||
$ adb shell setenforce 1
|
||||
```
|
||||
|
||||
This state will also reset to the normal SELinux enforcement when you reboot.
|
||||
|
||||
#### Stopping the bluetooth process
|
||||
Since the Bluetooth HAL service can only accept one client, and that in normal conditions
|
||||
that client is the Android's bluetooth stack, it is required to first shut down the
|
||||
Android bluetooth stack process.
|
||||
|
||||
!!! tip "Checking if the Bluetooth process is running"
|
||||
```bash
|
||||
$ adb shell "ps -A | grep com.google.android.bluetooth"
|
||||
```
|
||||
If the process is running, you will get a line like:
|
||||
```
|
||||
bluetooth 10759 876 17455796 136620 do_epoll_wait 0 S com.google.android.bluetooth
|
||||
```
|
||||
If you don't, it means that the process is not running and you are clear to proceed.
|
||||
|
||||
Simply turning Bluetooth off from the phone's settings does not ensure that the bluetooth process will exit.
|
||||
If the bluetooth process is still running after toggling Bluetooth off from the settings, you may try enabling
|
||||
Airplane Mode, then rebooting. The bluetooth process should, in theory, not restart after the reboot.
|
||||
|
||||
!!! tip "Stopping the bluetooth process with adb"
|
||||
```bash
|
||||
$ adb shell cmd bluetooth_manager disable
|
||||
```
|
||||
|
||||
### Running as a command line app
|
||||
|
||||
You push the built APK to a temporary location on the phone's filesystem, then launch the command
|
||||
line executable with an `adb shell` command.
|
||||
|
||||
!!! tip "Pushing the executable"
|
||||
```bash
|
||||
$ adb push app/build/outputs/apk/release/app-release-unsigned.apk /data/local/tmp/remotehci.apk
|
||||
```
|
||||
Do this every time you rebuild. Alternatively, you can push the `debug` APK instead:
|
||||
```bash
|
||||
$ adb push app/build/outputs/apk/debug/app-debug.apk /data/local/tmp/remotehci.apk
|
||||
```
|
||||
|
||||
!!! tip "Start the proxy from the command line"
|
||||
```bash
|
||||
adb shell "CLASSPATH=/data/local/tmp/remotehci.apk app_process /system/bin com.github.google.bumble.remotehci.CommandLineInterface"
|
||||
```
|
||||
This will run the proxy, listening on the default TCP port.
|
||||
If you want a different port, pass it as a command line parameter
|
||||
|
||||
!!! tip "Start the proxy from the command line with a specific TCP port"
|
||||
```bash
|
||||
adb shell "CLASSPATH=/data/local/tmp/remotehci.apk app_process /system/bin com.github.google.bumble.remotehci.CommandLineInterface 12345"
|
||||
```
|
||||
|
||||
### Running as a normal app
|
||||
You can start the app from the Android launcher, from Android Studio, or with `adb`
|
||||
|
||||
#### Launching from the launcher
|
||||
Just tap the app icon on the launcher, check the TCP port that is configured, and tap
|
||||
the "Start" button.
|
||||
|
||||
#### Launching with `adb`
|
||||
Using the `am` command, you can start the activity, and pass it arguments so that you can
|
||||
automatically start the proxy, and/or set the port number.
|
||||
|
||||
!!! tip "Launching from adb with auto-start"
|
||||
```bash
|
||||
$ adb shell am start -n com.github.google.bumble.remotehci/.MainActivity --ez autostart true
|
||||
```
|
||||
|
||||
!!! tip "Launching from adb with auto-start and a port"
|
||||
In this example, we auto-start the proxy upon launch, with the port set to 9995
|
||||
```bash
|
||||
$ adb shell am start -n com.github.google.bumble.remotehci/.MainActivity --ez autostart true --ei port 9995
|
||||
```
|
||||
|
||||
#### Selecting a TCP port
|
||||
The RemoteHCI app's main activity has a "TCP Port" setting where you can change the port on
|
||||
which the proxy is accepting connections. If the default value isn't suitable, you can
|
||||
change it there (you can also use the special value 0 to let the OS assign a port number for you).
|
||||
|
||||
### Connecting to the proxy
|
||||
To connect the Bumble stack to the proxy, you need to be able to reach the phone's network
|
||||
stack. This can be done over the phone's WiFi connection, or, alternatively, using an `adb`
|
||||
TCP forward (which should be faster than over WiFi).
|
||||
|
||||
!!! tip "Forwarding TCP with `adb`"
|
||||
To connect to the proxy via an `adb` TCP forward, use:
|
||||
```bash
|
||||
$ adb forward tcp:<outside-port> tcp:<inside-port>
|
||||
```
|
||||
Where ``<outside-port>`` is the port number for a listening socket on your laptop or
|
||||
desktop machine, and <inside-port> is the TCP port selected in the app's user interface.
|
||||
Those two ports may be the same, of course.
|
||||
For example, with the default TCP port 9993:
|
||||
```bash
|
||||
$ adb forward tcp:9993 tcp:9993
|
||||
```
|
||||
|
||||
Once you've ensured that you can reach the proxy's TCP port on the phone, either directly or
|
||||
via an `adb` forward, you can then use it as a Bumble transport, using the transport name:
|
||||
``tcp-client:<host>:<port>`` syntax.
|
||||
|
||||
!!! example "Connecting a Bumble client"
|
||||
Connecting the `bumble-controller-info` app to the phone's controller.
|
||||
Assuming you have set up an `adb` forward on port 9993:
|
||||
```bash
|
||||
$ bumble-controller-info tcp-client:localhost:9993
|
||||
```
|
||||
|
||||
Or over WiFi with, in this example, the IP address of the phone being ```192.168.86.27```
|
||||
```bash
|
||||
$ bumble-controller-info tcp-client:192.168.86.27:9993
|
||||
```
|
||||
19
docs/mkdocs/src/extras/index.md
Normal file
19
docs/mkdocs/src/extras/index.md
Normal file
@@ -0,0 +1,19 @@
|
||||
EXTRAS
|
||||
======
|
||||
|
||||
A collection of add-ons, apps and tools, to the Bumble project.
|
||||
|
||||
Android Remote HCI
|
||||
------------------
|
||||
|
||||
Allows using an Android phone's built-in Bluetooth controller with a Bumble
|
||||
stack running on a development machine.
|
||||
See [Android Remote HCI](android_remote_hci.md) for details.
|
||||
|
||||
Android BT Bench
|
||||
----------------
|
||||
|
||||
An Android app that is compatible with the Bumble `bench` command line app.
|
||||
This app can be used to test the throughput and latency between two Android
|
||||
devices, or between an Android device and another device running the Bumble
|
||||
`bench` app.
|
||||
@@ -3,7 +3,7 @@ HARDWARE
|
||||
|
||||
The Bumble Host connects to a controller over an [HCI Transport](../transports/index.md).
|
||||
To use a hardware controller attached to the host on which the host application is running, the transport is typically either [HCI over UART](../transports/serial.md) or [HCI over USB](../transports/usb.md).
|
||||
On Linux, the [VHCI Transport](../transports/vhci.md) can be used to communicate with any controller hardware managed by the operating system. Alternatively, a remote controller (a phyiscal controller attached to a remote host) can be used by connecting one of the networked transports (such as the [TCP Client transport](../transports/tcp_client.md), the [TCP Server transport](../transports/tcp_server.md) or the [UDP Transport](../transports/udp.md)) to an [HCI Bridge](../apps_and_tools/hci_bridge) bridging the network transport to a physical controller on a remote host.
|
||||
On Linux, the [VHCI Transport](../transports/vhci.md) can be used to communicate with any controller hardware managed by the operating system. Alternatively, a remote controller (a phyiscal controller attached to a remote host) can be used by connecting one of the networked transports (such as the [TCP Client transport](../transports/tcp_client.md), the [TCP Server transport](../transports/tcp_server.md) or the [UDP Transport](../transports/udp.md)) to an [HCI Bridge](../apps_and_tools/hci_bridge.md) bridging the network transport to a physical controller on a remote host.
|
||||
|
||||
In theory, any controller that is compliant with the HCI over UART or HCI over USB protocols can be used.
|
||||
|
||||
|
||||
59
docs/mkdocs/src/hive/index.md
Normal file
59
docs/mkdocs/src/hive/index.md
Normal file
@@ -0,0 +1,59 @@
|
||||
HIVE
|
||||
====
|
||||
|
||||
Welcome to the Bumble Hive.
|
||||
This is a collection of apps and virtual devices that can run entirely in a browser page.
|
||||
The code for the apps and devices, as well as the Bumble runtime code, runs via [Pyodide](https://pyodide.org/).
|
||||
Pyodide is a Python distribution for the browser and Node.js based on WebAssembly.
|
||||
|
||||
The Bumble stack uses a WebSocket to exchange HCI packets with a virtual or physical
|
||||
Bluetooth controller.
|
||||
|
||||
The apps and devices in the hive can be accessed by following the links below. Each
|
||||
page has a settings button that may be used to configure the WebSocket URL to use for
|
||||
the virtual HCI connection. This will typically be the WebSocket URL for a `netsim`
|
||||
daemon.
|
||||
There is also a [TOML index](index.toml) that can be used by tools to know at which URL to access
|
||||
each of the apps and devices, as well as their names and short descriptions.
|
||||
|
||||
!!! tip "Using `netsim`"
|
||||
When the `netsimd` daemon is running (for example when using the Android Emulator that
|
||||
is included in Android Studio), the daemon listens for connections on a TCP port.
|
||||
To find out what this TCP port is, you can read the `netsim.ini` file that `netsimd`
|
||||
creates, it includes a line with `web.port=<tcp-port>` (for example `web.port=7681`).
|
||||
The location of the `netsim.ini` file is platform-specific.
|
||||
|
||||
=== "macOS"
|
||||
On macOS, the directory where `netsim.ini` is stored is $TMPDIR
|
||||
```bash
|
||||
$ cat $TMPDIR/netsim.ini
|
||||
```
|
||||
|
||||
=== "Linux"
|
||||
On Linux, the directory where `netsim.ini` is stored is $XDG_RUNTIME_DIR
|
||||
```bash
|
||||
$ cat $XDG_RUNTIME_DIR/netsim.ini
|
||||
```
|
||||
|
||||
|
||||
!!! tip "Using a local radio"
|
||||
You can connect the hive virtual apps and devices to a local Bluetooth radio, like,
|
||||
for example, a USB dongle.
|
||||
For that, you need to run a local HCI bridge to bridge a local HCI device to a WebSocket
|
||||
that a web page can connect to.
|
||||
Use the `bumble-hci-bridge` app, with the host transport set to a WebSocket server on an
|
||||
available port (ex: `ws-server:_:7682`) and the controller transport set to the transport
|
||||
name for the radio you want to use (ex: `usb:0` for the first USB dongle)
|
||||
|
||||
|
||||
Applications
|
||||
------------
|
||||
|
||||
* [Scanner](web/scanner/scanner.html) - Scans for BLE devices.
|
||||
|
||||
Virtual Devices
|
||||
---------------
|
||||
|
||||
* [Speaker](web/speaker/speaker.html) - Virtual speaker that plays audio in a browser page.
|
||||
* [Heart Rate Monitor](web/heart_rate_monitor/heart_rate_monitor.html) - Virtual heart rate monitor.
|
||||
|
||||
21
docs/mkdocs/src/hive/index.toml
Normal file
21
docs/mkdocs/src/hive/index.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
version = "1.0.0"
|
||||
base_url = "https://google.github.io/bumble/hive/web"
|
||||
default_hci_query_param = "hci"
|
||||
|
||||
[[index]]
|
||||
name = "speaker"
|
||||
description = "Bumble Virtual Speaker"
|
||||
type = "Device"
|
||||
url = "speaker/speaker.html"
|
||||
|
||||
[[index]]
|
||||
name = "scanner"
|
||||
description = "Simple Scanner Application"
|
||||
type = "Application"
|
||||
url = "scanner/scanner.html"
|
||||
|
||||
[[index]]
|
||||
name = "heart-rate-monitor"
|
||||
description = "Virtual Heart Rate Monitor"
|
||||
type = "Device"
|
||||
url = "heart_rate_monitor/heart_rate_monitor.html"
|
||||
1
docs/mkdocs/src/hive/web/bumble.js
Symbolic link
1
docs/mkdocs/src/hive/web/bumble.js
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../../web/bumble.js
|
||||
@@ -0,0 +1 @@
|
||||
../../../../../../web/heart_rate_monitor/heart_rate_monitor.html
|
||||
@@ -0,0 +1 @@
|
||||
../../../../../../web/heart_rate_monitor/heart_rate_monitor.js
|
||||
@@ -0,0 +1 @@
|
||||
../../../../../../web/heart_rate_monitor/heart_rate_monitor.py
|
||||
1
docs/mkdocs/src/hive/web/scanner/scanner.css
Symbolic link
1
docs/mkdocs/src/hive/web/scanner/scanner.css
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../../../web/scanner/scanner.css
|
||||
1
docs/mkdocs/src/hive/web/scanner/scanner.html
Symbolic link
1
docs/mkdocs/src/hive/web/scanner/scanner.html
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../../../web/scanner/scanner.html
|
||||
1
docs/mkdocs/src/hive/web/scanner/scanner.js
Symbolic link
1
docs/mkdocs/src/hive/web/scanner/scanner.js
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../../../web/scanner/scanner.js
|
||||
1
docs/mkdocs/src/hive/web/scanner/scanner.py
Symbolic link
1
docs/mkdocs/src/hive/web/scanner/scanner.py
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../../../web/scanner/scanner.py
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user