Compare commits

...

104 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
32bb7cdaf3 add support for multiple concurrent broadcasts 2026-01-01 18:24:03 -08:00
zxzxwu
b4261548e8 Merge pull request #848 from zxzxwu/typing
Ruff: Add and fix UP rules
2026-01-01 16:47:05 +08:00
Gilles Boccon-Gibod
9161cea577 Merge pull request #846 from google/gbg/ruff-hot-fix 2025-12-31 14:14:47 -08:00
Josh Wu
3f643de4c1 Ruff: Add and fix UP rules 2026-01-01 03:25:32 +08:00
Gilles Boccon-Gibod
7c7b792cf9 remove unused import 2025-12-30 13:22:27 -08:00
Gilles Boccon-Gibod
8e28f4e159 Merge pull request #845 from google/gbg/ruff
use ruff for linting and import sorting
2025-12-30 11:38:48 -08:00
zxzxwu
8823cf108f Merge pull request #840 from zxzxwu/credit
L2CAP: Enhanced Credit-based Flow Control Mode
2025-12-30 20:26:44 +08:00
Gilles Boccon-Gibod
4fb501a0ef use ruff for linting and import sorting 2025-12-29 19:28:45 -08:00
Gilles Boccon-Gibod
ad0753b959 Merge pull request #843 from dlech/type-hints
Fix missing type hints on Device.notify_subscribers()
2025-12-29 16:35:46 -08:00
Gilles Boccon-Gibod
f12cccf6cd Merge pull request #844 from dlech/remove-unused-imports
Remove unused imports
2025-12-29 16:28:08 -08:00
David Lechner
5bbbe5e40f Remove unused imports
Mechanically remove unused imports with:

    ruff check --select F401 --fix --extend-exclude grpc_protobuf
2025-12-29 17:19:11 -06:00
David Lechner
793fcd750c Fix missing type hints on Device.notify_subscribers()
Add type hints for all arguments. Otherwise static checkers complain
when you try to use it.
2025-12-29 16:03:46 -06:00
Gilles Boccon-Gibod
ae2c638256 Merge pull request #842 from dlech/fix-duplicate-GATT_CONTENT_CONTROL_ID_CHARACTERISTIC
GATT: fix redefinition of GATT_CONTENT_CONTROL_ID_CHARACTERISTIC
2025-12-29 12:12:54 -08:00
David Lechner
9ad0eafe37 GATT: remove duplicate GATT_CONTENT_CONTROL_ID_CHARACTERISTIC
Remove the first occurrence of GATT_CONTENT_CONTROL_ID_CHARACTERISTIC.

The "Telephone Bearer Service (TBS)" section also defines
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC later, so we don't need this one.
2025-12-29 13:57:52 -06:00
Josh Wu
618e977f20 L2CAP: Enhanced Credit-based Flow Control Mode 2025-12-23 19:51:55 +08:00
zxzxwu
7fdc4f624e Merge pull request #838 from salmanmkc/upgrade-github-actions-node24-general
Upgrade GitHub Actions to latest versions
2025-12-18 17:32:16 +08:00
zxzxwu
255ca60d95 Merge pull request #839 from google/dependabot/pip/docs/mkdocs/pip-d9bbda99d0
Bump pymdown-extensions from 10.0 to 10.16.1 in /docs/mkdocs in the pip group across 1 directory
2025-12-17 19:20:19 +08:00
zxzxwu
716f57de46 Merge pull request #837 from salmanmkc/upgrade-github-actions-node24
Upgrade GitHub Actions for Node 24 compatibility
2025-12-17 19:20:14 +08:00
Salman Muin Kayser Chishti
95a987d3a4 Fix pypa/gh-action-pypi-publish to use SHA pinning
Pin to release/v1.13 for security best practices.
The v1 tag doesn't exist - only release/v1 branch exists.

Signed-off-by: Salman Muin Kayser Chishti <13schishti@gmail.com>
2025-12-17 10:31:35 +00:00
dependabot[bot]
6858c591aa Bump pymdown-extensions
Bumps the pip group with 1 update in the /docs/mkdocs directory: [pymdown-extensions](https://github.com/facelessuser/pymdown-extensions).


Updates `pymdown-extensions` from 10.0 to 10.16.1
- [Release notes](https://github.com/facelessuser/pymdown-extensions/releases)
- [Commits](https://github.com/facelessuser/pymdown-extensions/compare/10.0...10.16.1)

---
updated-dependencies:
- dependency-name: pymdown-extensions
  dependency-version: 10.16.1
  dependency-type: direct:production
  dependency-group: pip
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-12-16 20:46:58 +00:00
Salman Muin Kayser Chishti
e03b9cb441 Upgrade GitHub Actions to latest versions 2025-12-16 14:34:19 +00:00
Salman Muin Kayser Chishti
ade36f8d04 Upgrade GitHub Actions for Node 24 compatibility 2025-12-16 14:34:13 +00:00
Gilles Boccon-Gibod
48744ee9db Merge pull request #833 from google/gbg/netsim_enhancements
android-netsim transport enhancements
2025-12-15 15:38:20 -08:00
Gilles Boccon-Gibod
302e496890 android-netsim transport enhancements 2025-12-15 15:14:57 -08:00
Gilles Boccon-Gibod
6649464cd6 Merge pull request #835 from google/gbg/fix-rust-latest 2025-12-15 11:20:19 -08:00
Gilles Boccon-Gibod
c46df21385 use 1.91.0 instead of stable until a fix is available 2025-12-14 16:32:44 -08:00
Gilles Boccon-Gibod
7a35f5d095 bump pdl dependencies versions 2025-12-14 11:37:49 -08:00
zxzxwu
73f2853c5e Merge pull request #830 from zxzxwu/bridge
Add some docs about Android and Hardware
2025-12-10 18:52:15 +08:00
Josh Wu
de3009e296 Add some docs about Android and Hardware 2025-12-10 18:08:13 +08:00
zxzxwu
e47cb5512c Merge pull request #779 from zxzxwu/l2cap
L2CAP Enhanced Retransmission mode
2025-12-03 21:57:48 +08:00
zxzxwu
3171b5a19e Merge pull request #828 from zxzxwu/rust
Rust: Fix cargo-all-features to 1.11.0
2025-12-01 16:31:21 +08:00
Josh Wu
456cb59b48 L2CAP: FCS Implementation 2025-12-01 16:10:45 +08:00
zxzxwu
33ca324e41 Merge pull request #827 from zxzxwu/emu
Implement extended advertising emulation
2025-12-01 15:57:42 +08:00
Josh Wu
a84f0279b1 Refactor LE emulation with LL and Air Interface 2025-11-28 16:10:38 +08:00
Josh Wu
b93ba007ed Rust: Fix cargo-all-features to 1.11.0 2025-11-28 02:25:52 +08:00
Josh Wu
d2a4c2a8e4 Implement extended advertising emulation 2025-11-27 20:56:10 +08:00
Josh Wu
57e05781ad L2CAP: Enhanced Retransmission Mode 2025-11-24 16:17:11 +08:00
zxzxwu
bae6c1df97 Merge pull request #826 from ljodal/ljodal/cancel-pending-l2cap-connection
Cancel l2cap connection result future on abort
2025-11-19 18:52:32 +08:00
Sigurd Ljødal
7292c2785e Cancel l2cap connection result future on abort
This cancels the `connection_result` future of LeCreditBasedChannel when
abort() is called, e.g. if the LE connection disconnects. This makes it
possible for code waiting for a connection to open to detect that the
connection has failed.

Fixes google/bumble#825
2025-11-14 14:52:09 +01:00
khsiao-google
42711d3d31 Merge pull request #824 from khsiao-google/test_coverage
Add remote name request
2025-11-11 06:06:37 +08:00
khsiao-google
67a61ae34d Update tests/device_test.py
Co-authored-by: zxzxwu <92432172+zxzxwu@users.noreply.github.com>
2025-11-11 05:34:46 +08:00
khsiao-google
a62f981556 Add remote name request 2025-11-10 14:04:50 +00:00
zxzxwu
6b56b10b6e Merge pull request #823 from zxzxwu/lmp
Refactor classic emulation with LMP protocol
2025-11-09 15:57:47 +08:00
Josh Wu
e0dee2135f Basic LMP implementation 2025-11-09 15:50:12 +08:00
zxzxwu
bb9aa12a74 Merge pull request #822 from zxzxwu/call_soon
Emulation: Improve import, typing, and use call_soon
2025-11-09 15:00:12 +08:00
Josh Wu
da64f66bce Emulation: Improve import, typing, and use call_soon 2025-11-08 22:43:51 +08:00
zxzxwu
f000a3f30a Merge pull request #802 from zxzxwu/version
Upgrade Python version to 3.10-3.14
2025-11-07 23:22:07 +08:00
Gilles Boccon-Gibod
8ad48f92b3 Merge pull request #792 from markusjellitsch/task/fix-deprecated-warnings
Fix - deprecated warning for datetime.utcnow() with Python >= 3.12
2025-11-07 10:58:04 +01:00
zxzxwu
a827669f62 Merge pull request #817 from zxzxwu/device
Use EventWatcher and check_result
2025-11-07 17:16:04 +08:00
Josh Wu
4bee8d5287 Use EventWatcher and send_command(check_result=True) in all similar patterns 2025-11-07 00:37:57 +08:00
Josh Wu
5431941fe7 Upgrade Python version to 3.10-3.14 2025-11-05 04:45:05 +08:00
zxzxwu
d112901a17 Merge pull request #814 from zxzxwu/hid-fix
Fix wrong HID PSM
2025-11-04 15:20:20 +08:00
Josh Wu
2d74aef0e9 Fix wrong HID PSM 2025-11-04 01:36:07 +08:00
khsiao-google
f06e19e1ca Merge pull request #809 from khsiao-google/update
[Typing] Add controller.py typing
2025-11-03 18:58:13 +08:00
khsiao-google
36aefb280d Merge branch 'main' into update 2025-11-03 09:37:44 +00:00
zxzxwu
227f5cf62e Merge pull request #783 from zxzxwu/avrcp
AVCTP: Change callback packet type to bytes
2025-11-03 15:40:18 +08:00
Gilles Boccon-Gibod
1336cfa42c Merge pull request #813 from XenoKovah/main
Trivial change: Sorting VID/PIDs and adding new values
2025-11-02 19:08:45 +01:00
Xeno Kovah
0ca7b8b322 Sorting VID/PIDs and adding observed values on ZEXMTE (https://zexmtebluetooth.com/#Products) devices 2025-11-02 12:36:46 -05:00
Josh Wu
eef5304a36 AVCTP: Change callback packet type to bytes 2025-11-02 18:03:25 +08:00
khsiao-google
1a2141126c [Typing] Add controller.py typing 2025-11-01 09:30:36 +00:00
markus
6ed9a98490 use backquotes instead of regular quotes 2025-10-31 18:50:30 +01:00
zxzxwu
19b7660f88 Merge pull request #812 from markusjellitsch/fix/controller-dict-remove
Fix: RuntimeError in controller.py
2025-11-01 00:05:20 +08:00
zxzxwu
1932f14fb6 Merge pull request #811 from zxzxwu/websockets
Upgrade websockets dependency to 15.0.1+
2025-11-01 00:05:06 +08:00
markus
b70b92097f fix RuntimeError: dictionary change during iteration 2025-10-31 11:56:31 +01:00
markus
b6a800c692 use timezone utc for TIMESTAMP_ANCHOR 2025-10-31 11:35:47 +01:00
Josh Wu
d43f5573a6 Upgrade websockets dependency to 15.0.1+ 2025-10-31 17:35:13 +08:00
zxzxwu
1982168a9f Merge pull request #806 from zxzxwu/avrcp-response
AVRCP: Reply ACCEPTED on set absolute volume
2025-10-28 14:39:26 +08:00
Josh Wu
5e1794a15b AVRCP: Reply ACCEPTED on set absolute volume 2025-10-28 00:05:18 +08:00
Gilles Boccon-Gibod
578f7f054d Merge pull request #804 from graynode/rfcomm-tx-credit-goes-negative-fix
Fixed bug where it's possible for rfcomm tx_credit to go negative resulting in l2cap disconnect from peripheral
2025-10-26 14:25:29 +01:00
graynode
4b25b3581d updated per PR input 2025-10-24 10:09:02 -04:00
graynode
9601c7f287 fixed formatting issue 2025-10-24 09:30:45 -04:00
graynode
dae3ec5cba Fixed bug where it's possible for tx_credit to goe negative 2025-10-23 21:56:00 -04:00
zxzxwu
95225a1774 Merge pull request #803 from zxzxwu/avdtp
AVDTP: Migrate enums
2025-10-23 13:45:48 +08:00
Josh Wu
e54a26393e AVDTP: Add missing type annotations 2025-10-22 20:54:28 +08:00
Josh Wu
5dc76cf7b4 Migrate AVDTP enums 2025-10-22 20:41:51 +08:00
zxzxwu
6c68115660 Merge pull request #799 from zxzxwu/avdtp
Migrate AVDTP packets to dataclasses
2025-10-22 20:01:08 +08:00
zxzxwu
88ef65a4e2 Merge pull request #798 from khsiao-google/update
HFP: Change configuration attribute types to Sequence
2025-10-22 13:52:20 +08:00
zxzxwu
324b26d8f2 Merge pull request #801 from zyanwu-google/feat/intel_ddc
feat(intel): clarify firmware/DDC flow and preserve driver metadata
2025-10-22 13:51:16 +08:00
Josh Wu
a43b403511 Migrate AVDTP packets to dataclasses 2025-10-21 18:54:48 +08:00
zyanwu-google
c657494362 feat(intel): clarify firmware/DDC flow and preserve driver metadata
- Add explanatory comments across intel driver to clarify metadata parsing.
- Ensure driver selection preserves runtime options (e.g. "intel/ddc_override:AABB")
  so driver-specific metadata is passed through to the host and available to
  drivers via host.hci_metadata.
- Ensure transport parsing regex and metadata extraction so transport/source
  metadata is populated and visible to drivers.
- Example usage: passing [driver=intel/ddc_override:AABB] will be preserved and
  can be consumed by the Intel driver to apply a DDC override blob.
2025-10-21 09:00:38 +00:00
khsiao-google
11505f08b7 [Typing] Change to Sequence 2025-10-20 08:47:40 +00:00
khsiao-google
9bf9ed5f59 [Typing] Change list to Iterable 2025-10-10 15:32:06 +00:00
zxzxwu
0fa517a4f6 Merge pull request #793 from zain2983/main
Minor fixes
2025-10-03 15:54:13 +08:00
Z1
a11962a487 Minor fixes 2025-10-02 19:26:30 +00:00
markus
374a1c623f fix python 3.13 linter deprecated warnings for utcnow() 2025-09-26 22:49:46 +02:00
markus
82ffc6b23b Revert "fix python 3.13 linter deprecated warnings for utcnow()"
This reverts commit 589bbfcf19.
2025-09-26 22:46:57 +02:00
markus
589bbfcf19 fix python 3.13 linter deprecated warnings for utcnow() 2025-09-26 22:20:57 +02:00
zxzxwu
32d448edf3 Merge pull request #790 from markusjellitsch/task/fix-cis-reconnect
Fix - Allow re-creation of CIS link when not successfull
2025-09-26 19:55:49 +08:00
markus
3d615b13ce fix accessing pending_cis dict 2025-09-26 12:38:38 +02:00
Markus Jellitsch
1ad92dc759 Update bumble/device.py
Co-authored-by: zxzxwu <92432172+zxzxwu@users.noreply.github.com>
2025-09-26 12:25:50 +02:00
markus
aacfd4328c satisfy the linter, return None 2025-09-26 12:02:54 +02:00
markus
6aa1f5211c use local cis_link.handle to the pop the dict 2025-09-26 11:13:52 +02:00
markus
df8e454ee5 pop cis link only when cis created successfully 2025-09-26 10:58:37 +02:00
Gilles Boccon-Gibod
aec50ac616 Merge pull request #789 from google/gbg/nrf-uart-flow-control 2025-09-26 09:34:33 +02:00
Gilles Boccon-Gibod
6a3eaa457f python 3.9 compat 2025-09-26 08:42:10 +02:00
zxzxwu
6e6b4cd4b2 Merge pull request #773 from wescande/main
HAP: wait for MTU to process reconnection event
2025-09-26 01:36:45 +08:00
Gilles Boccon-Gibod
aa1d7933da enhance serial port transport 2025-09-25 18:31:14 +02:00
zxzxwu
34e0f293c2 Merge pull request #788 from zxzxwu/device
Fix wrong with_connection_from_address parameter
2025-09-23 19:44:50 +08:00
Josh Wu
85215df2c3 Fix wrong with_connection_from_address parameter 2025-09-23 17:55:47 +08:00
zxzxwu
f8223ca81f Merge pull request #780 from google/dependabot/cargo/rust/cargo-ad4b9ff1ea
Bump the cargo group across 1 directory with 5 updates
2025-09-19 14:50:45 +08:00
zxzxwu
2b0b1ad726 Merge pull request #781 from zxzxwu/connections
Revert pending_connections
2025-09-19 14:45:48 +08:00
Josh Wu
58debcd8bb Revert pending_connections 2025-09-19 12:32:28 +08:00
dependabot[bot]
6eba81e3dd Bump the cargo group across 1 directory with 5 updates
Bumps the cargo group with 4 updates in the /rust directory: [tokio](https://github.com/tokio-rs/tokio), [h2](https://github.com/hyperium/h2), [openssl](https://github.com/sfackler/rust-openssl) and [rustix](https://github.com/bytecodealliance/rustix).


Updates `tokio` from 1.32.0 to 1.38.2
- [Release notes](https://github.com/tokio-rs/tokio/releases)
- [Commits](https://github.com/tokio-rs/tokio/compare/tokio-1.32.0...tokio-1.38.2)

Updates `h2` from 0.3.21 to 0.3.27
- [Release notes](https://github.com/hyperium/h2/releases)
- [Changelog](https://github.com/hyperium/h2/blob/v0.3.27/CHANGELOG.md)
- [Commits](https://github.com/hyperium/h2/compare/v0.3.21...v0.3.27)

Updates `mio` from 0.8.8 to 0.8.11
- [Release notes](https://github.com/tokio-rs/mio/releases)
- [Changelog](https://github.com/tokio-rs/mio/blob/master/CHANGELOG.md)
- [Commits](https://github.com/tokio-rs/mio/compare/v0.8.8...v0.8.11)

Updates `openssl` from 0.10.60 to 0.10.73
- [Release notes](https://github.com/sfackler/rust-openssl/releases)
- [Commits](https://github.com/sfackler/rust-openssl/compare/openssl-v0.10.60...openssl-v0.10.73)

Updates `rustix` from 0.38.10 to 0.38.44
- [Release notes](https://github.com/bytecodealliance/rustix/releases)
- [Changelog](https://github.com/bytecodealliance/rustix/blob/main/CHANGES.md)
- [Commits](https://github.com/bytecodealliance/rustix/compare/v0.38.10...v0.38.44)

---
updated-dependencies:
- dependency-name: tokio
  dependency-version: 1.38.2
  dependency-type: direct:production
  dependency-group: cargo
- dependency-name: h2
  dependency-version: 0.3.27
  dependency-type: indirect
  dependency-group: cargo
- dependency-name: mio
  dependency-version: 0.8.11
  dependency-type: indirect
  dependency-group: cargo
- dependency-name: openssl
  dependency-version: 0.10.73
  dependency-type: indirect
  dependency-group: cargo
- dependency-name: rustix
  dependency-version: 0.38.44
  dependency-type: indirect
  dependency-group: cargo
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-09-17 08:10:17 +00:00
William Escande
8a5f6a61d5 HAP: wait for MTU to process reconnection event
When HAP reconnect, it sends indication of all events that happen during
the disconnection.
But it should wait for the profile to be ready and for the MTU to have
been negotiated or else the remote may not be ready yet.

As a side effect of this, the current GattServer doesn't re-populate the
handle of subscriber during a reconnection, we have to bypass this check
to send the notification
2025-09-16 16:18:16 -07:00
161 changed files with 6906 additions and 4199 deletions

View File

@@ -18,18 +18,18 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13.0"]
python-version: ["3.10", "3.11", "3.12", "3.13.0", "3.14"]
fail-fast: false
steps:
- name: Check out from Git
uses: actions/checkout@v3
uses: actions/checkout@v6
- name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies

View File

@@ -40,7 +40,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v6
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL

View File

@@ -22,10 +22,10 @@ jobs:
steps:
- name: Check out from Git
uses: actions/checkout@v3
uses: actions/checkout@v6
- name: Set up JDK
uses: actions/setup-java@v4
uses: actions/setup-java@v5
with:
distribution: 'zulu'
java-version: 17

View File

@@ -26,9 +26,9 @@ jobs:
21/24, 22/24, 23/24, 24/24,
]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v6
- name: Set Up Python 3.11
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: 3.11
- name: Install
@@ -46,7 +46,7 @@ jobs:
run: cat rootcanal.log
- name: Upload Mobly logs
if: always()
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: mobly-logs-${{ strategy.job-index }}
path: /tmp/logs/mobly/bumble.bumbles/

View File

@@ -18,18 +18,18 @@ jobs:
strategy:
matrix:
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
fail-fast: false
steps:
- name: Check out from Git
uses: actions/checkout@v3
uses: actions/checkout@v6
- name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
@@ -48,14 +48,15 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
rust-version: [ "1.80.0", "stable" ]
# Rust runtime doesn't support 3.14 yet.
python-version: ["3.10", "3.11", "3.12", "3.13"]
rust-version: [ "1.80.0", "1.91.0" ]
fail-fast: false
steps:
- name: Check out from Git
uses: actions/checkout@v3
uses: actions/checkout@v6
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
@@ -68,7 +69,7 @@ jobs:
components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }}
- name: Install Rust dependencies
run: cargo install cargo-all-features # allows building/testing combinations of features
run: cargo install cargo-all-features --version 1.11.0 # allows building/testing combinations of features
- name: Check License Headers
run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build

View File

@@ -14,13 +14,13 @@ jobs:
steps:
- name: Check out from Git
uses: actions/checkout@v3
uses: actions/checkout@v6
- name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v6
with:
python-version: '3.10'
- name: Install dependencies
@@ -31,7 +31,7 @@ jobs:
run: python -m build
- name: Publish package to PyPI
if: github.event_name == 'release' && startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # release/v1.13
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@@ -50,7 +50,7 @@ Bumble is easiest to use with a dedicated USB dongle.
This is because internal Bluetooth interfaces tend to be locked down by the operating system.
You can use the [usb_probe](/docs/mkdocs/src/apps_and_tools/usb_probe.md) tool (all platforms) or `lsusb` (Linux or macOS) to list the available USB devices on your system.
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices. Also, if your are on a mac, see [these instructions](docs/mkdocs/src/platforms/macos.md).
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices. Also, if you are on a mac, see [these instructions](docs/mkdocs/src/platforms/macos.md).
## License

File diff suppressed because it is too large Load Diff

View File

@@ -22,7 +22,6 @@ import logging
import statistics
import struct
import time
from typing import Optional
import click
@@ -257,8 +256,8 @@ async def pre_power_on(device: Device, classic: bool) -> None:
async def post_power_on(
device: Device,
le_scan: Optional[tuple[int, int]],
le_advertise: Optional[int],
le_scan: tuple[int, int] | None,
le_advertise: int | None,
classic_page_scan: bool,
classic_inquiry_scan: bool,
) -> None:
@@ -1300,7 +1299,7 @@ class IsoClient(StreamedPacketIO):
super().__init__()
self.device = device
self.ready = asyncio.Event()
self.cis_link: Optional[CisLink] = None
self.cis_link: CisLink | None = None
async def on_connection(
self, connection: Connection, cis_link: CisLink, sender: bool
@@ -1341,7 +1340,7 @@ class IsoServer(StreamedPacketIO):
):
super().__init__()
self.device = device
self.cis_link: Optional[CisLink] = None
self.cis_link: CisLink | None = None
self.ready = asyncio.Event()
logging.info(

View File

@@ -24,7 +24,6 @@ import logging
import os
import re
from collections import OrderedDict
from typing import Optional, Union
import click
import humanize
@@ -126,8 +125,8 @@ def parse_phys(phys):
# Console App
# -----------------------------------------------------------------------------
class ConsoleApp:
connected_peer: Optional[Peer]
connection_phy: Optional[ConnectionPHY]
connected_peer: Peer | None
connection_phy: ConnectionPHY | None
def __init__(self):
self.known_addresses = set()
@@ -520,7 +519,7 @@ class ConsoleApp:
self.show_attributes(attributes)
def find_remote_characteristic(self, param) -> Optional[CharacteristicProxy]:
def find_remote_characteristic(self, param) -> CharacteristicProxy | None:
if not self.connected_peer:
return None
parts = param.split('.')
@@ -542,9 +541,7 @@ class ConsoleApp:
return None
def find_local_attribute(
self, param
) -> Optional[Union[Characteristic, Descriptor]]:
def find_local_attribute(self, param) -> Characteristic | Descriptor | None:
parts = param.split('.')
if len(parts) == 3:
service_uuid = UUID(parts[0])
@@ -1096,9 +1093,7 @@ class DeviceListener(Device.Listener, Connection.Listener):
if self.app.connected_peer.connection.is_encrypted
else 'not encrypted'
)
self.app.append_to_output(
'connection encryption change: ' f'{encryption_state}'
)
self.app.append_to_output(f'connection encryption change: {encryption_state}')
def on_connection_data_length_change(self):
self.app.append_to_output(

View File

@@ -35,8 +35,6 @@ from bumble.hci import (
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_SUCCESS,
HCI_VERSION_NAMES,
LMP_VERSION_NAMES,
CodecID,
HCI_Command,
HCI_Command_Complete_Event,
@@ -54,6 +52,7 @@ from bumble.hci import (
HCI_Read_Local_Supported_Codecs_V2_Command,
HCI_Read_Local_Version_Information_Command,
LeFeature,
SpecificationVersion,
map_null_terminated_utf8_string,
)
from bumble.host import Host
@@ -275,7 +274,7 @@ async def async_main(
(
f'min={min(latencies):.2f}, '
f'max={max(latencies):.2f}, '
f'average={sum(latencies)/len(latencies):.2f},'
f'average={sum(latencies) / len(latencies):.2f},'
),
[f'{latency:.4}' for latency in latencies],
'\n',
@@ -289,14 +288,20 @@ async def async_main(
)
print(
color(' HCI Version: ', 'green'),
name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version),
SpecificationVersion(host.local_version.hci_version).name,
)
print(
color(' HCI Subversion:', 'green'),
f'0x{host.local_version.hci_subversion:04x}',
)
print(color(' HCI Subversion:', 'green'), host.local_version.hci_subversion)
print(
color(' LMP Version: ', 'green'),
name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version),
SpecificationVersion(host.local_version.lmp_version).name,
)
print(
color(' LMP Subversion:', 'green'),
f'0x{host.local_version.lmp_subversion:04x}',
)
print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion)
# Get the Classic info
await get_classic_info(host)

View File

@@ -17,7 +17,6 @@
# -----------------------------------------------------------------------------
import asyncio
import time
from typing import Optional
import click
@@ -41,7 +40,7 @@ class Loopback:
self.transport = transport
self.packet_size = packet_size
self.packet_count = packet_count
self.connection_handle: Optional[int] = None
self.connection_handle: int | None = None
self.connection_event = asyncio.Event()
self.done = asyncio.Event()
self.expected_cid = 0

View File

@@ -16,7 +16,7 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
from typing import Callable, Iterable, Optional
from collections.abc import Callable, Iterable
import click
@@ -174,7 +174,7 @@ async def show_vcs(vcs: VolumeControlServiceProxy) -> None:
# -----------------------------------------------------------------------------
async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
async def show_device_info(peer, done: asyncio.Future | None) -> None:
try:
# Discover all services
print(color('### Discovering Services and Characteristics', 'magenta'))
@@ -215,7 +215,6 @@ async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
# -----------------------------------------------------------------------------
async def async_main(device_config, encrypt, transport, address_or_name):
async with await open_transport(transport) as (hci_source, hci_sink):
# Create a device
if device_config:
device = Device.from_config_file_with_hci(

View File

@@ -61,7 +61,6 @@ async def dump_gatt_db(peer, done):
# -----------------------------------------------------------------------------
async def async_main(device_config, encrypt, transport, address_or_name):
async with await open_transport(transport) as (hci_source, hci_sink):
# Create a device
if device_config:
device = Device.from_config_file_with_hci(

View File

@@ -268,7 +268,6 @@ class UiServer:
# -----------------------------------------------------------------------------
class Speaker:
def __init__(
self,
device_config_path: str | None,

View File

@@ -18,7 +18,6 @@
import asyncio
import logging
import os
import struct
import click
from prompt_toolkit.shortcuts import PromptSession
@@ -528,7 +527,9 @@ async def pair(
if advertise_appearance:
advertise_appearance = advertise_appearance.upper()
try:
advertise_appearance_int = int(advertise_appearance)
appearance = data_types.Appearance.from_int(
int(advertise_appearance)
)
except ValueError:
category, subcategory = advertise_appearance.split('/')
try:
@@ -546,12 +547,11 @@ async def pair(
except ValueError:
print(color(f'Invalid subcategory {subcategory}', 'red'))
return
advertise_appearance_int = int(
Appearance(category_enum, subcategory_enum)
appearance = data_types.Appearance(
category_enum, subcategory_enum
)
advertising_data_types.append(
data_types.Appearance(category_enum, subcategory_enum)
)
advertising_data_types.append(appearance)
device.advertising_data = bytes(AdvertisingData(advertising_data_types))
await device.start_advertising(
auto_restart=True,

View File

@@ -19,7 +19,7 @@ ROOTCANAL_PORT_CUTTLEFISH = 7300
@click.option(
'--transport',
help='HCI transport',
default=f'tcp-client:127.0.0.1:<rootcanal-port>',
default='tcp-client:127.0.0.1:<rootcanal-port>',
)
@click.option(
'--config',
@@ -44,7 +44,7 @@ def retrieve_config(config: str) -> dict[str, Any]:
if not config:
return {}
with open(config, 'r') as f:
with open(config) as f:
return json.load(f)

View File

@@ -19,7 +19,6 @@ from __future__ import annotations
import asyncio
import logging
from typing import Optional, Union
import click
@@ -47,14 +46,13 @@ from bumble.avdtp import (
AVDTP_DELAY_REPORTING_SERVICE_CATEGORY,
MediaCodecCapabilities,
MediaPacketPump,
find_avdtp_service_with_connection,
)
from bumble.avdtp import Protocol as AvdtpProtocol
from bumble.avdtp import find_avdtp_service_with_connection
from bumble.avrcp import Protocol as AvrcpProtocol
from bumble.colors import color
from bumble.core import AdvertisingData
from bumble.core import AdvertisingData, DeviceClass, PhysicalTransport
from bumble.core import ConnectionError as BumbleConnectionError
from bumble.core import DeviceClass, PhysicalTransport
from bumble.device import Connection, Device, DeviceConfiguration
from bumble.hci import HCI_CONNECTION_ALREADY_EXISTS_ERROR, Address, HCI_Constant
from bumble.pairing import PairingConfig
@@ -191,7 +189,7 @@ class Player:
def __init__(
self,
transport: str,
device_config: Optional[str],
device_config: str | None,
authenticate: bool,
encrypt: bool,
) -> None:
@@ -199,8 +197,8 @@ class Player:
self.device_config = device_config
self.authenticate = authenticate
self.encrypt = encrypt
self.avrcp_protocol: Optional[AvrcpProtocol] = None
self.done: Optional[asyncio.Event]
self.avrcp_protocol: AvrcpProtocol | None = None
self.done: asyncio.Event | None
async def run(self, workload) -> None:
self.done = asyncio.Event()
@@ -315,7 +313,7 @@ class Player:
codec_type: int,
vendor_id: int,
codec_id: int,
packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource],
packet_source: SbcPacketSource | AacPacketSource | OpusPacketSource,
codec_capabilities: MediaCodecCapabilities,
):
# Discover all endpoints on the remote device
@@ -381,11 +379,11 @@ class Player:
print(f">>> {color(address.to_string(False), 'yellow')}:")
print(f" Device Class (raw): {class_of_device:06X}")
major_class_name = DeviceClass.major_device_class_name(major_device_class)
print(" Device Major Class: " f"{major_class_name}")
print(f" Device Major Class: {major_class_name}")
minor_class_name = DeviceClass.minor_device_class_name(
major_device_class, minor_device_class
)
print(" Device Minor Class: " f"{minor_class_name}")
print(f" Device Minor Class: {minor_class_name}")
print(
" Device Services: "
f"{', '.join(DeviceClass.service_class_labels(service_classes))}"
@@ -420,7 +418,7 @@ class Player:
async def play(
self,
device: Device,
address: Optional[str],
address: str | None,
audio_format: str,
audio_file: str,
) -> None:
@@ -449,7 +447,7 @@ class Player:
return input_file.read(byte_count)
# Obtain the codec capabilities from the stream
packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource]
packet_source: SbcPacketSource | AacPacketSource | OpusPacketSource
vendor_id = 0
codec_id = 0
if audio_format == "sbc":

View File

@@ -17,7 +17,6 @@
# -----------------------------------------------------------------------------
import asyncio
import time
from typing import Optional
import click
@@ -82,14 +81,14 @@ class ServerBridge:
def __init__(
self, channel: int, uuid: str, trace: bool, tcp_host: str, tcp_port: int
) -> None:
self.device: Optional[Device] = None
self.device: Device | None = None
self.channel = channel
self.uuid = uuid
self.tcp_host = tcp_host
self.tcp_port = tcp_port
self.rfcomm_channel: Optional[rfcomm.DLC] = None
self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Optional[Tracer]
self.rfcomm_channel: rfcomm.DLC | None = None
self.tcp_tracer: Tracer | None
self.rfcomm_tracer: Tracer | None
if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
@@ -242,14 +241,14 @@ class ClientBridge:
self.tcp_port = tcp_port
self.authenticate = authenticate
self.encrypt = encrypt
self.device: Optional[Device] = None
self.connection: Optional[Connection] = None
self.rfcomm_client: Optional[rfcomm.Client]
self.rfcomm_mux: Optional[rfcomm.Multiplexer]
self.device: Device | None = None
self.connection: Connection | None = None
self.rfcomm_client: rfcomm.Client | None
self.rfcomm_mux: rfcomm.Multiplexer | None
self.tcp_connected: bool = False
self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Optional[Tracer]
self.tcp_tracer: Tracer | None
self.rfcomm_tracer: Tracer | None
if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))

View File

@@ -217,9 +217,7 @@ async def scan(
@click.option(
'--irk',
metavar='<IRK_HEX>:<ADDRESS>',
help=(
'Use this IRK for resolving private addresses ' '(may be used more than once)'
),
help=('Use this IRK for resolving private addresses (may be used more than once)'),
multiple=True,
)
@click.option(

View File

@@ -26,7 +26,6 @@ import pathlib
import subprocess
import weakref
from importlib import resources
from typing import Optional
import aiohttp
import click
@@ -156,7 +155,7 @@ class QueuedOutput(Output):
packets: asyncio.Queue
extractor: AudioExtractor
packet_pump_task: Optional[asyncio.Task]
packet_pump_task: asyncio.Task | None
started: bool
def __init__(self, extractor):
@@ -230,8 +229,8 @@ class WebSocketOutput(QueuedOutput):
class FfplayOutput(QueuedOutput):
MAX_QUEUE_SIZE = 32768
subprocess: Optional[asyncio.subprocess.Process]
ffplay_task: Optional[asyncio.Task]
subprocess: asyncio.subprocess.Process | None
ffplay_task: asyncio.Task | None
def __init__(self, codec: str) -> None:
super().__init__(AudioExtractor.create(codec))

View File

@@ -21,11 +21,12 @@ import dataclasses
import enum
import logging
import struct
from collections.abc import AsyncGenerator
from typing import Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import ClassVar
from typing_extensions import ClassVar, Self
from typing_extensions import Self
from bumble import utils
from bumble.codecs import AacAudioRtpPacket
from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.core import (
@@ -59,19 +60,18 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# fmt: off
A2DP_SBC_CODEC_TYPE = 0x00
A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = 0x01
A2DP_MPEG_2_4_AAC_CODEC_TYPE = 0x02
A2DP_ATRAC_FAMILY_CODEC_TYPE = 0x03
A2DP_NON_A2DP_CODEC_TYPE = 0xFF
class CodecType(utils.OpenIntEnum):
SBC = 0x00
MPEG_1_2_AUDIO = 0x01
MPEG_2_4_AAC = 0x02
ATRAC_FAMILY = 0x03
NON_A2DP = 0xFF
A2DP_CODEC_TYPE_NAMES = {
A2DP_SBC_CODEC_TYPE: 'A2DP_SBC_CODEC_TYPE',
A2DP_MPEG_1_2_AUDIO_CODEC_TYPE: 'A2DP_MPEG_1_2_AUDIO_CODEC_TYPE',
A2DP_MPEG_2_4_AAC_CODEC_TYPE: 'A2DP_MPEG_2_4_AAC_CODEC_TYPE',
A2DP_ATRAC_FAMILY_CODEC_TYPE: 'A2DP_ATRAC_FAMILY_CODEC_TYPE',
A2DP_NON_A2DP_CODEC_TYPE: 'A2DP_NON_A2DP_CODEC_TYPE'
}
A2DP_SBC_CODEC_TYPE = CodecType.SBC
A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = CodecType.MPEG_1_2_AUDIO
A2DP_MPEG_2_4_AAC_CODEC_TYPE = CodecType.MPEG_2_4_AAC
A2DP_ATRAC_FAMILY_CODEC_TYPE = CodecType.ATRAC_FAMILY
A2DP_NON_A2DP_CODEC_TYPE = CodecType.NON_A2DP
SBC_SYNC_WORD = 0x9C
@@ -259,9 +259,48 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
]
# -----------------------------------------------------------------------------
class MediaCodecInformation:
'''Base Media Codec Information.'''
@classmethod
def create(
cls, media_codec_type: int, data: bytes
) -> MediaCodecInformation | bytes:
if media_codec_type == CodecType.SBC:
return SbcMediaCodecInformation.from_bytes(data)
elif media_codec_type == CodecType.MPEG_2_4_AAC:
return AacMediaCodecInformation.from_bytes(data)
elif media_codec_type == CodecType.NON_A2DP:
vendor_media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(data)
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
return media_codec_information_class.from_bytes(
vendor_media_codec_information.value
)
return vendor_media_codec_information
@classmethod
def from_bytes(cls, data: bytes) -> Self:
del data # Unused.
raise NotImplementedError
def __bytes__(self) -> bytes:
raise NotImplementedError
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class SbcMediaCodecInformation:
class SbcMediaCodecInformation(MediaCodecInformation):
'''
A2DP spec - 4.3.2 Codec Specific Information Elements
'''
@@ -345,7 +384,7 @@ class SbcMediaCodecInformation:
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class AacMediaCodecInformation:
class AacMediaCodecInformation(MediaCodecInformation):
'''
A2DP spec - 4.5.2 Codec Specific Information Elements
'''
@@ -427,7 +466,7 @@ class AacMediaCodecInformation:
@dataclasses.dataclass
# -----------------------------------------------------------------------------
class VendorSpecificMediaCodecInformation:
class VendorSpecificMediaCodecInformation(MediaCodecInformation):
'''
A2DP spec - 4.7.2 Codec Specific Information Elements
'''
@@ -451,7 +490,7 @@ class VendorSpecificMediaCodecInformation:
'VendorSpecificMediaCodecInformation(',
f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})',
f' codec_id: {self.codec_id:04X}',
f' value: {self.value.hex()}' ')',
f' value: {self.value.hex()})',
]
)
@@ -647,7 +686,7 @@ class SbcPacketSource:
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
sample_count += sum((frame.sample_count for frame in frames))
sample_count += sum(frame.sample_count for frame in frames)
frames = [frame]
frames_size = len(frame.payload)
else:

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from bumble import core
@@ -36,7 +35,7 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
if in_quotes:
token.extend(char)
if char == b'\"':
if char == b'"':
in_quotes = False
tokens.append(token[1:-1])
token = bytearray()
@@ -63,18 +62,18 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
return [bytes(token) for token in tokens if len(token) > 0]
def parse_parameters(buffer: bytes) -> list[Union[bytes, list]]:
def parse_parameters(buffer: bytes) -> list[bytes | list]:
"""Parse the parameters using the comma and parenthesis separators.
Raises AtParsingError in case of invalid input string."""
tokens = tokenize_parameters(buffer)
accumulator: list[list] = [[]]
current: Union[bytes, list] = bytes()
current: bytes | list = b''
for token in tokens:
if token == b',':
accumulator[-1].append(current)
current = bytes()
current = b''
elif token == b'(':
accumulator.append([])
elif token == b')':

View File

@@ -29,15 +29,12 @@ import enum
import functools
import inspect
import struct
from collections.abc import Awaitable, Callable
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
ClassVar,
Generic,
Optional,
TypeVar,
Union,
)
from bumble import hci, utils
@@ -220,7 +217,7 @@ class ATT_PDU:
fields: ClassVar[hci.Fields] = ()
op_code: int = dataclasses.field(init=False)
name: str = dataclasses.field(init=False)
_payload: Optional[bytes] = dataclasses.field(default=None, init=False)
_payload: bytes | None = dataclasses.field(default=None, init=False)
@classmethod
def from_bytes(cls, pdu: bytes) -> ATT_PDU:
@@ -760,26 +757,24 @@ class AttributeValue(Generic[_T]):
def __init__(
self,
read: Union[
Callable[[Connection], _T],
Callable[[Connection], Awaitable[_T]],
None,
] = None,
write: Union[
Callable[[Connection, _T], None],
Callable[[Connection, _T], Awaitable[None]],
None,
] = None,
read: (
Callable[[Connection], _T] | Callable[[Connection], Awaitable[_T]] | None
) = None,
write: (
Callable[[Connection, _T], None]
| Callable[[Connection, _T], Awaitable[None]]
| None
) = None,
):
self._read = read
self._write = write
def read(self, connection: Connection) -> Union[_T, Awaitable[_T]]:
def read(self, connection: Connection) -> _T | Awaitable[_T]:
if self._read is None:
raise InvalidOperationError('AttributeValue has no read function')
return self._read(connection)
def write(self, connection: Connection, value: _T) -> Union[Awaitable[None], None]:
def write(self, connection: Connection, value: _T) -> Awaitable[None] | None:
if self._write is None:
raise InvalidOperationError('AttributeValue has no write function')
return self._write(connection, value)
@@ -828,13 +823,13 @@ class Attribute(utils.EventEmitter, Generic[_T]):
EVENT_READ = "read"
EVENT_WRITE = "write"
value: Union[AttributeValue[_T], _T, None]
value: AttributeValue[_T] | _T | None
def __init__(
self,
attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions],
value: Union[AttributeValue[_T], _T, None] = None,
attribute_type: str | bytes | UUID,
permissions: str | Attribute.Permissions,
value: AttributeValue[_T] | _T | None = None,
) -> None:
utils.EventEmitter.__init__(self)
self.handle = 0
@@ -883,7 +878,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
value: Union[_T, None]
value: _T | None
if isinstance(self.value, AttributeValue):
try:
read_value = self.value.read(connection)

View File

@@ -19,14 +19,15 @@ from __future__ import annotations
import abc
import asyncio
import concurrent.futures
import dataclasses
import enum
import logging
import pathlib
import sys
import wave
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, BinaryIO
from bumble.colors import color
@@ -176,7 +177,7 @@ class ThreadedAudioOutput(AudioOutput):
"""
def __init__(self) -> None:
self._thread_pool = ThreadPoolExecutor(1)
self._thread_pool = concurrent.futures.ThreadPoolExecutor(1)
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
self._write_task = asyncio.create_task(self._write_loop())
@@ -405,7 +406,7 @@ class ThreadedAudioInput(AudioInput):
"""Base class for AudioInput implementation where reading samples may block."""
def __init__(self) -> None:
self._thread_pool = ThreadPoolExecutor(1)
self._thread_pool = concurrent.futures.ThreadPoolExecutor(1)
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
@abc.abstractmethod
@@ -545,5 +546,6 @@ class SoundDeviceAudioInput(ThreadedAudioInput):
return bytes(pcm_buffer)
def _close(self):
self._stream.stop()
self._stream = None
if self._stream:
self._stream.stop()
self._stream = None

View File

@@ -19,7 +19,6 @@ from __future__ import annotations
import enum
import struct
from typing import Union
from bumble import core, utils
@@ -166,7 +165,7 @@ class Frame:
def to_bytes(
self,
ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode],
ctype_or_response: CommandFrame.CommandType | ResponseFrame.ResponseCode,
) -> bytes:
# TODO: support extended subunit types and ids.
return (

View File

@@ -19,10 +19,10 @@ from __future__ import annotations
import logging
import struct
from collections.abc import Callable
from enum import IntEnum
from typing import Callable, Optional, cast
from bumble import avc, core, l2cap
from bumble import core, l2cap
from bumble.colors import color
# -----------------------------------------------------------------------------
@@ -144,9 +144,9 @@ class MessageAssembler:
# -----------------------------------------------------------------------------
class Protocol:
CommandHandler = Callable[[int, avc.CommandFrame], None]
CommandHandler = Callable[[int, bytes], None]
command_handlers: dict[int, CommandHandler] # Command handlers, by PID
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
ResponseHandler = Callable[[int, bytes | None], None]
response_handlers: dict[int, ResponseHandler] # Response handlers, by PID
next_transaction_label: int
message_assembler: MessageAssembler
@@ -204,20 +204,15 @@ class Protocol:
self.send_ipid(transaction_label, pid)
return
command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
self.command_handlers[pid](transaction_label, command_frame)
self.command_handlers[pid](transaction_label, payload)
else:
if pid not in self.response_handlers:
logger.warning(f"no response handler for PID {pid}")
return
# By convention, for an ipid, send a None payload to the response handler.
if ipid:
response_frame = None
else:
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
self.response_handlers[pid](transaction_label, response_frame)
response_payload = None if ipid else payload
self.response_handlers[pid](transaction_label, response_payload)
def send_message(
self,
@@ -262,7 +257,7 @@ class Protocol:
def send_ipid(self, transaction_label: int, pid: int) -> None:
logger.debug(
">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
f">>> AVCTP ipid: transaction_label={transaction_label}, pid={pid}"
)
self.send_message(transaction_label, False, True, pid, b'')

File diff suppressed because it is too large Load Diff

View File

@@ -22,21 +22,9 @@ import enum
import functools
import logging
import struct
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
from dataclasses import dataclass, field
from typing import (
AsyncIterator,
Awaitable,
Callable,
ClassVar,
Iterable,
List,
Optional,
Sequence,
SupportsBytes,
TypeVar,
Union,
cast,
)
from typing import ClassVar, SupportsBytes, TypeVar
from bumble import avc, avctp, core, hci, l2cap, utils
from bumble.colors import color
@@ -208,7 +196,7 @@ def make_controller_service_sdp_records(
service_record_handle: int,
avctp_version: tuple[int, int] = (1, 4),
avrcp_version: tuple[int, int] = (1, 6),
supported_features: Union[int, ControllerFeatures] = 1,
supported_features: int | ControllerFeatures = 1,
) -> list[ServiceAttribute]:
avctp_version_int = avctp_version[0] << 8 | avctp_version[1]
avrcp_version_int = avrcp_version[0] << 8 | avrcp_version[1]
@@ -300,7 +288,7 @@ def make_target_service_sdp_records(
service_record_handle: int,
avctp_version: tuple[int, int] = (1, 4),
avrcp_version: tuple[int, int] = (1, 6),
supported_features: Union[int, TargetFeatures] = 0x23,
supported_features: int | TargetFeatures = 0x23,
) -> list[ServiceAttribute]:
# TODO: support a way to compute the supported features from a feature list
avctp_version_int = avctp_version[0] << 8 | avctp_version[1]
@@ -490,7 +478,7 @@ class BrowseableItem:
MEDIA_ELEMENT = 0x03
item_type: ClassVar[Type]
_payload: Optional[bytes] = None
_payload: bytes | None = None
subclasses: ClassVar[dict[Type, type[BrowseableItem]]] = {}
fields: ClassVar[hci.Fields] = ()
@@ -684,7 +672,7 @@ class PduAssembler:
6.3.1 AVRCP specific AV//C commands
"""
pdu_id: Optional[PduId]
pdu_id: PduId | None
payload: bytes
def __init__(self, callback: Callable[[PduId, bytes], None]) -> None:
@@ -737,7 +725,7 @@ class PduAssembler:
# -----------------------------------------------------------------------------
class Command:
pdu_id: ClassVar[PduId]
_payload: Optional[bytes] = None
_payload: bytes | None = None
_Command = TypeVar('_Command', bound='Command')
subclasses: ClassVar[dict[int, type[Command]]] = {}
@@ -1029,7 +1017,7 @@ class AddToNowPlayingCommand(Command):
# -----------------------------------------------------------------------------
class Response:
pdu_id: PduId
_payload: Optional[bytes] = None
_payload: bytes | None = None
fields: ClassVar[hci.Fields] = ()
subclasses: ClassVar[dict[PduId, type[Response]]] = {}
@@ -1091,7 +1079,7 @@ class NotImplementedResponse(Response):
class GetCapabilitiesResponse(Response):
pdu_id = PduId.GET_CAPABILITIES
capability_id: GetCapabilitiesCommand.CapabilityId
capabilities: Sequence[Union[SupportsBytes, bytes]]
capabilities: Sequence[SupportsBytes | bytes]
@classmethod
def from_parameters(cls, parameters: bytes) -> Response:
@@ -1104,7 +1092,7 @@ class GetCapabilitiesResponse(Response):
capability_id = GetCapabilitiesCommand.CapabilityId(parameters[0])
capability_count = parameters[1]
capabilities: list[Union[SupportsBytes, bytes]]
capabilities: list[SupportsBytes | bytes]
if capability_id == GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED:
capabilities = [EventId(parameters[2 + x]) for x in range(capability_count)]
else:
@@ -1375,7 +1363,7 @@ class AddToNowPlayingResponse(Response):
# -----------------------------------------------------------------------------
class Event:
event_id: EventId
_pdu: Optional[bytes] = None
_pdu: bytes | None = None
_Event = TypeVar('_Event', bound='Event')
subclasses: ClassVar[dict[int, type[Event]]] = {}
@@ -1448,13 +1436,13 @@ class PlayerApplicationSettingChangedEvent(Event):
attribute_id: ApplicationSetting.AttributeId = field(
metadata=ApplicationSetting.AttributeId.type_metadata(1)
)
value_id: Union[
ApplicationSetting.EqualizerOnOffStatus,
ApplicationSetting.RepeatModeStatus,
ApplicationSetting.ShuffleOnOffStatus,
ApplicationSetting.ScanOnOffStatus,
ApplicationSetting.GenericValue,
] = field(metadata=hci.metadata(1))
value_id: (
ApplicationSetting.EqualizerOnOffStatus
| ApplicationSetting.RepeatModeStatus
| ApplicationSetting.ShuffleOnOffStatus
| ApplicationSetting.ScanOnOffStatus
| ApplicationSetting.GenericValue
) = field(metadata=hci.metadata(1))
def __post_init__(self) -> None:
super().__post_init__()
@@ -1640,17 +1628,17 @@ class Protocol(utils.EventEmitter):
delegate: Delegate
send_transaction_label: int
command_pdu_assembler: PduAssembler
receive_command_state: Optional[ReceiveCommandState]
receive_command_state: ReceiveCommandState | None
response_pdu_assembler: PduAssembler
receive_response_state: Optional[ReceiveResponseState]
avctp_protocol: Optional[avctp.Protocol]
receive_response_state: ReceiveResponseState | None
avctp_protocol: avctp.Protocol | None
free_commands: asyncio.Queue
pending_commands: dict[int, PendingCommand] # Pending commands, by label
notification_listeners: dict[EventId, NotificationListener]
@staticmethod
def _check_vendor_dependent_frame(
frame: Union[avc.VendorDependentCommandFrame, avc.VendorDependentResponseFrame],
frame: avc.VendorDependentCommandFrame | avc.VendorDependentResponseFrame,
) -> bool:
if frame.company_id != AVRCP_BLUETOOTH_SIG_COMPANY_ID:
logger.debug("unsupported company id, ignoring")
@@ -1662,7 +1650,7 @@ class Protocol(utils.EventEmitter):
return True
def __init__(self, delegate: Optional[Delegate] = None) -> None:
def __init__(self, delegate: Delegate | None = None) -> None:
super().__init__()
self.delegate = delegate if delegate else Delegate()
self.command_pdu_assembler = PduAssembler(self._on_command_pdu)
@@ -1762,7 +1750,11 @@ class Protocol(utils.EventEmitter):
),
)
response = self._check_response(response_context, GetCapabilitiesResponse)
return cast(List[EventId], response.capabilities)
return list(
capability
for capability in response.capabilities
if isinstance(capability, EventId)
)
async def get_play_status(self) -> SongAndPlayStatus:
"""Get the play status of the connected peer."""
@@ -2012,11 +2004,14 @@ class Protocol(utils.EventEmitter):
self.emit(self.EVENT_STOP)
def _on_avctp_command(
self, transaction_label: int, command: avc.CommandFrame
) -> None:
def _on_avctp_command(self, transaction_label: int, payload: bytes) -> None:
command = avc.CommandFrame.from_bytes(payload)
if not isinstance(command, avc.CommandFrame):
raise core.InvalidPacketError(
f"{command} is not a valid AV/C Command Frame"
)
logger.debug(
f"<<< AVCTP Command, transaction_label={transaction_label}: " f"{command}"
f"<<< AVCTP Command, transaction_label={transaction_label}: {command}"
)
# Only addressing the unit, or the PANEL subunit with subunit ID 0 is supported
@@ -2072,9 +2067,12 @@ class Protocol(utils.EventEmitter):
# TODO handle other types
self.send_not_implemented_response(transaction_label, command)
def _on_avctp_response(
self, transaction_label: int, response: Optional[avc.ResponseFrame]
) -> None:
def _on_avctp_response(self, transaction_label: int, payload: bytes | None) -> None:
response = avc.ResponseFrame.from_bytes(payload) if payload else None
if not isinstance(response, avc.ResponseFrame):
raise core.InvalidPacketError(
f"{response} is not a valid AV/C Response Frame"
)
logger.debug(
f"<<< AVCTP Response, transaction_label={transaction_label}: {response}"
)
@@ -2176,7 +2174,7 @@ class Protocol(utils.EventEmitter):
# NOTE: with a small number of supported responses, a manual switch like this
# is Ok, but if/when more responses are supported, a lookup mechanism would be
# more appropriate.
response: Optional[Response] = None
response: Response | None = None
if response_code == avc.ResponseFrame.ResponseCode.REJECTED:
response = RejectedResponse(pdu_id=pdu_id, status_code=StatusCode(pdu[0]))
elif response_code == avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED:
@@ -2391,7 +2389,7 @@ class Protocol(utils.EventEmitter):
effective_volume = await self.delegate.get_absolute_volume()
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
avc.ResponseFrame.ResponseCode.ACCEPTED,
SetAbsoluteVolumeResponse(effective_volume),
)

View File

@@ -163,23 +163,23 @@ class AacAudioRtpPacket:
cls, reader: BitReader, channel_configuration: int, audio_object_type: int
) -> Self:
# GASpecificConfig - ISO/EIC 14496-3 Table 4.1
frame_length_flag = reader.read(1)
reader.read(1) # frame_length_flag
depends_on_core_coder = reader.read(1)
if depends_on_core_coder:
core_coder_delay = reader.read(14)
reader.read(14) # core_coder_delay
extension_flag = reader.read(1)
if not channel_configuration:
raise core.InvalidPacketError('program_config_element not supported')
if audio_object_type in (6, 20):
layer_nr = reader.read(3)
reader.read(3) # layer_nr
if extension_flag:
if audio_object_type == 22:
num_of_sub_frame = reader.read(5)
layer_length = reader.read(11)
reader.read(5) # num_of_sub_frame
reader.read(11) # layer_length
if audio_object_type in (17, 19, 20, 23):
aac_section_data_resilience_flags = reader.read(1)
aac_scale_factor_data_resilience_flags = reader.read(1)
aac_spectral_data_resilience_flags = reader.read(1)
reader.read(1) # aac_section_data_resilience_flags
reader.read(1) # aac_scale_factor_data_resilience_flags
reader.read(1) # aac_spectral_data_resilience_flags
extension_flag_3 = reader.read(1)
if extension_flag_3 == 1:
raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
@@ -364,10 +364,10 @@ class AacAudioRtpPacket:
if audio_mux_version_a != 0:
raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
if audio_mux_version == 1:
tara_buffer_fullness = AacAudioRtpPacket.read_latm_value(reader)
stream_cnt = 0
all_streams_same_time_framing = reader.read(1)
num_sub_frames = reader.read(6)
AacAudioRtpPacket.read_latm_value(reader) # tara_buffer_fullness
# stream_cnt = 0
reader.read(1) # all_streams_same_time_framing
reader.read(6) # num_sub_frames
num_program = reader.read(4)
if num_program != 0:
raise core.InvalidPacketError('num_program != 0 not supported')
@@ -391,9 +391,9 @@ class AacAudioRtpPacket:
reader.skip(asc_len)
frame_length_type = reader.read(3)
if frame_length_type == 0:
latm_buffer_fullness = reader.read(8)
reader.read(8) # latm_buffer_fullness
elif frame_length_type == 1:
frame_length = reader.read(9)
reader.read(9) # frame_length
else:
raise core.InvalidPacketError(
f'frame_length_type {frame_length_type} not supported'
@@ -413,7 +413,7 @@ class AacAudioRtpPacket:
break
crc_check_present = reader.read(1)
if crc_check_present:
crc_checksum = reader.read(8)
reader.read(8) # crc_checksum
return cls(other_data_present, other_data_len_bits, audio_specific_config)

View File

@@ -13,7 +13,6 @@
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from functools import partial
from typing import Optional, Union
class ColorError(ValueError):
@@ -38,7 +37,7 @@ STYLES = (
)
ColorSpec = Union[str, int]
ColorSpec = str | int
def _join(*values: ColorSpec) -> str:
@@ -56,14 +55,14 @@ def _color_code(spec: ColorSpec, base: int) -> str:
elif isinstance(spec, int) and 0 <= spec <= 255:
return _join(base + 8, 5, spec)
else:
raise ColorError('Invalid color spec "%s"' % spec)
raise ColorError(f'Invalid color spec "{spec}"')
def color(
s: str,
fg: Optional[ColorSpec] = None,
bg: Optional[ColorSpec] = None,
style: Optional[str] = None,
fg: ColorSpec | None = None,
bg: ColorSpec | None = None,
style: str | None = None,
) -> str:
codes: list[ColorSpec] = []
@@ -76,10 +75,10 @@ def color(
if style_part in STYLES:
codes.append(STYLES.index(style_part))
else:
raise ColorError('Invalid style "%s"' % style_part)
raise ColorError(f'Invalid style "{style_part}"')
if codes:
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)
return f'\x1b[{_join(*codes)}m{s}\x1b[0m'
else:
return s

File diff suppressed because it is too large Load Diff

View File

@@ -20,15 +20,11 @@ from __future__ import annotations
import dataclasses
import enum
import struct
from collections.abc import Iterable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Iterable,
Literal,
Optional,
Type,
Union,
cast,
overload,
)
@@ -103,7 +99,7 @@ class BaseError(BaseBumbleError):
def __init__(
self,
error_code: Optional[int],
error_code: int | None,
error_namespace: str = '',
error_name: str = '',
details: str = '',
@@ -216,11 +212,9 @@ class UUID:
UUIDS: list[UUID] = [] # Registry of all instances created
uuid_bytes: bytes
name: Optional[str]
name: str | None
def __init__(
self, uuid_str_or_int: Union[str, int], name: Optional[str] = None
) -> None:
def __init__(self, uuid_str_or_int: str | int, name: str | None = None) -> None:
if isinstance(uuid_str_or_int, int):
self.uuid_bytes = struct.pack('<H', uuid_str_or_int)
else:
@@ -253,7 +247,7 @@ class UUID:
return self
@classmethod
def from_bytes(cls, uuid_bytes: bytes, name: Optional[str] = None) -> UUID:
def from_bytes(cls, uuid_bytes: bytes, name: str | None = None) -> UUID:
if len(uuid_bytes) in (2, 4, 16):
self = cls.__new__(cls)
self.uuid_bytes = uuid_bytes
@@ -264,11 +258,11 @@ class UUID:
raise InvalidArgumentError('only 2, 4 and 16 bytes are allowed')
@classmethod
def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID:
def from_16_bits(cls, uuid_16: int, name: str | None = None) -> UUID:
return cls.from_bytes(struct.pack('<H', uuid_16), name)
@classmethod
def from_32_bits(cls, uuid_32: int, name: Optional[str] = None) -> UUID:
def from_32_bits(cls, uuid_32: int, name: str | None = None) -> UUID:
return cls.from_bytes(struct.pack('<I', uuid_32), name)
@classmethod
@@ -734,7 +728,7 @@ class ClassOfDevice:
MajorDeviceClass.HEALTH: HEALTH_MINOR_DEVICE_CLASS_LABELS,
}
_MINOR_DEVICE_CLASSES: ClassVar[dict[MajorDeviceClass, Type]] = {
_MINOR_DEVICE_CLASSES: ClassVar[dict[MajorDeviceClass, type]] = {
MajorDeviceClass.COMPUTER: ComputerMinorDeviceClass,
MajorDeviceClass.PHONE: PhoneMinorDeviceClass,
MajorDeviceClass.LAN_NETWORK_ACCESS_POINT: LanNetworkMinorDeviceClass,
@@ -749,17 +743,17 @@ class ClassOfDevice:
major_service_classes: MajorServiceClasses
major_device_class: MajorDeviceClass
minor_device_class: Union[
ComputerMinorDeviceClass,
PhoneMinorDeviceClass,
LanNetworkMinorDeviceClass,
AudioVideoMinorDeviceClass,
PeripheralMinorDeviceClass,
WearableMinorDeviceClass,
ToyMinorDeviceClass,
HealthMinorDeviceClass,
int,
]
minor_device_class: (
ComputerMinorDeviceClass
| PhoneMinorDeviceClass
| LanNetworkMinorDeviceClass
| AudioVideoMinorDeviceClass
| PeripheralMinorDeviceClass
| WearableMinorDeviceClass
| ToyMinorDeviceClass
| HealthMinorDeviceClass
| int
)
@classmethod
def from_int(cls, class_of_device: int) -> Self:
@@ -1548,7 +1542,7 @@ class DataType:
return f"{self.__class__.__name__}({self.value_string()})"
@classmethod
def from_advertising_data(cls, advertising_data: AdvertisingData) -> Optional[Self]:
def from_advertising_data(cls, advertising_data: AdvertisingData) -> Self | None:
if (data := advertising_data.get(cls.ad_type, raw=True)) is None:
return None
@@ -1576,16 +1570,16 @@ class DataType:
# -----------------------------------------------------------------------------
# Advertising Data
# -----------------------------------------------------------------------------
AdvertisingDataObject = Union[
list[UUID],
tuple[UUID, bytes],
bytes,
str,
int,
tuple[int, int],
tuple[int, bytes],
Appearance,
]
AdvertisingDataObject = (
list[UUID]
| tuple[UUID, bytes]
| bytes
| str
| int
| tuple[int, int]
| tuple[int, bytes]
| Appearance
)
class AdvertisingData:
@@ -1722,7 +1716,7 @@ class AdvertisingData:
def __init__(
self,
ad_structures: Optional[Iterable[Union[tuple[int, bytes], DataType]]] = None,
ad_structures: Iterable[tuple[int, bytes] | DataType] | None = None,
) -> None:
if ad_structures is None:
ad_structures = []
@@ -2020,7 +2014,7 @@ class AdvertisingData:
AdvertisingData.Type.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
],
raw: Literal[False] = False,
) -> Optional[list[UUID]]: ...
) -> list[UUID] | None: ...
@overload
def get(
@@ -2031,7 +2025,7 @@ class AdvertisingData:
AdvertisingData.Type.SERVICE_DATA_128_BIT_UUID,
],
raw: Literal[False] = False,
) -> Optional[tuple[UUID, bytes]]: ...
) -> tuple[UUID, bytes] | None: ...
@overload
def get(
@@ -2043,7 +2037,7 @@ class AdvertisingData:
AdvertisingData.Type.BROADCAST_NAME,
],
raw: Literal[False] = False,
) -> Optional[Optional[str]]: ...
) -> str | None: ...
@overload
def get(
@@ -2055,38 +2049,36 @@ class AdvertisingData:
AdvertisingData.Type.CLASS_OF_DEVICE,
],
raw: Literal[False] = False,
) -> Optional[int]: ...
) -> int | None: ...
@overload
def get(
self,
type_id: Literal[AdvertisingData.Type.PERIPHERAL_CONNECTION_INTERVAL_RANGE,],
raw: Literal[False] = False,
) -> Optional[tuple[int, int]]: ...
) -> tuple[int, int] | None: ...
@overload
def get(
self,
type_id: Literal[AdvertisingData.Type.MANUFACTURER_SPECIFIC_DATA,],
raw: Literal[False] = False,
) -> Optional[tuple[int, bytes]]: ...
) -> tuple[int, bytes] | None: ...
@overload
def get(
self,
type_id: Literal[AdvertisingData.Type.APPEARANCE,],
raw: Literal[False] = False,
) -> Optional[Appearance]: ...
) -> Appearance | None: ...
@overload
def get(self, type_id: int, raw: Literal[True]) -> Optional[bytes]: ...
def get(self, type_id: int, raw: Literal[True]) -> bytes | None: ...
@overload
def get(
self, type_id: int, raw: bool = False
) -> Optional[AdvertisingDataObject]: ...
def get(self, type_id: int, raw: bool = False) -> AdvertisingDataObject | None: ...
def get(self, type_id: int, raw: bool = False) -> Optional[AdvertisingDataObject]:
def get(self, type_id: int, raw: bool = False) -> AdvertisingDataObject | None:
'''
Get advertising data as a simple AdvertisingDataObject object.

View File

@@ -25,10 +25,11 @@ try:
from bumble.crypto.cryptography import EccKey, aes_cmac, e
except ImportError:
logging.getLogger(__name__).debug(
"Unable to import cryptography, use built-in primitives."
"Unable to import cryptography, using built-in primitives."
)
from bumble.crypto.builtin import EccKey, aes_cmac, e # type: ignore[assignment]
_EccKey = EccKey # For the linter only
# -----------------------------------------------------------------------------
# Logging

View File

@@ -29,7 +29,6 @@ import dataclasses
import functools
import secrets
import struct
from typing import Optional
from bumble import core
@@ -85,7 +84,6 @@ class _AES:
# fmt: on
def __init__(self, key: bytes) -> None:
if len(key) not in (16, 24, 32):
raise core.InvalidArgumentError(f'Invalid key size {len(key)}')
@@ -112,7 +110,6 @@ class _AES:
r_con_pointer = 0
t = kc
while t < round_key_count:
tt = tk[kc - 1]
tk[0] ^= (
(self._S[(tt >> 16) & 0xFF] << 24)
@@ -269,7 +266,6 @@ class _ECB:
class _CBC:
def __init__(self, key: bytes, iv: bytes = bytes(16)) -> None:
if len(iv) != 16:
raise core.InvalidArgumentError(
@@ -302,7 +298,6 @@ class _CBC:
class _CMAC:
def __init__(
self,
key: bytes,
@@ -313,7 +308,7 @@ class _CMAC:
self.digest_size = mac_len
self._key = key
self._block_size = bs = 16
self._mac_tag: Optional[bytes] = None
self._mac_tag: bytes | None = None
self._update_after_digest = update_after_digest
# Section 5.3 of NIST SP 800 38B and Appendix B
@@ -352,7 +347,7 @@ class _CMAC:
self._last_ct = zero_block
# Last block that was encrypted with AES
self._last_pt: Optional[bytes] = None
self._last_pt: bytes | None = None
# Counter for total message size
self._data_size = 0
@@ -414,7 +409,6 @@ class _CMAC:
self._last_pt = _xor(second_last, data_block[-bs:])
def digest(self) -> bytes:
bs = self._block_size
if self._mac_tag is not None and not self._update_after_digest:

View File

@@ -25,7 +25,8 @@ from __future__ import annotations
import dataclasses
import math
import struct
from typing import Any, ClassVar, Sequence
from collections.abc import Sequence
from typing import Any, ClassVar
from typing_extensions import Self

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
# -----------------------------------------------------------------------------
# Constants
@@ -167,12 +166,12 @@ class G722Decoder:
# The initial value in BLOCK 3H
self._band[1].det = 8
def decode_frame(self, encoded_data: Union[bytes, bytearray]) -> bytearray:
def decode_frame(self, encoded_data: bytes | bytearray) -> bytearray:
result_array = bytearray(len(encoded_data) * 4)
self.g722_decode(result_array, encoded_data)
return result_array
def g722_decode(self, result_array, encoded_data: Union[bytes, bytearray]) -> int:
def g722_decode(self, result_array, encoded_data: bytes | bytearray) -> int:
"""Decode the data frame using g722 decoder."""
result_length = 0

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,8 @@ from __future__ import annotations
import logging
import pathlib
import platform
from typing import TYPE_CHECKING, Iterable, Optional
from collections.abc import Iterable
from typing import TYPE_CHECKING
from bumble.drivers import intel, rtk
from bumble.drivers.common import Driver
@@ -41,7 +42,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Functions
# -----------------------------------------------------------------------------
async def get_driver_for_host(host: Host) -> Optional[Driver]:
async def get_driver_for_host(host: Host) -> Driver | None:
"""Probe diver classes until one returns a valid instance for a host, or none is
found.
If a "driver" HCI metadata entry is present, only that driver class will be probed.
@@ -49,6 +50,10 @@ async def get_driver_for_host(host: Host) -> Optional[Driver]:
driver_classes: dict[str, type[Driver]] = {"rtk": rtk.Driver, "intel": intel.Driver}
probe_list: Iterable[str]
if driver_name := host.hci_metadata.get("driver"):
# The "driver" metadata may include runtime options after a '/' (for example
# "intel/ddc=..."). Keep only the base driver name (the portion before the
# first slash) so it matches a key in driver_classes (e.g. "intel").
driver_name = driver_name.split("/")[0]
# Only probe a single driver
probe_list = [driver_name]
else:

View File

@@ -29,7 +29,7 @@ import os
import pathlib
import platform
import struct
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from bumble import core, hci, utils
from bumble.drivers import common
@@ -353,8 +353,8 @@ class Driver(common.Driver):
self.reset_complete = asyncio.Event()
# Parse configuration options from the driver name.
self.ddc_addon: Optional[bytes] = None
self.ddc_override: Optional[bytes] = None
self.ddc_addon: bytes | None = None
self.ddc_override: bytes | None = None
driver = host.hci_metadata.get("driver")
if driver is not None and driver.startswith("intel/"):
for key, value in [
@@ -380,7 +380,7 @@ class Driver(common.Driver):
if (vendor_id, product_id) not in INTEL_USB_PRODUCTS:
logger.debug(
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
f"USB device ({vendor_id:04X}, {product_id:04X}) not in known list"
)
return False
@@ -459,6 +459,10 @@ class Driver(common.Driver):
== ModeOfOperation.OPERATIONAL
):
logger.debug("firmware already loaded")
# If the firmeare is already loaded, still attempt to load any
# device configuration (DDC). DDC can be applied independently of a
# firmware reload and may contain runtime overrides or patches.
await self.load_ddc_if_any()
return
# We only support some platforms and variants.
@@ -479,9 +483,7 @@ class Driver(common.Driver):
raise DriverError("insufficient device info, missing CNVI or CNVR")
firmware_base_name = (
"ibt-"
f"{device_info[ValueType.CNVI]:04X}-"
f"{device_info[ValueType.CNVR]:04X}"
f"ibt-{device_info[ValueType.CNVI]:04X}-{device_info[ValueType.CNVR]:04X}"
)
logger.debug(f"FW base name: {firmware_base_name}")
@@ -598,17 +600,39 @@ class Driver(common.Driver):
await self.reset_complete.wait()
logger.debug("reset complete")
# Load the device config if there is one.
await self.load_ddc_if_any(firmware_base_name)
async def load_ddc_if_any(self, firmware_base_name: str | None = None) -> None:
"""
Check for and load any Device Data Configuration (DDC) blobs.
Args:
firmware_base_name: Base name of the selected firmware (e.g. "ibt-XXXX-YYYY").
If None, don't attempt to look up a .ddc file that
corresponds to the firmware image.
Priority:
1. If a ddc_override was provided via driver metadata, use it (highest priority).
2. Otherwise, if firmware_base_name is provided, attempt to find a .ddc file
that corresponds to the selected firmware image.
3. Finally, if a ddc_addon was provided, append/load it after the primary DDC.
"""
# If an explicit DDC override was supplied, use it and skip file lookup.
if self.ddc_override:
logger.debug("loading overridden DDC")
await self.load_device_config(self.ddc_override)
else:
ddc_name = f"{firmware_base_name}.ddc"
ddc_path = _find_binary_path(ddc_name)
if ddc_path:
logger.debug(f"loading DDC from {ddc_path}")
ddc_data = ddc_path.read_bytes()
await self.load_device_config(ddc_data)
# Only attempt .ddc file lookup if a firmware_base_name was provided.
if firmware_base_name is None:
logger.debug(
"no firmware_base_name provided; skipping .ddc file lookup"
)
else:
ddc_name = f"{firmware_base_name}.ddc"
ddc_path = _find_binary_path(ddc_name)
if ddc_path:
logger.debug(f"loading DDC from {ddc_path}")
ddc_data = ddc_path.read_bytes()
await self.load_device_config(ddc_data)
if self.ddc_addon:
logger.debug("loading DDC addon")
await self.load_device_config(self.ddc_addon)

View File

@@ -115,12 +115,14 @@ RTK_USB_PRODUCTS = {
# Realtek 8761BUV
(0x0B05, 0x190E),
(0x0BDA, 0x8771),
(0x0BDA, 0x877B),
(0x0BDA, 0xA728),
(0x0BDA, 0xA729),
(0x2230, 0x0016),
(0x2357, 0x0604),
(0x2550, 0x8761),
(0x2B89, 0x8761),
(0x7392, 0xC611),
(0x0BDA, 0x877B),
# Realtek 8821AE
(0x0B05, 0x17DC),
(0x13D3, 0x3414),
@@ -482,7 +484,7 @@ class Driver(common.Driver):
if (vendor_id, product_id) not in RTK_USB_PRODUCTS:
logger.debug(
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
f"USB device ({vendor_id:04X}, {product_id:04X}) not in known list"
)
return False

View File

@@ -28,7 +28,8 @@ import enum
import functools
import logging
import struct
from typing import Iterable, Optional, Sequence, TypeVar, Union
from collections.abc import Iterable, Sequence
from typing import TypeVar
from bumble.att import Attribute, AttributeValue
from bumble.colors import color
@@ -227,7 +228,6 @@ GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x
GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA5, 'Media Control Point Opcodes Supported')
GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA6, 'Search Results Object ID')
GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA7, 'Search Control Point')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control Id')
# Telephone Bearer Service (TBS)
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB3, 'Bearer Provider Name')
@@ -356,7 +356,7 @@ class Service(Attribute):
def __init__(
self,
uuid: Union[str, UUID],
uuid: str | UUID,
characteristics: Iterable[Characteristic],
primary=True,
included_services: Iterable[Service] = (),
@@ -379,7 +379,7 @@ class Service(Attribute):
self.characteristics = list(characteristics)
self.primary = primary
def get_advertising_data(self) -> Optional[bytes]:
def get_advertising_data(self) -> bytes | None:
"""
Get Service specific advertising data
Defined by each Service, default value is empty
@@ -503,10 +503,10 @@ class Characteristic(Attribute[_T]):
def __init__(
self,
uuid: Union[str, bytes, UUID],
uuid: str | bytes | UUID,
properties: Characteristic.Properties,
permissions: Union[str, Attribute.Permissions],
value: Union[AttributeValue[_T], _T, None] = None,
permissions: str | Attribute.Permissions,
value: AttributeValue[_T] | _T | None = None,
descriptors: Sequence[Descriptor] = (),
):
super().__init__(uuid, permissions, value)

View File

@@ -22,7 +22,8 @@
from __future__ import annotations
import struct
from typing import Any, Callable, Generic, Iterable, Literal, Optional, TypeVar
from collections.abc import Callable, Iterable
from typing import Any, Generic, Literal, TypeVar
from bumble import utils
from bumble.core import InvalidOperationError
@@ -74,8 +75,8 @@ class DelegatedCharacteristicAdapter(CharacteristicAdapter[_T]):
def __init__(
self,
characteristic: Characteristic,
encode: Optional[Callable[[_T], bytes]] = None,
decode: Optional[Callable[[bytes], _T]] = None,
encode: Callable[[_T], bytes] | None = None,
decode: Callable[[bytes], _T] | None = None,
):
super().__init__(characteristic)
self.encode = encode
@@ -101,8 +102,8 @@ class DelegatedCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T]):
def __init__(
self,
characteristic_proxy: CharacteristicProxy,
encode: Optional[Callable[[_T], bytes]] = None,
decode: Optional[Callable[[bytes], _T]] = None,
encode: Callable[[_T], bytes] | None = None,
decode: Callable[[bytes], _T] | None = None,
):
super().__init__(characteristic_proxy)
self.encode = encode
@@ -361,5 +362,4 @@ class EnumCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T3]):
def decode_value(self, value: bytes) -> _T3:
int_value = int.from_bytes(value, self.byteorder)
a = self.cls(int_value)
return self.cls(int_value)

View File

@@ -28,16 +28,13 @@ from __future__ import annotations
import asyncio
import logging
import struct
from collections.abc import Callable, Iterable
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Iterable,
Optional,
TypeVar,
Union,
)
from bumble import att, core, utils
@@ -192,7 +189,7 @@ class CharacteristicProxy(AttributeProxy[_T]):
self.descriptors_discovered = False
self.subscribers = {} # Map from subscriber to proxy subscriber
def get_descriptor(self, descriptor_type: UUID) -> Optional[DescriptorProxy]:
def get_descriptor(self, descriptor_type: UUID) -> DescriptorProxy | None:
for descriptor in self.descriptors:
if descriptor.type == descriptor_type:
return descriptor
@@ -204,7 +201,7 @@ class CharacteristicProxy(AttributeProxy[_T]):
async def subscribe(
self,
subscriber: Optional[Callable[[_T], Any]] = None,
subscriber: Callable[[_T], Any] | None = None,
prefer_notify: bool = True,
) -> None:
if subscriber is not None:
@@ -253,7 +250,7 @@ class ProfileServiceProxy:
SERVICE_CLASS: type[TemplateService]
@classmethod
def from_client(cls, client: Client) -> Optional[ProfileServiceProxy]:
def from_client(cls, client: Client) -> ProfileServiceProxy | None:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -264,13 +261,11 @@ class Client:
services: list[ServiceProxy]
cached_values: dict[int, tuple[datetime, bytes]]
notification_subscribers: dict[
int, set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
int, set[CharacteristicProxy | Callable[[bytes], Any]]
]
indication_subscribers: dict[
int, set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
pending_response: Optional[asyncio.futures.Future[att.ATT_PDU]]
pending_request: Optional[att.ATT_PDU]
indication_subscribers: dict[int, set[CharacteristicProxy | Callable[[bytes], Any]]]
pending_response: asyncio.futures.Future[att.ATT_PDU] | None
pending_request: att.ATT_PDU | None
def __init__(self, connection: Connection) -> None:
self.connection = connection
@@ -360,7 +355,7 @@ class Client:
return [service for service in self.services if service.uuid == uuid]
def get_characteristics_by_uuid(
self, uuid: UUID, service: Optional[ServiceProxy] = None
self, uuid: UUID, service: ServiceProxy | None = None
) -> list[CharacteristicProxy[bytes]]:
services = [service] if service else self.services
return [
@@ -369,13 +364,14 @@ class Client:
if c.uuid == uuid
]
def get_attribute_grouping(self, attribute_handle: int) -> Optional[
Union[
ServiceProxy,
tuple[ServiceProxy, CharacteristicProxy],
tuple[ServiceProxy, CharacteristicProxy, DescriptorProxy],
]
]:
def get_attribute_grouping(
self, attribute_handle: int
) -> (
ServiceProxy
| tuple[ServiceProxy, CharacteristicProxy]
| tuple[ServiceProxy, CharacteristicProxy, DescriptorProxy]
| None
):
"""
Get the attribute(s) associated with an attribute handle
"""
@@ -478,7 +474,7 @@ class Client:
return services
async def discover_service(self, uuid: Union[str, UUID]) -> list[ServiceProxy]:
async def discover_service(self, uuid: str | UUID) -> list[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
'''
@@ -612,7 +608,7 @@ class Client:
return included_services
async def discover_characteristics(
self, uuids, service: Optional[ServiceProxy]
self, uuids, service: ServiceProxy | None
) -> list[CharacteristicProxy[bytes]]:
'''
See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2
@@ -699,9 +695,9 @@ class Client:
async def discover_descriptors(
self,
characteristic: Optional[CharacteristicProxy] = None,
start_handle: Optional[int] = None,
end_handle: Optional[int] = None,
characteristic: CharacteristicProxy | None = None,
start_handle: int | None = None,
end_handle: int | None = None,
) -> list[DescriptorProxy]:
'''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
@@ -810,7 +806,7 @@ class Client:
async def subscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[Any], Any]] = None,
subscriber: Callable[[Any], Any] | None = None,
prefer_notify: bool = True,
) -> None:
# If we haven't already discovered the descriptors for this characteristic,
@@ -860,7 +856,7 @@ class Client:
async def unsubscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[Any], Any]] = None,
subscriber: Callable[[Any], Any] | None = None,
force: bool = False,
) -> None:
'''
@@ -925,7 +921,7 @@ class Client:
await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value(
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
self, attribute: int | AttributeProxy, no_long_read: bool = False
) -> bytes:
'''
See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -980,7 +976,7 @@ class Client:
return attribute_value
async def read_characteristics_by_uuid(
self, uuid: UUID, service: Optional[ServiceProxy]
self, uuid: UUID, service: ServiceProxy | None
) -> list[bytes]:
'''
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
@@ -1038,7 +1034,7 @@ class Client:
async def write_value(
self,
attribute: Union[int, AttributeProxy],
attribute: int | AttributeProxy,
value: bytes,
with_response: bool = False,
) -> None:

View File

@@ -29,7 +29,8 @@ import asyncio
import logging
import struct
from collections import defaultdict
from typing import TYPE_CHECKING, Iterable, Optional, TypeVar
from collections.abc import Iterable
from typing import TYPE_CHECKING, TypeVar
from bumble import att, utils
from bumble.colors import color
@@ -73,7 +74,7 @@ class Server(utils.EventEmitter):
attributes_by_handle: dict[int, att.Attribute]
subscribers: dict[int, dict[int, bytes]]
indication_semaphores: defaultdict[int, asyncio.Semaphore]
pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
pending_confirmations: defaultdict[int, asyncio.futures.Future | None]
EVENT_CHARACTERISTIC_SUBSCRIPTION = "characteristic_subscription"
@@ -109,7 +110,7 @@ class Server(utils.EventEmitter):
and (data := attribute.get_advertising_data())
}
def get_attribute(self, handle: int) -> Optional[att.Attribute]:
def get_attribute(self, handle: int) -> att.Attribute | None:
attribute = self.attributes_by_handle.get(handle)
if attribute:
return attribute
@@ -126,7 +127,7 @@ class Server(utils.EventEmitter):
def get_attribute_group(
self, handle: int, group_type: type[AttributeGroupType]
) -> Optional[AttributeGroupType]:
) -> AttributeGroupType | None:
return next(
(
attribute
@@ -137,7 +138,7 @@ class Server(utils.EventEmitter):
None,
)
def get_service_attribute(self, service_uuid: UUID) -> Optional[Service]:
def get_service_attribute(self, service_uuid: UUID) -> Service | None:
return next(
(
attribute
@@ -151,7 +152,7 @@ class Server(utils.EventEmitter):
def get_characteristic_attributes(
self, service_uuid: UUID, characteristic_uuid: UUID
) -> Optional[tuple[CharacteristicDeclaration, Characteristic]]:
) -> tuple[CharacteristicDeclaration, Characteristic] | None:
service_handle = self.get_service_attribute(service_uuid)
if not service_handle:
return None
@@ -176,7 +177,7 @@ class Server(utils.EventEmitter):
def get_descriptor_attribute(
self, service_uuid: UUID, characteristic_uuid: UUID, descriptor_uuid: UUID
) -> Optional[Descriptor]:
) -> Descriptor | None:
characteristics = self.get_characteristic_attributes(
service_uuid, characteristic_uuid
)
@@ -334,7 +335,7 @@ class Server(utils.EventEmitter):
self,
connection: Connection,
attribute: att.Attribute,
value: Optional[bytes] = None,
value: bytes | None = None,
force: bool = False,
) -> None:
# Check if there's a subscriber
@@ -377,7 +378,7 @@ class Server(utils.EventEmitter):
self,
connection: Connection,
attribute: att.Attribute,
value: Optional[bytes] = None,
value: bytes | None = None,
force: bool = False,
) -> None:
# Check if there's a subscriber
@@ -437,7 +438,7 @@ class Server(utils.EventEmitter):
self,
indicate: bool,
attribute: att.Attribute,
value: Optional[bytes] = None,
value: bytes | None = None,
force: bool = False,
) -> None:
# Get all the connections for which there's at least one subscription
@@ -464,7 +465,7 @@ class Server(utils.EventEmitter):
async def notify_subscribers(
self,
attribute: att.Attribute,
value: Optional[bytes] = None,
value: bytes | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(
@@ -474,7 +475,7 @@ class Server(utils.EventEmitter):
async def indicate_subscribers(
self,
attribute: att.Attribute,
value: Optional[bytes] = None,
value: bytes | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(True, attribute, value, force)

View File

@@ -24,17 +24,13 @@ import functools
import logging
import secrets
import struct
from collections.abc import Sequence
from collections.abc import Callable, Iterable, Sequence
from dataclasses import field
from typing import (
Any,
Callable,
ClassVar,
Iterable,
Literal,
Optional,
TypeVar,
Union,
cast,
)
@@ -106,7 +102,7 @@ def map_class_of_device(class_of_device):
)
def phy_list_to_bits(phys: Optional[Iterable[Phy]]) -> int:
def phy_list_to_bits(phys: Iterable[Phy] | None) -> int:
if phys is None:
return 0
@@ -119,7 +115,6 @@ def phy_list_to_bits(phys: Optional[Iterable[Phy]]) -> int:
class SpecableEnum(utils.OpenIntEnum):
@classmethod
def type_spec(cls, size: int, byteorder: Literal['little', 'big'] = 'little'):
return {
@@ -147,7 +142,6 @@ class SpecableEnum(utils.OpenIntEnum):
class SpecableFlag(enum.IntFlag):
@classmethod
def type_spec(cls, size: int, byteorder: Literal['little', 'big'] = 'little'):
return {
@@ -186,8 +180,8 @@ class SpecableFlag(enum.IntFlag):
# - "v" for variable length bytes with a leading length byte
# - an integer [1, 4] for 1-byte, 2-byte or 4-byte unsigned little-endian integers
# - an integer [-2, -1] for 1-byte, 2-byte signed little-endian integers
FieldSpec = Union[dict[str, Any], Callable[[bytes, int], tuple[int, Any]], str, int]
Fields = Sequence[Union[tuple[str, FieldSpec], 'Fields']]
FieldSpec = dict[str, Any] | Callable[[bytes, int], tuple[int, Any]] | str | int
Fields = Sequence['tuple[str, FieldSpec] | Fields']
@dataclasses.dataclass
@@ -213,22 +207,44 @@ def metadata(
HCI_VENDOR_OGF = 0x3F
# HCI Version
HCI_VERSION_BLUETOOTH_CORE_1_0B = 0
HCI_VERSION_BLUETOOTH_CORE_1_1 = 1
HCI_VERSION_BLUETOOTH_CORE_1_2 = 2
HCI_VERSION_BLUETOOTH_CORE_2_0_EDR = 3
HCI_VERSION_BLUETOOTH_CORE_2_1_EDR = 4
HCI_VERSION_BLUETOOTH_CORE_3_0_HS = 5
HCI_VERSION_BLUETOOTH_CORE_4_0 = 6
HCI_VERSION_BLUETOOTH_CORE_4_1 = 7
HCI_VERSION_BLUETOOTH_CORE_4_2 = 8
HCI_VERSION_BLUETOOTH_CORE_5_0 = 9
HCI_VERSION_BLUETOOTH_CORE_5_1 = 10
HCI_VERSION_BLUETOOTH_CORE_5_2 = 11
HCI_VERSION_BLUETOOTH_CORE_5_3 = 12
HCI_VERSION_BLUETOOTH_CORE_5_4 = 13
HCI_VERSION_BLUETOOTH_CORE_6_0 = 14
# Specification Version
class SpecificationVersion(utils.OpenIntEnum):
BLUETOOTH_CORE_1_0B = 0
BLUETOOTH_CORE_1_1 = 1
BLUETOOTH_CORE_1_2 = 2
BLUETOOTH_CORE_2_0_EDR = 3
BLUETOOTH_CORE_2_1_EDR = 4
BLUETOOTH_CORE_3_0_HS = 5
BLUETOOTH_CORE_4_0 = 6
BLUETOOTH_CORE_4_1 = 7
BLUETOOTH_CORE_4_2 = 8
BLUETOOTH_CORE_5_0 = 9
BLUETOOTH_CORE_5_1 = 10
BLUETOOTH_CORE_5_2 = 11
BLUETOOTH_CORE_5_3 = 12
BLUETOOTH_CORE_5_4 = 13
BLUETOOTH_CORE_6_0 = 14
BLUETOOTH_CORE_6_1 = 15
BLUETOOTH_CORE_6_2 = 16
# For backwards compatibility only
HCI_VERSION_BLUETOOTH_CORE_1_0B = SpecificationVersion.BLUETOOTH_CORE_1_0B
HCI_VERSION_BLUETOOTH_CORE_1_1 = SpecificationVersion.BLUETOOTH_CORE_1_1
HCI_VERSION_BLUETOOTH_CORE_1_2 = SpecificationVersion.BLUETOOTH_CORE_1_2
HCI_VERSION_BLUETOOTH_CORE_2_0_EDR = SpecificationVersion.BLUETOOTH_CORE_2_0_EDR
HCI_VERSION_BLUETOOTH_CORE_2_1_EDR = SpecificationVersion.BLUETOOTH_CORE_2_1_EDR
HCI_VERSION_BLUETOOTH_CORE_3_0_HS = SpecificationVersion.BLUETOOTH_CORE_3_0_HS
HCI_VERSION_BLUETOOTH_CORE_4_0 = SpecificationVersion.BLUETOOTH_CORE_4_0
HCI_VERSION_BLUETOOTH_CORE_4_1 = SpecificationVersion.BLUETOOTH_CORE_4_1
HCI_VERSION_BLUETOOTH_CORE_4_2 = SpecificationVersion.BLUETOOTH_CORE_4_2
HCI_VERSION_BLUETOOTH_CORE_5_0 = SpecificationVersion.BLUETOOTH_CORE_5_0
HCI_VERSION_BLUETOOTH_CORE_5_1 = SpecificationVersion.BLUETOOTH_CORE_5_1
HCI_VERSION_BLUETOOTH_CORE_5_2 = SpecificationVersion.BLUETOOTH_CORE_5_2
HCI_VERSION_BLUETOOTH_CORE_5_3 = SpecificationVersion.BLUETOOTH_CORE_5_3
HCI_VERSION_BLUETOOTH_CORE_5_4 = SpecificationVersion.BLUETOOTH_CORE_5_4
HCI_VERSION_BLUETOOTH_CORE_6_0 = SpecificationVersion.BLUETOOTH_CORE_6_0
HCI_VERSION_BLUETOOTH_CORE_6_1 = SpecificationVersion.BLUETOOTH_CORE_6_1
HCI_VERSION_BLUETOOTH_CORE_6_2 = SpecificationVersion.BLUETOOTH_CORE_6_2
HCI_VERSION_NAMES = {
HCI_VERSION_BLUETOOTH_CORE_1_0B: 'HCI_VERSION_BLUETOOTH_CORE_1_0B',
@@ -246,9 +262,10 @@ HCI_VERSION_NAMES = {
HCI_VERSION_BLUETOOTH_CORE_5_3: 'HCI_VERSION_BLUETOOTH_CORE_5_3',
HCI_VERSION_BLUETOOTH_CORE_5_4: 'HCI_VERSION_BLUETOOTH_CORE_5_4',
HCI_VERSION_BLUETOOTH_CORE_6_0: 'HCI_VERSION_BLUETOOTH_CORE_6_0',
HCI_VERSION_BLUETOOTH_CORE_6_1: 'HCI_VERSION_BLUETOOTH_CORE_6_1',
HCI_VERSION_BLUETOOTH_CORE_6_2: 'HCI_VERSION_BLUETOOTH_CORE_6_2',
}
# LMP Version
LMP_VERSION_NAMES = HCI_VERSION_NAMES
# HCI Packet types
@@ -1786,20 +1803,20 @@ class HCI_Object:
@classmethod
def dict_and_offset_from_bytes(
cls, data: bytes, offset: int, fields: Fields
cls, data: bytes, offset: int, object_fields: Fields
) -> tuple[int, collections.OrderedDict[str, Any]]:
result = collections.OrderedDict[str, Any]()
for field in fields:
if isinstance(field, list):
for object_field in object_fields:
if isinstance(object_field, list):
# This is an array field, starting with a 1-byte item count.
item_count = data[offset]
offset += 1
# Set fields first, because item_count might be 0.
for sub_field_name, _ in field:
for sub_field_name, _ in object_field:
result[sub_field_name] = []
for _ in range(item_count):
for sub_field_name, sub_field_type in field:
for sub_field_name, sub_field_type in object_field:
value, size = HCI_Object.parse_field(
data, offset, sub_field_type
)
@@ -1807,7 +1824,7 @@ class HCI_Object:
offset += size
continue
field_name, field_type = field
field_name, field_type = object_field
assert isinstance(field_name, str)
field_value, field_size = HCI_Object.parse_field(
data, offset, cast(FieldSpec, field_type)
@@ -1890,26 +1907,26 @@ class HCI_Object:
return field_bytes
@staticmethod
def dict_to_bytes(hci_object, fields):
def dict_to_bytes(hci_object, object_fields):
result = bytearray()
for field in fields:
if isinstance(field, list):
for object_field in object_fields:
if isinstance(object_field, list):
# The field is an array. The serialized form starts with a 1-byte
# item count. We use the length of the first array field as the
# array count, since all array fields have the same number of items.
item_count = len(hci_object[field[0][0]])
item_count = len(hci_object[object_field[0][0]])
result += bytes([item_count]) + b''.join(
b''.join(
HCI_Object.serialize_field(
hci_object[sub_field_name][i], sub_field_type
)
for sub_field_name, sub_field_type in field
for sub_field_name, sub_field_type in object_field
)
for i in range(item_count)
)
continue
(field_name, field_type) = field
(field_name, field_type) = object_field
result += HCI_Object.serialize_field(hci_object[field_name], field_type)
return bytes(result)
@@ -1967,15 +1984,15 @@ class HCI_Object:
)
@staticmethod
def format_fields(hci_object, fields, indentation='', value_mappers=None):
if not fields:
def format_fields(hci_object, object_fields, indentation='', value_mappers=None):
if not object_fields:
return ''
# Build array of formatted key:value pairs
field_strings = []
for field in fields:
if isinstance(field, list):
for sub_field in field:
for object_field in object_fields:
if isinstance(object_field, list):
for sub_field in object_field:
sub_field_name, sub_field_type = sub_field
item_count = len(hci_object[sub_field_name])
for i in range(item_count):
@@ -1993,7 +2010,7 @@ class HCI_Object:
)
continue
field_name, field_type = field
field_name, field_type = object_field
field_value = hci_object[field_name]
field_strings.append(
(
@@ -2016,16 +2033,16 @@ class HCI_Object:
@classmethod
def fields_from_dataclass(cls, obj: Any) -> list[Any]:
stack: list[list[Any]] = [[]]
for field in dataclasses.fields(obj):
for object_field in dataclasses.fields(obj):
# Fields without metadata should be ignored.
if not isinstance(
(metadata := field.metadata.get("bumble.hci")), FieldMetadata
(metadata := object_field.metadata.get("bumble.hci")), FieldMetadata
):
continue
if metadata.list_begin:
stack.append([])
if metadata.spec:
stack[-1].append((field.name, metadata.spec))
stack[-1].append((object_field.name, metadata.spec))
if metadata.list_end:
top = stack.pop()
stack[-1].append(top)
@@ -2158,7 +2175,7 @@ class Address:
def __init__(
self,
address: Union[bytes, str],
address: bytes | str,
address_type: AddressType = RANDOM_DEVICE_ADDRESS,
) -> None:
'''
@@ -2423,9 +2440,9 @@ class HCI_Command(HCI_Packet):
def __init__(
self,
parameters: Optional[bytes] = None,
parameters: bytes | None = None,
*,
op_code: Optional[int] = None,
op_code: int | None = None,
**kwargs,
) -> None:
# op_code should be set in cls.
@@ -3441,6 +3458,17 @@ class HCI_Write_Synchronous_Flow_Control_Enable_Command(HCI_Command):
synchronous_flow_control_enable: int = field(metadata=metadata(1))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
class HCI_Set_Controller_To_Host_Flow_Control_Command(HCI_Command):
'''
See Bluetooth spec @ 7.3.38 Set Controller To Host Flow Control command
'''
flow_control_enable: int = field(metadata=metadata(1))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
@@ -4338,6 +4366,15 @@ class HCI_LE_Write_Suggested_Default_Data_Length_Command(HCI_Command):
suggested_max_tx_time: int = field(metadata=metadata(2))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
class HCI_LE_Read_Local_P_256_Public_Key_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.36 LE LE Read Local P-256 Public Key command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
@@ -4365,6 +4402,15 @@ class HCI_LE_Clear_Resolving_List_Command(HCI_Command):
'''
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
class HCI_LE_Read_Resolving_List_Size_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.41 LE Read Resolving List Size command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
@@ -5028,6 +5074,15 @@ class HCI_LE_Periodic_Advertising_Terminate_Sync_Command(HCI_Command):
sync_handle: int = field(metadata=metadata(2))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
class HCI_LE_Read_Transmit_Power_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.74 LE Read Transmit Power command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
@@ -5678,7 +5733,7 @@ class HCI_Event(HCI_Packet):
hci_packet_type = HCI_EVENT_PACKET
event_names: dict[int, str] = {}
event_classes: dict[int, type[HCI_Event]] = {}
vendor_factories: list[Callable[[bytes], Optional[HCI_Event]]] = []
vendor_factories: list[Callable[[bytes], HCI_Event | None]] = []
event_code: int
fields: Fields = ()
_parameters: bytes = b''
@@ -5739,14 +5794,12 @@ class HCI_Event(HCI_Packet):
return event_class
@classmethod
def add_vendor_factory(
cls, factory: Callable[[bytes], Optional[HCI_Event]]
) -> None:
def add_vendor_factory(cls, factory: Callable[[bytes], HCI_Event | None]) -> None:
cls.vendor_factories.append(factory)
@classmethod
def remove_vendor_factory(
cls, factory: Callable[[bytes], Optional[HCI_Event]]
cls, factory: Callable[[bytes], HCI_Event | None]
) -> None:
if factory in cls.vendor_factories:
cls.vendor_factories.remove(factory)
@@ -5759,7 +5812,7 @@ class HCI_Event(HCI_Packet):
if len(parameters) != length:
raise InvalidPacketError('invalid packet length')
subclass: Optional[type[HCI_Event]]
subclass: type[HCI_Event] | None
if event_code == HCI_LE_META_EVENT:
# We do this dispatch here and not in the subclass in order to avoid call
# loops
@@ -5797,9 +5850,9 @@ class HCI_Event(HCI_Packet):
def __init__(
self,
parameters: Optional[bytes] = None,
parameters: bytes | None = None,
*,
event_code: Optional[int] = None,
event_code: int | None = None,
**kwargs,
):
if event_code is not None:
@@ -5908,9 +5961,7 @@ class HCI_Extended_Event(HCI_Event):
cls.subevent_names.update(cls.subevent_map(symbols))
@classmethod
def subclass_from_parameters(
cls, parameters: bytes
) -> Optional[HCI_Extended_Event]:
def subclass_from_parameters(cls, parameters: bytes) -> HCI_Extended_Event | None:
"""
Factory method that parses the subevent code, finds a registered subclass,
and creates an instance if found.
@@ -5930,9 +5981,9 @@ class HCI_Extended_Event(HCI_Event):
def __init__(
self,
parameters: Optional[bytes] = None,
parameters: bytes | None = None,
*,
subevent_code: Optional[int] = None,
subevent_code: int | None = None,
**kwargs,
) -> None:
if subevent_code is not None:
@@ -6928,7 +6979,7 @@ class HCI_Command_Complete_Event(HCI_Event):
command_opcode: int = field(
metadata=metadata({'size': 2, 'mapper': HCI_Command.command_name})
)
return_parameters: Union[bytes, HCI_Object, int] = field(metadata=metadata("*"))
return_parameters: bytes | HCI_Object | int = field(metadata=metadata("*"))
def map_return_parameters(self, return_parameters):
'''Map simple 'status' return parameters to their named constant form'''
@@ -7512,20 +7563,20 @@ class HCI_IsoDataPacket(HCI_Packet):
iso_sdu_fragment: bytes
pb_flag: int
ts_flag: int = 0
time_stamp: Optional[int] = None
packet_sequence_number: Optional[int] = None
iso_sdu_length: Optional[int] = None
packet_status_flag: Optional[int] = None
time_stamp: int | None = None
packet_sequence_number: int | None = None
iso_sdu_length: int | None = None
packet_status_flag: int | None = None
def __post_init__(self) -> None:
self.ts_flag = self.time_stamp is not None
@staticmethod
def from_bytes(packet: bytes) -> HCI_IsoDataPacket:
time_stamp: Optional[int] = None
packet_sequence_number: Optional[int] = None
iso_sdu_length: Optional[int] = None
packet_status_flag: Optional[int] = None
time_stamp: int | None = None
packet_sequence_number: int | None = None
iso_sdu_length: int | None = None
packet_status_flag: int | None = None
pos = 1
pdu_info, data_total_length = struct.unpack_from('<HH', packet, pos)
@@ -7608,7 +7659,7 @@ class HCI_IsoDataPacket(HCI_Packet):
# -----------------------------------------------------------------------------
class HCI_AclDataPacketAssembler:
current_data: Optional[bytes]
current_data: bytes | None
def __init__(self, callback: Callable[[bytes], Any]) -> None:
self.callback = callback

View File

@@ -20,7 +20,7 @@ from __future__ import annotations
import datetime
import logging
from collections.abc import Callable, MutableMapping
from typing import Any, Optional, cast
from typing import Any, cast
from bumble import avc, avctp, avdtp, avrcp, crypto, rfcomm, sdp
from bumble.att import ATT_CID, ATT_PDU
@@ -70,7 +70,7 @@ AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'}
class PacketTracer:
class AclStream:
psms: MutableMapping[int, int]
peer: Optional[PacketTracer.AclStream]
peer: PacketTracer.AclStream | None
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
avctp_assemblers: MutableMapping[int, avctp.MessageAssembler]
@@ -201,7 +201,7 @@ class PacketTracer:
self.label = label
self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle
self.packet_timestamp: Optional[datetime.datetime] = None
self.packet_timestamp: datetime.datetime | None = None
def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream:
logger.info(
@@ -230,7 +230,7 @@ class PacketTracer:
self.peer.end_acl_stream(connection_handle)
def on_packet(
self, timestamp: Optional[datetime.datetime], packet: HCI_Packet
self, timestamp: datetime.datetime | None, packet: HCI_Packet
) -> None:
self.packet_timestamp = timestamp
self.emit(packet)
@@ -262,7 +262,7 @@ class PacketTracer:
self,
packet: HCI_Packet,
direction: int = 0,
timestamp: Optional[datetime.datetime] = None,
timestamp: datetime.datetime | None = None,
) -> None:
if direction == 0:
self.host_to_controller_analyzer.on_packet(timestamp, packet)

View File

@@ -25,7 +25,8 @@ import enum
import logging
import re
import traceback
from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional, Union
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, ClassVar
from typing_extensions import Self
@@ -80,7 +81,7 @@ class HfpProtocol:
dlc.sink = self.feed
def feed(self, data: Union[bytes, str]) -> None:
def feed(self, data: bytes | str) -> None:
# Convert the data to a string if needed
if isinstance(data, bytes):
data = data.decode('utf-8')
@@ -324,8 +325,8 @@ class CallInfo:
status: CallInfoStatus
mode: CallInfoMode
multi_party: CallInfoMultiParty
number: Optional[str] = None
type: Optional[int] = None
number: str | None = None
type: int | None = None
@dataclasses.dataclass
@@ -353,10 +354,10 @@ class CallLineIdentification:
number: str
type: int
subaddr: Optional[str] = None
satype: Optional[int] = None
alpha: Optional[str] = None
cli_validity: Optional[int] = None
subaddr: str | None = None
satype: int | None = None
alpha: str | None = None
cli_validity: int | None = None
@classmethod
def parse_from(cls, parameters: list[bytes]) -> Self:
@@ -489,9 +490,9 @@ STATUS_CODES = {
@dataclasses.dataclass
class HfConfiguration:
supported_hf_features: list[HfFeature]
supported_hf_indicators: list[HfIndicator]
supported_audio_codecs: list[AudioCodec]
supported_hf_features: collections.abc.Sequence[HfFeature]
supported_hf_indicators: collections.abc.Sequence[HfIndicator]
supported_audio_codecs: collections.abc.Sequence[AudioCodec]
@dataclasses.dataclass
@@ -584,7 +585,7 @@ class AgIndicatorState:
indicator: AgIndicator
supported_values: set[int]
current_status: int
index: Optional[int] = None
index: int | None = None
enabled: bool = True
@property
@@ -597,7 +598,7 @@ class AgIndicatorState:
supported_values_text = (
f'({",".join(str(v) for v in self.supported_values)})'
)
return f'(\"{self.indicator.value}\",{supported_values_text})'
return f'("{self.indicator.value}",{supported_values_text})'
@classmethod
def call(cls: type[Self]) -> Self:
@@ -728,7 +729,7 @@ class HfProtocol(utils.EventEmitter):
command_lock: asyncio.Lock
if TYPE_CHECKING:
response_queue: asyncio.Queue[AtResponse]
unsolicited_queue: asyncio.Queue[Optional[AtResponse]]
unsolicited_queue: asyncio.Queue[AtResponse | None]
else:
response_queue: asyncio.Queue
unsolicited_queue: asyncio.Queue
@@ -753,7 +754,7 @@ class HfProtocol(utils.EventEmitter):
# Build local features.
self.supported_hf_features = sum(configuration.supported_hf_features)
self.supported_audio_codecs = configuration.supported_audio_codecs
self.supported_audio_codecs = list(configuration.supported_audio_codecs)
self.hf_indicators = {
indicator: HfIndicatorState(indicator=indicator)
@@ -820,7 +821,7 @@ class HfProtocol(utils.EventEmitter):
cmd: str,
timeout: float = 1.0,
response_type: AtResponseType = AtResponseType.NONE,
) -> Union[None, AtResponse, list[AtResponse]]:
) -> None | AtResponse | list[AtResponse]:
"""
Sends an AT command and wait for the peer response.
Wait for the AT responses sent by the peer, to the status code.
@@ -1351,7 +1352,7 @@ class AgProtocol(utils.EventEmitter):
logger.warning(f'AG indicator {indicator} is disabled')
indicator_state.current_status = value
self.send_response(f'+CIEV: {index+1},{value}')
self.send_response(f'+CIEV: {index + 1},{value}')
async def negotiate_codec(self, codec: AudioCodec) -> None:
"""Starts codec negotiation."""
@@ -1411,13 +1412,13 @@ class AgProtocol(utils.EventEmitter):
self.emit(self.EVENT_VOICE_RECOGNITION, VoiceRecognitionState(int(vrec)))
def _on_chld(self, operation_code: bytes) -> None:
call_index: Optional[int] = None
call_index: int | None = None
if len(operation_code) > 1:
call_index = int(operation_code[1:])
operation_code = operation_code[:1] + b'x'
try:
operation = CallHoldOperation(operation_code.decode())
except:
except Exception:
logger.error(f'Invalid operation: {operation_code.decode()}')
self.send_cme_error(CmeError.OPERATION_NOT_SUPPORTED)
return
@@ -1481,8 +1482,8 @@ class AgProtocol(utils.EventEmitter):
def _on_cmer(
self,
mode: bytes,
keypad: Optional[bytes] = None,
display: Optional[bytes] = None,
keypad: bytes | None = None,
display: bytes | None = None,
indicator: bytes = b'',
) -> None:
if (
@@ -1589,7 +1590,7 @@ class AgProtocol(utils.EventEmitter):
def _on_clcc(self) -> None:
for call in self.calls:
number_text = f',\"{call.number}\"' if call.number is not None else ''
number_text = f',"{call.number}"' if call.number is not None else ''
type_text = f',{call.type}' if call.type is not None else ''
response = (
f'+CLCC: {call.index}'
@@ -1844,7 +1845,7 @@ def make_ag_sdp_records(
async def find_hf_sdp_record(
connection: device.Connection,
) -> Optional[tuple[int, ProfileVersion, HfSdpFeature]]:
) -> tuple[int, ProfileVersion, HfSdpFeature] | None:
"""Searches a Hands-Free SDP record from remote device.
Args:
@@ -1864,9 +1865,9 @@ async def find_hf_sdp_record(
],
)
for attribute_lists in search_result:
channel: Optional[int] = None
version: Optional[ProfileVersion] = None
features: Optional[HfSdpFeature] = None
channel: int | None = None
version: ProfileVersion | None = None
features: HfSdpFeature | None = None
for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
@@ -1896,7 +1897,7 @@ async def find_hf_sdp_record(
async def find_ag_sdp_record(
connection: device.Connection,
) -> Optional[tuple[int, ProfileVersion, AgSdpFeature]]:
) -> tuple[int, ProfileVersion, AgSdpFeature] | None:
"""Searches an Audio-Gateway SDP record from remote device.
Args:
@@ -1915,9 +1916,9 @@ async def find_ag_sdp_record(
],
)
for attribute_lists in search_result:
channel: Optional[int] = None
version: Optional[ProfileVersion] = None
features: Optional[AgSdpFeature] = None
channel: int | None = None
version: ProfileVersion | None = None
features: AgSdpFeature | None = None
for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:

View File

@@ -21,8 +21,8 @@ import enum
import logging
import struct
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import Callable, Optional
from typing_extensions import override
@@ -195,9 +195,9 @@ class SendHandshakeMessage(Message):
# -----------------------------------------------------------------------------
class HID(ABC, utils.EventEmitter):
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
connection: Optional[device.Connection] = None
l2cap_ctrl_channel: l2cap.ClassicChannel | None = None
l2cap_intr_channel: l2cap.ClassicChannel | None = None
connection: device.Connection | None = None
EVENT_INTERRUPT_DATA = "interrupt_data"
EVENT_CONTROL_DATA = "control_data"
@@ -212,7 +212,7 @@ class HID(ABC, utils.EventEmitter):
def __init__(self, device: device.Device, role: Role) -> None:
super().__init__()
self.remote_device_bd_address: Optional[Address] = None
self.remote_device_bd_address: Address | None = None
self.device = device
self.role = role
@@ -246,7 +246,7 @@ class HID(ABC, utils.EventEmitter):
# Create a new L2CAP connection - interrupt channel
try:
channel = await self.connection.create_l2cap_channel(
l2cap.ClassicChannelSpec(HID_CONTROL_PSM)
l2cap.ClassicChannelSpec(HID_INTERRUPT_PSM)
)
channel.sink = self.on_intr_pdu
self.l2cap_intr_channel = channel
@@ -353,10 +353,10 @@ class Device(HID):
data: bytes = b''
status: int = 0
get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
get_report_cb: Callable[[int, int, int], GetSetStatus] | None = None
set_report_cb: Callable[[int, int, int, bytes], GetSetStatus] | None = None
get_protocol_cb: Callable[[], GetSetStatus] | None = None
set_protocol_cb: Callable[[int], GetSetStatus] | None = None
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE)

View File

@@ -22,7 +22,8 @@ import collections
import dataclasses
import logging
import struct
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Union, cast
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast
from bumble import drivers, hci, utils
from bumble.colors import color
@@ -108,8 +109,7 @@ class DataPacketQueue(utils.EventEmitter):
if self._packets:
logger.debug(
f'{self._in_flight} packets in flight, '
f'{len(self._packets)} in queue'
f'{self._in_flight} packets in flight, {len(self._packets)} in queue'
)
def flush(self, connection_handle: int) -> None:
@@ -199,7 +199,7 @@ class Connection:
self.peer_address = peer_address
self.assembler = hci.HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport
acl_packet_queue: Optional[DataPacketQueue] = (
acl_packet_queue: DataPacketQueue | None = (
host.le_acl_packet_queue
if transport == PhysicalTransport.LE
else host.acl_packet_queue
@@ -242,20 +242,18 @@ class Host(utils.EventEmitter):
bis_links: dict[int, IsoLink]
sco_links: dict[int, ScoLink]
bigs: dict[int, set[int]]
acl_packet_queue: Optional[DataPacketQueue] = None
le_acl_packet_queue: Optional[DataPacketQueue] = None
iso_packet_queue: Optional[DataPacketQueue] = None
hci_sink: Optional[TransportSink] = None
acl_packet_queue: DataPacketQueue | None = None
le_acl_packet_queue: DataPacketQueue | None = None
iso_packet_queue: DataPacketQueue | None = None
hci_sink: TransportSink | None = None
hci_metadata: dict[str, Any]
long_term_key_provider: Optional[
Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
]
link_key_provider: Optional[Callable[[hci.Address], Awaitable[Optional[bytes]]]]
long_term_key_provider: Callable[[int, bytes, int], Awaitable[bytes | None]] | None
link_key_provider: Callable[[hci.Address], Awaitable[bytes | None]] | None
def __init__(
self,
controller_source: Optional[TransportSource] = None,
controller_sink: Optional[TransportSink] = None,
controller_source: TransportSource | None = None,
controller_sink: TransportSink | None = None,
) -> None:
super().__init__()
@@ -267,7 +265,7 @@ class Host(utils.EventEmitter):
self.sco_links = {} # SCO links, by connection handle
self.bigs = {} # BIG Handle to BIS Handles
self.pending_command = None
self.pending_response: Optional[asyncio.Future[Any]] = None
self.pending_response: asyncio.Future[Any] | None = None
self.number_of_supported_advertising_sets = 0
self.maximum_advertising_data_length = 31
self.local_version = None
@@ -280,7 +278,7 @@ class Host(utils.EventEmitter):
self.long_term_key_provider = None
self.link_key_provider = None
self.pairing_io_capability_provider = None # Classic only
self.snooper: Optional[Snooper] = None
self.snooper: Snooper | None = None
# Connect to the source and sink if specified
if controller_source:
@@ -291,9 +289,9 @@ class Host(utils.EventEmitter):
def find_connection_by_bd_addr(
self,
bd_addr: hci.Address,
transport: Optional[int] = None,
transport: int | None = None,
check_address_type: bool = False,
) -> Optional[Connection]:
) -> Connection | None:
for connection in self.connections.values():
if bytes(connection.peer_address) == bytes(bd_addr):
if (
@@ -550,7 +548,7 @@ class Host(utils.EventEmitter):
logger.debug(
'HCI LE flow control: '
f'le_acl_data_packet_length={le_acl_data_packet_length},'
f'total_num_le_acl_data_packets={total_num_le_acl_data_packets}'
f'total_num_le_acl_data_packets={total_num_le_acl_data_packets},'
f'iso_data_packet_length={iso_data_packet_length},'
f'total_num_iso_data_packets={total_num_iso_data_packets}'
)
@@ -633,7 +631,7 @@ class Host(utils.EventEmitter):
)
@property
def controller(self) -> Optional[TransportSink]:
def controller(self) -> TransportSink | None:
return self.hci_sink
@controller.setter
@@ -642,7 +640,7 @@ class Host(utils.EventEmitter):
if controller:
self.set_packet_source(controller)
def set_packet_sink(self, sink: Optional[TransportSink]) -> None:
def set_packet_sink(self, sink: TransportSink | None) -> None:
self.hci_sink = sink
def set_packet_source(self, source: TransportSource) -> None:
@@ -657,7 +655,7 @@ class Host(utils.EventEmitter):
self.hci_sink.on_packet(bytes(packet))
async def send_command(
self, command, check_result=False, response_timeout: Optional[int] = None
self, command, check_result=False, response_timeout: int | None = None
):
# Wait until we can send (only one pending command at a time)
async with self.command_semaphore:
@@ -707,7 +705,7 @@ class Host(utils.EventEmitter):
asyncio.create_task(send_command(command))
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
def send_acl_sdu(self, connection_handle: int, sdu: bytes) -> None:
if not (connection := self.connections.get(connection_handle)):
logger.warning(f'connection 0x{connection_handle:04X} not found')
return
@@ -718,27 +716,24 @@ class Host(utils.EventEmitter):
)
return
# Create a PDU
l2cap_pdu = bytes(L2CAP_PDU(cid, pdu))
# Send the data to the controller via ACL packets
bytes_remaining = len(l2cap_pdu)
offset = 0
pb_flag = 0
while bytes_remaining:
data_total_length = min(bytes_remaining, packet_queue.max_packet_size)
max_packet_size = packet_queue.max_packet_size
for offset in range(0, len(sdu), max_packet_size):
pdu = sdu[offset : offset + max_packet_size]
acl_packet = hci.HCI_AclDataPacket(
connection_handle=connection_handle,
pb_flag=pb_flag,
pb_flag=1 if offset > 0 else 0,
bc_flag=0,
data_total_length=data_total_length,
data=l2cap_pdu[offset : offset + data_total_length],
data_total_length=len(pdu),
data=pdu,
)
logger.debug(
'>>> ACL packet enqueue: (Handle=0x%04X) %s', connection_handle, pdu
)
logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}')
packet_queue.enqueue(acl_packet, connection_handle)
pb_flag = 1
offset += data_total_length
bytes_remaining -= data_total_length
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.send_acl_sdu(connection_handle, bytes(L2CAP_PDU(cid, pdu)))
def get_data_packet_queue(self, connection_handle: int) -> DataPacketQueue | None:
if connection := self.connections.get(connection_handle):
@@ -903,7 +898,7 @@ class Host(utils.EventEmitter):
self.emit('l2cap_pdu', connection.handle, cid, pdu)
def on_command_processed(
self, event: Union[hci.HCI_Command_Complete_Event, hci.HCI_Command_Status_Event]
self, event: hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event
):
if self.pending_response:
# Check that it is what we were expecting
@@ -966,11 +961,11 @@ class Host(utils.EventEmitter):
def on_hci_le_connection_complete_event(
self,
event: Union[
hci.HCI_LE_Connection_Complete_Event,
hci.HCI_LE_Enhanced_Connection_Complete_Event,
hci.HCI_LE_Enhanced_Connection_Complete_V2_Event,
],
event: (
hci.HCI_LE_Connection_Complete_Event
| hci.HCI_LE_Enhanced_Connection_Complete_Event
| hci.HCI_LE_Enhanced_Connection_Complete_V2_Event
),
):
# Check if this is a cancellation
if event.status == hci.HCI_SUCCESS:
@@ -1015,10 +1010,10 @@ class Host(utils.EventEmitter):
def on_hci_le_enhanced_connection_complete_event(
self,
event: Union[
hci.HCI_LE_Enhanced_Connection_Complete_Event,
hci.HCI_LE_Enhanced_Connection_Complete_V2_Event,
],
event: (
hci.HCI_LE_Enhanced_Connection_Complete_Event
| hci.HCI_LE_Enhanced_Connection_Complete_V2_Event
),
):
# Just use the same implementation as for the non-enhanced event for now
self.on_hci_le_connection_complete_event(event)
@@ -1397,8 +1392,7 @@ class Host(utils.EventEmitter):
if event.status == hci.HCI_SUCCESS:
# Create/update the connection
logger.debug(
f'### SCO CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.bd_addr}'
f'### SCO CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}'
)
self.sco_links[event.connection_handle] = ScoLink(
@@ -1450,7 +1444,7 @@ class Host(utils.EventEmitter):
def on_hci_le_data_length_change_event(
self, event: hci.HCI_LE_Data_Length_Change_Event
):
if (connection := self.connections.get(event.connection_handle)) is None:
if event.connection_handle not in self.connections:
logger.warning('!!! DATA LENGTH CHANGE: unknown handle')
return

View File

@@ -27,7 +27,7 @@ import dataclasses
import json
import logging
import os
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from typing_extensions import Self
@@ -51,8 +51,8 @@ class PairingKeys:
class Key:
value: bytes
authenticated: bool = False
ediv: Optional[int] = None
rand: Optional[bytes] = None
ediv: int | None = None
rand: bytes | None = None
@classmethod
def from_dict(cls, key_dict: dict[str, Any]) -> PairingKeys.Key:
@@ -74,17 +74,17 @@ class PairingKeys:
return key_dict
address_type: Optional[hci.AddressType] = None
ltk: Optional[Key] = None
ltk_central: Optional[Key] = None
ltk_peripheral: Optional[Key] = None
irk: Optional[Key] = None
csrk: Optional[Key] = None
link_key: Optional[Key] = None # Classic
link_key_type: Optional[int] = None # Classic
address_type: hci.AddressType | None = None
ltk: Key | None = None
ltk_central: Key | None = None
ltk_peripheral: Key | None = None
irk: Key | None = None
csrk: Key | None = None
link_key: Key | None = None # Classic
link_key_type: int | None = None # Classic
@classmethod
def key_from_dict(cls, keys_dict: dict[str, Any], key_name: str) -> Optional[Key]:
def key_from_dict(cls, keys_dict: dict[str, Any], key_name: str) -> Key | None:
key_dict = keys_dict.get(key_name)
if key_dict is None:
return None
@@ -156,7 +156,7 @@ class KeyStore:
async def update(self, name: str, keys: PairingKeys) -> None:
pass
async def get(self, _name: str) -> Optional[PairingKeys]:
async def get(self, _name: str) -> PairingKeys | None:
return None
async def get_all(self) -> list[tuple[str, PairingKeys]]:
@@ -274,7 +274,7 @@ class JsonKeyStore(KeyStore):
@classmethod
def from_device(
cls: type[Self], device: Device, filename: Optional[str] = None
cls: type[Self], device: Device, filename: str | None = None
) -> Self:
if not filename:
# Extract the filename from the config if there is one
@@ -297,7 +297,7 @@ class JsonKeyStore(KeyStore):
# Try to open the file, without failing. If the file does not exist, it
# will be created upon saving.
try:
with open(self.filename, 'r', encoding='utf-8') as json_file:
with open(self.filename, encoding='utf-8') as json_file:
db = json.load(json_file)
except FileNotFoundError:
db = {}
@@ -348,7 +348,7 @@ class JsonKeyStore(KeyStore):
key_map.clear()
await self.save(db)
async def get(self, name: str) -> Optional[PairingKeys]:
async def get(self, name: str) -> PairingKeys | None:
_, key_map = await self.load()
if name not in key_map:
return None
@@ -370,7 +370,7 @@ class MemoryKeyStore(KeyStore):
async def update(self, name: str, keys: PairingKeys) -> None:
self.all_keys[name] = keys
async def get(self, name: str) -> Optional[PairingKeys]:
async def get(self, name: str) -> PairingKeys | None:
return self.all_keys.get(name)
async def get_all(self) -> list[tuple[str, PairingKeys]]:

File diff suppressed because it is too large Load Diff

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
@@ -18,18 +19,12 @@ import asyncio
# Imports
# -----------------------------------------------------------------------------
import logging
from typing import Optional
from typing import TYPE_CHECKING
from bumble import controller, core
from bumble.hci import (
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_PAGE_TIMEOUT_ERROR,
HCI_SUCCESS,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
Address,
HCI_Connection_Complete_Event,
Role,
)
from bumble import core, hci, ll, lmp
if TYPE_CHECKING:
from bumble import controller
# -----------------------------------------------------------------------------
# Logging
@@ -37,18 +32,6 @@ from bumble.hci import (
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def parse_parameters(params_str):
result = {}
for param_str in params_str.split(','):
if '=' in param_str:
key, value = param_str.split('=')
result[key] = value
return result
# -----------------------------------------------------------------------------
# TODO: add more support for various LL exchanges
# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
@@ -62,37 +45,34 @@ class LocalLink:
def __init__(self):
self.controllers = set()
self.pending_connection = None
self.pending_classic_connection = None
############################################################
# Common utils
############################################################
def add_controller(self, controller):
def add_controller(self, controller: controller.Controller):
logger.debug(f'new controller: {controller}')
self.controllers.add(controller)
def remove_controller(self, controller):
def remove_controller(self, controller: controller.Controller):
self.controllers.remove(controller)
def find_controller(self, address):
def find_le_controller(self, address: hci.Address) -> controller.Controller | None:
for controller in self.controllers:
if controller.random_address == address:
return controller
for connection in controller.le_connections.values():
if connection.self_address == address:
return controller
return None
def find_classic_controller(
self, address: Address
) -> Optional[controller.Controller]:
self, address: hci.Address
) -> controller.Controller | None:
for controller in self.controllers:
if controller.public_address == address:
return controller
return None
def get_pending_connection(self):
return self.pending_connection
############################################################
# LE handlers
############################################################
@@ -100,16 +80,16 @@ class LocalLink:
def on_address_changed(self, controller):
pass
def send_advertising_data(self, sender_address, data):
# Send the advertising data to all controllers, except the sender
for controller in self.controllers:
if controller.random_address != sender_address:
controller.on_link_advertising_data(sender_address, data)
def send_acl_data(self, sender_controller, destination_address, transport, data):
def send_acl_data(
self,
sender_controller: controller.Controller,
destination_address: hci.Address,
transport: core.PhysicalTransport,
data: bytes,
):
# Send the data to the first controller with a matching address
if transport == core.PhysicalTransport.LE:
destination_controller = self.find_controller(destination_address)
destination_controller = self.find_le_controller(destination_address)
source_address = sender_controller.random_address
elif transport == core.PhysicalTransport.BR_EDR:
destination_controller = self.find_classic_controller(destination_address)
@@ -118,262 +98,52 @@ class LocalLink:
raise ValueError("unsupported transport type")
if destination_controller is not None:
destination_controller.on_link_acl_data(source_address, transport, data)
def on_connection_complete(self):
# Check that we expect this call
if not self.pending_connection:
logger.warning('on_connection_complete with no pending connection')
return
central_address, le_create_connection_command = self.pending_connection
self.pending_connection = None
# Find the controller that initiated the connection
if not (central_controller := self.find_controller(central_address)):
logger.warning('!!! Initiating controller not found')
return
# Connect to the first controller with a matching address
if peripheral_controller := self.find_controller(
le_create_connection_command.peer_address
):
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_SUCCESS
asyncio.get_running_loop().call_soon(
lambda: destination_controller.on_link_acl_data(
source_address, transport, data
)
)
peripheral_controller.on_link_central_connected(central_address)
return
# No peripheral found
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
)
def connect(self, central_address, le_create_connection_command):
logger.debug(
f'$$$ CONNECTION {central_address} -> '
f'{le_create_connection_command.peer_address}'
)
self.pending_connection = (central_address, le_create_connection_command)
asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete(
self, initiating_address, target_address, disconnect_command
def send_advertising_pdu(
self,
sender_controller: controller.Controller,
packet: ll.AdvertisingPdu,
):
# Find the controller that initiated the disconnection
if not (initiating_controller := self.find_controller(initiating_address)):
logger.warning('!!! Initiating controller not found')
return
loop = asyncio.get_running_loop()
for c in self.controllers:
if c != sender_controller:
loop.call_soon(c.on_ll_advertising_pdu, packet)
# Disconnect from the first controller with a matching address
if target_controller := self.find_controller(target_address):
target_controller.on_link_disconnected(
initiating_address, disconnect_command.reason
)
initiating_controller.on_link_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, initiating_address, target_address, disconnect_command):
logger.debug(
f'$$$ DISCONNECTION {initiating_address} -> '
f'{target_address}: reason = {disconnect_command.reason}'
)
args = [initiating_address, target_address, disconnect_command]
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
# pylint: disable=too-many-arguments
def on_connection_encrypted(
self, central_address, peripheral_address, rand, ediv, ltk
def send_ll_control_pdu(
self,
sender_address: hci.Address,
receiver_address: hci.Address,
packet: ll.ControlPdu,
):
logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
if central_controller := self.find_controller(central_address):
central_controller.on_link_encrypted(peripheral_address, rand, ediv, ltk)
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
def create_cis(
self,
central_controller: controller.Controller,
peripheral_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
if not (receiver_controller := self.find_le_controller(receiver_address)):
raise core.InvalidArgumentError(
f"Unable to find controller for address {receiver_address}"
)
asyncio.get_running_loop().call_soon(
lambda: receiver_controller.on_ll_control_pdu(sender_address, packet)
)
if peripheral_controller := self.find_controller(peripheral_address):
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_request,
central_controller.random_address,
cig_id,
cis_id,
)
def accept_cis(
self,
peripheral_controller: controller.Controller,
central_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
)
if central_controller := self.find_controller(central_address):
asyncio.get_running_loop().call_soon(
central_controller.on_link_cis_established, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_established, cig_id, cis_id
)
def disconnect_cis(
self,
initiator_controller: controller.Controller,
peer_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
)
if peer_controller := self.find_controller(peer_address):
asyncio.get_running_loop().call_soon(
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peer_controller.on_link_cis_disconnected, cig_id, cis_id
)
############################################################
# Classic handlers
############################################################
def classic_connect(self, initiator_controller, responder_address):
logger.debug(
f'[Classic] {initiator_controller.public_address} connects to {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
if responder_controller is None:
initiator_controller.on_classic_connection_complete(
responder_address, HCI_PAGE_TIMEOUT_ERROR
)
return
self.pending_classic_connection = (initiator_controller, responder_controller)
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
HCI_Connection_Complete_Event.LinkType.ACL,
)
def classic_accept_connection(
self, responder_controller, initiator_address, responder_role
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_connection_complete(
responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR
)
return
async def task():
if responder_role != Role.PERIPHERAL:
initiator_controller.on_classic_role_change(
responder_controller.public_address, int(not (responder_role))
)
initiator_controller.on_classic_connection_complete(
responder_controller.public_address, HCI_SUCCESS
)
asyncio.create_task(task())
responder_controller.on_classic_role_change(
initiator_controller.public_address, responder_role
)
responder_controller.on_classic_connection_complete(
initiator_controller.public_address, HCI_SUCCESS
)
self.pending_classic_connection = None
def classic_disconnect(self, initiator_controller, responder_address, reason):
logger.debug(
f'[Classic] {initiator_controller.public_address} disconnects {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
async def task():
initiator_controller.on_classic_disconnected(responder_address, reason)
asyncio.create_task(task())
responder_controller.on_classic_disconnected(
initiator_controller.public_address, reason
)
def classic_switch_role(
self, initiator_controller, responder_address, initiator_new_role
):
responder_controller = self.find_classic_controller(responder_address)
if responder_controller is None:
return
async def task():
initiator_controller.on_classic_role_change(
responder_address, initiator_new_role
)
asyncio.create_task(task())
responder_controller.on_classic_role_change(
initiator_controller.public_address, int(not (initiator_new_role))
)
def classic_sco_connect(
def send_lmp_packet(
self,
initiator_controller: controller.Controller,
responder_address: Address,
link_type: int,
sender_controller: controller.Controller,
receiver_address: hci.Address,
packet: lmp.Packet,
):
logger.debug(
f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
# Initiator controller should handle it.
assert responder_controller
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
link_type,
)
def classic_accept_sco_connection(
self,
responder_controller: controller.Controller,
initiator_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_sco_connection_complete(
responder_controller.public_address,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
link_type,
if not (receiver_controller := self.find_classic_controller(receiver_address)):
raise core.InvalidArgumentError(
f"Unable to find controller for address {receiver_address}"
)
return
async def task():
initiator_controller.on_classic_sco_connection_complete(
responder_controller.public_address, HCI_SUCCESS, link_type
asyncio.get_running_loop().call_soon(
lambda: receiver_controller.on_lmp_packet(
sender_controller.public_address, packet
)
asyncio.create_task(task())
responder_controller.on_classic_sco_connection_complete(
initiator_controller.public_address, HCI_SUCCESS, link_type
)

200
bumble/ll.py Normal file
View File

@@ -0,0 +1,200 @@
# Copyright 2021-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
from typing import ClassVar
from bumble import hci
# -----------------------------------------------------------------------------
# Advertising PDU
# -----------------------------------------------------------------------------
class AdvertisingPdu:
"""Base Advertising Physical Channel PDU class.
See Core Spec 6.0, Volume 6, Part B, 2.3. Advertising physical channel PDU.
Currently these messages don't really follow the LL spec, because LL protocol is
context-aware and we don't have real physical transport.
"""
@dataclasses.dataclass
class ConnectInd(AdvertisingPdu):
initiator_address: hci.Address
advertiser_address: hci.Address
interval: int
latency: int
timeout: int
@dataclasses.dataclass
class AdvInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
@dataclasses.dataclass
class AdvDirectInd(AdvertisingPdu):
advertiser_address: hci.Address
target_address: hci.Address
@dataclasses.dataclass
class AdvNonConnInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
@dataclasses.dataclass
class AdvExtInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
target_address: hci.Address | None = None
adi: int | None = None
tx_power: int | None = None
# -----------------------------------------------------------------------------
# LL Control PDU
# -----------------------------------------------------------------------------
class ControlPdu:
"""Base LL Control PDU Class.
See Core Spec 6.0, Volume 6, Part B, 2.4.2. LL Control PDU.
Currently these messages don't really follow the LL spec, because LL protocol is
context-aware and we don't have real physical transport.
"""
class Opcode(hci.SpecableEnum):
LL_CONNECTION_UPDATE_IND = 0x00
LL_CHANNEL_MAP_IND = 0x01
LL_TERMINATE_IND = 0x02
LL_ENC_REQ = 0x03
LL_ENC_RSP = 0x04
LL_START_ENC_REQ = 0x05
LL_START_ENC_RSP = 0x06
LL_UNKNOWN_RSP = 0x07
LL_FEATURE_REQ = 0x08
LL_FEATURE_RSP = 0x09
LL_PAUSE_ENC_REQ = 0x0A
LL_PAUSE_ENC_RSP = 0x0B
LL_VERSION_IND = 0x0C
LL_REJECT_IND = 0x0D
LL_PERIPHERAL_FEATURE_REQ = 0x0E
LL_CONNECTION_PARAM_REQ = 0x0F
LL_CONNECTION_PARAM_RSP = 0x10
LL_REJECT_EXT_IND = 0x11
LL_PING_REQ = 0x12
LL_PING_RSP = 0x13
LL_LENGTH_REQ = 0x14
LL_LENGTH_RSP = 0x15
LL_PHY_REQ = 0x16
LL_PHY_RSP = 0x17
LL_PHY_UPDATE_IND = 0x18
LL_MIN_USED_CHANNELS_IND = 0x19
LL_CTE_REQ = 0x1A
LL_CTE_RSP = 0x1B
LL_PERIODIC_SYNC_IND = 0x1C
LL_CLOCK_ACCURACY_REQ = 0x1D
LL_CLOCK_ACCURACY_RSP = 0x1E
LL_CIS_REQ = 0x1F
LL_CIS_RSP = 0x20
LL_CIS_IND = 0x21
LL_CIS_TERMINATE_IND = 0x22
LL_POWER_CONTROL_REQ = 0x23
LL_POWER_CONTROL_RSP = 0x24
LL_POWER_CHANGE_IND = 0x25
LL_SUBRATE_REQ = 0x26
LL_SUBRATE_IND = 0x27
LL_CHANNEL_REPORTING_IND = 0x28
LL_CHANNEL_STATUS_IND = 0x29
LL_PERIODIC_SYNC_WR_IND = 0x2A
LL_FEATURE_EXT_REQ = 0x2B
LL_FEATURE_EXT_RSP = 0x2C
LL_CS_SEC_RSP = 0x2D
LL_CS_CAPABILITIES_REQ = 0x2E
LL_CS_CAPABILITIES_RSP = 0x2F
LL_CS_CONFIG_REQ = 0x30
LL_CS_CONFIG_RSP = 0x31
LL_CS_REQ = 0x32
LL_CS_RSP = 0x33
LL_CS_IND = 0x34
LL_CS_TERMINATE_REQ = 0x35
LL_CS_FAE_REQ = 0x36
LL_CS_FAE_RSP = 0x37
LL_CS_CHANNEL_MAP_IND = 0x38
LL_CS_SEC_REQ = 0x39
LL_CS_TERMINATE_RSP = 0x3A
LL_FRAME_SPACE_REQ = 0x3B
LL_FRAME_SPACE_RSP = 0x3C
opcode: ClassVar[Opcode]
@dataclasses.dataclass
class TerminateInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_TERMINATE_IND
error_code: int
@dataclasses.dataclass
class EncReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_ENC_REQ
rand: bytes
ediv: int
ltk: bytes
@dataclasses.dataclass
class CisReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisRsp(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisTerminateInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_TERMINATE_IND
cig_id: int
cis_id: int
error_code: int

324
bumble/lmp.py Normal file
View File

@@ -0,0 +1,324 @@
# Copyright 2021-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from dataclasses import dataclass, field
from typing import TypeVar
from bumble import hci, utils
class Opcode(utils.OpenIntEnum):
'''
See Bluetooth spec @ Vol 2, Part C - 5.1 PDU summary.
Follow the alphabetical order defined there.
'''
# fmt: off
LMP_ACCEPTED = 3
LMP_ACCEPTED_EXT = 127 << 8 + 1
LMP_AU_RAND = 11
LMP_AUTO_RATE = 35
LMP_CHANNEL_CLASSIFICATION = 127 << 8 + 17
LMP_CHANNEL_CLASSIFICATION_REQ = 127 << 8 + 16
LMP_CLK_ADJ = 127 << 8 + 5
LMP_CLK_ADJ_ACK = 127 << 8 + 6
LMP_CLK_ADJ_REQ = 127 << 8 + 7
LMP_CLKOFFSET_REQ = 5
LMP_CLKOFFSET_RES = 6
LMP_COMB_KEY = 9
LMP_DECR_POWER_REQ = 32
LMP_DETACH = 7
LMP_DHKEY_CHECK = 65
LMP_ENCAPSULATED_HEADER = 61
LMP_ENCAPSULATED_PAYLOAD = 62
LMP_ENCRYPTION_KEY_SIZE_MASK_REQ= 58
LMP_ENCRYPTION_KEY_SIZE_MASK_RES= 59
LMP_ENCRYPTION_KEY_SIZE_REQ = 16
LMP_ENCRYPTION_MODE_REQ = 15
LMP_ESCO_LINK_REQ = 127 << 8 + 12
LMP_FEATURES_REQ = 39
LMP_FEATURES_REQ_EXT = 127 << 8 + 3
LMP_FEATURES_RES = 40
LMP_FEATURES_RES_EXT = 127 << 8 + 4
LMP_HOLD = 20
LMP_HOLD_REQ = 21
LMP_HOST_CONNECTION_REQ = 51
LMP_IN_RAND = 8
LMP_INCR_POWER_REQ = 31
LMP_IO_CAPABILITY_REQ = 127 << 8 + 25
LMP_IO_CAPABILITY_RES = 127 << 8 + 26
LMP_KEYPRESS_NOTIFICATION = 127 << 8 + 30
LMP_MAX_POWER = 33
LMP_MAX_SLOT = 45
LMP_MAX_SLOT_REQ = 46
LMP_MIN_POWER = 34
LMP_NAME_REQ = 1
LMP_NAME_RES = 2
LMP_NOT_ACCEPTED = 4
LMP_NOT_ACCEPTED_EXT = 127 << 8 + 2
LMP_NUMERIC_COMPARISON_FAILED = 127 << 8 + 27
LMP_OOB_FAILED = 127 << 8 + 29
LMP_PACKET_TYPE_TABLE_REQ = 127 << 8 + 11
LMP_PAGE_MODE_REQ = 53
LMP_PAGE_SCAN_MODE_REQ = 54
LMP_PASSKEY_FAILED = 127 << 8 + 28
LMP_PAUSE_ENCRYPTION_AES_REQ = 66
LMP_PAUSE_ENCRYPTION_REQ = 127 << 8 + 23
LMP_PING_REQ = 127 << 8 + 33
LMP_PING_RES = 127 << 8 + 34
LMP_POWER_CONTROL_REQ = 127 << 8 + 31
LMP_POWER_CONTROL_RES = 127 << 8 + 32
LMP_PREFERRED_RATE = 36
LMP_QUALITY_OF_SERVICE = 41
LMP_QUALITY_OF_SERVICE_REQ = 42
LMP_REMOVE_ESCO_LINK_REQ = 127 << 8 + 13
LMP_REMOVE_SCO_LINK_REQ = 44
LMP_RESUME_ENCRYPTION_REQ = 127 << 8 + 24
LMP_SAM_DEFINE_MAP = 127 << 8 + 36
LMP_SAM_SET_TYPE0 = 127 << 8 + 35
LMP_SAM_SWITCH = 127 << 8 + 37
LMP_SCO_LINK_REQ = 43
LMP_SET_AFH = 60
LMP_SETUP_COMPLETE = 49
LMP_SIMPLE_PAIRING_CONFIRM = 63
LMP_SIMPLE_PAIRING_NUMBER = 64
LMP_SLOT_OFFSET = 52
LMP_SNIFF_REQ = 23
LMP_SNIFF_SUBRATING_REQ = 127 << 8 + 21
LMP_SNIFF_SUBRATING_RES = 127 << 8 + 22
LMP_SRES = 12
LMP_START_ENCRYPTION_REQ = 17
LMP_STOP_ENCRYPTION_REQ = 18
LMP_SUPERVISION_TIMEOUT = 55
LMP_SWITCH_REQ = 19
LMP_TEMP_KEY = 14
LMP_TEMP_RAND = 13
LMP_TEST_ACTIVATE = 56
LMP_TEST_CONTROL = 57
LMP_TIMING_ACCURACY_REQ = 47
LMP_TIMING_ACCURACY_RES = 48
LMP_UNIT_KEY = 10
LMP_UNSNIFF_REQ = 24
LMP_USE_SEMI_PERMANENT_KEY = 50
LMP_VERSION_REQ = 37
LMP_VERSION_RES = 38
# fmt: on
@classmethod
def parse_from(cls, data: bytes, offset: int = 0) -> tuple[int, Opcode]:
opcode = data[offset]
if opcode in (124, 127):
opcode = struct.unpack('>H', data)[0]
return offset + 2, Opcode(opcode)
return offset + 1, Opcode(opcode)
def __bytes__(self) -> bytes:
if self.value >> 8:
return struct.pack('>H', self.value)
return bytes([self.value])
@classmethod
def type_metadata(cls):
return hci.metadata(
{
'serializer': bytes,
'parser': lambda data, offset: (Opcode.parse_from(data, offset)),
}
)
class Packet:
'''
See Bluetooth spec @ Vol 2, Part C - 5.1 PDU summary
'''
subclasses: dict[int, type[Packet]] = {}
opcode: Opcode
fields: hci.Fields = ()
_payload: bytes = b''
_Packet = TypeVar("_Packet", bound="Packet")
@classmethod
def subclass(cls, subclass: type[_Packet]) -> type[_Packet]:
# Register a factory for this class
cls.subclasses[subclass.opcode] = subclass
subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass)
return subclass
@classmethod
def from_bytes(cls, data: bytes) -> Packet:
offset, opcode = Opcode.parse_from(data)
if not (subclass := cls.subclasses.get(opcode)):
instance = Packet()
instance.opcode = opcode
else:
instance = subclass(
**hci.HCI_Object.dict_from_bytes(data, offset, subclass.fields)
)
instance.payload = data[offset:]
return instance
@property
def payload(self) -> bytes:
if self._payload is None:
self._payload = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
return self._payload
@payload.setter
def payload(self, value: bytes) -> None:
self._payload = value
def __bytes__(self) -> bytes:
return bytes(self.opcode) + self.payload
@Packet.subclass
@dataclass
class LmpAccepted(Packet):
opcode = Opcode.LMP_ACCEPTED
response_opcode: Opcode = field(metadata=Opcode.type_metadata())
@Packet.subclass
@dataclass
class LmpNotAccepted(Packet):
opcode = Opcode.LMP_NOT_ACCEPTED
response_opcode: Opcode = field(metadata=Opcode.type_metadata())
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpAcceptedExt(Packet):
opcode = Opcode.LMP_ACCEPTED_EXT
response_opcode: Opcode = field(metadata=Opcode.type_metadata())
@Packet.subclass
@dataclass
class LmpNotAcceptedExt(Packet):
opcode = Opcode.LMP_NOT_ACCEPTED_EXT
response_opcode: Opcode = field(metadata=Opcode.type_metadata())
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpAuRand(Packet):
opcode = Opcode.LMP_AU_RAND
random_number: bytes = field(metadata=hci.metadata(16))
@Packet.subclass
@dataclass
class LmpDetach(Packet):
opcode = Opcode.LMP_DETACH
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpEscoLinkReq(Packet):
opcode = Opcode.LMP_ESCO_LINK_REQ
esco_handle: int = field(metadata=hci.metadata(1))
esco_lt_addr: int = field(metadata=hci.metadata(1))
timing_control_flags: int = field(metadata=hci.metadata(1))
d_esco: int = field(metadata=hci.metadata(1))
t_esco: int = field(metadata=hci.metadata(1))
w_esco: int = field(metadata=hci.metadata(1))
esco_packet_type_c_to_p: int = field(metadata=hci.metadata(1))
esco_packet_type_p_to_c: int = field(metadata=hci.metadata(1))
packet_length_c_to_p: int = field(metadata=hci.metadata(2))
packet_length_p_to_c: int = field(metadata=hci.metadata(2))
air_mode: int = field(metadata=hci.metadata(1))
negotiation_state: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpHostConnectionReq(Packet):
opcode = Opcode.LMP_HOST_CONNECTION_REQ
@Packet.subclass
@dataclass
class LmpRemoveEscoLinkReq(Packet):
opcode = Opcode.LMP_REMOVE_ESCO_LINK_REQ
esco_handle: int = field(metadata=hci.metadata(1))
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpRemoveScoLinkReq(Packet):
opcode = Opcode.LMP_REMOVE_SCO_LINK_REQ
sco_handle: int = field(metadata=hci.metadata(1))
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpScoLinkReq(Packet):
opcode = Opcode.LMP_SCO_LINK_REQ
sco_handle: int = field(metadata=hci.metadata(1))
timing_control_flags: int = field(metadata=hci.metadata(1))
d_sco: int = field(metadata=hci.metadata(1))
t_sco: int = field(metadata=hci.metadata(1))
sco_packet: int = field(metadata=hci.metadata(1))
air_mode: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpSwitchReq(Packet):
opcode = Opcode.LMP_SWITCH_REQ
switch_instant: int = field(metadata=hci.metadata(4), default=0)
@Packet.subclass
@dataclass
class LmpNameReq(Packet):
opcode = Opcode.LMP_NAME_REQ
name_offset: int = field(metadata=hci.metadata(2))
@Packet.subclass
@dataclass
class LmpNameRes(Packet):
opcode = Opcode.LMP_NAME_RES
name_offset: int = field(metadata=hci.metadata(2))
name_length: int = field(metadata=hci.metadata(3))
name_fregment: bytes = field(metadata=hci.metadata('*'))

View File

@@ -20,7 +20,6 @@ from __future__ import annotations
import enum
import secrets
from dataclasses import dataclass
from typing import Optional
from bumble import hci
from bumble.core import AdvertisingData, LeRole
@@ -45,16 +44,16 @@ from bumble.smp import (
class OobData:
"""OOB data that can be sent from one device to another."""
address: Optional[hci.Address] = None
role: Optional[LeRole] = None
shared_data: Optional[OobSharedData] = None
legacy_context: Optional[OobLegacyContext] = None
address: hci.Address | None = None
role: LeRole | None = None
shared_data: OobSharedData | None = None
legacy_context: OobLegacyContext | None = None
@classmethod
def from_ad(cls, ad: AdvertisingData) -> OobData:
instance = cls()
shared_data_c: Optional[bytes] = None
shared_data_r: Optional[bytes] = None
shared_data_c: bytes | None = None
shared_data_r: bytes | None = None
for ad_type, ad_data in ad.ad_structures:
if ad_type == AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS:
instance.address = hci.Address(ad_data)
@@ -181,14 +180,14 @@ class PairingDelegate:
"""Compare two numbers."""
return True
async def get_number(self) -> Optional[int]:
async def get_number(self) -> int | None:
"""
Return an optional number as an answer to a passkey request.
Returning `None` will result in a negative reply.
"""
return 0
async def get_string(self, max_length: int) -> Optional[str]:
async def get_string(self, max_length: int) -> str | None:
"""
Return a string whose utf-8 encoding is up to max_length bytes.
"""
@@ -239,18 +238,18 @@ class PairingConfig:
class OobConfig:
"""Config for OOB pairing."""
our_context: Optional[OobContext]
peer_data: Optional[OobSharedData]
legacy_context: Optional[OobLegacyContext]
our_context: OobContext | None
peer_data: OobSharedData | None
legacy_context: OobLegacyContext | None
def __init__(
self,
sc: bool = True,
mitm: bool = True,
bonding: bool = True,
delegate: Optional[PairingDelegate] = None,
identity_address_type: Optional[AddressType] = None,
oob: Optional[OobConfig] = None,
delegate: PairingDelegate | None = None,
identity_address_type: AddressType | None = None,
oob: OobConfig | None = None,
) -> None:
self.sc = sc
self.mitm = mitm

View File

@@ -19,7 +19,7 @@ This module implement the Pandora Bluetooth test APIs for the Bumble stack.
__version__ = "0.0.1"
from typing import Callable, List, Optional
from collections.abc import Callable
import grpc
import grpc.aio
@@ -58,7 +58,7 @@ def register_servicer_hook(
async def serve(
bumble: PandoraDevice,
config: Config = Config(),
grpc_server: Optional[grpc.aio.Server] = None,
grpc_server: grpc.aio.Server | None = None,
port: int = 0,
) -> None:
# initialize a gRPC server if not provided.

View File

@@ -16,7 +16,7 @@
from __future__ import annotations
from typing import Any, Optional
from typing import Any
from bumble import transport
from bumble.core import (
@@ -54,7 +54,7 @@ class PandoraDevice:
# HCI transport name & instance.
_hci_name: str
_hci: Optional[transport.Transport] # type: ignore[name-defined]
_hci: transport.Transport | None # type: ignore[name-defined]
def __init__(self, config: dict[str, Any]) -> None:
self.config = config
@@ -74,7 +74,9 @@ class PandoraDevice:
# open HCI transport & set device host.
self._hci = await transport.open_transport(self._hci_name)
self.device.host = Host(controller_source=self._hci.source, controller_sink=self._hci.sink) # type: ignore[no-untyped-call]
self.device.host = Host(
controller_source=self._hci.source, controller_sink=self._hci.sink
) # type: ignore[no-untyped-call]
# power-on.
await self.device.power_on()
@@ -96,7 +98,7 @@ class PandoraDevice:
await self.close()
await self.open()
def info(self) -> Optional[dict[str, str]]:
def info(self) -> dict[str, str] | None:
return {
'public_bd_address': str(self.device.public_address),
'random_address': str(self.device.random_address),

View File

@@ -17,12 +17,15 @@ from __future__ import annotations
import asyncio
import logging
import struct
from typing import AsyncGenerator, Optional, cast
from collections.abc import AsyncGenerator
from typing import cast
import grpc
import grpc.aio
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import (
any_pb2, # pytype: disable=pyi-error
empty_pb2, # pytype: disable=pyi-error
)
from pandora import host_pb2
from pandora.host_grpc_aio import HostServicer
from pandora.host_pb2 import (
@@ -302,7 +305,9 @@ class HostService(HostServicer):
await disconnection_future
self.log.debug("Disconnected")
finally:
connection.remove_listener(connection.EVENT_DISCONNECTION, on_disconnection) # type: ignore
connection.remove_listener(
connection.EVENT_DISCONNECTION, on_disconnection
) # type: ignore
return empty_pb2.Empty()
@@ -539,7 +544,7 @@ class HostService(HostServicer):
await bumble.utils.cancel_on_event(
self.device, 'flush', self.device.stop_advertising()
)
except:
except Exception:
pass
@utils.rpc
@@ -609,7 +614,7 @@ class HostService(HostServicer):
await bumble.utils.cancel_on_event(
self.device, 'flush', self.device.stop_scanning()
)
except:
except Exception:
pass
@utils.rpc
@@ -619,7 +624,7 @@ class HostService(HostServicer):
self.log.debug('Inquiry')
inquiry_queue: asyncio.Queue[
Optional[tuple[Address, int, AdvertisingData, int]]
tuple[Address, int, AdvertisingData, int] | None
] = asyncio.Queue()
complete_handler = self.device.on(
self.device.EVENT_INQUIRY_COMPLETE, lambda: inquiry_queue.put_nowait(None)
@@ -644,14 +649,18 @@ class HostService(HostServicer):
)
finally:
self.device.remove_listener(self.device.EVENT_INQUIRY_COMPLETE, complete_handler) # type: ignore
self.device.remove_listener(self.device.EVENT_INQUIRY_RESULT, result_handler) # type: ignore
self.device.remove_listener(
self.device.EVENT_INQUIRY_COMPLETE, complete_handler
) # type: ignore
self.device.remove_listener(
self.device.EVENT_INQUIRY_RESULT, result_handler
) # type: ignore
try:
self.log.debug('Stop inquiry')
await bumble.utils.cancel_on_event(
self.device, 'flush', self.device.stop_discovery()
)
except:
except Exception:
pass
@utils.rpc

View File

@@ -18,15 +18,15 @@ import json
import logging
from asyncio import Future
from asyncio import Queue as AsyncQueue
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import AsyncGenerator, Optional, Union
import grpc
from google.protobuf import any_pb2, empty_pb2 # pytype: disable=pyi-error
from pandora.l2cap_grpc_aio import L2CAPServicer # pytype: disable=pyi-error
from pandora.l2cap_pb2 import COMMAND_NOT_UNDERSTOOD, INVALID_CID_IN_REQUEST
from pandora.l2cap_pb2 import Channel as PandoraChannel # pytype: disable=pyi-error
from pandora.l2cap_pb2 import (
COMMAND_NOT_UNDERSTOOD,
INVALID_CID_IN_REQUEST,
ConnectRequest,
ConnectResponse,
CreditBasedChannelRequest,
@@ -41,6 +41,7 @@ from pandora.l2cap_pb2 import (
WaitDisconnectionRequest,
WaitDisconnectionResponse,
)
from pandora.l2cap_pb2 import Channel as PandoraChannel # pytype: disable=pyi-error
from bumble.core import InvalidArgumentError, OutOfResourcesError
from bumble.device import Device
@@ -55,7 +56,7 @@ from bumble.l2cap import (
from bumble.pandora import utils
from bumble.pandora.config import Config
L2capChannel = Union[ClassicChannel, LeCreditBasedChannel]
L2capChannel = ClassicChannel | LeCreditBasedChannel
@dataclass
@@ -106,10 +107,8 @@ class L2CAPService(L2CAPServicer):
oneof = request.WhichOneof('type')
self.log.debug(f'WaitConnection channel request type: {oneof}.')
channel_type = getattr(request, oneof)
spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
l2cap_server: Optional[
Union[ClassicChannelServer, LeCreditBasedChannelServer]
] = None
spec: ClassicChannelSpec | LeCreditBasedChannelSpec | None = None
l2cap_server: ClassicChannelServer | LeCreditBasedChannelServer | None = None
if isinstance(channel_type, CreditBasedChannelRequest):
spec = LeCreditBasedChannelSpec(
psm=channel_type.spsm,
@@ -216,7 +215,7 @@ class L2CAPService(L2CAPServicer):
oneof = request.WhichOneof('type')
self.log.debug(f'Channel request type: {oneof}.')
channel_type = getattr(request, oneof)
spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
spec: ClassicChannelSpec | LeCreditBasedChannelSpec | None = None
if isinstance(channel_type, CreditBasedChannelRequest):
spec = LeCreditBasedChannelSpec(
psm=channel_type.spsm,

View File

@@ -17,13 +17,15 @@ from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import Awaitable
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Optional, Union
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
from typing import Any
import grpc
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
from google.protobuf import (
any_pb2, # pytype: disable=pyi-error
empty_pb2, # pytype: disable=pyi-error
wrappers_pb2, # pytype: disable=pyi-error
)
from pandora.host_pb2 import Connection
from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer
from pandora.security_pb2 import (
@@ -64,7 +66,7 @@ class PairingDelegate(BasePairingDelegate):
def __init__(
self,
connection: BumbleConnection,
service: "SecurityService",
service: SecurityService,
io_capability: BasePairingDelegate.IoCapability = BasePairingDelegate.NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
@@ -130,7 +132,7 @@ class PairingDelegate(BasePairingDelegate):
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
async def get_number(self) -> Optional[int]:
async def get_number(self) -> int | None:
self.log.debug(
f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
)
@@ -147,7 +149,7 @@ class PairingDelegate(BasePairingDelegate):
assert answer.answer_variant() == 'passkey'
return answer.passkey
async def get_string(self, max_length: int) -> Optional[str]:
async def get_string(self, max_length: int) -> str | None:
self.log.debug(
f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
)
@@ -195,8 +197,8 @@ class SecurityService(SecurityServicer):
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(), {'service_name': 'Security', 'device': device}
)
self.event_queue: Optional[asyncio.Queue[PairingEvent]] = None
self.event_answer: Optional[AsyncIterator[PairingEventAnswer]] = None
self.event_queue: asyncio.Queue[PairingEvent] | None = None
self.event_answer: AsyncIterator[PairingEventAnswer] | None = None
self.device = device
self.config = config
@@ -231,7 +233,7 @@ class SecurityService(SecurityServicer):
if level == LEVEL2:
return connection.encryption != 0 and connection.authenticated
link_key_type: Optional[int] = None
link_key_type: int | None = None
if (keystore := connection.device.keystore) and (
keys := await keystore.get(str(connection.peer_address))
):
@@ -410,8 +412,8 @@ class SecurityService(SecurityServicer):
wait_for_security: asyncio.Future[str] = (
asyncio.get_running_loop().create_future()
)
authenticate_task: Optional[asyncio.Future[None]] = None
pair_task: Optional[asyncio.Future[None]] = None
authenticate_task: asyncio.Future[None] | None = None
pair_task: asyncio.Future[None] | None = None
async def authenticate() -> None:
if (encryption := connection.encryption) != 0:
@@ -455,9 +457,9 @@ class SecurityService(SecurityServicer):
def pair(*_: Any) -> None:
if self.need_pairing(connection, level):
pair_task = asyncio.create_task(connection.pair())
bumble.utils.AsyncRunner.spawn(connection.pair())
listeners: dict[str, Callable[..., Union[None, Awaitable[None]]]] = {
listeners: dict[str, Callable[..., None | Awaitable[None]]] = {
'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'),
'connection_authentication_failure': set_failure('authentication_failure'),
@@ -500,7 +502,7 @@ class SecurityService(SecurityServicer):
return WaitSecurityResponse(**kwargs)
async def reached_security_level(
self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
self, connection: BumbleConnection, level: SecurityLevel | LESecurityLevel
) -> bool:
self.log.debug(
str(

View File

@@ -18,7 +18,8 @@ import contextlib
import functools
import inspect
import logging
from typing import Any, Generator, MutableMapping, Optional
from collections.abc import Generator, MutableMapping
from typing import Any
import grpc
from google.protobuf.message import Message # pytype: disable=pyi-error
@@ -34,7 +35,7 @@ ADDRESS_TYPES: dict[str, AddressType] = {
}
def address_from_request(request: Message, field: Optional[str]) -> Address:
def address_from_request(request: Message, field: str | None) -> Address:
if field is None:
return Address.ANY
return Address(bytes(reversed(getattr(request, field))), ADDRESS_TYPES[field])
@@ -95,8 +96,7 @@ def rpc(func: Any) -> Any:
@functools.wraps(func)
def gen_wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
with exception_to_rpc_error(context):
for v in func(self, request, context):
yield v
yield from func(self, request, context)
@functools.wraps(func)
def wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:

View File

@@ -22,7 +22,6 @@ from __future__ import annotations
import logging
import struct
from dataclasses import dataclass
from typing import Optional
from bumble import utils
from bumble.att import ATT_Error
@@ -129,7 +128,7 @@ class AudioInputState:
mute: Mute = Mute.NOT_MUTED
gain_mode: GainMode = GainMode.MANUAL
change_counter: int = 0
attribute: Optional[Attribute] = None
attribute: Attribute | None = None
def __bytes__(self) -> bytes:
return bytes(
@@ -199,7 +198,6 @@ class AudioInputControlPoint:
gain_settings_properties: GainSettingsProperties
async def on_write(self, connection: Connection, value: bytes) -> None:
opcode = AudioInputControlPointOpCode(value[0])
if opcode == AudioInputControlPointOpCode.SET_GAIN_SETTING:
@@ -317,7 +315,7 @@ class AudioInputDescription:
'''
audio_input_description: str = "Bluetooth"
attribute: Optional[Attribute] = None
attribute: Attribute | None = None
def on_read(self, _connection: Connection) -> str:
return self.audio_input_description
@@ -340,11 +338,11 @@ class AICSService(TemplateService):
def __init__(
self,
audio_input_state: Optional[AudioInputState] = None,
gain_settings_properties: Optional[GainSettingsProperties] = None,
audio_input_state: AudioInputState | None = None,
gain_settings_properties: GainSettingsProperties | None = None,
audio_input_type: str = "local",
audio_input_status: Optional[AudioInputStatus] = None,
audio_input_description: Optional[AudioInputDescription] = None,
audio_input_status: AudioInputStatus | None = None,
audio_input_description: AudioInputDescription | None = None,
):
self.audio_input_state = (
AudioInputState() if audio_input_state is None else audio_input_state

View File

@@ -25,7 +25,7 @@ import asyncio
import dataclasses
import enum
import logging
from typing import Iterable, Optional, Union
from collections.abc import Iterable
from bumble import utils
from bumble.device import Peer
@@ -230,7 +230,7 @@ class AmsClient(utils.EventEmitter):
self.supported_commands = set()
@classmethod
async def for_peer(cls, peer: Peer) -> Optional[AmsClient]:
async def for_peer(cls, peer: Peer) -> AmsClient | None:
ams_proxy = await peer.discover_service_and_create_proxy(AmsProxy)
if ams_proxy is None:
return None
@@ -263,9 +263,7 @@ class AmsClient(utils.EventEmitter):
async def observe(
self,
entity: EntityId,
attributes: Iterable[
Union[PlayerAttributeId, QueueAttributeId, TrackAttributeId]
],
attributes: Iterable[PlayerAttributeId | QueueAttributeId | TrackAttributeId],
) -> None:
await self._ams_proxy.entity_update.write_value(
bytes([entity] + list(attributes)), with_response=True

View File

@@ -27,7 +27,7 @@ import datetime
import enum
import logging
import struct
from typing import Optional, Sequence, Union
from collections.abc import Sequence
from bumble import utils
from bumble.att import ATT_Error
@@ -116,7 +116,7 @@ class NotificationAttributeId(utils.OpenIntEnum):
@dataclasses.dataclass
class NotificationAttribute:
attribute_id: NotificationAttributeId
value: Union[str, int, datetime.datetime]
value: str | int | datetime.datetime
@dataclasses.dataclass
@@ -242,10 +242,10 @@ class AncsProxy(ProfileServiceProxy):
class AncsClient(utils.EventEmitter):
_expected_response_command_id: Optional[CommandId]
_expected_response_notification_uid: Optional[int]
_expected_response_app_identifier: Optional[str]
_expected_app_identifier: Optional[str]
_expected_response_command_id: CommandId | None
_expected_response_notification_uid: int | None
_expected_response_app_identifier: str | None
_expected_app_identifier: str | None
_expected_response_tuples: int
_response_accumulator: bytes
@@ -255,12 +255,12 @@ class AncsClient(utils.EventEmitter):
super().__init__()
self._ancs_proxy = ancs_proxy
self._command_semaphore = asyncio.Semaphore()
self._response: Optional[asyncio.Future] = None
self._response: asyncio.Future | None = None
self._reset_response()
self._started = False
@classmethod
async def for_peer(cls, peer: Peer) -> Optional[AncsClient]:
async def for_peer(cls, peer: Peer) -> AncsClient | None:
ancs_proxy = await peer.discover_service_and_create_proxy(AncsProxy)
if ancs_proxy is None:
return None
@@ -316,7 +316,7 @@ class AncsClient(utils.EventEmitter):
# Not enough data yet.
return
attributes: list[Union[NotificationAttribute, AppAttribute]] = []
attributes: list[NotificationAttribute | AppAttribute] = []
if command_id == CommandId.GET_NOTIFICATION_ATTRIBUTES:
(notification_uid,) = struct.unpack_from(
@@ -342,7 +342,7 @@ class AncsClient(utils.EventEmitter):
str_value = attribute_data[3 : 3 + attribute_data_length].decode(
"utf-8"
)
value: Union[str, int, datetime.datetime]
value: str | int | datetime.datetime
if attribute_id == NotificationAttributeId.MESSAGE_SIZE:
value = int(str_value)
elif attribute_id == NotificationAttributeId.DATE:
@@ -415,7 +415,7 @@ class AncsClient(utils.EventEmitter):
self,
notification_uid: int,
attributes: Sequence[
Union[NotificationAttributeId, tuple[NotificationAttributeId, int]]
NotificationAttributeId | tuple[NotificationAttributeId, int]
],
) -> list[NotificationAttribute]:
if not self._started:

View File

@@ -24,7 +24,7 @@ import logging
import struct
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, Optional, TypeVar, Union
from typing import Any, TypeVar
from bumble import colors, device, gatt, gatt_client, hci, utils
from bumble.profiles import le_audio
@@ -49,7 +49,7 @@ class ASE_Operation:
classes: dict[int, type[ASE_Operation]] = {}
op_code: Opcode
name: str
fields: Optional[Sequence[Any]] = None
fields: Sequence[Any] | None = None
ase_id: Sequence[int]
class Opcode(enum.IntEnum):
@@ -278,7 +278,7 @@ class AseStateMachine(gatt.Characteristic):
EVENT_STATE_CHANGE = "state_change"
cis_link: Optional[device.CisLink] = None
cis_link: device.CisLink | None = None
# Additional parameters in CODEC_CONFIGURED State
preferred_framing = 0 # Unframed PDU supported
@@ -290,7 +290,7 @@ class AseStateMachine(gatt.Characteristic):
preferred_presentation_delay_min = 0
preferred_presentation_delay_max = 0
codec_id = hci.CodingFormat(hci.CodecID.LC3)
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
codec_specific_configuration: CodecSpecificConfiguration | bytes = b''
# Additional parameters in QOS_CONFIGURED State
cig_id = 0
@@ -610,7 +610,7 @@ class AudioStreamControlService(gatt.TemplateService):
ase_state_machines: dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic[bytes]
_active_client: Optional[device.Connection] = None
_active_client: device.Connection | None = None
def __init__(
self,

View File

@@ -19,7 +19,8 @@
import enum
import logging
import struct
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any
from bumble import data_types, gatt, gatt_client, l2cap, utils
from bumble.core import AdvertisingData
@@ -90,20 +91,20 @@ class AshaService(gatt.TemplateService):
EVENT_DISCONNECTED = "disconnected"
EVENT_VOLUME_CHANGED = "volume_changed"
audio_sink: Optional[Callable[[bytes], Any]]
active_codec: Optional[Codec] = None
audio_type: Optional[AudioType] = None
volume: Optional[int] = None
other_state: Optional[int] = None
connection: Optional[Connection] = None
audio_sink: Callable[[bytes], Any] | None
active_codec: Codec | None = None
audio_type: AudioType | None = None
volume: int | None = None
other_state: int | None = None
connection: Connection | None = None
def __init__(
self,
capability: int,
hisyncid: Union[list[int], bytes],
hisyncid: list[int] | bytes,
device: Device,
psm: int = 0,
audio_sink: Optional[Callable[[bytes], Any]] = None,
audio_sink: Callable[[bytes], Any] | None = None,
feature_map: int = FeatureMap.LE_COC_AUDIO_OUTPUT_STREAMING_SUPPORTED,
protocol_version: int = 0x01,
render_delay_milliseconds: int = 0,

View File

@@ -21,7 +21,8 @@ from __future__ import annotations
import dataclasses
import logging
import struct
from typing import ClassVar, Optional, Sequence
from collections.abc import Sequence
from typing import ClassVar
from bumble import core, device, gatt, gatt_adapters, gatt_client, hci, utils
@@ -337,7 +338,12 @@ class BroadcastAudioScanService(gatt.TemplateService):
b"12", # TEST
)
super().__init__([self.battery_level_characteristic])
super().__init__(
[
self.broadcast_audio_scan_control_point_characteristic,
self.broadcast_receive_state_characteristic,
]
)
def on_broadcast_audio_scan_control_point_write(
self, connection: device.Connection, value: bytes
@@ -351,7 +357,7 @@ class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy[bytes]
broadcast_receive_states: list[
gatt_client.CharacteristicProxy[Optional[BroadcastReceiveState]]
gatt_client.CharacteristicProxy[BroadcastReceiveState | None]
]
def __init__(self, service_proxy: gatt_client.ServiceProxy):

View File

@@ -16,7 +16,6 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from typing import Optional
from bumble.gatt import (
GATT_BATTERY_LEVEL_CHARACTERISTIC,
@@ -56,7 +55,7 @@ class BatteryService(TemplateService):
class BatteryServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = BatteryService
battery_level: Optional[CharacteristicProxy[int]]
battery_level: CharacteristicProxy[int] | None
def __init__(self, service_proxy):
self.service_proxy = service_proxy

View File

@@ -20,7 +20,6 @@ from __future__ import annotations
import enum
import struct
from typing import Optional
from bumble import core, crypto, device, gatt, gatt_client
@@ -96,17 +95,17 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
set_identity_resolving_key: bytes
set_identity_resolving_key_characteristic: gatt.Characteristic[bytes]
coordinated_set_size_characteristic: Optional[gatt.Characteristic[bytes]] = None
set_member_lock_characteristic: Optional[gatt.Characteristic[bytes]] = None
set_member_rank_characteristic: Optional[gatt.Characteristic[bytes]] = None
coordinated_set_size_characteristic: gatt.Characteristic[bytes] | None = None
set_member_lock_characteristic: gatt.Characteristic[bytes] | None = None
set_member_rank_characteristic: gatt.Characteristic[bytes] | None = None
def __init__(
self,
set_identity_resolving_key: bytes,
set_identity_resolving_key_type: SirkType,
coordinated_set_size: Optional[int] = None,
set_member_lock: Optional[MemberLock] = None,
set_member_rank: Optional[int] = None,
coordinated_set_size: int | None = None,
set_member_lock: MemberLock | None = None,
set_member_rank: int | None = None,
) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise core.InvalidArgumentError(
@@ -198,9 +197,9 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CoordinatedSetIdentificationService
set_identity_resolving_key: gatt_client.CharacteristicProxy[bytes]
coordinated_set_size: Optional[gatt_client.CharacteristicProxy[bytes]] = None
set_member_lock: Optional[gatt_client.CharacteristicProxy[bytes]] = None
set_member_rank: Optional[gatt_client.CharacteristicProxy[bytes]] = None
coordinated_set_size: gatt_client.CharacteristicProxy[bytes] | None = None
set_member_lock: gatt_client.CharacteristicProxy[bytes] | None = None
set_member_rank: gatt_client.CharacteristicProxy[bytes] | None = None
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy

View File

@@ -17,7 +17,6 @@
# Imports
# -----------------------------------------------------------------------------
import struct
from typing import Optional
from bumble.gatt import (
GATT_DEVICE_INFORMATION_SERVICE,
@@ -54,14 +53,14 @@ class DeviceInformationService(TemplateService):
def __init__(
self,
manufacturer_name: Optional[str] = None,
model_number: Optional[str] = None,
serial_number: Optional[str] = None,
hardware_revision: Optional[str] = None,
firmware_revision: Optional[str] = None,
software_revision: Optional[str] = None,
system_id: Optional[tuple[int, int]] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: Optional[bytes] = None,
manufacturer_name: str | None = None,
model_number: str | None = None,
serial_number: str | None = None,
hardware_revision: str | None = None,
firmware_revision: str | None = None,
software_revision: str | None = None,
system_id: tuple[int, int] | None = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: bytes | None = None,
# TODO: pnp_id
):
characteristics: list[Characteristic[bytes]] = [
@@ -109,14 +108,14 @@ class DeviceInformationService(TemplateService):
class DeviceInformationServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = DeviceInformationService
manufacturer_name: Optional[CharacteristicProxy[str]]
model_number: Optional[CharacteristicProxy[str]]
serial_number: Optional[CharacteristicProxy[str]]
hardware_revision: Optional[CharacteristicProxy[str]]
firmware_revision: Optional[CharacteristicProxy[str]]
software_revision: Optional[CharacteristicProxy[str]]
system_id: Optional[CharacteristicProxy[tuple[int, int]]]
ieee_regulatory_certification_data_list: Optional[CharacteristicProxy[bytes]]
manufacturer_name: CharacteristicProxy[str] | None
model_number: CharacteristicProxy[str] | None
serial_number: CharacteristicProxy[str] | None
hardware_revision: CharacteristicProxy[str] | None
firmware_revision: CharacteristicProxy[str] | None
software_revision: CharacteristicProxy[str] | None
system_id: CharacteristicProxy[tuple[int, int]] | None
ieee_regulatory_certification_data_list: CharacteristicProxy[bytes] | None
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy

View File

@@ -19,7 +19,6 @@
# -----------------------------------------------------------------------------
import logging
import struct
from typing import Optional, Union
from bumble.core import Appearance
from bumble.gatt import (
@@ -54,7 +53,7 @@ class GenericAccessService(TemplateService):
appearance_characteristic: Characteristic[bytes]
def __init__(
self, device_name: str, appearance: Union[Appearance, tuple[int, int], int] = 0
self, device_name: str, appearance: Appearance | tuple[int, int] | int = 0
):
if isinstance(appearance, int):
appearance_int = appearance
@@ -88,8 +87,8 @@ class GenericAccessService(TemplateService):
class GenericAccessServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GenericAccessService
device_name: Optional[CharacteristicProxy[str]]
appearance: Optional[CharacteristicProxy[Appearance]]
device_name: CharacteristicProxy[str] | None
appearance: CharacteristicProxy[Appearance] | None
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy

View File

@@ -40,7 +40,6 @@ class GenericAttributeProfileService(gatt.TemplateService):
database_hash_enabled: bool = True,
service_change_enabled: bool = True,
) -> None:
if server_supported_features is not None:
self.server_supported_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC,

View File

@@ -19,7 +19,6 @@
# -----------------------------------------------------------------------------
import struct
from enum import IntFlag
from typing import Optional
from bumble.gatt import (
GATT_BGR_FEATURES_CHARACTERISTIC,
@@ -77,18 +76,18 @@ class GamingAudioService(TemplateService):
UUID = GATT_GAMING_AUDIO_SERVICE
gmap_role: Characteristic
ugg_features: Optional[Characteristic] = None
ugt_features: Optional[Characteristic] = None
bgs_features: Optional[Characteristic] = None
bgr_features: Optional[Characteristic] = None
ugg_features: Characteristic | None = None
ugt_features: Characteristic | None = None
bgs_features: Characteristic | None = None
bgr_features: Characteristic | None = None
def __init__(
self,
gmap_role: GmapRole,
ugg_features: Optional[UggFeatures] = None,
ugt_features: Optional[UgtFeatures] = None,
bgs_features: Optional[BgsFeatures] = None,
bgr_features: Optional[BgrFeatures] = None,
ugg_features: UggFeatures | None = None,
ugt_features: UgtFeatures | None = None,
bgs_features: BgsFeatures | None = None,
bgr_features: BgrFeatures | None = None,
) -> None:
characteristics = []
@@ -150,10 +149,10 @@ class GamingAudioService(TemplateService):
class GamingAudioServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GamingAudioService
ugg_features: Optional[CharacteristicProxy[UggFeatures]] = None
ugt_features: Optional[CharacteristicProxy[UgtFeatures]] = None
bgs_features: Optional[CharacteristicProxy[BgsFeatures]] = None
bgr_features: Optional[CharacteristicProxy[BgrFeatures]] = None
ugg_features: CharacteristicProxy[UggFeatures] | None = None
ugt_features: CharacteristicProxy[UgtFeatures] | None = None
bgs_features: CharacteristicProxy[BgsFeatures] | None = None
bgr_features: CharacteristicProxy[BgrFeatures] | None = None
def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy

View File

@@ -20,7 +20,7 @@ from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from typing import Any
from bumble import att, gatt, gatt_adapters, gatt_client, utils
from bumble.core import InvalidArgumentError, InvalidStateError
@@ -145,7 +145,7 @@ class PresetChangedOperation:
return bytes([self.prev_index]) + bytes(self.preset_record)
change_id: ChangeId
additional_parameters: Union[Generic, int]
additional_parameters: Generic | int
def to_bytes(self, is_last: bool) -> bytes:
if isinstance(self.additional_parameters, PresetChangedOperation.Generic):
@@ -235,7 +235,7 @@ class HearingAccessService(gatt.TemplateService):
preset_records: dict[int, PresetRecord] # key is the preset index
read_presets_request_in_progress: bool
other_server_in_binaural_set: Optional[HearingAccessService] = None
other_server_in_binaural_set: HearingAccessService | None = None
preset_changed_operations_history_per_device: dict[
Address, list[PresetChangedOperation]
@@ -273,12 +273,19 @@ class HearingAccessService(gatt.TemplateService):
def on_disconnection(_reason) -> None:
self.currently_connected_clients.discard(connection)
@connection.on(connection.EVENT_CONNECTION_ATT_MTU_UPDATE)
def on_mtu_update(*_: Any) -> None:
self.on_incoming_connection(connection)
@connection.on(connection.EVENT_CONNECTION_ENCRYPTION_CHANGE)
def on_encryption_change(*_: Any) -> None:
self.on_incoming_connection(connection)
@connection.on(connection.EVENT_PAIRING)
def on_pairing(*_: Any) -> None:
self.on_incoming_paired_connection(connection)
self.on_incoming_connection(connection)
if connection.peer_resolvable_address:
self.on_incoming_paired_connection(connection)
self.on_incoming_connection(connection)
self.hearing_aid_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC,
@@ -315,9 +322,30 @@ class HearingAccessService(gatt.TemplateService):
]
)
def on_incoming_paired_connection(self, connection: Connection):
def on_incoming_connection(self, connection: Connection):
'''Setup initial operations to handle a remote bonded HAP device'''
# TODO Should we filter on HAP device only ?
if not connection.is_encrypted:
logging.debug(f'HAS: {connection.peer_address} is not encrypted')
return
if not connection.peer_resolvable_address:
logging.debug(f'HAS: {connection.peer_address} is not paired')
return
if connection.att_mtu < 49:
logging.debug(
f'HAS: {connection.peer_address} invalid MTU={connection.att_mtu}'
)
return
if connection.peer_address in self.currently_connected_clients:
logging.debug(
f'HAS: Already connected to {connection.peer_address} nothing to do'
)
return
self.currently_connected_clients.add(connection)
if (
connection.peer_address
@@ -457,6 +485,7 @@ class HearingAccessService(gatt.TemplateService):
connection,
self.hearing_aid_preset_control_point,
value=op_list[0].to_bytes(len(op_list) == 1),
force=True, # TODO GATT notification subscription should be persistent
)
# Remove item once sent, and keep the non sent item in the list
op_list.pop(0)

View File

@@ -20,7 +20,6 @@ from __future__ import annotations
import struct
from enum import IntEnum
from typing import Optional
from bumble import core
from bumble.att import ATT_Error
@@ -207,13 +206,13 @@ class HeartRateService(TemplateService):
class HeartRateServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = HeartRateService
heart_rate_measurement: Optional[
CharacteristicProxy[HeartRateService.HeartRateMeasurement]
]
body_sensor_location: Optional[
CharacteristicProxy[HeartRateService.BodySensorLocation]
]
heart_rate_control_point: Optional[CharacteristicProxy[int]]
heart_rate_measurement: (
CharacteristicProxy[HeartRateService.HeartRateMeasurement] | None
)
body_sensor_location: (
CharacteristicProxy[HeartRateService.BodySensorLocation] | None
)
heart_rate_control_point: CharacteristicProxy[int] | None
def __init__(self, service_proxy):
self.service_proxy = service_proxy

View File

@@ -137,7 +137,7 @@ class Metadata:
values.append(str(decoded))
return '\n'.join(
f'{indent}{key}: {" " * (max_key_length-len(key))}{value}'
f'{indent}{key}: {" " * (max_key_length - len(key))}{value}'
for key, value in zip(keys, values)
)

View File

@@ -22,7 +22,7 @@ import asyncio
import dataclasses
import enum
import struct
from typing import TYPE_CHECKING, ClassVar, Optional
from typing import TYPE_CHECKING, ClassVar
from typing_extensions import Self
@@ -196,7 +196,7 @@ class MediaControlService(gatt.TemplateService):
UUID = gatt.GATT_MEDIA_CONTROL_SERVICE
def __init__(self, media_player_name: Optional[str] = None) -> None:
def __init__(self, media_player_name: str | None = None) -> None:
self.track_position = 0
self.media_player_name_characteristic = gatt.Characteristic(
@@ -337,32 +337,32 @@ class MediaControlServiceProxy(
EVENT_TRACK_DURATION = "track_duration"
EVENT_TRACK_POSITION = "track_position"
media_player_name: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_player_icon_url: Optional[gatt_client.CharacteristicProxy[bytes]] = None
track_changed: Optional[gatt_client.CharacteristicProxy[bytes]] = None
track_title: Optional[gatt_client.CharacteristicProxy[bytes]] = None
track_duration: Optional[gatt_client.CharacteristicProxy[bytes]] = None
track_position: Optional[gatt_client.CharacteristicProxy[bytes]] = None
playback_speed: Optional[gatt_client.CharacteristicProxy[bytes]] = None
seeking_speed: Optional[gatt_client.CharacteristicProxy[bytes]] = None
current_track_segments_object_id: Optional[
gatt_client.CharacteristicProxy[bytes]
] = None
current_track_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
next_track_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
parent_group_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
current_group_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
playing_order: Optional[gatt_client.CharacteristicProxy[bytes]] = None
playing_orders_supported: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_state: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_control_point: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_control_point_opcodes_supported: Optional[
gatt_client.CharacteristicProxy[bytes]
] = None
search_control_point: Optional[gatt_client.CharacteristicProxy[bytes]] = None
search_results_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
content_control_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_player_name: gatt_client.CharacteristicProxy[bytes] | None = None
media_player_icon_object_id: gatt_client.CharacteristicProxy[bytes] | None = None
media_player_icon_url: gatt_client.CharacteristicProxy[bytes] | None = None
track_changed: gatt_client.CharacteristicProxy[bytes] | None = None
track_title: gatt_client.CharacteristicProxy[bytes] | None = None
track_duration: gatt_client.CharacteristicProxy[bytes] | None = None
track_position: gatt_client.CharacteristicProxy[bytes] | None = None
playback_speed: gatt_client.CharacteristicProxy[bytes] | None = None
seeking_speed: gatt_client.CharacteristicProxy[bytes] | None = None
current_track_segments_object_id: gatt_client.CharacteristicProxy[bytes] | None = (
None
)
current_track_object_id: gatt_client.CharacteristicProxy[bytes] | None = None
next_track_object_id: gatt_client.CharacteristicProxy[bytes] | None = None
parent_group_object_id: gatt_client.CharacteristicProxy[bytes] | None = None
current_group_object_id: gatt_client.CharacteristicProxy[bytes] | None = None
playing_order: gatt_client.CharacteristicProxy[bytes] | None = None
playing_orders_supported: gatt_client.CharacteristicProxy[bytes] | None = None
media_state: gatt_client.CharacteristicProxy[bytes] | None = None
media_control_point: gatt_client.CharacteristicProxy[bytes] | None = None
media_control_point_opcodes_supported: (
gatt_client.CharacteristicProxy[bytes] | None
) = None
search_control_point: gatt_client.CharacteristicProxy[bytes] | None = None
search_results_object_id: gatt_client.CharacteristicProxy[bytes] | None = None
content_control_id: gatt_client.CharacteristicProxy[bytes] | None = None
if TYPE_CHECKING:
media_control_point_notifications: asyncio.Queue[bytes]

View File

@@ -21,7 +21,7 @@ from __future__ import annotations
import dataclasses
import logging
import struct
from typing import Optional, Sequence, Union
from collections.abc import Sequence
from bumble import gatt, gatt_adapters, gatt_client, hci
from bumble.profiles import le_audio
@@ -39,7 +39,7 @@ class PacRecord:
'''Published Audio Capabilities Service, Table 3.2/3.4.'''
coding_format: hci.CodingFormat
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
codec_specific_capabilities: CodecSpecificCapabilities | bytes
metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata)
@classmethod
@@ -56,7 +56,7 @@ class PacRecord:
offset += 1
metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size])
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
codec_specific_capabilities: CodecSpecificCapabilities | bytes
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
codec_specific_capabilities = codec_specific_capabilities_bytes
else:
@@ -101,10 +101,10 @@ class PacRecord:
class PublishedAudioCapabilitiesService(gatt.TemplateService):
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
sink_pac: Optional[gatt.Characteristic[bytes]]
sink_audio_locations: Optional[gatt.Characteristic[bytes]]
source_pac: Optional[gatt.Characteristic[bytes]]
source_audio_locations: Optional[gatt.Characteristic[bytes]]
sink_pac: gatt.Characteristic[bytes] | None
sink_audio_locations: gatt.Characteristic[bytes] | None
source_pac: gatt.Characteristic[bytes] | None
source_audio_locations: gatt.Characteristic[bytes] | None
available_audio_contexts: gatt.Characteristic[bytes]
supported_audio_contexts: gatt.Characteristic[bytes]
@@ -115,9 +115,9 @@ class PublishedAudioCapabilitiesService(gatt.TemplateService):
available_source_context: ContextType,
available_sink_context: ContextType,
sink_pac: Sequence[PacRecord] = (),
sink_audio_locations: Optional[AudioLocation] = None,
sink_audio_locations: AudioLocation | None = None,
source_pac: Sequence[PacRecord] = (),
source_audio_locations: Optional[AudioLocation] = None,
source_audio_locations: AudioLocation | None = None,
) -> None:
characteristics = []
@@ -183,14 +183,10 @@ class PublishedAudioCapabilitiesService(gatt.TemplateService):
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = PublishedAudioCapabilitiesService
sink_pac: Optional[gatt_client.CharacteristicProxy[list[PacRecord]]] = None
sink_audio_locations: Optional[gatt_client.CharacteristicProxy[AudioLocation]] = (
None
)
source_pac: Optional[gatt_client.CharacteristicProxy[list[PacRecord]]] = None
source_audio_locations: Optional[gatt_client.CharacteristicProxy[AudioLocation]] = (
None
)
sink_pac: gatt_client.CharacteristicProxy[list[PacRecord]] | None = None
sink_audio_locations: gatt_client.CharacteristicProxy[AudioLocation] | None = None
source_pac: gatt_client.CharacteristicProxy[list[PacRecord]] | None = None
source_audio_locations: gatt_client.CharacteristicProxy[AudioLocation] | None = None
available_audio_contexts: gatt_client.CharacteristicProxy[tuple[ContextType, ...]]
supported_audio_contexts: gatt_client.CharacteristicProxy[tuple[ContextType, ...]]

View File

@@ -22,6 +22,7 @@ import enum
from typing_extensions import Self
from bumble import core, data_types, gatt
from bumble.profiles import le_audio
@@ -46,3 +47,18 @@ class PublicBroadcastAnnouncement:
return cls(
features=features, metadata=le_audio.Metadata.from_bytes(metadata_ltv)
)
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
data_types.ServiceData16BitUUID(
gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE, bytes(self)
)
]
)
)
def __bytes__(self) -> bytes:
metadata_bytes = bytes(self.metadata)
return bytes([self.features, len(metadata_bytes)]) + metadata_bytes

View File

@@ -20,9 +20,9 @@ from __future__ import annotations
import dataclasses
import enum
from typing import Sequence
from collections.abc import Sequence
from bumble import att, device, gatt, gatt_adapters, gatt_client, utils
from bumble import att, device, gatt, gatt_adapters, gatt_client
# -----------------------------------------------------------------------------
# Constants

View File

@@ -18,7 +18,6 @@
# -----------------------------------------------------------------------------
import struct
from dataclasses import dataclass
from typing import Optional
from bumble import utils
from bumble.att import ATT_Error
@@ -69,7 +68,7 @@ class ErrorCode(utils.OpenIntEnum):
class VolumeOffsetState:
volume_offset: int = 0
change_counter: int = 0
attribute: Optional[Characteristic] = None
attribute: Characteristic | None = None
def __bytes__(self) -> bytes:
return struct.pack('<hB', self.volume_offset, self.change_counter)
@@ -93,7 +92,7 @@ class VolumeOffsetState:
@dataclass
class VocsAudioLocation:
audio_location: AudioLocation = AudioLocation.NOT_ALLOWED
attribute: Optional[Characteristic] = None
attribute: Characteristic | None = None
def __bytes__(self) -> bytes:
return struct.pack('<I', self.audio_location)
@@ -118,7 +117,6 @@ class VolumeOffsetControlPoint:
volume_offset_state: VolumeOffsetState
async def on_write(self, connection: Connection, value: bytes) -> None:
opcode = value[0]
if opcode != SetVolumeOffsetOpCode.SET_VOLUME_OFFSET:
raise ATT_Error(ErrorCode.OPCODE_NOT_SUPPORTED)
@@ -148,7 +146,7 @@ class VolumeOffsetControlPoint:
@dataclass
class AudioOutputDescription:
audio_output_description: str = ''
attribute: Optional[Characteristic] = None
attribute: Characteristic | None = None
@classmethod
def from_bytes(cls, data: bytes):
@@ -173,11 +171,10 @@ class VolumeOffsetControlService(TemplateService):
def __init__(
self,
volume_offset_state: Optional[VolumeOffsetState] = None,
audio_location: Optional[VocsAudioLocation] = None,
audio_output_description: Optional[AudioOutputDescription] = None,
volume_offset_state: VolumeOffsetState | None = None,
audio_location: VocsAudioLocation | None = None,
audio_output_description: AudioOutputDescription | None = None,
) -> None:
self.volume_offset_state = (
VolumeOffsetState() if volume_offset_state is None else volume_offset_state
)

View File

@@ -22,7 +22,8 @@ import collections
import dataclasses
import enum
import logging
from typing import TYPE_CHECKING, Callable, Optional, Union
from collections.abc import Callable
from typing import TYPE_CHECKING
from typing_extensions import Self
@@ -119,7 +120,7 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# -----------------------------------------------------------------------------
def make_service_sdp_records(
service_record_handle: int, channel: int, uuid: Optional[UUID] = None
service_record_handle: int, channel: int, uuid: UUID | None = None
) -> list[sdp.ServiceAttribute]:
"""
Create SDP records for an RFComm service given a channel number and an
@@ -186,7 +187,7 @@ async def find_rfcomm_channels(connection: Connection) -> dict[int, list[UUID]]:
)
for attribute_lists in search_result:
service_classes: list[UUID] = []
channel: Optional[int] = None
channel: int | None = None
for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
@@ -207,7 +208,7 @@ async def find_rfcomm_channels(connection: Connection) -> dict[int, list[UUID]]:
# -----------------------------------------------------------------------------
async def find_rfcomm_channel_with_uuid(
connection: Connection, uuid: str | UUID
) -> Optional[int]:
) -> int | None:
"""Searches an RFCOMM channel associated with given UUID from service records.
Args:
@@ -473,15 +474,15 @@ class DLC(utils.EventEmitter):
self.state = DLC.State.INIT
self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
self.connection_result: Optional[asyncio.Future] = None
self.disconnection_result: Optional[asyncio.Future] = None
self.connection_result: asyncio.Future | None = None
self.disconnection_result: asyncio.Future | None = None
self.drained = asyncio.Event()
self.drained.set()
# Queued packets when sink is not set.
self._enqueued_rx_packets: collections.deque[bytes] = collections.deque(
maxlen=DEFAULT_RX_QUEUE_SIZE
)
self._sink: Optional[Callable[[bytes], None]] = None
self._sink: Callable[[bytes], None] | None = None
# Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs
@@ -490,11 +491,11 @@ class DLC(utils.EventEmitter):
)
@property
def sink(self) -> Optional[Callable[[bytes], None]]:
def sink(self) -> Callable[[bytes], None] | None:
return self._sink
@sink.setter
def sink(self, sink: Optional[Callable[[bytes], None]]) -> None:
def sink(self, sink: Callable[[bytes], None] | None) -> None:
self._sink = sink
# Dump queued packets to sink
if sink:
@@ -674,10 +675,14 @@ class DLC(utils.EventEmitter):
while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0:
# Get the next chunk, up to MTU size
if rx_credits_needed > 0:
chunk = bytes([rx_credits_needed]) + self.tx_buffer[: self.mtu - 1]
self.tx_buffer = self.tx_buffer[len(chunk) - 1 :]
chunk = bytes([rx_credits_needed])
self.rx_credits += rx_credits_needed
tx_credit_spent = len(chunk) > 1
if self.tx_buffer and self.tx_credits > 0:
chunk += self.tx_buffer[: self.mtu - 1]
self.tx_buffer = self.tx_buffer[len(chunk) - 1 :]
tx_credit_spent = True
else:
tx_credit_spent = False
else:
chunk = self.tx_buffer[: self.mtu]
self.tx_buffer = self.tx_buffer[len(chunk) :]
@@ -708,7 +713,7 @@ class DLC(utils.EventEmitter):
self.drained.set()
# Stream protocol
def write(self, data: Union[bytes, str]) -> None:
def write(self, data: bytes | str) -> None:
# We can only send bytes
if not isinstance(data, bytes):
if isinstance(data, str):
@@ -765,10 +770,10 @@ class Multiplexer(utils.EventEmitter):
EVENT_DLC = "dlc"
connection_result: Optional[asyncio.Future]
disconnection_result: Optional[asyncio.Future]
open_result: Optional[asyncio.Future]
acceptor: Optional[Callable[[int], Optional[tuple[int, int]]]]
connection_result: asyncio.Future | None
disconnection_result: asyncio.Future | None
open_result: asyncio.Future | None
acceptor: Callable[[int], tuple[int, int] | None] | None
dlcs: dict[int, DLC]
def __init__(self, l2cap_channel: l2cap.ClassicChannel, role: Role) -> None:
@@ -780,7 +785,7 @@ class Multiplexer(utils.EventEmitter):
self.connection_result = None
self.disconnection_result = None
self.open_result = None
self.open_pn: Optional[RFCOMM_MCC_PN] = None
self.open_pn: RFCOMM_MCC_PN | None = None
self.open_rx_max_credits = 0
self.acceptor = None
@@ -1027,8 +1032,8 @@ class Multiplexer(utils.EventEmitter):
# -----------------------------------------------------------------------------
class Client:
multiplexer: Optional[Multiplexer]
l2cap_channel: Optional[l2cap.ClassicChannel]
multiplexer: Multiplexer | None
l2cap_channel: l2cap.ClassicChannel | None
def __init__(
self, connection: Connection, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
@@ -1141,7 +1146,7 @@ class Server(utils.EventEmitter):
# Notify
self.emit(self.EVENT_START, multiplexer)
def accept_dlc(self, channel_number: int) -> Optional[tuple[int, int]]:
def accept_dlc(self, channel_number: int) -> tuple[int, int] | None:
return self.dlc_configs.get(channel_number)
def on_dlc(self, dlc: DLC) -> None:

View File

@@ -20,7 +20,8 @@ from __future__ import annotations
import asyncio
import logging
import struct
from typing import TYPE_CHECKING, Iterable, NewType, Optional, Sequence, Union
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, NewType
from typing_extensions import Self
@@ -497,7 +498,7 @@ class ServiceAttribute:
@staticmethod
def find_attribute_in_list(
attribute_list: Iterable[ServiceAttribute], attribute_id: int
) -> Optional[DataElement]:
) -> DataElement | None:
return next(
(
attribute.value
@@ -528,7 +529,7 @@ class ServiceAttribute:
def to_string(self, with_colors=False):
if with_colors:
return (
f'Attribute(id={color(self.id_name(self.id),"magenta")},'
f'Attribute(id={color(self.id_name(self.id), "magenta")},'
f'value={self.value})'
)
@@ -778,11 +779,11 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
class Client:
def __init__(self, connection: Connection, mtu: int = 0) -> None:
self.connection = connection
self.channel: Optional[l2cap.ClassicChannel] = None
self.channel: l2cap.ClassicChannel | None = None
self.mtu = mtu
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request: Optional[SDP_PDU] = None
self.pending_response: Optional[asyncio.futures.Future[SDP_PDU]] = None
self.pending_request: SDP_PDU | None = None
self.pending_response: asyncio.futures.Future[SDP_PDU] | None = None
self.next_transaction_id = 0
async def connect(self) -> None:
@@ -898,7 +899,7 @@ class Client:
async def search_attributes(
self,
uuids: Iterable[core.UUID],
attribute_ids: Iterable[Union[int, tuple[int, int]]],
attribute_ids: Iterable[int | tuple[int, int]],
) -> list[list[ServiceAttribute]]:
"""
Search for attributes by UUID and attribute IDs.
@@ -970,7 +971,7 @@ class Client:
async def get_attributes(
self,
service_record_handle: int,
attribute_ids: Iterable[Union[int, tuple[int, int]]],
attribute_ids: Iterable[int | tuple[int, int]],
) -> list[ServiceAttribute]:
"""
Get attributes for a service.
@@ -1042,10 +1043,10 @@ class Client:
# -----------------------------------------------------------------------------
class Server:
CONTINUATION_STATE = bytes([0x01, 0x00])
channel: Optional[l2cap.ClassicChannel]
channel: l2cap.ClassicChannel | None
Service = NewType('Service', list[ServiceAttribute])
service_records: dict[int, Service]
current_response: Union[None, bytes, tuple[int, list[int]]]
current_response: None | bytes | tuple[int, list[int]]
def __init__(self, device: Device) -> None:
self.device = device
@@ -1123,7 +1124,7 @@ class Server:
self,
continuation_state: bytes,
transaction_id: int,
) -> Optional[bool]:
) -> bool | None:
# Check if this is a valid continuation
if len(continuation_state) > 1:
if (

View File

@@ -27,17 +27,9 @@ from __future__ import annotations
import asyncio
import enum
import logging
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
ClassVar,
Optional,
TypeVar,
cast,
)
from typing import TYPE_CHECKING, ClassVar, TypeVar, cast
from bumble import crypto, utils
from bumble.colors import color
@@ -213,10 +205,10 @@ class SMP_Command:
fields: ClassVar[Fields]
code: int = field(default=0, init=False)
name: str = field(default='', init=False)
_payload: Optional[bytes] = field(default=None, init=False)
_payload: bytes | None = field(default=None, init=False)
@classmethod
def from_bytes(cls, pdu: bytes) -> "SMP_Command":
def from_bytes(cls, pdu: bytes) -> SMP_Command:
code = pdu[0]
subclass = SMP_Command.smp_classes.get(code)
@@ -554,7 +546,7 @@ class OobContext:
r: bytes
def __init__(
self, ecc_key: Optional[crypto.EccKey] = None, r: Optional[bytes] = None
self, ecc_key: crypto.EccKey | None = None, r: bytes | None = None
) -> None:
self.ecc_key = crypto.EccKey.generate() if ecc_key is None else ecc_key
self.r = crypto.r() if r is None else r
@@ -570,7 +562,7 @@ class OobLegacyContext:
tk: bytes
def __init__(self, tk: Optional[bytes] = None) -> None:
def __init__(self, tk: bytes | None = None) -> None:
self.tk = crypto.r() if tk is None else tk
@@ -677,31 +669,31 @@ class Session:
self.stk = None
self.ltk_ediv = 0
self.ltk_rand = bytes(8)
self.link_key: Optional[bytes] = None
self.link_key: bytes | None = None
self.maximum_encryption_key_size: int = 0
self.initiator_key_distribution: int = 0
self.responder_key_distribution: int = 0
self.peer_random_value: Optional[bytes] = None
self.peer_random_value: bytes | None = None
self.peer_public_key_x: bytes = bytes(32)
self.peer_public_key_y = bytes(32)
self.peer_ltk = None
self.peer_ediv = None
self.peer_rand: Optional[bytes] = None
self.peer_rand: bytes | None = None
self.peer_identity_resolving_key = None
self.peer_bd_addr: Optional[Address] = None
self.peer_bd_addr: Address | None = None
self.peer_signature_key = None
self.peer_expected_distributions: list[type[SMP_Command]] = []
self.dh_key = b''
self.confirm_value = None
self.passkey: Optional[int] = None
self.passkey: int | None = None
self.passkey_ready = asyncio.Event()
self.passkey_step = 0
self.passkey_display = False
self.pairing_method: PairingMethod = PairingMethod.JUST_WORKS
self.pairing_config = pairing_config
self.wait_before_continuing: Optional[asyncio.Future[None]] = None
self.wait_before_continuing: asyncio.Future[None] | None = None
self.completed = False
self.ctkd_task: Optional[Awaitable[None]] = None
self.ctkd_task: Awaitable[None] | None = None
# Decide if we're the initiator or the responder
self.is_initiator = is_initiator
@@ -720,7 +712,7 @@ class Session:
# Create a future that can be used to wait for the session to complete
if self.is_initiator:
self.pairing_result: Optional[asyncio.Future[None]] = (
self.pairing_result: asyncio.Future[None] | None = (
asyncio.get_running_loop().create_future()
)
else:
@@ -828,7 +820,7 @@ class Session:
def auth_req(self) -> int:
return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2)
def get_long_term_key(self, rand: bytes, ediv: int) -> Optional[bytes]:
def get_long_term_key(self, rand: bytes, ediv: int) -> bytes | None:
if not self.sc and not self.completed:
if rand == self.ltk_rand and ediv == self.ltk_ediv:
return self.stk
@@ -939,7 +931,7 @@ class Session:
self.pairing_config.delegate.display_number(self.passkey, digits=6)
)
def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None:
def input_passkey(self, next_steps: Callable[[], None] | None = None) -> None:
# Prompt the user for the passkey displayed on the peer
def after_input(passkey: int) -> None:
self.passkey = passkey
@@ -956,7 +948,7 @@ class Session:
self.prompt_user_for_number(after_input)
def display_or_input_passkey(
self, next_steps: Optional[Callable[[], None]] = None
self, next_steps: Callable[[], None] | None = None
) -> None:
if self.passkey_display:
@@ -1006,7 +998,6 @@ class Session:
self.send_command(response)
def send_pairing_confirm_command(self) -> None:
if self.pairing_method != PairingMethod.OOB:
self.r = crypto.r()
logger.debug(f'generated random: {self.r.hex()}')
@@ -1842,7 +1833,6 @@ class Session:
self.send_public_key_command()
def next_steps() -> None:
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
@@ -1929,7 +1919,7 @@ class Manager(utils.EventEmitter):
sessions: dict[int, Session]
pairing_config_factory: Callable[[Connection], PairingConfig]
session_proxy: type[Session]
_ecc_key: Optional[crypto.EccKey]
_ecc_key: crypto.EccKey | None
def __init__(
self,
@@ -2022,7 +2012,7 @@ class Manager(utils.EventEmitter):
self.device.on_pairing_start(session.connection)
async def on_pairing(
self, session: Session, identity_address: Optional[Address], keys: PairingKeys
self, session: Session, identity_address: Address | None, keys: PairingKeys
) -> None:
# Store the keys in the key store
if self.device.keystore and identity_address is not None:
@@ -2041,7 +2031,7 @@ class Manager(utils.EventEmitter):
def get_long_term_key(
self, connection: Connection, rand: bytes, ediv: int
) -> Optional[bytes]:
) -> bytes | None:
if session := self.sessions.get(connection.handle):
return session.get_long_term_key(rand, ediv)

View File

@@ -16,13 +16,14 @@ import datetime
import logging
import os
import struct
from collections.abc import Generator
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from contextlib import contextmanager
from enum import IntEnum
from typing import BinaryIO, Generator
from typing import BinaryIO
from bumble import core
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
@@ -65,7 +66,7 @@ class BtSnooper(Snooper):
"""
IDENTIFICATION_PATTERN = b'btsnoop\0'
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1)
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1, tzinfo=datetime.timezone.utc)
TIMESTAMP_DELTA = 0x00E03AB44A676000
ONE_MS = datetime.timedelta(microseconds=1)
@@ -85,7 +86,13 @@ class BtSnooper(Snooper):
# Compute the current timestamp
timestamp = (
int((datetime.datetime.utcnow() - self.TIMESTAMP_ANCHOR) / self.ONE_MS)
int(
(
datetime.datetime.now(tz=datetime.timezone.utc)
- self.TIMESTAMP_ANCHOR
)
/ self.ONE_MS
)
+ self.TIMESTAMP_DELTA
)
@@ -129,7 +136,7 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
records will be written to that file if it can be opened/created.
The keyword args that may be referenced by the string pattern are:
now: the value of `datetime.now()`
utcnow: the value of `datetime.utcnow()`
utcnow: the value of `datetime.now(tz=datetime.timezone.utc)`
pid: the current process ID.
instance: the instance ID in the current process.
@@ -153,7 +160,7 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
global _SNOOPER_INSTANCE_COUNT
file_path = io_name.format(
now=datetime.datetime.now(),
utcnow=datetime.datetime.utcnow(),
utcnow=datetime.datetime.now(tz=datetime.timezone.utc),
pid=os.getpid(),
instance=_SNOOPER_INSTANCE_COUNT,
)

View File

@@ -18,7 +18,6 @@
import logging
import os
import re
from typing import Optional
from bumble import utils
from bumble.snoop import create_snooper
@@ -84,7 +83,12 @@ async def open_transport(name: str) -> Transport:
scheme, *tail = name.split(':', 1)
spec = tail[0] if tail else None
metadata = None
if spec and (m := re.search(r'\[(\w+=\w+(?:,\w+=\w+)*,?)\]', spec)):
# If a spec is provided, check for a metadata section in square brackets.
# The regex captures a comma-separated list of key=value pairs (allowing an
# optional trailing comma). The key is matched by \w+ and the value by [^,\]]+,
# meaning the value may contain any character except a comma or a closing
# bracket (']').
if spec and (m := re.search(r'\[(\w+=[^,\]]+(?:,\w+=[^,\]]+)*,?)\]', spec)):
metadata_str = m.group(1)
if m.start() == 0:
# <metadata><spec>
@@ -106,7 +110,7 @@ async def open_transport(name: str) -> Transport:
# -----------------------------------------------------------------------------
async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
async def _open_transport(scheme: str, spec: str | None) -> Transport:
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements

View File

@@ -16,7 +16,6 @@
# Imports
# -----------------------------------------------------------------------------
import logging
from typing import Optional, Union
import grpc.aio
@@ -44,7 +43,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
async def open_android_emulator_transport(spec: str | None) -> Transport:
'''
Open a transport connection to an Android emulator via its gRPC interface.
The parameter string has this syntax:
@@ -89,7 +88,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
logger.debug('connecting to gRPC server at %s', server_address)
channel = grpc.aio.insecure_channel(server_address)
service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]
service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub
if mode == 'host':
# Connect as a host
service = EmulatedBluetoothServiceStub(channel)

View File

@@ -22,7 +22,6 @@ import os
import pathlib
import platform
import sys
from typing import Optional
import grpc.aio
@@ -66,7 +65,7 @@ DEFAULT_VARIANT = ''
# -----------------------------------------------------------------------------
def get_ini_dir() -> Optional[pathlib.Path]:
def get_ini_dir() -> pathlib.Path | None:
if sys.platform == 'darwin':
if tmpdir := os.getenv('TMPDIR', None):
return pathlib.Path(tmpdir)
@@ -100,7 +99,7 @@ def find_grpc_port(instance_number: int) -> int:
ini_file = ini_dir / ini_file_name(instance_number)
logger.debug(f'Looking for .ini file at {ini_file}')
if ini_file.is_file():
with open(ini_file, 'r') as ini_file_data:
with open(ini_file) as ini_file_data:
for line in ini_file_data.readlines():
if '=' in line:
key, value = line.split('=')
@@ -131,7 +130,11 @@ def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
def cleanup():
logger.debug("removing .ini file")
ini_file.unlink()
try:
ini_file.unlink()
except OSError as error:
# Don't log at exception level, since this may happen normally.
logger.debug(f'failed to remove .ini file ({error})')
atexit.register(cleanup)
return True
@@ -142,7 +145,7 @@ def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
# -----------------------------------------------------------------------------
async def open_android_netsim_controller_transport(
server_host: Optional[str], server_port: int, options: dict[str, str]
server_host: str | None, server_port: int, options: dict[str, str]
) -> Transport:
if server_host == '_' or not server_host:
server_host = 'localhost'
@@ -152,21 +155,26 @@ async def open_android_netsim_controller_transport(
logger.warning("unable to publish gRPC port")
class HciDevice:
def __init__(self, context, on_data_received):
def __init__(self, context, server):
self.context = context
self.on_data_received = on_data_received
self.server = server
self.name = None
self.sink = None
self.loop = asyncio.get_running_loop()
self.done = self.loop.create_future()
self.task = self.loop.create_task(self.pump())
async def pump(self):
try:
await self.pump_loop()
except asyncio.CancelledError:
logger.debug('Pump task canceled')
if not self.done.done():
self.done.set_result(None)
finally:
if self.sink:
logger.debug('Releasing sink')
self.server.release_sink()
self.sink = None
logger.debug('Pump task terminated')
async def pump_loop(self):
while True:
@@ -182,15 +190,26 @@ async def open_android_netsim_controller_transport(
if request.WhichOneof('request_type') == 'initial_info':
logger.debug(f'Received initial info: {request}')
self.name = request.initial_info.name
# We only accept BLUETOOTH
if request.initial_info.chip.kind != ChipKind.BLUETOOTH:
logger.warning('Unsupported chip type')
error = PacketResponse(error='Unsupported chip type')
await self.context.write(error)
return
# return
continue
self.name = request.initial_info.name
continue
# Lease the sink so that no other device can send
self.sink = self.server.lease_sink(self)
if self.sink is None:
logger.warning('Another device is already connected')
error = PacketResponse(error='Device busy')
await self.context.write(error)
# return
continue
continue
# Expect a data packet
request_type = request.WhichOneof('request_type')
@@ -201,10 +220,10 @@ async def open_android_netsim_controller_transport(
continue
# Process the packet
data = (
assert self.sink is not None
self.sink(
bytes([request.hci_packet.packet_type]) + request.hci_packet.packet
)
self.on_data_received(data)
async def send_packet(self, data):
return await self.context.write(
@@ -213,12 +232,6 @@ async def open_android_netsim_controller_transport(
)
)
def terminate(self):
self.task.cancel()
async def wait_for_termination(self):
await self.done
server_address = f'{server_host}:{server_port}'
class Server(PacketStreamerServicer, ParserSource):
@@ -254,27 +267,27 @@ async def open_android_netsim_controller_transport(
return await self.device.send_packet(packet)
async def StreamPackets(self, _request_iterator, context):
def lease_sink(self, device):
if self.device:
return None
self.device = device
return self.parser.feed_data
def release_sink(self):
self.device = None
async def StreamPackets(self, request_iterator, context):
logger.debug('StreamPackets request')
# Check that we don't already have a device
if self.device:
logger.debug('Busy, already serving a device')
return PacketResponse(error='Busy')
# Instantiate a new device
self.device = HciDevice(context, self.parser.feed_data)
device = HciDevice(context, self)
# Wait for the device to terminate
logger.debug('Waiting for device to terminate')
# Pump packets to/from the device
logger.debug('Pumping device packets')
try:
await self.device.wait_for_termination()
except asyncio.CancelledError:
logger.debug('Request canceled')
self.device.terminate()
logger.debug('Device terminated')
self.device = None
await device.pump()
finally:
logger.debug('Pump terminated')
server = Server()
await server.start()
@@ -287,9 +300,9 @@ async def open_android_netsim_controller_transport(
# -----------------------------------------------------------------------------
async def open_android_netsim_host_transport_with_address(
server_host: Optional[str],
server_host: str | None,
server_port: int,
options: Optional[dict[str, str]] = None,
options: dict[str, str] | None = None,
):
if server_host == '_' or not server_host:
server_host = 'localhost'
@@ -314,7 +327,7 @@ async def open_android_netsim_host_transport_with_address(
# -----------------------------------------------------------------------------
async def open_android_netsim_host_transport_with_channel(
channel, options: Optional[dict[str, str]] = None
channel, options: dict[str, str] | None = None
):
# Wrapper for I/O operations
class HciDevice:
@@ -394,7 +407,7 @@ async def open_android_netsim_host_transport_with_channel(
# -----------------------------------------------------------------------------
async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
async def open_android_netsim_transport(spec: str | None) -> Transport:
'''
Open a transport connection as a client or server, implementing Android's `netsim`
simulator protocol over gRPC.

View File

@@ -22,7 +22,8 @@ import contextlib
import io
import logging
import struct
from typing import Any, ContextManager, Optional, Protocol
from collections.abc import Awaitable, Callable
from typing import Any, Protocol
from bumble import core, hci
from bumble.colors import color
@@ -106,11 +107,11 @@ class PacketParser:
NEED_LENGTH = 1
NEED_BODY = 2
sink: Optional[TransportSink]
sink: TransportSink | None
extended_packet_info: dict[int, tuple[int, int, str]]
packet_info: Optional[tuple[int, int, str]] = None
packet_info: tuple[int, int, str] | None = None
def __init__(self, sink: Optional[TransportSink] = None) -> None:
def __init__(self, sink: TransportSink | None = None) -> None:
self.sink = sink
self.extended_packet_info = {}
self.reset()
@@ -175,7 +176,7 @@ class PacketReader:
self.source = source
self.at_end = False
def next_packet(self) -> Optional[bytes]:
def next_packet(self) -> bytes | None:
# Get the packet type
packet_type = self.source.read(1)
if len(packet_type) != 1:
@@ -252,7 +253,7 @@ class BaseSource:
"""
terminated: asyncio.Future[None]
sink: Optional[TransportSink]
sink: TransportSink | None
def __init__(self) -> None:
self.terminated = asyncio.get_running_loop().create_future()
@@ -356,7 +357,7 @@ class Transport:
# -----------------------------------------------------------------------------
class PumpedPacketSource(ParserSource):
pump_task: Optional[asyncio.Task[None]]
pump_task: asyncio.Task[None] | None
def __init__(self, receive) -> None:
super().__init__()
@@ -389,15 +390,17 @@ class PumpedPacketSource(ParserSource):
# -----------------------------------------------------------------------------
class PumpedPacketSink:
def __init__(self, send):
pump_task: asyncio.Task[None] | None
def __init__(self, send: Callable[[bytes], Awaitable[Any]]):
self.send_function = send
self.packet_queue = asyncio.Queue()
self.packet_queue = asyncio.Queue[bytes]()
self.pump_task = None
def on_packet(self, packet: bytes) -> None:
self.packet_queue.put_nowait(packet)
def start(self):
def start(self) -> None:
async def pump_packets():
while True:
try:
@@ -440,7 +443,7 @@ class SnoopingTransport(Transport):
@staticmethod
def create_with(
transport: Transport, snooper: ContextManager[Snooper]
transport: Transport, snooper: contextlib.AbstractContextManager[Snooper]
) -> SnoopingTransport:
"""
Create an instance given a snooper that works as as context manager.

View File

@@ -16,7 +16,6 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import io
import logging
from bumble.transport.common import StreamPacketSink, StreamPacketSource, Transport
@@ -36,7 +35,7 @@ async def open_file_transport(spec: str) -> Transport:
'''
# Open the file
file = io.open(spec, 'r+b', buffering=0)
file = open(spec, 'r+b', buffering=0)
# Setup reading
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(

View File

@@ -22,7 +22,6 @@ import logging
import os
import socket
import struct
from typing import Optional
from bumble.transport.common import ParserSource, Transport
@@ -33,7 +32,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
async def open_hci_socket_transport(spec: str | None) -> Transport:
'''
Open an HCI Socket (only available on some platforms).
The parameter string is either empty (to use the first/default Bluetooth adapter)
@@ -87,7 +86,7 @@ async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
)
!= 0
):
raise IOError(ctypes.get_errno(), os.strerror(ctypes.get_errno()))
raise OSError(ctypes.get_errno(), os.strerror(ctypes.get_errno()))
class HciSocketSource(ParserSource):
def __init__(self, hci_socket):

View File

@@ -17,12 +17,10 @@
# -----------------------------------------------------------------------------
import asyncio
import atexit
import io
import logging
import os
import pty
import tty
from typing import Optional
from bumble.transport.common import StreamPacketSink, StreamPacketSource, Transport
@@ -33,7 +31,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_pty_transport(spec: Optional[str]) -> Transport:
async def open_pty_transport(spec: str | None) -> Transport:
'''
Open a PTY transport.
The parameter string may be empty, or a path name where a symbolic link
@@ -48,11 +46,11 @@ async def open_pty_transport(spec: Optional[str]) -> Transport:
tty.setraw(replica)
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
StreamPacketSource, io.open(primary, 'rb', closefd=False)
StreamPacketSource, open(primary, 'rb', closefd=False)
)
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
asyncio.BaseProtocol, io.open(primary, 'wb', closefd=False)
asyncio.BaseProtocol, open(primary, 'wb', closefd=False)
)
packet_sink = StreamPacketSink(write_transport)

View File

@@ -19,7 +19,6 @@ import asyncio
import logging
import threading
import time
from typing import Optional
import usb.core
import usb.util
@@ -284,7 +283,9 @@ async def open_pyusb_transport(spec: str) -> Transport:
device = await _power_cycle(device) # type: ignore
except Exception as e:
logging.debug(e, stack_info=True)
logging.info(f"Unable to power cycle {hex(device.idVendor)} {hex(device.idProduct)}") # type: ignore
logging.info(
f"Unable to power cycle {hex(device.idVendor)} {hex(device.idProduct)}"
) # type: ignore
# Collect the metadata
device_metadata = {'vendor_id': device.idVendor, 'product_id': device.idProduct}
@@ -370,7 +371,9 @@ async def _power_cycle(device: UsbDevice) -> UsbDevice:
# Device needs to be find again otherwise it will appear as disconnected
return usb.core.find(idVendor=device.idVendor, idProduct=device.idProduct) # type: ignore
except USBError:
logger.exception(f"Adjustment needed: Please revise the udev rule for device {hex(device.idVendor)}:{hex(device.idProduct)} for proper recognition.") # type: ignore
logger.exception(
f"Adjustment needed: Please revise the udev rule for device {hex(device.idVendor)}:{hex(device.idProduct)} for proper recognition."
) # type: ignore
return device
@@ -385,7 +388,7 @@ def _set_port_status(device: UsbDevice, port: int, on: bool):
)
def _find_device_by_path(sys_path: str) -> Optional[UsbDevice]:
def _find_device_by_path(sys_path: str) -> UsbDevice | None:
"""Finds a USB device based on its system path."""
bus_num, *port_parts = sys_path.split('-')
ports = [int(port) for port in port_parts[0].split('.')]
@@ -398,7 +401,7 @@ def _find_device_by_path(sys_path: str) -> Optional[UsbDevice]:
return None
def _find_hub_by_device_path(sys_path: str) -> Optional[UsbDevice]:
def _find_hub_by_device_path(sys_path: str) -> UsbDevice | None:
"""Finds the USB hub associated with a specific device path."""
hub_sys_path = sys_path.rsplit('.', 1)[0]
hub_device = _find_device_by_path(hub_sys_path)

View File

@@ -28,25 +28,56 @@ from bumble.transport.common import StreamPacketSink, StreamPacketSource, Transp
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_POST_OPEN_DELAY = 0.5 # in seconds
# -----------------------------------------------------------------------------
# Classes and Functions
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class SerialPacketSource(StreamPacketSource):
def __init__(self) -> None:
super().__init__()
self._ready = asyncio.Event()
async def wait_until_ready(self) -> None:
await self._ready.wait()
def connection_made(self, transport: asyncio.BaseTransport) -> None:
logger.debug('connection made')
self._ready.set()
def connection_lost(self, exc: Exception | None) -> None:
logger.debug('connection lost')
self.on_transport_lost()
# -----------------------------------------------------------------------------
async def open_serial_transport(spec: str) -> Transport:
'''
Open a serial port transport.
The parameter string has this syntax:
<device-path>[,<speed>][,rtscts][,dsrdtr]
<device-path>[,<speed>][,rtscts][,dsrdtr][,delay]
When <speed> is omitted, the default value of 1000000 is used
When "rtscts" is specified, RTS/CTS hardware flow control is enabled
When "dsrdtr" is specified, DSR/DTR hardware flow control is enabled
When "delay" is specified, a short delay is added after opening the port
Examples:
/dev/tty.usbmodem0006839912172
/dev/tty.usbmodem0006839912172,1000000
/dev/tty.usbmodem0006839912172,rtscts
/dev/tty.usbmodem0006839912172,rtscts,delay
'''
speed = 1000000
rtscts = False
dsrdtr = False
delay = 0.0
if ',' in spec:
parts = spec.split(',')
device = parts[0]
@@ -55,13 +86,16 @@ async def open_serial_transport(spec: str) -> Transport:
rtscts = True
elif part == 'dsrdtr':
dsrdtr = True
elif part == 'delay':
delay = DEFAULT_POST_OPEN_DELAY
elif part.isnumeric():
speed = int(part)
else:
device = spec
serial_transport, packet_source = await serial_asyncio.create_serial_connection(
asyncio.get_running_loop(),
StreamPacketSource,
SerialPacketSource,
device,
baudrate=speed,
rtscts=rtscts,
@@ -69,4 +103,23 @@ async def open_serial_transport(spec: str) -> Transport:
)
packet_sink = StreamPacketSink(serial_transport)
logger.debug('waiting for the port to be ready')
await packet_source.wait_until_ready()
logger.debug('port is ready')
# Try to assert DTR
assert serial_transport.serial is not None
try:
serial_transport.serial.dtr = True
logger.debug(
f"DSR={serial_transport.serial.dsr}, DTR={serial_transport.serial.dtr}"
)
except Exception as e:
logger.warning(f'could not assert DTR: {e}')
# Wait a bit after opening the port, if requested
if delay > 0.0:
logger.debug(f'waiting {delay} seconds after opening the port')
await asyncio.sleep(delay)
return Transport(packet_source, packet_sink)

View File

@@ -16,7 +16,6 @@
# Imports
# -----------------------------------------------------------------------------
import logging
from typing import Optional
from bumble.transport.common import Transport
from bumble.transport.file import open_file_transport
@@ -28,7 +27,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_vhci_transport(spec: Optional[str]) -> Transport:
async def open_vhci_transport(spec: str | None) -> Transport:
'''
Open a VHCI transport (only available on some platforms).
The parameter string is either empty (to use the default VHCI device

View File

@@ -17,7 +17,7 @@
# -----------------------------------------------------------------------------
import logging
import websockets.client
import websockets.asyncio.client
from bumble.transport.common import (
PumpedPacketSink,
@@ -42,7 +42,7 @@ async def open_ws_client_transport(spec: str) -> Transport:
Example: ws://localhost:7681/v1/websocket/bt
'''
websocket = await websockets.client.connect(spec)
websocket = await websockets.asyncio.client.connect(spec)
class WsTransport(PumpedTransport):
async def close(self):

View File

@@ -17,7 +17,7 @@
# -----------------------------------------------------------------------------
import logging
import websockets
import websockets.asyncio.server
from bumble.transport.common import ParserSource, PumpedPacketSink, Transport
@@ -40,7 +40,12 @@ async def open_ws_server_transport(spec: str) -> Transport:
'''
class WsServerTransport(Transport):
def __init__(self):
sink: PumpedPacketSink
source: ParserSource
connection: websockets.asyncio.server.ServerConnection | None
server: websockets.asyncio.server.Server | None
def __init__(self) -> None:
source = ParserSource()
sink = PumpedPacketSink(self.send_packet)
self.connection = None
@@ -48,17 +53,19 @@ async def open_ws_server_transport(spec: str) -> Transport:
super().__init__(source, sink)
async def serve(self, local_host, local_port):
async def serve(self, local_host: str, local_port: str) -> None:
self.sink.start()
# pylint: disable-next=no-member
self.server = await websockets.serve(
ws_handler=self.on_connection,
self.server = await websockets.asyncio.server.serve(
handler=self.on_connection,
host=local_host if local_host != '_' else None,
port=int(local_port),
)
logger.debug(f'websocket server ready on port {local_port}')
async def on_connection(self, connection):
async def on_connection(
self, connection: websockets.asyncio.server.ServerConnection
) -> None:
logger.debug(
f'new connection on {connection.local_address} '
f'from {connection.remote_address}'
@@ -77,11 +84,11 @@ async def open_ws_server_transport(spec: str) -> Transport:
# We're now disconnected
self.connection = None
async def send_packet(self, packet):
async def send_packet(self, packet: bytes) -> None:
if self.connection is None:
logger.debug('no connection, dropping packet')
return
return await self.connection.send(packet)
await self.connection.send(packet)
local_host, local_port = spec.rsplit(':', maxsplit=1)
transport = WsServerTransport()

View File

@@ -22,16 +22,12 @@ import collections
import enum
import functools
import logging
import sys
import warnings
from collections.abc import Awaitable, Callable
from typing import (
Any,
Awaitable,
Callable,
Optional,
Protocol,
TypeVar,
Union,
overload,
)
@@ -170,8 +166,8 @@ class EventWatcher:
) -> _Handler: ...
def on(
self, emitter: pyee.EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
self, emitter: pyee.EventEmitter, event: str, handler: _Handler | None = None
) -> _Handler | Callable[[_Handler], _Handler]:
'''Watch an event until the context is closed.
Args:
@@ -199,8 +195,8 @@ class EventWatcher:
) -> _Handler: ...
def once(
self, emitter: pyee.EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
self, emitter: pyee.EventEmitter, event: str, handler: _Handler | None = None
) -> _Handler | Callable[[_Handler], _Handler]:
'''Watch an event for once.
Args:
@@ -241,11 +237,7 @@ def cancel_on_event(
return
msg = f'abort: {event} event occurred.'
if isinstance(future, asyncio.Task):
# python < 3.9 does not support passing a message on `Task.cancel`
if sys.version_info < (3, 9, 0):
future.cancel()
else:
future.cancel(msg)
future.cancel(msg)
else:
future.set_exception(asyncio.CancelledError(msg))
@@ -537,3 +529,20 @@ class IntConvertible(Protocol):
def __init__(self, value: int) -> None: ...
def __int__(self) -> int: ...
# -----------------------------------------------------------------------------
def crc_16(data: bytes) -> int:
"""Calculate CRC-16-IBM of given data.
Polynomial = x^16 + x^15 + x^2 + 1 = 0x8005 or 0xA001(Reversed)
"""
crc = 0x0000
for byte in data:
crc ^= byte
for _ in range(8):
if (crc & 0x0001) > 0:
crc = (crc >> 1) ^ 0xA001
else:
crc = crc >> 1
return crc

View File

@@ -18,7 +18,6 @@
import dataclasses
import struct
from dataclasses import field
from typing import Optional
from bumble import hci
@@ -51,6 +50,7 @@ class HCI_LE_Get_Vendor_Capabilities_Command(hci.HCI_Command):
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#vendor-specific-capabilities
'''
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('max_advt_instances', 1),
@@ -137,6 +137,7 @@ class HCI_Get_Controller_Activity_Energy_Info_Command(hci.HCI_Command):
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_get_controller_activity_energy_info
'''
return_parameters_fields = [
('status', hci.STATUS_SPEC),
('total_tx_time_ms', 4),
@@ -207,7 +208,7 @@ class HCI_Android_Vendor_Event(hci.HCI_Extended_Event):
@classmethod
def subclass_from_parameters(
cls, parameters: bytes
) -> Optional[hci.HCI_Extended_Event]:
) -> hci.HCI_Extended_Event | None:
subevent_code = parameters[0]
if subevent_code == HCI_BLUETOOTH_QUALITY_REPORT_EVENT:
quality_report_id = parameters[1]
@@ -229,6 +230,7 @@ class HCI_Bluetooth_Quality_Report_Event(HCI_Android_Vendor_Event):
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#bluetooth-quality-report-sub-event
'''
quality_report_id: int = field(metadata=hci.metadata(1))
packet_types: int = field(metadata=hci.metadata(1))
connection_handle: int = field(metadata=hci.metadata(2))

Some files were not shown because too many files have changed in this diff Show More