Compare commits

..

186 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
6e0129b71d fix libusb loading on Windows 2022-12-18 22:00:26 -08:00
Gilles Boccon-Gibod
7ae3a1d973 Merge pull request #101 from google/gbg/formatting-linting-automation
formatting linting automation
2022-12-16 19:39:28 -08:00
Gilles Boccon-Gibod
c2959dadb4 formatting and linting automation
Squashed commits:
[cd479ba] formatting and linting automation
[7fbfabb] formatting and linting automation
[c4f9505] fix after rebase
[f506ad4] rename job
[441d517] update doc (+7 squashed commits)
[2e1b416] fix invoke and github action
[6ae5bb4] doc for git blame
[44b5461] add GitHub action
[b07474f] add docs
[4cd9a6f] more linter fixes
[db71901] wip
[540dc88] wip
2022-12-15 23:07:17 -08:00
Michael Mogenson
80fe2ea422 Merge pull request #102 from mogenson/libusb_package
Load libusb-1.0 shared library from libusb_package wheel
2022-12-15 21:53:56 -05:00
Lucas Abel
08e6590a76 Merge pull request #88 from google/uael/abort_on_event
Host: spawn each asynchronous task with the right aliveness
2022-12-15 12:46:37 -08:00
Abel Lucas
f580ffcbc3 device: set as Secure Connection when encrypted with AES 2022-12-15 17:02:21 +00:00
Abel Lucas
5178c866ac classic: add to .encrypt the possibilty to disable encryption 2022-12-15 17:02:21 +00:00
Abel Lucas
441933bd64 reverted: 662704e "classic: complete authentication when being the .authenticate acceptor" 2022-12-15 17:02:21 +00:00
Abel Lucas
287df94090 host: spawn each asynchronous task with the right aliveness 2022-12-15 17:02:21 +00:00
Michael Mogenson
86f9496575 Load libusb-1.0 shared library from libusb_package wheel
It would be nice to pip install bumble without having to first install
the libusb system dependency. Expecially on platforms like Windows and
Mac, without a default package manager.

The libusb_package Python package distributes prebuilt libusb-1.0 shared
libraries for each OS and architecture as binary wheels for the pyusb
project. Add this package as a dependency for bumble.

For the pyusb transport, the libusb_package.find() function is a drop-in
replacement for pyusb.core.find(). It searches the libusb_package
site-path before system paths and creates a pyusb backend.

For the usb transport, use libusb_package.get_library_path() to return a
path to the libusb-1.0 library in site-packages. If this path exists,
create a ctypes DLL and init the usb1 backend. This only needs to be
done once. All future calls to usb1 will use this opened library.
If the library path does not exist, do nothing, and usb1 will search
default system paths when the usb1.USBContext object is created.

This commit pins the libusb_package dependency at 1.0.26.0 to ensure
every bumble install uses the exact same version of the libusb library.
2022-12-15 10:22:02 -05:00
Gilles Boccon-Gibod
f5fe3d87f2 Merge pull request #98 from yuyangh/yuyangh/update_asha_advertising
update ASHA AdvertisingData
2022-12-12 15:23:47 -08:00
Yuyang Huang
f65bed2ec4 Merge branch 'main' into yuyangh/update_asha_advertising 2022-12-12 13:35:17 -08:00
Gilles Boccon-Gibod
3efe35065d Merge pull request #96 from google/gbg/black
format with Black
2022-12-12 13:27:58 -08:00
Yuyang Huang
83b42488ea update ASHA AdvertisingData
previously the ASHA AdvertisingData uses INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, now it lets user to define whether it is complete list or incomplete list
2022-12-12 11:42:48 -08:00
Gilles Boccon-Gibod
135df0dcc0 format with Black 2022-12-10 09:40:12 -08:00
Gilles Boccon-Gibod
8bef344879 Merge pull request #94 from AlanRosenthal/alan/bumble_version_in_show_device
Add bumble's version to `show device`
2022-12-10 09:00:30 -08:00
Alan Rosenthal
55e2f23e29 Add bumble's version to show device 2022-12-09 12:23:45 -05:00
Gilles Boccon-Gibod
297246fa4c Merge pull request #92 from yuyangh/yuyangh/add_ASHA_GATT
add ASHA profile
2022-12-07 13:16:01 -08:00
Yuyang Huang
52db1cfcc1 improve code style 2022-12-06 07:38:05 -08:00
Yuyang Huang
29f9a79502 improve get service advertising data 2022-12-05 11:22:07 -08:00
Gilles Boccon-Gibod
c86125de4f Merge pull request #93 from AlanRosenthal/alan/add_default_services
Add Device::add_default_services()
2022-12-01 12:36:03 -08:00
Yuyang Huang
697d5df3f8 code style update 2022-12-01 10:50:15 -08:00
Yuyang Huang
87aa4f617e add ASHA advertising factory method 2022-12-01 10:40:30 -08:00
Alan Rosenthal
a8eff737e6 Add Device::add_default_services()
This will allow a test to:
a: add services to a device
b: reset services via `Server()`
c: add the default services back
2022-12-01 17:02:54 +00:00
Gilles Boccon-Gibod
4417eb636c Merge pull request #83 from AlanRosenthal/alan/pytest_fixes_2
Test all python versions in CI
2022-11-29 12:48:35 -08:00
Alan Rosenthal
f4e5e61bbb Test all python versions in CI
Followed instructions here: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#using-the-python-starter-workflow
2022-11-29 20:22:56 +00:00
Lucas Abel
ba7a60025f Merge pull request #89 from google/uael/misc
Typos & CI fixes
2022-11-29 12:14:03 -08:00
Yuyang Huang
d92b7e9b74 add ASHA Profile 2022-11-29 10:34:59 -08:00
Yuyang Huang
b0336adf1c add ASHA GATT UUID 2022-11-29 10:02:40 -08:00
Abel Lucas
691450c7de gatt: fix CharacteristicDeclaration.__str__ and associated test 2022-11-29 16:43:47 +00:00
Abel Lucas
99a0eb21c1 address: fix deprecated use of combined @classmethod and @property 2022-11-29 16:33:12 +00:00
Abel Lucas
ab4859bd94 device: fix typos 2022-11-29 16:33:12 +00:00
Lucas Abel
0d70cbde64 Merge pull request #75 from google/uael/fixes
Pairing: device/host fixes & improvements
2022-11-28 21:42:43 -08:00
Gilles Boccon-Gibod
f41d0682b2 Merge pull request #80 from AlanRosenthal/alan/gatt_server_getter
Added class CharacteristicDeclaration, gatt_server getters
2022-11-28 19:21:08 -08:00
Gilles Boccon-Gibod
062dc1e53d Merge pull request #85 from AlanRosenthal/alan/gatt_server_console2
Add `bumble-console --device-config` support for gatt services
2022-11-28 19:19:25 -08:00
Abel Lucas
662704e551 classic: complete authentication when being the .authenticate acceptor 2022-11-29 00:28:39 +00:00
Abel Lucas
02a474c44e smp: emit enough information on pairing complete to deduce security level 2022-11-29 00:28:38 +00:00
Abel Lucas
a1c7aec492 device: fix .find_connection_by_bd_addr 2022-11-29 00:28:38 +00:00
Abel Lucas
6112f00049 device: introduce BR/EDR pending connections
This commit enable the BR/EDR pairing to run asynchronously to
the connection being established.

When in security mode 3, a controller shall start authentication as
part of the connection, which result in HCI events being sent on a BD
address without a completed connection (ie. no connection handle).
2022-11-29 00:28:38 +00:00
Alan Rosenthal
f56ac14f2c Add bumble-console --device-config support for gatt services
This PR adds support for bumble-console to be preloaded with gatt services via `--device-config`.
This PR also adds some type annotations
2022-11-28 14:11:27 -05:00
Gilles Boccon-Gibod
a739fc71ce Merge pull request #84 from google/gbg/78
use libusb auto-detach feature
2022-11-27 12:42:02 -08:00
Alan Rosenthal
b89f9030a0 Added class CharacteristicDeclaration, gatt_server getters
* Converted CharacteristicDeclaration implementation to class
* Added ability to get a gatt_server attribute by service UUID, characteristics UUID, descriptor UUID
2022-11-27 19:22:25 +00:00
Gilles Boccon-Gibod
9e5a85bd10 use libusb auto-detach feature 2022-11-25 17:52:13 -08:00
Gilles Boccon-Gibod
b437bd8619 Merge pull request #82 from AlanRosenthal/alan/pytest_fixes
Fix test failures
2022-11-25 13:19:53 -08:00
Alan Rosenthal
a3e4674819 Fix test failures
a. `DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead`
Updated call in `bumble/smp.py`

b. `ModuleNotFoundError: No module named 'bumble.apps'`
Updated imports in `tests/import_test.py`

c. Added `pytest-html` for easier viewing of test results
Added package in `setup.cfg`, and hook in `tasks.py`

d. Updated workflows to use `invoke test`

This is a partial fix of #81
2022-11-23 11:31:27 -05:00
Abel Lucas
5f1d57fcb0 device: simplify and fixes remote name request 2022-11-22 21:20:56 +00:00
Gilles Boccon-Gibod
ae0b739e4a Merge pull request #79 from google/gbg/fix-host-reset
fix sequencing logic broken by earlier merge
2022-11-22 09:16:26 -08:00
Michael Mogenson
0570b59796 Merge pull request #77 from mogenson/l2cap-on-event
Fix for 'Host' object has no attribute 'add_listener'
2022-11-22 09:39:04 -05:00
Gilles Boccon-Gibod
22218627f6 fix sequencing logic broken by earlier merge 2022-11-21 21:07:47 -08:00
Michael Mogenson
1c72242264 Fix for 'Host' object has no attribute 'add_listener'
Pyee's add_listener() method was not added until release 9.0.0. Bumble's
setup.cfg specifies a minimum pyee version of 8.2.2.

Remove the call to add_listener() in l2cap.py. If the add_listener() API
is prefered over the on(), another solution would be to bump the pyee
version requirement.
2022-11-21 12:31:21 -05:00
Abel Lucas
9c133706e6 keys: add a way to remove all bonds from key store 2022-11-18 18:22:15 +00:00
Michael Mogenson
4988a31487 Merge pull request #76 from mogenson/connection-error-params
Swap arguments to ConnectionError in RFCOMM Multiplexer
2022-11-18 10:48:02 -05:00
Michael Mogenson
e6c062117f Swap arguments to ConnectionError in RFCOMM Multiplexer
Minor fixup. Change the order of arguments to ConnectionError to set the
transport and address correctly in rfcomm.py on_dm_frame().
2022-11-18 10:02:40 -05:00
Gilles Boccon-Gibod
f2133235d5 Merge pull request #73 from google/gbg/faster-l2cap-test
lower the number of test cases for l2cap in order to speed up the test
2022-11-15 10:49:55 -08:00
Gilles Boccon-Gibod
867e8c13dc lower the number of test cases for l2cap in order to speed up the test 2022-11-14 17:26:09 -08:00
Lucas Abel
25ce38c3f5 Merge pull request #72 from google/uael/public-str-address
address: add public information to the stringified value
2022-11-14 17:16:47 -08:00
Abel Lucas
c0810230a6 address: add public information to the stringified value
This affect the way security keys are stored. For instance the same
key can be used both as public and random, and it need to be stored
separately one from each other.
2022-11-14 20:05:12 +00:00
Michael Mogenson
27c46eef9d Merge pull request #71 from mogenson/prefer-notify
Add prefer_notify option to gatt_client subscribe()
2022-11-13 19:53:09 -05:00
Michael Mogenson
c140876157 Add prefer_notify option to gatt_client subscribe()
If characteristic supports Notify and Indicate, the prefer_notify option
will subscribe with Notify if True or Indicate if False.

If characteristic only supports one property, Notify or Indicate, that
mode will be selected, regardless of the prefer_notify setting.

Tested with a characteristic that supports both Notify and Indicate and
verified that prefer_notify sets the desired mode.
2022-11-13 19:38:12 -05:00
Lucas Abel
d743656f09 Merge pull request #68 from google/uael/pairing-improvements
Pairing improvements
2022-11-11 21:03:17 -08:00
Abel Lucas
b91d0e24c1 device: handle HCI passkey notification event 2022-11-11 18:43:35 +00:00
Abel Lucas
eb46f60c87 le: save own_address_type on ACL connection for SMP to be able to use the right self address 2022-11-10 02:06:37 +00:00
Abel Lucas
8bbba7c84c pairing: always ask user for confirmation, even in JUST_WORKS method 2022-11-10 01:58:02 +00:00
Gilles Boccon-Gibod
ee54df2d08 Merge pull request #65 from google/gbg/fix-classic-connect-await
fix classic connection event filtering
2022-11-09 14:40:29 -08:00
Gilles Boccon-Gibod
6549e53398 Merge pull request #60 from google/gbg/fix-console-logs
use a formatter object, not a string
2022-11-09 13:19:26 -08:00
Gilles Boccon-Gibod
0f219eff12 address PR comments 2022-11-09 13:18:30 -08:00
Gilles Boccon-Gibod
4a1345cf95 only force the type if the address is passed as a string 2022-11-08 19:10:13 -08:00
Gilles Boccon-Gibod
8a1cdef152 fix classic connection event filtering 2022-11-08 17:33:29 -08:00
Gilles Boccon-Gibod
6e1baf0344 use a formatter object, not a string 2022-11-08 13:19:41 -08:00
Lucas Abel
cea1905ffb Merge pull request #59 from google/uael/device-cleanup
le: pass `own_address_type` to BLE `Device.connect`
2022-11-08 11:50:40 -08:00
Abel Lucas
af8e0d4dc7 le: pass own_address_type to BLE Device.connect 2022-11-08 18:22:54 +00:00
Gilles Boccon-Gibod
875195aebb Merge pull request #58 from AlanRosenthal/main
Add definition of `Client Characteristic Configuration bit`
2022-11-08 09:34:22 -08:00
Gilles Boccon-Gibod
5aee37aeab Merge pull request #34 from google/gbg/l2cap-bridge
Add L2CAP CoC support
2022-11-07 16:57:17 -08:00
Gilles Boccon-Gibod
edcb7d05d6 fix merge conflict 2022-11-07 16:51:40 -08:00
Gilles Boccon-Gibod
ce9004f0ac Add L2CAP CoC support (squashed)
[85542e0] fix test
[3748781] add ASAH sink example
[e782e29] add app
[83daa30] wip
[7f138a0] add test
[f732108] allow different address syntax
[9d0bbf8] rename deprecated methods
[eb303d5] add LE CoC support
2022-11-07 16:45:37 -08:00
Alan Rosenthal
d4228e3b5b Add definition of Client Characteristic Configuration bit 2022-11-07 19:43:22 -05:00
Lucas Abel
be8f8ac68f Merge pull request #55 from google/uael/device-improvements
Device improvements
2022-11-07 15:22:41 -08:00
Abel Lucas
ca16410a6d device: add option to check for the address type when using find_connection_by_bd_addr 2022-11-07 22:17:01 +00:00
Abel Lucas
b95888eb39 le: permit legacy scanning even when extended is supported 2022-11-07 22:15:54 +00:00
Abel Lucas
56ed46adfa classic: add BR/EDR accept connection logic 2022-11-04 17:26:59 +00:00
Abel Lucas
7044102e05 classic: upgrade Device.cancel_connection logic to support canceling ongoing BR/EDR connections 2022-11-04 17:26:59 +00:00
Abel Lucas
ca8f284888 le: add own_address_type parameter to Device.start_advertising 2022-11-04 17:26:59 +00:00
Abel Lucas
e9e14f5183 le: make the device connecting state relative to LE only
We may need to add a distinct BR/EDR connecting state in the future.
2022-11-04 17:26:59 +00:00
Abel Lucas
b961affd3d device: update Device.connect documentation to match BR/EDR behavior 2022-11-04 17:26:59 +00:00
Abel Lucas
51ddb36c91 device: add auto_restart mechanism to .start_discovery (default to True) 2022-11-04 17:26:59 +00:00
Abel Lucas
78534b659a device: enhance .request_remote_name to also accept an Address as argument 2022-11-04 17:26:59 +00:00
Abel Lucas
ce9472bf42 core: change AdvertisingData.get default raw behavior to False 2022-11-04 17:26:59 +00:00
Abel Lucas
fc331b7aea core: improve Advertisement.ad_data_to_object with support for more data types 2022-11-04 17:26:59 +00:00
Abel Lucas
8119bc210c host: pass remote_host_supported_features event to upper layer
The `HCI_Remote_Name_Request` command may trigger this HCI event.
Instead of warn for not being handled, pass it to upper layer.
2022-11-02 20:23:14 +00:00
Abel Lucas
65deefdc64 host: allow bytes return paramaters when checking command result 2022-11-02 20:23:14 +00:00
Michael Mogenson
2920c3b0d1 Merge pull request #53 from mogenson/mogenson/show-device-tab
Add a show device tab
2022-10-24 09:15:43 -04:00
Michael Mogenson
f5cd825dbc Merge pull request #51 from mogenson/mogenson/console-py-rand-addr
Use random address in console.py if device config is not provided
2022-10-24 09:15:10 -04:00
Gilles Boccon-Gibod
cf4c43c4ff Merge pull request #48 from google/uael/classic-parallel-connect
classic: update `Device.connect` to allow parallels connection creation
2022-10-23 20:52:08 -07:00
Gilles Boccon-Gibod
da2f596e52 Merge pull request #50 from google/uael/command-timeout
device: raise a `CommandTimeoutError` error on command timeout
2022-10-23 20:49:06 -07:00
Gilles Boccon-Gibod
c8aa0b4ef6 Merge pull request #54 from google/gbg/fix-regression-001
use the correct constants as previously renamed
2022-10-23 20:43:43 -07:00
Gilles Boccon-Gibod
75ac276c8b use the correct constants as previously renamed 2022-10-21 17:12:26 -07:00
Michael Mogenson
dd4023ff56 Add a show device tab
Show configuration data about the Bumble device. Make this the default
tab on startup.
2022-10-21 16:00:03 -04:00
Michael Mogenson
dde8c5e1c2 Use random address in console.py if device config is not provided
If a device configuration is not provided on startup, generate a random
BT address instead of using a default static value of
"F0:F1:F2:F3:F4:F5". This is helpful to avoid colisions when there are
two instances of console.py running nearby.

Testing:
Started console.py and began advertising a few times. Scanned from a
second instance of console.py and observed that the advertising address
changed with each restart.
2022-10-21 15:32:58 -04:00
Michael Mogenson
8ed1f4b50d Merge pull request #52 from mogenson/mogenson/console-py-clear-scan-results
add 'scan clear' command to console.py
2022-10-21 14:34:08 -04:00
Michael Mogenson
92de7dff4f add 'scan clear' command to console.py
Add command to clear scan results and known addresses. Useful for
determining if a peripheral has stopped advertising.

Also, check if a scan is in progress before connecting. If it is, stop
scanning. Some BT controllers will fail to connect while scanning.

Testing:
Can clear scan results before, during, and after scan. Can clear scan
results while disconnected and connected.
2022-10-21 13:58:21 -04:00
Abel Lucas
16b4f18c92 tests: add parallel device connection test 2022-10-21 15:49:03 +00:00
Gilles Boccon-Gibod
46f4b82d29 Merge pull request #46 from AlanRosenthal/main
Add runtime switch for filtering by address.
2022-10-20 19:20:28 -07:00
Abel Lucas
4e2f66f709 device: raise a CommandTimeoutError error on command timeout 2022-10-20 22:11:07 +00:00
Alan Rosenthal
3d79d7def5 Add runtime switch for filtering by address.
* scan on [filter pattern]
* filter address <filter pattern>
2022-10-20 14:47:14 -04:00
Abel Lucas
915405a9bd examples: update run_classic_connect example to take multiple addresses instead of one 2022-10-20 14:53:39 +00:00
Abel Lucas
45dd849d9f classic: update ConnectionError to take transport and peer address 2022-10-20 14:53:03 +00:00
Abel Lucas
7208fd6642 classic: update Device.connect to allow parallels connection creation
According to the specification nothing prevent the Host from creating
multiple connections at the same time. This commit add this mechanisme
by matching the `connection` and `connection_failure` events against the
peer address.
2022-10-19 17:44:44 +00:00
Gilles Boccon-Gibod
eb8556ccf6 gbg/extended scanning (#47)
Squashed:
* add extended report class
* more HCI commands
* add AdvertisingType
* add phy options
* fix tests
2022-10-19 10:06:00 -07:00
Octavian Purdila
4d96b821bc Merge pull request #44 from google/tavip/fix-address-resolution
Fix address resolution handling
2022-10-12 10:09:33 -07:00
Gilles Boccon-Gibod
78b36d2049 Merge pull request #45 from google/gbg/add-missing-app
add controller-info CLI app to setup
2022-10-11 22:21:08 -07:00
Gilles Boccon-Gibod
3e0cad1456 add controller-info CLI app to setup 2022-10-11 22:15:23 -07:00
Octavian Purdila
b4de38cdc3 Fix address resolution handling
In one of the refactors the command address_resolution field was
changed to address_reslution_enable but the controller code was not
updated.
2022-10-11 22:53:42 +00:00
Gilles Boccon-Gibod
68d9fbc159 Merge pull request #42 from google/gbg/improve-linux-doc
Refactor and improve the doc for Bumble on Linux
2022-10-11 14:35:14 -07:00
Gilles Boccon-Gibod
a916b7a21a Merge pull request #43 from google/gbg/proxy-write-with-response
support with_response on adapters
2022-10-11 07:41:28 -07:00
Gilles Boccon-Gibod
6ff52df8bd better/safer Linux recommendations 2022-10-10 20:11:55 -07:00
Gilles Boccon-Gibod
7fa2eb7658 support with_response on adapters 2022-10-10 12:11:51 -07:00
Gilles Boccon-Gibod
86618e52ef Refactor and improve the doc for Bumble on Linux 2022-10-09 12:56:06 -07:00
Gilles Boccon-Gibod
fbb46dd736 Merge pull request #41 from google/gbg/cli-scripts
use arg-less main() functions in all scripts
2022-10-07 16:16:35 -07:00
Gilles Boccon-Gibod
d1e119f176 use arg-less main() functions in all scripts 2022-10-07 13:56:42 -07:00
Gilles Boccon-Gibod
2fc7a0bf04 Merge pull request #39 from google/gbg/usb-descriptors
improve USB device detection logic
2022-10-06 15:39:32 -07:00
Gilles Boccon-Gibod
d6c4644b23 reorder the order of printing 2022-10-06 10:40:28 -07:00
Gilles Boccon-Gibod
073757d5dd Merge pull request #40 from google/gbg/gatt-mtu
maintain the att mtu only at the connection level
2022-10-05 13:53:47 -07:00
Gilles Boccon-Gibod
20dedbd923 maintain the att mtu only at the connection level 2022-10-04 20:04:43 -07:00
Octavian Purdila
df1962e8da apps/usb_probe.py: handle libusb1 exceptions
Some USB device properties are only accessible if the user has the
appropriate permissions. Handle libusb1 errors to graciously skip
showing details for these devices.
2022-10-04 23:38:13 +00:00
Gilles Boccon-Gibod
0edd6b731f Merge pull request #37 from google/gbg/gatt-notify-with-value
add support for notifying with a transient value
2022-10-04 10:33:04 -07:00
Gilles Boccon-Gibod
d2227f017f improve USB device detection logic 2022-10-04 09:59:48 -07:00
Gilles Boccon-Gibod
a2f18cffc9 Merge pull request #38 from google/gbg/usb-interface-discovery
add support for dynamic discovery of USB endpoint addresses
2022-09-21 11:40:13 -07:00
Gilles Boccon-Gibod
db5e52f1df add support for alternate settings 2022-09-20 22:25:40 -07:00
Gilles Boccon-Gibod
d7da5a9379 add support for dynamic discovery of USB endpoints 2022-09-20 16:39:12 -07:00
Gilles Boccon-Gibod
80569bc9f3 add support for notifying with a transient value 2022-09-06 12:42:35 -07:00
Gilles Boccon-Gibod
daa05b8996 Merge pull request #36 from google/gbg/pairing-with-no-distribution
gbg/pairing with no distribution
2022-09-02 10:17:31 -07:00
Gilles Boccon-Gibod
624e860762 support empty distributions in both directions 2022-08-30 18:50:48 -07:00
Gilles Boccon-Gibod
159cbf7774 support pairing with no key distribution 2022-08-30 18:28:24 -07:00
Gilles Boccon-Gibod
d188041694 Merge pull request #35 from zxzxwu/ctkd
Support CTKD over BR/EDR
2022-08-30 06:19:57 -07:00
Josh Wu
99cba19d7c Support CTKD over BR/EDR
Self test is not available Bumble BR/EDR local transport is not
implemented yet.

Test: Internal test - CTKD over BR/EDR
2022-08-30 11:19:22 +08:00
Gilles Boccon-Gibod
84d70ad4f3 add usb_probe tool and improve compatibility (#33)
* add usb_probe tool and improve compatibility with older/non-compliant devices

* fix logic test

* add doc
2022-08-26 12:41:55 -07:00
zxzxwu
996a9e28f4 Handle L2CAP info dynamically (#28)
* Add feature and MTU fields in L2CAP manager constructor
* Add register/unregister API for fixed channels
2022-08-18 08:25:59 -07:00
zxzxwu
27cb4c586b Delegate Classic connectable and discoverable (#27)
For remote-initiated test cases, we need the device to be
scan-configurable.
2022-08-17 14:20:32 -07:00
Gilles Boccon-Gibod
1f78243ea6 add test.release task to facilitate CI integration (#26) 2022-08-16 13:37:26 -07:00
Ray
216ce2abd0 Add release tasks (#6)
Added two tasks to tasks.py, release and release_tests.

Applied black formatter

authored-by: Raymundo Ramirez Mata <raymundora@google.com>
2022-08-16 11:50:30 -07:00
Gilles Boccon-Gibod
431445e6a2 fix imports (#25) 2022-08-16 11:29:56 -07:00
Michael Mogenson
d7cc546248 Update supported commands in console.py docs (#24)
Co-authored-by: Michael Mogenson <mogenson@google.com>
2022-08-12 14:23:21 -07:00
Gilles Boccon-Gibod
29fd19f40d gbg/fix subscribe lambda (#23)
* don't use a lambda as a subscriber

* update tests to look at side effects instead of internals
2022-08-12 14:22:31 -07:00
Michael Mogenson
14dfc1a501 Add subscribe and unsubscribe commands to console.py (#22)
Subscribe command will enable notify or indicate events from the
characteristic, depending on supported characteristic properties, and
print received values to the output window.

Unsubscribe will stop notify or indicate events.

Rename find_attribute() to find_characteristic() and return a
characteristic for a set of UUIDS, a characteristic for an attribute
handle, or None.

Print read and received values has a hex string.

Add an unsubscribe implementation to gatt_client.py. Reset the CCCD bits
to 0x0000. Remove a matching subsciber, if one is provided. Otherwise
remove all subscribers for a characteristic, since no more notify or
indicates events will be comming.

authored-by: Michael Mogenson <mogenson@google.com>
2022-08-12 11:49:01 -07:00
Gilles Boccon-Gibod
938282e961 Update python-publish.yml 2022-08-04 14:40:40 -07:00
Gilles Boccon-Gibod
900c15b151 Update python-publish.yml
trigger on published release
2022-08-04 14:30:25 -07:00
Gilles Boccon-Gibod
9ea93be723 add missing package entry (#21) 2022-08-04 14:27:21 -07:00
Gilles Boccon-Gibod
894ab023c7 Update python-publish.yml
don't run on PRs
2022-08-04 10:50:28 -07:00
Gilles Boccon-Gibod
7bbb37b2da Merge pull request #20 from google/gbg/test-gatt-long-read
add long read self test
2022-08-04 10:33:27 -07:00
Gilles Boccon-Gibod
3fa5d320de add long read self test 2022-08-03 16:19:04 -07:00
Gilles Boccon-Gibod
16d684c199 Merge pull request #19 from google/gbg/pypi-publish
add long description
2022-08-03 16:11:51 -07:00
Gilles Boccon-Gibod
c28aa2ebb6 add long description 2022-08-01 18:18:32 -07:00
Gilles Boccon-Gibod
28586382f4 don't publish to test PyPI
publishing to PyPI doesn't work with SCM versioning
2022-08-01 18:16:46 -07:00
Gilles Boccon-Gibod
76f08977c4 support SCM versioning 2022-08-01 17:30:00 -07:00
Gilles Boccon-Gibod
15cbf52da4 Update python-build-test.yml
Get history and tags for SCM versioning to work
2022-08-01 17:27:11 -07:00
Gilles Boccon-Gibod
f4f84dffef Update python-publish.yml
add action to fetch tags in order for SCM versioning to work
2022-08-01 17:21:19 -07:00
Gilles Boccon-Gibod
6dfb07d7b9 Create python-publish.yml 2022-08-01 16:35:32 -07:00
Gilles Boccon-Gibod
d7ce62beaa Merge pull request #18 from turon/docs/quick_start
[docs] Add some getting started information to the top-level README.
2022-07-31 12:00:36 -07:00
Gilles Boccon-Gibod
0e2a184edb Merge pull request #17 from mogenson/console_py_do_write
Implement 'write' command for console.py
2022-07-30 16:02:47 -07:00
Martin Turon
e6ee5ae996 [docs] Add references to some of the docs to the top-level for discoverability. 2022-07-30 14:18:08 -07:00
Martin Turon
f1836e659f [docs] Add some getting started information to the top-level README. 2022-07-30 14:13:55 -07:00
Michael Mogenson
99218d3abf Implement 'write' command for console.py
Syntax is `write <attribute> <value>`. Supports a value of type string,
hexadecimal string, or integer.

Ex:
- `write 180D.2A38 hello`
- `write 180D.2A38 0xbeef`
- `write 180D.2A38 123`

Write with response method is used if supported by characteristic,
otherwise write without response.

Add a find_attribute() method to consolidate common logic of finding a
characteristic or attribute handle in `do_read()` and `do_write()`.

Tested with run_gatt_server.py example to verify sent data.
2022-07-29 19:45:24 -04:00
Gilles Boccon-Gibod
b5ba0bef63 Merge pull request #16 from google/jdm/connection-context-manager
Adding in Device.connected_to context manager and Peer.sustain
2022-07-27 17:25:06 -07:00
Jayson Messenger
9cd1890faa Adding in context manager for Connection and Peer classes
* Connection implements async context manager to disconnect when
  context is left
    * The Connection only calls disconnect if the context manager exits
      without an exception
* Peer implements async context manager to discover when entering the
  context
* Device.connect_as_gatt implements an async context manager to nest the
  connection and peer context managers
* Added HCI_StatusError that can be raised when a HCI Command Status
  event is received that doesn't show "PENDING" as status
* Added Connection.sustain to wait for a timeout or disconnect
* Peer.sustain also maps to Connectin.sustain
* Updated battery_client.py to use .connect_as_gatt and .sustain
* Updated heart_rate_client.py to use .connect_as_gatt and .sustain
2022-07-27 14:03:12 -04:00
Gilles Boccon-Gibod
472702a9d9 Merge pull request #12 from google/gbg/more-hci-types
more hci types
2022-07-26 18:00:21 -07:00
Gilles Boccon-Gibod
b38740e5b7 Merge pull request #15 from google/gbg/hr-profile
add support for the heart rate service
2022-07-26 13:17:22 -07:00
Gilles Boccon-Gibod
3040df3179 add support for the heart rate service 2022-07-23 09:38:44 -07:00
Gilles Boccon-Gibod
c66b357de6 Merge pull request #13 from google/gbg/standard-profiles
support for type adapters and framework for standard GATT profiles
2022-07-22 10:21:39 -07:00
Gilles Boccon-Gibod
e156ed3758 add in-context uuids and service proxy factories 2022-07-20 19:56:40 -07:00
Gilles Boccon-Gibod
0ffed3deff Merge pull request #11 from zxzxwu/main
Implement CTKD over LE and key distribution delegation
2022-07-20 15:35:26 -07:00
Josh Wu
2f949a1182 Delegate SMP key distribution
* Delegate SMP key distribution
* Align LE pairing key expectation
* Parametrize SMP self test, and add key distribution coverage
2022-07-21 01:19:36 +08:00
Josh Wu
4e2fae5145 Implement CTKD over LE 2022-07-21 01:19:25 +08:00
Gilles Boccon-Gibod
2b58364c51 Merge pull request #14 from zxzxwu/conn-lookup
Refactor find_connection_by_bd_addr
2022-07-20 08:26:04 -07:00
Josh Wu
e3bf7c4b53 Refactor find_connection_by_bd_addr
* Compare only address bytes because Address.__eq__ also compares types.
* Add a transport field to find connection to a device on specific
  transport. (It's possible to connect a device on both BR/EDR and LE)
2022-07-20 21:32:20 +08:00
Gilles Boccon-Gibod
009ecfce96 use list comprehension 2022-07-19 19:53:18 -07:00
Gilles Boccon-Gibod
d6075df356 add tool 2022-07-19 19:53:18 -07:00
Gilles Boccon-Gibod
ebd0a0c8ca more complete set of HCI types and constants 2022-07-19 19:53:18 -07:00
Gilles Boccon-Gibod
bd28892734 add support for type adapters and framework for adding standard GATT profiles 2022-07-19 19:42:21 -07:00
Gilles Boccon-Gibod
b64fa65921 Merge pull request #10 from zxzxwu/main
Make pairing and link mode configurable
2022-07-18 12:48:07 -07:00
Josh Wu
7d87c3cc3a Make pairing and link mode configurable 2022-07-18 14:28:21 +08:00
Gilles Boccon-Gibod
94fc81c183 Merge pull request #7 from DeltaEvo/david/config-load
Make DeviceConfiguration loadable from a dict
2022-06-25 05:08:44 -07:00
Gilles Boccon-Gibod
b65b395fc4 Merge pull request #8 from Arrowbox/threadsafe_usb_close
Use threadsafe call when setting event_loop_done
2022-06-25 05:07:42 -07:00
David Duarte
0f157d55f7 Make DeviceConfiguration loadable from a dict 2022-06-24 09:57:16 +00:00
Jayson Messenger
925d79491f Use threadsafe call when setting event_loop_done
Previously, the close method would hang waiting on the future to be
done.
2022-06-23 15:19:05 -04:00
Gilles Boccon-Gibod
3d14df909c Merge pull request #5 from google/gbg/disconnection-event-routing
fix the routing of disconnection events
2022-06-15 10:50:14 -07:00
Gilles Boccon-Gibod
153788afe3 fix the routing of disconnection events 2022-06-14 14:38:40 -07:00
179 changed files with 18590 additions and 6817 deletions

2
.git-blame-ignore-revs Normal file
View File

@@ -0,0 +1,2 @@
# Migrate code style to Black
135df0dcc01ab765f432e19b1a5202d29bd55545

35
.github/workflows/code-check.yml vendored Normal file
View File

@@ -0,0 +1,35 @@
# Check the code against the formatter and linter
name: Code format and lint check
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
permissions:
contents: read
jobs:
check:
name: Check Code
runs-on: ubuntu-latest
steps:
- name: Check out from Git
uses: actions/checkout@v3
- name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development]"
- name: Check
run: |
invoke project.pre-commit

View File

@@ -14,21 +14,30 @@ jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
fail-fast: false
steps: steps:
- uses: actions/checkout@v3 - name: Check out from Git
- name: Set up Python 3.10 uses: actions/checkout@v3
uses: actions/setup-python@v3 - name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ".[test,development,documentation]" python -m pip install ".[build,test,development,documentation]"
- name: Test with pytest - name: Test
run: | run: |
pytest invoke test
- name: Build - name: Build
run: | run: |
inv build inv build
inv mkdocs inv build.mkdocs

37
.github/workflows/python-publish.yml vendored Normal file
View File

@@ -0,0 +1,37 @@
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
deploy:
name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI
runs-on: ubuntu-latest
steps:
- name: Check out from Git
uses: actions/checkout@v3
- name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install build
- name: Build package
run: python -m build
- name: Publish package to PyPI
if: github.event_name == 'release' && startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

7
.gitignore vendored
View File

@@ -3,8 +3,9 @@ build/
dist/ dist/
*.egg-info/ *.egg-info/
*~ *~
bumble/__pycache__
docs/mkdocs/site docs/mkdocs/site
tests/__pycache__
test-results.xml test-results.xml
bumble/transport/__pycache__ __pycache__
# generated by setuptools_scm
bumble/_version.py
.vscode/launch.json

75
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,75 @@
{
"cSpell.words": [
"Abortable",
"altsetting",
"ansiblue",
"ansicyan",
"ansigreen",
"ansimagenta",
"ansired",
"ansiyellow",
"appendleft",
"ASHA",
"asyncio",
"ATRAC",
"avdtp",
"bitpool",
"bitstruct",
"BSCP",
"BTPROTO",
"CCCD",
"cccds",
"cmac",
"CONNECTIONLESS",
"csrcs",
"datagram",
"DATALINK",
"delayreport",
"deregisters",
"deregistration",
"dhkey",
"diversifier",
"Fitbit",
"GATTLINK",
"HANDSFREE",
"keydown",
"keyup",
"levelname",
"libc",
"libusb",
"MITM",
"NDIS",
"NONBLOCK",
"NONCONN",
"OXIMETER",
"popleft",
"psms",
"pyee",
"pyusb",
"rfcomm",
"ROHC",
"rssi",
"SEID",
"seids",
"SERV",
"ssrc",
"strerror",
"subband",
"subbands",
"subevent",
"Subrating",
"substates",
"tobytes",
"tsep",
"usbmodem",
"vhci",
"websockets",
"xcursor",
"ycursor"
],
"[python]": {
"editor.rulers": [88]
},
"python.formatting.provider": "black",
"pylint.importStrategy": "useBundled"
}

View File

@@ -199,4 +199,4 @@
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.

View File

@@ -9,18 +9,49 @@
Bluetooth Stack for Apps, Emulation, Test and Experimentation Bluetooth Stack for Apps, Emulation, Test and Experimentation
============================================================= =============================================================
<img src="docs/mkdocs/src/images/logo_framed.png" alt="drawing" width="200" height="200"/> <img src="docs/mkdocs/src/images/logo_framed.png" alt="Logo" width="200" height="200"/>
Bumble is a full-featured Bluetooth stack written entirely in Python. It supports most of the common Bluetooth Low Energy (BLE) and Bluetooth Classic (BR/EDR) protocols and profiles, including GAP, L2CAP, ATT, GATT, SMP, SDP, RFCOMM, HFP, HID and A2DP. The stack can be used with physical radios via HCI over USB, UART, or the Linux VHCI, as well as virtual radios, including the virtual Bluetooth support of the Android emulator. Bumble is a full-featured Bluetooth stack written entirely in Python. It supports most of the common Bluetooth Low Energy (BLE) and Bluetooth Classic (BR/EDR) protocols and profiles, including GAP, L2CAP, ATT, GATT, SMP, SDP, RFCOMM, HFP, HID and A2DP. The stack can be used with physical radios via HCI over USB, UART, or the Linux VHCI, as well as virtual radios, including the virtual Bluetooth support of the Android emulator.
## Documentation ## Documentation
Browse the pre-built [Online Documentation](https://google.github.io/bumble/), Browse the pre-built [Online Documentation](https://google.github.io/bumble/),
or see the documentation source under `docs/mkdocs/src`, or build the static HTML site from the markdown text with: or see the documentation source under `docs/mkdocs/src`, or build the static HTML site from the markdown text with:
``` ```
mkdocs build -f docs/mkdocs/mkdocs.yml mkdocs build -f docs/mkdocs/mkdocs.yml
``` ```
## Usage
### Getting Started
For a quick start to using Bumble, see the [Getting Started](docs/mkdocs/src/getting_started.md) guide.
### Dependencies
To install package dependencies needed to run the bumble examples, execute the following commands:
```
python -m pip install --upgrade pip
python -m pip install ".[test,development,documentation]"
```
### Examples
Refer to the [Examples Documentation](examples/README.md) for details on the included example scripts and how to run them.
The complete [list of Examples](/docs/mkdocs/src/examples/index.md), and what they are designed to do is here.
There are also a set of [Apps and Tools](docs/mkdocs/src/apps_and_tools/index.md) that show the utility of Bumble.
### Using Bumble With a USB Dongle
Bumble is easiest to use with a dedicated USB dongle.
This is because internal Bluetooth interfaces tend to be locked down by the operating system.
You can use the [usb_probe](/docs/mkdocs/src/apps_and_tools/usb_probe.md) tool (all platforms) or `lsusb` (Linux or macOS) to list the available USB devices on your system.
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices. Also, if your are on a mac, see [these instructions](docs/mkdocs/src/platforms/macos.md).
## License ## License
Licensed under the [Apache 2.0](LICENSE) License. Licensed under the [Apache 2.0](LICENSE) License.

View File

@@ -47,5 +47,3 @@ NOTE: this assumes you're running a Link Relay on port `10723`.
## `console.py` ## `console.py`
A simple text-based-ui interactive Bluetooth device with GATT client capabilities. A simple text-based-ui interactive Bluetooth device with GATT client capabilities.

File diff suppressed because it is too large Load Diff

162
apps/controller_info.py Normal file
View File

@@ -0,0 +1,162 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import os
import logging
import click
from colors import color
from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.core import name_or_number
from bumble.hci import (
map_null_terminated_utf8_string,
HCI_SUCCESS,
HCI_LE_SUPPORTED_FEATURES_NAMES,
HCI_VERSION_NAMES,
LMP_VERSION_NAMES,
HCI_Command,
HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_Read_Local_Name_Command,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def get_classic_info(host):
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if response.return_parameters.status == HCI_SUCCESS:
print()
print(
color('Classic Address:', 'yellow'), response.return_parameters.bd_addr
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response = await host.send_command(HCI_Read_Local_Name_Command())
if response.return_parameters.status == HCI_SUCCESS:
print()
print(
color('Local Name:', 'yellow'),
map_null_terminated_utf8_string(response.return_parameters.local_name),
)
# -----------------------------------------------------------------------------
async def get_le_info(host):
print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
response = await host.send_command(
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
if response.return_parameters.status == HCI_SUCCESS:
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response.return_parameters.num_supported_advertising_sets,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND):
response = await host.send_command(
HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
if response.return_parameters.status == HCI_SUCCESS:
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response.return_parameters.max_advertising_data_length,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
response = await host.send_command(HCI_LE_Read_Maximum_Data_Length_Command())
if response.return_parameters.status == HCI_SUCCESS:
print(
color('Maximum Data Length:', 'yellow'),
(
f'tx:{response.return_parameters.supported_max_tx_octets}/'
f'{response.return_parameters.supported_max_tx_time}, '
f'rx:{response.return_parameters.supported_max_rx_octets}/'
f'{response.return_parameters.supported_max_rx_time}'
),
'\n',
)
print(color('LE Features:', 'yellow'))
for feature in host.supported_le_features:
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
# -----------------------------------------------------------------------------
async def async_main(transport):
print('<<< connecting to HCI...')
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected')
host = Host(hci_source, hci_sink)
await host.reset()
# Print version
print(color('Version:', 'yellow'))
print(
color(' Manufacturer: ', 'green'),
name_or_number(COMPANY_IDENTIFIERS, host.local_version.company_identifier),
)
print(
color(' HCI Version: ', 'green'),
name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version),
)
print(color(' HCI Subversion:', 'green'), host.local_version.hci_subversion)
print(
color(' LMP Version: ', 'green'),
name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version),
)
print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion)
# Get the Classic info
await get_classic_info(host)
# Get the LE info
await get_le_info(host)
# Print the list of commands supported by the controller
print()
print(color('Supported Commands:', 'yellow'))
for command in host.supported_commands:
print(' ', HCI_Command.command_name(command))
# -----------------------------------------------------------------------------
@click.command()
@click.argument('transport')
def main(transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(async_main(transport))
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()

View File

@@ -28,11 +28,14 @@ from bumble.transport import open_transport_or_link
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main(): async def async_main():
if len(sys.argv) != 3: if len(sys.argv) != 3:
print('Usage: controllers.py <hci-transport-1> <hci-transport-2> [<hci-transport-3> ...]') print(
'Usage: controllers.py <hci-transport-1> <hci-transport-2> '
'[<hci-transport-3> ...]'
)
print('example: python controllers.py pty:ble1 pty:ble2') print('example: python controllers.py pty:ble1 pty:ble2')
return return
# Create a loccal link to attach the controllers to # Create a local link to attach the controllers to
link = LocalLink() link = LocalLink()
# Create a transport and controller for all requested names # Create a transport and controller for all requested names
@@ -41,7 +44,12 @@ async def async_main():
for index, transport_name in enumerate(sys.argv[1:]): for index, transport_name in enumerate(sys.argv[1:]):
transport = await open_transport_or_link(transport_name) transport = await open_transport_or_link(transport_name)
transports.append(transport) transports.append(transport)
controller = Controller(f'C{index}', host_source = transport.source, host_sink = transport.sink, link = link) controller = Controller(
f'C{index}',
host_source=transport.source,
host_sink=transport.sink,
link=link,
)
controllers.append(controller) controllers.append(controller)
# Wait until the user interrupts # Wait until the user interrupts
@@ -54,7 +62,7 @@ async def async_main():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def main(): def main():
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main()) asyncio.run(async_main())

View File

@@ -21,7 +21,7 @@ import logging
import click import click
from colors import color from colors import color
from bumble.core import ProtocolError, TimeoutError import bumble.core
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.gatt import show_services from bumble.gatt import show_services
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -32,10 +32,10 @@ async def dump_gatt_db(peer, done):
# Discover all services # Discover all services
print(color('### Discovering Services and Characteristics', 'magenta')) print(color('### Discovering Services and Characteristics', 'magenta'))
await peer.discover_services() await peer.discover_services()
await peer.discover_characteristics()
for service in peer.services: for service in peer.services:
await service.discover_characteristics()
for characteristic in service.characteristics: for characteristic in service.characteristics:
await peer.discover_descriptors(characteristic) await characteristic.discover_descriptors()
print(color('=== Services ===', 'yellow')) print(color('=== Services ===', 'yellow'))
show_services(peer.services) show_services(peer.services)
@@ -47,11 +47,11 @@ async def dump_gatt_db(peer, done):
for attribute in attributes: for attribute in attributes:
print(attribute) print(attribute)
try: try:
value = await peer.read_value(attribute) value = await attribute.read_value()
print(color(f'{value.hex()}', 'green')) print(color(f'{value.hex()}', 'green'))
except ProtocolError as error: except bumble.core.ProtocolError as error:
print(color(error, 'red')) print(color(error, 'red'))
except TimeoutError: except bumble.core.TimeoutError:
print(color('read timeout', 'red')) print(color('read timeout', 'red'))
if done is not None: if done is not None:
@@ -64,9 +64,13 @@ async def async_main(device_config, encrypt, transport, address_or_name):
# Create a device # Create a device
if device_config: if device_config:
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else: else:
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink) device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
await device.power_on() await device.power_on()
if address_or_name: if address_or_name:
@@ -81,7 +85,12 @@ async def async_main(device_config, encrypt, transport, address_or_name):
else: else:
# Wait for a connection # Wait for a connection
done = asyncio.get_running_loop().create_future() done = asyncio.get_running_loop().create_future()
device.on('connection', lambda connection: asyncio.create_task(dump_gatt_db(Peer(connection), done))) device.on(
'connection',
lambda connection: asyncio.create_task(
dump_gatt_db(Peer(connection), done)
),
)
await device.start_advertising(auto_restart=True) await device.start_advertising(auto_restart=True)
print(color('### Waiting for connection...', 'blue')) print(color('### Waiting for connection...', 'blue'))
@@ -99,7 +108,7 @@ def main(device_config, encrypt, transport, address_or_name):
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified, Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
wait for an incoming connection. wait for an incoming connection.
""" """
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main(device_config, encrypt, transport, address_or_name)) asyncio.run(async_main(device_config, encrypt, transport, address_or_name))

View File

@@ -17,13 +17,14 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import os import os
import struct
import logging import logging
import click import click
from colors import color from colors import color
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.core import AdvertisingData from bumble.core import AdvertisingData
from bumble.gatt import Service, Characteristic from bumble.gatt import Service, Characteristic, CharacteristicValue
from bumble.utils import AsyncRunner from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.hci import HCI_Constant from bumble.hci import HCI_Constant
@@ -32,24 +33,73 @@ from bumble.hci import HCI_Constant
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8' GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8' GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8' GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = 'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8' GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = (
'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8'
)
GG_PREFERRED_MTU = 256 GG_PREFERRED_MTU = 256
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class GattlinkHubBridge(Device.Listener): class GattlinkL2capEndpoint:
def __init__(self): def __init__(self):
self.peer = None self.l2cap_channel = None
self.rx_socket = None self.l2cap_packet = b''
self.tx_socket = None self.l2cap_packet_size = 0
# Called when an L2CAP SDU has been received
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
while len(sdu):
if self.l2cap_packet_size == 0:
# Expect a new packet
self.l2cap_packet_size = sdu[0] + 1
sdu = sdu[1:]
else:
bytes_needed = self.l2cap_packet_size - len(self.l2cap_packet)
chunk = min(bytes_needed, len(sdu))
self.l2cap_packet += sdu[:chunk]
sdu = sdu[chunk:]
if len(self.l2cap_packet) == self.l2cap_packet_size:
self.on_l2cap_packet(self.l2cap_packet)
self.l2cap_packet = b''
self.l2cap_packet_size = 0
# -----------------------------------------------------------------------------
class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device, peer_address):
super().__init__()
self.device = device
self.peer_address = peer_address
self.peer = None
self.tx_socket = None
self.rx_characteristic = None self.rx_characteristic = None
self.tx_characteristic = None self.tx_characteristic = None
self.l2cap_psm_characteristic = None
device.listener = self
async def start(self):
# Connect to the peer
print(f'=== Connecting to {self.peer_address}...')
await self.device.connect(self.peer_address)
async def connect_l2cap(self, psm):
print(color(f'### Connecting with L2CAP on PSM = {psm}', 'yellow'))
try:
self.l2cap_channel = await self.peer.connection.open_l2cap_channel(psm)
print(color('*** Connected', 'yellow'), self.l2cap_channel)
self.l2cap_channel.sink = self.on_coc_sdu
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
@AsyncRunner.run_in_task() @AsyncRunner.run_in_task()
# pylint: disable=invalid-overridden-method
async def on_connection(self, connection): async def on_connection(self, connection):
print(f'=== Connected to {connection}') print(f'=== Connected to {connection}')
self.peer = Peer(connection) self.peer = Peer(connection)
@@ -73,122 +123,229 @@ class GattlinkHubBridge(Device.Listener):
gattlink_service = services[0] gattlink_service = services[0]
# Discover all the characteristics for the service # Discover all the characteristics for the service
characteristics = await self.peer.discover_characteristics(service = gattlink_service) characteristics = await gattlink_service.discover_characteristics()
print(color('=== Characteristics discovered', 'yellow')) print(color('=== Characteristics discovered', 'yellow'))
for characteristic in characteristics: for characteristic in characteristics:
if characteristic.uuid == GG_GATTLINK_RX_CHARACTERISTIC_UUID: if characteristic.uuid == GG_GATTLINK_RX_CHARACTERISTIC_UUID:
self.rx_characteristic = characteristic self.rx_characteristic = characteristic
elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID: elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID:
self.tx_characteristic = characteristic self.tx_characteristic = characteristic
elif (
characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID
):
self.l2cap_psm_characteristic = characteristic
print('RX:', self.rx_characteristic) print('RX:', self.rx_characteristic)
print('TX:', self.tx_characteristic) print('TX:', self.tx_characteristic)
print('PSM:', self.l2cap_psm_characteristic)
# Subscribe to TX if self.l2cap_psm_characteristic:
if self.tx_characteristic: # Subscribe to and then read the PSM value
await self.peer.subscribe(
self.l2cap_psm_characteristic, self.on_l2cap_psm_received
)
psm_bytes = await self.peer.read_value(self.l2cap_psm_characteristic)
psm = struct.unpack('<H', psm_bytes)[0]
await self.connect_l2cap(psm)
elif self.tx_characteristic:
# Subscribe to TX
await self.peer.subscribe(self.tx_characteristic, self.on_tx_received) await self.peer.subscribe(self.tx_characteristic, self.on_tx_received)
print(color('=== Subscribed to Gattlink TX', 'yellow')) print(color('=== Subscribed to Gattlink TX', 'yellow'))
else: else:
print(color('!!! Gattlink TX not found', 'red')) print(color('!!! No Gattlink TX or PSM found', 'red'))
def on_connection_failure(self, error): def on_connection_failure(self, error):
print(color(f'!!! Connection failed: {error}')) print(color(f'!!! Connection failed: {error}'))
def on_disconnection(self, reason): def on_disconnection(self, reason):
print(color(f'!!! Disconnected from {self.peer}, reason={HCI_Constant.error_name(reason)}', 'red')) print(
color(
f'!!! Disconnected from {self.peer}, '
f'reason={HCI_Constant.error_name(reason)}',
'red',
)
)
self.tx_characteristic = None self.tx_characteristic = None
self.rx_characteristic = None self.rx_characteristic = None
self.peer = None self.peer = None
# Called when an L2CAP packet has been received
def on_l2cap_packet(self, packet):
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(packet)
# Called by the GATT client when a notification is received # Called by the GATT client when a notification is received
def on_tx_received(self, value): def on_tx_received(self, value):
print(color('>>> TX:', 'magenta'), value.hex()) print(color(f'<<< [GATT TX]: {len(value)} bytes', 'cyan'))
if self.tx_socket: if self.tx_socket:
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(value) self.tx_socket.sendto(value)
# Called by asyncio when the UDP socket is created # Called by asyncio when the UDP socket is created
def connection_made(self, transport): def on_l2cap_psm_received(self, value):
pass psm = struct.unpack('<H', value)[0]
asyncio.create_task(self.connect_l2cap(psm))
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address):
print(color('<<< RX:', 'magenta'), data.hex())
# TODO: use a queue instead of creating a task everytime
if self.peer and self.rx_characteristic:
asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
# -----------------------------------------------------------------------------
class GattlinkNodeBridge(Device.Listener):
def __init__(self):
self.peer = None
self.rx_socket = None
self.tx_socket = None
# Called by asyncio when the UDP socket is created # Called by asyncio when the UDP socket is created
def connection_made(self, transport): def connection_made(self, transport):
pass pass
# Called by asyncio when a UDP datagram is received # Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address): def datagram_received(self, data, _address):
print(color('<<< RX:', 'magenta'), data.hex()) print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
# TODO: use a queue instead of creating a task everytime if self.l2cap_channel:
if self.peer and self.rx_characteristic: print(color('>>> [L2CAP]', 'yellow'))
self.l2cap_channel.write(bytes([len(data) - 1]) + data)
elif self.peer and self.rx_characteristic:
print(color('>>> [GATT RX]', 'yellow'))
asyncio.create_task(self.peer.write_value(self.rx_characteristic, data)) asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def run(hci_transport, device_address, send_host, send_port, receive_host, receive_port): class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device):
super().__init__()
self.device = device
self.peer = None
self.tx_socket = None
self.tx_subscriber = None
self.rx_characteristic = None
self.transport = None
# Register as a listener
device.listener = self
# Listen for incoming L2CAP CoC connections
psm = 0xFB
device.register_l2cap_channel_server(0xFB, self.on_coc)
print(f'### Listening for CoC connection on PSM {psm}')
# Setup the Gattlink service
self.rx_characteristic = Characteristic(
GG_GATTLINK_RX_CHARACTERISTIC_UUID,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=self.on_rx_write),
)
self.tx_characteristic = Characteristic(
GG_GATTLINK_TX_CHARACTERISTIC_UUID,
Characteristic.NOTIFY,
Characteristic.READABLE,
)
self.tx_characteristic.on('subscription', self.on_tx_subscription)
self.psm_characteristic = Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([psm, 0]),
)
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[self.rx_characteristic, self.tx_characteristic, self.psm_characteristic],
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData(
[
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
bytes(
reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))
),
),
]
)
)
async def start(self):
await self.device.start_advertising()
# Called by asyncio when the UDP socket is created
def connection_made(self, transport):
self.transport = transport
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, _address):
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
if self.l2cap_channel:
print(color('>>> [L2CAP]', 'yellow'))
self.l2cap_channel.write(bytes([len(data) - 1]) + data)
elif self.tx_subscriber:
print(color('>>> [GATT TX]', 'yellow'))
self.tx_characteristic.value = data
asyncio.create_task(self.device.notify_subscribers(self.tx_characteristic))
# Called when a write to the RX characteristic has been received
def on_rx_write(self, _connection, data):
print(color(f'<<< [GATT RX]: {len(data)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(data)
# Called when the subscription to the TX characteristic has changed
def on_tx_subscription(self, peer, enabled):
print(
f'### [GATT TX] subscription from {peer}: '
f'{"enabled" if enabled else "disabled"}'
)
if enabled:
self.tx_subscriber = peer
else:
self.tx_subscriber = None
# Called when an L2CAP packet is received
def on_l2cap_packet(self, packet):
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(packet)
# Called when a new connection is established
def on_coc(self, channel):
print('*** CoC Connection', channel)
self.l2cap_channel = channel
channel.sink = self.on_coc_sdu
# -----------------------------------------------------------------------------
async def run(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
# Instantiate a bridge object # Instantiate a bridge object
bridge = GattlinkNodeBridge() device = Device.with_hci('Bumble GG', device_address, hci_source, hci_sink)
# Instantiate a bridge object
if role_or_peer_address == 'node':
bridge = GattlinkNodeBridge(device)
else:
bridge = GattlinkHubBridge(device, role_or_peer_address)
# Create a UDP to RX bridge (receive from UDP, send to RX) # Create a UDP to RX bridge (receive from UDP, send to RX)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
await loop.create_datagram_endpoint( await loop.create_datagram_endpoint(
lambda: bridge, lambda: bridge, local_addr=(receive_host, receive_port)
local_addr=(receive_host, receive_port)
) )
# Create a UDP to TX bridge (receive from TX, send to UDP) # Create a UDP to TX bridge (receive from TX, send to UDP)
bridge.tx_socket, _ = await loop.create_datagram_endpoint( bridge.tx_socket, _ = await loop.create_datagram_endpoint(
# pylint: disable-next=unnecessary-lambda
lambda: asyncio.DatagramProtocol(), lambda: asyncio.DatagramProtocol(),
remote_addr=(send_host, send_port) remote_addr=(send_host, send_port),
) )
# Create a device to manage the host, with a custom listener
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
device.listener = bridge
await device.power_on() await device.power_on()
await bridge.start()
# Connect to the peer
# print(f'=== Connecting to {device_address}...')
# await device.connect(device_address)
# TODO move to class
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[
Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.READ,
Characteristic.READABLE,
bytes([193, 0])
)
]
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))))
])
)
await device.start_advertising()
# Wait until the source terminates # Wait until the source terminates
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
@@ -197,15 +354,44 @@ async def run(hci_transport, device_address, send_host, send_port, receive_host,
@click.command() @click.command()
@click.argument('hci_transport') @click.argument('hci_transport')
@click.argument('device_address') @click.argument('device_address')
@click.option('-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to') @click.argument('role_or_peer_address')
@click.option(
'-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to'
)
@click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to') @click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to')
@click.option('-rh', '--receive-host', type=str, default='127.0.0.1', help='UDP host to receive on') @click.option(
@click.option('-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on') '-rh',
def main(hci_transport, device_address, send_host, send_port, receive_host, receive_port): '--receive-host',
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) type=str,
asyncio.run(run(hci_transport, device_address, send_host, send_port, receive_host, receive_port)) default='127.0.0.1',
help='UDP host to receive on',
)
@click.option(
'-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on'
)
def main(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
):
asyncio.run(
run(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@@ -34,16 +34,29 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main(): async def async_main():
if len(sys.argv) < 3: if len(sys.argv) < 3:
print('Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> [command-short-circuit-list]') print(
print('example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078') 'Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> '
'[command-short-circuit-list]'
)
print(
'example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 '
'serial:/dev/tty.usbmodem0006839912171,1000000 '
'0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078'
)
return return
print('>>> connecting to HCI...') print('>>> connecting to HCI...')
async with await transport.open_transport_or_link(sys.argv[1]) as (hci_host_source, hci_host_sink): async with await transport.open_transport_or_link(sys.argv[1]) as (
hci_host_source,
hci_host_sink,
):
print('>>> connected') print('>>> connected')
print('>>> connecting to HCI...') print('>>> connecting to HCI...')
async with await transport.open_transport_or_link(sys.argv[2]) as (hci_controller_source, hci_controller_sink): async with await transport.open_transport_or_link(sys.argv[2]) as (
hci_controller_source,
hci_controller_sink,
):
print('>>> connected') print('>>> connected')
command_short_circuits = [] command_short_circuits = []
@@ -51,36 +64,43 @@ async def async_main():
for op_code_str in sys.argv[3].split(','): for op_code_str in sys.argv[3].split(','):
if ':' in op_code_str: if ':' in op_code_str:
ogf, ocf = op_code_str.split(':') ogf, ocf = op_code_str.split(':')
command_short_circuits.append(hci.hci_command_op_code(int(ogf, 16), int(ocf, 16))) command_short_circuits.append(
hci.hci_command_op_code(int(ogf, 16), int(ocf, 16))
)
else: else:
command_short_circuits.append(int(op_code_str, 16)) command_short_circuits.append(int(op_code_str, 16))
def host_to_controller_filter(hci_packet): def host_to_controller_filter(hci_packet):
if hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET and hci_packet.op_code in command_short_circuits: if (
hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET
and hci_packet.op_code in command_short_circuits
):
# Respond with a success response # Respond with a success response
logger.debug('short-circuiting packet') logger.debug('short-circuiting packet')
response = hci.HCI_Command_Complete_Event( response = hci.HCI_Command_Complete_Event(
num_hci_command_packets = 1, num_hci_command_packets=1,
command_opcode = hci_packet.op_code, command_opcode=hci_packet.op_code,
return_parameters = bytes([hci.HCI_SUCCESS]) return_parameters=bytes([hci.HCI_SUCCESS]),
) )
# Return a packet with 'respond to sender' set to True # Return a packet with 'respond to sender' set to True
return (response.to_bytes(), True) return (response.to_bytes(), True)
return None
_ = HCI_Bridge( _ = HCI_Bridge(
hci_host_source, hci_host_source,
hci_host_sink, hci_host_sink,
hci_controller_source, hci_controller_source,
hci_controller_sink, hci_controller_sink,
host_to_controller_filter, host_to_controller_filter,
None None,
) )
await asyncio.get_running_loop().create_future() await asyncio.get_running_loop().create_future()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def main(): def main():
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main()) asyncio.run(async_main())

350
apps/l2cap_bridge.py Normal file
View File

@@ -0,0 +1,350 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import click
from colors import color
from bumble.transport import open_transport_or_link
from bumble.device import Device
from bumble.utils import FlowControlAsyncPipe
from bumble.hci import HCI_Constant
# -----------------------------------------------------------------------------
class ServerBridge:
"""
L2CAP CoC server bridge: waits for a peer to connect an L2CAP CoC channel
on a specified PSM. When the connection is made, the bridge connects a TCP
socket to a remote host and bridges the data in both directions, with flow
control.
When the L2CAP CoC channel is closed, the bridge disconnects the TCP socket
and waits for a new L2CAP CoC channel to be connected.
When the TCP connection is closed by the TCP server, XXXX
"""
def __init__(self, psm, max_credits, mtu, mps, tcp_host, tcp_port):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
# Listen for incoming L2CAP CoC connections
device.register_l2cap_channel_server(
psm=self.psm,
server=self.on_coc,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection):
def on_ble_disconnection(reason):
print(
color('@@@ Bluetooth disconnection:', 'red'),
HCI_Constant.error_name(reason),
)
print(color('@@@ Bluetooth connection:', 'green'), connection)
connection.on('disconnection', on_ble_disconnection)
device.on('connection', on_ble_connection)
await device.start_advertising(auto_restart=True)
# Called when a new L2CAP connection is established
def on_coc(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
class Pipe:
def __init__(self, bridge, l2cap_channel):
self.bridge = bridge
self.tcp_transport = None
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_coc_sdu
async def connect_to_tcp(self):
# Connect to the TCP server
print(
color(
f'### Connecting to TCP {self.bridge.tcp_host}:'
f'{self.bridge.tcp_port}...',
'yellow',
)
)
class TcpClientProtocol(asyncio.Protocol):
def __init__(self, pipe):
self.pipe = pipe
def connection_lost(self, exc):
print(color(f'!!! TCP connection lost: {exc}', 'red'))
if self.pipe.l2cap_channel is not None:
asyncio.create_task(self.pipe.l2cap_channel.disconnect())
def data_received(self, data):
print(f'<<< Received on TCP: {len(data)}')
self.pipe.l2cap_channel.write(data)
try:
(
self.tcp_transport,
_,
) = await asyncio.get_running_loop().create_connection(
lambda: TcpClientProtocol(self),
host=self.bridge.tcp_host,
port=self.bridge.tcp_port,
)
print(color('### Connected', 'green'))
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
await self.l2cap_channel.disconnect()
def on_l2cap_close(self):
self.l2cap_channel = None
if self.tcp_transport is not None:
self.tcp_transport.close()
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
if self.tcp_transport is None:
print(color('!!! TCP socket not open, dropping', 'red'))
return
self.tcp_transport.write(sdu)
pipe = Pipe(self, l2cap_channel)
asyncio.create_task(pipe.connect_to_tcp())
# -----------------------------------------------------------------------------
class ClientBridge:
"""
L2CAP CoC client bridge: connects to a BLE device, then waits for an inbound
TCP connection on a specified port number. When a TCP client connects, an
L2CAP CoC channel connection to the BLE device is established, and the data
is bridged in both directions, with flow control.
When the TCP connection is closed by the client, the L2CAP CoC channel is
disconnected, but the connection to the BLE device remains, ready for a new
TCP client to connect.
When the L2CAP CoC channel is closed, XXXX
"""
READ_CHUNK_SIZE = 4096
def __init__(self, psm, max_credits, mtu, mps, address, tcp_host, tcp_port):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.address = address
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
print(color(f'### Connecting to {self.address}...', 'yellow'))
connection = await device.connect(self.address)
print(color('### Connected', 'green'))
# Called when the BLE connection is disconnected
def on_ble_disconnection(reason):
print(
color('@@@ Bluetooth disconnection:', 'red'),
HCI_Constant.error_name(reason),
)
connection.on('disconnection', on_ble_disconnection)
# Called when a TCP connection is established
async def on_tcp_connection(reader, writer):
peer_name = writer.get_extra_info('peer_name')
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
def on_coc_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
l2cap_to_tcp_pipe.write(sdu)
def on_l2cap_close():
print(color('*** L2CAP channel closed', 'red'))
l2cap_to_tcp_pipe.stop()
writer.close()
# Connect a new L2CAP channel
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
try:
l2cap_channel = await connection.open_l2cap_channel(
psm=self.psm,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
writer.close()
return
l2cap_channel.sink = on_coc_sdu
l2cap_channel.on('close', on_l2cap_close)
# Start a flow control pipe from L2CAP to TCP
l2cap_to_tcp_pipe = FlowControlAsyncPipe(
l2cap_channel.pause_reading,
l2cap_channel.resume_reading,
writer.write,
writer.drain,
)
l2cap_to_tcp_pipe.start()
# Pipe data from TCP to L2CAP
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color('!!! End of stream', 'red'))
await l2cap_channel.disconnect()
return
print(color(f'<<< [TCP DATA]: {len(data)} bytes', 'blue'))
l2cap_channel.write(data)
await l2cap_channel.drain()
except Exception as error:
print(f'!!! Exception: {error}')
break
writer.close()
print(color('~~~ Bye bye', 'magenta'))
await asyncio.start_server(
on_tcp_connection,
host=self.tcp_host if self.tcp_host != '_' else None,
port=self.tcp_port,
)
print(
color(
f'### Listening for TCP connections on port {self.tcp_port}', 'magenta'
)
)
# -----------------------------------------------------------------------------
async def run(device_config, hci_transport, bridge):
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
# Let's go
await device.power_on()
await bridge.start(device)
# Wait until the transport terminates
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option(
'--l2cap-coc-max-credits',
help='Maximum L2CAP CoC Credits',
type=click.IntRange(1, 65535),
default=128,
)
@click.option(
'--l2cap-coc-mtu',
help='L2CAP CoC MTU',
type=click.IntRange(23, 65535),
default=1022,
)
@click.option(
'--l2cap-coc-mps',
help='L2CAP CoC MPS',
type=click.IntRange(23, 65533),
default=1024,
)
def cli(
context,
device_config,
hci_transport,
psm,
l2cap_coc_max_credits,
l2cap_coc_mtu,
l2cap_coc_mps,
):
context.ensure_object(dict)
context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_coc_max_credits
context.obj['mtu'] = l2cap_coc_mtu
context.obj['mps'] = l2cap_coc_mps
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.option('--tcp-host', help='TCP host', default='localhost')
@click.option('--tcp-port', help='TCP port', default=9544)
def server(context, tcp_host, tcp_port):
bridge = ServerBridge(
context.obj['psm'],
context.obj['max_credits'],
context.obj['mtu'],
context.obj['mps'],
tcp_host,
tcp_port,
)
asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge))
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.argument('bluetooth-address')
@click.option('--tcp-host', help='TCP host', default='_')
@click.option('--tcp-port', help='TCP port', default=9543)
def client(context, bluetooth_address, tcp_host, tcp_port):
bridge = ClientBridge(
context.obj['psm'],
context.obj['max_credits'],
context.obj['mtu'],
context.obj['mps'],
bluetooth_address,
tcp_host,
tcp_port,
)
asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge))
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__':
cli(obj={}) # pylint: disable=no-value-for-parameter

View File

@@ -16,7 +16,6 @@
# Imports # Imports
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
import sys import sys
import websockets
import logging import logging
import json import json
import asyncio import asyncio
@@ -25,6 +24,7 @@ import uuid
import os import os
from urllib.parse import urlparse from urllib.parse import urlparse
from colors import color from colors import color
import websockets
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -65,9 +65,9 @@ class Connection:
""" """
def __init__(self, room, websocket): def __init__(self, room, websocket):
self.room = room self.room = room
self.websocket = websocket self.websocket = websocket
self.address = str(uuid.uuid4()) self.address = str(uuid.uuid4())
async def send_message(self, message): async def send_message(self, message):
try: try:
@@ -98,7 +98,11 @@ class Connection:
self.address = address self.address = address
def __str__(self): def __str__(self):
return f'Connection(address="{self.address}", client={self.websocket.remote_address[0]}:{self.websocket.remote_address[1]})' return (
f'Connection(address="{self.address}", '
f'client={self.websocket.remote_address[0]}:'
f'{self.websocket.remote_address[1]})'
)
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
@@ -110,9 +114,9 @@ class Room:
""" """
def __init__(self, relay, name): def __init__(self, relay, name):
self.relay = relay self.relay = relay
self.name = name self.name = name
self.observers = [] self.observers = []
self.connections = [] self.connections = []
async def add_connection(self, connection): async def add_connection(self, connection):
@@ -139,13 +143,15 @@ class Room:
# Parse the message to decide how to handle it # Parse the message to decide how to handle it
if message.startswith('@'): if message.startswith('@'):
# This is a targetted message # This is a targeted message
await self.on_targetted_message(connection, message) await self.on_targeted_message(connection, message)
elif message.startswith('/'): elif message.startswith('/'):
# This is an RPC request # This is an RPC request
await self.on_rpc_request(connection, message) await self.on_rpc_request(connection, message)
else: else:
await connection.send_message(f'result:{error_to_json("error: invalid message")}') await connection.send_message(
f'result:{error_to_json("error: invalid message")}'
)
async def broadcast_message(self, sender, message): async def broadcast_message(self, sender, message):
''' '''
@@ -155,7 +161,9 @@ class Room:
async def on_rpc_request(self, connection, message): async def on_rpc_request(self, connection, message):
command, *params = message.split(' ', 1) command, *params = message.split(' ', 1)
if handler := getattr(self, f'on_{command[1:].lower().replace("-","_")}_command', None): if handler := getattr(
self, f'on_{command[1:].lower().replace("-","_")}_command', None
):
try: try:
result = await handler(connection, params) result = await handler(connection, params)
except Exception as error: except Exception as error:
@@ -165,7 +173,7 @@ class Room:
await connection.send_message(result or 'result:{}') await connection.send_message(result or 'result:{}')
async def on_targetted_message(self, connection, message): async def on_targeted_message(self, connection, message):
target, *payload = message.split(' ', 1) target, *payload = message.split(' ', 1)
if not payload: if not payload:
return error_to_json('missing arguments') return error_to_json('missing arguments')
@@ -174,7 +182,8 @@ class Room:
# Determine what targets to send to # Determine what targets to send to
if target == '*': if target == '*':
# Send to all connections in the room except the connection from which the message was received # Send to all connections in the room except the connection from which the
# message was received
connections = [c for c in self.connections if c != connection] connections = [c for c in self.connections if c != connection]
else: else:
connections = self.find_connections_by_address(target) connections = self.find_connections_by_address(target)
@@ -192,7 +201,9 @@ class Room:
current_address = connection.address current_address = connection.address
new_address = params[0] new_address = params[0]
connection.set_address(new_address) connection.set_address(new_address)
await self.broadcast_message(connection, f'address-changed:from={current_address},to={new_address}') await self.broadcast_message(
connection, f'address-changed:from={current_address},to={new_address}'
)
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
@@ -210,9 +221,10 @@ class Relay:
def start(self): def start(self):
logger.info(f'Starting Relay on port {self.port}') logger.info(f'Starting Relay on port {self.port}')
# pylint: disable-next=no-member
return websockets.serve(self.serve, '0.0.0.0', self.port, ping_interval=None) return websockets.serve(self.serve, '0.0.0.0', self.port, ping_interval=None)
async def serve_as_controller(connection): async def serve_as_controller(self, connection):
pass pass
async def serve(self, websocket, path): async def serve(self, websocket, path):
@@ -246,24 +258,24 @@ def main():
print('ERROR: Python 3.6.1 or higher is required') print('ERROR: Python 3.6.1 or higher is required')
sys.exit(1) sys.exit(1)
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Parse arguments # Parse arguments
arg_parser = argparse.ArgumentParser(description='Bumble Link Relay') arg_parser = argparse.ArgumentParser(description='Bumble Link Relay')
arg_parser.add_argument('--log-level', default='INFO', help='logger level') arg_parser.add_argument('--log-level', default='INFO', help='logger level')
arg_parser.add_argument('--log-config', help='logger config file (YAML)') arg_parser.add_argument('--log-config', help='logger config file (YAML)')
arg_parser.add_argument('--port', arg_parser.add_argument(
type = int, '--port', type=int, default=DEFAULT_RELAY_PORT, help='Port to listen on'
default = DEFAULT_RELAY_PORT, )
help = 'Port to listen on')
args = arg_parser.parse_args() args = arg_parser.parse_args()
# Setup logger # Setup logger
if args.log_config: if args.log_config:
from logging import config from logging import config # pylint: disable=import-outside-toplevel
config.fileConfig(args.log_config) config.fileConfig(args.log_config)
else: else:
logging.basicConfig(level = getattr(logging, args.log_level.upper())) logging.basicConfig(level=getattr(logging, args.log_level.upper()))
# Start a relay # Start a relay
relay = Relay(args.port) relay = Relay(args.port)

View File

@@ -33,30 +33,32 @@ from bumble.gatt import (
GATT_GENERIC_ACCESS_SERVICE, GATT_GENERIC_ACCESS_SERVICE,
Service, Service,
Characteristic, Characteristic,
CharacteristicValue CharacteristicValue,
) )
from bumble.att import ( from bumble.att import (
ATT_Error, ATT_Error,
ATT_INSUFFICIENT_AUTHENTICATION_ERROR, ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
ATT_INSUFFICIENT_ENCRYPTION_ERROR ATT_INSUFFICIENT_ENCRYPTION_ERROR,
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Delegate(PairingDelegate): class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, prompt): def __init__(self, mode, connection, capability_string, prompt):
super().__init__({ super().__init__(
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY, {
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY, 'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT, 'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT, 'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
'none': PairingDelegate.NO_OUTPUT_NO_INPUT 'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
}[capability_string.lower()]) 'none': PairingDelegate.NO_OUTPUT_NO_INPUT,
}[capability_string.lower()]
)
self.mode = mode self.mode = mode
self.peer = Peer(connection) self.peer = Peer(connection)
self.peer_name = None self.peer_name = None
self.prompt = prompt self.prompt = prompt
async def update_peer_name(self): async def update_peer_name(self):
if self.peer_name is not None: if self.peer_name is not None:
@@ -84,15 +86,17 @@ class Delegate(PairingDelegate):
while True: while True:
response = await aioconsole.ainput(color('>>> Accept? ', 'yellow')) response = await aioconsole.ainput(color('>>> Accept? ', 'yellow'))
response = response.lower().strip() response = response.lower().strip()
if response == 'yes': if response == 'yes':
return True return True
elif response == 'no':
return False
else:
# Accept silently
return True
async def compare_numbers(self, number, digits): if response == 'no':
return False
# Accept silently
return True
async def compare_numbers(self, number, digits=6):
await self.update_peer_name() await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt # Wait a bit to allow some of the log lines to print before we prompt
@@ -103,11 +107,17 @@ class Delegate(PairingDelegate):
print(color(f'### Pairing with {self.peer_name}', 'yellow')) print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow')) print(color('###-----------------------------------', 'yellow'))
while True: while True:
response = await aioconsole.ainput(color(f'>>> Does the other device display {number:0{digits}}? ', 'yellow')) response = await aioconsole.ainput(
color(
f'>>> Does the other device display {number:0{digits}}? ', 'yellow'
)
)
response = response.lower().strip() response = response.lower().strip()
if response == 'yes': if response == 'yes':
return True return True
elif response == 'no':
if response == 'no':
return False return False
async def get_number(self): async def get_number(self):
@@ -126,7 +136,7 @@ class Delegate(PairingDelegate):
except ValueError: except ValueError:
pass pass
async def display_number(self, number, digits): async def display_number(self, number, digits=6):
await self.update_peer_name() await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt # Wait a bit to allow some of the log lines to print before we prompt
@@ -143,15 +153,19 @@ class Delegate(PairingDelegate):
async def get_peer_name(peer, mode): async def get_peer_name(peer, mode):
if mode == 'classic': if mode == 'classic':
return await peer.request_name() return await peer.request_name()
else:
# Try to get the peer name from GATT
services = await peer.discover_service(GATT_GENERIC_ACCESS_SERVICE)
if not services:
return None
values = await peer.read_characteristics_by_uuid(GATT_DEVICE_NAME_CHARACTERISTIC, services[0]) # Try to get the peer name from GATT
if values: services = await peer.discover_service(GATT_GENERIC_ACCESS_SERVICE)
return values[0].decode('utf-8') if not services:
return None
values = await peer.read_characteristics_by_uuid(
GATT_DEVICE_NAME_CHARACTERISTIC, services[0]
)
if values:
return values[0].decode('utf-8')
return None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -164,12 +178,12 @@ def read_with_error(connection):
if AUTHENTICATION_ERROR_RETURNED[0]: if AUTHENTICATION_ERROR_RETURNED[0]:
return bytes([1]) return bytes([1])
else:
AUTHENTICATION_ERROR_RETURNED[0] = True AUTHENTICATION_ERROR_RETURNED[0] = True
raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR) raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
def write_with_error(connection, value): def write_with_error(connection, _value):
if not connection.is_encrypted: if not connection.is_encrypted:
raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR) raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR)
@@ -183,14 +197,14 @@ def on_connection(connection, request):
print(color(f'<<< Connection: {connection}', 'green')) print(color(f'<<< Connection: {connection}', 'green'))
# Listen for pairing events # Listen for pairing events
connection.on('pairing_start', on_pairing_start) connection.on('pairing_start', on_pairing_start)
connection.on('pairing', on_pairing) connection.on('pairing', on_pairing)
connection.on('pairing_failure', on_pairing_failure) connection.on('pairing_failure', on_pairing_failure)
# Listen for encryption changes # Listen for encryption changes
connection.on( connection.on(
'connection_encryption_change', 'connection_encryption_change',
lambda: on_connection_encryption_change(connection) lambda: on_connection_encryption_change(connection),
) )
# Request pairing if needed # Request pairing if needed
@@ -202,7 +216,12 @@ def on_connection(connection, request):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_connection_encryption_change(connection): def on_connection_encryption_change(connection):
print(color('@@@-----------------------------------', 'blue')) print(color('@@@-----------------------------------', 'blue'))
print(color(f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted', 'blue')) print(
color(
f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted',
'blue',
)
)
print(color('@@@-----------------------------------', 'blue')) print(color('@@@-----------------------------------', 'blue'))
@@ -241,7 +260,7 @@ async def pair(
keystore_file, keystore_file,
device_config, device_config,
hci_transport, hci_transport,
address_or_name address_or_name,
): ):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
@@ -272,9 +291,11 @@ async def pair(
'552957FB-CF1F-4A31-9535-E78847E1A714', '552957FB-CF1F-4A31-9535-E78847E1A714',
Characteristic.READ | Characteristic.WRITE, Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE, Characteristic.READABLE | Characteristic.WRITEABLE,
CharacteristicValue(read=read_with_error, write=write_with_error) CharacteristicValue(
read=read_with_error, write=write_with_error
),
) )
] ],
) )
) )
@@ -288,10 +309,7 @@ async def pair(
# Set up a pairing config factory # Set up a pairing config factory
device.pairing_config_factory = lambda connection: PairingConfig( device.pairing_config_factory = lambda connection: PairingConfig(
sc, sc, mitm, bond, Delegate(mode, connection, io, prompt)
mitm,
bond,
Delegate(mode, connection, io, prompt)
) )
# Connect to a peer or wait for a connection # Connect to a peer or wait for a connection
@@ -319,21 +337,70 @@ async def pair(
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@click.command() @click.command()
@click.option('--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True) @click.option(
@click.option('--sc', type=bool, default=True, help='Use the Secure Connections protocol', show_default=True) '--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True
@click.option('--mitm', type=bool, default=True, help='Request MITM protection', show_default=True) )
@click.option('--bond', type=bool, default=True, help='Enable bonding', show_default=True) @click.option(
@click.option('--io', type=click.Choice(['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']), default='display+keyboard', show_default=True) '--sc',
type=bool,
default=True,
help='Use the Secure Connections protocol',
show_default=True,
)
@click.option(
'--mitm', type=bool, default=True, help='Request MITM protection', show_default=True
)
@click.option(
'--bond', type=bool, default=True, help='Enable bonding', show_default=True
)
@click.option(
'--io',
type=click.Choice(
['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']
),
default='display+keyboard',
show_default=True,
)
@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request') @click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
@click.option('--request', is_flag=True, help='Request that the connecting peer initiate pairing') @click.option(
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
)
@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing') @click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing')
@click.option('--keystore-file', help='File in which to store the pairing keys') @click.option('--keystore-file', help='File in which to store the pairing keys')
@click.argument('device-config') @click.argument('device-config')
@click.argument('hci_transport') @click.argument('hci_transport')
@click.argument('address-or-name', required=False) @click.argument('address-or-name', required=False)
def main(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name): def main(
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) mode,
asyncio.run(pair(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name)) sc,
mitm,
bond,
io,
prompt,
request,
print_keys,
keystore_file,
device_config,
hci_transport,
address_or_name,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(
pair(
mode,
sc,
mitm,
bond,
io,
prompt,
request,
print_keys,
keystore_file,
device_config,
hci_transport,
address_or_name,
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -25,14 +25,14 @@ from bumble.device import Device
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.keys import JsonKeyStore from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver from bumble.smp import AddressResolver
from bumble.hci import HCI_LE_Advertising_Report_Event from bumble.device import Advertisement
from bumble.core import AdvertisingData from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_rssi_bar(rssi): def make_rssi_bar(rssi):
DISPLAY_MIN_RSSI = -105 DISPLAY_MIN_RSSI = -105
DISPLAY_MAX_RSSI = -30 DISPLAY_MAX_RSSI = -30
DEFAULT_RSSI_BAR_WIDTH = 30 DEFAULT_RSSI_BAR_WIDTH = 30
blocks = ['', '', '', '', '', '', '', ''] blocks = ['', '', '', '', '', '', '', '']
@@ -48,19 +48,24 @@ class AdvertisementPrinter:
self.min_rssi = min_rssi self.min_rssi = min_rssi
self.resolver = resolver self.resolver = resolver
def print_advertisement(self, address, address_color, ad_data, rssi): def print_advertisement(self, advertisement):
if self.min_rssi is not None and rssi < self.min_rssi: address = advertisement.address
address_color = 'yellow' if advertisement.is_connectable else 'red'
if self.min_rssi is not None and advertisement.rssi < self.min_rssi:
return return
address_qualifier = '' address_qualifier = ''
resolution_qualifier = '' resolution_qualifier = ''
if self.resolver and address.is_resolvable: if self.resolver and advertisement.address.is_resolvable:
resolved = self.resolver.resolve(address) resolved = self.resolver.resolve(advertisement.address)
if resolved is not None: if resolved is not None:
resolution_qualifier = f'(resolved from {address})' resolution_qualifier = f'(resolved from {advertisement.address})'
address = resolved address = resolved
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[address.address_type] address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
address.address_type
]
if address.is_public: if address.is_public:
type_color = 'cyan' type_color = 'cyan'
else: else:
@@ -74,18 +79,32 @@ class AdvertisementPrinter:
type_color = 'blue' type_color = 'blue'
address_qualifier = '(non-resolvable)' address_qualifier = '(non-resolvable)'
rssi_bar = make_rssi_bar(rssi)
separator = '\n ' separator = '\n '
print(f'>>> {color(address, address_color)} [{color(address_type_string, type_color)}]{address_qualifier}{resolution_qualifier}:{separator}RSSI:{rssi:4} {rssi_bar}{separator}{ad_data.to_string(separator)}\n') rssi_bar = make_rssi_bar(advertisement.rssi)
if not advertisement.is_legacy:
phy_info = (
f'PHY: {HCI_Constant.le_phy_name(advertisement.primary_phy)}/'
f'{HCI_Constant.le_phy_name(advertisement.secondary_phy)} '
f'{separator}'
)
else:
phy_info = ''
def on_advertisement(self, address, ad_data, rssi, connectable): print(
address_color = 'yellow' if connectable else 'red' f'>>> {color(address, address_color)} '
self.print_advertisement(address, address_color, ad_data, rssi) f'[{color(address_type_string, type_color)}]{address_qualifier}'
f'{resolution_qualifier}:{separator}'
f'{phy_info}'
f'RSSI:{advertisement.rssi:4} {rssi_bar}{separator}'
f'{advertisement.data.to_string(separator)}\n'
)
def on_advertising_report(self, address, ad_data, rssi, event_type): def on_advertisement(self, advertisement):
print(f'{color("EVENT", "green")}: {HCI_LE_Advertising_Report_Event.event_type_name(event_type)}') self.print_advertisement(advertisement)
ad_data = AdvertisingData.from_bytes(ad_data)
self.print_advertisement(address, 'yellow', ad_data, rssi) def on_advertising_report(self, report):
print(f'{color("EVENT", "green")}: {report.event_type_string()}')
self.print_advertisement(Advertisement.from_advertising_report(report))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -94,20 +113,25 @@ async def scan(
passive, passive,
scan_interval, scan_interval,
scan_window, scan_window,
phy,
filter_duplicates, filter_duplicates,
raw, raw,
keystore_file, keystore_file,
device_config, device_config,
transport transport,
): ):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(transport) as (hci_source, hci_sink): async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
if device_config: if device_config:
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else: else:
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink) device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
if keystore_file: if keystore_file:
keystore = JsonKeyStore(namespace=None, filename=keystore_file) keystore = JsonKeyStore(namespace=None, filename=keystore_file)
@@ -126,11 +150,18 @@ async def scan(
device.on('advertisement', printer.on_advertisement) device.on('advertisement', printer.on_advertisement)
await device.power_on() await device.power_on()
if phy is None:
scanning_phys = [HCI_LE_1M_PHY, HCI_LE_CODED_PHY]
else:
scanning_phys = [{'1m': HCI_LE_1M_PHY, 'coded': HCI_LE_CODED_PHY}[phy]]
await device.start_scanning( await device.start_scanning(
active=(not passive), active=(not passive),
scan_interval=scan_interval, scan_interval=scan_interval,
scan_window=scan_window, scan_window=scan_window,
filter_duplicates=filter_duplicates filter_duplicates=filter_duplicates,
scanning_phys=scanning_phys,
) )
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
@@ -142,14 +173,51 @@ async def scan(
@click.option('--passive', is_flag=True, default=False, help='Perform passive scanning') @click.option('--passive', is_flag=True, default=False, help='Perform passive scanning')
@click.option('--scan-interval', type=int, default=60, help='Scan interval') @click.option('--scan-interval', type=int, default=60, help='Scan interval')
@click.option('--scan-window', type=int, default=60, help='Scan window') @click.option('--scan-window', type=int, default=60, help='Scan window')
@click.option('--filter-duplicates', type=bool, default=True, help='Filter duplicates at the controller level') @click.option(
@click.option('--raw', is_flag=True, default=False, help='Listen for raw advertising reports instead of processed ones') '--phy', type=click.Choice(['1m', 'coded']), help='Only scan on the specified PHY'
)
@click.option(
'--filter-duplicates',
type=bool,
default=True,
help='Filter duplicates at the controller level',
)
@click.option(
'--raw',
is_flag=True,
default=False,
help='Listen for raw advertising reports instead of processed ones',
)
@click.option('--keystore-file', help='Keystore file to use when resolving addresses') @click.option('--keystore-file', help='Keystore file to use when resolving addresses')
@click.option('--device-config', help='Device config file for the scanning device') @click.option('--device-config', help='Device config file for the scanning device')
@click.argument('transport') @click.argument('transport')
def main(min_rssi, passive, scan_interval, scan_window, filter_duplicates, raw, keystore_file, device_config, transport): def main(
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) min_rssi,
asyncio.run(scan(min_rssi, passive, scan_interval, scan_window, filter_duplicates, raw, keystore_file, device_config, transport)) passive,
scan_interval,
scan_window,
phy,
filter_duplicates,
raw,
keystore_file,
device_config,
transport,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(
scan(
min_rssi,
passive,
scan_interval,
scan_window,
phy,
filter_duplicates,
raw,
keystore_file,
device_config,
transport,
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -27,13 +27,14 @@ from bumble.helpers import PacketTracer
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class SnoopPacketReader: class SnoopPacketReader:
''' '''
Reader that reads HCI packets from a "snoop" file (based on RFC 1761, but not exactly the same...) Reader that reads HCI packets from a "snoop" file (based on RFC 1761, but not
exactly the same...)
''' '''
DATALINK_H1 = 1001 DATALINK_H1 = 1001
DATALINK_H4 = 1002 DATALINK_H4 = 1002
DATALINK_BSCP = 1003 DATALINK_BSCP = 1003
DATALINK_H5 = 1004 DATALINK_H5 = 1004
def __init__(self, source): def __init__(self, source):
self.source = source self.source = source
@@ -41,9 +42,13 @@ class SnoopPacketReader:
# Read the header # Read the header
identification_pattern = source.read(8) identification_pattern = source.read(8)
if identification_pattern.hex().lower() != '6274736e6f6f7000': if identification_pattern.hex().lower() != '6274736e6f6f7000':
raise ValueError('not a valid snoop file, unexpected identification pattern') raise ValueError(
(self.version_number, self.data_link_type) = struct.unpack('>II', source.read(8)) 'not a valid snoop file, unexpected identification pattern'
if self.data_link_type != self.DATALINK_H4 and self.data_link_type != self.DATALINK_H1: )
(self.version_number, self.data_link_type) = struct.unpack(
'>II', source.read(8)
)
if self.data_link_type not in (self.DATALINK_H4, self.DATALINK_H1):
raise ValueError(f'datalink type {self.data_link_type} not supported') raise ValueError(f'datalink type {self.data_link_type} not supported')
def next_packet(self): def next_packet(self):
@@ -55,9 +60,9 @@ class SnoopPacketReader:
original_length, original_length,
included_length, included_length,
packet_flags, packet_flags,
cumulative_drops, _cumulative_drops,
timestamp_seconds, _timestamp_seconds,
timestamp_microsecond _timestamp_microsecond,
) = struct.unpack('>IIIIII', header) ) = struct.unpack('>IIIIII', header)
# Abort on truncated packets # Abort on truncated packets
@@ -79,24 +84,34 @@ class SnoopPacketReader:
else: else:
packet_type = hci.HCI_ACL_DATA_PACKET packet_type = hci.HCI_ACL_DATA_PACKET
return (packet_flags & 1, bytes([packet_type]) + self.source.read(included_length)) return (
else: packet_flags & 1,
return (packet_flags & 1, self.source.read(included_length)) bytes([packet_type]) + self.source.read(included_length),
)
return (packet_flags & 1, self.source.read(included_length))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Main # Main
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@click.command() @click.command()
@click.option('--format', type=click.Choice(['h4', 'snoop']), default='h4', help='Format of the input file') @click.option(
'--format',
type=click.Choice(['h4', 'snoop']),
default='h4',
help='Format of the input file',
)
@click.argument('filename') @click.argument('filename')
def show(format, filename): # pylint: disable=redefined-builtin
def main(format, filename):
input = open(filename, 'rb') input = open(filename, 'rb')
if format == 'h4': if format == 'h4':
packet_reader = PacketReader(input) packet_reader = PacketReader(input)
def read_next_packet(): def read_next_packet():
(0, packet_reader.next_packet()) return (0, packet_reader.next_packet())
else: else:
packet_reader = SnoopPacketReader(input) packet_reader = SnoopPacketReader(input)
read_next_packet = packet_reader.next_packet read_next_packet = packet_reader.next_packet
@@ -112,9 +127,8 @@ def show(format, filename):
except Exception as error: except Exception as error:
print(color(f'!!! {error}', 'red')) print(color(f'!!! {error}', 'red'))
pass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
show() main() # pylint: disable=no-value-for-parameter

View File

@@ -54,7 +54,7 @@ async def unbond(keystore_file, device_config, address):
@click.argument('device-config') @click.argument('device-config')
@click.argument('address', required=False) @click.argument('address', required=False)
def main(keystore_file, device_config, address): def main(keystore_file, device_config, address):
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(unbond(keystore_file, device_config, address)) asyncio.run(unbond(keystore_file, device_config, address))

278
apps/usb_probe.py Normal file
View File

@@ -0,0 +1,278 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# This tool lists all the USB devices, with details about each device.
# For each device, the different possible Bumble transport strings that can
# refer to it are listed. If the device is known to be a Bluetooth HCI device,
# its identifier is printed in reverse colors, and the transport names in cyan color.
# For other devices, regardless of their type, the transport names are printed
# in red. Whether that device is actually a Bluetooth device or not depends on
# whether it is a Bluetooth device that uses a non-standard Class, or some other
# type of device (there's no way to tell).
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import os
import logging
import click
import usb1
from colors import color
from bumble.transport.usb import load_libusb
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
USB_DEVICE_CLASS_DEVICE = 0x00
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
USB_DEVICE_CLASSES = {
0x00: 'Device',
0x01: 'Audio',
0x02: 'Communications and CDC Control',
0x03: 'Human Interface Device',
0x05: 'Physical',
0x06: 'Still Imaging',
0x07: 'Printer',
0x08: 'Mass Storage',
0x09: 'Hub',
0x0A: 'CDC Data',
0x0B: 'Smart Card',
0x0D: 'Content Security',
0x0E: 'Video',
0x0F: 'Personal Healthcare',
0x10: 'Audio/Video',
0x11: 'Billboard',
0x12: 'USB Type-C Bridge',
0x3C: 'I3C',
0xDC: 'Diagnostic',
USB_DEVICE_CLASS_WIRELESS_CONTROLLER: (
'Wireless Controller',
{
0x01: {
0x01: 'Bluetooth',
0x02: 'UWB',
0x03: 'Remote NDIS',
0x04: 'Bluetooth AMP',
}
},
),
0xEF: 'Miscellaneous',
0xFE: 'Application Specific',
0xFF: 'Vendor Specific',
}
USB_ENDPOINT_IN = 0x80
USB_ENDPOINT_TYPES = ['CONTROL', 'ISOCHRONOUS', 'BULK', 'INTERRUPT']
USB_BT_HCI_CLASS_TUPLE = (
USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
USB_DEVICE_SUBCLASS_RF_CONTROLLER,
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
# -----------------------------------------------------------------------------
def show_device_details(device):
for configuration in device:
print(f' Configuration {configuration.getConfigurationValue()}')
for interface in configuration:
for setting in interface:
alternate_setting = setting.getAlternateSetting()
suffix = (
f'/{alternate_setting}' if interface.getNumSettings() > 1 else ''
)
(class_string, subclass_string) = get_class_info(
setting.getClass(), setting.getSubClass(), setting.getProtocol()
)
details = f'({class_string}, {subclass_string})'
print(f' Interface: {setting.getNumber()}{suffix} {details}')
for endpoint in setting:
endpoint_type = USB_ENDPOINT_TYPES[endpoint.getAttributes() & 3]
endpoint_direction = (
'OUT'
if (endpoint.getAddress() & USB_ENDPOINT_IN == 0)
else 'IN'
)
print(
f' Endpoint 0x{endpoint.getAddress():02X}: '
f'{endpoint_type} {endpoint_direction}'
)
# -----------------------------------------------------------------------------
def get_class_info(cls, subclass, protocol):
class_info = USB_DEVICE_CLASSES.get(cls)
protocol_string = ''
if class_info is None:
class_string = f'0x{cls:02X}'
else:
if isinstance(class_info, tuple):
class_string = class_info[0]
subclass_info = class_info[1].get(subclass)
if subclass_info:
protocol_string = subclass_info.get(protocol)
if protocol_string is not None:
protocol_string = f' [{protocol_string}]'
else:
class_string = class_info
subclass_string = f'{subclass}/{protocol}{protocol_string}'
return (class_string, subclass_string)
# -----------------------------------------------------------------------------
def is_bluetooth_hci(device):
# Check if the device class indicates a match
if (
device.getDeviceClass(),
device.getDeviceSubClass(),
device.getDeviceProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
# If the device class is 'Device', look for a matching interface
if device.getDeviceClass() == USB_DEVICE_CLASS_DEVICE:
for configuration in device:
for interface in configuration:
for setting in interface:
if (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
return False
# -----------------------------------------------------------------------------
@click.command()
@click.option('--verbose', is_flag=True, default=False, help='Print more details')
def main(verbose):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
load_libusb()
with usb1.USBContext() as context:
bluetooth_device_count = 0
devices = {}
for device in context.getDeviceIterator(skip_on_error=True):
device_class = device.getDeviceClass()
device_subclass = device.getDeviceSubClass()
device_protocol = device.getDeviceProtocol()
device_id = (device.getVendorID(), device.getProductID())
(device_class_string, device_subclass_string) = get_class_info(
device_class, device_subclass, device_protocol
)
try:
device_serial_number = device.getSerialNumber()
except usb1.USBError:
device_serial_number = None
try:
device_manufacturer = device.getManufacturer()
except usb1.USBError:
device_manufacturer = None
try:
device_product = device.getProduct()
except usb1.USBError:
device_product = None
device_is_bluetooth_hci = is_bluetooth_hci(device)
if device_is_bluetooth_hci:
bluetooth_device_count += 1
fg_color = 'black'
bg_color = 'yellow'
else:
fg_color = 'yellow'
bg_color = 'black'
# Compute the different ways this can be referenced as a Bumble transport
bumble_transport_names = []
basic_transport_name = (
f'usb:{device.getVendorID():04X}:{device.getProductID():04X}'
)
if device_is_bluetooth_hci:
bumble_transport_names.append(f'usb:{bluetooth_device_count - 1}')
if device_id not in devices:
bumble_transport_names.append(basic_transport_name)
else:
bumble_transport_names.append(
f'{basic_transport_name}#{len(devices[device_id])}'
)
if device_serial_number is not None:
if (
device_id not in devices
or device_serial_number not in devices[device_id]
):
bumble_transport_names.append(
f'{basic_transport_name}/{device_serial_number}'
)
# Print the results
print(
color(
f'ID {device.getVendorID():04X}:{device.getProductID():04X}',
fg=fg_color,
bg=bg_color,
)
)
if bumble_transport_names:
print(
color(' Bumble Transport Names:', 'blue'),
' or '.join(
color(x, 'cyan' if device_is_bluetooth_hci else 'red')
for x in bumble_transport_names
),
)
print(
color(' Bus/Device: ', 'green'),
f'{device.getBusNumber():03}/{device.getDeviceAddress():03}',
)
print(color(' Class: ', 'green'), device_class_string)
print(color(' Subclass/Protocol: ', 'green'), device_subclass_string)
if device_serial_number is not None:
print(color(' Serial: ', 'green'), device_serial_number)
if device_manufacturer is not None:
print(color(' Manufacturer: ', 'green'), device_manufacturer)
if device_product is not None:
print(color(' Product: ', 'green'), device_product)
if verbose:
show_device_details(device)
print()
devices.setdefault(device_id, []).append(device_serial_number)
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main() # pylint: disable=no-value-for-parameter

View File

@@ -0,0 +1,4 @@
try:
from ._version import version as __version__
except ImportError:
__version__ = "unknown version"

View File

@@ -16,10 +16,9 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import struct import struct
import bitstruct
import logging import logging
from collections import namedtuple from collections import namedtuple
from colors import color import bitstruct
from .company_ids import COMPANY_IDENTIFIERS from .company_ids import COMPANY_IDENTIFIERS
from .sdp import ( from .sdp import (
@@ -30,7 +29,7 @@ from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
) )
from .core import ( from .core import (
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
@@ -38,7 +37,7 @@ from .core import (
BT_AUDIO_SINK_SERVICE, BT_AUDIO_SINK_SERVICE,
BT_AVDTP_PROTOCOL_ID, BT_AVDTP_PROTOCOL_ID,
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE, BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
name_or_number name_or_number,
) )
@@ -51,6 +50,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
A2DP_SBC_CODEC_TYPE = 0x00 A2DP_SBC_CODEC_TYPE = 0x00
A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = 0x01 A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = 0x01
@@ -127,71 +127,115 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE' MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
} }
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def flags_to_list(flags, values): def flags_to_list(flags, values):
result = [] result = []
for i in range(len(values)): for i, value in enumerate(values):
if flags & (1 << (len(values) - i - 1)): if flags & (1 << (len(values) - i - 1)):
result.append(values[i]) result.append(value)
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)): def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)):
# pylint: disable=import-outside-toplevel
from .avdtp import AVDTP_PSM from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1] version_int = version[0] << 8 | version[1]
return [ return [
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)), ServiceAttribute(
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([ SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT) DataElement.unsigned_integer_32(service_record_handle),
])), ),
ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([ ServiceAttribute(
DataElement.uuid(BT_AUDIO_SOURCE_SERVICE) SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
])), DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ ),
DataElement.sequence([ ServiceAttribute(
DataElement.uuid(BT_L2CAP_PROTOCOL_ID), SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(AVDTP_PSM) DataElement.sequence([DataElement.uuid(BT_AUDIO_SOURCE_SERVICE)]),
]), ),
DataElement.sequence([ ServiceAttribute(
DataElement.uuid(BT_AVDTP_PROTOCOL_ID), SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(version_int) DataElement.sequence(
]) [
])), DataElement.sequence(
ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ [
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int) DataElement.unsigned_integer_16(AVDTP_PSM),
])), ]
),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
),
),
] ]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)): def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
# pylint: disable=import-outside-toplevel
from .avdtp import AVDTP_PSM from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1] version_int = version[0] << 8 | version[1]
return [ return [
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)), ServiceAttribute(
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([ SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT) DataElement.unsigned_integer_32(service_record_handle),
])), ),
ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([ ServiceAttribute(
DataElement.uuid(BT_AUDIO_SINK_SERVICE) SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
])), DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ ),
DataElement.sequence([ ServiceAttribute(
DataElement.uuid(BT_L2CAP_PROTOCOL_ID), SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(AVDTP_PSM) DataElement.sequence([DataElement.uuid(BT_AUDIO_SINK_SERVICE)]),
]), ),
DataElement.sequence([ ServiceAttribute(
DataElement.uuid(BT_AVDTP_PROTOCOL_ID), SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(version_int) DataElement.sequence(
]) [
])), DataElement.sequence(
ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ [
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int) DataElement.unsigned_integer_16(AVDTP_PSM),
])), ]
),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
),
),
] ]
@@ -206,8 +250,8 @@ class SbcMediaCodecInformation(
'subbands', 'subbands',
'allocation_method', 'allocation_method',
'minimum_bitpool_value', 'minimum_bitpool_value',
'maximum_bitpool_value' 'maximum_bitpool_value',
] ],
) )
): ):
''' '''
@@ -215,36 +259,25 @@ class SbcMediaCodecInformation(
''' '''
BIT_FIELDS = 'u4u4u4u2u2u8u8' BIT_FIELDS = 'u4u4u4u2u2u8u8'
SAMPLING_FREQUENCY_BITS = { SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
16000: 1 << 3,
32000: 1 << 2,
44100: 1 << 1,
48000: 1
}
CHANNEL_MODE_BITS = { CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3, SBC_MONO_CHANNEL_MODE: 1 << 3,
SBC_DUAL_CHANNEL_MODE: 1 << 2, SBC_DUAL_CHANNEL_MODE: 1 << 2,
SBC_STEREO_CHANNEL_MODE: 1 << 1, SBC_STEREO_CHANNEL_MODE: 1 << 1,
SBC_JOINT_STEREO_CHANNEL_MODE: 1 SBC_JOINT_STEREO_CHANNEL_MODE: 1,
}
BLOCK_LENGTH_BITS = {
4: 1 << 3,
8: 1 << 2,
12: 1 << 1,
16: 1
}
SUBBANDS_BITS = {
4: 1 << 1,
8: 1
} }
BLOCK_LENGTH_BITS = {4: 1 << 3, 8: 1 << 2, 12: 1 << 1, 16: 1}
SUBBANDS_BITS = {4: 1 << 1, 8: 1}
ALLOCATION_METHOD_BITS = { ALLOCATION_METHOD_BITS = {
SBC_SNR_ALLOCATION_METHOD: 1 << 1, SBC_SNR_ALLOCATION_METHOD: 1 << 1,
SBC_LOUDNESS_ALLOCATION_METHOD: 1 SBC_LOUDNESS_ALLOCATION_METHOD: 1,
} }
@staticmethod @staticmethod
def from_bytes(data): def from_bytes(data):
return SbcMediaCodecInformation(*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data)) return SbcMediaCodecInformation(
*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data)
)
@classmethod @classmethod
def from_discrete_values( def from_discrete_values(
@@ -255,16 +288,16 @@ class SbcMediaCodecInformation(
subbands, subbands,
allocation_method, allocation_method,
minimum_bitpool_value, minimum_bitpool_value,
maximum_bitpool_value maximum_bitpool_value,
): ):
return SbcMediaCodecInformation( return SbcMediaCodecInformation(
sampling_frequency = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency], sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode = cls.CHANNEL_MODE_BITS[channel_mode], channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
block_length = cls.BLOCK_LENGTH_BITS[block_length], block_length=cls.BLOCK_LENGTH_BITS[block_length],
subbands = cls.SUBBANDS_BITS[subbands], subbands=cls.SUBBANDS_BITS[subbands],
allocation_method = cls.ALLOCATION_METHOD_BITS[allocation_method], allocation_method=cls.ALLOCATION_METHOD_BITS[allocation_method],
minimum_bitpool_value = minimum_bitpool_value, minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value = maximum_bitpool_value maximum_bitpool_value=maximum_bitpool_value,
) )
@classmethod @classmethod
@@ -276,16 +309,20 @@ class SbcMediaCodecInformation(
subbands, subbands,
allocation_methods, allocation_methods,
minimum_bitpool_value, minimum_bitpool_value,
maximum_bitpool_value maximum_bitpool_value,
): ):
return SbcMediaCodecInformation( return SbcMediaCodecInformation(
sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies), sampling_frequency=sum(
channel_mode = sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes), cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
block_length = sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths), ),
subbands = sum(cls.SUBBANDS_BITS[x] for x in subbands), channel_mode=sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes),
allocation_method = sum(cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods), block_length=sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths),
minimum_bitpool_value = minimum_bitpool_value, subbands=sum(cls.SUBBANDS_BITS[x] for x in subbands),
maximum_bitpool_value = maximum_bitpool_value allocation_method=sum(
cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods
),
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value=maximum_bitpool_value,
) )
def __bytes__(self): def __bytes__(self):
@@ -294,30 +331,26 @@ class SbcMediaCodecInformation(
def __str__(self): def __str__(self):
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO'] channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness'] allocation_methods = ['SNR', 'Loudness']
return '\n'.join([ return '\n'.join(
'SbcMediaCodecInformation(', # pylint: disable=line-too-long
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}', [
f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}', 'SbcMediaCodecInformation(',
f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}', f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}', f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}',
f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}', f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}',
f' minimum_bitpool_value: {self.minimum_bitpool_value}', f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}',
f' maximum_bitpool_value: {self.maximum_bitpool_value}' f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}',
')' f' minimum_bitpool_value: {self.minimum_bitpool_value}',
]) f' maximum_bitpool_value: {self.maximum_bitpool_value}' ')',
]
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AacMediaCodecInformation( class AacMediaCodecInformation(
namedtuple( namedtuple(
'AacMediaCodecInformation', 'AacMediaCodecInformation',
[ ['object_type', 'sampling_frequency', 'channels', 'vbr', 'bitrate'],
'object_type',
'sampling_frequency',
'channels',
'vbr',
'bitrate'
]
) )
): ):
''' '''
@@ -326,13 +359,13 @@ class AacMediaCodecInformation(
BIT_FIELDS = 'u8u12u2p2u1u23' BIT_FIELDS = 'u8u12u2p2u1u23'
OBJECT_TYPE_BITS = { OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7, MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6, MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5, MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5,
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4 MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4,
} }
SAMPLING_FREQUENCY_BITS = { SAMPLING_FREQUENCY_BITS = {
8000: 1 << 11, 8000: 1 << 11,
11025: 1 << 10, 11025: 1 << 10,
12000: 1 << 9, 12000: 1 << 9,
16000: 1 << 8, 16000: 1 << 8,
@@ -343,66 +376,66 @@ class AacMediaCodecInformation(
48000: 1 << 3, 48000: 1 << 3,
64000: 1 << 2, 64000: 1 << 2,
88200: 1 << 1, 88200: 1 << 1,
96000: 1 96000: 1,
}
CHANNELS_BITS = {
1: 1 << 1,
2: 1
} }
CHANNELS_BITS = {1: 1 << 1, 2: 1}
@staticmethod @staticmethod
def from_bytes(data): def from_bytes(data):
return AacMediaCodecInformation(*bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data))
@classmethod
def from_discrete_values(
cls,
object_type,
sampling_frequency,
channels,
vbr,
bitrate
):
return AacMediaCodecInformation( return AacMediaCodecInformation(
object_type = cls.OBJECT_TYPE_BITS[object_type], *bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data)
sampling_frequency = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels = cls.CHANNELS_BITS[channels],
vbr = vbr,
bitrate = bitrate
) )
@classmethod @classmethod
def from_lists( def from_discrete_values(
cls, cls, object_type, sampling_frequency, channels, vbr, bitrate
object_types,
sampling_frequencies,
channels,
vbr,
bitrate
): ):
return AacMediaCodecInformation( return AacMediaCodecInformation(
object_type = sum(cls.OBJECT_TYPE_BITS[x] for x in object_types), object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies), sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels = sum(cls.CHANNELS_BITS[x] for x in channels), channels=cls.CHANNELS_BITS[channels],
vbr = vbr, vbr=vbr,
bitrate = bitrate bitrate=bitrate,
)
@classmethod
def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate):
return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channels=sum(cls.CHANNELS_BITS[x] for x in channels),
vbr=vbr,
bitrate=bitrate,
) )
def __bytes__(self): def __bytes__(self):
return bitstruct.pack(self.BIT_FIELDS, *self) return bitstruct.pack(self.BIT_FIELDS, *self)
def __str__(self): def __str__(self):
object_types = ['MPEG_2_AAC_LC', 'MPEG_4_AAC_LC', 'MPEG_4_AAC_LTP', 'MPEG_4_AAC_SCALABLE', '[4]', '[5]', '[6]', '[7]'] object_types = [
'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC',
'MPEG_4_AAC_LTP',
'MPEG_4_AAC_SCALABLE',
'[4]',
'[5]',
'[6]',
'[7]',
]
channels = [1, 2] channels = [1, 2]
return '\n'.join([ # pylint: disable=line-too-long
'AacMediaCodecInformation(', return '\n'.join(
f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}', [
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}', 'AacMediaCodecInformation(',
f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}', f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}',
f' vbr: {self.vbr}', f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}',
f' bitrate: {self.bitrate}' f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}',
')' f' vbr: {self.vbr}',
]) f' bitrate: {self.bitrate}' ')',
]
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -418,37 +451,34 @@ class VendorSpecificMediaCodecInformation:
def __init__(self, vendor_id, codec_id, value): def __init__(self, vendor_id, codec_id, value):
self.vendor_id = vendor_id self.vendor_id = vendor_id
self.codec_id = codec_id self.codec_id = codec_id
self.value = value self.value = value
def __bytes__(self): def __bytes__(self):
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value) return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
def __str__(self): def __str__(self):
return '\n'.join([ # pylint: disable=line-too-long
'VendorSpecificMediaCodecInformation(', return '\n'.join(
f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})', [
f' codec_id: {self.codec_id:04X}', 'VendorSpecificMediaCodecInformation(',
f' value: {self.value.hex()}' f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})',
')' f' codec_id: {self.codec_id:04X}',
]) f' value: {self.value.hex()}' ')',
]
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class SbcFrame: class SbcFrame:
def __init__( def __init__(
self, self, sampling_frequency, block_count, channel_mode, subband_count, payload
sampling_frequency,
block_count,
channel_mode,
subband_count,
payload
): ):
self.sampling_frequency = sampling_frequency self.sampling_frequency = sampling_frequency
self.block_count = block_count self.block_count = block_count
self.channel_mode = channel_mode self.channel_mode = channel_mode
self.subband_count = subband_count self.subband_count = subband_count
self.payload = payload self.payload = payload
@property @property
def sample_count(self): def sample_count(self):
@@ -463,7 +493,13 @@ class SbcFrame:
return self.sample_count / self.sampling_frequency return self.sample_count / self.sampling_frequency
def __str__(self): def __str__(self):
return f'SBC(sf={self.sampling_frequency},cm={self.channel_mode},br={self.bitrate},sc={self.sample_count},size={len(self.payload)})' return (
f'SBC(sf={self.sampling_frequency},'
f'cm={self.channel_mode},'
f'br={self.bitrate},'
f'sc={self.sample_count},'
f'size={len(self.payload)})'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -487,24 +523,30 @@ class SbcParser:
# Extract some of the header fields # Extract some of the header fields
sampling_frequency = SBC_SAMPLING_FREQUENCIES[(header[1] >> 6) & 3] sampling_frequency = SBC_SAMPLING_FREQUENCIES[(header[1] >> 6) & 3]
blocks = 4 * (1 + ((header[1] >> 4) & 3)) blocks = 4 * (1 + ((header[1] >> 4) & 3))
channel_mode = (header[1] >> 2) & 3 channel_mode = (header[1] >> 2) & 3
channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2 channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2
subbands = 8 if ((header[1]) & 1) else 4 subbands = 8 if ((header[1]) & 1) else 4
bitpool = header[2] bitpool = header[2]
# Compute the frame length # Compute the frame length
frame_length = 4 + (4 * subbands * channels) // 8 frame_length = 4 + (4 * subbands * channels) // 8
if channel_mode in (SBC_MONO_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE): if channel_mode in (SBC_MONO_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE):
frame_length += (blocks * channels * bitpool) // 8 frame_length += (blocks * channels * bitpool) // 8
else: else:
frame_length += ((1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0) * subbands + blocks * bitpool) // 8 frame_length += (
(1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0)
* subbands
+ blocks * bitpool
) // 8
# Read the rest of the frame # Read the rest of the frame
payload = header + await self.read(frame_length - 4) payload = header + await self.read(frame_length - 4)
# Emit the next frame # Emit the next frame
yield SbcFrame(sampling_frequency, blocks, channel_mode, subbands, payload) yield SbcFrame(
sampling_frequency, blocks, channel_mode, subbands, payload
)
return generate_frames() return generate_frames()
@@ -512,19 +554,20 @@ class SbcParser:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class SbcPacketSource: class SbcPacketSource:
def __init__(self, read, mtu, codec_capabilities): def __init__(self, read, mtu, codec_capabilities):
self.read = read self.read = read
self.mtu = mtu self.mtu = mtu
self.codec_capabilities = codec_capabilities self.codec_capabilities = codec_capabilities
@property @property
def packets(self): def packets(self):
async def generate_packets(): async def generate_packets():
# pylint: disable=import-outside-toplevel
from .avdtp import MediaPacket # Import here to avoid a circular reference from .avdtp import MediaPacket # Import here to avoid a circular reference
sequence_number = 0 sequence_number = 0
timestamp = 0 timestamp = 0
frames = [] frames = []
frames_size = 0 frames_size = 0
max_rtp_payload = self.mtu - 12 - 1 max_rtp_payload = self.mtu - 12 - 1
# NOTE: this doesn't support frame fragments # NOTE: this doesn't support frame fragments
@@ -532,18 +575,25 @@ class SbcPacketSource:
async for frame in sbc_parser.frames: async for frame in sbc_parser.frames:
print(frame) print(frame)
if frames_size + len(frame.payload) > max_rtp_payload or len(frames) == 16: if (
frames_size + len(frame.payload) > max_rtp_payload
or len(frames) == 16
):
# Need to flush what has been accumulated so far # Need to flush what has been accumulated so far
# Emit a packet # Emit a packet
sbc_payload = bytes([len(frames)]) + b''.join([frame.payload for frame in frames]) sbc_payload = bytes([len(frames)]) + b''.join(
packet = MediaPacket(2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload) [frame.payload for frame in frames]
)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
)
packet.timestamp_seconds = timestamp / frame.sampling_frequency packet.timestamp_seconds = timestamp / frame.sampling_frequency
yield packet yield packet
# Prepare for next packets # Prepare for next packets
sequence_number += 1 sequence_number += 1
timestamp += sum([frame.sample_count for frame in frames]) timestamp += sum((frame.sample_count for frame in frames))
frames = [frame] frames = [frame]
frames_size = len(frame.payload) frames_size = len(frame.payload)
else: else:

View File

@@ -22,15 +22,20 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import struct
from colors import color from colors import color
from pyee import EventEmitter from pyee import EventEmitter
from .core import * from bumble.core import UUID, name_or_number
from .hci import * from bumble.hci import HCI_Object, key_with_value
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
ATT_CID = 0x04 ATT_CID = 0x04
ATT_ERROR_RESPONSE = 0x01 ATT_ERROR_RESPONSE = 0x01
@@ -163,19 +168,14 @@ ATT_ERROR_NAMES = {
ATT_DEFAULT_MTU = 23 ATT_DEFAULT_MTU = 23
HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'} HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'}
UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y) # noqa: E731 # pylint: disable-next=unnecessary-lambda-assignment,unnecessary-lambda
UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y)
# pylint: disable-next=unnecessary-lambda-assignment,unnecessary-lambda
UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731 UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# fmt: on
# ----------------------------------------------------------------------------- # pylint: enable=line-too-long
# Utils # pylint: disable=invalid-name
# -----------------------------------------------------------------------------
def key_with_value(dictionary, target_value):
for key, value in dictionary.items():
if value == target_value:
return key
return None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Exceptions # Exceptions
@@ -196,8 +196,10 @@ class ATT_PDU:
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.3 ATTRIBUTE PDU See Bluetooth spec @ Vol 3, Part F - 3.3 ATTRIBUTE PDU
''' '''
pdu_classes = {} pdu_classes = {}
op_code = 0 op_code = 0
name = None
@staticmethod @staticmethod
def from_bytes(pdu): def from_bytes(pdu):
@@ -274,11 +276,13 @@ class ATT_PDU:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}), [
('attribute_handle_in_error', HANDLE_FIELD_SPEC), ('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}),
('error_code', {'size': 1, 'mapper': ATT_PDU.error_name}) ('attribute_handle_in_error', HANDLE_FIELD_SPEC),
]) ('error_code', {'size': 1, 'mapper': ATT_PDU.error_name}),
]
)
class ATT_Error_Response(ATT_PDU): class ATT_Error_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response See Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response
@@ -286,9 +290,7 @@ class ATT_Error_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('client_rx_mtu', 2)])
('client_rx_mtu', 2)
])
class ATT_Exchange_MTU_Request(ATT_PDU): class ATT_Exchange_MTU_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.2.1 Exchange MTU Request See Bluetooth spec @ Vol 3, Part F - 3.4.2.1 Exchange MTU Request
@@ -296,9 +298,7 @@ class ATT_Exchange_MTU_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('server_rx_mtu', 2)])
('server_rx_mtu', 2)
])
class ATT_Exchange_MTU_Response(ATT_PDU): class ATT_Exchange_MTU_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.2.2 Exchange MTU Response See Bluetooth spec @ Vol 3, Part F - 3.4.2.2 Exchange MTU Response
@@ -306,10 +306,9 @@ class ATT_Exchange_MTU_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
('starting_handle', HANDLE_FIELD_SPEC), [('starting_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC)]
('ending_handle', HANDLE_FIELD_SPEC) )
])
class ATT_Find_Information_Request(ATT_PDU): class ATT_Find_Information_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.1 Find Information Request See Bluetooth spec @ Vol 3, Part F - 3.4.3.1 Find Information Request
@@ -317,10 +316,7 @@ class ATT_Find_Information_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('format', 1), ('information_data', '*')])
('format', 1),
('information_data', '*')
])
class ATT_Find_Information_Response(ATT_PDU): class ATT_Find_Information_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.2 Find Information Response See Bluetooth spec @ Vol 3, Part F - 3.4.3.2 Find Information Response
@@ -332,7 +328,7 @@ class ATT_Find_Information_Response(ATT_PDU):
uuid_size = 2 if self.format == 1 else 16 uuid_size = 2 if self.format == 1 else 16
while offset + uuid_size <= len(self.information_data): while offset + uuid_size <= len(self.information_data):
handle = struct.unpack_from('<H', self.information_data, offset)[0] handle = struct.unpack_from('<H', self.information_data, offset)[0]
uuid = self.information_data[2 + offset:2 + offset + uuid_size] uuid = self.information_data[2 + offset : 2 + offset + uuid_size]
self.information.append((handle, uuid)) self.information.append((handle, uuid))
offset += 2 + uuid_size offset += 2 + uuid_size
@@ -346,20 +342,33 @@ class ATT_Find_Information_Response(ATT_PDU):
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [ result += ':\n' + HCI_Object.format_fields(
('format', 1), self.__dict__,
('information', {'mapper': lambda x: ', '.join([f'0x{handle:04X}:{uuid.hex()}' for handle, uuid in x])}) [
], ' ') ('format', 1),
(
'information',
{
'mapper': lambda x: ', '.join(
[f'0x{handle:04X}:{uuid.hex()}' for handle, uuid in x]
)
},
),
],
' ',
)
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
('starting_handle', HANDLE_FIELD_SPEC), [
('ending_handle', HANDLE_FIELD_SPEC), ('starting_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*') ('attribute_type', UUID_2_FIELD_SPEC),
]) ('attribute_value', '*'),
]
)
class ATT_Find_By_Type_Value_Request(ATT_PDU): class ATT_Find_By_Type_Value_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.3 Find By Type Value Request See Bluetooth spec @ Vol 3, Part F - 3.4.3.3 Find By Type Value Request
@@ -367,9 +376,7 @@ class ATT_Find_By_Type_Value_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('handles_information_list', '*')])
('handles_information_list', '*')
])
class ATT_Find_By_Type_Value_Response(ATT_PDU): class ATT_Find_By_Type_Value_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.4 Find By Type Value Response See Bluetooth spec @ Vol 3, Part F - 3.4.3.4 Find By Type Value Response
@@ -379,7 +386,9 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU):
self.handles_information = [] self.handles_information = []
offset = 0 offset = 0
while offset + 4 <= len(self.handles_information_list): while offset + 4 <= len(self.handles_information_list):
found_attribute_handle, group_end_handle = struct.unpack_from('<HH', self.handles_information_list, offset) found_attribute_handle, group_end_handle = struct.unpack_from(
'<HH', self.handles_information_list, offset
)
self.handles_information.append((found_attribute_handle, group_end_handle)) self.handles_information.append((found_attribute_handle, group_end_handle))
offset += 4 offset += 4
@@ -393,18 +402,34 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU):
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [ result += ':\n' + HCI_Object.format_fields(
('handles_information', {'mapper': lambda x: ', '.join([f'0x{handle1:04X}-0x{handle2:04X}' for handle1, handle2 in x])}) self.__dict__,
], ' ') [
(
'handles_information',
{
'mapper': lambda x: ', '.join(
[
f'0x{handle1:04X}-0x{handle2:04X}'
for handle1, handle2 in x
]
)
},
)
],
' ',
)
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
('starting_handle', HANDLE_FIELD_SPEC), [
('ending_handle', HANDLE_FIELD_SPEC), ('starting_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_16_FIELD_SPEC) ('ending_handle', HANDLE_FIELD_SPEC),
]) ('attribute_type', UUID_2_16_FIELD_SPEC),
]
)
class ATT_Read_By_Type_Request(ATT_PDU): class ATT_Read_By_Type_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.1 Read By Type Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.1 Read By Type Request
@@ -412,10 +437,7 @@ class ATT_Read_By_Type_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')])
('length', 1),
('attribute_data_list', '*')
])
class ATT_Read_By_Type_Response(ATT_PDU): class ATT_Read_By_Type_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.2 Read By Type Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.2 Read By Type Response
@@ -424,9 +446,15 @@ class ATT_Read_By_Type_Response(ATT_PDU):
def parse_attribute_data_list(self): def parse_attribute_data_list(self):
self.attributes = [] self.attributes = []
offset = 0 offset = 0
while self.length != 0 and offset + self.length <= len(self.attribute_data_list): while self.length != 0 and offset + self.length <= len(
attribute_handle, = struct.unpack_from('<H', self.attribute_data_list, offset) self.attribute_data_list
attribute_value = self.attribute_data_list[offset + 2:offset + self.length] ):
(attribute_handle,) = struct.unpack_from(
'<H', self.attribute_data_list, offset
)
attribute_value = self.attribute_data_list[
offset + 2 : offset + self.length
]
self.attributes.append((attribute_handle, attribute_value)) self.attributes.append((attribute_handle, attribute_value))
offset += self.length offset += self.length
@@ -440,17 +468,26 @@ class ATT_Read_By_Type_Response(ATT_PDU):
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [ result += ':\n' + HCI_Object.format_fields(
('length', 1), self.__dict__,
('attributes', {'mapper': lambda x: ', '.join([f'0x{handle:04X}:{value.hex()}' for handle, value in x])}) [
], ' ') ('length', 1),
(
'attributes',
{
'mapper': lambda x: ', '.join(
[f'0x{handle:04X}:{value.hex()}' for handle, value in x]
)
},
),
],
' ',
)
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC)])
('attribute_handle', HANDLE_FIELD_SPEC)
])
class ATT_Read_Request(ATT_PDU): class ATT_Read_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.3 Read Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.3 Read Request
@@ -458,9 +495,7 @@ class ATT_Read_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_value', '*')])
('attribute_value', '*')
])
class ATT_Read_Response(ATT_PDU): class ATT_Read_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.4 Read Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.4 Read Response
@@ -468,10 +503,7 @@ class ATT_Read_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('value_offset', 2)])
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2)
])
class ATT_Read_Blob_Request(ATT_PDU): class ATT_Read_Blob_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.5 Read Blob Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.5 Read Blob Request
@@ -479,9 +511,7 @@ class ATT_Read_Blob_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('part_attribute_value', '*')])
('part_attribute_value', '*')
])
class ATT_Read_Blob_Response(ATT_PDU): class ATT_Read_Blob_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.6 Read Blob Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.6 Read Blob Response
@@ -489,9 +519,7 @@ class ATT_Read_Blob_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('set_of_handles', '*')])
('set_of_handles', '*')
])
class ATT_Read_Multiple_Request(ATT_PDU): class ATT_Read_Multiple_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.7 Read Multiple Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.7 Read Multiple Request
@@ -499,9 +527,7 @@ class ATT_Read_Multiple_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('set_of_values', '*')])
('set_of_values', '*')
])
class ATT_Read_Multiple_Response(ATT_PDU): class ATT_Read_Multiple_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.8 Read Multiple Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.8 Read Multiple Response
@@ -509,11 +535,13 @@ class ATT_Read_Multiple_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
('starting_handle', HANDLE_FIELD_SPEC), [
('ending_handle', HANDLE_FIELD_SPEC), ('starting_handle', HANDLE_FIELD_SPEC),
('attribute_group_type', UUID_2_16_FIELD_SPEC) ('ending_handle', HANDLE_FIELD_SPEC),
]) ('attribute_group_type', UUID_2_16_FIELD_SPEC),
]
)
class ATT_Read_By_Group_Type_Request(ATT_PDU): class ATT_Read_By_Group_Type_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.9 Read by Group Type Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.9 Read by Group Type Request
@@ -521,10 +549,7 @@ class ATT_Read_By_Group_Type_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')])
('length', 1),
('attribute_data_list', '*')
])
class ATT_Read_By_Group_Type_Response(ATT_PDU): class ATT_Read_By_Group_Type_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.10 Read by Group Type Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.10 Read by Group Type Response
@@ -533,10 +558,18 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
def parse_attribute_data_list(self): def parse_attribute_data_list(self):
self.attributes = [] self.attributes = []
offset = 0 offset = 0
while self.length != 0 and offset + self.length <= len(self.attribute_data_list): while self.length != 0 and offset + self.length <= len(
attribute_handle, end_group_handle = struct.unpack_from('<HH', self.attribute_data_list, offset) self.attribute_data_list
attribute_value = self.attribute_data_list[offset + 4:offset + self.length] ):
self.attributes.append((attribute_handle, end_group_handle, attribute_value)) attribute_handle, end_group_handle = struct.unpack_from(
'<HH', self.attribute_data_list, offset
)
attribute_value = self.attribute_data_list[
offset + 4 : offset + self.length
]
self.attributes.append(
(attribute_handle, end_group_handle, attribute_value)
)
offset += self.length offset += self.length
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@@ -549,18 +582,29 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [ result += ':\n' + HCI_Object.format_fields(
('length', 1), self.__dict__,
('attributes', {'mapper': lambda x: ', '.join([f'0x{handle:04X}-0x{end:04X}:{value.hex()}' for handle, end, value in x])}) [
], ' ') ('length', 1),
(
'attributes',
{
'mapper': lambda x: ', '.join(
[
f'0x{handle:04X}-0x{end:04X}:{value.hex()}'
for handle, end, value in x
]
)
},
),
],
' ',
)
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Write_Request(ATT_PDU): class ATT_Write_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.1 Write Request See Bluetooth spec @ Vol 3, Part F - 3.4.5.1 Write Request
@@ -576,10 +620,7 @@ class ATT_Write_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Write_Command(ATT_PDU): class ATT_Write_Command(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.3 Write Command See Bluetooth spec @ Vol 3, Part F - 3.4.5.3 Write Command
@@ -587,11 +628,13 @@ class ATT_Write_Command(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
('attribute_handle', HANDLE_FIELD_SPEC), [
('attribute_value', '*') ('attribute_handle', HANDLE_FIELD_SPEC),
# ('authentication_signature', 'TODO') ('attribute_value', '*')
]) # ('authentication_signature', 'TODO')
]
)
class ATT_Signed_Write_Command(ATT_PDU): class ATT_Signed_Write_Command(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.4 Signed Write Command See Bluetooth spec @ Vol 3, Part F - 3.4.5.4 Signed Write Command
@@ -599,11 +642,13 @@ class ATT_Signed_Write_Command(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
('attribute_handle', HANDLE_FIELD_SPEC), [
('value_offset', 2), ('attribute_handle', HANDLE_FIELD_SPEC),
('part_attribute_value', '*') ('value_offset', 2),
]) ('part_attribute_value', '*'),
]
)
class ATT_Prepare_Write_Request(ATT_PDU): class ATT_Prepare_Write_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.6.1 Prepare Write Request See Bluetooth spec @ Vol 3, Part F - 3.4.6.1 Prepare Write Request
@@ -611,11 +656,13 @@ class ATT_Prepare_Write_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
('attribute_handle', HANDLE_FIELD_SPEC), [
('value_offset', 2), ('attribute_handle', HANDLE_FIELD_SPEC),
('part_attribute_value', '*') ('value_offset', 2),
]) ('part_attribute_value', '*'),
]
)
class ATT_Prepare_Write_Response(ATT_PDU): class ATT_Prepare_Write_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.6.2 Prepare Write Response See Bluetooth spec @ Vol 3, Part F - 3.4.6.2 Prepare Write Response
@@ -639,10 +686,7 @@ class ATT_Execute_Write_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Handle_Value_Notification(ATT_PDU): class ATT_Handle_Value_Notification(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.7.1 Handle Value Notification See Bluetooth spec @ Vol 3, Part F - 3.4.7.1 Handle Value Notification
@@ -650,10 +694,7 @@ class ATT_Handle_Value_Notification(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Handle_Value_Indication(ATT_PDU): class ATT_Handle_Value_Indication(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.7.2 Handle Value Indication See Bluetooth spec @ Vol 3, Part F - 3.4.7.2 Handle Value Indication
@@ -671,58 +712,80 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Attribute(EventEmitter): class Attribute(EventEmitter):
# Permission flags # Permission flags
READABLE = 0x01 READABLE = 0x01
WRITEABLE = 0x02 WRITEABLE = 0x02
READ_REQUIRES_ENCRYPTION = 0x04 READ_REQUIRES_ENCRYPTION = 0x04
WRITE_REQUIRES_ENCRYPTION = 0x08 WRITE_REQUIRES_ENCRYPTION = 0x08
READ_REQUIRES_AUTHENTICATION = 0x10 READ_REQUIRES_AUTHENTICATION = 0x10
WRITE_REQUIRES_AUTHENTICATION = 0x20 WRITE_REQUIRES_AUTHENTICATION = 0x20
READ_REQUIRES_AUTHORIZATION = 0x40 READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80 WRITE_REQUIRES_AUTHORIZATION = 0x80
def __init__(self, attribute_type, permissions, value = b''): def __init__(self, attribute_type, permissions, value=b''):
EventEmitter.__init__(self) EventEmitter.__init__(self)
self.handle = 0 self.handle = 0
self.end_group_handle = 0
self.permissions = permissions self.permissions = permissions
# Convert the type to a UUID # Convert the type to a UUID object if it isn't already
if type(attribute_type) is bytes: if isinstance(attribute_type, str):
self.type = UUID(attribute_type)
elif isinstance(attribute_type, bytes):
self.type = UUID.from_bytes(attribute_type) self.type = UUID.from_bytes(attribute_type)
else: else:
self.type = attribute_type self.type = attribute_type
# Convert the value to a byte array # Convert the value to a byte array
if type(value) is str: if isinstance(value, str):
self.value = bytes(value, 'utf-8') self.value = bytes(value, 'utf-8')
else: else:
self.value = value self.value = value
def read_value(self, connection): def encode_value(self, value):
if type(self.value) is bytes: return value
return self.value
else: def decode_value(self, value_bytes):
if read := getattr(self.value, 'read', None): return value_bytes
try:
return read(connection) def read_value(self, connection):
except ATT_Error as error: if read := getattr(self.value, 'read', None):
raise ATT_Error(error_code=error.error_code, att_handle=self.handle) try:
else: value = read(connection) # pylint: disable=not-callable
return bytes(self.value) except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
value = self.value
return self.encode_value(value)
def write_value(self, connection, value_bytes):
value = self.decode_value(value_bytes)
def write_value(self, connection, value):
if write := getattr(self.value, 'write', None): if write := getattr(self.value, 'write', None):
try: try:
write(connection, value) write(connection, value) # pylint: disable=not-callable
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error(error_code=error.error_code, att_handle=self.handle) raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else: else:
self.value = value self.value = value
self.emit('write', connection, value) self.emit('write', connection, value)
def __repr__(self): def __repr__(self):
if len(self.value) > 0: if isinstance(self.value, bytes):
value_str = self.value.hex()
else:
value_str = str(self.value)
if value_str:
value_string = f', value={self.value.hex()}' value_string = f', value={self.value.hex()}'
else: else:
value_string = '' value_string = ''
return f'Attribute(handle=0x{self.handle:04X}, type={self.type}, permissions={self.permissions}{value_string})' return (
f'Attribute(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'permissions={self.permissions}{value_string})'
)

File diff suppressed because it is too large Load Diff

View File

@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
class HCI_Bridge: class HCI_Bridge:
class Forwarder: class Forwarder:
def __init__(self, hci_sink, sender_hci_sink, packet_filter, trace): def __init__(self, hci_sink, sender_hci_sink, packet_filter, trace):
self.hci_sink = hci_sink self.hci_sink = hci_sink
self.sender_hci_sink = sender_hci_sink self.sender_hci_sink = sender_hci_sink
self.packet_filter = packet_filter self.packet_filter = packet_filter
self.trace = trace self.trace = trace
def on_packet(self, packet): def on_packet(self, packet):
# Convert the packet bytes to an object # Convert the packet bytes to an object
@@ -61,15 +61,15 @@ class HCI_Bridge:
hci_host_sink, hci_host_sink,
hci_controller_source, hci_controller_source,
hci_controller_sink, hci_controller_sink,
host_to_controller_filter = None, host_to_controller_filter=None,
controller_to_host_filter = None controller_to_host_filter=None,
): ):
tracer = PacketTracer(emit_message=logger.info) tracer = PacketTracer(emit_message=logger.info)
host_to_controller_forwarder = HCI_Bridge.Forwarder( host_to_controller_forwarder = HCI_Bridge.Forwarder(
hci_controller_sink, hci_controller_sink,
hci_host_sink, hci_host_sink,
host_to_controller_filter, host_to_controller_filter,
lambda packet: tracer.trace(packet, 0) lambda packet: tracer.trace(packet, 0),
) )
hci_host_source.set_packet_sink(host_to_controller_forwarder) hci_host_source.set_packet_sink(host_to_controller_forwarder)
@@ -77,6 +77,6 @@ class HCI_Bridge:
hci_host_sink, hci_host_sink,
hci_controller_sink, hci_controller_sink,
controller_to_host_filter, controller_to_host_filter,
lambda packet: tracer.trace(packet, 1) lambda packet: tracer.trace(packet, 1),
) )
hci_controller_source.set_packet_sink(controller_to_host_forwarder) hci_controller_source.set_packet_sink(controller_to_host_forwarder)

View File

@@ -17,6 +17,7 @@
# the `generate_company_id_list.py` script # the `generate_company_id_list.py` script
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# pylint: disable=line-too-long
COMPANY_IDENTIFIERS = { COMPANY_IDENTIFIERS = {
0x0000: "Ericsson Technology Licensing", 0x0000: "Ericsson Technology Licensing",
0x0001: "Nokia Mobile Phones", 0x0001: "Nokia Mobile Phones",
@@ -196,28 +197,28 @@ COMPANY_IDENTIFIERS = {
0x00AF: "Cinetix", 0x00AF: "Cinetix",
0x00B0: "Passif Semiconductor Corp", 0x00B0: "Passif Semiconductor Corp",
0x00B1: "Saris Cycling Group, Inc", 0x00B1: "Saris Cycling Group, Inc",
0x00B2: "Bekey A/S", 0x00B2: "Bekey A/S",
0x00B3: "Clarinox Technologies Pty. Ltd.", 0x00B3: "Clarinox Technologies Pty. Ltd.",
0x00B4: "BDE Technology Co., Ltd.", 0x00B4: "BDE Technology Co., Ltd.",
0x00B5: "Swirl Networks", 0x00B5: "Swirl Networks",
0x00B6: "Meso international", 0x00B6: "Meso international",
0x00B7: "TreLab Ltd", 0x00B7: "TreLab Ltd",
0x00B8: "Qualcomm Innovation Center, Inc. (QuIC)", 0x00B8: "Qualcomm Innovation Center, Inc. (QuIC)",
0x00B9: "Johnson Controls, Inc.", 0x00B9: "Johnson Controls, Inc.",
0x00BA: "Starkey Laboratories Inc.", 0x00BA: "Starkey Laboratories Inc.",
0x00BB: "S-Power Electronics Limited", 0x00BB: "S-Power Electronics Limited",
0x00BC: "Ace Sensor Inc", 0x00BC: "Ace Sensor Inc",
0x00BD: "Aplix Corporation", 0x00BD: "Aplix Corporation",
0x00BE: "AAMP of America", 0x00BE: "AAMP of America",
0x00BF: "Stalmart Technology Limited", 0x00BF: "Stalmart Technology Limited",
0x00C0: "AMICCOM Electronics Corporation", 0x00C0: "AMICCOM Electronics Corporation",
0x00C1: "Shenzhen Excelsecu Data Technology Co.,Ltd", 0x00C1: "Shenzhen Excelsecu Data Technology Co.,Ltd",
0x00C2: "Geneq Inc.", 0x00C2: "Geneq Inc.",
0x00C3: "adidas AG", 0x00C3: "adidas AG",
0x00C4: "LG Electronics", 0x00C4: "LG Electronics",
0x00C5: "Onset Computer Corporation", 0x00C5: "Onset Computer Corporation",
0x00C6: "Selfly BV", 0x00C6: "Selfly BV",
0x00C7: "Quuppa Oy.", 0x00C7: "Quuppa Oy.",
0x00C8: "GeLo Inc", 0x00C8: "GeLo Inc",
0x00C9: "Evluma", 0x00C9: "Evluma",
0x00CA: "MC10", 0x00CA: "MC10",
@@ -249,10 +250,10 @@ COMPANY_IDENTIFIERS = {
0x00E4: "Laird Connectivity, Inc. formerly L.S. Research Inc.", 0x00E4: "Laird Connectivity, Inc. formerly L.S. Research Inc.",
0x00E5: "Eden Software Consultants Ltd.", 0x00E5: "Eden Software Consultants Ltd.",
0x00E6: "Freshtemp", 0x00E6: "Freshtemp",
0x00E7: "KS Technologies", 0x00E7: "KS Technologies",
0x00E8: "ACTS Technologies", 0x00E8: "ACTS Technologies",
0x00E9: "Vtrack Systems", 0x00E9: "Vtrack Systems",
0x00EA: "Nielsen-Kellerman Company", 0x00EA: "Nielsen-Kellerman Company",
0x00EB: "Server Technology Inc.", 0x00EB: "Server Technology Inc.",
0x00EC: "BioResearch Associates", 0x00EC: "BioResearch Associates",
0x00ED: "Jolly Logic, LLC", 0x00ED: "Jolly Logic, LLC",
@@ -2704,5 +2705,5 @@ COMPANY_IDENTIFIERS = {
0x0A7C: "WAFERLOCK", 0x0A7C: "WAFERLOCK",
0x0A7D: "Freedman Electronics Pty Ltd", 0x0A7D: "Freedman Electronics Pty Ltd",
0x0A7E: "Keba AG", 0x0A7E: "Keba AG",
0x0A7F: "Intuity Medical" 0x0A7F: "Intuity Medical",
} }

File diff suppressed because it is too large Load Diff

View File

@@ -23,6 +23,8 @@ from .company_ids import COMPANY_IDENTIFIERS
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
BT_CENTRAL_ROLE = 0 BT_CENTRAL_ROLE = 0
BT_PERIPHERAL_ROLE = 1 BT_PERIPHERAL_ROLE = 1
@@ -30,6 +32,9 @@ BT_BR_EDR_TRANSPORT = 0
BT_LE_TRANSPORT = 1 BT_LE_TRANSPORT = 1
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -58,17 +63,25 @@ def padded_bytes(buffer, size):
return buffer + bytes(padding_size) return buffer + bytes(padding_size)
def get_dict_key_by_value(dictionary, value):
for key, val in dictionary.items():
if val == value:
return key
return None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Exceptions # Exceptions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class BaseError(Exception): class BaseError(Exception):
""" Base class for errors with an error code, error name and namespace""" """Base class for errors with an error code, error name and namespace"""
def __init__(self, error_code, error_namespace='', error_name='', details=''): def __init__(self, error_code, error_namespace='', error_name='', details=''):
super().__init__() super().__init__()
self.error_code = error_code self.error_code = error_code
self.error_namespace = error_namespace self.error_namespace = error_namespace
self.error_name = error_name self.error_name = error_name
self.details = details self.details = details
def __str__(self): def __str__(self):
if self.error_namespace: if self.error_namespace:
@@ -84,22 +97,40 @@ class BaseError(Exception):
class ProtocolError(BaseError): class ProtocolError(BaseError):
""" Protocol Error """ """Protocol Error"""
class TimeoutError(Exception): class TimeoutError(Exception): # pylint: disable=redefined-builtin
""" Timeout Error """ """Timeout Error"""
class CommandTimeoutError(Exception):
"""Command Timeout Error"""
class InvalidStateError(Exception): class InvalidStateError(Exception):
""" Invalid State Error """ """Invalid State Error"""
class ConnectionError(BaseError): class ConnectionError(BaseError): # pylint: disable=redefined-builtin
""" Connection Error """ """Connection Error"""
FAILURE = 0x01
FAILURE = 0x01
CONNECTION_REFUSED = 0x02 CONNECTION_REFUSED = 0x02
def __init__(
self,
error_code,
transport,
peer_address,
error_namespace='',
error_name='',
details='',
):
super().__init__(error_code, error_namespace, error_name, details)
self.transport = transport
self.peer_address = peer_address
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# UUID # UUID
@@ -112,26 +143,33 @@ class UUID:
''' '''
See Bluetooth spec Vol 3, Part B - 2.5.1 UUID See Bluetooth spec Vol 3, Part B - 2.5.1 UUID
''' '''
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB') BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')
UUIDS = [] # Registry of all instances created UUIDS = [] # Registry of all instances created
def __init__(self, uuid_str_or_int, name = None): def __init__(self, uuid_str_or_int, name=None):
if type(uuid_str_or_int) is int: if isinstance(uuid_str_or_int, int):
self.uuid_bytes = struct.pack('<H', uuid_str_or_int) self.uuid_bytes = struct.pack('<H', uuid_str_or_int)
else: else:
if len(uuid_str_or_int) == 36: if len(uuid_str_or_int) == 36:
if uuid_str_or_int[8] != '-' or uuid_str_or_int[13] != '-' or uuid_str_or_int[18] != '-' or uuid_str_or_int[23] != '-': if (
uuid_str_or_int[8] != '-'
or uuid_str_or_int[13] != '-'
or uuid_str_or_int[18] != '-'
or uuid_str_or_int[23] != '-'
):
raise ValueError('invalid UUID format') raise ValueError('invalid UUID format')
uuid_str = uuid_str_or_int.replace('-', '') uuid_str = uuid_str_or_int.replace('-', '')
else: else:
uuid_str = uuid_str_or_int uuid_str = uuid_str_or_int
if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4: if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4:
raise ValueError('invalid UUID format') raise ValueError(f"invalid UUID format: {uuid_str}")
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str))) self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
self.name = name self.name = name
def register(self): def register(self):
# Register this object in the class registry, and update the entry's name if it wasn't set already # Register this object in the class registry, and update the entry's name if
# it wasn't set already
for uuid in self.UUIDS: for uuid in self.UUIDS:
if self == uuid: if self == uuid:
if uuid.name is None: if uuid.name is None:
@@ -142,39 +180,40 @@ class UUID:
return self return self
@classmethod @classmethod
def from_bytes(cls, uuid_bytes, name = None): def from_bytes(cls, uuid_bytes, name=None):
if len(uuid_bytes) in {2, 4, 16}: if len(uuid_bytes) in (2, 4, 16):
self = cls.__new__(cls) self = cls.__new__(cls)
self.uuid_bytes = uuid_bytes self.uuid_bytes = uuid_bytes
self.name = name self.name = name
return self.register() return self.register()
else:
raise ValueError('only 2, 4 and 16 bytes are allowed') raise ValueError('only 2, 4 and 16 bytes are allowed')
@classmethod @classmethod
def from_16_bits(cls, uuid_16, name = None): def from_16_bits(cls, uuid_16, name=None):
return cls.from_bytes(struct.pack('<H', uuid_16), name) return cls.from_bytes(struct.pack('<H', uuid_16), name)
@classmethod @classmethod
def from_32_bits(cls, uuid_32, name = None): def from_32_bits(cls, uuid_32, name=None):
return cls.from_bytes(struct.pack('<I', uuid_32), name) return cls.from_bytes(struct.pack('<I', uuid_32), name)
@classmethod @classmethod
def parse_uuid(cls, bytes, offset): def parse_uuid(cls, uuid_as_bytes, offset):
return len(bytes), cls.from_bytes(bytes[offset:]) return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:])
@classmethod @classmethod
def parse_uuid_2(cls, bytes, offset): def parse_uuid_2(cls, uuid_as_bytes, offset):
return offset + 2, cls.from_bytes(bytes[offset:offset + 2]) return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
def to_bytes(self, force_128 = False): def to_bytes(self, force_128=False):
if len(self.uuid_bytes) == 16 or not force_128: if len(self.uuid_bytes) == 16 or not force_128:
return self.uuid_bytes return self.uuid_bytes
elif len(self.uuid_bytes) == 4:
if len(self.uuid_bytes) == 4:
return self.uuid_bytes + UUID.BASE_UUID return self.uuid_bytes + UUID.BASE_UUID
else:
return self.uuid_bytes + bytes([0, 0]) + UUID.BASE_UUID return self.uuid_bytes + bytes([0, 0]) + UUID.BASE_UUID
def to_pdu_bytes(self): def to_pdu_bytes(self):
''' '''
@@ -183,27 +222,30 @@ class UUID:
"All 32-bit Attribute UUIDs shall be converted to 128-bit UUIDs when the "All 32-bit Attribute UUIDs shall be converted to 128-bit UUIDs when the
Attribute UUID is contained in an ATT PDU." Attribute UUID is contained in an ATT PDU."
''' '''
return self.to_bytes(force_128 = (len(self.uuid_bytes) == 4)) return self.to_bytes(force_128=(len(self.uuid_bytes) == 4))
def to_hex_str(self): def to_hex_str(self):
if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4: if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4:
return bytes(reversed(self.uuid_bytes)).hex().upper() return bytes(reversed(self.uuid_bytes)).hex().upper()
else:
return ''.join([ return ''.join(
[
bytes(reversed(self.uuid_bytes[12:16])).hex(), bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(), bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(), bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(), bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex() bytes(reversed(self.uuid_bytes[0:6])).hex(),
]).upper() ]
).upper()
def __bytes__(self): def __bytes__(self):
return self.to_bytes() return self.to_bytes()
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, UUID): if isinstance(other, UUID):
return self.to_bytes(force_128 = True) == other.to_bytes(force_128 = True) return self.to_bytes(force_128=True) == other.to_bytes(force_128=True)
elif type(other) is str:
if isinstance(other, str):
return UUID(other) == self return UUID(other) == self
return False return False
@@ -213,23 +255,26 @@ class UUID:
def __str__(self): def __str__(self):
if len(self.uuid_bytes) == 2: if len(self.uuid_bytes) == 2:
v = struct.unpack('<H', self.uuid_bytes)[0] uuid = struct.unpack('<H', self.uuid_bytes)[0]
result = f'UUID-16:{v:04X}' result = f'UUID-16:{uuid:04X}'
elif len(self.uuid_bytes) == 4: elif len(self.uuid_bytes) == 4:
v = struct.unpack('<I', self.uuid_bytes)[0] uuid = struct.unpack('<I', self.uuid_bytes)[0]
result = f'UUID-32:{v:08X}' result = f'UUID-32:{uuid:08X}'
else: else:
result = '-'.join([ result = '-'.join(
bytes(reversed(self.uuid_bytes[12:16])).hex(), [
bytes(reversed(self.uuid_bytes[10:12])).hex(), bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(), bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(), bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex() bytes(reversed(self.uuid_bytes[6:8])).hex(),
]).upper() bytes(reversed(self.uuid_bytes[0:6])).hex(),
]
).upper()
if self.name is not None: if self.name is not None:
return result + f' ({self.name})' return result + f' ({self.name})'
else:
return result return result
def __repr__(self): def __repr__(self):
return str(self) return str(self)
@@ -238,6 +283,8 @@ class UUID:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Common UUID constants # Common UUID constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
# Protocol Identifiers # Protocol Identifiers
BT_SDP_PROTOCOL_ID = UUID.from_16_bits(0x0001, 'SDP') BT_SDP_PROTOCOL_ID = UUID.from_16_bits(0x0001, 'SDP')
@@ -343,11 +390,17 @@ BT_HDP_SERVICE = UUID.from_16_bits(0x1400,
BT_HDP_SOURCE_SERVICE = UUID.from_16_bits(0x1401, 'HDP Source') BT_HDP_SOURCE_SERVICE = UUID.from_16_bits(0x1401, 'HDP Source')
BT_HDP_SINK_SERVICE = UUID.from_16_bits(0x1402, 'HDP Sink') BT_HDP_SINK_SERVICE = UUID.from_16_bits(0x1402, 'HDP Sink')
# fmt: on
# pylint: enable=line-too-long
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# DeviceClass # DeviceClass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class DeviceClass: class DeviceClass:
# fmt: off
# pylint: disable=line-too-long
# Major Service Classes (flags combined with OR) # Major Service Classes (flags combined with OR)
LIMITED_DISCOVERABLE_MODE_SERVICE_CLASS = (1 << 0) LIMITED_DISCOVERABLE_MODE_SERVICE_CLASS = (1 << 0)
LE_AUDIO_SERVICE_CLASS = (1 << 1) LE_AUDIO_SERVICE_CLASS = (1 << 1)
@@ -515,11 +568,18 @@ class DeviceClass:
PERIPHERAL_MAJOR_DEVICE_CLASS: PERIPHERAL_MINOR_DEVICE_CLASS_NAMES PERIPHERAL_MAJOR_DEVICE_CLASS: PERIPHERAL_MINOR_DEVICE_CLASS_NAMES
} }
# fmt: on
# pylint: enable=line-too-long
@staticmethod @staticmethod
def split_class_of_device(class_of_device): def split_class_of_device(class_of_device):
# Split the bit fields of the composite class of device value into: # Split the bit fields of the composite class of device value into:
# (service_classes, major_device_class, minor_device_class) # (service_classes, major_device_class, minor_device_class)
return ((class_of_device >> 13 & 0x7FF), (class_of_device >> 8 & 0x1F), (class_of_device >> 2 & 0x3F)) return (
(class_of_device >> 13 & 0x7FF),
(class_of_device >> 8 & 0x1F),
(class_of_device >> 2 & 0x3F),
)
@staticmethod @staticmethod
def pack_class_of_device(service_classes, major_device_class, minor_device_class): def pack_class_of_device(service_classes, major_device_class, minor_device_class):
@@ -527,7 +587,9 @@ class DeviceClass:
@staticmethod @staticmethod
def service_class_labels(service_class_flags): def service_class_labels(service_class_flags):
return bit_flags_to_strings(service_class_flags, DeviceClass.SERVICE_CLASS_LABELS) return bit_flags_to_strings(
service_class_flags, DeviceClass.SERVICE_CLASS_LABELS
)
@staticmethod @staticmethod
def major_device_class_name(device_class): def major_device_class_name(device_class):
@@ -545,6 +607,9 @@ class DeviceClass:
# Advertising Data # Advertising Data
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AdvertisingData: class AdvertisingData:
# fmt: off
# pylint: disable=line-too-long
# This list is only partial, it still needs to be filled in from the spec # This list is only partial, it still needs to be filled in from the spec
FLAGS = 0x01 FLAGS = 0x01
INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02 INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02
@@ -656,7 +721,12 @@ class AdvertisingData:
BR_EDR_CONTROLLER_FLAG = 0x08 BR_EDR_CONTROLLER_FLAG = 0x08
BR_EDR_HOST_FLAG = 0x10 BR_EDR_HOST_FLAG = 0x10
def __init__(self, ad_structures = []): # fmt: on
# pylint: enable=line-too-long
def __init__(self, ad_structures=None):
if ad_structures is None:
ad_structures = []
self.ad_structures = ad_structures[:] self.ad_structures = ad_structures[:]
@staticmethod @staticmethod
@@ -667,19 +737,17 @@ class AdvertisingData:
@staticmethod @staticmethod
def flags_to_string(flags, short=False): def flags_to_string(flags, short=False):
flag_names = [ flag_names = (
'LE Limited', ['LE Limited', 'LE General', 'No BR/EDR', 'BR/EDR C', 'BR/EDR H']
'LE General', if short
'No BR/EDR', else [
'BR/EDR C', 'LE Limited Discoverable Mode',
'BR/EDR H' 'LE General Discoverable Mode',
] if short else [ 'BR/EDR Not Supported',
'LE Limited Discoverable Mode', 'Simultaneous LE and BR/EDR (Controller)',
'LE General Discoverable Mode', 'Simultaneous LE and BR/EDR (Host)',
'BR/EDR Not Supported', ]
'Simultaneous LE and BR/EDR (Controller)', )
'Simultaneous LE and BR/EDR (Host)'
]
return ','.join(bit_flags_to_strings(flags, flag_names)) return ','.join(bit_flags_to_strings(flags, flag_names))
@staticmethod @staticmethod
@@ -687,16 +755,18 @@ class AdvertisingData:
uuids = [] uuids = []
offset = 0 offset = 0
while (uuid_size * (offset + 1)) <= len(ad_data): while (uuid_size * (offset + 1)) <= len(ad_data):
uuids.append(UUID.from_bytes(ad_data[offset:offset + uuid_size])) uuids.append(UUID.from_bytes(ad_data[offset : offset + uuid_size]))
offset += uuid_size offset += uuid_size
return uuids return uuids
@staticmethod @staticmethod
def uuid_list_to_string(ad_data, uuid_size): def uuid_list_to_string(ad_data, uuid_size):
return ', '.join([ return ', '.join(
str(uuid) [
for uuid in AdvertisingData.uuid_list_to_objects(ad_data, uuid_size) str(uuid)
]) for uuid in AdvertisingData.uuid_list_to_objects(ad_data, uuid_size)
]
)
@staticmethod @staticmethod
def ad_data_to_string(ad_type, ad_data): def ad_data_to_string(ad_type, ad_data):
@@ -756,40 +826,65 @@ class AdvertisingData:
return f'[{ad_type_str}]: {ad_data_str}' return f'[{ad_type_str}]: {ad_data_str}'
# pylint: disable=too-many-return-statements
@staticmethod @staticmethod
def ad_data_to_object(ad_type, ad_data): def ad_data_to_object(ad_type, ad_data):
if ad_type in { if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
}: AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
):
return AdvertisingData.uuid_list_to_objects(ad_data, 2) return AdvertisingData.uuid_list_to_objects(ad_data, 2)
elif ad_type in {
if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
}: AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
):
return AdvertisingData.uuid_list_to_objects(ad_data, 4) return AdvertisingData.uuid_list_to_objects(ad_data, 4)
elif ad_type in {
if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
}: AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
):
return AdvertisingData.uuid_list_to_objects(ad_data, 16) return AdvertisingData.uuid_list_to_objects(ad_data, 16)
elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
if ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
return (UUID.from_bytes(ad_data[:2]), ad_data[2:]) return (UUID.from_bytes(ad_data[:2]), ad_data[2:])
elif ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID:
if ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID:
return (UUID.from_bytes(ad_data[:4]), ad_data[4:]) return (UUID.from_bytes(ad_data[:4]), ad_data[4:])
elif ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID:
if ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID:
return (UUID.from_bytes(ad_data[:16]), ad_data[16:]) return (UUID.from_bytes(ad_data[:16]), ad_data[16:])
elif ad_type in {
if ad_type in (
AdvertisingData.SHORTENED_LOCAL_NAME, AdvertisingData.SHORTENED_LOCAL_NAME,
AdvertisingData.COMPLETE_LOCAL_NAME AdvertisingData.COMPLETE_LOCAL_NAME,
}: AdvertisingData.URI,
):
return ad_data.decode("utf-8") return ad_data.decode("utf-8")
elif ad_type == AdvertisingData.TX_POWER_LEVEL:
if ad_type in (AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS):
return ad_data[0] return ad_data[0]
elif ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
if ad_type in (
AdvertisingData.APPEARANCE,
AdvertisingData.ADVERTISING_INTERVAL,
):
return struct.unpack('<H', ad_data)[0]
if ad_type == AdvertisingData.CLASS_OF_DEVICE:
return struct.unpack('<I', bytes([*ad_data, 0]))[0]
if ad_type == AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE:
return struct.unpack('<HH', ad_data)
if ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
return (struct.unpack_from('<H', ad_data, 0)[0], ad_data[2:]) return (struct.unpack_from('<H', ad_data, 0)[0], ad_data[2:])
else:
return ad_data return ad_data
def append(self, data): def append(self, data):
offset = 0 offset = 0
@@ -798,30 +893,40 @@ class AdvertisingData:
offset += 1 offset += 1
if length > 0: if length > 0:
ad_type = data[offset] ad_type = data[offset]
ad_data = data[offset + 1:offset + length] ad_data = data[offset + 1 : offset + length]
self.ad_structures.append((ad_type, ad_data)) self.ad_structures.append((ad_type, ad_data))
offset += length offset += length
def get(self, type_id, return_all=False, raw=True): def get(self, type_id, return_all=False, raw=False):
''' '''
Get Advertising Data Structure(s) with a given type Get Advertising Data Structure(s) with a given type
If return_all is True, returns a (possibly empty) list of matches, If return_all is True, returns a (possibly empty) list of matches,
else returns the first entry, or None if no structure matches. else returns the first entry, or None if no structure matches.
''' '''
def process_ad_data(ad_data): def process_ad_data(ad_data):
return ad_data if raw else self.ad_data_to_object(type_id, ad_data) return ad_data if raw else self.ad_data_to_object(type_id, ad_data)
if return_all: if return_all:
return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id] return [
else: process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id
return next((process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id), None) ]
return next(
(process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id),
None,
)
def __bytes__(self): def __bytes__(self):
return b''.join([bytes([len(x[1]) + 1, x[0]]) + x[1] for x in self.ad_structures]) return b''.join(
[bytes([len(x[1]) + 1, x[0]]) + x[1] for x in self.ad_structures]
)
def to_string(self, separator=', '): def to_string(self, separator=', '):
return separator.join([AdvertisingData.ad_data_to_string(x[0], x[1]) for x in self.ad_structures]) return separator.join(
[AdvertisingData.ad_data_to_string(x[0], x[1]) for x in self.ad_structures]
)
def __str__(self): def __str__(self):
return self.to_string() return self.to_string()
@@ -831,13 +936,17 @@ class AdvertisingData:
# Connection Parameters # Connection Parameters
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ConnectionParameters: class ConnectionParameters:
def __init__(self, connection_interval, connection_latency, supervision_timeout): def __init__(self, connection_interval, peripheral_latency, supervision_timeout):
self.connection_interval = connection_interval self.connection_interval = connection_interval
self.connection_latency = connection_latency self.peripheral_latency = peripheral_latency
self.supervision_timeout = supervision_timeout self.supervision_timeout = supervision_timeout
def __str__(self): def __str__(self):
return f'ConnectionParameters(connection_interval={self.connection_interval}, connection_latency={self.connection_latency}, supervision_timeout={self.supervision_timeout}' return (
f'ConnectionParameters(connection_interval={self.connection_interval}, '
f'peripheral_latency={self.peripheral_latency}, '
f'supervision_timeout={self.supervision_timeout}'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -24,19 +24,16 @@
import logging import logging
import operator import operator
import platform import platform
if platform.system() != 'Emscripten': if platform.system() != 'Emscripten':
import secrets import secrets
from cryptography.hazmat.primitives.ciphers import ( from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
Cipher,
algorithms,
modes
)
from cryptography.hazmat.primitives.asymmetric.ec import ( from cryptography.hazmat.primitives.asymmetric.ec import (
generate_private_key, generate_private_key,
ECDH, ECDH,
EllipticCurvePublicNumbers, EllipticCurvePublicNumbers,
EllipticCurvePrivateNumbers, EllipticCurvePrivateNumbers,
SECP256R1 SECP256R1,
) )
from cryptography.hazmat.primitives import cmac from cryptography.hazmat.primitives import cmac
else: else:
@@ -66,16 +63,26 @@ class EccKey:
d = int.from_bytes(d_bytes, byteorder='big', signed=False) d = int.from_bytes(d_bytes, byteorder='big', signed=False)
x = int.from_bytes(x_bytes, byteorder='big', signed=False) x = int.from_bytes(x_bytes, byteorder='big', signed=False)
y = int.from_bytes(y_bytes, byteorder='big', signed=False) y = int.from_bytes(y_bytes, byteorder='big', signed=False)
private_key = EllipticCurvePrivateNumbers(d, EllipticCurvePublicNumbers(x, y, SECP256R1())).private_key() private_key = EllipticCurvePrivateNumbers(
d, EllipticCurvePublicNumbers(x, y, SECP256R1())
).private_key()
return cls(private_key) return cls(private_key)
@property @property
def x(self): def x(self):
return self.private_key.public_key().public_numbers().x.to_bytes(32, byteorder='big') return (
self.private_key.public_key()
.public_numbers()
.x.to_bytes(32, byteorder='big')
)
@property @property
def y(self): def y(self):
return self.private_key.public_key().public_numbers().y.to_bytes(32, byteorder='big') return (
self.private_key.public_key()
.public_numbers()
.y.to_bytes(32, byteorder='big')
)
def dh(self, public_key_x, public_key_y): def dh(self, public_key_x, public_key_y):
x = int.from_bytes(public_key_x, byteorder='big', signed=False) x = int.from_bytes(public_key_x, byteorder='big', signed=False)
@@ -92,7 +99,7 @@ class EccKey:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def xor(x, y): def xor(x, y):
assert(len(x) == len(y)) assert len(x) == len(y)
return bytes(map(operator.xor, x, y)) return bytes(map(operator.xor, x, y))
@@ -118,7 +125,7 @@ def e(key, data):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def ah(k, r): def ah(k, r): # pylint: disable=redefined-outer-name
''' '''
See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah
''' '''
@@ -129,9 +136,10 @@ def ah(k, r):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def c1(k, r, preq, pres, iat, rat, ia, ra): def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-name
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for LE Legacy Pairing See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for
LE Legacy Pairing
''' '''
p1 = bytes([iat, rat]) + preq + pres p1 = bytes([iat, rat]) + preq + pres
@@ -142,7 +150,8 @@ def c1(k, r, preq, pres, iat, rat, ia, ra):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def s1(k, r1, r2): def s1(k, r1, r2):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy Pairing See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy
Pairing
''' '''
return e(k, r2[0:8] + r1[0:8]) return e(k, r2[0:8] + r1[0:8])
@@ -163,67 +172,106 @@ def aes_cmac(m, k):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def f4(u, v, x, z): def f4(u, v, x, z):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value Generation Function f4 See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value
Generation Function f4
''' '''
return bytes(reversed(aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x))))) return bytes(
reversed(
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def f5(w, n1, n2, a1, a2): def f5(w, n1, n2, a1, a2):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation Function f5 See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation
Function f5
NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order
''' '''
salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE') salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE')
t = aes_cmac(bytes(reversed(w)), salt) t = aes_cmac(bytes(reversed(w)), salt)
key_id = bytes([0x62, 0x74, 0x6c, 0x65]) key_id = bytes([0x62, 0x74, 0x6C, 0x65])
return ( return (
bytes(reversed(aes_cmac( bytes(
bytes([0]) + reversed(
key_id + aes_cmac(
bytes(reversed(n1)) + bytes([0])
bytes(reversed(n2)) + + key_id
bytes(reversed(a1)) + + bytes(reversed(n1))
bytes(reversed(a2)) + + bytes(reversed(n2))
bytes([1, 0]), + bytes(reversed(a1))
t + bytes(reversed(a2))
))), + bytes([1, 0]),
bytes(reversed(aes_cmac( t,
bytes([1]) + )
key_id + )
bytes(reversed(n1)) + ),
bytes(reversed(n2)) + bytes(
bytes(reversed(a1)) + reversed(
bytes(reversed(a2)) + aes_cmac(
bytes([1, 0]), bytes([1])
t + key_id
))) + bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(a1))
+ bytes(reversed(a2))
+ bytes([1, 0]),
t,
)
)
),
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def f6(w, n1, n2, r, io_cap, a1, a2): def f6(w, n1, n2, r, io_cap, a1, a2): # pylint: disable=redefined-outer-name
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value Generation Function f6 See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value
Generation Function f6
''' '''
return bytes(reversed(aes_cmac( return bytes(
bytes(reversed(n1)) + reversed(
bytes(reversed(n2)) + aes_cmac(
bytes(reversed(r)) + bytes(reversed(n1))
bytes(reversed(io_cap)) + + bytes(reversed(n2))
bytes(reversed(a1)) + + bytes(reversed(r))
bytes(reversed(a2)), + bytes(reversed(io_cap))
bytes(reversed(w)) + bytes(reversed(a1))
))) + bytes(reversed(a2)),
bytes(reversed(w)),
)
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def g2(u, v, x, y): def g2(u, v, x, y):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison Value Generation Function g2 See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison
Value Generation Function g2
''' '''
return int.from_bytes( return int.from_bytes(
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)), bytes(reversed(x)))[-4:], aes_cmac(
byteorder='big' bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)),
bytes(reversed(x)),
)[-4:],
byteorder='big',
) )
# -----------------------------------------------------------------------------
def h6(w, key_id):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.10 Link key conversion function h6
'''
return aes_cmac(key_id, w)
# -----------------------------------------------------------------------------
def h7(salt, w):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.11 Link key conversion function h7
'''
return aes_cmac(w, salt)

File diff suppressed because it is too large Load Diff

View File

@@ -23,7 +23,7 @@ from .gatt import (
Characteristic, Characteristic,
GATT_GENERIC_ACCESS_SERVICE, GATT_GENERIC_ACCESS_SERVICE,
GATT_DEVICE_NAME_CHARACTERISTIC, GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC GATT_APPEARANCE_CHARACTERISTIC,
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -38,22 +38,22 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class GenericAccessService(Service): class GenericAccessService(Service):
def __init__(self, device_name, appearance = (0, 0)): def __init__(self, device_name, appearance=(0, 0)):
device_name_characteristic = Characteristic( device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC, GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
device_name.encode('utf-8')[:248] device_name.encode('utf-8')[:248],
) )
appearance_characteristic = Characteristic( appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC, GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
struct.pack('<H', (appearance[0] << 6) | appearance[1]) struct.pack('<H', (appearance[0] << 6) | appearance[1]),
) )
super().__init__(GATT_GENERIC_ACCESS_SERVICE, [ super().__init__(
device_name_characteristic, GATT_GENERIC_ACCESS_SERVICE,
appearance_characteristic [device_name_characteristic, appearance_characteristic],
]) )

View File

@@ -22,12 +22,18 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import enum
import functools
import logging import logging
import struct
from typing import Sequence
from colors import color from colors import color
from .core import * from .core import UUID, get_dict_key_by_value
from .hci import * from .att import Attribute
from .att import *
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -37,6 +43,9 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
GATT_REQUEST_TIMEOUT = 30 # seconds GATT_REQUEST_TIMEOUT = 30 # seconds
GATT_MAX_ATTRIBUTE_VALUE_SIZE = 512 GATT_MAX_ATTRIBUTE_VALUE_SIZE = 512
@@ -53,13 +62,13 @@ GATT_NEXT_DST_CHANGE_SERVICE = UUID.from_16_bits(0x1807, 'Next DS
GATT_GLUCOSE_SERVICE = UUID.from_16_bits(0x1808, 'Glucose') GATT_GLUCOSE_SERVICE = UUID.from_16_bits(0x1808, 'Glucose')
GATT_HEALTH_THERMOMETER_SERVICE = UUID.from_16_bits(0x1809, 'Health Thermometer') GATT_HEALTH_THERMOMETER_SERVICE = UUID.from_16_bits(0x1809, 'Health Thermometer')
GATT_DEVICE_INFORMATION_SERVICE = UUID.from_16_bits(0x180A, 'Device Information') GATT_DEVICE_INFORMATION_SERVICE = UUID.from_16_bits(0x180A, 'Device Information')
GATT_DEVICE_HEART_RATE_SERVICE = UUID.from_16_bits(0x180D, 'Heart Rate') GATT_HEART_RATE_SERVICE = UUID.from_16_bits(0x180D, 'Heart Rate')
GATT_PHONE_ALTERT_STATUS_SERVICE = UUID.from_16_bits(0x180E, 'Phone Alert Status') GATT_PHONE_ALERT_STATUS_SERVICE = UUID.from_16_bits(0x180E, 'Phone Alert Status')
GATT_DEVICE_BATTERY_SERVICE = UUID.from_16_bits(0x180F, 'Battery') GATT_BATTERY_SERVICE = UUID.from_16_bits(0x180F, 'Battery')
GATT_BLOOD_PRESSURE_SERVICE = UUID.from_16_bits(0x1810, 'Blood Pressure') GATT_BLOOD_PRESSURE_SERVICE = UUID.from_16_bits(0x1810, 'Blood Pressure')
GATT_ALTERT_NOTIFICATION_SERVICE = UUID.from_16_bits(0x1811, 'Alert Notification') GATT_ALERT_NOTIFICATION_SERVICE = UUID.from_16_bits(0x1811, 'Alert Notification')
GATT_DEVICE_HUMAN_INTERFACE_DEVICE_SERVICE = UUID.from_16_bits(0x1812, 'Human Interface Device') GATT_HUMAN_INTERFACE_DEVICE_SERVICE = UUID.from_16_bits(0x1812, 'Human Interface Device')
GATT_DEVICE_SCAN_PARAMETERS_SERVICE = UUID.from_16_bits(0x1813, 'Scan Parameters') GATT_SCAN_PARAMETERS_SERVICE = UUID.from_16_bits(0x1813, 'Scan Parameters')
GATT_RUNNING_SPEED_AND_CADENCE_SERVICE = UUID.from_16_bits(0x1814, 'Running Speed and Cadence') GATT_RUNNING_SPEED_AND_CADENCE_SERVICE = UUID.from_16_bits(0x1814, 'Running Speed and Cadence')
GATT_AUTOMATION_IO_SERVICE = UUID.from_16_bits(0x1815, 'Automation IO') GATT_AUTOMATION_IO_SERVICE = UUID.from_16_bits(0x1815, 'Automation IO')
GATT_CYCLING_SPEED_AND_CADENCE_SERVICE = UUID.from_16_bits(0x1816, 'Cycling Speed and Cadence') GATT_CYCLING_SPEED_AND_CADENCE_SERVICE = UUID.from_16_bits(0x1816, 'Cycling Speed and Cadence')
@@ -119,7 +128,7 @@ GATT_ENVIRONMENTAL_SENSING_CONFIGURATION_DESCRIPTOR = UUID.from_16_bits(0x290B,
GATT_ENVIRONMENTAL_SENSING_MEASUREMENT_DESCRIPTOR = UUID.from_16_bits(0x290C, 'Environmental Sensing Measurement') GATT_ENVIRONMENTAL_SENSING_MEASUREMENT_DESCRIPTOR = UUID.from_16_bits(0x290C, 'Environmental Sensing Measurement')
GATT_ENVIRONMENTAL_SENSING_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290D, 'Environmental Sensing Trigger Setting') GATT_ENVIRONMENTAL_SENSING_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290D, 'Environmental Sensing Trigger Setting')
GATT_TIME_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290E, 'Time Trigger Setting') GATT_TIME_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290E, 'Time Trigger Setting')
GATT_COMPLETE_BE_EDR_TRANSPORT_BLOCK_DATA_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Complete BR-EDR Transport Block Data') GATT_COMPLETE_BR_EDR_TRANSPORT_BLOCK_DATA_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Complete BR-EDR Transport Block Data')
# Device Information Service # Device Information Service
GATT_SYSTEM_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A23, 'System ID') GATT_SYSTEM_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A23, 'System ID')
@@ -132,33 +141,52 @@ GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC = UUID.from_16_bits(0x2A2
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2A2A, 'IEEE 11073-20601 Regulatory Certification Data List') GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2A2A, 'IEEE 11073-20601 Regulatory Certification Data List')
GATT_PNP_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A50, 'PnP ID') GATT_PNP_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A50, 'PnP ID')
# Human Interface Device # Human Interface Device Service
GATT_HID_INFORMATION_CHARACTERISTIC = UUID.from_16_bits(0x2A4A, 'HID Information') GATT_HID_INFORMATION_CHARACTERISTIC = UUID.from_16_bits(0x2A4A, 'HID Information')
GATT_REPORT_MAP_CHARACTERISTIC = UUID.from_16_bits(0x2A4B, 'Report Map') GATT_REPORT_MAP_CHARACTERISTIC = UUID.from_16_bits(0x2A4B, 'Report Map')
GATT_HID_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A4C, 'HID Control Point') GATT_HID_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A4C, 'HID Control Point')
GATT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A4D, 'Report') GATT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A4D, 'Report')
GATT_PROTOCOL_MODE_CHARACTERISTIC = UUID.from_16_bits(0x2A4E, 'Protocol Mode') GATT_PROTOCOL_MODE_CHARACTERISTIC = UUID.from_16_bits(0x2A4E, 'Protocol Mode')
# Heart Rate Service
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC = UUID.from_16_bits(0x2A37, 'Heart Rate Measurement')
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2A38, 'Body Sensor Location')
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A39, 'Heart Rate Control Point')
# Battery Service
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level')
# ASHA Service
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID('f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint')
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus')
GATT_ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT')
# Misc # Misc
GATT_DEVICE_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2A00, 'Device Name') GATT_DEVICE_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2A00, 'Device Name')
GATT_APPEARANCE_CHARACTERISTIC = UUID.from_16_bits(0x2A01, 'Appearance') GATT_APPEARANCE_CHARACTERISTIC = UUID.from_16_bits(0x2A01, 'Appearance')
GATT_PERIPHERAL_PRIVACY_FLAG_CHARACTERISTIC = UUID.from_16_bits(0x2A02, 'Peripheral Privacy Flag') GATT_PERIPHERAL_PRIVACY_FLAG_CHARACTERISTIC = UUID.from_16_bits(0x2A02, 'Peripheral Privacy Flag')
GATT_RECONNECTION_ADDRESS_CHARACTERISTIC = UUID.from_16_bits(0x2A03, 'Reconnection Address') GATT_RECONNECTION_ADDRESS_CHARACTERISTIC = UUID.from_16_bits(0x2A03, 'Reconnection Address')
GATT_PERIPHERAL_PREFERRREED_CONNECTION_PARAMETERS_CHARACTERISTIC = UUID.from_16_bits(0x2A04, 'Peripheral Preferred Connection Parameters') GATT_PERIPHERAL_PREFERRED_CONNECTION_PARAMETERS_CHARACTERISTIC = UUID.from_16_bits(0x2A04, 'Peripheral Preferred Connection Parameters')
GATT_SERVICE_CHANGED_CHARACTERISTIC = UUID.from_16_bits(0x2A05, 'Service Changed') GATT_SERVICE_CHANGED_CHARACTERISTIC = UUID.from_16_bits(0x2A05, 'Service Changed')
GATT_ALERT_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A06, 'Alert Level') GATT_ALERT_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A06, 'Alert Level')
GATT_TX_POWER_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A07, 'Tx Power Level') GATT_TX_POWER_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A07, 'Tx Power Level')
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level') GATT_BOOT_KEYBOARD_INPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A22, 'Boot Keyboard Input Report')
GATT_BOOT_KEYBOARD_INPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A22, 'Boot Keyboard Input Report') GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bits(0x2A2B, 'Current Time')
GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bits(0x2A2B, 'Current Time') GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report')
GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report') GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
# fmt: on
# pylint: enable=line-too-long
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def show_services(services): def show_services(services):
for service in services: for service in services:
print(color(str(service), 'cyan')) print(color(str(service), 'cyan'))
@@ -176,24 +204,51 @@ class Service(Attribute):
See Vol 3, Part G - 3.1 SERVICE DEFINITION See Vol 3, Part G - 3.1 SERVICE DEFINITION
''' '''
def __init__(self, uuid, characteristics, primary=True): def __init__(self, uuid, characteristics: list[Characteristic], primary=True):
# Convert the uuid to a UUID object if it isn't already # Convert the uuid to a UUID object if it isn't already
if type(uuid) is str: if isinstance(uuid, str):
uuid = UUID(uuid) uuid = UUID(uuid)
super().__init__( super().__init__(
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE if primary else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Attribute.READABLE, Attribute.READABLE,
uuid.to_pdu_bytes() uuid.to_pdu_bytes(),
) )
self.uuid = uuid self.uuid = uuid
self.included_services = [] self.included_services = []
self.characteristics = characteristics[:] self.characteristics = characteristics[:]
self.end_group_handle = 0 self.primary = primary
self.primary = primary
def get_advertising_data(self):
"""
Get Service specific advertising data
Defined by each Service, default value is empty
:return Service data for advertising
"""
return None
def __str__(self): def __str__(self):
return f'Service(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}){"" if self.primary else "*"}' return (
f'Service(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
f'uuid={self.uuid})'
f'{"" if self.primary else "*"}'
)
# -----------------------------------------------------------------------------
class TemplateService(Service):
'''
Convenience abstract class that can be used by profile-specific subclasses that want
to expose their UUID as a class property
'''
UUID = None
def __init__(self, characteristics, primary=True):
super().__init__(self.UUID, characteristics, primary)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -203,80 +258,113 @@ class Characteristic(Attribute):
''' '''
# Property flags # Property flags
BROADCAST = 0x01 BROADCAST = 0x01
READ = 0x02 READ = 0x02
WRITE_WITHOUT_RESPONSE = 0x04 WRITE_WITHOUT_RESPONSE = 0x04
WRITE = 0x08 WRITE = 0x08
NOTIFY = 0x10 NOTIFY = 0x10
INDICATE = 0X20 INDICATE = 0x20
AUTHENTICATED_SIGNED_WRITES = 0X40 AUTHENTICATED_SIGNED_WRITES = 0x40
EXTENDED_PROPERTIES = 0X80 EXTENDED_PROPERTIES = 0x80
PROPERTY_NAMES = { PROPERTY_NAMES = {
BROADCAST: 'BROADCAST', BROADCAST: 'BROADCAST',
READ: 'READ', READ: 'READ',
WRITE_WITHOUT_RESPONSE: 'WRITE_WITHOUT_RESPONSE', WRITE_WITHOUT_RESPONSE: 'WRITE_WITHOUT_RESPONSE',
WRITE: 'WRITE', WRITE: 'WRITE',
NOTIFY: 'NOTIFY', NOTIFY: 'NOTIFY',
INDICATE: 'INDICATE', INDICATE: 'INDICATE',
AUTHENTICATED_SIGNED_WRITES: 'AUTHENTICATED_SIGNED_WRITES', AUTHENTICATED_SIGNED_WRITES: 'AUTHENTICATED_SIGNED_WRITES',
EXTENDED_PROPERTIES: 'EXTENDED_PROPERTIES' EXTENDED_PROPERTIES: 'EXTENDED_PROPERTIES',
} }
@staticmethod @staticmethod
def property_name(property): def property_name(property_int):
return Characteristic.PROPERTY_NAMES.get(property, '') return Characteristic.PROPERTY_NAMES.get(property_int, '')
def __init__(self, uuid, properties, permissions, value = b'', descriptors = []): @staticmethod
# Convert the uuid to a UUID object if it isn't already def properties_as_string(properties):
if type(uuid) is str: return ','.join(
uuid = UUID(uuid) [
Characteristic.property_name(p)
for p in Characteristic.PROPERTY_NAMES
if properties & p
]
)
@staticmethod
def string_to_properties(properties_str: str):
return functools.reduce(
lambda x, y: x | get_dict_key_by_value(Characteristic.PROPERTY_NAMES, y),
properties_str.split(","),
0,
)
def __init__(
self,
uuid,
properties,
permissions,
value=b'',
descriptors: Sequence[Descriptor] = (),
):
super().__init__(uuid, permissions, value) super().__init__(uuid, permissions, value)
self.uuid = uuid self.uuid = self.type
self.properties = properties if isinstance(properties, str):
self._descriptors = descriptors self.properties = Characteristic.string_to_properties(properties)
self._descriptors_discovered = False else:
self.end_group_handle = 0 self.properties = properties
self.attach_descriptors() self.descriptors = descriptors
def attach_descriptors(self):
""" Let all the descriptors know they are attached to this characteristic """
for descriptor in self._descriptors:
descriptor.characteristic = self
def add_descriptor(self, descriptor):
descriptor.characteristic = self
self.descriptors.append(descriptor)
def get_descriptor(self, descriptor_type): def get_descriptor(self, descriptor_type):
for descriptor in self.descriptors: for descriptor in self.descriptors:
if descriptor.uuid == descriptor_type: if descriptor.type == descriptor_type:
return descriptor return descriptor
@property return None
def descriptors(self):
return self._descriptors
@descriptors.setter
def descriptors(self, value):
self._descriptors = value
self._descriptors_discovered = True
self.attach_descriptors()
@property
def descriptors_discovered(self):
return self._descriptors_discovered
def get_properties_as_string(self):
return ','.join([self.property_name(p) for p in self.PROPERTY_NAMES.keys() if self.properties & p])
def __str__(self): def __str__(self):
return f'Characteristic(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}, properties={self.get_properties_as_string()})' return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
f'uuid={self.uuid}, '
f'properties={Characteristic.properties_as_string(self.properties)})'
)
# -----------------------------------------------------------------------------
class CharacteristicDeclaration(Attribute):
'''
See Vol 3, Part G - 3.3.1 CHARACTERISTIC DECLARATION
'''
def __init__(self, characteristic, value_handle):
declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes()
)
super().__init__(
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
)
self.value_handle = value_handle
self.characteristic = characteristic
def __str__(self):
return (
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
f'value_handle=0x{self.value_handle:04X}, '
f'uuid={self.characteristic.uuid}, properties='
f'{Characteristic.properties_as_string(self.characteristic.properties)})'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CharacteristicValue: class CharacteristicValue:
'''
Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(self, read=None, write=None): def __init__(self, read=None, write=None):
self._read = read self._read = read
self._write = write self._write = write
@@ -289,20 +377,208 @@ class CharacteristicValue:
self._write(connection, value) self._write(connection, value)
# -----------------------------------------------------------------------------
class CharacteristicAdapter:
'''
An adapter that can adapt any object with `read_value` and `write_value`
methods (like Characteristic and CharacteristicProxy objects) by wrapping
those methods with ones that return/accept encoded/decoded values.
Objects with async methods are considered proxies, so the adaptation is one
where the return value of `read_value` is decoded and the value passed to
`write_value` is encoded. Other objects are considered local characteristics
so the adaptation is one where the return value of `read_value` is encoded
and the value passed to `write_value` is decoded.
If the characteristic has a `subscribe` method, it is wrapped with one where
the values are decoded before being passed to the subscriber.
'''
def __init__(self, characteristic):
self.wrapped_characteristic = characteristic
self.subscribers = {} # Map from subscriber to proxy subscriber
if asyncio.iscoroutinefunction(
characteristic.read_value
) and asyncio.iscoroutinefunction(characteristic.write_value):
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
else:
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
if hasattr(self.wrapped_characteristic, 'subscribe'):
self.subscribe = self.wrapped_subscribe
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name):
return getattr(self.wrapped_characteristic, name)
def __setattr__(self, name, value):
if name in (
'wrapped_characteristic',
'subscribers',
'read_value',
'write_value',
'subscribe',
'unsubscribe',
):
super().__setattr__(name, value)
else:
setattr(self.wrapped_characteristic, name, value)
def read_encoded_value(self, connection):
return self.encode_value(self.wrapped_characteristic.read_value(connection))
def write_encoded_value(self, connection, value):
return self.wrapped_characteristic.write_value(
connection, self.decode_value(value)
)
async def read_decoded_value(self):
return self.decode_value(await self.wrapped_characteristic.read_value())
async def write_decoded_value(self, value, with_response=False):
return await self.wrapped_characteristic.write_value(
self.encode_value(value), with_response
)
def encode_value(self, value):
return value
def decode_value(self, value):
return value
def wrapped_subscribe(self, subscriber=None):
if subscriber is not None:
if subscriber in self.subscribers:
# We already have a proxy subscriber
subscriber = self.subscribers[subscriber]
else:
# Create and register a proxy that will decode the value
original_subscriber = subscriber
def on_change(value):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
return self.wrapped_characteristic.subscribe(subscriber)
def wrapped_unsubscribe(self, subscriber=None):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return self.wrapped_characteristic.unsubscribe(subscriber)
def __str__(self):
wrapped = str(self.wrapped_characteristic)
return f'{self.__class__.__name__}({wrapped})'
# -----------------------------------------------------------------------------
class DelegatedCharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that converts bytes values using an encode and a decode function.
'''
def __init__(self, characteristic, encode=None, decode=None):
super().__init__(characteristic)
self.encode = encode
self.decode = decode
def encode_value(self, value):
return self.encode(value) if self.encode else value
def decode_value(self, value):
return self.decode(value) if self.decode else value
# -----------------------------------------------------------------------------
class PackedCharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format.
For formats with a single value, the adapted `read_value` and `write_value`
methods return/accept single values. For formats with multiple values,
they return/accept a tuple with the same number of elements as is required for
the format.
'''
def __init__(self, characteristic, pack_format):
super().__init__(characteristic)
self.struct = struct.Struct(pack_format)
def pack(self, *values):
return self.struct.pack(*values)
def unpack(self, buffer):
return self.struct.unpack(buffer)
def encode_value(self, value):
return self.pack(*value if isinstance(value, tuple) else (value,))
def decode_value(self, value):
unpacked = self.unpack(value)
return unpacked[0] if len(unpacked) == 1 else unpacked
# -----------------------------------------------------------------------------
class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
'''
Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format.
The adapted `read_value` and `write_value` methods return/accept aa dictionary which
is packed/unpacked according to format, with the arguments extracted from the
dictionary by key, in the same order as they occur in the `keys` parameter.
'''
def __init__(self, characteristic, pack_format, keys):
super().__init__(characteristic, pack_format)
self.keys = keys
# pylint: disable=arguments-differ
def pack(self, values):
return super().pack(*(values[key] for key in self.keys))
def unpack(self, buffer):
return dict(zip(self.keys, super().unpack(buffer)))
# -----------------------------------------------------------------------------
class UTF8CharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that converts strings to/from bytes using UTF-8 encoding
'''
def encode_value(self, value):
return value.encode('utf-8')
def decode_value(self, value):
return value.decode('utf-8')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Descriptor(Attribute): class Descriptor(Attribute):
''' '''
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
''' '''
def __init__(self, uuid, permissions, value = b''):
# Convert the uuid to a UUID object if it isn't already
if type(uuid) is str:
uuid = UUID(uuid)
super().__init__(uuid, permissions, value)
self.uuid = uuid
self.characteristic = None
def __str__(self): def __str__(self):
return f'Descriptor(handle=0x{self.handle:04X}, uuid={self.uuid}, value={self.read_value(None).hex()})' return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'value={self.read_value(None).hex()})'
)
class ClientCharacteristicConfigurationBits(enum.IntFlag):
'''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
field definition
'''
DEFAULT = 0x0000
NOTIFICATION = 0x0001
INDICATION = 0x0002

View File

@@ -26,19 +26,41 @@
import asyncio import asyncio
import logging import logging
import struct import struct
from colors import color
from .core import ProtocolError, TimeoutError from colors import color
from .hci import * from pyee import EventEmitter
from .att import *
from .hci import HCI_Constant
from .att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
ATT_DEFAULT_MTU,
ATT_ERROR_RESPONSE,
ATT_INVALID_OFFSET_ERROR,
ATT_PDU,
ATT_RESPONSES,
ATT_Exchange_MTU_Request,
ATT_Find_By_Type_Value_Request,
ATT_Find_Information_Request,
ATT_Handle_Value_Confirmation,
ATT_Read_Blob_Request,
ATT_Read_By_Group_Type_Request,
ATT_Read_By_Type_Request,
ATT_Read_Request,
ATT_Write_Command,
ATT_Write_Request,
)
from . import core
from .core import UUID, InvalidStateError, ProtocolError
from .gatt import ( from .gatt import (
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_REQUEST_TIMEOUT,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
Service, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic, Characteristic,
Descriptor ClientCharacteristicConfigurationBits,
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -47,55 +69,193 @@ from .gatt import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Proxies
# -----------------------------------------------------------------------------
class AttributeProxy(EventEmitter):
def __init__(self, client, handle, end_group_handle, attribute_type):
EventEmitter.__init__(self)
self.client = client
self.handle = handle
self.end_group_handle = end_group_handle
self.type = attribute_type
async def read_value(self, no_long_read=False):
return self.decode_value(
await self.client.read_value(self.handle, no_long_read)
)
async def write_value(self, value, with_response=False):
return await self.client.write_value(
self.handle, self.encode_value(value), with_response
)
def encode_value(self, value):
return value
def decode_value(self, value_bytes):
return value_bytes
def __str__(self):
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
class ServiceProxy(AttributeProxy):
@staticmethod
def from_client(service_class, client, service_uuid):
# The service and its characteristics are considered to have already been
# discovered
services = client.get_services_by_uuid(service_uuid)
service = services[0] if services else None
return service_class(service) if service else None
def __init__(self, client, handle, end_group_handle, uuid, primary=True):
attribute_type = (
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
)
super().__init__(client, handle, end_group_handle, attribute_type)
self.uuid = uuid
self.characteristics = []
async def discover_characteristics(self, uuids=()):
return await self.client.discover_characteristics(uuids, self)
def get_characteristics_by_uuid(self, uuid):
return self.client.get_characteristics_by_uuid(uuid, self)
def __str__(self):
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
class CharacteristicProxy(AttributeProxy):
def __init__(self, client, handle, end_group_handle, uuid, properties):
super().__init__(client, handle, end_group_handle, uuid)
self.uuid = uuid
self.properties = properties
self.descriptors = []
self.descriptors_discovered = False
self.subscribers = {} # Map from subscriber to proxy subscriber
def get_descriptor(self, descriptor_type):
for descriptor in self.descriptors:
if descriptor.type == descriptor_type:
return descriptor
return None
async def discover_descriptors(self):
return await self.client.discover_descriptors(self)
async def subscribe(self, subscriber=None, prefer_notify=True):
if subscriber is not None:
if subscriber in self.subscribers:
# We already have a proxy subscriber
subscriber = self.subscribers[subscriber]
else:
# Create and register a proxy that will decode the value
original_subscriber = subscriber
def on_change(value):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
return await self.client.subscribe(self, subscriber, prefer_notify)
async def unsubscribe(self, subscriber=None):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return await self.client.unsubscribe(self, subscriber)
def __str__(self):
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'uuid={self.uuid}, '
f'properties={Characteristic.properties_as_string(self.properties)})'
)
class DescriptorProxy(AttributeProxy):
def __init__(self, client, handle, descriptor_type):
super().__init__(client, handle, 0, descriptor_type)
def __str__(self):
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})'
class ProfileServiceProxy:
'''
Base class for profile-specific service proxies
'''
@classmethod
def from_client(cls, client):
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# GATT Client # GATT Client
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Client: class Client:
def __init__(self, connection): def __init__(self, connection):
self.connection = connection self.connection = connection
self.mtu = ATT_DEFAULT_MTU self.mtu_exchange_done = False
self.mtu_exchange_done = False self.request_semaphore = asyncio.Semaphore(1)
self.request_semaphore = asyncio.Semaphore(1) self.pending_request = None
self.pending_request = None self.pending_response = None
self.pending_response = None self.notification_subscribers = (
self.notification_subscribers = {} # Notification subscribers, by attribute handle {}
self.indication_subscribers = {} # Indication subscribers, by attribute handle ) # Notification subscribers, by attribute handle
self.services = [] self.indication_subscribers = {} # Indication subscribers, by attribute handle
self.services = []
def send_gatt_pdu(self, pdu): def send_gatt_pdu(self, pdu):
self.connection.send_l2cap_pdu(ATT_CID, pdu) self.connection.send_l2cap_pdu(ATT_CID, pdu)
async def send_command(self, command): async def send_command(self, command):
logger.debug(f'GATT Command from client: [0x{self.connection.handle:04X}] {command}') logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
self.send_gatt_pdu(command.to_bytes()) self.send_gatt_pdu(command.to_bytes())
async def send_request(self, request): async def send_request(self, request):
logger.debug(f'GATT Request from client: [0x{self.connection.handle:04X}] {request}') logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
# Wait until we can send (only one pending command at a time for the connection) # Wait until we can send (only one pending command at a time for the connection)
response = None response = None
async with self.request_semaphore: async with self.request_semaphore:
assert(self.pending_request is None) assert self.pending_request is None
assert(self.pending_response is None) assert self.pending_response is None
# Create a future value to hold the eventual response # Create a future value to hold the eventual response
self.pending_response = asyncio.get_running_loop().create_future() self.pending_response = asyncio.get_running_loop().create_future()
self.pending_request = request self.pending_request = request
try: try:
self.send_gatt_pdu(request.to_bytes()) self.send_gatt_pdu(request.to_bytes())
response = await asyncio.wait_for(self.pending_response, GATT_REQUEST_TIMEOUT) response = await asyncio.wait_for(
except asyncio.TimeoutError: self.pending_response, GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Request timeout', 'red')) logger.warning(color('!!! GATT Request timeout', 'red'))
raise TimeoutError(f'GATT timeout for {request.name}') raise core.TimeoutError(f'GATT timeout for {request.name}') from error
finally: finally:
self.pending_request = None self.pending_request = None
self.pending_response = None self.pending_response = None
return response return response
def send_confirmation(self, confirmation): def send_confirmation(self, confirmation):
logger.debug(f'GATT Confirmation from client: [0x{self.connection.handle:04X}] {confirmation}') logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
self.send_gatt_pdu(confirmation.to_bytes()) self.send_gatt_pdu(confirmation.to_bytes())
async def request_mtu(self, mtu): async def request_mtu(self, mtu):
@@ -107,31 +267,37 @@ class Client:
# We can only send one request per connection # We can only send one request per connection
if self.mtu_exchange_done: if self.mtu_exchange_done:
return return self.connection.att_mtu
# Send the request # Send the request
self.mtu_exchange_done = True self.mtu_exchange_done = True
response = await self.send_request(ATT_Exchange_MTU_Request(client_rx_mtu = mtu)) response = await self.send_request(ATT_Exchange_MTU_Request(client_rx_mtu=mtu))
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
raise ProtocolError( raise ProtocolError(
response.error_code, response.error_code,
'att', 'att',
ATT_PDU.error_name(response.error_code), ATT_PDU.error_name(response.error_code),
response response,
) )
self.mtu = max(ATT_DEFAULT_MTU, response.server_rx_mtu) # Compute the final MTU
return self.mtu self.connection.att_mtu = min(mtu, response.server_rx_mtu)
return self.connection.att_mtu
def get_services_by_uuid(self, uuid): def get_services_by_uuid(self, uuid):
return [service for service in self.services if service.uuid == uuid] return [service for service in self.services if service.uuid == uuid]
def get_characteristics_by_uuid(self, uuid, service = None): def get_characteristics_by_uuid(self, uuid, service=None):
services = [service] if service else self.services services = [service] if service else self.services
return [c for c in [c for s in services for c in s.characteristics] if c.uuid == uuid] return [
c
for c in [c for s in services for c in s.characteristics]
if c.uuid == uuid
]
def on_service_discovered(self, service): def on_service_discovered(self, service):
''' Add a service to the service list if it wasn't already there ''' '''Add a service to the service list if it wasn't already there'''
already_known = False already_known = False
for existing_service in self.services: for existing_service in self.services:
if existing_service.handle == service.handle: if existing_service.handle == service.handle:
@@ -140,7 +306,7 @@ class Client:
if not already_known: if not already_known:
self.services.append(service) self.services.append(service)
async def discover_services(self, uuids = None): async def discover_services(self, uuids=None):
''' '''
See Vol 3, Part G - 4.4.1 Discover All Primary Services See Vol 3, Part G - 4.4.1 Discover All Primary Services
''' '''
@@ -149,9 +315,9 @@ class Client:
while starting_handle < 0xFFFF: while starting_handle < 0xFFFF:
response = await self.send_request( response = await self.send_request(
ATT_Read_By_Group_Type_Request( ATT_Read_By_Group_Type_Request(
starting_handle = starting_handle, starting_handle=starting_handle,
ending_handle = 0xFFFF, ending_handle=0xFFFF,
attribute_group_type = GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE attribute_group_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
) )
) )
if response is None: if response is None:
@@ -162,21 +328,37 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}') logger.warning(
'!!! unexpected error while discovering services: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception # TODO raise appropriate exception
return return
break break
for attribute_handle, end_group_handle, attribute_value in response.attributes: for (
if attribute_handle < starting_handle or end_group_handle < attribute_handle: attribute_handle,
end_group_handle,
attribute_value,
) in response.attributes:
if (
attribute_handle < starting_handle
or end_group_handle < attribute_handle
):
# Something's not right # Something's not right
logger.warning(f'bogus handle values: {attribute_handle} {end_group_handle}') logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
return return
# Create a primary service object # Create a service proxy for this service
service = Service(UUID.from_bytes(attribute_value), [], True) service = ServiceProxy(
service.handle = attribute_handle self,
service.end_group_handle = end_group_handle attribute_handle,
end_group_handle,
UUID.from_bytes(attribute_value),
True,
)
# Filter out returned services based on the given uuids list # Filter out returned services based on the given uuids list
if (not uuids) or (service.uuid in uuids): if (not uuids) or (service.uuid in uuids):
@@ -200,7 +382,7 @@ class Client:
''' '''
# Force uuid to be a UUID object # Force uuid to be a UUID object
if type(uuid) is str: if isinstance(uuid, str):
uuid = UUID(uuid) uuid = UUID(uuid)
starting_handle = 0x0001 starting_handle = 0x0001
@@ -208,10 +390,10 @@ class Client:
while starting_handle < 0xFFFF: while starting_handle < 0xFFFF:
response = await self.send_request( response = await self.send_request(
ATT_Find_By_Type_Value_Request( ATT_Find_By_Type_Value_Request(
starting_handle = starting_handle, starting_handle=starting_handle,
ending_handle = 0xFFFF, ending_handle=0xFFFF,
attribute_type = GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, attribute_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
attribute_value = uuid.to_pdu_bytes() attribute_value=uuid.to_pdu_bytes(),
) )
) )
if response is None: if response is None:
@@ -222,21 +404,29 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}') logger.warning(
'!!! unexpected error while discovering services: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception # TODO raise appropriate exception
return return
break break
for attribute_handle, end_group_handle in response.handles_information: for attribute_handle, end_group_handle in response.handles_information:
if attribute_handle < starting_handle or end_group_handle < attribute_handle: if (
attribute_handle < starting_handle
or end_group_handle < attribute_handle
):
# Something's not right # Something's not right
logger.warning(f'bogus handle values: {attribute_handle} {end_group_handle}') logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
return return
# Create a primary service object # Create a service proxy for this service
service = Service(uuid, [], True) service = ServiceProxy(
service.handle = attribute_handle self, attribute_handle, end_group_handle, uuid, True
service.end_group_handle = end_group_handle )
# Add the service to the peer's service list # Add the service to the peer's service list
services.append(service) services.append(service)
@@ -255,7 +445,7 @@ class Client:
return services return services
async def discover_included_services(self, service): async def discover_included_services(self, _service):
''' '''
See Vol 3, Part G - 4.5.1 Find Included Services See Vol 3, Part G - 4.5.1 Find Included Services
''' '''
@@ -264,11 +454,12 @@ class Client:
async def discover_characteristics(self, uuids, service): async def discover_characteristics(self, uuids, service):
''' '''
See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2 Discover Characteristics by UUID See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2
Discover Characteristics by UUID
''' '''
# Cast the UUIDs type from string to object if needed # Cast the UUIDs type from string to object if needed
uuids = [UUID(uuid) if type(uuid) is str else uuid for uuid in uuids] uuids = [UUID(uuid) if isinstance(uuid, str) else uuid for uuid in uuids]
# Decide which services to discover for # Decide which services to discover for
services = [service] if service else self.services services = [service] if service else self.services
@@ -277,15 +468,15 @@ class Client:
discovered_characteristics = [] discovered_characteristics = []
for service in services: for service in services:
starting_handle = service.handle starting_handle = service.handle
ending_handle = service.end_group_handle ending_handle = service.end_group_handle
characteristics = [] characteristics = []
while starting_handle <= ending_handle: while starting_handle <= ending_handle:
response = await self.send_request( response = await self.send_request(
ATT_Read_By_Type_Request( ATT_Read_By_Type_Request(
starting_handle = starting_handle, starting_handle=starting_handle,
ending_handle = ending_handle, ending_handle=ending_handle,
attribute_type = GATT_CHARACTERISTIC_ATTRIBUTE_TYPE attribute_type=GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
) )
) )
if response is None: if response is None:
@@ -296,7 +487,10 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning(f'!!! unexpected error while discovering characteristics: {HCI_Constant.error_name(response.error_code)}') logger.warning(
'!!! unexpected error while discovering characteristics: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception # TODO raise appropriate exception
return return
break break
@@ -314,8 +508,9 @@ class Client:
properties, handle = struct.unpack_from('<BH', attribute_value) properties, handle = struct.unpack_from('<BH', attribute_value)
characteristic_uuid = UUID.from_bytes(attribute_value[3:]) characteristic_uuid = UUID.from_bytes(attribute_value[3:])
characteristic = Characteristic(characteristic_uuid, properties, 0) characteristic = CharacteristicProxy(
characteristic.handle = handle self, handle, 0, characteristic_uuid, properties
)
# Set the previous characteristic's end handle # Set the previous characteristic's end handle
if characteristics: if characteristics:
@@ -331,22 +526,26 @@ class Client:
characteristics[-1].end_group_handle = service.end_group_handle characteristics[-1].end_group_handle = service.end_group_handle
# Set the service's characteristics # Set the service's characteristics
characteristics = [c for c in characteristics if not uuids or c.uuid in uuids] characteristics = [
c for c in characteristics if not uuids or c.uuid in uuids
]
service.characteristics = characteristics service.characteristics = characteristics
discovered_characteristics.extend(characteristics) discovered_characteristics.extend(characteristics)
return discovered_characteristics return discovered_characteristics
async def discover_descriptors(self, characteristic = None, start_handle = None, end_handle = None): async def discover_descriptors(
self, characteristic=None, start_handle=None, end_handle=None
):
''' '''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
''' '''
if characteristic: if characteristic:
starting_handle = characteristic.handle + 1 starting_handle = characteristic.handle + 1
ending_handle = characteristic.end_group_handle ending_handle = characteristic.end_group_handle
elif start_handle and end_handle: elif start_handle and end_handle:
starting_handle = start_handle starting_handle = start_handle
ending_handle = end_handle ending_handle = end_handle
else: else:
return [] return []
@@ -354,8 +553,7 @@ class Client:
while starting_handle <= ending_handle: while starting_handle <= ending_handle:
response = await self.send_request( response = await self.send_request(
ATT_Find_Information_Request( ATT_Find_Information_Request(
starting_handle = starting_handle, starting_handle=starting_handle, ending_handle=ending_handle
ending_handle = ending_handle
) )
) )
if response is None: if response is None:
@@ -366,7 +564,10 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning(f'!!! unexpected error while discovering descriptors: {HCI_Constant.error_name(response.error_code)}') logger.warning(
'!!! unexpected error while discovering descriptors: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception # TODO raise appropriate exception
return [] return []
break break
@@ -382,8 +583,9 @@ class Client:
logger.warning(f'bogus handle value: {attribute_handle}') logger.warning(f'bogus handle value: {attribute_handle}')
return [] return []
descriptor = Descriptor(UUID.from_bytes(attribute_uuid), 0) descriptor = DescriptorProxy(
descriptor.handle = attribute_handle self, attribute_handle, UUID.from_bytes(attribute_uuid)
)
descriptors.append(descriptor) descriptors.append(descriptor)
# TODO: read descriptor value # TODO: read descriptor value
@@ -401,13 +603,12 @@ class Client:
Discover all attributes, regardless of type Discover all attributes, regardless of type
''' '''
starting_handle = 0x0001 starting_handle = 0x0001
ending_handle = 0xFFFF ending_handle = 0xFFFF
attributes = [] attributes = []
while True: while True:
response = await self.send_request( response = await self.send_request(
ATT_Find_Information_Request( ATT_Find_Information_Request(
starting_handle = starting_handle, starting_handle=starting_handle, ending_handle=ending_handle
ending_handle = ending_handle
) )
) )
if response is None: if response is None:
@@ -417,7 +618,10 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning(f'!!! unexpected error while discovering attributes: {HCI_Constant.error_name(response.error_code)}') logger.warning(
'!!! unexpected error while discovering attributes: '
f'{HCI_Constant.error_name(response.error_code)}'
)
return [] return []
break break
@@ -427,8 +631,9 @@ class Client:
logger.warning(f'bogus handle value: {attribute_handle}') logger.warning(f'bogus handle value: {attribute_handle}')
return [] return []
attribute = Attribute(attribute_uuid, 0) attribute = AttributeProxy(
attribute.handle = attribute_handle self, attribute_handle, 0, UUID.from_bytes(attribute_uuid)
)
attributes.append(attribute) attributes.append(attribute)
# Move on to the next attributes # Move on to the next attributes
@@ -436,35 +641,86 @@ class Client:
return attributes return attributes
async def subscribe(self, characteristic, subscriber=None): async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
# If we haven't already discovered the descriptors for this characteristic, do it now # If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered: if not characteristic.descriptors_discovered:
await self.discover_descriptors(characteristic) await self.discover_descriptors(characteristic)
# Look for the CCCD descriptor # Look for the CCCD descriptor
cccd = characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR) cccd = characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
if not cccd: if not cccd:
logger.warning('subscribing to characteristic with no CCCD descriptor') logger.warning('subscribing to characteristic with no CCCD descriptor')
return return
# Set the subscription bits and select the subscriber set if (
bits = 0 characteristic.properties & Characteristic.NOTIFY
subscriber_sets = [] and characteristic.properties & Characteristic.INDICATE
if characteristic.properties & Characteristic.NOTIFY: ):
bits |= 0x0001 if prefer_notify:
subscriber_sets.append(self.notification_subscribers.setdefault(characteristic.handle, set())) bits = ClientCharacteristicConfigurationBits.NOTIFICATION
if characteristic.properties & Characteristic.INDICATE: subscribers = self.notification_subscribers
bits |= 0x0002 else:
subscriber_sets.append(self.indication_subscribers.setdefault(characteristic.handle, set())) bits = ClientCharacteristicConfigurationBits.INDICATION
subscribers = self.indication_subscribers
elif characteristic.properties & Characteristic.NOTIFY:
bits = ClientCharacteristicConfigurationBits.NOTIFICATION
subscribers = self.notification_subscribers
elif characteristic.properties & Characteristic.INDICATE:
bits = ClientCharacteristicConfigurationBits.INDICATION
subscribers = self.indication_subscribers
else:
raise InvalidStateError("characteristic is not notify or indicate")
# Add subscribers to the sets # Add subscribers to the sets
for subscriber_set in subscriber_sets: subscriber_set = subscribers.setdefault(characteristic.handle, set())
if subscriber is not None: if subscriber is not None:
subscriber_set.add(subscriber) subscriber_set.add(subscriber)
subscriber_set.add(lambda value: characteristic.emit('update', self.connection, value)) # Add the characteristic as a subscriber, which will result in the
# characteristic emitting an 'update' event when a notification or indication
# is received
subscriber_set.add(characteristic)
await self.write_value(cccd, struct.pack('<H', bits), with_response=True) await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
async def unsubscribe(self, characteristic, subscriber=None):
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
await self.discover_descriptors(characteristic)
# Look for the CCCD descriptor
cccd = characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
if not cccd:
logger.warning('unsubscribing from characteristic with no CCCD descriptor')
return
if subscriber is not None:
# Remove matching subscriber from subscriber sets
for subscriber_set in (
self.notification_subscribers,
self.indication_subscribers,
):
subscribers = subscriber_set.get(characteristic.handle, [])
if subscriber in subscribers:
subscribers.remove(subscriber)
# Cleanup if we removed the last one
if not subscribers:
del subscriber_set[characteristic.handle]
else:
# Remove all subscribers for this attribute from the sets!
self.notification_subscribers.pop(characteristic.handle, None)
self.indication_subscribers.pop(characteristic.handle, None)
if not self.notification_subscribers and not self.indication_subscribers:
# No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value(self, attribute, no_long_read=False): async def read_value(self, attribute, no_long_read=False):
''' '''
See Vol 3, Part G - 4.8.1 Read Characteristic Value See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -473,8 +729,10 @@ class Client:
''' '''
# Send a request to read # Send a request to read
attribute_handle = attribute if type(attribute) is int else attribute.handle attribute_handle = attribute if isinstance(attribute, int) else attribute.handle
response = await self.send_request(ATT_Read_Request(attribute_handle = attribute_handle)) response = await self.send_request(
ATT_Read_Request(attribute_handle=attribute_handle)
)
if response is None: if response is None:
raise TimeoutError('read timeout') raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
@@ -482,35 +740,40 @@ class Client:
response.error_code, response.error_code,
'att', 'att',
ATT_PDU.error_name(response.error_code), ATT_PDU.error_name(response.error_code),
response response,
) )
# If the value is the max size for the MTU, try to read more unless the caller # If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that # specifically asked not to do that
attribute_value = response.attribute_value attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.mtu - 1: if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
logger.debug('using READ BLOB to get the rest of the value') logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value) offset = len(attribute_value)
while True: while True:
response = await self.send_request( response = await self.send_request(
ATT_Read_Blob_Request(attribute_handle = attribute_handle, value_offset = offset) ATT_Read_Blob_Request(
attribute_handle=attribute_handle, value_offset=offset
)
) )
if response is None: if response is None:
raise TimeoutError('read timeout') raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code == ATT_ATTRIBUTE_NOT_LONG_ERROR or response.error_code == ATT_INVALID_OFFSET_ERROR: if response.error_code in (
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_INVALID_OFFSET_ERROR,
):
break break
raise ProtocolError( raise ProtocolError(
response.error_code, response.error_code,
'att', 'att',
ATT_PDU.error_name(response.error_code), ATT_PDU.error_name(response.error_code),
response response,
) )
part = response.part_attribute_value part = response.part_attribute_value
attribute_value += part attribute_value += part
if len(part) < self.mtu - 1: if len(part) < self.connection.att_mtu - 1:
break break
offset += len(part) offset += len(part)
@@ -525,18 +788,18 @@ class Client:
if service is None: if service is None:
starting_handle = 0x0001 starting_handle = 0x0001
ending_handle = 0xFFFF ending_handle = 0xFFFF
else: else:
starting_handle = service.handle starting_handle = service.handle
ending_handle = service.end_group_handle ending_handle = service.end_group_handle
characteristics_values = [] characteristics_values = []
while starting_handle <= ending_handle: while starting_handle <= ending_handle:
response = await self.send_request( response = await self.send_request(
ATT_Read_By_Type_Request( ATT_Read_By_Type_Request(
starting_handle = starting_handle, starting_handle=starting_handle,
ending_handle = ending_handle, ending_handle=ending_handle,
attribute_type = uuid attribute_type=uuid,
) )
) )
if response is None: if response is None:
@@ -547,7 +810,10 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning(f'!!! unexpected error while reading characteristics: {HCI_Constant.error_name(response.error_code)}') logger.warning(
'!!! unexpected error while reading characteristics: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception # TODO raise appropriate exception
return [] return []
break break
@@ -572,47 +838,54 @@ class Client:
async def write_value(self, attribute, value, with_response=False): async def write_value(self, attribute, value, with_response=False):
''' '''
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic Value See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
Value
`attribute` can be an Attribute object, or a handle value `attribute` can be an Attribute object, or a handle value
''' '''
# Send a request or command to write # Send a request or command to write
attribute_handle = attribute if type(attribute) is int else attribute.handle attribute_handle = attribute if isinstance(attribute, int) else attribute.handle
if with_response: if with_response:
response = await self.send_request( response = await self.send_request(
ATT_Write_Request( ATT_Write_Request(
attribute_handle = attribute_handle, attribute_handle=attribute_handle, attribute_value=value
attribute_value = value
) )
) )
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
raise ProtocolError( raise ProtocolError(
response.error_code, response.error_code,
'att', 'att',
ATT_PDU.error_name(response.error_code), response ATT_PDU.error_name(response.error_code),
response,
) )
else: else:
await self.send_command( await self.send_command(
ATT_Write_Command( ATT_Write_Command(
attribute_handle = attribute_handle, attribute_handle=attribute_handle, attribute_value=value
attribute_value = value
) )
) )
def on_gatt_pdu(self, att_pdu): def on_gatt_pdu(self, att_pdu):
logger.debug(f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}') logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
if att_pdu.op_code in ATT_RESPONSES: if att_pdu.op_code in ATT_RESPONSES:
if self.pending_request is None: if self.pending_request is None:
# Not expected! # Not expected!
logger.warning('!!! unexpected response, there is no pending request') logger.warning('!!! unexpected response, there is no pending request')
return return
# Sanity check: the response should match the pending request unless it is an error response # Sanity check: the response should match the pending request unless it is
# an error response
if att_pdu.op_code != ATT_ERROR_RESPONSE: if att_pdu.op_code != ATT_ERROR_RESPONSE:
expected_response_name = self.pending_request.name.replace('_REQUEST', '_RESPONSE') expected_response_name = self.pending_request.name.replace(
'_REQUEST', '_RESPONSE'
)
if att_pdu.name != expected_response_name: if att_pdu.name != expected_response_name:
logger.warning(f'!!! mismatched response: expected {expected_response_name}') logger.warning(
f'!!! mismatched response: expected {expected_response_name}'
)
return return
# Return the response to the coroutine that is waiting for it # Return the response to the coroutine that is waiting for it
@@ -623,15 +896,27 @@ class Client:
if handler is not None: if handler is not None:
handler(att_pdu) handler(att_pdu)
else: else:
logger.warning(f'{color(f"--- Ignoring GATT Response from [0x{self.connection.handle:04X}]:", "red")} {att_pdu}') logger.warning(
color(
'--- Ignoring GATT Response from '
f'[0x{self.connection.handle:04X}]: ',
'red',
)
+ str(att_pdu)
)
def on_att_handle_value_notification(self, notification): def on_att_handle_value_notification(self, notification):
# Call all subscribers # Call all subscribers
subscribers = self.notification_subscribers.get(notification.attribute_handle, []) subscribers = self.notification_subscribers.get(
notification.attribute_handle, []
)
if not subscribers: if not subscribers:
logger.warning('!!! received notification with no subscriber') logger.warning('!!! received notification with no subscriber')
for subscriber in subscribers: for subscriber in subscribers:
subscriber(notification.attribute_value) if callable(subscriber):
subscriber(notification.attribute_value)
else:
subscriber.emit('update', notification.attribute_value)
def on_att_handle_value_indication(self, indication): def on_att_handle_value_indication(self, indication):
# Call all subscribers # Call all subscribers
@@ -639,7 +924,10 @@ class Client:
if not subscribers: if not subscribers:
logger.warning('!!! received indication with no subscriber') logger.warning('!!! received indication with no subscriber')
for subscriber in subscribers: for subscriber in subscribers:
subscriber(indication.attribute_value) if callable(subscriber):
subscriber(indication.attribute_value)
else:
subscriber.emit('update', indication.attribute_value)
# Confirm that we received the indication # Confirm that we received the indication
self.send_confirmation(ATT_Handle_Value_Confirmation()) self.send_confirmation(ATT_Handle_Value_Confirmation())

View File

@@ -26,13 +26,53 @@
import asyncio import asyncio
import logging import logging
from collections import defaultdict from collections import defaultdict
import struct
from typing import Tuple, Optional
from pyee import EventEmitter from pyee import EventEmitter
from colors import color from colors import color
from .core import * from .core import UUID
from .hci import * from .att import (
from .att import * ATT_ATTRIBUTE_NOT_FOUND_ERROR,
from .gatt import * ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
ATT_DEFAULT_MTU,
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR,
ATT_INVALID_HANDLE_ERROR,
ATT_INVALID_OFFSET_ERROR,
ATT_REQUEST_NOT_SUPPORTED_ERROR,
ATT_REQUESTS,
ATT_UNLIKELY_ERROR_ERROR,
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
ATT_Error,
ATT_Error_Response,
ATT_Exchange_MTU_Response,
ATT_Find_By_Type_Value_Response,
ATT_Find_Information_Response,
ATT_Handle_Value_Indication,
ATT_Handle_Value_Notification,
ATT_Read_Blob_Response,
ATT_Read_By_Group_Type_Response,
ATT_Read_By_Type_Response,
ATT_Read_Response,
ATT_Write_Response,
Attribute,
)
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_INCLUDE_ATTRIBUTE_TYPE,
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic,
CharacteristicDeclaration,
CharacteristicValue,
Descriptor,
Service,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -40,27 +80,47 @@ from .gatt import *
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
GATT_SERVER_DEFAULT_MAX_MTU = 517
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# GATT Server # GATT Server
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Server(EventEmitter): class Server(EventEmitter):
def __init__(self, device): def __init__(self, device):
super().__init__() super().__init__()
self.device = device self.device = device
self.attributes = [] # Attributes, ordered by increasing handle values self.attributes = [] # Attributes, ordered by increasing handle values
self.attributes_by_handle = {} # Map for fast attribute access by handle self.attributes_by_handle = {} # Map for fast attribute access by handle
self.max_mtu = 23 # FIXME: 517 # The max MTU we're willing to negotiate self.max_mtu = (
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
self.mtus = {} # Map of ATT MTU values by connection handle )
self.subscribers = (
{}
) # Map of subscriber states by connection handle and attribute handle
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1)) self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None) self.pending_confirmations = defaultdict(lambda: None)
def __str__(self):
return "\n".join(map(str, self.attributes))
def send_gatt_pdu(self, connection_handle, pdu): def send_gatt_pdu(self, connection_handle, pdu):
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu) self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
def next_handle(self): def next_handle(self):
return 1 + len(self.attributes) return 1 + len(self.attributes)
def get_advertising_service_data(self):
return {
attribute: data
for attribute in self.attributes
if isinstance(attribute, Service)
and (data := attribute.get_advertising_data())
}
def get_attribute(self, handle): def get_attribute(self, handle):
attribute = self.attributes_by_handle.get(handle) attribute = self.attributes_by_handle.get(handle)
if attribute: if attribute:
@@ -74,15 +134,74 @@ class Server(EventEmitter):
return attribute return attribute
return None return None
def get_service_attribute(self, service_uuid: UUID) -> Optional[Service]:
return next(
(
attribute
for attribute in self.attributes
if attribute.type == GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
and attribute.uuid == service_uuid
),
None,
)
def get_characteristic_attributes(
self, service_uuid: UUID, characteristic_uuid: UUID
) -> Optional[Tuple[CharacteristicDeclaration, Characteristic]]:
service_handle = self.get_service_attribute(service_uuid)
if not service_handle:
return None
return next(
(
(attribute, self.get_attribute(attribute.characteristic.handle))
for attribute in map(
self.get_attribute,
range(service_handle.handle, service_handle.end_group_handle + 1),
)
if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
and attribute.characteristic.uuid == characteristic_uuid
),
None,
)
def get_descriptor_attribute(
self, service_uuid: UUID, characteristic_uuid: UUID, descriptor_uuid: UUID
) -> Optional[Descriptor]:
characteristics = self.get_characteristic_attributes(
service_uuid, characteristic_uuid
)
if not characteristics:
return None
(_, characteristic_value) = characteristics
return next(
(
attribute
for attribute in map(
self.get_attribute,
range(
characteristic_value.handle + 1,
characteristic_value.end_group_handle + 1,
),
)
if attribute.type == descriptor_uuid
),
None,
)
def add_attribute(self, attribute): def add_attribute(self, attribute):
# Assign a handle to this attribute # Assign a handle to this attribute
attribute.handle = self.next_handle() attribute.handle = self.next_handle()
attribute.end_group_handle = attribute.handle # TODO: keep track of descriptors in the group attribute.end_group_handle = (
attribute.handle
) # TODO: keep track of descriptors in the group
# Add this attribute to the list # Add this attribute to the list
self.attributes.append(attribute) self.attributes.append(attribute)
def add_service(self, service): def add_service(self, service: Service):
# Add the service attribute to the DB # Add the service attribute to the DB
self.add_attribute(service) self.add_attribute(service)
@@ -90,16 +209,9 @@ class Server(EventEmitter):
# Add all characteristics # Add all characteristics
for characteristic in service.characteristics: for characteristic in service.characteristics:
# Add a Characteristic Declaration (Vol 3, Part G - 3.3.1 Characteristic Declaration) # Add a Characteristic Declaration
declaration_bytes = struct.pack( characteristic_declaration = CharacteristicDeclaration(
'<BH', characteristic, self.next_handle() + 1
characteristic.properties,
self.next_handle() + 1, # The value will be the next attribute after this declaration
) + characteristic.uuid.to_pdu_bytes()
characteristic_declaration = Attribute(
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
Attribute.READABLE,
declaration_bytes
) )
self.add_attribute(characteristic_declaration) self.add_attribute(characteristic_declaration)
@@ -113,17 +225,26 @@ class Server(EventEmitter):
# If the characteristic supports subscriptions, add a CCCD descriptor # If the characteristic supports subscriptions, add a CCCD descriptor
# unless there is one already # unless there is one already
if ( if (
characteristic.properties & (Characteristic.NOTIFY | Characteristic.INDICATE) and characteristic.properties
characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR) is None & (Characteristic.NOTIFY | Characteristic.INDICATE)
and characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
is None
): ):
self.add_attribute( self.add_attribute(
# pylint: disable=line-too-long
Descriptor( Descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
Attribute.READABLE | Attribute.WRITEABLE, Attribute.READABLE | Attribute.WRITEABLE,
CharacteristicValue( CharacteristicValue(
read=lambda connection, characteristic=characteristic: self.read_cccd(connection, characteristic), read=lambda connection, characteristic=characteristic: self.read_cccd(
write=lambda connection, value, characteristic=characteristic: self.write_cccd(connection, characteristic, value) connection, characteristic
) ),
write=lambda connection, value, characteristic=characteristic: self.write_cccd(
connection, characteristic, value
),
),
) )
) )
@@ -150,26 +271,39 @@ class Server(EventEmitter):
return cccd or bytes([0, 0]) return cccd or bytes([0, 0])
def write_cccd(self, connection, characteristic, value): def write_cccd(self, connection, characteristic, value):
logger.debug(f'Subscription update for connection={connection.handle:04X}, handle={characteristic.handle:04X}: {value.hex()}') logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
)
# Sanity check # Sanity check
if len(value) != 2: if len(value) != 2:
logger.warn('CCCD value not 2 bytes long') logger.warning('CCCD value not 2 bytes long')
return return
cccds = self.subscribers.setdefault(connection.handle, {}) cccds = self.subscribers.setdefault(connection.handle, {})
cccds[characteristic.handle] = value cccds[characteristic.handle] = value
logger.debug(f'CCCDs: {cccds}') logger.debug(f'CCCDs: {cccds}')
notify_enabled = (value[0] & 0x01 != 0) notify_enabled = value[0] & 0x01 != 0
indicate_enabled = (value[0] & 0x02 != 0) indicate_enabled = value[0] & 0x02 != 0
characteristic.emit('subscription', connection, notify_enabled, indicate_enabled) characteristic.emit(
self.emit('characteristic_subscription', connection, characteristic, notify_enabled, indicate_enabled) 'subscription', connection, notify_enabled, indicate_enabled
)
self.emit(
'characteristic_subscription',
connection,
characteristic,
notify_enabled,
indicate_enabled,
)
def send_response(self, connection, response): def send_response(self, connection, response):
logger.debug(f'GATT Response from server: [0x{connection.handle:04X}] {response}') logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
self.send_gatt_pdu(connection.handle, response.to_bytes()) self.send_gatt_pdu(connection.handle, response.to_bytes())
async def notify_subscriber(self, connection, attribute, force=False): async def notify_subscriber(self, connection, attribute, value=None, force=False):
# Check if there's a subscriber # Check if there's a subscriber
if not force: if not force:
subscribers = self.subscribers.get(connection.handle) subscribers = self.subscribers.get(connection.handle)
@@ -178,47 +312,35 @@ class Server(EventEmitter):
return return
cccd = subscribers.get(attribute.handle) cccd = subscribers.get(attribute.handle)
if not cccd: if not cccd:
logger.debug(f'not notifying, no subscribers for handle {attribute.handle:04X}') logger.debug(
f'not notifying, no subscribers for handle {attribute.handle:04X}'
)
return return
if len(cccd) != 2 or (cccd[0] & 0x01 == 0): if len(cccd) != 2 or (cccd[0] & 0x01 == 0):
logger.debug(f'not notifying, cccd={cccd.hex()}') logger.debug(f'not notifying, cccd={cccd.hex()}')
return return
# Get the value # Get or encode the value
value = attribute.read_value(connection) value = (
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed # Truncate if needed
mtu = self.get_mtu(connection) if len(value) > connection.att_mtu - 3:
if len(value) > mtu - 3: value = value[: connection.att_mtu - 3]
value = value[:mtu - 3]
# Notify # Notify
notification = ATT_Handle_Value_Notification( notification = ATT_Handle_Value_Notification(
attribute_handle = attribute.handle, attribute_handle=attribute.handle, attribute_value=value
attribute_value = value
) )
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}') logger.debug(
self.send_gatt_pdu(connection.handle, notification.to_bytes()) f'GATT Notify from server: [0x{connection.handle:04X}] {notification}'
)
self.send_gatt_pdu(connection.handle, bytes(notification))
async def notify_subscribers(self, attribute, force=False): async def indicate_subscriber(self, connection, attribute, value=None, force=False):
# Get all the connections for which there's at least one subscription
connections = [
connection for connection in [
self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle)
]
if connection is not None
]
# Notify for each connection
if connections:
await asyncio.wait([
self.notify_subscriber(connection, attribute, force)
for connection in connections
])
async def indicate_subscriber(self, connection, attribute, force=False):
# Check if there's a subscriber # Check if there's a subscriber
if not force: if not force:
subscribers = self.subscribers.get(connection.handle) subscribers = self.subscribers.get(connection.handle)
@@ -227,64 +349,84 @@ class Server(EventEmitter):
return return
cccd = subscribers.get(attribute.handle) cccd = subscribers.get(attribute.handle)
if not cccd: if not cccd:
logger.debug(f'not indicating, no subscribers for handle {attribute.handle:04X}') logger.debug(
f'not indicating, no subscribers for handle {attribute.handle:04X}'
)
return return
if len(cccd) != 2 or (cccd[0] & 0x02 == 0): if len(cccd) != 2 or (cccd[0] & 0x02 == 0):
logger.debug(f'not indicating, cccd={cccd.hex()}') logger.debug(f'not indicating, cccd={cccd.hex()}')
return return
# Get the value # Get or encode the value
value = attribute.read_value(connection) value = (
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed # Truncate if needed
mtu = self.get_mtu(connection) if len(value) > connection.att_mtu - 3:
if len(value) > mtu - 3: value = value[: connection.att_mtu - 3]
value = value[:mtu - 3]
# Indicate # Indicate
indication = ATT_Handle_Value_Indication( indication = ATT_Handle_Value_Indication(
attribute_handle = attribute.handle, attribute_handle=attribute.handle, attribute_value=value
attribute_value = value )
logger.debug(
f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}'
) )
logger.debug(f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}')
# Wait until we can send (only one pending indication at a time per connection) # Wait until we can send (only one pending indication at a time per connection)
async with self.indication_semaphores[connection.handle]: async with self.indication_semaphores[connection.handle]:
assert(self.pending_confirmations[connection.handle] is None) assert self.pending_confirmations[connection.handle] is None
# Create a future value to hold the eventual response # Create a future value to hold the eventual response
self.pending_confirmations[connection.handle] = asyncio.get_running_loop().create_future() self.pending_confirmations[
connection.handle
] = asyncio.get_running_loop().create_future()
try: try:
self.send_gatt_pdu(connection.handle, indication.to_bytes()) self.send_gatt_pdu(connection.handle, indication.to_bytes())
await asyncio.wait_for(self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT) await asyncio.wait_for(
except asyncio.TimeoutError: self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red')) logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') raise TimeoutError(f'GATT timeout for {indication.name}') from error
finally: finally:
self.pending_confirmations[connection.handle] = None self.pending_confirmations[connection.handle] = None
async def indicate_subscribers(self, attribute): async def notify_or_indicate_subscribers(
self, indicate, attribute, value=None, force=False
):
# Get all the connections for which there's at least one subscription # Get all the connections for which there's at least one subscription
connections = [ connections = [
connection for connection in [ connection
for connection in [
self.device.lookup_connection(connection_handle) self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items() for (connection_handle, subscribers) in self.subscribers.items()
if subscribers.get(attribute.handle) if force or subscribers.get(attribute.handle)
] ]
if connection is not None if connection is not None
] ]
# Indicate for each connection # Indicate or notify for each connection
if connections: if connections:
await asyncio.wait([ coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
self.indicate_subscriber(connection, attribute) await asyncio.wait(
for connection in connections [
]) asyncio.create_task(coroutine(connection, attribute, value, force))
for connection in connections
]
)
async def notify_subscribers(self, attribute, value=None, force=False):
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
async def indicate_subscribers(self, attribute, value=None, force=False):
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection): def on_disconnection(self, connection):
if connection.handle in self.mtus:
del self.mtus[connection.handle]
if connection.handle in self.subscribers: if connection.handle in self.subscribers:
del self.subscribers[connection.handle] del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores: if connection.handle in self.indication_semaphores:
@@ -302,17 +444,17 @@ class Server(EventEmitter):
except ATT_Error as error: except ATT_Error as error:
logger.debug(f'normal exception returned by handler: {error}') logger.debug(f'normal exception returned by handler: {error}')
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error = att_pdu.op_code, request_opcode_in_error=att_pdu.op_code,
attribute_handle_in_error = error.att_handle, attribute_handle_in_error=error.att_handle,
error_code = error.error_code error_code=error.error_code,
) )
self.send_response(connection, response) self.send_response(connection, response)
except Exception as error: except Exception as error:
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}') logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error = att_pdu.op_code, request_opcode_in_error=att_pdu.op_code,
attribute_handle_in_error = 0x0000, attribute_handle_in_error=0x0000,
error_code = ATT_UNLIKELY_ERROR_ERROR error_code=ATT_UNLIKELY_ERROR_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
raise error raise error
@@ -323,10 +465,13 @@ class Server(EventEmitter):
self.on_att_request(connection, att_pdu) self.on_att_request(connection, att_pdu)
else: else:
# Just ignore # Just ignore
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}') logger.warning(
color(
def get_mtu(self, connection): f'--- Ignoring GATT Request from [0x{connection.handle:04X}]: ',
return self.mtus.get(connection.handle, ATT_DEFAULT_MTU) 'red',
)
+ str(att_pdu)
)
####################################################### #######################################################
# ATT handlers # ATT handlers
@@ -335,11 +480,16 @@ class Server(EventEmitter):
''' '''
Handler for requests without a more specific handler Handler for requests without a more specific handler
''' '''
logger.warning(f'{color(f"--- Unsupported ATT Request from [0x{connection.handle:04X}]:", "red")} {pdu}') logger.warning(
color(
f'--- Unsupported ATT Request from [0x{connection.handle:04X}]: ', 'red'
)
+ str(pdu)
)
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error = pdu.op_code, request_opcode_in_error=pdu.op_code,
attribute_handle_in_error = 0x0000, attribute_handle_in_error=0x0000,
error_code = ATT_REQUEST_NOT_SUPPORTED_ERROR error_code=ATT_REQUEST_NOT_SUPPORTED_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -347,12 +497,18 @@ class Server(EventEmitter):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
''' '''
mtu = max(ATT_DEFAULT_MTU, min(self.max_mtu, request.client_rx_mtu)) self.send_response(
self.mtus[connection.handle] = mtu connection, ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = mtu)) )
# Notify the device # Compute the final MTU
self.device.on_connection_att_mtu_update(connection.handle, mtu) if request.client_rx_mtu >= ATT_DEFAULT_MTU:
mtu = min(self.max_mtu, request.client_rx_mtu)
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
else:
logger.warning('invalid client_rx_mtu received, MTU not changed')
def on_att_find_information_request(self, connection, request): def on_att_find_information_request(self, connection, request):
''' '''
@@ -360,22 +516,29 @@ class Server(EventEmitter):
''' '''
# Check the request parameters # Check the request parameters
if request.starting_handle == 0 or request.starting_handle > request.ending_handle: if (
self.send_response(connection, ATT_Error_Response( request.starting_handle == 0
request_opcode_in_error = request.op_code, or request.starting_handle > request.ending_handle
attribute_handle_in_error = request.starting_handle, ):
error_code = ATT_INVALID_HANDLE_ERROR self.send_response(
)) connection,
ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_INVALID_HANDLE_ERROR,
),
)
return return
# Build list of returned attributes # Build list of returned attributes
pdu_space_available = self.get_mtu(connection) - 2 pdu_space_available = connection.att_mtu - 2
attributes = [] attributes = []
uuid_size = 0 uuid_size = 0
for attribute in ( for attribute in (
attribute for attribute in self.attributes if attribute
attribute.handle >= request.starting_handle and for attribute in self.attributes
attribute.handle <= request.ending_handle if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
): ):
# TODO: check permissions # TODO: check permissions
@@ -402,14 +565,14 @@ class Server(EventEmitter):
for attribute in attributes for attribute in attributes
] ]
response = ATT_Find_Information_Response( response = ATT_Find_Information_Response(
format = 1 if len(attributes[0].type.to_pdu_bytes()) == 2 else 2, format=1 if len(attributes[0].type.to_pdu_bytes()) == 2 else 2,
information_data = b''.join(information_data_list) information_data=b''.join(information_data_list),
) )
else: else:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error = request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error = request.starting_handle, attribute_handle_in_error=request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -420,15 +583,16 @@ class Server(EventEmitter):
''' '''
# Build list of returned attributes # Build list of returned attributes
pdu_space_available = self.get_mtu(connection) - 2 pdu_space_available = connection.att_mtu - 2
attributes = [] attributes = []
for attribute in ( for attribute in (
attribute for attribute in self.attributes if attribute
attribute.handle >= request.starting_handle and for attribute in self.attributes
attribute.handle <= request.ending_handle and if attribute.handle >= request.starting_handle
attribute.type == request.attribute_type and and attribute.handle <= request.ending_handle
attribute.read_value(connection) == request.attribute_value and and attribute.type == request.attribute_type
pdu_space_available >= 4 and attribute.read_value(connection) == request.attribute_value
and pdu_space_available >= 4
): ):
# TODO: check permissions # TODO: check permissions
@@ -440,25 +604,27 @@ class Server(EventEmitter):
if attributes: if attributes:
handles_information_list = [] handles_information_list = []
for attribute in attributes: for attribute in attributes:
if attribute.type in { if attribute.type in (
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
}: ):
# Part of a group # Part of a group
group_end_handle = attribute.end_group_handle group_end_handle = attribute.end_group_handle
else: else:
# Not part of a group # Not part of a group
group_end_handle = attribute.handle group_end_handle = attribute.handle
handles_information_list.append(struct.pack('<HH', attribute.handle, group_end_handle)) handles_information_list.append(
struct.pack('<HH', attribute.handle, group_end_handle)
)
response = ATT_Find_By_Type_Value_Response( response = ATT_Find_By_Type_Value_Response(
handles_information_list = b''.join(handles_information_list) handles_information_list=b''.join(handles_information_list)
) )
else: else:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error = request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error = request.starting_handle, attribute_handle_in_error=request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -468,21 +634,21 @@ class Server(EventEmitter):
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
''' '''
mtu = self.get_mtu(connection) pdu_space_available = connection.att_mtu - 2
pdu_space_available = mtu - 2
attributes = [] attributes = []
for attribute in ( for attribute in (
attribute for attribute in self.attributes if attribute
attribute.type == request.attribute_type and for attribute in self.attributes
attribute.handle >= request.starting_handle and if attribute.type == request.attribute_type
attribute.handle <= request.ending_handle and and attribute.handle >= request.starting_handle
pdu_space_available and attribute.handle <= request.ending_handle
and pdu_space_available
): ):
# TODO: check permissions # TODO: check permissions
# Check the attribute value size # Check the attribute value size
attribute_value = attribute.read_value(connection) attribute_value = attribute.read_value(connection)
max_attribute_size = min(mtu - 4, 253) max_attribute_size = min(connection.att_mtu - 4, 253)
if len(attribute_value) > max_attribute_size: if len(attribute_value) > max_attribute_size:
# We need to truncate # We need to truncate
attribute_value = attribute_value[:max_attribute_size] attribute_value = attribute_value[:max_attribute_size]
@@ -500,16 +666,17 @@ class Server(EventEmitter):
pdu_space_available -= entry_size pdu_space_available -= entry_size
if attributes: if attributes:
attribute_data_list = [struct.pack('<H', handle) + value for handle, value in attributes] attribute_data_list = [
struct.pack('<H', handle) + value for handle, value in attributes
]
response = ATT_Read_By_Type_Response( response = ATT_Read_By_Type_Response(
length = entry_size, length=entry_size, attribute_data_list=b''.join(attribute_data_list)
attribute_data_list = b''.join(attribute_data_list)
) )
else: else:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error = request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error = request.starting_handle, attribute_handle_in_error=request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -522,15 +689,13 @@ class Server(EventEmitter):
if attribute := self.get_attribute(request.attribute_handle): if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions # TODO: check permissions
value = attribute.read_value(connection) value = attribute.read_value(connection)
value_size = min(self.get_mtu(connection) - 1, len(value)) value_size = min(connection.att_mtu - 1, len(value))
response = ATT_Read_Response( response = ATT_Read_Response(attribute_value=value[:value_size])
attribute_value = value[:value_size]
)
else: else:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error = request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error = request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR error_code=ATT_INVALID_HANDLE_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -541,30 +706,33 @@ class Server(EventEmitter):
if attribute := self.get_attribute(request.attribute_handle): if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions # TODO: check permissions
mtu = self.get_mtu(connection)
value = attribute.read_value(connection) value = attribute.read_value(connection)
if request.value_offset > len(value): if request.value_offset > len(value):
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error = request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error = request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
error_code = ATT_INVALID_OFFSET_ERROR error_code=ATT_INVALID_OFFSET_ERROR,
) )
elif len(value) <= mtu - 1: elif len(value) <= connection.att_mtu - 1:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error = request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error = request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
) )
else: else:
part_size = min(mtu - 1, len(value) - request.value_offset) part_size = min(
connection.att_mtu - 1, len(value) - request.value_offset
)
response = ATT_Read_Blob_Response( response = ATT_Read_Blob_Response(
part_attribute_value = value[request.value_offset:request.value_offset + part_size] part_attribute_value=value[
request.value_offset : request.value_offset + part_size
]
) )
else: else:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error = request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error = request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR error_code=ATT_INVALID_HANDLE_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -572,32 +740,32 @@ class Server(EventEmitter):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
''' '''
if request.attribute_group_type not in { if request.attribute_group_type not in (
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_INCLUDE_ATTRIBUTE_TYPE GATT_INCLUDE_ATTRIBUTE_TYPE,
}: ):
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error = request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error = request.starting_handle, attribute_handle_in_error=request.starting_handle,
error_code = ATT_UNSUPPORTED_GROUP_TYPE_ERROR error_code=ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
return return
mtu = self.get_mtu(connection) pdu_space_available = connection.att_mtu - 2
pdu_space_available = mtu - 2
attributes = [] attributes = []
for attribute in ( for attribute in (
attribute for attribute in self.attributes if attribute
attribute.type == request.attribute_group_type and for attribute in self.attributes
attribute.handle >= request.starting_handle and if attribute.type == request.attribute_group_type
attribute.handle <= request.ending_handle and and attribute.handle >= request.starting_handle
pdu_space_available and attribute.handle <= request.ending_handle
and pdu_space_available
): ):
# Check the attribute value size # Check the attribute value size
attribute_value = attribute.read_value(connection) attribute_value = attribute.read_value(connection)
max_attribute_size = min(mtu - 6, 251) max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size: if len(attribute_value) > max_attribute_size:
# We need to truncate # We need to truncate
attribute_value = attribute_value[:max_attribute_size] attribute_value = attribute_value[:max_attribute_size]
@@ -611,7 +779,9 @@ class Server(EventEmitter):
break break
# Add the attribute to the list # Add the attribute to the list
attributes.append((attribute.handle, attribute.end_group_handle, attribute_value)) attributes.append(
(attribute.handle, attribute.end_group_handle, attribute_value)
)
pdu_space_available -= entry_size pdu_space_available -= entry_size
if attributes: if attributes:
@@ -620,14 +790,14 @@ class Server(EventEmitter):
for handle, end_group_handle, value in attributes for handle, end_group_handle, value in attributes
] ]
response = ATT_Read_By_Group_Type_Response( response = ATT_Read_By_Group_Type_Response(
length = len(attribute_data_list[0]), length=len(attribute_data_list[0]),
attribute_data_list = b''.join(attribute_data_list) attribute_data_list=b''.join(attribute_data_list),
) )
else: else:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error = request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error = request.starting_handle, attribute_handle_in_error=request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -640,22 +810,28 @@ class Server(EventEmitter):
# Check that the attribute exists # Check that the attribute exists
attribute = self.get_attribute(request.attribute_handle) attribute = self.get_attribute(request.attribute_handle)
if attribute is None: if attribute is None:
self.send_response(connection, ATT_Error_Response( self.send_response(
request_opcode_in_error = request.op_code, connection,
attribute_handle_in_error = request.attribute_handle, ATT_Error_Response(
error_code = ATT_INVALID_HANDLE_ERROR request_opcode_in_error=request.op_code,
)) attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_HANDLE_ERROR,
),
)
return return
# TODO: check permissions # TODO: check permissions
# Check the request parameters # Check the request parameters
if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE: if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE:
self.send_response(connection, ATT_Error_Response( self.send_response(
request_opcode_in_error = request.op_code, connection,
attribute_handle_in_error = request.attribute_handle, ATT_Error_Response(
error_code = ATT_INVALID_ATTRIBUTE_LENGTH_ERROR request_opcode_in_error=request.op_code,
)) attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_ATTRIBUTE_LENGTH_ERROR,
),
)
return return
# Accept the value # Accept the value
@@ -686,13 +862,15 @@ class Server(EventEmitter):
except Exception as error: except Exception as error:
logger.warning(f'!!! ignoring exception: {error}') logger.warning(f'!!! ignoring exception: {error}')
def on_att_handle_value_confirmation(self, connection, confirmation): def on_att_handle_value_confirmation(self, connection, _confirmation):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation
''' '''
if self.pending_confirmations[connection.handle] is None: if self.pending_confirmations[connection.handle] is None:
# Not expected! # Not expected!
logger.warning('!!! unexpected confirmation, there is no pending indication') logger.warning(
'!!! unexpected confirmation, there is no pending indication'
)
return return
self.pending_confirmations[connection.handle].set_result(None) self.pending_confirmations[connection.handle].set_result(None)

File diff suppressed because it is too large Load Diff

View File

@@ -18,8 +18,9 @@
import logging import logging
from colors import color from colors import color
from .att import ATT_CID, ATT_PDU
from .smp import SMP_CID, SMP_Command
from .core import name_or_number from .core import name_or_number
from .gatt import ATT_PDU, ATT_CID
from .l2cap import ( from .l2cap import (
L2CAP_PDU, L2CAP_PDU,
L2CAP_CONNECTION_REQUEST, L2CAP_CONNECTION_REQUEST,
@@ -27,20 +28,17 @@ from .l2cap import (
L2CAP_SIGNALING_CID, L2CAP_SIGNALING_CID,
L2CAP_LE_SIGNALING_CID, L2CAP_LE_SIGNALING_CID,
L2CAP_Control_Frame, L2CAP_Control_Frame,
L2CAP_Connection_Response L2CAP_Connection_Response,
) )
from .hci import ( from .hci import (
HCI_EVENT_PACKET, HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET, HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT, HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_AclDataPacketAssembler HCI_AclDataPacketAssembler,
) )
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM from .rfcomm import RFCOMM_Frame, RFCOMM_PSM
from .sdp import SDP_PDU, SDP_PSM from .sdp import SDP_PDU, SDP_PSM
from .avdtp import ( from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
MessageAssembler as AVDTP_MessageAssembler,
AVDTP_PSM
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -51,8 +49,8 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
PSM_NAMES = { PSM_NAMES = {
RFCOMM_PSM: 'RFCOMM', RFCOMM_PSM: 'RFCOMM',
SDP_PSM: 'SDP', SDP_PSM: 'SDP',
AVDTP_PSM: 'AVDTP' AVDTP_PSM: 'AVDTP'
# TODO: add more PSM values # TODO: add more PSM values
} }
@@ -61,19 +59,23 @@ PSM_NAMES = {
class PacketTracer: class PacketTracer:
class AclStream: class AclStream:
def __init__(self, analyzer): def __init__(self, analyzer):
self.analyzer = analyzer self.analyzer = analyzer
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.psms = {} # PSM, by source_cid self.psms = {} # PSM, by source_cid
self.peer = None # ACL stream in the other direction self.peer = None # ACL stream in the other direction
# pylint: disable=too-many-nested-blocks
def on_acl_pdu(self, pdu): def on_acl_pdu(self, pdu):
l2cap_pdu = L2CAP_PDU.from_bytes(pdu) l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
if l2cap_pdu.cid == ATT_CID: if l2cap_pdu.cid == ATT_CID:
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload) att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(att_pdu) self.analyzer.emit(att_pdu)
elif l2cap_pdu.cid == L2CAP_SIGNALING_CID or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID: elif l2cap_pdu.cid == SMP_CID:
smp_command = SMP_Command.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(smp_command)
elif l2cap_pdu.cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload) control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(control_frame) self.analyzer.emit(control_frame)
@@ -81,16 +83,26 @@ class PacketTracer:
if control_frame.code == L2CAP_CONNECTION_REQUEST: if control_frame.code == L2CAP_CONNECTION_REQUEST:
self.psms[control_frame.source_cid] = control_frame.psm self.psms[control_frame.source_cid] = control_frame.psm
elif control_frame.code == L2CAP_CONNECTION_RESPONSE: elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
if control_frame.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL: if (
control_frame.result
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
):
if self.peer: if self.peer:
if psm := self.peer.psms.get(control_frame.source_cid): if psm := self.peer.psms.get(control_frame.source_cid):
# Found a pending connection # Found a pending connection
self.psms[control_frame.destination_cid] = psm self.psms[control_frame.destination_cid] = psm
# For AVDTP connections, create a packet assembler for each direction # For AVDTP connections, create a packet assembler for
# each direction
if psm == AVDTP_PSM: if psm == AVDTP_PSM:
self.avdtp_assemblers[control_frame.source_cid] = AVDTP_MessageAssembler(self.on_avdtp_message) self.avdtp_assemblers[
self.peer.avdtp_assemblers[control_frame.destination_cid] = AVDTP_MessageAssembler(self.peer.on_avdtp_message) control_frame.source_cid
] = AVDTP_MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
control_frame.destination_cid
] = AVDTP_MessageAssembler(
self.peer.on_avdtp_message
)
else: else:
# Try to find the PSM associated with this PDU # Try to find the PSM associated with this PDU
@@ -102,31 +114,42 @@ class PacketTracer:
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload) rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame) self.analyzer.emit(rfcomm_frame)
elif psm == AVDTP_PSM: elif psm == AVDTP_PSM:
self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM=AVDTP]: {l2cap_pdu.payload.hex()}') self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
)
assembler = self.avdtp_assemblers.get(l2cap_pdu.cid) assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
if assembler: if assembler:
assembler.on_pdu(l2cap_pdu.payload) assembler.on_pdu(l2cap_pdu.payload)
else: else:
psm_string = name_or_number(PSM_NAMES, psm) psm_string = name_or_number(PSM_NAMES, psm)
self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM={psm_string}]: {l2cap_pdu.payload.hex()}') self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM={psm_string}]: {l2cap_pdu.payload.hex()}'
)
else: else:
self.analyzer.emit(l2cap_pdu) self.analyzer.emit(l2cap_pdu)
def on_avdtp_message(self, transaction_label, message): def on_avdtp_message(self, transaction_label, message):
self.analyzer.emit(f'{color("AVDTP", "green")} [{transaction_label}] {message}') self.analyzer.emit(
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
)
def feed_packet(self, packet): def feed_packet(self, packet):
self.packet_assembler.feed_packet(packet) self.packet_assembler.feed_packet(packet)
class Analyzer: class Analyzer:
def __init__(self, label, emit_message): def __init__(self, label, emit_message):
self.label = label self.label = label
self.emit_message = emit_message self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle self.acl_streams = {} # ACL streams, by connection handle
self.peer = None # Analyzer in the other direction self.peer = None # Analyzer in the other direction
def start_acl_stream(self, connection_handle): def start_acl_stream(self, connection_handle):
logger.info(f'[{self.label}] +++ Creating ACL stream for connection 0x{connection_handle:04X}') logger.info(
f'[{self.label}] +++ Creating ACL stream for connection '
f'0x{connection_handle:04X}'
)
stream = PacketTracer.AclStream(self) stream = PacketTracer.AclStream(self)
self.acl_streams[connection_handle] = stream self.acl_streams[connection_handle] = stream
@@ -139,7 +162,10 @@ class PacketTracer:
def end_acl_stream(self, connection_handle): def end_acl_stream(self, connection_handle):
if connection_handle in self.acl_streams: if connection_handle in self.acl_streams:
logger.info(f'[{self.label}] --- Removing ACL stream for connection 0x{connection_handle:04X}') logger.info(
f'[{self.label}] --- Removing ACL stream for connection '
f'0x{connection_handle:04X}'
)
del self.acl_streams[connection_handle] del self.acl_streams[connection_handle]
# Let the other forwarder know so it can cleanup its stream as well # Let the other forwarder know so it can cleanup its stream as well
@@ -171,9 +197,13 @@ class PacketTracer:
self, self,
host_to_controller_label=color('HOST->CONTROLLER', 'blue'), host_to_controller_label=color('HOST->CONTROLLER', 'blue'),
controller_to_host_label=color('CONTROLLER->HOST', 'cyan'), controller_to_host_label=color('CONTROLLER->HOST', 'cyan'),
emit_message=logger.info emit_message=logger.info,
): ):
self.host_to_controller_analyzer = PacketTracer.Analyzer(host_to_controller_label, emit_message) self.host_to_controller_analyzer = PacketTracer.Analyzer(
self.controller_to_host_analyzer = PacketTracer.Analyzer(controller_to_host_label, emit_message) host_to_controller_label, emit_message
)
self.controller_to_host_analyzer = PacketTracer.Analyzer(
controller_to_host_label, emit_message
)
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer

View File

@@ -34,16 +34,16 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HfpProtocol: class HfpProtocol:
def __init__(self, dlc): def __init__(self, dlc):
self.dlc = dlc self.dlc = dlc
self.buffer = '' self.buffer = ''
self.lines = collections.deque() self.lines = collections.deque()
self.lines_available = asyncio.Event() self.lines_available = asyncio.Event()
dlc.sink = self.feed dlc.sink = self.feed
def feed(self, data): def feed(self, data):
# Convert the data to a string if needed # Convert the data to a string if needed
if type(data) == bytes: if isinstance(data, bytes):
data = data.decode('utf-8') data = data.decode('utf-8')
logger.debug(f'<<< Data received: {data}') logger.debug(f'<<< Data received: {data}')
@@ -52,7 +52,7 @@ class HfpProtocol:
self.buffer += data self.buffer += data
while (separator := self.buffer.find('\r')) >= 0: while (separator := self.buffer.find('\r')) >= 0:
line = self.buffer[:separator].strip() line = self.buffer[:separator].strip()
self.buffer = self.buffer[separator + 1:] self.buffer = self.buffer[separator + 1 :]
if len(line) > 0: if len(line) > 0:
self.on_line(line) self.on_line(line)
@@ -79,16 +79,16 @@ class HfpProtocol:
async def initialize_service(self): async def initialize_service(self):
# Perform Service Level Connection Initialization # Perform Service Level Connection Initialization
self.send_command_line('AT+BRSF=2072') # Retrieve Supported Features self.send_command_line('AT+BRSF=2072') # Retrieve Supported Features
line = await(self.next_line()) await (self.next_line())
line = await(self.next_line()) await (self.next_line())
self.send_command_line('AT+CIND=?') self.send_command_line('AT+CIND=?')
line = await(self.next_line()) await (self.next_line())
line = await(self.next_line()) await (self.next_line())
self.send_command_line('AT+CIND?') self.send_command_line('AT+CIND?')
line = await(self.next_line()) await (self.next_line())
line = await(self.next_line()) await (self.next_line())
self.send_command_line('AT+CMER=3,0,0,1') self.send_command_line('AT+CMER=3,0,0,1')
line = await(self.next_line()) await (self.next_line())

View File

@@ -16,16 +16,60 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import collections
import logging import logging
from pyee import EventEmitter import struct
from colors import color from colors import color
from .hci import * from bumble.l2cap import L2CAP_PDU
from .l2cap import *
from .att import * from .hci import (
from .gatt import * HCI_ACL_DATA_PACKET,
from .smp import * HCI_COMMAND_COMPLETE_EVENT,
from .core import ConnectionParameters HCI_COMMAND_PACKET,
HCI_EVENT_PACKET,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND,
HCI_RESET_COMMAND,
HCI_SUCCESS,
HCI_SUPPORTED_COMMANDS_FLAGS,
HCI_VERSION_BLUETOOTH_CORE_4_0,
HCI_AclDataPacket,
HCI_AclDataPacketAssembler,
HCI_Constant,
HCI_Error,
HCI_LE_Long_Term_Key_Request_Negative_Reply_Command,
HCI_LE_Long_Term_Key_Request_Reply_Command,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_Read_Local_Supported_Features_Command,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_LE_Remote_Connection_Parameter_Request_Reply_Command,
HCI_LE_Set_Event_Mask_Command,
HCI_LE_Write_Suggested_Default_Data_Length_Command,
HCI_Link_Key_Request_Negative_Reply_Command,
HCI_Link_Key_Request_Reply_Command,
HCI_PIN_Code_Request_Negative_Reply_Command,
HCI_Packet,
HCI_Read_Buffer_Size_Command,
HCI_Read_Local_Supported_Commands_Command,
HCI_Read_Local_Version_Information_Command,
HCI_Reset_Command,
HCI_Set_Event_Mask_Command,
)
from .core import (
BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE,
BT_LE_TRANSPORT,
ConnectionPHY,
ConnectionParameters,
)
from .utils import AbortableEventEmitter
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -36,56 +80,60 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27 HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1 HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1
HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27 HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1 HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Connection: class Connection:
def __init__(self, host, handle, role, peer_address): def __init__(self, host, handle, role, peer_address, transport):
self.host = host self.host = host
self.handle = handle self.handle = handle
self.role = role self.role = role
self.peer_address = peer_address self.peer_address = peer_address
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport
def on_hci_acl_data_packet(self, packet): def on_hci_acl_data_packet(self, packet):
self.assembler.feed_packet(packet) self.assembler.feed_packet(packet)
def on_acl_pdu(self, pdu): def on_acl_pdu(self, pdu):
l2cap_pdu = L2CAP_PDU.from_bytes(pdu) l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
if l2cap_pdu.cid == ATT_CID:
self.host.on_gatt_pdu(self, l2cap_pdu.payload)
elif l2cap_pdu.cid == SMP_CID:
self.host.on_smp_pdu(self, l2cap_pdu.payload)
else:
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Host(EventEmitter): class Host(AbortableEventEmitter):
def __init__(self, controller_source = None, controller_sink = None): def __init__(self, controller_source=None, controller_sink=None):
super().__init__() super().__init__()
self.hci_sink = None self.hci_sink = None
self.ready = False # True when we can accept incoming packets self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle self.reset_done = False
self.pending_command = None self.connections = {} # Connections, by connection handle
self.pending_response = None self.pending_command = None
self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH self.pending_response = None
self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH
self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS
self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH
self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS
self.acl_packet_queue = collections.deque() self.acl_packet_queue = collections.deque()
self.acl_packets_in_flight = 0 self.acl_packets_in_flight = 0
self.local_supported_commands = bytes(64) self.local_version = None
self.command_semaphore = asyncio.Semaphore(1) self.local_supported_commands = bytes(64)
self.long_term_key_provider = None self.local_le_features = 0
self.link_key_provider = None self.suggested_max_tx_octets = 251 # Max allowed
self.pairing_io_capability_provider = None # Classic only self.suggested_max_tx_time = 2120 # Max allowed
self.command_semaphore = asyncio.Semaphore(1)
self.long_term_key_provider = None
self.link_key_provider = None
self.pairing_io_capability_provider = None # Classic only
# Connect to the source and sink if specified # Connect to the source and sink if specified
if controller_source: if controller_source:
@@ -93,38 +141,122 @@ class Host(EventEmitter):
if controller_sink: if controller_sink:
self.set_packet_sink(controller_sink) self.set_packet_sink(controller_sink)
async def flush(self):
# Make sure no command is pending
await self.command_semaphore.acquire()
# Flush current host state, then release command semaphore
self.emit('flush')
self.command_semaphore.release()
async def reset(self): async def reset(self):
await self.send_command(HCI_Reset_Command()) if self.ready:
self.ready = False
await self.flush()
await self.send_command(HCI_Reset_Command(), check_result=True)
self.ready = True self.ready = True
response = await self.send_command(HCI_Read_Local_Supported_Commands_Command()) response = await self.send_command(
if response.return_parameters.status != HCI_SUCCESS: HCI_Read_Local_Supported_Commands_Command(), check_result=True
raise ProtocolError(response.return_parameters.status, 'hci') )
self.local_supported_commands = response.return_parameters.supported_commands self.local_supported_commands = response.return_parameters.supported_commands
await self.send_command(HCI_Set_Event_Mask_Command(event_mask = bytes.fromhex('FFFFFFFFFFFFFFFF'))) if self.supports_command(HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
await self.send_command(HCI_LE_Set_Event_Mask_Command(le_event_mask = bytes.fromhex('FFFFF00000000000'))) response = await self.send_command(
await self.send_command(HCI_Read_Local_Version_Information_Command()) HCI_LE_Read_Local_Supported_Features_Command(), check_result=True
await self.send_command(HCI_Write_LE_Host_Support_Command(le_supported_host = 1, simultaneous_le_host = 0)) )
self.local_le_features = struct.unpack(
'<Q', response.return_parameters.le_features
)[0]
response = await self.send_command(HCI_LE_Read_Buffer_Size_Command()) if self.supports_command(HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND):
if response.return_parameters.status == HCI_SUCCESS: response = await self.send_command(
self.hc_le_acl_data_packet_length = response.return_parameters.hc_le_acl_data_packet_length HCI_Read_Local_Version_Information_Command(), check_result=True
self.hc_total_num_le_acl_data_packets = response.return_parameters.hc_total_num_le_acl_data_packets )
logger.debug(f'HCI LE ACL flow control: hc_le_acl_data_packet_length={response.return_parameters.hc_le_acl_data_packet_length}, hc_total_num_le_acl_data_packets={response.return_parameters.hc_total_num_le_acl_data_packets}') self.local_version = response.return_parameters
await self.send_command(
HCI_Set_Event_Mask_Command(event_mask=bytes.fromhex('FFFFFFFFFFFFFF3F'))
)
if (
self.local_version is not None
and self.local_version.hci_version <= HCI_VERSION_BLUETOOTH_CORE_4_0
):
# Some older controllers don't like event masks with bits they don't
# understand
le_event_mask = bytes.fromhex('1F00000000000000')
else: else:
logger.warn(f'HCI_LE_Read_Buffer_Size_Command failed: {response.return_parameters.status}') le_event_mask = bytes.fromhex('FFFFF00000000000')
if response.return_parameters.hc_le_acl_data_packet_length == 0 or response.return_parameters.hc_total_num_le_acl_data_packets == 0:
# Read the non-LE-specific values await self.send_command(
response = await self.send_command(HCI_Read_Buffer_Size_Command()) HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask)
if response.return_parameters.status == HCI_SUCCESS: )
self.hc_acl_data_packet_length = response.return_parameters.hc_le_acl_data_packet_length
self.hc_le_acl_data_packet_length = self.hc_le_acl_data_packet_length or self.hc_acl_data_packet_length if self.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
self.hc_total_num_acl_data_packets = response.return_parameters.hc_total_num_le_acl_data_packets response = await self.send_command(
self.hc_total_num_le_acl_data_packets = self.hc_total_num_le_acl_data_packets or self.hc_total_num_acl_data_packets HCI_Read_Buffer_Size_Command(), check_result=True
logger.debug(f'HCI LE ACL flow control: hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length}, hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}') )
else: self.hc_acl_data_packet_length = (
logger.warn(f'HCI_Read_Buffer_Size_Command failed: {response.return_parameters.status}') response.return_parameters.hc_acl_data_packet_length
)
self.hc_total_num_acl_data_packets = (
response.return_parameters.hc_total_num_acl_data_packets
)
logger.debug(
'HCI ACL flow control: '
f'hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}'
)
if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
self.hc_le_acl_data_packet_length = (
response.return_parameters.hc_le_acl_data_packet_length
)
self.hc_total_num_le_acl_data_packets = (
response.return_parameters.hc_total_num_le_acl_data_packets
)
logger.debug(
'HCI LE ACL flow control: '
f'hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
'hc_total_num_le_acl_data_packets='
f'{self.hc_total_num_le_acl_data_packets}'
)
if (
response.return_parameters.hc_le_acl_data_packet_length == 0
or response.return_parameters.hc_total_num_le_acl_data_packets == 0
):
# LE and Classic share the same values
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = (
self.hc_total_num_acl_data_packets
)
if self.supports_command(
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
) and self.supports_command(HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await self.send_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
if (
suggested_max_tx_octets != self.suggested_max_tx_octets
or suggested_max_tx_time != self.suggested_max_tx_time
):
await self.send_command(
HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets=self.suggested_max_tx_octets,
suggested_max_tx_time=self.suggested_max_tx_time,
)
)
self.reset_done = True self.reset_done = True
@@ -144,13 +276,13 @@ class Host(EventEmitter):
def send_hci_packet(self, packet): def send_hci_packet(self, packet):
self.hci_sink.on_packet(packet.to_bytes()) self.hci_sink.on_packet(packet.to_bytes())
async def send_command(self, command): async def send_command(self, command, check_result=False):
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}') logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}')
# Wait until we can send (only one pending command at a time) # Wait until we can send (only one pending command at a time)
async with self.command_semaphore: async with self.command_semaphore:
assert(self.pending_command is None) assert self.pending_command is None
assert(self.pending_response is None) assert self.pending_response is None
# Create a future value to hold the eventual response # Create a future value to hold the eventual response
self.pending_response = asyncio.get_running_loop().create_future() self.pending_response = asyncio.get_running_loop().create_future()
@@ -159,11 +291,29 @@ class Host(EventEmitter):
try: try:
self.send_hci_packet(command) self.send_hci_packet(command)
response = await self.pending_response response = await self.pending_response
# TODO: check error values
# Check the return parameters if required
if check_result:
if isinstance(response.return_parameters, int):
status = response.return_parameters
elif isinstance(response.return_parameters, bytes):
# return parameters first field is a one byte status code
status = response.return_parameters[0]
else:
status = response.return_parameters.status
if status != HCI_SUCCESS:
logger.warning(
f'{command.name} failed ({HCI_Constant.error_name(status)})'
)
raise HCI_Error(status)
return response return response
except Exception as error: except Exception as error:
logger.warning(f'{color("!!! Exception while sending HCI packet:", "red")} {error}') logger.warning(
# raise error f'{color("!!! Exception while sending HCI packet:", "red")} {error}'
)
raise error
finally: finally:
self.pending_command = None self.pending_command = None
self.pending_response = None self.pending_response = None
@@ -183,15 +333,18 @@ class Host(EventEmitter):
offset = 0 offset = 0
pb_flag = 0 pb_flag = 0
while bytes_remaining: while bytes_remaining:
# TODO: support different LE/Classic lengths
data_total_length = min(bytes_remaining, self.hc_le_acl_data_packet_length) data_total_length = min(bytes_remaining, self.hc_le_acl_data_packet_length)
acl_packet = HCI_AclDataPacket( acl_packet = HCI_AclDataPacket(
connection_handle = connection_handle, connection_handle=connection_handle,
pb_flag = pb_flag, pb_flag=pb_flag,
bc_flag = 0, bc_flag=0,
data_total_length = data_total_length, data_total_length=data_total_length,
data = l2cap_pdu[offset:offset + data_total_length] data=l2cap_pdu[offset : offset + data_total_length],
)
logger.debug(
f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}'
) )
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}')
self.queue_acl_packet(acl_packet) self.queue_acl_packet(acl_packet)
pb_flag = 1 pb_flag = 1
offset += data_total_length offset += data_total_length
@@ -202,22 +355,63 @@ class Host(EventEmitter):
self.check_acl_packet_queue() self.check_acl_packet_queue()
if len(self.acl_packet_queue): if len(self.acl_packet_queue):
logger.debug(f'{self.acl_packets_in_flight} ACL packets in flight, {len(self.acl_packet_queue)} in queue') logger.debug(
f'{self.acl_packets_in_flight} ACL packets in flight, '
f'{len(self.acl_packet_queue)} in queue'
)
def check_acl_packet_queue(self): def check_acl_packet_queue(self):
# Send all we can # Send all we can (TODO: support different LE/Classic limits)
while len(self.acl_packet_queue) > 0 and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets: while (
len(self.acl_packet_queue) > 0
and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets
):
packet = self.acl_packet_queue.pop() packet = self.acl_packet_queue.pop()
self.send_hci_packet(packet) self.send_hci_packet(packet)
self.acl_packets_in_flight += 1 self.acl_packets_in_flight += 1
def supports_command(self, command):
# Find the support flag position for this command
for (octet, flags) in enumerate(HCI_SUPPORTED_COMMANDS_FLAGS):
for (flag_position, value) in enumerate(flags):
if value == command:
# Check if the flag is set
if octet < len(self.local_supported_commands) and flag_position < 8:
return (
self.local_supported_commands[octet] & (1 << flag_position)
) != 0
return False
@property
def supported_commands(self):
commands = []
for (octet, flags) in enumerate(self.local_supported_commands):
if octet < len(HCI_SUPPORTED_COMMANDS_FLAGS):
for flag in range(8):
if flags & (1 << flag) != 0:
command = HCI_SUPPORTED_COMMANDS_FLAGS[octet][flag]
if command is not None:
commands.append(command)
return commands
def supports_le_feature(self, feature):
return (self.local_le_features & (1 << feature)) != 0
@property
def supported_le_features(self):
return [
feature for feature in range(64) if self.local_le_features & (1 << feature)
]
# Packet Sink protocol (packets coming from the controller via HCI) # Packet Sink protocol (packets coming from the controller via HCI)
def on_packet(self, packet): def on_packet(self, packet):
hci_packet = HCI_Packet.from_bytes(packet) hci_packet = HCI_Packet.from_bytes(packet)
if self.ready or ( if self.ready or (
hci_packet.hci_packet_type == HCI_EVENT_PACKET and hci_packet.hci_packet_type == HCI_EVENT_PACKET
hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT and and hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT
hci_packet.command_opcode == HCI_RESET_COMMAND and hci_packet.command_opcode == HCI_RESET_COMMAND
): ):
self.on_hci_packet(hci_packet) self.on_hci_packet(hci_packet)
else: else:
@@ -249,12 +443,6 @@ class Host(EventEmitter):
if connection := self.connections.get(packet.connection_handle): if connection := self.connections.get(packet.connection_handle):
connection.on_hci_acl_data_packet(packet) connection.on_hci_acl_data_packet(packet)
def on_gatt_pdu(self, connection, pdu):
self.emit('gatt_pdu', connection.handle, pdu)
def on_smp_pdu(self, connection, pdu):
self.emit('smp_pdu', connection.handle, pdu)
def on_l2cap_pdu(self, connection, cid, pdu): def on_l2cap_pdu(self, connection, cid, pdu):
self.emit('l2cap_pdu', connection.handle, cid, pdu) self.emit('l2cap_pdu', connection.handle, cid, pdu)
@@ -262,7 +450,11 @@ class Host(EventEmitter):
if self.pending_response: if self.pending_response:
# Check that it is what we were expecting # Check that it is what we were expecting
if self.pending_command.op_code != event.command_opcode: if self.pending_command.op_code != event.command_opcode:
logger.warning(f'!!! command result mismatch, expected 0x{self.pending_command.op_code:X} but got 0x{event.command_opcode:X}') logger.warning(
'!!! command result mismatch, expected '
f'0x{self.pending_command.op_code:X} but got '
f'0x{event.command_opcode:X}'
)
self.pending_response.set_result(event) self.pending_response.set_result(event)
else: else:
@@ -276,10 +468,12 @@ class Host(EventEmitter):
def on_hci_command_complete_event(self, event): def on_hci_command_complete_event(self, event):
if event.command_opcode == 0: if event.command_opcode == 0:
# This is used just for the Num_HCI_Command_Packets field, not related to an actual command # This is used just for the Num_HCI_Command_Packets field, not related to
# an actual command
logger.debug('no-command event') logger.debug('no-command event')
else: return None
return self.on_command_processed(event)
return self.on_command_processed(event)
def on_hci_command_status_event(self, event): def on_hci_command_status_event(self, event):
return self.on_command_processed(event) return self.on_command_processed(event)
@@ -290,36 +484,49 @@ class Host(EventEmitter):
self.acl_packets_in_flight -= total_packets self.acl_packets_in_flight -= total_packets
self.check_acl_packet_queue() self.check_acl_packet_queue()
else: else:
logger.warning(color(f'!!! {total_packets} completed but only {self.acl_packets_in_flight} in flight')) logger.warning(
color(
'!!! {total_packets} completed but only '
f'{self.acl_packets_in_flight} in flight'
)
)
self.acl_packets_in_flight = 0 self.acl_packets_in_flight = 0
# Classic only # Classic only
def on_hci_connection_request_event(self, event): def on_hci_connection_request_event(self, event):
# For now, just accept everything # Notify the listeners
# TODO: delegate the decision self.emit(
self.send_command_sync( 'connection_request',
HCI_Accept_Connection_Request_Command( event.bd_addr,
bd_addr = event.bd_addr, event.class_of_device,
role = 0x01 # Remain the peripheral event.link_type,
)
) )
def on_hci_le_connection_complete_event(self, event): def on_hci_le_connection_complete_event(self, event):
# Check if this is a cancellation # Check if this is a cancellation
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
# Create/update the connection # Create/update the connection
logger.debug(f'### CONNECTION: [0x{event.connection_handle:04X}] {event.peer_address} as {HCI_Constant.role_name(event.role)}') logger.debug(
f'### CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.peer_address} as {HCI_Constant.role_name(event.role)}'
)
connection = self.connections.get(event.connection_handle) connection = self.connections.get(event.connection_handle)
if connection is None: if connection is None:
connection = Connection(self, event.connection_handle, event.role, event.peer_address) connection = Connection(
self,
event.connection_handle,
event.role,
event.peer_address,
BT_LE_TRANSPORT,
)
self.connections[event.connection_handle] = connection self.connections[event.connection_handle] = connection
# Notify the client # Notify the client
connection_parameters = ConnectionParameters( connection_parameters = ConnectionParameters(
event.conn_interval, event.connection_interval,
event.conn_latency, event.peripheral_latency,
event.supervision_timeout event.supervision_timeout,
) )
self.emit( self.emit(
'connection', 'connection',
@@ -328,13 +535,15 @@ class Host(EventEmitter):
event.peer_address, event.peer_address,
None, None,
event.role, event.role,
connection_parameters connection_parameters,
) )
else: else:
logger.debug(f'### CONNECTION FAILED: {event.status}') logger.debug(f'### CONNECTION FAILED: {event.status}')
# Notify the listeners # Notify the listeners
self.emit('connection_failure', event.status) self.emit(
'connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status
)
def on_hci_le_enhanced_connection_complete_event(self, event): def on_hci_le_enhanced_connection_complete_event(self, event):
# Just use the same implementation as for the non-enhanced event for now # Just use the same implementation as for the non-enhanced event for now
@@ -343,11 +552,20 @@ class Host(EventEmitter):
def on_hci_connection_complete_event(self, event): def on_hci_connection_complete_event(self, event):
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
# Create/update the connection # Create/update the connection
logger.debug(f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}') logger.debug(
f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.bd_addr}'
)
connection = self.connections.get(event.connection_handle) connection = self.connections.get(event.connection_handle)
if connection is None: if connection is None:
connection = Connection(self, event.connection_handle, BT_CENTRAL_ROLE, event.bd_addr) connection = Connection(
self,
event.connection_handle,
BT_CENTRAL_ROLE,
event.bd_addr,
BT_BR_EDR_TRANSPORT,
)
self.connections[event.connection_handle] = connection self.connections[event.connection_handle] = connection
# Notify the client # Notify the client
@@ -358,13 +576,15 @@ class Host(EventEmitter):
event.bd_addr, event.bd_addr,
None, None,
BT_CENTRAL_ROLE, BT_CENTRAL_ROLE,
None None,
) )
else: else:
logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}') logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')
# Notify the client # Notify the client
self.emit('connection_failure', event.status) self.emit(
'connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status
)
def on_hci_disconnection_complete_event(self, event): def on_hci_disconnection_complete_event(self, event):
# Find the connection # Find the connection
@@ -373,7 +593,12 @@ class Host(EventEmitter):
return return
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
logger.debug(f'### DISCONNECTION: [0x{event.connection_handle:04X}] {connection.peer_address} as {HCI_Constant.role_name(connection.role)}, reason={event.reason}') logger.debug(
f'### DISCONNECTION: [0x{event.connection_handle:04X}] '
f'{connection.peer_address} as '
f'{HCI_Constant.role_name(connection.role)}, '
f'reason={event.reason}'
)
del self.connections[event.connection_handle] del self.connections[event.connection_handle]
# Notify the listeners # Notify the listeners
@@ -382,7 +607,7 @@ class Host(EventEmitter):
logger.debug(f'### DISCONNECTION FAILED: {event.status}') logger.debug(f'### DISCONNECTION FAILED: {event.status}')
# Notify the listeners # Notify the listeners
self.emit('disconnection_failure', event.status) self.emit('disconnection_failure', event.connection_handle, event.status)
def on_hci_le_connection_update_complete_event(self, event): def on_hci_le_connection_update_complete_event(self, event):
if (connection := self.connections.get(event.connection_handle)) is None: if (connection := self.connections.get(event.connection_handle)) is None:
@@ -392,13 +617,17 @@ class Host(EventEmitter):
# Notify the client # Notify the client
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
connection_parameters = ConnectionParameters( connection_parameters = ConnectionParameters(
event.conn_interval, event.connection_interval,
event.conn_latency, event.peripheral_latency,
event.supervision_timeout event.supervision_timeout,
)
self.emit(
'connection_parameters_update', connection.handle, connection_parameters
) )
self.emit('connection_parameters_update', connection.handle, connection_parameters)
else: else:
self.emit('connection_parameters_update_failure', connection.handle, event.status) self.emit(
'connection_parameters_update_failure', connection.handle, event.status
)
def on_hci_le_phy_update_complete_event(self, event): def on_hci_le_phy_update_complete_event(self, event):
if (connection := self.connections.get(event.connection_handle)) is None: if (connection := self.connections.get(event.connection_handle)) is None:
@@ -414,13 +643,10 @@ class Host(EventEmitter):
def on_hci_le_advertising_report_event(self, event): def on_hci_le_advertising_report_event(self, event):
for report in event.reports: for report in event.reports:
self.emit( self.emit('advertising_report', report)
'advertising_report',
report.address, def on_hci_le_extended_advertising_report_event(self, event):
report.data, self.on_hci_le_advertising_report_event(event)
report.rssi,
report.event_type
)
def on_hci_le_remote_connection_parameter_request_event(self, event): def on_hci_le_remote_connection_parameter_request_event(self, event):
if event.connection_handle not in self.connections: if event.connection_handle not in self.connections:
@@ -431,13 +657,13 @@ class Host(EventEmitter):
# TODO: delegate the decision # TODO: delegate the decision
self.send_command_sync( self.send_command_sync(
HCI_LE_Remote_Connection_Parameter_Request_Reply_Command( HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
connection_handle = event.connection_handle, connection_handle=event.connection_handle,
interval_min = event.interval_min, interval_min=event.interval_min,
interval_max = event.interval_max, interval_max=event.interval_max,
latency = event.latency, latency=event.latency,
timeout = event.timeout, timeout=event.timeout,
minimum_ce_length = 0, min_ce_length=0,
maximum_ce_length = 0 max_ce_length=0,
) )
) )
@@ -451,19 +677,23 @@ class Host(EventEmitter):
logger.debug('no long term key provider') logger.debug('no long term key provider')
long_term_key = None long_term_key = None
else: else:
long_term_key = await self.long_term_key_provider( long_term_key = await self.abort_on(
connection.handle, 'flush',
event.random_number, # pylint: disable-next=not-callable
event.encryption_diversifier self.long_term_key_provider(
connection.handle,
event.random_number,
event.encryption_diversifier,
),
) )
if long_term_key: if long_term_key:
response = HCI_LE_Long_Term_Key_Request_Reply_Command( response = HCI_LE_Long_Term_Key_Request_Reply_Command(
connection_handle = event.connection_handle, connection_handle=event.connection_handle,
long_term_key = long_term_key long_term_key=long_term_key,
) )
else: else:
response = HCI_LE_Long_Term_Key_Request_Negative_Reply_Command( response = HCI_LE_Long_Term_Key_Request_Negative_Reply_Command(
connection_handle = event.connection_handle connection_handle=event.connection_handle
) )
await self.send_command(response) await self.send_command(response)
@@ -478,10 +708,16 @@ class Host(EventEmitter):
def on_hci_role_change_event(self, event): def on_hci_role_change_event(self, event):
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
logger.debug(f'role change for {event.bd_addr}: {HCI_Constant.role_name(event.new_role)}') logger.debug(
f'role change for {event.bd_addr}: '
f'{HCI_Constant.role_name(event.new_role)}'
)
# TODO: lookup the connection and update the role # TODO: lookup the connection and update the role
else: else:
logger.debug(f'role change for {event.bd_addr} failed: {HCI_Constant.error_name(event.status)}') logger.debug(
f'role change for {event.bd_addr} failed: '
f'{HCI_Constant.error_name(event.status)}'
)
def on_hci_le_data_length_change_event(self, event): def on_hci_le_data_length_change_event(self, event):
self.emit( self.emit(
@@ -490,7 +726,7 @@ class Host(EventEmitter):
event.max_tx_octets, event.max_tx_octets,
event.max_tx_time, event.max_tx_time,
event.max_rx_octets, event.max_rx_octets,
event.max_rx_time event.max_rx_time,
) )
def on_hci_authentication_complete_event(self, event): def on_hci_authentication_complete_event(self, event):
@@ -498,21 +734,35 @@ class Host(EventEmitter):
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
self.emit('connection_authentication', event.connection_handle) self.emit('connection_authentication', event.connection_handle)
else: else:
self.emit('connection_authentication_failure', event.connection_handle, event.status) self.emit(
'connection_authentication_failure',
event.connection_handle,
event.status,
)
def on_hci_encryption_change_event(self, event): def on_hci_encryption_change_event(self, event):
# Notify the client # Notify the client
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
self.emit('connection_encryption_change', event.connection_handle, event.encryption_enabled) self.emit(
'connection_encryption_change',
event.connection_handle,
event.encryption_enabled,
)
else: else:
self.emit('connection_encryption_failure', event.connection_handle, event.status) self.emit(
'connection_encryption_failure', event.connection_handle, event.status
)
def on_hci_encryption_key_refresh_complete_event(self, event): def on_hci_encryption_key_refresh_complete_event(self, event):
# Notify the client # Notify the client
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
self.emit('connection_encryption_key_refresh', event.connection_handle) self.emit('connection_encryption_key_refresh', event.connection_handle)
else: else:
self.emit('connection_encryption_key_refresh_failure', event.connection_handle, event.status) self.emit(
'connection_encryption_key_refresh_failure',
event.connection_handle,
event.status,
)
def on_hci_link_supervision_timeout_changed_event(self, event): def on_hci_link_supervision_timeout_changed_event(self, event):
pass pass
@@ -524,19 +774,23 @@ class Host(EventEmitter):
pass pass
def on_hci_link_key_notification_event(self, event): def on_hci_link_key_notification_event(self, event):
logger.debug(f'link key for {event.bd_addr}: {event.link_key.hex()}, type={HCI_Constant.link_key_type_name(event.key_type)}') logger.debug(
f'link key for {event.bd_addr}: {event.link_key.hex()}, '
f'type={HCI_Constant.link_key_type_name(event.key_type)}'
)
self.emit('link_key', event.bd_addr, event.link_key, event.key_type) self.emit('link_key', event.bd_addr, event.link_key, event.key_type)
def on_hci_simple_pairing_complete_event(self, event): def on_hci_simple_pairing_complete_event(self, event):
logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}') logger.debug(
f'simple pairing complete for {event.bd_addr}: '
f'status={HCI_Constant.status_name(event.status)}'
)
def on_hci_pin_code_request_event(self, event): def on_hci_pin_code_request_event(self, event):
# For now, just refuse all requests # For now, just refuse all requests
# TODO: delegate the decision # TODO: delegate the decision
self.send_command_sync( self.send_command_sync(
HCI_PIN_Code_Request_Negative_Reply_Command( HCI_PIN_Code_Request_Negative_Reply_Command(bd_addr=event.bd_addr)
bd_addr = event.bd_addr
)
) )
def on_hci_link_key_request_event(self, event): def on_hci_link_key_request_event(self, event):
@@ -545,15 +799,18 @@ class Host(EventEmitter):
logger.debug('no link key provider') logger.debug('no link key provider')
link_key = None link_key = None
else: else:
link_key = await self.link_key_provider(event.bd_addr) link_key = await self.abort_on(
'flush',
# pylint: disable-next=not-callable
self.link_key_provider(event.bd_addr),
)
if link_key: if link_key:
response = HCI_Link_Key_Request_Reply_Command( response = HCI_Link_Key_Request_Reply_Command(
bd_addr = event.bd_addr, bd_addr=event.bd_addr, link_key=link_key
link_key = link_key
) )
else: else:
response = HCI_Link_Key_Request_Negative_Reply_Command( response = HCI_Link_Key_Request_Negative_Reply_Command(
bd_addr = event.bd_addr bd_addr=event.bd_addr
) )
await self.send_command(response) await self.send_command(response)
@@ -567,12 +824,21 @@ class Host(EventEmitter):
pass pass
def on_hci_user_confirmation_request_event(self, event): def on_hci_user_confirmation_request_event(self, event):
self.emit('authentication_user_confirmation_request', event.bd_addr, event.numeric_value) self.emit(
'authentication_user_confirmation_request',
event.bd_addr,
event.numeric_value,
)
def on_hci_user_passkey_request_event(self, event): def on_hci_user_passkey_request_event(self, event):
self.emit('authentication_user_passkey_request', event.bd_addr) self.emit('authentication_user_passkey_request', event.bd_addr)
def on_hci_inquiry_complete_event(self, event): def on_hci_user_passkey_notification_event(self, event):
self.emit(
'authentication_user_passkey_notification', event.bd_addr, event.passkey
)
def on_hci_inquiry_complete_event(self, _event):
self.emit('inquiry_complete') self.emit('inquiry_complete')
def on_hci_inquiry_result_with_rssi_event(self, event): def on_hci_inquiry_result_with_rssi_event(self, event):
@@ -582,7 +848,7 @@ class Host(EventEmitter):
response.bd_addr, response.bd_addr,
response.class_of_device, response.class_of_device,
b'', b'',
response.rssi response.rssi,
) )
def on_hci_extended_inquiry_result_event(self, event): def on_hci_extended_inquiry_result_event(self, event):
@@ -591,7 +857,7 @@ class Host(EventEmitter):
event.bd_addr, event.bd_addr,
event.class_of_device, event.class_of_device,
event.extended_inquiry_response, event.extended_inquiry_response,
event.rssi event.rssi,
) )
def on_hci_remote_name_request_complete_event(self, event): def on_hci_remote_name_request_complete_event(self, event):
@@ -599,3 +865,10 @@ class Host(EventEmitter):
self.emit('remote_name_failure', event.bd_addr, event.status) self.emit('remote_name_failure', event.bd_addr, event.status)
else: else:
self.emit('remote_name', event.bd_addr, event.remote_name) self.emit('remote_name', event.bd_addr, event.remote_name)
def on_hci_remote_host_supported_features_notification_event(self, event):
self.emit(
'remote_host_supported_features',
event.bd_addr,
event.host_supported_features,
)

View File

@@ -20,6 +20,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio
import logging import logging
import os import os
import json import json
@@ -38,10 +39,10 @@ logger = logging.getLogger(__name__)
class PairingKeys: class PairingKeys:
class Key: class Key:
def __init__(self, value, authenticated=False, ediv=None, rand=None): def __init__(self, value, authenticated=False, ediv=None, rand=None):
self.value = value self.value = value
self.authenticated = authenticated self.authenticated = authenticated
self.ediv = ediv self.ediv = ediv
self.rand = rand self.rand = rand
@classmethod @classmethod
def from_dict(cls, key_dict): def from_dict(cls, key_dict):
@@ -64,31 +65,33 @@ class PairingKeys:
return key_dict return key_dict
def __init__(self): def __init__(self):
self.address_type = None self.address_type = None
self.ltk = None self.ltk = None
self.ltk_central = None self.ltk_central = None
self.ltk_peripheral = None self.ltk_peripheral = None
self.irk = None self.irk = None
self.csrk = None self.csrk = None
self.link_key = None # Classic self.link_key = None # Classic
@staticmethod @staticmethod
def key_from_dict(keys_dict, key_name): def key_from_dict(keys_dict, key_name):
key_dict = keys_dict.get(key_name) key_dict = keys_dict.get(key_name)
if key_dict is not None: if key_dict is None:
return PairingKeys.Key.from_dict(key_dict) return None
return PairingKeys.Key.from_dict(key_dict)
@staticmethod @staticmethod
def from_dict(keys_dict): def from_dict(keys_dict):
keys = PairingKeys() keys = PairingKeys()
keys.address_type = keys_dict.get('address_type') keys.address_type = keys_dict.get('address_type')
keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk') keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk')
keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central') keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central')
keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral') keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral')
keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk') keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk')
keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk') keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk')
keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key') keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key')
return keys return keys
@@ -120,9 +123,9 @@ class PairingKeys:
def print(self, prefix=''): def print(self, prefix=''):
keys_dict = self.to_dict() keys_dict = self.to_dict()
for (property, value) in keys_dict.items(): for (container_property, value) in keys_dict.items():
if type(value) is dict: if isinstance(value, dict):
print(f'{prefix}{color(property, "cyan")}:') print(f'{prefix}{color(container_property, "cyan")}:')
for (key_property, key_value) in value.items(): for (key_property, key_value) in value.items():
print(f'{prefix} {color(key_property, "green")}: {key_value}') print(f'{prefix} {color(key_property, "green")}: {key_value}')
else: else:
@@ -137,12 +140,16 @@ class KeyStore:
async def update(self, name, keys): async def update(self, name, keys):
pass pass
async def get(self, name): async def get(self, _name):
return PairingKeys() return PairingKeys()
async def get_all(self): async def get_all(self):
return [] return []
async def delete_all(self):
all_keys = await self.get_all()
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
async def get_resolving_keys(self): async def get_resolving_keys(self):
all_keys = await self.get_all() all_keys = await self.get_all()
resolving_keys = [] resolving_keys = []
@@ -161,7 +168,7 @@ class KeyStore:
separator = '' separator = ''
for (name, keys) in entries: for (name, keys) in entries:
print(separator + prefix + color(name, 'yellow')) print(separator + prefix + color(name, 'yellow'))
keys.print(prefix = prefix + ' ') keys.print(prefix=prefix + ' ')
separator = '\n' separator = '\n'
@staticmethod @staticmethod
@@ -178,9 +185,9 @@ class KeyStore:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class JsonKeyStore(KeyStore): class JsonKeyStore(KeyStore):
APP_NAME = 'Bumble' APP_NAME = 'Bumble'
APP_AUTHOR = 'Google' APP_AUTHOR = 'Google'
KEYS_DIR = 'Pairing' KEYS_DIR = 'Pairing'
DEFAULT_NAMESPACE = '__DEFAULT__' DEFAULT_NAMESPACE = '__DEFAULT__'
def __init__(self, namespace, filename=None): def __init__(self, namespace, filename=None):
@@ -188,10 +195,13 @@ class JsonKeyStore(KeyStore):
if filename is None: if filename is None:
# Use a default for the current user # Use a default for the current user
# Import here because this may not exist on all platforms
# pylint: disable=import-outside-toplevel
import appdirs import appdirs
self.directory_name = os.path.join( self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
self.KEYS_DIR
) )
json_filename = f'{self.namespace}.json'.lower().replace(':', '-') json_filename = f'{self.namespace}.json'.lower().replace(':', '-')
self.filename = os.path.join(self.directory_name, json_filename) self.filename = os.path.join(self.directory_name, json_filename)
@@ -214,7 +224,7 @@ class JsonKeyStore(KeyStore):
async def load(self): async def load(self):
try: try:
with open(self.filename, 'r') as json_file: with open(self.filename, 'r', encoding='utf-8') as json_file:
return json.load(json_file) return json.load(json_file)
except FileNotFoundError: except FileNotFoundError:
return {} return {}
@@ -226,7 +236,7 @@ class JsonKeyStore(KeyStore):
# Save to a temporary file # Save to a temporary file
temp_filename = self.filename + '.tmp' temp_filename = self.filename + '.tmp'
with open(temp_filename, 'w') as output: with open(temp_filename, 'w', encoding='utf-8') as output:
json.dump(db, output, sort_keys=True, indent=4) json.dump(db, output, sort_keys=True, indent=4)
# Atomically replace the previous file # Atomically replace the previous file
@@ -257,7 +267,16 @@ class JsonKeyStore(KeyStore):
if namespace is None: if namespace is None:
return [] return []
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()] return [
(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()
]
async def delete_all(self):
db = await self.load()
db.pop(self.namespace, None)
await self.save(db)
async def get(self, name): async def get(self, name):
db = await self.load() db = await self.load()

File diff suppressed because it is too large Load Diff

View File

@@ -17,15 +17,16 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
import asyncio import asyncio
import websockets
from functools import partial from functools import partial
from colors import color from colors import color
import websockets
from bumble.hci import ( from bumble.hci import (
Address, Address,
HCI_SUCCESS, HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR HCI_CONNECTION_TIMEOUT_ERROR,
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -47,7 +48,8 @@ def parse_parameters(params_str):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# TODO: add more support for various LL exchanges (see Vol 6, Part B - 2.4 DATA CHANNEL PDU) # TODO: add more support for various LL exchanges
# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class LocalLink: class LocalLink:
''' '''
@@ -55,7 +57,7 @@ class LocalLink:
''' '''
def __init__(self): def __init__(self):
self.controllers = set() self.controllers = set()
self.pending_connection = None self.pending_connection = None
def add_controller(self, controller): def add_controller(self, controller):
@@ -103,23 +105,31 @@ class LocalLink:
return return
# Connect to the first controller with a matching address # Connect to the first controller with a matching address
if peripheral_controller := self.find_controller(le_create_connection_command.peer_address): if peripheral_controller := self.find_controller(
central_controller.on_link_peripheral_connection_complete(le_create_connection_command, HCI_SUCCESS) le_create_connection_command.peer_address
):
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_SUCCESS
)
peripheral_controller.on_link_central_connected(central_address) peripheral_controller.on_link_central_connected(central_address)
return return
# No peripheral found # No peripheral found
central_controller.on_link_peripheral_connection_complete( central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
) )
def connect(self, central_address, le_create_connection_command): def connect(self, central_address, le_create_connection_command):
logger.debug(f'$$$ CONNECTION {central_address} -> {le_create_connection_command.peer_address}') logger.debug(
f'$$$ CONNECTION {central_address} -> '
f'{le_create_connection_command.peer_address}'
)
self.pending_connection = (central_address, le_create_connection_command) self.pending_connection = (central_address, le_create_connection_command)
asyncio.get_running_loop().call_soon(self.on_connection_complete) asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete(self, central_address, peripheral_address, disconnect_command): def on_disconnection_complete(
self, central_address, peripheral_address, disconnect_command
):
# Find the controller that initiated the disconnection # Find the controller that initiated the disconnection
if not (central_controller := self.find_controller(central_address)): if not (central_controller := self.find_controller(central_address)):
logger.warning('!!! Initiating controller not found') logger.warning('!!! Initiating controller not found')
@@ -127,16 +137,26 @@ class LocalLink:
# Disconnect from the first controller with a matching address # Disconnect from the first controller with a matching address
if peripheral_controller := self.find_controller(peripheral_address): if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_central_disconnected(central_address, disconnect_command.reason) peripheral_controller.on_link_central_disconnected(
central_address, disconnect_command.reason
)
central_controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS) central_controller.on_link_peripheral_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command): def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(f'$$$ DISCONNECTION {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}') logger.debug(
f'$$$ DISCONNECTION {central_address} -> '
f'{peripheral_address}: reason = {disconnect_command.reason}'
)
args = [central_address, peripheral_address, disconnect_command] args = [central_address, peripheral_address, disconnect_command]
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args) asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk): # pylint: disable=too-many-arguments
def on_connection_encrypted(
self, central_address, peripheral_address, rand, ediv, ltk
):
logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}') logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
if central_controller := self.find_controller(central_address): if central_controller := self.find_controller(central_address):
@@ -152,15 +172,18 @@ class RemoteLink:
A Link implementation that communicates with other virtual controllers via a A Link implementation that communicates with other virtual controllers via a
WebSocket relay WebSocket relay
''' '''
def __init__(self, uri): def __init__(self, uri):
self.controller = None self.controller = None
self.uri = uri self.uri = uri
self.execution_queue = asyncio.Queue() self.execution_queue = asyncio.Queue()
self.websocket = asyncio.get_running_loop().create_future() self.websocket = asyncio.get_running_loop().create_future()
self.rpc_result = None self.rpc_result = None
self.pending_connection = None self.pending_connection = None
self.central_connections = set() # List of addresses that we have connected to self.central_connections = set() # List of addresses that we have connected to
self.peripheral_connections = set() # List of addresses that have connected to us self.peripheral_connections = (
set()
) # List of addresses that have connected to us
# Connect and run asynchronously # Connect and run asynchronously
asyncio.create_task(self.run_connection()) asyncio.create_task(self.run_connection())
@@ -192,11 +215,14 @@ class RemoteLink:
try: try:
await item await item
except Exception as error: except Exception as error:
logger.warning(f'{color("!!! Exception in async handler:", "red")} {error}') logger.warning(
f'{color("!!! Exception in async handler:", "red")} {error}'
)
async def run_connection(self): async def run_connection(self):
# Connect to the relay # Connect to the relay
logger.debug(f'connecting to {self.uri}') logger.debug(f'connecting to {self.uri}')
# pylint: disable-next=no-member
websocket = await websockets.connect(self.uri) websocket = await websockets.connect(self.uri)
self.websocket.set_result(websocket) self.websocket.set_result(websocket)
logger.debug(f'connected to {self.uri}') logger.debug(f'connected to {self.uri}')
@@ -227,7 +253,9 @@ class RemoteLink:
self.central_connections.remove(address) self.central_connections.remove(address)
if address in self.peripheral_connections: if address in self.peripheral_connections:
self.controller.on_link_central_disconnected(address, HCI_CONNECTION_TIMEOUT_ERROR) self.controller.on_link_central_disconnected(
address, HCI_CONNECTION_TIMEOUT_ERROR
)
self.peripheral_connections.remove(address) self.peripheral_connections.remove(address)
async def on_unreachable_received(self, target): async def on_unreachable_received(self, target):
@@ -244,7 +272,9 @@ class RemoteLink:
async def on_advertisement_message_received(self, sender, advertisement): async def on_advertisement_message_received(self, sender, advertisement):
try: try:
self.controller.on_link_advertising_data(Address(sender), bytes.fromhex(advertisement)) self.controller.on_link_advertising_data(
Address(sender), bytes.fromhex(advertisement)
)
except Exception: except Exception:
logger.exception('exception') logger.exception('exception')
@@ -263,11 +293,11 @@ class RemoteLink:
self.controller.on_link_central_connected(Address(sender)) self.controller.on_link_central_connected(Address(sender))
# Accept the connection by responding to it # Accept the connection by responding to it
await self.send_targetted_message(sender, 'connected') await self.send_targeted_message(sender, 'connected')
async def on_connected_message_received(self, sender, _): async def on_connected_message_received(self, sender, _):
if not self.pending_connection: if not self.pending_connection:
logger.warn('received a connection ack, but no connection is pending') logger.warning('received a connection ack, but no connection is pending')
return return
# Remember the connection # Remember the connection
@@ -275,7 +305,9 @@ class RemoteLink:
# Notify the controller # Notify the controller
logger.debug(f'connected to peripheral {self.pending_connection.peer_address}') logger.debug(f'connected to peripheral {self.pending_connection.peer_address}')
self.controller.on_link_peripheral_connection_complete(self.pending_connection, HCI_SUCCESS) self.controller.on_link_peripheral_connection_complete(
self.pending_connection, HCI_SUCCESS
)
async def on_disconnect_message_received(self, sender, message): async def on_disconnect_message_received(self, sender, message):
# Notify the controller # Notify the controller
@@ -287,7 +319,7 @@ class RemoteLink:
if sender in self.peripheral_connections: if sender in self.peripheral_connections:
self.peripheral_connections.remove(sender) self.peripheral_connections.remove(sender)
async def on_encrypted_message_received(self, sender, message): async def on_encrypted_message_received(self, sender, _):
# TODO parse params to get real args # TODO parse params to get real args
self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16)) self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16))
@@ -296,7 +328,7 @@ class RemoteLink:
websocket = await self.websocket websocket = await self.websocket
# Create a future value to hold the eventual result # Create a future value to hold the eventual result
assert(self.rpc_result is None) assert self.rpc_result is None
self.rpc_result = asyncio.get_running_loop().create_future() self.rpc_result = asyncio.get_running_loop().create_future()
# Send the command # Send the command
@@ -309,7 +341,7 @@ class RemoteLink:
# TODO: parse the result # TODO: parse the result
async def send_targetted_message(self, target, message): async def send_targeted_message(self, target, message):
# Ensure we have a connection # Ensure we have a connection
websocket = await self.websocket websocket = await self.websocket
@@ -326,35 +358,61 @@ class RemoteLink:
self.execute(self.notify_address_changed) self.execute(self.notify_address_changed)
async def send_advertising_data_to_relay(self, data): async def send_advertising_data_to_relay(self, data):
await self.send_targetted_message('*', f'advertisement:{data.hex()}') await self.send_targeted_message('*', f'advertisement:{data.hex()}')
def send_advertising_data(self, sender_address, data): def send_advertising_data(self, _, data):
self.execute(partial(self.send_advertising_data_to_relay, data)) self.execute(partial(self.send_advertising_data_to_relay, data))
async def send_acl_data_to_relay(self, peer_address, data): async def send_acl_data_to_relay(self, peer_address, data):
await self.send_targetted_message(peer_address, f'acl:{data.hex()}') await self.send_targeted_message(peer_address, f'acl:{data.hex()}')
def send_acl_data(self, sender_address, peer_address, data): def send_acl_data(self, _, peer_address, data):
self.execute(partial(self.send_acl_data_to_relay, peer_address, data)) self.execute(partial(self.send_acl_data_to_relay, peer_address, data))
async def send_connection_request_to_relay(self, peer_address): async def send_connection_request_to_relay(self, peer_address):
await self.send_targetted_message(peer_address, 'connect') await self.send_targeted_message(peer_address, 'connect')
def connect(self, central_address, le_create_connection_command): def connect(self, _, le_create_connection_command):
if self.pending_connection: if self.pending_connection:
logger.warn('connection already pending') logger.warning('connection already pending')
return return
self.pending_connection = le_create_connection_command self.pending_connection = le_create_connection_command
self.execute(partial(self.send_connection_request_to_relay, str(le_create_connection_command.peer_address))) self.execute(
partial(
self.send_connection_request_to_relay,
str(le_create_connection_command.peer_address),
)
)
def on_disconnection_complete(self, disconnect_command): def on_disconnection_complete(self, disconnect_command):
self.controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS) self.controller.on_link_peripheral_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command): def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(f'disconnect {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}') logger.debug(
self.execute(partial(self.send_targetted_message, peripheral_address, f'disconnect:reason={disconnect_command.reason}')) f'disconnect {central_address} -> '
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, disconnect_command) f'{peripheral_address}: reason = {disconnect_command.reason}'
)
self.execute(
partial(
self.send_targeted_message,
peripheral_address,
f'disconnect:reason={disconnect_command.reason}',
)
)
asyncio.get_running_loop().call_soon(
self.on_disconnection_complete, disconnect_command
)
def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk): def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk):
asyncio.get_running_loop().call_soon(self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk) asyncio.get_running_loop().call_soon(
self.execute(partial(self.send_targetted_message, peripheral_address, f'encrypted:ltk={ltk.hex()}')) self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk
)
self.execute(
partial(
self.send_targeted_message,
peripheral_address,
f'encrypted:ltk={ltk.hex()}',
)
)

View File

@@ -0,0 +1,13 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,157 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
import logging
from typing import List
from ..core import AdvertisingData
from ..gatt import (
GATT_ASHA_SERVICE,
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
GATT_ASHA_VOLUME_CHARACTERISTIC,
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
TemplateService,
Characteristic,
CharacteristicValue,
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class AshaService(TemplateService):
UUID = GATT_ASHA_SERVICE
OPCODE_START = 1
OPCODE_STOP = 2
OPCODE_STATUS = 3
PROTOCOL_VERSION = 0x01
RESERVED_FOR_FUTURE_USE = [00, 00]
FEATURE_MAP = [0x01] # [LE CoC audio output streaming supported]
SUPPORTED_CODEC_ID = [0x02, 0x01] # Codec IDs [G.722 at 16 kHz]
RENDER_DELAY = [00, 00]
def __init__(self, capability: int, hisyncid: List[int]):
self.hisyncid = hisyncid
self.capability = capability # Device Capabilities [Left, Monaural]
# Handler for volume control
def on_volume_write(_connection, value):
logger.info(f'--- VOLUME Write:{value[0]}')
# Handler for audio control commands
def on_audio_control_point_write(_connection, value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
logger.info(
f'### START: codec={value[1]}, '
f'audio_type={audio_type}, '
f'volume={value[3]}, '
f'otherstate={value[4]}'
)
elif opcode == AshaService.OPCODE_STOP:
logger.info('### STOP')
elif opcode == AshaService.OPCODE_STATUS:
logger.info(f'### STATUS: connected={value[1]}')
# TODO Respond with a status
# asyncio.create_task(device.notify_subscribers(audio_status_characteristic,
# force=True))
self.read_only_properties_characteristic = Characteristic(
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes(
[
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]
)
+ bytes(self.hisyncid)
+ bytes(AshaService.FEATURE_MAP)
+ bytes(AshaService.RENDER_DELAY)
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
+ bytes(AshaService.SUPPORTED_CODEC_ID),
)
self.audio_control_point_characteristic = Characteristic(
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write),
)
self.audio_status_characteristic = Characteristic(
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([0]),
)
self.volume_characteristic = Characteristic(
GATT_ASHA_VOLUME_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write),
)
# TODO add real psm value
self.psm = 0x0080
# self.psm = device.register_l2cap_channel_server(0, on_coc, 8)
self.le_psm_out_characteristic = Characteristic(
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
struct.pack('<H', self.psm),
)
characteristics = [
self.read_only_properties_characteristic,
self.audio_control_point_characteristic,
self.audio_status_characteristic,
self.volume_characteristic,
self.le_psm_out_characteristic,
]
super().__init__(characteristics)
def get_advertising_data(self):
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData(
[
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(GATT_ASHA_SERVICE)
+ bytes(
[
AshaService.PROTOCOL_VERSION,
self.capability,
]
)
+ bytes(self.hisyncid[:4]),
),
]
)
)

View File

@@ -0,0 +1,62 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from ..gatt_client import ProfileServiceProxy
from ..gatt import (
GATT_BATTERY_SERVICE,
GATT_BATTERY_LEVEL_CHARACTERISTIC,
TemplateService,
Characteristic,
CharacteristicValue,
PackedCharacteristicAdapter,
)
# -----------------------------------------------------------------------------
class BatteryService(TemplateService):
UUID = GATT_BATTERY_SERVICE
BATTERY_LEVEL_FORMAT = 'B'
def __init__(self, read_battery_level):
self.battery_level_characteristic = PackedCharacteristicAdapter(
Characteristic(
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
CharacteristicValue(read=read_battery_level),
),
pack_format=BatteryService.BATTERY_LEVEL_FORMAT,
)
super().__init__([self.battery_level_characteristic])
# -----------------------------------------------------------------------------
class BatteryServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = BatteryService
def __init__(self, service_proxy):
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BATTERY_LEVEL_CHARACTERISTIC
):
self.battery_level = PackedCharacteristicAdapter(
characteristics[0], pack_format=BatteryService.BATTERY_LEVEL_FORMAT
)
else:
self.battery_level = None

View File

@@ -0,0 +1,138 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
from typing import Tuple
from ..gatt_client import ProfileServiceProxy
from ..gatt import (
GATT_DEVICE_INFORMATION_SERVICE,
GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC,
GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC,
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
GATT_MODEL_NUMBER_STRING_CHARACTERISTIC,
GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC,
GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC,
GATT_SYSTEM_ID_CHARACTERISTIC,
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
TemplateService,
Characteristic,
DelegatedCharacteristicAdapter,
UTF8CharacteristicAdapter,
)
# -----------------------------------------------------------------------------
class DeviceInformationService(TemplateService):
UUID = GATT_DEVICE_INFORMATION_SERVICE
@staticmethod
def pack_system_id(oui, manufacturer_id):
return struct.pack('<Q', oui << 40 | manufacturer_id)
@staticmethod
def unpack_system_id(buffer):
system_id = struct.unpack('<Q', buffer)[0]
return (system_id >> 40, system_id & 0xFFFFFFFFFF)
def __init__(
self,
manufacturer_name: str = None,
model_number: str = None,
serial_number: str = None,
hardware_revision: str = None,
firmware_revision: str = None,
software_revision: str = None,
system_id: Tuple[int, int] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: bytes = None
# TODO: pnp_id
):
characteristics = [
Characteristic(uuid, Characteristic.READ, Characteristic.READABLE, field)
for (field, uuid) in (
(manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
(model_number, GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
(serial_number, GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
(hardware_revision, GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC),
(firmware_revision, GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC),
(software_revision, GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC),
)
if field is not None
]
if system_id is not None:
characteristics.append(
Characteristic(
GATT_SYSTEM_ID_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
self.pack_system_id(*system_id),
)
)
if ieee_regulatory_certification_data_list is not None:
characteristics.append(
Characteristic(
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
ieee_regulatory_certification_data_list,
)
)
super().__init__(characteristics)
# -----------------------------------------------------------------------------
class DeviceInformationServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = DeviceInformationService
def __init__(self, service_proxy):
self.service_proxy = service_proxy
for (field, uuid) in (
('manufacturer_name', GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
('model_number', GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
('hardware_revision', GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC),
('firmware_revision', GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC),
('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC),
):
if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
characteristic = UTF8CharacteristicAdapter(characteristics[0])
else:
characteristic = None
self.__setattr__(field, characteristic)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_SYSTEM_ID_CHARACTERISTIC
):
self.system_id = DelegatedCharacteristicAdapter(
characteristics[0],
encode=lambda v: DeviceInformationService.pack_system_id(*v),
decode=DeviceInformationService.unpack_system_id,
)
else:
self.system_id = None
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC
):
self.ieee_regulatory_certification_data_list = characteristics[0]
else:
self.ieee_regulatory_certification_data_list = None

View File

@@ -0,0 +1,237 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from enum import IntEnum
import struct
from ..gatt_client import ProfileServiceProxy
from ..att import ATT_Error
from ..gatt import (
GATT_HEART_RATE_SERVICE,
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
TemplateService,
Characteristic,
CharacteristicValue,
DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter,
)
# -----------------------------------------------------------------------------
class HeartRateService(TemplateService):
UUID = GATT_HEART_RATE_SERVICE
HEART_RATE_CONTROL_POINT_FORMAT = 'B'
CONTROL_POINT_NOT_SUPPORTED = 0x80
RESET_ENERGY_EXPENDED = 0x01
class BodySensorLocation(IntEnum):
OTHER = (0,)
CHEST = (1,)
WRIST = (2,)
FINGER = (3,)
HAND = (4,)
EAR_LOBE = (5,)
FOOT = 6
class HeartRateMeasurement:
def __init__(
self,
heart_rate,
sensor_contact_detected=None,
energy_expended=None,
rr_intervals=None,
):
if heart_rate < 0 or heart_rate > 0xFFFF:
raise ValueError('heart_rate out of range')
if energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF
):
raise ValueError('energy_expended out of range')
if rr_intervals:
for rr_interval in rr_intervals:
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise ValueError('rr_intervals out of range')
self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected
self.energy_expended = energy_expended
self.rr_intervals = rr_intervals
@classmethod
def from_bytes(cls, data):
flags = data[0]
offset = 1
if flags & 1:
hr = struct.unpack_from('<H', data, offset)[0]
offset += 2
else:
hr = struct.unpack_from('B', data, offset)[0]
offset += 1
if flags & (1 << 2):
sensor_contact_detected = flags & (1 << 1) != 0
else:
sensor_contact_detected = None
if flags & (1 << 3):
energy_expended = struct.unpack_from('<H', data, offset)[0]
offset += 2
else:
energy_expended = None
if flags & (1 << 4):
rr_intervals = tuple(
struct.unpack_from('<H', data, offset + i * 2)[0] / 1024
for i in range((len(data) - offset) // 2)
)
else:
rr_intervals = ()
return cls(hr, sensor_contact_detected, energy_expended, rr_intervals)
def __bytes__(self):
if self.heart_rate < 256:
flags = 0
data = struct.pack('B', self.heart_rate)
else:
flags = 1
data = struct.pack('<H', self.heart_rate)
if self.sensor_contact_detected is not None:
flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2)
if self.energy_expended is not None:
flags |= 1 << 3
data += struct.pack('<H', self.energy_expended)
if self.rr_intervals:
flags |= 1 << 4
data += b''.join(
[
struct.pack('<H', int(rr_interval * 1024))
for rr_interval in self.rr_intervals
]
)
return bytes([flags]) + data
def __str__(self):
return (
f'HeartRateMeasurement(heart_rate={self.heart_rate},'
f' sensor_contact_detected={self.sensor_contact_detected},'
f' energy_expended={self.energy_expended},'
f' rr_intervals={self.rr_intervals})'
)
def __init__(
self,
read_heart_rate_measurement,
body_sensor_location=None,
reset_energy_expended=None,
):
self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Characteristic.NOTIFY,
0,
CharacteristicValue(read=read_heart_rate_measurement),
),
# pylint: disable=unnecessary-lambda
encode=lambda value: bytes(value),
)
characteristics = [self.heart_rate_measurement_characteristic]
if body_sensor_location is not None:
self.body_sensor_location_characteristic = Characteristic(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([int(body_sensor_location)]),
)
characteristics.append(self.body_sensor_location_characteristic)
if reset_energy_expended:
def write_heart_rate_control_point_value(connection, value):
if value == self.RESET_ENERGY_EXPENDED:
if reset_energy_expended is not None:
reset_energy_expended(connection)
else:
raise ATT_Error(self.CONTROL_POINT_NOT_SUPPORTED)
self.heart_rate_control_point_characteristic = PackedCharacteristicAdapter(
Characteristic(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE,
Characteristic.WRITEABLE,
CharacteristicValue(write=write_heart_rate_control_point_value),
),
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
characteristics.append(self.heart_rate_control_point_characteristic)
super().__init__(characteristics)
# -----------------------------------------------------------------------------
class HeartRateServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = HeartRateService
def __init__(self, service_proxy):
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
):
self.heart_rate_measurement = DelegatedCharacteristicAdapter(
characteristics[0],
decode=HeartRateService.HeartRateMeasurement.from_bytes,
)
else:
self.heart_rate_measurement = None
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
):
self.body_sensor_location = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: HeartRateService.BodySensorLocation(value[0]),
)
else:
self.body_sensor_location = None
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
):
self.heart_rate_control_point = PackedCharacteristicAdapter(
characteristics[0],
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
else:
self.heart_rate_control_point = None
async def reset_energy_expended(self):
if self.heart_rate_control_point is not None:
return await self.heart_rate_control_point.write_value(
HeartRateService.RESET_ENERGY_EXPENDED
)

View File

@@ -17,10 +17,12 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
import asyncio import asyncio
from colors import color
from .utils import EventEmitter from colors import color
from .core import InvalidStateError, ProtocolError, ConnectionError from pyee import EventEmitter
from . import core
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -31,6 +33,8 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
RFCOMM_PSM = 0x0003 RFCOMM_PSM = 0x0003
@@ -97,22 +101,24 @@ RFCOMM_DEFAULT_PREFERRED_MTU = 1280
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1 RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30 RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def fcs(buffer): def compute_fcs(buffer):
fcs = 0xFF result = 0xFF
for byte in buffer: for byte in buffer:
fcs = CRC_TABLE[fcs ^ byte] result = CRC_TABLE[result ^ byte]
return 0xFF - fcs return 0xFF - result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class RFCOMM_Frame: class RFCOMM_Frame:
def __init__(self, type, c_r, dlci, p_f, information = b'', with_credits = False): def __init__(self, frame_type, c_r, dlci, p_f, information=b'', with_credits=False):
self.type = type self.type = frame_type
self.c_r = c_r self.c_r = c_r
self.dlci = dlci self.dlci = dlci
self.p_f = p_f self.p_f = p_f
self.information = information self.information = information
length = len(information) length = len(information)
if with_credits: if with_credits:
@@ -123,19 +129,19 @@ class RFCOMM_Frame:
else: else:
# 1-byte length indicator # 1-byte length indicator
self.length = bytes([(length << 1) | 1]) self.length = bytes([(length << 1) | 1])
self.address = (dlci << 2) | (c_r << 1) | 1 self.address = (dlci << 2) | (c_r << 1) | 1
self.control = type | (p_f << 4) self.control = frame_type | (p_f << 4)
if type == RFCOMM_UIH_FRAME: if frame_type == RFCOMM_UIH_FRAME:
self.fcs = fcs(bytes([self.address, self.control])) self.fcs = compute_fcs(bytes([self.address, self.control]))
else: else:
self.fcs = fcs(bytes([self.address, self.control]) + self.length) self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
def type_name(self): def type_name(self):
return RFCOMM_FRAME_TYPE_NAMES[self.type] return RFCOMM_FRAME_TYPE_NAMES[self.type]
@staticmethod @staticmethod
def parse_mcc(data): def parse_mcc(data):
type = data[0] >> 2 mcc_type = data[0] >> 2
c_r = (data[0] >> 1) & 1 c_r = (data[0] >> 1) & 1
length = data[1] length = data[1]
if data[1] & 1: if data[1] & 1:
@@ -143,13 +149,16 @@ class RFCOMM_Frame:
value = data[2:] value = data[2:]
else: else:
length = (data[3] << 7) & (length >> 1) length = (data[3] << 7) & (length >> 1)
value = data[3:3 + length] value = data[3 : 3 + length]
return (type, c_r, value) return (mcc_type, c_r, value)
@staticmethod @staticmethod
def make_mcc(type, c_r, data): def make_mcc(mcc_type, c_r, data):
return bytes([(type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1]) + data return (
bytes([(mcc_type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1])
+ data
)
@staticmethod @staticmethod
def sabm(c_r, dlci): def sabm(c_r, dlci):
@@ -168,15 +177,17 @@ class RFCOMM_Frame:
return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1) return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1)
@staticmethod @staticmethod
def uih(c_r, dlci, information, p_f = 0): def uih(c_r, dlci, information, p_f=0):
return RFCOMM_Frame(RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits = (p_f == 1)) return RFCOMM_Frame(
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
)
@staticmethod @staticmethod
def from_bytes(data): def from_bytes(data):
# Extract fields # Extract fields
dlci = (data[0] >> 2) & 0x3F dlci = (data[0] >> 2) & 0x3F
c_r = (data[0] >> 1) & 0x01 c_r = (data[0] >> 1) & 0x01
type = data[1] & 0xEF frame_type = data[1] & 0xEF
p_f = (data[1] >> 4) & 0x01 p_f = (data[1] >> 4) & 0x01
length = data[2] length = data[2]
if length & 0x01: if length & 0x01:
@@ -188,132 +199,182 @@ class RFCOMM_Frame:
fcs = data[-1] fcs = data[-1]
# Construct the frame and check the CRC # Construct the frame and check the CRC
frame = RFCOMM_Frame(type, c_r, dlci, p_f, information) frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
if frame.fcs != fcs: if frame.fcs != fcs:
logger.warn(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}') logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
raise ValueError('fcs mismatch') raise ValueError('fcs mismatch')
return frame return frame
def __bytes__(self): def __bytes__(self):
return bytes([self.address, self.control]) + self.length + self.information + bytes([self.fcs]) return (
bytes([self.address, self.control])
+ self.length
+ self.information
+ bytes([self.fcs])
)
def __str__(self): def __str__(self):
return f'{color(self.type_name(), "yellow")}(c/r={self.c_r},dlci={self.dlci},p/f={self.p_f},length={len(self.information)},fcs=0x{self.fcs:02X})' return (
f'{color(self.type_name(), "yellow")}'
f'(c/r={self.c_r},'
f'dlci={self.dlci},'
f'p/f={self.p_f},'
f'length={len(self.information)},'
f'fcs=0x{self.fcs:02X})'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class RFCOMM_MCC_PN: class RFCOMM_MCC_PN:
def __init__(self, dlci, cl, priority, ack_timer, max_frame_size, max_retransmissions, window_size): def __init__(
self.dlci = dlci self,
self.cl = cl dlci,
self.priority = priority cl,
self.ack_timer = ack_timer priority,
self.max_frame_size = max_frame_size ack_timer,
max_frame_size,
max_retransmissions,
window_size,
):
self.dlci = dlci
self.cl = cl
self.priority = priority
self.ack_timer = ack_timer
self.max_frame_size = max_frame_size
self.max_retransmissions = max_retransmissions self.max_retransmissions = max_retransmissions
self.window_size = window_size self.window_size = window_size
@staticmethod @staticmethod
def from_bytes(data): def from_bytes(data):
return RFCOMM_MCC_PN( return RFCOMM_MCC_PN(
dlci = data[0], dlci=data[0],
cl = data[1], cl=data[1],
priority = data[2], priority=data[2],
ack_timer = data[3], ack_timer=data[3],
max_frame_size = data[4] | data[5] << 8, max_frame_size=data[4] | data[5] << 8,
max_retransmissions = data[6], max_retransmissions=data[6],
window_size = data[7] window_size=data[7],
) )
def __bytes__(self): def __bytes__(self):
return bytes([ return bytes(
self.dlci & 0xFF, [
self.cl & 0xFF, self.dlci & 0xFF,
self.priority & 0xFF, self.cl & 0xFF,
self.ack_timer & 0xFF, self.priority & 0xFF,
self.max_frame_size & 0xFF, self.ack_timer & 0xFF,
(self.max_frame_size >> 8) & 0xFF, self.max_frame_size & 0xFF,
self.max_retransmissions & 0xFF, (self.max_frame_size >> 8) & 0xFF,
self.window_size & 0xFF self.max_retransmissions & 0xFF,
]) self.window_size & 0xFF,
]
)
def __str__(self): def __str__(self):
return f'PN(dlci={self.dlci},cl={self.cl},priority={self.priority},ack_timer={self.ack_timer},max_frame_size={self.max_frame_size},max_retransmissions={self.max_retransmissions},window_size={self.window_size})' return (
f'PN(dlci={self.dlci},'
f'cl={self.cl},'
f'priority={self.priority},'
f'ack_timer={self.ack_timer},'
f'max_frame_size={self.max_frame_size},'
f'max_retransmissions={self.max_retransmissions},'
f'window_size={self.window_size})'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class RFCOMM_MCC_MSC: class RFCOMM_MCC_MSC:
def __init__(self, dlci, fc, rtc, rtr, ic, dv): def __init__(self, dlci, fc, rtc, rtr, ic, dv):
self.dlci = dlci self.dlci = dlci
self.fc = fc self.fc = fc
self.rtc = rtc self.rtc = rtc
self.rtr = rtr self.rtr = rtr
self.ic = ic self.ic = ic
self.dv = dv self.dv = dv
@staticmethod @staticmethod
def from_bytes(data): def from_bytes(data):
return RFCOMM_MCC_MSC( return RFCOMM_MCC_MSC(
dlci = data[0] >> 2, dlci=data[0] >> 2,
fc = data[1] >> 1 & 1, fc=data[1] >> 1 & 1,
rtc = data[1] >> 2 & 1, rtc=data[1] >> 2 & 1,
rtr = data[1] >> 3 & 1, rtr=data[1] >> 3 & 1,
ic = data[1] >> 6 & 1, ic=data[1] >> 6 & 1,
dv = data[1] >> 7 & 1 dv=data[1] >> 7 & 1,
) )
def __bytes__(self): def __bytes__(self):
return bytes([ return bytes(
(self.dlci << 2) | 3, [
1 | self.fc << 1 | self.rtc << 2 | self.rtr << 3 | self.ic << 6 | self.dv << 7 (self.dlci << 2) | 3,
]) 1
| self.fc << 1
| self.rtc << 2
| self.rtr << 3
| self.ic << 6
| self.dv << 7,
]
)
def __str__(self): def __str__(self):
return f'MSC(dlci={self.dlci},fc={self.fc},rtc={self.rtc},rtr={self.rtr},ic={self.ic},dv={self.dv})' return (
f'MSC(dlci={self.dlci},'
f'fc={self.fc},'
f'rtc={self.rtc},'
f'rtr={self.rtr},'
f'ic={self.ic},'
f'dv={self.dv})'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class DLC(EventEmitter): class DLC(EventEmitter):
# States # States
INIT = 0x00 INIT = 0x00
CONNECTING = 0x01 CONNECTING = 0x01
CONNECTED = 0x02 CONNECTED = 0x02
DISCONNECTING = 0x03 DISCONNECTING = 0x03
DISCONNECTED = 0x04 DISCONNECTED = 0x04
RESET = 0x05 RESET = 0x05
STATE_NAMES = { STATE_NAMES = {
INIT: 'INIT', INIT: 'INIT',
CONNECTING: 'CONNECTING', CONNECTING: 'CONNECTING',
CONNECTED: 'CONNECTED', CONNECTED: 'CONNECTED',
DISCONNECTING: 'DISCONNECTING', DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED', DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET' RESET: 'RESET',
} }
def __init__(self, multiplexer, dlci, max_frame_size, initial_tx_credits): def __init__(self, multiplexer, dlci, max_frame_size, initial_tx_credits):
super().__init__() super().__init__()
self.multiplexer = multiplexer self.multiplexer = multiplexer
self.dlci = dlci self.dlci = dlci
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
self.rx_threshold = self.rx_credits // 2 self.rx_threshold = self.rx_credits // 2
self.tx_credits = initial_tx_credits self.tx_credits = initial_tx_credits
self.tx_buffer = b'' self.tx_buffer = b''
self.state = DLC.INIT self.state = DLC.INIT
self.role = multiplexer.role self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.INITIATOR else 0 self.c_r = 1 if self.role == Multiplexer.INITIATOR else 0
self.sink = None self.sink = None
self.connection_result = None
# Compute the MTU # Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs max_overhead = 4 + 1 # header with 2-byte length + fcs
self.mtu = min(max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead) self.mtu = min(
max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead
)
@staticmethod @staticmethod
def state_name(state): def state_name(state):
return DLC.STATE_NAMES[state] return DLC.STATE_NAMES[state]
def change_state(self, new_state): def change_state(self, new_state):
logger.debug(f'{self} state change -> {color(self.state_name(new_state), "magenta")}') logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "magenta")}'
)
self.state = new_state self.state = new_state
def send_frame(self, frame): def send_frame(self, frame):
@@ -323,58 +384,40 @@ class DLC(EventEmitter):
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower()) handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame) handler(frame)
def on_sabm_frame(self, frame): def on_sabm_frame(self, _frame):
if self.state != DLC.CONNECTING: if self.state != DLC.CONNECTING:
logger.warn(color('!!! received SABM when not in CONNECTING state', 'red')) logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red')
)
return return
self.send_frame(RFCOMM_Frame.ua(c_r = 1 - self.c_r, dlci = self.dlci)) self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
# Exchange the modem status with the peer # Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC( msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
dlci = self.dlci, mcc = RFCOMM_Frame.make_mcc(
fc = 0, mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
) )
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 1, data = bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}') logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame( self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.change_state(DLC.CONNECTED) self.change_state(DLC.CONNECTED)
self.emit('open') self.emit('open')
def on_ua_frame(self, frame): def on_ua_frame(self, _frame):
if self.state != DLC.CONNECTING: if self.state != DLC.CONNECTING:
logger.warn(color('!!! received SABM when not in CONNECTING state', 'red')) logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red')
)
return return
# Exchange the modem status with the peer # Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC( msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
dlci = self.dlci, mcc = RFCOMM_Frame.make_mcc(
fc = 0, mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
) )
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 1, data = bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}') logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame( self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.change_state(DLC.CONNECTED) self.change_state(DLC.CONNECTED)
self.multiplexer.on_dlc_open_complete(self) self.multiplexer.on_dlc_open_complete(self)
@@ -383,29 +426,36 @@ class DLC(EventEmitter):
# TODO: handle all states # TODO: handle all states
pass pass
def on_disc_frame(self, frame): def on_disc_frame(self, _frame):
# TODO: handle all states # TODO: handle all states
self.send_frame(RFCOMM_Frame.ua(c_r = 1 - self.c_r, dlci = self.dlci)) self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
def on_uih_frame(self, frame): def on_uih_frame(self, frame):
data = frame.information data = frame.information
if frame.p_f == 1: if frame.p_f == 1:
# With credits # With credits
credits = frame.information[0] received_credits = frame.information[0]
self.tx_credits += credits self.tx_credits += received_credits
logger.debug(f'<<< Credits [{self.dlci}]: received {credits}, total={self.tx_credits}') logger.debug(
f'<<< Credits [{self.dlci}]: '
f'received {credits}, total={self.tx_credits}'
)
data = data[1:] data = data[1:]
logger.debug(f'{color("<<< Data", "yellow")} [{self.dlci}] {len(data)} bytes, rx_credits={self.rx_credits}: {data.hex()}') logger.debug(
f'{color("<<< Data", "yellow")} '
f'[{self.dlci}] {len(data)} bytes, '
f'rx_credits={self.rx_credits}: {data.hex()}'
)
if len(data) and self.sink: if len(data) and self.sink:
self.sink(data) self.sink(data) # pylint: disable=not-callable
# Update the credits # Update the credits
if self.rx_credits > 0: if self.rx_credits > 0:
self.rx_credits -= 1 self.rx_credits -= 1
else: else:
logger.warn(color('!!! received frame with no rx credits', 'red')) logger.warning(color('!!! received frame with no rx credits', 'red'))
# Check if there's anything to send (including credits) # Check if there's anything to send (including credits)
self.process_tx() self.process_tx()
@@ -417,69 +467,47 @@ class DLC(EventEmitter):
if c_r: if c_r:
# Command # Command
logger.debug(f'<<< MCC MSC Command: {msc}') logger.debug(f'<<< MCC MSC Command: {msc}')
msc = RFCOMM_MCC_MSC( msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
dlci = self.dlci, mcc = RFCOMM_Frame.make_mcc(
fc = 0, mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
) )
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 0, data = bytes(msc))
logger.debug(f'>>> MCC MSC Response: {msc}') logger.debug(f'>>> MCC MSC Response: {msc}')
self.send_frame( self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
else: else:
# Response # Response
logger.debug(f'<<< MCC MSC Response: {msc}') logger.debug(f'<<< MCC MSC Response: {msc}')
def connect(self): def connect(self):
if not self.state == DLC.INIT: if self.state != DLC.INIT:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
self.change_state(DLC.CONNECTING) self.change_state(DLC.CONNECTING)
self.connection_result = asyncio.get_running_loop().create_future() self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame( self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
RFCOMM_Frame.sabm(
c_r = self.c_r,
dlci = self.dlci
)
)
def accept(self): def accept(self):
if not self.state == DLC.INIT: if self.state != DLC.INIT:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
pn = RFCOMM_MCC_PN( pn = RFCOMM_MCC_PN(
dlci = self.dlci, dlci=self.dlci,
cl = 0xE0, cl=0xE0,
priority = 7, priority=7,
ack_timer = 0, ack_timer=0,
max_frame_size = RFCOMM_DEFAULT_PREFERRED_MTU, max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions = 0, max_retransmissions=0,
window_size = RFCOMM_DEFAULT_INITIAL_RX_CREDITS window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
) )
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_PN_TYPE, c_r = 0, data = bytes(pn)) mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}') logger.debug(f'>>> PN Response: {pn}')
self.send_frame( self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.change_state(DLC.CONNECTING) self.change_state(DLC.CONNECTING)
def rx_credits_needed(self): def rx_credits_needed(self):
if self.rx_credits <= self.rx_threshold: if self.rx_credits <= self.rx_threshold:
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
else:
return 0 return 0
def process_tx(self): def process_tx(self):
# Send anything we can (or an empty frame if we need to send rx credits) # Send anything we can (or an empty frame if we need to send rx credits)
@@ -487,13 +515,13 @@ class DLC(EventEmitter):
while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0: while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0:
# Get the next chunk, up to MTU size # Get the next chunk, up to MTU size
if rx_credits_needed > 0: if rx_credits_needed > 0:
chunk = bytes([rx_credits_needed]) + self.tx_buffer[:self.mtu - 1] chunk = bytes([rx_credits_needed]) + self.tx_buffer[: self.mtu - 1]
self.tx_buffer = self.tx_buffer[len(chunk) - 1:] self.tx_buffer = self.tx_buffer[len(chunk) - 1 :]
self.rx_credits += rx_credits_needed self.rx_credits += rx_credits_needed
tx_credit_spent = (len(chunk) > 1) tx_credit_spent = len(chunk) > 1
else: else:
chunk = self.tx_buffer[:self.mtu] chunk = self.tx_buffer[: self.mtu]
self.tx_buffer = self.tx_buffer[len(chunk):] self.tx_buffer = self.tx_buffer[len(chunk) :]
tx_credit_spent = True tx_credit_spent = True
# Update the tx credits # Update the tx credits
@@ -502,13 +530,17 @@ class DLC(EventEmitter):
self.tx_credits -= 1 self.tx_credits -= 1
# Send the frame # Send the frame
logger.debug(f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, rx_credits={self.rx_credits}, tx_credits={self.tx_credits}') logger.debug(
f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, '
f'rx_credits={self.rx_credits}, '
f'tx_credits={self.tx_credits}'
)
self.send_frame( self.send_frame(
RFCOMM_Frame.uih( RFCOMM_Frame.uih(
c_r = self.c_r, c_r=self.c_r,
dlci = self.dlci, dlci=self.dlci,
information = chunk, information=chunk,
p_f = 1 if rx_credits_needed > 0 else 0 p_f=1 if rx_credits_needed > 0 else 0,
) )
) )
@@ -517,8 +549,8 @@ class DLC(EventEmitter):
# Stream protocol # Stream protocol
def write(self, data): def write(self, data):
# We can only send bytes # We can only send bytes
if type(data) != bytes: if not isinstance(data, bytes):
if type(data) == str: if isinstance(data, str):
# Automatically convert strings to bytes using UTF-8 # Automatically convert strings to bytes using UTF-8
data = data.encode('utf-8') data = data.encode('utf-8')
else: else:
@@ -542,34 +574,34 @@ class Multiplexer(EventEmitter):
RESPONDER = 0x01 RESPONDER = 0x01
# States # States
INIT = 0x00 INIT = 0x00
CONNECTING = 0x01 CONNECTING = 0x01
CONNECTED = 0x02 CONNECTED = 0x02
OPENING = 0x03 OPENING = 0x03
DISCONNECTING = 0x04 DISCONNECTING = 0x04
DISCONNECTED = 0x05 DISCONNECTED = 0x05
RESET = 0x06 RESET = 0x06
STATE_NAMES = { STATE_NAMES = {
INIT: 'INIT', INIT: 'INIT',
CONNECTING: 'CONNECTING', CONNECTING: 'CONNECTING',
CONNECTED: 'CONNECTED', CONNECTED: 'CONNECTED',
OPENING: 'OPENING', OPENING: 'OPENING',
DISCONNECTING: 'DISCONNECTING', DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED', DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET' RESET: 'RESET',
} }
def __init__(self, l2cap_channel, role): def __init__(self, l2cap_channel, role):
super().__init__() super().__init__()
self.role = role self.role = role
self.l2cap_channel = l2cap_channel self.l2cap_channel = l2cap_channel
self.state = Multiplexer.INIT self.state = Multiplexer.INIT
self.dlcs = {} # DLCs, by DLCI self.dlcs = {} # DLCs, by DLCI
self.connection_result = None self.connection_result = None
self.disconnection_result = None self.disconnection_result = None
self.open_result = None self.open_result = None
self.acceptor = None self.acceptor = None
# Become a sink for the L2CAP channel # Become a sink for the L2CAP channel
l2cap_channel.sink = self.on_pdu l2cap_channel.sink = self.on_pdu
@@ -579,7 +611,9 @@ class Multiplexer(EventEmitter):
return Multiplexer.STATE_NAMES[state] return Multiplexer.STATE_NAMES[state]
def change_state(self, new_state): def change_state(self, new_state):
logger.debug(f'{self} state change -> {color(self.state_name(new_state), "cyan")}') logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
self.state = new_state self.state = new_state
def send_frame(self, frame): def send_frame(self, frame):
@@ -595,14 +629,14 @@ class Multiplexer(EventEmitter):
self.on_frame(frame) self.on_frame(frame)
else: else:
if frame.type == RFCOMM_DM_FRAME: if frame.type == RFCOMM_DM_FRAME:
# DM responses are for a DLCI, but since we only create the dlc when we receive # DM responses are for a DLCI, but since we only create the dlc when we
# a PN response (because we need the parameters), we handle DM frames at the Multiplexer # receive a PN response (because we need the parameters), we handle DM
# level # frames at the Multiplexer level
self.on_dm_frame(frame) self.on_dm_frame(frame)
else: else:
dlc = self.dlcs.get(frame.dlci) dlc = self.dlcs.get(frame.dlci)
if dlc is None: if dlc is None:
logger.warn(f'no dlc for DLCI {frame.dlci}') logger.warning(f'no dlc for DLCI {frame.dlci}')
return return
dlc.on_frame(frame) dlc.on_frame(frame)
@@ -610,14 +644,14 @@ class Multiplexer(EventEmitter):
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower()) handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame) handler(frame)
def on_sabm_frame(self, frame): def on_sabm_frame(self, _frame):
if self.state != Multiplexer.INIT: if self.state != Multiplexer.INIT:
logger.debug('not in INIT state, ignoring SABM') logger.debug('not in INIT state, ignoring SABM')
return return
self.change_state(Multiplexer.CONNECTED) self.change_state(Multiplexer.CONNECTED)
self.send_frame(RFCOMM_Frame.ua(c_r = 1, dlci = 0)) self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0))
def on_ua_frame(self, frame): def on_ua_frame(self, _frame):
if self.state == Multiplexer.CONNECTING: if self.state == Multiplexer.CONNECTING:
self.change_state(Multiplexer.CONNECTED) self.change_state(Multiplexer.CONNECTED)
if self.connection_result: if self.connection_result:
@@ -629,25 +663,34 @@ class Multiplexer(EventEmitter):
self.disconnection_result.set_result(None) self.disconnection_result.set_result(None)
self.disconnection_result = None self.disconnection_result = None
def on_dm_frame(self, frame): def on_dm_frame(self, _frame):
if self.state == Multiplexer.OPENING: if self.state == Multiplexer.OPENING:
self.change_state(Multiplexer.CONNECTED) self.change_state(Multiplexer.CONNECTED)
if self.open_result: if self.open_result:
self.open_result.set_exception(ConnectionError(ConnectionError.CONNECTION_REFUSED)) self.open_result.set_exception(
core.ConnectionError(
core.ConnectionError.CONNECTION_REFUSED,
BT_BR_EDR_TRANSPORT,
self.l2cap_channel.connection.peer_address,
'rfcomm',
)
)
else: else:
logger.warn(f'unexpected state for DM: {self}') logger.warning(f'unexpected state for DM: {self}')
def on_disc_frame(self, frame): def on_disc_frame(self, _frame):
self.change_state(Multiplexer.DISCONNECTED) self.change_state(Multiplexer.DISCONNECTED)
self.send_frame(RFCOMM_Frame.ua(c_r = 0 if self.role == Multiplexer.INITIATOR else 1, dlci = 0)) self.send_frame(
RFCOMM_Frame.ua(c_r=0 if self.role == Multiplexer.INITIATOR else 1, dlci=0)
)
def on_uih_frame(self, frame): def on_uih_frame(self, frame):
(type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information) (mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
if type == RFCOMM_MCC_PN_TYPE: if mcc_type == RFCOMM_MCC_PN_TYPE:
pn = RFCOMM_MCC_PN.from_bytes(value) pn = RFCOMM_MCC_PN.from_bytes(value)
self.on_mcc_pn(c_r, pn) self.on_mcc_pn(c_r, pn)
elif type == RFCOMM_MCC_MSC_TYPE: elif mcc_type == RFCOMM_MCC_MSC_TYPE:
mcs = RFCOMM_MCC_MSC.from_bytes(value) mcs = RFCOMM_MCC_MSC.from_bytes(value)
self.on_mcc_msc(c_r, mcs) self.on_mcc_msc(c_r, mcs)
@@ -663,7 +706,7 @@ class Multiplexer(EventEmitter):
if pn.dlci & 1: if pn.dlci & 1:
# Not expected, this is an initiator-side number # Not expected, this is an initiator-side number
# TODO: error out # TODO: error out
logger.warn(f'invalid DLCI: {pn.dlci}') logger.warning(f'invalid DLCI: {pn.dlci}')
else: else:
if self.acceptor: if self.acceptor:
channel_number = pn.dlci >> 1 channel_number = pn.dlci >> 1
@@ -679,10 +722,10 @@ class Multiplexer(EventEmitter):
dlc.accept() dlc.accept()
else: else:
# No acceptor, we're in Disconnected Mode # No acceptor, we're in Disconnected Mode
self.send_frame(RFCOMM_Frame.dm(c_r = 1, dlci = pn.dlci)) self.send_frame(RFCOMM_Frame.dm(c_r=1, dlci=pn.dlci))
else: else:
# No acceptor?? shouldn't happen # No acceptor?? shouldn't happen
logger.warn(color('!!! no acceptor registered', 'red')) logger.warning(color('!!! no acceptor registered', 'red'))
else: else:
# Response # Response
logger.debug(f'>>> PN Response: {pn}') logger.debug(f'>>> PN Response: {pn}')
@@ -691,12 +734,12 @@ class Multiplexer(EventEmitter):
self.dlcs[pn.dlci] = dlc self.dlcs[pn.dlci] = dlc
dlc.connect() dlc.connect()
else: else:
logger.warn('ignoring PN response') logger.warning('ignoring PN response')
def on_mcc_msc(self, c_r, msc): def on_mcc_msc(self, c_r, msc):
dlc = self.dlcs.get(msc.dlci) dlc = self.dlcs.get(msc.dlci)
if dlc is None: if dlc is None:
logger.warn(f'no dlc for DLCI {msc.dlci}') logger.warning(f'no dlc for DLCI {msc.dlci}')
return return
dlc.on_mcc_msc(c_r, msc) dlc.on_mcc_msc(c_r, msc)
@@ -706,7 +749,7 @@ class Multiplexer(EventEmitter):
self.change_state(Multiplexer.CONNECTING) self.change_state(Multiplexer.CONNECTING)
self.connection_result = asyncio.get_running_loop().create_future() self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame(RFCOMM_Frame.sabm(c_r = 1, dlci = 0)) self.send_frame(RFCOMM_Frame.sabm(c_r=1, dlci=0))
return await self.connection_result return await self.connection_result
async def disconnect(self): async def disconnect(self):
@@ -715,34 +758,38 @@ class Multiplexer(EventEmitter):
self.disconnection_result = asyncio.get_running_loop().create_future() self.disconnection_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.DISCONNECTING) self.change_state(Multiplexer.DISCONNECTING)
self.send_frame(RFCOMM_Frame.disc(c_r = 1 if self.role == Multiplexer.INITIATOR else 0, dlci = 0)) self.send_frame(
RFCOMM_Frame.disc(
c_r=1 if self.role == Multiplexer.INITIATOR else 0, dlci=0
)
)
await self.disconnection_result await self.disconnection_result
async def open_dlc(self, channel): async def open_dlc(self, channel):
if self.state != Multiplexer.CONNECTED: if self.state != Multiplexer.CONNECTED:
if self.state == Multiplexer.OPENING: if self.state == Multiplexer.OPENING:
raise InvalidStateError('open already in progress') raise InvalidStateError('open already in progress')
else:
raise InvalidStateError('not connected') raise InvalidStateError('not connected')
pn = RFCOMM_MCC_PN( pn = RFCOMM_MCC_PN(
dlci = channel << 1, dlci=channel << 1,
cl = 0xF0, cl=0xF0,
priority = 7, priority=7,
ack_timer = 0, ack_timer=0,
max_frame_size = RFCOMM_DEFAULT_PREFERRED_MTU, max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions = 0, max_retransmissions=0,
window_size = RFCOMM_DEFAULT_INITIAL_RX_CREDITS window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
) )
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_PN_TYPE, c_r = 1, data = bytes(pn)) mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}') logger.debug(f'>>> Sending MCC: {pn}')
self.open_result = asyncio.get_running_loop().create_future() self.open_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.OPENING) self.change_state(Multiplexer.OPENING)
self.send_frame( self.send_frame(
RFCOMM_Frame.uih( RFCOMM_Frame.uih(
c_r = 1 if self.role == Multiplexer.INITIATOR else 0, c_r=1 if self.role == Multiplexer.INITIATOR else 0,
dlci = 0, dlci=0,
information = mcc information=mcc,
) )
) )
result = await self.open_result result = await self.open_result
@@ -762,17 +809,19 @@ class Multiplexer(EventEmitter):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Client: class Client:
def __init__(self, device, connection): def __init__(self, device, connection):
self.device = device self.device = device
self.connection = connection self.connection = connection
self.l2cap_channel = None self.l2cap_channel = None
self.multiplexer = None self.multiplexer = None
async def start(self): async def start(self):
# Create a new L2CAP connection # Create a new L2CAP connection
try: try:
self.l2cap_channel = await self.device.l2cap_channel_manager.connect(self.connection, RFCOMM_PSM) self.l2cap_channel = await self.device.l2cap_channel_manager.connect(
self.connection, RFCOMM_PSM
)
except ProtocolError as error: except ProtocolError as error:
logger.warn(f'L2CAP connection failed: {error}') logger.warning(f'L2CAP connection failed: {error}')
raise raise
# Create a mutliplexer to manage DLCs with the server # Create a mutliplexer to manage DLCs with the server
@@ -796,16 +845,18 @@ class Client:
class Server(EventEmitter): class Server(EventEmitter):
def __init__(self, device): def __init__(self, device):
super().__init__() super().__init__()
self.device = device self.device = device
self.multiplexer = None self.multiplexer = None
self.acceptors = {} self.acceptors = {}
# Register ourselves with the L2CAP channel manager # Register ourselves with the L2CAP channel manager
device.register_l2cap_server(RFCOMM_PSM, self.on_connection) device.register_l2cap_server(RFCOMM_PSM, self.on_connection)
def listen(self, acceptor): def listen(self, acceptor):
# Find a free channel number # Find a free channel number
for channel in range(RFCOMM_DYNAMIC_CHANNEL_NUMBER_START, RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1): for channel in range(
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START, RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1
):
if channel not in self.acceptors: if channel not in self.acceptors:
self.acceptors[channel] = acceptor self.acceptors[channel] = acceptor
return channel return channel

View File

@@ -33,6 +33,9 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
SDP_CONTINUATION_WATCHDOG = 64 # Maximum number of continuations we're willing to do SDP_CONTINUATION_WATCHDOG = 64 # Maximum number of continuations we're willing to do
SDP_PSM = 0x0001 SDP_PSM = 0x0001
@@ -112,49 +115,68 @@ SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
# To be used in searches where an attribute ID list allows a range to be specified # To be used in searches where an attribute ID list allows a range to be specified
SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4) # Express this as tuple so we can convey the desired encoding size SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4) # Express this as tuple so we can convey the desired encoding size
# fmt: on
# pylint: enable=line-too-long
# pylint: disable=invalid-name
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class DataElement: class DataElement:
NIL = 0 NIL = 0
UNSIGNED_INTEGER = 1 UNSIGNED_INTEGER = 1
SIGNED_INTEGER = 2 SIGNED_INTEGER = 2
UUID = 3 UUID = 3
TEXT_STRING = 4 TEXT_STRING = 4
BOOLEAN = 5 BOOLEAN = 5
SEQUENCE = 6 SEQUENCE = 6
ALTERNATIVE = 7 ALTERNATIVE = 7
URL = 8 URL = 8
TYPE_NAMES = { TYPE_NAMES = {
NIL: 'NIL', NIL: 'NIL',
UNSIGNED_INTEGER: 'UNSIGNED_INTEGER', UNSIGNED_INTEGER: 'UNSIGNED_INTEGER',
SIGNED_INTEGER: 'SIGNED_INTEGER', SIGNED_INTEGER: 'SIGNED_INTEGER',
UUID: 'UUID', UUID: 'UUID',
TEXT_STRING: 'TEXT_STRING', TEXT_STRING: 'TEXT_STRING',
BOOLEAN: 'BOOLEAN', BOOLEAN: 'BOOLEAN',
SEQUENCE: 'SEQUENCE', SEQUENCE: 'SEQUENCE',
ALTERNATIVE: 'ALTERNATIVE', ALTERNATIVE: 'ALTERNATIVE',
URL: 'URL' URL: 'URL',
} }
type_constructors = { type_constructors = {
NIL: lambda x: DataElement(DataElement.NIL, None), NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(DataElement.UNSIGNED_INTEGER, DataElement.unsigned_integer_from_bytes(x), value_size=y), UNSIGNED_INTEGER: lambda x, y: DataElement(
SIGNED_INTEGER: lambda x, y: DataElement(DataElement.SIGNED_INTEGER, DataElement.signed_integer_from_bytes(x), value_size=y), DataElement.UNSIGNED_INTEGER,
UUID: lambda x: DataElement(DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))), DataElement.unsigned_integer_from_bytes(x),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')), value_size=y,
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1), ),
SEQUENCE: lambda x: DataElement(DataElement.SEQUENCE, DataElement.list_from_bytes(x)), SIGNED_INTEGER: lambda x, y: DataElement(
ALTERNATIVE: lambda x: DataElement(DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)), DataElement.SIGNED_INTEGER,
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')) DataElement.signed_integer_from_bytes(x),
value_size=y,
),
UUID: lambda x: DataElement(
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
),
ALTERNATIVE: lambda x: DataElement(
DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
} }
def __init__(self, type, value, value_size=None): def __init__(self, element_type, value, value_size=None):
self.type = type self.type = element_type
self.value = value self.value = value
self.value_size = value_size self.value_size = value_size
self.bytes = None # Used a cache when parsing from bytes so we can emit a byte-for-byte replica # Used as a cache when parsing from bytes so we can emit a byte-for-byte replica
if type == DataElement.UNSIGNED_INTEGER or type == DataElement.SIGNED_INTEGER: self.bytes = None
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
if value_size is None: if value_size is None:
raise ValueError('integer types must have a value size specified') raise ValueError('integer types must have a value size specified')
@@ -222,27 +244,33 @@ class DataElement:
def unsigned_integer_from_bytes(data): def unsigned_integer_from_bytes(data):
if len(data) == 1: if len(data) == 1:
return data[0] return data[0]
elif len(data) == 2:
if len(data) == 2:
return struct.unpack('>H', data)[0] return struct.unpack('>H', data)[0]
elif len(data) == 4:
if len(data) == 4:
return struct.unpack('>I', data)[0] return struct.unpack('>I', data)[0]
elif len(data) == 8:
if len(data) == 8:
return struct.unpack('>Q', data)[0] return struct.unpack('>Q', data)[0]
else:
raise ValueError(f'invalid integer length {len(data)}') raise ValueError(f'invalid integer length {len(data)}')
@staticmethod @staticmethod
def signed_integer_from_bytes(data): def signed_integer_from_bytes(data):
if len(data) == 1: if len(data) == 1:
return struct.unpack('b', data)[0] return struct.unpack('b', data)[0]
elif len(data) == 2:
if len(data) == 2:
return struct.unpack('>h', data)[0] return struct.unpack('>h', data)[0]
elif len(data) == 4:
if len(data) == 4:
return struct.unpack('>i', data)[0] return struct.unpack('>i', data)[0]
elif len(data) == 8:
if len(data) == 8:
return struct.unpack('>q', data)[0] return struct.unpack('>q', data)[0]
else:
raise ValueError(f'invalid integer length {len(data)}') raise ValueError(f'invalid integer length {len(data)}')
@staticmethod @staticmethod
def list_from_bytes(data): def list_from_bytes(data):
@@ -250,7 +278,7 @@ class DataElement:
while data: while data:
element = DataElement.from_bytes(data) element = DataElement.from_bytes(data)
elements.append(element) elements.append(element)
data = data[len(bytes(element)):] data = data[len(bytes(element)) :]
return elements return elements
@staticmethod @staticmethod
@@ -260,11 +288,11 @@ class DataElement:
@staticmethod @staticmethod
def from_bytes(data): def from_bytes(data):
type = data[0] >> 3 element_type = data[0] >> 3
size_index = data[0] & 7 size_index = data[0] & 7
value_offset = 0 value_offset = 0
if size_index == 0: if size_index == 0:
if type == DataElement.NIL: if element_type == DataElement.NIL:
value_size = 0 value_size = 0
else: else:
value_size = 1 value_size = 1
@@ -286,16 +314,21 @@ class DataElement:
value_size = struct.unpack('>I', data[1:5])[0] value_size = struct.unpack('>I', data[1:5])[0]
value_offset = 4 value_offset = 4
value_data = data[1 + value_offset:1 + value_offset + value_size] value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.type_constructors.get(type) constructor = DataElement.type_constructors.get(element_type)
if constructor: if constructor:
if type == DataElement.UNSIGNED_INTEGER or type == DataElement.SIGNED_INTEGER: if element_type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
):
result = constructor(value_data, value_size) result = constructor(value_data, value_size)
else: else:
result = constructor(value_data) result = constructor(value_data)
else: else:
result = DataElement(type, value_data) result = DataElement(element_type, value_data)
result.bytes = data[:1 + value_offset + value_size] # Keep a copy so we can re-serialize to an exact replica result.bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
return result return result
def to_bytes(self): def to_bytes(self):
@@ -311,7 +344,8 @@ class DataElement:
elif self.type == DataElement.UNSIGNED_INTEGER: elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0: if self.value < 0:
raise ValueError('UNSIGNED_INTEGER cannot be negative') raise ValueError('UNSIGNED_INTEGER cannot be negative')
elif self.value_size == 1:
if self.value_size == 1:
data = struct.pack('B', self.value) data = struct.pack('B', self.value)
elif self.value_size == 2: elif self.value_size == 2:
data = struct.pack('>H', self.value) data = struct.pack('>H', self.value)
@@ -334,11 +368,11 @@ class DataElement:
raise ValueError('invalid value_size') raise ValueError('invalid value_size')
elif self.type == DataElement.UUID: elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value))) data = bytes(reversed(bytes(self.value)))
elif self.type == DataElement.TEXT_STRING or self.type == DataElement.URL: elif self.type in (DataElement.TEXT_STRING, DataElement.URL):
data = self.value.encode('utf8') data = self.value.encode('utf8')
elif self.type == DataElement.BOOLEAN: elif self.type == DataElement.BOOLEAN:
data = bytes([1 if self.value else 0]) data = bytes([1 if self.value else 0])
elif self.type == DataElement.SEQUENCE or self.type == DataElement.ALTERNATIVE: elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
data = b''.join([bytes(element) for element in self.value]) data = b''.join([bytes(element) for element in self.value])
else: else:
data = self.value data = self.value
@@ -349,9 +383,11 @@ class DataElement:
if size != 0: if size != 0:
raise ValueError('NIL must be empty') raise ValueError('NIL must be empty')
size_index = 0 size_index = 0
elif (self.type == DataElement.UNSIGNED_INTEGER or elif self.type in (
self.type == DataElement.SIGNED_INTEGER or DataElement.UNSIGNED_INTEGER,
self.type == DataElement.UUID): DataElement.SIGNED_INTEGER,
DataElement.UUID,
):
if size <= 1: if size <= 1:
size_index = 0 size_index = 0
elif size == 2: elif size == 2:
@@ -364,10 +400,12 @@ class DataElement:
size_index = 4 size_index = 4
else: else:
raise ValueError('invalid data size') raise ValueError('invalid data size')
elif (self.type == DataElement.TEXT_STRING or elif self.type in (
self.type == DataElement.SEQUENCE or DataElement.TEXT_STRING,
self.type == DataElement.ALTERNATIVE or DataElement.SEQUENCE,
self.type == DataElement.URL): DataElement.ALTERNATIVE,
DataElement.URL,
):
if size <= 0xFF: if size <= 0xFF:
size_index = 5 size_index = 5
size_bytes = bytes([size]) size_bytes = bytes([size])
@@ -392,11 +430,19 @@ class DataElement:
type_name = name_or_number(self.TYPE_NAMES, self.type) type_name = name_or_number(self.TYPE_NAMES, self.type)
if self.type == DataElement.NIL: if self.type == DataElement.NIL:
value_string = '' value_string = ''
elif self.type == DataElement.SEQUENCE or self.type == DataElement.ALTERNATIVE: elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
container_separator = '\n' if pretty else '' container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ',' element_separator = '\n' if pretty else ','
value_string = f'[{container_separator}{element_separator.join([element.to_string(pretty, indentation + 1 if pretty else 0) for element in self.value])}{container_separator}{prefix}]' elements = [
elif self.type == DataElement.UNSIGNED_INTEGER or self.type == DataElement.SIGNED_INTEGER: element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
value_string = f'{self.value}#{self.value_size}' value_string = f'{self.value}#{self.value_size}'
elif isinstance(self.value, DataElement): elif isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation) value_string = self.value.to_string(pretty, indentation)
@@ -410,17 +456,17 @@ class DataElement:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ServiceAttribute: class ServiceAttribute:
def __init__(self, id, value): def __init__(self, attribute_id, value):
self.id = id self.id = attribute_id
self.value = value self.value = value
@staticmethod @staticmethod
def list_from_data_elements(elements): def list_from_data_elements(elements):
attribute_list = [] attribute_list = []
for i in range(0, len(elements) // 2): for i in range(0, len(elements) // 2):
attribute_id, attribute_value = elements[2 * i:2 * (i + 1)] attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
if attribute_id.type != DataElement.UNSIGNED_INTEGER: if attribute_id.type != DataElement.UNSIGNED_INTEGER:
logger.warn('attribute ID element is not an integer') logger.warning('attribute ID element is not an integer')
continue continue
attribute_list.append(ServiceAttribute(attribute_id.value, attribute_value)) attribute_list.append(ServiceAttribute(attribute_id.value, attribute_value))
@@ -428,30 +474,41 @@ class ServiceAttribute:
@staticmethod @staticmethod
def find_attribute_in_list(attribute_list, attribute_id): def find_attribute_in_list(attribute_list, attribute_id):
return next((attribute.value for attribute in attribute_list if attribute.id == attribute_id), None) return next(
(
attribute.value
for attribute in attribute_list
if attribute.id == attribute_id
),
None,
)
@staticmethod @staticmethod
def id_name(id): def id_name(id_code):
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id) return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
@staticmethod @staticmethod
def is_uuid_in_value(uuid, value): def is_uuid_in_value(uuid, value):
# Find if a uuid matches a value, either directly or recursing into sequences # Find if a uuid matches a value, either directly or recursing into sequences
if value.type == DataElement.UUID: if value.type == DataElement.UUID:
return value.value == uuid return value.value == uuid
elif value.type == DataElement.SEQUENCE:
if value.type == DataElement.SEQUENCE:
for element in value.value: for element in value.value:
if ServiceAttribute.is_uuid_in_value(uuid, element): if ServiceAttribute.is_uuid_in_value(uuid, element):
return True return True
return False return False
else:
return False
def to_string(self, color=False): return False
if color:
return f'Attribute(id={colors.color(self.id_name(self.id),"magenta")},value={self.value})' def to_string(self, with_colors=False):
else: if with_colors:
return f'Attribute(id={self.id_name(self.id)},value={self.value})' return (
f'Attribute(id={colors.color(self.id_name(self.id),"magenta")},'
f'value={self.value})'
)
return f'Attribute(id={self.id_name(self.id)},value={self.value})'
def __str__(self): def __str__(self):
return self.to_string() return self.to_string()
@@ -462,11 +519,14 @@ class SDP_PDU:
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
''' '''
sdp_pdu_classes = {} sdp_pdu_classes = {}
name = None
pdu_id = 0
@staticmethod @staticmethod
def from_bytes(pdu): def from_bytes(pdu):
pdu_id, transaction_id, parameters_length = struct.unpack_from('>BHH', pdu, 0) pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0)
cls = SDP_PDU.sdp_pdu_classes.get(pdu_id) cls = SDP_PDU.sdp_pdu_classes.get(pdu_id)
if cls is None: if cls is None:
@@ -484,13 +544,15 @@ class SDP_PDU:
@staticmethod @staticmethod
def parse_service_record_handle_list_preceded_by_count(data, offset): def parse_service_record_handle_list_preceded_by_count(data, offset):
count = struct.unpack_from('>H', data, offset - 2)[0] count = struct.unpack_from('>H', data, offset - 2)[0]
handle_list = [struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)] handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
]
return offset + count * 4, handle_list return offset + count * 4, handle_list
@staticmethod @staticmethod
def parse_bytes_preceded_by_length(data, offset): def parse_bytes_preceded_by_length(data, offset):
length = struct.unpack_from('>H', data, offset - 2)[0] length = struct.unpack_from('>H', data, offset - 2)[0]
return offset + length, data[offset:offset + length] return offset + length, data[offset : offset + length]
@staticmethod @staticmethod
def error_name(error_code): def error_name(error_code):
@@ -532,7 +594,10 @@ class SDP_PDU:
HCI_Object.init_from_fields(self, self.fields, kwargs) HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None: if pdu is None:
parameters = HCI_Object.dict_to_bytes(kwargs, self.fields) parameters = HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters)) + parameters pdu = (
struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters))
+ parameters
)
self.pdu = pdu self.pdu = pdu
self.transaction_id = transaction_id self.transaction_id = transaction_id
@@ -555,9 +620,7 @@ class SDP_PDU:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass([('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})])
('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})
])
class SDP_ErrorResponse(SDP_PDU): class SDP_ErrorResponse(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
@@ -565,11 +628,13 @@ class SDP_ErrorResponse(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass(
('service_search_pattern', DataElement.parse_from_bytes), [
('maximum_service_record_count', '>2'), ('service_search_pattern', DataElement.parse_from_bytes),
('continuation_state', '*') ('maximum_service_record_count', '>2'),
]) ('continuation_state', '*'),
]
)
class SDP_ServiceSearchRequest(SDP_PDU): class SDP_ServiceSearchRequest(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
@@ -577,12 +642,17 @@ class SDP_ServiceSearchRequest(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass(
('total_service_record_count', '>2'), [
('current_service_record_count', '>2'), ('total_service_record_count', '>2'),
('service_record_handle_list', SDP_PDU.parse_service_record_handle_list_preceded_by_count), ('current_service_record_count', '>2'),
('continuation_state', '*') (
]) 'service_record_handle_list',
SDP_PDU.parse_service_record_handle_list_preceded_by_count,
),
('continuation_state', '*'),
]
)
class SDP_ServiceSearchResponse(SDP_PDU): class SDP_ServiceSearchResponse(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
@@ -590,12 +660,14 @@ class SDP_ServiceSearchResponse(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass(
('service_record_handle', '>4'), [
('maximum_attribute_byte_count', '>2'), ('service_record_handle', '>4'),
('attribute_id_list', DataElement.parse_from_bytes), ('maximum_attribute_byte_count', '>2'),
('continuation_state', '*') ('attribute_id_list', DataElement.parse_from_bytes),
]) ('continuation_state', '*'),
]
)
class SDP_ServiceAttributeRequest(SDP_PDU): class SDP_ServiceAttributeRequest(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
@@ -603,11 +675,13 @@ class SDP_ServiceAttributeRequest(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass(
('attribute_list_byte_count', '>2'), [
('attribute_list', SDP_PDU.parse_bytes_preceded_by_length), ('attribute_list_byte_count', '>2'),
('continuation_state', '*') ('attribute_list', SDP_PDU.parse_bytes_preceded_by_length),
]) ('continuation_state', '*'),
]
)
class SDP_ServiceAttributeResponse(SDP_PDU): class SDP_ServiceAttributeResponse(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
@@ -615,12 +689,14 @@ class SDP_ServiceAttributeResponse(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass(
('service_search_pattern', DataElement.parse_from_bytes), [
('maximum_attribute_byte_count', '>2'), ('service_search_pattern', DataElement.parse_from_bytes),
('attribute_id_list', DataElement.parse_from_bytes), ('maximum_attribute_byte_count', '>2'),
('continuation_state', '*') ('attribute_id_list', DataElement.parse_from_bytes),
]) ('continuation_state', '*'),
]
)
class SDP_ServiceSearchAttributeRequest(SDP_PDU): class SDP_ServiceSearchAttributeRequest(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
@@ -628,11 +704,13 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass(
('attribute_lists_byte_count', '>2'), [
('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length), ('attribute_lists_byte_count', '>2'),
('continuation_state', '*') ('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length),
]) ('continuation_state', '*'),
]
)
class SDP_ServiceSearchAttributeResponse(SDP_PDU): class SDP_ServiceSearchAttributeResponse(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
@@ -642,9 +720,9 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Client: class Client:
def __init__(self, device): def __init__(self, device):
self.device = device self.device = device
self.pending_request = None self.pending_request = None
self.channel = None self.channel = None
async def connect(self, connection): async def connect(self, connection):
result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM) result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM)
@@ -659,7 +737,9 @@ class Client:
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
service_search_pattern = DataElement.sequence([DataElement.uuid(uuid) for uuid in uuids]) service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids]
)
# Request and accumulate until there's no more continuation # Request and accumulate until there's no more continuation
service_record_handle_list = [] service_record_handle_list = []
@@ -668,10 +748,10 @@ class Client:
while watchdog > 0: while watchdog > 0:
response_pdu = await self.channel.send_request( response_pdu = await self.channel.send_request(
SDP_ServiceSearchRequest( SDP_ServiceSearchRequest(
transaction_id = 0, # Transaction ID TODO: pick a real value transaction_id=0, # Transaction ID TODO: pick a real value
service_search_pattern = service_search_pattern, service_search_pattern=service_search_pattern,
maximum_service_record_count = 0xFFFF, maximum_service_record_count=0xFFFF,
continuation_state = continuation_state continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu) response = SDP_PDU.from_bytes(response_pdu)
@@ -689,11 +769,15 @@ class Client:
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
service_search_pattern = DataElement.sequence([DataElement.uuid(uuid) for uuid in uuids]) service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids]
)
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
DataElement.unsigned_integer(attribute_id[0], value_size=attribute_id[1]) DataElement.unsigned_integer(
if type(attribute_id) is tuple attribute_id[0], value_size=attribute_id[1]
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id) else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids for attribute_id in attribute_ids
] ]
@@ -706,11 +790,11 @@ class Client:
while watchdog > 0: while watchdog > 0:
response_pdu = await self.channel.send_request( response_pdu = await self.channel.send_request(
SDP_ServiceSearchAttributeRequest( SDP_ServiceSearchAttributeRequest(
transaction_id = 0, # Transaction ID TODO: pick a real value transaction_id=0, # Transaction ID TODO: pick a real value
service_search_pattern = service_search_pattern, service_search_pattern=service_search_pattern,
maximum_attribute_byte_count = 0xFFFF, maximum_attribute_byte_count=0xFFFF,
attribute_id_list = attribute_id_list, attribute_id_list=attribute_id_list,
continuation_state = continuation_state continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu) response = SDP_PDU.from_bytes(response_pdu)
@@ -725,7 +809,7 @@ class Client:
# Parse the result into attribute lists # Parse the result into attribute lists
attribute_lists_sequences = DataElement.from_bytes(accumulator) attribute_lists_sequences = DataElement.from_bytes(accumulator)
if attribute_lists_sequences.type != DataElement.SEQUENCE: if attribute_lists_sequences.type != DataElement.SEQUENCE:
logger.warn('unexpected data type') logger.warning('unexpected data type')
return [] return []
return [ return [
@@ -740,8 +824,10 @@ class Client:
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
DataElement.unsigned_integer(attribute_id[0], value_size=attribute_id[1]) DataElement.unsigned_integer(
if type(attribute_id) is tuple attribute_id[0], value_size=attribute_id[1]
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id) else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids for attribute_id in attribute_ids
] ]
@@ -754,11 +840,11 @@ class Client:
while watchdog > 0: while watchdog > 0:
response_pdu = await self.channel.send_request( response_pdu = await self.channel.send_request(
SDP_ServiceAttributeRequest( SDP_ServiceAttributeRequest(
transaction_id = 0, # Transaction ID TODO: pick a real value transaction_id=0, # Transaction ID TODO: pick a real value
service_record_handle = service_record_handle, service_record_handle=service_record_handle,
maximum_attribute_byte_count = 0xFFFF, maximum_attribute_byte_count=0xFFFF,
attribute_id_list = attribute_id_list, attribute_id_list=attribute_id_list,
continuation_state = continuation_state continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu) response = SDP_PDU.from_bytes(response_pdu)
@@ -773,7 +859,7 @@ class Client:
# Parse the result into a list of attributes # Parse the result into a list of attributes
attribute_list_sequence = DataElement.from_bytes(accumulator) attribute_list_sequence = DataElement.from_bytes(accumulator)
if attribute_list_sequence.type != DataElement.SEQUENCE: if attribute_list_sequence.type != DataElement.SEQUENCE:
logger.warn('unexpected data type') logger.warning('unexpected data type')
return [] return []
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value) return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
@@ -784,8 +870,9 @@ class Server:
CONTINUATION_STATE = bytes([0x01, 0x43]) CONTINUATION_STATE = bytes([0x01, 0x43])
def __init__(self, device): def __init__(self, device):
self.device = device self.device = device
self.service_records = {} # Service records maps, by record handle self.service_records = {} # Service records maps, by record handle
self.channel = None
self.current_response = None self.current_response = None
def register(self, l2cap_channel_manager): def register(self, l2cap_channel_manager):
@@ -820,11 +907,10 @@ class Server:
try: try:
sdp_pdu = SDP_PDU.from_bytes(pdu) sdp_pdu = SDP_PDU.from_bytes(pdu)
except Exception as error: except Exception as error:
logger.warn(color(f'failed to parse SDP Request PDU: {error}', 'red')) logger.warning(color(f'failed to parse SDP Request PDU: {error}', 'red'))
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id = 0, transaction_id=0, error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR
error_code = SDP_INVALID_REQUEST_SYNTAX_ERROR
) )
) )
@@ -840,16 +926,16 @@ class Server:
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}') logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id = sdp_pdu.transaction_id, transaction_id=sdp_pdu.transaction_id,
error_code = SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR error_code=SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR,
) )
) )
else: else:
logger.error(color('SDP Request not handled???', 'red')) logger.error(color('SDP Request not handled???', 'red'))
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id = sdp_pdu.transaction_id, transaction_id=sdp_pdu.transaction_id,
error_code = SDP_INVALID_REQUEST_SYNTAX_ERROR error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR,
) )
) )
@@ -872,17 +958,18 @@ class Server:
if attribute_id.value_size == 4: if attribute_id.value_size == 4:
# Attribute ID range # Attribute ID range
id_range_start = attribute_id.value >> 16 id_range_start = attribute_id.value >> 16
id_range_end = attribute_id.value & 0xFFFF id_range_end = attribute_id.value & 0xFFFF
else: else:
id_range_start = attribute_id.value id_range_start = attribute_id.value
id_range_end = attribute_id.value id_range_end = attribute_id.value
attributes += [ attributes += [
attribute for attribute in service attribute
for attribute in service
if attribute.id >= id_range_start and attribute.id <= id_range_end if attribute.id >= id_range_start and attribute.id <= id_range_end
] ]
# Return the maching attributes, sorted by attribute id # Return the matching attributes, sorted by attribute id
attributes.sort(key = lambda x: x.id) attributes.sort(key=lambda x: x.id)
attribute_list = DataElement.sequence([]) attribute_list = DataElement.sequence([])
for attribute in attributes: for attribute in attributes:
attribute_list.value.append(DataElement.unsigned_integer_16(attribute.id)) attribute_list.value.append(DataElement.unsigned_integer_16(attribute.id))
@@ -896,8 +983,8 @@ class Server:
if not self.current_response: if not self.current_response:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id = request.transaction_id, transaction_id=request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
) )
) )
return return
@@ -910,30 +997,38 @@ class Server:
service_record_handles = list(matching_services.keys()) service_record_handles = list(matching_services.keys())
# Only return up to the maximum requested # Only return up to the maximum requested
service_record_handles_subset = service_record_handles[:request.maximum_service_record_count] service_record_handles_subset = service_record_handles[
: request.maximum_service_record_count
]
# Serialize to a byte array, and remember the total count # Serialize to a byte array, and remember the total count
logger.debug(f'Service Record Handles: {service_record_handles}') logger.debug(f'Service Record Handles: {service_record_handles}')
self.current_response = ( self.current_response = (
len(service_record_handles), len(service_record_handles),
service_record_handles_subset service_record_handles_subset,
) )
# Respond, keeping any unsent handles for later # Respond, keeping any unsent handles for later
service_record_handles = self.current_response[1][:request.maximum_service_record_count] service_record_handles = self.current_response[1][
: request.maximum_service_record_count
]
self.current_response = ( self.current_response = (
self.current_response[0], self.current_response[0],
self.current_response[1][request.maximum_service_record_count:] self.current_response[1][request.maximum_service_record_count :],
)
continuation_state = (
Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
)
service_record_handle_list = b''.join(
[struct.pack('>I', handle) for handle in service_record_handles]
) )
continuation_state = Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
service_record_handle_list = b''.join([struct.pack('>I', handle) for handle in service_record_handles])
self.send_response( self.send_response(
SDP_ServiceSearchResponse( SDP_ServiceSearchResponse(
transaction_id = request.transaction_id, transaction_id=request.transaction_id,
total_service_record_count = self.current_response[0], total_service_record_count=self.current_response[0],
current_service_record_count = len(service_record_handles), current_service_record_count=len(service_record_handles),
service_record_handle_list = service_record_handle_list, service_record_handle_list=service_record_handle_list,
continuation_state = continuation_state continuation_state=continuation_state,
) )
) )
@@ -943,8 +1038,8 @@ class Server:
if not self.current_response: if not self.current_response:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id = request.transaction_id, transaction_id=request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
) )
) )
return return
@@ -957,27 +1052,31 @@ class Server:
if service is None: if service is None:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id = request.transaction_id, transaction_id=request.transaction_id,
error_code = SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR error_code=SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR,
) )
) )
return return
# Get the attributes for the service # Get the attributes for the service
attribute_list = Server.get_service_attributes(service, request.attribute_id_list.value) attribute_list = Server.get_service_attributes(
service, request.attribute_id_list.value
)
# Serialize to a byte array # Serialize to a byte array
logger.debug(f'Attributes: {attribute_list}') logger.debug(f'Attributes: {attribute_list}')
self.current_response = bytes(attribute_list) self.current_response = bytes(attribute_list)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
attribute_list, continuation_state = self.get_next_response_payload(request.maximum_attribute_byte_count) attribute_list, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count
)
self.send_response( self.send_response(
SDP_ServiceAttributeResponse( SDP_ServiceAttributeResponse(
transaction_id = request.transaction_id, transaction_id=request.transaction_id,
attribute_list_byte_count = len(attribute_list), attribute_list_byte_count=len(attribute_list),
attribute_list = attribute_list, attribute_list=attribute_list,
continuation_state = continuation_state continuation_state=continuation_state,
) )
) )
@@ -987,8 +1086,8 @@ class Server:
if not self.current_response: if not self.current_response:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id = request.transaction_id, transaction_id=request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
) )
) )
else: else:
@@ -996,12 +1095,16 @@ class Server:
self.current_response = None self.current_response = None
# Find the matching services # Find the matching services
matching_services = self.match_services(request.service_search_pattern).values() matching_services = self.match_services(
request.service_search_pattern
).values()
# Filter the required attributes # Filter the required attributes
attribute_lists = DataElement.sequence([]) attribute_lists = DataElement.sequence([])
for service in matching_services: for service in matching_services:
attribute_list = Server.get_service_attributes(service, request.attribute_id_list.value) attribute_list = Server.get_service_attributes(
service, request.attribute_id_list.value
)
if attribute_list.value: if attribute_list.value:
attribute_lists.value.append(attribute_list) attribute_lists.value.append(attribute_list)
@@ -1010,12 +1113,14 @@ class Server:
self.current_response = bytes(attribute_lists) self.current_response = bytes(attribute_lists)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
attribute_lists, continuation_state = self.get_next_response_payload(request.maximum_attribute_byte_count) attribute_lists, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count
)
self.send_response( self.send_response(
SDP_ServiceSearchAttributeResponse( SDP_ServiceSearchAttributeResponse(
transaction_id = request.transaction_id, transaction_id=request.transaction_id,
attribute_lists_byte_count = len(attribute_lists), attribute_lists_byte_count=len(attribute_lists),
attribute_lists = attribute_lists, attribute_lists=attribute_lists,
continuation_state = continuation_state continuation_state=continuation_state,
) )
) )

File diff suppressed because it is too large Load Diff

View File

@@ -35,48 +35,76 @@ async def open_transport(name):
Where <parameters> depend on the type (and may be empty for some types). Where <parameters> depend on the type (and may be empty for some types).
The supported types are: serial,udp,tcp,pty,usb The supported types are: serial,udp,tcp,pty,usb
''' '''
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements
scheme, *spec = name.split(':', 1) scheme, *spec = name.split(':', 1)
if scheme == 'serial' and spec: if scheme == 'serial' and spec:
from .serial import open_serial_transport from .serial import open_serial_transport
return await open_serial_transport(spec[0]) return await open_serial_transport(spec[0])
elif scheme == 'udp' and spec:
if scheme == 'udp' and spec:
from .udp import open_udp_transport from .udp import open_udp_transport
return await open_udp_transport(spec[0]) return await open_udp_transport(spec[0])
elif scheme == 'tcp-client' and spec:
if scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport from .tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec[0]) return await open_tcp_client_transport(spec[0])
elif scheme == 'tcp-server' and spec:
if scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport from .tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec[0]) return await open_tcp_server_transport(spec[0])
elif scheme == 'ws-client' and spec:
if scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport from .ws_client import open_ws_client_transport
return await open_ws_client_transport(spec[0]) return await open_ws_client_transport(spec[0])
elif scheme == 'ws-server' and spec:
if scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport from .ws_server import open_ws_server_transport
return await open_ws_server_transport(spec[0]) return await open_ws_server_transport(spec[0])
elif scheme == 'pty':
if scheme == 'pty':
from .pty import open_pty_transport from .pty import open_pty_transport
return await open_pty_transport(spec[0] if spec else None) return await open_pty_transport(spec[0] if spec else None)
elif scheme == 'file':
if scheme == 'file':
from .file import open_file_transport from .file import open_file_transport
return await open_file_transport(spec[0] if spec else None) return await open_file_transport(spec[0] if spec else None)
elif scheme == 'vhci':
if scheme == 'vhci':
from .vhci import open_vhci_transport from .vhci import open_vhci_transport
return await open_vhci_transport(spec[0] if spec else None) return await open_vhci_transport(spec[0] if spec else None)
elif scheme == 'hci-socket':
if scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport from .hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec[0] if spec else None) return await open_hci_socket_transport(spec[0] if spec else None)
elif scheme == 'usb':
if scheme == 'usb':
from .usb import open_usb_transport from .usb import open_usb_transport
return await open_usb_transport(spec[0] if spec else None) return await open_usb_transport(spec[0] if spec else None)
elif scheme == 'pyusb':
if scheme == 'pyusb':
from .pyusb import open_pyusb_transport from .pyusb import open_pyusb_transport
return await open_pyusb_transport(spec[0] if spec else None) return await open_pyusb_transport(spec[0] if spec else None)
elif scheme == 'android-emulator':
if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport from .android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec[0] if spec else None) return await open_android_emulator_transport(spec[0] if spec else None)
else:
raise ValueError('unknown transport scheme') raise ValueError('unknown transport scheme')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -84,12 +112,12 @@ async def open_transport_or_link(name):
if name.startswith('link-relay:'): if name.startswith('link-relay:'):
link = RemoteLink(name[11:]) link = RemoteLink(name[11:])
await link.wait_until_connected() await link.wait_until_connected()
controller = Controller('remote', link = link) controller = Controller('remote', link=link)
class LinkTransport(Transport): class LinkTransport(Transport):
async def close(self): async def close(self):
link.close() link.close()
return LinkTransport(controller, AsyncPipeSink(controller)) return LinkTransport(controller, AsyncPipeSink(controller))
else:
return await open_transport(name) return await open_transport(name)

View File

@@ -59,15 +59,10 @@ async def open_android_emulator_transport(spec):
return bytes([packet.type]) + packet.packet return bytes([packet.type]) + packet.packet
async def write(self, packet): async def write(self, packet):
await self.hci_device.write( await self.hci_device.write(HCIPacket(type=packet[0], packet=packet[1:]))
HCIPacket(
type = packet[0],
packet = packet[1:]
)
)
# Parse the parameters # Parse the parameters
mode = 'host' mode = 'host'
server_host = 'localhost' server_host = 'localhost'
server_port = 8554 server_port = 8554
if spec is not None: if spec is not None:
@@ -100,7 +95,7 @@ async def open_android_emulator_transport(spec):
transport = PumpedTransport( transport = PumpedTransport(
PumpedPacketSource(hci_device.read), PumpedPacketSource(hci_device.read),
PumpedPacketSink(hci_device.write), PumpedPacketSink(hci_device.write),
channel.close channel.close,
) )
transport.start() transport.start()

View File

@@ -33,10 +33,10 @@ logger = logging.getLogger(__name__)
# For each packet type, the info represents: # For each packet type, the info represents:
# (length-size, length-offset, unpack-type) # (length-size, length-offset, unpack-type)
HCI_PACKET_INFO = { HCI_PACKET_INFO = {
hci.HCI_COMMAND_PACKET: (1, 2, 'B'), hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
hci.HCI_EVENT_PACKET: (1, 1, 'B') hci.HCI_EVENT_PACKET: (1, 1, 'B'),
} }
@@ -48,7 +48,7 @@ class PacketPump:
def __init__(self, reader, sink): def __init__(self, reader, sink):
self.reader = reader self.reader = reader
self.sink = sink self.sink = sink
async def run(self): async def run(self):
while True: while True:
@@ -65,43 +65,51 @@ class PacketPump:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PacketParser: class PacketParser:
''' '''
In-line parser that accepts data and emits 'on_packet' when a full packet has been parsed In-line parser that accepts data and emits 'on_packet' when a full packet has been
parsed
''' '''
NEED_TYPE = 0
NEED_LENGTH = 1
NEED_BODY = 2
def __init__(self, sink = None): # pylint: disable=attribute-defined-outside-init
NEED_TYPE = 0
NEED_LENGTH = 1
NEED_BODY = 2
def __init__(self, sink=None):
self.sink = sink self.sink = sink
self.extended_packet_info = {} self.extended_packet_info = {}
self.reset() self.reset()
def reset(self): def reset(self):
self.state = PacketParser.NEED_TYPE self.state = PacketParser.NEED_TYPE
self.bytes_needed = 1 self.bytes_needed = 1
self.packet = bytearray() self.packet = bytearray()
self.packet_info = None self.packet_info = None
def feed_data(self, data): def feed_data(self, data):
data_offset = 0 data_offset = 0
data_left = len(data) data_left = len(data)
while data_left and self.bytes_needed: while data_left and self.bytes_needed:
consumed = min(self.bytes_needed, data_left) consumed = min(self.bytes_needed, data_left)
self.packet.extend(data[data_offset:data_offset + consumed]) self.packet.extend(data[data_offset : data_offset + consumed])
data_offset += consumed data_offset += consumed
data_left -= consumed data_left -= consumed
self.bytes_needed -= consumed self.bytes_needed -= consumed
if self.bytes_needed == 0: if self.bytes_needed == 0:
if self.state == PacketParser.NEED_TYPE: if self.state == PacketParser.NEED_TYPE:
packet_type = self.packet[0] packet_type = self.packet[0]
self.packet_info = HCI_PACKET_INFO.get(packet_type) or self.extended_packet_info.get(packet_type) self.packet_info = HCI_PACKET_INFO.get(
packet_type
) or self.extended_packet_info.get(packet_type)
if self.packet_info is None: if self.packet_info is None:
raise ValueError(f'invalid packet type {packet_type}') raise ValueError(f'invalid packet type {packet_type}')
self.state = PacketParser.NEED_LENGTH self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1] self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH: elif self.state == PacketParser.NEED_LENGTH:
body_length = struct.unpack_from(self.packet_info[2], self.packet, 1 + self.packet_info[1])[0] body_length = struct.unpack_from(
self.packet_info[2], self.packet, 1 + self.packet_info[1]
)[0]
self.bytes_needed = body_length self.bytes_needed = body_length
self.state = PacketParser.NEED_BODY self.state = PacketParser.NEED_BODY
@@ -111,7 +119,9 @@ class PacketParser:
try: try:
self.sink.on_packet(bytes(self.packet)) self.sink.on_packet(bytes(self.packet))
except Exception as error: except Exception as error:
logger.warning(color(f'!!! Exception in on_packet: {error}', 'red')) logger.warning(
color(f'!!! Exception in on_packet: {error}', 'red')
)
self.reset() self.reset()
def set_packet_sink(self, sink): def set_packet_sink(self, sink):
@@ -187,6 +197,7 @@ class AsyncPipeSink:
''' '''
Sink that forwards packets asynchronously to another sink Sink that forwards packets asynchronously to another sink
''' '''
def __init__(self, sink): def __init__(self, sink):
self.sink = sink self.sink = sink
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
@@ -202,7 +213,7 @@ class ParserSource:
""" """
def __init__(self): def __init__(self):
self.parser = PacketParser() self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future() self.terminated = asyncio.get_running_loop().create_future()
def set_packet_sink(self, sink): def set_packet_sink(self, sink):
@@ -237,7 +248,7 @@ class StreamPacketSink:
class Transport: class Transport:
def __init__(self, source, sink): def __init__(self, source, sink):
self.source = source self.source = source
self.sink = sink self.sink = sink
async def __aenter__(self): async def __aenter__(self):
return self return self
@@ -258,7 +269,7 @@ class PumpedPacketSource(ParserSource):
def __init__(self, receive): def __init__(self, receive):
super().__init__() super().__init__()
self.receive_function = receive self.receive_function = receive
self.pump_task = None self.pump_task = None
def start(self): def start(self):
async def pump_packets(): async def pump_packets():
@@ -270,11 +281,11 @@ class PumpedPacketSource(ParserSource):
logger.debug('source pump task done') logger.debug('source pump task done')
break break
except Exception as error: except Exception as error:
logger.warn(f'exception while waiting for packet: {error}') logger.warning(f'exception while waiting for packet: {error}')
self.terminated.set_result(error) self.terminated.set_result(error)
break break
self.pump_task = asyncio.get_running_loop().create_task(pump_packets()) self.pump_task = asyncio.create_task(pump_packets())
def close(self): def close(self):
if self.pump_task: if self.pump_task:
@@ -285,8 +296,8 @@ class PumpedPacketSource(ParserSource):
class PumpedPacketSink: class PumpedPacketSink:
def __init__(self, send): def __init__(self, send):
self.send_function = send self.send_function = send
self.packet_queue = asyncio.Queue() self.packet_queue = asyncio.Queue()
self.pump_task = None self.pump_task = None
def on_packet(self, packet): def on_packet(self, packet):
self.packet_queue.put_nowait(packet) self.packet_queue.put_nowait(packet)
@@ -301,10 +312,10 @@ class PumpedPacketSink:
logger.debug('sink pump task done') logger.debug('sink pump task done')
break break
except Exception as error: except Exception as error:
logger.warn(f'exception while sending packet: {error}') logger.warning(f'exception while sending packet: {error}')
break break
self.pump_task = asyncio.get_running_loop().create_task(pump_packets()) self.pump_task = asyncio.create_task(pump_packets())
def close(self): def close(self):
if self.pump_task: if self.pump_task:

View File

@@ -21,32 +21,36 @@ from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3'
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3') )
_HCIPACKET = DESCRIPTOR.message_types_by_name['HCIPacket'] _HCIPACKET = DESCRIPTOR.message_types_by_name['HCIPacket']
_HCIPACKET_PACKETTYPE = _HCIPACKET.enum_types_by_name['PacketType'] _HCIPACKET_PACKETTYPE = _HCIPACKET.enum_types_by_name['PacketType']
HCIPacket = _reflection.GeneratedProtocolMessageType('HCIPacket', (_message.Message,), { HCIPacket = _reflection.GeneratedProtocolMessageType(
'DESCRIPTOR' : _HCIPACKET, 'HCIPacket',
'__module__' : 'emulated_bluetooth_packets_pb2' (_message.Message,),
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.HCIPacket) {
}) 'DESCRIPTOR': _HCIPACKET,
'__module__': 'emulated_bluetooth_packets_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.HCIPacket)
},
)
_sym_db.RegisterMessage(HCIPacket) _sym_db.RegisterMessage(HCIPacket)
if _descriptor._USE_C_DESCRIPTORS == False: if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth' DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth'
_HCIPACKET._serialized_start=66 _HCIPACKET._serialized_start = 66
_HCIPACKET._serialized_end=317 _HCIPACKET._serialized_end = 317
_HCIPACKET_PACKETTYPE._serialized_start=161 _HCIPACKET_PACKETTYPE._serialized_start = 161
_HCIPACKET_PACKETTYPE._serialized_end=317 _HCIPACKET_PACKETTYPE._serialized_end = 317
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View File

@@ -21,6 +21,7 @@ from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@@ -29,25 +30,30 @@ _sym_db = _symbol_database.Default()
from . import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2 from . import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3'
)
_RAWDATA = DESCRIPTOR.message_types_by_name['RawData'] _RAWDATA = DESCRIPTOR.message_types_by_name['RawData']
RawData = _reflection.GeneratedProtocolMessageType('RawData', (_message.Message,), { RawData = _reflection.GeneratedProtocolMessageType(
'DESCRIPTOR' : _RAWDATA, 'RawData',
'__module__' : 'emulated_bluetooth_pb2' (_message.Message,),
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.RawData) {
}) 'DESCRIPTOR': _RAWDATA,
'__module__': 'emulated_bluetooth_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.RawData)
},
)
_sym_db.RegisterMessage(RawData) _sym_db.RegisterMessage(RawData)
_EMULATEDBLUETOOTHSERVICE = DESCRIPTOR.services_by_name['EmulatedBluetoothService'] _EMULATEDBLUETOOTHSERVICE = DESCRIPTOR.services_by_name['EmulatedBluetoothService']
if _descriptor._USE_C_DESCRIPTORS == False: if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\036com.android.emulator.bluetoothP\001' DESCRIPTOR._serialized_options = b'\n\036com.android.emulator.bluetoothP\001'
_RAWDATA._serialized_start=91 _RAWDATA._serialized_start = 91
_RAWDATA._serialized_end=116 _RAWDATA._serialized_end = 116
_EMULATEDBLUETOOTHSERVICE._serialized_start=119 _EMULATEDBLUETOOTHSERVICE._serialized_start = 119
_EMULATEDBLUETOOTHSERVICE._serialized_end=450 _EMULATEDBLUETOOTHSERVICE._serialized_end = 450
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View File

@@ -39,20 +39,20 @@ class EmulatedBluetoothServiceStub(object):
channel: A grpc.Channel. channel: A grpc.Channel.
""" """
self.registerClassicPhy = channel.stream_stream( self.registerClassicPhy = channel.stream_stream(
'/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy', '/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString, request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
response_deserializer=emulated__bluetooth__pb2.RawData.FromString, response_deserializer=emulated__bluetooth__pb2.RawData.FromString,
) )
self.registerBlePhy = channel.stream_stream( self.registerBlePhy = channel.stream_stream(
'/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy', '/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString, request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
response_deserializer=emulated__bluetooth__pb2.RawData.FromString, response_deserializer=emulated__bluetooth__pb2.RawData.FromString,
) )
self.registerHCIDevice = channel.stream_stream( self.registerHCIDevice = channel.stream_stream(
'/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice', '/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString, response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
) )
class EmulatedBluetoothServiceServicer(object): class EmulatedBluetoothServiceServicer(object):
@@ -121,28 +121,29 @@ class EmulatedBluetoothServiceServicer(object):
def add_EmulatedBluetoothServiceServicer_to_server(servicer, server): def add_EmulatedBluetoothServiceServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
'registerClassicPhy': grpc.stream_stream_rpc_method_handler( 'registerClassicPhy': grpc.stream_stream_rpc_method_handler(
servicer.registerClassicPhy, servicer.registerClassicPhy,
request_deserializer=emulated__bluetooth__pb2.RawData.FromString, request_deserializer=emulated__bluetooth__pb2.RawData.FromString,
response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString, response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
), ),
'registerBlePhy': grpc.stream_stream_rpc_method_handler( 'registerBlePhy': grpc.stream_stream_rpc_method_handler(
servicer.registerBlePhy, servicer.registerBlePhy,
request_deserializer=emulated__bluetooth__pb2.RawData.FromString, request_deserializer=emulated__bluetooth__pb2.RawData.FromString,
response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString, response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
), ),
'registerHCIDevice': grpc.stream_stream_rpc_method_handler( 'registerHCIDevice': grpc.stream_stream_rpc_method_handler(
servicer.registerHCIDevice, servicer.registerHCIDevice,
request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString, request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
), ),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
'android.emulation.bluetooth.EmulatedBluetoothService', rpc_method_handlers) 'android.emulation.bluetooth.EmulatedBluetoothService', rpc_method_handlers
)
server.add_generic_rpc_handlers((generic_handler,)) server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API. # This class is part of an EXPERIMENTAL API.
class EmulatedBluetoothService(object): class EmulatedBluetoothService(object):
"""An Emulated Bluetooth Service exposes the emulated bluetooth chip from the """An Emulated Bluetooth Service exposes the emulated bluetooth chip from the
android emulator. It allows you to register emulated bluetooth devices and android emulator. It allows you to register emulated bluetooth devices and
@@ -156,52 +157,88 @@ class EmulatedBluetoothService(object):
""" """
@staticmethod @staticmethod
def registerClassicPhy(request_iterator, def registerClassicPhy(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target, target,
options=(), '/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
emulated__bluetooth__pb2.RawData.SerializeToString, emulated__bluetooth__pb2.RawData.SerializeToString,
emulated__bluetooth__pb2.RawData.FromString, emulated__bluetooth__pb2.RawData.FromString,
options, channel_credentials, options,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
@staticmethod @staticmethod
def registerBlePhy(request_iterator, def registerBlePhy(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target, target,
options=(), '/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
emulated__bluetooth__pb2.RawData.SerializeToString, emulated__bluetooth__pb2.RawData.SerializeToString,
emulated__bluetooth__pb2.RawData.FromString, emulated__bluetooth__pb2.RawData.FromString,
options, channel_credentials, options,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
@staticmethod @staticmethod
def registerHCIDevice(request_iterator, def registerHCIDevice(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target, target,
options=(), '/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
emulated__bluetooth__packets__pb2.HCIPacket.FromString, emulated__bluetooth__packets__pb2.HCIPacket.FromString,
options, channel_credentials, options,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)

View File

@@ -21,6 +21,7 @@ from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@@ -29,15 +30,16 @@ _sym_db = _symbol_database.Default()
import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2 import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1d\x65mulated_bluetooth_vhci.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto2y\n\x15VhciForwardingService\x12`\n\nattachVhci\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x1d\x65mulated_bluetooth_vhci.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto2y\n\x15VhciForwardingService\x12`\n\nattachVhci\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3'
)
_VHCIFORWARDINGSERVICE = DESCRIPTOR.services_by_name['VhciForwardingService'] _VHCIFORWARDINGSERVICE = DESCRIPTOR.services_by_name['VhciForwardingService']
if _descriptor._USE_C_DESCRIPTORS == False: if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth' DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth'
_VHCIFORWARDINGSERVICE._serialized_start=96 _VHCIFORWARDINGSERVICE._serialized_start = 96
_VHCIFORWARDINGSERVICE._serialized_end=217 _VHCIFORWARDINGSERVICE._serialized_end = 217
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View File

@@ -35,10 +35,10 @@ class VhciForwardingServiceStub(object):
channel: A grpc.Channel. channel: A grpc.Channel.
""" """
self.attachVhci = channel.stream_stream( self.attachVhci = channel.stream_stream(
'/android.emulation.bluetooth.VhciForwardingService/attachVhci', '/android.emulation.bluetooth.VhciForwardingService/attachVhci',
request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString, response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
) )
class VhciForwardingServiceServicer(object): class VhciForwardingServiceServicer(object):
@@ -75,18 +75,19 @@ class VhciForwardingServiceServicer(object):
def add_VhciForwardingServiceServicer_to_server(servicer, server): def add_VhciForwardingServiceServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
'attachVhci': grpc.stream_stream_rpc_method_handler( 'attachVhci': grpc.stream_stream_rpc_method_handler(
servicer.attachVhci, servicer.attachVhci,
request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString, request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
), ),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
'android.emulation.bluetooth.VhciForwardingService', rpc_method_handlers) 'android.emulation.bluetooth.VhciForwardingService', rpc_method_handlers
)
server.add_generic_rpc_handlers((generic_handler,)) server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API. # This class is part of an EXPERIMENTAL API.
class VhciForwardingService(object): class VhciForwardingService(object):
"""This is a service which allows you to directly intercept the VHCI packets """This is a service which allows you to directly intercept the VHCI packets
that are coming and going to the device before they are delivered to that are coming and going to the device before they are delivered to
@@ -97,18 +98,30 @@ class VhciForwardingService(object):
""" """
@staticmethod @staticmethod
def attachVhci(request_iterator, def attachVhci(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target, target,
options=(), '/android.emulation.bluetooth.VhciForwardingService/attachVhci',
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.VhciForwardingService/attachVhci',
emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
emulated__bluetooth__packets__pb2.HCIPacket.FromString, emulated__bluetooth__packets__pb2.HCIPacket.FromString,
options, channel_credentials, options,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)

View File

@@ -30,8 +30,9 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_file_transport(spec): async def open_file_transport(spec):
''' '''
Open a File transport (typically not for a real file, but for a PTY or other unix virtual files). Open a File transport (typically not for a real file, but for a PTY or other unix
The parameter string is the path of the file to open virtual files).
The parameter string is the path of the file to open.
''' '''
# Open the file # Open the file
@@ -39,14 +40,12 @@ async def open_file_transport(spec):
# Setup reading # Setup reading
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe( read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
lambda: StreamPacketSource(), StreamPacketSource, file
file
) )
# Setup writing # Setup writing
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe( write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
lambda: asyncio.BaseProtocol(), asyncio.BaseProtocol, file
file
) )
packet_sink = StreamPacketSink(write_transport) packet_sink = StreamPacketSink(write_transport)
@@ -57,4 +56,3 @@ async def open_file_transport(spec):
file.close() file.close()
return FileTransport(packet_source, packet_sink) return FileTransport(packet_source, packet_sink)

View File

@@ -40,15 +40,21 @@ async def open_hci_socket_transport(spec):
or a 0-based integer to indicate the adapter number. or a 0-based integer to indicate the adapter number.
''' '''
HCI_CHANNEL_USER = 1 HCI_CHANNEL_USER = 1 # pylint: disable=invalid-name
# Create a raw HCI socket # Create a raw HCI socket
try: try:
hci_socket = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_RAW | socket.SOCK_NONBLOCK, socket.BTPROTO_HCI) hci_socket = socket.socket(
except AttributeError: socket.AF_BLUETOOTH,
socket.SOCK_RAW | socket.SOCK_NONBLOCK,
socket.BTPROTO_HCI,
)
except AttributeError as error:
# Not supported on this platform # Not supported on this platform
logger.info("HCI sockets not supported on this platform") logger.info("HCI sockets not supported on this platform")
raise Exception('Bluetooth HCI sockets not supported on this platform') raise Exception(
'Bluetooth HCI sockets not supported on this platform'
) from error
# Compute the adapter index # Compute the adapter index
if spec is None: if spec is None:
@@ -62,20 +68,37 @@ async def open_hci_socket_transport(spec):
try: try:
ctypes.cdll.LoadLibrary('libc.so.6') ctypes.cdll.LoadLibrary('libc.so.6')
libc = ctypes.CDLL('libc.so.6', use_errno=True) libc = ctypes.CDLL('libc.so.6', use_errno=True)
except OSError: except OSError as error:
logger.info("HCI sockets not supported on this platform") logger.info("HCI sockets not supported on this platform")
raise Exception('Bluetooth HCI sockets not supported on this platform') raise Exception(
'Bluetooth HCI sockets not supported on this platform'
) from error
libc.bind.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_char), ctypes.c_int) libc.bind.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_char), ctypes.c_int)
libc.bind.restype = ctypes.c_int libc.bind.restype = ctypes.c_int
bind_address = struct.pack('<HHH', socket.AF_BLUETOOTH, adapter_index, HCI_CHANNEL_USER) bind_address = struct.pack(
if libc.bind(hci_socket.fileno(), ctypes.create_string_buffer(bind_address), len(bind_address)) != 0: # pylint: disable=no-member
'<HHH',
socket.AF_BLUETOOTH,
adapter_index,
HCI_CHANNEL_USER,
)
if (
libc.bind(
hci_socket.fileno(),
ctypes.create_string_buffer(bind_address),
len(bind_address),
)
!= 0
):
raise IOError(ctypes.get_errno(), os.strerror(ctypes.get_errno())) raise IOError(ctypes.get_errno(), os.strerror(ctypes.get_errno()))
class HciSocketSource(ParserSource): class HciSocketSource(ParserSource):
def __init__(self, socket): def __init__(self, hci_socket):
super().__init__() super().__init__()
self.socket = socket self.socket = hci_socket
asyncio.get_running_loop().add_reader(socket.fileno(), self.recv_until_would_block) asyncio.get_running_loop().add_reader(
socket.fileno(), self.recv_until_would_block
)
def recv_until_would_block(self): def recv_until_would_block(self):
logger.debug('recv until would block +++') logger.debug('recv until would block +++')
@@ -92,9 +115,9 @@ async def open_hci_socket_transport(spec):
asyncio.get_running_loop().remove_reader(self.socket.fileno()) asyncio.get_running_loop().remove_reader(self.socket.fileno())
class HciSocketSink: class HciSocketSink:
def __init__(self, socket): def __init__(self, hci_socket):
self.socket = socket self.socket = hci_socket
self.packets = collections.deque() self.packets = collections.deque()
self.writer_added = False self.writer_added = False
def send_until_would_block(self): def send_until_would_block(self):
@@ -112,9 +135,14 @@ async def open_hci_socket_transport(spec):
break break
if self.packets: if self.packets:
# There's still something to send, ensure that we are monitoring the socket # There's still something to send, ensure that we are monitoring the
# socket
if not self.writer_added: if not self.writer_added:
asyncio.get_running_loop().add_writer(socket.fileno(), self.send_until_would_block) asyncio.get_running_loop().add_writer(
# pylint: disable=no-member
socket.fileno(),
self.send_until_would_block,
)
self.writer_added = True self.writer_added = True
else: else:
# Nothing left to send, stop monitoring the socket # Nothing left to send, stop monitoring the socket
@@ -131,9 +159,9 @@ async def open_hci_socket_transport(spec):
asyncio.get_running_loop().remove_writer(self.socket.fileno()) asyncio.get_running_loop().remove_writer(self.socket.fileno())
class HciSocketTransport(Transport): class HciSocketTransport(Transport):
def __init__(self, socket, source, sink): def __init__(self, hci_socket, source, sink):
super().__init__(source, sink) super().__init__(source, sink)
self.socket = socket self.socket = hci_socket
async def close(self): async def close(self):
logger.debug('closing HCI socket transport') logger.debug('closing HCI socket transport')

View File

@@ -47,13 +47,11 @@ async def open_pty_transport(spec):
tty.setraw(replica) tty.setraw(replica)
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe( read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
lambda: StreamPacketSource(), StreamPacketSource, io.open(primary, 'rb', closefd=False)
io.open(primary, 'rb', closefd=False)
) )
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe( write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
lambda: asyncio.BaseProtocol(), asyncio.BaseProtocol, io.open(primary, 'wb', closefd=False)
io.open(primary, 'wb', closefd=False)
) )
packet_sink = StreamPacketSink(write_transport) packet_sink = StreamPacketSink(write_transport)

View File

@@ -17,10 +17,12 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import logging import logging
import usb.core
import usb.util
import threading import threading
import time import time
import libusb_package
import usb.core
import usb.util
from colors import color from colors import color
from .common import Transport, ParserSource from .common import Transport, ParserSource
@@ -48,25 +50,26 @@ async def open_pyusb_transport(spec):
04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901 04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901
''' '''
USB_RECIPIENT_DEVICE = 0x00 # pylint: disable=invalid-name
USB_REQUEST_TYPE_CLASS = 0x01 << 5 USB_RECIPIENT_DEVICE = 0x00
USB_ENDPOINT_EVENTS_IN = 0x81 USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_ENDPOINT_ACL_IN = 0x82 USB_ENDPOINT_EVENTS_IN = 0x81
USB_ENDPOINT_SCO_IN = 0x83 USB_ENDPOINT_ACL_IN = 0x82
USB_ENDPOINT_ACL_OUT = 0x02 USB_ENDPOINT_SCO_IN = 0x83
USB_ENDPOINT_ACL_OUT = 0x02
# USB_ENDPOINT_SCO_OUT = 0x03 # USB_ENDPOINT_SCO_OUT = 0x03
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0 USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01 USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01 USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
READ_SIZE = 1024 READ_SIZE = 1024
READ_TIMEOUT = 1000 READ_TIMEOUT = 1000
class UsbPacketSink: class UsbPacketSink:
def __init__(self, device): def __init__(self, device):
self.device = device self.device = device
self.thread = threading.Thread(target=self.run) self.thread = threading.Thread(target=self.run)
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.stop_event = None self.stop_event = None
def on_packet(self, packet): def on_packet(self, packet):
@@ -80,9 +83,17 @@ async def open_pyusb_transport(spec):
if packet_type == hci.HCI_ACL_DATA_PACKET: if packet_type == hci.HCI_ACL_DATA_PACKET:
self.device.write(USB_ENDPOINT_ACL_OUT, packet[1:]) self.device.write(USB_ENDPOINT_ACL_OUT, packet[1:])
elif packet_type == hci.HCI_COMMAND_PACKET: elif packet_type == hci.HCI_COMMAND_PACKET:
self.device.ctrl_transfer(USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 0, 0, 0, packet[1:]) self.device.ctrl_transfer(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:],
)
else: else:
logger.warning(color(f'unsupported packet type {packet_type}', 'red')) logger.warning(
color(f'unsupported packet type {packet_type}', 'red')
)
except usb.core.USBTimeoutError: except usb.core.USBTimeoutError:
logger.warning('USB Write Timeout') logger.warning('USB Write Timeout')
except usb.core.USBError as error: except usb.core.USBError as error:
@@ -100,22 +111,21 @@ async def open_pyusb_transport(spec):
def run(self): def run(self):
while self.stop_event is None: while self.stop_event is None:
time.sleep(1) time.sleep(1)
self.loop.call_soon_threadsafe(lambda: self.stop_event.set()) self.loop.call_soon_threadsafe(self.stop_event.set)
class UsbPacketSource(asyncio.Protocol, ParserSource): class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, device, sco_enabled): def __init__(self, device, sco_enabled):
super().__init__() super().__init__()
self.device = device self.device = device
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.dequeue_task = None
self.event_thread = threading.Thread( self.event_thread = threading.Thread(
target=self.run, target=self.run, args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET)
args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET)
) )
self.event_thread.stop_event = None self.event_thread.stop_event = None
self.acl_thread = threading.Thread( self.acl_thread = threading.Thread(
target=self.run, target=self.run, args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET)
args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET)
) )
self.acl_thread.stop_event = None self.acl_thread.stop_event = None
@@ -124,12 +134,12 @@ async def open_pyusb_transport(spec):
if sco_enabled: if sco_enabled:
self.sco_thread = threading.Thread( self.sco_thread = threading.Thread(
target=self.run, target=self.run,
args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET) args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET),
) )
self.sco_thread.stop_event = None self.sco_thread.stop_event = None
def data_received(self, packet): def data_received(self, data):
self.parser.feed_data(packet) self.parser.feed_data(data)
def enqueue(self, packet): def enqueue(self, packet):
self.queue.put_nowait(packet) self.queue.put_nowait(packet)
@@ -155,7 +165,7 @@ async def open_pyusb_transport(spec):
# Create stop events and wait for them to be signaled # Create stop events and wait for them to be signaled
self.event_thread.stop_event = asyncio.Event() self.event_thread.stop_event = asyncio.Event()
self.acl_thread.stop_event = asyncio.Event() self.acl_thread.stop_event = asyncio.Event()
await self.event_thread.stop_event.wait() await self.event_thread.stop_event.wait()
await self.acl_thread.stop_event.wait() await self.acl_thread.stop_event.wait()
if self.sco_enabled: if self.sco_enabled:
@@ -173,16 +183,17 @@ async def open_pyusb_transport(spec):
except usb.core.USBTimeoutError: except usb.core.USBTimeoutError:
continue continue
except usb.core.USBError: except usb.core.USBError:
# Don't log this: because pyusb doesn't really support multiple threads # Don't log this: because pyusb doesn't really support multiple
# reading at the same time, we can get occasional USBError(errno=5) # threads reading at the same time, we can get occasional
# Input/Output errors reported, but they seem to be harmless. # USBError(errno=5) Input/Output errors reported, but they seem to
# be harmless.
# Until support for async or multi-thread support is added to pyusb, # Until support for async or multi-thread support is added to pyusb,
# we'll just live with this as is... # we'll just live with this as is...
# logger.warning(f'USB read error: {error}') # logger.warning(f'USB read error: {error}')
time.sleep(1) # Sleep one second to avoid busy looping time.sleep(1) # Sleep one second to avoid busy looping
stop_event = current_thread.stop_event stop_event = current_thread.stop_event
self.loop.call_soon_threadsafe(lambda: stop_event.set()) self.loop.call_soon_threadsafe(stop_event.set)
class UsbTransport(Transport): class UsbTransport(Transport):
def __init__(self, device, source, sink): def __init__(self, device, source, sink):
@@ -197,15 +208,19 @@ async def open_pyusb_transport(spec):
# Find the device according to the spec moniker # Find the device according to the spec moniker
if ':' in spec: if ':' in spec:
vendor_id, product_id = spec.split(':') vendor_id, product_id = spec.split(':')
device = usb.core.find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16)) device = libusb_package.find(
idVendor=int(vendor_id, 16), idProduct=int(product_id, 16)
)
else: else:
device_index = int(spec) device_index = int(spec)
devices = list(usb.core.find( devices = list(
find_all = 1, libusb_package.find(
bDeviceClass = USB_DEVICE_CLASS_WIRELESS_CONTROLLER, find_all=1,
bDeviceSubClass = USB_DEVICE_SUBCLASS_RF_CONTROLLER, bDeviceClass=USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
bDeviceProtocol = USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER bDeviceSubClass=USB_DEVICE_SUBCLASS_RF_CONTROLLER,
)) bDeviceProtocol=USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
)
if len(devices) > device_index: if len(devices) > device_index:
device = devices[device_index] device = devices[device_index]
else: else:
@@ -232,6 +247,7 @@ async def open_pyusb_transport(spec):
# Select an alternate setting for SCO, if available # Select an alternate setting for SCO, if available
sco_enabled = False sco_enabled = False
# pylint: disable=line-too-long
# NOTE: this is disabled for now, because SCO with alternate settings is broken, # NOTE: this is disabled for now, because SCO with alternate settings is broken,
# see: https://github.com/libusb/libusb/issues/36 # see: https://github.com/libusb/libusb/issues/36
# #
@@ -273,4 +289,4 @@ async def open_pyusb_transport(spec):
packet_source.start() packet_source.start()
packet_sink.start() packet_sink.start()
return UsbTransport(device, packet_source, packet_sink) return UsbTransport(device, packet_source, packet_sink)

View File

@@ -60,13 +60,12 @@ async def open_serial_transport(spec):
device = spec device = spec
serial_transport, packet_source = await serial_asyncio.create_serial_connection( serial_transport, packet_source = await serial_asyncio.create_serial_connection(
asyncio.get_running_loop(), asyncio.get_running_loop(),
lambda: StreamPacketSource(), StreamPacketSource,
device, device,
baudrate=speed, baudrate=speed,
rtscts=rtscts, rtscts=rtscts,
dsrdtr=dsrdtr dsrdtr=dsrdtr,
) )
packet_sink = StreamPacketSink(serial_transport) packet_sink = StreamPacketSink(serial_transport)
return Transport(packet_source, packet_sink) return Transport(packet_source, packet_sink)

View File

@@ -37,13 +37,13 @@ async def open_tcp_client_transport(spec):
''' '''
class TcpPacketSource(StreamPacketSource): class TcpPacketSource(StreamPacketSource):
def connection_lost(self, error): def connection_lost(self, exc):
logger.debug(f'connection lost: {error}') logger.debug(f'connection lost: {exc}')
self.terminated.set_result(error) self.terminated.set_result(exc)
remote_host, remote_port = spec.split(':') remote_host, remote_port = spec.split(':')
tcp_transport, packet_source = await asyncio.get_running_loop().create_connection( tcp_transport, packet_source = await asyncio.get_running_loop().create_connection(
lambda: TcpPacketSource(), TcpPacketSource,
host=remote_host, host=remote_host,
port=int(remote_port), port=int(remote_port),
) )

View File

@@ -45,12 +45,12 @@ async def open_tcp_server_transport(spec):
class TcpServerProtocol: class TcpServerProtocol:
def __init__(self, packet_source, packet_sink): def __init__(self, packet_source, packet_sink):
self.packet_source = packet_source self.packet_source = packet_source
self.packet_sink = packet_sink self.packet_sink = packet_sink
# Called when a new connection is established # Called when a new connection is established
def connection_made(self, transport): def connection_made(self, transport):
peername = transport.get_extra_info('peername') peer_name = transport.get_extra_info('peer_name')
logger.debug('connection from {}'.format(peername)) logger.debug(f'connection from {peer_name}')
self.packet_sink.transport = transport self.packet_sink.transport = transport
# Called when the client is disconnected # Called when the client is disconnected
@@ -78,7 +78,7 @@ async def open_tcp_server_transport(spec):
local_host, local_port = spec.split(':') local_host, local_port = spec.split(':')
packet_source = StreamPacketSource() packet_source = StreamPacketSource()
packet_sink = TcpServerPacketSink() packet_sink = TcpServerPacketSink()
await asyncio.get_running_loop().create_server( await asyncio.get_running_loop().create_server(
lambda: TcpServerProtocol(packet_source, packet_sink), lambda: TcpServerProtocol(packet_source, packet_sink),
host=local_host if local_host != '_' else None, host=local_host if local_host != '_' else None,

View File

@@ -53,10 +53,13 @@ async def open_udp_transport(spec):
local, remote = spec.split(',') local, remote = spec.split(',')
local_host, local_port = local.split(':') local_host, local_port = local.split(':')
remote_host, remote_port = remote.split(':') remote_host, remote_port = remote.split(':')
udp_transport, packet_source = await asyncio.get_running_loop().create_datagram_endpoint( (
lambda: UdpPacketSource(), udp_transport,
packet_source,
) = await asyncio.get_running_loop().create_datagram_endpoint(
UdpPacketSource,
local_addr=(local_host, int(local_port)), local_addr=(local_host, int(local_port)),
remote_addr=(remote_host, int(remote_port)) remote_addr=(remote_host, int(remote_port)),
) )
packet_sink = UdpPacketSink(udp_transport) packet_sink = UdpPacketSink(udp_transport)

View File

@@ -17,9 +17,13 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import logging import logging
import usb1
import threading import threading
import collections import collections
import ctypes
import platform
import libusb_package
import usb1
from colors import color from colors import color
from .common import Transport, ParserSource from .common import Transport, ParserSource
@@ -33,42 +37,79 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def load_libusb():
'''
Attempt to load the libusb-1.0 C library from libusb_package in site-packages.
If the library exists, we create a DLL object and initialize the usb1 backend.
This only needs to be done once, but before a usb1.USBContext is created.
If the library does not exists, do nothing and usb1 will search default system paths
when usb1.USBContext is created.
'''
if libusb_path := libusb_package.get_library_path():
logger.debug(f'loading libusb library at {libusb_path}')
dll_loader = ctypes.WinDLL if platform.system() == 'Windows' else ctypes.CDLL
libusb_dll = dll_loader(str(libusb_path), use_errno=True, use_last_error=True)
usb1.loadLibrary(libusb_dll)
async def open_usb_transport(spec): async def open_usb_transport(spec):
''' '''
Open a USB transport. Open a USB transport.
The parameter string has this syntax: The moniker string has this syntax:
either <index> or <vendor>:<product>[/<serial-number>] either <index> or
<vendor>:<product> or
<vendor>:<product>/<serial-number>] or
<vendor>:<product>#<index>
With <index> as the 0-based index to select amongst all the devices that appear With <index> as the 0-based index to select amongst all the devices that appear
to be supporting Bluetooth HCI (0 being the first one), or to be supporting Bluetooth HCI (0 being the first one), or
Where <vendor> and <product> are the vendor ID and product ID in hexadecimal. The Where <vendor> and <product> are the vendor ID and product ID in hexadecimal. The
/<serial-number> suffix max be specified when more than one device with the same /<serial-number> suffix or #<index> suffix max be specified when more than one
vendor and product identifiers are present. device with the same vendor and product identifiers are present.
In addition, if the moniker ends with the symbol "!", the device will be used in
"forced" mode:
the first USB interface of the device will be used, regardless of the interface
class/subclass.
This may be useful for some devices that use a custom class/subclass but may
nonetheless work as-is.
Examples: Examples:
0 --> the first BT USB dongle 0 --> the first BT USB dongle
04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901 04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901
04b4:f901/00E04C239987 --> the BT USB dongle with vendor=04b4 and product=f901 and serial number 00E04C239987 04b4:f901#2 --> the third USB device with vendor=04b4 and product=f901
04b4:f901/00E04C239987 --> the BT USB dongle with vendor=04b4 and product=f901 and
serial number 00E04C239987
usb:0B05:17CB! --> the BT USB dongle vendor=0B05 and product=17CB, in "forced" mode.
''' '''
USB_RECIPIENT_DEVICE = 0x00 # pylint: disable=invalid-name
USB_REQUEST_TYPE_CLASS = 0x01 << 5 USB_RECIPIENT_DEVICE = 0x00
USB_ENDPOINT_EVENTS_IN = 0x81 USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_ENDPOINT_ACL_IN = 0x82 USB_DEVICE_CLASS_DEVICE = 0x00
USB_ENDPOINT_ACL_OUT = 0x02 USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0 USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01 USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
USB_ENDPOINT_TRANSFER_TYPE_BULK = 0x02
USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT = 0x03
USB_ENDPOINT_IN = 0x80
USB_BT_HCI_CLASS_TUPLE = (
USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
USB_DEVICE_SUBCLASS_RF_CONTROLLER,
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
READ_SIZE = 1024 READ_SIZE = 1024
class UsbPacketSink: class UsbPacketSink:
def __init__(self, device): def __init__(self, device, acl_out):
self.device = device self.device = device
self.transfer = device.getTransfer() self.acl_out = acl_out
self.packets = collections.deque() # Queue of packets waiting to be sent self.transfer = device.getTransfer()
self.loop = asyncio.get_running_loop() self.packets = collections.deque() # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop()
self.cancel_done = self.loop.create_future() self.cancel_done = self.loop.create_future()
self.closed = False self.closed = False
def start(self): def start(self):
pass pass
@@ -92,12 +133,15 @@ async def open_usb_transport(spec):
status = transfer.getStatus() status = transfer.getStatus()
# logger.debug(f'<<< USB out transfer callback: status={status}') # logger.debug(f'<<< USB out transfer callback: status={status}')
# pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED: if status == usb1.TRANSFER_COMPLETED:
self.loop.call_soon_threadsafe(self.on_packet_sent_) self.loop.call_soon_threadsafe(self.on_packet_sent_)
elif status == usb1.TRANSFER_CANCELLED: elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None) self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
else: else:
logger.warning(color(f'!!! out transfer not completed: status={status}', 'red')) logger.warning(
color(f'!!! out transfer not completed: status={status}', 'red')
)
def on_packet_sent_(self): def on_packet_sent_(self):
if self.packets: if self.packets:
@@ -112,32 +156,38 @@ async def open_usb_transport(spec):
packet_type = packet[0] packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET: if packet_type == hci.HCI_ACL_DATA_PACKET:
self.transfer.setBulk( self.transfer.setBulk(
USB_ENDPOINT_ACL_OUT, self.acl_out, packet[1:], callback=self.on_packet_sent
packet[1:],
callback=self.on_packet_sent
) )
logger.debug('submit ACL') logger.debug('submit ACL')
self.transfer.submit() self.transfer.submit()
elif packet_type == hci.HCI_COMMAND_PACKET: elif packet_type == hci.HCI_COMMAND_PACKET:
self.transfer.setControl( self.transfer.setControl(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 0, 0, 0, USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:], packet[1:],
callback=self.on_packet_sent callback=self.on_packet_sent,
) )
logger.debug('submit COMMAND') logger.debug('submit COMMAND')
self.transfer.submit() self.transfer.submit()
else: else:
logger.warning(color(f'unsupported packet type {packet_type}', 'red')) logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
async def close(self): def close(self):
self.closed = True self.closed = True
async def terminate(self):
if not self.closed:
self.close()
# Empty the packet queue so that we don't send any more data # Empty the packet queue so that we don't send any more data
self.packets.clear() self.packets.clear()
# If we have a transfer in flight, cancel it # If we have a transfer in flight, cancel it
if self.transfer.isSubmitted(): if self.transfer.isSubmitted():
# Try to cancel the transfer, but that may fail because it may have already completed # Try to cancel the transfer, but that may fail because it may have
# already completed
try: try:
self.transfer.cancel() self.transfer.cancel()
@@ -148,18 +198,23 @@ async def open_usb_transport(spec):
logger.debug('OUT transfer likely already completed') logger.debug('OUT transfer likely already completed')
class UsbPacketSource(asyncio.Protocol, ParserSource): class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, context, device): def __init__(self, context, device, acl_in, events_in):
super().__init__() super().__init__()
self.context = context self.context = context
self.device = device self.device = device
self.loop = asyncio.get_running_loop() self.acl_in = acl_in
self.queue = asyncio.Queue() self.events_in = events_in
self.closed = False self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.dequeue_task = None
self.closed = False
self.event_loop_done = self.loop.create_future() self.event_loop_done = self.loop.create_future()
self.cancel_done = { self.cancel_done = {
hci.HCI_EVENT_PACKET: self.loop.create_future(), hci.HCI_EVENT_PACKET: self.loop.create_future(),
hci.HCI_ACL_DATA_PACKET: self.loop.create_future() hci.HCI_ACL_DATA_PACKET: self.loop.create_future(),
} }
self.events_in_transfer = None
self.acl_in_transfer = None
# Create a thread to process events # Create a thread to process events
self.event_thread = threading.Thread(target=self.run) self.event_thread = threading.Thread(target=self.run)
@@ -168,19 +223,19 @@ async def open_usb_transport(spec):
# Set up transfer objects for input # Set up transfer objects for input
self.events_in_transfer = device.getTransfer() self.events_in_transfer = device.getTransfer()
self.events_in_transfer.setInterrupt( self.events_in_transfer.setInterrupt(
USB_ENDPOINT_EVENTS_IN, self.events_in,
READ_SIZE, READ_SIZE,
callback=self.on_packet_received, callback=self.on_packet_received,
user_data=hci.HCI_EVENT_PACKET user_data=hci.HCI_EVENT_PACKET,
) )
self.events_in_transfer.submit() self.events_in_transfer.submit()
self.acl_in_transfer = device.getTransfer() self.acl_in_transfer = device.getTransfer()
self.acl_in_transfer.setBulk( self.acl_in_transfer.setBulk(
USB_ENDPOINT_ACL_IN, self.acl_in,
READ_SIZE, READ_SIZE,
callback=self.on_packet_received, callback=self.on_packet_received,
user_data=hci.HCI_ACL_DATA_PACKET user_data=hci.HCI_ACL_DATA_PACKET,
) )
self.acl_in_transfer.submit() self.acl_in_transfer.submit()
@@ -190,16 +245,28 @@ async def open_usb_transport(spec):
def on_packet_received(self, transfer): def on_packet_received(self, transfer):
packet_type = transfer.getUserData() packet_type = transfer.getUserData()
status = transfer.getStatus() status = transfer.getStatus()
# logger.debug(f'<<< USB IN transfer callback: status={status} packet_type={packet_type}') # logger.debug(
# f'<<< USB IN transfer callback: status={status} '
# f'packet_type={packet_type} '
# f'length={transfer.getActualLength()}'
# )
# pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED: if status == usb1.TRANSFER_COMPLETED:
packet = bytes([packet_type]) + transfer.getBuffer()[:transfer.getActualLength()] packet = (
bytes([packet_type])
+ transfer.getBuffer()[: transfer.getActualLength()]
)
self.loop.call_soon_threadsafe(self.queue.put_nowait, packet) self.loop.call_soon_threadsafe(self.queue.put_nowait, packet)
elif status == usb1.TRANSFER_CANCELLED: elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done[packet_type].set_result, None) self.loop.call_soon_threadsafe(
self.cancel_done[packet_type].set_result, None
)
return return
else: else:
logger.warning(color(f'!!! transfer not completed: status={status}', 'red')) logger.warning(
color(f'!!! transfer not completed: status={status}', 'red')
)
# Re-submit the transfer so we can receive more data # Re-submit the transfer so we can receive more data
transfer.submit() transfer.submit()
@@ -214,84 +281,143 @@ async def open_usb_transport(spec):
def run(self): def run(self):
logger.debug('starting USB event loop') logger.debug('starting USB event loop')
while self.events_in_transfer.isSubmitted() or self.acl_in_transfer.isSubmitted(): while (
self.events_in_transfer.isSubmitted()
or self.acl_in_transfer.isSubmitted()
):
# pylint: disable=no-member
try: try:
self.context.handleEvents() self.context.handleEvents()
except usb1.USBErrorInterrupted: except usb1.USBErrorInterrupted:
pass pass
logger.debug('USB event loop done') logger.debug('USB event loop done')
self.event_loop_done.set_result(None) self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
async def close(self): def close(self):
self.closed = True self.closed = True
async def terminate(self):
if not self.closed:
self.close()
self.dequeue_task.cancel() self.dequeue_task.cancel()
# Cancel the transfers # Cancel the transfers
for transfer in (self.events_in_transfer, self.acl_in_transfer): for transfer in (self.events_in_transfer, self.acl_in_transfer):
if transfer.isSubmitted(): if transfer.isSubmitted():
# Try to cancel the transfer, but that may fail because it may have already completed # Try to cancel the transfer, but that may fail because it may have
# already completed
packet_type = transfer.getUserData() packet_type = transfer.getUserData()
try: try:
transfer.cancel() transfer.cancel()
logger.debug(f'waiting for IN[{packet_type}] transfer cancellation to be done...') logger.debug(
f'waiting for IN[{packet_type}] transfer cancellation '
'to be done...'
)
await self.cancel_done[packet_type] await self.cancel_done[packet_type]
logger.debug(f'IN[{packet_type}] transfer cancellation done') logger.debug(f'IN[{packet_type}] transfer cancellation done')
except usb1.USBError: except usb1.USBError:
logger.debug(f'IN[{packet_type}] transfer likely already completed') logger.debug(
f'IN[{packet_type}] transfer likely already completed'
)
# Wait for the thread to terminate # Wait for the thread to terminate
await self.event_loop_done await self.event_loop_done
class UsbTransport(Transport): class UsbTransport(Transport):
def __init__(self, context, device, interface, source, sink): def __init__(self, context, device, interface, setting, source, sink):
super().__init__(source, sink) super().__init__(source, sink)
self.context = context self.context = context
self.device = device self.device = device
self.interface = interface self.interface = interface
# Get exclusive access # Get exclusive access
device.claimInterface(interface) device.claimInterface(interface)
# Set the alternate setting if not the default
if setting != 0:
device.setInterfaceAltSetting(interface, setting)
# The source and sink can now start # The source and sink can now start
source.start() source.start()
sink.start() sink.start()
async def close(self): async def close(self):
await self.source.close() self.source.close()
await self.sink.close() self.sink.close()
await self.source.terminate()
await self.sink.terminate()
self.device.releaseInterface(self.interface) self.device.releaseInterface(self.interface)
self.device.close() self.device.close()
self.context.close() self.context.close()
# Find the device according to the spec moniker # Find the device according to the spec moniker
load_libusb()
context = usb1.USBContext() context = usb1.USBContext()
context.open() context.open()
try: try:
found = None found = None
if spec.endswith('!'):
spec = spec[:-1]
forced_mode = True
else:
forced_mode = False
if ':' in spec: if ':' in spec:
vendor_id, product_id = spec.split(':') vendor_id, product_id = spec.split(':')
serial_number = None
device_index = 0
if '/' in product_id: if '/' in product_id:
product_id, serial_number = product_id.split('/') product_id, serial_number = product_id.split('/')
for device in context.getDeviceIterator(skip_on_error=True): elif '#' in product_id:
if ( product_id, device_index_str = product_id.split('#')
device.getVendorID() == int(vendor_id, 16) and device_index = int(device_index_str)
device.getProductID() == int(product_id, 16) and
device.getSerialNumber() == serial_number for device in context.getDeviceIterator(skip_on_error=True):
): try:
device_serial_number = device.getSerialNumber()
except usb1.USBError:
device_serial_number = None
if (
device.getVendorID() == int(vendor_id, 16)
and device.getProductID() == int(product_id, 16)
and (serial_number is None or serial_number == device_serial_number)
):
if device_index == 0:
found = device found = device
break break
device.close() device_index -= 1
else: device.close()
found = context.getByVendorIDAndProductID(int(vendor_id, 16), int(product_id, 16), skip_on_error=True)
else: else:
# Look for a compatible device by index
def device_is_bluetooth_hci(device):
# Check if the device class indicates a match
if (
device.getDeviceClass(),
device.getDeviceSubClass(),
device.getDeviceProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
# If the device class is 'Device', look for a matching interface
if device.getDeviceClass() == USB_DEVICE_CLASS_DEVICE:
for configuration in device:
for interface in configuration:
for setting in interface:
if (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
return False
device_index = int(spec) device_index = int(spec)
for device in context.getDeviceIterator(skip_on_error=True): for device in context.getDeviceIterator(skip_on_error=True):
if ( if device_is_bluetooth_hci(device):
device.getDeviceClass() == USB_DEVICE_CLASS_WIRELESS_CONTROLLER and
device.getDeviceSubClass() == USB_DEVICE_SUBCLASS_RF_CONTROLLER and
device.getDeviceProtocol() == USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
):
if device_index == 0: if device_index == 0:
found = device found = device
break break
@@ -303,34 +429,107 @@ async def open_usb_transport(spec):
raise ValueError('device not found') raise ValueError('device not found')
logger.debug(f'USB Device: {found}') logger.debug(f'USB Device: {found}')
# Look for the first interface with the right class and endpoints
def find_endpoints(device):
# pylint: disable-next=too-many-nested-blocks
for (configuration_index, configuration) in enumerate(device):
interface = None
for interface in configuration:
setting = None
for setting in interface:
if (
not forced_mode
and (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
)
!= USB_BT_HCI_CLASS_TUPLE
):
continue
events_in = None
acl_in = None
acl_out = None
for endpoint in setting:
attributes = endpoint.getAttributes()
address = endpoint.getAddress()
if attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_BULK:
if address & USB_ENDPOINT_IN and acl_in is None:
acl_in = address
elif acl_out is None:
acl_out = address
elif (
attributes & 0x03
== USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT
):
if address & USB_ENDPOINT_IN and events_in is None:
events_in = address
# Return if we found all 3 endpoints
if (
acl_in is not None
and acl_out is not None
and events_in is not None
):
return (
configuration_index + 1,
setting.getNumber(),
setting.getAlternateSetting(),
acl_in,
acl_out,
events_in,
)
logger.debug(
f'skipping configuration {configuration_index + 1} / '
f'interface {setting.getNumber()}'
)
return None
endpoints = find_endpoints(found)
if endpoints is None:
raise ValueError('no compatible interface found for device')
(configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
logger.debug(
f'selected endpoints: configuration={configuration}, '
f'interface={interface}, '
f'setting={setting}, '
f'acl_in=0x{acl_in:02X}, '
f'acl_out=0x{acl_out:02X}, '
f'events_in=0x{events_in:02X}, '
)
device = found.open() device = found.open()
# Auto-detach the kernel driver if supported
# pylint: disable=no-member
if usb1.hasCapability(usb1.CAP_SUPPORTS_DETACH_KERNEL_DRIVER):
try:
logger.debug('auto-detaching kernel driver')
device.setAutoDetachKernelDriver(True)
except usb1.USBError as error:
logger.warning(f'unable to auto-detach kernel driver: {error}')
# Set the configuration if needed # Set the configuration if needed
try: try:
configuration = device.getConfiguration() current_configuration = device.getConfiguration()
logger.debug(f'current configuration = {configuration}') logger.debug(f'current configuration = {current_configuration}')
except usb1.USBError: except usb1.USBError:
current_configuration = 0
if current_configuration != configuration:
try: try:
logger.debug('setting configuration 1') logger.debug(f'setting configuration {configuration}')
device.setConfiguration(1) device.setConfiguration(configuration)
except usb1.USBError: except usb1.USBError:
logger.debug('failed to set configuration 1') logger.warning('failed to set configuration')
# Use the first interface source = UsbPacketSource(context, device, acl_in, events_in)
interface = 0 sink = UsbPacketSink(device, acl_out)
return UsbTransport(context, device, interface, setting, source, sink)
# Detach the kernel driver if supported and needed
if usb1.hasCapability(usb1.CAP_SUPPORTS_DETACH_KERNEL_DRIVER):
try:
if device.kernelDriverActive(interface):
logger.debug("detaching kernel driver")
device.detachKernelDriver(interface)
except usb1.USBError:
pass
source = UsbPacketSource(context, device)
sink = UsbPacketSink(device)
return UsbTransport(context, device, interface, source, sink)
except usb1.USBError as error: except usb1.USBError as error:
logger.warning(color(f'!!! failed to open USB device: {error}', 'red')) logger.warning(color(f'!!! failed to open USB device: {error}', 'red'))
context.close() context.close()

View File

@@ -33,7 +33,7 @@ async def open_vhci_transport(spec):
path at /dev/vhci), or the path of a VHCI device path at /dev/vhci), or the path of a VHCI device
''' '''
HCI_VENDOR_PKT = 0xff HCI_VENDOR_PKT = 0xFF
HCI_BREDR = 0x00 # Controller type HCI_BREDR = 0x00 # Controller type
# Open the VHCI device # Open the VHCI device
@@ -56,4 +56,3 @@ async def open_vhci_transport(spec):
transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR])) transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR]))
return transport return transport

View File

@@ -43,7 +43,7 @@ async def open_ws_client_transport(spec):
transport = PumpedTransport( transport = PumpedTransport(
PumpedPacketSource(websocket.recv), PumpedPacketSource(websocket.recv),
PumpedPacketSink(websocket.send), PumpedPacketSink(websocket.send),
websocket.close websocket.close,
) )
transport.start() transport.start()
return transport return transport

View File

@@ -41,30 +41,36 @@ async def open_ws_server_transport(spec):
class WsServerTransport(Transport): class WsServerTransport(Transport):
def __init__(self): def __init__(self):
source = ParserSource() source = ParserSource()
sink = PumpedPacketSink(self.send_packet) sink = PumpedPacketSink(self.send_packet)
self.connection = asyncio.get_running_loop().create_future() self.connection = asyncio.get_running_loop().create_future()
self.server = None
super().__init__(source, sink) super().__init__(source, sink)
async def serve(self, local_host, local_port): async def serve(self, local_host, local_port):
self.sink.start() self.sink.start()
# pylint: disable-next=no-member
self.server = await websockets.serve( self.server = await websockets.serve(
ws_handler = self.on_connection, ws_handler=self.on_connection,
host = local_host if local_host != '_' else None, host=local_host if local_host != '_' else None,
port = int(local_port) port=int(local_port),
) )
logger.debug(f'websocket server ready on port {local_port}') logger.debug(f'websocket server ready on port {local_port}')
async def on_connection(self, connection): async def on_connection(self, connection):
logger.debug(f'new connection on {connection.local_address} from {connection.remote_address}') logger.debug(
f'new connection on {connection.local_address} '
f'from {connection.remote_address}'
)
self.connection.set_result(connection) self.connection.set_result(connection)
# pylint: disable=no-member
try: try:
async for packet in connection: async for packet in connection:
if type(packet) is bytes: if isinstance(packet, bytes):
self.source.parser.feed_data(packet) self.source.parser.feed_data(packet)
else: else:
logger.warn('discarding packet: not a BINARY frame') logger.warning('discarding packet: not a BINARY frame')
except websockets.WebSocketException as error: except websockets.WebSocketException as error:
logger.debug(f'exception while receiving packet: {error}') logger.debug(f'exception while receiving packet: {error}')

View File

@@ -18,6 +18,9 @@
import asyncio import asyncio
import logging import logging
import traceback import traceback
import collections
import sys
from typing import Awaitable
from functools import wraps from functools import wraps
from colors import color from colors import color
from pyee import EventEmitter from pyee import EventEmitter
@@ -33,6 +36,7 @@ logger = logging.getLogger(__name__)
def setup_event_forwarding(emitter, forwarder, event_name): def setup_event_forwarding(emitter, forwarder, event_name):
def emit(*args, **kwargs): def emit(*args, **kwargs):
forwarder.emit(event_name, *args, **kwargs) forwarder.emit(event_name, *args, **kwargs)
emitter.on(event_name, emit) emitter.on(event_name, emit)
@@ -43,6 +47,8 @@ def composite_listener(cls):
registers/deregisters all methods named `on_<event_name>` as a listener for registers/deregisters all methods named `on_<event_name>` as a listener for
the <event_name> event with an emitter. the <event_name> event with an emitter.
""" """
# pylint: disable=protected-access
def register(self, emitter): def register(self, emitter):
for method_name in dir(cls): for method_name in dir(cls):
if method_name.startswith('on_'): if method_name.startswith('on_'):
@@ -53,13 +59,42 @@ def composite_listener(cls):
if method_name.startswith('on_'): if method_name.startswith('on_'):
emitter.remove_listener(method_name[3:], getattr(self, method_name)) emitter.remove_listener(method_name[3:], getattr(self, method_name))
cls._bumble_register_composite = register cls._bumble_register_composite = register
cls._bumble_deregister_composite = deregister cls._bumble_deregister_composite = deregister
return cls return cls
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CompositeEventEmitter(EventEmitter): class AbortableEventEmitter(EventEmitter):
def abort_on(self, event: str, awaitable: Awaitable):
"""
Set a coroutine or future to abort when an event occur.
"""
future = asyncio.ensure_future(awaitable)
if future.done():
return future
def on_event(*_):
msg = f'abort: {event} event occurred.'
if isinstance(future, asyncio.Task):
# python < 3.9 does not support passing a message on `Task.cancel`
if sys.version_info < (3, 9, 0):
future.cancel()
else:
future.cancel(msg)
else:
future.set_exception(asyncio.CancelledError(msg))
def on_done(_):
self.remove_listener(event, on_event)
self.on(event, on_event)
future.add_done_callback(on_done)
return future
# -----------------------------------------------------------------------------
class CompositeEventEmitter(AbortableEventEmitter):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._listener = None self._listener = None
@@ -70,6 +105,7 @@ class CompositeEventEmitter(EventEmitter):
@listener.setter @listener.setter
def listener(self, listener): def listener(self, listener):
# pylint: disable=protected-access
if self._listener: if self._listener:
# Call the deregistration methods for each base class that has them # Call the deregistration methods for each base class that has them
for cls in self._listener.__class__.mro(): for cls in self._listener.__class__.mro():
@@ -109,7 +145,9 @@ class AsyncRunner:
try: try:
await item await item
except Exception as error: except Exception as error:
logger.warning(f'{color("!!! Exception in work queue:", "red")} {error}') logger.warning(
f'{color("!!! Exception in work queue:", "red")} {error}'
)
# Shared default queue # Shared default queue
default_queue = WorkQueue() default_queue = WorkQueue()
@@ -130,7 +168,10 @@ class AsyncRunner:
try: try:
await coroutine await coroutine
except Exception: except Exception:
logger.warning(f'{color("!!! Exception in wrapper:", "red")} {traceback.format_exc()}') logger.warning(
f'{color("!!! Exception in wrapper:", "red")} '
f'{traceback.format_exc()}'
)
asyncio.create_task(run()) asyncio.create_task(run())
else: else:
@@ -140,3 +181,103 @@ class AsyncRunner:
return wrapper return wrapper
return decorator return decorator
# -----------------------------------------------------------------------------
class FlowControlAsyncPipe:
"""
Asyncio pipe with flow control. When writing to the pipe, the source is
paused (by calling a function passed in when the pipe is created) if the
amount of queued data exceeds a specified threshold.
"""
def __init__(
self,
pause_source,
resume_source,
write_to_sink=None,
drain_sink=None,
threshold=0,
):
self.pause_source = pause_source
self.resume_source = resume_source
self.write_to_sink = write_to_sink
self.drain_sink = drain_sink
self.threshold = threshold
self.queue = collections.deque() # Queue of packets
self.queued_bytes = 0 # Number of bytes in the queue
self.ready_to_pump = asyncio.Event()
self.paused = False
self.source_paused = False
self.pump_task = None
def start(self):
if self.pump_task is None:
self.pump_task = asyncio.create_task(self.pump())
self.check_pump()
def stop(self):
if self.pump_task is not None:
self.pump_task.cancel()
self.pump_task = None
def write(self, packet):
self.queued_bytes += len(packet)
self.queue.append(packet)
# Pause the source if we're over the threshold
if self.queued_bytes > self.threshold and not self.source_paused:
logger.debug(f'pausing source (queued={self.queued_bytes})')
self.pause_source()
self.source_paused = True
self.check_pump()
def pause(self):
if not self.paused:
self.paused = True
if not self.source_paused:
self.pause_source()
self.source_paused = True
self.check_pump()
def resume(self):
if self.paused:
self.paused = False
if self.source_paused:
self.resume_source()
self.source_paused = False
self.check_pump()
def can_pump(self):
return self.queue and not self.paused and self.write_to_sink is not None
def check_pump(self):
if self.can_pump():
self.ready_to_pump.set()
else:
self.ready_to_pump.clear()
async def pump(self):
while True:
# Wait until we can try to pump packets
await self.ready_to_pump.wait()
# Try to pump a packet
if self.can_pump():
packet = self.queue.pop()
self.write_to_sink(packet)
self.queued_bytes -= len(packet)
# Drain the sink if we can
if self.drain_sink:
await self.drain_sink()
# Check if we can accept more
if self.queued_bytes <= self.threshold and self.source_paused:
logger.debug(f'resuming source (queued={self.queued_bytes})')
self.source_paused = False
self.resume_source()
self.check_pump()

View File

@@ -2,7 +2,7 @@ Bumble Documentation
==================== ====================
The documentation consists of a collection of markdown text files, with the root of the file The documentation consists of a collection of markdown text files, with the root of the file
hierarchy at `docs/mkdocs/src`, starting with `docs/mkdocs/src/index.md`. hierarchy at `docs/mkdocs/src`, starting with `docs/mkdocs/src/index.md`.
You can read the documentation as text, with any text viewer or your favorite markdown viewer, You can read the documentation as text, with any text viewer or your favorite markdown viewer,
or generate a static HTML "site" using `mkdocs`, which you can then open with any browser. or generate a static HTML "site" using `mkdocs`, which you can then open with any browser.
@@ -14,9 +14,9 @@ The `mkdocs` directory contains all the data (actual documentation) and metadata
`mkdocs/mkdocs.yml` contains the site configuration. `mkdocs/mkdocs.yml` contains the site configuration.
`mkdocs/src/` is the directory where the actual documentation text, in markdown format, is located. `mkdocs/src/` is the directory where the actual documentation text, in markdown format, is located.
To build, from the project's root directory: To build, from the project's root directory:
``` ```
$ mkdocs build -f docs/mkdocs/mkdocs.yml $ mkdocs build -f docs/mkdocs/mkdocs.yml
``` ```
You can then open `docs/mkdocs/site/index.html` with any web browser. You can then open `docs/mkdocs/site/index.html` with any web browser.

File diff suppressed because one or more lines are too long

View File

@@ -1 +1 @@
{"date":644900643.85054696,"appVersion":"4.1.5","drawing":{"modificationDate":644894800.328192,"activeArtboardIndex":0,"settings":{"outlineMode":false,"isolateActiveLayer":false,"snapToEdges":false,"snapToPoints":false,"guidesVisible":true,"snapToGrid":false,"units":"Pixels","dimensionsVisible":true,"dynamicGuides":false,"isCMYKColorPreviewEnabled":false,"undoHistoryDisabled":false,"snapToGuides":true,"drawOnlyUsingPencil":false,"whiteBackground":false,"rulersVisible":true,"isTimeLapseWatermarkDisabled":false},"artboardPaths":["Artboard0.json"],"documentVersion":"unknown"}} {"date":644900643.85054696,"appVersion":"4.1.5","drawing":{"modificationDate":644894800.328192,"activeArtboardIndex":0,"settings":{"outlineMode":false,"isolateActiveLayer":false,"snapToEdges":false,"snapToPoints":false,"guidesVisible":true,"snapToGrid":false,"units":"Pixels","dimensionsVisible":true,"dynamicGuides":false,"isCMYKColorPreviewEnabled":false,"undoHistoryDisabled":false,"snapToGuides":true,"drawOnlyUsingPencil":false,"whiteBackground":false,"rulersVisible":true,"isTimeLapseWatermarkDisabled":false},"artboardPaths":["Artboard0.json"],"documentVersion":"unknown"}}

View File

@@ -1 +1 @@
{"documentJSONFilename":"Document.json","undoHistoryJSONFilename":"UndoHistory.json","fileFormatVersion":0,"thumbnailImageFilename":"Thumbnail.png"} {"documentJSONFilename":"Document.json","undoHistoryJSONFilename":"UndoHistory.json","fileFormatVersion":0,"thumbnailImageFilename":"Thumbnail.png"}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1 +1 @@
{"date":644900741.09290397,"appVersion":"4.1.5","drawing":{"modificationDate":644894800.328192,"activeArtboardIndex":0,"settings":{"outlineMode":false,"isolateActiveLayer":false,"snapToEdges":false,"snapToPoints":false,"guidesVisible":true,"snapToGrid":false,"units":"Pixels","dimensionsVisible":true,"dynamicGuides":false,"isCMYKColorPreviewEnabled":false,"undoHistoryDisabled":false,"snapToGuides":true,"drawOnlyUsingPencil":false,"whiteBackground":false,"rulersVisible":true,"isTimeLapseWatermarkDisabled":false},"artboardPaths":["Artboard0.json"],"documentVersion":"unknown"}} {"date":644900741.09290397,"appVersion":"4.1.5","drawing":{"modificationDate":644894800.328192,"activeArtboardIndex":0,"settings":{"outlineMode":false,"isolateActiveLayer":false,"snapToEdges":false,"snapToPoints":false,"guidesVisible":true,"snapToGrid":false,"units":"Pixels","dimensionsVisible":true,"dynamicGuides":false,"isCMYKColorPreviewEnabled":false,"undoHistoryDisabled":false,"snapToGuides":true,"drawOnlyUsingPencil":false,"whiteBackground":false,"rulersVisible":true,"isTimeLapseWatermarkDisabled":false},"artboardPaths":["Artboard0.json"],"documentVersion":"unknown"}}

View File

@@ -1 +1 @@
{"documentJSONFilename":"Document.json","undoHistoryJSONFilename":"UndoHistory.json","fileFormatVersion":0,"thumbnailImageFilename":"Thumbnail.png"} {"documentJSONFilename":"Document.json","undoHistoryJSONFilename":"UndoHistory.json","fileFormatVersion":0,"thumbnailImageFilename":"Thumbnail.png"}

File diff suppressed because one or more lines are too long

View File

@@ -7,6 +7,8 @@ nav:
- Getting Started: getting_started.md - Getting Started: getting_started.md
- Development: - Development:
- Python Environments: development/python_environments.md - Python Environments: development/python_environments.md
- Contributing: development/contributing.md
- Code Style: development/code_style.md
- Use Cases: - Use Cases:
- Overview: use_cases/index.md - Overview: use_cases/index.md
- Use Case 1: use_cases/use_case_1.md - Use Case 1: use_cases/use_case_1.md
@@ -45,6 +47,10 @@ nav:
- HCI Bridge: apps_and_tools/hci_bridge.md - HCI Bridge: apps_and_tools/hci_bridge.md
- Golden Gate Bridge: apps_and_tools/gg_bridge.md - Golden Gate Bridge: apps_and_tools/gg_bridge.md
- Show: apps_and_tools/show.md - Show: apps_and_tools/show.md
- GATT Dump: apps_and_tools/gatt_dump.md
- Pair: apps_and_tools/pair.md
- Unbond: apps_and_tools/unbond.md
- USB Probe: apps_and_tools/usb_probe.md
- Hardware: - Hardware:
- Overview: hardware/index.md - Overview: hardware/index.md
- Platforms: - Platforms:

View File

@@ -1,6 +1,6 @@
# This requirements file is for python3 # This requirements file is for python3
mkdocs == 1.2.3 mkdocs == 1.4.0
mkdocs-material == 7.1.7 mkdocs-material == 8.5.6
mkdocs-material-extensions == 1.0.1 mkdocs-material-extensions == 1.0.3
pymdown-extensions == 8.2 pymdown-extensions == 9.6
mkdocstrings == 0.15.1 mkdocstrings-python == 0.7.1

View File

@@ -1,2 +1,2 @@
API EXAMPLES API EXAMPLES
============ ============

View File

@@ -1,2 +1,2 @@
API DEVELOPER GUIDE API DEVELOPER GUIDE
=================== ===================

View File

@@ -16,4 +16,3 @@ Bumble Python API
### HCI_Disconnect_Command ### HCI_Disconnect_Command
::: bumble.hci.HCI_Disconnect_Command ::: bumble.hci.HCI_Disconnect_Command

View File

@@ -7,10 +7,12 @@ The Console app is an interactive text user interface that offers a number of fu
* scanning * scanning
* advertising * advertising
* connecting to devices * connecting to and disconnecting from devices
* changing connection parameters * changing connection parameters
* enabling encryption
* discovering GATT services and characteristics * discovering GATT services and characteristics
* read & write GATT characteristics * reading and writing GATT characteristics
* subscribing to and unsubscribing from GATT characteristics
The console user interface has 3 main panes: The console user interface has 3 main panes:

View File

@@ -1,2 +1,2 @@
GOLDEN GATE BRIDGE GOLDEN GATE BRIDGE
================== ==================

View File

@@ -28,5 +28,3 @@ a host that send custom HCI commands that the controller may not understand.
(through which the communication with other virtual controllers will be mediated). (through which the communication with other virtual controllers will be mediated).
NOTE: this assumes you're running a Link Relay on port `10723`. NOTE: this assumes you're running a Link Relay on port `10723`.

View File

@@ -11,4 +11,3 @@ These include:
* [Golden Gate Bridge](gg_bridge.md) - a bridge between GATT and UDP to use with the Golden Gate "stack tool" * [Golden Gate Bridge](gg_bridge.md) - a bridge between GATT and UDP to use with the Golden Gate "stack tool"
* [Show](show.md) - Parse a file with HCI packets and print the details of each packet in a human readable form * [Show](show.md) - Parse a file with HCI packets and print the details of each packet in a human readable form
* [Link Relay](link_relay.md) - WebSocket relay for virtual RemoteLink instances to communicate with each other. * [Link Relay](link_relay.md) - WebSocket relay for virtual RemoteLink instances to communicate with each other.

View File

@@ -31,4 +31,3 @@ The WebSocket path used by a connecting client indicates which virtual "chat roo
It is possible to connect to a "chat room" in a relay as an observer, rather than a virtual controller. In this case, a text-based console can be used to observe what is going on in the "chat room". Tools like [`wscat`](https://github.com/websockets/wscat#readme) or [`websocat`](https://github.com/vi/websocat) can be used for that. It is possible to connect to a "chat room" in a relay as an observer, rather than a virtual controller. In this case, a text-based console can be used to observe what is going on in the "chat room". Tools like [`wscat`](https://github.com/websockets/wscat#readme) or [`websocat`](https://github.com/vi/websocat) can be used for that.
Example: `wscat --connect ws://localhost:10723/test` Example: `wscat --connect ws://localhost:10723/test`

View File

@@ -0,0 +1,50 @@
USB PROBE TOOL
==============
This tool lists all the USB devices, with details about each device.
For each device, the different possible Bumble transport strings that can
refer to it are listed.
If the device is known to be a Bluetooth HCI device, its identifier is printed
in reverse colors, and the transport names in cyan color.
For other devices, regardless of their type, the transport names are printed
in red. Whether that device is actually a Bluetooth device or not depends on
whether it is a Bluetooth device that uses a non-standard Class, or some other
type of device (there's no way to tell).
## Usage
This command line tool may be invoked with no arguments, or with `--verbose`
for extra details.
When installed from PyPI, run as
```
$ bumble-usb-probe
```
or, for extra details, with the `--verbose` argument
```
$ bumble-usb-probe --v
```
When running from the source distribution:
```
$ python3 apps/usb-probe.py
```
or
```
$ python3 apps/usb-probe.py --verbose
```
!!! example
```
$ python3 apps/usb_probe.py
ID 0A12:0001
Bumble Transport Names: usb:0 or usb:0A12:0001
Bus/Device: 020/034
Class: Wireless Controller
Subclass/Protocol: 1/1 [Bluetooth]
Manufacturer: None
Product: USB2.0-BT
```

View File

@@ -1,2 +1,2 @@
CONTROLLER CONTROLLER
========== ==========

View File

@@ -1,2 +1,2 @@
GATT GATT
==== ====

View File

@@ -1,2 +1,2 @@
HOST HOST
==== ====

View File

@@ -1,2 +1,2 @@
SECURITY MANAGER SECURITY MANAGER
================ ================

View File

@@ -0,0 +1,43 @@
CODE STYLE
==========
The Python code style used in this project follows the [Black code style](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html).
# Formatting
For now, we are configuring the `black` formatter with the option to leave quotes unchanged.
The preferred quote style is single quotes, which isn't a configurable option for `Black`, so we are not enforcing it. This may change in the future.
## Ignoring Commit for Git Blame
The adoption of `Black` as a formatter came in late in the project, with already a large code base. As a result, a large number of files were changed in a single commit, which gets in the way of tracing authorship with `git blame`. The file `git-blame-ignore-revs` contains the commit hash of when that mass-formatting event occurred, which you can use to skip it in a `git blame` analysis:
!!! example "Ignoring a commit with `git blame`"
```
$ git blame --ignore-revs-file .git-blame-ignore-revs
```
# Linting
The project includes a `pylint` configuration (see the `pyproject.toml` file for details).
The `pre-commit` checks only enforce that there are no errors. But we strongly recommend that you run the linter with warnings enabled at least, and possibly the "Refactor" ('R') and "Convention" ('C') categories as well.
To run the linter, use the `project.lint` invoke command.
!!! example "Running the linter with default options"
With the default settings, Errors and Warnings are enabled, but Refactor and Convention categories are not.
```
$ invoke project.lint
```
!!! example "Running the linter with all categories"
```
$ invoke project.lint --disable=""
```
# Editor/IDE Integration
## Visual Studio Code
The project includes a `.vscode/settings.json` file that specifies the `black` formatter and enables an editor ruler at 88 columns.
You may want to configure your own environment to "format on save" with `black` if you find that useful. We are not making that choice at the workspace level.

View File

@@ -0,0 +1,11 @@
CONTRIBUTING TO THE PROJECT
===========================
To contribute some code to the project, you will need to submit a GitHub Pull Request (a.k.a PR). Please familiarize yourself with how that works (see [GitHub Pull Requests](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests))
You should follow the project's [code style](code_style.md), and pre-check your code before submitting a PR. The GitHub project is set up with some [Actions](https://github.com/features/actions) that will check that a PR passes at least the basic tests and complies with the coding style, but it is still recommended to check that for yourself before submitting a PR.
To run the basic checks (essentially: running the tests, the linter, and the formatter), use the `project.pre-commit` `invoke` command, and address any issues found:
```
$ invoke project.pre-commit
```

View File

@@ -1,11 +1,11 @@
PYTHON ENVIRONMENTS PYTHON ENVIRONMENTS
=================== ===================
When you don't want to install Bumble in your main/default python environment, When you don't want to install Bumble in your main/default python environment,
using a virtual environment, where the package and its dependencies can be using a virtual environment, where the package and its dependencies can be
installed, isolated from the rest, may be useful. installed, isolated from the rest, may be useful.
There are many flavors of python environments and dependency managers. There are many flavors of python environments and dependency managers.
This page describes a few of the most common ones. This page describes a few of the most common ones.
@@ -16,7 +16,7 @@ Visit the [`venv` documentation](https://docs.python.org/3/library/venv.html) pa
## Pyenv ## Pyenv
`pyenv` lets you easily switch between multiple versions of Python. It's simple, unobtrusive, and follows the UNIX tradition of single-purpose tools that do one thing well. `pyenv` lets you easily switch between multiple versions of Python. It's simple, unobtrusive, and follows the UNIX tradition of single-purpose tools that do one thing well.
Visit the [`pyenv` site](https://github.com/pyenv/pyenv) for instructions on how to install Visit the [`pyenv` site](https://github.com/pyenv/pyenv) for instructions on how to install
and use `pyenv` and use `pyenv`
@@ -25,10 +25,10 @@ and use `pyenv`
Conda is a convenient package manager and virtual environment. Conda is a convenient package manager and virtual environment.
The file `environment.yml` is a Conda environment file that you can use to create The file `environment.yml` is a Conda environment file that you can use to create
a new Conda environment. Once created, you can simply activate this environment when a new Conda environment. Once created, you can simply activate this environment when
working with Bumble. working with Bumble.
Visit the [Conda site](https://docs.conda.io/en/latest/) for instructions on how to install Visit the [Conda site](https://docs.conda.io/en/latest/) for instructions on how to install
and use Conda. and use Conda.
A few useful commands: A few useful commands:
### Create a new `bumble` Conda environment ### Create a new `bumble` Conda environment
``` ```

View File

@@ -69,4 +69,4 @@ An app that connects to an RFComm server and bridges the RFComm channel to a loc
An app that implements an RFComm server and, when a connection is received, bridges the channel to a local TCP socket An app that implements an RFComm server and, when a connection is received, bridges the channel to a local TCP socket
## `run_scanner.py` ## `run_scanner.py`
An app that scan for BLE devices and print the advertisements received. An app that scan for BLE devices and print the advertisements received.

Some files were not shown because too many files have changed in this diff Show More