Compare commits

..

107 Commits

Author SHA1 Message Date
Lucas Abel
4c6320f98a Merge pull request #142 from AlanRosenthal/main
Fix small bug with services set via --device-config
2023-03-23 12:47:14 -07:00
Lucas Abel
cc0d56ad14 Merge pull request #152 from duohoo/g722_decoder
Add G722 decoder with pure python implementation
2023-03-23 12:45:07 -07:00
Lucas Abel
0019fa8e79 Merge pull request #149 from yuyangh/yuyangh/add_ASHA_event_emit
Add ASHA event emitter
2023-03-23 12:44:42 -07:00
Lucas Abel
7ae1bf8959 Merge pull request #148 from yuyangh/yuyangh/add_audio_status_point
Add ASHA audio status point
2023-03-23 12:43:35 -07:00
Yuyang Huang
9541cb6db0 Add ASHA audio status point 2023-03-23 12:15:10 -07:00
Lucas Abel
1cd13dfc19 Merge pull request #153 from benquike/main
Add 1 bug fix and a few features in bumble
2023-03-23 10:31:02 -07:00
Hui Peng
d4346c3c9b delegate the HCI_PIN_Code_Request event on host 2023-03-23 10:14:56 -07:00
Hui Peng
afe8765508 Add on_pin_code_request to support legacy BT classic pairing 2023-03-23 10:14:56 -07:00
Hui Peng
41d1772cb5 Add test for HCI_PIN_Code_Request_Reply_Command 2023-03-23 10:14:51 -07:00
Hui Peng
6e9078d60e Add implemenetation of HCI_PIN_Code_Request_Reply_Command 2023-03-23 09:50:50 -07:00
Hui Peng
d5c7d0db57 Fix a bug in HCI_Object.dict_from_bytes 2023-03-23 08:57:10 -07:00
Hui Peng
b70ebdef73 Allow Device.enable_classic to be configurable 2023-03-23 08:56:32 -07:00
Duo Ho
3af027e234 fix comments 2023-03-23 04:36:02 +00:00
Gilles Boccon-Gibod
6e719ca9fd Merge pull request #147 from google/gbg/btbench
add benchmark tool and doc
2023-03-22 21:13:24 -07:00
Duo Ho
1a580d1c1e Add G722 decoder with pure python implementation 2023-03-23 03:07:45 +00:00
Alan Rosenthal
aee7348687 Merge pull request #151 from AlanRosenthal/alan/add-types-via-pytypes
Used pytype to find some missing types / fix small issue
2023-03-22 20:58:25 -04:00
Gilles Boccon-Gibod
864889ccab rename .run to .spawn 2023-03-22 17:26:32 -07:00
Alan Rosenthal
fda00dcb28 Used pytype to find some missing types
```
pytype --pythonpath . ./bumble/device.py
```
2023-03-22 14:46:41 +00:00
Yuyang Huang
77e5618ce7 Add ASHA event emitter 2023-03-21 18:00:50 -07:00
Yuyang Huang
6fa857ad13 Add ASHA event emitter 2023-03-21 15:38:29 -07:00
Gilles Boccon-Gibod
bc29f327ef address PR comments, take 2. 2023-03-21 15:33:34 -07:00
Gilles Boccon-Gibod
1894b96de4 address PR comments 2023-03-21 15:01:46 -07:00
Gilles Boccon-Gibod
c4fb63d35c Merge pull request #146 from google/gbg/snoop-file
add auto-snooping for transports
2023-03-21 09:15:07 -07:00
Gilles Boccon-Gibod
33ae047765 add reversed role example doc 2023-03-20 18:35:22 -07:00
Gilles Boccon-Gibod
1efa2e9d44 add benchmark tool and doc 2023-03-20 18:25:21 -07:00
Gilles Boccon-Gibod
aa9af61cbe improve exception messages 2023-03-20 12:14:28 -07:00
Gilles Boccon-Gibod
dc3ac3060e add auto-snooping for transports 2023-03-20 11:06:50 -07:00
Alan Rosenthal
c34c5fdf17 Fix small bug with services set via --device-config
before:
```
  File "/home/alanrosenthal/code/fitbit/bumble/bumble/gatt.py", line 572, in __str__
    f'Descriptor(handle=0x{self.handle:04X}, '
  File "/home/alanrosenthal/code/fitbit/bumble/bumble/att.py", line 756, in read_value
    self.permissions & self.READ_REQUIRES_ENCRYPTION
TypeError: unsupported operand type(s) for &: 'str' and 'int'
```
2023-03-14 18:16:46 -04:00
Gilles Boccon-Gibod
e77723a5f9 Merge pull request #135 from google/gbg/snoop
add snoop support
2023-03-07 09:16:33 -08:00
Gilles Boccon-Gibod
fe8cf51432 Merge pull request #139 from google/gbg/hotfix-001
two small hotfixes
2023-03-07 09:16:15 -08:00
Gilles Boccon-Gibod
97a0e115ae two small hotfixes 2023-03-05 20:24:16 -08:00
Lucas Abel
46e7aac77c Merge pull request #138 from rahularya50/aryarahul/fix-att-perms
Add support for ATT permissions on server-side
2023-03-03 16:18:45 -08:00
Rahul Arya
08a6f4fa49 Add support for ATT permissions on server-side 2023-03-03 16:11:33 -08:00
Lucas Abel
ca063eda0b Merge pull request #132 from rahularya50/aryarahul/fix-uuid
Fix UUID byte-order in serialization
2023-03-03 15:48:50 -08:00
Rahul Arya
c97ba4319f Fix UUID byte-order in serialization 2023-03-03 22:38:21 +00:00
Gilles Boccon-Gibod
a5275ade29 add snoop support 2023-03-02 14:34:49 -08:00
Lucas Abel
e7b39c4188 Merge pull request #130 from google/uael/self-host-ainsicolors
Effort to make Bumble self hosted into AOSP
2023-02-23 15:31:23 -08:00
uael
0594eaef09 link: make websockets import lazy 2023-02-23 21:06:12 +00:00
uael
05200284d2 a2dp: get rid of construct dependency 2023-02-23 21:01:17 +00:00
uael
d21da78aa3 overall: host a minimal copy of ainsicolors 2023-02-23 20:53:06 +00:00
Gilles Boccon-Gibod
fbc7cf02a3 Merge pull request #129 from google/gbg/smp-improvements
improve smp compatibility with other OS flows
2023-02-14 19:10:51 -08:00
Gilles Boccon-Gibod
a8beb6b1ff remove stale comment 2023-02-14 16:05:46 -08:00
Gilles Boccon-Gibod
2d44de611f make pylint happy 2023-02-14 16:04:20 -08:00
Lucas Abel
9874bb3b37 Merge pull request #128 from google/uael/device-smp-patch
Small patches for device and SMP
2023-02-14 13:15:16 -08:00
uael
6645ad47ee smp: add a small type hint 2023-02-14 21:04:39 +00:00
uael
ad27de7717 device: remove "feature" which enable accept to return the same connection has connect 2023-02-14 21:04:39 +00:00
Gilles Boccon-Gibod
e6fc63b2d8 improve smp compatibility with other OS flows 2023-02-13 10:53:00 -08:00
Gilles Boccon-Gibod
1321c7da81 Merge pull request #125 from google/gbg/gh-124
fix getting the filename from the keystore option.
2023-02-10 20:17:38 -08:00
Gilles Boccon-Gibod
5a1b03fd91 format 2023-02-08 10:54:27 -08:00
Gilles Boccon-Gibod
de47721753 fix typo caused by an earlier refactor. 2023-02-08 09:56:11 -08:00
Gilles Boccon-Gibod
83a76a75d3 fix getting the filename from the keystore option. 2023-02-08 09:40:19 -08:00
Lucas Abel
d5b5ef8313 Merge pull request #122 from google/uael/abort-on-fix-invalid-state
utils: fix possible invalide state error while canceling future for `abort_on`
2023-02-06 17:13:34 -08:00
uael
856a8d53cd utils: fix possible invalide state error while canceling future for abort_on 2023-02-06 16:58:23 +00:00
Gilles Boccon-Gibod
177c273a57 Merge pull request #121 from google/gbg/replace-bitstruct
replace bitstruct with construct
2023-02-05 11:33:36 -08:00
Gilles Boccon-Gibod
24a863983d Merge branch 'gbg/replace-bitstruct' of https://github.com/google/bumble into gbg/replace-bitstruct
# Conflicts:
#	bumble/a2dp.py
#	pyproject.toml
2023-02-04 09:31:18 -08:00
Gilles Boccon-Gibod
b7ef09d4a3 fix format 2023-02-04 09:26:31 -08:00
Gilles Boccon-Gibod
b5b6cd13b8 replace bitstruct with construct 2023-02-04 09:23:13 -08:00
Gilles Boccon-Gibod
ef781bc374 replace bitstruct with construct 2023-02-03 19:41:07 -08:00
Lucas Abel
00978c1d63 Merge pull request #118 from google/uael/type-hints
overall: add types hints to the small subset used by avatar
2023-02-02 12:48:40 -08:00
uael
b731f6f556 overall: add types hints to the small subset used by avatar 2023-02-02 19:37:55 +00:00
Lucas Abel
ed261886e1 Merge pull request #119 from google/uael/fix-ci-packages-version
build: fix version of packages running checks in CI
2023-02-02 11:03:34 -08:00
uael
5e18094c31 build: fix version of packages running checks in CI 2023-02-02 17:23:15 +00:00
Lucas Abel
9a9b4e5bf1 Merge pull request #117 from google/uael/host-fixes
host: fixed `.latency` attribute error
2023-01-27 17:38:11 -08:00
Abel Lucas
895f1618d8 host: fixed .latency attribute error 2023-01-27 23:05:43 +00:00
Gilles Boccon-Gibod
52746e0c68 Merge pull request #116 from google/barbibulle-patch-1
fix libusb-package dependency
2023-01-25 15:59:42 -08:00
Gilles Boccon-Gibod
f9b7072423 Update setup.cfg 2023-01-25 15:37:33 -08:00
Gilles Boccon-Gibod
fa4be1958f Merge pull request #114 from google/gbg/fix-constant-typo
fix typo in constant name
2023-01-23 08:50:07 -08:00
Gilles Boccon-Gibod
f1686d8a9a fix typo in constant name 2023-01-22 19:10:13 -08:00
Gilles Boccon-Gibod
5c6a7f2036 Merge pull request #113 from google/gbg/mypy
add basic support for mypy type checking
2023-01-20 08:08:19 -08:00
Gilles Boccon-Gibod
99758e4b7d add basic support for mypy type checking 2023-01-20 00:20:50 -08:00
Alan Rosenthal
7385de6a69 Merge pull request #95 from AlanRosenthal/alan/fix_show_attributes
Fix `show attributes`
2023-01-19 14:57:22 -05:00
Alan Rosenthal
bb297e7516 Fix show attributes
`show attributes` wasn't being populated since `show_attributes()` was never called.

Also updated `show attributes` to match the color and indentation of `show services`
2023-01-19 12:21:37 -05:00
Lucas Abel
8a91c614c7 Merge pull request #109 from qiaoccolato/main
transport: make libusb_package optional
2023-01-18 14:48:05 -08:00
Qiao Yang
70a50a74b7 transport: make libusb_package optional 2023-01-17 15:17:11 -08:00
Gilles Boccon-Gibod
6a16c61c5f Merge pull request #111 from google/gbg/fix-null-address-setting
don't set a random address when it is 00:00:00:00:00:00
2023-01-13 21:35:32 -08:00
Gilles Boccon-Gibod
0a22f2f7c7 use HCI_LE_Rand 2023-01-13 16:59:34 -08:00
Gilles Boccon-Gibod
422b05ad51 don't set a random address when it is 00:00:00:00:00:00 2023-01-13 13:22:27 -08:00
Gilles Boccon-Gibod
16e926a216 Merge pull request #107 from yuyangh/yuyangh/add_ASHA_L2CAP
add ASHA L2CAP and Event Emitter
2023-01-13 11:05:16 -08:00
Gilles Boccon-Gibod
e94dc66d0c Merge pull request #110 from aleksandrovrts/hci-socket_fix
Fix bug when use hci-socket transport
2023-01-11 09:35:23 -08:00
Aleksandr Aleksandrov
e37c77532b hci_socket.py: fix socket.fileno() call 2023-01-11 16:16:45 +03:00
Gilles Boccon-Gibod
8b9ce03e86 Merge pull request #108 from google/gbg/fix-bluez-vhci
support more commands in controller.py
2023-01-08 14:40:26 -08:00
Gilles Boccon-Gibod
7e854efbbb support more commands in controller.py 2023-01-06 21:51:47 -08:00
Yuyang Huang
64b75be29b add psm parameter for testing support 2023-01-03 16:39:45 -08:00
Yuyang Huang
06018211fe emit event for ASHA l2cap packet 2023-01-03 15:01:32 -08:00
Yuyang Huang
e640991608 Merge branch 'google:main' into yuyangh/add_ASHA_L2CAP 2023-01-03 14:58:37 -08:00
Yuyang Huang
1068a6858d improve logging 2022-12-20 13:33:18 -08:00
Lucas Abel
17db5dd4ff Merge pull request #103 from google/uael/device-fixes
Misc device fixes
2022-12-20 12:15:49 -08:00
Abel Lucas
ea0a7e2347 device: commit LE connection **before** reading it's PHY 2022-12-20 19:25:43 +00:00
Yuyang Huang
6febd1ba35 add L2CAP CoC to ASHA 2022-12-20 11:15:58 -08:00
Gilles Boccon-Gibod
ea6a8d4339 Merge pull request #104 from google/gbg/fix-windll-load
fix libusb loading on Windows
2022-12-20 08:05:57 -08:00
Abel Lucas
ce049865a4 device: always prefer R2 for remote name request 2022-12-20 01:48:08 +00:00
Gilles Boccon-Gibod
6e0129b71d fix libusb loading on Windows 2022-12-18 22:00:26 -08:00
Gilles Boccon-Gibod
7ae3a1d973 Merge pull request #101 from google/gbg/formatting-linting-automation
formatting linting automation
2022-12-16 19:39:28 -08:00
Gilles Boccon-Gibod
c2959dadb4 formatting and linting automation
Squashed commits:
[cd479ba] formatting and linting automation
[7fbfabb] formatting and linting automation
[c4f9505] fix after rebase
[f506ad4] rename job
[441d517] update doc (+7 squashed commits)
[2e1b416] fix invoke and github action
[6ae5bb4] doc for git blame
[44b5461] add GitHub action
[b07474f] add docs
[4cd9a6f] more linter fixes
[db71901] wip
[540dc88] wip
2022-12-15 23:07:17 -08:00
Michael Mogenson
80fe2ea422 Merge pull request #102 from mogenson/libusb_package
Load libusb-1.0 shared library from libusb_package wheel
2022-12-15 21:53:56 -05:00
Lucas Abel
08e6590a76 Merge pull request #88 from google/uael/abort_on_event
Host: spawn each asynchronous task with the right aliveness
2022-12-15 12:46:37 -08:00
Abel Lucas
f580ffcbc3 device: set as Secure Connection when encrypted with AES 2022-12-15 17:02:21 +00:00
Abel Lucas
5178c866ac classic: add to .encrypt the possibilty to disable encryption 2022-12-15 17:02:21 +00:00
Abel Lucas
441933bd64 reverted: 662704e "classic: complete authentication when being the .authenticate acceptor" 2022-12-15 17:02:21 +00:00
Abel Lucas
287df94090 host: spawn each asynchronous task with the right aliveness 2022-12-15 17:02:21 +00:00
Michael Mogenson
86f9496575 Load libusb-1.0 shared library from libusb_package wheel
It would be nice to pip install bumble without having to first install
the libusb system dependency. Expecially on platforms like Windows and
Mac, without a default package manager.

The libusb_package Python package distributes prebuilt libusb-1.0 shared
libraries for each OS and architecture as binary wheels for the pyusb
project. Add this package as a dependency for bumble.

For the pyusb transport, the libusb_package.find() function is a drop-in
replacement for pyusb.core.find(). It searches the libusb_package
site-path before system paths and creates a pyusb backend.

For the usb transport, use libusb_package.get_library_path() to return a
path to the libusb-1.0 library in site-packages. If this path exists,
create a ctypes DLL and init the usb1 backend. This only needs to be
done once. All future calls to usb1 will use this opened library.
If the library path does not exist, do nothing, and usb1 will search
default system paths when the usb1.USBContext object is created.

This commit pins the libusb_package dependency at 1.0.26.0 to ensure
every bumble install uses the exact same version of the libusb library.
2022-12-15 10:22:02 -05:00
Gilles Boccon-Gibod
f5fe3d87f2 Merge pull request #98 from yuyangh/yuyangh/update_asha_advertising
update ASHA AdvertisingData
2022-12-12 15:23:47 -08:00
Yuyang Huang
f65bed2ec4 Merge branch 'main' into yuyangh/update_asha_advertising 2022-12-12 13:35:17 -08:00
Gilles Boccon-Gibod
3efe35065d Merge pull request #96 from google/gbg/black
format with Black
2022-12-12 13:27:58 -08:00
Yuyang Huang
83b42488ea update ASHA AdvertisingData
previously the ASHA AdvertisingData uses INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, now it lets user to define whether it is complete list or incomplete list
2022-12-12 11:42:48 -08:00
Gilles Boccon-Gibod
8bef344879 Merge pull request #94 from AlanRosenthal/alan/bumble_version_in_show_device
Add bumble's version to `show device`
2022-12-10 09:00:30 -08:00
Alan Rosenthal
55e2f23e29 Add bumble's version to show device 2022-12-09 12:23:45 -05:00
172 changed files with 6285 additions and 1883 deletions

2
.git-blame-ignore-revs Normal file
View File

@@ -0,0 +1,2 @@
# Migrate code style to Black
135df0dcc01ab765f432e19b1a5202d29bd55545

35
.github/workflows/code-check.yml vendored Normal file
View File

@@ -0,0 +1,35 @@
# Check the code against the formatter and linter
name: Code format and lint check
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
permissions:
contents: read
jobs:
check:
name: Check Code
runs-on: ubuntu-latest
steps:
- name: Check out from Git
uses: actions/checkout@v3
- name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development]"
- name: Check
run: |
invoke project.pre-commit

3
.gitignore vendored
View File

@@ -6,3 +6,6 @@ dist/
docs/mkdocs/site
test-results.xml
__pycache__
# generated by setuptools_scm
bumble/_version.py
.vscode/launch.json

80
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,80 @@
{
"cSpell.words": [
"Abortable",
"altsetting",
"ansiblue",
"ansicyan",
"ansigreen",
"ansimagenta",
"ansired",
"ansiyellow",
"appendleft",
"ASHA",
"asyncio",
"ATRAC",
"avdtp",
"bitpool",
"bitstruct",
"BSCP",
"BTPROTO",
"CCCD",
"cccds",
"cmac",
"CONNECTIONLESS",
"csrcs",
"datagram",
"DATALINK",
"delayreport",
"deregisters",
"deregistration",
"dhkey",
"diversifier",
"Fitbit",
"GATTLINK",
"HANDSFREE",
"keydown",
"keyup",
"levelname",
"libc",
"libusb",
"MITM",
"NDIS",
"NONBLOCK",
"NONCONN",
"OXIMETER",
"popleft",
"psms",
"pyee",
"pyusb",
"rfcomm",
"ROHC",
"rssi",
"SEID",
"seids",
"SERV",
"ssrc",
"strerror",
"subband",
"subbands",
"subevent",
"Subrating",
"substates",
"tobytes",
"tsep",
"usbmodem",
"vhci",
"websockets",
"xcursor",
"ycursor"
],
"[python]": {
"editor.rulers": [88]
},
"python.formatting.provider": "black",
"pylint.importStrategy": "useBundled",
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

View File

@@ -199,4 +199,4 @@
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.

View File

@@ -15,10 +15,10 @@ Bumble is a full-featured Bluetooth stack written entirely in Python. It support
## Documentation
Browse the pre-built [Online Documentation](https://google.github.io/bumble/),
Browse the pre-built [Online Documentation](https://google.github.io/bumble/),
or see the documentation source under `docs/mkdocs/src`, or build the static HTML site from the markdown text with:
```
mkdocs build -f docs/mkdocs/mkdocs.yml
mkdocs build -f docs/mkdocs/mkdocs.yml
```
## Usage
@@ -29,7 +29,7 @@ For a quick start to using Bumble, see the [Getting Started](docs/mkdocs/src/get
### Dependencies
To install package dependencies needed to run the bumble examples execute the following commands:
To install package dependencies needed to run the bumble examples, execute the following commands:
```
python -m pip install --upgrade pip
@@ -50,7 +50,7 @@ Bumble is easiest to use with a dedicated USB dongle.
This is because internal Bluetooth interfaces tend to be locked down by the operating system.
You can use the [usb_probe](/docs/mkdocs/src/apps_and_tools/usb_probe.md) tool (all platforms) or `lsusb` (Linux or macOS) to list the available USB devices on your system.
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices.
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices. Also, if your are on a mac, see [these instructions](docs/mkdocs/src/platforms/macos.md).
## License

View File

@@ -47,5 +47,3 @@ NOTE: this assumes you're running a Link Relay on port `10723`.
## `console.py`
A simple text-based-ui interactive Bluetooth device with GATT client capabilities.

1207
apps/bench.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -27,19 +27,6 @@ import re
from collections import OrderedDict
import click
import colors
from bumble.core import UUID, AdvertisingData, TimeoutError, BT_LE_TRANSPORT
from bumble.device import ConnectionParametersPreferences, Device, Connection, Peer
from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link
from bumble.gatt import Characteristic
from bumble.hci import (
HCI_Constant,
HCI_LE_1M_PHY,
HCI_LE_2M_PHY,
HCI_LE_CODED_PHY,
)
from prompt_toolkit import Application
from prompt_toolkit.history import FileHistory
@@ -63,6 +50,22 @@ from prompt_toolkit.layout import (
Dimension,
)
from bumble import __version__
import bumble.core
from bumble import colors
from bumble.core import UUID, AdvertisingData, BT_LE_TRANSPORT
from bumble.device import ConnectionParametersPreferences, Device, Connection, Peer
from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
from bumble.hci import (
HCI_Constant,
HCI_LE_1M_PHY,
HCI_LE_2M_PHY,
HCI_LE_CODED_PHY,
)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -74,12 +77,6 @@ DISPLAY_MAX_RSSI = -30
RSSI_MONITOR_INTERVAL = 5.0 # Seconds
# -----------------------------------------------------------------------------
# Globals
# -----------------------------------------------------------------------------
App = None
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
@@ -103,19 +100,19 @@ def rssi_bar(rssi):
def parse_phys(phys):
if phys.lower() == '*':
return None
else:
phy_list = []
elements = phys.lower().split(',')
for element in elements:
if element == '1m':
phy_list.append(HCI_LE_1M_PHY)
elif element == '2m':
phy_list.append(HCI_LE_2M_PHY)
elif element == 'coded':
phy_list.append(HCI_LE_CODED_PHY)
else:
raise ValueError('invalid PHY name')
return phy_list
phy_list = []
elements = phys.lower().split(',')
for element in elements:
if element == '1m':
phy_list.append(HCI_LE_1M_PHY)
elif element == '2m':
phy_list.append(HCI_LE_2M_PHY)
elif element == 'coded':
phy_list.append(HCI_LE_CODED_PHY)
else:
raise ValueError('invalid PHY name')
return phy_list
# -----------------------------------------------------------------------------
@@ -157,10 +154,10 @@ class ConsoleApp:
'rssi': {'on': None, 'off': None},
'show': {
'scan': None,
'services': None,
'attributes': None,
'log': None,
'device': None,
'local-services': None,
'remote-services': None,
},
'filter': {
'address': None,
@@ -200,8 +197,8 @@ class ConsoleApp:
)
self.output_max_lines = 20
self.scan_results_text = FormattedTextControl()
self.services_text = FormattedTextControl()
self.attributes_text = FormattedTextControl()
self.local_services_text = FormattedTextControl()
self.remote_services_text = FormattedTextControl()
self.device_text = FormattedTextControl()
self.log_text = FormattedTextControl(
get_cursor_position=lambda: Point(0, max(0, len(self.log_lines) - 1))
@@ -217,12 +214,12 @@ class ConsoleApp:
filter=Condition(lambda: self.top_tab == 'scan'),
),
ConditionalContainer(
Frame(Window(self.services_text), title='Services'),
filter=Condition(lambda: self.top_tab == 'services'),
Frame(Window(self.local_services_text), title='Local Services'),
filter=Condition(lambda: self.top_tab == 'local-services'),
),
ConditionalContainer(
Frame(Window(self.attributes_text), title='Attributes'),
filter=Condition(lambda: self.top_tab == 'attributes'),
Frame(Window(self.remote_services_text), title='Remove Services'),
filter=Condition(lambda: self.top_tab == 'remote-services'),
),
ConditionalContainer(
Frame(Window(self.log_text, height=self.log_height), title='Log'),
@@ -251,15 +248,16 @@ class ConsoleApp:
layout = Layout(container, focused_element=self.input_field)
kb = KeyBindings()
key_bindings = KeyBindings()
@kb.add("c-c")
@kb.add("c-q")
@key_bindings.add("c-c")
@key_bindings.add("c-q")
def _(event):
event.app.exit()
# pylint: disable=invalid-name
self.ui = Application(
layout=layout, style=style, key_bindings=kb, full_screen=True
layout=layout, style=style, key_bindings=key_bindings, full_screen=True
)
async def run_async(self, device_config, transport):
@@ -274,8 +272,8 @@ class ConsoleApp:
random_address = (
f"{random.randint(192,255):02X}" # address is static random
)
for c in random.sample(range(255), 5):
random_address += f":{c:02X}"
for random_byte in random.sample(range(255), 5):
random_address += f":{random_byte:02X}"
self.append_to_log(f"Setting random address: {random_address}")
self.device = Device.with_hci(
'Bumble', random_address, hci_source, hci_sink
@@ -283,6 +281,7 @@ class ConsoleApp:
self.device.listener = DeviceListener(self)
await self.device.power_on()
self.show_device(self.device)
self.show_local_services(self.device.gatt_server.attributes)
# Run the UI
await self.ui.run_async()
@@ -292,7 +291,7 @@ class ConsoleApp:
def add_known_address(self, address):
self.known_addresses.add(address)
def accept_input(self, buff):
def accept_input(self, _):
if len(self.input_field.text) == 0:
return
self.append_to_output([('', '* '), ('ansicyan', self.input_field.text)], False)
@@ -311,12 +310,24 @@ class ConsoleApp:
connection_state = 'CONNECTING'
elif self.connected_peer:
connection = self.connected_peer.connection
connection_parameters = f'{connection.parameters.connection_interval}/{connection.parameters.peripheral_latency}/{connection.parameters.supervision_timeout}'
connection_parameters = (
f'{connection.parameters.connection_interval}/'
f'{connection.parameters.peripheral_latency}/'
f'{connection.parameters.supervision_timeout}'
)
if connection.transport == BT_LE_TRANSPORT:
phy_state = f' RX={le_phy_name(connection.phy.rx_phy)}/TX={le_phy_name(connection.phy.tx_phy)}'
phy_state = (
f' RX={le_phy_name(connection.phy.rx_phy)}/'
f'TX={le_phy_name(connection.phy.tx_phy)}'
)
else:
phy_state = ''
connection_state = f'{connection.peer_address} {connection_parameters} {connection.data_length}{phy_state}'
connection_state = (
f'{connection.peer_address} '
f'{connection_parameters} '
f'{connection.data_length}'
f'{phy_state}'
)
encryption_state = (
'ENCRYPTED' if connection.is_encrypted else 'NOT ENCRYPTED'
)
@@ -349,37 +360,45 @@ class ConsoleApp:
self.scan_results_text.text = ANSI('\n'.join(lines))
self.ui.invalidate()
def show_services(self, services):
def show_remote_services(self, services):
lines = []
del self.known_attributes[:]
for service in services:
lines.append(('ansicyan', str(service) + '\n'))
lines.append(("ansicyan", f"{service}\n"))
for characteristic in service.characteristics:
lines.append(('ansimagenta', ' ' + str(characteristic) + '\n'))
lines.append(('ansimagenta', f' {characteristic} + \n'))
self.known_attributes.append(
f'{service.uuid.to_hex_str()}.{characteristic.uuid.to_hex_str()}'
)
self.known_attributes.append(f'*.{characteristic.uuid.to_hex_str()}')
self.known_attributes.append(f'#{characteristic.handle:X}')
for descriptor in characteristic.descriptors:
lines.append(('ansigreen', ' ' + str(descriptor) + '\n'))
lines.append(("ansigreen", f" {descriptor}\n"))
self.services_text.text = lines
self.remote_services_text.text = lines
self.ui.invalidate()
def show_attributes(self, attributes):
def show_local_services(self, attributes):
lines = []
for attribute in attributes:
lines.append(('ansicyan', f'{attribute}\n'))
if isinstance(attribute, Service):
lines.append(("ansicyan", f"{attribute}\n"))
elif isinstance(attribute, (Characteristic, CharacteristicDeclaration)):
lines.append(("ansimagenta", f" {attribute}\n"))
elif isinstance(attribute, Descriptor):
lines.append(("ansigreen", f" {attribute}\n"))
else:
lines.append(("ansiyellow", f"{attribute}\n"))
self.attributes_text.text = lines
self.local_services_text.text = lines
self.ui.invalidate()
def show_device(self, device):
lines = []
lines.append(('ansicyan', 'Bumble Version: '))
lines.append(('', f'{__version__}\n'))
lines.append(('ansicyan', 'Name: '))
lines.append(('', f'{device.name}\n'))
lines.append(('ansicyan', 'Public Address: '))
@@ -407,7 +426,10 @@ class ConsoleApp:
advertising_interval = (
device.advertising_interval_min
if device.advertising_interval_min == device.advertising_interval_max
else f"{device.advertising_interval_min} to {device.advertising_interval_max}"
else (
f'{device.advertising_interval_min} to '
f'{device.advertising_interval_max}'
)
)
lines.append(('ansicyan', 'Advertising Interval: '))
lines.append(('', f'{advertising_interval}\n'))
@@ -416,7 +438,7 @@ class ConsoleApp:
self.ui.invalidate()
def append_to_output(self, line, invalidate=True):
if type(line) is str:
if isinstance(line, str):
line = [('', line)]
self.output_lines = self.output_lines[-self.output_max_lines :]
self.output_lines.append(line)
@@ -454,7 +476,7 @@ class ConsoleApp:
await self.connected_peer.discover_descriptors(characteristic)
self.append_to_output('discovery completed')
self.show_services(self.connected_peer.services)
self.show_remote_services(self.connected_peer.services)
async def discover_attributes(self):
if not self.connected_peer:
@@ -486,6 +508,8 @@ class ConsoleApp:
if characteristic.handle == attribute_handle:
return characteristic
return None
async def rssi_monitor_loop(self):
while True:
if self.monitor_rssi and self.connected_peer:
@@ -517,7 +541,8 @@ class ConsoleApp:
if not params[1].startswith("filter="):
self.show_error(
'invalid syntax',
'expected address filter=key1:value1,key2:value,... available filters: address',
'expected address filter=key1:value1,key2:value,... '
'available filters: address',
)
# regex: (word):(any char except ,)
matches = re.findall(r"(\w+):([^,]+)", params[1])
@@ -575,10 +600,10 @@ class ConsoleApp:
timeout=DEFAULT_CONNECTION_TIMEOUT,
)
self.top_tab = 'services'
except TimeoutError:
except bumble.core.TimeoutError:
self.show_error('connection timed out')
async def do_disconnect(self, params):
async def do_disconnect(self, _):
if self.device.is_le_connecting:
await self.device.cancel_connection()
else:
@@ -592,7 +617,8 @@ class ConsoleApp:
if len(params) != 1 or len(params[0].split('/')) != 3:
self.show_error(
'invalid syntax',
'expected update-parameters <interval-min>-<interval-max>/<max-latency>/<supervision>',
'expected update-parameters <interval-min>-<interval-max>'
'/<max-latency>/<supervision>',
)
return
@@ -613,7 +639,7 @@ class ConsoleApp:
supervision_timeout,
)
async def do_encrypt(self, params):
async def do_encrypt(self, _):
if not self.connected_peer:
self.show_error('not connected')
return
@@ -636,18 +662,25 @@ class ConsoleApp:
async def do_show(self, params):
if params:
if params[0] in {'scan', 'services', 'attributes', 'log', 'device'}:
if params[0] in {
'scan',
'log',
'device',
'local-services',
'remote-services',
}:
self.top_tab = params[0]
self.ui.invalidate()
async def do_get_phy(self, params):
async def do_get_phy(self, _):
if not self.connected_peer:
self.show_error('not connected')
return
phy = await self.connected_peer.connection.get_phy()
self.append_to_output(
f'PHY: RX={HCI_Constant.le_phy_name(phy[0])}, TX={HCI_Constant.le_phy_name(phy[1])}'
f'PHY: RX={HCI_Constant.le_phy_name(phy[0])}, '
f'TX={HCI_Constant.le_phy_name(phy[1])}'
)
async def do_request_mtu(self, params):
@@ -790,10 +823,10 @@ class ConsoleApp:
tx_phys=parse_phys(tx_phys), rx_phys=parse_phys(rx_phys)
)
async def do_exit(self, params):
async def do_exit(self, _):
self.ui.exit()
async def do_quit(self, params):
async def do_quit(self, _):
self.ui.exit()
async def do_filter(self, params):
@@ -824,7 +857,7 @@ class DeviceListener(Device.Listener, Connection.Listener):
else:
self._address_filter = re.compile(filter_addr)
self.scan_results = OrderedDict(
filter(lambda x: self.filter_address_match(x), self.scan_results)
filter(self.filter_address_match, self.scan_results)
)
self.app.show_scan_results(self.scan_results)
@@ -835,6 +868,7 @@ class DeviceListener(Device.Listener, Connection.Listener):
return bool(self.address_filter.match(address))
@AsyncRunner.run_in_task()
# pylint: disable=invalid-overridden-method
async def on_connection(self, connection):
self.app.connected_peer = Peer(connection)
self.app.connection_rssi = None
@@ -843,14 +877,16 @@ class DeviceListener(Device.Listener, Connection.Listener):
def on_disconnection(self, reason):
self.app.append_to_output(
f'disconnected from {self.app.connected_peer}, reason: {HCI_Constant.error_name(reason)}'
f'disconnected from {self.app.connected_peer}, '
f'reason: {HCI_Constant.error_name(reason)}'
)
self.app.connected_peer = None
self.app.connection_rssi = None
def on_connection_parameters_update(self):
self.app.append_to_output(
f'connection parameters update: {self.app.connected_peer.connection.parameters}'
f'connection parameters update: '
f'{self.app.connected_peer.connection.parameters}'
)
def on_connection_phy_update(self):
@@ -864,13 +900,19 @@ class DeviceListener(Device.Listener, Connection.Listener):
)
def on_connection_encryption_change(self):
encryption_state = (
'encrypted'
if self.app.connected_peer.connection.is_encrypted
else 'not encrypted'
)
self.app.append_to_output(
f'connection encryption change: {"encrypted" if self.app.connected_peer.connection.is_encrypted else "not encrypted"}'
'connection encryption change: ' f'{encryption_state}'
)
def on_connection_data_length_change(self):
self.app.append_to_output(
f'connection data length change: {self.app.connected_peer.connection.data_length}'
'connection data length change: '
f'{self.app.connected_peer.connection.data_length}'
)
def on_advertisement(self, advertisement):
@@ -927,10 +969,16 @@ class ScanResult:
else:
name = ''
# Remove any '/P' qualifier suffix from the address string
address_str = str(self.address).replace('/P', '')
# RSSI bar
bar_string = rssi_bar(self.rssi)
bar_padding = ' ' * (DEFAULT_RSSI_BAR_WIDTH + 5 - len(bar_string))
return f'{address_color(str(self.address))} [{type_color(address_type_string)}] {bar_string} {bar_padding} {name}'
return (
f'{address_color(address_str)} [{type_color(address_type_string)}] '
f'{bar_string} {bar_padding} {name}'
)
# -----------------------------------------------------------------------------
@@ -958,7 +1006,7 @@ def main(device_config, transport):
if not os.path.isdir(BUMBLE_USER_DIR):
os.mkdir(BUMBLE_USER_DIR)
# Create an instane of the app
# Create an instance of the app
app = ConsoleApp()
# Setup logging
@@ -975,4 +1023,4 @@ def main(device_config, transport):
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main()
main() # pylint: disable=no-value-for-parameter

View File

@@ -19,9 +19,9 @@ import asyncio
import os
import logging
import click
from colors import color
from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.colors import color
from bumble.core import name_or_number
from bumble.hci import (
map_null_terminated_utf8_string,
@@ -30,6 +30,8 @@ from bumble.hci import (
HCI_VERSION_NAMES,
LMP_VERSION_NAMES,
HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND,
@@ -45,11 +47,20 @@ from bumble.host import Host
from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
def command_succeeded(response):
if isinstance(response, HCI_Command_Status_Event):
return response.status == HCI_SUCCESS
if isinstance(response, HCI_Command_Complete_Event):
return response.return_parameters.status == HCI_SUCCESS
return False
# -----------------------------------------------------------------------------
async def get_classic_info(host):
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if response.return_parameters.status == HCI_SUCCESS:
if command_succeeded(response):
print()
print(
color('Classic Address:', 'yellow'), response.return_parameters.bd_addr
@@ -57,7 +68,7 @@ async def get_classic_info(host):
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response = await host.send_command(HCI_Read_Local_Name_Command())
if response.return_parameters.status == HCI_SUCCESS:
if command_succeeded(response):
print()
print(
color('Local Name:', 'yellow'),
@@ -73,7 +84,7 @@ async def get_le_info(host):
response = await host.send_command(
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
if response.return_parameters.status == HCI_SUCCESS:
if command_succeeded(response):
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response.return_parameters.num_supported_advertising_sets,
@@ -84,7 +95,7 @@ async def get_le_info(host):
response = await host.send_command(
HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
if response.return_parameters.status == HCI_SUCCESS:
if command_succeeded(response):
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response.return_parameters.max_advertising_data_length,
@@ -93,7 +104,7 @@ async def get_le_info(host):
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
response = await host.send_command(HCI_LE_Read_Maximum_Data_Length_Command())
if response.return_parameters.status == HCI_SUCCESS:
if command_succeeded(response):
print(
color('Maximum Data Length:', 'yellow'),
(

View File

@@ -29,12 +29,13 @@ from bumble.transport import open_transport_or_link
async def async_main():
if len(sys.argv) != 3:
print(
'Usage: controllers.py <hci-transport-1> <hci-transport-2> [<hci-transport-3> ...]'
'Usage: controllers.py <hci-transport-1> <hci-transport-2> '
'[<hci-transport-3> ...]'
)
print('example: python controllers.py pty:ble1 pty:ble2')
return
# Create a loccal link to attach the controllers to
# Create a local link to attach the controllers to
link = LocalLink()
# Create a transport and controller for all requested names

View File

@@ -19,9 +19,9 @@ import asyncio
import os
import logging
import click
from colors import color
from bumble.core import ProtocolError, TimeoutError
import bumble.core
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.gatt import show_services
from bumble.transport import open_transport_or_link
@@ -49,9 +49,9 @@ async def dump_gatt_db(peer, done):
try:
value = await attribute.read_value()
print(color(f'{value.hex()}', 'green'))
except ProtocolError as error:
except bumble.core.ProtocolError as error:
print(color(error, 'red'))
except TimeoutError:
except bumble.core.TimeoutError:
print(color('read timeout', 'red'))
if done is not None:

View File

@@ -20,8 +20,8 @@ import os
import struct
import logging
import click
from colors import color
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.core import AdvertisingData
from bumble.gatt import Service, Characteristic, CharacteristicValue
@@ -99,6 +99,7 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
print(color(f'!!! Connection failed: {error}', 'red'))
@AsyncRunner.run_in_task()
# pylint: disable=invalid-overridden-method
async def on_connection(self, connection):
print(f'=== Connected to {connection}')
self.peer = Peer(connection)
@@ -158,7 +159,8 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
def on_disconnection(self, reason):
print(
color(
f'!!! Disconnected from {self.peer}, reason={HCI_Constant.error_name(reason)}',
f'!!! Disconnected from {self.peer}, '
f'reason={HCI_Constant.error_name(reason)}',
'red',
)
)
@@ -189,7 +191,7 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
pass
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address):
def datagram_received(self, data, _address):
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
if self.l2cap_channel:
@@ -209,6 +211,7 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
self.tx_socket = None
self.tx_subscriber = None
self.rx_characteristic = None
self.transport = None
# Register as a listener
device.listener = self
@@ -264,7 +267,7 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
self.transport = transport
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address):
def datagram_received(self, data, _address):
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
if self.l2cap_channel:
@@ -276,7 +279,7 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
asyncio.create_task(self.device.notify_subscribers(self.tx_characteristic))
# Called when a write to the RX characteristic has been received
def on_rx_write(self, connection, data):
def on_rx_write(self, _connection, data):
print(color(f'<<< [GATT RX]: {len(data)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(data)
@@ -284,7 +287,8 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
# Called when the subscription to the TX characteristic has changed
def on_tx_subscription(self, peer, enabled):
print(
f'### [GATT TX] subscription from {peer}: {"enabled" if enabled else "disabled"}'
f'### [GATT TX] subscription from {peer}: '
f'{"enabled" if enabled else "disabled"}'
)
if enabled:
self.tx_subscriber = peer
@@ -335,7 +339,9 @@ async def run(
# Create a UDP to TX bridge (receive from TX, send to UDP)
bridge.tx_socket, _ = await loop.create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(), remote_addr=(send_host, send_port)
# pylint: disable-next=unnecessary-lambda
lambda: asyncio.DatagramProtocol(),
remote_addr=(send_host, send_port),
)
await device.power_on()

View File

@@ -35,10 +35,13 @@ logger = logging.getLogger(__name__)
async def async_main():
if len(sys.argv) < 3:
print(
'Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> [command-short-circuit-list]'
'Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> '
'[command-short-circuit-list]'
)
print(
'example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078'
'example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 '
'serial:/dev/tty.usbmodem0006839912171,1000000 '
'0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078'
)
return
@@ -82,6 +85,8 @@ async def async_main():
# Return a packet with 'respond to sender' set to True
return (response.to_bytes(), True)
return None
_ = HCI_Bridge(
hci_host_source,
hci_host_sink,

View File

@@ -16,11 +16,11 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import click
import logging
import os
from colors import color
import click
from bumble.colors import color
from bumble.transport import open_transport_or_link
from bumble.device import Device
from bumble.utils import FlowControlAsyncPipe
@@ -89,7 +89,8 @@ class ServerBridge:
# Connect to the TCP server
print(
color(
f'### Connecting to TCP {self.bridge.tcp_host}:{self.bridge.tcp_port}...',
f'### Connecting to TCP {self.bridge.tcp_host}:'
f'{self.bridge.tcp_port}...',
'yellow',
)
)
@@ -98,8 +99,8 @@ class ServerBridge:
def __init__(self, pipe):
self.pipe = pipe
def connection_lost(self, error):
print(color(f'!!! TCP connection lost: {error}', 'red'))
def connection_lost(self, exc):
print(color(f'!!! TCP connection lost: {exc}', 'red'))
if self.pipe.l2cap_channel is not None:
asyncio.create_task(self.pipe.l2cap_channel.disconnect())
@@ -178,8 +179,8 @@ class ClientBridge:
# Called when a TCP connection is established
async def on_tcp_connection(reader, writer):
peername = writer.get_extra_info('peername')
print(color(f'<<< TCP connection from {peername}', 'magenta'))
peer_name = writer.get_extra_info('peer_name')
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
def on_coc_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
@@ -346,4 +347,4 @@ def client(context, bluetooth_address, tcp_host, tcp_port):
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__':
cli(obj={})
cli(obj={}) # pylint: disable=no-value-for-parameter

View File

@@ -16,7 +16,6 @@
# Imports
# ----------------------------------------------------------------------------
import sys
import websockets
import logging
import json
import asyncio
@@ -24,7 +23,9 @@ import argparse
import uuid
import os
from urllib.parse import urlparse
from colors import color
import websockets
from bumble.colors import color
# -----------------------------------------------------------------------------
# Logging
@@ -98,7 +99,11 @@ class Connection:
self.address = address
def __str__(self):
return f'Connection(address="{self.address}", client={self.websocket.remote_address[0]}:{self.websocket.remote_address[1]})'
return (
f'Connection(address="{self.address}", '
f'client={self.websocket.remote_address[0]}:'
f'{self.websocket.remote_address[1]})'
)
# ----------------------------------------------------------------------------
@@ -139,8 +144,8 @@ class Room:
# Parse the message to decide how to handle it
if message.startswith('@'):
# This is a targetted message
await self.on_targetted_message(connection, message)
# This is a targeted message
await self.on_targeted_message(connection, message)
elif message.startswith('/'):
# This is an RPC request
await self.on_rpc_request(connection, message)
@@ -169,7 +174,7 @@ class Room:
await connection.send_message(result or 'result:{}')
async def on_targetted_message(self, connection, message):
async def on_targeted_message(self, connection, message):
target, *payload = message.split(' ', 1)
if not payload:
return error_to_json('missing arguments')
@@ -178,7 +183,8 @@ class Room:
# Determine what targets to send to
if target == '*':
# Send to all connections in the room except the connection from which the message was received
# Send to all connections in the room except the connection from which the
# message was received
connections = [c for c in self.connections if c != connection]
else:
connections = self.find_connections_by_address(target)
@@ -216,9 +222,10 @@ class Relay:
def start(self):
logger.info(f'Starting Relay on port {self.port}')
# pylint: disable-next=no-member
return websockets.serve(self.serve, '0.0.0.0', self.port, ping_interval=None)
async def serve_as_controller(connection):
async def serve_as_controller(self, connection):
pass
async def serve(self, websocket, path):
@@ -265,7 +272,7 @@ def main():
# Setup logger
if args.log_config:
from logging import config
from logging import config # pylint: disable=import-outside-toplevel
config.fileConfig(args.log_config)
else:

View File

@@ -19,9 +19,9 @@ import asyncio
import os
import logging
import click
import aioconsole
from colors import color
from prompt_toolkit.shortcuts import PromptSession
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.transport import open_transport_or_link
from bumble.smp import PairingDelegate, PairingConfig
@@ -42,9 +42,23 @@ from bumble.att import (
)
# -----------------------------------------------------------------------------
class Waiter:
instance = None
def __init__(self):
self.done = asyncio.get_running_loop().create_future()
def terminate(self):
self.done.set_result(None)
async def wait_until_terminated(self):
return await self.done
# -----------------------------------------------------------------------------
class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, prompt):
def __init__(self, mode, connection, capability_string, do_prompt):
super().__init__(
{
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
@@ -58,7 +72,18 @@ class Delegate(PairingDelegate):
self.mode = mode
self.peer = Peer(connection)
self.peer_name = None
self.prompt = prompt
self.do_prompt = do_prompt
def print(self, message):
print(color(message, 'yellow'))
async def prompt(self, message):
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
session = PromptSession(message)
response = await session.prompt_async()
return response.lower().strip()
async def update_peer_name(self):
if self.peer_name is not None:
@@ -73,93 +98,83 @@ class Delegate(PairingDelegate):
self.peer_name = '[?]'
async def accept(self):
if self.prompt:
if self.do_prompt:
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for acceptance
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing request from {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
self.print('###-----------------------------------')
self.print(f'### Pairing request from {self.peer_name}')
self.print('###-----------------------------------')
while True:
response = await aioconsole.ainput(color('>>> Accept? ', 'yellow'))
response = response.lower().strip()
response = await self.prompt('>>> Accept? ')
if response == 'yes':
return True
elif response == 'no':
if response == 'no':
return False
else:
# Accept silently
return True
# Accept silently
return True
async def compare_numbers(self, number, digits):
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for a numeric comparison
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print('###-----------------------------------')
while True:
response = await aioconsole.ainput(
color(
f'>>> Does the other device display {number:0{digits}}? ', 'yellow'
)
response = await self.prompt(
f'>>> Does the other device display {number:0{digits}}? '
)
response = response.lower().strip()
if response == 'yes':
return True
elif response == 'no':
if response == 'no':
return False
async def get_number(self):
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for a PIN
while True:
try:
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
return int(await aioconsole.ainput(color('>>> Enter PIN: ', 'yellow')))
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print('###-----------------------------------')
return int(await self.prompt('>>> Enter PIN: '))
except ValueError:
pass
async def display_number(self, number, digits):
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Display a PIN code
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color(f'### PIN: {number:0{digits}}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print(f'### PIN: {number:0{digits}}')
self.print('###-----------------------------------')
# -----------------------------------------------------------------------------
async def get_peer_name(peer, mode):
if mode == 'classic':
return await peer.request_name()
else:
# Try to get the peer name from GATT
services = await peer.discover_service(GATT_GENERIC_ACCESS_SERVICE)
if not services:
return None
values = await peer.read_characteristics_by_uuid(
GATT_DEVICE_NAME_CHARACTERISTIC, services[0]
)
if values:
return values[0].decode('utf-8')
# Try to get the peer name from GATT
services = await peer.discover_service(GATT_GENERIC_ACCESS_SERVICE)
if not services:
return None
values = await peer.read_characteristics_by_uuid(
GATT_DEVICE_NAME_CHARACTERISTIC, services[0]
)
if values:
return values[0].decode('utf-8')
return None
# -----------------------------------------------------------------------------
@@ -172,12 +187,12 @@ def read_with_error(connection):
if AUTHENTICATION_ERROR_RETURNED[0]:
return bytes([1])
else:
AUTHENTICATION_ERROR_RETURNED[0] = True
raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
AUTHENTICATION_ERROR_RETURNED[0] = True
raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
def write_with_error(connection, value):
def write_with_error(connection, _value):
if not connection.is_encrypted:
raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR)
@@ -232,6 +247,7 @@ def on_pairing(keys):
print(color('*** Paired!', 'cyan'))
keys.print(prefix=color('*** ', 'cyan'))
print(color('***-----------------------------------', 'cyan'))
Waiter.instance.terminate()
# -----------------------------------------------------------------------------
@@ -239,6 +255,7 @@ def on_pairing_failure(reason):
print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color('***-----------------------------------', 'red'))
Waiter.instance.terminate()
# -----------------------------------------------------------------------------
@@ -256,6 +273,8 @@ async def pair(
hci_transport,
address_or_name,
):
Waiter.instance = Waiter()
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
@@ -326,7 +345,19 @@ async def pair(
# Advertise so that peers can find us and connect
await device.start_advertising(auto_restart=True)
await hci_source.wait_for_termination()
# Run until the user asks to exit
await Waiter.instance.wait_until_terminated()
# -----------------------------------------------------------------------------
class LogHandler(logging.Handler):
def __init__(self):
super().__init__()
self.setFormatter(logging.Formatter('%(levelname)s:%(name)s:%(message)s'))
def emit(self, record):
message = self.format(record)
print(message)
# -----------------------------------------------------------------------------
@@ -360,7 +391,11 @@ async def pair(
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
)
@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing')
@click.option('--keystore-file', help='File in which to store the pairing keys')
@click.option(
'--keystore-file',
metavar='<filename>',
help='File in which to store the pairing keys',
)
@click.argument('device-config')
@click.argument('hci_transport')
@click.argument('address-or-name', required=False)
@@ -378,7 +413,13 @@ def main(
hci_transport,
address_or_name,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Setup logging
log_handler = LogHandler()
root_logger = logging.getLogger()
root_logger.addHandler(log_handler)
root_logger.setLevel(os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Pair
asyncio.run(
pair(
mode,

View File

@@ -19,8 +19,8 @@ import asyncio
import os
import logging
import click
from colors import color
from bumble.colors import color
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.keys import JsonKeyStore
@@ -92,7 +92,8 @@ class AdvertisementPrinter:
print(
f'>>> {color(address, address_color)} '
f'[{color(address_type_string, type_color)}]{address_qualifier}{resolution_qualifier}:{separator}'
f'[{color(address_type_string, type_color)}]{address_qualifier}'
f'{resolution_qualifier}:{separator}'
f'{phy_info}'
f'RSSI:{advertisement.rssi:4} {rssi_bar}{separator}'
f'{advertisement.data.to_string(separator)}\n'

View File

@@ -17,8 +17,8 @@
# -----------------------------------------------------------------------------
import struct
import click
from colors import color
from bumble.colors import color
from bumble import hci
from bumble.transport.common import PacketReader
from bumble.helpers import PacketTracer
@@ -27,7 +27,8 @@ from bumble.helpers import PacketTracer
# -----------------------------------------------------------------------------
class SnoopPacketReader:
'''
Reader that reads HCI packets from a "snoop" file (based on RFC 1761, but not exactly the same...)
Reader that reads HCI packets from a "snoop" file (based on RFC 1761, but not
exactly the same...)
'''
DATALINK_H1 = 1001
@@ -47,10 +48,7 @@ class SnoopPacketReader:
(self.version_number, self.data_link_type) = struct.unpack(
'>II', source.read(8)
)
if (
self.data_link_type != self.DATALINK_H4
and self.data_link_type != self.DATALINK_H1
):
if self.data_link_type not in (self.DATALINK_H4, self.DATALINK_H1):
raise ValueError(f'datalink type {self.data_link_type} not supported')
def next_packet(self):
@@ -62,9 +60,9 @@ class SnoopPacketReader:
original_length,
included_length,
packet_flags,
cumulative_drops,
timestamp_seconds,
timestamp_microsecond,
_cumulative_drops,
_timestamp_seconds,
_timestamp_microsecond,
) = struct.unpack('>IIIIII', header)
# Abort on truncated packets
@@ -90,8 +88,8 @@ class SnoopPacketReader:
packet_flags & 1,
bytes([packet_type]) + self.source.read(included_length),
)
else:
return (packet_flags & 1, self.source.read(included_length))
return (packet_flags & 1, self.source.read(included_length))
# -----------------------------------------------------------------------------
@@ -105,13 +103,14 @@ class SnoopPacketReader:
help='Format of the input file',
)
@click.argument('filename')
# pylint: disable=redefined-builtin
def main(format, filename):
input = open(filename, 'rb')
if format == 'h4':
packet_reader = PacketReader(input)
def read_next_packet():
(0, packet_reader.next_packet())
return (0, packet_reader.next_packet())
else:
packet_reader = SnoopPacketReader(input)
@@ -128,9 +127,8 @@ def main(format, filename):
except Exception as error:
print(color(f'!!! {error}', 'red'))
pass
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()
main() # pylint: disable=no-value-for-parameter

View File

@@ -28,10 +28,11 @@
# -----------------------------------------------------------------------------
import os
import logging
import sys
import click
import usb1
from colors import color
from bumble.colors import color
from bumble.transport.usb import load_libusb
# -----------------------------------------------------------------------------
@@ -94,9 +95,9 @@ def show_device_details(device):
print(f' Configuration {configuration.getConfigurationValue()}')
for interface in configuration:
for setting in interface:
alternateSetting = setting.getAlternateSetting()
alternate_setting = setting.getAlternateSetting()
suffix = (
f'/{alternateSetting}' if interface.getNumSettings() > 1 else ''
f'/{alternate_setting}' if interface.getNumSettings() > 1 else ''
)
(class_string, subclass_string) = get_class_info(
setting.getClass(), setting.getSubClass(), setting.getProtocol()
@@ -111,7 +112,8 @@ def show_device_details(device):
else 'IN'
)
print(
f' Endpoint 0x{endpoint.getAddress():02X}: {endpoint_type} {endpoint_direction}'
f' Endpoint 0x{endpoint.getAddress():02X}: '
f'{endpoint_type} {endpoint_direction}'
)
@@ -122,7 +124,7 @@ def get_class_info(cls, subclass, protocol):
if class_info is None:
class_string = f'0x{cls:02X}'
else:
if type(class_info) is tuple:
if isinstance(class_info, tuple):
class_string = class_info[0]
subclass_info = class_info[1].get(subclass)
if subclass_info:
@@ -169,6 +171,7 @@ def is_bluetooth_hci(device):
def main(verbose):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
load_libusb()
with usb1.USBContext() as context:
bluetooth_device_count = 0
devices = {}
@@ -272,4 +275,4 @@ def main(verbose):
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()
main() # pylint: disable=no-value-for-parameter

View File

@@ -0,0 +1,4 @@
try:
from ._version import version as __version__
except ImportError:
__version__ = "unknown version"

View File

@@ -16,10 +16,8 @@
# Imports
# -----------------------------------------------------------------------------
import struct
import bitstruct
import logging
from collections import namedtuple
from colors import color
from .company_ids import COMPANY_IDENTIFIERS
from .sdp import (
@@ -134,14 +132,15 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
# -----------------------------------------------------------------------------
def flags_to_list(flags, values):
result = []
for i in range(len(values)):
for i, value in enumerate(values):
if flags & (1 << (len(values) - i - 1)):
result.append(values[i])
result.append(value)
return result
# -----------------------------------------------------------------------------
def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)):
# pylint: disable=import-outside-toplevel
from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1]
@@ -191,6 +190,7 @@ def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3))
# -----------------------------------------------------------------------------
def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
# pylint: disable=import-outside-toplevel
from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1]
@@ -257,7 +257,6 @@ class SbcMediaCodecInformation(
A2DP spec - 4.3.2 Codec Specific Information Elements
'''
BIT_FIELDS = 'u4u4u4u2u2u8u8'
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3,
@@ -273,9 +272,22 @@ class SbcMediaCodecInformation(
}
@staticmethod
def from_bytes(data):
def from_bytes(data: bytes) -> 'SbcMediaCodecInformation':
sampling_frequency = (data[0] >> 4) & 0x0F
channel_mode = (data[0] >> 0) & 0x0F
block_length = (data[1] >> 4) & 0x0F
subbands = (data[1] >> 2) & 0x03
allocation_method = (data[1] >> 0) & 0x03
minimum_bitpool_value = (data[2] >> 0) & 0xFF
maximum_bitpool_value = (data[3] >> 0) & 0xFF
return SbcMediaCodecInformation(
*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data)
sampling_frequency,
channel_mode,
block_length,
subbands,
allocation_method,
minimum_bitpool_value,
maximum_bitpool_value,
)
@classmethod
@@ -324,13 +336,23 @@ class SbcMediaCodecInformation(
maximum_bitpool_value=maximum_bitpool_value,
)
def __bytes__(self):
return bitstruct.pack(self.BIT_FIELDS, *self)
def __bytes__(self) -> bytes:
return bytes(
[
(self.sampling_frequency << 4) | self.channel_mode,
(self.block_length << 4)
| (self.subbands << 2)
| self.allocation_method,
self.minimum_bitpool_value,
self.maximum_bitpool_value,
]
)
def __str__(self):
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness']
return '\n'.join(
# pylint: disable=line-too-long
[
'SbcMediaCodecInformation(',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
@@ -348,14 +370,13 @@ class SbcMediaCodecInformation(
class AacMediaCodecInformation(
namedtuple(
'AacMediaCodecInformation',
['object_type', 'sampling_frequency', 'channels', 'vbr', 'bitrate'],
['object_type', 'sampling_frequency', 'channels', 'rfa', 'vbr', 'bitrate'],
)
):
'''
A2DP spec - 4.5.2 Codec Specific Information Elements
'''
BIT_FIELDS = 'u8u12u2p2u1u23'
OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
@@ -379,9 +400,15 @@ class AacMediaCodecInformation(
CHANNELS_BITS = {1: 1 << 1, 2: 1}
@staticmethod
def from_bytes(data):
def from_bytes(data: bytes) -> 'AacMediaCodecInformation':
object_type = data[0]
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
channels = (data[2] >> 2) & 0x03
rfa = 0
vbr = (data[3] >> 7) & 0x01
bitrate = ((data[3] & 0x7F) << 16) | (data[4] << 8) | data[5]
return AacMediaCodecInformation(
*bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data)
object_type, sampling_frequency, channels, rfa, vbr, bitrate
)
@classmethod
@@ -392,6 +419,7 @@ class AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels=cls.CHANNELS_BITS[channels],
rfa=0,
vbr=vbr,
bitrate=bitrate,
)
@@ -408,8 +436,17 @@ class AacMediaCodecInformation(
bitrate=bitrate,
)
def __bytes__(self):
return bitstruct.pack(self.BIT_FIELDS, *self)
def __bytes__(self) -> bytes:
return bytes(
[
self.object_type & 0xFF,
(self.sampling_frequency >> 4) & 0xFF,
(((self.sampling_frequency & 0x0F) << 4) | (self.channels << 2)) & 0xFF,
((self.vbr << 7) | ((self.bitrate >> 16) & 0x7F)) & 0xFF,
((self.bitrate >> 8) & 0xFF) & 0xFF,
self.bitrate & 0xFF,
]
)
def __str__(self):
object_types = [
@@ -423,6 +460,7 @@ class AacMediaCodecInformation(
'[7]',
]
channels = [1, 2]
# pylint: disable=line-too-long
return '\n'.join(
[
'AacMediaCodecInformation(',
@@ -455,6 +493,7 @@ class VendorSpecificMediaCodecInformation:
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
def __str__(self):
# pylint: disable=line-too-long
return '\n'.join(
[
'VendorSpecificMediaCodecInformation(',
@@ -489,7 +528,13 @@ class SbcFrame:
return self.sample_count / self.sampling_frequency
def __str__(self):
return f'SBC(sf={self.sampling_frequency},cm={self.channel_mode},br={self.bitrate},sc={self.sample_count},size={len(self.payload)})'
return (
f'SBC(sf={self.sampling_frequency},'
f'cm={self.channel_mode},'
f'br={self.bitrate},'
f'sc={self.sample_count},'
f'size={len(self.payload)})'
)
# -----------------------------------------------------------------------------
@@ -551,6 +596,7 @@ class SbcPacketSource:
@property
def packets(self):
async def generate_packets():
# pylint: disable=import-outside-toplevel
from .avdtp import MediaPacket # Import here to avoid a circular reference
sequence_number = 0
@@ -582,7 +628,7 @@ class SbcPacketSource:
# Prepare for next packets
sequence_number += 1
timestamp += sum([frame.sample_count for frame in frames])
timestamp += sum((frame.sample_count for frame in frames))
frames = [frame]
frames_size = len(frame.payload)
else:

View File

@@ -22,16 +22,24 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from colors import color
from __future__ import annotations
import functools
import struct
from pyee import EventEmitter
from typing import Dict, Type, TYPE_CHECKING
from .core import *
from .hci import *
from bumble.core import UUID, name_or_number, get_dict_key_by_value
from bumble.hci import HCI_Object, key_with_value
from bumble.colors import color
if TYPE_CHECKING:
from bumble.device import Connection
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
ATT_CID = 0x04
@@ -165,21 +173,14 @@ ATT_ERROR_NAMES = {
ATT_DEFAULT_MTU = 23
HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'}
UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y) # noqa: E731
# pylint: disable-next=unnecessary-lambda-assignment,unnecessary-lambda
UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y)
# pylint: disable-next=unnecessary-lambda-assignment,unnecessary-lambda
UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# fmt: on
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def key_with_value(dictionary, target_value):
for key, value in dictionary.items():
if value == target_value:
return key
return None
# pylint: enable=line-too-long
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------
# Exceptions
@@ -201,8 +202,9 @@ class ATT_PDU:
See Bluetooth spec @ Vol 3, Part F - 3.3 ATTRIBUTE PDU
'''
pdu_classes = {}
pdu_classes: Dict[int, Type[ATT_PDU]] = {}
op_code = 0
name = None
@staticmethod
def from_bytes(pdu):
@@ -724,22 +726,44 @@ class Attribute(EventEmitter):
READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80
PERMISSION_NAMES = {
READABLE: 'READABLE',
WRITEABLE: 'WRITEABLE',
READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION',
WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION',
READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION',
WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION',
READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION',
WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION',
}
@staticmethod
def string_to_permissions(permissions_str: str):
return functools.reduce(
lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y),
permissions_str.split(","),
0,
)
def __init__(self, attribute_type, permissions, value=b''):
EventEmitter.__init__(self)
self.handle = 0
self.end_group_handle = 0
self.permissions = permissions
if isinstance(permissions, str):
self.permissions = self.string_to_permissions(permissions)
else:
self.permissions = permissions
# Convert the type to a UUID object if it isn't already
if type(attribute_type) is str:
if isinstance(attribute_type, str):
self.type = UUID(attribute_type)
elif type(attribute_type) is bytes:
elif isinstance(attribute_type, bytes):
self.type = UUID.from_bytes(attribute_type)
else:
self.type = attribute_type
# Convert the value to a byte array
if type(value) is str:
if isinstance(value, str):
self.value = bytes(value, 'utf-8')
else:
self.value = value
@@ -750,32 +774,72 @@ class Attribute(EventEmitter):
def decode_value(self, value_bytes):
return value_bytes
def read_value(self, connection):
def read_value(self, connection: Connection):
if (
self.permissions & self.READ_REQUIRES_ENCRYPTION
) and not connection.encryption:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
)
if (
self.permissions & self.READ_REQUIRES_AUTHENTICATION
) and not connection.authenticated:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
)
if self.permissions & self.READ_REQUIRES_AUTHORIZATION:
# TODO: handle authorization better
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
if read := getattr(self.value, 'read', None):
try:
value = read(connection)
value = read(connection) # pylint: disable=not-callable
except ATT_Error as error:
raise ATT_Error(error_code=error.error_code, att_handle=self.handle)
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
value = self.value
return self.encode_value(value)
def write_value(self, connection, value_bytes):
def write_value(self, connection: Connection, value_bytes):
if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
)
if (
self.permissions & self.WRITE_REQUIRES_AUTHENTICATION
) and not connection.authenticated:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
)
if self.permissions & self.WRITE_REQUIRES_AUTHORIZATION:
# TODO: handle authorization better
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
value = self.decode_value(value_bytes)
if write := getattr(self.value, 'write', None):
try:
write(connection, value)
write(connection, value) # pylint: disable=not-callable
except ATT_Error as error:
raise ATT_Error(error_code=error.error_code, att_handle=self.handle)
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
self.value = value
self.emit('write', connection, value)
def __repr__(self):
if type(self.value) is bytes:
if isinstance(self.value, bytes):
value_str = self.value.hex()
else:
value_str = str(self.value)
@@ -783,4 +847,8 @@ class Attribute(EventEmitter):
value_string = f', value={self.value.hex()}'
else:
value_string = ''
return f'Attribute(handle=0x{self.handle:04X}, type={self.type}, permissions={self.permissions}{value_string})'
return (
f'Attribute(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'permissions={self.permissions}{value_string})'
)

View File

@@ -15,12 +15,13 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import struct
import time
import logging
from colors import color
from pyee import EventEmitter
from typing import Dict, Type
from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
@@ -38,6 +39,7 @@ from .a2dp import (
VendorSpecificMediaCodecInformation,
)
from . import sdp
from .colors import color
# -----------------------------------------------------------------------------
# Logging
@@ -49,6 +51,7 @@ logger = logging.getLogger(__name__)
# Constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
AVDTP_PSM = 0x0019
@@ -198,6 +201,8 @@ AVDTP_STATE_NAMES = {
}
# fmt: on
# pylint: enable=line-too-long
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------
@@ -318,7 +323,18 @@ class MediaPacket:
return header + self.payload
def __str__(self):
return f'RTP(v={self.version},p={self.padding},x={self.extension},m={self.marker},pt={self.payload_type},sn={self.sequence_number},ts={self.timestamp},ssrc={self.ssrc},csrcs={self.csrc_list},payload_size={len(self.payload)})'
return (
f'RTP(v={self.version},'
f'p={self.padding},'
f'x={self.extension},'
f'm={self.marker},'
f'pt={self.payload_type},'
f'sn={self.sequence_number},'
f'ts={self.timestamp},'
f'ssrc={self.ssrc},'
f'csrcs={self.csrc_list},'
f'payload_size={len(self.payload)})'
)
# -----------------------------------------------------------------------------
@@ -369,7 +385,7 @@ class MediaPacketPump:
# -----------------------------------------------------------------------------
class MessageAssembler:
class MessageAssembler: # pylint: disable=attribute-defined-outside-init
def __init__(self, callback):
self.callback = callback
self.reset()
@@ -390,16 +406,16 @@ class MessageAssembler:
message_type = pdu[0] & 3
logger.debug(
f'transaction_label={transaction_label}, packet_type={Protocol.packet_type_name(packet_type)}, message_type={Message.message_type_name(message_type)}'
f'transaction_label={transaction_label}, '
f'packet_type={Protocol.packet_type_name(packet_type)}, '
f'message_type={Message.message_type_name(message_type)}'
)
if (
packet_type == Protocol.SINGLE_PACKET
or packet_type == Protocol.START_PACKET
):
if packet_type in (Protocol.SINGLE_PACKET, Protocol.START_PACKET):
if self.message is not None:
# The previous message has not been terminated
logger.warning(
'received a start or single packet when expecting an end or continuation'
'received a start or single packet when expecting an end or '
'continuation'
)
self.reset()
@@ -413,23 +429,22 @@ class MessageAssembler:
else:
self.number_of_signal_packets = pdu[2]
self.message = pdu[3:]
elif (
packet_type == Protocol.CONTINUE_PACKET
or packet_type == Protocol.END_PACKET
):
elif packet_type in (Protocol.CONTINUE_PACKET, Protocol.END_PACKET):
if self.packet_count == 0:
logger.warning('unexpected continuation')
return
if transaction_label != self.transaction_label:
logger.warning(
f'transaction label mismatch: expected {self.transaction_label}, received {transaction_label}'
f'transaction label mismatch: expected {self.transaction_label}, '
f'received {transaction_label}'
)
return
if message_type != self.message_type:
logger.warning(
f'message type mismatch: expected {self.message_type}, received {message_type}'
f'message type mismatch: expected {self.message_type}, '
f'received {message_type}'
)
return
@@ -438,7 +453,9 @@ class MessageAssembler:
if packet_type == Protocol.END_PACKET:
if self.packet_count != self.number_of_signal_packets:
logger.warning(
f'incomplete fragmented message: expected {self.number_of_signal_packets} packets, received {self.packet_count}'
'incomplete fragmented message: '
f'expected {self.number_of_signal_packets} packets, '
f'received {self.packet_count}'
)
self.reset()
return
@@ -447,7 +464,9 @@ class MessageAssembler:
else:
if self.packet_count > self.number_of_signal_packets:
logger.warning(
f'too many packets: expected {self.number_of_signal_packets}, received {self.packet_count}'
'too many packets: '
f'expected {self.number_of_signal_packets}, '
f'received {self.packet_count}'
)
self.reset()
return
@@ -515,7 +534,7 @@ class ServiceCapabilities:
self.service_category = service_category
self.service_capabilities_bytes = service_capabilities_bytes
def to_string(self, details=[]):
def to_string(self, details=[]): # pylint: disable=dangerous-default-value
attributes = ','.join(
[name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)]
+ details
@@ -562,10 +581,16 @@ class MediaCodecCapabilities(ServiceCapabilities):
self.media_codec_information = media_codec_information
def __str__(self):
codec_info = (
self.media_codec_information.hex()
if isinstance(self.media_codec_information, bytes)
else str(self.media_codec_information)
)
details = [
f'media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}',
f'codec={name_or_number(A2DP_CODEC_TYPE_NAMES, self.media_codec_type)}',
f'codec_info={self.media_codec_information.hex() if type(self.media_codec_information) is bytes else str(self.media_codec_information)}',
f'codec_info={codec_info}',
]
return self.to_string(details)
@@ -591,7 +616,7 @@ class EndPointInfo:
# -----------------------------------------------------------------------------
class Message:
class Message: # pylint:disable=attribute-defined-outside-init
COMMAND = 0
GENERAL_REJECT = 1
RESPONSE_ACCEPT = 2
@@ -604,18 +629,19 @@ class Message:
RESPONSE_REJECT: 'RESPONSE_REJECT',
}
subclasses = {} # Subclasses, by signal identifier and message type
# Subclasses, by signal identifier and message type
subclasses: Dict[int, Dict[int, Type[Message]]] = {}
@staticmethod
def message_type_name(message_type):
return name_or_number(Message.MESSAGE_TYPE_NAMES, message_type)
@staticmethod
def subclass(cls):
def subclass(subclass):
# Infer the signal identifier and message subtype from the class name
name = cls.__name__
name = subclass.__name__
if name == 'General_Reject':
cls.signal_identifier = 0
subclass.signal_identifier = 0
signal_identifier_str = None
message_type = Message.COMMAND
elif name.endswith('_Command'):
@@ -630,22 +656,23 @@ class Message:
else:
raise ValueError('invalid class name')
cls.message_type = message_type
subclass.message_type = message_type
if signal_identifier_str is not None:
for (name, signal_identifier) in AVDTP_SIGNAL_IDENTIFIERS.items():
if name.lower().endswith(signal_identifier_str.lower()):
cls.signal_identifier = signal_identifier
subclass.signal_identifier = signal_identifier
break
# Register the subclass
Message.subclasses.setdefault(cls.signal_identifier, {})[
cls.message_type
] = cls
Message.subclasses.setdefault(subclass.signal_identifier, {})[
subclass.message_type
] = subclass
return cls
return subclass
# Factory method to create a subclass based on the signal identifier and message type
# Factory method to create a subclass based on the signal identifier and message
# type
@staticmethod
def create(signal_identifier, message_type, payload):
# Look for a registered subclass
@@ -676,18 +703,23 @@ class Message:
self.payload = payload
def to_string(self, details):
base = f'{color(f"{name_or_number(AVDTP_SIGNAL_NAMES, self.signal_identifier)}_{Message.message_type_name(self.message_type)}", "yellow")}'
base = color(
f'{name_or_number(AVDTP_SIGNAL_NAMES, self.signal_identifier)}_'
f'{Message.message_type_name(self.message_type)}',
'yellow',
)
if details:
if type(details) is str:
if isinstance(details, str):
return f'{base}: {details}'
else:
return (
base
+ ':\n'
+ '\n'.join([' ' + color(detail, 'cyan') for detail in details])
)
else:
return base
return (
base
+ ':\n'
+ '\n'.join([' ' + color(detail, 'cyan') for detail in details])
)
return base
def __str__(self):
return self.to_string(self.payload.hex())
@@ -703,8 +735,8 @@ class Simple_Command(Message):
self.acp_seid = self.payload[0] >> 2
def __init__(self, seid):
super().__init__(payload=bytes([seid << 2]))
self.acp_seid = seid
self.payload = bytes([seid << 2])
def __str__(self):
return self.to_string([f'ACP SEID: {self.acp_seid}'])
@@ -720,8 +752,8 @@ class Simple_Reject(Message):
self.error_code = self.payload[0]
def __init__(self, error_code):
super().__init__(payload=bytes([error_code]))
self.error_code = error_code
self.payload = bytes([self.error_code])
def __str__(self):
details = [f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}']
@@ -752,13 +784,14 @@ class Discover_Response(Message):
)
def __init__(self, endpoints):
super().__init__(payload=b''.join([bytes(endpoint) for endpoint in endpoints]))
self.endpoints = endpoints
self.payload = b''.join([bytes(endpoint) for endpoint in endpoints])
def __str__(self):
details = []
for endpoint in self.endpoints:
details.extend(
# pylint: disable=line-too-long
[
f'ACP SEID: {endpoint.seid}',
f' in_use: {endpoint.in_use}',
@@ -788,8 +821,10 @@ class Get_Capabilities_Response(Message):
self.capabilities = ServiceCapabilities.parse_capabilities(self.payload)
def __init__(self, capabilities):
super().__init__(
payload=ServiceCapabilities.serialize_capabilities(capabilities)
)
self.capabilities = capabilities
self.payload = ServiceCapabilities.serialize_capabilities(capabilities)
def __str__(self):
details = [str(capability) for capability in self.capabilities]
@@ -841,12 +876,13 @@ class Set_Configuration_Command(Message):
self.capabilities = ServiceCapabilities.parse_capabilities(self.payload[2:])
def __init__(self, acp_seid, int_seid, capabilities):
super().__init__(
payload=bytes([acp_seid << 2, int_seid << 2])
+ ServiceCapabilities.serialize_capabilities(capabilities)
)
self.acp_seid = acp_seid
self.int_seid = int_seid
self.capabilities = capabilities
self.payload = bytes(
[acp_seid << 2, int_seid << 2]
) + ServiceCapabilities.serialize_capabilities(capabilities)
def __str__(self):
details = [f'ACP SEID: {self.acp_seid}', f'INT SEID: {self.int_seid}'] + [
@@ -875,14 +911,20 @@ class Set_Configuration_Reject(Message):
self.error_code = self.payload[1]
def __init__(self, service_category, error_code):
super().__init__(payload=bytes([service_category, error_code]))
self.service_category = service_category
self.error_code = error_code
self.payload = bytes([service_category, self.error_code])
def __str__(self):
details = [
f'service_category: {name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)}',
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}',
(
'service_category: '
f'{name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)}'
),
(
'error_code: '
f'{name_or_number(AVDTP_ERROR_NAMES, self.error_code)}'
),
]
return self.to_string(details)
@@ -906,8 +948,10 @@ class Get_Configuration_Response(Message):
self.capabilities = ServiceCapabilities.parse_capabilities(self.payload)
def __init__(self, capabilities):
super().__init__(
payload=ServiceCapabilities.serialize_capabilities(capabilities)
)
self.capabilities = capabilities
self.payload = ServiceCapabilities.serialize_capabilities(capabilities)
def __str__(self):
details = [str(capability) for capability in self.capabilities]
@@ -930,6 +974,7 @@ class Reconfigure_Command(Message):
'''
def init_from_payload(self):
# pylint: disable=attribute-defined-outside-init
self.acp_seid = self.payload[0] >> 2
self.capabilities = ServiceCapabilities.parse_capabilities(self.payload[1:])
@@ -991,8 +1036,8 @@ class Start_Command(Message):
self.acp_seids = [x >> 2 for x in self.payload]
def __init__(self, seids):
super().__init__(payload=bytes([seid << 2 for seid in seids]))
self.acp_seids = seids
self.payload = bytes([seid << 2 for seid in self.acp_seids])
def __str__(self):
return self.to_string([f'ACP SEIDs: {self.acp_seids}'])
@@ -1018,9 +1063,9 @@ class Start_Reject(Message):
self.error_code = self.payload[1]
def __init__(self, acp_seid, error_code):
super().__init__(payload=bytes([acp_seid << 2, error_code]))
self.acp_seid = acp_seid
self.error_code = error_code
self.payload = bytes([self.acp_seid << 2, self.error_code])
def __str__(self):
details = [
@@ -1126,7 +1171,7 @@ class General_Reject(Message):
'''
def to_string(self, details):
return f'{color(f"GENERAL_REJECT", "yellow")}'
return color('GENERAL_REJECT', 'yellow')
# -----------------------------------------------------------------------------
@@ -1137,6 +1182,7 @@ class DelayReport_Command(Message):
'''
def init_from_payload(self):
# pylint: disable=attribute-defined-outside-init
self.acp_seid = self.payload[0] >> 2
self.delay = (self.payload[1] << 8) | (self.payload[2])
@@ -1206,9 +1252,11 @@ class Protocol:
l2cap_channel.on('open', self.on_l2cap_channel_open)
def get_local_endpoint_by_seid(self, seid):
if seid > 0 and seid <= len(self.local_endpoints):
if 0 < seid <= len(self.local_endpoints):
return self.local_endpoints[seid - 1]
return None
def add_source(self, codec_capabilities, packet_pump):
seid = len(self.local_endpoints) + 1
source = LocalSource(self, seid, codec_capabilities, packet_pump)
@@ -1288,12 +1336,15 @@ class Protocol:
if has_media_transport and has_codec:
return endpoint
return None
def on_pdu(self, pdu):
self.message_assembler.on_pdu(pdu)
def on_message(self, transaction_label, message):
logger.debug(
f'{color("<<< Received AVDTP message", "magenta")}: [{transaction_label}] {message}'
f'{color("<<< Received AVDTP message", "magenta")}: '
f'[{transaction_label}] {message}'
)
# Check that the identifier is not reserved
@@ -1311,7 +1362,12 @@ class Protocol:
if message.message_type == Message.COMMAND:
# Command
handler_name = f'on_{AVDTP_SIGNAL_NAMES.get(message.signal_identifier,"").replace("AVDTP_","").lower()}_command'
signal_name = (
AVDTP_SIGNAL_NAMES.get(message.signal_identifier, "")
.replace("AVDTP_", "")
.lower()
)
handler_name = f'on_{signal_name}_command'
handler = getattr(self, handler_name, None)
if handler:
try:
@@ -1344,7 +1400,8 @@ class Protocol:
def send_message(self, transaction_label, message):
logger.debug(
f'{color(">>> Sending AVDTP message", "magenta")}: [{transaction_label}] {message}'
f'{color(">>> Sending AVDTP message", "magenta")}: '
f'[{transaction_label}] {message}'
)
max_fragment_size = (
self.l2cap_channel.mtu - 3
@@ -1398,10 +1455,7 @@ class Protocol:
response = await transaction_result
# Check for errors
if (
response.message_type == Message.GENERAL_REJECT
or response.message_type == Message.RESPONSE_REJECT
):
if response.message_type in (Message.GENERAL_REJECT, Message.RESPONSE_REJECT):
raise ProtocolError(response.error_code, 'avdtp')
return response
@@ -1424,8 +1478,8 @@ class Protocol:
async def get_capabilities(self, seid):
if self.version > (1, 2):
return await self.send_command(Get_All_Capabilities_Command(seid))
else:
return await self.send_command(Get_Capabilities_Command(seid))
return await self.send_command(Get_Capabilities_Command(seid))
async def set_configuration(self, acp_seid, int_seid, capabilities):
return await self.send_command(
@@ -1451,7 +1505,7 @@ class Protocol:
async def abort(self, seid):
return await self.send_command(Abort_Command(seid))
def on_discover_command(self, command):
def on_discover_command(self, _command):
endpoint_infos = [
EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep)
for endpoint in self.local_endpoints
@@ -1689,7 +1743,7 @@ class Stream:
self.change_state(AVDTP_OPEN_STATE)
async def close(self):
if self.state not in {AVDTP_OPEN_STATE, AVDTP_STREAMING_STATE}:
if self.state not in (AVDTP_OPEN_STATE, AVDTP_STREAMING_STATE):
raise InvalidStateError('current state is not OPEN or STREAMING')
logger.debug('closing local endpoint')
@@ -1718,13 +1772,14 @@ class Stream:
return result
self.change_state(AVDTP_CONFIGURED_STATE)
return None
def on_get_configuration_command(self, configuration):
if self.state not in {
if self.state not in (
AVDTP_CONFIGURED_STATE,
AVDTP_OPEN_STATE,
AVDTP_STREAMING_STATE,
}:
):
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
return self.local_endpoint.on_get_configuration_command(configuration)
@@ -1737,6 +1792,8 @@ class Stream:
if result is not None:
return result
return None
def on_open_command(self):
if self.state != AVDTP_CONFIGURED_STATE:
return Open_Reject(AVDTP_BAD_STATE_ERROR)
@@ -1749,6 +1806,7 @@ class Stream:
self.protocol.channel_acceptor = self
self.change_state(AVDTP_OPEN_STATE)
return None
def on_start_command(self):
if self.state != AVDTP_OPEN_STATE:
@@ -1764,6 +1822,7 @@ class Stream:
return result
self.change_state(AVDTP_STREAMING_STATE)
return None
def on_suspend_command(self):
if self.state != AVDTP_STREAMING_STATE:
@@ -1774,9 +1833,10 @@ class Stream:
return result
self.change_state(AVDTP_OPEN_STATE)
return None
def on_close_command(self):
if self.state not in {AVDTP_OPEN_STATE, AVDTP_STREAMING_STATE}:
if self.state not in (AVDTP_OPEN_STATE, AVDTP_STREAMING_STATE):
return Open_Reject(AVDTP_BAD_STATE_ERROR)
result = self.local_endpoint.on_close_command()
@@ -1792,6 +1852,8 @@ class Stream:
# TODO: set a timer as we wait for the RTP channel to be closed
pass
return None
def on_abort_command(self):
if self.rtp_channel is None:
# No need to wait
@@ -1819,7 +1881,7 @@ class Stream:
self.local_endpoint.in_use = 0
self.rtp_channel = None
if self.state in {AVDTP_CLOSING_STATE, AVDTP_ABORTING_STATE}:
if self.state in (AVDTP_CLOSING_STATE, AVDTP_ABORTING_STATE):
self.change_state(AVDTP_IDLE_STATE)
else:
logger.warning('unexpected channel close while not CLOSING or ABORTING')
@@ -1839,7 +1901,10 @@ class Stream:
local_endpoint.in_use = 1
def __str__(self):
return f'Stream({self.local_endpoint.seid} -> {self.remote_endpoint.seid} {self.state_name(self.state)})'
return (
f'Stream({self.local_endpoint.seid} -> '
f'{self.remote_endpoint.seid} {self.state_name(self.state)})'
)
# -----------------------------------------------------------------------------
@@ -1852,12 +1917,14 @@ class StreamEndPoint:
self.capabilities = capabilities
def __str__(self):
media_type = f'{name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}'
tsep = f'{name_or_number(AVDTP_TSEP_NAMES, self.tsep)}'
return '\n'.join(
[
'SEP(',
f' seid={self.seid}',
f' media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}',
f' tsep={name_or_number(AVDTP_TSEP_NAMES, self.tsep)}',
f' media_type={media_type}',
f' tsep={tsep}',
f' in_use={self.in_use}',
' capabilities=[',
'\n'.join([f' {x}' for x in self.capabilities]),
@@ -1902,11 +1969,11 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
# -----------------------------------------------------------------------------
class LocalStreamEndPoint(StreamEndPoint):
def __init__(
self, protocol, seid, media_type, tsep, capabilities, configuration=[]
self, protocol, seid, media_type, tsep, capabilities, configuration=None
):
super().__init__(seid, media_type, tsep, 0, capabilities)
self.protocol = protocol
self.configuration = configuration
self.configuration = configuration if configuration is not None else []
self.stream = None
async def start(self):
@@ -1968,14 +2035,14 @@ class LocalSource(LocalStreamEndPoint, EventEmitter):
async def start(self):
if self.packet_pump:
return await self.packet_pump.start(self.stream.rtp_channel)
else:
self.emit('start', self.stream.rtp_channel)
self.emit('start', self.stream.rtp_channel)
async def stop(self):
if self.packet_pump:
return await self.packet_pump.stop()
else:
self.emit('stop')
self.emit('stop')
def on_set_configuration_command(self, configuration):
# For now, blindly accept the configuration
@@ -2018,6 +2085,7 @@ class LocalSink(LocalStreamEndPoint, EventEmitter):
def on_avdtp_packet(self, packet):
rtp_packet = MediaPacket.from_bytes(packet)
logger.debug(
f'{color("<<< RTP Packet:", "green")} {rtp_packet} {rtp_packet.payload[:16].hex()}'
f'{color("<<< RTP Packet:", "green")} '
f'{rtp_packet} {rtp_packet.payload[:16].hex()}'
)
self.emit('rtp_packet', rtp_packet)

103
bumble/colors.py Normal file
View File

@@ -0,0 +1,103 @@
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from functools import partial
from typing import List, Optional, Union
# ANSI color names. There is also a "default"
COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
# ANSI style names
STYLES = (
'none',
'bold',
'faint',
'italic',
'underline',
'blink',
'blink2',
'negative',
'concealed',
'crossed',
)
ColorSpec = Union[str, int]
def _join(*values: ColorSpec) -> str:
return ';'.join(str(v) for v in values)
def _color_code(spec: ColorSpec, base: int) -> str:
if isinstance(spec, str):
spec = spec.strip().lower()
if spec == 'default':
return _join(base + 9)
elif spec in COLORS:
return _join(base + COLORS.index(spec))
elif isinstance(spec, int) and 0 <= spec <= 255:
return _join(base + 8, 5, spec)
else:
raise ValueError('Invalid color spec "%s"' % spec)
def color(
s: str,
fg: Optional[ColorSpec] = None,
bg: Optional[ColorSpec] = None,
style: Optional[str] = None,
) -> str:
codes: List[ColorSpec] = []
if fg:
codes.append(_color_code(fg, 30))
if bg:
codes.append(_color_code(bg, 40))
if style:
for style_part in style.split('+'):
if style_part in STYLES:
codes.append(STYLES.index(style_part))
else:
raise ValueError('Invalid style "%s"' % style_part)
if codes:
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)
else:
return s
# Foreground color shortcuts
black = partial(color, fg='black')
red = partial(color, fg='red')
green = partial(color, fg='green')
yellow = partial(color, fg='yellow')
blue = partial(color, fg='blue')
magenta = partial(color, fg='magenta')
cyan = partial(color, fg='cyan')
white = partial(color, fg='white')
# Style shortcuts
bold = partial(color, style='bold')
none = partial(color, style='none')
faint = partial(color, style='faint')
italic = partial(color, style='italic')
underline = partial(color, style='underline')
blink = partial(color, style='blink')
blink2 = partial(color, style='blink2')
negative = partial(color, style='negative')
concealed = partial(color, style='concealed')
crossed = partial(color, style='crossed')

View File

@@ -17,6 +17,7 @@
# the `generate_company_id_list.py` script
# -----------------------------------------------------------------------------
# pylint: disable=line-too-long
COMPANY_IDENTIFIERS = {
0x0000: "Ericsson Technology Licensing",
0x0001: "Nokia Mobile Phones",
@@ -196,28 +197,28 @@ COMPANY_IDENTIFIERS = {
0x00AF: "Cinetix",
0x00B0: "Passif Semiconductor Corp",
0x00B1: "Saris Cycling Group, Inc",
0x00B2: "Bekey A/S",
0x00B3: "Clarinox Technologies Pty. Ltd.",
0x00B4: "BDE Technology Co., Ltd.",
0x00B2: "Bekey A/S",
0x00B3: "Clarinox Technologies Pty. Ltd.",
0x00B4: "BDE Technology Co., Ltd.",
0x00B5: "Swirl Networks",
0x00B6: "Meso international",
0x00B7: "TreLab Ltd",
0x00B8: "Qualcomm Innovation Center, Inc. (QuIC)",
0x00B9: "Johnson Controls, Inc.",
0x00BA: "Starkey Laboratories Inc.",
0x00BB: "S-Power Electronics Limited",
0x00BC: "Ace Sensor Inc",
0x00BD: "Aplix Corporation",
0x00BE: "AAMP of America",
0x00BF: "Stalmart Technology Limited",
0x00C0: "AMICCOM Electronics Corporation",
0x00C1: "Shenzhen Excelsecu Data Technology Co.,Ltd",
0x00C2: "Geneq Inc.",
0x00C3: "adidas AG",
0x00C4: "LG Electronics",
0x00C5: "Onset Computer Corporation",
0x00C6: "Selfly BV",
0x00C7: "Quuppa Oy.",
0x00B6: "Meso international",
0x00B7: "TreLab Ltd",
0x00B8: "Qualcomm Innovation Center, Inc. (QuIC)",
0x00B9: "Johnson Controls, Inc.",
0x00BA: "Starkey Laboratories Inc.",
0x00BB: "S-Power Electronics Limited",
0x00BC: "Ace Sensor Inc",
0x00BD: "Aplix Corporation",
0x00BE: "AAMP of America",
0x00BF: "Stalmart Technology Limited",
0x00C0: "AMICCOM Electronics Corporation",
0x00C1: "Shenzhen Excelsecu Data Technology Co.,Ltd",
0x00C2: "Geneq Inc.",
0x00C3: "adidas AG",
0x00C4: "LG Electronics",
0x00C5: "Onset Computer Corporation",
0x00C6: "Selfly BV",
0x00C7: "Quuppa Oy.",
0x00C8: "GeLo Inc",
0x00C9: "Evluma",
0x00CA: "MC10",
@@ -249,10 +250,10 @@ COMPANY_IDENTIFIERS = {
0x00E4: "Laird Connectivity, Inc. formerly L.S. Research Inc.",
0x00E5: "Eden Software Consultants Ltd.",
0x00E6: "Freshtemp",
0x00E7: "KS Technologies",
0x00E8: "ACTS Technologies",
0x00E9: "Vtrack Systems",
0x00EA: "Nielsen-Kellerman Company",
0x00E7: "KS Technologies",
0x00E8: "ACTS Technologies",
0x00E9: "Vtrack Systems",
0x00EA: "Nielsen-Kellerman Company",
0x00EB: "Server Technology Inc.",
0x00EC: "BioResearch Associates",
0x00ED: "Jolly Logic, LLC",

View File

@@ -19,9 +19,36 @@ import logging
import asyncio
import itertools
import random
import struct
from bumble.colors import color
from bumble.core import BT_CENTRAL_ROLE, BT_PERIPHERAL_ROLE
from bumble.hci import (
HCI_ACL_DATA_PACKET,
HCI_COMMAND_DISALLOWED_ERROR,
HCI_COMMAND_PACKET,
HCI_COMMAND_STATUS_PENDING,
HCI_CONNECTION_TIMEOUT_ERROR,
HCI_EVENT_PACKET,
HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR,
HCI_LE_1M_PHY,
HCI_SUCCESS,
HCI_UNKNOWN_HCI_COMMAND_ERROR,
HCI_VERSION_BLUETOOTH_CORE_5_0,
Address,
HCI_AclDataPacket,
HCI_AclDataPacketAssembler,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_Disconnection_Complete_Event,
HCI_Encryption_Change_Event,
HCI_LE_Advertising_Report_Event,
HCI_LE_Connection_Complete_Event,
HCI_LE_Read_Remote_Features_Complete_Event,
HCI_Number_Of_Completed_Packets_Event,
HCI_Packet,
)
from .hci import *
from .l2cap import *
# -----------------------------------------------------------------------------
# Logging
@@ -83,13 +110,19 @@ class Controller:
self.manufacturer_name = 0xFFFF
self.hc_le_data_packet_length = 27
self.hc_total_num_le_data_packets = 64
self.event_mask = 0
self.event_mask_page_2 = 0
self.supported_commands = bytes.fromhex(
'2000800000c000000000e40000002822000000000000040000f7ffff7f00000030f0f9ff01008004000000000000000000000000000000000000000000000000'
'2000800000c000000000e40000002822000000000000040000f7ffff7f000000'
'30f0f9ff01008004000000000000000000000000000000000000000000000000'
)
self.le_event_mask = 0
self.advertising_parameters = None
self.le_features = bytes.fromhex('ff49010000000000')
self.le_states = bytes.fromhex('ffff3fffff030000')
self.advertising_channel_tx_power = 0
self.filter_accept_list_size = 8
self.filter_duplicates = False
self.resolving_list_size = 8
self.supported_max_tx_octets = 27
self.supported_max_tx_time = 10000 # microseconds
@@ -133,7 +166,8 @@ class Controller:
@host.setter
def host(self, host):
'''
Sets the host (sink) for this controller, and set this controller as the controller (sink) for the host
Sets the host (sink) for this controller, and set this controller as the
controller (sink) for the host
'''
self.set_packet_sink(host)
if host:
@@ -151,7 +185,7 @@ class Controller:
@public_address.setter
def public_address(self, address):
if type(address) is str:
if isinstance(address, str):
address = Address(address)
self._public_address = address
@@ -161,7 +195,7 @@ class Controller:
@random_address.setter
def random_address(self, address):
if type(address) is str:
if isinstance(address, str):
address = Address(address)
self._random_address = address
logger.debug(f'new random address: {address}')
@@ -175,7 +209,8 @@ class Controller:
def on_hci_packet(self, packet):
logger.debug(
f'{color("<<<", "blue")} [{self.name}] {color("HOST -> CONTROLLER", "blue")}: {packet}'
f'{color("<<<", "blue")} [{self.name}] '
f'{color("HOST -> CONTROLLER", "blue")}: {packet}'
)
# If the packet is a command, invoke the handler for this packet
@@ -192,7 +227,7 @@ class Controller:
handler_name = f'on_{command.name.lower()}'
handler = getattr(self, handler_name, self.on_hci_command)
result = handler(command)
if type(result) is bytes:
if isinstance(result, bytes):
self.send_hci_packet(
HCI_Command_Complete_Event(
num_hci_command_packets=1,
@@ -201,7 +236,7 @@ class Controller:
)
)
def on_hci_event_packet(self, event):
def on_hci_event_packet(self, _event):
logger.warning('!!! unexpected event packet')
def on_hci_acl_data_packet(self, packet):
@@ -218,7 +253,8 @@ class Controller:
def send_hci_packet(self, packet):
logger.debug(
f'{color(">>>", "green")} [{self.name}] {color("CONTROLLER -> HOST", "green")}: {packet}'
f'{color(">>>", "green")} [{self.name}] '
f'{color("CONTROLLER -> HOST", "green")}: {packet}'
)
if self.host:
self.host.on_packet(packet.to_bytes())
@@ -312,7 +348,7 @@ class Controller:
# Remove the connection
del self.peripheral_connections[peer_address]
else:
logger.warn(f'!!! No peripheral connection found for {peer_address}')
logger.warning(f'!!! No peripheral connection found for {peer_address}')
def on_link_peripheral_connection_complete(
self, le_create_connection_command, status
@@ -339,6 +375,7 @@ class Controller:
# Say that the connection has completed
self.send_hci_packet(
# pylint: disable=line-too-long
HCI_LE_Connection_Complete_Event(
status=status,
connection_handle=connection.handle if connection else 0,
@@ -391,9 +428,9 @@ class Controller:
# Remove the connection
del self.central_connections[peer_address]
else:
logger.warn(f'!!! No central connection found for {peer_address}')
logger.warning(f'!!! No central connection found for {peer_address}')
def on_link_encrypted(self, peer_address, rand, ediv, ltk):
def on_link_encrypted(self, peer_address, _rand, _ediv, _ltk):
# For now, just setup the encryption without asking the host
if connection := self.find_connection_by_address(peer_address):
self.send_hci_packet(
@@ -420,8 +457,8 @@ class Controller:
return
# Send a scan report
report = HCI_Object(
HCI_LE_Advertising_Report_Event.REPORT_FIELDS,
report = HCI_LE_Advertising_Report_Event.Report(
HCI_LE_Advertising_Report_Event.Report.FIELDS,
event_type=HCI_LE_Advertising_Report_Event.ADV_IND,
address_type=sender_address.address_type,
address=sender_address,
@@ -431,8 +468,8 @@ class Controller:
self.send_hci_packet(HCI_LE_Advertising_Report_Event([report]))
# Simulate a scan response
report = HCI_Object(
HCI_LE_Advertising_Report_Event.REPORT_FIELDS,
report = HCI_LE_Advertising_Report_Event.Report(
HCI_LE_Advertising_Report_Event.Report.FIELDS,
event_type=HCI_LE_Advertising_Report_Event.SCAN_RSP,
address_type=sender_address.address_type,
address=sender_address,
@@ -505,7 +542,7 @@ class Controller:
command.connection_handle
)
):
logger.warn('connection not found')
logger.warning('connection not found')
return
if self.link:
@@ -521,7 +558,7 @@ class Controller:
self.event_mask = command.event_mask
return bytes([HCI_SUCCESS])
def on_hci_reset_command(self, command):
def on_hci_reset_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.2 Reset Command
'''
@@ -543,7 +580,7 @@ class Controller:
pass
return bytes([HCI_SUCCESS])
def on_hci_read_local_name_command(self, command):
def on_hci_read_local_name_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.12 Read Local Name Command
'''
@@ -553,21 +590,22 @@ class Controller:
return bytes([HCI_SUCCESS]) + local_name
def on_hci_read_class_of_device_command(self, command):
def on_hci_read_class_of_device_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.25 Read Class of Device Command
'''
return bytes([HCI_SUCCESS, 0, 0, 0])
def on_hci_write_class_of_device_command(self, command):
def on_hci_write_class_of_device_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.26 Write Class of Device Command
'''
return bytes([HCI_SUCCESS])
def on_hci_read_synchronous_flow_control_enable_command(self, command):
def on_hci_read_synchronous_flow_control_enable_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.36 Read Synchronous Flow Control Enable Command
See Bluetooth spec Vol 2, Part E - 7.3.36 Read Synchronous Flow Control Enable
Command
'''
if self.sync_flow_control:
ret = 1
@@ -577,7 +615,8 @@ class Controller:
def on_hci_write_synchronous_flow_control_enable_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.37 Write Synchronous Flow Control Enable Command
See Bluetooth spec Vol 2, Part E - 7.3.37 Write Synchronous Flow Control Enable
Command
'''
ret = HCI_SUCCESS
if command.synchronous_flow_control_enable == 1:
@@ -588,7 +627,7 @@ class Controller:
ret = HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR
return bytes([ret])
def on_hci_write_simple_pairing_mode_command(self, command):
def on_hci_write_simple_pairing_mode_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.59 Write Simple Pairing Mode Command
'''
@@ -601,13 +640,13 @@ class Controller:
self.event_mask_page_2 = command.event_mask_page_2
return bytes([HCI_SUCCESS])
def on_hci_read_le_host_support_command(self, command):
def on_hci_read_le_host_support_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.78 Write LE Host Support Command
'''
return bytes([HCI_SUCCESS, 1, 0])
def on_hci_write_le_host_support_command(self, command):
def on_hci_write_le_host_support_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.79 Write LE Host Support Command
'''
@@ -616,12 +655,13 @@ class Controller:
def on_hci_write_authenticated_payload_timeout_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.94 Write Authenticated Payload Timeout Command
See Bluetooth spec Vol 2, Part E - 7.3.94 Write Authenticated Payload Timeout
Command
'''
# TODO
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
def on_hci_read_local_version_information_command(self, command):
def on_hci_read_local_version_information_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.4.1 Read Local Version Information Command
'''
@@ -635,19 +675,19 @@ class Controller:
self.lmp_subversion,
)
def on_hci_read_local_supported_commands_command(self, command):
def on_hci_read_local_supported_commands_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.4.2 Read Local Supported Commands Command
'''
return bytes([HCI_SUCCESS]) + self.supported_commands
def on_hci_read_local_supported_features_command(self, command):
def on_hci_read_local_supported_features_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.4.3 Read Local Supported Features Command
'''
return bytes([HCI_SUCCESS]) + self.lmp_features
def on_hci_read_bd_addr_command(self, command):
def on_hci_read_bd_addr_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.4.6 Read BD_ADDR Command
'''
@@ -665,7 +705,7 @@ class Controller:
self.le_event_mask = command.le_event_mask
return bytes([HCI_SUCCESS])
def on_hci_le_read_buffer_size_command(self, command):
def on_hci_le_read_buffer_size_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.2 LE Read Buffer Size Command
'''
@@ -676,9 +716,10 @@ class Controller:
self.hc_total_num_le_data_packets,
)
def on_hci_le_read_local_supported_features_command(self, command):
def on_hci_le_read_local_supported_features_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.3 LE Read Local Supported Features Command
See Bluetooth spec Vol 2, Part E - 7.8.3 LE Read Local Supported Features
Command
'''
return bytes([HCI_SUCCESS]) + self.le_features
@@ -696,9 +737,10 @@ class Controller:
self.advertising_parameters = command
return bytes([HCI_SUCCESS])
def on_hci_le_read_advertising_channel_tx_power_command(self, command):
def on_hci_le_read_advertising_physical_channel_tx_power_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.6 LE Read Advertising Channel Tx Power Command
See Bluetooth spec Vol 2, Part E - 7.8.6 LE Read Advertising Physical Channel
Tx Power Command
'''
return bytes([HCI_SUCCESS, self.advertising_channel_tx_power])
@@ -779,33 +821,36 @@ class Controller:
)
)
def on_hci_le_create_connection_cancel_command(self, command):
def on_hci_le_create_connection_cancel_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.13 LE Create Connection Cancel Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_read_filter_accept_list_size_command(self, command):
def on_hci_le_read_filter_accept_list_size_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.14 LE Read Filter Accept List Size Command
See Bluetooth spec Vol 2, Part E - 7.8.14 LE Read Filter Accept List Size
Command
'''
return bytes([HCI_SUCCESS, self.filter_accept_list_size])
def on_hci_le_clear_filter_accept_list_command(self, command):
def on_hci_le_clear_filter_accept_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.15 LE Clear Filter Accept List Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_add_device_to_filter_accept_list_command(self, command):
def on_hci_le_add_device_to_filter_accept_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.16 LE Add Device To Filter Accept List Command
See Bluetooth spec Vol 2, Part E - 7.8.16 LE Add Device To Filter Accept List
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_remove_device_from_filter_accept_list_command(self, command):
def on_hci_le_remove_device_from_filter_accept_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.17 LE Remove Device From Filter Accept List Command
See Bluetooth spec Vol 2, Part E - 7.8.17 LE Remove Device From Filter Accept
List Command
'''
return bytes([HCI_SUCCESS])
@@ -832,7 +877,7 @@ class Controller:
)
)
def on_hci_le_rand_command(self, command):
def on_hci_le_rand_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.23 LE Rand Command
'''
@@ -849,7 +894,7 @@ class Controller:
command.connection_handle
)
):
logger.warn('connection not found')
logger.warning('connection not found')
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
# Notify that the connection is now encrypted
@@ -869,15 +914,18 @@ class Controller:
)
)
def on_hci_le_read_supported_states_command(self, command):
return None
def on_hci_le_read_supported_states_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.27 LE Read Supported States Command
'''
return bytes([HCI_SUCCESS]) + self.le_states
def on_hci_le_read_suggested_default_data_length_command(self, command):
def on_hci_le_read_suggested_default_data_length_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.34 LE Read Suggested Default Data Length Command
See Bluetooth spec Vol 2, Part E - 7.8.34 LE Read Suggested Default Data Length
Command
'''
return struct.pack(
'<BHH',
@@ -888,33 +936,35 @@ class Controller:
def on_hci_le_write_suggested_default_data_length_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.35 LE Write Suggested Default Data Length Command
See Bluetooth spec Vol 2, Part E - 7.8.35 LE Write Suggested Default Data Length
Command
'''
self.suggested_max_tx_octets, self.suggested_max_tx_time = struct.unpack(
'<HH', command.parameters[:4]
)
return bytes([HCI_SUCCESS])
def on_hci_le_read_local_p_256_public_key_command(self, command):
def on_hci_le_read_local_p_256_public_key_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.36 LE Read P-256 Public Key Command
'''
# TODO create key and send HCI_LE_Read_Local_P-256_Public_Key_Complete event
return bytes([HCI_SUCCESS])
def on_hci_le_add_device_to_resolving_list_command(self, command):
def on_hci_le_add_device_to_resolving_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.38 LE Add Device To Resolving List Command
See Bluetooth spec Vol 2, Part E - 7.8.38 LE Add Device To Resolving List
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_clear_resolving_list_command(self, command):
def on_hci_le_clear_resolving_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.40 LE Clear Resolving List Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_read_resolving_list_size_command(self, command):
def on_hci_le_read_resolving_list_size_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.41 LE Read Resolving List Size Command
'''
@@ -922,7 +972,8 @@ class Controller:
def on_hci_le_set_address_resolution_enable_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.44 LE Set Address Resolution Enable Command
See Bluetooth spec Vol 2, Part E - 7.8.44 LE Set Address Resolution Enable
Command
'''
ret = HCI_SUCCESS
if command.address_resolution_enable == 1:
@@ -935,12 +986,13 @@ class Controller:
def on_hci_le_set_resolvable_private_address_timeout_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.45 LE Set Resolvable Private Address Timeout Command
See Bluetooth spec Vol 2, Part E - 7.8.45 LE Set Resolvable Private Address
Timeout Command
'''
self.le_rpa_timeout = command.rpa_timeout
return bytes([HCI_SUCCESS])
def on_hci_le_read_maximum_data_length_command(self, command):
def on_hci_le_read_maximum_data_length_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.46 LE Read Maximum Data Length Command
'''
@@ -955,7 +1007,7 @@ class Controller:
def on_hci_le_read_phy_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.47 LE Read PHY command
See Bluetooth spec Vol 2, Part E - 7.8.47 LE Read PHY Command
'''
return struct.pack(
'<BHBB',
@@ -975,3 +1027,9 @@ class Controller:
'rx_phys': command.rx_phys,
}
return bytes([HCI_SUCCESS])
def on_hci_le_read_transmit_power_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.74 LE Read Transmit Power Command
'''
return struct.pack('<BBB', HCI_SUCCESS, 0, 0)

View File

@@ -15,7 +15,9 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from typing import List, Optional, Tuple, Union, cast
from .company_ids import COMPANY_IDENTIFIERS
@@ -100,7 +102,7 @@ class ProtocolError(BaseError):
"""Protocol Error"""
class TimeoutError(Exception):
class TimeoutError(Exception): # pylint: disable=redefined-builtin
"""Timeout Error"""
@@ -112,7 +114,7 @@ class InvalidStateError(Exception):
"""Invalid State Error"""
class ConnectionError(BaseError):
class ConnectionError(BaseError): # pylint: disable=redefined-builtin
"""Connection Error"""
FAILURE = 0x01
@@ -142,13 +144,16 @@ class ConnectionError(BaseError):
class UUID:
'''
See Bluetooth spec Vol 3, Part B - 2.5.1 UUID
Note that this class expects and works in little-endian byte-order throughout.
The exception is when interacting with strings, which are in big-endian byte-order.
'''
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')
UUIDS = [] # Registry of all instances created
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian
UUIDS: List[UUID] = [] # Registry of all instances created
def __init__(self, uuid_str_or_int, name=None):
if type(uuid_str_or_int) is int:
if isinstance(uuid_str_or_int, int):
self.uuid_bytes = struct.pack('<H', uuid_str_or_int)
else:
if len(uuid_str_or_int) == 36:
@@ -168,7 +173,8 @@ class UUID:
self.name = name
def register(self):
# Register this object in the class registry, and update the entry's name if it wasn't set already
# Register this object in the class registry, and update the entry's name if
# it wasn't set already
for uuid in self.UUIDS:
if self == uuid:
if uuid.name is None:
@@ -179,15 +185,15 @@ class UUID:
return self
@classmethod
def from_bytes(cls, uuid_bytes, name=None):
if len(uuid_bytes) in {2, 4, 16}:
def from_bytes(cls, uuid_bytes: bytes, name: Optional[str] = None) -> UUID:
if len(uuid_bytes) in (2, 4, 16):
self = cls.__new__(cls)
self.uuid_bytes = uuid_bytes
self.name = name
return self.register()
else:
raise ValueError('only 2, 4 and 16 bytes are allowed')
raise ValueError('only 2, 4 and 16 bytes are allowed')
@classmethod
def from_16_bits(cls, uuid_16, name=None):
@@ -198,20 +204,28 @@ class UUID:
return cls.from_bytes(struct.pack('<I', uuid_32), name)
@classmethod
def parse_uuid(cls, bytes, offset):
return len(bytes), cls.from_bytes(bytes[offset:])
def parse_uuid(cls, uuid_as_bytes, offset):
return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:])
@classmethod
def parse_uuid_2(cls, bytes, offset):
return offset + 2, cls.from_bytes(bytes[offset : offset + 2])
def parse_uuid_2(cls, uuid_as_bytes, offset):
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
def to_bytes(self, force_128=False):
if len(self.uuid_bytes) == 16 or not force_128:
'''
Serialize UUID in little-endian byte-order
'''
if not force_128:
return self.uuid_bytes
if len(self.uuid_bytes) == 2:
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
elif len(self.uuid_bytes) == 4:
return self.uuid_bytes + UUID.BASE_UUID
return self.BASE_UUID + self.uuid_bytes
elif len(self.uuid_bytes) == 16:
return self.uuid_bytes
else:
return self.uuid_bytes + bytes([0, 0]) + UUID.BASE_UUID
assert False, "unreachable"
def to_pdu_bytes(self):
'''
@@ -222,19 +236,19 @@ class UUID:
'''
return self.to_bytes(force_128=(len(self.uuid_bytes) == 4))
def to_hex_str(self):
def to_hex_str(self) -> str:
if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4:
return bytes(reversed(self.uuid_bytes)).hex().upper()
else:
return ''.join(
[
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex(),
]
).upper()
return ''.join(
[
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex(),
]
).upper()
def __bytes__(self):
return self.to_bytes()
@@ -242,7 +256,8 @@ class UUID:
def __eq__(self, other):
if isinstance(other, UUID):
return self.to_bytes(force_128=True) == other.to_bytes(force_128=True)
elif type(other) is str:
if isinstance(other, str):
return UUID(other) == self
return False
@@ -252,11 +267,11 @@ class UUID:
def __str__(self):
if len(self.uuid_bytes) == 2:
v = struct.unpack('<H', self.uuid_bytes)[0]
result = f'UUID-16:{v:04X}'
uuid = struct.unpack('<H', self.uuid_bytes)[0]
result = f'UUID-16:{uuid:04X}'
elif len(self.uuid_bytes) == 4:
v = struct.unpack('<I', self.uuid_bytes)[0]
result = f'UUID-32:{v:08X}'
uuid = struct.unpack('<I', self.uuid_bytes)[0]
result = f'UUID-32:{uuid:08X}'
else:
result = '-'.join(
[
@@ -267,10 +282,11 @@ class UUID:
bytes(reversed(self.uuid_bytes[0:6])).hex(),
]
).upper()
if self.name is not None:
return result + f' ({self.name})'
else:
return result
return result
def __repr__(self):
return str(self)
@@ -280,6 +296,7 @@ class UUID:
# Common UUID constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
# Protocol Identifiers
BT_SDP_PROTOCOL_ID = UUID.from_16_bits(0x0001, 'SDP')
@@ -386,6 +403,7 @@ BT_HDP_SOURCE_SERVICE = UUID.from_16_bits(0x1401,
BT_HDP_SINK_SERVICE = UUID.from_16_bits(0x1402, 'HDP Sink')
# fmt: on
# pylint: enable=line-too-long
# -----------------------------------------------------------------------------
@@ -393,6 +411,7 @@ BT_HDP_SINK_SERVICE = UUID.from_16_bits(0x1402,
# -----------------------------------------------------------------------------
class DeviceClass:
# fmt: off
# pylint: disable=line-too-long
# Major Service Classes (flags combined with OR)
LIMITED_DISCOVERABLE_MODE_SERVICE_CLASS = (1 << 0)
@@ -562,6 +581,7 @@ class DeviceClass:
}
# fmt: on
# pylint: enable=line-too-long
@staticmethod
def split_class_of_device(class_of_device):
@@ -598,8 +618,14 @@ class DeviceClass:
# -----------------------------------------------------------------------------
# Advertising Data
# -----------------------------------------------------------------------------
AdvertisingObject = Union[
List[UUID], Tuple[UUID, bytes], bytes, str, int, Tuple[int, int], Tuple[int, bytes]
]
class AdvertisingData:
# fmt: off
# pylint: disable=line-too-long
# This list is only partial, it still needs to be filled in from the spec
FLAGS = 0x01
@@ -712,9 +738,14 @@ class AdvertisingData:
BR_EDR_CONTROLLER_FLAG = 0x08
BR_EDR_HOST_FLAG = 0x10
# fmt: on
ad_structures: List[Tuple[int, bytes]]
def __init__(self, ad_structures=[]):
# fmt: on
# pylint: enable=line-too-long
def __init__(self, ad_structures: Optional[List[Tuple[int, bytes]]] = None) -> None:
if ad_structures is None:
ad_structures = []
self.ad_structures = ad_structures[:]
@staticmethod
@@ -739,7 +770,7 @@ class AdvertisingData:
return ','.join(bit_flags_to_strings(flags, flag_names))
@staticmethod
def uuid_list_to_objects(ad_data, uuid_size):
def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> List[UUID]:
uuids = []
offset = 0
while (uuid_size * (offset + 1)) <= len(ad_data):
@@ -814,53 +845,65 @@ class AdvertisingData:
return f'[{ad_type_str}]: {ad_data_str}'
# pylint: disable=too-many-return-statements
@staticmethod
def ad_data_to_object(ad_type, ad_data):
if ad_type in {
def ad_data_to_object(ad_type: int, ad_data: bytes) -> AdvertisingObject:
if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
}:
):
return AdvertisingData.uuid_list_to_objects(ad_data, 2)
elif ad_type in {
if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
}:
):
return AdvertisingData.uuid_list_to_objects(ad_data, 4)
elif ad_type in {
if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
}:
):
return AdvertisingData.uuid_list_to_objects(ad_data, 16)
elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
if ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
return (UUID.from_bytes(ad_data[:2]), ad_data[2:])
elif ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID:
if ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID:
return (UUID.from_bytes(ad_data[:4]), ad_data[4:])
elif ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID:
if ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID:
return (UUID.from_bytes(ad_data[:16]), ad_data[16:])
elif ad_type in {
if ad_type in (
AdvertisingData.SHORTENED_LOCAL_NAME,
AdvertisingData.COMPLETE_LOCAL_NAME,
AdvertisingData.URI,
}:
):
return ad_data.decode("utf-8")
elif ad_type in {AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS}:
return ad_data[0]
elif ad_type in {
if ad_type in (AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS):
return cast(int, struct.unpack('B', ad_data)[0])
if ad_type in (
AdvertisingData.APPEARANCE,
AdvertisingData.ADVERTISING_INTERVAL,
}:
return struct.unpack('<H', ad_data)[0]
elif ad_type == AdvertisingData.CLASS_OF_DEVICE:
return struct.unpack('<I', bytes([*ad_data, 0]))[0]
elif ad_type == AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE:
return struct.unpack('<HH', ad_data)
elif ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
return (struct.unpack_from('<H', ad_data, 0)[0], ad_data[2:])
else:
return ad_data
):
return cast(int, struct.unpack('<H', ad_data)[0])
if ad_type == AdvertisingData.CLASS_OF_DEVICE:
return cast(int, struct.unpack('<I', bytes([*ad_data, 0]))[0])
if ad_type == AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE:
return cast(Tuple[int, int], struct.unpack('<HH', ad_data))
if ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
return (cast(int, struct.unpack_from('<H', ad_data, 0)[0]), ad_data[2:])
return ad_data
def append(self, data):
offset = 0
@@ -873,30 +916,27 @@ class AdvertisingData:
self.ad_structures.append((ad_type, ad_data))
offset += length
def get(self, type_id, return_all=False, raw=False):
def get_all(self, type_id: int, raw: bool = False) -> List[AdvertisingObject]:
'''
Get Advertising Data Structure(s) with a given type
If return_all is True, returns a (possibly empty) list of matches,
else returns the first entry, or None if no structure matches.
Returns a (possibly empty) list of matches.
'''
def process_ad_data(ad_data):
def process_ad_data(ad_data: bytes) -> AdvertisingObject:
return ad_data if raw else self.ad_data_to_object(type_id, ad_data)
if return_all:
return [
process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id
]
else:
return next(
(
process_ad_data(ad[1])
for ad in self.ad_structures
if ad[0] == type_id
),
None,
)
return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id]
def get(self, type_id: int, raw: bool = False) -> Optional[AdvertisingObject]:
'''
Get Advertising Data Structure(s) with a given type
Returns the first entry, or None if no structure matches.
'''
all = self.get_all(type_id, raw=raw)
return all[0] if all else None
def __bytes__(self):
return b''.join(

View File

@@ -125,7 +125,7 @@ def e(key, data):
# -----------------------------------------------------------------------------
def ah(k, r):
def ah(k, r): # pylint: disable=redefined-outer-name
'''
See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah
'''
@@ -136,9 +136,10 @@ def ah(k, r):
# -----------------------------------------------------------------------------
def c1(k, r, preq, pres, iat, rat, ia, ra):
def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-name
'''
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for LE Legacy Pairing
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for
LE Legacy Pairing
'''
p1 = bytes([iat, rat]) + preq + pres
@@ -149,7 +150,8 @@ def c1(k, r, preq, pres, iat, rat, ia, ra):
# -----------------------------------------------------------------------------
def s1(k, r1, r2):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy Pairing
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy
Pairing
'''
return e(k, r2[0:8] + r1[0:8])
@@ -170,7 +172,8 @@ def aes_cmac(m, k):
# -----------------------------------------------------------------------------
def f4(u, v, x, z):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value Generation Function f4
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value
Generation Function f4
'''
return bytes(
reversed(
@@ -182,7 +185,8 @@ def f4(u, v, x, z):
# -----------------------------------------------------------------------------
def f5(w, n1, n2, a1, a2):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation Function f5
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation
Function f5
NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order
'''
@@ -222,9 +226,10 @@ def f5(w, n1, n2, a1, a2):
# -----------------------------------------------------------------------------
def f6(w, n1, n2, r, io_cap, a1, a2):
def f6(w, n1, n2, r, io_cap, a1, a2): # pylint: disable=redefined-outer-name
'''
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value Generation Function f6
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value
Generation Function f6
'''
return bytes(
reversed(
@@ -244,7 +249,8 @@ def f6(w, n1, n2, r, io_cap, a1, a2):
# -----------------------------------------------------------------------------
def g2(u, v, x, y):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison Value Generation Function g2
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison
Value Generation Function g2
'''
return int.from_bytes(
aes_cmac(

416
bumble/decoder.py Normal file
View File

@@ -0,0 +1,416 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
WL = [-60, -30, 58, 172, 334, 538, 1198, 3042]
RL42 = [0, 7, 6, 5, 4, 3, 2, 1, 7, 6, 5, 4, 3, 2, 1, 0]
ILB = [
2048,
2093,
2139,
2186,
2233,
2282,
2332,
2383,
2435,
2489,
2543,
2599,
2656,
2714,
2774,
2834,
2896,
2960,
3025,
3091,
3158,
3228,
3298,
3371,
3444,
3520,
3597,
3676,
3756,
3838,
3922,
4008,
]
WH = [0, -214, 798]
RH2 = [2, 1, 2, 1]
# Values in QM2/QM4/QM6 left shift three bits than original g722 specification.
QM2 = [-7408, -1616, 7408, 1616]
QM4 = [
0,
-20456,
-12896,
-8968,
-6288,
-4240,
-2584,
-1200,
20456,
12896,
8968,
6288,
4240,
2584,
1200,
0,
]
QM6 = [
-136,
-136,
-136,
-136,
-24808,
-21904,
-19008,
-16704,
-14984,
-13512,
-12280,
-11192,
-10232,
-9360,
-8576,
-7856,
-7192,
-6576,
-6000,
-5456,
-4944,
-4464,
-4008,
-3576,
-3168,
-2776,
-2400,
-2032,
-1688,
-1360,
-1040,
-728,
24808,
21904,
19008,
16704,
14984,
13512,
12280,
11192,
10232,
9360,
8576,
7856,
7192,
6576,
6000,
5456,
4944,
4464,
4008,
3576,
3168,
2776,
2400,
2032,
1688,
1360,
1040,
728,
432,
136,
-432,
-136,
]
QMF_COEFFS = [3, -11, 12, 32, -210, 951, 3876, -805, 362, -156, 53, -11]
# fmt: on
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class G722Decoder(object):
"""G.722 decoder with bitrate 64kbit/s.
For the Blocks in the sub-band decoders, please refer to the G.722
specification for the required information. G722 specification:
https://www.itu.int/rec/T-REC-G.722-201209-I
"""
def __init__(self):
self._x = [0] * 24
self._band = [Band(), Band()]
# The initial value in BLOCK 3L
self._band[0].det = 32
# The initial value in BLOCK 3H
self._band[1].det = 8
def decode_frame(self, encoded_data) -> bytearray:
result_array = bytearray(len(encoded_data) * 4)
self.g722_decode(result_array, encoded_data)
return result_array
def g722_decode(self, result_array, encoded_data) -> int:
"""Decode the data frame using g722 decoder."""
result_length = 0
for code in encoded_data:
higher_bits = (code >> 6) & 0x03
lower_bits = code & 0x3F
rlow = self.lower_sub_band_decoder(lower_bits)
rhigh = self.higher_sub_band_decoder(higher_bits)
# Apply the receive QMF
self._x[:22] = self._x[2:]
self._x[22] = rlow + rhigh
self._x[23] = rlow - rhigh
xout2 = sum(self._x[2 * i] * QMF_COEFFS[i] for i in range(12))
xout1 = sum(self._x[2 * i + 1] * QMF_COEFFS[11 - i] for i in range(12))
result_length = self.update_decoded_result(
xout1, result_length, result_array
)
result_length = self.update_decoded_result(
xout2, result_length, result_array
)
return result_length
def update_decoded_result(self, xout, byte_length, byte_array) -> int:
result = (int)(xout >> 11)
bytes_result = result.to_bytes(2, 'little', signed=True)
byte_array[byte_length] = bytes_result[0]
byte_array[byte_length + 1] = bytes_result[1]
return byte_length + 2
def lower_sub_band_decoder(self, lower_bits) -> int:
"""Lower sub-band decoder for last six bits."""
# Block 5L
# INVQBL
wd1 = lower_bits
wd2 = QM6[wd1]
wd1 >>= 2
wd2 = (self._band[0].det * wd2) >> 15
# RECONS
rlow = self._band[0].s + wd2
# Block 6L
# LIMIT
if rlow > 16383:
rlow = 16383
elif rlow < -16384:
rlow = -16384
# Block 2L
# INVQAL
wd2 = QM4[wd1]
dlowt = (self._band[0].det * wd2) >> 15
# Block 3L
# LOGSCL
wd2 = RL42[wd1]
wd1 = (self._band[0].nb * 127) >> 7
wd1 += WL[wd2]
if wd1 < 0:
wd1 = 0
elif wd1 > 18432:
wd1 = 18432
self._band[0].nb = wd1
# SCALEL
wd1 = (self._band[0].nb >> 6) & 31
wd2 = 8 - (self._band[0].nb >> 11)
if wd2 < 0:
wd3 = ILB[wd1] << -wd2
else:
wd3 = ILB[wd1] >> wd2
self._band[0].det = wd3 << 2
# Block 4L
self._band[0].block4(dlowt)
return rlow
def higher_sub_band_decoder(self, higher_bits) -> int:
"""Higher sub-band decoder for first two bits."""
# Block 2H
# INVQAH
wd2 = QM2[higher_bits]
dhigh = (self._band[1].det * wd2) >> 15
# Block 5H
# RECONS
rhigh = dhigh + self._band[1].s
# Block 6H
# LIMIT
if rhigh > 16383:
rhigh = 16383
elif rhigh < -16384:
rhigh = -16384
# Block 3H
# LOGSCH
wd2 = RH2[higher_bits]
wd1 = (self._band[1].nb * 127) >> 7
wd1 += WH[wd2]
if wd1 < 0:
wd1 = 0
elif wd1 > 22528:
wd1 = 22528
self._band[1].nb = wd1
# SCALEH
wd1 = (self._band[1].nb >> 6) & 31
wd2 = 10 - (self._band[1].nb >> 11)
if wd2 < 0:
wd3 = ILB[wd1] << -wd2
else:
wd3 = ILB[wd1] >> wd2
self._band[1].det = wd3 << 2
# Block 4H
self._band[1].block4(dhigh)
return rhigh
# -----------------------------------------------------------------------------
class Band(object):
"""Structure for G722 decode proccessing."""
s: int = 0
nb: int = 0
det: int = 0
def __init__(self):
self._sp = 0
self._sz = 0
self._r = [0] * 3
self._a = [0] * 3
self._ap = [0] * 3
self._p = [0] * 3
self._d = [0] * 7
self._b = [0] * 7
self._bp = [0] * 7
self._sg = [0] * 7
def saturate(self, amp: int) -> int:
if amp > 32767:
return 32767
elif amp < -32768:
return -32768
else:
return amp
def block4(self, d: int) -> None:
"""Block4 for both lower and higher sub-band decoder."""
wd1 = 0
wd2 = 0
wd3 = 0
# RECONS
self._d[0] = d
self._r[0] = self.saturate(self.s + d)
# PARREC
self._p[0] = self.saturate(self._sz + d)
# UPPOL2
for i in range(3):
self._sg[i] = (self._p[i]) >> 15
wd1 = self.saturate((self._a[1]) << 2)
wd2 = -wd1 if self._sg[0] == self._sg[1] else wd1
if wd2 > 32767:
wd2 = 32767
wd3 = 128 if self._sg[0] == self._sg[2] else -128
wd3 += wd2 >> 7
wd3 += (self._a[2] * 32512) >> 15
if wd3 > 12288:
wd3 = 12288
elif wd3 < -12288:
wd3 = -12288
self._ap[2] = wd3
# UPPOL1
self._sg[0] = (self._p[0]) >> 15
self._sg[1] = (self._p[1]) >> 15
wd1 = 192 if self._sg[0] == self._sg[1] else -192
wd2 = (self._a[1] * 32640) >> 15
self._ap[1] = self.saturate(wd1 + wd2)
wd3 = self.saturate(15360 - self._ap[2])
if self._ap[1] > wd3:
self._ap[1] = wd3
elif self._ap[1] < -wd3:
self._ap[1] = -wd3
# UPZERO
wd1 = 0 if d == 0 else 128
self._sg[0] = d >> 15
for i in range(1, 7):
self._sg[i] = (self._d[i]) >> 15
wd2 = wd1 if self._sg[i] == self._sg[0] else -wd1
wd3 = (self._b[i] * 32640) >> 15
self._bp[i] = self.saturate(wd2 + wd3)
# DELAYA
for i in range(6, 0, -1):
self._d[i] = self._d[i - 1]
self._b[i] = self._bp[i]
for i in range(2, 0, -1):
self._r[i] = self._r[i - 1]
self._p[i] = self._p[i - 1]
self._a[i] = self._ap[i]
# FILTEP
self._sp = 0
for i in range(1, 3):
wd1 = self.saturate(self._r[i] + self._r[i])
self._sp += (self._a[i] * wd1) >> 15
self._sp = self.saturate(self._sp)
# FILTEZ
self._sz = 0
for i in range(6, 0, -1):
wd1 = self.saturate(self._d[i] + self._d[i])
self._sz += (self._b[i] * wd1) >> 15
self._sz = self.saturate(self._sz)
# PREDIC
self.s = self.saturate(self._sp + self._sz)

File diff suppressed because it is too large Load Diff

View File

@@ -25,14 +25,15 @@
from __future__ import annotations
import asyncio
import enum
import types
import functools
import logging
from pyee import EventEmitter
from colors import color
import struct
from typing import Optional, Sequence
from .colors import color
from .core import UUID, get_dict_key_by_value
from .att import Attribute
from .core import *
from .hci import *
from .att import *
# -----------------------------------------------------------------------------
# Logging
@@ -43,6 +44,7 @@ logger = logging.getLogger(__name__)
# Constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
GATT_REQUEST_TIMEOUT = 30 # seconds
@@ -177,6 +179,7 @@ GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bi
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
# fmt: on
# pylint: enable=line-too-long
# -----------------------------------------------------------------------------
@@ -201,9 +204,11 @@ class Service(Attribute):
See Vol 3, Part G - 3.1 SERVICE DEFINITION
'''
uuid: UUID
def __init__(self, uuid, characteristics: list[Characteristic], primary=True):
# Convert the uuid to a UUID object if it isn't already
if type(uuid) is str:
if isinstance(uuid, str):
uuid = UUID(uuid)
super().__init__(
@@ -214,11 +219,11 @@ class Service(Attribute):
uuid.to_pdu_bytes(),
)
self.uuid = uuid
self.included_services = []
# self.included_services = []
self.characteristics = characteristics[:]
self.primary = primary
def get_advertising_data(self):
def get_advertising_data(self) -> Optional[bytes]:
"""
Get Service specific advertising data
Defined by each Service, default value is empty
@@ -227,7 +232,12 @@ class Service(Attribute):
return None
def __str__(self):
return f'Service(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}){"" if self.primary else "*"}'
return (
f'Service(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
f'uuid={self.uuid})'
f'{"" if self.primary else "*"}'
)
# -----------------------------------------------------------------------------
@@ -271,15 +281,15 @@ class Characteristic(Attribute):
}
@staticmethod
def property_name(property):
return Characteristic.PROPERTY_NAMES.get(property, '')
def property_name(property_int):
return Characteristic.PROPERTY_NAMES.get(property_int, '')
@staticmethod
def properties_as_string(properties):
return ','.join(
[
Characteristic.property_name(p)
for p in Characteristic.PROPERTY_NAMES.keys()
for p in Characteristic.PROPERTY_NAMES
if properties & p
]
)
@@ -298,11 +308,11 @@ class Characteristic(Attribute):
properties,
permissions,
value=b'',
descriptors: list[Descriptor] = [],
descriptors: Sequence[Descriptor] = (),
):
super().__init__(uuid, permissions, value)
self.uuid = self.type
if type(properties) is str:
if isinstance(properties, str):
self.properties = Characteristic.string_to_properties(properties)
else:
self.properties = properties
@@ -313,8 +323,15 @@ class Characteristic(Attribute):
if descriptor.type == descriptor_type:
return descriptor
return None
def __str__(self):
return f'Characteristic(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}, properties={Characteristic.properties_as_string(self.properties)})'
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
f'uuid={self.uuid}, '
f'properties={Characteristic.properties_as_string(self.properties)})'
)
# -----------------------------------------------------------------------------
@@ -335,7 +352,12 @@ class CharacteristicDeclaration(Attribute):
self.characteristic = characteristic
def __str__(self):
return f'CharacteristicDeclaration(handle=0x{self.handle:04X}, value_handle=0x{self.value_handle:04X}, uuid={self.characteristic.uuid}, properties={Characteristic.properties_as_string(self.characteristic.properties)})'
return (
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
f'value_handle=0x{self.value_handle:04X}, '
f'uuid={self.characteristic.uuid}, properties='
f'{Characteristic.properties_as_string(self.characteristic.properties)})'
)
# -----------------------------------------------------------------------------
@@ -395,14 +417,14 @@ class CharacteristicAdapter:
return getattr(self.wrapped_characteristic, name)
def __setattr__(self, name, value):
if name in {
if name in (
'wrapped_characteristic',
'subscribers',
'read_value',
'write_value',
'subscribe',
'unsubscribe',
}:
):
super().__setattr__(name, value)
else:
setattr(self.wrapped_characteristic, name, value)
@@ -486,9 +508,9 @@ class PackedCharacteristicAdapter(CharacteristicAdapter):
the format.
'''
def __init__(self, characteristic, format):
def __init__(self, characteristic, pack_format):
super().__init__(characteristic)
self.struct = struct.Struct(format)
self.struct = struct.Struct(pack_format)
def pack(self, *values):
return self.struct.pack(*values)
@@ -497,7 +519,7 @@ class PackedCharacteristicAdapter(CharacteristicAdapter):
return self.struct.unpack(buffer)
def encode_value(self, value):
return self.pack(*value if type(value) is tuple else (value,))
return self.pack(*value if isinstance(value, tuple) else (value,))
def decode_value(self, value):
unpacked = self.unpack(value)
@@ -510,14 +532,15 @@ class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format.
The adapted `read_value` and `write_value` methods return/accept aa dictionary which
is packed/unpacked according to format, with the arguments extracted from the dictionary
by key, in the same order as they occur in the `keys` parameter.
is packed/unpacked according to format, with the arguments extracted from the
dictionary by key, in the same order as they occur in the `keys` parameter.
'''
def __init__(self, characteristic, format, keys):
super().__init__(characteristic, format)
def __init__(self, characteristic, pack_format, keys):
super().__init__(characteristic, pack_format)
self.keys = keys
# pylint: disable=arguments-differ
def pack(self, values):
return super().pack(*(values[key] for key in self.keys))
@@ -544,16 +567,18 @@ class Descriptor(Attribute):
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
'''
def __init__(self, descriptor_type, permissions, value=b''):
super().__init__(descriptor_type, permissions, value)
def __str__(self):
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type}, value={self.read_value(None).hex()})'
return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'value={self.read_value(None).hex()})'
)
class ClientCharacteristicConfigurationBits(enum.IntFlag):
'''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit field definition
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
field definition
'''
DEFAULT = 0x0000

View File

@@ -27,10 +27,32 @@ import asyncio
import logging
import struct
from colors import color
from pyee import EventEmitter
from .att import *
from .core import InvalidStateError, ProtocolError, TimeoutError
from .colors import color
from .hci import HCI_Constant
from .att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
ATT_DEFAULT_MTU,
ATT_ERROR_RESPONSE,
ATT_INVALID_OFFSET_ERROR,
ATT_PDU,
ATT_RESPONSES,
ATT_Exchange_MTU_Request,
ATT_Find_By_Type_Value_Request,
ATT_Find_Information_Request,
ATT_Handle_Value_Confirmation,
ATT_Read_Blob_Request,
ATT_Read_By_Group_Type_Request,
ATT_Read_By_Type_Request,
ATT_Read_Request,
ATT_Write_Command,
ATT_Write_Request,
)
from . import core
from .core import UUID, InvalidStateError, ProtocolError
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
@@ -40,7 +62,6 @@ from .gatt import (
Characteristic,
ClientCharacteristicConfigurationBits,
)
from .hci import *
# -----------------------------------------------------------------------------
# Logging
@@ -76,16 +97,17 @@ class AttributeProxy(EventEmitter):
return value_bytes
def __str__(self):
return f'Attribute(handle=0x{self.handle:04X}, type={self.uuid})'
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
class ServiceProxy(AttributeProxy):
@staticmethod
def from_client(cls, client, service_uuid):
# The service and its characteristics are considered to have already been discovered
def from_client(service_class, client, service_uuid):
# The service and its characteristics are considered to have already been
# discovered
services = client.get_services_by_uuid(service_uuid)
service = services[0] if services else None
return cls(service) if service else None
return service_class(service) if service else None
def __init__(self, client, handle, end_group_handle, uuid, primary=True):
attribute_type = (
@@ -97,7 +119,7 @@ class ServiceProxy(AttributeProxy):
self.uuid = uuid
self.characteristics = []
async def discover_characteristics(self, uuids=[]):
async def discover_characteristics(self, uuids=()):
return await self.client.discover_characteristics(uuids, self)
def get_characteristics_by_uuid(self, uuid):
@@ -121,6 +143,8 @@ class CharacteristicProxy(AttributeProxy):
if descriptor.type == descriptor_type:
return descriptor
return None
async def discover_descriptors(self):
return await self.client.discover_descriptors(self)
@@ -148,7 +172,11 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.unsubscribe(self, subscriber)
def __str__(self):
return f'Characteristic(handle=0x{self.handle:04X}, uuid={self.uuid}, properties={Characteristic.properties_as_string(self.properties)})'
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'uuid={self.uuid}, '
f'properties={Characteristic.properties_as_string(self.properties)})'
)
class DescriptorProxy(AttributeProxy):
@@ -214,9 +242,9 @@ class Client:
response = await asyncio.wait_for(
self.pending_response, GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError:
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Request timeout', 'red'))
raise TimeoutError(f'GATT timeout for {request.name}')
raise core.TimeoutError(f'GATT timeout for {request.name}') from error
finally:
self.pending_request = None
self.pending_response = None
@@ -225,7 +253,8 @@ class Client:
def send_confirmation(self, confirmation):
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] {confirmation}'
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
self.send_gatt_pdu(confirmation.to_bytes())
@@ -300,7 +329,8 @@ class Client:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(
f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}'
'!!! unexpected error while discovering services: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return
@@ -352,7 +382,7 @@ class Client:
'''
# Force uuid to be a UUID object
if type(uuid) is str:
if isinstance(uuid, str):
uuid = UUID(uuid)
starting_handle = 0x0001
@@ -375,7 +405,8 @@ class Client:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(
f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}'
'!!! unexpected error while discovering services: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return
@@ -414,7 +445,7 @@ class Client:
return services
async def discover_included_services(self, service):
async def discover_included_services(self, _service):
'''
See Vol 3, Part G - 4.5.1 Find Included Services
'''
@@ -423,11 +454,12 @@ class Client:
async def discover_characteristics(self, uuids, service):
'''
See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2 Discover Characteristics by UUID
See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2
Discover Characteristics by UUID
'''
# Cast the UUIDs type from string to object if needed
uuids = [UUID(uuid) if type(uuid) is str else uuid for uuid in uuids]
uuids = [UUID(uuid) if isinstance(uuid, str) else uuid for uuid in uuids]
# Decide which services to discover for
services = [service] if service else self.services
@@ -456,7 +488,8 @@ class Client:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(
f'!!! unexpected error while discovering characteristics: {HCI_Constant.error_name(response.error_code)}'
'!!! unexpected error while discovering characteristics: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return
@@ -532,7 +565,8 @@ class Client:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(
f'!!! unexpected error while discovering descriptors: {HCI_Constant.error_name(response.error_code)}'
'!!! unexpected error while discovering descriptors: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return []
@@ -585,7 +619,8 @@ class Client:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(
f'!!! unexpected error while discovering attributes: {HCI_Constant.error_name(response.error_code)}'
'!!! unexpected error while discovering attributes: '
f'{HCI_Constant.error_name(response.error_code)}'
)
return []
break
@@ -607,7 +642,8 @@ class Client:
return attributes
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
# If we haven't already discovered the descriptors for this characteristic, do it now
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
await self.discover_descriptors(characteristic)
@@ -642,14 +678,16 @@ class Client:
subscriber_set = subscribers.setdefault(characteristic.handle, set())
if subscriber is not None:
subscriber_set.add(subscriber)
# Add the characteristic as a subscriber, which will result in the characteristic
# emitting an 'update' event when a notification or indication is received
# Add the characteristic as a subscriber, which will result in the
# characteristic emitting an 'update' event when a notification or indication
# is received
subscriber_set.add(characteristic)
await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
async def unsubscribe(self, characteristic, subscriber=None):
# If we haven't already discovered the descriptors for this characteristic, do it now
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
await self.discover_descriptors(characteristic)
@@ -673,7 +711,7 @@ class Client:
# Cleanup if we removed the last one
if not subscribers:
subscriber_set.remove(characteristic.handle)
del subscriber_set[characteristic.handle]
else:
# Remove all subscribers for this attribute from the sets!
self.notification_subscribers.pop(characteristic.handle, None)
@@ -691,7 +729,7 @@ class Client:
'''
# Send a request to read
attribute_handle = attribute if type(attribute) is int else attribute.handle
attribute_handle = attribute if isinstance(attribute, int) else attribute.handle
response = await self.send_request(
ATT_Read_Request(attribute_handle=attribute_handle)
)
@@ -720,9 +758,9 @@ class Client:
if response is None:
raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE:
if (
response.error_code == ATT_ATTRIBUTE_NOT_LONG_ERROR
or response.error_code == ATT_INVALID_OFFSET_ERROR
if response.error_code in (
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_INVALID_OFFSET_ERROR,
):
break
raise ProtocolError(
@@ -773,7 +811,8 @@ class Client:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(
f'!!! unexpected error while reading characteristics: {HCI_Constant.error_name(response.error_code)}'
'!!! unexpected error while reading characteristics: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return []
@@ -799,13 +838,14 @@ class Client:
async def write_value(self, attribute, value, with_response=False):
'''
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic Value
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
Value
`attribute` can be an Attribute object, or a handle value
'''
# Send a request or command to write
attribute_handle = attribute if type(attribute) is int else attribute.handle
attribute_handle = attribute if isinstance(attribute, int) else attribute.handle
if with_response:
response = await self.send_request(
ATT_Write_Request(
@@ -836,7 +876,8 @@ class Client:
logger.warning('!!! unexpected response, there is no pending request')
return
# Sanity check: the response should match the pending request unless it is an error response
# Sanity check: the response should match the pending request unless it is
# an error response
if att_pdu.op_code != ATT_ERROR_RESPONSE:
expected_response_name = self.pending_request.name.replace(
'_REQUEST', '_RESPONSE'
@@ -856,7 +897,12 @@ class Client:
handler(att_pdu)
else:
logger.warning(
f'{color(f"--- Ignoring GATT Response from [0x{self.connection.handle:04X}]:", "red")} {att_pdu}'
color(
'--- Ignoring GATT Response from '
f'[0x{self.connection.handle:04X}]: ',
'red',
)
+ str(att_pdu)
)
def on_att_handle_value_notification(self, notification):

View File

@@ -26,14 +26,52 @@
import asyncio
import logging
from collections import defaultdict
from typing import Tuple, Optional
import struct
from typing import List, Tuple, Optional
from pyee import EventEmitter
from colors import color
from .core import *
from .hci import *
from .att import *
from .gatt import *
from .colors import color
from .core import UUID
from .att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
ATT_DEFAULT_MTU,
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR,
ATT_INVALID_HANDLE_ERROR,
ATT_INVALID_OFFSET_ERROR,
ATT_REQUEST_NOT_SUPPORTED_ERROR,
ATT_REQUESTS,
ATT_UNLIKELY_ERROR_ERROR,
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
ATT_Error,
ATT_Error_Response,
ATT_Exchange_MTU_Response,
ATT_Find_By_Type_Value_Response,
ATT_Find_Information_Response,
ATT_Handle_Value_Indication,
ATT_Handle_Value_Notification,
ATT_Read_Blob_Response,
ATT_Read_By_Group_Type_Response,
ATT_Read_By_Type_Response,
ATT_Read_Response,
ATT_Write_Response,
Attribute,
)
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic,
CharacteristicDeclaration,
CharacteristicValue,
Descriptor,
Service,
)
# -----------------------------------------------------------------------------
# Logging
@@ -51,6 +89,8 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# GATT Server
# -----------------------------------------------------------------------------
class Server(EventEmitter):
attributes: List[Attribute]
def __init__(self, device):
super().__init__()
self.device = device
@@ -101,6 +141,7 @@ class Server(EventEmitter):
attribute
for attribute in self.attributes
if attribute.type == GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
and isinstance(attribute, Service)
and attribute.uuid == service_uuid
),
None,
@@ -194,6 +235,7 @@ class Server(EventEmitter):
is None
):
self.add_attribute(
# pylint: disable=line-too-long
Descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
Attribute.READABLE | Attribute.WRITEABLE,
@@ -232,12 +274,13 @@ class Server(EventEmitter):
def write_cccd(self, connection, characteristic, value):
logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, handle=0x{characteristic.handle:04X}: {value.hex()}'
f'Subscription update for connection=0x{connection.handle:04X}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
)
# Sanity check
if len(value) != 2:
logger.warn('CCCD value not 2 bytes long')
logger.warning('CCCD value not 2 bytes long')
return
cccds = self.subscribers.setdefault(connection.handle, {})
@@ -349,9 +392,9 @@ class Server(EventEmitter):
await asyncio.wait_for(
self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError:
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}')
raise TimeoutError(f'GATT timeout for {indication.name}') from error
finally:
self.pending_confirmations[connection.handle] = None
@@ -425,7 +468,11 @@ class Server(EventEmitter):
else:
# Just ignore
logger.warning(
f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}'
color(
f'--- Ignoring GATT Request from [0x{connection.handle:04X}]: ',
'red',
)
+ str(att_pdu)
)
#######################################################
@@ -436,7 +483,10 @@ class Server(EventEmitter):
Handler for requests without a more specific handler
'''
logger.warning(
f'{color(f"--- Unsupported ATT Request from [0x{connection.handle:04X}]:", "red")} {pdu}'
color(
f'--- Unsupported ATT Request from [0x{connection.handle:04X}]: ', 'red'
)
+ str(pdu)
)
response = ATT_Error_Response(
request_opcode_in_error=pdu.op_code,
@@ -492,8 +542,6 @@ class Server(EventEmitter):
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
):
# TODO: check permissions
this_uuid_size = len(attribute.type.to_pdu_bytes())
if attributes:
@@ -556,11 +604,11 @@ class Server(EventEmitter):
if attributes:
handles_information_list = []
for attribute in attributes:
if attribute.type in {
if attribute.type in (
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
}:
):
# Part of a group
group_end_handle = attribute.end_group_handle
else:
@@ -587,6 +635,13 @@ class Server(EventEmitter):
'''
pdu_space_available = connection.att_mtu - 2
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
attributes = []
for attribute in (
attribute
@@ -596,10 +651,21 @@ class Server(EventEmitter):
and attribute.handle <= request.ending_handle
and pdu_space_available
):
# TODO: check permissions
try:
attribute_value = attribute.read_value(connection)
except ATT_Error as error:
# If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point
if not attributes:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=attribute.handle,
error_code=error.error_code,
)
break
# Check the attribute value size
attribute_value = attribute.read_value(connection)
max_attribute_size = min(connection.att_mtu - 4, 253)
if len(attribute_value) > max_attribute_size:
# We need to truncate
@@ -625,11 +691,7 @@ class Server(EventEmitter):
length=entry_size, attribute_data_list=b''.join(attribute_data_list)
)
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
logging.debug(f"not found {request}")
self.send_response(connection, response)
@@ -639,10 +701,17 @@ class Server(EventEmitter):
'''
if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions
value = attribute.read_value(connection)
value_size = min(connection.att_mtu - 1, len(value))
response = ATT_Read_Response(attribute_value=value[:value_size])
try:
value = attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=error.error_code,
)
else:
value_size = min(connection.att_mtu - 1, len(value))
response = ATT_Read_Response(attribute_value=value[:value_size])
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -657,29 +726,36 @@ class Server(EventEmitter):
'''
if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions
value = attribute.read_value(connection)
if request.value_offset > len(value):
try:
value = attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_OFFSET_ERROR,
)
elif len(value) <= connection.att_mtu - 1:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
error_code=error.error_code,
)
else:
part_size = min(
connection.att_mtu - 1, len(value) - request.value_offset
)
response = ATT_Read_Blob_Response(
part_attribute_value=value[
request.value_offset : request.value_offset + part_size
]
)
if request.value_offset > len(value):
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_OFFSET_ERROR,
)
elif len(value) <= connection.att_mtu - 1:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
)
else:
part_size = min(
connection.att_mtu - 1, len(value) - request.value_offset
)
response = ATT_Read_Blob_Response(
part_attribute_value=value[
request.value_offset : request.value_offset + part_size
]
)
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -692,11 +768,10 @@ class Server(EventEmitter):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
'''
if request.attribute_group_type not in {
if request.attribute_group_type not in (
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_INCLUDE_ATTRIBUTE_TYPE,
}:
):
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
@@ -715,8 +790,10 @@ class Server(EventEmitter):
and attribute.handle <= request.ending_handle
and pdu_space_available
):
# Check the attribute value size
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = attribute.read_value(connection)
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
@@ -814,7 +891,7 @@ class Server(EventEmitter):
except Exception as error:
logger.warning(f'!!! ignoring exception: {error}')
def on_att_handle_value_confirmation(self, connection, confirmation):
def on_att_handle_value_confirmation(self, connection, _confirmation):
'''
See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation
'''

View File

@@ -15,13 +15,24 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
import collections
import logging
import functools
from colors import color
from typing import Dict, Type, Union
from .colors import color
from .core import (
BT_BR_EDR_TRANSPORT,
AdvertisingData,
DeviceClass,
ProtocolError,
bit_flags_to_strings,
name_or_number,
padded_bytes,
)
from .core import *
# -----------------------------------------------------------------------------
# Logging
@@ -43,8 +54,8 @@ def key_with_value(dictionary, target_value):
return None
def indent_lines(str):
return '\n'.join([' ' + line for line in str.split('\n')])
def indent_lines(string):
return '\n'.join([' ' + line for line in string.split('\n')])
def map_null_terminated_utf8_string(utf8_bytes):
@@ -63,25 +74,32 @@ def map_class_of_device(class_of_device):
major_device_class,
minor_device_class,
) = DeviceClass.split_class_of_device(class_of_device)
return f'[{class_of_device:06X}] Services({",".join(DeviceClass.service_class_labels(service_classes))}),Class({DeviceClass.major_device_class_name(major_device_class)}|{DeviceClass.minor_device_class_name(major_device_class, minor_device_class)})'
return (
f'[{class_of_device:06X}] Services('
f'{",".join(DeviceClass.service_class_labels(service_classes))}),'
f'Class({DeviceClass.major_device_class_name(major_device_class)}|'
f'{DeviceClass.minor_device_class_name(major_device_class, minor_device_class)}'
')'
)
def phy_list_to_bits(phys):
if phys is None:
return 0
else:
phy_bits = 0
for phy in phys:
if phy not in HCI_LE_PHY_TYPE_TO_BIT:
raise ValueError('invalid PHY')
phy_bits |= 1 << HCI_LE_PHY_TYPE_TO_BIT[phy]
return phy_bits
phy_bits = 0
for phy in phys:
if phy not in HCI_LE_PHY_TYPE_TO_BIT:
raise ValueError('invalid PHY')
phy_bits |= 1 << HCI_LE_PHY_TYPE_TO_BIT[phy]
return phy_bits
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
# HCI Version
HCI_VERSION_BLUETOOTH_CORE_1_0B = 0
@@ -1355,8 +1373,11 @@ HCI_LE_SUPPORTED_FEATURES_NAMES = {
}
# fmt: on
# pylint: enable=line-too-long
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------
# pylint: disable-next=unnecessary-lambda
STATUS_SPEC = {'size': 1, 'mapper': lambda x: HCI_Constant.status_name(x)}
@@ -1400,7 +1421,11 @@ class HCI_Constant:
# -----------------------------------------------------------------------------
class HCI_Error(ProtocolError):
def __init__(self, error_code):
super().__init__(error_code, 'hci', HCI_Constant.error_name(error_code))
super().__init__(
error_code,
error_namespace='hci',
error_name=HCI_Constant.error_name(error_code),
)
# -----------------------------------------------------------------------------
@@ -1418,25 +1443,25 @@ class HCI_StatusError(ProtocolError):
# -----------------------------------------------------------------------------
class HCI_Object:
@staticmethod
def init_from_fields(object, fields, values):
if type(values) is dict:
def init_from_fields(hci_object, fields, values):
if isinstance(values, dict):
for field_name, _ in fields:
setattr(object, field_name, values[field_name])
setattr(hci_object, field_name, values[field_name])
else:
for field_name, field_value in zip(fields, values):
setattr(object, field_name, field_value)
setattr(hci_object, field_name, field_value)
@staticmethod
def init_from_bytes(object, data, offset, fields):
def init_from_bytes(hci_object, data, offset, fields):
parsed = HCI_Object.dict_from_bytes(data, offset, fields)
HCI_Object.init_from_fields(object, parsed.keys(), parsed.values())
HCI_Object.init_from_fields(hci_object, parsed.keys(), parsed.values())
@staticmethod
def dict_from_bytes(data, offset, fields):
result = collections.OrderedDict()
for (field_name, field_type) in fields:
# The field_type may be a dictionary with a mapper, parser, and/or size
if type(field_type) is dict:
if isinstance(field_type, dict):
if 'size' in field_type:
field_type = field_type['size']
elif 'parser' in field_type:
@@ -1466,7 +1491,7 @@ class HCI_Object:
elif field_type == -2:
# 16-bit signed
field_value = struct.unpack_from('<h', data, offset)[0]
offset += 1
offset += 2
elif field_type == 3:
# 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0])
@@ -1480,7 +1505,7 @@ class HCI_Object:
# 32-bit unsigned big-endian
field_value = struct.unpack_from('>I', data, offset)[0]
offset += 4
elif type(field_type) is int and field_type > 4 and field_type <= 256:
elif isinstance(field_type, int) and 4 < field_type <= 256:
# Byte array (from 5 up to 256 bytes)
field_value = data[offset : offset + field_type]
offset += field_type
@@ -1494,19 +1519,20 @@ class HCI_Object:
return result
@staticmethod
def dict_to_bytes(object, fields):
def dict_to_bytes(hci_object, fields):
result = bytearray()
for (field_name, field_type) in fields:
# The field_type may be a dictionary with a mapper, parser, serializer, and/or size
# The field_type may be a dictionary with a mapper, parser, serializer,
# and/or size
serializer = None
if type(field_type) is dict:
if isinstance(field_type, dict):
if 'serializer' in field_type:
serializer = field_type['serializer']
if 'size' in field_type:
field_type = field_type['size']
# Serialize the field
field_value = object[field_name]
field_value = hci_object[field_name]
if serializer:
field_bytes = serializer(field_value)
elif field_type == 1:
@@ -1534,20 +1560,18 @@ class HCI_Object:
# 32-bit unsigned big-endian
field_bytes = struct.pack('>I', field_value)
elif field_type == '*':
if type(field_value) is int:
if field_value >= 0 and field_value <= 255:
if isinstance(field_value, int):
if 0 <= field_value <= 255:
field_bytes = bytes([field_value])
else:
raise ValueError('value too large for *-typed field')
else:
field_bytes = bytes(field_value)
elif (
type(field_value) is bytes
or type(field_value) is bytearray
or hasattr(field_value, 'to_bytes')
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
field_value, 'to_bytes'
):
field_bytes = bytes(field_value)
if type(field_type) is int and field_type > 4 and field_type <= 256:
if isinstance(field_type, int) and 4 < field_type <= 256:
# Truncate or Pad with zeros if the field is too long or too short
if len(field_bytes) < field_type:
field_bytes += bytes(field_type - len(field_bytes))
@@ -1584,40 +1608,42 @@ class HCI_Object:
@staticmethod
def format_field_value(value, indentation):
if type(value) is bytes:
if isinstance(value, bytes):
return value.hex()
elif isinstance(value, HCI_Object):
if isinstance(value, HCI_Object):
return '\n' + value.to_string(indentation)
else:
return str(value)
return str(value)
@staticmethod
def format_fields(object, keys, indentation='', value_mappers={}):
def format_fields(hci_object, keys, indentation='', value_mappers=None):
if not keys:
return ''
# Measure the widest field name
max_field_name_length = max(
[len(key[0] if type(key) is tuple else key) for key in keys]
(len(key[0] if isinstance(key, tuple) else key) for key in keys)
)
# Build array of formatted key:value pairs
fields = []
for key in keys:
value_mapper = None
if type(key) is tuple:
if isinstance(key, tuple):
# The key has an associated specifier
key, specifier = key
# Get the value mapper from the specifier
if type(specifier) is dict:
if isinstance(specifier, dict):
value_mapper = specifier.get('mapper')
# Get the value for the field
value = object[key]
value = hci_object[key]
# Map the value if needed
value_mapper = value_mappers.get(key, value_mapper)
if value_mappers:
value_mapper = value_mappers.get(key, value_mapper)
if value_mapper is not None:
value = value_mapper(value)
@@ -1639,7 +1665,7 @@ class HCI_Object:
self.fields = fields
self.init_from_fields(self, fields, kwargs)
def to_string(self, indentation='', value_mappers={}):
def to_string(self, indentation='', value_mappers=None):
return HCI_Object.format_fields(
self.__dict__, self.fields, indentation, value_mappers
)
@@ -1670,6 +1696,12 @@ class Address:
RANDOM_IDENTITY_ADDRESS: 'RANDOM_IDENTITY_ADDRESS',
}
# Type declarations
NIL: Address
ANY: Address
ANY_RANDOM: Address
# pylint: disable-next=unnecessary-lambda
ADDRESS_TYPE_SPEC = {'size': 1, 'mapper': lambda x: Address.address_type_name(x)}
@staticmethod
@@ -1686,7 +1718,8 @@ class Address:
@staticmethod
def parse_address(data, offset):
# Fix the type to a default value. This is used for parsing type-less Classic addresses
# Fix the type to a default value. This is used for parsing type-less Classic
# addresses
return Address.parse_address_with_type(
data, offset, Address.PUBLIC_DEVICE_ADDRESS
)
@@ -1700,15 +1733,17 @@ class Address:
address_type = data[offset - 1]
return Address.parse_address_with_type(data, offset, address_type)
def __init__(self, address, address_type=RANDOM_DEVICE_ADDRESS):
def __init__(
self, address: Union[bytes, str], address_type: int = RANDOM_DEVICE_ADDRESS
):
'''
Initialize an instance. `address` may be a byte array in little-endian
format, or a hex string in big-endian format (with optional ':'
separators between the bytes).
If the address is a string suffixed with '/P', `address_type` is ignored and the type
is set to PUBLIC_DEVICE_ADDRESS.
If the address is a string suffixed with '/P', `address_type` is ignored and
the type is set to PUBLIC_DEVICE_ADDRESS.
'''
if type(address) is bytes:
if isinstance(address, bytes):
self.address_bytes = address
else:
# Check if there's a '/P' type specifier
@@ -1731,9 +1766,9 @@ class Address:
@property
def is_public(self):
return (
self.address_type == self.PUBLIC_DEVICE_ADDRESS
or self.address_type == self.PUBLIC_IDENTITY_ADDRESS
return self.address_type in (
self.PUBLIC_DEVICE_ADDRESS,
self.PUBLIC_IDENTITY_ADDRESS,
)
@property
@@ -1742,9 +1777,9 @@ class Address:
@property
def is_resolved(self):
return (
self.address_type == self.PUBLIC_IDENTITY_ADDRESS
or self.address_type == self.RANDOM_IDENTITY_ADDRESS
return self.address_type in (
self.PUBLIC_IDENTITY_ADDRESS,
self.RANDOM_IDENTITY_ADDRESS,
)
@property
@@ -1776,15 +1811,16 @@ class Address:
'''
String representation of the address, MSB first
'''
str = ':'.join([f'{x:02X}' for x in reversed(self.address_bytes)])
result = ':'.join([f'{x:02X}' for x in reversed(self.address_bytes)])
if not self.is_public:
return str
return str + '/P'
return result
return result + '/P'
# Predefined address values
Address.NIL = Address(b"\xff\xff\xff\xff\xff\xff", Address.PUBLIC_DEVICE_ADDRESS)
Address.ANY = Address(b"\x00\x00\x00\x00\x00\x00", Address.PUBLIC_DEVICE_ADDRESS)
Address.ANY_RANDOM = Address(b"\x00\x00\x00\x00\x00\x00", Address.RANDOM_DEVICE_ADDRESS)
# -----------------------------------------------------------------------------
class OwnAddressType:
@@ -1801,9 +1837,10 @@ class OwnAddressType:
}
@staticmethod
def type_name(type):
return name_or_number(OwnAddressType.TYPE_NAMES, type)
def type_name(type_id):
return name_or_number(OwnAddressType.TYPE_NAMES, type_id)
# pylint: disable-next=unnecessary-lambda
TYPE_SPEC = {'size': 1, 'mapper': lambda x: OwnAddressType.type_name(x)}
@@ -1813,21 +1850,29 @@ class HCI_Packet:
Abstract Base class for HCI packets
'''
hci_packet_type: int
@staticmethod
def from_bytes(packet):
packet_type = packet[0]
if packet_type == HCI_COMMAND_PACKET:
return HCI_Command.from_bytes(packet)
elif packet_type == HCI_ACL_DATA_PACKET:
if packet_type == HCI_ACL_DATA_PACKET:
return HCI_AclDataPacket.from_bytes(packet)
elif packet_type == HCI_EVENT_PACKET:
if packet_type == HCI_EVENT_PACKET:
return HCI_Event.from_bytes(packet)
else:
return HCI_CustomPacket(packet)
return HCI_CustomPacket(packet)
def __init__(self, name):
self.name = name
def __bytes__(self) -> bytes:
raise NotImplementedError
def __repr__(self) -> str:
return self.name
@@ -1839,6 +1884,9 @@ class HCI_CustomPacket(HCI_Packet):
self.hci_packet_type = payload[0]
self.payload = payload
def __bytes__(self) -> bytes:
return self.payload
# -----------------------------------------------------------------------------
class HCI_Command(HCI_Packet):
@@ -1847,10 +1895,10 @@ class HCI_Command(HCI_Packet):
'''
hci_packet_type = HCI_COMMAND_PACKET
command_classes = {}
command_classes: Dict[int, Type[HCI_Command]] = {}
@staticmethod
def command(fields=[], return_parameters_fields=[]):
def command(fields=(), return_parameters_fields=()):
'''
Decorator used to declare and register subclasses
'''
@@ -1897,8 +1945,8 @@ class HCI_Command(HCI_Packet):
HCI_Command.__init__(self, op_code, parameters)
HCI_Object.init_from_bytes(self, parameters, 0, fields)
return self
else:
return cls.from_parameters(parameters)
return cls.from_parameters(parameters)
@staticmethod
def command_name(op_code):
@@ -2049,6 +2097,24 @@ class HCI_Link_Key_Request_Negative_Reply_Command(HCI_Command):
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
('bd_addr', Address.parse_address),
('pin_code_length', 1),
('pin_code', '*'),
],
return_parameters_fields=[
('status', STATUS_SPEC),
('bd_addr', Address.parse_address),
],
)
class HCI_PIN_Code_Request_Reply_Command(HCI_Command):
'''
See Bluetooth spec @ 7.1.12 PIN Code Request Reply Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[('bd_addr', Address.parse_address)],
@@ -2842,6 +2908,7 @@ class HCI_LE_Set_Random_Address_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
# pylint: disable=line-too-long,unnecessary-lambda
[
('advertising_interval_min', 2),
('advertising_interval_max', 2),
@@ -3077,6 +3144,16 @@ class HCI_LE_Read_Remote_Features_Command(HCI_Command):
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[("status", STATUS_SPEC), ("random_number", 8)]
)
class HCI_LE_Rand_Command(HCI_Command):
"""
See Bluetooth spec @ 7.8.23 LE Rand Command
"""
# -----------------------------------------------------------------------------
@HCI_Command.command(
[
@@ -3089,7 +3166,8 @@ class HCI_LE_Read_Remote_Features_Command(HCI_Command):
class HCI_LE_Enable_Encryption_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.24 LE Enable Encryption Command
(renamed from "LE Start Encryption Command" in version prior to 5.2 of the specification)
(renamed from "LE Start Encryption Command" in version prior to 5.2 of the
specification)
'''
@@ -3144,7 +3222,8 @@ class HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(HCI_Command):
)
class HCI_LE_Remote_Connection_Parameter_Request_Negative_Reply_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.32 LE Remote Connection Parameter Request Negative Reply Command
See Bluetooth spec @ 7.8.32 LE Remote Connection Parameter Request Negative Reply
Command
'''
@@ -3356,6 +3435,7 @@ class HCI_LE_Set_Advertising_Set_Random_Address_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
# pylint: disable=line-too-long,unnecessary-lambda
fields=[
('advertising_handle', 1),
(
@@ -3422,6 +3502,7 @@ class HCI_LE_Set_Extended_Advertising_Parameters_Command(HCI_Command):
@classmethod
def advertising_properties_string(cls, properties):
# pylint: disable=line-too-long
return f'[{",".join(bit_flags_to_strings(properties, cls.ADVERTISING_PROPERTIES_NAMES))}]'
@classmethod
@@ -3431,6 +3512,7 @@ class HCI_LE_Set_Extended_Advertising_Parameters_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
# pylint: disable=line-too-long,unnecessary-lambda
[
('advertising_handle', 1),
(
@@ -3480,6 +3562,7 @@ class HCI_LE_Set_Extended_Advertising_Data_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
# pylint: disable=line-too-long,unnecessary-lambda
[
('advertising_handle', 1),
(
@@ -3573,9 +3656,9 @@ class HCI_LE_Set_Extended_Advertising_Enable_Command(HCI_Command):
def __str__(self):
fields = [('enable:', self.enable)]
for i in range(len(self.advertising_handles)):
for i, advertising_handle in enumerate(self.advertising_handles):
fields.append(
(f'advertising_handle[{i}]: ', self.advertising_handles[i])
(f'advertising_handle[{i}]: ', advertising_handle)
)
fields.append((f'duration[{i}]: ', self.durations[i]))
fields.append(
@@ -3736,7 +3819,7 @@ class HCI_LE_Set_Extended_Scan_Parameters_Command(HCI_Command):
)
fields.append(
(f'{scanning_phy_str}.scan_interval:', self.scan_intervals[i])
),
)
fields.append((f'{scanning_phy_str}.scan_window: ', self.scan_windows[i]))
return (
@@ -3871,43 +3954,43 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command):
f'{initiating_phys_str}.scan_interval: ',
self.scan_intervals[i],
)
),
)
fields.append(
(
f'{initiating_phys_str}.scan_window: ',
self.scan_windows[i],
)
),
)
fields.append(
(
f'{initiating_phys_str}.connection_interval_min:',
self.connection_interval_mins[i],
)
),
)
fields.append(
(
f'{initiating_phys_str}.connection_interval_max:',
self.connection_interval_maxs[i],
)
),
)
fields.append(
(
f'{initiating_phys_str}.max_latency: ',
self.max_latencies[i],
)
),
)
fields.append(
(
f'{initiating_phys_str}.supervision_timeout: ',
self.supervision_timeouts[i],
)
),
)
fields.append(
(
f'{initiating_phys_str}.min_ce_length: ',
self.min_ce_lengths[i],
)
),
)
fields.append(
(
f'{initiating_phys_str}.max_ce_length: ',
@@ -3933,6 +4016,7 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command):
'privacy_mode',
{
'size': 1,
# pylint: disable-next=unnecessary-lambda
'mapper': lambda x: HCI_LE_Set_Privacy_Mode_Command.privacy_mode_name(
x
),
@@ -3975,11 +4059,11 @@ class HCI_Event(HCI_Packet):
'''
hci_packet_type = HCI_EVENT_PACKET
event_classes = {}
meta_event_classes = {}
event_classes: Dict[int, Type[HCI_Event]] = {}
meta_event_classes: Dict[int, Type[HCI_LE_Meta_Event]] = {}
@staticmethod
def event(fields=[]):
def event(fields=()):
'''
Decorator used to declare and register subclasses
'''
@@ -4005,16 +4089,16 @@ class HCI_Event(HCI_Packet):
return inner
@staticmethod
def registered(cls):
cls.name = cls.__name__.upper()
cls.event_code = key_with_value(HCI_EVENT_NAMES, cls.name)
if cls.event_code is None:
def registered(event_class):
event_class.name = event_class.__name__.upper()
event_class.event_code = key_with_value(HCI_EVENT_NAMES, event_class.name)
if event_class.event_code is None:
raise KeyError('event not found in HCI_EVENT_NAMES')
# Register a factory for this class
HCI_Event.event_classes[cls.event_code] = cls
HCI_Event.event_classes[event_class.event_code] = event_class
return cls
return event_class
@staticmethod
def from_bytes(packet):
@@ -4025,7 +4109,8 @@ class HCI_Event(HCI_Packet):
raise ValueError('invalid packet length')
if event_code == HCI_LE_META_EVENT:
# We do this dispatch here and not in the subclass in order to avoid call loops
# We do this dispatch here and not in the subclass in order to avoid call
# loops
subevent_code = parameters[0]
cls = HCI_Event.meta_event_classes.get(subevent_code)
if cls is None:
@@ -4086,7 +4171,7 @@ class HCI_LE_Meta_Event(HCI_Event):
'''
@staticmethod
def event(fields=[]):
def event(fields=()):
'''
Decorator used to declare and register subclasses
'''
@@ -4214,9 +4299,9 @@ class HCI_LE_Advertising_Report_Event(HCI_LE_Meta_Event):
def event_type_string(self):
return HCI_LE_Advertising_Report_Event.event_type_name(self.event_type)
def to_string(self, prefix):
def to_string(self, indentation='', _=None):
return super().to_string(
prefix,
indentation,
{
'event_type': HCI_LE_Advertising_Report_Event.event_type_name,
'address_type': Address.address_type_name,
@@ -4443,9 +4528,10 @@ class HCI_LE_Extended_Advertising_Report_Event(HCI_LE_Meta_Event):
self.event_type
)
def to_string(self, prefix):
def to_string(self, indentation='', _=None):
# pylint: disable=line-too-long
return super().to_string(
prefix,
indentation,
{
'event_type': HCI_LE_Extended_Advertising_Report_Event.event_type_string,
'address_type': Address.address_type_name,
@@ -4472,6 +4558,7 @@ class HCI_LE_Extended_Advertising_Report_Event(HCI_LE_Meta_Event):
)
)
if legacy_pdu_type is not None:
# pylint: disable=line-too-long
legacy_info_string = f'({HCI_LE_Advertising_Report_Event.event_type_name(legacy_pdu_type)})'
else:
legacy_info_string = ''
@@ -4587,6 +4674,7 @@ class HCI_Inquiry_Result_Event(HCI_Event):
'link_type',
{
'size': 1,
# pylint: disable-next=unnecessary-lambda
'mapper': lambda x: HCI_Connection_Complete_Event.link_type_name(x),
},
),
@@ -4622,6 +4710,7 @@ class HCI_Connection_Complete_Event(HCI_Event):
'link_type',
{
'size': 1,
# pylint: disable-next=unnecessary-lambda
'mapper': lambda x: HCI_Connection_Complete_Event.link_type_name(x),
},
),
@@ -4678,6 +4767,7 @@ class HCI_Remote_Name_Request_Complete_Event(HCI_Event):
'encryption_enabled',
{
'size': 1,
# pylint: disable-next=unnecessary-lambda
'mapper': lambda x: HCI_Encryption_Change_Event.encryption_enabled_name(
x
),
@@ -4746,16 +4836,20 @@ class HCI_Command_Complete_Event(HCI_Event):
See Bluetooth spec @ 7.7.14 Command Complete Event
'''
return_parameters = b''
def map_return_parameters(self, return_parameters):
# Map simple 'status' return parameters to their named constant form
if type(return_parameters) is bytes and len(return_parameters) == 1:
'''Map simple 'status' return parameters to their named constant form'''
if isinstance(return_parameters, bytes) and len(return_parameters) == 1:
# Byte-array form
return HCI_Constant.status_name(return_parameters[0])
elif type(return_parameters) is int:
if isinstance(return_parameters, int):
# Already converted to an integer status code
return HCI_Constant.status_name(return_parameters)
else:
return return_parameters
return return_parameters
@staticmethod
def from_parameters(parameters):
@@ -4766,8 +4860,12 @@ class HCI_Command_Complete_Event(HCI_Event):
)
# Parse the return parameters
if type(self.return_parameters) is bytes and len(self.return_parameters) == 1:
# All commands with 1-byte return parameters return a 'status' field, convert it to an integer
if (
isinstance(self.return_parameters, bytes)
and len(self.return_parameters) == 1
):
# All commands with 1-byte return parameters return a 'status' field,
# convert it to an integer
self.return_parameters = self.return_parameters[0]
else:
cls = HCI_Command.command_classes.get(self.command_opcode)
@@ -4793,6 +4891,7 @@ class HCI_Command_Complete_Event(HCI_Event):
[
(
'status',
# pylint: disable-next=unnecessary-lambda
{'size': 1, 'mapper': lambda x: HCI_Command_Status_Event.status_name(x)},
),
('num_hci_command_packets', 1),
@@ -4810,8 +4909,8 @@ class HCI_Command_Status_Event(HCI_Event):
def status_name(status):
if status == HCI_Command_Status_Event.PENDING:
return 'PENDING'
else:
return HCI_Constant.error_name(status)
return HCI_Constant.error_name(status)
# -----------------------------------------------------------------------------
@@ -4869,10 +4968,10 @@ class HCI_Number_Of_Completed_Packets_Event(HCI_Event):
color(' number_of_handles: ', 'cyan')
+ f'{len(self.connection_handles)}',
]
for i in range(len(self.connection_handles)):
for i, connection_handle in enumerate(self.connection_handles):
lines.append(
color(f' connection_handle[{i}]: ', 'cyan')
+ f'{self.connection_handles[i]}'
+ f'{connection_handle}'
)
lines.append(
color(f' num_completed_packets[{i}]: ', 'cyan')
@@ -4888,6 +4987,7 @@ class HCI_Number_Of_Completed_Packets_Event(HCI_Event):
('connection_handle', 2),
(
'current_mode',
# pylint: disable-next=unnecessary-lambda
{'size': 1, 'mapper': lambda x: HCI_Mode_Change_Event.mode_name(x)},
),
('interval', 2),
@@ -5044,6 +5144,7 @@ class HCI_Read_Remote_Extended_Features_Complete_Event(HCI_Event):
# -----------------------------------------------------------------------------
@HCI_Event.event(
# pylint: disable=line-too-long
[
('status', STATUS_SPEC),
('connection_handle', 2),
@@ -5052,6 +5153,7 @@ class HCI_Read_Remote_Extended_Features_Complete_Event(HCI_Event):
'link_type',
{
'size': 1,
# pylint: disable-next=unnecessary-lambda
'mapper': lambda x: HCI_Synchronous_Connection_Complete_Event.link_type_name(
x
),
@@ -5065,6 +5167,7 @@ class HCI_Read_Remote_Extended_Features_Complete_Event(HCI_Event):
'air_mode',
{
'size': 1,
# pylint: disable-next=unnecessary-lambda
'mapper': lambda x: HCI_Synchronous_Connection_Complete_Event.air_mode_name(
x
),
@@ -5229,7 +5332,7 @@ class HCI_Remote_Host_Supported_Features_Notification_Event(HCI_Event):
# -----------------------------------------------------------------------------
class HCI_AclDataPacket(HCI_Packet):
class HCI_AclDataPacket:
'''
See Bluetooth spec @ 5.4.2 HCI ACL Data Packets
'''
@@ -5268,7 +5371,13 @@ class HCI_AclDataPacket(HCI_Packet):
return self.to_bytes()
def __str__(self):
return f'{color("ACL", "blue")}: handle=0x{self.connection_handle:04x}, pb={self.pb_flag}, bc={self.bc_flag}, data_total_length={self.data_total_length}, data={self.data.hex()}'
return (
f'{color("ACL", "blue")}: '
f'handle=0x{self.connection_handle:04x}'
f'pb={self.pb_flag}, bc={self.bc_flag}, '
f'data_total_length={self.data_total_length}, '
f'data={self.data.hex()}'
)
# -----------------------------------------------------------------------------
@@ -5279,9 +5388,9 @@ class HCI_AclDataPacketAssembler:
self.l2cap_pdu_length = 0
def feed_packet(self, packet):
if (
packet.pb_flag == HCI_ACL_PB_FIRST_NON_FLUSHABLE
or packet.pb_flag == HCI_ACL_PB_FIRST_FLUSHABLE
if packet.pb_flag in (
HCI_ACL_PB_FIRST_NON_FLUSHABLE,
HCI_ACL_PB_FIRST_FLUSHABLE,
):
(l2cap_pdu_length,) = struct.unpack_from('<H', packet.data, 0)
self.current_data = packet.data

View File

@@ -16,12 +16,11 @@
# Imports
# -----------------------------------------------------------------------------
import logging
from colors import color
from bumble.smp import SMP_CID, SMP_Command
from .colors import color
from .att import ATT_CID, ATT_PDU
from .smp import SMP_CID, SMP_Command
from .core import name_or_number
from .gatt import ATT_PDU, ATT_CID
from .l2cap import (
L2CAP_PDU,
L2CAP_CONNECTION_REQUEST,
@@ -66,6 +65,7 @@ class PacketTracer:
self.psms = {} # PSM, by source_cid
self.peer = None # ACL stream in the other direction
# pylint: disable=too-many-nested-blocks
def on_acl_pdu(self, pdu):
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
@@ -75,10 +75,7 @@ class PacketTracer:
elif l2cap_pdu.cid == SMP_CID:
smp_command = SMP_Command.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(smp_command)
elif (
l2cap_pdu.cid == L2CAP_SIGNALING_CID
or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID
):
elif l2cap_pdu.cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(control_frame)
@@ -95,7 +92,8 @@ class PacketTracer:
# Found a pending connection
self.psms[control_frame.destination_cid] = psm
# For AVDTP connections, create a packet assembler for each direction
# For AVDTP connections, create a packet assembler for
# each direction
if psm == AVDTP_PSM:
self.avdtp_assemblers[
control_frame.source_cid
@@ -117,7 +115,8 @@ class PacketTracer:
self.analyzer.emit(rfcomm_frame)
elif psm == AVDTP_PSM:
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
)
assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
if assembler:
@@ -125,7 +124,8 @@ class PacketTracer:
else:
psm_string = name_or_number(PSM_NAMES, psm)
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM={psm_string}]: {l2cap_pdu.payload.hex()}'
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM={psm_string}]: {l2cap_pdu.payload.hex()}'
)
else:
self.analyzer.emit(l2cap_pdu)
@@ -147,7 +147,8 @@ class PacketTracer:
def start_acl_stream(self, connection_handle):
logger.info(
f'[{self.label}] +++ Creating ACL stream for connection 0x{connection_handle:04X}'
f'[{self.label}] +++ Creating ACL stream for connection '
f'0x{connection_handle:04X}'
)
stream = PacketTracer.AclStream(self)
self.acl_streams[connection_handle] = stream
@@ -162,7 +163,8 @@ class PacketTracer:
def end_acl_stream(self, connection_handle):
if connection_handle in self.acl_streams:
logger.info(
f'[{self.label}] --- Removing ACL stream for connection 0x{connection_handle:04X}'
f'[{self.label}] --- Removing ACL stream for connection '
f'0x{connection_handle:04X}'
)
del self.acl_streams[connection_handle]

View File

@@ -18,7 +18,8 @@
import logging
import asyncio
import collections
from colors import color
from .colors import color
# -----------------------------------------------------------------------------
@@ -43,7 +44,7 @@ class HfpProtocol:
def feed(self, data):
# Convert the data to a string if needed
if type(data) == bytes:
if isinstance(data, bytes):
data = data.decode('utf-8')
logger.debug(f'<<< Data received: {data}')
@@ -79,16 +80,16 @@ class HfpProtocol:
async def initialize_service(self):
# Perform Service Level Connection Initialization
self.send_command_line('AT+BRSF=2072') # Retrieve Supported Features
line = await (self.next_line())
line = await (self.next_line())
await (self.next_line())
await (self.next_line())
self.send_command_line('AT+CIND=?')
line = await (self.next_line())
line = await (self.next_line())
await (self.next_line())
await (self.next_line())
self.send_command_line('AT+CIND?')
line = await (self.next_line())
line = await (self.next_line())
await (self.next_line())
await (self.next_line())
self.send_command_line('AT+CMER=3,0,0,1')
line = await (self.next_line())
await (self.next_line())

View File

@@ -16,16 +16,59 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import collections
import logging
from pyee import EventEmitter
from colors import color
import struct
from bumble.colors import color
from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper
from .hci import (
HCI_ACL_DATA_PACKET,
HCI_COMMAND_COMPLETE_EVENT,
HCI_COMMAND_PACKET,
HCI_EVENT_PACKET,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND,
HCI_RESET_COMMAND,
HCI_SUCCESS,
HCI_SUPPORTED_COMMANDS_FLAGS,
HCI_VERSION_BLUETOOTH_CORE_4_0,
HCI_AclDataPacket,
HCI_AclDataPacketAssembler,
HCI_Constant,
HCI_Error,
HCI_LE_Long_Term_Key_Request_Negative_Reply_Command,
HCI_LE_Long_Term_Key_Request_Reply_Command,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_Read_Local_Supported_Features_Command,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_LE_Remote_Connection_Parameter_Request_Reply_Command,
HCI_LE_Set_Event_Mask_Command,
HCI_LE_Write_Suggested_Default_Data_Length_Command,
HCI_Link_Key_Request_Negative_Reply_Command,
HCI_Link_Key_Request_Reply_Command,
HCI_Packet,
HCI_Read_Buffer_Size_Command,
HCI_Read_Local_Supported_Commands_Command,
HCI_Read_Local_Version_Information_Command,
HCI_Reset_Command,
HCI_Set_Event_Mask_Command,
)
from .core import (
BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE,
BT_LE_TRANSPORT,
ConnectionPHY,
ConnectionParameters,
)
from .utils import AbortableEventEmitter
from .hci import *
from .l2cap import *
from .att import *
from .gatt import *
from .smp import *
from .core import ConnectionParameters
# -----------------------------------------------------------------------------
# Logging
@@ -65,12 +108,13 @@ class Connection:
# -----------------------------------------------------------------------------
class Host(EventEmitter):
class Host(AbortableEventEmitter):
def __init__(self, controller_source=None, controller_sink=None):
super().__init__()
self.hci_sink = None
self.ready = False # True when we can accept incoming packets
self.reset_done = False
self.connections = {} # Connections, by connection handle
self.pending_command = None
self.pending_response = None
@@ -89,6 +133,7 @@ class Host(EventEmitter):
self.long_term_key_provider = None
self.link_key_provider = None
self.pairing_io_capability_provider = None # Classic only
self.snooper = None
# Connect to the source and sink if specified
if controller_source:
@@ -96,7 +141,19 @@ class Host(EventEmitter):
if controller_sink:
self.set_packet_sink(controller_sink)
async def flush(self) -> None:
# Make sure no command is pending
await self.command_semaphore.acquire()
# Flush current host state, then release command semaphore
self.emit('flush')
self.command_semaphore.release()
async def reset(self):
if self.ready:
self.ready = False
await self.flush()
await self.send_command(HCI_Reset_Command(), check_result=True)
self.ready = True
@@ -127,10 +184,12 @@ class Host(EventEmitter):
self.local_version is not None
and self.local_version.hci_version <= HCI_VERSION_BLUETOOTH_CORE_4_0
):
# Some older controllers don't like event masks with bits they don't understand
# Some older controllers don't like event masks with bits they don't
# understand
le_event_mask = bytes.fromhex('1F00000000000000')
else:
le_event_mask = bytes.fromhex('FFFFF00000000000')
await self.send_command(
HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask)
)
@@ -147,7 +206,8 @@ class Host(EventEmitter):
)
logger.debug(
f'HCI ACL flow control: hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
'HCI ACL flow control: '
f'hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}'
)
@@ -163,8 +223,10 @@ class Host(EventEmitter):
)
logger.debug(
f'HCI LE ACL flow control: hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
f'hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}'
'HCI LE ACL flow control: '
f'hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
'hc_total_num_le_acl_data_packets='
f'{self.hc_total_num_le_acl_data_packets}'
)
if (
@@ -212,6 +274,9 @@ class Host(EventEmitter):
self.hci_sink = sink
def send_hci_packet(self, packet):
if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
self.hci_sink.on_packet(packet.to_bytes())
async def send_command(self, command, check_result=False):
@@ -232,9 +297,9 @@ class Host(EventEmitter):
# Check the return parameters if required
if check_result:
if type(response.return_parameters) is int:
if isinstance(response.return_parameters, int):
status = response.return_parameters
elif type(response.return_parameters) is bytes:
elif isinstance(response.return_parameters, bytes):
# return parameters first field is a one byte status code
status = response.return_parameters[0]
else:
@@ -294,7 +359,8 @@ class Host(EventEmitter):
if len(self.acl_packet_queue):
logger.debug(
f'{self.acl_packets_in_flight} ACL packets in flight, {len(self.acl_packet_queue)} in queue'
f'{self.acl_packets_in_flight} ACL packets in flight, '
f'{len(self.acl_packet_queue)} in queue'
)
def check_acl_packet_queue(self):
@@ -357,6 +423,9 @@ class Host(EventEmitter):
def on_hci_packet(self, packet):
logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}')
if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
# If the packet is a command, invoke the handler for this packet
if packet.hci_packet_type == HCI_COMMAND_PACKET:
self.on_hci_command_packet(packet)
@@ -388,7 +457,9 @@ class Host(EventEmitter):
# Check that it is what we were expecting
if self.pending_command.op_code != event.command_opcode:
logger.warning(
f'!!! command result mismatch, expected 0x{self.pending_command.op_code:X} but got 0x{event.command_opcode:X}'
'!!! command result mismatch, expected '
f'0x{self.pending_command.op_code:X} but got '
f'0x{event.command_opcode:X}'
)
self.pending_response.set_result(event)
@@ -403,10 +474,12 @@ class Host(EventEmitter):
def on_hci_command_complete_event(self, event):
if event.command_opcode == 0:
# This is used just for the Num_HCI_Command_Packets field, not related to an actual command
# This is used just for the Num_HCI_Command_Packets field, not related to
# an actual command
logger.debug('no-command event')
else:
return self.on_command_processed(event)
return None
return self.on_command_processed(event)
def on_hci_command_status_event(self, event):
return self.on_command_processed(event)
@@ -419,7 +492,8 @@ class Host(EventEmitter):
else:
logger.warning(
color(
f'!!! {total_packets} completed but only {self.acl_packets_in_flight} in flight'
'!!! {total_packets} completed but only '
f'{self.acl_packets_in_flight} in flight'
)
)
self.acl_packets_in_flight = 0
@@ -439,7 +513,8 @@ class Host(EventEmitter):
if event.status == HCI_SUCCESS:
# Create/update the connection
logger.debug(
f'### CONNECTION: [0x{event.connection_handle:04X}] {event.peer_address} as {HCI_Constant.role_name(event.role)}'
f'### CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.peer_address} as {HCI_Constant.role_name(event.role)}'
)
connection = self.connections.get(event.connection_handle)
@@ -484,7 +559,8 @@ class Host(EventEmitter):
if event.status == HCI_SUCCESS:
# Create/update the connection
logger.debug(
f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}'
f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.bd_addr}'
)
connection = self.connections.get(event.connection_handle)
@@ -524,7 +600,10 @@ class Host(EventEmitter):
if event.status == HCI_SUCCESS:
logger.debug(
f'### DISCONNECTION: [0x{event.connection_handle:04X}] {connection.peer_address} as {HCI_Constant.role_name(connection.role)}, reason={event.reason}'
f'### DISCONNECTION: [0x{event.connection_handle:04X}] '
f'{connection.peer_address} as '
f'{HCI_Constant.role_name(connection.role)}, '
f'reason={event.reason}'
)
del self.connections[event.connection_handle]
@@ -587,7 +666,7 @@ class Host(EventEmitter):
connection_handle=event.connection_handle,
interval_min=event.interval_min,
interval_max=event.interval_max,
latency=event.latency,
max_latency=event.max_latency,
timeout=event.timeout,
min_ce_length=0,
max_ce_length=0,
@@ -604,8 +683,14 @@ class Host(EventEmitter):
logger.debug('no long term key provider')
long_term_key = None
else:
long_term_key = await self.long_term_key_provider(
connection.handle, event.random_number, event.encryption_diversifier
long_term_key = await self.abort_on(
'flush',
# pylint: disable-next=not-callable
self.long_term_key_provider(
connection.handle,
event.random_number,
event.encryption_diversifier,
),
)
if long_term_key:
response = HCI_LE_Long_Term_Key_Request_Reply_Command(
@@ -630,12 +715,14 @@ class Host(EventEmitter):
def on_hci_role_change_event(self, event):
if event.status == HCI_SUCCESS:
logger.debug(
f'role change for {event.bd_addr}: {HCI_Constant.role_name(event.new_role)}'
f'role change for {event.bd_addr}: '
f'{HCI_Constant.role_name(event.new_role)}'
)
# TODO: lookup the connection and update the role
else:
logger.debug(
f'role change for {event.bd_addr} failed: {HCI_Constant.error_name(event.status)}'
f'role change for {event.bd_addr} failed: '
f'{HCI_Constant.error_name(event.status)}'
)
def on_hci_le_data_length_change_event(self, event):
@@ -694,24 +781,19 @@ class Host(EventEmitter):
def on_hci_link_key_notification_event(self, event):
logger.debug(
f'link key for {event.bd_addr}: {event.link_key.hex()}, type={HCI_Constant.link_key_type_name(event.key_type)}'
f'link key for {event.bd_addr}: {event.link_key.hex()}, '
f'type={HCI_Constant.link_key_type_name(event.key_type)}'
)
self.emit('link_key', event.bd_addr, event.link_key, event.key_type)
def on_hci_simple_pairing_complete_event(self, event):
logger.debug(
f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}'
f'simple pairing complete for {event.bd_addr}: '
f'status={HCI_Constant.status_name(event.status)}'
)
# Notify the client
if event.status == HCI_SUCCESS:
self.emit('ssp_complete', event.bd_addr)
def on_hci_pin_code_request_event(self, event):
# For now, just refuse all requests
# TODO: delegate the decision
self.send_command_sync(
HCI_PIN_Code_Request_Negative_Reply_Command(bd_addr=event.bd_addr)
)
self.emit('pin_code_request', event.bd_addr)
def on_hci_link_key_request_event(self, event):
async def send_link_key():
@@ -719,7 +801,11 @@ class Host(EventEmitter):
logger.debug('no link key provider')
link_key = None
else:
link_key = await self.link_key_provider(event.bd_addr)
link_key = await self.abort_on(
'flush',
# pylint: disable-next=not-callable
self.link_key_provider(event.bd_addr),
)
if link_key:
response = HCI_Link_Key_Request_Reply_Command(
bd_addr=event.bd_addr, link_key=link_key
@@ -754,7 +840,7 @@ class Host(EventEmitter):
'authentication_user_passkey_notification', event.bd_addr, event.passkey
)
def on_hci_inquiry_complete_event(self, event):
def on_hci_inquiry_complete_event(self, _event):
self.emit('inquiry_complete')
def on_hci_inquiry_result_with_rssi_event(self, event):

View File

@@ -24,8 +24,9 @@ import asyncio
import logging
import os
import json
from colors import color
from typing import Optional
from .colors import color
from .hci import Address
@@ -76,8 +77,10 @@ class PairingKeys:
@staticmethod
def key_from_dict(keys_dict, key_name):
key_dict = keys_dict.get(key_name)
if key_dict is not None:
return PairingKeys.Key.from_dict(key_dict)
if key_dict is None:
return None
return PairingKeys.Key.from_dict(key_dict)
@staticmethod
def from_dict(keys_dict):
@@ -121,13 +124,13 @@ class PairingKeys:
def print(self, prefix=''):
keys_dict = self.to_dict()
for (property, value) in keys_dict.items():
if type(value) is dict:
print(f'{prefix}{color(property, "cyan")}:')
for (container_property, value) in keys_dict.items():
if isinstance(value, dict):
print(f'{prefix}{color(container_property, "cyan")}:')
for (key_property, key_value) in value.items():
print(f'{prefix} {color(key_property, "green")}: {key_value}')
else:
print(f'{prefix}{color(property, "cyan")}: {value}')
print(f'{prefix}{color(container_property, "cyan")}: {value}')
# -----------------------------------------------------------------------------
@@ -138,7 +141,7 @@ class KeyStore:
async def update(self, name, keys):
pass
async def get(self, name):
async def get(self, _name):
return PairingKeys()
async def get_all(self):
@@ -193,6 +196,9 @@ class JsonKeyStore(KeyStore):
if filename is None:
# Use a default for the current user
# Import here because this may not exist on all platforms
# pylint: disable=import-outside-toplevel
import appdirs
self.directory_name = os.path.join(
@@ -211,7 +217,7 @@ class JsonKeyStore(KeyStore):
params = device_config.keystore.split(':', 1)[1:]
namespace = str(device_config.address)
if params:
filename = params[1]
filename = params[0]
else:
filename = None
@@ -219,7 +225,7 @@ class JsonKeyStore(KeyStore):
async def load(self):
try:
with open(self.filename, 'r') as json_file:
with open(self.filename, 'r', encoding='utf-8') as json_file:
return json.load(json_file)
except FileNotFoundError:
return {}
@@ -231,13 +237,13 @@ class JsonKeyStore(KeyStore):
# Save to a temporary file
temp_filename = self.filename + '.tmp'
with open(temp_filename, 'w') as output:
with open(temp_filename, 'w', encoding='utf-8') as output:
json.dump(db, output, sort_keys=True, indent=4)
# Atomically replace the previous file
os.rename(temp_filename, self.filename)
async def delete(self, name):
async def delete(self, name: str) -> None:
db = await self.load()
namespace = db.get(self.namespace)
@@ -273,7 +279,7 @@ class JsonKeyStore(KeyStore):
await self.save(db)
async def get(self, name):
async def get(self, name: str) -> Optional[PairingKeys]:
db = await self.load()
namespace = db.get(self.namespace)

View File

@@ -15,14 +15,16 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import struct
from collections import deque
from colors import color
from pyee import EventEmitter
from typing import Dict, Type
from .colors import color
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
from .hci import (
HCI_LE_Connection_Update_Command,
@@ -41,6 +43,7 @@ logger = logging.getLogger(__name__)
# Constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
L2CAP_SIGNALING_CID = 0x01
L2CAP_LE_SIGNALING_CID = 0x05
@@ -137,11 +140,15 @@ L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE = 0x01
L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE = 0x01
# fmt: on
# pylint: enable=line-too-long
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
# pylint: disable=invalid-name
class L2CAP_PDU:
'''
See Bluetooth spec @ Vol 3, Part A - 3 DATA PACKET FORMAT
@@ -179,8 +186,9 @@ class L2CAP_Control_Frame:
See Bluetooth spec @ Vol 3, Part A - 4 SIGNALING PACKET FORMATS
'''
classes = {}
classes: Dict[int, Type[L2CAP_Control_Frame]] = {}
code = 0
name = None
@staticmethod
def from_bytes(pdu):
@@ -215,11 +223,11 @@ class L2CAP_Control_Frame:
def decode_configuration_options(data):
options = []
while len(data) >= 2:
type = data[0]
value_type = data[0]
length = data[1]
value = data[2 : 2 + length]
data = data[2 + length :]
options.append((type, value))
options.append((value_type, value))
return options
@@ -236,7 +244,8 @@ class L2CAP_Control_Frame:
cls.code = key_with_value(L2CAP_CONTROL_FRAME_NAMES, cls.name)
if cls.code is None:
raise KeyError(
f'Control Frame name {cls.name} not found in L2CAP_CONTROL_FRAME_NAMES'
f'Control Frame name {cls.name} '
'not found in L2CAP_CONTROL_FRAME_NAMES'
)
cls.fields = fields
@@ -281,6 +290,7 @@ class L2CAP_Control_Frame:
# -----------------------------------------------------------------------------
@L2CAP_Control_Frame.subclass(
# pylint: disable=unnecessary-lambda
[
(
'reason',
@@ -311,6 +321,7 @@ class L2CAP_Command_Reject(L2CAP_Control_Frame):
# -----------------------------------------------------------------------------
@L2CAP_Control_Frame.subclass(
# pylint: disable=unnecessary-lambda
[
(
'psm',
@@ -356,6 +367,7 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
# -----------------------------------------------------------------------------
@L2CAP_Control_Frame.subclass(
# pylint: disable=unnecessary-lambda
[
('destination_cid', 2),
('source_cid', 2),
@@ -373,17 +385,18 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
CONNECTION_SUCCESSFUL = 0x0000
CONNECTION_PENDING = 0x0001
CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED = 0x0002
CONNECTION_REFUSED_PSM_NOT_SUPPORTED = 0x0002
CONNECTION_REFUSED_SECURITY_BLOCK = 0x0003
CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE = 0x0004
CONNECTION_REFUSED_INVALID_SOURCE_CID = 0x0006
CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED = 0x0007
CONNECTION_REFUSED_UNACCEPTABLE_PARAMETERS = 0x000B
# pylint: disable=line-too-long
RESULT_NAMES = {
CONNECTION_SUCCESSFUL: 'CONNECTION_SUCCESSFUL',
CONNECTION_PENDING: 'CONNECTION_PENDING',
CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED: 'CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED',
CONNECTION_REFUSED_PSM_NOT_SUPPORTED: 'CONNECTION_REFUSED_PSM_NOT_SUPPORTED',
CONNECTION_REFUSED_SECURITY_BLOCK: 'CONNECTION_REFUSED_SECURITY_BLOCK',
CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE: 'CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE',
CONNECTION_REFUSED_INVALID_SOURCE_CID: 'CONNECTION_REFUSED_INVALID_SOURCE_CID',
@@ -406,6 +419,7 @@ class L2CAP_Configure_Request(L2CAP_Control_Frame):
# -----------------------------------------------------------------------------
@L2CAP_Control_Frame.subclass(
# pylint: disable=unnecessary-lambda
[
('source_cid', 2),
('flags', 2),
@@ -481,6 +495,7 @@ class L2CAP_Echo_Response(L2CAP_Control_Frame):
'info_type',
{
'size': 2,
# pylint: disable-next=unnecessary-lambda
'mapper': lambda x: L2CAP_Information_Request.info_type_name(x),
},
)
@@ -524,6 +539,7 @@ class L2CAP_Information_Request(L2CAP_Control_Frame):
('info_type', {'size': 2, 'mapper': L2CAP_Information_Request.info_type_name}),
(
'result',
# pylint: disable-next=unnecessary-lambda
{'size': 2, 'mapper': lambda x: L2CAP_Information_Response.result_name(x)},
),
('data', '*'),
@@ -568,12 +584,14 @@ class L2CAP_Connection_Parameter_Update_Response(L2CAP_Control_Frame):
)
class L2CAP_LE_Credit_Based_Connection_Request(L2CAP_Control_Frame):
'''
See Bluetooth spec @ Vol 3, Part A - 4.22 LE CREDIT BASED CONNECTION REQUEST (CODE 0x14)
See Bluetooth spec @ Vol 3, Part A - 4.22 LE CREDIT BASED CONNECTION REQUEST
(CODE 0x14)
'''
# -----------------------------------------------------------------------------
@L2CAP_Control_Frame.subclass(
# pylint: disable=unnecessary-lambda,line-too-long
[
('destination_cid', 2),
('mtu', 2),
@@ -592,7 +610,8 @@ class L2CAP_LE_Credit_Based_Connection_Request(L2CAP_Control_Frame):
)
class L2CAP_LE_Credit_Based_Connection_Response(L2CAP_Control_Frame):
'''
See Bluetooth spec @ Vol 3, Part A - 4.23 LE CREDIT BASED CONNECTION RESPONSE (CODE 0x15)
See Bluetooth spec @ Vol 3, Part A - 4.23 LE CREDIT BASED CONNECTION RESPONSE
(CODE 0x15)
'''
CONNECTION_SUCCESSFUL = 0x0000
@@ -606,6 +625,7 @@ class L2CAP_LE_Credit_Based_Connection_Response(L2CAP_Control_Frame):
CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED = 0x000A
CONNECTION_REFUSED_UNACCEPTABLE_PARAMETERS = 0x000B
# pylint: disable=line-too-long
RESULT_NAMES = {
CONNECTION_SUCCESSFUL: 'CONNECTION_SUCCESSFUL',
CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED: 'CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED',
@@ -693,6 +713,7 @@ class Channel(EventEmitter):
self.destination_cid = 0
self.response = None
self.connection_result = None
self.disconnection_result = None
self.sink = None
def change_state(self, new_state):
@@ -723,6 +744,7 @@ class Channel(EventEmitter):
self.response.set_result(pdu)
self.response = None
elif self.sink:
# pylint: disable=not-callable
self.sink(pdu)
else:
logger.warning(
@@ -746,7 +768,8 @@ class Channel(EventEmitter):
)
)
# Create a future to wait for the state machine to get to a success or error state
# Create a future to wait for the state machine to get to a success or error
# state
self.connection_result = asyncio.get_running_loop().create_future()
# Wait for the connection to succeed or fail
@@ -768,10 +791,16 @@ class Channel(EventEmitter):
)
)
# Create a future to wait for the state machine to get to a success or error state
# Create a future to wait for the state machine to get to a success or error
# state
self.disconnection_result = asyncio.get_running_loop().create_future()
return await self.disconnection_result
def abort(self):
if self.state == self.OPEN:
self.change_state(self.CLOSED)
self.emit('close')
def send_configure_request(self):
options = L2CAP_Control_Frame.encode_configuration_options(
[
@@ -830,10 +859,10 @@ class Channel(EventEmitter):
self.connection_result = None
def on_configure_request(self, request):
if (
self.state != Channel.WAIT_CONFIG
and self.state != Channel.WAIT_CONFIG_REQ
and self.state != Channel.WAIT_CONFIG_REQ_RSP
if self.state not in (
Channel.WAIT_CONFIG,
Channel.WAIT_CONFIG_REQ,
Channel.WAIT_CONFIG_REQ_RSP,
):
logger.warning(color('invalid state', 'red'))
return
@@ -871,10 +900,7 @@ class Channel(EventEmitter):
if response.result == L2CAP_Configure_Response.SUCCESS:
if self.state == Channel.WAIT_CONFIG_REQ_RSP:
self.change_state(Channel.WAIT_CONFIG_REQ)
elif (
self.state == Channel.WAIT_CONFIG_RSP
or self.state == Channel.WAIT_CONTROL_IND
):
elif self.state in (Channel.WAIT_CONFIG_RSP, Channel.WAIT_CONTROL_IND):
self.change_state(Channel.OPEN)
if self.connection_result:
self.connection_result.set_result(None)
@@ -897,14 +923,15 @@ class Channel(EventEmitter):
else:
logger.warning(
color(
f'!!! configuration rejected: {L2CAP_Configure_Response.result_name(response.result)}',
'!!! configuration rejected: '
f'{L2CAP_Configure_Response.result_name(response.result)}',
'red',
)
)
# TODO: decide how to fail gracefully
def on_disconnection_request(self, request):
if self.state == Channel.OPEN or self.state == Channel.WAIT_DISCONNECT:
if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT):
self.send_control_frame(
L2CAP_Disconnection_Response(
identifier=request.identifier,
@@ -938,7 +965,12 @@ class Channel(EventEmitter):
self.manager.on_channel_closed(self)
def __str__(self):
return f'Channel({self.source_cid}->{self.destination_cid}, PSM={self.psm}, MTU={self.mtu}, state={Channel.STATE_NAMES[self.state]})'
return (
f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, '
f'MTU={self.mtu}, '
f'state={Channel.STATE_NAMES[self.state]})'
)
# -----------------------------------------------------------------------------
@@ -976,7 +1008,7 @@ class LeConnectionOrientedChannel(EventEmitter):
destination_cid,
mtu,
mps,
credits,
credits, # pylint: disable=redefined-builtin
peer_mtu,
peer_mps,
peer_credits,
@@ -1001,6 +1033,7 @@ class LeConnectionOrientedChannel(EventEmitter):
self.out_queue = deque()
self.out_sdu = None
self.sink = None
self.connected = False
self.connection_result = None
self.disconnection_result = None
self.drained = asyncio.Event()
@@ -1072,10 +1105,15 @@ class LeConnectionOrientedChannel(EventEmitter):
)
)
# Create a future to wait for the state machine to get to a success or error state
# Create a future to wait for the state machine to get to a success or error
# state
self.disconnection_result = asyncio.get_running_loop().create_future()
return await self.disconnection_result
def abort(self):
if self.state == self.CONNECTED:
self.change_state(self.DISCONNECTED)
def on_pdu(self, pdu):
if self.sink is None:
logger.warning('received pdu without a sink')
@@ -1110,7 +1148,8 @@ class LeConnectionOrientedChannel(EventEmitter):
# Check if the SDU is complete
if self.in_sdu_length == 0:
# We don't know the size yet, check if we have received the header to compute it
# We don't know the size yet, check if we have received the header to
# compute it
if len(self.in_sdu) >= 2:
self.in_sdu_length = struct.unpack_from('<H', self.in_sdu, 0)[0]
if self.in_sdu_length == 0:
@@ -1125,7 +1164,8 @@ class LeConnectionOrientedChannel(EventEmitter):
if len(self.in_sdu) != 2 + self.in_sdu_length:
# Overflow
logger.warning(
f'SDU overflow: sdu_length={self.in_sdu_length}, received {len(self.in_sdu) - 2}'
f'SDU overflow: sdu_length={self.in_sdu_length}, '
f'received {len(self.in_sdu) - 2}'
)
# TODO: we should disconnect
self.in_sdu = None
@@ -1134,7 +1174,7 @@ class LeConnectionOrientedChannel(EventEmitter):
# Send the SDU to the sink
logger.debug(f'SDU complete: 2+{len(self.in_sdu) - 2} bytes')
self.sink(self.in_sdu[2:])
self.sink(self.in_sdu[2:]) # pylint: disable=not-callable
# Prepare for a new SDU
self.in_sdu = None
@@ -1174,7 +1214,7 @@ class LeConnectionOrientedChannel(EventEmitter):
# Cleanup
self.connection_result = None
def on_credits(self, credits):
def on_credits(self, credits): # pylint: disable=redefined-builtin
self.credits += credits
logger.debug(f'received {credits} credits, total = {self.credits}')
@@ -1228,7 +1268,8 @@ class LeConnectionOrientedChannel(EventEmitter):
# Keep what's still left to send
self.out_sdu = self.out_sdu[len(packet) :]
continue
elif self.out_queue:
if self.out_queue:
# Create the next SDU (2 bytes header plus up to MTU bytes payload)
logger.debug(
f'assembling SDU from {len(self.out_queue)} packets in output queue'
@@ -1282,13 +1323,20 @@ class LeConnectionOrientedChannel(EventEmitter):
pass
def __str__(self):
return f'CoC({self.source_cid}->{self.destination_cid}, State={self.state_name(self.state)}, PSM={self.le_psm}, MTU={self.mtu}/{self.peer_mtu}, MPS={self.mps}/{self.peer_mps}, credits={self.credits}/{self.peer_credits})'
return (
f'CoC({self.source_cid}->{self.destination_cid}, '
f'State={self.state_name(self.state)}, '
f'PSM={self.le_psm}, '
f'MTU={self.mtu}/{self.peer_mtu}, '
f'MPS={self.mps}/{self.peer_mps}, '
f'credits={self.credits}/{self.peer_credits})'
)
# -----------------------------------------------------------------------------
class ChannelManager:
def __init__(
self, extended_features=[], connectionless_mtu=L2CAP_DEFAULT_CONNECTIONLESS_MTU
self, extended_features=(), connectionless_mtu=L2CAP_DEFAULT_CONNECTIONLESS_MTU
):
self._host = None
self.identifiers = {} # Incrementing identifier values by connection
@@ -1322,10 +1370,14 @@ class ChannelManager:
if connection_channels := self.channels.get(connection_handle):
return connection_channels.get(cid)
return None
def find_le_coc_channel(self, connection_handle, cid):
if connection_channels := self.le_coc_channels.get(connection_handle):
return connection_channels.get(cid)
return None
@staticmethod
def find_free_br_edr_cid(channels):
# Pick the smallest valid CID that's not already in the list
@@ -1337,6 +1389,8 @@ class ChannelManager:
if cid not in channels:
return cid
raise RuntimeError('no free CID available')
@staticmethod
def find_free_le_cid(channels):
# Pick the smallest valid CID that's not already in the list
@@ -1348,6 +1402,8 @@ class ChannelManager:
if cid not in channels:
return cid
raise RuntimeError('no free CID')
@staticmethod
def check_le_coc_parameters(max_credits, mtu, mps):
if (
@@ -1442,24 +1498,30 @@ class ChannelManager:
return psm
def on_disconnection(self, connection_handle, reason):
def on_disconnection(self, connection_handle, _reason):
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
if connection_handle in self.channels:
for _, channel in self.channels[connection_handle].items():
channel.abort()
del self.channels[connection_handle]
if connection_handle in self.le_coc_channels:
for _, channel in self.le_coc_channels[connection_handle].items():
channel.abort()
del self.le_coc_channels[connection_handle]
if connection_handle in self.identifiers:
del self.identifiers[connection_handle]
def send_pdu(self, connection, cid, pdu):
pdu_str = pdu.hex() if type(pdu) is bytes else str(pdu)
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} on connection [0x{connection.handle:04X}] (CID={cid}) {connection.peer_address}: {pdu_str}'
f'{color(">>> Sending L2CAP PDU", "blue")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
f'{connection.peer_address}: {pdu_str}'
)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
def on_pdu(self, connection, cid, pdu):
if cid == L2CAP_SIGNALING_CID or cid == L2CAP_LE_SIGNALING_CID:
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
# Parse the L2CAP payload into a Control Frame object
control_frame = L2CAP_Control_Frame.from_bytes(pdu)
@@ -1479,13 +1541,17 @@ class ChannelManager:
def send_control_frame(self, connection, cid, control_frame):
logger.debug(
f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} on connection [0x{connection.handle:04X}] (CID={cid}) {connection.peer_address}:\n{control_frame}'
f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
f'{connection.peer_address}:\n{control_frame}'
)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame))
def on_control_frame(self, connection, cid, control_frame):
logger.debug(
f'{color("<<< Received L2CAP Signaling Control Frame", "green")} on connection [0x{connection.handle:04X}] (CID={cid}) {connection.peer_address}:\n{control_frame}'
f'{color("<<< Received L2CAP Signaling Control Frame", "green")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
f'{connection.peer_address}:\n{control_frame}'
)
# Find the handler method
@@ -1518,7 +1584,7 @@ class ChannelManager:
),
)
def on_l2cap_command_reject(self, connection, cid, packet):
def on_l2cap_command_reject(self, _connection, _cid, packet):
logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}')
def on_l2cap_connection_request(self, connection, cid, request):
@@ -1536,6 +1602,7 @@ class ChannelManager:
identifier=request.identifier,
destination_cid=request.source_cid,
source_cid=0,
# pylint: disable=line-too-long
result=L2CAP_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
status=0x0000,
),
@@ -1556,7 +1623,8 @@ class ChannelManager:
channel.on_connection_request(request)
else:
logger.warning(
f'No server for connection 0x{connection.handle:04X} on PSM {request.psm}'
f'No server for connection 0x{connection.handle:04X} '
f'on PSM {request.psm}'
)
self.send_control_frame(
connection,
@@ -1565,7 +1633,8 @@ class ChannelManager:
identifier=request.identifier,
destination_cid=request.source_cid,
source_cid=0,
result=L2CAP_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED,
# pylint: disable=line-too-long
result=L2CAP_Connection_Response.CONNECTION_REFUSED_PSM_NOT_SUPPORTED,
status=0x0000,
),
)
@@ -1576,7 +1645,8 @@ class ChannelManager:
) is None:
logger.warning(
color(
f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}',
f'channel {response.source_cid} not found for '
f'0x{connection.handle:04X}:{cid}',
'red',
)
)
@@ -1590,7 +1660,8 @@ class ChannelManager:
) is None:
logger.warning(
color(
f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}',
f'channel {request.destination_cid} not found for '
f'0x{connection.handle:04X}:{cid}',
'red',
)
)
@@ -1604,7 +1675,8 @@ class ChannelManager:
) is None:
logger.warning(
color(
f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}',
f'channel {response.source_cid} not found for '
f'0x{connection.handle:04X}:{cid}',
'red',
)
)
@@ -1618,7 +1690,8 @@ class ChannelManager:
) is None:
logger.warning(
color(
f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}',
f'channel {request.destination_cid} not found for '
f'0x{connection.handle:04X}:{cid}',
'red',
)
)
@@ -1632,7 +1705,8 @@ class ChannelManager:
) is None:
logger.warning(
color(
f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}',
f'channel {response.source_cid} not found for '
f'0x{connection.handle:04X}:{cid}',
'red',
)
)
@@ -1648,7 +1722,7 @@ class ChannelManager:
L2CAP_Echo_Response(identifier=request.identifier, data=request.data),
)
def on_l2cap_echo_response(self, connection, cid, response):
def on_l2cap_echo_response(self, _connection, _cid, response):
logger.debug(f'<<< Echo response: data={response.data.hex()}')
# TODO notify listeners
@@ -1663,7 +1737,7 @@ class ChannelManager:
result = L2CAP_Information_Response.SUCCESS
data = sum(1 << cid for cid in self.fixed_channels).to_bytes(8, 'little')
else:
result = L2CAP_Information_Request.NO_SUPPORTED
result = L2CAP_Information_Response.NOT_SUPPORTED
self.send_control_frame(
connection,
@@ -1730,6 +1804,7 @@ class ChannelManager:
mtu=mtu,
mps=mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
),
)
@@ -1748,6 +1823,7 @@ class ChannelManager:
mtu=mtu,
mps=mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
),
)
@@ -1755,7 +1831,8 @@ class ChannelManager:
# Create a new channel
logger.debug(
f'creating LE CoC server channel with cid={source_cid} for psm {request.le_psm}'
f'creating LE CoC server channel with cid={source_cid} for psm '
f'{request.le_psm}'
)
channel = LeConnectionOrientedChannel(
self,
@@ -1784,6 +1861,7 @@ class ChannelManager:
mtu=mtu,
mps=mps,
initial_credits=max_credits,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL,
),
)
@@ -1792,7 +1870,8 @@ class ChannelManager:
server(channel)
else:
logger.info(
f'No LE server for connection 0x{connection.handle:04X} on PSM {request.le_psm}'
f'No LE server for connection 0x{connection.handle:04X} '
f'on PSM {request.le_psm}'
)
self.send_control_frame(
connection,
@@ -1803,11 +1882,12 @@ class ChannelManager:
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED,
),
)
def on_l2cap_le_credit_based_connection_response(self, connection, cid, response):
def on_l2cap_le_credit_based_connection_response(self, connection, _cid, response):
# Find the pending request by identifier
request = self.le_coc_requests.get(response.identifier)
if request is None:
@@ -1820,7 +1900,8 @@ class ChannelManager:
if channel is None:
logger.warning(
color(
f'received connection response for an unknown channel (cid={request.source_cid})',
'received connection response for an unknown channel '
f'(cid={request.source_cid})',
'red',
)
)
@@ -1829,7 +1910,7 @@ class ChannelManager:
# Process the response
channel.on_connection_response(response)
def on_l2cap_le_flow_control_credit(self, connection, cid, credit):
def on_l2cap_le_flow_control_credit(self, connection, _cid, credit):
channel = self.find_le_coc_channel(connection.handle, credit.cid)
if channel is None:
logger.warning(f'received credits for an unknown channel (cid={credit.cid}')

View File

@@ -17,10 +17,9 @@
# -----------------------------------------------------------------------------
import logging
import asyncio
import websockets
from functools import partial
from colors import color
from bumble.colors import color
from bumble.hci import (
Address,
HCI_SUCCESS,
@@ -47,7 +46,8 @@ def parse_parameters(params_str):
# -----------------------------------------------------------------------------
# TODO: add more support for various LL exchanges (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
# TODO: add more support for various LL exchanges
# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
# -----------------------------------------------------------------------------
class LocalLink:
'''
@@ -119,7 +119,8 @@ class LocalLink:
def connect(self, central_address, le_create_connection_command):
logger.debug(
f'$$$ CONNECTION {central_address} -> {le_create_connection_command.peer_address}'
f'$$$ CONNECTION {central_address} -> '
f'{le_create_connection_command.peer_address}'
)
self.pending_connection = (central_address, le_create_connection_command)
asyncio.get_running_loop().call_soon(self.on_connection_complete)
@@ -144,11 +145,13 @@ class LocalLink:
def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(
f'$$$ DISCONNECTION {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}'
f'$$$ DISCONNECTION {central_address} -> '
f'{peripheral_address}: reason = {disconnect_command.reason}'
)
args = [central_address, peripheral_address, disconnect_command]
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
# pylint: disable=too-many-arguments
def on_connection_encrypted(
self, central_address, peripheral_address, rand, ediv, ltk
):
@@ -215,8 +218,11 @@ class RemoteLink:
)
async def run_connection(self):
import websockets # lazy import
# Connect to the relay
logger.debug(f'connecting to {self.uri}')
# pylint: disable-next=no-member
websocket = await websockets.connect(self.uri)
self.websocket.set_result(websocket)
logger.debug(f'connected to {self.uri}')
@@ -287,11 +293,11 @@ class RemoteLink:
self.controller.on_link_central_connected(Address(sender))
# Accept the connection by responding to it
await self.send_targetted_message(sender, 'connected')
await self.send_targeted_message(sender, 'connected')
async def on_connected_message_received(self, sender, _):
if not self.pending_connection:
logger.warn('received a connection ack, but no connection is pending')
logger.warning('received a connection ack, but no connection is pending')
return
# Remember the connection
@@ -313,7 +319,7 @@ class RemoteLink:
if sender in self.peripheral_connections:
self.peripheral_connections.remove(sender)
async def on_encrypted_message_received(self, sender, message):
async def on_encrypted_message_received(self, sender, _):
# TODO parse params to get real args
self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16))
@@ -335,7 +341,7 @@ class RemoteLink:
# TODO: parse the result
async def send_targetted_message(self, target, message):
async def send_targeted_message(self, target, message):
# Ensure we have a connection
websocket = await self.websocket
@@ -352,23 +358,23 @@ class RemoteLink:
self.execute(self.notify_address_changed)
async def send_advertising_data_to_relay(self, data):
await self.send_targetted_message('*', f'advertisement:{data.hex()}')
await self.send_targeted_message('*', f'advertisement:{data.hex()}')
def send_advertising_data(self, sender_address, data):
def send_advertising_data(self, _, data):
self.execute(partial(self.send_advertising_data_to_relay, data))
async def send_acl_data_to_relay(self, peer_address, data):
await self.send_targetted_message(peer_address, f'acl:{data.hex()}')
await self.send_targeted_message(peer_address, f'acl:{data.hex()}')
def send_acl_data(self, sender_address, peer_address, data):
def send_acl_data(self, _, peer_address, data):
self.execute(partial(self.send_acl_data_to_relay, peer_address, data))
async def send_connection_request_to_relay(self, peer_address):
await self.send_targetted_message(peer_address, 'connect')
await self.send_targeted_message(peer_address, 'connect')
def connect(self, central_address, le_create_connection_command):
def connect(self, _, le_create_connection_command):
if self.pending_connection:
logger.warn('connection already pending')
logger.warning('connection already pending')
return
self.pending_connection = le_create_connection_command
self.execute(
@@ -385,11 +391,12 @@ class RemoteLink:
def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(
f'disconnect {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}'
f'disconnect {central_address} -> '
f'{peripheral_address}: reason = {disconnect_command.reason}'
)
self.execute(
partial(
self.send_targetted_message,
self.send_targeted_message,
peripheral_address,
f'disconnect:reason={disconnect_command.reason}',
)
@@ -398,15 +405,13 @@ class RemoteLink:
self.on_disconnection_complete, disconnect_command
)
def on_connection_encrypted(
self, central_address, peripheral_address, rand, ediv, ltk
):
def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk):
asyncio.get_running_loop().call_soon(
self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk
)
self.execute(
partial(
self.send_targetted_message,
self.send_targeted_message,
peripheral_address,
f'encrypted:ltk={ltk.hex()}',
)

View File

@@ -18,7 +18,9 @@
# -----------------------------------------------------------------------------
import struct
import logging
from typing import List
from ..core import AdvertisingData
from ..device import Device
from ..gatt import (
GATT_ASHA_SERVICE,
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
@@ -29,8 +31,8 @@ from ..gatt import (
TemplateService,
Characteristic,
CharacteristicValue,
PackedCharacteristicAdapter,
)
from ..utils import AsyncRunner
# -----------------------------------------------------------------------------
# Logging
@@ -50,31 +52,53 @@ class AshaService(TemplateService):
SUPPORTED_CODEC_ID = [0x02, 0x01] # Codec IDs [G.722 at 16 kHz]
RENDER_DELAY = [00, 00]
def __init__(self, capability: int, hisyncid: [int]):
def __init__(self, capability: int, hisyncid: List[int], device: Device, psm=0):
self.hisyncid = hisyncid
self.capability = capability # Device Capabilities [Left, Monaural]
self.device = device
self.audio_out_data = b''
self.psm = psm # a non-zero psm is mainly for testing purpose
# Handler for volume control
def on_volume_write(connection, value):
def on_volume_write(_connection, value):
logger.info(f'--- VOLUME Write:{value[0]}')
self.emit('volume', value[0])
# Handler for audio control commands
def on_audio_control_point_write(connection, value):
def on_audio_control_point_write(_connection, value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
logger.info(
f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}'
f'### START: codec={value[1]}, '
f'audio_type={audio_type}, '
f'volume={value[3]}, '
f'otherstate={value[4]}'
)
self.emit(
'start',
{
'codec': value[1],
'audiotype': value[2],
'volume': value[3],
'otherstate': value[4],
},
)
elif opcode == AshaService.OPCODE_STOP:
logger.info('### STOP')
self.emit('stop')
elif opcode == AshaService.OPCODE_STATUS:
logger.info(f'### STATUS: connected={value[1]}')
# TODO Respond with a status
# asyncio.create_task(device.notify_subscribers(audio_status_characteristic, force=True))
# OPCODE_STATUS does not need audio status point update
if opcode != AshaService.OPCODE_STATUS:
AsyncRunner.spawn(
device.notify_subscribers(
self.audio_status_characteristic, force=True
)
)
self.read_only_properties_characteristic = Characteristic(
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
@@ -112,9 +136,18 @@ class AshaService(TemplateService):
CharacteristicValue(write=on_volume_write),
)
# TODO add real psm value
self.psm = 0x0080
# self.psm = device.register_l2cap_channel_server(0, on_coc, 8)
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
logging.debug(f'<<< data received:{data}')
self.emit('data', data)
self.audio_out_data += data
channel.sink = on_data
# let the server find a free PSM
self.psm = self.device.register_l2cap_channel_server(self.psm, on_coc, 8)
self.le_psm_out_characteristic = Characteristic(
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.READ,
@@ -137,10 +170,6 @@ class AshaService(TemplateService):
return bytes(
AdvertisingData(
[
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(GATT_ASHA_SERVICE),
),
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(GATT_ASHA_SERVICE)

View File

@@ -40,7 +40,7 @@ class BatteryService(TemplateService):
Characteristic.READABLE,
CharacteristicValue(read=read_battery_level),
),
format=BatteryService.BATTERY_LEVEL_FORMAT,
pack_format=BatteryService.BATTERY_LEVEL_FORMAT,
)
super().__init__([self.battery_level_characteristic])
@@ -56,7 +56,7 @@ class BatteryServiceProxy(ProfileServiceProxy):
GATT_BATTERY_LEVEL_CHARACTERISTIC
):
self.battery_level = PackedCharacteristicAdapter(
characteristics[0], format=BatteryService.BATTERY_LEVEL_FORMAT
characteristics[0], pack_format=BatteryService.BATTERY_LEVEL_FORMAT
)
else:
self.battery_level = None

View File

@@ -17,7 +17,7 @@
# Imports
# -----------------------------------------------------------------------------
import struct
from typing import Tuple
from typing import Optional, Tuple
from ..gatt_client import ProfileServiceProxy
from ..gatt import (
@@ -52,14 +52,14 @@ class DeviceInformationService(TemplateService):
def __init__(
self,
manufacturer_name: str = None,
model_number: str = None,
serial_number: str = None,
hardware_revision: str = None,
firmware_revision: str = None,
software_revision: str = None,
system_id: Tuple[int, int] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: bytes = None
manufacturer_name: Optional[str] = None,
model_number: Optional[str] = None,
serial_number: Optional[str] = None,
hardware_revision: Optional[str] = None,
firmware_revision: Optional[str] = None,
software_revision: Optional[str] = None,
system_id: Optional[Tuple[int, int]] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: Optional[bytes] = None
# TODO: pnp_id
):
characteristics = [

View File

@@ -156,6 +156,7 @@ class HeartRateService(TemplateService):
0,
CharacteristicValue(read=read_heart_rate_measurement),
),
# pylint: disable=unnecessary-lambda
encode=lambda value: bytes(value),
)
characteristics = [self.heart_rate_measurement_characteristic]
@@ -185,7 +186,7 @@ class HeartRateService(TemplateService):
Characteristic.WRITEABLE,
CharacteristicValue(write=write_heart_rate_control_point_value),
),
format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
characteristics.append(self.heart_rate_control_point_characteristic)
@@ -224,7 +225,7 @@ class HeartRateServiceProxy(ProfileServiceProxy):
):
self.heart_rate_control_point = PackedCharacteristicAdapter(
characteristics[0],
format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
else:
self.heart_rate_control_point = None

0
bumble/profiles/py.typed Normal file
View File

0
bumble/py.typed Normal file
View File

View File

@@ -18,10 +18,11 @@
import logging
import asyncio
from colors import color
from pyee import EventEmitter
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError, ConnectionError
from . import core
from .colors import color
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError
# -----------------------------------------------------------------------------
# Logging
@@ -104,17 +105,17 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# -----------------------------------------------------------------------------
def fcs(buffer):
fcs = 0xFF
def compute_fcs(buffer):
result = 0xFF
for byte in buffer:
fcs = CRC_TABLE[fcs ^ byte]
return 0xFF - fcs
result = CRC_TABLE[result ^ byte]
return 0xFF - result
# -----------------------------------------------------------------------------
class RFCOMM_Frame:
def __init__(self, type, c_r, dlci, p_f, information=b'', with_credits=False):
self.type = type
def __init__(self, frame_type, c_r, dlci, p_f, information=b'', with_credits=False):
self.type = frame_type
self.c_r = c_r
self.dlci = dlci
self.p_f = p_f
@@ -129,18 +130,18 @@ class RFCOMM_Frame:
# 1-byte length indicator
self.length = bytes([(length << 1) | 1])
self.address = (dlci << 2) | (c_r << 1) | 1
self.control = type | (p_f << 4)
if type == RFCOMM_UIH_FRAME:
self.fcs = fcs(bytes([self.address, self.control]))
self.control = frame_type | (p_f << 4)
if frame_type == RFCOMM_UIH_FRAME:
self.fcs = compute_fcs(bytes([self.address, self.control]))
else:
self.fcs = fcs(bytes([self.address, self.control]) + self.length)
self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
def type_name(self):
return RFCOMM_FRAME_TYPE_NAMES[self.type]
@staticmethod
def parse_mcc(data):
type = data[0] >> 2
mcc_type = data[0] >> 2
c_r = (data[0] >> 1) & 1
length = data[1]
if data[1] & 1:
@@ -150,12 +151,12 @@ class RFCOMM_Frame:
length = (data[3] << 7) & (length >> 1)
value = data[3 : 3 + length]
return (type, c_r, value)
return (mcc_type, c_r, value)
@staticmethod
def make_mcc(type, c_r, data):
def make_mcc(mcc_type, c_r, data):
return (
bytes([(type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1])
bytes([(mcc_type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1])
+ data
)
@@ -186,7 +187,7 @@ class RFCOMM_Frame:
# Extract fields
dlci = (data[0] >> 2) & 0x3F
c_r = (data[0] >> 1) & 0x01
type = data[1] & 0xEF
frame_type = data[1] & 0xEF
p_f = (data[1] >> 4) & 0x01
length = data[2]
if length & 0x01:
@@ -198,9 +199,9 @@ class RFCOMM_Frame:
fcs = data[-1]
# Construct the frame and check the CRC
frame = RFCOMM_Frame(type, c_r, dlci, p_f, information)
frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
if frame.fcs != fcs:
logger.warn(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
raise ValueError('fcs mismatch')
return frame
@@ -214,7 +215,14 @@ class RFCOMM_Frame:
)
def __str__(self):
return f'{color(self.type_name(), "yellow")}(c/r={self.c_r},dlci={self.dlci},p/f={self.p_f},length={len(self.information)},fcs=0x{self.fcs:02X})'
return (
f'{color(self.type_name(), "yellow")}'
f'(c/r={self.c_r},'
f'dlci={self.dlci},'
f'p/f={self.p_f},'
f'length={len(self.information)},'
f'fcs=0x{self.fcs:02X})'
)
# -----------------------------------------------------------------------------
@@ -264,7 +272,15 @@ class RFCOMM_MCC_PN:
)
def __str__(self):
return f'PN(dlci={self.dlci},cl={self.cl},priority={self.priority},ack_timer={self.ack_timer},max_frame_size={self.max_frame_size},max_retransmissions={self.max_retransmissions},window_size={self.window_size})'
return (
f'PN(dlci={self.dlci},'
f'cl={self.cl},'
f'priority={self.priority},'
f'ack_timer={self.ack_timer},'
f'max_frame_size={self.max_frame_size},'
f'max_retransmissions={self.max_retransmissions},'
f'window_size={self.window_size})'
)
# -----------------------------------------------------------------------------
@@ -302,7 +318,14 @@ class RFCOMM_MCC_MSC:
)
def __str__(self):
return f'MSC(dlci={self.dlci},fc={self.fc},rtc={self.rtc},rtr={self.rtr},ic={self.ic},dv={self.dv})'
return (
f'MSC(dlci={self.dlci},'
f'fc={self.fc},'
f'rtc={self.rtc},'
f'rtr={self.rtr},'
f'ic={self.ic},'
f'dv={self.dv})'
)
# -----------------------------------------------------------------------------
@@ -336,6 +359,7 @@ class DLC(EventEmitter):
self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.INITIATOR else 0
self.sink = None
self.connection_result = None
# Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs
@@ -360,30 +384,38 @@ class DLC(EventEmitter):
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, frame):
def on_sabm_frame(self, _frame):
if self.state != DLC.CONNECTING:
logger.warn(color('!!! received SABM when not in CONNECTING state', 'red'))
logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red')
)
return
self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc))
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.CONNECTED)
self.emit('open')
def on_ua_frame(self, frame):
def on_ua_frame(self, _frame):
if self.state != DLC.CONNECTING:
logger.warn(color('!!! received SABM when not in CONNECTING state', 'red'))
logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red')
)
return
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc))
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
@@ -394,7 +426,7 @@ class DLC(EventEmitter):
# TODO: handle all states
pass
def on_disc_frame(self, frame):
def on_disc_frame(self, _frame):
# TODO: handle all states
self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
@@ -402,25 +434,28 @@ class DLC(EventEmitter):
data = frame.information
if frame.p_f == 1:
# With credits
credits = frame.information[0]
self.tx_credits += credits
received_credits = frame.information[0]
self.tx_credits += received_credits
logger.debug(
f'<<< Credits [{self.dlci}]: received {credits}, total={self.tx_credits}'
f'<<< Credits [{self.dlci}]: '
f'received {credits}, total={self.tx_credits}'
)
data = data[1:]
logger.debug(
f'{color("<<< Data", "yellow")} [{self.dlci}] {len(data)} bytes, rx_credits={self.rx_credits}: {data.hex()}'
f'{color("<<< Data", "yellow")} '
f'[{self.dlci}] {len(data)} bytes, '
f'rx_credits={self.rx_credits}: {data.hex()}'
)
if len(data) and self.sink:
self.sink(data)
self.sink(data) # pylint: disable=not-callable
# Update the credits
if self.rx_credits > 0:
self.rx_credits -= 1
else:
logger.warn(color('!!! received frame with no rx credits', 'red'))
logger.warning(color('!!! received frame with no rx credits', 'red'))
# Check if there's anything to send (including credits)
self.process_tx()
@@ -434,7 +469,7 @@ class DLC(EventEmitter):
logger.debug(f'<<< MCC MSC Command: {msc}')
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
)
logger.debug(f'>>> MCC MSC Response: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
@@ -443,7 +478,7 @@ class DLC(EventEmitter):
logger.debug(f'<<< MCC MSC Response: {msc}')
def connect(self):
if not self.state == DLC.INIT:
if self.state != DLC.INIT:
raise InvalidStateError('invalid state')
self.change_state(DLC.CONNECTING)
@@ -451,7 +486,7 @@ class DLC(EventEmitter):
self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
def accept(self):
if not self.state == DLC.INIT:
if self.state != DLC.INIT:
raise InvalidStateError('invalid state')
pn = RFCOMM_MCC_PN(
@@ -463,7 +498,7 @@ class DLC(EventEmitter):
max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
)
mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.CONNECTING)
@@ -471,8 +506,8 @@ class DLC(EventEmitter):
def rx_credits_needed(self):
if self.rx_credits <= self.rx_threshold:
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
else:
return 0
return 0
def process_tx(self):
# Send anything we can (or an empty frame if we need to send rx credits)
@@ -496,7 +531,9 @@ class DLC(EventEmitter):
# Send the frame
logger.debug(
f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, rx_credits={self.rx_credits}, tx_credits={self.tx_credits}'
f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, '
f'rx_credits={self.rx_credits}, '
f'tx_credits={self.tx_credits}'
)
self.send_frame(
RFCOMM_Frame.uih(
@@ -512,8 +549,8 @@ class DLC(EventEmitter):
# Stream protocol
def write(self, data):
# We can only send bytes
if type(data) != bytes:
if type(data) == str:
if not isinstance(data, bytes):
if isinstance(data, str):
# Automatically convert strings to bytes using UTF-8
data = data.encode('utf-8')
else:
@@ -592,14 +629,14 @@ class Multiplexer(EventEmitter):
self.on_frame(frame)
else:
if frame.type == RFCOMM_DM_FRAME:
# DM responses are for a DLCI, but since we only create the dlc when we receive
# a PN response (because we need the parameters), we handle DM frames at the Multiplexer
# level
# DM responses are for a DLCI, but since we only create the dlc when we
# receive a PN response (because we need the parameters), we handle DM
# frames at the Multiplexer level
self.on_dm_frame(frame)
else:
dlc = self.dlcs.get(frame.dlci)
if dlc is None:
logger.warn(f'no dlc for DLCI {frame.dlci}')
logger.warning(f'no dlc for DLCI {frame.dlci}')
return
dlc.on_frame(frame)
@@ -607,14 +644,14 @@ class Multiplexer(EventEmitter):
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, frame):
def on_sabm_frame(self, _frame):
if self.state != Multiplexer.INIT:
logger.debug('not in INIT state, ignoring SABM')
return
self.change_state(Multiplexer.CONNECTED)
self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0))
def on_ua_frame(self, frame):
def on_ua_frame(self, _frame):
if self.state == Multiplexer.CONNECTING:
self.change_state(Multiplexer.CONNECTED)
if self.connection_result:
@@ -626,34 +663,34 @@ class Multiplexer(EventEmitter):
self.disconnection_result.set_result(None)
self.disconnection_result = None
def on_dm_frame(self, frame):
def on_dm_frame(self, _frame):
if self.state == Multiplexer.OPENING:
self.change_state(Multiplexer.CONNECTED)
if self.open_result:
self.open_result.set_exception(
ConnectionError(
ConnectionError.CONNECTION_REFUSED,
core.ConnectionError(
core.ConnectionError.CONNECTION_REFUSED,
BT_BR_EDR_TRANSPORT,
self.l2cap_channel.connection.peer_address,
'rfcomm',
)
)
else:
logger.warn(f'unexpected state for DM: {self}')
logger.warning(f'unexpected state for DM: {self}')
def on_disc_frame(self, frame):
def on_disc_frame(self, _frame):
self.change_state(Multiplexer.DISCONNECTED)
self.send_frame(
RFCOMM_Frame.ua(c_r=0 if self.role == Multiplexer.INITIATOR else 1, dlci=0)
)
def on_uih_frame(self, frame):
(type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
(mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
if type == RFCOMM_MCC_PN_TYPE:
if mcc_type == RFCOMM_MCC_PN_TYPE:
pn = RFCOMM_MCC_PN.from_bytes(value)
self.on_mcc_pn(c_r, pn)
elif type == RFCOMM_MCC_MSC_TYPE:
elif mcc_type == RFCOMM_MCC_MSC_TYPE:
mcs = RFCOMM_MCC_MSC.from_bytes(value)
self.on_mcc_msc(c_r, mcs)
@@ -669,7 +706,7 @@ class Multiplexer(EventEmitter):
if pn.dlci & 1:
# Not expected, this is an initiator-side number
# TODO: error out
logger.warn(f'invalid DLCI: {pn.dlci}')
logger.warning(f'invalid DLCI: {pn.dlci}')
else:
if self.acceptor:
channel_number = pn.dlci >> 1
@@ -688,7 +725,7 @@ class Multiplexer(EventEmitter):
self.send_frame(RFCOMM_Frame.dm(c_r=1, dlci=pn.dlci))
else:
# No acceptor?? shouldn't happen
logger.warn(color('!!! no acceptor registered', 'red'))
logger.warning(color('!!! no acceptor registered', 'red'))
else:
# Response
logger.debug(f'>>> PN Response: {pn}')
@@ -697,12 +734,12 @@ class Multiplexer(EventEmitter):
self.dlcs[pn.dlci] = dlc
dlc.connect()
else:
logger.warn('ignoring PN response')
logger.warning('ignoring PN response')
def on_mcc_msc(self, c_r, msc):
dlc = self.dlcs.get(msc.dlci)
if dlc is None:
logger.warn(f'no dlc for DLCI {msc.dlci}')
logger.warning(f'no dlc for DLCI {msc.dlci}')
return
dlc.on_mcc_msc(c_r, msc)
@@ -732,8 +769,8 @@ class Multiplexer(EventEmitter):
if self.state != Multiplexer.CONNECTED:
if self.state == Multiplexer.OPENING:
raise InvalidStateError('open already in progress')
else:
raise InvalidStateError('not connected')
raise InvalidStateError('not connected')
pn = RFCOMM_MCC_PN(
dlci=channel << 1,
@@ -744,7 +781,7 @@ class Multiplexer(EventEmitter):
max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
)
mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}')
self.open_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.OPENING)
@@ -784,7 +821,7 @@ class Client:
self.connection, RFCOMM_PSM
)
except ProtocolError as error:
logger.warn(f'L2CAP connection failed: {error}')
logger.warning(f'L2CAP connection failed: {error}')
raise
# Create a mutliplexer to manage DLCs with the server
@@ -815,17 +852,27 @@ class Server(EventEmitter):
# Register ourselves with the L2CAP channel manager
device.register_l2cap_server(RFCOMM_PSM, self.on_connection)
def listen(self, acceptor):
# Find a free channel number
for channel in range(
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START, RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1
):
if channel not in self.acceptors:
self.acceptors[channel] = acceptor
return channel
def listen(self, acceptor, channel=0):
if channel:
if channel in self.acceptors:
# Busy
return 0
else:
# Find a free channel number
for candidate in range(
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START,
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1,
):
if candidate not in self.acceptors:
channel = candidate
break
# All channels used...
return 0
if channel == 0:
# All channels used...
return 0
self.acceptors[channel] = acceptor
return channel
def on_connection(self, l2cap_channel):
logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')

View File

@@ -15,12 +15,13 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import logging
import struct
from colors import color
import colors
from typing import Dict, List, Type
from . import core
from .colors import color
from .core import InvalidStateError
from .hci import HCI_Object, name_or_number, key_with_value
@@ -34,6 +35,7 @@ logger = logging.getLogger(__name__)
# Constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
SDP_CONTINUATION_WATCHDOG = 64 # Maximum number of continuations we're willing to do
@@ -115,6 +117,8 @@ SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4) # Express this as tuple so we can convey the desired encoding size
# fmt: on
# pylint: enable=line-too-long
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------
@@ -167,100 +171,107 @@ class DataElement:
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
}
def __init__(self, type, value, value_size=None):
self.type = type
def __init__(self, element_type, value, value_size=None):
self.type = element_type
self.value = value
self.value_size = value_size
self.bytes = None # Used a cache when parsing from bytes so we can emit a byte-for-byte replica
if type == DataElement.UNSIGNED_INTEGER or type == DataElement.SIGNED_INTEGER:
# Used as a cache when parsing from bytes so we can emit a byte-for-byte replica
self.bytes = None
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
if value_size is None:
raise ValueError('integer types must have a value size specified')
@staticmethod
def nil():
def nil() -> DataElement:
return DataElement(DataElement.NIL, None)
@staticmethod
def unsigned_integer(value, value_size):
def unsigned_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size)
@staticmethod
def unsigned_integer_8(value):
def unsigned_integer_8(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1)
@staticmethod
def unsigned_integer_16(value):
def unsigned_integer_16(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2)
@staticmethod
def unsigned_integer_32(value):
def unsigned_integer_32(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4)
@staticmethod
def signed_integer(value, value_size):
def signed_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size)
@staticmethod
def signed_integer_8(value):
def signed_integer_8(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1)
@staticmethod
def signed_integer_16(value):
def signed_integer_16(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2)
@staticmethod
def signed_integer_32(value):
def signed_integer_32(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4)
@staticmethod
def uuid(value):
def uuid(value: core.UUID) -> DataElement:
return DataElement(DataElement.UUID, value)
@staticmethod
def text_string(value):
def text_string(value: str) -> DataElement:
return DataElement(DataElement.TEXT_STRING, value)
@staticmethod
def boolean(value):
def boolean(value: bool) -> DataElement:
return DataElement(DataElement.BOOLEAN, value)
@staticmethod
def sequence(value):
def sequence(value: List[DataElement]) -> DataElement:
return DataElement(DataElement.SEQUENCE, value)
@staticmethod
def alternative(value):
def alternative(value: List[DataElement]) -> DataElement:
return DataElement(DataElement.ALTERNATIVE, value)
@staticmethod
def url(value):
def url(value: str) -> DataElement:
return DataElement(DataElement.URL, value)
@staticmethod
def unsigned_integer_from_bytes(data):
if len(data) == 1:
return data[0]
elif len(data) == 2:
if len(data) == 2:
return struct.unpack('>H', data)[0]
elif len(data) == 4:
if len(data) == 4:
return struct.unpack('>I', data)[0]
elif len(data) == 8:
if len(data) == 8:
return struct.unpack('>Q', data)[0]
else:
raise ValueError(f'invalid integer length {len(data)}')
raise ValueError(f'invalid integer length {len(data)}')
@staticmethod
def signed_integer_from_bytes(data):
if len(data) == 1:
return struct.unpack('b', data)[0]
elif len(data) == 2:
if len(data) == 2:
return struct.unpack('>h', data)[0]
elif len(data) == 4:
if len(data) == 4:
return struct.unpack('>i', data)[0]
elif len(data) == 8:
if len(data) == 8:
return struct.unpack('>q', data)[0]
else:
raise ValueError(f'invalid integer length {len(data)}')
raise ValueError(f'invalid integer length {len(data)}')
@staticmethod
def list_from_bytes(data):
@@ -278,11 +289,11 @@ class DataElement:
@staticmethod
def from_bytes(data):
type = data[0] >> 3
element_type = data[0] >> 3
size_index = data[0] & 7
value_offset = 0
if size_index == 0:
if type == DataElement.NIL:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
@@ -305,17 +316,17 @@ class DataElement:
value_offset = 4
value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.type_constructors.get(type)
constructor = DataElement.type_constructors.get(element_type)
if constructor:
if (
type == DataElement.UNSIGNED_INTEGER
or type == DataElement.SIGNED_INTEGER
if element_type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
):
result = constructor(value_data, value_size)
else:
result = constructor(value_data)
else:
result = DataElement(type, value_data)
result = DataElement(element_type, value_data)
result.bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
@@ -334,7 +345,8 @@ class DataElement:
elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise ValueError('UNSIGNED_INTEGER cannot be negative')
elif self.value_size == 1:
if self.value_size == 1:
data = struct.pack('B', self.value)
elif self.value_size == 2:
data = struct.pack('>H', self.value)
@@ -357,11 +369,11 @@ class DataElement:
raise ValueError('invalid value_size')
elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value)))
elif self.type == DataElement.TEXT_STRING or self.type == DataElement.URL:
elif self.type in (DataElement.TEXT_STRING, DataElement.URL):
data = self.value.encode('utf8')
elif self.type == DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
elif self.type == DataElement.SEQUENCE or self.type == DataElement.ALTERNATIVE:
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
data = b''.join([bytes(element) for element in self.value])
else:
data = self.value
@@ -372,10 +384,10 @@ class DataElement:
if size != 0:
raise ValueError('NIL must be empty')
size_index = 0
elif (
self.type == DataElement.UNSIGNED_INTEGER
or self.type == DataElement.SIGNED_INTEGER
or self.type == DataElement.UUID
elif self.type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
DataElement.UUID,
):
if size <= 1:
size_index = 0
@@ -389,11 +401,11 @@ class DataElement:
size_index = 4
else:
raise ValueError('invalid data size')
elif (
self.type == DataElement.TEXT_STRING
or self.type == DataElement.SEQUENCE
or self.type == DataElement.ALTERNATIVE
or self.type == DataElement.URL
elif self.type in (
DataElement.TEXT_STRING,
DataElement.SEQUENCE,
DataElement.ALTERNATIVE,
DataElement.URL,
):
if size <= 0xFF:
size_index = 5
@@ -419,14 +431,19 @@ class DataElement:
type_name = name_or_number(self.TYPE_NAMES, self.type)
if self.type == DataElement.NIL:
value_string = ''
elif self.type == DataElement.SEQUENCE or self.type == DataElement.ALTERNATIVE:
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
value_string = f'[{container_separator}{element_separator.join([element.to_string(pretty, indentation + 1 if pretty else 0) for element in self.value])}{container_separator}{prefix}]'
elif (
self.type == DataElement.UNSIGNED_INTEGER
or self.type == DataElement.SIGNED_INTEGER
):
elements = [
element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
value_string = f'{self.value}#{self.value_size}'
elif isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
@@ -440,8 +457,8 @@ class DataElement:
# -----------------------------------------------------------------------------
class ServiceAttribute:
def __init__(self, id, value):
self.id = id
def __init__(self, attribute_id: int, value: DataElement) -> None:
self.id = attribute_id
self.value = value
@staticmethod
@@ -450,7 +467,7 @@ class ServiceAttribute:
for i in range(0, len(elements) // 2):
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
if attribute_id.type != DataElement.UNSIGNED_INTEGER:
logger.warn('attribute ID element is not an integer')
logger.warning('attribute ID element is not an integer')
continue
attribute_list.append(ServiceAttribute(attribute_id.value, attribute_value))
@@ -468,27 +485,31 @@ class ServiceAttribute:
)
@staticmethod
def id_name(id):
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id)
def id_name(id_code):
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
@staticmethod
def is_uuid_in_value(uuid, value):
# Find if a uuid matches a value, either directly or recursing into sequences
if value.type == DataElement.UUID:
return value.value == uuid
elif value.type == DataElement.SEQUENCE:
if value.type == DataElement.SEQUENCE:
for element in value.value:
if ServiceAttribute.is_uuid_in_value(uuid, element):
return True
return False
else:
return False
def to_string(self, color=False):
if color:
return f'Attribute(id={colors.color(self.id_name(self.id),"magenta")},value={self.value})'
else:
return f'Attribute(id={self.id_name(self.id)},value={self.value})'
return False
def to_string(self, with_colors=False):
if with_colors:
return (
f'Attribute(id={color(self.id_name(self.id),"magenta")},'
f'value={self.value})'
)
return f'Attribute(id={self.id_name(self.id)},value={self.value})'
def __str__(self):
return self.to_string()
@@ -500,11 +521,13 @@ class SDP_PDU:
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
'''
sdp_pdu_classes = {}
sdp_pdu_classes: Dict[int, Type[SDP_PDU]] = {}
name = None
pdu_id = 0
@staticmethod
def from_bytes(pdu):
pdu_id, transaction_id, parameters_length = struct.unpack_from('>BHH', pdu, 0)
pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0)
cls = SDP_PDU.sdp_pdu_classes.get(pdu_id)
if cls is None:
@@ -755,7 +778,7 @@ class Client:
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if type(attribute_id) is tuple
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids
]
@@ -787,7 +810,7 @@ class Client:
# Parse the result into attribute lists
attribute_lists_sequences = DataElement.from_bytes(accumulator)
if attribute_lists_sequences.type != DataElement.SEQUENCE:
logger.warn('unexpected data type')
logger.warning('unexpected data type')
return []
return [
@@ -805,7 +828,7 @@ class Client:
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if type(attribute_id) is tuple
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids
]
@@ -837,7 +860,7 @@ class Client:
# Parse the result into a list of attributes
attribute_list_sequence = DataElement.from_bytes(accumulator)
if attribute_list_sequence.type != DataElement.SEQUENCE:
logger.warn('unexpected data type')
logger.warning('unexpected data type')
return []
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
@@ -850,6 +873,7 @@ class Server:
def __init__(self, device):
self.device = device
self.service_records = {} # Service records maps, by record handle
self.channel = None
self.current_response = None
def register(self, l2cap_channel_manager):
@@ -884,7 +908,7 @@ class Server:
try:
sdp_pdu = SDP_PDU.from_bytes(pdu)
except Exception as error:
logger.warn(color(f'failed to parse SDP Request PDU: {error}', 'red'))
logger.warning(color(f'failed to parse SDP Request PDU: {error}', 'red'))
self.send_response(
SDP_ErrorResponse(
transaction_id=0, error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR
@@ -945,7 +969,7 @@ class Server:
if attribute.id >= id_range_start and attribute.id <= id_range_end
]
# Return the maching attributes, sorted by attribute id
# Return the matching attributes, sorted by attribute id
attributes.sort(key=lambda x: x.id)
attribute_list = DataElement.sequence([])
for attribute in attributes:

View File

@@ -22,14 +22,23 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import logging
import asyncio
import secrets
from pyee import EventEmitter
from colors import color
from typing import Dict, Optional, Type
from .core import *
from .hci import *
from pyee import EventEmitter
from .colors import color
from .hci import Address, HCI_LE_Enable_Encryption_Command, HCI_Object, key_with_value
from .core import (
BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE,
BT_LE_TRANSPORT,
ProtocolError,
name_or_number,
)
from .keys import PairingKeys
from . import crypto
@@ -44,6 +53,7 @@ logger = logging.getLogger(__name__)
# Constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
SMP_CID = 0x06
SMP_BR_CID = 0x07
@@ -158,6 +168,8 @@ SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('00000000000000000000000000000000746D7031'
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032')
# fmt: on
# pylint: enable=line-too-long
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------
@@ -175,8 +187,9 @@ class SMP_Command:
See Bluetooth spec @ Vol 3, Part H - 3 SECURITY MANAGER PROTOCOL
'''
smp_classes = {}
smp_classes: Dict[int, Type[SMP_Command]] = {}
code = 0
name = ''
@staticmethod
def from_bytes(pdu):
@@ -206,7 +219,10 @@ class SMP_Command:
keypress = (value >> 4) & 1
ct2 = (value >> 5) & 1
return f'bonding_flags={bonding_flags}, MITM={mitm}, sc={sc}, keypress={keypress}, ct2={ct2}'
return (
f'bonding_flags={bonding_flags}, '
f'MITM={mitm}, sc={sc}, keypress={keypress}, ct2={ct2}'
)
@staticmethod
def io_capability_name(io_capability):
@@ -458,11 +474,11 @@ class AddressResolver:
def resolve(self, address):
address_bytes = bytes(address)
hash = address_bytes[0:3]
hash_part = address_bytes[0:3]
prand = address_bytes[3:6]
for (irk, resolved_address) in self.resolving_keys:
local_hash = crypto.ah(irk, prand)
if local_hash == hash:
if local_hash == hash_part:
# Match!
if resolved_address.address_type == Address.PUBLIC_DEVICE_ADDRESS:
resolved_address_type = Address.PUBLIC_IDENTITY_ADDRESS
@@ -472,6 +488,8 @@ class AddressResolver:
address=str(resolved_address), address_type=resolved_address_type
)
return None
# -----------------------------------------------------------------------------
class PairingDelegate:
@@ -480,33 +498,35 @@ class PairingDelegate:
DISPLAY_OUTPUT_ONLY = SMP_DISPLAY_ONLY_IO_CAPABILITY
DISPLAY_OUTPUT_AND_YES_NO_INPUT = SMP_DISPLAY_YES_NO_IO_CAPABILITY
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = SMP_KEYBOARD_DISPLAY_IO_CAPABILITY
DEFAULT_KEY_DISTRIBUTION = (
DEFAULT_KEY_DISTRIBUTION: int = (
SMP_ENC_KEY_DISTRIBUTION_FLAG | SMP_ID_KEY_DISTRIBUTION_FLAG
)
def __init__(
self,
io_capability=NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution=DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution=DEFAULT_KEY_DISTRIBUTION,
):
io_capability: int = NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution: int = DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution: int = DEFAULT_KEY_DISTRIBUTION,
) -> None:
self.io_capability = io_capability
self.local_initiator_key_distribution = local_initiator_key_distribution
self.local_responder_key_distribution = local_responder_key_distribution
async def accept(self):
async def accept(self) -> bool:
return True
async def confirm(self):
async def confirm(self) -> bool:
return True
async def compare_numbers(self, number, digits=6):
# pylint: disable-next=unused-argument
async def compare_numbers(self, number: int, digits: int) -> bool:
return True
async def get_number(self):
async def get_number(self) -> int:
return 0
async def display_number(self, number, digits=6):
# pylint: disable-next=unused-argument
async def display_number(self, number: int, digits: int) -> None:
pass
async def key_distribution_response(
@@ -520,7 +540,13 @@ class PairingDelegate:
# -----------------------------------------------------------------------------
class PairingConfig:
def __init__(self, sc=True, mitm=True, bonding=True, delegate=None):
def __init__(
self,
sc: bool = True,
mitm: bool = True,
bonding: bool = True,
delegate: Optional[PairingDelegate] = None,
) -> None:
self.sc = sc
self.mitm = mitm
self.bonding = bonding
@@ -528,7 +554,11 @@ class PairingConfig:
def __str__(self):
io_capability_str = SMP_Command.io_capability_name(self.delegate.io_capability)
return f'PairingConfig(sc={self.sc}, mitm={self.mitm}, bonding={self.bonding}, delegate[{io_capability_str}])'
return (
f'PairingConfig(sc={self.sc}, '
f'mitm={self.mitm}, bonding={self.bonding}, '
f'delegate[{io_capability_str}])'
)
# -----------------------------------------------------------------------------
@@ -548,14 +578,16 @@ class Session:
# I/O Capability to pairing method decision matrix
#
# See Bluetooth spec @ Vol 3, part H - Table 2.8: Mapping of IO Capabilities to Key Generation Method
# See Bluetooth spec @ Vol 3, part H - Table 2.8: Mapping of IO Capabilities to Key
# Generation Method
#
# Map: initiator -> responder -> <method>
# where <method> may be a simple entry or a 2-element tuple, with the first element for legacy
# pairing and the second for secure connections, when the two are different.
# Each entry is either a method name, or, for PASSKEY, a tuple:
# where <method> may be a simple entry or a 2-element tuple, with the first element
# for legacy pairing and the second for secure connections, when the two are
# different. Each entry is either a method name, or, for PASSKEY, a tuple:
# (method, initiator_displays, responder_displays)
# to specify if the initiator and responder should display (True) or input a code (False).
# to specify if the initiator and responder should display (True) or input a code
# (False).
PAIRING_METHODS = {
SMP_DISPLAY_ONLY_IO_CAPABILITY: {
SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS,
@@ -606,6 +638,10 @@ class Session:
def __init__(self, manager, connection, pairing_config):
self.manager = manager
self.connection = connection
self.preq = None
self.pres = None
self.ea = None
self.eb = None
self.tk = bytes(16)
self.r = bytes(16)
self.stk = None
@@ -626,7 +662,9 @@ class Session:
self.peer_signature_key = None
self.peer_expected_distributions = []
self.dh_key = None
self.passkey = 0
self.confirm_value = None
self.passkey = None
self.passkey_ready = asyncio.Event()
self.passkey_step = 0
self.passkey_display = False
self.pairing_method = 0
@@ -726,6 +764,8 @@ class Session:
else:
return self.ltk
return None
def decide_pairing_method(
self, auth_req, initiator_io_capability, responder_io_capability
):
@@ -734,10 +774,10 @@ class Session:
return
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability]
if type(details) is tuple and len(details) == 2:
if isinstance(details, tuple) and len(details) == 2:
# One entry for legacy pairing and one for secure connections
details = details[1 if self.sc else 0]
if type(details) is int:
if isinstance(details, int):
# Just a method ID
self.pairing_method = details
else:
@@ -762,11 +802,11 @@ class Session:
next_steps()
return
except Exception as error:
logger.warn(f'exception while confirm: {error}')
logger.warning(f'exception while confirm: {error}')
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
asyncio.create_task(prompt())
self.connection.abort_on('disconnection', prompt())
def prompt_user_for_numeric_comparison(self, code, next_steps):
async def prompt():
@@ -779,11 +819,11 @@ class Session:
next_steps()
return
except Exception as error:
logger.warn(f'exception while prompting: {error}')
logger.warning(f'exception while prompting: {error}')
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
asyncio.create_task(prompt())
self.connection.abort_on('disconnection', prompt())
def prompt_user_for_number(self, next_steps):
async def prompt():
@@ -793,23 +833,25 @@ class Session:
logger.debug(f'user input: {passkey}')
next_steps(passkey)
except Exception as error:
logger.warn(f'exception while prompting: {error}')
logger.warning(f'exception while prompting: {error}')
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
asyncio.create_task(prompt())
self.connection.abort_on('disconnection', prompt())
def display_passkey(self):
# Generate random Passkey/PIN code
self.passkey = secrets.randbelow(1000000)
logger.debug(f'Pairing PIN CODE: {self.passkey:06}')
self.passkey_ready.set()
# The value of TK is computed from the PIN code
if not self.sc:
self.tk = self.passkey.to_bytes(16, byteorder='little')
logger.debug(f'TK from passkey = {self.tk.hex()}')
asyncio.create_task(
self.pairing_config.delegate.display_number(self.passkey, digits=6)
self.connection.abort_on(
'disconnection',
self.pairing_config.delegate.display_number(self.passkey, digits=6),
)
def input_passkey(self, next_steps=None):
@@ -821,6 +863,8 @@ class Session:
self.tk = passkey.to_bytes(16, byteorder='little')
logger.debug(f'TK from passkey = {self.tk.hex()}')
self.passkey_ready.set()
if next_steps is not None:
next_steps()
@@ -872,20 +916,29 @@ class Session:
logger.debug(f'generated random: {self.r.hex()}')
if self.sc:
if (
self.pairing_method == self.JUST_WORKS
or self.pairing_method == self.NUMERIC_COMPARISON
):
z = 0
elif self.pairing_method == self.PASSKEY:
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else:
return
if self.is_initiator:
confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z]))
else:
confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z]))
async def next_steps():
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
z = 0
elif self.pairing_method == self.PASSKEY:
# We need a passkey
await self.passkey_ready.wait()
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else:
return
if self.is_initiator:
confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z]))
else:
confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z]))
self.send_command(
SMP_Pairing_Confirm_Command(confirm_value=confirm_value)
)
# Perform the next steps asynchronously in case we need to wait for input
self.connection.abort_on('disconnection', next_steps())
else:
confirm_value = crypto.c1(
self.tk,
@@ -898,7 +951,7 @@ class Session:
self.ra,
)
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
def send_pairing_random_command(self):
self.send_command(SMP_Pairing_Random_Command(random_value=self.r))
@@ -921,14 +974,12 @@ class Session:
def start_encryption(self, key):
# We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection
asyncio.create_task(
self.manager.device.host.send_command(
HCI_LE_Enable_Encryption_Command(
connection_handle=self.connection.handle,
random_number=bytes(8),
encrypted_diversifier=0,
long_term_key=key,
)
self.manager.device.host.send_command_sync(
HCI_LE_Enable_Encryption_Command(
connection_handle=self.connection.handle,
random_number=bytes(8),
encrypted_diversifier=0,
long_term_key=key,
)
)
@@ -950,7 +1001,9 @@ class Session:
self.connection.transport == BT_BR_EDR_TRANSPORT
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
):
self.ctkd_task = asyncio.create_task(self.derive_ltk())
self.ctkd_task = self.connection.abort_on(
'disconnection', self.derive_ltk()
)
elif not self.sc:
# Distribute the LTK, EDIV and RAND
if self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
@@ -997,7 +1050,9 @@ class Session:
self.connection.transport == BT_BR_EDR_TRANSPORT
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
):
self.ctkd_task = asyncio.create_task(self.derive_ltk())
self.ctkd_task = self.connection.abort_on(
'disconnection', self.derive_ltk()
)
# Distribute the LTK, EDIV and RAND
elif not self.sc:
if self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
@@ -1057,13 +1112,14 @@ class Session:
if key_distribution_flags & SMP_SIGN_KEY_DISTRIBUTION_FLAG != 0:
self.peer_expected_distributions.append(SMP_Signing_Information_Command)
logger.debug(
f'expecting distributions: {[c.__name__ for c in self.peer_expected_distributions]}'
'expecting distributions: '
f'{[c.__name__ for c in self.peer_expected_distributions]}'
)
def check_key_distribution(self, command_class):
# First, check that the connection is encrypted
if not self.connection.is_encrypted:
logger.warn(
logger.warning(
color('received key distribution on a non-encrypted connection', 'red')
)
self.send_pairing_failed(SMP_UNSPECIFIED_REASON_ERROR)
@@ -1073,14 +1129,16 @@ class Session:
if command_class in self.peer_expected_distributions:
self.peer_expected_distributions.remove(command_class)
logger.debug(
f'remaining distributions: {[c.__name__ for c in self.peer_expected_distributions]}'
'remaining distributions: '
f'{[c.__name__ for c in self.peer_expected_distributions]}'
)
if not self.peer_expected_distributions:
self.on_peer_key_distribution_complete()
else:
logger.warn(
logger.warning(
color(
f'!!! unexpected key distribution command: {command_class.__name__}',
'!!! unexpected key distribution command: '
f'{command_class.__name__}',
'red',
)
)
@@ -1094,9 +1152,9 @@ class Session:
self.send_pairing_request_command()
# Wait for the pairing process to finish
await self.pairing_result
await self.connection.abort_on('disconnection', self.pairing_result)
def on_disconnection(self, reason):
def on_disconnection(self, _):
self.connection.remove_listener('disconnection', self.on_disconnection)
self.connection.remove_listener(
'connection_encryption_change', self.on_connection_encryption_change
@@ -1112,7 +1170,7 @@ class Session:
if self.is_initiator:
self.distribute_keys()
asyncio.create_task(self.on_pairing())
self.connection.abort_on('disconnection', self.on_pairing())
def on_connection_encryption_change(self):
if self.connection.is_encrypted:
@@ -1133,8 +1191,8 @@ class Session:
if self.completed:
return
else:
self.completed = True
self.completed = True
if self.pairing_result is not None and not self.pairing_result.done():
self.pairing_result.set_result(None)
@@ -1194,8 +1252,8 @@ class Session:
if self.completed:
return
else:
self.completed = True
self.completed = True
error = ProtocolError(reason, 'smp', error_name(reason))
if self.pairing_result is not None and not self.pairing_result.done():
@@ -1219,7 +1277,9 @@ class Session:
logger.error(color('SMP command not handled???', 'red'))
def on_smp_pairing_request_command(self, command):
asyncio.create_task(self.on_smp_pairing_request_command_async(command))
self.connection.abort_on(
'disconnection', self.on_smp_pairing_request_command_async(command)
)
async def on_smp_pairing_request_command_async(self, command):
# Check if the request should proceed
@@ -1239,7 +1299,7 @@ class Session:
# Check for OOB
if command.oob_data_flag != 0:
self.terminate(SMP_OOB_NOT_AVAILABLE_ERROR)
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
# Decide which pairing method to use
@@ -1283,7 +1343,7 @@ class Session:
def on_smp_pairing_response_command(self, command):
if self.is_responder:
logger.warn(color('received pairing response as a responder', 'red'))
logger.warning(color('received pairing response as a responder', 'red'))
return
# Save the response
@@ -1322,8 +1382,8 @@ class Session:
# Start phase 2
if self.sc:
if self.pairing_method == self.PASSKEY and self.passkey_display:
self.display_passkey()
if self.pairing_method == self.PASSKEY:
self.display_or_input_passkey()
self.send_public_key_command()
else:
@@ -1332,7 +1392,7 @@ class Session:
else:
self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command_legacy(self, command):
def on_smp_pairing_confirm_command_legacy(self, _):
if self.is_initiator:
self.send_pairing_random_command()
else:
@@ -1342,11 +1402,8 @@ class Session:
else:
self.send_pairing_confirm_command()
def on_smp_pairing_confirm_command_secure_connections(self, command):
if (
self.pairing_method == self.JUST_WORKS
or self.pairing_method == self.NUMERIC_COMPARISON
):
def on_smp_pairing_confirm_command_secure_connections(self, _):
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.is_initiator:
self.r = crypto.r()
self.send_pairing_random_command()
@@ -1387,23 +1444,25 @@ class Session:
else:
srand = self.r
mrand = command.random_value
stk = crypto.s1(self.tk, srand, mrand)
logger.debug(f'STK = {stk.hex()}')
self.stk = crypto.s1(self.tk, srand, mrand)
logger.debug(f'STK = {self.stk.hex()}')
# Generate LTK
self.ltk = crypto.r()
if self.is_initiator:
self.start_encryption(stk)
self.start_encryption(self.stk)
else:
self.send_pairing_random_command()
def on_smp_pairing_random_command_secure_connections(self, command):
if self.pairing_method == self.PASSKEY and self.passkey is None:
logger.warning('no passkey entered, ignoring command')
return
# pylint: disable=too-many-return-statements
if self.is_initiator:
if (
self.pairing_method == self.JUST_WORKS
or self.pairing_method == self.NUMERIC_COMPARISON
):
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
# Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4(
self.pkb, self.pka, command.random_value, bytes([0])
@@ -1434,10 +1493,7 @@ class Session:
else:
return
else:
if (
self.pairing_method == self.JUST_WORKS
or self.pairing_method == self.NUMERIC_COMPARISON
):
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
self.send_pairing_random_command()
elif self.pairing_method == self.PASSKEY:
# Check that the random value matches what was committed to earlier
@@ -1469,10 +1525,7 @@ class Session:
(mac_key, self.ltk) = crypto.f5(self.dh_key, self.na, self.nb, a, b)
# Compute the DH Key checks
if (
self.pairing_method == self.JUST_WORKS
or self.pairing_method == self.NUMERIC_COMPARISON
):
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
ra = bytes(16)
rb = ra
elif self.pairing_method == self.PASSKEY:
@@ -1497,10 +1550,7 @@ class Session:
self.wait_before_continuing.set_result(None)
# Prompt the user for confirmation if needed
if (
self.pairing_method == self.JUST_WORKS
or self.pairing_method == self.NUMERIC_COMPARISON
):
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
# Compute the 6-digit code
code = crypto.g2(self.pka, self.pkb, self.na, self.nb) % 1000000
@@ -1537,22 +1587,15 @@ class Session:
logger.debug(f'DH key: {self.dh_key.hex()}')
if self.is_initiator:
if self.pairing_method == self.PASSKEY:
if self.passkey_display:
self.send_pairing_confirm_command()
else:
self.input_passkey(self.send_pairing_confirm_command)
self.send_pairing_confirm_command()
else:
# Send our public key back to the initiator
if self.pairing_method == self.PASSKEY:
self.display_or_input_passkey(self.send_public_key_command)
else:
self.send_public_key_command()
self.display_or_input_passkey()
if (
self.pairing_method == self.JUST_WORKS
or self.pairing_method == self.NUMERIC_COMPARISON
):
# Send our public key back to the initiator
self.send_public_key_command()
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
# We can now send the confirmation value
self.send_pairing_confirm_command()
@@ -1572,7 +1615,7 @@ class Session:
self.wait_before_continuing = None
self.send_pairing_dhkey_check_command()
asyncio.create_task(next_steps())
self.connection.abort_on('disconnection', next_steps())
else:
self.send_pairing_dhkey_check_command()
else:
@@ -1618,7 +1661,8 @@ class Manager(EventEmitter):
def send_command(self, connection, command):
logger.debug(
f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] {connection.peer_address}: {command}'
f'>>> Sending SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
)
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes())
@@ -1640,7 +1684,8 @@ class Manager(EventEmitter):
# Parse the L2CAP payload into an SMP Command object
command = SMP_Command.from_bytes(pdu)
logger.debug(
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] {connection.peer_address}: {command}'
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
)
# Delegate the handling of the command to the session
@@ -1686,9 +1731,9 @@ class Manager(EventEmitter):
try:
await self.device.keystore.update(str(identity_address), keys)
except Exception as error:
logger.warn(f'!!! error while storing keys: {error}')
logger.warning(f'!!! error while storing keys: {error}')
asyncio.create_task(store_keys())
self.device.abort_on('flush', store_keys())
# Notify the device
self.device.on_pairing(session.connection.handle, keys, session.sc)
@@ -1704,3 +1749,5 @@ class Manager(EventEmitter):
def get_long_term_key(self, connection, rand, ediv):
if session := self.sessions.get(connection.handle):
return session.get_long_term_key(rand, ediv)
return None

170
bumble/snoop.py Normal file
View File

@@ -0,0 +1,170 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from contextlib import contextmanager
from enum import IntEnum
import logging
import struct
import datetime
from typing import BinaryIO, Generator
import os
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Snooper:
"""
Base class for snooper implementations.
A snooper is an object that will be provided with HCI packets as they are
exchanged between a host and a controller.
"""
class Direction(IntEnum):
HOST_TO_CONTROLLER = 0
CONTROLLER_TO_HOST = 1
class DataLinkType(IntEnum):
H1 = 1001
H4 = 1002
HCI_BSCP = 1003
H5 = 1004
def snoop(self, hci_packet: bytes, direction: Direction) -> None:
"""Snoop on an HCI packet."""
# -----------------------------------------------------------------------------
class BtSnooper(Snooper):
"""
Snooper that saves HCI packets using the BTSnoop format, based on RFC 1761.
"""
IDENTIFICATION_PATTERN = b'btsnoop\0'
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1)
TIMESTAMP_DELTA = 0x00E03AB44A676000
ONE_MS = datetime.timedelta(microseconds=1)
def __init__(self, output: BinaryIO):
self.output = output
# Write the header
self.output.write(
self.IDENTIFICATION_PATTERN + struct.pack('>LL', 1, self.DataLinkType.H4)
)
def snoop(self, hci_packet: bytes, direction: Snooper.Direction) -> None:
flags = int(direction)
packet_type = hci_packet[0]
if packet_type in (HCI_EVENT_PACKET, HCI_COMMAND_PACKET):
flags |= 0x10
# Compute the current timestamp
timestamp = (
int((datetime.datetime.utcnow() - self.TIMESTAMP_ANCHOR) / self.ONE_MS)
+ self.TIMESTAMP_DELTA
)
# Emit the record
self.output.write(
struct.pack(
'>IIIIQ',
len(hci_packet), # Original Length
len(hci_packet), # Included Length
flags, # Packet Flags
0, # Cumulative Drops
timestamp, # Timestamp
)
+ hci_packet
)
# -----------------------------------------------------------------------------
_SNOOPER_INSTANCE_COUNT = 0
@contextmanager
def create_snooper(spec: str) -> Generator[Snooper, None, None]:
"""
Create a snooper given a specification string.
The general syntax for the specification string is:
<snooper-type>:<type-specific-arguments>
Supported snooper types are:
btsnoop
The syntax for the type-specific arguments for this type is:
<io-type>:<io-type-specific-arguments>
Supported I/O types are:
file
The type-specific arguments for this I/O type is a string that is converted
to a file path using the python `str.format()` string formatting. The log
records will be written to that file if it can be opened/created.
The keyword args that may be referenced by the string pattern are:
now: the value of `datetime.now()`
utcnow: the value of `datetime.utcnow()`
pid: the current process ID.
instance: the instance ID in the current process.
Examples:
btsnoop:file:my_btsnoop.log
btsnoop:file:/tmp/bumble_{now:%Y-%m-%d-%H:%M:%S}_{pid}.log
"""
if ':' not in spec:
raise ValueError('snooper type prefix missing')
snooper_type, snooper_args = spec.split(':', maxsplit=1)
if snooper_type == 'btsnoop':
if ':' not in snooper_args:
raise ValueError('I/O type for btsnoop snooper type missing')
io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type == 'file':
# Process the file name string pattern.
global _SNOOPER_INSTANCE_COUNT
file_path = io_name.format(
now=datetime.datetime.now(),
utcnow=datetime.datetime.utcnow(),
pid=os.getpid(),
instance=_SNOOPER_INSTANCE_COUNT,
)
# Open the file
logger.debug(f'Snoop file: {file_path}')
with open(file_path, 'wb') as snoop_file:
_SNOOPER_INSTANCE_COUNT += 1
yield BtSnooper(snoop_file)
_SNOOPER_INSTANCE_COUNT -= 1
return
raise ValueError(f'I/O type {io_type} not supported')
raise ValueError(f'snooper type {snooper_type} not found')

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,11 +15,13 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from contextlib import asynccontextmanager
import logging
import os
from .common import Transport, AsyncPipeSink
from ..link import RemoteLink
from .common import Transport, AsyncPipeSink, SnoopingTransport
from ..controller import Controller
from ..snoop import create_snooper
# -----------------------------------------------------------------------------
# Logging
@@ -28,73 +30,140 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_transport(name):
'''
def _wrap_transport(transport: Transport) -> Transport:
"""
Automatically wrap a Transport instance when a wrapping class can be inferred
from the environment.
If no wrapping class is applicable, the transport argument is returned as-is.
"""
# If BUMBLE_SNOOPER is set, try to automatically create a snooper.
if snooper_spec := os.getenv('BUMBLE_SNOOPER'):
try:
return SnoopingTransport.create_with(
transport, create_snooper(snooper_spec)
)
except Exception as exc:
logger.warning(f'Exception while creating snooper: {exc}')
return transport
# -----------------------------------------------------------------------------
async def open_transport(name: str) -> Transport:
"""
Open a transport by name.
The name must be <type>:<parameters>
Where <parameters> depend on the type (and may be empty for some types).
The supported types are: serial,udp,tcp,pty,usb
'''
The supported types are:
* serial
* udp
* tcp-client
* tcp-server
* ws-client
* ws-server
* pty
* file
* vhci
* hci-socket
* usb
* pyusb
* android-emulator
"""
return _wrap_transport(await _open_transport(name))
# -----------------------------------------------------------------------------
async def _open_transport(name: str) -> Transport:
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements
scheme, *spec = name.split(':', 1)
if scheme == 'serial' and spec:
from .serial import open_serial_transport
return await open_serial_transport(spec[0])
elif scheme == 'udp' and spec:
if scheme == 'udp' and spec:
from .udp import open_udp_transport
return await open_udp_transport(spec[0])
elif scheme == 'tcp-client' and spec:
if scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec[0])
elif scheme == 'tcp-server' and spec:
if scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec[0])
elif scheme == 'ws-client' and spec:
if scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport
return await open_ws_client_transport(spec[0])
elif scheme == 'ws-server' and spec:
if scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport
return await open_ws_server_transport(spec[0])
elif scheme == 'pty':
if scheme == 'pty':
from .pty import open_pty_transport
return await open_pty_transport(spec[0] if spec else None)
elif scheme == 'file':
if scheme == 'file':
from .file import open_file_transport
return await open_file_transport(spec[0] if spec else None)
elif scheme == 'vhci':
if scheme == 'vhci':
from .vhci import open_vhci_transport
return await open_vhci_transport(spec[0] if spec else None)
elif scheme == 'hci-socket':
if scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec[0] if spec else None)
elif scheme == 'usb':
if scheme == 'usb':
from .usb import open_usb_transport
return await open_usb_transport(spec[0] if spec else None)
elif scheme == 'pyusb':
if scheme == 'pyusb':
from .pyusb import open_pyusb_transport
return await open_pyusb_transport(spec[0] if spec else None)
elif scheme == 'android-emulator':
if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec[0] if spec else None)
else:
raise ValueError('unknown transport scheme')
raise ValueError('unknown transport scheme')
# -----------------------------------------------------------------------------
async def open_transport_or_link(name):
async def open_transport_or_link(name: str) -> Transport:
"""
Open a transport or a link relay.
Args:
name:
Name of the transport or link relay to open.
When the name starts with "link-relay:", open a link relay (see RemoteLink
for details on what the arguments are).
For other namespaces, see `open_transport`.
"""
if name.startswith('link-relay:'):
from ..link import RemoteLink # lazy import
link = RemoteLink(name[11:])
await link.wait_until_connected()
controller = Controller('remote', link=link)
@@ -103,6 +172,6 @@ async def open_transport_or_link(name):
async def close(self):
link.close()
return LinkTransport(controller, AsyncPipeSink(controller))
else:
return await open_transport(name)
return _wrap_transport(LinkTransport(controller, AsyncPipeSink(controller)))
return await open_transport(name)

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,9 +20,11 @@ import grpc
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink
from .emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
from .emulated_bluetooth_packets_pb2 import HCIPacket
from .emulated_bluetooth_vhci_pb2_grpc import VhciForwardingServiceStub
# pylint: disable-next=no-name-in-module
from .emulated_bluetooth_packets_pb2 import HCIPacket
# -----------------------------------------------------------------------------
# Logging

View File

@@ -15,12 +15,16 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import contextlib
import struct
import asyncio
import logging
from colors import color
from typing import ContextManager
from .. import hci
from ..colors import color
from ..snoop import Snooper
# -----------------------------------------------------------------------------
@@ -65,9 +69,12 @@ class PacketPump:
# -----------------------------------------------------------------------------
class PacketParser:
'''
In-line parser that accepts data and emits 'on_packet' when a full packet has been parsed
In-line parser that accepts data and emits 'on_packet' when a full packet has been
parsed
'''
# pylint: disable=attribute-defined-outside-init
NEED_TYPE = 0
NEED_LENGTH = 1
NEED_BODY = 2
@@ -243,6 +250,20 @@ class StreamPacketSink:
# -----------------------------------------------------------------------------
class Transport:
"""
Base class for all transports.
A Transport represents a source and a sink together.
An instance must be closed by calling close() when no longer used. Instances
implement the ContextManager protocol so that they may be used in a `async with`
statement.
An instance is iterable. The iterator yields, in order, its source and sink, so
that it may be used with a convenient call syntax like:
async with create_transport() as (source, sink):
...
"""
def __init__(self, source, sink):
self.source = source
self.sink = sink
@@ -256,7 +277,7 @@ class Transport:
def __iter__(self):
return iter((self.source, self.sink))
async def close(self):
async def close(self) -> None:
self.source.close()
self.sink.close()
@@ -278,7 +299,7 @@ class PumpedPacketSource(ParserSource):
logger.debug('source pump task done')
break
except Exception as error:
logger.warn(f'exception while waiting for packet: {error}')
logger.warning(f'exception while waiting for packet: {error}')
self.terminated.set_result(error)
break
@@ -309,7 +330,7 @@ class PumpedPacketSink:
logger.debug('sink pump task done')
break
except Exception as error:
logger.warn(f'exception while sending packet: {error}')
logger.warning(f'exception while sending packet: {error}')
break
self.pump_task = asyncio.create_task(pump_packets())
@@ -332,3 +353,60 @@ class PumpedTransport(Transport):
async def close(self):
await super().close()
await self.close_function()
# -----------------------------------------------------------------------------
class SnoopingTransport(Transport):
"""Transport wrapper that snoops on packets to/from a wrapped transport."""
@staticmethod
def create_with(
transport: Transport, snooper: ContextManager[Snooper]
) -> SnoopingTransport:
"""
Create an instance given a snooper that works as as context manager.
The returned instance will exit the snooper context when it is closed.
"""
with contextlib.ExitStack() as exit_stack:
return SnoopingTransport(
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
)
raise RuntimeError('unexpected code path') # Satisfy the type checker
class Source:
def __init__(self, source, snooper):
self.source = source
self.snooper = snooper
self.sink = None
def set_packet_sink(self, sink):
self.sink = sink
self.source.set_packet_sink(self)
def on_packet(self, packet):
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
if self.sink:
self.sink.on_packet(packet)
class Sink:
def __init__(self, sink, snooper):
self.sink = sink
self.snooper = snooper
def on_packet(self, packet):
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
if self.sink:
self.sink.on_packet(packet)
def __init__(self, transport, snooper, close_snooper=None):
super().__init__(
self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
)
self.transport = transport
self.close_snooper = close_snooper
async def close(self):
await self.transport.close()
if self.close_snooper:
self.close_snooper()

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,10 +16,9 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: emulated_bluetooth_packets.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
@@ -31,20 +30,10 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3'
)
_HCIPACKET = DESCRIPTOR.message_types_by_name['HCIPacket']
_HCIPACKET_PACKETTYPE = _HCIPACKET.enum_types_by_name['PacketType']
HCIPacket = _reflection.GeneratedProtocolMessageType(
'HCIPacket',
(_message.Message,),
{
'DESCRIPTOR': _HCIPACKET,
'__module__': 'emulated_bluetooth_packets_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.HCIPacket)
},
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, 'emulated_bluetooth_packets_pb2', globals()
)
_sym_db.RegisterMessage(HCIPacket)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None

View File

@@ -0,0 +1,41 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class HCIPacket(_message.Message):
__slots__ = ["packet", "type"]
class PacketType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = []
PACKET_FIELD_NUMBER: _ClassVar[int]
PACKET_TYPE_ACL: HCIPacket.PacketType
PACKET_TYPE_EVENT: HCIPacket.PacketType
PACKET_TYPE_HCI_COMMAND: HCIPacket.PacketType
PACKET_TYPE_ISO: HCIPacket.PacketType
PACKET_TYPE_SCO: HCIPacket.PacketType
PACKET_TYPE_UNSPECIFIED: HCIPacket.PacketType
TYPE_FIELD_NUMBER: _ClassVar[int]
packet: bytes
type: HCIPacket.PacketType
def __init__(
self,
type: _Optional[_Union[HCIPacket.PacketType, str]] = ...,
packet: _Optional[bytes] = ...,
) -> None: ...

View File

@@ -0,0 +1,17 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,10 +16,9 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: emulated_bluetooth.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
@@ -34,20 +33,8 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3'
)
_RAWDATA = DESCRIPTOR.message_types_by_name['RawData']
RawData = _reflection.GeneratedProtocolMessageType(
'RawData',
(_message.Message,),
{
'DESCRIPTOR': _RAWDATA,
'__module__': 'emulated_bluetooth_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.RawData)
},
)
_sym_db.RegisterMessage(RawData)
_EMULATEDBLUETOOTHSERVICE = DESCRIPTOR.services_by_name['EmulatedBluetoothService']
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'emulated_bluetooth_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None

View File

@@ -0,0 +1,26 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import emulated_bluetooth_packets_pb2 as _emulated_bluetooth_packets_pb2
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Optional as _Optional
DESCRIPTOR: _descriptor.FileDescriptor
class RawData(_message.Message):
__slots__ = ["packet"]
PACKET_FIELD_NUMBER: _ClassVar[int]
packet: bytes
def __init__(self, packet: _Optional[bytes] = ...) -> None: ...

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,10 +16,9 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: emulated_bluetooth_vhci.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
@@ -27,15 +26,17 @@ from google.protobuf import symbol_database as _symbol_database
_sym_db = _symbol_database.Default()
import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
from . import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x1d\x65mulated_bluetooth_vhci.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto2y\n\x15VhciForwardingService\x12`\n\nattachVhci\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3'
)
_VHCIFORWARDINGSERVICE = DESCRIPTOR.services_by_name['VhciForwardingService']
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, 'emulated_bluetooth_vhci_pb2', globals()
)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None

View File

@@ -0,0 +1,19 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import emulated_bluetooth_packets_pb2 as _emulated_bluetooth_packets_pb2
from google.protobuf import descriptor as _descriptor
from typing import ClassVar as _ClassVar
DESCRIPTOR: _descriptor.FileDescriptor

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -30,8 +30,9 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_file_transport(spec):
'''
Open a File transport (typically not for a real file, but for a PTY or other unix virtual files).
The parameter string is the path of the file to open
Open a File transport (typically not for a real file, but for a PTY or other unix
virtual files).
The parameter string is the path of the file to open.
'''
# Open the file
@@ -39,12 +40,12 @@ async def open_file_transport(spec):
# Setup reading
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
lambda: StreamPacketSource(), file
StreamPacketSource, file
)
# Setup writing
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
lambda: asyncio.BaseProtocol(), file
asyncio.BaseProtocol, file
)
packet_sink = StreamPacketSink(write_transport)

View File

@@ -40,7 +40,7 @@ async def open_hci_socket_transport(spec):
or a 0-based integer to indicate the adapter number.
'''
HCI_CHANNEL_USER = 1
HCI_CHANNEL_USER = 1 # pylint: disable=invalid-name
# Create a raw HCI socket
try:
@@ -49,10 +49,12 @@ async def open_hci_socket_transport(spec):
socket.SOCK_RAW | socket.SOCK_NONBLOCK,
socket.BTPROTO_HCI,
)
except AttributeError:
except AttributeError as error:
# Not supported on this platform
logger.info("HCI sockets not supported on this platform")
raise Exception('Bluetooth HCI sockets not supported on this platform')
raise Exception(
'Bluetooth HCI sockets not supported on this platform'
) from error
# Compute the adapter index
if spec is None:
@@ -66,13 +68,19 @@ async def open_hci_socket_transport(spec):
try:
ctypes.cdll.LoadLibrary('libc.so.6')
libc = ctypes.CDLL('libc.so.6', use_errno=True)
except OSError:
except OSError as error:
logger.info("HCI sockets not supported on this platform")
raise Exception('Bluetooth HCI sockets not supported on this platform')
raise Exception(
'Bluetooth HCI sockets not supported on this platform'
) from error
libc.bind.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_char), ctypes.c_int)
libc.bind.restype = ctypes.c_int
bind_address = struct.pack(
'<HHH', socket.AF_BLUETOOTH, adapter_index, HCI_CHANNEL_USER
# pylint: disable=no-member
'<HHH',
socket.AF_BLUETOOTH,
adapter_index,
HCI_CHANNEL_USER,
)
if (
libc.bind(
@@ -85,11 +93,11 @@ async def open_hci_socket_transport(spec):
raise IOError(ctypes.get_errno(), os.strerror(ctypes.get_errno()))
class HciSocketSource(ParserSource):
def __init__(self, socket):
def __init__(self, hci_socket):
super().__init__()
self.socket = socket
self.socket = hci_socket
asyncio.get_running_loop().add_reader(
socket.fileno(), self.recv_until_would_block
self.socket.fileno(), self.recv_until_would_block
)
def recv_until_would_block(self):
@@ -107,8 +115,8 @@ async def open_hci_socket_transport(spec):
asyncio.get_running_loop().remove_reader(self.socket.fileno())
class HciSocketSink:
def __init__(self, socket):
self.socket = socket
def __init__(self, hci_socket):
self.socket = hci_socket
self.packets = collections.deque()
self.writer_added = False
@@ -127,10 +135,13 @@ async def open_hci_socket_transport(spec):
break
if self.packets:
# There's still something to send, ensure that we are monitoring the socket
# There's still something to send, ensure that we are monitoring the
# socket
if not self.writer_added:
asyncio.get_running_loop().add_writer(
socket.fileno(), self.send_until_would_block
# pylint: disable=no-member
self.socket.fileno(),
self.send_until_would_block,
)
self.writer_added = True
else:
@@ -148,9 +159,9 @@ async def open_hci_socket_transport(spec):
asyncio.get_running_loop().remove_writer(self.socket.fileno())
class HciSocketTransport(Transport):
def __init__(self, socket, source, sink):
def __init__(self, hci_socket, source, sink):
super().__init__(source, sink)
self.socket = socket
self.socket = hci_socket
async def close(self):
logger.debug('closing HCI socket transport')

View File

@@ -47,11 +47,11 @@ async def open_pty_transport(spec):
tty.setraw(replica)
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
lambda: StreamPacketSource(), io.open(primary, 'rb', closefd=False)
StreamPacketSource, io.open(primary, 'rb', closefd=False)
)
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
lambda: asyncio.BaseProtocol(), io.open(primary, 'wb', closefd=False)
asyncio.BaseProtocol, io.open(primary, 'wb', closefd=False)
)
packet_sink = StreamPacketSink(write_transport)

View File

View File

@@ -17,14 +17,15 @@
# -----------------------------------------------------------------------------
import asyncio
import logging
import usb.core
import usb.util
import threading
import time
from colors import color
import usb.core
import usb.util
from .common import Transport, ParserSource
from .. import hci
from ..colors import color
# -----------------------------------------------------------------------------
@@ -48,6 +49,7 @@ async def open_pyusb_transport(spec):
04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901
'''
# pylint: disable=invalid-name
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_ENDPOINT_EVENTS_IN = 0x81
@@ -108,7 +110,7 @@ async def open_pyusb_transport(spec):
def run(self):
while self.stop_event is None:
time.sleep(1)
self.loop.call_soon_threadsafe(lambda: self.stop_event.set())
self.loop.call_soon_threadsafe(self.stop_event.set)
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, device, sco_enabled):
@@ -116,6 +118,7 @@ async def open_pyusb_transport(spec):
self.device = device
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.dequeue_task = None
self.event_thread = threading.Thread(
target=self.run, args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET)
)
@@ -134,8 +137,8 @@ async def open_pyusb_transport(spec):
)
self.sco_thread.stop_event = None
def data_received(self, packet):
self.parser.feed_data(packet)
def data_received(self, data):
self.parser.feed_data(data)
def enqueue(self, packet):
self.queue.put_nowait(packet)
@@ -179,16 +182,17 @@ async def open_pyusb_transport(spec):
except usb.core.USBTimeoutError:
continue
except usb.core.USBError:
# Don't log this: because pyusb doesn't really support multiple threads
# reading at the same time, we can get occasional USBError(errno=5)
# Input/Output errors reported, but they seem to be harmless.
# Don't log this: because pyusb doesn't really support multiple
# threads reading at the same time, we can get occasional
# USBError(errno=5) Input/Output errors reported, but they seem to
# be harmless.
# Until support for async or multi-thread support is added to pyusb,
# we'll just live with this as is...
# logger.warning(f'USB read error: {error}')
time.sleep(1) # Sleep one second to avoid busy looping
stop_event = current_thread.stop_event
self.loop.call_soon_threadsafe(lambda: stop_event.set())
self.loop.call_soon_threadsafe(stop_event.set)
class UsbTransport(Transport):
def __init__(self, device, source, sink):
@@ -200,16 +204,22 @@ async def open_pyusb_transport(spec):
await self.sink.stop()
usb.util.release_interface(self.device, 0)
usb_find = usb.core.find
try:
import libusb_package
except ImportError:
logger.debug('libusb_package is not available')
else:
usb_find = libusb_package.find
# Find the device according to the spec moniker
if ':' in spec:
vendor_id, product_id = spec.split(':')
device = usb.core.find(
idVendor=int(vendor_id, 16), idProduct=int(product_id, 16)
)
device = usb_find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16))
else:
device_index = int(spec)
devices = list(
usb.core.find(
usb_find(
find_all=1,
bDeviceClass=USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
bDeviceSubClass=USB_DEVICE_SUBCLASS_RF_CONTROLLER,
@@ -242,6 +252,7 @@ async def open_pyusb_transport(spec):
# Select an alternate setting for SCO, if available
sco_enabled = False
# pylint: disable=line-too-long
# NOTE: this is disabled for now, because SCO with alternate settings is broken,
# see: https://github.com/libusb/libusb/issues/36
#

View File

@@ -60,7 +60,7 @@ async def open_serial_transport(spec):
device = spec
serial_transport, packet_source = await serial_asyncio.create_serial_connection(
asyncio.get_running_loop(),
lambda: StreamPacketSource(),
StreamPacketSource,
device,
baudrate=speed,
rtscts=rtscts,

View File

@@ -37,13 +37,13 @@ async def open_tcp_client_transport(spec):
'''
class TcpPacketSource(StreamPacketSource):
def connection_lost(self, error):
logger.debug(f'connection lost: {error}')
self.terminated.set_result(error)
def connection_lost(self, exc):
logger.debug(f'connection lost: {exc}')
self.terminated.set_result(exc)
remote_host, remote_port = spec.split(':')
tcp_transport, packet_source = await asyncio.get_running_loop().create_connection(
lambda: TcpPacketSource(),
TcpPacketSource,
host=remote_host,
port=int(remote_port),
)

View File

@@ -49,8 +49,8 @@ async def open_tcp_server_transport(spec):
# Called when a new connection is established
def connection_made(self, transport):
peername = transport.get_extra_info('peername')
logger.debug('connection from {}'.format(peername))
peer_name = transport.get_extra_info('peer_name')
logger.debug(f'connection from {peer_name}')
self.packet_sink.transport = transport
# Called when the client is disconnected

View File

@@ -57,7 +57,7 @@ async def open_udp_transport(spec):
udp_transport,
packet_source,
) = await asyncio.get_running_loop().create_datagram_endpoint(
lambda: UdpPacketSource(),
UdpPacketSource,
local_addr=(local_host, int(local_port)),
remote_addr=(remote_host, int(remote_port)),
)

View File

@@ -17,13 +17,16 @@
# -----------------------------------------------------------------------------
import asyncio
import logging
import usb1
import threading
import collections
from colors import color
import ctypes
import platform
import usb1
from .common import Transport, ParserSource
from .. import hci
from ..colors import color
# -----------------------------------------------------------------------------
@@ -33,6 +36,30 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def load_libusb():
'''
Attempt to load the libusb-1.0 C library from libusb_package in site-packages.
If the library exists, we create a DLL object and initialize the usb1 backend.
This only needs to be done once, but before a usb1.USBContext is created.
If the library does not exists, do nothing and usb1 will search default system paths
when usb1.USBContext is created.
'''
try:
import libusb_package
except ImportError:
logger.debug('libusb_package is not available')
else:
if libusb_path := libusb_package.get_library_path():
logger.debug(f'loading libusb library at {libusb_path}')
dll_loader = (
ctypes.WinDLL if platform.system() == 'Windows' else ctypes.CDLL
)
libusb_dll = dll_loader(
str(libusb_path), use_errno=True, use_last_error=True
)
usb1.loadLibrary(libusb_dll)
async def open_usb_transport(spec):
'''
Open a USB transport.
@@ -44,21 +71,26 @@ async def open_usb_transport(spec):
With <index> as the 0-based index to select amongst all the devices that appear
to be supporting Bluetooth HCI (0 being the first one), or
Where <vendor> and <product> are the vendor ID and product ID in hexadecimal. The
/<serial-number> suffix or #<index> suffix max be specified when more than one device with
the same vendor and product identifiers are present.
/<serial-number> suffix or #<index> suffix max be specified when more than one
device with the same vendor and product identifiers are present.
In addition, if the moniker ends with the symbol "!", the device will be used in "forced" mode:
the first USB interface of the device will be used, regardless of the interface class/subclass.
This may be useful for some devices that use a custom class/subclass but may nonetheless work as-is.
In addition, if the moniker ends with the symbol "!", the device will be used in
"forced" mode:
the first USB interface of the device will be used, regardless of the interface
class/subclass.
This may be useful for some devices that use a custom class/subclass but may
nonetheless work as-is.
Examples:
0 --> the first BT USB dongle
04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901
04b4:f901#2 --> the third USB device with vendor=04b4 and product=f901
04b4:f901/00E04C239987 --> the BT USB dongle with vendor=04b4 and product=f901 and serial number 00E04C239987
04b4:f901/00E04C239987 --> the BT USB dongle with vendor=04b4 and product=f901 and
serial number 00E04C239987
usb:0B05:17CB! --> the BT USB dongle vendor=0B05 and product=17CB, in "forced" mode.
'''
# pylint: disable=invalid-name
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_DEVICE_CLASS_DEVICE = 0x00
@@ -109,6 +141,7 @@ async def open_usb_transport(spec):
status = transfer.getStatus()
# logger.debug(f'<<< USB out transfer callback: status={status}')
# pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED:
self.loop.call_soon_threadsafe(self.on_packet_sent_)
elif status == usb1.TRANSFER_CANCELLED:
@@ -149,15 +182,20 @@ async def open_usb_transport(spec):
else:
logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
async def close(self):
def close(self):
self.closed = True
async def terminate(self):
if not self.closed:
self.close()
# Empty the packet queue so that we don't send any more data
self.packets.clear()
# If we have a transfer in flight, cancel it
if self.transfer.isSubmitted():
# Try to cancel the transfer, but that may fail because it may have already completed
# Try to cancel the transfer, but that may fail because it may have
# already completed
try:
self.transfer.cancel()
@@ -176,12 +214,15 @@ async def open_usb_transport(spec):
self.events_in = events_in
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.dequeue_task = None
self.closed = False
self.event_loop_done = self.loop.create_future()
self.cancel_done = {
hci.HCI_EVENT_PACKET: self.loop.create_future(),
hci.HCI_ACL_DATA_PACKET: self.loop.create_future(),
}
self.events_in_transfer = None
self.acl_in_transfer = None
# Create a thread to process events
self.event_thread = threading.Thread(target=self.run)
@@ -212,8 +253,13 @@ async def open_usb_transport(spec):
def on_packet_received(self, transfer):
packet_type = transfer.getUserData()
status = transfer.getStatus()
# logger.debug(f'<<< USB IN transfer callback: status={status} packet_type={packet_type} length={transfer.getActualLength()}')
# logger.debug(
# f'<<< USB IN transfer callback: status={status} '
# f'packet_type={packet_type} '
# f'length={transfer.getActualLength()}'
# )
# pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED:
packet = (
bytes([packet_type])
@@ -247,6 +293,7 @@ async def open_usb_transport(spec):
self.events_in_transfer.isSubmitted()
or self.acl_in_transfer.isSubmitted()
):
# pylint: disable=no-member
try:
self.context.handleEvents()
except usb1.USBErrorInterrupted:
@@ -255,19 +302,26 @@ async def open_usb_transport(spec):
logger.debug('USB event loop done')
self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
async def close(self):
def close(self):
self.closed = True
async def terminate(self):
if not self.closed:
self.close()
self.dequeue_task.cancel()
# Cancel the transfers
for transfer in (self.events_in_transfer, self.acl_in_transfer):
if transfer.isSubmitted():
# Try to cancel the transfer, but that may fail because it may have already completed
# Try to cancel the transfer, but that may fail because it may have
# already completed
packet_type = transfer.getUserData()
try:
transfer.cancel()
logger.debug(
f'waiting for IN[{packet_type}] transfer cancellation to be done...'
f'waiting for IN[{packet_type}] transfer cancellation '
'to be done...'
)
await self.cancel_done[packet_type]
logger.debug(f'IN[{packet_type}] transfer cancellation done')
@@ -298,13 +352,16 @@ async def open_usb_transport(spec):
sink.start()
async def close(self):
await self.source.close()
await self.sink.close()
self.source.close()
self.sink.close()
await self.source.terminate()
await self.sink.terminate()
self.device.releaseInterface(self.interface)
self.device.close()
self.context.close()
# Find the device according to the spec moniker
load_libusb()
context = usb1.USBContext()
context.open()
try:
@@ -383,6 +440,7 @@ async def open_usb_transport(spec):
# Look for the first interface with the right class and endpoints
def find_endpoints(device):
# pylint: disable-next=too-many-nested-blocks
for (configuration_index, configuration) in enumerate(device):
interface = None
for interface in configuration:
@@ -431,10 +489,13 @@ async def open_usb_transport(spec):
acl_out,
events_in,
)
else:
logger.debug(
f'skipping configuration {configuration_index + 1} / interface {setting.getNumber()}'
)
logger.debug(
f'skipping configuration {configuration_index + 1} / '
f'interface {setting.getNumber()}'
)
return None
endpoints = find_endpoints(found)
if endpoints is None:
@@ -452,6 +513,7 @@ async def open_usb_transport(spec):
device = found.open()
# Auto-detach the kernel driver if supported
# pylint: disable=no-member
if usb1.hasCapability(usb1.CAP_SUPPORTS_DETACH_KERNEL_DRIVER):
try:
logger.debug('auto-detaching kernel driver')

View File

@@ -44,11 +44,13 @@ async def open_ws_server_transport(spec):
source = ParserSource()
sink = PumpedPacketSink(self.send_packet)
self.connection = asyncio.get_running_loop().create_future()
self.server = None
super().__init__(source, sink)
async def serve(self, local_host, local_port):
self.sink.start()
# pylint: disable-next=no-member
self.server = await websockets.serve(
ws_handler=self.on_connection,
host=local_host if local_host != '_' else None,
@@ -58,15 +60,17 @@ async def open_ws_server_transport(spec):
async def on_connection(self, connection):
logger.debug(
f'new connection on {connection.local_address} from {connection.remote_address}'
f'new connection on {connection.local_address} '
f'from {connection.remote_address}'
)
self.connection.set_result(connection)
# pylint: disable=no-member
try:
async for packet in connection:
if type(packet) is bytes:
if isinstance(packet, bytes):
self.source.parser.feed_data(packet)
else:
logger.warn('discarding packet: not a BINARY frame')
logger.warning('discarding packet: not a BINARY frame')
except websockets.WebSocketException as error:
logger.debug(f'exception while receiving packet: {error}')

View File

@@ -19,10 +19,12 @@ import asyncio
import logging
import traceback
import collections
import sys
from typing import Awaitable, Set, TypeVar
from functools import wraps
from colors import color
from pyee import EventEmitter
from .colors import color
# -----------------------------------------------------------------------------
# Logging
@@ -45,6 +47,7 @@ def composite_listener(cls):
registers/deregisters all methods named `on_<event_name>` as a listener for
the <event_name> event with an emitter.
"""
# pylint: disable=protected-access
def register(self, emitter):
for method_name in dir(cls):
@@ -62,7 +65,41 @@ def composite_listener(cls):
# -----------------------------------------------------------------------------
class CompositeEventEmitter(EventEmitter):
_T = TypeVar('_T')
class AbortableEventEmitter(EventEmitter):
def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]:
"""
Set a coroutine or future to abort when an event occur.
"""
future = asyncio.ensure_future(awaitable)
if future.done():
return future
def on_event(*_):
if future.done():
return
msg = f'abort: {event} event occurred.'
if isinstance(future, asyncio.Task):
# python < 3.9 does not support passing a message on `Task.cancel`
if sys.version_info < (3, 9, 0):
future.cancel()
else:
future.cancel(msg)
else:
future.set_exception(asyncio.CancelledError(msg))
def on_done(_):
self.remove_listener(event, on_event)
self.on(event, on_event)
future.add_done_callback(on_done)
return future
# -----------------------------------------------------------------------------
class CompositeEventEmitter(AbortableEventEmitter):
def __init__(self):
super().__init__()
self._listener = None
@@ -73,6 +110,7 @@ class CompositeEventEmitter(EventEmitter):
@listener.setter
def listener(self, listener):
# pylint: disable=protected-access
if self._listener:
# Call the deregistration methods for each base class that has them
for cls in self._listener.__class__.mro():
@@ -119,6 +157,9 @@ class AsyncRunner:
# Shared default queue
default_queue = WorkQueue()
# Shared set of running tasks
running_tasks: Set[Awaitable] = set()
@staticmethod
def run_in_task(queue=None):
"""
@@ -136,7 +177,8 @@ class AsyncRunner:
await coroutine
except Exception:
logger.warning(
f'{color("!!! Exception in wrapper:", "red")} {traceback.format_exc()}'
f'{color("!!! Exception in wrapper:", "red")} '
f'{traceback.format_exc()}'
)
asyncio.create_task(run())
@@ -148,6 +190,19 @@ class AsyncRunner:
return decorator
@staticmethod
def spawn(coroutine):
"""
Spawn a task to run a coroutine in a "fire and forget" mode.
Using this method instead of just calling `asyncio.create_task(coroutine)`
is necessary when you don't keep a reference to the task, because `asyncio`
only keeps weak references to alive tasks.
"""
task = asyncio.create_task(coroutine)
AsyncRunner.running_tasks.add(task)
task.add_done_callback(AsyncRunner.running_tasks.remove)
# -----------------------------------------------------------------------------
class FlowControlAsyncPipe:

View File

@@ -2,7 +2,7 @@ Bumble Documentation
====================
The documentation consists of a collection of markdown text files, with the root of the file
hierarchy at `docs/mkdocs/src`, starting with `docs/mkdocs/src/index.md`.
hierarchy at `docs/mkdocs/src`, starting with `docs/mkdocs/src/index.md`.
You can read the documentation as text, with any text viewer or your favorite markdown viewer,
or generate a static HTML "site" using `mkdocs`, which you can then open with any browser.
@@ -14,9 +14,9 @@ The `mkdocs` directory contains all the data (actual documentation) and metadata
`mkdocs/mkdocs.yml` contains the site configuration.
`mkdocs/src/` is the directory where the actual documentation text, in markdown format, is located.
To build, from the project's root directory:
To build, from the project's root directory:
```
$ mkdocs build -f docs/mkdocs/mkdocs.yml
$ mkdocs build -f docs/mkdocs/mkdocs.yml
```
You can then open `docs/mkdocs/site/index.html` with any web browser.

File diff suppressed because one or more lines are too long

View File

@@ -1 +1 @@
{"date":644900643.85054696,"appVersion":"4.1.5","drawing":{"modificationDate":644894800.328192,"activeArtboardIndex":0,"settings":{"outlineMode":false,"isolateActiveLayer":false,"snapToEdges":false,"snapToPoints":false,"guidesVisible":true,"snapToGrid":false,"units":"Pixels","dimensionsVisible":true,"dynamicGuides":false,"isCMYKColorPreviewEnabled":false,"undoHistoryDisabled":false,"snapToGuides":true,"drawOnlyUsingPencil":false,"whiteBackground":false,"rulersVisible":true,"isTimeLapseWatermarkDisabled":false},"artboardPaths":["Artboard0.json"],"documentVersion":"unknown"}}
{"date":644900643.85054696,"appVersion":"4.1.5","drawing":{"modificationDate":644894800.328192,"activeArtboardIndex":0,"settings":{"outlineMode":false,"isolateActiveLayer":false,"snapToEdges":false,"snapToPoints":false,"guidesVisible":true,"snapToGrid":false,"units":"Pixels","dimensionsVisible":true,"dynamicGuides":false,"isCMYKColorPreviewEnabled":false,"undoHistoryDisabled":false,"snapToGuides":true,"drawOnlyUsingPencil":false,"whiteBackground":false,"rulersVisible":true,"isTimeLapseWatermarkDisabled":false},"artboardPaths":["Artboard0.json"],"documentVersion":"unknown"}}

View File

@@ -1 +1 @@
{"documentJSONFilename":"Document.json","undoHistoryJSONFilename":"UndoHistory.json","fileFormatVersion":0,"thumbnailImageFilename":"Thumbnail.png"}
{"documentJSONFilename":"Document.json","undoHistoryJSONFilename":"UndoHistory.json","fileFormatVersion":0,"thumbnailImageFilename":"Thumbnail.png"}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1 +1 @@
{"date":644900741.09290397,"appVersion":"4.1.5","drawing":{"modificationDate":644894800.328192,"activeArtboardIndex":0,"settings":{"outlineMode":false,"isolateActiveLayer":false,"snapToEdges":false,"snapToPoints":false,"guidesVisible":true,"snapToGrid":false,"units":"Pixels","dimensionsVisible":true,"dynamicGuides":false,"isCMYKColorPreviewEnabled":false,"undoHistoryDisabled":false,"snapToGuides":true,"drawOnlyUsingPencil":false,"whiteBackground":false,"rulersVisible":true,"isTimeLapseWatermarkDisabled":false},"artboardPaths":["Artboard0.json"],"documentVersion":"unknown"}}
{"date":644900741.09290397,"appVersion":"4.1.5","drawing":{"modificationDate":644894800.328192,"activeArtboardIndex":0,"settings":{"outlineMode":false,"isolateActiveLayer":false,"snapToEdges":false,"snapToPoints":false,"guidesVisible":true,"snapToGrid":false,"units":"Pixels","dimensionsVisible":true,"dynamicGuides":false,"isCMYKColorPreviewEnabled":false,"undoHistoryDisabled":false,"snapToGuides":true,"drawOnlyUsingPencil":false,"whiteBackground":false,"rulersVisible":true,"isTimeLapseWatermarkDisabled":false},"artboardPaths":["Artboard0.json"],"documentVersion":"unknown"}}

View File

@@ -1 +1 @@
{"documentJSONFilename":"Document.json","undoHistoryJSONFilename":"UndoHistory.json","fileFormatVersion":0,"thumbnailImageFilename":"Thumbnail.png"}
{"documentJSONFilename":"Document.json","undoHistoryJSONFilename":"UndoHistory.json","fileFormatVersion":0,"thumbnailImageFilename":"Thumbnail.png"}

File diff suppressed because one or more lines are too long

View File

@@ -7,6 +7,8 @@ nav:
- Getting Started: getting_started.md
- Development:
- Python Environments: development/python_environments.md
- Contributing: development/contributing.md
- Code Style: development/code_style.md
- Use Cases:
- Overview: use_cases/index.md
- Use Case 1: use_cases/use_case_1.md
@@ -41,7 +43,7 @@ nav:
- Apps & Tools:
- Overview: apps_and_tools/index.md
- Console: apps_and_tools/console.md
- Link Relay: apps_and_tools/link_relay.md
- Bench: apps_and_tools/bench.md
- HCI Bridge: apps_and_tools/hci_bridge.md
- Golden Gate Bridge: apps_and_tools/gg_bridge.md
- Show: apps_and_tools/show.md
@@ -49,6 +51,7 @@ nav:
- Pair: apps_and_tools/pair.md
- Unbond: apps_and_tools/unbond.md
- USB Probe: apps_and_tools/usb_probe.md
- Link Relay: apps_and_tools/link_relay.md
- Hardware:
- Overview: hardware/index.md
- Platforms:
@@ -60,7 +63,7 @@ nav:
- Examples:
- Overview: examples/index.md
copyright: Copyright 2021-2022 Google LLC
copyright: Copyright 2021-2023 Google LLC
theme:
name: 'material'

View File

@@ -3,4 +3,4 @@ mkdocs == 1.4.0
mkdocs-material == 8.5.6
mkdocs-material-extensions == 1.0.3
pymdown-extensions == 9.6
mkdocstrings-python == 0.7.1
mkdocstrings-python == 0.7.1

View File

@@ -1,2 +1,2 @@
API EXAMPLES
============
============

View File

@@ -1,2 +1,2 @@
API DEVELOPER GUIDE
===================
===================

View File

@@ -16,4 +16,3 @@ Bumble Python API
### HCI_Disconnect_Command
::: bumble.hci.HCI_Disconnect_Command

View File

@@ -0,0 +1,158 @@
BENCH TOOL
==========
The "bench" tool implements a number of different ways of measuring the
throughput and/or latency between two devices.
# General Usage
```
Usage: bench.py [OPTIONS] COMMAND [ARGS]...
Options:
--device-config FILENAME Device configuration file
--role [sender|receiver|ping|pong]
--mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server]
--att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517]
-s, --packet-size SIZE Packet size (server role) [8<=x<=4096]
-c, --packet-count COUNT Packet count (server role)
-sd, --start-delay SECONDS Start delay (server role)
--help Show this message and exit.
Commands:
central Run as a central (initiates the connection)
peripheral Run as a peripheral (waits for a connection)
```
## Options for the ``central`` Command
```
Usage: bumble-bench central [OPTIONS] TRANSPORT
Run as a central (initiates the connection)
Options:
--peripheral ADDRESS_OR_NAME Address or name to connect to
--connection-interval, --ci CONNECTION_INTERVAL
Connection interval (in ms)
--phy [1m|2m|coded] PHY to use
--help Show this message and exit.
```
To test once device against another, one of the two devices must be running
the ``peripheral`` command and the other the ``central`` command. The device
running the ``peripheral`` command will accept connections from the device
running the ``central`` command.
When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils),
the default addresses configured in the tool should be sufficient. But when using
Bluetooth Classic, the address of the Peripheral must be specified on the Central
using the ``--peripheral`` option. The address will be printed by the Peripheral when
it starts.
Independently of whether the device is the Central or Peripheral, each device selects a
``mode`` and and ``role`` to run as. The ``mode`` and ``role`` of the Central and Peripheral
must be compatible.
Device 1 mode | Device 2 mode
------------------|------------------
``gatt-client`` | ``gatt-server``
``l2cap-client`` | ``l2cap-server``
``rfcomm-client`` | ``rfcomm-server``
Device 1 role | Device 2 role
--------------|--------------
``sender`` | ``receiver``
``ping`` | ``pong``
# Examples
In the following examples, we have two USB Bluetooth controllers, one on `usb:0` and
the other on `usb:1`, and two consoles/terminals. We will run a command in each.
!!! example "GATT Throughput"
Using the default mode and role for the Central and Peripheral.
In the first console/terminal:
```
$ bumble-bench peripheral usb:0
```
In the second console/terminal:
```
$ bumble-bench central usb:1
```
In this default configuration, the Central runs a Sender, as a GATT client,
connecting to the Peripheral running a Receiver, as a GATT server.
!!! example "L2CAP Throughput"
In the first console/terminal:
```
$ bumble-bench --mode l2cap-server peripheral usb:0
```
In the second console/terminal:
```
$ bumble-bench --mode l2cap-client central usb:1
```
!!! example "RFComm Throughput"
In the first console/terminal:
```
$ bumble-bench --mode rfcomm-server peripheral usb:0
```
NOTE: the BT address of the Peripheral will be printed out, use it with the
``--peripheral`` option for the Central.
In this example, we use a larger packet size and packet count than the default.
In the second console/terminal:
```
$ bumble-bench --mode rfcomm-client --packet-size 2000 --packet-count 100 central --peripheral 00:16:A4:5A:40:F2 usb:1
```
!!! example "Ping/Pong Latency"
In the first console/terminal:
```
$ bumble-bench --role pong peripheral usb:0
```
In the second console/terminal:
```
$ bumble-bench --role ping central usb:1
```
!!! example "Reversed modes with GATT and custom connection interval"
In the first console/terminal:
```
$ bumble-bench --mode gatt-client peripheral usb:0
```
In the second console/terminal:
```
$ bumble-bench --mode gatt-server central --ci 10 usb:1
```
!!! example "Reversed modes with L2CAP and custom PHY"
In the first console/terminal:
```
$ bumble-bench --mode l2cap-client peripheral usb:0
```
In the second console/terminal:
```
$ bumble-bench --mode l2cap-server central --phy 2m usb:1
```
!!! example "Reversed roles with L2CAP"
In the first console/terminal:
```
$ bumble-bench --mode l2cap-client --role sender peripheral usb:0
```
In the second console/terminal:
```
$ bumble-bench --mode l2cap-server --role receiver central usb:1
```

View File

@@ -1,2 +1,2 @@
GOLDEN GATE BRIDGE
==================
==================

View File

@@ -28,5 +28,3 @@ a host that send custom HCI commands that the controller may not understand.
(through which the communication with other virtual controllers will be mediated).
NOTE: this assumes you're running a Link Relay on port `10723`.

View File

@@ -5,10 +5,10 @@ Included in the project are a few apps and tools, built on top of the core libra
These include:
* [Console](console.md) - an interactive text-based console
* [Bench](bench.md) - Speed and Latency benchmarking between two devices (LE and Classic)
* [Pair](pair.md) - Pair/bond two devices (LE and Classic)
* [Unbond](unbond.md) - Remove a previously established bond
* [HCI Bridge](hci_bridge.md) - a HCI transport bridge to connect two HCI transports and filter/snoop the HCI packets
* [Golden Gate Bridge](gg_bridge.md) - a bridge between GATT and UDP to use with the Golden Gate "stack tool"
* [Show](show.md) - Parse a file with HCI packets and print the details of each packet in a human readable form
* [Link Relay](link_relay.md) - WebSocket relay for virtual RemoteLink instances to communicate with each other.

View File

@@ -31,4 +31,3 @@ The WebSocket path used by a connecting client indicates which virtual "chat roo
It is possible to connect to a "chat room" in a relay as an observer, rather than a virtual controller. In this case, a text-based console can be used to observe what is going on in the "chat room". Tools like [`wscat`](https://github.com/websockets/wscat#readme) or [`websocat`](https://github.com/vi/websocat) can be used for that.
Example: `wscat --connect ws://localhost:10723/test`

View File

@@ -3,8 +3,8 @@ USB PROBE TOOL
This tool lists all the USB devices, with details about each device.
For each device, the different possible Bumble transport strings that can
refer to it are listed.
If the device is known to be a Bluetooth HCI device, its identifier is printed
refer to it are listed.
If the device is known to be a Bluetooth HCI device, its identifier is printed
in reverse colors, and the transport names in cyan color.
For other devices, regardless of their type, the transport names are printed
in red. Whether that device is actually a Bluetooth device or not depends on
@@ -30,7 +30,7 @@ When running from the source distribution:
$ python3 apps/usb-probe.py
```
or
or
```
$ python3 apps/usb-probe.py --verbose
@@ -38,7 +38,7 @@ $ python3 apps/usb-probe.py --verbose
!!! example
```
$ python3 apps/usb_probe.py
$ python3 apps/usb_probe.py
ID 0A12:0001
Bumble Transport Names: usb:0 or usb:0A12:0001
@@ -47,4 +47,4 @@ $ python3 apps/usb-probe.py --verbose
Subclass/Protocol: 1/1 [Bluetooth]
Manufacturer: None
Product: USB2.0-BT
```
```

View File

@@ -1,2 +1,2 @@
CONTROLLER
==========
==========

View File

@@ -1,2 +1,2 @@
GATT
====
====

View File

@@ -1,2 +1,2 @@
HOST
====
====

View File

@@ -1,2 +1,2 @@
SECURITY MANAGER
================
================

View File

@@ -0,0 +1,43 @@
CODE STYLE
==========
The Python code style used in this project follows the [Black code style](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html).
# Formatting
For now, we are configuring the `black` formatter with the option to leave quotes unchanged.
The preferred quote style is single quotes, which isn't a configurable option for `Black`, so we are not enforcing it. This may change in the future.
## Ignoring Commit for Git Blame
The adoption of `Black` as a formatter came in late in the project, with already a large code base. As a result, a large number of files were changed in a single commit, which gets in the way of tracing authorship with `git blame`. The file `git-blame-ignore-revs` contains the commit hash of when that mass-formatting event occurred, which you can use to skip it in a `git blame` analysis:
!!! example "Ignoring a commit with `git blame`"
```
$ git blame --ignore-revs-file .git-blame-ignore-revs
```
# Linting
The project includes a `pylint` configuration (see the `pyproject.toml` file for details).
The `pre-commit` checks only enforce that there are no errors. But we strongly recommend that you run the linter with warnings enabled at least, and possibly the "Refactor" ('R') and "Convention" ('C') categories as well.
To run the linter, use the `project.lint` invoke command.
!!! example "Running the linter with default options"
With the default settings, Errors and Warnings are enabled, but Refactor and Convention categories are not.
```
$ invoke project.lint
```
!!! example "Running the linter with all categories"
```
$ invoke project.lint --disable=""
```
# Editor/IDE Integration
## Visual Studio Code
The project includes a `.vscode/settings.json` file that specifies the `black` formatter and enables an editor ruler at 88 columns.
You may want to configure your own environment to "format on save" with `black` if you find that useful. We are not making that choice at the workspace level.

Some files were not shown because too many files have changed in this diff Show More