mirror of
https://github.com/google/bumble.git
synced 2026-04-16 00:25:31 +00:00
Compare commits
690 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
af466c2970 | ||
|
|
931e2de854 | ||
|
|
55eb7eb237 | ||
|
|
bade4502f9 | ||
|
|
9f952f202f | ||
|
|
5a477eb391 | ||
|
|
86cda8771d | ||
|
|
c1ea0ddd35 | ||
|
|
f567711a6c | ||
|
|
509df4c676 | ||
|
|
b375ed07b4 | ||
|
|
69d62d3dd1 | ||
|
|
fe3fa3d505 | ||
|
|
27fcd43224 | ||
|
|
c3b2bb19d5 | ||
|
|
34287177b9 | ||
|
|
d238dd4059 | ||
|
|
865f3a249f | ||
|
|
7324d322fe | ||
|
|
af148b476d | ||
|
|
80d60aaf15 | ||
|
|
c80f89d20f | ||
|
|
a27f55a588 | ||
|
|
62e4670a39 | ||
|
|
99695bb264 | ||
|
|
eb54898106 | ||
|
|
4f5ee204d2 | ||
|
|
2552e21db1 | ||
|
|
6168f87e2f | ||
|
|
ca7d2ca4df | ||
|
|
60723323e9 | ||
|
|
3ce7b9255b | ||
|
|
97fcfc2fa0 | ||
|
|
19674e3758 | ||
|
|
1130e1db8f | ||
|
|
37c7f3a58a | ||
|
|
0a12b2bf2e | ||
|
|
d014acbe63 | ||
|
|
07f9997a49 | ||
|
|
b9f91f695a | ||
|
|
082d55af10 | ||
|
|
4c3fd5688d | ||
|
|
9d3d5495ce | ||
|
|
b3869f267c | ||
|
|
8715333706 | ||
|
|
b57096abe2 | ||
|
|
48685c8587 | ||
|
|
100bea6b41 | ||
|
|
63819bf9dd | ||
|
|
6e55390930 | ||
|
|
e3fdab4175 | ||
|
|
bbcd14dbf0 | ||
|
|
01dc0d574b | ||
|
|
5e959d638e | ||
|
|
8d908288c8 | ||
|
|
c88b32a406 | ||
|
|
5a72eefb89 | ||
|
|
430046944b | ||
|
|
21d23320eb | ||
|
|
d0990ee04d | ||
|
|
2d88e853e8 | ||
|
|
a060a70fba | ||
|
|
a06394ad4a | ||
|
|
a1414c2b5b | ||
|
|
b2864dac2d | ||
|
|
b78f895143 | ||
|
|
c4e9726828 | ||
|
|
d4b8e8348a | ||
|
|
19debaa52e | ||
|
|
73fe564321 | ||
|
|
a00abd65b3 | ||
|
|
f169ceaebb | ||
|
|
528af0d338 | ||
|
|
4b25eed869 | ||
|
|
fcd6bd7136 | ||
|
|
32642c5d7c | ||
|
|
ff8b0c375d | ||
|
|
ae0228aeb8 | ||
|
|
5d2dac18c8 | ||
|
|
d03fc14cfd | ||
|
|
ad7ce79bc4 | ||
|
|
c6bf27fd2c | ||
|
|
7584daa3f9 | ||
|
|
654030e789 | ||
|
|
1de7d2cd6f | ||
|
|
68db78c833 | ||
|
|
e1714c16cc | ||
|
|
0a20f14ea9 | ||
|
|
23f46b36b3 | ||
|
|
009649abd1 | ||
|
|
855a007116 | ||
|
|
d064de35e0 | ||
|
|
dab4d13303 | ||
|
|
2bed50b353 | ||
|
|
1fe3778a74 | ||
|
|
f5443a9826 | ||
|
|
db723a5196 | ||
|
|
5e31bcf23d | ||
|
|
fe429cb2eb | ||
|
|
c91695c23a | ||
|
|
55f99e6887 | ||
|
|
b190069f48 | ||
|
|
e16be1a8f4 | ||
|
|
2fa8075fb0 | ||
|
|
566ca13d23 | ||
|
|
e5666c0510 | ||
|
|
46ec39ccfb | ||
|
|
eef418ae5f | ||
|
|
9e663ad051 | ||
|
|
f28eac4c14 | ||
|
|
669bb3f3a8 | ||
|
|
347fe8b272 | ||
|
|
d56c4d0a11 | ||
|
|
034140ccbd | ||
|
|
35bef7d7b7 | ||
|
|
d069708c79 | ||
|
|
bdba5c9d95 | ||
|
|
ff659383f9 | ||
|
|
f06a35713f | ||
|
|
737abdc481 | ||
|
|
02eb4d2e1c | ||
|
|
e7f9acb421 | ||
|
|
976e6cce57 | ||
|
|
dfdf37019c | ||
|
|
56ca19600b | ||
|
|
cd9feeb455 | ||
|
|
f8e5b88be6 | ||
|
|
0f71a63b42 | ||
|
|
b7259abe3c | ||
|
|
00e660d410 | ||
|
|
88e3a2b87f | ||
|
|
aa658418bc | ||
|
|
ac0cff43b6 | ||
|
|
8051c23375 | ||
|
|
7b34bb4050 | ||
|
|
fe38ab35cf | ||
|
|
65a9102ba1 | ||
|
|
1256170985 | ||
|
|
4394a36332 | ||
|
|
0c9fd64434 | ||
|
|
2e99153696 | ||
|
|
54a6f3cb36 | ||
|
|
4a691c11d4 | ||
|
|
b114c0d63f | ||
|
|
a311c3f723 | ||
|
|
04311b4c90 | ||
|
|
b2bb82a432 | ||
|
|
597560ff80 | ||
|
|
db383bb3e6 | ||
|
|
ccc5bbdad4 | ||
|
|
11c8229017 | ||
|
|
2248f9ae5e | ||
|
|
c44c89cc6e | ||
|
|
03c79aacb2 | ||
|
|
0c31713a8e | ||
|
|
9dd814f32e | ||
|
|
ab6e595bcb | ||
|
|
f08fac8c8a | ||
|
|
a699520188 | ||
|
|
f66633459e | ||
|
|
f3b776c343 | ||
|
|
de7b99ce34 | ||
|
|
c0b17d9aff | ||
|
|
3c12be59c5 | ||
|
|
c6b3deb8df | ||
|
|
414f2f3efb | ||
|
|
a0b5606047 | ||
|
|
ed00d44ae1 | ||
|
|
3824e38485 | ||
|
|
b164524380 | ||
|
|
29e4a843df | ||
|
|
619b32d36e | ||
|
|
4433184048 | ||
|
|
312fc8db36 | ||
|
|
615691ec81 | ||
|
|
ae8b83f294 | ||
|
|
4a8e21f4db | ||
|
|
3462e7c437 | ||
|
|
0f2e5239ad | ||
|
|
ee48cdc63f | ||
|
|
1c278bec93 | ||
|
|
6a51166af7 | ||
|
|
85d79fa914 | ||
|
|
142bdce94a | ||
|
|
881a5a64b5 | ||
|
|
5aae44b610 | ||
|
|
e3ea167827 | ||
|
|
eec145e095 | ||
|
|
87fa02d6e5 | ||
|
|
ad94c1e1f3 | ||
|
|
546a0bce8d | ||
|
|
cb7ca44a1c | ||
|
|
4081b93407 | ||
|
|
26203ebaad | ||
|
|
3389e3e1ed | ||
|
|
7e1f01c01e | ||
|
|
613e15548a | ||
|
|
e09c91df8e | ||
|
|
df206667b6 | ||
|
|
0f19dd5263 | ||
|
|
b98e4937f3 | ||
|
|
c2c46e9ace | ||
|
|
27791cf218 | ||
|
|
32a41a815d | ||
|
|
df5fc2ddfe | ||
|
|
79122313a6 | ||
|
|
d7d03e2e92 | ||
|
|
ea493480a9 | ||
|
|
658f641a53 | ||
|
|
f8a2d4f0e0 | ||
|
|
00edd1fbf8 | ||
|
|
999d7b07e1 | ||
|
|
2e3aeb8648 | ||
|
|
f910a696ad | ||
|
|
e1d10bc482 | ||
|
|
181467f11b | ||
|
|
394137b6f7 | ||
|
|
dea907be86 | ||
|
|
f5baf51132 | ||
|
|
f2dc8bd84e | ||
|
|
090309302f | ||
|
|
28e6229b24 | ||
|
|
1b66f03dbe | ||
|
|
e34f6b5fd3 | ||
|
|
8a0482c947 | ||
|
|
938a189f3f | ||
|
|
2005b4a11b | ||
|
|
951fdc8bdd | ||
|
|
12af7a526c | ||
|
|
8781943646 | ||
|
|
7fbfdb634c | ||
|
|
9682077f6b | ||
|
|
22eb405fde | ||
|
|
593c61973f | ||
|
|
ccff32102f | ||
|
|
851d62c6c9 | ||
|
|
a5ac5f26e2 | ||
|
|
090158820f | ||
|
|
26e6650038 | ||
|
|
c48568aabe | ||
|
|
1b33c9eb74 | ||
|
|
6633228975 | ||
|
|
e9cba788a4 | ||
|
|
98822cfc6b | ||
|
|
97ad7e5741 | ||
|
|
71df062e07 | ||
|
|
049f9021e9 | ||
|
|
50eae2ef54 | ||
|
|
c8883a7d0f | ||
|
|
51321caf5b | ||
|
|
51a94288e2 | ||
|
|
8758856e8c | ||
|
|
deba181857 | ||
|
|
c65188dcbf | ||
|
|
21d607898d | ||
|
|
2698d4534e | ||
|
|
bbcd64286a | ||
|
|
9140afbf8c | ||
|
|
90a682c71b | ||
|
|
e8737a8243 | ||
|
|
72fceca72e | ||
|
|
732294abbc | ||
|
|
dc1204531e | ||
|
|
962114379c | ||
|
|
e6913a3055 | ||
|
|
e21d122aef | ||
|
|
58d4ab913a | ||
|
|
76bca03fe3 | ||
|
|
f1e5c9e59e | ||
|
|
ec82242462 | ||
|
|
a4efdd3f3e | ||
|
|
69c6643bb8 | ||
|
|
b8214bf948 | ||
|
|
a9c62c44b3 | ||
|
|
7d0b4ef4e0 | ||
|
|
313340f1c6 | ||
|
|
e8ed69fb09 | ||
|
|
16d5cf6770 | ||
|
|
a2caf1deb2 | ||
|
|
01bfdd2c98 | ||
|
|
4a60df108a | ||
|
|
ad48109748 | ||
|
|
1ceeccbbc0 | ||
|
|
44c51c13ac | ||
|
|
7507be1eab | ||
|
|
cbe9446dcf | ||
|
|
174930399a | ||
|
|
35db4a4c93 | ||
|
|
1f3aee5566 | ||
|
|
256044a789 | ||
|
|
6205199d7f | ||
|
|
e554bd1033 | ||
|
|
38981cefa1 | ||
|
|
f2d601f411 | ||
|
|
6e7c64c1de | ||
|
|
565d51f4db | ||
|
|
de8f3d9c1e | ||
|
|
cde6d48690 | ||
|
|
02180088b3 | ||
|
|
90f49267d1 | ||
|
|
0e6d69cd7b | ||
|
|
9eccc583d5 | ||
|
|
f4aeaa6eb3 | ||
|
|
d7489a644a | ||
|
|
a877283360 | ||
|
|
6d91e7e79b | ||
|
|
567146b143 | ||
|
|
1a3272d7ca | ||
|
|
1ee1ff0b62 | ||
|
|
729fd97748 | ||
|
|
e308051885 | ||
|
|
10e53553d7 | ||
|
|
ef0b30d059 | ||
|
|
e7e9f9509a | ||
|
|
c6cfd101df | ||
|
|
d2dcf063ee | ||
|
|
d15bc7d664 | ||
|
|
e4364d18a7 | ||
|
|
6a34c9f224 | ||
|
|
2a764fd6bb | ||
|
|
3e8ce38eba | ||
|
|
8d2f37aa7a | ||
|
|
b7b70ebcbb | ||
|
|
8ba91f4986 | ||
|
|
79a5e953bc | ||
|
|
20de5ea250 | ||
|
|
bad9ce272c | ||
|
|
d3273ffa8c | ||
|
|
071fc2723a | ||
|
|
ef4ea86f58 | ||
|
|
dfdaa149d0 | ||
|
|
986343a807 | ||
|
|
5211d7ba96 | ||
|
|
a167342778 | ||
|
|
1efb8cdbee | ||
|
|
80d83e6a70 | ||
|
|
31ec1c41ce | ||
|
|
aba1ac0cea | ||
|
|
c40824e51c | ||
|
|
2920f05dae | ||
|
|
bc911d6da0 | ||
|
|
4f87f587e4 | ||
|
|
3e38ab3638 | ||
|
|
21bb911fea | ||
|
|
744dfa33a2 | ||
|
|
ec5f8535a8 | ||
|
|
5a83734a00 | ||
|
|
b4ae8af3a7 | ||
|
|
da60386385 | ||
|
|
45c4c4f4c5 | ||
|
|
9187c75d68 | ||
|
|
abeec22546 | ||
|
|
a6bab755cf | ||
|
|
acd9d994c3 | ||
|
|
37afda3ed3 | ||
|
|
54f2981267 | ||
|
|
bb025514e7 | ||
|
|
e228597269 | ||
|
|
95b0d6c6f2 | ||
|
|
fa4df6e3a2 | ||
|
|
46ceea7ecd | ||
|
|
30f89d5739 | ||
|
|
481cf40831 | ||
|
|
eff05afb7a | ||
|
|
d8e6700611 | ||
|
|
56eb5a933b | ||
|
|
caacc0c133 | ||
|
|
5f377c024b | ||
|
|
00cd8fbdd0 | ||
|
|
aeeff18428 | ||
|
|
c48e3f5e9c | ||
|
|
d6bbc1145a | ||
|
|
e2fec67bd9 | ||
|
|
88cb3b2a4d | ||
|
|
9ebb03be46 | ||
|
|
80d84af76c | ||
|
|
8f4721758f | ||
|
|
8864af4acd | ||
|
|
8980fb8cc7 | ||
|
|
2c5f3472a9 | ||
|
|
f18277ac78 | ||
|
|
8d46bc04d2 | ||
|
|
09e5ea5dec | ||
|
|
d43281c57e | ||
|
|
6810865670 | ||
|
|
3e9e06a02c | ||
|
|
ccd12f6591 | ||
|
|
f9a7843f7e | ||
|
|
210c334db7 | ||
|
|
f297cdfcce | ||
|
|
5b536d00ab | ||
|
|
b4af46ebd5 | ||
|
|
c08da3193e | ||
|
|
f2925ca647 | ||
|
|
fd4d68e5c0 | ||
|
|
5d83deffa4 | ||
|
|
2878cca478 | ||
|
|
53934716db | ||
|
|
d885d45824 | ||
|
|
b90d0f8710 | ||
|
|
8ccfc90fe6 | ||
|
|
92aa7e9e2a | ||
|
|
afc6d19e04 | ||
|
|
c05f073b33 | ||
|
|
2b4c2a22f4 | ||
|
|
47fe93a148 | ||
|
|
6139ca8045 | ||
|
|
87c76a4a0e | ||
|
|
f7b66db873 | ||
|
|
0b314bd7f7 | ||
|
|
9da2e32ad7 | ||
|
|
93c0875740 | ||
|
|
a286700239 | ||
|
|
98ed772e8a | ||
|
|
f0b55a4f97 | ||
|
|
b74503d345 | ||
|
|
f911163e49 | ||
|
|
b083cc99ad | ||
|
|
d35643524e | ||
|
|
62a8ced447 | ||
|
|
085f163c92 | ||
|
|
81a6b1e097 | ||
|
|
dd090c9e6b | ||
|
|
11faa48422 | ||
|
|
55596176c2 | ||
|
|
4d6822d312 | ||
|
|
985c365e6d | ||
|
|
af57762227 | ||
|
|
3575f9030e | ||
|
|
698d947d85 | ||
|
|
ff6528d2bf | ||
|
|
72ac75a98d | ||
|
|
5e3ecb74e4 | ||
|
|
c59be293c8 | ||
|
|
88b4cbdf1a | ||
|
|
d6afbc6f4e | ||
|
|
fc90de3e7b | ||
|
|
847c2ef114 | ||
|
|
a0bf0c1f4d | ||
|
|
8400ff0802 | ||
|
|
0ed6aa230b | ||
|
|
6d22ed80ec | ||
|
|
72d5360af9 | ||
|
|
ac3961e763 | ||
|
|
843466c822 | ||
|
|
8385035400 | ||
|
|
3adcc8be09 | ||
|
|
c853d56302 | ||
|
|
dc97be5b35 | ||
|
|
73dbdfff9f | ||
|
|
dff14e1258 | ||
|
|
10a3833893 | ||
|
|
247cb89332 | ||
|
|
3fc71a0266 | ||
|
|
392dcc3a05 | ||
|
|
f27015d1b7 | ||
|
|
86a19b41aa | ||
|
|
320164d476 | ||
|
|
40ae661ee5 | ||
|
|
ffb3eca68b | ||
|
|
c5def93bb8 | ||
|
|
a9c4c5833d | ||
|
|
58c9c4f590 | ||
|
|
24524d88cb | ||
|
|
b8849ab311 | ||
|
|
f3cd8f8ed0 | ||
|
|
2b26de3f3a | ||
|
|
0149c4c212 | ||
|
|
f2ed898784 | ||
|
|
464a476f9f | ||
|
|
e85d067fb5 | ||
|
|
7eb493990f | ||
|
|
04d5bf3afc | ||
|
|
403a13e4c6 | ||
|
|
ad0f035df5 | ||
|
|
a13e193d3b | ||
|
|
28a1a5ebc2 | ||
|
|
6310dc777f | ||
|
|
07f71fc895 | ||
|
|
f47b9178ad | ||
|
|
863de18877 | ||
|
|
4f399249bd | ||
|
|
f0e5cdee1a | ||
|
|
7bc7d0f5af | ||
|
|
a65a215fd7 | ||
|
|
80d34a226d | ||
|
|
a9628f73e3 | ||
|
|
9324237828 | ||
|
|
d1033c018a | ||
|
|
0f29052ade | ||
|
|
0578e84586 | ||
|
|
6ab41c466f | ||
|
|
98a1093ebf | ||
|
|
caf04373f3 | ||
|
|
d4e8526766 | ||
|
|
515b83a8c7 | ||
|
|
9bf2e03354 | ||
|
|
dc18595c8a | ||
|
|
488bcfe9c6 | ||
|
|
2900b93bb3 | ||
|
|
284cc8a321 | ||
|
|
3dc2e4036c | ||
|
|
268f6b0d51 | ||
|
|
46239b321b | ||
|
|
8a536cd522 | ||
|
|
f9f5d7ccbd | ||
|
|
d6cefdff8e | ||
|
|
dc410b14c4 | ||
|
|
4c49ef9403 | ||
|
|
ba85dcbda5 | ||
|
|
e08c84dd20 | ||
|
|
8b46136703 | ||
|
|
9c7089c8ff | ||
|
|
aac8d89cd0 | ||
|
|
24e75bfeab | ||
|
|
42868b08d3 | ||
|
|
19b61d9ac0 | ||
|
|
db2a2e2bb9 | ||
|
|
e1fdb12647 | ||
|
|
a8ec1b0949 | ||
|
|
2e30b2de77 | ||
|
|
7e407ccae1 | ||
|
|
0667e83919 | ||
|
|
1a6c9a4d04 | ||
|
|
14f5b912ad | ||
|
|
46d6242171 | ||
|
|
753b966148 | ||
|
|
5a307c19b8 | ||
|
|
2cd4f84800 | ||
|
|
4ae612090b | ||
|
|
c67ca4a09e | ||
|
|
94506220d3 | ||
|
|
dbd865a484 | ||
|
|
9d2f3e932a | ||
|
|
49d32f5b5b | ||
|
|
f7b74c0bcb | ||
|
|
c75cb0c7b7 | ||
|
|
a63b335149 | ||
|
|
d8517ce407 | ||
|
|
ad13b11464 | ||
|
|
99bc92d53d | ||
|
|
72199f5615 | ||
|
|
78b8b50082 | ||
|
|
3ab64ce00d | ||
|
|
651e44e0b6 | ||
|
|
963fa41a49 | ||
|
|
493f4f8b95 | ||
|
|
fc1bf36ace | ||
|
|
5ddee17411 | ||
|
|
5ce353bcde | ||
|
|
16d33199eb | ||
|
|
e02303a448 | ||
|
|
36fc966ad6 | ||
|
|
644f74400d | ||
|
|
b7cd451ddb | ||
|
|
59d7717963 | ||
|
|
88392efca4 | ||
|
|
907f2acc7e | ||
|
|
6616477bcf | ||
|
|
5b173cb879 | ||
|
|
dc6b466a42 | ||
|
|
8b04161da3 | ||
|
|
5a85765360 | ||
|
|
333940919b | ||
|
|
b9476be9ad | ||
|
|
704c60491c | ||
|
|
4a8e612c6e | ||
|
|
5e5c9c2580 | ||
|
|
4e71ec5738 | ||
|
|
1004f10384 | ||
|
|
1051648ffb | ||
|
|
7255a09705 | ||
|
|
c2bf6b5f13 | ||
|
|
d8e699b588 | ||
|
|
3e4d4705f5 | ||
|
|
c8b2804446 | ||
|
|
e732f2589f | ||
|
|
aec5543081 | ||
|
|
e03d90ca57 | ||
|
|
495ce62d9c | ||
|
|
fbc3959a5a | ||
|
|
246b11925c | ||
|
|
dfa9131192 | ||
|
|
88c801b4c2 | ||
|
|
a1b55b94e0 | ||
|
|
80db9e2e2f | ||
|
|
ce74690420 | ||
|
|
50de4dfb5d | ||
|
|
9bcdf860f4 | ||
|
|
511ab4b630 | ||
|
|
6f2b623e3c | ||
|
|
fa12165cd3 | ||
|
|
c0c6f3329d | ||
|
|
406a932467 | ||
|
|
cc96d4245f | ||
|
|
c6cdca8923 | ||
|
|
45edcafb06 | ||
|
|
9f0bcc131f | ||
|
|
7e331c2944 | ||
|
|
10347765cb | ||
|
|
c12dee4e76 | ||
|
|
772c188674 | ||
|
|
7c1a3bb8f9 | ||
|
|
8c3c0b1e13 | ||
|
|
1ad84ad51c | ||
|
|
64937c3f77 | ||
|
|
50fd2218fa | ||
|
|
4c29a16271 | ||
|
|
762d3e92de | ||
|
|
2f97531d78 | ||
|
|
f6c7cae661 | ||
|
|
f1777a5bd2 | ||
|
|
78a06ae8cf | ||
|
|
d290df4aa9 | ||
|
|
e559744f32 | ||
|
|
67418e649a | ||
|
|
5adf9fab53 | ||
|
|
2491b686fa | ||
|
|
efd02b2f3e | ||
|
|
3b14078646 | ||
|
|
eb9d5632bc | ||
|
|
45f60edbb6 | ||
|
|
393ea6a7bb | ||
|
|
6ec6f1efe5 | ||
|
|
5d9598ea51 | ||
|
|
0d36d99a73 | ||
|
|
d8a9f5a724 | ||
|
|
2c66e1a042 | ||
|
|
d5eccdb00f | ||
|
|
32626573a6 | ||
|
|
caa82b8f7e | ||
|
|
5af347b499 | ||
|
|
4ed5bb5a9e | ||
|
|
2478d45673 | ||
|
|
1bc7d94111 | ||
|
|
6432414cd5 | ||
|
|
179064ba15 | ||
|
|
783b2d70a5 | ||
|
|
80824f3fc1 | ||
|
|
f39f5f531c | ||
|
|
56139c622f | ||
|
|
da02f6a39b | ||
|
|
548d5597c0 | ||
|
|
7fd65d2412 | ||
|
|
05a54a4af9 | ||
|
|
1e00c8f456 | ||
|
|
90d165aa01 | ||
|
|
01603ca9e4 | ||
|
|
a1b6eb61f2 | ||
|
|
25f300d3ec | ||
|
|
41fe63df06 | ||
|
|
b312170d5f | ||
|
|
cf7f2e8f44 | ||
|
|
d292083ed1 | ||
|
|
9b11142b45 | ||
|
|
acdbc4d7b9 | ||
|
|
838d10a09d | ||
|
|
3852aa056b | ||
|
|
ae77e4528f | ||
|
|
9303f4fc5b | ||
|
|
8be9f4cb0e | ||
|
|
1ea12b1bf7 | ||
|
|
65e6d68355 | ||
|
|
9732eb8836 | ||
|
|
5ae668bc70 | ||
|
|
fd4d1bcca3 | ||
|
|
0a251c9f8e | ||
|
|
351d77be59 | ||
|
|
0e2fc80509 | ||
|
|
8f3fdecb93 | ||
|
|
249a205d8e | ||
|
|
7485801222 | ||
|
|
4678e59737 | ||
|
|
952d351c00 | ||
|
|
901eb55b0e | ||
|
|
727586e40e | ||
|
|
3aa678a58e | ||
|
|
fc7c1a8113 | ||
|
|
f62a0bbe75 | ||
|
|
7341172739 | ||
|
|
91b9fbe450 | ||
|
|
e6b566b848 | ||
|
|
2527a711dc | ||
|
|
5fba6b1cae | ||
|
|
85a61dc39d | ||
|
|
6226bfd196 | ||
|
|
71e11b7cf8 | ||
|
|
800c62fdb6 | ||
|
|
640b9cd53a | ||
|
|
2af3494d8c |
30
.devcontainer/devcontainer.json
Normal file
30
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,30 @@
|
||||
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
||||
// README at: https://github.com/devcontainers/templates/tree/main/src/python
|
||||
{
|
||||
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
|
||||
"image": "mcr.microsoft.com/devcontainers/universal:2",
|
||||
|
||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||
// "features": {},
|
||||
|
||||
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||
// "forwardPorts": [],
|
||||
|
||||
// Use 'postCreateCommand' to run commands after the container is created.
|
||||
"postCreateCommand":
|
||||
"python -m pip install '.[build,test,development,documentation]'",
|
||||
|
||||
// Configure tool-specific properties.
|
||||
"customizations": {
|
||||
// Configure properties specific to VS Code.
|
||||
"vscode": {
|
||||
// Add the IDs of extensions you want installed when the container is created.
|
||||
"extensions": [
|
||||
"ms-python.python"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
||||
// "remoteUser": "root"
|
||||
}
|
||||
6
.github/workflows/code-check.yml
vendored
6
.github/workflows/code-check.yml
vendored
@@ -14,6 +14,10 @@ jobs:
|
||||
check:
|
||||
name: Check Code
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13.0"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Check out from Git
|
||||
@@ -25,7 +29,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
50
.github/workflows/python-avatar.yml
vendored
Normal file
50
.github/workflows/python-avatar.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: Python Avatar
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Avatar [${{ matrix.shard }}]
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
shard: [
|
||||
1/24, 2/24, 3/24, 4/24,
|
||||
5/24, 6/24, 7/24, 8/24,
|
||||
9/24, 10/24, 11/24, 12/24,
|
||||
13/24, 14/24, 15/24, 16/24,
|
||||
17/24, 18/24, 19/24, 20/24,
|
||||
21/24, 22/24, 23/24, 24/24,
|
||||
]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set Up Python 3.11
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.11
|
||||
- name: Install
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install .[avatar]
|
||||
- name: Rootcanal
|
||||
run: nohup python -m rootcanal > rootcanal.log &
|
||||
- name: Test
|
||||
run: |
|
||||
avatar --list | grep -Ev '^=' > test-names.txt
|
||||
timeout 5m avatar --test-beds bumble.bumbles --tests $(split test-names.txt -n l/${{ matrix.shard }})
|
||||
- name: Rootcanal Logs
|
||||
if: always()
|
||||
run: cat rootcanal.log
|
||||
- name: Upload Mobly logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: mobly-logs
|
||||
path: /tmp/logs/mobly/bumble.bumbles/
|
||||
30
.github/workflows/python-build-test.yml
vendored
30
.github/workflows/python-build-test.yml
vendored
@@ -12,11 +12,11 @@ permissions:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
@@ -41,11 +41,13 @@ jobs:
|
||||
run: |
|
||||
inv build
|
||||
inv build.mkdocs
|
||||
|
||||
build-rust:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [ "3.8", "3.9", "3.10" ]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
rust-version: [ "1.76.0", "stable" ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out from Git
|
||||
@@ -54,7 +56,7 @@ jobs:
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[build,test,development,documentation]"
|
||||
@@ -62,9 +64,19 @@ jobs:
|
||||
uses: actions-rust-lang/setup-rust-toolchain@v1
|
||||
with:
|
||||
components: clippy,rustfmt
|
||||
- name: Rust Lints
|
||||
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings
|
||||
toolchain: ${{ matrix.rust-version }}
|
||||
- name: Install Rust dependencies
|
||||
run: cargo install cargo-all-features # allows building/testing combinations of features
|
||||
- name: Check License Headers
|
||||
run: cd rust && cargo run --features dev-tools --bin file-header check-all
|
||||
- name: Rust Build
|
||||
run: cd rust && cargo build --all-targets
|
||||
run: cd rust && cargo build --all-targets && cargo build-all-features --all-targets
|
||||
# Lints after build so what clippy needs is already built
|
||||
- name: Rust Lints
|
||||
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings && cargo clippy --all-features --all-targets -- --deny warnings
|
||||
- name: Rust Tests
|
||||
run: cd rust && cargo test
|
||||
run: cd rust && cargo test-all-features
|
||||
# At some point, hook up publishing the binary. For now, just make sure it builds.
|
||||
# Once we're ready to publish binaries, this should be built with `--release`.
|
||||
- name: Build Bumble CLI
|
||||
run: cd rust && cargo build --features bumble-tools --bin bumble
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -6,7 +6,14 @@ dist/
|
||||
docs/mkdocs/site
|
||||
test-results.xml
|
||||
__pycache__
|
||||
# Vim
|
||||
.*.sw*
|
||||
# generated by setuptools_scm
|
||||
bumble/_version.py
|
||||
.vscode/launch.json
|
||||
.vscode/settings.json
|
||||
/.idea
|
||||
venv/
|
||||
.venv/
|
||||
# snoop logs
|
||||
out/
|
||||
|
||||
26
.vscode/settings.json
vendored
26
.vscode/settings.json
vendored
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"cSpell.words": [
|
||||
"Abortable",
|
||||
"aiohttp",
|
||||
"altsetting",
|
||||
"ansiblue",
|
||||
"ansicyan",
|
||||
@@ -9,10 +10,16 @@
|
||||
"ansired",
|
||||
"ansiyellow",
|
||||
"appendleft",
|
||||
"ascs",
|
||||
"ASHA",
|
||||
"asyncio",
|
||||
"ATRAC",
|
||||
"auracast",
|
||||
"avctp",
|
||||
"avdtp",
|
||||
"avrcp",
|
||||
"biginfo",
|
||||
"bigs",
|
||||
"bitpool",
|
||||
"bitstruct",
|
||||
"BSCP",
|
||||
@@ -21,7 +28,10 @@
|
||||
"cccds",
|
||||
"cmac",
|
||||
"CONNECTIONLESS",
|
||||
"csip",
|
||||
"csis",
|
||||
"csrcs",
|
||||
"CVSD",
|
||||
"datagram",
|
||||
"DATALINK",
|
||||
"delayreport",
|
||||
@@ -29,6 +39,9 @@
|
||||
"deregistration",
|
||||
"dhkey",
|
||||
"diversifier",
|
||||
"ediv",
|
||||
"endianness",
|
||||
"ESCO",
|
||||
"Fitbit",
|
||||
"GATTLINK",
|
||||
"HANDSFREE",
|
||||
@@ -36,15 +49,24 @@
|
||||
"keyup",
|
||||
"levelname",
|
||||
"libc",
|
||||
"liblc",
|
||||
"libusb",
|
||||
"maxs",
|
||||
"MITM",
|
||||
"MSBC",
|
||||
"NDIS",
|
||||
"netsim",
|
||||
"NONBLOCK",
|
||||
"NONCONN",
|
||||
"OXIMETER",
|
||||
"PDUS",
|
||||
"popleft",
|
||||
"PRAND",
|
||||
"prefs",
|
||||
"protobuf",
|
||||
"psms",
|
||||
"pyee",
|
||||
"Pyodide",
|
||||
"pyusb",
|
||||
"rfcomm",
|
||||
"ROHC",
|
||||
@@ -52,6 +74,7 @@
|
||||
"SEID",
|
||||
"seids",
|
||||
"SERV",
|
||||
"SIRK",
|
||||
"ssrc",
|
||||
"strerror",
|
||||
"subband",
|
||||
@@ -61,8 +84,11 @@
|
||||
"substates",
|
||||
"tobytes",
|
||||
"tsep",
|
||||
"UNMUTE",
|
||||
"unmuted",
|
||||
"usbmodem",
|
||||
"vhci",
|
||||
"wasmtime",
|
||||
"websockets",
|
||||
"xcursor",
|
||||
"ycursor"
|
||||
|
||||
986
apps/auracast.py
Normal file
986
apps/auracast.py
Normal file
@@ -0,0 +1,986 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import wave
|
||||
import itertools
|
||||
from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple
|
||||
|
||||
import click
|
||||
import pyee
|
||||
|
||||
try:
|
||||
import lc3 # type: ignore # pylint: disable=E0401
|
||||
except ImportError as e:
|
||||
raise ImportError("Try `python -m pip install \".[lc3]\"`.") from e
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble import company_ids
|
||||
from bumble import core
|
||||
from bumble import gatt
|
||||
from bumble import hci
|
||||
from bumble.profiles import bap
|
||||
from bumble.profiles import le_audio
|
||||
from bumble.profiles import pbp
|
||||
from bumble.profiles import bass
|
||||
import bumble.device
|
||||
import bumble.transport
|
||||
import bumble.utils
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast'
|
||||
AURACAST_DEFAULT_DEVICE_ADDRESS = hci.Address('F0:F1:F2:F3:F4:F5')
|
||||
AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0
|
||||
AURACAST_DEFAULT_ATT_MTU = 256
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Scan For Broadcasts
|
||||
# -----------------------------------------------------------------------------
|
||||
class BroadcastScanner(pyee.EventEmitter):
|
||||
@dataclasses.dataclass
|
||||
class Broadcast(pyee.EventEmitter):
|
||||
name: str | None
|
||||
sync: bumble.device.PeriodicAdvertisingSync
|
||||
broadcast_id: int
|
||||
rssi: int = 0
|
||||
public_broadcast_announcement: Optional[pbp.PublicBroadcastAnnouncement] = None
|
||||
broadcast_audio_announcement: Optional[bap.BroadcastAudioAnnouncement] = None
|
||||
basic_audio_announcement: Optional[bap.BasicAudioAnnouncement] = None
|
||||
appearance: Optional[core.Appearance] = None
|
||||
biginfo: Optional[bumble.device.BIGInfoAdvertisement] = None
|
||||
manufacturer_data: Optional[Tuple[str, bytes]] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
self.sync.on('establishment', self.on_sync_establishment)
|
||||
self.sync.on('loss', self.on_sync_loss)
|
||||
self.sync.on('periodic_advertisement', self.on_periodic_advertisement)
|
||||
self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement)
|
||||
|
||||
def update(self, advertisement: bumble.device.Advertisement) -> None:
|
||||
self.rssi = advertisement.rssi
|
||||
for service_data in advertisement.data.get_all(
|
||||
core.AdvertisingData.SERVICE_DATA
|
||||
):
|
||||
assert isinstance(service_data, tuple)
|
||||
service_uuid, data = service_data
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
if service_uuid == gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE:
|
||||
self.public_broadcast_announcement = (
|
||||
pbp.PublicBroadcastAnnouncement.from_bytes(data)
|
||||
)
|
||||
continue
|
||||
|
||||
if service_uuid == gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE:
|
||||
self.broadcast_audio_announcement = (
|
||||
bap.BroadcastAudioAnnouncement.from_bytes(data)
|
||||
)
|
||||
continue
|
||||
|
||||
self.appearance = advertisement.data.get( # type: ignore[assignment]
|
||||
core.AdvertisingData.APPEARANCE
|
||||
)
|
||||
|
||||
if manufacturer_data := advertisement.data.get(
|
||||
core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA
|
||||
):
|
||||
assert isinstance(manufacturer_data, tuple)
|
||||
company_id = cast(int, manufacturer_data[0])
|
||||
data = cast(bytes, manufacturer_data[1])
|
||||
self.manufacturer_data = (
|
||||
company_ids.COMPANY_IDENTIFIERS.get(
|
||||
company_id, f'0x{company_id:04X}'
|
||||
),
|
||||
data,
|
||||
)
|
||||
|
||||
self.emit('update')
|
||||
|
||||
def print(self) -> None:
|
||||
print(
|
||||
color('Broadcast:', 'yellow'),
|
||||
self.sync.advertiser_address,
|
||||
color(self.sync.state.name, 'green'),
|
||||
)
|
||||
if self.name is not None:
|
||||
print(f' {color("Name", "cyan")}: {self.name}')
|
||||
if self.appearance:
|
||||
print(f' {color("Appearance", "cyan")}: {str(self.appearance)}')
|
||||
print(f' {color("RSSI", "cyan")}: {self.rssi}')
|
||||
print(f' {color("SID", "cyan")}: {self.sync.sid}')
|
||||
|
||||
if self.manufacturer_data:
|
||||
print(
|
||||
f' {color("Manufacturer Data", "cyan")}: '
|
||||
f'{self.manufacturer_data[0]} -> {self.manufacturer_data[1].hex()}'
|
||||
)
|
||||
|
||||
if self.broadcast_audio_announcement:
|
||||
print(
|
||||
f' {color("Broadcast ID", "cyan")}: '
|
||||
f'{self.broadcast_audio_announcement.broadcast_id}'
|
||||
)
|
||||
|
||||
if self.public_broadcast_announcement:
|
||||
print(
|
||||
f' {color("Features", "cyan")}: '
|
||||
f'{self.public_broadcast_announcement.features}'
|
||||
)
|
||||
print(
|
||||
f' {color("Metadata", "cyan")}: '
|
||||
f'{self.public_broadcast_announcement.metadata}'
|
||||
)
|
||||
|
||||
if self.basic_audio_announcement:
|
||||
print(color(' Audio:', 'cyan'))
|
||||
print(
|
||||
color(' Presentation Delay:', 'magenta'),
|
||||
self.basic_audio_announcement.presentation_delay,
|
||||
)
|
||||
for subgroup in self.basic_audio_announcement.subgroups:
|
||||
print(color(' Subgroup:', 'magenta'))
|
||||
print(color(' Codec ID:', 'yellow'))
|
||||
print(
|
||||
color(' Coding Format: ', 'green'),
|
||||
subgroup.codec_id.codec_id.name,
|
||||
)
|
||||
print(
|
||||
color(' Company ID: ', 'green'),
|
||||
subgroup.codec_id.company_id,
|
||||
)
|
||||
print(
|
||||
color(' Vendor Specific Codec ID:', 'green'),
|
||||
subgroup.codec_id.vendor_specific_codec_id,
|
||||
)
|
||||
print(
|
||||
color(' Codec Config:', 'yellow'),
|
||||
subgroup.codec_specific_configuration,
|
||||
)
|
||||
print(color(' Metadata: ', 'yellow'), subgroup.metadata)
|
||||
|
||||
for bis in subgroup.bis:
|
||||
print(color(f' BIS [{bis.index}]:', 'yellow'))
|
||||
print(
|
||||
color(' Codec Config:', 'green'),
|
||||
bis.codec_specific_configuration,
|
||||
)
|
||||
|
||||
if self.biginfo:
|
||||
print(color(' BIG:', 'cyan'))
|
||||
print(
|
||||
color(' Number of BIS:', 'magenta'),
|
||||
self.biginfo.num_bis,
|
||||
)
|
||||
print(
|
||||
color(' PHY: ', 'magenta'),
|
||||
self.biginfo.phy.name,
|
||||
)
|
||||
print(
|
||||
color(' Framed: ', 'magenta'),
|
||||
self.biginfo.framed,
|
||||
)
|
||||
print(
|
||||
color(' Encrypted: ', 'magenta'),
|
||||
self.biginfo.encrypted,
|
||||
)
|
||||
|
||||
def on_sync_establishment(self) -> None:
|
||||
self.emit('sync_establishment')
|
||||
|
||||
def on_sync_loss(self) -> None:
|
||||
self.basic_audio_announcement = None
|
||||
self.biginfo = None
|
||||
self.emit('sync_loss')
|
||||
|
||||
def on_periodic_advertisement(
|
||||
self, advertisement: bumble.device.PeriodicAdvertisement
|
||||
) -> None:
|
||||
if advertisement.data is None:
|
||||
return
|
||||
|
||||
for service_data in advertisement.data.get_all(
|
||||
core.AdvertisingData.SERVICE_DATA
|
||||
):
|
||||
assert isinstance(service_data, tuple)
|
||||
service_uuid, data = service_data
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
if service_uuid == gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
|
||||
self.basic_audio_announcement = (
|
||||
bap.BasicAudioAnnouncement.from_bytes(data)
|
||||
)
|
||||
break
|
||||
|
||||
self.emit('change')
|
||||
|
||||
def on_biginfo_advertisement(
|
||||
self, advertisement: bumble.device.BIGInfoAdvertisement
|
||||
) -> None:
|
||||
self.biginfo = advertisement
|
||||
self.emit('change')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: bumble.device.Device,
|
||||
filter_duplicates: bool,
|
||||
sync_timeout: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.filter_duplicates = filter_duplicates
|
||||
self.sync_timeout = sync_timeout
|
||||
self.broadcasts = dict[hci.Address, BroadcastScanner.Broadcast]()
|
||||
device.on('advertisement', self.on_advertisement)
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.device.start_scanning(
|
||||
active=False,
|
||||
filter_duplicates=False,
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
await self.device.stop_scanning()
|
||||
|
||||
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
|
||||
if not (
|
||||
ads := advertisement.data.get_all(
|
||||
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID
|
||||
)
|
||||
) or not (
|
||||
broadcast_audio_announcement := next(
|
||||
(
|
||||
ad
|
||||
for ad in ads
|
||||
if isinstance(ad, tuple)
|
||||
and ad[0] == gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
|
||||
),
|
||||
None,
|
||||
)
|
||||
):
|
||||
return
|
||||
|
||||
broadcast_name = advertisement.data.get(core.AdvertisingData.BROADCAST_NAME)
|
||||
assert isinstance(broadcast_name, str) or broadcast_name is None
|
||||
assert isinstance(broadcast_audio_announcement[1], bytes)
|
||||
|
||||
if broadcast := self.broadcasts.get(advertisement.address):
|
||||
broadcast.update(advertisement)
|
||||
return
|
||||
|
||||
bumble.utils.AsyncRunner.spawn(
|
||||
self.on_new_broadcast(
|
||||
broadcast_name,
|
||||
advertisement,
|
||||
bap.BroadcastAudioAnnouncement.from_bytes(
|
||||
broadcast_audio_announcement[1]
|
||||
).broadcast_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def on_new_broadcast(
|
||||
self,
|
||||
name: str | None,
|
||||
advertisement: bumble.device.Advertisement,
|
||||
broadcast_id: int,
|
||||
) -> None:
|
||||
periodic_advertising_sync = await self.device.create_periodic_advertising_sync(
|
||||
advertiser_address=advertisement.address,
|
||||
sid=advertisement.sid,
|
||||
sync_timeout=self.sync_timeout,
|
||||
filter_duplicates=self.filter_duplicates,
|
||||
)
|
||||
broadcast = self.Broadcast(name, periodic_advertising_sync, broadcast_id)
|
||||
broadcast.update(advertisement)
|
||||
self.broadcasts[advertisement.address] = broadcast
|
||||
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
|
||||
self.emit('new_broadcast', broadcast)
|
||||
|
||||
def on_broadcast_loss(self, broadcast: Broadcast) -> None:
|
||||
del self.broadcasts[broadcast.sync.advertiser_address]
|
||||
bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate())
|
||||
self.emit('broadcast_loss', broadcast)
|
||||
|
||||
|
||||
class PrintingBroadcastScanner(pyee.EventEmitter):
|
||||
def __init__(
|
||||
self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout)
|
||||
self.scanner.on('new_broadcast', self.on_new_broadcast)
|
||||
self.scanner.on('broadcast_loss', self.on_broadcast_loss)
|
||||
self.scanner.on('update', self.refresh)
|
||||
self.status_message = ''
|
||||
|
||||
async def start(self) -> None:
|
||||
self.status_message = color('Scanning...', 'green')
|
||||
await self.scanner.start()
|
||||
|
||||
def on_new_broadcast(self, broadcast: BroadcastScanner.Broadcast) -> None:
|
||||
self.status_message = color(
|
||||
f'+Found {len(self.scanner.broadcasts)} broadcasts', 'green'
|
||||
)
|
||||
broadcast.on('change', self.refresh)
|
||||
broadcast.on('update', self.refresh)
|
||||
self.refresh()
|
||||
|
||||
def on_broadcast_loss(self, broadcast: BroadcastScanner.Broadcast) -> None:
|
||||
self.status_message = color(
|
||||
f'-Found {len(self.scanner.broadcasts)} broadcasts', 'green'
|
||||
)
|
||||
self.refresh()
|
||||
|
||||
def refresh(self) -> None:
|
||||
# Clear the screen from the top
|
||||
print('\033[H')
|
||||
print('\033[0J')
|
||||
print('\033[H')
|
||||
|
||||
# Print the status message
|
||||
print(self.status_message)
|
||||
print("==========================================")
|
||||
|
||||
# Print all broadcasts
|
||||
for broadcast in self.scanner.broadcasts.values():
|
||||
broadcast.print()
|
||||
print('------------------------------------------')
|
||||
|
||||
# Clear the screen to the bottom
|
||||
print('\033[0J')
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def create_device(transport: str) -> AsyncGenerator[bumble.device.Device, Any]:
|
||||
async with await bumble.transport.open_transport(transport) as (
|
||||
hci_source,
|
||||
hci_sink,
|
||||
):
|
||||
device_config = bumble.device.DeviceConfiguration(
|
||||
name=AURACAST_DEFAULT_DEVICE_NAME,
|
||||
address=AURACAST_DEFAULT_DEVICE_ADDRESS,
|
||||
keystore='JsonKeyStore',
|
||||
)
|
||||
|
||||
device = bumble.device.Device.from_config_with_hci(
|
||||
device_config,
|
||||
hci_source,
|
||||
hci_sink,
|
||||
)
|
||||
await device.power_on()
|
||||
|
||||
yield device
|
||||
|
||||
|
||||
async def find_broadcast_by_name(
|
||||
device: bumble.device.Device, name: Optional[str]
|
||||
) -> BroadcastScanner.Broadcast:
|
||||
result = asyncio.get_running_loop().create_future()
|
||||
|
||||
def on_broadcast_change(broadcast: BroadcastScanner.Broadcast) -> None:
|
||||
if broadcast.basic_audio_announcement and not result.done():
|
||||
print(color('Broadcast basic audio announcement received', 'green'))
|
||||
result.set_result(broadcast)
|
||||
|
||||
def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
|
||||
if name is None or broadcast.name == name:
|
||||
print(color('Broadcast found:', 'green'), broadcast.name)
|
||||
broadcast.on('change', lambda: on_broadcast_change(broadcast))
|
||||
return
|
||||
|
||||
print(color(f'Skipping broadcast {broadcast.name}'))
|
||||
|
||||
scanner = BroadcastScanner(device, False, AURACAST_DEFAULT_SYNC_TIMEOUT)
|
||||
scanner.on('new_broadcast', on_new_broadcast)
|
||||
await scanner.start()
|
||||
|
||||
broadcast = await result
|
||||
await scanner.stop()
|
||||
|
||||
return broadcast
|
||||
|
||||
|
||||
async def run_scan(
|
||||
filter_duplicates: bool, sync_timeout: float, transport: str
|
||||
) -> None:
|
||||
async with create_device(transport) as device:
|
||||
if not device.supports_le_periodic_advertising:
|
||||
print(color('Periodic advertising not supported', 'red'))
|
||||
return
|
||||
|
||||
scanner = PrintingBroadcastScanner(device, filter_duplicates, sync_timeout)
|
||||
await scanner.start()
|
||||
await asyncio.get_running_loop().create_future()
|
||||
|
||||
|
||||
async def run_assist(
|
||||
broadcast_name: Optional[str],
|
||||
source_id: Optional[int],
|
||||
command: str,
|
||||
transport: str,
|
||||
address: str,
|
||||
) -> None:
|
||||
async with create_device(transport) as device:
|
||||
if not device.supports_le_periodic_advertising:
|
||||
print(color('Periodic advertising not supported', 'red'))
|
||||
return
|
||||
|
||||
# Connect to the server
|
||||
print(f'=== Connecting to {address}...')
|
||||
connection = await device.connect(address)
|
||||
peer = bumble.device.Peer(connection)
|
||||
print(f'=== Connected to {peer}')
|
||||
|
||||
print("+++ Encrypting connection...")
|
||||
await peer.connection.encrypt()
|
||||
print("+++ Connection encrypted")
|
||||
|
||||
# Request a larger MTU
|
||||
mtu = AURACAST_DEFAULT_ATT_MTU
|
||||
print(color(f'$$$ Requesting MTU={mtu}', 'yellow'))
|
||||
await peer.request_mtu(mtu)
|
||||
|
||||
# Get the BASS service
|
||||
bass_client = await peer.discover_service_and_create_proxy(
|
||||
bass.BroadcastAudioScanServiceProxy
|
||||
)
|
||||
|
||||
# Check that the service was found
|
||||
if not bass_client:
|
||||
print(color('!!! Broadcast Audio Scan Service not found', 'red'))
|
||||
return
|
||||
|
||||
# Subscribe to and read the broadcast receive state characteristics
|
||||
for i, broadcast_receive_state in enumerate(
|
||||
bass_client.broadcast_receive_states
|
||||
):
|
||||
try:
|
||||
await broadcast_receive_state.subscribe(
|
||||
lambda value, i=i: print(
|
||||
f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}"
|
||||
)
|
||||
)
|
||||
except core.ProtocolError as error:
|
||||
print(
|
||||
color(
|
||||
f'!!! Failed to subscribe to Broadcast Receive State characteristic:',
|
||||
'red',
|
||||
),
|
||||
error,
|
||||
)
|
||||
value = await broadcast_receive_state.read_value()
|
||||
print(
|
||||
f'{color(f"Initial Broadcast Receive State [{i}]:", "green")} {value}'
|
||||
)
|
||||
|
||||
if command == 'monitor-state':
|
||||
await peer.sustain()
|
||||
return
|
||||
|
||||
if command == 'add-source':
|
||||
# Find the requested broadcast
|
||||
await bass_client.remote_scan_started()
|
||||
if broadcast_name:
|
||||
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
|
||||
else:
|
||||
print(color('Scanning for any broadcast', 'cyan'))
|
||||
broadcast = await find_broadcast_by_name(device, broadcast_name)
|
||||
|
||||
if broadcast.broadcast_audio_announcement is None:
|
||||
print(color('No broadcast audio announcement found', 'red'))
|
||||
return
|
||||
|
||||
if (
|
||||
broadcast.basic_audio_announcement is None
|
||||
or not broadcast.basic_audio_announcement.subgroups
|
||||
):
|
||||
print(color('No subgroups found', 'red'))
|
||||
return
|
||||
|
||||
# Add the source
|
||||
print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address)
|
||||
await bass_client.add_source(
|
||||
broadcast.sync.advertiser_address,
|
||||
broadcast.sync.sid,
|
||||
broadcast.broadcast_audio_announcement.broadcast_id,
|
||||
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
|
||||
0xFFFF,
|
||||
[
|
||||
bass.SubgroupInfo(
|
||||
bass.SubgroupInfo.ANY_BIS,
|
||||
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Initiate a PA Sync Transfer
|
||||
await broadcast.sync.transfer(peer.connection)
|
||||
|
||||
# Notify the sink that we're done scanning.
|
||||
await bass_client.remote_scan_stopped()
|
||||
|
||||
await peer.sustain()
|
||||
return
|
||||
|
||||
if command == 'modify-source':
|
||||
if source_id is None:
|
||||
print(color('!!! modify-source requires --source-id'))
|
||||
return
|
||||
|
||||
# Find the requested broadcast
|
||||
await bass_client.remote_scan_started()
|
||||
if broadcast_name:
|
||||
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
|
||||
else:
|
||||
print(color('Scanning for any broadcast', 'cyan'))
|
||||
broadcast = await find_broadcast_by_name(device, broadcast_name)
|
||||
|
||||
if broadcast.broadcast_audio_announcement is None:
|
||||
print(color('No broadcast audio announcement found', 'red'))
|
||||
return
|
||||
|
||||
if (
|
||||
broadcast.basic_audio_announcement is None
|
||||
or not broadcast.basic_audio_announcement.subgroups
|
||||
):
|
||||
print(color('No subgroups found', 'red'))
|
||||
return
|
||||
|
||||
# Modify the source
|
||||
print(
|
||||
color('Modifying source:', 'blue'),
|
||||
source_id,
|
||||
)
|
||||
await bass_client.modify_source(
|
||||
source_id,
|
||||
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
|
||||
0xFFFF,
|
||||
[
|
||||
bass.SubgroupInfo(
|
||||
bass.SubgroupInfo.ANY_BIS,
|
||||
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
|
||||
)
|
||||
],
|
||||
)
|
||||
await peer.sustain()
|
||||
return
|
||||
|
||||
if command == 'remove-source':
|
||||
if source_id is None:
|
||||
print(color('!!! remove-source requires --source-id'))
|
||||
return
|
||||
|
||||
# Remove the source
|
||||
print(color('Removing source:', 'blue'), source_id)
|
||||
await bass_client.remove_source(source_id)
|
||||
await peer.sustain()
|
||||
return
|
||||
|
||||
print(color(f'!!! invalid command {command}'))
|
||||
|
||||
|
||||
async def run_pair(transport: str, address: str) -> None:
|
||||
async with create_device(transport) as device:
|
||||
|
||||
# Connect to the server
|
||||
print(f'=== Connecting to {address}...')
|
||||
async with device.connect_as_gatt(address) as peer:
|
||||
print(f'=== Connected to {peer}')
|
||||
|
||||
print("+++ Initiating pairing...")
|
||||
await peer.connection.pair()
|
||||
print("+++ Paired")
|
||||
|
||||
|
||||
async def run_receive(
|
||||
transport: str,
|
||||
broadcast_id: int,
|
||||
broadcast_code: str | None,
|
||||
sync_timeout: float,
|
||||
subgroup_index: int,
|
||||
) -> None:
|
||||
async with create_device(transport) as device:
|
||||
if not device.supports_le_periodic_advertising:
|
||||
print(color('Periodic advertising not supported', 'red'))
|
||||
return
|
||||
|
||||
scanner = BroadcastScanner(device, False, sync_timeout)
|
||||
scan_result: asyncio.Future[BroadcastScanner.Broadcast] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
|
||||
if scan_result.done():
|
||||
return
|
||||
if broadcast.broadcast_id == broadcast_id:
|
||||
scan_result.set_result(broadcast)
|
||||
|
||||
scanner.on('new_broadcast', on_new_broadcast)
|
||||
await scanner.start()
|
||||
print('Start scanning...')
|
||||
broadcast = await scan_result
|
||||
print('Advertisement found:')
|
||||
broadcast.print()
|
||||
basic_audio_announcement_scanned = asyncio.Event()
|
||||
|
||||
def on_change() -> None:
|
||||
if (
|
||||
broadcast.basic_audio_announcement
|
||||
and not basic_audio_announcement_scanned.is_set()
|
||||
):
|
||||
basic_audio_announcement_scanned.set()
|
||||
|
||||
broadcast.on('change', on_change)
|
||||
if not broadcast.basic_audio_announcement:
|
||||
print('Wait for Basic Audio Announcement...')
|
||||
await basic_audio_announcement_scanned.wait()
|
||||
print('Basic Audio Announcement found')
|
||||
broadcast.print()
|
||||
print('Stop scanning')
|
||||
await scanner.stop()
|
||||
print('Start sync to BIG')
|
||||
|
||||
assert broadcast.basic_audio_announcement
|
||||
subgroup = broadcast.basic_audio_announcement.subgroups[subgroup_index]
|
||||
configuration = subgroup.codec_specific_configuration
|
||||
assert configuration
|
||||
assert (sampling_frequency := configuration.sampling_frequency)
|
||||
assert (frame_duration := configuration.frame_duration)
|
||||
|
||||
big_sync = await device.create_big_sync(
|
||||
broadcast.sync,
|
||||
bumble.device.BigSyncParameters(
|
||||
big_sync_timeout=0x4000,
|
||||
bis=[bis.index for bis in subgroup.bis],
|
||||
broadcast_code=(
|
||||
bytes.fromhex(broadcast_code) if broadcast_code else None
|
||||
),
|
||||
),
|
||||
)
|
||||
num_bis = len(big_sync.bis_links)
|
||||
decoder = lc3.Decoder(
|
||||
frame_duration_us=frame_duration.us,
|
||||
sample_rate_hz=sampling_frequency.hz,
|
||||
num_channels=num_bis,
|
||||
)
|
||||
sdus = [b''] * num_bis
|
||||
subprocess = await asyncio.create_subprocess_shell(
|
||||
f'stdbuf -i0 ffplay -ar {sampling_frequency.hz} -ac {num_bis} -f f32le pipe:0',
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
for i, bis_link in enumerate(big_sync.bis_links):
|
||||
print(f'Setup ISO for BIS {bis_link.handle}')
|
||||
|
||||
def sink(index: int, packet: hci.HCI_IsoDataPacket):
|
||||
nonlocal sdus
|
||||
sdus[index] = packet.iso_sdu_fragment
|
||||
if all(sdus) and subprocess.stdin:
|
||||
subprocess.stdin.write(decoder.decode(b''.join(sdus)).tobytes())
|
||||
sdus = [b''] * num_bis
|
||||
|
||||
bis_link.sink = functools.partial(sink, i)
|
||||
await bis_link.setup_data_path(
|
||||
direction=bis_link.Direction.CONTROLLER_TO_HOST
|
||||
)
|
||||
|
||||
terminated = asyncio.Event()
|
||||
big_sync.on(big_sync.Event.TERMINATION, lambda _: terminated.set())
|
||||
await terminated.wait()
|
||||
|
||||
|
||||
async def run_broadcast(
|
||||
transport: str, broadcast_id: int, broadcast_code: str | None, wav_file_path: str
|
||||
) -> None:
|
||||
async with create_device(transport) as device:
|
||||
if not device.supports_le_periodic_advertising:
|
||||
print(color('Periodic advertising not supported', 'red'))
|
||||
return
|
||||
|
||||
with wave.open(wav_file_path, 'rb') as wav:
|
||||
print('Encoding wav file into lc3...')
|
||||
encoder = lc3.Encoder(
|
||||
frame_duration_us=10000,
|
||||
sample_rate_hz=48000,
|
||||
num_channels=2,
|
||||
input_sample_rate_hz=wav.getframerate(),
|
||||
)
|
||||
frames = list[bytes]()
|
||||
while pcm := wav.readframes(encoder.get_frame_samples()):
|
||||
frames.append(
|
||||
encoder.encode(pcm, num_bytes=200, bit_depth=wav.getsampwidth() * 8)
|
||||
)
|
||||
del encoder
|
||||
print('Encoding complete.')
|
||||
|
||||
basic_audio_announcement = bap.BasicAudioAnnouncement(
|
||||
presentation_delay=40000,
|
||||
subgroups=[
|
||||
bap.BasicAudioAnnouncement.Subgroup(
|
||||
codec_id=hci.CodingFormat(codec_id=hci.CodecID.LC3),
|
||||
codec_specific_configuration=bap.CodecSpecificConfiguration(
|
||||
sampling_frequency=bap.SamplingFrequency.FREQ_48000,
|
||||
frame_duration=bap.FrameDuration.DURATION_10000_US,
|
||||
octets_per_codec_frame=100,
|
||||
),
|
||||
metadata=le_audio.Metadata(
|
||||
[
|
||||
le_audio.Metadata.Entry(
|
||||
tag=le_audio.Metadata.Tag.LANGUAGE, data=b'eng'
|
||||
),
|
||||
le_audio.Metadata.Entry(
|
||||
tag=le_audio.Metadata.Tag.PROGRAM_INFO, data=b'Disco'
|
||||
),
|
||||
]
|
||||
),
|
||||
bis=[
|
||||
bap.BasicAudioAnnouncement.BIS(
|
||||
index=1,
|
||||
codec_specific_configuration=bap.CodecSpecificConfiguration(
|
||||
audio_channel_allocation=bap.AudioLocation.FRONT_LEFT
|
||||
),
|
||||
),
|
||||
bap.BasicAudioAnnouncement.BIS(
|
||||
index=2,
|
||||
codec_specific_configuration=bap.CodecSpecificConfiguration(
|
||||
audio_channel_allocation=bap.AudioLocation.FRONT_RIGHT
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
broadcast_audio_announcement = bap.BroadcastAudioAnnouncement(broadcast_id)
|
||||
print('Start Advertising')
|
||||
advertising_set = await device.create_advertising_set(
|
||||
advertising_parameters=bumble.device.AdvertisingParameters(
|
||||
advertising_event_properties=bumble.device.AdvertisingEventProperties(
|
||||
is_connectable=False
|
||||
),
|
||||
primary_advertising_interval_min=100,
|
||||
primary_advertising_interval_max=200,
|
||||
),
|
||||
advertising_data=(
|
||||
broadcast_audio_announcement.get_advertising_data()
|
||||
+ bytes(
|
||||
core.AdvertisingData(
|
||||
[(core.AdvertisingData.BROADCAST_NAME, b'Bumble Auracast')]
|
||||
)
|
||||
)
|
||||
),
|
||||
periodic_advertising_parameters=bumble.device.PeriodicAdvertisingParameters(
|
||||
periodic_advertising_interval_min=80,
|
||||
periodic_advertising_interval_max=160,
|
||||
),
|
||||
periodic_advertising_data=basic_audio_announcement.get_advertising_data(),
|
||||
auto_restart=True,
|
||||
auto_start=True,
|
||||
)
|
||||
print('Start Periodic Advertising')
|
||||
await advertising_set.start_periodic()
|
||||
print('Setup BIG')
|
||||
big = await device.create_big(
|
||||
advertising_set,
|
||||
parameters=bumble.device.BigParameters(
|
||||
num_bis=2,
|
||||
sdu_interval=10000,
|
||||
max_sdu=100,
|
||||
max_transport_latency=65,
|
||||
rtn=4,
|
||||
broadcast_code=(
|
||||
bytes.fromhex(broadcast_code) if broadcast_code else None
|
||||
),
|
||||
),
|
||||
)
|
||||
print('Setup ISO Data Path')
|
||||
for bis_link in big.bis_links:
|
||||
await bis_link.setup_data_path(
|
||||
direction=bis_link.Direction.HOST_TO_CONTROLLER
|
||||
)
|
||||
|
||||
for frame in itertools.cycle(frames):
|
||||
mid = len(frame) // 2
|
||||
big.bis_links[0].write(frame[:mid])
|
||||
big.bis_links[1].write(frame[mid:])
|
||||
await asyncio.sleep(0.009)
|
||||
|
||||
|
||||
def run_async(async_command: Coroutine) -> None:
|
||||
try:
|
||||
asyncio.run(async_command)
|
||||
except core.ProtocolError as error:
|
||||
if error.error_namespace == 'att' and error.error_code in list(
|
||||
bass.ApplicationError
|
||||
):
|
||||
message = bass.ApplicationError(error.error_code).name
|
||||
else:
|
||||
message = str(error)
|
||||
|
||||
print(
|
||||
color('!!! An error occurred while executing the command:', 'red'), message
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Main
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.group()
|
||||
@click.pass_context
|
||||
def auracast(ctx):
|
||||
ctx.ensure_object(dict)
|
||||
|
||||
|
||||
@auracast.command('scan')
|
||||
@click.option(
|
||||
'--filter-duplicates', is_flag=True, default=False, help='Filter duplicates'
|
||||
)
|
||||
@click.option(
|
||||
'--sync-timeout',
|
||||
metavar='SYNC_TIMEOUT',
|
||||
type=float,
|
||||
default=AURACAST_DEFAULT_SYNC_TIMEOUT,
|
||||
help='Sync timeout (in seconds)',
|
||||
)
|
||||
@click.argument('transport')
|
||||
@click.pass_context
|
||||
def scan(ctx, filter_duplicates, sync_timeout, transport):
|
||||
"""Scan for public broadcasts"""
|
||||
run_async(run_scan(filter_duplicates, sync_timeout, transport))
|
||||
|
||||
|
||||
@auracast.command('assist')
|
||||
@click.option(
|
||||
'--broadcast-name',
|
||||
metavar='BROADCAST_NAME',
|
||||
help='Broadcast Name to tune to',
|
||||
)
|
||||
@click.option(
|
||||
'--source-id',
|
||||
metavar='SOURCE_ID',
|
||||
type=int,
|
||||
help='Source ID (for remove-source command)',
|
||||
)
|
||||
@click.option(
|
||||
'--command',
|
||||
type=click.Choice(
|
||||
['monitor-state', 'add-source', 'modify-source', 'remove-source']
|
||||
),
|
||||
required=True,
|
||||
)
|
||||
@click.argument('transport')
|
||||
@click.argument('address')
|
||||
@click.pass_context
|
||||
def assist(ctx, broadcast_name, source_id, command, transport, address):
|
||||
"""Scan for broadcasts on behalf of a audio server"""
|
||||
run_async(run_assist(broadcast_name, source_id, command, transport, address))
|
||||
|
||||
|
||||
@auracast.command('pair')
|
||||
@click.argument('transport')
|
||||
@click.argument('address')
|
||||
@click.pass_context
|
||||
def pair(ctx, transport, address):
|
||||
"""Pair with an audio server"""
|
||||
run_async(run_pair(transport, address))
|
||||
|
||||
|
||||
@auracast.command('receive')
|
||||
@click.argument('transport')
|
||||
@click.argument('broadcast_id', type=int)
|
||||
@click.option(
|
||||
'--broadcast-code',
|
||||
metavar='BROADCAST_CODE',
|
||||
type=str,
|
||||
help='Broadcast encryption code in hex format',
|
||||
)
|
||||
@click.option(
|
||||
'--sync-timeout',
|
||||
metavar='SYNC_TIMEOUT',
|
||||
type=float,
|
||||
default=AURACAST_DEFAULT_SYNC_TIMEOUT,
|
||||
help='Sync timeout (in seconds)',
|
||||
)
|
||||
@click.option(
|
||||
'--subgroup',
|
||||
metavar='SUBGROUP',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Index of Subgroup',
|
||||
)
|
||||
@click.pass_context
|
||||
def receive(ctx, transport, broadcast_id, broadcast_code, sync_timeout, subgroup):
|
||||
"""Receive a broadcast source"""
|
||||
run_async(
|
||||
run_receive(transport, broadcast_id, broadcast_code, sync_timeout, subgroup)
|
||||
)
|
||||
|
||||
|
||||
@auracast.command('broadcast')
|
||||
@click.argument('transport')
|
||||
@click.argument('wav_file_path', type=str)
|
||||
@click.option(
|
||||
'--broadcast-id',
|
||||
metavar='BROADCAST_ID',
|
||||
type=int,
|
||||
default=123456,
|
||||
help='Broadcast ID',
|
||||
)
|
||||
@click.option(
|
||||
'--broadcast-code',
|
||||
metavar='BROADCAST_CODE',
|
||||
type=str,
|
||||
help='Broadcast encryption code in hex format',
|
||||
)
|
||||
@click.pass_context
|
||||
def broadcast(ctx, transport, broadcast_id, broadcast_code, wav_file_path):
|
||||
"""Start a broadcast as a source."""
|
||||
run_async(
|
||||
run_broadcast(
|
||||
transport=transport,
|
||||
broadcast_id=broadcast_id,
|
||||
broadcast_code=broadcast_code,
|
||||
wav_file_path=wav_file_path,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
auracast()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
1051
apps/bench.py
1051
apps/bench.py
File diff suppressed because it is too large
Load Diff
63
apps/ble_rpa_tool.py
Normal file
63
apps/ble_rpa_tool.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import click
|
||||
from bumble.colors import color
|
||||
from bumble.hci import Address
|
||||
from bumble.helpers import generate_irk, verify_rpa_with_irk
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
'''
|
||||
This is a tool for generating IRK, RPA,
|
||||
and verifying IRK/RPA pairs
|
||||
'''
|
||||
|
||||
|
||||
@click.command()
|
||||
def gen_irk() -> None:
|
||||
print(generate_irk().hex())
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("irk", type=str)
|
||||
def gen_rpa(irk: str) -> None:
|
||||
irk_bytes = bytes.fromhex(irk)
|
||||
rpa = Address.generate_private_address(irk_bytes)
|
||||
print(rpa.to_string(with_type_qualifier=False))
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("irk", type=str)
|
||||
@click.argument("rpa", type=str)
|
||||
def verify_rpa(irk: str, rpa: str) -> None:
|
||||
address = Address(rpa)
|
||||
irk_bytes = bytes.fromhex(irk)
|
||||
if verify_rpa_with_irk(address, irk_bytes):
|
||||
print(color("Verified", "green"))
|
||||
else:
|
||||
print(color("Not Verified", "red"))
|
||||
|
||||
|
||||
def main():
|
||||
cli.add_command(gen_irk)
|
||||
cli.add_command(gen_rpa)
|
||||
cli.add_command(verify_rpa)
|
||||
cli()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -63,6 +63,7 @@ from bumble.transport import open_transport_or_link
|
||||
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
|
||||
from bumble.gatt_client import CharacteristicProxy
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
HCI_Constant,
|
||||
HCI_LE_1M_PHY,
|
||||
HCI_LE_2M_PHY,
|
||||
@@ -289,11 +290,7 @@ class ConsoleApp:
|
||||
device_config, hci_source, hci_sink
|
||||
)
|
||||
else:
|
||||
random_address = (
|
||||
f"{random.randint(192,255):02X}" # address is static random
|
||||
)
|
||||
for random_byte in random.sample(range(255), 5):
|
||||
random_address += f":{random_byte:02X}"
|
||||
random_address = Address.generate_static_address()
|
||||
self.append_to_log(f"Setting random address: {random_address}")
|
||||
self.device = Device.with_hci(
|
||||
'Bumble', random_address, hci_source, hci_sink
|
||||
@@ -503,21 +500,9 @@ class ConsoleApp:
|
||||
self.show_error('not connected')
|
||||
return
|
||||
|
||||
# Discover all services, characteristics and descriptors
|
||||
self.append_to_output('discovering services...')
|
||||
await self.connected_peer.discover_services()
|
||||
self.append_to_output(
|
||||
f'found {len(self.connected_peer.services)} services,'
|
||||
' discovering characteristics...'
|
||||
)
|
||||
await self.connected_peer.discover_characteristics()
|
||||
self.append_to_output('found characteristics, discovering descriptors...')
|
||||
for service in self.connected_peer.services:
|
||||
for characteristic in service.characteristics:
|
||||
await self.connected_peer.discover_descriptors(characteristic)
|
||||
self.append_to_output('discovery completed')
|
||||
|
||||
self.show_remote_services(self.connected_peer.services)
|
||||
self.append_to_output('Service Discovery starting...')
|
||||
await self.connected_peer.discover_all()
|
||||
self.append_to_output('Service Discovery done!')
|
||||
|
||||
async def discover_attributes(self):
|
||||
if not self.connected_peer:
|
||||
@@ -777,7 +762,7 @@ class ConsoleApp:
|
||||
if not service:
|
||||
continue
|
||||
values = [
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
for connection in self.device.connections.values()
|
||||
]
|
||||
if not values:
|
||||
@@ -796,11 +781,11 @@ class ConsoleApp:
|
||||
if not characteristic:
|
||||
continue
|
||||
values = [
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
for connection in self.device.connections.values()
|
||||
]
|
||||
if not values:
|
||||
values = [attribute.read_value(None)]
|
||||
values = [await attribute.read_value(None)]
|
||||
|
||||
# TODO: future optimization: convert CCCD value to human readable string
|
||||
|
||||
@@ -944,7 +929,7 @@ class ConsoleApp:
|
||||
|
||||
# send data to any subscribers
|
||||
if isinstance(attribute, Characteristic):
|
||||
attribute.write_value(None, value)
|
||||
await attribute.write_value(None, value)
|
||||
if attribute.has_properties(Characteristic.NOTIFY):
|
||||
await self.device.gatt_server.notify_subscribers(attribute)
|
||||
if attribute.has_properties(Characteristic.INDICATE):
|
||||
@@ -1172,7 +1157,7 @@ class ScanResult:
|
||||
name = ''
|
||||
|
||||
# Remove any '/P' qualifier suffix from the address string
|
||||
address_str = str(self.address).replace('/P', '')
|
||||
address_str = self.address.to_string(with_type_qualifier=False)
|
||||
|
||||
# RSSI bar
|
||||
bar_string = rssi_bar(self.rssi)
|
||||
|
||||
@@ -18,30 +18,42 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import click
|
||||
from bumble.company_ids import COMPANY_IDENTIFIERS
|
||||
import time
|
||||
|
||||
import click
|
||||
|
||||
from bumble.company_ids import COMPANY_IDENTIFIERS
|
||||
from bumble.colors import color
|
||||
from bumble.core import name_or_number
|
||||
from bumble.hci import (
|
||||
map_null_terminated_utf8_string,
|
||||
CodecID,
|
||||
LeFeature,
|
||||
HCI_SUCCESS,
|
||||
HCI_LE_SUPPORTED_FEATURES_NAMES,
|
||||
HCI_VERSION_NAMES,
|
||||
LMP_VERSION_NAMES,
|
||||
HCI_Command,
|
||||
HCI_Command_Complete_Event,
|
||||
HCI_Command_Status_Event,
|
||||
HCI_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_Read_Buffer_Size_Command,
|
||||
HCI_READ_BD_ADDR_COMMAND,
|
||||
HCI_Read_BD_ADDR_Command,
|
||||
HCI_READ_LOCAL_NAME_COMMAND,
|
||||
HCI_Read_Local_Name_Command,
|
||||
HCI_LE_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_LE_Read_Buffer_Size_Command,
|
||||
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Maximum_Data_Length_Command,
|
||||
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
|
||||
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
|
||||
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
|
||||
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Suggested_Default_Data_Length_Command,
|
||||
HCI_Read_Local_Supported_Codecs_Command,
|
||||
HCI_Read_Local_Supported_Codecs_V2_Command,
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
)
|
||||
from bumble.host import Host
|
||||
from bumble.transport import open_transport_or_link
|
||||
@@ -57,13 +69,14 @@ def command_succeeded(response):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_classic_info(host):
|
||||
async def get_classic_info(host: Host) -> None:
|
||||
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
|
||||
response = await host.send_command(HCI_Read_BD_ADDR_Command())
|
||||
if command_succeeded(response):
|
||||
print()
|
||||
print(
|
||||
color('Classic Address:', 'yellow'), response.return_parameters.bd_addr
|
||||
color('Public Address:', 'yellow'),
|
||||
response.return_parameters.bd_addr.to_string(False),
|
||||
)
|
||||
|
||||
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
|
||||
@@ -77,7 +90,7 @@ async def get_classic_info(host):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_le_info(host):
|
||||
async def get_le_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
|
||||
@@ -116,13 +129,104 @@ async def get_le_info(host):
|
||||
'\n',
|
||||
)
|
||||
|
||||
if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_LE_Read_Suggested_Default_Data_Length_Command()
|
||||
)
|
||||
if command_succeeded(response):
|
||||
print(
|
||||
color('Suggested Default Data Length:', 'yellow'),
|
||||
f'{response.return_parameters.suggested_max_tx_octets}/'
|
||||
f'{response.return_parameters.suggested_max_tx_time}',
|
||||
'\n',
|
||||
)
|
||||
|
||||
print(color('LE Features:', 'yellow'))
|
||||
for feature in host.supported_le_features:
|
||||
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
|
||||
print(f' {LeFeature(feature).name}')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(transport):
|
||||
async def get_acl_flow_control_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.hc_total_num_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_LE_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('LE ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_codecs_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_Read_Local_Supported_Codecs_V2_Command.op_code):
|
||||
response = await host.send_command(
|
||||
HCI_Read_Local_Supported_Codecs_V2_Command(), check_result=True
|
||||
)
|
||||
print(color('Codecs:', 'yellow'))
|
||||
|
||||
for codec_id, transport in zip(
|
||||
response.return_parameters.standard_codec_ids,
|
||||
response.return_parameters.standard_codec_transports,
|
||||
):
|
||||
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
|
||||
transport
|
||||
).name
|
||||
codec_name = CodecID(codec_id).name
|
||||
print(f' {codec_name} - {transport_name}')
|
||||
|
||||
for codec_id, transport in zip(
|
||||
response.return_parameters.vendor_specific_codec_ids,
|
||||
response.return_parameters.vendor_specific_codec_transports,
|
||||
):
|
||||
transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
|
||||
transport
|
||||
).name
|
||||
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
|
||||
print(f' {company} / {codec_id & 0xFFFF} - {transport_name}')
|
||||
|
||||
if not response.return_parameters.standard_codec_ids:
|
||||
print(' No standard codecs')
|
||||
if not response.return_parameters.vendor_specific_codec_ids:
|
||||
print(' No Vendor-specific codecs')
|
||||
|
||||
if host.supports_command(HCI_Read_Local_Supported_Codecs_Command.op_code):
|
||||
response = await host.send_command(
|
||||
HCI_Read_Local_Supported_Codecs_Command(), check_result=True
|
||||
)
|
||||
print(color('Codecs (BR/EDR):', 'yellow'))
|
||||
for codec_id in response.return_parameters.standard_codec_ids:
|
||||
codec_name = CodecID(codec_id).name
|
||||
print(f' {codec_name}')
|
||||
|
||||
for codec_id in response.return_parameters.vendor_specific_codec_ids:
|
||||
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
|
||||
print(f' {company} / {codec_id & 0xFFFF}')
|
||||
|
||||
if not response.return_parameters.standard_codec_ids:
|
||||
print(' No standard codecs')
|
||||
if not response.return_parameters.vendor_specific_codec_ids:
|
||||
print(' No Vendor-specific codecs')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(latency_probes, transport):
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
@@ -130,6 +234,23 @@ async def async_main(transport):
|
||||
host = Host(hci_source, hci_sink)
|
||||
await host.reset()
|
||||
|
||||
# Measure the latency if requested
|
||||
latencies = []
|
||||
if latency_probes:
|
||||
for _ in range(latency_probes):
|
||||
start = time.time()
|
||||
await host.send_command(HCI_Read_Local_Version_Information_Command())
|
||||
latencies.append(1000 * (time.time() - start))
|
||||
print(
|
||||
color('HCI Command Latency:', 'yellow'),
|
||||
(
|
||||
f'min={min(latencies):.2f}, '
|
||||
f'max={max(latencies):.2f}, '
|
||||
f'average={sum(latencies)/len(latencies):.2f}'
|
||||
),
|
||||
'\n',
|
||||
)
|
||||
|
||||
# Print version
|
||||
print(color('Version:', 'yellow'))
|
||||
print(
|
||||
@@ -153,19 +274,31 @@ async def async_main(transport):
|
||||
# Get the LE info
|
||||
await get_le_info(host)
|
||||
|
||||
# Print the ACL flow control info
|
||||
await get_acl_flow_control_info(host)
|
||||
|
||||
# Get codec info
|
||||
await get_codecs_info(host)
|
||||
|
||||
# Print the list of commands supported by the controller
|
||||
print()
|
||||
print(color('Supported Commands:', 'yellow'))
|
||||
for command in host.supported_commands:
|
||||
print(' ', HCI_Command.command_name(command))
|
||||
print(f' {HCI_Command.command_name(command)}')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--latency-probes',
|
||||
metavar='N',
|
||||
type=int,
|
||||
help='Send N commands to measure HCI transport latency statistics',
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(transport):
|
||||
def main(latency_probes, transport):
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
asyncio.run(async_main(transport))
|
||||
asyncio.run(async_main(latency_probes, transport))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
205
apps/controller_loopback.py
Normal file
205
apps/controller_loopback.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from bumble.colors import color
|
||||
from bumble.hci import (
|
||||
HCI_READ_LOOPBACK_MODE_COMMAND,
|
||||
HCI_Read_Loopback_Mode_Command,
|
||||
HCI_WRITE_LOOPBACK_MODE_COMMAND,
|
||||
HCI_Write_Loopback_Mode_Command,
|
||||
LoopbackMode,
|
||||
)
|
||||
from bumble.host import Host
|
||||
from bumble.transport import open_transport_or_link
|
||||
import click
|
||||
|
||||
|
||||
class Loopback:
|
||||
"""Send and receive ACL data packets in local loopback mode"""
|
||||
|
||||
def __init__(self, packet_size: int, packet_count: int, transport: str):
|
||||
self.transport = transport
|
||||
self.packet_size = packet_size
|
||||
self.packet_count = packet_count
|
||||
self.connection_handle: Optional[int] = None
|
||||
self.connection_event = asyncio.Event()
|
||||
self.done = asyncio.Event()
|
||||
self.expected_cid = 0
|
||||
self.bytes_received = 0
|
||||
self.start_timestamp = 0.0
|
||||
self.last_timestamp = 0.0
|
||||
|
||||
def on_connection(self, connection_handle: int, *args):
|
||||
"""Retrieve connection handle from new connection event"""
|
||||
if not self.connection_event.is_set():
|
||||
# save first connection handle for ACL
|
||||
# subsequent connections are SCO
|
||||
self.connection_handle = connection_handle
|
||||
self.connection_event.set()
|
||||
|
||||
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
|
||||
"""Calculate packet receive speed"""
|
||||
now = time.time()
|
||||
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
|
||||
assert connection_handle == self.connection_handle
|
||||
assert cid == self.expected_cid
|
||||
self.expected_cid += 1
|
||||
if cid == 0:
|
||||
self.start_timestamp = now
|
||||
else:
|
||||
elapsed_since_start = now - self.start_timestamp
|
||||
elapsed_since_last = now - self.last_timestamp
|
||||
self.bytes_received += len(pdu)
|
||||
instant_rx_speed = len(pdu) / elapsed_since_last
|
||||
average_rx_speed = self.bytes_received / elapsed_since_start
|
||||
print(
|
||||
color(
|
||||
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
|
||||
f' average={average_rx_speed:.4f}',
|
||||
'cyan',
|
||||
)
|
||||
)
|
||||
|
||||
self.last_timestamp = now
|
||||
|
||||
if self.expected_cid == self.packet_count:
|
||||
print(color('@@@ Received last packet', 'green'))
|
||||
self.done.set()
|
||||
|
||||
async def run(self):
|
||||
"""Run a loopback throughput test"""
|
||||
print(color('>>> Connecting to HCI...', 'green'))
|
||||
async with await open_transport_or_link(self.transport) as (
|
||||
hci_source,
|
||||
hci_sink,
|
||||
):
|
||||
print(color('>>> Connected', 'green'))
|
||||
|
||||
host = Host(hci_source, hci_sink)
|
||||
await host.reset()
|
||||
|
||||
# make sure data can fit in one l2cap pdu
|
||||
l2cap_header_size = 4
|
||||
|
||||
max_packet_size = (
|
||||
host.acl_packet_queue
|
||||
if host.acl_packet_queue
|
||||
else host.le_acl_packet_queue
|
||||
).max_packet_size - l2cap_header_size
|
||||
if self.packet_size > max_packet_size:
|
||||
print(
|
||||
color(
|
||||
f'!!! Packet size ({self.packet_size}) larger than max supported'
|
||||
f' size ({max_packet_size})',
|
||||
'red',
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if not host.supports_command(
|
||||
HCI_WRITE_LOOPBACK_MODE_COMMAND
|
||||
) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND):
|
||||
print(color('!!! Loopback mode not supported', 'red'))
|
||||
return
|
||||
|
||||
# set event callbacks
|
||||
host.on('connection', self.on_connection)
|
||||
host.on('l2cap_pdu', self.on_l2cap_pdu)
|
||||
|
||||
loopback_mode = LoopbackMode.LOCAL
|
||||
|
||||
print(color('### Setting loopback mode', 'blue'))
|
||||
await host.send_command(
|
||||
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
print(color('### Checking loopback mode', 'blue'))
|
||||
response = await host.send_command(
|
||||
HCI_Read_Loopback_Mode_Command(), check_result=True
|
||||
)
|
||||
if response.return_parameters.loopback_mode != loopback_mode:
|
||||
print(color('!!! Loopback mode mismatch', 'red'))
|
||||
return
|
||||
|
||||
await self.connection_event.wait()
|
||||
print(color('### Connected', 'cyan'))
|
||||
|
||||
print(color('=== Start sending', 'magenta'))
|
||||
start_time = time.time()
|
||||
bytes_sent = 0
|
||||
for cid in range(0, self.packet_count):
|
||||
# using the cid as an incremental index
|
||||
host.send_l2cap_pdu(
|
||||
self.connection_handle, cid, bytes(self.packet_size)
|
||||
)
|
||||
print(
|
||||
color(
|
||||
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
|
||||
)
|
||||
)
|
||||
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
|
||||
await asyncio.sleep(0) # yield to allow packet receive
|
||||
|
||||
await self.done.wait()
|
||||
print(color('=== Done!', 'magenta'))
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
average_tx_speed = bytes_sent / elapsed
|
||||
print(
|
||||
color(
|
||||
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
|
||||
f' in {elapsed:.2f} seconds)',
|
||||
'green',
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--packet-size',
|
||||
'-s',
|
||||
metavar='SIZE',
|
||||
type=click.IntRange(8, 4096),
|
||||
default=500,
|
||||
help='Packet size',
|
||||
)
|
||||
@click.option(
|
||||
'--packet-count',
|
||||
'-c',
|
||||
metavar='COUNT',
|
||||
type=click.IntRange(1, 65535),
|
||||
default=10,
|
||||
help='Packet count',
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(packet_size, packet_count, transport):
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
|
||||
loopback = Loopback(packet_size, packet_count, transport)
|
||||
asyncio.run(loopback.run())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
230
apps/device_info.py
Normal file
230
apps/device_info.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
import click
|
||||
|
||||
from bumble.core import ProtocolError
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.gatt import Service
|
||||
from bumble.profiles.device_information_service import DeviceInformationServiceProxy
|
||||
from bumble.profiles.battery_service import BatteryServiceProxy
|
||||
from bumble.profiles.gap import GenericAccessServiceProxy
|
||||
from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy
|
||||
from bumble.transport import open_transport_or_link
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def try_show(function: Callable, *args, **kwargs) -> None:
|
||||
try:
|
||||
await function(*args, **kwargs)
|
||||
except ProtocolError as error:
|
||||
print(color('ERROR:', 'red'), error)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def show_services(services: Iterable[Service]) -> None:
|
||||
for service in services:
|
||||
print(color(str(service), 'cyan'))
|
||||
|
||||
for characteristic in service.characteristics:
|
||||
print(color(' ' + str(characteristic), 'magenta'))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def show_gap_information(
|
||||
gap_service: GenericAccessServiceProxy,
|
||||
):
|
||||
print(color('### Generic Access Profile', 'yellow'))
|
||||
|
||||
if gap_service.device_name:
|
||||
print(
|
||||
color(' Device Name:', 'green'),
|
||||
await gap_service.device_name.read_value(),
|
||||
)
|
||||
|
||||
if gap_service.appearance:
|
||||
print(
|
||||
color(' Appearance: ', 'green'),
|
||||
await gap_service.appearance.read_value(),
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def show_device_information(
|
||||
device_information_service: DeviceInformationServiceProxy,
|
||||
):
|
||||
print(color('### Device Information', 'yellow'))
|
||||
|
||||
if device_information_service.manufacturer_name:
|
||||
print(
|
||||
color(' Manufacturer Name:', 'green'),
|
||||
await device_information_service.manufacturer_name.read_value(),
|
||||
)
|
||||
|
||||
if device_information_service.model_number:
|
||||
print(
|
||||
color(' Model Number: ', 'green'),
|
||||
await device_information_service.model_number.read_value(),
|
||||
)
|
||||
|
||||
if device_information_service.serial_number:
|
||||
print(
|
||||
color(' Serial Number: ', 'green'),
|
||||
await device_information_service.serial_number.read_value(),
|
||||
)
|
||||
|
||||
if device_information_service.firmware_revision:
|
||||
print(
|
||||
color(' Firmware Revision:', 'green'),
|
||||
await device_information_service.firmware_revision.read_value(),
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def show_battery_level(
|
||||
battery_service: BatteryServiceProxy,
|
||||
):
|
||||
print(color('### Battery Information', 'yellow'))
|
||||
|
||||
if battery_service.battery_level:
|
||||
print(
|
||||
color(' Battery Level:', 'green'),
|
||||
await battery_service.battery_level.read_value(),
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def show_tmas(
|
||||
tmas: TelephonyAndMediaAudioServiceProxy,
|
||||
):
|
||||
print(color('### Telephony And Media Audio Service', 'yellow'))
|
||||
|
||||
if tmas.role:
|
||||
print(
|
||||
color(' Role:', 'green'),
|
||||
await tmas.role.read_value(),
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
|
||||
try:
|
||||
# Discover all services
|
||||
print(color('### Discovering Services and Characteristics', 'magenta'))
|
||||
await peer.discover_services()
|
||||
for service in peer.services:
|
||||
await service.discover_characteristics()
|
||||
|
||||
print(color('=== Services ===', 'yellow'))
|
||||
show_services(peer.services)
|
||||
print()
|
||||
|
||||
if gap_service := peer.create_service_proxy(GenericAccessServiceProxy):
|
||||
await try_show(show_gap_information, gap_service)
|
||||
|
||||
if device_information_service := peer.create_service_proxy(
|
||||
DeviceInformationServiceProxy
|
||||
):
|
||||
await try_show(show_device_information, device_information_service)
|
||||
|
||||
if battery_service := peer.create_service_proxy(BatteryServiceProxy):
|
||||
await try_show(show_battery_level, battery_service)
|
||||
|
||||
if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy):
|
||||
await try_show(show_tmas, tmas)
|
||||
|
||||
if done is not None:
|
||||
done.set_result(None)
|
||||
except asyncio.CancelledError:
|
||||
print(color('!!! Operation canceled', 'red'))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(device_config, encrypt, transport, address_or_name):
|
||||
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
|
||||
|
||||
# Create a device
|
||||
if device_config:
|
||||
device = Device.from_config_file_with_hci(
|
||||
device_config, hci_source, hci_sink
|
||||
)
|
||||
else:
|
||||
device = Device.with_hci(
|
||||
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
|
||||
)
|
||||
await device.power_on()
|
||||
|
||||
if address_or_name:
|
||||
# Connect to the target peer
|
||||
print(color('>>> Connecting...', 'green'))
|
||||
connection = await device.connect(address_or_name)
|
||||
print(color('>>> Connected', 'green'))
|
||||
|
||||
# Encrypt the connection if required
|
||||
if encrypt:
|
||||
print(color('+++ Encrypting connection...', 'blue'))
|
||||
await connection.encrypt()
|
||||
print(color('+++ Encryption established', 'blue'))
|
||||
|
||||
await show_device_info(Peer(connection), None)
|
||||
else:
|
||||
# Wait for a connection
|
||||
done = asyncio.get_running_loop().create_future()
|
||||
device.on(
|
||||
'connection',
|
||||
lambda connection: asyncio.create_task(
|
||||
show_device_info(Peer(connection), done)
|
||||
),
|
||||
)
|
||||
await device.start_advertising(auto_restart=True)
|
||||
|
||||
print(color('### Waiting for connection...', 'blue'))
|
||||
await done
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.command()
|
||||
@click.option('--device-config', help='Device configuration', type=click.Path())
|
||||
@click.option('--encrypt', help='Encrypt the connection', is_flag=True, default=False)
|
||||
@click.argument('transport')
|
||||
@click.argument('address-or-name', required=False)
|
||||
def main(device_config, encrypt, transport, address_or_name):
|
||||
"""
|
||||
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
|
||||
wait for an incoming connection.
|
||||
"""
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
asyncio.run(async_main(device_config, encrypt, transport, address_or_name))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -75,11 +75,15 @@ async def async_main(device_config, encrypt, transport, address_or_name):
|
||||
|
||||
if address_or_name:
|
||||
# Connect to the target peer
|
||||
print(color('>>> Connecting...', 'green'))
|
||||
connection = await device.connect(address_or_name)
|
||||
print(color('>>> Connected', 'green'))
|
||||
|
||||
# Encrypt the connection if required
|
||||
if encrypt:
|
||||
print(color('+++ Encrypting connection...', 'blue'))
|
||||
await connection.encrypt()
|
||||
print(color('+++ Encryption established', 'blue'))
|
||||
|
||||
await dump_gatt_db(Peer(connection), None)
|
||||
else:
|
||||
|
||||
@@ -21,6 +21,7 @@ import struct
|
||||
import logging
|
||||
import click
|
||||
|
||||
from bumble import l2cap
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.core import AdvertisingData
|
||||
@@ -204,7 +205,7 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
|
||||
def __init__(self, device):
|
||||
def __init__(self, device: Device):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.peer = None
|
||||
@@ -218,7 +219,12 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
|
||||
|
||||
# Listen for incoming L2CAP CoC connections
|
||||
psm = 0xFB
|
||||
device.register_l2cap_channel_server(0xFB, self.on_coc)
|
||||
device.create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(
|
||||
psm=0xFB,
|
||||
),
|
||||
handler=self.on_coc,
|
||||
)
|
||||
print(f'### Listening for CoC connection on PSM {psm}')
|
||||
|
||||
# Setup the Gattlink service
|
||||
|
||||
@@ -83,7 +83,7 @@ async def async_main():
|
||||
return_parameters=bytes([hci.HCI_SUCCESS]),
|
||||
)
|
||||
# Return a packet with 'respond to sender' set to True
|
||||
return (response.to_bytes(), True)
|
||||
return (bytes(response), True)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import logging
|
||||
import os
|
||||
import click
|
||||
|
||||
from bumble import l2cap
|
||||
from bumble.colors import color
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.device import Device
|
||||
@@ -47,16 +48,17 @@ class ServerBridge:
|
||||
self.tcp_host = tcp_host
|
||||
self.tcp_port = tcp_port
|
||||
|
||||
async def start(self, device):
|
||||
# Listen for incoming L2CAP CoC connections
|
||||
device.register_l2cap_channel_server(
|
||||
psm=self.psm,
|
||||
server=self.on_coc,
|
||||
max_credits=self.max_credits,
|
||||
mtu=self.mtu,
|
||||
mps=self.mps,
|
||||
async def start(self, device: Device) -> None:
|
||||
# Listen for incoming L2CAP channel connections
|
||||
device.create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(
|
||||
psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits
|
||||
),
|
||||
handler=self.on_channel,
|
||||
)
|
||||
print(
|
||||
color(f'### Listening for channel connection on PSM {self.psm}', 'yellow')
|
||||
)
|
||||
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
|
||||
|
||||
def on_ble_connection(connection):
|
||||
def on_ble_disconnection(reason):
|
||||
@@ -73,7 +75,7 @@ class ServerBridge:
|
||||
await device.start_advertising(auto_restart=True)
|
||||
|
||||
# Called when a new L2CAP connection is established
|
||||
def on_coc(self, l2cap_channel):
|
||||
def on_channel(self, l2cap_channel):
|
||||
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
|
||||
|
||||
class Pipe:
|
||||
@@ -83,7 +85,7 @@ class ServerBridge:
|
||||
self.l2cap_channel = l2cap_channel
|
||||
|
||||
l2cap_channel.on('close', self.on_l2cap_close)
|
||||
l2cap_channel.sink = self.on_coc_sdu
|
||||
l2cap_channel.sink = self.on_channel_sdu
|
||||
|
||||
async def connect_to_tcp(self):
|
||||
# Connect to the TCP server
|
||||
@@ -105,7 +107,7 @@ class ServerBridge:
|
||||
asyncio.create_task(self.pipe.l2cap_channel.disconnect())
|
||||
|
||||
def data_received(self, data):
|
||||
print(f'<<< Received on TCP: {len(data)}')
|
||||
print(color(f'<<< [TCP DATA]: {len(data)} bytes', 'blue'))
|
||||
self.pipe.l2cap_channel.write(data)
|
||||
|
||||
try:
|
||||
@@ -123,11 +125,12 @@ class ServerBridge:
|
||||
await self.l2cap_channel.disconnect()
|
||||
|
||||
def on_l2cap_close(self):
|
||||
print(color('*** L2CAP channel closed', 'red'))
|
||||
self.l2cap_channel = None
|
||||
if self.tcp_transport is not None:
|
||||
self.tcp_transport.close()
|
||||
|
||||
def on_coc_sdu(self, sdu):
|
||||
def on_channel_sdu(self, sdu):
|
||||
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
|
||||
if self.tcp_transport is None:
|
||||
print(color('!!! TCP socket not open, dropping', 'red'))
|
||||
@@ -182,7 +185,7 @@ class ClientBridge:
|
||||
peer_name = writer.get_extra_info('peer_name')
|
||||
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
|
||||
|
||||
def on_coc_sdu(sdu):
|
||||
def on_channel_sdu(sdu):
|
||||
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
|
||||
l2cap_to_tcp_pipe.write(sdu)
|
||||
|
||||
@@ -194,11 +197,13 @@ class ClientBridge:
|
||||
# Connect a new L2CAP channel
|
||||
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
|
||||
try:
|
||||
l2cap_channel = await connection.open_l2cap_channel(
|
||||
psm=self.psm,
|
||||
max_credits=self.max_credits,
|
||||
mtu=self.mtu,
|
||||
mps=self.mps,
|
||||
l2cap_channel = await connection.create_l2cap_channel(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(
|
||||
psm=self.psm,
|
||||
max_credits=self.max_credits,
|
||||
mtu=self.mtu,
|
||||
mps=self.mps,
|
||||
)
|
||||
)
|
||||
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
|
||||
except Exception as error:
|
||||
@@ -206,7 +211,7 @@ class ClientBridge:
|
||||
writer.close()
|
||||
return
|
||||
|
||||
l2cap_channel.sink = on_coc_sdu
|
||||
l2cap_channel.sink = on_channel_sdu
|
||||
l2cap_channel.on('close', on_l2cap_close)
|
||||
|
||||
# Start a flow control pipe from L2CAP to TCP
|
||||
@@ -271,23 +276,29 @@ async def run(device_config, hci_transport, bridge):
|
||||
@click.pass_context
|
||||
@click.option('--device-config', help='Device configuration file', required=True)
|
||||
@click.option('--hci-transport', help='HCI transport', required=True)
|
||||
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
|
||||
@click.option('--psm', help='PSM for L2CAP', type=int, default=1234)
|
||||
@click.option(
|
||||
'--l2cap-coc-max-credits',
|
||||
help='Maximum L2CAP CoC Credits',
|
||||
'--l2cap-max-credits',
|
||||
help='Maximum L2CAP Credits',
|
||||
type=click.IntRange(1, 65535),
|
||||
default=128,
|
||||
)
|
||||
@click.option(
|
||||
'--l2cap-coc-mtu',
|
||||
help='L2CAP CoC MTU',
|
||||
type=click.IntRange(23, 65535),
|
||||
default=1022,
|
||||
'--l2cap-mtu',
|
||||
help='L2CAP MTU',
|
||||
type=click.IntRange(
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU,
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU,
|
||||
),
|
||||
default=1024,
|
||||
)
|
||||
@click.option(
|
||||
'--l2cap-coc-mps',
|
||||
help='L2CAP CoC MPS',
|
||||
type=click.IntRange(23, 65533),
|
||||
'--l2cap-mps',
|
||||
help='L2CAP MPS',
|
||||
type=click.IntRange(
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS,
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS,
|
||||
),
|
||||
default=1024,
|
||||
)
|
||||
def cli(
|
||||
@@ -295,17 +306,17 @@ def cli(
|
||||
device_config,
|
||||
hci_transport,
|
||||
psm,
|
||||
l2cap_coc_max_credits,
|
||||
l2cap_coc_mtu,
|
||||
l2cap_coc_mps,
|
||||
l2cap_max_credits,
|
||||
l2cap_mtu,
|
||||
l2cap_mps,
|
||||
):
|
||||
context.ensure_object(dict)
|
||||
context.obj['device_config'] = device_config
|
||||
context.obj['hci_transport'] = hci_transport
|
||||
context.obj['psm'] = psm
|
||||
context.obj['max_credits'] = l2cap_coc_max_credits
|
||||
context.obj['mtu'] = l2cap_coc_mtu
|
||||
context.obj['mps'] = l2cap_coc_mps
|
||||
context.obj['max_credits'] = l2cap_max_credits
|
||||
context.obj['mtu'] = l2cap_mtu
|
||||
context.obj['mps'] = l2cap_mps
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
454
apps/lea_unicast/app.py
Normal file
454
apps/lea_unicast/app.py
Normal file
@@ -0,0 +1,454 @@
|
||||
# Copyright 2021-2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import functools
|
||||
from importlib import resources
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import pathlib
|
||||
import weakref
|
||||
import wave
|
||||
|
||||
try:
|
||||
import lc3 # type: ignore # pylint: disable=E0401
|
||||
except ImportError as e:
|
||||
raise ImportError("Try `python -m pip install \".[lc3]\"`.") from e
|
||||
|
||||
import click
|
||||
import aiohttp.web
|
||||
|
||||
import bumble
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters, CisLink
|
||||
from bumble.transport import open_transport
|
||||
from bumble.profiles import ascs, bap, pacs
|
||||
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
DEFAULT_UI_PORT = 7654
|
||||
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
|
||||
|
||||
|
||||
def _sink_pac_record() -> pacs.PacRecord:
|
||||
return pacs.PacRecord(
|
||||
coding_format=CodingFormat(CodecID.LC3),
|
||||
codec_specific_capabilities=bap.CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=(
|
||||
bap.SupportedSamplingFrequency.FREQ_8000
|
||||
| bap.SupportedSamplingFrequency.FREQ_16000
|
||||
| bap.SupportedSamplingFrequency.FREQ_24000
|
||||
| bap.SupportedSamplingFrequency.FREQ_32000
|
||||
| bap.SupportedSamplingFrequency.FREQ_48000
|
||||
),
|
||||
supported_frame_durations=(
|
||||
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
|
||||
),
|
||||
supported_audio_channel_count=[1, 2],
|
||||
min_octets_per_codec_frame=26,
|
||||
max_octets_per_codec_frame=240,
|
||||
supported_max_codec_frames_per_sdu=2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _source_pac_record() -> pacs.PacRecord:
|
||||
return pacs.PacRecord(
|
||||
coding_format=CodingFormat(CodecID.LC3),
|
||||
codec_specific_capabilities=bap.CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=(
|
||||
bap.SupportedSamplingFrequency.FREQ_8000
|
||||
| bap.SupportedSamplingFrequency.FREQ_16000
|
||||
| bap.SupportedSamplingFrequency.FREQ_24000
|
||||
| bap.SupportedSamplingFrequency.FREQ_32000
|
||||
| bap.SupportedSamplingFrequency.FREQ_48000
|
||||
),
|
||||
supported_frame_durations=(
|
||||
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
|
||||
),
|
||||
supported_audio_channel_count=[1],
|
||||
min_octets_per_codec_frame=30,
|
||||
max_octets_per_codec_frame=100,
|
||||
supported_max_codec_frames_per_sdu=1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
decoder: lc3.Decoder | None = None
|
||||
encoding_config: bap.CodecSpecificConfiguration | None = None
|
||||
|
||||
|
||||
async def lc3_source_task(
|
||||
filename: str,
|
||||
sdu_length: int,
|
||||
frame_duration_us: int,
|
||||
device: Device,
|
||||
cis_link: CisLink,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"lc3_source_task filename=%s, sdu_length=%d, frame_duration=%.1f",
|
||||
filename,
|
||||
sdu_length,
|
||||
frame_duration_us / 1000,
|
||||
)
|
||||
with wave.open(filename, 'rb') as wav:
|
||||
bits_per_sample = wav.getsampwidth() * 8
|
||||
|
||||
encoder: lc3.Encoder | None = None
|
||||
|
||||
while True:
|
||||
next_round = datetime.datetime.now() + datetime.timedelta(
|
||||
microseconds=frame_duration_us
|
||||
)
|
||||
if not encoder:
|
||||
if (
|
||||
encoding_config
|
||||
and (frame_duration := encoding_config.frame_duration)
|
||||
and (sampling_frequency := encoding_config.sampling_frequency)
|
||||
and (
|
||||
audio_channel_allocation := encoding_config.audio_channel_allocation
|
||||
)
|
||||
):
|
||||
logger.info("Use %s", encoding_config)
|
||||
encoder = lc3.Encoder(
|
||||
frame_duration_us=frame_duration.us,
|
||||
sample_rate_hz=sampling_frequency.hz,
|
||||
num_channels=audio_channel_allocation.channel_count,
|
||||
input_sample_rate_hz=wav.getframerate(),
|
||||
)
|
||||
else:
|
||||
sdu = encoder.encode(
|
||||
pcm=wav.readframes(encoder.get_frame_samples()),
|
||||
num_bytes=sdu_length,
|
||||
bit_depth=bits_per_sample,
|
||||
)
|
||||
cis_link.write(sdu)
|
||||
|
||||
sleep_time = next_round - datetime.datetime.now()
|
||||
await asyncio.sleep(sleep_time.total_seconds() * 0.9)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class UiServer:
|
||||
speaker: weakref.ReferenceType[Speaker]
|
||||
port: int
|
||||
|
||||
def __init__(self, speaker: Speaker, port: int) -> None:
|
||||
self.speaker = weakref.ref(speaker)
|
||||
self.port = port
|
||||
self.channel_socket = None
|
||||
|
||||
async def start_http(self) -> None:
|
||||
"""Start the UI HTTP server."""
|
||||
|
||||
app = aiohttp.web.Application()
|
||||
app.add_routes(
|
||||
[
|
||||
aiohttp.web.get('/', self.get_static),
|
||||
aiohttp.web.get('/index.html', self.get_static),
|
||||
aiohttp.web.get('/channel', self.get_channel),
|
||||
]
|
||||
)
|
||||
|
||||
runner = aiohttp.web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = aiohttp.web.TCPSite(runner, 'localhost', self.port)
|
||||
print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green'))
|
||||
await site.start()
|
||||
|
||||
async def get_static(self, request):
|
||||
path = request.path
|
||||
if path == '/':
|
||||
path = '/index.html'
|
||||
if path.endswith('.html'):
|
||||
content_type = 'text/html'
|
||||
elif path.endswith('.js'):
|
||||
content_type = 'text/javascript'
|
||||
elif path.endswith('.css'):
|
||||
content_type = 'text/css'
|
||||
elif path.endswith('.svg'):
|
||||
content_type = 'image/svg+xml'
|
||||
else:
|
||||
content_type = 'text/plain'
|
||||
text = (
|
||||
resources.files("bumble.apps.lea_unicast")
|
||||
.joinpath(pathlib.Path(path).relative_to('/'))
|
||||
.read_text(encoding="utf-8")
|
||||
)
|
||||
return aiohttp.web.Response(text=text, content_type=content_type)
|
||||
|
||||
async def get_channel(self, request):
|
||||
ws = aiohttp.web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
# Process messages until the socket is closed.
|
||||
self.channel_socket = ws
|
||||
async for message in ws:
|
||||
if message.type == aiohttp.WSMsgType.TEXT:
|
||||
logger.debug(f'<<< received message: {message.data}')
|
||||
await self.on_message(message.data)
|
||||
elif message.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.debug(
|
||||
f'channel connection closed with exception {ws.exception()}'
|
||||
)
|
||||
|
||||
self.channel_socket = None
|
||||
logger.debug('--- channel connection closed')
|
||||
|
||||
return ws
|
||||
|
||||
async def on_message(self, message_str: str):
|
||||
# Parse the message as JSON
|
||||
message = json.loads(message_str)
|
||||
|
||||
# Dispatch the message
|
||||
message_type = message['type']
|
||||
message_params = message.get('params', {})
|
||||
handler = getattr(self, f'on_{message_type}_message')
|
||||
if handler:
|
||||
await handler(**message_params)
|
||||
|
||||
async def on_hello_message(self):
|
||||
await self.send_message(
|
||||
'hello',
|
||||
bumble_version=bumble.__version__,
|
||||
codec=self.speaker().codec,
|
||||
streamState=self.speaker().stream_state.name,
|
||||
)
|
||||
if connection := self.speaker().connection:
|
||||
await self.send_message(
|
||||
'connection',
|
||||
peer_address=connection.peer_address.to_string(False),
|
||||
peer_name=connection.peer_name,
|
||||
)
|
||||
|
||||
async def send_message(self, message_type: str, **kwargs) -> None:
|
||||
if self.channel_socket is None:
|
||||
return
|
||||
|
||||
message = {'type': message_type, 'params': kwargs}
|
||||
await self.channel_socket.send_json(message)
|
||||
|
||||
async def send_audio(self, data: bytes) -> None:
|
||||
if self.channel_socket is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.channel_socket.send_bytes(data)
|
||||
except Exception as error:
|
||||
logger.warning(f'exception while sending audio packet: {error}')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Speaker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_config_path: str | None,
|
||||
ui_port: int,
|
||||
transport: str,
|
||||
lc3_input_file_path: str,
|
||||
):
|
||||
self.device_config_path = device_config_path
|
||||
self.transport = transport
|
||||
self.lc3_input_file_path = lc3_input_file_path
|
||||
|
||||
# Create an HTTP server for the UI
|
||||
self.ui_server = UiServer(speaker=self, port=ui_port)
|
||||
|
||||
async def run(self) -> None:
|
||||
await self.ui_server.start_http()
|
||||
|
||||
async with await open_transport(self.transport) as hci_transport:
|
||||
# Create a device
|
||||
if self.device_config_path:
|
||||
device_config = DeviceConfiguration.from_file(self.device_config_path)
|
||||
else:
|
||||
device_config = DeviceConfiguration(
|
||||
name="Bumble LE Headphone",
|
||||
class_of_device=0x244418,
|
||||
keystore="JsonKeyStore",
|
||||
advertising_interval_min=25,
|
||||
advertising_interval_max=25,
|
||||
address=Address('F1:F2:F3:F4:F5:F6'),
|
||||
identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
|
||||
)
|
||||
|
||||
device_config.le_enabled = True
|
||||
device_config.cis_enabled = True
|
||||
self.device = Device.from_config_with_hci(
|
||||
device_config, hci_transport.source, hci_transport.sink
|
||||
)
|
||||
|
||||
self.device.add_service(
|
||||
pacs.PublishedAudioCapabilitiesService(
|
||||
supported_source_context=bap.ContextType(0xFFFF),
|
||||
available_source_context=bap.ContextType(0xFFFF),
|
||||
supported_sink_context=bap.ContextType(0xFFFF), # All context types
|
||||
available_sink_context=bap.ContextType(0xFFFF), # All context types
|
||||
sink_audio_locations=(
|
||||
bap.AudioLocation.FRONT_LEFT | bap.AudioLocation.FRONT_RIGHT
|
||||
),
|
||||
sink_pac=[_sink_pac_record()],
|
||||
source_audio_locations=bap.AudioLocation.FRONT_LEFT,
|
||||
source_pac=[_source_pac_record()],
|
||||
)
|
||||
)
|
||||
|
||||
ascs_service = ascs.AudioStreamControlService(
|
||||
self.device, sink_ase_id=[1], source_ase_id=[2]
|
||||
)
|
||||
self.device.add_service(ascs_service)
|
||||
|
||||
advertising_data = bytes(
|
||||
AdvertisingData(
|
||||
[
|
||||
(
|
||||
AdvertisingData.COMPLETE_LOCAL_NAME,
|
||||
bytes(device_config.name, 'utf-8'),
|
||||
),
|
||||
(
|
||||
AdvertisingData.FLAGS,
|
||||
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
|
||||
),
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
bytes(pacs.PublishedAudioCapabilitiesService.UUID),
|
||||
),
|
||||
]
|
||||
)
|
||||
) + bytes(bap.UnicastServerAdvertisingData())
|
||||
|
||||
def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine):
|
||||
codec_config = ase.codec_specific_configuration
|
||||
if (
|
||||
not isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
or codec_config.frame_duration is None
|
||||
or codec_config.audio_channel_allocation is None
|
||||
or decoder is None
|
||||
or not pdu.iso_sdu_fragment
|
||||
):
|
||||
return
|
||||
pcm = decoder.decode(
|
||||
pdu.iso_sdu_fragment, bit_depth=DEFAULT_PCM_BYTES_PER_SAMPLE * 8
|
||||
)
|
||||
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
|
||||
|
||||
def on_ase_state_change(ase: ascs.AseStateMachine) -> None:
|
||||
codec_config = ase.codec_specific_configuration
|
||||
if ase.state == ascs.AseStateMachine.State.STREAMING:
|
||||
if ase.role == ascs.AudioRole.SOURCE:
|
||||
if (
|
||||
not isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
or ase.cis_link is None
|
||||
or codec_config.octets_per_codec_frame is None
|
||||
or codec_config.frame_duration is None
|
||||
or codec_config.codec_frames_per_sdu is None
|
||||
):
|
||||
return
|
||||
ase.cis_link.abort_on(
|
||||
'disconnection',
|
||||
lc3_source_task(
|
||||
filename=self.lc3_input_file_path,
|
||||
sdu_length=(
|
||||
codec_config.codec_frames_per_sdu
|
||||
* codec_config.octets_per_codec_frame
|
||||
),
|
||||
frame_duration_us=codec_config.frame_duration.us,
|
||||
device=self.device,
|
||||
cis_link=ase.cis_link,
|
||||
),
|
||||
)
|
||||
else:
|
||||
if not ase.cis_link:
|
||||
return
|
||||
ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
|
||||
elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED:
|
||||
if (
|
||||
not isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
or codec_config.sampling_frequency is None
|
||||
or codec_config.frame_duration is None
|
||||
or codec_config.audio_channel_allocation is None
|
||||
):
|
||||
return
|
||||
if ase.role == ascs.AudioRole.SOURCE:
|
||||
global encoding_config
|
||||
encoding_config = codec_config
|
||||
else:
|
||||
global decoder
|
||||
decoder = lc3.Decoder(
|
||||
frame_duration_us=codec_config.frame_duration.us,
|
||||
sample_rate_hz=codec_config.sampling_frequency.hz,
|
||||
num_channels=codec_config.audio_channel_allocation.channel_count,
|
||||
)
|
||||
|
||||
for ase in ascs_service.ase_state_machines.values():
|
||||
ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
|
||||
|
||||
await self.device.power_on()
|
||||
await self.device.create_advertising_set(
|
||||
advertising_data=advertising_data,
|
||||
auto_restart=True,
|
||||
advertising_parameters=AdvertisingParameters(
|
||||
primary_advertising_interval_min=100,
|
||||
primary_advertising_interval_max=100,
|
||||
),
|
||||
)
|
||||
|
||||
await hci_transport.source.terminated
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--ui-port',
|
||||
'ui_port',
|
||||
metavar='HTTP_PORT',
|
||||
default=DEFAULT_UI_PORT,
|
||||
show_default=True,
|
||||
help='HTTP port for the UI server',
|
||||
)
|
||||
@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
|
||||
@click.argument('transport')
|
||||
@click.argument('lc3_file')
|
||||
def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) -> None:
|
||||
"""Run the speaker."""
|
||||
|
||||
asyncio.run(Speaker(device_config, ui_port, transport, lc3_file).run())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
speaker()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
68
apps/lea_unicast/index.html
Normal file
68
apps/lea_unicast/index.html
Normal file
@@ -0,0 +1,68 @@
|
||||
<html data-bs-theme="dark">
|
||||
|
||||
<head>
|
||||
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
|
||||
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
|
||||
<script src="https://unpkg.com/pcm-player"></script>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<nav class="navbar navbar-dark bg-primary">
|
||||
<div class="container">
|
||||
<span class="navbar-brand mb-0 h1">Bumble Unicast Server</span>
|
||||
</div>
|
||||
</nav>
|
||||
<br>
|
||||
|
||||
<div class="container">
|
||||
<button type="button" class="btn btn-danger" id="connect-audio" onclick="connectAudio()">Connect Audio</button>
|
||||
<button class="btn btn-primary" type="button" disabled>
|
||||
<span class="spinner-border spinner-border-sm" id="ws-status-spinner" aria-hidden="true"></span>
|
||||
<span role="status" id="ws-status">WebSocket Connecting...</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
|
||||
<script>
|
||||
let player = null;
|
||||
const wsStatus = document.getElementById("ws-status");
|
||||
const wsStatusSpinner = document.getElementById("ws-status-spinner");
|
||||
|
||||
const socket = new WebSocket('ws://127.0.0.1:7654/channel');
|
||||
socket.binaryType = "arraybuffer";
|
||||
socket.onmessage = function (message) {
|
||||
if (typeof message.data === 'string' || message.data instanceof String) {
|
||||
console.log(`channel MESSAGE: ${message.data}`);
|
||||
} else {
|
||||
console.log(typeof (message.data))
|
||||
// BINARY audio data.
|
||||
if (player == null) return;
|
||||
player.feed(message.data);
|
||||
}
|
||||
};
|
||||
|
||||
socket.onopen = (message) => {
|
||||
wsStatusSpinner.remove();
|
||||
wsStatus.textContent = "WebSocket Connected";
|
||||
}
|
||||
|
||||
socket.onclose = (message) => {
|
||||
wsStatus.textContent = "WebSocket Disconnected";
|
||||
}
|
||||
|
||||
function connectAudio() {
|
||||
player = new PCMPlayer({
|
||||
inputCodec: 'Int16',
|
||||
channels: 2,
|
||||
sampleRate: 48000,
|
||||
flushTime: 10,
|
||||
});
|
||||
const button = document.getElementById("connect-audio")
|
||||
button.disabled = true;
|
||||
button.textContent = "Audio Connected";
|
||||
}
|
||||
</script>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
121
apps/pair.py
121
apps/pair.py
@@ -24,10 +24,16 @@ from prompt_toolkit.shortcuts import PromptSession
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.pairing import PairingDelegate, PairingConfig
|
||||
from bumble.pairing import OobData, PairingDelegate, PairingConfig
|
||||
from bumble.smp import OobContext, OobLegacyContext
|
||||
from bumble.smp import error_name as smp_error_name
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.core import ProtocolError
|
||||
from bumble.core import (
|
||||
AdvertisingData,
|
||||
ProtocolError,
|
||||
BT_LE_TRANSPORT,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
)
|
||||
from bumble.gatt import (
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
GATT_GENERIC_ACCESS_SERVICE,
|
||||
@@ -40,17 +46,25 @@ from bumble.att import (
|
||||
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
|
||||
ATT_INSUFFICIENT_ENCRYPTION_ERROR,
|
||||
)
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
POST_PAIRING_DELAY = 1
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Waiter:
|
||||
instance = None
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, linger=False):
|
||||
self.done = asyncio.get_running_loop().create_future()
|
||||
self.linger = linger
|
||||
|
||||
def terminate(self):
|
||||
self.done.set_result(None)
|
||||
if not self.linger:
|
||||
self.done.set_result(None)
|
||||
|
||||
async def wait_until_terminated(self):
|
||||
return await self.done
|
||||
@@ -60,7 +74,7 @@ class Waiter:
|
||||
class Delegate(PairingDelegate):
|
||||
def __init__(self, mode, connection, capability_string, do_prompt):
|
||||
super().__init__(
|
||||
{
|
||||
io_capability={
|
||||
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
|
||||
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
|
||||
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
|
||||
@@ -227,8 +241,10 @@ def on_connection(connection, request):
|
||||
|
||||
# Listen for pairing events
|
||||
connection.on('pairing_start', on_pairing_start)
|
||||
connection.on('pairing', lambda keys: on_pairing(connection.peer_address, keys))
|
||||
connection.on('pairing_failure', on_pairing_failure)
|
||||
connection.on('pairing', lambda keys: on_pairing(connection, keys))
|
||||
connection.on(
|
||||
'pairing_failure', lambda reason: on_pairing_failure(connection, reason)
|
||||
)
|
||||
|
||||
# Listen for encryption changes
|
||||
connection.on(
|
||||
@@ -262,19 +278,24 @@ def on_pairing_start():
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_pairing(address, keys):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_pairing(connection, keys):
|
||||
print(color('***-----------------------------------', 'cyan'))
|
||||
print(color(f'*** Paired! (peer identity={address})', 'cyan'))
|
||||
print(color(f'*** Paired! (peer identity={connection.peer_address})', 'cyan'))
|
||||
keys.print(prefix=color('*** ', 'cyan'))
|
||||
print(color('***-----------------------------------', 'cyan'))
|
||||
await asyncio.sleep(POST_PAIRING_DELAY)
|
||||
await connection.disconnect()
|
||||
Waiter.instance.terminate()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_pairing_failure(reason):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_pairing_failure(connection, reason):
|
||||
print(color('***-----------------------------------', 'red'))
|
||||
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
|
||||
print(color('***-----------------------------------', 'red'))
|
||||
await connection.disconnect()
|
||||
Waiter.instance.terminate()
|
||||
|
||||
|
||||
@@ -285,7 +306,10 @@ async def pair(
|
||||
mitm,
|
||||
bond,
|
||||
ctkd,
|
||||
identity_address,
|
||||
linger,
|
||||
io,
|
||||
oob,
|
||||
prompt,
|
||||
request,
|
||||
print_keys,
|
||||
@@ -294,7 +318,7 @@ async def pair(
|
||||
hci_transport,
|
||||
address_or_name,
|
||||
):
|
||||
Waiter.instance = Waiter()
|
||||
Waiter.instance = Waiter(linger=linger)
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
|
||||
@@ -306,6 +330,7 @@ async def pair(
|
||||
# Expose a GATT characteristic that can be used to trigger pairing by
|
||||
# responding with an authentication error when read
|
||||
if mode == 'le':
|
||||
device.le_enabled = True
|
||||
device.add_service(
|
||||
Service(
|
||||
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
|
||||
@@ -326,7 +351,6 @@ async def pair(
|
||||
# Select LE or Classic
|
||||
if mode == 'classic':
|
||||
device.classic_enabled = True
|
||||
device.le_enabled = False
|
||||
device.classic_smp_enabled = ctkd
|
||||
|
||||
# Get things going
|
||||
@@ -343,16 +367,63 @@ async def pair(
|
||||
await device.keystore.print(prefix=color('@@@ ', 'blue'))
|
||||
print(color('@@@-----------------------------------', 'blue'))
|
||||
|
||||
# Create an OOB context if needed
|
||||
if oob:
|
||||
our_oob_context = OobContext()
|
||||
shared_data = (
|
||||
None
|
||||
if oob == '-'
|
||||
else OobData.from_ad(
|
||||
AdvertisingData.from_bytes(bytes.fromhex(oob))
|
||||
).shared_data
|
||||
)
|
||||
legacy_context = OobLegacyContext()
|
||||
oob_contexts = PairingConfig.OobConfig(
|
||||
our_context=our_oob_context,
|
||||
peer_data=shared_data,
|
||||
legacy_context=legacy_context,
|
||||
)
|
||||
print(color('@@@-----------------------------------', 'yellow'))
|
||||
print(color('@@@ OOB Data:', 'yellow'))
|
||||
if shared_data is None:
|
||||
oob_data = OobData(
|
||||
address=device.random_address, shared_data=our_oob_context.share()
|
||||
)
|
||||
print(
|
||||
color(
|
||||
f'@@@ SHARE: {bytes(oob_data.to_ad()).hex()}',
|
||||
'yellow',
|
||||
)
|
||||
)
|
||||
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
|
||||
print(color('@@@-----------------------------------', 'yellow'))
|
||||
else:
|
||||
oob_contexts = None
|
||||
|
||||
# Set up a pairing config factory
|
||||
if identity_address == 'public':
|
||||
identity_address_type = PairingConfig.AddressType.PUBLIC
|
||||
elif identity_address == 'random':
|
||||
identity_address_type = PairingConfig.AddressType.RANDOM
|
||||
else:
|
||||
identity_address_type = None
|
||||
device.pairing_config_factory = lambda connection: PairingConfig(
|
||||
sc, mitm, bond, Delegate(mode, connection, io, prompt)
|
||||
sc=sc,
|
||||
mitm=mitm,
|
||||
bonding=bond,
|
||||
oob=oob_contexts,
|
||||
identity_address_type=identity_address_type,
|
||||
delegate=Delegate(mode, connection, io, prompt),
|
||||
)
|
||||
|
||||
# Connect to a peer or wait for a connection
|
||||
device.on('connection', lambda connection: on_connection(connection, request))
|
||||
if address_or_name is not None:
|
||||
print(color(f'=== Connecting to {address_or_name}...', 'green'))
|
||||
connection = await device.connect(address_or_name)
|
||||
connection = await device.connect(
|
||||
address_or_name,
|
||||
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
|
||||
)
|
||||
|
||||
if not request:
|
||||
try:
|
||||
@@ -360,10 +431,9 @@ async def pair(
|
||||
await connection.pair()
|
||||
else:
|
||||
await connection.authenticate()
|
||||
return
|
||||
except ProtocolError as error:
|
||||
print(color(f'Pairing failed: {error}', 'red'))
|
||||
return
|
||||
|
||||
else:
|
||||
if mode == 'le':
|
||||
# Advertise so that peers can find us and connect
|
||||
@@ -413,6 +483,11 @@ class LogHandler(logging.Handler):
|
||||
help='Enable CTKD',
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
'--identity-address',
|
||||
type=click.Choice(['random', 'public']),
|
||||
)
|
||||
@click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
|
||||
@click.option(
|
||||
'--io',
|
||||
type=click.Choice(
|
||||
@@ -421,6 +496,14 @@ class LogHandler(logging.Handler):
|
||||
default='display+keyboard',
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
'--oob',
|
||||
metavar='<oob-data-hex>',
|
||||
help=(
|
||||
'Use OOB pairing with this data from the peer '
|
||||
'(use "-" to enable OOB without peer data)'
|
||||
),
|
||||
)
|
||||
@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
|
||||
@click.option(
|
||||
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
|
||||
@@ -440,7 +523,10 @@ def main(
|
||||
mitm,
|
||||
bond,
|
||||
ctkd,
|
||||
identity_address,
|
||||
linger,
|
||||
io,
|
||||
oob,
|
||||
prompt,
|
||||
request,
|
||||
print_keys,
|
||||
@@ -463,7 +549,10 @@ def main(
|
||||
mitm,
|
||||
bond,
|
||||
ctkd,
|
||||
identity_address,
|
||||
linger,
|
||||
io,
|
||||
oob,
|
||||
prompt,
|
||||
request,
|
||||
print_keys,
|
||||
|
||||
@@ -3,7 +3,7 @@ import click
|
||||
import logging
|
||||
import json
|
||||
|
||||
from bumble.pandora import PandoraDevice, serve
|
||||
from bumble.pandora import PandoraDevice, Config, serve
|
||||
from typing import Dict, Any
|
||||
|
||||
BUMBLE_SERVER_GRPC_PORT = 7999
|
||||
@@ -29,12 +29,14 @@ def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> No
|
||||
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
|
||||
|
||||
bumble_config = retrieve_config(config)
|
||||
if 'transport' not in bumble_config.keys():
|
||||
bumble_config.update({'transport': transport})
|
||||
bumble_config.setdefault('transport', transport)
|
||||
device = PandoraDevice(bumble_config)
|
||||
|
||||
server_config = Config()
|
||||
server_config.load_from_dict(bumble_config.get('server', {}))
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
asyncio.run(serve(device, port=grpc_port))
|
||||
asyncio.run(serve(device, config=server_config, port=grpc_port))
|
||||
|
||||
|
||||
def retrieve_config(config: str) -> Dict[str, Any]:
|
||||
|
||||
608
apps/player/player.py
Normal file
608
apps/player/player.py
Normal file
@@ -0,0 +1,608 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import asyncio.subprocess
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import click
|
||||
|
||||
from bumble.a2dp import (
|
||||
make_audio_source_service_sdp_records,
|
||||
A2DP_SBC_CODEC_TYPE,
|
||||
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
|
||||
A2DP_NON_A2DP_CODEC_TYPE,
|
||||
AacFrame,
|
||||
AacParser,
|
||||
AacPacketSource,
|
||||
AacMediaCodecInformation,
|
||||
SbcFrame,
|
||||
SbcParser,
|
||||
SbcPacketSource,
|
||||
SbcMediaCodecInformation,
|
||||
OpusPacket,
|
||||
OpusParser,
|
||||
OpusPacketSource,
|
||||
OpusMediaCodecInformation,
|
||||
)
|
||||
from bumble.avrcp import Protocol as AvrcpProtocol
|
||||
from bumble.avdtp import (
|
||||
find_avdtp_service_with_connection,
|
||||
AVDTP_AUDIO_MEDIA_TYPE,
|
||||
AVDTP_DELAY_REPORTING_SERVICE_CATEGORY,
|
||||
MediaCodecCapabilities,
|
||||
MediaPacketPump,
|
||||
Protocol as AvdtpProtocol,
|
||||
)
|
||||
from bumble.colors import color
|
||||
from bumble.core import (
|
||||
AdvertisingData,
|
||||
ConnectionError as BumbleConnectionError,
|
||||
DeviceClass,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
)
|
||||
from bumble.device import Connection, Device, DeviceConfiguration
|
||||
from bumble.hci import Address, HCI_CONNECTION_ALREADY_EXISTS_ERROR, HCI_Constant
|
||||
from bumble.pairing import PairingConfig
|
||||
from bumble.transport import open_transport
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def a2dp_source_sdp_records():
|
||||
service_record_handle = 0x00010001
|
||||
return {
|
||||
service_record_handle: make_audio_source_service_sdp_records(
|
||||
service_record_handle
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def sbc_codec_capabilities(read_function) -> MediaCodecCapabilities:
|
||||
sbc_parser = SbcParser(read_function)
|
||||
sbc_frame: SbcFrame
|
||||
async for sbc_frame in sbc_parser.frames:
|
||||
# We only need the first frame
|
||||
print(color(f"SBC format: {sbc_frame}", "cyan"))
|
||||
break
|
||||
|
||||
channel_mode = [
|
||||
SbcMediaCodecInformation.ChannelMode.MONO,
|
||||
SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL,
|
||||
SbcMediaCodecInformation.ChannelMode.STEREO,
|
||||
SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
|
||||
][sbc_frame.channel_mode]
|
||||
block_length = {
|
||||
4: SbcMediaCodecInformation.BlockLength.BL_4,
|
||||
8: SbcMediaCodecInformation.BlockLength.BL_8,
|
||||
12: SbcMediaCodecInformation.BlockLength.BL_12,
|
||||
16: SbcMediaCodecInformation.BlockLength.BL_16,
|
||||
}[sbc_frame.block_count]
|
||||
subbands = {
|
||||
4: SbcMediaCodecInformation.Subbands.S_4,
|
||||
8: SbcMediaCodecInformation.Subbands.S_8,
|
||||
}[sbc_frame.subband_count]
|
||||
allocation_method = [
|
||||
SbcMediaCodecInformation.AllocationMethod.LOUDNESS,
|
||||
SbcMediaCodecInformation.AllocationMethod.SNR,
|
||||
][sbc_frame.allocation_method]
|
||||
return MediaCodecCapabilities(
|
||||
media_type=AVDTP_AUDIO_MEDIA_TYPE,
|
||||
media_codec_type=A2DP_SBC_CODEC_TYPE,
|
||||
media_codec_information=SbcMediaCodecInformation(
|
||||
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.from_int(
|
||||
sbc_frame.sampling_frequency
|
||||
),
|
||||
channel_mode=channel_mode,
|
||||
block_length=block_length,
|
||||
subbands=subbands,
|
||||
allocation_method=allocation_method,
|
||||
minimum_bitpool_value=2,
|
||||
maximum_bitpool_value=40,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def aac_codec_capabilities(read_function) -> MediaCodecCapabilities:
|
||||
aac_parser = AacParser(read_function)
|
||||
aac_frame: AacFrame
|
||||
async for aac_frame in aac_parser.frames:
|
||||
# We only need the first frame
|
||||
print(color(f"AAC format: {aac_frame}", "cyan"))
|
||||
break
|
||||
|
||||
sampling_frequency = AacMediaCodecInformation.SamplingFrequency.from_int(
|
||||
aac_frame.sampling_frequency
|
||||
)
|
||||
channels = (
|
||||
AacMediaCodecInformation.Channels.MONO
|
||||
if aac_frame.channel_configuration == 1
|
||||
else AacMediaCodecInformation.Channels.STEREO
|
||||
)
|
||||
|
||||
return MediaCodecCapabilities(
|
||||
media_type=AVDTP_AUDIO_MEDIA_TYPE,
|
||||
media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE,
|
||||
media_codec_information=AacMediaCodecInformation(
|
||||
object_type=AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC,
|
||||
sampling_frequency=sampling_frequency,
|
||||
channels=channels,
|
||||
vbr=1,
|
||||
bitrate=128000,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def opus_codec_capabilities(read_function) -> MediaCodecCapabilities:
|
||||
opus_parser = OpusParser(read_function)
|
||||
opus_packet: OpusPacket
|
||||
async for opus_packet in opus_parser.packets:
|
||||
# We only need the first packet
|
||||
print(color(f"Opus format: {opus_packet}", "cyan"))
|
||||
break
|
||||
|
||||
if opus_packet.channel_mode == OpusPacket.ChannelMode.MONO:
|
||||
channel_mode = OpusMediaCodecInformation.ChannelMode.MONO
|
||||
elif opus_packet.channel_mode == OpusPacket.ChannelMode.STEREO:
|
||||
channel_mode = OpusMediaCodecInformation.ChannelMode.STEREO
|
||||
else:
|
||||
channel_mode = OpusMediaCodecInformation.ChannelMode.DUAL_MONO
|
||||
|
||||
if opus_packet.duration == 10:
|
||||
frame_size = OpusMediaCodecInformation.FrameSize.FS_10MS
|
||||
else:
|
||||
frame_size = OpusMediaCodecInformation.FrameSize.FS_20MS
|
||||
|
||||
return MediaCodecCapabilities(
|
||||
media_type=AVDTP_AUDIO_MEDIA_TYPE,
|
||||
media_codec_type=A2DP_NON_A2DP_CODEC_TYPE,
|
||||
media_codec_information=OpusMediaCodecInformation(
|
||||
channel_mode=channel_mode,
|
||||
sampling_frequency=OpusMediaCodecInformation.SamplingFrequency.SF_48000,
|
||||
frame_size=frame_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Player:
|
||||
def __init__(
|
||||
self,
|
||||
transport: str,
|
||||
device_config: Optional[str],
|
||||
authenticate: bool,
|
||||
encrypt: bool,
|
||||
) -> None:
|
||||
self.transport = transport
|
||||
self.device_config = device_config
|
||||
self.authenticate = authenticate
|
||||
self.encrypt = encrypt
|
||||
self.avrcp_protocol: Optional[AvrcpProtocol] = None
|
||||
self.done: Optional[asyncio.Event]
|
||||
|
||||
async def run(self, workload) -> None:
|
||||
self.done = asyncio.Event()
|
||||
try:
|
||||
await self._run(workload)
|
||||
except Exception as error:
|
||||
print(color(f"!!! ERROR: {error}", "red"))
|
||||
|
||||
async def _run(self, workload) -> None:
|
||||
async with await open_transport(self.transport) as (hci_source, hci_sink):
|
||||
# Create a device
|
||||
device_config = DeviceConfiguration()
|
||||
if self.device_config:
|
||||
device_config.load_from_file(self.device_config)
|
||||
else:
|
||||
device_config.name = "Bumble Player"
|
||||
device_config.class_of_device = DeviceClass.pack_class_of_device(
|
||||
DeviceClass.AUDIO_SERVICE_CLASS,
|
||||
DeviceClass.AUDIO_VIDEO_MAJOR_DEVICE_CLASS,
|
||||
DeviceClass.AUDIO_VIDEO_UNCATEGORIZED_MINOR_DEVICE_CLASS,
|
||||
)
|
||||
device_config.keystore = "JsonKeyStore"
|
||||
|
||||
device_config.classic_enabled = True
|
||||
device_config.le_enabled = False
|
||||
device_config.le_simultaneous_enabled = False
|
||||
device_config.classic_sc_enabled = False
|
||||
device_config.classic_smp_enabled = False
|
||||
device = Device.from_config_with_hci(device_config, hci_source, hci_sink)
|
||||
|
||||
# Setup the SDP records to expose the SRC service
|
||||
device.sdp_service_records = a2dp_source_sdp_records()
|
||||
|
||||
# Setup AVRCP
|
||||
self.avrcp_protocol = AvrcpProtocol()
|
||||
self.avrcp_protocol.listen(device)
|
||||
|
||||
# Don't require MITM when pairing.
|
||||
device.pairing_config_factory = lambda connection: PairingConfig(mitm=False)
|
||||
|
||||
# Start the controller
|
||||
await device.power_on()
|
||||
|
||||
# Print some of the config/properties
|
||||
print(
|
||||
"Player Bluetooth Address:",
|
||||
color(
|
||||
device.public_address.to_string(with_type_qualifier=False),
|
||||
"yellow",
|
||||
),
|
||||
)
|
||||
|
||||
# Listen for connections
|
||||
device.on("connection", self.on_bluetooth_connection)
|
||||
|
||||
# Run the workload
|
||||
try:
|
||||
await workload(device)
|
||||
except BumbleConnectionError as error:
|
||||
if error.error_code == HCI_CONNECTION_ALREADY_EXISTS_ERROR:
|
||||
print(color("Connection already established", "blue"))
|
||||
else:
|
||||
print(color(f"Failed to connect: {error}", "red"))
|
||||
|
||||
# Wait until it is time to exit
|
||||
assert self.done is not None
|
||||
await asyncio.wait(
|
||||
[hci_source.terminated, asyncio.ensure_future(self.done.wait())],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
def on_bluetooth_connection(self, connection: Connection) -> None:
|
||||
print(color(f"--- Connected: {connection}", "cyan"))
|
||||
connection.on("disconnection", self.on_bluetooth_disconnection)
|
||||
|
||||
def on_bluetooth_disconnection(self, reason) -> None:
|
||||
print(color(f"--- Disconnected: {HCI_Constant.error_name(reason)}", "cyan"))
|
||||
self.set_done()
|
||||
|
||||
async def connect(self, device: Device, address: str) -> Connection:
|
||||
print(color(f"Connecting to {address}...", "green"))
|
||||
connection = await device.connect(address, transport=BT_BR_EDR_TRANSPORT)
|
||||
|
||||
# Request authentication
|
||||
if self.authenticate:
|
||||
print(color("*** Authenticating...", "blue"))
|
||||
await connection.authenticate()
|
||||
print(color("*** Authenticated", "blue"))
|
||||
|
||||
# Enable encryption
|
||||
if self.encrypt:
|
||||
print(color("*** Enabling encryption...", "blue"))
|
||||
await connection.encrypt()
|
||||
print(color("*** Encryption on", "blue"))
|
||||
|
||||
return connection
|
||||
|
||||
async def create_avdtp_protocol(self, connection: Connection) -> AvdtpProtocol:
|
||||
# Look for an A2DP service
|
||||
avdtp_version = await find_avdtp_service_with_connection(connection)
|
||||
if not avdtp_version:
|
||||
raise RuntimeError("no A2DP service found")
|
||||
|
||||
print(color(f"AVDTP Version: {avdtp_version}"))
|
||||
|
||||
# Create a client to interact with the remote device
|
||||
return await AvdtpProtocol.connect(connection, avdtp_version)
|
||||
|
||||
async def stream_packets(
|
||||
self,
|
||||
protocol: AvdtpProtocol,
|
||||
codec_type: int,
|
||||
vendor_id: int,
|
||||
codec_id: int,
|
||||
packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource],
|
||||
codec_capabilities: MediaCodecCapabilities,
|
||||
):
|
||||
# Discover all endpoints on the remote device
|
||||
endpoints = await protocol.discover_remote_endpoints()
|
||||
for endpoint in endpoints:
|
||||
print('@@@', endpoint)
|
||||
|
||||
# Select a sink
|
||||
sink = protocol.find_remote_sink_by_codec(
|
||||
AVDTP_AUDIO_MEDIA_TYPE, codec_type, vendor_id, codec_id
|
||||
)
|
||||
if sink is None:
|
||||
print(color('!!! no compatible sink found', 'red'))
|
||||
return
|
||||
print(f'### Selected sink: {sink.seid}')
|
||||
|
||||
# Check if the sink supports delay reporting
|
||||
delay_reporting = False
|
||||
for capability in sink.capabilities:
|
||||
if capability.service_category == AVDTP_DELAY_REPORTING_SERVICE_CATEGORY:
|
||||
delay_reporting = True
|
||||
break
|
||||
|
||||
def on_delay_report(delay: int):
|
||||
print(color(f"*** DELAY REPORT: {delay}", "blue"))
|
||||
|
||||
# Adjust the codec capabilities for certain codecs
|
||||
for capability in sink.capabilities:
|
||||
if isinstance(capability, MediaCodecCapabilities):
|
||||
if isinstance(
|
||||
codec_capabilities.media_codec_information, SbcMediaCodecInformation
|
||||
) and isinstance(
|
||||
capability.media_codec_information, SbcMediaCodecInformation
|
||||
):
|
||||
codec_capabilities.media_codec_information.minimum_bitpool_value = (
|
||||
capability.media_codec_information.minimum_bitpool_value
|
||||
)
|
||||
codec_capabilities.media_codec_information.maximum_bitpool_value = (
|
||||
capability.media_codec_information.maximum_bitpool_value
|
||||
)
|
||||
print(color("Source media codec:", "green"), codec_capabilities)
|
||||
|
||||
# Stream the packets
|
||||
packet_pump = MediaPacketPump(packet_source.packets)
|
||||
source = protocol.add_source(codec_capabilities, packet_pump, delay_reporting)
|
||||
source.on("delay_report", on_delay_report)
|
||||
stream = await protocol.create_stream(source, sink)
|
||||
await stream.start()
|
||||
|
||||
await packet_pump.wait_for_completion()
|
||||
|
||||
async def discover(self, device: Device) -> None:
|
||||
@device.listens_to("inquiry_result")
|
||||
def on_inquiry_result(
|
||||
address: Address, class_of_device: int, data: AdvertisingData, rssi: int
|
||||
) -> None:
|
||||
(
|
||||
service_classes,
|
||||
major_device_class,
|
||||
minor_device_class,
|
||||
) = DeviceClass.split_class_of_device(class_of_device)
|
||||
separator = "\n "
|
||||
print(f">>> {color(address.to_string(False), 'yellow')}:")
|
||||
print(f" Device Class (raw): {class_of_device:06X}")
|
||||
major_class_name = DeviceClass.major_device_class_name(major_device_class)
|
||||
print(" Device Major Class: " f"{major_class_name}")
|
||||
minor_class_name = DeviceClass.minor_device_class_name(
|
||||
major_device_class, minor_device_class
|
||||
)
|
||||
print(" Device Minor Class: " f"{minor_class_name}")
|
||||
print(
|
||||
" Device Services: "
|
||||
f"{', '.join(DeviceClass.service_class_labels(service_classes))}"
|
||||
)
|
||||
print(f" RSSI: {rssi}")
|
||||
if data.ad_structures:
|
||||
print(f" {data.to_string(separator)}")
|
||||
|
||||
await device.start_discovery()
|
||||
|
||||
async def pair(self, device: Device, address: str) -> None:
|
||||
print(color(f"Connecting to {address}...", "green"))
|
||||
connection = await device.connect(address, transport=BT_BR_EDR_TRANSPORT)
|
||||
|
||||
print(color("Pairing...", "magenta"))
|
||||
await connection.authenticate()
|
||||
print(color("Pairing completed", "magenta"))
|
||||
self.set_done()
|
||||
|
||||
async def inquire(self, device: Device, address: str) -> None:
|
||||
connection = await self.connect(device, address)
|
||||
avdtp_protocol = await self.create_avdtp_protocol(connection)
|
||||
|
||||
# Discover the remote endpoints
|
||||
endpoints = await avdtp_protocol.discover_remote_endpoints()
|
||||
print(f'@@@ Found {len(list(endpoints))} endpoints')
|
||||
for endpoint in endpoints:
|
||||
print('@@@', endpoint)
|
||||
|
||||
self.set_done()
|
||||
|
||||
async def play(
|
||||
self,
|
||||
device: Device,
|
||||
address: Optional[str],
|
||||
audio_format: str,
|
||||
audio_file: str,
|
||||
) -> None:
|
||||
if audio_format == "auto":
|
||||
if audio_file.endswith(".sbc"):
|
||||
audio_format = "sbc"
|
||||
elif audio_file.endswith(".aac") or audio_file.endswith(".adts"):
|
||||
audio_format = "aac"
|
||||
elif audio_file.endswith(".ogg"):
|
||||
audio_format = "opus"
|
||||
else:
|
||||
raise ValueError("Unable to determine audio format from file extension")
|
||||
|
||||
device.on(
|
||||
"connection",
|
||||
lambda connection: AsyncRunner.spawn(on_connection(connection)),
|
||||
)
|
||||
|
||||
async def on_connection(connection: Connection):
|
||||
avdtp_protocol = await self.create_avdtp_protocol(connection)
|
||||
|
||||
with open(audio_file, 'rb') as input_file:
|
||||
# NOTE: this should be using asyncio file reading, but blocking reads
|
||||
# are good enough for this command line app.
|
||||
async def read_audio_data(byte_count):
|
||||
return input_file.read(byte_count)
|
||||
|
||||
# Obtain the codec capabilities from the stream
|
||||
packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource]
|
||||
vendor_id = 0
|
||||
codec_id = 0
|
||||
if audio_format == "sbc":
|
||||
codec_type = A2DP_SBC_CODEC_TYPE
|
||||
codec_capabilities = await sbc_codec_capabilities(read_audio_data)
|
||||
packet_source = SbcPacketSource(
|
||||
read_audio_data,
|
||||
avdtp_protocol.l2cap_channel.peer_mtu,
|
||||
)
|
||||
elif audio_format == "aac":
|
||||
codec_type = A2DP_MPEG_2_4_AAC_CODEC_TYPE
|
||||
codec_capabilities = await aac_codec_capabilities(read_audio_data)
|
||||
packet_source = AacPacketSource(
|
||||
read_audio_data,
|
||||
avdtp_protocol.l2cap_channel.peer_mtu,
|
||||
)
|
||||
else:
|
||||
codec_type = A2DP_NON_A2DP_CODEC_TYPE
|
||||
vendor_id = OpusMediaCodecInformation.VENDOR_ID
|
||||
codec_id = OpusMediaCodecInformation.CODEC_ID
|
||||
codec_capabilities = await opus_codec_capabilities(read_audio_data)
|
||||
packet_source = OpusPacketSource(
|
||||
read_audio_data,
|
||||
avdtp_protocol.l2cap_channel.peer_mtu,
|
||||
)
|
||||
|
||||
# Rewind to the start
|
||||
input_file.seek(0)
|
||||
|
||||
try:
|
||||
await self.stream_packets(
|
||||
avdtp_protocol,
|
||||
codec_type,
|
||||
vendor_id,
|
||||
codec_id,
|
||||
packet_source,
|
||||
codec_capabilities,
|
||||
)
|
||||
except Exception as error:
|
||||
print(color(f"!!! Error while streaming: {error}", "red"))
|
||||
|
||||
self.set_done()
|
||||
|
||||
if address:
|
||||
await self.connect(device, address)
|
||||
else:
|
||||
print(color("Waiting for an incoming connection...", "magenta"))
|
||||
|
||||
def set_done(self) -> None:
|
||||
if self.done:
|
||||
self.done.set()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def create_player(context) -> Player:
|
||||
return Player(
|
||||
transport=context.obj["hci_transport"],
|
||||
device_config=context.obj["device_config"],
|
||||
authenticate=context.obj["authenticate"],
|
||||
encrypt=context.obj["encrypt"],
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.group()
|
||||
@click.pass_context
|
||||
@click.option("--hci-transport", metavar="TRANSPORT", required=True)
|
||||
@click.option("--device-config", metavar="FILENAME", help="Device configuration file")
|
||||
@click.option(
|
||||
"--authenticate",
|
||||
is_flag=True,
|
||||
help="Request authentication when connecting",
|
||||
default=False,
|
||||
)
|
||||
@click.option(
|
||||
"--encrypt", is_flag=True, help="Request encryption when connecting", default=True
|
||||
)
|
||||
def player_cli(ctx, hci_transport, device_config, authenticate, encrypt):
|
||||
ctx.ensure_object(dict)
|
||||
ctx.obj["hci_transport"] = hci_transport
|
||||
ctx.obj["device_config"] = device_config
|
||||
ctx.obj["authenticate"] = authenticate
|
||||
ctx.obj["encrypt"] = encrypt
|
||||
|
||||
|
||||
@player_cli.command("discover")
|
||||
@click.pass_context
|
||||
def discover(context):
|
||||
"""Discover speakers or headphones"""
|
||||
player = create_player(context)
|
||||
asyncio.run(player.run(player.discover))
|
||||
|
||||
|
||||
@player_cli.command("inquire")
|
||||
@click.pass_context
|
||||
@click.argument(
|
||||
"address",
|
||||
metavar="ADDRESS",
|
||||
)
|
||||
def inquire(context, address):
|
||||
"""Connect to a speaker or headphone and inquire about their capabilities"""
|
||||
player = create_player(context)
|
||||
asyncio.run(player.run(lambda device: player.inquire(device, address)))
|
||||
|
||||
|
||||
@player_cli.command("pair")
|
||||
@click.pass_context
|
||||
@click.argument(
|
||||
"address",
|
||||
metavar="ADDRESS",
|
||||
)
|
||||
def pair(context, address):
|
||||
"""Pair with a speaker or headphone"""
|
||||
player = create_player(context)
|
||||
asyncio.run(player.run(lambda device: player.pair(device, address)))
|
||||
|
||||
|
||||
@player_cli.command("play")
|
||||
@click.pass_context
|
||||
@click.option(
|
||||
"--connect",
|
||||
"address",
|
||||
metavar="ADDRESS",
|
||||
help="Address or name to connect to",
|
||||
)
|
||||
@click.option(
|
||||
"-f",
|
||||
"--audio-format",
|
||||
type=click.Choice(["auto", "sbc", "aac", "opus"]),
|
||||
help="Audio file format (use 'auto' to infer the format from the file extension)",
|
||||
default="auto",
|
||||
)
|
||||
@click.argument("audio_file")
|
||||
def play(context, address, audio_format, audio_file):
|
||||
"""Play and audio file"""
|
||||
player = create_player(context)
|
||||
asyncio.run(
|
||||
player.run(
|
||||
lambda device: player.play(device, address, audio_format, audio_file)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper())
|
||||
player_cli()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
520
apps/rfcomm_bridge.py
Normal file
520
apps/rfcomm_bridge.py
Normal file
@@ -0,0 +1,520 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, DeviceConfiguration, Connection
|
||||
from bumble import core
|
||||
from bumble import hci
|
||||
from bumble import rfcomm
|
||||
from bumble import transport
|
||||
from bumble import utils
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
DEFAULT_RFCOMM_UUID = "E6D55659-C8B4-4B85-96BB-B1143AF6D3AE"
|
||||
DEFAULT_MTU = 4096
|
||||
DEFAULT_CLIENT_TCP_PORT = 9544
|
||||
DEFAULT_SERVER_TCP_PORT = 9545
|
||||
|
||||
TRACE_MAX_SIZE = 48
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Tracer:
|
||||
"""
|
||||
Trace data buffers transmitted from one endpoint to another, with stats.
|
||||
"""
|
||||
|
||||
def __init__(self, channel_name: str) -> None:
|
||||
self.channel_name = channel_name
|
||||
self.last_ts: float = 0.0
|
||||
|
||||
def trace_data(self, data: bytes) -> None:
|
||||
now = time.time()
|
||||
elapsed_s = now - self.last_ts if self.last_ts else 0
|
||||
elapsed_ms = int(elapsed_s * 1000)
|
||||
instant_throughput_kbps = ((len(data) / elapsed_s) / 1000) if elapsed_s else 0.0
|
||||
|
||||
hex_str = data[:TRACE_MAX_SIZE].hex() + (
|
||||
"..." if len(data) > TRACE_MAX_SIZE else ""
|
||||
)
|
||||
print(
|
||||
f"[{self.channel_name}] {len(data):4} bytes "
|
||||
f"(+{elapsed_ms:4}ms, {instant_throughput_kbps: 7.2f}kB/s) "
|
||||
f" {hex_str}"
|
||||
)
|
||||
|
||||
self.last_ts = now
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ServerBridge:
|
||||
"""
|
||||
RFCOMM server bridge: waits for a peer to connect an RFCOMM channel.
|
||||
The RFCOMM channel may be associated with a UUID published in an SDP service
|
||||
description, or simply be on a system-assigned channel number.
|
||||
When the connection is made, the bridge connects a TCP socket to a remote host and
|
||||
bridges the data in both directions, with flow control.
|
||||
When the RFCOMM channel is closed, the bridge disconnects the TCP socket
|
||||
and waits for a new channel to be connected.
|
||||
"""
|
||||
|
||||
READ_CHUNK_SIZE = 4096
|
||||
|
||||
def __init__(
|
||||
self, channel: int, uuid: str, trace: bool, tcp_host: str, tcp_port: int
|
||||
) -> None:
|
||||
self.device: Optional[Device] = None
|
||||
self.channel = channel
|
||||
self.uuid = uuid
|
||||
self.tcp_host = tcp_host
|
||||
self.tcp_port = tcp_port
|
||||
self.rfcomm_channel: Optional[rfcomm.DLC] = None
|
||||
self.tcp_tracer: Optional[Tracer]
|
||||
self.rfcomm_tracer: Optional[Tracer]
|
||||
|
||||
if trace:
|
||||
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
|
||||
self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
|
||||
else:
|
||||
self.rfcomm_tracer = None
|
||||
self.tcp_tracer = None
|
||||
|
||||
async def start(self, device: Device) -> None:
|
||||
self.device = device
|
||||
|
||||
# Create and register a server
|
||||
rfcomm_server = rfcomm.Server(self.device)
|
||||
|
||||
# Listen for incoming DLC connections
|
||||
self.channel = rfcomm_server.listen(self.on_rfcomm_channel, self.channel)
|
||||
|
||||
# Setup the SDP to advertise this channel
|
||||
service_record_handle = 0x00010001
|
||||
self.device.sdp_service_records = {
|
||||
service_record_handle: rfcomm.make_service_sdp_records(
|
||||
service_record_handle, self.channel, core.UUID(self.uuid)
|
||||
)
|
||||
}
|
||||
|
||||
# We're ready for a connection
|
||||
self.device.on("connection", self.on_connection)
|
||||
await self.set_available(True)
|
||||
|
||||
print(
|
||||
color(
|
||||
(
|
||||
f"### Listening for RFCOMM connection on {device.public_address}, "
|
||||
f"channel {self.channel}"
|
||||
),
|
||||
"yellow",
|
||||
)
|
||||
)
|
||||
|
||||
async def set_available(self, available: bool):
|
||||
# Become discoverable and connectable
|
||||
assert self.device
|
||||
await self.device.set_connectable(available)
|
||||
await self.device.set_discoverable(available)
|
||||
|
||||
def on_connection(self, connection):
|
||||
print(color(f"@@@ Bluetooth connection: {connection}", "blue"))
|
||||
connection.on("disconnection", self.on_disconnection)
|
||||
|
||||
# Don't accept new connections until we're disconnected
|
||||
utils.AsyncRunner.spawn(self.set_available(False))
|
||||
|
||||
def on_disconnection(self, reason: int):
|
||||
print(
|
||||
color("@@@ Bluetooth disconnection:", "red"),
|
||||
hci.HCI_Constant.error_name(reason),
|
||||
)
|
||||
|
||||
# We're ready for a new connection
|
||||
utils.AsyncRunner.spawn(self.set_available(True))
|
||||
|
||||
# Called when an RFCOMM channel is established
|
||||
@utils.AsyncRunner.run_in_task()
|
||||
async def on_rfcomm_channel(self, rfcomm_channel):
|
||||
print(color("*** RFCOMM channel:", "cyan"), rfcomm_channel)
|
||||
|
||||
# Connect to the TCP server
|
||||
print(
|
||||
color(
|
||||
f"### Connecting to TCP {self.tcp_host}:{self.tcp_port}",
|
||||
"yellow",
|
||||
)
|
||||
)
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(self.tcp_host, self.tcp_port)
|
||||
except OSError:
|
||||
print(color("!!! Connection failed", "red"))
|
||||
await rfcomm_channel.disconnect()
|
||||
return
|
||||
|
||||
# Pipe data from RFCOMM to TCP
|
||||
def on_rfcomm_channel_closed():
|
||||
print(color("*** RFCOMM channel closed", "cyan"))
|
||||
writer.close()
|
||||
|
||||
def write_rfcomm_data(data):
|
||||
if self.rfcomm_tracer:
|
||||
self.rfcomm_tracer.trace_data(data)
|
||||
|
||||
writer.write(data)
|
||||
|
||||
rfcomm_channel.sink = write_rfcomm_data
|
||||
rfcomm_channel.on("close", on_rfcomm_channel_closed)
|
||||
|
||||
# Pipe data from TCP to RFCOMM
|
||||
while True:
|
||||
try:
|
||||
data = await reader.read(self.READ_CHUNK_SIZE)
|
||||
|
||||
if len(data) == 0:
|
||||
print(color("### TCP end of stream", "yellow"))
|
||||
if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
|
||||
await rfcomm_channel.disconnect()
|
||||
return
|
||||
|
||||
if self.tcp_tracer:
|
||||
self.tcp_tracer.trace_data(data)
|
||||
|
||||
rfcomm_channel.write(data)
|
||||
await rfcomm_channel.drain()
|
||||
except Exception as error:
|
||||
print(f"!!! Exception: {error}")
|
||||
break
|
||||
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
print(color("~~~ Bye bye", "magenta"))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ClientBridge:
|
||||
"""
|
||||
RFCOMM client bridge: connects to a BR/EDR device, then waits for an inbound
|
||||
TCP connection on a specified port number. When a TCP client connects, an
|
||||
RFCOMM connection to the device is established, and the data is bridged in both
|
||||
directions, with flow control.
|
||||
When the TCP connection is closed by the client, the RFCOMM channel is
|
||||
disconnected, but the connection to the device remains, ready for a new TCP client
|
||||
to connect.
|
||||
"""
|
||||
|
||||
READ_CHUNK_SIZE = 4096
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel: int,
|
||||
uuid: str,
|
||||
trace: bool,
|
||||
address: str,
|
||||
tcp_host: str,
|
||||
tcp_port: int,
|
||||
authenticate: bool,
|
||||
encrypt: bool,
|
||||
):
|
||||
self.channel = channel
|
||||
self.uuid = uuid
|
||||
self.trace = trace
|
||||
self.address = address
|
||||
self.tcp_host = tcp_host
|
||||
self.tcp_port = tcp_port
|
||||
self.authenticate = authenticate
|
||||
self.encrypt = encrypt
|
||||
self.device: Optional[Device] = None
|
||||
self.connection: Optional[Connection] = None
|
||||
self.rfcomm_client: Optional[rfcomm.Client]
|
||||
self.rfcomm_mux: Optional[rfcomm.Multiplexer]
|
||||
self.tcp_connected: bool = False
|
||||
|
||||
self.tcp_tracer: Optional[Tracer]
|
||||
self.rfcomm_tracer: Optional[Tracer]
|
||||
|
||||
if trace:
|
||||
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
|
||||
self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
|
||||
else:
|
||||
self.rfcomm_tracer = None
|
||||
self.tcp_tracer = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self.connection:
|
||||
return
|
||||
|
||||
print(color(f"@@@ Connecting to Bluetooth {self.address}", "blue"))
|
||||
assert self.device
|
||||
self.connection = await self.device.connect(
|
||||
self.address, transport=core.BT_BR_EDR_TRANSPORT
|
||||
)
|
||||
print(color(f"@@@ Bluetooth connection: {self.connection}", "blue"))
|
||||
self.connection.on("disconnection", self.on_disconnection)
|
||||
|
||||
if self.authenticate:
|
||||
print(color("@@@ Authenticating Bluetooth connection", "blue"))
|
||||
await self.connection.authenticate()
|
||||
print(color("@@@ Bluetooth connection authenticated", "blue"))
|
||||
|
||||
if self.encrypt:
|
||||
print(color("@@@ Encrypting Bluetooth connection", "blue"))
|
||||
await self.connection.encrypt()
|
||||
print(color("@@@ Bluetooth connection encrypted", "blue"))
|
||||
|
||||
self.rfcomm_client = rfcomm.Client(self.connection)
|
||||
try:
|
||||
self.rfcomm_mux = await self.rfcomm_client.start()
|
||||
except BaseException as e:
|
||||
print(color("!!! Failed to setup RFCOMM connection", "red"), e)
|
||||
raise
|
||||
|
||||
async def start(self, device: Device) -> None:
|
||||
self.device = device
|
||||
await device.set_connectable(False)
|
||||
await device.set_discoverable(False)
|
||||
|
||||
# Called when a TCP connection is established
|
||||
async def on_tcp_connection(reader, writer):
|
||||
print(color("<<< TCP connection", "magenta"))
|
||||
if self.tcp_connected:
|
||||
print(
|
||||
color("!!! TCP connection already active, rejecting new one", "red")
|
||||
)
|
||||
writer.close()
|
||||
return
|
||||
self.tcp_connected = True
|
||||
|
||||
try:
|
||||
await self.pipe(reader, writer)
|
||||
except BaseException as error:
|
||||
print(color("!!! Exception while piping data:", "red"), error)
|
||||
return
|
||||
finally:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
self.tcp_connected = False
|
||||
|
||||
await asyncio.start_server(
|
||||
on_tcp_connection,
|
||||
host=self.tcp_host if self.tcp_host != "_" else None,
|
||||
port=self.tcp_port,
|
||||
)
|
||||
print(
|
||||
color(
|
||||
f"### Listening for TCP connections on port {self.tcp_port}", "magenta"
|
||||
)
|
||||
)
|
||||
|
||||
async def pipe(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
# Resolve the channel number from the UUID if needed
|
||||
if self.channel == 0:
|
||||
await self.connect()
|
||||
assert self.connection
|
||||
channel = await rfcomm.find_rfcomm_channel_with_uuid(
|
||||
self.connection, self.uuid
|
||||
)
|
||||
if channel:
|
||||
print(color(f"### Found RFCOMM channel {channel}", "yellow"))
|
||||
else:
|
||||
print(color(f"!!! RFCOMM channel with UUID {self.uuid} not found"))
|
||||
return
|
||||
else:
|
||||
channel = self.channel
|
||||
|
||||
# Connect a new RFCOMM channel
|
||||
await self.connect()
|
||||
assert self.rfcomm_mux
|
||||
print(color(f"*** Opening RFCOMM channel {channel}", "green"))
|
||||
try:
|
||||
rfcomm_channel = await self.rfcomm_mux.open_dlc(channel)
|
||||
print(color(f"*** RFCOMM channel open: {rfcomm_channel}", "green"))
|
||||
except Exception as error:
|
||||
print(color(f"!!! RFCOMM open failed: {error}", "red"))
|
||||
return
|
||||
|
||||
# Pipe data from RFCOMM to TCP
|
||||
def on_rfcomm_channel_closed():
|
||||
print(color("*** RFCOMM channel closed", "green"))
|
||||
|
||||
def write_rfcomm_data(data):
|
||||
if self.trace:
|
||||
self.rfcomm_tracer.trace_data(data)
|
||||
|
||||
writer.write(data)
|
||||
|
||||
rfcomm_channel.on("close", on_rfcomm_channel_closed)
|
||||
rfcomm_channel.sink = write_rfcomm_data
|
||||
|
||||
# Pipe data from TCP to RFCOMM
|
||||
while True:
|
||||
try:
|
||||
data = await reader.read(self.READ_CHUNK_SIZE)
|
||||
|
||||
if len(data) == 0:
|
||||
print(color("### TCP end of stream", "yellow"))
|
||||
if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
|
||||
await rfcomm_channel.disconnect()
|
||||
self.tcp_connected = False
|
||||
return
|
||||
|
||||
if self.tcp_tracer:
|
||||
self.tcp_tracer.trace_data(data)
|
||||
|
||||
rfcomm_channel.write(data)
|
||||
await rfcomm_channel.drain()
|
||||
except Exception as error:
|
||||
print(f"!!! Exception: {error}")
|
||||
break
|
||||
|
||||
print(color("~~~ Bye bye", "magenta"))
|
||||
|
||||
def on_disconnection(self, reason: int) -> None:
|
||||
print(
|
||||
color("@@@ Bluetooth disconnection:", "red"),
|
||||
hci.HCI_Constant.error_name(reason),
|
||||
)
|
||||
self.connection = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run(device_config, hci_transport, bridge):
|
||||
print("<<< connecting to HCI...")
|
||||
async with await transport.open_transport_or_link(hci_transport) as (
|
||||
hci_source,
|
||||
hci_sink,
|
||||
):
|
||||
print("<<< connected")
|
||||
|
||||
if device_config:
|
||||
device = Device.from_config_file_with_hci(
|
||||
device_config, hci_source, hci_sink
|
||||
)
|
||||
else:
|
||||
device = Device.from_config_with_hci(
|
||||
DeviceConfiguration(), hci_source, hci_sink
|
||||
)
|
||||
device.classic_enabled = True
|
||||
|
||||
# Let's go
|
||||
await device.power_on()
|
||||
try:
|
||||
await bridge.start(device)
|
||||
|
||||
# Wait until the transport terminates
|
||||
await hci_source.wait_for_termination()
|
||||
except core.ConnectionError as error:
|
||||
print(color(f"!!! Bluetooth connection failed: {error}", "red"))
|
||||
except Exception as error:
|
||||
print(f"Exception while running bridge: {error}")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.group()
|
||||
@click.pass_context
|
||||
@click.option(
|
||||
"--device-config",
|
||||
metavar="CONFIG_FILE",
|
||||
help="Device configuration file",
|
||||
)
|
||||
@click.option(
|
||||
"--hci-transport", metavar="TRANSPORT_NAME", help="HCI transport", required=True
|
||||
)
|
||||
@click.option("--trace", is_flag=True, help="Trace bridged data to stdout")
|
||||
@click.option(
|
||||
"--channel",
|
||||
metavar="CHANNEL_NUMER",
|
||||
help="RFCOMM channel number",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
@click.option(
|
||||
"--uuid",
|
||||
metavar="UUID",
|
||||
help="UUID for the RFCOMM channel",
|
||||
default=DEFAULT_RFCOMM_UUID,
|
||||
)
|
||||
def cli(
|
||||
context,
|
||||
device_config,
|
||||
hci_transport,
|
||||
trace,
|
||||
channel,
|
||||
uuid,
|
||||
):
|
||||
context.ensure_object(dict)
|
||||
context.obj["device_config"] = device_config
|
||||
context.obj["hci_transport"] = hci_transport
|
||||
context.obj["trace"] = trace
|
||||
context.obj["channel"] = channel
|
||||
context.obj["uuid"] = uuid
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@cli.command()
|
||||
@click.pass_context
|
||||
@click.option("--tcp-host", help="TCP host", default="localhost")
|
||||
@click.option("--tcp-port", help="TCP port", default=DEFAULT_SERVER_TCP_PORT)
|
||||
def server(context, tcp_host, tcp_port):
|
||||
bridge = ServerBridge(
|
||||
context.obj["channel"],
|
||||
context.obj["uuid"],
|
||||
context.obj["trace"],
|
||||
tcp_host,
|
||||
tcp_port,
|
||||
)
|
||||
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@cli.command()
|
||||
@click.pass_context
|
||||
@click.argument("bluetooth-address")
|
||||
@click.option("--tcp-host", help="TCP host", default="_")
|
||||
@click.option("--tcp-port", help="TCP port", default=DEFAULT_CLIENT_TCP_PORT)
|
||||
@click.option("--authenticate", is_flag=True, help="Authenticate the connection")
|
||||
@click.option("--encrypt", is_flag=True, help="Encrypt the connection")
|
||||
def client(context, bluetooth_address, tcp_host, tcp_port, authenticate, encrypt):
|
||||
bridge = ClientBridge(
|
||||
context.obj["channel"],
|
||||
context.obj["uuid"],
|
||||
context.obj["trace"],
|
||||
bluetooth_address,
|
||||
tcp_host,
|
||||
tcp_port,
|
||||
authenticate,
|
||||
encrypt,
|
||||
)
|
||||
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper())
|
||||
if __name__ == "__main__":
|
||||
cli(obj={}) # pylint: disable=no-value-for-parameter
|
||||
52
apps/scan.py
52
apps/scan.py
@@ -26,7 +26,7 @@ from bumble.transport import open_transport_or_link
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.smp import AddressResolver
|
||||
from bumble.device import Advertisement
|
||||
from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
|
||||
from bumble.hci import Address, HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -66,10 +66,15 @@ class AdvertisementPrinter:
|
||||
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
|
||||
address.address_type
|
||||
]
|
||||
if address.is_public:
|
||||
type_color = 'cyan'
|
||||
if address.address_type in (
|
||||
Address.RANDOM_IDENTITY_ADDRESS,
|
||||
Address.PUBLIC_IDENTITY_ADDRESS,
|
||||
):
|
||||
type_color = 'yellow'
|
||||
else:
|
||||
if address.is_static:
|
||||
if address.is_public:
|
||||
type_color = 'cyan'
|
||||
elif address.is_static:
|
||||
type_color = 'green'
|
||||
address_qualifier = '(static)'
|
||||
elif address.is_resolvable:
|
||||
@@ -116,6 +121,7 @@ async def scan(
|
||||
phy,
|
||||
filter_duplicates,
|
||||
raw,
|
||||
irks,
|
||||
keystore_file,
|
||||
device_config,
|
||||
transport,
|
||||
@@ -140,9 +146,21 @@ async def scan(
|
||||
|
||||
if device.keystore:
|
||||
resolving_keys = await device.keystore.get_resolving_keys()
|
||||
resolver = AddressResolver(resolving_keys)
|
||||
else:
|
||||
resolver = None
|
||||
resolving_keys = []
|
||||
|
||||
for irk_and_address in irks:
|
||||
if ':' not in irk_and_address:
|
||||
raise ValueError('invalid IRK:ADDRESS value')
|
||||
irk_hex, address_str = irk_and_address.split(':', 1)
|
||||
resolving_keys.append(
|
||||
(
|
||||
bytes.fromhex(irk_hex),
|
||||
Address(address_str, Address.RANDOM_DEVICE_ADDRESS),
|
||||
)
|
||||
)
|
||||
|
||||
resolver = AddressResolver(resolving_keys) if resolving_keys else None
|
||||
|
||||
printer = AdvertisementPrinter(min_rssi, resolver)
|
||||
if raw:
|
||||
@@ -187,8 +205,24 @@ async def scan(
|
||||
default=False,
|
||||
help='Listen for raw advertising reports instead of processed ones',
|
||||
)
|
||||
@click.option('--keystore-file', help='Keystore file to use when resolving addresses')
|
||||
@click.option('--device-config', help='Device config file for the scanning device')
|
||||
@click.option(
|
||||
'--irk',
|
||||
metavar='<IRK_HEX>:<ADDRESS>',
|
||||
help=(
|
||||
'Use this IRK for resolving private addresses ' '(may be used more than once)'
|
||||
),
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
'--keystore-file',
|
||||
metavar='FILE_PATH',
|
||||
help='Keystore file to use when resolving addresses',
|
||||
)
|
||||
@click.option(
|
||||
'--device-config',
|
||||
metavar='FILE_PATH',
|
||||
help='Device config file for the scanning device',
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(
|
||||
min_rssi,
|
||||
@@ -198,6 +232,7 @@ def main(
|
||||
phy,
|
||||
filter_duplicates,
|
||||
raw,
|
||||
irk,
|
||||
keystore_file,
|
||||
device_config,
|
||||
transport,
|
||||
@@ -212,6 +247,7 @@ def main(
|
||||
phy,
|
||||
filter_duplicates,
|
||||
raw,
|
||||
irk,
|
||||
keystore_file,
|
||||
device_config,
|
||||
transport,
|
||||
|
||||
90
apps/show.py
90
apps/show.py
@@ -15,7 +15,11 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
|
||||
import click
|
||||
|
||||
from bumble.colors import color
|
||||
@@ -24,6 +28,14 @@ from bumble.transport.common import PacketReader
|
||||
from bumble.helpers import PacketTracer
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class SnoopPacketReader:
|
||||
'''
|
||||
@@ -36,12 +48,18 @@ class SnoopPacketReader:
|
||||
DATALINK_BSCP = 1003
|
||||
DATALINK_H5 = 1004
|
||||
|
||||
IDENTIFICATION_PATTERN = b'btsnoop\0'
|
||||
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1)
|
||||
TIMESTAMP_DELTA = 0x00E03AB44A676000
|
||||
ONE_MICROSECOND = datetime.timedelta(microseconds=1)
|
||||
|
||||
def __init__(self, source):
|
||||
self.source = source
|
||||
self.at_end = False
|
||||
|
||||
# Read the header
|
||||
identification_pattern = source.read(8)
|
||||
if identification_pattern.hex().lower() != '6274736e6f6f7000':
|
||||
if identification_pattern != self.IDENTIFICATION_PATTERN:
|
||||
raise ValueError(
|
||||
'not a valid snoop file, unexpected identification pattern'
|
||||
)
|
||||
@@ -55,19 +73,32 @@ class SnoopPacketReader:
|
||||
# Read the record header
|
||||
header = self.source.read(24)
|
||||
if len(header) < 24:
|
||||
return (0, None)
|
||||
self.at_end = True
|
||||
return (None, 0, None)
|
||||
|
||||
# Parse the header
|
||||
(
|
||||
original_length,
|
||||
included_length,
|
||||
packet_flags,
|
||||
_cumulative_drops,
|
||||
_timestamp_seconds,
|
||||
_timestamp_microsecond,
|
||||
) = struct.unpack('>IIIIII', header)
|
||||
timestamp,
|
||||
) = struct.unpack('>IIIIQ', header)
|
||||
|
||||
# Abort on truncated packets
|
||||
# Skip truncated packets
|
||||
if original_length != included_length:
|
||||
return (0, None)
|
||||
print(
|
||||
color(
|
||||
f"!!! truncated packet ({included_length}/{original_length})", "red"
|
||||
)
|
||||
)
|
||||
self.source.read(included_length)
|
||||
return (None, 0, None)
|
||||
|
||||
# Convert the timestamp to a datetime object.
|
||||
ts_dt = self.TIMESTAMP_ANCHOR + datetime.timedelta(
|
||||
microseconds=timestamp - self.TIMESTAMP_DELTA
|
||||
)
|
||||
|
||||
if self.data_link_type == self.DATALINK_H1:
|
||||
# The packet is un-encapsulated, look at the flags to figure out its type
|
||||
@@ -89,7 +120,17 @@ class SnoopPacketReader:
|
||||
bytes([packet_type]) + self.source.read(included_length),
|
||||
)
|
||||
|
||||
return (packet_flags & 1, self.source.read(included_length))
|
||||
return (ts_dt, packet_flags & 1, self.source.read(included_length))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Printer:
|
||||
def __init__(self):
|
||||
self.index = 0
|
||||
|
||||
def print(self, message: str) -> None:
|
||||
self.index += 1
|
||||
print(f"[{self.index:8}]{message}")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -102,33 +143,48 @@ class SnoopPacketReader:
|
||||
default='h4',
|
||||
help='Format of the input file',
|
||||
)
|
||||
@click.option(
|
||||
'--vendor',
|
||||
type=click.Choice(['android', 'zephyr']),
|
||||
multiple=True,
|
||||
help='Support vendor-specific commands (list one or more)',
|
||||
)
|
||||
@click.argument('filename')
|
||||
# pylint: disable=redefined-builtin
|
||||
def main(format, filename):
|
||||
def main(format, vendor, filename):
|
||||
for vendor_name in vendor:
|
||||
if vendor_name == 'android':
|
||||
import bumble.vendor.android.hci
|
||||
elif vendor_name == 'zephyr':
|
||||
import bumble.vendor.zephyr.hci
|
||||
|
||||
input = open(filename, 'rb')
|
||||
if format == 'h4':
|
||||
packet_reader = PacketReader(input)
|
||||
|
||||
def read_next_packet():
|
||||
return (0, packet_reader.next_packet())
|
||||
return (None, 0, packet_reader.next_packet())
|
||||
|
||||
else:
|
||||
packet_reader = SnoopPacketReader(input)
|
||||
read_next_packet = packet_reader.next_packet
|
||||
|
||||
tracer = PacketTracer(emit_message=print)
|
||||
printer = Printer()
|
||||
tracer = PacketTracer(emit_message=printer.print)
|
||||
|
||||
while True:
|
||||
while not packet_reader.at_end:
|
||||
try:
|
||||
(direction, packet) = read_next_packet()
|
||||
if packet is None:
|
||||
break
|
||||
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction)
|
||||
|
||||
(timestamp, direction, packet) = read_next_packet()
|
||||
if packet:
|
||||
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction, timestamp)
|
||||
else:
|
||||
printer.print(color("[TRUNCATED]", "red"))
|
||||
except Exception as error:
|
||||
logger.exception('')
|
||||
print(color(f'!!! {error}', 'red'))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
|
||||
@@ -56,7 +56,7 @@ body, h1, h2, h3, h4, h5, h6 {
|
||||
border-radius: 4px;
|
||||
padding: 4px;
|
||||
margin: 6px;
|
||||
margin-left: 0px;
|
||||
margin-left: 0;
|
||||
}
|
||||
|
||||
th, td {
|
||||
@@ -65,7 +65,7 @@ th, td {
|
||||
}
|
||||
|
||||
.properties td:nth-child(even) {
|
||||
background-color: #D6EEEE;
|
||||
background-color: #d6eeee;
|
||||
font-family: monospace;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
<html>
|
||||
<head>
|
||||
<title>Bumble Speaker</title>
|
||||
<script type="text/javascript" src="speaker.js"></script>
|
||||
<script src="speaker.js"></script>
|
||||
<link rel="stylesheet" href="speaker.css">
|
||||
</head>
|
||||
<body>
|
||||
|
||||
@@ -44,25 +44,18 @@ from bumble.avdtp import (
|
||||
AVDTP_AUDIO_MEDIA_TYPE,
|
||||
Listener,
|
||||
MediaCodecCapabilities,
|
||||
MediaPacket,
|
||||
Protocol,
|
||||
)
|
||||
from bumble.a2dp import (
|
||||
MPEG_2_AAC_LC_OBJECT_TYPE,
|
||||
make_audio_sink_service_sdp_records,
|
||||
A2DP_SBC_CODEC_TYPE,
|
||||
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
|
||||
SBC_MONO_CHANNEL_MODE,
|
||||
SBC_DUAL_CHANNEL_MODE,
|
||||
SBC_SNR_ALLOCATION_METHOD,
|
||||
SBC_LOUDNESS_ALLOCATION_METHOD,
|
||||
SBC_STEREO_CHANNEL_MODE,
|
||||
SBC_JOINT_STEREO_CHANNEL_MODE,
|
||||
SbcMediaCodecInformation,
|
||||
AacMediaCodecInformation,
|
||||
)
|
||||
from bumble.utils import AsyncRunner
|
||||
from bumble.codecs import AacAudioRtpPacket
|
||||
from bumble.rtp import MediaPacket
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -76,6 +69,7 @@ logger = logging.getLogger(__name__)
|
||||
# -----------------------------------------------------------------------------
|
||||
DEFAULT_UI_PORT = 7654
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AudioExtractor:
|
||||
@staticmethod
|
||||
@@ -92,7 +86,7 @@ class AudioExtractor:
|
||||
# -----------------------------------------------------------------------------
|
||||
class AacAudioExtractor:
|
||||
def extract_audio(self, packet: MediaPacket) -> bytes:
|
||||
return AacAudioRtpPacket(packet.payload).to_adts()
|
||||
return AacAudioRtpPacket.from_bytes(packet.payload).to_adts()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -195,7 +189,7 @@ class WebSocketOutput(QueuedOutput):
|
||||
except HCI_StatusError:
|
||||
pass
|
||||
peer_name = '' if connection.peer_name is None else connection.peer_name
|
||||
peer_address = str(connection.peer_address).replace('/P', '')
|
||||
peer_address = connection.peer_address.to_string(False)
|
||||
await self.send_message(
|
||||
'connection',
|
||||
peer_address=peer_address,
|
||||
@@ -376,7 +370,7 @@ class UiServer:
|
||||
if connection := self.speaker().connection:
|
||||
await self.send_message(
|
||||
'connection',
|
||||
peer_address=str(connection.peer_address).replace('/P', ''),
|
||||
peer_address=connection.peer_address.to_string(False),
|
||||
peer_name=connection.peer_name,
|
||||
)
|
||||
|
||||
@@ -450,10 +444,12 @@ class Speaker:
|
||||
return MediaCodecCapabilities(
|
||||
media_type=AVDTP_AUDIO_MEDIA_TYPE,
|
||||
media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE,
|
||||
media_codec_information=AacMediaCodecInformation.from_lists(
|
||||
object_types=[MPEG_2_AAC_LC_OBJECT_TYPE],
|
||||
sampling_frequencies=[48000, 44100],
|
||||
channels=[1, 2],
|
||||
media_codec_information=AacMediaCodecInformation(
|
||||
object_type=AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC,
|
||||
sampling_frequency=AacMediaCodecInformation.SamplingFrequency.SF_48000
|
||||
| AacMediaCodecInformation.SamplingFrequency.SF_44100,
|
||||
channels=AacMediaCodecInformation.Channels.MONO
|
||||
| AacMediaCodecInformation.Channels.STEREO,
|
||||
vbr=1,
|
||||
bitrate=256000,
|
||||
),
|
||||
@@ -463,20 +459,23 @@ class Speaker:
|
||||
return MediaCodecCapabilities(
|
||||
media_type=AVDTP_AUDIO_MEDIA_TYPE,
|
||||
media_codec_type=A2DP_SBC_CODEC_TYPE,
|
||||
media_codec_information=SbcMediaCodecInformation.from_lists(
|
||||
sampling_frequencies=[48000, 44100, 32000, 16000],
|
||||
channel_modes=[
|
||||
SBC_MONO_CHANNEL_MODE,
|
||||
SBC_DUAL_CHANNEL_MODE,
|
||||
SBC_STEREO_CHANNEL_MODE,
|
||||
SBC_JOINT_STEREO_CHANNEL_MODE,
|
||||
],
|
||||
block_lengths=[4, 8, 12, 16],
|
||||
subbands=[4, 8],
|
||||
allocation_methods=[
|
||||
SBC_LOUDNESS_ALLOCATION_METHOD,
|
||||
SBC_SNR_ALLOCATION_METHOD,
|
||||
],
|
||||
media_codec_information=SbcMediaCodecInformation(
|
||||
sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.SF_48000
|
||||
| SbcMediaCodecInformation.SamplingFrequency.SF_44100
|
||||
| SbcMediaCodecInformation.SamplingFrequency.SF_32000
|
||||
| SbcMediaCodecInformation.SamplingFrequency.SF_16000,
|
||||
channel_mode=SbcMediaCodecInformation.ChannelMode.MONO
|
||||
| SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL
|
||||
| SbcMediaCodecInformation.ChannelMode.STEREO
|
||||
| SbcMediaCodecInformation.ChannelMode.JOINT_STEREO,
|
||||
block_length=SbcMediaCodecInformation.BlockLength.BL_4
|
||||
| SbcMediaCodecInformation.BlockLength.BL_8
|
||||
| SbcMediaCodecInformation.BlockLength.BL_12
|
||||
| SbcMediaCodecInformation.BlockLength.BL_16,
|
||||
subbands=SbcMediaCodecInformation.Subbands.S_4
|
||||
| SbcMediaCodecInformation.Subbands.S_8,
|
||||
allocation_method=SbcMediaCodecInformation.AllocationMethod.LOUDNESS
|
||||
| SbcMediaCodecInformation.AllocationMethod.SNR,
|
||||
minimum_bitpool_value=2,
|
||||
maximum_bitpool_value=53,
|
||||
),
|
||||
@@ -641,7 +640,7 @@ class Speaker:
|
||||
self.device.on('connection', self.on_bluetooth_connection)
|
||||
|
||||
# Create a listener to wait for AVDTP connections
|
||||
self.listener = Listener(Listener.create_registrar(self.device))
|
||||
self.listener = Listener.for_device(self.device)
|
||||
self.listener.on('connection', self.on_avdtp_connection)
|
||||
|
||||
print(f'Speaker ready to play, codec={color(self.codec, "cyan")}')
|
||||
|
||||
@@ -24,6 +24,7 @@ from bumble.device import Device
|
||||
from bumble.keys import JsonKeyStore
|
||||
from bumble.transport import open_transport
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def unbond_with_keystore(keystore, address):
|
||||
if address is None:
|
||||
|
||||
777
bumble/a2dp.py
777
bumble/a2dp.py
@@ -15,10 +15,18 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import struct
|
||||
from typing import Awaitable, Callable
|
||||
from typing_extensions import ClassVar, Self
|
||||
|
||||
|
||||
from .codecs import AacAudioRtpPacket
|
||||
from .company_ids import COMPANY_IDENTIFIERS
|
||||
from .sdp import (
|
||||
DataElement,
|
||||
@@ -38,6 +46,7 @@ from .core import (
|
||||
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
|
||||
name_or_number,
|
||||
)
|
||||
from .rtp import MediaPacket
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -99,6 +108,8 @@ SBC_ALLOCATION_METHOD_NAMES = {
|
||||
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
|
||||
}
|
||||
|
||||
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
||||
|
||||
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
|
||||
8000,
|
||||
11025,
|
||||
@@ -126,6 +137,9 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
|
||||
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
|
||||
}
|
||||
|
||||
|
||||
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
@@ -180,8 +194,12 @@ def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3))
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
|
||||
DataElement.unsigned_integer_16(version_int),
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
|
||||
DataElement.unsigned_integer_16(version_int),
|
||||
]
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
@@ -230,8 +248,12 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
|
||||
DataElement.unsigned_integer_16(version_int),
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
|
||||
DataElement.unsigned_integer_16(version_int),
|
||||
]
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
@@ -239,48 +261,67 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SbcMediaCodecInformation(
|
||||
namedtuple(
|
||||
'SbcMediaCodecInformation',
|
||||
[
|
||||
'sampling_frequency',
|
||||
'channel_mode',
|
||||
'block_length',
|
||||
'subbands',
|
||||
'allocation_method',
|
||||
'minimum_bitpool_value',
|
||||
'maximum_bitpool_value',
|
||||
],
|
||||
)
|
||||
):
|
||||
@dataclasses.dataclass
|
||||
class SbcMediaCodecInformation:
|
||||
'''
|
||||
A2DP spec - 4.3.2 Codec Specific Information Elements
|
||||
'''
|
||||
|
||||
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
|
||||
CHANNEL_MODE_BITS = {
|
||||
SBC_MONO_CHANNEL_MODE: 1 << 3,
|
||||
SBC_DUAL_CHANNEL_MODE: 1 << 2,
|
||||
SBC_STEREO_CHANNEL_MODE: 1 << 1,
|
||||
SBC_JOINT_STEREO_CHANNEL_MODE: 1,
|
||||
}
|
||||
BLOCK_LENGTH_BITS = {4: 1 << 3, 8: 1 << 2, 12: 1 << 1, 16: 1}
|
||||
SUBBANDS_BITS = {4: 1 << 1, 8: 1}
|
||||
ALLOCATION_METHOD_BITS = {
|
||||
SBC_SNR_ALLOCATION_METHOD: 1 << 1,
|
||||
SBC_LOUDNESS_ALLOCATION_METHOD: 1,
|
||||
}
|
||||
sampling_frequency: SamplingFrequency
|
||||
channel_mode: ChannelMode
|
||||
block_length: BlockLength
|
||||
subbands: Subbands
|
||||
allocation_method: AllocationMethod
|
||||
minimum_bitpool_value: int
|
||||
maximum_bitpool_value: int
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> 'SbcMediaCodecInformation':
|
||||
sampling_frequency = (data[0] >> 4) & 0x0F
|
||||
channel_mode = (data[0] >> 0) & 0x0F
|
||||
block_length = (data[1] >> 4) & 0x0F
|
||||
subbands = (data[1] >> 2) & 0x03
|
||||
allocation_method = (data[1] >> 0) & 0x03
|
||||
class SamplingFrequency(enum.IntFlag):
|
||||
SF_16000 = 1 << 3
|
||||
SF_32000 = 1 << 2
|
||||
SF_44100 = 1 << 1
|
||||
SF_48000 = 1 << 0
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, sampling_frequency: int) -> Self:
|
||||
sampling_frequencies = [
|
||||
16000,
|
||||
32000,
|
||||
44100,
|
||||
48000,
|
||||
]
|
||||
index = sampling_frequencies.index(sampling_frequency)
|
||||
return cls(1 << (len(sampling_frequencies) - index - 1))
|
||||
|
||||
class ChannelMode(enum.IntFlag):
|
||||
MONO = 1 << 3
|
||||
DUAL_CHANNEL = 1 << 2
|
||||
STEREO = 1 << 1
|
||||
JOINT_STEREO = 1 << 0
|
||||
|
||||
class BlockLength(enum.IntFlag):
|
||||
BL_4 = 1 << 3
|
||||
BL_8 = 1 << 2
|
||||
BL_12 = 1 << 1
|
||||
BL_16 = 1 << 0
|
||||
|
||||
class Subbands(enum.IntFlag):
|
||||
S_4 = 1 << 1
|
||||
S_8 = 1 << 0
|
||||
|
||||
class AllocationMethod(enum.IntFlag):
|
||||
SNR = 1 << 1
|
||||
LOUDNESS = 1 << 0
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
sampling_frequency = cls.SamplingFrequency((data[0] >> 4) & 0x0F)
|
||||
channel_mode = cls.ChannelMode((data[0] >> 0) & 0x0F)
|
||||
block_length = cls.BlockLength((data[1] >> 4) & 0x0F)
|
||||
subbands = cls.Subbands((data[1] >> 2) & 0x03)
|
||||
allocation_method = cls.AllocationMethod((data[1] >> 0) & 0x03)
|
||||
minimum_bitpool_value = (data[2] >> 0) & 0xFF
|
||||
maximum_bitpool_value = (data[3] >> 0) & 0xFF
|
||||
return SbcMediaCodecInformation(
|
||||
return cls(
|
||||
sampling_frequency,
|
||||
channel_mode,
|
||||
block_length,
|
||||
@@ -290,52 +331,6 @@ class SbcMediaCodecInformation(
|
||||
maximum_bitpool_value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_discrete_values(
|
||||
cls,
|
||||
sampling_frequency,
|
||||
channel_mode,
|
||||
block_length,
|
||||
subbands,
|
||||
allocation_method,
|
||||
minimum_bitpool_value,
|
||||
maximum_bitpool_value,
|
||||
):
|
||||
return SbcMediaCodecInformation(
|
||||
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
|
||||
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
|
||||
block_length=cls.BLOCK_LENGTH_BITS[block_length],
|
||||
subbands=cls.SUBBANDS_BITS[subbands],
|
||||
allocation_method=cls.ALLOCATION_METHOD_BITS[allocation_method],
|
||||
minimum_bitpool_value=minimum_bitpool_value,
|
||||
maximum_bitpool_value=maximum_bitpool_value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lists(
|
||||
cls,
|
||||
sampling_frequencies,
|
||||
channel_modes,
|
||||
block_lengths,
|
||||
subbands,
|
||||
allocation_methods,
|
||||
minimum_bitpool_value,
|
||||
maximum_bitpool_value,
|
||||
):
|
||||
return SbcMediaCodecInformation(
|
||||
sampling_frequency=sum(
|
||||
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
|
||||
),
|
||||
channel_mode=sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes),
|
||||
block_length=sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths),
|
||||
subbands=sum(cls.SUBBANDS_BITS[x] for x in subbands),
|
||||
allocation_method=sum(
|
||||
cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods
|
||||
),
|
||||
minimum_bitpool_value=minimum_bitpool_value,
|
||||
maximum_bitpool_value=maximum_bitpool_value,
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes(
|
||||
[
|
||||
@@ -348,93 +343,74 @@ class SbcMediaCodecInformation(
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
|
||||
allocation_methods = ['SNR', 'Loudness']
|
||||
return '\n'.join(
|
||||
# pylint: disable=line-too-long
|
||||
[
|
||||
'SbcMediaCodecInformation(',
|
||||
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
|
||||
f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}',
|
||||
f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}',
|
||||
f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}',
|
||||
f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}',
|
||||
f' minimum_bitpool_value: {self.minimum_bitpool_value}',
|
||||
f' maximum_bitpool_value: {self.maximum_bitpool_value}' ')',
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AacMediaCodecInformation(
|
||||
namedtuple(
|
||||
'AacMediaCodecInformation',
|
||||
['object_type', 'sampling_frequency', 'channels', 'rfa', 'vbr', 'bitrate'],
|
||||
)
|
||||
):
|
||||
@dataclasses.dataclass
|
||||
class AacMediaCodecInformation:
|
||||
'''
|
||||
A2DP spec - 4.5.2 Codec Specific Information Elements
|
||||
'''
|
||||
|
||||
OBJECT_TYPE_BITS = {
|
||||
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
|
||||
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
|
||||
MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5,
|
||||
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4,
|
||||
}
|
||||
SAMPLING_FREQUENCY_BITS = {
|
||||
8000: 1 << 11,
|
||||
11025: 1 << 10,
|
||||
12000: 1 << 9,
|
||||
16000: 1 << 8,
|
||||
22050: 1 << 7,
|
||||
24000: 1 << 6,
|
||||
32000: 1 << 5,
|
||||
44100: 1 << 4,
|
||||
48000: 1 << 3,
|
||||
64000: 1 << 2,
|
||||
88200: 1 << 1,
|
||||
96000: 1,
|
||||
}
|
||||
CHANNELS_BITS = {1: 1 << 1, 2: 1}
|
||||
object_type: ObjectType
|
||||
sampling_frequency: SamplingFrequency
|
||||
channels: Channels
|
||||
vbr: int
|
||||
bitrate: int
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> 'AacMediaCodecInformation':
|
||||
object_type = data[0]
|
||||
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
|
||||
channels = (data[2] >> 2) & 0x03
|
||||
rfa = 0
|
||||
class ObjectType(enum.IntFlag):
|
||||
MPEG_2_AAC_LC = 1 << 7
|
||||
MPEG_4_AAC_LC = 1 << 6
|
||||
MPEG_4_AAC_LTP = 1 << 5
|
||||
MPEG_4_AAC_SCALABLE = 1 << 4
|
||||
|
||||
class SamplingFrequency(enum.IntFlag):
|
||||
SF_8000 = 1 << 11
|
||||
SF_11025 = 1 << 10
|
||||
SF_12000 = 1 << 9
|
||||
SF_16000 = 1 << 8
|
||||
SF_22050 = 1 << 7
|
||||
SF_24000 = 1 << 6
|
||||
SF_32000 = 1 << 5
|
||||
SF_44100 = 1 << 4
|
||||
SF_48000 = 1 << 3
|
||||
SF_64000 = 1 << 2
|
||||
SF_88200 = 1 << 1
|
||||
SF_96000 = 1 << 0
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, sampling_frequency: int) -> Self:
|
||||
sampling_frequencies = [
|
||||
8000,
|
||||
11025,
|
||||
12000,
|
||||
16000,
|
||||
22050,
|
||||
24000,
|
||||
32000,
|
||||
44100,
|
||||
48000,
|
||||
64000,
|
||||
88200,
|
||||
96000,
|
||||
]
|
||||
index = sampling_frequencies.index(sampling_frequency)
|
||||
return cls(1 << (len(sampling_frequencies) - index - 1))
|
||||
|
||||
class Channels(enum.IntFlag):
|
||||
MONO = 1 << 1
|
||||
STEREO = 1 << 0
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> AacMediaCodecInformation:
|
||||
object_type = cls.ObjectType(data[0])
|
||||
sampling_frequency = cls.SamplingFrequency(
|
||||
(data[1] << 4) | ((data[2] >> 4) & 0x0F)
|
||||
)
|
||||
channels = cls.Channels((data[2] >> 2) & 0x03)
|
||||
vbr = (data[3] >> 7) & 0x01
|
||||
bitrate = ((data[3] & 0x7F) << 16) | (data[4] << 8) | data[5]
|
||||
return AacMediaCodecInformation(
|
||||
object_type, sampling_frequency, channels, rfa, vbr, bitrate
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_discrete_values(
|
||||
cls, object_type, sampling_frequency, channels, vbr, bitrate
|
||||
):
|
||||
return AacMediaCodecInformation(
|
||||
object_type=cls.OBJECT_TYPE_BITS[object_type],
|
||||
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
|
||||
channels=cls.CHANNELS_BITS[channels],
|
||||
rfa=0,
|
||||
vbr=vbr,
|
||||
bitrate=bitrate,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate):
|
||||
return AacMediaCodecInformation(
|
||||
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
|
||||
sampling_frequency=sum(
|
||||
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
|
||||
),
|
||||
channels=sum(cls.CHANNELS_BITS[x] for x in channels),
|
||||
rfa=0,
|
||||
vbr=vbr,
|
||||
bitrate=bitrate,
|
||||
object_type, sampling_frequency, channels, vbr, bitrate
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
@@ -449,51 +425,27 @@ class AacMediaCodecInformation(
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
object_types = [
|
||||
'MPEG_2_AAC_LC',
|
||||
'MPEG_4_AAC_LC',
|
||||
'MPEG_4_AAC_LTP',
|
||||
'MPEG_4_AAC_SCALABLE',
|
||||
'[4]',
|
||||
'[5]',
|
||||
'[6]',
|
||||
'[7]',
|
||||
]
|
||||
channels = [1, 2]
|
||||
# pylint: disable=line-too-long
|
||||
return '\n'.join(
|
||||
[
|
||||
'AacMediaCodecInformation(',
|
||||
f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}',
|
||||
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}',
|
||||
f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}',
|
||||
f' vbr: {self.vbr}',
|
||||
f' bitrate: {self.bitrate}' ')',
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
# -----------------------------------------------------------------------------
|
||||
class VendorSpecificMediaCodecInformation:
|
||||
'''
|
||||
A2DP spec - 4.7.2 Codec Specific Information Elements
|
||||
'''
|
||||
|
||||
vendor_id: int
|
||||
codec_id: int
|
||||
value: bytes
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data):
|
||||
def from_bytes(data: bytes) -> VendorSpecificMediaCodecInformation:
|
||||
(vendor_id, codec_id) = struct.unpack_from('<IH', data, 0)
|
||||
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
|
||||
|
||||
def __init__(self, vendor_id, codec_id, value):
|
||||
self.vendor_id = vendor_id
|
||||
self.codec_id = codec_id
|
||||
self.value = value
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack('<IH', self.vendor_id, self.codec_id) + self.value
|
||||
|
||||
def __bytes__(self):
|
||||
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
# pylint: disable=line-too-long
|
||||
return '\n'.join(
|
||||
[
|
||||
@@ -506,46 +458,102 @@ class VendorSpecificMediaCodecInformation:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class OpusMediaCodecInformation(VendorSpecificMediaCodecInformation):
|
||||
vendor_id: int = dataclasses.field(init=False, repr=False)
|
||||
codec_id: int = dataclasses.field(init=False, repr=False)
|
||||
value: bytes = dataclasses.field(init=False, repr=False)
|
||||
channel_mode: ChannelMode
|
||||
frame_size: FrameSize
|
||||
sampling_frequency: SamplingFrequency
|
||||
|
||||
class ChannelMode(enum.IntFlag):
|
||||
MONO = 1 << 0
|
||||
STEREO = 1 << 1
|
||||
DUAL_MONO = 1 << 2
|
||||
|
||||
class FrameSize(enum.IntFlag):
|
||||
FS_10MS = 1 << 0
|
||||
FS_20MS = 1 << 1
|
||||
|
||||
class SamplingFrequency(enum.IntFlag):
|
||||
SF_48000 = 1 << 0
|
||||
|
||||
VENDOR_ID: ClassVar[int] = 0x000000E0
|
||||
CODEC_ID: ClassVar[int] = 0x0001
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.vendor_id = self.VENDOR_ID
|
||||
self.codec_id = self.CODEC_ID
|
||||
self.value = bytes(
|
||||
[
|
||||
self.channel_mode
|
||||
| (self.frame_size << 3)
|
||||
| (self.sampling_frequency << 7)
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
"""Create a new instance from the `value` part of the data, not including
|
||||
the vendor id and codec id"""
|
||||
channel_mode = cls.ChannelMode(data[0] & 0x07)
|
||||
frame_size = cls.FrameSize((data[0] >> 3) & 0x03)
|
||||
sampling_frequency = cls.SamplingFrequency((data[0] >> 7) & 0x01)
|
||||
|
||||
return cls(
|
||||
channel_mode,
|
||||
frame_size,
|
||||
sampling_frequency,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return repr(self)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class SbcFrame:
|
||||
def __init__(
|
||||
self, sampling_frequency, block_count, channel_mode, subband_count, payload
|
||||
):
|
||||
self.sampling_frequency = sampling_frequency
|
||||
self.block_count = block_count
|
||||
self.channel_mode = channel_mode
|
||||
self.subband_count = subband_count
|
||||
self.payload = payload
|
||||
sampling_frequency: int
|
||||
block_count: int
|
||||
channel_mode: int
|
||||
allocation_method: int
|
||||
subband_count: int
|
||||
bitpool: int
|
||||
payload: bytes
|
||||
|
||||
@property
|
||||
def sample_count(self):
|
||||
def sample_count(self) -> int:
|
||||
return self.subband_count * self.block_count
|
||||
|
||||
@property
|
||||
def bitrate(self):
|
||||
def bitrate(self) -> int:
|
||||
return 8 * ((len(self.payload) * self.sampling_frequency) // self.sample_count)
|
||||
|
||||
@property
|
||||
def duration(self):
|
||||
def duration(self) -> float:
|
||||
return self.sample_count / self.sampling_frequency
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'SBC(sf={self.sampling_frequency},'
|
||||
f'cm={self.channel_mode},'
|
||||
f'am={self.allocation_method},'
|
||||
f'br={self.bitrate},'
|
||||
f'sc={self.sample_count},'
|
||||
f'bp={self.bitpool},'
|
||||
f'size={len(self.payload)})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SbcParser:
|
||||
def __init__(self, read):
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
|
||||
self.read = read
|
||||
|
||||
@property
|
||||
def frames(self):
|
||||
async def generate_frames():
|
||||
def frames(self) -> AsyncGenerator[SbcFrame, None]:
|
||||
async def generate_frames() -> AsyncGenerator[SbcFrame, None]:
|
||||
while True:
|
||||
# Read 4 bytes of header
|
||||
header = await self.read(4)
|
||||
@@ -562,6 +570,7 @@ class SbcParser:
|
||||
blocks = 4 * (1 + ((header[1] >> 4) & 3))
|
||||
channel_mode = (header[1] >> 2) & 3
|
||||
channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2
|
||||
allocation_method = (header[1] >> 1) & 1
|
||||
subbands = 8 if ((header[1]) & 1) else 4
|
||||
bitpool = header[2]
|
||||
|
||||
@@ -581,7 +590,13 @@ class SbcParser:
|
||||
|
||||
# Emit the next frame
|
||||
yield SbcFrame(
|
||||
sampling_frequency, blocks, channel_mode, subbands, payload
|
||||
sampling_frequency,
|
||||
blocks,
|
||||
channel_mode,
|
||||
allocation_method,
|
||||
subbands,
|
||||
bitpool,
|
||||
payload,
|
||||
)
|
||||
|
||||
return generate_frames()
|
||||
@@ -589,19 +604,15 @@ class SbcParser:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SbcPacketSource:
|
||||
def __init__(self, read, mtu, codec_capabilities):
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
|
||||
self.read = read
|
||||
self.mtu = mtu
|
||||
self.codec_capabilities = codec_capabilities
|
||||
|
||||
@property
|
||||
def packets(self):
|
||||
async def generate_packets():
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from .avdtp import MediaPacket # Import here to avoid a circular reference
|
||||
|
||||
sequence_number = 0
|
||||
timestamp = 0
|
||||
sample_count = 0
|
||||
frames = []
|
||||
frames_size = 0
|
||||
max_rtp_payload = self.mtu - 12 - 1
|
||||
@@ -609,27 +620,29 @@ class SbcPacketSource:
|
||||
# NOTE: this doesn't support frame fragments
|
||||
sbc_parser = SbcParser(self.read)
|
||||
async for frame in sbc_parser.frames:
|
||||
print(frame)
|
||||
|
||||
if (
|
||||
frames_size + len(frame.payload) > max_rtp_payload
|
||||
or len(frames) == 16
|
||||
or len(frames) == SBC_MAX_FRAMES_IN_RTP_PAYLOAD
|
||||
):
|
||||
# Need to flush what has been accumulated so far
|
||||
logger.debug(f"yielding {len(frames)} frames")
|
||||
|
||||
# Emit a packet
|
||||
sbc_payload = bytes([len(frames)]) + b''.join(
|
||||
sbc_payload = bytes([len(frames) & 0x0F]) + b''.join(
|
||||
[frame.payload for frame in frames]
|
||||
)
|
||||
timestamp_seconds = sample_count / frame.sampling_frequency
|
||||
timestamp = int(1000 * timestamp_seconds)
|
||||
packet = MediaPacket(
|
||||
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
|
||||
)
|
||||
packet.timestamp_seconds = timestamp / frame.sampling_frequency
|
||||
packet.timestamp_seconds = timestamp_seconds
|
||||
yield packet
|
||||
|
||||
# Prepare for next packets
|
||||
sequence_number += 1
|
||||
timestamp += sum((frame.sample_count for frame in frames))
|
||||
sequence_number &= 0xFFFF
|
||||
sample_count += sum((frame.sample_count for frame in frames))
|
||||
frames = [frame]
|
||||
frames_size = len(frame.payload)
|
||||
else:
|
||||
@@ -638,3 +651,315 @@ class SbcPacketSource:
|
||||
frames_size += len(frame.payload)
|
||||
|
||||
return generate_packets()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class AacFrame:
|
||||
class Profile(enum.IntEnum):
|
||||
MAIN = 0
|
||||
LC = 1
|
||||
SSR = 2
|
||||
LTP = 3
|
||||
|
||||
profile: Profile
|
||||
sampling_frequency: int
|
||||
channel_configuration: int
|
||||
payload: bytes
|
||||
|
||||
@property
|
||||
def sample_count(self) -> int:
|
||||
return 1024
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
return self.sample_count / self.sampling_frequency
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'AAC(sf={self.sampling_frequency},'
|
||||
f'ch={self.channel_configuration},'
|
||||
f'size={len(self.payload)})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
ADTS_AAC_SAMPLING_FREQUENCIES = [
|
||||
96000,
|
||||
88200,
|
||||
64000,
|
||||
48000,
|
||||
44100,
|
||||
32000,
|
||||
24000,
|
||||
22050,
|
||||
16000,
|
||||
12000,
|
||||
11025,
|
||||
8000,
|
||||
7350,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AacParser:
|
||||
"""Parser for AAC frames in an ADTS stream"""
|
||||
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
|
||||
self.read = read
|
||||
|
||||
@property
|
||||
def frames(self) -> AsyncGenerator[AacFrame, None]:
|
||||
async def generate_frames() -> AsyncGenerator[AacFrame, None]:
|
||||
while True:
|
||||
header = await self.read(7)
|
||||
if not header:
|
||||
return
|
||||
|
||||
sync_word = (header[0] << 4) | (header[1] >> 4)
|
||||
if sync_word != 0b111111111111:
|
||||
raise ValueError(f"invalid sync word ({sync_word:06x})")
|
||||
layer = (header[1] >> 1) & 0b11
|
||||
profile = AacFrame.Profile((header[2] >> 6) & 0b11)
|
||||
sampling_frequency = ADTS_AAC_SAMPLING_FREQUENCIES[
|
||||
(header[2] >> 2) & 0b1111
|
||||
]
|
||||
channel_configuration = ((header[2] & 0b1) << 2) | (header[3] >> 6)
|
||||
frame_length = (
|
||||
((header[3] & 0b11) << 11) | (header[4] << 3) | (header[5] >> 5)
|
||||
)
|
||||
|
||||
if layer != 0:
|
||||
raise ValueError("layer must be 0")
|
||||
|
||||
payload = await self.read(frame_length - 7)
|
||||
if payload:
|
||||
yield AacFrame(
|
||||
profile, sampling_frequency, channel_configuration, payload
|
||||
)
|
||||
|
||||
return generate_frames()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AacPacketSource:
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
|
||||
self.read = read
|
||||
self.mtu = mtu
|
||||
|
||||
@property
|
||||
def packets(self):
|
||||
async def generate_packets():
|
||||
sequence_number = 0
|
||||
sample_count = 0
|
||||
|
||||
aac_parser = AacParser(self.read)
|
||||
async for frame in aac_parser.frames:
|
||||
logger.debug("yielding one AAC frame")
|
||||
|
||||
# Emit a packet
|
||||
aac_payload = bytes(
|
||||
AacAudioRtpPacket.for_simple_aac(
|
||||
frame.sampling_frequency,
|
||||
frame.channel_configuration,
|
||||
frame.payload,
|
||||
)
|
||||
)
|
||||
timestamp_seconds = sample_count / frame.sampling_frequency
|
||||
timestamp = int(1000 * timestamp_seconds)
|
||||
packet = MediaPacket(
|
||||
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, aac_payload
|
||||
)
|
||||
packet.timestamp_seconds = timestamp_seconds
|
||||
yield packet
|
||||
|
||||
# Prepare for next packets
|
||||
sequence_number += 1
|
||||
sequence_number &= 0xFFFF
|
||||
sample_count += frame.sample_count
|
||||
|
||||
return generate_packets()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class OpusPacket:
|
||||
class ChannelMode(enum.IntEnum):
|
||||
MONO = 0
|
||||
STEREO = 1
|
||||
DUAL_MONO = 2
|
||||
|
||||
channel_mode: ChannelMode
|
||||
duration: int # Duration in ms.
|
||||
sampling_frequency: int
|
||||
payload: bytes
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Opus(ch={self.channel_mode.name}, '
|
||||
f'd={self.duration}ms, '
|
||||
f'size={len(self.payload)})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OpusParser:
|
||||
"""
|
||||
Parser for Opus packets in an Ogg stream
|
||||
|
||||
See RFC 3533
|
||||
|
||||
NOTE: this parser only supports bitstreams with a single logical stream.
|
||||
"""
|
||||
|
||||
CAPTURE_PATTERN = b'OggS'
|
||||
|
||||
class HeaderType(enum.IntFlag):
|
||||
CONTINUED = 0x01
|
||||
FIRST = 0x02
|
||||
LAST = 0x04
|
||||
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
|
||||
self.read = read
|
||||
|
||||
@property
|
||||
def packets(self) -> AsyncGenerator[OpusPacket, None]:
|
||||
async def generate_frames() -> AsyncGenerator[OpusPacket, None]:
|
||||
packet = b''
|
||||
packet_count = 0
|
||||
expected_bitstream_serial_number = None
|
||||
expected_page_sequence_number = 0
|
||||
channel_mode = OpusPacket.ChannelMode.STEREO
|
||||
|
||||
while True:
|
||||
# Parse the page header
|
||||
header = await self.read(27)
|
||||
if len(header) != 27:
|
||||
logger.debug("end of stream")
|
||||
break
|
||||
|
||||
capture_pattern = header[:4]
|
||||
if capture_pattern != self.CAPTURE_PATTERN:
|
||||
print(capture_pattern.hex())
|
||||
raise ValueError("invalid capture pattern at start of page")
|
||||
|
||||
version = header[4]
|
||||
if version != 0:
|
||||
raise ValueError(f"version {version} not supported")
|
||||
|
||||
header_type = self.HeaderType(header[5])
|
||||
(
|
||||
granule_position,
|
||||
bitstream_serial_number,
|
||||
page_sequence_number,
|
||||
crc_checksum,
|
||||
page_segments,
|
||||
) = struct.unpack_from("<QIIIB", header, 6)
|
||||
segment_table = await self.read(page_segments)
|
||||
|
||||
if header_type & self.HeaderType.FIRST:
|
||||
if expected_bitstream_serial_number is None:
|
||||
# We will only accept pages for the first encountered stream
|
||||
logger.debug("BOS")
|
||||
expected_bitstream_serial_number = bitstream_serial_number
|
||||
expected_page_sequence_number = page_sequence_number
|
||||
|
||||
if (
|
||||
expected_bitstream_serial_number is None
|
||||
or expected_bitstream_serial_number != bitstream_serial_number
|
||||
):
|
||||
logger.debug("skipping page (not the first logical bitstream)")
|
||||
for lacing_value in segment_table:
|
||||
if lacing_value:
|
||||
await self.read(lacing_value)
|
||||
continue
|
||||
|
||||
if expected_page_sequence_number != page_sequence_number:
|
||||
raise ValueError(
|
||||
f"expected page sequence number {expected_page_sequence_number}"
|
||||
f" but got {page_sequence_number}"
|
||||
)
|
||||
expected_page_sequence_number = page_sequence_number + 1
|
||||
|
||||
# Assemble the page
|
||||
if not header_type & self.HeaderType.CONTINUED:
|
||||
packet = b''
|
||||
for lacing_value in segment_table:
|
||||
if lacing_value:
|
||||
packet += await self.read(lacing_value)
|
||||
if lacing_value < 255:
|
||||
# End of packet
|
||||
packet_count += 1
|
||||
|
||||
if packet_count == 1:
|
||||
# The first packet contains the identification header
|
||||
logger.debug("first packet (header)")
|
||||
if packet[:8] != b"OpusHead":
|
||||
raise ValueError("first packet is not OpusHead")
|
||||
packet_count = (
|
||||
OpusPacket.ChannelMode.MONO
|
||||
if packet[9] == 1
|
||||
else OpusPacket.ChannelMode.STEREO
|
||||
)
|
||||
|
||||
elif packet_count == 2:
|
||||
# The second packet contains the comment header
|
||||
logger.debug("second packet (tags)")
|
||||
if packet[:8] != b"OpusTags":
|
||||
logger.warning("second packet is not OpusTags")
|
||||
else:
|
||||
yield OpusPacket(channel_mode, 20, 48000, packet)
|
||||
|
||||
packet = b''
|
||||
|
||||
if header_type & self.HeaderType.LAST:
|
||||
logger.debug("EOS")
|
||||
|
||||
return generate_frames()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OpusPacketSource:
|
||||
def __init__(self, read: Callable[[int], Awaitable[bytes]], mtu: int) -> None:
|
||||
self.read = read
|
||||
self.mtu = mtu
|
||||
|
||||
@property
|
||||
def packets(self):
|
||||
async def generate_packets():
|
||||
sequence_number = 0
|
||||
elapsed_ms = 0
|
||||
|
||||
opus_parser = OpusParser(self.read)
|
||||
async for opus_packet in opus_parser.packets:
|
||||
# We only support sending one Opus frame per RTP packet
|
||||
# TODO: check the spec for the first byte value here
|
||||
opus_payload = bytes([1]) + opus_packet.payload
|
||||
elapsed_s = elapsed_ms / 1000
|
||||
timestamp = int(elapsed_s * opus_packet.sampling_frequency)
|
||||
rtp_packet = MediaPacket(
|
||||
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, opus_payload
|
||||
)
|
||||
rtp_packet.timestamp_seconds = elapsed_s
|
||||
yield rtp_packet
|
||||
|
||||
# Prepare for next packets
|
||||
sequence_number += 1
|
||||
sequence_number &= 0xFFFF
|
||||
elapsed_ms += opus_packet.duration
|
||||
|
||||
return generate_packets()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# This map should be left at the end of the file so it can refer to the classes
|
||||
# above
|
||||
# -----------------------------------------------------------------------------
|
||||
A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES = {
|
||||
OpusMediaCodecInformation.VENDOR_ID: {
|
||||
OpusMediaCodecInformation.CODEC_ID: OpusMediaCodecInformation
|
||||
}
|
||||
}
|
||||
|
||||
18
bumble/at.py
18
bumble/at.py
@@ -14,13 +14,19 @@
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from bumble import core
|
||||
|
||||
|
||||
class AtParsingError(core.InvalidPacketError):
|
||||
"""Error raised when parsing AT commands fails."""
|
||||
|
||||
|
||||
def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
||||
"""Split input parameters into tokens.
|
||||
Removes space characters outside of double quote blocks:
|
||||
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
|
||||
are ignored [..], unless they are embedded in numeric or string constants"
|
||||
Raises ValueError in case of invalid input string."""
|
||||
Raises AtParsingError in case of invalid input string."""
|
||||
|
||||
tokens = []
|
||||
in_quotes = False
|
||||
@@ -43,11 +49,11 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
||||
token = bytearray()
|
||||
elif char == b'(':
|
||||
if len(token) > 0:
|
||||
raise ValueError("open_paren following regular character")
|
||||
raise AtParsingError("open_paren following regular character")
|
||||
tokens.append(char)
|
||||
elif char == b'"':
|
||||
if len(token) > 0:
|
||||
raise ValueError("quote following regular character")
|
||||
raise AtParsingError("quote following regular character")
|
||||
in_quotes = True
|
||||
token.extend(char)
|
||||
else:
|
||||
@@ -59,7 +65,7 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
|
||||
|
||||
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
||||
"""Parse the parameters using the comma and parenthesis separators.
|
||||
Raises ValueError in case of invalid input string."""
|
||||
Raises AtParsingError in case of invalid input string."""
|
||||
|
||||
tokens = tokenize_parameters(buffer)
|
||||
accumulator: List[list] = [[]]
|
||||
@@ -73,7 +79,7 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
||||
accumulator.append([])
|
||||
elif token == b')':
|
||||
if len(accumulator) < 2:
|
||||
raise ValueError("close_paren without matching open_paren")
|
||||
raise AtParsingError("close_paren without matching open_paren")
|
||||
accumulator[-1].append(current)
|
||||
current = accumulator.pop()
|
||||
else:
|
||||
@@ -81,5 +87,5 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
|
||||
|
||||
accumulator[-1].append(current)
|
||||
if len(accumulator) > 1:
|
||||
raise ValueError("missing close_paren")
|
||||
raise AtParsingError("missing close_paren")
|
||||
return accumulator[0]
|
||||
|
||||
284
bumble/att.py
284
bumble/att.py
@@ -23,13 +23,28 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import struct
|
||||
from pyee import EventEmitter
|
||||
from typing import Dict, Type, TYPE_CHECKING
|
||||
|
||||
from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError
|
||||
from bumble.hci import HCI_Object, key_with_value, HCI_Constant
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import struct
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from pyee import EventEmitter
|
||||
|
||||
from bumble import utils
|
||||
from bumble.core import UUID, name_or_number, ProtocolError
|
||||
from bumble.hci import HCI_Object, key_with_value
|
||||
from bumble.colors import color
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -42,6 +57,7 @@ if TYPE_CHECKING:
|
||||
# pylint: disable=line-too-long
|
||||
|
||||
ATT_CID = 0x04
|
||||
ATT_PSM = 0x001F
|
||||
|
||||
ATT_ERROR_RESPONSE = 0x01
|
||||
ATT_EXCHANGE_MTU_REQUEST = 0x02
|
||||
@@ -132,43 +148,57 @@ ATT_RESPONSES = [
|
||||
ATT_EXECUTE_WRITE_RESPONSE
|
||||
]
|
||||
|
||||
ATT_INVALID_HANDLE_ERROR = 0x01
|
||||
ATT_READ_NOT_PERMITTED_ERROR = 0x02
|
||||
ATT_WRITE_NOT_PERMITTED_ERROR = 0x03
|
||||
ATT_INVALID_PDU_ERROR = 0x04
|
||||
ATT_INSUFFICIENT_AUTHENTICATION_ERROR = 0x05
|
||||
ATT_REQUEST_NOT_SUPPORTED_ERROR = 0x06
|
||||
ATT_INVALID_OFFSET_ERROR = 0x07
|
||||
ATT_INSUFFICIENT_AUTHORIZATION_ERROR = 0x08
|
||||
ATT_PREPARE_QUEUE_FULL_ERROR = 0x09
|
||||
ATT_ATTRIBUTE_NOT_FOUND_ERROR = 0x0A
|
||||
ATT_ATTRIBUTE_NOT_LONG_ERROR = 0x0B
|
||||
ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR = 0x0C
|
||||
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR = 0x0D
|
||||
ATT_UNLIKELY_ERROR_ERROR = 0x0E
|
||||
ATT_INSUFFICIENT_ENCRYPTION_ERROR = 0x0F
|
||||
ATT_UNSUPPORTED_GROUP_TYPE_ERROR = 0x10
|
||||
ATT_INSUFFICIENT_RESOURCES_ERROR = 0x11
|
||||
class ErrorCode(utils.OpenIntEnum):
|
||||
'''
|
||||
See
|
||||
|
||||
ATT_ERROR_NAMES = {
|
||||
ATT_INVALID_HANDLE_ERROR: 'ATT_INVALID_HANDLE_ERROR',
|
||||
ATT_READ_NOT_PERMITTED_ERROR: 'ATT_READ_NOT_PERMITTED_ERROR',
|
||||
ATT_WRITE_NOT_PERMITTED_ERROR: 'ATT_WRITE_NOT_PERMITTED_ERROR',
|
||||
ATT_INVALID_PDU_ERROR: 'ATT_INVALID_PDU_ERROR',
|
||||
ATT_INSUFFICIENT_AUTHENTICATION_ERROR: 'ATT_INSUFFICIENT_AUTHENTICATION_ERROR',
|
||||
ATT_REQUEST_NOT_SUPPORTED_ERROR: 'ATT_REQUEST_NOT_SUPPORTED_ERROR',
|
||||
ATT_INVALID_OFFSET_ERROR: 'ATT_INVALID_OFFSET_ERROR',
|
||||
ATT_INSUFFICIENT_AUTHORIZATION_ERROR: 'ATT_INSUFFICIENT_AUTHORIZATION_ERROR',
|
||||
ATT_PREPARE_QUEUE_FULL_ERROR: 'ATT_PREPARE_QUEUE_FULL_ERROR',
|
||||
ATT_ATTRIBUTE_NOT_FOUND_ERROR: 'ATT_ATTRIBUTE_NOT_FOUND_ERROR',
|
||||
ATT_ATTRIBUTE_NOT_LONG_ERROR: 'ATT_ATTRIBUTE_NOT_LONG_ERROR',
|
||||
ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR: 'ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR',
|
||||
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR: 'ATT_INVALID_ATTRIBUTE_LENGTH_ERROR',
|
||||
ATT_UNLIKELY_ERROR_ERROR: 'ATT_UNLIKELY_ERROR_ERROR',
|
||||
ATT_INSUFFICIENT_ENCRYPTION_ERROR: 'ATT_INSUFFICIENT_ENCRYPTION_ERROR',
|
||||
ATT_UNSUPPORTED_GROUP_TYPE_ERROR: 'ATT_UNSUPPORTED_GROUP_TYPE_ERROR',
|
||||
ATT_INSUFFICIENT_RESOURCES_ERROR: 'ATT_INSUFFICIENT_RESOURCES_ERROR'
|
||||
}
|
||||
* Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response
|
||||
* Core Specification Supplement: Common Profile And Service Error Codes
|
||||
'''
|
||||
INVALID_HANDLE = 0x01
|
||||
READ_NOT_PERMITTED = 0x02
|
||||
WRITE_NOT_PERMITTED = 0x03
|
||||
INVALID_PDU = 0x04
|
||||
INSUFFICIENT_AUTHENTICATION = 0x05
|
||||
REQUEST_NOT_SUPPORTED = 0x06
|
||||
INVALID_OFFSET = 0x07
|
||||
INSUFFICIENT_AUTHORIZATION = 0x08
|
||||
PREPARE_QUEUE_FULL = 0x09
|
||||
ATTRIBUTE_NOT_FOUND = 0x0A
|
||||
ATTRIBUTE_NOT_LONG = 0x0B
|
||||
INSUFFICIENT_ENCRYPTION_KEY_SIZE = 0x0C
|
||||
INVALID_ATTRIBUTE_LENGTH = 0x0D
|
||||
UNLIKELY_ERROR = 0x0E
|
||||
INSUFFICIENT_ENCRYPTION = 0x0F
|
||||
UNSUPPORTED_GROUP_TYPE = 0x10
|
||||
INSUFFICIENT_RESOURCES = 0x11
|
||||
DATABASE_OUT_OF_SYNC = 0x12
|
||||
VALUE_NOT_ALLOWED = 0x13
|
||||
# 0x80 – 0x9F: Application Error
|
||||
# 0xE0 – 0xFF: Common Profile and Service Error Codes
|
||||
WRITE_REQUEST_REJECTED = 0xFC
|
||||
CCCD_IMPROPERLY_CONFIGURED = 0xFD
|
||||
PROCEDURE_ALREADY_IN_PROGRESS = 0xFE
|
||||
OUT_OF_RANGE = 0xFF
|
||||
|
||||
# Backward Compatible Constants
|
||||
ATT_INVALID_HANDLE_ERROR = ErrorCode.INVALID_HANDLE
|
||||
ATT_READ_NOT_PERMITTED_ERROR = ErrorCode.READ_NOT_PERMITTED
|
||||
ATT_WRITE_NOT_PERMITTED_ERROR = ErrorCode.WRITE_NOT_PERMITTED
|
||||
ATT_INVALID_PDU_ERROR = ErrorCode.INVALID_PDU
|
||||
ATT_INSUFFICIENT_AUTHENTICATION_ERROR = ErrorCode.INSUFFICIENT_AUTHENTICATION
|
||||
ATT_REQUEST_NOT_SUPPORTED_ERROR = ErrorCode.REQUEST_NOT_SUPPORTED
|
||||
ATT_INVALID_OFFSET_ERROR = ErrorCode.INVALID_OFFSET
|
||||
ATT_INSUFFICIENT_AUTHORIZATION_ERROR = ErrorCode.INSUFFICIENT_AUTHORIZATION
|
||||
ATT_PREPARE_QUEUE_FULL_ERROR = ErrorCode.PREPARE_QUEUE_FULL
|
||||
ATT_ATTRIBUTE_NOT_FOUND_ERROR = ErrorCode.ATTRIBUTE_NOT_FOUND
|
||||
ATT_ATTRIBUTE_NOT_LONG_ERROR = ErrorCode.ATTRIBUTE_NOT_LONG
|
||||
ATT_INSUFFICIENT_ENCRYPTION_KEY_SIZE_ERROR = ErrorCode.INSUFFICIENT_ENCRYPTION_KEY_SIZE
|
||||
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR = ErrorCode.INVALID_ATTRIBUTE_LENGTH
|
||||
ATT_UNLIKELY_ERROR_ERROR = ErrorCode.UNLIKELY_ERROR
|
||||
ATT_INSUFFICIENT_ENCRYPTION_ERROR = ErrorCode.INSUFFICIENT_ENCRYPTION
|
||||
ATT_UNSUPPORTED_GROUP_TYPE_ERROR = ErrorCode.UNSUPPORTED_GROUP_TYPE
|
||||
ATT_INSUFFICIENT_RESOURCES_ERROR = ErrorCode.INSUFFICIENT_RESOURCES
|
||||
|
||||
ATT_DEFAULT_MTU = 23
|
||||
|
||||
@@ -182,6 +212,7 @@ UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
|
||||
# pylint: enable=line-too-long
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Exceptions
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -209,7 +240,7 @@ class ATT_PDU:
|
||||
|
||||
pdu_classes: Dict[int, Type[ATT_PDU]] = {}
|
||||
op_code = 0
|
||||
name = None
|
||||
name: str
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(pdu):
|
||||
@@ -231,9 +262,9 @@ class ATT_PDU:
|
||||
def pdu_name(op_code):
|
||||
return name_or_number(ATT_PDU_NAMES, op_code, 2)
|
||||
|
||||
@staticmethod
|
||||
def error_name(error_code):
|
||||
return name_or_number(ATT_ERROR_NAMES, error_code, 2)
|
||||
@classmethod
|
||||
def error_name(cls, error_code: int) -> str:
|
||||
return ErrorCode(error_code).name
|
||||
|
||||
@staticmethod
|
||||
def subclass(fields):
|
||||
@@ -261,9 +292,6 @@ class ATT_PDU:
|
||||
def init_from_bytes(self, pdu, offset):
|
||||
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||
|
||||
def to_bytes(self):
|
||||
return self.pdu
|
||||
|
||||
@property
|
||||
def is_command(self):
|
||||
return ((self.op_code >> 6) & 1) == 1
|
||||
@@ -273,7 +301,7 @@ class ATT_PDU:
|
||||
return ((self.op_code >> 7) & 1) == 1
|
||||
|
||||
def __bytes__(self):
|
||||
return self.to_bytes()
|
||||
return self.pdu
|
||||
|
||||
def __str__(self):
|
||||
result = color(self.name, 'yellow')
|
||||
@@ -641,7 +669,7 @@ class ATT_Write_Command(ATT_PDU):
|
||||
@ATT_PDU.subclass(
|
||||
[
|
||||
('attribute_handle', HANDLE_FIELD_SPEC),
|
||||
('attribute_value', '*')
|
||||
('attribute_value', '*'),
|
||||
# ('authentication_signature', 'TODO')
|
||||
]
|
||||
)
|
||||
@@ -680,7 +708,7 @@ class ATT_Prepare_Write_Response(ATT_PDU):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@ATT_PDU.subclass([])
|
||||
@ATT_PDU.subclass([("flags", 1)])
|
||||
class ATT_Execute_Write_Request(ATT_PDU):
|
||||
'''
|
||||
See Bluetooth spec @ Vol 3, Part F - 3.4.6.3 Execute Write Request
|
||||
@@ -719,48 +747,94 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AttributeValue:
|
||||
'''
|
||||
Attribute value where reading and/or writing is delegated to functions
|
||||
passed as arguments to the constructor.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read: Union[
|
||||
Callable[[Optional[Connection]], Any],
|
||||
Callable[[Optional[Connection]], Awaitable[Any]],
|
||||
None,
|
||||
] = None,
|
||||
write: Union[
|
||||
Callable[[Optional[Connection], Any], None],
|
||||
Callable[[Optional[Connection], Any], Awaitable[None]],
|
||||
None,
|
||||
] = None,
|
||||
):
|
||||
self._read = read
|
||||
self._write = write
|
||||
|
||||
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
|
||||
return self._read(connection) if self._read else b''
|
||||
|
||||
def write(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
) -> Union[Awaitable[None], None]:
|
||||
if self._write:
|
||||
return self._write(connection, value)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Attribute(EventEmitter):
|
||||
# Permission flags
|
||||
READABLE = 0x01
|
||||
WRITEABLE = 0x02
|
||||
READ_REQUIRES_ENCRYPTION = 0x04
|
||||
WRITE_REQUIRES_ENCRYPTION = 0x08
|
||||
READ_REQUIRES_AUTHENTICATION = 0x10
|
||||
WRITE_REQUIRES_AUTHENTICATION = 0x20
|
||||
READ_REQUIRES_AUTHORIZATION = 0x40
|
||||
WRITE_REQUIRES_AUTHORIZATION = 0x80
|
||||
class Permissions(enum.IntFlag):
|
||||
READABLE = 0x01
|
||||
WRITEABLE = 0x02
|
||||
READ_REQUIRES_ENCRYPTION = 0x04
|
||||
WRITE_REQUIRES_ENCRYPTION = 0x08
|
||||
READ_REQUIRES_AUTHENTICATION = 0x10
|
||||
WRITE_REQUIRES_AUTHENTICATION = 0x20
|
||||
READ_REQUIRES_AUTHORIZATION = 0x40
|
||||
WRITE_REQUIRES_AUTHORIZATION = 0x80
|
||||
|
||||
PERMISSION_NAMES = {
|
||||
READABLE: 'READABLE',
|
||||
WRITEABLE: 'WRITEABLE',
|
||||
READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION',
|
||||
WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION',
|
||||
READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION',
|
||||
WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION',
|
||||
READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION',
|
||||
WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION',
|
||||
}
|
||||
@classmethod
|
||||
def from_string(cls, permissions_str: str) -> Attribute.Permissions:
|
||||
try:
|
||||
return functools.reduce(
|
||||
lambda x, y: x | Attribute.Permissions[y],
|
||||
permissions_str.replace('|', ',').split(","),
|
||||
Attribute.Permissions(0),
|
||||
)
|
||||
except TypeError as exc:
|
||||
# The check for `p.name is not None` here is needed because for InFlag
|
||||
# enums, the .name property can be None, when the enum value is 0,
|
||||
# so the type hint for .name is Optional[str].
|
||||
enum_list: List[str] = [p.name for p in cls if p.name is not None]
|
||||
enum_list_str = ",".join(enum_list)
|
||||
raise TypeError(
|
||||
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str}\nGot: {permissions_str}"
|
||||
) from exc
|
||||
|
||||
@staticmethod
|
||||
def string_to_permissions(permissions_str: str):
|
||||
try:
|
||||
return functools.reduce(
|
||||
lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y),
|
||||
permissions_str.split(","),
|
||||
0,
|
||||
)
|
||||
except TypeError as exc:
|
||||
raise TypeError(
|
||||
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}"
|
||||
) from exc
|
||||
# Permission flags(legacy-use only)
|
||||
READABLE = Permissions.READABLE
|
||||
WRITEABLE = Permissions.WRITEABLE
|
||||
READ_REQUIRES_ENCRYPTION = Permissions.READ_REQUIRES_ENCRYPTION
|
||||
WRITE_REQUIRES_ENCRYPTION = Permissions.WRITE_REQUIRES_ENCRYPTION
|
||||
READ_REQUIRES_AUTHENTICATION = Permissions.READ_REQUIRES_AUTHENTICATION
|
||||
WRITE_REQUIRES_AUTHENTICATION = Permissions.WRITE_REQUIRES_AUTHENTICATION
|
||||
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
|
||||
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
|
||||
|
||||
def __init__(self, attribute_type, permissions, value=b''):
|
||||
value: Any
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attribute_type: Union[str, bytes, UUID],
|
||||
permissions: Union[str, Attribute.Permissions],
|
||||
value: Any = b'',
|
||||
) -> None:
|
||||
EventEmitter.__init__(self)
|
||||
self.handle = 0
|
||||
self.end_group_handle = 0
|
||||
if isinstance(permissions, str):
|
||||
self.permissions = self.string_to_permissions(permissions)
|
||||
self.permissions = Attribute.Permissions.from_string(permissions)
|
||||
else:
|
||||
self.permissions = permissions
|
||||
|
||||
@@ -772,28 +846,28 @@ class Attribute(EventEmitter):
|
||||
else:
|
||||
self.type = attribute_type
|
||||
|
||||
# Convert the value to a byte array
|
||||
if isinstance(value, str):
|
||||
self.value = bytes(value, 'utf-8')
|
||||
else:
|
||||
self.value = value
|
||||
self.value = value
|
||||
|
||||
def encode_value(self, value):
|
||||
def encode_value(self, value: Any) -> bytes:
|
||||
return value
|
||||
|
||||
def decode_value(self, value_bytes):
|
||||
def decode_value(self, value_bytes: bytes) -> Any:
|
||||
return value_bytes
|
||||
|
||||
def read_value(self, connection: Connection):
|
||||
async def read_value(self, connection: Optional[Connection]) -> bytes:
|
||||
if (
|
||||
self.permissions & self.READ_REQUIRES_ENCRYPTION
|
||||
) and not connection.encryption:
|
||||
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
|
||||
and connection is not None
|
||||
and not connection.encryption
|
||||
):
|
||||
raise ATT_Error(
|
||||
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
|
||||
)
|
||||
if (
|
||||
self.permissions & self.READ_REQUIRES_AUTHENTICATION
|
||||
) and not connection.authenticated:
|
||||
(self.permissions & self.READ_REQUIRES_AUTHENTICATION)
|
||||
and connection is not None
|
||||
and not connection.authenticated
|
||||
):
|
||||
raise ATT_Error(
|
||||
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
|
||||
)
|
||||
@@ -803,9 +877,11 @@ class Attribute(EventEmitter):
|
||||
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
|
||||
)
|
||||
|
||||
if read := getattr(self.value, 'read', None):
|
||||
if hasattr(self.value, 'read'):
|
||||
try:
|
||||
value = read(connection) # pylint: disable=not-callable
|
||||
value = self.value.read(connection)
|
||||
if inspect.isawaitable(value):
|
||||
value = await value
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
@@ -813,9 +889,11 @@ class Attribute(EventEmitter):
|
||||
else:
|
||||
value = self.value
|
||||
|
||||
self.emit('read', connection, value)
|
||||
|
||||
return self.encode_value(value)
|
||||
|
||||
def write_value(self, connection: Connection, value_bytes):
|
||||
async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
|
||||
if (
|
||||
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
|
||||
) and not connection.encryption:
|
||||
@@ -836,9 +914,11 @@ class Attribute(EventEmitter):
|
||||
|
||||
value = self.decode_value(value_bytes)
|
||||
|
||||
if write := getattr(self.value, 'write', None):
|
||||
if hasattr(self.value, 'write'):
|
||||
try:
|
||||
write(connection, value) # pylint: disable=not-callable
|
||||
result = self.value.write(connection, value)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
|
||||
524
bumble/avc.py
Normal file
524
bumble/avc.py
Normal file
@@ -0,0 +1,524 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import struct
|
||||
from typing import Dict, Type, Union, Tuple
|
||||
|
||||
from bumble import core
|
||||
from bumble.utils import OpenIntEnum
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Frame:
|
||||
class SubunitType(enum.IntEnum):
|
||||
# AV/C Digital Interface Command Set General Specification Version 4.1
|
||||
# Table 7.4
|
||||
MONITOR = 0x00
|
||||
AUDIO = 0x01
|
||||
PRINTER = 0x02
|
||||
DISC = 0x03
|
||||
TAPE_RECORDER_OR_PLAYER = 0x04
|
||||
TUNER = 0x05
|
||||
CA = 0x06
|
||||
CAMERA = 0x07
|
||||
PANEL = 0x09
|
||||
BULLETIN_BOARD = 0x0A
|
||||
VENDOR_UNIQUE = 0x1C
|
||||
EXTENDED = 0x1E
|
||||
UNIT = 0x1F
|
||||
|
||||
class OperationCode(OpenIntEnum):
|
||||
# 0x00 - 0x0F: Unit and subunit commands
|
||||
VENDOR_DEPENDENT = 0x00
|
||||
RESERVE = 0x01
|
||||
PLUG_INFO = 0x02
|
||||
|
||||
# 0x10 - 0x3F: Unit commands
|
||||
DIGITAL_OUTPUT = 0x10
|
||||
DIGITAL_INPUT = 0x11
|
||||
CHANNEL_USAGE = 0x12
|
||||
OUTPUT_PLUG_SIGNAL_FORMAT = 0x18
|
||||
INPUT_PLUG_SIGNAL_FORMAT = 0x19
|
||||
GENERAL_BUS_SETUP = 0x1F
|
||||
CONNECT_AV = 0x20
|
||||
DISCONNECT_AV = 0x21
|
||||
CONNECTIONS = 0x22
|
||||
CONNECT = 0x24
|
||||
DISCONNECT = 0x25
|
||||
UNIT_INFO = 0x30
|
||||
SUBUNIT_INFO = 0x31
|
||||
|
||||
# 0x40 - 0x7F: Subunit commands
|
||||
PASS_THROUGH = 0x7C
|
||||
GUI_UPDATE = 0x7D
|
||||
PUSH_GUI_DATA = 0x7E
|
||||
USER_ACTION = 0x7F
|
||||
|
||||
# 0xA0 - 0xBF: Unit and subunit commands
|
||||
VERSION = 0xB0
|
||||
POWER = 0xB2
|
||||
|
||||
subunit_type: SubunitType
|
||||
subunit_id: int
|
||||
opcode: OperationCode
|
||||
operands: bytes
|
||||
|
||||
@staticmethod
|
||||
def subclass(subclass):
|
||||
# Infer the opcode from the class name
|
||||
if subclass.__name__.endswith("CommandFrame"):
|
||||
short_name = subclass.__name__.replace("CommandFrame", "")
|
||||
category_class = CommandFrame
|
||||
elif subclass.__name__.endswith("ResponseFrame"):
|
||||
short_name = subclass.__name__.replace("ResponseFrame", "")
|
||||
category_class = ResponseFrame
|
||||
else:
|
||||
raise core.InvalidArgumentError(
|
||||
f"invalid subclass name {subclass.__name__}"
|
||||
)
|
||||
|
||||
uppercase_indexes = [
|
||||
i for i in range(len(short_name)) if short_name[i].isupper()
|
||||
]
|
||||
uppercase_indexes.append(len(short_name))
|
||||
words = [
|
||||
short_name[uppercase_indexes[i] : uppercase_indexes[i + 1]].upper()
|
||||
for i in range(len(uppercase_indexes) - 1)
|
||||
]
|
||||
opcode_name = "_".join(words)
|
||||
opcode = Frame.OperationCode[opcode_name]
|
||||
category_class.subclasses[opcode] = subclass
|
||||
return subclass
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> Frame:
|
||||
if data[0] >> 4 != 0:
|
||||
raise core.InvalidPacketError("first 4 bits must be 0s")
|
||||
|
||||
ctype_or_response = data[0] & 0xF
|
||||
subunit_type = Frame.SubunitType(data[1] >> 3)
|
||||
subunit_id = data[1] & 7
|
||||
|
||||
if subunit_type == Frame.SubunitType.EXTENDED:
|
||||
# Not supported
|
||||
raise NotImplementedError("extended subunit types not supported")
|
||||
|
||||
if subunit_id < 5 or subunit_id == 7:
|
||||
opcode_offset = 2
|
||||
elif subunit_id == 5:
|
||||
# Extended to the next byte
|
||||
extension = data[2]
|
||||
if extension == 0:
|
||||
raise core.InvalidPacketError("extended subunit ID value reserved")
|
||||
if extension == 0xFF:
|
||||
subunit_id = 5 + 254 + data[3]
|
||||
opcode_offset = 4
|
||||
else:
|
||||
subunit_id = 5 + extension
|
||||
opcode_offset = 3
|
||||
elif subunit_id == 6:
|
||||
raise core.InvalidPacketError("reserved subunit ID")
|
||||
else:
|
||||
raise core.InvalidPacketError("invalid subunit ID")
|
||||
|
||||
opcode = Frame.OperationCode(data[opcode_offset])
|
||||
operands = data[opcode_offset + 1 :]
|
||||
|
||||
# Look for a registered subclass
|
||||
if ctype_or_response < 8:
|
||||
# Command
|
||||
ctype = CommandFrame.CommandType(ctype_or_response)
|
||||
if c_subclass := CommandFrame.subclasses.get(opcode):
|
||||
return c_subclass(
|
||||
ctype,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
*c_subclass.parse_operands(operands),
|
||||
)
|
||||
return CommandFrame(ctype, subunit_type, subunit_id, opcode, operands)
|
||||
else:
|
||||
# Response
|
||||
response = ResponseFrame.ResponseCode(ctype_or_response)
|
||||
if r_subclass := ResponseFrame.subclasses.get(opcode):
|
||||
return r_subclass(
|
||||
response,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
*r_subclass.parse_operands(operands),
|
||||
)
|
||||
return ResponseFrame(response, subunit_type, subunit_id, opcode, operands)
|
||||
|
||||
def to_bytes(
|
||||
self,
|
||||
ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode],
|
||||
) -> bytes:
|
||||
# TODO: support extended subunit types and ids.
|
||||
return (
|
||||
bytes(
|
||||
[
|
||||
ctype_or_response,
|
||||
self.subunit_type << 3 | self.subunit_id,
|
||||
self.opcode,
|
||||
]
|
||||
)
|
||||
+ self.operands
|
||||
)
|
||||
|
||||
def to_string(self, extra: str) -> str:
|
||||
return (
|
||||
f"{self.__class__.__name__}({extra}"
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"opcode={self.opcode.name}, "
|
||||
f"operands={self.operands.hex()})"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subunit_type: SubunitType,
|
||||
subunit_id: int,
|
||||
opcode: OperationCode,
|
||||
operands: bytes,
|
||||
) -> None:
|
||||
self.subunit_type = subunit_type
|
||||
self.subunit_id = subunit_id
|
||||
self.opcode = opcode
|
||||
self.operands = operands
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommandFrame(Frame):
|
||||
class CommandType(OpenIntEnum):
|
||||
# AV/C Digital Interface Command Set General Specification Version 4.1
|
||||
# Table 7.1
|
||||
CONTROL = 0x00
|
||||
STATUS = 0x01
|
||||
SPECIFIC_INQUIRY = 0x02
|
||||
NOTIFY = 0x03
|
||||
GENERAL_INQUIRY = 0x04
|
||||
|
||||
subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {}
|
||||
ctype: CommandType
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctype: CommandType,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
opcode: Frame.OperationCode,
|
||||
operands: bytes,
|
||||
) -> None:
|
||||
super().__init__(subunit_type, subunit_id, opcode, operands)
|
||||
self.ctype = ctype
|
||||
|
||||
def __bytes__(self):
|
||||
return self.to_bytes(self.ctype)
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string(f"ctype={self.ctype.name}, ")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ResponseFrame(Frame):
|
||||
class ResponseCode(OpenIntEnum):
|
||||
# AV/C Digital Interface Command Set General Specification Version 4.1
|
||||
# Table 7.2
|
||||
NOT_IMPLEMENTED = 0x08
|
||||
ACCEPTED = 0x09
|
||||
REJECTED = 0x0A
|
||||
IN_TRANSITION = 0x0B
|
||||
IMPLEMENTED_OR_STABLE = 0x0C
|
||||
CHANGED = 0x0D
|
||||
INTERIM = 0x0F
|
||||
|
||||
subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {}
|
||||
response: ResponseCode
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: ResponseCode,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
opcode: Frame.OperationCode,
|
||||
operands: bytes,
|
||||
) -> None:
|
||||
super().__init__(subunit_type, subunit_id, opcode, operands)
|
||||
self.response = response
|
||||
|
||||
def __bytes__(self):
|
||||
return self.to_bytes(self.response)
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string(f"response={self.response.name}, ")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class VendorDependentFrame:
|
||||
company_id: int
|
||||
vendor_dependent_data: bytes
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
return (
|
||||
struct.unpack(">I", b"\x00" + operands[:3])[0],
|
||||
operands[3:],
|
||||
)
|
||||
|
||||
def make_operands(self) -> bytes:
|
||||
return struct.pack(">I", self.company_id)[1:] + self.vendor_dependent_data
|
||||
|
||||
def __init__(self, company_id: int, vendor_dependent_data: bytes):
|
||||
self.company_id = company_id
|
||||
self.vendor_dependent_data = vendor_dependent_data
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@Frame.subclass
|
||||
class VendorDependentCommandFrame(VendorDependentFrame, CommandFrame):
|
||||
def __init__(
|
||||
self,
|
||||
ctype: CommandFrame.CommandType,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
company_id: int,
|
||||
vendor_dependent_data: bytes,
|
||||
) -> None:
|
||||
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
|
||||
CommandFrame.__init__(
|
||||
self,
|
||||
ctype,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
Frame.OperationCode.VENDOR_DEPENDENT,
|
||||
self.make_operands(),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"VendorDependentCommandFrame(ctype={self.ctype.name}, "
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"company_id=0x{self.company_id:06X}, "
|
||||
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@Frame.subclass
|
||||
class VendorDependentResponseFrame(VendorDependentFrame, ResponseFrame):
|
||||
def __init__(
|
||||
self,
|
||||
response: ResponseFrame.ResponseCode,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
company_id: int,
|
||||
vendor_dependent_data: bytes,
|
||||
) -> None:
|
||||
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
|
||||
ResponseFrame.__init__(
|
||||
self,
|
||||
response,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
Frame.OperationCode.VENDOR_DEPENDENT,
|
||||
self.make_operands(),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"VendorDependentResponseFrame(response={self.response.name}, "
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"company_id=0x{self.company_id:06X}, "
|
||||
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PassThroughFrame:
|
||||
"""
|
||||
See AV/C Panel Subunit Specification 1.1 - 9.4 PASS THROUGH control command
|
||||
"""
|
||||
|
||||
class StateFlag(enum.IntEnum):
|
||||
PRESSED = 0
|
||||
RELEASED = 1
|
||||
|
||||
class OperationId(OpenIntEnum):
|
||||
SELECT = 0x00
|
||||
UP = 0x01
|
||||
DOWN = 0x01
|
||||
LEFT = 0x03
|
||||
RIGHT = 0x04
|
||||
RIGHT_UP = 0x05
|
||||
RIGHT_DOWN = 0x06
|
||||
LEFT_UP = 0x07
|
||||
LEFT_DOWN = 0x08
|
||||
ROOT_MENU = 0x09
|
||||
SETUP_MENU = 0x0A
|
||||
CONTENTS_MENU = 0x0B
|
||||
FAVORITE_MENU = 0x0C
|
||||
EXIT = 0x0D
|
||||
NUMBER_0 = 0x20
|
||||
NUMBER_1 = 0x21
|
||||
NUMBER_2 = 0x22
|
||||
NUMBER_3 = 0x23
|
||||
NUMBER_4 = 0x24
|
||||
NUMBER_5 = 0x25
|
||||
NUMBER_6 = 0x26
|
||||
NUMBER_7 = 0x27
|
||||
NUMBER_8 = 0x28
|
||||
NUMBER_9 = 0x29
|
||||
DOT = 0x2A
|
||||
ENTER = 0x2B
|
||||
CLEAR = 0x2C
|
||||
CHANNEL_UP = 0x30
|
||||
CHANNEL_DOWN = 0x31
|
||||
PREVIOUS_CHANNEL = 0x32
|
||||
SOUND_SELECT = 0x33
|
||||
INPUT_SELECT = 0x34
|
||||
DISPLAY_INFORMATION = 0x35
|
||||
HELP = 0x36
|
||||
PAGE_UP = 0x37
|
||||
PAGE_DOWN = 0x38
|
||||
POWER = 0x40
|
||||
VOLUME_UP = 0x41
|
||||
VOLUME_DOWN = 0x42
|
||||
MUTE = 0x43
|
||||
PLAY = 0x44
|
||||
STOP = 0x45
|
||||
PAUSE = 0x46
|
||||
RECORD = 0x47
|
||||
REWIND = 0x48
|
||||
FAST_FORWARD = 0x49
|
||||
EJECT = 0x4A
|
||||
FORWARD = 0x4B
|
||||
BACKWARD = 0x4C
|
||||
ANGLE = 0x50
|
||||
SUBPICTURE = 0x51
|
||||
F1 = 0x71
|
||||
F2 = 0x72
|
||||
F3 = 0x73
|
||||
F4 = 0x74
|
||||
F5 = 0x75
|
||||
VENDOR_UNIQUE = 0x7E
|
||||
|
||||
state_flag: StateFlag
|
||||
operation_id: OperationId
|
||||
operation_data: bytes
|
||||
|
||||
@staticmethod
|
||||
def parse_operands(operands: bytes) -> Tuple:
|
||||
return (
|
||||
PassThroughFrame.StateFlag(operands[0] >> 7),
|
||||
PassThroughFrame.OperationId(operands[0] & 0x7F),
|
||||
operands[1 : 1 + operands[1]],
|
||||
)
|
||||
|
||||
def make_operands(self):
|
||||
return (
|
||||
bytes([self.state_flag << 7 | self.operation_id, len(self.operation_data)])
|
||||
+ self.operation_data
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_flag: StateFlag,
|
||||
operation_id: OperationId,
|
||||
operation_data: bytes,
|
||||
) -> None:
|
||||
if len(operation_data) > 255:
|
||||
raise core.InvalidArgumentError("operation data must be <= 255 bytes")
|
||||
self.state_flag = state_flag
|
||||
self.operation_id = operation_id
|
||||
self.operation_data = operation_data
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@Frame.subclass
|
||||
class PassThroughCommandFrame(PassThroughFrame, CommandFrame):
|
||||
def __init__(
|
||||
self,
|
||||
ctype: CommandFrame.CommandType,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
state_flag: PassThroughFrame.StateFlag,
|
||||
operation_id: PassThroughFrame.OperationId,
|
||||
operation_data: bytes,
|
||||
) -> None:
|
||||
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
|
||||
CommandFrame.__init__(
|
||||
self,
|
||||
ctype,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
Frame.OperationCode.PASS_THROUGH,
|
||||
self.make_operands(),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"PassThroughCommandFrame(ctype={self.ctype.name}, "
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"state_flag={self.state_flag.name}, "
|
||||
f"operation_id={self.operation_id.name}, "
|
||||
f"operation_data={self.operation_data.hex()})"
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@Frame.subclass
|
||||
class PassThroughResponseFrame(PassThroughFrame, ResponseFrame):
|
||||
def __init__(
|
||||
self,
|
||||
response: ResponseFrame.ResponseCode,
|
||||
subunit_type: Frame.SubunitType,
|
||||
subunit_id: int,
|
||||
state_flag: PassThroughFrame.StateFlag,
|
||||
operation_id: PassThroughFrame.OperationId,
|
||||
operation_data: bytes,
|
||||
) -> None:
|
||||
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
|
||||
ResponseFrame.__init__(
|
||||
self,
|
||||
response,
|
||||
subunit_type,
|
||||
subunit_id,
|
||||
Frame.OperationCode.PASS_THROUGH,
|
||||
self.make_operands(),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"PassThroughResponseFrame(response={self.response.name}, "
|
||||
f"subunit_type={self.subunit_type.name}, "
|
||||
f"subunit_id=0x{self.subunit_id:02X}, "
|
||||
f"state_flag={self.state_flag.name}, "
|
||||
f"operation_id={self.operation_id.name}, "
|
||||
f"operation_data={self.operation_data.hex()})"
|
||||
)
|
||||
292
bumble/avctp.py
Normal file
292
bumble/avctp.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
import logging
|
||||
import struct
|
||||
from typing import Callable, cast, Dict, Optional
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble import avc
|
||||
from bumble import core
|
||||
from bumble import l2cap
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
AVCTP_PSM = 0x0017
|
||||
AVCTP_BROWSING_PSM = 0x001B
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MessageAssembler:
|
||||
Callback = Callable[[int, bool, bool, int, bytes], None]
|
||||
|
||||
transaction_label: int
|
||||
pid: int
|
||||
c_r: int
|
||||
ipid: int
|
||||
payload: bytes
|
||||
number_of_packets: int
|
||||
packets_received: int
|
||||
|
||||
def __init__(self, callback: Callback) -> None:
|
||||
self.callback = callback
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.packets_received = 0
|
||||
self.transaction_label = -1
|
||||
self.pid = -1
|
||||
self.c_r = -1
|
||||
self.ipid = -1
|
||||
self.payload = b''
|
||||
self.number_of_packets = 0
|
||||
self.packet_count = 0
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
self.packets_received += 1
|
||||
|
||||
transaction_label = pdu[0] >> 4
|
||||
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
|
||||
c_r = (pdu[0] >> 1) & 1
|
||||
ipid = pdu[0] & 1
|
||||
|
||||
if c_r == 0 and ipid != 0:
|
||||
logger.warning("invalid IPID in command frame")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
pid_offset = 1
|
||||
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START):
|
||||
if self.transaction_label >= 0:
|
||||
# We are already in a transaction
|
||||
logger.warning("received START or SINGLE fragment while in transaction")
|
||||
self.reset()
|
||||
self.packets_received = 1
|
||||
|
||||
if packet_type == Protocol.PacketType.START:
|
||||
self.number_of_packets = pdu[1]
|
||||
pid_offset = 2
|
||||
|
||||
pid = struct.unpack_from(">H", pdu, pid_offset)[0]
|
||||
self.payload += pdu[pid_offset + 2 :]
|
||||
|
||||
if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END):
|
||||
if transaction_label != self.transaction_label:
|
||||
logger.warning("transaction label does not match")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if pid != self.pid:
|
||||
logger.warning("PID does not match")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if c_r != self.c_r:
|
||||
logger.warning("C/R does not match")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if self.packets_received > self.number_of_packets:
|
||||
logger.warning("too many fragments in transaction")
|
||||
self.reset()
|
||||
return
|
||||
|
||||
if packet_type == Protocol.PacketType.END:
|
||||
if self.packets_received != self.number_of_packets:
|
||||
logger.warning("premature END")
|
||||
self.reset()
|
||||
return
|
||||
else:
|
||||
self.transaction_label = transaction_label
|
||||
self.c_r = c_r
|
||||
self.ipid = ipid
|
||||
self.pid = pid
|
||||
|
||||
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END):
|
||||
self.on_message_complete()
|
||||
|
||||
def on_message_complete(self):
|
||||
try:
|
||||
self.callback(
|
||||
self.transaction_label,
|
||||
self.c_r == 0,
|
||||
self.ipid != 0,
|
||||
self.pid,
|
||||
self.payload,
|
||||
)
|
||||
except Exception as error:
|
||||
logger.exception(color(f"!!! exception in callback: {error}", "red"))
|
||||
|
||||
self.reset()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Protocol:
|
||||
CommandHandler = Callable[[int, avc.CommandFrame], None]
|
||||
command_handlers: Dict[int, CommandHandler] # Command handlers, by PID
|
||||
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
|
||||
response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID
|
||||
next_transaction_label: int
|
||||
message_assembler: MessageAssembler
|
||||
|
||||
class PacketType(IntEnum):
|
||||
SINGLE = 0b00
|
||||
START = 0b01
|
||||
CONTINUE = 0b10
|
||||
END = 0b11
|
||||
|
||||
def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
self.command_handlers = {}
|
||||
self.response_handlers = {}
|
||||
self.l2cap_channel = l2cap_channel
|
||||
self.message_assembler = MessageAssembler(self.on_message)
|
||||
|
||||
# Register to receive PDUs from the channel
|
||||
l2cap_channel.sink = self.on_pdu
|
||||
l2cap_channel.on("open", self.on_l2cap_channel_open)
|
||||
l2cap_channel.on("close", self.on_l2cap_channel_close)
|
||||
|
||||
def on_l2cap_channel_open(self):
|
||||
logger.debug(color("<<< AVCTP channel open", "magenta"))
|
||||
|
||||
def on_l2cap_channel_close(self):
|
||||
logger.debug(color("<<< AVCTP channel closed", "magenta"))
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
self.message_assembler.on_pdu(pdu)
|
||||
|
||||
def on_message(
|
||||
self,
|
||||
transaction_label: int,
|
||||
is_command: bool,
|
||||
ipid: bool,
|
||||
pid: int,
|
||||
payload: bytes,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"<<< AVCTP Message: pid={pid}, "
|
||||
f"transaction_label={transaction_label}, "
|
||||
f"is_command={is_command}, "
|
||||
f"ipid={ipid}, "
|
||||
f"payload={payload.hex()}"
|
||||
)
|
||||
|
||||
# Check for invalid PID responses.
|
||||
if ipid:
|
||||
logger.debug(f"received IPID for PID={pid}")
|
||||
|
||||
# Find the appropriate handler.
|
||||
if is_command:
|
||||
if pid not in self.command_handlers:
|
||||
logger.warning(f"no command handler for PID {pid}")
|
||||
self.send_ipid(transaction_label, pid)
|
||||
return
|
||||
|
||||
command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
|
||||
self.command_handlers[pid](transaction_label, command_frame)
|
||||
else:
|
||||
if pid not in self.response_handlers:
|
||||
logger.warning(f"no response handler for PID {pid}")
|
||||
return
|
||||
|
||||
# By convention, for an ipid, send a None payload to the response handler.
|
||||
if ipid:
|
||||
response_frame = None
|
||||
else:
|
||||
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
|
||||
|
||||
self.response_handlers[pid](transaction_label, response_frame)
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
transaction_label: int,
|
||||
is_command: bool,
|
||||
ipid: bool,
|
||||
pid: int,
|
||||
payload: bytes,
|
||||
):
|
||||
# TODO: fragment large messages
|
||||
packet_type = Protocol.PacketType.SINGLE
|
||||
pdu = (
|
||||
struct.pack(
|
||||
">BH",
|
||||
transaction_label << 4
|
||||
| packet_type << 2
|
||||
| (0 if is_command else 1) << 1
|
||||
| (1 if ipid else 0),
|
||||
pid,
|
||||
)
|
||||
+ payload
|
||||
)
|
||||
self.l2cap_channel.send_pdu(pdu)
|
||||
|
||||
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
|
||||
logger.debug(
|
||||
">>> AVCTP command: "
|
||||
f"transaction_label={transaction_label}, "
|
||||
f"pid={pid}, "
|
||||
f"payload={payload.hex()}"
|
||||
)
|
||||
self.send_message(transaction_label, True, False, pid, payload)
|
||||
|
||||
def send_response(self, transaction_label: int, pid: int, payload: bytes):
|
||||
logger.debug(
|
||||
">>> AVCTP response: "
|
||||
f"transaction_label={transaction_label}, "
|
||||
f"pid={pid}, "
|
||||
f"payload={payload.hex()}"
|
||||
)
|
||||
self.send_message(transaction_label, False, False, pid, payload)
|
||||
|
||||
def send_ipid(self, transaction_label: int, pid: int) -> None:
|
||||
logger.debug(
|
||||
">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
|
||||
)
|
||||
self.send_message(transaction_label, False, True, pid, b'')
|
||||
|
||||
def register_command_handler(
|
||||
self, pid: int, handler: Protocol.CommandHandler
|
||||
) -> None:
|
||||
self.command_handlers[pid] = handler
|
||||
|
||||
def unregister_command_handler(
|
||||
self, pid: int, handler: Protocol.CommandHandler
|
||||
) -> None:
|
||||
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
|
||||
raise core.InvalidArgumentError("command handler not registered")
|
||||
del self.command_handlers[pid]
|
||||
|
||||
def register_response_handler(
|
||||
self, pid: int, handler: Protocol.ResponseHandler
|
||||
) -> None:
|
||||
self.response_handlers[pid] = handler
|
||||
|
||||
def unregister_response_handler(
|
||||
self, pid: int, handler: Protocol.ResponseHandler
|
||||
) -> None:
|
||||
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
|
||||
raise core.InvalidArgumentError("response handler not registered")
|
||||
del self.response_handlers[pid]
|
||||
613
bumble/avdtp.py
613
bumble/avdtp.py
File diff suppressed because it is too large
Load Diff
1938
bumble/avrcp.py
Normal file
1938
bumble/avrcp.py
Normal file
File diff suppressed because it is too large
Load Diff
316
bumble/codecs.py
316
bumble/codecs.py
@@ -17,6 +17,9 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble import core
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -40,7 +43,7 @@ class BitReader:
|
||||
""" "Read up to 32 bits."""
|
||||
|
||||
if bits > 32:
|
||||
raise ValueError('maximum read size is 32')
|
||||
raise core.InvalidArgumentError('maximum read size is 32')
|
||||
|
||||
if self.bits_cached >= bits:
|
||||
# We have enough bits.
|
||||
@@ -53,7 +56,7 @@ class BitReader:
|
||||
feed_size = len(feed_bytes)
|
||||
feed_int = int.from_bytes(feed_bytes, byteorder='big')
|
||||
if 8 * feed_size + self.bits_cached < bits:
|
||||
raise ValueError('trying to read past the data')
|
||||
raise core.InvalidArgumentError('trying to read past the data')
|
||||
self.byte_position += feed_size
|
||||
|
||||
# Combine the new cache and the old cache
|
||||
@@ -68,7 +71,7 @@ class BitReader:
|
||||
|
||||
def read_bytes(self, count: int):
|
||||
if self.bit_position + 8 * count > 8 * len(self.data):
|
||||
raise ValueError('not enough data')
|
||||
raise core.InvalidArgumentError('not enough data')
|
||||
|
||||
if self.bit_position % 8:
|
||||
# Not byte aligned
|
||||
@@ -99,12 +102,40 @@ class BitReader:
|
||||
break
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class BitWriter:
|
||||
"""Simple but not optimized bit stream writer."""
|
||||
|
||||
data: int
|
||||
bit_count: int
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = 0
|
||||
self.bit_count = 0
|
||||
|
||||
def write(self, value: int, bit_count: int) -> None:
|
||||
self.data = (self.data << bit_count) | value
|
||||
self.bit_count += bit_count
|
||||
|
||||
def write_bytes(self, data: bytes) -> None:
|
||||
bit_count = 8 * len(data)
|
||||
self.data = (self.data << bit_count) | int.from_bytes(data, 'big')
|
||||
self.bit_count += bit_count
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return (self.data << ((8 - (self.bit_count % 8)) % 8)).to_bytes(
|
||||
(self.bit_count + 7) // 8, 'big'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AacAudioRtpPacket:
|
||||
"""AAC payload encapsulated in an RTP packet payload"""
|
||||
|
||||
audio_mux_element: AudioMuxElement
|
||||
|
||||
@staticmethod
|
||||
def latm_value(reader: BitReader) -> int:
|
||||
def read_latm_value(reader: BitReader) -> int:
|
||||
bytes_for_value = reader.read(2)
|
||||
value = 0
|
||||
for _ in range(bytes_for_value + 1):
|
||||
@@ -112,24 +143,33 @@ class AacAudioRtpPacket:
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def program_config_element(reader: BitReader):
|
||||
raise ValueError('program_config_element not supported')
|
||||
def read_audio_object_type(reader: BitReader):
|
||||
# GetAudioObjectType - ISO/EIC 14496-3 Table 1.16
|
||||
audio_object_type = reader.read(5)
|
||||
if audio_object_type == 31:
|
||||
audio_object_type = 32 + reader.read(6)
|
||||
|
||||
return audio_object_type
|
||||
|
||||
@dataclass
|
||||
class GASpecificConfig:
|
||||
def __init__(
|
||||
self, reader: BitReader, channel_configuration: int, audio_object_type: int
|
||||
) -> None:
|
||||
audio_object_type: int
|
||||
# NOTE: other fields not supported
|
||||
|
||||
@classmethod
|
||||
def from_bits(
|
||||
cls, reader: BitReader, channel_configuration: int, audio_object_type: int
|
||||
) -> Self:
|
||||
# GASpecificConfig - ISO/EIC 14496-3 Table 4.1
|
||||
frame_length_flag = reader.read(1)
|
||||
depends_on_core_coder = reader.read(1)
|
||||
if depends_on_core_coder:
|
||||
self.core_coder_delay = reader.read(14)
|
||||
core_coder_delay = reader.read(14)
|
||||
extension_flag = reader.read(1)
|
||||
if not channel_configuration:
|
||||
AacAudioRtpPacket.program_config_element(reader)
|
||||
raise core.InvalidPacketError('program_config_element not supported')
|
||||
if audio_object_type in (6, 20):
|
||||
self.layer_nr = reader.read(3)
|
||||
layer_nr = reader.read(3)
|
||||
if extension_flag:
|
||||
if audio_object_type == 22:
|
||||
num_of_sub_frame = reader.read(5)
|
||||
@@ -140,16 +180,15 @@ class AacAudioRtpPacket:
|
||||
aac_spectral_data_resilience_flags = reader.read(1)
|
||||
extension_flag_3 = reader.read(1)
|
||||
if extension_flag_3 == 1:
|
||||
raise ValueError('extensionFlag3 == 1 not supported')
|
||||
raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
|
||||
|
||||
@staticmethod
|
||||
def audio_object_type(reader: BitReader):
|
||||
# GetAudioObjectType - ISO/EIC 14496-3 Table 1.16
|
||||
audio_object_type = reader.read(5)
|
||||
if audio_object_type == 31:
|
||||
audio_object_type = 32 + reader.read(6)
|
||||
return cls(audio_object_type)
|
||||
|
||||
return audio_object_type
|
||||
def to_bits(self, writer: BitWriter) -> None:
|
||||
assert self.audio_object_type in (1, 2)
|
||||
writer.write(0, 1) # frame_length_flag = 0
|
||||
writer.write(0, 1) # depends_on_core_coder = 0
|
||||
writer.write(0, 1) # extension_flag = 0
|
||||
|
||||
@dataclass
|
||||
class AudioSpecificConfig:
|
||||
@@ -157,6 +196,7 @@ class AacAudioRtpPacket:
|
||||
sampling_frequency_index: int
|
||||
sampling_frequency: int
|
||||
channel_configuration: int
|
||||
ga_specific_config: AacAudioRtpPacket.GASpecificConfig
|
||||
sbr_present_flag: int
|
||||
ps_present_flag: int
|
||||
extension_audio_object_type: int
|
||||
@@ -180,44 +220,73 @@ class AacAudioRtpPacket:
|
||||
7350,
|
||||
]
|
||||
|
||||
def __init__(self, reader: BitReader) -> None:
|
||||
# AudioSpecificConfig - ISO/EIC 14496-3 Table 1.15
|
||||
self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
|
||||
self.sampling_frequency_index = reader.read(4)
|
||||
if self.sampling_frequency_index == 0xF:
|
||||
self.sampling_frequency = reader.read(24)
|
||||
else:
|
||||
self.sampling_frequency = self.SAMPLING_FREQUENCIES[
|
||||
self.sampling_frequency_index
|
||||
]
|
||||
self.channel_configuration = reader.read(4)
|
||||
self.sbr_present_flag = -1
|
||||
self.ps_present_flag = -1
|
||||
if self.audio_object_type in (5, 29):
|
||||
self.extension_audio_object_type = 5
|
||||
self.sbc_present_flag = 1
|
||||
if self.audio_object_type == 29:
|
||||
self.ps_present_flag = 1
|
||||
self.extension_sampling_frequency_index = reader.read(4)
|
||||
if self.extension_sampling_frequency_index == 0xF:
|
||||
self.extension_sampling_frequency = reader.read(24)
|
||||
else:
|
||||
self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[
|
||||
self.extension_sampling_frequency_index
|
||||
]
|
||||
self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
|
||||
if self.audio_object_type == 22:
|
||||
self.extension_channel_configuration = reader.read(4)
|
||||
else:
|
||||
self.extension_audio_object_type = 0
|
||||
@classmethod
|
||||
def for_simple_aac(
|
||||
cls,
|
||||
audio_object_type: int,
|
||||
sampling_frequency: int,
|
||||
channel_configuration: int,
|
||||
) -> Self:
|
||||
if sampling_frequency not in cls.SAMPLING_FREQUENCIES:
|
||||
raise ValueError(f'invalid sampling frequency {sampling_frequency}')
|
||||
|
||||
if self.audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23):
|
||||
ga_specific_config = AacAudioRtpPacket.GASpecificConfig(
|
||||
reader, self.channel_configuration, self.audio_object_type
|
||||
ga_specific_config = AacAudioRtpPacket.GASpecificConfig(audio_object_type)
|
||||
|
||||
return cls(
|
||||
audio_object_type=audio_object_type,
|
||||
sampling_frequency_index=cls.SAMPLING_FREQUENCIES.index(
|
||||
sampling_frequency
|
||||
),
|
||||
sampling_frequency=sampling_frequency,
|
||||
channel_configuration=channel_configuration,
|
||||
ga_specific_config=ga_specific_config,
|
||||
sbr_present_flag=0,
|
||||
ps_present_flag=0,
|
||||
extension_audio_object_type=0,
|
||||
extension_sampling_frequency_index=0,
|
||||
extension_sampling_frequency=0,
|
||||
extension_channel_configuration=0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_bits(cls, reader: BitReader) -> Self:
|
||||
# AudioSpecificConfig - ISO/EIC 14496-3 Table 1.15
|
||||
audio_object_type = AacAudioRtpPacket.read_audio_object_type(reader)
|
||||
sampling_frequency_index = reader.read(4)
|
||||
if sampling_frequency_index == 0xF:
|
||||
sampling_frequency = reader.read(24)
|
||||
else:
|
||||
sampling_frequency = cls.SAMPLING_FREQUENCIES[sampling_frequency_index]
|
||||
channel_configuration = reader.read(4)
|
||||
sbr_present_flag = 0
|
||||
ps_present_flag = 0
|
||||
extension_sampling_frequency_index = 0
|
||||
extension_sampling_frequency = 0
|
||||
extension_channel_configuration = 0
|
||||
extension_audio_object_type = 0
|
||||
if audio_object_type in (5, 29):
|
||||
extension_audio_object_type = 5
|
||||
sbr_present_flag = 1
|
||||
if audio_object_type == 29:
|
||||
ps_present_flag = 1
|
||||
extension_sampling_frequency_index = reader.read(4)
|
||||
if extension_sampling_frequency_index == 0xF:
|
||||
extension_sampling_frequency = reader.read(24)
|
||||
else:
|
||||
extension_sampling_frequency = cls.SAMPLING_FREQUENCIES[
|
||||
extension_sampling_frequency_index
|
||||
]
|
||||
audio_object_type = AacAudioRtpPacket.read_audio_object_type(reader)
|
||||
if audio_object_type == 22:
|
||||
extension_channel_configuration = reader.read(4)
|
||||
|
||||
if audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23):
|
||||
ga_specific_config = AacAudioRtpPacket.GASpecificConfig.from_bits(
|
||||
reader, channel_configuration, audio_object_type
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'audioObjectType {self.audio_object_type} not supported'
|
||||
raise core.InvalidPacketError(
|
||||
f'audioObjectType {audio_object_type} not supported'
|
||||
)
|
||||
|
||||
# if self.extension_audio_object_type != 5 and bits_to_decode >= 16:
|
||||
@@ -246,13 +315,44 @@ class AacAudioRtpPacket:
|
||||
# self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[self.extension_sampling_frequency_index]
|
||||
# self.extension_channel_configuration = reader.read(4)
|
||||
|
||||
return cls(
|
||||
audio_object_type,
|
||||
sampling_frequency_index,
|
||||
sampling_frequency,
|
||||
channel_configuration,
|
||||
ga_specific_config,
|
||||
sbr_present_flag,
|
||||
ps_present_flag,
|
||||
extension_audio_object_type,
|
||||
extension_sampling_frequency_index,
|
||||
extension_sampling_frequency,
|
||||
extension_channel_configuration,
|
||||
)
|
||||
|
||||
def to_bits(self, writer: BitWriter) -> None:
|
||||
if self.sampling_frequency_index >= 15:
|
||||
raise ValueError(
|
||||
f"unsupported sampling frequency index {self.sampling_frequency_index}"
|
||||
)
|
||||
|
||||
if self.audio_object_type not in (1, 2):
|
||||
raise ValueError(
|
||||
f"unsupported audio object type {self.audio_object_type} "
|
||||
)
|
||||
|
||||
writer.write(self.audio_object_type, 5)
|
||||
writer.write(self.sampling_frequency_index, 4)
|
||||
writer.write(self.channel_configuration, 4)
|
||||
self.ga_specific_config.to_bits(writer)
|
||||
|
||||
@dataclass
|
||||
class StreamMuxConfig:
|
||||
other_data_present: int
|
||||
other_data_len_bits: int
|
||||
audio_specific_config: AacAudioRtpPacket.AudioSpecificConfig
|
||||
|
||||
def __init__(self, reader: BitReader) -> None:
|
||||
@classmethod
|
||||
def from_bits(cls, reader: BitReader) -> Self:
|
||||
# StreamMuxConfig - ISO/EIC 14496-3 Table 1.42
|
||||
audio_mux_version = reader.read(1)
|
||||
if audio_mux_version == 1:
|
||||
@@ -260,31 +360,31 @@ class AacAudioRtpPacket:
|
||||
else:
|
||||
audio_mux_version_a = 0
|
||||
if audio_mux_version_a != 0:
|
||||
raise ValueError('audioMuxVersionA != 0 not supported')
|
||||
raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
|
||||
if audio_mux_version == 1:
|
||||
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
|
||||
tara_buffer_fullness = AacAudioRtpPacket.read_latm_value(reader)
|
||||
stream_cnt = 0
|
||||
all_streams_same_time_framing = reader.read(1)
|
||||
num_sub_frames = reader.read(6)
|
||||
num_program = reader.read(4)
|
||||
if num_program != 0:
|
||||
raise ValueError('num_program != 0 not supported')
|
||||
raise core.InvalidPacketError('num_program != 0 not supported')
|
||||
num_layer = reader.read(3)
|
||||
if num_layer != 0:
|
||||
raise ValueError('num_layer != 0 not supported')
|
||||
raise core.InvalidPacketError('num_layer != 0 not supported')
|
||||
if audio_mux_version == 0:
|
||||
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
|
||||
audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig.from_bits(
|
||||
reader
|
||||
)
|
||||
else:
|
||||
asc_len = AacAudioRtpPacket.latm_value(reader)
|
||||
asc_len = AacAudioRtpPacket.read_latm_value(reader)
|
||||
marker = reader.bit_position
|
||||
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
|
||||
audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig.from_bits(
|
||||
reader
|
||||
)
|
||||
audio_specific_config_len = reader.bit_position - marker
|
||||
if asc_len < audio_specific_config_len:
|
||||
raise ValueError('audio_specific_config_len > asc_len')
|
||||
raise core.InvalidPacketError('audio_specific_config_len > asc_len')
|
||||
asc_len -= audio_specific_config_len
|
||||
reader.skip(asc_len)
|
||||
frame_length_type = reader.read(3)
|
||||
@@ -293,38 +393,53 @@ class AacAudioRtpPacket:
|
||||
elif frame_length_type == 1:
|
||||
frame_length = reader.read(9)
|
||||
else:
|
||||
raise ValueError(f'frame_length_type {frame_length_type} not supported')
|
||||
raise core.InvalidPacketError(
|
||||
f'frame_length_type {frame_length_type} not supported'
|
||||
)
|
||||
|
||||
self.other_data_present = reader.read(1)
|
||||
if self.other_data_present:
|
||||
other_data_present = reader.read(1)
|
||||
other_data_len_bits = 0
|
||||
if other_data_present:
|
||||
if audio_mux_version == 1:
|
||||
self.other_data_len_bits = AacAudioRtpPacket.latm_value(reader)
|
||||
other_data_len_bits = AacAudioRtpPacket.read_latm_value(reader)
|
||||
else:
|
||||
self.other_data_len_bits = 0
|
||||
while True:
|
||||
self.other_data_len_bits *= 256
|
||||
other_data_len_bits *= 256
|
||||
other_data_len_esc = reader.read(1)
|
||||
self.other_data_len_bits += reader.read(8)
|
||||
other_data_len_bits += reader.read(8)
|
||||
if other_data_len_esc == 0:
|
||||
break
|
||||
crc_check_present = reader.read(1)
|
||||
if crc_check_present:
|
||||
crc_checksum = reader.read(8)
|
||||
|
||||
return cls(other_data_present, other_data_len_bits, audio_specific_config)
|
||||
|
||||
def to_bits(self, writer: BitWriter) -> None:
|
||||
writer.write(0, 1) # audioMuxVersion = 0
|
||||
writer.write(1, 1) # allStreamsSameTimeFraming = 1
|
||||
writer.write(0, 6) # numSubFrames = 0
|
||||
writer.write(0, 4) # numProgram = 0
|
||||
writer.write(0, 3) # numLayer = 0
|
||||
self.audio_specific_config.to_bits(writer)
|
||||
writer.write(0, 3) # frameLengthType = 0
|
||||
writer.write(0, 8) # latmBufferFullness = 0
|
||||
writer.write(0, 1) # otherDataPresent = 0
|
||||
writer.write(0, 1) # crcCheckPresent = 0
|
||||
|
||||
@dataclass
|
||||
class AudioMuxElement:
|
||||
payload: bytes
|
||||
stream_mux_config: AacAudioRtpPacket.StreamMuxConfig
|
||||
payload: bytes
|
||||
|
||||
def __init__(self, reader: BitReader, mux_config_present: int):
|
||||
if mux_config_present == 0:
|
||||
raise ValueError('muxConfigPresent == 0 not supported')
|
||||
|
||||
@classmethod
|
||||
def from_bits(cls, reader: BitReader) -> Self:
|
||||
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41
|
||||
# (only supports mux_config_present=1)
|
||||
use_same_stream_mux = reader.read(1)
|
||||
if use_same_stream_mux:
|
||||
raise ValueError('useSameStreamMux == 1 not supported')
|
||||
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
|
||||
raise core.InvalidPacketError('useSameStreamMux == 1 not supported')
|
||||
stream_mux_config = AacAudioRtpPacket.StreamMuxConfig.from_bits(reader)
|
||||
|
||||
# We only support:
|
||||
# allStreamsSameTimeFraming == 1
|
||||
@@ -340,19 +455,46 @@ class AacAudioRtpPacket:
|
||||
if tmp != 255:
|
||||
break
|
||||
|
||||
self.payload = reader.read_bytes(mux_slot_length_bytes)
|
||||
payload = reader.read_bytes(mux_slot_length_bytes)
|
||||
|
||||
if self.stream_mux_config.other_data_present:
|
||||
reader.skip(self.stream_mux_config.other_data_len_bits)
|
||||
if stream_mux_config.other_data_present:
|
||||
reader.skip(stream_mux_config.other_data_len_bits)
|
||||
|
||||
# ByteAlign
|
||||
while reader.bit_position % 8:
|
||||
reader.read(1)
|
||||
|
||||
def __init__(self, data: bytes) -> None:
|
||||
return cls(stream_mux_config, payload)
|
||||
|
||||
def to_bits(self, writer: BitWriter) -> None:
|
||||
writer.write(0, 1) # useSameStreamMux = 0
|
||||
self.stream_mux_config.to_bits(writer)
|
||||
mux_slot_length_bytes = len(self.payload)
|
||||
while mux_slot_length_bytes > 255:
|
||||
writer.write(255, 8)
|
||||
mux_slot_length_bytes -= 255
|
||||
writer.write(mux_slot_length_bytes, 8)
|
||||
if mux_slot_length_bytes == 255:
|
||||
writer.write(0, 8)
|
||||
writer.write_bytes(self.payload)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
# Parse the bit stream
|
||||
reader = BitReader(data)
|
||||
self.audio_mux_element = self.AudioMuxElement(reader, mux_config_present=1)
|
||||
return cls(cls.AudioMuxElement.from_bits(reader))
|
||||
|
||||
@classmethod
|
||||
def for_simple_aac(
|
||||
cls, sampling_frequency: int, channel_configuration: int, payload: bytes
|
||||
) -> Self:
|
||||
audio_specific_config = cls.AudioSpecificConfig.for_simple_aac(
|
||||
2, sampling_frequency, channel_configuration
|
||||
)
|
||||
stream_mux_config = cls.StreamMuxConfig(0, 0, audio_specific_config)
|
||||
audio_mux_element = cls.AudioMuxElement(stream_mux_config, payload)
|
||||
|
||||
return cls(audio_mux_element)
|
||||
|
||||
def to_adts(self):
|
||||
# pylint: disable=line-too-long
|
||||
@@ -379,3 +521,11 @@ class AacAudioRtpPacket:
|
||||
)
|
||||
+ self.audio_mux_element.payload
|
||||
)
|
||||
|
||||
def __init__(self, audio_mux_element: AudioMuxElement) -> None:
|
||||
self.audio_mux_element = audio_mux_element
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
writer = BitWriter()
|
||||
self.audio_mux_element.to_bits(writer)
|
||||
return bytes(writer)
|
||||
|
||||
@@ -16,6 +16,10 @@ from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
class ColorError(ValueError):
|
||||
"""Error raised when a color spec is invalid."""
|
||||
|
||||
|
||||
# ANSI color names. There is also a "default"
|
||||
COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
|
||||
|
||||
@@ -52,7 +56,7 @@ def _color_code(spec: ColorSpec, base: int) -> str:
|
||||
elif isinstance(spec, int) and 0 <= spec <= 255:
|
||||
return _join(base + 8, 5, spec)
|
||||
else:
|
||||
raise ValueError('Invalid color spec "%s"' % spec)
|
||||
raise ColorError('Invalid color spec "%s"' % spec)
|
||||
|
||||
|
||||
def color(
|
||||
@@ -72,7 +76,7 @@ def color(
|
||||
if style_part in STYLES:
|
||||
codes.append(STYLES.index(style_part))
|
||||
else:
|
||||
raise ValueError('Invalid style "%s"' % style_part)
|
||||
raise ColorError('Invalid style "%s"' % style_part)
|
||||
|
||||
if codes:
|
||||
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,8 +15,11 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import itertools
|
||||
import random
|
||||
import struct
|
||||
@@ -40,6 +43,7 @@ from bumble.hci import (
|
||||
HCI_LE_1M_PHY,
|
||||
HCI_SUCCESS,
|
||||
HCI_UNKNOWN_HCI_COMMAND_ERROR,
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
HCI_VERSION_BLUETOOTH_CORE_5_0,
|
||||
Address,
|
||||
@@ -51,15 +55,21 @@ from bumble.hci import (
|
||||
HCI_Connection_Request_Event,
|
||||
HCI_Disconnection_Complete_Event,
|
||||
HCI_Encryption_Change_Event,
|
||||
HCI_Synchronous_Connection_Complete_Event,
|
||||
HCI_LE_Advertising_Report_Event,
|
||||
HCI_LE_CIS_Established_Event,
|
||||
HCI_LE_CIS_Request_Event,
|
||||
HCI_LE_Connection_Complete_Event,
|
||||
HCI_LE_Read_Remote_Features_Complete_Event,
|
||||
HCI_Number_Of_Completed_Packets_Event,
|
||||
HCI_Packet,
|
||||
HCI_Role_Change_Event,
|
||||
)
|
||||
from typing import Optional, Union, Dict
|
||||
from typing import Optional, Union, Dict, Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.link import LocalLink
|
||||
from bumble.transport.common import TransportSink
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -75,15 +85,27 @@ class DataObject:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class CisLink:
|
||||
handle: int
|
||||
cis_id: int
|
||||
cig_id: int
|
||||
acl_connection: Optional[Connection] = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class Connection:
|
||||
def __init__(self, controller, handle, role, peer_address, link, transport):
|
||||
self.controller = controller
|
||||
self.handle = handle
|
||||
self.role = role
|
||||
self.peer_address = peer_address
|
||||
self.link = link
|
||||
controller: Controller
|
||||
handle: int
|
||||
role: int
|
||||
peer_address: Address
|
||||
link: Any
|
||||
transport: int
|
||||
link_type: int
|
||||
|
||||
def __post_init__(self):
|
||||
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
|
||||
self.transport = transport
|
||||
|
||||
def on_hci_acl_data_packet(self, packet):
|
||||
self.assembler.feed_packet(packet)
|
||||
@@ -102,25 +124,27 @@ class Connection:
|
||||
class Controller:
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
name: str,
|
||||
host_source=None,
|
||||
host_sink=None,
|
||||
link=None,
|
||||
host_sink: Optional[TransportSink] = None,
|
||||
link: Optional[LocalLink] = None,
|
||||
public_address: Optional[Union[bytes, str, Address]] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.hci_sink = None
|
||||
self.link = link
|
||||
|
||||
self.central_connections: Dict[
|
||||
Address, Connection
|
||||
] = {} # Connections where this controller is the central
|
||||
self.peripheral_connections: Dict[
|
||||
Address, Connection
|
||||
] = {} # Connections where this controller is the peripheral
|
||||
self.classic_connections: Dict[
|
||||
Address, Connection
|
||||
] = {} # Connections in BR/EDR
|
||||
self.central_connections: Dict[Address, Connection] = (
|
||||
{}
|
||||
) # Connections where this controller is the central
|
||||
self.peripheral_connections: Dict[Address, Connection] = (
|
||||
{}
|
||||
) # Connections where this controller is the peripheral
|
||||
self.classic_connections: Dict[Address, Connection] = (
|
||||
{}
|
||||
) # Connections in BR/EDR
|
||||
self.central_cis_links: Dict[int, CisLink] = {} # CIS links by handle
|
||||
self.peripheral_cis_links: Dict[int, CisLink] = {} # CIS links by handle
|
||||
|
||||
self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0
|
||||
self.hci_revision = 0
|
||||
@@ -130,12 +154,14 @@ class Controller:
|
||||
'0000000060000000'
|
||||
) # BR/EDR Not Supported, LE Supported (Controller)
|
||||
self.manufacturer_name = 0xFFFF
|
||||
self.hc_data_packet_length = 27
|
||||
self.hc_total_num_data_packets = 64
|
||||
self.hc_le_data_packet_length = 27
|
||||
self.hc_total_num_le_data_packets = 64
|
||||
self.event_mask = 0
|
||||
self.event_mask_page_2 = 0
|
||||
self.supported_commands = bytes.fromhex(
|
||||
'2000800000c000000000e40000002822000000000000040000f7ffff7f000000'
|
||||
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
|
||||
'30f0f9ff01008004000000000000000000000000000000000000000000000000'
|
||||
)
|
||||
self.le_event_mask = 0
|
||||
@@ -288,7 +314,7 @@ class Controller:
|
||||
f'{color("CONTROLLER -> HOST", "green")}: {packet}'
|
||||
)
|
||||
if self.host:
|
||||
self.host.on_packet(packet.to_bytes())
|
||||
self.host.on_packet(bytes(packet))
|
||||
|
||||
# This method allows the controller to emulate the same API as a transport source
|
||||
async def wait_for_termination(self):
|
||||
@@ -297,7 +323,7 @@ class Controller:
|
||||
############################################################
|
||||
# Link connections
|
||||
############################################################
|
||||
def allocate_connection_handle(self):
|
||||
def allocate_connection_handle(self) -> int:
|
||||
handle = 0
|
||||
max_handle = 0
|
||||
for connection in itertools.chain(
|
||||
@@ -309,6 +335,13 @@ class Controller:
|
||||
if connection.handle == handle:
|
||||
# Already used, continue searching after the current max
|
||||
handle = max_handle + 1
|
||||
for cis_handle in itertools.chain(
|
||||
self.central_cis_links.keys(), self.peripheral_cis_links.keys()
|
||||
):
|
||||
max_handle = max(max_handle, cis_handle)
|
||||
if cis_handle == handle:
|
||||
# Already used, continue searching after the current max
|
||||
handle = max_handle + 1
|
||||
return handle
|
||||
|
||||
def find_le_connection_by_address(self, address):
|
||||
@@ -353,12 +386,13 @@ class Controller:
|
||||
if connection is None:
|
||||
connection_handle = self.allocate_connection_handle()
|
||||
connection = Connection(
|
||||
self,
|
||||
connection_handle,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
peer_address,
|
||||
self.link,
|
||||
BT_LE_TRANSPORT,
|
||||
controller=self,
|
||||
handle=connection_handle,
|
||||
role=BT_PERIPHERAL_ROLE,
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=BT_LE_TRANSPORT,
|
||||
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
)
|
||||
self.peripheral_connections[peer_address] = connection
|
||||
logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}')
|
||||
@@ -412,12 +446,13 @@ class Controller:
|
||||
if connection is None:
|
||||
connection_handle = self.allocate_connection_handle()
|
||||
connection = Connection(
|
||||
self,
|
||||
connection_handle,
|
||||
BT_CENTRAL_ROLE,
|
||||
peer_address,
|
||||
self.link,
|
||||
BT_LE_TRANSPORT,
|
||||
controller=self,
|
||||
handle=connection_handle,
|
||||
role=BT_CENTRAL_ROLE,
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=BT_LE_TRANSPORT,
|
||||
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
)
|
||||
self.central_connections[peer_address] = connection
|
||||
logger.debug(
|
||||
@@ -534,6 +569,104 @@ class Controller:
|
||||
)
|
||||
self.send_hci_packet(HCI_LE_Advertising_Report_Event([report]))
|
||||
|
||||
def on_link_cis_request(
|
||||
self, central_address: Address, cig_id: int, cis_id: int
|
||||
) -> None:
|
||||
'''
|
||||
Called when an incoming CIS request occurs from a central on the link
|
||||
'''
|
||||
|
||||
connection = self.peripheral_connections.get(central_address)
|
||||
assert connection
|
||||
|
||||
pending_cis_link = CisLink(
|
||||
handle=self.allocate_connection_handle(),
|
||||
cis_id=cis_id,
|
||||
cig_id=cig_id,
|
||||
acl_connection=connection,
|
||||
)
|
||||
self.peripheral_cis_links[pending_cis_link.handle] = pending_cis_link
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_LE_CIS_Request_Event(
|
||||
acl_connection_handle=connection.handle,
|
||||
cis_connection_handle=pending_cis_link.handle,
|
||||
cig_id=cig_id,
|
||||
cis_id=cis_id,
|
||||
)
|
||||
)
|
||||
|
||||
def on_link_cis_established(self, cig_id: int, cis_id: int) -> None:
|
||||
'''
|
||||
Called when an incoming CIS established.
|
||||
'''
|
||||
|
||||
cis_link = next(
|
||||
cis_link
|
||||
for cis_link in itertools.chain(
|
||||
self.central_cis_links.values(), self.peripheral_cis_links.values()
|
||||
)
|
||||
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
|
||||
)
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_LE_CIS_Established_Event(
|
||||
status=HCI_SUCCESS,
|
||||
connection_handle=cis_link.handle,
|
||||
# CIS parameters are ignored.
|
||||
cig_sync_delay=0,
|
||||
cis_sync_delay=0,
|
||||
transport_latency_c_to_p=0,
|
||||
transport_latency_p_to_c=0,
|
||||
phy_c_to_p=0,
|
||||
phy_p_to_c=0,
|
||||
nse=0,
|
||||
bn_c_to_p=0,
|
||||
bn_p_to_c=0,
|
||||
ft_c_to_p=0,
|
||||
ft_p_to_c=0,
|
||||
max_pdu_c_to_p=0,
|
||||
max_pdu_p_to_c=0,
|
||||
iso_interval=0,
|
||||
)
|
||||
)
|
||||
|
||||
def on_link_cis_disconnected(self, cig_id: int, cis_id: int) -> None:
|
||||
'''
|
||||
Called when a CIS disconnected.
|
||||
'''
|
||||
|
||||
if cis_link := next(
|
||||
(
|
||||
cis_link
|
||||
for cis_link in self.peripheral_cis_links.values()
|
||||
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
|
||||
),
|
||||
None,
|
||||
):
|
||||
# Remove peripheral CIS on disconnection.
|
||||
self.peripheral_cis_links.pop(cis_link.handle)
|
||||
elif cis_link := next(
|
||||
(
|
||||
cis_link
|
||||
for cis_link in self.central_cis_links.values()
|
||||
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
|
||||
),
|
||||
None,
|
||||
):
|
||||
# Keep central CIS on disconnection. They should be removed by HCI_LE_Remove_CIG_Command.
|
||||
cis_link.acl_connection = None
|
||||
else:
|
||||
return
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_Disconnection_Complete_Event(
|
||||
status=HCI_SUCCESS,
|
||||
connection_handle=cis_link.handle,
|
||||
reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
############################################################
|
||||
# Classic link connections
|
||||
############################################################
|
||||
@@ -562,6 +695,7 @@ class Controller:
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=BT_BR_EDR_TRANSPORT,
|
||||
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
|
||||
)
|
||||
self.classic_connections[peer_address] = connection
|
||||
logger.debug(
|
||||
@@ -615,6 +749,42 @@ class Controller:
|
||||
)
|
||||
)
|
||||
|
||||
def on_classic_sco_connection_complete(
|
||||
self, peer_address: Address, status: int, link_type: int
|
||||
):
|
||||
if status == HCI_SUCCESS:
|
||||
# Allocate (or reuse) a connection handle
|
||||
connection_handle = self.allocate_connection_handle()
|
||||
connection = Connection(
|
||||
controller=self,
|
||||
handle=connection_handle,
|
||||
# Role doesn't matter in SCO.
|
||||
role=BT_CENTRAL_ROLE,
|
||||
peer_address=peer_address,
|
||||
link=self.link,
|
||||
transport=BT_BR_EDR_TRANSPORT,
|
||||
link_type=link_type,
|
||||
)
|
||||
self.classic_connections[peer_address] = connection
|
||||
logger.debug(f'New SCO connection handle: 0x{connection_handle:04X}')
|
||||
else:
|
||||
connection_handle = 0
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_Synchronous_Connection_Complete_Event(
|
||||
status=status,
|
||||
connection_handle=connection_handle,
|
||||
bd_addr=peer_address,
|
||||
link_type=link_type,
|
||||
# TODO: Provide SCO connection parameters.
|
||||
transmission_interval=0,
|
||||
retransmission_window=0,
|
||||
rx_packet_length=0,
|
||||
tx_packet_length=0,
|
||||
air_mode=0,
|
||||
)
|
||||
)
|
||||
|
||||
############################################################
|
||||
# Advertising support
|
||||
############################################################
|
||||
@@ -717,6 +887,17 @@ class Controller:
|
||||
else:
|
||||
# Remove the connection
|
||||
del self.classic_connections[connection.peer_address]
|
||||
elif cis_link := (
|
||||
self.central_cis_links.get(handle) or self.peripheral_cis_links.get(handle)
|
||||
):
|
||||
if self.link:
|
||||
self.link.disconnect_cis(
|
||||
initiator_controller=self,
|
||||
peer_address=cis_link.acl_connection.peer_address,
|
||||
cig_id=cis_link.cig_id,
|
||||
cis_id=cis_link.cis_id,
|
||||
)
|
||||
# Spec requires handle to be kept after disconnection.
|
||||
|
||||
def on_hci_accept_connection_request_command(self, command):
|
||||
'''
|
||||
@@ -734,6 +915,68 @@ class Controller:
|
||||
)
|
||||
self.link.classic_accept_connection(self, command.bd_addr, command.role)
|
||||
|
||||
def on_hci_enhanced_setup_synchronous_connection_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.1.45 Enhanced Setup Synchronous Connection command
|
||||
'''
|
||||
|
||||
if self.link is None:
|
||||
return
|
||||
|
||||
if not (
|
||||
connection := self.find_classic_connection_by_handle(
|
||||
command.connection_handle
|
||||
)
|
||||
):
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_SUCCESS,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
self.link.classic_sco_connect(
|
||||
self, connection.peer_address, HCI_Connection_Complete_Event.ESCO_LINK_TYPE
|
||||
)
|
||||
|
||||
def on_hci_enhanced_accept_synchronous_connection_request_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.1.46 Enhanced Accept Synchronous Connection Request command
|
||||
'''
|
||||
|
||||
if self.link is None:
|
||||
return
|
||||
|
||||
if not (connection := self.find_classic_connection_by_address(command.bd_addr)):
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_SUCCESS,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
self.link.classic_accept_sco_connection(
|
||||
self, connection.peer_address, HCI_Connection_Complete_Event.ESCO_LINK_TYPE
|
||||
)
|
||||
|
||||
def on_hci_switch_role_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.2.8 Switch Role command
|
||||
@@ -908,14 +1151,48 @@ class Controller:
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.4.3 Read Local Supported Features Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS]) + self.lmp_features
|
||||
return bytes([HCI_SUCCESS]) + self.lmp_features[:8]
|
||||
|
||||
def on_hci_read_local_extended_features_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.4.4 Read Local Extended Features Command
|
||||
'''
|
||||
if command.page_number * 8 > len(self.lmp_features):
|
||||
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
|
||||
return (
|
||||
bytes(
|
||||
[
|
||||
# Status
|
||||
HCI_SUCCESS,
|
||||
# Page number
|
||||
command.page_number,
|
||||
# Max page number
|
||||
len(self.lmp_features) // 8 - 1,
|
||||
]
|
||||
)
|
||||
# Features of the current page
|
||||
+ self.lmp_features[command.page_number * 8 : (command.page_number + 1) * 8]
|
||||
)
|
||||
|
||||
def on_hci_read_buffer_size_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.4.5 Read Buffer Size Command
|
||||
'''
|
||||
return struct.pack(
|
||||
'<BHBHH',
|
||||
HCI_SUCCESS,
|
||||
self.hc_data_packet_length,
|
||||
0,
|
||||
self.hc_total_num_data_packets,
|
||||
0,
|
||||
)
|
||||
|
||||
def on_hci_read_bd_addr_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command
|
||||
'''
|
||||
bd_addr = (
|
||||
self._public_address.to_bytes()
|
||||
bytes(self._public_address)
|
||||
if self._public_address is not None
|
||||
else bytes(6)
|
||||
)
|
||||
@@ -996,6 +1273,9 @@ class Controller:
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.10 LE Set Scan Parameters Command
|
||||
'''
|
||||
if self.le_scan_enable:
|
||||
return bytes([HCI_COMMAND_DISALLOWED_ERROR])
|
||||
|
||||
self.le_scan_type = command.le_scan_type
|
||||
self.le_scan_interval = command.le_scan_interval
|
||||
self.le_scan_window = command.le_scan_window
|
||||
@@ -1082,6 +1362,18 @@ class Controller:
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.21 LE Read Remote Features Command
|
||||
'''
|
||||
|
||||
handle = command.connection_handle
|
||||
|
||||
if not self.find_connection_by_handle(handle):
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# First, say that the command is pending
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
@@ -1095,7 +1387,7 @@ class Controller:
|
||||
self.send_hci_packet(
|
||||
HCI_LE_Read_Remote_Features_Complete_Event(
|
||||
status=HCI_SUCCESS,
|
||||
connection_handle=0,
|
||||
connection_handle=handle,
|
||||
le_features=bytes.fromhex('dd40000000000000'),
|
||||
)
|
||||
)
|
||||
@@ -1251,8 +1543,191 @@ class Controller:
|
||||
}
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_advertising_set_random_address_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.52 LE Set Advertising Set Random Address
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_extended_advertising_parameters_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.53 LE Set Extended Advertising Parameters
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS, 0])
|
||||
|
||||
def on_hci_le_set_extended_advertising_data_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.54 LE Set Extended Advertising Data
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_extended_scan_response_data_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.55 LE Set Extended Scan Response Data
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_extended_advertising_enable_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.56 LE Set Extended Advertising Enable
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_read_maximum_advertising_data_length_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.57 LE Read Maximum Advertising Data
|
||||
Length Command
|
||||
'''
|
||||
return struct.pack('<BH', HCI_SUCCESS, 0x0672)
|
||||
|
||||
def on_hci_le_read_number_of_supported_advertising_sets_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.58 LE Read Number of Supported
|
||||
Advertising Set Command
|
||||
'''
|
||||
return struct.pack('<BB', HCI_SUCCESS, 0xF0)
|
||||
|
||||
def on_hci_le_set_periodic_advertising_parameters_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.61 LE Set Periodic Advertising Parameters
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_periodic_advertising_data_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.62 LE Set Periodic Advertising Data
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_periodic_advertising_enable_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.63 LE Set Periodic Advertising Enable
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_read_transmit_power_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command
|
||||
'''
|
||||
return struct.pack('<BBB', HCI_SUCCESS, 0, 0)
|
||||
|
||||
def on_hci_le_set_cig_parameters_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.97 LE Set CIG Parameter Command
|
||||
'''
|
||||
|
||||
# Remove old CIG implicitly.
|
||||
for handle, cis_link in self.central_cis_links.items():
|
||||
if cis_link.cig_id == command.cig_id:
|
||||
self.central_cis_links.pop(handle)
|
||||
|
||||
handles = []
|
||||
for cis_id in command.cis_id:
|
||||
handle = self.allocate_connection_handle()
|
||||
handles.append(handle)
|
||||
self.central_cis_links[handle] = CisLink(
|
||||
cis_id=cis_id,
|
||||
cig_id=command.cig_id,
|
||||
handle=handle,
|
||||
)
|
||||
return struct.pack(
|
||||
'<BBB', HCI_SUCCESS, command.cig_id, len(handles)
|
||||
) + b''.join([struct.pack('<H', handle) for handle in handles])
|
||||
|
||||
def on_hci_le_create_cis_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.99 LE Create CIS Command
|
||||
'''
|
||||
if not self.link:
|
||||
return
|
||||
|
||||
for cis_handle, acl_handle in zip(
|
||||
command.cis_connection_handle, command.acl_connection_handle
|
||||
):
|
||||
if not (connection := self.find_connection_by_handle(acl_handle)):
|
||||
logger.error(f'Cannot find connection with handle={acl_handle}')
|
||||
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
|
||||
|
||||
if not (cis_link := self.central_cis_links.get(cis_handle)):
|
||||
logger.error(f'Cannot find CIS with handle={cis_handle}')
|
||||
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
|
||||
|
||||
cis_link.acl_connection = connection
|
||||
|
||||
self.link.create_cis(
|
||||
self,
|
||||
peripheral_address=connection.peer_address,
|
||||
cig_id=cis_link.cig_id,
|
||||
cis_id=cis_link.cis_id,
|
||||
)
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_COMMAND_STATUS_PENDING,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
|
||||
def on_hci_le_remove_cig_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.100 LE Remove CIG Command
|
||||
'''
|
||||
|
||||
status = HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR
|
||||
|
||||
for cis_handle, cis_link in self.central_cis_links.items():
|
||||
if cis_link.cig_id == command.cig_id:
|
||||
self.central_cis_links.pop(cis_handle)
|
||||
status = HCI_SUCCESS
|
||||
|
||||
return struct.pack('<BH', status, command.cig_id)
|
||||
|
||||
def on_hci_le_accept_cis_request_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.101 LE Accept CIS Request Command
|
||||
'''
|
||||
if not self.link:
|
||||
return
|
||||
|
||||
if not (
|
||||
pending_cis_link := self.peripheral_cis_links.get(command.connection_handle)
|
||||
):
|
||||
logger.error(f'Cannot find CIS with handle={command.connection_handle}')
|
||||
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
|
||||
|
||||
assert pending_cis_link.acl_connection
|
||||
self.link.accept_cis(
|
||||
peripheral_controller=self,
|
||||
central_address=pending_cis_link.acl_connection.peer_address,
|
||||
cig_id=pending_cis_link.cig_id,
|
||||
cis_id=pending_cis_link.cis_id,
|
||||
)
|
||||
|
||||
self.send_hci_packet(
|
||||
HCI_Command_Status_Event(
|
||||
status=HCI_COMMAND_STATUS_PENDING,
|
||||
num_hci_command_packets=1,
|
||||
command_opcode=command.op_code,
|
||||
)
|
||||
)
|
||||
|
||||
def on_hci_le_setup_iso_data_path_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.109 LE Setup ISO Data Path Command
|
||||
'''
|
||||
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
|
||||
|
||||
def on_hci_le_remove_iso_data_path_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command
|
||||
'''
|
||||
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
|
||||
|
||||
898
bumble/core.py
898
bumble/core.py
File diff suppressed because it is too large
Load Diff
184
bumble/crypto.py
184
bumble/crypto.py
@@ -21,24 +21,24 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import operator
|
||||
import platform
|
||||
|
||||
if platform.system() != 'Emscripten':
|
||||
import secrets
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||
generate_private_key,
|
||||
ECDH,
|
||||
EllipticCurvePublicNumbers,
|
||||
EllipticCurvePrivateNumbers,
|
||||
SECP256R1,
|
||||
)
|
||||
from cryptography.hazmat.primitives import cmac
|
||||
else:
|
||||
# TODO: implement stubs
|
||||
pass
|
||||
import secrets
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||
generate_private_key,
|
||||
ECDH,
|
||||
EllipticCurvePrivateKey,
|
||||
EllipticCurvePublicNumbers,
|
||||
EllipticCurvePrivateNumbers,
|
||||
SECP256R1,
|
||||
)
|
||||
from cryptography.hazmat.primitives import cmac
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -50,16 +50,18 @@ logger = logging.getLogger(__name__)
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class EccKey:
|
||||
def __init__(self, private_key):
|
||||
def __init__(self, private_key: EllipticCurvePrivateKey) -> None:
|
||||
self.private_key = private_key
|
||||
|
||||
@classmethod
|
||||
def generate(cls):
|
||||
def generate(cls) -> EccKey:
|
||||
private_key = generate_private_key(SECP256R1())
|
||||
return cls(private_key)
|
||||
|
||||
@classmethod
|
||||
def from_private_key_bytes(cls, d_bytes, x_bytes, y_bytes):
|
||||
def from_private_key_bytes(
|
||||
cls, d_bytes: bytes, x_bytes: bytes, y_bytes: bytes
|
||||
) -> EccKey:
|
||||
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
|
||||
x = int.from_bytes(x_bytes, byteorder='big', signed=False)
|
||||
y = int.from_bytes(y_bytes, byteorder='big', signed=False)
|
||||
@@ -69,7 +71,7 @@ class EccKey:
|
||||
return cls(private_key)
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
def x(self) -> bytes:
|
||||
return (
|
||||
self.private_key.public_key()
|
||||
.public_numbers()
|
||||
@@ -77,14 +79,14 @@ class EccKey:
|
||||
)
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
def y(self) -> bytes:
|
||||
return (
|
||||
self.private_key.public_key()
|
||||
.public_numbers()
|
||||
.y.to_bytes(32, byteorder='big')
|
||||
)
|
||||
|
||||
def dh(self, public_key_x, public_key_y):
|
||||
def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes:
|
||||
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
|
||||
y = int.from_bytes(public_key_y, byteorder='big', signed=False)
|
||||
public_key = EllipticCurvePublicNumbers(x, y, SECP256R1()).public_key()
|
||||
@@ -97,14 +99,33 @@ class EccKey:
|
||||
# Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def xor(x, y):
|
||||
def generate_prand() -> bytes:
|
||||
'''Generates random 3 bytes, with the 2 most significant bits of 0b01.
|
||||
|
||||
See Bluetooth spec, Vol 6, Part E - Table 1.2.
|
||||
'''
|
||||
prand_bytes = secrets.token_bytes(6)
|
||||
return prand_bytes[:2] + bytes([(prand_bytes[2] & 0b01111111) | 0b01000000])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def xor(x: bytes, y: bytes) -> bytes:
|
||||
assert len(x) == len(y)
|
||||
return bytes(map(operator.xor, x, y))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def r():
|
||||
def reverse(input: bytes) -> bytes:
|
||||
'''
|
||||
Returns bytes of input in reversed endianness.
|
||||
'''
|
||||
return input[::-1]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def r() -> bytes:
|
||||
'''
|
||||
Generate 16 bytes of random data
|
||||
'''
|
||||
@@ -112,20 +133,20 @@ def r():
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def e(key, data):
|
||||
def e(key: bytes, data: bytes) -> bytes:
|
||||
'''
|
||||
AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output.
|
||||
|
||||
See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e
|
||||
'''
|
||||
|
||||
cipher = Cipher(algorithms.AES(bytes(reversed(key))), modes.ECB())
|
||||
cipher = Cipher(algorithms.AES(reverse(key)), modes.ECB())
|
||||
encryptor = cipher.encryptor()
|
||||
return bytes(reversed(encryptor.update(bytes(reversed(data)))))
|
||||
return reverse(encryptor.update(reverse(data)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def ah(k, r): # pylint: disable=redefined-outer-name
|
||||
def ah(k: bytes, r: bytes) -> bytes: # pylint: disable=redefined-outer-name
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah
|
||||
'''
|
||||
@@ -136,7 +157,16 @@ def ah(k, r): # pylint: disable=redefined-outer-name
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-name
|
||||
def c1(
|
||||
k: bytes,
|
||||
r: bytes,
|
||||
preq: bytes,
|
||||
pres: bytes,
|
||||
iat: int,
|
||||
rat: int,
|
||||
ia: bytes,
|
||||
ra: bytes,
|
||||
) -> bytes: # pylint: disable=redefined-outer-name
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for
|
||||
LE Legacy Pairing
|
||||
@@ -148,7 +178,7 @@ def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-n
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def s1(k, r1, r2):
|
||||
def s1(k: bytes, r1: bytes, r2: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy
|
||||
Pairing
|
||||
@@ -158,7 +188,7 @@ def s1(k, r1, r2):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def aes_cmac(m, k):
|
||||
def aes_cmac(m: bytes, k: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC
|
||||
|
||||
@@ -170,20 +200,16 @@ def aes_cmac(m, k):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def f4(u, v, x, z):
|
||||
def f4(u: bytes, v: bytes, x: bytes, z: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value
|
||||
Generation Function f4
|
||||
'''
|
||||
return bytes(
|
||||
reversed(
|
||||
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))
|
||||
)
|
||||
)
|
||||
return reverse(aes_cmac(reverse(u) + reverse(v) + z, reverse(x)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def f5(w, n1, n2, a1, a2):
|
||||
def f5(w: bytes, n1: bytes, n2: bytes, a1: bytes, a2: bytes) -> Tuple[bytes, bytes]:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation
|
||||
Function f5
|
||||
@@ -191,87 +217,83 @@ def f5(w, n1, n2, a1, a2):
|
||||
NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order
|
||||
'''
|
||||
salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE')
|
||||
t = aes_cmac(bytes(reversed(w)), salt)
|
||||
t = aes_cmac(reverse(w), salt)
|
||||
key_id = bytes([0x62, 0x74, 0x6C, 0x65])
|
||||
return (
|
||||
bytes(
|
||||
reversed(
|
||||
aes_cmac(
|
||||
bytes([0])
|
||||
+ key_id
|
||||
+ bytes(reversed(n1))
|
||||
+ bytes(reversed(n2))
|
||||
+ bytes(reversed(a1))
|
||||
+ bytes(reversed(a2))
|
||||
+ bytes([1, 0]),
|
||||
t,
|
||||
)
|
||||
reverse(
|
||||
aes_cmac(
|
||||
bytes([0])
|
||||
+ key_id
|
||||
+ reverse(n1)
|
||||
+ reverse(n2)
|
||||
+ reverse(a1)
|
||||
+ reverse(a2)
|
||||
+ bytes([1, 0]),
|
||||
t,
|
||||
)
|
||||
),
|
||||
bytes(
|
||||
reversed(
|
||||
aes_cmac(
|
||||
bytes([1])
|
||||
+ key_id
|
||||
+ bytes(reversed(n1))
|
||||
+ bytes(reversed(n2))
|
||||
+ bytes(reversed(a1))
|
||||
+ bytes(reversed(a2))
|
||||
+ bytes([1, 0]),
|
||||
t,
|
||||
)
|
||||
reverse(
|
||||
aes_cmac(
|
||||
bytes([1])
|
||||
+ key_id
|
||||
+ reverse(n1)
|
||||
+ reverse(n2)
|
||||
+ reverse(a1)
|
||||
+ reverse(a2)
|
||||
+ bytes([1, 0]),
|
||||
t,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def f6(w, n1, n2, r, io_cap, a1, a2): # pylint: disable=redefined-outer-name
|
||||
def f6(
|
||||
w: bytes, n1: bytes, n2: bytes, r: bytes, io_cap: bytes, a1: bytes, a2: bytes
|
||||
) -> bytes: # pylint: disable=redefined-outer-name
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value
|
||||
Generation Function f6
|
||||
'''
|
||||
return bytes(
|
||||
reversed(
|
||||
aes_cmac(
|
||||
bytes(reversed(n1))
|
||||
+ bytes(reversed(n2))
|
||||
+ bytes(reversed(r))
|
||||
+ bytes(reversed(io_cap))
|
||||
+ bytes(reversed(a1))
|
||||
+ bytes(reversed(a2)),
|
||||
bytes(reversed(w)),
|
||||
)
|
||||
return reverse(
|
||||
aes_cmac(
|
||||
reverse(n1)
|
||||
+ reverse(n2)
|
||||
+ reverse(r)
|
||||
+ reverse(io_cap)
|
||||
+ reverse(a1)
|
||||
+ reverse(a2),
|
||||
reverse(w),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def g2(u, v, x, y):
|
||||
def g2(u: bytes, v: bytes, x: bytes, y: bytes) -> int:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison
|
||||
Value Generation Function g2
|
||||
'''
|
||||
return int.from_bytes(
|
||||
aes_cmac(
|
||||
bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)),
|
||||
bytes(reversed(x)),
|
||||
reverse(u) + reverse(v) + reverse(y),
|
||||
reverse(x),
|
||||
)[-4:],
|
||||
byteorder='big',
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def h6(w, key_id):
|
||||
def h6(w: bytes, key_id: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.10 Link key conversion function h6
|
||||
'''
|
||||
return aes_cmac(key_id, w)
|
||||
return reverse(aes_cmac(key_id, reverse(w)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def h7(salt, w):
|
||||
def h7(salt: bytes, w: bytes) -> bytes:
|
||||
'''
|
||||
See Bluetooth spec, Vol 3, Part H - 2.2.11 Link key conversion function h7
|
||||
'''
|
||||
return aes_cmac(w, salt)
|
||||
return reverse(aes_cmac(reverse(w), salt))
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -149,7 +151,7 @@ QMF_COEFFS = [3, -11, 12, 32, -210, 951, 3876, -805, 362, -156, 53, -11]
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class G722Decoder(object):
|
||||
class G722Decoder:
|
||||
"""G.722 decoder with bitrate 64kbit/s.
|
||||
|
||||
For the Blocks in the sub-band decoders, please refer to the G.722
|
||||
@@ -157,7 +159,7 @@ class G722Decoder(object):
|
||||
https://www.itu.int/rec/T-REC-G.722-201209-I
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._x = [0] * 24
|
||||
self._band = [Band(), Band()]
|
||||
# The initial value in BLOCK 3L
|
||||
@@ -165,12 +167,12 @@ class G722Decoder(object):
|
||||
# The initial value in BLOCK 3H
|
||||
self._band[1].det = 8
|
||||
|
||||
def decode_frame(self, encoded_data) -> bytearray:
|
||||
def decode_frame(self, encoded_data: Union[bytes, bytearray]) -> bytearray:
|
||||
result_array = bytearray(len(encoded_data) * 4)
|
||||
self.g722_decode(result_array, encoded_data)
|
||||
return result_array
|
||||
|
||||
def g722_decode(self, result_array, encoded_data) -> int:
|
||||
def g722_decode(self, result_array, encoded_data: Union[bytes, bytearray]) -> int:
|
||||
"""Decode the data frame using g722 decoder."""
|
||||
result_length = 0
|
||||
|
||||
@@ -198,14 +200,16 @@ class G722Decoder(object):
|
||||
|
||||
return result_length
|
||||
|
||||
def update_decoded_result(self, xout, byte_length, byte_array) -> int:
|
||||
def update_decoded_result(
|
||||
self, xout: int, byte_length: int, byte_array: bytearray
|
||||
) -> int:
|
||||
result = (int)(xout >> 11)
|
||||
bytes_result = result.to_bytes(2, 'little', signed=True)
|
||||
byte_array[byte_length] = bytes_result[0]
|
||||
byte_array[byte_length + 1] = bytes_result[1]
|
||||
return byte_length + 2
|
||||
|
||||
def lower_sub_band_decoder(self, lower_bits) -> int:
|
||||
def lower_sub_band_decoder(self, lower_bits: int) -> int:
|
||||
"""Lower sub-band decoder for last six bits."""
|
||||
|
||||
# Block 5L
|
||||
@@ -258,7 +262,7 @@ class G722Decoder(object):
|
||||
|
||||
return rlow
|
||||
|
||||
def higher_sub_band_decoder(self, higher_bits) -> int:
|
||||
def higher_sub_band_decoder(self, higher_bits: int) -> int:
|
||||
"""Higher sub-band decoder for first two bits."""
|
||||
|
||||
# Block 2H
|
||||
@@ -306,14 +310,14 @@ class G722Decoder(object):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Band(object):
|
||||
"""Structure for G722 decode proccessing."""
|
||||
class Band:
|
||||
"""Structure for G722 decode processing."""
|
||||
|
||||
s: int = 0
|
||||
nb: int = 0
|
||||
det: int = 0
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._sp = 0
|
||||
self._sz = 0
|
||||
self._r = [0] * 3
|
||||
|
||||
3868
bumble/device.py
3868
bumble/device.py
File diff suppressed because it is too large
Load Diff
@@ -19,10 +19,17 @@ like loading firmware after a cold start.
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import abc
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from . import rtk
|
||||
import pathlib
|
||||
import platform
|
||||
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
|
||||
|
||||
from . import rtk, intel
|
||||
from .common import Driver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.host import Host
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -30,39 +37,51 @@ from . import rtk
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class Driver(abc.ABC):
|
||||
"""Base class for drivers."""
|
||||
|
||||
@staticmethod
|
||||
async def for_host(_host):
|
||||
"""Return a driver instance for a host.
|
||||
|
||||
Args:
|
||||
host: Host object for which a driver should be created.
|
||||
|
||||
Returns:
|
||||
A Driver instance if a driver should be instantiated for this host, or
|
||||
None if no driver instance of this class is needed.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def init_controller(self):
|
||||
"""Initialize the controller."""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_driver_for_host(host):
|
||||
"""Probe all known diver classes until one returns a valid instance for a host,
|
||||
or none is found.
|
||||
async def get_driver_for_host(host: Host) -> Optional[Driver]:
|
||||
"""Probe diver classes until one returns a valid instance for a host, or none is
|
||||
found.
|
||||
If a "driver" HCI metadata entry is present, only that driver class will be probed.
|
||||
"""
|
||||
if driver := await rtk.Driver.for_host(host):
|
||||
logger.debug("Instantiated RTK driver")
|
||||
return driver
|
||||
driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver, "intel": intel.Driver}
|
||||
probe_list: Iterable[str]
|
||||
if driver_name := host.hci_metadata.get("driver"):
|
||||
# Only probe a single driver
|
||||
probe_list = [driver_name]
|
||||
else:
|
||||
# Probe all drivers
|
||||
probe_list = driver_classes.keys()
|
||||
|
||||
for driver_name in probe_list:
|
||||
if driver_class := driver_classes.get(driver_name):
|
||||
logger.debug(f"Probing driver class: {driver_name}")
|
||||
if driver := await driver_class.for_host(host):
|
||||
logger.debug(f"Instantiated {driver_name} driver")
|
||||
return driver
|
||||
else:
|
||||
logger.debug(f"Skipping unknown driver class: {driver_name}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def project_data_dir() -> pathlib.Path:
|
||||
"""
|
||||
Returns:
|
||||
A path to an OS-specific directory for bumble data. The directory is created if
|
||||
it doesn't exist.
|
||||
"""
|
||||
import platformdirs
|
||||
|
||||
if platform.system() == 'Darwin':
|
||||
# platformdirs doesn't handle macOS right: it doesn't assemble a bundle id
|
||||
# out of author & project
|
||||
return platformdirs.user_data_path(
|
||||
appname='com.google.bumble', ensure_exists=True
|
||||
)
|
||||
else:
|
||||
# windows and linux don't use the com qualifier
|
||||
return platformdirs.user_data_path(
|
||||
appname='bumble', appauthor='google', ensure_exists=True
|
||||
)
|
||||
|
||||
47
bumble/drivers/common.py
Normal file
47
bumble/drivers/common.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Common types for drivers.
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import abc
|
||||
|
||||
from bumble import core
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class Driver(abc.ABC):
|
||||
"""Base class for drivers."""
|
||||
|
||||
@staticmethod
|
||||
async def for_host(_host):
|
||||
"""Return a driver instance for a host.
|
||||
|
||||
Args:
|
||||
host: Host object for which a driver should be created.
|
||||
|
||||
Returns:
|
||||
A Driver instance if a driver should be instantiated for this host, or
|
||||
None if no driver instance of this class is needed.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def init_controller(self):
|
||||
"""Initialize the controller."""
|
||||
671
bumble/drivers/intel.py
Normal file
671
bumble/drivers/intel.py
Normal file
@@ -0,0 +1,671 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Support for Intel USB controllers.
|
||||
Loosely based on the Fuchsia OS implementation.
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import collections
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import struct
|
||||
from typing import Any, Deque, Optional, TYPE_CHECKING
|
||||
|
||||
from bumble import core
|
||||
from bumble.drivers import common
|
||||
from bumble import hci
|
||||
from bumble import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.host import Host
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constant
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
INTEL_USB_PRODUCTS = {
|
||||
(0x8087, 0x0032), # AX210
|
||||
(0x8087, 0x0036), # BE200
|
||||
}
|
||||
|
||||
INTEL_FW_IMAGE_NAMES = [
|
||||
"ibt-0040-0041",
|
||||
"ibt-0040-1020",
|
||||
"ibt-0040-1050",
|
||||
"ibt-0040-2120",
|
||||
"ibt-0040-4150",
|
||||
"ibt-0041-0041",
|
||||
"ibt-0180-0041",
|
||||
"ibt-0180-1050",
|
||||
"ibt-0180-4150",
|
||||
"ibt-0291-0291",
|
||||
"ibt-1040-0041",
|
||||
"ibt-1040-1020",
|
||||
"ibt-1040-1050",
|
||||
"ibt-1040-2120",
|
||||
"ibt-1040-4150",
|
||||
]
|
||||
|
||||
INTEL_FIRMWARE_DIR_ENV = "BUMBLE_INTEL_FIRMWARE_DIR"
|
||||
INTEL_LINUX_FIRMWARE_DIR = "/lib/firmware/intel"
|
||||
|
||||
_MAX_FRAGMENT_SIZE = 252
|
||||
_POST_RESET_DELAY = 0.2
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HCI Commands
|
||||
# -----------------------------------------------------------------------------
|
||||
HCI_INTEL_WRITE_DEVICE_CONFIG_COMMAND = hci.hci_vendor_command_op_code(0x008B)
|
||||
HCI_INTEL_READ_VERSION_COMMAND = hci.hci_vendor_command_op_code(0x0005)
|
||||
HCI_INTEL_RESET_COMMAND = hci.hci_vendor_command_op_code(0x0001)
|
||||
HCI_INTEL_SECURE_SEND_COMMAND = hci.hci_vendor_command_op_code(0x0009)
|
||||
HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND = hci.hci_vendor_command_op_code(0x000E)
|
||||
|
||||
hci.HCI_Command.register_commands(globals())
|
||||
|
||||
|
||||
@hci.HCI_Command.command(
|
||||
fields=[
|
||||
("param0", 1),
|
||||
],
|
||||
return_parameters_fields=[
|
||||
("status", hci.STATUS_SPEC),
|
||||
("tlv", "*"),
|
||||
],
|
||||
)
|
||||
class HCI_Intel_Read_Version_Command(hci.HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
@hci.HCI_Command.command(
|
||||
fields=[("data_type", 1), ("data", "*")],
|
||||
return_parameters_fields=[
|
||||
("status", 1),
|
||||
],
|
||||
)
|
||||
class Hci_Intel_Secure_Send_Command(hci.HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
@hci.HCI_Command.command(
|
||||
fields=[
|
||||
("reset_type", 1),
|
||||
("patch_enable", 1),
|
||||
("ddc_reload", 1),
|
||||
("boot_option", 1),
|
||||
("boot_address", 4),
|
||||
],
|
||||
return_parameters_fields=[
|
||||
("data", "*"),
|
||||
],
|
||||
)
|
||||
class HCI_Intel_Reset_Command(hci.HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
@hci.HCI_Command.command(
|
||||
fields=[("data", "*")],
|
||||
return_parameters_fields=[
|
||||
("status", hci.STATUS_SPEC),
|
||||
("params", "*"),
|
||||
],
|
||||
)
|
||||
class Hci_Intel_Write_Device_Config_Command(hci.HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
def intel_firmware_dir() -> pathlib.Path:
|
||||
"""
|
||||
Returns:
|
||||
A path to a subdir of the project data dir for Intel firmware.
|
||||
The directory is created if it doesn't exist.
|
||||
"""
|
||||
from bumble.drivers import project_data_dir
|
||||
|
||||
p = project_data_dir() / "firmware" / "intel"
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
return p
|
||||
|
||||
|
||||
def _find_binary_path(file_name: str) -> pathlib.Path | None:
|
||||
# First check if an environment variable is set
|
||||
if INTEL_FIRMWARE_DIR_ENV in os.environ:
|
||||
if (
|
||||
path := pathlib.Path(os.environ[INTEL_FIRMWARE_DIR_ENV]) / file_name
|
||||
).is_file():
|
||||
logger.debug(f"{file_name} found in env dir")
|
||||
return path
|
||||
|
||||
# When the environment variable is set, don't look elsewhere
|
||||
return None
|
||||
|
||||
# Then, look where the firmware download tool writes by default
|
||||
if (path := intel_firmware_dir() / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in project data dir")
|
||||
return path
|
||||
|
||||
# Then, look in the package's driver directory
|
||||
if (path := pathlib.Path(__file__).parent / "intel_fw" / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in package dir")
|
||||
return path
|
||||
|
||||
# On Linux, check the system's FW directory
|
||||
if (
|
||||
platform.system() == "Linux"
|
||||
and (path := pathlib.Path(INTEL_LINUX_FIRMWARE_DIR) / file_name).is_file()
|
||||
):
|
||||
logger.debug(f"{file_name} found in Linux system FW dir")
|
||||
return path
|
||||
|
||||
# Finally look in the current directory
|
||||
if (path := pathlib.Path.cwd() / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in CWD")
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _parse_tlv(data: bytes) -> list[tuple[ValueType, Any]]:
|
||||
result: list[tuple[ValueType, Any]] = []
|
||||
while len(data) >= 2:
|
||||
value_type = ValueType(data[0])
|
||||
value_length = data[1]
|
||||
value = data[2 : 2 + value_length]
|
||||
typed_value: Any
|
||||
|
||||
if value_type == ValueType.END:
|
||||
break
|
||||
|
||||
if value_type in (ValueType.CNVI, ValueType.CNVR):
|
||||
(v,) = struct.unpack("<I", value)
|
||||
typed_value = (
|
||||
(((v >> 0) & 0xF) << 12)
|
||||
| (((v >> 4) & 0xF) << 0)
|
||||
| (((v >> 8) & 0xF) << 4)
|
||||
| (((v >> 24) & 0xF) << 8)
|
||||
)
|
||||
elif value_type == ValueType.HARDWARE_INFO:
|
||||
(v,) = struct.unpack("<I", value)
|
||||
typed_value = HardwareInfo(
|
||||
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
|
||||
)
|
||||
elif value_type in (
|
||||
ValueType.USB_VENDOR_ID,
|
||||
ValueType.USB_PRODUCT_ID,
|
||||
ValueType.DEVICE_REVISION,
|
||||
):
|
||||
(typed_value,) = struct.unpack("<H", value)
|
||||
elif value_type == ValueType.CURRENT_MODE_OF_OPERATION:
|
||||
typed_value = ModeOfOperation(value[0])
|
||||
elif value_type in (
|
||||
ValueType.BUILD_TYPE,
|
||||
ValueType.BUILD_NUMBER,
|
||||
ValueType.SECURE_BOOT,
|
||||
ValueType.OTP_LOCK,
|
||||
ValueType.API_LOCK,
|
||||
ValueType.DEBUG_LOCK,
|
||||
ValueType.SECURE_BOOT_ENGINE_TYPE,
|
||||
):
|
||||
typed_value = value[0]
|
||||
elif value_type == ValueType.TIMESTAMP:
|
||||
typed_value = Timestamp(value[0], value[1])
|
||||
elif value_type == ValueType.FIRMWARE_BUILD:
|
||||
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
|
||||
elif value_type == ValueType.BLUETOOTH_ADDRESS:
|
||||
typed_value = hci.Address(
|
||||
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
|
||||
)
|
||||
else:
|
||||
typed_value = value
|
||||
|
||||
result.append((value_type, typed_value))
|
||||
data = data[2 + value_length :]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class DriverError(core.BaseBumbleError):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"IntelDriverError({self.message})"
|
||||
|
||||
|
||||
class ValueType(utils.OpenIntEnum):
|
||||
END = 0x00
|
||||
CNVI = 0x10
|
||||
CNVR = 0x11
|
||||
HARDWARE_INFO = 0x12
|
||||
DEVICE_REVISION = 0x16
|
||||
CURRENT_MODE_OF_OPERATION = 0x1C
|
||||
USB_VENDOR_ID = 0x17
|
||||
USB_PRODUCT_ID = 0x18
|
||||
TIMESTAMP = 0x1D
|
||||
BUILD_TYPE = 0x1E
|
||||
BUILD_NUMBER = 0x1F
|
||||
SECURE_BOOT = 0x28
|
||||
OTP_LOCK = 0x2A
|
||||
API_LOCK = 0x2B
|
||||
DEBUG_LOCK = 0x2C
|
||||
FIRMWARE_BUILD = 0x2D
|
||||
SECURE_BOOT_ENGINE_TYPE = 0x2F
|
||||
BLUETOOTH_ADDRESS = 0x30
|
||||
|
||||
|
||||
class HardwarePlatform(utils.OpenIntEnum):
|
||||
INTEL_37 = 0x37
|
||||
|
||||
|
||||
class HardwareVariant(utils.OpenIntEnum):
|
||||
# This is a just a partial list.
|
||||
# Add other constants here as new hardware is encountered and tested.
|
||||
TYPHOON_PEAK = 0x17
|
||||
GALE_PEAK = 0x1C
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HardwareInfo:
|
||||
platform: HardwarePlatform
|
||||
variant: HardwareVariant
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Timestamp:
|
||||
week: int
|
||||
year: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FirmwareBuild:
|
||||
build_number: int
|
||||
timestamp: Timestamp
|
||||
|
||||
|
||||
class ModeOfOperation(utils.OpenIntEnum):
|
||||
BOOTLOADER = 0x01
|
||||
INTERMEDIATE = 0x02
|
||||
OPERATIONAL = 0x03
|
||||
|
||||
|
||||
class SecureBootEngineType(utils.OpenIntEnum):
|
||||
RSA = 0x00
|
||||
ECDSA = 0x01
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BootParams:
|
||||
css_header_offset: int
|
||||
css_header_size: int
|
||||
pki_offset: int
|
||||
pki_size: int
|
||||
sig_offset: int
|
||||
sig_size: int
|
||||
write_offset: int
|
||||
|
||||
|
||||
_BOOT_PARAMS = {
|
||||
SecureBootEngineType.RSA: BootParams(0, 128, 128, 256, 388, 256, 964),
|
||||
SecureBootEngineType.ECDSA: BootParams(644, 128, 772, 96, 868, 96, 964),
|
||||
}
|
||||
|
||||
|
||||
class Driver(common.Driver):
|
||||
def __init__(self, host: Host) -> None:
|
||||
self.host = host
|
||||
self.max_in_flight_firmware_load_commands = 1
|
||||
self.pending_firmware_load_commands: Deque[hci.HCI_Command] = (
|
||||
collections.deque()
|
||||
)
|
||||
self.can_send_firmware_load_command = asyncio.Event()
|
||||
self.can_send_firmware_load_command.set()
|
||||
self.firmware_load_complete = asyncio.Event()
|
||||
self.reset_complete = asyncio.Event()
|
||||
|
||||
# Parse configuration options from the driver name.
|
||||
self.ddc_addon: Optional[bytes] = None
|
||||
self.ddc_override: Optional[bytes] = None
|
||||
driver = host.hci_metadata.get("driver")
|
||||
if driver is not None and driver.startswith("intel/"):
|
||||
for key, value in [
|
||||
key_eq_value.split(":") for key_eq_value in driver[6:].split("+")
|
||||
]:
|
||||
if key == "ddc_addon":
|
||||
self.ddc_addon = bytes.fromhex(value)
|
||||
elif key == "ddc_override":
|
||||
self.ddc_override = bytes.fromhex(value)
|
||||
|
||||
@staticmethod
|
||||
def check(host: Host) -> bool:
|
||||
driver = host.hci_metadata.get("driver")
|
||||
if driver == "intel" or driver is not None and driver.startswith("intel/"):
|
||||
return True
|
||||
|
||||
vendor_id = host.hci_metadata.get("vendor_id")
|
||||
product_id = host.hci_metadata.get("product_id")
|
||||
|
||||
if vendor_id is None or product_id is None:
|
||||
logger.debug("USB metadata not sufficient")
|
||||
return False
|
||||
|
||||
if (vendor_id, product_id) not in INTEL_USB_PRODUCTS:
|
||||
logger.debug(
|
||||
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def for_host(cls, host: Host, force: bool = False):
|
||||
# Only instantiate this driver if explicitly selected
|
||||
if not force and not cls.check(host):
|
||||
return None
|
||||
|
||||
return cls(host)
|
||||
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
"""Handler for event packets that are received from an ACL channel"""
|
||||
event = hci.HCI_Event.from_bytes(packet)
|
||||
|
||||
if not isinstance(event, hci.HCI_Command_Complete_Event):
|
||||
self.host.on_hci_event_packet(event)
|
||||
return
|
||||
|
||||
if not event.return_parameters == hci.HCI_SUCCESS:
|
||||
raise DriverError("HCI_Command_Complete_Event error")
|
||||
|
||||
if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets:
|
||||
logger.debug(
|
||||
"max_in_flight_firmware_load_commands update: "
|
||||
f"{event.num_hci_command_packets}"
|
||||
)
|
||||
self.max_in_flight_firmware_load_commands = event.num_hci_command_packets
|
||||
logger.debug(f"event: {event}")
|
||||
self.pending_firmware_load_commands.popleft()
|
||||
in_flight = len(self.pending_firmware_load_commands)
|
||||
logger.debug(f"event received, {in_flight} still in flight")
|
||||
if in_flight < self.max_in_flight_firmware_load_commands:
|
||||
self.can_send_firmware_load_command.set()
|
||||
|
||||
async def send_firmware_load_command(self, command: hci.HCI_Command) -> None:
|
||||
# Wait until we can send.
|
||||
await self.can_send_firmware_load_command.wait()
|
||||
|
||||
# Send the command and adjust counters.
|
||||
self.host.send_hci_packet(command)
|
||||
self.pending_firmware_load_commands.append(command)
|
||||
in_flight = len(self.pending_firmware_load_commands)
|
||||
if in_flight >= self.max_in_flight_firmware_load_commands:
|
||||
logger.debug(f"max commands in flight reached [{in_flight}]")
|
||||
self.can_send_firmware_load_command.clear()
|
||||
|
||||
async def send_firmware_data(self, data_type: int, data: bytes) -> None:
|
||||
while data:
|
||||
fragment_size = min(len(data), _MAX_FRAGMENT_SIZE)
|
||||
fragment = data[:fragment_size]
|
||||
data = data[fragment_size:]
|
||||
|
||||
await self.send_firmware_load_command(
|
||||
Hci_Intel_Secure_Send_Command(data_type=data_type, data=fragment)
|
||||
)
|
||||
|
||||
async def load_firmware(self) -> None:
|
||||
self.host.ready = True
|
||||
device_info = await self.read_device_info()
|
||||
logger.debug(
|
||||
"device info: \n%s",
|
||||
"\n".join(
|
||||
[
|
||||
f" {value_type.name}: {value}"
|
||||
for value_type, value in device_info.items()
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# Check if the firmware is already loaded.
|
||||
if (
|
||||
device_info.get(ValueType.CURRENT_MODE_OF_OPERATION)
|
||||
== ModeOfOperation.OPERATIONAL
|
||||
):
|
||||
logger.debug("firmware already loaded")
|
||||
return
|
||||
|
||||
# We only support some platforms and variants.
|
||||
hardware_info = device_info.get(ValueType.HARDWARE_INFO)
|
||||
if hardware_info is None:
|
||||
raise DriverError("hardware info missing")
|
||||
if hardware_info.platform != HardwarePlatform.INTEL_37:
|
||||
raise DriverError("hardware platform not supported")
|
||||
if hardware_info.variant not in (
|
||||
HardwareVariant.TYPHOON_PEAK,
|
||||
HardwareVariant.GALE_PEAK,
|
||||
):
|
||||
raise DriverError("hardware variant not supported")
|
||||
|
||||
# Compute the firmware name.
|
||||
if ValueType.CNVI not in device_info or ValueType.CNVR not in device_info:
|
||||
raise DriverError("insufficient device info, missing CNVI or CNVR")
|
||||
|
||||
firmware_base_name = (
|
||||
"ibt-"
|
||||
f"{device_info[ValueType.CNVI]:04X}-"
|
||||
f"{device_info[ValueType.CNVR]:04X}"
|
||||
)
|
||||
logger.debug(f"FW base name: {firmware_base_name}")
|
||||
|
||||
firmware_name = f"{firmware_base_name}.sfi"
|
||||
firmware_path = _find_binary_path(firmware_name)
|
||||
if not firmware_path:
|
||||
logger.warning(f"Firmware file {firmware_name} not found")
|
||||
logger.warning("See https://google.github.io/bumble/drivers/intel.html")
|
||||
return None
|
||||
logger.debug(f"loading firmware from {firmware_path}")
|
||||
firmware_image = firmware_path.read_bytes()
|
||||
|
||||
engine_type = device_info.get(ValueType.SECURE_BOOT_ENGINE_TYPE)
|
||||
if engine_type is None:
|
||||
raise DriverError("secure boot engine type missing")
|
||||
if engine_type not in _BOOT_PARAMS:
|
||||
raise DriverError("secure boot engine type not supported")
|
||||
|
||||
boot_params = _BOOT_PARAMS[engine_type]
|
||||
if len(firmware_image) < boot_params.write_offset:
|
||||
raise DriverError("firmware image too small")
|
||||
|
||||
# Register to receive vendor events.
|
||||
def on_vendor_event(event: hci.HCI_Vendor_Event):
|
||||
logger.debug(f"vendor event: {event}")
|
||||
event_type = event.parameters[0]
|
||||
if event_type == 0x02:
|
||||
# Boot event
|
||||
logger.debug("boot complete")
|
||||
self.reset_complete.set()
|
||||
elif event_type == 0x06:
|
||||
# Firmware load event
|
||||
logger.debug("download complete")
|
||||
self.firmware_load_complete.set()
|
||||
else:
|
||||
logger.debug(f"ignoring vendor event type {event_type}")
|
||||
|
||||
self.host.on("vendor_event", on_vendor_event)
|
||||
|
||||
# We need to temporarily intercept packets from the controller,
|
||||
# because they are formatted as HCI event packets but are received
|
||||
# on the ACL channel, so the host parser would get confused.
|
||||
saved_on_packet = self.host.on_packet
|
||||
self.host.on_packet = self.on_packet # type: ignore
|
||||
self.firmware_load_complete.clear()
|
||||
|
||||
# Send the CSS header
|
||||
data = firmware_image[
|
||||
boot_params.css_header_offset : boot_params.css_header_offset
|
||||
+ boot_params.css_header_size
|
||||
]
|
||||
await self.send_firmware_data(0x00, data)
|
||||
|
||||
# Send the PKI header
|
||||
data = firmware_image[
|
||||
boot_params.pki_offset : boot_params.pki_offset + boot_params.pki_size
|
||||
]
|
||||
await self.send_firmware_data(0x03, data)
|
||||
|
||||
# Send the Signature header
|
||||
data = firmware_image[
|
||||
boot_params.sig_offset : boot_params.sig_offset + boot_params.sig_size
|
||||
]
|
||||
await self.send_firmware_data(0x02, data)
|
||||
|
||||
# Send the rest of the image.
|
||||
# The payload consists of command objects, which are sent when they add up
|
||||
# to a multiple of 4 bytes.
|
||||
boot_address = 0
|
||||
offset = boot_params.write_offset
|
||||
fragment_size = 0
|
||||
while offset + 3 < len(firmware_image):
|
||||
(command_opcode,) = struct.unpack_from(
|
||||
"<H", firmware_image, offset + fragment_size
|
||||
)
|
||||
command_size = firmware_image[offset + fragment_size + 2]
|
||||
if command_opcode == HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND:
|
||||
(boot_address,) = struct.unpack_from(
|
||||
"<I", firmware_image, offset + fragment_size + 3
|
||||
)
|
||||
logger.debug(
|
||||
"found HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND, "
|
||||
f"boot_address={boot_address}"
|
||||
)
|
||||
fragment_size += 3 + command_size
|
||||
if fragment_size % 4 == 0:
|
||||
await self.send_firmware_data(
|
||||
0x01, firmware_image[offset : offset + fragment_size]
|
||||
)
|
||||
logger.debug(f"sent {fragment_size} bytes")
|
||||
offset += fragment_size
|
||||
fragment_size = 0
|
||||
|
||||
# Wait for the firmware loading to be complete.
|
||||
logger.debug("waiting for firmware to be loaded")
|
||||
await self.firmware_load_complete.wait()
|
||||
logger.debug("firmware loaded")
|
||||
|
||||
# Restore the original packet handler.
|
||||
self.host.on_packet = saved_on_packet # type: ignore
|
||||
|
||||
# Reset
|
||||
self.reset_complete.clear()
|
||||
self.host.send_hci_packet(
|
||||
HCI_Intel_Reset_Command(
|
||||
reset_type=0x00,
|
||||
patch_enable=0x01,
|
||||
ddc_reload=0x00,
|
||||
boot_option=0x01,
|
||||
boot_address=boot_address,
|
||||
)
|
||||
)
|
||||
logger.debug("waiting for reset completion")
|
||||
await self.reset_complete.wait()
|
||||
logger.debug("reset complete")
|
||||
|
||||
# Load the device config if there is one.
|
||||
if self.ddc_override:
|
||||
logger.debug("loading overridden DDC")
|
||||
await self.load_device_config(self.ddc_override)
|
||||
else:
|
||||
ddc_name = f"{firmware_base_name}.ddc"
|
||||
ddc_path = _find_binary_path(ddc_name)
|
||||
if ddc_path:
|
||||
logger.debug(f"loading DDC from {ddc_path}")
|
||||
ddc_data = ddc_path.read_bytes()
|
||||
await self.load_device_config(ddc_data)
|
||||
if self.ddc_addon:
|
||||
logger.debug("loading DDC addon")
|
||||
await self.load_device_config(self.ddc_addon)
|
||||
|
||||
async def load_device_config(self, ddc_data: bytes) -> None:
|
||||
while ddc_data:
|
||||
ddc_len = 1 + ddc_data[0]
|
||||
ddc_payload = ddc_data[:ddc_len]
|
||||
await self.host.send_command(
|
||||
Hci_Intel_Write_Device_Config_Command(data=ddc_payload)
|
||||
)
|
||||
ddc_data = ddc_data[ddc_len:]
|
||||
|
||||
async def reboot_bootloader(self) -> None:
|
||||
self.host.send_hci_packet(
|
||||
HCI_Intel_Reset_Command(
|
||||
reset_type=0x01,
|
||||
patch_enable=0x01,
|
||||
ddc_reload=0x01,
|
||||
boot_option=0x00,
|
||||
boot_address=0,
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(_POST_RESET_DELAY)
|
||||
|
||||
async def read_device_info(self) -> dict[ValueType, Any]:
|
||||
self.host.ready = True
|
||||
response = await self.host.send_command(hci.HCI_Reset_Command())
|
||||
if not (
|
||||
isinstance(response, hci.HCI_Command_Complete_Event)
|
||||
and response.return_parameters
|
||||
in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS)
|
||||
):
|
||||
# When the controller is in operational mode, the response is a
|
||||
# successful response.
|
||||
# When the controller is in bootloader mode,
|
||||
# HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything
|
||||
# else is a failure.
|
||||
logger.warning(f"unexpected response: {response}")
|
||||
raise DriverError("unexpected HCI response")
|
||||
|
||||
# Read the firmware version.
|
||||
response = await self.host.send_command(
|
||||
HCI_Intel_Read_Version_Command(param0=0xFF)
|
||||
)
|
||||
if not isinstance(response, hci.HCI_Command_Complete_Event):
|
||||
raise DriverError("unexpected HCI response")
|
||||
|
||||
if response.return_parameters.status != 0: # type: ignore
|
||||
raise DriverError("HCI_Intel_Read_Version_Command error")
|
||||
|
||||
tlvs = _parse_tlv(response.return_parameters.tlv) # type: ignore
|
||||
|
||||
# Convert the list to a dict. That's Ok here because we only expect each type
|
||||
# to appear just once.
|
||||
return dict(tlvs)
|
||||
|
||||
async def init_controller(self):
|
||||
await self.load_firmware()
|
||||
@@ -33,16 +33,16 @@ from typing import Tuple
|
||||
import weakref
|
||||
|
||||
|
||||
from bumble import core
|
||||
from bumble.hci import (
|
||||
hci_command_op_code,
|
||||
hci_vendor_command_op_code,
|
||||
STATUS_SPEC,
|
||||
HCI_SUCCESS,
|
||||
HCI_COMMAND_NAMES,
|
||||
HCI_Command,
|
||||
HCI_Reset_Command,
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
)
|
||||
|
||||
from bumble.drivers import common
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -50,6 +50,10 @@ from bumble.hci import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RtkFirmwareError(core.BaseBumbleError):
|
||||
"""Error raised when RTK firmware initialization fails."""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -125,6 +129,7 @@ RTK_USB_PRODUCTS = {
|
||||
(0x2550, 0x8761),
|
||||
(0x2B89, 0x8761),
|
||||
(0x7392, 0xC611),
|
||||
(0x0BDA, 0x877B),
|
||||
# Realtek 8821AE
|
||||
(0x0B05, 0x17DC),
|
||||
(0x13D3, 0x3414),
|
||||
@@ -178,8 +183,10 @@ RTK_USB_PRODUCTS = {
|
||||
# -----------------------------------------------------------------------------
|
||||
# HCI Commands
|
||||
# -----------------------------------------------------------------------------
|
||||
HCI_RTK_READ_ROM_VERSION_COMMAND = hci_command_op_code(0x3F, 0x6D)
|
||||
HCI_COMMAND_NAMES[HCI_RTK_READ_ROM_VERSION_COMMAND] = "HCI_RTK_READ_ROM_VERSION_COMMAND"
|
||||
HCI_RTK_READ_ROM_VERSION_COMMAND = hci_vendor_command_op_code(0x6D)
|
||||
HCI_RTK_DOWNLOAD_COMMAND = hci_vendor_command_op_code(0x20)
|
||||
HCI_RTK_DROP_FIRMWARE_COMMAND = hci_vendor_command_op_code(0x66)
|
||||
HCI_Command.register_commands(globals())
|
||||
|
||||
|
||||
@HCI_Command.command(return_parameters_fields=[("status", STATUS_SPEC), ("version", 1)])
|
||||
@@ -187,10 +194,6 @@ class HCI_RTK_Read_ROM_Version_Command(HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
HCI_RTK_DOWNLOAD_COMMAND = hci_command_op_code(0x3F, 0x20)
|
||||
HCI_COMMAND_NAMES[HCI_RTK_DOWNLOAD_COMMAND] = "HCI_RTK_DOWNLOAD_COMMAND"
|
||||
|
||||
|
||||
@HCI_Command.command(
|
||||
fields=[("index", 1), ("payload", RTK_FRAGMENT_LENGTH)],
|
||||
return_parameters_fields=[("status", STATUS_SPEC), ("index", 1)],
|
||||
@@ -199,10 +202,6 @@ class HCI_RTK_Download_Command(HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
HCI_RTK_DROP_FIRMWARE_COMMAND = hci_command_op_code(0x3F, 0x66)
|
||||
HCI_COMMAND_NAMES[HCI_RTK_DROP_FIRMWARE_COMMAND] = "HCI_RTK_DROP_FIRMWARE_COMMAND"
|
||||
|
||||
|
||||
@HCI_Command.command()
|
||||
class HCI_RTK_Drop_Firmware_Command(HCI_Command):
|
||||
pass
|
||||
@@ -214,15 +213,15 @@ class Firmware:
|
||||
extension_sig = bytes([0x51, 0x04, 0xFD, 0x77])
|
||||
|
||||
if not firmware.startswith(RTK_EPATCH_SIGNATURE):
|
||||
raise ValueError("Firmware does not start with epatch signature")
|
||||
raise RtkFirmwareError("Firmware does not start with epatch signature")
|
||||
|
||||
if not firmware.endswith(extension_sig):
|
||||
raise ValueError("Firmware does not end with extension sig")
|
||||
raise RtkFirmwareError("Firmware does not end with extension sig")
|
||||
|
||||
# The firmware should start with a 14 byte header.
|
||||
epatch_header_size = 14
|
||||
if len(firmware) < epatch_header_size:
|
||||
raise ValueError("Firmware too short")
|
||||
raise RtkFirmwareError("Firmware too short")
|
||||
|
||||
# Look for the "project ID", starting from the end.
|
||||
offset = len(firmware) - len(extension_sig)
|
||||
@@ -236,7 +235,7 @@ class Firmware:
|
||||
break
|
||||
|
||||
if length == 0:
|
||||
raise ValueError("Invalid 0-length instruction")
|
||||
raise RtkFirmwareError("Invalid 0-length instruction")
|
||||
|
||||
if opcode == 0 and length == 1:
|
||||
project_id = firmware[offset - 1]
|
||||
@@ -245,7 +244,7 @@ class Firmware:
|
||||
offset -= length
|
||||
|
||||
if project_id < 0:
|
||||
raise ValueError("Project ID not found")
|
||||
raise RtkFirmwareError("Project ID not found")
|
||||
|
||||
self.project_id = project_id
|
||||
|
||||
@@ -258,7 +257,7 @@ class Firmware:
|
||||
# <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each)
|
||||
# <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each)
|
||||
if epatch_header_size + 8 * num_patches > len(firmware):
|
||||
raise ValueError("Firmware too short")
|
||||
raise RtkFirmwareError("Firmware too short")
|
||||
chip_id_table_offset = epatch_header_size
|
||||
patch_length_table_offset = chip_id_table_offset + 2 * num_patches
|
||||
patch_offset_table_offset = chip_id_table_offset + 4 * num_patches
|
||||
@@ -272,7 +271,7 @@ class Firmware:
|
||||
"<I", firmware, patch_offset_table_offset + 4 * patch_index
|
||||
)
|
||||
if patch_offset + patch_length > len(firmware):
|
||||
raise ValueError("Firmware too short")
|
||||
raise RtkFirmwareError("Firmware too short")
|
||||
|
||||
# Get the SVN version for the patch
|
||||
(svn_version,) = struct.unpack_from(
|
||||
@@ -291,7 +290,7 @@ class Firmware:
|
||||
)
|
||||
|
||||
|
||||
class Driver:
|
||||
class Driver(common.Driver):
|
||||
@dataclass
|
||||
class DriverInfo:
|
||||
rom: int
|
||||
@@ -302,6 +301,8 @@ class Driver:
|
||||
fw_name: str = ""
|
||||
config_name: str = ""
|
||||
|
||||
POST_RESET_DELAY: float = 0.2
|
||||
|
||||
DRIVER_INFOS = [
|
||||
# 8723A
|
||||
DriverInfo(
|
||||
@@ -445,6 +446,11 @@ class Driver:
|
||||
# When the environment variable is set, don't look elsewhere
|
||||
return None
|
||||
|
||||
# Then, look where the firmware download tool writes by default
|
||||
if (path := rtk_firmware_dir() / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in project data dir")
|
||||
return path
|
||||
|
||||
# Then, look in the package's driver directory
|
||||
if (path := pathlib.Path(__file__).parent / "rtk_fw" / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in package dir")
|
||||
@@ -471,8 +477,12 @@ class Driver:
|
||||
logger.debug("USB metadata not found")
|
||||
return False
|
||||
|
||||
vendor_id = host.hci_metadata.get("vendor_id", None)
|
||||
product_id = host.hci_metadata.get("product_id", None)
|
||||
if host.hci_metadata.get('driver') == 'rtk':
|
||||
# Forced driver
|
||||
return True
|
||||
|
||||
vendor_id = host.hci_metadata.get("vendor_id")
|
||||
product_id = host.hci_metadata.get("product_id")
|
||||
if vendor_id is None or product_id is None:
|
||||
logger.debug("USB metadata not sufficient")
|
||||
return False
|
||||
@@ -487,9 +497,24 @@ class Driver:
|
||||
|
||||
@classmethod
|
||||
async def driver_info_for_host(cls, host):
|
||||
response = await host.send_command(
|
||||
HCI_Read_Local_Version_Information_Command(), check_result=True
|
||||
)
|
||||
try:
|
||||
await host.send_command(
|
||||
HCI_Reset_Command(),
|
||||
check_result=True,
|
||||
response_timeout=cls.POST_RESET_DELAY,
|
||||
)
|
||||
host.ready = True # Needed to let the host know the controller is ready.
|
||||
except asyncio.exceptions.TimeoutError:
|
||||
logger.warning("timeout waiting for hci reset, retrying")
|
||||
await host.send_command(HCI_Reset_Command(), check_result=True)
|
||||
host.ready = True
|
||||
|
||||
command = HCI_Read_Local_Version_Information_Command()
|
||||
response = await host.send_command(command, check_result=True)
|
||||
if response.command_opcode != command.op_code:
|
||||
logger.error("failed to probe local version information")
|
||||
return None
|
||||
|
||||
local_version = response.return_parameters
|
||||
|
||||
logger.debug(
|
||||
@@ -639,9 +664,22 @@ class Driver:
|
||||
):
|
||||
return await self.download_for_rtl8723b()
|
||||
|
||||
raise ValueError("ROM not supported")
|
||||
raise RtkFirmwareError("ROM not supported")
|
||||
|
||||
async def init_controller(self):
|
||||
await self.download_firmware()
|
||||
await self.host.send_command(HCI_Reset_Command(), check_result=True)
|
||||
logger.info(f"loaded FW image {self.driver_info.fw_name}")
|
||||
|
||||
|
||||
def rtk_firmware_dir() -> pathlib.Path:
|
||||
"""
|
||||
Returns:
|
||||
A path to a subdir of the project data dir for Realtek firmware.
|
||||
The directory is created if it doesn't exist.
|
||||
"""
|
||||
from bumble.drivers import project_data_dir
|
||||
|
||||
p = project_data_dir() / "firmware" / "realtek"
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
return p
|
||||
|
||||
@@ -36,6 +36,7 @@ logger = logging.getLogger(__name__)
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class GenericAccessService(Service):
|
||||
def __init__(self, device_name, appearance=(0, 0)):
|
||||
|
||||
340
bumble/gatt.py
340
bumble/gatt.py
@@ -23,16 +23,32 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import struct
|
||||
from typing import Optional, Sequence, List
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
SupportsBytes,
|
||||
Type,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from .colors import color
|
||||
from .core import UUID, get_dict_key_by_value
|
||||
from .att import Attribute
|
||||
from bumble.colors import color
|
||||
from bumble.core import BaseBumbleError, UUID
|
||||
from bumble.att import Attribute, AttributeValue
|
||||
from bumble.utils import ByteSerializable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.gatt_client import AttributeProxy
|
||||
from bumble.device import Connection
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -93,20 +109,35 @@ GATT_RECONNECTION_CONFIGURATION_SERVICE = UUID.from_16_bits(0x1829, 'Reconne
|
||||
GATT_INSULIN_DELIVERY_SERVICE = UUID.from_16_bits(0x183A, 'Insulin Delivery')
|
||||
GATT_BINARY_SENSOR_SERVICE = UUID.from_16_bits(0x183B, 'Binary Sensor')
|
||||
GATT_EMERGENCY_CONFIGURATION_SERVICE = UUID.from_16_bits(0x183C, 'Emergency Configuration')
|
||||
GATT_AUTHORIZATION_CONTROL_SERVICE = UUID.from_16_bits(0x183D, 'Authorization Control')
|
||||
GATT_PHYSICAL_ACTIVITY_MONITOR_SERVICE = UUID.from_16_bits(0x183E, 'Physical Activity Monitor')
|
||||
GATT_ELAPSED_TIME_SERVICE = UUID.from_16_bits(0x183F, 'Elapsed Time')
|
||||
GATT_GENERIC_HEALTH_SENSOR_SERVICE = UUID.from_16_bits(0x1840, 'Generic Health Sensor')
|
||||
GATT_AUDIO_INPUT_CONTROL_SERVICE = UUID.from_16_bits(0x1843, 'Audio Input Control')
|
||||
GATT_VOLUME_CONTROL_SERVICE = UUID.from_16_bits(0x1844, 'Volume Control')
|
||||
GATT_VOLUME_OFFSET_CONTROL_SERVICE = UUID.from_16_bits(0x1845, 'Volume Offset Control')
|
||||
GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification Service')
|
||||
GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification')
|
||||
GATT_DEVICE_TIME_SERVICE = UUID.from_16_bits(0x1847, 'Device Time')
|
||||
GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control Service')
|
||||
GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control Service')
|
||||
GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control')
|
||||
GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control')
|
||||
GATT_CONSTANT_TONE_EXTENSION_SERVICE = UUID.from_16_bits(0x184A, 'Constant Tone Extension')
|
||||
GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer Service')
|
||||
GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer Service')
|
||||
GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer')
|
||||
GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer')
|
||||
GATT_MICROPHONE_CONTROL_SERVICE = UUID.from_16_bits(0x184D, 'Microphone Control')
|
||||
GATT_AUDIO_STREAM_CONTROL_SERVICE = UUID.from_16_bits(0x184E, 'Audio Stream Control')
|
||||
GATT_BROADCAST_AUDIO_SCAN_SERVICE = UUID.from_16_bits(0x184F, 'Broadcast Audio Scan')
|
||||
GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE = UUID.from_16_bits(0x1850, 'Published Audio Capabilities')
|
||||
GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1851, 'Basic Audio Announcement')
|
||||
GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1852, 'Broadcast Audio Announcement')
|
||||
GATT_COMMON_AUDIO_SERVICE = UUID.from_16_bits(0x1853, 'Common Audio')
|
||||
GATT_HEARING_ACCESS_SERVICE = UUID.from_16_bits(0x1854, 'Hearing Access')
|
||||
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE = UUID.from_16_bits(0x1855, 'Telephony and Media Audio')
|
||||
GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1856, 'Public Broadcast Announcement')
|
||||
GATT_ELECTRONIC_SHELF_LABEL_SERVICE = UUID.from_16_bits(0X1857, 'Electronic Shelf Label')
|
||||
GATT_GAMING_AUDIO_SERVICE = UUID.from_16_bits(0x1858, 'Gaming Audio')
|
||||
GATT_MESH_PROXY_SOLICITATION_SERVICE = UUID.from_16_bits(0x1859, 'Mesh Audio Solicitation')
|
||||
|
||||
# Types
|
||||
# Attribute Types
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2800, 'Primary Service')
|
||||
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2801, 'Secondary Service')
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2802, 'Include')
|
||||
@@ -129,6 +160,8 @@ GATT_ENVIRONMENTAL_SENSING_MEASUREMENT_DESCRIPTOR = UUID.from_16_bits(0x290C,
|
||||
GATT_ENVIRONMENTAL_SENSING_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290D, 'Environmental Sensing Trigger Setting')
|
||||
GATT_TIME_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290E, 'Time Trigger Setting')
|
||||
GATT_COMPLETE_BR_EDR_TRANSPORT_BLOCK_DATA_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Complete BR-EDR Transport Block Data')
|
||||
GATT_OBSERVATION_SCHEDULE_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Observation Schedule')
|
||||
GATT_VALID_RANGE_AND_ACCURACY_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Valid Range And Accuracy')
|
||||
|
||||
# Device Information Service
|
||||
GATT_SYSTEM_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A23, 'System ID')
|
||||
@@ -156,6 +189,108 @@ GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A39, 'Heart
|
||||
# Battery Service
|
||||
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level')
|
||||
|
||||
# Telephony And Media Audio Service (TMAS)
|
||||
GATT_TMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2B51, 'TMAP Role')
|
||||
|
||||
# Audio Input Control Service (AICS)
|
||||
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B77, 'Audio Input State')
|
||||
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC = UUID.from_16_bits(0x2B78, 'Gain Settings Attribute')
|
||||
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC = UUID.from_16_bits(0x2B79, 'Audio Input Type')
|
||||
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC = UUID.from_16_bits(0x2B7A, 'Audio Input Status')
|
||||
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7B, 'Audio Input Control Point')
|
||||
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B7C, 'Audio Input Description')
|
||||
|
||||
# Volume Control Service (VCS)
|
||||
GATT_VOLUME_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B7D, 'Volume State')
|
||||
GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7E, 'Volume Control Point')
|
||||
GATT_VOLUME_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2B7F, 'Volume Flags')
|
||||
|
||||
# Volume Offset Control Service (VOCS)
|
||||
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B80, 'Volume Offset State')
|
||||
GATT_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2B81, 'Audio Location')
|
||||
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B82, 'Volume Offset Control Point')
|
||||
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B83, 'Audio Output Description')
|
||||
|
||||
# Coordinated Set Identification Service (CSIS)
|
||||
GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC = UUID.from_16_bits(0x2B84, 'Set Identity Resolving Key')
|
||||
GATT_COORDINATED_SET_SIZE_CHARACTERISTIC = UUID.from_16_bits(0x2B85, 'Coordinated Set Size')
|
||||
GATT_SET_MEMBER_LOCK_CHARACTERISTIC = UUID.from_16_bits(0x2B86, 'Set Member Lock')
|
||||
GATT_SET_MEMBER_RANK_CHARACTERISTIC = UUID.from_16_bits(0x2B87, 'Set Member Rank')
|
||||
|
||||
# Media Control Service (MCS)
|
||||
GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2B93, 'Media Player Name')
|
||||
GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B94, 'Media Player Icon Object ID')
|
||||
GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC = UUID.from_16_bits(0x2B95, 'Media Player Icon URL')
|
||||
GATT_TRACK_CHANGED_CHARACTERISTIC = UUID.from_16_bits(0x2B96, 'Track Changed')
|
||||
GATT_TRACK_TITLE_CHARACTERISTIC = UUID.from_16_bits(0x2B97, 'Track Title')
|
||||
GATT_TRACK_DURATION_CHARACTERISTIC = UUID.from_16_bits(0x2B98, 'Track Duration')
|
||||
GATT_TRACK_POSITION_CHARACTERISTIC = UUID.from_16_bits(0x2B99, 'Track Position')
|
||||
GATT_PLAYBACK_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9A, 'Playback Speed')
|
||||
GATT_SEEKING_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9B, 'Seeking Speed')
|
||||
GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9C, 'Current Track Segments Object ID')
|
||||
GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9D, 'Current Track Object ID')
|
||||
GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9E, 'Next Track Object ID')
|
||||
GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9F, 'Parent Group Object ID')
|
||||
GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA0, 'Current Group Object ID')
|
||||
GATT_PLAYING_ORDER_CHARACTERISTIC = UUID.from_16_bits(0x2BA1, 'Playing Order')
|
||||
GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA2, 'Playing Orders Supported')
|
||||
GATT_MEDIA_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BA3, 'Media State')
|
||||
GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA4, 'Media Control Point')
|
||||
GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA5, 'Media Control Point Opcodes Supported')
|
||||
GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA6, 'Search Results Object ID')
|
||||
GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA7, 'Search Control Point')
|
||||
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control Id')
|
||||
|
||||
# Telephone Bearer Service (TBS)
|
||||
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB3, 'Bearer Provider Name')
|
||||
GATT_BEARER_UCI_CHARACTERISTIC = UUID.from_16_bits(0x2BB4, 'Bearer UCI')
|
||||
GATT_BEARER_TECHNOLOGY_CHARACTERISTIC = UUID.from_16_bits(0x2BB5, 'Bearer Technology')
|
||||
GATT_BEARER_URI_SCHEMES_SUPPORTED_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2BB6, 'Bearer URI Schemes Supported List')
|
||||
GATT_BEARER_SIGNAL_STRENGTH_CHARACTERISTIC = UUID.from_16_bits(0x2BB7, 'Bearer Signal Strength')
|
||||
GATT_BEARER_SIGNAL_STRENGTH_REPORTING_INTERVAL_CHARACTERISTIC = UUID.from_16_bits(0x2BB8, 'Bearer Signal Strength Reporting Interval')
|
||||
GATT_BEARER_LIST_CURRENT_CALLS_CHARACTERISTIC = UUID.from_16_bits(0x2BB9, 'Bearer List Current Calls')
|
||||
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control ID')
|
||||
GATT_STATUS_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2BBB, 'Status Flags')
|
||||
GATT_INCOMING_CALL_TARGET_BEARER_URI_CHARACTERISTIC = UUID.from_16_bits(0x2BBC, 'Incoming Call Target Bearer URI')
|
||||
GATT_CALL_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BBD, 'Call State')
|
||||
GATT_CALL_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BBE, 'Call Control Point')
|
||||
GATT_CALL_CONTROL_POINT_OPTIONAL_OPCODES_CHARACTERISTIC = UUID.from_16_bits(0x2BBF, 'Call Control Point Optional Opcodes')
|
||||
GATT_TERMINATION_REASON_CHARACTERISTIC = UUID.from_16_bits(0x2BC0, 'Termination Reason')
|
||||
GATT_INCOMING_CALL_CHARACTERISTIC = UUID.from_16_bits(0x2BC1, 'Incoming Call')
|
||||
GATT_CALL_FRIENDLY_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BC2, 'Call Friendly Name')
|
||||
|
||||
# Microphone Control Service (MICS)
|
||||
GATT_MUTE_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Mute')
|
||||
|
||||
# Audio Stream Control Service (ASCS)
|
||||
GATT_SINK_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC4, 'Sink ASE')
|
||||
GATT_SOURCE_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC5, 'Source ASE')
|
||||
GATT_ASE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC6, 'ASE Control Point')
|
||||
|
||||
# Broadcast Audio Scan Service (BASS)
|
||||
GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC7, 'Broadcast Audio Scan Control Point')
|
||||
GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BC8, 'Broadcast Receive State')
|
||||
|
||||
# Published Audio Capabilities Service (PACS)
|
||||
GATT_SINK_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BC9, 'Sink PAC')
|
||||
GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCA, 'Sink Audio Location')
|
||||
GATT_SOURCE_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BCB, 'Source PAC')
|
||||
GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Source Audio Location')
|
||||
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
|
||||
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
|
||||
|
||||
# Gaming Audio Service (GMAS)
|
||||
GATT_GMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2C00, 'GMAP Role')
|
||||
GATT_UGG_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C01, 'UGG Features')
|
||||
GATT_UGT_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C02, 'UGT Features')
|
||||
GATT_BGS_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C03, 'BGS Features')
|
||||
GATT_BGR_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C04, 'BGR Features')
|
||||
|
||||
# Hearing Access Service
|
||||
GATT_HEARING_AID_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2BDA, 'Hearing Aid Features')
|
||||
GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BDB, 'Hearing Aid Preset Control Point')
|
||||
GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC = UUID.from_16_bits(0x2BDC, 'Active Preset Index')
|
||||
|
||||
# ASHA Service
|
||||
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
|
||||
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
|
||||
@@ -177,6 +312,9 @@ GATT_BOOT_KEYBOARD_INPUT_REPORT_CHARACTERISTIC = UUID.from_16_bi
|
||||
GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bits(0x2A2B, 'Current Time')
|
||||
GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report')
|
||||
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
|
||||
GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features')
|
||||
GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash')
|
||||
GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features')
|
||||
|
||||
# fmt: on
|
||||
# pylint: enable=line-too-long
|
||||
@@ -187,7 +325,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def show_services(services):
|
||||
def show_services(services: Iterable[Service]) -> None:
|
||||
for service in services:
|
||||
print(color(str(service), 'cyan'))
|
||||
|
||||
@@ -198,6 +336,11 @@ def show_services(services):
|
||||
print(color(' ' + str(descriptor), 'green'))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class InvalidServiceError(BaseBumbleError):
|
||||
"""The service is not compliant with the spec/profile"""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Service(Attribute):
|
||||
'''
|
||||
@@ -210,25 +353,27 @@ class Service(Attribute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uuid,
|
||||
characteristics: List[Characteristic],
|
||||
uuid: Union[str, UUID],
|
||||
characteristics: Iterable[Characteristic],
|
||||
primary=True,
|
||||
included_services: List[Service] = [],
|
||||
):
|
||||
included_services: Iterable[Service] = (),
|
||||
) -> None:
|
||||
# Convert the uuid to a UUID object if it isn't already
|
||||
if isinstance(uuid, str):
|
||||
uuid = UUID(uuid)
|
||||
|
||||
super().__init__(
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
|
||||
if primary
|
||||
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
(
|
||||
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
|
||||
if primary
|
||||
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
|
||||
),
|
||||
Attribute.READABLE,
|
||||
uuid.to_pdu_bytes(),
|
||||
)
|
||||
self.uuid = uuid
|
||||
self.included_services = included_services[:]
|
||||
self.characteristics = characteristics[:]
|
||||
self.included_services = list(included_services)
|
||||
self.characteristics = list(characteristics)
|
||||
self.primary = primary
|
||||
|
||||
def get_advertising_data(self) -> Optional[bytes]:
|
||||
@@ -239,7 +384,7 @@ class Service(Attribute):
|
||||
"""
|
||||
return None
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Service(handle=0x{self.handle:04X}, '
|
||||
f'end=0x{self.end_group_handle:04X}, '
|
||||
@@ -255,10 +400,15 @@ class TemplateService(Service):
|
||||
to expose their UUID as a class property
|
||||
'''
|
||||
|
||||
UUID: Optional[UUID] = None
|
||||
UUID: UUID
|
||||
|
||||
def __init__(self, characteristics, primary=True):
|
||||
super().__init__(self.UUID, characteristics, primary)
|
||||
def __init__(
|
||||
self,
|
||||
characteristics: Iterable[Characteristic],
|
||||
primary: bool = True,
|
||||
included_services: Iterable[Service] = (),
|
||||
) -> None:
|
||||
super().__init__(self.UUID, characteristics, primary, included_services)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -269,16 +419,16 @@ class IncludedServiceDeclaration(Attribute):
|
||||
|
||||
service: Service
|
||||
|
||||
def __init__(self, service):
|
||||
def __init__(self, service: Service) -> None:
|
||||
declaration_bytes = struct.pack(
|
||||
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
|
||||
'<HH2s', service.handle, service.end_group_handle, bytes(service.uuid)
|
||||
)
|
||||
super().__init__(
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
|
||||
)
|
||||
self.service = service
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
|
||||
f'group_starting_handle=0x{self.service.handle:04X}, '
|
||||
@@ -326,7 +476,7 @@ class Characteristic(Attribute):
|
||||
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
# NOTE: we override this method to offer a consistent result between python
|
||||
# versions: the value returned by IntFlag.__str__() changed in version 11.
|
||||
return '|'.join(
|
||||
@@ -348,10 +498,10 @@ class Characteristic(Attribute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uuid,
|
||||
uuid: Union[str, bytes, UUID],
|
||||
properties: Characteristic.Properties,
|
||||
permissions,
|
||||
value=b'',
|
||||
permissions: Union[str, Attribute.Permissions],
|
||||
value: Any = b'',
|
||||
descriptors: Sequence[Descriptor] = (),
|
||||
):
|
||||
super().__init__(uuid, permissions, value)
|
||||
@@ -369,7 +519,7 @@ class Characteristic(Attribute):
|
||||
def has_properties(self, properties: Characteristic.Properties) -> bool:
|
||||
return self.properties & properties == properties
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Characteristic(handle=0x{self.handle:04X}, '
|
||||
f'end=0x{self.end_group_handle:04X}, '
|
||||
@@ -386,7 +536,11 @@ class CharacteristicDeclaration(Attribute):
|
||||
|
||||
characteristic: Characteristic
|
||||
|
||||
def __init__(self, characteristic, value_handle):
|
||||
def __init__(
|
||||
self,
|
||||
characteristic: Characteristic,
|
||||
value_handle: int,
|
||||
) -> None:
|
||||
declaration_bytes = (
|
||||
struct.pack('<BH', characteristic.properties, value_handle)
|
||||
+ characteristic.uuid.to_pdu_bytes()
|
||||
@@ -397,7 +551,7 @@ class CharacteristicDeclaration(Attribute):
|
||||
self.value_handle = value_handle
|
||||
self.characteristic = characteristic
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
|
||||
f'value_handle=0x{self.value_handle:04X}, '
|
||||
@@ -407,56 +561,43 @@ class CharacteristicDeclaration(Attribute):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicValue:
|
||||
'''
|
||||
Characteristic value where reading and/or writing is delegated to functions
|
||||
passed as arguments to the constructor.
|
||||
'''
|
||||
|
||||
def __init__(self, read=None, write=None):
|
||||
self._read = read
|
||||
self._write = write
|
||||
|
||||
def read(self, connection):
|
||||
return self._read(connection) if self._read else b''
|
||||
|
||||
def write(self, connection, value):
|
||||
if self._write:
|
||||
self._write(connection, value)
|
||||
class CharacteristicValue(AttributeValue):
|
||||
"""Same as AttributeValue, for backward compatibility"""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicAdapter:
|
||||
'''
|
||||
An adapter that can adapt any object with `read_value` and `write_value`
|
||||
methods (like Characteristic and CharacteristicProxy objects) by wrapping
|
||||
those methods with ones that return/accept encoded/decoded values.
|
||||
Objects with async methods are considered proxies, so the adaptation is one
|
||||
where the return value of `read_value` is decoded and the value passed to
|
||||
`write_value` is encoded. Other objects are considered local characteristics
|
||||
so the adaptation is one where the return value of `read_value` is encoded
|
||||
and the value passed to `write_value` is decoded.
|
||||
If the characteristic has a `subscribe` method, it is wrapped with one where
|
||||
the values are decoded before being passed to the subscriber.
|
||||
An adapter that can adapt Characteristic and AttributeProxy objects
|
||||
by wrapping their `read_value()` and `write_value()` methods with ones that
|
||||
return/accept encoded/decoded values.
|
||||
|
||||
For proxies (i.e used by a GATT client), the adaptation is one where the return
|
||||
value of `read_value()` is decoded and the value passed to `write_value()` is
|
||||
encoded. The `subscribe()` method, is wrapped with one where the values are decoded
|
||||
before being passed to the subscriber.
|
||||
|
||||
For local values (i.e hosted by a GATT server) the adaptation is one where the
|
||||
return value of `read_value()` is encoded and the value passed to `write_value()`
|
||||
is decoded.
|
||||
'''
|
||||
|
||||
def __init__(self, characteristic):
|
||||
self.wrapped_characteristic = characteristic
|
||||
self.subscribers = {} # Map from subscriber to proxy subscriber
|
||||
read_value: Callable
|
||||
write_value: Callable
|
||||
|
||||
if asyncio.iscoroutinefunction(
|
||||
characteristic.read_value
|
||||
) and asyncio.iscoroutinefunction(characteristic.write_value):
|
||||
self.read_value = self.read_decoded_value
|
||||
self.write_value = self.write_decoded_value
|
||||
else:
|
||||
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
|
||||
self.wrapped_characteristic = characteristic
|
||||
self.subscribers: Dict[Callable, Callable] = (
|
||||
{}
|
||||
) # Map from subscriber to proxy subscriber
|
||||
|
||||
if isinstance(characteristic, Characteristic):
|
||||
self.read_value = self.read_encoded_value
|
||||
self.write_value = self.write_encoded_value
|
||||
|
||||
if hasattr(self.wrapped_characteristic, 'subscribe'):
|
||||
else:
|
||||
self.read_value = self.read_decoded_value
|
||||
self.write_value = self.write_decoded_value
|
||||
self.subscribe = self.wrapped_subscribe
|
||||
|
||||
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
|
||||
self.unsubscribe = self.wrapped_unsubscribe
|
||||
|
||||
def __getattr__(self, name):
|
||||
@@ -475,11 +616,13 @@ class CharacteristicAdapter:
|
||||
else:
|
||||
setattr(self.wrapped_characteristic, name, value)
|
||||
|
||||
def read_encoded_value(self, connection):
|
||||
return self.encode_value(self.wrapped_characteristic.read_value(connection))
|
||||
async def read_encoded_value(self, connection):
|
||||
return self.encode_value(
|
||||
await self.wrapped_characteristic.read_value(connection)
|
||||
)
|
||||
|
||||
def write_encoded_value(self, connection, value):
|
||||
return self.wrapped_characteristic.write_value(
|
||||
async def write_encoded_value(self, connection, value):
|
||||
return await self.wrapped_characteristic.write_value(
|
||||
connection, self.decode_value(value)
|
||||
)
|
||||
|
||||
@@ -520,7 +663,7 @@ class CharacteristicAdapter:
|
||||
|
||||
return self.wrapped_characteristic.unsubscribe(subscriber)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
wrapped = str(self.wrapped_characteristic)
|
||||
return f'{self.__class__.__name__}({wrapped})'
|
||||
|
||||
@@ -577,7 +720,7 @@ class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
|
||||
'''
|
||||
Adapter that packs/unpacks characteristic values according to a standard
|
||||
Python `struct` format.
|
||||
The adapted `read_value` and `write_value` methods return/accept aa dictionary which
|
||||
The adapted `read_value` and `write_value` methods return/accept a dictionary which
|
||||
is packed/unpacked according to format, with the arguments extracted from the
|
||||
dictionary by key, in the same order as they occur in the `keys` parameter.
|
||||
'''
|
||||
@@ -600,27 +743,56 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
|
||||
Adapter that converts strings to/from bytes using UTF-8 encoding
|
||||
'''
|
||||
|
||||
def encode_value(self, value):
|
||||
def encode_value(self, value: str) -> bytes:
|
||||
return value.encode('utf-8')
|
||||
|
||||
def decode_value(self, value):
|
||||
def decode_value(self, value: bytes) -> str:
|
||||
return value.decode('utf-8')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SerializableCharacteristicAdapter(CharacteristicAdapter):
|
||||
'''
|
||||
Adapter that converts any class to/from bytes using the class'
|
||||
`to_bytes` and `__bytes__` methods, respectively.
|
||||
'''
|
||||
|
||||
def __init__(self, characteristic, cls: Type[ByteSerializable]):
|
||||
super().__init__(characteristic)
|
||||
self.cls = cls
|
||||
|
||||
def encode_value(self, value: SupportsBytes) -> bytes:
|
||||
return bytes(value)
|
||||
|
||||
def decode_value(self, value: bytes) -> Any:
|
||||
return self.cls.from_bytes(value)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Descriptor(Attribute):
|
||||
'''
|
||||
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
|
||||
'''
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if isinstance(self.value, bytes):
|
||||
value_str = self.value.hex()
|
||||
elif isinstance(self.value, CharacteristicValue):
|
||||
value = self.value.read(None)
|
||||
if isinstance(value, bytes):
|
||||
value_str = value.hex()
|
||||
else:
|
||||
value_str = '<async>'
|
||||
else:
|
||||
value_str = '<...>'
|
||||
return (
|
||||
f'Descriptor(handle=0x{self.handle:04X}, '
|
||||
f'type={self.type}, '
|
||||
f'value={self.read_value(None).hex()})'
|
||||
f'value={value_str})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ClientCharacteristicConfigurationBits(enum.IntFlag):
|
||||
'''
|
||||
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
|
||||
|
||||
@@ -28,7 +28,19 @@ import asyncio
|
||||
import logging
|
||||
import struct
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Tuple, Callable, Union, Any
|
||||
from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Dict,
|
||||
Tuple,
|
||||
Callable,
|
||||
Union,
|
||||
Any,
|
||||
Iterable,
|
||||
Type,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from pyee import EventEmitter
|
||||
|
||||
@@ -56,7 +68,7 @@ from .att import (
|
||||
ATT_Error,
|
||||
)
|
||||
from . import core
|
||||
from .core import UUID, InvalidStateError, ProtocolError
|
||||
from .core import UUID, InvalidStateError
|
||||
from .gatt import (
|
||||
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
@@ -66,28 +78,48 @@ from .gatt import (
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||
Characteristic,
|
||||
ClientCharacteristicConfigurationBits,
|
||||
TemplateService,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Connection
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def show_services(services: Iterable[ServiceProxy]) -> None:
|
||||
for service in services:
|
||||
print(color(str(service), 'cyan'))
|
||||
|
||||
for characteristic in service.characteristics:
|
||||
print(color(' ' + str(characteristic), 'magenta'))
|
||||
|
||||
for descriptor in characteristic.descriptors:
|
||||
print(color(' ' + str(descriptor), 'green'))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Proxies
|
||||
# -----------------------------------------------------------------------------
|
||||
class AttributeProxy(EventEmitter):
|
||||
client: Client
|
||||
|
||||
def __init__(self, client, handle, end_group_handle, attribute_type):
|
||||
def __init__(
|
||||
self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID
|
||||
) -> None:
|
||||
EventEmitter.__init__(self)
|
||||
self.client = client
|
||||
self.handle = handle
|
||||
self.end_group_handle = end_group_handle
|
||||
self.type = attribute_type
|
||||
|
||||
async def read_value(self, no_long_read=False):
|
||||
async def read_value(self, no_long_read: bool = False) -> bytes:
|
||||
return self.decode_value(
|
||||
await self.client.read_value(self.handle, no_long_read)
|
||||
)
|
||||
@@ -97,13 +129,13 @@ class AttributeProxy(EventEmitter):
|
||||
self.handle, self.encode_value(value), with_response
|
||||
)
|
||||
|
||||
def encode_value(self, value):
|
||||
def encode_value(self, value: Any) -> bytes:
|
||||
return value
|
||||
|
||||
def decode_value(self, value_bytes):
|
||||
def decode_value(self, value_bytes: bytes) -> Any:
|
||||
return value_bytes
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
|
||||
|
||||
|
||||
@@ -113,7 +145,7 @@ class ServiceProxy(AttributeProxy):
|
||||
included_services: List[ServiceProxy]
|
||||
|
||||
@staticmethod
|
||||
def from_client(service_class, client, service_uuid):
|
||||
def from_client(service_class, client: Client, service_uuid: UUID):
|
||||
# The service and its characteristics are considered to have already been
|
||||
# discovered
|
||||
services = client.get_services_by_uuid(service_uuid)
|
||||
@@ -136,14 +168,14 @@ class ServiceProxy(AttributeProxy):
|
||||
def get_characteristics_by_uuid(self, uuid):
|
||||
return self.client.get_characteristics_by_uuid(uuid, self)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
|
||||
|
||||
|
||||
class CharacteristicProxy(AttributeProxy):
|
||||
properties: Characteristic.Properties
|
||||
descriptors: List[DescriptorProxy]
|
||||
subscribers: Dict[Any, Callable]
|
||||
subscribers: Dict[Any, Callable[[bytes], Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -171,7 +203,9 @@ class CharacteristicProxy(AttributeProxy):
|
||||
return await self.client.discover_descriptors(self)
|
||||
|
||||
async def subscribe(
|
||||
self, subscriber: Optional[Callable] = None, prefer_notify=True
|
||||
self,
|
||||
subscriber: Optional[Callable[[bytes], Any]] = None,
|
||||
prefer_notify: bool = True,
|
||||
):
|
||||
if subscriber is not None:
|
||||
if subscriber in self.subscribers:
|
||||
@@ -189,13 +223,13 @@ class CharacteristicProxy(AttributeProxy):
|
||||
|
||||
return await self.client.subscribe(self, subscriber, prefer_notify)
|
||||
|
||||
async def unsubscribe(self, subscriber=None):
|
||||
async def unsubscribe(self, subscriber=None, force=False):
|
||||
if subscriber in self.subscribers:
|
||||
subscriber = self.subscribers.pop(subscriber)
|
||||
|
||||
return await self.client.unsubscribe(self, subscriber)
|
||||
return await self.client.unsubscribe(self, subscriber, force)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Characteristic(handle=0x{self.handle:04X}, '
|
||||
f'uuid={self.uuid}, '
|
||||
@@ -207,7 +241,7 @@ class DescriptorProxy(AttributeProxy):
|
||||
def __init__(self, client, handle, descriptor_type):
|
||||
super().__init__(client, handle, 0, descriptor_type)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})'
|
||||
|
||||
|
||||
@@ -216,8 +250,10 @@ class ProfileServiceProxy:
|
||||
Base class for profile-specific service proxies
|
||||
'''
|
||||
|
||||
SERVICE_CLASS: Type[TemplateService]
|
||||
|
||||
@classmethod
|
||||
def from_client(cls, client):
|
||||
def from_client(cls, client: Client) -> Optional[ProfileServiceProxy]:
|
||||
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
|
||||
|
||||
|
||||
@@ -227,30 +263,38 @@ class ProfileServiceProxy:
|
||||
class Client:
|
||||
services: List[ServiceProxy]
|
||||
cached_values: Dict[int, Tuple[datetime, bytes]]
|
||||
notification_subscribers: Dict[
|
||||
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
|
||||
]
|
||||
indication_subscribers: Dict[
|
||||
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
|
||||
]
|
||||
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
|
||||
pending_request: Optional[ATT_PDU]
|
||||
|
||||
def __init__(self, connection):
|
||||
def __init__(self, connection: Connection) -> None:
|
||||
self.connection = connection
|
||||
self.mtu_exchange_done = False
|
||||
self.request_semaphore = asyncio.Semaphore(1)
|
||||
self.pending_request = None
|
||||
self.pending_response = None
|
||||
self.notification_subscribers = (
|
||||
{}
|
||||
) # Notification subscribers, by attribute handle
|
||||
self.indication_subscribers = {} # Indication subscribers, by attribute handle
|
||||
self.notification_subscribers = {} # Subscriber set, by attribute handle
|
||||
self.indication_subscribers = {} # Subscriber set, by attribute handle
|
||||
self.services = []
|
||||
self.cached_values = {}
|
||||
|
||||
def send_gatt_pdu(self, pdu):
|
||||
connection.on('disconnection', self.on_disconnection)
|
||||
|
||||
def send_gatt_pdu(self, pdu: bytes) -> None:
|
||||
self.connection.send_l2cap_pdu(ATT_CID, pdu)
|
||||
|
||||
async def send_command(self, command):
|
||||
async def send_command(self, command: ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
|
||||
)
|
||||
self.send_gatt_pdu(command.to_bytes())
|
||||
self.send_gatt_pdu(bytes(command))
|
||||
|
||||
async def send_request(self, request):
|
||||
async def send_request(self, request: ATT_PDU):
|
||||
logger.debug(
|
||||
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
|
||||
)
|
||||
@@ -266,7 +310,7 @@ class Client:
|
||||
self.pending_request = request
|
||||
|
||||
try:
|
||||
self.send_gatt_pdu(request.to_bytes())
|
||||
self.send_gatt_pdu(bytes(request))
|
||||
response = await asyncio.wait_for(
|
||||
self.pending_response, GATT_REQUEST_TIMEOUT
|
||||
)
|
||||
@@ -279,19 +323,19 @@ class Client:
|
||||
|
||||
return response
|
||||
|
||||
def send_confirmation(self, confirmation):
|
||||
def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None:
|
||||
logger.debug(
|
||||
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
|
||||
f'{confirmation}'
|
||||
)
|
||||
self.send_gatt_pdu(confirmation.to_bytes())
|
||||
self.send_gatt_pdu(bytes(confirmation))
|
||||
|
||||
async def request_mtu(self, mtu):
|
||||
async def request_mtu(self, mtu: int) -> int:
|
||||
# Check the range
|
||||
if mtu < ATT_DEFAULT_MTU:
|
||||
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
|
||||
raise core.InvalidArgumentError(f'MTU must be >= {ATT_DEFAULT_MTU}')
|
||||
if mtu > 0xFFFF:
|
||||
raise ValueError('MTU must be <= 0xFFFF')
|
||||
raise core.InvalidArgumentError('MTU must be <= 0xFFFF')
|
||||
|
||||
# We can only send one request per connection
|
||||
if self.mtu_exchange_done:
|
||||
@@ -301,22 +345,19 @@ class Client:
|
||||
self.mtu_exchange_done = True
|
||||
response = await self.send_request(ATT_Exchange_MTU_Request(client_rx_mtu=mtu))
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
raise ProtocolError(
|
||||
response.error_code,
|
||||
'att',
|
||||
ATT_PDU.error_name(response.error_code),
|
||||
response,
|
||||
)
|
||||
raise ATT_Error(error_code=response.error_code, message=response)
|
||||
|
||||
# Compute the final MTU
|
||||
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
|
||||
|
||||
return self.connection.att_mtu
|
||||
|
||||
def get_services_by_uuid(self, uuid):
|
||||
def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]:
|
||||
return [service for service in self.services if service.uuid == uuid]
|
||||
|
||||
def get_characteristics_by_uuid(self, uuid, service=None):
|
||||
def get_characteristics_by_uuid(
|
||||
self, uuid: UUID, service: Optional[ServiceProxy] = None
|
||||
) -> List[CharacteristicProxy]:
|
||||
services = [service] if service else self.services
|
||||
return [
|
||||
c
|
||||
@@ -324,9 +365,7 @@ class Client:
|
||||
if c.uuid == uuid
|
||||
]
|
||||
|
||||
def get_attribute_grouping(
|
||||
self, attribute_handle: int
|
||||
) -> Optional[
|
||||
def get_attribute_grouping(self, attribute_handle: int) -> Optional[
|
||||
Union[
|
||||
ServiceProxy,
|
||||
Tuple[ServiceProxy, CharacteristicProxy],
|
||||
@@ -363,7 +402,7 @@ class Client:
|
||||
if not already_known:
|
||||
self.services.append(service)
|
||||
|
||||
async def discover_services(self, uuids=None) -> List[ServiceProxy]:
|
||||
async def discover_services(self, uuids: Iterable[UUID] = ()) -> List[ServiceProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.4.1 Discover All Primary Services
|
||||
'''
|
||||
@@ -435,7 +474,7 @@ class Client:
|
||||
|
||||
return services
|
||||
|
||||
async def discover_service(self, uuid):
|
||||
async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
|
||||
'''
|
||||
@@ -468,7 +507,7 @@ class Client:
|
||||
f'{HCI_Constant.error_name(response.error_code)}'
|
||||
)
|
||||
# TODO raise appropriate exception
|
||||
return
|
||||
return []
|
||||
break
|
||||
|
||||
for attribute_handle, end_group_handle in response.handles_information:
|
||||
@@ -480,7 +519,7 @@ class Client:
|
||||
logger.warning(
|
||||
f'bogus handle values: {attribute_handle} {end_group_handle}'
|
||||
)
|
||||
return
|
||||
return []
|
||||
|
||||
# Create a service proxy for this service
|
||||
service = ServiceProxy(
|
||||
@@ -657,8 +696,8 @@ class Client:
|
||||
async def discover_descriptors(
|
||||
self,
|
||||
characteristic: Optional[CharacteristicProxy] = None,
|
||||
start_handle=None,
|
||||
end_handle=None,
|
||||
start_handle: Optional[int] = None,
|
||||
end_handle: Optional[int] = None,
|
||||
) -> List[DescriptorProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
|
||||
@@ -721,7 +760,7 @@ class Client:
|
||||
|
||||
return descriptors
|
||||
|
||||
async def discover_attributes(self):
|
||||
async def discover_attributes(self) -> List[AttributeProxy]:
|
||||
'''
|
||||
Discover all attributes, regardless of type
|
||||
'''
|
||||
@@ -764,7 +803,12 @@ class Client:
|
||||
|
||||
return attributes
|
||||
|
||||
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
|
||||
async def subscribe(
|
||||
self,
|
||||
characteristic: CharacteristicProxy,
|
||||
subscriber: Optional[Callable[[bytes], Any]] = None,
|
||||
prefer_notify: bool = True,
|
||||
) -> None:
|
||||
# If we haven't already discovered the descriptors for this characteristic,
|
||||
# do it now
|
||||
if not characteristic.descriptors_discovered:
|
||||
@@ -801,6 +845,7 @@ class Client:
|
||||
subscriber_set = subscribers.setdefault(characteristic.handle, set())
|
||||
if subscriber is not None:
|
||||
subscriber_set.add(subscriber)
|
||||
|
||||
# Add the characteristic as a subscriber, which will result in the
|
||||
# characteristic emitting an 'update' event when a notification or indication
|
||||
# is received
|
||||
@@ -808,7 +853,18 @@ class Client:
|
||||
|
||||
await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
|
||||
|
||||
async def unsubscribe(self, characteristic, subscriber=None):
|
||||
async def unsubscribe(
|
||||
self,
|
||||
characteristic: CharacteristicProxy,
|
||||
subscriber: Optional[Callable[[bytes], Any]] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
'''
|
||||
Unsubscribe from a characteristic.
|
||||
|
||||
If `force` is True, this will write zeros to the CCCD when there are no
|
||||
subscribers left, even if there were already no registered subscribers.
|
||||
'''
|
||||
# If we haven't already discovered the descriptors for this characteristic,
|
||||
# do it now
|
||||
if not characteristic.descriptors_discovered:
|
||||
@@ -822,29 +878,51 @@ class Client:
|
||||
logger.warning('unsubscribing from characteristic with no CCCD descriptor')
|
||||
return
|
||||
|
||||
# Check if the characteristic has subscribers
|
||||
if not (
|
||||
characteristic.handle in self.notification_subscribers
|
||||
or characteristic.handle in self.indication_subscribers
|
||||
):
|
||||
if not force:
|
||||
return
|
||||
|
||||
# Remove the subscriber(s)
|
||||
if subscriber is not None:
|
||||
# Remove matching subscriber from subscriber sets
|
||||
for subscriber_set in (
|
||||
self.notification_subscribers,
|
||||
self.indication_subscribers,
|
||||
):
|
||||
subscribers = subscriber_set.get(characteristic.handle, [])
|
||||
if subscriber in subscribers:
|
||||
if (
|
||||
subscribers := subscriber_set.get(characteristic.handle)
|
||||
) and subscriber in subscribers:
|
||||
subscribers.remove(subscriber)
|
||||
|
||||
# The characteristic itself is added as subscriber. If it is the
|
||||
# last remaining subscriber, we remove it, such that the clean up
|
||||
# works correctly. Otherwise the CCCD never is set back to 0.
|
||||
if len(subscribers) == 1 and characteristic in subscribers:
|
||||
subscribers.remove(characteristic)
|
||||
|
||||
# Cleanup if we removed the last one
|
||||
if not subscribers:
|
||||
del subscriber_set[characteristic.handle]
|
||||
else:
|
||||
# Remove all subscribers for this attribute from the sets!
|
||||
# Remove all subscribers for this attribute from the sets
|
||||
self.notification_subscribers.pop(characteristic.handle, None)
|
||||
self.indication_subscribers.pop(characteristic.handle, None)
|
||||
|
||||
if not self.notification_subscribers and not self.indication_subscribers:
|
||||
# Update the CCCD
|
||||
if not (
|
||||
characteristic.handle in self.notification_subscribers
|
||||
or characteristic.handle in self.indication_subscribers
|
||||
):
|
||||
# No more subscribers left
|
||||
await self.write_value(cccd, b'\x00\x00', with_response=True)
|
||||
|
||||
async def read_value(self, attribute, no_long_read=False):
|
||||
async def read_value(
|
||||
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
|
||||
) -> bytes:
|
||||
'''
|
||||
See Vol 3, Part G - 4.8.1 Read Characteristic Value
|
||||
|
||||
@@ -859,12 +937,7 @@ class Client:
|
||||
if response is None:
|
||||
raise TimeoutError('read timeout')
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
raise ProtocolError(
|
||||
response.error_code,
|
||||
'att',
|
||||
ATT_PDU.error_name(response.error_code),
|
||||
response,
|
||||
)
|
||||
raise ATT_Error(error_code=response.error_code, message=response)
|
||||
|
||||
# If the value is the max size for the MTU, try to read more unless the caller
|
||||
# specifically asked not to do that
|
||||
@@ -886,12 +959,7 @@ class Client:
|
||||
ATT_INVALID_OFFSET_ERROR,
|
||||
):
|
||||
break
|
||||
raise ProtocolError(
|
||||
response.error_code,
|
||||
'att',
|
||||
ATT_PDU.error_name(response.error_code),
|
||||
response,
|
||||
)
|
||||
raise ATT_Error(error_code=response.error_code, message=response)
|
||||
|
||||
part = response.part_attribute_value
|
||||
attribute_value += part
|
||||
@@ -905,7 +973,9 @@ class Client:
|
||||
# Return the value as bytes
|
||||
return attribute_value
|
||||
|
||||
async def read_characteristics_by_uuid(self, uuid, service):
|
||||
async def read_characteristics_by_uuid(
|
||||
self, uuid: UUID, service: Optional[ServiceProxy]
|
||||
) -> List[bytes]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
|
||||
'''
|
||||
@@ -960,7 +1030,12 @@ class Client:
|
||||
|
||||
return characteristics_values
|
||||
|
||||
async def write_value(self, attribute, value, with_response=False):
|
||||
async def write_value(
|
||||
self,
|
||||
attribute: Union[int, AttributeProxy],
|
||||
value: bytes,
|
||||
with_response: bool = False,
|
||||
) -> None:
|
||||
'''
|
||||
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
|
||||
Value
|
||||
@@ -977,12 +1052,7 @@ class Client:
|
||||
)
|
||||
)
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
raise ProtocolError(
|
||||
response.error_code,
|
||||
'att',
|
||||
ATT_PDU.error_name(response.error_code),
|
||||
response,
|
||||
)
|
||||
raise ATT_Error(error_code=response.error_code, message=response)
|
||||
else:
|
||||
await self.send_command(
|
||||
ATT_Write_Command(
|
||||
@@ -990,7 +1060,11 @@ class Client:
|
||||
)
|
||||
)
|
||||
|
||||
def on_gatt_pdu(self, att_pdu):
|
||||
def on_disconnection(self, _) -> None:
|
||||
if self.pending_response and not self.pending_response.done():
|
||||
self.pending_response.cancel()
|
||||
|
||||
def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
|
||||
)
|
||||
@@ -1000,7 +1074,7 @@ class Client:
|
||||
logger.warning('!!! unexpected response, there is no pending request')
|
||||
return
|
||||
|
||||
# Sanity check: the response should match the pending request unless it is
|
||||
# The response should match the pending request unless it is
|
||||
# an error response
|
||||
if att_pdu.op_code != ATT_ERROR_RESPONSE:
|
||||
expected_response_name = self.pending_request.name.replace(
|
||||
@@ -1013,6 +1087,7 @@ class Client:
|
||||
return
|
||||
|
||||
# Return the response to the coroutine that is waiting for it
|
||||
assert self.pending_response is not None
|
||||
self.pending_response.set_result(att_pdu)
|
||||
else:
|
||||
handler_name = f'on_{att_pdu.name.lower()}'
|
||||
@@ -1032,7 +1107,7 @@ class Client:
|
||||
def on_att_handle_value_notification(self, notification):
|
||||
# Call all subscribers
|
||||
subscribers = self.notification_subscribers.get(
|
||||
notification.attribute_handle, []
|
||||
notification.attribute_handle, set()
|
||||
)
|
||||
if not subscribers:
|
||||
logger.warning('!!! received notification with no subscriber')
|
||||
@@ -1046,7 +1121,9 @@ class Client:
|
||||
|
||||
def on_att_handle_value_indication(self, indication):
|
||||
# Call all subscribers
|
||||
subscribers = self.indication_subscribers.get(indication.attribute_handle, [])
|
||||
subscribers = self.indication_subscribers.get(
|
||||
indication.attribute_handle, set()
|
||||
)
|
||||
if not subscribers:
|
||||
logger.warning('!!! received indication with no subscriber')
|
||||
|
||||
@@ -1060,7 +1137,7 @@ class Client:
|
||||
# Confirm that we received the indication
|
||||
self.send_confirmation(ATT_Handle_Value_Confirmation())
|
||||
|
||||
def cache_value(self, attribute_handle: int, value: bytes):
|
||||
def cache_value(self, attribute_handle: int, value: bytes) -> None:
|
||||
self.cached_values[attribute_handle] = (
|
||||
datetime.now(),
|
||||
value,
|
||||
|
||||
@@ -23,16 +23,27 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import struct
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from typing import (
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Type,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from pyee import EventEmitter
|
||||
|
||||
from .colors import color
|
||||
from .core import UUID
|
||||
from .att import (
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID
|
||||
from bumble.att import (
|
||||
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
||||
ATT_CID,
|
||||
@@ -42,6 +53,7 @@ from .att import (
|
||||
ATT_INVALID_OFFSET_ERROR,
|
||||
ATT_REQUEST_NOT_SUPPORTED_ERROR,
|
||||
ATT_REQUESTS,
|
||||
ATT_PDU,
|
||||
ATT_UNLIKELY_ERROR_ERROR,
|
||||
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
|
||||
ATT_Error,
|
||||
@@ -58,7 +70,7 @@ from .att import (
|
||||
ATT_Write_Response,
|
||||
Attribute,
|
||||
)
|
||||
from .gatt import (
|
||||
from bumble.gatt import (
|
||||
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
|
||||
@@ -66,13 +78,17 @@ from .gatt import (
|
||||
GATT_REQUEST_TIMEOUT,
|
||||
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
Characteristic,
|
||||
CharacteristicAdapter,
|
||||
CharacteristicDeclaration,
|
||||
CharacteristicValue,
|
||||
IncludedServiceDeclaration,
|
||||
Descriptor,
|
||||
Service,
|
||||
)
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Device, Connection
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -91,8 +107,13 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
|
||||
# -----------------------------------------------------------------------------
|
||||
class Server(EventEmitter):
|
||||
attributes: List[Attribute]
|
||||
services: List[Service]
|
||||
attributes_by_handle: Dict[int, Attribute]
|
||||
subscribers: Dict[int, Dict[int, bytes]]
|
||||
indication_semaphores: defaultdict[int, asyncio.Semaphore]
|
||||
pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
|
||||
|
||||
def __init__(self, device):
|
||||
def __init__(self, device: Device) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.services = []
|
||||
@@ -107,16 +128,16 @@ class Server(EventEmitter):
|
||||
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
|
||||
self.pending_confirmations = defaultdict(lambda: None)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return "\n".join(map(str, self.attributes))
|
||||
|
||||
def send_gatt_pdu(self, connection_handle, pdu):
|
||||
def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
|
||||
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
|
||||
|
||||
def next_handle(self):
|
||||
def next_handle(self) -> int:
|
||||
return 1 + len(self.attributes)
|
||||
|
||||
def get_advertising_service_data(self):
|
||||
def get_advertising_service_data(self) -> Dict[Attribute, bytes]:
|
||||
return {
|
||||
attribute: data
|
||||
for attribute in self.attributes
|
||||
@@ -124,7 +145,7 @@ class Server(EventEmitter):
|
||||
and (data := attribute.get_advertising_data())
|
||||
}
|
||||
|
||||
def get_attribute(self, handle):
|
||||
def get_attribute(self, handle: int) -> Optional[Attribute]:
|
||||
attribute = self.attributes_by_handle.get(handle)
|
||||
if attribute:
|
||||
return attribute
|
||||
@@ -173,12 +194,17 @@ class Server(EventEmitter):
|
||||
|
||||
return next(
|
||||
(
|
||||
(attribute, self.get_attribute(attribute.characteristic.handle))
|
||||
(
|
||||
attribute,
|
||||
self.get_attribute(attribute.characteristic.handle),
|
||||
) # type: ignore
|
||||
for attribute in map(
|
||||
self.get_attribute,
|
||||
range(service_handle.handle, service_handle.end_group_handle + 1),
|
||||
)
|
||||
if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
|
||||
if attribute is not None
|
||||
and attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
|
||||
and isinstance(attribute, CharacteristicDeclaration)
|
||||
and attribute.characteristic.uuid == characteristic_uuid
|
||||
),
|
||||
None,
|
||||
@@ -197,7 +223,7 @@ class Server(EventEmitter):
|
||||
|
||||
return next(
|
||||
(
|
||||
attribute
|
||||
attribute # type: ignore
|
||||
for attribute in map(
|
||||
self.get_attribute,
|
||||
range(
|
||||
@@ -205,12 +231,12 @@ class Server(EventEmitter):
|
||||
characteristic_value.end_group_handle + 1,
|
||||
),
|
||||
)
|
||||
if attribute.type == descriptor_uuid
|
||||
if attribute is not None and attribute.type == descriptor_uuid
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def add_attribute(self, attribute):
|
||||
def add_attribute(self, attribute: Attribute) -> None:
|
||||
# Assign a handle to this attribute
|
||||
attribute.handle = self.next_handle()
|
||||
attribute.end_group_handle = (
|
||||
@@ -220,7 +246,7 @@ class Server(EventEmitter):
|
||||
# Add this attribute to the list
|
||||
self.attributes.append(attribute)
|
||||
|
||||
def add_service(self, service: Service):
|
||||
def add_service(self, service: Service) -> None:
|
||||
# Add the service attribute to the DB
|
||||
self.add_attribute(service)
|
||||
|
||||
@@ -285,11 +311,13 @@ class Server(EventEmitter):
|
||||
service.end_group_handle = self.attributes[-1].handle
|
||||
self.services.append(service)
|
||||
|
||||
def add_services(self, services):
|
||||
def add_services(self, services: Iterable[Service]) -> None:
|
||||
for service in services:
|
||||
self.add_service(service)
|
||||
|
||||
def read_cccd(self, connection, characteristic):
|
||||
def read_cccd(
|
||||
self, connection: Optional[Connection], characteristic: Characteristic
|
||||
) -> bytes:
|
||||
if connection is None:
|
||||
return bytes([0, 0])
|
||||
|
||||
@@ -300,13 +328,18 @@ class Server(EventEmitter):
|
||||
|
||||
return cccd or bytes([0, 0])
|
||||
|
||||
def write_cccd(self, connection, characteristic, value):
|
||||
def write_cccd(
|
||||
self,
|
||||
connection: Connection,
|
||||
characteristic: Characteristic,
|
||||
value: bytes,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'Subscription update for connection=0x{connection.handle:04X}, '
|
||||
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
|
||||
)
|
||||
|
||||
# Sanity check
|
||||
# Check parameters
|
||||
if len(value) != 2:
|
||||
logger.warning('CCCD value not 2 bytes long')
|
||||
return
|
||||
@@ -327,13 +360,19 @@ class Server(EventEmitter):
|
||||
indicate_enabled,
|
||||
)
|
||||
|
||||
def send_response(self, connection, response):
|
||||
def send_response(self, connection: Connection, response: ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
|
||||
)
|
||||
self.send_gatt_pdu(connection.handle, response.to_bytes())
|
||||
self.send_gatt_pdu(connection.handle, bytes(response))
|
||||
|
||||
async def notify_subscriber(self, connection, attribute, value=None, force=False):
|
||||
async def notify_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
# Check if there's a subscriber
|
||||
if not force:
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
@@ -352,7 +391,7 @@ class Server(EventEmitter):
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
@@ -370,7 +409,13 @@ class Server(EventEmitter):
|
||||
)
|
||||
self.send_gatt_pdu(connection.handle, bytes(notification))
|
||||
|
||||
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
|
||||
async def indicate_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
# Check if there's a subscriber
|
||||
if not force:
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
@@ -389,7 +434,7 @@ class Server(EventEmitter):
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
@@ -411,15 +456,13 @@ class Server(EventEmitter):
|
||||
assert self.pending_confirmations[connection.handle] is None
|
||||
|
||||
# Create a future value to hold the eventual response
|
||||
self.pending_confirmations[
|
||||
connection.handle
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
pending_confirmation = self.pending_confirmations[connection.handle] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
try:
|
||||
self.send_gatt_pdu(connection.handle, indication.to_bytes())
|
||||
await asyncio.wait_for(
|
||||
self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
|
||||
)
|
||||
self.send_gatt_pdu(connection.handle, bytes(indication))
|
||||
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
|
||||
except asyncio.TimeoutError as error:
|
||||
logger.warning(color('!!! GATT Indicate timeout', 'red'))
|
||||
raise TimeoutError(f'GATT timeout for {indication.name}') from error
|
||||
@@ -427,8 +470,12 @@ class Server(EventEmitter):
|
||||
self.pending_confirmations[connection.handle] = None
|
||||
|
||||
async def notify_or_indicate_subscribers(
|
||||
self, indicate, attribute, value=None, force=False
|
||||
):
|
||||
self,
|
||||
indicate: bool,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
# Get all the connections for which there's at least one subscription
|
||||
connections = [
|
||||
connection
|
||||
@@ -450,13 +497,23 @@ class Server(EventEmitter):
|
||||
]
|
||||
)
|
||||
|
||||
async def notify_subscribers(self, attribute, value=None, force=False):
|
||||
async def notify_subscribers(
|
||||
self,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
):
|
||||
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
|
||||
|
||||
async def indicate_subscribers(self, attribute, value=None, force=False):
|
||||
async def indicate_subscribers(
|
||||
self,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
):
|
||||
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
|
||||
|
||||
def on_disconnection(self, connection):
|
||||
def on_disconnection(self, connection: Connection) -> None:
|
||||
if connection.handle in self.subscribers:
|
||||
del self.subscribers[connection.handle]
|
||||
if connection.handle in self.indication_semaphores:
|
||||
@@ -464,7 +521,7 @@ class Server(EventEmitter):
|
||||
if connection.handle in self.pending_confirmations:
|
||||
del self.pending_confirmations[connection.handle]
|
||||
|
||||
def on_gatt_pdu(self, connection, att_pdu):
|
||||
def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None:
|
||||
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
|
||||
handler_name = f'on_{att_pdu.name.lower()}'
|
||||
handler = getattr(self, handler_name, None)
|
||||
@@ -506,7 +563,7 @@ class Server(EventEmitter):
|
||||
#######################################################
|
||||
# ATT handlers
|
||||
#######################################################
|
||||
def on_att_request(self, connection, pdu):
|
||||
def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None:
|
||||
'''
|
||||
Handler for requests without a more specific handler
|
||||
'''
|
||||
@@ -605,7 +662,8 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_find_by_type_value_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_find_by_type_value_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
|
||||
'''
|
||||
@@ -613,13 +671,13 @@ class Server(EventEmitter):
|
||||
# Build list of returned attributes
|
||||
pdu_space_available = connection.att_mtu - 2
|
||||
attributes = []
|
||||
for attribute in (
|
||||
async for attribute in (
|
||||
attribute
|
||||
for attribute in self.attributes
|
||||
if attribute.handle >= request.starting_handle
|
||||
and attribute.handle <= request.ending_handle
|
||||
and attribute.type == request.attribute_type
|
||||
and attribute.read_value(connection) == request.attribute_value
|
||||
and (await attribute.read_value(connection)) == request.attribute_value
|
||||
and pdu_space_available >= 4
|
||||
):
|
||||
# TODO: check permissions
|
||||
@@ -657,7 +715,8 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_by_type_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_type_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
|
||||
'''
|
||||
@@ -679,9 +738,8 @@ class Server(EventEmitter):
|
||||
and attribute.handle <= request.ending_handle
|
||||
and pdu_space_available
|
||||
):
|
||||
|
||||
try:
|
||||
attribute_value = attribute.read_value(connection)
|
||||
attribute_value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
# If the first attribute is unreadable, return an error
|
||||
# Otherwise return attributes up to this point
|
||||
@@ -723,14 +781,15 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
|
||||
'''
|
||||
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
try:
|
||||
value = attribute.read_value(connection)
|
||||
value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -748,14 +807,15 @@ class Server(EventEmitter):
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_blob_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_blob_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
|
||||
'''
|
||||
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
try:
|
||||
value = attribute.read_value(connection)
|
||||
value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -792,7 +852,8 @@ class Server(EventEmitter):
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_by_group_type_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_group_type_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
|
||||
'''
|
||||
@@ -820,7 +881,7 @@ class Server(EventEmitter):
|
||||
):
|
||||
# No need to catch permission errors here, since these attributes
|
||||
# must all be world-readable
|
||||
attribute_value = attribute.read_value(connection)
|
||||
attribute_value = await attribute.read_value(connection)
|
||||
# Check the attribute value size
|
||||
max_attribute_size = min(connection.att_mtu - 6, 251)
|
||||
if len(attribute_value) > max_attribute_size:
|
||||
@@ -859,12 +920,13 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_write_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_write_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
|
||||
'''
|
||||
|
||||
# Check that the attribute exists
|
||||
# Check that the attribute exists
|
||||
attribute = self.get_attribute(request.attribute_handle)
|
||||
if attribute is None:
|
||||
self.send_response(
|
||||
@@ -891,13 +953,22 @@ class Server(EventEmitter):
|
||||
)
|
||||
return
|
||||
|
||||
# Accept the value
|
||||
attribute.write_value(connection, request.attribute_value)
|
||||
try:
|
||||
# Accept the value
|
||||
await attribute.write_value(connection, request.attribute_value)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
attribute_handle_in_error=request.attribute_handle,
|
||||
error_code=error.error_code,
|
||||
)
|
||||
else:
|
||||
# Done
|
||||
response = ATT_Write_Response()
|
||||
self.send_response(connection, response)
|
||||
|
||||
# Done
|
||||
self.send_response(connection, ATT_Write_Response())
|
||||
|
||||
def on_att_write_command(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_write_command(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
|
||||
'''
|
||||
@@ -915,9 +986,9 @@ class Server(EventEmitter):
|
||||
|
||||
# Accept the value
|
||||
try:
|
||||
attribute.write_value(connection, request.attribute_value)
|
||||
await attribute.write_value(connection, request.attribute_value)
|
||||
except Exception as error:
|
||||
logger.warning(f'!!! ignoring exception: {error}')
|
||||
logger.exception(f'!!! ignoring exception: {error}')
|
||||
|
||||
def on_att_handle_value_confirmation(self, connection, _confirmation):
|
||||
'''
|
||||
|
||||
3685
bumble/hci.py
3685
bumble/hci.py
File diff suppressed because it is too large
Load Diff
@@ -15,30 +15,46 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, MutableMapping
|
||||
import datetime
|
||||
from typing import cast, Any, Optional
|
||||
import logging
|
||||
|
||||
from .colors import color
|
||||
from .att import ATT_CID, ATT_PDU
|
||||
from .smp import SMP_CID, SMP_Command
|
||||
from .core import name_or_number
|
||||
from .l2cap import (
|
||||
from bumble import avc
|
||||
from bumble import avctp
|
||||
from bumble import avdtp
|
||||
from bumble import avrcp
|
||||
from bumble import crypto
|
||||
from bumble import rfcomm
|
||||
from bumble import sdp
|
||||
from bumble.colors import color
|
||||
from bumble.att import ATT_CID, ATT_PDU
|
||||
from bumble.smp import SMP_CID, SMP_Command
|
||||
from bumble.core import name_or_number
|
||||
from bumble.l2cap import (
|
||||
L2CAP_PDU,
|
||||
L2CAP_CONNECTION_REQUEST,
|
||||
L2CAP_CONNECTION_RESPONSE,
|
||||
L2CAP_SIGNALING_CID,
|
||||
L2CAP_LE_SIGNALING_CID,
|
||||
L2CAP_Control_Frame,
|
||||
L2CAP_Connection_Request,
|
||||
L2CAP_Connection_Response,
|
||||
)
|
||||
from .hci import (
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
HCI_EVENT_PACKET,
|
||||
HCI_ACL_DATA_PACKET,
|
||||
HCI_DISCONNECTION_COMPLETE_EVENT,
|
||||
HCI_AclDataPacketAssembler,
|
||||
HCI_Packet,
|
||||
HCI_Event,
|
||||
HCI_AclDataPacket,
|
||||
HCI_Disconnection_Complete_Event,
|
||||
)
|
||||
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM
|
||||
from .sdp import SDP_PDU, SDP_PSM
|
||||
from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -48,26 +64,36 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
PSM_NAMES = {
|
||||
RFCOMM_PSM: 'RFCOMM',
|
||||
SDP_PSM: 'SDP',
|
||||
AVDTP_PSM: 'AVDTP'
|
||||
rfcomm.RFCOMM_PSM: 'RFCOMM',
|
||||
sdp.SDP_PSM: 'SDP',
|
||||
avdtp.AVDTP_PSM: 'AVDTP',
|
||||
avctp.AVCTP_PSM: 'AVCTP',
|
||||
# TODO: add more PSM values
|
||||
}
|
||||
|
||||
AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PacketTracer:
|
||||
class AclStream:
|
||||
def __init__(self, analyzer):
|
||||
psms: MutableMapping[int, int]
|
||||
peer: Optional[PacketTracer.AclStream]
|
||||
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
|
||||
avctp_assemblers: MutableMapping[int, avctp.MessageAssembler]
|
||||
|
||||
def __init__(self, analyzer: PacketTracer.Analyzer) -> None:
|
||||
self.analyzer = analyzer
|
||||
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
|
||||
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
|
||||
self.avctp_assemblers = {} # AVCTP assemblers, by source_cid
|
||||
self.psms = {} # PSM, by source_cid
|
||||
self.peer = None # ACL stream in the other direction
|
||||
self.peer = None
|
||||
|
||||
# pylint: disable=too-many-nested-blocks
|
||||
def on_acl_pdu(self, pdu):
|
||||
def on_acl_pdu(self, pdu: bytes) -> None:
|
||||
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
|
||||
self.analyzer.emit(l2cap_pdu)
|
||||
|
||||
if l2cap_pdu.cid == ATT_CID:
|
||||
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
|
||||
@@ -81,46 +107,59 @@ class PacketTracer:
|
||||
|
||||
# Check if this signals a new channel
|
||||
if control_frame.code == L2CAP_CONNECTION_REQUEST:
|
||||
self.psms[control_frame.source_cid] = control_frame.psm
|
||||
connection_request = cast(L2CAP_Connection_Request, control_frame)
|
||||
self.psms[connection_request.source_cid] = connection_request.psm
|
||||
elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
|
||||
connection_response = cast(L2CAP_Connection_Response, control_frame)
|
||||
if (
|
||||
control_frame.result
|
||||
connection_response.result
|
||||
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
|
||||
):
|
||||
if self.peer:
|
||||
if psm := self.peer.psms.get(control_frame.source_cid):
|
||||
# Found a pending connection
|
||||
self.psms[control_frame.destination_cid] = psm
|
||||
|
||||
# For AVDTP connections, create a packet assembler for
|
||||
# each direction
|
||||
if psm == AVDTP_PSM:
|
||||
self.avdtp_assemblers[
|
||||
control_frame.source_cid
|
||||
] = AVDTP_MessageAssembler(self.on_avdtp_message)
|
||||
self.peer.avdtp_assemblers[
|
||||
control_frame.destination_cid
|
||||
] = AVDTP_MessageAssembler(
|
||||
self.peer.on_avdtp_message
|
||||
)
|
||||
if self.peer and (
|
||||
psm := self.peer.psms.get(connection_response.source_cid)
|
||||
):
|
||||
# Found a pending connection
|
||||
self.psms[connection_response.destination_cid] = psm
|
||||
|
||||
# For AVDTP connections, create a packet assembler for
|
||||
# each direction
|
||||
if psm == avdtp.AVDTP_PSM:
|
||||
self.avdtp_assemblers[
|
||||
connection_response.source_cid
|
||||
] = avdtp.MessageAssembler(self.on_avdtp_message)
|
||||
self.peer.avdtp_assemblers[
|
||||
connection_response.destination_cid
|
||||
] = avdtp.MessageAssembler(self.peer.on_avdtp_message)
|
||||
elif psm == avctp.AVCTP_PSM:
|
||||
self.avctp_assemblers[
|
||||
connection_response.source_cid
|
||||
] = avctp.MessageAssembler(self.on_avctp_message)
|
||||
self.peer.avctp_assemblers[
|
||||
connection_response.destination_cid
|
||||
] = avctp.MessageAssembler(self.peer.on_avctp_message)
|
||||
else:
|
||||
# Try to find the PSM associated with this PDU
|
||||
if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)):
|
||||
if psm == SDP_PSM:
|
||||
sdp_pdu = SDP_PDU.from_bytes(l2cap_pdu.payload)
|
||||
if psm == sdp.SDP_PSM:
|
||||
sdp_pdu = sdp.SDP_PDU.from_bytes(l2cap_pdu.payload)
|
||||
self.analyzer.emit(sdp_pdu)
|
||||
elif psm == RFCOMM_PSM:
|
||||
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
|
||||
elif psm == rfcomm.RFCOMM_PSM:
|
||||
rfcomm_frame = rfcomm.RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
|
||||
self.analyzer.emit(rfcomm_frame)
|
||||
elif psm == AVDTP_PSM:
|
||||
elif psm == avdtp.AVDTP_PSM:
|
||||
self.analyzer.emit(
|
||||
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
|
||||
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
|
||||
)
|
||||
assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
|
||||
if assembler:
|
||||
assembler.on_pdu(l2cap_pdu.payload)
|
||||
if avdtp_assembler := self.avdtp_assemblers.get(l2cap_pdu.cid):
|
||||
avdtp_assembler.on_pdu(l2cap_pdu.payload)
|
||||
elif psm == avctp.AVCTP_PSM:
|
||||
self.analyzer.emit(
|
||||
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
|
||||
f'PSM=AVCTP]: {l2cap_pdu.payload.hex()}'
|
||||
)
|
||||
if avctp_assembler := self.avctp_assemblers.get(l2cap_pdu.cid):
|
||||
avctp_assembler.on_pdu(l2cap_pdu.payload)
|
||||
else:
|
||||
psm_string = name_or_number(PSM_NAMES, psm)
|
||||
self.analyzer.emit(
|
||||
@@ -130,22 +169,49 @@ class PacketTracer:
|
||||
else:
|
||||
self.analyzer.emit(l2cap_pdu)
|
||||
|
||||
def on_avdtp_message(self, transaction_label, message):
|
||||
def on_avdtp_message(
|
||||
self, transaction_label: int, message: avdtp.Message
|
||||
) -> None:
|
||||
self.analyzer.emit(
|
||||
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
|
||||
)
|
||||
|
||||
def feed_packet(self, packet):
|
||||
def on_avctp_message(
|
||||
self,
|
||||
transaction_label: int,
|
||||
is_command: bool,
|
||||
ipid: bool,
|
||||
pid: int,
|
||||
payload: bytes,
|
||||
):
|
||||
if pid == avrcp.AVRCP_PID:
|
||||
avc_frame = avc.Frame.from_bytes(payload)
|
||||
details = str(avc_frame)
|
||||
else:
|
||||
details = payload.hex()
|
||||
|
||||
c_r = 'Command' if is_command else 'Response'
|
||||
self.analyzer.emit(
|
||||
f'{color("AVCTP", "green")} '
|
||||
f'{c_r}[{transaction_label}][{name_or_number(AVCTP_PID_NAMES, pid)}] '
|
||||
f'{"#" if ipid else ""}'
|
||||
f'{details}'
|
||||
)
|
||||
|
||||
def feed_packet(self, packet: HCI_AclDataPacket) -> None:
|
||||
self.packet_assembler.feed_packet(packet)
|
||||
|
||||
class Analyzer:
|
||||
def __init__(self, label, emit_message):
|
||||
acl_streams: MutableMapping[int, PacketTracer.AclStream]
|
||||
peer: PacketTracer.Analyzer
|
||||
|
||||
def __init__(self, label: str, emit_message: Callable[..., None]) -> None:
|
||||
self.label = label
|
||||
self.emit_message = emit_message
|
||||
self.acl_streams = {} # ACL streams, by connection handle
|
||||
self.peer = None # Analyzer in the other direction
|
||||
self.packet_timestamp: Optional[datetime.datetime] = None
|
||||
|
||||
def start_acl_stream(self, connection_handle):
|
||||
def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream:
|
||||
logger.info(
|
||||
f'[{self.label}] +++ Creating ACL stream for connection '
|
||||
f'0x{connection_handle:04X}'
|
||||
@@ -160,7 +226,7 @@ class PacketTracer:
|
||||
|
||||
return stream
|
||||
|
||||
def end_acl_stream(self, connection_handle):
|
||||
def end_acl_stream(self, connection_handle: int) -> None:
|
||||
if connection_handle in self.acl_streams:
|
||||
logger.info(
|
||||
f'[{self.label}] --- Removing ACL stream for connection '
|
||||
@@ -171,34 +237,52 @@ class PacketTracer:
|
||||
# Let the other forwarder know so it can cleanup its stream as well
|
||||
self.peer.end_acl_stream(connection_handle)
|
||||
|
||||
def on_packet(self, packet):
|
||||
def on_packet(
|
||||
self, timestamp: Optional[datetime.datetime], packet: HCI_Packet
|
||||
) -> None:
|
||||
self.packet_timestamp = timestamp
|
||||
self.emit(packet)
|
||||
|
||||
if packet.hci_packet_type == HCI_ACL_DATA_PACKET:
|
||||
acl_packet = cast(HCI_AclDataPacket, packet)
|
||||
# Look for an existing stream for this handle, create one if it is the
|
||||
# first ACL packet for that connection handle
|
||||
if (stream := self.acl_streams.get(packet.connection_handle)) is None:
|
||||
stream = self.start_acl_stream(packet.connection_handle)
|
||||
stream.feed_packet(packet)
|
||||
if (
|
||||
stream := self.acl_streams.get(acl_packet.connection_handle)
|
||||
) is None:
|
||||
stream = self.start_acl_stream(acl_packet.connection_handle)
|
||||
stream.feed_packet(acl_packet)
|
||||
elif packet.hci_packet_type == HCI_EVENT_PACKET:
|
||||
if packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
|
||||
self.end_acl_stream(packet.connection_handle)
|
||||
event_packet = cast(HCI_Event, packet)
|
||||
if event_packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
|
||||
self.end_acl_stream(
|
||||
cast(HCI_Disconnection_Complete_Event, packet).connection_handle
|
||||
)
|
||||
|
||||
def emit(self, message):
|
||||
self.emit_message(f'[{self.label}] {message}')
|
||||
def emit(self, message: Any) -> None:
|
||||
if self.packet_timestamp:
|
||||
prefix = f"[{self.packet_timestamp.strftime('%Y-%m-%d %H:%M:%S.%f')}]"
|
||||
else:
|
||||
prefix = ""
|
||||
self.emit_message(f'{prefix}[{self.label}] {message}')
|
||||
|
||||
def trace(self, packet, direction=0):
|
||||
def trace(
|
||||
self,
|
||||
packet: HCI_Packet,
|
||||
direction: int = 0,
|
||||
timestamp: Optional[datetime.datetime] = None,
|
||||
) -> None:
|
||||
if direction == 0:
|
||||
self.host_to_controller_analyzer.on_packet(packet)
|
||||
self.host_to_controller_analyzer.on_packet(timestamp, packet)
|
||||
else:
|
||||
self.controller_to_host_analyzer.on_packet(packet)
|
||||
self.controller_to_host_analyzer.on_packet(timestamp, packet)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host_to_controller_label=color('HOST->CONTROLLER', 'blue'),
|
||||
controller_to_host_label=color('CONTROLLER->HOST', 'cyan'),
|
||||
emit_message=logger.info,
|
||||
):
|
||||
host_to_controller_label: str = color('HOST->CONTROLLER', 'blue'),
|
||||
controller_to_host_label: str = color('CONTROLLER->HOST', 'cyan'),
|
||||
emit_message: Callable[..., None] = logger.info,
|
||||
) -> None:
|
||||
self.host_to_controller_analyzer = PacketTracer.Analyzer(
|
||||
host_to_controller_label, emit_message
|
||||
)
|
||||
@@ -207,3 +291,15 @@ class PacketTracer:
|
||||
)
|
||||
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
|
||||
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer
|
||||
|
||||
|
||||
def generate_irk() -> bytes:
|
||||
return crypto.r()
|
||||
|
||||
|
||||
def verify_rpa_with_irk(rpa: Address, irk: bytes) -> bool:
|
||||
rpa_bytes = bytes(rpa)
|
||||
prand_given = rpa_bytes[3:]
|
||||
hash_given = rpa_bytes[:3]
|
||||
hash_local = crypto.ah(irk, prand_given)
|
||||
return hash_local[:3] == hash_given
|
||||
|
||||
1675
bumble/hfp.py
1675
bumble/hfp.py
File diff suppressed because it is too large
Load Diff
551
bumble/hid.py
Normal file
551
bumble/hid.py
Normal file
@@ -0,0 +1,551 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import enum
|
||||
import struct
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pyee import EventEmitter
|
||||
from typing import Optional, Callable
|
||||
from typing_extensions import override
|
||||
|
||||
from bumble import l2cap, device
|
||||
from bumble.core import InvalidStateError, ProtocolError
|
||||
from bumble.hci import Address
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
# fmt: on
|
||||
HID_CONTROL_PSM = 0x0011
|
||||
HID_INTERRUPT_PSM = 0x0013
|
||||
|
||||
|
||||
class Message:
|
||||
message_type: MessageType
|
||||
|
||||
# Report types
|
||||
class ReportType(enum.IntEnum):
|
||||
OTHER_REPORT = 0x00
|
||||
INPUT_REPORT = 0x01
|
||||
OUTPUT_REPORT = 0x02
|
||||
FEATURE_REPORT = 0x03
|
||||
|
||||
# Handshake parameters
|
||||
class Handshake(enum.IntEnum):
|
||||
SUCCESSFUL = 0x00
|
||||
NOT_READY = 0x01
|
||||
ERR_INVALID_REPORT_ID = 0x02
|
||||
ERR_UNSUPPORTED_REQUEST = 0x03
|
||||
ERR_INVALID_PARAMETER = 0x04
|
||||
ERR_UNKNOWN = 0x0E
|
||||
ERR_FATAL = 0x0F
|
||||
|
||||
# Message Type
|
||||
class MessageType(enum.IntEnum):
|
||||
HANDSHAKE = 0x00
|
||||
CONTROL = 0x01
|
||||
GET_REPORT = 0x04
|
||||
SET_REPORT = 0x05
|
||||
GET_PROTOCOL = 0x06
|
||||
SET_PROTOCOL = 0x07
|
||||
DATA = 0x0A
|
||||
|
||||
# Protocol modes
|
||||
class ProtocolMode(enum.IntEnum):
|
||||
BOOT_PROTOCOL = 0x00
|
||||
REPORT_PROTOCOL = 0x01
|
||||
|
||||
# Control Operations
|
||||
class ControlCommand(enum.IntEnum):
|
||||
SUSPEND = 0x03
|
||||
EXIT_SUSPEND = 0x04
|
||||
VIRTUAL_CABLE_UNPLUG = 0x05
|
||||
|
||||
# Class Method to derive header
|
||||
@classmethod
|
||||
def header(cls, lower_bits: int = 0x00) -> bytes:
|
||||
return bytes([(cls.message_type << 4) | lower_bits])
|
||||
|
||||
|
||||
# HIDP messages
|
||||
@dataclass
|
||||
class GetReportMessage(Message):
|
||||
report_type: int
|
||||
report_id: int
|
||||
buffer_size: int
|
||||
message_type = Message.MessageType.GET_REPORT
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
packet_bytes = bytearray()
|
||||
packet_bytes.append(self.report_id)
|
||||
if self.buffer_size == 0:
|
||||
return self.header(self.report_type) + packet_bytes
|
||||
else:
|
||||
return (
|
||||
self.header(0x08 | self.report_type)
|
||||
+ packet_bytes
|
||||
+ struct.pack("<H", self.buffer_size)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetReportMessage(Message):
|
||||
report_type: int
|
||||
data: bytes
|
||||
message_type = Message.MessageType.SET_REPORT
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendControlData(Message):
|
||||
report_type: int
|
||||
data: bytes
|
||||
message_type = Message.MessageType.DATA
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetProtocolMessage(Message):
|
||||
message_type = Message.MessageType.GET_PROTOCOL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetProtocolMessage(Message):
|
||||
protocol_mode: int
|
||||
message_type = Message.MessageType.SET_PROTOCOL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.protocol_mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Suspend(Message):
|
||||
message_type = Message.MessageType.CONTROL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(Message.ControlCommand.SUSPEND)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExitSuspend(Message):
|
||||
message_type = Message.MessageType.CONTROL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(Message.ControlCommand.EXIT_SUSPEND)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VirtualCableUnplug(Message):
|
||||
message_type = Message.MessageType.CONTROL
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
|
||||
|
||||
|
||||
# Device sends input report, host sends output report.
|
||||
@dataclass
|
||||
class SendData(Message):
|
||||
data: bytes
|
||||
report_type: int
|
||||
message_type = Message.MessageType.DATA
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendHandshakeMessage(Message):
|
||||
result_code: int
|
||||
message_type = Message.MessageType.HANDSHAKE
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.result_code)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class HID(ABC, EventEmitter):
|
||||
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
|
||||
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
|
||||
connection: Optional[device.Connection] = None
|
||||
|
||||
class Role(enum.IntEnum):
|
||||
HOST = 0x00
|
||||
DEVICE = 0x01
|
||||
|
||||
def __init__(self, device: device.Device, role: Role) -> None:
|
||||
super().__init__()
|
||||
self.remote_device_bd_address: Optional[Address] = None
|
||||
self.device = device
|
||||
self.role = role
|
||||
|
||||
# Register ourselves with the L2CAP channel manager
|
||||
device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
|
||||
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
|
||||
|
||||
device.on('connection', self.on_device_connection)
|
||||
|
||||
async def connect_control_channel(self) -> None:
|
||||
# Create a new L2CAP connection - control channel
|
||||
try:
|
||||
channel = await self.device.l2cap_channel_manager.connect(
|
||||
self.connection, HID_CONTROL_PSM
|
||||
)
|
||||
channel.sink = self.on_ctrl_pdu
|
||||
self.l2cap_ctrl_channel = channel
|
||||
except ProtocolError:
|
||||
logging.exception(f'L2CAP connection failed.')
|
||||
raise
|
||||
|
||||
async def connect_interrupt_channel(self) -> None:
|
||||
# Create a new L2CAP connection - interrupt channel
|
||||
try:
|
||||
channel = await self.device.l2cap_channel_manager.connect(
|
||||
self.connection, HID_INTERRUPT_PSM
|
||||
)
|
||||
channel.sink = self.on_intr_pdu
|
||||
self.l2cap_intr_channel = channel
|
||||
except ProtocolError:
|
||||
logging.exception(f'L2CAP connection failed.')
|
||||
raise
|
||||
|
||||
async def disconnect_interrupt_channel(self) -> None:
|
||||
if self.l2cap_intr_channel is None:
|
||||
raise InvalidStateError('invalid state')
|
||||
channel = self.l2cap_intr_channel
|
||||
self.l2cap_intr_channel = None
|
||||
await channel.disconnect()
|
||||
|
||||
async def disconnect_control_channel(self) -> None:
|
||||
if self.l2cap_ctrl_channel is None:
|
||||
raise InvalidStateError('invalid state')
|
||||
channel = self.l2cap_ctrl_channel
|
||||
self.l2cap_ctrl_channel = None
|
||||
await channel.disconnect()
|
||||
|
||||
def on_device_connection(self, connection: device.Connection) -> None:
|
||||
self.connection = connection
|
||||
self.remote_device_bd_address = connection.peer_address
|
||||
connection.on('disconnection', self.on_device_disconnection)
|
||||
|
||||
def on_device_disconnection(self, reason: int) -> None:
|
||||
self.connection = None
|
||||
|
||||
def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
|
||||
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
|
||||
l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
|
||||
|
||||
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
if l2cap_channel.psm == HID_CONTROL_PSM:
|
||||
self.l2cap_ctrl_channel = l2cap_channel
|
||||
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
|
||||
else:
|
||||
self.l2cap_intr_channel = l2cap_channel
|
||||
self.l2cap_intr_channel.sink = self.on_intr_pdu
|
||||
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
|
||||
|
||||
def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
if l2cap_channel.psm == HID_CONTROL_PSM:
|
||||
self.l2cap_ctrl_channel = None
|
||||
else:
|
||||
self.l2cap_intr_channel = None
|
||||
logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
|
||||
|
||||
@abstractmethod
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
pass
|
||||
|
||||
def on_intr_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
|
||||
self.emit("interrupt_data", pdu)
|
||||
|
||||
def send_pdu_on_ctrl(self, msg: bytes) -> None:
|
||||
assert self.l2cap_ctrl_channel
|
||||
self.l2cap_ctrl_channel.send_pdu(msg)
|
||||
|
||||
def send_pdu_on_intr(self, msg: bytes) -> None:
|
||||
assert self.l2cap_intr_channel
|
||||
self.l2cap_intr_channel.send_pdu(msg)
|
||||
|
||||
def send_data(self, data: bytes) -> None:
|
||||
if self.role == HID.Role.HOST:
|
||||
report_type = Message.ReportType.OUTPUT_REPORT
|
||||
else:
|
||||
report_type = Message.ReportType.INPUT_REPORT
|
||||
msg = SendData(data, report_type)
|
||||
hid_message = bytes(msg)
|
||||
if self.l2cap_intr_channel is not None:
|
||||
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_intr(hid_message)
|
||||
|
||||
def virtual_cable_unplug(self) -> None:
|
||||
msg = VirtualCableUnplug()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Device(HID):
|
||||
class GetSetReturn(enum.IntEnum):
|
||||
FAILURE = 0x00
|
||||
REPORT_ID_NOT_FOUND = 0x01
|
||||
ERR_UNSUPPORTED_REQUEST = 0x02
|
||||
ERR_UNKNOWN = 0x03
|
||||
ERR_INVALID_PARAMETER = 0x04
|
||||
SUCCESS = 0xFF
|
||||
|
||||
@dataclass
|
||||
class GetSetStatus:
|
||||
data: bytes = b''
|
||||
status: int = 0
|
||||
|
||||
get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
|
||||
set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
|
||||
get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
|
||||
set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
|
||||
|
||||
def __init__(self, device: device.Device) -> None:
|
||||
super().__init__(device, HID.Role.DEVICE)
|
||||
|
||||
@override
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
|
||||
param = pdu[0] & 0x0F
|
||||
message_type = pdu[0] >> 4
|
||||
|
||||
if message_type == Message.MessageType.GET_REPORT:
|
||||
logger.debug('<<< HID GET REPORT')
|
||||
self.handle_get_report(pdu)
|
||||
elif message_type == Message.MessageType.SET_REPORT:
|
||||
logger.debug('<<< HID SET REPORT')
|
||||
self.handle_set_report(pdu)
|
||||
elif message_type == Message.MessageType.GET_PROTOCOL:
|
||||
logger.debug('<<< HID GET PROTOCOL')
|
||||
self.handle_get_protocol(pdu)
|
||||
elif message_type == Message.MessageType.SET_PROTOCOL:
|
||||
logger.debug('<<< HID SET PROTOCOL')
|
||||
self.handle_set_protocol(pdu)
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('control_data', pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.SUSPEND:
|
||||
logger.debug('<<< HID SUSPEND')
|
||||
self.emit('suspend')
|
||||
elif param == Message.ControlCommand.EXIT_SUSPEND:
|
||||
logger.debug('<<< HID EXIT SUSPEND')
|
||||
self.emit('exit_suspend')
|
||||
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
else:
|
||||
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def send_handshake_message(self, result_code: int) -> None:
|
||||
msg = SendHandshakeMessage(result_code)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def send_control_data(self, report_type: int, data: bytes):
|
||||
msg = SendControlData(report_type=report_type, data=data)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def handle_get_report(self, pdu: bytes):
|
||||
if self.get_report_cb is None:
|
||||
logger.debug("GetReport callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
report_type = pdu[0] & 0x03
|
||||
buffer_flag = (pdu[0] & 0x08) >> 3
|
||||
report_id = pdu[1]
|
||||
logger.debug(f"buffer_flag: {buffer_flag}")
|
||||
if buffer_flag == 1:
|
||||
buffer_size = (pdu[3] << 8) | pdu[2]
|
||||
else:
|
||||
buffer_size = 0
|
||||
|
||||
ret = self.get_report_cb(report_id, report_type, buffer_size)
|
||||
if ret.status == self.GetSetReturn.FAILURE:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
|
||||
elif ret.status == self.GetSetReturn.SUCCESS:
|
||||
data = bytearray()
|
||||
data.append(report_id)
|
||||
data.extend(ret.data)
|
||||
if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr]
|
||||
self.send_control_data(report_type=report_type, data=data)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_get_report_cb(
|
||||
self, cb: Callable[[int, int, int], Device.GetSetStatus]
|
||||
) -> None:
|
||||
self.get_report_cb = cb
|
||||
logger.debug("GetReport callback registered successfully")
|
||||
|
||||
def handle_set_report(self, pdu: bytes):
|
||||
if self.set_report_cb is None:
|
||||
logger.debug("SetReport callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
report_type = pdu[0] & 0x03
|
||||
report_id = pdu[1]
|
||||
report_data = pdu[2:]
|
||||
report_size = len(report_data) + 1
|
||||
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_set_report_cb(
|
||||
self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus]
|
||||
) -> None:
|
||||
self.set_report_cb = cb
|
||||
logger.debug("SetReport callback registered successfully")
|
||||
|
||||
def handle_get_protocol(self, pdu: bytes):
|
||||
if self.get_protocol_cb is None:
|
||||
logger.debug("GetProtocol callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
ret = self.get_protocol_cb()
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None:
|
||||
self.get_protocol_cb = cb
|
||||
logger.debug("GetProtocol callback registered successfully")
|
||||
|
||||
def handle_set_protocol(self, pdu: bytes):
|
||||
if self.set_protocol_cb is None:
|
||||
logger.debug("SetProtocol callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
ret = self.set_protocol_cb(pdu[0] & 0x01)
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_set_protocol_cb(
|
||||
self, cb: Callable[[int], Device.GetSetStatus]
|
||||
) -> None:
|
||||
self.set_protocol_cb = cb
|
||||
logger.debug("SetProtocol callback registered successfully")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Host(HID):
|
||||
def __init__(self, device: device.Device) -> None:
|
||||
super().__init__(device, HID.Role.HOST)
|
||||
|
||||
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
|
||||
msg = GetReportMessage(
|
||||
report_type=report_type, report_id=report_id, buffer_size=buffer_size
|
||||
)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_report(self, report_type: int, data: bytes) -> None:
|
||||
msg = SetReportMessage(report_type=report_type, data=data)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def get_protocol(self) -> None:
|
||||
msg = GetProtocolMessage()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_protocol(self, protocol_mode: int) -> None:
|
||||
msg = SetProtocolMessage(protocol_mode=protocol_mode)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def suspend(self) -> None:
|
||||
msg = Suspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def exit_suspend(self) -> None:
|
||||
msg = ExitSuspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
@override
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
|
||||
param = pdu[0] & 0x0F
|
||||
message_type = pdu[0] >> 4
|
||||
if message_type == Message.MessageType.HANDSHAKE:
|
||||
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
|
||||
self.emit('handshake', Message.Handshake(param))
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('control_data', pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
else:
|
||||
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
|
||||
855
bumble/host.py
855
bumble/host.py
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,8 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
|
||||
from typing_extensions import Self
|
||||
|
||||
from .colors import color
|
||||
from .hci import Address
|
||||
@@ -128,10 +129,10 @@ class PairingKeys:
|
||||
|
||||
def print(self, prefix=''):
|
||||
keys_dict = self.to_dict()
|
||||
for (container_property, value) in keys_dict.items():
|
||||
for container_property, value in keys_dict.items():
|
||||
if isinstance(value, dict):
|
||||
print(f'{prefix}{color(container_property, "cyan")}:')
|
||||
for (key_property, key_value) in value.items():
|
||||
for key_property, key_value in value.items():
|
||||
print(f'{prefix} {color(key_property, "green")}: {key_value}')
|
||||
else:
|
||||
print(f'{prefix}{color(container_property, "cyan")}: {value}')
|
||||
@@ -158,7 +159,7 @@ class KeyStore:
|
||||
async def get_resolving_keys(self):
|
||||
all_keys = await self.get_all()
|
||||
resolving_keys = []
|
||||
for (name, keys) in all_keys:
|
||||
for name, keys in all_keys:
|
||||
if keys.irk is not None:
|
||||
if keys.address_type is None:
|
||||
address_type = Address.RANDOM_DEVICE_ADDRESS
|
||||
@@ -171,7 +172,7 @@ class KeyStore:
|
||||
async def print(self, prefix=''):
|
||||
entries = await self.get_all()
|
||||
separator = ''
|
||||
for (name, keys) in entries:
|
||||
for name, keys in entries:
|
||||
print(separator + prefix + color(name, 'yellow'))
|
||||
keys.print(prefix=prefix + ' ')
|
||||
separator = '\n'
|
||||
@@ -253,8 +254,10 @@ class JsonKeyStore(KeyStore):
|
||||
|
||||
logger.debug(f'JSON keystore: {self.filename}')
|
||||
|
||||
@staticmethod
|
||||
def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]:
|
||||
@classmethod
|
||||
def from_device(
|
||||
cls: Type[Self], device: Device, filename: Optional[str] = None
|
||||
) -> Self:
|
||||
if not filename:
|
||||
# Extract the filename from the config if there is one
|
||||
if device.config.keystore is not None:
|
||||
@@ -270,7 +273,7 @@ class JsonKeyStore(KeyStore):
|
||||
else:
|
||||
namespace = JsonKeyStore.DEFAULT_NAMESPACE
|
||||
|
||||
return JsonKeyStore(namespace, filename)
|
||||
return cls(namespace, filename)
|
||||
|
||||
async def load(self):
|
||||
# Try to open the file, without failing. If the file does not exist, it
|
||||
|
||||
687
bumble/l2cap.py
687
bumble/l2cap.py
File diff suppressed because it is too large
Load Diff
123
bumble/link.py
123
bumble/link.py
@@ -19,16 +19,25 @@ import logging
|
||||
import asyncio
|
||||
from functools import partial
|
||||
|
||||
from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT
|
||||
from bumble.core import (
|
||||
BT_PERIPHERAL_ROLE,
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_LE_TRANSPORT,
|
||||
InvalidStateError,
|
||||
)
|
||||
from bumble.colors import color
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
HCI_SUCCESS,
|
||||
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
|
||||
HCI_CONNECTION_TIMEOUT_ERROR,
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
HCI_PAGE_TIMEOUT_ERROR,
|
||||
HCI_Connection_Complete_Event,
|
||||
)
|
||||
from bumble import controller
|
||||
|
||||
from typing import Optional, Set
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -57,6 +66,8 @@ class LocalLink:
|
||||
Link bus for controllers to communicate with each other
|
||||
'''
|
||||
|
||||
controllers: Set[controller.Controller]
|
||||
|
||||
def __init__(self):
|
||||
self.controllers = set()
|
||||
self.pending_connection = None
|
||||
@@ -79,7 +90,9 @@ class LocalLink:
|
||||
return controller
|
||||
return None
|
||||
|
||||
def find_classic_controller(self, address):
|
||||
def find_classic_controller(
|
||||
self, address: Address
|
||||
) -> Optional[controller.Controller]:
|
||||
for controller in self.controllers:
|
||||
if controller.public_address == address:
|
||||
return controller
|
||||
@@ -109,6 +122,8 @@ class LocalLink:
|
||||
elif transport == BT_BR_EDR_TRANSPORT:
|
||||
destination_controller = self.find_classic_controller(destination_address)
|
||||
source_address = sender_controller.public_address
|
||||
else:
|
||||
raise ValueError("unsupported transport type")
|
||||
|
||||
if destination_controller is not None:
|
||||
destination_controller.on_link_acl_data(source_address, transport, data)
|
||||
@@ -188,6 +203,60 @@ class LocalLink:
|
||||
if peripheral_controller := self.find_controller(peripheral_address):
|
||||
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
|
||||
|
||||
def create_cis(
|
||||
self,
|
||||
central_controller: controller.Controller,
|
||||
peripheral_address: Address,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
|
||||
)
|
||||
if peripheral_controller := self.find_controller(peripheral_address):
|
||||
asyncio.get_running_loop().call_soon(
|
||||
peripheral_controller.on_link_cis_request,
|
||||
central_controller.random_address,
|
||||
cig_id,
|
||||
cis_id,
|
||||
)
|
||||
|
||||
def accept_cis(
|
||||
self,
|
||||
peripheral_controller: controller.Controller,
|
||||
central_address: Address,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
|
||||
)
|
||||
if central_controller := self.find_controller(central_address):
|
||||
asyncio.get_running_loop().call_soon(
|
||||
central_controller.on_link_cis_established, cig_id, cis_id
|
||||
)
|
||||
asyncio.get_running_loop().call_soon(
|
||||
peripheral_controller.on_link_cis_established, cig_id, cis_id
|
||||
)
|
||||
|
||||
def disconnect_cis(
|
||||
self,
|
||||
initiator_controller: controller.Controller,
|
||||
peer_address: Address,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
|
||||
)
|
||||
if peer_controller := self.find_controller(peer_address):
|
||||
asyncio.get_running_loop().call_soon(
|
||||
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
|
||||
)
|
||||
asyncio.get_running_loop().call_soon(
|
||||
peer_controller.on_link_cis_disconnected, cig_id, cis_id
|
||||
)
|
||||
|
||||
############################################################
|
||||
# Classic handlers
|
||||
############################################################
|
||||
@@ -271,6 +340,52 @@ class LocalLink:
|
||||
initiator_controller.public_address, int(not (initiator_new_role))
|
||||
)
|
||||
|
||||
def classic_sco_connect(
|
||||
self,
|
||||
initiator_controller: controller.Controller,
|
||||
responder_address: Address,
|
||||
link_type: int,
|
||||
):
|
||||
logger.debug(
|
||||
f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
|
||||
)
|
||||
responder_controller = self.find_classic_controller(responder_address)
|
||||
# Initiator controller should handle it.
|
||||
assert responder_controller
|
||||
|
||||
responder_controller.on_classic_connection_request(
|
||||
initiator_controller.public_address,
|
||||
link_type,
|
||||
)
|
||||
|
||||
def classic_accept_sco_connection(
|
||||
self,
|
||||
responder_controller: controller.Controller,
|
||||
initiator_address: Address,
|
||||
link_type: int,
|
||||
):
|
||||
logger.debug(
|
||||
f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
|
||||
)
|
||||
initiator_controller = self.find_classic_controller(initiator_address)
|
||||
if initiator_controller is None:
|
||||
responder_controller.on_classic_sco_connection_complete(
|
||||
responder_controller.public_address,
|
||||
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
|
||||
link_type,
|
||||
)
|
||||
return
|
||||
|
||||
async def task():
|
||||
initiator_controller.on_classic_sco_connection_complete(
|
||||
responder_controller.public_address, HCI_SUCCESS, link_type
|
||||
)
|
||||
|
||||
asyncio.create_task(task())
|
||||
responder_controller.on_classic_sco_connection_complete(
|
||||
initiator_controller.public_address, HCI_SUCCESS, link_type
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class RemoteLink:
|
||||
@@ -297,12 +412,12 @@ class RemoteLink:
|
||||
|
||||
def add_controller(self, controller):
|
||||
if self.controller:
|
||||
raise ValueError('controller already set')
|
||||
raise InvalidStateError('controller already set')
|
||||
self.controller = controller
|
||||
|
||||
def remove_controller(self, controller):
|
||||
if self.controller != controller:
|
||||
raise ValueError('controller mismatch')
|
||||
raise InvalidStateError('controller mismatch')
|
||||
self.controller = None
|
||||
|
||||
def get_pending_connection(self):
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from .hci import (
|
||||
@@ -35,7 +37,60 @@ from .smp import (
|
||||
SMP_ID_KEY_DISTRIBUTION_FLAG,
|
||||
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
|
||||
SMP_LINK_KEY_DISTRIBUTION_FLAG,
|
||||
OobContext,
|
||||
OobLegacyContext,
|
||||
OobSharedData,
|
||||
)
|
||||
from .core import AdvertisingData, LeRole
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class OobData:
|
||||
"""OOB data that can be sent from one device to another."""
|
||||
|
||||
address: Optional[Address] = None
|
||||
role: Optional[LeRole] = None
|
||||
shared_data: Optional[OobSharedData] = None
|
||||
legacy_context: Optional[OobLegacyContext] = None
|
||||
|
||||
@classmethod
|
||||
def from_ad(cls, ad: AdvertisingData) -> OobData:
|
||||
instance = cls()
|
||||
shared_data_c: Optional[bytes] = None
|
||||
shared_data_r: Optional[bytes] = None
|
||||
for ad_type, ad_data in ad.ad_structures:
|
||||
if ad_type == AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS:
|
||||
instance.address = Address(ad_data)
|
||||
elif ad_type == AdvertisingData.LE_ROLE:
|
||||
instance.role = LeRole(ad_data[0])
|
||||
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE:
|
||||
shared_data_c = ad_data
|
||||
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE:
|
||||
shared_data_r = ad_data
|
||||
elif ad_type == AdvertisingData.SECURITY_MANAGER_TK_VALUE:
|
||||
instance.legacy_context = OobLegacyContext(tk=ad_data)
|
||||
if shared_data_c and shared_data_r:
|
||||
instance.shared_data = OobSharedData(c=shared_data_c, r=shared_data_r)
|
||||
|
||||
return instance
|
||||
|
||||
def to_ad(self) -> AdvertisingData:
|
||||
ad_structures = []
|
||||
if self.address is not None:
|
||||
ad_structures.append(
|
||||
(AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address))
|
||||
)
|
||||
if self.role is not None:
|
||||
ad_structures.append((AdvertisingData.LE_ROLE, bytes([self.role])))
|
||||
if self.shared_data is not None:
|
||||
ad_structures.extend(self.shared_data.to_ad().ad_structures)
|
||||
if self.legacy_context is not None:
|
||||
ad_structures.append(
|
||||
(AdvertisingData.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk)
|
||||
)
|
||||
|
||||
return AdvertisingData(ad_structures)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -84,16 +139,19 @@ class PairingDelegate:
|
||||
io_capability: IoCapability
|
||||
local_initiator_key_distribution: KeyDistribution
|
||||
local_responder_key_distribution: KeyDistribution
|
||||
maximum_encryption_key_size: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
io_capability: IoCapability = NO_OUTPUT_NO_INPUT,
|
||||
local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
|
||||
local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
|
||||
maximum_encryption_key_size: int = 16,
|
||||
) -> None:
|
||||
self.io_capability = io_capability
|
||||
self.local_initiator_key_distribution = local_initiator_key_distribution
|
||||
self.local_responder_key_distribution = local_responder_key_distribution
|
||||
self.maximum_encryption_key_size = maximum_encryption_key_size
|
||||
|
||||
@property
|
||||
def classic_io_capability(self) -> int:
|
||||
@@ -173,6 +231,14 @@ class PairingConfig:
|
||||
PUBLIC = Address.PUBLIC_DEVICE_ADDRESS
|
||||
RANDOM = Address.RANDOM_DEVICE_ADDRESS
|
||||
|
||||
@dataclass
|
||||
class OobConfig:
|
||||
"""Config for OOB pairing."""
|
||||
|
||||
our_context: Optional[OobContext]
|
||||
peer_data: Optional[OobSharedData]
|
||||
legacy_context: Optional[OobLegacyContext]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sc: bool = True,
|
||||
@@ -180,17 +246,20 @@ class PairingConfig:
|
||||
bonding: bool = True,
|
||||
delegate: Optional[PairingDelegate] = None,
|
||||
identity_address_type: Optional[AddressType] = None,
|
||||
oob: Optional[OobConfig] = None,
|
||||
) -> None:
|
||||
self.sc = sc
|
||||
self.mitm = mitm
|
||||
self.bonding = bonding
|
||||
self.delegate = delegate or PairingDelegate()
|
||||
self.identity_address_type = identity_address_type
|
||||
self.oob = oob
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'PairingConfig(sc={self.sc}, '
|
||||
f'mitm={self.mitm}, bonding={self.bonding}, '
|
||||
f'identity_address_type={self.identity_address_type}, '
|
||||
f'delegate[{self.delegate.io_capability}])'
|
||||
f'delegate[{self.delegate.io_capability}]), '
|
||||
f'oob[{self.oob}])'
|
||||
)
|
||||
|
||||
@@ -25,8 +25,10 @@ import grpc.aio
|
||||
from .config import Config
|
||||
from .device import PandoraDevice
|
||||
from .host import HostService
|
||||
from .l2cap import L2CAPService
|
||||
from .security import SecurityService, SecurityStorageService
|
||||
from pandora.host_grpc_aio import add_HostServicer_to_server
|
||||
from pandora.l2cap_grpc_aio import add_L2CAPServicer_to_server
|
||||
from pandora.security_grpc_aio import (
|
||||
add_SecurityServicer_to_server,
|
||||
add_SecurityStorageServicer_to_server,
|
||||
@@ -77,6 +79,7 @@ async def serve(
|
||||
add_SecurityStorageServicer_to_server(
|
||||
SecurityStorageService(bumble.device, config), server
|
||||
)
|
||||
add_L2CAPServicer_to_server(L2CAPService(bumble.device, config), server)
|
||||
|
||||
# call hooks if any.
|
||||
for hook in _SERVICERS_HOOKS:
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
from bumble.pairing import PairingConfig, PairingDelegate
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
"""Generic & dependency free Bumble (reference) device."""
|
||||
|
||||
from __future__ import annotations
|
||||
from bumble import transport
|
||||
from bumble.core import (
|
||||
BT_GENERIC_AUDIO_SERVICE,
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import bumble.device
|
||||
import grpc
|
||||
@@ -27,12 +28,15 @@ from bumble.core import (
|
||||
BT_PERIPHERAL_ROLE,
|
||||
UUID,
|
||||
AdvertisingData,
|
||||
Appearance,
|
||||
ConnectionError,
|
||||
)
|
||||
from bumble.device import (
|
||||
DEVICE_DEFAULT_SCAN_INTERVAL,
|
||||
DEVICE_DEFAULT_SCAN_WINDOW,
|
||||
Advertisement,
|
||||
AdvertisingParameters,
|
||||
AdvertisingEventProperties,
|
||||
AdvertisingType,
|
||||
Device,
|
||||
)
|
||||
@@ -42,13 +46,17 @@ from bumble.hci import (
|
||||
HCI_PAGE_TIMEOUT_ERROR,
|
||||
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
Address,
|
||||
Phy,
|
||||
)
|
||||
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
||||
from pandora.host_grpc_aio import HostServicer
|
||||
from pandora import host_pb2
|
||||
from pandora.host_pb2 import (
|
||||
NOT_CONNECTABLE,
|
||||
NOT_DISCOVERABLE,
|
||||
DISCOVERABLE_LIMITED,
|
||||
DISCOVERABLE_GENERAL,
|
||||
PRIMARY_1M,
|
||||
PRIMARY_CODED,
|
||||
SECONDARY_1M,
|
||||
@@ -64,6 +72,7 @@ from pandora.host_pb2 import (
|
||||
ConnectResponse,
|
||||
DataTypes,
|
||||
DisconnectRequest,
|
||||
DiscoverabilityMode,
|
||||
InquiryResponse,
|
||||
PrimaryPhy,
|
||||
ReadLocalAddressResponse,
|
||||
@@ -93,6 +102,25 @@ SECONDARY_PHY_MAP: Dict[int, SecondaryPhy] = {
|
||||
3: SECONDARY_CODED,
|
||||
}
|
||||
|
||||
PRIMARY_PHY_TO_BUMBLE_PHY_MAP: Dict[PrimaryPhy, Phy] = {
|
||||
PRIMARY_1M: Phy.LE_1M,
|
||||
PRIMARY_CODED: Phy.LE_CODED,
|
||||
}
|
||||
|
||||
SECONDARY_PHY_TO_BUMBLE_PHY_MAP: Dict[SecondaryPhy, Phy] = {
|
||||
SECONDARY_NONE: Phy.LE_1M,
|
||||
SECONDARY_1M: Phy.LE_1M,
|
||||
SECONDARY_2M: Phy.LE_2M,
|
||||
SECONDARY_CODED: Phy.LE_CODED,
|
||||
}
|
||||
|
||||
OWN_ADDRESS_MAP: Dict[host_pb2.OwnAddressType, bumble.hci.OwnAddressType] = {
|
||||
host_pb2.PUBLIC: bumble.hci.OwnAddressType.PUBLIC,
|
||||
host_pb2.RANDOM: bumble.hci.OwnAddressType.RANDOM,
|
||||
host_pb2.RESOLVABLE_OR_PUBLIC: bumble.hci.OwnAddressType.RESOLVABLE_OR_PUBLIC,
|
||||
host_pb2.RESOLVABLE_OR_RANDOM: bumble.hci.OwnAddressType.RESOLVABLE_OR_RANDOM,
|
||||
}
|
||||
|
||||
|
||||
class HostService(HostServicer):
|
||||
waited_connections: Set[int]
|
||||
@@ -260,9 +288,9 @@ class HostService(HostServicer):
|
||||
self.log.debug(f"WaitDisconnection: {connection_handle}")
|
||||
|
||||
if connection := self.device.lookup_connection(connection_handle):
|
||||
disconnection_future: asyncio.Future[
|
||||
None
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
disconnection_future: asyncio.Future[None] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
def on_disconnection(_: None) -> None:
|
||||
disconnection_future.set_result(None)
|
||||
@@ -280,14 +308,118 @@ class HostService(HostServicer):
|
||||
async def Advertise(
|
||||
self, request: AdvertiseRequest, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[AdvertiseResponse, None]:
|
||||
if not request.legacy:
|
||||
raise NotImplementedError(
|
||||
"TODO: add support for extended advertising in Bumble"
|
||||
try:
|
||||
if request.legacy:
|
||||
async for rsp in self.legacy_advertise(request, context):
|
||||
yield rsp
|
||||
else:
|
||||
async for rsp in self.extended_advertise(request, context):
|
||||
yield rsp
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def extended_advertise(
|
||||
self, request: AdvertiseRequest, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[AdvertiseResponse, None]:
|
||||
advertising_data = bytes(self.unpack_data_types(request.data))
|
||||
scan_response_data = bytes(self.unpack_data_types(request.scan_response_data))
|
||||
scannable = len(scan_response_data) != 0
|
||||
|
||||
advertising_event_properties = AdvertisingEventProperties(
|
||||
is_connectable=request.connectable,
|
||||
is_scannable=scannable,
|
||||
is_directed=request.target is not None,
|
||||
is_high_duty_cycle_directed_connectable=False,
|
||||
is_legacy=False,
|
||||
is_anonymous=False,
|
||||
include_tx_power=False,
|
||||
)
|
||||
|
||||
peer_address = Address.ANY
|
||||
if request.target:
|
||||
# Need to reverse bytes order since Bumble Address is using MSB.
|
||||
target_bytes = bytes(reversed(request.target))
|
||||
if request.target_variant() == "public":
|
||||
peer_address = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
|
||||
else:
|
||||
peer_address = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
|
||||
|
||||
advertising_parameters = AdvertisingParameters(
|
||||
advertising_event_properties=advertising_event_properties,
|
||||
own_address_type=OWN_ADDRESS_MAP[request.own_address_type],
|
||||
peer_address=peer_address,
|
||||
primary_advertising_phy=PRIMARY_PHY_TO_BUMBLE_PHY_MAP[request.primary_phy],
|
||||
secondary_advertising_phy=SECONDARY_PHY_TO_BUMBLE_PHY_MAP[
|
||||
request.secondary_phy
|
||||
],
|
||||
)
|
||||
if advertising_interval := request.interval:
|
||||
advertising_parameters.primary_advertising_interval_min = int(
|
||||
advertising_interval
|
||||
)
|
||||
if request.interval:
|
||||
raise NotImplementedError("TODO: add support for `request.interval`")
|
||||
if request.interval_range:
|
||||
raise NotImplementedError("TODO: add support for `request.interval_range`")
|
||||
advertising_parameters.primary_advertising_interval_max = int(
|
||||
advertising_interval
|
||||
)
|
||||
if interval_range := request.interval_range:
|
||||
advertising_parameters.primary_advertising_interval_max += int(
|
||||
interval_range
|
||||
)
|
||||
|
||||
advertising_set = await self.device.create_advertising_set(
|
||||
advertising_parameters=advertising_parameters,
|
||||
advertising_data=advertising_data,
|
||||
scan_response_data=scan_response_data,
|
||||
)
|
||||
|
||||
pending_connection: asyncio.Future[bumble.device.Connection] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
if request.connectable:
|
||||
|
||||
def on_connection(connection: bumble.device.Connection) -> None:
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
):
|
||||
pending_connection.set_result(connection)
|
||||
|
||||
self.device.on('connection', on_connection)
|
||||
|
||||
try:
|
||||
# Advertise until RPC is canceled
|
||||
while True:
|
||||
if not advertising_set.enabled:
|
||||
self.log.debug('Advertise (extended)')
|
||||
await advertising_set.start()
|
||||
|
||||
if not request.connectable:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
connection = await pending_connection
|
||||
pending_connection = asyncio.get_running_loop().create_future()
|
||||
|
||||
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
|
||||
yield AdvertiseResponse(connection=Connection(cookie=cookie))
|
||||
|
||||
await asyncio.sleep(1)
|
||||
finally:
|
||||
try:
|
||||
self.log.debug('Stop Advertise (extended)')
|
||||
await advertising_set.stop()
|
||||
await advertising_set.remove()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def legacy_advertise(
|
||||
self, request: AdvertiseRequest, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[AdvertiseResponse, None]:
|
||||
if advertising_interval := request.interval:
|
||||
self.device.config.advertising_interval_min = int(advertising_interval)
|
||||
self.device.config.advertising_interval_max = int(advertising_interval)
|
||||
if interval_range := request.interval_range:
|
||||
self.device.config.advertising_interval_max += int(interval_range)
|
||||
if request.primary_phy:
|
||||
raise NotImplementedError("TODO: add support for `request.primary_phy`")
|
||||
if request.secondary_phy:
|
||||
@@ -355,14 +487,10 @@ class HostService(HostServicer):
|
||||
target_bytes = bytes(reversed(request.target))
|
||||
if request.target_variant() == "public":
|
||||
target = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
|
||||
advertising_type = (
|
||||
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
|
||||
) # FIXME: HIGH_DUTY ?
|
||||
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY
|
||||
else:
|
||||
target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
|
||||
advertising_type = (
|
||||
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
|
||||
) # FIXME: HIGH_DUTY ?
|
||||
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY
|
||||
|
||||
if request.connectable:
|
||||
|
||||
@@ -389,9 +517,9 @@ class HostService(HostServicer):
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
pending_connection: asyncio.Future[
|
||||
bumble.device.Connection
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
pending_connection: asyncio.Future[bumble.device.Connection] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
self.log.debug('Wait for LE connection...')
|
||||
connection = await pending_connection
|
||||
@@ -420,23 +548,31 @@ class HostService(HostServicer):
|
||||
self, request: ScanRequest, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[ScanningResponse, None]:
|
||||
# TODO: modify `start_scanning` to accept floats instead of int for ms values
|
||||
if request.phys:
|
||||
raise NotImplementedError("TODO: add support for `request.phys`")
|
||||
|
||||
self.log.debug('Scan')
|
||||
|
||||
scanning_phys = []
|
||||
if PRIMARY_1M in request.phys:
|
||||
scanning_phys.append(int(Phy.LE_1M))
|
||||
if PRIMARY_CODED in request.phys:
|
||||
scanning_phys.append(int(Phy.LE_CODED))
|
||||
if not scanning_phys:
|
||||
scanning_phys = [int(Phy.LE_1M), int(Phy.LE_CODED)]
|
||||
|
||||
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
|
||||
handler = self.device.on('advertisement', scan_queue.put_nowait)
|
||||
await self.device.start_scanning(
|
||||
legacy=request.legacy,
|
||||
active=not request.passive,
|
||||
own_address_type=request.own_address_type,
|
||||
scan_interval=int(request.interval)
|
||||
if request.interval
|
||||
else DEVICE_DEFAULT_SCAN_INTERVAL,
|
||||
scan_window=int(request.window)
|
||||
if request.window
|
||||
else DEVICE_DEFAULT_SCAN_WINDOW,
|
||||
scan_interval=(
|
||||
int(request.interval)
|
||||
if request.interval
|
||||
else DEVICE_DEFAULT_SCAN_INTERVAL
|
||||
),
|
||||
scan_window=(
|
||||
int(request.window) if request.window else DEVICE_DEFAULT_SCAN_WINDOW
|
||||
),
|
||||
scanning_phys=scanning_phys,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -649,9 +785,11 @@ class HostService(HostServicer):
|
||||
*struct.pack('<H', dt.peripheral_connection_interval_min),
|
||||
*struct.pack(
|
||||
'<H',
|
||||
dt.peripheral_connection_interval_max
|
||||
if dt.peripheral_connection_interval_max
|
||||
else dt.peripheral_connection_interval_min,
|
||||
(
|
||||
dt.peripheral_connection_interval_max
|
||||
if dt.peripheral_connection_interval_max
|
||||
else dt.peripheral_connection_interval_min
|
||||
),
|
||||
),
|
||||
]
|
||||
),
|
||||
@@ -733,6 +871,16 @@ class HostService(HostServicer):
|
||||
)
|
||||
)
|
||||
|
||||
flag_map = {
|
||||
NOT_DISCOVERABLE: 0x00,
|
||||
DISCOVERABLE_LIMITED: AdvertisingData.LE_LIMITED_DISCOVERABLE_MODE_FLAG,
|
||||
DISCOVERABLE_GENERAL: AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG,
|
||||
}
|
||||
|
||||
if dt.le_discoverability_mode:
|
||||
flags = flag_map[dt.le_discoverability_mode]
|
||||
ad_structures.append((AdvertisingData.FLAGS, flags.to_bytes(1, 'big')))
|
||||
|
||||
return AdvertisingData(ad_structures)
|
||||
|
||||
def pack_data_types(self, ad: AdvertisingData) -> DataTypes:
|
||||
@@ -841,8 +989,8 @@ class HostService(HostServicer):
|
||||
dt.random_target_addresses.extend(
|
||||
[data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))]
|
||||
)
|
||||
if i := cast(int, ad.get(AdvertisingData.APPEARANCE)):
|
||||
dt.appearance = i
|
||||
if appearance := cast(Appearance, ad.get(AdvertisingData.APPEARANCE)):
|
||||
dt.appearance = int(appearance)
|
||||
if i := cast(int, ad.get(AdvertisingData.ADVERTISING_INTERVAL)):
|
||||
dt.advertising_interval = i
|
||||
if s := cast(str, ad.get(AdvertisingData.URI)):
|
||||
|
||||
310
bumble/pandora/l2cap.py
Normal file
310
bumble/pandora/l2cap.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import grpc
|
||||
import json
|
||||
import logging
|
||||
|
||||
from asyncio import Queue as AsyncQueue, Future
|
||||
|
||||
from . import utils
|
||||
from .config import Config
|
||||
from bumble.core import OutOfResourcesError, InvalidArgumentError
|
||||
from bumble.device import Device
|
||||
from bumble.l2cap import (
|
||||
ClassicChannel,
|
||||
ClassicChannelServer,
|
||||
ClassicChannelSpec,
|
||||
LeCreditBasedChannel,
|
||||
LeCreditBasedChannelServer,
|
||||
LeCreditBasedChannelSpec,
|
||||
)
|
||||
from google.protobuf import any_pb2, empty_pb2 # pytype: disable=pyi-error
|
||||
from pandora.l2cap_grpc_aio import L2CAPServicer # pytype: disable=pyi-error
|
||||
from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error
|
||||
COMMAND_NOT_UNDERSTOOD,
|
||||
INVALID_CID_IN_REQUEST,
|
||||
Channel as PandoraChannel,
|
||||
ConnectRequest,
|
||||
ConnectResponse,
|
||||
CreditBasedChannelRequest,
|
||||
DisconnectRequest,
|
||||
DisconnectResponse,
|
||||
ReceiveRequest,
|
||||
ReceiveResponse,
|
||||
SendRequest,
|
||||
SendResponse,
|
||||
WaitConnectionRequest,
|
||||
WaitConnectionResponse,
|
||||
WaitDisconnectionRequest,
|
||||
WaitDisconnectionResponse,
|
||||
)
|
||||
from typing import AsyncGenerator, Dict, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
L2capChannel = Union[ClassicChannel, LeCreditBasedChannel]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChannelContext:
|
||||
close_future: Future
|
||||
sdu_queue: AsyncQueue
|
||||
|
||||
|
||||
class L2CAPService(L2CAPServicer):
|
||||
def __init__(self, device: Device, config: Config) -> None:
|
||||
self.log = utils.BumbleServerLoggerAdapter(
|
||||
logging.getLogger(), {'service_name': 'L2CAP', 'device': device}
|
||||
)
|
||||
self.device = device
|
||||
self.config = config
|
||||
self.channels: Dict[bytes, ChannelContext] = {}
|
||||
|
||||
def register_event(self, l2cap_channel: L2capChannel) -> ChannelContext:
|
||||
close_future = asyncio.get_running_loop().create_future()
|
||||
sdu_queue: AsyncQueue = AsyncQueue()
|
||||
|
||||
def on_channel_sdu(sdu):
|
||||
sdu_queue.put_nowait(sdu)
|
||||
|
||||
def on_close():
|
||||
close_future.set_result(None)
|
||||
|
||||
l2cap_channel.sink = on_channel_sdu
|
||||
l2cap_channel.on('close', on_close)
|
||||
|
||||
return ChannelContext(close_future, sdu_queue)
|
||||
|
||||
@utils.rpc
|
||||
async def WaitConnection(
|
||||
self, request: WaitConnectionRequest, context: grpc.ServicerContext
|
||||
) -> WaitConnectionResponse:
|
||||
self.log.debug('WaitConnection')
|
||||
if not request.connection:
|
||||
raise ValueError('A valid connection field must be set')
|
||||
|
||||
# find connection on device based on connection cookie value
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
connection = self.device.lookup_connection(connection_handle)
|
||||
|
||||
if not connection:
|
||||
raise ValueError('The connection specified is invalid.')
|
||||
|
||||
oneof = request.WhichOneof('type')
|
||||
self.log.debug(f'WaitConnection channel request type: {oneof}.')
|
||||
channel_type = getattr(request, oneof)
|
||||
spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
|
||||
l2cap_server: Optional[
|
||||
Union[ClassicChannelServer, LeCreditBasedChannelServer]
|
||||
] = None
|
||||
if isinstance(channel_type, CreditBasedChannelRequest):
|
||||
spec = LeCreditBasedChannelSpec(
|
||||
psm=channel_type.spsm,
|
||||
max_credits=channel_type.initial_credit,
|
||||
mtu=channel_type.mtu,
|
||||
mps=channel_type.mps,
|
||||
)
|
||||
if channel_type.spsm in self.device.l2cap_channel_manager.le_coc_servers:
|
||||
l2cap_server = self.device.l2cap_channel_manager.le_coc_servers[
|
||||
channel_type.spsm
|
||||
]
|
||||
else:
|
||||
spec = ClassicChannelSpec(
|
||||
psm=channel_type.psm,
|
||||
mtu=channel_type.mtu,
|
||||
)
|
||||
if channel_type.psm in self.device.l2cap_channel_manager.servers:
|
||||
l2cap_server = self.device.l2cap_channel_manager.servers[
|
||||
channel_type.psm
|
||||
]
|
||||
|
||||
self.log.info(f'Listening for L2CAP connection on PSM {spec.psm}')
|
||||
channel_future: Future[PandoraChannel] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
def on_l2cap_channel(l2cap_channel: L2capChannel):
|
||||
try:
|
||||
channel_context = self.register_event(l2cap_channel)
|
||||
pandora_channel: PandoraChannel = self.craft_pandora_channel(
|
||||
connection_handle, l2cap_channel
|
||||
)
|
||||
self.channels[pandora_channel.cookie.value] = channel_context
|
||||
channel_future.set_result(pandora_channel)
|
||||
except Exception as e:
|
||||
self.log.error(f'Failed to set channel future: {e}')
|
||||
|
||||
if l2cap_server is None:
|
||||
l2cap_server = self.device.create_l2cap_server(
|
||||
spec=spec, handler=on_l2cap_channel
|
||||
)
|
||||
else:
|
||||
l2cap_server.on('connection', on_l2cap_channel)
|
||||
|
||||
try:
|
||||
self.log.debug('Waiting for a channel connection.')
|
||||
pandora_channel: PandoraChannel = await channel_future
|
||||
|
||||
return WaitConnectionResponse(channel=pandora_channel)
|
||||
except Exception as e:
|
||||
self.log.warning(f'Exception: {e}')
|
||||
|
||||
return WaitConnectionResponse(error=COMMAND_NOT_UNDERSTOOD)
|
||||
|
||||
@utils.rpc
|
||||
async def WaitDisconnection(
|
||||
self, request: WaitDisconnectionRequest, context: grpc.ServicerContext
|
||||
) -> WaitDisconnectionResponse:
|
||||
try:
|
||||
self.log.debug('WaitDisconnection')
|
||||
|
||||
await self.lookup_context(request.channel).close_future
|
||||
self.log.debug("return WaitDisconnectionResponse")
|
||||
return WaitDisconnectionResponse(success=empty_pb2.Empty())
|
||||
except KeyError as e:
|
||||
self.log.warning(f'WaitDisconnection: Unable to find the channel: {e}')
|
||||
return WaitDisconnectionResponse(error=INVALID_CID_IN_REQUEST)
|
||||
except Exception as e:
|
||||
self.log.exception(f'WaitDisonnection failed: {e}')
|
||||
return WaitDisconnectionResponse(error=COMMAND_NOT_UNDERSTOOD)
|
||||
|
||||
@utils.rpc
|
||||
async def Receive(
|
||||
self, request: ReceiveRequest, context: grpc.ServicerContext
|
||||
) -> AsyncGenerator[ReceiveResponse, None]:
|
||||
self.log.debug('Receive')
|
||||
oneof = request.WhichOneof('source')
|
||||
self.log.debug(f'Source: {oneof}.')
|
||||
pandora_channel = getattr(request, oneof)
|
||||
|
||||
sdu_queue = self.lookup_context(pandora_channel).sdu_queue
|
||||
|
||||
while sdu := await sdu_queue.get():
|
||||
self.log.debug(f'Receive: Received {len(sdu)} bytes -> {sdu.decode()}')
|
||||
response = ReceiveResponse(data=sdu)
|
||||
yield response
|
||||
|
||||
@utils.rpc
|
||||
async def Connect(
|
||||
self, request: ConnectRequest, context: grpc.ServicerContext
|
||||
) -> ConnectResponse:
|
||||
self.log.debug('Connect')
|
||||
|
||||
if not request.connection:
|
||||
raise ValueError('A valid connection field must be set')
|
||||
|
||||
# find connection on device based on connection cookie value
|
||||
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
|
||||
connection = self.device.lookup_connection(connection_handle)
|
||||
|
||||
if not connection:
|
||||
raise ValueError('The connection specified is invalid.')
|
||||
|
||||
oneof = request.WhichOneof('type')
|
||||
self.log.debug(f'Channel request type: {oneof}.')
|
||||
channel_type = getattr(request, oneof)
|
||||
spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
|
||||
if isinstance(channel_type, CreditBasedChannelRequest):
|
||||
spec = LeCreditBasedChannelSpec(
|
||||
psm=channel_type.spsm,
|
||||
max_credits=channel_type.initial_credit,
|
||||
mtu=channel_type.mtu,
|
||||
mps=channel_type.mps,
|
||||
)
|
||||
else:
|
||||
spec = ClassicChannelSpec(
|
||||
psm=channel_type.psm,
|
||||
mtu=channel_type.mtu,
|
||||
)
|
||||
|
||||
try:
|
||||
self.log.info(f'Opening L2CAP channel on PSM = {spec.psm}')
|
||||
l2cap_channel = await connection.create_l2cap_channel(spec=spec)
|
||||
channel_context = self.register_event(l2cap_channel)
|
||||
pandora_channel = self.craft_pandora_channel(
|
||||
connection_handle, l2cap_channel
|
||||
)
|
||||
self.channels[pandora_channel.cookie.value] = channel_context
|
||||
|
||||
return ConnectResponse(channel=pandora_channel)
|
||||
|
||||
except OutOfResourcesError as e:
|
||||
self.log.error(e)
|
||||
return ConnectResponse(error=INVALID_CID_IN_REQUEST)
|
||||
except InvalidArgumentError as e:
|
||||
self.log.error(e)
|
||||
return ConnectResponse(error=COMMAND_NOT_UNDERSTOOD)
|
||||
|
||||
@utils.rpc
|
||||
async def Disconnect(
|
||||
self, request: DisconnectRequest, context: grpc.ServicerContext
|
||||
) -> DisconnectResponse:
|
||||
try:
|
||||
self.log.debug('Disconnect')
|
||||
l2cap_channel = self.lookup_channel(request.channel)
|
||||
if not l2cap_channel:
|
||||
self.log.warning('Disconnect: Unable to find the channel')
|
||||
return DisconnectResponse(error=INVALID_CID_IN_REQUEST)
|
||||
|
||||
await l2cap_channel.disconnect()
|
||||
return DisconnectResponse(success=empty_pb2.Empty())
|
||||
except Exception as e:
|
||||
self.log.exception(f'Disonnect failed: {e}')
|
||||
return DisconnectResponse(error=COMMAND_NOT_UNDERSTOOD)
|
||||
|
||||
@utils.rpc
|
||||
async def Send(
|
||||
self, request: SendRequest, context: grpc.ServicerContext
|
||||
) -> SendResponse:
|
||||
self.log.debug('Send')
|
||||
try:
|
||||
oneof = request.WhichOneof('sink')
|
||||
self.log.debug(f'Sink: {oneof}.')
|
||||
pandora_channel = getattr(request, oneof)
|
||||
|
||||
l2cap_channel = self.lookup_channel(pandora_channel)
|
||||
if not l2cap_channel:
|
||||
return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
|
||||
if isinstance(l2cap_channel, ClassicChannel):
|
||||
l2cap_channel.send_pdu(request.data)
|
||||
else:
|
||||
l2cap_channel.write(request.data)
|
||||
return SendResponse(success=empty_pb2.Empty())
|
||||
except Exception as e:
|
||||
self.log.exception(f'Disonnect failed: {e}')
|
||||
return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
|
||||
|
||||
def craft_pandora_channel(
|
||||
self,
|
||||
connection_handle: int,
|
||||
l2cap_channel: L2capChannel,
|
||||
) -> PandoraChannel:
|
||||
parameters = {
|
||||
"connection_handle": connection_handle,
|
||||
"source_cid": l2cap_channel.source_cid,
|
||||
}
|
||||
cookie = any_pb2.Any()
|
||||
cookie.value = json.dumps(parameters).encode()
|
||||
return PandoraChannel(cookie=cookie)
|
||||
|
||||
def lookup_channel(self, pandora_channel: PandoraChannel) -> L2capChannel:
|
||||
(connection_handle, source_cid) = json.loads(
|
||||
pandora_channel.cookie.value
|
||||
).values()
|
||||
|
||||
return self.device.l2cap_channel_manager.channels[connection_handle][source_cid]
|
||||
|
||||
def lookup_context(self, pandora_channel: PandoraChannel) -> ChannelContext:
|
||||
return self.channels[pandora_channel.cookie.value]
|
||||
@@ -12,7 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import grpc
|
||||
import logging
|
||||
|
||||
@@ -27,8 +29,8 @@ from bumble.core import (
|
||||
)
|
||||
from bumble.device import Connection as BumbleConnection, Device
|
||||
from bumble.hci import HCI_Error
|
||||
from bumble.utils import EventWatcher
|
||||
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
|
||||
from contextlib import suppress
|
||||
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
|
||||
@@ -108,7 +110,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
|
||||
event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
answer = await anext(self.service.event_answer) # type: ignore
|
||||
assert answer.event == event
|
||||
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
|
||||
return answer.confirm
|
||||
@@ -123,7 +125,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
|
||||
event = self.add_origin(PairingEvent(numeric_comparison=number))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
answer = await anext(self.service.event_answer) # type: ignore
|
||||
assert answer.event == event
|
||||
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
|
||||
return answer.confirm
|
||||
@@ -138,7 +140,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
|
||||
event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
answer = await anext(self.service.event_answer) # type: ignore
|
||||
assert answer.event == event
|
||||
if answer.answer_variant() is None:
|
||||
return None
|
||||
@@ -155,7 +157,7 @@ class PairingDelegate(BasePairingDelegate):
|
||||
|
||||
event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
|
||||
self.service.event_queue.put_nowait(event)
|
||||
answer = await anext(self.service.event_answer) # pytype: disable=name-error
|
||||
answer = await anext(self.service.event_answer) # type: ignore
|
||||
assert answer.event == event
|
||||
if answer.answer_variant() is None:
|
||||
return None
|
||||
@@ -232,7 +234,11 @@ class SecurityService(SecurityServicer):
|
||||
sc=config.pairing_sc_enable,
|
||||
mitm=config.pairing_mitm_enable,
|
||||
bonding=config.pairing_bonding_enable,
|
||||
identity_address_type=config.identity_address_type,
|
||||
identity_address_type=(
|
||||
PairingConfig.AddressType.PUBLIC
|
||||
if connection.self_address.is_public
|
||||
else config.identity_address_type
|
||||
),
|
||||
delegate=PairingDelegate(
|
||||
connection,
|
||||
self,
|
||||
@@ -294,23 +300,35 @@ class SecurityService(SecurityServicer):
|
||||
try:
|
||||
self.log.debug('Pair...')
|
||||
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
):
|
||||
wait_for_security: asyncio.Future[
|
||||
bool
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
|
||||
connection.on("pairing_failure", wait_for_security.set_exception)
|
||||
security_result = asyncio.get_running_loop().create_future()
|
||||
|
||||
connection.request_pairing()
|
||||
with contextlib.closing(EventWatcher()) as watcher:
|
||||
|
||||
await wait_for_security
|
||||
else:
|
||||
await connection.pair()
|
||||
@watcher.on(connection, 'pairing')
|
||||
def on_pairing(*_: Any) -> None:
|
||||
security_result.set_result('success')
|
||||
|
||||
self.log.debug('Paired')
|
||||
@watcher.on(connection, 'pairing_failure')
|
||||
def on_pairing_failure(*_: Any) -> None:
|
||||
security_result.set_result('pairing_failure')
|
||||
|
||||
@watcher.on(connection, 'disconnection')
|
||||
def on_disconnection(*_: Any) -> None:
|
||||
security_result.set_result('connection_died')
|
||||
|
||||
if (
|
||||
connection.transport == BT_LE_TRANSPORT
|
||||
and connection.role == BT_PERIPHERAL_ROLE
|
||||
):
|
||||
connection.request_pairing()
|
||||
else:
|
||||
await connection.pair()
|
||||
|
||||
result = await security_result
|
||||
|
||||
self.log.debug(f'Pairing session complete, status={result}')
|
||||
if result != 'success':
|
||||
return SecureResponse(**{result: empty_pb2.Empty()})
|
||||
except asyncio.CancelledError:
|
||||
self.log.warning("Connection died during encryption")
|
||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||
@@ -365,10 +383,11 @@ class SecurityService(SecurityServicer):
|
||||
connection.transport
|
||||
] == request.level_variant()
|
||||
|
||||
wait_for_security: asyncio.Future[
|
||||
str
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
wait_for_security: asyncio.Future[str] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
authenticate_task: Optional[asyncio.Future[None]] = None
|
||||
pair_task: Optional[asyncio.Future[None]] = None
|
||||
|
||||
async def authenticate() -> None:
|
||||
assert connection
|
||||
@@ -415,6 +434,10 @@ class SecurityService(SecurityServicer):
|
||||
if authenticate_task is None:
|
||||
authenticate_task = asyncio.create_task(authenticate())
|
||||
|
||||
def pair(*_: Any) -> None:
|
||||
if self.need_pairing(connection, level):
|
||||
pair_task = asyncio.create_task(connection.pair())
|
||||
|
||||
listeners: Dict[str, Callable[..., None]] = {
|
||||
'disconnection': set_failure('connection_died'),
|
||||
'pairing_failure': set_failure('pairing_failure'),
|
||||
@@ -423,23 +446,23 @@ class SecurityService(SecurityServicer):
|
||||
'pairing': try_set_success,
|
||||
'connection_authentication': try_set_success,
|
||||
'connection_encryption_change': on_encryption_change,
|
||||
'classic_pairing': try_set_success,
|
||||
'classic_pairing_failure': set_failure('pairing_failure'),
|
||||
'security_request': pair,
|
||||
}
|
||||
|
||||
# register event handlers
|
||||
for event, listener in listeners.items():
|
||||
connection.on(event, listener)
|
||||
with contextlib.closing(EventWatcher()) as watcher:
|
||||
# register event handlers
|
||||
for event, listener in listeners.items():
|
||||
watcher.on(connection, event, listener)
|
||||
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
return WaitSecurityResponse(success=empty_pb2.Empty())
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
return WaitSecurityResponse(success=empty_pb2.Empty())
|
||||
|
||||
self.log.debug('Wait for security...')
|
||||
kwargs = {}
|
||||
kwargs[await wait_for_security] = empty_pb2.Empty()
|
||||
|
||||
# remove event handlers
|
||||
for event, listener in listeners.items():
|
||||
connection.remove_listener(event, listener) # type: ignore
|
||||
self.log.debug('Wait for security...')
|
||||
kwargs = {}
|
||||
kwargs[await wait_for_security] = empty_pb2.Empty()
|
||||
|
||||
# wait for `authenticate` to finish if any
|
||||
if authenticate_task is not None:
|
||||
@@ -450,6 +473,15 @@ class SecurityService(SecurityServicer):
|
||||
pass
|
||||
self.log.debug('Authenticated')
|
||||
|
||||
# wait for `pair` to finish if any
|
||||
if pair_task is not None:
|
||||
self.log.debug('Wait for authentication...')
|
||||
try:
|
||||
await pair_task # type: ignore
|
||||
except:
|
||||
pass
|
||||
self.log.debug('paired')
|
||||
|
||||
return WaitSecurityResponse(**kwargs)
|
||||
|
||||
def reached_security_level(
|
||||
@@ -521,7 +553,7 @@ class SecurityStorageService(SecurityStorageServicer):
|
||||
self.log.debug(f"DeleteBond: {address}")
|
||||
|
||||
if self.device.keystore is not None:
|
||||
with suppress(KeyError):
|
||||
with contextlib.suppress(KeyError):
|
||||
await self.device.keystore.delete(str(address))
|
||||
|
||||
return empty_pb2.Empty()
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import functools
|
||||
import grpc
|
||||
|
||||
504
bumble/profiles/aics.py
Normal file
504
bumble/profiles/aics.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""LE Audio - Audio Input Control Service"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from bumble import gatt
|
||||
from bumble.device import Connection
|
||||
from bumble.att import ATT_Error
|
||||
from bumble.gatt import (
|
||||
Characteristic,
|
||||
SerializableCharacteristicAdapter,
|
||||
PackedCharacteristicAdapter,
|
||||
TemplateService,
|
||||
CharacteristicValue,
|
||||
UTF8CharacteristicAdapter,
|
||||
GATT_AUDIO_INPUT_CONTROL_SERVICE,
|
||||
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
|
||||
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
|
||||
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
|
||||
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
|
||||
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
|
||||
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
|
||||
)
|
||||
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
|
||||
from bumble.utils import OpenIntEnum
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
CHANGE_COUNTER_MAX_VALUE = 0xFF
|
||||
GAIN_SETTINGS_MIN_VALUE = 0
|
||||
GAIN_SETTINGS_MAX_VALUE = 255
|
||||
|
||||
|
||||
class ErrorCode(OpenIntEnum):
|
||||
'''
|
||||
Cf. 1.6 Application error codes
|
||||
'''
|
||||
|
||||
INVALID_CHANGE_COUNTER = 0x80
|
||||
OPCODE_NOT_SUPPORTED = 0x81
|
||||
MUTE_DISABLED = 0x82
|
||||
VALUE_OUT_OF_RANGE = 0x83
|
||||
GAIN_MODE_CHANGE_NOT_ALLOWED = 0x84
|
||||
|
||||
|
||||
class Mute(OpenIntEnum):
|
||||
'''
|
||||
Cf. 2.2.1.2 Mute Field
|
||||
'''
|
||||
|
||||
NOT_MUTED = 0x00
|
||||
MUTED = 0x01
|
||||
DISABLED = 0x02
|
||||
|
||||
|
||||
class GainMode(OpenIntEnum):
|
||||
'''
|
||||
Cf. 2.2.1.3 Gain Mode
|
||||
'''
|
||||
|
||||
MANUAL_ONLY = 0x00
|
||||
AUTOMATIC_ONLY = 0x01
|
||||
MANUAL = 0x02
|
||||
AUTOMATIC = 0x03
|
||||
|
||||
|
||||
class AudioInputStatus(OpenIntEnum):
|
||||
'''
|
||||
Cf. 3.4 Audio Input Status
|
||||
'''
|
||||
|
||||
INACTIVE = 0x00
|
||||
ACTIVE = 0x01
|
||||
|
||||
|
||||
class AudioInputControlPointOpCode(OpenIntEnum):
|
||||
'''
|
||||
Cf. 3.5.1 Audio Input Control Point procedure requirements
|
||||
'''
|
||||
|
||||
SET_GAIN_SETTING = 0x01
|
||||
UNMUTE = 0x02
|
||||
MUTE = 0x03
|
||||
SET_MANUAL_GAIN_MODE = 0x04
|
||||
SET_AUTOMATIC_GAIN_MODE = 0x05
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class AudioInputState:
|
||||
'''
|
||||
Cf. 2.2.1 Audio Input State
|
||||
'''
|
||||
|
||||
gain_settings: int = 0
|
||||
mute: Mute = Mute.NOT_MUTED
|
||||
gain_mode: GainMode = GainMode.MANUAL
|
||||
change_counter: int = 0
|
||||
attribute_value: Optional[CharacteristicValue] = None
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes(
|
||||
[self.gain_settings, self.mute, self.gain_mode, self.change_counter]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
gain_settings, mute, gain_mode, change_counter = struct.unpack("BBBB", data)
|
||||
return cls(gain_settings, mute, gain_mode, change_counter)
|
||||
|
||||
def update_gain_settings_unit(self, gain_settings_unit: int) -> None:
|
||||
self.gain_settings_unit = gain_settings_unit
|
||||
|
||||
def increment_gain_settings(self, gain_settings_unit: int) -> None:
|
||||
self.gain_settings += gain_settings_unit
|
||||
self.increment_change_counter()
|
||||
|
||||
def decrement_gain_settings(self) -> None:
|
||||
self.gain_settings -= self.gain_settings_unit
|
||||
self.increment_change_counter()
|
||||
|
||||
def increment_change_counter(self):
|
||||
self.change_counter = (self.change_counter + 1) % (CHANGE_COUNTER_MAX_VALUE + 1)
|
||||
|
||||
async def notify_subscribers_via_connection(self, connection: Connection) -> None:
|
||||
assert self.attribute_value is not None
|
||||
await connection.device.notify_subscribers(
|
||||
attribute=self.attribute_value, value=bytes(self)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GainSettingsProperties:
|
||||
'''
|
||||
Cf. 3.2 Gain Settings Properties
|
||||
'''
|
||||
|
||||
gain_settings_unit: int = 1
|
||||
gain_settings_minimum: int = GAIN_SETTINGS_MIN_VALUE
|
||||
gain_settings_maximum: int = GAIN_SETTINGS_MAX_VALUE
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
(gain_settings_unit, gain_settings_minimum, gain_settings_maximum) = (
|
||||
struct.unpack('BBB', data)
|
||||
)
|
||||
return GainSettingsProperties(
|
||||
gain_settings_unit, gain_settings_minimum, gain_settings_maximum
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes(
|
||||
[
|
||||
self.gain_settings_unit,
|
||||
self.gain_settings_minimum,
|
||||
self.gain_settings_maximum,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioInputControlPoint:
|
||||
'''
|
||||
Cf. 3.5.2 Audio Input Control Point
|
||||
'''
|
||||
|
||||
audio_input_state: AudioInputState
|
||||
gain_settings_properties: GainSettingsProperties
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
assert connection
|
||||
|
||||
opcode = AudioInputControlPointOpCode(value[0])
|
||||
|
||||
if opcode == AudioInputControlPointOpCode.SET_GAIN_SETTING:
|
||||
gain_settings_operand = value[2]
|
||||
await self._set_gain_settings(connection, gain_settings_operand)
|
||||
elif opcode == AudioInputControlPointOpCode.UNMUTE:
|
||||
await self._unmute(connection)
|
||||
elif opcode == AudioInputControlPointOpCode.MUTE:
|
||||
change_counter_operand = value[1]
|
||||
await self._mute(connection, change_counter_operand)
|
||||
elif opcode == AudioInputControlPointOpCode.SET_MANUAL_GAIN_MODE:
|
||||
await self._set_manual_gain_mode(connection)
|
||||
elif opcode == AudioInputControlPointOpCode.SET_AUTOMATIC_GAIN_MODE:
|
||||
await self._set_automatic_gain_mode(connection)
|
||||
else:
|
||||
logger.error(f"OpCode value is incorrect: {opcode}")
|
||||
raise ATT_Error(ErrorCode.OPCODE_NOT_SUPPORTED)
|
||||
|
||||
async def _set_gain_settings(
|
||||
self, connection: Connection, gain_settings_operand: int
|
||||
) -> None:
|
||||
'''Cf. 3.5.2.1 Set Gain Settings Procedure'''
|
||||
|
||||
gain_mode = self.audio_input_state.gain_mode
|
||||
|
||||
logger.error(f"set_gain_setting: gain_mode: {gain_mode}")
|
||||
if not (gain_mode == GainMode.MANUAL or gain_mode == GainMode.MANUAL_ONLY):
|
||||
logger.warning(
|
||||
"GainMode should be either MANUAL or MANUAL_ONLY Cf Spec Audio Input Control Service 3.5.2.1"
|
||||
)
|
||||
return
|
||||
|
||||
if (
|
||||
gain_settings_operand < self.gain_settings_properties.gain_settings_minimum
|
||||
or gain_settings_operand
|
||||
> self.gain_settings_properties.gain_settings_maximum
|
||||
):
|
||||
logger.error("gain_settings value out of range")
|
||||
raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
|
||||
|
||||
if self.audio_input_state.gain_settings != gain_settings_operand:
|
||||
self.audio_input_state.gain_settings = gain_settings_operand
|
||||
await self.audio_input_state.notify_subscribers_via_connection(connection)
|
||||
|
||||
async def _unmute(self, connection: Connection):
|
||||
'''Cf. 3.5.2.2 Unmute procedure'''
|
||||
|
||||
logger.error(f'unmute: {self.audio_input_state.mute}')
|
||||
mute = self.audio_input_state.mute
|
||||
if mute == Mute.DISABLED:
|
||||
logger.error("unmute: Cannot change Mute value, Mute state is DISABLED")
|
||||
raise ATT_Error(ErrorCode.MUTE_DISABLED)
|
||||
|
||||
if mute == Mute.NOT_MUTED:
|
||||
return
|
||||
|
||||
self.audio_input_state.mute = Mute.NOT_MUTED
|
||||
self.audio_input_state.increment_change_counter()
|
||||
await self.audio_input_state.notify_subscribers_via_connection(connection)
|
||||
|
||||
async def _mute(self, connection: Connection, change_counter_operand: int) -> None:
|
||||
'''Cf. 3.5.5.2 Mute procedure'''
|
||||
|
||||
change_counter = self.audio_input_state.change_counter
|
||||
mute = self.audio_input_state.mute
|
||||
if mute == Mute.DISABLED:
|
||||
logger.error("mute: Cannot change Mute value, Mute state is DISABLED")
|
||||
raise ATT_Error(ErrorCode.MUTE_DISABLED)
|
||||
|
||||
if change_counter != change_counter_operand:
|
||||
raise ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
|
||||
|
||||
if mute == Mute.MUTED:
|
||||
return
|
||||
|
||||
self.audio_input_state.mute = Mute.MUTED
|
||||
self.audio_input_state.increment_change_counter()
|
||||
await self.audio_input_state.notify_subscribers_via_connection(connection)
|
||||
|
||||
async def _set_manual_gain_mode(self, connection: Connection) -> None:
|
||||
'''Cf. 3.5.2.4 Set Manual Gain Mode procedure'''
|
||||
|
||||
gain_mode = self.audio_input_state.gain_mode
|
||||
if gain_mode in (GainMode.AUTOMATIC_ONLY, GainMode.MANUAL_ONLY):
|
||||
logger.error(f"Cannot change gain_mode, bad state: {gain_mode}")
|
||||
raise ATT_Error(ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED)
|
||||
|
||||
if gain_mode == GainMode.MANUAL:
|
||||
return
|
||||
|
||||
self.audio_input_state.gain_mode = GainMode.MANUAL
|
||||
self.audio_input_state.increment_change_counter()
|
||||
await self.audio_input_state.notify_subscribers_via_connection(connection)
|
||||
|
||||
async def _set_automatic_gain_mode(self, connection: Connection) -> None:
|
||||
'''Cf. 3.5.2.5 Set Automatic Gain Mode'''
|
||||
|
||||
gain_mode = self.audio_input_state.gain_mode
|
||||
if gain_mode in (GainMode.AUTOMATIC_ONLY, GainMode.MANUAL_ONLY):
|
||||
logger.error(f"Cannot change gain_mode, bad state: {gain_mode}")
|
||||
raise ATT_Error(ErrorCode.GAIN_MODE_CHANGE_NOT_ALLOWED)
|
||||
|
||||
if gain_mode == GainMode.AUTOMATIC:
|
||||
return
|
||||
|
||||
self.audio_input_state.gain_mode = GainMode.AUTOMATIC
|
||||
self.audio_input_state.increment_change_counter()
|
||||
await self.audio_input_state.notify_subscribers_via_connection(connection)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioInputDescription:
|
||||
'''
|
||||
Cf. 3.6 Audio Input Description
|
||||
'''
|
||||
|
||||
audio_input_description: str = "Bluetooth"
|
||||
attribute_value: Optional[CharacteristicValue] = None
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> str:
|
||||
return self.audio_input_description
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: str) -> None:
|
||||
assert connection
|
||||
assert self.attribute_value
|
||||
|
||||
self.audio_input_description = value
|
||||
await connection.device.notify_subscribers(
|
||||
attribute=self.attribute_value, value=value
|
||||
)
|
||||
|
||||
|
||||
class AICSService(TemplateService):
|
||||
UUID = GATT_AUDIO_INPUT_CONTROL_SERVICE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_input_state: Optional[AudioInputState] = None,
|
||||
gain_settings_properties: Optional[GainSettingsProperties] = None,
|
||||
audio_input_type: str = "local",
|
||||
audio_input_status: Optional[AudioInputStatus] = None,
|
||||
audio_input_description: Optional[AudioInputDescription] = None,
|
||||
):
|
||||
self.audio_input_state = (
|
||||
AudioInputState() if audio_input_state is None else audio_input_state
|
||||
)
|
||||
self.gain_settings_properties = (
|
||||
GainSettingsProperties()
|
||||
if gain_settings_properties is None
|
||||
else gain_settings_properties
|
||||
)
|
||||
self.audio_input_status = (
|
||||
AudioInputStatus.ACTIVE
|
||||
if audio_input_status is None
|
||||
else audio_input_status
|
||||
)
|
||||
self.audio_input_description = (
|
||||
AudioInputDescription()
|
||||
if audio_input_description is None
|
||||
else audio_input_description
|
||||
)
|
||||
|
||||
self.audio_input_control_point: AudioInputControlPoint = AudioInputControlPoint(
|
||||
self.audio_input_state, self.gain_settings_properties
|
||||
)
|
||||
|
||||
self.audio_input_state_characteristic = SerializableCharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ
|
||||
| Characteristic.Properties.NOTIFY,
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=self.audio_input_state,
|
||||
),
|
||||
AudioInputState,
|
||||
)
|
||||
self.audio_input_state.attribute_value = (
|
||||
self.audio_input_state_characteristic.value
|
||||
)
|
||||
|
||||
self.gain_settings_properties_characteristic = (
|
||||
SerializableCharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=self.gain_settings_properties,
|
||||
),
|
||||
GainSettingsProperties,
|
||||
)
|
||||
)
|
||||
|
||||
self.audio_input_type_characteristic = Characteristic(
|
||||
uuid=GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=bytes(audio_input_type, 'utf-8'),
|
||||
)
|
||||
|
||||
self.audio_input_status_characteristic = Characteristic(
|
||||
uuid=GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=bytes([self.audio_input_status]),
|
||||
)
|
||||
|
||||
self.audio_input_control_point_characteristic = Characteristic(
|
||||
uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.WRITE,
|
||||
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(write=self.audio_input_control_point.on_write),
|
||||
)
|
||||
|
||||
self.audio_input_description_characteristic = UTF8CharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ
|
||||
| Characteristic.Properties.NOTIFY
|
||||
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(
|
||||
write=self.audio_input_description.on_write,
|
||||
read=self.audio_input_description.on_read,
|
||||
),
|
||||
)
|
||||
)
|
||||
self.audio_input_description.attribute_value = (
|
||||
self.audio_input_control_point_characteristic.value
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
characteristics=[
|
||||
self.audio_input_state_characteristic, # type: ignore
|
||||
self.gain_settings_properties_characteristic, # type: ignore
|
||||
self.audio_input_type_characteristic, # type: ignore
|
||||
self.audio_input_status_characteristic, # type: ignore
|
||||
self.audio_input_control_point_characteristic, # type: ignore
|
||||
self.audio_input_description_characteristic, # type: ignore
|
||||
],
|
||||
primary=False,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class AICSServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = AICSService
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError("Audio Input State Characteristic not found")
|
||||
self.audio_input_state = SerializableCharacteristicAdapter(
|
||||
characteristics[0], AudioInputState
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Gain Settings Attribute Characteristic not found"
|
||||
)
|
||||
self.gain_settings_properties = SerializableCharacteristicAdapter(
|
||||
characteristics[0], GainSettingsProperties
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Audio Input Status Characteristic not found"
|
||||
)
|
||||
self.audio_input_status = PackedCharacteristicAdapter(characteristics[0], 'B')
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Audio Input Control Point Characteristic not found"
|
||||
)
|
||||
self.audio_input_control_point = characteristics[0]
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Audio Input Description Characteristic not found"
|
||||
)
|
||||
self.audio_input_description = UTF8CharacteristicAdapter(characteristics[0])
|
||||
727
bumble/profiles/ascs.py
Normal file
727
bumble/profiles/ascs.py
Normal file
@@ -0,0 +1,727 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for
|
||||
|
||||
"""LE Audio - Audio Stream Control Service"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import struct
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from bumble import colors
|
||||
from bumble.profiles.bap import CodecSpecificConfiguration
|
||||
from bumble.profiles import le_audio
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble import hci
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# ASE Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ASE_Operation:
|
||||
'''
|
||||
See Audio Stream Control Service - 5 ASE Control operations.
|
||||
'''
|
||||
|
||||
classes: Dict[int, Type[ASE_Operation]] = {}
|
||||
op_code: int
|
||||
name: str
|
||||
fields: Optional[Sequence[Any]] = None
|
||||
ase_id: List[int]
|
||||
|
||||
class Opcode(enum.IntEnum):
|
||||
# fmt: off
|
||||
CONFIG_CODEC = 0x01
|
||||
CONFIG_QOS = 0x02
|
||||
ENABLE = 0x03
|
||||
RECEIVER_START_READY = 0x04
|
||||
DISABLE = 0x05
|
||||
RECEIVER_STOP_READY = 0x06
|
||||
UPDATE_METADATA = 0x07
|
||||
RELEASE = 0x08
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(pdu: bytes) -> ASE_Operation:
|
||||
op_code = pdu[0]
|
||||
|
||||
cls = ASE_Operation.classes.get(op_code)
|
||||
if cls is None:
|
||||
instance = ASE_Operation(pdu)
|
||||
instance.name = ASE_Operation.Opcode(op_code).name
|
||||
instance.op_code = op_code
|
||||
return instance
|
||||
self = cls.__new__(cls)
|
||||
ASE_Operation.__init__(self, pdu)
|
||||
if self.fields is not None:
|
||||
self.init_from_bytes(pdu, 1)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def subclass(fields):
|
||||
def inner(cls: Type[ASE_Operation]):
|
||||
try:
|
||||
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
|
||||
cls.name = operation.name
|
||||
cls.op_code = operation
|
||||
except:
|
||||
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
|
||||
cls.fields = fields
|
||||
|
||||
# Register a factory for this class
|
||||
ASE_Operation.classes[cls.op_code] = cls
|
||||
|
||||
return cls
|
||||
|
||||
return inner
|
||||
|
||||
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
|
||||
if self.fields is not None and kwargs:
|
||||
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
|
||||
if pdu is None:
|
||||
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
|
||||
kwargs, self.fields
|
||||
)
|
||||
self.pdu = pdu
|
||||
|
||||
def init_from_bytes(self, pdu: bytes, offset: int):
|
||||
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.pdu
|
||||
|
||||
def __str__(self) -> str:
|
||||
result = f'{colors.color(self.name, "yellow")} '
|
||||
if fields := getattr(self, 'fields', None):
|
||||
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
|
||||
else:
|
||||
if len(self.pdu) > 1:
|
||||
result += f': {self.pdu.hex()}'
|
||||
return result
|
||||
|
||||
|
||||
@ASE_Operation.subclass(
|
||||
[
|
||||
[
|
||||
('ase_id', 1),
|
||||
('target_latency', 1),
|
||||
('target_phy', 1),
|
||||
('codec_id', hci.CodingFormat.parse_from_bytes),
|
||||
('codec_specific_configuration', 'v'),
|
||||
],
|
||||
]
|
||||
)
|
||||
class ASE_Config_Codec(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.1 - Config Codec Operation
|
||||
'''
|
||||
|
||||
target_latency: List[int]
|
||||
target_phy: List[int]
|
||||
codec_id: List[hci.CodingFormat]
|
||||
codec_specific_configuration: List[bytes]
|
||||
|
||||
|
||||
@ASE_Operation.subclass(
|
||||
[
|
||||
[
|
||||
('ase_id', 1),
|
||||
('cig_id', 1),
|
||||
('cis_id', 1),
|
||||
('sdu_interval', 3),
|
||||
('framing', 1),
|
||||
('phy', 1),
|
||||
('max_sdu', 2),
|
||||
('retransmission_number', 1),
|
||||
('max_transport_latency', 2),
|
||||
('presentation_delay', 3),
|
||||
],
|
||||
]
|
||||
)
|
||||
class ASE_Config_QOS(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.2 - Config Qos Operation
|
||||
'''
|
||||
|
||||
cig_id: List[int]
|
||||
cis_id: List[int]
|
||||
sdu_interval: List[int]
|
||||
framing: List[int]
|
||||
phy: List[int]
|
||||
max_sdu: List[int]
|
||||
retransmission_number: List[int]
|
||||
max_transport_latency: List[int]
|
||||
presentation_delay: List[int]
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
|
||||
class ASE_Enable(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.3 - Enable Operation
|
||||
'''
|
||||
|
||||
metadata: bytes
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||
class ASE_Receiver_Start_Ready(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
|
||||
'''
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||
class ASE_Disable(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.5 - Disable Operation
|
||||
'''
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||
class ASE_Receiver_Stop_Ready(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
|
||||
'''
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
|
||||
class ASE_Update_Metadata(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.7 - Update Metadata Operation
|
||||
'''
|
||||
|
||||
metadata: List[bytes]
|
||||
|
||||
|
||||
@ASE_Operation.subclass([[('ase_id', 1)]])
|
||||
class ASE_Release(ASE_Operation):
|
||||
'''
|
||||
See Audio Stream Control Service 5.8 - Release Operation
|
||||
'''
|
||||
|
||||
|
||||
class AseResponseCode(enum.IntEnum):
|
||||
# fmt: off
|
||||
SUCCESS = 0x00
|
||||
UNSUPPORTED_OPCODE = 0x01
|
||||
INVALID_LENGTH = 0x02
|
||||
INVALID_ASE_ID = 0x03
|
||||
INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04
|
||||
INVALID_ASE_DIRECTION = 0x05
|
||||
UNSUPPORTED_AUDIO_CAPABILITIES = 0x06
|
||||
UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07
|
||||
REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08
|
||||
INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09
|
||||
UNSUPPORTED_METADATA = 0x0A
|
||||
REJECTED_METADATA = 0x0B
|
||||
INVALID_METADATA = 0x0C
|
||||
INSUFFICIENT_RESOURCES = 0x0D
|
||||
UNSPECIFIED_ERROR = 0x0E
|
||||
|
||||
|
||||
class AseReasonCode(enum.IntEnum):
|
||||
# fmt: off
|
||||
NONE = 0x00
|
||||
CODEC_ID = 0x01
|
||||
CODEC_SPECIFIC_CONFIGURATION = 0x02
|
||||
SDU_INTERVAL = 0x03
|
||||
FRAMING = 0x04
|
||||
PHY = 0x05
|
||||
MAXIMUM_SDU_SIZE = 0x06
|
||||
RETRANSMISSION_NUMBER = 0x07
|
||||
MAX_TRANSPORT_LATENCY = 0x08
|
||||
PRESENTATION_DELAY = 0x09
|
||||
INVALID_ASE_CIS_MAPPING = 0x0A
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AudioRole(enum.IntEnum):
|
||||
SINK = device.CisLink.Direction.CONTROLLER_TO_HOST
|
||||
SOURCE = device.CisLink.Direction.HOST_TO_CONTROLLER
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AseStateMachine(gatt.Characteristic):
|
||||
class State(enum.IntEnum):
|
||||
# fmt: off
|
||||
IDLE = 0x00
|
||||
CODEC_CONFIGURED = 0x01
|
||||
QOS_CONFIGURED = 0x02
|
||||
ENABLING = 0x03
|
||||
STREAMING = 0x04
|
||||
DISABLING = 0x05
|
||||
RELEASING = 0x06
|
||||
|
||||
cis_link: Optional[device.CisLink] = None
|
||||
|
||||
# Additional parameters in CODEC_CONFIGURED State
|
||||
preferred_framing = 0 # Unframed PDU supported
|
||||
preferred_phy = 0
|
||||
preferred_retransmission_number = 13
|
||||
preferred_max_transport_latency = 100
|
||||
supported_presentation_delay_min = 0
|
||||
supported_presentation_delay_max = 0
|
||||
preferred_presentation_delay_min = 0
|
||||
preferred_presentation_delay_max = 0
|
||||
codec_id = hci.CodingFormat(hci.CodecID.LC3)
|
||||
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
|
||||
|
||||
# Additional parameters in QOS_CONFIGURED State
|
||||
cig_id = 0
|
||||
cis_id = 0
|
||||
sdu_interval = 0
|
||||
framing = 0
|
||||
phy = 0
|
||||
max_sdu = 0
|
||||
retransmission_number = 0
|
||||
max_transport_latency = 0
|
||||
presentation_delay = 0
|
||||
|
||||
# Additional parameters in ENABLING, STREAMING, DISABLING State
|
||||
metadata = le_audio.Metadata()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: AudioRole,
|
||||
ase_id: int,
|
||||
service: AudioStreamControlService,
|
||||
) -> None:
|
||||
self.service = service
|
||||
self.ase_id = ase_id
|
||||
self._state = AseStateMachine.State.IDLE
|
||||
self.role = role
|
||||
|
||||
uuid = (
|
||||
gatt.GATT_SINK_ASE_CHARACTERISTIC
|
||||
if role == AudioRole.SINK
|
||||
else gatt.GATT_SOURCE_ASE_CHARACTERISTIC
|
||||
)
|
||||
super().__init__(
|
||||
uuid=uuid,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
value=gatt.CharacteristicValue(read=self.on_read),
|
||||
)
|
||||
|
||||
self.service.device.on('cis_request', self.on_cis_request)
|
||||
self.service.device.on('cis_establishment', self.on_cis_establishment)
|
||||
|
||||
def on_cis_request(
|
||||
self,
|
||||
acl_connection: device.Connection,
|
||||
cis_handle: int,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
) -> None:
|
||||
if (
|
||||
cig_id == self.cig_id
|
||||
and cis_id == self.cis_id
|
||||
and self.state == self.State.ENABLING
|
||||
):
|
||||
acl_connection.abort_on(
|
||||
'flush', self.service.device.accept_cis_request(cis_handle)
|
||||
)
|
||||
|
||||
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
|
||||
if (
|
||||
cis_link.cig_id == self.cig_id
|
||||
and cis_link.cis_id == self.cis_id
|
||||
and self.state == self.State.ENABLING
|
||||
):
|
||||
cis_link.on('disconnection', self.on_cis_disconnection)
|
||||
|
||||
async def post_cis_established():
|
||||
await cis_link.setup_data_path(direction=self.role)
|
||||
if self.role == AudioRole.SINK:
|
||||
self.state = self.State.STREAMING
|
||||
await self.service.device.notify_subscribers(self, self.value)
|
||||
|
||||
cis_link.acl_connection.abort_on('flush', post_cis_established())
|
||||
self.cis_link = cis_link
|
||||
|
||||
def on_cis_disconnection(self, _reason) -> None:
|
||||
self.cis_link = None
|
||||
|
||||
def on_config_codec(
|
||||
self,
|
||||
target_latency: int,
|
||||
target_phy: int,
|
||||
codec_id: hci.CodingFormat,
|
||||
codec_specific_configuration: bytes,
|
||||
) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state not in (
|
||||
self.State.IDLE,
|
||||
self.State.CODEC_CONFIGURED,
|
||||
self.State.QOS_CONFIGURED,
|
||||
):
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
|
||||
self.max_transport_latency = target_latency
|
||||
self.phy = target_phy
|
||||
self.codec_id = codec_id
|
||||
if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC:
|
||||
self.codec_specific_configuration = codec_specific_configuration
|
||||
else:
|
||||
self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes(
|
||||
codec_specific_configuration
|
||||
)
|
||||
|
||||
self.state = self.State.CODEC_CONFIGURED
|
||||
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_config_qos(
|
||||
self,
|
||||
cig_id: int,
|
||||
cis_id: int,
|
||||
sdu_interval: int,
|
||||
framing: int,
|
||||
phy: int,
|
||||
max_sdu: int,
|
||||
retransmission_number: int,
|
||||
max_transport_latency: int,
|
||||
presentation_delay: int,
|
||||
) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state not in (
|
||||
AseStateMachine.State.CODEC_CONFIGURED,
|
||||
AseStateMachine.State.QOS_CONFIGURED,
|
||||
):
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
|
||||
self.cig_id = cig_id
|
||||
self.cis_id = cis_id
|
||||
self.sdu_interval = sdu_interval
|
||||
self.framing = framing
|
||||
self.phy = phy
|
||||
self.max_sdu = max_sdu
|
||||
self.retransmission_number = retransmission_number
|
||||
self.max_transport_latency = max_transport_latency
|
||||
self.presentation_delay = presentation_delay
|
||||
|
||||
self.state = self.State.QOS_CONFIGURED
|
||||
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state != AseStateMachine.State.QOS_CONFIGURED:
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
|
||||
self.metadata = le_audio.Metadata.from_bytes(metadata)
|
||||
self.state = self.State.ENABLING
|
||||
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state != AseStateMachine.State.ENABLING:
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
self.state = self.State.STREAMING
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state not in (
|
||||
AseStateMachine.State.ENABLING,
|
||||
AseStateMachine.State.STREAMING,
|
||||
):
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
if self.role == AudioRole.SINK:
|
||||
self.state = self.State.QOS_CONFIGURED
|
||||
else:
|
||||
self.state = self.State.DISABLING
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if (
|
||||
self.role != AudioRole.SOURCE
|
||||
or self.state != AseStateMachine.State.DISABLING
|
||||
):
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
self.state = self.State.QOS_CONFIGURED
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_update_metadata(
|
||||
self, metadata: bytes
|
||||
) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state not in (
|
||||
AseStateMachine.State.ENABLING,
|
||||
AseStateMachine.State.STREAMING,
|
||||
):
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
self.metadata = le_audio.Metadata.from_bytes(metadata)
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
|
||||
if self.state == AseStateMachine.State.IDLE:
|
||||
return (
|
||||
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
|
||||
AseReasonCode.NONE,
|
||||
)
|
||||
self.state = self.State.RELEASING
|
||||
|
||||
async def remove_cis_async():
|
||||
if self.cis_link:
|
||||
await self.cis_link.remove_data_path(self.role)
|
||||
self.state = self.State.IDLE
|
||||
await self.service.device.notify_subscribers(self, self.value)
|
||||
|
||||
self.service.device.abort_on('flush', remove_cis_async())
|
||||
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
|
||||
|
||||
@property
|
||||
def state(self) -> State:
|
||||
return self._state
|
||||
|
||||
@state.setter
|
||||
def state(self, new_state: State) -> None:
|
||||
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
|
||||
self._state = new_state
|
||||
self.emit('state_change')
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
'''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.'''
|
||||
|
||||
if self.state == self.State.CODEC_CONFIGURED:
|
||||
codec_specific_configuration_bytes = bytes(
|
||||
self.codec_specific_configuration
|
||||
)
|
||||
additional_parameters = (
|
||||
struct.pack(
|
||||
'<BBBH',
|
||||
self.preferred_framing,
|
||||
self.preferred_phy,
|
||||
self.preferred_retransmission_number,
|
||||
self.preferred_max_transport_latency,
|
||||
)
|
||||
+ self.supported_presentation_delay_min.to_bytes(3, 'little')
|
||||
+ self.supported_presentation_delay_max.to_bytes(3, 'little')
|
||||
+ self.preferred_presentation_delay_min.to_bytes(3, 'little')
|
||||
+ self.preferred_presentation_delay_max.to_bytes(3, 'little')
|
||||
+ bytes(self.codec_id)
|
||||
+ bytes([len(codec_specific_configuration_bytes)])
|
||||
+ codec_specific_configuration_bytes
|
||||
)
|
||||
elif self.state == self.State.QOS_CONFIGURED:
|
||||
additional_parameters = (
|
||||
bytes([self.cig_id, self.cis_id])
|
||||
+ self.sdu_interval.to_bytes(3, 'little')
|
||||
+ struct.pack(
|
||||
'<BBHBH',
|
||||
self.framing,
|
||||
self.phy,
|
||||
self.max_sdu,
|
||||
self.retransmission_number,
|
||||
self.max_transport_latency,
|
||||
)
|
||||
+ self.presentation_delay.to_bytes(3, 'little')
|
||||
)
|
||||
elif self.state in (
|
||||
self.State.ENABLING,
|
||||
self.State.STREAMING,
|
||||
self.State.DISABLING,
|
||||
):
|
||||
metadata_bytes = bytes(self.metadata)
|
||||
additional_parameters = (
|
||||
bytes([self.cig_id, self.cis_id, len(metadata_bytes)]) + metadata_bytes
|
||||
)
|
||||
else:
|
||||
additional_parameters = b''
|
||||
|
||||
return bytes([self.ase_id, self.state]) + additional_parameters
|
||||
|
||||
@value.setter
|
||||
def value(self, _new_value):
|
||||
# Readonly. Do nothing in the setter.
|
||||
pass
|
||||
|
||||
def on_read(self, _: Optional[device.Connection]) -> bytes:
|
||||
return self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'AseStateMachine(id={self.ase_id}, role={self.role.name} '
|
||||
f'state={self._state.name})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AudioStreamControlService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
|
||||
|
||||
ase_state_machines: Dict[int, AseStateMachine]
|
||||
ase_control_point: gatt.Characteristic
|
||||
_active_client: Optional[device.Connection] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: device.Device,
|
||||
source_ase_id: Sequence[int] = (),
|
||||
sink_ase_id: Sequence[int] = (),
|
||||
) -> None:
|
||||
self.device = device
|
||||
self.ase_state_machines = {
|
||||
**{
|
||||
id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self)
|
||||
for id in sink_ase_id
|
||||
},
|
||||
**{
|
||||
id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self)
|
||||
for id in source_ase_id
|
||||
},
|
||||
} # ASE state machines, by ASE ID
|
||||
|
||||
self.ase_control_point = gatt.Characteristic(
|
||||
uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.WRITE
|
||||
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.WRITEABLE,
|
||||
value=gatt.CharacteristicValue(write=self.on_write_ase_control_point),
|
||||
)
|
||||
|
||||
super().__init__([self.ase_control_point, *self.ase_state_machines.values()])
|
||||
|
||||
def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args):
|
||||
if ase := self.ase_state_machines.get(ase_id):
|
||||
handler = getattr(ase, 'on_' + opcode.name.lower())
|
||||
return (ase_id, *handler(*args))
|
||||
else:
|
||||
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
|
||||
|
||||
def _on_client_disconnected(self, _reason: int) -> None:
|
||||
for ase in self.ase_state_machines.values():
|
||||
ase.state = AseStateMachine.State.IDLE
|
||||
self._active_client = None
|
||||
|
||||
def on_write_ase_control_point(self, connection, data):
|
||||
if not self._active_client and connection:
|
||||
self._active_client = connection
|
||||
connection.once('disconnection', self._on_client_disconnected)
|
||||
|
||||
operation = ASE_Operation.from_bytes(data)
|
||||
responses = []
|
||||
logger.debug(f'*** ASCS Write {operation} ***')
|
||||
|
||||
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.target_latency,
|
||||
operation.target_phy,
|
||||
operation.codec_id,
|
||||
operation.codec_specific_configuration,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.cig_id,
|
||||
operation.cis_id,
|
||||
operation.sdu_interval,
|
||||
operation.framing,
|
||||
operation.phy,
|
||||
operation.max_sdu,
|
||||
operation.retransmission_number,
|
||||
operation.max_transport_latency,
|
||||
operation.presentation_delay,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
elif operation.op_code in (
|
||||
ASE_Operation.Opcode.ENABLE,
|
||||
ASE_Operation.Opcode.UPDATE_METADATA,
|
||||
):
|
||||
for ase_id, *args in zip(
|
||||
operation.ase_id,
|
||||
operation.metadata,
|
||||
):
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, args))
|
||||
elif operation.op_code in (
|
||||
ASE_Operation.Opcode.RECEIVER_START_READY,
|
||||
ASE_Operation.Opcode.DISABLE,
|
||||
ASE_Operation.Opcode.RECEIVER_STOP_READY,
|
||||
ASE_Operation.Opcode.RELEASE,
|
||||
):
|
||||
for ase_id in operation.ase_id:
|
||||
responses.append(self.on_operation(operation.op_code, ase_id, []))
|
||||
|
||||
control_point_notification = bytes(
|
||||
[operation.op_code, len(responses)]
|
||||
) + b''.join(map(bytes, responses))
|
||||
self.device.abort_on(
|
||||
'flush',
|
||||
self.device.notify_subscribers(
|
||||
self.ase_control_point, control_point_notification
|
||||
),
|
||||
)
|
||||
|
||||
for ase_id, *_ in responses:
|
||||
if ase := self.ase_state_machines.get(ase_id):
|
||||
self.device.abort_on(
|
||||
'flush',
|
||||
self.device.notify_subscribers(ase, ase.value),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = AudioStreamControlService
|
||||
|
||||
sink_ase: List[gatt_client.CharacteristicProxy]
|
||||
source_ase: List[gatt_client.CharacteristicProxy]
|
||||
ase_control_point: gatt_client.CharacteristicProxy
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.sink_ase = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SINK_ASE_CHARACTERISTIC
|
||||
)
|
||||
self.source_ase = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SOURCE_ASE_CHARACTERISTIC
|
||||
)
|
||||
self.ase_control_point = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC
|
||||
)[0]
|
||||
295
bumble/profiles/asha.py
Normal file
295
bumble/profiles/asha.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import enum
|
||||
import struct
|
||||
import logging
|
||||
from typing import List, Optional, Callable, Union, Any
|
||||
|
||||
from bumble import l2cap
|
||||
from bumble import utils
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.device import Device, Connection
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
class DeviceCapabilities(enum.IntFlag):
|
||||
IS_RIGHT = 0x01
|
||||
IS_DUAL = 0x02
|
||||
CSIS_SUPPORTED = 0x04
|
||||
|
||||
|
||||
class FeatureMap(enum.IntFlag):
|
||||
LE_COC_AUDIO_OUTPUT_STREAMING_SUPPORTED = 0x01
|
||||
|
||||
|
||||
class AudioType(utils.OpenIntEnum):
|
||||
UNKNOWN = 0x00
|
||||
RINGTONE = 0x01
|
||||
PHONE_CALL = 0x02
|
||||
MEDIA = 0x03
|
||||
|
||||
|
||||
class OpCode(utils.OpenIntEnum):
|
||||
START = 1
|
||||
STOP = 2
|
||||
STATUS = 3
|
||||
|
||||
|
||||
class Codec(utils.OpenIntEnum):
|
||||
G_722_16KHZ = 1
|
||||
|
||||
|
||||
class SupportedCodecs(enum.IntFlag):
|
||||
G_722_16KHZ = 1 << Codec.G_722_16KHZ
|
||||
|
||||
|
||||
class PeripheralStatus(utils.OpenIntEnum):
|
||||
"""Status update on the other peripheral."""
|
||||
|
||||
OTHER_PERIPHERAL_DISCONNECTED = 1
|
||||
OTHER_PERIPHERAL_CONNECTED = 2
|
||||
CONNECTION_PARAMETER_UPDATED = 3
|
||||
|
||||
|
||||
class AudioStatus(utils.OpenIntEnum):
|
||||
"""Status report field for the audio control point."""
|
||||
|
||||
OK = 0
|
||||
UNKNOWN_COMMAND = -1
|
||||
ILLEGAL_PARAMETERS = -2
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AshaService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_ASHA_SERVICE
|
||||
|
||||
audio_sink: Optional[Callable[[bytes], Any]]
|
||||
active_codec: Optional[Codec] = None
|
||||
audio_type: Optional[AudioType] = None
|
||||
volume: Optional[int] = None
|
||||
other_state: Optional[int] = None
|
||||
connection: Optional[Connection] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capability: int,
|
||||
hisyncid: Union[List[int], bytes],
|
||||
device: Device,
|
||||
psm: int = 0,
|
||||
audio_sink: Optional[Callable[[bytes], Any]] = None,
|
||||
feature_map: int = FeatureMap.LE_COC_AUDIO_OUTPUT_STREAMING_SUPPORTED,
|
||||
protocol_version: int = 0x01,
|
||||
render_delay_milliseconds: int = 0,
|
||||
supported_codecs: int = SupportedCodecs.G_722_16KHZ,
|
||||
) -> None:
|
||||
if len(hisyncid) != 8:
|
||||
_logger.warning('HiSyncId should have a length of 8, got %d', len(hisyncid))
|
||||
|
||||
self.hisyncid = bytes(hisyncid)
|
||||
self.capability = capability
|
||||
self.device = device
|
||||
self.audio_out_data = b''
|
||||
self.psm = psm # a non-zero psm is mainly for testing purpose
|
||||
self.audio_sink = audio_sink
|
||||
self.protocol_version = protocol_version
|
||||
|
||||
self.read_only_properties_characteristic = gatt.Characteristic(
|
||||
gatt.GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
|
||||
gatt.Characteristic.Properties.READ,
|
||||
gatt.Characteristic.READABLE,
|
||||
struct.pack(
|
||||
"<BB8sBH2sH",
|
||||
protocol_version,
|
||||
capability,
|
||||
self.hisyncid,
|
||||
feature_map,
|
||||
render_delay_milliseconds,
|
||||
b'\x00\x00',
|
||||
supported_codecs,
|
||||
),
|
||||
)
|
||||
|
||||
self.audio_control_point_characteristic = gatt.Characteristic(
|
||||
gatt.GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
|
||||
gatt.Characteristic.Properties.WRITE
|
||||
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
|
||||
gatt.Characteristic.WRITEABLE,
|
||||
gatt.CharacteristicValue(write=self._on_audio_control_point_write),
|
||||
)
|
||||
self.audio_status_characteristic = gatt.Characteristic(
|
||||
gatt.GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
|
||||
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY,
|
||||
gatt.Characteristic.READABLE,
|
||||
bytes([AudioStatus.OK]),
|
||||
)
|
||||
self.volume_characteristic = gatt.Characteristic(
|
||||
gatt.GATT_ASHA_VOLUME_CHARACTERISTIC,
|
||||
gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
|
||||
gatt.Characteristic.WRITEABLE,
|
||||
gatt.CharacteristicValue(write=self._on_volume_write),
|
||||
)
|
||||
|
||||
# let the server find a free PSM
|
||||
self.psm = device.create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(psm=self.psm, max_credits=8),
|
||||
handler=self._on_connection,
|
||||
).psm
|
||||
self.le_psm_out_characteristic = gatt.Characteristic(
|
||||
gatt.GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
|
||||
gatt.Characteristic.Properties.READ,
|
||||
gatt.Characteristic.READABLE,
|
||||
struct.pack('<H', self.psm),
|
||||
)
|
||||
|
||||
characteristics = [
|
||||
self.read_only_properties_characteristic,
|
||||
self.audio_control_point_characteristic,
|
||||
self.audio_status_characteristic,
|
||||
self.volume_characteristic,
|
||||
self.le_psm_out_characteristic,
|
||||
]
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
def get_advertising_data(self) -> bytes:
|
||||
# Advertisement only uses 4 least significant bytes of the HiSyncId.
|
||||
return bytes(
|
||||
AdvertisingData(
|
||||
[
|
||||
(
|
||||
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
|
||||
bytes(gatt.GATT_ASHA_SERVICE)
|
||||
+ bytes([self.protocol_version, self.capability])
|
||||
+ self.hisyncid[:4],
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Handler for audio control commands
|
||||
async def _on_audio_control_point_write(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
) -> None:
|
||||
_logger.debug(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
|
||||
opcode = value[0]
|
||||
if opcode == OpCode.START:
|
||||
# Start
|
||||
self.active_codec = Codec(value[1])
|
||||
self.audio_type = AudioType(value[2])
|
||||
self.volume = value[3]
|
||||
self.other_state = value[4]
|
||||
_logger.debug(
|
||||
f'### START: codec={self.active_codec.name}, '
|
||||
f'audio_type={self.audio_type.name}, '
|
||||
f'volume={self.volume}, '
|
||||
f'other_state={self.other_state}'
|
||||
)
|
||||
self.emit('started')
|
||||
elif opcode == OpCode.STOP:
|
||||
_logger.debug('### STOP')
|
||||
self.active_codec = None
|
||||
self.audio_type = None
|
||||
self.volume = None
|
||||
self.other_state = None
|
||||
self.emit('stopped')
|
||||
elif opcode == OpCode.STATUS:
|
||||
_logger.debug('### STATUS: %s', PeripheralStatus(value[1]).name)
|
||||
|
||||
if self.connection is None and connection:
|
||||
self.connection = connection
|
||||
|
||||
def on_disconnection(_reason) -> None:
|
||||
self.connection = None
|
||||
self.active_codec = None
|
||||
self.audio_type = None
|
||||
self.volume = None
|
||||
self.other_state = None
|
||||
self.emit('disconnected')
|
||||
|
||||
connection.once('disconnection', on_disconnection)
|
||||
|
||||
# OPCODE_STATUS does not need audio status point update
|
||||
if opcode != OpCode.STATUS:
|
||||
await self.device.notify_subscribers(
|
||||
self.audio_status_characteristic, force=True
|
||||
)
|
||||
|
||||
# Handler for volume control
|
||||
def _on_volume_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
_logger.debug(f'--- VOLUME Write:{value[0]}')
|
||||
self.volume = value[0]
|
||||
self.emit('volume_changed')
|
||||
|
||||
# Register an L2CAP CoC server
|
||||
def _on_connection(self, channel: l2cap.LeCreditBasedChannel) -> None:
|
||||
def on_data(data: bytes) -> None:
|
||||
if self.audio_sink: # pylint: disable=not-callable
|
||||
self.audio_sink(data)
|
||||
|
||||
channel.sink = on_data
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AshaServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = AshaService
|
||||
read_only_properties_characteristic: gatt_client.CharacteristicProxy
|
||||
audio_control_point_characteristic: gatt_client.CharacteristicProxy
|
||||
audio_status_point_characteristic: gatt_client.CharacteristicProxy
|
||||
volume_characteristic: gatt_client.CharacteristicProxy
|
||||
psm_characteristic: gatt_client.CharacteristicProxy
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
for uuid, attribute_name in (
|
||||
(
|
||||
gatt.GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
|
||||
'read_only_properties_characteristic',
|
||||
),
|
||||
(
|
||||
gatt.GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
|
||||
'audio_control_point_characteristic',
|
||||
),
|
||||
(
|
||||
gatt.GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
|
||||
'audio_status_point_characteristic',
|
||||
),
|
||||
(
|
||||
gatt.GATT_ASHA_VOLUME_CHARACTERISTIC,
|
||||
'volume_characteristic',
|
||||
),
|
||||
(
|
||||
gatt.GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
|
||||
'psm_characteristic',
|
||||
),
|
||||
):
|
||||
if not (
|
||||
characteristics := self.service_proxy.get_characteristics_by_uuid(uuid)
|
||||
):
|
||||
raise gatt.InvalidServiceError(f"Missing {uuid} Characteristic")
|
||||
setattr(self, attribute_name, characteristics[0])
|
||||
@@ -1,188 +0,0 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
import logging
|
||||
from typing import List
|
||||
from ..core import AdvertisingData
|
||||
from ..device import Device, Connection
|
||||
from ..gatt import (
|
||||
GATT_ASHA_SERVICE,
|
||||
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
|
||||
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
|
||||
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
|
||||
GATT_ASHA_VOLUME_CHARACTERISTIC,
|
||||
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
|
||||
TemplateService,
|
||||
Characteristic,
|
||||
CharacteristicValue,
|
||||
)
|
||||
from ..utils import AsyncRunner
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AshaService(TemplateService):
|
||||
UUID = GATT_ASHA_SERVICE
|
||||
OPCODE_START = 1
|
||||
OPCODE_STOP = 2
|
||||
OPCODE_STATUS = 3
|
||||
PROTOCOL_VERSION = 0x01
|
||||
RESERVED_FOR_FUTURE_USE = [00, 00]
|
||||
FEATURE_MAP = [0x01] # [LE CoC audio output streaming supported]
|
||||
SUPPORTED_CODEC_ID = [0x02, 0x01] # Codec IDs [G.722 at 16 kHz]
|
||||
RENDER_DELAY = [00, 00]
|
||||
|
||||
def __init__(self, capability: int, hisyncid: List[int], device: Device, psm=0):
|
||||
self.hisyncid = hisyncid
|
||||
self.capability = capability # Device Capabilities [Left, Monaural]
|
||||
self.device = device
|
||||
self.audio_out_data = b''
|
||||
self.psm = psm # a non-zero psm is mainly for testing purpose
|
||||
|
||||
# Handler for volume control
|
||||
def on_volume_write(connection, value):
|
||||
logger.info(f'--- VOLUME Write:{value[0]}')
|
||||
self.emit('volume', connection, value[0])
|
||||
|
||||
# Handler for audio control commands
|
||||
def on_audio_control_point_write(connection: Connection, value):
|
||||
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
|
||||
opcode = value[0]
|
||||
if opcode == AshaService.OPCODE_START:
|
||||
# Start
|
||||
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
|
||||
logger.info(
|
||||
f'### START: codec={value[1]}, '
|
||||
f'audio_type={audio_type}, '
|
||||
f'volume={value[3]}, '
|
||||
f'otherstate={value[4]}'
|
||||
)
|
||||
self.emit(
|
||||
'start',
|
||||
connection,
|
||||
{
|
||||
'codec': value[1],
|
||||
'audiotype': value[2],
|
||||
'volume': value[3],
|
||||
'otherstate': value[4],
|
||||
},
|
||||
)
|
||||
elif opcode == AshaService.OPCODE_STOP:
|
||||
logger.info('### STOP')
|
||||
self.emit('stop', connection)
|
||||
elif opcode == AshaService.OPCODE_STATUS:
|
||||
logger.info(f'### STATUS: connected={value[1]}')
|
||||
|
||||
# OPCODE_STATUS does not need audio status point update
|
||||
if opcode != AshaService.OPCODE_STATUS:
|
||||
AsyncRunner.spawn(
|
||||
device.notify_subscribers(
|
||||
self.audio_status_characteristic, force=True
|
||||
)
|
||||
)
|
||||
|
||||
self.read_only_properties_characteristic = Characteristic(
|
||||
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.READABLE,
|
||||
bytes(
|
||||
[
|
||||
AshaService.PROTOCOL_VERSION, # Version
|
||||
self.capability,
|
||||
]
|
||||
)
|
||||
+ bytes(self.hisyncid)
|
||||
+ bytes(AshaService.FEATURE_MAP)
|
||||
+ bytes(AshaService.RENDER_DELAY)
|
||||
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
|
||||
+ bytes(AshaService.SUPPORTED_CODEC_ID),
|
||||
)
|
||||
|
||||
self.audio_control_point_characteristic = Characteristic(
|
||||
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
|
||||
Characteristic.Properties.WRITE
|
||||
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
|
||||
Characteristic.WRITEABLE,
|
||||
CharacteristicValue(write=on_audio_control_point_write),
|
||||
)
|
||||
self.audio_status_characteristic = Characteristic(
|
||||
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
|
||||
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
|
||||
Characteristic.READABLE,
|
||||
bytes([0]),
|
||||
)
|
||||
self.volume_characteristic = Characteristic(
|
||||
GATT_ASHA_VOLUME_CHARACTERISTIC,
|
||||
Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
|
||||
Characteristic.WRITEABLE,
|
||||
CharacteristicValue(write=on_volume_write),
|
||||
)
|
||||
|
||||
# Register an L2CAP CoC server
|
||||
def on_coc(channel):
|
||||
def on_data(data):
|
||||
logging.debug(f'<<< data received:{data}')
|
||||
|
||||
self.emit('data', channel.connection, data)
|
||||
self.audio_out_data += data
|
||||
|
||||
channel.sink = on_data
|
||||
|
||||
# let the server find a free PSM
|
||||
self.psm = self.device.register_l2cap_channel_server(self.psm, on_coc, 8)
|
||||
self.le_psm_out_characteristic = Characteristic(
|
||||
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.READABLE,
|
||||
struct.pack('<H', self.psm),
|
||||
)
|
||||
|
||||
characteristics = [
|
||||
self.read_only_properties_characteristic,
|
||||
self.audio_control_point_characteristic,
|
||||
self.audio_status_characteristic,
|
||||
self.volume_characteristic,
|
||||
self.le_psm_out_characteristic,
|
||||
]
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
def get_advertising_data(self):
|
||||
# Advertisement only uses 4 least significant bytes of the HiSyncId.
|
||||
return bytes(
|
||||
AdvertisingData(
|
||||
[
|
||||
(
|
||||
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
|
||||
bytes(GATT_ASHA_SERVICE)
|
||||
+ bytes(
|
||||
[
|
||||
AshaService.PROTOCOL_VERSION,
|
||||
self.capability,
|
||||
]
|
||||
)
|
||||
+ bytes(self.hisyncid[:4]),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
623
bumble/profiles/bap.py
Normal file
623
bumble/profiles/bap.py
Normal file
@@ -0,0 +1,623 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import struct
|
||||
import functools
|
||||
import logging
|
||||
from typing import List
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble import core
|
||||
from bumble import hci
|
||||
from bumble import gatt
|
||||
from bumble import utils
|
||||
from bumble.profiles import le_audio
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AudioLocation(enum.IntFlag):
|
||||
'''Bluetooth Assigned Numbers, Section 6.12.1 - Audio Location'''
|
||||
|
||||
# fmt: off
|
||||
NOT_ALLOWED = 0x00000000
|
||||
FRONT_LEFT = 0x00000001
|
||||
FRONT_RIGHT = 0x00000002
|
||||
FRONT_CENTER = 0x00000004
|
||||
LOW_FREQUENCY_EFFECTS_1 = 0x00000008
|
||||
BACK_LEFT = 0x00000010
|
||||
BACK_RIGHT = 0x00000020
|
||||
FRONT_LEFT_OF_CENTER = 0x00000040
|
||||
FRONT_RIGHT_OF_CENTER = 0x00000080
|
||||
BACK_CENTER = 0x00000100
|
||||
LOW_FREQUENCY_EFFECTS_2 = 0x00000200
|
||||
SIDE_LEFT = 0x00000400
|
||||
SIDE_RIGHT = 0x00000800
|
||||
TOP_FRONT_LEFT = 0x00001000
|
||||
TOP_FRONT_RIGHT = 0x00002000
|
||||
TOP_FRONT_CENTER = 0x00004000
|
||||
TOP_CENTER = 0x00008000
|
||||
TOP_BACK_LEFT = 0x00010000
|
||||
TOP_BACK_RIGHT = 0x00020000
|
||||
TOP_SIDE_LEFT = 0x00040000
|
||||
TOP_SIDE_RIGHT = 0x00080000
|
||||
TOP_BACK_CENTER = 0x00100000
|
||||
BOTTOM_FRONT_CENTER = 0x00200000
|
||||
BOTTOM_FRONT_LEFT = 0x00400000
|
||||
BOTTOM_FRONT_RIGHT = 0x00800000
|
||||
FRONT_LEFT_WIDE = 0x01000000
|
||||
FRONT_RIGHT_WIDE = 0x02000000
|
||||
LEFT_SURROUND = 0x04000000
|
||||
RIGHT_SURROUND = 0x08000000
|
||||
|
||||
@property
|
||||
def channel_count(self) -> int:
|
||||
return bin(self.value).count('1')
|
||||
|
||||
|
||||
class AudioInputType(enum.IntEnum):
|
||||
'''Bluetooth Assigned Numbers, Section 6.12.2 - Audio Input Type'''
|
||||
|
||||
# fmt: off
|
||||
UNSPECIFIED = 0x00
|
||||
BLUETOOTH = 0x01
|
||||
MICROPHONE = 0x02
|
||||
ANALOG = 0x03
|
||||
DIGITAL = 0x04
|
||||
RADIO = 0x05
|
||||
STREAMING = 0x06
|
||||
AMBIENT = 0x07
|
||||
|
||||
|
||||
class ContextType(enum.IntFlag):
|
||||
'''Bluetooth Assigned Numbers, Section 6.12.3 - Context Type'''
|
||||
|
||||
# fmt: off
|
||||
PROHIBITED = 0x0000
|
||||
UNSPECIFIED = 0x0001
|
||||
CONVERSATIONAL = 0x0002
|
||||
MEDIA = 0x0004
|
||||
GAME = 0x0008
|
||||
INSTRUCTIONAL = 0x0010
|
||||
VOICE_ASSISTANTS = 0x0020
|
||||
LIVE = 0x0040
|
||||
SOUND_EFFECTS = 0x0080
|
||||
NOTIFICATIONS = 0x0100
|
||||
RINGTONE = 0x0200
|
||||
ALERTS = 0x0400
|
||||
EMERGENCY_ALARM = 0x0800
|
||||
|
||||
|
||||
class SamplingFrequency(utils.OpenIntEnum):
|
||||
'''Bluetooth Assigned Numbers, Section 6.12.5.1 - Sampling Frequency'''
|
||||
|
||||
# fmt: off
|
||||
FREQ_8000 = 0x01
|
||||
FREQ_11025 = 0x02
|
||||
FREQ_16000 = 0x03
|
||||
FREQ_22050 = 0x04
|
||||
FREQ_24000 = 0x05
|
||||
FREQ_32000 = 0x06
|
||||
FREQ_44100 = 0x07
|
||||
FREQ_48000 = 0x08
|
||||
FREQ_88200 = 0x09
|
||||
FREQ_96000 = 0x0A
|
||||
FREQ_176400 = 0x0B
|
||||
FREQ_192000 = 0x0C
|
||||
FREQ_384000 = 0x0D
|
||||
# fmt: on
|
||||
|
||||
@classmethod
|
||||
def from_hz(cls, frequency: int) -> SamplingFrequency:
|
||||
return {
|
||||
8000: SamplingFrequency.FREQ_8000,
|
||||
11025: SamplingFrequency.FREQ_11025,
|
||||
16000: SamplingFrequency.FREQ_16000,
|
||||
22050: SamplingFrequency.FREQ_22050,
|
||||
24000: SamplingFrequency.FREQ_24000,
|
||||
32000: SamplingFrequency.FREQ_32000,
|
||||
44100: SamplingFrequency.FREQ_44100,
|
||||
48000: SamplingFrequency.FREQ_48000,
|
||||
88200: SamplingFrequency.FREQ_88200,
|
||||
96000: SamplingFrequency.FREQ_96000,
|
||||
176400: SamplingFrequency.FREQ_176400,
|
||||
192000: SamplingFrequency.FREQ_192000,
|
||||
384000: SamplingFrequency.FREQ_384000,
|
||||
}[frequency]
|
||||
|
||||
@property
|
||||
def hz(self) -> int:
|
||||
return {
|
||||
SamplingFrequency.FREQ_8000: 8000,
|
||||
SamplingFrequency.FREQ_11025: 11025,
|
||||
SamplingFrequency.FREQ_16000: 16000,
|
||||
SamplingFrequency.FREQ_22050: 22050,
|
||||
SamplingFrequency.FREQ_24000: 24000,
|
||||
SamplingFrequency.FREQ_32000: 32000,
|
||||
SamplingFrequency.FREQ_44100: 44100,
|
||||
SamplingFrequency.FREQ_48000: 48000,
|
||||
SamplingFrequency.FREQ_88200: 88200,
|
||||
SamplingFrequency.FREQ_96000: 96000,
|
||||
SamplingFrequency.FREQ_176400: 176400,
|
||||
SamplingFrequency.FREQ_192000: 192000,
|
||||
SamplingFrequency.FREQ_384000: 384000,
|
||||
}[self]
|
||||
|
||||
|
||||
class SupportedSamplingFrequency(enum.IntFlag):
|
||||
'''Bluetooth Assigned Numbers, Section 6.12.4.1 - Sample Frequency'''
|
||||
|
||||
# fmt: off
|
||||
FREQ_8000 = 1 << (SamplingFrequency.FREQ_8000 - 1)
|
||||
FREQ_11025 = 1 << (SamplingFrequency.FREQ_11025 - 1)
|
||||
FREQ_16000 = 1 << (SamplingFrequency.FREQ_16000 - 1)
|
||||
FREQ_22050 = 1 << (SamplingFrequency.FREQ_22050 - 1)
|
||||
FREQ_24000 = 1 << (SamplingFrequency.FREQ_24000 - 1)
|
||||
FREQ_32000 = 1 << (SamplingFrequency.FREQ_32000 - 1)
|
||||
FREQ_44100 = 1 << (SamplingFrequency.FREQ_44100 - 1)
|
||||
FREQ_48000 = 1 << (SamplingFrequency.FREQ_48000 - 1)
|
||||
FREQ_88200 = 1 << (SamplingFrequency.FREQ_88200 - 1)
|
||||
FREQ_96000 = 1 << (SamplingFrequency.FREQ_96000 - 1)
|
||||
FREQ_176400 = 1 << (SamplingFrequency.FREQ_176400 - 1)
|
||||
FREQ_192000 = 1 << (SamplingFrequency.FREQ_192000 - 1)
|
||||
FREQ_384000 = 1 << (SamplingFrequency.FREQ_384000 - 1)
|
||||
# fmt: on
|
||||
|
||||
@classmethod
|
||||
def from_hz(cls, frequencies: Sequence[int]) -> SupportedSamplingFrequency:
|
||||
MAPPING = {
|
||||
8000: SupportedSamplingFrequency.FREQ_8000,
|
||||
11025: SupportedSamplingFrequency.FREQ_11025,
|
||||
16000: SupportedSamplingFrequency.FREQ_16000,
|
||||
22050: SupportedSamplingFrequency.FREQ_22050,
|
||||
24000: SupportedSamplingFrequency.FREQ_24000,
|
||||
32000: SupportedSamplingFrequency.FREQ_32000,
|
||||
44100: SupportedSamplingFrequency.FREQ_44100,
|
||||
48000: SupportedSamplingFrequency.FREQ_48000,
|
||||
88200: SupportedSamplingFrequency.FREQ_88200,
|
||||
96000: SupportedSamplingFrequency.FREQ_96000,
|
||||
176400: SupportedSamplingFrequency.FREQ_176400,
|
||||
192000: SupportedSamplingFrequency.FREQ_192000,
|
||||
384000: SupportedSamplingFrequency.FREQ_384000,
|
||||
}
|
||||
|
||||
return functools.reduce(
|
||||
lambda x, y: x | MAPPING[y],
|
||||
frequencies,
|
||||
cls(0),
|
||||
)
|
||||
|
||||
|
||||
class FrameDuration(enum.IntEnum):
|
||||
'''Bluetooth Assigned Numbers, Section 6.12.5.2 - Frame Duration'''
|
||||
|
||||
# fmt: off
|
||||
DURATION_7500_US = 0x00
|
||||
DURATION_10000_US = 0x01
|
||||
|
||||
@property
|
||||
def us(self) -> int:
|
||||
return {
|
||||
FrameDuration.DURATION_7500_US: 7500,
|
||||
FrameDuration.DURATION_10000_US: 10000,
|
||||
}[self]
|
||||
|
||||
|
||||
class SupportedFrameDuration(enum.IntFlag):
|
||||
'''Bluetooth Assigned Numbers, Section 6.12.4.2 - Frame Duration'''
|
||||
|
||||
# fmt: off
|
||||
DURATION_7500_US_SUPPORTED = 0b0001
|
||||
DURATION_10000_US_SUPPORTED = 0b0010
|
||||
DURATION_7500_US_PREFERRED = 0b0001
|
||||
DURATION_10000_US_PREFERRED = 0b0010
|
||||
|
||||
|
||||
class AnnouncementType(utils.OpenIntEnum):
|
||||
'''Basic Audio Profile, 3.5.3. Additional Audio Stream Control Service requirements'''
|
||||
|
||||
# fmt: off
|
||||
GENERAL = 0x00
|
||||
TARGETED = 0x01
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UnicastServerAdvertisingData:
|
||||
"""Advertising Data for ASCS."""
|
||||
|
||||
announcement_type: AnnouncementType = AnnouncementType.TARGETED
|
||||
available_audio_contexts: ContextType = ContextType.MEDIA
|
||||
metadata: bytes = b''
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
(
|
||||
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
|
||||
struct.pack(
|
||||
'<2sBIB',
|
||||
bytes(gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE),
|
||||
self.announcement_type,
|
||||
self.available_audio_contexts,
|
||||
len(self.metadata),
|
||||
)
|
||||
+ self.metadata,
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def bits_to_channel_counts(data: int) -> List[int]:
|
||||
pos = 0
|
||||
counts = []
|
||||
while data != 0:
|
||||
# Bit 0 = count 1
|
||||
# Bit 1 = count 2, and so on
|
||||
pos += 1
|
||||
if data & 1:
|
||||
counts.append(pos)
|
||||
data >>= 1
|
||||
return counts
|
||||
|
||||
|
||||
def channel_counts_to_bits(counts: Sequence[int]) -> int:
|
||||
return sum(set([1 << (count - 1) for count in counts]))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Structures
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CodecSpecificCapabilities:
|
||||
'''See:
|
||||
* Bluetooth Assigned Numbers, 6.12.4 - Codec Specific Capabilities LTV Structures
|
||||
* Basic Audio Profile, 4.3.1 - Codec_Specific_Capabilities LTV requirements
|
||||
'''
|
||||
|
||||
class Type(enum.IntEnum):
|
||||
# fmt: off
|
||||
SAMPLING_FREQUENCY = 0x01
|
||||
FRAME_DURATION = 0x02
|
||||
AUDIO_CHANNEL_COUNT = 0x03
|
||||
OCTETS_PER_FRAME = 0x04
|
||||
CODEC_FRAMES_PER_SDU = 0x05
|
||||
|
||||
supported_sampling_frequencies: SupportedSamplingFrequency
|
||||
supported_frame_durations: SupportedFrameDuration
|
||||
supported_audio_channel_count: Sequence[int]
|
||||
min_octets_per_codec_frame: int
|
||||
max_octets_per_codec_frame: int
|
||||
supported_max_codec_frames_per_sdu: int
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> CodecSpecificCapabilities:
|
||||
offset = 0
|
||||
# Allowed default values.
|
||||
supported_audio_channel_count = [1]
|
||||
supported_max_codec_frames_per_sdu = 1
|
||||
while offset < len(data):
|
||||
length, type = struct.unpack_from('BB', data, offset)
|
||||
offset += 2
|
||||
value = int.from_bytes(data[offset : offset + length - 1], 'little')
|
||||
offset += length - 1
|
||||
|
||||
if type == CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
|
||||
supported_sampling_frequencies = SupportedSamplingFrequency(value)
|
||||
elif type == CodecSpecificCapabilities.Type.FRAME_DURATION:
|
||||
supported_frame_durations = SupportedFrameDuration(value)
|
||||
elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
|
||||
supported_audio_channel_count = bits_to_channel_counts(value)
|
||||
elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
|
||||
min_octets_per_sample = value & 0xFFFF
|
||||
max_octets_per_sample = value >> 16
|
||||
elif type == CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
|
||||
supported_max_codec_frames_per_sdu = value
|
||||
|
||||
# It is expected here that if some fields are missing, an error should be raised.
|
||||
# pylint: disable=possibly-used-before-assignment,used-before-assignment
|
||||
return CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=supported_sampling_frequencies,
|
||||
supported_frame_durations=supported_frame_durations,
|
||||
supported_audio_channel_count=supported_audio_channel_count,
|
||||
min_octets_per_codec_frame=min_octets_per_sample,
|
||||
max_octets_per_codec_frame=max_octets_per_sample,
|
||||
supported_max_codec_frames_per_sdu=supported_max_codec_frames_per_sdu,
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack(
|
||||
'<BBHBBBBBBBBHHBBB',
|
||||
3,
|
||||
CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY,
|
||||
self.supported_sampling_frequencies,
|
||||
2,
|
||||
CodecSpecificCapabilities.Type.FRAME_DURATION,
|
||||
self.supported_frame_durations,
|
||||
2,
|
||||
CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT,
|
||||
channel_counts_to_bits(self.supported_audio_channel_count),
|
||||
5,
|
||||
CodecSpecificCapabilities.Type.OCTETS_PER_FRAME,
|
||||
self.min_octets_per_codec_frame,
|
||||
self.max_octets_per_codec_frame,
|
||||
2,
|
||||
CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU,
|
||||
self.supported_max_codec_frames_per_sdu,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CodecSpecificConfiguration:
|
||||
'''See:
|
||||
* Bluetooth Assigned Numbers, 6.12.5 - Codec Specific Configuration LTV Structures
|
||||
* Basic Audio Profile, 4.3.2 - Codec_Specific_Capabilities LTV requirements
|
||||
'''
|
||||
|
||||
class Type(utils.OpenIntEnum):
|
||||
# fmt: off
|
||||
SAMPLING_FREQUENCY = 0x01
|
||||
FRAME_DURATION = 0x02
|
||||
AUDIO_CHANNEL_ALLOCATION = 0x03
|
||||
OCTETS_PER_FRAME = 0x04
|
||||
CODEC_FRAMES_PER_SDU = 0x05
|
||||
|
||||
sampling_frequency: SamplingFrequency | None = None
|
||||
frame_duration: FrameDuration | None = None
|
||||
audio_channel_allocation: AudioLocation | None = None
|
||||
octets_per_codec_frame: int | None = None
|
||||
codec_frames_per_sdu: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> CodecSpecificConfiguration:
|
||||
offset = 0
|
||||
sampling_frequency: SamplingFrequency | None = None
|
||||
frame_duration: FrameDuration | None = None
|
||||
audio_channel_allocation: AudioLocation | None = None
|
||||
octets_per_codec_frame: int | None = None
|
||||
codec_frames_per_sdu: int | None = None
|
||||
|
||||
while offset < len(data):
|
||||
length, type = struct.unpack_from('BB', data, offset)
|
||||
offset += 2
|
||||
value = int.from_bytes(data[offset : offset + length - 1], 'little')
|
||||
offset += length - 1
|
||||
|
||||
if type == CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY:
|
||||
sampling_frequency = SamplingFrequency(value)
|
||||
elif type == CodecSpecificConfiguration.Type.FRAME_DURATION:
|
||||
frame_duration = FrameDuration(value)
|
||||
elif type == CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION:
|
||||
audio_channel_allocation = AudioLocation(value)
|
||||
elif type == CodecSpecificConfiguration.Type.OCTETS_PER_FRAME:
|
||||
octets_per_codec_frame = value
|
||||
elif type == CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU:
|
||||
codec_frames_per_sdu = value
|
||||
|
||||
return CodecSpecificConfiguration(
|
||||
sampling_frequency=sampling_frequency,
|
||||
frame_duration=frame_duration,
|
||||
audio_channel_allocation=audio_channel_allocation,
|
||||
octets_per_codec_frame=octets_per_codec_frame,
|
||||
codec_frames_per_sdu=codec_frames_per_sdu,
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return b''.join(
|
||||
[
|
||||
struct.pack(fmt, length, tag, value)
|
||||
for fmt, length, tag, value in [
|
||||
(
|
||||
'<BBB',
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY,
|
||||
self.sampling_frequency,
|
||||
),
|
||||
(
|
||||
'<BBB',
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.FRAME_DURATION,
|
||||
self.frame_duration,
|
||||
),
|
||||
(
|
||||
'<BBI',
|
||||
5,
|
||||
CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION,
|
||||
self.audio_channel_allocation,
|
||||
),
|
||||
(
|
||||
'<BBH',
|
||||
3,
|
||||
CodecSpecificConfiguration.Type.OCTETS_PER_FRAME,
|
||||
self.octets_per_codec_frame,
|
||||
),
|
||||
(
|
||||
'<BBB',
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU,
|
||||
self.codec_frames_per_sdu,
|
||||
),
|
||||
]
|
||||
if value is not None
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BroadcastAudioAnnouncement:
|
||||
broadcast_id: int
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
return cls(int.from_bytes(data[:3], 'little'))
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.broadcast_id.to_bytes(3, 'little')
|
||||
|
||||
def get_advertising_data(self) -> bytes:
|
||||
return bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
(
|
||||
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
|
||||
(
|
||||
bytes(gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE)
|
||||
+ bytes(self)
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BasicAudioAnnouncement:
|
||||
@dataclasses.dataclass
|
||||
class BIS:
|
||||
index: int
|
||||
codec_specific_configuration: CodecSpecificConfiguration
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
codec_specific_configuration_bytes = bytes(
|
||||
self.codec_specific_configuration
|
||||
)
|
||||
return (
|
||||
bytes([self.index, len(codec_specific_configuration_bytes)])
|
||||
+ codec_specific_configuration_bytes
|
||||
)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Subgroup:
|
||||
codec_id: hci.CodingFormat
|
||||
codec_specific_configuration: CodecSpecificConfiguration
|
||||
metadata: le_audio.Metadata
|
||||
bis: List[BasicAudioAnnouncement.BIS]
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
metadata_bytes = bytes(self.metadata)
|
||||
codec_specific_configuration_bytes = bytes(
|
||||
self.codec_specific_configuration
|
||||
)
|
||||
return (
|
||||
bytes([len(self.bis)])
|
||||
+ bytes(self.codec_id)
|
||||
+ bytes([len(codec_specific_configuration_bytes)])
|
||||
+ codec_specific_configuration_bytes
|
||||
+ bytes([len(metadata_bytes)])
|
||||
+ metadata_bytes
|
||||
+ b''.join(map(bytes, self.bis))
|
||||
)
|
||||
|
||||
presentation_delay: int
|
||||
subgroups: List[BasicAudioAnnouncement.Subgroup]
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
presentation_delay = int.from_bytes(data[:3], 'little')
|
||||
subgroups = []
|
||||
offset = 4
|
||||
for _ in range(data[3]):
|
||||
num_bis = data[offset]
|
||||
offset += 1
|
||||
codec_id = hci.CodingFormat.from_bytes(data[offset : offset + 5])
|
||||
offset += 5
|
||||
codec_specific_configuration_length = data[offset]
|
||||
offset += 1
|
||||
codec_specific_configuration = data[
|
||||
offset : offset + codec_specific_configuration_length
|
||||
]
|
||||
offset += codec_specific_configuration_length
|
||||
metadata_length = data[offset]
|
||||
offset += 1
|
||||
metadata = le_audio.Metadata.from_bytes(
|
||||
data[offset : offset + metadata_length]
|
||||
)
|
||||
offset += metadata_length
|
||||
|
||||
bis = []
|
||||
for _ in range(num_bis):
|
||||
bis_index = data[offset]
|
||||
offset += 1
|
||||
bis_codec_specific_configuration_length = data[offset]
|
||||
offset += 1
|
||||
bis_codec_specific_configuration = data[
|
||||
offset : offset + bis_codec_specific_configuration_length
|
||||
]
|
||||
offset += bis_codec_specific_configuration_length
|
||||
bis.append(
|
||||
cls.BIS(
|
||||
bis_index,
|
||||
CodecSpecificConfiguration.from_bytes(
|
||||
bis_codec_specific_configuration
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
subgroups.append(
|
||||
cls.Subgroup(
|
||||
codec_id,
|
||||
CodecSpecificConfiguration.from_bytes(codec_specific_configuration),
|
||||
metadata,
|
||||
bis,
|
||||
)
|
||||
)
|
||||
|
||||
return cls(presentation_delay, subgroups)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return (
|
||||
self.presentation_delay.to_bytes(3, 'little')
|
||||
+ bytes([len(self.subgroups)])
|
||||
+ b''.join(map(bytes, self.subgroups))
|
||||
)
|
||||
|
||||
def get_advertising_data(self) -> bytes:
|
||||
return bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
(
|
||||
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
|
||||
(
|
||||
bytes(gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE)
|
||||
+ bytes(self)
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
437
bumble/profiles/bass.py
Normal file
437
bumble/profiles/bass.py
Normal file
@@ -0,0 +1,437 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for
|
||||
|
||||
"""LE Audio - Broadcast Audio Scan Service"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import logging
|
||||
import struct
|
||||
from typing import ClassVar, List, Optional, Sequence
|
||||
|
||||
from bumble import core
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble import hci
|
||||
from bumble import utils
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
class ApplicationError(utils.OpenIntEnum):
|
||||
OPCODE_NOT_SUPPORTED = 0x80
|
||||
INVALID_SOURCE_ID = 0x81
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def encode_subgroups(subgroups: Sequence[SubgroupInfo]) -> bytes:
|
||||
return bytes([len(subgroups)]) + b"".join(
|
||||
struct.pack("<IB", subgroup.bis_sync, len(subgroup.metadata))
|
||||
+ subgroup.metadata
|
||||
for subgroup in subgroups
|
||||
)
|
||||
|
||||
|
||||
def decode_subgroups(data: bytes) -> List[SubgroupInfo]:
|
||||
num_subgroups = data[0]
|
||||
offset = 1
|
||||
subgroups = []
|
||||
for _ in range(num_subgroups):
|
||||
bis_sync = struct.unpack("<I", data[offset : offset + 4])[0]
|
||||
metadata_length = data[offset + 4]
|
||||
metadata = data[offset + 5 : offset + 5 + metadata_length]
|
||||
offset += 5 + metadata_length
|
||||
subgroups.append(SubgroupInfo(bis_sync, metadata))
|
||||
|
||||
return subgroups
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PeriodicAdvertisingSyncParams(utils.OpenIntEnum):
|
||||
DO_NOT_SYNCHRONIZE_TO_PA = 0x00
|
||||
SYNCHRONIZE_TO_PA_PAST_AVAILABLE = 0x01
|
||||
SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE = 0x02
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SubgroupInfo:
|
||||
ANY_BIS: ClassVar[int] = 0xFFFFFFFF
|
||||
|
||||
bis_sync: int
|
||||
metadata: bytes
|
||||
|
||||
|
||||
class ControlPointOperation:
|
||||
class OpCode(utils.OpenIntEnum):
|
||||
REMOTE_SCAN_STOPPED = 0x00
|
||||
REMOTE_SCAN_STARTED = 0x01
|
||||
ADD_SOURCE = 0x02
|
||||
MODIFY_SOURCE = 0x03
|
||||
SET_BROADCAST_CODE = 0x04
|
||||
REMOVE_SOURCE = 0x05
|
||||
|
||||
op_code: OpCode
|
||||
parameters: bytes
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> ControlPointOperation:
|
||||
op_code = data[0]
|
||||
|
||||
if op_code == cls.OpCode.REMOTE_SCAN_STOPPED:
|
||||
return RemoteScanStoppedOperation()
|
||||
|
||||
if op_code == cls.OpCode.REMOTE_SCAN_STARTED:
|
||||
return RemoteScanStartedOperation()
|
||||
|
||||
if op_code == cls.OpCode.ADD_SOURCE:
|
||||
return AddSourceOperation.from_parameters(data[1:])
|
||||
|
||||
if op_code == cls.OpCode.MODIFY_SOURCE:
|
||||
return ModifySourceOperation.from_parameters(data[1:])
|
||||
|
||||
if op_code == cls.OpCode.SET_BROADCAST_CODE:
|
||||
return SetBroadcastCodeOperation.from_parameters(data[1:])
|
||||
|
||||
if op_code == cls.OpCode.REMOVE_SOURCE:
|
||||
return RemoveSourceOperation.from_parameters(data[1:])
|
||||
|
||||
raise core.InvalidArgumentError("invalid op code")
|
||||
|
||||
def __init__(self, op_code: OpCode, parameters: bytes = b"") -> None:
|
||||
self.op_code = op_code
|
||||
self.parameters = parameters
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes([self.op_code]) + self.parameters
|
||||
|
||||
|
||||
class RemoteScanStoppedOperation(ControlPointOperation):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STOPPED)
|
||||
|
||||
|
||||
class RemoteScanStartedOperation(ControlPointOperation):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STARTED)
|
||||
|
||||
|
||||
class AddSourceOperation(ControlPointOperation):
|
||||
@classmethod
|
||||
def from_parameters(cls, parameters: bytes) -> AddSourceOperation:
|
||||
instance = cls.__new__(cls)
|
||||
instance.op_code = ControlPointOperation.OpCode.ADD_SOURCE
|
||||
instance.parameters = parameters
|
||||
instance.advertiser_address = hci.Address.parse_address_preceded_by_type(
|
||||
parameters, 1
|
||||
)[1]
|
||||
instance.advertising_sid = parameters[7]
|
||||
instance.broadcast_id = int.from_bytes(parameters[8:11], "little")
|
||||
instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[11])
|
||||
instance.pa_interval = struct.unpack("<H", parameters[12:14])[0]
|
||||
instance.subgroups = decode_subgroups(parameters[14:])
|
||||
return instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
advertiser_address: hci.Address,
|
||||
advertising_sid: int,
|
||||
broadcast_id: int,
|
||||
pa_sync: PeriodicAdvertisingSyncParams,
|
||||
pa_interval: int,
|
||||
subgroups: Sequence[SubgroupInfo],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
ControlPointOperation.OpCode.ADD_SOURCE,
|
||||
struct.pack(
|
||||
"<B6sB3sBH",
|
||||
advertiser_address.address_type,
|
||||
bytes(advertiser_address),
|
||||
advertising_sid,
|
||||
broadcast_id.to_bytes(3, "little"),
|
||||
pa_sync,
|
||||
pa_interval,
|
||||
)
|
||||
+ encode_subgroups(subgroups),
|
||||
)
|
||||
self.advertiser_address = advertiser_address
|
||||
self.advertising_sid = advertising_sid
|
||||
self.broadcast_id = broadcast_id
|
||||
self.pa_sync = pa_sync
|
||||
self.pa_interval = pa_interval
|
||||
self.subgroups = list(subgroups)
|
||||
|
||||
|
||||
class ModifySourceOperation(ControlPointOperation):
|
||||
@classmethod
|
||||
def from_parameters(cls, parameters: bytes) -> ModifySourceOperation:
|
||||
instance = cls.__new__(cls)
|
||||
instance.op_code = ControlPointOperation.OpCode.MODIFY_SOURCE
|
||||
instance.parameters = parameters
|
||||
instance.source_id = parameters[0]
|
||||
instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[1])
|
||||
instance.pa_interval = struct.unpack("<H", parameters[2:4])[0]
|
||||
instance.subgroups = decode_subgroups(parameters[4:])
|
||||
return instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_id: int,
|
||||
pa_sync: PeriodicAdvertisingSyncParams,
|
||||
pa_interval: int,
|
||||
subgroups: Sequence[SubgroupInfo],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
ControlPointOperation.OpCode.MODIFY_SOURCE,
|
||||
struct.pack("<BBH", source_id, pa_sync, pa_interval)
|
||||
+ encode_subgroups(subgroups),
|
||||
)
|
||||
self.source_id = source_id
|
||||
self.pa_sync = pa_sync
|
||||
self.pa_interval = pa_interval
|
||||
self.subgroups = list(subgroups)
|
||||
|
||||
|
||||
class SetBroadcastCodeOperation(ControlPointOperation):
|
||||
@classmethod
|
||||
def from_parameters(cls, parameters: bytes) -> SetBroadcastCodeOperation:
|
||||
instance = cls.__new__(cls)
|
||||
instance.op_code = ControlPointOperation.OpCode.SET_BROADCAST_CODE
|
||||
instance.parameters = parameters
|
||||
instance.source_id = parameters[0]
|
||||
instance.broadcast_code = parameters[1:17]
|
||||
return instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_id: int,
|
||||
broadcast_code: bytes,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
ControlPointOperation.OpCode.SET_BROADCAST_CODE,
|
||||
bytes([source_id]) + broadcast_code,
|
||||
)
|
||||
self.source_id = source_id
|
||||
self.broadcast_code = broadcast_code
|
||||
|
||||
if len(self.broadcast_code) != 16:
|
||||
raise core.InvalidArgumentError("broadcast_code must be 16 bytes")
|
||||
|
||||
|
||||
class RemoveSourceOperation(ControlPointOperation):
|
||||
@classmethod
|
||||
def from_parameters(cls, parameters: bytes) -> RemoveSourceOperation:
|
||||
instance = cls.__new__(cls)
|
||||
instance.op_code = ControlPointOperation.OpCode.REMOVE_SOURCE
|
||||
instance.parameters = parameters
|
||||
instance.source_id = parameters[0]
|
||||
return instance
|
||||
|
||||
def __init__(self, source_id: int) -> None:
|
||||
super().__init__(ControlPointOperation.OpCode.REMOVE_SOURCE, bytes([source_id]))
|
||||
self.source_id = source_id
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BroadcastReceiveState:
|
||||
class PeriodicAdvertisingSyncState(utils.OpenIntEnum):
|
||||
NOT_SYNCHRONIZED_TO_PA = 0x00
|
||||
SYNCINFO_REQUEST = 0x01
|
||||
SYNCHRONIZED_TO_PA = 0x02
|
||||
FAILED_TO_SYNCHRONIZE_TO_PA = 0x03
|
||||
NO_PAST = 0x04
|
||||
|
||||
class BigEncryption(utils.OpenIntEnum):
|
||||
NOT_ENCRYPTED = 0x00
|
||||
BROADCAST_CODE_REQUIRED = 0x01
|
||||
DECRYPTING = 0x02
|
||||
BAD_CODE = 0x03
|
||||
|
||||
source_id: int
|
||||
source_address: hci.Address
|
||||
source_adv_sid: int
|
||||
broadcast_id: int
|
||||
pa_sync_state: PeriodicAdvertisingSyncState
|
||||
big_encryption: BigEncryption
|
||||
bad_code: bytes
|
||||
subgroups: List[SubgroupInfo]
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> BroadcastReceiveState:
|
||||
source_id = data[0]
|
||||
_, source_address = hci.Address.parse_address_preceded_by_type(data, 2)
|
||||
source_adv_sid = data[8]
|
||||
broadcast_id = int.from_bytes(data[9:12], "little")
|
||||
pa_sync_state = cls.PeriodicAdvertisingSyncState(data[12])
|
||||
big_encryption = cls.BigEncryption(data[13])
|
||||
if big_encryption == cls.BigEncryption.BAD_CODE:
|
||||
bad_code = data[14:30]
|
||||
subgroups = decode_subgroups(data[30:])
|
||||
else:
|
||||
bad_code = b""
|
||||
subgroups = decode_subgroups(data[14:])
|
||||
|
||||
return cls(
|
||||
source_id,
|
||||
source_address,
|
||||
source_adv_sid,
|
||||
broadcast_id,
|
||||
pa_sync_state,
|
||||
big_encryption,
|
||||
bad_code,
|
||||
subgroups,
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return (
|
||||
struct.pack(
|
||||
"<BB6sB3sBB",
|
||||
self.source_id,
|
||||
self.source_address.address_type,
|
||||
bytes(self.source_address),
|
||||
self.source_adv_sid,
|
||||
self.broadcast_id.to_bytes(3, "little"),
|
||||
self.pa_sync_state,
|
||||
self.big_encryption,
|
||||
)
|
||||
+ self.bad_code
|
||||
+ encode_subgroups(self.subgroups)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class BroadcastAudioScanService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_BROADCAST_AUDIO_SCAN_SERVICE
|
||||
|
||||
def __init__(self):
|
||||
self.broadcast_audio_scan_control_point_characteristic = gatt.Characteristic(
|
||||
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC,
|
||||
gatt.Characteristic.Properties.WRITE
|
||||
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
|
||||
gatt.Characteristic.WRITEABLE,
|
||||
gatt.CharacteristicValue(
|
||||
write=self.on_broadcast_audio_scan_control_point_write
|
||||
),
|
||||
)
|
||||
|
||||
self.broadcast_receive_state_characteristic = gatt.Characteristic(
|
||||
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC,
|
||||
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY,
|
||||
gatt.Characteristic.Permissions.READABLE
|
||||
| gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
b"12", # TEST
|
||||
)
|
||||
|
||||
super().__init__([self.battery_level_characteristic])
|
||||
|
||||
def on_broadcast_audio_scan_control_point_write(
|
||||
self, connection: device.Connection, value: bytes
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = BroadcastAudioScanService
|
||||
|
||||
broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy
|
||||
broadcast_receive_states: List[gatt.SerializableCharacteristicAdapter]
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Broadcast Audio Scan Control Point characteristic not found"
|
||||
)
|
||||
self.broadcast_audio_scan_control_point = characteristics[0]
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Broadcast Receive State characteristic not found"
|
||||
)
|
||||
self.broadcast_receive_states = [
|
||||
gatt.SerializableCharacteristicAdapter(
|
||||
characteristic, BroadcastReceiveState
|
||||
)
|
||||
for characteristic in characteristics
|
||||
]
|
||||
|
||||
async def send_control_point_operation(
|
||||
self, operation: ControlPointOperation
|
||||
) -> None:
|
||||
await self.broadcast_audio_scan_control_point.write_value(
|
||||
bytes(operation), with_response=True
|
||||
)
|
||||
|
||||
async def remote_scan_started(self) -> None:
|
||||
await self.send_control_point_operation(RemoteScanStartedOperation())
|
||||
|
||||
async def remote_scan_stopped(self) -> None:
|
||||
await self.send_control_point_operation(RemoteScanStoppedOperation())
|
||||
|
||||
async def add_source(
|
||||
self,
|
||||
advertiser_address: hci.Address,
|
||||
advertising_sid: int,
|
||||
broadcast_id: int,
|
||||
pa_sync: PeriodicAdvertisingSyncParams,
|
||||
pa_interval: int,
|
||||
subgroups: Sequence[SubgroupInfo],
|
||||
) -> None:
|
||||
await self.send_control_point_operation(
|
||||
AddSourceOperation(
|
||||
advertiser_address,
|
||||
advertising_sid,
|
||||
broadcast_id,
|
||||
pa_sync,
|
||||
pa_interval,
|
||||
subgroups,
|
||||
)
|
||||
)
|
||||
|
||||
async def modify_source(
|
||||
self,
|
||||
source_id: int,
|
||||
pa_sync: PeriodicAdvertisingSyncParams,
|
||||
pa_interval: int,
|
||||
subgroups: Sequence[SubgroupInfo],
|
||||
) -> None:
|
||||
await self.send_control_point_operation(
|
||||
ModifySourceOperation(
|
||||
source_id,
|
||||
pa_sync,
|
||||
pa_interval,
|
||||
subgroups,
|
||||
)
|
||||
)
|
||||
|
||||
async def remove_source(self, source_id: int) -> None:
|
||||
await self.send_control_point_operation(RemoveSourceOperation(source_id))
|
||||
52
bumble/profiles/cap.py
Normal file
52
bumble/profiles/cap.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble.profiles import csip
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommonAudioServiceService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_COMMON_AUDIO_SERVICE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
coordinated_set_identification_service: csip.CoordinatedSetIdentificationService,
|
||||
) -> None:
|
||||
self.coordinated_set_identification_service = (
|
||||
coordinated_set_identification_service
|
||||
)
|
||||
super().__init__(
|
||||
characteristics=[],
|
||||
included_services=[coordinated_set_identification_service],
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommonAudioServiceServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = CommonAudioServiceService
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
257
bumble/profiles/csip.py
Normal file
257
bumble/profiles/csip.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import struct
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from bumble import core
|
||||
from bumble import crypto
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
SET_IDENTITY_RESOLVING_KEY_LENGTH = 16
|
||||
|
||||
|
||||
class SirkType(enum.IntEnum):
|
||||
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
|
||||
|
||||
ENCRYPTED = 0x00
|
||||
PLAINTEXT = 0x01
|
||||
|
||||
|
||||
class MemberLock(enum.IntEnum):
|
||||
'''Coordinated Set Identification Service - 5.3 Set Member Lock.'''
|
||||
|
||||
UNLOCKED = 0x01
|
||||
LOCKED = 0x02
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Crypto Toolbox
|
||||
# -----------------------------------------------------------------------------
|
||||
def s1(m: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.3 s1 SALT generation function.
|
||||
'''
|
||||
return crypto.aes_cmac(m[::-1], bytes(16))[::-1]
|
||||
|
||||
|
||||
def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.4 k1 derivation function.
|
||||
'''
|
||||
t = crypto.aes_cmac(n[::-1], salt[::-1])
|
||||
return crypto.aes_cmac(p[::-1], t)[::-1]
|
||||
|
||||
|
||||
def sef(k: bytes, r: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
|
||||
|
||||
SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
|
||||
* Plaintext in encryption
|
||||
* Cipher in decryption
|
||||
'''
|
||||
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
|
||||
|
||||
|
||||
def sih(k: bytes, r: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
|
||||
'''
|
||||
return crypto.e(k, r + bytes(13))[:3]
|
||||
|
||||
|
||||
def generate_rsi(sirk: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
|
||||
'''
|
||||
prand = crypto.generate_prand()
|
||||
return sih(sirk, prand) + prand
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
|
||||
|
||||
set_identity_resolving_key: bytes
|
||||
set_identity_resolving_key_characteristic: gatt.Characteristic
|
||||
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
|
||||
set_member_lock_characteristic: Optional[gatt.Characteristic] = None
|
||||
set_member_rank_characteristic: Optional[gatt.Characteristic] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
set_identity_resolving_key: bytes,
|
||||
set_identity_resolving_key_type: SirkType,
|
||||
coordinated_set_size: Optional[int] = None,
|
||||
set_member_lock: Optional[MemberLock] = None,
|
||||
set_member_rank: Optional[int] = None,
|
||||
) -> None:
|
||||
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
|
||||
raise core.InvalidArgumentError(
|
||||
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
|
||||
)
|
||||
|
||||
characteristics = []
|
||||
|
||||
self.set_identity_resolving_key = set_identity_resolving_key
|
||||
self.set_identity_resolving_key_type = set_identity_resolving_key_type
|
||||
self.set_identity_resolving_key_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(read=self.on_sirk_read),
|
||||
)
|
||||
characteristics.append(self.set_identity_resolving_key_characteristic)
|
||||
|
||||
if coordinated_set_size is not None:
|
||||
self.coordinated_set_size_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=struct.pack('B', coordinated_set_size),
|
||||
)
|
||||
characteristics.append(self.coordinated_set_size_characteristic)
|
||||
|
||||
if set_member_lock is not None:
|
||||
self.set_member_lock_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY
|
||||
| gatt.Characteristic.Properties.WRITE,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| gatt.Characteristic.Permissions.WRITEABLE,
|
||||
value=struct.pack('B', set_member_lock),
|
||||
)
|
||||
characteristics.append(self.set_member_lock_characteristic)
|
||||
|
||||
if set_member_rank is not None:
|
||||
self.set_member_rank_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=struct.pack('B', set_member_rank),
|
||||
)
|
||||
characteristics.append(self.set_member_rank_characteristic)
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
|
||||
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
|
||||
sirk_bytes = self.set_identity_resolving_key
|
||||
else:
|
||||
assert connection
|
||||
|
||||
if connection.transport == core.BT_LE_TRANSPORT:
|
||||
key = await connection.device.get_long_term_key(
|
||||
connection_handle=connection.handle, rand=b'', ediv=0
|
||||
)
|
||||
else:
|
||||
key = await connection.device.get_link_key(connection.peer_address)
|
||||
|
||||
if not key:
|
||||
raise core.InvalidOperationError('LTK or LinkKey is not present')
|
||||
|
||||
sirk_bytes = sef(key, self.set_identity_resolving_key)
|
||||
|
||||
return bytes([self.set_identity_resolving_key_type]) + sirk_bytes
|
||||
|
||||
def get_advertising_data(self) -> bytes:
|
||||
return bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
(
|
||||
core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
|
||||
generate_rsi(self.set_identity_resolving_key),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = CoordinatedSetIdentificationService
|
||||
|
||||
set_identity_resolving_key: gatt_client.CharacteristicProxy
|
||||
coordinated_set_size: Optional[gatt_client.CharacteristicProxy] = None
|
||||
set_member_lock: Optional[gatt_client.CharacteristicProxy] = None
|
||||
set_member_rank: Optional[gatt_client.CharacteristicProxy] = None
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.set_identity_resolving_key = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC
|
||||
)[0]
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC
|
||||
):
|
||||
self.coordinated_set_size = characteristics[0]
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC
|
||||
):
|
||||
self.set_member_lock = characteristics[0]
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
|
||||
):
|
||||
self.set_member_rank = characteristics[0]
|
||||
|
||||
async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
|
||||
'''Reads SIRK and decrypts if encrypted.'''
|
||||
response = await self.set_identity_resolving_key.read_value()
|
||||
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
|
||||
raise core.InvalidPacketError('Invalid SIRK value')
|
||||
|
||||
sirk_type = SirkType(response[0])
|
||||
if sirk_type == SirkType.PLAINTEXT:
|
||||
sirk = response[1:]
|
||||
else:
|
||||
connection = self.service_proxy.client.connection
|
||||
device = connection.device
|
||||
if connection.transport == core.BT_LE_TRANSPORT:
|
||||
key = await device.get_long_term_key(
|
||||
connection_handle=connection.handle, rand=b'', ediv=0
|
||||
)
|
||||
else:
|
||||
key = await device.get_link_key(connection.peer_address)
|
||||
|
||||
if not key:
|
||||
raise core.InvalidOperationError('LTK or LinkKey is not present')
|
||||
|
||||
sirk = sef(key, response[1:])
|
||||
|
||||
return (sirk_type, sirk)
|
||||
@@ -19,8 +19,8 @@
|
||||
import struct
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from ..gatt_client import ProfileServiceProxy
|
||||
from ..gatt import (
|
||||
from bumble.gatt_client import ServiceProxy, ProfileServiceProxy, CharacteristicProxy
|
||||
from bumble.gatt import (
|
||||
GATT_DEVICE_INFORMATION_SERVICE,
|
||||
GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC,
|
||||
GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC,
|
||||
@@ -59,12 +59,15 @@ class DeviceInformationService(TemplateService):
|
||||
firmware_revision: Optional[str] = None,
|
||||
software_revision: Optional[str] = None,
|
||||
system_id: Optional[Tuple[int, int]] = None, # (OUI, Manufacturer ID)
|
||||
ieee_regulatory_certification_data_list: Optional[bytes] = None
|
||||
ieee_regulatory_certification_data_list: Optional[bytes] = None,
|
||||
# TODO: pnp_id
|
||||
):
|
||||
characteristics = [
|
||||
Characteristic(
|
||||
uuid, Characteristic.Properties.READ, Characteristic.READABLE, field
|
||||
uuid,
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.READABLE,
|
||||
bytes(field, 'utf-8'),
|
||||
)
|
||||
for (field, uuid) in (
|
||||
(manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
|
||||
@@ -104,10 +107,19 @@ class DeviceInformationService(TemplateService):
|
||||
class DeviceInformationServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = DeviceInformationService
|
||||
|
||||
def __init__(self, service_proxy):
|
||||
manufacturer_name: Optional[UTF8CharacteristicAdapter]
|
||||
model_number: Optional[UTF8CharacteristicAdapter]
|
||||
serial_number: Optional[UTF8CharacteristicAdapter]
|
||||
hardware_revision: Optional[UTF8CharacteristicAdapter]
|
||||
firmware_revision: Optional[UTF8CharacteristicAdapter]
|
||||
software_revision: Optional[UTF8CharacteristicAdapter]
|
||||
system_id: Optional[DelegatedCharacteristicAdapter]
|
||||
ieee_regulatory_certification_data_list: Optional[CharacteristicProxy]
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
for (field, uuid) in (
|
||||
for field, uuid in (
|
||||
('manufacturer_name', GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
|
||||
('model_number', GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
|
||||
('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
|
||||
|
||||
110
bumble/profiles/gap.py
Normal file
110
bumble/profiles/gap.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Generic Access Profile"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import logging
|
||||
import struct
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from bumble.core import Appearance
|
||||
from bumble.gatt import (
|
||||
TemplateService,
|
||||
Characteristic,
|
||||
CharacteristicAdapter,
|
||||
DelegatedCharacteristicAdapter,
|
||||
UTF8CharacteristicAdapter,
|
||||
GATT_GENERIC_ACCESS_SERVICE,
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
GATT_APPEARANCE_CHARACTERISTIC,
|
||||
)
|
||||
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class GenericAccessService(TemplateService):
|
||||
UUID = GATT_GENERIC_ACCESS_SERVICE
|
||||
|
||||
def __init__(
|
||||
self, device_name: str, appearance: Union[Appearance, Tuple[int, int], int] = 0
|
||||
):
|
||||
if isinstance(appearance, int):
|
||||
appearance_int = appearance
|
||||
elif isinstance(appearance, tuple):
|
||||
appearance_int = (appearance[0] << 6) | appearance[1]
|
||||
elif isinstance(appearance, Appearance):
|
||||
appearance_int = int(appearance)
|
||||
else:
|
||||
raise TypeError()
|
||||
|
||||
self.device_name_characteristic = Characteristic(
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC,
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.READABLE,
|
||||
device_name.encode('utf-8')[:248],
|
||||
)
|
||||
|
||||
self.appearance_characteristic = Characteristic(
|
||||
GATT_APPEARANCE_CHARACTERISTIC,
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.READABLE,
|
||||
struct.pack('<H', appearance_int),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
[self.device_name_characteristic, self.appearance_characteristic]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class GenericAccessServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = GenericAccessService
|
||||
|
||||
device_name: Optional[CharacteristicAdapter]
|
||||
appearance: Optional[DelegatedCharacteristicAdapter]
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_DEVICE_NAME_CHARACTERISTIC
|
||||
):
|
||||
self.device_name = UTF8CharacteristicAdapter(characteristics[0])
|
||||
else:
|
||||
self.device_name = None
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_APPEARANCE_CHARACTERISTIC
|
||||
):
|
||||
self.appearance = DelegatedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
decode=lambda value: Appearance.from_int(
|
||||
struct.unpack_from('<H', value, 0)[0],
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.appearance = None
|
||||
198
bumble/profiles/gmap.py
Normal file
198
bumble/profiles/gmap.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""LE Audio - Gaming Audio Profile"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
from typing import Optional
|
||||
|
||||
from bumble.gatt import (
|
||||
TemplateService,
|
||||
DelegatedCharacteristicAdapter,
|
||||
Characteristic,
|
||||
GATT_GAMING_AUDIO_SERVICE,
|
||||
GATT_GMAP_ROLE_CHARACTERISTIC,
|
||||
GATT_UGG_FEATURES_CHARACTERISTIC,
|
||||
GATT_UGT_FEATURES_CHARACTERISTIC,
|
||||
GATT_BGS_FEATURES_CHARACTERISTIC,
|
||||
GATT_BGR_FEATURES_CHARACTERISTIC,
|
||||
InvalidServiceError,
|
||||
)
|
||||
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
|
||||
from enum import IntFlag
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class GmapRole(IntFlag):
|
||||
UNICAST_GAME_GATEWAY = 1 << 0
|
||||
UNICAST_GAME_TERMINAL = 1 << 1
|
||||
BROADCAST_GAME_SENDER = 1 << 2
|
||||
BROADCAST_GAME_RECEIVER = 1 << 3
|
||||
|
||||
|
||||
class UggFeatures(IntFlag):
|
||||
UGG_MULTIPLEX = 1 << 0
|
||||
UGG_96_KBPS_SOURCE = 1 << 1
|
||||
UGG_MULTISINK = 1 << 2
|
||||
|
||||
|
||||
class UgtFeatures(IntFlag):
|
||||
UGT_SOURCE = 1 << 0
|
||||
UGT_80_KBPS_SOURCE = 1 << 1
|
||||
UGT_SINK = 1 << 2
|
||||
UGT_64_KBPS_SINK = 1 << 3
|
||||
UGT_MULTIPLEX = 1 << 4
|
||||
UGT_MULTISINK = 1 << 5
|
||||
UGT_MULTISOURCE = 1 << 6
|
||||
|
||||
|
||||
class BgsFeatures(IntFlag):
|
||||
BGS_96_KBPS = 1 << 0
|
||||
|
||||
|
||||
class BgrFeatures(IntFlag):
|
||||
BGR_MULTISINK = 1 << 0
|
||||
BGR_MULTIPLEX = 1 << 1
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class GamingAudioService(TemplateService):
|
||||
UUID = GATT_GAMING_AUDIO_SERVICE
|
||||
|
||||
gmap_role: Characteristic
|
||||
ugg_features: Optional[Characteristic] = None
|
||||
ugt_features: Optional[Characteristic] = None
|
||||
bgs_features: Optional[Characteristic] = None
|
||||
bgr_features: Optional[Characteristic] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gmap_role: GmapRole,
|
||||
ugg_features: Optional[UggFeatures] = None,
|
||||
ugt_features: Optional[UgtFeatures] = None,
|
||||
bgs_features: Optional[BgsFeatures] = None,
|
||||
bgr_features: Optional[BgrFeatures] = None,
|
||||
) -> None:
|
||||
characteristics = []
|
||||
|
||||
ugg_features = UggFeatures(0) if ugg_features is None else ugg_features
|
||||
ugt_features = UgtFeatures(0) if ugt_features is None else ugt_features
|
||||
bgs_features = BgsFeatures(0) if bgs_features is None else bgs_features
|
||||
bgr_features = BgrFeatures(0) if bgr_features is None else bgr_features
|
||||
|
||||
self.gmap_role = Characteristic(
|
||||
uuid=GATT_GMAP_ROLE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', gmap_role),
|
||||
)
|
||||
characteristics.append(self.gmap_role)
|
||||
|
||||
if gmap_role & GmapRole.UNICAST_GAME_GATEWAY:
|
||||
self.ugg_features = Characteristic(
|
||||
uuid=GATT_UGG_FEATURES_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', ugg_features),
|
||||
)
|
||||
characteristics.append(self.ugg_features)
|
||||
|
||||
if gmap_role & GmapRole.UNICAST_GAME_TERMINAL:
|
||||
self.ugt_features = Characteristic(
|
||||
uuid=GATT_UGT_FEATURES_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', ugt_features),
|
||||
)
|
||||
characteristics.append(self.ugt_features)
|
||||
|
||||
if gmap_role & GmapRole.BROADCAST_GAME_SENDER:
|
||||
self.bgs_features = Characteristic(
|
||||
uuid=GATT_BGS_FEATURES_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', bgs_features),
|
||||
)
|
||||
characteristics.append(self.bgs_features)
|
||||
|
||||
if gmap_role & GmapRole.BROADCAST_GAME_RECEIVER:
|
||||
self.bgr_features = Characteristic(
|
||||
uuid=GATT_BGR_FEATURES_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', bgr_features),
|
||||
)
|
||||
characteristics.append(self.bgr_features)
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class GamingAudioServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = GamingAudioService
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_GMAP_ROLE_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise InvalidServiceError("GMAP Role Characteristic not found")
|
||||
self.gmap_role = DelegatedCharacteristicAdapter(
|
||||
characteristic=characteristics[0],
|
||||
decode=lambda value: GmapRole(value[0]),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_UGG_FEATURES_CHARACTERISTIC
|
||||
):
|
||||
self.ugg_features = DelegatedCharacteristicAdapter(
|
||||
characteristic=characteristics[0],
|
||||
decode=lambda value: UggFeatures(value[0]),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_UGT_FEATURES_CHARACTERISTIC
|
||||
):
|
||||
self.ugt_features = DelegatedCharacteristicAdapter(
|
||||
characteristic=characteristics[0],
|
||||
decode=lambda value: UgtFeatures(value[0]),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_BGS_FEATURES_CHARACTERISTIC
|
||||
):
|
||||
self.bgs_features = DelegatedCharacteristicAdapter(
|
||||
characteristic=characteristics[0],
|
||||
decode=lambda value: BgsFeatures(value[0]),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_BGR_FEATURES_CHARACTERISTIC
|
||||
):
|
||||
self.bgr_features = DelegatedCharacteristicAdapter(
|
||||
characteristic=characteristics[0],
|
||||
decode=lambda value: BgrFeatures(value[0]),
|
||||
)
|
||||
674
bumble/profiles/hap.py
Normal file
674
bumble/profiles/hap.py
Normal file
@@ -0,0 +1,674 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import functools
|
||||
from bumble import att, gatt, gatt_client
|
||||
from bumble.core import InvalidArgumentError, InvalidStateError
|
||||
from bumble.device import Device, Connection
|
||||
from bumble.utils import AsyncRunner, OpenIntEnum
|
||||
from bumble.hci import Address
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
class ErrorCode(OpenIntEnum):
|
||||
'''See Hearing Access Service 2.4. Attribute Profile error codes.'''
|
||||
|
||||
INVALID_OPCODE = 0x80
|
||||
WRITE_NAME_NOT_ALLOWED = 0x81
|
||||
PRESET_SYNCHRONIZATION_NOT_SUPPORTED = 0x82
|
||||
PRESET_OPERATION_NOT_POSSIBLE = 0x83
|
||||
INVALID_PARAMETERS_LENGTH = 0x84
|
||||
|
||||
|
||||
class HearingAidType(OpenIntEnum):
|
||||
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
|
||||
|
||||
BINAURAL_HEARING_AID = 0b00
|
||||
MONAURAL_HEARING_AID = 0b01
|
||||
BANDED_HEARING_AID = 0b10
|
||||
|
||||
|
||||
class PresetSynchronizationSupport(OpenIntEnum):
|
||||
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
|
||||
|
||||
PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED = 0b0
|
||||
PRESET_SYNCHRONIZATION_IS_SUPPORTED = 0b1
|
||||
|
||||
|
||||
class IndependentPresets(OpenIntEnum):
|
||||
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
|
||||
|
||||
IDENTICAL_PRESET_RECORD = 0b0
|
||||
DIFFERENT_PRESET_RECORD = 0b1
|
||||
|
||||
|
||||
class DynamicPresets(OpenIntEnum):
|
||||
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
|
||||
|
||||
PRESET_RECORDS_DOES_NOT_CHANGE = 0b0
|
||||
PRESET_RECORDS_MAY_CHANGE = 0b1
|
||||
|
||||
|
||||
class WritablePresetsSupport(OpenIntEnum):
|
||||
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
|
||||
|
||||
WRITABLE_PRESET_RECORDS_NOT_SUPPORTED = 0b0
|
||||
WRITABLE_PRESET_RECORDS_SUPPORTED = 0b1
|
||||
|
||||
|
||||
class HearingAidPresetControlPointOpcode(OpenIntEnum):
|
||||
'''See Hearing Access Service 3.3.1 Hearing Aid Preset Control Point operation requirements.'''
|
||||
|
||||
# fmt: off
|
||||
READ_PRESETS_REQUEST = 0x01
|
||||
READ_PRESET_RESPONSE = 0x02
|
||||
PRESET_CHANGED = 0x03
|
||||
WRITE_PRESET_NAME = 0x04
|
||||
SET_ACTIVE_PRESET = 0x05
|
||||
SET_NEXT_PRESET = 0x06
|
||||
SET_PREVIOUS_PRESET = 0x07
|
||||
SET_ACTIVE_PRESET_SYNCHRONIZED_LOCALLY = 0x08
|
||||
SET_NEXT_PRESET_SYNCHRONIZED_LOCALLY = 0x09
|
||||
SET_PREVIOUS_PRESET_SYNCHRONIZED_LOCALLY = 0x0A
|
||||
|
||||
|
||||
@dataclass
|
||||
class HearingAidFeatures:
|
||||
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
|
||||
|
||||
hearing_aid_type: HearingAidType
|
||||
preset_synchronization_support: PresetSynchronizationSupport
|
||||
independent_presets: IndependentPresets
|
||||
dynamic_presets: DynamicPresets
|
||||
writable_presets_support: WritablePresetsSupport
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes(
|
||||
[
|
||||
(self.hearing_aid_type << 0)
|
||||
| (self.preset_synchronization_support << 2)
|
||||
| (self.independent_presets << 3)
|
||||
| (self.dynamic_presets << 4)
|
||||
| (self.writable_presets_support << 5)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def HearingAidFeatures_from_bytes(data: int) -> HearingAidFeatures:
|
||||
return HearingAidFeatures(
|
||||
HearingAidType(data & 0b11),
|
||||
PresetSynchronizationSupport(data >> 2 & 0b1),
|
||||
IndependentPresets(data >> 3 & 0b1),
|
||||
DynamicPresets(data >> 4 & 0b1),
|
||||
WritablePresetsSupport(data >> 5 & 0b1),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PresetChangedOperation:
|
||||
'''See Hearing Access Service 3.2.2.2. Preset Changed operation.'''
|
||||
|
||||
class ChangeId(OpenIntEnum):
|
||||
# fmt: off
|
||||
GENERIC_UPDATE = 0x00
|
||||
PRESET_RECORD_DELETED = 0x01
|
||||
PRESET_RECORD_AVAILABLE = 0x02
|
||||
PRESET_RECORD_UNAVAILABLE = 0x03
|
||||
|
||||
@dataclass
|
||||
class Generic:
|
||||
prev_index: int
|
||||
preset_record: PresetRecord
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes([self.prev_index]) + bytes(self.preset_record)
|
||||
|
||||
change_id: ChangeId
|
||||
additional_parameters: Union[Generic, int]
|
||||
|
||||
def to_bytes(self, is_last: bool) -> bytes:
|
||||
if isinstance(self.additional_parameters, PresetChangedOperation.Generic):
|
||||
additional_parameters_bytes = bytes(self.additional_parameters)
|
||||
else:
|
||||
additional_parameters_bytes = bytes([self.additional_parameters])
|
||||
|
||||
return (
|
||||
bytes(
|
||||
[
|
||||
HearingAidPresetControlPointOpcode.PRESET_CHANGED,
|
||||
self.change_id,
|
||||
is_last,
|
||||
]
|
||||
)
|
||||
+ additional_parameters_bytes
|
||||
)
|
||||
|
||||
|
||||
class PresetChangedOperationDeleted(PresetChangedOperation):
|
||||
def __init__(self, index) -> None:
|
||||
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_DELETED
|
||||
self.additional_parameters = index
|
||||
|
||||
|
||||
class PresetChangedOperationAvailable(PresetChangedOperation):
|
||||
def __init__(self, index) -> None:
|
||||
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_AVAILABLE
|
||||
self.additional_parameters = index
|
||||
|
||||
|
||||
class PresetChangedOperationUnavailable(PresetChangedOperation):
|
||||
def __init__(self, index) -> None:
|
||||
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_UNAVAILABLE
|
||||
self.additional_parameters = index
|
||||
|
||||
|
||||
@dataclass
|
||||
class PresetRecord:
|
||||
'''See Hearing Access Service 2.8. Preset record.'''
|
||||
|
||||
@dataclass
|
||||
class Property:
|
||||
class Writable(OpenIntEnum):
|
||||
CANNOT_BE_WRITTEN = 0b0
|
||||
CAN_BE_WRITTEN = 0b1
|
||||
|
||||
class IsAvailable(OpenIntEnum):
|
||||
IS_UNAVAILABLE = 0b0
|
||||
IS_AVAILABLE = 0b1
|
||||
|
||||
writable: Writable = Writable.CAN_BE_WRITTEN
|
||||
is_available: IsAvailable = IsAvailable.IS_AVAILABLE
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes([self.writable | (self.is_available << 1)])
|
||||
|
||||
index: int
|
||||
name: str
|
||||
properties: Property = field(default_factory=Property)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes([self.index]) + bytes(self.properties) + self.name.encode('utf-8')
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return (
|
||||
self.properties.is_available
|
||||
== PresetRecord.Property.IsAvailable.IS_AVAILABLE
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class HearingAccessService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_HEARING_ACCESS_SERVICE
|
||||
|
||||
hearing_aid_features_characteristic: gatt.Characteristic
|
||||
hearing_aid_preset_control_point: gatt.Characteristic
|
||||
active_preset_index_characteristic: gatt.Characteristic
|
||||
active_preset_index: int
|
||||
active_preset_index_per_device: Dict[Address, int]
|
||||
|
||||
device: Device
|
||||
|
||||
server_features: HearingAidFeatures
|
||||
preset_records: Dict[int, PresetRecord] # key is the preset index
|
||||
read_presets_request_in_progress: bool
|
||||
|
||||
preset_changed_operations_history_per_device: Dict[
|
||||
Address, List[PresetChangedOperation]
|
||||
]
|
||||
|
||||
# Keep an updated list of connected client to send notification to
|
||||
currently_connected_clients: Set[Connection]
|
||||
|
||||
def __init__(
|
||||
self, device: Device, features: HearingAidFeatures, presets: List[PresetRecord]
|
||||
) -> None:
|
||||
self.active_preset_index_per_device = {}
|
||||
self.read_presets_request_in_progress = False
|
||||
self.preset_changed_operations_history_per_device = {}
|
||||
self.currently_connected_clients = set()
|
||||
|
||||
self.device = device
|
||||
self.server_features = features
|
||||
if len(presets) < 1:
|
||||
raise InvalidArgumentError(f'Invalid presets: {presets}')
|
||||
|
||||
self.preset_records = {}
|
||||
for p in presets:
|
||||
if len(p.name.encode()) < 1 or len(p.name.encode()) > 40:
|
||||
raise InvalidArgumentError(f'Invalid name: {p.name}')
|
||||
|
||||
self.preset_records[p.index] = p
|
||||
|
||||
# associate the lowest index as the current active preset at startup
|
||||
self.active_preset_index = sorted(self.preset_records.keys())[0]
|
||||
|
||||
@device.on('connection') # type: ignore
|
||||
def on_connection(connection: Connection) -> None:
|
||||
@connection.on('disconnection') # type: ignore
|
||||
def on_disconnection(_reason) -> None:
|
||||
self.currently_connected_clients.remove(connection)
|
||||
|
||||
@connection.on('pairing') # type: ignore
|
||||
def on_pairing(*_: Any) -> None:
|
||||
self.on_incoming_paired_connection(connection)
|
||||
|
||||
if connection.peer_resolvable_address:
|
||||
self.on_incoming_paired_connection(connection)
|
||||
|
||||
self.hearing_aid_features_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=bytes(self.server_features),
|
||||
)
|
||||
self.hearing_aid_preset_control_point = gatt.Characteristic(
|
||||
uuid=gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=(
|
||||
gatt.Characteristic.Properties.WRITE
|
||||
| gatt.Characteristic.Properties.INDICATE
|
||||
),
|
||||
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(
|
||||
write=self._on_write_hearing_aid_preset_control_point
|
||||
),
|
||||
)
|
||||
self.active_preset_index_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC,
|
||||
properties=(
|
||||
gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY
|
||||
),
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(read=self._on_read_active_preset_index),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
[
|
||||
self.hearing_aid_features_characteristic,
|
||||
self.hearing_aid_preset_control_point,
|
||||
self.active_preset_index_characteristic,
|
||||
]
|
||||
)
|
||||
|
||||
def on_incoming_paired_connection(self, connection: Connection):
|
||||
'''Setup initial operations to handle a remote bonded HAP device'''
|
||||
# TODO Should we filter on HAP device only ?
|
||||
self.currently_connected_clients.add(connection)
|
||||
if (
|
||||
connection.peer_address
|
||||
not in self.preset_changed_operations_history_per_device
|
||||
):
|
||||
self.preset_changed_operations_history_per_device[
|
||||
connection.peer_address
|
||||
] = []
|
||||
return
|
||||
|
||||
async def on_connection_async() -> None:
|
||||
# Send all the PresetChangedOperation that occur when not connected
|
||||
await self._preset_changed_operation(connection)
|
||||
# Update the active preset index if needed
|
||||
await self.notify_active_preset_for_connection(connection)
|
||||
|
||||
connection.abort_on('disconnection', on_connection_async())
|
||||
|
||||
def _on_read_active_preset_index(
|
||||
self, __connection__: Optional[Connection]
|
||||
) -> bytes:
|
||||
return bytes([self.active_preset_index])
|
||||
|
||||
# TODO this need to be triggered when device is unbonded
|
||||
def on_forget(self, addr: Address) -> None:
|
||||
self.preset_changed_operations_history_per_device.pop(addr)
|
||||
|
||||
async def _on_write_hearing_aid_preset_control_point(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
):
|
||||
assert connection
|
||||
|
||||
opcode = HearingAidPresetControlPointOpcode(value[0])
|
||||
handler = getattr(self, '_on_' + opcode.name.lower())
|
||||
await handler(connection, value)
|
||||
|
||||
async def _on_read_presets_request(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
):
|
||||
assert connection
|
||||
if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements
|
||||
logging.warning(f'HAS require MTU >= 49: {connection}')
|
||||
|
||||
if self.read_presets_request_in_progress:
|
||||
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
|
||||
self.read_presets_request_in_progress = True
|
||||
|
||||
start_index = value[1]
|
||||
if start_index == 0x00:
|
||||
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
|
||||
|
||||
num_presets = value[2]
|
||||
if num_presets == 0x00:
|
||||
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
|
||||
|
||||
# Sending `num_presets` presets ordered by increasing index field, starting from start_index
|
||||
presets = [
|
||||
self.preset_records[key]
|
||||
for key in sorted(self.preset_records.keys())
|
||||
if self.preset_records[key].index >= start_index
|
||||
]
|
||||
del presets[num_presets:]
|
||||
if len(presets) == 0:
|
||||
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
|
||||
|
||||
AsyncRunner.spawn(self._read_preset_response(connection, presets))
|
||||
|
||||
async def _read_preset_response(
|
||||
self, connection: Connection, presets: List[PresetRecord]
|
||||
):
|
||||
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Read Presets Request operation aborted and shall not either continue or restart the operation when the client reconnects.
|
||||
try:
|
||||
for i, preset in enumerate(presets):
|
||||
await connection.device.indicate_subscriber(
|
||||
connection,
|
||||
self.hearing_aid_preset_control_point,
|
||||
value=bytes(
|
||||
[
|
||||
HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE,
|
||||
i == len(presets) - 1,
|
||||
]
|
||||
)
|
||||
+ bytes(preset),
|
||||
)
|
||||
|
||||
finally:
|
||||
# indicate_subscriber can raise a TimeoutError, we need to gracefully terminate the operation
|
||||
self.read_presets_request_in_progress = False
|
||||
|
||||
async def generic_update(self, op: PresetChangedOperation) -> None:
|
||||
'''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent'''
|
||||
await self._notifyPresetOperations(op)
|
||||
|
||||
async def delete_preset(self, index: int) -> None:
|
||||
'''Server API to delete a preset. It should not be the current active preset'''
|
||||
|
||||
if index == self.active_preset_index:
|
||||
raise InvalidStateError('Cannot delete active preset')
|
||||
|
||||
del self.preset_records[index]
|
||||
await self._notifyPresetOperations(PresetChangedOperationDeleted(index))
|
||||
|
||||
async def available_preset(self, index: int) -> None:
|
||||
'''Server API to make a preset available'''
|
||||
|
||||
preset = self.preset_records[index]
|
||||
preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE
|
||||
await self._notifyPresetOperations(PresetChangedOperationAvailable(index))
|
||||
|
||||
async def unavailable_preset(self, index: int) -> None:
|
||||
'''Server API to make a preset unavailable. It should not be the current active preset'''
|
||||
|
||||
if index == self.active_preset_index:
|
||||
raise InvalidStateError('Cannot set active preset as unavailable')
|
||||
|
||||
preset = self.preset_records[index]
|
||||
preset.properties.is_available = (
|
||||
PresetRecord.Property.IsAvailable.IS_UNAVAILABLE
|
||||
)
|
||||
await self._notifyPresetOperations(PresetChangedOperationUnavailable(index))
|
||||
|
||||
async def _preset_changed_operation(self, connection: Connection) -> None:
|
||||
'''Send all PresetChangedOperation saved for a given connection'''
|
||||
op_list = self.preset_changed_operations_history_per_device.get(
|
||||
connection.peer_address, []
|
||||
)
|
||||
|
||||
# Notification will be sent in index order
|
||||
def get_op_index(op: PresetChangedOperation) -> int:
|
||||
if isinstance(op.additional_parameters, PresetChangedOperation.Generic):
|
||||
return op.additional_parameters.prev_index
|
||||
return op.additional_parameters
|
||||
|
||||
op_list.sort(key=get_op_index)
|
||||
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Preset Changed operation aborted and shall continue the operation when the client reconnects.
|
||||
while len(op_list) > 0:
|
||||
try:
|
||||
await connection.device.indicate_subscriber(
|
||||
connection,
|
||||
self.hearing_aid_preset_control_point,
|
||||
value=op_list[0].to_bytes(len(op_list) == 1),
|
||||
)
|
||||
# Remove item once sent, and keep the non sent item in the list
|
||||
op_list.pop(0)
|
||||
except TimeoutError:
|
||||
break
|
||||
|
||||
async def _notifyPresetOperations(self, op: PresetChangedOperation) -> None:
|
||||
for historyList in self.preset_changed_operations_history_per_device.values():
|
||||
historyList.append(op)
|
||||
|
||||
for connection in self.currently_connected_clients:
|
||||
await self._preset_changed_operation(connection)
|
||||
|
||||
async def _on_write_preset_name(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
):
|
||||
assert connection
|
||||
|
||||
if self.read_presets_request_in_progress:
|
||||
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
|
||||
|
||||
index = value[1]
|
||||
preset = self.preset_records.get(index, None)
|
||||
if (
|
||||
not preset
|
||||
or preset.properties.writable
|
||||
== PresetRecord.Property.Writable.CANNOT_BE_WRITTEN
|
||||
):
|
||||
raise att.ATT_Error(ErrorCode.WRITE_NAME_NOT_ALLOWED)
|
||||
|
||||
name = value[2:].decode('utf-8')
|
||||
if not name or len(name) > 40:
|
||||
raise att.ATT_Error(ErrorCode.INVALID_PARAMETERS_LENGTH)
|
||||
|
||||
preset.name = name
|
||||
|
||||
await self.generic_update(
|
||||
PresetChangedOperation(
|
||||
PresetChangedOperation.ChangeId.GENERIC_UPDATE,
|
||||
PresetChangedOperation.Generic(index, preset),
|
||||
)
|
||||
)
|
||||
|
||||
async def notify_active_preset_for_connection(self, connection: Connection) -> None:
|
||||
if (
|
||||
self.active_preset_index_per_device.get(connection.peer_address, 0x00)
|
||||
== self.active_preset_index
|
||||
):
|
||||
# Nothing to do, peer is already updated
|
||||
return
|
||||
|
||||
await connection.device.notify_subscriber(
|
||||
connection,
|
||||
attribute=self.active_preset_index_characteristic,
|
||||
value=bytes([self.active_preset_index]),
|
||||
)
|
||||
self.active_preset_index_per_device[connection.peer_address] = (
|
||||
self.active_preset_index
|
||||
)
|
||||
|
||||
async def notify_active_preset(self) -> None:
|
||||
for connection in self.currently_connected_clients:
|
||||
await self.notify_active_preset_for_connection(connection)
|
||||
|
||||
async def set_active_preset(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
) -> None:
|
||||
assert connection
|
||||
index = value[1]
|
||||
preset = self.preset_records.get(index, None)
|
||||
if (
|
||||
not preset
|
||||
or preset.properties.is_available
|
||||
!= PresetRecord.Property.IsAvailable.IS_AVAILABLE
|
||||
):
|
||||
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
|
||||
|
||||
if index == self.active_preset_index:
|
||||
# Already at correct value
|
||||
return
|
||||
|
||||
self.active_preset_index = index
|
||||
await self.notify_active_preset()
|
||||
|
||||
async def _on_set_active_preset(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
):
|
||||
await self.set_active_preset(connection, value)
|
||||
|
||||
async def set_next_or_previous_preset(
|
||||
self, connection: Optional[Connection], is_previous
|
||||
):
|
||||
'''Set the next or the previous preset as active'''
|
||||
assert connection
|
||||
|
||||
if self.active_preset_index == 0x00:
|
||||
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
|
||||
|
||||
first_preset: Optional[PresetRecord] = None # To loop to first preset
|
||||
next_preset: Optional[PresetRecord] = None
|
||||
for index, record in sorted(self.preset_records.items(), reverse=is_previous):
|
||||
if not record.is_available():
|
||||
continue
|
||||
if first_preset == None:
|
||||
first_preset = record
|
||||
if is_previous:
|
||||
if index >= self.active_preset_index:
|
||||
continue
|
||||
elif index <= self.active_preset_index:
|
||||
continue
|
||||
next_preset = record
|
||||
break
|
||||
|
||||
if not first_preset: # If no other preset are available
|
||||
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
|
||||
|
||||
if next_preset:
|
||||
self.active_preset_index = next_preset.index
|
||||
else:
|
||||
self.active_preset_index = first_preset.index
|
||||
await self.notify_active_preset()
|
||||
|
||||
async def _on_set_next_preset(
|
||||
self, connection: Optional[Connection], __value__: bytes
|
||||
) -> None:
|
||||
await self.set_next_or_previous_preset(connection, False)
|
||||
|
||||
async def _on_set_previous_preset(
|
||||
self, connection: Optional[Connection], __value__: bytes
|
||||
) -> None:
|
||||
await self.set_next_or_previous_preset(connection, True)
|
||||
|
||||
async def _on_set_active_preset_synchronized_locally(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
):
|
||||
if (
|
||||
self.server_features.preset_synchronization_support
|
||||
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
|
||||
):
|
||||
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
|
||||
await self.set_active_preset(connection, value)
|
||||
# TODO (low priority) inform other server of the change
|
||||
|
||||
async def _on_set_next_preset_synchronized_locally(
|
||||
self, connection: Optional[Connection], __value__: bytes
|
||||
):
|
||||
if (
|
||||
self.server_features.preset_synchronization_support
|
||||
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
|
||||
):
|
||||
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
|
||||
await self.set_next_or_previous_preset(connection, False)
|
||||
# TODO (low priority) inform other server of the change
|
||||
|
||||
async def _on_set_previous_preset_synchronized_locally(
|
||||
self, connection: Optional[Connection], __value__: bytes
|
||||
):
|
||||
if (
|
||||
self.server_features.preset_synchronization_support
|
||||
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
|
||||
):
|
||||
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
|
||||
await self.set_next_or_previous_preset(connection, True)
|
||||
# TODO (low priority) inform other server of the change
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = HearingAccessService
|
||||
|
||||
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
|
||||
preset_control_point_indications: asyncio.Queue
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.server_features = gatt.PackedCharacteristicAdapter(
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC
|
||||
)[0],
|
||||
'B',
|
||||
)
|
||||
|
||||
self.hearing_aid_preset_control_point = (
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC
|
||||
)[0]
|
||||
)
|
||||
|
||||
self.active_preset_index = gatt.PackedCharacteristicAdapter(
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC
|
||||
)[0],
|
||||
'B',
|
||||
)
|
||||
|
||||
async def setup_subscription(self):
|
||||
self.preset_control_point_indications = asyncio.Queue()
|
||||
self.active_preset_index_notification = asyncio.Queue()
|
||||
|
||||
def on_active_preset_index_notification(data: bytes):
|
||||
self.active_preset_index_notification.put_nowait(data)
|
||||
|
||||
def on_preset_control_point_indication(data: bytes):
|
||||
self.preset_control_point_indications.put_nowait(data)
|
||||
|
||||
await self.hearing_aid_preset_control_point.subscribe(
|
||||
functools.partial(on_preset_control_point_indication), prefer_notify=False
|
||||
)
|
||||
|
||||
await self.active_preset_index.subscribe(
|
||||
functools.partial(on_active_preset_index_notification)
|
||||
)
|
||||
@@ -19,6 +19,7 @@
|
||||
from enum import IntEnum
|
||||
import struct
|
||||
|
||||
from bumble import core
|
||||
from ..gatt_client import ProfileServiceProxy
|
||||
from ..att import ATT_Error
|
||||
from ..gatt import (
|
||||
@@ -29,6 +30,7 @@ from ..gatt import (
|
||||
TemplateService,
|
||||
Characteristic,
|
||||
CharacteristicValue,
|
||||
SerializableCharacteristicAdapter,
|
||||
DelegatedCharacteristicAdapter,
|
||||
PackedCharacteristicAdapter,
|
||||
)
|
||||
@@ -42,12 +44,12 @@ class HeartRateService(TemplateService):
|
||||
RESET_ENERGY_EXPENDED = 0x01
|
||||
|
||||
class BodySensorLocation(IntEnum):
|
||||
OTHER = (0,)
|
||||
CHEST = (1,)
|
||||
WRIST = (2,)
|
||||
FINGER = (3,)
|
||||
HAND = (4,)
|
||||
EAR_LOBE = (5,)
|
||||
OTHER = 0
|
||||
CHEST = 1
|
||||
WRIST = 2
|
||||
FINGER = 3
|
||||
HAND = 4
|
||||
EAR_LOBE = 5
|
||||
FOOT = 6
|
||||
|
||||
class HeartRateMeasurement:
|
||||
@@ -59,17 +61,17 @@ class HeartRateService(TemplateService):
|
||||
rr_intervals=None,
|
||||
):
|
||||
if heart_rate < 0 or heart_rate > 0xFFFF:
|
||||
raise ValueError('heart_rate out of range')
|
||||
raise core.InvalidArgumentError('heart_rate out of range')
|
||||
|
||||
if energy_expended is not None and (
|
||||
energy_expended < 0 or energy_expended > 0xFFFF
|
||||
):
|
||||
raise ValueError('energy_expended out of range')
|
||||
raise core.InvalidArgumentError('energy_expended out of range')
|
||||
|
||||
if rr_intervals:
|
||||
for rr_interval in rr_intervals:
|
||||
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
|
||||
raise ValueError('rr_intervals out of range')
|
||||
raise core.InvalidArgumentError('rr_intervals out of range')
|
||||
|
||||
self.heart_rate = heart_rate
|
||||
self.sensor_contact_detected = sensor_contact_detected
|
||||
@@ -149,15 +151,14 @@ class HeartRateService(TemplateService):
|
||||
body_sensor_location=None,
|
||||
reset_energy_expended=None,
|
||||
):
|
||||
self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter(
|
||||
self.heart_rate_measurement_characteristic = SerializableCharacteristicAdapter(
|
||||
Characteristic(
|
||||
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
|
||||
Characteristic.Properties.NOTIFY,
|
||||
0,
|
||||
CharacteristicValue(read=read_heart_rate_measurement),
|
||||
),
|
||||
# pylint: disable=unnecessary-lambda
|
||||
encode=lambda value: bytes(value),
|
||||
HeartRateService.HeartRateMeasurement,
|
||||
)
|
||||
characteristics = [self.heart_rate_measurement_characteristic]
|
||||
|
||||
@@ -203,9 +204,8 @@ class HeartRateServiceProxy(ProfileServiceProxy):
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
|
||||
):
|
||||
self.heart_rate_measurement = DelegatedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
decode=HeartRateService.HeartRateMeasurement.from_bytes,
|
||||
self.heart_rate_measurement = SerializableCharacteristicAdapter(
|
||||
characteristics[0], HeartRateService.HeartRateMeasurement
|
||||
)
|
||||
else:
|
||||
self.heart_rate_measurement = None
|
||||
|
||||
83
bumble/profiles/le_audio.py
Normal file
83
bumble/profiles/le_audio.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import struct
|
||||
from typing import List, Type
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble import utils
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class Metadata:
|
||||
'''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures.
|
||||
|
||||
As Metadata fields may extend, and Spec doesn't forbid duplication, we don't parse
|
||||
Metadata into a key-value style dataclass here. Rather, we encourage users to parse
|
||||
again outside the lib.
|
||||
'''
|
||||
|
||||
class Tag(utils.OpenIntEnum):
|
||||
# fmt: off
|
||||
PREFERRED_AUDIO_CONTEXTS = 0x01
|
||||
STREAMING_AUDIO_CONTEXTS = 0x02
|
||||
PROGRAM_INFO = 0x03
|
||||
LANGUAGE = 0x04
|
||||
CCID_LIST = 0x05
|
||||
PARENTAL_RATING = 0x06
|
||||
PROGRAM_INFO_URI = 0x07
|
||||
AUDIO_ACTIVE_STATE = 0x08
|
||||
BROADCAST_AUDIO_IMMEDIATE_RENDERING_FLAG = 0x09
|
||||
ASSISTED_LISTENING_STREAM = 0x0A
|
||||
BROADCAST_NAME = 0x0B
|
||||
EXTENDED_METADATA = 0xFE
|
||||
VENDOR_SPECIFIC = 0xFF
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Entry:
|
||||
tag: Metadata.Tag
|
||||
data: bytes
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
return cls(tag=Metadata.Tag(data[0]), data=data[1:])
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes([len(self.data) + 1, self.tag]) + self.data
|
||||
|
||||
entries: List[Entry] = dataclasses.field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
entries = []
|
||||
offset = 0
|
||||
length = len(data)
|
||||
while offset < length:
|
||||
entry_length = data[offset]
|
||||
offset += 1
|
||||
entries.append(cls.Entry.from_bytes(data[offset : offset + entry_length]))
|
||||
offset += entry_length
|
||||
|
||||
return cls(entries)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return b''.join([bytes(entry) for entry in self.entries])
|
||||
448
bumble/profiles/mcp.py
Normal file
448
bumble/profiles/mcp.py
Normal file
@@ -0,0 +1,448 @@
|
||||
# Copyright 2021-2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import struct
|
||||
|
||||
from bumble import core
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble import utils
|
||||
|
||||
from typing import Type, Optional, ClassVar, Dict, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PlayingOrder(utils.OpenIntEnum):
|
||||
'''See Media Control Service 3.15. Playing Order.'''
|
||||
|
||||
SINGLE_ONCE = 0x01
|
||||
SINGLE_REPEAT = 0x02
|
||||
IN_ORDER_ONCE = 0x03
|
||||
IN_ORDER_REPEAT = 0x04
|
||||
OLDEST_ONCE = 0x05
|
||||
OLDEST_REPEAT = 0x06
|
||||
NEWEST_ONCE = 0x07
|
||||
NEWEST_REPEAT = 0x08
|
||||
SHUFFLE_ONCE = 0x09
|
||||
SHUFFLE_REPEAT = 0x0A
|
||||
|
||||
|
||||
class PlayingOrderSupported(enum.IntFlag):
|
||||
'''See Media Control Service 3.16. Playing Orders Supported.'''
|
||||
|
||||
SINGLE_ONCE = 0x0001
|
||||
SINGLE_REPEAT = 0x0002
|
||||
IN_ORDER_ONCE = 0x0004
|
||||
IN_ORDER_REPEAT = 0x0008
|
||||
OLDEST_ONCE = 0x0010
|
||||
OLDEST_REPEAT = 0x0020
|
||||
NEWEST_ONCE = 0x0040
|
||||
NEWEST_REPEAT = 0x0080
|
||||
SHUFFLE_ONCE = 0x0100
|
||||
SHUFFLE_REPEAT = 0x0200
|
||||
|
||||
|
||||
class MediaState(utils.OpenIntEnum):
|
||||
'''See Media Control Service 3.17. Media State.'''
|
||||
|
||||
INACTIVE = 0x00
|
||||
PLAYING = 0x01
|
||||
PAUSED = 0x02
|
||||
SEEKING = 0x03
|
||||
|
||||
|
||||
class MediaControlPointOpcode(utils.OpenIntEnum):
|
||||
'''See Media Control Service 3.18. Media Control Point.'''
|
||||
|
||||
PLAY = 0x01
|
||||
PAUSE = 0x02
|
||||
FAST_REWIND = 0x03
|
||||
FAST_FORWARD = 0x04
|
||||
STOP = 0x05
|
||||
MOVE_RELATIVE = 0x10
|
||||
PREVIOUS_SEGMENT = 0x20
|
||||
NEXT_SEGMENT = 0x21
|
||||
FIRST_SEGMENT = 0x22
|
||||
LAST_SEGMENT = 0x23
|
||||
GOTO_SEGMENT = 0x24
|
||||
PREVIOUS_TRACK = 0x30
|
||||
NEXT_TRACK = 0x31
|
||||
FIRST_TRACK = 0x32
|
||||
LAST_TRACK = 0x33
|
||||
GOTO_TRACK = 0x34
|
||||
PREVIOUS_GROUP = 0x40
|
||||
NEXT_GROUP = 0x41
|
||||
FIRST_GROUP = 0x42
|
||||
LAST_GROUP = 0x43
|
||||
GOTO_GROUP = 0x44
|
||||
|
||||
|
||||
class MediaControlPointResultCode(enum.IntFlag):
|
||||
'''See Media Control Service 3.18.2. Media Control Point Notification.'''
|
||||
|
||||
SUCCESS = 0x01
|
||||
OPCODE_NOT_SUPPORTED = 0x02
|
||||
MEDIA_PLAYER_INACTIVE = 0x03
|
||||
COMMAND_CANNOT_BE_COMPLETED = 0x04
|
||||
|
||||
|
||||
class MediaControlPointOpcodeSupported(enum.IntFlag):
|
||||
'''See Media Control Service 3.19. Media Control Point Opcodes Supported.'''
|
||||
|
||||
PLAY = 0x00000001
|
||||
PAUSE = 0x00000002
|
||||
FAST_REWIND = 0x00000004
|
||||
FAST_FORWARD = 0x00000008
|
||||
STOP = 0x00000010
|
||||
MOVE_RELATIVE = 0x00000020
|
||||
PREVIOUS_SEGMENT = 0x00000040
|
||||
NEXT_SEGMENT = 0x00000080
|
||||
FIRST_SEGMENT = 0x00000100
|
||||
LAST_SEGMENT = 0x00000200
|
||||
GOTO_SEGMENT = 0x00000400
|
||||
PREVIOUS_TRACK = 0x00000800
|
||||
NEXT_TRACK = 0x00001000
|
||||
FIRST_TRACK = 0x00002000
|
||||
LAST_TRACK = 0x00004000
|
||||
GOTO_TRACK = 0x00008000
|
||||
PREVIOUS_GROUP = 0x00010000
|
||||
NEXT_GROUP = 0x00020000
|
||||
FIRST_GROUP = 0x00040000
|
||||
LAST_GROUP = 0x00080000
|
||||
GOTO_GROUP = 0x00100000
|
||||
|
||||
|
||||
class SearchControlPointItemType(utils.OpenIntEnum):
|
||||
'''See Media Control Service 3.20. Search Control Point.'''
|
||||
|
||||
TRACK_NAME = 0x01
|
||||
ARTIST_NAME = 0x02
|
||||
ALBUM_NAME = 0x03
|
||||
GROUP_NAME = 0x04
|
||||
EARLIEST_YEAR = 0x05
|
||||
LATEST_YEAR = 0x06
|
||||
GENRE = 0x07
|
||||
ONLY_TRACKS = 0x08
|
||||
ONLY_GROUPS = 0x09
|
||||
|
||||
|
||||
class ObjectType(utils.OpenIntEnum):
|
||||
'''See Media Control Service 4.4.1. Object Type field.'''
|
||||
|
||||
TASK = 0
|
||||
GROUP = 1
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ObjectId(int):
|
||||
'''See Media Control Service 4.4.2. Object ID field.'''
|
||||
|
||||
@classmethod
|
||||
def create_from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
return cls(int.from_bytes(data, byteorder='little', signed=False))
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.to_bytes(6, 'little')
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GroupObjectType:
|
||||
'''See Media Control Service 4.4. Group Object Type.'''
|
||||
|
||||
object_type: ObjectType
|
||||
object_id: ObjectId
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
return cls(
|
||||
object_type=ObjectType(data[0]),
|
||||
object_id=ObjectId.create_from_bytes(data[1:]),
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes([self.object_type]) + bytes(self.object_id)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class MediaControlService(gatt.TemplateService):
|
||||
'''Media Control Service server implementation, only for testing currently.'''
|
||||
|
||||
UUID = gatt.GATT_MEDIA_CONTROL_SERVICE
|
||||
|
||||
def __init__(self, media_player_name: Optional[str] = None) -> None:
|
||||
self.track_position = 0
|
||||
|
||||
self.media_player_name_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=media_player_name or 'Bumble Player',
|
||||
)
|
||||
self.track_changed_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
self.track_title_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
self.track_duration_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
self.track_position_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.WRITE
|
||||
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
self.media_state_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
self.media_control_point_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.WRITE
|
||||
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(write=self.on_media_control_point),
|
||||
)
|
||||
self.media_control_point_opcodes_supported_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
self.content_control_id_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=b'',
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
[
|
||||
self.media_player_name_characteristic,
|
||||
self.track_changed_characteristic,
|
||||
self.track_title_characteristic,
|
||||
self.track_duration_characteristic,
|
||||
self.track_position_characteristic,
|
||||
self.media_state_characteristic,
|
||||
self.media_control_point_characteristic,
|
||||
self.media_control_point_opcodes_supported_characteristic,
|
||||
self.content_control_id_characteristic,
|
||||
]
|
||||
)
|
||||
|
||||
async def on_media_control_point(
|
||||
self, connection: Optional[device.Connection], data: bytes
|
||||
) -> None:
|
||||
if not connection:
|
||||
raise core.InvalidStateError()
|
||||
|
||||
opcode = MediaControlPointOpcode(data[0])
|
||||
|
||||
await connection.device.notify_subscriber(
|
||||
connection,
|
||||
self.media_control_point_characteristic,
|
||||
value=bytes([opcode, MediaControlPointResultCode.SUCCESS]),
|
||||
)
|
||||
|
||||
|
||||
class GenericMediaControlService(MediaControlService):
|
||||
UUID = gatt.GATT_GENERIC_MEDIA_CONTROL_SERVICE
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class MediaControlServiceProxy(
|
||||
gatt_client.ProfileServiceProxy, utils.CompositeEventEmitter
|
||||
):
|
||||
SERVICE_CLASS = MediaControlService
|
||||
|
||||
_CHARACTERISTICS: ClassVar[Dict[str, core.UUID]] = {
|
||||
'media_player_name': gatt.GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC,
|
||||
'media_player_icon_object_id': gatt.GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC,
|
||||
'media_player_icon_url': gatt.GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC,
|
||||
'track_changed': gatt.GATT_TRACK_CHANGED_CHARACTERISTIC,
|
||||
'track_title': gatt.GATT_TRACK_TITLE_CHARACTERISTIC,
|
||||
'track_duration': gatt.GATT_TRACK_DURATION_CHARACTERISTIC,
|
||||
'track_position': gatt.GATT_TRACK_POSITION_CHARACTERISTIC,
|
||||
'playback_speed': gatt.GATT_PLAYBACK_SPEED_CHARACTERISTIC,
|
||||
'seeking_speed': gatt.GATT_SEEKING_SPEED_CHARACTERISTIC,
|
||||
'current_track_segments_object_id': gatt.GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC,
|
||||
'current_track_object_id': gatt.GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC,
|
||||
'next_track_object_id': gatt.GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC,
|
||||
'parent_group_object_id': gatt.GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC,
|
||||
'current_group_object_id': gatt.GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC,
|
||||
'playing_order': gatt.GATT_PLAYING_ORDER_CHARACTERISTIC,
|
||||
'playing_orders_supported': gatt.GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC,
|
||||
'media_state': gatt.GATT_MEDIA_STATE_CHARACTERISTIC,
|
||||
'media_control_point': gatt.GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC,
|
||||
'media_control_point_opcodes_supported': gatt.GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC,
|
||||
'search_control_point': gatt.GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC,
|
||||
'search_results_object_id': gatt.GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC,
|
||||
'content_control_id': gatt.GATT_CONTENT_CONTROL_ID_CHARACTERISTIC,
|
||||
}
|
||||
|
||||
media_player_name: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_player_icon_url: Optional[gatt_client.CharacteristicProxy] = None
|
||||
track_changed: Optional[gatt_client.CharacteristicProxy] = None
|
||||
track_title: Optional[gatt_client.CharacteristicProxy] = None
|
||||
track_duration: Optional[gatt_client.CharacteristicProxy] = None
|
||||
track_position: Optional[gatt_client.CharacteristicProxy] = None
|
||||
playback_speed: Optional[gatt_client.CharacteristicProxy] = None
|
||||
seeking_speed: Optional[gatt_client.CharacteristicProxy] = None
|
||||
current_track_segments_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
current_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
next_track_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
parent_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
current_group_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
playing_order: Optional[gatt_client.CharacteristicProxy] = None
|
||||
playing_orders_supported: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_state: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_control_point: Optional[gatt_client.CharacteristicProxy] = None
|
||||
media_control_point_opcodes_supported: Optional[gatt_client.CharacteristicProxy] = (
|
||||
None
|
||||
)
|
||||
search_control_point: Optional[gatt_client.CharacteristicProxy] = None
|
||||
search_results_object_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
content_control_id: Optional[gatt_client.CharacteristicProxy] = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
media_control_point_notifications: asyncio.Queue[bytes]
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
utils.CompositeEventEmitter.__init__(self)
|
||||
self.service_proxy = service_proxy
|
||||
self.lock = asyncio.Lock()
|
||||
self.media_control_point_notifications = asyncio.Queue()
|
||||
|
||||
for field, uuid in self._CHARACTERISTICS.items():
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
|
||||
setattr(self, field, characteristics[0])
|
||||
|
||||
async def subscribe_characteristics(self) -> None:
|
||||
if self.media_control_point:
|
||||
await self.media_control_point.subscribe(self._on_media_control_point)
|
||||
if self.media_state:
|
||||
await self.media_state.subscribe(self._on_media_state)
|
||||
if self.track_changed:
|
||||
await self.track_changed.subscribe(self._on_track_changed)
|
||||
if self.track_title:
|
||||
await self.track_title.subscribe(self._on_track_title)
|
||||
if self.track_duration:
|
||||
await self.track_duration.subscribe(self._on_track_duration)
|
||||
if self.track_position:
|
||||
await self.track_position.subscribe(self._on_track_position)
|
||||
|
||||
async def write_control_point(
|
||||
self, opcode: MediaControlPointOpcode
|
||||
) -> MediaControlPointResultCode:
|
||||
'''Writes a Media Control Point Opcode to peer and waits for the notification.
|
||||
|
||||
The write operation will be executed when there isn't other pending commands.
|
||||
|
||||
Args:
|
||||
opcode: opcode defined in `MediaControlPointOpcode`.
|
||||
|
||||
Returns:
|
||||
Response code provided in `MediaControlPointResultCode`
|
||||
|
||||
Raises:
|
||||
InvalidOperationError: Server does not have Media Control Point Characteristic.
|
||||
InvalidStateError: Server replies a notification with mismatched opcode.
|
||||
'''
|
||||
if not self.media_control_point:
|
||||
raise core.InvalidOperationError("Peer does not have media control point")
|
||||
|
||||
async with self.lock:
|
||||
await self.media_control_point.write_value(
|
||||
bytes([opcode]),
|
||||
with_response=False,
|
||||
)
|
||||
|
||||
(
|
||||
response_opcode,
|
||||
response_code,
|
||||
) = await self.media_control_point_notifications.get()
|
||||
if response_opcode != opcode:
|
||||
raise core.InvalidStateError(
|
||||
f"Expected {opcode} notification, but get {response_opcode}"
|
||||
)
|
||||
return MediaControlPointResultCode(response_code)
|
||||
|
||||
def _on_media_control_point(self, data: bytes) -> None:
|
||||
self.media_control_point_notifications.put_nowait(data)
|
||||
|
||||
def _on_media_state(self, data: bytes) -> None:
|
||||
self.emit('media_state', MediaState(data[0]))
|
||||
|
||||
def _on_track_changed(self, data: bytes) -> None:
|
||||
del data
|
||||
self.emit('track_changed')
|
||||
|
||||
def _on_track_title(self, data: bytes) -> None:
|
||||
self.emit('track_title', data.decode("utf-8"))
|
||||
|
||||
def _on_track_duration(self, data: bytes) -> None:
|
||||
self.emit('track_duration', struct.unpack_from('<i', data)[0])
|
||||
|
||||
def _on_track_position(self, data: bytes) -> None:
|
||||
self.emit('track_position', struct.unpack_from('<i', data)[0])
|
||||
|
||||
|
||||
class GenericMediaControlServiceProxy(MediaControlServiceProxy):
|
||||
SERVICE_CLASS = GenericMediaControlService
|
||||
210
bumble/profiles/pacs.py
Normal file
210
bumble/profiles/pacs.py
Normal file
@@ -0,0 +1,210 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for
|
||||
|
||||
"""LE Audio - Published Audio Capabilities Service"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import logging
|
||||
import struct
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from bumble.profiles.bap import AudioLocation, CodecSpecificCapabilities, ContextType
|
||||
from bumble.profiles import le_audio
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble import hci
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class PacRecord:
|
||||
'''Published Audio Capabilities Service, Table 3.2/3.4.'''
|
||||
|
||||
coding_format: hci.CodingFormat
|
||||
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
|
||||
metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> PacRecord:
|
||||
offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0)
|
||||
codec_specific_capabilities_size = data[offset]
|
||||
|
||||
offset += 1
|
||||
codec_specific_capabilities_bytes = data[
|
||||
offset : offset + codec_specific_capabilities_size
|
||||
]
|
||||
offset += codec_specific_capabilities_size
|
||||
metadata_size = data[offset]
|
||||
offset += 1
|
||||
metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size])
|
||||
|
||||
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
|
||||
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
|
||||
codec_specific_capabilities = codec_specific_capabilities_bytes
|
||||
else:
|
||||
codec_specific_capabilities = CodecSpecificCapabilities.from_bytes(
|
||||
codec_specific_capabilities_bytes
|
||||
)
|
||||
|
||||
return PacRecord(
|
||||
coding_format=coding_format,
|
||||
codec_specific_capabilities=codec_specific_capabilities,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
capabilities_bytes = bytes(self.codec_specific_capabilities)
|
||||
metadata_bytes = bytes(self.metadata)
|
||||
return (
|
||||
bytes(self.coding_format)
|
||||
+ bytes([len(capabilities_bytes)])
|
||||
+ capabilities_bytes
|
||||
+ bytes([len(metadata_bytes)])
|
||||
+ metadata_bytes
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class PublishedAudioCapabilitiesService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
|
||||
|
||||
sink_pac: Optional[gatt.Characteristic]
|
||||
sink_audio_locations: Optional[gatt.Characteristic]
|
||||
source_pac: Optional[gatt.Characteristic]
|
||||
source_audio_locations: Optional[gatt.Characteristic]
|
||||
available_audio_contexts: gatt.Characteristic
|
||||
supported_audio_contexts: gatt.Characteristic
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
supported_source_context: ContextType,
|
||||
supported_sink_context: ContextType,
|
||||
available_source_context: ContextType,
|
||||
available_sink_context: ContextType,
|
||||
sink_pac: Sequence[PacRecord] = (),
|
||||
sink_audio_locations: Optional[AudioLocation] = None,
|
||||
source_pac: Sequence[PacRecord] = (),
|
||||
source_audio_locations: Optional[AudioLocation] = None,
|
||||
) -> None:
|
||||
characteristics = []
|
||||
|
||||
self.supported_audio_contexts = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('<HH', supported_sink_context, supported_source_context),
|
||||
)
|
||||
characteristics.append(self.supported_audio_contexts)
|
||||
|
||||
self.available_audio_contexts = gatt.Characteristic(
|
||||
uuid=gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('<HH', available_sink_context, available_source_context),
|
||||
)
|
||||
characteristics.append(self.available_audio_contexts)
|
||||
|
||||
if sink_pac:
|
||||
self.sink_pac = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SINK_PAC_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
value=bytes([len(sink_pac)]) + b''.join(map(bytes, sink_pac)),
|
||||
)
|
||||
characteristics.append(self.sink_pac)
|
||||
|
||||
if sink_audio_locations is not None:
|
||||
self.sink_audio_locations = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('<I', sink_audio_locations),
|
||||
)
|
||||
characteristics.append(self.sink_audio_locations)
|
||||
|
||||
if source_pac:
|
||||
self.source_pac = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SOURCE_PAC_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
value=bytes([len(source_pac)]) + b''.join(map(bytes, source_pac)),
|
||||
)
|
||||
characteristics.append(self.source_pac)
|
||||
|
||||
if source_audio_locations is not None:
|
||||
self.source_audio_locations = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('<I', source_audio_locations),
|
||||
)
|
||||
characteristics.append(self.source_audio_locations)
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = PublishedAudioCapabilitiesService
|
||||
|
||||
sink_pac: Optional[gatt_client.CharacteristicProxy] = None
|
||||
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
|
||||
source_pac: Optional[gatt_client.CharacteristicProxy] = None
|
||||
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
|
||||
available_audio_contexts: gatt_client.CharacteristicProxy
|
||||
supported_audio_contexts: gatt_client.CharacteristicProxy
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
|
||||
)[0]
|
||||
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
|
||||
)[0]
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SINK_PAC_CHARACTERISTIC
|
||||
):
|
||||
self.sink_pac = characteristics[0]
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SOURCE_PAC_CHARACTERISTIC
|
||||
):
|
||||
self.source_pac = characteristics[0]
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
|
||||
):
|
||||
self.sink_audio_locations = characteristics[0]
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
|
||||
):
|
||||
self.source_audio_locations = characteristics[0]
|
||||
46
bumble/profiles/pbp.py
Normal file
46
bumble/profiles/pbp.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble.profiles import le_audio
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class PublicBroadcastAnnouncement:
|
||||
class Features(enum.IntFlag):
|
||||
ENCRYPTED = 1 << 0
|
||||
STANDARD_QUALITY_CONFIGURATION = 1 << 1
|
||||
HIGH_QUALITY_CONFIGURATION = 1 << 2
|
||||
|
||||
features: Features
|
||||
metadata: le_audio.Metadata
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
features = cls.Features(data[0])
|
||||
metadata_length = data[1]
|
||||
metadata_ltv = data[1 : 1 + metadata_length]
|
||||
return cls(
|
||||
features=features, metadata=le_audio.Metadata.from_bytes(metadata_ltv)
|
||||
)
|
||||
89
bumble/profiles/tmap.py
Normal file
89
bumble/profiles/tmap.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""LE Audio - Telephony and Media Audio Profile"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import enum
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from bumble.gatt import (
|
||||
TemplateService,
|
||||
Characteristic,
|
||||
DelegatedCharacteristicAdapter,
|
||||
InvalidServiceError,
|
||||
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE,
|
||||
GATT_TMAP_ROLE_CHARACTERISTIC,
|
||||
)
|
||||
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class Role(enum.IntFlag):
|
||||
CALL_GATEWAY = 1 << 0
|
||||
CALL_TERMINAL = 1 << 1
|
||||
UNICAST_MEDIA_SENDER = 1 << 2
|
||||
UNICAST_MEDIA_RECEIVER = 1 << 3
|
||||
BROADCAST_MEDIA_SENDER = 1 << 4
|
||||
BROADCAST_MEDIA_RECEIVER = 1 << 5
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class TelephonyAndMediaAudioService(TemplateService):
|
||||
UUID = GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE
|
||||
|
||||
def __init__(self, role: Role):
|
||||
self.role_characteristic = Characteristic(
|
||||
GATT_TMAP_ROLE_CHARACTERISTIC,
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.READABLE,
|
||||
struct.pack('<H', int(role)),
|
||||
)
|
||||
|
||||
super().__init__([self.role_characteristic])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class TelephonyAndMediaAudioServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = TelephonyAndMediaAudioService
|
||||
|
||||
role: DelegatedCharacteristicAdapter
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_TMAP_ROLE_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise InvalidServiceError('TMAP Role characteristic not found')
|
||||
|
||||
self.role = DelegatedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
decode=lambda value: Role(
|
||||
struct.unpack_from('<H', value, 0)[0],
|
||||
),
|
||||
)
|
||||
230
bumble/profiles/vcp.py
Normal file
230
bumble/profiles/vcp.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# Copyright 2021-2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
|
||||
from bumble import att
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
MIN_VOLUME = 0
|
||||
MAX_VOLUME = 255
|
||||
|
||||
|
||||
class ErrorCode(enum.IntEnum):
|
||||
'''
|
||||
See Volume Control Service 1.6. Application error codes.
|
||||
'''
|
||||
|
||||
INVALID_CHANGE_COUNTER = 0x80
|
||||
OPCODE_NOT_SUPPORTED = 0x81
|
||||
|
||||
|
||||
class VolumeFlags(enum.IntFlag):
|
||||
'''
|
||||
See Volume Control Service 3.3. Volume Flags.
|
||||
'''
|
||||
|
||||
VOLUME_SETTING_PERSISTED = 0x01
|
||||
# RFU
|
||||
|
||||
|
||||
class VolumeControlPointOpcode(enum.IntEnum):
|
||||
'''
|
||||
See Volume Control Service Table 3.3: Volume Control Point procedure requirements.
|
||||
'''
|
||||
|
||||
# fmt: off
|
||||
RELATIVE_VOLUME_DOWN = 0x00
|
||||
RELATIVE_VOLUME_UP = 0x01
|
||||
UNMUTE_RELATIVE_VOLUME_DOWN = 0x02
|
||||
UNMUTE_RELATIVE_VOLUME_UP = 0x03
|
||||
SET_ABSOLUTE_VOLUME = 0x04
|
||||
UNMUTE = 0x05
|
||||
MUTE = 0x06
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class VolumeControlService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_VOLUME_CONTROL_SERVICE
|
||||
|
||||
volume_state: gatt.Characteristic
|
||||
volume_control_point: gatt.Characteristic
|
||||
volume_flags: gatt.Characteristic
|
||||
|
||||
volume_setting: int
|
||||
muted: int
|
||||
change_counter: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_size: int = 16,
|
||||
volume_setting: int = 0,
|
||||
muted: int = 0,
|
||||
change_counter: int = 0,
|
||||
volume_flags: int = 0,
|
||||
included_services: Sequence[gatt.Service] = (),
|
||||
) -> None:
|
||||
self.step_size = step_size
|
||||
self.volume_setting = volume_setting
|
||||
self.muted = muted
|
||||
self.change_counter = change_counter
|
||||
|
||||
self.volume_state = gatt.Characteristic(
|
||||
uuid=gatt.GATT_VOLUME_STATE_CHARACTERISTIC,
|
||||
properties=(
|
||||
gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY
|
||||
),
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(read=self._on_read_volume_state),
|
||||
)
|
||||
self.volume_control_point = gatt.Characteristic(
|
||||
uuid=gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.WRITE,
|
||||
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(write=self._on_write_volume_control_point),
|
||||
)
|
||||
self.volume_flags = gatt.Characteristic(
|
||||
uuid=gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=bytes([volume_flags]),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
characteristics=[
|
||||
self.volume_state,
|
||||
self.volume_control_point,
|
||||
self.volume_flags,
|
||||
],
|
||||
included_services=list(included_services),
|
||||
)
|
||||
|
||||
@property
|
||||
def volume_state_bytes(self) -> bytes:
|
||||
return bytes([self.volume_setting, self.muted, self.change_counter])
|
||||
|
||||
@volume_state_bytes.setter
|
||||
def volume_state_bytes(self, new_value: bytes) -> None:
|
||||
self.volume_setting, self.muted, self.change_counter = new_value
|
||||
|
||||
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
|
||||
return self.volume_state_bytes
|
||||
|
||||
def _on_write_volume_control_point(
|
||||
self, connection: Optional[device.Connection], value: bytes
|
||||
) -> None:
|
||||
assert connection
|
||||
|
||||
opcode = VolumeControlPointOpcode(value[0])
|
||||
change_counter = value[1]
|
||||
|
||||
if change_counter != self.change_counter:
|
||||
raise att.ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
|
||||
|
||||
handler = getattr(self, '_on_' + opcode.name.lower())
|
||||
if handler(*value[2:]):
|
||||
self.change_counter = (self.change_counter + 1) % 256
|
||||
connection.abort_on(
|
||||
'disconnection',
|
||||
connection.device.notify_subscribers(
|
||||
attribute=self.volume_state,
|
||||
value=self.volume_state_bytes,
|
||||
),
|
||||
)
|
||||
self.emit(
|
||||
'volume_state', self.volume_setting, self.muted, self.change_counter
|
||||
)
|
||||
|
||||
def _on_relative_volume_down(self) -> bool:
|
||||
old_volume = self.volume_setting
|
||||
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
|
||||
return self.volume_setting != old_volume
|
||||
|
||||
def _on_relative_volume_up(self) -> bool:
|
||||
old_volume = self.volume_setting
|
||||
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
|
||||
return self.volume_setting != old_volume
|
||||
|
||||
def _on_unmute_relative_volume_down(self) -> bool:
|
||||
old_volume, old_muted_state = self.volume_setting, self.muted
|
||||
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
|
||||
self.muted = 0
|
||||
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
|
||||
|
||||
def _on_unmute_relative_volume_up(self) -> bool:
|
||||
old_volume, old_muted_state = self.volume_setting, self.muted
|
||||
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
|
||||
self.muted = 0
|
||||
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
|
||||
|
||||
def _on_set_absolute_volume(self, volume_setting: int) -> bool:
|
||||
old_volume_setting = self.volume_setting
|
||||
self.volume_setting = volume_setting
|
||||
return old_volume_setting != self.volume_setting
|
||||
|
||||
def _on_unmute(self) -> bool:
|
||||
old_muted_state = self.muted
|
||||
self.muted = 0
|
||||
return self.muted != old_muted_state
|
||||
|
||||
def _on_mute(self) -> bool:
|
||||
old_muted_state = self.muted
|
||||
self.muted = 1
|
||||
return self.muted != old_muted_state
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class VolumeControlServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = VolumeControlService
|
||||
|
||||
volume_control_point: gatt_client.CharacteristicProxy
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.volume_state = gatt.PackedCharacteristicAdapter(
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_VOLUME_STATE_CHARACTERISTIC
|
||||
)[0],
|
||||
'BBB',
|
||||
)
|
||||
|
||||
self.volume_control_point = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC
|
||||
)[0]
|
||||
|
||||
self.volume_flags = gatt.PackedCharacteristicAdapter(
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC
|
||||
)[0],
|
||||
'B',
|
||||
)
|
||||
330
bumble/profiles/vocs.py
Normal file
330
bumble/profiles/vocs.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from bumble.device import Connection
|
||||
from bumble.att import ATT_Error
|
||||
from bumble.gatt import (
|
||||
Characteristic,
|
||||
DelegatedCharacteristicAdapter,
|
||||
TemplateService,
|
||||
CharacteristicValue,
|
||||
UTF8CharacteristicAdapter,
|
||||
InvalidServiceError,
|
||||
GATT_VOLUME_OFFSET_CONTROL_SERVICE,
|
||||
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC,
|
||||
GATT_AUDIO_LOCATION_CHARACTERISTIC,
|
||||
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC,
|
||||
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC,
|
||||
)
|
||||
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
|
||||
from bumble.utils import OpenIntEnum
|
||||
from bumble.profiles.bap import AudioLocation
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
MIN_VOLUME_OFFSET = -255
|
||||
MAX_VOLUME_OFFSET = 255
|
||||
CHANGE_COUNTER_MAX_VALUE = 0xFF
|
||||
|
||||
|
||||
class SetVolumeOffsetOpCode(OpenIntEnum):
|
||||
SET_VOLUME_OFFSET = 0x01
|
||||
|
||||
|
||||
class ErrorCode(OpenIntEnum):
|
||||
"""
|
||||
See Volume Offset Control Service 1.6. Application error codes.
|
||||
"""
|
||||
|
||||
INVALID_CHANGE_COUNTER = 0x80
|
||||
OPCODE_NOT_SUPPORTED = 0x81
|
||||
VALUE_OUT_OF_RANGE = 0x82
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class VolumeOffsetState:
|
||||
volume_offset: int = 0
|
||||
change_counter: int = 0
|
||||
attribute_value: Optional[CharacteristicValue] = None
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack('<hB', self.volume_offset, self.change_counter)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
volume_offset, change_counter = struct.unpack('<hB', data)
|
||||
return cls(volume_offset, change_counter)
|
||||
|
||||
def increment_change_counter(self) -> None:
|
||||
self.change_counter = (self.change_counter + 1) % (CHANGE_COUNTER_MAX_VALUE + 1)
|
||||
|
||||
async def notify_subscribers_via_connection(self, connection: Connection) -> None:
|
||||
assert self.attribute_value is not None
|
||||
await connection.device.notify_subscribers(
|
||||
attribute=self.attribute_value, value=bytes(self)
|
||||
)
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocsAudioLocation:
|
||||
audio_location: AudioLocation = AudioLocation.NOT_ALLOWED
|
||||
attribute_value: Optional[CharacteristicValue] = None
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack('<I', self.audio_location)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
audio_location = AudioLocation(struct.unpack('<I', data)[0])
|
||||
return cls(audio_location)
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
assert connection
|
||||
assert self.attribute_value
|
||||
|
||||
self.audio_location = AudioLocation(int.from_bytes(value, 'little'))
|
||||
await connection.device.notify_subscribers(
|
||||
attribute=self.attribute_value, value=value
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VolumeOffsetControlPoint:
|
||||
volume_offset_state: VolumeOffsetState
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
assert connection
|
||||
|
||||
opcode = value[0]
|
||||
if opcode != SetVolumeOffsetOpCode.SET_VOLUME_OFFSET:
|
||||
raise ATT_Error(ErrorCode.OPCODE_NOT_SUPPORTED)
|
||||
|
||||
change_counter, volume_offset = struct.unpack('<Bh', value[1:])
|
||||
await self._set_volume_offset(connection, change_counter, volume_offset)
|
||||
|
||||
async def _set_volume_offset(
|
||||
self,
|
||||
connection: Connection,
|
||||
change_counter_operand: int,
|
||||
volume_offset_operand: int,
|
||||
) -> None:
|
||||
change_counter = self.volume_offset_state.change_counter
|
||||
|
||||
if change_counter != change_counter_operand:
|
||||
raise ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
|
||||
|
||||
if not MIN_VOLUME_OFFSET <= volume_offset_operand <= MAX_VOLUME_OFFSET:
|
||||
raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
|
||||
|
||||
self.volume_offset_state.volume_offset = volume_offset_operand
|
||||
self.volume_offset_state.increment_change_counter()
|
||||
await self.volume_offset_state.notify_subscribers_via_connection(connection)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioOutputDescription:
|
||||
audio_output_description: str = ''
|
||||
attribute_value: Optional[CharacteristicValue] = None
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
return cls(audio_output_description=data.decode('utf-8'))
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.audio_output_description.encode('utf-8')
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
assert connection
|
||||
assert self.attribute_value
|
||||
|
||||
self.audio_output_description = value.decode('utf-8')
|
||||
await connection.device.notify_subscribers(
|
||||
attribute=self.attribute_value, value=value
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class VolumeOffsetControlService(TemplateService):
|
||||
UUID = GATT_VOLUME_OFFSET_CONTROL_SERVICE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
volume_offset_state: Optional[VolumeOffsetState] = None,
|
||||
audio_location: Optional[VocsAudioLocation] = None,
|
||||
audio_output_description: Optional[AudioOutputDescription] = None,
|
||||
) -> None:
|
||||
|
||||
self.volume_offset_state = (
|
||||
VolumeOffsetState() if volume_offset_state is None else volume_offset_state
|
||||
)
|
||||
|
||||
self.audio_location = (
|
||||
VocsAudioLocation() if audio_location is None else audio_location
|
||||
)
|
||||
|
||||
self.audio_output_description = (
|
||||
AudioOutputDescription()
|
||||
if audio_output_description is None
|
||||
else audio_output_description
|
||||
)
|
||||
|
||||
self.volume_offset_control_point: VolumeOffsetControlPoint = (
|
||||
VolumeOffsetControlPoint(self.volume_offset_state)
|
||||
)
|
||||
|
||||
self.volume_offset_state_characteristic = DelegatedCharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC,
|
||||
properties=(
|
||||
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY
|
||||
),
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(read=self.volume_offset_state.on_read),
|
||||
),
|
||||
encode=lambda value: bytes(value),
|
||||
)
|
||||
|
||||
self.audio_location_characteristic = DelegatedCharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_AUDIO_LOCATION_CHARACTERISTIC,
|
||||
properties=(
|
||||
Characteristic.Properties.READ
|
||||
| Characteristic.Properties.NOTIFY
|
||||
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||
),
|
||||
permissions=(
|
||||
Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION
|
||||
),
|
||||
value=CharacteristicValue(
|
||||
read=self.audio_location.on_read,
|
||||
write=self.audio_location.on_write,
|
||||
),
|
||||
),
|
||||
encode=lambda value: bytes(value),
|
||||
decode=VocsAudioLocation.from_bytes,
|
||||
)
|
||||
self.audio_location.attribute_value = self.audio_location_characteristic.value
|
||||
|
||||
self.volume_offset_control_point_characteristic = Characteristic(
|
||||
uuid=GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.WRITE,
|
||||
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(write=self.volume_offset_control_point.on_write),
|
||||
)
|
||||
|
||||
self.audio_output_description_characteristic = DelegatedCharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC,
|
||||
properties=(
|
||||
Characteristic.Properties.READ
|
||||
| Characteristic.Properties.NOTIFY
|
||||
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||
),
|
||||
permissions=(
|
||||
Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION
|
||||
),
|
||||
value=CharacteristicValue(
|
||||
read=self.audio_output_description.on_read,
|
||||
write=self.audio_output_description.on_write,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.audio_output_description.attribute_value = (
|
||||
self.audio_output_description_characteristic.value
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
characteristics=[
|
||||
self.volume_offset_state_characteristic, # type: ignore
|
||||
self.audio_location_characteristic, # type: ignore
|
||||
self.volume_offset_control_point_characteristic, # type: ignore
|
||||
self.audio_output_description_characteristic, # type: ignore
|
||||
],
|
||||
primary=False,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class VolumeOffsetControlServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = VolumeOffsetControlService
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise InvalidServiceError("Volume Offset State characteristic not found")
|
||||
self.volume_offset_state = DelegatedCharacteristicAdapter(
|
||||
characteristics[0], decode=VolumeOffsetState.from_bytes
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_AUDIO_LOCATION_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise InvalidServiceError("Audio Location characteristic not found")
|
||||
self.audio_location = DelegatedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
encode=lambda value: bytes(value),
|
||||
decode=VocsAudioLocation.from_bytes,
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise InvalidServiceError(
|
||||
"Volume Offset Control Point characteristic not found"
|
||||
)
|
||||
self.volume_offset_control_point = characteristics[0]
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise InvalidServiceError(
|
||||
"Audio Output Description characteristic not found"
|
||||
)
|
||||
self.audio_output_description = UTF8CharacteristicAdapter(characteristics[0])
|
||||
774
bumble/rfcomm.py
774
bumble/rfcomm.py
File diff suppressed because it is too large
Load Diff
110
bumble/rtp.py
Normal file
110
bumble/rtp.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import struct
|
||||
from typing import List
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class MediaPacket:
|
||||
@staticmethod
|
||||
def from_bytes(data: bytes) -> MediaPacket:
|
||||
version = (data[0] >> 6) & 0x03
|
||||
padding = (data[0] >> 5) & 0x01
|
||||
extension = (data[0] >> 4) & 0x01
|
||||
csrc_count = data[0] & 0x0F
|
||||
marker = (data[1] >> 7) & 0x01
|
||||
payload_type = data[1] & 0x7F
|
||||
sequence_number = struct.unpack_from('>H', data, 2)[0]
|
||||
timestamp = struct.unpack_from('>I', data, 4)[0]
|
||||
ssrc = struct.unpack_from('>I', data, 8)[0]
|
||||
csrc_list = [
|
||||
struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)
|
||||
]
|
||||
payload = data[12 + csrc_count * 4 :]
|
||||
|
||||
return MediaPacket(
|
||||
version,
|
||||
padding,
|
||||
extension,
|
||||
marker,
|
||||
sequence_number,
|
||||
timestamp,
|
||||
ssrc,
|
||||
csrc_list,
|
||||
payload_type,
|
||||
payload,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version: int,
|
||||
padding: int,
|
||||
extension: int,
|
||||
marker: int,
|
||||
sequence_number: int,
|
||||
timestamp: int,
|
||||
ssrc: int,
|
||||
csrc_list: List[int],
|
||||
payload_type: int,
|
||||
payload: bytes,
|
||||
) -> None:
|
||||
self.version = version
|
||||
self.padding = padding
|
||||
self.extension = extension
|
||||
self.marker = marker
|
||||
self.sequence_number = sequence_number & 0xFFFF
|
||||
self.timestamp = timestamp & 0xFFFFFFFF
|
||||
self.timestamp_seconds = 0.0
|
||||
self.ssrc = ssrc
|
||||
self.csrc_list = csrc_list
|
||||
self.payload_type = payload_type
|
||||
self.payload = payload
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
header = bytes(
|
||||
[
|
||||
self.version << 6
|
||||
| self.padding << 5
|
||||
| self.extension << 4
|
||||
| len(self.csrc_list),
|
||||
self.marker << 7 | self.payload_type,
|
||||
]
|
||||
) + struct.pack(
|
||||
'>HII',
|
||||
self.sequence_number,
|
||||
self.timestamp,
|
||||
self.ssrc,
|
||||
)
|
||||
for csrc in self.csrc_list:
|
||||
header += struct.pack('>I', csrc)
|
||||
return header + self.payload
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'RTP(v={self.version},'
|
||||
f'p={self.padding},'
|
||||
f'x={self.extension},'
|
||||
f'm={self.marker},'
|
||||
f'pt={self.payload_type},'
|
||||
f'sn={self.sequence_number},'
|
||||
f'ts={self.timestamp},'
|
||||
f'ssrc={self.ssrc},'
|
||||
f'csrcs={self.csrc_list},'
|
||||
f'payload_size={len(self.payload)})'
|
||||
)
|
||||
452
bumble/sdp.py
452
bumble/sdp.py
@@ -16,14 +16,24 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import struct
|
||||
from typing import Dict, List, Type
|
||||
from typing import Iterable, NewType, Optional, Union, Sequence, Type, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
|
||||
from . import core
|
||||
from .colors import color
|
||||
from .core import InvalidStateError
|
||||
from .hci import HCI_Object, name_or_number, key_with_value
|
||||
from bumble import core, l2cap
|
||||
from bumble.colors import color
|
||||
from bumble.core import (
|
||||
InvalidStateError,
|
||||
InvalidArgumentError,
|
||||
InvalidPacketError,
|
||||
ProtocolError,
|
||||
)
|
||||
from bumble.hci import HCI_Object, name_or_number, key_with_value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .device import Device, Connection
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -94,7 +104,8 @@ SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID = 0X000B
|
||||
SDP_ICON_URL_ATTRIBUTE_ID = 0X000C
|
||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D
|
||||
|
||||
# Attribute Identifier (cf. Assigned Numbers for Service Discovery)
|
||||
|
||||
# Profile-specific Attribute Identifiers (cf. Assigned Numbers for Service Discovery)
|
||||
# used by AVRCP, HFP and A2DP
|
||||
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311
|
||||
|
||||
@@ -112,7 +123,8 @@ SDP_ATTRIBUTE_ID_NAMES = {
|
||||
SDP_DOCUMENTATION_URL_ATTRIBUTE_ID: 'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID',
|
||||
SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID: 'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID',
|
||||
SDP_ICON_URL_ATTRIBUTE_ID: 'SDP_ICON_URL_ATTRIBUTE_ID',
|
||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID'
|
||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID',
|
||||
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID: 'SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID',
|
||||
}
|
||||
|
||||
SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
|
||||
@@ -164,7 +176,7 @@ class DataElement:
|
||||
UUID: lambda x: DataElement(
|
||||
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
|
||||
),
|
||||
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')),
|
||||
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
|
||||
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
|
||||
SEQUENCE: lambda x: DataElement(
|
||||
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
|
||||
@@ -183,7 +195,9 @@ class DataElement:
|
||||
self.bytes = None
|
||||
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
|
||||
if value_size is None:
|
||||
raise ValueError('integer types must have a value size specified')
|
||||
raise InvalidArgumentError(
|
||||
'integer types must have a value size specified'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def nil() -> DataElement:
|
||||
@@ -226,7 +240,7 @@ class DataElement:
|
||||
return DataElement(DataElement.UUID, value)
|
||||
|
||||
@staticmethod
|
||||
def text_string(value: str) -> DataElement:
|
||||
def text_string(value: bytes) -> DataElement:
|
||||
return DataElement(DataElement.TEXT_STRING, value)
|
||||
|
||||
@staticmethod
|
||||
@@ -234,11 +248,11 @@ class DataElement:
|
||||
return DataElement(DataElement.BOOLEAN, value)
|
||||
|
||||
@staticmethod
|
||||
def sequence(value: List[DataElement]) -> DataElement:
|
||||
def sequence(value: Iterable[DataElement]) -> DataElement:
|
||||
return DataElement(DataElement.SEQUENCE, value)
|
||||
|
||||
@staticmethod
|
||||
def alternative(value: List[DataElement]) -> DataElement:
|
||||
def alternative(value: Iterable[DataElement]) -> DataElement:
|
||||
return DataElement(DataElement.ALTERNATIVE, value)
|
||||
|
||||
@staticmethod
|
||||
@@ -259,7 +273,7 @@ class DataElement:
|
||||
if len(data) == 8:
|
||||
return struct.unpack('>Q', data)[0]
|
||||
|
||||
raise ValueError(f'invalid integer length {len(data)}')
|
||||
raise InvalidPacketError(f'invalid integer length {len(data)}')
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_from_bytes(data):
|
||||
@@ -275,7 +289,7 @@ class DataElement:
|
||||
if len(data) == 8:
|
||||
return struct.unpack('>q', data)[0]
|
||||
|
||||
raise ValueError(f'invalid integer length {len(data)}')
|
||||
raise InvalidPacketError(f'invalid integer length {len(data)}')
|
||||
|
||||
@staticmethod
|
||||
def list_from_bytes(data):
|
||||
@@ -336,9 +350,6 @@ class DataElement:
|
||||
] # Keep a copy so we can re-serialize to an exact replica
|
||||
return result
|
||||
|
||||
def to_bytes(self):
|
||||
return bytes(self)
|
||||
|
||||
def __bytes__(self):
|
||||
# Return early if we have a cache
|
||||
if self.bytes:
|
||||
@@ -348,7 +359,7 @@ class DataElement:
|
||||
data = b''
|
||||
elif self.type == DataElement.UNSIGNED_INTEGER:
|
||||
if self.value < 0:
|
||||
raise ValueError('UNSIGNED_INTEGER cannot be negative')
|
||||
raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
|
||||
|
||||
if self.value_size == 1:
|
||||
data = struct.pack('B', self.value)
|
||||
@@ -359,7 +370,7 @@ class DataElement:
|
||||
elif self.value_size == 8:
|
||||
data = struct.pack('>Q', self.value)
|
||||
else:
|
||||
raise ValueError('invalid value_size')
|
||||
raise InvalidArgumentError('invalid value_size')
|
||||
elif self.type == DataElement.SIGNED_INTEGER:
|
||||
if self.value_size == 1:
|
||||
data = struct.pack('b', self.value)
|
||||
@@ -370,10 +381,10 @@ class DataElement:
|
||||
elif self.value_size == 8:
|
||||
data = struct.pack('>q', self.value)
|
||||
else:
|
||||
raise ValueError('invalid value_size')
|
||||
raise InvalidArgumentError('invalid value_size')
|
||||
elif self.type == DataElement.UUID:
|
||||
data = bytes(reversed(bytes(self.value)))
|
||||
elif self.type in (DataElement.TEXT_STRING, DataElement.URL):
|
||||
elif self.type == DataElement.URL:
|
||||
data = self.value.encode('utf8')
|
||||
elif self.type == DataElement.BOOLEAN:
|
||||
data = bytes([1 if self.value else 0])
|
||||
@@ -386,7 +397,7 @@ class DataElement:
|
||||
size_bytes = b''
|
||||
if self.type == DataElement.NIL:
|
||||
if size != 0:
|
||||
raise ValueError('NIL must be empty')
|
||||
raise InvalidArgumentError('NIL must be empty')
|
||||
size_index = 0
|
||||
elif self.type in (
|
||||
DataElement.UNSIGNED_INTEGER,
|
||||
@@ -404,7 +415,7 @@ class DataElement:
|
||||
elif size == 16:
|
||||
size_index = 4
|
||||
else:
|
||||
raise ValueError('invalid data size')
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
elif self.type in (
|
||||
DataElement.TEXT_STRING,
|
||||
DataElement.SEQUENCE,
|
||||
@@ -421,11 +432,13 @@ class DataElement:
|
||||
size_index = 7
|
||||
size_bytes = struct.pack('>I', size)
|
||||
else:
|
||||
raise ValueError('invalid data size')
|
||||
raise InvalidArgumentError('invalid data size')
|
||||
elif self.type == DataElement.BOOLEAN:
|
||||
if size != 1:
|
||||
raise ValueError('boolean must be 1 byte')
|
||||
raise InvalidArgumentError('boolean must be 1 byte')
|
||||
size_index = 0
|
||||
else:
|
||||
raise RuntimeError("internal error - self.type not supported")
|
||||
|
||||
self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
|
||||
return self.bytes
|
||||
@@ -466,7 +479,9 @@ class ServiceAttribute:
|
||||
self.value = value
|
||||
|
||||
@staticmethod
|
||||
def list_from_data_elements(elements):
|
||||
def list_from_data_elements(
|
||||
elements: Sequence[DataElement],
|
||||
) -> list[ServiceAttribute]:
|
||||
attribute_list = []
|
||||
for i in range(0, len(elements) // 2):
|
||||
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
|
||||
@@ -478,7 +493,9 @@ class ServiceAttribute:
|
||||
return attribute_list
|
||||
|
||||
@staticmethod
|
||||
def find_attribute_in_list(attribute_list, attribute_id):
|
||||
def find_attribute_in_list(
|
||||
attribute_list: Iterable[ServiceAttribute], attribute_id: int
|
||||
) -> Optional[DataElement]:
|
||||
return next(
|
||||
(
|
||||
attribute.value
|
||||
@@ -493,7 +510,7 @@ class ServiceAttribute:
|
||||
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
|
||||
|
||||
@staticmethod
|
||||
def is_uuid_in_value(uuid, value):
|
||||
def is_uuid_in_value(uuid: core.UUID, value: DataElement) -> bool:
|
||||
# Find if a uuid matches a value, either directly or recursing into sequences
|
||||
if value.type == DataElement.UUID:
|
||||
return value.value == uuid
|
||||
@@ -525,7 +542,12 @@ class SDP_PDU:
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
|
||||
'''
|
||||
|
||||
sdp_pdu_classes: Dict[int, Type[SDP_PDU]] = {}
|
||||
RESPONSE_PDU_IDS = {
|
||||
SDP_SERVICE_SEARCH_REQUEST: SDP_SERVICE_SEARCH_RESPONSE,
|
||||
SDP_SERVICE_ATTRIBUTE_REQUEST: SDP_SERVICE_ATTRIBUTE_RESPONSE,
|
||||
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE,
|
||||
}
|
||||
sdp_pdu_classes: dict[int, Type[SDP_PDU]] = {}
|
||||
name = None
|
||||
pdu_id = 0
|
||||
|
||||
@@ -547,7 +569,9 @@ class SDP_PDU:
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def parse_service_record_handle_list_preceded_by_count(data, offset):
|
||||
def parse_service_record_handle_list_preceded_by_count(
|
||||
data: bytes, offset: int
|
||||
) -> tuple[int, list[int]]:
|
||||
count = struct.unpack_from('>H', data, offset - 2)[0]
|
||||
handle_list = [
|
||||
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
|
||||
@@ -609,11 +633,8 @@ class SDP_PDU:
|
||||
def init_from_bytes(self, pdu, offset):
|
||||
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||
|
||||
def to_bytes(self):
|
||||
return self.pdu
|
||||
|
||||
def __bytes__(self):
|
||||
return self.to_bytes()
|
||||
return self.pdu
|
||||
|
||||
def __str__(self):
|
||||
result = f'{color(self.name, "blue")} [TID={self.transaction_id}]'
|
||||
@@ -631,6 +652,8 @@ class SDP_ErrorResponse(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
|
||||
'''
|
||||
|
||||
error_code: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SDP_PDU.subclass(
|
||||
@@ -645,6 +668,10 @@ class SDP_ServiceSearchRequest(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
|
||||
'''
|
||||
|
||||
service_search_pattern: DataElement
|
||||
maximum_service_record_count: int
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SDP_PDU.subclass(
|
||||
@@ -663,6 +690,11 @@ class SDP_ServiceSearchResponse(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
|
||||
'''
|
||||
|
||||
service_record_handle_list: list[int]
|
||||
total_service_record_count: int
|
||||
current_service_record_count: int
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SDP_PDU.subclass(
|
||||
@@ -678,6 +710,11 @@ class SDP_ServiceAttributeRequest(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
|
||||
'''
|
||||
|
||||
service_record_handle: int
|
||||
maximum_attribute_byte_count: int
|
||||
attribute_id_list: DataElement
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SDP_PDU.subclass(
|
||||
@@ -692,6 +729,10 @@ class SDP_ServiceAttributeResponse(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
|
||||
'''
|
||||
|
||||
attribute_list_byte_count: int
|
||||
attribute_list: bytes
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SDP_PDU.subclass(
|
||||
@@ -707,6 +748,11 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
|
||||
'''
|
||||
|
||||
service_search_pattern: DataElement
|
||||
maximum_attribute_byte_count: int
|
||||
attribute_id_list: DataElement
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SDP_PDU.subclass(
|
||||
@@ -721,26 +767,103 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
|
||||
'''
|
||||
|
||||
attribute_lists_byte_count: int
|
||||
attribute_lists: bytes
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Client:
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
self.pending_request = None
|
||||
self.channel = None
|
||||
def __init__(self, connection: Connection, mtu: int = 0) -> None:
|
||||
self.connection = connection
|
||||
self.channel: Optional[l2cap.ClassicChannel] = None
|
||||
self.mtu = mtu
|
||||
self.request_semaphore = asyncio.Semaphore(1)
|
||||
self.pending_request: Optional[SDP_PDU] = None
|
||||
self.pending_response: Optional[asyncio.futures.Future[SDP_PDU]] = None
|
||||
self.next_transaction_id = 0
|
||||
|
||||
async def connect(self, connection):
|
||||
result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM)
|
||||
self.channel = result
|
||||
async def connect(self) -> None:
|
||||
self.channel = await self.connection.create_l2cap_channel(
|
||||
spec=(
|
||||
l2cap.ClassicChannelSpec(SDP_PSM, self.mtu)
|
||||
if self.mtu
|
||||
else l2cap.ClassicChannelSpec(SDP_PSM)
|
||||
)
|
||||
)
|
||||
self.channel.sink = self.on_pdu
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
if self.channel:
|
||||
await self.channel.disconnect()
|
||||
self.channel = None
|
||||
|
||||
async def search_services(self, uuids):
|
||||
def make_transaction_id(self) -> int:
|
||||
transaction_id = self.next_transaction_id
|
||||
self.next_transaction_id = (self.next_transaction_id + 1) & 0xFFFF
|
||||
return transaction_id
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
if not self.pending_request:
|
||||
logger.warning('received response with no pending request')
|
||||
return
|
||||
assert self.pending_response is not None
|
||||
|
||||
response = SDP_PDU.from_bytes(pdu)
|
||||
|
||||
# Check that the transaction ID is what we expect
|
||||
if self.pending_request.transaction_id != response.transaction_id:
|
||||
logger.warning(
|
||||
f"received response with transaction ID {response.transaction_id} "
|
||||
f"but expected {self.pending_request.transaction_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Check if the response is an error
|
||||
if isinstance(response, SDP_ErrorResponse):
|
||||
self.pending_response.set_exception(
|
||||
ProtocolError(error_code=response.error_code)
|
||||
)
|
||||
return
|
||||
|
||||
# Check that the type of the response matches the request
|
||||
if response.pdu_id != SDP_PDU.RESPONSE_PDU_IDS.get(self.pending_request.pdu_id):
|
||||
logger.warning("response type mismatch")
|
||||
return
|
||||
|
||||
self.pending_response.set_result(response)
|
||||
|
||||
async def send_request(self, request: SDP_PDU) -> SDP_PDU:
|
||||
assert self.channel is not None
|
||||
async with self.request_semaphore:
|
||||
assert self.pending_request is None
|
||||
assert self.pending_response is None
|
||||
|
||||
# Create a future value to hold the eventual response
|
||||
self.pending_response = asyncio.get_running_loop().create_future()
|
||||
self.pending_request = request
|
||||
|
||||
try:
|
||||
self.channel.send_pdu(bytes(request))
|
||||
return await self.pending_response
|
||||
finally:
|
||||
self.pending_request = None
|
||||
self.pending_response = None
|
||||
|
||||
async def search_services(self, uuids: Iterable[core.UUID]) -> list[int]:
|
||||
"""
|
||||
Search for services by UUID.
|
||||
|
||||
Args:
|
||||
uuids: service the UUIDs to search for.
|
||||
|
||||
Returns:
|
||||
A list of matching service record handles.
|
||||
"""
|
||||
if self.pending_request is not None:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.channel is None:
|
||||
raise InvalidStateError('L2CAP not connected')
|
||||
|
||||
service_search_pattern = DataElement.sequence(
|
||||
[DataElement.uuid(uuid) for uuid in uuids]
|
||||
@@ -751,16 +874,16 @@ class Client:
|
||||
continuation_state = bytes([0])
|
||||
watchdog = SDP_CONTINUATION_WATCHDOG
|
||||
while watchdog > 0:
|
||||
response_pdu = await self.channel.send_request(
|
||||
response = await self.send_request(
|
||||
SDP_ServiceSearchRequest(
|
||||
transaction_id=0, # Transaction ID TODO: pick a real value
|
||||
transaction_id=self.make_transaction_id(),
|
||||
service_search_pattern=service_search_pattern,
|
||||
maximum_service_record_count=0xFFFF,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
response = SDP_PDU.from_bytes(response_pdu)
|
||||
logger.debug(f'<<< Response: {response}')
|
||||
assert isinstance(response, SDP_ServiceSearchResponse)
|
||||
service_record_handle_list += response.service_record_handle_list
|
||||
continuation_state = response.continuation_state
|
||||
if len(continuation_state) == 1 and continuation_state[0] == 0:
|
||||
@@ -770,20 +893,39 @@ class Client:
|
||||
|
||||
return service_record_handle_list
|
||||
|
||||
async def search_attributes(self, uuids, attribute_ids):
|
||||
async def search_attributes(
|
||||
self,
|
||||
uuids: Iterable[core.UUID],
|
||||
attribute_ids: Iterable[Union[int, tuple[int, int]]],
|
||||
) -> list[list[ServiceAttribute]]:
|
||||
"""
|
||||
Search for attributes by UUID and attribute IDs.
|
||||
|
||||
Args:
|
||||
uuids: the service UUIDs to search for.
|
||||
attribute_ids: list of attribute IDs or (start, end) attribute ID ranges.
|
||||
(use (0, 0xFFFF) to include all attributes)
|
||||
|
||||
Returns:
|
||||
A list of list of attributes, one list per matching service.
|
||||
"""
|
||||
if self.pending_request is not None:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.channel is None:
|
||||
raise InvalidStateError('L2CAP not connected')
|
||||
|
||||
service_search_pattern = DataElement.sequence(
|
||||
[DataElement.uuid(uuid) for uuid in uuids]
|
||||
)
|
||||
attribute_id_list = DataElement.sequence(
|
||||
[
|
||||
DataElement.unsigned_integer(
|
||||
attribute_id[0], value_size=attribute_id[1]
|
||||
(
|
||||
DataElement.unsigned_integer_32(
|
||||
attribute_id[0] << 16 | attribute_id[1]
|
||||
)
|
||||
if isinstance(attribute_id, tuple)
|
||||
else DataElement.unsigned_integer_16(attribute_id)
|
||||
)
|
||||
if isinstance(attribute_id, tuple)
|
||||
else DataElement.unsigned_integer_16(attribute_id)
|
||||
for attribute_id in attribute_ids
|
||||
]
|
||||
)
|
||||
@@ -793,17 +935,17 @@ class Client:
|
||||
continuation_state = bytes([0])
|
||||
watchdog = SDP_CONTINUATION_WATCHDOG
|
||||
while watchdog > 0:
|
||||
response_pdu = await self.channel.send_request(
|
||||
response = await self.send_request(
|
||||
SDP_ServiceSearchAttributeRequest(
|
||||
transaction_id=0, # Transaction ID TODO: pick a real value
|
||||
transaction_id=self.make_transaction_id(),
|
||||
service_search_pattern=service_search_pattern,
|
||||
maximum_attribute_byte_count=0xFFFF,
|
||||
attribute_id_list=attribute_id_list,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
response = SDP_PDU.from_bytes(response_pdu)
|
||||
logger.debug(f'<<< Response: {response}')
|
||||
assert isinstance(response, SDP_ServiceSearchAttributeResponse)
|
||||
accumulator += response.attribute_lists
|
||||
continuation_state = response.continuation_state
|
||||
if len(continuation_state) == 1 and continuation_state[0] == 0:
|
||||
@@ -823,17 +965,35 @@ class Client:
|
||||
if sequence.type == DataElement.SEQUENCE
|
||||
]
|
||||
|
||||
async def get_attributes(self, service_record_handle, attribute_ids):
|
||||
async def get_attributes(
|
||||
self,
|
||||
service_record_handle: int,
|
||||
attribute_ids: Iterable[Union[int, tuple[int, int]]],
|
||||
) -> list[ServiceAttribute]:
|
||||
"""
|
||||
Get attributes for a service.
|
||||
|
||||
Args:
|
||||
service_record_handle: the handle for a service
|
||||
attribute_ids: list or attribute IDs or (start, end) attribute ID handles.
|
||||
|
||||
Returns:
|
||||
A list of attributes.
|
||||
"""
|
||||
if self.pending_request is not None:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.channel is None:
|
||||
raise InvalidStateError('L2CAP not connected')
|
||||
|
||||
attribute_id_list = DataElement.sequence(
|
||||
[
|
||||
DataElement.unsigned_integer(
|
||||
attribute_id[0], value_size=attribute_id[1]
|
||||
(
|
||||
DataElement.unsigned_integer_32(
|
||||
attribute_id[0] << 16 | attribute_id[1]
|
||||
)
|
||||
if isinstance(attribute_id, tuple)
|
||||
else DataElement.unsigned_integer_16(attribute_id)
|
||||
)
|
||||
if isinstance(attribute_id, tuple)
|
||||
else DataElement.unsigned_integer_16(attribute_id)
|
||||
for attribute_id in attribute_ids
|
||||
]
|
||||
)
|
||||
@@ -843,17 +1003,17 @@ class Client:
|
||||
continuation_state = bytes([0])
|
||||
watchdog = SDP_CONTINUATION_WATCHDOG
|
||||
while watchdog > 0:
|
||||
response_pdu = await self.channel.send_request(
|
||||
response = await self.send_request(
|
||||
SDP_ServiceAttributeRequest(
|
||||
transaction_id=0, # Transaction ID TODO: pick a real value
|
||||
transaction_id=self.make_transaction_id(),
|
||||
service_record_handle=service_record_handle,
|
||||
maximum_attribute_byte_count=0xFFFF,
|
||||
attribute_id_list=attribute_id_list,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
response = SDP_PDU.from_bytes(response_pdu)
|
||||
logger.debug(f'<<< Response: {response}')
|
||||
assert isinstance(response, SDP_ServiceAttributeResponse)
|
||||
accumulator += response.attribute_list
|
||||
continuation_state = response.continuation_state
|
||||
if len(continuation_state) == 1 and continuation_state[0] == 0:
|
||||
@@ -869,25 +1029,38 @@ class Client:
|
||||
|
||||
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args) -> None:
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Server:
|
||||
CONTINUATION_STATE = bytes([0x01, 0x43])
|
||||
CONTINUATION_STATE = bytes([0x01, 0x00])
|
||||
channel: Optional[l2cap.ClassicChannel]
|
||||
Service = NewType('Service', list[ServiceAttribute])
|
||||
service_records: dict[int, Service]
|
||||
current_response: Union[None, bytes, tuple[int, list[int]]]
|
||||
|
||||
def __init__(self, device):
|
||||
def __init__(self, device: Device) -> None:
|
||||
self.device = device
|
||||
self.service_records = {} # Service records maps, by record handle
|
||||
self.channel = None
|
||||
self.current_response = None
|
||||
self.current_response = None # Current response data, used for continuations
|
||||
|
||||
def register(self, l2cap_channel_manager):
|
||||
l2cap_channel_manager.register_server(SDP_PSM, self.on_connection)
|
||||
def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
|
||||
l2cap_channel_manager.create_classic_server(
|
||||
spec=l2cap.ClassicChannelSpec(psm=SDP_PSM), handler=self.on_connection
|
||||
)
|
||||
|
||||
def send_response(self, response):
|
||||
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
|
||||
self.channel.send_pdu(response)
|
||||
|
||||
def match_services(self, search_pattern):
|
||||
def match_services(self, search_pattern: DataElement) -> dict[int, Service]:
|
||||
# Find the services for which the attributes in the pattern is a subset of the
|
||||
# service's attribute values (NOTE: the value search recurses into sequences)
|
||||
matching_services = {}
|
||||
@@ -928,7 +1101,7 @@ class Server:
|
||||
try:
|
||||
handler(sdp_pdu)
|
||||
except Exception as error:
|
||||
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
|
||||
logger.exception(f'{color("!!! Exception in handler:", "red")} {error}')
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=sdp_pdu.transaction_id,
|
||||
@@ -944,6 +1117,31 @@ class Server:
|
||||
)
|
||||
)
|
||||
|
||||
def check_continuation(
|
||||
self,
|
||||
continuation_state: bytes,
|
||||
transaction_id: int,
|
||||
) -> Optional[bool]:
|
||||
# Check if this is a valid continuation
|
||||
if len(continuation_state) > 1:
|
||||
if (
|
||||
self.current_response is None
|
||||
or continuation_state != self.CONTINUATION_STATE
|
||||
):
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=transaction_id,
|
||||
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
|
||||
)
|
||||
)
|
||||
return None
|
||||
return True
|
||||
|
||||
# Cleanup any partial response leftover
|
||||
self.current_response = None
|
||||
|
||||
return False
|
||||
|
||||
def get_next_response_payload(self, maximum_size):
|
||||
if len(self.current_response) > maximum_size:
|
||||
payload = self.current_response[:maximum_size]
|
||||
@@ -957,7 +1155,9 @@ class Server:
|
||||
return (payload, continuation_state)
|
||||
|
||||
@staticmethod
|
||||
def get_service_attributes(service, attribute_ids):
|
||||
def get_service_attributes(
|
||||
service: Service, attribute_ids: Iterable[DataElement]
|
||||
) -> DataElement:
|
||||
attributes = []
|
||||
for attribute_id in attribute_ids:
|
||||
if attribute_id.value_size == 4:
|
||||
@@ -982,47 +1182,48 @@ class Server:
|
||||
|
||||
return attribute_list
|
||||
|
||||
def on_sdp_service_search_request(self, request):
|
||||
def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None:
|
||||
# Check if this is a continuation
|
||||
if len(request.continuation_state) > 1:
|
||||
if not self.current_response:
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
|
||||
)
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Cleanup any partial response leftover
|
||||
self.current_response = None
|
||||
if (
|
||||
continuation := self.check_continuation(
|
||||
request.continuation_state, request.transaction_id
|
||||
)
|
||||
) is None:
|
||||
return
|
||||
|
||||
if not continuation:
|
||||
# Find the matching services
|
||||
matching_services = self.match_services(request.service_search_pattern)
|
||||
service_record_handles = list(matching_services.keys())
|
||||
logger.debug(f'Service Record Handles: {service_record_handles}')
|
||||
|
||||
# Only return up to the maximum requested
|
||||
service_record_handles_subset = service_record_handles[
|
||||
: request.maximum_service_record_count
|
||||
]
|
||||
|
||||
# Serialize to a byte array, and remember the total count
|
||||
logger.debug(f'Service Record Handles: {service_record_handles}')
|
||||
self.current_response = (
|
||||
len(service_record_handles),
|
||||
service_record_handles_subset,
|
||||
)
|
||||
|
||||
# Respond, keeping any unsent handles for later
|
||||
service_record_handles = self.current_response[1][
|
||||
: request.maximum_service_record_count
|
||||
assert isinstance(self.current_response, tuple)
|
||||
assert self.channel is not None
|
||||
total_service_record_count, service_record_handles = self.current_response
|
||||
maximum_service_record_count = (self.channel.peer_mtu - 11) // 4
|
||||
service_record_handles_remaining = service_record_handles[
|
||||
maximum_service_record_count:
|
||||
]
|
||||
service_record_handles = service_record_handles[:maximum_service_record_count]
|
||||
self.current_response = (
|
||||
self.current_response[0],
|
||||
self.current_response[1][request.maximum_service_record_count :],
|
||||
total_service_record_count,
|
||||
service_record_handles_remaining,
|
||||
)
|
||||
continuation_state = (
|
||||
Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
|
||||
Server.CONTINUATION_STATE
|
||||
if service_record_handles_remaining
|
||||
else bytes([0])
|
||||
)
|
||||
service_record_handle_list = b''.join(
|
||||
[struct.pack('>I', handle) for handle in service_record_handles]
|
||||
@@ -1030,28 +1231,25 @@ class Server:
|
||||
self.send_response(
|
||||
SDP_ServiceSearchResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
total_service_record_count=self.current_response[0],
|
||||
total_service_record_count=total_service_record_count,
|
||||
current_service_record_count=len(service_record_handles),
|
||||
service_record_handle_list=service_record_handle_list,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
|
||||
def on_sdp_service_attribute_request(self, request):
|
||||
def on_sdp_service_attribute_request(
|
||||
self, request: SDP_ServiceAttributeRequest
|
||||
) -> None:
|
||||
# Check if this is a continuation
|
||||
if len(request.continuation_state) > 1:
|
||||
if not self.current_response:
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
|
||||
)
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Cleanup any partial response leftover
|
||||
self.current_response = None
|
||||
if (
|
||||
continuation := self.check_continuation(
|
||||
request.continuation_state, request.transaction_id
|
||||
)
|
||||
) is None:
|
||||
return
|
||||
|
||||
if not continuation:
|
||||
# Check that the service exists
|
||||
service = self.service_records.get(request.service_record_handle)
|
||||
if service is None:
|
||||
@@ -1073,32 +1271,34 @@ class Server:
|
||||
self.current_response = bytes(attribute_list)
|
||||
|
||||
# Respond, keeping any pending chunks for later
|
||||
attribute_list, continuation_state = self.get_next_response_payload(
|
||||
request.maximum_attribute_byte_count
|
||||
assert self.channel is not None
|
||||
maximum_attribute_byte_count = min(
|
||||
request.maximum_attribute_byte_count, self.channel.peer_mtu - 9
|
||||
)
|
||||
attribute_list_response, continuation_state = self.get_next_response_payload(
|
||||
maximum_attribute_byte_count
|
||||
)
|
||||
self.send_response(
|
||||
SDP_ServiceAttributeResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
attribute_list_byte_count=len(attribute_list),
|
||||
attribute_list=attribute_list,
|
||||
attribute_list_byte_count=len(attribute_list_response),
|
||||
attribute_list=attribute_list_response,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
|
||||
def on_sdp_service_search_attribute_request(self, request):
|
||||
def on_sdp_service_search_attribute_request(
|
||||
self, request: SDP_ServiceSearchAttributeRequest
|
||||
) -> None:
|
||||
# Check if this is a continuation
|
||||
if len(request.continuation_state) > 1:
|
||||
if not self.current_response:
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Cleanup any partial response leftover
|
||||
self.current_response = None
|
||||
if (
|
||||
continuation := self.check_continuation(
|
||||
request.continuation_state, request.transaction_id
|
||||
)
|
||||
) is None:
|
||||
return
|
||||
|
||||
if not continuation:
|
||||
# Find the matching services
|
||||
matching_services = self.match_services(
|
||||
request.service_search_pattern
|
||||
@@ -1118,14 +1318,18 @@ class Server:
|
||||
self.current_response = bytes(attribute_lists)
|
||||
|
||||
# Respond, keeping any pending chunks for later
|
||||
attribute_lists, continuation_state = self.get_next_response_payload(
|
||||
request.maximum_attribute_byte_count
|
||||
assert self.channel is not None
|
||||
maximum_attribute_byte_count = min(
|
||||
request.maximum_attribute_byte_count, self.channel.peer_mtu - 9
|
||||
)
|
||||
attribute_lists_response, continuation_state = self.get_next_response_payload(
|
||||
maximum_attribute_byte_count
|
||||
)
|
||||
self.send_response(
|
||||
SDP_ServiceSearchAttributeResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
attribute_lists_byte_count=len(attribute_lists),
|
||||
attribute_lists=attribute_lists,
|
||||
attribute_lists_byte_count=len(attribute_lists_response),
|
||||
attribute_lists=attribute_lists_response,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
|
||||
387
bumble/smp.py
387
bumble/smp.py
@@ -27,6 +27,7 @@ import logging
|
||||
import asyncio
|
||||
import enum
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -37,6 +38,7 @@ from typing import (
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pyee import EventEmitter
|
||||
@@ -52,6 +54,8 @@ from .core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_CENTRAL_ROLE,
|
||||
BT_LE_TRANSPORT,
|
||||
AdvertisingData,
|
||||
InvalidArgumentError,
|
||||
ProtocolError,
|
||||
name_or_number,
|
||||
)
|
||||
@@ -184,8 +188,8 @@ SMP_KEYPRESS_AUTHREQ = 0b00010000
|
||||
SMP_CT2_AUTHREQ = 0b00100000
|
||||
|
||||
# Crypto salt
|
||||
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('00000000000000000000000000000000746D7031')
|
||||
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032')
|
||||
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('000000000000000000000000746D7031')
|
||||
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
|
||||
|
||||
# fmt: on
|
||||
# pylint: enable=line-too-long
|
||||
@@ -294,11 +298,8 @@ class SMP_Command:
|
||||
def init_from_bytes(self, pdu: bytes, offset: int) -> None:
|
||||
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||
|
||||
def to_bytes(self):
|
||||
return self.pdu
|
||||
|
||||
def __bytes__(self):
|
||||
return self.to_bytes()
|
||||
return self.pdu
|
||||
|
||||
def __str__(self):
|
||||
result = color(self.name, 'yellow')
|
||||
@@ -562,6 +563,54 @@ class PairingMethod(enum.IntEnum):
|
||||
CTKD_OVER_CLASSIC = 4
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OobContext:
|
||||
"""Cryptographic context for LE SC OOB pairing."""
|
||||
|
||||
ecc_key: crypto.EccKey
|
||||
r: bytes
|
||||
|
||||
def __init__(
|
||||
self, ecc_key: Optional[crypto.EccKey] = None, r: Optional[bytes] = None
|
||||
) -> None:
|
||||
self.ecc_key = crypto.EccKey.generate() if ecc_key is None else ecc_key
|
||||
self.r = crypto.r() if r is None else r
|
||||
|
||||
def share(self) -> OobSharedData:
|
||||
pkx = self.ecc_key.x[::-1]
|
||||
return OobSharedData(c=crypto.f4(pkx, pkx, self.r, bytes(1)), r=self.r)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OobLegacyContext:
|
||||
"""Cryptographic context for LE Legacy OOB pairing."""
|
||||
|
||||
tk: bytes
|
||||
|
||||
def __init__(self, tk: Optional[bytes] = None) -> None:
|
||||
self.tk = crypto.r() if tk is None else tk
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class OobSharedData:
|
||||
"""Shareable data for LE SC OOB pairing."""
|
||||
|
||||
c: bytes
|
||||
r: bytes
|
||||
|
||||
def to_ad(self) -> AdvertisingData:
|
||||
return AdvertisingData(
|
||||
[
|
||||
(AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE, self.c),
|
||||
(AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE, self.r),
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'OOB(C={self.c.hex()}, R={self.r.hex()})'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Session:
|
||||
# I/O Capability to pairing method decision matrix
|
||||
@@ -626,6 +675,13 @@ class Session:
|
||||
},
|
||||
}
|
||||
|
||||
ea: bytes
|
||||
eb: bytes
|
||||
ltk: bytes
|
||||
preq: bytes
|
||||
pres: bytes
|
||||
tk: bytes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: Manager,
|
||||
@@ -635,17 +691,11 @@ class Session:
|
||||
) -> None:
|
||||
self.manager = manager
|
||||
self.connection = connection
|
||||
self.preq: Optional[bytes] = None
|
||||
self.pres: Optional[bytes] = None
|
||||
self.ea = None
|
||||
self.eb = None
|
||||
self.tk = bytes(16)
|
||||
self.r = bytes(16)
|
||||
self.stk = None
|
||||
self.ltk = None
|
||||
self.ltk_ediv = 0
|
||||
self.ltk_rand = bytes(8)
|
||||
self.link_key = None
|
||||
self.link_key: Optional[bytes] = None
|
||||
self.maximum_encryption_key_size: int = 0
|
||||
self.initiator_key_distribution: int = 0
|
||||
self.responder_key_distribution: int = 0
|
||||
self.peer_random_value: Optional[bytes] = None
|
||||
@@ -658,7 +708,7 @@ class Session:
|
||||
self.peer_bd_addr: Optional[Address] = None
|
||||
self.peer_signature_key = None
|
||||
self.peer_expected_distributions: List[Type[SMP_Command]] = []
|
||||
self.dh_key = None
|
||||
self.dh_key = b''
|
||||
self.confirm_value = None
|
||||
self.passkey: Optional[int] = None
|
||||
self.passkey_ready = asyncio.Event()
|
||||
@@ -686,12 +736,16 @@ class Session:
|
||||
|
||||
# Create a future that can be used to wait for the session to complete
|
||||
if self.is_initiator:
|
||||
self.pairing_result: Optional[
|
||||
asyncio.Future[None]
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
self.pairing_result: Optional[asyncio.Future[None]] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
else:
|
||||
self.pairing_result = None
|
||||
|
||||
self.maximum_encryption_key_size = (
|
||||
pairing_config.delegate.maximum_encryption_key_size
|
||||
)
|
||||
|
||||
# Key Distribution (default values before negotiation)
|
||||
self.initiator_key_distribution = (
|
||||
pairing_config.delegate.local_initiator_key_distribution
|
||||
@@ -711,12 +765,17 @@ class Session:
|
||||
self.io_capability = pairing_config.delegate.io_capability
|
||||
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
|
||||
|
||||
# OOB (not supported yet)
|
||||
self.oob = False
|
||||
# OOB
|
||||
self.oob_data_flag = (
|
||||
1 if pairing_config.oob and pairing_config.oob.peer_data else 0
|
||||
)
|
||||
|
||||
# Set up addresses
|
||||
self_address = connection.self_address
|
||||
self_address = connection.self_resolvable_address or connection.self_address
|
||||
peer_address = connection.peer_resolvable_address or connection.peer_address
|
||||
logger.debug(
|
||||
f"pairing with self_address={self_address}, peer_address={peer_address}"
|
||||
)
|
||||
if self.is_initiator:
|
||||
self.ia = bytes(self_address)
|
||||
self.iat = 1 if self_address.is_random else 0
|
||||
@@ -728,9 +787,35 @@ class Session:
|
||||
self.ia = bytes(peer_address)
|
||||
self.iat = 1 if peer_address.is_random else 0
|
||||
|
||||
# Select the ECC key, TK and r initial value
|
||||
if pairing_config.oob:
|
||||
self.peer_oob_data = pairing_config.oob.peer_data
|
||||
if pairing_config.sc:
|
||||
if pairing_config.oob.our_context is None:
|
||||
raise InvalidArgumentError(
|
||||
"oob pairing config requires a context when sc is True"
|
||||
)
|
||||
self.r = pairing_config.oob.our_context.r
|
||||
self.ecc_key = pairing_config.oob.our_context.ecc_key
|
||||
if pairing_config.oob.legacy_context is not None:
|
||||
self.tk = pairing_config.oob.legacy_context.tk
|
||||
else:
|
||||
if pairing_config.oob.legacy_context is None:
|
||||
raise InvalidArgumentError(
|
||||
"oob pairing config requires a legacy context when sc is False"
|
||||
)
|
||||
self.r = bytes(16)
|
||||
self.ecc_key = manager.ecc_key
|
||||
self.tk = pairing_config.oob.legacy_context.tk
|
||||
else:
|
||||
self.peer_oob_data = None
|
||||
self.r = bytes(16)
|
||||
self.ecc_key = manager.ecc_key
|
||||
self.tk = bytes(16)
|
||||
|
||||
@property
|
||||
def pkx(self) -> Tuple[bytes, bytes]:
|
||||
return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x)
|
||||
return (self.ecc_key.x[::-1], self.peer_public_key_x)
|
||||
|
||||
@property
|
||||
def pka(self) -> bytes:
|
||||
@@ -767,7 +852,10 @@ class Session:
|
||||
return None
|
||||
|
||||
def decide_pairing_method(
|
||||
self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
|
||||
self,
|
||||
auth_req: int,
|
||||
initiator_io_capability: int,
|
||||
responder_io_capability: int,
|
||||
) -> None:
|
||||
if self.connection.transport == BT_BR_EDR_TRANSPORT:
|
||||
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
|
||||
@@ -908,9 +996,9 @@ class Session:
|
||||
|
||||
command = SMP_Pairing_Request_Command(
|
||||
io_capability=self.io_capability,
|
||||
oob_data_flag=0,
|
||||
oob_data_flag=self.oob_data_flag,
|
||||
auth_req=self.auth_req,
|
||||
maximum_encryption_key_size=16,
|
||||
maximum_encryption_key_size=self.maximum_encryption_key_size,
|
||||
initiator_key_distribution=self.initiator_key_distribution,
|
||||
responder_key_distribution=self.responder_key_distribution,
|
||||
)
|
||||
@@ -920,9 +1008,9 @@ class Session:
|
||||
def send_pairing_response_command(self) -> None:
|
||||
response = SMP_Pairing_Response_Command(
|
||||
io_capability=self.io_capability,
|
||||
oob_data_flag=0,
|
||||
oob_data_flag=self.oob_data_flag,
|
||||
auth_req=self.auth_req,
|
||||
maximum_encryption_key_size=16,
|
||||
maximum_encryption_key_size=self.maximum_encryption_key_size,
|
||||
initiator_key_distribution=self.initiator_key_distribution,
|
||||
responder_key_distribution=self.responder_key_distribution,
|
||||
)
|
||||
@@ -930,8 +1018,10 @@ class Session:
|
||||
self.send_command(response)
|
||||
|
||||
def send_pairing_confirm_command(self) -> None:
|
||||
self.r = crypto.r()
|
||||
logger.debug(f'generated random: {self.r.hex()}')
|
||||
|
||||
if self.pairing_method != PairingMethod.OOB:
|
||||
self.r = crypto.r()
|
||||
logger.debug(f'generated random: {self.r.hex()}')
|
||||
|
||||
if self.sc:
|
||||
|
||||
@@ -981,8 +1071,8 @@ class Session:
|
||||
def send_public_key_command(self) -> None:
|
||||
self.send_command(
|
||||
SMP_Pairing_Public_Key_Command(
|
||||
public_key_x=bytes(reversed(self.manager.ecc_key.x)),
|
||||
public_key_y=bytes(reversed(self.manager.ecc_key.y)),
|
||||
public_key_x=self.ecc_key.x[::-1],
|
||||
public_key_y=self.ecc_key.y[::-1],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -994,11 +1084,19 @@ class Session:
|
||||
)
|
||||
|
||||
def send_identity_address_command(self) -> None:
|
||||
identity_address = {
|
||||
None: self.connection.self_address,
|
||||
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
|
||||
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.random_address,
|
||||
}[self.pairing_config.identity_address_type]
|
||||
if self.pairing_config.identity_address_type == Address.PUBLIC_DEVICE_ADDRESS:
|
||||
identity_address = self.manager.device.public_address
|
||||
elif self.pairing_config.identity_address_type == Address.RANDOM_DEVICE_ADDRESS:
|
||||
identity_address = self.manager.device.static_address
|
||||
else:
|
||||
# No identity address type set. If the controller has a public address, it
|
||||
# will be more responsible to be the identity address.
|
||||
if self.manager.device.public_address != Address.ANY:
|
||||
logger.debug("No identity address type set, using PUBLIC")
|
||||
identity_address = self.manager.device.public_address
|
||||
else:
|
||||
logger.debug("No identity address type set, using RANDOM")
|
||||
identity_address = self.manager.device.static_address
|
||||
self.send_command(
|
||||
SMP_Identity_Address_Information_Command(
|
||||
addr_type=identity_address.address_type,
|
||||
@@ -1010,7 +1108,7 @@ class Session:
|
||||
# We can now encrypt the connection with the short term key, so that we can
|
||||
# distribute the long term and/or other keys over an encrypted connection
|
||||
self.manager.device.host.send_command_sync(
|
||||
HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
|
||||
HCI_LE_Enable_Encryption_Command(
|
||||
connection_handle=self.connection.handle,
|
||||
random_number=bytes(8),
|
||||
encrypted_diversifier=0,
|
||||
@@ -1018,18 +1116,56 @@ class Session:
|
||||
)
|
||||
)
|
||||
|
||||
async def derive_ltk(self) -> None:
|
||||
link_key = await self.manager.device.get_link_key(self.connection.peer_address)
|
||||
assert link_key is not None
|
||||
@classmethod
|
||||
def derive_ltk(cls, link_key: bytes, ct2: bool) -> bytes:
|
||||
'''Derives Long Term Key from Link Key.
|
||||
|
||||
Args:
|
||||
link_key: BR/EDR Link Key bytes in little-endian.
|
||||
ct2: whether ct2 is supported on both devices.
|
||||
Returns:
|
||||
LE Long Tern Key bytes in little-endian.
|
||||
'''
|
||||
ilk = (
|
||||
crypto.h7(salt=SMP_CTKD_H7_BRLE_SALT, w=link_key)
|
||||
if self.ct2
|
||||
if ct2
|
||||
else crypto.h6(link_key, b'tmp2')
|
||||
)
|
||||
self.ltk = crypto.h6(ilk, b'brle')
|
||||
return crypto.h6(ilk, b'brle')
|
||||
|
||||
@classmethod
|
||||
def derive_link_key(cls, ltk: bytes, ct2: bool) -> bytes:
|
||||
'''Derives Link Key from Long Term Key.
|
||||
|
||||
Args:
|
||||
ltk: LE Long Term Key bytes in little-endian.
|
||||
ct2: whether ct2 is supported on both devices.
|
||||
Returns:
|
||||
BR/EDR Link Key bytes in little-endian.
|
||||
'''
|
||||
ilk = (
|
||||
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=ltk)
|
||||
if ct2
|
||||
else crypto.h6(ltk, b'tmp1')
|
||||
)
|
||||
return crypto.h6(ilk, b'lebr')
|
||||
|
||||
async def get_link_key_and_derive_ltk(self) -> None:
|
||||
'''Retrieves BR/EDR Link Key from storage and derive it to LE LTK.'''
|
||||
self.link_key = await self.manager.device.get_link_key(
|
||||
self.connection.peer_address
|
||||
)
|
||||
if self.link_key is None:
|
||||
logging.warning(
|
||||
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
|
||||
)
|
||||
self.send_pairing_failed(
|
||||
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR
|
||||
)
|
||||
else:
|
||||
self.ltk = self.derive_ltk(self.link_key, self.ct2)
|
||||
|
||||
def distribute_keys(self) -> None:
|
||||
|
||||
# Distribute the keys as required
|
||||
if self.is_initiator:
|
||||
# CTKD: Derive LTK from LinkKey
|
||||
@@ -1038,7 +1174,7 @@ class Session:
|
||||
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
||||
):
|
||||
self.ctkd_task = self.connection.abort_on(
|
||||
'disconnection', self.derive_ltk()
|
||||
'disconnection', self.get_link_key_and_derive_ltk()
|
||||
)
|
||||
elif not self.sc:
|
||||
# Distribute the LTK, EDIV and RAND
|
||||
@@ -1068,12 +1204,7 @@ class Session:
|
||||
|
||||
# CTKD, calculate BR/EDR link key
|
||||
if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
||||
ilk = (
|
||||
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk)
|
||||
if self.ct2
|
||||
else crypto.h6(self.ltk, b'tmp1')
|
||||
)
|
||||
self.link_key = crypto.h6(ilk, b'lebr')
|
||||
self.link_key = self.derive_link_key(self.ltk, self.ct2)
|
||||
|
||||
else:
|
||||
# CTKD: Derive LTK from LinkKey
|
||||
@@ -1082,7 +1213,7 @@ class Session:
|
||||
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
|
||||
):
|
||||
self.ctkd_task = self.connection.abort_on(
|
||||
'disconnection', self.derive_ltk()
|
||||
'disconnection', self.get_link_key_and_derive_ltk()
|
||||
)
|
||||
# Distribute the LTK, EDIV and RAND
|
||||
elif not self.sc:
|
||||
@@ -1112,12 +1243,7 @@ class Session:
|
||||
|
||||
# CTKD, calculate BR/EDR link key
|
||||
if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
|
||||
ilk = (
|
||||
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk)
|
||||
if self.ct2
|
||||
else crypto.h6(self.ltk, b'tmp1')
|
||||
)
|
||||
self.link_key = crypto.h6(ilk, b'lebr')
|
||||
self.link_key = self.derive_link_key(self.ltk, self.ct2)
|
||||
|
||||
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
|
||||
# Set our expectations for what to wait for in the key distribution phase
|
||||
@@ -1295,7 +1421,7 @@ class Session:
|
||||
try:
|
||||
handler(command)
|
||||
except Exception as error:
|
||||
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
|
||||
logger.exception(f'{color("!!! Exception in handler:", "red")} {error}')
|
||||
response = SMP_Pairing_Failed_Command(
|
||||
reason=SMP_UNSPECIFIED_REASON_ERROR
|
||||
)
|
||||
@@ -1332,15 +1458,28 @@ class Session:
|
||||
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
|
||||
self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0)
|
||||
|
||||
# Check for OOB
|
||||
if command.oob_data_flag != 0:
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
return
|
||||
# Infer the pairing method
|
||||
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
|
||||
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0)
|
||||
):
|
||||
# Use OOB
|
||||
self.pairing_method = PairingMethod.OOB
|
||||
if not self.sc and self.tk is None:
|
||||
# For legacy OOB, TK is required.
|
||||
logger.warning("legacy OOB without TK")
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
return
|
||||
if command.oob_data_flag == 0:
|
||||
# The peer doesn't have OOB data, use r=0
|
||||
self.r = bytes(16)
|
||||
else:
|
||||
# Decide which pairing method to use from the IO capability
|
||||
self.decide_pairing_method(
|
||||
command.auth_req,
|
||||
command.io_capability,
|
||||
self.io_capability,
|
||||
)
|
||||
|
||||
# Decide which pairing method to use
|
||||
self.decide_pairing_method(
|
||||
command.auth_req, command.io_capability, self.io_capability
|
||||
)
|
||||
logger.debug(f'pairing method: {self.pairing_method.name}')
|
||||
|
||||
# Key distribution
|
||||
@@ -1389,15 +1528,26 @@ class Session:
|
||||
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
|
||||
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
|
||||
|
||||
# Check for OOB
|
||||
if self.sc and command.oob_data_flag:
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
return
|
||||
# Infer the pairing method
|
||||
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
|
||||
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0)
|
||||
):
|
||||
# Use OOB
|
||||
self.pairing_method = PairingMethod.OOB
|
||||
if not self.sc and self.tk is None:
|
||||
# For legacy OOB, TK is required.
|
||||
logger.warning("legacy OOB without TK")
|
||||
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
|
||||
return
|
||||
if command.oob_data_flag == 0:
|
||||
# The peer doesn't have OOB data, use r=0
|
||||
self.r = bytes(16)
|
||||
else:
|
||||
# Decide which pairing method to use from the IO capability
|
||||
self.decide_pairing_method(
|
||||
command.auth_req, self.io_capability, command.io_capability
|
||||
)
|
||||
|
||||
# Decide which pairing method to use
|
||||
self.decide_pairing_method(
|
||||
command.auth_req, self.io_capability, command.io_capability
|
||||
)
|
||||
logger.debug(f'pairing method: {self.pairing_method.name}')
|
||||
|
||||
# Key distribution
|
||||
@@ -1548,12 +1698,13 @@ class Session:
|
||||
if self.passkey_step < 20:
|
||||
self.send_pairing_confirm_command()
|
||||
return
|
||||
else:
|
||||
elif self.pairing_method != PairingMethod.OOB:
|
||||
return
|
||||
else:
|
||||
if self.pairing_method in (
|
||||
PairingMethod.JUST_WORKS,
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
PairingMethod.OOB,
|
||||
):
|
||||
self.send_pairing_random_command()
|
||||
elif self.pairing_method == PairingMethod.PASSKEY:
|
||||
@@ -1597,8 +1748,23 @@ class Session:
|
||||
assert self.passkey
|
||||
ra = self.passkey.to_bytes(16, byteorder='little')
|
||||
rb = ra
|
||||
elif self.pairing_method == PairingMethod.OOB:
|
||||
if self.is_initiator:
|
||||
if self.peer_oob_data:
|
||||
rb = self.peer_oob_data.r
|
||||
ra = self.r
|
||||
else:
|
||||
rb = bytes(16)
|
||||
ra = self.r
|
||||
else:
|
||||
if self.peer_oob_data:
|
||||
ra = self.peer_oob_data.r
|
||||
rb = self.r
|
||||
else:
|
||||
ra = bytes(16)
|
||||
rb = self.r
|
||||
|
||||
else:
|
||||
# OOB not implemented yet
|
||||
return
|
||||
|
||||
assert self.preq and self.pres
|
||||
@@ -1650,18 +1816,33 @@ class Session:
|
||||
self.peer_public_key_y = command.public_key_y
|
||||
|
||||
# Compute the DH key
|
||||
self.dh_key = bytes(
|
||||
reversed(
|
||||
self.manager.ecc_key.dh(
|
||||
bytes(reversed(command.public_key_x)),
|
||||
bytes(reversed(command.public_key_y)),
|
||||
)
|
||||
)
|
||||
)
|
||||
self.dh_key = self.ecc_key.dh(
|
||||
command.public_key_x[::-1],
|
||||
command.public_key_y[::-1],
|
||||
)[::-1]
|
||||
logger.debug(f'DH key: {self.dh_key.hex()}')
|
||||
|
||||
if self.pairing_method == PairingMethod.OOB:
|
||||
# Check against shared OOB data
|
||||
if self.peer_oob_data:
|
||||
confirm_verifier = crypto.f4(
|
||||
self.peer_public_key_x,
|
||||
self.peer_public_key_x,
|
||||
self.peer_oob_data.r,
|
||||
bytes(1),
|
||||
)
|
||||
if not self.check_expected_value(
|
||||
self.peer_oob_data.c,
|
||||
confirm_verifier,
|
||||
SMP_CONFIRM_VALUE_FAILED_ERROR,
|
||||
):
|
||||
return
|
||||
|
||||
if self.is_initiator:
|
||||
self.send_pairing_confirm_command()
|
||||
if self.pairing_method == PairingMethod.OOB:
|
||||
self.send_pairing_random_command()
|
||||
elif self.pairing_method == PairingMethod.PASSKEY:
|
||||
self.send_pairing_confirm_command()
|
||||
else:
|
||||
if self.pairing_method == PairingMethod.PASSKEY:
|
||||
self.display_or_input_passkey()
|
||||
@@ -1672,6 +1853,7 @@ class Session:
|
||||
if self.pairing_method in (
|
||||
PairingMethod.JUST_WORKS,
|
||||
PairingMethod.NUMERIC_COMPARISON,
|
||||
PairingMethod.OOB,
|
||||
):
|
||||
# We can now send the confirmation value
|
||||
self.send_pairing_confirm_command()
|
||||
@@ -1700,7 +1882,6 @@ class Session:
|
||||
else:
|
||||
self.send_pairing_dhkey_check_command()
|
||||
else:
|
||||
assert self.ltk
|
||||
self.start_encryption(self.ltk)
|
||||
|
||||
def on_smp_pairing_failed_command(
|
||||
@@ -1750,6 +1931,7 @@ class Manager(EventEmitter):
|
||||
sessions: Dict[int, Session]
|
||||
pairing_config_factory: Callable[[Connection], PairingConfig]
|
||||
session_proxy: Type[Session]
|
||||
_ecc_key: Optional[crypto.EccKey]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1769,9 +1951,28 @@ class Manager(EventEmitter):
|
||||
f'{connection.peer_address}: {command}'
|
||||
)
|
||||
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
|
||||
connection.send_l2cap_pdu(cid, command.to_bytes())
|
||||
connection.send_l2cap_pdu(cid, bytes(command))
|
||||
|
||||
def on_smp_security_request_command(
|
||||
self, connection: Connection, request: SMP_Security_Request_Command
|
||||
) -> None:
|
||||
connection.emit('security_request', request.auth_req)
|
||||
|
||||
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
|
||||
# Parse the L2CAP payload into an SMP Command object
|
||||
command = SMP_Command.from_bytes(pdu)
|
||||
logger.debug(
|
||||
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address}: {command}'
|
||||
)
|
||||
|
||||
# Security request is more than just pairing, so let applications handle them
|
||||
if command.code == SMP_SECURITY_REQUEST_COMMAND:
|
||||
self.on_smp_security_request_command(
|
||||
connection, cast(SMP_Security_Request_Command, command)
|
||||
)
|
||||
return
|
||||
|
||||
# Look for a session with this connection, and create one if none exists
|
||||
if not (session := self.sessions.get(connection.handle)):
|
||||
if connection.role == BT_CENTRAL_ROLE:
|
||||
@@ -1782,13 +1983,6 @@ class Manager(EventEmitter):
|
||||
)
|
||||
self.sessions[connection.handle] = session
|
||||
|
||||
# Parse the L2CAP payload into an SMP Command object
|
||||
command = SMP_Command.from_bytes(pdu)
|
||||
logger.debug(
|
||||
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
|
||||
f'{connection.peer_address}: {command}'
|
||||
)
|
||||
|
||||
# Delegate the handling of the command to the session
|
||||
session.on_smp_command(command)
|
||||
|
||||
@@ -1832,9 +2026,8 @@ class Manager(EventEmitter):
|
||||
) -> None:
|
||||
# Store the keys in the key store
|
||||
if self.device.keystore and identity_address is not None:
|
||||
await self.device.keystore.update(str(identity_address), keys)
|
||||
await self.device.refresh_resolving_list()
|
||||
|
||||
# Make sure on_pairing emits after key update.
|
||||
await self.device.update_keys(str(identity_address), keys)
|
||||
# Notify the device
|
||||
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import datetime
|
||||
from typing import BinaryIO, Generator
|
||||
import os
|
||||
|
||||
from bumble import core
|
||||
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
|
||||
|
||||
|
||||
@@ -138,13 +139,13 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
|
||||
|
||||
"""
|
||||
if ':' not in spec:
|
||||
raise ValueError('snooper type prefix missing')
|
||||
raise core.InvalidArgumentError('snooper type prefix missing')
|
||||
|
||||
snooper_type, snooper_args = spec.split(':', maxsplit=1)
|
||||
|
||||
if snooper_type == 'btsnoop':
|
||||
if ':' not in snooper_args:
|
||||
raise ValueError('I/O type for btsnoop snooper type missing')
|
||||
raise core.InvalidArgumentError('I/O type for btsnoop snooper type missing')
|
||||
|
||||
io_type, io_name = snooper_args.split(':', maxsplit=1)
|
||||
if io_type == 'file':
|
||||
@@ -165,6 +166,6 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
|
||||
_SNOOPER_INSTANCE_COUNT -= 1
|
||||
return
|
||||
|
||||
raise ValueError(f'I/O type {io_type} not supported')
|
||||
raise core.InvalidArgumentError(f'I/O type {io_type} not supported')
|
||||
|
||||
raise ValueError(f'snooper type {snooper_type} not found')
|
||||
raise core.InvalidArgumentError(f'snooper type {snooper_type} not found')
|
||||
|
||||
@@ -18,9 +18,9 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .common import Transport, AsyncPipeSink, SnoopingTransport
|
||||
from ..controller import Controller
|
||||
from .common import Transport, AsyncPipeSink, SnoopingTransport, TransportSpecError
|
||||
from ..snoop import create_snooper
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -53,8 +53,16 @@ def _wrap_transport(transport: Transport) -> Transport:
|
||||
async def open_transport(name: str) -> Transport:
|
||||
"""
|
||||
Open a transport by name.
|
||||
The name must be <type>:<parameters>
|
||||
Where <parameters> depend on the type (and may be empty for some types).
|
||||
The name must be <type>:<metadata><parameters>
|
||||
Where <parameters> depend on the type (and may be empty for some types), and
|
||||
<metadata> is either omitted, or a ,-separated list of <key>=<value> pairs,
|
||||
enclosed in [].
|
||||
If there are not metadata or parameter, the : after the <type> may be omitted.
|
||||
Examples:
|
||||
* usb:0
|
||||
* usb:[driver=rtk]0
|
||||
* android-netsim
|
||||
|
||||
The supported types are:
|
||||
* serial
|
||||
* udp
|
||||
@@ -69,88 +77,116 @@ async def open_transport(name: str) -> Transport:
|
||||
* usb
|
||||
* pyusb
|
||||
* android-emulator
|
||||
* android-netsim
|
||||
"""
|
||||
|
||||
return _wrap_transport(await _open_transport(name))
|
||||
scheme, *tail = name.split(':', 1)
|
||||
spec = tail[0] if tail else None
|
||||
metadata = None
|
||||
if spec:
|
||||
# Metadata may precede the spec
|
||||
if spec.startswith('['):
|
||||
metadata_str, *tail = spec[1:].split(']')
|
||||
spec = tail[0] if tail else None
|
||||
metadata = dict([entry.split('=') for entry in metadata_str.split(',')])
|
||||
|
||||
transport = await _open_transport(scheme, spec)
|
||||
if metadata:
|
||||
transport.source.metadata = { # type: ignore[attr-defined]
|
||||
**metadata,
|
||||
**getattr(transport.source, 'metadata', {}),
|
||||
}
|
||||
# pylint: disable=line-too-long
|
||||
logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined]
|
||||
|
||||
return _wrap_transport(transport)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def _open_transport(name: str) -> Transport:
|
||||
async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pylint: disable=too-many-return-statements
|
||||
|
||||
scheme, *spec = name.split(':', 1)
|
||||
if scheme == 'serial' and spec:
|
||||
from .serial import open_serial_transport
|
||||
|
||||
return await open_serial_transport(spec[0])
|
||||
return await open_serial_transport(spec)
|
||||
|
||||
if scheme == 'udp' and spec:
|
||||
from .udp import open_udp_transport
|
||||
|
||||
return await open_udp_transport(spec[0])
|
||||
return await open_udp_transport(spec)
|
||||
|
||||
if scheme == 'tcp-client' and spec:
|
||||
from .tcp_client import open_tcp_client_transport
|
||||
|
||||
return await open_tcp_client_transport(spec[0])
|
||||
return await open_tcp_client_transport(spec)
|
||||
|
||||
if scheme == 'tcp-server' and spec:
|
||||
from .tcp_server import open_tcp_server_transport
|
||||
|
||||
return await open_tcp_server_transport(spec[0])
|
||||
return await open_tcp_server_transport(spec)
|
||||
|
||||
if scheme == 'ws-client' and spec:
|
||||
from .ws_client import open_ws_client_transport
|
||||
|
||||
return await open_ws_client_transport(spec[0])
|
||||
return await open_ws_client_transport(spec)
|
||||
|
||||
if scheme == 'ws-server' and spec:
|
||||
from .ws_server import open_ws_server_transport
|
||||
|
||||
return await open_ws_server_transport(spec[0])
|
||||
return await open_ws_server_transport(spec)
|
||||
|
||||
if scheme == 'pty':
|
||||
from .pty import open_pty_transport
|
||||
|
||||
return await open_pty_transport(spec[0] if spec else None)
|
||||
return await open_pty_transport(spec)
|
||||
|
||||
if scheme == 'file':
|
||||
from .file import open_file_transport
|
||||
|
||||
return await open_file_transport(spec[0] if spec else None)
|
||||
assert spec is not None
|
||||
return await open_file_transport(spec)
|
||||
|
||||
if scheme == 'vhci':
|
||||
from .vhci import open_vhci_transport
|
||||
|
||||
return await open_vhci_transport(spec[0] if spec else None)
|
||||
return await open_vhci_transport(spec)
|
||||
|
||||
if scheme == 'hci-socket':
|
||||
from .hci_socket import open_hci_socket_transport
|
||||
|
||||
return await open_hci_socket_transport(spec[0] if spec else None)
|
||||
return await open_hci_socket_transport(spec)
|
||||
|
||||
if scheme == 'usb':
|
||||
from .usb import open_usb_transport
|
||||
|
||||
return await open_usb_transport(spec[0] if spec else None)
|
||||
assert spec
|
||||
return await open_usb_transport(spec)
|
||||
|
||||
if scheme == 'pyusb':
|
||||
from .pyusb import open_pyusb_transport
|
||||
|
||||
return await open_pyusb_transport(spec[0] if spec else None)
|
||||
assert spec
|
||||
return await open_pyusb_transport(spec)
|
||||
|
||||
if scheme == 'android-emulator':
|
||||
from .android_emulator import open_android_emulator_transport
|
||||
|
||||
return await open_android_emulator_transport(spec[0] if spec else None)
|
||||
return await open_android_emulator_transport(spec)
|
||||
|
||||
if scheme == 'android-netsim':
|
||||
from .android_netsim import open_android_netsim_transport
|
||||
|
||||
return await open_android_netsim_transport(spec[0] if spec else None)
|
||||
return await open_android_netsim_transport(spec)
|
||||
|
||||
raise ValueError('unknown transport scheme')
|
||||
if scheme == 'unix':
|
||||
from .unix import open_unix_client_transport
|
||||
|
||||
assert spec
|
||||
return await open_unix_client_transport(spec)
|
||||
|
||||
raise TransportSpecError('unknown transport scheme')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -167,11 +203,13 @@ async def open_transport_or_link(name: str) -> Transport:
|
||||
|
||||
"""
|
||||
if name.startswith('link-relay:'):
|
||||
logger.warning('Link Relay has been deprecated.')
|
||||
from ..controller import Controller
|
||||
from ..link import RemoteLink # lazy import
|
||||
|
||||
link = RemoteLink(name[11:])
|
||||
await link.wait_until_connected()
|
||||
controller = Controller('remote', link=link)
|
||||
controller = Controller('remote', link=link) # type:ignore[arg-type]
|
||||
|
||||
class LinkTransport(Transport):
|
||||
async def close(self):
|
||||
|
||||
@@ -18,7 +18,15 @@
|
||||
import logging
|
||||
import grpc.aio
|
||||
|
||||
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink
|
||||
from typing import Optional, Union
|
||||
|
||||
from .common import (
|
||||
PumpedTransport,
|
||||
PumpedPacketSource,
|
||||
PumpedPacketSink,
|
||||
Transport,
|
||||
TransportSpecError,
|
||||
)
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
|
||||
@@ -33,7 +41,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_emulator_transport(spec):
|
||||
async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||
'''
|
||||
Open a transport connection to an Android emulator via its gRPC interface.
|
||||
The parameter string has this syntax:
|
||||
@@ -66,8 +74,8 @@ async def open_android_emulator_transport(spec):
|
||||
# Parse the parameters
|
||||
mode = 'host'
|
||||
server_host = 'localhost'
|
||||
server_port = 8554
|
||||
if spec is not None:
|
||||
server_port = '8554'
|
||||
if spec:
|
||||
params = spec.split(',')
|
||||
for param in params:
|
||||
if param.startswith('mode='):
|
||||
@@ -75,13 +83,14 @@ async def open_android_emulator_transport(spec):
|
||||
elif ':' in param:
|
||||
server_host, server_port = param.split(':')
|
||||
else:
|
||||
raise ValueError('invalid parameter')
|
||||
raise TransportSpecError('invalid parameter')
|
||||
|
||||
# Connect to the gRPC server
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
logger.debug(f'connecting to gRPC server at {server_address}')
|
||||
channel = grpc.aio.insecure_channel(server_address)
|
||||
|
||||
service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]
|
||||
if mode == 'host':
|
||||
# Connect as a host
|
||||
service = EmulatedBluetoothServiceStub(channel)
|
||||
@@ -91,13 +100,16 @@ async def open_android_emulator_transport(spec):
|
||||
service = VhciForwardingServiceStub(channel)
|
||||
hci_device = HciDevice(service.attachVhci())
|
||||
else:
|
||||
raise ValueError('invalid mode')
|
||||
raise TransportSpecError('invalid mode')
|
||||
|
||||
# Create the transport object
|
||||
transport = PumpedTransport(
|
||||
PumpedPacketSource(hci_device.read),
|
||||
PumpedPacketSink(hci_device.write),
|
||||
channel.close,
|
||||
class EmulatorTransport(PumpedTransport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
await channel.close()
|
||||
|
||||
transport = EmulatorTransport(
|
||||
PumpedPacketSource(hci_device.read), PumpedPacketSink(hci_device.write)
|
||||
)
|
||||
transport.start()
|
||||
|
||||
|
||||
@@ -18,30 +18,36 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import grpc.aio
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import sys
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .common import (
|
||||
import grpc.aio
|
||||
|
||||
import bumble
|
||||
from bumble.transport.common import (
|
||||
ParserSource,
|
||||
PumpedTransport,
|
||||
PumpedPacketSource,
|
||||
PumpedPacketSink,
|
||||
Transport,
|
||||
TransportSpecError,
|
||||
TransportInitError,
|
||||
)
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
from .grpc_protobuf.packet_streamer_pb2_grpc import PacketStreamerStub
|
||||
from .grpc_protobuf.packet_streamer_pb2_grpc import (
|
||||
from .grpc_protobuf.netsim.packet_streamer_pb2_grpc import (
|
||||
PacketStreamerStub,
|
||||
PacketStreamerServicer,
|
||||
add_PacketStreamerServicer_to_server,
|
||||
)
|
||||
from .grpc_protobuf.packet_streamer_pb2 import PacketRequest, PacketResponse
|
||||
from .grpc_protobuf.hci_packet_pb2 import HCIPacket
|
||||
from .grpc_protobuf.startup_pb2 import Chip, ChipInfo
|
||||
from .grpc_protobuf.common_pb2 import ChipKind
|
||||
from .grpc_protobuf.netsim.packet_streamer_pb2 import PacketRequest, PacketResponse
|
||||
from .grpc_protobuf.netsim.hci_packet_pb2 import HCIPacket
|
||||
from .grpc_protobuf.netsim.startup_pb2 import Chip, ChipInfo, DeviceInfo
|
||||
from .grpc_protobuf.netsim.common_pb2 import ChipKind
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -54,6 +60,7 @@ logger = logging.getLogger(__name__)
|
||||
# -----------------------------------------------------------------------------
|
||||
DEFAULT_NAME = 'bumble0'
|
||||
DEFAULT_MANUFACTURER = 'Bumble'
|
||||
DEFAULT_VARIANT = ''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -66,6 +73,9 @@ def get_ini_dir() -> Optional[pathlib.Path]:
|
||||
elif sys.platform == 'linux':
|
||||
if xdg_runtime_dir := os.environ.get('XDG_RUNTIME_DIR', None):
|
||||
return pathlib.Path(xdg_runtime_dir)
|
||||
tmpdir = os.environ.get('TMPDIR', '/tmp')
|
||||
if pathlib.Path(tmpdir).is_dir():
|
||||
return pathlib.Path(tmpdir)
|
||||
elif sys.platform == 'win32':
|
||||
if local_app_data_dir := os.environ.get('LOCALAPPDATA', None):
|
||||
return pathlib.Path(local_app_data_dir) / 'Temp'
|
||||
@@ -74,14 +84,20 @@ def get_ini_dir() -> Optional[pathlib.Path]:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def find_grpc_port() -> int:
|
||||
def ini_file_name(instance_number: int) -> str:
|
||||
suffix = f'_{instance_number}' if instance_number > 0 else ''
|
||||
return f'netsim{suffix}.ini'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def find_grpc_port(instance_number: int) -> int:
|
||||
if not (ini_dir := get_ini_dir()):
|
||||
logger.debug('no known directory for .ini file')
|
||||
return 0
|
||||
|
||||
ini_file = ini_dir / 'netsim.ini'
|
||||
ini_file = ini_dir / ini_file_name(instance_number)
|
||||
logger.debug(f'Looking for .ini file at {ini_file}')
|
||||
if ini_file.is_file():
|
||||
logger.debug(f'Found .ini file at {ini_file}')
|
||||
with open(ini_file, 'r') as ini_file_data:
|
||||
for line in ini_file_data.readlines():
|
||||
if '=' in line:
|
||||
@@ -90,12 +106,14 @@ def find_grpc_port() -> int:
|
||||
logger.debug(f'gRPC port = {value}')
|
||||
return int(value)
|
||||
|
||||
logger.debug('no grpc.port property found in .ini file')
|
||||
|
||||
# Not found
|
||||
return 0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def publish_grpc_port(grpc_port) -> bool:
|
||||
def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
|
||||
if not (ini_dir := get_ini_dir()):
|
||||
logger.debug('no known directory for .ini file')
|
||||
return False
|
||||
@@ -104,7 +122,7 @@ def publish_grpc_port(grpc_port) -> bool:
|
||||
logger.debug('ini directory does not exist')
|
||||
return False
|
||||
|
||||
ini_file = ini_dir / 'netsim.ini'
|
||||
ini_file = ini_dir / ini_file_name(instance_number)
|
||||
try:
|
||||
ini_file.write_text(f'grpc.port={grpc_port}\n')
|
||||
logger.debug(f"published gRPC port at {ini_file}")
|
||||
@@ -121,13 +139,16 @@ def publish_grpc_port(grpc_port) -> bool:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_netsim_controller_transport(server_host, server_port):
|
||||
async def open_android_netsim_controller_transport(
|
||||
server_host: Optional[str], server_port: int, options: Dict[str, str]
|
||||
) -> Transport:
|
||||
if not server_port:
|
||||
raise ValueError('invalid port')
|
||||
raise TransportSpecError('invalid port')
|
||||
if server_host == '_' or not server_host:
|
||||
server_host = 'localhost'
|
||||
|
||||
if not publish_grpc_port(server_port):
|
||||
instance_number = int(options.get('instance', "0"))
|
||||
if not publish_grpc_port(server_port, instance_number):
|
||||
logger.warning("unable to publish gRPC port")
|
||||
|
||||
class HciDevice:
|
||||
@@ -181,18 +202,14 @@ async def open_android_netsim_controller_transport(server_host, server_port):
|
||||
data = (
|
||||
bytes([request.hci_packet.packet_type]) + request.hci_packet.packet
|
||||
)
|
||||
logger.debug(f'<<< PACKET: {data.hex()}')
|
||||
self.on_data_received(data)
|
||||
|
||||
def send_packet(self, data):
|
||||
async def send():
|
||||
await self.context.write(
|
||||
PacketResponse(
|
||||
hci_packet=HCIPacket(packet_type=data[0], packet=data[1:])
|
||||
)
|
||||
async def send_packet(self, data):
|
||||
return await self.context.write(
|
||||
PacketResponse(
|
||||
hci_packet=HCIPacket(packet_type=data[0], packet=data[1:])
|
||||
)
|
||||
|
||||
self.loop.create_task(send())
|
||||
)
|
||||
|
||||
def terminate(self):
|
||||
self.task.cancel()
|
||||
@@ -226,19 +243,19 @@ async def open_android_netsim_controller_transport(server_host, server_port):
|
||||
logger.debug('gRPC server cancelled')
|
||||
await self.grpc_server.stop(None)
|
||||
|
||||
def on_packet(self, packet):
|
||||
async def send_packet(self, packet):
|
||||
if not self.device:
|
||||
logger.debug('no device, dropping packet')
|
||||
return
|
||||
|
||||
self.device.send_packet(packet)
|
||||
return await self.device.send_packet(packet)
|
||||
|
||||
async def StreamPackets(self, _request_iterator, context):
|
||||
logger.debug('StreamPackets request')
|
||||
|
||||
# Check that we won't already have a device
|
||||
# Check that we don't already have a device
|
||||
if self.device:
|
||||
logger.debug('busy, already serving a device')
|
||||
logger.debug('Busy, already serving a device')
|
||||
return PacketResponse(error='Busy')
|
||||
|
||||
# Instantiate a new device
|
||||
@@ -259,43 +276,80 @@ async def open_android_netsim_controller_transport(server_host, server_port):
|
||||
await server.start()
|
||||
asyncio.get_running_loop().create_task(server.serve())
|
||||
|
||||
class GrpcServerTransport(Transport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
|
||||
return GrpcServerTransport(server, server)
|
||||
sink = PumpedPacketSink(server.send_packet)
|
||||
sink.start()
|
||||
return Transport(server, sink)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_netsim_host_transport(server_host, server_port, options):
|
||||
async def open_android_netsim_host_transport_with_address(
|
||||
server_host: Optional[str],
|
||||
server_port: int,
|
||||
options: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
if server_host == '_' or not server_host:
|
||||
server_host = 'localhost'
|
||||
|
||||
if not server_port:
|
||||
# Look for the gRPC config in a .ini file
|
||||
instance_number = 0 if options is None else int(options.get('instance', '0'))
|
||||
server_port = find_grpc_port(instance_number)
|
||||
if not server_port:
|
||||
raise TransportInitError('gRPC server port not found')
|
||||
|
||||
# Connect to the gRPC server
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
logger.debug(f'Connecting to gRPC server at {server_address}')
|
||||
channel = grpc.aio.insecure_channel(server_address)
|
||||
|
||||
return await open_android_netsim_host_transport_with_channel(
|
||||
channel,
|
||||
options,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_netsim_host_transport_with_channel(
|
||||
channel, options: Optional[Dict[str, str]] = None
|
||||
):
|
||||
# Wrapper for I/O operations
|
||||
class HciDevice:
|
||||
def __init__(self, name, manufacturer, hci_device):
|
||||
def __init__(self, name, variant, manufacturer, hci_device):
|
||||
self.name = name
|
||||
self.variant = variant
|
||||
self.manufacturer = manufacturer
|
||||
self.hci_device = hci_device
|
||||
|
||||
async def start(self): # Send the startup info
|
||||
chip_info = ChipInfo(
|
||||
device_info = DeviceInfo(
|
||||
name=self.name,
|
||||
chip=Chip(kind=ChipKind.BLUETOOTH, manufacturer=self.manufacturer),
|
||||
kind='BUMBLE',
|
||||
version=bumble.__version__,
|
||||
sdk_version=platform.python_version(),
|
||||
build_id=platform.platform(),
|
||||
arch=platform.machine(),
|
||||
variant=self.variant,
|
||||
)
|
||||
chip = Chip(kind=ChipKind.BLUETOOTH, manufacturer=self.manufacturer)
|
||||
chip_info = ChipInfo(name=self.name, chip=chip, device_info=device_info)
|
||||
logger.debug(f'Sending chip info to netsim: {chip_info}')
|
||||
await self.hci_device.write(PacketRequest(initial_info=chip_info))
|
||||
|
||||
async def read(self):
|
||||
response = await self.hci_device.read()
|
||||
response_type = response.WhichOneof('response_type')
|
||||
|
||||
if response_type == 'error':
|
||||
logger.warning(f'received error: {response.error}')
|
||||
raise RuntimeError(response.error)
|
||||
elif response_type == 'hci_packet':
|
||||
raise TransportInitError(response.error)
|
||||
|
||||
if response_type == 'hci_packet':
|
||||
return (
|
||||
bytes([response.hci_packet.packet_type])
|
||||
+ response.hci_packet.packet
|
||||
)
|
||||
|
||||
raise ValueError('unsupported response type')
|
||||
raise TransportSpecError('unsupported response type')
|
||||
|
||||
async def write(self, packet):
|
||||
await self.hci_device.write(
|
||||
@@ -304,38 +358,31 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
|
||||
)
|
||||
)
|
||||
|
||||
name = options.get('name', DEFAULT_NAME)
|
||||
name = DEFAULT_NAME if options is None else options.get('name', DEFAULT_NAME)
|
||||
variant = (
|
||||
DEFAULT_VARIANT if options is None else options.get('variant', DEFAULT_VARIANT)
|
||||
)
|
||||
manufacturer = DEFAULT_MANUFACTURER
|
||||
|
||||
if server_host == '_' or not server_host:
|
||||
server_host = 'localhost'
|
||||
|
||||
if not server_port:
|
||||
# Look for the gRPC config in a .ini file
|
||||
server_host = 'localhost'
|
||||
server_port = find_grpc_port()
|
||||
if not server_port:
|
||||
raise RuntimeError('gRPC server port not found')
|
||||
|
||||
# Connect to the gRPC server
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
logger.debug(f'Connecting to gRPC server at {server_address}')
|
||||
channel = grpc.aio.insecure_channel(server_address)
|
||||
|
||||
# Connect as a host
|
||||
service = PacketStreamerStub(channel)
|
||||
hci_device = HciDevice(
|
||||
name=name,
|
||||
variant=variant,
|
||||
manufacturer=manufacturer,
|
||||
hci_device=service.StreamPackets(),
|
||||
)
|
||||
await hci_device.start()
|
||||
|
||||
# Create the transport object
|
||||
transport = PumpedTransport(
|
||||
class GrpcTransport(PumpedTransport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
await channel.close()
|
||||
|
||||
transport = GrpcTransport(
|
||||
PumpedPacketSource(hci_device.read),
|
||||
PumpedPacketSink(hci_device.write),
|
||||
channel.close,
|
||||
)
|
||||
transport.start()
|
||||
|
||||
@@ -343,7 +390,7 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_netsim_transport(spec):
|
||||
async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
|
||||
'''
|
||||
Open a transport connection as a client or server, implementing Android's `netsim`
|
||||
simulator protocol over gRPC.
|
||||
@@ -357,6 +404,11 @@ async def open_android_netsim_transport(spec):
|
||||
to connect *to* a netsim server (netsim is the controller), or accept
|
||||
connections *as* a netsim-compatible server.
|
||||
|
||||
instance=<n>
|
||||
Specifies an instance number, with <n> > 0. This is used to determine which
|
||||
.init file to use. In `host` mode, it is ignored when the <host>:<port>
|
||||
specifier is present, since in that case no .ini file is used.
|
||||
|
||||
In `host` mode:
|
||||
The <host>:<port> part is optional. When not specified, the transport
|
||||
looks for a netsim .ini file, from which it will read the `grpc.backend.port`
|
||||
@@ -366,6 +418,9 @@ async def open_android_netsim_transport(spec):
|
||||
The "chip" name, used to identify the "chip" instance. This
|
||||
may be useful when several clients are connected, since each needs to use a
|
||||
different name.
|
||||
variant=<variant>
|
||||
The device info variant field, which may be used to convey a device or
|
||||
application type (ex: "virtual-speaker", or "keyboard")
|
||||
|
||||
In `controller` mode:
|
||||
The <host>:<port> part is required. <host> may be the address of a local network
|
||||
@@ -385,26 +440,29 @@ async def open_android_netsim_transport(spec):
|
||||
params = spec.split(',') if spec else []
|
||||
if params and ':' in params[0]:
|
||||
# Explicit <host>:<port>
|
||||
host, port = params[0].split(':')
|
||||
host, port_str = params[0].split(':')
|
||||
port = int(port_str)
|
||||
params_offset = 1
|
||||
else:
|
||||
host = None
|
||||
port = 0
|
||||
params_offset = 0
|
||||
|
||||
options = {}
|
||||
options: Dict[str, str] = {}
|
||||
for param in params[params_offset:]:
|
||||
if '=' not in param:
|
||||
raise ValueError('invalid parameter, expected <name>=<value>')
|
||||
raise TransportSpecError('invalid parameter, expected <name>=<value>')
|
||||
option_name, option_value = param.split('=')
|
||||
options[option_name] = option_value
|
||||
|
||||
mode = options.get('mode', 'host')
|
||||
if mode == 'host':
|
||||
return await open_android_netsim_host_transport(host, port, options)
|
||||
return await open_android_netsim_host_transport_with_address(
|
||||
host, port, options
|
||||
)
|
||||
if mode == 'controller':
|
||||
if host is None:
|
||||
raise ValueError('<host>:<port> missing')
|
||||
return await open_android_netsim_controller_transport(host, port)
|
||||
raise TransportSpecError('<host>:<port> missing')
|
||||
return await open_android_netsim_controller_transport(host, port, options)
|
||||
|
||||
raise ValueError('invalid mode option')
|
||||
raise TransportSpecError('invalid mode option')
|
||||
|
||||
@@ -20,11 +20,13 @@ import contextlib
|
||||
import struct
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import ContextManager
|
||||
import io
|
||||
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
|
||||
|
||||
from .. import hci
|
||||
from ..colors import color
|
||||
from ..snoop import Snooper
|
||||
from bumble import core
|
||||
from bumble import hci
|
||||
from bumble.colors import color
|
||||
from bumble.snoop import Snooper
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -36,19 +38,41 @@ logger = logging.getLogger(__name__)
|
||||
# Information needed to parse HCI packets with a generic parser:
|
||||
# For each packet type, the info represents:
|
||||
# (length-size, length-offset, unpack-type)
|
||||
HCI_PACKET_INFO = {
|
||||
HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
|
||||
hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
|
||||
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
|
||||
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
|
||||
hci.HCI_EVENT_PACKET: (1, 1, 'B'),
|
||||
hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'),
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class TransportLostError(Exception):
|
||||
"""
|
||||
The Transport has been lost/disconnected.
|
||||
"""
|
||||
# Errors
|
||||
# -----------------------------------------------------------------------------
|
||||
class TransportLostError(core.BaseBumbleError, RuntimeError):
|
||||
"""The Transport has been lost/disconnected."""
|
||||
|
||||
|
||||
class TransportInitError(core.BaseBumbleError, RuntimeError):
|
||||
"""Error raised when the transport cannot be initialized."""
|
||||
|
||||
|
||||
class TransportSpecError(core.BaseBumbleError, ValueError):
|
||||
"""Error raised when the transport spec is invalid."""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Typing Protocols
|
||||
# -----------------------------------------------------------------------------
|
||||
class TransportSink(Protocol):
|
||||
def on_packet(self, packet: bytes) -> None: ...
|
||||
|
||||
|
||||
class TransportSource(Protocol):
|
||||
terminated: asyncio.Future[None]
|
||||
|
||||
def set_packet_sink(self, sink: TransportSink) -> None: ...
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -57,18 +81,15 @@ class PacketPump:
|
||||
Pump HCI packets from a reader to a sink.
|
||||
"""
|
||||
|
||||
def __init__(self, reader, sink):
|
||||
def __init__(self, reader: AsyncPacketReader, sink: TransportSink) -> None:
|
||||
self.reader = reader
|
||||
self.sink = sink
|
||||
|
||||
async def run(self):
|
||||
async def run(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
# Get a packet from the source
|
||||
packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet())
|
||||
|
||||
# Deliver the packet to the sink
|
||||
self.sink.on_packet(packet)
|
||||
self.sink.on_packet(await self.reader.next_packet())
|
||||
except Exception as error:
|
||||
logger.warning(f'!!! {error}')
|
||||
|
||||
@@ -86,18 +107,22 @@ class PacketParser:
|
||||
NEED_LENGTH = 1
|
||||
NEED_BODY = 2
|
||||
|
||||
def __init__(self, sink=None):
|
||||
sink: Optional[TransportSink]
|
||||
extended_packet_info: Dict[int, Tuple[int, int, str]]
|
||||
packet_info: Optional[Tuple[int, int, str]] = None
|
||||
|
||||
def __init__(self, sink: Optional[TransportSink] = None) -> None:
|
||||
self.sink = sink
|
||||
self.extended_packet_info = {}
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self.state = PacketParser.NEED_TYPE
|
||||
self.bytes_needed = 1
|
||||
self.packet = bytearray()
|
||||
self.packet_info = None
|
||||
|
||||
def feed_data(self, data):
|
||||
def feed_data(self, data: bytes) -> None:
|
||||
data_offset = 0
|
||||
data_left = len(data)
|
||||
while data_left and self.bytes_needed:
|
||||
@@ -114,10 +139,13 @@ class PacketParser:
|
||||
packet_type
|
||||
) or self.extended_packet_info.get(packet_type)
|
||||
if self.packet_info is None:
|
||||
raise ValueError(f'invalid packet type {packet_type}')
|
||||
raise core.InvalidPacketError(
|
||||
f'invalid packet type {packet_type}'
|
||||
)
|
||||
self.state = PacketParser.NEED_LENGTH
|
||||
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
|
||||
elif self.state == PacketParser.NEED_LENGTH:
|
||||
assert self.packet_info is not None
|
||||
body_length = struct.unpack_from(
|
||||
self.packet_info[2], self.packet, 1 + self.packet_info[1]
|
||||
)[0]
|
||||
@@ -130,12 +158,12 @@ class PacketParser:
|
||||
try:
|
||||
self.sink.on_packet(bytes(self.packet))
|
||||
except Exception as error:
|
||||
logger.warning(
|
||||
logger.exception(
|
||||
color(f'!!! Exception in on_packet: {error}', 'red')
|
||||
)
|
||||
self.reset()
|
||||
|
||||
def set_packet_sink(self, sink):
|
||||
def set_packet_sink(self, sink: TransportSink) -> None:
|
||||
self.sink = sink
|
||||
|
||||
|
||||
@@ -145,31 +173,33 @@ class PacketReader:
|
||||
Reader that reads HCI packets from a sync source.
|
||||
"""
|
||||
|
||||
def __init__(self, source):
|
||||
def __init__(self, source: io.BufferedReader) -> None:
|
||||
self.source = source
|
||||
self.at_end = False
|
||||
|
||||
def next_packet(self):
|
||||
def next_packet(self) -> Optional[bytes]:
|
||||
# Get the packet type
|
||||
packet_type = self.source.read(1)
|
||||
if len(packet_type) != 1:
|
||||
self.at_end = True
|
||||
return None
|
||||
|
||||
# Get the packet info based on its type
|
||||
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
||||
if packet_info is None:
|
||||
raise ValueError(f'invalid packet type {packet_type} found')
|
||||
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
|
||||
|
||||
# Read the header (that includes the length)
|
||||
header_size = packet_info[0] + packet_info[1]
|
||||
header = self.source.read(header_size)
|
||||
if len(header) != header_size:
|
||||
raise ValueError('packet too short')
|
||||
raise core.InvalidPacketError('packet too short')
|
||||
|
||||
# Read the body
|
||||
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
|
||||
body = self.source.read(body_length)
|
||||
if len(body) != body_length:
|
||||
raise ValueError('packet too short')
|
||||
raise core.InvalidPacketError('packet too short')
|
||||
|
||||
return packet_type + header + body
|
||||
|
||||
@@ -180,17 +210,17 @@ class AsyncPacketReader:
|
||||
Reader that reads HCI packets from an async source.
|
||||
"""
|
||||
|
||||
def __init__(self, source):
|
||||
def __init__(self, source: asyncio.StreamReader) -> None:
|
||||
self.source = source
|
||||
|
||||
async def next_packet(self):
|
||||
async def next_packet(self) -> bytes:
|
||||
# Get the packet type
|
||||
packet_type = await self.source.readexactly(1)
|
||||
|
||||
# Get the packet info based on its type
|
||||
packet_info = HCI_PACKET_INFO.get(packet_type[0])
|
||||
if packet_info is None:
|
||||
raise ValueError(f'invalid packet type {packet_type} found')
|
||||
raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
|
||||
|
||||
# Read the header (that includes the length)
|
||||
header_size = packet_info[0] + packet_info[1]
|
||||
@@ -209,64 +239,81 @@ class AsyncPipeSink:
|
||||
Sink that forwards packets asynchronously to another sink.
|
||||
"""
|
||||
|
||||
def __init__(self, sink):
|
||||
def __init__(self, sink: TransportSink) -> None:
|
||||
self.sink = sink
|
||||
self.loop = asyncio.get_running_loop()
|
||||
|
||||
def on_packet(self, packet):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
self.loop.call_soon(self.sink.on_packet, packet)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ParserSource:
|
||||
class BaseSource:
|
||||
"""
|
||||
Base class designed to be subclassed by transport-specific source classes
|
||||
"""
|
||||
|
||||
terminated: asyncio.Future
|
||||
parser: PacketParser
|
||||
terminated: asyncio.Future[None]
|
||||
sink: Optional[TransportSink]
|
||||
|
||||
def __init__(self):
|
||||
self.parser = PacketParser()
|
||||
def __init__(self) -> None:
|
||||
self.terminated = asyncio.get_running_loop().create_future()
|
||||
self.sink = None
|
||||
|
||||
def set_packet_sink(self, sink):
|
||||
self.parser.set_packet_sink(sink)
|
||||
def set_packet_sink(self, sink: TransportSink) -> None:
|
||||
self.sink = sink
|
||||
|
||||
def on_transport_lost(self):
|
||||
self.terminated.set_result(None)
|
||||
if self.parser.sink:
|
||||
try:
|
||||
self.parser.sink.on_transport_lost()
|
||||
except AttributeError:
|
||||
pass
|
||||
def on_transport_lost(self) -> None:
|
||||
if not self.terminated.done():
|
||||
self.terminated.set_result(None)
|
||||
|
||||
async def wait_for_termination(self):
|
||||
if self.sink:
|
||||
if hasattr(self.sink, 'on_transport_lost'):
|
||||
self.sink.on_transport_lost()
|
||||
|
||||
async def wait_for_termination(self) -> None:
|
||||
"""
|
||||
Convenience method for backward compatibility. Prefer using the `terminated`
|
||||
attribute instead.
|
||||
"""
|
||||
return await self.terminated
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ParserSource(BaseSource):
|
||||
"""
|
||||
Base class for sources that use an HCI parser.
|
||||
"""
|
||||
|
||||
parser: PacketParser
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.parser = PacketParser()
|
||||
|
||||
def set_packet_sink(self, sink: TransportSink) -> None:
|
||||
super().set_packet_sink(sink)
|
||||
self.parser.set_packet_sink(sink)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class StreamPacketSource(asyncio.Protocol, ParserSource):
|
||||
def data_received(self, data):
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self.parser.feed_data(data)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class StreamPacketSink:
|
||||
def __init__(self, transport):
|
||||
def __init__(self, transport: asyncio.WriteTransport) -> None:
|
||||
self.transport = transport
|
||||
|
||||
def on_packet(self, packet):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
self.transport.write(packet)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self.transport.close()
|
||||
|
||||
|
||||
@@ -286,7 +333,7 @@ class Transport:
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, source, sink):
|
||||
def __init__(self, source: TransportSource, sink: TransportSink) -> None:
|
||||
self.source = source
|
||||
self.sink = sink
|
||||
|
||||
@@ -300,34 +347,41 @@ class Transport:
|
||||
return iter((self.source, self.sink))
|
||||
|
||||
async def close(self) -> None:
|
||||
self.source.close()
|
||||
self.sink.close()
|
||||
if hasattr(self.source, 'close'):
|
||||
self.source.close()
|
||||
if hasattr(self.sink, 'close'):
|
||||
self.sink.close()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PumpedPacketSource(ParserSource):
|
||||
def __init__(self, receive):
|
||||
pump_task: Optional[asyncio.Task[None]]
|
||||
|
||||
def __init__(self, receive) -> None:
|
||||
super().__init__()
|
||||
self.receive_function = receive
|
||||
self.pump_task = None
|
||||
|
||||
def start(self):
|
||||
async def pump_packets():
|
||||
def start(self) -> None:
|
||||
async def pump_packets() -> None:
|
||||
while True:
|
||||
try:
|
||||
packet = await self.receive_function()
|
||||
self.parser.feed_data(packet)
|
||||
except asyncio.exceptions.CancelledError:
|
||||
except asyncio.CancelledError:
|
||||
logger.debug('source pump task done')
|
||||
if not self.terminated.done():
|
||||
self.terminated.set_result(None)
|
||||
break
|
||||
except Exception as error:
|
||||
logger.warning(f'exception while waiting for packet: {error}')
|
||||
self.terminated.set_result(error)
|
||||
if not self.terminated.done():
|
||||
self.terminated.set_exception(error)
|
||||
break
|
||||
|
||||
self.pump_task = asyncio.create_task(pump_packets())
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
if self.pump_task:
|
||||
self.pump_task.cancel()
|
||||
|
||||
@@ -339,7 +393,7 @@ class PumpedPacketSink:
|
||||
self.packet_queue = asyncio.Queue()
|
||||
self.pump_task = None
|
||||
|
||||
def on_packet(self, packet):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
self.packet_queue.put_nowait(packet)
|
||||
|
||||
def start(self):
|
||||
@@ -348,7 +402,7 @@ class PumpedPacketSink:
|
||||
try:
|
||||
packet = await self.packet_queue.get()
|
||||
await self.send_function(packet)
|
||||
except asyncio.exceptions.CancelledError:
|
||||
except asyncio.CancelledError:
|
||||
logger.debug('sink pump task done')
|
||||
break
|
||||
except Exception as error:
|
||||
@@ -364,18 +418,20 @@ class PumpedPacketSink:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PumpedTransport(Transport):
|
||||
def __init__(self, source, sink, close_function):
|
||||
super().__init__(source, sink)
|
||||
self.close_function = close_function
|
||||
source: PumpedPacketSource
|
||||
sink: PumpedPacketSink
|
||||
|
||||
def start(self):
|
||||
def __init__(
|
||||
self,
|
||||
source: PumpedPacketSource,
|
||||
sink: PumpedPacketSink,
|
||||
) -> None:
|
||||
super().__init__(source, sink)
|
||||
|
||||
def start(self) -> None:
|
||||
self.source.start()
|
||||
self.sink.start()
|
||||
|
||||
async def close(self):
|
||||
await super().close()
|
||||
await self.close_function()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SnoopingTransport(Transport):
|
||||
@@ -394,34 +450,45 @@ class SnoopingTransport(Transport):
|
||||
return SnoopingTransport(
|
||||
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
|
||||
)
|
||||
raise RuntimeError('unexpected code path') # Satisfy the type checker
|
||||
raise core.UnreachableError() # Satisfy the type checker
|
||||
|
||||
class Source:
|
||||
def __init__(self, source, snooper):
|
||||
sink: TransportSink
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
return getattr(self.source, 'metadata', {})
|
||||
|
||||
def __init__(self, source: TransportSource, snooper: Snooper):
|
||||
self.source = source
|
||||
self.snooper = snooper
|
||||
self.sink = None
|
||||
self.terminated = source.terminated
|
||||
|
||||
def set_packet_sink(self, sink):
|
||||
def set_packet_sink(self, sink: TransportSink) -> None:
|
||||
self.sink = sink
|
||||
self.source.set_packet_sink(self)
|
||||
|
||||
def on_packet(self, packet):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
|
||||
if self.sink:
|
||||
self.sink.on_packet(packet)
|
||||
|
||||
class Sink:
|
||||
def __init__(self, sink, snooper):
|
||||
def __init__(self, sink: TransportSink, snooper: Snooper) -> None:
|
||||
self.sink = sink
|
||||
self.snooper = snooper
|
||||
|
||||
def on_packet(self, packet):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
|
||||
if self.sink:
|
||||
self.sink.on_packet(packet)
|
||||
|
||||
def __init__(self, transport, snooper, close_snooper=None):
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
snooper: Snooper,
|
||||
close_snooper=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_file_transport(spec):
|
||||
async def open_file_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a File transport (typically not for a real file, but for a PTY or other unix
|
||||
virtual files).
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: hci_packet.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10hci_packet.proto\x12\rnetsim.packet\"\xb2\x01\n\tHCIPacket\x12\x38\n\x0bpacket_type\x18\x01 \x01(\x0e\x32#.netsim.packet.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"[\n\nPacketType\x12\x1a\n\x16HCI_PACKET_UNSPECIFIED\x10\x00\x12\x0b\n\x07\x43OMMAND\x10\x01\x12\x07\n\x03\x41\x43L\x10\x02\x12\x07\n\x03SCO\x10\x03\x12\t\n\x05\x45VENT\x10\x04\x12\x07\n\x03ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hci_packet_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth'
|
||||
_HCIPACKET._serialized_start=36
|
||||
_HCIPACKET._serialized_end=214
|
||||
_HCIPACKET_PACKETTYPE._serialized_start=123
|
||||
_HCIPACKET_PACKETTYPE._serialized_end=214
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
0
bumble/transport/grpc_protobuf/netsim/__init__.py
Normal file
0
bumble/transport/grpc_protobuf/netsim/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user