forked from auracaster/bumble_mirror
Compare commits
410 Commits
v0.0.173
...
gbg/usb-qu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a51166af7 | ||
|
|
5aae44b610 | ||
|
|
e3ea167827 | ||
|
|
eec145e095 | ||
|
|
87fa02d6e5 | ||
|
|
ad94c1e1f3 | ||
|
|
546a0bce8d | ||
|
|
cb7ca44a1c | ||
|
|
4081b93407 | ||
|
|
26203ebaad | ||
|
|
3389e3e1ed | ||
|
|
7e1f01c01e | ||
|
|
613e15548a | ||
|
|
e09c91df8e | ||
|
|
df206667b6 | ||
|
|
0f19dd5263 | ||
|
|
b98e4937f3 | ||
|
|
c2c46e9ace | ||
|
|
27791cf218 | ||
|
|
32a41a815d | ||
|
|
df5fc2ddfe | ||
|
|
79122313a6 | ||
|
|
d7d03e2e92 | ||
|
|
ea493480a9 | ||
|
|
658f641a53 | ||
|
|
f8a2d4f0e0 | ||
|
|
00edd1fbf8 | ||
|
|
999d7b07e1 | ||
|
|
2e3aeb8648 | ||
|
|
f910a696ad | ||
|
|
e1d10bc482 | ||
|
|
181467f11b | ||
|
|
394137b6f7 | ||
|
|
dea907be86 | ||
|
|
f5baf51132 | ||
|
|
f2dc8bd84e | ||
|
|
090309302f | ||
|
|
28e6229b24 | ||
|
|
1b66f03dbe | ||
|
|
e34f6b5fd3 | ||
|
|
8a0482c947 | ||
|
|
938a189f3f | ||
|
|
2005b4a11b | ||
|
|
951fdc8bdd | ||
|
|
12af7a526c | ||
|
|
8781943646 | ||
|
|
7fbfdb634c | ||
|
|
9682077f6b | ||
|
|
22eb405fde | ||
|
|
593c61973f | ||
|
|
ccff32102f | ||
|
|
851d62c6c9 | ||
|
|
a5ac5f26e2 | ||
|
|
090158820f | ||
|
|
26e6650038 | ||
|
|
c48568aabe | ||
|
|
1b33c9eb74 | ||
|
|
6633228975 | ||
|
|
e9cba788a4 | ||
|
|
98822cfc6b | ||
|
|
97ad7e5741 | ||
|
|
71df062e07 | ||
|
|
049f9021e9 | ||
|
|
50eae2ef54 | ||
|
|
c8883a7d0f | ||
|
|
51321caf5b | ||
|
|
51a94288e2 | ||
|
|
8758856e8c | ||
|
|
deba181857 | ||
|
|
c65188dcbf | ||
|
|
21d607898d | ||
|
|
2698d4534e | ||
|
|
bbcd64286a | ||
|
|
9140afbf8c | ||
|
|
90a682c71b | ||
|
|
e8737a8243 | ||
|
|
72fceca72e | ||
|
|
732294abbc | ||
|
|
dc1204531e | ||
|
|
962114379c | ||
|
|
e6913a3055 | ||
|
|
e21d122aef | ||
|
|
58d4ab913a | ||
|
|
76bca03fe3 | ||
|
|
f1e5c9e59e | ||
|
|
ec82242462 | ||
|
|
a4efdd3f3e | ||
|
|
69c6643bb8 | ||
|
|
b8214bf948 | ||
|
|
a9c62c44b3 | ||
|
|
7d0b4ef4e0 | ||
|
|
313340f1c6 | ||
|
|
e8ed69fb09 | ||
|
|
16d5cf6770 | ||
|
|
a2caf1deb2 | ||
|
|
01bfdd2c98 | ||
|
|
4a60df108a | ||
|
|
ad48109748 | ||
|
|
1ceeccbbc0 | ||
|
|
44c51c13ac | ||
|
|
7507be1eab | ||
|
|
cbe9446dcf | ||
|
|
174930399a | ||
|
|
35db4a4c93 | ||
|
|
1f3aee5566 | ||
|
|
256044a789 | ||
|
|
6205199d7f | ||
|
|
e554bd1033 | ||
|
|
38981cefa1 | ||
|
|
f2d601f411 | ||
|
|
6e7c64c1de | ||
|
|
565d51f4db | ||
|
|
de8f3d9c1e | ||
|
|
cde6d48690 | ||
|
|
02180088b3 | ||
|
|
90f49267d1 | ||
|
|
0e6d69cd7b | ||
|
|
9eccc583d5 | ||
|
|
f4aeaa6eb3 | ||
|
|
d7489a644a | ||
|
|
a877283360 | ||
|
|
6d91e7e79b | ||
|
|
567146b143 | ||
|
|
1a3272d7ca | ||
|
|
1ee1ff0b62 | ||
|
|
729fd97748 | ||
|
|
e308051885 | ||
|
|
10e53553d7 | ||
|
|
ef0b30d059 | ||
|
|
e7e9f9509a | ||
|
|
c6cfd101df | ||
|
|
d2dcf063ee | ||
|
|
d15bc7d664 | ||
|
|
e4364d18a7 | ||
|
|
6a34c9f224 | ||
|
|
2a764fd6bb | ||
|
|
3e8ce38eba | ||
|
|
8d2f37aa7a | ||
|
|
b7b70ebcbb | ||
|
|
8ba91f4986 | ||
|
|
79a5e953bc | ||
|
|
20de5ea250 | ||
|
|
bad9ce272c | ||
|
|
d3273ffa8c | ||
|
|
071fc2723a | ||
|
|
ef4ea86f58 | ||
|
|
dfdaa149d0 | ||
|
|
986343a807 | ||
|
|
5211d7ba96 | ||
|
|
a167342778 | ||
|
|
1efb8cdbee | ||
|
|
80d83e6a70 | ||
|
|
31ec1c41ce | ||
|
|
aba1ac0cea | ||
|
|
c40824e51c | ||
|
|
2920f05dae | ||
|
|
bc911d6da0 | ||
|
|
4f87f587e4 | ||
|
|
3e38ab3638 | ||
|
|
21bb911fea | ||
|
|
744dfa33a2 | ||
|
|
ec5f8535a8 | ||
|
|
5a83734a00 | ||
|
|
b4ae8af3a7 | ||
|
|
da60386385 | ||
|
|
45c4c4f4c5 | ||
|
|
9187c75d68 | ||
|
|
abeec22546 | ||
|
|
a6bab755cf | ||
|
|
acd9d994c3 | ||
|
|
37afda3ed3 | ||
|
|
54f2981267 | ||
|
|
bb025514e7 | ||
|
|
e228597269 | ||
|
|
95b0d6c6f2 | ||
|
|
fa4df6e3a2 | ||
|
|
46ceea7ecd | ||
|
|
30f89d5739 | ||
|
|
481cf40831 | ||
|
|
eff05afb7a | ||
|
|
d8e6700611 | ||
|
|
56eb5a933b | ||
|
|
caacc0c133 | ||
|
|
5f377c024b | ||
|
|
00cd8fbdd0 | ||
|
|
aeeff18428 | ||
|
|
c48e3f5e9c | ||
|
|
d6bbc1145a | ||
|
|
e2fec67bd9 | ||
|
|
88cb3b2a4d | ||
|
|
9ebb03be46 | ||
|
|
80d84af76c | ||
|
|
8f4721758f | ||
|
|
8864af4acd | ||
|
|
8980fb8cc7 | ||
|
|
2c5f3472a9 | ||
|
|
f18277ac78 | ||
|
|
8d46bc04d2 | ||
|
|
09e5ea5dec | ||
|
|
d43281c57e | ||
|
|
6810865670 | ||
|
|
3e9e06a02c | ||
|
|
ccd12f6591 | ||
|
|
f9a7843f7e | ||
|
|
210c334db7 | ||
|
|
f297cdfcce | ||
|
|
5b536d00ab | ||
|
|
b4af46ebd5 | ||
|
|
c08da3193e | ||
|
|
f2925ca647 | ||
|
|
fd4d68e5c0 | ||
|
|
5d83deffa4 | ||
|
|
2878cca478 | ||
|
|
53934716db | ||
|
|
d885d45824 | ||
|
|
b90d0f8710 | ||
|
|
8ccfc90fe6 | ||
|
|
92aa7e9e2a | ||
|
|
afc6d19e04 | ||
|
|
c05f073b33 | ||
|
|
2b4c2a22f4 | ||
|
|
47fe93a148 | ||
|
|
6139ca8045 | ||
|
|
87c76a4a0e | ||
|
|
f7b66db873 | ||
|
|
0b314bd7f7 | ||
|
|
9da2e32ad7 | ||
|
|
93c0875740 | ||
|
|
a286700239 | ||
|
|
98ed772e8a | ||
|
|
f0b55a4f97 | ||
|
|
b74503d345 | ||
|
|
f911163e49 | ||
|
|
b083cc99ad | ||
|
|
d35643524e | ||
|
|
62a8ced447 | ||
|
|
085f163c92 | ||
|
|
81a6b1e097 | ||
|
|
dd090c9e6b | ||
|
|
11faa48422 | ||
|
|
55596176c2 | ||
|
|
4d6822d312 | ||
|
|
985c365e6d | ||
|
|
af57762227 | ||
|
|
3575f9030e | ||
|
|
698d947d85 | ||
|
|
ff6528d2bf | ||
|
|
72ac75a98d | ||
|
|
5e3ecb74e4 | ||
|
|
c59be293c8 | ||
|
|
88b4cbdf1a | ||
|
|
d6afbc6f4e | ||
|
|
fc90de3e7b | ||
|
|
847c2ef114 | ||
|
|
a0bf0c1f4d | ||
|
|
8400ff0802 | ||
|
|
0ed6aa230b | ||
|
|
6d22ed80ec | ||
|
|
72d5360af9 | ||
|
|
ac3961e763 | ||
|
|
843466c822 | ||
|
|
8385035400 | ||
|
|
3adcc8be09 | ||
|
|
c853d56302 | ||
|
|
dc97be5b35 | ||
|
|
73dbdfff9f | ||
|
|
dff14e1258 | ||
|
|
10a3833893 | ||
|
|
247cb89332 | ||
|
|
3fc71a0266 | ||
|
|
392dcc3a05 | ||
|
|
f27015d1b7 | ||
|
|
86a19b41aa | ||
|
|
320164d476 | ||
|
|
40ae661ee5 | ||
|
|
ffb3eca68b | ||
|
|
c5def93bb8 | ||
|
|
a9c4c5833d | ||
|
|
58c9c4f590 | ||
|
|
24524d88cb | ||
|
|
b8849ab311 | ||
|
|
f3cd8f8ed0 | ||
|
|
2b26de3f3a | ||
|
|
0149c4c212 | ||
|
|
f2ed898784 | ||
|
|
464a476f9f | ||
|
|
e85d067fb5 | ||
|
|
7eb493990f | ||
|
|
04d5bf3afc | ||
|
|
403a13e4c6 | ||
|
|
ad0f035df5 | ||
|
|
a13e193d3b | ||
|
|
28a1a5ebc2 | ||
|
|
6310dc777f | ||
|
|
07f71fc895 | ||
|
|
f47b9178ad | ||
|
|
863de18877 | ||
|
|
4f399249bd | ||
|
|
f0e5cdee1a | ||
|
|
7bc7d0f5af | ||
|
|
a65a215fd7 | ||
|
|
80d34a226d | ||
|
|
a9628f73e3 | ||
|
|
9324237828 | ||
|
|
d1033c018a | ||
|
|
0f29052ade | ||
|
|
0578e84586 | ||
|
|
6ab41c466f | ||
|
|
98a1093ebf | ||
|
|
caf04373f3 | ||
|
|
d4e8526766 | ||
|
|
515b83a8c7 | ||
|
|
9bf2e03354 | ||
|
|
dc18595c8a | ||
|
|
488bcfe9c6 | ||
|
|
2900b93bb3 | ||
|
|
284cc8a321 | ||
|
|
3dc2e4036c | ||
|
|
268f6b0d51 | ||
|
|
46239b321b | ||
|
|
8a536cd522 | ||
|
|
f9f5d7ccbd | ||
|
|
d6cefdff8e | ||
|
|
dc410b14c4 | ||
|
|
4c49ef9403 | ||
|
|
ba85dcbda5 | ||
|
|
e08c84dd20 | ||
|
|
8b46136703 | ||
|
|
9c7089c8ff | ||
|
|
aac8d89cd0 | ||
|
|
24e75bfeab | ||
|
|
42868b08d3 | ||
|
|
19b61d9ac0 | ||
|
|
db2a2e2bb9 | ||
|
|
e1fdb12647 | ||
|
|
a8ec1b0949 | ||
|
|
2e30b2de77 | ||
|
|
7e407ccae1 | ||
|
|
0667e83919 | ||
|
|
1a6c9a4d04 | ||
|
|
14f5b912ad | ||
|
|
46d6242171 | ||
|
|
753b966148 | ||
|
|
5a307c19b8 | ||
|
|
2cd4f84800 | ||
|
|
4ae612090b | ||
|
|
c67ca4a09e | ||
|
|
94506220d3 | ||
|
|
dbd865a484 | ||
|
|
9d2f3e932a | ||
|
|
49d32f5b5b | ||
|
|
f7b74c0bcb | ||
|
|
c75cb0c7b7 | ||
|
|
a63b335149 | ||
|
|
d8517ce407 | ||
|
|
ad13b11464 | ||
|
|
99bc92d53d | ||
|
|
72199f5615 | ||
|
|
78b8b50082 | ||
|
|
3ab64ce00d | ||
|
|
651e44e0b6 | ||
|
|
963fa41a49 | ||
|
|
493f4f8b95 | ||
|
|
fc1bf36ace | ||
|
|
5ddee17411 | ||
|
|
5ce353bcde | ||
|
|
16d33199eb | ||
|
|
e02303a448 | ||
|
|
36fc966ad6 | ||
|
|
644f74400d | ||
|
|
b7cd451ddb | ||
|
|
59d7717963 | ||
|
|
88392efca4 | ||
|
|
907f2acc7e | ||
|
|
6616477bcf | ||
|
|
5b173cb879 | ||
|
|
dc6b466a42 | ||
|
|
8b04161da3 | ||
|
|
5a85765360 | ||
|
|
333940919b | ||
|
|
b9476be9ad | ||
|
|
704c60491c | ||
|
|
4a8e612c6e | ||
|
|
5e5c9c2580 | ||
|
|
4e71ec5738 | ||
|
|
1004f10384 | ||
|
|
1051648ffb | ||
|
|
7255a09705 | ||
|
|
c2bf6b5f13 | ||
|
|
d8e699b588 | ||
|
|
3e4d4705f5 | ||
|
|
c8b2804446 | ||
|
|
e732f2589f | ||
|
|
aec5543081 | ||
|
|
e03d90ca57 | ||
|
|
495ce62d9c | ||
|
|
fbc3959a5a | ||
|
|
246b11925c | ||
|
|
dfa9131192 | ||
|
|
88c801b4c2 | ||
|
|
a1b55b94e0 | ||
|
|
80db9e2e2f | ||
|
|
ce74690420 | ||
|
|
50de4dfb5d | ||
|
|
9bcdf860f4 | ||
|
|
511ab4b630 | ||
|
|
45edcafb06 | ||
|
|
9f0bcc131f | ||
|
|
7e331c2944 | ||
|
|
50fd2218fa |
30
.devcontainer/devcontainer.json
Normal file
30
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,30 @@
|
||||
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
||||
// README at: https://github.com/devcontainers/templates/tree/main/src/python
|
||||
{
|
||||
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
|
||||
"image": "mcr.microsoft.com/devcontainers/universal:2",
|
||||
|
||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||
// "features": {},
|
||||
|
||||
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||
// "forwardPorts": [],
|
||||
|
||||
// Use 'postCreateCommand' to run commands after the container is created.
|
||||
"postCreateCommand":
|
||||
"python -m pip install '.[build,test,development,documentation]'",
|
||||
|
||||
// Configure tool-specific properties.
|
||||
"customizations": {
|
||||
// Configure properties specific to VS Code.
|
||||
"vscode": {
|
||||
// Add the IDs of extensions you want installed when the container is created.
|
||||
"extensions": [
|
||||
"ms-python.python"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
||||
// "remoteUser": "root"
|
||||
}
|
||||
6
.github/workflows/code-check.yml
vendored
6
.github/workflows/code-check.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
@@ -29,11 +29,11 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[build,test,development]"
|
||||
python -m pip install ".[build,test,development,pandora]"
|
||||
- name: Check
|
||||
run: |
|
||||
invoke project.pre-commit
|
||||
|
||||
43
.github/workflows/python-avatar.yml
vendored
Normal file
43
.github/workflows/python-avatar.yml
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
name: Python Avatar
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Avatar [${{ matrix.shard }}]
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
shard: [
|
||||
1/24, 2/24, 3/24, 4/24,
|
||||
5/24, 6/24, 7/24, 8/24,
|
||||
9/24, 10/24, 11/24, 12/24,
|
||||
13/24, 14/24, 15/24, 16/24,
|
||||
17/24, 18/24, 19/24, 20/24,
|
||||
21/24, 22/24, 23/24, 24/24,
|
||||
]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set Up Python 3.11
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.11
|
||||
- name: Install
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install .[avatar,pandora]
|
||||
- name: Rootcanal
|
||||
run: nohup python -m rootcanal > rootcanal.log &
|
||||
- name: Test
|
||||
run: |
|
||||
avatar --list | grep -Ev '^=' > test-names.txt
|
||||
timeout 5m avatar --test-beds bumble.bumbles --tests $(split test-names.txt -n l/${{ matrix.shard }})
|
||||
- name: Rootcanal Logs
|
||||
run: cat rootcanal.log
|
||||
14
.github/workflows/python-build-test.yml
vendored
14
.github/workflows/python-build-test.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
@@ -46,8 +46,8 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [ "3.8", "3.9", "3.10", "3.11" ]
|
||||
rust-version: [ "1.70.0", "stable" ]
|
||||
python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
|
||||
rust-version: [ "1.76.0", "stable" ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out from Git
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[build,test,development,documentation]"
|
||||
@@ -65,15 +65,17 @@ jobs:
|
||||
with:
|
||||
components: clippy,rustfmt
|
||||
toolchain: ${{ matrix.rust-version }}
|
||||
- name: Install Rust dependencies
|
||||
run: cargo install cargo-all-features # allows building/testing combinations of features
|
||||
- name: Check License Headers
|
||||
run: cd rust && cargo run --features dev-tools --bin file-header check-all
|
||||
- name: Rust Build
|
||||
run: cd rust && cargo build --all-targets && cargo build --all-features --all-targets
|
||||
run: cd rust && cargo build --all-targets && cargo build-all-features --all-targets
|
||||
# Lints after build so what clippy needs is already built
|
||||
- name: Rust Lints
|
||||
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings && cargo clippy --all-features --all-targets -- --deny warnings
|
||||
- name: Rust Tests
|
||||
run: cd rust && cargo test
|
||||
run: cd rust && cargo test-all-features
|
||||
# At some point, hook up publishing the binary. For now, just make sure it builds.
|
||||
# Once we're ready to publish binaries, this should be built with `--release`.
|
||||
- name: Build Bumble CLI
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -6,7 +6,14 @@ dist/
|
||||
docs/mkdocs/site
|
||||
test-results.xml
|
||||
__pycache__
|
||||
# Vim
|
||||
.*.sw*
|
||||
# generated by setuptools_scm
|
||||
bumble/_version.py
|
||||
.vscode/launch.json
|
||||
.vscode/settings.json
|
||||
/.idea
|
||||
venv/
|
||||
.venv/
|
||||
# snoop logs
|
||||
out/
|
||||
|
||||
17
.vscode/settings.json
vendored
17
.vscode/settings.json
vendored
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"cSpell.words": [
|
||||
"Abortable",
|
||||
"aiohttp",
|
||||
"altsetting",
|
||||
"ansiblue",
|
||||
"ansicyan",
|
||||
@@ -9,10 +10,13 @@
|
||||
"ansired",
|
||||
"ansiyellow",
|
||||
"appendleft",
|
||||
"ascs",
|
||||
"ASHA",
|
||||
"asyncio",
|
||||
"ATRAC",
|
||||
"avctp",
|
||||
"avdtp",
|
||||
"avrcp",
|
||||
"bitpool",
|
||||
"bitstruct",
|
||||
"BSCP",
|
||||
@@ -21,7 +25,10 @@
|
||||
"cccds",
|
||||
"cmac",
|
||||
"CONNECTIONLESS",
|
||||
"csip",
|
||||
"csis",
|
||||
"csrcs",
|
||||
"CVSD",
|
||||
"datagram",
|
||||
"DATALINK",
|
||||
"delayreport",
|
||||
@@ -29,6 +36,8 @@
|
||||
"deregistration",
|
||||
"dhkey",
|
||||
"diversifier",
|
||||
"endianness",
|
||||
"ESCO",
|
||||
"Fitbit",
|
||||
"GATTLINK",
|
||||
"HANDSFREE",
|
||||
@@ -36,17 +45,21 @@
|
||||
"keyup",
|
||||
"levelname",
|
||||
"libc",
|
||||
"liblc",
|
||||
"libusb",
|
||||
"MITM",
|
||||
"MSBC",
|
||||
"NDIS",
|
||||
"netsim",
|
||||
"NONBLOCK",
|
||||
"NONCONN",
|
||||
"OXIMETER",
|
||||
"popleft",
|
||||
"PRAND",
|
||||
"protobuf",
|
||||
"psms",
|
||||
"pyee",
|
||||
"Pyodide",
|
||||
"pyusb",
|
||||
"rfcomm",
|
||||
"ROHC",
|
||||
@@ -54,6 +67,7 @@
|
||||
"SEID",
|
||||
"seids",
|
||||
"SERV",
|
||||
"SIRK",
|
||||
"ssrc",
|
||||
"strerror",
|
||||
"subband",
|
||||
@@ -63,8 +77,11 @@
|
||||
"substates",
|
||||
"tobytes",
|
||||
"tsep",
|
||||
"UNMUTE",
|
||||
"unmuted",
|
||||
"usbmodem",
|
||||
"vhci",
|
||||
"wasmtime",
|
||||
"websockets",
|
||||
"xcursor",
|
||||
"ycursor"
|
||||
|
||||
407
apps/auracast.py
Normal file
407
apps/auracast.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from typing import cast, Dict, Optional, Tuple
|
||||
|
||||
import click
|
||||
import pyee
|
||||
|
||||
from bumble.colors import color
|
||||
import bumble.company_ids
|
||||
import bumble.core
|
||||
import bumble.device
|
||||
import bumble.gatt
|
||||
import bumble.hci
|
||||
import bumble.profiles.bap
|
||||
import bumble.profiles.pbp
|
||||
import bumble.transport
|
||||
import bumble.utils
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
AURACAST_DEFAULT_DEVICE_NAME = "Bumble Auracast"
|
||||
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address("F0:F1:F2:F3:F4:F5")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Discover Broadcasts
|
||||
# -----------------------------------------------------------------------------
|
||||
class BroadcastDiscoverer:
|
||||
@dataclasses.dataclass
|
||||
class Broadcast(pyee.EventEmitter):
|
||||
name: str
|
||||
sync: bumble.device.PeriodicAdvertisingSync
|
||||
rssi: int = 0
|
||||
public_broadcast_announcement: Optional[
|
||||
bumble.profiles.pbp.PublicBroadcastAnnouncement
|
||||
] = None
|
||||
broadcast_audio_announcement: Optional[
|
||||
bumble.profiles.bap.BroadcastAudioAnnouncement
|
||||
] = None
|
||||
basic_audio_announcement: Optional[
|
||||
bumble.profiles.bap.BasicAudioAnnouncement
|
||||
] = None
|
||||
appearance: Optional[bumble.core.Appearance] = None
|
||||
biginfo: Optional[bumble.device.BIGInfoAdvertisement] = None
|
||||
manufacturer_data: Optional[Tuple[str, bytes]] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
self.sync.on('establishment', self.on_sync_establishment)
|
||||
self.sync.on('loss', self.on_sync_loss)
|
||||
self.sync.on('periodic_advertisement', self.on_periodic_advertisement)
|
||||
self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement)
|
||||
|
||||
self.establishment_timeout_task = asyncio.create_task(
|
||||
self.wait_for_establishment()
|
||||
)
|
||||
|
||||
async def wait_for_establishment(self) -> None:
|
||||
await asyncio.sleep(5.0)
|
||||
if self.sync.state == bumble.device.PeriodicAdvertisingSync.State.PENDING:
|
||||
print(
|
||||
color(
|
||||
'!!! Periodic advertisement sync not established in time, '
|
||||
'canceling',
|
||||
'red',
|
||||
)
|
||||
)
|
||||
await self.sync.terminate()
|
||||
|
||||
def update(self, advertisement: bumble.device.Advertisement) -> None:
|
||||
self.rssi = advertisement.rssi
|
||||
for service_data in advertisement.data.get_all(
|
||||
bumble.core.AdvertisingData.SERVICE_DATA
|
||||
):
|
||||
assert isinstance(service_data, tuple)
|
||||
service_uuid, data = service_data
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
if (
|
||||
service_uuid
|
||||
== bumble.gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE
|
||||
):
|
||||
self.public_broadcast_announcement = (
|
||||
bumble.profiles.pbp.PublicBroadcastAnnouncement.from_bytes(data)
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
service_uuid
|
||||
== bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
|
||||
):
|
||||
self.broadcast_audio_announcement = (
|
||||
bumble.profiles.bap.BroadcastAudioAnnouncement.from_bytes(data)
|
||||
)
|
||||
continue
|
||||
|
||||
self.appearance = advertisement.data.get( # type: ignore[assignment]
|
||||
bumble.core.AdvertisingData.APPEARANCE
|
||||
)
|
||||
|
||||
if manufacturer_data := advertisement.data.get(
|
||||
bumble.core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA
|
||||
):
|
||||
assert isinstance(manufacturer_data, tuple)
|
||||
company_id = cast(int, manufacturer_data[0])
|
||||
data = cast(bytes, manufacturer_data[1])
|
||||
self.manufacturer_data = (
|
||||
bumble.company_ids.COMPANY_IDENTIFIERS.get(
|
||||
company_id, f'0x{company_id:04X}'
|
||||
),
|
||||
data,
|
||||
)
|
||||
|
||||
def print(self) -> None:
|
||||
print(
|
||||
color('Broadcast:', 'yellow'),
|
||||
self.sync.advertiser_address,
|
||||
color(self.sync.state.name, 'green'),
|
||||
)
|
||||
print(f' {color("Name", "cyan")}: {self.name}')
|
||||
if self.appearance:
|
||||
print(f' {color("Appearance", "cyan")}: {str(self.appearance)}')
|
||||
print(f' {color("RSSI", "cyan")}: {self.rssi}')
|
||||
print(f' {color("SID", "cyan")}: {self.sync.sid}')
|
||||
|
||||
if self.manufacturer_data:
|
||||
print(
|
||||
f' {color("Manufacturer Data", "cyan")}: '
|
||||
f'{self.manufacturer_data[0]} -> {self.manufacturer_data[1].hex()}'
|
||||
)
|
||||
|
||||
if self.broadcast_audio_announcement:
|
||||
print(
|
||||
f' {color("Broadcast ID", "cyan")}: '
|
||||
f'{self.broadcast_audio_announcement.broadcast_id}'
|
||||
)
|
||||
|
||||
if self.public_broadcast_announcement:
|
||||
print(
|
||||
f' {color("Features", "cyan")}: '
|
||||
f'{self.public_broadcast_announcement.features}'
|
||||
)
|
||||
print(
|
||||
f' {color("Metadata", "cyan")}: '
|
||||
f'{self.public_broadcast_announcement.metadata}'
|
||||
)
|
||||
|
||||
if self.basic_audio_announcement:
|
||||
print(color(' Audio:', 'cyan'))
|
||||
print(
|
||||
color(' Presentation Delay:', 'magenta'),
|
||||
self.basic_audio_announcement.presentation_delay,
|
||||
)
|
||||
for subgroup in self.basic_audio_announcement.subgroups:
|
||||
print(color(' Subgroup:', 'magenta'))
|
||||
print(color(' Codec ID:', 'yellow'))
|
||||
print(
|
||||
color(' Coding Format: ', 'green'),
|
||||
subgroup.codec_id.coding_format.name,
|
||||
)
|
||||
print(
|
||||
color(' Company ID: ', 'green'),
|
||||
subgroup.codec_id.company_id,
|
||||
)
|
||||
print(
|
||||
color(' Vendor Specific Codec ID:', 'green'),
|
||||
subgroup.codec_id.vendor_specific_codec_id,
|
||||
)
|
||||
print(
|
||||
color(' Codec Config:', 'yellow'),
|
||||
subgroup.codec_specific_configuration,
|
||||
)
|
||||
print(color(' Metadata: ', 'yellow'), subgroup.metadata)
|
||||
|
||||
for bis in subgroup.bis:
|
||||
print(color(f' BIS [{bis.index}]:', 'yellow'))
|
||||
print(
|
||||
color(' Codec Config:', 'green'),
|
||||
bis.codec_specific_configuration,
|
||||
)
|
||||
|
||||
if self.biginfo:
|
||||
print(color(' BIG:', 'cyan'))
|
||||
print(
|
||||
color(' Number of BIS:', 'magenta'),
|
||||
self.biginfo.num_bis,
|
||||
)
|
||||
print(
|
||||
color(' PHY: ', 'magenta'),
|
||||
self.biginfo.phy.name,
|
||||
)
|
||||
print(
|
||||
color(' Framed: ', 'magenta'),
|
||||
self.biginfo.framed,
|
||||
)
|
||||
print(
|
||||
color(' Encrypted: ', 'magenta'),
|
||||
self.biginfo.encrypted,
|
||||
)
|
||||
|
||||
def on_sync_establishment(self) -> None:
|
||||
self.establishment_timeout_task.cancel()
|
||||
self.emit('change')
|
||||
|
||||
def on_sync_loss(self) -> None:
|
||||
self.basic_audio_announcement = None
|
||||
self.biginfo = None
|
||||
self.emit('change')
|
||||
|
||||
def on_periodic_advertisement(
|
||||
self, advertisement: bumble.device.PeriodicAdvertisement
|
||||
) -> None:
|
||||
if advertisement.data is None:
|
||||
return
|
||||
|
||||
for service_data in advertisement.data.get_all(
|
||||
bumble.core.AdvertisingData.SERVICE_DATA
|
||||
):
|
||||
assert isinstance(service_data, tuple)
|
||||
service_uuid, data = service_data
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
if service_uuid == bumble.gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
|
||||
self.basic_audio_announcement = (
|
||||
bumble.profiles.bap.BasicAudioAnnouncement.from_bytes(data)
|
||||
)
|
||||
break
|
||||
|
||||
self.emit('change')
|
||||
|
||||
def on_biginfo_advertisement(
|
||||
self, advertisement: bumble.device.BIGInfoAdvertisement
|
||||
) -> None:
|
||||
self.biginfo = advertisement
|
||||
self.emit('change')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: bumble.device.Device,
|
||||
filter_duplicates: bool,
|
||||
sync_timeout: float,
|
||||
):
|
||||
self.device = device
|
||||
self.filter_duplicates = filter_duplicates
|
||||
self.sync_timeout = sync_timeout
|
||||
self.broadcasts: Dict[bumble.hci.Address, BroadcastDiscoverer.Broadcast] = {}
|
||||
self.status_message = ''
|
||||
device.on('advertisement', self.on_advertisement)
|
||||
|
||||
async def run(self) -> None:
|
||||
self.status_message = color('Scanning...', 'green')
|
||||
await self.device.start_scanning(
|
||||
active=False,
|
||||
filter_duplicates=False,
|
||||
)
|
||||
|
||||
def refresh(self) -> None:
|
||||
# Clear the screen from the top
|
||||
print('\033[H')
|
||||
print('\033[0J')
|
||||
print('\033[H')
|
||||
|
||||
# Print the status message
|
||||
print(self.status_message)
|
||||
print("==========================================")
|
||||
|
||||
# Print all broadcasts
|
||||
for broadcast in self.broadcasts.values():
|
||||
broadcast.print()
|
||||
print('------------------------------------------')
|
||||
|
||||
# Clear the screen to the bottom
|
||||
print('\033[0J')
|
||||
|
||||
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
|
||||
if (
|
||||
broadcast_name := advertisement.data.get(
|
||||
bumble.core.AdvertisingData.BROADCAST_NAME
|
||||
)
|
||||
) is None:
|
||||
return
|
||||
assert isinstance(broadcast_name, str)
|
||||
|
||||
if broadcast := self.broadcasts.get(advertisement.address):
|
||||
broadcast.update(advertisement)
|
||||
self.refresh()
|
||||
return
|
||||
|
||||
bumble.utils.AsyncRunner.spawn(
|
||||
self.on_new_broadcast(broadcast_name, advertisement)
|
||||
)
|
||||
|
||||
async def on_new_broadcast(
|
||||
self, name: str, advertisement: bumble.device.Advertisement
|
||||
) -> None:
|
||||
periodic_advertising_sync = await self.device.create_periodic_advertising_sync(
|
||||
advertiser_address=advertisement.address,
|
||||
sid=advertisement.sid,
|
||||
sync_timeout=self.sync_timeout,
|
||||
filter_duplicates=self.filter_duplicates,
|
||||
)
|
||||
broadcast = self.Broadcast(
|
||||
name,
|
||||
periodic_advertising_sync,
|
||||
)
|
||||
broadcast.on('change', self.refresh)
|
||||
broadcast.update(advertisement)
|
||||
self.broadcasts[advertisement.address] = broadcast
|
||||
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
|
||||
self.status_message = color(
|
||||
f'+Found {len(self.broadcasts)} broadcasts', 'green'
|
||||
)
|
||||
self.refresh()
|
||||
|
||||
def on_broadcast_loss(self, broadcast: Broadcast) -> None:
|
||||
del self.broadcasts[broadcast.sync.advertiser_address]
|
||||
bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate())
|
||||
self.status_message = color(
|
||||
f'-Found {len(self.broadcasts)} broadcasts', 'green'
|
||||
)
|
||||
self.refresh()
|
||||
|
||||
|
||||
async def run_discover_broadcasts(
|
||||
filter_duplicates: bool, sync_timeout: float, transport: str
|
||||
) -> None:
|
||||
async with await bumble.transport.open_transport(transport) as (
|
||||
hci_source,
|
||||
hci_sink,
|
||||
):
|
||||
device = bumble.device.Device.with_hci(
|
||||
AURACAST_DEFAULT_DEVICE_NAME,
|
||||
AURACAST_DEFAULT_DEVICE_ADDRESS,
|
||||
hci_source,
|
||||
hci_sink,
|
||||
)
|
||||
await device.power_on()
|
||||
discoverer = BroadcastDiscoverer(device, filter_duplicates, sync_timeout)
|
||||
await discoverer.run()
|
||||
await hci_source.terminated
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Main
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.group()
|
||||
@click.pass_context
|
||||
def auracast(
|
||||
ctx,
|
||||
):
|
||||
ctx.ensure_object(dict)
|
||||
|
||||
|
||||
@auracast.command('discover-broadcasts')
|
||||
@click.option(
|
||||
'--filter-duplicates', is_flag=True, default=False, help='Filter duplicates'
|
||||
)
|
||||
@click.option(
|
||||
'--sync-timeout',
|
||||
metavar='SYNC_TIMEOUT',
|
||||
type=float,
|
||||
default=5.0,
|
||||
help='Sync timeout (in seconds)',
|
||||
)
|
||||
@click.argument('transport')
|
||||
@click.pass_context
|
||||
def discover_broadcasts(ctx, filter_duplicates, sync_timeout, transport):
|
||||
"""Discover public broadcasts"""
|
||||
asyncio.run(run_discover_broadcasts(filter_duplicates, sync_timeout, transport))
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
auracast()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
911
apps/bench.py
911
apps/bench.py
File diff suppressed because it is too large
Load Diff
63
apps/ble_rpa_tool.py
Normal file
63
apps/ble_rpa_tool.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import click
|
||||
from bumble.colors import color
|
||||
from bumble.hci import Address
|
||||
from bumble.helpers import generate_irk, verify_rpa_with_irk
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
'''
|
||||
This is a tool for generating IRK, RPA,
|
||||
and verifying IRK/RPA pairs
|
||||
'''
|
||||
|
||||
|
||||
@click.command()
|
||||
def gen_irk() -> None:
|
||||
print(generate_irk().hex())
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("irk", type=str)
|
||||
def gen_rpa(irk: str) -> None:
|
||||
irk_bytes = bytes.fromhex(irk)
|
||||
rpa = Address.generate_private_address(irk_bytes)
|
||||
print(rpa.to_string(with_type_qualifier=False))
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("irk", type=str)
|
||||
@click.argument("rpa", type=str)
|
||||
def verify_rpa(irk: str, rpa: str) -> None:
|
||||
address = Address(rpa)
|
||||
irk_bytes = bytes.fromhex(irk)
|
||||
if verify_rpa_with_irk(address, irk_bytes):
|
||||
print(color("Verified", "green"))
|
||||
else:
|
||||
print(color("Not Verified", "red"))
|
||||
|
||||
|
||||
def main():
|
||||
cli.add_command(gen_irk)
|
||||
cli.add_command(gen_rpa)
|
||||
cli.add_command(verify_rpa)
|
||||
cli()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -63,6 +63,7 @@ from bumble.transport import open_transport_or_link
|
||||
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
|
||||
from bumble.gatt_client import CharacteristicProxy
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
HCI_Constant,
|
||||
HCI_LE_1M_PHY,
|
||||
HCI_LE_2M_PHY,
|
||||
@@ -289,11 +290,7 @@ class ConsoleApp:
|
||||
device_config, hci_source, hci_sink
|
||||
)
|
||||
else:
|
||||
random_address = (
|
||||
f"{random.randint(192,255):02X}" # address is static random
|
||||
)
|
||||
for random_byte in random.sample(range(255), 5):
|
||||
random_address += f":{random_byte:02X}"
|
||||
random_address = Address.generate_static_address()
|
||||
self.append_to_log(f"Setting random address: {random_address}")
|
||||
self.device = Device.with_hci(
|
||||
'Bumble', random_address, hci_source, hci_sink
|
||||
@@ -503,21 +500,9 @@ class ConsoleApp:
|
||||
self.show_error('not connected')
|
||||
return
|
||||
|
||||
# Discover all services, characteristics and descriptors
|
||||
self.append_to_output('discovering services...')
|
||||
await self.connected_peer.discover_services()
|
||||
self.append_to_output(
|
||||
f'found {len(self.connected_peer.services)} services,'
|
||||
' discovering characteristics...'
|
||||
)
|
||||
await self.connected_peer.discover_characteristics()
|
||||
self.append_to_output('found characteristics, discovering descriptors...')
|
||||
for service in self.connected_peer.services:
|
||||
for characteristic in service.characteristics:
|
||||
await self.connected_peer.discover_descriptors(characteristic)
|
||||
self.append_to_output('discovery completed')
|
||||
|
||||
self.show_remote_services(self.connected_peer.services)
|
||||
self.append_to_output('Service Discovery starting...')
|
||||
await self.connected_peer.discover_all()
|
||||
self.append_to_output('Service Discovery done!')
|
||||
|
||||
async def discover_attributes(self):
|
||||
if not self.connected_peer:
|
||||
@@ -777,7 +762,7 @@ class ConsoleApp:
|
||||
if not service:
|
||||
continue
|
||||
values = [
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
for connection in self.device.connections.values()
|
||||
]
|
||||
if not values:
|
||||
@@ -796,11 +781,11 @@ class ConsoleApp:
|
||||
if not characteristic:
|
||||
continue
|
||||
values = [
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
for connection in self.device.connections.values()
|
||||
]
|
||||
if not values:
|
||||
values = [attribute.read_value(None)]
|
||||
values = [await attribute.read_value(None)]
|
||||
|
||||
# TODO: future optimization: convert CCCD value to human readable string
|
||||
|
||||
@@ -944,7 +929,7 @@ class ConsoleApp:
|
||||
|
||||
# send data to any subscribers
|
||||
if isinstance(attribute, Characteristic):
|
||||
attribute.write_value(None, value)
|
||||
await attribute.write_value(None, value)
|
||||
if attribute.has_properties(Characteristic.NOTIFY):
|
||||
await self.device.gatt_server.notify_subscribers(attribute)
|
||||
if attribute.has_properties(Characteristic.INDICATE):
|
||||
|
||||
@@ -18,30 +18,39 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import click
|
||||
from bumble.company_ids import COMPANY_IDENTIFIERS
|
||||
import time
|
||||
|
||||
import click
|
||||
|
||||
from bumble.company_ids import COMPANY_IDENTIFIERS
|
||||
from bumble.colors import color
|
||||
from bumble.core import name_or_number
|
||||
from bumble.hci import (
|
||||
map_null_terminated_utf8_string,
|
||||
LeFeature,
|
||||
HCI_SUCCESS,
|
||||
HCI_LE_SUPPORTED_FEATURES_NAMES,
|
||||
HCI_VERSION_NAMES,
|
||||
LMP_VERSION_NAMES,
|
||||
HCI_Command,
|
||||
HCI_Command_Complete_Event,
|
||||
HCI_Command_Status_Event,
|
||||
HCI_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_Read_Buffer_Size_Command,
|
||||
HCI_READ_BD_ADDR_COMMAND,
|
||||
HCI_Read_BD_ADDR_Command,
|
||||
HCI_READ_LOCAL_NAME_COMMAND,
|
||||
HCI_Read_Local_Name_Command,
|
||||
HCI_LE_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_LE_Read_Buffer_Size_Command,
|
||||
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Maximum_Data_Length_Command,
|
||||
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
|
||||
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
|
||||
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
|
||||
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Suggested_Default_Data_Length_Command,
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
)
|
||||
from bumble.host import Host
|
||||
from bumble.transport import open_transport_or_link
|
||||
@@ -57,7 +66,7 @@ def command_succeeded(response):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_classic_info(host):
|
||||
async def get_classic_info(host: Host) -> None:
|
||||
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
|
||||
response = await host.send_command(HCI_Read_BD_ADDR_Command())
|
||||
if command_succeeded(response):
|
||||
@@ -78,7 +87,7 @@ async def get_classic_info(host):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_le_info(host):
|
||||
async def get_le_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
|
||||
@@ -117,13 +126,50 @@ async def get_le_info(host):
|
||||
'\n',
|
||||
)
|
||||
|
||||
if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_LE_Read_Suggested_Default_Data_Length_Command()
|
||||
)
|
||||
if command_succeeded(response):
|
||||
print(
|
||||
color('Suggested Default Data Length:', 'yellow'),
|
||||
f'{response.return_parameters.suggested_max_tx_octets}/'
|
||||
f'{response.return_parameters.suggested_max_tx_time}',
|
||||
'\n',
|
||||
)
|
||||
|
||||
print(color('LE Features:', 'yellow'))
|
||||
for feature in host.supported_le_features:
|
||||
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
|
||||
print(f' {LeFeature(feature).name}')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(transport):
|
||||
async def get_acl_flow_control_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.hc_total_num_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_LE_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('LE ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(latency_probes, transport):
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
@@ -131,6 +177,23 @@ async def async_main(transport):
|
||||
host = Host(hci_source, hci_sink)
|
||||
await host.reset()
|
||||
|
||||
# Measure the latency if requested
|
||||
latencies = []
|
||||
if latency_probes:
|
||||
for _ in range(latency_probes):
|
||||
start = time.time()
|
||||
await host.send_command(HCI_Read_Local_Version_Information_Command())
|
||||
latencies.append(1000 * (time.time() - start))
|
||||
print(
|
||||
color('HCI Command Latency:', 'yellow'),
|
||||
(
|
||||
f'min={min(latencies):.2f}, '
|
||||
f'max={max(latencies):.2f}, '
|
||||
f'average={sum(latencies)/len(latencies):.2f}'
|
||||
),
|
||||
'\n',
|
||||
)
|
||||
|
||||
# Print version
|
||||
print(color('Version:', 'yellow'))
|
||||
print(
|
||||
@@ -154,19 +217,28 @@ async def async_main(transport):
|
||||
# Get the LE info
|
||||
await get_le_info(host)
|
||||
|
||||
# Print the ACL flow control info
|
||||
await get_acl_flow_control_info(host)
|
||||
|
||||
# Print the list of commands supported by the controller
|
||||
print()
|
||||
print(color('Supported Commands:', 'yellow'))
|
||||
for command in host.supported_commands:
|
||||
print(' ', HCI_Command.command_name(command))
|
||||
print(f' {HCI_Command.command_name(command)}')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--latency-probes',
|
||||
metavar='N',
|
||||
type=int,
|
||||
help='Send N commands to measure HCI transport latency statistics',
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(transport):
|
||||
def main(latency_probes, transport):
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
asyncio.run(async_main(transport))
|
||||
asyncio.run(async_main(latency_probes, transport))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
205
apps/controller_loopback.py
Normal file
205
apps/controller_loopback.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from bumble.colors import color
|
||||
from bumble.hci import (
|
||||
HCI_READ_LOOPBACK_MODE_COMMAND,
|
||||
HCI_Read_Loopback_Mode_Command,
|
||||
HCI_WRITE_LOOPBACK_MODE_COMMAND,
|
||||
HCI_Write_Loopback_Mode_Command,
|
||||
LoopbackMode,
|
||||
)
|
||||
from bumble.host import Host
|
||||
from bumble.transport import open_transport_or_link
|
||||
import click
|
||||
|
||||
|
||||
class Loopback:
|
||||
"""Send and receive ACL data packets in local loopback mode"""
|
||||
|
||||
def __init__(self, packet_size: int, packet_count: int, transport: str):
|
||||
self.transport = transport
|
||||
self.packet_size = packet_size
|
||||
self.packet_count = packet_count
|
||||
self.connection_handle: Optional[int] = None
|
||||
self.connection_event = asyncio.Event()
|
||||
self.done = asyncio.Event()
|
||||
self.expected_cid = 0
|
||||
self.bytes_received = 0
|
||||
self.start_timestamp = 0.0
|
||||
self.last_timestamp = 0.0
|
||||
|
||||
def on_connection(self, connection_handle: int, *args):
|
||||
"""Retrieve connection handle from new connection event"""
|
||||
if not self.connection_event.is_set():
|
||||
# save first connection handle for ACL
|
||||
# subsequent connections are SCO
|
||||
self.connection_handle = connection_handle
|
||||
self.connection_event.set()
|
||||
|
||||
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
|
||||
"""Calculate packet receive speed"""
|
||||
now = time.time()
|
||||
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
|
||||
assert connection_handle == self.connection_handle
|
||||
assert cid == self.expected_cid
|
||||
self.expected_cid += 1
|
||||
if cid == 0:
|
||||
self.start_timestamp = now
|
||||
else:
|
||||
elapsed_since_start = now - self.start_timestamp
|
||||
elapsed_since_last = now - self.last_timestamp
|
||||
self.bytes_received += len(pdu)
|
||||
instant_rx_speed = len(pdu) / elapsed_since_last
|
||||
average_rx_speed = self.bytes_received / elapsed_since_start
|
||||
print(
|
||||
color(
|
||||
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
|
||||
f' average={average_rx_speed:.4f}',
|
||||
'cyan',
|
||||
)
|
||||
)
|
||||
|
||||
self.last_timestamp = now
|
||||
|
||||
if self.expected_cid == self.packet_count:
|
||||
print(color('@@@ Received last packet', 'green'))
|
||||
self.done.set()
|
||||
|
||||
async def run(self):
|
||||
"""Run a loopback throughput test"""
|
||||
print(color('>>> Connecting to HCI...', 'green'))
|
||||
async with await open_transport_or_link(self.transport) as (
|
||||
hci_source,
|
||||
hci_sink,
|
||||
):
|
||||
print(color('>>> Connected', 'green'))
|
||||
|
||||
host = Host(hci_source, hci_sink)
|
||||
await host.reset()
|
||||
|
||||
# make sure data can fit in one l2cap pdu
|
||||
l2cap_header_size = 4
|
||||
|
||||
max_packet_size = (
|
||||
host.acl_packet_queue
|
||||
if host.acl_packet_queue
|
||||
else host.le_acl_packet_queue
|
||||
).max_packet_size - l2cap_header_size
|
||||
if self.packet_size > max_packet_size:
|
||||
print(
|
||||
color(
|
||||
f'!!! Packet size ({self.packet_size}) larger than max supported'
|
||||
f' size ({max_packet_size})',
|
||||
'red',
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if not host.supports_command(
|
||||
HCI_WRITE_LOOPBACK_MODE_COMMAND
|
||||
) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND):
|
||||
print(color('!!! Loopback mode not supported', 'red'))
|
||||
return
|
||||
|
||||
# set event callbacks
|
||||
host.on('connection', self.on_connection)
|
||||
host.on('l2cap_pdu', self.on_l2cap_pdu)
|
||||
|
||||
loopback_mode = LoopbackMode.LOCAL
|
||||
|
||||
print(color('### Setting loopback mode', 'blue'))
|
||||
await host.send_command(
|
||||
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
print(color('### Checking loopback mode', 'blue'))
|
||||
response = await host.send_command(
|
||||
HCI_Read_Loopback_Mode_Command(), check_result=True
|
||||
)
|
||||
if response.return_parameters.loopback_mode != loopback_mode:
|
||||
print(color('!!! Loopback mode mismatch', 'red'))
|
||||
return
|
||||
|
||||
await self.connection_event.wait()
|
||||
print(color('### Connected', 'cyan'))
|
||||
|
||||
print(color('=== Start sending', 'magenta'))
|
||||
start_time = time.time()
|
||||
bytes_sent = 0
|
||||
for cid in range(0, self.packet_count):
|
||||
# using the cid as an incremental index
|
||||
host.send_l2cap_pdu(
|
||||
self.connection_handle, cid, bytes(self.packet_size)
|
||||
)
|
||||
print(
|
||||
color(
|
||||
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
|
||||
)
|
||||
)
|
||||
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
|
||||
await asyncio.sleep(0) # yield to allow packet receive
|
||||
|
||||
await self.done.wait()
|
||||
print(color('=== Done!', 'magenta'))
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
average_tx_speed = bytes_sent / elapsed
|
||||
print(
|
||||
color(
|
||||
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
|
||||
f' in {elapsed:.2f} seconds)',
|
||||
'green',
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--packet-size',
|
||||
'-s',
|
||||
metavar='SIZE',
|
||||
type=click.IntRange(8, 4096),
|
||||
default=500,
|
||||
help='Packet size',
|
||||
)
|
||||
@click.option(
|
||||
'--packet-count',
|
||||
'-c',
|
||||
metavar='COUNT',
|
||||
type=click.IntRange(1, 65535),
|
||||
default=10,
|
||||
help='Packet count',
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(packet_size, packet_count, transport):
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
|
||||
loopback = Loopback(packet_size, packet_count, transport)
|
||||
asyncio.run(loopback.run())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -21,6 +21,7 @@ import struct
|
||||
import logging
|
||||
import click
|
||||
|
||||
from bumble import l2cap
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.core import AdvertisingData
|
||||
@@ -204,7 +205,7 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
|
||||
def __init__(self, device):
|
||||
def __init__(self, device: Device):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.peer = None
|
||||
@@ -218,7 +219,12 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
|
||||
|
||||
# Listen for incoming L2CAP CoC connections
|
||||
psm = 0xFB
|
||||
device.register_l2cap_channel_server(0xFB, self.on_coc)
|
||||
device.create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(
|
||||
psm=0xFB,
|
||||
),
|
||||
handler=self.on_coc,
|
||||
)
|
||||
print(f'### Listening for CoC connection on PSM {psm}')
|
||||
|
||||
# Setup the Gattlink service
|
||||
|
||||
@@ -20,6 +20,7 @@ import logging
|
||||
import os
|
||||
import click
|
||||
|
||||
from bumble import l2cap
|
||||
from bumble.colors import color
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.device import Device
|
||||
@@ -47,16 +48,17 @@ class ServerBridge:
|
||||
self.tcp_host = tcp_host
|
||||
self.tcp_port = tcp_port
|
||||
|
||||
async def start(self, device):
|
||||
# Listen for incoming L2CAP CoC connections
|
||||
device.register_l2cap_channel_server(
|
||||
psm=self.psm,
|
||||
server=self.on_coc,
|
||||
max_credits=self.max_credits,
|
||||
mtu=self.mtu,
|
||||
mps=self.mps,
|
||||
async def start(self, device: Device) -> None:
|
||||
# Listen for incoming L2CAP channel connections
|
||||
device.create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(
|
||||
psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits
|
||||
),
|
||||
handler=self.on_channel,
|
||||
)
|
||||
print(
|
||||
color(f'### Listening for channel connection on PSM {self.psm}', 'yellow')
|
||||
)
|
||||
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
|
||||
|
||||
def on_ble_connection(connection):
|
||||
def on_ble_disconnection(reason):
|
||||
@@ -73,7 +75,7 @@ class ServerBridge:
|
||||
await device.start_advertising(auto_restart=True)
|
||||
|
||||
# Called when a new L2CAP connection is established
|
||||
def on_coc(self, l2cap_channel):
|
||||
def on_channel(self, l2cap_channel):
|
||||
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
|
||||
|
||||
class Pipe:
|
||||
@@ -83,7 +85,7 @@ class ServerBridge:
|
||||
self.l2cap_channel = l2cap_channel
|
||||
|
||||
l2cap_channel.on('close', self.on_l2cap_close)
|
||||
l2cap_channel.sink = self.on_coc_sdu
|
||||
l2cap_channel.sink = self.on_channel_sdu
|
||||
|
||||
async def connect_to_tcp(self):
|
||||
# Connect to the TCP server
|
||||
@@ -128,7 +130,7 @@ class ServerBridge:
|
||||
if self.tcp_transport is not None:
|
||||
self.tcp_transport.close()
|
||||
|
||||
def on_coc_sdu(self, sdu):
|
||||
def on_channel_sdu(self, sdu):
|
||||
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
|
||||
if self.tcp_transport is None:
|
||||
print(color('!!! TCP socket not open, dropping', 'red'))
|
||||
@@ -183,7 +185,7 @@ class ClientBridge:
|
||||
peer_name = writer.get_extra_info('peer_name')
|
||||
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
|
||||
|
||||
def on_coc_sdu(sdu):
|
||||
def on_channel_sdu(sdu):
|
||||
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
|
||||
l2cap_to_tcp_pipe.write(sdu)
|
||||
|
||||
@@ -195,11 +197,13 @@ class ClientBridge:
|
||||
# Connect a new L2CAP channel
|
||||
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
|
||||
try:
|
||||
l2cap_channel = await connection.open_l2cap_channel(
|
||||
psm=self.psm,
|
||||
max_credits=self.max_credits,
|
||||
mtu=self.mtu,
|
||||
mps=self.mps,
|
||||
l2cap_channel = await connection.create_l2cap_channel(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(
|
||||
psm=self.psm,
|
||||
max_credits=self.max_credits,
|
||||
mtu=self.mtu,
|
||||
mps=self.mps,
|
||||
)
|
||||
)
|
||||
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
|
||||
except Exception as error:
|
||||
@@ -207,7 +211,7 @@ class ClientBridge:
|
||||
writer.close()
|
||||
return
|
||||
|
||||
l2cap_channel.sink = on_coc_sdu
|
||||
l2cap_channel.sink = on_channel_sdu
|
||||
l2cap_channel.on('close', on_l2cap_close)
|
||||
|
||||
# Start a flow control pipe from L2CAP to TCP
|
||||
@@ -272,23 +276,29 @@ async def run(device_config, hci_transport, bridge):
|
||||
@click.pass_context
|
||||
@click.option('--device-config', help='Device configuration file', required=True)
|
||||
@click.option('--hci-transport', help='HCI transport', required=True)
|
||||
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
|
||||
@click.option('--psm', help='PSM for L2CAP', type=int, default=1234)
|
||||
@click.option(
|
||||
'--l2cap-coc-max-credits',
|
||||
help='Maximum L2CAP CoC Credits',
|
||||
'--l2cap-max-credits',
|
||||
help='Maximum L2CAP Credits',
|
||||
type=click.IntRange(1, 65535),
|
||||
default=128,
|
||||
)
|
||||
@click.option(
|
||||
'--l2cap-coc-mtu',
|
||||
help='L2CAP CoC MTU',
|
||||
type=click.IntRange(23, 65535),
|
||||
default=1022,
|
||||
'--l2cap-mtu',
|
||||
help='L2CAP MTU',
|
||||
type=click.IntRange(
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU,
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU,
|
||||
),
|
||||
default=1024,
|
||||
)
|
||||
@click.option(
|
||||
'--l2cap-coc-mps',
|
||||
help='L2CAP CoC MPS',
|
||||
type=click.IntRange(23, 65533),
|
||||
'--l2cap-mps',
|
||||
help='L2CAP MPS',
|
||||
type=click.IntRange(
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS,
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS,
|
||||
),
|
||||
default=1024,
|
||||
)
|
||||
def cli(
|
||||
@@ -296,17 +306,17 @@ def cli(
|
||||
device_config,
|
||||
hci_transport,
|
||||
psm,
|
||||
l2cap_coc_max_credits,
|
||||
l2cap_coc_mtu,
|
||||
l2cap_coc_mps,
|
||||
l2cap_max_credits,
|
||||
l2cap_mtu,
|
||||
l2cap_mps,
|
||||
):
|
||||
context.ensure_object(dict)
|
||||
context.obj['device_config'] = device_config
|
||||
context.obj['hci_transport'] = hci_transport
|
||||
context.obj['psm'] = psm
|
||||
context.obj['max_credits'] = l2cap_coc_max_credits
|
||||
context.obj['mtu'] = l2cap_coc_mtu
|
||||
context.obj['mps'] = l2cap_coc_mps
|
||||
context.obj['max_credits'] = l2cap_max_credits
|
||||
context.obj['mtu'] = l2cap_mtu
|
||||
context.obj['mps'] = l2cap_mps
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
577
apps/lea_unicast/app.py
Normal file
577
apps/lea_unicast/app.py
Normal file
@@ -0,0 +1,577 @@
|
||||
# Copyright 2021-2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import datetime
|
||||
import enum
|
||||
import functools
|
||||
from importlib import resources
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Optional, List, cast
|
||||
import weakref
|
||||
import struct
|
||||
|
||||
import ctypes
|
||||
import wasmtime
|
||||
import wasmtime.loader
|
||||
import liblc3 # type: ignore
|
||||
import logging
|
||||
|
||||
import click
|
||||
import aiohttp.web
|
||||
|
||||
import bumble
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
|
||||
from bumble.transport import open_transport
|
||||
from bumble.profiles import bap
|
||||
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
DEFAULT_UI_PORT = 7654
|
||||
|
||||
|
||||
def _sink_pac_record() -> bap.PacRecord:
|
||||
return bap.PacRecord(
|
||||
coding_format=CodingFormat(CodecID.LC3),
|
||||
codec_specific_capabilities=bap.CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=(
|
||||
bap.SupportedSamplingFrequency.FREQ_8000
|
||||
| bap.SupportedSamplingFrequency.FREQ_16000
|
||||
| bap.SupportedSamplingFrequency.FREQ_24000
|
||||
| bap.SupportedSamplingFrequency.FREQ_32000
|
||||
| bap.SupportedSamplingFrequency.FREQ_48000
|
||||
),
|
||||
supported_frame_durations=(
|
||||
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
|
||||
),
|
||||
supported_audio_channel_count=[1, 2],
|
||||
min_octets_per_codec_frame=26,
|
||||
max_octets_per_codec_frame=240,
|
||||
supported_max_codec_frames_per_sdu=2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _source_pac_record() -> bap.PacRecord:
|
||||
return bap.PacRecord(
|
||||
coding_format=CodingFormat(CodecID.LC3),
|
||||
codec_specific_capabilities=bap.CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=(
|
||||
bap.SupportedSamplingFrequency.FREQ_8000
|
||||
| bap.SupportedSamplingFrequency.FREQ_16000
|
||||
| bap.SupportedSamplingFrequency.FREQ_24000
|
||||
| bap.SupportedSamplingFrequency.FREQ_32000
|
||||
| bap.SupportedSamplingFrequency.FREQ_48000
|
||||
),
|
||||
supported_frame_durations=(
|
||||
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
|
||||
),
|
||||
supported_audio_channel_count=[1],
|
||||
min_octets_per_codec_frame=30,
|
||||
max_octets_per_codec_frame=100,
|
||||
supported_max_codec_frames_per_sdu=1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# WASM - liblc3
|
||||
# -----------------------------------------------------------------------------
|
||||
store = wasmtime.loader.store
|
||||
_memory = cast(wasmtime.Memory, liblc3.memory)
|
||||
STACK_POINTER = _memory.data_len(store)
|
||||
_memory.grow(store, 1)
|
||||
# Mapping wasmtime memory to linear address
|
||||
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
|
||||
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class Liblc3PcmFormat(enum.IntEnum):
|
||||
S16 = 0
|
||||
S24 = 1
|
||||
S24_3LE = 2
|
||||
FLOAT = 3
|
||||
|
||||
|
||||
MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
|
||||
MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)
|
||||
|
||||
DECODER_STACK_POINTER = STACK_POINTER
|
||||
ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
|
||||
DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
|
||||
ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
|
||||
DEFAULT_PCM_SAMPLE_RATE = 48000
|
||||
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
|
||||
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
|
||||
|
||||
|
||||
encoders: List[int] = []
|
||||
decoders: List[int] = []
|
||||
|
||||
|
||||
def setup_encoders(
|
||||
sample_rate_hz: int, frame_duration_us: int, num_channels: int
|
||||
) -> None:
|
||||
logger.info(
|
||||
f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
|
||||
)
|
||||
encoders[:num_channels] = [
|
||||
liblc3.lc3_setup_encoder(
|
||||
frame_duration_us,
|
||||
sample_rate_hz,
|
||||
DEFAULT_PCM_SAMPLE_RATE, # Input sample rate
|
||||
ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
|
||||
)
|
||||
for i in range(num_channels)
|
||||
]
|
||||
|
||||
|
||||
def setup_decoders(
|
||||
sample_rate_hz: int, frame_duration_us: int, num_channels: int
|
||||
) -> None:
|
||||
logger.info(
|
||||
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
|
||||
)
|
||||
decoders[:num_channels] = [
|
||||
liblc3.lc3_setup_decoder(
|
||||
frame_duration_us,
|
||||
sample_rate_hz,
|
||||
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
|
||||
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
|
||||
)
|
||||
for i in range(num_channels)
|
||||
]
|
||||
|
||||
|
||||
def decode(
|
||||
frame_duration_us: int,
|
||||
num_channels: int,
|
||||
input_bytes: bytes,
|
||||
) -> bytes:
|
||||
if not input_bytes:
|
||||
return b''
|
||||
|
||||
input_buffer_offset = DECODE_BUFFER_STACK_POINTER
|
||||
input_buffer_size = len(input_bytes)
|
||||
input_bytes_per_frame = input_buffer_size // num_channels
|
||||
|
||||
# Copy into wasm
|
||||
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
|
||||
|
||||
output_buffer_offset = input_buffer_offset + input_buffer_size
|
||||
output_buffer_size = (
|
||||
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
|
||||
* DEFAULT_PCM_BYTES_PER_SAMPLE
|
||||
* num_channels
|
||||
)
|
||||
|
||||
for i in range(num_channels):
|
||||
res = liblc3.lc3_decode(
|
||||
decoders[i],
|
||||
input_buffer_offset + input_bytes_per_frame * i,
|
||||
input_bytes_per_frame,
|
||||
DEFAULT_PCM_FORMAT,
|
||||
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
|
||||
num_channels, # Stride
|
||||
)
|
||||
|
||||
if res != 0:
|
||||
logging.error(f"Parsing failed, res={res}")
|
||||
|
||||
# Extract decoded data from the output buffer
|
||||
return bytes(
|
||||
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
|
||||
)
|
||||
|
||||
|
||||
def encode(
|
||||
sdu_length: int,
|
||||
num_channels: int,
|
||||
stride: int,
|
||||
input_bytes: bytes,
|
||||
) -> bytes:
|
||||
if not input_bytes:
|
||||
return b''
|
||||
|
||||
input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
|
||||
input_buffer_size = len(input_bytes)
|
||||
|
||||
# Copy into wasm
|
||||
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
|
||||
|
||||
output_buffer_offset = input_buffer_offset + input_buffer_size
|
||||
output_buffer_size = sdu_length
|
||||
output_frame_size = output_buffer_size // num_channels
|
||||
|
||||
for i in range(num_channels):
|
||||
res = liblc3.lc3_encode(
|
||||
encoders[i],
|
||||
DEFAULT_PCM_FORMAT,
|
||||
input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
|
||||
stride,
|
||||
output_frame_size,
|
||||
output_buffer_offset + output_frame_size * i,
|
||||
)
|
||||
|
||||
if res != 0:
|
||||
logging.error(f"Parsing failed, res={res}")
|
||||
|
||||
# Extract decoded data from the output buffer
|
||||
return bytes(
|
||||
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
|
||||
)
|
||||
|
||||
|
||||
async def lc3_source_task(
|
||||
filename: str,
|
||||
sdu_length: int,
|
||||
frame_duration_us: int,
|
||||
device: Device,
|
||||
cis_handle: int,
|
||||
) -> None:
|
||||
with open(filename, 'rb') as f:
|
||||
header = f.read(44)
|
||||
assert header[8:12] == b'WAVE'
|
||||
|
||||
pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = (
|
||||
struct.unpack("<HIIHH", header[22:36])
|
||||
)
|
||||
assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
|
||||
assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
|
||||
|
||||
frame_bytes = (
|
||||
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
|
||||
* DEFAULT_PCM_BYTES_PER_SAMPLE
|
||||
)
|
||||
packet_sequence_number = 0
|
||||
|
||||
while True:
|
||||
next_round = datetime.datetime.now() + datetime.timedelta(
|
||||
microseconds=frame_duration_us
|
||||
)
|
||||
pcm_data = f.read(frame_bytes)
|
||||
sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data)
|
||||
|
||||
iso_packet = HCI_IsoDataPacket(
|
||||
connection_handle=cis_handle,
|
||||
data_total_length=sdu_length + 4,
|
||||
packet_sequence_number=packet_sequence_number,
|
||||
pb_flag=0b10,
|
||||
packet_status_flag=0,
|
||||
iso_sdu_length=sdu_length,
|
||||
iso_sdu_fragment=sdu,
|
||||
)
|
||||
device.host.send_hci_packet(iso_packet)
|
||||
packet_sequence_number += 1
|
||||
sleep_time = next_round - datetime.datetime.now()
|
||||
await asyncio.sleep(sleep_time.total_seconds())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class UiServer:
|
||||
speaker: weakref.ReferenceType[Speaker]
|
||||
port: int
|
||||
|
||||
def __init__(self, speaker: Speaker, port: int) -> None:
|
||||
self.speaker = weakref.ref(speaker)
|
||||
self.port = port
|
||||
self.channel_socket = None
|
||||
|
||||
async def start_http(self) -> None:
|
||||
"""Start the UI HTTP server."""
|
||||
|
||||
app = aiohttp.web.Application()
|
||||
app.add_routes(
|
||||
[
|
||||
aiohttp.web.get('/', self.get_static),
|
||||
aiohttp.web.get('/index.html', self.get_static),
|
||||
aiohttp.web.get('/channel', self.get_channel),
|
||||
]
|
||||
)
|
||||
|
||||
runner = aiohttp.web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = aiohttp.web.TCPSite(runner, 'localhost', self.port)
|
||||
print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green'))
|
||||
await site.start()
|
||||
|
||||
async def get_static(self, request):
|
||||
path = request.path
|
||||
if path == '/':
|
||||
path = '/index.html'
|
||||
if path.endswith('.html'):
|
||||
content_type = 'text/html'
|
||||
elif path.endswith('.js'):
|
||||
content_type = 'text/javascript'
|
||||
elif path.endswith('.css'):
|
||||
content_type = 'text/css'
|
||||
elif path.endswith('.svg'):
|
||||
content_type = 'image/svg+xml'
|
||||
else:
|
||||
content_type = 'text/plain'
|
||||
text = (
|
||||
resources.files("bumble.apps.lea_unicast")
|
||||
.joinpath(pathlib.Path(path).relative_to('/'))
|
||||
.read_text(encoding="utf-8")
|
||||
)
|
||||
return aiohttp.web.Response(text=text, content_type=content_type)
|
||||
|
||||
async def get_channel(self, request):
|
||||
ws = aiohttp.web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
# Process messages until the socket is closed.
|
||||
self.channel_socket = ws
|
||||
async for message in ws:
|
||||
if message.type == aiohttp.WSMsgType.TEXT:
|
||||
logger.debug(f'<<< received message: {message.data}')
|
||||
await self.on_message(message.data)
|
||||
elif message.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.debug(
|
||||
f'channel connection closed with exception {ws.exception()}'
|
||||
)
|
||||
|
||||
self.channel_socket = None
|
||||
logger.debug('--- channel connection closed')
|
||||
|
||||
return ws
|
||||
|
||||
async def on_message(self, message_str: str):
|
||||
# Parse the message as JSON
|
||||
message = json.loads(message_str)
|
||||
|
||||
# Dispatch the message
|
||||
message_type = message['type']
|
||||
message_params = message.get('params', {})
|
||||
handler = getattr(self, f'on_{message_type}_message')
|
||||
if handler:
|
||||
await handler(**message_params)
|
||||
|
||||
async def on_hello_message(self):
|
||||
await self.send_message(
|
||||
'hello',
|
||||
bumble_version=bumble.__version__,
|
||||
codec=self.speaker().codec,
|
||||
streamState=self.speaker().stream_state.name,
|
||||
)
|
||||
if connection := self.speaker().connection:
|
||||
await self.send_message(
|
||||
'connection',
|
||||
peer_address=connection.peer_address.to_string(False),
|
||||
peer_name=connection.peer_name,
|
||||
)
|
||||
|
||||
async def send_message(self, message_type: str, **kwargs) -> None:
|
||||
if self.channel_socket is None:
|
||||
return
|
||||
|
||||
message = {'type': message_type, 'params': kwargs}
|
||||
await self.channel_socket.send_json(message)
|
||||
|
||||
async def send_audio(self, data: bytes) -> None:
|
||||
if self.channel_socket is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.channel_socket.send_bytes(data)
|
||||
except Exception as error:
|
||||
logger.warning(f'exception while sending audio packet: {error}')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Speaker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_config_path: Optional[str],
|
||||
ui_port: int,
|
||||
transport: str,
|
||||
lc3_input_file_path: str,
|
||||
):
|
||||
self.device_config_path = device_config_path
|
||||
self.transport = transport
|
||||
self.lc3_input_file_path = lc3_input_file_path
|
||||
|
||||
# Create an HTTP server for the UI
|
||||
self.ui_server = UiServer(speaker=self, port=ui_port)
|
||||
|
||||
async def run(self) -> None:
|
||||
await self.ui_server.start_http()
|
||||
|
||||
async with await open_transport(self.transport) as hci_transport:
|
||||
# Create a device
|
||||
if self.device_config_path:
|
||||
device_config = DeviceConfiguration.from_file(self.device_config_path)
|
||||
else:
|
||||
device_config = DeviceConfiguration(
|
||||
name="Bumble LE Headphone",
|
||||
class_of_device=0x244418,
|
||||
keystore="JsonKeyStore",
|
||||
advertising_interval_min=25,
|
||||
advertising_interval_max=25,
|
||||
address=Address('F1:F2:F3:F4:F5:F6'),
|
||||
)
|
||||
|
||||
device_config.le_enabled = True
|
||||
device_config.cis_enabled = True
|
||||
self.device = Device.from_config_with_hci(
|
||||
device_config, hci_transport.source, hci_transport.sink
|
||||
)
|
||||
|
||||
self.device.add_service(
|
||||
bap.PublishedAudioCapabilitiesService(
|
||||
supported_source_context=bap.ContextType(0xFFFF),
|
||||
available_source_context=bap.ContextType(0xFFFF),
|
||||
supported_sink_context=bap.ContextType(0xFFFF), # All context types
|
||||
available_sink_context=bap.ContextType(0xFFFF), # All context types
|
||||
sink_audio_locations=(
|
||||
bap.AudioLocation.FRONT_LEFT | bap.AudioLocation.FRONT_RIGHT
|
||||
),
|
||||
sink_pac=[_sink_pac_record()],
|
||||
source_audio_locations=bap.AudioLocation.FRONT_LEFT,
|
||||
source_pac=[_source_pac_record()],
|
||||
)
|
||||
)
|
||||
|
||||
ascs = bap.AudioStreamControlService(
|
||||
self.device, sink_ase_id=[1], source_ase_id=[2]
|
||||
)
|
||||
self.device.add_service(ascs)
|
||||
|
||||
advertising_data = bytes(
|
||||
AdvertisingData(
|
||||
[
|
||||
(
|
||||
AdvertisingData.COMPLETE_LOCAL_NAME,
|
||||
bytes(device_config.name, 'utf-8'),
|
||||
),
|
||||
(
|
||||
AdvertisingData.FLAGS,
|
||||
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
|
||||
),
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
bytes(bap.PublishedAudioCapabilitiesService.UUID),
|
||||
),
|
||||
]
|
||||
)
|
||||
) + bytes(bap.UnicastServerAdvertisingData())
|
||||
|
||||
def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine):
|
||||
codec_config = ase.codec_specific_configuration
|
||||
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
pcm = decode(
|
||||
codec_config.frame_duration.us,
|
||||
codec_config.audio_channel_allocation.channel_count,
|
||||
pdu.iso_sdu_fragment,
|
||||
)
|
||||
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
|
||||
|
||||
def on_ase_state_change(ase: bap.AseStateMachine) -> None:
|
||||
if ase.state == bap.AseStateMachine.State.STREAMING:
|
||||
codec_config = ase.codec_specific_configuration
|
||||
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
assert ase.cis_link
|
||||
if ase.role == bap.AudioRole.SOURCE:
|
||||
ase.cis_link.abort_on(
|
||||
'disconnection',
|
||||
lc3_source_task(
|
||||
filename=self.lc3_input_file_path,
|
||||
sdu_length=(
|
||||
codec_config.codec_frames_per_sdu
|
||||
* codec_config.octets_per_codec_frame
|
||||
),
|
||||
frame_duration_us=codec_config.frame_duration.us,
|
||||
device=self.device,
|
||||
cis_handle=ase.cis_link.handle,
|
||||
),
|
||||
)
|
||||
else:
|
||||
ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
|
||||
elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED:
|
||||
codec_config = ase.codec_specific_configuration
|
||||
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
if ase.role == bap.AudioRole.SOURCE:
|
||||
setup_encoders(
|
||||
codec_config.sampling_frequency.hz,
|
||||
codec_config.frame_duration.us,
|
||||
codec_config.audio_channel_allocation.channel_count,
|
||||
)
|
||||
else:
|
||||
setup_decoders(
|
||||
codec_config.sampling_frequency.hz,
|
||||
codec_config.frame_duration.us,
|
||||
codec_config.audio_channel_allocation.channel_count,
|
||||
)
|
||||
|
||||
for ase in ascs.ase_state_machines.values():
|
||||
ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
|
||||
|
||||
await self.device.power_on()
|
||||
await self.device.create_advertising_set(
|
||||
advertising_data=advertising_data,
|
||||
auto_restart=True,
|
||||
advertising_parameters=AdvertisingParameters(
|
||||
primary_advertising_interval_min=100,
|
||||
primary_advertising_interval_max=100,
|
||||
),
|
||||
)
|
||||
|
||||
await hci_transport.source.terminated
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--ui-port',
|
||||
'ui_port',
|
||||
metavar='HTTP_PORT',
|
||||
default=DEFAULT_UI_PORT,
|
||||
show_default=True,
|
||||
help='HTTP port for the UI server',
|
||||
)
|
||||
@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
|
||||
@click.argument('transport')
|
||||
@click.argument('lc3_file')
|
||||
def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) -> None:
|
||||
"""Run the speaker."""
|
||||
|
||||
asyncio.run(Speaker(device_config, ui_port, transport, lc3_file).run())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
speaker()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
68
apps/lea_unicast/index.html
Normal file
68
apps/lea_unicast/index.html
Normal file
@@ -0,0 +1,68 @@
|
||||
<html data-bs-theme="dark">
|
||||
|
||||
<head>
|
||||
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
|
||||
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
|
||||
<script src="https://unpkg.com/pcm-player"></script>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<nav class="navbar navbar-dark bg-primary">
|
||||
<div class="container">
|
||||
<span class="navbar-brand mb-0 h1">Bumble Unicast Server</span>
|
||||
</div>
|
||||
</nav>
|
||||
<br>
|
||||
|
||||
<div class="container">
|
||||
<button type="button" class="btn btn-danger" id="connect-audio" onclick="connectAudio()">Connect Audio</button>
|
||||
<button class="btn btn-primary" type="button" disabled>
|
||||
<span class="spinner-border spinner-border-sm" id="ws-status-spinner" aria-hidden="true"></span>
|
||||
<span role="status" id="ws-status">WebSocket Connecting...</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
|
||||
<script>
|
||||
let player = null;
|
||||
const wsStatus = document.getElementById("ws-status");
|
||||
const wsStatusSpinner = document.getElementById("ws-status-spinner");
|
||||
|
||||
const socket = new WebSocket('ws://127.0.0.1:7654/channel');
|
||||
socket.binaryType = "arraybuffer";
|
||||
socket.onmessage = function (message) {
|
||||
if (typeof message.data === 'string' || message.data instanceof String) {
|
||||
console.log(`channel MESSAGE: ${message.data}`);
|
||||
} else {
|
||||
console.log(typeof (message.data))
|
||||
// BINARY audio data.
|
||||
if (player == null) return;
|
||||
player.feed(message.data);
|
||||
}
|
||||
};
|
||||
|
||||
socket.onopen = (message) => {
|
||||
wsStatusSpinner.remove();
|
||||
wsStatus.textContent = "WebSocket Connected";
|
||||
}
|
||||
|
||||
socket.onclose = (message) => {
|
||||
wsStatus.textContent = "WebSocket Disconnected";
|
||||
}
|
||||
|
||||
function connectAudio() {
|
||||
player = new PCMPlayer({
|
||||
inputCodec: 'Int16',
|
||||
channels: 2,
|
||||
sampleRate: 48000,
|
||||
flushTime: 10,
|
||||
});
|
||||
const button = document.getElementById("connect-audio")
|
||||
button.disabled = true;
|
||||
button.textContent = "Audio Connected";
|
||||
}
|
||||
</script>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
BIN
apps/lea_unicast/liblc3.wasm
Executable file
BIN
apps/lea_unicast/liblc3.wasm
Executable file
Binary file not shown.
79
apps/pair.py
79
apps/pair.py
@@ -24,10 +24,16 @@ from prompt_toolkit.shortcuts import PromptSession
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.pairing import PairingDelegate, PairingConfig
|
||||
from bumble.pairing import OobData, PairingDelegate, PairingConfig
|
||||
from bumble.smp import OobContext, OobLegacyContext
|
||||
from bumble.smp import error_name as smp_error_name
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.core import ProtocolError
|
||||
from bumble.core import (
|
||||
AdvertisingData,
|
||||
ProtocolError,
|
||||
BT_LE_TRANSPORT,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
)
|
||||
from bumble.gatt import (
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
GATT_GENERIC_ACCESS_SERVICE,
|
||||
@@ -46,11 +52,13 @@ from bumble.att import (
|
||||
class Waiter:
|
||||
instance = None
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, linger=False):
|
||||
self.done = asyncio.get_running_loop().create_future()
|
||||
self.linger = linger
|
||||
|
||||
def terminate(self):
|
||||
self.done.set_result(None)
|
||||
if not self.linger:
|
||||
self.done.set_result(None)
|
||||
|
||||
async def wait_until_terminated(self):
|
||||
return await self.done
|
||||
@@ -60,7 +68,7 @@ class Waiter:
|
||||
class Delegate(PairingDelegate):
|
||||
def __init__(self, mode, connection, capability_string, do_prompt):
|
||||
super().__init__(
|
||||
{
|
||||
io_capability={
|
||||
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
|
||||
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
|
||||
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
|
||||
@@ -285,7 +293,9 @@ async def pair(
|
||||
mitm,
|
||||
bond,
|
||||
ctkd,
|
||||
linger,
|
||||
io,
|
||||
oob,
|
||||
prompt,
|
||||
request,
|
||||
print_keys,
|
||||
@@ -294,7 +304,7 @@ async def pair(
|
||||
hci_transport,
|
||||
address_or_name,
|
||||
):
|
||||
Waiter.instance = Waiter()
|
||||
Waiter.instance = Waiter(linger=linger)
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
|
||||
@@ -306,6 +316,7 @@ async def pair(
|
||||
# Expose a GATT characteristic that can be used to trigger pairing by
|
||||
# responding with an authentication error when read
|
||||
if mode == 'le':
|
||||
device.le_enabled = True
|
||||
device.add_service(
|
||||
Service(
|
||||
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
|
||||
@@ -326,7 +337,6 @@ async def pair(
|
||||
# Select LE or Classic
|
||||
if mode == 'classic':
|
||||
device.classic_enabled = True
|
||||
device.le_enabled = False
|
||||
device.classic_smp_enabled = ctkd
|
||||
|
||||
# Get things going
|
||||
@@ -343,16 +353,51 @@ async def pair(
|
||||
await device.keystore.print(prefix=color('@@@ ', 'blue'))
|
||||
print(color('@@@-----------------------------------', 'blue'))
|
||||
|
||||
# Create an OOB context if needed
|
||||
if oob:
|
||||
our_oob_context = OobContext()
|
||||
shared_data = (
|
||||
None
|
||||
if oob == '-'
|
||||
else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob)))
|
||||
)
|
||||
legacy_context = OobLegacyContext()
|
||||
oob_contexts = PairingConfig.OobConfig(
|
||||
our_context=our_oob_context,
|
||||
peer_data=shared_data,
|
||||
legacy_context=legacy_context,
|
||||
)
|
||||
oob_data = OobData(
|
||||
address=device.random_address,
|
||||
shared_data=shared_data,
|
||||
legacy_context=legacy_context,
|
||||
)
|
||||
print(color('@@@-----------------------------------', 'yellow'))
|
||||
print(color('@@@ OOB Data:', 'yellow'))
|
||||
print(color(f'@@@ {our_oob_context.share()}', 'yellow'))
|
||||
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
|
||||
print(color(f'@@@ HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow'))
|
||||
print(color('@@@-----------------------------------', 'yellow'))
|
||||
else:
|
||||
oob_contexts = None
|
||||
|
||||
# Set up a pairing config factory
|
||||
device.pairing_config_factory = lambda connection: PairingConfig(
|
||||
sc, mitm, bond, Delegate(mode, connection, io, prompt)
|
||||
sc=sc,
|
||||
mitm=mitm,
|
||||
bonding=bond,
|
||||
oob=oob_contexts,
|
||||
delegate=Delegate(mode, connection, io, prompt),
|
||||
)
|
||||
|
||||
# Connect to a peer or wait for a connection
|
||||
device.on('connection', lambda connection: on_connection(connection, request))
|
||||
if address_or_name is not None:
|
||||
print(color(f'=== Connecting to {address_or_name}...', 'green'))
|
||||
connection = await device.connect(address_or_name)
|
||||
connection = await device.connect(
|
||||
address_or_name,
|
||||
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
|
||||
)
|
||||
|
||||
if not request:
|
||||
try:
|
||||
@@ -360,10 +405,9 @@ async def pair(
|
||||
await connection.pair()
|
||||
else:
|
||||
await connection.authenticate()
|
||||
return
|
||||
except ProtocolError as error:
|
||||
print(color(f'Pairing failed: {error}', 'red'))
|
||||
return
|
||||
|
||||
else:
|
||||
if mode == 'le':
|
||||
# Advertise so that peers can find us and connect
|
||||
@@ -413,6 +457,7 @@ class LogHandler(logging.Handler):
|
||||
help='Enable CTKD',
|
||||
show_default=True,
|
||||
)
|
||||
@click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
|
||||
@click.option(
|
||||
'--io',
|
||||
type=click.Choice(
|
||||
@@ -421,6 +466,14 @@ class LogHandler(logging.Handler):
|
||||
default='display+keyboard',
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
'--oob',
|
||||
metavar='<oob-data-hex>',
|
||||
help=(
|
||||
'Use OOB pairing with this data from the peer '
|
||||
'(use "-" to enable OOB without peer data)'
|
||||
),
|
||||
)
|
||||
@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
|
||||
@click.option(
|
||||
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
|
||||
@@ -440,7 +493,9 @@ def main(
|
||||
mitm,
|
||||
bond,
|
||||
ctkd,
|
||||
linger,
|
||||
io,
|
||||
oob,
|
||||
prompt,
|
||||
request,
|
||||
print_keys,
|
||||
@@ -463,7 +518,9 @@ def main(
|
||||
mitm,
|
||||
bond,
|
||||
ctkd,
|
||||
linger,
|
||||
io,
|
||||
oob,
|
||||
prompt,
|
||||
request,
|
||||
print_keys,
|
||||
|
||||
511
apps/rfcomm_bridge.py
Normal file
511
apps/rfcomm_bridge.py
Normal file
@@ -0,0 +1,511 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, DeviceConfiguration, Connection
|
||||
from bumble import core
|
||||
from bumble import hci
|
||||
from bumble import rfcomm
|
||||
from bumble import transport
|
||||
from bumble import utils
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
DEFAULT_RFCOMM_UUID = "E6D55659-C8B4-4B85-96BB-B1143AF6D3AE"
|
||||
DEFAULT_MTU = 4096
|
||||
DEFAULT_CLIENT_TCP_PORT = 9544
|
||||
DEFAULT_SERVER_TCP_PORT = 9545
|
||||
|
||||
TRACE_MAX_SIZE = 48
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Tracer:
|
||||
"""
|
||||
Trace data buffers transmitted from one endpoint to another, with stats.
|
||||
"""
|
||||
|
||||
def __init__(self, channel_name: str) -> None:
|
||||
self.channel_name = channel_name
|
||||
self.last_ts: float = 0.0
|
||||
|
||||
def trace_data(self, data: bytes) -> None:
|
||||
now = time.time()
|
||||
elapsed_s = now - self.last_ts if self.last_ts else 0
|
||||
elapsed_ms = int(elapsed_s * 1000)
|
||||
instant_throughput_kbps = ((len(data) / elapsed_s) / 1000) if elapsed_s else 0.0
|
||||
|
||||
hex_str = data[:TRACE_MAX_SIZE].hex() + (
|
||||
"..." if len(data) > TRACE_MAX_SIZE else ""
|
||||
)
|
||||
print(
|
||||
f"[{self.channel_name}] {len(data):4} bytes "
|
||||
f"(+{elapsed_ms:4}ms, {instant_throughput_kbps: 7.2f}kB/s) "
|
||||
f" {hex_str}"
|
||||
)
|
||||
|
||||
self.last_ts = now
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ServerBridge:
|
||||
"""
|
||||
RFCOMM server bridge: waits for a peer to connect an RFCOMM channel.
|
||||
The RFCOMM channel may be associated with a UUID published in an SDP service
|
||||
description, or simply be on a system-assigned channel number.
|
||||
When the connection is made, the bridge connects a TCP socket to a remote host and
|
||||
bridges the data in both directions, with flow control.
|
||||
When the RFCOMM channel is closed, the bridge disconnects the TCP socket
|
||||
and waits for a new channel to be connected.
|
||||
"""
|
||||
|
||||
READ_CHUNK_SIZE = 4096
|
||||
|
||||
def __init__(
|
||||
self, channel: int, uuid: str, trace: bool, tcp_host: str, tcp_port: int
|
||||
) -> None:
|
||||
self.device: Optional[Device] = None
|
||||
self.channel = channel
|
||||
self.uuid = uuid
|
||||
self.tcp_host = tcp_host
|
||||
self.tcp_port = tcp_port
|
||||
self.rfcomm_channel: Optional[rfcomm.DLC] = None
|
||||
self.tcp_tracer: Optional[Tracer]
|
||||
self.rfcomm_tracer: Optional[Tracer]
|
||||
|
||||
if trace:
|
||||
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
|
||||
self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
|
||||
else:
|
||||
self.rfcomm_tracer = None
|
||||
self.tcp_tracer = None
|
||||
|
||||
async def start(self, device: Device) -> None:
|
||||
self.device = device
|
||||
|
||||
# Create and register a server
|
||||
rfcomm_server = rfcomm.Server(self.device)
|
||||
|
||||
# Listen for incoming DLC connections
|
||||
self.channel = rfcomm_server.listen(self.on_rfcomm_channel, self.channel)
|
||||
|
||||
# Setup the SDP to advertise this channel
|
||||
service_record_handle = 0x00010001
|
||||
self.device.sdp_service_records = {
|
||||
service_record_handle: rfcomm.make_service_sdp_records(
|
||||
service_record_handle, self.channel, core.UUID(self.uuid)
|
||||
)
|
||||
}
|
||||
|
||||
# We're ready for a connection
|
||||
self.device.on("connection", self.on_connection)
|
||||
await self.set_available(True)
|
||||
|
||||
print(
|
||||
color(
|
||||
(
|
||||
f"### Listening for RFCOMM connection on {device.public_address}, "
|
||||
f"channel {self.channel}"
|
||||
),
|
||||
"yellow",
|
||||
)
|
||||
)
|
||||
|
||||
async def set_available(self, available: bool):
|
||||
# Become discoverable and connectable
|
||||
assert self.device
|
||||
await self.device.set_connectable(available)
|
||||
await self.device.set_discoverable(available)
|
||||
|
||||
def on_connection(self, connection):
|
||||
print(color(f"@@@ Bluetooth connection: {connection}", "blue"))
|
||||
connection.on("disconnection", self.on_disconnection)
|
||||
|
||||
# Don't accept new connections until we're disconnected
|
||||
utils.AsyncRunner.spawn(self.set_available(False))
|
||||
|
||||
def on_disconnection(self, reason: int):
|
||||
print(
|
||||
color("@@@ Bluetooth disconnection:", "red"),
|
||||
hci.HCI_Constant.error_name(reason),
|
||||
)
|
||||
|
||||
# We're ready for a new connection
|
||||
utils.AsyncRunner.spawn(self.set_available(True))
|
||||
|
||||
# Called when an RFCOMM channel is established
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_rfcomm_channel(self, rfcomm_channel):
|
||||
print(color("*** RFCOMM channel:", "cyan"), rfcomm_channel)
|
||||
|
||||
# Connect to the TCP server
|
||||
print(
|
||||
color(
|
||||
f"### Connecting to TCP {self.tcp_host}:{self.tcp_port}",
|
||||
"yellow",
|
||||
)
|
||||
)
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(self.tcp_host, self.tcp_port)
|
||||
except OSError:
|
||||
print(color("!!! Connection failed", "red"))
|
||||
await rfcomm_channel.disconnect()
|
||||
return
|
||||
|
||||
# Pipe data from RFCOMM to TCP
|
||||
def on_rfcomm_channel_closed():
|
||||
print(color("*** RFCOMM channel closed", "cyan"))
|
||||
writer.close()
|
||||
|
||||
def write_rfcomm_data(data):
|
||||
if self.rfcomm_tracer:
|
||||
self.rfcomm_tracer.trace_data(data)
|
||||
|
||||
writer.write(data)
|
||||
|
||||
rfcomm_channel.sink = write_rfcomm_data
|
||||
rfcomm_channel.on("close", on_rfcomm_channel_closed)
|
||||
|
||||
# Pipe data from TCP to RFCOMM
|
||||
while True:
|
||||
try:
|
||||
data = await reader.read(self.READ_CHUNK_SIZE)
|
||||
|
||||
if len(data) == 0:
|
||||
print(color("### TCP end of stream", "yellow"))
|
||||
if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
|
||||
await rfcomm_channel.disconnect()
|
||||
return
|
||||
|
||||
if self.tcp_tracer:
|
||||
self.tcp_tracer.trace_data(data)
|
||||
|
||||
rfcomm_channel.write(data)
|
||||
await rfcomm_channel.drain()
|
||||
except Exception as error:
|
||||
print(f"!!! Exception: {error}")
|
||||
break
|
||||
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
print(color("~~~ Bye bye", "magenta"))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ClientBridge:
|
||||
"""
|
||||
RFCOMM client bridge: connects to a BR/EDR device, then waits for an inbound
|
||||
TCP connection on a specified port number. When a TCP client connects, an
|
||||
RFCOMM connection to the device is established, and the data is bridged in both
|
||||
directions, with flow control.
|
||||
When the TCP connection is closed by the client, the RFCOMM channel is
|
||||
disconnected, but the connection to the device remains, ready for a new TCP client
|
||||
to connect.
|
||||
"""
|
||||
|
||||
READ_CHUNK_SIZE = 4096
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel: int,
|
||||
uuid: str,
|
||||
trace: bool,
|
||||
address: str,
|
||||
tcp_host: str,
|
||||
tcp_port: int,
|
||||
encrypt: bool,
|
||||
):
|
||||
self.channel = channel
|
||||
self.uuid = uuid
|
||||
self.trace = trace
|
||||
self.address = address
|
||||
self.tcp_host = tcp_host
|
||||
self.tcp_port = tcp_port
|
||||
self.encrypt = encrypt
|
||||
self.device: Optional[Device] = None
|
||||
self.connection: Optional[Connection] = None
|
||||
self.rfcomm_client: Optional[rfcomm.Client]
|
||||
self.rfcomm_mux: Optional[rfcomm.Multiplexer]
|
||||
self.tcp_connected: bool = False
|
||||
|
||||
self.tcp_tracer: Optional[Tracer]
|
||||
self.rfcomm_tracer: Optional[Tracer]
|
||||
|
||||
if trace:
|
||||
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
|
||||
self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
|
||||
else:
|
||||
self.rfcomm_tracer = None
|
||||
self.tcp_tracer = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self.connection:
|
||||
return
|
||||
|
||||
print(color(f"@@@ Connecting to Bluetooth {self.address}", "blue"))
|
||||
assert self.device
|
||||
self.connection = await self.device.connect(
|
||||
self.address, transport=core.BT_BR_EDR_TRANSPORT
|
||||
)
|
||||
print(color(f"@@@ Bluetooth connection: {self.connection}", "blue"))
|
||||
self.connection.on("disconnection", self.on_disconnection)
|
||||
|
||||
if self.encrypt:
|
||||
print(color("@@@ Encrypting Bluetooth connection", "blue"))
|
||||
await self.connection.encrypt()
|
||||
print(color("@@@ Bluetooth connection encrypted", "blue"))
|
||||
|
||||
self.rfcomm_client = rfcomm.Client(self.connection)
|
||||
try:
|
||||
self.rfcomm_mux = await self.rfcomm_client.start()
|
||||
except BaseException as e:
|
||||
print(color("!!! Failed to setup RFCOMM connection", "red"), e)
|
||||
raise
|
||||
|
||||
async def start(self, device: Device) -> None:
|
||||
self.device = device
|
||||
await device.set_connectable(False)
|
||||
await device.set_discoverable(False)
|
||||
|
||||
# Called when a TCP connection is established
|
||||
async def on_tcp_connection(reader, writer):
|
||||
print(color("<<< TCP connection", "magenta"))
|
||||
if self.tcp_connected:
|
||||
print(
|
||||
color("!!! TCP connection already active, rejecting new one", "red")
|
||||
)
|
||||
writer.close()
|
||||
return
|
||||
self.tcp_connected = True
|
||||
|
||||
try:
|
||||
await self.pipe(reader, writer)
|
||||
except BaseException as error:
|
||||
print(color("!!! Exception while piping data:", "red"), error)
|
||||
return
|
||||
finally:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
self.tcp_connected = False
|
||||
|
||||
await asyncio.start_server(
|
||||
on_tcp_connection,
|
||||
host=self.tcp_host if self.tcp_host != "_" else None,
|
||||
port=self.tcp_port,
|
||||
)
|
||||
print(
|
||||
color(
|
||||
f"### Listening for TCP connections on port {self.tcp_port}", "magenta"
|
||||
)
|
||||
)
|
||||
|
||||
async def pipe(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
# Resolve the channel number from the UUID if needed
|
||||
if self.channel == 0:
|
||||
await self.connect()
|
||||
assert self.connection
|
||||
channel = await rfcomm.find_rfcomm_channel_with_uuid(
|
||||
self.connection, self.uuid
|
||||
)
|
||||
if channel:
|
||||
print(color(f"### Found RFCOMM channel {channel}", "yellow"))
|
||||
else:
|
||||
print(color(f"!!! RFCOMM channel with UUID {self.uuid} not found"))
|
||||
return
|
||||
else:
|
||||
channel = self.channel
|
||||
|
||||
# Connect a new RFCOMM channel
|
||||
await self.connect()
|
||||
assert self.rfcomm_mux
|
||||
print(color(f"*** Opening RFCOMM channel {channel}", "green"))
|
||||
try:
|
||||
rfcomm_channel = await self.rfcomm_mux.open_dlc(channel)
|
||||
print(color(f"*** RFCOMM channel open: {rfcomm_channel}", "green"))
|
||||
except Exception as error:
|
||||
print(color(f"!!! RFCOMM open failed: {error}", "red"))
|
||||
return
|
||||
|
||||
# Pipe data from RFCOMM to TCP
|
||||
def on_rfcomm_channel_closed():
|
||||
print(color("*** RFCOMM channel closed", "green"))
|
||||
|
||||
def write_rfcomm_data(data):
|
||||
if self.trace:
|
||||
self.rfcomm_tracer.trace_data(data)
|
||||
|
||||
writer.write(data)
|
||||
|
||||
rfcomm_channel.on("close", on_rfcomm_channel_closed)
|
||||
rfcomm_channel.sink = write_rfcomm_data
|
||||
|
||||
# Pipe data from TCP to RFCOMM
|
||||
while True:
|
||||
try:
|
||||
data = await reader.read(self.READ_CHUNK_SIZE)
|
||||
|
||||
if len(data) == 0:
|
||||
print(color("### TCP end of stream", "yellow"))
|
||||
if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
|
||||
await rfcomm_channel.disconnect()
|
||||
self.tcp_connected = False
|
||||
return
|
||||
|
||||
if self.tcp_tracer:
|
||||
self.tcp_tracer.trace_data(data)
|
||||
|
||||
rfcomm_channel.write(data)
|
||||
await rfcomm_channel.drain()
|
||||
except Exception as error:
|
||||
print(f"!!! Exception: {error}")
|
||||
break
|
||||
|
||||
print(color("~~~ Bye bye", "magenta"))
|
||||
|
||||
def on_disconnection(self, reason: int) -> None:
|
||||
print(
|
||||
color("@@@ Bluetooth disconnection:", "red"),
|
||||
hci.HCI_Constant.error_name(reason),
|
||||
)
|
||||
self.connection = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run(device_config, hci_transport, bridge):
|
||||
print("<<< connecting to HCI...")
|
||||
async with await transport.open_transport_or_link(hci_transport) as (
|
||||
hci_source,
|
||||
hci_sink,
|
||||
):
|
||||
print("<<< connected")
|
||||
|
||||
if device_config:
|
||||
device = Device.from_config_file_with_hci(
|
||||
device_config, hci_source, hci_sink
|
||||
)
|
||||
else:
|
||||
device = Device.from_config_with_hci(
|
||||
DeviceConfiguration(), hci_source, hci_sink
|
||||
)
|
||||
device.classic_enabled = True
|
||||
|
||||
# Let's go
|
||||
await device.power_on()
|
||||
try:
|
||||
await bridge.start(device)
|
||||
|
||||
# Wait until the transport terminates
|
||||
await hci_source.wait_for_termination()
|
||||
except core.ConnectionError as error:
|
||||
print(color(f"!!! Bluetooth connection failed: {error}", "red"))
|
||||
except Exception as error:
|
||||
print(f"Exception while running bridge: {error}")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.group()
|
||||
@click.pass_context
|
||||
@click.option(
|
||||
"--device-config",
|
||||
metavar="CONFIG_FILE",
|
||||
help="Device configuration file",
|
||||
)
|
||||
@click.option(
|
||||
"--hci-transport", metavar="TRANSPORT_NAME", help="HCI transport", required=True
|
||||
)
|
||||
@click.option("--trace", is_flag=True, help="Trace bridged data to stdout")
|
||||
@click.option(
|
||||
"--channel",
|
||||
metavar="CHANNEL_NUMER",
|
||||
help="RFCOMM channel number",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
@click.option(
|
||||
"--uuid",
|
||||
metavar="UUID",
|
||||
help="UUID for the RFCOMM channel",
|
||||
default=DEFAULT_RFCOMM_UUID,
|
||||
)
|
||||
def cli(
|
||||
context,
|
||||
device_config,
|
||||
hci_transport,
|
||||
trace,
|
||||
channel,
|
||||
uuid,
|
||||
):
|
||||
context.ensure_object(dict)
|
||||
context.obj["device_config"] = device_config
|
||||
context.obj["hci_transport"] = hci_transport
|
||||
context.obj["trace"] = trace
|
||||
context.obj["channel"] = channel
|
||||
context.obj["uuid"] = uuid
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@cli.command()
|
||||
@click.pass_context
|
||||
@click.option("--tcp-host", help="TCP host", default="localhost")
|
||||
@click.option("--tcp-port", help="TCP port", default=DEFAULT_SERVER_TCP_PORT)
|
||||
def server(context, tcp_host, tcp_port):
|
||||
bridge = ServerBridge(
|
||||
context.obj["channel"],
|
||||
context.obj["uuid"],
|
||||
context.obj["trace"],
|
||||
tcp_host,
|
||||
tcp_port,
|
||||
)
|
||||
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@cli.command()
|
||||
@click.pass_context
|
||||
@click.argument("bluetooth-address")
|
||||
@click.option("--tcp-host", help="TCP host", default="_")
|
||||
@click.option("--tcp-port", help="TCP port", default=DEFAULT_CLIENT_TCP_PORT)
|
||||
@click.option("--encrypt", is_flag=True, help="Encrypt the connection")
|
||||
def client(context, bluetooth_address, tcp_host, tcp_port, encrypt):
|
||||
bridge = ClientBridge(
|
||||
context.obj["channel"],
|
||||
context.obj["uuid"],
|
||||
context.obj["trace"],
|
||||
bluetooth_address,
|
||||
tcp_host,
|
||||
tcp_port,
|
||||
encrypt,
|
||||
)
|
||||
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper())
|
||||
if __name__ == "__main__":
|
||||
cli(obj={}) # pylint: disable=no-value-for-parameter
|
||||
52
apps/scan.py
52
apps/scan.py
@@ -26,7 +26,7 @@ from bumble.transport import open_transport_or_link
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.smp import AddressResolver
|
||||
from bumble.device import Advertisement
|
||||
from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
|
||||
from bumble.hci import Address, HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -66,10 +66,15 @@ class AdvertisementPrinter:
|
||||
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
|
||||
address.address_type
|
||||
]
|
||||
if address.is_public:
|
||||
type_color = 'cyan'
|
||||
if address.address_type in (
|
||||
Address.RANDOM_IDENTITY_ADDRESS,
|
||||
Address.PUBLIC_IDENTITY_ADDRESS,
|
||||
):
|
||||
type_color = 'yellow'
|
||||
else:
|
||||
if address.is_static:
|
||||
if address.is_public:
|
||||
type_color = 'cyan'
|
||||
elif address.is_static:
|
||||
type_color = 'green'
|
||||
address_qualifier = '(static)'
|
||||
elif address.is_resolvable:
|
||||
@@ -116,6 +121,7 @@ async def scan(
|
||||
phy,
|
||||
filter_duplicates,
|
||||
raw,
|
||||
irks,
|
||||
keystore_file,
|
||||
device_config,
|
||||
transport,
|
||||
@@ -140,9 +146,21 @@ async def scan(
|
||||
|
||||
if device.keystore:
|
||||
resolving_keys = await device.keystore.get_resolving_keys()
|
||||
resolver = AddressResolver(resolving_keys)
|
||||
else:
|
||||
resolver = None
|
||||
resolving_keys = []
|
||||
|
||||
for irk_and_address in irks:
|
||||
if ':' not in irk_and_address:
|
||||
raise ValueError('invalid IRK:ADDRESS value')
|
||||
irk_hex, address_str = irk_and_address.split(':', 1)
|
||||
resolving_keys.append(
|
||||
(
|
||||
bytes.fromhex(irk_hex),
|
||||
Address(address_str, Address.RANDOM_DEVICE_ADDRESS),
|
||||
)
|
||||
)
|
||||
|
||||
resolver = AddressResolver(resolving_keys) if resolving_keys else None
|
||||
|
||||
printer = AdvertisementPrinter(min_rssi, resolver)
|
||||
if raw:
|
||||
@@ -187,8 +205,24 @@ async def scan(
|
||||
default=False,
|
||||
help='Listen for raw advertising reports instead of processed ones',
|
||||
)
|
||||
@click.option('--keystore-file', help='Keystore file to use when resolving addresses')
|
||||
@click.option('--device-config', help='Device config file for the scanning device')
|
||||
@click.option(
|
||||
'--irk',
|
||||
metavar='<IRK_HEX>:<ADDRESS>',
|
||||
help=(
|
||||
'Use this IRK for resolving private addresses ' '(may be used more than once)'
|
||||
),
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
'--keystore-file',
|
||||
metavar='FILE_PATH',
|
||||
help='Keystore file to use when resolving addresses',
|
||||
)
|
||||
@click.option(
|
||||
'--device-config',
|
||||
metavar='FILE_PATH',
|
||||
help='Device config file for the scanning device',
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(
|
||||
min_rssi,
|
||||
@@ -198,6 +232,7 @@ def main(
|
||||
phy,
|
||||
filter_duplicates,
|
||||
raw,
|
||||
irk,
|
||||
keystore_file,
|
||||
device_config,
|
||||
transport,
|
||||
@@ -212,6 +247,7 @@ def main(
|
||||
phy,
|
||||
filter_duplicates,
|
||||
raw,
|
||||
irk,
|
||||
keystore_file,
|
||||
device_config,
|
||||
transport,
|
||||
|
||||
75
apps/show.py
75
apps/show.py
@@ -15,7 +15,11 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
|
||||
import click
|
||||
|
||||
from bumble.colors import color
|
||||
@@ -24,6 +28,14 @@ from bumble.transport.common import PacketReader
|
||||
from bumble.helpers import PacketTracer
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class SnoopPacketReader:
|
||||
'''
|
||||
@@ -36,12 +48,18 @@ class SnoopPacketReader:
|
||||
DATALINK_BSCP = 1003
|
||||
DATALINK_H5 = 1004
|
||||
|
||||
IDENTIFICATION_PATTERN = b'btsnoop\0'
|
||||
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1)
|
||||
TIMESTAMP_DELTA = 0x00E03AB44A676000
|
||||
ONE_MICROSECOND = datetime.timedelta(microseconds=1)
|
||||
|
||||
def __init__(self, source):
|
||||
self.source = source
|
||||
self.at_end = False
|
||||
|
||||
# Read the header
|
||||
identification_pattern = source.read(8)
|
||||
if identification_pattern.hex().lower() != '6274736e6f6f7000':
|
||||
if identification_pattern != self.IDENTIFICATION_PATTERN:
|
||||
raise ValueError(
|
||||
'not a valid snoop file, unexpected identification pattern'
|
||||
)
|
||||
@@ -55,19 +73,32 @@ class SnoopPacketReader:
|
||||
# Read the record header
|
||||
header = self.source.read(24)
|
||||
if len(header) < 24:
|
||||
return (0, None)
|
||||
self.at_end = True
|
||||
return (None, 0, None)
|
||||
|
||||
# Parse the header
|
||||
(
|
||||
original_length,
|
||||
included_length,
|
||||
packet_flags,
|
||||
_cumulative_drops,
|
||||
_timestamp_seconds,
|
||||
_timestamp_microsecond,
|
||||
) = struct.unpack('>IIIIII', header)
|
||||
timestamp,
|
||||
) = struct.unpack('>IIIIQ', header)
|
||||
|
||||
# Abort on truncated packets
|
||||
# Skip truncated packets
|
||||
if original_length != included_length:
|
||||
return (0, None)
|
||||
print(
|
||||
color(
|
||||
f"!!! truncated packet ({included_length}/{original_length})", "red"
|
||||
)
|
||||
)
|
||||
self.source.read(included_length)
|
||||
return (None, 0, None)
|
||||
|
||||
# Convert the timestamp to a datetime object.
|
||||
ts_dt = self.TIMESTAMP_ANCHOR + datetime.timedelta(
|
||||
microseconds=timestamp - self.TIMESTAMP_DELTA
|
||||
)
|
||||
|
||||
if self.data_link_type == self.DATALINK_H1:
|
||||
# The packet is un-encapsulated, look at the flags to figure out its type
|
||||
@@ -89,7 +120,17 @@ class SnoopPacketReader:
|
||||
bytes([packet_type]) + self.source.read(included_length),
|
||||
)
|
||||
|
||||
return (packet_flags & 1, self.source.read(included_length))
|
||||
return (ts_dt, packet_flags & 1, self.source.read(included_length))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Printer:
|
||||
def __init__(self):
|
||||
self.index = 0
|
||||
|
||||
def print(self, message: str) -> None:
|
||||
self.index += 1
|
||||
print(f"[{self.index:8}]{message}")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -122,24 +163,28 @@ def main(format, vendors, filename):
|
||||
packet_reader = PacketReader(input)
|
||||
|
||||
def read_next_packet():
|
||||
return (0, packet_reader.next_packet())
|
||||
return (None, 0, packet_reader.next_packet())
|
||||
|
||||
else:
|
||||
packet_reader = SnoopPacketReader(input)
|
||||
read_next_packet = packet_reader.next_packet
|
||||
|
||||
tracer = PacketTracer(emit_message=print)
|
||||
printer = Printer()
|
||||
tracer = PacketTracer(emit_message=printer.print)
|
||||
|
||||
while True:
|
||||
while not packet_reader.at_end:
|
||||
try:
|
||||
(direction, packet) = read_next_packet()
|
||||
if packet is None:
|
||||
break
|
||||
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction)
|
||||
(timestamp, direction, packet) = read_next_packet()
|
||||
if packet:
|
||||
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction, timestamp)
|
||||
else:
|
||||
printer.print(color("[TRUNCATED]", "red"))
|
||||
except Exception as error:
|
||||
logger.exception()
|
||||
print(color(f'!!! {error}', 'red'))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
|
||||
@@ -76,6 +76,7 @@ logger = logging.getLogger(__name__)
|
||||
# -----------------------------------------------------------------------------
|
||||
DEFAULT_UI_PORT = 7654
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AudioExtractor:
|
||||
@staticmethod
|
||||
@@ -641,7 +642,7 @@ class Speaker:
|
||||
self.device.on('connection', self.on_bluetooth_connection)
|
||||
|
||||
# Create a listener to wait for AVDTP connections
|
||||
self.listener = Listener(Listener.create_registrar(self.device))
|
||||
self.listener = Listener.for_device(self.device)
|
||||
self.listener.on('connection', self.on_avdtp_connection)
|
||||
|
||||
print(f'Speaker ready to play, codec={color(self.codec, "cyan")}')
|
||||
|
||||
@@ -24,6 +24,7 @@ from bumble.device import Device
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.transport import open_transport
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def unbond_with_keystore(keystore, address):
|
||||
if address is None:
|
||||
|
||||
169
bumble/a2dp.py
169
bumble/a2dp.py
@@ -15,9 +15,13 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import struct
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import List, Callable, Awaitable
|
||||
|
||||
from .company_ids import COMPANY_IDENTIFIERS
|
||||
from .sdp import (
|
||||
@@ -180,8 +184,12 @@ def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3))
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
|
||||
DataElement.unsigned_integer_16(version_int),
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
|
||||
DataElement.unsigned_integer_16(version_int),
|
||||
]
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
@@ -230,8 +238,12 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
|
||||
DataElement.unsigned_integer_16(version_int),
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
|
||||
DataElement.unsigned_integer_16(version_int),
|
||||
]
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
@@ -239,24 +251,20 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SbcMediaCodecInformation(
|
||||
namedtuple(
|
||||
'SbcMediaCodecInformation',
|
||||
[
|
||||
'sampling_frequency',
|
||||
'channel_mode',
|
||||
'block_length',
|
||||
'subbands',
|
||||
'allocation_method',
|
||||
'minimum_bitpool_value',
|
||||
'maximum_bitpool_value',
|
||||
],
|
||||
)
|
||||
):
|
||||
@dataclasses.dataclass
|
||||
class SbcMediaCodecInformation:
|
||||
'''
|
||||
A2DP spec - 4.3.2 Codec Specific Information Elements
|
||||
'''
|
||||
|
||||
sampling_frequency: int
|
||||
channel_mode: int
|
||||
block_length: int
|
||||
subbands: int
|
||||
allocation_method: int
|
||||
minimum_bitpool_value: int
|
||||
maximum_bitpool_value: int
|
||||
|
||||
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
|
||||
CHANNEL_MODE_BITS = {
|
||||
SBC_MONO_CHANNEL_MODE: 1 << 3,
|
||||
@@ -272,7 +280,7 @@ class SbcMediaCodecInformation(
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> 'SbcMediaCodecInformation':
|
||||
def from_bytes(data: bytes) -> SbcMediaCodecInformation:
|
||||
sampling_frequency = (data[0] >> 4) & 0x0F
|
||||
channel_mode = (data[0] >> 0) & 0x0F
|
||||
block_length = (data[1] >> 4) & 0x0F
|
||||
@@ -293,14 +301,14 @@ class SbcMediaCodecInformation(
|
||||
@classmethod
|
||||
def from_discrete_values(
|
||||
cls,
|
||||
sampling_frequency,
|
||||
channel_mode,
|
||||
block_length,
|
||||
subbands,
|
||||
allocation_method,
|
||||
minimum_bitpool_value,
|
||||
maximum_bitpool_value,
|
||||
):
|
||||
sampling_frequency: int,
|
||||
channel_mode: int,
|
||||
block_length: int,
|
||||
subbands: int,
|
||||
allocation_method: int,
|
||||
minimum_bitpool_value: int,
|
||||
maximum_bitpool_value: int,
|
||||
) -> SbcMediaCodecInformation:
|
||||
return SbcMediaCodecInformation(
|
||||
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
|
||||
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
|
||||
@@ -314,14 +322,14 @@ class SbcMediaCodecInformation(
|
||||
@classmethod
|
||||
def from_lists(
|
||||
cls,
|
||||
sampling_frequencies,
|
||||
channel_modes,
|
||||
block_lengths,
|
||||
subbands,
|
||||
allocation_methods,
|
||||
minimum_bitpool_value,
|
||||
maximum_bitpool_value,
|
||||
):
|
||||
sampling_frequencies: List[int],
|
||||
channel_modes: List[int],
|
||||
block_lengths: List[int],
|
||||
subbands: List[int],
|
||||
allocation_methods: List[int],
|
||||
minimum_bitpool_value: int,
|
||||
maximum_bitpool_value: int,
|
||||
) -> SbcMediaCodecInformation:
|
||||
return SbcMediaCodecInformation(
|
||||
sampling_frequency=sum(
|
||||
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
|
||||
@@ -348,7 +356,7 @@ class SbcMediaCodecInformation(
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
|
||||
allocation_methods = ['SNR', 'Loudness']
|
||||
return '\n'.join(
|
||||
@@ -367,16 +375,19 @@ class SbcMediaCodecInformation(
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AacMediaCodecInformation(
|
||||
namedtuple(
|
||||
'AacMediaCodecInformation',
|
||||
['object_type', 'sampling_frequency', 'channels', 'rfa', 'vbr', 'bitrate'],
|
||||
)
|
||||
):
|
||||
@dataclasses.dataclass
|
||||
class AacMediaCodecInformation:
|
||||
'''
|
||||
A2DP spec - 4.5.2 Codec Specific Information Elements
|
||||
'''
|
||||
|
||||
object_type: int
|
||||
sampling_frequency: int
|
||||
channels: int
|
||||
rfa: int
|
||||
vbr: int
|
||||
bitrate: int
|
||||
|
||||
OBJECT_TYPE_BITS = {
|
||||
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
|
||||
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
|
||||
@@ -400,7 +411,7 @@ class AacMediaCodecInformation(
|
||||
CHANNELS_BITS = {1: 1 << 1, 2: 1}
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> 'AacMediaCodecInformation':
|
||||
def from_bytes(data: bytes) -> AacMediaCodecInformation:
|
||||
object_type = data[0]
|
||||
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
|
||||
channels = (data[2] >> 2) & 0x03
|
||||
@@ -413,8 +424,13 @@ class AacMediaCodecInformation(
|
||||
|
||||
@classmethod
|
||||
def from_discrete_values(
|
||||
cls, object_type, sampling_frequency, channels, vbr, bitrate
|
||||
):
|
||||
cls,
|
||||
object_type: int,
|
||||
sampling_frequency: int,
|
||||
channels: int,
|
||||
vbr: int,
|
||||
bitrate: int,
|
||||
) -> AacMediaCodecInformation:
|
||||
return AacMediaCodecInformation(
|
||||
object_type=cls.OBJECT_TYPE_BITS[object_type],
|
||||
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
|
||||
@@ -425,7 +441,14 @@ class AacMediaCodecInformation(
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate):
|
||||
def from_lists(
|
||||
cls,
|
||||
object_types: List[int],
|
||||
sampling_frequencies: List[int],
|
||||
channels: List[int],
|
||||
vbr: int,
|
||||
bitrate: int,
|
||||
) -> AacMediaCodecInformation:
|
||||
return AacMediaCodecInformation(
|
||||
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
|
||||
sampling_frequency=sum(
|
||||
@@ -449,7 +472,7 @@ class AacMediaCodecInformation(
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
object_types = [
|
||||
'MPEG_2_AAC_LC',
|
||||
'MPEG_4_AAC_LC',
|
||||
@@ -474,26 +497,26 @@ class AacMediaCodecInformation(
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
# -----------------------------------------------------------------------------
|
||||
class VendorSpecificMediaCodecInformation:
|
||||
'''
|
||||
A2DP spec - 4.7.2 Codec Specific Information Elements
|
||||
'''
|
||||
|
||||
vendor_id: int
|
||||
codec_id: int
|
||||
value: bytes
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data):
|
||||
def from_bytes(data: bytes) -> VendorSpecificMediaCodecInformation:
|
||||
(vendor_id, codec_id) = struct.unpack_from('<IH', data, 0)
|
||||
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
|
||||
|
||||
def __init__(self, vendor_id, codec_id, value):
|
||||
self.vendor_id = vendor_id
|
||||
self.codec_id = codec_id
|
||||
self.value = value
|
||||
|
||||
def __bytes__(self):
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
# pylint: disable=line-too-long
|
||||
return '\n'.join(
|
||||
[
|
||||
@@ -506,29 +529,27 @@ class VendorSpecificMediaCodecInformation:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class SbcFrame:
|
||||
def __init__(
|
||||
self, sampling_frequency, block_count, channel_mode, subband_count, payload
|
||||
):
|
||||
self.sampling_frequency = sampling_frequency
|
||||
self.block_count = block_count
|
||||
self.channel_mode = channel_mode
|
||||
self.subband_count = subband_count
|
||||
self.payload = payload
|
||||
sampling_frequency: int
|
||||
block_count: int
|
||||
channel_mode: int
|
||||
subband_count: int
|
||||
payload: bytes
|
||||
|
||||
@property
|
||||
def sample_count(self):
|
||||
def sample_count(self) -> int:
|
||||
return self.subband_count * self.block_count
|
||||
|
||||
@property
|
||||
def bitrate(self):
|
||||
def bitrate(self) -> int:
|
||||
return 8 * ((len(self.payload) * self.sampling_frequency) // self.sample_count)
|
||||
|
||||
@property
|
||||
def duration(self):
|
||||
def duration(self) -> float:
|
||||
return self.sample_count / self.sampling_frequency
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'SBC(sf={self.sampling_frequency},'
|
||||
f'cm={self.channel_mode},'
|
||||
@@ -540,12 +561,12 @@ class SbcFrame:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SbcParser:
|
||||
def __init__(self, read):
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
|
||||
self.read = read
|
||||
|
||||
@property
|
||||
def frames(self):
|
||||
async def generate_frames():
|
||||
def frames(self) -> AsyncGenerator[SbcFrame, None]:
|
||||
async def generate_frames() -> AsyncGenerator[SbcFrame, None]:
|
||||
while True:
|
||||
# Read 4 bytes of header
|
||||
header = await self.read(4)
|
||||
@@ -589,7 +610,9 @@ class SbcParser:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SbcPacketSource:
|
||||
def __init__(self, read, mtu, codec_capabilities):
|
||||
def __init__(
|
||||
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
|
||||
) -> None:
|
||||
self.read = read
|
||||
self.mtu = mtu
|
||||
self.codec_capabilities = codec_capabilities
|
||||
@@ -629,7 +652,9 @@ class SbcPacketSource:
|
||||
|
||||
# Prepare for next packets
|
||||
sequence_number += 1
|
||||
sequence_number &= 0xFFFF
|
||||
timestamp += sum((frame.sample_count for frame in frames))
|
||||
timestamp &= 0xFFFFFFFF
|
||||
frames = [frame]
|
||||
frames_size = len(frame.payload)
|
||||
else:
|
||||
|
||||
18
bumble/at.py
18
bumble/at.py
@@ -14,13 +14,19 @@
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from bumble import core
|
||||
|
||||
|
||||
class AtParsingError(core.InvalidPacketError):
|
||||
"""Error raised when parsing AT commands fails."""
|
||||
|
||||
|
||||
def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
||||
"""Split input parameters into tokens.
|
||||
Removes space characters outside of double quote blocks:
|
||||
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
|
||||
are ignored [..], unless they are embedded in numeric or string constants"
|
||||
Raises ValueError in case of invalid input string."""
|
||||
Raises AtParsingError in case of invalid input string."""
|
||||
|
||||
tokens = []
|
||||
in_quotes = False
|
||||
@@ -43,11 +49,11 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
||||
token = bytearray()
|
||||
elif char == b'(':
|
||||
if len(token) > 0:
|
||||
raise ValueError("open_paren following regular character")
|
||||
raise AtParsingError("open_paren following regular character")
|
||||
tokens.append(char)
|
||||
elif char == b'"':
|
||||
if len(token) > 0:
|
||||
raise ValueError("quote following regular character")
|
||||
raise AtParsingError("quote following regular character")
|
||||
in_quotes = True
|
||||
token.extend(char)
|
||||
else:
|
||||
@@ -59,7 +65,7 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
||||
|
||||
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
||||
"""Parse the parameters using the comma and parenthesis separators.
|
||||
Raises ValueError in case of invalid input string."""
|
||||
Raises AtParsingError in case of invalid input string."""
|
||||
|
||||
tokens = tokenize_parameters(buffer)
|
||||
accumulator: List[list] = [[]]
|
||||
@@ -73,7 +79,7 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
||||
accumulator.append([])
|
||||
elif token == b')':
|
||||
if len(accumulator) < 2:
|
||||
raise ValueError("close_paren without matching open_paren")
|
||||
raise AtParsingError("close_paren without matching open_paren")
|
||||
accumulator[-1].append(current)
|
||||
current = accumulator.pop()
|
||||
else:
|
||||
@@ -81,5 +87,5 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
||||
|
||||
accumulator[-1].append(current)
|
||||
if len(accumulator) > 1:
|
||||
raise ValueError("missing close_paren")
|
||||
raise AtParsingError("missing close_paren")
|
||||
return accumulator[0]
|
||||
|
||||
@@ -25,9 +25,21 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import struct
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from pyee import EventEmitter
|
||||
from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
|
||||
|
||||
from bumble.core import UUID, name_or_number, ProtocolError
|
||||
from bumble.hci import HCI_Object, key_with_value
|
||||
@@ -643,7 +655,7 @@ class ATT_Write_Command(ATT_PDU):
|
||||
@ATT_PDU.subclass(
|
||||
[
|
||||
('attribute_handle', HANDLE_FIELD_SPEC),
|
||||
('attribute_value', '*')
|
||||
('attribute_value', '*'),
|
||||
# ('authentication_signature', 'TODO')
|
||||
]
|
||||
)
|
||||
@@ -722,12 +734,38 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ConnectionValue(Protocol):
|
||||
def read(self, connection) -> bytes:
|
||||
...
|
||||
class AttributeValue:
|
||||
'''
|
||||
Attribute value where reading and/or writing is delegated to functions
|
||||
passed as arguments to the constructor.
|
||||
'''
|
||||
|
||||
def write(self, connection, value: bytes) -> None:
|
||||
...
|
||||
def __init__(
|
||||
self,
|
||||
read: Union[
|
||||
Callable[[Optional[Connection]], bytes],
|
||||
Callable[[Optional[Connection]], Awaitable[bytes]],
|
||||
None,
|
||||
] = None,
|
||||
write: Union[
|
||||
Callable[[Optional[Connection], bytes], None],
|
||||
Callable[[Optional[Connection], bytes], Awaitable[None]],
|
||||
None,
|
||||
] = None,
|
||||
):
|
||||
self._read = read
|
||||
self._write = write
|
||||
|
||||
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
|
||||
return self._read(connection) if self._read else b''
|
||||
|
||||
def write(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
) -> Union[Awaitable[None], None]:
|
||||
if self._write:
|
||||
return self._write(connection, value)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -770,13 +808,13 @@ class Attribute(EventEmitter):
|
||||
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
|
||||
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
|
||||
|
||||
value: Union[str, bytes, ConnectionValue]
|
||||
value: Union[bytes, AttributeValue]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attribute_type: Union[str, bytes, UUID],
|
||||
permissions: Union[str, Attribute.Permissions],
|
||||
value: Union[str, bytes, ConnectionValue] = b'',
|
||||
value: Union[str, bytes, AttributeValue] = b'',
|
||||
) -> None:
|
||||
EventEmitter.__init__(self)
|
||||
self.handle = 0
|
||||
@@ -806,7 +844,7 @@ class Attribute(EventEmitter):
|
||||
def decode_value(self, value_bytes: bytes) -> Any:
|
||||
return value_bytes
|
||||
|
||||
def read_value(self, connection: Optional[Connection]) -> bytes:
|
||||
async def read_value(self, connection: Optional[Connection]) -> bytes:
|
||||
if (
|
||||
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
|
||||
and connection is not None
|
||||
@@ -832,6 +870,8 @@ class Attribute(EventEmitter):
|
||||
if hasattr(self.value, 'read'):
|
||||
try:
|
||||
value = self.value.read(connection)
|
||||
if inspect.isawaitable(value):
|
||||
value = await value
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
@@ -841,7 +881,7 @@ class Attribute(EventEmitter):
|
||||
|
||||
return self.encode_value(value)
|
||||
|
||||
def write_value(self, connection: Connection, value_bytes: bytes) -> None:
|
||||
async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
|
||||
if (
|
||||
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
|
||||
) and not connection.encryption:
|
||||
@@ -864,7 +904,9 @@ class Attribute(EventEmitter):
|
||||
|
||||
if hasattr(self.value, 'write'):
|
||||
try:
|
||||
self.value.write(connection, value) # pylint: disable=not-callable
|
||||
result = self.value.write(connection, value)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
|
||||
523
bumble/avc.py
Normal file
523
bumble/avc.py
Normal file
@@ -0,0 +1,523 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import struct
|
||||
from typing import Dict, Type, Union, Tuple
|
||||
|
||||
from bumble import core
|
||||
from bumble.utils import OpenIntEnum
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Frame:
|
||||
class SubunitType(enum.IntEnum):
|
||||
# AV/C Digital Interface Command Set General Specification Version 4.1
|
||||
# Table 7.4
|
||||
MONITOR = 0x00
|
||||
AUDIO = 0x01
|
||||
PRINTER = 0x02
|
||||
DISC = 0x03
|
||||
TAPE_RECORDER_OR_PLAYER = 0x04
|
||||
TUNER = 0x05
|
||||
CA = 0x06
|
||||
CAMERA = 0x07
|
||||
PANEL = 0x09
|
||||
BULLETIN_BOARD = 0x0A
|
||||
VENDOR_UNIQUE = 0x1C
|
||||
EXTENDED = 0x1E
|
||||
UNIT = 0x1F
|
||||
|
||||
class OperationCode(OpenIntEnum):
|
||||
# 0x00 - 0x0F: Unit and subunit commands
|
||||
VENDOR_DEPENDENT = 0x00
|
||||
RESERVE = 0x01
|
||||
PLUG_INFO = 0x02
|
||||
|
||||
# 0x10 - 0x3F: Unit commands
|
||||
DIGITAL_OUTPUT = 0x10
|
||||
DIGITAL_INPUT = 0x11
|
||||
CHANNEL_USAGE = 0x12
|
||||
OUTPUT_PLUG_SIGNAL_FORMAT = 0x18
|
||||
INPUT_PLUG_SIGNAL_FORMAT = 0x19
|
||||
GENERAL_BUS_SETUP = 0x1F
|
||||
CONNECT_AV = 0x20
|
||||
DISCONNECT_AV = 0x21
|
||||
CONNECTIONS = 0x22
|
||||
CONNECT = 0x24
|
||||
DISCONNECT = 0x25
|
||||
UNIT_INFO = 0x30
|
||||
SUBUNIT_INFO = 0x31
|
||||
|
||||
# 0x40 - 0x7F: Subunit commands
|
||||
PASS_THROUGH = 0x7C
|
||||
GUI_UPDATE = 0x7D
|
||||
PUSH_GUI_DATA = 0x7E
|
||||
USER_ACTION = 0x7F
|
||||
|
||||
# 0xA0 - 0xBF: Unit and subunit commands
|
||||
VERSION = 0xB0
|
||||
POWER = 0xB2
|
||||
|
||||
subunit_type: SubunitType
|
||||
subunit_id: int
|
||||
opcode: OperationCode
|
||||
operands: bytes
|
||||
|
||||
@staticmethod
|
||||
def subclass(subclass):
|
||||
# Infer the opcode from the class name
|
||||
if subclass.__name__.endswith("CommandFrame"):
|
||||
short_name = subclass.__name__.replace("CommandFrame", "")
|
||||
category_class = CommandFrame
|
||||
elif subclass.__name__.endswith("ResponseFrame"):
|
||||
short_name = subclass.__name__.replace("ResponseFrame", "")
|
||||
category_class = ResponseFrame
|
||||
else:
|
||||
raise core.InvalidArgumentError(
|
||||
f"invalid subclass name {subclass.__name__}"
|
||||
)
|
||||
|
||||
uppercase_indexes = [
|
||||
i for i in range(len(short_name)) if short_name[i].isupper()
|
||||
]
|
||||
uppercase_indexes.append(len(short_name))
|
||||
words = [
|
||||
short_name[uppercase_indexes[i] : uppercase_indexes[i + 1]].upper()
|
||||
for i in range(len(uppercase_indexes) - 1)
|
||||
]
|
||||
opcode_name = "_".join(words)
|
||||
opcode = Frame.OperationCode[opcode_name]
|
||||
category_class.subclasses[opcode] = subclass
|
||||
return subclass
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> Frame:
|
||||
if data[0] >> 4 != 0:
|
||||
raise core.InvalidPacketError("first 4 bits must be 0s")
|
||||
|
||||
ctype_or_response = data[0] & 0xF
|
||||
subunit_type = Frame.SubunitType(data[1] >> 3)
|
||||
subunit_id = data[1] & 7
|
||||
|
||||
if subunit_type == Frame.SubunitType.EXTENDED:
|
||||
# Not supported
|
||||
raise NotImplementedError("extended subunit types not supported")
|
||||
|
||||
if subunit_id < 5:
|
||||
opcode_offset = 2
|
||||
elif subunit_id == 5:
|
||||
# Extended to the next byte
|
||||
extension = data[2]
|
||||
if extension == 0:
|
||||
raise core.InvalidPacketError("extended subunit ID value reserved")
|
||||
if extension == 0xFF:
|
||||
subunit_id = 5 + 254 + data[3]
|
||||
opcode_offset = 4
|
||||
else:
|
||||
subunit_id = 5 + extension
|
||||
opcode_offset = 3
|
||||
|
||||
elif subunit_id == 6:
|
||||
raise core.InvalidPacketError("reserved subunit ID")
|
||||
|
||||
opcode = Frame.OperationCode(data[opcode_offset])
|
||||
operands = data[opcode_offset + 1 :]
|
||||
|
||||
# Look for a registered subclass
|
||||
if ctype_or_response < 8:
|
||||
# Command
|
||||
ctype = CommandFrame.CommandType(ctype_or_response)
|
||||
if c_subclass := CommandFrame.subclasses.get(opcode):
|
||||
return c_subclass(
|
||||
ctype,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
*c_subclass.parse_operands(operands),
|
||||
)
|
||||
return CommandFrame(ctype, subunit_type, subunit_id, opcode, operands)
|
||||
else:
|
||||
# Response
|
||||
response = ResponseFrame.ResponseCode(ctype_or_response)
|
||||
if r_subclass := ResponseFrame.subclasses.get(opcode):
|
||||
return r_subclass(
|
||||
response,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
*r_subclass.parse_operands(operands),
|
||||
)
|
||||
return ResponseFrame(response, subunit_type, subunit_id, opcode, operands)
|
||||
|
||||
def to_bytes(
|
||||
self,
|
||||
ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode],
|
||||
) -> bytes:
|
||||
# TODO: support extended subunit types and ids.
|
||||
return (
|
||||
bytes(
|
||||
[
|
||||
ctype_or_response,
|
||||
self.subunit_type << 3 | self.subunit_id,
|
||||
self.opcode,
|
||||
]
|
||||
)
|
||||
+ self.operands
|
||||
)
|
||||
|
||||
def to_string(self, extra: str) -> str:
|
||||
return (
|
||||
f"{self.__class__.__name__}({extra}"
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"opcode={self.opcode.name}, "
|
||||
f"operands={self.operands.hex()})"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subunit_type: SubunitType,
|
||||
subunit_id: int,
|
||||
opcode: OperationCode,
|
||||
operands: bytes,
|
||||
) -> None:
|
||||
self.subunit_type = subunit_type
|
||||
self.subunit_id = subunit_id
|
||||
self.opcode = opcode
|
||||
self.operands = operands
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommandFrame(Frame):
|
||||
class CommandType(OpenIntEnum):
|
||||
# AV/C Digital Interface Command Set General Specification Version 4.1
|
||||
# Table 7.1
|
||||
CONTROL = 0x00
|
||||
STATUS = 0x01
|
||||
SPECIFIC_INQUIRY = 0x02
|
||||
NOTIFY = 0x03
|
||||
GENERAL_INQUIRY = 0x04
|
||||
|
||||
subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {}
|
||||
ctype: CommandType
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctype: CommandType,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
opcode: Frame.OperationCode,
|
||||
operands: bytes,
|
||||
) -> None:
|
||||
super().__init__(subunit_type, subunit_id, opcode, operands)
|
||||
self.ctype = ctype
|
||||
|
||||
def __bytes__(self):
|
||||
return self.to_bytes(self.ctype)
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string(f"ctype={self.ctype.name}, ")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ResponseFrame(Frame):
|
||||
class ResponseCode(OpenIntEnum):
|
||||
# AV/C Digital Interface Command Set General Specification Version 4.1
|
||||
# Table 7.2
|
||||
NOT_IMPLEMENTED = 0x08
|
||||
ACCEPTED = 0x09
|
||||
REJECTED = 0x0A
|
||||
IN_TRANSITION = 0x0B
|
||||
IMPLEMENTED_OR_STABLE = 0x0C
|
||||
CHANGED = 0x0D
|
||||
INTERIM = 0x0F
|
||||
|
||||
subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {}
|
||||
response: ResponseCode
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: ResponseCode,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
opcode: Frame.OperationCode,
|
||||
operands: bytes,
|
||||
) -> None:
|
||||
super().__init__(subunit_type, subunit_id, opcode, operands)
|
||||
self.response = response
|
||||
|
||||
def __bytes__(self):
|
||||
return self.to_bytes(self.response)
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string(f"response={self.response.name}, ")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class VendorDependentFrame:
|
||||
company_id: int
|
||||
vendor_dependent_data: bytes
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
return (
|
||||
struct.unpack(">I", b"\x00" + operands[:3])[0],
|
||||
operands[3:],
|
||||
)
|
||||
|
||||
def make_operands(self) -> bytes:
|
||||
return struct.pack(">I", self.company_id)[1:] + self.vendor_dependent_data
|
||||
|
||||
def __init__(self, company_id: int, vendor_dependent_data: bytes):
|
||||
self.company_id = company_id
|
||||
self.vendor_dependent_data = vendor_dependent_data
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@Frame.subclass
|
||||
class VendorDependentCommandFrame(VendorDependentFrame, CommandFrame):
|
||||
def __init__(
|
||||
self,
|
||||
ctype: CommandFrame.CommandType,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
company_id: int,
|
||||
vendor_dependent_data: bytes,
|
||||
) -> None:
|
||||
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
|
||||
CommandFrame.__init__(
|
||||
self,
|
||||
ctype,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
Frame.OperationCode.VENDOR_DEPENDENT,
|
||||
self.make_operands(),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"VendorDependentCommandFrame(ctype={self.ctype.name}, "
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"company_id=0x{self.company_id:06X}, "
|
||||
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@Frame.subclass
|
||||
class VendorDependentResponseFrame(VendorDependentFrame, ResponseFrame):
|
||||
def __init__(
|
||||
self,
|
||||
response: ResponseFrame.ResponseCode,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
company_id: int,
|
||||
vendor_dependent_data: bytes,
|
||||
) -> None:
|
||||
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
|
||||
ResponseFrame.__init__(
|
||||
self,
|
||||
response,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
Frame.OperationCode.VENDOR_DEPENDENT,
|
||||
self.make_operands(),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"VendorDependentResponseFrame(response={self.response.name}, "
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"company_id=0x{self.company_id:06X}, "
|
||||
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PassThroughFrame:
|
||||
"""
|
||||
See AV/C Panel Subunit Specification 1.1 - 9.4 PASS THROUGH control command
|
||||
"""
|
||||
|
||||
class StateFlag(enum.IntEnum):
|
||||
PRESSED = 0
|
||||
RELEASED = 1
|
||||
|
||||
class OperationId(OpenIntEnum):
|
||||
SELECT = 0x00
|
||||
UP = 0x01
|
||||
DOWN = 0x01
|
||||
LEFT = 0x03
|
||||
RIGHT = 0x04
|
||||
RIGHT_UP = 0x05
|
||||
RIGHT_DOWN = 0x06
|
||||
LEFT_UP = 0x07
|
||||
LEFT_DOWN = 0x08
|
||||
ROOT_MENU = 0x09
|
||||
SETUP_MENU = 0x0A
|
||||
CONTENTS_MENU = 0x0B
|
||||
FAVORITE_MENU = 0x0C
|
||||
EXIT = 0x0D
|
||||
NUMBER_0 = 0x20
|
||||
NUMBER_1 = 0x21
|
||||
NUMBER_2 = 0x22
|
||||
NUMBER_3 = 0x23
|
||||
NUMBER_4 = 0x24
|
||||
NUMBER_5 = 0x25
|
||||
NUMBER_6 = 0x26
|
||||
NUMBER_7 = 0x27
|
||||
NUMBER_8 = 0x28
|
||||
NUMBER_9 = 0x29
|
||||
DOT = 0x2A
|
||||
ENTER = 0x2B
|
||||
CLEAR = 0x2C
|
||||
CHANNEL_UP = 0x30
|
||||
CHANNEL_DOWN = 0x31
|
||||
PREVIOUS_CHANNEL = 0x32
|
||||
SOUND_SELECT = 0x33
|
||||
INPUT_SELECT = 0x34
|
||||
DISPLAY_INFORMATION = 0x35
|
||||
HELP = 0x36
|
||||
PAGE_UP = 0x37
|
||||
PAGE_DOWN = 0x38
|
||||
POWER = 0x40
|
||||
VOLUME_UP = 0x41
|
||||
VOLUME_DOWN = 0x42
|
||||
MUTE = 0x43
|
||||
PLAY = 0x44
|
||||
STOP = 0x45
|
||||
PAUSE = 0x46
|
||||
RECORD = 0x47
|
||||
REWIND = 0x48
|
||||
FAST_FORWARD = 0x49
|
||||
EJECT = 0x4A
|
||||
FORWARD = 0x4B
|
||||
BACKWARD = 0x4C
|
||||
ANGLE = 0x50
|
||||
SUBPICTURE = 0x51
|
||||
F1 = 0x71
|
||||
F2 = 0x72
|
||||
F3 = 0x73
|
||||
F4 = 0x74
|
||||
F5 = 0x75
|
||||
VENDOR_UNIQUE = 0x7E
|
||||
|
||||
state_flag: StateFlag
|
||||
operation_id: OperationId
|
||||
operation_data: bytes
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
return (
|
||||
PassThroughFrame.StateFlag(operands[0] >> 7),
|
||||
PassThroughFrame.OperationId(operands[0] & 0x7F),
|
||||
operands[1 : 1 + operands[1]],
|
||||
)
|
||||
|
||||
def make_operands(self):
|
||||
return (
|
||||
bytes([self.state_flag << 7 | self.operation_id, len(self.operation_data)])
|
||||
+ self.operation_data
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_flag: StateFlag,
|
||||
operation_id: OperationId,
|
||||
operation_data: bytes,
|
||||
) -> None:
|
||||
if len(operation_data) > 255:
|
||||
raise core.InvalidArgumentError("operation data must be <= 255 bytes")
|
||||
self.state_flag = state_flag
|
||||
self.operation_id = operation_id
|
||||
self.operation_data = operation_data
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@Frame.subclass
|
||||
class PassThroughCommandFrame(PassThroughFrame, CommandFrame):
|
||||
def __init__(
|
||||
self,
|
||||
ctype: CommandFrame.CommandType,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
state_flag: PassThroughFrame.StateFlag,
|
||||
operation_id: PassThroughFrame.OperationId,
|
||||
operation_data: bytes,
|
||||
) -> None:
|
||||
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
|
||||
CommandFrame.__init__(
|
||||
self,
|
||||
ctype,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
Frame.OperationCode.PASS_THROUGH,
|
||||
self.make_operands(),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"PassThroughCommandFrame(ctype={self.ctype.name}, "
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"state_flag={self.state_flag.name}, "
|
||||
f"operation_id={self.operation_id.name}, "
|
||||
f"operation_data={self.operation_data.hex()})"
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@Frame.subclass
|
||||
class PassThroughResponseFrame(PassThroughFrame, ResponseFrame):
|
||||
def __init__(
|
||||
self,
|
||||
response: ResponseFrame.ResponseCode,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
state_flag: PassThroughFrame.StateFlag,
|
||||
operation_id: PassThroughFrame.OperationId,
|
||||
operation_data: bytes,
|
||||
) -> None:
|
||||
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
|
||||
ResponseFrame.__init__(
|
||||
self,
|
||||
response,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
Frame.OperationCode.PASS_THROUGH,
|
||||
self.make_operands(),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"PassThroughResponseFrame(response={self.response.name}, "
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"state_flag={self.state_flag.name}, "
|
||||
f"operation_id={self.operation_id.name}, "
|
||||
f"operation_data={self.operation_data.hex()})"
|
||||
)
|
||||
292
bumble/avctp.py
Normal file
292
bumble/avctp.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
import logging
|
||||
import struct
|
||||
from typing import Callable, cast, Dict, Optional
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble import avc
|
||||
from bumble import core
|
||||
from bumble import l2cap
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
AVCTP_PSM = 0x0017
|
||||
AVCTP_BROWSING_PSM = 0x001B
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MessageAssembler:
|
||||
Callback = Callable[[int, bool, bool, int, bytes], None]
|
||||
|
||||
transaction_label: int
|
||||
pid: int
|
||||
c_r: int
|
||||
ipid: int
|
||||
payload: bytes
|
||||
number_of_packets: int
|
||||
packets_received: int
|
||||
|
||||
def __init__(self, callback: Callback) -> None:
|
||||
self.callback = callback
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.packets_received = 0
|
||||
self.transaction_label = -1
|
||||
self.pid = -1
|
||||
self.c_r = -1
|
||||
self.ipid = -1
|
||||
self.payload = b''
|
||||
self.number_of_packets = 0
|
||||
self.packet_count = 0
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
self.packets_received += 1
|
||||
|
||||
transaction_label = pdu[0] >> 4
|
||||
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
|
||||
c_r = (pdu[0] >> 1) & 1
|
||||
ipid = pdu[0] & 1
|
||||
|
||||
if c_r == 0 and ipid != 0:
|
||||
logger.warning("invalid IPID in command frame")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
pid_offset = 1
|
||||
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START):
|
||||
if self.transaction_label >= 0:
|
||||
# We are already in a transaction
|
||||
logger.warning("received START or SINGLE fragment while in transaction")
|
||||
self.reset()
|
||||
self.packets_received = 1
|
||||
|
||||
if packet_type == Protocol.PacketType.START:
|
||||
self.number_of_packets = pdu[1]
|
||||
pid_offset = 2
|
||||
|
||||
pid = struct.unpack_from(">H", pdu, pid_offset)[0]
|
||||
self.payload += pdu[pid_offset + 2 :]
|
||||
|
||||
if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END):
|
||||
if transaction_label != self.transaction_label:
|
||||
logger.warning("transaction label does not match")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if pid != self.pid:
|
||||
logger.warning("PID does not match")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if c_r != self.c_r:
|
||||
logger.warning("C/R does not match")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if self.packets_received > self.number_of_packets:
|
||||
logger.warning("too many fragments in transaction")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if packet_type == Protocol.PacketType.END:
|
||||
if self.packets_received != self.number_of_packets:
|
||||
logger.warning("premature END")
|
||||
self.reset()
|
||||
return
|
||||
else:
|
||||
self.transaction_label = transaction_label
|
||||
self.c_r = c_r
|
||||
self.ipid = ipid
|
||||
self.pid = pid
|
||||
|
||||
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END):
|
||||
self.on_message_complete()
|
||||
|
||||
def on_message_complete(self):
|
||||
try:
|
||||
self.callback(
|
||||
self.transaction_label,
|
||||
self.c_r == 0,
|
||||
self.ipid != 0,
|
||||
self.pid,
|
||||
self.payload,
|
||||
)
|
||||
except Exception as error:
|
||||
logger.exception(color(f"!!! exception in callback: {error}", "red"))
|
||||
|
||||
self.reset()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Protocol:
|
||||
CommandHandler = Callable[[int, avc.CommandFrame], None]
|
||||
command_handlers: Dict[int, CommandHandler] # Command handlers, by PID
|
||||
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
|
||||
response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID
|
||||
next_transaction_label: int
|
||||
message_assembler: MessageAssembler
|
||||
|
||||
class PacketType(IntEnum):
|
||||
SINGLE = 0b00
|
||||
START = 0b01
|
||||
CONTINUE = 0b10
|
||||
END = 0b11
|
||||
|
||||
def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
self.command_handlers = {}
|
||||
self.response_handlers = {}
|
||||
self.l2cap_channel = l2cap_channel
|
||||
self.message_assembler = MessageAssembler(self.on_message)
|
||||
|
||||
# Register to receive PDUs from the channel
|
||||
l2cap_channel.sink = self.on_pdu
|
||||
l2cap_channel.on("open", self.on_l2cap_channel_open)
|
||||
l2cap_channel.on("close", self.on_l2cap_channel_close)
|
||||
|
||||
def on_l2cap_channel_open(self):
|
||||
logger.debug(color("<<< AVCTP channel open", "magenta"))
|
||||
|
||||
def on_l2cap_channel_close(self):
|
||||
logger.debug(color("<<< AVCTP channel closed", "magenta"))
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
self.message_assembler.on_pdu(pdu)
|
||||
|
||||
def on_message(
|
||||
self,
|
||||
transaction_label: int,
|
||||
is_command: bool,
|
||||
ipid: bool,
|
||||
pid: int,
|
||||
payload: bytes,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"<<< AVCTP Message: pid={pid}, "
|
||||
f"transaction_label={transaction_label}, "
|
||||
f"is_command={is_command}, "
|
||||
f"ipid={ipid}, "
|
||||
f"payload={payload.hex()}"
|
||||
)
|
||||
|
||||
# Check for invalid PID responses.
|
||||
if ipid:
|
||||
logger.debug(f"received IPID for PID={pid}")
|
||||
|
||||
# Find the appropriate handler.
|
||||
if is_command:
|
||||
if pid not in self.command_handlers:
|
||||
logger.warning(f"no command handler for PID {pid}")
|
||||
self.send_ipid(transaction_label, pid)
|
||||
return
|
||||
|
||||
command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
|
||||
self.command_handlers[pid](transaction_label, command_frame)
|
||||
else:
|
||||
if pid not in self.response_handlers:
|
||||
logger.warning(f"no response handler for PID {pid}")
|
||||
return
|
||||
|
||||
# By convention, for an ipid, send a None payload to the response handler.
|
||||
if ipid:
|
||||
response_frame = None
|
||||
else:
|
||||
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
|
||||
|
||||
self.response_handlers[pid](transaction_label, response_frame)
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
transaction_label: int,
|
||||
is_command: bool,
|
||||
ipid: bool,
|
||||
pid: int,
|
||||
payload: bytes,
|
||||
):
|
||||
# TODO: fragment large messages
|
||||
packet_type = Protocol.PacketType.SINGLE
|
||||
pdu = (
|
||||
struct.pack(
|
||||
">BH",
|
||||
transaction_label << 4
|
||||
| packet_type << 2
|
||||
| (0 if is_command else 1) << 1
|
||||
| (1 if ipid else 0),
|
||||
pid,
|
||||
)
|
||||
+ payload
|
||||
)
|
||||
self.l2cap_channel.send_pdu(pdu)
|
||||
|
||||
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
|
||||
logger.debug(
|
||||
">>> AVCTP command: "
|
||||
f"transaction_label={transaction_label}, "
|
||||
f"pid={pid}, "
|
||||
f"payload={payload.hex()}"
|
||||
)
|
||||
self.send_message(transaction_label, True, False, pid, payload)
|
||||
|
||||
def send_response(self, transaction_label: int, pid: int, payload: bytes):
|
||||
logger.debug(
|
||||
">>> AVCTP response: "
|
||||
f"transaction_label={transaction_label}, "
|
||||
f"pid={pid}, "
|
||||
f"payload={payload.hex()}"
|
||||
)
|
||||
self.send_message(transaction_label, False, False, pid, payload)
|
||||
|
||||
def send_ipid(self, transaction_label: int, pid: int) -> None:
|
||||
logger.debug(
|
||||
">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
|
||||
)
|
||||
self.send_message(transaction_label, False, True, pid, b'')
|
||||
|
||||
def register_command_handler(
|
||||
self, pid: int, handler: Protocol.CommandHandler
|
||||
) -> None:
|
||||
self.command_handlers[pid] = handler
|
||||
|
||||
def unregister_command_handler(
|
||||
self, pid: int, handler: Protocol.CommandHandler
|
||||
) -> None:
|
||||
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
|
||||
raise core.InvalidArgumentError("command handler not registered")
|
||||
del self.command_handlers[pid]
|
||||
|
||||
def register_response_handler(
|
||||
self, pid: int, handler: Protocol.ResponseHandler
|
||||
) -> None:
|
||||
self.response_handlers[pid] = handler
|
||||
|
||||
def unregister_response_handler(
|
||||
self, pid: int, handler: Protocol.ResponseHandler
|
||||
) -> None:
|
||||
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
|
||||
raise core.InvalidArgumentError("response handler not registered")
|
||||
del self.response_handlers[pid]
|
||||
512
bumble/avdtp.py
512
bumble/avdtp.py
File diff suppressed because it is too large
Load Diff
1919
bumble/avrcp.py
Normal file
1919
bumble/avrcp.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -18,6 +18,8 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
|
||||
from bumble import core
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class BitReader:
|
||||
@@ -40,7 +42,7 @@ class BitReader:
|
||||
""" "Read up to 32 bits."""
|
||||
|
||||
if bits > 32:
|
||||
raise ValueError('maximum read size is 32')
|
||||
raise core.InvalidArgumentError('maximum read size is 32')
|
||||
|
||||
if self.bits_cached >= bits:
|
||||
# We have enough bits.
|
||||
@@ -53,7 +55,7 @@ class BitReader:
|
||||
feed_size = len(feed_bytes)
|
||||
feed_int = int.from_bytes(feed_bytes, byteorder='big')
|
||||
if 8 * feed_size + self.bits_cached < bits:
|
||||
raise ValueError('trying to read past the data')
|
||||
raise core.InvalidArgumentError('trying to read past the data')
|
||||
self.byte_position += feed_size
|
||||
|
||||
# Combine the new cache and the old cache
|
||||
@@ -68,7 +70,7 @@ class BitReader:
|
||||
|
||||
def read_bytes(self, count: int):
|
||||
if self.bit_position + 8 * count > 8 * len(self.data):
|
||||
raise ValueError('not enough data')
|
||||
raise core.InvalidArgumentError('not enough data')
|
||||
|
||||
if self.bit_position % 8:
|
||||
# Not byte aligned
|
||||
@@ -113,7 +115,7 @@ class AacAudioRtpPacket:
|
||||
|
||||
@staticmethod
|
||||
def program_config_element(reader: BitReader):
|
||||
raise ValueError('program_config_element not supported')
|
||||
raise core.InvalidPacketError('program_config_element not supported')
|
||||
|
||||
@dataclass
|
||||
class GASpecificConfig:
|
||||
@@ -140,7 +142,7 @@ class AacAudioRtpPacket:
|
||||
aac_spectral_data_resilience_flags = reader.read(1)
|
||||
extension_flag_3 = reader.read(1)
|
||||
if extension_flag_3 == 1:
|
||||
raise ValueError('extensionFlag3 == 1 not supported')
|
||||
raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
|
||||
|
||||
@staticmethod
|
||||
def audio_object_type(reader: BitReader):
|
||||
@@ -216,7 +218,7 @@ class AacAudioRtpPacket:
|
||||
reader, self.channel_configuration, self.audio_object_type
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
raise core.InvalidPacketError(
|
||||
f'audioObjectType {self.audio_object_type} not supported'
|
||||
)
|
||||
|
||||
@@ -260,7 +262,7 @@ class AacAudioRtpPacket:
|
||||
else:
|
||||
audio_mux_version_a = 0
|
||||
if audio_mux_version_a != 0:
|
||||
raise ValueError('audioMuxVersionA != 0 not supported')
|
||||
raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
|
||||
if audio_mux_version == 1:
|
||||
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
|
||||
stream_cnt = 0
|
||||
@@ -268,10 +270,10 @@ class AacAudioRtpPacket:
|
||||
num_sub_frames = reader.read(6)
|
||||
num_program = reader.read(4)
|
||||
if num_program != 0:
|
||||
raise ValueError('num_program != 0 not supported')
|
||||
raise core.InvalidPacketError('num_program != 0 not supported')
|
||||
num_layer = reader.read(3)
|
||||
if num_layer != 0:
|
||||
raise ValueError('num_layer != 0 not supported')
|
||||
raise core.InvalidPacketError('num_layer != 0 not supported')
|
||||
if audio_mux_version == 0:
|
||||
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
|
||||
reader
|
||||
@@ -284,7 +286,7 @@ class AacAudioRtpPacket:
|
||||
)
|
||||
audio_specific_config_len = reader.bit_position - marker
|
||||
if asc_len < audio_specific_config_len:
|
||||
raise ValueError('audio_specific_config_len > asc_len')
|
||||
raise core.InvalidPacketError('audio_specific_config_len > asc_len')
|
||||
asc_len -= audio_specific_config_len
|
||||
reader.skip(asc_len)
|
||||
frame_length_type = reader.read(3)
|
||||
@@ -293,7 +295,9 @@ class AacAudioRtpPacket:
|
||||
elif frame_length_type == 1:
|
||||
frame_length = reader.read(9)
|
||||
else:
|
||||
raise ValueError(f'frame_length_type {frame_length_type} not supported')
|
||||
raise core.InvalidPacketError(
|
||||
f'frame_length_type {frame_length_type} not supported'
|
||||
)
|
||||
|
||||
self.other_data_present = reader.read(1)
|
||||
if self.other_data_present:
|
||||
@@ -318,12 +322,12 @@ class AacAudioRtpPacket:
|
||||
|
||||
def __init__(self, reader: BitReader, mux_config_present: int):
|
||||
if mux_config_present == 0:
|
||||
raise ValueError('muxConfigPresent == 0 not supported')
|
||||
raise core.InvalidPacketError('muxConfigPresent == 0 not supported')
|
||||
|
||||
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41
|
||||
use_same_stream_mux = reader.read(1)
|
||||
if use_same_stream_mux:
|
||||
raise ValueError('useSameStreamMux == 1 not supported')
|
||||
raise core.InvalidPacketError('useSameStreamMux == 1 not supported')
|
||||
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
|
||||
|
||||
# We only support:
|
||||
|
||||
@@ -16,6 +16,10 @@ from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
class ColorError(ValueError):
|
||||
"""Error raised when a color spec is invalid."""
|
||||
|
||||
|
||||
# ANSI color names. There is also a "default"
|
||||
COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
|
||||
|
||||
@@ -52,7 +56,7 @@ def _color_code(spec: ColorSpec, base: int) -> str:
|
||||
elif isinstance(spec, int) and 0 <= spec <= 255:
|
||||
return _join(base + 8, 5, spec)
|
||||
else:
|
||||
raise ValueError('Invalid color spec "%s"' % spec)
|
||||
raise ColorError('Invalid color spec "%s"' % spec)
|
||||
|
||||
|
||||
def color(
|
||||
@@ -72,7 +76,7 @@ def color(
|
||||
if style_part in STYLES:
|
||||
codes.append(STYLES.index(style_part))
|
||||
else:
|
||||
raise ValueError('Invalid style "%s"' % style_part)
|
||||
raise ColorError('Invalid style "%s"' % style_part)
|
||||
|
||||
if codes:
|
||||
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,6 +19,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import itertools
|
||||
import random
|
||||
import struct
|
||||
@@ -42,6 +43,7 @@ from bumble.hci import (
|
||||
HCI_LE_1M_PHY,
|
||||
HCI_SUCCESS,
|
||||
HCI_UNKNOWN_HCI_COMMAND_ERROR,
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
HCI_VERSION_BLUETOOTH_CORE_5_0,
|
||||
Address,
|
||||
@@ -53,17 +55,21 @@ from bumble.hci import (
|
||||
HCI_Connection_Request_Event,
|
||||
HCI_Disconnection_Complete_Event,
|
||||
HCI_Encryption_Change_Event,
|
||||
HCI_Synchronous_Connection_Complete_Event,
|
||||
HCI_LE_Advertising_Report_Event,
|
||||
HCI_LE_CIS_Established_Event,
|
||||
HCI_LE_CIS_Request_Event,
|
||||
HCI_LE_Connection_Complete_Event,
|
||||
HCI_LE_Read_Remote_Features_Complete_Event,
|
||||
HCI_Number_Of_Completed_Packets_Event,
|
||||
HCI_Packet,
|
||||
HCI_Role_Change_Event,
|
||||
)
|
||||
from typing import Optional, Union, Dict, TYPE_CHECKING
|
||||
from typing import Optional, Union, Dict, Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.transport.common import TransportSink, TransportSource
|
||||
from bumble.link import LocalLink
|
||||
from bumble.transport.common import TransportSink
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -79,15 +85,27 @@ class DataObject:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class CisLink:
|
||||
handle: int
|
||||
cis_id: int
|
||||
cig_id: int
|
||||
acl_connection: Optional[Connection] = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class Connection:
|
||||
def __init__(self, controller, handle, role, peer_address, link, transport):
|
||||
self.controller = controller
|
||||
self.handle = handle
|
||||
self.role = role
|
||||
self.peer_address = peer_address
|
||||
self.link = link
|
||||
controller: Controller
|
||||
handle: int
|
||||
role: int
|
||||
peer_address: Address
|
||||
link: Any
|
||||
transport: int
|
||||
link_type: int
|
||||
|
||||
def __post_init__(self):
|
||||
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
|
||||
self.transport = transport
|
||||
|
||||
def on_hci_acl_data_packet(self, packet):
|
||||
self.assembler.feed_packet(packet)
|
||||
@@ -106,25 +124,27 @@ class Connection:
|
||||
class Controller:
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
name: str,
|
||||
host_source=None,
|
||||
host_sink: Optional[TransportSink] = None,
|
||||
link=None,
|
||||
link: Optional[LocalLink] = None,
|
||||
public_address: Optional[Union[bytes, str, Address]] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.hci_sink = None
|
||||
self.link = link
|
||||
|
||||
self.central_connections: Dict[
|
||||
Address, Connection
|
||||
] = {} # Connections where this controller is the central
|
||||
self.peripheral_connections: Dict[
|
||||
Address, Connection
|
||||
] = {} # Connections where this controller is the peripheral
|
||||
self.classic_connections: Dict[
|
||||
Address, Connection
|
||||
] = {} # Connections in BR/EDR
|
||||
self.central_connections: Dict[Address, Connection] = (
|
||||
{}
|
||||
) # Connections where this controller is the central
|
||||
self.peripheral_connections: Dict[Address, Connection] = (
|
||||
{}
|
||||
) # Connections where this controller is the peripheral
|
||||
self.classic_connections: Dict[Address, Connection] = (
|
||||
{}
|
||||
) # Connections in BR/EDR
|
||||
self.central_cis_links: Dict[int, CisLink] = {} # CIS links by handle
|
||||
self.peripheral_cis_links: Dict[int, CisLink] = {} # CIS links by handle
|
||||
|
||||
self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0
|
||||
self.hci_revision = 0
|
||||
@@ -134,12 +154,14 @@ class Controller:
|
||||
'0000000060000000'
|
||||
) # BR/EDR Not Supported, LE Supported (Controller)
|
||||
self.manufacturer_name = 0xFFFF
|
||||
self.hc_data_packet_length = 27
|
||||
self.hc_total_num_data_packets = 64
|
||||
self.hc_le_data_packet_length = 27
|
||||
self.hc_total_num_le_data_packets = 64
|
||||
self.event_mask = 0
|
||||
self.event_mask_page_2 = 0
|
||||
self.supported_commands = bytes.fromhex(
|
||||
'2000800000c000000000e40000002822000000000000040000f7ffff7f000000'
|
||||
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
|
||||
'30f0f9ff01008004000000000000000000000000000000000000000000000000'
|
||||
)
|
||||
self.le_event_mask = 0
|
||||
@@ -301,7 +323,7 @@ class Controller:
|
||||
############################################################
|
||||
# Link connections
|
||||
############################################################
|
||||
def allocate_connection_handle(self):
|
||||
def allocate_connection_handle(self) -> int:
|
||||
handle = 0
|
||||
max_handle = 0
|
||||
for connection in itertools.chain(
|
||||
@@ -313,6 +335,13 @@ class Controller:
|
||||
if connection.handle == handle:
|
||||
# Already used, continue searching after the current max
|
||||
handle = max_handle + 1
|
||||
for cis_handle in itertools.chain(
|
||||
self.central_cis_links.keys(), self.peripheral_cis_links.keys()
|
||||
):
|
||||
max_handle = max(max_handle, cis_handle)
|
||||
if cis_handle == handle:
|
||||
# Already used, continue searching after the current max
|
||||
handle = max_handle + 1
|
||||
return handle
|
||||
|
||||
def find_le_connection_by_address(self, address):
|
||||
@@ -357,12 +386,13 @@ class Controller:
|
||||
if connection is None:
|
||||
connection_handle = self.allocate_connection_handle()
|
||||
connection = Connection(
|
||||
self,
|
||||
connection_handle,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
peer_address,
|
||||
self.link,
|
||||
BT_LE_TRANSPORT,
|
||||
controller=self,
|
||||
handle=connection_handle,
|
||||
role=BT_PERIPHERAL_ROLE,
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=BT_LE_TRANSPORT,
|
||||
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
)
|
||||
self.peripheral_connections[peer_address] = connection
|
||||
logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}')
|
||||
@@ -416,12 +446,13 @@ class Controller:
|
||||
if connection is None:
|
||||
connection_handle = self.allocate_connection_handle()
|
||||
connection = Connection(
|
||||
self,
|
||||
connection_handle,
|
||||
BT_CENTRAL_ROLE,
|
||||
peer_address,
|
||||
self.link,
|
||||
BT_LE_TRANSPORT,
|
||||
controller=self,
|
||||
handle=connection_handle,
|
||||
role=BT_CENTRAL_ROLE,
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=BT_LE_TRANSPORT,
|
||||
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
)
|
||||
self.central_connections[peer_address] = connection
|
||||
logger.debug(
|
||||
@@ -538,6 +569,104 @@ class Controller:
|
||||
)
|
||||
self.send_hci_packet(HCI_LE_Advertising_Report_Event([report]))
|
||||
|
||||
def on_link_cis_request(
|
||||
self, central_address: Address, cig_id: int, cis_id: int
|
||||
) -> None:
|
||||
'''
|
||||
Called when an incoming CIS request occurs from a central on the link
|
||||
'''
|
||||
|
||||
connection = self.peripheral_connections.get(central_address)
|
||||
assert connection
|
||||
|
||||
pending_cis_link = CisLink(
|
||||
handle=self.allocate_connection_handle(),
|
||||
cis_id=cis_id,
|
||||
cig_id=cig_id,
|
||||
acl_connection=connection,
|
||||
)
|
||||
self.peripheral_cis_links[pending_cis_link.handle] = pending_cis_link
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_LE_CIS_Request_Event(
|
||||
acl_connection_handle=connection.handle,
|
||||
cis_connection_handle=pending_cis_link.handle,
|
||||
cig_id=cig_id,
|
||||
cis_id=cis_id,
|
||||
)
|
||||
)
|
||||
|
||||
def on_link_cis_established(self, cig_id: int, cis_id: int) -> None:
|
||||
'''
|
||||
Called when an incoming CIS established.
|
||||
'''
|
||||
|
||||
cis_link = next(
|
||||
cis_link
|
||||
for cis_link in itertools.chain(
|
||||
self.central_cis_links.values(), self.peripheral_cis_links.values()
|
||||
)
|
||||
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
|
||||
)
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_LE_CIS_Established_Event(
|
||||
status=HCI_SUCCESS,
|
||||
connection_handle=cis_link.handle,
|
||||
# CIS parameters are ignored.
|
||||
cig_sync_delay=0,
|
||||
cis_sync_delay=0,
|
||||
transport_latency_c_to_p=0,
|
||||
transport_latency_p_to_c=0,
|
||||
phy_c_to_p=0,
|
||||
phy_p_to_c=0,
|
||||
nse=0,
|
||||
bn_c_to_p=0,
|
||||
bn_p_to_c=0,
|
||||
ft_c_to_p=0,
|
||||
ft_p_to_c=0,
|
||||
max_pdu_c_to_p=0,
|
||||
max_pdu_p_to_c=0,
|
||||
iso_interval=0,
|
||||
)
|
||||
)
|
||||
|
||||
def on_link_cis_disconnected(self, cig_id: int, cis_id: int) -> None:
|
||||
'''
|
||||
Called when a CIS disconnected.
|
||||
'''
|
||||
|
||||
if cis_link := next(
|
||||
(
|
||||
cis_link
|
||||
for cis_link in self.peripheral_cis_links.values()
|
||||
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
|
||||
),
|
||||
None,
|
||||
):
|
||||
# Remove peripheral CIS on disconnection.
|
||||
self.peripheral_cis_links.pop(cis_link.handle)
|
||||
elif cis_link := next(
|
||||
(
|
||||
cis_link
|
||||
for cis_link in self.central_cis_links.values()
|
||||
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
|
||||
),
|
||||
None,
|
||||
):
|
||||
# Keep central CIS on disconnection. They should be removed by HCI_LE_Remove_CIG_Command.
|
||||
cis_link.acl_connection = None
|
||||
else:
|
||||
return
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_Disconnection_Complete_Event(
|
||||
status=HCI_SUCCESS,
|
||||
connection_handle=cis_link.handle,
|
||||
reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
############################################################
|
||||
# Classic link connections
|
||||
############################################################
|
||||
@@ -566,6 +695,7 @@ class Controller:
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=BT_BR_EDR_TRANSPORT,
|
||||
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
)
|
||||
self.classic_connections[peer_address] = connection
|
||||
logger.debug(
|
||||
@@ -619,6 +749,42 @@ class Controller:
|
||||
)
|
||||
)
|
||||
|
||||
def on_classic_sco_connection_complete(
|
||||
self, peer_address: Address, status: int, link_type: int
|
||||
):
|
||||
if status == HCI_SUCCESS:
|
||||
# Allocate (or reuse) a connection handle
|
||||
connection_handle = self.allocate_connection_handle()
|
||||
connection = Connection(
|
||||
controller=self,
|
||||
handle=connection_handle,
|
||||
# Role doesn't matter in SCO.
|
||||
role=BT_CENTRAL_ROLE,
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=BT_BR_EDR_TRANSPORT,
|
||||
link_type=link_type,
|
||||
)
|
||||
self.classic_connections[peer_address] = connection
|
||||
logger.debug(f'New SCO connection handle: 0x{connection_handle:04X}')
|
||||
else:
|
||||
connection_handle = 0
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_Synchronous_Connection_Complete_Event(
|
||||
status=status,
|
||||
connection_handle=connection_handle,
|
||||
bd_addr=peer_address,
|
||||
link_type=link_type,
|
||||
# TODO: Provide SCO connection parameters.
|
||||
transmission_interval=0,
|
||||
retransmission_window=0,
|
||||
rx_packet_length=0,
|
||||
tx_packet_length=0,
|
||||
air_mode=0,
|
||||
)
|
||||
)
|
||||
|
||||
############################################################
|
||||
# Advertising support
|
||||
############################################################
|
||||
@@ -721,6 +887,17 @@ class Controller:
|
||||
else:
|
||||
# Remove the connection
|
||||
del self.classic_connections[connection.peer_address]
|
||||
elif cis_link := (
|
||||
self.central_cis_links.get(handle) or self.peripheral_cis_links.get(handle)
|
||||
):
|
||||
if self.link:
|
||||
self.link.disconnect_cis(
|
||||
initiator_controller=self,
|
||||
peer_address=cis_link.acl_connection.peer_address,
|
||||
cig_id=cis_link.cig_id,
|
||||
cis_id=cis_link.cis_id,
|
||||
)
|
||||
# Spec requires handle to be kept after disconnection.
|
||||
|
||||
def on_hci_accept_connection_request_command(self, command):
|
||||
'''
|
||||
@@ -738,6 +915,68 @@ class Controller:
|
||||
)
|
||||
self.link.classic_accept_connection(self, command.bd_addr, command.role)
|
||||
|
||||
def on_hci_enhanced_setup_synchronous_connection_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.1.45 Enhanced Setup Synchronous Connection command
|
||||
'''
|
||||
|
||||
if self.link is None:
|
||||
return
|
||||
|
||||
if not (
|
||||
connection := self.find_classic_connection_by_handle(
|
||||
command.connection_handle
|
||||
)
|
||||
):
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_SUCCESS,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
self.link.classic_sco_connect(
|
||||
self, connection.peer_address, HCI_Connection_Complete_Event.ESCO_LINK_TYPE
|
||||
)
|
||||
|
||||
def on_hci_enhanced_accept_synchronous_connection_request_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.1.46 Enhanced Accept Synchronous Connection Request command
|
||||
'''
|
||||
|
||||
if self.link is None:
|
||||
return
|
||||
|
||||
if not (connection := self.find_classic_connection_by_address(command.bd_addr)):
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_SUCCESS,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
self.link.classic_accept_sco_connection(
|
||||
self, connection.peer_address, HCI_Connection_Complete_Event.ESCO_LINK_TYPE
|
||||
)
|
||||
|
||||
def on_hci_switch_role_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.2.8 Switch Role command
|
||||
@@ -912,7 +1151,41 @@ class Controller:
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.4.3 Read Local Supported Features Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS]) + self.lmp_features
|
||||
return bytes([HCI_SUCCESS]) + self.lmp_features[:8]
|
||||
|
||||
def on_hci_read_local_extended_features_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.4.4 Read Local Extended Features Command
|
||||
'''
|
||||
if command.page_number * 8 > len(self.lmp_features):
|
||||
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
|
||||
return (
|
||||
bytes(
|
||||
[
|
||||
# Status
|
||||
HCI_SUCCESS,
|
||||
# Page number
|
||||
command.page_number,
|
||||
# Max page number
|
||||
len(self.lmp_features) // 8 - 1,
|
||||
]
|
||||
)
|
||||
# Features of the current page
|
||||
+ self.lmp_features[command.page_number * 8 : (command.page_number + 1) * 8]
|
||||
)
|
||||
|
||||
def on_hci_read_buffer_size_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.4.5 Read Buffer Size Command
|
||||
'''
|
||||
return struct.pack(
|
||||
'<BHBHH',
|
||||
HCI_SUCCESS,
|
||||
self.hc_data_packet_length,
|
||||
0,
|
||||
self.hc_total_num_data_packets,
|
||||
0,
|
||||
)
|
||||
|
||||
def on_hci_read_bd_addr_command(self, _command):
|
||||
'''
|
||||
@@ -1000,6 +1273,9 @@ class Controller:
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.10 LE Set Scan Parameters Command
|
||||
'''
|
||||
if self.le_scan_enable:
|
||||
return bytes([HCI_COMMAND_DISALLOWED_ERROR])
|
||||
|
||||
self.le_scan_type = command.le_scan_type
|
||||
self.le_scan_interval = command.le_scan_interval
|
||||
self.le_scan_window = command.le_scan_window
|
||||
@@ -1086,6 +1362,18 @@ class Controller:
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.21 LE Read Remote Features Command
|
||||
'''
|
||||
|
||||
handle = command.connection_handle
|
||||
|
||||
if not self.find_connection_by_handle(handle):
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# First, say that the command is pending
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
@@ -1099,7 +1387,7 @@ class Controller:
|
||||
self.send_hci_packet(
|
||||
HCI_LE_Read_Remote_Features_Complete_Event(
|
||||
status=HCI_SUCCESS,
|
||||
connection_handle=0,
|
||||
connection_handle=handle,
|
||||
le_features=bytes.fromhex('dd40000000000000'),
|
||||
)
|
||||
)
|
||||
@@ -1255,8 +1543,135 @@ class Controller:
|
||||
}
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_read_maximum_advertising_data_length_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.57 LE Read Maximum Advertising Data
|
||||
Length Command
|
||||
'''
|
||||
return struct.pack('<BH', HCI_SUCCESS, 0x0672)
|
||||
|
||||
def on_hci_le_read_number_of_supported_advertising_sets_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.58 LE Read Number of Supported
|
||||
Advertising Set Command
|
||||
'''
|
||||
return struct.pack('<BB', HCI_SUCCESS, 0xF0)
|
||||
|
||||
def on_hci_le_read_transmit_power_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command
|
||||
'''
|
||||
return struct.pack('<BBB', HCI_SUCCESS, 0, 0)
|
||||
|
||||
def on_hci_le_set_cig_parameters_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.97 LE Set CIG Parameter Command
|
||||
'''
|
||||
|
||||
# Remove old CIG implicitly.
|
||||
for handle, cis_link in self.central_cis_links.items():
|
||||
if cis_link.cig_id == command.cig_id:
|
||||
self.central_cis_links.pop(handle)
|
||||
|
||||
handles = []
|
||||
for cis_id in command.cis_id:
|
||||
handle = self.allocate_connection_handle()
|
||||
handles.append(handle)
|
||||
self.central_cis_links[handle] = CisLink(
|
||||
cis_id=cis_id,
|
||||
cig_id=command.cig_id,
|
||||
handle=handle,
|
||||
)
|
||||
return struct.pack(
|
||||
'<BBB', HCI_SUCCESS, command.cig_id, len(handles)
|
||||
) + b''.join([struct.pack('<H', handle) for handle in handles])
|
||||
|
||||
def on_hci_le_create_cis_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.99 LE Create CIS Command
|
||||
'''
|
||||
if not self.link:
|
||||
return
|
||||
|
||||
for cis_handle, acl_handle in zip(
|
||||
command.cis_connection_handle, command.acl_connection_handle
|
||||
):
|
||||
if not (connection := self.find_connection_by_handle(acl_handle)):
|
||||
logger.error(f'Cannot find connection with handle={acl_handle}')
|
||||
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
|
||||
|
||||
if not (cis_link := self.central_cis_links.get(cis_handle)):
|
||||
logger.error(f'Cannot find CIS with handle={cis_handle}')
|
||||
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
|
||||
|
||||
cis_link.acl_connection = connection
|
||||
|
||||
self.link.create_cis(
|
||||
self,
|
||||
peripheral_address=connection.peer_address,
|
||||
cig_id=cis_link.cig_id,
|
||||
cis_id=cis_link.cis_id,
|
||||
)
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_COMMAND_STATUS_PENDING,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
|
||||
def on_hci_le_remove_cig_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.100 LE Remove CIG Command
|
||||
'''
|
||||
|
||||
status = HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR
|
||||
|
||||
for cis_handle, cis_link in self.central_cis_links.items():
|
||||
if cis_link.cig_id == command.cig_id:
|
||||
self.central_cis_links.pop(cis_handle)
|
||||
status = HCI_SUCCESS
|
||||
|
||||
return struct.pack('<BH', status, command.cig_id)
|
||||
|
||||
def on_hci_le_accept_cis_request_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.101 LE Accept CIS Request Command
|
||||
'''
|
||||
if not self.link:
|
||||
return
|
||||
|
||||
if not (
|
||||
pending_cis_link := self.peripheral_cis_links.get(command.connection_handle)
|
||||
):
|
||||
logger.error(f'Cannot find CIS with handle={command.connection_handle}')
|
||||
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
|
||||
|
||||
assert pending_cis_link.acl_connection
|
||||
self.link.accept_cis(
|
||||
peripheral_controller=self,
|
||||
central_address=pending_cis_link.acl_connection.peer_address,
|
||||
cig_id=pending_cis_link.cig_id,
|
||||
cis_id=pending_cis_link.cis_id,
|
||||
)
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_COMMAND_STATUS_PENDING,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
|
||||
def on_hci_le_setup_iso_data_path_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.109 LE Setup ISO Data Path Command
|
||||
'''
|
||||
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
|
||||
|
||||
def on_hci_le_remove_iso_data_path_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command
|
||||
'''
|
||||
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
|
||||
|
||||
879
bumble/core.py
879
bumble/core.py
File diff suppressed because it is too large
Load Diff
158
bumble/crypto.py
158
bumble/crypto.py
@@ -21,6 +21,8 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import operator
|
||||
|
||||
@@ -29,11 +31,13 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||
generate_private_key,
|
||||
ECDH,
|
||||
EllipticCurvePrivateKey,
|
||||
EllipticCurvePublicNumbers,
|
||||
EllipticCurvePrivateNumbers,
|
||||
SECP256R1,
|
||||
)
|
||||
from cryptography.hazmat.primitives import cmac
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -46,16 +50,18 @@ logger = logging.getLogger(__name__)
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class EccKey:
|
||||
def __init__(self, private_key):
|
||||
def __init__(self, private_key: EllipticCurvePrivateKey) -> None:
|
||||
self.private_key = private_key
|
||||
|
||||
@classmethod
|
||||
def generate(cls):
|
||||
def generate(cls) -> EccKey:
|
||||
private_key = generate_private_key(SECP256R1())
|
||||
return cls(private_key)
|
||||
|
||||
@classmethod
|
||||
def from_private_key_bytes(cls, d_bytes, x_bytes, y_bytes):
|
||||
def from_private_key_bytes(
|
||||
cls, d_bytes: bytes, x_bytes: bytes, y_bytes: bytes
|
||||
) -> EccKey:
|
||||
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
|
||||
x = int.from_bytes(x_bytes, byteorder='big', signed=False)
|
||||
y = int.from_bytes(y_bytes, byteorder='big', signed=False)
|
||||
@@ -65,7 +71,7 @@ class EccKey:
|
||||
return cls(private_key)
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
def x(self) -> bytes:
|
||||
return (
|
||||
self.private_key.public_key()
|
||||
.public_numbers()
|
||||
@@ -73,14 +79,14 @@ class EccKey:
|
||||
)
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
def y(self) -> bytes:
|
||||
return (
|
||||
self.private_key.public_key()
|
||||
.public_numbers()
|
||||
.y.to_bytes(32, byteorder='big')
|
||||
)
|
||||
|
||||
def dh(self, public_key_x, public_key_y):
|
||||
def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes:
|
||||
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
|
||||
y = int.from_bytes(public_key_y, byteorder='big', signed=False)
|
||||
public_key = EllipticCurvePublicNumbers(x, y, SECP256R1()).public_key()
|
||||
@@ -93,14 +99,33 @@ class EccKey:
|
||||
# Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def xor(x, y):
|
||||
def generate_prand() -> bytes:
|
||||
'''Generates random 3 bytes, with the 2 most significant bits of 0b01.
|
||||
|
||||
See Bluetooth spec, Vol 6, Part E - Table 1.2.
|
||||
'''
|
||||
prand_bytes = secrets.token_bytes(6)
|
||||
return prand_bytes[:2] + bytes([(prand_bytes[2] & 0b01111111) | 0b01000000])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def xor(x: bytes, y: bytes) -> bytes:
|
||||
assert len(x) == len(y)
|
||||
return bytes(map(operator.xor, x, y))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def r():
|
||||
def reverse(input: bytes) -> bytes:
|
||||
'''
|
||||
Returns bytes of input in reversed endianness.
|
||||
'''
|
||||
return input[::-1]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def r() -> bytes:
|
||||
'''
|
||||
Generate 16 bytes of random data
|
||||
'''
|
||||
@@ -108,20 +133,20 @@ def r():
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def e(key, data):
|
||||
def e(key: bytes, data: bytes) -> bytes:
|
||||
'''
|
||||
AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output.
|
||||
|
||||
See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e
|
||||
'''
|
||||
|
||||
cipher = Cipher(algorithms.AES(bytes(reversed(key))), modes.ECB())
|
||||
cipher = Cipher(algorithms.AES(reverse(key)), modes.ECB())
|
||||
encryptor = cipher.encryptor()
|
||||
return bytes(reversed(encryptor.update(bytes(reversed(data)))))
|
||||
return reverse(encryptor.update(reverse(data)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def ah(k, r): # pylint: disable=redefined-outer-name
|
||||
def ah(k: bytes, r: bytes) -> bytes: # pylint: disable=redefined-outer-name
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah
|
||||
'''
|
||||
@@ -132,7 +157,16 @@ def ah(k, r): # pylint: disable=redefined-outer-name
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-name
|
||||
def c1(
|
||||
k: bytes,
|
||||
r: bytes,
|
||||
preq: bytes,
|
||||
pres: bytes,
|
||||
iat: int,
|
||||
rat: int,
|
||||
ia: bytes,
|
||||
ra: bytes,
|
||||
) -> bytes: # pylint: disable=redefined-outer-name
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for
|
||||
LE Legacy Pairing
|
||||
@@ -144,7 +178,7 @@ def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-n
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def s1(k, r1, r2):
|
||||
def s1(k: bytes, r1: bytes, r2: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy
|
||||
Pairing
|
||||
@@ -154,7 +188,7 @@ def s1(k, r1, r2):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def aes_cmac(m, k):
|
||||
def aes_cmac(m: bytes, k: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC
|
||||
|
||||
@@ -166,20 +200,16 @@ def aes_cmac(m, k):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def f4(u, v, x, z):
|
||||
def f4(u: bytes, v: bytes, x: bytes, z: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value
|
||||
Generation Function f4
|
||||
'''
|
||||
return bytes(
|
||||
reversed(
|
||||
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))
|
||||
)
|
||||
)
|
||||
return reverse(aes_cmac(reverse(u) + reverse(v) + z, reverse(x)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def f5(w, n1, n2, a1, a2):
|
||||
def f5(w: bytes, n1: bytes, n2: bytes, a1: bytes, a2: bytes) -> Tuple[bytes, bytes]:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation
|
||||
Function f5
|
||||
@@ -187,87 +217,83 @@ def f5(w, n1, n2, a1, a2):
|
||||
NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order
|
||||
'''
|
||||
salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE')
|
||||
t = aes_cmac(bytes(reversed(w)), salt)
|
||||
t = aes_cmac(reverse(w), salt)
|
||||
key_id = bytes([0x62, 0x74, 0x6C, 0x65])
|
||||
return (
|
||||
bytes(
|
||||
reversed(
|
||||
aes_cmac(
|
||||
bytes([0])
|
||||
+ key_id
|
||||
+ bytes(reversed(n1))
|
||||
+ bytes(reversed(n2))
|
||||
+ bytes(reversed(a1))
|
||||
+ bytes(reversed(a2))
|
||||
+ bytes([1, 0]),
|
||||
t,
|
||||
)
|
||||
reverse(
|
||||
aes_cmac(
|
||||
bytes([0])
|
||||
+ key_id
|
||||
+ reverse(n1)
|
||||
+ reverse(n2)
|
||||
+ reverse(a1)
|
||||
+ reverse(a2)
|
||||
+ bytes([1, 0]),
|
||||
t,
|
||||
)
|
||||
),
|
||||
bytes(
|
||||
reversed(
|
||||
aes_cmac(
|
||||
bytes([1])
|
||||
+ key_id
|
||||
+ bytes(reversed(n1))
|
||||
+ bytes(reversed(n2))
|
||||
+ bytes(reversed(a1))
|
||||
+ bytes(reversed(a2))
|
||||
+ bytes([1, 0]),
|
||||
t,
|
||||
)
|
||||
reverse(
|
||||
aes_cmac(
|
||||
bytes([1])
|
||||
+ key_id
|
||||
+ reverse(n1)
|
||||
+ reverse(n2)
|
||||
+ reverse(a1)
|
||||
+ reverse(a2)
|
||||
+ bytes([1, 0]),
|
||||
t,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def f6(w, n1, n2, r, io_cap, a1, a2): # pylint: disable=redefined-outer-name
|
||||
def f6(
|
||||
w: bytes, n1: bytes, n2: bytes, r: bytes, io_cap: bytes, a1: bytes, a2: bytes
|
||||
) -> bytes: # pylint: disable=redefined-outer-name
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value
|
||||
Generation Function f6
|
||||
'''
|
||||
return bytes(
|
||||
reversed(
|
||||
aes_cmac(
|
||||
bytes(reversed(n1))
|
||||
+ bytes(reversed(n2))
|
||||
+ bytes(reversed(r))
|
||||
+ bytes(reversed(io_cap))
|
||||
+ bytes(reversed(a1))
|
||||
+ bytes(reversed(a2)),
|
||||
bytes(reversed(w)),
|
||||
)
|
||||
return reverse(
|
||||
aes_cmac(
|
||||
reverse(n1)
|
||||
+ reverse(n2)
|
||||
+ reverse(r)
|
||||
+ reverse(io_cap)
|
||||
+ reverse(a1)
|
||||
+ reverse(a2),
|
||||
reverse(w),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def g2(u, v, x, y):
|
||||
def g2(u: bytes, v: bytes, x: bytes, y: bytes) -> int:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison
|
||||
Value Generation Function g2
|
||||
'''
|
||||
return int.from_bytes(
|
||||
aes_cmac(
|
||||
bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)),
|
||||
bytes(reversed(x)),
|
||||
reverse(u) + reverse(v) + reverse(y),
|
||||
reverse(x),
|
||||
)[-4:],
|
||||
byteorder='big',
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def h6(w, key_id):
|
||||
def h6(w: bytes, key_id: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.10 Link key conversion function h6
|
||||
'''
|
||||
return aes_cmac(key_id, w)
|
||||
return reverse(aes_cmac(key_id, reverse(w)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def h7(salt, w):
|
||||
def h7(salt: bytes, w: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.11 Link key conversion function h7
|
||||
'''
|
||||
return aes_cmac(w, salt)
|
||||
return reverse(aes_cmac(reverse(w), salt))
|
||||
|
||||
2413
bumble/device.py
2413
bumble/device.py
File diff suppressed because it is too large
Load Diff
@@ -19,12 +19,17 @@ like loading firmware after a cold start.
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import abc
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import pathlib
|
||||
import platform
|
||||
from . import rtk
|
||||
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
|
||||
|
||||
from . import rtk, intel
|
||||
from .common import Driver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.host import Host
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -32,40 +37,31 @@ from . import rtk
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class Driver(abc.ABC):
|
||||
"""Base class for drivers."""
|
||||
|
||||
@staticmethod
|
||||
async def for_host(_host):
|
||||
"""Return a driver instance for a host.
|
||||
|
||||
Args:
|
||||
host: Host object for which a driver should be created.
|
||||
|
||||
Returns:
|
||||
A Driver instance if a driver should be instantiated for this host, or
|
||||
None if no driver instance of this class is needed.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def init_controller(self):
|
||||
"""Initialize the controller."""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_driver_for_host(host):
|
||||
"""Probe all known diver classes until one returns a valid instance for a host,
|
||||
or none is found.
|
||||
async def get_driver_for_host(host: Host) -> Optional[Driver]:
|
||||
"""Probe diver classes until one returns a valid instance for a host, or none is
|
||||
found.
|
||||
If a "driver" HCI metadata entry is present, only that driver class will be probed.
|
||||
"""
|
||||
if driver := await rtk.Driver.for_host(host):
|
||||
logger.debug("Instantiated RTK driver")
|
||||
return driver
|
||||
driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver, "intel": intel.Driver}
|
||||
probe_list: Iterable[str]
|
||||
if driver_name := host.hci_metadata.get("driver"):
|
||||
# Only probe a single driver
|
||||
probe_list = [driver_name]
|
||||
else:
|
||||
# Probe all drivers
|
||||
probe_list = driver_classes.keys()
|
||||
|
||||
for driver_name in probe_list:
|
||||
if driver_class := driver_classes.get(driver_name):
|
||||
logger.debug(f"Probing driver class: {driver_name}")
|
||||
if driver := await driver_class.for_host(host):
|
||||
logger.debug(f"Instantiated {driver_name} driver")
|
||||
return driver
|
||||
else:
|
||||
logger.debug(f"Skipping unknown driver class: {driver_name}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
45
bumble/drivers/common.py
Normal file
45
bumble/drivers/common.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Common types for drivers.
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import abc
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class Driver(abc.ABC):
|
||||
"""Base class for drivers."""
|
||||
|
||||
@staticmethod
|
||||
async def for_host(_host):
|
||||
"""Return a driver instance for a host.
|
||||
|
||||
Args:
|
||||
host: Host object for which a driver should be created.
|
||||
|
||||
Returns:
|
||||
A Driver instance if a driver should be instantiated for this host, or
|
||||
None if no driver instance of this class is needed.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def init_controller(self):
|
||||
"""Initialize the controller."""
|
||||
102
bumble/drivers/intel.py
Normal file
102
bumble/drivers/intel.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import logging
|
||||
|
||||
from bumble.drivers import common
|
||||
from bumble.hci import (
|
||||
hci_vendor_command_op_code, # type: ignore
|
||||
HCI_Command,
|
||||
HCI_Reset_Command,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constant
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
INTEL_USB_PRODUCTS = {
|
||||
# Intel AX210
|
||||
(0x8087, 0x0032),
|
||||
# Intel BE200
|
||||
(0x8087, 0x0036),
|
||||
}
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HCI Commands
|
||||
# -----------------------------------------------------------------------------
|
||||
HCI_INTEL_DDC_CONFIG_WRITE_COMMAND = hci_vendor_command_op_code(0xFC8B) # type: ignore
|
||||
HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD = [0x03, 0xE4, 0x02, 0x00]
|
||||
|
||||
HCI_Command.register_commands(globals())
|
||||
|
||||
|
||||
@HCI_Command.command( # type: ignore
|
||||
fields=[("params", "*")],
|
||||
return_parameters_fields=[
|
||||
("params", "*"),
|
||||
],
|
||||
)
|
||||
class Hci_Intel_DDC_Config_Write_Command(HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
class Driver(common.Driver):
|
||||
def __init__(self, host):
|
||||
self.host = host
|
||||
|
||||
@staticmethod
|
||||
def check(host):
|
||||
driver = host.hci_metadata.get("driver")
|
||||
if driver == "intel":
|
||||
return True
|
||||
|
||||
vendor_id = host.hci_metadata.get("vendor_id")
|
||||
product_id = host.hci_metadata.get("product_id")
|
||||
|
||||
if vendor_id is None or product_id is None:
|
||||
logger.debug("USB metadata not sufficient")
|
||||
return False
|
||||
|
||||
if (vendor_id, product_id) not in INTEL_USB_PRODUCTS:
|
||||
logger.debug(
|
||||
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def for_host(cls, host, force=False): # type: ignore
|
||||
# Only instantiate this driver if explicitly selected
|
||||
if not force and not cls.check(host):
|
||||
return None
|
||||
|
||||
return cls(host)
|
||||
|
||||
async def init_controller(self):
|
||||
self.host.ready = True
|
||||
await self.host.send_command(HCI_Reset_Command(), check_result=True)
|
||||
await self.host.send_command(
|
||||
Hci_Intel_DDC_Config_Write_Command(
|
||||
params=HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD
|
||||
)
|
||||
)
|
||||
@@ -33,6 +33,7 @@ from typing import Tuple
|
||||
import weakref
|
||||
|
||||
|
||||
from bumble import core
|
||||
from bumble.hci import (
|
||||
hci_vendor_command_op_code,
|
||||
STATUS_SPEC,
|
||||
@@ -41,7 +42,7 @@ from bumble.hci import (
|
||||
HCI_Reset_Command,
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
)
|
||||
|
||||
from bumble.drivers import common
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -49,6 +50,10 @@ from bumble.hci import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RtkFirmwareError(core.BaseBumbleError):
|
||||
"""Error raised when RTK firmware initialization fails."""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -208,15 +213,15 @@ class Firmware:
|
||||
extension_sig = bytes([0x51, 0x04, 0xFD, 0x77])
|
||||
|
||||
if not firmware.startswith(RTK_EPATCH_SIGNATURE):
|
||||
raise ValueError("Firmware does not start with epatch signature")
|
||||
raise RtkFirmwareError("Firmware does not start with epatch signature")
|
||||
|
||||
if not firmware.endswith(extension_sig):
|
||||
raise ValueError("Firmware does not end with extension sig")
|
||||
raise RtkFirmwareError("Firmware does not end with extension sig")
|
||||
|
||||
# The firmware should start with a 14 byte header.
|
||||
epatch_header_size = 14
|
||||
if len(firmware) < epatch_header_size:
|
||||
raise ValueError("Firmware too short")
|
||||
raise RtkFirmwareError("Firmware too short")
|
||||
|
||||
# Look for the "project ID", starting from the end.
|
||||
offset = len(firmware) - len(extension_sig)
|
||||
@@ -230,7 +235,7 @@ class Firmware:
|
||||
break
|
||||
|
||||
if length == 0:
|
||||
raise ValueError("Invalid 0-length instruction")
|
||||
raise RtkFirmwareError("Invalid 0-length instruction")
|
||||
|
||||
if opcode == 0 and length == 1:
|
||||
project_id = firmware[offset - 1]
|
||||
@@ -239,7 +244,7 @@ class Firmware:
|
||||
offset -= length
|
||||
|
||||
if project_id < 0:
|
||||
raise ValueError("Project ID not found")
|
||||
raise RtkFirmwareError("Project ID not found")
|
||||
|
||||
self.project_id = project_id
|
||||
|
||||
@@ -252,7 +257,7 @@ class Firmware:
|
||||
# <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each)
|
||||
# <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each)
|
||||
if epatch_header_size + 8 * num_patches > len(firmware):
|
||||
raise ValueError("Firmware too short")
|
||||
raise RtkFirmwareError("Firmware too short")
|
||||
chip_id_table_offset = epatch_header_size
|
||||
patch_length_table_offset = chip_id_table_offset + 2 * num_patches
|
||||
patch_offset_table_offset = chip_id_table_offset + 4 * num_patches
|
||||
@@ -266,7 +271,7 @@ class Firmware:
|
||||
"<I", firmware, patch_offset_table_offset + 4 * patch_index
|
||||
)
|
||||
if patch_offset + patch_length > len(firmware):
|
||||
raise ValueError("Firmware too short")
|
||||
raise RtkFirmwareError("Firmware too short")
|
||||
|
||||
# Get the SVN version for the patch
|
||||
(svn_version,) = struct.unpack_from(
|
||||
@@ -285,7 +290,7 @@ class Firmware:
|
||||
)
|
||||
|
||||
|
||||
class Driver:
|
||||
class Driver(common.Driver):
|
||||
@dataclass
|
||||
class DriverInfo:
|
||||
rom: int
|
||||
@@ -470,8 +475,12 @@ class Driver:
|
||||
logger.debug("USB metadata not found")
|
||||
return False
|
||||
|
||||
vendor_id = host.hci_metadata.get("vendor_id", None)
|
||||
product_id = host.hci_metadata.get("product_id", None)
|
||||
if host.hci_metadata.get('driver') == 'rtk':
|
||||
# Forced driver
|
||||
return True
|
||||
|
||||
vendor_id = host.hci_metadata.get("vendor_id")
|
||||
product_id = host.hci_metadata.get("product_id")
|
||||
if vendor_id is None or product_id is None:
|
||||
logger.debug("USB metadata not sufficient")
|
||||
return False
|
||||
@@ -486,6 +495,9 @@ class Driver:
|
||||
|
||||
@classmethod
|
||||
async def driver_info_for_host(cls, host):
|
||||
await host.send_command(HCI_Reset_Command(), check_result=True)
|
||||
host.ready = True # Needed to let the host know the controller is ready.
|
||||
|
||||
response = await host.send_command(
|
||||
HCI_Read_Local_Version_Information_Command(), check_result=True
|
||||
)
|
||||
@@ -638,7 +650,7 @@ class Driver:
|
||||
):
|
||||
return await self.download_for_rtl8723b()
|
||||
|
||||
raise ValueError("ROM not supported")
|
||||
raise RtkFirmwareError("ROM not supported")
|
||||
|
||||
async def init_controller(self):
|
||||
await self.download_firmware()
|
||||
|
||||
@@ -36,6 +36,7 @@ logger = logging.getLogger(__name__)
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class GenericAccessService(Service):
|
||||
def __init__(self, device_name, appearance=(0, 0)):
|
||||
|
||||
247
bumble/gatt.py
247
bumble/gatt.py
@@ -23,16 +23,28 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import struct
|
||||
from typing import Optional, Sequence, Iterable, List, Union
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from .colors import color
|
||||
from .core import UUID, get_dict_key_by_value
|
||||
from .att import Attribute
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID
|
||||
from bumble.att import Attribute, AttributeValue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.gatt_client import AttributeProxy
|
||||
from bumble.device import Connection
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -93,20 +105,35 @@ GATT_RECONNECTION_CONFIGURATION_SERVICE = UUID.from_16_bits(0x1829, 'Reconne
|
||||
GATT_INSULIN_DELIVERY_SERVICE = UUID.from_16_bits(0x183A, 'Insulin Delivery')
|
||||
GATT_BINARY_SENSOR_SERVICE = UUID.from_16_bits(0x183B, 'Binary Sensor')
|
||||
GATT_EMERGENCY_CONFIGURATION_SERVICE = UUID.from_16_bits(0x183C, 'Emergency Configuration')
|
||||
GATT_AUTHORIZATION_CONTROL_SERVICE = UUID.from_16_bits(0x183D, 'Authorization Control')
|
||||
GATT_PHYSICAL_ACTIVITY_MONITOR_SERVICE = UUID.from_16_bits(0x183E, 'Physical Activity Monitor')
|
||||
GATT_ELAPSED_TIME_SERVICE = UUID.from_16_bits(0x183F, 'Elapsed Time')
|
||||
GATT_GENERIC_HEALTH_SENSOR_SERVICE = UUID.from_16_bits(0x1840, 'Generic Health Sensor')
|
||||
GATT_AUDIO_INPUT_CONTROL_SERVICE = UUID.from_16_bits(0x1843, 'Audio Input Control')
|
||||
GATT_VOLUME_CONTROL_SERVICE = UUID.from_16_bits(0x1844, 'Volume Control')
|
||||
GATT_VOLUME_OFFSET_CONTROL_SERVICE = UUID.from_16_bits(0x1845, 'Volume Offset Control')
|
||||
GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification Service')
|
||||
GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification')
|
||||
GATT_DEVICE_TIME_SERVICE = UUID.from_16_bits(0x1847, 'Device Time')
|
||||
GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control Service')
|
||||
GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control Service')
|
||||
GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control')
|
||||
GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control')
|
||||
GATT_CONSTANT_TONE_EXTENSION_SERVICE = UUID.from_16_bits(0x184A, 'Constant Tone Extension')
|
||||
GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer Service')
|
||||
GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer Service')
|
||||
GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer')
|
||||
GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer')
|
||||
GATT_MICROPHONE_CONTROL_SERVICE = UUID.from_16_bits(0x184D, 'Microphone Control')
|
||||
GATT_AUDIO_STREAM_CONTROL_SERVICE = UUID.from_16_bits(0x184E, 'Audio Stream Control')
|
||||
GATT_BROADCAST_AUDIO_SCAN_SERVICE = UUID.from_16_bits(0x184F, 'Broadcast Audio Scan')
|
||||
GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE = UUID.from_16_bits(0x1850, 'Published Audio Capabilities')
|
||||
GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1851, 'Basic Audio Announcement')
|
||||
GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1852, 'Broadcast Audio Announcement')
|
||||
GATT_COMMON_AUDIO_SERVICE = UUID.from_16_bits(0x1853, 'Common Audio')
|
||||
GATT_HEARING_ACCESS_SERVICE = UUID.from_16_bits(0x1854, 'Hearing Access')
|
||||
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE = UUID.from_16_bits(0x1855, 'Telephony and Media Audio')
|
||||
GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1856, 'Public Broadcast Announcement')
|
||||
GATT_ELECTRONIC_SHELF_LABEL_SERVICE = UUID.from_16_bits(0X1857, 'Electronic Shelf Label')
|
||||
GATT_GAMING_AUDIO_SERVICE = UUID.from_16_bits(0x1858, 'Gaming Audio')
|
||||
GATT_MESH_PROXY_SOLICITATION_SERVICE = UUID.from_16_bits(0x1859, 'Mesh Audio Solicitation')
|
||||
|
||||
# Types
|
||||
# Attribute Types
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2800, 'Primary Service')
|
||||
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2801, 'Secondary Service')
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2802, 'Include')
|
||||
@@ -129,6 +156,8 @@ GATT_ENVIRONMENTAL_SENSING_MEASUREMENT_DESCRIPTOR = UUID.from_16_bits(0x290C,
|
||||
GATT_ENVIRONMENTAL_SENSING_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290D, 'Environmental Sensing Trigger Setting')
|
||||
GATT_TIME_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290E, 'Time Trigger Setting')
|
||||
GATT_COMPLETE_BR_EDR_TRANSPORT_BLOCK_DATA_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Complete BR-EDR Transport Block Data')
|
||||
GATT_OBSERVATION_SCHEDULE_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Observation Schedule')
|
||||
GATT_VALID_RANGE_AND_ACCURACY_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Valid Range And Accuracy')
|
||||
|
||||
# Device Information Service
|
||||
GATT_SYSTEM_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A23, 'System ID')
|
||||
@@ -156,6 +185,96 @@ GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A39, 'Heart
|
||||
# Battery Service
|
||||
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level')
|
||||
|
||||
# Telephony And Media Audio Service (TMAS)
|
||||
GATT_TMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2B51, 'TMAP Role')
|
||||
|
||||
# Audio Input Control Service (AICS)
|
||||
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B77, 'Audio Input State')
|
||||
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC = UUID.from_16_bits(0x2B78, 'Gain Settings Attribute')
|
||||
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC = UUID.from_16_bits(0x2B79, 'Audio Input Type')
|
||||
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC = UUID.from_16_bits(0x2B7A, 'Audio Input Status')
|
||||
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7B, 'Audio Input Control Point')
|
||||
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B7C, 'Audio Input Description')
|
||||
|
||||
# Volume Control Service (VCS)
|
||||
GATT_VOLUME_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B7D, 'Volume State')
|
||||
GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7E, 'Volume Control Point')
|
||||
GATT_VOLUME_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2B7F, 'Volume Flags')
|
||||
|
||||
# Volume Offset Control Service (VOCS)
|
||||
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B80, 'Volume Offset State')
|
||||
GATT_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2B81, 'Audio Location')
|
||||
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B82, 'Volume Offset Control Point')
|
||||
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B83, 'Audio Output Description')
|
||||
|
||||
# Coordinated Set Identification Service (CSIS)
|
||||
GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC = UUID.from_16_bits(0x2B84, 'Set Identity Resolving Key')
|
||||
GATT_COORDINATED_SET_SIZE_CHARACTERISTIC = UUID.from_16_bits(0x2B85, 'Coordinated Set Size')
|
||||
GATT_SET_MEMBER_LOCK_CHARACTERISTIC = UUID.from_16_bits(0x2B86, 'Set Member Lock')
|
||||
GATT_SET_MEMBER_RANK_CHARACTERISTIC = UUID.from_16_bits(0x2B87, 'Set Member Rank')
|
||||
|
||||
# Media Control Service (MCS)
|
||||
GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2B93, 'Media Player Name')
|
||||
GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B94, 'Media Player Icon Object ID')
|
||||
GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC = UUID.from_16_bits(0x2B95, 'Media Player Icon URL')
|
||||
GATT_TRACK_CHANGED_CHARACTERISTIC = UUID.from_16_bits(0x2B96, 'Track Changed')
|
||||
GATT_TRACK_TITLE_CHARACTERISTIC = UUID.from_16_bits(0x2B97, 'Track Title')
|
||||
GATT_TRACK_DURATION_CHARACTERISTIC = UUID.from_16_bits(0x2B98, 'Track Duration')
|
||||
GATT_TRACK_POSITION_CHARACTERISTIC = UUID.from_16_bits(0x2B99, 'Track Position')
|
||||
GATT_PLAYBACK_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9A, 'Playback Speed')
|
||||
GATT_SEEKING_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9B, 'Seeking Speed')
|
||||
GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9C, 'Current Track Segments Object ID')
|
||||
GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9D, 'Current Track Object ID')
|
||||
GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9E, 'Next Track Object ID')
|
||||
GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9F, 'Parent Group Object ID')
|
||||
GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA0, 'Current Group Object ID')
|
||||
GATT_PLAYING_ORDER_CHARACTERISTIC = UUID.from_16_bits(0x2BA1, 'Playing Order')
|
||||
GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA2, 'Playing Orders Supported')
|
||||
GATT_MEDIA_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BA3, 'Media State')
|
||||
GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA4, 'Media Control Point')
|
||||
GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA5, 'Media Control Point Opcodes Supported')
|
||||
GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA6, 'Search Results Object ID')
|
||||
GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA7, 'Search Control Point')
|
||||
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control Id')
|
||||
|
||||
# Telephone Bearer Service (TBS)
|
||||
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB4, 'Bearer Provider Name')
|
||||
GATT_BEARER_UCI_CHARACTERISTIC = UUID.from_16_bits(0x2BB5, 'Bearer UCI')
|
||||
GATT_BEARER_TECHNOLOGY_CHARACTERISTIC = UUID.from_16_bits(0x2BB6, 'Bearer Technology')
|
||||
GATT_BEARER_URI_SCHEMES_SUPPORTED_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2BB7, 'Bearer URI Schemes Supported List')
|
||||
GATT_BEARER_SIGNAL_STRENGTH_CHARACTERISTIC = UUID.from_16_bits(0x2BB8, 'Bearer Signal Strength')
|
||||
GATT_BEARER_SIGNAL_STRENGTH_REPORTING_INTERVAL_CHARACTERISTIC = UUID.from_16_bits(0x2BB9, 'Bearer Signal Strength Reporting Interval')
|
||||
GATT_BEARER_LIST_CURRENT_CALLS_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Bearer List Current Calls')
|
||||
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBB, 'Content Control ID')
|
||||
GATT_STATUS_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2BBC, 'Status Flags')
|
||||
GATT_INCOMING_CALL_TARGET_BEARER_URI_CHARACTERISTIC = UUID.from_16_bits(0x2BBD, 'Incoming Call Target Bearer URI')
|
||||
GATT_CALL_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BBE, 'Call State')
|
||||
GATT_CALL_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BBF, 'Call Control Point')
|
||||
GATT_CALL_CONTROL_POINT_OPTIONAL_OPCODES_CHARACTERISTIC = UUID.from_16_bits(0x2BC0, 'Call Control Point Optional Opcodes')
|
||||
GATT_TERMINATION_REASON_CHARACTERISTIC = UUID.from_16_bits(0x2BC1, 'Termination Reason')
|
||||
GATT_INCOMING_CALL_CHARACTERISTIC = UUID.from_16_bits(0x2BC2, 'Incoming Call')
|
||||
GATT_CALL_FRIENDLY_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Call Friendly Name')
|
||||
|
||||
# Microphone Control Service (MICS)
|
||||
GATT_MUTE_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Mute')
|
||||
|
||||
# Audio Stream Control Service (ASCS)
|
||||
GATT_SINK_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC4, 'Sink ASE')
|
||||
GATT_SOURCE_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC5, 'Source ASE')
|
||||
GATT_ASE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC6, 'ASE Control Point')
|
||||
|
||||
# Broadcast Audio Scan Service (BASS)
|
||||
GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC7, 'Broadcast Audio Scan Control Point')
|
||||
GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BC8, 'Broadcast Receive State')
|
||||
|
||||
# Published Audio Capabilities Service (PACS)
|
||||
GATT_SINK_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BC9, 'Sink PAC')
|
||||
GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCA, 'Sink Audio Location')
|
||||
GATT_SOURCE_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BCB, 'Source PAC')
|
||||
GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Source Audio Location')
|
||||
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
|
||||
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
|
||||
|
||||
# ASHA Service
|
||||
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
|
||||
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
|
||||
@@ -177,6 +296,9 @@ GATT_BOOT_KEYBOARD_INPUT_REPORT_CHARACTERISTIC = UUID.from_16_bi
|
||||
GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bits(0x2A2B, 'Current Time')
|
||||
GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report')
|
||||
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
|
||||
GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features')
|
||||
GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash')
|
||||
GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features')
|
||||
|
||||
# fmt: on
|
||||
# pylint: enable=line-too-long
|
||||
@@ -220,9 +342,11 @@ class Service(Attribute):
|
||||
uuid = UUID(uuid)
|
||||
|
||||
super().__init__(
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
|
||||
if primary
|
||||
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
(
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
|
||||
if primary
|
||||
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
|
||||
),
|
||||
Attribute.READABLE,
|
||||
uuid.to_pdu_bytes(),
|
||||
)
|
||||
@@ -258,9 +382,12 @@ class TemplateService(Service):
|
||||
UUID: UUID
|
||||
|
||||
def __init__(
|
||||
self, characteristics: List[Characteristic], primary: bool = True
|
||||
self,
|
||||
characteristics: List[Characteristic],
|
||||
primary: bool = True,
|
||||
included_services: List[Service] = [],
|
||||
) -> None:
|
||||
super().__init__(self.UUID, characteristics, primary)
|
||||
super().__init__(self.UUID, characteristics, primary, included_services)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -409,56 +536,43 @@ class CharacteristicDeclaration(Attribute):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicValue:
|
||||
'''
|
||||
Characteristic value where reading and/or writing is delegated to functions
|
||||
passed as arguments to the constructor.
|
||||
'''
|
||||
|
||||
def __init__(self, read=None, write=None):
|
||||
self._read = read
|
||||
self._write = write
|
||||
|
||||
def read(self, connection):
|
||||
return self._read(connection) if self._read else b''
|
||||
|
||||
def write(self, connection, value):
|
||||
if self._write:
|
||||
self._write(connection, value)
|
||||
class CharacteristicValue(AttributeValue):
|
||||
"""Same as AttributeValue, for backward compatibility"""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicAdapter:
|
||||
'''
|
||||
An adapter that can adapt any object with `read_value` and `write_value`
|
||||
methods (like Characteristic and CharacteristicProxy objects) by wrapping
|
||||
those methods with ones that return/accept encoded/decoded values.
|
||||
Objects with async methods are considered proxies, so the adaptation is one
|
||||
where the return value of `read_value` is decoded and the value passed to
|
||||
`write_value` is encoded. Other objects are considered local characteristics
|
||||
so the adaptation is one where the return value of `read_value` is encoded
|
||||
and the value passed to `write_value` is decoded.
|
||||
If the characteristic has a `subscribe` method, it is wrapped with one where
|
||||
the values are decoded before being passed to the subscriber.
|
||||
An adapter that can adapt Characteristic and AttributeProxy objects
|
||||
by wrapping their `read_value()` and `write_value()` methods with ones that
|
||||
return/accept encoded/decoded values.
|
||||
|
||||
For proxies (i.e used by a GATT client), the adaptation is one where the return
|
||||
value of `read_value()` is decoded and the value passed to `write_value()` is
|
||||
encoded. The `subscribe()` method, is wrapped with one where the values are decoded
|
||||
before being passed to the subscriber.
|
||||
|
||||
For local values (i.e hosted by a GATT server) the adaptation is one where the
|
||||
return value of `read_value()` is encoded and the value passed to `write_value()`
|
||||
is decoded.
|
||||
'''
|
||||
|
||||
def __init__(self, characteristic):
|
||||
self.wrapped_characteristic = characteristic
|
||||
self.subscribers = {} # Map from subscriber to proxy subscriber
|
||||
read_value: Callable
|
||||
write_value: Callable
|
||||
|
||||
if asyncio.iscoroutinefunction(
|
||||
characteristic.read_value
|
||||
) and asyncio.iscoroutinefunction(characteristic.write_value):
|
||||
self.read_value = self.read_decoded_value
|
||||
self.write_value = self.write_decoded_value
|
||||
else:
|
||||
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
|
||||
self.wrapped_characteristic = characteristic
|
||||
self.subscribers: Dict[Callable, Callable] = (
|
||||
{}
|
||||
) # Map from subscriber to proxy subscriber
|
||||
|
||||
if isinstance(characteristic, Characteristic):
|
||||
self.read_value = self.read_encoded_value
|
||||
self.write_value = self.write_encoded_value
|
||||
|
||||
if hasattr(self.wrapped_characteristic, 'subscribe'):
|
||||
else:
|
||||
self.read_value = self.read_decoded_value
|
||||
self.write_value = self.write_decoded_value
|
||||
self.subscribe = self.wrapped_subscribe
|
||||
|
||||
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
|
||||
self.unsubscribe = self.wrapped_unsubscribe
|
||||
|
||||
def __getattr__(self, name):
|
||||
@@ -477,11 +591,13 @@ class CharacteristicAdapter:
|
||||
else:
|
||||
setattr(self.wrapped_characteristic, name, value)
|
||||
|
||||
def read_encoded_value(self, connection):
|
||||
return self.encode_value(self.wrapped_characteristic.read_value(connection))
|
||||
async def read_encoded_value(self, connection):
|
||||
return self.encode_value(
|
||||
await self.wrapped_characteristic.read_value(connection)
|
||||
)
|
||||
|
||||
def write_encoded_value(self, connection, value):
|
||||
return self.wrapped_characteristic.write_value(
|
||||
async def write_encoded_value(self, connection, value):
|
||||
return await self.wrapped_characteristic.write_value(
|
||||
connection, self.decode_value(value)
|
||||
)
|
||||
|
||||
@@ -616,13 +732,24 @@ class Descriptor(Attribute):
|
||||
'''
|
||||
|
||||
def __str__(self) -> str:
|
||||
if isinstance(self.value, bytes):
|
||||
value_str = self.value.hex()
|
||||
elif isinstance(self.value, CharacteristicValue):
|
||||
value = self.value.read(None)
|
||||
if isinstance(value, bytes):
|
||||
value_str = value.hex()
|
||||
else:
|
||||
value_str = '<async>'
|
||||
else:
|
||||
value_str = '<...>'
|
||||
return (
|
||||
f'Descriptor(handle=0x{self.handle:04X}, '
|
||||
f'type={self.type}, '
|
||||
f'value={self.read_value(None).hex()})'
|
||||
f'value={value_str})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ClientCharacteristicConfigurationBits(enum.IntFlag):
|
||||
'''
|
||||
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
|
||||
|
||||
@@ -38,6 +38,7 @@ from typing import (
|
||||
Any,
|
||||
Iterable,
|
||||
Type,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
@@ -89,6 +90,22 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def show_services(services: Iterable[ServiceProxy]) -> None:
|
||||
for service in services:
|
||||
print(color(str(service), 'cyan'))
|
||||
|
||||
for characteristic in service.characteristics:
|
||||
print(color(' ' + str(characteristic), 'magenta'))
|
||||
|
||||
for descriptor in characteristic.descriptors:
|
||||
print(color(' ' + str(descriptor), 'green'))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Proxies
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -128,7 +145,7 @@ class ServiceProxy(AttributeProxy):
|
||||
included_services: List[ServiceProxy]
|
||||
|
||||
@staticmethod
|
||||
def from_client(service_class, client, service_uuid):
|
||||
def from_client(service_class, client: Client, service_uuid: UUID):
|
||||
# The service and its characteristics are considered to have already been
|
||||
# discovered
|
||||
services = client.get_services_by_uuid(service_uuid)
|
||||
@@ -206,11 +223,11 @@ class CharacteristicProxy(AttributeProxy):
|
||||
|
||||
return await self.client.subscribe(self, subscriber, prefer_notify)
|
||||
|
||||
async def unsubscribe(self, subscriber=None):
|
||||
async def unsubscribe(self, subscriber=None, force=False):
|
||||
if subscriber in self.subscribers:
|
||||
subscriber = self.subscribers.pop(subscriber)
|
||||
|
||||
return await self.client.unsubscribe(self, subscriber)
|
||||
return await self.client.unsubscribe(self, subscriber, force)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
@@ -246,8 +263,12 @@ class ProfileServiceProxy:
|
||||
class Client:
|
||||
services: List[ServiceProxy]
|
||||
cached_values: Dict[int, Tuple[datetime, bytes]]
|
||||
notification_subscribers: Dict[int, Callable[[bytes], Any]]
|
||||
indication_subscribers: Dict[int, Callable[[bytes], Any]]
|
||||
notification_subscribers: Dict[
|
||||
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
|
||||
]
|
||||
indication_subscribers: Dict[
|
||||
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
|
||||
]
|
||||
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
|
||||
pending_request: Optional[ATT_PDU]
|
||||
|
||||
@@ -257,10 +278,8 @@ class Client:
|
||||
self.request_semaphore = asyncio.Semaphore(1)
|
||||
self.pending_request = None
|
||||
self.pending_response = None
|
||||
self.notification_subscribers = (
|
||||
{}
|
||||
) # Notification subscribers, by attribute handle
|
||||
self.indication_subscribers = {} # Indication subscribers, by attribute handle
|
||||
self.notification_subscribers = {} # Subscriber set, by attribute handle
|
||||
self.indication_subscribers = {} # Subscriber set, by attribute handle
|
||||
self.services = []
|
||||
self.cached_values = {}
|
||||
|
||||
@@ -312,9 +331,9 @@ class Client:
|
||||
async def request_mtu(self, mtu: int) -> int:
|
||||
# Check the range
|
||||
if mtu < ATT_DEFAULT_MTU:
|
||||
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
|
||||
raise core.InvalidArgumentError(f'MTU must be >= {ATT_DEFAULT_MTU}')
|
||||
if mtu > 0xFFFF:
|
||||
raise ValueError('MTU must be <= 0xFFFF')
|
||||
raise core.InvalidArgumentError('MTU must be <= 0xFFFF')
|
||||
|
||||
# We can only send one request per connection
|
||||
if self.mtu_exchange_done:
|
||||
@@ -349,9 +368,7 @@ class Client:
|
||||
if c.uuid == uuid
|
||||
]
|
||||
|
||||
def get_attribute_grouping(
|
||||
self, attribute_handle: int
|
||||
) -> Optional[
|
||||
def get_attribute_grouping(self, attribute_handle: int) -> Optional[
|
||||
Union[
|
||||
ServiceProxy,
|
||||
Tuple[ServiceProxy, CharacteristicProxy],
|
||||
@@ -682,8 +699,8 @@ class Client:
|
||||
async def discover_descriptors(
|
||||
self,
|
||||
characteristic: Optional[CharacteristicProxy] = None,
|
||||
start_handle=None,
|
||||
end_handle=None,
|
||||
start_handle: Optional[int] = None,
|
||||
end_handle: Optional[int] = None,
|
||||
) -> List[DescriptorProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
|
||||
@@ -789,7 +806,12 @@ class Client:
|
||||
|
||||
return attributes
|
||||
|
||||
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
|
||||
async def subscribe(
|
||||
self,
|
||||
characteristic: CharacteristicProxy,
|
||||
subscriber: Optional[Callable[[bytes], Any]] = None,
|
||||
prefer_notify: bool = True,
|
||||
) -> None:
|
||||
# If we haven't already discovered the descriptors for this characteristic,
|
||||
# do it now
|
||||
if not characteristic.descriptors_discovered:
|
||||
@@ -826,6 +848,7 @@ class Client:
|
||||
subscriber_set = subscribers.setdefault(characteristic.handle, set())
|
||||
if subscriber is not None:
|
||||
subscriber_set.add(subscriber)
|
||||
|
||||
# Add the characteristic as a subscriber, which will result in the
|
||||
# characteristic emitting an 'update' event when a notification or indication
|
||||
# is received
|
||||
@@ -833,7 +856,18 @@ class Client:
|
||||
|
||||
await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
|
||||
|
||||
async def unsubscribe(self, characteristic, subscriber=None):
|
||||
async def unsubscribe(
|
||||
self,
|
||||
characteristic: CharacteristicProxy,
|
||||
subscriber: Optional[Callable[[bytes], Any]] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
'''
|
||||
Unsubscribe from a characteristic.
|
||||
|
||||
If `force` is True, this will write zeros to the CCCD when there are no
|
||||
subscribers left, even if there were already no registered subscribers.
|
||||
'''
|
||||
# If we haven't already discovered the descriptors for this characteristic,
|
||||
# do it now
|
||||
if not characteristic.descriptors_discovered:
|
||||
@@ -847,31 +881,45 @@ class Client:
|
||||
logger.warning('unsubscribing from characteristic with no CCCD descriptor')
|
||||
return
|
||||
|
||||
# Check if the characteristic has subscribers
|
||||
if not (
|
||||
characteristic.handle in self.notification_subscribers
|
||||
or characteristic.handle in self.indication_subscribers
|
||||
):
|
||||
if not force:
|
||||
return
|
||||
|
||||
# Remove the subscriber(s)
|
||||
if subscriber is not None:
|
||||
# Remove matching subscriber from subscriber sets
|
||||
for subscriber_set in (
|
||||
self.notification_subscribers,
|
||||
self.indication_subscribers,
|
||||
):
|
||||
subscribers = subscriber_set.get(characteristic.handle, [])
|
||||
if subscriber in subscribers:
|
||||
if (
|
||||
subscribers := subscriber_set.get(characteristic.handle)
|
||||
) and subscriber in subscribers:
|
||||
subscribers.remove(subscriber)
|
||||
|
||||
# Cleanup if we removed the last one
|
||||
if not subscribers:
|
||||
del subscriber_set[characteristic.handle]
|
||||
else:
|
||||
# Remove all subscribers for this attribute from the sets!
|
||||
# Remove all subscribers for this attribute from the sets
|
||||
self.notification_subscribers.pop(characteristic.handle, None)
|
||||
self.indication_subscribers.pop(characteristic.handle, None)
|
||||
|
||||
if not self.notification_subscribers and not self.indication_subscribers:
|
||||
# Update the CCCD
|
||||
if not (
|
||||
characteristic.handle in self.notification_subscribers
|
||||
or characteristic.handle in self.indication_subscribers
|
||||
):
|
||||
# No more subscribers left
|
||||
await self.write_value(cccd, b'\x00\x00', with_response=True)
|
||||
|
||||
async def read_value(
|
||||
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
|
||||
) -> Any:
|
||||
) -> bytes:
|
||||
'''
|
||||
See Vol 3, Part G - 4.8.1 Read Characteristic Value
|
||||
|
||||
@@ -1034,7 +1082,7 @@ class Client:
|
||||
logger.warning('!!! unexpected response, there is no pending request')
|
||||
return
|
||||
|
||||
# Sanity check: the response should match the pending request unless it is
|
||||
# The response should match the pending request unless it is
|
||||
# an error response
|
||||
if att_pdu.op_code != ATT_ERROR_RESPONSE:
|
||||
expected_response_name = self.pending_request.name.replace(
|
||||
@@ -1067,7 +1115,7 @@ class Client:
|
||||
def on_att_handle_value_notification(self, notification):
|
||||
# Call all subscribers
|
||||
subscribers = self.notification_subscribers.get(
|
||||
notification.attribute_handle, []
|
||||
notification.attribute_handle, set()
|
||||
)
|
||||
if not subscribers:
|
||||
logger.warning('!!! received notification with no subscriber')
|
||||
@@ -1081,7 +1129,9 @@ class Client:
|
||||
|
||||
def on_att_handle_value_indication(self, indication):
|
||||
# Call all subscribers
|
||||
subscribers = self.indication_subscribers.get(indication.attribute_handle, [])
|
||||
subscribers = self.indication_subscribers.get(
|
||||
indication.attribute_handle, set()
|
||||
)
|
||||
if not subscribers:
|
||||
logger.warning('!!! received indication with no subscriber')
|
||||
|
||||
|
||||
@@ -31,9 +31,9 @@ import struct
|
||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
|
||||
from pyee import EventEmitter
|
||||
|
||||
from .colors import color
|
||||
from .core import UUID
|
||||
from .att import (
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID
|
||||
from bumble.att import (
|
||||
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
||||
ATT_CID,
|
||||
@@ -60,7 +60,7 @@ from .att import (
|
||||
ATT_Write_Response,
|
||||
Attribute,
|
||||
)
|
||||
from .gatt import (
|
||||
from bumble.gatt import (
|
||||
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
|
||||
@@ -74,6 +74,7 @@ from .gatt import (
|
||||
Descriptor,
|
||||
Service,
|
||||
)
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Device, Connection
|
||||
@@ -327,7 +328,7 @@ class Server(EventEmitter):
|
||||
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
|
||||
)
|
||||
|
||||
# Sanity check
|
||||
# Check parameters
|
||||
if len(value) != 2:
|
||||
logger.warning('CCCD value not 2 bytes long')
|
||||
return
|
||||
@@ -379,7 +380,7 @@ class Server(EventEmitter):
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
@@ -422,7 +423,7 @@ class Server(EventEmitter):
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
@@ -444,9 +445,9 @@ class Server(EventEmitter):
|
||||
assert self.pending_confirmations[connection.handle] is None
|
||||
|
||||
# Create a future value to hold the eventual response
|
||||
pending_confirmation = self.pending_confirmations[
|
||||
connection.handle
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
pending_confirmation = self.pending_confirmations[connection.handle] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
try:
|
||||
self.send_gatt_pdu(connection.handle, indication.to_bytes())
|
||||
@@ -650,7 +651,8 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_find_by_type_value_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_find_by_type_value_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
|
||||
'''
|
||||
@@ -658,13 +660,13 @@ class Server(EventEmitter):
|
||||
# Build list of returned attributes
|
||||
pdu_space_available = connection.att_mtu - 2
|
||||
attributes = []
|
||||
for attribute in (
|
||||
async for attribute in (
|
||||
attribute
|
||||
for attribute in self.attributes
|
||||
if attribute.handle >= request.starting_handle
|
||||
and attribute.handle <= request.ending_handle
|
||||
and attribute.type == request.attribute_type
|
||||
and attribute.read_value(connection) == request.attribute_value
|
||||
and (await attribute.read_value(connection)) == request.attribute_value
|
||||
and pdu_space_available >= 4
|
||||
):
|
||||
# TODO: check permissions
|
||||
@@ -702,7 +704,8 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_by_type_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_type_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
|
||||
'''
|
||||
@@ -725,7 +728,7 @@ class Server(EventEmitter):
|
||||
and pdu_space_available
|
||||
):
|
||||
try:
|
||||
attribute_value = attribute.read_value(connection)
|
||||
attribute_value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
# If the first attribute is unreadable, return an error
|
||||
# Otherwise return attributes up to this point
|
||||
@@ -767,14 +770,15 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
|
||||
'''
|
||||
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
try:
|
||||
value = attribute.read_value(connection)
|
||||
value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -792,14 +796,15 @@ class Server(EventEmitter):
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_blob_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_blob_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
|
||||
'''
|
||||
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
try:
|
||||
value = attribute.read_value(connection)
|
||||
value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -836,7 +841,8 @@ class Server(EventEmitter):
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_by_group_type_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_group_type_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
|
||||
'''
|
||||
@@ -864,7 +870,7 @@ class Server(EventEmitter):
|
||||
):
|
||||
# No need to catch permission errors here, since these attributes
|
||||
# must all be world-readable
|
||||
attribute_value = attribute.read_value(connection)
|
||||
attribute_value = await attribute.read_value(connection)
|
||||
# Check the attribute value size
|
||||
max_attribute_size = min(connection.att_mtu - 6, 251)
|
||||
if len(attribute_value) > max_attribute_size:
|
||||
@@ -903,7 +909,8 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_write_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_write_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
|
||||
'''
|
||||
@@ -936,12 +943,13 @@ class Server(EventEmitter):
|
||||
return
|
||||
|
||||
# Accept the value
|
||||
attribute.write_value(connection, request.attribute_value)
|
||||
await attribute.write_value(connection, request.attribute_value)
|
||||
|
||||
# Done
|
||||
self.send_response(connection, ATT_Write_Response())
|
||||
|
||||
def on_att_write_command(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_write_command(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
|
||||
'''
|
||||
@@ -959,9 +967,9 @@ class Server(EventEmitter):
|
||||
|
||||
# Accept the value
|
||||
try:
|
||||
attribute.write_value(connection, request.attribute_value)
|
||||
await attribute.write_value(connection, request.attribute_value)
|
||||
except Exception as error:
|
||||
logger.warning(f'!!! ignoring exception: {error}')
|
||||
logger.exception(f'!!! ignoring exception: {error}')
|
||||
|
||||
def on_att_handle_value_confirmation(self, connection, _confirmation):
|
||||
'''
|
||||
|
||||
2323
bumble/hci.py
2323
bumble/hci.py
File diff suppressed because it is too large
Load Diff
@@ -15,30 +15,46 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, MutableMapping
|
||||
import datetime
|
||||
from typing import cast, Any, Optional
|
||||
import logging
|
||||
|
||||
from .colors import color
|
||||
from .att import ATT_CID, ATT_PDU
|
||||
from .smp import SMP_CID, SMP_Command
|
||||
from .core import name_or_number
|
||||
from .l2cap import (
|
||||
from bumble import avc
|
||||
from bumble import avctp
|
||||
from bumble import avdtp
|
||||
from bumble import avrcp
|
||||
from bumble import crypto
|
||||
from bumble import rfcomm
|
||||
from bumble import sdp
|
||||
from bumble.colors import color
|
||||
from bumble.att import ATT_CID, ATT_PDU
|
||||
from bumble.smp import SMP_CID, SMP_Command
|
||||
from bumble.core import name_or_number
|
||||
from bumble.l2cap import (
|
||||
L2CAP_PDU,
|
||||
L2CAP_CONNECTION_REQUEST,
|
||||
L2CAP_CONNECTION_RESPONSE,
|
||||
L2CAP_SIGNALING_CID,
|
||||
L2CAP_LE_SIGNALING_CID,
|
||||
L2CAP_Control_Frame,
|
||||
L2CAP_Connection_Request,
|
||||
L2CAP_Connection_Response,
|
||||
)
|
||||
from .hci import (
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
HCI_EVENT_PACKET,
|
||||
HCI_ACL_DATA_PACKET,
|
||||
HCI_DISCONNECTION_COMPLETE_EVENT,
|
||||
HCI_AclDataPacketAssembler,
|
||||
HCI_Packet,
|
||||
HCI_Event,
|
||||
HCI_AclDataPacket,
|
||||
HCI_Disconnection_Complete_Event,
|
||||
)
|
||||
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM
|
||||
from .sdp import SDP_PDU, SDP_PSM
|
||||
from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -48,26 +64,36 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
PSM_NAMES = {
|
||||
RFCOMM_PSM: 'RFCOMM',
|
||||
SDP_PSM: 'SDP',
|
||||
AVDTP_PSM: 'AVDTP'
|
||||
rfcomm.RFCOMM_PSM: 'RFCOMM',
|
||||
sdp.SDP_PSM: 'SDP',
|
||||
avdtp.AVDTP_PSM: 'AVDTP',
|
||||
avctp.AVCTP_PSM: 'AVCTP',
|
||||
# TODO: add more PSM values
|
||||
}
|
||||
|
||||
AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PacketTracer:
|
||||
class AclStream:
|
||||
def __init__(self, analyzer):
|
||||
psms: MutableMapping[int, int]
|
||||
peer: Optional[PacketTracer.AclStream]
|
||||
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
|
||||
avctp_assemblers: MutableMapping[int, avctp.MessageAssembler]
|
||||
|
||||
def __init__(self, analyzer: PacketTracer.Analyzer) -> None:
|
||||
self.analyzer = analyzer
|
||||
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
|
||||
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
|
||||
self.avctp_assemblers = {} # AVCTP assemblers, by source_cid
|
||||
self.psms = {} # PSM, by source_cid
|
||||
self.peer = None # ACL stream in the other direction
|
||||
self.peer = None
|
||||
|
||||
# pylint: disable=too-many-nested-blocks
|
||||
def on_acl_pdu(self, pdu):
|
||||
def on_acl_pdu(self, pdu: bytes) -> None:
|
||||
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
|
||||
self.analyzer.emit(l2cap_pdu)
|
||||
|
||||
if l2cap_pdu.cid == ATT_CID:
|
||||
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
|
||||
@@ -81,46 +107,59 @@ class PacketTracer:
|
||||
|
||||
# Check if this signals a new channel
|
||||
if control_frame.code == L2CAP_CONNECTION_REQUEST:
|
||||
self.psms[control_frame.source_cid] = control_frame.psm
|
||||
connection_request = cast(L2CAP_Connection_Request, control_frame)
|
||||
self.psms[connection_request.source_cid] = connection_request.psm
|
||||
elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
|
||||
connection_response = cast(L2CAP_Connection_Response, control_frame)
|
||||
if (
|
||||
control_frame.result
|
||||
connection_response.result
|
||||
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
|
||||
):
|
||||
if self.peer:
|
||||
if psm := self.peer.psms.get(control_frame.source_cid):
|
||||
# Found a pending connection
|
||||
self.psms[control_frame.destination_cid] = psm
|
||||
|
||||
# For AVDTP connections, create a packet assembler for
|
||||
# each direction
|
||||
if psm == AVDTP_PSM:
|
||||
self.avdtp_assemblers[
|
||||
control_frame.source_cid
|
||||
] = AVDTP_MessageAssembler(self.on_avdtp_message)
|
||||
self.peer.avdtp_assemblers[
|
||||
control_frame.destination_cid
|
||||
] = AVDTP_MessageAssembler(
|
||||
self.peer.on_avdtp_message
|
||||
)
|
||||
if self.peer and (
|
||||
psm := self.peer.psms.get(connection_response.source_cid)
|
||||
):
|
||||
# Found a pending connection
|
||||
self.psms[connection_response.destination_cid] = psm
|
||||
|
||||
# For AVDTP connections, create a packet assembler for
|
||||
# each direction
|
||||
if psm == avdtp.AVDTP_PSM:
|
||||
self.avdtp_assemblers[
|
||||
connection_response.source_cid
|
||||
] = avdtp.MessageAssembler(self.on_avdtp_message)
|
||||
self.peer.avdtp_assemblers[
|
||||
connection_response.destination_cid
|
||||
] = avdtp.MessageAssembler(self.peer.on_avdtp_message)
|
||||
elif psm == avctp.AVCTP_PSM:
|
||||
self.avctp_assemblers[
|
||||
connection_response.source_cid
|
||||
] = avctp.MessageAssembler(self.on_avctp_message)
|
||||
self.peer.avctp_assemblers[
|
||||
connection_response.destination_cid
|
||||
] = avctp.MessageAssembler(self.peer.on_avctp_message)
|
||||
else:
|
||||
# Try to find the PSM associated with this PDU
|
||||
if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)):
|
||||
if psm == SDP_PSM:
|
||||
sdp_pdu = SDP_PDU.from_bytes(l2cap_pdu.payload)
|
||||
if psm == sdp.SDP_PSM:
|
||||
sdp_pdu = sdp.SDP_PDU.from_bytes(l2cap_pdu.payload)
|
||||
self.analyzer.emit(sdp_pdu)
|
||||
elif psm == RFCOMM_PSM:
|
||||
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
|
||||
elif psm == rfcomm.RFCOMM_PSM:
|
||||
rfcomm_frame = rfcomm.RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
|
||||
self.analyzer.emit(rfcomm_frame)
|
||||
elif psm == AVDTP_PSM:
|
||||
elif psm == avdtp.AVDTP_PSM:
|
||||
self.analyzer.emit(
|
||||
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
|
||||
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
|
||||
)
|
||||
assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
|
||||
if assembler:
|
||||
assembler.on_pdu(l2cap_pdu.payload)
|
||||
if avdtp_assembler := self.avdtp_assemblers.get(l2cap_pdu.cid):
|
||||
avdtp_assembler.on_pdu(l2cap_pdu.payload)
|
||||
elif psm == avctp.AVCTP_PSM:
|
||||
self.analyzer.emit(
|
||||
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
|
||||
f'PSM=AVCTP]: {l2cap_pdu.payload.hex()}'
|
||||
)
|
||||
if avctp_assembler := self.avctp_assemblers.get(l2cap_pdu.cid):
|
||||
avctp_assembler.on_pdu(l2cap_pdu.payload)
|
||||
else:
|
||||
psm_string = name_or_number(PSM_NAMES, psm)
|
||||
self.analyzer.emit(
|
||||
@@ -130,22 +169,49 @@ class PacketTracer:
|
||||
else:
|
||||
self.analyzer.emit(l2cap_pdu)
|
||||
|
||||
def on_avdtp_message(self, transaction_label, message):
|
||||
def on_avdtp_message(
|
||||
self, transaction_label: int, message: avdtp.Message
|
||||
) -> None:
|
||||
self.analyzer.emit(
|
||||
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
|
||||
)
|
||||
|
||||
def feed_packet(self, packet):
|
||||
def on_avctp_message(
|
||||
self,
|
||||
transaction_label: int,
|
||||
is_command: bool,
|
||||
ipid: bool,
|
||||
pid: int,
|
||||
payload: bytes,
|
||||
):
|
||||
if pid == avrcp.AVRCP_PID:
|
||||
avc_frame = avc.Frame.from_bytes(payload)
|
||||
details = str(avc_frame)
|
||||
else:
|
||||
details = payload.hex()
|
||||
|
||||
c_r = 'Command' if is_command else 'Response'
|
||||
self.analyzer.emit(
|
||||
f'{color("AVCTP", "green")} '
|
||||
f'{c_r}[{transaction_label}][{name_or_number(AVCTP_PID_NAMES, pid)}] '
|
||||
f'{"#" if ipid else ""}'
|
||||
f'{details}'
|
||||
)
|
||||
|
||||
def feed_packet(self, packet: HCI_AclDataPacket) -> None:
|
||||
self.packet_assembler.feed_packet(packet)
|
||||
|
||||
class Analyzer:
|
||||
def __init__(self, label, emit_message):
|
||||
acl_streams: MutableMapping[int, PacketTracer.AclStream]
|
||||
peer: PacketTracer.Analyzer
|
||||
|
||||
def __init__(self, label: str, emit_message: Callable[..., None]) -> None:
|
||||
self.label = label
|
||||
self.emit_message = emit_message
|
||||
self.acl_streams = {} # ACL streams, by connection handle
|
||||
self.peer = None # Analyzer in the other direction
|
||||
self.packet_timestamp: Optional[datetime.datetime] = None
|
||||
|
||||
def start_acl_stream(self, connection_handle):
|
||||
def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream:
|
||||
logger.info(
|
||||
f'[{self.label}] +++ Creating ACL stream for connection '
|
||||
f'0x{connection_handle:04X}'
|
||||
@@ -160,7 +226,7 @@ class PacketTracer:
|
||||
|
||||
return stream
|
||||
|
||||
def end_acl_stream(self, connection_handle):
|
||||
def end_acl_stream(self, connection_handle: int) -> None:
|
||||
if connection_handle in self.acl_streams:
|
||||
logger.info(
|
||||
f'[{self.label}] --- Removing ACL stream for connection '
|
||||
@@ -171,34 +237,52 @@ class PacketTracer:
|
||||
# Let the other forwarder know so it can cleanup its stream as well
|
||||
self.peer.end_acl_stream(connection_handle)
|
||||
|
||||
def on_packet(self, packet):
|
||||
def on_packet(
|
||||
self, timestamp: Optional[datetime.datetime], packet: HCI_Packet
|
||||
) -> None:
|
||||
self.packet_timestamp = timestamp
|
||||
self.emit(packet)
|
||||
|
||||
if packet.hci_packet_type == HCI_ACL_DATA_PACKET:
|
||||
acl_packet = cast(HCI_AclDataPacket, packet)
|
||||
# Look for an existing stream for this handle, create one if it is the
|
||||
# first ACL packet for that connection handle
|
||||
if (stream := self.acl_streams.get(packet.connection_handle)) is None:
|
||||
stream = self.start_acl_stream(packet.connection_handle)
|
||||
stream.feed_packet(packet)
|
||||
if (
|
||||
stream := self.acl_streams.get(acl_packet.connection_handle)
|
||||
) is None:
|
||||
stream = self.start_acl_stream(acl_packet.connection_handle)
|
||||
stream.feed_packet(acl_packet)
|
||||
elif packet.hci_packet_type == HCI_EVENT_PACKET:
|
||||
if packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
|
||||
self.end_acl_stream(packet.connection_handle)
|
||||
event_packet = cast(HCI_Event, packet)
|
||||
if event_packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
|
||||
self.end_acl_stream(
|
||||
cast(HCI_Disconnection_Complete_Event, packet).connection_handle
|
||||
)
|
||||
|
||||
def emit(self, message):
|
||||
self.emit_message(f'[{self.label}] {message}')
|
||||
def emit(self, message: Any) -> None:
|
||||
if self.packet_timestamp:
|
||||
prefix = f"[{self.packet_timestamp.strftime('%Y-%m-%d %H:%M:%S.%f')}]"
|
||||
else:
|
||||
prefix = ""
|
||||
self.emit_message(f'{prefix}[{self.label}] {message}')
|
||||
|
||||
def trace(self, packet, direction=0):
|
||||
def trace(
|
||||
self,
|
||||
packet: HCI_Packet,
|
||||
direction: int = 0,
|
||||
timestamp: Optional[datetime.datetime] = None,
|
||||
) -> None:
|
||||
if direction == 0:
|
||||
self.host_to_controller_analyzer.on_packet(packet)
|
||||
self.host_to_controller_analyzer.on_packet(timestamp, packet)
|
||||
else:
|
||||
self.controller_to_host_analyzer.on_packet(packet)
|
||||
self.controller_to_host_analyzer.on_packet(timestamp, packet)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host_to_controller_label=color('HOST->CONTROLLER', 'blue'),
|
||||
controller_to_host_label=color('CONTROLLER->HOST', 'cyan'),
|
||||
emit_message=logger.info,
|
||||
):
|
||||
host_to_controller_label: str = color('HOST->CONTROLLER', 'blue'),
|
||||
controller_to_host_label: str = color('CONTROLLER->HOST', 'cyan'),
|
||||
emit_message: Callable[..., None] = logger.info,
|
||||
) -> None:
|
||||
self.host_to_controller_analyzer = PacketTracer.Analyzer(
|
||||
host_to_controller_label, emit_message
|
||||
)
|
||||
@@ -207,3 +291,15 @@ class PacketTracer:
|
||||
)
|
||||
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
|
||||
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer
|
||||
|
||||
|
||||
def generate_irk() -> bytes:
|
||||
return crypto.r()
|
||||
|
||||
|
||||
def verify_rpa_with_irk(rpa: Address, irk: bytes) -> bool:
|
||||
rpa_bytes = bytes(rpa)
|
||||
prand_given = rpa_bytes[3:]
|
||||
hash_given = rpa_bytes[:3]
|
||||
hash_local = crypto.ah(irk, prand_given)
|
||||
return hash_local[:3] == hash_given
|
||||
|
||||
1538
bumble/hfp.py
1538
bumble/hfp.py
File diff suppressed because it is too large
Load Diff
555
bumble/hid.py
Normal file
555
bumble/hid.py
Normal file
@@ -0,0 +1,555 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import enum
|
||||
import struct
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pyee import EventEmitter
|
||||
from typing import Optional, Callable, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
from bumble import l2cap, device
|
||||
from bumble.colors import color
|
||||
from bumble.core import InvalidStateError, ProtocolError
|
||||
from .hci import Address
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
# fmt: on
|
||||
HID_CONTROL_PSM = 0x0011
|
||||
HID_INTERRUPT_PSM = 0x0013
|
||||
|
||||
|
||||
class Message:
|
||||
message_type: MessageType
|
||||
|
||||
# Report types
|
||||
class ReportType(enum.IntEnum):
|
||||
OTHER_REPORT = 0x00
|
||||
INPUT_REPORT = 0x01
|
||||
OUTPUT_REPORT = 0x02
|
||||
FEATURE_REPORT = 0x03
|
||||
|
||||
# Handshake parameters
|
||||
class Handshake(enum.IntEnum):
|
||||
SUCCESSFUL = 0x00
|
||||
NOT_READY = 0x01
|
||||
ERR_INVALID_REPORT_ID = 0x02
|
||||
ERR_UNSUPPORTED_REQUEST = 0x03
|
||||
ERR_INVALID_PARAMETER = 0x04
|
||||
ERR_UNKNOWN = 0x0E
|
||||
ERR_FATAL = 0x0F
|
||||
|
||||
# Message Type
|
||||
class MessageType(enum.IntEnum):
|
||||
HANDSHAKE = 0x00
|
||||
CONTROL = 0x01
|
||||
GET_REPORT = 0x04
|
||||
SET_REPORT = 0x05
|
||||
GET_PROTOCOL = 0x06
|
||||
SET_PROTOCOL = 0x07
|
||||
DATA = 0x0A
|
||||
|
||||
# Protocol modes
|
||||
class ProtocolMode(enum.IntEnum):
|
||||
BOOT_PROTOCOL = 0x00
|
||||
REPORT_PROTOCOL = 0x01
|
||||
|
||||
# Control Operations
|
||||
class ControlCommand(enum.IntEnum):
|
||||
SUSPEND = 0x03
|
||||
EXIT_SUSPEND = 0x04
|
||||
VIRTUAL_CABLE_UNPLUG = 0x05
|
||||
|
||||
# Class Method to derive header
|
||||
@classmethod
|
||||
def header(cls, lower_bits: int = 0x00) -> bytes:
|
||||
return bytes([(cls.message_type << 4) | lower_bits])
|
||||
|
||||
|
||||
# HIDP messages
|
||||
@dataclass
|
||||
class GetReportMessage(Message):
|
||||
report_type: int
|
||||
report_id: int
|
||||
buffer_size: int
|
||||
message_type = Message.MessageType.GET_REPORT
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
packet_bytes = bytearray()
|
||||
packet_bytes.append(self.report_id)
|
||||
if self.buffer_size == 0:
|
||||
return self.header(self.report_type) + packet_bytes
|
||||
else:
|
||||
return (
|
||||
self.header(0x08 | self.report_type)
|
||||
+ packet_bytes
|
||||
+ struct.pack("<H", self.buffer_size)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetReportMessage(Message):
|
||||
report_type: int
|
||||
data: bytes
|
||||
message_type = Message.MessageType.SET_REPORT
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendControlData(Message):
|
||||
report_type: int
|
||||
data: bytes
|
||||
message_type = Message.MessageType.DATA
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetProtocolMessage(Message):
|
||||
message_type = Message.MessageType.GET_PROTOCOL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetProtocolMessage(Message):
|
||||
protocol_mode: int
|
||||
message_type = Message.MessageType.SET_PROTOCOL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.protocol_mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Suspend(Message):
|
||||
message_type = Message.MessageType.CONTROL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(Message.ControlCommand.SUSPEND)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExitSuspend(Message):
|
||||
message_type = Message.MessageType.CONTROL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(Message.ControlCommand.EXIT_SUSPEND)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VirtualCableUnplug(Message):
|
||||
message_type = Message.MessageType.CONTROL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
|
||||
|
||||
|
||||
# Device sends input report, host sends output report.
|
||||
@dataclass
|
||||
class SendData(Message):
|
||||
data: bytes
|
||||
report_type: int
|
||||
message_type = Message.MessageType.DATA
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendHandshakeMessage(Message):
|
||||
result_code: int
|
||||
message_type = Message.MessageType.HANDSHAKE
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.result_code)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class HID(ABC, EventEmitter):
|
||||
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
|
||||
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
|
||||
connection: Optional[device.Connection] = None
|
||||
|
||||
class Role(enum.IntEnum):
|
||||
HOST = 0x00
|
||||
DEVICE = 0x01
|
||||
|
||||
def __init__(self, device: device.Device, role: Role) -> None:
|
||||
super().__init__()
|
||||
self.remote_device_bd_address: Optional[Address] = None
|
||||
self.device = device
|
||||
self.role = role
|
||||
|
||||
# Register ourselves with the L2CAP channel manager
|
||||
device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
|
||||
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
|
||||
|
||||
device.on('connection', self.on_device_connection)
|
||||
|
||||
async def connect_control_channel(self) -> None:
|
||||
# Create a new L2CAP connection - control channel
|
||||
try:
|
||||
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
|
||||
self.connection, HID_CONTROL_PSM
|
||||
)
|
||||
except ProtocolError:
|
||||
logging.exception(f'L2CAP connection failed.')
|
||||
raise
|
||||
|
||||
assert self.l2cap_ctrl_channel is not None
|
||||
# Become a sink for the L2CAP channel
|
||||
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
|
||||
|
||||
async def connect_interrupt_channel(self) -> None:
|
||||
# Create a new L2CAP connection - interrupt channel
|
||||
try:
|
||||
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
|
||||
self.connection, HID_INTERRUPT_PSM
|
||||
)
|
||||
except ProtocolError:
|
||||
logging.exception(f'L2CAP connection failed.')
|
||||
raise
|
||||
|
||||
assert self.l2cap_intr_channel is not None
|
||||
# Become a sink for the L2CAP channel
|
||||
self.l2cap_intr_channel.sink = self.on_intr_pdu
|
||||
|
||||
async def disconnect_interrupt_channel(self) -> None:
|
||||
if self.l2cap_intr_channel is None:
|
||||
raise InvalidStateError('invalid state')
|
||||
channel = self.l2cap_intr_channel
|
||||
self.l2cap_intr_channel = None
|
||||
await channel.disconnect()
|
||||
|
||||
async def disconnect_control_channel(self) -> None:
|
||||
if self.l2cap_ctrl_channel is None:
|
||||
raise InvalidStateError('invalid state')
|
||||
channel = self.l2cap_ctrl_channel
|
||||
self.l2cap_ctrl_channel = None
|
||||
await channel.disconnect()
|
||||
|
||||
def on_device_connection(self, connection: device.Connection) -> None:
|
||||
self.connection = connection
|
||||
self.remote_device_bd_address = connection.peer_address
|
||||
connection.on('disconnection', self.on_device_disconnection)
|
||||
|
||||
def on_device_disconnection(self, reason: int) -> None:
|
||||
self.connection = None
|
||||
|
||||
def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
|
||||
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
|
||||
l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
|
||||
|
||||
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
if l2cap_channel.psm == HID_CONTROL_PSM:
|
||||
self.l2cap_ctrl_channel = l2cap_channel
|
||||
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
|
||||
else:
|
||||
self.l2cap_intr_channel = l2cap_channel
|
||||
self.l2cap_intr_channel.sink = self.on_intr_pdu
|
||||
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
|
||||
|
||||
def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
if l2cap_channel.psm == HID_CONTROL_PSM:
|
||||
self.l2cap_ctrl_channel = None
|
||||
else:
|
||||
self.l2cap_intr_channel = None
|
||||
logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
|
||||
|
||||
@abstractmethod
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
pass
|
||||
|
||||
def on_intr_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
|
||||
self.emit("interrupt_data", pdu)
|
||||
|
||||
def send_pdu_on_ctrl(self, msg: bytes) -> None:
|
||||
assert self.l2cap_ctrl_channel
|
||||
self.l2cap_ctrl_channel.send_pdu(msg)
|
||||
|
||||
def send_pdu_on_intr(self, msg: bytes) -> None:
|
||||
assert self.l2cap_intr_channel
|
||||
self.l2cap_intr_channel.send_pdu(msg)
|
||||
|
||||
def send_data(self, data: bytes) -> None:
|
||||
if self.role == HID.Role.HOST:
|
||||
report_type = Message.ReportType.OUTPUT_REPORT
|
||||
else:
|
||||
report_type = Message.ReportType.INPUT_REPORT
|
||||
msg = SendData(data, report_type)
|
||||
hid_message = bytes(msg)
|
||||
if self.l2cap_intr_channel is not None:
|
||||
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_intr(hid_message)
|
||||
|
||||
def virtual_cable_unplug(self) -> None:
|
||||
msg = VirtualCableUnplug()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Device(HID):
|
||||
class GetSetReturn(enum.IntEnum):
|
||||
FAILURE = 0x00
|
||||
REPORT_ID_NOT_FOUND = 0x01
|
||||
ERR_UNSUPPORTED_REQUEST = 0x02
|
||||
ERR_UNKNOWN = 0x03
|
||||
ERR_INVALID_PARAMETER = 0x04
|
||||
SUCCESS = 0xFF
|
||||
|
||||
class GetSetStatus:
|
||||
def __init__(self) -> None:
|
||||
self.data = bytearray()
|
||||
self.status = 0
|
||||
|
||||
def __init__(self, device: device.Device) -> None:
|
||||
super().__init__(device, HID.Role.DEVICE)
|
||||
get_report_cb: Optional[Callable[[int, int, int], None]] = None
|
||||
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
|
||||
get_protocol_cb: Optional[Callable[[], None]] = None
|
||||
set_protocol_cb: Optional[Callable[[int], None]] = None
|
||||
|
||||
@override
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
|
||||
param = pdu[0] & 0x0F
|
||||
message_type = pdu[0] >> 4
|
||||
|
||||
if message_type == Message.MessageType.GET_REPORT:
|
||||
logger.debug('<<< HID GET REPORT')
|
||||
self.handle_get_report(pdu)
|
||||
elif message_type == Message.MessageType.SET_REPORT:
|
||||
logger.debug('<<< HID SET REPORT')
|
||||
self.handle_set_report(pdu)
|
||||
elif message_type == Message.MessageType.GET_PROTOCOL:
|
||||
logger.debug('<<< HID GET PROTOCOL')
|
||||
self.handle_get_protocol(pdu)
|
||||
elif message_type == Message.MessageType.SET_PROTOCOL:
|
||||
logger.debug('<<< HID SET PROTOCOL')
|
||||
self.handle_set_protocol(pdu)
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('control_data', pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.SUSPEND:
|
||||
logger.debug('<<< HID SUSPEND')
|
||||
self.emit('suspend')
|
||||
elif param == Message.ControlCommand.EXIT_SUSPEND:
|
||||
logger.debug('<<< HID EXIT SUSPEND')
|
||||
self.emit('exit_suspend')
|
||||
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
else:
|
||||
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def send_handshake_message(self, result_code: int) -> None:
|
||||
msg = SendHandshakeMessage(result_code)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def send_control_data(self, report_type: int, data: bytes):
|
||||
msg = SendControlData(report_type=report_type, data=data)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def handle_get_report(self, pdu: bytes):
|
||||
if self.get_report_cb is None:
|
||||
logger.debug("GetReport callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
report_type = pdu[0] & 0x03
|
||||
buffer_flag = (pdu[0] & 0x08) >> 3
|
||||
report_id = pdu[1]
|
||||
logger.debug(f"buffer_flag: {buffer_flag}")
|
||||
if buffer_flag == 1:
|
||||
buffer_size = (pdu[3] << 8) | pdu[2]
|
||||
else:
|
||||
buffer_size = 0
|
||||
|
||||
ret = self.get_report_cb(report_id, report_type, buffer_size)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.FAILURE:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
|
||||
elif ret.status == self.GetSetReturn.SUCCESS:
|
||||
data = bytearray()
|
||||
data.append(report_id)
|
||||
data.extend(ret.data)
|
||||
if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr]
|
||||
self.send_control_data(report_type=report_type, data=data)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
|
||||
self.get_report_cb = cb
|
||||
logger.debug("GetReport callback registered successfully")
|
||||
|
||||
def handle_set_report(self, pdu: bytes):
|
||||
if self.set_report_cb is None:
|
||||
logger.debug("SetReport callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
report_type = pdu[0] & 0x03
|
||||
report_id = pdu[1]
|
||||
report_data = pdu[2:]
|
||||
report_size = len(report_data) + 1
|
||||
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_set_report_cb(
|
||||
self, cb: Callable[[int, int, int, bytes], None]
|
||||
) -> None:
|
||||
self.set_report_cb = cb
|
||||
logger.debug("SetReport callback registered successfully")
|
||||
|
||||
def handle_get_protocol(self, pdu: bytes):
|
||||
if self.get_protocol_cb is None:
|
||||
logger.debug("GetProtocol callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
ret = self.get_protocol_cb()
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
|
||||
self.get_protocol_cb = cb
|
||||
logger.debug("GetProtocol callback registered successfully")
|
||||
|
||||
def handle_set_protocol(self, pdu: bytes):
|
||||
if self.set_protocol_cb is None:
|
||||
logger.debug("SetProtocol callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
ret = self.set_protocol_cb(pdu[0] & 0x01)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
|
||||
self.set_protocol_cb = cb
|
||||
logger.debug("SetProtocol callback registered successfully")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Host(HID):
|
||||
def __init__(self, device: device.Device) -> None:
|
||||
super().__init__(device, HID.Role.HOST)
|
||||
|
||||
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
|
||||
msg = GetReportMessage(
|
||||
report_type=report_type, report_id=report_id, buffer_size=buffer_size
|
||||
)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_report(self, report_type: int, data: bytes) -> None:
|
||||
msg = SetReportMessage(report_type=report_type, data=data)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def get_protocol(self) -> None:
|
||||
msg = GetProtocolMessage()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_protocol(self, protocol_mode: int) -> None:
|
||||
msg = SetProtocolMessage(protocol_mode=protocol_mode)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def suspend(self) -> None:
|
||||
msg = Suspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def exit_suspend(self) -> None:
|
||||
msg = ExitSuspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
@override
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
|
||||
param = pdu[0] & 0x0F
|
||||
message_type = pdu[0] >> 4
|
||||
if message_type == Message.MessageType.HANDSHAKE:
|
||||
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
|
||||
self.emit('handshake', Message.Handshake(param))
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('control_data', pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
else:
|
||||
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
|
||||
746
bumble/host.py
746
bumble/host.py
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,8 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
|
||||
from typing_extensions import Self
|
||||
|
||||
from .colors import color
|
||||
from .hci import Address
|
||||
@@ -128,10 +129,10 @@ class PairingKeys:
|
||||
|
||||
def print(self, prefix=''):
|
||||
keys_dict = self.to_dict()
|
||||
for (container_property, value) in keys_dict.items():
|
||||
for container_property, value in keys_dict.items():
|
||||
if isinstance(value, dict):
|
||||
print(f'{prefix}{color(container_property, "cyan")}:')
|
||||
for (key_property, key_value) in value.items():
|
||||
for key_property, key_value in value.items():
|
||||
print(f'{prefix} {color(key_property, "green")}: {key_value}')
|
||||
else:
|
||||
print(f'{prefix}{color(container_property, "cyan")}: {value}')
|
||||
@@ -158,7 +159,7 @@ class KeyStore:
|
||||
async def get_resolving_keys(self):
|
||||
all_keys = await self.get_all()
|
||||
resolving_keys = []
|
||||
for (name, keys) in all_keys:
|
||||
for name, keys in all_keys:
|
||||
if keys.irk is not None:
|
||||
if keys.address_type is None:
|
||||
address_type = Address.RANDOM_DEVICE_ADDRESS
|
||||
@@ -171,7 +172,7 @@ class KeyStore:
|
||||
async def print(self, prefix=''):
|
||||
entries = await self.get_all()
|
||||
separator = ''
|
||||
for (name, keys) in entries:
|
||||
for name, keys in entries:
|
||||
print(separator + prefix + color(name, 'yellow'))
|
||||
keys.print(prefix=prefix + ' ')
|
||||
separator = '\n'
|
||||
@@ -253,8 +254,10 @@ class JsonKeyStore(KeyStore):
|
||||
|
||||
logger.debug(f'JSON keystore: {self.filename}')
|
||||
|
||||
@staticmethod
|
||||
def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]:
|
||||
@classmethod
|
||||
def from_device(
|
||||
cls: Type[Self], device: Device, filename: Optional[str] = None
|
||||
) -> Self:
|
||||
if not filename:
|
||||
# Extract the filename from the config if there is one
|
||||
if device.config.keystore is not None:
|
||||
@@ -270,7 +273,7 @@ class JsonKeyStore(KeyStore):
|
||||
else:
|
||||
namespace = JsonKeyStore.DEFAULT_NAMESPACE
|
||||
|
||||
return JsonKeyStore(namespace, filename)
|
||||
return cls(namespace, filename)
|
||||
|
||||
async def load(self):
|
||||
# Try to open the file, without failing. If the file does not exist, it
|
||||
|
||||
369
bumble/l2cap.py
369
bumble/l2cap.py
@@ -17,6 +17,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import struct
|
||||
@@ -38,8 +39,16 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from .utils import deprecated
|
||||
from .colors import color
|
||||
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
|
||||
from .core import (
|
||||
BT_CENTRAL_ROLE,
|
||||
InvalidStateError,
|
||||
InvalidArgumentError,
|
||||
InvalidPacketError,
|
||||
OutOfResourcesError,
|
||||
ProtocolError,
|
||||
)
|
||||
from .hci import (
|
||||
HCI_LE_Connection_Update_Command,
|
||||
HCI_Object,
|
||||
@@ -68,6 +77,7 @@ L2CAP_LE_SIGNALING_CID = 0x05
|
||||
|
||||
L2CAP_MIN_LE_MTU = 23
|
||||
L2CAP_MIN_BR_EDR_MTU = 48
|
||||
L2CAP_MAX_BR_EDR_MTU = 65535
|
||||
|
||||
L2CAP_DEFAULT_MTU = 2048 # Default value for the MTU we are willing to accept
|
||||
|
||||
@@ -147,9 +157,10 @@ L2CAP_INVALID_CID_IN_REQUEST_REASON = 0x0002
|
||||
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU = 65535
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2046
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256
|
||||
|
||||
@@ -167,6 +178,37 @@ L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE = 0x01
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ClassicChannelSpec:
|
||||
psm: Optional[int] = None
|
||||
mtu: int = L2CAP_DEFAULT_MTU
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LeCreditBasedChannelSpec:
|
||||
psm: Optional[int] = None
|
||||
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU
|
||||
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
|
||||
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS
|
||||
|
||||
def __post_init__(self):
|
||||
if (
|
||||
self.max_credits < 1
|
||||
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
|
||||
):
|
||||
raise InvalidArgumentError('max credits out of range')
|
||||
if (
|
||||
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
|
||||
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
|
||||
):
|
||||
raise InvalidArgumentError('MTU out of range')
|
||||
if (
|
||||
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
|
||||
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
|
||||
):
|
||||
raise InvalidArgumentError('MPS out of range')
|
||||
|
||||
|
||||
class L2CAP_PDU:
|
||||
'''
|
||||
See Bluetooth spec @ Vol 3, Part A - 3 DATA PACKET FORMAT
|
||||
@@ -174,9 +216,9 @@ class L2CAP_PDU:
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> L2CAP_PDU:
|
||||
# Sanity check
|
||||
# Check parameters
|
||||
if len(data) < 4:
|
||||
raise ValueError('not enough data for L2CAP header')
|
||||
raise InvalidPacketError('not enough data for L2CAP header')
|
||||
|
||||
_, l2cap_pdu_cid = struct.unpack_from('<HH', data, 0)
|
||||
l2cap_pdu_payload = data[4:]
|
||||
@@ -361,6 +403,9 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
|
||||
See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST
|
||||
'''
|
||||
|
||||
psm: int
|
||||
source_cid: int
|
||||
|
||||
@staticmethod
|
||||
def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]:
|
||||
psm_length = 2
|
||||
@@ -402,6 +447,11 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
|
||||
See Bluetooth spec @ Vol 3, Part A - 4.3 CONNECTION RESPONSE
|
||||
'''
|
||||
|
||||
source_cid: int
|
||||
destination_cid: int
|
||||
status: int
|
||||
result: int
|
||||
|
||||
CONNECTION_SUCCESSFUL = 0x0000
|
||||
CONNECTION_PENDING = 0x0001
|
||||
CONNECTION_REFUSED_PSM_NOT_SUPPORTED = 0x0002
|
||||
@@ -676,7 +726,7 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Channel(EventEmitter):
|
||||
class ClassicChannel(EventEmitter):
|
||||
class State(enum.IntEnum):
|
||||
# States
|
||||
CLOSED = 0x00
|
||||
@@ -707,6 +757,8 @@ class Channel(EventEmitter):
|
||||
sink: Optional[Callable[[bytes], Any]]
|
||||
state: State
|
||||
connection: Connection
|
||||
mtu: int
|
||||
peer_mtu: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -723,6 +775,7 @@ class Channel(EventEmitter):
|
||||
self.signaling_cid = signaling_cid
|
||||
self.state = self.State.CLOSED
|
||||
self.mtu = mtu
|
||||
self.peer_mtu = L2CAP_MIN_BR_EDR_MTU
|
||||
self.psm = psm
|
||||
self.source_cid = source_cid
|
||||
self.destination_cid = 0
|
||||
@@ -770,7 +823,7 @@ class Channel(EventEmitter):
|
||||
|
||||
# Check that we can start a new connection
|
||||
if self.connection_result:
|
||||
raise RuntimeError('connection already pending')
|
||||
raise InvalidStateError('connection already pending')
|
||||
|
||||
self._change_state(self.State.WAIT_CONNECT_RSP)
|
||||
self.send_control_frame(
|
||||
@@ -787,7 +840,9 @@ class Channel(EventEmitter):
|
||||
|
||||
# Wait for the connection to succeed or fail
|
||||
try:
|
||||
return await self.connection_result
|
||||
return await self.connection.abort_on(
|
||||
'disconnection', self.connection_result
|
||||
)
|
||||
finally:
|
||||
self.connection_result = None
|
||||
|
||||
@@ -819,7 +874,7 @@ class Channel(EventEmitter):
|
||||
[
|
||||
(
|
||||
L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE,
|
||||
struct.pack('<H', L2CAP_DEFAULT_MTU),
|
||||
struct.pack('<H', self.mtu),
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -884,8 +939,8 @@ class Channel(EventEmitter):
|
||||
options = L2CAP_Control_Frame.decode_configuration_options(request.options)
|
||||
for option in options:
|
||||
if option[0] == L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE:
|
||||
self.mtu = struct.unpack('<H', option[1])[0]
|
||||
logger.debug(f'MTU = {self.mtu}')
|
||||
self.peer_mtu = struct.unpack('<H', option[1])[0]
|
||||
logger.debug(f'peer MTU = {self.peer_mtu}')
|
||||
|
||||
self.send_control_frame(
|
||||
L2CAP_Configure_Response(
|
||||
@@ -984,13 +1039,13 @@ class Channel(EventEmitter):
|
||||
return (
|
||||
f'Channel({self.source_cid}->{self.destination_cid}, '
|
||||
f'PSM={self.psm}, '
|
||||
f'MTU={self.mtu}, '
|
||||
f'MTU={self.mtu}/{self.peer_mtu}, '
|
||||
f'state={self.state.name})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class LeConnectionOrientedChannel(EventEmitter):
|
||||
class LeCreditBasedChannel(EventEmitter):
|
||||
"""
|
||||
LE Credit-based Connection Oriented Channel
|
||||
"""
|
||||
@@ -1004,11 +1059,13 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
CONNECTION_ERROR = 5
|
||||
|
||||
out_queue: Deque[bytes]
|
||||
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
|
||||
connection_result: Optional[asyncio.Future[LeCreditBasedChannel]]
|
||||
disconnection_result: Optional[asyncio.Future[None]]
|
||||
in_sdu: Optional[bytes]
|
||||
out_sdu: Optional[bytes]
|
||||
state: State
|
||||
connection: Connection
|
||||
sink: Optional[Callable[[bytes], Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1071,7 +1128,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
||||
self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame)
|
||||
|
||||
async def connect(self) -> LeConnectionOrientedChannel:
|
||||
async def connect(self) -> LeCreditBasedChannel:
|
||||
# Check that we're in the right state
|
||||
if self.state != self.State.INIT:
|
||||
raise InvalidStateError('not in a connectable state')
|
||||
@@ -1079,7 +1136,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
# Check that we can start a new connection
|
||||
identifier = self.manager.next_identifier(self.connection)
|
||||
if identifier in self.manager.le_coc_requests:
|
||||
raise RuntimeError('too many concurrent connection requests')
|
||||
raise InvalidStateError('too many concurrent connection requests')
|
||||
|
||||
self._change_state(self.State.CONNECTING)
|
||||
request = L2CAP_LE_Credit_Based_Connection_Request(
|
||||
@@ -1342,15 +1399,67 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ClassicChannelServer(EventEmitter):
|
||||
def __init__(
|
||||
self,
|
||||
manager: ChannelManager,
|
||||
psm: int,
|
||||
handler: Optional[Callable[[ClassicChannel], Any]],
|
||||
mtu: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.manager = manager
|
||||
self.handler = handler
|
||||
self.psm = psm
|
||||
self.mtu = mtu
|
||||
|
||||
def on_connection(self, channel: ClassicChannel) -> None:
|
||||
self.emit('connection', channel)
|
||||
if self.handler:
|
||||
self.handler(channel)
|
||||
|
||||
def close(self) -> None:
|
||||
if self.psm in self.manager.servers:
|
||||
del self.manager.servers[self.psm]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class LeCreditBasedChannelServer(EventEmitter):
|
||||
def __init__(
|
||||
self,
|
||||
manager: ChannelManager,
|
||||
psm: int,
|
||||
handler: Optional[Callable[[LeCreditBasedChannel], Any]],
|
||||
max_credits: int,
|
||||
mtu: int,
|
||||
mps: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.manager = manager
|
||||
self.handler = handler
|
||||
self.psm = psm
|
||||
self.max_credits = max_credits
|
||||
self.mtu = mtu
|
||||
self.mps = mps
|
||||
|
||||
def on_connection(self, channel: LeCreditBasedChannel) -> None:
|
||||
self.emit('connection', channel)
|
||||
if self.handler:
|
||||
self.handler(channel)
|
||||
|
||||
def close(self) -> None:
|
||||
if self.psm in self.manager.le_coc_servers:
|
||||
del self.manager.le_coc_servers[self.psm]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ChannelManager:
|
||||
identifiers: Dict[int, int]
|
||||
channels: Dict[int, Dict[int, Union[Channel, LeConnectionOrientedChannel]]]
|
||||
servers: Dict[int, Callable[[Channel], Any]]
|
||||
le_coc_channels: Dict[int, Dict[int, LeConnectionOrientedChannel]]
|
||||
le_coc_servers: Dict[
|
||||
int, Tuple[Callable[[LeConnectionOrientedChannel], Any], int, int, int]
|
||||
]
|
||||
channels: Dict[int, Dict[int, Union[ClassicChannel, LeCreditBasedChannel]]]
|
||||
servers: Dict[int, ClassicChannelServer]
|
||||
le_coc_channels: Dict[int, Dict[int, LeCreditBasedChannel]]
|
||||
le_coc_servers: Dict[int, LeCreditBasedChannelServer]
|
||||
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
|
||||
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
|
||||
_host: Optional[Host]
|
||||
@@ -1414,7 +1523,7 @@ class ChannelManager:
|
||||
if cid not in channels:
|
||||
return cid
|
||||
|
||||
raise RuntimeError('no free CID available')
|
||||
raise OutOfResourcesError('no free CID available')
|
||||
|
||||
@staticmethod
|
||||
def find_free_le_cid(channels: Iterable[int]) -> int:
|
||||
@@ -1427,22 +1536,7 @@ class ChannelManager:
|
||||
if cid not in channels:
|
||||
return cid
|
||||
|
||||
raise RuntimeError('no free CID')
|
||||
|
||||
@staticmethod
|
||||
def check_le_coc_parameters(max_credits: int, mtu: int, mps: int) -> None:
|
||||
if (
|
||||
max_credits < 1
|
||||
or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
|
||||
):
|
||||
raise ValueError('max credits out of range')
|
||||
if mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
|
||||
raise ValueError('MTU too small')
|
||||
if (
|
||||
mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
|
||||
or mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
|
||||
):
|
||||
raise ValueError('MPS out of range')
|
||||
raise OutOfResourcesError('no free CID')
|
||||
|
||||
def next_identifier(self, connection: Connection) -> int:
|
||||
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
|
||||
@@ -1458,8 +1552,22 @@ class ChannelManager:
|
||||
if cid in self.fixed_channels:
|
||||
del self.fixed_channels[cid]
|
||||
|
||||
def register_server(self, psm: int, server: Callable[[Channel], Any]) -> int:
|
||||
if psm == 0:
|
||||
@deprecated("Please use create_classic_server")
|
||||
def register_server(
|
||||
self,
|
||||
psm: int,
|
||||
server: Callable[[ClassicChannel], Any],
|
||||
) -> int:
|
||||
return self.create_classic_server(
|
||||
handler=server, spec=ClassicChannelSpec(psm=psm)
|
||||
).psm
|
||||
|
||||
def create_classic_server(
|
||||
self,
|
||||
spec: ClassicChannelSpec,
|
||||
handler: Optional[Callable[[ClassicChannel], Any]] = None,
|
||||
) -> ClassicChannelServer:
|
||||
if not spec.psm:
|
||||
# Find a free PSM
|
||||
for candidate in range(
|
||||
L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2
|
||||
@@ -1468,62 +1576,75 @@ class ChannelManager:
|
||||
continue
|
||||
if candidate in self.servers:
|
||||
continue
|
||||
psm = candidate
|
||||
spec.psm = candidate
|
||||
break
|
||||
else:
|
||||
raise InvalidStateError('no free PSM')
|
||||
else:
|
||||
# Check that the PSM isn't already in use
|
||||
if psm in self.servers:
|
||||
raise ValueError('PSM already in use')
|
||||
if spec.psm in self.servers:
|
||||
raise InvalidArgumentError('PSM already in use')
|
||||
|
||||
# Check that the PSM is valid
|
||||
if psm % 2 == 0:
|
||||
raise ValueError('invalid PSM (not odd)')
|
||||
check = psm >> 8
|
||||
if spec.psm % 2 == 0:
|
||||
raise InvalidArgumentError('invalid PSM (not odd)')
|
||||
check = spec.psm >> 8
|
||||
while check:
|
||||
if check % 2 != 0:
|
||||
raise ValueError('invalid PSM')
|
||||
raise InvalidArgumentError('invalid PSM')
|
||||
check >>= 8
|
||||
|
||||
self.servers[psm] = server
|
||||
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
|
||||
|
||||
return psm
|
||||
return self.servers[spec.psm]
|
||||
|
||||
@deprecated("Please use create_le_credit_based_server()")
|
||||
def register_le_coc_server(
|
||||
self,
|
||||
psm: int,
|
||||
server: Callable[[LeConnectionOrientedChannel], Any],
|
||||
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
|
||||
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
|
||||
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
|
||||
server: Callable[[LeCreditBasedChannel], Any],
|
||||
max_credits: int,
|
||||
mtu: int,
|
||||
mps: int,
|
||||
) -> int:
|
||||
self.check_le_coc_parameters(max_credits, mtu, mps)
|
||||
return self.create_le_credit_based_server(
|
||||
spec=LeCreditBasedChannelSpec(
|
||||
psm=None if psm == 0 else psm, mtu=mtu, mps=mps, max_credits=max_credits
|
||||
),
|
||||
handler=server,
|
||||
).psm
|
||||
|
||||
if psm == 0:
|
||||
def create_le_credit_based_server(
|
||||
self,
|
||||
spec: LeCreditBasedChannelSpec,
|
||||
handler: Optional[Callable[[LeCreditBasedChannel], Any]] = None,
|
||||
) -> LeCreditBasedChannelServer:
|
||||
if not spec.psm:
|
||||
# Find a free PSM
|
||||
for candidate in range(
|
||||
L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1
|
||||
):
|
||||
if candidate in self.le_coc_servers:
|
||||
continue
|
||||
psm = candidate
|
||||
spec.psm = candidate
|
||||
break
|
||||
else:
|
||||
raise InvalidStateError('no free PSM')
|
||||
else:
|
||||
# Check that the PSM isn't already in use
|
||||
if psm in self.le_coc_servers:
|
||||
raise ValueError('PSM already in use')
|
||||
if spec.psm in self.le_coc_servers:
|
||||
raise InvalidArgumentError('PSM already in use')
|
||||
|
||||
self.le_coc_servers[psm] = (
|
||||
server,
|
||||
max_credits or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
|
||||
mtu or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
|
||||
mps or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
|
||||
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
|
||||
self,
|
||||
spec.psm,
|
||||
handler,
|
||||
max_credits=spec.max_credits,
|
||||
mtu=spec.mtu,
|
||||
mps=spec.mps,
|
||||
)
|
||||
|
||||
return psm
|
||||
return self.le_coc_servers[spec.psm]
|
||||
|
||||
def on_disconnection(self, connection_handle: int, _reason: int) -> None:
|
||||
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
|
||||
@@ -1540,12 +1661,13 @@ class ChannelManager:
|
||||
|
||||
def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
|
||||
pdu_bytes = bytes(pdu)
|
||||
logger.debug(
|
||||
f'{color(">>> Sending L2CAP PDU", "blue")} '
|
||||
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
|
||||
f'{connection.peer_address}: {pdu_str}'
|
||||
f'{connection.peer_address}: {len(pdu_bytes)} bytes, {pdu_str}'
|
||||
)
|
||||
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
|
||||
self.host.send_l2cap_pdu(connection.handle, cid, pdu_bytes)
|
||||
|
||||
def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
|
||||
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
|
||||
@@ -1650,13 +1772,13 @@ class ChannelManager:
|
||||
logger.debug(
|
||||
f'creating server channel with cid={source_cid} for psm {request.psm}'
|
||||
)
|
||||
channel = Channel(
|
||||
self, connection, cid, request.psm, source_cid, L2CAP_MIN_BR_EDR_MTU
|
||||
channel = ClassicChannel(
|
||||
self, connection, cid, request.psm, source_cid, server.mtu
|
||||
)
|
||||
connection_channels[source_cid] = channel
|
||||
|
||||
# Notify
|
||||
server(channel)
|
||||
server.on_connection(channel)
|
||||
channel.on_connection_request(request)
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -1822,7 +1944,7 @@ class ChannelManager:
|
||||
supervision_timeout=request.timeout,
|
||||
min_ce_length=0,
|
||||
max_ce_length=0,
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.send_control_frame(
|
||||
@@ -1878,7 +2000,7 @@ class ChannelManager:
|
||||
self, connection: Connection, cid: int, request
|
||||
) -> None:
|
||||
if request.le_psm in self.le_coc_servers:
|
||||
(server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm]
|
||||
server = self.le_coc_servers[request.le_psm]
|
||||
|
||||
# Check that the CID isn't already used
|
||||
le_connection_channels = self.le_coc_channels.setdefault(
|
||||
@@ -1892,8 +2014,8 @@ class ChannelManager:
|
||||
L2CAP_LE_Credit_Based_Connection_Response(
|
||||
identifier=request.identifier,
|
||||
destination_cid=0,
|
||||
mtu=mtu,
|
||||
mps=mps,
|
||||
mtu=server.mtu,
|
||||
mps=server.mps,
|
||||
initial_credits=0,
|
||||
# pylint: disable=line-too-long
|
||||
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
|
||||
@@ -1911,8 +2033,8 @@ class ChannelManager:
|
||||
L2CAP_LE_Credit_Based_Connection_Response(
|
||||
identifier=request.identifier,
|
||||
destination_cid=0,
|
||||
mtu=mtu,
|
||||
mps=mps,
|
||||
mtu=server.mtu,
|
||||
mps=server.mps,
|
||||
initial_credits=0,
|
||||
# pylint: disable=line-too-long
|
||||
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
|
||||
@@ -1925,18 +2047,18 @@ class ChannelManager:
|
||||
f'creating LE CoC server channel with cid={source_cid} for psm '
|
||||
f'{request.le_psm}'
|
||||
)
|
||||
channel = LeConnectionOrientedChannel(
|
||||
channel = LeCreditBasedChannel(
|
||||
self,
|
||||
connection,
|
||||
request.le_psm,
|
||||
source_cid,
|
||||
request.source_cid,
|
||||
mtu,
|
||||
mps,
|
||||
server.mtu,
|
||||
server.mps,
|
||||
request.initial_credits,
|
||||
request.mtu,
|
||||
request.mps,
|
||||
max_credits,
|
||||
server.max_credits,
|
||||
True,
|
||||
)
|
||||
connection_channels[source_cid] = channel
|
||||
@@ -1949,16 +2071,16 @@ class ChannelManager:
|
||||
L2CAP_LE_Credit_Based_Connection_Response(
|
||||
identifier=request.identifier,
|
||||
destination_cid=source_cid,
|
||||
mtu=mtu,
|
||||
mps=mps,
|
||||
initial_credits=max_credits,
|
||||
mtu=server.mtu,
|
||||
mps=server.mps,
|
||||
initial_credits=server.max_credits,
|
||||
# pylint: disable=line-too-long
|
||||
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL,
|
||||
),
|
||||
)
|
||||
|
||||
# Notify
|
||||
server(channel)
|
||||
server.on_connection(channel)
|
||||
else:
|
||||
logger.info(
|
||||
f'No LE server for connection 0x{connection.handle:04X} '
|
||||
@@ -2013,37 +2135,51 @@ class ChannelManager:
|
||||
|
||||
channel.on_credits(credit.credits)
|
||||
|
||||
def on_channel_closed(self, channel: Channel) -> None:
|
||||
def on_channel_closed(self, channel: ClassicChannel) -> None:
|
||||
connection_channels = self.channels.get(channel.connection.handle)
|
||||
if connection_channels:
|
||||
if channel.source_cid in connection_channels:
|
||||
del connection_channels[channel.source_cid]
|
||||
|
||||
@deprecated("Please use create_le_credit_based_channel()")
|
||||
async def open_le_coc(
|
||||
self, connection: Connection, psm: int, max_credits: int, mtu: int, mps: int
|
||||
) -> LeConnectionOrientedChannel:
|
||||
self.check_le_coc_parameters(max_credits, mtu, mps)
|
||||
) -> LeCreditBasedChannel:
|
||||
return await self.create_le_credit_based_channel(
|
||||
connection=connection,
|
||||
spec=LeCreditBasedChannelSpec(
|
||||
psm=psm, max_credits=max_credits, mtu=mtu, mps=mps
|
||||
),
|
||||
)
|
||||
|
||||
async def create_le_credit_based_channel(
|
||||
self,
|
||||
connection: Connection,
|
||||
spec: LeCreditBasedChannelSpec,
|
||||
) -> LeCreditBasedChannel:
|
||||
# Find a free CID for the new channel
|
||||
connection_channels = self.channels.setdefault(connection.handle, {})
|
||||
source_cid = self.find_free_le_cid(connection_channels)
|
||||
if source_cid is None: # Should never happen!
|
||||
raise RuntimeError('all CIDs already in use')
|
||||
raise OutOfResourcesError('all CIDs already in use')
|
||||
|
||||
if spec.psm is None:
|
||||
raise InvalidArgumentError('PSM cannot be None')
|
||||
|
||||
# Create the channel
|
||||
logger.debug(f'creating coc channel with cid={source_cid} for psm {psm}')
|
||||
channel = LeConnectionOrientedChannel(
|
||||
logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}')
|
||||
channel = LeCreditBasedChannel(
|
||||
manager=self,
|
||||
connection=connection,
|
||||
le_psm=psm,
|
||||
le_psm=spec.psm,
|
||||
source_cid=source_cid,
|
||||
destination_cid=0,
|
||||
mtu=mtu,
|
||||
mps=mps,
|
||||
mtu=spec.mtu,
|
||||
mps=spec.mps,
|
||||
credits=0,
|
||||
peer_mtu=0,
|
||||
peer_mps=0,
|
||||
peer_credits=max_credits,
|
||||
peer_credits=spec.max_credits,
|
||||
connected=False,
|
||||
)
|
||||
connection_channels[source_cid] = channel
|
||||
@@ -2062,27 +2198,62 @@ class ChannelManager:
|
||||
|
||||
return channel
|
||||
|
||||
async def connect(self, connection: Connection, psm: int) -> Channel:
|
||||
@deprecated("Please use create_classic_channel()")
|
||||
async def connect(self, connection: Connection, psm: int) -> ClassicChannel:
|
||||
return await self.create_classic_channel(
|
||||
connection=connection, spec=ClassicChannelSpec(psm=psm)
|
||||
)
|
||||
|
||||
async def create_classic_channel(
|
||||
self, connection: Connection, spec: ClassicChannelSpec
|
||||
) -> ClassicChannel:
|
||||
# NOTE: this implementation hard-codes BR/EDR
|
||||
|
||||
# Find a free CID for a new channel
|
||||
connection_channels = self.channels.setdefault(connection.handle, {})
|
||||
source_cid = self.find_free_br_edr_cid(connection_channels)
|
||||
if source_cid is None: # Should never happen!
|
||||
raise RuntimeError('all CIDs already in use')
|
||||
raise OutOfResourcesError('all CIDs already in use')
|
||||
|
||||
if spec.psm is None:
|
||||
raise InvalidArgumentError('PSM cannot be None')
|
||||
|
||||
# Create the channel
|
||||
logger.debug(f'creating client channel with cid={source_cid} for psm {psm}')
|
||||
channel = Channel(
|
||||
self, connection, L2CAP_SIGNALING_CID, psm, source_cid, L2CAP_MIN_BR_EDR_MTU
|
||||
logger.debug(
|
||||
f'creating client channel with cid={source_cid} for psm {spec.psm}'
|
||||
)
|
||||
channel = ClassicChannel(
|
||||
self,
|
||||
connection,
|
||||
L2CAP_SIGNALING_CID,
|
||||
spec.psm,
|
||||
source_cid,
|
||||
spec.mtu,
|
||||
)
|
||||
connection_channels[source_cid] = channel
|
||||
|
||||
# Connect
|
||||
try:
|
||||
await channel.connect()
|
||||
except Exception as e:
|
||||
except BaseException as e:
|
||||
del connection_channels[source_cid]
|
||||
raise e
|
||||
|
||||
return channel
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Deprecated Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Channel(ClassicChannel):
|
||||
@deprecated("Please use ClassicChannel")
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class LeConnectionOrientedChannel(LeCreditBasedChannel):
|
||||
@deprecated("Please use LeCreditBasedChannel")
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
121
bumble/link.py
121
bumble/link.py
@@ -19,16 +19,25 @@ import logging
|
||||
import asyncio
|
||||
from functools import partial
|
||||
|
||||
from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT
|
||||
from bumble.core import (
|
||||
BT_PERIPHERAL_ROLE,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_LE_TRANSPORT,
|
||||
InvalidStateError,
|
||||
)
|
||||
from bumble.colors import color
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
HCI_SUCCESS,
|
||||
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
|
||||
HCI_CONNECTION_TIMEOUT_ERROR,
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
HCI_PAGE_TIMEOUT_ERROR,
|
||||
HCI_Connection_Complete_Event,
|
||||
)
|
||||
from bumble import controller
|
||||
|
||||
from typing import Optional, Set
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -57,6 +66,8 @@ class LocalLink:
|
||||
Link bus for controllers to communicate with each other
|
||||
'''
|
||||
|
||||
controllers: Set[controller.Controller]
|
||||
|
||||
def __init__(self):
|
||||
self.controllers = set()
|
||||
self.pending_connection = None
|
||||
@@ -79,7 +90,9 @@ class LocalLink:
|
||||
return controller
|
||||
return None
|
||||
|
||||
def find_classic_controller(self, address):
|
||||
def find_classic_controller(
|
||||
self, address: Address
|
||||
) -> Optional[controller.Controller]:
|
||||
for controller in self.controllers:
|
||||
if controller.public_address == address:
|
||||
return controller
|
||||
@@ -188,6 +201,60 @@ class LocalLink:
|
||||
if peripheral_controller := self.find_controller(peripheral_address):
|
||||
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
|
||||
|
||||
def create_cis(
|
||||
self,
|
||||
central_controller: controller.Controller,
|
||||
peripheral_address: Address,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
|
||||
)
|
||||
if peripheral_controller := self.find_controller(peripheral_address):
|
||||
asyncio.get_running_loop().call_soon(
|
||||
peripheral_controller.on_link_cis_request,
|
||||
central_controller.random_address,
|
||||
cig_id,
|
||||
cis_id,
|
||||
)
|
||||
|
||||
def accept_cis(
|
||||
self,
|
||||
peripheral_controller: controller.Controller,
|
||||
central_address: Address,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
|
||||
)
|
||||
if central_controller := self.find_controller(central_address):
|
||||
asyncio.get_running_loop().call_soon(
|
||||
central_controller.on_link_cis_established, cig_id, cis_id
|
||||
)
|
||||
asyncio.get_running_loop().call_soon(
|
||||
peripheral_controller.on_link_cis_established, cig_id, cis_id
|
||||
)
|
||||
|
||||
def disconnect_cis(
|
||||
self,
|
||||
initiator_controller: controller.Controller,
|
||||
peer_address: Address,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
|
||||
)
|
||||
if peer_controller := self.find_controller(peer_address):
|
||||
asyncio.get_running_loop().call_soon(
|
||||
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
|
||||
)
|
||||
asyncio.get_running_loop().call_soon(
|
||||
peer_controller.on_link_cis_disconnected, cig_id, cis_id
|
||||
)
|
||||
|
||||
############################################################
|
||||
# Classic handlers
|
||||
############################################################
|
||||
@@ -271,6 +338,52 @@ class LocalLink:
|
||||
initiator_controller.public_address, int(not (initiator_new_role))
|
||||
)
|
||||
|
||||
def classic_sco_connect(
|
||||
self,
|
||||
initiator_controller: controller.Controller,
|
||||
responder_address: Address,
|
||||
link_type: int,
|
||||
):
|
||||
logger.debug(
|
||||
f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
|
||||
)
|
||||
responder_controller = self.find_classic_controller(responder_address)
|
||||
# Initiator controller should handle it.
|
||||
assert responder_controller
|
||||
|
||||
responder_controller.on_classic_connection_request(
|
||||
initiator_controller.public_address,
|
||||
link_type,
|
||||
)
|
||||
|
||||
def classic_accept_sco_connection(
|
||||
self,
|
||||
responder_controller: controller.Controller,
|
||||
initiator_address: Address,
|
||||
link_type: int,
|
||||
):
|
||||
logger.debug(
|
||||
f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
|
||||
)
|
||||
initiator_controller = self.find_classic_controller(initiator_address)
|
||||
if initiator_controller is None:
|
||||
responder_controller.on_classic_sco_connection_complete(
|
||||
responder_controller.public_address,
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
link_type,
|
||||
)
|
||||
return
|
||||
|
||||
async def task():
|
||||
initiator_controller.on_classic_sco_connection_complete(
|
||||
responder_controller.public_address, HCI_SUCCESS, link_type
|
||||
)
|
||||
|
||||
asyncio.create_task(task())
|
||||
responder_controller.on_classic_sco_connection_complete(
|
||||
initiator_controller.public_address, HCI_SUCCESS, link_type
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class RemoteLink:
|
||||
@@ -297,12 +410,12 @@ class RemoteLink:
|
||||
|
||||
def add_controller(self, controller):
|
||||
if self.controller:
|
||||
raise ValueError('controller already set')
|
||||
raise InvalidStateError('controller already set')
|
||||
self.controller = controller
|
||||
|
||||
def remove_controller(self, controller):
|
||||
if self.controller != controller:
|
||||
raise ValueError('controller mismatch')
|
||||
raise InvalidStateError('controller mismatch')
|
||||
self.controller = None
|
||||
|
||||
def get_pending_connection(self):
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from .hci import (
|
||||
@@ -35,7 +37,60 @@ from .smp import (
|
||||
SMP_ID_KEY_DISTRIBUTION_FLAG,
|
||||
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
|
||||
SMP_LINK_KEY_DISTRIBUTION_FLAG,
|
||||
OobContext,
|
||||
OobLegacyContext,
|
||||
OobSharedData,
|
||||
)
|
||||
from .core import AdvertisingData, LeRole
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class OobData:
|
||||
"""OOB data that can be sent from one device to another."""
|
||||
|
||||
address: Optional[Address] = None
|
||||
role: Optional[LeRole] = None
|
||||
shared_data: Optional[OobSharedData] = None
|
||||
legacy_context: Optional[OobLegacyContext] = None
|
||||
|
||||
@classmethod
|
||||
def from_ad(cls, ad: AdvertisingData) -> OobData:
|
||||
instance = cls()
|
||||
shared_data_c: Optional[bytes] = None
|
||||
shared_data_r: Optional[bytes] = None
|
||||
for ad_type, ad_data in ad.ad_structures:
|
||||
if ad_type == AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS:
|
||||
instance.address = Address(ad_data)
|
||||
elif ad_type == AdvertisingData.LE_ROLE:
|
||||
instance.role = LeRole(ad_data[0])
|
||||
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE:
|
||||
shared_data_c = ad_data
|
||||
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE:
|
||||
shared_data_r = ad_data
|
||||
elif ad_type == AdvertisingData.SECURITY_MANAGER_TK_VALUE:
|
||||
instance.legacy_context = OobLegacyContext(tk=ad_data)
|
||||
if shared_data_c and shared_data_r:
|
||||
instance.shared_data = OobSharedData(c=shared_data_c, r=shared_data_r)
|
||||
|
||||
return instance
|
||||
|
||||
def to_ad(self) -> AdvertisingData:
|
||||
ad_structures = []
|
||||
if self.address is not None:
|
||||
ad_structures.append(
|
||||
(AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address))
|
||||
)
|
||||
if self.role is not None:
|
||||
ad_structures.append((AdvertisingData.LE_ROLE, bytes([self.role])))
|
||||
if self.shared_data is not None:
|
||||
ad_structures.extend(self.shared_data.to_ad().ad_structures)
|
||||
if self.legacy_context is not None:
|
||||
ad_structures.append(
|
||||
(AdvertisingData.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk)
|
||||
)
|
||||
|
||||
return AdvertisingData(ad_structures)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -173,6 +228,14 @@ class PairingConfig:
|
||||
PUBLIC = Address.PUBLIC_DEVICE_ADDRESS
|
||||
RANDOM = Address.RANDOM_DEVICE_ADDRESS
|
||||
|
||||
@dataclass
|
||||
class OobConfig:
|
||||
"""Config for OOB pairing."""
|
||||
|
||||
our_context: Optional[OobContext]
|
||||
peer_data: Optional[OobSharedData]
|
||||
legacy_context: Optional[OobLegacyContext]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sc: bool = True,
|
||||
@@ -180,17 +243,20 @@ class PairingConfig:
|
||||
bonding: bool = True,
|
||||
delegate: Optional[PairingDelegate] = None,
|
||||
identity_address_type: Optional[AddressType] = None,
|
||||
oob: Optional[OobConfig] = None,
|
||||
) -> None:
|
||||
self.sc = sc
|
||||
self.mitm = mitm
|
||||
self.bonding = bonding
|
||||
self.delegate = delegate or PairingDelegate()
|
||||
self.identity_address_type = identity_address_type
|
||||
self.oob = oob
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'PairingConfig(sc={self.sc}, '
|
||||
f'mitm={self.mitm}, bonding={self.bonding}, '
|
||||
f'identity_address_type={self.identity_address_type}, '
|
||||
f'delegate[{self.delegate.io_capability}])'
|
||||
f'delegate[{self.delegate.io_capability}]), '
|
||||
f'oob[{self.oob}])'
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
from bumble.pairing import PairingConfig, PairingDelegate
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
"""Generic & dependency free Bumble (reference) device."""
|
||||
|
||||
from __future__ import annotations
|
||||
from bumble import transport
|
||||
from bumble.core import (
|
||||
BT_GENERIC_AUDIO_SERVICE,
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import bumble.device
|
||||
import grpc
|
||||
@@ -27,14 +28,18 @@ from bumble.core import (
|
||||
BT_PERIPHERAL_ROLE,
|
||||
UUID,
|
||||
AdvertisingData,
|
||||
Appearance,
|
||||
ConnectionError,
|
||||
)
|
||||
from bumble.device import (
|
||||
DEVICE_DEFAULT_SCAN_INTERVAL,
|
||||
DEVICE_DEFAULT_SCAN_WINDOW,
|
||||
Advertisement,
|
||||
AdvertisingParameters,
|
||||
AdvertisingEventProperties,
|
||||
AdvertisingType,
|
||||
Device,
|
||||
Phy,
|
||||
)
|
||||
from bumble.gatt import Service
|
||||
from bumble.hci import (
|
||||
@@ -46,9 +51,12 @@ from bumble.hci import (
|
||||
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
||||
from pandora.host_grpc_aio import HostServicer
|
||||
from pandora import host_pb2
|
||||
from pandora.host_pb2 import (
|
||||
NOT_CONNECTABLE,
|
||||
NOT_DISCOVERABLE,
|
||||
DISCOVERABLE_LIMITED,
|
||||
DISCOVERABLE_GENERAL,
|
||||
PRIMARY_1M,
|
||||
PRIMARY_CODED,
|
||||
SECONDARY_1M,
|
||||
@@ -64,6 +72,7 @@ from pandora.host_pb2 import (
|
||||
ConnectResponse,
|
||||
DataTypes,
|
||||
DisconnectRequest,
|
||||
DiscoverabilityMode,
|
||||
InquiryResponse,
|
||||
PrimaryPhy,
|
||||
ReadLocalAddressResponse,
|
||||
@@ -93,6 +102,25 @@ SECONDARY_PHY_MAP: Dict[int, SecondaryPhy] = {
|
||||
3: SECONDARY_CODED,
|
||||
}
|
||||
|
||||
PRIMARY_PHY_TO_BUMBLE_PHY_MAP: Dict[PrimaryPhy, Phy] = {
|
||||
PRIMARY_1M: Phy.LE_1M,
|
||||
PRIMARY_CODED: Phy.LE_CODED,
|
||||
}
|
||||
|
||||
SECONDARY_PHY_TO_BUMBLE_PHY_MAP: Dict[SecondaryPhy, Phy] = {
|
||||
SECONDARY_NONE: Phy.LE_1M,
|
||||
SECONDARY_1M: Phy.LE_1M,
|
||||
SECONDARY_2M: Phy.LE_2M,
|
||||
SECONDARY_CODED: Phy.LE_CODED,
|
||||
}
|
||||
|
||||
OWN_ADDRESS_MAP: Dict[host_pb2.OwnAddressType, bumble.hci.OwnAddressType] = {
|
||||
host_pb2.PUBLIC: bumble.hci.OwnAddressType.PUBLIC,
|
||||
host_pb2.RANDOM: bumble.hci.OwnAddressType.RANDOM,
|
||||
host_pb2.RESOLVABLE_OR_PUBLIC: bumble.hci.OwnAddressType.RESOLVABLE_OR_PUBLIC,
|
||||
host_pb2.RESOLVABLE_OR_RANDOM: bumble.hci.OwnAddressType.RESOLVABLE_OR_RANDOM,
|
||||
}
|
||||
|
||||
|
||||
class HostService(HostServicer):
|
||||
waited_connections: Set[int]
|
||||
@@ -260,9 +288,9 @@ class HostService(HostServicer):
|
||||
self.log.debug(f"WaitDisconnection: {connection_handle}")
|
||||
|
||||
if connection := self.device.lookup_connection(connection_handle):
|
||||
disconnection_future: asyncio.Future[
|
||||
None
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
disconnection_future: asyncio.Future[None] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
def on_disconnection(_: None) -> None:
|
||||
disconnection_future.set_result(None)
|
||||
@@ -280,14 +308,118 @@ class HostService(HostServicer):
|
||||
async def Advertise(
|
||||
self, request: AdvertiseRequest, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[AdvertiseResponse, None]:
|
||||
if not request.legacy:
|
||||
raise NotImplementedError(
|
||||
"TODO: add support for extended advertising in Bumble"
|
||||
try:
|
||||
if request.legacy:
|
||||
async for rsp in self.legacy_advertise(request, context):
|
||||
yield rsp
|
||||
else:
|
||||
async for rsp in self.extended_advertise(request, context):
|
||||
yield rsp
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def extended_advertise(
|
||||
self, request: AdvertiseRequest, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[AdvertiseResponse, None]:
|
||||
advertising_data = bytes(self.unpack_data_types(request.data))
|
||||
scan_response_data = bytes(self.unpack_data_types(request.scan_response_data))
|
||||
scannable = len(scan_response_data) != 0
|
||||
|
||||
advertising_event_properties = AdvertisingEventProperties(
|
||||
is_connectable=request.connectable,
|
||||
is_scannable=scannable,
|
||||
is_directed=request.target is not None,
|
||||
is_high_duty_cycle_directed_connectable=False,
|
||||
is_legacy=False,
|
||||
is_anonymous=False,
|
||||
include_tx_power=False,
|
||||
)
|
||||
|
||||
peer_address = Address.ANY
|
||||
if request.target:
|
||||
# Need to reverse bytes order since Bumble Address is using MSB.
|
||||
target_bytes = bytes(reversed(request.target))
|
||||
if request.target_variant() == "public":
|
||||
peer_address = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
|
||||
else:
|
||||
peer_address = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
|
||||
|
||||
advertising_parameters = AdvertisingParameters(
|
||||
advertising_event_properties=advertising_event_properties,
|
||||
own_address_type=OWN_ADDRESS_MAP[request.own_address_type],
|
||||
peer_address=peer_address,
|
||||
primary_advertising_phy=PRIMARY_PHY_TO_BUMBLE_PHY_MAP[request.primary_phy],
|
||||
secondary_advertising_phy=SECONDARY_PHY_TO_BUMBLE_PHY_MAP[
|
||||
request.secondary_phy
|
||||
],
|
||||
)
|
||||
if advertising_interval := request.interval:
|
||||
advertising_parameters.primary_advertising_interval_min = int(
|
||||
advertising_interval
|
||||
)
|
||||
if request.interval:
|
||||
raise NotImplementedError("TODO: add support for `request.interval`")
|
||||
if request.interval_range:
|
||||
raise NotImplementedError("TODO: add support for `request.interval_range`")
|
||||
advertising_parameters.primary_advertising_interval_max = int(
|
||||
advertising_interval
|
||||
)
|
||||
if interval_range := request.interval_range:
|
||||
advertising_parameters.primary_advertising_interval_max += int(
|
||||
interval_range
|
||||
)
|
||||
|
||||
advertising_set = await self.device.create_advertising_set(
|
||||
advertising_parameters=advertising_parameters,
|
||||
advertising_data=advertising_data,
|
||||
scan_response_data=scan_response_data,
|
||||
)
|
||||
|
||||
pending_connection: asyncio.Future[bumble.device.Connection] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
if request.connectable:
|
||||
|
||||
def on_connection(connection: bumble.device.Connection) -> None:
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
):
|
||||
pending_connection.set_result(connection)
|
||||
|
||||
self.device.on('connection', on_connection)
|
||||
|
||||
try:
|
||||
# Advertise until RPC is canceled
|
||||
while True:
|
||||
if not advertising_set.enabled:
|
||||
self.log.debug('Advertise (extended)')
|
||||
await advertising_set.start()
|
||||
|
||||
if not request.connectable:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
connection = await pending_connection
|
||||
pending_connection = asyncio.get_running_loop().create_future()
|
||||
|
||||
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
|
||||
yield AdvertiseResponse(connection=Connection(cookie=cookie))
|
||||
|
||||
await asyncio.sleep(1)
|
||||
finally:
|
||||
try:
|
||||
self.log.debug('Stop Advertise (extended)')
|
||||
await advertising_set.stop()
|
||||
await advertising_set.remove()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def legacy_advertise(
|
||||
self, request: AdvertiseRequest, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[AdvertiseResponse, None]:
|
||||
if advertising_interval := request.interval:
|
||||
self.device.config.advertising_interval_min = int(advertising_interval)
|
||||
self.device.config.advertising_interval_max = int(advertising_interval)
|
||||
if interval_range := request.interval_range:
|
||||
self.device.config.advertising_interval_max += int(interval_range)
|
||||
if request.primary_phy:
|
||||
raise NotImplementedError("TODO: add support for `request.primary_phy`")
|
||||
if request.secondary_phy:
|
||||
@@ -355,14 +487,10 @@ class HostService(HostServicer):
|
||||
target_bytes = bytes(reversed(request.target))
|
||||
if request.target_variant() == "public":
|
||||
target = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
|
||||
advertising_type = (
|
||||
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
|
||||
) # FIXME: HIGH_DUTY ?
|
||||
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY
|
||||
else:
|
||||
target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
|
||||
advertising_type = (
|
||||
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
|
||||
) # FIXME: HIGH_DUTY ?
|
||||
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY
|
||||
|
||||
if request.connectable:
|
||||
|
||||
@@ -389,9 +517,9 @@ class HostService(HostServicer):
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
pending_connection: asyncio.Future[
|
||||
bumble.device.Connection
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
pending_connection: asyncio.Future[bumble.device.Connection] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
self.log.debug('Wait for LE connection...')
|
||||
connection = await pending_connection
|
||||
@@ -420,23 +548,31 @@ class HostService(HostServicer):
|
||||
self, request: ScanRequest, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[ScanningResponse, None]:
|
||||
# TODO: modify `start_scanning` to accept floats instead of int for ms values
|
||||
if request.phys:
|
||||
raise NotImplementedError("TODO: add support for `request.phys`")
|
||||
|
||||
self.log.debug('Scan')
|
||||
|
||||
scanning_phys = []
|
||||
if PRIMARY_1M in request.phys:
|
||||
scanning_phys.append(int(Phy.LE_1M))
|
||||
if PRIMARY_CODED in request.phys:
|
||||
scanning_phys.append(int(Phy.LE_CODED))
|
||||
if not scanning_phys:
|
||||
scanning_phys = [int(Phy.LE_1M), int(Phy.LE_CODED)]
|
||||
|
||||
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
|
||||
handler = self.device.on('advertisement', scan_queue.put_nowait)
|
||||
await self.device.start_scanning(
|
||||
legacy=request.legacy,
|
||||
active=not request.passive,
|
||||
own_address_type=request.own_address_type,
|
||||
scan_interval=int(request.interval)
|
||||
if request.interval
|
||||
else DEVICE_DEFAULT_SCAN_INTERVAL,
|
||||
scan_window=int(request.window)
|
||||
if request.window
|
||||
else DEVICE_DEFAULT_SCAN_WINDOW,
|
||||
scan_interval=(
|
||||
int(request.interval)
|
||||
if request.interval
|
||||
else DEVICE_DEFAULT_SCAN_INTERVAL
|
||||
),
|
||||
scan_window=(
|
||||
int(request.window) if request.window else DEVICE_DEFAULT_SCAN_WINDOW
|
||||
),
|
||||
scanning_phys=scanning_phys,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -649,9 +785,11 @@ class HostService(HostServicer):
|
||||
*struct.pack('<H', dt.peripheral_connection_interval_min),
|
||||
*struct.pack(
|
||||
'<H',
|
||||
dt.peripheral_connection_interval_max
|
||||
if dt.peripheral_connection_interval_max
|
||||
else dt.peripheral_connection_interval_min,
|
||||
(
|
||||
dt.peripheral_connection_interval_max
|
||||
if dt.peripheral_connection_interval_max
|
||||
else dt.peripheral_connection_interval_min
|
||||
),
|
||||
),
|
||||
]
|
||||
),
|
||||
@@ -733,6 +871,16 @@ class HostService(HostServicer):
|
||||
)
|
||||
)
|
||||
|
||||
flag_map = {
|
||||
NOT_DISCOVERABLE: 0x00,
|
||||
DISCOVERABLE_LIMITED: AdvertisingData.LE_LIMITED_DISCOVERABLE_MODE_FLAG,
|
||||
DISCOVERABLE_GENERAL: AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG,
|
||||
}
|
||||
|
||||
if dt.le_discoverability_mode:
|
||||
flags = flag_map[dt.le_discoverability_mode]
|
||||
ad_structures.append((AdvertisingData.FLAGS, flags.to_bytes(1, 'big')))
|
||||
|
||||
return AdvertisingData(ad_structures)
|
||||
|
||||
def pack_data_types(self, ad: AdvertisingData) -> DataTypes:
|
||||
@@ -841,8 +989,8 @@ class HostService(HostServicer):
|
||||
dt.random_target_addresses.extend(
|
||||
[data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))]
|
||||
)
|
||||
if i := cast(int, ad.get(AdvertisingData.APPEARANCE)):
|
||||
dt.appearance = i
|
||||
if appearance := cast(Appearance, ad.get(AdvertisingData.APPEARANCE)):
|
||||
dt.appearance = int(appearance)
|
||||
if i := cast(int, ad.get(AdvertisingData.ADVERTISING_INTERVAL)):
|
||||
dt.advertising_interval = i
|
||||
if s := cast(str, ad.get(AdvertisingData.URI)):
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import grpc
|
||||
@@ -109,7 +110,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
|
||||
event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
answer = await anext(self.service.event_answer) # type: ignore
|
||||
assert answer.event == event
|
||||
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
|
||||
return answer.confirm
|
||||
@@ -124,7 +125,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
|
||||
event = self.add_origin(PairingEvent(numeric_comparison=number))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
answer = await anext(self.service.event_answer) # type: ignore
|
||||
assert answer.event == event
|
||||
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
|
||||
return answer.confirm
|
||||
@@ -139,7 +140,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
|
||||
event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
answer = await anext(self.service.event_answer) # type: ignore
|
||||
assert answer.event == event
|
||||
if answer.answer_variant() is None:
|
||||
return None
|
||||
@@ -156,7 +157,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
|
||||
event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
answer = await anext(self.service.event_answer) # type: ignore
|
||||
assert answer.event == event
|
||||
if answer.answer_variant() is None:
|
||||
return None
|
||||
@@ -382,9 +383,9 @@ class SecurityService(SecurityServicer):
|
||||
connection.transport
|
||||
] == request.level_variant()
|
||||
|
||||
wait_for_security: asyncio.Future[
|
||||
str
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
wait_for_security: asyncio.Future[str] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
authenticate_task: Optional[asyncio.Future[None]] = None
|
||||
pair_task: Optional[asyncio.Future[None]] = None
|
||||
|
||||
@@ -450,21 +451,18 @@ class SecurityService(SecurityServicer):
|
||||
'security_request': pair,
|
||||
}
|
||||
|
||||
# register event handlers
|
||||
for event, listener in listeners.items():
|
||||
connection.on(event, listener)
|
||||
with contextlib.closing(EventWatcher()) as watcher:
|
||||
# register event handlers
|
||||
for event, listener in listeners.items():
|
||||
watcher.on(connection, event, listener)
|
||||
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
return WaitSecurityResponse(success=empty_pb2.Empty())
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
return WaitSecurityResponse(success=empty_pb2.Empty())
|
||||
|
||||
self.log.debug('Wait for security...')
|
||||
kwargs = {}
|
||||
kwargs[await wait_for_security] = empty_pb2.Empty()
|
||||
|
||||
# remove event handlers
|
||||
for event, listener in listeners.items():
|
||||
connection.remove_listener(event, listener) # type: ignore
|
||||
self.log.debug('Wait for security...')
|
||||
kwargs = {}
|
||||
kwargs[await wait_for_security] = empty_pb2.Empty()
|
||||
|
||||
# wait for `authenticate` to finish if any
|
||||
if authenticate_task is not None:
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import functools
|
||||
import grpc
|
||||
|
||||
@@ -18,7 +18,9 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from bumble import l2cap
|
||||
from ..core import AdvertisingData
|
||||
from ..device import Device, Connection
|
||||
from ..gatt import (
|
||||
@@ -65,7 +67,7 @@ class AshaService(TemplateService):
|
||||
self.emit('volume', connection, value[0])
|
||||
|
||||
# Handler for audio control commands
|
||||
def on_audio_control_point_write(connection: Connection, value):
|
||||
def on_audio_control_point_write(connection: Optional[Connection], value):
|
||||
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
|
||||
opcode = value[0]
|
||||
if opcode == AshaService.OPCODE_START:
|
||||
@@ -149,7 +151,10 @@ class AshaService(TemplateService):
|
||||
channel.sink = on_data
|
||||
|
||||
# let the server find a free PSM
|
||||
self.psm = self.device.register_l2cap_channel_server(self.psm, on_coc, 8)
|
||||
self.psm = device.create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(psm=self.psm, max_credits=8),
|
||||
handler=on_coc,
|
||||
).psm
|
||||
self.le_psm_out_characteristic = Characteristic(
|
||||
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
|
||||
Characteristic.Properties.READ,
|
||||
|
||||
1421
bumble/profiles/bap.py
Normal file
1421
bumble/profiles/bap.py
Normal file
File diff suppressed because it is too large
Load Diff
52
bumble/profiles/cap.py
Normal file
52
bumble/profiles/cap.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble.profiles import csip
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommonAudioServiceService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_COMMON_AUDIO_SERVICE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
coordinated_set_identification_service: csip.CoordinatedSetIdentificationService,
|
||||
) -> None:
|
||||
self.coordinated_set_identification_service = (
|
||||
coordinated_set_identification_service
|
||||
)
|
||||
super().__init__(
|
||||
characteristics=[],
|
||||
included_services=[coordinated_set_identification_service],
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommonAudioServiceServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = CommonAudioServiceService
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
257
bumble/profiles/csip.py
Normal file
257
bumble/profiles/csip.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import struct
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from bumble import core
|
||||
from bumble import crypto
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
SET_IDENTITY_RESOLVING_KEY_LENGTH = 16
|
||||
|
||||
|
||||
class SirkType(enum.IntEnum):
|
||||
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
|
||||
|
||||
ENCRYPTED = 0x00
|
||||
PLAINTEXT = 0x01
|
||||
|
||||
|
||||
class MemberLock(enum.IntEnum):
|
||||
'''Coordinated Set Identification Service - 5.3 Set Member Lock.'''
|
||||
|
||||
UNLOCKED = 0x01
|
||||
LOCKED = 0x02
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Crypto Toolbox
|
||||
# -----------------------------------------------------------------------------
|
||||
def s1(m: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.3 s1 SALT generation function.
|
||||
'''
|
||||
return crypto.aes_cmac(m[::-1], bytes(16))[::-1]
|
||||
|
||||
|
||||
def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.4 k1 derivation function.
|
||||
'''
|
||||
t = crypto.aes_cmac(n[::-1], salt[::-1])
|
||||
return crypto.aes_cmac(p[::-1], t)[::-1]
|
||||
|
||||
|
||||
def sef(k: bytes, r: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
|
||||
|
||||
SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
|
||||
* Plaintext in encryption
|
||||
* Cipher in decryption
|
||||
'''
|
||||
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
|
||||
|
||||
|
||||
def sih(k: bytes, r: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
|
||||
'''
|
||||
return crypto.e(k, r + bytes(13))[:3]
|
||||
|
||||
|
||||
def generate_rsi(sirk: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
|
||||
'''
|
||||
prand = crypto.generate_prand()
|
||||
return sih(sirk, prand) + prand
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
|
||||
|
||||
set_identity_resolving_key: bytes
|
||||
set_identity_resolving_key_characteristic: gatt.Characteristic
|
||||
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
|
||||
set_member_lock_characteristic: Optional[gatt.Characteristic] = None
|
||||
set_member_rank_characteristic: Optional[gatt.Characteristic] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
set_identity_resolving_key: bytes,
|
||||
set_identity_resolving_key_type: SirkType,
|
||||
coordinated_set_size: Optional[int] = None,
|
||||
set_member_lock: Optional[MemberLock] = None,
|
||||
set_member_rank: Optional[int] = None,
|
||||
) -> None:
|
||||
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
|
||||
raise core.InvalidArgumentError(
|
||||
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
|
||||
)
|
||||
|
||||
characteristics = []
|
||||
|
||||
self.set_identity_resolving_key = set_identity_resolving_key
|
||||
self.set_identity_resolving_key_type = set_identity_resolving_key_type
|
||||
self.set_identity_resolving_key_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(read=self.on_sirk_read),
|
||||
)
|
||||
characteristics.append(self.set_identity_resolving_key_characteristic)
|
||||
|
||||
if coordinated_set_size is not None:
|
||||
self.coordinated_set_size_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=struct.pack('B', coordinated_set_size),
|
||||
)
|
||||
characteristics.append(self.coordinated_set_size_characteristic)
|
||||
|
||||
if set_member_lock is not None:
|
||||
self.set_member_lock_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY
|
||||
| gatt.Characteristic.Properties.WRITE,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| gatt.Characteristic.Permissions.WRITEABLE,
|
||||
value=struct.pack('B', set_member_lock),
|
||||
)
|
||||
characteristics.append(self.set_member_lock_characteristic)
|
||||
|
||||
if set_member_rank is not None:
|
||||
self.set_member_rank_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=struct.pack('B', set_member_rank),
|
||||
)
|
||||
characteristics.append(self.set_member_rank_characteristic)
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
|
||||
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
|
||||
sirk_bytes = self.set_identity_resolving_key
|
||||
else:
|
||||
assert connection
|
||||
|
||||
if connection.transport == core.BT_LE_TRANSPORT:
|
||||
key = await connection.device.get_long_term_key(
|
||||
connection_handle=connection.handle, rand=b'', ediv=0
|
||||
)
|
||||
else:
|
||||
key = await connection.device.get_link_key(connection.peer_address)
|
||||
|
||||
if not key:
|
||||
raise core.InvalidOperationError('LTK or LinkKey is not present')
|
||||
|
||||
sirk_bytes = sef(key, self.set_identity_resolving_key)
|
||||
|
||||
return bytes([self.set_identity_resolving_key_type]) + sirk_bytes
|
||||
|
||||
def get_advertising_data(self) -> bytes:
|
||||
return bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
(
|
||||
core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
|
||||
generate_rsi(self.set_identity_resolving_key),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = CoordinatedSetIdentificationService
|
||||
|
||||
set_identity_resolving_key: gatt_client.CharacteristicProxy
|
||||
coordinated_set_size: Optional[gatt_client.CharacteristicProxy] = None
|
||||
set_member_lock: Optional[gatt_client.CharacteristicProxy] = None
|
||||
set_member_rank: Optional[gatt_client.CharacteristicProxy] = None
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.set_identity_resolving_key = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC
|
||||
)[0]
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC
|
||||
):
|
||||
self.coordinated_set_size = characteristics[0]
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC
|
||||
):
|
||||
self.set_member_lock = characteristics[0]
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
|
||||
):
|
||||
self.set_member_rank = characteristics[0]
|
||||
|
||||
async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
|
||||
'''Reads SIRK and decrypts if encrypted.'''
|
||||
response = await self.set_identity_resolving_key.read_value()
|
||||
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
|
||||
raise core.InvalidPacketError('Invalid SIRK value')
|
||||
|
||||
sirk_type = SirkType(response[0])
|
||||
if sirk_type == SirkType.PLAINTEXT:
|
||||
sirk = response[1:]
|
||||
else:
|
||||
connection = self.service_proxy.client.connection
|
||||
device = connection.device
|
||||
if connection.transport == core.BT_LE_TRANSPORT:
|
||||
key = await device.get_long_term_key(
|
||||
connection_handle=connection.handle, rand=b'', ediv=0
|
||||
)
|
||||
else:
|
||||
key = await device.get_link_key(connection.peer_address)
|
||||
|
||||
if not key:
|
||||
raise core.InvalidOperationError('LTK or LinkKey is not present')
|
||||
|
||||
sirk = sef(key, response[1:])
|
||||
|
||||
return (sirk_type, sirk)
|
||||
@@ -19,8 +19,8 @@
|
||||
import struct
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from ..gatt_client import ProfileServiceProxy
|
||||
from ..gatt import (
|
||||
from bumble.gatt_client import ServiceProxy, ProfileServiceProxy, CharacteristicProxy
|
||||
from bumble.gatt import (
|
||||
GATT_DEVICE_INFORMATION_SERVICE,
|
||||
GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC,
|
||||
GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC,
|
||||
@@ -59,7 +59,7 @@ class DeviceInformationService(TemplateService):
|
||||
firmware_revision: Optional[str] = None,
|
||||
software_revision: Optional[str] = None,
|
||||
system_id: Optional[Tuple[int, int]] = None, # (OUI, Manufacturer ID)
|
||||
ieee_regulatory_certification_data_list: Optional[bytes] = None
|
||||
ieee_regulatory_certification_data_list: Optional[bytes] = None,
|
||||
# TODO: pnp_id
|
||||
):
|
||||
characteristics = [
|
||||
@@ -104,10 +104,19 @@ class DeviceInformationService(TemplateService):
|
||||
class DeviceInformationServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = DeviceInformationService
|
||||
|
||||
def __init__(self, service_proxy):
|
||||
manufacturer_name: Optional[UTF8CharacteristicAdapter]
|
||||
model_number: Optional[UTF8CharacteristicAdapter]
|
||||
serial_number: Optional[UTF8CharacteristicAdapter]
|
||||
hardware_revision: Optional[UTF8CharacteristicAdapter]
|
||||
firmware_revision: Optional[UTF8CharacteristicAdapter]
|
||||
software_revision: Optional[UTF8CharacteristicAdapter]
|
||||
system_id: Optional[DelegatedCharacteristicAdapter]
|
||||
ieee_regulatory_certification_data_list: Optional[CharacteristicProxy]
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
for (field, uuid) in (
|
||||
for field, uuid in (
|
||||
('manufacturer_name', GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
|
||||
('model_number', GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
|
||||
('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
from enum import IntEnum
|
||||
import struct
|
||||
|
||||
from bumble import core
|
||||
from ..gatt_client import ProfileServiceProxy
|
||||
from ..att import ATT_Error
|
||||
from ..gatt import (
|
||||
@@ -42,12 +43,12 @@ class HeartRateService(TemplateService):
|
||||
RESET_ENERGY_EXPENDED = 0x01
|
||||
|
||||
class BodySensorLocation(IntEnum):
|
||||
OTHER = (0,)
|
||||
CHEST = (1,)
|
||||
WRIST = (2,)
|
||||
FINGER = (3,)
|
||||
HAND = (4,)
|
||||
EAR_LOBE = (5,)
|
||||
OTHER = 0
|
||||
CHEST = 1
|
||||
WRIST = 2
|
||||
FINGER = 3
|
||||
HAND = 4
|
||||
EAR_LOBE = 5
|
||||
FOOT = 6
|
||||
|
||||
class HeartRateMeasurement:
|
||||
@@ -59,17 +60,17 @@ class HeartRateService(TemplateService):
|
||||
rr_intervals=None,
|
||||
):
|
||||
if heart_rate < 0 or heart_rate > 0xFFFF:
|
||||
raise ValueError('heart_rate out of range')
|
||||
raise core.InvalidArgumentError('heart_rate out of range')
|
||||
|
||||
if energy_expended is not None and (
|
||||
energy_expended < 0 or energy_expended > 0xFFFF
|
||||
):
|
||||
raise ValueError('energy_expended out of range')
|
||||
raise core.InvalidArgumentError('energy_expended out of range')
|
||||
|
||||
if rr_intervals:
|
||||
for rr_interval in rr_intervals:
|
||||
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
|
||||
raise ValueError('rr_intervals out of range')
|
||||
raise core.InvalidArgumentError('rr_intervals out of range')
|
||||
|
||||
self.heart_rate = heart_rate
|
||||
self.sensor_contact_detected = sensor_contact_detected
|
||||
|
||||
49
bumble/profiles/le_audio.py
Normal file
49
bumble/profiles/le_audio.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
from typing import List
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class Metadata:
|
||||
@dataclasses.dataclass
|
||||
class Entry:
|
||||
tag: int
|
||||
data: bytes
|
||||
|
||||
entries: List[Entry]
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
entries = []
|
||||
offset = 0
|
||||
length = len(data)
|
||||
while length >= 2:
|
||||
entry_length = data[offset]
|
||||
entry_tag = data[offset + 1]
|
||||
entry_data = data[offset + 2 : offset + 2 + entry_length - 1]
|
||||
entries.append(cls.Entry(entry_tag, entry_data))
|
||||
length -= entry_length
|
||||
offset += entry_length
|
||||
|
||||
return cls(entries)
|
||||
46
bumble/profiles/pbp.py
Normal file
46
bumble/profiles/pbp.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble.profiles import le_audio
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class PublicBroadcastAnnouncement:
|
||||
class Features(enum.IntFlag):
|
||||
ENCRYPTED = 1 << 0
|
||||
STANDARD_QUALITY_CONFIGURATION = 1 << 1
|
||||
HIGH_QUALITY_CONFIGURATION = 1 << 2
|
||||
|
||||
features: Features
|
||||
metadata: le_audio.Metadata
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
features = cls.Features(data[0])
|
||||
metadata_length = data[1]
|
||||
metadata_ltv = data[1 : 1 + metadata_length]
|
||||
return cls(
|
||||
features=features, metadata=le_audio.Metadata.from_bytes(metadata_ltv)
|
||||
)
|
||||
228
bumble/profiles/vcp.py
Normal file
228
bumble/profiles/vcp.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# Copyright 2021-2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
|
||||
from bumble import att
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
|
||||
from typing import Optional
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
MIN_VOLUME = 0
|
||||
MAX_VOLUME = 255
|
||||
|
||||
|
||||
class ErrorCode(enum.IntEnum):
|
||||
'''
|
||||
See Volume Control Service 1.6. Application error codes.
|
||||
'''
|
||||
|
||||
INVALID_CHANGE_COUNTER = 0x80
|
||||
OPCODE_NOT_SUPPORTED = 0x81
|
||||
|
||||
|
||||
class VolumeFlags(enum.IntFlag):
|
||||
'''
|
||||
See Volume Control Service 3.3. Volume Flags.
|
||||
'''
|
||||
|
||||
VOLUME_SETTING_PERSISTED = 0x01
|
||||
# RFU
|
||||
|
||||
|
||||
class VolumeControlPointOpcode(enum.IntEnum):
|
||||
'''
|
||||
See Volume Control Service Table 3.3: Volume Control Point procedure requirements.
|
||||
'''
|
||||
|
||||
# fmt: off
|
||||
RELATIVE_VOLUME_DOWN = 0x00
|
||||
RELATIVE_VOLUME_UP = 0x01
|
||||
UNMUTE_RELATIVE_VOLUME_DOWN = 0x02
|
||||
UNMUTE_RELATIVE_VOLUME_UP = 0x03
|
||||
SET_ABSOLUTE_VOLUME = 0x04
|
||||
UNMUTE = 0x05
|
||||
MUTE = 0x06
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class VolumeControlService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_VOLUME_CONTROL_SERVICE
|
||||
|
||||
volume_state: gatt.Characteristic
|
||||
volume_control_point: gatt.Characteristic
|
||||
volume_flags: gatt.Characteristic
|
||||
|
||||
volume_setting: int
|
||||
muted: int
|
||||
change_counter: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_size: int = 16,
|
||||
volume_setting: int = 0,
|
||||
muted: int = 0,
|
||||
change_counter: int = 0,
|
||||
volume_flags: int = 0,
|
||||
) -> None:
|
||||
self.step_size = step_size
|
||||
self.volume_setting = volume_setting
|
||||
self.muted = muted
|
||||
self.change_counter = change_counter
|
||||
|
||||
self.volume_state = gatt.Characteristic(
|
||||
uuid=gatt.GATT_VOLUME_STATE_CHARACTERISTIC,
|
||||
properties=(
|
||||
gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY
|
||||
),
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(read=self._on_read_volume_state),
|
||||
)
|
||||
self.volume_control_point = gatt.Characteristic(
|
||||
uuid=gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.WRITE,
|
||||
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(write=self._on_write_volume_control_point),
|
||||
)
|
||||
self.volume_flags = gatt.Characteristic(
|
||||
uuid=gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=bytes([volume_flags]),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
[
|
||||
self.volume_state,
|
||||
self.volume_control_point,
|
||||
self.volume_flags,
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def volume_state_bytes(self) -> bytes:
|
||||
return bytes([self.volume_setting, self.muted, self.change_counter])
|
||||
|
||||
@volume_state_bytes.setter
|
||||
def volume_state_bytes(self, new_value: bytes) -> None:
|
||||
self.volume_setting, self.muted, self.change_counter = new_value
|
||||
|
||||
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
|
||||
return self.volume_state_bytes
|
||||
|
||||
def _on_write_volume_control_point(
|
||||
self, connection: Optional[device.Connection], value: bytes
|
||||
) -> None:
|
||||
assert connection
|
||||
|
||||
opcode = VolumeControlPointOpcode(value[0])
|
||||
change_counter = value[1]
|
||||
|
||||
if change_counter != self.change_counter:
|
||||
raise att.ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
|
||||
|
||||
handler = getattr(self, '_on_' + opcode.name.lower())
|
||||
if handler(*value[2:]):
|
||||
self.change_counter = (self.change_counter + 1) % 256
|
||||
connection.abort_on(
|
||||
'disconnection',
|
||||
connection.device.notify_subscribers(
|
||||
attribute=self.volume_state,
|
||||
value=self.volume_state_bytes,
|
||||
),
|
||||
)
|
||||
self.emit(
|
||||
'volume_state', self.volume_setting, self.muted, self.change_counter
|
||||
)
|
||||
|
||||
def _on_relative_volume_down(self) -> bool:
|
||||
old_volume = self.volume_setting
|
||||
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
|
||||
return self.volume_setting != old_volume
|
||||
|
||||
def _on_relative_volume_up(self) -> bool:
|
||||
old_volume = self.volume_setting
|
||||
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
|
||||
return self.volume_setting != old_volume
|
||||
|
||||
def _on_unmute_relative_volume_down(self) -> bool:
|
||||
old_volume, old_muted_state = self.volume_setting, self.muted
|
||||
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
|
||||
self.muted = 0
|
||||
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
|
||||
|
||||
def _on_unmute_relative_volume_up(self) -> bool:
|
||||
old_volume, old_muted_state = self.volume_setting, self.muted
|
||||
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
|
||||
self.muted = 0
|
||||
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
|
||||
|
||||
def _on_set_absolute_volume(self, volume_setting: int) -> bool:
|
||||
old_volume_setting = self.volume_setting
|
||||
self.volume_setting = volume_setting
|
||||
return old_volume_setting != self.volume_setting
|
||||
|
||||
def _on_unmute(self) -> bool:
|
||||
old_muted_state = self.muted
|
||||
self.muted = 0
|
||||
return self.muted != old_muted_state
|
||||
|
||||
def _on_mute(self) -> bool:
|
||||
old_muted_state = self.muted
|
||||
self.muted = 1
|
||||
return self.muted != old_muted_state
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class VolumeControlServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = VolumeControlService
|
||||
|
||||
volume_control_point: gatt_client.CharacteristicProxy
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.volume_state = gatt.PackedCharacteristicAdapter(
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_VOLUME_STATE_CHARACTERISTIC
|
||||
)[0],
|
||||
'BBB',
|
||||
)
|
||||
|
||||
self.volume_control_point = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC
|
||||
)[0]
|
||||
|
||||
self.volume_flags = gatt.PackedCharacteristicAdapter(
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC
|
||||
)[0],
|
||||
'B',
|
||||
)
|
||||
563
bumble/rfcomm.py
563
bumble/rfcomm.py
@@ -19,30 +19,28 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import collections
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
|
||||
from pyee import EventEmitter
|
||||
|
||||
from . import core, l2cap
|
||||
from bumble import core
|
||||
from bumble import l2cap
|
||||
from bumble import sdp
|
||||
from .colors import color
|
||||
from .core import (
|
||||
UUID,
|
||||
BT_RFCOMM_PROTOCOL_ID,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_L2CAP_PROTOCOL_ID,
|
||||
InvalidArgumentError,
|
||||
InvalidStateError,
|
||||
InvalidPacketError,
|
||||
ProtocolError,
|
||||
)
|
||||
from .sdp import (
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_PUBLIC_BROWSE_ROOT,
|
||||
DataElement,
|
||||
ServiceAttribute,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Device, Connection
|
||||
@@ -59,28 +57,20 @@ logger = logging.getLogger(__name__)
|
||||
# fmt: off
|
||||
|
||||
RFCOMM_PSM = 0x0003
|
||||
DEFAULT_RX_QUEUE_SIZE = 32
|
||||
|
||||
class FrameType(enum.IntEnum):
|
||||
SABM = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
|
||||
UA = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
|
||||
DM = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
|
||||
DISC = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
|
||||
UIH = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
|
||||
UI = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first
|
||||
|
||||
# Frame types
|
||||
RFCOMM_SABM_FRAME = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
|
||||
RFCOMM_UA_FRAME = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
|
||||
RFCOMM_DM_FRAME = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
|
||||
RFCOMM_DISC_FRAME = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
|
||||
RFCOMM_UIH_FRAME = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
|
||||
RFCOMM_UI_FRAME = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first
|
||||
class MccType(enum.IntEnum):
|
||||
PN = 0x20
|
||||
MSC = 0x38
|
||||
|
||||
RFCOMM_FRAME_TYPE_NAMES = {
|
||||
RFCOMM_SABM_FRAME: 'SABM',
|
||||
RFCOMM_UA_FRAME: 'UA',
|
||||
RFCOMM_DM_FRAME: 'DM',
|
||||
RFCOMM_DISC_FRAME: 'DISC',
|
||||
RFCOMM_UIH_FRAME: 'UIH',
|
||||
RFCOMM_UI_FRAME: 'UI'
|
||||
}
|
||||
|
||||
# MCC Types
|
||||
RFCOMM_MCC_PN_TYPE = 0x20
|
||||
RFCOMM_MCC_MSC_TYPE = 0x38
|
||||
|
||||
# FCS CRC
|
||||
CRC_TABLE = bytes([
|
||||
@@ -118,8 +108,11 @@ CRC_TABLE = bytes([
|
||||
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
|
||||
])
|
||||
|
||||
RFCOMM_DEFAULT_INITIAL_RX_CREDITS = 7
|
||||
RFCOMM_DEFAULT_PREFERRED_MTU = 1280
|
||||
RFCOMM_DEFAULT_L2CAP_MTU = 2048
|
||||
RFCOMM_DEFAULT_INITIAL_CREDITS = 7
|
||||
RFCOMM_DEFAULT_MAX_CREDITS = 32
|
||||
RFCOMM_DEFAULT_CREDIT_THRESHOLD = RFCOMM_DEFAULT_MAX_CREDITS // 2
|
||||
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
|
||||
|
||||
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
|
||||
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
|
||||
@@ -130,29 +123,33 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
|
||||
# -----------------------------------------------------------------------------
|
||||
def make_service_sdp_records(
|
||||
service_record_handle: int, channel: int, uuid: Optional[UUID] = None
|
||||
) -> List[ServiceAttribute]:
|
||||
) -> List[sdp.ServiceAttribute]:
|
||||
"""
|
||||
Create SDP records for an RFComm service given a channel number and an
|
||||
optional UUID. A Service Class Attribute is included only if the UUID is not None.
|
||||
"""
|
||||
records = [
|
||||
ServiceAttribute(
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_32(service_record_handle),
|
||||
sdp.ServiceAttribute(
|
||||
sdp.SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
sdp.DataElement.unsigned_integer_32(service_record_handle),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
|
||||
sdp.ServiceAttribute(
|
||||
sdp.SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||
sdp.DataElement.sequence(
|
||||
[sdp.DataElement.uuid(sdp.SDP_PUBLIC_BROWSE_ROOT)]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
sdp.ServiceAttribute(
|
||||
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
sdp.DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
|
||||
DataElement.sequence(
|
||||
sdp.DataElement.sequence(
|
||||
[sdp.DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]
|
||||
),
|
||||
sdp.DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
|
||||
DataElement.unsigned_integer_8(channel),
|
||||
sdp.DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
|
||||
sdp.DataElement.unsigned_integer_8(channel),
|
||||
]
|
||||
),
|
||||
]
|
||||
@@ -162,15 +159,81 @@ def make_service_sdp_records(
|
||||
|
||||
if uuid:
|
||||
records.append(
|
||||
ServiceAttribute(
|
||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence([DataElement.uuid(uuid)]),
|
||||
sdp.ServiceAttribute(
|
||||
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
sdp.DataElement.sequence([sdp.DataElement.uuid(uuid)]),
|
||||
)
|
||||
)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def find_rfcomm_channels(connection: Connection) -> Dict[int, List[UUID]]:
|
||||
"""Searches all RFCOMM channels and their associated UUID from SDP service records.
|
||||
|
||||
Args:
|
||||
connection: ACL connection to make SDP search.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping from channel number to service class UUID list.
|
||||
"""
|
||||
results = {}
|
||||
async with sdp.Client(connection) as sdp_client:
|
||||
search_result = await sdp_client.search_attributes(
|
||||
uuids=[core.BT_RFCOMM_PROTOCOL_ID],
|
||||
attribute_ids=[
|
||||
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
],
|
||||
)
|
||||
for attribute_lists in search_result:
|
||||
service_classes: List[UUID] = []
|
||||
channel: Optional[int] = None
|
||||
for attribute in attribute_lists:
|
||||
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
|
||||
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
|
||||
protocol_descriptor_list = attribute.value.value
|
||||
channel = protocol_descriptor_list[1].value[1].value
|
||||
elif attribute.id == sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID:
|
||||
service_class_id_list = attribute.value.value
|
||||
service_classes = [
|
||||
service_class.value for service_class in service_class_id_list
|
||||
]
|
||||
if not service_classes or not channel:
|
||||
logger.warning(f"Bad result {attribute_lists}.")
|
||||
else:
|
||||
results[channel] = service_classes
|
||||
return results
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def find_rfcomm_channel_with_uuid(
|
||||
connection: Connection, uuid: str | UUID
|
||||
) -> Optional[int]:
|
||||
"""Searches an RFCOMM channel associated with given UUID from service records.
|
||||
|
||||
Args:
|
||||
connection: ACL connection to make SDP search.
|
||||
uuid: UUID of service record to search for.
|
||||
|
||||
Returns:
|
||||
RFCOMM channel number if found, otherwise None.
|
||||
"""
|
||||
if isinstance(uuid, str):
|
||||
uuid = UUID(uuid)
|
||||
return next(
|
||||
(
|
||||
channel
|
||||
for channel, class_id_list in (
|
||||
await find_rfcomm_channels(connection)
|
||||
).items()
|
||||
if uuid in class_id_list
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def compute_fcs(buffer: bytes) -> int:
|
||||
result = 0xFF
|
||||
@@ -183,7 +246,7 @@ def compute_fcs(buffer: bytes) -> int:
|
||||
class RFCOMM_Frame:
|
||||
def __init__(
|
||||
self,
|
||||
frame_type: int,
|
||||
frame_type: FrameType,
|
||||
c_r: int,
|
||||
dlci: int,
|
||||
p_f: int,
|
||||
@@ -206,14 +269,11 @@ class RFCOMM_Frame:
|
||||
self.length = bytes([(length << 1) | 1])
|
||||
self.address = (dlci << 2) | (c_r << 1) | 1
|
||||
self.control = frame_type | (p_f << 4)
|
||||
if frame_type == RFCOMM_UIH_FRAME:
|
||||
if frame_type == FrameType.UIH:
|
||||
self.fcs = compute_fcs(bytes([self.address, self.control]))
|
||||
else:
|
||||
self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
|
||||
|
||||
def type_name(self) -> str:
|
||||
return RFCOMM_FRAME_TYPE_NAMES[self.type]
|
||||
|
||||
@staticmethod
|
||||
def parse_mcc(data) -> Tuple[int, bool, bytes]:
|
||||
mcc_type = data[0] >> 2
|
||||
@@ -237,24 +297,24 @@ class RFCOMM_Frame:
|
||||
|
||||
@staticmethod
|
||||
def sabm(c_r: int, dlci: int):
|
||||
return RFCOMM_Frame(RFCOMM_SABM_FRAME, c_r, dlci, 1)
|
||||
return RFCOMM_Frame(FrameType.SABM, c_r, dlci, 1)
|
||||
|
||||
@staticmethod
|
||||
def ua(c_r: int, dlci: int):
|
||||
return RFCOMM_Frame(RFCOMM_UA_FRAME, c_r, dlci, 1)
|
||||
return RFCOMM_Frame(FrameType.UA, c_r, dlci, 1)
|
||||
|
||||
@staticmethod
|
||||
def dm(c_r: int, dlci: int):
|
||||
return RFCOMM_Frame(RFCOMM_DM_FRAME, c_r, dlci, 1)
|
||||
return RFCOMM_Frame(FrameType.DM, c_r, dlci, 1)
|
||||
|
||||
@staticmethod
|
||||
def disc(c_r: int, dlci: int):
|
||||
return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1)
|
||||
return RFCOMM_Frame(FrameType.DISC, c_r, dlci, 1)
|
||||
|
||||
@staticmethod
|
||||
def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0):
|
||||
return RFCOMM_Frame(
|
||||
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
|
||||
FrameType.UIH, c_r, dlci, p_f, information, with_credits=(p_f == 1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -262,7 +322,7 @@ class RFCOMM_Frame:
|
||||
# Extract fields
|
||||
dlci = (data[0] >> 2) & 0x3F
|
||||
c_r = (data[0] >> 1) & 0x01
|
||||
frame_type = data[1] & 0xEF
|
||||
frame_type = FrameType(data[1] & 0xEF)
|
||||
p_f = (data[1] >> 4) & 0x01
|
||||
length = data[2]
|
||||
if length & 0x01:
|
||||
@@ -277,7 +337,7 @@ class RFCOMM_Frame:
|
||||
frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
|
||||
if frame.fcs != fcs:
|
||||
logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
|
||||
raise ValueError('fcs mismatch')
|
||||
raise InvalidPacketError('fcs mismatch')
|
||||
|
||||
return frame
|
||||
|
||||
@@ -291,7 +351,7 @@ class RFCOMM_Frame:
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'{color(self.type_name(), "yellow")}'
|
||||
f'{color(self.type.name, "yellow")}'
|
||||
f'(c/r={self.c_r},'
|
||||
f'dlci={self.dlci},'
|
||||
f'p/f={self.p_f},'
|
||||
@@ -301,6 +361,7 @@ class RFCOMM_Frame:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class RFCOMM_MCC_PN:
|
||||
dlci: int
|
||||
cl: int
|
||||
@@ -308,25 +369,13 @@ class RFCOMM_MCC_PN:
|
||||
ack_timer: int
|
||||
max_frame_size: int
|
||||
max_retransmissions: int
|
||||
window_size: int
|
||||
initial_credits: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dlci: int,
|
||||
cl: int,
|
||||
priority: int,
|
||||
ack_timer: int,
|
||||
max_frame_size: int,
|
||||
max_retransmissions: int,
|
||||
window_size: int,
|
||||
) -> None:
|
||||
self.dlci = dlci
|
||||
self.cl = cl
|
||||
self.priority = priority
|
||||
self.ack_timer = ack_timer
|
||||
self.max_frame_size = max_frame_size
|
||||
self.max_retransmissions = max_retransmissions
|
||||
self.window_size = window_size
|
||||
def __post_init__(self) -> None:
|
||||
if self.initial_credits < 1 or self.initial_credits > 7:
|
||||
logger.warning(
|
||||
f'Initial credits {self.initial_credits} is out of range [1, 7].'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> RFCOMM_MCC_PN:
|
||||
@@ -337,7 +386,7 @@ class RFCOMM_MCC_PN:
|
||||
ack_timer=data[3],
|
||||
max_frame_size=data[4] | data[5] << 8,
|
||||
max_retransmissions=data[6],
|
||||
window_size=data[7],
|
||||
initial_credits=data[7] & 0x07,
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
@@ -350,23 +399,14 @@ class RFCOMM_MCC_PN:
|
||||
self.max_frame_size & 0xFF,
|
||||
(self.max_frame_size >> 8) & 0xFF,
|
||||
self.max_retransmissions & 0xFF,
|
||||
self.window_size & 0xFF,
|
||||
# Only 3 bits are meaningful.
|
||||
self.initial_credits & 0x07,
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'PN(dlci={self.dlci},'
|
||||
f'cl={self.cl},'
|
||||
f'priority={self.priority},'
|
||||
f'ack_timer={self.ack_timer},'
|
||||
f'max_frame_size={self.max_frame_size},'
|
||||
f'max_retransmissions={self.max_retransmissions},'
|
||||
f'window_size={self.window_size})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class RFCOMM_MCC_MSC:
|
||||
dlci: int
|
||||
fc: int
|
||||
@@ -375,16 +415,6 @@ class RFCOMM_MCC_MSC:
|
||||
ic: int
|
||||
dv: int
|
||||
|
||||
def __init__(
|
||||
self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int
|
||||
) -> None:
|
||||
self.dlci = dlci
|
||||
self.fc = fc
|
||||
self.rtc = rtc
|
||||
self.rtr = rtr
|
||||
self.ic = ic
|
||||
self.dv = dv
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> RFCOMM_MCC_MSC:
|
||||
return RFCOMM_MCC_MSC(
|
||||
@@ -409,16 +439,6 @@ class RFCOMM_MCC_MSC:
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'MSC(dlci={self.dlci},'
|
||||
f'fc={self.fc},'
|
||||
f'rtc={self.rtc},'
|
||||
f'rtr={self.rtr},'
|
||||
f'ic={self.ic},'
|
||||
f'dv={self.dv})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class DLC(EventEmitter):
|
||||
@@ -430,35 +450,58 @@ class DLC(EventEmitter):
|
||||
DISCONNECTED = 0x04
|
||||
RESET = 0x05
|
||||
|
||||
connection_result: Optional[asyncio.Future]
|
||||
sink: Optional[Callable[[bytes], None]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
multiplexer: Multiplexer,
|
||||
dlci: int,
|
||||
max_frame_size: int,
|
||||
initial_tx_credits: int,
|
||||
tx_max_frame_size: int,
|
||||
tx_initial_credits: int,
|
||||
rx_max_frame_size: int,
|
||||
rx_initial_credits: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplexer = multiplexer
|
||||
self.dlci = dlci
|
||||
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
|
||||
self.rx_threshold = self.rx_credits // 2
|
||||
self.tx_credits = initial_tx_credits
|
||||
self.rx_max_frame_size = rx_max_frame_size
|
||||
self.rx_initial_credits = rx_initial_credits
|
||||
self.rx_max_credits = RFCOMM_DEFAULT_MAX_CREDITS
|
||||
self.rx_credits = rx_initial_credits
|
||||
self.rx_credits_threshold = RFCOMM_DEFAULT_CREDIT_THRESHOLD
|
||||
self.tx_max_frame_size = tx_max_frame_size
|
||||
self.tx_credits = tx_initial_credits
|
||||
self.tx_buffer = b''
|
||||
self.state = DLC.State.INIT
|
||||
self.role = multiplexer.role
|
||||
self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
|
||||
self.sink = None
|
||||
self.connection_result = None
|
||||
self.connection_result: Optional[asyncio.Future] = None
|
||||
self.disconnection_result: Optional[asyncio.Future] = None
|
||||
self.drained = asyncio.Event()
|
||||
self.drained.set()
|
||||
# Queued packets when sink is not set.
|
||||
self._enqueued_rx_packets: collections.deque[bytes] = collections.deque(
|
||||
maxlen=DEFAULT_RX_QUEUE_SIZE
|
||||
)
|
||||
self._sink: Optional[Callable[[bytes], None]] = None
|
||||
|
||||
# Compute the MTU
|
||||
max_overhead = 4 + 1 # header with 2-byte length + fcs
|
||||
self.mtu = min(
|
||||
max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead
|
||||
tx_max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead
|
||||
)
|
||||
|
||||
@property
|
||||
def sink(self) -> Optional[Callable[[bytes], None]]:
|
||||
return self._sink
|
||||
|
||||
@sink.setter
|
||||
def sink(self, sink: Optional[Callable[[bytes], None]]) -> None:
|
||||
self._sink = sink
|
||||
# Dump queued packets to sink
|
||||
if sink:
|
||||
for packet in self._enqueued_rx_packets:
|
||||
sink(packet) # pylint: disable=not-callable
|
||||
self._enqueued_rx_packets.clear()
|
||||
|
||||
def change_state(self, new_state: State) -> None:
|
||||
logger.debug(f'{self} state change -> {color(new_state.name, "magenta")}')
|
||||
self.state = new_state
|
||||
@@ -467,7 +510,7 @@ class DLC(EventEmitter):
|
||||
self.multiplexer.send_frame(frame)
|
||||
|
||||
def on_frame(self, frame: RFCOMM_Frame) -> None:
|
||||
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
|
||||
handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
|
||||
handler(frame)
|
||||
|
||||
def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
|
||||
@@ -481,9 +524,7 @@ class DLC(EventEmitter):
|
||||
|
||||
# Exchange the modem status with the peer
|
||||
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
|
||||
mcc = RFCOMM_Frame.make_mcc(
|
||||
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
|
||||
)
|
||||
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
|
||||
logger.debug(f'>>> MCC MSC Command: {msc}')
|
||||
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
|
||||
|
||||
@@ -491,22 +532,35 @@ class DLC(EventEmitter):
|
||||
self.emit('open')
|
||||
|
||||
def on_ua_frame(self, _frame: RFCOMM_Frame) -> None:
|
||||
if self.state != DLC.State.CONNECTING:
|
||||
if self.state == DLC.State.CONNECTING:
|
||||
# Exchange the modem status with the peer
|
||||
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
|
||||
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
|
||||
logger.debug(f'>>> MCC MSC Command: {msc}')
|
||||
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
|
||||
|
||||
self.change_state(DLC.State.CONNECTED)
|
||||
if self.connection_result:
|
||||
self.connection_result.set_result(None)
|
||||
self.connection_result = None
|
||||
self.multiplexer.on_dlc_open_complete(self)
|
||||
elif self.state == DLC.State.DISCONNECTING:
|
||||
self.change_state(DLC.State.DISCONNECTED)
|
||||
if self.disconnection_result:
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
self.multiplexer.on_dlc_disconnection(self)
|
||||
self.emit('close')
|
||||
else:
|
||||
logger.warning(
|
||||
color('!!! received SABM when not in CONNECTING state', 'red')
|
||||
color(
|
||||
(
|
||||
'!!! received UA frame when not in '
|
||||
'CONNECTING or DISCONNECTING state'
|
||||
),
|
||||
'red',
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Exchange the modem status with the peer
|
||||
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
|
||||
mcc = RFCOMM_Frame.make_mcc(
|
||||
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
|
||||
)
|
||||
logger.debug(f'>>> MCC MSC Command: {msc}')
|
||||
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
|
||||
|
||||
self.change_state(DLC.State.CONNECTED)
|
||||
self.multiplexer.on_dlc_open_complete(self)
|
||||
|
||||
def on_dm_frame(self, frame: RFCOMM_Frame) -> None:
|
||||
# TODO: handle all states
|
||||
@@ -534,14 +588,22 @@ class DLC(EventEmitter):
|
||||
f'[{self.dlci}] {len(data)} bytes, '
|
||||
f'rx_credits={self.rx_credits}: {data.hex()}'
|
||||
)
|
||||
if len(data) and self.sink:
|
||||
self.sink(data) # pylint: disable=not-callable
|
||||
if data:
|
||||
if self._sink:
|
||||
self._sink(data) # pylint: disable=not-callable
|
||||
else:
|
||||
self._enqueued_rx_packets.append(data)
|
||||
if (
|
||||
self._enqueued_rx_packets.maxlen
|
||||
and len(self._enqueued_rx_packets) >= self._enqueued_rx_packets.maxlen
|
||||
):
|
||||
logger.warning(f'DLC [{self.dlci}] received packet queue is full')
|
||||
|
||||
# Update the credits
|
||||
if self.rx_credits > 0:
|
||||
self.rx_credits -= 1
|
||||
else:
|
||||
logger.warning(color('!!! received frame with no rx credits', 'red'))
|
||||
# Update the credits
|
||||
if self.rx_credits > 0:
|
||||
self.rx_credits -= 1
|
||||
else:
|
||||
logger.warning(color('!!! received frame with no rx credits', 'red'))
|
||||
|
||||
# Check if there's anything to send (including credits)
|
||||
self.process_tx()
|
||||
@@ -554,9 +616,7 @@ class DLC(EventEmitter):
|
||||
# Command
|
||||
logger.debug(f'<<< MCC MSC Command: {msc}')
|
||||
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
|
||||
mcc = RFCOMM_Frame.make_mcc(
|
||||
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
|
||||
)
|
||||
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=0, data=bytes(msc))
|
||||
logger.debug(f'>>> MCC MSC Response: {msc}')
|
||||
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
|
||||
else:
|
||||
@@ -571,6 +631,19 @@ class DLC(EventEmitter):
|
||||
self.connection_result = asyncio.get_running_loop().create_future()
|
||||
self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self.state != DLC.State.CONNECTED:
|
||||
raise InvalidStateError('invalid state')
|
||||
|
||||
self.disconnection_result = asyncio.get_running_loop().create_future()
|
||||
self.change_state(DLC.State.DISCONNECTING)
|
||||
self.send_frame(
|
||||
RFCOMM_Frame.disc(
|
||||
c_r=1 if self.role == Multiplexer.Role.INITIATOR else 0, dlci=self.dlci
|
||||
)
|
||||
)
|
||||
await self.disconnection_result
|
||||
|
||||
def accept(self) -> None:
|
||||
if self.state != DLC.State.INIT:
|
||||
raise InvalidStateError('invalid state')
|
||||
@@ -580,18 +653,18 @@ class DLC(EventEmitter):
|
||||
cl=0xE0,
|
||||
priority=7,
|
||||
ack_timer=0,
|
||||
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
|
||||
max_frame_size=self.rx_max_frame_size,
|
||||
max_retransmissions=0,
|
||||
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
|
||||
initial_credits=self.rx_initial_credits,
|
||||
)
|
||||
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
|
||||
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.PN, c_r=0, data=bytes(pn))
|
||||
logger.debug(f'>>> PN Response: {pn}')
|
||||
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
|
||||
self.change_state(DLC.State.CONNECTING)
|
||||
|
||||
def rx_credits_needed(self) -> int:
|
||||
if self.rx_credits <= self.rx_threshold:
|
||||
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
|
||||
if self.rx_credits <= self.rx_credits_threshold:
|
||||
return self.rx_max_credits - self.rx_credits
|
||||
|
||||
return 0
|
||||
|
||||
@@ -631,6 +704,8 @@ class DLC(EventEmitter):
|
||||
)
|
||||
|
||||
rx_credits_needed = 0
|
||||
if not self.tx_buffer:
|
||||
self.drained.set()
|
||||
|
||||
# Stream protocol
|
||||
def write(self, data: Union[bytes, str]) -> None:
|
||||
@@ -640,17 +715,37 @@ class DLC(EventEmitter):
|
||||
# Automatically convert strings to bytes using UTF-8
|
||||
data = data.encode('utf-8')
|
||||
else:
|
||||
raise ValueError('write only accept bytes or strings')
|
||||
raise InvalidArgumentError('write only accept bytes or strings')
|
||||
|
||||
self.tx_buffer += data
|
||||
self.drained.clear()
|
||||
self.process_tx()
|
||||
|
||||
def drain(self) -> None:
|
||||
# TODO
|
||||
pass
|
||||
async def drain(self) -> None:
|
||||
await self.drained.wait()
|
||||
|
||||
def abort(self) -> None:
|
||||
logger.debug(f'aborting DLC: {self}')
|
||||
if self.connection_result:
|
||||
self.connection_result.cancel()
|
||||
self.connection_result = None
|
||||
if self.disconnection_result:
|
||||
self.disconnection_result.cancel()
|
||||
self.disconnection_result = None
|
||||
self.change_state(DLC.State.RESET)
|
||||
self.emit('close')
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'DLC(dlci={self.dlci},state={self.state.name})'
|
||||
return (
|
||||
f'DLC(dlci={self.dlci}, '
|
||||
f'state={self.state.name}, '
|
||||
f'rx_max_frame_size={self.rx_max_frame_size}, '
|
||||
f'rx_credits={self.rx_credits}, '
|
||||
f'rx_max_credits={self.rx_max_credits}, '
|
||||
f'tx_max_frame_size={self.tx_max_frame_size}, '
|
||||
f'tx_credits={self.tx_credits}'
|
||||
')'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -671,10 +766,10 @@ class Multiplexer(EventEmitter):
|
||||
connection_result: Optional[asyncio.Future]
|
||||
disconnection_result: Optional[asyncio.Future]
|
||||
open_result: Optional[asyncio.Future]
|
||||
acceptor: Optional[Callable[[int], bool]]
|
||||
acceptor: Optional[Callable[[int], Optional[Tuple[int, int]]]]
|
||||
dlcs: Dict[int, DLC]
|
||||
|
||||
def __init__(self, l2cap_channel: l2cap.Channel, role: Role) -> None:
|
||||
def __init__(self, l2cap_channel: l2cap.ClassicChannel, role: Role) -> None:
|
||||
super().__init__()
|
||||
self.role = role
|
||||
self.l2cap_channel = l2cap_channel
|
||||
@@ -683,11 +778,15 @@ class Multiplexer(EventEmitter):
|
||||
self.connection_result = None
|
||||
self.disconnection_result = None
|
||||
self.open_result = None
|
||||
self.open_pn: Optional[RFCOMM_MCC_PN] = None
|
||||
self.open_rx_max_credits = 0
|
||||
self.acceptor = None
|
||||
|
||||
# Become a sink for the L2CAP channel
|
||||
l2cap_channel.sink = self.on_pdu
|
||||
|
||||
l2cap_channel.on('close', self.on_l2cap_channel_close)
|
||||
|
||||
def change_state(self, new_state: State) -> None:
|
||||
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
|
||||
self.state = new_state
|
||||
@@ -704,7 +803,7 @@ class Multiplexer(EventEmitter):
|
||||
if frame.dlci == 0:
|
||||
self.on_frame(frame)
|
||||
else:
|
||||
if frame.type == RFCOMM_DM_FRAME:
|
||||
if frame.type == FrameType.DM:
|
||||
# DM responses are for a DLCI, but since we only create the dlc when we
|
||||
# receive a PN response (because we need the parameters), we handle DM
|
||||
# frames at the Multiplexer level
|
||||
@@ -717,7 +816,7 @@ class Multiplexer(EventEmitter):
|
||||
dlc.on_frame(frame)
|
||||
|
||||
def on_frame(self, frame: RFCOMM_Frame) -> None:
|
||||
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
|
||||
handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
|
||||
handler(frame)
|
||||
|
||||
def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
|
||||
@@ -751,6 +850,7 @@ class Multiplexer(EventEmitter):
|
||||
'rfcomm',
|
||||
)
|
||||
)
|
||||
self.open_result = None
|
||||
else:
|
||||
logger.warning(f'unexpected state for DM: {self}')
|
||||
|
||||
@@ -765,10 +865,10 @@ class Multiplexer(EventEmitter):
|
||||
def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
|
||||
(mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
|
||||
|
||||
if mcc_type == RFCOMM_MCC_PN_TYPE:
|
||||
if mcc_type == MccType.PN:
|
||||
pn = RFCOMM_MCC_PN.from_bytes(value)
|
||||
self.on_mcc_pn(c_r, pn)
|
||||
elif mcc_type == RFCOMM_MCC_MSC_TYPE:
|
||||
elif mcc_type == MccType.MSC:
|
||||
mcs = RFCOMM_MCC_MSC.from_bytes(value)
|
||||
self.on_mcc_msc(c_r, mcs)
|
||||
|
||||
@@ -788,9 +888,16 @@ class Multiplexer(EventEmitter):
|
||||
else:
|
||||
if self.acceptor:
|
||||
channel_number = pn.dlci >> 1
|
||||
if self.acceptor(channel_number):
|
||||
if dlc_params := self.acceptor(channel_number):
|
||||
# Create a new DLC
|
||||
dlc = DLC(self, pn.dlci, pn.max_frame_size, pn.window_size)
|
||||
dlc = DLC(
|
||||
self,
|
||||
dlci=pn.dlci,
|
||||
tx_max_frame_size=pn.max_frame_size,
|
||||
tx_initial_credits=pn.initial_credits,
|
||||
rx_max_frame_size=dlc_params[0],
|
||||
rx_initial_credits=dlc_params[1],
|
||||
)
|
||||
self.dlcs[pn.dlci] = dlc
|
||||
|
||||
# Re-emit the handshake completion event
|
||||
@@ -808,8 +915,17 @@ class Multiplexer(EventEmitter):
|
||||
# Response
|
||||
logger.debug(f'>>> PN Response: {pn}')
|
||||
if self.state == Multiplexer.State.OPENING:
|
||||
dlc = DLC(self, pn.dlci, pn.max_frame_size, pn.window_size)
|
||||
assert self.open_pn
|
||||
dlc = DLC(
|
||||
self,
|
||||
dlci=pn.dlci,
|
||||
tx_max_frame_size=pn.max_frame_size,
|
||||
tx_initial_credits=pn.initial_credits,
|
||||
rx_max_frame_size=self.open_pn.max_frame_size,
|
||||
rx_initial_credits=self.open_pn.initial_credits,
|
||||
)
|
||||
self.dlcs[pn.dlci] = dlc
|
||||
self.open_pn = None
|
||||
dlc.connect()
|
||||
else:
|
||||
logger.warning('ignoring PN response')
|
||||
@@ -843,24 +959,31 @@ class Multiplexer(EventEmitter):
|
||||
)
|
||||
await self.disconnection_result
|
||||
|
||||
async def open_dlc(self, channel: int) -> DLC:
|
||||
async def open_dlc(
|
||||
self,
|
||||
channel: int,
|
||||
max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
|
||||
initial_credits: int = RFCOMM_DEFAULT_INITIAL_CREDITS,
|
||||
) -> DLC:
|
||||
if self.state != Multiplexer.State.CONNECTED:
|
||||
if self.state == Multiplexer.State.OPENING:
|
||||
raise InvalidStateError('open already in progress')
|
||||
|
||||
raise InvalidStateError('not connected')
|
||||
|
||||
pn = RFCOMM_MCC_PN(
|
||||
self.open_pn = RFCOMM_MCC_PN(
|
||||
dlci=channel << 1,
|
||||
cl=0xF0,
|
||||
priority=7,
|
||||
ack_timer=0,
|
||||
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
|
||||
max_frame_size=max_frame_size,
|
||||
max_retransmissions=0,
|
||||
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
|
||||
initial_credits=initial_credits,
|
||||
)
|
||||
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
|
||||
logger.debug(f'>>> Sending MCC: {pn}')
|
||||
mcc = RFCOMM_Frame.make_mcc(
|
||||
mcc_type=MccType.PN, c_r=1, data=bytes(self.open_pn)
|
||||
)
|
||||
logger.debug(f'>>> Sending MCC: {self.open_pn}')
|
||||
self.open_result = asyncio.get_running_loop().create_future()
|
||||
self.change_state(Multiplexer.State.OPENING)
|
||||
self.send_frame(
|
||||
@@ -870,15 +993,31 @@ class Multiplexer(EventEmitter):
|
||||
information=mcc,
|
||||
)
|
||||
)
|
||||
result = await self.open_result
|
||||
self.open_result = None
|
||||
return result
|
||||
return await self.open_result
|
||||
|
||||
def on_dlc_open_complete(self, dlc: DLC) -> None:
|
||||
logger.debug(f'DLC [{dlc.dlci}] open complete')
|
||||
|
||||
self.change_state(Multiplexer.State.CONNECTED)
|
||||
|
||||
if self.open_result:
|
||||
self.open_result.set_result(dlc)
|
||||
self.open_result = None
|
||||
|
||||
def on_dlc_disconnection(self, dlc: DLC) -> None:
|
||||
logger.debug(f'DLC [{dlc.dlci}] disconnection')
|
||||
self.dlcs.pop(dlc.dlci, None)
|
||||
|
||||
def on_l2cap_channel_close(self) -> None:
|
||||
logger.debug('L2CAP channel closed, cleaning up')
|
||||
if self.open_result:
|
||||
self.open_result.cancel()
|
||||
self.open_result = None
|
||||
if self.disconnection_result:
|
||||
self.disconnection_result.cancel()
|
||||
self.disconnection_result = None
|
||||
for dlc in self.dlcs.values():
|
||||
dlc.abort()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'Multiplexer(state={self.state.name})'
|
||||
@@ -887,26 +1026,28 @@ class Multiplexer(EventEmitter):
|
||||
# -----------------------------------------------------------------------------
|
||||
class Client:
|
||||
multiplexer: Optional[Multiplexer]
|
||||
l2cap_channel: Optional[l2cap.Channel]
|
||||
l2cap_channel: Optional[l2cap.ClassicChannel]
|
||||
|
||||
def __init__(self, device: Device, connection: Connection) -> None:
|
||||
self.device = device
|
||||
def __init__(
|
||||
self, connection: Connection, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
|
||||
) -> None:
|
||||
self.connection = connection
|
||||
self.l2cap_mtu = l2cap_mtu
|
||||
self.l2cap_channel = None
|
||||
self.multiplexer = None
|
||||
|
||||
async def start(self) -> Multiplexer:
|
||||
# Create a new L2CAP connection
|
||||
try:
|
||||
self.l2cap_channel = await self.device.l2cap_channel_manager.connect(
|
||||
self.connection, RFCOMM_PSM
|
||||
self.l2cap_channel = await self.connection.create_l2cap_channel(
|
||||
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=self.l2cap_mtu)
|
||||
)
|
||||
except ProtocolError as error:
|
||||
logger.warning(f'L2CAP connection failed: {error}')
|
||||
raise
|
||||
|
||||
assert self.l2cap_channel is not None
|
||||
# Create a mutliplexer to manage DLCs with the server
|
||||
# Create a multiplexer to manage DLCs with the server
|
||||
self.multiplexer = Multiplexer(self.l2cap_channel, Multiplexer.Role.INITIATOR)
|
||||
|
||||
# Connect the multiplexer
|
||||
@@ -922,23 +1063,40 @@ class Client:
|
||||
self.multiplexer = None
|
||||
|
||||
# Close the L2CAP channel
|
||||
# TODO
|
||||
if self.l2cap_channel:
|
||||
await self.l2cap_channel.disconnect()
|
||||
self.l2cap_channel = None
|
||||
|
||||
async def __aenter__(self) -> Multiplexer:
|
||||
return await self.start()
|
||||
|
||||
async def __aexit__(self, *args) -> None:
|
||||
await self.shutdown()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Server(EventEmitter):
|
||||
acceptors: Dict[int, Callable[[DLC], None]]
|
||||
|
||||
def __init__(self, device: Device) -> None:
|
||||
def __init__(
|
||||
self, device: Device, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.multiplexer = None
|
||||
self.acceptors = {}
|
||||
self.acceptors: Dict[int, Callable[[DLC], None]] = {}
|
||||
self.dlc_configs: Dict[int, Tuple[int, int]] = {}
|
||||
|
||||
# Register ourselves with the L2CAP channel manager
|
||||
device.register_l2cap_server(RFCOMM_PSM, self.on_connection)
|
||||
self.l2cap_server = device.create_l2cap_server(
|
||||
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=l2cap_mtu),
|
||||
handler=self.on_connection,
|
||||
)
|
||||
|
||||
def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int:
|
||||
def listen(
|
||||
self,
|
||||
acceptor: Callable[[DLC], None],
|
||||
channel: int = 0,
|
||||
max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
|
||||
initial_credits: int = RFCOMM_DEFAULT_INITIAL_CREDITS,
|
||||
) -> int:
|
||||
if channel:
|
||||
if channel in self.acceptors:
|
||||
# Busy
|
||||
@@ -958,13 +1116,15 @@ class Server(EventEmitter):
|
||||
return 0
|
||||
|
||||
self.acceptors[channel] = acceptor
|
||||
self.dlc_configs[channel] = (max_frame_size, initial_credits)
|
||||
|
||||
return channel
|
||||
|
||||
def on_connection(self, l2cap_channel: l2cap.Channel) -> None:
|
||||
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')
|
||||
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
|
||||
|
||||
def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None:
|
||||
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
|
||||
|
||||
# Create a new multiplexer for the channel
|
||||
@@ -975,13 +1135,18 @@ class Server(EventEmitter):
|
||||
# Notify
|
||||
self.emit('start', multiplexer)
|
||||
|
||||
def accept_dlc(self, channel_number: int) -> bool:
|
||||
return channel_number in self.acceptors
|
||||
def accept_dlc(self, channel_number: int) -> Optional[Tuple[int, int]]:
|
||||
return self.dlc_configs.get(channel_number)
|
||||
|
||||
def on_dlc(self, dlc: DLC) -> None:
|
||||
logger.debug(f'@@@ new DLC connected: {dlc}')
|
||||
|
||||
# Let the acceptor know
|
||||
acceptor = self.acceptors.get(dlc.dlci >> 1)
|
||||
if acceptor:
|
||||
if acceptor := self.acceptors.get(dlc.dlci >> 1):
|
||||
acceptor(dlc)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.l2cap_server.close()
|
||||
|
||||
@@ -19,10 +19,11 @@ from __future__ import annotations
|
||||
import logging
|
||||
import struct
|
||||
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
|
||||
from . import core, l2cap
|
||||
from .colors import color
|
||||
from .core import InvalidStateError
|
||||
from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError
|
||||
from .hci import HCI_Object, name_or_number, key_with_value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -97,7 +98,8 @@ SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID = 0X000B
|
||||
SDP_ICON_URL_ATTRIBUTE_ID = 0X000C
|
||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D
|
||||
|
||||
# Attribute Identifier (cf. Assigned Numbers for Service Discovery)
|
||||
|
||||
# Profile-specific Attribute Identifiers (cf. Assigned Numbers for Service Discovery)
|
||||
# used by AVRCP, HFP and A2DP
|
||||
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311
|
||||
|
||||
@@ -115,7 +117,8 @@ SDP_ATTRIBUTE_ID_NAMES = {
|
||||
SDP_DOCUMENTATION_URL_ATTRIBUTE_ID: 'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID',
|
||||
SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID: 'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID',
|
||||
SDP_ICON_URL_ATTRIBUTE_ID: 'SDP_ICON_URL_ATTRIBUTE_ID',
|
||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID'
|
||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID',
|
||||
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID: 'SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID',
|
||||
}
|
||||
|
||||
SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
|
||||
@@ -167,7 +170,7 @@ class DataElement:
|
||||
UUID: lambda x: DataElement(
|
||||
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
|
||||
),
|
||||
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')),
|
||||
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
|
||||
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
|
||||
SEQUENCE: lambda x: DataElement(
|
||||
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
|
||||
@@ -186,7 +189,9 @@ class DataElement:
|
||||
self.bytes = None
|
||||
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
|
||||
if value_size is None:
|
||||
raise ValueError('integer types must have a value size specified')
|
||||
raise InvalidArgumentError(
|
||||
'integer types must have a value size specified'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def nil() -> DataElement:
|
||||
@@ -229,7 +234,7 @@ class DataElement:
|
||||
return DataElement(DataElement.UUID, value)
|
||||
|
||||
@staticmethod
|
||||
def text_string(value: str) -> DataElement:
|
||||
def text_string(value: bytes) -> DataElement:
|
||||
return DataElement(DataElement.TEXT_STRING, value)
|
||||
|
||||
@staticmethod
|
||||
@@ -262,7 +267,7 @@ class DataElement:
|
||||
if len(data) == 8:
|
||||
return struct.unpack('>Q', data)[0]
|
||||
|
||||
raise ValueError(f'invalid integer length {len(data)}')
|
||||
raise InvalidPacketError(f'invalid integer length {len(data)}')
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_from_bytes(data):
|
||||
@@ -278,7 +283,7 @@ class DataElement:
|
||||
if len(data) == 8:
|
||||
return struct.unpack('>q', data)[0]
|
||||
|
||||
raise ValueError(f'invalid integer length {len(data)}')
|
||||
raise InvalidPacketError(f'invalid integer length {len(data)}')
|
||||
|
||||
@staticmethod
|
||||
def list_from_bytes(data):
|
||||
@@ -351,7 +356,7 @@ class DataElement:
|
||||
data = b''
|
||||
elif self.type == DataElement.UNSIGNED_INTEGER:
|
||||
if self.value < 0:
|
||||
raise ValueError('UNSIGNED_INTEGER cannot be negative')
|
||||
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
|
||||
|
||||
if self.value_size == 1:
|
||||
data = struct.pack('B', self.value)
|
||||
@@ -362,7 +367,7 @@ class DataElement:
|
||||
elif self.value_size == 8:
|
||||
data = struct.pack('>Q', self.value)
|
||||
else:
|
||||
raise ValueError('invalid value_size')
|
||||
raise InvalidArgumentError('invalid value_size')
|
||||
elif self.type == DataElement.SIGNED_INTEGER:
|
||||
if self.value_size == 1:
|
||||
data = struct.pack('b', self.value)
|
||||
@@ -373,10 +378,10 @@ class DataElement:
|
||||
elif self.value_size == 8:
|
||||
data = struct.pack('>q', self.value)
|
||||
else:
|
||||
raise ValueError('invalid value_size')
|
||||
raise InvalidArgumentError('invalid value_size')
|
||||
elif self.type == DataElement.UUID:
|
||||
data = bytes(reversed(bytes(self.value)))
|
||||
elif self.type in (DataElement.TEXT_STRING, DataElement.URL):
|
||||
elif self.type == DataElement.URL:
|
||||
data = self.value.encode('utf8')
|
||||
elif self.type == DataElement.BOOLEAN:
|
||||
data = bytes([1 if self.value else 0])
|
||||
@@ -389,7 +394,7 @@ class DataElement:
|
||||
size_bytes = b''
|
||||
if self.type == DataElement.NIL:
|
||||
if size != 0:
|
||||
raise ValueError('NIL must be empty')
|
||||
raise InvalidArgumentError('NIL must be empty')
|
||||
size_index = 0
|
||||
elif self.type in (
|
||||
DataElement.UNSIGNED_INTEGER,
|
||||
@@ -407,7 +412,7 @@ class DataElement:
|
||||
elif size == 16:
|
||||
size_index = 4
|
||||
else:
|
||||
raise ValueError('invalid data size')
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
elif self.type in (
|
||||
DataElement.TEXT_STRING,
|
||||
DataElement.SEQUENCE,
|
||||
@@ -424,10 +429,10 @@ class DataElement:
|
||||
size_index = 7
|
||||
size_bytes = struct.pack('>I', size)
|
||||
else:
|
||||
raise ValueError('invalid data size')
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
elif self.type == DataElement.BOOLEAN:
|
||||
if size != 1:
|
||||
raise ValueError('boolean must be 1 byte')
|
||||
raise InvalidArgumentError('boolean must be 1 byte')
|
||||
size_index = 0
|
||||
|
||||
self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
|
||||
@@ -758,16 +763,17 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Client:
|
||||
channel: Optional[l2cap.Channel]
|
||||
channel: Optional[l2cap.ClassicChannel]
|
||||
|
||||
def __init__(self, device: Device) -> None:
|
||||
self.device = device
|
||||
def __init__(self, connection: Connection) -> None:
|
||||
self.connection = connection
|
||||
self.pending_request = None
|
||||
self.channel = None
|
||||
|
||||
async def connect(self, connection: Connection) -> None:
|
||||
result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM)
|
||||
self.channel = result
|
||||
async def connect(self) -> None:
|
||||
self.channel = await self.connection.create_l2cap_channel(
|
||||
spec=l2cap.ClassicChannelSpec(SDP_PSM)
|
||||
)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self.channel:
|
||||
@@ -821,11 +827,13 @@ class Client:
|
||||
)
|
||||
attribute_id_list = DataElement.sequence(
|
||||
[
|
||||
DataElement.unsigned_integer(
|
||||
attribute_id[0], value_size=attribute_id[1]
|
||||
(
|
||||
DataElement.unsigned_integer(
|
||||
attribute_id[0], value_size=attribute_id[1]
|
||||
)
|
||||
if isinstance(attribute_id, tuple)
|
||||
else DataElement.unsigned_integer_16(attribute_id)
|
||||
)
|
||||
if isinstance(attribute_id, tuple)
|
||||
else DataElement.unsigned_integer_16(attribute_id)
|
||||
for attribute_id in attribute_ids
|
||||
]
|
||||
)
|
||||
@@ -877,11 +885,13 @@ class Client:
|
||||
|
||||
attribute_id_list = DataElement.sequence(
|
||||
[
|
||||
DataElement.unsigned_integer(
|
||||
attribute_id[0], value_size=attribute_id[1]
|
||||
(
|
||||
DataElement.unsigned_integer(
|
||||
attribute_id[0], value_size=attribute_id[1]
|
||||
)
|
||||
if isinstance(attribute_id, tuple)
|
||||
else DataElement.unsigned_integer_16(attribute_id)
|
||||
)
|
||||
if isinstance(attribute_id, tuple)
|
||||
else DataElement.unsigned_integer_16(attribute_id)
|
||||
for attribute_id in attribute_ids
|
||||
]
|
||||
)
|
||||
@@ -917,11 +927,18 @@ class Client:
|
||||
|
||||
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args) -> None:
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Server:
|
||||
CONTINUATION_STATE = bytes([0x01, 0x43])
|
||||
channel: Optional[l2cap.Channel]
|
||||
channel: Optional[l2cap.ClassicChannel]
|
||||
Service = NewType('Service', List[ServiceAttribute])
|
||||
service_records: Dict[int, Service]
|
||||
current_response: Union[None, bytes, Tuple[int, List[int]]]
|
||||
@@ -933,7 +950,9 @@ class Server:
|
||||
self.current_response = None
|
||||
|
||||
def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
|
||||
l2cap_channel_manager.register_server(SDP_PSM, self.on_connection)
|
||||
l2cap_channel_manager.create_classic_server(
|
||||
spec=l2cap.ClassicChannelSpec(psm=SDP_PSM), handler=self.on_connection
|
||||
)
|
||||
|
||||
def send_response(self, response):
|
||||
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
|
||||
@@ -980,7 +999,7 @@ class Server:
|
||||
try:
|
||||
handler(sdp_pdu)
|
||||
except Exception as error:
|
||||
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
|
||||
logger.exception(f'{color("!!! Exception in handler:", "red")} {error}')
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=sdp_pdu.transaction_id,
|
||||
|
||||
299
bumble/smp.py
299
bumble/smp.py
@@ -27,6 +27,7 @@ import logging
|
||||
import asyncio
|
||||
import enum
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -53,6 +54,8 @@ from .core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_CENTRAL_ROLE,
|
||||
BT_LE_TRANSPORT,
|
||||
AdvertisingData,
|
||||
InvalidArgumentError,
|
||||
ProtocolError,
|
||||
name_or_number,
|
||||
)
|
||||
@@ -185,8 +188,8 @@ SMP_KEYPRESS_AUTHREQ = 0b00010000
|
||||
SMP_CT2_AUTHREQ = 0b00100000
|
||||
|
||||
# Crypto salt
|
||||
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('00000000000000000000000000000000746D7031')
|
||||
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032')
|
||||
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('000000000000000000000000746D7031')
|
||||
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
|
||||
|
||||
# fmt: on
|
||||
# pylint: enable=line-too-long
|
||||
@@ -563,6 +566,54 @@ class PairingMethod(enum.IntEnum):
|
||||
CTKD_OVER_CLASSIC = 4
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OobContext:
|
||||
"""Cryptographic context for LE SC OOB pairing."""
|
||||
|
||||
ecc_key: crypto.EccKey
|
||||
r: bytes
|
||||
|
||||
def __init__(
|
||||
self, ecc_key: Optional[crypto.EccKey] = None, r: Optional[bytes] = None
|
||||
) -> None:
|
||||
self.ecc_key = crypto.EccKey.generate() if ecc_key is None else ecc_key
|
||||
self.r = crypto.r() if r is None else r
|
||||
|
||||
def share(self) -> OobSharedData:
|
||||
pkx = self.ecc_key.x[::-1]
|
||||
return OobSharedData(c=crypto.f4(pkx, pkx, self.r, bytes(1)), r=self.r)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OobLegacyContext:
|
||||
"""Cryptographic context for LE Legacy OOB pairing."""
|
||||
|
||||
tk: bytes
|
||||
|
||||
def __init__(self, tk: Optional[bytes] = None) -> None:
|
||||
self.tk = crypto.r() if tk is None else tk
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class OobSharedData:
|
||||
"""Shareable data for LE SC OOB pairing."""
|
||||
|
||||
c: bytes
|
||||
r: bytes
|
||||
|
||||
def to_ad(self) -> AdvertisingData:
|
||||
return AdvertisingData(
|
||||
[
|
||||
(AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE, self.c),
|
||||
(AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE, self.r),
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'OOB(C={self.c.hex()}, R={self.r.hex()})'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Session:
|
||||
# I/O Capability to pairing method decision matrix
|
||||
@@ -627,6 +678,13 @@ class Session:
|
||||
},
|
||||
}
|
||||
|
||||
ea: bytes
|
||||
eb: bytes
|
||||
ltk: bytes
|
||||
preq: bytes
|
||||
pres: bytes
|
||||
tk: bytes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: Manager,
|
||||
@@ -636,17 +694,10 @@ class Session:
|
||||
) -> None:
|
||||
self.manager = manager
|
||||
self.connection = connection
|
||||
self.preq: Optional[bytes] = None
|
||||
self.pres: Optional[bytes] = None
|
||||
self.ea = None
|
||||
self.eb = None
|
||||
self.tk = bytes(16)
|
||||
self.r = bytes(16)
|
||||
self.stk = None
|
||||
self.ltk = None
|
||||
self.ltk_ediv = 0
|
||||
self.ltk_rand = bytes(8)
|
||||
self.link_key = None
|
||||
self.link_key: Optional[bytes] = None
|
||||
self.initiator_key_distribution: int = 0
|
||||
self.responder_key_distribution: int = 0
|
||||
self.peer_random_value: Optional[bytes] = None
|
||||
@@ -659,7 +710,7 @@ class Session:
|
||||
self.peer_bd_addr: Optional[Address] = None
|
||||
self.peer_signature_key = None
|
||||
self.peer_expected_distributions: List[Type[SMP_Command]] = []
|
||||
self.dh_key = None
|
||||
self.dh_key = b''
|
||||
self.confirm_value = None
|
||||
self.passkey: Optional[int] = None
|
||||
self.passkey_ready = asyncio.Event()
|
||||
@@ -687,9 +738,9 @@ class Session:
|
||||
|
||||
# Create a future that can be used to wait for the session to complete
|
||||
if self.is_initiator:
|
||||
self.pairing_result: Optional[
|
||||
asyncio.Future[None]
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
self.pairing_result: Optional[asyncio.Future[None]] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
else:
|
||||
self.pairing_result = None
|
||||
|
||||
@@ -712,8 +763,8 @@ class Session:
|
||||
self.io_capability = pairing_config.delegate.io_capability
|
||||
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
|
||||
|
||||
# OOB (not supported yet)
|
||||
self.oob = False
|
||||
# OOB
|
||||
self.oob_data_flag = 0 if pairing_config.oob is None else 1
|
||||
|
||||
# Set up addresses
|
||||
self_address = connection.self_address
|
||||
@@ -729,9 +780,35 @@ class Session:
|
||||
self.ia = bytes(peer_address)
|
||||
self.iat = 1 if peer_address.is_random else 0
|
||||
|
||||
# Select the ECC key, TK and r initial value
|
||||
if pairing_config.oob:
|
||||
self.peer_oob_data = pairing_config.oob.peer_data
|
||||
if pairing_config.sc:
|
||||
if pairing_config.oob.our_context is None:
|
||||
raise InvalidArgumentError(
|
||||
"oob pairing config requires a context when sc is True"
|
||||
)
|
||||
self.r = pairing_config.oob.our_context.r
|
||||
self.ecc_key = pairing_config.oob.our_context.ecc_key
|
||||
if pairing_config.oob.legacy_context is not None:
|
||||
self.tk = pairing_config.oob.legacy_context.tk
|
||||
else:
|
||||
if pairing_config.oob.legacy_context is None:
|
||||
raise InvalidArgumentError(
|
||||
"oob pairing config requires a legacy context when sc is False"
|
||||
)
|
||||
self.r = bytes(16)
|
||||
self.ecc_key = manager.ecc_key
|
||||
self.tk = pairing_config.oob.legacy_context.tk
|
||||
else:
|
||||
self.peer_oob_data = None
|
||||
self.r = bytes(16)
|
||||
self.ecc_key = manager.ecc_key
|
||||
self.tk = bytes(16)
|
||||
|
||||
@property
|
||||
def pkx(self) -> Tuple[bytes, bytes]:
|
||||
return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x)
|
||||
return (self.ecc_key.x[::-1], self.peer_public_key_x)
|
||||
|
||||
@property
|
||||
def pka(self) -> bytes:
|
||||
@@ -768,7 +845,10 @@ class Session:
|
||||
return None
|
||||
|
||||
def decide_pairing_method(
|
||||
self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
|
||||
self,
|
||||
auth_req: int,
|
||||
initiator_io_capability: int,
|
||||
responder_io_capability: int,
|
||||
) -> None:
|
||||
if self.connection.transport == BT_BR_EDR_TRANSPORT:
|
||||
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
|
||||
@@ -909,7 +989,7 @@ class Session:
|
||||
|
||||
command = SMP_Pairing_Request_Command(
|
||||
io_capability=self.io_capability,
|
||||
oob_data_flag=0,
|
||||
oob_data_flag=self.oob_data_flag,
|
||||
auth_req=self.auth_req,
|
||||
maximum_encryption_key_size=16,
|
||||
initiator_key_distribution=self.initiator_key_distribution,
|
||||
@@ -921,7 +1001,7 @@ class Session:
|
||||
def send_pairing_response_command(self) -> None:
|
||||
response = SMP_Pairing_Response_Command(
|
||||
io_capability=self.io_capability,
|
||||
oob_data_flag=0,
|
||||
oob_data_flag=self.oob_data_flag,
|
||||
auth_req=self.auth_req,
|
||||
maximum_encryption_key_size=16,
|
||||
initiator_key_distribution=self.initiator_key_distribution,
|
||||
@@ -982,8 +1062,8 @@ class Session:
|
||||
def send_public_key_command(self) -> None:
|
||||
self.send_command(
|
||||
SMP_Pairing_Public_Key_Command(
|
||||
public_key_x=bytes(reversed(self.manager.ecc_key.x)),
|
||||
public_key_y=bytes(reversed(self.manager.ecc_key.y)),
|
||||
public_key_x=self.ecc_key.x[::-1],
|
||||
public_key_y=self.ecc_key.y[::-1],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1011,7 +1091,7 @@ class Session:
|
||||
# We can now encrypt the connection with the short term key, so that we can
|
||||
# distribute the long term and/or other keys over an encrypted connection
|
||||
self.manager.device.host.send_command_sync(
|
||||
HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
|
||||
HCI_LE_Enable_Encryption_Command(
|
||||
connection_handle=self.connection.handle,
|
||||
random_number=bytes(8),
|
||||
encrypted_diversifier=0,
|
||||
@@ -1019,18 +1099,56 @@ class Session:
|
||||
)
|
||||
)
|
||||
|
||||
async def derive_ltk(self) -> None:
|
||||
link_key = await self.manager.device.get_link_key(self.connection.peer_address)
|
||||
assert link_key is not None
|
||||
@classmethod
|
||||
def derive_ltk(cls, link_key: bytes, ct2: bool) -> bytes:
|
||||
'''Derives Long Term Key from Link Key.
|
||||
|
||||
Args:
|
||||
link_key: BR/EDR Link Key bytes in little-endian.
|
||||
ct2: whether ct2 is supported on both devices.
|
||||
Returns:
|
||||
LE Long Tern Key bytes in little-endian.
|
||||
'''
|
||||
ilk = (
|
||||
crypto.h7(salt=SMP_CTKD_H7_BRLE_SALT, w=link_key)
|
||||
if self.ct2
|
||||
if ct2
|
||||
else crypto.h6(link_key, b'tmp2')
|
||||
)
|
||||
self.ltk = crypto.h6(ilk, b'brle')
|
||||
return crypto.h6(ilk, b'brle')
|
||||
|
||||
@classmethod
|
||||
def derive_link_key(cls, ltk: bytes, ct2: bool) -> bytes:
|
||||
'''Derives Link Key from Long Term Key.
|
||||
|
||||
Args:
|
||||
ltk: LE Long Term Key bytes in little-endian.
|
||||
ct2: whether ct2 is supported on both devices.
|
||||
Returns:
|
||||
BR/EDR Link Key bytes in little-endian.
|
||||
'''
|
||||
ilk = (
|
||||
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=ltk)
|
||||
if ct2
|
||||
else crypto.h6(ltk, b'tmp1')
|
||||
)
|
||||
return crypto.h6(ilk, b'lebr')
|
||||
|
||||
async def get_link_key_and_derive_ltk(self) -> None:
|
||||
'''Retrieves BR/EDR Link Key from storage and derive it to LE LTK.'''
|
||||
self.link_key = await self.manager.device.get_link_key(
|
||||
self.connection.peer_address
|
||||
)
|
||||
if self.link_key is None:
|
||||
logging.warning(
|
||||
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
|
||||
)
|
||||
self.send_pairing_failed(
|
||||
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR
|
||||
)
|
||||
else:
|
||||
self.ltk = self.derive_ltk(self.link_key, self.ct2)
|
||||
|
||||
def distribute_keys(self) -> None:
|
||||
|
||||
# Distribute the keys as required
|
||||
if self.is_initiator:
|
||||
# CTKD: Derive LTK from LinkKey
|
||||
@@ -1039,7 +1157,7 @@ class Session:
|
||||
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
||||
):
|
||||
self.ctkd_task = self.connection.abort_on(
|
||||
'disconnection', self.derive_ltk()
|
||||
'disconnection', self.get_link_key_and_derive_ltk()
|
||||
)
|
||||
elif not self.sc:
|
||||
# Distribute the LTK, EDIV and RAND
|
||||
@@ -1069,12 +1187,7 @@ class Session:
|
||||
|
||||
# CTKD, calculate BR/EDR link key
|
||||
if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
||||
ilk = (
|
||||
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk)
|
||||
if self.ct2
|
||||
else crypto.h6(self.ltk, b'tmp1')
|
||||
)
|
||||
self.link_key = crypto.h6(ilk, b'lebr')
|
||||
self.link_key = self.derive_link_key(self.ltk, self.ct2)
|
||||
|
||||
else:
|
||||
# CTKD: Derive LTK from LinkKey
|
||||
@@ -1083,7 +1196,7 @@ class Session:
|
||||
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
||||
):
|
||||
self.ctkd_task = self.connection.abort_on(
|
||||
'disconnection', self.derive_ltk()
|
||||
'disconnection', self.get_link_key_and_derive_ltk()
|
||||
)
|
||||
# Distribute the LTK, EDIV and RAND
|
||||
elif not self.sc:
|
||||
@@ -1113,12 +1226,7 @@ class Session:
|
||||
|
||||
# CTKD, calculate BR/EDR link key
|
||||
if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
||||
ilk = (
|
||||
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk)
|
||||
if self.ct2
|
||||
else crypto.h6(self.ltk, b'tmp1')
|
||||
)
|
||||
self.link_key = crypto.h6(ilk, b'lebr')
|
||||
self.link_key = self.derive_link_key(self.ltk, self.ct2)
|
||||
|
||||
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
|
||||
# Set our expectations for what to wait for in the key distribution phase
|
||||
@@ -1296,7 +1404,7 @@ class Session:
|
||||
try:
|
||||
handler(command)
|
||||
except Exception as error:
|
||||
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
|
||||
logger.exception(f'{color("!!! Exception in handler:", "red")} {error}')
|
||||
response = SMP_Pairing_Failed_Command(
|
||||
reason=SMP_UNSPECIFIED_REASON_ERROR
|
||||
)
|
||||
@@ -1333,15 +1441,28 @@ class Session:
|
||||
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
|
||||
self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0)
|
||||
|
||||
# Check for OOB
|
||||
if command.oob_data_flag != 0:
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
return
|
||||
# Infer the pairing method
|
||||
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
|
||||
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0)
|
||||
):
|
||||
# Use OOB
|
||||
self.pairing_method = PairingMethod.OOB
|
||||
if not self.sc and self.tk is None:
|
||||
# For legacy OOB, TK is required.
|
||||
logger.warning("legacy OOB without TK")
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
return
|
||||
if command.oob_data_flag == 0:
|
||||
# The peer doesn't have OOB data, use r=0
|
||||
self.r = bytes(16)
|
||||
else:
|
||||
# Decide which pairing method to use from the IO capability
|
||||
self.decide_pairing_method(
|
||||
command.auth_req,
|
||||
command.io_capability,
|
||||
self.io_capability,
|
||||
)
|
||||
|
||||
# Decide which pairing method to use
|
||||
self.decide_pairing_method(
|
||||
command.auth_req, command.io_capability, self.io_capability
|
||||
)
|
||||
logger.debug(f'pairing method: {self.pairing_method.name}')
|
||||
|
||||
# Key distribution
|
||||
@@ -1390,15 +1511,26 @@ class Session:
|
||||
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
|
||||
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
|
||||
|
||||
# Check for OOB
|
||||
if self.sc and command.oob_data_flag:
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
return
|
||||
# Infer the pairing method
|
||||
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
|
||||
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0)
|
||||
):
|
||||
# Use OOB
|
||||
self.pairing_method = PairingMethod.OOB
|
||||
if not self.sc and self.tk is None:
|
||||
# For legacy OOB, TK is required.
|
||||
logger.warning("legacy OOB without TK")
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
return
|
||||
if command.oob_data_flag == 0:
|
||||
# The peer doesn't have OOB data, use r=0
|
||||
self.r = bytes(16)
|
||||
else:
|
||||
# Decide which pairing method to use from the IO capability
|
||||
self.decide_pairing_method(
|
||||
command.auth_req, self.io_capability, command.io_capability
|
||||
)
|
||||
|
||||
# Decide which pairing method to use
|
||||
self.decide_pairing_method(
|
||||
command.auth_req, self.io_capability, command.io_capability
|
||||
)
|
||||
logger.debug(f'pairing method: {self.pairing_method.name}')
|
||||
|
||||
# Key distribution
|
||||
@@ -1549,12 +1681,13 @@ class Session:
|
||||
if self.passkey_step < 20:
|
||||
self.send_pairing_confirm_command()
|
||||
return
|
||||
else:
|
||||
elif self.pairing_method != PairingMethod.OOB:
|
||||
return
|
||||
else:
|
||||
if self.pairing_method in (
|
||||
PairingMethod.JUST_WORKS,
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
PairingMethod.OOB,
|
||||
):
|
||||
self.send_pairing_random_command()
|
||||
elif self.pairing_method == PairingMethod.PASSKEY:
|
||||
@@ -1591,6 +1724,7 @@ class Session:
|
||||
if self.pairing_method in (
|
||||
PairingMethod.JUST_WORKS,
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
PairingMethod.OOB,
|
||||
):
|
||||
ra = bytes(16)
|
||||
rb = ra
|
||||
@@ -1599,7 +1733,6 @@ class Session:
|
||||
ra = self.passkey.to_bytes(16, byteorder='little')
|
||||
rb = ra
|
||||
else:
|
||||
# OOB not implemented yet
|
||||
return
|
||||
|
||||
assert self.preq and self.pres
|
||||
@@ -1651,18 +1784,33 @@ class Session:
|
||||
self.peer_public_key_y = command.public_key_y
|
||||
|
||||
# Compute the DH key
|
||||
self.dh_key = bytes(
|
||||
reversed(
|
||||
self.manager.ecc_key.dh(
|
||||
bytes(reversed(command.public_key_x)),
|
||||
bytes(reversed(command.public_key_y)),
|
||||
)
|
||||
)
|
||||
)
|
||||
self.dh_key = self.ecc_key.dh(
|
||||
command.public_key_x[::-1],
|
||||
command.public_key_y[::-1],
|
||||
)[::-1]
|
||||
logger.debug(f'DH key: {self.dh_key.hex()}')
|
||||
|
||||
if self.pairing_method == PairingMethod.OOB:
|
||||
# Check against shared OOB data
|
||||
if self.peer_oob_data:
|
||||
confirm_verifier = crypto.f4(
|
||||
self.peer_public_key_x,
|
||||
self.peer_public_key_x,
|
||||
self.peer_oob_data.r,
|
||||
bytes(1),
|
||||
)
|
||||
if not self.check_expected_value(
|
||||
self.peer_oob_data.c,
|
||||
confirm_verifier,
|
||||
SMP_CONFIRM_VALUE_FAILED_ERROR,
|
||||
):
|
||||
return
|
||||
|
||||
if self.is_initiator:
|
||||
self.send_pairing_confirm_command()
|
||||
if self.pairing_method == PairingMethod.OOB:
|
||||
self.send_pairing_random_command()
|
||||
else:
|
||||
self.send_pairing_confirm_command()
|
||||
else:
|
||||
if self.pairing_method == PairingMethod.PASSKEY:
|
||||
self.display_or_input_passkey()
|
||||
@@ -1673,6 +1821,7 @@ class Session:
|
||||
if self.pairing_method in (
|
||||
PairingMethod.JUST_WORKS,
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
PairingMethod.OOB,
|
||||
):
|
||||
# We can now send the confirmation value
|
||||
self.send_pairing_confirm_command()
|
||||
@@ -1701,7 +1850,6 @@ class Session:
|
||||
else:
|
||||
self.send_pairing_dhkey_check_command()
|
||||
else:
|
||||
assert self.ltk
|
||||
self.start_encryption(self.ltk)
|
||||
|
||||
def on_smp_pairing_failed_command(
|
||||
@@ -1751,6 +1899,7 @@ class Manager(EventEmitter):
|
||||
sessions: Dict[int, Session]
|
||||
pairing_config_factory: Callable[[Connection], PairingConfig]
|
||||
session_proxy: Type[Session]
|
||||
_ecc_key: Optional[crypto.EccKey]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1845,10 +1994,8 @@ class Manager(EventEmitter):
|
||||
) -> None:
|
||||
# Store the keys in the key store
|
||||
if self.device.keystore and identity_address is not None:
|
||||
self.device.abort_on(
|
||||
'flush', self.device.update_keys(str(identity_address), keys)
|
||||
)
|
||||
|
||||
# Make sure on_pairing emits after key update.
|
||||
await self.device.update_keys(str(identity_address), keys)
|
||||
# Notify the device
|
||||
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import datetime
|
||||
from typing import BinaryIO, Generator
|
||||
import os
|
||||
|
||||
from bumble import core
|
||||
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
|
||||
|
||||
|
||||
@@ -138,13 +139,13 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
|
||||
|
||||
"""
|
||||
if ':' not in spec:
|
||||
raise ValueError('snooper type prefix missing')
|
||||
raise core.InvalidArgumentError('snooper type prefix missing')
|
||||
|
||||
snooper_type, snooper_args = spec.split(':', maxsplit=1)
|
||||
|
||||
if snooper_type == 'btsnoop':
|
||||
if ':' not in snooper_args:
|
||||
raise ValueError('I/O type for btsnoop snooper type missing')
|
||||
raise core.InvalidArgumentError('I/O type for btsnoop snooper type missing')
|
||||
|
||||
io_type, io_name = snooper_args.split(':', maxsplit=1)
|
||||
if io_type == 'file':
|
||||
@@ -165,6 +166,6 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
|
||||
_SNOOPER_INSTANCE_COUNT -= 1
|
||||
return
|
||||
|
||||
raise ValueError(f'I/O type {io_type} not supported')
|
||||
raise core.InvalidArgumentError(f'I/O type {io_type} not supported')
|
||||
|
||||
raise ValueError(f'snooper type {snooper_type} not found')
|
||||
raise core.InvalidArgumentError(f'snooper type {snooper_type} not found')
|
||||
|
||||
@@ -18,8 +18,9 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .common import Transport, AsyncPipeSink, SnoopingTransport
|
||||
from .common import Transport, AsyncPipeSink, SnoopingTransport, TransportSpecError
|
||||
from ..snoop import create_snooper
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -52,8 +53,16 @@ def _wrap_transport(transport: Transport) -> Transport:
|
||||
async def open_transport(name: str) -> Transport:
|
||||
"""
|
||||
Open a transport by name.
|
||||
The name must be <type>:<parameters>
|
||||
Where <parameters> depend on the type (and may be empty for some types).
|
||||
The name must be <type>:<metadata><parameters>
|
||||
Where <parameters> depend on the type (and may be empty for some types), and
|
||||
<metadata> is either omitted, or a ,-separated list of <key>=<value> pairs,
|
||||
enclosed in [].
|
||||
If there are not metadata or parameter, the : after the <type> may be omitted.
|
||||
Examples:
|
||||
* usb:0
|
||||
* usb:[driver=rtk]0
|
||||
* android-netsim
|
||||
|
||||
The supported types are:
|
||||
* serial
|
||||
* udp
|
||||
@@ -71,89 +80,107 @@ async def open_transport(name: str) -> Transport:
|
||||
* android-netsim
|
||||
"""
|
||||
|
||||
return _wrap_transport(await _open_transport(name))
|
||||
scheme, *tail = name.split(':', 1)
|
||||
spec = tail[0] if tail else None
|
||||
metadata = None
|
||||
if spec:
|
||||
# Metadata may precede the spec
|
||||
if spec.startswith('['):
|
||||
metadata_str, *tail = spec[1:].split(']')
|
||||
spec = tail[0] if tail else None
|
||||
metadata = dict([entry.split('=') for entry in metadata_str.split(',')])
|
||||
|
||||
transport = await _open_transport(scheme, spec)
|
||||
if metadata:
|
||||
transport.source.metadata = { # type: ignore[attr-defined]
|
||||
**metadata,
|
||||
**getattr(transport.source, 'metadata', {}),
|
||||
}
|
||||
# pylint: disable=line-too-long
|
||||
logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined]
|
||||
|
||||
return _wrap_transport(transport)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def _open_transport(name: str) -> Transport:
|
||||
async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pylint: disable=too-many-return-statements
|
||||
|
||||
scheme, *spec = name.split(':', 1)
|
||||
if scheme == 'serial' and spec:
|
||||
from .serial import open_serial_transport
|
||||
|
||||
return await open_serial_transport(spec[0])
|
||||
return await open_serial_transport(spec)
|
||||
|
||||
if scheme == 'udp' and spec:
|
||||
from .udp import open_udp_transport
|
||||
|
||||
return await open_udp_transport(spec[0])
|
||||
return await open_udp_transport(spec)
|
||||
|
||||
if scheme == 'tcp-client' and spec:
|
||||
from .tcp_client import open_tcp_client_transport
|
||||
|
||||
return await open_tcp_client_transport(spec[0])
|
||||
return await open_tcp_client_transport(spec)
|
||||
|
||||
if scheme == 'tcp-server' and spec:
|
||||
from .tcp_server import open_tcp_server_transport
|
||||
|
||||
return await open_tcp_server_transport(spec[0])
|
||||
return await open_tcp_server_transport(spec)
|
||||
|
||||
if scheme == 'ws-client' and spec:
|
||||
from .ws_client import open_ws_client_transport
|
||||
|
||||
return await open_ws_client_transport(spec[0])
|
||||
return await open_ws_client_transport(spec)
|
||||
|
||||
if scheme == 'ws-server' and spec:
|
||||
from .ws_server import open_ws_server_transport
|
||||
|
||||
return await open_ws_server_transport(spec[0])
|
||||
return await open_ws_server_transport(spec)
|
||||
|
||||
if scheme == 'pty':
|
||||
from .pty import open_pty_transport
|
||||
|
||||
return await open_pty_transport(spec[0] if spec else None)
|
||||
return await open_pty_transport(spec)
|
||||
|
||||
if scheme == 'file':
|
||||
from .file import open_file_transport
|
||||
|
||||
assert spec is not None
|
||||
return await open_file_transport(spec[0])
|
||||
return await open_file_transport(spec)
|
||||
|
||||
if scheme == 'vhci':
|
||||
from .vhci import open_vhci_transport
|
||||
|
||||
return await open_vhci_transport(spec[0] if spec else None)
|
||||
return await open_vhci_transport(spec)
|
||||
|
||||
if scheme == 'hci-socket':
|
||||
from .hci_socket import open_hci_socket_transport
|
||||
|
||||
return await open_hci_socket_transport(spec[0] if spec else None)
|
||||
return await open_hci_socket_transport(spec)
|
||||
|
||||
if scheme == 'usb':
|
||||
from .usb import open_usb_transport
|
||||
|
||||
assert spec is not None
|
||||
return await open_usb_transport(spec[0])
|
||||
assert spec
|
||||
return await open_usb_transport(spec)
|
||||
|
||||
if scheme == 'pyusb':
|
||||
from .pyusb import open_pyusb_transport
|
||||
|
||||
assert spec is not None
|
||||
return await open_pyusb_transport(spec[0])
|
||||
assert spec
|
||||
return await open_pyusb_transport(spec)
|
||||
|
||||
if scheme == 'android-emulator':
|
||||
from .android_emulator import open_android_emulator_transport
|
||||
|
||||
return await open_android_emulator_transport(spec[0] if spec else None)
|
||||
return await open_android_emulator_transport(spec)
|
||||
|
||||
if scheme == 'android-netsim':
|
||||
from .android_netsim import open_android_netsim_transport
|
||||
|
||||
return await open_android_netsim_transport(spec[0] if spec else None)
|
||||
return await open_android_netsim_transport(spec)
|
||||
|
||||
raise ValueError('unknown transport scheme')
|
||||
raise TransportSpecError('unknown transport scheme')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -170,12 +197,13 @@ async def open_transport_or_link(name: str) -> Transport:
|
||||
|
||||
"""
|
||||
if name.startswith('link-relay:'):
|
||||
logger.warning('Link Relay has been deprecated.')
|
||||
from ..controller import Controller
|
||||
from ..link import RemoteLink # lazy import
|
||||
|
||||
link = RemoteLink(name[11:])
|
||||
await link.wait_until_connected()
|
||||
controller = Controller('remote', link=link)
|
||||
controller = Controller('remote', link=link) # type:ignore[arg-type]
|
||||
|
||||
class LinkTransport(Transport):
|
||||
async def close(self):
|
||||
|
||||
@@ -20,7 +20,13 @@ import grpc.aio
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
|
||||
from .common import (
|
||||
PumpedTransport,
|
||||
PumpedPacketSource,
|
||||
PumpedPacketSink,
|
||||
Transport,
|
||||
TransportSpecError,
|
||||
)
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
|
||||
@@ -69,7 +75,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||
mode = 'host'
|
||||
server_host = 'localhost'
|
||||
server_port = '8554'
|
||||
if spec is not None:
|
||||
if spec:
|
||||
params = spec.split(',')
|
||||
for param in params:
|
||||
if param.startswith('mode='):
|
||||
@@ -77,7 +83,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||
elif ':' in param:
|
||||
server_host, server_port = param.split(':')
|
||||
else:
|
||||
raise ValueError('invalid parameter')
|
||||
raise TransportSpecError('invalid parameter')
|
||||
|
||||
# Connect to the gRPC server
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
@@ -94,7 +100,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||
service = VhciForwardingServiceStub(channel)
|
||||
hci_device = HciDevice(service.attachVhci())
|
||||
else:
|
||||
raise ValueError('invalid mode')
|
||||
raise TransportSpecError('invalid mode')
|
||||
|
||||
# Create the transport object
|
||||
class EmulatorTransport(PumpedTransport):
|
||||
|
||||
@@ -31,6 +31,8 @@ from .common import (
|
||||
PumpedPacketSource,
|
||||
PumpedPacketSink,
|
||||
Transport,
|
||||
TransportSpecError,
|
||||
TransportInitError,
|
||||
)
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
@@ -135,7 +137,7 @@ async def open_android_netsim_controller_transport(
|
||||
server_host: Optional[str], server_port: int, options: Dict[str, str]
|
||||
) -> Transport:
|
||||
if not server_port:
|
||||
raise ValueError('invalid port')
|
||||
raise TransportSpecError('invalid port')
|
||||
if server_host == '_' or not server_host:
|
||||
server_host = 'localhost'
|
||||
|
||||
@@ -288,7 +290,7 @@ async def open_android_netsim_host_transport_with_address(
|
||||
instance_number = 0 if options is None else int(options.get('instance', '0'))
|
||||
server_port = find_grpc_port(instance_number)
|
||||
if not server_port:
|
||||
raise RuntimeError('gRPC server port not found')
|
||||
raise TransportInitError('gRPC server port not found')
|
||||
|
||||
# Connect to the gRPC server
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
@@ -326,7 +328,7 @@ async def open_android_netsim_host_transport_with_channel(
|
||||
|
||||
if response_type == 'error':
|
||||
logger.warning(f'received error: {response.error}')
|
||||
raise RuntimeError(response.error)
|
||||
raise TransportInitError(response.error)
|
||||
|
||||
if response_type == 'hci_packet':
|
||||
return (
|
||||
@@ -334,7 +336,7 @@ async def open_android_netsim_host_transport_with_channel(
|
||||
+ response.hci_packet.packet
|
||||
)
|
||||
|
||||
raise ValueError('unsupported response type')
|
||||
raise TransportSpecError('unsupported response type')
|
||||
|
||||
async def write(self, packet):
|
||||
await self.hci_device.write(
|
||||
@@ -429,7 +431,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
|
||||
options: Dict[str, str] = {}
|
||||
for param in params[params_offset:]:
|
||||
if '=' not in param:
|
||||
raise ValueError('invalid parameter, expected <name>=<value>')
|
||||
raise TransportSpecError('invalid parameter, expected <name>=<value>')
|
||||
option_name, option_value = param.split('=')
|
||||
options[option_name] = option_value
|
||||
|
||||
@@ -440,7 +442,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
|
||||
)
|
||||
if mode == 'controller':
|
||||
if host is None:
|
||||
raise ValueError('<host>:<port> missing')
|
||||
raise TransportSpecError('<host>:<port> missing')
|
||||
return await open_android_netsim_controller_transport(host, port, options)
|
||||
|
||||
raise ValueError('invalid mode option')
|
||||
raise TransportSpecError('invalid mode option')
|
||||
|
||||
@@ -21,8 +21,9 @@ import struct
|
||||
import asyncio
|
||||
import logging
|
||||
import io
|
||||
from typing import ContextManager, Tuple, Optional, Protocol, Dict
|
||||
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
|
||||
|
||||
from bumble import core
|
||||
from bumble import hci
|
||||
from bumble.colors import color
|
||||
from bumble.snoop import Snooper
|
||||
@@ -42,31 +43,36 @@ HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
|
||||
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
|
||||
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
|
||||
hci.HCI_EVENT_PACKET: (1, 1, 'B'),
|
||||
hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'),
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Errors
|
||||
# -----------------------------------------------------------------------------
|
||||
class TransportLostError(Exception):
|
||||
"""
|
||||
The Transport has been lost/disconnected.
|
||||
"""
|
||||
class TransportLostError(core.BaseBumbleError, RuntimeError):
|
||||
"""The Transport has been lost/disconnected."""
|
||||
|
||||
|
||||
class TransportInitError(core.BaseBumbleError, RuntimeError):
|
||||
"""Error raised when the transport cannot be initialized."""
|
||||
|
||||
|
||||
class TransportSpecError(core.BaseBumbleError, ValueError):
|
||||
"""Error raised when the transport spec is invalid."""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Typing Protocols
|
||||
# -----------------------------------------------------------------------------
|
||||
class TransportSink(Protocol):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
...
|
||||
def on_packet(self, packet: bytes) -> None: ...
|
||||
|
||||
|
||||
class TransportSource(Protocol):
|
||||
terminated: asyncio.Future[None]
|
||||
|
||||
def set_packet_sink(self, sink: TransportSink) -> None:
|
||||
...
|
||||
def set_packet_sink(self, sink: TransportSink) -> None: ...
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -133,7 +139,9 @@ class PacketParser:
|
||||
packet_type
|
||||
) or self.extended_packet_info.get(packet_type)
|
||||
if self.packet_info is None:
|
||||
raise ValueError(f'invalid packet type {packet_type}')
|
||||
raise core.InvalidPacketError(
|
||||
f'invalid packet type {packet_type}'
|
||||
)
|
||||
self.state = PacketParser.NEED_LENGTH
|
||||
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
|
||||
elif self.state == PacketParser.NEED_LENGTH:
|
||||
@@ -150,7 +158,7 @@ class PacketParser:
|
||||
try:
|
||||
self.sink.on_packet(bytes(self.packet))
|
||||
except Exception as error:
|
||||
logger.warning(
|
||||
logger.exception(
|
||||
color(f'!!! Exception in on_packet: {error}', 'red')
|
||||
)
|
||||
self.reset()
|
||||
@@ -167,29 +175,31 @@ class PacketReader:
|
||||
|
||||
def __init__(self, source: io.BufferedReader) -> None:
|
||||
self.source = source
|
||||
self.at_end = False
|
||||
|
||||
def next_packet(self) -> Optional[bytes]:
|
||||
# Get the packet type
|
||||
packet_type = self.source.read(1)
|
||||
if len(packet_type) != 1:
|
||||
self.at_end = True
|
||||
return None
|
||||
|
||||
# Get the packet info based on its type
|
||||
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
||||
if packet_info is None:
|
||||
raise ValueError(f'invalid packet type {packet_type[0]} found')
|
||||
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
|
||||
|
||||
# Read the header (that includes the length)
|
||||
header_size = packet_info[0] + packet_info[1]
|
||||
header = self.source.read(header_size)
|
||||
if len(header) != header_size:
|
||||
raise ValueError('packet too short')
|
||||
raise core.InvalidPacketError('packet too short')
|
||||
|
||||
# Read the body
|
||||
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
|
||||
body = self.source.read(body_length)
|
||||
if len(body) != body_length:
|
||||
raise ValueError('packet too short')
|
||||
raise core.InvalidPacketError('packet too short')
|
||||
|
||||
return packet_type + header + body
|
||||
|
||||
@@ -210,7 +220,7 @@ class AsyncPacketReader:
|
||||
# Get the packet info based on its type
|
||||
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
||||
if packet_info is None:
|
||||
raise ValueError(f'invalid packet type {packet_type[0]} found')
|
||||
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
|
||||
|
||||
# Read the header (that includes the length)
|
||||
header_size = packet_info[0] + packet_info[1]
|
||||
@@ -419,11 +429,15 @@ class SnoopingTransport(Transport):
|
||||
return SnoopingTransport(
|
||||
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
|
||||
)
|
||||
raise RuntimeError('unexpected code path') # Satisfy the type checker
|
||||
raise core.UnreachableError() # Satisfy the type checker
|
||||
|
||||
class Source:
|
||||
sink: TransportSink
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
return getattr(self.source, 'metadata', {})
|
||||
|
||||
def __init__(self, source: TransportSource, snooper: Snooper):
|
||||
self.source = source
|
||||
self.snooper = snooper
|
||||
|
||||
@@ -59,10 +59,7 @@ async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
|
||||
) from error
|
||||
|
||||
# Compute the adapter index
|
||||
if spec is None:
|
||||
adapter_index = 0
|
||||
else:
|
||||
adapter_index = int(spec)
|
||||
adapter_index = int(spec) if spec else 0
|
||||
|
||||
# Bind the socket
|
||||
# NOTE: since Python doesn't support binding with the required address format (yet),
|
||||
|
||||
@@ -23,11 +23,24 @@ import time
|
||||
import usb.core
|
||||
import usb.util
|
||||
|
||||
from .common import Transport, ParserSource
|
||||
from typing import Optional
|
||||
from usb.core import Device as UsbDevice
|
||||
from usb.core import USBError
|
||||
from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER
|
||||
from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB
|
||||
|
||||
from .common import Transport, ParserSource, TransportInitError
|
||||
from .. import hci
|
||||
from ..colors import color
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constant
|
||||
# -----------------------------------------------------------------------------
|
||||
USB_PORT_FEATURE_POWER = 8
|
||||
POWER_CYCLE_DELAY = 1
|
||||
RESET_DELAY = 3
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -113,9 +126,10 @@ async def open_pyusb_transport(spec: str) -> Transport:
|
||||
self.loop.call_soon_threadsafe(self.stop_event.set)
|
||||
|
||||
class UsbPacketSource(asyncio.Protocol, ParserSource):
|
||||
def __init__(self, device, sco_enabled):
|
||||
def __init__(self, device, metadata, sco_enabled):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.metadata = metadata
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.queue = asyncio.Queue()
|
||||
self.dequeue_task = None
|
||||
@@ -213,9 +227,22 @@ async def open_pyusb_transport(spec: str) -> Transport:
|
||||
usb_find = libusb_package.find
|
||||
|
||||
# Find the device according to the spec moniker
|
||||
power_cycle = False
|
||||
if spec.startswith('!'):
|
||||
power_cycle = True
|
||||
spec = spec[1:]
|
||||
if ':' in spec:
|
||||
vendor_id, product_id = spec.split(':')
|
||||
device = usb_find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16))
|
||||
elif '-' in spec:
|
||||
|
||||
def device_path(device):
|
||||
if device.port_numbers:
|
||||
return f'{device.bus}-{".".join(map(str, device.port_numbers))}'
|
||||
else:
|
||||
return str(device.bus)
|
||||
|
||||
device = usb_find(custom_match=lambda device: device_path(device) == spec)
|
||||
else:
|
||||
device_index = int(spec)
|
||||
devices = list(
|
||||
@@ -232,9 +259,20 @@ async def open_pyusb_transport(spec: str) -> Transport:
|
||||
device = None
|
||||
|
||||
if device is None:
|
||||
raise ValueError('device not found')
|
||||
raise TransportInitError('device not found')
|
||||
logger.debug(f'USB Device: {device}')
|
||||
|
||||
# Power Cycle the device
|
||||
if power_cycle:
|
||||
try:
|
||||
device = await _power_cycle(device) # type: ignore
|
||||
except Exception as e:
|
||||
logging.debug(e)
|
||||
logging.info(f"Unable to power cycle {hex(device.idVendor)} {hex(device.idProduct)}") # type: ignore
|
||||
|
||||
# Collect the metadata
|
||||
device_metadata = {'vendor_id': device.idVendor, 'product_id': device.idProduct}
|
||||
|
||||
# Detach the kernel driver if needed
|
||||
if device.is_kernel_driver_active(0):
|
||||
logger.debug("detaching kernel driver")
|
||||
@@ -289,9 +327,79 @@ async def open_pyusb_transport(spec: str) -> Transport:
|
||||
# except usb.USBError:
|
||||
# logger.warning('failed to set alternate setting')
|
||||
|
||||
packet_source = UsbPacketSource(device, sco_enabled)
|
||||
packet_source = UsbPacketSource(device, device_metadata, sco_enabled)
|
||||
packet_sink = UsbPacketSink(device)
|
||||
packet_source.start()
|
||||
packet_sink.start()
|
||||
|
||||
return UsbTransport(device, packet_source, packet_sink)
|
||||
|
||||
|
||||
async def _power_cycle(device: UsbDevice) -> UsbDevice:
|
||||
"""
|
||||
For devices connected to compatible USB hubs: Performs a power cycle on a given USB device.
|
||||
This involves temporarily disabling its port on the hub and then re-enabling it.
|
||||
"""
|
||||
device_path = f'{device.bus}-{".".join(map(str, device.port_numbers))}' # type: ignore
|
||||
hub = _find_hub_by_device_path(device_path)
|
||||
|
||||
if hub:
|
||||
try:
|
||||
device_port = device.port_numbers[-1] # type: ignore
|
||||
_set_port_status(hub, device_port, False)
|
||||
await asyncio.sleep(POWER_CYCLE_DELAY)
|
||||
_set_port_status(hub, device_port, True)
|
||||
await asyncio.sleep(RESET_DELAY)
|
||||
|
||||
# Device needs to be find again otherwise it will appear as disconnected
|
||||
return usb.core.find(idVendor=device.idVendor, idProduct=device.idProduct) # type: ignore
|
||||
except USBError as e:
|
||||
logger.error(f"Adjustment needed: Please revise the udev rule for device {hex(device.idVendor)}:{hex(device.idProduct)} for proper recognition.") # type: ignore
|
||||
logger.error(e)
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def _set_port_status(device: UsbDevice, port: int, on: bool):
|
||||
"""Sets the power status of a specific port on a USB hub."""
|
||||
device.ctrl_transfer(
|
||||
bmRequestType=CTRL_TYPE_CLASS | CTRL_RECIPIENT_OTHER,
|
||||
bRequest=REQ_SET_FEATURE if on else REQ_CLEAR_FEATURE,
|
||||
wIndex=port,
|
||||
wValue=USB_PORT_FEATURE_POWER,
|
||||
)
|
||||
|
||||
|
||||
def _find_device_by_path(sys_path: str) -> Optional[UsbDevice]:
|
||||
"""Finds a USB device based on its system path."""
|
||||
bus_num, *port_parts = sys_path.split('-')
|
||||
ports = [int(port) for port in port_parts[0].split('.')]
|
||||
devices = usb.core.find(find_all=True, bus=int(bus_num))
|
||||
if devices:
|
||||
for device in devices:
|
||||
if device.bus == int(bus_num) and list(device.port_numbers) == ports: # type: ignore
|
||||
return device
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_hub_by_device_path(sys_path: str) -> Optional[UsbDevice]:
|
||||
"""Finds the USB hub associated with a specific device path."""
|
||||
hub_sys_path = sys_path.rsplit('.', 1)[0]
|
||||
hub_device = _find_device_by_path(hub_sys_path)
|
||||
|
||||
if hub_device is None:
|
||||
return None
|
||||
else:
|
||||
return hub_device if _is_hub(hub_device) else None
|
||||
|
||||
|
||||
def _is_hub(device: UsbDevice) -> bool:
|
||||
"""Checks if a USB device is a hub"""
|
||||
if device.bDeviceClass == CLASS_HUB: # type: ignore
|
||||
return True
|
||||
for config in device:
|
||||
for interface in config:
|
||||
if interface.bInterfaceClass == CLASS_HUB: # type: ignore
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
|
||||
from .common import Transport, StreamPacketSource
|
||||
|
||||
@@ -28,6 +29,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# A pass-through function to ease mock testing.
|
||||
async def _create_server(*args, **kw_args):
|
||||
await asyncio.get_running_loop().create_server(*args, **kw_args)
|
||||
|
||||
|
||||
async def open_tcp_server_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a TCP server transport.
|
||||
@@ -38,7 +46,22 @@ async def open_tcp_server_transport(spec: str) -> Transport:
|
||||
|
||||
Example: _:9001
|
||||
'''
|
||||
local_host, local_port = spec.split(':')
|
||||
return await _open_tcp_server_transport_impl(
|
||||
host=local_host if local_host != '_' else None, port=int(local_port)
|
||||
)
|
||||
|
||||
|
||||
async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transport:
|
||||
'''
|
||||
Open a TCP server transport with an existing socket.
|
||||
|
||||
One reason to use this variant is to let python pick an unused port.
|
||||
'''
|
||||
return await _open_tcp_server_transport_impl(sock=sock)
|
||||
|
||||
|
||||
async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
|
||||
class TcpServerTransport(Transport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
@@ -77,13 +100,10 @@ async def open_tcp_server_transport(spec: str) -> Transport:
|
||||
else:
|
||||
logger.debug('no client, dropping packet')
|
||||
|
||||
local_host, local_port = spec.split(':')
|
||||
packet_source = StreamPacketSource()
|
||||
packet_sink = TcpServerPacketSink()
|
||||
await asyncio.get_running_loop().create_server(
|
||||
lambda: TcpServerProtocol(packet_source, packet_sink),
|
||||
host=local_host if local_host != '_' else None,
|
||||
port=int(local_port),
|
||||
await _create_server(
|
||||
lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs
|
||||
)
|
||||
|
||||
return TcpServerTransport(packet_source, packet_sink)
|
||||
|
||||
@@ -15,18 +15,18 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import collections
|
||||
import ctypes
|
||||
import platform
|
||||
|
||||
import usb1
|
||||
|
||||
from .common import Transport, ParserSource
|
||||
from .. import hci
|
||||
from ..colors import color
|
||||
from bumble.transport.common import Transport, ParserSource, TransportInitError
|
||||
from bumble import hci
|
||||
from bumble.colors import color
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -107,20 +107,24 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
|
||||
)
|
||||
|
||||
READ_SIZE = 1024
|
||||
READ_SIZE = 4096
|
||||
|
||||
class UsbPacketSink:
|
||||
def __init__(self, device, acl_out):
|
||||
self.device = device
|
||||
self.acl_out = acl_out
|
||||
self.transfer = device.getTransfer()
|
||||
self.packets = collections.deque() # Queue of packets waiting to be sent
|
||||
self.acl_out_transfer = device.getTransfer()
|
||||
self.acl_out_transfer_ready = asyncio.Semaphore(1)
|
||||
self.packets: asyncio.Queue[bytes] = (
|
||||
asyncio.Queue()
|
||||
) # Queue of packets waiting to be sent
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.queue_task = None
|
||||
self.cancel_done = self.loop.create_future()
|
||||
self.closed = False
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
self.queue_task = asyncio.create_task(self.process_queue())
|
||||
|
||||
def on_packet(self, packet):
|
||||
# Ignore packets if we're closed
|
||||
@@ -132,72 +136,71 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
return
|
||||
|
||||
# Queue the packet
|
||||
self.packets.append(packet)
|
||||
if len(self.packets) == 1:
|
||||
# The queue was previously empty, re-prime the pump
|
||||
self.process_queue()
|
||||
self.packets.put_nowait(packet)
|
||||
|
||||
def on_packet_sent(self, transfer):
|
||||
def transfer_callback(self, transfer):
|
||||
self.acl_out_transfer_ready.release()
|
||||
status = transfer.getStatus()
|
||||
# logger.debug(f'<<< USB out transfer callback: status={status}')
|
||||
|
||||
# pylint: disable=no-member
|
||||
if status == usb1.TRANSFER_COMPLETED:
|
||||
self.loop.call_soon_threadsafe(self.on_packet_sent_)
|
||||
elif status == usb1.TRANSFER_CANCELLED:
|
||||
if status == usb1.TRANSFER_CANCELLED:
|
||||
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
|
||||
else:
|
||||
return
|
||||
|
||||
if status != usb1.TRANSFER_COMPLETED:
|
||||
logger.warning(
|
||||
color(f'!!! out transfer not completed: status={status}', 'red')
|
||||
color(f'!!! OUT transfer not completed: status={status}', 'red')
|
||||
)
|
||||
|
||||
def on_packet_sent_(self):
|
||||
if self.packets:
|
||||
self.packets.popleft()
|
||||
self.process_queue()
|
||||
async def process_queue(self):
|
||||
while True:
|
||||
# Wait for a packet to transfer.
|
||||
packet = await self.packets.get()
|
||||
|
||||
def process_queue(self):
|
||||
if len(self.packets) == 0:
|
||||
return # Nothing to do
|
||||
# Wait until we can start a transfer.
|
||||
await self.acl_out_transfer_ready.acquire()
|
||||
|
||||
packet = self.packets[0]
|
||||
packet_type = packet[0]
|
||||
if packet_type == hci.HCI_ACL_DATA_PACKET:
|
||||
self.transfer.setBulk(
|
||||
self.acl_out, packet[1:], callback=self.on_packet_sent
|
||||
)
|
||||
logger.debug('submit ACL')
|
||||
self.transfer.submit()
|
||||
elif packet_type == hci.HCI_COMMAND_PACKET:
|
||||
self.transfer.setControl(
|
||||
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
packet[1:],
|
||||
callback=self.on_packet_sent,
|
||||
)
|
||||
logger.debug('submit COMMAND')
|
||||
self.transfer.submit()
|
||||
else:
|
||||
logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
|
||||
# Transfer the packet.
|
||||
packet_type = packet[0]
|
||||
if packet_type == hci.HCI_ACL_DATA_PACKET:
|
||||
self.acl_out_transfer.setBulk(
|
||||
self.acl_out, packet[1:], callback=self.transfer_callback
|
||||
)
|
||||
self.acl_out_transfer.submit()
|
||||
elif packet_type == hci.HCI_COMMAND_PACKET:
|
||||
self.acl_out_transfer.setControl(
|
||||
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
packet[1:],
|
||||
callback=self.transfer_callback,
|
||||
)
|
||||
self.acl_out_transfer.submit()
|
||||
else:
|
||||
logger.warning(
|
||||
color(f'unsupported packet type {packet_type}', 'red')
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
if self.queue_task:
|
||||
self.queue_task.cancel()
|
||||
|
||||
async def terminate(self):
|
||||
if not self.closed:
|
||||
self.close()
|
||||
|
||||
# Empty the packet queue so that we don't send any more data
|
||||
self.packets.clear()
|
||||
while not self.packets.empty():
|
||||
self.packets.get_nowait()
|
||||
|
||||
# If we have a transfer in flight, cancel it
|
||||
if self.transfer.isSubmitted():
|
||||
if self.acl_out_transfer.isSubmitted():
|
||||
# Try to cancel the transfer, but that may fail because it may have
|
||||
# already completed
|
||||
try:
|
||||
self.transfer.cancel()
|
||||
self.acl_out_transfer.cancel()
|
||||
|
||||
logger.debug('waiting for OUT transfer cancellation to be done...')
|
||||
await self.cancel_done
|
||||
@@ -206,27 +209,22 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
logger.debug('OUT transfer likely already completed')
|
||||
|
||||
class UsbPacketSource(asyncio.Protocol, ParserSource):
|
||||
def __init__(self, context, device, metadata, acl_in, events_in):
|
||||
def __init__(self, device, metadata, acl_in, events_in):
|
||||
super().__init__()
|
||||
self.context = context
|
||||
self.device = device
|
||||
self.metadata = metadata
|
||||
self.acl_in = acl_in
|
||||
self.acl_in_transfer = None
|
||||
self.events_in = events_in
|
||||
self.events_in_transfer = None
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.queue = asyncio.Queue()
|
||||
self.dequeue_task = None
|
||||
self.closed = False
|
||||
self.event_loop_done = self.loop.create_future()
|
||||
self.cancel_done = {
|
||||
hci.HCI_EVENT_PACKET: self.loop.create_future(),
|
||||
hci.HCI_ACL_DATA_PACKET: self.loop.create_future(),
|
||||
}
|
||||
self.events_in_transfer = None
|
||||
self.acl_in_transfer = None
|
||||
|
||||
# Create a thread to process events
|
||||
self.event_thread = threading.Thread(target=self.run)
|
||||
self.closed = False
|
||||
|
||||
def start(self):
|
||||
# Set up transfer objects for input
|
||||
@@ -234,7 +232,7 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
self.events_in_transfer.setInterrupt(
|
||||
self.events_in,
|
||||
READ_SIZE,
|
||||
callback=self.on_packet_received,
|
||||
callback=self.transfer_callback,
|
||||
user_data=hci.HCI_EVENT_PACKET,
|
||||
)
|
||||
self.events_in_transfer.submit()
|
||||
@@ -243,22 +241,23 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
self.acl_in_transfer.setBulk(
|
||||
self.acl_in,
|
||||
READ_SIZE,
|
||||
callback=self.on_packet_received,
|
||||
callback=self.transfer_callback,
|
||||
user_data=hci.HCI_ACL_DATA_PACKET,
|
||||
)
|
||||
self.acl_in_transfer.submit()
|
||||
|
||||
self.dequeue_task = self.loop.create_task(self.dequeue())
|
||||
self.event_thread.start()
|
||||
|
||||
def on_packet_received(self, transfer):
|
||||
@property
|
||||
def usb_transfer_submitted(self):
|
||||
return (
|
||||
self.events_in_transfer.isSubmitted()
|
||||
or self.acl_in_transfer.isSubmitted()
|
||||
)
|
||||
|
||||
def transfer_callback(self, transfer):
|
||||
packet_type = transfer.getUserData()
|
||||
status = transfer.getStatus()
|
||||
# logger.debug(
|
||||
# f'<<< USB IN transfer callback: status={status} '
|
||||
# f'packet_type={packet_type} '
|
||||
# f'length={transfer.getActualLength()}'
|
||||
# )
|
||||
|
||||
# pylint: disable=no-member
|
||||
if status == usb1.TRANSFER_COMPLETED:
|
||||
@@ -267,18 +266,18 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
+ transfer.getBuffer()[: transfer.getActualLength()]
|
||||
)
|
||||
self.loop.call_soon_threadsafe(self.queue.put_nowait, packet)
|
||||
|
||||
# Re-submit the transfer so we can receive more data
|
||||
transfer.submit()
|
||||
elif status == usb1.TRANSFER_CANCELLED:
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.cancel_done[packet_type].set_result, None
|
||||
)
|
||||
return
|
||||
else:
|
||||
logger.warning(
|
||||
color(f'!!! transfer not completed: status={status}', 'red')
|
||||
color(f'!!! IN transfer not completed: status={status}', 'red')
|
||||
)
|
||||
|
||||
# Re-submit the transfer so we can receive more data
|
||||
transfer.submit()
|
||||
self.loop.call_soon_threadsafe(self.on_transport_lost)
|
||||
|
||||
async def dequeue(self):
|
||||
while not self.closed:
|
||||
@@ -288,21 +287,6 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
return
|
||||
self.parser.feed_data(packet)
|
||||
|
||||
def run(self):
|
||||
logger.debug('starting USB event loop')
|
||||
while (
|
||||
self.events_in_transfer.isSubmitted()
|
||||
or self.acl_in_transfer.isSubmitted()
|
||||
):
|
||||
# pylint: disable=no-member
|
||||
try:
|
||||
self.context.handleEvents()
|
||||
except usb1.USBErrorInterrupted:
|
||||
pass
|
||||
|
||||
logger.debug('USB event loop done')
|
||||
self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
@@ -331,15 +315,14 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
f'IN[{packet_type}] transfer likely already completed'
|
||||
)
|
||||
|
||||
# Wait for the thread to terminate
|
||||
await self.event_loop_done
|
||||
|
||||
class UsbTransport(Transport):
|
||||
def __init__(self, context, device, interface, setting, source, sink):
|
||||
super().__init__(source, sink)
|
||||
self.context = context
|
||||
self.device = device
|
||||
self.interface = interface
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.event_loop_done = self.loop.create_future()
|
||||
|
||||
# Get exclusive access
|
||||
device.claimInterface(interface)
|
||||
@@ -352,6 +335,22 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
source.start()
|
||||
sink.start()
|
||||
|
||||
# Create a thread to process events
|
||||
self.event_thread = threading.Thread(target=self.run)
|
||||
self.event_thread.start()
|
||||
|
||||
def run(self):
|
||||
logger.debug('starting USB event loop')
|
||||
while self.source.usb_transfer_submitted:
|
||||
# pylint: disable=no-member
|
||||
try:
|
||||
self.context.handleEvents()
|
||||
except usb1.USBErrorInterrupted:
|
||||
pass
|
||||
|
||||
logger.debug('USB event loop done')
|
||||
self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
|
||||
|
||||
async def close(self):
|
||||
self.source.close()
|
||||
self.sink.close()
|
||||
@@ -361,6 +360,9 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
self.device.close()
|
||||
self.context.close()
|
||||
|
||||
# Wait for the thread to terminate
|
||||
await self.event_loop_done
|
||||
|
||||
# Find the device according to the spec moniker
|
||||
load_libusb()
|
||||
context = usb1.USBContext()
|
||||
@@ -399,6 +401,16 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
break
|
||||
device_index -= 1
|
||||
device.close()
|
||||
elif '-' in spec:
|
||||
|
||||
def device_path(device):
|
||||
return f'{device.getBusNumber()}-{".".join(map(str, device.getPortNumberList()))}'
|
||||
|
||||
for device in context.getDeviceIterator(skip_on_error=True):
|
||||
if device_path(device) == spec:
|
||||
found = device
|
||||
break
|
||||
device.close()
|
||||
else:
|
||||
# Look for a compatible device by index
|
||||
def device_is_bluetooth_hci(device):
|
||||
@@ -435,14 +447,14 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
|
||||
if found is None:
|
||||
context.close()
|
||||
raise ValueError('device not found')
|
||||
raise TransportInitError('device not found')
|
||||
|
||||
logger.debug(f'USB Device: {found}')
|
||||
|
||||
# Look for the first interface with the right class and endpoints
|
||||
def find_endpoints(device):
|
||||
# pylint: disable-next=too-many-nested-blocks
|
||||
for (configuration_index, configuration) in enumerate(device):
|
||||
for configuration_index, configuration in enumerate(device):
|
||||
interface = None
|
||||
for interface in configuration:
|
||||
setting = None
|
||||
@@ -500,7 +512,7 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
|
||||
endpoints = find_endpoints(found)
|
||||
if endpoints is None:
|
||||
raise ValueError('no compatible interface found for device')
|
||||
raise TransportInitError('no compatible interface found for device')
|
||||
(configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
|
||||
logger.debug(
|
||||
f'selected endpoints: configuration={configuration}, '
|
||||
@@ -540,7 +552,7 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
except usb1.USBError:
|
||||
logger.warning('failed to set configuration')
|
||||
|
||||
source = UsbPacketSource(context, device, device_metadata, acl_in, events_in)
|
||||
source = UsbPacketSource(device, device_metadata, acl_in, events_in)
|
||||
sink = UsbPacketSink(device, acl_out)
|
||||
return UsbTransport(context, device, interface, setting, source, sink)
|
||||
except usb1.USBError as error:
|
||||
|
||||
139
bumble/utils.py
139
bumble/utils.py
@@ -17,10 +17,12 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
import collections
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Set,
|
||||
@@ -33,7 +35,7 @@ from typing import (
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
from functools import wraps
|
||||
|
||||
from pyee import EventEmitter
|
||||
|
||||
from .colors import color
|
||||
@@ -115,12 +117,12 @@ class EventWatcher:
|
||||
self.handlers = []
|
||||
|
||||
@overload
|
||||
def on(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
|
||||
...
|
||||
def on(
|
||||
self, emitter: EventEmitter, event: str
|
||||
) -> Callable[[_Handler], _Handler]: ...
|
||||
|
||||
@overload
|
||||
def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
|
||||
...
|
||||
def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler: ...
|
||||
|
||||
def on(
|
||||
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
|
||||
@@ -130,23 +132,26 @@ class EventWatcher:
|
||||
Args:
|
||||
emitter: EventEmitter to watch
|
||||
event: Event name
|
||||
handler: (Optional) Event handler. When nothing is passed, this method works as a decorator.
|
||||
handler: (Optional) Event handler. When nothing is passed, this method
|
||||
works as a decorator.
|
||||
'''
|
||||
|
||||
def wrapper(f: _Handler) -> _Handler:
|
||||
self.handlers.append((emitter, event, f))
|
||||
emitter.on(event, f)
|
||||
return f
|
||||
def wrapper(wrapped: _Handler) -> _Handler:
|
||||
self.handlers.append((emitter, event, wrapped))
|
||||
emitter.on(event, wrapped)
|
||||
return wrapped
|
||||
|
||||
return wrapper if handler is None else wrapper(handler)
|
||||
|
||||
@overload
|
||||
def once(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
|
||||
...
|
||||
def once(
|
||||
self, emitter: EventEmitter, event: str
|
||||
) -> Callable[[_Handler], _Handler]: ...
|
||||
|
||||
@overload
|
||||
def once(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
|
||||
...
|
||||
def once(
|
||||
self, emitter: EventEmitter, event: str, handler: _Handler
|
||||
) -> _Handler: ...
|
||||
|
||||
def once(
|
||||
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
|
||||
@@ -156,13 +161,14 @@ class EventWatcher:
|
||||
Args:
|
||||
emitter: EventEmitter to watch
|
||||
event: Event name
|
||||
handler: (Optional) Event handler. When nothing passed, this method works as a decorator.
|
||||
handler: (Optional) Event handler. When nothing passed, this method works
|
||||
as a decorator.
|
||||
'''
|
||||
|
||||
def wrapper(f: _Handler) -> _Handler:
|
||||
self.handlers.append((emitter, event, f))
|
||||
emitter.once(event, f)
|
||||
return f
|
||||
def wrapper(wrapped: _Handler) -> _Handler:
|
||||
self.handlers.append((emitter, event, wrapped))
|
||||
emitter.once(event, wrapped)
|
||||
return wrapped
|
||||
|
||||
return wrapper if handler is None else wrapper(handler)
|
||||
|
||||
@@ -222,13 +228,13 @@ class CompositeEventEmitter(AbortableEventEmitter):
|
||||
if self._listener:
|
||||
# Call the deregistration methods for each base class that has them
|
||||
for cls in self._listener.__class__.mro():
|
||||
if hasattr(cls, '_bumble_register_composite'):
|
||||
cls._bumble_deregister_composite(listener, self)
|
||||
if '_bumble_register_composite' in cls.__dict__:
|
||||
cls._bumble_deregister_composite(self._listener, self)
|
||||
self._listener = listener
|
||||
if listener:
|
||||
# Call the registration methods for each base class that has them
|
||||
for cls in listener.__class__.mro():
|
||||
if hasattr(cls, '_bumble_deregister_composite'):
|
||||
if '_bumble_deregister_composite' in cls.__dict__:
|
||||
cls._bumble_register_composite(listener, self)
|
||||
|
||||
|
||||
@@ -275,21 +281,18 @@ class AsyncRunner:
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
coroutine = func(*args, **kwargs)
|
||||
if queue is None:
|
||||
# Create a task to run the coroutine
|
||||
# Spawn the coroutine as a task
|
||||
async def run():
|
||||
try:
|
||||
await coroutine
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f'{color("!!! Exception in wrapper:", "red")} '
|
||||
f'{traceback.format_exc()}'
|
||||
)
|
||||
logger.exception(color("!!! Exception in wrapper:", "red"))
|
||||
|
||||
asyncio.create_task(run())
|
||||
AsyncRunner.spawn(run())
|
||||
else:
|
||||
# Queue the coroutine to be awaited by the work queue
|
||||
queue.enqueue(coroutine)
|
||||
@@ -410,3 +413,77 @@ class FlowControlAsyncPipe:
|
||||
self.resume_source()
|
||||
|
||||
self.check_pump()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_call(function, *args, **kwargs):
|
||||
"""
|
||||
Immediately calls the function with provided args and kwargs, wrapping it in an
|
||||
async function.
|
||||
Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject
|
||||
a running loop.
|
||||
|
||||
result = await async_call(some_function, ...)
|
||||
"""
|
||||
return function(*args, **kwargs)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def wrap_async(function):
|
||||
"""
|
||||
Wraps the provided function in an async function.
|
||||
"""
|
||||
return functools.partial(async_call, function)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def deprecated(msg: str):
|
||||
"""
|
||||
Throw deprecation warning before execution.
|
||||
"""
|
||||
|
||||
def wrapper(function):
|
||||
@functools.wraps(function)
|
||||
def inner(*args, **kwargs):
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
return function(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def experimental(msg: str):
|
||||
"""
|
||||
Throws a future warning before execution.
|
||||
"""
|
||||
|
||||
def wrapper(function):
|
||||
@functools.wraps(function)
|
||||
def inner(*args, **kwargs):
|
||||
warnings.warn(msg, FutureWarning)
|
||||
return function(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OpenIntEnum(enum.IntEnum):
|
||||
"""
|
||||
Subclass of enum.IntEnum that can hold integer values outside the set of
|
||||
predefined values. This is convenient for implementing protocols where some
|
||||
integer constants may be added over time.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if not isinstance(value, int):
|
||||
return None
|
||||
|
||||
obj = int.__new__(cls, value)
|
||||
obj._value_ = value
|
||||
obj._name_ = f"{cls.__name__}[{value}]"
|
||||
return obj
|
||||
|
||||
BIN
docs/images/favicon.ico
Normal file
BIN
docs/images/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
@@ -10,7 +10,7 @@ nav:
|
||||
- Contributing: development/contributing.md
|
||||
- Code Style: development/code_style.md
|
||||
- Use Cases:
|
||||
- Overview: use_cases/index.md
|
||||
- use_cases/index.md
|
||||
- Use Case 1: use_cases/use_case_1.md
|
||||
- Use Case 2: use_cases/use_case_2.md
|
||||
- Use Case 3: use_cases/use_case_3.md
|
||||
@@ -23,7 +23,7 @@ nav:
|
||||
- GATT: components/gatt.md
|
||||
- Security Manager: components/security_manager.md
|
||||
- Transports:
|
||||
- Overview: transports/index.md
|
||||
- transports/index.md
|
||||
- Serial: transports/serial.md
|
||||
- USB: transports/usb.md
|
||||
- PTY: transports/pty.md
|
||||
@@ -37,14 +37,14 @@ nav:
|
||||
- Android Emulator: transports/android_emulator.md
|
||||
- File: transports/file.md
|
||||
- Drivers:
|
||||
- Overview: drivers/index.md
|
||||
- drivers/index.md
|
||||
- Realtek: drivers/realtek.md
|
||||
- API:
|
||||
- Guide: api/guide.md
|
||||
- Examples: api/examples.md
|
||||
- Reference: api/reference.md
|
||||
- Apps & Tools:
|
||||
- Overview: apps_and_tools/index.md
|
||||
- apps_and_tools/index.md
|
||||
- Console: apps_and_tools/console.md
|
||||
- Bench: apps_and_tools/bench.md
|
||||
- Speaker: apps_and_tools/speaker.md
|
||||
@@ -57,16 +57,25 @@ nav:
|
||||
- USB Probe: apps_and_tools/usb_probe.md
|
||||
- Link Relay: apps_and_tools/link_relay.md
|
||||
- Hardware:
|
||||
- Overview: hardware/index.md
|
||||
- hardware/index.md
|
||||
- Platforms:
|
||||
- Overview: platforms/index.md
|
||||
- platforms/index.md
|
||||
- macOS: platforms/macos.md
|
||||
- Linux: platforms/linux.md
|
||||
- Windows: platforms/windows.md
|
||||
- Android: platforms/android.md
|
||||
- Zephyr: platforms/zephyr.md
|
||||
- Examples:
|
||||
- Overview: examples/index.md
|
||||
- examples/index.md
|
||||
- Extras:
|
||||
- extras/index.md
|
||||
- Android Remote HCI: extras/android_remote_hci.md
|
||||
- Android BT Bench: extras/android_bt_bench.md
|
||||
- Hive:
|
||||
- hive/index.md
|
||||
- Speaker: hive/web/speaker/speaker.html
|
||||
- Scanner: hive/web/scanner/scanner.html
|
||||
- Heart Rate Monitor: hive/web/heart_rate_monitor/heart_rate_monitor.html
|
||||
|
||||
copyright: Copyright 2021-2023 Google LLC
|
||||
|
||||
@@ -75,6 +84,8 @@ theme:
|
||||
logo: 'images/logo.png'
|
||||
favicon: 'images/favicon.ico'
|
||||
custom_dir: 'theme'
|
||||
features:
|
||||
- navigation.indexes
|
||||
|
||||
plugins:
|
||||
- mkdocstrings:
|
||||
@@ -99,6 +110,8 @@ markdown_extensions:
|
||||
- pymdownx.emoji:
|
||||
emoji_index: !!python/name:materialx.emoji.twemoji
|
||||
emoji_generator: !!python/name:materialx.emoji.to_svg
|
||||
- pymdownx.tabbed:
|
||||
alternate_style: true
|
||||
- codehilite:
|
||||
guess_lang: false
|
||||
- toc:
|
||||
|
||||
@@ -7,16 +7,36 @@ throughput and/or latency between two devices.
|
||||
# General Usage
|
||||
|
||||
```
|
||||
Usage: bench.py [OPTIONS] COMMAND [ARGS]...
|
||||
Usage: bumble-bench [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Options:
|
||||
--device-config FILENAME Device configuration file
|
||||
--role [sender|receiver|ping|pong]
|
||||
--mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server]
|
||||
--att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517]
|
||||
-s, --packet-size SIZE Packet size (server role) [8<=x<=4096]
|
||||
-c, --packet-count COUNT Packet count (server role)
|
||||
-sd, --start-delay SECONDS Start delay (server role)
|
||||
--extended-data-length TEXT Request a data length upon connection,
|
||||
specified as tx_octets/tx_time
|
||||
--rfcomm-channel INTEGER RFComm channel to use
|
||||
--rfcomm-uuid TEXT RFComm service UUID to use (ignored if
|
||||
--rfcomm-channel is not 0)
|
||||
--l2cap-psm INTEGER L2CAP PSM to use
|
||||
--l2cap-mtu INTEGER L2CAP MTU to use
|
||||
--l2cap-mps INTEGER L2CAP MPS to use
|
||||
--l2cap-max-credits INTEGER L2CAP maximum number of credits allowed for
|
||||
the peer
|
||||
-s, --packet-size SIZE Packet size (client or ping role)
|
||||
[8<=x<=4096]
|
||||
-c, --packet-count COUNT Packet count (client or ping role)
|
||||
-sd, --start-delay SECONDS Start delay (client or ping role)
|
||||
--repeat N Repeat the run N times (client and ping
|
||||
roles)(0, which is the fault, to run just
|
||||
once)
|
||||
--repeat-delay SECONDS Delay, in seconds, between repeats
|
||||
--pace MILLISECONDS Wait N milliseconds between packets (0,
|
||||
which is the fault, to send as fast as
|
||||
possible)
|
||||
--linger Don't exit at the end of a run (server and
|
||||
pong roles)
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
@@ -35,17 +55,18 @@ Options:
|
||||
--connection-interval, --ci CONNECTION_INTERVAL
|
||||
Connection interval (in ms)
|
||||
--phy [1m|2m|coded] PHY to use
|
||||
--authenticate Authenticate (RFComm only)
|
||||
--encrypt Encrypt the connection (RFComm only)
|
||||
--help Show this message and exit.
|
||||
```
|
||||
|
||||
|
||||
To test once device against another, one of the two devices must be running
|
||||
To test once device against another, one of the two devices must be running
|
||||
the ``peripheral`` command and the other the ``central`` command. The device
|
||||
running the ``peripheral`` command will accept connections from the device
|
||||
running the ``central`` command.
|
||||
When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils),
|
||||
the default addresses configured in the tool should be sufficient. But when using
|
||||
Bluetooth Classic, the address of the Peripheral must be specified on the Central
|
||||
the default addresses configured in the tool should be sufficient. But when using
|
||||
Bluetooth Classic, the address of the Peripheral must be specified on the Central
|
||||
using the ``--peripheral`` option. The address will be printed by the Peripheral when
|
||||
it starts.
|
||||
|
||||
@@ -83,7 +104,7 @@ the other on `usb:1`, and two consoles/terminals. We will run a command in each.
|
||||
$ bumble-bench central usb:1
|
||||
```
|
||||
|
||||
In this default configuration, the Central runs a Sender, as a GATT client,
|
||||
In this default configuration, the Central runs a Sender, as a GATT client,
|
||||
connecting to the Peripheral running a Receiver, as a GATT server.
|
||||
|
||||
!!! example "L2CAP Throughput"
|
||||
|
||||
@@ -12,12 +12,25 @@ a host that send custom HCI commands that the controller may not understand.
|
||||
```
|
||||
python hci_bridge.py <host-transport-spec> <controller-transport-spec> [command-short-circuit-list]
|
||||
```
|
||||
The command-short-circuit-list field is specified by a series of comma separated Opcode Group
|
||||
Field (OGF) : OpCode Command Field (OCF) pairs. The OGF/OCF values are specified in the Blutooth
|
||||
core specification.
|
||||
|
||||
For the commands that are listed in the short-circuit-list, the HCI bridge will always generate
|
||||
a Command Complete Event for the specified op code. The return parameter will be HCI_SUCCESS.
|
||||
|
||||
This feature can only be used for commands that return Command Complete. Other events will not be
|
||||
generated by the HCI bridge tool.
|
||||
|
||||
!!! example "UDP to Serial"
|
||||
```
|
||||
python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078
|
||||
```
|
||||
|
||||
In this example, the short circuit list is specified to respond to the Vendor-specific Opcode Group
|
||||
Field (0x3f) commands 0x70, 0x74, 0x77, 0x78 with Command Complete. The short circuit list can be
|
||||
used where the Host uses some HCI commands that are not supported/implemented by the Controller.
|
||||
|
||||
!!! example "PTY to Link Relay"
|
||||
```
|
||||
python hci_bridge.py serial:emulated_uart_pty,1000000 link-relay:ws://127.0.0.1:10723/test
|
||||
@@ -28,3 +41,4 @@ a host that send custom HCI commands that the controller may not understand.
|
||||
(through which the communication with other virtual controllers will be mediated).
|
||||
|
||||
NOTE: this assumes you're running a Link Relay on port `10723`.
|
||||
|
||||
|
||||
@@ -5,6 +5,15 @@ Some Bluetooth controllers require a driver to function properly.
|
||||
This may include, for instance, loading a Firmware image or patch,
|
||||
loading a configuration.
|
||||
|
||||
By default, drivers will be automatically probed to determine if they should be
|
||||
used with particular HCI controller.
|
||||
When the transport for an HCI controller is instantiated from a transport name,
|
||||
a driver may also be forced by specifying ``driver=<driver-name>`` in the optional
|
||||
metadata portion of the transport name. For example,
|
||||
``usb:[driver=rtk]0`` indicates that the ``rtk`` driver should be used with the
|
||||
first USB device, even if a normal probe would not have selected it based on the
|
||||
USB vendor ID and product ID.
|
||||
|
||||
Drivers included in the module are:
|
||||
|
||||
* [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles.
|
||||
@@ -1,13 +1,16 @@
|
||||
REALTEK DRIVER
|
||||
==============
|
||||
|
||||
This driver supports loading firmware images and optional config data to
|
||||
This driver supports loading firmware images and optional config data to
|
||||
USB dongles with a Realtek chipset.
|
||||
A number of USB dongles are supported, but likely not all.
|
||||
When using a USB dongle, the USB product ID and manufacturer ID are used
|
||||
When using a USB dongle, the USB product ID and vendor ID are used
|
||||
to find whether a matching set of firmware image and config data
|
||||
is needed for that specific model. If a match exists, the driver will try
|
||||
load the firmware image and, if needed, config data.
|
||||
Alternatively, the metadata property ``driver=rtk`` may be specified in a transport
|
||||
name to force that driver to be used (ex: ``usb:[driver=rtk]0`` instead of just
|
||||
``usb:0`` for the first USB device).
|
||||
The driver will look for those files by name, in order, in:
|
||||
|
||||
* The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR`
|
||||
|
||||
64
docs/mkdocs/src/extras/android_bt_bench.md
Normal file
64
docs/mkdocs/src/extras/android_bt_bench.md
Normal file
@@ -0,0 +1,64 @@
|
||||
ANDROID BENCH APP
|
||||
=================
|
||||
|
||||
This Android app that is compatible with the Bumble `bench` command line app.
|
||||
This app can be used to test the throughput and latency between two Android
|
||||
devices, or between an Android device and another device running the Bumble
|
||||
`bench` app.
|
||||
Only the RFComm Client, RFComm Server, L2CAP Client and L2CAP Server modes are
|
||||
supported.
|
||||
|
||||
Building
|
||||
--------
|
||||
|
||||
You can build the app by running `./gradlew build` (use `gradlew.bat` on Windows) from the `BtBench` top level directory.
|
||||
You can also build with Android Studio: open the `BtBench` project. You can build and/or debug from there.
|
||||
|
||||
If the build succeeds, you can find the app APKs (debug and release) at:
|
||||
|
||||
* [Release] ``app/build/outputs/apk/release/app-release-unsigned.apk``
|
||||
* [Debug] ``app/build/outputs/apk/debug/app-debug.apk``
|
||||
|
||||
|
||||
Running
|
||||
-------
|
||||
|
||||
### Starting the app
|
||||
You can start the app from the Android launcher, from Android Studio, or with `adb`
|
||||
|
||||
#### Launching from the launcher
|
||||
Just tap the app icon on the launcher, check the parameters, and tap
|
||||
one of the benchmark action buttons.
|
||||
|
||||
#### Launching with `adb`
|
||||
Using the `am` command, you can start the activity, and pass it arguments so that you can
|
||||
automatically start the benchmark test, and/or set the parameters.
|
||||
|
||||
| Parameter Name | Parameter Type | Description
|
||||
|------------------------|----------------|------------
|
||||
| autostart | String | Benchmark to start. (rfcomm-client, rfcomm-server, l2cap-client or l2cap-server)
|
||||
| packet-count | Integer | Number of packets to send (rfcomm-client and l2cap-client only)
|
||||
| packet-size | Integer | Number of bytes per packet (rfcomm-client and l2cap-client only)
|
||||
| peer-bluetooth-address | Integer | Peer Bluetooth address to connect to (rfcomm-client and l2cap-client | only)
|
||||
|
||||
|
||||
!!! tip "Launching from adb with auto-start"
|
||||
In this example, we auto-start the Rfcomm Server bench action.
|
||||
```bash
|
||||
$ adb shell am start -n com.github.google.bumble.btbench/.MainActivity --es autostart rfcomm-server
|
||||
```
|
||||
|
||||
!!! tip "Launching from adb with auto-start and some parameters"
|
||||
In this example, we auto-start the Rfcomm Client bench action, set the packet count to 100,
|
||||
and the packet size to 1024, and connect to DA:4C:10:DE:17:02
|
||||
```bash
|
||||
$ adb shell am start -n com.github.google.bumble.btbench/.MainActivity --es autostart rfcomm-client --ei packet-count 100 --ei packet-size 1024 --es peer-bluetooth-address DA:4C:10:DE:17:02
|
||||
```
|
||||
|
||||
#### Selecting a Peer Bluetooth Address
|
||||
The app's main activity has a "Peer Bluetooth Address" setting where you can change the address.
|
||||
|
||||
!!! note "Bluetooth Address for L2CAP vs RFComm"
|
||||
For BLE (L2CAP mode), the address of a device typically changes regularly (it is randomized for privacy), whereas the Bluetooth Classic addresses will remain the same (RFComm mode).
|
||||
If two devices are paired and bonded, then they will each "see" a non-changing address for each other even with BLE (Resolvable Private Address)
|
||||
|
||||
181
docs/mkdocs/src/extras/android_remote_hci.md
Normal file
181
docs/mkdocs/src/extras/android_remote_hci.md
Normal file
@@ -0,0 +1,181 @@
|
||||
ANDROID REMOTE HCI APP
|
||||
======================
|
||||
|
||||
This application allows using an android phone's built-in Bluetooth controller with
|
||||
a Bumble host stack running outside the phone (typically a development laptop or desktop).
|
||||
The app runs an HCI proxy between a TCP socket on the "outside" and the Bluetooth HCI HAL
|
||||
on the "inside". (See [this page](https://source.android.com/docs/core/connect/bluetooth) for a high level
|
||||
description of the Android Bluetooth HCI HAL).
|
||||
The HCI packets received on the TCP socket are forwarded to the phone's controller, and the
|
||||
packets coming from the controller are forwarded to the TCP socket.
|
||||
|
||||
|
||||
Building
|
||||
--------
|
||||
|
||||
You can build the app by running `./gradlew build` (use `gradlew.bat` on Windows) from the `extras/android/RemoteHCI` top level directory.
|
||||
You can also build with Android Studio: open the `RemoteHCI` project. You can build and/or debug from there.
|
||||
|
||||
If the build succeeds, you can find the app APKs (debug and release) at:
|
||||
|
||||
* [Release] ``app/build/outputs/apk/release/app-release-unsigned.apk``
|
||||
* [Debug] ``app/build/outputs/apk/debug/app-debug.apk``
|
||||
|
||||
|
||||
Running
|
||||
-------
|
||||
|
||||
!!! note
|
||||
In the following examples, it is assumed that shell commands are executed while in the
|
||||
app's root directory, `extras/android/RemoteHCI`. If you are in a different directory,
|
||||
adjust the relative paths accordingly.
|
||||
|
||||
### Preconditions
|
||||
When the proxy starts (tapping the "Start" button in the app's main activity, or running the proxy
|
||||
from an `adb shell` command line), it will try to bind to the Bluetooth HAL.
|
||||
This requires that there is no other HAL client, and requires certain privileges.
|
||||
For running as a regular app, this requires disabling SELinux temporarily.
|
||||
For running as a command-line executable, this just requires a root shell.
|
||||
|
||||
#### Root Shell
|
||||
!!! tip "Restart `adb` as root"
|
||||
```bash
|
||||
$ adb root
|
||||
```
|
||||
|
||||
#### Disabling SELinux
|
||||
Binding to the Bluetooth HCI HAL requires certain SELinux permissions that can't simply be changed
|
||||
on a device without rebuilding its system image. To bypass these restrictions, you will need
|
||||
to disable SELinux on your phone (please be aware that this is global, not just for the proxy app,
|
||||
so proceed with caution).
|
||||
In order to disable SELinux, you need to root the phone (it may be advisable to do this on a
|
||||
development phone).
|
||||
|
||||
!!! tip "Disabling SELinux Temporarily"
|
||||
Restart `adb` as root:
|
||||
```bash
|
||||
$ adb root
|
||||
```
|
||||
|
||||
Then disable SELinux
|
||||
```bash
|
||||
$ adb shell setenforce 0
|
||||
```
|
||||
|
||||
Once you're done using the proxy, you can restore SELinux, if you need to, with
|
||||
```bash
|
||||
$ adb shell setenforce 1
|
||||
```
|
||||
|
||||
This state will also reset to the normal SELinux enforcement when you reboot.
|
||||
|
||||
#### Stopping the bluetooth process
|
||||
Since the Bluetooth HAL service can only accept one client, and that in normal conditions
|
||||
that client is the Android's bluetooth stack, it is required to first shut down the
|
||||
Android bluetooth stack process.
|
||||
|
||||
!!! tip "Checking if the Bluetooth process is running"
|
||||
```bash
|
||||
$ adb shell "ps -A | grep com.google.android.bluetooth"
|
||||
```
|
||||
If the process is running, you will get a line like:
|
||||
```
|
||||
bluetooth 10759 876 17455796 136620 do_epoll_wait 0 S com.google.android.bluetooth
|
||||
```
|
||||
If you don't, it means that the process is not running and you are clear to proceed.
|
||||
|
||||
Simply turning Bluetooth off from the phone's settings does not ensure that the bluetooth process will exit.
|
||||
If the bluetooth process is still running after toggling Bluetooth off from the settings, you may try enabling
|
||||
Airplane Mode, then rebooting. The bluetooth process should, in theory, not restart after the reboot.
|
||||
|
||||
!!! tip "Stopping the bluetooth process with adb"
|
||||
```bash
|
||||
$ adb shell cmd bluetooth_manager disable
|
||||
```
|
||||
|
||||
### Running as a command line app
|
||||
|
||||
You push the built APK to a temporary location on the phone's filesystem, then launch the command
|
||||
line executable with an `adb shell` command.
|
||||
|
||||
!!! tip "Pushing the executable"
|
||||
```bash
|
||||
$ adb push app/build/outputs/apk/release/app-release-unsigned.apk /data/local/tmp/remotehci.apk
|
||||
```
|
||||
Do this every time you rebuild. Alternatively, you can push the `debug` APK instead:
|
||||
```bash
|
||||
$ adb push app/build/outputs/apk/debug/app-debug.apk /data/local/tmp/remotehci.apk
|
||||
```
|
||||
|
||||
!!! tip "Start the proxy from the command line"
|
||||
```bash
|
||||
adb shell "CLASSPATH=/data/local/tmp/remotehci.apk app_process /system/bin com.github.google.bumble.remotehci.CommandLineInterface"
|
||||
```
|
||||
This will run the proxy, listening on the default TCP port.
|
||||
If you want a different port, pass it as a command line parameter
|
||||
|
||||
!!! tip "Start the proxy from the command line with a specific TCP port"
|
||||
```bash
|
||||
adb shell "CLASSPATH=/data/local/tmp/remotehci.apk app_process /system/bin com.github.google.bumble.remotehci.CommandLineInterface 12345"
|
||||
```
|
||||
|
||||
### Running as a normal app
|
||||
You can start the app from the Android launcher, from Android Studio, or with `adb`
|
||||
|
||||
#### Launching from the launcher
|
||||
Just tap the app icon on the launcher, check the TCP port that is configured, and tap
|
||||
the "Start" button.
|
||||
|
||||
#### Launching with `adb`
|
||||
Using the `am` command, you can start the activity, and pass it arguments so that you can
|
||||
automatically start the proxy, and/or set the port number.
|
||||
|
||||
!!! tip "Launching from adb with auto-start"
|
||||
```bash
|
||||
$ adb shell am start -n com.github.google.bumble.remotehci/.MainActivity --ez autostart true
|
||||
```
|
||||
|
||||
!!! tip "Launching from adb with auto-start and a port"
|
||||
In this example, we auto-start the proxy upon launch, with the port set to 9995
|
||||
```bash
|
||||
$ adb shell am start -n com.github.google.bumble.remotehci/.MainActivity --ez autostart true --ei port 9995
|
||||
```
|
||||
|
||||
#### Selecting a TCP port
|
||||
The RemoteHCI app's main activity has a "TCP Port" setting where you can change the port on
|
||||
which the proxy is accepting connections. If the default value isn't suitable, you can
|
||||
change it there (you can also use the special value 0 to let the OS assign a port number for you).
|
||||
|
||||
### Connecting to the proxy
|
||||
To connect the Bumble stack to the proxy, you need to be able to reach the phone's network
|
||||
stack. This can be done over the phone's WiFi connection, or, alternatively, using an `adb`
|
||||
TCP forward (which should be faster than over WiFi).
|
||||
|
||||
!!! tip "Forwarding TCP with `adb`"
|
||||
To connect to the proxy via an `adb` TCP forward, use:
|
||||
```bash
|
||||
$ adb forward tcp:<outside-port> tcp:<inside-port>
|
||||
```
|
||||
Where ``<outside-port>`` is the port number for a listening socket on your laptop or
|
||||
desktop machine, and <inside-port> is the TCP port selected in the app's user interface.
|
||||
Those two ports may be the same, of course.
|
||||
For example, with the default TCP port 9993:
|
||||
```bash
|
||||
$ adb forward tcp:9993 tcp:9993
|
||||
```
|
||||
|
||||
Once you've ensured that you can reach the proxy's TCP port on the phone, either directly or
|
||||
via an `adb` forward, you can then use it as a Bumble transport, using the transport name:
|
||||
``tcp-client:<host>:<port>`` syntax.
|
||||
|
||||
!!! example "Connecting a Bumble client"
|
||||
Connecting the `bumble-controller-info` app to the phone's controller.
|
||||
Assuming you have set up an `adb` forward on port 9993:
|
||||
```bash
|
||||
$ bumble-controller-info tcp-client:localhost:9993
|
||||
```
|
||||
|
||||
Or over WiFi with, in this example, the IP address of the phone being ```192.168.86.27```
|
||||
```bash
|
||||
$ bumble-controller-info tcp-client:192.168.86.27:9993
|
||||
```
|
||||
19
docs/mkdocs/src/extras/index.md
Normal file
19
docs/mkdocs/src/extras/index.md
Normal file
@@ -0,0 +1,19 @@
|
||||
EXTRAS
|
||||
======
|
||||
|
||||
A collection of add-ons, apps and tools, to the Bumble project.
|
||||
|
||||
Android Remote HCI
|
||||
------------------
|
||||
|
||||
Allows using an Android phone's built-in Bluetooth controller with a Bumble
|
||||
stack running on a development machine.
|
||||
See [Android Remote HCI](android_remote_hci.md) for details.
|
||||
|
||||
Android BT Bench
|
||||
----------------
|
||||
|
||||
An Android app that is compatible with the Bumble `bench` command line app.
|
||||
This app can be used to test the throughput and latency between two Android
|
||||
devices, or between an Android device and another device running the Bumble
|
||||
`bench` app.
|
||||
@@ -3,7 +3,7 @@ HARDWARE
|
||||
|
||||
The Bumble Host connects to a controller over an [HCI Transport](../transports/index.md).
|
||||
To use a hardware controller attached to the host on which the host application is running, the transport is typically either [HCI over UART](../transports/serial.md) or [HCI over USB](../transports/usb.md).
|
||||
On Linux, the [VHCI Transport](../transports/vhci.md) can be used to communicate with any controller hardware managed by the operating system. Alternatively, a remote controller (a phyiscal controller attached to a remote host) can be used by connecting one of the networked transports (such as the [TCP Client transport](../transports/tcp_client.md), the [TCP Server transport](../transports/tcp_server.md) or the [UDP Transport](../transports/udp.md)) to an [HCI Bridge](../apps_and_tools/hci_bridge) bridging the network transport to a physical controller on a remote host.
|
||||
On Linux, the [VHCI Transport](../transports/vhci.md) can be used to communicate with any controller hardware managed by the operating system. Alternatively, a remote controller (a phyiscal controller attached to a remote host) can be used by connecting one of the networked transports (such as the [TCP Client transport](../transports/tcp_client.md), the [TCP Server transport](../transports/tcp_server.md) or the [UDP Transport](../transports/udp.md)) to an [HCI Bridge](../apps_and_tools/hci_bridge.md) bridging the network transport to a physical controller on a remote host.
|
||||
|
||||
In theory, any controller that is compliant with the HCI over UART or HCI over USB protocols can be used.
|
||||
|
||||
|
||||
59
docs/mkdocs/src/hive/index.md
Normal file
59
docs/mkdocs/src/hive/index.md
Normal file
@@ -0,0 +1,59 @@
|
||||
HIVE
|
||||
====
|
||||
|
||||
Welcome to the Bumble Hive.
|
||||
This is a collection of apps and virtual devices that can run entirely in a browser page.
|
||||
The code for the apps and devices, as well as the Bumble runtime code, runs via [Pyodide](https://pyodide.org/).
|
||||
Pyodide is a Python distribution for the browser and Node.js based on WebAssembly.
|
||||
|
||||
The Bumble stack uses a WebSocket to exchange HCI packets with a virtual or physical
|
||||
Bluetooth controller.
|
||||
|
||||
The apps and devices in the hive can be accessed by following the links below. Each
|
||||
page has a settings button that may be used to configure the WebSocket URL to use for
|
||||
the virtual HCI connection. This will typically be the WebSocket URL for a `netsim`
|
||||
daemon.
|
||||
There is also a [TOML index](index.toml) that can be used by tools to know at which URL to access
|
||||
each of the apps and devices, as well as their names and short descriptions.
|
||||
|
||||
!!! tip "Using `netsim`"
|
||||
When the `netsimd` daemon is running (for example when using the Android Emulator that
|
||||
is included in Android Studio), the daemon listens for connections on a TCP port.
|
||||
To find out what this TCP port is, you can read the `netsim.ini` file that `netsimd`
|
||||
creates, it includes a line with `web.port=<tcp-port>` (for example `web.port=7681`).
|
||||
The location of the `netsim.ini` file is platform-specific.
|
||||
|
||||
=== "macOS"
|
||||
On macOS, the directory where `netsim.ini` is stored is $TMPDIR
|
||||
```bash
|
||||
$ cat $TMPDIR/netsim.ini
|
||||
```
|
||||
|
||||
=== "Linux"
|
||||
On Linux, the directory where `netsim.ini` is stored is $XDG_RUNTIME_DIR
|
||||
```bash
|
||||
$ cat $XDG_RUNTIME_DIR/netsim.ini
|
||||
```
|
||||
|
||||
|
||||
!!! tip "Using a local radio"
|
||||
You can connect the hive virtual apps and devices to a local Bluetooth radio, like,
|
||||
for example, a USB dongle.
|
||||
For that, you need to run a local HCI bridge to bridge a local HCI device to a WebSocket
|
||||
that a web page can connect to.
|
||||
Use the `bumble-hci-bridge` app, with the host transport set to a WebSocket server on an
|
||||
available port (ex: `ws-server:_:7682`) and the controller transport set to the transport
|
||||
name for the radio you want to use (ex: `usb:0` for the first USB dongle)
|
||||
|
||||
|
||||
Applications
|
||||
------------
|
||||
|
||||
* [Scanner](web/scanner/scanner.html) - Scans for BLE devices.
|
||||
|
||||
Virtual Devices
|
||||
---------------
|
||||
|
||||
* [Speaker](web/speaker/speaker.html) - Virtual speaker that plays audio in a browser page.
|
||||
* [Heart Rate Monitor](web/heart_rate_monitor/heart_rate_monitor.html) - Virtual heart rate monitor.
|
||||
|
||||
21
docs/mkdocs/src/hive/index.toml
Normal file
21
docs/mkdocs/src/hive/index.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
version = "1.0.0"
|
||||
base_url = "https://google.github.io/bumble/hive/web"
|
||||
default_hci_query_param = "hci"
|
||||
|
||||
[[index]]
|
||||
name = "speaker"
|
||||
description = "Bumble Virtual Speaker"
|
||||
type = "Device"
|
||||
url = "speaker/speaker.html"
|
||||
|
||||
[[index]]
|
||||
name = "scanner"
|
||||
description = "Simple Scanner Application"
|
||||
type = "Application"
|
||||
url = "scanner/scanner.html"
|
||||
|
||||
[[index]]
|
||||
name = "heart-rate-monitor"
|
||||
description = "Virtual Heart Rate Monitor"
|
||||
type = "Device"
|
||||
url = "heart_rate_monitor/heart_rate_monitor.html"
|
||||
1
docs/mkdocs/src/hive/web/bumble.js
Symbolic link
1
docs/mkdocs/src/hive/web/bumble.js
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../../web/bumble.js
|
||||
@@ -0,0 +1 @@
|
||||
../../../../../../web/heart_rate_monitor/heart_rate_monitor.html
|
||||
@@ -0,0 +1 @@
|
||||
../../../../../../web/heart_rate_monitor/heart_rate_monitor.js
|
||||
@@ -0,0 +1 @@
|
||||
../../../../../../web/heart_rate_monitor/heart_rate_monitor.py
|
||||
1
docs/mkdocs/src/hive/web/scanner/scanner.css
Symbolic link
1
docs/mkdocs/src/hive/web/scanner/scanner.css
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../../../web/scanner/scanner.css
|
||||
1
docs/mkdocs/src/hive/web/scanner/scanner.html
Symbolic link
1
docs/mkdocs/src/hive/web/scanner/scanner.html
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../../../web/scanner/scanner.html
|
||||
1
docs/mkdocs/src/hive/web/scanner/scanner.js
Symbolic link
1
docs/mkdocs/src/hive/web/scanner/scanner.js
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../../../web/scanner/scanner.js
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user